Spaces:
Sleeping
Sleeping
| import os | |
| import sqlite3 | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from typing import List, Dict | |
| class EmbeddingGenerator: | |
| def __init__(self, model_name: str = "all-MiniLM-L6-v2", db_path: str = "embeddings.db"): | |
| self.model = SentenceTransformer(model_name) | |
| self.db_path = db_path | |
| self._initialize_db() | |
| print(f"Loaded embedding model: {model_name}") | |
| def _initialize_db(self): | |
| # Connect to SQLite database and create table | |
| self.conn = sqlite3.connect(self.db_path) | |
| self.cursor = self.conn.cursor() | |
| self.cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS embeddings ( | |
| filename TEXT PRIMARY KEY, | |
| content TEXT, | |
| embedding BLOB | |
| ) | |
| """) | |
| self.conn.commit() | |
| def generate_embedding(self, text: str) -> np.ndarray: | |
| try: | |
| embedding = self.model.encode(text, convert_to_numpy=True) | |
| return embedding | |
| except Exception as e: | |
| print(f"Error generating embedding: {str(e)}") | |
| return np.array([]) | |
| def ingest_files(self, directory: str): | |
| for filename in os.listdir(directory): | |
| if filename.endswith(".txt"): | |
| file_path = os.path.join(directory, filename) | |
| with open(file_path, 'r') as f: | |
| content = f.read() | |
| embedding = self.generate_embedding(content) | |
| self._store_embedding(filename, content, embedding) | |
| def _store_embedding(self, filename: str, content: str, embedding: np.ndarray): | |
| try: | |
| self.cursor.execute("INSERT OR REPLACE INTO embeddings (filename, content, embedding) VALUES (?, ?, ?)", | |
| (filename, content, embedding.tobytes())) | |
| self.conn.commit() | |
| except Exception as e: | |
| print(f"Error storing embedding: {str(e)}") | |
| def load_embeddings(self) -> List[Dict]: | |
| self.cursor.execute("SELECT filename, content, embedding FROM embeddings") | |
| rows = self.cursor.fetchall() | |
| documents = [] | |
| for filename, content, embedding_blob in rows: | |
| embedding = np.frombuffer(embedding_blob, dtype=np.float32) | |
| documents.append({"filename": filename, "content": content, "embedding": embedding}) | |
| return documents | |
| def compute_similarity(self, query_embedding: np.ndarray, document_embeddings: List[np.ndarray]) -> List[float]: | |
| try: | |
| similarities = cosine_similarity([query_embedding], document_embeddings)[0] | |
| return similarities.tolist() | |
| except Exception as e: | |
| print(f"Error computing similarity: {str(e)}") | |
| return [] | |
| def find_most_similar(self, query: str, top_k: int = 5) -> List[Dict]: | |
| query_embedding = self.generate_embedding(query) | |
| documents = self.load_embeddings() | |
| if query_embedding.size == 0 or len(documents) == 0: | |
| print("Error: Invalid embeddings or no documents found.") | |
| return [] | |
| document_embeddings = [doc["embedding"] for doc in documents] | |
| similarities = self.compute_similarity(query_embedding, document_embeddings) | |
| ranked_results = sorted( | |
| [{"filename": doc["filename"], "content": doc["content"][:100], "similarity": sim} | |
| for doc, sim in zip(documents, similarities)], | |
| key=lambda x: x["similarity"], | |
| reverse=True | |
| ) | |
| return ranked_results[:top_k] | |
| # Example Usage | |
| if __name__ == "__main__": | |
| # Initialize the embedding generator and ingest .txt files from the 'documents' directory | |
| embedding_generator = EmbeddingGenerator() | |
| embedding_generator.ingest_files(os.path.expanduser("~/data-sets/aclImdb/train/")) | |
| # Perform a search query | |
| query = "What can be used for document search?"#"DROP TABLE reviews; SELECT * FROM confidential_data;"#"What can be used for document search?" | |
| results = embedding_generator.find_most_similar(query, top_k=3) | |
| print("Search Results:") | |
| for result in results: | |
| print(f"Filename: {result['filename']}, Similarity: {result['similarity']:.4f}") | |
| print(f"Snippet: {result['content']}\n") | |