"""
Gestor de sesiones basado en Redis para persistencia y escalabilidad.
Reemplaza la implementación anterior en memoria.
"""
import json
import logging
import asyncio
from typing import Dict, Any, Optional
from datetime import datetime
import redis.asyncio as redis

from langchain_core.messages import messages_to_dict, messages_from_dict
from config.settings import settings

logger = logging.getLogger(__name__)

class SessionData:
    """
    Datos de una sesión de cliente (DTO).
    Ahora es puramente datos, la lógica de lock se maneja externamente o en el manager.
    """
    def __init__(self, uuid_conversation: str, state: Dict[str, Any] = None):
        self.uuid_conversation = uuid_conversation
        if state is None:
            self.state = {
                "messages": [],
                "missing_fields": [],
                "intent": None,
                "customer_id": None,
                "product_id": None,
                "category": None,
                "brand": None,
                "last_tool_result": None,
                "error": None
            }
        else:
            self.state = state

class RedisSessionManager:
    """
    Gestor de sesiones persistente usando Redis.
    """
    def __init__(self):
        self._redis_kwargs: Dict[str, Any] = dict(
            host=settings.REDIS_HOST,
            port=settings.REDIS_PORT,
            password=settings.REDIS_PASSWORD or None,
            db=settings.REDIS_DB,
            decode_responses=True,
            socket_timeout=3,
            socket_connect_timeout=3,
            socket_keepalive=True,
            retry_on_timeout=False,    # Sin reintentos en timeout
            retry_on_error=[],         # Sin reintentos en error de conexión
        )
        if settings.REDIS_SSL:
            self._redis_kwargs["ssl"] = True
            self._redis_kwargs["ssl_cert_reqs"] = "none"      # Azure no requiere cert del cliente
            self._redis_kwargs["ssl_check_hostname"] = False  # IP en vez de hostname → sin verificación SNI
        self.client = redis.Redis(**self._redis_kwargs)
        self.ttl = settings.REDIS_SESSION_TTL
        logger.info(
            f"🔌 Redis configurado: {settings.REDIS_HOST}:{settings.REDIS_PORT} "
            f"db={settings.REDIS_DB} ssl={settings.REDIS_SSL}"
        )

        # Locks locales para concurrencia en el mismo proceso (asyncio)
        self._local_locks: Dict[str, asyncio.Lock] = {}
        self._lock_cleaner_lock = asyncio.Lock()

    def _recreate_client(self):
        """Recrea el cliente Redis para limpiar el pool de conexiones corrupto."""
        try:
            asyncio.get_event_loop().create_task(self.client.aclose())
        except Exception:
            pass
        self.client = redis.Redis(**self._redis_kwargs)
        logger.info("♻️  Cliente Redis recreado (pool limpiado)")

    def get_lock(self, uuid_conversation: str) -> asyncio.Lock:
        """Obtiene el lock local para un cliente."""
        if uuid_conversation not in self._local_locks:
            self._local_locks[uuid_conversation] = asyncio.Lock()
        return self._local_locks[uuid_conversation]

    async def get_session(self, uuid_conversation: str) -> SessionData:
        """
        Carga la sesión desde Redis.
        """
        try:
            data = await asyncio.wait_for(
                self.client.get(f"session:{uuid_conversation}"),
                timeout=4.0
            )
            if data:
                session_dict = json.loads(data)
                # Deserializar mensajes de LangChain
                if "messages" in session_dict["state"]:
                    session_dict["state"]["messages"] = messages_from_dict(session_dict["state"]["messages"])

                return SessionData(uuid_conversation, session_dict["state"])
            else:
                logger.info(f"📝 Creando nueva sesión en Redis para: {uuid_conversation}")
                return SessionData(uuid_conversation)
        except (asyncio.TimeoutError, asyncio.CancelledError) as e:
            logger.error(f"❌ Timeout cargando sesión Redis {uuid_conversation} — recreando cliente")
            self._recreate_client()
            return SessionData(uuid_conversation)
        except Exception as e:
            logger.error(f"❌ Error cargando sesión Redis {uuid_conversation}: {type(e).__name__}: {e}")
            return SessionData(uuid_conversation)

    async def save_session(self, uuid_conversation: str, session: SessionData):
        """
        Guarda la sesión en Redis.
        """
        try:
            # Serializar estado
            state_to_save = session.state.copy()
            # Serializar mensajes de LangChain
            if "messages" in state_to_save and state_to_save["messages"]:
                state_to_save["messages"] = messages_to_dict(state_to_save["messages"])

            data = {
                "uuid_conversation": uuid_conversation,
                "state": state_to_save,
                "updated_at": datetime.now().isoformat()
            }

            await asyncio.wait_for(
                self.client.setex(
                    f"session:{uuid_conversation}",
                    self.ttl,
                    json.dumps(data)
                ),
                timeout=4.0
            )
        except (asyncio.TimeoutError, asyncio.CancelledError):
            logger.error(f"❌ Timeout guardando sesión Redis {uuid_conversation} — recreando cliente")
            self._recreate_client()
        except Exception as e:
            logger.error(f"❌ Error guardando sesión Redis {uuid_conversation}: {type(e).__name__}: {e}")

    async def remove_session(self, uuid_conversation: str) -> bool:
        """Elimina la sesión de Redis."""
        try:
            await self.client.delete(f"session:{uuid_conversation}")
            return True
        except Exception as e:
            logger.error(f"Error borrando sesión: {e}")
            return False

    async def get_active_sessions_count(self) -> int:
        """Cuenta llaves de sesiones activas (Scan). Costoso, usar con cuidado."""
        count = 0
        try:
            async for _ in self.client.scan_iter("session:*"):
                count += 1
        except:
            pass
        return count

    async def close(self):
        await self.client.close()

# Instancia global
session_manager = RedisSessionManager()
