"""Position state management and exchange reconciliation."""

from __future__ import annotations

import psycopg
from datetime import datetime, timezone
from decimal import Decimal
from typing import Callable, Any, cast

import structlog

from src.config import SigilConfig
from src.contracts import AlertEvent, PositionInfo
from src.symbols import to_ccxt_symbol

log = structlog.get_logger(__name__)

# Threshold above which a balance discrepancy triggers an alert
_DISCREPANCY_ALERT_PCT = Decimal("1")


class PositionTracker:
    """Manages position state and reconciliation against the exchange."""

    def __init__(
        self,
        config: SigilConfig,
        exchange,  # ccxt exchange instance — no static type to avoid ccxt import dep
        db_conn: psycopg.Connection,
        alert_fn: Callable[[AlertEvent], None],
    ) -> None:
        self._config = config
        self._exchange = exchange
        self._db = db_conn
        self._alert = alert_fn
        self._log = log.bind(component="position_tracker")

    # ------------------------------------------------------------------
    # Reconciliation
    # ------------------------------------------------------------------

    async def reconcile_orders(self) -> None:
        """Fetch open orders from exchange and update SQLite to match."""
        if self._config.mode == "draft":
            self._log.debug("reconcile_orders_skipped", reason="draft_mode")
            return

        try:
            open_orders = await self._exchange.fetch_open_orders()
        except Exception:
            self._log.exception("reconcile_orders_exchange_fetch_failed")
            return

        exchange_client_ids: set[str] = set()

        for order in open_orders:
            client_id = order.get("clientOrderId") or order.get("id", "")
            exchange_id = str(order.get("id", ""))
            filled = Decimal(str(order.get("filled", 0) or 0))
            status = order.get("status", "open")

            exchange_client_ids.add(client_id)

            order_state = "PARTIAL_ACTIVE" if filled > 0 else "SUBMITTED"
            if status == "closed":
                order_state = "FILLED"
            elif status == "canceled":
                order_state = "CANCELLED"

            self._db.execute(
                """
                UPDATE trades
                SET order_state = %s,
                    filled_quantity = %s,
                    exchange_order_id = %s,
                    updated_at = NOW()
                WHERE client_order_id = %s AND mode = 'live'
                  AND order_state NOT IN ('FILLED', 'CANCELLED', 'REJECTED')
                """,
                (order_state, str(filled), exchange_id, client_id),
            )

        # Mark any orders we think are open but the exchange no longer has
        rows = cast(list[dict[str, Any]], self._db.execute(
            """
            SELECT client_order_id, symbol FROM trades
            WHERE mode = 'live'
              AND order_state IN ('SUBMITTED', 'PARTIAL_ACTIVE')
            """
        ).fetchall())

        for row in rows:
            cid = row["client_order_id"]
            if cid not in exchange_client_ids:
                self._log.warning(
                    "order_not_found_on_exchange",
                    client_order_id=cid,
                    symbol=row["symbol"],
                )
                # Conservatively mark as CANCELLED; downstream audit will confirm
                self._db.execute(
                    """
                    UPDATE trades SET order_state = 'CANCELLED', updated_at = NOW()
                    WHERE client_order_id = %s AND mode = 'live'
                      AND order_state IN ('SUBMITTED', 'PARTIAL_ACTIVE')
                    """,
                    (cid,),
                )

        self._db.commit()
        self._log.info(
            "reconcile_orders_complete",
            open_orders_on_exchange=len(open_orders),
        )

    async def reconcile_balances(self) -> None:
        """Fetch balances from exchange; alert on >1% discrepancy vs positions."""
        if self._config.mode == "draft":
            self._log.debug("reconcile_balances_skipped", reason="draft_mode")
            return

        try:
            balance = await self._exchange.fetch_balance()
        except Exception:
            self._log.exception("reconcile_balances_exchange_fetch_failed")
            return

        positions = self.get_all_positions("live")

        for pos in positions:
            # Derive base currency from symbol (e.g. BTCUSDT -> BTC)
            try:
                ccxt_sym = to_ccxt_symbol(pos.symbol)
                base = ccxt_sym.split("/")[0]
            except ValueError:
                continue

            exchange_qty = Decimal(str(balance.get(base, {}).get("total", 0) or 0))
            db_qty = pos.quantity

            if db_qty == 0:
                continue

            discrepancy_pct = abs(exchange_qty - db_qty) / db_qty * Decimal("100")
            if discrepancy_pct > _DISCREPANCY_ALERT_PCT:
                self._log.warning(
                    "balance_discrepancy_detected",
                    symbol=pos.symbol,
                    exchange_qty=str(exchange_qty),
                    db_qty=str(db_qty),
                    discrepancy_pct=str(discrepancy_pct),
                )
                self._alert(
                    AlertEvent(
                        event_type="balance_discrepancy",
                        symbol=pos.symbol,
                        severity="warn",
                        message=(
                            f"Balance discrepancy for {pos.symbol}: "
                            f"exchange={exchange_qty}, db={db_qty}, "
                            f"diff={discrepancy_pct:.2f}%"
                        ),
                        mode="live",
                    )
                )

        self._log.info("reconcile_balances_complete", positions_checked=len(positions))

    async def full_reconciliation(self) -> None:
        """Run order then balance reconciliation (startup sequence)."""
        self._log.info("full_reconciliation_start")
        await self.reconcile_orders()
        await self.reconcile_balances()
        self._log.info("full_reconciliation_complete")

    # ------------------------------------------------------------------
    # Position reads
    # ------------------------------------------------------------------

    def get_position(self, symbol: str, mode: str) -> PositionInfo | None:
        """Return a single open position, or None if not found."""
        row = self._db.execute(
            """
            SELECT symbol, mode, quantity, avg_entry_price,
                   strategy_tier, stop_loss_price, take_profit_price,
                   max_hold_until, exchange_order_id
            FROM positions
            WHERE symbol = %s AND mode = %s AND (quantity IS NULL OR CAST(quantity AS NUMERIC) > 0)
            """,
            (symbol, mode),
        ).fetchone()

        if row is None:
            return None

        return self._row_to_position(row)

    def get_all_positions(self, mode: str) -> list[PositionInfo]:
        """Return all open positions for the given mode."""
        rows = self._db.execute(
            """
            SELECT symbol, mode, quantity, avg_entry_price,
                   strategy_tier, stop_loss_price, take_profit_price,
                   max_hold_until, exchange_order_id
            FROM positions
            WHERE mode = %s AND CAST(quantity AS NUMERIC) > 0
            """,
            (mode,),
        ).fetchall()

        return [self._row_to_position(row) for row in rows]

    def get_total_exposure(
        self, mode: str, strategy_tier: str | None = None
    ) -> Decimal:
        """Return total exposure as a percentage of portfolio (sum of position pcts).

        Exposure is the sum of (quantity * avg_entry_price) for all open
        positions divided by total portfolio value, expressed as a percentage.
        When no portfolio snapshot is available returns 0.

        When ``strategy_tier`` is provided only positions belonging to that
        tier are included, giving per-tier exposure rather than aggregate.
        """
        if strategy_tier is not None:
            rows = self._db.execute(
                """
                SELECT symbol, mode, quantity, avg_entry_price,
                       strategy_tier, stop_loss_price, take_profit_price,
                       max_hold_until, exchange_order_id
                FROM positions
                WHERE mode = %s AND CAST(quantity AS NUMERIC) > 0
                  AND strategy_tier = %s
                """,
                (mode, strategy_tier),
            ).fetchall()
            positions = [self._row_to_position(row) for row in rows]
        else:
            positions = self.get_all_positions(mode)
        if not positions:
            return Decimal("0")

        total_position_value = sum(
            pos.quantity * pos.avg_entry_price for pos in positions
        )

        snapshot = cast(dict[str, Any] | None, self._db.execute(
            """
            SELECT total_value_usd FROM portfolio_snapshots
            WHERE mode = %s ORDER BY snapshot_at DESC LIMIT 1
            """,
            (mode,),
        ).fetchone())

        if snapshot is None:
            return Decimal("0")

        portfolio_value = Decimal(snapshot["total_value_usd"])
        if portfolio_value == 0:
            return Decimal("0")

        return (total_position_value / portfolio_value * Decimal("100")).quantize(
            Decimal("0.01")
        )

    def get_cash_balance(self, mode: str) -> Decimal:
        """Return cash balance as a percentage of total portfolio.

        Derived from the most recent portfolio snapshot.
        """
        snapshot = cast(dict[str, Any] | None, self._db.execute(
            """
            SELECT cash_balance, total_value_usd FROM portfolio_snapshots
            WHERE mode = %s ORDER BY snapshot_at DESC LIMIT 1
            """,
            (mode,),
        ).fetchone())

        if snapshot is None:
            # No snapshot — assume fully liquid (conservative for risk checks)
            return Decimal("100")

        cash = Decimal(snapshot["cash_balance"])
        total = Decimal(snapshot["total_value_usd"])

        if total == 0:
            return Decimal("100")

        return (cash / total * Decimal("100")).quantize(Decimal("0.01"))

    # ------------------------------------------------------------------
    # Position writes
    # ------------------------------------------------------------------

    def update_position(self, symbol: str, mode: str, **kwargs) -> None:
        """Update arbitrary fields on a position row.

        Accepted keyword arguments map directly to positions table columns.
        Always updates the updated_at timestamp.
        """
        allowed_cols = {
            "quantity",
            "avg_entry_price",
            "current_value",
            "unrealised_pnl",
            "strategy_tier",
            "stop_loss_price",
            "take_profit_price",
            "max_hold_until",
            "exchange_order_id",
            "last_reconciled_at",
        }
        updates = {k: v for k, v in kwargs.items() if k in allowed_cols}
        if not updates:
            return

        set_clause = ", ".join(f"{col} = %s" for col in updates)
        values: list = [
            str(v) if isinstance(v, Decimal) else v for v in updates.values()
        ]
        values.extend([symbol, mode])

        # Ensure row exists, then update
        self._db.execute(
            """
            INSERT INTO positions (symbol, mode, quantity, updated_at)
            VALUES (%s, %s, '0', NOW())
            ON CONFLICT DO NOTHING
            """,
            (symbol, mode),
        )
        query = cast(Any, f"""
            UPDATE positions SET {set_clause}, updated_at = NOW()
            WHERE symbol = %s AND mode = %s
            """)
        self._db.execute(query, values)
        self._db.commit()

        self._log.debug(
            "position_updated",
            symbol=symbol,
            mode=mode,
            fields=list(updates.keys()),
        )

    def close_position(self, symbol: str, mode: str, exit_reason: str) -> None:
        """Mark a position as closed (quantity = 0) with exit reason."""
        self._db.execute(
            """
            UPDATE positions
            SET quantity = '0',
                closed_by = %s,
                updated_at = NOW()
            WHERE symbol = %s AND mode = %s
            """,
            (exit_reason, symbol, mode),
        )
        self._db.commit()

        self._log.info(
            "position_closed",
            symbol=symbol,
            mode=mode,
            exit_reason=exit_reason,
        )

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _row_to_position(row) -> PositionInfo:
        max_hold: datetime | None = None
        if row["max_hold_until"]:
            try:
                max_hold = datetime.fromisoformat(row["max_hold_until"]).replace(
                    tzinfo=timezone.utc
                )
            except (ValueError, TypeError):
                pass

        return PositionInfo(
            symbol=row["symbol"],
            mode=row["mode"],
            quantity=Decimal(row["quantity"] or "0"),
            avg_entry_price=Decimal(row["avg_entry_price"] or "0"),
            strategy_tier=row["strategy_tier"] or "core",
            stop_loss_price=Decimal(row["stop_loss_price"])
            if row["stop_loss_price"]
            else None,
            take_profit_price=(
                Decimal(row["take_profit_price"]) if row["take_profit_price"] else None
            ),
            max_hold_until=max_hold,
            exchange_order_id=row["exchange_order_id"],
        )
