Spaces:
Sleeping
Sleeping
import gradio as gr | |
from typing import List, Dict | |
from langchain_huggingface import HuggingFacePipeline # Fixed import | |
from langchain_core.prompts import ChatPromptTemplate | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import chromadb | |
from chromadb.utils import embedding_functions | |
import torch | |
import os | |
class LegalChatbot: | |
def __init__(self): | |
print("Initializing Legal Chatbot...") | |
# Initialize ChromaDB | |
self.chroma_client = chromadb.Client() | |
# Initialize embedding function | |
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="all-MiniLM-L6-v2", | |
device="cpu" | |
) | |
# Create collection | |
self.collection = self.chroma_client.create_collection( | |
name="text_collection", | |
embedding_function=self.embedding_function, | |
metadata={"hnsw:space": "cosine"} | |
) | |
# Initialize the model - using a smaller model suitable for CPU | |
pipe = pipeline( | |
"text-generation", | |
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.15, | |
device="cpu" | |
) | |
self.llm = HuggingFacePipeline(pipeline=pipe) | |
# Create prompt template | |
self.template = """ | |
IMPORTANT: You are a helpful assistant that provides information about the Bharatiya Nyaya Sanhita, 2023 based on the retrieved context. | |
STRICT RULES: | |
1. Base your response ONLY on the provided context | |
2. If you cannot find relevant information, respond with: "I apologize, but I cannot find information about that in the database." | |
3. Do not make assumptions or use external knowledge | |
4. Be concise and accurate in your responses | |
5. If quoting from the context, clearly indicate it | |
Context: {context} | |
Chat History: {chat_history} | |
Question: {question} | |
Answer:""" | |
self.prompt = ChatPromptTemplate.from_template(self.template) | |
self.chat_history = "" | |
self.initialized = False | |
def _initialize_database(self) -> bool: | |
"""Initialize the database with document content""" | |
try: | |
if self.initialized: | |
return True | |
print("Loading documents into database...") | |
# Read the main text file | |
with open('a2023-45.txt', 'r', encoding='utf-8') as f: | |
text_content = f.read() | |
# Read the index file | |
with open('index.txt', 'r', encoding='utf-8') as f: | |
index_lines = f.readlines() | |
# Create chunks | |
chunk_size = 512 | |
chunks = [] | |
for i in range(0, len(text_content), chunk_size): | |
chunk = text_content[i:i + chunk_size] | |
chunks.append(chunk) | |
# Add documents in batches | |
batch_size = 50 | |
for i in range(0, len(chunks), batch_size): | |
batch = chunks[i:i + batch_size] | |
batch_ids = [f"doc_{j}" for j in range(i, i + len(batch))] | |
batch_metadata = [{ | |
"index": index_lines[j].strip() if j < len(index_lines) else f"Chunk {j+1}", | |
"chunk_number": j | |
} for j in range(i, i + len(batch))] | |
self.collection.add( | |
documents=batch, | |
ids=batch_ids, | |
metadatas=batch_metadata | |
) | |
self.initialized = True | |
return True | |
except Exception as e: | |
print(f"Error initializing database: {str(e)}") | |
return False | |
def _search_database(self, query: str) -> List[Dict]: | |
"""Search the database for relevant documents""" | |
try: | |
results = self.collection.query( | |
query_texts=[query], | |
n_results=3, | |
include=["documents", "metadatas", "distances"] | |
) | |
return [ | |
{ | |
"content": doc, | |
"metadata": meta, | |
"score": 1 - dist | |
} | |
for doc, meta, dist in zip( | |
results['documents'][0], | |
results['metadatas'][0], | |
results['distances'][0] | |
) | |
] | |
except Exception as e: | |
print(f"Error searching database: {str(e)}") | |
return [] | |
def chat(self, query: str, history) -> str: | |
"""Process a query and return a response""" | |
try: | |
# Initialize database if needed | |
if not self.initialized and not self._initialize_database(): | |
return "Error: Unable to initialize the database. Please try again." | |
# Search for relevant content | |
search_results = self._search_database(query) | |
if not search_results: | |
return "I apologize, but I cannot find information about that in the database." | |
# Extract and combine relevant content | |
context = "\n\n".join([ | |
f"[Section {r['metadata']['index']}]\n{r['content']}" | |
for r in search_results | |
]) | |
# Generate response using LLM | |
chain = self.prompt | self.llm | |
result = chain.invoke({ | |
"context": context, | |
"chat_history": self.chat_history, | |
"question": query | |
}) | |
# Update chat history | |
self.chat_history += f"\nUser: {query}\nAI: {result}\n" | |
return result | |
except Exception as e: | |
return f"Error processing query: {str(e)}" | |
# Initialize the chatbot | |
chatbot = LegalChatbot() | |
# Create the Gradio interface | |
iface = gr.ChatInterface( | |
chatbot.chat, | |
title="Bharatiya Nyaya Sanhita, 2023 - Legal Assistant", | |
description="Ask questions about the Bharatiya Nyaya Sanhita, 2023. The system will initialize on your first query.", | |
examples=[ | |
"What is criminal conspiracy?", | |
"What are the punishments for corruption?", | |
"Explain the concept of culpable homicide", | |
"What constitutes theft under the act?" | |
], | |
theme=gr.themes.Soft() | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch( | |
share=False, | |
show_error=True | |
) |