from langchain_openai import ChatOpenAI
from langchain_core.prompts.chat import PromptTemplate
from langchain_community.vectorstores import Milvus
from langchain_openai import OpenAIEmbeddings
import statistics
import concurrent.futures
from dotenv import load_dotenv
import os
load_dotenv()
API_KEY_OPENAI=os.getenv("API_KEY_OPENAI")
HOST_MILVUS=os.getenv("HOST_MILVUS")
DB_NAME_MILVUS=os.getenv("DB_NAME_MILVUS")

def milvus_busqueda(collection_name: str):
    embeddings = OpenAIEmbeddings(
        model="text-embedding-3-large",
        openai_api_key=API_KEY_OPENAI,
    )
    return Milvus(
        embeddings,
        collection_name=collection_name,
        connection_args={
            "db_name": DB_NAME_MILVUS,
            "host": HOST_MILVUS,
            "port": "19530",
        }
        
    )


def vectores(pregunta: str):
    colecciones = ["PRODUCTOS_MARCIMEX"]
    
    for coleccion in colecciones:
        coleccion_docs = milvus_busqueda(coleccion).similarity_search_with_score(pregunta, k=5)
        puntajes = []  # Inicializamos puntajes antes del bucle
        for documento, puntaje in coleccion_docs:
            puntajes.append(puntaje)
        if puntajes:  # Verificamos que puntajes no esté vacío
            promedio = sum(puntajes) / len(puntajes)
            if promedio > 0.55:
                return coleccion_docs
    
    return milvus_busqueda("PRODUCTOS_MARCIMEX").similarity_search_with_score(pregunta, k=5)


def vectors_name(question: str,collection: str,amount:int):
    colecciones = [collection]
    for coleccion in colecciones:
        coleccion_docs = milvus_busqueda(coleccion).similarity_search_with_score(question, k=amount)
        puntajes = []  # Inicializamos puntajes antes del bucle
        for documento, puntaje in coleccion_docs:
            puntajes.append(puntaje)
        if puntajes:  # Verificamos que puntajes no esté vacío
            promedio = sum(puntajes) / len(puntajes)
            if promedio > 0.55:
                return coleccion_docs
    
    return milvus_busqueda(collection).similarity_search_with_score(question, k=amount)



def procesar_colecciones(colecciones, question: str, amount: int):
    lista_colecciones = []
    def procesar_coleccion(coleccion):
        coleccion_docs = milvus_busqueda(coleccion).similarity_search_with_score(
            question, k=amount
        )
        for doc in coleccion_docs:
            puntaje = doc[1]
            if puntaje > 0.40:
                return coleccion_docs
        return None

    with concurrent.futures.ThreadPoolExecutor() as executor:
        futuros = {
            executor.submit(procesar_coleccion, coleccion): coleccion
            for coleccion in colecciones
        }
        for futuro in concurrent.futures.as_completed(futuros):
            resultado = futuro.result()
            if resultado:
                lista_colecciones.append(resultado)
    return lista_colecciones