Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import torch | |
| import faiss | |
| import pickle | |
| import json | |
| import logging | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # ------------------------------- | |
| # 1. Configuration | |
| # ------------------------------- | |
| # Paths | |
| CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| VECTOR_DB_DIR = os.path.join(CURRENT_DIR, "model_cache", "vector_db") | |
| # Model Configuration | |
| EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" # Default embedding model | |
| GENERATION_MODEL = "google/flan-t5-base" # Default generation model | |
| MAX_OUTPUT_LENGTH = 512 # Output token limit | |
| MAX_CONTEXT_LENGTH = 800 # Maximum length for context chunks | |
| # ------------------------------- | |
| # 2. Load Components | |
| # ------------------------------- | |
| def load_vector_db(): | |
| """Load the vector store for retrieval""" | |
| try: | |
| # Load model info | |
| model_info_path = os.path.join(VECTOR_DB_DIR, "model_info.json") | |
| if os.path.exists(model_info_path): | |
| with open(model_info_path, "r") as f: | |
| model_info = json.load(f) | |
| model_name = model_info.get("name", EMBEDDING_MODEL) | |
| else: | |
| model_name = EMBEDDING_MODEL | |
| # Load embedding model | |
| logger.info(f"Loading embedding model: {model_name}") | |
| retriever_model = SentenceTransformer(model_name) | |
| # Load FAISS index | |
| logger.info(f"Loading FAISS index from {VECTOR_DB_DIR}") | |
| index = faiss.read_index(os.path.join(VECTOR_DB_DIR, "index.faiss")) | |
| # Load chunks | |
| with open(os.path.join(VECTOR_DB_DIR, "chunks.pkl"), "rb") as f: | |
| chunks = pickle.load(f) | |
| logger.info(f"Loaded vector store with {len(chunks)} chunks") | |
| return retriever_model, index, chunks | |
| except Exception as e: | |
| logger.error(f"Error loading vector store: {str(e)}") | |
| return None, None, None | |
| def load_generator(): | |
| """Load the generation model""" | |
| try: | |
| logger.info(f"Loading generation model: {GENERATION_MODEL}") | |
| # Load tokenizer and model separately | |
| tokenizer = AutoTokenizer.from_pretrained(GENERATION_MODEL) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(GENERATION_MODEL) | |
| # Create the pipeline with the tokenizer and model | |
| generator = pipeline( | |
| "text2text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_length=MAX_OUTPUT_LENGTH, | |
| device_map="auto" if torch.cuda.is_available() else "cpu" | |
| ) | |
| return generator, tokenizer | |
| except Exception as e: | |
| logger.error(f"Error loading generation model: {str(e)}") | |
| return None, None | |
| # Load components | |
| retriever_model, index, chunks = load_vector_db() | |
| generator, tokenizer = load_generator() | |
| def chunk_context(context, tokenizer, max_length=MAX_CONTEXT_LENGTH): | |
| """Split context into manageable chunks that won't exceed token limits""" | |
| # Quick check if tokenizer is missing | |
| if not tokenizer: | |
| logger.warning("Tokenizer not available, returning context as is") | |
| return context | |
| # Tokenize the context to get token counts | |
| encoded = tokenizer.encode(context) | |
| # If context fits within the limit, return it as is | |
| if len(encoded) <= max_length: | |
| return context | |
| # Otherwise, split into sentences and build chunks | |
| import re | |
| sentences = re.split(r'(?<=[.!?])\s+', context) | |
| chunks = [] | |
| current_chunk = [] | |
| current_length = 0 | |
| for sentence in sentences: | |
| sentence_tokens = len(tokenizer.encode(sentence)) | |
| if current_length + sentence_tokens > max_length: | |
| # This sentence would make the chunk too long, start a new chunk | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| current_chunk = [sentence] | |
| current_length = sentence_tokens | |
| else: | |
| # Add this sentence to the current chunk | |
| current_chunk.append(sentence) | |
| current_length += sentence_tokens | |
| # Add the last chunk if not empty | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| return chunks | |
| # ------------------------------- | |
| # 3. RAG Pipeline | |
| # ------------------------------- | |
| def retrieve(query, top_k=5): | |
| """Retrieve relevant passages for the query""" | |
| if not retriever_model or not index or not chunks: | |
| logger.error("Vector store components not loaded") | |
| return [] | |
| try: | |
| # Generate query embedding | |
| query_embedding = retriever_model.encode([query], convert_to_numpy=True) | |
| query_embedding = query_embedding.astype('float32') | |
| faiss.normalize_L2(query_embedding) | |
| # Search index | |
| scores, indices = index.search(query_embedding, top_k) | |
| # Filter valid indices | |
| valid_indices = [idx for idx in indices[0] if idx < len(chunks)] | |
| retrieved_chunks = [chunks[idx] for idx in valid_indices] | |
| return retrieved_chunks | |
| except Exception as e: | |
| logger.error(f"Error during retrieval: {str(e)}") | |
| return [] | |
| def answer_query(query): | |
| """Process a query through RAG pipeline""" | |
| if not generator: | |
| return "Generation model could not be loaded. Please check your installation." | |
| try: | |
| # Retrieve relevant context | |
| retrieved_chunks = retrieve(query) | |
| if not retrieved_chunks: | |
| return "No relevant information found in the knowledge base or retrieval system not working." | |
| # Build context from chunks with source information | |
| context_parts = [] | |
| sources = [] | |
| for i, chunk in enumerate(retrieved_chunks): | |
| # Get the content and source information | |
| content = chunk.get('content', str(chunk)) | |
| # Extract source information - look for multiple possible source fields | |
| source = None | |
| for source_field in ['source', 'title', 'metadata', 'filename', 'document']: | |
| if source_field in chunk: | |
| source = chunk[source_field] | |
| break | |
| # If no source field found, create a reference ID | |
| if not source: | |
| source = f"Reference {i+1}" | |
| # Add source identifier within the content for better context | |
| labeled_content = f"[From {source}]: {content}" | |
| context_parts.append(labeled_content) | |
| sources.append(source) | |
| # Join all content into one context | |
| full_context = "\n\n".join(context_parts) | |
| # Split context if it's too long for the model | |
| context_chunks = chunk_context(full_context, tokenizer) | |
| # If context is a string, make it a list for consistent handling | |
| if isinstance(context_chunks, str): | |
| context_chunks = [context_chunks] | |
| # Generate answers for each context chunk | |
| answers = [] | |
| for i, ctx in enumerate(context_chunks): | |
| # Format prompt with explicit instructions to cite sources | |
| prompt = f"""Answer the following legal question based on the provided context. | |
| You MUST include specific citations to the sources in your answer, | |
| clearly indicating which information comes from which source. | |
| Context: {ctx} | |
| Question: {query} | |
| Provide a detailed answer with explicit source citations:""" | |
| # Generate answer | |
| result = generator( | |
| prompt, | |
| max_length=MAX_OUTPUT_LENGTH, | |
| temperature=0.7, | |
| do_sample=True, | |
| num_return_sequences=1 | |
| ) | |
| # Extract response | |
| if isinstance(result, list) and len(result) > 0: | |
| if 'generated_text' in result[0]: | |
| answers.append(result[0]['generated_text'].strip()) | |
| else: | |
| answers.append(str(result[0]).strip()) | |
| else: | |
| answers.append(str(result).strip()) | |
| # Combine answers | |
| combined_answer = "\n\n".join(answers) | |
| # Always add a sources section at the end, regardless of whether sources were | |
| # mentioned in the generated text | |
| source_text = "\n\nSources used to generate this answer:\n" + "\n".join([f"- {source}" for source in sources]) | |
| combined_answer += source_text | |
| return combined_answer | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| return f"Error generating response: {str(e)}" | |
| # ------------------------------- | |
| # 4. Gradio Interface | |
| # ------------------------------- | |
| title = "⚖️ Indian Labor Law Assistant" | |
| description = """ | |
| This AI assistant uses a local Retrieval-Augmented Generation (RAG) system to answer questions about Indian labor laws. | |
| Ask about employment regulations, workers' rights, or legal provisions in India. Please note that this is not a substitute for professional legal advice, this is a tool to make legal information more accessible. | |
| Please be patient with the model, as it may take a few seconds to generate a response - we are using the cheapest GPU available for this model. For faster load speed, please clone the repository and run it locally. | |
| """ | |
| examples = [ | |
| ["What are the key provisions of the Industrial Disputes Act?"], | |
| ["What is the minimum wage according to Indian labor law?"], | |
| ["What are the rules for overtime pay in India?"], | |
| ["What constitutes unfair labor practice under Indian law?"], | |
| ["Explain maternity leave entitlements in India"] | |
| ] | |
| interface = gr.Interface( | |
| fn=answer_query, | |
| inputs=gr.Textbox( | |
| lines=2, | |
| placeholder="Ask your Indian labor law question here...", | |
| label="Legal Query" | |
| ), | |
| outputs=gr.Textbox( | |
| lines=10, | |
| placeholder="Answer will appear here...", | |
| label="Generated Response" | |
| ), | |
| title=title, | |
| description=description, | |
| examples=examples, | |
| theme="soft", | |
| cache_examples=False | |
| ) | |
| # ------------------------------- | |
| # 5. Launch App | |
| # ------------------------------- | |
| if __name__ == "__main__": | |
| interface.launch() |