import warnings
from typing import Any, Dict, List, Optional, Set, Union

from pydantic import BaseModel, Field, field_validator, model_validator
from redis.commands.search.aggregation import AggregateRequest, Desc
from typing_extensions import Self

from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.fields import VectorDataType
from redisvl.utils.full_text_query_helper import FullTextQueryHelper
from redisvl.utils.utils import lazy_import

nltk = lazy_import("nltk")
nltk_stopwords = lazy_import("nltk.corpus.stopwords")


class Vector(BaseModel):
    """
    Simple object containing the necessary arguments to perform a multi vector query.

    Args:
    vector: The vector values as a list of floats or bytes
    field_name: The name of the vector field to search
    dtype: The data type of the vector (default: "float32")
    weight: The weight for this vector in the combined score (default: 1.0)
    max_distance: The maximum distance for vector range search (default: 2.0, range: [0.0, 2.0])
    """

    vector: Union[List[float], bytes]
    field_name: str
    dtype: str = "float32"
    weight: float = 1.0
    max_distance: float = Field(default=2.0, ge=0.0, le=2.0)

    @field_validator("dtype")
    @classmethod
    def validate_dtype(cls, dtype: str) -> str:
        try:
            VectorDataType(dtype.upper())
        except ValueError:
            raise ValueError(
                f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
            )
        return dtype

    @field_validator("max_distance")
    @classmethod
    def validate_max_distance(cls, max_distance: float) -> float:
        if not isinstance(max_distance, (float, int)):
            raise ValueError("max_distance must be a value between 0.0 and 2.0")
        if max_distance < 0.0 or max_distance > 2.0:
            raise ValueError("max_distance must be a value between 0.0 and 2.0")
        return max_distance

    @model_validator(mode="after")
    def validate_vector(self) -> Self:
        """If the vector passed in is an array of float convert it to a byte string."""
        if isinstance(self.vector, bytes):
            return self
        self.vector = array_to_buffer(self.vector, self.dtype)
        return self


class AggregationQuery(AggregateRequest):
    """
    Base class for aggregation queries used to create aggregation queries for Redis.
    """

    def __init__(self, query_string):
        super().__init__(query_string)


class AggregateHybridQuery(AggregationQuery):
    """
    AggregateHybridQuery combines text and vector search in Redis.
    It allows you to perform a hybrid search using both text and vector similarity.
    It scores documents based on a weighted combination of text and vector similarity.

    .. code-block:: python

        from redisvl.query import AggregateHybridQuery
        from redisvl.index import SearchIndex

        index = SearchIndex.from_yaml("path/to/index.yaml")

        query = AggregateHybridQuery(
            text="example text",
            text_field_name="text_field",
            vector=[0.1, 0.2, 0.3],
            vector_field_name="vector_field",
            text_scorer="BM25STD",
            filter_expression=None,
            alpha=0.7,
            dtype="float32",
            num_results=10,
            return_fields=["field1", "field2"],
            stopwords="english",
            dialect=2,
        )

        results = index.query(query)

    """

    DISTANCE_ID: str = "vector_distance"
    VECTOR_PARAM: str = "vector"

    def __init__(
        self,
        text: str,
        text_field_name: str,
        vector: Union[bytes, List[float]],
        vector_field_name: str,
        text_scorer: str = "BM25STD",
        filter_expression: Optional[Union[str, FilterExpression]] = None,
        alpha: float = 0.7,
        dtype: str = "float32",
        num_results: int = 10,
        return_fields: Optional[List[str]] = None,
        stopwords: Optional[Union[str, Set[str]]] = "english",
        dialect: int = 2,
        text_weights: Optional[Dict[str, float]] = None,
    ):
        """
        Instantiates a AggregateHybridQuery object.

        Args:
            text (str): The text to search for.
            text_field_name (str): The text field name to search in.
            vector (Union[bytes, List[float]]): The vector to perform vector similarity search.
            vector_field_name (str): The vector field name to search in.
            text_scorer (str, optional): The text scorer to use. Options are {TFIDF, TFIDF.DOCNORM,
                BM25, DISMAX, DOCSCORE, BM25STD}. Defaults to "BM25STD".
            filter_expression (Optional[FilterExpression], optional): The filter expression to use.
                Defaults to None.
            alpha (float, optional): The weight of the vector similarity. Documents will be scored
                as: hybrid_score = (alpha) * vector_score + (1-alpha) * text_score.
                Defaults to 0.7.
            dtype (str, optional): The data type of the vector. Defaults to "float32".
            num_results (int, optional): The number of results to return. Defaults to 10.
            return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
            stopwords (Optional[Union[str, Set[str]]], optional): The stopwords to remove from the
                provided text prior to search-use. If a string such as "english" "german" is
                provided then a default set of stopwords for that language will be used. if a list,
                set, or tuple of strings is provided then those will be used as stopwords.
                Defaults to "english". if set to "None" then no stopwords will be removed.

                Note: This parameter controls query-time stopword filtering (client-side).
                For index-level stopwords configuration (server-side), see IndexInfo.stopwords.
                Using query-time stopwords with index-level STOPWORDS 0 is counterproductive.
            dialect (int, optional): The Redis dialect version. Defaults to 2.
            text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
                within the query text. Defaults to None, as no modifications will be made to the
                text_scorer score.

        Note:
            AggregateHybridQuery uses FT.AGGREGATE commands which do NOT support runtime
            parameters. For runtime parameter support (ef_runtime, search_window_size, etc.),
            use VectorQuery or VectorRangeQuery which use FT.SEARCH commands.

        Raises:
            ValueError: If the text string is empty, or if the text string becomes empty after
                stopwords are removed.
            TypeError: If the stopwords are not a set, list, or tuple of strings.
        """

        if not text.strip():
            raise ValueError("text string cannot be empty")

        self._text = text
        self._text_field = text_field_name
        self._vector = vector
        self._vector_field = vector_field_name
        self._filter_expression = filter_expression
        self._alpha = alpha
        self._dtype = dtype
        self._num_results = num_results

        self._ft_helper = FullTextQueryHelper(
            stopwords=stopwords,
            text_weights=text_weights,
        )

        query_string = self._build_query_string()
        super().__init__(query_string)

        self.scorer(text_scorer)
        self.add_scores()
        self.apply(
            vector_similarity=f"(2 - @{self.DISTANCE_ID})/2", text_score="@__score"
        )
        self.apply(hybrid_score=f"{1-alpha}*@text_score + {alpha}*@vector_similarity")
        self.sort_by(Desc("@hybrid_score"), max=num_results)  # type: ignore
        self.dialect(dialect)
        if return_fields:
            self.load(*return_fields)  # type: ignore[arg-type]

    @property
    def params(self) -> Dict[str, Any]:
        """Return the parameters for the aggregation.

        Returns:
            Dict[str, Any]: The parameters for the aggregation.
        """
        if isinstance(self._vector, list):
            vector = array_to_buffer(self._vector, dtype=self._dtype)
        else:
            vector = self._vector

        params: Dict[str, Any] = {self.VECTOR_PARAM: vector}

        return params

    @property
    def stopwords(self) -> Set[str]:
        """Return the stopwords used in the query.
        Returns:
            Set[str]: The stopwords used in the query.
        """
        return self._ft_helper.stopwords

    @property
    def text_weights(self) -> Dict[str, float]:
        """Get the text weights.

        Returns:
            Dictionary of word:weight mappings.
        """
        return self._ft_helper.text_weights

    def set_text_weights(self, weights: Dict[str, float]):
        """Set or update the text weights for the query.

        Args:
            weights: Dictionary of word:weight mappings
        """
        self._ft_helper.set_text_weights(weights)
        self._query = self._build_query_string()

    def _build_query_string(self) -> str:
        """Build the full query string for text search with optional filtering."""
        text = self._ft_helper.build_query_string(
            self._text, self._text_field, self._filter_expression
        )

        # Build KNN query
        knn_query = (
            f"KNN {self._num_results} @{self._vector_field} ${self.VECTOR_PARAM}"
        )

        # Add distance field alias
        knn_query += f" AS {self.DISTANCE_ID}"

        return f"{text}=>[{knn_query}]"

    def __str__(self) -> str:
        """Return the string representation of the query."""
        return " ".join([str(x) for x in self.build_args()])


class MultiVectorQuery(AggregationQuery):
    """
    MultiVectorQuery allows for search over multiple vector fields in a document simultaneously.
    The final score will be a weighted combination of the individual vector similarity scores
    following the formula:

    score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... )

    Vectors may be of different size and datatype, but must be indexed using the 'cosine' distance_metric.

    .. code-block:: python

        from redisvl.query import MultiVectorQuery, Vector
        from redisvl.index import SearchIndex

        index = SearchIndex.from_yaml("path/to/index.yaml")

        vector_1 = Vector(
            vector=[0.1, 0.2, 0.3],
            field_name="text_vector",
            dtype="float32",
            weight=0.7,
        )
        vector_2 = Vector(
            vector=[0.5, 0.5],
            field_name="image_vector",
            dtype="bfloat16",
            weight=0.2,
        )
        vector_3 = Vector(
            vector=[0.1, 0.2, 0.3],
            field_name="text_vector",
            dtype="float64",
            weight=0.5,
        )

        query = MultiVectorQuery(
            vectors=[vector_1, vector_2, vector_3],
            filter_expression=None,
            num_results=10,
            return_fields=["field1", "field2"],
            dialect=2,
        )

        results = index.query(query)
    """

    _vectors: List[Vector]

    def __init__(
        self,
        vectors: Union[Vector, List[Vector]],
        return_fields: Optional[List[str]] = None,
        filter_expression: Optional[Union[str, FilterExpression]] = None,
        num_results: int = 10,
        dialect: int = 2,
    ):
        """
        Instantiates a MultiVectorQuery object.

        Args:
            vectors (Union[Vector, List[Vector]]): The Vectors to perform vector similarity search.
            return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
            filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use.
                Defaults to None.
            num_results (int, optional): The number of results to return. Defaults to 10.
            dialect (int, optional): The Redis dialect version. Defaults to 2.
        """

        self._filter_expression = filter_expression
        self._num_results = num_results

        if isinstance(vectors, Vector):
            self._vectors = [vectors]
        else:
            self._vectors = vectors  # type: ignore

        if not all([isinstance(v, Vector) for v in self._vectors]):
            raise TypeError(
                "vector argument must be a Vector object or list of Vector objects."
            )

        query_string = self._build_query_string()
        super().__init__(query_string)

        # calculate the respective vector similarities
        for i in range(len(self._vectors)):
            self.apply(**{f"score_{i}": f"(2 - @distance_{i})/2"})

        # construct the scoring string based on the vector similarity scores and weights
        combined_scores = []
        for i, w in enumerate([v.weight for v in self._vectors]):
            combined_scores.append(f"@score_{i} * {w}")
        combined_score_string = " + ".join(combined_scores)

        self.apply(combined_score=combined_score_string)

        self.sort_by(Desc("@combined_score"), max=num_results)  # type: ignore
        self.dialect(dialect)
        if return_fields:
            self.load(*return_fields)  # type: ignore[arg-type]

    @property
    def params(self) -> Dict[str, Any]:
        """Return the parameters for the aggregation.

        Returns:
            Dict[str, Any]: The parameters for the aggregation.
        """
        params = {}
        for i, v in enumerate(self._vectors):
            params[f"vector_{i}"] = v.vector
        return params

    def _build_query_string(self) -> str:
        """Build the full query string for text search with optional filtering."""

        # base KNN query
        range_queries = []
        for i, (vector, field, max_dist) in enumerate(
            [(v.vector, v.field_name, v.max_distance) for v in self._vectors]
        ):
            range_queries.append(
                f"@{field}:[VECTOR_RANGE {max_dist} $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}"
            )

        range_query = " AND ".join(range_queries)

        filter_expression = self._filter_expression
        if isinstance(self._filter_expression, FilterExpression):
            filter_expression = str(self._filter_expression)

        if filter_expression:
            return f"({range_query}) AND ({filter_expression})"
        else:
            return f"{range_query}"

    def __str__(self) -> str:
        """Return the string representation of the query."""
        return " ".join([str(x) for x in self.build_args()])
