chagu-demo / falocon_api /embeddingGenerator.py
talexm
update RAG query improvements
73321dd
raw
history blame
4.33 kB
import os
import sqlite3
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict
class EmbeddingGenerator:
def __init__(self, model_name: str = "all-MiniLM-L6-v2", db_path: str = "embeddings.db"):
self.model = SentenceTransformer(model_name)
self.db_path = db_path
self._initialize_db()
print(f"Loaded embedding model: {model_name}")
def _initialize_db(self):
# Connect to SQLite database and create table
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
filename TEXT PRIMARY KEY,
content TEXT,
embedding BLOB
)
""")
self.conn.commit()
def generate_embedding(self, text: str) -> np.ndarray:
try:
embedding = self.model.encode(text, convert_to_numpy=True)
return embedding
except Exception as e:
print(f"Error generating embedding: {str(e)}")
return np.array([])
def ingest_files(self, directory: str):
for filename in os.listdir(directory):
if filename.endswith(".txt"):
file_path = os.path.join(directory, filename)
with open(file_path, 'r') as f:
content = f.read()
embedding = self.generate_embedding(content)
self._store_embedding(filename, content, embedding)
def _store_embedding(self, filename: str, content: str, embedding: np.ndarray):
try:
self.cursor.execute("INSERT OR REPLACE INTO embeddings (filename, content, embedding) VALUES (?, ?, ?)",
(filename, content, embedding.tobytes()))
self.conn.commit()
except Exception as e:
print(f"Error storing embedding: {str(e)}")
def load_embeddings(self) -> List[Dict]:
self.cursor.execute("SELECT filename, content, embedding FROM embeddings")
rows = self.cursor.fetchall()
documents = []
for filename, content, embedding_blob in rows:
embedding = np.frombuffer(embedding_blob, dtype=np.float32)
documents.append({"filename": filename, "content": content, "embedding": embedding})
return documents
def compute_similarity(self, query_embedding: np.ndarray, document_embeddings: List[np.ndarray]) -> List[float]:
try:
similarities = cosine_similarity([query_embedding], document_embeddings)[0]
return similarities.tolist()
except Exception as e:
print(f"Error computing similarity: {str(e)}")
return []
def find_most_similar(self, query: str, top_k: int = 5) -> List[Dict]:
query_embedding = self.generate_embedding(query)
documents = self.load_embeddings()
if query_embedding.size == 0 or len(documents) == 0:
print("Error: Invalid embeddings or no documents found.")
return []
document_embeddings = [doc["embedding"] for doc in documents]
similarities = self.compute_similarity(query_embedding, document_embeddings)
ranked_results = sorted(
[{"filename": doc["filename"], "content": doc["content"][:100], "similarity": sim}
for doc, sim in zip(documents, similarities)],
key=lambda x: x["similarity"],
reverse=True
)
return ranked_results[:top_k]
# Example Usage
if __name__ == "__main__":
# Initialize the embedding generator and ingest .txt files from the 'documents' directory
embedding_generator = EmbeddingGenerator()
embedding_generator.ingest_files(os.path.expanduser("~/data-sets/aclImdb/train/"))
# Perform a search query
query = "What can be used for document search?"#"DROP TABLE reviews; SELECT * FROM confidential_data;"#"What can be used for document search?"
results = embedding_generator.find_most_similar(query, top_k=3)
print("Search Results:")
for result in results:
print(f"Filename: {result['filename']}, Similarity: {result['similarity']:.4f}")
print(f"Snippet: {result['content']}\n")