import asyncio
import threading
from langchain_core.tools import StructuredTool


def create_tool_wrapper(original_tool: StructuredTool) -> StructuredTool:
    """
    Wraps an MCP tool so it works correctly in both sync and async contexts.
    Handles Spring AI's 'request' argument convention automatically.
    """
    async def _ainvoke(**kwargs):
        schema = getattr(original_tool, "args_schema", None)
        needs_wrap = False

        if schema:
            if hasattr(schema, "model_fields"):
                needs_wrap = "request" in schema.model_fields
            elif hasattr(schema, "__fields__"):
                needs_wrap = "request" in schema.__fields__
            elif isinstance(schema, dict):
                needs_wrap = "request" in schema.get("properties", {})

        if needs_wrap and "request" not in kwargs:
            kwargs = {"request": kwargs}

        return await original_tool.ainvoke(kwargs)

    def _invoke_sync(**kwargs):
        try:
            asyncio.get_running_loop()
            # Running inside an event loop (FastAPI/uvicorn): use a thread
            container: dict = {"result": None, "error": None}

            def run_in_thread():
                new_loop = asyncio.new_event_loop()
                asyncio.set_event_loop(new_loop)
                try:
                    container["result"] = new_loop.run_until_complete(_ainvoke(**kwargs))
                except Exception as exc:
                    container["error"] = exc
                finally:
                    new_loop.close()

            thread = threading.Thread(target=run_in_thread)
            thread.start()
            thread.join(timeout=60)

            if container["error"]:
                raise container["error"]
            return container["result"]

        except RuntimeError:
            return asyncio.run(_ainvoke(**kwargs))

    return StructuredTool.from_function(
        func=_invoke_sync,
        name=original_tool.name,
        description=original_tool.description,
        args_schema=original_tool.args_schema,
    )
