"""
modules/drive.py
----------------
Utilidades de SOLO LECTURA para Google Drive usando Service Account.

Objetivo:
- Minimizar RAM y latencia (cache local + indexación de carpeta + filtros pushdown en Parquet)
- Contratos consistentes (bytes in / DataFrame out)
- Soportar binarios (parquet/xlsx/csv) y Google Sheets (export -> xlsx)

Requisitos:
- google-api-python-client
- google-auth
- pandas
- openpyxl
- pyarrow
"""

from __future__ import annotations

import io
import os
import re
import json
import hashlib
import unicodedata
from pathlib import Path
from typing import Any, Iterator
import pandas as pd
import streamlit as st
import pyarrow as pa
import pyarrow.parquet as pq
from google.oauth2 import service_account
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload
from modules.config import get_setting

# =============================================================================
# Constantes
# =============================================================================
SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]

GOOGLE_SHEET_MIME = "application/vnd.google-apps.spreadsheet"
XLSX_MIME = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"

CACHE_DIR = Path("/tmp/cabanna_drive_cache")
CACHE_DIR.mkdir(parents=True, exist_ok=True)

FOLDER_ID_FOLIOS = get_setting("FOLDER_ID_FOLIOS")
FOLDER_ID_LINEAS = get_setting("FOLDER_ID_LINEAS")
FILE_ID =get_setting("FILE_ID")


# =============================================================================
# Auth / Service
# =============================================================================
def _resolve_sa_path(path_str: str) -> str:
    p = Path(path_str)
    if not p.is_absolute():
        ROOT = Path(__file__).resolve().parents[1]  # Cabanna/
        p = (ROOT / p).resolve()
    return str(p)


def _looks_like_json(s: str) -> bool:
    s = (s or "").strip()
    return s.startswith("{") and s.endswith("}")


def _normalize_private_key(info: dict) -> dict:
    info = dict(info)
    pk = info.get("private_key")
    if isinstance(pk, str):
        if "\\n" in pk:
            pk = pk.replace("\\n", "\n")
        if not pk.endswith("\n"):
            pk += "\n"
        info["private_key"] = pk
    return info


def _load_sa_credentials_from_env():
    """
    Prioridad:
    1) st.secrets["gcp_service_account"] (dict)
    2) ENV GDRIVE_SA_JSON como JSON inline
    3) ENV GDRIVE_SA_JSON como ruta al .json (local)
    """
    if "gcp_service_account" in st.secrets:
        info = _normalize_private_key(dict(st.secrets["gcp_service_account"]))
        return service_account.Credentials.from_service_account_info(info, scopes=SCOPES)

    raw = os.getenv("GDRIVE_SA_JSON")
    if not raw:
        raise FileNotFoundError(
            "Falta credencial. Define:\n"
            "- Streamlit Secrets: [gcp_service_account]\n"
            "o\n"
            "- ENV: GDRIVE_SA_JSON (JSON completo o ruta al .json)"
        )

    raw = raw.strip().replace("\r\n", "\n")

    if _looks_like_json(raw):
        info = json.loads(raw)
        info = _normalize_private_key(info)
        return service_account.Credentials.from_service_account_info(info, scopes=SCOPES)

    json_path = _resolve_sa_path(raw)
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"No encuentro el JSON del Service Account en: {json_path}")

    return service_account.Credentials.from_service_account_file(json_path, scopes=SCOPES)


@st.cache_resource
def get_drive_service_cached():
    creds = _load_sa_credentials_from_env()
    return build("drive", "v3", credentials=creds, cache_discovery=False)


def get_drive_service():
    return get_drive_service_cached()


# =============================================================================
# Metadata / Indexación
# =============================================================================
@st.cache_data(ttl=600, show_spinner=False)
def get_drive_meta(file_id: str) -> dict[str, Any]:
    """
    Metadata mínima para caching/branching.
    """
    service = get_drive_service()
    return (
        service.files()
        .get(
            fileId=file_id,
            fields="id,name,modifiedTime,mimeType,size",
            supportsAllDrives=True,
        )
        .execute()
    )


@st.cache_data(ttl=600, show_spinner=False)
def index_folder_files(folder_id: str) -> dict[str, dict]:
    """
    Indexa la carpeta 1 vez: name -> {id, modifiedTime, size, mimeType}
    Reduce N calls (find por archivo) a 1 list paginada.
    """
    service = get_drive_service()
    q = f"'{folder_id}' in parents and trashed=false"

    out: dict[str, dict] = {}
    page_token = None

    while True:
        res = (
            service.files()
            .list(
                q=q,
                fields="nextPageToken, files(id,name,modifiedTime,size,mimeType)",
                pageSize=1000,
                pageToken=page_token,
                supportsAllDrives=True,
                includeItemsFromAllDrives=True,
            )
            .execute()
        )

        for f in res.get("files", []):
            out[f["name"]] = {
                "id": f["id"],
                "modifiedTime": f.get("modifiedTime"),
                "size": f.get("size"),
                "mimeType": f.get("mimeType"),
            }

        page_token = res.get("nextPageToken")
        if not page_token:
            break

    return out


def find_file_id_in_folder(folder_id: str, filename: str) -> str | None:
    idx = index_folder_files(folder_id)
    rec = idx.get(filename)
    return rec["id"] if rec else None


# =============================================================================
# Cache local de bytes
# =============================================================================
def _cache_path_for(file_id: str, filename: str) -> Path:
    key = hashlib.md5(f"{file_id}|{filename}".encode("utf-8")).hexdigest()
    return CACHE_DIR / f"{key}.bin"


def _cache_path_for_generic(key: str) -> Path:
    h = hashlib.md5(key.encode("utf-8")).hexdigest()
    return CACHE_DIR / f"{h}.bin"


def download_file_bytes(file_id: str, filename: str, modifiedTime: str | None) -> bytes:
    """
    Descarga binarios con cache local por modifiedTime.
    IMPORTANTE: SOLO para archivos descargables con get_media (no Google Sheets/Docs).
    """
    p = _cache_path_for(file_id, filename)
    meta_p = p.with_suffix(".meta")

    if p.exists() and meta_p.exists() and modifiedTime:
        try:
            if meta_p.read_text().strip() == modifiedTime:
                return p.read_bytes()
        except Exception:
            pass

    service = get_drive_service()
    request = service.files().get_media(fileId=file_id)

    fh = io.BytesIO()
    downloader = MediaIoBaseDownload(fh, request, chunksize=8 * 1024 * 1024)

    done = False
    while not done:
        _, done = downloader.next_chunk()

    data = fh.getvalue()

    try:
        p.write_bytes(data)
        if modifiedTime:
            meta_p.write_text(modifiedTime)
    except Exception:
        pass

    return data


def export_sheet_xlsx_bytes(file_id: str, filename: str, modifiedTime: str | None) -> bytes:
    """
    Exporta Google Sheet -> XLSX con cache local por modifiedTime.
    """
    cache_key = f"export_xlsx|{file_id}|{filename}"
    p = _cache_path_for_generic(cache_key)
    meta_p = p.with_suffix(".meta")

    if p.exists() and meta_p.exists() and modifiedTime:
        try:
            if meta_p.read_text().strip() == modifiedTime:
                return p.read_bytes()
        except Exception:
            pass

    service = get_drive_service()
    request = service.files().export_media(fileId=file_id, mimeType=XLSX_MIME)

    fh = io.BytesIO()
    downloader = MediaIoBaseDownload(fh, request, chunksize=8 * 1024 * 1024)

    done = False
    while not done:
        _, done = downloader.next_chunk()

    data = fh.getvalue()

    try:
        p.write_bytes(data)
        if modifiedTime:
            meta_p.write_text(modifiedTime)
    except Exception:
        pass

    return data

# =============================================================================
# Lecturas específicas
# =============================================================================
def read_excel_from_drive(file_id: str, sheet_name: str | int = 0) -> pd.DataFrame:
    """
    Lee un Excel desde Drive como DataFrame.
    Soporta:
    - Google Sheet: exporta a XLSX
    - XLSX binario: descarga normal
    """
    meta = get_drive_meta(file_id)
    filename = meta.get("name", file_id)
    modifiedTime = meta.get("modifiedTime")
    mimeType = meta.get("mimeType")

    if mimeType == GOOGLE_SHEET_MIME:
        content = export_sheet_xlsx_bytes(file_id, filename, modifiedTime)
    else:
        content = download_file_bytes(file_id, filename, modifiedTime)

    return pd.read_excel(io.BytesIO(content), sheet_name=sheet_name)

# =============================================================================
# Transformación: Comensales (Excel multi-hoja)
# =============================================================================
MESES_MAP = {
    "ENERO": 1, "FEBRERO": 2, "MARZO": 3, "ABRIL": 4, "MAYO": 5, "JUNIO": 6,
    "JULIO": 7, "AGOSTO": 8, "SEPTIEMBRE": 9, "OCTUBRE": 10, "NOVIEMBRE": 11, "DICIEMBRE": 12,
    "ENE": 1, "FEB": 2, "MAR": 3, "ABR": 4, "MAY": 5, "JUN": 6,
    "JUL": 7, "AGO": 8, "SEP": 9, "SEPT": 9, "OCT": 10, "NOV": 11, "DIC": 12,
}
_SHEET_YM_RE = re.compile(r"^\s*(\d{4})\s*[-_/ ]\s*([A-Za-zÁÉÍÓÚÜÑáéíóúüñ\.]+)\s*$", re.UNICODE)

def _strip_accents_upper(s: str) -> str:
    s = str(s).strip().upper().replace(".", "")
    s = "".join(c for c in unicodedata.normalize("NFKD", s) if not unicodedata.combining(c))
    return s

def _parse_sheet_year_month(sheet_name: str) -> tuple[int | None, int | None]:
    s = _strip_accents_upper(sheet_name)
    m = _SHEET_YM_RE.match(s)
    if m:
        year = int(m.group(1))
        mon_raw = _strip_accents_upper(m.group(2))
        return year, MESES_MAP.get(mon_raw)

    # Compatibilidad: "ENERO" sin año
    return (None, MESES_MAP.get(s))

def _detect_sucursal_col(df: pd.DataFrame) -> str | None:
    for c in df.columns:
        if str(c).strip().upper() == "SUCURSAL":
            return c
    return None

def _detect_day_cols(df: pd.DataFrame) -> list[Any]:
    day_cols: list[Any] = []
    for c in df.columns:
        s = str(c).strip()
        if s.isdigit():
            d = int(s)
            if 1 <= d <= 31:
                day_cols.append(c)
    return day_cols

def transformar_comensales_excel(excel_bytes_or_path) -> pd.DataFrame:
    """
    Hojas soportadas:
      - "2025-Enero", "2025 Enero", "2025_Ene", etc.
      - (compat) "ENERO"... "DICIEMBRE" (sin año) -> solo se usa si NO hay años explícitos.

    Formato por hoja:
      SUCURSAL | 1 | 2 | ... | 31

    Output:
      fecha | NombreSucursal | comensales
    """
    xls = pd.ExcelFile(excel_bytes_or_path)
    frames: list[pd.DataFrame] = []

    years_detected = set()
    parsed: list[tuple[str, int | None, int]] = []

    for sheet in xls.sheet_names:
        y, m = _parse_sheet_year_month(sheet)
        if m is None:
            continue
        if y is not None:
            years_detected.add(y)
        parsed.append((sheet, y, m))

    has_explicit_years = len(years_detected) >= 1

    for sheet, year, mes_num in parsed:
        if year is None and has_explicit_years:
            continue  # fail-safe: no adivinar

        df = pd.read_excel(xls, sheet_name=sheet)

        col_suc = _detect_sucursal_col(df)
        if col_suc is None:
            continue

        day_cols = _detect_day_cols(df)
        if not day_cols:
            continue

        tmp = df[[col_suc] + day_cols].copy()
        tmp = tmp.rename(columns={col_suc: "NombreSucursal"})
        tmp["NombreSucursal"] = tmp["NombreSucursal"].astype(str).str.strip()

        tmp = tmp.melt(
            id_vars=["NombreSucursal"],
            value_vars=day_cols,
            var_name="dia",
            value_name="comensales",
        )

        tmp["dia"] = tmp["dia"].astype(str).str.strip().astype(int)
        tmp["comensales"] = pd.to_numeric(tmp["comensales"], errors="coerce")

        # Año obligatorio para construir fecha
        if year is None:
            continue

        tmp["fecha"] = pd.to_datetime(
            dict(year=year, month=mes_num, day=tmp["dia"]),
            errors="coerce",
        )

        tmp = tmp.dropna(subset=["fecha", "comensales"])
        tmp["comensales"] = tmp["comensales"].round(0).astype(int)

        frames.append(tmp[["fecha", "NombreSucursal", "comensales"]])

    if not frames:
        return pd.DataFrame(columns=["fecha", "NombreSucursal", "comensales"])

    out = pd.concat(frames, ignore_index=True)
    out["NombreSucursal"] = out["NombreSucursal"].str.replace(r"\s+", " ", regex=True).str.strip()
    out = out.sort_values(["fecha", "NombreSucursal"]).reset_index(drop=True)
    return out

@st.cache_data(show_spinner="Cargando comensales desde Google Drive...")
def load_comensales(file_id: str) -> pd.DataFrame:
    meta = get_drive_meta(file_id)
    filename = meta.get("name", file_id)
    modifiedTime = meta.get("modifiedTime")
    mimeType = meta.get("mimeType")

    if mimeType == GOOGLE_SHEET_MIME:
        content = export_sheet_xlsx_bytes(file_id, filename, modifiedTime)
    else:
        content = download_file_bytes(file_id, filename, modifiedTime)

    df = transformar_comensales_excel(io.BytesIO(content))
    if df.empty:
        return df

    # Normalizar a "Fecha"
    if "fecha" in df.columns and "Fecha" not in df.columns:
        df = df.rename(columns={"fecha": "Fecha"})

    if "NombreSucursal" in df.columns:
        df = df[df["NombreSucursal"].ne("TOTAL")]

    df["Fecha"] = pd.to_datetime(df["Fecha"], errors="coerce").dt.normalize()

    MAP_SUC = {
        "Cd. Juárez": "CD JUAREZ",
        "Culiacán": "CULIACAN",
        "Guadalajara - Av. México": "AV MEXICO",
        "Guadalajara - Gourmetería": "GOURMETERIA",
        "Mexicali": "MEXICALI",
        "Monterrey": "METROPOLITAN",
        "Polanco": "POLANCO",
        "Puebla": "PUEBLA",
        "Tijuana": "TIJUANA",
    }

    s = df["NombreSucursal"].astype("string")
    df["NombreSucursal"] = (
        s.str.replace(r"\s+", " ", regex=True)
         .str.strip()
         .replace(MAP_SUC)
         .str.upper()
         .astype("category")
    )

    df["comensales"] = pd.to_numeric(df["comensales"], errors="coerce")
    return df

# =============================================================================
# Folios (Parquet mensual) - carga rápida
# =============================================================================
def _months_between(start: pd.Timestamp, end: pd.Timestamp) -> Iterator[pd.Timestamp]:
    cur = start.normalize().replace(day=1)
    end = pd.to_datetime(end)
    while cur < end:
        yield cur
        cur = (cur + pd.offsets.MonthBegin(1)).normalize()

@st.cache_data(show_spinner="Cargando datos de ventas...", ttl=600, max_entries=8)
def load_folios(ini_q, fin_q) -> pd.DataFrame:
    start_dt = pd.to_datetime(ini_q)
    end_dt = pd.to_datetime(fin_q)

    if not FOLDER_ID_FOLIOS:
        raise RuntimeError("Falta FOLDER_ID_FOLIOS (ID de la carpeta 'folios' en Drive).")

    idx = index_folder_files(FOLDER_ID_FOLIOS)

    filters = [
        ("Fecha", ">=", start_dt.to_pydatetime()),
        ("Fecha", "<", end_dt.to_pydatetime()),
    ]

    tables: list[pa.Table] = []

    for m in _months_between(start_dt, end_dt):
        fname = f"folios_{m.strftime('%Y_%m')}.parquet"
        rec = idx.get(fname)
        if not rec:
            continue

        file_id = rec["id"]
        modifiedTime = rec.get("modifiedTime")

        content = download_file_bytes(file_id, fname, modifiedTime)
        if not content:
            continue

        try:
            tbl = pq.read_table(io.BytesIO(content), filters=filters, use_threads=True)
        except Exception:
            # fallback: pandas filtra (más lento, pero robusto)
            df_m = pd.read_parquet(io.BytesIO(content))
            if "Fecha" in df_m.columns:
                df_m["Fecha"] = pd.to_datetime(df_m["Fecha"], errors="coerce")
                df_m = df_m[(df_m["Fecha"] >= start_dt) & (df_m["Fecha"] < end_dt)]
            if not df_m.empty:
                tables.append(pa.Table.from_pandas(df_m, preserve_index=False))
            continue

        if tbl is not None and tbl.num_rows:
            tables.append(tbl)

    if not tables:
        return pd.DataFrame()

    out_tbl = pa.concat_tables(tables, promote=True)
    df = out_tbl.to_pandas(self_destruct=True, split_blocks=True)

    if "Fecha" in df.columns:
        df["Fecha"] = pd.to_datetime(df["Fecha"], errors="coerce")

    return df

# =============================================================================
# Linease (Parquet mensual) - carga rápida
# =============================================================================
def _months_between(start: pd.Timestamp, end: pd.Timestamp):
    cur = start.normalize().replace(day=1)
    end = pd.to_datetime(end)
    while cur < end:
        yield cur
        cur = (cur + pd.offsets.MonthBegin(1)).normalize()


@st.cache_data(show_spinner="Indexando parquets de líneas en Drive...", ttl=600, max_entries=8)
def _lineas_index() -> dict[str, dict]:
    # 1 llamada grande cacheada (name -> {id, modifiedTime, size, mimeType})
    return index_folder_files(FOLDER_ID_LINEAS)


@st.cache_data(show_spinner="Cargando líneas desde Drive...")
def load_lineas(ini_q, fin_q) -> pd.DataFrame:
    start_dt = pd.to_datetime(ini_q)
    end_dt   = pd.to_datetime(fin_q)

    idx = _lineas_index()

    # Filtro pushdown (si Fecha está bien tipada en el parquet)
    filters = [
        ("Fecha", ">=", start_dt.to_pydatetime()),
        ("Fecha", "<",  end_dt.to_pydatetime()),
    ]

    tables: list[pa.Table] = []

    for m in _months_between(start_dt, end_dt):
        fname = f"lineas_{m.strftime('%Y_%m')}.parquet"
        rec = idx.get(fname)
        if not rec:
            continue

        file_id = rec["id"]
        modifiedTime = rec.get("modifiedTime")

        content = download_file_bytes(file_id, fname, modifiedTime)
        if not content:
            continue

        # Intento 1: Arrow pushdown (rápido y RAM-friendly)
        try:
            tbl = pq.read_table(io.BytesIO(content), filters=filters, use_threads=True)
            if tbl is not None and tbl.num_rows:
                tables.append(tbl)
            continue
        except Exception:
            # Fallback: pandas (más lento, pero robusto)
            pass

        df_m = pd.read_parquet(io.BytesIO(content))

        if "Fecha" in df_m.columns:
            df_m["Fecha"] = pd.to_datetime(df_m["Fecha"], errors="coerce")
            df_m = df_m[(df_m["Fecha"] >= start_dt) & (df_m["Fecha"] < end_dt)]

        if not df_m.empty:
            tables.append(pa.Table.from_pandas(df_m, preserve_index=False))

    if not tables:
        return pd.DataFrame()

    out_tbl = pa.concat_tables(tables, promote=True)

    # self_destruct reduce picos de RAM en conversión a pandas
    out = out_tbl.to_pandas(self_destruct=True, split_blocks=True)

    if "Fecha" in out.columns:
        out["Fecha"] = pd.to_datetime(out["Fecha"], errors="coerce")

    return out
