sagar008's picture
Update vector_store.py
3a757d8 verified
# vector_store.py - Vector store integration with Pinecone
"""
Vector store integration for legal document embeddings using InLegalBERT and Pinecone
"""
import os
import numpy as np
from typing import List, Dict, Any
from langchain_pinecone import PineconeVectorStore
from langchain.embeddings.base import Embeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
class InLegalBERTEmbeddings(Embeddings):
"""Custom LangChain embeddings wrapper for InLegalBERT"""
def __init__(self, model):
self.model = model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents"""
return self.model.encode(texts).tolist()
def embed_query(self, text: str) -> List[float]:
"""Embed a single query"""
return self.model.encode([text])[0].tolist()
class LegalDocumentVectorStore:
"""Manages vector storage for legal documents"""
def __init__(self):
self.index_name = 'legal-documents'
self.dimension = 768 # InLegalBERT dimension
self._initialized = False
self.clause_tagger = None
self.pc = None
def _initialize_pinecone(self):
"""Initialize Pinecone connection"""
if self._initialized:
return
PINECONE_API_KEY = os.getenv('PINECONE_API_KEY')
if not PINECONE_API_KEY:
raise ValueError("PINECONE_API_KEY environment variable not set")
# Use modern Pinecone API
from pinecone import Pinecone, ServerlessSpec
self.pc = Pinecone(api_key=PINECONE_API_KEY)
# Create index if doesn't exist
existing_indexes = [index_info["name"] for index_info in self.pc.list_indexes()]
if self.index_name not in existing_indexes:
self.pc.create_index(
name=self.index_name,
dimension=self.dimension,
metric='cosine',
spec=ServerlessSpec(cloud='aws', region='us-east-1')
)
print(f"βœ… Created Pinecone index: {self.index_name}")
self._initialized = True
def _normalize_embedding(self, embedding):
"""Ensure embedding is always a list of floats"""
if embedding is None:
return None
if isinstance(embedding, np.ndarray):
return embedding.tolist()
if isinstance(embedding, list):
# Already a Python list
return embedding
# Fallback: try converting if it's a torch tensor or similar
try:
return embedding.tolist()
except Exception:
return list(embedding)
def save_document_embeddings_optimized(
self,
chunk_data: List[Dict],
document_id: str,
analysis_results: Dict[str, Any]
) -> bool:
"""Save embeddings using pre-computed vectors with proper text storage"""
try:
self._initialize_pinecone()
# Normalize embeddings safely
valid_chunks = [c for c in chunk_data if c.get("embedding") is not None]
if not valid_chunks:
print("⚠️ No embeddings found in chunk_data")
return False
vectors = []
for i, chunk_info in enumerate(valid_chunks):
normalized_embedding = self._normalize_embedding(chunk_info["embedding"])
if normalized_embedding is None:
continue
metadata = {
'document_id': document_id,
'chunk_index': i,
'total_chunks': len(chunk_data),
'source': 'legal_document',
'has_key_clauses': len(analysis_results.get('key_clauses', [])) > 0,
'risk_count': len(analysis_results.get('risky_terms', [])),
'embedding_model': 'InLegalBERT',
'timestamp': str(np.datetime64('now')),
'text': chunk_info["text"] # Store text in metadata for retrieval
}
vectors.append({
"id": f"{document_id}_chunk_{i}",
"values": normalized_embedding,
"metadata": metadata
})
# Add to Pinecone
index = self.pc.Index(self.index_name)
index.upsert(vectors=vectors)
print(f"βœ… Saved {len(vectors)} pre-computed embeddings with text to Pinecone")
return True
except Exception as e:
print(f"❌ Error saving pre-computed embeddings: {e}")
return False
def get_retriever(self, clause_tagger, document_id: str = None):
"""Get retriever for chat functionality with improved settings"""
try:
self._initialize_pinecone()
legal_embeddings = InLegalBERTEmbeddings(clause_tagger.embedding_model)
index = self.pc.Index(self.index_name)
vectorstore = PineconeVectorStore(
index=index,
embedding=legal_embeddings,
text_key="text" # Use text stored in metadata
)
# Configure search parameters
search_kwargs = {'k': 10}
if document_id:
search_kwargs['filter'] = {'document_id': document_id}
return vectorstore.as_retriever(
search_type="similarity",
search_kwargs=search_kwargs
)
except Exception as e:
print(f"❌ Error creating retriever: {e}")
return None
# Global instance
vector_store = LegalDocumentVectorStore()