import asyncio
from typing import List, Optional
from langchain_core.tools import StructuredTool
from langchain_mcp_adapters.client import MultiServerMCPClient
from config.settings import settings

class MCPClient:
    """Cliente MCP para comunicación con servidores Spring AI (Java) de Agrota."""
    
    _instance = None
    _tools_cache = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(MCPClient, cls).__new__(cls)
            cls._instance._init_client()
        return cls._instance

    def _init_client(self):
        self.server_url = settings.MCP_SERVER_URL
        self.client_name = settings.CLIENT_NAME or "agrota-server"
        
        # El endpoint SSE de Spring AI suele ser /mcp/sse
        self.client = MultiServerMCPClient({
            self.client_name: {
                "url": self.server_url,
                "transport": "sse"
            }
        })

    def _wrap_tool(self, tool: StructuredTool) -> StructuredTool:
        """Adapta las herramientas para el formato de 'request' que espera Java si es necesario."""
        original_ainvoke = tool.ainvoke
        schema = tool.args_schema
        
        # Detectar si el esquema requiere un contenedor 'request'
        needs_wrap = False
        if isinstance(schema, dict) and "properties" in schema:
            if "request" in schema["properties"]: needs_wrap = True
        elif hasattr(schema, "model_fields") and "request" in schema.model_fields:
            needs_wrap = True
        elif hasattr(schema, "__fields__") and "request" in schema.__fields__:
            needs_wrap = True

        async def wrapped_ainvoke(**input_args):
            final_args = input_args
            if needs_wrap and "request" not in input_args:
                final_args = {"request": input_args}
            return await original_ainvoke(final_args)

        return StructuredTool.from_function(
            func=lambda **kwargs: asyncio.run(wrapped_ainvoke(**kwargs)),
            coroutine=wrapped_ainvoke,
            name=tool.name,
            description=tool.description,
            args_schema=tool.args_schema
        )

    async def get_tools_async(self) -> List[StructuredTool]:
        """Obtiene y adapta las herramientas del servidor MCP de Agrota."""
        if self._tools_cache is not None:
            return self._tools_cache
        
        try:
            # En la versión que usa lhia-v3, espera a que esté conectado
            raw_tools = await self.client.get_tools()
            if raw_tools:
                self._tools_cache = [self._wrap_tool(t) for t in raw_tools]
                tool_names = [t.name for t in self._tools_cache]
                print(f"\n{'='*50}\n🛠️  MCP TOOLS AGROTA: {tool_names}\n{'='*50}\n", flush=True)
                return self._tools_cache
            print("⚠️ No se encontraron herramientas en el servidor MCP de Agrota.")
            return []
        except Exception as e:
            print(f"❌ Error MCP Tools en Agrota: {e}")
            return []
