"""PostgreSQL-native database layer for micro-scanner."""

from __future__ import annotations

from typing import Any, cast

import os
import time
from pathlib import Path

import psycopg
from psycopg.rows import dict_row
from psycopg_pool import ConnectionPool
import structlog

log = structlog.get_logger()

_pool: ConnectionPool | None = None

_REQUIRED_MICRO_TABLES = frozenset(
    {
        "micro_signals",
        "micro_positions",
        "micro_coin_candidates",
        "micro_cooldowns",
    }
)


def init_pool(
    conninfo: str | None = None, *, min_size: int = 1, max_size: int = 2
) -> ConnectionPool:
    global _pool
    dsn = conninfo or os.environ.get("SIGIL_DATABASE_URL", "")
    if not dsn:
        raise ValueError("SIGIL_DATABASE_URL not set")
    if "connect_timeout" not in dsn:
        sep = "&" if "?" in dsn else "?"
        dsn = f"{dsn}{sep}connect_timeout=5"
    _pool = ConnectionPool(
        conninfo=dsn,
        min_size=min_size,
        max_size=max_size,
        kwargs={"row_factory": dict_row, "autocommit": False},
        open=True,
    )
    log.info("micro_pg_pool_initialized", min_size=min_size, max_size=max_size)
    return _pool


def get_pool() -> ConnectionPool:
    if _pool is None:
        raise RuntimeError("Pool not initialised — call init_pool() first")
    return _pool


def get_connection() -> psycopg.Connection:
    return get_pool().getconn()


def return_connection(conn: psycopg.Connection) -> None:
    try:
        get_pool().putconn(conn)
    except Exception:
        pass


def close_pool() -> None:
    global _pool
    if _pool is not None:
        _pool.close()
        _pool = None


def init_micro_tables(conn: psycopg.Connection) -> None:
    cursor = conn.execute(
        "SELECT table_name FROM information_schema.tables "
        "WHERE table_schema = 'public' AND table_name LIKE 'micro_%%'"
    )
    existing = {row["table_name"] for row in cast(list[dict[str, Any]], cursor.fetchall())}
    missing = _REQUIRED_MICRO_TABLES - existing
    if missing:
        log.warning("micro_tables_missing", missing=sorted(missing))
    else:
        log.info("micro_tables_verified", tables=sorted(_REQUIRED_MICRO_TABLES))


def write_heartbeat(path: str) -> None:
    Path(path).write_text(str(time.time()))


def read_heartbeat(path: str) -> float | None:
    p = Path(path)
    if p.exists():
        try:
            return float(p.read_text().strip())
        except (ValueError, OSError):
            return None
    return None
