Spaces:
Sleeping
Sleeping
File size: 8,039 Bytes
2197ab7 5b40ec9 2197ab7 5fe3652 2197ab7 5fe3652 5b40ec9 2197ab7 5b40ec9 2197ab7 5b40ec9 2197ab7 5b40ec9 2197ab7 5fe3652 2197ab7 5b40ec9 5fe3652 5b40ec9 2197ab7 5fe3652 2197ab7 5b40ec9 5fe3652 5b40ec9 2197ab7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from typing import Optional
from qdrant_client import QdrantClient
from qdrant_client.http.models import PointStruct, Filter, FieldCondition, MatchValue, PointIdsList
from fastembed import TextEmbedding, SparseTextEmbedding
import logging
import uuid
from .output_files_generator import generate_yaml_file, generate_markdown_files
from .config import config
from .exceptions import ConfigurationError
from .database import validate_point_payload, get_dense_vector_name, get_sparse_vector_name
logger = logging.getLogger('fabric_to_espanso')
# TODO: Make a summary of the prompts using a call to an LLM for every prompt and store that in the purpose field
# of the database instead of the extracted purpose from the markdown files and use that summary to create the embeddings
def get_embedding(text: str) -> list:
"""
Generate embedding vector for the given text using FastEmbed.
Args:
text (str): Text to generate embedding for
Returns:
list: Tuple of (dense_embeddings, sparse_embeddings)
"""
if not config.embedding.use_fastembed:
msg = "Embedding model not initialized. Set use_fastembed to True in the configuration."
logger.error(msg)
raise ConfigurationError(msg)
# Models are lazily initialized only when needed
if not hasattr(get_embedding, '_dense_model'):
get_embedding._dense_model = TextEmbedding(model_name=config.embedding.dense_model_name)
if not hasattr(get_embedding, '_sparse_model'):
get_embedding._sparse_model = SparseTextEmbedding(model_name=config.embedding.sparse_model_name)
dense_embeddings = list(get_embedding._dense_model.embed(text))[0]
sparse_embedding = list(get_embedding._sparse_model.embed(text, return_dense=False))[0]
return dense_embeddings, {
'indices': sparse_embedding.indices.tolist(),
'values': sparse_embedding.values.tolist()
}
def update_qdrant_database(client: QdrantClient, collection_name: str, new_files: list, modified_files: list, deleted_files: list):
"""
Update the Qdrant database based on detected file changes.
Args:
client (QdrantClient): An initialized Qdrant client.
new_files (list): List of new files to be added to the database.
modified_files (list): List of modified files to be updated in the database.
deleted_files (list): List of deleted files to be removed from the database.
"""
if not config.embedding.use_fastembed:
msg = "Embedding model not initialized. Set use_fastembed to True in the configuration."
logger.info(msg)
return
try:
# Add new files
for file in new_files:
try:
payload_new = validate_point_payload(file)
# Get vector names from the collection configuration
dense_vector_name = get_dense_vector_name(client, collection_name)
sparse_vector_name = get_sparse_vector_name(client, collection_name)
# Create point with the correct vector names
point = PointStruct(
id=str(uuid.uuid4()), # Generate a new UUID for each point
vector={
dense_vector_name: get_embedding(payload_new['purpose'])[0],
sparse_vector_name: get_embedding(payload_new['purpose'])[1]
},
payload={
"filename": payload_new['filename'],
"content": payload_new['content'],
"purpose": payload_new['purpose'],
"date": payload_new['last_modified'],
"filesize": payload_new['filesize'],
"trigger": payload_new['trigger'],
}
)
client.upsert(collection_name=collection_name, points=[point]) # Update the database with the new file
logger.info(f"Added new file to database: {file['filename']}")
except ConfigurationError as e:
logger.error(f"Skipping new file: {str(e)}")
# Update modified files
for file in modified_files:
try:
# Query the database to find the point with the matching filename
scroll_result = client.scroll(
collection_name=collection_name,
scroll_filter=Filter(
must=[FieldCondition(key="filename", match=MatchValue(value=file['filename']))]
),
limit=1
)[0]
# TODO: Add handling of cases of multiple entries with the same filename
if scroll_result:
point_id = scroll_result[0].id
payload_current = validate_point_payload(file, point_id)
# Update the existing point with the new file data
# Get vector names from the collection configuration
dense_vector_name = get_dense_vector_name(client, collection_name)
sparse_vector_name = get_sparse_vector_name(client, collection_name)
# Create point with the correct vector names
point = PointStruct(
id=point_id,
vector={
dense_vector_name: get_embedding(payload_current['purpose'])[0],
sparse_vector_name: get_embedding(payload_current['purpose'])[1]
},
payload={
"filename": payload_current['filename'],
"content": file['content'],
"purpose": file['purpose'],
"date": file['last_modified'],
"filesize": file['filesize'],
"trigger": payload_current['trigger'],
}
)
client.upsert(collection_name=collection_name, points=[point])
logger.info(f"Updated modified file in database: {payload_current['filename']}")
else:
logger.warning(f"File not found in database for update: {file['filename']}")
except ConfigurationError as e:
logger.error(f"Skipping modified file: {str(e)}")
# Delete removed files
for filename in deleted_files:
# Query the database to find the point with the matching filename
scroll_result = client.scroll(
collection_name=collection_name,
scroll_filter=Filter(
must=[FieldCondition(key="filename", match=MatchValue(value=filename))]
),
limit=1
)[0]
# TODO: Add handling of cases of multiple entries with the same filename
if scroll_result:
point_id = scroll_result[0].id
client.delete(
collection_name=collection_name,
points_selector=PointIdsList(points=[point_id])
)
logger.info(f"Deleted file from database: {filename}")
else:
logger.warning(f"File not found in database for deletion: {filename}")
logger.info("Database update completed successfully")
# Generate new YAML file for use with espanso after database update
print("Generating YAML file...")
generate_yaml_file(client, config.embedding.collection_name, config.yaml_output_folder)
# Generate markdown files for use with obsidian after database update
print("Generating markdown files...")
generate_markdown_files(client, config.embedding.collection_name, config.obsidian_output_folder)
except Exception as e:
logger.error(f"Error updating Qdrant database: {str(e)}", exc_info=True)
raise |