
# almha_dashboard.py
# =============================================================================
# Almha • Business Intelligence Dashboard (Prototipo)
# -----------------------------------------------------------------------------
# - App multipágina en Streamlit con datos simulados (mock) realistas
# - Secciones: Resumen Ejecutivo, Farmacias, Laboratorio, Hospital, Fertilidad
# - Librerías: streamlit, pandas, numpy, plotly_express (y graph_objects)
# - Ejecución: streamlit run almha_dashboard.py
# =============================================================================

from __future__ import annotations

import os
import base64
from dataclasses import dataclass
from datetime import datetime, timedelta, date

import numpy as np
import pandas as pd
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go

# ---------------------------------- CONFIG ----------------------------------
st.set_page_config(
    page_title="Almha • BI Dashboard",
    page_icon="🏥",
    layout="wide",
    initial_sidebar_state="expanded",
)

PRIMARY = "#2563eb"     # azul profesional
ACCENT = "#22c55e"      # verde para positivos
DANGER = "#ef4444"      # rojo para alertas
MUTED  = "#64748b"      # gris azulado
BG     = "#0b1220"      # fondo oscuro elegante
FG     = "#e5e7eb"      # texto
CARD   = "#111827"      # cartas

st.markdown(f"""
<style>
    .stApp {{
        background: radial-gradient(1200px 600px at 10% 5%, rgba(37,99,235,0.10), transparent 50%),
                    radial-gradient(1100px 600px at 90% 10%, rgba(34,197,94,0.08), transparent 50%),
                    {BG};
        color: {FG};
    }}
    .metric-card {{
        background: {CARD};
        padding: 16px 18px;
        border-radius: 16px;
        border: 1px solid rgba(255,255,255,0.06);
        box-shadow: 0 4px 18px rgba(0,0,0,0.20);
    }}
    .callout {{
        border-left: 6px solid {PRIMARY};
        padding: 12px 16px;
        background: rgba(255,255,255,0.03);
        border-radius: 8px;
    }}
    .small {{
        color: {MUTED};
        font-size: 0.88rem;
    }}
    .ok   {{ color: {ACCENT}; }}
    .bad  {{ color: {DANGER}; }}
</style>
""", unsafe_allow_html=True)


# ------------------------------- UTILIDADES ---------------------------------
@dataclass
class DateRange:
    start: pd.Timestamp
    end: pd.Timestamp

def money(x: float) -> str:
    """Formatea montos en MXN con separadores."""
    if pd.isna(x):
        return "$0"
    return "${:,.0f}".format(x)

def pct(x: float) -> str:
    if pd.isna(x):
        return "0%"
    return "{:.1f}%".format(100 * x)

def try_load_logo(encoded_path_list: list[str]) -> str | None:
    """Carga imagen (si existe) y la devuelve en base64 para HTML <img>."""
    for p in encoded_path_list:
        if os.path.exists(p):
            with open(p, "rb") as f:
                b64 = base64.b64encode(f.read()).decode("utf-8")
            return f"data:image/png;base64,{b64}"
    return None

@st.cache_data(show_spinner=False)
def generate_mock_data(seed: int = 42) -> dict[str, pd.DataFrame]:
    """
    Genera datos simulados con estacionalidad semanal/anual y variación por unidad.
    Cubre ~18 meses hasta la fecha actual.
    """
    rng = np.random.default_rng(seed)

    # Horizonte temporal
    end = pd.Timestamp(date.today())
    start = end - pd.Timedelta(days=540)  # ~18 meses
    days = pd.date_range(start, end, freq="D")
    n_days = len(days)

    # ------------------------ FARMACIAS (ventas) -------------------------
    branches = ["Centro Mazatlán", "Marina", "Roma CDMX", "GDL Andares",
                "Santa Fe", "Monterrey Valle", "Tijuana Río", "Culiacán Forum"]
    categories = ["Analgésicos", "Antibióticos", "Crónicos",
                  "Respiratorios", "Vitaminas", "Dermatológicos"]
    meds = {
        "Analgésicos": ["Paracetamol 500mg", "Ibuprofeno 400mg", "Naproxeno 550mg",
                        "Diclofenaco 50mg", "Ketorolaco 10mg", "Tramadol 50mg"],
        "Antibióticos": ["Amoxicilina 500mg", "Azitromicina 500mg", "Ciprofloxacino 500mg",
                         "Cefalexina 500mg", "Claritromicina 500mg", "Doxiciclina 100mg"],
        "Crónicos": ["Metformina 850mg", "Losartán 50mg", "Amlodipino 5mg",
                     "Atorvastatina 20mg", "Levotiroxina 50mcg", "Enalapril 10mg"],
        "Respiratorios": ["Salbutamol Inhalador", "Budesonida Inhalador",
                          "Loratadina 10mg", "Montelukast 10mg", "Ambroxol Jarabe"],
        "Vitaminas": ["Vitamina C 1g", "Complejo B", "Vitamina D 400UI",
                      "Omega 3", "Multivitamínico Adulto"],
        "Dermatológicos": ["Clotrimazol Crema", "Hidrocortisona Crema", "Ácido Hialurónico Gel",
                           "Bepanthen", "Aciclovir Crema"]
    }

    # Precio base por categoría (MXN)
    price_cat = {
        "Analgésicos": 80, "Antibióticos": 180, "Crónicos": 220,
        "Respiratorios": 160, "Vitaminas": 140, "Dermatológicos": 200
    }

    # Demanda base semanal y estacionalidad (invierno: +respiratorio)
    weekday_effect = np.array([0.9, 0.95, 1.0, 1.05, 1.15, 1.35, 1.1])  # L..D
    day_of_week = np.array([d.weekday() for d in days])
    week_mult = weekday_effect[day_of_week]

    # Efecto estacional anual (más respiratorios en nov-feb)
    month = np.array([d.month for d in days])
    resp_mult = np.where((month >= 11) | (month <= 2), 1.35, 0.95)

    ph_rows = []
    ph_hourly_rows = []
    for b in branches:
        base_trx = rng.normal(320, 40)  # transacciones diarias por sucursal
        for i, dt in enumerate(days):
            w = week_mult[i]
            # Respiratorios se ven amplificados por estacionalidad
            demand_resp = rng.poisson(lam=max(30, 35 * w * resp_mult[i]))
            demand_other = rng.poisson(lam=max(210, base_trx * w - demand_resp))

            # Repartir demanda entre categorías
            cat_weights = np.array([1.0, 0.9, 1.15, 1.0, 0.8, 0.6])
            cat_weights = cat_weights / cat_weights.sum()
            # Ajustar respiratorios
            cat_weights[3] *= resp_mult[i]
            cat_weights = cat_weights / cat_weights.sum()

            trx_total = int(demand_other + demand_resp)
            # Simular ticket promedio alrededor de $220
            avg_ticket = max(120, rng.normal(220, 25))
            revenue_day = avg_ticket * trx_total

            # Por categoría elegir medicamentos y cantidades
            for ci, c in enumerate(categories):
                # proporción por categoría
                cat_trx = int(trx_total * cat_weights[ci])
                # precio base con ruido
                price = max(40, rng.normal(price_cat[c], price_cat[c] * 0.15))
                # escoger medicamentos de la categoría
                choices = meds[c]
                probs = np.full(len(choices), 1/len(choices))
                # repartir cantidad por medicamento
                if cat_trx > 0:
                    alloc = rng.multinomial(cat_trx, probs)
                else:
                    alloc = np.zeros(len(choices), dtype=int)

                for med, qty in zip(choices, alloc):
                    if qty <= 0:
                        continue
                    rev = qty * price
                    ph_rows.append([dt, b, med, c, qty, price, rev, cat_trx])

            # --- Matriz HOURLY para mapa de calor ---
            # Curva horaria típica (pico medio día y tarde)
            hourly_profile = np.array([0.02,0.02,0.01,0.01,0.01,0.02,0.03,0.04,0.06,0.08,0.09,0.10,
                                       0.10,0.09,0.08,0.07,0.06,0.05,0.04,0.03,0.03,0.02,0.02,0.02])
            hourly_profile = hourly_profile / hourly_profile.sum()
            hourly_rev = rng.multinomial(int(revenue_day), hourly_profile)
            for hr in range(24):
                ph_hourly_rows.append([dt, b, dt.weekday(), hr, hourly_rev[hr]])

    pharmacy = pd.DataFrame(ph_rows, columns=[
        "fecha", "sucursal", "medicamento", "categoria", "cantidad", "precio_unit", "ingreso", "trx_cat"
    ])
    pharmacy_hourly = pd.DataFrame(ph_hourly_rows, columns=[
        "fecha", "sucursal", "weekday", "hora", "ingreso"
    ])

    # Generar número de transacciones por día/sucursal (para KPI ticket)
    trx_by_day = (pharmacy.groupby(["fecha","sucursal"])["cantidad"]
                  .sum().rename("trx").reset_index())
    revenue_by_day = (pharmacy.groupby(["fecha","sucursal"])["ingreso"]
                      .sum().rename("ingreso").reset_index())
    pharmacy_daily = trx_by_day.merge(revenue_by_day, on=["fecha","sucursal"], how="left")

    # --------------------------- LABORATORIO ----------------------------
    studies = ["Química Sanguínea", "Hemograma", "Perfil Tiroideo",
               "Orina Completa", "Panel Respiratorio", "Antígeno COVID",
               "Marcadores Tumorales", "Prenatal"]
    base_price = {
        "Química Sanguínea": 350, "Hemograma": 280, "Perfil Tiroideo": 550,
        "Orina Completa": 160, "Panel Respiratorio": 850, "Antígeno COVID": 350,
        "Marcadores Tumorales": 1600, "Prenatal": 900
    }
    lab_rows = []
    for stype in studies:
        for i, dt in enumerate(days):
            w = week_mult[i]
            # Volumen base por estudio
            base_mu = {
                "Química Sanguínea": 50, "Hemograma": 45, "Perfil Tiroideo": 20,
                "Orina Completa": 35, "Panel Respiratorio": 18, "Antígeno COVID": 12,
                "Marcadores Tumorales": 5, "Prenatal": 10
            }[stype]

            # Estacionalidad respiratoria
            seasonal = resp_mult[i] if stype in ("Panel Respiratorio", "Antígeno COVID") else 1.0
            count = int(max(0, rng.poisson(lam=base_mu * w * seasonal)))
            if count == 0:
                continue
            price = max(80, rng.normal(base_price[stype], base_price[stype]*0.1))
            # Tiempos de entrega (horas), con colas para algunos
            if stype in ("Marcadores Tumorales", "Perfil Tiroideo"):
                tat = np.clip(rng.normal(48, 12, size=count), 8, 120)  # más largos
            else:
                tat = np.clip(rng.normal(18, 6, size=count), 2, 72)
            revenue = price * count
            lab_rows.append([dt, stype, count, revenue, tat.mean()])

    lab = pd.DataFrame(lab_rows, columns=["fecha","estudio","n_estudios","ingreso","tat_promedio_h"])

    # ----------------------------- HOSPITAL ------------------------------
    specialties = ["Cardiología","Pediatría","Ginecología","Traumatología",
                   "Oncología","Medicina Interna","UCI"]
    # capacidad de camas por especialidad
    cap = {"Cardiología": 18, "Pediatría": 22, "Ginecología": 16,
           "Traumatología": 20, "Oncología": 14, "Medicina Interna": 28, "UCI": 12}
    price_per_adm = {"Cardiología": 18000, "Pediatría": 12000, "Ginecología": 16000,
                     "Traumatología": 15000, "Oncología": 24000, "Medicina Interna": 14000, "UCI": 32000}
    hosp_rows = []
    for spec in specialties:
        c = cap[spec]
        for i, dt in enumerate(days):
            # Ocupación con estacionalidad y algo de tendencia
            seasonal = 0.06*np.sin(2*np.pi*i/365) + 0.03*np.cos(2*np.pi*i/30)
            occ_rate = np.clip(0.72 + seasonal + rng.normal(0, 0.05), 0.40, 0.98)
            beds_used = int(np.round(c * occ_rate))
            # Ingresos/altas ~ a camas usadas + ruido
            admissions = max(0, int(rng.poisson(lam=max(2, beds_used * 0.6))))
            discharges = max(0, int(rng.poisson(lam=max(2, beds_used * 0.55))))
            los = np.clip(rng.normal(3.6, 0.8), 1.2, 9.0)  # días
            satisfaction = np.clip(86 - (los-3.6)*2 + rng.normal(0,3), 60, 98)
            revenue = admissions * price_per_adm[spec]
            hosp_rows.append([dt, spec, c, beds_used, admissions, discharges, los, satisfaction, revenue])

    hospital = pd.DataFrame(hosp_rows, columns=[
        "fecha","especialidad","camas_cap","camas_usadas","ingresos","altas","los_promedio","satisfaccion","ingreso"
    ])

    # ----------------------------- FERTILIDAD ----------------------------
    procedures = ["FIV", "Inseminación Artificial", "Preservación de Óvulos", "PGT-A"]
    age_groups = ["<30","30-34","35-39","40-44","45+"]
    fert_rows = []
    for proc in procedures:
        base_conv = {"FIV":0.45, "Inseminación Artificial":0.28, "Preservación de Óvulos":0.70, "PGT-A":0.55}[proc]
        for g in age_groups:
            # Prob de éxito por grupo
            success_mult = {"<30":1.0, "30-34":0.95, "35-39":0.75, "40-44":0.45, "45+":0.20}[g]
            for i, dt in enumerate(days):
                demand = max(0, int(rng.poisson(lam=8 if proc=="FIV" else 5)))
                consults = demand + int(rng.poisson(lam=3))
                cycles = int(np.round(consults * base_conv * rng.uniform(0.85,1.15)))
                # Éxitos según edad y procedimiento
                success_rate = np.clip(0.55 * success_mult if proc=="FIV" else 0.35 * success_mult, 0.05, 0.8)
                successes = int(np.round(cycles * success_rate * rng.uniform(0.85,1.15)))
                fert_rows.append([dt, proc, g, consults, cycles, successes])

    fertility = pd.DataFrame(fert_rows, columns=[
        "fecha","procedimiento","grupo_edad","consultas","ciclos","exitos"
    ])

    # --------------------- INGRESOS CONSOLIDADOS (por unidad) ------------------
    # Hospital ya tiene "ingreso" diario por especialidad; sumamos
    hosp_rev_daily = hospital.groupby("fecha")["ingreso"].sum().rename("Hospital")
    lab_rev_daily  = lab.groupby("fecha")["ingreso"].sum().rename("Laboratorio")
    ph_rev_daily   = pharmacy.groupby("fecha")["ingreso"].sum().rename("Farmacia")
    # Fertilidad: estimar ingreso por ciclo y consulta
    fert_rev_daily = (fertility.assign(ingreso = fertility["ciclos"]*18000 + fertility["consultas"]*800)
                      .groupby("fecha")["ingreso"].sum().rename("Fertilidad"))
    revenue_all = pd.concat([hosp_rev_daily, lab_rev_daily, ph_rev_daily, fert_rev_daily], axis=1).fillna(0)
    revenue_all["Total"] = revenue_all.sum(axis=1)

    return {
        "pharmacy": pharmacy,
        "pharmacy_daily": pharmacy_daily,
        "pharmacy_hourly": pharmacy_hourly,
        "lab": lab,
        "hospital": hospital,
        "fertility": fertility,
        "revenue": revenue_all,
        "meta": {
            "branches": branches,
            "categories": categories,
            "studies": studies,
            "specialties": specialties,
            "procedures": procedures,
            "age_groups": age_groups,
            "date_range": DateRange(start=days.min(), end=days.max())
        }
    }

def default_dates(meta: dict, months_back: int = 3) -> tuple[date,date]:
    """Rango por defecto (últimos N meses) en date objects para widgets."""
    end = meta["date_range"].end.date()
    start = (meta["date_range"].end - pd.DateOffset(months=months_back)).date()
    return start, end

def forecast_next_7_days(series: pd.Series) -> pd.DataFrame:
    """
    Proyección simple para los próximos 7 días combinando:
    - Tendencia lineal últimos 30 días (polyfit)
    - Promedio móvil últimos 7 días
    Retorna DataFrame con 'fecha', 'yhat'.
    """
    s = series.dropna().astype(float)
    s = s[s > 0]
    if len(s) < 14:
        future_dates = pd.date_range(s.index.max() + pd.Timedelta(days=1), periods=7)
        return pd.DataFrame({"fecha": future_dates, "yhat": np.repeat(s.mean() if len(s) else 0, 7)})

    tail = s.tail(30) if len(s) >= 30 else s
    x = np.arange(len(tail))
    m, b = np.polyfit(x, tail.values, 1)
    last_mean = s.tail(7).mean()
    y0 = tail.values[-1]
    future = []
    for i in range(1, 8):
        trend = b + m * (len(tail) + i - 1)
        blended = max(0, 0.55 * trend + 0.45 * last_mean)
        # Suavizado con el último valor
        blended = 0.7 * blended + 0.3 * y0
        future.append(blended)
    future_dates = pd.date_range(s.index.max() + pd.Timedelta(days=1), periods=7)
    return pd.DataFrame({"fecha": future_dates, "yhat": future})

def insight_of_the_day(data: dict) -> str:
    """
    Regla simple que busca señales destacables recientes.
    Prioridad: aumento en estudios respiratorios -> cambio en ticket farmacia -> ocupación.
    """
    # 1) Laboratorio respiratorio (últimos 7d vs previos 7d)
    lab = data["lab"]
    max_date = lab["fecha"].max()
    w1 = (lab["fecha"] > max_date - pd.Timedelta(days=7))
    w0 = (lab["fecha"] <= max_date - pd.Timedelta(days=7)) & (lab["fecha"] > max_date - pd.Timedelta(days=14))
    mask_resp = lab["estudio"].isin(["Panel Respiratorio","Antígeno COVID"])
    v1 = lab[w1 & mask_resp]["n_estudios"].sum()
    v0 = lab[w0 & mask_resp]["n_estudios"].sum()
    if v0 > 0:
        delta = (v1 - v0) / v0
        if delta >= 0.15:
            return f"Se detecta un aumento de {delta*100:.0f}% en estudios respiratorios en la última semana vs. la anterior."

    # 2) Ticket promedio farmacia
    ph = data["pharmacy_daily"]
    maxd = ph["fecha"].max()
    a = ph[ph["fecha"] > maxd - pd.Timedelta(days=7)].groupby("fecha")[["ingreso","trx"]].sum().sum()
    b = ph[(ph["fecha"] <= maxd - pd.Timedelta(days=7)) & (ph["fecha"] > maxd - pd.Timedelta(days=14))]\
          .groupby("fecha")[["ingreso","trx"]].sum().sum()
    if (a["trx"] > 0) and (b["trx"] > 0):
        t1, t0 = a["ingreso"]/a["trx"], b["ingreso"]/b["trx"]
        d = (t1 - t0) / t0 if t0 > 0 else 0
        if abs(d) >= 0.08:
            sign = "aumento" if d > 0 else "disminución"
            return f"El ticket promedio en farmacias muestra un {sign} de {abs(d)*100:.1f}% en la última semana."
    # 3) Ocupación hospitalaria promedio 7d vs 30d
    hosp = data["hospital"]
    maxh = hosp["fecha"].max()
    occ7 = (hosp[hosp["fecha"] > maxh - pd.Timedelta(days=7)]["camas_usadas"].sum() /
            hosp[hosp["fecha"] > maxh - pd.Timedelta(days=7)]["camas_cap"].sum())
    occ30 = (hosp[hosp["fecha"] > maxh - pd.Timedelta(days=30)]["camas_usadas"].sum() /
             hosp[hosp["fecha"] > maxh - pd.Timedelta(days=30)]["camas_cap"].sum())
    d = occ7 - occ30
    if abs(d) >= 0.05:
        sign = "alza" if d > 0 else "baja"
        return f"La ocupación hospitalaria registra una {sign} de {abs(d)*100:.1f}% vs. promedio 30 días."
    return "Demanda estable. Sin señales atípicas relevantes en la última semana."


# ---------------------------- CARGA DE DATOS -------------------------------
DATA = generate_mock_data(seed=123)
META = DATA["meta"]

# ------------------------------ SIDEBAR UI ---------------------------------
logo_data = try_load_logo(["assets/almha_logo.png", "almha_logo.png"])
if logo_data:
    st.sidebar.markdown(f'<img src="{logo_data}" style="width: 60%; border-radius: 12px; margin-bottom: 8px;" />', unsafe_allow_html=True)
else:
    st.sidebar.markdown("## Almha\n**El hospital que abraza tu salud**")

st.sidebar.caption("Dashboard de Business Intelligence — Prototipo")

page = st.sidebar.radio("Navegación", [
    "🏠 Resumen Ejecutivo",
    "💊 Análisis de Farmacias",
    "🧪 Análisis de Laboratorio",
    "🏥 Análisis Hospitalario",
    "👶 Análisis de Fertilidad",
])

st.sidebar.markdown("---")
st.sidebar.markdown(
    "<div class='small'>*Nota:* Datos simulados con fines demostrativos. "
    "Puedes reemplazar el generador por tus fuentes reales (SQL/CSV/APIs) y adaptar los filtros.</div>",
    unsafe_allow_html=True
)


# ---------------------------- PÁGINA: RESUMEN ------------------------------
def page_resumen(data: dict):
    st.markdown("## 🏠 Resumen Ejecutivo")
    st.markdown("Una vista 360° del estado general del hospital y sus unidades de negocio.")

    revenue = data["revenue"]
    hosp = data["hospital"]
    lab = data["lab"]
    fert = data["fertility"]

    # Último mes
    end = revenue.index.max()
    start = end - pd.Timedelta(days=30)

    # KPIs
    # Pacientes atendidos = ingresos hospitalarios + estudios de laboratorio + consultas de fertilidad (aprox.)
    patients = (
        hosp[(hosp["fecha"] > start) & (hosp["fecha"] <= end)]["ingresos"].sum() +
        lab[(lab["fecha"] > start) & (lab["fecha"] <= end)]["n_estudios"].sum() +
        fert[(fert["fecha"] > start) & (fert["fecha"] <= end)]["consultas"].sum()
    )

    # Ocupación (promedio último mes)
    occ = (
        hosp[(hosp["fecha"] > start) & (hosp["fecha"] <= end)]["camas_usadas"].sum() /
        hosp[(hosp["fecha"] > start) & (hosp["fecha"] <= end)]["camas_cap"].sum()
    )

    # Ingresos consolidados (último mes)
    total_rev_30d = revenue[(revenue.index > start) & (revenue.index <= end)]["Total"].sum()

    # Satisfacción promedio (hospital)
    sat = hosp[(hosp["fecha"] > start) & (hosp["fecha"] <= end)]["satisfaccion"].mean()

    # KPIs en 4 columnas
    c1, c2, c3, c4 = st.columns(4)
    with c1:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Pacientes atendidos (últ. 30 días)", f"{int(patients):,}")
        st.markdown("</div>", unsafe_allow_html=True)
    with c2:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Ocupación hospitalaria", pct(occ))
        st.markdown("</div>", unsafe_allow_html=True)
    with c3:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Ingresos consolidados", money(total_rev_30d))
        st.markdown("</div>", unsafe_allow_html=True)
    with c4:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Satisfacción promedio", f"{sat:.1f}/100")
        st.markdown("</div>", unsafe_allow_html=True)

    # Serie de ingresos por unidad
    st.markdown("### Evolución de ingresos por unidad")
    rev_long = (
        revenue.reset_index()
               .melt(id_vars="fecha", value_vars=["Hospital","Farmacia","Laboratorio","Fertilidad"],
                     var_name="Unidad", value_name="Ingreso")
    )
    fig = px.line(rev_long, x="fecha", y="Ingreso", color="Unidad", markers=False,
                  template="plotly_dark", title="Ingresos diarios por unidad")
    fig.update_layout(height=420, margin=dict(l=10,r=10,t=50,b=10))
    st.plotly_chart(fig, use_container_width=True)

    # Insight del día
    st.markdown("### 💡 Insight del día")
    st.markdown(f"<div class='callout'>{insight_of_the_day(data)}</div>", unsafe_allow_html=True)


# -------------------------- PÁGINA: FARMACIAS ------------------------------
def page_farmacias(data: dict):
    st.markdown("## 💊 Análisis de Farmacias")

    ph = data["pharmacy"]
    ph_daily = data["pharmacy_daily"]
    ph_hourly = data["pharmacy_hourly"]
    branches = data["meta"]["branches"]
    start_def, end_def = default_dates(data["meta"], months_back=3)

    c1, c2 = st.columns([2,2])
    with c1:
        dr = st.date_input("Rango de fechas", (start_def, end_def), min_value=data["meta"]["date_range"].start.date(),
                           max_value=data["meta"]["date_range"].end.date())
        start, end = pd.to_datetime(dr[0]), pd.to_datetime(dr[1])
    with c2:
        suc = st.multiselect("Sucursales", options=branches, default=branches)

    # Filtrado
    phf = ph[(ph["fecha"] >= start) & (ph["fecha"] <= end) & (ph["sucursal"].isin(suc))]
    phdf = ph_daily[(ph_daily["fecha"] >= start) & (ph_daily["fecha"] <= end) & (ph_daily["sucursal"].isin(suc))]
    phhf = ph_hourly[(ph_hourly["fecha"] >= start) & (ph_hourly["fecha"] <= end) & (ph_hourly["sucursal"].isin(suc))]

    if phf.empty:
        st.warning("No hay datos para el filtro seleccionado.")
        return

    # KPIs
    total_ingreso = phf["ingreso"].sum()
    total_trx = phdf["trx"].sum()
    ticket_prom = total_ingreso / total_trx if total_trx > 0 else np.nan

    k1, k2, k3 = st.columns(3)
    with k1:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Ingresos (farmacia)", money(total_ingreso))
        st.markdown("</div>", unsafe_allow_html=True)
    with k2:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Transacciones", f"{int(total_trx):,}")
        st.markdown("</div>", unsafe_allow_html=True)
    with k3:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Ticket promedio", money(ticket_prom))
        st.markdown("</div>", unsafe_allow_html=True)

    # Top 10 medicamentos
    st.markdown("### Top 10 Medicamentos más vendidos (por ingreso)")
    top10 = (phf.groupby(["medicamento"])["ingreso"].sum()
               .sort_values(ascending=False).head(10).reset_index())
    fig1 = px.bar(top10, x="medicamento", y="ingreso", template="plotly_dark")
    fig1.update_layout(height=420, xaxis_title="", yaxis_title="Ingreso (MXN)", margin=dict(l=10,r=10,t=10,b=100))
    st.plotly_chart(fig1, use_container_width=True)

    # Pie por categoría
    st.markdown("### Distribución de ventas por categoría")
    cat = (phf.groupby("categoria")["ingreso"].sum().reset_index())
    fig2 = px.pie(cat, names="categoria", values="ingreso", hole=0.45, template="plotly_dark")
    fig2.update_layout(height=420, margin=dict(l=10,r=10,t=10,b=10))
    st.plotly_chart(fig2, use_container_width=True)

    # Heatmap día x hora
    st.markdown("### Mapa de calor — picos de ventas (día de semana × hora)")
    # 0=Lunes ... 6=Domingo
    heat = phhf.copy()
    heat["dow"] = heat["weekday"].map({0:"Lun",1:"Mar",2:"Mié",3:"Jue",4:"Vie",5:"Sáb",6:"Dom"})
    fig3 = px.density_heatmap(heat, x="hora", y="dow", z="ingreso", nbinsx=24, histfunc="sum",
                              color_continuous_scale="Blues", template="plotly_dark")
    fig3.update_layout(height=420, margin=dict(l=10,r=10,t=10,b=10), xaxis_nticks=24)
    st.plotly_chart(fig3, use_container_width=True)

    # Forecast próximos 7 días
    st.markdown("### Proyección de demanda — próximos 7 días")
    daily_rev = (phdf.groupby("fecha")["ingreso"].sum().sort_index())
    fc = forecast_next_7_days(daily_rev)
    hist = daily_rev.reset_index().rename(columns={"ingreso":"y"})
    hist_recent = hist[hist["fecha"] >= hist["fecha"].max() - pd.Timedelta(days=120)]

    fig4 = go.Figure()
    fig4.add_trace(go.Scatter(x=hist_recent["fecha"], y=hist_recent["y"], mode="lines", name="Histórico"))
    fig4.add_trace(go.Scatter(x=fc["fecha"], y=fc["yhat"], mode="lines+markers", name="Forecast", line=dict(dash="dash")))
    fig4.update_layout(template="plotly_dark", height=420, margin=dict(l=10,r=10,t=10,b=10),
                       yaxis_title="Ingreso (MXN)")
    st.plotly_chart(fig4, use_container_width=True)


# ------------------------- PÁGINA: LABORATORIO -----------------------------
def page_laboratorio(data: dict):
    st.markdown("## 🧪 Análisis de Laboratorio")

    lab = data["lab"]
    studies = data["meta"]["studies"]
    start_def, end_def = default_dates(data["meta"], months_back=3)

    c1, c2 = st.columns([2,2])
    with c1:
        dr = st.date_input("Rango de fechas", (start_def, end_def),
                           min_value=data["meta"]["date_range"].start.date(),
                           max_value=data["meta"]["date_range"].end.date())
        start, end = pd.to_datetime(dr[0]), pd.to_datetime(dr[1])
    with c2:
        sel = st.multiselect("Tipo de estudio", options=studies, default=studies)

    lf = lab[(lab["fecha"] >= start) & (lab["fecha"] <= end) & (lab["estudio"].isin(sel))]
    if lf.empty:
        st.warning("No hay datos para el filtro seleccionado.")
        return

    # KPIs
    k1, k2, k3 = st.columns(3)
    with k1:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Estudios realizados", f"{int(lf['n_estudios'].sum()):,}")
        st.markdown("</div>", unsafe_allow_html=True)
    with k2:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Tiempo promedio de entrega", f"{lf['tat_promedio_h'].mean():.1f} h")
        st.markdown("</div>", unsafe_allow_html=True)
    with k3:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Ingresos generados", money(lf["ingreso"].sum()))
        st.markdown("</div>", unsafe_allow_html=True)

    # Top estudios
    st.markdown("### Estudios más solicitados")
    top = lf.groupby("estudio")["n_estudios"].sum().sort_values(ascending=False).reset_index()
    fig1 = px.bar(top, x="estudio", y="n_estudios", template="plotly_dark")
    fig1.update_layout(height=420, margin=dict(l=10,r=10,t=10,b=100), xaxis_title="", yaxis_title="Número de estudios")
    st.plotly_chart(fig1, use_container_width=True)

    # Histograma TAT
    st.markdown("### Distribución de tiempos de entrega (horas)")
    # Expandir por conteo aprox: repetimos tat_promedio_h según n_estudios (muestra) para demo
    # (en un caso real usaríamos observaciones a nivel estudio)
    tat_sample = np.repeat(lf["tat_promedio_h"].values, np.clip((lf["n_estudios"]/5).astype(int), 1, 40))
    fig2 = px.histogram(x=tat_sample, nbins=30, template="plotly_dark")
    fig2.update_layout(height=420, margin=dict(l=10,r=10,t=10,b=10), xaxis_title="Horas", yaxis_title="Frecuencia")
    st.plotly_chart(fig2, use_container_width=True)

    # Serie diaria de volumen
    st.markdown("### Volumen diario de estudios")
    daily = lf.groupby("fecha")["n_estudios"].sum().reset_index()
    fig3 = px.line(daily, x="fecha", y="n_estudios", template="plotly_dark")
    fig3.update_layout(height=420, margin=dict(l=10,r=10,t=10,b=10), yaxis_title="Estudios")
    st.plotly_chart(fig3, use_container_width=True)


# ------------------------- PÁGINA: HOSPITALARIO ----------------------------
def page_hospitalario(data: dict):
    st.markdown("## 🏥 Análisis Hospitalario")

    hosp = data["hospital"]
    specs = data["meta"]["specialties"]
    start_def, end_def = default_dates(data["meta"], months_back=3)

    c1, c2 = st.columns([2,2])
    with c1:
        dr = st.date_input("Rango de fechas", (start_def, end_def),
                           min_value=data["meta"]["date_range"].start.date(),
                           max_value=data["meta"]["date_range"].end.date())
        start, end = pd.to_datetime(dr[0]), pd.to_datetime(dr[1])
    with c2:
        sel = st.multiselect("Especialidad médica", options=specs, default=specs)

    hf = hosp[(hosp["fecha"] >= start) & (hosp["fecha"] <= end) & (hosp["especialidad"].isin(sel))]
    if hf.empty:
        st.warning("No hay datos para el filtro seleccionado.")
        return

    # KPIs básicos
    occ = hf["camas_usadas"].sum() / hf["camas_cap"].sum()
    los = hf["los_promedio"].mean()
    adm = hf["ingresos"].sum()
    dis = hf["altas"].sum()

    k1,k2,k3,k4 = st.columns(4)
    with k1:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Ocupación actual (promedio)", pct(occ))
        st.markdown("</div>", unsafe_allow_html=True)
    with k2:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Estancia promedio", f"{los:.1f} días")
        st.markdown("</div>", unsafe_allow_html=True)
    with k3:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Ingresos", f"{int(adm):,}")
        st.markdown("</div>", unsafe_allow_html=True)
    with k4:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Altas", f"{int(dis):,}")
        st.markdown("</div>", unsafe_allow_html=True)

    # Gauge de ocupación (último día del rango)
    last_day = hf["fecha"].max()
    hlast = hf[hf["fecha"] == last_day]
    occ_last = hlast["camas_usadas"].sum() / hlast["camas_cap"].sum()
    st.markdown("### Tasa de ocupación (último día del rango)")
    gauge = go.Figure(go.Indicator(
        mode="gauge+number",
        value=occ_last * 100,
        number={'suffix': "%"},
        gauge={
            'axis': {'range': [0,100]},
            'bar': {'color': PRIMARY},
            'steps': [
                {'range': [0, 80], 'color': "rgba(34,197,94,0.30)"},
                {'range': [80, 90], 'color': "rgba(234,179,8,0.35)"},
                {'range': [90, 100], 'color': "rgba(239,68,68,0.40)"},
            ],
            'threshold': {'line': {'color': DANGER, 'width': 4}, 'thickness': 0.75, 'value': 90}
        }
    ))
    gauge.update_layout(height=260, margin=dict(l=10,r=10,t=10,b=10), template="plotly_dark")
    st.plotly_chart(gauge, use_container_width=True)

    # Pacientes por especialidad
    st.markdown("### Pacientes atendidos por especialidad")
    by_spec = hf.groupby("especialidad")["ingresos"].sum().sort_values(ascending=False).reset_index()
    fig1 = px.bar(by_spec, x="especialidad", y="ingresos", template="plotly_dark")
    fig1.update_layout(height=420, margin=dict(l=10,r=10,t=10,b=100), xaxis_title="", yaxis_title="Pacientes")
    st.plotly_chart(fig1, use_container_width=True)

    # Dispersión LOS vs Satisfacción (agregado por especialidad)
    st.markdown("### Estancia (días) vs. Satisfacción del paciente")
    agg = hf.groupby("especialidad").agg(los=("los_promedio","mean"),
                                         sat=("satisfaccion","mean"),
                                         n=("ingresos","sum")).reset_index()
    fig2 = px.scatter(agg, x="los", y="sat", size="n", color="especialidad", hover_name="especialidad",
                      labels={"los":"LOS promedio (días)","sat":"Satisfacción promedio"},
                      template="plotly_dark")
    fig2.update_layout(height=420, margin=dict(l=10,r=10,t=10,b=10))
    st.plotly_chart(fig2, use_container_width=True)


# -------------------------- PÁGINA: FERTILIDAD -----------------------------
def page_fertilidad(data: dict):
    st.markdown("## 👶 Análisis de Fertilidad")

    ft = data["fertility"]
    procs = data["meta"]["procedures"]
    ages = data["meta"]["age_groups"]
    start_def, end_def = default_dates(data["meta"], months_back=6)

    c1, c2, c3 = st.columns([2,2,2])
    with c1:
        dr = st.date_input("Rango de fechas", (start_def, end_def),
                           min_value=data["meta"]["date_range"].start.date(),
                           max_value=data["meta"]["date_range"].end.date())
        start, end = pd.to_datetime(dr[0]), pd.to_datetime(dr[1])
    with c2:
        sel_proc = st.multiselect("Procedimiento", options=procs, default=procs)
    with c3:
        sel_age = st.multiselect("Grupo de edad", options=ages, default=ages)

    ff = ft[(ft["fecha"] >= start) & (ft["fecha"] <= end) &
            (ft["procedimiento"].isin(sel_proc)) & (ft["grupo_edad"].isin(sel_age))]
    if ff.empty:
        st.warning("No hay datos para el filtro seleccionado.")
        return

    # KPIs
    cycles = ff["ciclos"].sum()
    success = ff["exitos"].sum()
    consults = ff["consultas"].sum()
    tasa_exito = success / cycles if cycles > 0 else np.nan
    conversion = cycles / consults if consults > 0 else np.nan

    k1,k2,k3 = st.columns(3)
    with k1:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Tasa de éxito", pct(tasa_exito))
        st.markdown("</div>", unsafe_allow_html=True)
    with k2:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Ciclos iniciados", f"{int(cycles):,}")
        st.markdown("</div>", unsafe_allow_html=True)
    with k3:
        st.markdown("<div class='metric-card'>", unsafe_allow_html=True)
        st.metric("Tasa de conversión (consulta → tratamiento)", pct(conversion))
        st.markdown("</div>", unsafe_allow_html=True)

    # Éxito por procedimiento
    st.markdown("### Tasa de éxito por procedimiento")
    by_proc = ff.groupby("procedimiento").agg(cic=("ciclos","sum"), exi=("exitos","sum"))
    by_proc["tasa"] = np.where(by_proc["cic"]>0, by_proc["exi"]/by_proc["cic"], np.nan)
    fig1 = px.bar(by_proc.reset_index(), x="procedimiento", y="tasa", template="plotly_dark")
    fig1.update_layout(height=420, yaxis_tickformat=".0%", margin=dict(l=10,r=10,t=10,b=100),
                       xaxis_title="", yaxis_title="Tasa de éxito")
    st.plotly_chart(fig1, use_container_width=True)

    # Histograma de edad de pacientes (aprox: expandimos grupos a edades simuladas)
    st.markdown("### Distribución de edad de los pacientes")
    # Expandir a 1000 muestras ponderadas por consultas
    weights = ff.groupby("grupo_edad")["consultas"].sum().reindex(ages).fillna(0).values
    if weights.sum() == 0:
        ages_samples = np.array([])
    else:
        n = 1000
        probs = weights / weights.sum() if weights.sum()>0 else np.ones_like(weights)/len(weights)
        groups_draw = np.random.choice(ages, size=n, p=probs)
        bounds = {"<30":(22,29), "30-34":(30,34), "35-39":(35,39), "40-44":(40,44), "45+":(45,49)}
        ages_samples = np.array([np.random.randint(*bounds[g]) for g in groups_draw])
    fig2 = px.histogram(x=ages_samples, nbins=20, template="plotly_dark")
    fig2.update_layout(height=420, margin=dict(l=10,r=10,t=10,b=10), xaxis_title="Edad", yaxis_title="Pacientes (sim.)")
    st.plotly_chart(fig2, use_container_width=True)

    # Funnel de conversión
    st.markdown("### Embudo de conversión")
    steps = {
        "Consultas": consults,
        "Ciclos": cycles,
        "Embarazos exitosos": success
    }
    fig3 = go.Figure(go.Funnel(
        y=list(steps.keys()),
        x=list(steps.values()),
        textinfo="value+percent initial",
        marker={"color": [PRIMARY, "#60a5fa", "#93c5fd"]}
    ))
    fig3.update_layout(template="plotly_dark", height=420, margin=dict(l=10,r=10,t=10,b=10))
    st.plotly_chart(fig3, use_container_width=True)


# ------------------------------ ROUTER -------------------------------------
if page.startswith("🏠"):
    page_resumen(DATA)
elif page.startswith("💊"):
    page_farmacias(DATA)
elif page.startswith("🧪"):
    page_laboratorio(DATA)
elif page.startswith("🏥"):
    page_hospitalario(DATA)
elif page.startswith("👶"):
    page_fertilidad(DATA)
else:
    st.write("Página no encontrada.")

# ------------------------------ FOOTER -------------------------------------
st.markdown("---")
st.caption("© Almha — Prototipo de Business Intelligence. Datos simulados.")
