Spaces:
Sleeping
Sleeping
from typing import Optional | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import PointStruct, Filter, FieldCondition, MatchValue, PointIdsList | |
from fastembed import TextEmbedding, SparseTextEmbedding | |
import logging | |
import uuid | |
from .output_files_generator import generate_yaml_file, generate_markdown_files | |
from .config import config | |
from .exceptions import ConfigurationError | |
from .database import validate_point_payload, get_dense_vector_name, get_sparse_vector_name | |
logger = logging.getLogger('fabric_to_espanso') | |
# TODO: Make a summary of the prompts using a call to an LLM for every prompt and store that in the purpose field | |
# of the database instead of the extracted purpose from the markdown files and use that summary to create the embeddings | |
def get_embedding(text: str) -> list: | |
""" | |
Generate embedding vector for the given text using FastEmbed. | |
Args: | |
text (str): Text to generate embedding for | |
Returns: | |
list: Tuple of (dense_embeddings, sparse_embeddings) | |
""" | |
if not config.embedding.use_fastembed: | |
msg = "Embedding model not initialized. Set use_fastembed to True in the configuration." | |
logger.error(msg) | |
raise ConfigurationError(msg) | |
# Models are lazily initialized only when needed | |
if not hasattr(get_embedding, '_dense_model'): | |
get_embedding._dense_model = TextEmbedding(model_name=config.embedding.dense_model_name) | |
if not hasattr(get_embedding, '_sparse_model'): | |
get_embedding._sparse_model = SparseTextEmbedding(model_name=config.embedding.sparse_model_name) | |
dense_embeddings = list(get_embedding._dense_model.embed(text))[0] | |
sparse_embedding = list(get_embedding._sparse_model.embed(text, return_dense=False))[0] | |
return dense_embeddings, { | |
'indices': sparse_embedding.indices.tolist(), | |
'values': sparse_embedding.values.tolist() | |
} | |
def update_qdrant_database(client: QdrantClient, collection_name: str, new_files: list, modified_files: list, deleted_files: list): | |
""" | |
Update the Qdrant database based on detected file changes. | |
Args: | |
client (QdrantClient): An initialized Qdrant client. | |
new_files (list): List of new files to be added to the database. | |
modified_files (list): List of modified files to be updated in the database. | |
deleted_files (list): List of deleted files to be removed from the database. | |
""" | |
if not config.embedding.use_fastembed: | |
msg = "Embedding model not initialized. Set use_fastembed to True in the configuration." | |
logger.info(msg) | |
return | |
try: | |
# Add new files | |
for file in new_files: | |
try: | |
payload_new = validate_point_payload(file) | |
# Get vector names from the collection configuration | |
dense_vector_name = get_dense_vector_name(client, collection_name) | |
sparse_vector_name = get_sparse_vector_name(client, collection_name) | |
# Create point with the correct vector names | |
point = PointStruct( | |
id=str(uuid.uuid4()), # Generate a new UUID for each point | |
vector={ | |
dense_vector_name: get_embedding(payload_new['purpose'])[0], | |
sparse_vector_name: get_embedding(payload_new['purpose'])[1] | |
}, | |
payload={ | |
"filename": payload_new['filename'], | |
"content": payload_new['content'], | |
"purpose": payload_new['purpose'], | |
"date": payload_new['last_modified'], | |
"filesize": payload_new['filesize'], | |
"trigger": payload_new['trigger'], | |
} | |
) | |
client.upsert(collection_name=collection_name, points=[point]) # Update the database with the new file | |
logger.info(f"Added new file to database: {file['filename']}") | |
except ConfigurationError as e: | |
logger.error(f"Skipping new file: {str(e)}") | |
# Update modified files | |
for file in modified_files: | |
try: | |
# Query the database to find the point with the matching filename | |
scroll_result = client.scroll( | |
collection_name=collection_name, | |
scroll_filter=Filter( | |
must=[FieldCondition(key="filename", match=MatchValue(value=file['filename']))] | |
), | |
limit=1 | |
)[0] | |
# TODO: Add handling of cases of multiple entries with the same filename | |
if scroll_result: | |
point_id = scroll_result[0].id | |
payload_current = validate_point_payload(file, point_id) | |
# Update the existing point with the new file data | |
# Get vector names from the collection configuration | |
dense_vector_name = get_dense_vector_name(client, collection_name) | |
sparse_vector_name = get_sparse_vector_name(client, collection_name) | |
# Create point with the correct vector names | |
point = PointStruct( | |
id=point_id, | |
vector={ | |
dense_vector_name: get_embedding(payload_current['purpose'])[0], | |
sparse_vector_name: get_embedding(payload_current['purpose'])[1] | |
}, | |
payload={ | |
"filename": payload_current['filename'], | |
"content": file['content'], | |
"purpose": file['purpose'], | |
"date": file['last_modified'], | |
"filesize": file['filesize'], | |
"trigger": payload_current['trigger'], | |
} | |
) | |
client.upsert(collection_name=collection_name, points=[point]) | |
logger.info(f"Updated modified file in database: {payload_current['filename']}") | |
else: | |
logger.warning(f"File not found in database for update: {file['filename']}") | |
except ConfigurationError as e: | |
logger.error(f"Skipping modified file: {str(e)}") | |
# Delete removed files | |
for filename in deleted_files: | |
# Query the database to find the point with the matching filename | |
scroll_result = client.scroll( | |
collection_name=collection_name, | |
scroll_filter=Filter( | |
must=[FieldCondition(key="filename", match=MatchValue(value=filename))] | |
), | |
limit=1 | |
)[0] | |
# TODO: Add handling of cases of multiple entries with the same filename | |
if scroll_result: | |
point_id = scroll_result[0].id | |
client.delete( | |
collection_name=collection_name, | |
points_selector=PointIdsList(points=[point_id]) | |
) | |
logger.info(f"Deleted file from database: {filename}") | |
else: | |
logger.warning(f"File not found in database for deletion: {filename}") | |
logger.info("Database update completed successfully") | |
# Generate new YAML file for use with espanso after database update | |
print("Generating YAML file...") | |
generate_yaml_file(client, config.embedding.collection_name, config.yaml_output_folder) | |
# Generate markdown files for use with obsidian after database update | |
print("Generating markdown files...") | |
generate_markdown_files(client, config.embedding.collection_name, config.obsidian_output_folder) | |
except Exception as e: | |
logger.error(f"Error updating Qdrant database: {str(e)}", exc_info=True) | |
raise |