BearSystemsChat / app.py
SlightlyHappy's picture
AI
76d66b6
import gradio as gr
import os
import torch
import faiss
import pickle
import json
import logging
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# -------------------------------
# 1. Configuration
# -------------------------------
# Paths
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
VECTOR_DB_DIR = os.path.join(CURRENT_DIR, "model_cache", "vector_db")
# Model Configuration
EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" # Default embedding model
GENERATION_MODEL = "google/flan-t5-base" # Default generation model
MAX_OUTPUT_LENGTH = 512 # Output token limit
MAX_CONTEXT_LENGTH = 800 # Maximum length for context chunks
# -------------------------------
# 2. Load Components
# -------------------------------
def load_vector_db():
"""Load the vector store for retrieval"""
try:
# Load model info
model_info_path = os.path.join(VECTOR_DB_DIR, "model_info.json")
if os.path.exists(model_info_path):
with open(model_info_path, "r") as f:
model_info = json.load(f)
model_name = model_info.get("name", EMBEDDING_MODEL)
else:
model_name = EMBEDDING_MODEL
# Load embedding model
logger.info(f"Loading embedding model: {model_name}")
retriever_model = SentenceTransformer(model_name)
# Load FAISS index
logger.info(f"Loading FAISS index from {VECTOR_DB_DIR}")
index = faiss.read_index(os.path.join(VECTOR_DB_DIR, "index.faiss"))
# Load chunks
with open(os.path.join(VECTOR_DB_DIR, "chunks.pkl"), "rb") as f:
chunks = pickle.load(f)
logger.info(f"Loaded vector store with {len(chunks)} chunks")
return retriever_model, index, chunks
except Exception as e:
logger.error(f"Error loading vector store: {str(e)}")
return None, None, None
def load_generator():
"""Load the generation model"""
try:
logger.info(f"Loading generation model: {GENERATION_MODEL}")
# Load tokenizer and model separately
tokenizer = AutoTokenizer.from_pretrained(GENERATION_MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(GENERATION_MODEL)
# Create the pipeline with the tokenizer and model
generator = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=MAX_OUTPUT_LENGTH,
device_map="auto" if torch.cuda.is_available() else "cpu"
)
return generator, tokenizer
except Exception as e:
logger.error(f"Error loading generation model: {str(e)}")
return None, None
# Load components
retriever_model, index, chunks = load_vector_db()
generator, tokenizer = load_generator()
def chunk_context(context, tokenizer, max_length=MAX_CONTEXT_LENGTH):
"""Split context into manageable chunks that won't exceed token limits"""
# Quick check if tokenizer is missing
if not tokenizer:
logger.warning("Tokenizer not available, returning context as is")
return context
# Tokenize the context to get token counts
encoded = tokenizer.encode(context)
# If context fits within the limit, return it as is
if len(encoded) <= max_length:
return context
# Otherwise, split into sentences and build chunks
import re
sentences = re.split(r'(?<=[.!?])\s+', context)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
sentence_tokens = len(tokenizer.encode(sentence))
if current_length + sentence_tokens > max_length:
# This sentence would make the chunk too long, start a new chunk
if current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = [sentence]
current_length = sentence_tokens
else:
# Add this sentence to the current chunk
current_chunk.append(sentence)
current_length += sentence_tokens
# Add the last chunk if not empty
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
# -------------------------------
# 3. RAG Pipeline
# -------------------------------
def retrieve(query, top_k=5):
"""Retrieve relevant passages for the query"""
if not retriever_model or not index or not chunks:
logger.error("Vector store components not loaded")
return []
try:
# Generate query embedding
query_embedding = retriever_model.encode([query], convert_to_numpy=True)
query_embedding = query_embedding.astype('float32')
faiss.normalize_L2(query_embedding)
# Search index
scores, indices = index.search(query_embedding, top_k)
# Filter valid indices
valid_indices = [idx for idx in indices[0] if idx < len(chunks)]
retrieved_chunks = [chunks[idx] for idx in valid_indices]
return retrieved_chunks
except Exception as e:
logger.error(f"Error during retrieval: {str(e)}")
return []
def answer_query(query):
"""Process a query through RAG pipeline"""
if not generator:
return "Generation model could not be loaded. Please check your installation."
try:
# Retrieve relevant context
retrieved_chunks = retrieve(query)
if not retrieved_chunks:
return "No relevant information found in the knowledge base or retrieval system not working."
# Build context from chunks with source information
context_parts = []
sources = []
for i, chunk in enumerate(retrieved_chunks):
# Get the content and source information
content = chunk.get('content', str(chunk))
# Extract source information - look for multiple possible source fields
source = None
for source_field in ['source', 'title', 'metadata', 'filename', 'document']:
if source_field in chunk:
source = chunk[source_field]
break
# If no source field found, create a reference ID
if not source:
source = f"Reference {i+1}"
# Add source identifier within the content for better context
labeled_content = f"[From {source}]: {content}"
context_parts.append(labeled_content)
sources.append(source)
# Join all content into one context
full_context = "\n\n".join(context_parts)
# Split context if it's too long for the model
context_chunks = chunk_context(full_context, tokenizer)
# If context is a string, make it a list for consistent handling
if isinstance(context_chunks, str):
context_chunks = [context_chunks]
# Generate answers for each context chunk
answers = []
for i, ctx in enumerate(context_chunks):
# Format prompt with explicit instructions to cite sources
prompt = f"""Answer the following legal question based on the provided context.
You MUST include specific citations to the sources in your answer,
clearly indicating which information comes from which source.
Context: {ctx}
Question: {query}
Provide a detailed answer with explicit source citations:"""
# Generate answer
result = generator(
prompt,
max_length=MAX_OUTPUT_LENGTH,
temperature=0.7,
do_sample=True,
num_return_sequences=1
)
# Extract response
if isinstance(result, list) and len(result) > 0:
if 'generated_text' in result[0]:
answers.append(result[0]['generated_text'].strip())
else:
answers.append(str(result[0]).strip())
else:
answers.append(str(result).strip())
# Combine answers
combined_answer = "\n\n".join(answers)
# Always add a sources section at the end, regardless of whether sources were
# mentioned in the generated text
source_text = "\n\nSources used to generate this answer:\n" + "\n".join([f"- {source}" for source in sources])
combined_answer += source_text
return combined_answer
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return f"Error generating response: {str(e)}"
# -------------------------------
# 4. Gradio Interface
# -------------------------------
title = "⚖️ Indian Labor Law Assistant"
description = """
This AI assistant uses a local Retrieval-Augmented Generation (RAG) system to answer questions about Indian labor laws.
Ask about employment regulations, workers' rights, or legal provisions in India. Please note that this is not a substitute for professional legal advice, this is a tool to make legal information more accessible.
Please be patient with the model, as it may take a few seconds to generate a response - we are using the cheapest GPU available for this model. For faster load speed, please clone the repository and run it locally.
"""
examples = [
["What are the key provisions of the Industrial Disputes Act?"],
["What is the minimum wage according to Indian labor law?"],
["What are the rules for overtime pay in India?"],
["What constitutes unfair labor practice under Indian law?"],
["Explain maternity leave entitlements in India"]
]
interface = gr.Interface(
fn=answer_query,
inputs=gr.Textbox(
lines=2,
placeholder="Ask your Indian labor law question here...",
label="Legal Query"
),
outputs=gr.Textbox(
lines=10,
placeholder="Answer will appear here...",
label="Generated Response"
),
title=title,
description=description,
examples=examples,
theme="soft",
cache_examples=False
)
# -------------------------------
# 5. Launch App
# -------------------------------
if __name__ == "__main__":
interface.launch()