"""Startup state machine and graceful shutdown for Sigil bot."""

from __future__ import annotations

import asyncio
import signal
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

import structlog

from src.observability.health import (
    set_reconciliation_complete,
    update_component_status,
)
from src.storage.database import (
    close_pool,
    get_connection,
    init_db,
    init_pool,
    return_connection,
    write_heartbeat,
)

if TYPE_CHECKING:
    from src.config import SigilConfig

log = structlog.get_logger()


class Lifecycle:
    """Manages the 10-step startup state machine and graceful shutdown."""

    def __init__(self, config: SigilConfig) -> None:
        self.config = config
        self._shutdown_event = asyncio.Event()
        self._db_conn = None
        self._components: dict = {}

    async def startup(self) -> dict:
        """Execute the 10-step startup state machine.

        Returns dict of initialised components for the orchestrator.
        """
        steps = [
            ("LOAD", self._step_load),
            ("MODEL", self._step_model),
            ("ORDERS", self._step_reconcile_orders),
            ("BALANCE", self._step_reconcile_balances),
            ("KILL", self._step_check_kill),
            ("SCHED", self._step_scheduler),
            ("HTTP", self._step_http),
            ("HEART", self._step_heartbeat),
            ("READY", self._step_ready),
            ("ALERT", self._step_alert),
        ]

        for step_name, step_fn in steps:
            log.info("startup_step", step=step_name, status="starting")
            try:
                await step_fn()
                update_component_status(step_name, "ok")
                log.info("startup_step", step=step_name, status="complete")
            except Exception as e:
                log.error("startup_step_failed", step=step_name, error=str(e))
                update_component_status(step_name, f"failed: {e}")
                raise

        return self._components

    async def _step_load(self) -> None:
        """Step 1: Load config, secrets, init DB pool, apply migrations."""
        log.info(
            "config_loaded",
            mode=self.config.mode,
            exchange=self.config.exchange.exchange_id,
            core_symbols=self.config.core.symbols,
        )

        init_pool()
        init_db()
        self._db_conn = get_connection()
        self._components["db"] = self._db_conn

        # Verify config HMAC if key is set
        if self.config.config_hmac_key:
            from src.config import verify_config_hmac

            config_path = Path("config")
            if config_path.exists() and not verify_config_hmac(
                config_path, self.config.config_hmac_key
            ):
                log.warning("config_hmac_mismatch", path=str(config_path))

    async def _step_model(self) -> None:
        """Step 2: Verify ML model HMAC and load (or fall back to rules)."""
        if self.config.ml.enabled and self.config.ml.model_path:
            from src.signals.ml_strategy import MLStrategy

            ml = MLStrategy(self.config)
            # MLStrategy.__init__ calls load_model if config.ml.model_path is set
            if ml._ml_enabled:
                log.info(
                    "model_step",
                    strategy="ml-v1",
                    model_path=self.config.ml.model_path,
                )
                self._components["strategy"] = "ml"
                self._components["ml_strategy"] = ml
                return

            log.warning(
                "model_step",
                strategy="rules-v1",
                note="ML enabled but model failed to load — falling back to rules",
            )
        else:
            log.info(
                "model_step",
                strategy="rules-v1",
                note="ML not enabled — using rules strategy",
            )
        self._components["strategy"] = "rules"

    async def _step_reconcile_orders(self) -> None:
        """Step 3: Fetch all open orders from exchange, reconcile with DB."""
        if self.config.mode == "draft":
            log.info("reconcile_orders_skip", reason="draft mode, no exchange orders")
            return

        # In live mode, this would query exchange open orders
        # and update the trades table to match reality
        log.info("reconcile_orders", mode=self.config.mode)

    async def _step_reconcile_balances(self) -> None:
        """Step 4: Fetch balances, compare against positions table."""
        assert self._db_conn is not None
        if self.config.mode == "draft":
            # Seed initial portfolio snapshot if none exists
            existing = cast(dict[str, Any] | None, self._db_conn.execute(
                "SELECT COUNT(*) as cnt FROM portfolio_snapshots WHERE mode = 'draft'"
            ).fetchone())
            assert existing is not None
            if existing["cnt"] == 0:
                starting_capital = getattr(
                    self.config.core, "draft_starting_capital", 50.0
                )
                self._db_conn.execute(
                    """INSERT INTO portfolio_snapshots
                       (total_value_usd, cash_balance, positions_value, drawdown_from_ath, mode)
                       VALUES (%s, %s, '0', 0.0, 'draft')""",
                    (str(starting_capital), str(starting_capital)),
                )
                self._db_conn.commit()
                log.info("draft_capital_seeded", amount=starting_capital)
            else:
                latest = cast(dict[str, Any] | None, self._db_conn.execute(
                    "SELECT total_value_usd FROM portfolio_snapshots WHERE mode='draft' ORDER BY snapshot_at DESC LIMIT 1"
                ).fetchone())
                log.info(
                    "draft_capital_existing",
                    balance=latest["total_value_usd"] if latest else "unknown",
                )
            return

        log.info("reconcile_balances", mode=self.config.mode)

    async def _step_check_kill(self) -> None:
        """Step 5: Check kill switch file — abort if present."""
        kill_path = Path(self.config.risk.kill_switch_path)
        if kill_path.exists():
            log.critical("kill_switch_active", path=str(kill_path))
            raise RuntimeError(f"Kill switch active: {kill_path}")

    async def _step_scheduler(self) -> None:
        """Step 6: Start APScheduler jobs (placeholder — wired in main.py)."""
        log.info("scheduler_ready")

    async def _step_http(self) -> None:
        """Step 7: Start aiohttp server (wired in main.py)."""
        log.info("http_ready")

    async def _step_heartbeat(self) -> None:
        """Step 8: Write initial heartbeat."""
        write_heartbeat(self.config.heartbeat_path)
        log.info("heartbeat_written", path=self.config.heartbeat_path)

    async def _step_ready(self) -> None:
        """Step 9: Mark as ready."""
        set_reconciliation_complete()
        log.info("bot_ready", mode=self.config.mode)

    async def _step_alert(self) -> None:
        """Step 10: Send startup alert."""
        mode_tag = self.config.mode.upper()
        log.info(
            "bot_started", mode=mode_tag, message=f"Sigil bot started [{mode_tag}]"
        )
        # Alert sending is handled by the orchestrator after components are wired

    async def shutdown(self) -> None:
        """Graceful shutdown on SIGTERM."""
        log.info("shutdown_initiated")

        # Step 1: Log open orders
        if self._db_conn:
            try:
                open_orders = self._db_conn.execute(
                    "SELECT id, symbol, order_state FROM trades WHERE order_state IN ('SUBMITTED', 'PARTIAL_ACTIVE')"
                ).fetchall()
                if open_orders:
                    log.warning(
                        "shutdown_open_orders",
                        count=len(open_orders),
                        orders=[dict(o) for o in open_orders],
                    )
                else:
                    log.info("shutdown_no_open_orders")
            except Exception as e:
                log.error("shutdown_order_check_failed", error=str(e))

        # Step 2: Cancel pending orders (would cancel via exchange in live mode)
        log.info("shutdown_cancelling_pending")

        # Step 3: Return connection and close pool
        if self._db_conn:
            return_connection(self._db_conn)
            self._db_conn = None
            close_pool()
            log.info("shutdown_db_closed")

        # Step 4: Final log
        log.info("shutdown_complete")

    def setup_signal_handlers(self, loop: asyncio.AbstractEventLoop) -> None:
        """Register SIGTERM/SIGINT handlers."""
        for sig in (signal.SIGTERM, signal.SIGINT):
            loop.add_signal_handler(sig, lambda: self._shutdown_event.set())

    @property
    def shutdown_event(self) -> asyncio.Event:
        return self._shutdown_event
