from pymilvus import (
    connections, 
    FieldSchema, 
    CollectionSchema, 
    DataType, 
    Collection, 
    utility
)

from dotenv import load_dotenv
import os
load_dotenv()
API_KEY_OPENAI=os.getenv("API_KEY_OPENAI")
HOST_MILVUS=os.getenv("HOST_MILVUS")
DB_NAME_MILVUS=os.getenv("DB_NAME_MILVUS")

from concurrent.futures import ThreadPoolExecutor

def procesar_productos(data):
    def procesar_producto(producto):
        code = producto["_id"] 

        # Eliminar el campo 'embeddings' del producto antes de convertirlo a cadena
        producto_sin_embeddings = producto.copy()
        if "embeddings" in producto_sin_embeddings:
            del producto_sin_embeddings["embeddings"]

        text = str(producto_sin_embeddings)  # El JSON del producto como cadena sin embeddings
        document = "productos"  # Valor constante
        vector = producto["embeddings"]  # Los embeddings del producto

        return code, text, document, vector

    codes = []
    texts = []
    documents = []
    vectors = []

    with ThreadPoolExecutor(max_workers=10) as executor:
        results = executor.map(procesar_producto, data)

        for code, text, document, vector in results:
            codes.append(code)
            texts.append(text)
            documents.append(document)
            vectors.append(vector)

    return codes, texts, documents, vectors

def insertar_productos_en_milvus(data):
    print("Procesando productos...")

    connections.connect(db_name=DB_NAME_MILVUS, host=HOST_MILVUS, port="19530")
    collection_name = "PRODUCTOS_MARCIMEX"

    if utility.has_collection(collection_name):
        print("colección ya existe")
        collection_to_drop = Collection(collection_name)
        collection_to_drop.drop()
        print("colección eliminada")
    
    if not utility.has_collection(collection_name):
        # Esquema
        fields = [
            FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
            FieldSchema(name="code", dtype=DataType.VARCHAR, max_length=20),
            FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=10000),
            FieldSchema(name="namespace", dtype=DataType.VARCHAR, max_length=10000),
            FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=3072)
        ]
        schema = CollectionSchema(fields, description="productos")
        collection = Collection(name=collection_name, schema=schema)
        print("colección creada")
    else:
        collection = Collection(name=collection_name)
    
    # Crear índice si no existe
    if not collection.has_index():
        index_params = {
            "metric_type": "COSINE",
            "index_type": "IVF_FLAT",
            "params": {"nlist": 36}
        }
        collection.create_index(field_name="vector", index_params=index_params)
        print("Índice creado.")
    
    # Prepara los datos para insertar
    codes, texts, documents, vectors = procesar_productos(data)
    
    insert_data = [codes, texts, documents, vectors]

    # Inserta los datos en la colección
    mr = collection.insert(insert_data)
    # print("Datos insertados con IDs:", mr.primary_keys)
    print("terminado")
    
    collection.load()
    connections.disconnect("marcimex")

