Nihal2000's picture
inferancing using gemma 270 model
f05e8f9
from __future__ import annotations
import json
import os
import hashlib
from pathlib import Path
from typing import List
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from src.utils.logger import get_logger
from config.settings import settings
logger = get_logger(__name__)
class ChromaVectorDBManager:
"""Corporate-friendly ChromaDB manager - completely offline."""
def __init__(self, model_name: str = None, db_path: str = None):
self.model_name = model_name or getattr(
settings, 'EMBEDDING_MODEL', 'sentence-transformers/all-MiniLM-L6-v2'
)
self.embedding_model = SentenceTransformer(self.model_name)
self.db_path = db_path or getattr(settings, 'CHROMADB_PATH', './chroma_db')
os.makedirs(self.db_path, exist_ok=True)
self.client = chromadb.PersistentClient(
path=self.db_path,
settings=Settings(
anonymized_telemetry=False,
allow_reset=True,
is_persistent=True
)
)
self.collection_name = getattr(settings, 'COLLECTION_NAME', 'rag_chunks')
self.collection = self._get_collection()
logger.info(f"ChromaDB initialized at: {self.db_path}")
def _get_collection(self):
"""Get or create collection without embedding function."""
try:
return self.client.get_collection(name=self.collection_name)
except Exception:
try:
self.client.delete_collection(name=self.collection_name)
except Exception:
pass
return self.client.create_collection(
name=self.collection_name,
metadata={"description": "RAG chunks"}
)
def generate_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Generate embeddings using local sentence-transformers."""
embeddings = self.embedding_model.encode(
texts,
batch_size=32,
show_progress_bar=len(texts) > 100,
convert_to_tensor=False
)
return embeddings.tolist()
def add_chunks_to_db(self, chunks: list, source_file: str) -> bool:
"""Add chunks (list of dicts) to ChromaDB with manual embedding generation."""
if not chunks:
return True
texts, ids, metadatas = [], [], []
seen_hashes = set()
for chunk in chunks:
text = chunk.get("text", "").strip()
if not text:
continue
text_hash = hashlib.md5(text.encode()).hexdigest()
if text_hash in seen_hashes:
continue
seen_hashes.add(text_hash)
chunk_id = f"{source_file}_{chunk.get('chunk_id', 0)}_{text_hash[:8]}"
try:
if self.collection.get(ids=[chunk_id])['ids']:
continue
except Exception:
pass
texts.append(text)
ids.append(chunk_id)
metadata = {
"source_file": source_file,
"chunk_index": chunk.get("chunk_id", 0),
"char_length": len(text),
"text_hash": text_hash
}
metadatas.append(metadata)
if not texts:
return True
embeddings = self.generate_embeddings(texts)
self.collection.add(
embeddings=embeddings,
documents=texts,
metadatas=metadatas,
ids=ids
)
logger.info(f"Added {len(texts)} chunks from {source_file} to ChromaDB")
return True
def search_for_rag(
self,
query: str,
n_results: int = 5,
use_truncated: bool = False,
filter_128_context: bool = False
) -> List[dict]:
"""Search using manual query embedding generation - completely offline."""
query_embedding = self.generate_embeddings([query])[0]
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=min(n_results * 2, 20),
include=["documents", "metadatas", "distances"]
)
search_results = []
for i, (doc, metadata, distance) in enumerate(zip(
results['documents'][0], results['metadatas'][0], results['distances'][0]
)):
if len(search_results) >= n_results:
break
similarity_score = 1 / (1 + distance)
result = {
"id": results['ids'][0][i],
"score": similarity_score,
"distance": distance,
"text": doc,
"source_file": metadata["source_file"],
"chunk_index": metadata["chunk_index"]
}
search_results.append(result)
return search_results
def reset_database(self):
"""Reset/delete existing collection."""
try:
self.client.delete_collection(name=self.collection_name)
self.collection = self._get_collection()
logger.info(f"Reset collection: {self.collection_name}")
return True
except Exception as e:
logger.error(f"Failed to reset database: {e}")
return False
def get_collection_stats(self) -> dict:
"""Get collection statistics."""
count = self.collection.count()
db_size_mb = 0
try:
for file_path in Path(self.db_path).rglob("*"):
if file_path.is_file():
db_size_mb += file_path.stat().st_size
db_size_mb /= (1024 * 1024)
except Exception:
db_size_mb = 0
return {
"total_chunks": count,
"collection_name": self.collection_name,
"embedding_model": self.model_name,
"db_path": self.db_path,
"db_size_mb": db_size_mb
}
def process_all_chunks(self, chunks_dir: str = None) -> bool:
"""Process all *_extracted.json files and build ChromaDB."""
if not chunks_dir:
chunks_dir = getattr(settings, 'PROCESSED_TEXT_DIR', './data/processed_text')
chunk_files = list(Path(chunks_dir).glob("*_extracted.json"))
logger.info(f"Found {len(chunk_files)} extracted JSON files to process")
total_processed = 0
for chunk_file in chunk_files:
try:
with open(chunk_file, "r", encoding="utf-8") as f:
data = json.load(f)
# Handle the actual structure of extracted JSON files
if isinstance(data, dict) and "initial_chunks" in data:
# New format: { "source_info": {...}, "initial_chunks": [...] }
chunks = data["initial_chunks"]
elif isinstance(data, list):
# Old format: list of chunks directly
chunks = data
else:
logger.warning(f"Unexpected format in {chunk_file.name}")
continue
if chunks and self.add_chunks_to_db(chunks, source_file=chunk_file.name):
total_processed += len(chunks)
logger.info(f"Processed {chunk_file.name}: {len(chunks)} chunks")
except Exception as e:
logger.error(f"Error processing {chunk_file}: {e}")
continue
logger.info(f"Successfully processed {total_processed} total chunks")
return total_processed > 0