"""Semantic cache middleware for LLM responses.

This module provides a middleware that caches LLM responses based on
semantic similarity using Redis and vector embeddings. Compatible with
LangChain's AgentMiddleware protocol for use with create_agent.
"""

import json
import logging
import uuid
from typing import Any, Awaitable, Callable, List, Union

from langchain.agents.middleware.types import (
    ModelCallResult,
    ModelRequest,
    ModelResponse,
)
from langchain_core.messages import AIMessage
from redisvl.extensions.cache.llm import SemanticCache

from langgraph.checkpoint.redis.jsonplus_redis import JsonPlusRedisSerializer

from .aio import AsyncRedisMiddleware
from .types import SemanticCacheConfig

logger = logging.getLogger(__name__)

# Use the project's serializer for proper LangChain object handling
_serializer = JsonPlusRedisSerializer()


def _strip_content_ids(content: Any) -> Any:
    """Strip provider-specific IDs from content blocks.

    When using the OpenAI Responses API, content is a list of blocks with
    embedded item IDs (rs_, msg_ prefixes). These must be removed from cached
    messages to prevent duplicate ID errors.
    """
    if not isinstance(content, list):
        return content
    stripped = []
    for block in content:
        if isinstance(block, dict) and "id" in block:
            stripped.append({k: v for k, v in block.items() if k != "id"})
        else:
            stripped.append(block)
    return stripped


def _serialize_response(response: Any) -> str:
    """Serialize a model response for cache storage.

    Uses LangChain's to_json() for proper serialization of LangChain objects.

    Args:
        response: The response to serialize.

    Returns:
        JSON string representation of the response.
    """
    # Handle ModelResponse (has .result which is list of messages)
    if hasattr(response, "result") and isinstance(response.result, list):
        # Serialize the last AI message from the result list
        for msg in reversed(response.result):
            if hasattr(msg, "to_json"):
                # Use LangChain's built-in serialization
                return json.dumps(msg.to_json())
            elif hasattr(msg, "content"):
                # Fallback: extract content
                return json.dumps({"content": getattr(msg, "content", "")})
        return json.dumps({"content": ""})

    # Handle objects with to_json() (LangChain objects like AIMessage)
    if hasattr(response, "to_json"):
        return json.dumps(response.to_json())

    # Handle dict-style responses
    if isinstance(response, dict):
        return json.dumps(response)

    # Fallback: try to get content attribute
    content = getattr(response, "content", str(response))
    return json.dumps({"content": content})


def _deserialize_response(cached_str: str) -> ModelResponse:
    """Deserialize a cached response into a ModelResponse.

    Uses the project's JsonPlusRedisSerializer for proper LangChain object revival.
    Always returns a ModelResponse to maintain compatibility with agent routing.

    IMPORTANT: Each cache hit generates a NEW message ID (UUID). This is critical
    for frontend streaming compatibility - without unique IDs, the frontend
    deduplicates messages and cached responses don't appear.

    The cached response is also marked with additional_kwargs={"cached": True}
    to allow consumers to identify cached responses.

    Args:
        cached_str: The cached JSON string.

    Returns:
        A ModelResponse containing the cached message with a unique ID.
    """
    # Generate a new UUID for this cache hit
    # This ensures each cached response appears as a new message in the frontend
    new_message_id = str(uuid.uuid4())

    try:
        data = json.loads(cached_str)
        if isinstance(data, dict):
            # Check if this is in LangChain constructor format
            if data.get("lc") in (1, 2) and data.get("type") == "constructor":
                # Use the project's serializer to properly revive
                revived = _serializer._revive_if_needed(data)
                if isinstance(revived, AIMessage):
                    # Create a new AIMessage with fresh ID and cached marker,
                    # preserving all fields from the revived message
                    cached_message = revived.model_copy(
                        update={
                            "id": new_message_id,
                            "content": _strip_content_ids(revived.content),
                            "additional_kwargs": {"cached": True},
                            "response_metadata": {},
                        }
                    )
                    return ModelResponse(
                        result=[cached_message], structured_response=None
                    )
                # If revived is not an AIMessage, wrap content in one
                content = getattr(revived, "content", str(revived))
                return ModelResponse(
                    result=[
                        AIMessage(
                            content=content,
                            id=new_message_id,
                            additional_kwargs={"cached": True},
                        )
                    ],
                    structured_response=None,
                )
            # Simple dict with content - wrap in ModelResponse
            content = data.get("content", "")
            return ModelResponse(
                result=[
                    AIMessage(
                        content=content,
                        id=new_message_id,
                        additional_kwargs={"cached": True},
                    )
                ],
                structured_response=None,
            )
        # Non-dict data - convert to string
        return ModelResponse(
            result=[
                AIMessage(
                    content=str(data),
                    id=new_message_id,
                    additional_kwargs={"cached": True},
                )
            ],
            structured_response=None,
        )
    except json.JSONDecodeError:
        # If not valid JSON, treat as plain content
        return ModelResponse(
            result=[
                AIMessage(
                    content=cached_str,
                    id=new_message_id,
                    additional_kwargs={"cached": True},
                )
            ],
            structured_response=None,
        )


class SemanticCacheMiddleware(AsyncRedisMiddleware):
    """Middleware that caches LLM responses based on semantic similarity.

    Uses redisvl.extensions.llmcache.SemanticCache to store and retrieve
    cached responses. When a request is semantically similar to a previous
    request (within the distance threshold), the cached response is returned
    without calling the LLM.

    By default, only "final" responses (those without tool_calls) are cached.
    This prevents caching intermediate responses that require tool execution.

    Example:
        ```python
        from langgraph.middleware.redis import (
            SemanticCacheMiddleware,
            SemanticCacheConfig,
        )

        config = SemanticCacheConfig(
            redis_url="redis://localhost:6379",
            distance_threshold=0.1,
            ttl_seconds=3600,
        )

        middleware = SemanticCacheMiddleware(config)

        async def call_llm(request):
            # Your LLM call here
            return response

        # Use middleware
        result = await middleware.awrap_model_call(request, call_llm)
        ```
    """

    _cache: SemanticCache
    _config: SemanticCacheConfig

    def __init__(self, config: SemanticCacheConfig) -> None:
        """Initialize the semantic cache middleware.

        Args:
            config: Configuration for the semantic cache.
        """
        super().__init__(config)
        self._config = config

    async def _setup_async(self) -> None:
        """Set up the SemanticCache instance.

        Note: SemanticCache from redisvl uses synchronous Redis operations
        internally, so we must provide redis_url and let it manage its own
        sync connection rather than passing our async client.
        """
        cache_kwargs: dict[str, Any] = {
            "name": self._config.name,
            "distance_threshold": self._config.distance_threshold,
        }

        # SemanticCache requires a sync Redis connection
        # Use redis_url to let it create its own connection
        if self._config.redis_url:
            cache_kwargs["redis_url"] = self._config.redis_url
        elif self._config.connection_args:
            cache_kwargs["connection_kwargs"] = self._config.connection_args

        if self._config.vectorizer is not None:
            cache_kwargs["vectorizer"] = self._config.vectorizer

        if self._config.ttl_seconds is not None:
            cache_kwargs["ttl"] = self._config.ttl_seconds

        self._cache = SemanticCache(**cache_kwargs)

    def _extract_prompt(self, messages: List[Union[dict[str, Any], Any]]) -> str:
        """Extract the prompt to use for cache lookup.

        Extracts the last user message content from the messages list.
        Handles both dict-style messages and LangChain message objects.

        Args:
            messages: List of messages from the request.

        Returns:
            The extracted prompt string.
        """
        if not messages:
            return ""

        # Find the last user message
        for message in reversed(messages):
            # Handle dict-style messages
            if isinstance(message, dict):
                role = message.get("role", "")
                if role == "user":
                    return message.get("content", "")
            else:
                # Handle LangChain-style message objects
                msg_type = getattr(message, "type", None) or getattr(
                    message, "role", None
                )
                if msg_type in ("user", "human"):
                    return getattr(message, "content", "")

        return ""

    def _is_final_response(self, response: Any) -> bool:
        """Check if the response is a final response (no tool calls).

        Args:
            response: The model response to check (dict or LangChain type).

        Returns:
            True if the response is final (should be cached), False otherwise.
        """
        # Support both dict-style and LangChain response types
        if isinstance(response, dict):
            tool_calls = response.get("tool_calls")
        else:
            # For ModelResponse, check result[0].tool_calls
            # ModelResponse itself doesn't have tool_calls attribute
            tool_calls = getattr(response, "tool_calls", None)
            if tool_calls is None and hasattr(response, "result"):
                result = response.result
                if result and len(result) > 0:
                    tool_calls = getattr(result[0], "tool_calls", None)
        # Response is final if there are no tool_calls or tool_calls is empty
        return not tool_calls

    def _get_tool_names_from_results(
        self, messages: List[Union[dict[str, Any], Any]]
    ) -> List[str]:
        """Extract tool names from tool result messages.

        Args:
            messages: List of messages from the request.

        Returns:
            List of tool names that have results in the messages.
        """
        tool_names = []
        for message in messages:
            if isinstance(message, dict):
                role = message.get("role", "") or message.get("type", "")
                if role == "tool":
                    tool_names.append(message.get("name", ""))
            else:
                msg_type = getattr(message, "type", None) or getattr(
                    message, "role", None
                )
                if msg_type == "tool" or message.__class__.__name__ == "ToolMessage":
                    tool_names.append(getattr(message, "name", ""))
        return tool_names

    def _should_skip_cache_for_tool_results(
        self, messages: List[Union[dict[str, Any], Any]]
    ) -> bool:
        """Check if cache should be skipped due to tool results.

        When tool results are present, we check if ALL tools are in the
        deterministic_tools list. If so, caching is safe. Otherwise,
        we skip the cache to avoid returning stale responses.

        Args:
            messages: List of messages from the request.

        Returns:
            True if cache should be skipped, False if caching is OK.
        """
        tool_names = self._get_tool_names_from_results(messages)

        if not tool_names:
            # No tool results - caching is OK
            return False

        # If deterministic_tools is not configured, always skip cache
        # when tool results are present (safest default)
        if self._config.deterministic_tools is None:
            return True

        # Check if ALL tool results are from deterministic tools
        for tool_name in tool_names:
            if tool_name and tool_name not in self._config.deterministic_tools:
                # Found a non-deterministic tool result - skip cache
                return True

        # All tools are deterministic - caching is OK
        return False

    async def awrap_model_call(
        self,
        request: ModelRequest,
        handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
    ) -> ModelCallResult:
        """Wrap a model call with semantic caching.

        Checks the cache for a semantically similar request. If found,
        returns the cached response. Otherwise, calls the handler and
        optionally caches the result.

        Args:
            request: The model request containing messages.
            handler: The async function to call the model.

        Returns:
            The model response (from cache or handler).

        Raises:
            Exception: If graceful_degradation is False and cache operations fail.
        """
        await self._ensure_initialized_async()

        # Support both dict-style and LangChain ModelRequest types
        if isinstance(request, dict):
            messages = request.get("messages", [])
        else:
            messages = getattr(request, "messages", [])
        prompt = self._extract_prompt(messages)

        if not prompt:
            # No prompt to cache, just call handler
            return await handler(request)

        # Skip cache lookup if request contains non-deterministic tool results
        # The model needs to process tool output to generate the final response
        if self._should_skip_cache_for_tool_results(messages):
            logger.debug(
                "Skipping cache - request contains non-deterministic tool results"
            )
            response = await handler(request)
            # Cache the final response after tool processing if tools are deterministic
            if not self._config.cache_final_only or self._is_final_response(response):
                try:
                    response_str = _serialize_response(response)
                    await self._cache.astore(prompt=prompt, response=response_str)
                except Exception as e:
                    if not self._graceful_degradation:
                        raise
                    logger.warning(f"Cache store failed: {e}")
            return response

        # Try to get from cache using async method
        try:
            cached = await self._cache.acheck(prompt=prompt)
            if cached:
                cached_response = cached[0].get("response")
                if cached_response:
                    logger.debug(f"Cache hit for prompt: {prompt[:50]}...")
                    return _deserialize_response(cached_response)
        except Exception as e:
            if not self._graceful_degradation:
                raise
            logger.warning(f"Cache check failed, calling handler: {e}")

        # Cache miss - call handler
        response = await handler(request)

        # Store in cache if appropriate
        should_cache = not self._config.cache_final_only or self._is_final_response(
            response
        )

        if should_cache:
            try:
                # Serialize response for storage using async method
                response_str = _serialize_response(response)
                await self._cache.astore(prompt=prompt, response=response_str)
                logger.debug(f"Cached response for prompt: {prompt[:50]}...")
            except Exception as e:
                if not self._graceful_degradation:
                    raise
                logger.warning(f"Cache store failed: {e}")

        return response
