#!/usr/bin/env python3
"""
Hogwarts-themed battle game with IMU wand path tracking and CLIP spell classification.

Controls:
    - Arrow Up/Down + Enter in the menu
    - In battle: Draw spells with IMU wand (5 seconds per turn)
    - ESC returns to the main menu
    - After a win/loss press Enter to return to the menu

Assets used:
    background: assets/background/background.jpg
    hero sheet: assets/characters/harry.jpg
    boss:      assets/characters/neil-front.jpg
    spells:    assets/spells/lumos.png, expelliarmus.png, engorgio.png, protego.png
    music:
        - Menu:  assets/music/hp-theme.mp3
        - Battle: assets/music/battle-music.mp3
"""

from __future__ import annotations

import math
import os
import queue
import random
import socket
import threading
import time
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pygame
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# --------------------------------------------------------------------------- #
# Configuration
# --------------------------------------------------------------------------- #
BASE_DIR = Path(__file__).resolve().parent
ASSET_DIR = BASE_DIR / "assets"

SCREEN_WIDTH = 960
SCREEN_HEIGHT = 540
FPS = 60

# IMU settings
LISTEN_HOST = os.getenv("LISTEN_HOST", "0.0.0.0")
LISTEN_PORT = int(os.getenv("LISTEN_PORT", "5005"))
WAND_LENGTH = float(os.getenv("WAND_LENGTH", "1.0"))
PROJECTION_PLANE = os.getenv("PROJECTION_PLANE", "yz").lower()

# Turn settings
TURN_DURATION_SECONDS = 5.0

# Rectangles (x, y, width, height) used to crop poses out of harry.jpg
HERO_IDLE_RECT = pygame.Rect(120, 220, 540, 1150)
HERO_CAST_RECT = pygame.Rect(1180, 420, 520, 900)

ATTACK_POSE_DURATION_MS = 450
MESSAGE_DURATION_MS = 2000

MAX_HERO_HP = 180
MAX_BOSS_HP = 220

# Path canvas and spell panel settings
PANEL_HEIGHT = int(SCREEN_HEIGHT * 0.30)  # 30% of screen height
PANEL_WIDTH = SCREEN_WIDTH  # Full width panel
CANVAS_WIDTH = SCREEN_WIDTH // 3  # 1/3 of panel width
SPELL_GRID_WIDTH = (SCREEN_WIDTH * 2) // 3  # 2/3 of panel width
PANEL_Y = SCREEN_HEIGHT - PANEL_HEIGHT  # Position at bottom
CANVAS_X = 0  # Canvas at left of panel
SPELL_GRID_X = CANVAS_WIDTH  # Spell grid at right of canvas
PATH_CANVAS_BG_COLOR = (216, 188, 144)  # #D8BC90
PATH_COLOR = (0, 0, 0)  # Black
SPELL_PANEL_BG_COLOR = (15, 30, 60)
SPELL_PANEL_BORDER_COLOR = (70, 110, 160)


# --------------------------------------------------------------------------- #
# IMU Data Structures
# --------------------------------------------------------------------------- #
@dataclass
class Quaternion:
    w: float
    x: float
    y: float
    z: float

    def normalized(self) -> "Quaternion":
        mag = (self.w**2 + self.x**2 + self.y**2 + self.z**2) ** 0.5
        if mag == 0:
            return Quaternion(1.0, 0.0, 0.0, 0.0)
        return Quaternion(self.w / mag, self.x / mag, self.y / mag, self.z / mag)

    def rotate_vector(self, vector: Tuple[float, float, float]) -> Tuple[float, float, float]:
        """Rotate a 3D vector by this quaternion."""
        qw, qx, qy, qz = self.w, self.x, self.y, self.z
        x, y, z = vector
        
        # Quaternion rotation formula
        ix = qw * x + qy * z - qz * y
        iy = qw * y + qz * x - qx * z
        iz = qw * z + qx * y - qy * x
        iw = -qx * x - qy * y - qz * z
        
        rx = ix * qw + iw * -qx + iy * -qz - iz * -qy
        ry = iy * qw + iw * -qy + iz * -qx - ix * -qz
        rz = iz * qw + iw * -qz + ix * -qy - iy * -qx
        
        return rx, ry, rz


def slerp(q1: Quaternion, q2: Quaternion, t: float) -> Quaternion:
    """Spherical linear interpolation between two quaternions."""
    dot = q1.w * q2.w + q1.x * q2.x + q1.y * q2.y + q1.z * q2.z
    
    if dot < 0.0:
        q2 = Quaternion(-q2.w, -q2.x, -q2.y, -q2.z)
        dot = -dot
    
    DOT_THRESH = 0.9995
    if dot > DOT_THRESH:
        w = q1.w + t * (q2.w - q1.w)
        x = q1.x + t * (q2.x - q1.x)
        y = q1.y + t * (q2.y - q1.y)
        z = q1.z + t * (q2.z - q1.z)
        return Quaternion(w, x, y, z).normalized()
    
    theta_0 = math.acos(dot)
    sin_theta_0 = math.sin(theta_0)
    theta = theta_0 * t
    sin_theta = math.sin(theta)
    s0 = math.cos(theta) - dot * sin_theta / sin_theta_0
    s1 = sin_theta / sin_theta_0
    
    w = (s0 * q1.w) + (s1 * q2.w)
    x = (s0 * q1.x) + (s1 * q2.x)
    y = (s0 * q1.y) + (s1 * q2.y)
    z = (s0 * q1.z) + (s1 * q2.z)
    return Quaternion(w, x, y, z).normalized()


@dataclass
class IMUData:
    quaternion: Optional[Quaternion] = None
    accelerometer: Optional[Tuple[int, int, int]] = None
    gyroscope: Optional[Tuple[int, int, int]] = None
    gravity: Optional[Tuple[int, int, int]] = None
    timestamp: float = 0.0


class KalmanFilter2D:
    """Simple 2D Kalman filter for position and velocity tracking."""
    
    def __init__(self) -> None:
        self.state = np.array([0.0, 0.0, 0.0, 0.0])  # [x, vx, y, vy]
        self.P = np.diag([1.0, 1.0, 1.0, 1.0])
        self.Q = np.diag([1e-4, 1e-3, 1e-4, 1e-3])
        self.R = np.diag([1e-2, 1e-2])
        self.F = np.eye(4)
        self.H = np.array([[1, 0, 0, 0], [0, 0, 1, 0]])
    
    def update(self, measurement: Tuple[float, float], dt: float) -> Tuple[float, float]:
        dt = max(0.001, min(0.1, dt))
        self.F = np.array([[1, dt, 0, 0], [0, 1, 0, 0], [0, 0, 1, dt], [0, 0, 0, 1]])
        self.state = self.F @ self.state
        self.P = self.F @ self.P @ self.F.T + self.Q
        z = np.array([measurement[0], measurement[1]])
        y = z - self.H @ self.state
        S = self.H @ self.P @ self.H.T + self.R
        K = self.P @ self.H.T @ np.linalg.inv(S)
        self.state = self.state + K @ y
        self.P = (np.eye(4) - K @ self.H) @ self.P
        return (float(self.state[0]), float(self.state[2]))
    
    def reset(self) -> None:
        self.state = np.array([0.0, 0.0, 0.0, 0.0])
        self.P = np.diag([1.0, 1.0, 1.0, 1.0])


# --------------------------------------------------------------------------- #
# Spell Definitions
# --------------------------------------------------------------------------- #
@dataclass
class Spell:
    name: str
    image_path: Path
    damage_range: Tuple[int, int]
    description: str


SPELLS: List[Spell] = [
    Spell("Lumos", ASSET_DIR / "spells" / "lumos.png", (15, 25), "Light charm"),
    Spell("Expelliarmus", ASSET_DIR / "spells" / "expelliarmus.png", (22, 34), "Disarming charm"),
    Spell("Engorgio", ASSET_DIR / "spells" / "engorgio.png", (18, 28), "Enlargement charm"),
    Spell("Protego", ASSET_DIR / "spells" / "protego.png", (12, 24), "Shield charm"),
]


# --------------------------------------------------------------------------- #
# CLIP Classification
# --------------------------------------------------------------------------- #
def classify_path_with_clip(path_image: pygame.Surface) -> Optional[Spell]:
    """Classify the drawn path using CLIP model."""
    try:
        import torch
        from transformers import CLIPProcessor, CLIPModel
        from PIL import Image
        import io
        
        # Load CLIP model (lazy loading)
        if not hasattr(classify_path_with_clip, 'model'):
            device = "cuda" if torch.cuda.is_available() else "cpu"
            print(f"Loading CLIP model on {device}...")
            classify_path_with_clip.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
            classify_path_with_clip.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            classify_path_with_clip.device = device
        
        # Convert pygame surface to PIL Image
        img_str = pygame.image.tostring(path_image, "RGB")
        img = Image.open(io.BytesIO(img_str))
        img = img.resize((224, 224))  # CLIP input size
        
        # Prepare text prompts for each spell
        spell_texts = [f"a drawing of the spell {spell.name.lower()}" for spell in SPELLS]
        
        # Process inputs
        inputs = classify_path_with_clip.processor(
            text=spell_texts,
            images=img,
            return_tensors="pt",
            padding=True
        )
        inputs = {k: v.to(classify_path_with_clip.device) for k, v in inputs.items()}
        
        # Get similarity scores
        with torch.no_grad():
            outputs = classify_path_with_clip.model(**inputs)
            logits_per_image = outputs.logits_per_image
            probs = logits_per_image.softmax(dim=1)
        
        # Find best match
        best_idx = probs.argmax().item()
        confidence = probs[0][best_idx].item()
        
        print(f"CLIP classification: {SPELLS[best_idx].name} (confidence: {confidence:.3f})")
        
        # Only return if confidence is reasonable, otherwise return None for randomization
        if confidence > 0.25:
            return SPELLS[best_idx]
        else:
            return None
            
    except ImportError:
        print("⚠️  CLIP not available. Install with: pip install transformers torch pillow")
        return None
    except Exception as e:
        print(f"⚠️  CLIP classification error: {e}")
        return None


# --------------------------------------------------------------------------- #
# UDP Listener
# --------------------------------------------------------------------------- #
def udp_listener(data_queue: queue.Queue[IMUData], stop_event: threading.Event) -> None:
    """Listen for UDP packets from ESP32S3."""
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        sock.bind((LISTEN_HOST, LISTEN_PORT))
        sock.settimeout(0.5)
        print(f"Listening for UDP packets on {LISTEN_HOST}:{LISTEN_PORT}")
    except OSError as e:
        print(f"⚠️  Could not bind to {LISTEN_HOST}:{LISTEN_PORT}: {e}")
        stop_event.set()
        return

    while not stop_event.is_set():
        try:
            data, addr = sock.recvfrom(1024)
        except socket.timeout:
            continue
        except OSError:
            break

        try:
            decoded = data.decode("utf-8").strip()
            parts = decoded.split(",")
            
            if len(parts) < 2:
                continue
            
            packet_type = parts[0]
            imu_data = IMUData()
            imu_data.timestamp = time.time()
            
            if packet_type == "QUAT" and len(parts) == 5:
                qw, qx, qy, qz = (float(p) for p in parts[1:5])
                imu_data.quaternion = Quaternion(qw, qx, qy, qz).normalized()
                data_queue.put(imu_data)
                
        except (ValueError, IndexError) as e:
            continue

    sock.close()
    print("UDP listener stopped.")


# --------------------------------------------------------------------------- #
# Battle Game
# --------------------------------------------------------------------------- #
class BattleGame:
    def __init__(self) -> None:
        pygame.init()
        self._init_mixer()
        self.screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
        pygame.display.set_caption("Hogwarts Battle - IMU Wand Edition")
        self.clock = pygame.time.Clock()
        self.font = pygame.font.SysFont("arial", 20)
        self.big_font = pygame.font.SysFont("arial", 48, bold=True)
        self.timer_font = pygame.font.SysFont("arial", 72, bold=True)

        self.state = "menu"
        self.menu_index = 0
        self.message: str = ""
        self.message_timer: int = 0
        self.attack_pose_timer: int = 0

        self.hero_hp = MAX_HERO_HP
        self.boss_hp = MAX_BOSS_HP
        self.hero_pose = "idle"

        # IMU tracking
        self.imu_data_queue: queue.Queue[IMUData] = queue.Queue()
        self.stop_event = threading.Event()
        self.current_data = IMUData()
        self.current_data.quaternion = Quaternion(1.0, 0.0, 0.0, 0.0)
        self.smoothed_quat = Quaternion(1.0, 0.0, 0.0, 0.0)
        self.slerp_alpha = 0.15
        self.wand_base = np.array([0.0, 0.0, 0.0])
        self.wand_direction = np.array([0.0, 0.0, 1.0])
        self.kalman = KalmanFilter2D()
        self.path_points: deque[Tuple[float, float]] = deque(maxlen=1000)
        self.last_packet_time: Optional[float] = None
        self.running_max_distance = 0.1
        self.path_scale = 50.0

        # Turn-based gameplay
        self.turn_start_time: Optional[float] = None
        self.turn_active = False
        self.current_spell: Optional[Spell] = None
        self.classifying = False

        # Path canvas (1/3 width, 30% height)
        self.path_canvas = pygame.Surface((CANVAS_WIDTH, PANEL_HEIGHT))
        self.path_canvas.fill(PATH_CANVAS_BG_COLOR)

        self._load_assets()
        self._play_music(self.menu_music_path)
        
        # Start UDP listener
        self.listener_thread = threading.Thread(
            target=udp_listener,
            args=(self.imu_data_queue, self.stop_event),
            daemon=True
        )
        self.listener_thread.start()

    def _init_mixer(self) -> None:
        try:
            pygame.mixer.init()
            self.mixer_ready = True
        except pygame.error:
            print("⚠️  Audio mixer could not be initialised. Continuing without sound.")
            self.mixer_ready = False

    def _load_assets(self) -> None:
        """Load images and prepare sprites."""
        bg_path = ASSET_DIR / "background" / "background.jpg"
        hero_path = ASSET_DIR / "characters" / "harry.jpg"
        boss_path = ASSET_DIR / "characters" / "neil-front.jpg"

        self.menu_music_path = ASSET_DIR / "music" / "hp-theme.mp3"
        self.battle_music_path = ASSET_DIR / "music" / "battle-music.mp3"

        self.background_img = pygame.transform.smoothscale(
            pygame.image.load(bg_path).convert(),
            (SCREEN_WIDTH, SCREEN_HEIGHT),
        )

        hero_sheet = pygame.image.load(hero_path).convert_alpha()
        self.hero_idle_img = self._extract_sprite(hero_sheet, HERO_IDLE_RECT, (260, 360))
        self.hero_cast_img = self._extract_sprite(hero_sheet, HERO_CAST_RECT, (280, 360))
        self.hero_sprite = self.hero_idle_img

        boss_img = pygame.image.load(boss_path).convert_alpha()
        boss_scale = (260, 360)
        self.boss_sprite = pygame.transform.smoothscale(boss_img, boss_scale)

        # Load spell images (larger for grid display)
        self.spell_images: Dict[str, pygame.Surface] = {}
        self.spell_images_small: Dict[str, pygame.Surface] = {}  # For after-cast display
        for spell in SPELLS:
            if spell.image_path.exists():
                img = pygame.image.load(spell.image_path).convert_alpha()
                # Large images for grid (about 100x100 to fit in grid cells)
                self.spell_images[spell.name] = pygame.transform.smoothscale(img, (100, 100))
                # Small images for after-cast display
                self.spell_images_small[spell.name] = pygame.transform.smoothscale(img, (80, 80))
            else:
                print(f"⚠️  Spell image not found: {spell.image_path}")

    def _extract_sprite(
        self,
        sheet: pygame.Surface,
        rect: pygame.Rect,
        output_size: Tuple[int, int] | None = None,
    ) -> pygame.Surface:
        """Crop sprite from sheet and scale."""
        sub = sheet.subsurface(rect).copy()
        if output_size:
            sub = pygame.transform.smoothscale(sub, output_size)
        return sub

    def _play_music(self, path: Path) -> None:
        if not self.mixer_ready:
            return
        try:
            pygame.mixer.music.load(path)
            pygame.mixer.music.set_volume(0.6)
            pygame.mixer.music.play(-1)
        except pygame.error as exc:
            print(f"⚠️  Could not play music {path}: {exc}")

    def reset_battle(self) -> None:
        """Reset battle state."""
        self.hero_hp = MAX_HERO_HP
        self.boss_hp = MAX_BOSS_HP
        self.hero_pose = "idle"
        self.hero_sprite = self.hero_idle_img
        self.message = ""
        self.message_timer = 0
        self.attack_pose_timer = 0
        self.path_points.clear()
        self.kalman.reset()
        self.running_max_distance = 0.1
        self.last_packet_time = None
        self.smoothed_quat = Quaternion(1.0, 0.0, 0.0, 0.0)
        self.turn_start_time = None
        self.turn_active = False
        self.current_spell = None
        self.classifying = False
        self._start_turn()

    def _start_turn(self) -> None:
        """Start a new player turn."""
        self.turn_start_time = time.time()
        self.turn_active = True
        self.path_points.clear()
        self.kalman.reset()
        self.running_max_distance = 0.1
        self.current_spell = None
        self.classifying = False

    def _update_imu_path(self, dt: float) -> None:
        """Update IMU path tracking."""
        # Process new IMU data
        try:
            while True:
                new_data = self.imu_data_queue.get_nowait()
                if new_data.quaternion:
                    self.current_data.quaternion = new_data.quaternion
                    self.smoothed_quat = slerp(self.smoothed_quat, new_data.quaternion, self.slerp_alpha)
                
                if new_data.timestamp > 0:
                    if self.last_packet_time is not None:
                        dt = new_data.timestamp - self.last_packet_time
                    self.last_packet_time = new_data.timestamp
        except queue.Empty:
            pass

        if not self.current_data.quaternion:
            return

        # Calculate wand tip position
        quat = self.smoothed_quat
        rotated_dir = np.array(quat.rotate_vector(tuple(self.wand_direction)))
        tip_3d = self.wand_base + rotated_dir * WAND_LENGTH

        # Project to 2D
        if PROJECTION_PLANE == "xz":
            raw_2d = (float(tip_3d[0]), float(tip_3d[2]))
        else:  # yz
            raw_2d = (float(tip_3d[1]), float(tip_3d[2]))

        # Update scaling
        distance = math.sqrt(raw_2d[0]**2 + raw_2d[1]**2)
        self.running_max_distance = max(self.running_max_distance, distance)
        if self.running_max_distance > 1e-6:
            # Use smaller dimension for scaling
            target_radius = min(CANVAS_WIDTH, PANEL_HEIGHT) * 0.35
            self.path_scale = target_radius / self.running_max_distance

        # Apply Kalman filter
        if dt is None or dt <= 0:
            dt = 1.0 / FPS
        filtered_2d = self.kalman.update(raw_2d, dt)
        self.path_points.append(filtered_2d)

    def _draw_path_canvas(self) -> None:
        """Draw the IMU path on the canvas (left side of panel)."""
        self.path_canvas.fill(PATH_CANVAS_BG_COLOR)
        
        if len(self.path_points) > 1:
            # Convert path points to canvas coordinates
            canvas_points = []
            center_x = CANVAS_WIDTH / 2
            center_y = PANEL_HEIGHT / 2
            
            for point in self.path_points:
                x = center_x + point[0] * self.path_scale
                y = center_y - point[1] * self.path_scale  # Invert y
                canvas_points.append((int(x), int(y)))
            
            # Draw path
            if len(canvas_points) > 1:
                pygame.draw.lines(self.path_canvas, PATH_COLOR, False, canvas_points, 2)
            
            # Draw current position
            if canvas_points:
                pygame.draw.circle(self.path_canvas, PATH_COLOR, canvas_points[-1], 3)
        
        # Blit canvas to screen at bottom left of panel
        self.screen.blit(self.path_canvas, (CANVAS_X, PANEL_Y))
    
    def _draw_spell_grid(self) -> None:
        """Draw the spell grid in a 2x2 layout (right side of panel)."""
        # Draw spell panel background
        spell_panel = pygame.Rect(SPELL_GRID_X, PANEL_Y, SPELL_GRID_WIDTH, PANEL_HEIGHT)
        pygame.draw.rect(self.screen, SPELL_PANEL_BG_COLOR, spell_panel)
        pygame.draw.rect(self.screen, SPELL_PANEL_BORDER_COLOR, spell_panel, 3)
        
        # Calculate cell dimensions for 2x2 grid
        cell_width = SPELL_GRID_WIDTH // 2
        cell_height = PANEL_HEIGHT // 2
        padding = 10
        
        # Draw each spell in 2x2 grid
        for idx, spell in enumerate(SPELLS):
            row = idx // 2
            col = idx % 2
            
            cell_x = SPELL_GRID_X + col * cell_width
            cell_y = PANEL_Y + row * cell_height
            
            # Draw cell background (slightly lighter)
            cell_rect = pygame.Rect(cell_x + padding, cell_y + padding, 
                                   cell_width - 2 * padding, cell_height - 2 * padding)
            pygame.draw.rect(self.screen, (25, 45, 80), cell_rect)
            pygame.draw.rect(self.screen, SPELL_PANEL_BORDER_COLOR, cell_rect, 2)
            
            # Draw spell image
            if spell.name in self.spell_images:
                img = self.spell_images[spell.name]
                img_x = cell_x + (cell_width - img.get_width()) // 2
                img_y = cell_y + padding + 5
                self.screen.blit(img, (img_x, img_y))
            
            # Draw spell name
            name_y = cell_y + cell_height - 50
            name_text = self.font.render(spell.name, True, (255, 255, 255))
            name_x = cell_x + (cell_width - name_text.get_width()) // 2
            self.screen.blit(name_text, (name_x, name_y))
            
            # Draw damage range
            damage_text = self.font.render(
                f"{spell.damage_range[0]}-{spell.damage_range[1]} dmg",
                True, (255, 200, 100)
            )
            damage_x = cell_x + (cell_width - damage_text.get_width()) // 2
            damage_y = name_y + 25
            self.screen.blit(damage_text, (damage_x, damage_y))

    def _classify_and_cast_spell(self) -> None:
        """Classify the drawn path and cast the spell."""
        if self.classifying or len(self.path_points) < 10:
            return
        
        self.classifying = True
        self.turn_active = False
        
        # Classify path
        spell = classify_path_with_clip(self.path_canvas)
        
        # If no match, randomize
        if spell is None:
            spell = random.choice(SPELLS)
            print(f"No CLIP match, randomizing to: {spell.name}")
        
        self.current_spell = spell
        
        # Cast spell
        damage = random.randint(*spell.damage_range)
        self.boss_hp = max(0, self.boss_hp - damage)
        self.hero_pose = "attack"
        self.hero_sprite = self.hero_cast_img
        self.attack_pose_timer = ATTACK_POSE_DURATION_MS
        
        self.message = f"{spell.name}! {damage} dmg dealt."
        self.message_timer = MESSAGE_DURATION_MS
        
        if self.boss_hp <= 0:
            self.message = f"{spell.name}! {damage} dmg. Victory!"
            self.message_timer = 0
            self.state = "result"
            return
        
        # Enemy counter-attack
        retaliation = random.randint(12, 24)
        self.hero_hp = max(0, self.hero_hp - retaliation)
        self.message += f" Neil retaliates (-{retaliation})."
        
        if self.hero_hp <= 0:
            self.state = "result"
        else:
            # Start next turn after a delay
            pygame.time.set_timer(pygame.USEREVENT, int(MESSAGE_DURATION_MS), 1)

    def run(self) -> None:
        running = True
        while running:
            dt_ms = self.clock.tick(FPS)
            dt = dt_ms / 1000.0
            
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False
                elif event.type == pygame.USEREVENT:
                    # Timer event for next turn
                    if self.state == "battle" and not self.turn_active:
                        self._start_turn()
                elif event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_ESCAPE and self.state == "battle":
                        self.state = "menu"
                        self._play_music(self.menu_music_path)
                    elif self.state == "menu":
                        self._handle_menu_input(event)
                    elif self.state == "result" and event.key == pygame.K_RETURN:
                        self.state = "menu"
                        self._play_music(self.menu_music_path)

            if self.state == "menu":
                self._draw_menu()
            elif self.state == "battle":
                self._update_battle(dt)
                self._draw_battle()
            elif self.state == "result":
                self._draw_result_screen()

            pygame.display.flip()

        self.stop_event.set()
        pygame.quit()

    def _handle_menu_input(self, event: pygame.event.Event) -> None:
        options = ["Start Game", "Quit"]
        if event.key in (pygame.K_UP, pygame.K_w):
            self.menu_index = (self.menu_index - 1) % len(options)
        elif event.key in (pygame.K_DOWN, pygame.K_s):
            self.menu_index = (self.menu_index + 1) % len(options)
        elif event.key == pygame.K_RETURN:
            if self.menu_index == 0:
                self.reset_battle()
                self.state = "battle"
                self._play_music(self.battle_music_path)
            else:
                pygame.event.post(pygame.event.Event(pygame.QUIT))

    def _update_battle(self, dt: float) -> None:
        """Update battle state."""
        # Update timers
        if self.attack_pose_timer > 0:
            self.attack_pose_timer -= int(dt * 1000)
            if self.attack_pose_timer <= 0:
                self.hero_pose = "idle"
                self.hero_sprite = self.hero_idle_img

        if self.message_timer > 0:
            self.message_timer -= int(dt * 1000)
            if self.message_timer <= 0:
                self.message = ""

        # Update IMU path if turn is active
        if self.turn_active:
            self._update_imu_path(dt)
            
            # Check if turn time expired
            if self.turn_start_time:
                elapsed = time.time() - self.turn_start_time
                if elapsed >= TURN_DURATION_SECONDS:
                    self._classify_and_cast_spell()

    def _draw_menu(self) -> None:
        self.screen.fill((5, 5, 15))
        title = self.big_font.render("Hogwarts Battle", True, (240, 220, 120))
        self.screen.blit(title, title.get_rect(center=(SCREEN_WIDTH // 2, 120)))

        options = ["Start Game", "Quit"]
        for idx, text in enumerate(options):
            color = (255, 255, 255) if idx == self.menu_index else (150, 150, 150)
            option_surface = self.font.render(text, True, color)
            self.screen.blit(
                option_surface,
                option_surface.get_rect(center=(SCREEN_WIDTH // 2, 250 + idx * 40)),
            )

        hint = self.font.render("Use ↑/↓ + Enter. Draw spells with IMU wand!", True, (180, 180, 180))
        self.screen.blit(hint, hint.get_rect(center=(SCREEN_WIDTH // 2, SCREEN_HEIGHT - 60)))

    def _draw_battle(self) -> None:
        self.screen.blit(self.background_img, (0, 0))

        # Draw characters
        hero_pos = (80, SCREEN_HEIGHT - self.hero_sprite.get_height() - 60)
        boss_pos = (
            SCREEN_WIDTH - self.boss_sprite.get_width() - 60,
            SCREEN_HEIGHT - self.boss_sprite.get_height() - 60,
        )
        self.screen.blit(self.hero_sprite, hero_pos)
        flipped_boss = pygame.transform.flip(self.boss_sprite, True, False)
        self.screen.blit(flipped_boss, boss_pos)

        # Draw path canvas and spell grid panel
        self._draw_path_canvas()
        self._draw_spell_grid()

        # Draw turn timer and UI
        self._draw_ui()

    def _draw_ui(self) -> None:
        # Health bars
        self._draw_health_bar(70, 40, 320, 18, self.hero_hp, MAX_HERO_HP, (120, 200, 255), "Harry")
        self._draw_health_bar(
            SCREEN_WIDTH - 390,
            40,
            320,
            18,
            self.boss_hp,
            MAX_BOSS_HP,
            (255, 180, 120),
            "Neil (Boss)",
        )

        # Turn timer and spell display
        if self.turn_active and self.turn_start_time:
            elapsed = time.time() - self.turn_start_time
            remaining = max(0, TURN_DURATION_SECONDS - elapsed)
            
            # "Your Turn" text
            turn_text = self.big_font.render("Your Turn", True, (255, 255, 180))
            turn_rect = turn_text.get_rect(center=(SCREEN_WIDTH // 2, SCREEN_HEIGHT // 2 - 60))
            self.screen.blit(turn_text, turn_rect)
            
            # Countdown timer
            countdown_text = self.timer_font.render(f"{int(remaining) + 1}", True, (255, 100, 100))
            countdown_rect = countdown_text.get_rect(center=(SCREEN_WIDTH // 2, SCREEN_HEIGHT // 2))
            self.screen.blit(countdown_text, countdown_rect)

        # Spell display (after casting) - show above the panel
        if self.current_spell:
            spell_panel = pygame.Rect(SCREEN_WIDTH - 250, PANEL_Y - 130, 230, 120)
            pygame.draw.rect(self.screen, (15, 30, 60), spell_panel)
            pygame.draw.rect(self.screen, (70, 110, 160), spell_panel, 3)
            
            # Spell image
            if self.current_spell.name in self.spell_images_small:
                spell_img = self.spell_images_small[self.current_spell.name]
                self.screen.blit(spell_img, (spell_panel.x + 10, spell_panel.y + 10))
            
            # Spell name and damage
            name_text = self.font.render(self.current_spell.name, True, (255, 255, 255))
            self.screen.blit(name_text, (spell_panel.x + 100, spell_panel.y + 15))
            
            damage_text = self.font.render(
                f"Damage: {self.current_spell.damage_range[0]}-{self.current_spell.damage_range[1]}",
                True, (255, 200, 100)
            )
            self.screen.blit(damage_text, (spell_panel.x + 100, spell_panel.y + 45))

        # Message display
        if self.message:
            message_surface = self.font.render(self.message, True, (255, 255, 180))
            self.screen.blit(message_surface, (SCREEN_WIDTH // 2 - message_surface.get_width() // 2, 100))

    def _draw_health_bar(
        self,
        x: int,
        y: int,
        width: int,
        height: int,
        current_hp: int,
        max_hp: int,
        color: Tuple[int, int, int],
        label: str,
    ) -> None:
        pygame.draw.rect(self.screen, (40, 40, 40), (x, y, width, height))
        ratio = max(0, min(1.0, current_hp / max_hp))
        pygame.draw.rect(self.screen, color, (x, y, int(width * ratio), height))
        pygame.draw.rect(self.screen, (255, 255, 255), (x, y, width, height), 2)

        label_surface = self.font.render(f"{label}: {current_hp}/{max_hp}", True, (255, 255, 255))
        self.screen.blit(label_surface, (x, y - 24))

    def _draw_result_screen(self) -> None:
        self._draw_battle()
        overlay = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT), pygame.SRCALPHA)
        overlay.fill((0, 0, 0, 180))
        self.screen.blit(overlay, (0, 0))

        result_text = "Victory! Hogwarts is safe." if self.boss_hp <= 0 else "Defeat! Train harder."
        prompt = "Press Enter to return to the menu."

        result_surface = self.big_font.render(result_text, True, (255, 230, 180))
        prompt_surface = self.font.render(prompt, True, (255, 255, 255))

        self.screen.blit(result_surface, result_surface.get_rect(center=(SCREEN_WIDTH // 2, 180)))
        self.screen.blit(prompt_surface, prompt_surface.get_rect(center=(SCREEN_WIDTH // 2, 230)))


def main() -> None:
    BattleGame().run()


if __name__ == "__main__":
    main()
