"""Capital pool management for micro-scanner.

Architecture Design §2.3, FR-05 fix.

Micro capital is isolated from core/pump capital at startup.
Only draws from MICRO_CAPITAL_POOL, never from other tiers.

Reads portfolio value from the main bot's ``portfolio_snapshots`` table
(READ-ONLY access — never writes to main bot tables).
"""

from __future__ import annotations

import math
import psycopg
from decimal import Decimal
from typing import TYPE_CHECKING, Any, cast

import structlog

if TYPE_CHECKING:
    from src.micro_scanner.config import MicroConfig

log = structlog.get_logger(__name__)

# Minimum cash reserve across the whole portfolio
_CASH_RESERVE_PCT = Decimal("0.20")


class CapitalManager:
    """Manage the micro-scanner capital pool.

    At construction, reserves ``capital_pct`` percent of the current portfolio
    as the micro pool. All position sizing draws from this isolated pool.
    """

    def __init__(self, config: "MicroConfig", conn: psycopg.Connection) -> None:
        self._config = config
        self._conn = conn
        self._portfolio_value: Decimal = Decimal("0")
        self._micro_pool: Decimal = Decimal("0")

    # ------------------------------------------------------------------
    # Startup
    # ------------------------------------------------------------------

    def initialise(self) -> None:
        """Read portfolio value from DB and reserve the micro pool.

        Must be called once during startup, after DB connection is available.
        Falls back to zero if the table doesn't exist (draft / fresh DB).
        """
        self._portfolio_value = self._read_portfolio_value()
        self._micro_pool = self._portfolio_value * Decimal(
            str(self._config.capital_pct / 100.0)
        )
        log.info(
            "capital_pool_initialised",
            portfolio_usd=str(self._portfolio_value),
            micro_pool_usd=str(self._micro_pool),
            capital_pct=self._config.capital_pct,
        )

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def available_cash(self) -> Decimal:
        """Return micro pool minus the notional value of open positions.

        Queries live open micro_positions for up-to-date exposure.
        """
        open_exposure = self._open_position_exposure()
        available = self._micro_pool - open_exposure
        # Never return negative
        return max(available, Decimal("0"))

    def calculate_position_size(self, score: float) -> Decimal:
        """Return position size in USDT for the given composite score.

        - score >= config.score_full_position (0.70): use base_position_pct (5%)
        - score >= config.score_entry_threshold (0.55): use half_position_pct (2.5%)
        - below entry threshold: return 0 (should not be called in this case)

        Always floors to config.min_notional_usd ($5).
        """
        avail = self.available_cash()

        if score >= self._config.score_full_position:
            pct = Decimal(str(self._config.base_position_pct / 100.0))
        elif score >= self._config.score_entry_threshold:
            pct = Decimal(str(self._config.half_position_pct / 100.0))
        else:
            log.debug("position_size_below_threshold", score=score)
            return Decimal("0")

        size = avail * pct
        floor = Decimal(str(self._config.min_notional_usd))
        result = max(size, floor)

        log.debug(
            "position_size_calculated",
            score=score,
            available_cash=str(avail),
            position_pct=str(pct),
            size_usd=str(result),
        )
        return result

    def max_positions(self) -> int:
        """Return the maximum number of simultaneous micro positions.

        Architecture Design §2.3: floor(portfolio / $15), capped at max_simultaneous.
        """
        if self._portfolio_value <= 0:
            return 0
        portfolio_float = float(self._portfolio_value)
        dynamic = math.floor(portfolio_float / 15.0)
        return min(dynamic, self._config.max_simultaneous)

    def refresh_portfolio(self) -> None:
        """Re-read portfolio value from DB.  Call periodically (e.g., each scan cycle)."""
        self._portfolio_value = self._read_portfolio_value()
        self._micro_pool = self._portfolio_value * Decimal(
            str(self._config.capital_pct / 100.0)
        )

    @property
    def micro_pool(self) -> Decimal:
        return self._micro_pool

    @property
    def portfolio_value(self) -> Decimal:
        return self._portfolio_value

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    def _read_portfolio_value(self) -> Decimal:
        """Read latest portfolio total value from portfolio_snapshots.

        Returns 0 if the table doesn't exist or is empty.
        """
        try:
            cursor = self._conn.execute(
                """
                SELECT total_value_usd
                FROM portfolio_snapshots
                ORDER BY snapshot_at DESC
                LIMIT 1
                """
            )
            row = cursor.fetchone()
            row = cast(dict[str, Any] | None, row)
            if row:
                return Decimal(str(row["total_value_usd"]))
        except Exception as exc:
            log.warning(
                "portfolio_read_error",
                error=str(exc),
                hint="portfolio_snapshots may not exist yet; using 0",
            )
        return Decimal("0")

    def _open_position_exposure(self) -> Decimal:
        """Sum the notional value of all open micro positions.

        Uses avg_fill_price * filled_qty for each open position.
        """
        try:
            cursor = self._conn.execute(
                """
                SELECT avg_fill_price, filled_qty
                FROM micro_positions
                WHERE status = 'open'
                """
            )
            total = Decimal("0")
            for row in cast(list[dict[str, Any]], cursor.fetchall()):
                price = Decimal(str(row["avg_fill_price"]))
                qty = Decimal(str(row["filled_qty"]))
                total += price * qty
            return total
        except Exception as exc:
            log.warning("open_exposure_read_error", error=str(exc))
            return Decimal("0")
