"""Risk management — 12-step check pipeline, circuit breaker, kill switch."""

from __future__ import annotations

import psycopg
from datetime import datetime, timezone
from decimal import Decimal
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Literal, cast, Any

import structlog

from src.config import SigilConfig
from src.contracts import AlertEvent, RiskDecision, TradeSignal

if TYPE_CHECKING:
    from src.executor.position_tracker import PositionTracker

log = structlog.get_logger(__name__)


def get_stop_multiplier(utc_hour: int, market_cap: float | None) -> float:
    """Return a stop-loss distance multiplier based on UTC hour and market cap.

    Rules (evaluated in priority order):
    - If hour is in 02:00-06:59 UTC OR 21:00-23:59 UTC AND market_cap < $50M → 0.75
    - If hour is in 00:00-03:59 UTC → 0.85
    - Otherwise → 1.0

    Parameters
    ----------
    utc_hour:
        Current UTC hour (0-23, exclusive upper bound per range semantics).
    market_cap:
        Token market cap in USD, or None if unavailable.

    Returns
    -------
    float multiplier to apply to the raw stop-loss distance.
    """
    low_liquidity_hour = utc_hour in range(2, 7) or utc_hour in range(21, 24)
    small_cap = market_cap is not None and market_cap < 50_000_000

    if low_liquidity_hour and small_cap:
        return 0.75

    if utc_hour in range(0, 4):
        return 0.85

    return 1.0


class RiskManager:
    """12-step risk check pipeline with circuit breaker and kill switch.

    Checks are evaluated in order; the first failing check terminates the
    pipeline and returns a rejected RiskDecision with the step number and
    reason attached.
    """

    def __init__(
        self,
        config: SigilConfig,
        position_tracker: PositionTracker,
        db_conn: psycopg.Connection,
        alert_fn: Callable[[AlertEvent], None],
    ) -> None:
        self._config = config
        self._tracker = position_tracker
        self._db = db_conn
        self._alert = alert_fn
        self._log = log.bind(component="risk_manager")

        # In-memory circuit breaker state
        self._circuit_breaker_open: bool = False
        self._circuit_breaker_reason: str = ""
        self._circuit_breaker_opened_at: datetime | None = None

        # Recovery hysteresis tracking
        # After circuit-breaker recovery first `_recovery_trades_required`
        # trades run at reduced size; we count successful reduced trades here.
        self._recovery_mode: bool = False
        self._recovery_trades_completed: int = 0

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def evaluate(self, signal: TradeSignal) -> RiskDecision:
        """Run all 12 risk checks in order.

        Returns an approved RiskDecision (possibly with scaled-down quantity)
        or a rejected RiskDecision with the failing step and reason.
        """
        checks = [
            self._check_kill_switch,  # 1
            self._check_draft_mode,  # 2
            self._check_price_staleness,  # 3
            self._check_daily_loss,  # 4
            self._check_max_drawdown,  # 5
            self._check_position_size,  # 6
            self._check_trade_rate,  # 7
            self._check_total_exposure,  # 8
            self._check_cash_reserve,  # 9
            self._check_symbol_cooldown,  # 10
            self._check_circuit_breaker,  # 11
            self._check_confidence,  # 12
            self._check_duplicate_signal,  # 13
        ]

        # Start with the full requested quantity
        quantity = Decimal(str(signal.suggested_quantity_pct))

        for step_number, check in enumerate(checks, start=1):
            passed, reason, quantity = check(signal, step_number, quantity)
            if not passed:
                self._log.warning(
                    "risk_check_rejected",
                    signal_id=str(signal.signal_id),
                    symbol=signal.symbol,
                    step=step_number,
                    reason=reason,
                )
                return RiskDecision(
                    signal_id=signal.signal_id,
                    approved=False,
                    adjusted_quantity=None,
                    rejection_reason=reason,
                    rejection_step=step_number,
                )

        # Apply recovery-mode scaling (after circuit breaker recovery)
        if self._recovery_mode:
            quantity = (
                quantity * Decimal(str(self._config.risk.recovery_position_scale))
            ).quantize(Decimal("0.00000001"))
            self._log.info(
                "risk_recovery_scaling_applied",
                signal_id=str(signal.signal_id),
                symbol=signal.symbol,
                recovery_trades_remaining=(
                    self._config.risk.recovery_trades_required
                    - self._recovery_trades_completed
                ),
            )

        stop_loss_price = self._compute_stop_loss(signal, quantity)

        self._log.info(
            "risk_check_approved",
            signal_id=str(signal.signal_id),
            symbol=signal.symbol,
            adjusted_quantity_pct=str(quantity),
            stop_loss_price=str(stop_loss_price) if stop_loss_price else None,
            recovery_mode=self._recovery_mode,
        )

        return RiskDecision(
            signal_id=signal.signal_id,
            approved=True,
            adjusted_quantity=quantity,
            stop_loss_price=stop_loss_price,
        )

    def is_circuit_breaker_active(self) -> bool:
        """Return True if the circuit breaker is currently open."""
        return self._circuit_breaker_open

    def trigger_circuit_breaker(self, reason: str) -> None:
        """Open the circuit breaker, halt trading, and alert."""
        if self._circuit_breaker_open:
            # Already open — just update the reason
            self._circuit_breaker_reason = reason
            return

        self._circuit_breaker_open = True
        self._circuit_breaker_reason = reason
        self._circuit_breaker_opened_at = datetime.now(timezone.utc)

        self._log.critical(
            "circuit_breaker_triggered",
            reason=reason,
            opened_at=self._circuit_breaker_opened_at.isoformat(),
        )

        self._record_circuit_breaker_event(reason)

        self._alert(
            AlertEvent(
                event_type="circuit_breaker_open",
                severity="critical",
                message=f"Circuit breaker triggered: {reason}",
                mode=cast(Literal["draft", "live"], self._config.mode),
            )
        )

    def reset_circuit_breaker(self, resumed_by: str = "auto") -> None:
        """Close the circuit breaker and enter recovery mode."""
        if not self._circuit_breaker_open:
            return

        self._circuit_breaker_open = False
        self._recovery_mode = True
        self._recovery_trades_completed = 0

        self._log.info(
            "circuit_breaker_reset",
            resumed_by=resumed_by,
            recovery_trades_required=self._config.risk.recovery_trades_required,
        )

        self._update_circuit_breaker_resumed(resumed_by)

    def record_successful_trade(self) -> None:
        """Notify the risk manager that a trade completed successfully.

        Used to advance the recovery-mode counter.
        """
        # Called by OrderManager after successful trade execution
        if not self._recovery_mode:
            return

        self._recovery_trades_completed += 1
        self._log.info(
            "recovery_trade_recorded",
            completed=self._recovery_trades_completed,
            required=self._config.risk.recovery_trades_required,
        )

        if (
            self._recovery_trades_completed
            >= self._config.risk.recovery_trades_required
        ):
            self._recovery_mode = False
            self._recovery_trades_completed = 0
            self._log.info("recovery_mode_complete")

    def check_kill_switch(self) -> bool:
        """Return True if the filesystem kill switch flag is present."""
        return Path(self._config.risk.kill_switch_path).exists()

    # ------------------------------------------------------------------
    # Individual risk checks
    # Each check signature: (signal, step, quantity) -> (passed, reason, quantity)
    # ------------------------------------------------------------------

    def _check_kill_switch(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        if self.check_kill_switch():
            return False, "Kill switch is active", quantity
        return True, "", quantity

    def _check_draft_mode(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        # Draft mode always passes risk (routed to simulator downstream)
        if self._config.mode == "draft":
            return True, "", quantity
        return True, "", quantity

    def _check_price_staleness(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        now = datetime.now(timezone.utc)
        age_seconds = (now - signal.emitted_at).total_seconds()
        if age_seconds > self._config.risk.price_staleness_seconds:
            return (
                False,
                f"Signal price stale: {age_seconds:.1f}s > {self._config.risk.price_staleness_seconds}s",
                quantity,
            )
        return True, "", quantity

    def _check_daily_loss(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        daily_loss_pct = self._get_daily_loss_pct()
        limit = Decimal(str(self._config.risk.daily_loss_limit_pct))
        if daily_loss_pct >= limit:
            self.trigger_circuit_breaker(
                f"Daily loss limit reached: {daily_loss_pct:.2f}% >= {limit}%"
            )
            return (
                False,
                f"Daily loss limit exceeded: {daily_loss_pct:.2f}% >= {limit}%",
                quantity,
            )
        return True, "", quantity

    def _check_max_drawdown(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        drawdown_pct = self._get_max_drawdown_pct()
        limit = Decimal(str(self._config.risk.max_drawdown_pct))
        if drawdown_pct >= limit:
            self.trigger_circuit_breaker(
                f"Max drawdown exceeded: {drawdown_pct:.2f}% >= {limit}%"
            )
            return (
                False,
                f"Max drawdown exceeded: {drawdown_pct:.2f}% >= {limit}%",
                quantity,
            )
        return True, "", quantity

    def _check_position_size(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        if signal.strategy_tier == "pump":
            max_pct = Decimal(str(self._config.pump.max_position_pct))
        else:
            max_pct = Decimal(str(self._config.core.max_position_pct))

        if quantity > max_pct:
            # Scale down rather than reject
            self._log.info(
                "position_size_scaled_down",
                signal_id=str(signal.signal_id),
                symbol=signal.symbol,
                original=str(quantity),
                scaled_to=str(max_pct),
            )
            return True, "", max_pct
        return True, "", quantity

    def _check_trade_rate(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        row = cast(dict[str, Any] | None, self._db.execute(
            """
            SELECT COUNT(*) AS trade_count FROM trades
            WHERE mode = %s
              AND created_at > NOW() - INTERVAL '1 hour'
              AND order_state IN ('FILLED', 'SIMULATED')
            """,
            (self._config.mode,),
        ).fetchone())

        count = row["trade_count"] if row is not None else 0
        limit = self._config.risk.max_trades_per_hour
        if count >= limit:
            return (
                False,
                f"Trade rate limit reached: {count} trades in last hour >= {limit}",
                quantity,
            )
        return True, "", quantity

    def _check_total_exposure(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        # Sells reduce exposure — always allow them through
        if signal.direction == "sell":
            return True, "", quantity

        # Per-tier exposure: only count positions in the same tier
        exposure = self._tracker.get_total_exposure(
            self._config.mode, strategy_tier=signal.strategy_tier
        )
        # Use tier-specific allocation cap instead of the global risk limit
        if signal.strategy_tier == "pump":
            limit = Decimal(str(self._config.pump.max_total_exposure_pct))
        elif signal.strategy_tier == "micro":
            limit = Decimal(str(self._config.micro.max_total_exposure_pct))
        else:
            limit = Decimal(str(self._config.core.max_total_exposure_pct))
        if exposure >= limit:
            return (
                False,
                f"Total exposure {exposure:.2f}% >= {limit}% for tier '{signal.strategy_tier}'",
                quantity,
            )
        # Scale down quantity if it would exceed the tier cap
        remaining = limit - exposure
        if quantity > remaining:
            quantity = remaining
        return True, "", quantity

    def _check_cash_reserve(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        cash = self._tracker.get_cash_balance(self._config.mode)
        # Cash is stored as a percentage of total portfolio
        min_cash = Decimal(str(self._config.risk.min_cash_reserve_pct))
        if cash < min_cash:
            return (
                False,
                f"Cash reserve {cash:.2f}% < {min_cash}% minimum",
                quantity,
            )
        return True, "", quantity

    def _check_symbol_cooldown(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        if signal.strategy_tier == "pump":
            cooldown_hours = self._config.pump.reentry_cooldown_hours
        else:
            cooldown_hours = self._config.core.reentry_cooldown_hours

        # Check sell-side cooldown: enforce reentry cooldown after a sell fill
        sell_row = cast(dict[str, Any] | None, self._db.execute(
            """
            SELECT updated_at FROM trades
            WHERE symbol = %s AND mode = %s AND side = 'sell'
              AND order_state IN ('FILLED', 'SIMULATED')
            ORDER BY updated_at DESC LIMIT 1
            """,
            (signal.symbol, self._config.mode),
        ).fetchone())

        if sell_row is not None:
            last_exit_str = sell_row["updated_at"]
            try:
                last_exit = datetime.fromisoformat(last_exit_str).replace(
                    tzinfo=timezone.utc
                )
                elapsed_hours = (
                    datetime.now(timezone.utc) - last_exit
                ).total_seconds() / 3600
                if elapsed_hours < cooldown_hours:
                    return (
                        False,
                        f"Symbol cooldown active: {elapsed_hours:.1f}h elapsed of {cooldown_hours}h",
                        quantity,
                    )
            except (ValueError, TypeError):
                pass

        # Check buy-side cooldown: prevent rapid-fire buys within 5 minutes
        _BUY_COOLDOWN_MINUTES = 5
        buy_row = cast(dict[str, Any] | None, self._db.execute(
            """
            SELECT updated_at FROM trades
            WHERE symbol = %s AND mode = %s AND side = 'buy'
              AND order_state IN ('FILLED', 'SIMULATED')
            ORDER BY updated_at DESC LIMIT 1
            """,
            (signal.symbol, self._config.mode),
        ).fetchone())

        if buy_row is not None:
            last_buy_str = buy_row["updated_at"]
            try:
                last_buy = datetime.fromisoformat(last_buy_str).replace(
                    tzinfo=timezone.utc
                )
                elapsed_minutes = (
                    datetime.now(timezone.utc) - last_buy
                ).total_seconds() / 60
                if elapsed_minutes < _BUY_COOLDOWN_MINUTES:
                    return (
                        False,
                        f"Buy cooldown active: {elapsed_minutes:.1f}m elapsed of {_BUY_COOLDOWN_MINUTES}m",
                        quantity,
                    )
            except (ValueError, TypeError):
                pass

        return True, "", quantity

    def _check_circuit_breaker(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        if self._circuit_breaker_open:
            return (
                False,
                f"Circuit breaker open: {self._circuit_breaker_reason}",
                quantity,
            )
        return True, "", quantity

    def _check_confidence(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        if signal.strategy_tier == "pump":
            min_confidence = self._config.pump.min_confidence
        else:
            min_confidence = self._config.core.min_confidence

        if signal.confidence < min_confidence:
            return (
                False,
                f"Confidence {signal.confidence:.3f} < threshold {min_confidence:.3f}",
                quantity,
            )
        return True, "", quantity

    def _check_duplicate_signal(
        self, signal: TradeSignal, step: int, quantity: Decimal
    ) -> tuple[bool, str, Decimal]:
        # Check for an already-open order on this symbol in the same mode
        row = self._db.execute(
            """
            SELECT id FROM trades
            WHERE symbol = %s AND mode = %s AND side = 'buy'
              AND order_state IN ('SUBMITTED', 'PARTIAL_ACTIVE')
            LIMIT 1
            """,
            (signal.symbol, self._config.mode),
        ).fetchone()

        if row is not None:
            return (
                False,
                f"Duplicate: open buy order already exists for {signal.symbol}",
                quantity,
            )

        # Also check for duplicate signal_id
        sig_row = self._db.execute(
            "SELECT id FROM trades WHERE signal_id = %s LIMIT 1",
            (str(signal.signal_id),),
        ).fetchone()
        if sig_row is not None:
            return (
                False,
                f"Duplicate signal_id already processed: {signal.signal_id}",
                quantity,
            )

        return True, "", quantity

    # ------------------------------------------------------------------
    # Portfolio data helpers
    # ------------------------------------------------------------------

    def _get_daily_loss_pct(self) -> Decimal:
        """Compute today's P&L as a percentage from portfolio snapshots."""
        rows = cast(list[dict[str, Any]], self._db.execute(
            """
            SELECT total_value_usd FROM portfolio_snapshots
            WHERE mode = %s AND snapshot_at::date = CURRENT_DATE
            ORDER BY snapshot_at ASC
            LIMIT 2
            """,
            (self._config.mode,),
        ).fetchall())

        if len(rows) < 2:
            # Not enough snapshots today — check against yesterday's close
            yesterday = cast(dict[str, Any] | None, self._db.execute(
                """
                SELECT total_value_usd FROM portfolio_snapshots
                WHERE mode = %s AND snapshot_at::date = CURRENT_DATE - INTERVAL '1 day'
                ORDER BY snapshot_at DESC LIMIT 1
                """,
                (self._config.mode,),
            ).fetchone())

            today_latest = cast(dict[str, Any] | None, self._db.execute(
                """
                SELECT total_value_usd FROM portfolio_snapshots
                WHERE mode = %s ORDER BY snapshot_at DESC LIMIT 1
                """,
                (self._config.mode,),
            ).fetchone())

            if yesterday is None or today_latest is None:
                return Decimal("0")

            try:
                start = Decimal(str(yesterday["total_value_usd"]))
                current = Decimal(str(today_latest["total_value_usd"]))
            except Exception:
                return Decimal("0")
        else:
            try:
                start = Decimal(str(rows[0]["total_value_usd"]))
                current = Decimal(str(rows[-1]["total_value_usd"]))
            except Exception:
                return Decimal("0")

        if start == 0:
            return Decimal("0")

        pnl_pct = (start - current) / start * Decimal("100")
        return max(pnl_pct, Decimal("0"))  # Only report losses

    def _get_max_drawdown_pct(self) -> Decimal:
        """Compute drawdown from all-time-high portfolio value."""
        ath_row = cast(dict[str, Any] | None, self._db.execute(
            """
            SELECT MAX(CAST(total_value_usd AS NUMERIC)) AS ath
            FROM portfolio_snapshots WHERE mode = %s
            """,
            (self._config.mode,),
        ).fetchone())

        current_row = cast(dict[str, Any] | None, self._db.execute(
            """
            SELECT total_value_usd FROM portfolio_snapshots
            WHERE mode = %s ORDER BY snapshot_at DESC LIMIT 1
            """,
            (self._config.mode,),
        ).fetchone())

        if ath_row is None or ath_row["ath"] is None or current_row is None:
            return Decimal("0")

        try:
            ath = Decimal(str(ath_row["ath"]))
            current = Decimal(str(current_row["total_value_usd"]))
        except Exception:
            return Decimal("0")

        if ath == 0:
            return Decimal("0")

        drawdown = (ath - current) / ath * Decimal("100")
        return max(drawdown, Decimal("0"))

    def _compute_stop_loss(
        self, signal: TradeSignal, quantity: Decimal
    ) -> Decimal | None:
        """Return stop-loss price placeholder.

        The actual price is not available here (no ticker access in
        risk_manager); callers that need a real stop-loss price should
        override via the order placement layer.  We return None so the
        OrderManager can fill it in.

        When a concrete stop distance is available, callers must multiply it by
        ``get_stop_multiplier(datetime.now(timezone.utc).hour, market_cap)``
        before use — this widens the stop during low-liquidity hours for
        small-cap tokens to reduce stop-hunting losses.
        """
        market_cap = getattr(signal, "market_cap", None)
        multiplier = get_stop_multiplier(datetime.now(timezone.utc).hour, market_cap)
        # Log the multiplier for observability; actual application happens at
        # the order-placement layer where the real entry price is available.
        if multiplier != 1.0:
            self._log.debug(
                "stop_multiplier_applied",
                symbol=signal.symbol,
                multiplier=multiplier,
                utc_hour=datetime.now(timezone.utc).hour,
                market_cap=market_cap,
            )
        return None

    # ------------------------------------------------------------------
    # DB helpers
    # ------------------------------------------------------------------

    def _record_circuit_breaker_event(self, reason: str) -> None:
        try:
            self._db.execute(
                """
                INSERT INTO circuit_breaker_events (trigger_reason)
                VALUES (%s)
                """,
                (reason,),
            )
            self._db.commit()
        except Exception:
            self._log.exception("circuit_breaker_event_record_failed", reason=reason)

    def _update_circuit_breaker_resumed(self, resumed_by: str) -> None:
        try:
            self._db.execute(
                """
                UPDATE circuit_breaker_events
                SET trading_resumed_at = NOW(), resumed_by = %s
                WHERE id = (
                    SELECT id FROM circuit_breaker_events
                    WHERE trading_resumed_at IS NULL
                    ORDER BY trading_halted_at DESC
                    LIMIT 1
                )
                """,
                (resumed_by,),
            )
            self._db.commit()
        except Exception:
            self._log.exception("circuit_breaker_resume_record_failed")
