"""
auth.py
-------
Auth0 login guard para apps Streamlit.

Flujo:
- Si no hay usuario en session_state:
  - Si viene callback (code + state), intercambia por tokens
  - Si no, construye URL de login y pide iniciar sesión
- Valida roles desde:
  - access token (si viene en claim) o
  - /userinfo + claim de roles
"""
from __future__ import annotations
import os
import secrets
from urllib.parse import urlencode
import requests
import streamlit as st
import jwt
from jwt import PyJWKClient
# -----------------------------
# Config
# -----------------------------
def _cfg() -> dict:
    if "auth0" not in st.secrets:
        raise KeyError("Falta [auth0] en secrets de Streamlit Cloud (o secrets.toml local).")
    return dict(st.secrets["auth0"])

def _base_url() -> str:
    # En Streamlit Cloud, lo normal es usar redirect_uri = base url
    return _cfg()["redirect_uri"].rstrip("/")

# -----------------------------
# Login URL
# -----------------------------
def build_login_url() -> str:
    cfg = _cfg()
    domain = cfg["domain"]
    client_id = cfg["client_id"]
    redirect_uri = _base_url()

    state = secrets.token_urlsafe(24)
    st.session_state["oauth_state"] = state

    params = {
        "response_type": "code",
        "client_id": client_id,
        "redirect_uri": redirect_uri,
        "scope": "openid profile email",
        "state": state,
    }

    # Audience es opcional: solo si usas API y quieres access_token con audiencia
    audience = cfg.get("audience")
    if audience:
        params["audience"] = audience

    return f"https://{domain}/authorize?{urlencode(params)}"

# -----------------------------
# Callback: exchange code -> tokens
# -----------------------------
def _exchange_code_for_tokens(code: str) -> dict:
    cfg = _cfg()
    domain = cfg["domain"]
    client_id = cfg["client_id"]
    client_secret = cfg["client_secret"]
    redirect_uri = _base_url()

    token_url = f"https://{domain}/oauth/token"
    payload = {
        "grant_type": "authorization_code",
        "client_id": client_id,
        "client_secret": client_secret,
        "code": code,
        "redirect_uri": redirect_uri,
    }

    r = requests.post(token_url, json=payload, timeout=30)
    r.raise_for_status()
    return r.json()

def _fetch_userinfo(access_token: str) -> dict:
    cfg = _cfg()
    domain = cfg["domain"]
    r = requests.get(
        f"https://{domain}/userinfo",
        headers={"Authorization": f"Bearer {access_token}"},
        timeout=30,
    )
    r.raise_for_status()
    return r.json()

def _verify_jwt(access_token: str) -> dict:
    """
    Verifica firma RS256 del access_token (si es JWT).
    Si access_token no es JWT o falla verificación, regresamos {}.
    """
    cfg = _cfg()
    domain = cfg["domain"]
    audience = cfg.get("audience")  # si no hay audience, igual puede verificar sin aud
    issuer = f"https://{domain}/"

    try:
        jwks_url = f"https://{domain}/.well-known/jwks.json"
        jwk_client = PyJWKClient(jwks_url)
        signing_key = jwk_client.get_signing_key_from_jwt(access_token)
        decoded = jwt.decode(
            access_token,
            signing_key.key,
            algorithms=["RS256"],
            audience=audience if audience else None,
            issuer=issuer,
            options={"verify_aud": bool(audience)},
        )
        return decoded
    except Exception:
        return {}

def _extract_roles(claims: dict, userinfo: dict) -> list[str]:
    """
    Ajusta aquí según cómo guardas roles:
    - Si usas Auth0 RBAC + Add Permissions in Access Token, quizá viene en access_token.
    - Si usas custom claim, típicamente:
      "https://ideasfrescas.com.mx/roles": [...]
    """
    namespace = "dev-ajifaa3ovgdvp4ty.us.auth0.com"
    candidates = [
        claims.get("roles"),
        claims.get(f"{namespace}/roles"),
        userinfo.get("roles"),
        userinfo.get(f"{namespace}/roles"),
    ]
    for c in candidates:
        if isinstance(c, list):
            return [str(x) for x in c]
    return []

def handle_callback() -> dict | None:
    """
    Si la URL trae ?code=...&state=..., procesa callback y setea session_state['user'].
    """
    qp = st.query_params
    code = qp.get("code")
    state = qp.get("state")

    if not code or not state:
        return None

    expected_state = st.session_state.get("oauth_state")
    if expected_state and state != expected_state:
        st.error("State inválido. Intenta iniciar sesión otra vez.")
        return None

    tokens = _exchange_code_for_tokens(code)
    access_token = tokens.get("access_token")
    id_token = tokens.get("id_token")

    # Intenta verificar access_token (JWT). Si no, claims vacíos.
    claims = _verify_jwt(access_token) if access_token else {}

    # userinfo
    userinfo = _fetch_userinfo(access_token) if access_token else {}

    roles = _extract_roles(claims, userinfo)

    user = {
        "email": userinfo.get("email") or claims.get("email"),
        "name": userinfo.get("name") or claims.get("name"),
        "roles": roles,
        "access_token": access_token,
        "id_token": id_token,
    }




    st.session_state["user"] = user

    # Limpia query params para que no se reprocese cada rerun
    st.query_params.clear()

    return user

def get_current_user() -> dict | None:
    return st.session_state.get("user")

def logout_url() -> str:
    cfg = _cfg()
    domain = cfg["domain"]
    client_id = cfg["client_id"]
    return_to = _base_url()

    params = {
        "client_id": client_id,
        "returnTo": return_to,
    }
    return f"https://{domain}/v2/logout?{urlencode(params)}"
