# app.py — Aires MX • Dashboard con Forecast y Alertas (versión depurada)
# ======================================================================
# Requiere: streamlit, plotly, pandas, numpy, statsmodels
# pip install streamlit plotly pandas numpy statsmodels

import numpy as np
import pandas as pd
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import seaborn as sns
import matplotlib.pyplot as plt

st.set_page_config(page_title="Aires MX • Dashboard", page_icon="❄️", layout="wide")
np.random.seed(42)

# ----------------------------- Catálogos MX -----------------------------
ZONA_POR_ESTADO = {
    "Baja California": "Noroeste", "Baja California Sur": "Noroeste", "Sonora": "Noroeste",
    "Chihuahua": "Noreste", "Coahuila": "Noreste", "Nuevo León": "Noreste", "Tamaulipas": "Noreste",
    "Sinaloa": "Occidente", "Nayarit": "Occidente", "Jalisco": "Occidente", "Colima": "Occidente", "Michoacán": "Occidente",
    "Aguascalientes": "Bajío", "Guanajuato": "Bajío", "Querétaro": "Bajío", "San Luis Potosí": "Bajío", "Zacatecas": "Bajío",
    "Ciudad de México": "Centro", "Estado de México": "Centro", "Morelos": "Centro", "Hidalgo": "Centro", "Puebla": "Centro", "Tlaxcala": "Centro",
    "Veracruz": "Sureste", "Tabasco": "Sureste", "Oaxaca": "Sureste", "Chiapas": "Sureste",
    "Yucatán": "Sureste", "Quintana Roo": "Sureste", "Campeche": "Sureste",
    "Durango": "Noroeste", "Guerrero": "Occidente"
}

SUCURSALES = [
    # Noroeste
    ("Tijuana","Baja California",32.5149,-117.0382), ("Mexicali","Baja California",32.6245,-115.4523),
    ("Ensenada","Baja California",31.8667,-116.5964), ("La Paz","Baja California Sur",24.1426,-110.3128),
    ("Cabo San Lucas","Baja California Sur",22.8905,-109.9167), ("Hermosillo","Sonora",29.0729,-110.9559),
    # Noreste
    ("Chihuahua","Chihuahua",28.6320,-106.0691), ("Ciudad Juárez","Chihuahua",31.6904,-106.4245),
    ("Torreón","Coahuila",25.5428,-103.4068), ("Saltillo","Coahuila",25.4380,-100.9730),
    ("Monterrey","Nuevo León",25.6866,-100.3161), ("Nuevo Laredo","Tamaulipas",27.4763,-99.5164),
    ("Reynosa","Tamaulipas",26.0922,-98.2773),
    # Occidente
    ("Culiacán","Sinaloa",24.7995,-107.3897), ("Mazatlán","Sinaloa",23.2494,-106.4111),
    ("Tepic","Nayarit",21.5095,-104.8957), ("Guadalajara","Jalisco",20.6597,-103.3496),
    ("Puerto Vallarta","Jalisco",20.6534,-105.2253), ("Colima","Colima",19.2433,-103.7250),
    ("Morelia","Michoacán",19.7059,-101.1942),
    # Bajío
    ("León","Guanajuato",21.1250,-101.6860), ("Celaya","Guanajuato",20.5222,-100.8122),
    ("Querétaro","Querétaro",20.5881,-100.3881), ("Aguascalientes","Aguascalientes",21.8853,-102.2916),
    ("San Luis Potosí","San Luis Potosí",22.1565,-100.9855), ("Zacatecas","Zacatecas",22.7709,-102.5833),
    # Centro
    ("CDMX","Ciudad de México",19.4326,-99.1332), ("Toluca","Estado de México",19.2826,-99.6557),
    ("Cuernavaca","Morelos",18.9242,-99.2216), ("Pachuca","Hidalgo",20.1011,-98.7591),
    ("Puebla","Puebla",19.0437,-98.1981), ("Tlaxcala","Tlaxcala",19.3182,-98.2370),
    # Golfo / Sur
    ("Xalapa","Veracruz",19.5438,-96.9104), ("Veracruz","Veracruz",19.1738,-96.1342),
    ("Villahermosa","Tabasco",17.9895,-92.9475), ("Oaxaca","Oaxaca",17.0732,-96.7266),
    ("Tuxtla Gutiérrez","Chiapas",16.75,-93.1167), ("Durango","Durango",24.0277,-104.6532),
    ("Acapulco","Guerrero",16.8634,-99.8901),
    # Península
    ("Mérida","Yucatán",20.9674,-89.5926), ("Cancún","Quintana Roo",21.1619,-86.8515),
    ("Chetumal","Quintana Roo",18.499,-88.303), ("Campeche","Campeche",19.8301,-90.5349),
]

# ----------------------------- Datos sintéticos -----------------------------
@st.cache_data
def generar_datos(ini="2024-01-01", fin="2025-12-01"):
    rng = pd.date_range(ini, fin, freq="MS")
    rows = []
    rng_month = np.array([d.month for d in rng])
    for ciudad, estado, lat, lon in SUCURSALES:
        zona = ZONA_POR_ESTADO.get(estado, "Centro")
        zona_boost = {"Noroeste":1.20, "Noreste":1.15, "Occidente":1.05, "Bajío":1.00, "Centro":0.95, "Sureste":1.25}.get(zona,1.0)
        season = np.where((rng_month>=5)&(rng_month<=9), 1.3, 0.9)  # verano ↑
        trend = np.linspace(0.95, 1.10, len(rng)) * np.random.uniform(0.9, 1.1)
        base_ventas = np.random.randint(60, 140)
        precio_unit = np.random.randint(8_000, 15_000)
        costo_unit = np.random.randint(5_000, 7_200)
        for i, fecha in enumerate(rng):
            ruido = np.random.normal(1.0, 0.12)
            ventas = int(max(10, base_ventas * zona_boost * season[i] * trend[i] * ruido))
            ingresos = ventas * precio_unit
            costos = ventas * costo_unit
            utilidad = ingresos - costos
            margen = (utilidad / ingresos) if ingresos > 0 else 0.0
            demanda_est = int(np.round(ventas * np.random.uniform(1.05, 1.25)))
            rows.append({
                "Fecha": fecha, "Ciudad": ciudad, "Estado": estado, "Zona": zona,
                "Lat": lat, "Lon": lon,
                "Ventas": ventas, "Precio_Unit": precio_unit, "Costo_Unit": costo_unit,
                "Ingresos": ingresos, "Costos": costos, "Utilidad": utilidad,
                "Margen_porcentaje": round(margen*100, 2), "Demanda_Estimada": demanda_est
            })
    return pd.DataFrame(rows)

df = generar_datos()

# ----------------------------- Simulación de inventario -----------------------------
@st.cache_data
def sim_inventario(df_in):
    df = df_in.sort_values(["Ciudad","Fecha"]).copy()
    df["Stock_Inicial"] = 0
    df["Reabasto"] = 0
    df["Stock_Final"] = 0
    for ciudad, g in df.groupby("Ciudad", sort=False):
        stock = np.random.randint(80, 250)  # stock inicial aleatorio
        for idx in g.index:
            exp = int(df.loc[idx, "Demanda_Estimada"] * np.random.uniform(0.8, 1.0))
            reorder = int(1.5*exp) if stock < 1.2*exp else int(0.3*exp)  # política simple
            stock_ini = stock
            ventas = df.loc[idx, "Ventas"]
            stock_fin = max(0, stock_ini + reorder - ventas)
            df.at[idx, "Stock_Inicial"] = stock_ini
            df.at[idx, "Reabasto"] = reorder
            df.at[idx, "Stock_Final"] = stock_fin
            stock = stock_fin
    df["Demanda_Siguiente"] = df.groupby("Ciudad")["Demanda_Estimada"].shift(-1)
    df["Cobertura_Meses"] = df["Stock_Final"] / df["Demanda_Siguiente"]
    # esperado por estacionalidad
    df["Mes"] = df["Fecha"].dt.month
    season = (df.groupby("Mes")["Ingresos"].mean() / df["Ingresos"].mean()).rename("Factor_Estacional")
    df = df.merge(season, on="Mes", how="left")
    df["Ingresos_MA6"] = df.groupby("Ciudad")["Ingresos"].transform(lambda s: s.rolling(6, min_periods=3).mean())
    df["Ingresos_Esperados"] = (df["Ingresos_MA6"] * df["Factor_Estacional"]).fillna(df["Ingresos_MA6"])
    df["Underperf_%"] = (df["Ingresos"] - df["Ingresos_Esperados"]) / df["Ingresos_Esperados"]
    return df

df = sim_inventario(df)

# ----------------------------- Sidebar -----------------------------
st.sidebar.title("Filtros")
z_sel = st.sidebar.multiselect("Zonas", sorted(df["Zona"].unique()), default=sorted(df["Zona"].unique()))
e_sel = st.sidebar.multiselect("Estados", sorted(df[df["Zona"].isin(z_sel)]["Estado"].unique()))
c_sel = st.sidebar.multiselect("Ciudades", sorted(df[(df["Zona"].isin(z_sel)) & (df["Estado"].isin(e_sel) if e_sel else [True])]["Ciudad"].unique()))
fecha_min, fecha_max = st.sidebar.date_input(
    "Rango de fechas",
    value=(df["Fecha"].min().date(), df["Fecha"].max().date()),
    min_value=df["Fecha"].min().date(),
    max_value=df["Fecha"].max().date()
)
metric_map = {"Ventas":"Ventas","Ingresos":"Ingresos","Utilidad":"Utilidad","Margen (%)":"Margen_porcentaje"}
metric_sel = st.sidebar.selectbox("Métrica para mapa/rankings", list(metric_map.keys()))

st.sidebar.markdown("---")
st.sidebar.subheader("Umbrales de alertas")
margen_min = st.sidebar.number_input("Margen % mínimo", 0.0, 100.0, 25.0, 0.5)
cover_min = st.sidebar.number_input("Cobertura mínima (meses)", 0.0, 3.0, 0.6, 0.1)
underperf_thr = st.sidebar.number_input("Bajo desempeño (desvío %)", 0.0, 100.0, 10.0, 1.0) / 100.0

st.sidebar.markdown("---")
st.sidebar.subheader("Forecast")
ciudad_forecast = st.sidebar.selectbox("Ciudad para proyección", sorted(df["Ciudad"].unique()))
horizonte = st.sidebar.slider("Horizonte (meses)", 3, 18, 6)

# ----------------------------- Filtros -----------------------------
mask = (df["Zona"].isin(z_sel)) & (df["Fecha"].between(pd.to_datetime(fecha_min), pd.to_datetime(fecha_max)))
if e_sel:
    mask &= df["Estado"].isin(e_sel)
if c_sel:
    mask &= df["Ciudad"].isin(c_sel)
dff = df[mask].copy()

if dff.empty:
    st.warning("No hay datos con los filtros actuales. Ajusta Zona/Estado/Ciudad/Fechas.")
    st.stop()

ultimo_mes = dff["Fecha"].max()

# ----------------------------- KPIs -----------------------------
k1, k2, k3, k5, k4 = st.columns(5)
k1.metric("Ventas", f"{dff['Ventas'].sum():,.0f}")
k2.metric("Ingresos", f"${dff['Ingresos'].sum():,.0f}")
k3.metric("Costos", f"${dff['Costos'].sum():,.0f}")
k5.metric("Utilidad", f"${dff['Utilidad'].sum():,.0f}")
margen_total = (dff["Utilidad"].sum()/dff["Ingresos"].sum()*100) if dff["Ingresos"].sum() > 0 else 0
k4.metric("Margen promedio", f"{margen_total:,.1f}%")

st.markdown("---")

# ----------------------------- Mapa por ciudad -----------------------------
st.subheader(f" Mapa por ciudad — {metric_sel}")
agg_city = (dff.groupby(["Ciudad","Estado","Zona","Lat","Lon"], as_index=False)
            .agg(Ventas=("Ventas","sum"), Ingresos=("Ingresos","sum"), Utilidad=("Utilidad","sum"), Margen_porcentaje=("Margen_porcentaje","mean")))
fig_map = px.scatter_geo(
    agg_city, lat="Lat", lon="Lon", scope="north america",
    hover_name="Ciudad",
    hover_data={"Estado":True,"Zona":True,"Ventas":True,"Ingresos":":,","Utilidad":":,","Margen_porcentaje":":.1f"},
    size=metric_map[metric_sel], color=metric_map[metric_sel], size_max=30,
    title=f"Mapa (burbuja proporcional) — {metric_sel}"
)
fig_map.update_geos(fitbounds="locations", showcountries=True, countrycolor="#888", showland=True, landcolor="#f2f2f2")
st.plotly_chart(fig_map, use_container_width=True)

# ----------------------------- Tendencias -----------------------------
st.subheader("Tendencias")
colA, colB = st.columns(2)
with colA:
    st.plotly_chart(px.line(dff, x="Fecha", y="Ventas", color="Zona", markers=True, title="Ventas por zona"), use_container_width=True)
with colB:
    st.plotly_chart(px.line(dff, x="Fecha", y="Utilidad", color="Zona", markers=True, title="Utilidad por zona"), use_container_width=True)

# ----------------------------- Rankings -----------------------------
st.subheader("Rankings")
rank_city = (dff.groupby(["Ciudad","Zona"], as_index=False)
             .agg(Ventas=("Ventas","sum"),Ingresos=("Ingresos","sum"),Utilidad=("Utilidad","sum"),Margen_porcentaje=("Margen_porcentaje","mean")))
rank_metric_col = metric_map[metric_sel]
rank_city = rank_city.sort_values(rank_metric_col, ascending=False).head(15)

col1, col2 = st.columns(2)
with col1:
    fig_r1 = px.bar(rank_city, x=rank_metric_col, y="Ciudad", color="Zona", orientation="h",
                    title=f"Top sucursales por {metric_sel}", text=rank_metric_col)
    fig_r1.update_layout(yaxis={"categoryorder":"total ascending"})
    st.plotly_chart(fig_r1, use_container_width=True)

rank_zona = (dff.groupby("Zona", as_index=False)
             .agg(Ventas=("Ventas","sum"),Ingresos=("Ingresos","sum"),Utilidad=("Utilidad","sum"),Margen_porcentaje=("Margen_porcentaje","mean"))
             .sort_values(rank_metric_col, ascending=False))
with col2:
    st.plotly_chart(px.bar(rank_zona, x="Zona", y=rank_metric_col, title=f"Ranking por zona — {metric_sel}",
                           text=rank_metric_col), use_container_width=True)

# ----------------------------- Finanzas -----------------------------
st.subheader("Finanzas — Ingresos vs Costos")
agg_month = dff.groupby("Fecha", as_index=False).agg(Ingresos=("Ingresos","sum"), Costos=("Costos","sum"))
fig_fin = go.Figure()
fig_fin.add_trace(go.Bar(x=agg_month["Fecha"], y=agg_month["Ingresos"], name="Ingresos"))
fig_fin.add_trace(go.Bar(x=agg_month["Fecha"], y=agg_month["Costos"], name="Costos"))
fig_fin.update_layout(barmode="group", title="Ingresos vs Costos (agregado)")
st.plotly_chart(fig_fin, use_container_width=True)

# ----------------------------- Centro de Alertas -----------------------------
st.subheader("Centro de Alertas (priorizado por impacto en $)")
base_alertas = dff[dff["Fecha"] == ultimo_mes].copy()

base_alertas["Unidades_Riesgo"] = (base_alertas["Demanda_Siguiente"] - base_alertas["Stock_Final"]).clip(lower=0).fillna(0)
base_alertas["Impacto_Quiebre_MXN"] = base_alertas["Unidades_Riesgo"] * base_alertas["Precio_Unit"]

base_alertas["Impacto_Margen_MXN"] = ((margen_min/100) - (base_alertas["Margen_porcentaje"]/100)).clip(lower=0) * base_alertas["Ingresos"]

base_alertas["Gap_MXN"] = (base_alertas["Ingresos_Esperados"] - base_alertas["Ingresos"]).clip(lower=0)
alert_quiebre = base_alertas[base_alertas["Cobertura_Meses"] < cover_min].copy()
alert_quiebre["Alerta"] = "Riesgo quiebre de stock"

alert_margen = base_alertas[base_alertas["Margen_porcentaje"] < margen_min].copy()
alert_margen["Alerta"] = "Margen bajo"

alert_under = base_alertas[base_alertas["Underperf_%"] < -underperf_thr].copy()
alert_under["Alerta"] = "Bajo desempeño vs esperado"
alert_under["Underperf_%_abs"] = (alert_under["Underperf_%"].abs() * 100).round(1)

cols_base = ["Fecha","Zona","Estado","Ciudad","Ventas","Ingresos","Utilidad","Margen_porcentaje","Cobertura_Meses"]
alertas = pd.concat([
    alert_quiebre[cols_base + ["Unidades_Riesgo","Impacto_Quiebre_MXN","Alerta"]],
    alert_margen[cols_base + ["Impacto_Margen_MXN","Alerta"]],
    alert_under[cols_base + ["Underperf_%","Underperf_%_abs","Gap_MXN","Alerta"]]
], ignore_index=True).fillna({"Underperf_%_abs":0,"Unidades_Riesgo":0,"Impacto_Quiebre_MXN":0,"Impacto_Margen_MXN":0,"Gap_MXN":0})

alertas["Impacto_Total_MXN"] = alertas[["Impacto_Quiebre_MXN","Impacto_Margen_MXN","Gap_MXN"]].sum(axis=1)
alertas = alertas.sort_values("Impacto_Total_MXN", ascending=False)

cA, cB = st.columns([2,1])
with cA:
    st.dataframe(alertas, use_container_width=True)
with cB:
    if not alertas.empty:
        top_imp = (alertas.groupby("Ciudad", as_index=False)["Impacto_Total_MXN"].sum()
                   .sort_values("Impacto_Total_MXN", ascending=False).head(10))
        st.plotly_chart(px.bar(top_imp, x="Impacto_Total_MXN", y="Ciudad", orientation="h",
                               title="Top ciudades por impacto económico (MXN)"), use_container_width=True)
    csv = alertas.to_csv(index=False).encode("utf-8")
    st.download_button("⬇️ Descargar alertas (CSV)", data=csv, file_name="alertas_aires_mx.csv", mime="text/csv")

with st.expander("¿Cómo se calculan las alertas?"):
    st.markdown("""
**Quiebre de stock:** `Cobertura_Meses = Stock_Final / Demanda_Siguiente`.  
Impacto ≈ `max(Demanda_Siguiente − Stock_Final, 0) * Precio_Unit`.

**Margen bajo:** impacto ≈ `(Margen_obj − Margen_actual) * Ingresos`.

**Bajo desempeño:** Esperado = `MA(6m) * Factor estacional`.  
Gap % = `(Ingresos − Esperados) / Esperados`; impacto = `Esperados − Ingresos`.
""")

# ----------------------------- Forecast por ciudad (SARIMAX) -----------------------------
st.subheader("Proyección de demanda por ciudad")
serie = (df[df["Ciudad"] == ciudad_forecast]
         .sort_values("Fecha")[["Fecha","Demanda_Estimada"]]
         .set_index("Fecha").asfreq("MS"))
hist = serie["Demanda_Estimada"].astype(float)

def sarimax_forecast(y, steps=6):
    try:
        mod = SARIMAX(y, order=(1,1,1), seasonal_order=(0,1,1,12),
                      enforce_stationarity=False, enforce_invertibility=False)
        res = mod.fit(disp=False)
        fc = res.get_forecast(steps=steps)
        pred = fc.predicted_mean
        conf = fc.conf_int()
        conf.columns = ["lo","hi"]
        return pred, conf
    except Exception:
        idx_future = pd.date_range(y.index[-1]+pd.offsets.MonthBegin(), periods=steps, freq="MS")
        est = y.groupby(y.index.month).mean()
        pred = pd.Series([est.get(m, y.mean()) for m in idx_future.month], index=idx_future)
        conf = pd.DataFrame({"lo": pred*0.9, "hi": pred*1.1}, index=idx_future)
        return pred, conf

pred, conf = sarimax_forecast(hist, steps=horizonte)

df_plot = pd.concat([
    pd.DataFrame({"Fecha": hist.index, "Valor": hist.values, "Serie": "Histórico"}),
    pd.DataFrame({"Fecha": pred.index, "Valor": pred.values, "Serie": "Pronóstico"})
])
fig_fc = px.line(df_plot, x="Fecha", y="Valor", color="Serie", markers=True,
                 title=f"Demanda Estimada — {ciudad_forecast}")
fig_fc.add_traces([
    go.Scatter(x=conf.index, y=conf["lo"], name="LI 95%", mode="lines", line=dict(dash="dash")),
    go.Scatter(x=conf.index, y=conf["hi"], name="LS 95%", mode="lines", line=dict(dash="dash"), fill="tonexty")
])
st.plotly_chart(fig_fc, use_container_width=True)

# ----------------------------- Pedidos sugeridos -----------------------------
st.subheader("Sugerencia de pedidos")
tmp = dff.sort_values("Fecha").copy()
avg2 = tmp.groupby("Ciudad")["Ventas"].rolling(2).mean().reset_index(level=0, drop=True)
tmp["Ventas_prom_2m"] = avg2
tmp["Pedido_Sugerido"] = (tmp["Demanda_Estimada"] - tmp["Ventas_prom_2m"].fillna(tmp["Ventas"])).clip(lower=0).round()
pedidos = (tmp.groupby(["Ciudad","Zona"], as_index=False).agg(Pedido_Sugerido=("Pedido_Sugerido","mean"))
           .sort_values("Pedido_Sugerido", ascending=False).head(15))
st.plotly_chart(px.bar(pedidos, x="Pedido_Sugerido", y="Ciudad", color="Zona",
                       orientation="h", title="Top ciudades por pedido sugerido (promedio)"),
                use_container_width=True)

# ----------------------------- Análisis Estratégico -----------------------------
st.subheader("Análisis Estratégico")
col_strat1, col_strat2 = st.columns(2)

with col_strat1:
    # --- Análisis de Pareto (80/20) ---
    st.markdown("##### Principio de Pareto (80/20) sobre Ingresos")
    pareto_df = agg_city.sort_values("Ingresos", ascending=False)
    pareto_df["Ingresos_Acum"] = pareto_df["Ingresos"].cumsum()
    pareto_df["Ingresos_Acum_%"] = 100 * pareto_df["Ingresos_Acum"] / pareto_df["Ingresos"].sum()

    # Encontrar el punto del 80%
    num_sucursales_80 = pareto_df[pareto_df["Ingresos_Acum_%"] <= 80].shape[0] + 1
    total_sucursales = len(SUCURSALES)
    porc_sucursales_80 = round(100 * num_sucursales_80 / total_sucursales)

    st.info(f"""El **{porc_sucursales_80}%** de las sucursales (aprox. **{num_sucursales_80}** de **{total_sucursales}**) generan el **80%** de los ingresos totales.""")
    
    fig_pareto = go.Figure()
    fig_pareto.add_trace(go.Bar(x=pareto_df["Ciudad"], y=pareto_df["Ingresos"], name="Ingresos por Sucursal"))
    fig_pareto.add_trace(go.Scatter(x=pareto_df["Ciudad"], y=pareto_df["Ingresos_Acum_%"], name="% Acumulado", yaxis="y2", mode="lines+markers"))
    fig_pareto.update_layout(
        title="Distribución de Ingresos por Sucursal",
        yaxis=dict(title="Ingresos (MXN)"),
        yaxis2=dict(title="% Acumulado", overlaying="y", side="right", range=[0, 110]),
        legend=dict(x=0.01, y=0.98)
    )
    st.plotly_chart(fig_pareto, use_container_width=True)

with col_strat2:
    # --- Matriz de Correlación ---
    st.markdown("##### Matriz de Correlación de Métricas Clave")
    corr_cols = ["Ventas", "Precio_Unit", "Costo_Unit", "Ingresos", "Utilidad", "Margen_porcentaje"]
    corr_matrix = dff[corr_cols].corr()
    
    fig_corr, ax = plt.subplots()
    sns.heatmap(corr_matrix, annot=True, cmap="coolwarm", fmt=".2f", ax=ax)
    ax.set_title("Correlación entre Variables")
    st.pyplot(fig_corr)
    st.info("""
    **Interpretación:**
    - **Valores cercanos a 1:** Fuerte correlación positiva (si uno sube, el otro también).
    - **Valores cercanos a -1:** Fuerte correlación negativa (si uno sube, el otro baja).
    - **Valores cercanos a 0:** Poca o ninguna correlación lineal.
    """)

# ----------------------------- Segmentación de Sucursales (K-Means) -----------------------------
st.subheader("Segmentación de Sucursales con Machine Learning")

# 1. Preparar los datos para el modelo
# Agregamos métricas de volatilidad y crecimiento para enriquecer el modelo
agg_city_growth = dff.sort_values("Fecha").groupby("Ciudad").agg(
    Crecimiento_Ingresos=("Ingresos", lambda x: (x.iloc[-1] - x.iloc[0]) / x.iloc[0] if len(x) > 1 and x.iloc[0] != 0 else 0),
    Volatilidad_Ventas=("Ventas", lambda x: x.std() / x.mean() if x.mean() != 0 else 0)
).reset_index()

cluster_data = agg_city.merge(agg_city_growth, on="Ciudad")
features = ["Ingresos", "Margen_porcentaje", "Ventas", "Crecimiento_Ingresos", "Volatilidad_Ventas"]
X = cluster_data[features].fillna(0)

# 2. Escalar los datos (importante para K-Means)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 3. Entrenar el modelo K-Means
n_clusters = st.slider("Número de Clusters a generar", 2, 8, 4)
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
cluster_data["Cluster"] = kmeans.fit_predict(X_scaled)
cluster_data["Cluster"] = cluster_data["Cluster"].astype("str") # Para colores en plotly

# 4. Analizar y describir los clusters
cluster_profiles = cluster_data.groupby("Cluster")[features].mean().reset_index()
st.markdown("##### Perfil Promedio de Cada Cluster")
st.dataframe(cluster_profiles.style.background_gradient(cmap='viridis'))

st.info("""
**Sugerencia de Interpretación de Clusters:**
- **Cluster con altos Ingresos/Ventas y alto Margen/Crecimiento:**  **Sucursales Estrella**. Proteger, invertir y replicar su éxito.
- **Cluster con altos Ingresos/Ventas pero bajo Margen/Crecimiento:**  **Vacas Lecheras**. Foco en optimizar costos y márgenes.
- **Cluster con bajos Ingresos/Ventas pero alto Margen/Crecimiento:**  **Sucursales Potenciales**. Invertir en marketing para aumentar volumen.
- **Cluster con bajos Ingresos/Ventas y bajo Margen/Crecimiento:**  **Sucursales a Revisar**. Requieren un análisis profundo de su viabilidad.
""")


# 5. Visualizar los clusters
col_cluster1, col_cluster2 = st.columns(2)

with col_cluster1:
    # Scatter plot para visualizar los segmentos
    fig_cluster = px.scatter(
        cluster_data,
        x="Ingresos",
        y="Margen_porcentaje",
        color="Cluster",
        size="Ventas",
        hover_name="Ciudad",
        hover_data=["Crecimiento_Ingresos", "Volatilidad_Ventas"],
        title="Segmentación de Sucursales por Desempeño"
    )
    st.plotly_chart(fig_cluster, use_container_width=True)

with col_cluster2:
    # Mapa geográfico de los clusters
    fig_map_cluster = px.scatter_geo(
        cluster_data, lat="Lat", lon="Lon", scope="north america",
        hover_name="Ciudad",
        color="Cluster",
        size="Ingresos", size_max=30,
        title="Distribución Geográfica de Clusters"
    )
    fig_map_cluster.update_geos(fitbounds="locations", showland=True, landcolor="#f2f2f2")
    st.plotly_chart(fig_map_cluster, use_container_width=True)




# ----------------------------- Tabla detalle -----------------------------
st.subheader("Detalle (ciudad-mes)")
mostrar_cols = ["Fecha","Zona","Estado","Ciudad","Ventas","Ingresos","Costos","Utilidad","Margen_porcentaje","Demanda_Estimada",
                "Stock_Inicial","Reabasto","Stock_Final","Cobertura_Meses","Ingresos_Esperados","Underperf_%"]
st.dataframe(dff[mostrar_cols].sort_values(["Zona","Estado","Ciudad","Fecha"]), use_container_width=True)
