"""Breakout / Fall entry + exit detector for the paper trading leg.

Pure-ish class: entry evaluation is stateless; exit evaluation maintains
per-symbol state (previous RSI for reversal detection, ROC history for
momentum-deceleration detection) keyed by symbol.

All methods are synchronous and free of I/O — the caller (MicroTradingManager)
fetches OHLCV + tickers and passes them in.
"""

from __future__ import annotations

from collections import deque
from datetime import datetime, timezone
from typing import Deque

import structlog

from src.micro_scanner.config import MicroConfig
from src.micro_scanner.indicators import calculate_rsi

log = structlog.get_logger(__name__)


class BreakoutDetector:
    def __init__(self, config: MicroConfig) -> None:
        self._cfg = config
        self._prev_rsi: dict[str, float] = {}
        self._roc_history: dict[str, Deque[float]] = {}

    # ------------------------------------------------------------------
    # Helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _closes(candles: list[dict]) -> list[float]:
        return [float(c["close"]) for c in candles]

    @staticmethod
    def _pct_change(a: float, b: float) -> float:
        if b == 0:
            return 0.0
        return (a - b) / b * 100.0

    @staticmethod
    def _roc(a: float, b: float) -> float:
        if b == 0:
            return 0.0
        return (a - b) / b

    def _momentum_acceleration(self, closes: list[float]) -> float | None:
        if len(closes) < 3:
            return None
        current_roc = self._roc(closes[-1], closes[-2])
        prev_roc = self._roc(closes[-2], closes[-3])
        return current_roc - prev_roc

    # ------------------------------------------------------------------
    # Entry evaluation
    # ------------------------------------------------------------------

    def evaluate_entry(
        self,
        symbol: str,
        candles_1h: list[dict],
        candles_2h: list[dict],
    ) -> bool:
        cfg = self._cfg

        if len(candles_1h) < 2:
            log.debug(
                "breakout_entry_rejected",
                symbol=symbol,
                reason="insufficient_1h_candles",
                candles=len(candles_1h),
            )
            return False

        closes_1h = self._closes(candles_1h)
        change_1h = self._pct_change(closes_1h[-1], closes_1h[-2])

        if not (cfg.breakout_1h_min_pct < change_1h < cfg.breakout_1h_max_pct):
            log.debug(
                "breakout_entry_rejected",
                symbol=symbol,
                reason="1h_change_out_of_band",
                change_1h=round(change_1h, 4),
                min=cfg.breakout_1h_min_pct,
                max=cfg.breakout_1h_max_pct,
            )
            return False

        if len(candles_2h) >= 3:
            closes_2h = self._closes(candles_2h)
            change_2h = self._pct_change(closes_2h[-1], closes_2h[-3])
        else:
            change_2h = 0.0

        if change_2h >= cfg.breakout_2h_max_pct:
            log.debug(
                "breakout_entry_rejected",
                symbol=symbol,
                reason="2h_change_too_high",
                change_2h=round(change_2h, 4),
                max=cfg.breakout_2h_max_pct,
            )
            return False

        rsi = calculate_rsi(closes_1h, period=14)
        if rsi is None:
            log.debug(
                "breakout_entry_rejected",
                symbol=symbol,
                reason="rsi_insufficient_data",
                closes=len(closes_1h),
            )
            return False

        if rsi >= cfg.breakout_rsi_max:
            log.debug(
                "breakout_entry_rejected",
                symbol=symbol,
                reason="rsi_too_high",
                rsi=round(rsi, 4),
                max=cfg.breakout_rsi_max,
            )
            return False

        momentum_accel = self._momentum_acceleration(closes_1h)
        if momentum_accel is None:
            log.debug(
                "breakout_entry_rejected",
                symbol=symbol,
                reason="momentum_insufficient_data",
                closes=len(closes_1h),
            )
            return False

        if momentum_accel <= self._cfg.momentum_accel_min:
            log.debug(
                "breakout_entry_rejected",
                symbol=symbol,
                reason="momentum_not_accelerating",
                momentum_accel=round(momentum_accel, 6),
            )
            return False

        score = self.compute_entry_score(symbol, candles_1h)
        log.info(
            "breakout_entry_signal",
            symbol=symbol,
            change_1h=round(change_1h, 4),
            change_2h=round(change_2h, 4),
            rsi=round(rsi, 4),
            momentum_accel=round(momentum_accel, 6),
            score=round(score, 4),
        )
        return True

    def compute_entry_score(self, symbol: str, candles_1h: list[dict]) -> float:
        cfg = self._cfg
        if len(candles_1h) < 3:
            return 0.0

        closes_1h = self._closes(candles_1h)
        change_1h = self._pct_change(closes_1h[-1], closes_1h[-2])

        rsi = calculate_rsi(closes_1h, period=14)
        if rsi is None:
            return 0.0

        momentum_accel = self._momentum_acceleration(closes_1h)
        if momentum_accel is None:
            return 0.0

        max_1h = cfg.breakout_1h_max_pct or 1.0
        max_rsi = cfg.breakout_rsi_max or 1.0

        change_1h_score = min(1.0, max(0.0, change_1h / max_1h))
        rsi_headroom_score = max(0.0, (cfg.breakout_rsi_max - rsi) / max_rsi)
        momentum_accel_score = min(1.0, abs(momentum_accel) * 100.0)

        score = (
            change_1h_score * 0.50
            + rsi_headroom_score * 0.30
            + momentum_accel_score * 0.20
        )
        return max(0.0, min(1.0, score))

    # ------------------------------------------------------------------
    # Exit evaluation
    # ------------------------------------------------------------------

    def evaluate_exit(
        self,
        symbol: str,
        position: dict,
        current_price: float,
        candles_1h: list[dict],
    ) -> str | None:
        cfg = self._cfg

        entry_price = float(position["entry_price"])
        peak_price = float(position.get("peak_price") or entry_price)
        entry_time = position["entry_time"]

        def _pnl_pct() -> float:
            if entry_price <= 0:
                return 0.0
            return round(self._pct_change(current_price, entry_price), 4)

        # 1. Hard stop
        if entry_price > 0 and current_price < entry_price * (
            1 - cfg.paper_stop_loss_pct / 100.0
        ):
            log.info(
                "breakout_exit_signal",
                symbol=symbol,
                reason="stop_loss",
                entry_price=entry_price,
                exit_price=current_price,
                pnl_pct=_pnl_pct(),
            )
            return "stop_loss"

        # 2. Trailing stop
        if peak_price > 0 and current_price < peak_price * (
            1 - cfg.trail_stop_pct / 100.0
        ):
            log.info(
                "breakout_exit_signal",
                symbol=symbol,
                reason="trailing_stop",
                entry_price=entry_price,
                exit_price=current_price,
                pnl_pct=_pnl_pct(),
            )
            return "trailing_stop"

        # 3. Time limit
        now = datetime.now(timezone.utc)
        entry_time_utc = entry_time
        if hasattr(entry_time_utc, "tzinfo") and entry_time_utc.tzinfo is None:
            entry_time_utc = entry_time_utc.replace(tzinfo=timezone.utc)
        try:
            elapsed_s = (now - entry_time_utc).total_seconds()
        except Exception:
            elapsed_s = 0.0
        if elapsed_s > cfg.max_hold_hours * 3600:
            log.info(
                "breakout_exit_signal",
                symbol=symbol,
                reason="time_limit",
                entry_price=entry_price,
                exit_price=current_price,
                pnl_pct=_pnl_pct(),
            )
            return "time_limit"

        # 4. RSI reversal — requires enough candles for RSI(14)
        closes_1h = self._closes(candles_1h) if candles_1h else []
        current_rsi = calculate_rsi(closes_1h, period=14) if closes_1h else None

        if current_rsi is not None:
            prev_rsi = self._prev_rsi.get(symbol)
            # update AFTER the comparison so we compare current vs previous tick
            if (
                prev_rsi is not None
                and prev_rsi > cfg.rsi_overbought
                and current_rsi < cfg.rsi_reversal
            ):
                self._prev_rsi[symbol] = current_rsi
                log.info(
                    "breakout_exit_signal",
                    symbol=symbol,
                    reason="rsi_reversal",
                    entry_price=entry_price,
                    exit_price=current_price,
                    pnl_pct=_pnl_pct(),
                    prev_rsi=round(prev_rsi, 4),
                    current_rsi=round(current_rsi, 4),
                )
                return "rsi_reversal"
            self._prev_rsi[symbol] = current_rsi

        # 5. Momentum deceleration — 3 consecutive declining ROCs
        if len(closes_1h) >= 2:
            roc = self._roc(closes_1h[-1], closes_1h[-2])
            hist = self._roc_history.setdefault(symbol, deque(maxlen=3))
            hist.append(roc)
            if len(hist) == 3 and hist[0] > hist[1] > hist[2]:
                log.info(
                    "breakout_exit_signal",
                    symbol=symbol,
                    reason="momentum_decel",
                    entry_price=entry_price,
                    exit_price=current_price,
                    pnl_pct=_pnl_pct(),
                    roc_history=[round(r, 6) for r in hist],
                )
                return "momentum_decel"

        return None
