#!/usr/bin/env python3
"""
UDP server + Tkinter viewer for IMU wand path tracking with Kalman filtering.

The ESP32S3 firmware sends multiple packet types:
- QUAT: quaternion (qw,qx,qy,qz)
- ACCEL: accelerometer (x,y,z)
- GYRO: gyroscope (x,y,z)
- GRAV: gravity vector (x,y,z)

This script uses Kalman filtering to track the wand tip path in 2D space (xz or yz plane).
"""

from __future__ import annotations

import math
import queue
import socket
import threading
import time
from collections import deque
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
import os
import tkinter as tk
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# ---- Network settings ------------------------------------------------------
LISTEN_HOST = os.getenv("LISTEN_HOST", "0.0.0.0")  # bind to all interfaces
LISTEN_PORT = int(os.getenv("LISTEN_PORT", "5005"))  # must match udp_port on the ESP32S3

# ---- Visualization settings ------------------------------------------------
CANVAS_SIZE = int(os.getenv("CANVAS_SIZE", "800"))
REFRESH_HZ = float(os.getenv("REFRESH_HZ", "60.0"))
WAND_LENGTH = float(os.getenv("WAND_LENGTH", "1.0"))  # Length of wand in arbitrary units
MAX_PATH_POINTS = int(os.getenv("MAX_PATH_POINTS", "2000"))  # Maximum path points to keep
PROJECTION_PLANE = os.getenv("PROJECTION_PLANE", "yz").lower()  # "xz" or "yz"

# ---- Kalman filter settings -------------------------------------------------
KALMAN_PROCESS_NOISE = float(os.getenv("KALMAN_PROCESS_NOISE", "0.01"))
KALMAN_MEASUREMENT_NOISE = float(os.getenv("KALMAN_MEASUREMENT_NOISE", "0.1"))


@dataclass
class Quaternion:
    w: float
    x: float
    y: float
    z: float

    def normalized(self) -> "Quaternion":
        mag = (self.w**2 + self.x**2 + self.y**2 + self.z**2) ** 0.5
        if mag == 0:
            return Quaternion(1.0, 0.0, 0.0, 0.0)
        return Quaternion(self.w / mag, self.x / mag, self.y / mag, self.z / mag)

    def to_rotation_matrix(self) -> Tuple[Tuple[float, float, float], ...]:
        """Convert to a 3x3 rotation matrix."""
        qw, qx, qy, qz = self.w, self.x, self.y, self.z
        return (
            (
                1 - 2 * (qy * qy + qz * qz),
                2 * (qx * qy - qz * qw),
                2 * (qx * qz + qy * qw),
            ),
            (
                2 * (qx * qy + qz * qw),
                1 - 2 * (qx * qx + qz * qz),
                2 * (qy * qz - qx * qw),
            ),
            (
                2 * (qx * qz - qy * qw),
                2 * (qy * qz + qx * qw),
                1 - 2 * (qx * qx + qy * qy),
            ),
        )

    def rotate_vector(self, vector: Tuple[float, float, float]) -> Tuple[float, float, float]:
        """Rotate a 3D vector by this quaternion."""
        matrix = self.to_rotation_matrix()
        x, y, z = vector
        rx = matrix[0][0] * x + matrix[0][1] * y + matrix[0][2] * z
        ry = matrix[1][0] * x + matrix[1][1] * y + matrix[1][2] * z
        rz = matrix[2][0] * x + matrix[2][1] * y + matrix[2][2] * z
        return rx, ry, rz


def slerp(q1: Quaternion, q2: Quaternion, t: float) -> Quaternion:
    """Spherical linear interpolation between two quaternions.
    
    Args:
        q1: Starting quaternion
        q2: Ending quaternion
        t: Interpolation factor in [0, 1]
    
    Returns:
        Interpolated quaternion
    """
    dot = q1.w * q2.w + q1.x * q2.x + q1.y * q2.y + q1.z * q2.z
    
    # If dot < 0, negate one quaternion to take shorter path
    if dot < 0.0:
        q2 = Quaternion(-q2.w, -q2.x, -q2.y, -q2.z)
        dot = -dot
    
    DOT_THRESH = 0.9995
    if dot > DOT_THRESH:
        # Quaternions are very close, use linear interpolation
        w = q1.w + t * (q2.w - q1.w)
        x = q1.x + t * (q2.x - q1.x)
        y = q1.y + t * (q2.y - q1.y)
        z = q1.z + t * (q2.z - q1.z)
        return Quaternion(w, x, y, z).normalized()
    
    # Spherical interpolation
    theta_0 = math.acos(dot)
    sin_theta_0 = math.sin(theta_0)
    theta = theta_0 * t
    sin_theta = math.sin(theta)
    s0 = math.cos(theta) - dot * sin_theta / sin_theta_0
    s1 = sin_theta / sin_theta_0
    
    w = (s0 * q1.w) + (s1 * q2.w)
    x = (s0 * q1.x) + (s1 * q2.x)
    y = (s0 * q1.y) + (s1 * q2.y)
    z = (s0 * q1.z) + (s1 * q2.z)
    return Quaternion(w, x, y, z).normalized()


@dataclass
class IMUData:
    """Container for all IMU sensor data."""
    quaternion: Optional[Quaternion] = None
    accelerometer: Optional[Tuple[int, int, int]] = None
    gyroscope: Optional[Tuple[int, int, int]] = None
    gravity: Optional[Tuple[int, int, int]] = None
    timestamp: float = 0.0


class KalmanFilter2D:
    """Simple 2D Kalman filter for position and velocity tracking."""
    
    def __init__(self, process_noise: float = 0.01, measurement_noise: float = 0.1) -> None:
        # State: [x, vx, y, vy] (position and velocity for two axes)
        self.state = np.array([0.0, 0.0, 0.0, 0.0])
        
        # Covariance matrix - reduced initial uncertainty for faster convergence
        self.P = np.diag([1.0, 1.0, 1.0, 1.0])
        
        # Process noise - tuned for smoother tracking
        self.Q = np.diag([1e-4, 1e-3, 1e-4, 1e-3])
        
        # Measurement noise - lower R means trust measurements more
        self.R = np.diag([1e-2, 1e-2])
        
        # State transition matrix (constant velocity model)
        # Will be updated with actual dt in update()
        self.F = np.eye(4)
        
        # Measurement matrix (we only observe position)
        self.H = np.array([
            [1, 0, 0, 0],
            [0, 0, 1, 0],
        ])
    
    def update(self, measurement: Tuple[float, float], dt: float) -> Tuple[float, float]:
        """Update filter with new measurement and return filtered position.
        
        Args:
            measurement: (x, y) position measurement
            dt: Time delta since last update (in seconds)
        """
        # Clamp dt to reasonable bounds to handle packet loss
        dt = max(0.001, min(0.1, dt))
        
        # Update state transition matrix with actual dt
        self.F = np.array([
            [1, dt, 0, 0],
            [0, 1, 0, 0],
            [0, 0, 1, dt],
            [0, 0, 0, 1],
        ])
        
        # Predict step
        self.state = self.F @ self.state
        self.P = self.F @ self.P @ self.F.T + self.Q
        
        # Update step
        z = np.array([measurement[0], measurement[1]])
        y = z - self.H @ self.state  # Innovation
        S = self.H @ self.P @ self.H.T + self.R  # Innovation covariance
        K = self.P @ self.H.T @ np.linalg.inv(S)  # Kalman gain
        
        self.state = self.state + K @ y
        self.P = (np.eye(4) - K @ self.H) @ self.P
        
        return (float(self.state[0]), float(self.state[2]))


class WandPathCanvas:
    def __init__(self, root: tk.Tk, data_queue: "queue.Queue[IMUData]") -> None:
        self.root = root
        self.canvas = tk.Canvas(root, width=CANVAS_SIZE, height=CANVAS_SIZE, bg="black")
        self.canvas.pack()
        self.data_queue = data_queue
        
        # Current IMU data
        self.current_data = IMUData()
        self.current_data.quaternion = Quaternion(1.0, 0.0, 0.0, 0.0)
        
        # Smoothed quaternion (for SLERP)
        self.smoothed_quat = Quaternion(1.0, 0.0, 0.0, 0.0)
        self.slerp_alpha = 0.15  # How much to blend new quaternion (0.15 = 15% new, 85% old)
        
        # Wand base position (assumed at origin)
        self.wand_base = np.array([0.0, 0.0, 0.0])
        
        # Initial wand direction (assumes wand points in +Z direction when quaternion is identity)
        self.wand_direction = np.array([0.0, 0.0, 1.0])
        
        # Kalman filter for path smoothing
        self.kalman = KalmanFilter2D(
            process_noise=KALMAN_PROCESS_NOISE,
            measurement_noise=KALMAN_MEASUREMENT_NOISE
        )
        
        # Path points (2D coordinates after projection)
        self.path_points: deque[Tuple[float, float]] = deque(maxlen=MAX_PATH_POINTS)
        self.raw_path_points: deque[Tuple[float, float]] = deque(maxlen=MAX_PATH_POINTS)
        
        # Scale and offset for displaying path
        self.path_scale = 100.0
        self.path_offset_x = CANVAS_SIZE / 2
        self.path_offset_y = CANVAS_SIZE / 2
        
        # Auto-scaling: track max distance from origin
        self.running_max_distance = 0.1  # Start with small value
        
        # Timestamp tracking for dt calculation
        self.last_packet_time: Optional[float] = None
        
        # UI elements
        self.text_id = self.canvas.create_text(10, 10, anchor="nw", text="", font=("Helvetica", 9), fill="white")
        
        # Clear button
        clear_btn = tk.Button(root, text="Clear Path", command=self.clear_path)
        clear_btn.pack(pady=5)
        
        self._schedule_next_frame()
    
    def clear_path(self) -> None:
        """Clear the path."""
        self.path_points.clear()
        self.raw_path_points.clear()
        self.kalman = KalmanFilter2D(
            process_noise=KALMAN_PROCESS_NOISE,
            measurement_noise=KALMAN_MEASUREMENT_NOISE
        )
        self.running_max_distance = 0.1
        self.last_packet_time = None
        self.smoothed_quat = Quaternion(1.0, 0.0, 0.0, 0.0)
        self._draw_path()
    
    def _schedule_next_frame(self) -> None:
        delay_ms = int(1000.0 / REFRESH_HZ)
        self.root.after(delay_ms, self._update_frame)
    
    def _update_frame(self) -> None:
        # Update current data with any new sensor readings
        dt = None
        try:
            while True:
                new_data = self.data_queue.get_nowait()
                # Merge new data into current data
                if new_data.quaternion:
                    self.current_data.quaternion = new_data.quaternion
                    # Apply SLERP smoothing to quaternion
                    self.smoothed_quat = slerp(self.smoothed_quat, new_data.quaternion, self.slerp_alpha)
                if new_data.accelerometer:
                    self.current_data.accelerometer = new_data.accelerometer
                if new_data.gyroscope:
                    self.current_data.gyroscope = new_data.gyroscope
                if new_data.gravity:
                    self.current_data.gravity = new_data.gravity
                
                # Compute dt from packet timestamp
                if new_data.timestamp > 0:
                    if self.last_packet_time is not None:
                        dt = new_data.timestamp - self.last_packet_time
                    self.last_packet_time = new_data.timestamp
        except queue.Empty:
            pass
        
        # If no new packets, use frame-based dt (fallback)
        if dt is None:
            if self.last_packet_time is not None:
                dt = time.time() - self.last_packet_time
            else:
                dt = 1.0 / REFRESH_HZ
        
        # Calculate wand tip position and update path
        if self.current_data.quaternion:
            self._update_path(dt)
        
        self._draw_path()
        self._schedule_next_frame()
    
    def _update_path(self, dt: float) -> None:
        """Calculate wand tip position and add to path.
        
        Args:
            dt: Time delta since last update (in seconds)
        """
        # Use smoothed quaternion for more stable tracking
        quat = self.smoothed_quat
        
        # Rotate wand direction vector by quaternion
        rotated_dir = np.array(quat.rotate_vector(tuple(self.wand_direction)))
        
        # Calculate tip position in 3D
        tip_3d = self.wand_base + rotated_dir * WAND_LENGTH
        
        # Project to 2D plane
        if PROJECTION_PLANE == "xz":
            raw_2d = (float(tip_3d[0]), float(tip_3d[2]))  # x, z
        elif PROJECTION_PLANE == "yz":
            raw_2d = (float(tip_3d[1]), float(tip_3d[2]))  # y, z
        else:
            raw_2d = (float(tip_3d[0]), float(tip_3d[2]))  # default to xz
        
        # Update auto-scaling based on raw measurement
        distance = math.sqrt(raw_2d[0]**2 + raw_2d[1]**2)
        self.running_max_distance = max(self.running_max_distance, distance)
        
        # Auto-scale: adjust path_scale based on max distance
        if self.running_max_distance > 1e-6:
            # Scale so that max distance fills about 70% of canvas
            target_radius = CANVAS_SIZE * 0.35
            self.path_scale = target_radius / self.running_max_distance
        
        # Apply Kalman filter with actual dt
        filtered_2d = self.kalman.update(raw_2d, dt)
        
        # Add to paths
        self.raw_path_points.append(raw_2d)
        self.path_points.append(filtered_2d)
    
    def _draw_path(self) -> None:
        """Draw the wand tip path."""
        self.canvas.delete("path")
        
        # Draw raw path (thin, dim)
        if len(self.raw_path_points) > 1:
            raw_coords = []
            for point in self.raw_path_points:
                x = self.path_offset_x + point[0] * self.path_scale
                y = self.path_offset_y - point[1] * self.path_scale  # invert y
                raw_coords.extend([x, y])
            
            self.canvas.create_line(
                *raw_coords, fill="gray", width=1, tags="path", smooth=True
            )
        
        # Draw filtered path (thick, bright)
        if len(self.path_points) > 1:
            filtered_coords = []
            for i, point in enumerate(self.path_points):
                x = self.path_offset_x + point[0] * self.path_scale
                y = self.path_offset_y - point[1] * self.path_scale  # invert y
                filtered_coords.extend([x, y])
                
                # Draw point with color gradient (newer = brighter)
                if i > 0:
                    alpha = i / len(self.path_points)
                    color_intensity = int(255 * (1 - alpha * 0.5))
                    color = f"#{min(255, color_intensity):02x}{min(255, color_intensity):02x}ff"
                    self.canvas.create_oval(
                        x - 2, y - 2, x + 2, y + 2,
                        fill=color, outline=color, tags="path"
                    )
            
            self.canvas.create_line(
                *filtered_coords, fill="cyan", width=3, tags="path", smooth=True
            )
        
        # Draw current position
        if self.path_points:
            current = self.path_points[-1]
            x = self.path_offset_x + current[0] * self.path_scale
            y = self.path_offset_y - current[1] * self.path_scale
            
            self.canvas.create_oval(
                x - 6, y - 6, x + 6, y + 6,
                fill="yellow", outline="red", width=2, tags="path"
            )
        
        # Update text display
        text_lines = [
            f"Wand Path Tracker ({PROJECTION_PLANE.upper()} plane)",
            "",
            f"Quaternion (smoothed):",
            f"  w: {self.smoothed_quat.w:+.4f}",
            f"  x: {self.smoothed_quat.x:+.4f}",
            f"  y: {self.smoothed_quat.y:+.4f}",
            f"  z: {self.smoothed_quat.z:+.4f}",
            "",
            f"Path Points: {len(self.path_points)}",
            f"Wand Length: {WAND_LENGTH:.2f}",
            f"Scale: {self.path_scale:.1f}",
            "",
            "Controls:",
            "  - Clear button: Reset path",
        ]
        
        if self.path_points:
            current = self.path_points[-1]
            text_lines.append(f"")
            text_lines.append(f"Current Position:")
            if PROJECTION_PLANE == "xz":
                text_lines.append(f"  X: {current[0]:+.3f}")
                text_lines.append(f"  Z: {current[1]:+.3f}")
            else:
                text_lines.append(f"  Y: {current[0]:+.3f}")
                text_lines.append(f"  Z: {current[1]:+.3f}")
        
        self.canvas.itemconfig(self.text_id, text="\n".join(text_lines))


def udp_listener(data_queue: "queue.Queue[IMUData]", stop_event: threading.Event) -> None:
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.bind((LISTEN_HOST, LISTEN_PORT))
    sock.settimeout(0.5)
    print(f"Listening for UDP packets on {LISTEN_HOST}:{LISTEN_PORT}")
    print("Waiting for IMU data...\n")

    packet_count = 0

    while not stop_event.is_set():
        try:
            data, addr = sock.recvfrom(1024)
        except socket.timeout:
            continue
        except OSError:
            break

        try:
            decoded = data.decode("utf-8").strip()
            parts = decoded.split(",")
            
            if len(parts) < 2:
                continue
            
            packet_type = parts[0]
            imu_data = IMUData()
            imu_data.timestamp = time.time()  # Set timestamp when packet arrives
            
            if packet_type == "QUAT" and len(parts) == 5:
                qw, qx, qy, qz = (float(p) for p in parts[1:5])
                imu_data.quaternion = Quaternion(qw, qx, qy, qz).normalized()
                packet_count += 1
                print(f"[{packet_count}] QUAT: w={qw:+.6f}, x={qx:+.6f}, y={qy:+.6f}, z={qz:+.6f}")
                data_queue.put(imu_data)
                
            elif packet_type == "ACCEL" and len(parts) == 4:
                ax, ay, az = (int(p) for p in parts[1:4])
                imu_data.accelerometer = (ax, ay, az)
                packet_count += 1
                print(f"[{packet_count}] ACCEL: x={ax:6d}, y={ay:6d}, z={az:6d}")
                data_queue.put(imu_data)
                
            elif packet_type == "GYRO" and len(parts) == 4:
                gx, gy, gz = (int(p) for p in parts[1:4])
                imu_data.gyroscope = (gx, gy, gz)
                packet_count += 1
                print(f"[{packet_count}] GYRO:  x={gx:6d}, y={gy:6d}, z={gz:6d}")
                data_queue.put(imu_data)
                
            elif packet_type == "GRAV" and len(parts) == 4:
                gx, gy, gz = (int(p) for p in parts[1:4])
                imu_data.gravity = (gx, gy, gz)
                packet_count += 1
                print(f"[{packet_count}] GRAV:  x={gx:6d}, y={gy:6d}, z={gz:6d}")
                data_queue.put(imu_data)
                
        except (ValueError, IndexError) as e:
            print(f"Error parsing packet: {decoded[:50]}... ({e})")
            continue

    sock.close()
    print("\nUDP listener stopped.")


def main() -> None:
    data_queue: "queue.Queue[IMUData]" = queue.Queue()
    stop_event = threading.Event()

    listener_thread = threading.Thread(
        target=udp_listener,
        args=(data_queue, stop_event),
        daemon=True,
    )
    listener_thread.start()

    root = tk.Tk()
    root.title("ESP32S3 Wand Path Tracker")
    WandPathCanvas(root, data_queue)

    try:
        root.mainloop()
    finally:
        stop_event.set()
        listener_thread.join(timeout=1.0)
        time.sleep(0.05)


if __name__ == "__main__":
    main()
