rachel.ai / app.py
shresthasingh's picture
Update app.py
8b80009 verified
raw
history blame
9.54 kB
import os
import shutil
import requests
import json
import gradio as gr
import PyPDF2
import chromadb
import csv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
# Constants
API_KEY = os.getenv("togetherai")
BASE_URL = "https://api.together.xyz/v1/chat/completions"
CHUNK_SIZE = 6000 # Maximum words per chunk
TEMP_SUMMARY_FILE = "temp_summaries.txt"
COLLECTIONS_FILE = "collections.csv"
# Function to convert PDF to text
def pdf_to_text(file_path):
with open(file_path, 'rb') as pdf_file:
pdf_reader = PyPDF2.PdfReader(pdf_file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
# Function to summarize text using LLM
def summarize_text(text):
user_prompt = f"""
You are an expert in legal language and document summarization. Your task is to provide a concise and accurate summary of the given document.
Keep the summary concise, ideally in 2000 words, while covering all essential points. Here is the document to summarize:
{text}
"""
return call_llm(user_prompt)
# Function to handle file upload, summarization, and saving to ChromaDB
def handle_file_upload(files, collection_name):
if not collection_name:
return "Please provide a collection name."
os.makedirs('uploaded_pdfs', exist_ok=True)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=100)
embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
client = chromadb.PersistentClient(path="./db")
try:
collection = client.create_collection(name=collection_name)
except ValueError as e:
return f"Error creating collection: {str(e)}. Please try a different collection name."
file_names = []
with open(TEMP_SUMMARY_FILE, 'w', encoding='utf-8') as temp_file:
for file in files:
file_name = os.path.basename(file.name)
file_names.append(file_name)
file_path = os.path.join('uploaded_pdfs', file_name)
shutil.copy(file.name, file_path)
text = pdf_to_text(file_path)
chunks = text_splitter.split_text(text)
for i, chunk in enumerate(chunks):
summary = summarize_text(chunk)
temp_file.write(f"Summary of {file_name} (Part {i+1}):\n{summary}\n\n")
# Process the temporary file and add to ChromaDB
with open(TEMP_SUMMARY_FILE, 'r', encoding='utf-8') as temp_file:
summaries = temp_file.read()
summary_chunks = text_splitter.split_text(summaries)
for i, chunk in enumerate(summary_chunks):
vector = embeddings.embed_query(chunk)
collection.add(
embeddings=[vector],
documents=[chunk],
ids=[f"summary_{i}"]
)
os.remove(TEMP_SUMMARY_FILE)
# Update collections.csv
update_collections_csv(collection_name, file_names)
return "Files uploaded, summarized, and processed successfully."
# Function to update collections.csv
def update_collections_csv(collection_name, file_names):
file_names_str = ", ".join(file_names)
with open(COLLECTIONS_FILE, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([collection_name, file_names_str])
# Function to read collections.csv
def read_collections():
if not os.path.exists(COLLECTIONS_FILE):
return "No collections found."
with open(COLLECTIONS_FILE, 'r') as csvfile:
reader = csv.reader(csvfile)
collections = [f"Collection: {row[0]}\nFiles: {row[1]}\n\n" for row in reader]
return "".join(collections)
# Function to search vector database
def search_vector_database(query, collection_name):
if not collection_name:
return "Please provide a collection name."
embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
client = chromadb.PersistentClient(path="./db")
try:
collection = client.get_collection(name=collection_name)
except ValueError as e:
return f"Error accessing collection: {str(e)}. Make sure the collection name is correct."
query_vector = embeddings.embed_query(query)
results = collection.query(query_embeddings=[query_vector], n_results=2, include=["documents"])
return "\n\n".join(results["documents"][0])
# Function to call LLM
def call_llm(prompt):
headers = {
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
}
data = {
"model": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"top_p": 0.7,
"top_k": 50,
"repetition_penalty": 1,
"stop": ["\"\""],
"stream": False
}
response = requests.post(BASE_URL, headers=headers, data=json.dumps(data))
response.raise_for_status()
return response.json()['choices'][0]['message']['content']
# Function to answer questions using Rachel.AI
def answer_question(question, collection_name):
context = search_vector_database(question, collection_name)
prompt = f"""
You are a paralegal AI assistant. Your role is to assist with legal inquiries by providing clear and concise answers based on the provided question and legal context. Always maintain a highly professional tone, ensuring that your responses are well-reasoned and legally accurate.
Question: {question}
Legal Context: {context}
Please provide a detailed response considering the above information.
"""
return call_llm(prompt)
# Gradio interface
def gradio_interface():
with gr.Blocks(theme='gl198976/The-Rounded') as interface:
gr.Markdown("# rachel.ai backend")
gr.Markdown("""
### Warning
If you encounter an error when uploading files, try changing the collection name and upload again.
Each collection name must be unique.
""")
with gr.Tab("Document Upload and Search"):
with gr.Row():
with gr.Column():
collection_name_input = gr.Textbox(label="Collection Name", placeholder="Enter a unique name for this collection")
file_upload = gr.Files(file_types=[".pdf"], label="Upload PDFs")
upload_btn = gr.Button("Upload, Summarize, and Process Files")
upload_status = gr.Textbox(label="Upload Status", interactive=False)
with gr.Column():
search_query_input = gr.Textbox(label="Search Query")
search_collection_name = gr.Textbox(label="Collection Name for Search", placeholder="Enter the collection name to search")
search_output = gr.Textbox(label="Search Results", lines=10)
search_btn = gr.Button("Search")
api_details = gr.Markdown("""
### API Endpoint Details
- **URL:** http://0.0.0.0:7860/search_vector_database
- **Method:** POST
- **Example Usage:**
```python
from gradio_client import Client
client = Client("http://0.0.0.0:7860/")
result = client.predict(
"search query", # str in 'Search Query' Textbox component
"name of collection given in ui", # str in 'Collection Name' Textbox component
api_name="/search_vector_database"
)
print(result)
```
""")
with gr.Tab("Rachel.AI"):
question_input = gr.Textbox(label="Ask a question")
rachel_collection_name = gr.Textbox(label="Collection Name", placeholder="Enter the collection name to search")
answer_output = gr.Textbox(label="Answer", lines=10)
ask_btn = gr.Button("Ask Rachel.AI")
rachel_api_details = gr.Markdown("""
### API Endpoint Details for Rachel.AI
- **URL:** http://0.0.0.0:7860/answer_question
- **Method:** POST
- **Example Usage:**
```python
from gradio_client import Client
client = Client("http://0.0.0.0:7860/")
result = client.predict(
"question", # str in 'Ask a question' Textbox component
"collection_name", # str in 'Collection Name' Textbox component
api_name="/answer_question"
)
print(result)
```
""")
with gr.Tab("Collections"):
collections_output = gr.Textbox(label="Collections and Files", lines=20)
refresh_btn = gr.Button("Refresh Collections")
upload_btn.click(handle_file_upload, inputs=[file_upload, collection_name_input], outputs=[upload_status])
search_btn.click(search_vector_database, inputs=[search_query_input, search_collection_name], outputs=[search_output])
ask_btn.click(answer_question, inputs=[question_input, rachel_collection_name], outputs=[answer_output])
refresh_btn.click(read_collections, inputs=[], outputs=[collections_output])
interface.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
gradio_interface()