#include <AccelStepper.h>
#include <Arduino.h>
#include "esp_camera.h"
#include <romeo.giudici-project-1_inferencing.h>
#include "edge-impulse-sdk/dsp/image/image.hpp"

static bool is_running = false; // serial logic to actuate motors

const int M1_STEP = D0;
const int M1_DIR  = D1;
const int M2_STEP = D5;
const int M2_DIR  = D6;
const int M3_STEP = D7;
const int M3_DIR  = D8;

AccelStepper stepper1(AccelStepper::DRIVER, M1_STEP, M1_DIR);
AccelStepper stepper2(AccelStepper::DRIVER, M2_STEP, M2_DIR);
AccelStepper stepper3(AccelStepper::DRIVER, M3_STEP, M3_DIR);

const int SPEED_FAST = 2000;
const int ACCEL      = 1500;
const int STEP_SIZE  = 3000;

// XIAO S3 Sense Pinouts
#define PWDN_GPIO_NUM     -1
#define RESET_GPIO_NUM    -1
#define XCLK_GPIO_NUM     10
#define SIOD_GPIO_NUM     40  
#define SIOC_GPIO_NUM     39 
#define Y9_GPIO_NUM       48
#define Y8_GPIO_NUM       11
#define Y7_GPIO_NUM       12
#define Y6_GPIO_NUM       14
#define Y5_GPIO_NUM       16
#define Y4_GPIO_NUM       18
#define Y3_GPIO_NUM       17
#define Y2_GPIO_NUM       15
#define VSYNC_GPIO_NUM    38
#define HREF_GPIO_NUM     47
#define PCLK_GPIO_NUM     13


void runAllSteppers() {
    // Blocking loop, but short duration due to small STEP_SIZE
    while (stepper1.distanceToGo() != 0 || stepper2.distanceToGo() != 0 || stepper3.distanceToGo() != 0) {
        stepper1.run();
        stepper2.run();
        stepper3.run();
    }
}

void stopMotors() {
    // Stops motors by commanding a zero step and running the step function once
    stepper1.move(0);
    stepper2.move(0);
    stepper3.move(0);
    runAllSteppers();
    Serial.println("Action: STOP");
}

void moveForward() {
    // M1 CCW, M2 CW 
    stepper1.move(STEP_SIZE);  
    stepper2.move(-STEP_SIZE); 
    stepper3.move(0);          
    runAllSteppers();
    Serial.println("Action: FORWARD");
}

void turnLeft() {
    // Spin Left (CCW)
    stepper1.move(-STEP_SIZE/5); 
    stepper2.move(-STEP_SIZE/5); 
    stepper3.move(0);
    runAllSteppers();
    Serial.println("Action: TURN LEFT");
}

void turnRight() {
    // Spin Right (CW)
    stepper1.move(STEP_SIZE/5);  
    stepper2.move(STEP_SIZE/5);  
    stepper3.move(0);
    runAllSteppers();
    Serial.println("Action: TURN RIGHT");
}


#define EI_CAMERA_RAW_FRAME_BUFFER_COLS           320
#define EI_CAMERA_RAW_FRAME_BUFFER_ROWS           240
#define EI_CAMERA_FRAME_BYTE_SIZE                 3

static bool debug_nn = false; 
static bool is_initialised = false;
uint8_t *snapshot_buf; 

static camera_config_t camera_config = {
    .pin_pwdn = PWDN_GPIO_NUM,
    .pin_reset = RESET_GPIO_NUM,
    .pin_xclk = XCLK_GPIO_NUM,
    .pin_sscb_sda = SIOD_GPIO_NUM,
    .pin_sscb_scl = SIOC_GPIO_NUM,
    .pin_d7 = Y9_GPIO_NUM,
    .pin_d6 = Y8_GPIO_NUM,
    .pin_d5 = Y7_GPIO_NUM,
    .pin_d4 = Y6_GPIO_NUM,
    .pin_d3 = Y5_GPIO_NUM,
    .pin_d2 = Y4_GPIO_NUM,
    .pin_d1 = Y3_GPIO_NUM,
    .pin_d0 = Y2_GPIO_NUM,
    .pin_vsync = VSYNC_GPIO_NUM,
    .pin_href = HREF_GPIO_NUM,
    .pin_pclk = PCLK_GPIO_NUM,
    .xclk_freq_hz = 20000000,
    .ledc_timer = LEDC_TIMER_0,
    .ledc_channel = LEDC_CHANNEL_0,
    .pixel_format = PIXFORMAT_JPEG, 
    .frame_size = FRAMESIZE_QVGA, 
    .jpeg_quality = 12, 
    .fb_count = 1, 
    .fb_location = CAMERA_FB_IN_PSRAM,
    .grab_mode = CAMERA_GRAB_WHEN_EMPTY,
};

// Camera Initialization
bool ei_camera_init(void) {
    if (is_initialised) return true;
    esp_err_t err = esp_camera_init(&camera_config);
    if (err != ESP_OK) {
        Serial.printf("Camera init failed with error 0x%x\n", err);
        return false;
    }
    sensor_t *s = esp_camera_sensor_get();
    
    // Default OV2640 fixes
    if (s->id.PID == OV2640_PID) s->set_vflip(s, 1);
    
    is_initialised = true;
    return true;
}

bool ei_camera_capture(uint32_t img_width, uint32_t img_height, uint8_t *out_buf) {
    bool do_resize = false;
    if (!is_initialised) return false;
    camera_fb_t *fb = esp_camera_fb_get();
    if (!fb) return false;
    bool converted = fmt2rgb888(fb->buf, fb->len, PIXFORMAT_JPEG, snapshot_buf);
    esp_camera_fb_return(fb);
    if(!converted) return false;
    if ((img_width != EI_CAMERA_RAW_FRAME_BUFFER_COLS) || (img_height != EI_CAMERA_RAW_FRAME_BUFFER_ROWS)) do_resize = true;
    if (do_resize) {
        ei::image::processing::crop_and_interpolate_rgb888(
            snapshot_buf, EI_CAMERA_RAW_FRAME_BUFFER_COLS, EI_CAMERA_RAW_FRAME_BUFFER_ROWS,
            out_buf, img_width, img_height);
    }
    return true;
}

static int ei_camera_get_data(size_t offset, size_t length, float *out_ptr) {
    size_t pixel_ix = offset * 3;
    size_t pixels_left = length;
    size_t out_ptr_ix = 0;
    while (pixels_left != 0) {
        out_ptr[out_ptr_ix] = (snapshot_buf[pixel_ix + 2] << 16) + (snapshot_buf[pixel_ix + 1] << 8) + snapshot_buf[pixel_ix];
        out_ptr_ix++; pixel_ix+=3; pixels_left--;
    }
    return 0;
}


void setup() {
    
    Serial.begin(115200);
    Serial.println("WALL-AZY Starting...");

    // Motor Setup
    pinMode(M1_STEP, OUTPUT); pinMode(M1_DIR, OUTPUT);
    pinMode(M2_STEP, OUTPUT); pinMode(M2_DIR, OUTPUT);
    pinMode(M3_STEP, OUTPUT); pinMode(M3_DIR, OUTPUT);

    stepper1.setMaxSpeed(SPEED_FAST); stepper1.setAcceleration(ACCEL);
    stepper2.setMaxSpeed(SPEED_FAST); stepper2.setAcceleration(ACCEL);
    stepper3.setMaxSpeed(SPEED_FAST); stepper3.setAcceleration(ACCEL);

    // Camera Init
    if (ei_camera_init() == false) {
        Serial.println("ERR: Failed to initialize Camera!");
    } else {
        // Buffer Init
        snapshot_buf = (uint8_t*)malloc(EI_CAMERA_RAW_FRAME_BUFFER_COLS * EI_CAMERA_RAW_FRAME_BUFFER_ROWS * EI_CAMERA_FRAME_BYTE_SIZE);
        if(snapshot_buf == nullptr) {
            Serial.println("ERR: Failed to allocate snapshot buffer! (Check PSRAM)");
            return;
        }
        Serial.println("Camera initialized OK and Buffer allocated.");
    }
    
    Serial.println("System Ready. Send 'start' to begin tracking, or 'stop' to pause.");
    
    // is_running remains false initially, handled by loop
}


void loop() {
    if (Serial.available()) {
        String command = Serial.readStringUntil('\n');
        command.trim(); 
        command.toLowerCase();
        
        if (command.equals("start")) {
            is_running = true;
            Serial.println("--- STARTING ROBOT TRACKING ---");
        } else if (command.equals("stop")) {
            is_running = false;
            stopMotors(); // Ensure motors stop immediately
            Serial.println("--- PAUSED: Send 'start' to resume ---");
        }
    }

    // Image Capture and Inference Setup
    if (ei_camera_capture((size_t)EI_CLASSIFIER_INPUT_WIDTH, (size_t)EI_CLASSIFIER_INPUT_HEIGHT, snapshot_buf) == false) {
        Serial.println("Failed to capture image");
        return;
    }

    ei::signal_t signal;
    signal.total_length = EI_CLASSIFIER_INPUT_WIDTH * EI_CLASSIFIER_INPUT_HEIGHT;
    signal.get_data = &ei_camera_get_data;
    ei_impulse_result_t result = { 0 };
    
    EI_IMPULSE_ERROR err = run_classifier(&signal, &result, debug_nn);
    if (err != EI_IMPULSE_OK) {
        Serial.printf("ERR: Failed to run classifier (%d)\n", err);
        return;
    }

    // Decision Variables
    float bestConfidence = 0.0;
    int bestX = 0; // Center X coordinate of the best bounding box
    
    // Width constants
    const int CENTER_X_PIXEL = EI_CLASSIFIER_INPUT_WIDTH / 2; // 48
    const int DEAD_ZONE = 30; // 23 - 73 Pixels to allow for centering
    const float ACTION_THRESHOLD = 0.4; // Minimum confidence to trigger movement

    // Find the Best "red_ball" Detection
    for (uint32_t i = 0; i < result.bounding_boxes_count; i++) {
        ei_impulse_result_bounding_box_t bb = result.bounding_boxes[i];
        
        if (strcmp(bb.label, "red_ball") == 0 && bb.value > bestConfidence) {
             bestConfidence = bb.value;
             // Calculate center of the bounding box
             bestX = bb.x + (bb.width / 2);
        }
    }
    
    // TARGET NOT FOUND OR LOW CONFIDENCE
    if (bestConfidence < ACTION_THRESHOLD) {
        Serial.printf("Decision: STOP (Target Lost, Conf: %.2f < %.2f)\n", bestConfidence, ACTION_THRESHOLD);
        
        if (is_running) {
             stopMotors(); 
        }
        return;
    }
    
    // TARGET FOUND
    Serial.printf("Target Conf: %.2f, Center X: %d\n", bestConfidence, bestX);
    
    // ALIGNMENT & FORWARD MOVEMENT
    if (bestX > CENTER_X_PIXEL + DEAD_ZONE) {
        // Ball is on the right side of the center dead zone
        Serial.println("Decision: TURN RIGHT");
        if (is_running) {
            turnRight();
        }
    } 
    else if (bestX < CENTER_X_PIXEL - DEAD_ZONE) {
        // Ball is on the left side of the center dead zone
        Serial.println("Decision: TURN LEFT");
        if (is_running) {
            turnLeft();
        }
    }
    else {
        // Ball is centered in the dead zone
        Serial.println("Decision: FORWARD");
        if (is_running) {
            moveForward();
        }
    }
}