from langchain_openai import ChatOpenAI
from langchain_core.prompts.chat import PromptTemplate
from langchain_milvus import Milvus
from langchain_openai import OpenAIEmbeddings
import statistics
import concurrent.futures


def milvus_busqueda(collection_name: str, db_name: str, uri: str):
    embeddings = OpenAIEmbeddings(
        model="text-embedding-3-large",
        openai_api_key="sk-agrota-v7nqf1DdSh1EYdGJic5ST3BlbkFJr5HoyVHSYxf462AMV6hs",
    )

    conexion = Milvus(
        embedding_function=embeddings,
        collection_name=collection_name,
        connection_args={"db_name": db_name, "uri": uri, "token": "root:Milvus"},
    )
    return conexion


def procesar_colecciones(colecciones, db_name, uri, pregunta, k, max_workers=100):
    lista_colecciones = []

    def procesar_coleccion(coleccion):
        try:
            coleccion_docs = milvus_busqueda(
                coleccion, db_name, uri
            ).similarity_search_with_score(pregunta, k=k)
            resultados_filtrados = [doc for doc in coleccion_docs if doc[1] > 0.3]
            if resultados_filtrados:
                return resultados_filtrados
        except Exception as e:
            print(f"Error procesando la colección {coleccion}: {e}")
        return None

    # Configurar el número de hilos a usar
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futuros = {
            executor.submit(procesar_coleccion, coleccion): coleccion
            for coleccion in colecciones
        }
        for futuro in concurrent.futures.as_completed(futuros):
            resultado = futuro.result()
            if resultado:
                lista_colecciones.append(resultado)

    return lista_colecciones


def procesar_colecciones_codigo(
    colecciones, db_name, uri, pregunta, k, max_workers=100
):
    lista_colecciones = []

    def procesar_coleccion_codigo(coleccion):
        try:
            coleccion_docs = milvus_busqueda(
                coleccion, db_name, uri
            ).similarity_search_with_score(pregunta, k=k)
            resultados_filtrados = [doc for doc in coleccion_docs if doc[1] > 0.3]
            if resultados_filtrados:
                return resultados_filtrados
        except Exception as e:
            print(f"Error procesando la colección {coleccion}: {e}")
        return None

    # Ejecutar con un ThreadPoolExecutor
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futuros = {
            executor.submit(procesar_coleccion_codigo, coleccion): coleccion
            for coleccion in colecciones
        }
        for futuro in concurrent.futures.as_completed(futuros):
            resultado = futuro.result()
            if resultado:
                lista_colecciones.append(resultado)
                print(f"Procesadas {len(lista_colecciones)} colecciones.")

    return lista_colecciones
