pratikshahp's picture
Update app.py
93949bf verified
#https://medium.com/@csakash03/hybrid-search-is-a-method-to-optimize-rag-implementation-98d9d0911341
#https://medium.com/etoai/hybrid-search-combining-bm25-and-semantic-search-for-better-results-with-lan-1358038fe7e6
import gradio as gr
import zipfile
import os
import re
from pathlib import Path
import chromadb
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_chroma import Chroma
# from langchain.textsplitters import RecursiveCharacterTextSplitter
from langchain_text_splitters import RecursiveCharacterTextSplitter
import hashlib
import nltk
from rank_bm25 import BM25Okapi
import numpy as np
from langchain.schema import Document
from dotenv import load_dotenv
# Download the required NLTK data
nltk.download('punkt')
# Define embeddings using Hugging Face models
embeddings = HuggingFaceEmbeddings()
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
# Initialize Chroma vector store
persist_directory = "./chroma_langchain_db"
client = chromadb.PersistentClient()
collection = client.get_or_create_collection("whatsapp_collection")
vector_store = Chroma(
collection_name="whatsapp_collection",
embedding_function=embeddings,
persist_directory=persist_directory,
)
# Define global variables
bm25 = None
all_texts = []
processed_files = {} # Dictionary to store hashes of processed files
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
huggingfacehub_api_token=HF_TOKEN.strip(),
temperature=0.1,
max_new_tokens=200
)
# Function to remove emojis and clean the text
def clean_text(text):
# Remove emojis
text = re.sub(r'[^\x00-\x7F]+', '', text)
# Additional cleaning if necessary
text = re.sub(r'\s+', ' ', text).strip()
return text
# Function to compute a file hash for identifying duplicates
def compute_file_hash(file_path):
hasher = hashlib.md5()
with open(file_path, 'rb') as f:
buf = f.read()
hasher.update(buf)
return hasher.hexdigest()
# Function to process and upload the zip file to Chroma
def process_and_upload_zip(zip_file):
global bm25, all_texts, processed_files
temp_dir = Path("temp")
temp_dir.mkdir(exist_ok=True)
# Compute hash to check if file has been processed
zip_file_hash = compute_file_hash(zip_file.name)
# If the file has been processed before, skip re-uploading
if zip_file_hash in processed_files:
return f"File '{zip_file.name}' already processed. Using existing Chroma storage."
# Extract the zip file
with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# Load and clean the chat text
chat_files = list(temp_dir.glob("*.txt"))
metadata = []
all_texts = []
for chat_file in chat_files:
with open(chat_file, 'r', encoding='utf-8') as file:
page_content = file.read()
# Clean the text
clean_content = clean_text(page_content)
# Split the clean_content into chunks of 2500 characters with 200 overlap
chunk_splitter = RecursiveCharacterTextSplitter(chunk_size=2500, chunk_overlap=200)
chunks = chunk_splitter.split_text(clean_content)
for chunk_index, chunk in enumerate(chunks):
metadata.append({
"context": chunk,
"document_id": chat_file.stem,
"chunk_index": chunk_index
})
all_texts.append(chunk)
# Initialize BM25 for sparse retrieval
bm25 = BM25Okapi([doc.split() for doc in all_texts])
# Create dense embeddings and store in Chroma
chunk_embeddings = embeddings.embed_documents(all_texts)
ids = [f"{m['document_id']}_chunk_{m['chunk_index']}" for m in metadata]
documents = [Document(page_content=m["context"], metadata=m) for m in metadata]
vector_store.add_documents(documents=documents, ids=ids)
# Store the hash of the processed file to avoid reprocessing
processed_files[zip_file_hash] = zip_file.name
return "Data uploaded and stored in Chroma successfully."
def hybrid_search(query):
global bm25, all_texts
# BM25 Sparse Retrieval
query_terms = query.split()
bm25_scores = bm25.get_scores(query_terms)
bm25_top_n_indices = np.argsort(bm25_scores)[::-1][:5] # Top 5 results
sparse_results = [all_texts[i] for i in bm25_top_n_indices]
# Dense Retrieval using Chroma
dense_results = vector_store.similarity_search(query, k=5)
# Combine the results (you can enhance the combination logic here)
combined_results = sparse_results + [result.page_content for result in dense_results]
response = ""
for result in combined_results:
response += f"{result}\n\n"
return f"Hybrid Search Results:\n\n{response}"
# Gradio Interface for uploading and querying
def query_interface(zip_file, query):
upload_status = process_and_upload_zip(zip_file)
search_results = hybrid_search(query)
prompt = (f"Here is a summary of WhatsApp chat contents based on the search for the query: '{query}'. "
f"The chat content includes important messages:\n\n"
f"{search_results}\n\n"
f"Now, based on this chat content, answer the following question as an expert. "
f"Please provide a complete and precise answer in **100 words**.\n\n"
f"Question: {query}")
response = llm.invoke(prompt)
# Generate answer using the LLM
return f"{upload_status}\n\n{search_results}", response
interface = gr.Interface(
fn=query_interface,
inputs=[gr.File(label="Upload WhatsApp Chat Zip File"), gr.Textbox(label="Enter your query")],
outputs=[
gr.Textbox(label="Chat Content"), # To display the chat content
gr.Textbox(label="Generated Answer") # To display the generated answer
],
title="WhatsApp Chat Upload and Hybrid Search",
description="Upload a zip file containing WhatsApp chat data. This app processes the data and performs hybrid search with BM25 + Chroma."
)
if __name__ == "__main__":
interface.launch()