/**
 * vision/ocr.h - OCR pipeline for scorecard reading
 * 
 * Ported from desktop test_detector.cpp
 * 
 * Pipeline:
 * 1. Crop (top 20%, bottom 21%, left 19%)
 * 2. Adaptive binary threshold
 * 3. L-mark detection with convolution kernels
 * 4. Perspective transform
 * 5. Connected component analysis
 * 6. Digit recognition
 */

#ifndef VISION_OCR_H
#define VISION_OCR_H

#include "image.h"
#include <algorithm>
#include <cmath>

// ============================================================================
// Constants
// ============================================================================

#define OCR_KERNEL_SIZE 21
#define OCR_ARM_LENGTH 8
#define OCR_MAX_ROWS 5
#define OCR_WARP_WIDTH 300
#define OCR_WARP_HEIGHT 420

// ============================================================================
// Structures
// ============================================================================

struct Point2f {
    float x, y;
    Point2f() : x(0), y(0) {}
    Point2f(float _x, float _y) : x(_x), y(_y) {}
};

struct OCRQuad {
    Point2f tl, tr, bl, br;
    bool valid;
    OCRQuad() : valid(false) {}
};

struct LMarkResult {
    int x, y;
    float score;
    bool found;
    LMarkResult() : x(0), y(0), score(0), found(false) {}
};

struct OCRResult {
    int numbers[OCR_MAX_ROWS];
    int count;
    bool success;
    OCRResult() : count(0), success(false) {
        for (int i = 0; i < OCR_MAX_ROWS; i++) numbers[i] = 0;
    }
};

// ============================================================================
// Image Processing
// ============================================================================

/**
 * Crop image by percentages
 */
inline bool ocrCropImage(const GrayImage& src, GrayImage& dst,
                         float topPct, float bottomPct, float leftPct, float rightPct) {
    int cropTop = (int)(src.height * topPct);
    int cropBottom = (int)(src.height * bottomPct);
    int cropLeft = (int)(src.width * leftPct);
    int cropRight = (int)(src.width * rightPct);
    
    int newW = src.width - cropLeft - cropRight;
    int newH = src.height - cropTop - cropBottom;
    
    if (newW <= 0 || newH <= 0) return false;
    if (!dst.alloc(newW, newH)) return false;
    
    for (int y = 0; y < newH; y++) {
        memcpy(dst.data + y * newW, 
               src.data + (y + cropTop) * src.width + cropLeft, 
               newW);
    }
    return true;
}

/**
 * Adaptive binary threshold
 */
inline bool ocrAdaptiveThreshold(const GrayImage& src, GrayImage& dst, 
                                  int blockSize = 35, int C = 10) {
    if (!dst.alloc(src.width, src.height)) return false;
    
    int halfBlock = blockSize / 2;
    
    for (int y = 0; y < src.height; y++) {
        for (int x = 0; x < src.width; x++) {
            // Compute local mean
            int sum = 0, count = 0;
            int y0 = std::max(0, y - halfBlock);
            int y1 = std::min(src.height, y + halfBlock + 1);
            int x0 = std::max(0, x - halfBlock);
            int x1 = std::min(src.width, x + halfBlock + 1);
            
            for (int yy = y0; yy < y1; yy++) {
                for (int xx = x0; xx < x1; xx++) {
                    sum += src.at(xx, yy);
                    count++;
                }
            }
            
            int thresh = sum / count - C;
            dst.set(x, y, src.at(x, y) < thresh ? 0 : 255);
        }
    }
    return true;
}

// ============================================================================
// L-Mark Kernel Detection
// ============================================================================

enum OCRLMarkType { OCR_LMARK_TL = 0, OCR_LMARK_TR, OCR_LMARK_BL, OCR_LMARK_BR };

/**
 * Generate L-mark convolution kernel
 */
inline void ocrGenerateLKernel(float kernel[OCR_KERNEL_SIZE][OCR_KERNEL_SIZE], OCRLMarkType type) {
    const int K = OCR_KERNEL_SIZE;
    const int arm = OCR_ARM_LENGTH;
    const int thick = 3;
    
    // Clear kernel
    for (int y = 0; y < K; y++) {
        for (int x = 0; x < K; x++) {
            kernel[y][x] = 0;
        }
    }
    
    // Define corner and arm directions based on type
    int cx, cy;       // corner position
    int hDir, vDir;   // arm directions (1 or -1)
    
    switch (type) {
        case OCR_LMARK_TL: cx = 1; cy = 1; hDir = 1; vDir = 1; break;
        case OCR_LMARK_TR: cx = K-2; cy = 1; hDir = -1; vDir = 1; break;
        case OCR_LMARK_BL: cx = 1; cy = K-2; hDir = 1; vDir = -1; break;
        case OCR_LMARK_BR: cx = K-2; cy = K-2; hDir = -1; vDir = -1; break;
    }
    
    // Horizontal arm (white = expect light pixels outside L)
    for (int i = 0; i < arm; i++) {
        int x = cx + hDir * i;
        for (int t = 0; t < thick; t++) {
            int y = cy + vDir * (arm + t);
            if (x >= 0 && x < K && y >= 0 && y < K) kernel[y][x] = -1;
        }
    }
    
    // Vertical arm (white)
    for (int i = 0; i < arm; i++) {
        int y = cy + vDir * i;
        for (int t = 0; t < thick; t++) {
            int x = cx + hDir * (arm + t);
            if (x >= 0 && x < K && y >= 0 && y < K) kernel[y][x] = -1;
        }
    }
    
    // Inner corner (black = expect dark pixels for L-mark)
    for (int i = 0; i < arm + thick; i++) {
        for (int j = 0; j < arm + thick; j++) {
            int x = cx + hDir * i;
            int y = cy + vDir * j;
            if (x >= 0 && x < K && y >= 0 && y < K) {
                if ((i < thick || j < thick) && i < arm + thick && j < arm + thick) {
                    kernel[y][x] = 1;
                }
            }
        }
    }
    
    // Inside of L (should be white/light)
    for (int i = thick; i < arm + thick - 1; i++) {
        for (int j = thick; j < arm + thick - 1; j++) {
            int x = cx + hDir * i;
            int y = cy + vDir * j;
            if (x >= 0 && x < K && y >= 0 && y < K) {
                kernel[y][x] = 0.5f;
            }
        }
    }
}

/**
 * Apply L-mark kernel at position
 */
inline float ocrApplyLKernel(const GrayImage& binary, int px, int py,
                              const float kernel[OCR_KERNEL_SIZE][OCR_KERNEL_SIZE]) {
    const int K = OCR_KERNEL_SIZE;
    const int half = K / 2;
    float score = 0;
    float maxScore = 0;
    
    for (int ky = 0; ky < K; ky++) {
        for (int kx = 0; kx < K; kx++) {
            float w = kernel[ky][kx];
            if (w == 0) continue;
            
            int x = px + kx - half;
            int y = py + ky - half;
            float pixel = binary.get(x, y) / 255.0f;  // 0=black, 1=white
            
            if (w > 0) {
                // Expect black (low pixel value)
                score += w * (1.0f - pixel);
            } else {
                // Expect white (high pixel value)
                score += (-w) * pixel;
            }
            maxScore += fabsf(w);
        }
    }
    
    return maxScore > 0 ? score / maxScore : 0;
}

/**
 * Find L-mark in region
 */
inline LMarkResult ocrFindLMark(const GrayImage& binary,
                                 const float kernel[OCR_KERNEL_SIZE][OCR_KERNEL_SIZE],
                                 int x0, int y0, int x1, int y1) {
    LMarkResult best;
    const int step = 3;
    const float threshold = 0.6f;
    
    for (int y = y0; y < y1; y += step) {
        for (int x = x0; x < x1; x += step) {
            float score = ocrApplyLKernel(binary, x, y, kernel);
            if (score > best.score) {
                best.score = score;
                best.x = x;
                best.y = y;
            }
        }
    }
    
    // Refine
    if (best.score > 0) {
        int rx = best.x, ry = best.y;
        for (int dy = -step; dy <= step; dy++) {
            for (int dx = -step; dx <= step; dx++) {
                float score = ocrApplyLKernel(binary, rx + dx, ry + dy, kernel);
                if (score > best.score) {
                    best.score = score;
                    best.x = rx + dx;
                    best.y = ry + dy;
                }
            }
        }
    }
    
    best.found = best.score > threshold;
    return best;
}

// ============================================================================
// Perspective Transform
// ============================================================================

/**
 * Warp perspective using quad corners
 */
inline bool ocrPerspectiveTransform(const GrayImage& src, GrayImage& dst,
                                     const OCRQuad& q, int outW, int outH) {
    if (!dst.alloc(outW, outH)) return false;
    dst.clear(128);
    
    // Compute homography coefficients using simple bilinear interpolation
    // For each output pixel (u,v), compute source coordinates
    for (int v = 0; v < outH; v++) {
        float ty = (float)v / (outH - 1);
        
        // Interpolate top and bottom edges
        float topX = q.tl.x + ty * (q.bl.x - q.tl.x);
        float topY = q.tl.y + ty * (q.bl.y - q.tl.y);
        float botX = q.tr.x + ty * (q.br.x - q.tr.x);
        float botY = q.tr.y + ty * (q.br.y - q.tr.y);
        
        for (int u = 0; u < outW; u++) {
            float tx = (float)u / (outW - 1);
            
            // Bilinear interpolation
            float srcX = topX + tx * (botX - topX);
            float srcY = topY + tx * (botY - topY);
            
            if (srcX >= 0 && srcX < src.width - 1 && srcY >= 0 && srcY < src.height - 1) {
                dst.set(u, v, src.sample(srcX, srcY));
            }
        }
    }
    return true;
}

// ============================================================================
// Connected Components
// ============================================================================

struct OCRComponent {
    int minX, maxX, minY, maxY;
    int pixelCount;
    
    OCRComponent() : minX(9999), maxX(0), minY(9999), maxY(0), pixelCount(0) {}
    int width() const { return maxX - minX + 1; }
    int height() const { return maxY - minY + 1; }
};

/**
 * Find root in union-find
 */
inline int ocrFind(int* parent, int i) {
    while (parent[i] != i) {
        parent[i] = parent[parent[i]];  // path compression
        i = parent[i];
    }
    return i;
}

/**
 * Union in union-find
 */
inline void ocrUnion(int* parent, int a, int b) {
    int ra = ocrFind(parent, a);
    int rb = ocrFind(parent, b);
    if (ra != rb) {
        if (ra < rb) parent[rb] = ra;
        else parent[ra] = rb;
    }
}

/**
 * Label connected components (black pixels = foreground)
 * Returns number of components
 */
inline int ocrLabelComponents(const GrayImage& binary, int* labels, 
                               OCRComponent* components, int maxComponents) {
    int w = binary.width, h = binary.height;
    
    // Initialize
    for (int i = 0; i < w * h; i++) labels[i] = 0;
    
    // Union-find parent array
    int* parent = new int[w * h + 1];
    for (int i = 0; i <= w * h; i++) parent[i] = i;
    
    int nextLabel = 1;
    
    // First pass: assign labels
    for (int y = 0; y < h; y++) {
        for (int x = 0; x < w; x++) {
            if (binary.at(x, y) != 0) continue;  // skip white
            
            int idx = y * w + x;
            int left = (x > 0) ? labels[idx - 1] : 0;
            int up = (y > 0) ? labels[idx - w] : 0;
            
            if (left == 0 && up == 0) {
                labels[idx] = nextLabel++;
            } else if (left != 0 && up == 0) {
                labels[idx] = left;
            } else if (left == 0 && up != 0) {
                labels[idx] = up;
            } else {
                labels[idx] = std::min(left, up);
                ocrUnion(parent, left, up);
            }
        }
    }
    
    // Second pass: resolve labels and compute stats
    for (int i = 0; i < maxComponents; i++) {
        components[i] = OCRComponent();
    }
    
    int numComponents = 0;
    for (int y = 0; y < h; y++) {
        for (int x = 0; x < w; x++) {
            int idx = y * w + x;
            int lbl = labels[idx];
            if (lbl > 0) {
                lbl = ocrFind(parent, lbl);
                labels[idx] = lbl;
                
                if (lbl < maxComponents) {
                    OCRComponent& c = components[lbl];
                    c.pixelCount++;
                    c.minX = std::min(c.minX, x);
                    c.maxX = std::max(c.maxX, x);
                    c.minY = std::min(c.minY, y);
                    c.maxY = std::max(c.maxY, y);
                    if (lbl >= numComponents) numComponents = lbl + 1;
                }
            }
        }
    }
    
    delete[] parent;
    return numComponents;
}

// ============================================================================
// Digit Recognition
// ============================================================================

/**
 * Recognize single digit from region
 * Returns 0-9 or -1 if unrecognized
 */
inline int ocrRecognizeDigit(const GrayImage& binary, int x0, int y0, int w, int h) {
    if (w < 3 || h < 5) return -1;
    
    // Compute features
    int totalPixels = 0;
    int topPixels = 0, midPixels = 0, botPixels = 0;
    int leftPixels = 0, rightPixels = 0;
    
    int thirdH = h / 3;
    int halfW = w / 2;
    
    int topLinePixels = 0, midLinePixels = 0, botLinePixels = 0;
    int topY = h / 6, midY = h / 2, botY = h - h / 6;
    
    for (int dy = 0; dy < h; dy++) {
        for (int dx = 0; dx < w; dx++) {
            if (binary.get(x0 + dx, y0 + dy) == 0) {  // black pixel
                totalPixels++;
                
                // Vertical thirds
                if (dy < thirdH) topPixels++;
                else if (dy < 2 * thirdH) midPixels++;
                else botPixels++;
                
                // Horizontal halves
                if (dx < halfW) leftPixels++;
                else rightPixels++;
                
                // Horizontal lines
                if (abs(dy - topY) <= 2) topLinePixels++;
                if (abs(dy - midY) <= 2) midLinePixels++;
                if (abs(dy - botY) <= 2) botLinePixels++;
            }
        }
    }
    
    if (totalPixels == 0) return -1;
    
    float aspect = (float)w / h;
    float density = (float)totalPixels / (w * h);
    float topRatio = (float)topPixels / totalPixels;
    float midRatio = (float)midPixels / totalPixels;
    float botRatio = (float)botPixels / totalPixels;
    float leftRatio = (float)leftPixels / totalPixels;
    float rightRatio = (float)rightPixels / totalPixels;
    float topLine = (float)topLinePixels / w;
    float midLine = (float)midLinePixels / w;
    float botLine = (float)botLinePixels / w;
    
    // Left/right edge presence
    int leftEdgePixels = 0, rightEdgePixels = 0;
    for (int dy = 0; dy < h; dy++) {
        if (binary.get(x0, y0 + dy) == 0 || binary.get(x0 + 1, y0 + dy) == 0)
            leftEdgePixels++;
        if (binary.get(x0 + w - 1, y0 + dy) == 0 || binary.get(x0 + w - 2, y0 + dy) == 0)
            rightEdgePixels++;
    }
    float leftEdge = (float)leftEdgePixels / h;
    float rightEdge = (float)rightEdgePixels / h;
    
    // ===== Classification rules =====
    
    // Very narrow bottom-heavy → 6
    if (aspect < 0.45f && botRatio > 0.38f && topRatio < 0.32f && leftEdge > 0.45f)
        return 6;
    
    // Narrow/medium top-heavy → 7
    if (aspect < 0.55f && topRatio > 0.48f && botRatio < 0.30f)
        return 7;
    
    // Very narrow → 1
    if (aspect < 0.30f && density > 0.40f && fabsf(topRatio - botRatio) < 0.15f)
        return 1;
    
    // Wide 7
    if (aspect > 0.50f && topRatio > 0.50f && botRatio < 0.30f && rightRatio > 0.55f)
        return 7;
    
    // Wide 6
    if (aspect > 0.50f && botRatio > 0.35f && topRatio < 0.32f)
        return 6;
    
    // 0: oval shape with hole
    if (aspect > 0.50f && density < 0.35f &&
        fabsf(topRatio - botRatio) < 0.20f &&
        fabsf(leftRatio - rightRatio) < 0.20f)
        return 0;
    
    // 8: two loops, balanced
    if (aspect > 0.35f && density > 0.35f && density < 0.55f &&
        fabsf(topRatio - botRatio) < 0.15f && midRatio > 0.28f)
        return 8;
    
    // 3: right-heavy
    if (rightRatio > 0.52f && topLine > 0.20f && botLine > 0.20f)
        return 3;
    
    if (aspect >= 0.28f && aspect < 0.50f && rightRatio > 0.55f && leftRatio < 0.45f)
        return 3;
    
    // 9: top-heavy with right edge
    if (topRatio > 0.38f && rightEdge > 0.40f && botRatio < 0.35f)
        return 9;
    
    // 4: strong middle horizontal
    if (midLine > 0.30f && rightEdge > 0.40f && midRatio > 0.28f)
        return 4;
    
    if (rightEdge > 0.45f && midRatio > 0.32f && topRatio < 0.38f)
        return 4;
    
    // 5: top line, left-heavy
    if (topLine > 0.30f && leftRatio > 0.42f && botRatio > 0.28f)
        return 5;
    
    // 2: bottom line strong
    if (botLine > 0.35f && topRatio > 0.28f && midRatio < 0.38f)
        return 2;
    
    // Fallbacks
    if (topRatio > 0.48f && botRatio < 0.28f) return 7;
    if (botRatio > 0.40f && topRatio < 0.28f) return 6;
    if (rightRatio > 0.58f) return 3;
    if (topRatio > 0.35f && rightRatio > 0.55f) return 9;
    if (fabsf(topRatio - botRatio) < 0.12f && density > 0.30f && density < 0.55f) return 8;
    if (aspect < 0.40f && fabsf(topRatio - botRatio) < 0.15f) return 1;
    
    return -1;
}

// ============================================================================
// Full OCR Pipeline
// ============================================================================

/**
 * Run full OCR pipeline on image
 * Returns detected numbers from right column of scorecard
 */
inline OCRResult ocrProcess(const GrayImage& input) {
    OCRResult result;
    
    // Step 1: Crop
    GrayImage cropped;
    if (!ocrCropImage(input, cropped, 0.20f, 0.21f, 0.19f, 0.0f)) {
        Serial.println("OCR: crop failed");
        return result;
    }
    
    // Step 2: Binary threshold
    GrayImage binary;
    if (!ocrAdaptiveThreshold(cropped, binary, 35, 10)) {
        Serial.println("OCR: threshold failed");
        return result;
    }
    
    // Step 3: L-mark detection
    float kernelTL[OCR_KERNEL_SIZE][OCR_KERNEL_SIZE];
    float kernelTR[OCR_KERNEL_SIZE][OCR_KERNEL_SIZE];
    float kernelBL[OCR_KERNEL_SIZE][OCR_KERNEL_SIZE];
    float kernelBR[OCR_KERNEL_SIZE][OCR_KERNEL_SIZE];
    
    ocrGenerateLKernel(kernelTL, OCR_LMARK_TL);
    ocrGenerateLKernel(kernelTR, OCR_LMARK_TR);
    ocrGenerateLKernel(kernelBL, OCR_LMARK_BL);
    ocrGenerateLKernel(kernelBR, OCR_LMARK_BR);
    
    int cornerW = binary.width / 4;
    int cornerH = binary.height / 5;
    
    LMarkResult tlResult = ocrFindLMark(binary, kernelTL, 0, 0, cornerW, cornerH);
    LMarkResult trResult = ocrFindLMark(binary, kernelTR, binary.width - cornerW, 0, binary.width, cornerH);
    LMarkResult blResult = ocrFindLMark(binary, kernelBL, 0, binary.height - cornerH, cornerW, binary.height);
    LMarkResult brResult = ocrFindLMark(binary, kernelBR, binary.width - cornerW, binary.height - cornerH, binary.width, binary.height);
    
    Serial.printf("OCR L-marks: TL=%d TR=%d BL=%d BR=%d\n", 
                  tlResult.found, trResult.found, blResult.found, brResult.found);
    
    // Step 4: Form quad and warp
    OCRQuad q;
    if (tlResult.found && trResult.found && blResult.found && brResult.found) {
        q.tl = Point2f(tlResult.x, tlResult.y);
        q.tr = Point2f(trResult.x, trResult.y);
        q.bl = Point2f(blResult.x, blResult.y);
        q.br = Point2f(brResult.x, brResult.y);
        q.valid = true;
    } else {
        // Fallback to image corners
        q.tl = Point2f(10, 10);
        q.tr = Point2f(binary.width - 10, 10);
        q.bl = Point2f(10, binary.height - 10);
        q.br = Point2f(binary.width - 10, binary.height - 10);
        q.valid = true;
    }
    
    // Expand quad
    float expandX = 10, expandYTop = 5, expandYBot = 50;
    q.tl.x -= expandX; q.tl.y -= expandYTop;
    q.tr.x += expandX; q.tr.y -= expandYTop;
    q.bl.x -= expandX; q.bl.y += expandYBot;
    q.br.x += expandX; q.br.y += expandYBot;
    
    GrayImage warped;
    if (!ocrPerspectiveTransform(cropped, warped, q, OCR_WARP_WIDTH, OCR_WARP_HEIGHT)) {
        Serial.println("OCR: warp failed");
        return result;
    }
    
    // Re-threshold warped image
    GrayImage warpedClean;
    if (!ocrAdaptiveThreshold(warped, warpedClean, 25, 8)) {
        Serial.println("OCR: clean threshold failed");
        return result;
    }
    
    // Step 5: Find connected components
    int maxLabels = 500;
    int* labels = new int[warpedClean.width * warpedClean.height];
    OCRComponent* components = new OCRComponent[maxLabels];
    
    int numComponents = ocrLabelComponents(warpedClean, labels, components, maxLabels);
    
    // Filter digit candidates
    int minH = warpedClean.height / 12;
    int maxH = warpedClean.height / 4;
    int topExclude = warpedClean.height * 12 / 100;
    int rightExclude = warpedClean.width * 80 / 100;
    int leftFilterX = warpedClean.width * 22 / 100;
    int rightFilterX = warpedClean.width * 85 / 100;
    
    // Collect valid digit candidates
    struct DigitCandidate {
        int x, y, w, h;
    };
    DigitCandidate candidates[50];
    int numCandidates = 0;
    
    for (int i = 1; i < numComponents && numCandidates < 50; i++) {
        OCRComponent& c = components[i];
        if (c.pixelCount == 0) continue;
        
        int h = c.height();
        int w = c.width();
        
        // Size filter
        if (h < minH || h > maxH || w < minH / 4 || w >= h * 1.2f)
            continue;
        
        // Exclude corners
        if (c.minY < topExclude && c.maxX > rightExclude)
            continue;
        
        // Right column filter
        if (c.minX <= leftFilterX || c.maxX >= rightFilterX)
            continue;
        
        candidates[numCandidates++] = {c.minX, c.minY, w, h};
    }
    
    // Sort by Y
    for (int i = 0; i < numCandidates - 1; i++) {
        for (int j = i + 1; j < numCandidates; j++) {
            if (candidates[j].y < candidates[i].y) {
                DigitCandidate tmp = candidates[i];
                candidates[i] = candidates[j];
                candidates[j] = tmp;
            }
        }
    }
    
    // Group into rows and recognize
    int rowThresh = warpedClean.height / 10;
    int lastY = -1000;
    int currentNumber = 0;
    
    for (int i = 0; i < numCandidates; i++) {
        // New row?
        if (candidates[i].y - lastY > rowThresh) {
            if (currentNumber > 0 && result.count < OCR_MAX_ROWS) {
                result.numbers[result.count++] = currentNumber;
            }
            currentNumber = 0;
        }
        
        int digit = ocrRecognizeDigit(warpedClean, candidates[i].x, candidates[i].y,
                                       candidates[i].w, candidates[i].h);
        if (digit >= 0) {
            currentNumber = currentNumber * 10 + digit;
        }
        
        lastY = candidates[i].y;
    }
    
    // Last row
    if (currentNumber > 0 && result.count < OCR_MAX_ROWS) {
        result.numbers[result.count++] = currentNumber;
    }
    
    delete[] labels;
    delete[] components;
    
    result.success = result.count > 0;
    
    Serial.printf("OCR result: %d numbers", result.count);
    for (int i = 0; i < result.count; i++) {
        Serial.printf(" %d", result.numbers[i]);
    }
    Serial.println();
    
    return result;
}

#endif

