Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| import hashlib | |
| from pathlib import Path | |
| from typing import List | |
| import chromadb | |
| from chromadb.config import Settings | |
| from sentence_transformers import SentenceTransformer | |
| from src.utils.logger import get_logger | |
| from config.settings import settings | |
| logger = get_logger(__name__) | |
| class ChromaVectorDBManager: | |
| """Corporate-friendly ChromaDB manager - completely offline.""" | |
| def __init__(self, model_name: str = None, db_path: str = None): | |
| self.model_name = model_name or getattr( | |
| settings, 'EMBEDDING_MODEL', 'sentence-transformers/all-MiniLM-L6-v2' | |
| ) | |
| self.embedding_model = SentenceTransformer(self.model_name) | |
| self.db_path = db_path or getattr(settings, 'CHROMADB_PATH', './chroma_db') | |
| os.makedirs(self.db_path, exist_ok=True) | |
| self.client = chromadb.PersistentClient( | |
| path=self.db_path, | |
| settings=Settings( | |
| anonymized_telemetry=False, | |
| allow_reset=True, | |
| is_persistent=True | |
| ) | |
| ) | |
| self.collection_name = getattr(settings, 'COLLECTION_NAME', 'rag_chunks') | |
| self.collection = self._get_collection() | |
| logger.info(f"ChromaDB initialized at: {self.db_path}") | |
| def _get_collection(self): | |
| """Get or create collection without embedding function.""" | |
| try: | |
| return self.client.get_collection(name=self.collection_name) | |
| except Exception: | |
| try: | |
| self.client.delete_collection(name=self.collection_name) | |
| except Exception: | |
| pass | |
| return self.client.create_collection( | |
| name=self.collection_name, | |
| metadata={"description": "RAG chunks"} | |
| ) | |
| def generate_embeddings(self, texts: List[str]) -> List[List[float]]: | |
| """Generate embeddings using local sentence-transformers.""" | |
| embeddings = self.embedding_model.encode( | |
| texts, | |
| batch_size=32, | |
| show_progress_bar=len(texts) > 100, | |
| convert_to_tensor=False | |
| ) | |
| return embeddings.tolist() | |
| def add_chunks_to_db(self, chunks: list, source_file: str) -> bool: | |
| """Add chunks (list of dicts) to ChromaDB with manual embedding generation.""" | |
| if not chunks: | |
| return True | |
| texts, ids, metadatas = [], [], [] | |
| seen_hashes = set() | |
| for chunk in chunks: | |
| text = chunk.get("text", "").strip() | |
| if not text: | |
| continue | |
| text_hash = hashlib.md5(text.encode()).hexdigest() | |
| if text_hash in seen_hashes: | |
| continue | |
| seen_hashes.add(text_hash) | |
| chunk_id = f"{source_file}_{chunk.get('chunk_id', 0)}_{text_hash[:8]}" | |
| try: | |
| if self.collection.get(ids=[chunk_id])['ids']: | |
| continue | |
| except Exception: | |
| pass | |
| texts.append(text) | |
| ids.append(chunk_id) | |
| metadata = { | |
| "source_file": source_file, | |
| "chunk_index": chunk.get("chunk_id", 0), | |
| "char_length": len(text), | |
| "text_hash": text_hash | |
| } | |
| metadatas.append(metadata) | |
| if not texts: | |
| return True | |
| embeddings = self.generate_embeddings(texts) | |
| self.collection.add( | |
| embeddings=embeddings, | |
| documents=texts, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| logger.info(f"Added {len(texts)} chunks from {source_file} to ChromaDB") | |
| return True | |
| def search_for_rag( | |
| self, | |
| query: str, | |
| n_results: int = 5, | |
| use_truncated: bool = False, | |
| filter_128_context: bool = False | |
| ) -> List[dict]: | |
| """Search using manual query embedding generation - completely offline.""" | |
| query_embedding = self.generate_embeddings([query])[0] | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=min(n_results * 2, 20), | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| search_results = [] | |
| for i, (doc, metadata, distance) in enumerate(zip( | |
| results['documents'][0], results['metadatas'][0], results['distances'][0] | |
| )): | |
| if len(search_results) >= n_results: | |
| break | |
| similarity_score = 1 / (1 + distance) | |
| result = { | |
| "id": results['ids'][0][i], | |
| "score": similarity_score, | |
| "distance": distance, | |
| "text": doc, | |
| "source_file": metadata["source_file"], | |
| "chunk_index": metadata["chunk_index"] | |
| } | |
| search_results.append(result) | |
| return search_results | |
| def reset_database(self): | |
| """Reset/delete existing collection.""" | |
| try: | |
| self.client.delete_collection(name=self.collection_name) | |
| self.collection = self._get_collection() | |
| logger.info(f"Reset collection: {self.collection_name}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to reset database: {e}") | |
| return False | |
| def get_collection_stats(self) -> dict: | |
| """Get collection statistics.""" | |
| count = self.collection.count() | |
| db_size_mb = 0 | |
| try: | |
| for file_path in Path(self.db_path).rglob("*"): | |
| if file_path.is_file(): | |
| db_size_mb += file_path.stat().st_size | |
| db_size_mb /= (1024 * 1024) | |
| except Exception: | |
| db_size_mb = 0 | |
| return { | |
| "total_chunks": count, | |
| "collection_name": self.collection_name, | |
| "embedding_model": self.model_name, | |
| "db_path": self.db_path, | |
| "db_size_mb": db_size_mb | |
| } | |
| def process_all_chunks(self, chunks_dir: str = None) -> bool: | |
| """Process all *_extracted.json files and build ChromaDB.""" | |
| if not chunks_dir: | |
| chunks_dir = getattr(settings, 'PROCESSED_TEXT_DIR', './data/processed_text') | |
| chunk_files = list(Path(chunks_dir).glob("*_extracted.json")) | |
| logger.info(f"Found {len(chunk_files)} extracted JSON files to process") | |
| total_processed = 0 | |
| for chunk_file in chunk_files: | |
| try: | |
| with open(chunk_file, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| # Handle the actual structure of extracted JSON files | |
| if isinstance(data, dict) and "initial_chunks" in data: | |
| # New format: { "source_info": {...}, "initial_chunks": [...] } | |
| chunks = data["initial_chunks"] | |
| elif isinstance(data, list): | |
| # Old format: list of chunks directly | |
| chunks = data | |
| else: | |
| logger.warning(f"Unexpected format in {chunk_file.name}") | |
| continue | |
| if chunks and self.add_chunks_to_db(chunks, source_file=chunk_file.name): | |
| total_processed += len(chunks) | |
| logger.info(f"Processed {chunk_file.name}: {len(chunks)} chunks") | |
| except Exception as e: | |
| logger.error(f"Error processing {chunk_file}: {e}") | |
| continue | |
| logger.info(f"Successfully processed {total_processed} total chunks") | |
| return total_processed > 0 | |