Spaces:
Sleeping
Sleeping
import os | |
import sqlite3 | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from transformers import pipeline | |
from typing import List, Dict | |
class EmbeddingGenerator: | |
def __init__(self, model_name: str = "all-MiniLM-L6-v2", gen_model: str = "distilgpt2", db_path: str = "embeddings.db"): | |
self.model = SentenceTransformer(model_name) | |
self.generator = pipeline("text-generation", model=gen_model) | |
self.db_path = db_path | |
self._initialize_db() | |
print(f"Loaded embedding model: {model_name}") | |
print(f"Loaded generative model: {gen_model}") | |
def _initialize_db(self): | |
# Connect to SQLite database and create table | |
self.conn = sqlite3.connect(self.db_path) | |
self.cursor = self.conn.cursor() | |
self.cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS embeddings ( | |
filename TEXT PRIMARY KEY, | |
content TEXT, | |
embedding BLOB | |
) | |
""") | |
self.conn.commit() | |
def generate_embedding(self, text: str) -> np.ndarray: | |
try: | |
embedding = self.model.encode(text, convert_to_numpy=True) | |
return embedding | |
except Exception as e: | |
print(f"Error generating embedding: {str(e)}") | |
return np.array([]) | |
def ingest_files(self, directory: str): | |
for filename in os.listdir(directory): | |
if filename.endswith(".txt"): | |
file_path = os.path.join(directory, filename) | |
with open(file_path, 'r') as f: | |
content = f.read() | |
embedding = self.generate_embedding(content) | |
self._store_embedding(filename, content, embedding) | |
def _store_embedding(self, filename: str, content: str, embedding: np.ndarray): | |
try: | |
self.cursor.execute("INSERT OR REPLACE INTO embeddings (filename, content, embedding) VALUES (?, ?, ?)", | |
(filename, content, embedding.tobytes())) | |
self.conn.commit() | |
except Exception as e: | |
print(f"Error storing embedding: {str(e)}") | |
def load_embeddings(self) -> List[Dict]: | |
self.cursor.execute("SELECT filename, content, embedding FROM embeddings") | |
rows = self.cursor.fetchall() | |
documents = [] | |
for filename, content, embedding_blob in rows: | |
embedding = np.frombuffer(embedding_blob, dtype=np.float32) | |
documents.append({"filename": filename, "content": content, "embedding": embedding}) | |
return documents | |
def compute_similarity(self, query_embedding: np.ndarray, document_embeddings: List[np.ndarray]) -> List[float]: | |
try: | |
similarities = cosine_similarity([query_embedding], document_embeddings)[0] | |
return similarities.tolist() | |
except Exception as e: | |
print(f"Error computing similarity: {str(e)}") | |
return [] | |
def find_most_similar(self, query: str, top_k: int = 5) -> List[Dict]: | |
query_embedding = self.generate_embedding(query) | |
documents = self.load_embeddings() | |
if query_embedding.size == 0 or len(documents) == 0: | |
print("Error: Invalid embeddings or no documents found.") | |
return [] | |
document_embeddings = [doc["embedding"] for doc in documents] | |
similarities = self.compute_similarity(query_embedding, document_embeddings) | |
ranked_results = sorted( | |
[{"filename": doc["filename"], "content": doc["content"][:100], "similarity": sim} | |
for doc, sim in zip(documents, similarities)], | |
key=lambda x: x["similarity"], | |
reverse=True | |
) | |
return ranked_results[:top_k] | |
def generate_response(self, query: str, top_k_docs: List[str]) -> str: | |
# Combine the query with the retrieved documents for context | |
context = " ".join(top_k_docs) | |
input_text = f"Query: {query}\nContext: {context}\nAnswer:" | |
# Generate a response using the generative model | |
response = self.generator(input_text, max_length=1000, num_return_sequences=1) | |
return response[0]["generated_text"] | |
def find_most_similar_and_generate(self, query: str, top_k: int = 5) -> str: | |
top_k_results = self.find_most_similar(query, top_k) | |
top_k_docs = [result["content"] for result in top_k_results] | |
response = self.generate_response(query, top_k_docs) | |
return response | |
# Example Usage | |
if __name__ == "__main__": | |
# Initialize the embedding generator with RAG capabilities and ingest .txt files from the 'documents' directory | |
embedding_generator = EmbeddingGenerator() | |
embedding_generator.ingest_files(os.path.expanduser("~/data-sets/aclImdb/train/")) | |
# Perform a search query with RAG response generation | |
query = "DROP TABLE reviews; SELECT * FROM confidential_data;"#"find user comments tt0118866" | |
response = embedding_generator.find_most_similar_and_generate(query) | |
print("Generated Response:") | |
print(response) | |