import gradio as gr import os import torch import faiss import pickle import json import logging from sentence_transformers import SentenceTransformer from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # ------------------------------- # 1. Configuration # ------------------------------- # Paths CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) VECTOR_DB_DIR = os.path.join(CURRENT_DIR, "model_cache", "vector_db") # Model Configuration EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" # Default embedding model GENERATION_MODEL = "google/flan-t5-base" # Default generation model MAX_OUTPUT_LENGTH = 512 # Output token limit MAX_CONTEXT_LENGTH = 800 # Maximum length for context chunks # ------------------------------- # 2. Load Components # ------------------------------- def load_vector_db(): """Load the vector store for retrieval""" try: # Load model info model_info_path = os.path.join(VECTOR_DB_DIR, "model_info.json") if os.path.exists(model_info_path): with open(model_info_path, "r") as f: model_info = json.load(f) model_name = model_info.get("name", EMBEDDING_MODEL) else: model_name = EMBEDDING_MODEL # Load embedding model logger.info(f"Loading embedding model: {model_name}") retriever_model = SentenceTransformer(model_name) # Load FAISS index logger.info(f"Loading FAISS index from {VECTOR_DB_DIR}") index = faiss.read_index(os.path.join(VECTOR_DB_DIR, "index.faiss")) # Load chunks with open(os.path.join(VECTOR_DB_DIR, "chunks.pkl"), "rb") as f: chunks = pickle.load(f) logger.info(f"Loaded vector store with {len(chunks)} chunks") return retriever_model, index, chunks except Exception as e: logger.error(f"Error loading vector store: {str(e)}") return None, None, None def load_generator(): """Load the generation model""" try: logger.info(f"Loading generation model: {GENERATION_MODEL}") # Load tokenizer and model separately tokenizer = AutoTokenizer.from_pretrained(GENERATION_MODEL) model = AutoModelForSeq2SeqLM.from_pretrained(GENERATION_MODEL) # Create the pipeline with the tokenizer and model generator = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_length=MAX_OUTPUT_LENGTH, device_map="auto" if torch.cuda.is_available() else "cpu" ) return generator, tokenizer except Exception as e: logger.error(f"Error loading generation model: {str(e)}") return None, None # Load components retriever_model, index, chunks = load_vector_db() generator, tokenizer = load_generator() def chunk_context(context, tokenizer, max_length=MAX_CONTEXT_LENGTH): """Split context into manageable chunks that won't exceed token limits""" # Quick check if tokenizer is missing if not tokenizer: logger.warning("Tokenizer not available, returning context as is") return context # Tokenize the context to get token counts encoded = tokenizer.encode(context) # If context fits within the limit, return it as is if len(encoded) <= max_length: return context # Otherwise, split into sentences and build chunks import re sentences = re.split(r'(?<=[.!?])\s+', context) chunks = [] current_chunk = [] current_length = 0 for sentence in sentences: sentence_tokens = len(tokenizer.encode(sentence)) if current_length + sentence_tokens > max_length: # This sentence would make the chunk too long, start a new chunk if current_chunk: chunks.append(' '.join(current_chunk)) current_chunk = [sentence] current_length = sentence_tokens else: # Add this sentence to the current chunk current_chunk.append(sentence) current_length += sentence_tokens # Add the last chunk if not empty if current_chunk: chunks.append(' '.join(current_chunk)) return chunks # ------------------------------- # 3. RAG Pipeline # ------------------------------- def retrieve(query, top_k=5): """Retrieve relevant passages for the query""" if not retriever_model or not index or not chunks: logger.error("Vector store components not loaded") return [] try: # Generate query embedding query_embedding = retriever_model.encode([query], convert_to_numpy=True) query_embedding = query_embedding.astype('float32') faiss.normalize_L2(query_embedding) # Search index scores, indices = index.search(query_embedding, top_k) # Filter valid indices valid_indices = [idx for idx in indices[0] if idx < len(chunks)] retrieved_chunks = [chunks[idx] for idx in valid_indices] return retrieved_chunks except Exception as e: logger.error(f"Error during retrieval: {str(e)}") return [] def answer_query(query): """Process a query through RAG pipeline""" if not generator: return "Generation model could not be loaded. Please check your installation." try: # Retrieve relevant context retrieved_chunks = retrieve(query) if not retrieved_chunks: return "No relevant information found in the knowledge base or retrieval system not working." # Build context from chunks with source information context_parts = [] sources = [] for i, chunk in enumerate(retrieved_chunks): # Get the content and source information content = chunk.get('content', str(chunk)) # Extract source information - look for multiple possible source fields source = None for source_field in ['source', 'title', 'metadata', 'filename', 'document']: if source_field in chunk: source = chunk[source_field] break # If no source field found, create a reference ID if not source: source = f"Reference {i+1}" # Add source identifier within the content for better context labeled_content = f"[From {source}]: {content}" context_parts.append(labeled_content) sources.append(source) # Join all content into one context full_context = "\n\n".join(context_parts) # Split context if it's too long for the model context_chunks = chunk_context(full_context, tokenizer) # If context is a string, make it a list for consistent handling if isinstance(context_chunks, str): context_chunks = [context_chunks] # Generate answers for each context chunk answers = [] for i, ctx in enumerate(context_chunks): # Format prompt with explicit instructions to cite sources prompt = f"""Answer the following legal question based on the provided context. You MUST include specific citations to the sources in your answer, clearly indicating which information comes from which source. Context: {ctx} Question: {query} Provide a detailed answer with explicit source citations:""" # Generate answer result = generator( prompt, max_length=MAX_OUTPUT_LENGTH, temperature=0.7, do_sample=True, num_return_sequences=1 ) # Extract response if isinstance(result, list) and len(result) > 0: if 'generated_text' in result[0]: answers.append(result[0]['generated_text'].strip()) else: answers.append(str(result[0]).strip()) else: answers.append(str(result).strip()) # Combine answers combined_answer = "\n\n".join(answers) # Always add a sources section at the end, regardless of whether sources were # mentioned in the generated text source_text = "\n\nSources used to generate this answer:\n" + "\n".join([f"- {source}" for source in sources]) combined_answer += source_text return combined_answer except Exception as e: logger.error(f"Error generating response: {str(e)}") return f"Error generating response: {str(e)}" # ------------------------------- # 4. Gradio Interface # ------------------------------- title = "⚖️ Indian Labor Law Assistant" description = """ This AI assistant uses a local Retrieval-Augmented Generation (RAG) system to answer questions about Indian labor laws. Ask about employment regulations, workers' rights, or legal provisions in India. Please note that this is not a substitute for professional legal advice, this is a tool to make legal information more accessible. Please be patient with the model, as it may take a few seconds to generate a response - we are using the cheapest GPU available for this model. For faster load speed, please clone the repository and run it locally. """ examples = [ ["What are the key provisions of the Industrial Disputes Act?"], ["What is the minimum wage according to Indian labor law?"], ["What are the rules for overtime pay in India?"], ["What constitutes unfair labor practice under Indian law?"], ["Explain maternity leave entitlements in India"] ] interface = gr.Interface( fn=answer_query, inputs=gr.Textbox( lines=2, placeholder="Ask your Indian labor law question here...", label="Legal Query" ), outputs=gr.Textbox( lines=10, placeholder="Answer will appear here...", label="Generated Response" ), title=title, description=description, examples=examples, theme="soft", cache_examples=False ) # ------------------------------- # 5. Launch App # ------------------------------- if __name__ == "__main__": interface.launch()