#include <Wire.h>
#include <Adafruit_GFX.h>
#include <Adafruit_SSD1306.h>

#include <SPI.h>
#include <Adafruit_ST7735.h>

#include <MAX30105.h>
#include "heartRate.h"

// =================== I2C PINS ===================
#define SDA_PIN 5
#define SCL_PIN 6

// =================== OLED =======================
#define OLED_WIDTH 128
#define OLED_HEIGHT 64
#define OLED_ADDR 0x3C
Adafruit_SSD1306 display(OLED_WIDTH, OLED_HEIGHT, &Wire, -1);

// =================== ST7735 TFT (from your JP6) ===================
// JP6 pin 2 = D8  -> TFT SCK
// JP6 pin 3 = D10 -> TFT MOSI (SDA/SDI)
// JP6 pin 4 = D1  -> TFT DC (A0)
// JP6 pin 5 = D2  -> TFT RST
// JP6 pin 6 = D9  -> TFT CS
#define TFT_SCK   D8
#define TFT_MOSI  D10
#define TFT_DC    D1
#define TFT_RST   D2
#define TFT_CS    D9
// JP6 pin 1 likely drives BL/LED via resistor to 3V3 (backlight always on)
// If your TFT has a BL pin and you wired it to a GPIO, define TFT_BL and drive it HIGH.
// #define TFT_BL  <some_gpio>

Adafruit_ST7735 tft = Adafruit_ST7735(TFT_CS, TFT_DC, TFT_RST);

// =================== MPU6050 (raw register map) ===================
#define MPU_ADDR      0x68
#define REG_PWR_MGMT1  0x6B
#define REG_SMPLRTDIV  0x19
#define REG_CONFIG     0x1A
#define REG_GYRO_CFG   0x1B
#define REG_ACCEL_CFG  0x1C
#define REG_ACCEL_XOUT 0x3B   // 0x3B..0x48 (14 bytes)

static const float ACCEL_LSB_PER_G  = 16384.0f; // ±2g
static const float GYRO_LSB_PER_DPS = 131.0f;   // ±250 dps

// =================== MAX30102 ===================
MAX30105 sensor;
long lastBeat = 0;
float bpm = 0;
bool beatFlag = false;

// =================== SpO2 (simple RMS ratio) ====
const int BUFFER_SIZE = 50;
long irBuf[BUFFER_SIZE];
long redBuf[BUFFER_SIZE];
int bufIndex = 0;
int bufCount = 0;
float spo2 = 0.0f;
unsigned long lastSpO2Update = 0;

// =================== Timing =====================
unsigned long lastOLED = 0;
unsigned long lastTFT  = 0;
const unsigned long OLED_INTERVAL_MS = 150;
const unsigned long TFT_INTERVAL_MS  = 150;

void computeSpO2() {
  if (bufCount < BUFFER_SIZE) return;

  double sumIR = 0, sumRED = 0;
  for (int i = 0; i < BUFFER_SIZE; i++) { sumIR += irBuf[i]; sumRED += redBuf[i]; }

  double dcIR = sumIR / BUFFER_SIZE;
  double dcRED = sumRED / BUFFER_SIZE;

  double sumSqIR = 0, sumSqRED = 0;
  for (int i = 0; i < BUFFER_SIZE; i++) {
    double acIR  = irBuf[i]  - dcIR;
    double acRED = redBuf[i] - dcRED;
    sumSqIR  += acIR  * acIR;
    sumSqRED += acRED * acRED;
  }

  double rmsIR  = sqrt(sumSqIR / BUFFER_SIZE);
  double rmsRED = sqrt(sumSqRED / BUFFER_SIZE);
  if (dcIR <= 0 || dcRED <= 0 || rmsIR <= 0 || rmsRED <= 0) return;

  double R = (rmsRED / dcRED) / (rmsIR / dcIR);
  double estSpO2 = 110.0 - 25.0 * R;

  if (estSpO2 < 70.0) estSpO2 = 70.0;
  if (estSpO2 > 100.0) estSpO2 = 100.0;

  spo2 = (float)estSpO2;
}

// ---------------- I2C helpers -------------------
bool writeReg(uint8_t addr, uint8_t reg, uint8_t val) {
  Wire.beginTransmission(addr);
  Wire.write(reg);
  Wire.write(val);
  return (Wire.endTransmission() == 0);
}

bool readBytes(uint8_t addr, uint8_t reg, uint8_t* out, uint8_t n) {
  Wire.beginTransmission(addr);
  Wire.write(reg);
  if (Wire.endTransmission(false) != 0) return false;
  uint8_t got = Wire.requestFrom(addr, n);
  if (got != n) return false;
  for (uint8_t i = 0; i < n; i++) out[i] = Wire.read();
  return true;
}

int16_t be16(const uint8_t* b) {
  return (int16_t)((b[0] << 8) | b[1]);
}

// ---------------- MPU init/read -----------------
bool initMPU6050_like() {
  if (!writeReg(MPU_ADDR, REG_PWR_MGMT1, 0x00)) return false;
  delay(10);
  writeReg(MPU_ADDR, REG_SMPLRTDIV, 0x07); // ~125Hz
  writeReg(MPU_ADDR, REG_CONFIG,    0x03); // DLPF ~44Hz
  writeReg(MPU_ADDR, REG_GYRO_CFG,  0x00); // ±250dps
  writeReg(MPU_ADDR, REG_ACCEL_CFG, 0x00); // ±2g
  return true;
}

bool readMPU(float& ax_g, float& ay_g, float& az_g,
             float& gx_dps, float& gy_dps, float& gz_dps,
             float& temp_c) {
  uint8_t buf[14];
  if (!readBytes(MPU_ADDR, REG_ACCEL_XOUT, buf, 14)) return false;

  int16_t ax = be16(&buf[0]);
  int16_t ay = be16(&buf[2]);
  int16_t az = be16(&buf[4]);
  int16_t t  = be16(&buf[6]);
  int16_t gx = be16(&buf[8]);
  int16_t gy = be16(&buf[10]);
  int16_t gz = be16(&buf[12]);

  ax_g = ax / ACCEL_LSB_PER_G;
  ay_g = ay / ACCEL_LSB_PER_G;
  az_g = az / ACCEL_LSB_PER_G;

  gx_dps = gx / GYRO_LSB_PER_DPS;
  gy_dps = gy / GYRO_LSB_PER_DPS;
  gz_dps = gz / GYRO_LSB_PER_DPS;

  temp_c = (t / 340.0f) + 36.53f;
  return true;
}

// ---------------- OLED render -------------------
void drawOLED(float axg, float ayg, float azg,
              float gxd, float gyd, float gzd,
              float tc, long ir, long red,
              float bpmVal, float spo2Val, bool beat) {

  display.clearDisplay();
  display.setTextColor(SSD1306_WHITE);

  display.setTextSize(2);
  display.setCursor(0, 0);
  display.print((int)(bpmVal + 0.5f));
  display.setTextSize(1);
  display.setCursor(40, 4);
  display.print("BPM");

  display.setTextSize(2);
  display.setCursor(70, 0);
  display.print((int)(spo2Val + 0.5f));
  display.setTextSize(1);
  display.setCursor(112, 4);
  display.print("%");

  if (beat) display.fillCircle(62, 8, 3, SSD1306_WHITE);

  display.setTextSize(1);
  display.setCursor(0, 22);
  display.print("A(g) ");
  display.print(axg, 2); display.print(" ");
  display.print(ayg, 2); display.print(" ");
  display.print(azg, 2);

  display.setCursor(0, 34);
  display.print("G ");
  display.print(gxd, 1); display.print(" ");
  display.print(gyd, 1); display.print(" ");
  display.print(gzd, 1);

  display.setCursor(0, 46);
  display.print("T ");
  display.print(tc, 1);
  display.print("C");

  display.setCursor(0, 56);
  display.print("IR "); display.print(ir);
  display.print(" R "); display.print(red);

  display.display();
}

// ---------------- TFT render --------------------
void drawTFT(float axg, float ayg, float azg,
             float gxd, float gyd, float gzd,
             float tc, long ir, long red,
             float bpmVal, float spo2Val, bool beat) {

  tft.fillScreen(ST77XX_BLACK);

  tft.setTextColor(ST77XX_WHITE);
  tft.setTextSize(1);
  tft.setCursor(0, 0);
  tft.print("Vitals + IMU");

  tft.setTextSize(2);
  tft.setCursor(0, 14);
  tft.print("BPM ");
  tft.print((int)(bpmVal + 0.5f));

  if (beat) {
    tft.fillRect(150, 18, 8, 8, ST77XX_WHITE);
  }

  tft.setCursor(0, 38);
  tft.print("O2  ");
  tft.print((int)(spo2Val + 0.5f));
  tft.print("%");

  tft.setTextSize(1);

  tft.setCursor(0, 62);
  tft.print("A(g): ");
  tft.print(axg, 2); tft.print(", ");
  tft.print(ayg, 2); tft.print(", ");
  tft.print(azg, 2);

  tft.setCursor(0, 74);
  tft.print("G(dps): ");
  tft.print(gxd, 1); tft.print(", ");
  tft.print(gyd, 1); tft.print(", ");
  tft.print(gzd, 1);

  tft.setCursor(0, 86);
  tft.print("Temp: ");
  tft.print(tc, 1); tft.print("C");

  tft.setCursor(0, 98);
  tft.print("IR: "); tft.print(ir);

  tft.setCursor(0, 110);
  tft.print("RED: "); tft.print(red);
}

void setup() {
  Serial.begin(115200);
  delay(300);

  // I2C for OLED + MAX30102 + IMU
  Wire.begin(SDA_PIN, SCL_PIN);
  Wire.setClock(400000);

  // OLED init
  if (!display.begin(SSD1306_SWITCHCAPVCC, OLED_ADDR)) {
    Serial.println("OLED not found");
    while (1) delay(10);
  }
  display.clearDisplay();
  display.setTextSize(1);
  display.setCursor(0, 0);
  display.println("Booting...");
  display.display();

  // TFT init (SPI)
  // IMPORTANT: explicitly start SPI with your SCK/MOSI pins from JP6
  SPI.begin(TFT_SCK, -1, TFT_MOSI, TFT_CS);

  // Try BLACKTAB first; if you still get white, try REDTAB/GREENTAB.
  tft.initR(INITR_BLACKTAB);
  // tft.initR(INITR_REDTAB);
  // tft.initR(INITR_GREENTAB);

  tft.setRotation(1);
  tft.fillScreen(ST77XX_BLACK);
  tft.setTextColor(ST77XX_WHITE);
  tft.setTextSize(1);
  tft.setCursor(0, 0);
  tft.println("Booting...");

  // MAX30102 init
  Serial.println("Initializing MAX30102...");
  if (!sensor.begin(Wire, I2C_SPEED_FAST)) {
    Serial.println("MAX30102 not found");
    display.clearDisplay(); display.setCursor(0, 0); display.println("MAX30102 missing"); display.display();
    tft.fillScreen(ST77XX_BLACK); tft.setCursor(0, 0); tft.println("MAX30102 missing");
    while (1) delay(10);
  }
  sensor.setup();
  sensor.setLEDMode(2);
  sensor.setADCRange(16384);
  sensor.setSampleRate(100);
  sensor.setPulseWidth(411);
  sensor.setPulseAmplitudeRed(0x3F);
  sensor.setPulseAmplitudeIR(0x3F);

  // IMU init
  if (!initMPU6050_like()) {
    Serial.println("IMU init failed");
    display.clearDisplay(); display.setCursor(0, 0); display.println("IMU init failed"); display.display();
    tft.fillScreen(ST77XX_BLACK); tft.setCursor(0, 0); tft.println("IMU init failed");
    while (1) delay(10);
  }

  Serial.println("Ready.");
}

void loop() {
  // MAX30102
  long ir  = sensor.getIR();
  long red = sensor.getRed();

  beatFlag = false;
  if (checkForBeat(ir)) {
    long delta = millis() - lastBeat;
    lastBeat = millis();
    if (delta > 0) bpm = 60.0f / (delta / 1000.0f);
    beatFlag = true;
  }

  // SpO2 buffer
  irBuf[bufIndex]  = ir;
  redBuf[bufIndex] = red;
  bufIndex = (bufIndex + 1) % BUFFER_SIZE;
  if (bufCount < BUFFER_SIZE) bufCount++;

  if (millis() - lastSpO2Update > 300) {
    computeSpO2();
    lastSpO2Update = millis();
  }

  // IMU
  float axg, ayg, azg, gxd, gyd, gzd, tc;
  bool ok = readMPU(axg, ayg, azg, gxd, gyd, gzd, tc);
  if (!ok) axg = ayg = azg = gxd = gyd = gzd = tc = 0.0f;

  // Serial Plotter (label:value)
  Serial.print("ax_g:");   Serial.print(axg, 4);  Serial.print(" ");
  Serial.print("ay_g:");   Serial.print(ayg, 4);  Serial.print(" ");
  Serial.print("az_g:");   Serial.print(azg, 4);  Serial.print(" ");
  Serial.print("gy_dps:"); Serial.print(gyd, 3);  Serial.print(" ");
  Serial.print("BPM:");    Serial.print(bpm, 2);  Serial.print(" ");
  Serial.print("SpO2:");   Serial.print(spo2, 1); Serial.print(" ");
  Serial.print("IR:");     Serial.print(ir);      Serial.print(" ");
  Serial.print("RED:");    Serial.print(red);     Serial.print(" ");
  Serial.print("beat:");   Serial.println(beatFlag ? 1 : 0);

  // Displays
  unsigned long now = millis();

  if (now - lastOLED >= OLED_INTERVAL_MS) {
    lastOLED = now;
    drawOLED(axg, ayg, azg, gxd, gyd, gzd, tc, ir, red, bpm, spo2, beatFlag);
  }

  if (now - lastTFT >= TFT_INTERVAL_MS) {
    lastTFT = now;
    drawTFT(axg, ayg, azg, gxd, gyd, gzd, tc, ir, red, bpm, spo2, beatFlag);
  }

  delay(20);
}
