Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import torch | |
import faiss | |
import pickle | |
import json | |
import logging | |
from sentence_transformers import SentenceTransformer | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# ------------------------------- | |
# 1. Configuration | |
# ------------------------------- | |
# Paths | |
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
VECTOR_DB_DIR = os.path.join(CURRENT_DIR, "model_cache", "vector_db") | |
# Model Configuration | |
EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" # Default embedding model | |
GENERATION_MODEL = "google/flan-t5-base" # Default generation model | |
MAX_OUTPUT_LENGTH = 512 # Output token limit | |
MAX_CONTEXT_LENGTH = 800 # Maximum length for context chunks | |
# ------------------------------- | |
# 2. Load Components | |
# ------------------------------- | |
def load_vector_db(): | |
"""Load the vector store for retrieval""" | |
try: | |
# Load model info | |
model_info_path = os.path.join(VECTOR_DB_DIR, "model_info.json") | |
if os.path.exists(model_info_path): | |
with open(model_info_path, "r") as f: | |
model_info = json.load(f) | |
model_name = model_info.get("name", EMBEDDING_MODEL) | |
else: | |
model_name = EMBEDDING_MODEL | |
# Load embedding model | |
logger.info(f"Loading embedding model: {model_name}") | |
retriever_model = SentenceTransformer(model_name) | |
# Load FAISS index | |
logger.info(f"Loading FAISS index from {VECTOR_DB_DIR}") | |
index = faiss.read_index(os.path.join(VECTOR_DB_DIR, "index.faiss")) | |
# Load chunks | |
with open(os.path.join(VECTOR_DB_DIR, "chunks.pkl"), "rb") as f: | |
chunks = pickle.load(f) | |
logger.info(f"Loaded vector store with {len(chunks)} chunks") | |
return retriever_model, index, chunks | |
except Exception as e: | |
logger.error(f"Error loading vector store: {str(e)}") | |
return None, None, None | |
def load_generator(): | |
"""Load the generation model""" | |
try: | |
logger.info(f"Loading generation model: {GENERATION_MODEL}") | |
# Load tokenizer and model separately | |
tokenizer = AutoTokenizer.from_pretrained(GENERATION_MODEL) | |
model = AutoModelForSeq2SeqLM.from_pretrained(GENERATION_MODEL) | |
# Create the pipeline with the tokenizer and model | |
generator = pipeline( | |
"text2text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_length=MAX_OUTPUT_LENGTH, | |
device_map="auto" if torch.cuda.is_available() else "cpu" | |
) | |
return generator, tokenizer | |
except Exception as e: | |
logger.error(f"Error loading generation model: {str(e)}") | |
return None, None | |
# Load components | |
retriever_model, index, chunks = load_vector_db() | |
generator, tokenizer = load_generator() | |
def chunk_context(context, tokenizer, max_length=MAX_CONTEXT_LENGTH): | |
"""Split context into manageable chunks that won't exceed token limits""" | |
# Quick check if tokenizer is missing | |
if not tokenizer: | |
logger.warning("Tokenizer not available, returning context as is") | |
return context | |
# Tokenize the context to get token counts | |
encoded = tokenizer.encode(context) | |
# If context fits within the limit, return it as is | |
if len(encoded) <= max_length: | |
return context | |
# Otherwise, split into sentences and build chunks | |
import re | |
sentences = re.split(r'(?<=[.!?])\s+', context) | |
chunks = [] | |
current_chunk = [] | |
current_length = 0 | |
for sentence in sentences: | |
sentence_tokens = len(tokenizer.encode(sentence)) | |
if current_length + sentence_tokens > max_length: | |
# This sentence would make the chunk too long, start a new chunk | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
current_chunk = [sentence] | |
current_length = sentence_tokens | |
else: | |
# Add this sentence to the current chunk | |
current_chunk.append(sentence) | |
current_length += sentence_tokens | |
# Add the last chunk if not empty | |
if current_chunk: | |
chunks.append(' '.join(current_chunk)) | |
return chunks | |
# ------------------------------- | |
# 3. RAG Pipeline | |
# ------------------------------- | |
def retrieve(query, top_k=5): | |
"""Retrieve relevant passages for the query""" | |
if not retriever_model or not index or not chunks: | |
logger.error("Vector store components not loaded") | |
return [] | |
try: | |
# Generate query embedding | |
query_embedding = retriever_model.encode([query], convert_to_numpy=True) | |
query_embedding = query_embedding.astype('float32') | |
faiss.normalize_L2(query_embedding) | |
# Search index | |
scores, indices = index.search(query_embedding, top_k) | |
# Filter valid indices | |
valid_indices = [idx for idx in indices[0] if idx < len(chunks)] | |
retrieved_chunks = [chunks[idx] for idx in valid_indices] | |
return retrieved_chunks | |
except Exception as e: | |
logger.error(f"Error during retrieval: {str(e)}") | |
return [] | |
def answer_query(query): | |
"""Process a query through RAG pipeline""" | |
if not generator: | |
return "Generation model could not be loaded. Please check your installation." | |
try: | |
# Retrieve relevant context | |
retrieved_chunks = retrieve(query) | |
if not retrieved_chunks: | |
return "No relevant information found in the knowledge base or retrieval system not working." | |
# Build context from chunks with source information | |
context_parts = [] | |
sources = [] | |
for i, chunk in enumerate(retrieved_chunks): | |
# Get the content and source information | |
content = chunk.get('content', str(chunk)) | |
# Extract source information - look for multiple possible source fields | |
source = None | |
for source_field in ['source', 'title', 'metadata', 'filename', 'document']: | |
if source_field in chunk: | |
source = chunk[source_field] | |
break | |
# If no source field found, create a reference ID | |
if not source: | |
source = f"Reference {i+1}" | |
# Add source identifier within the content for better context | |
labeled_content = f"[From {source}]: {content}" | |
context_parts.append(labeled_content) | |
sources.append(source) | |
# Join all content into one context | |
full_context = "\n\n".join(context_parts) | |
# Split context if it's too long for the model | |
context_chunks = chunk_context(full_context, tokenizer) | |
# If context is a string, make it a list for consistent handling | |
if isinstance(context_chunks, str): | |
context_chunks = [context_chunks] | |
# Generate answers for each context chunk | |
answers = [] | |
for i, ctx in enumerate(context_chunks): | |
# Format prompt with explicit instructions to cite sources | |
prompt = f"""Answer the following legal question based on the provided context. | |
You MUST include specific citations to the sources in your answer, | |
clearly indicating which information comes from which source. | |
Context: {ctx} | |
Question: {query} | |
Provide a detailed answer with explicit source citations:""" | |
# Generate answer | |
result = generator( | |
prompt, | |
max_length=MAX_OUTPUT_LENGTH, | |
temperature=0.7, | |
do_sample=True, | |
num_return_sequences=1 | |
) | |
# Extract response | |
if isinstance(result, list) and len(result) > 0: | |
if 'generated_text' in result[0]: | |
answers.append(result[0]['generated_text'].strip()) | |
else: | |
answers.append(str(result[0]).strip()) | |
else: | |
answers.append(str(result).strip()) | |
# Combine answers | |
combined_answer = "\n\n".join(answers) | |
# Always add a sources section at the end, regardless of whether sources were | |
# mentioned in the generated text | |
source_text = "\n\nSources used to generate this answer:\n" + "\n".join([f"- {source}" for source in sources]) | |
combined_answer += source_text | |
return combined_answer | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
return f"Error generating response: {str(e)}" | |
# ------------------------------- | |
# 4. Gradio Interface | |
# ------------------------------- | |
title = "⚖️ Indian Labor Law Assistant" | |
description = """ | |
This AI assistant uses a local Retrieval-Augmented Generation (RAG) system to answer questions about Indian labor laws. | |
Ask about employment regulations, workers' rights, or legal provisions in India. Please note that this is not a substitute for professional legal advice, this is a tool to make legal information more accessible. | |
Please be patient with the model, as it may take a few seconds to generate a response - we are using the cheapest GPU available for this model. For faster load speed, please clone the repository and run it locally. | |
""" | |
examples = [ | |
["What are the key provisions of the Industrial Disputes Act?"], | |
["What is the minimum wage according to Indian labor law?"], | |
["What are the rules for overtime pay in India?"], | |
["What constitutes unfair labor practice under Indian law?"], | |
["Explain maternity leave entitlements in India"] | |
] | |
interface = gr.Interface( | |
fn=answer_query, | |
inputs=gr.Textbox( | |
lines=2, | |
placeholder="Ask your Indian labor law question here...", | |
label="Legal Query" | |
), | |
outputs=gr.Textbox( | |
lines=10, | |
placeholder="Answer will appear here...", | |
label="Generated Response" | |
), | |
title=title, | |
description=description, | |
examples=examples, | |
theme="soft", | |
cache_examples=False | |
) | |
# ------------------------------- | |
# 5. Launch App | |
# ------------------------------- | |
if __name__ == "__main__": | |
interface.launch() |