from enum import Enum
from typing import Any

from langchain_core.callbacks import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.vectorstores import VectorStore
from pydantic import Field, model_validator
from typing_extensions import override

from langchain_classic.storage._lc_store import create_kv_docstore


class SearchType(str, Enum):
    """Enumerator of the types of search to perform."""

    similarity = "similarity"
    """Similarity search."""
    similarity_score_threshold = "similarity_score_threshold"
    """Similarity search with a score threshold."""
    mmr = "mmr"
    """Maximal Marginal Relevance reranking of similarity search."""


class MultiVectorRetriever(BaseRetriever):
    """Retriever that supports multiple embeddings per parent document.

    This retriever is designed for scenarios where documents are split into
    smaller chunks for embedding and vector search, but retrieval returns
    the original parent documents rather than individual chunks.

    It works by:
    - Performing similarity (or MMR) search over embedded child chunks
    - Collecting unique parent document IDs from chunk metadata
    - Fetching and returning the corresponding parent documents from the docstore

    This pattern is commonly used in RAG pipelines to improve answer grounding
    while preserving full document context.
    """

    vectorstore: VectorStore
    """The underlying `VectorStore` to use to store small chunks
    and their embedding vectors"""

    byte_store: ByteStore | None = None
    """The lower-level backing storage layer for the parent documents"""

    docstore: BaseStore[str, Document]
    """The storage interface for the parent documents"""

    id_key: str = "doc_id"

    search_kwargs: dict = Field(default_factory=dict)
    """Keyword arguments to pass to the search function."""

    search_type: SearchType = SearchType.similarity
    """Type of search to perform (similarity / mmr)"""

    @model_validator(mode="before")
    @classmethod
    def _shim_docstore(cls, values: dict) -> Any:
        byte_store = values.get("byte_store")
        docstore = values.get("docstore")
        if byte_store is not None:
            docstore = create_kv_docstore(byte_store)
        elif docstore is None:
            msg = "You must pass a `byte_store` parameter."
            raise ValueError(msg)
        values["docstore"] = docstore
        return values

    @override
    def _get_relevant_documents(
        self,
        query: str,
        *,
        run_manager: CallbackManagerForRetrieverRun,
    ) -> list[Document]:
        """Get documents relevant to a query.

        Args:
            query: String to find relevant documents for
            run_manager: The callbacks handler to use
        Returns:
            List of relevant documents.
        """
        if self.search_type == SearchType.mmr:
            sub_docs = self.vectorstore.max_marginal_relevance_search(
                query,
                **self.search_kwargs,
            )
        elif self.search_type == SearchType.similarity_score_threshold:
            sub_docs_and_similarities = (
                self.vectorstore.similarity_search_with_relevance_scores(
                    query,
                    **self.search_kwargs,
                )
            )
            sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
        else:
            sub_docs = self.vectorstore.similarity_search(query, **self.search_kwargs)

        # We do this to maintain the order of the IDs that are returned
        ids = []
        for d in sub_docs:
            if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
                ids.append(d.metadata[self.id_key])
        docs = self.docstore.mget(ids)
        return [d for d in docs if d is not None]

    @override
    async def _aget_relevant_documents(
        self,
        query: str,
        *,
        run_manager: AsyncCallbackManagerForRetrieverRun,
    ) -> list[Document]:
        """Asynchronously get documents relevant to a query.

        Args:
            query: String to find relevant documents for
            run_manager: The callbacks handler to use
        Returns:
            List of relevant documents.
        """
        if self.search_type == SearchType.mmr:
            sub_docs = await self.vectorstore.amax_marginal_relevance_search(
                query,
                **self.search_kwargs,
            )
        elif self.search_type == SearchType.similarity_score_threshold:
            sub_docs_and_similarities = (
                await self.vectorstore.asimilarity_search_with_relevance_scores(
                    query,
                    **self.search_kwargs,
                )
            )
            sub_docs = [sub_doc for sub_doc, _ in sub_docs_and_similarities]
        else:
            sub_docs = await self.vectorstore.asimilarity_search(
                query,
                **self.search_kwargs,
            )

        # We do this to maintain the order of the IDs that are returned
        ids = []
        for d in sub_docs:
            if self.id_key in d.metadata and d.metadata[self.id_key] not in ids:
                ids.append(d.metadata[self.id_key])
        docs = await self.docstore.amget(ids)
        return [d for d in docs if d is not None]
