Spaces:
Sleeping
Sleeping
# embedding.py | |
import logging | |
import pandas as pd | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from chroma_setup import initialize_client | |
import uuid | |
# Creamos una instancia del modelo local de sentence-transformers | |
# (se descargará y cacheará la primera vez que se ejecute) | |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
def embed_text_chunks(pages_and_chunks: list[dict]) -> pd.DataFrame: | |
""" | |
Genera embeddings para cada chunk de texto usando un modelo local | |
de sentence-transformers. | |
""" | |
for item in pages_and_chunks: | |
text_chunk = item["sentence_chunk"] | |
try: | |
# encode() acepta una lista de strings y retorna una lista de embeddings (ndarray). | |
embedding_array = model.encode([text_chunk]) | |
# Devuelve una matriz shape (1, 384) si es all-MiniLM-L6-v2, así que tomamos el [0] | |
embedding = embedding_array[0].tolist() | |
# embedding ahora es una lista de floats | |
item["embedding"] = embedding | |
except Exception as e: | |
logging.error(f"Fallo al generar embedding para: {text_chunk}. Error: {e}") | |
item["embedding"] = None | |
return pd.DataFrame(pages_and_chunks) | |
def save_to_chroma_db(embeddings_df: pd.DataFrame, user_id: str, document_id: str): | |
""" | |
Guarda en ChromaDB los embeddings generados. | |
""" | |
client = initialize_client() | |
# Creas o recuperas la colección. Asegúrate de usar el mismo nombre | |
# que luego usarás en tus queries. | |
collection = client.get_or_create_collection(name=f"text_embeddings_{user_id}") | |
combined_key = f"{user_id}_{document_id}" | |
ids = [f"{combined_key}_{i}" for i in range(len(embeddings_df))] | |
documents = embeddings_df["sentence_chunk"].tolist() | |
embeddings = embeddings_df["embedding"].tolist() | |
# Verificamos que ninguno sea None | |
for idx, emb in enumerate(embeddings): | |
if emb is None: | |
raise ValueError( | |
f"El chunk con ID {ids[idx]} no tiene embedding válido (None)." | |
) | |
# ¡Ahora todos deben ser listas de floats! | |
# Podemos añadirlos a la colección: | |
collection.add( | |
documents=documents, | |
embeddings=embeddings, | |
ids=ids, | |
metadatas=[{"combined_key": combined_key} for _ in range(len(embeddings_df))] | |
) | |
def generate_document_id() -> str: | |
return str(uuid.uuid4()) | |
def query_chroma_db(user_id: str, document_id: str, query: str): | |
client = initialize_client() | |
collection = client.get_collection(name=f"text_embeddings_{user_id}") | |
combined_key = f"{user_id}_{document_id}" | |
results = collection.query( | |
query_texts=[query], | |
n_results=5, | |
where={"combined_key": combined_key}, | |
) | |
documents = results.get("documents", []) | |
if not documents: | |
return "No se encontraron documentos" | |
# Aplanar la lista de documentos | |
relevant_docs = [doc for sublist in documents for doc in sublist] | |
return "\n\n".join(relevant_docs) | |