chatPDF-RAG / app.py
bipin-saha
deployment 0.1
9376ebb
import streamlit as st
import os
import uuid
import shutil
from datetime import datetime, timedelta
from dotenv import load_dotenv
from chatMode import chat_response
from modules.pdfExtractor import PdfConverter
from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma
from sentence_transformers import SentenceTransformer
from modules.llm import GroqClient, GroqCompletion
import chromadb
import json
# Load environment variables
load_dotenv()
######## Embedding Model ########
embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel"))
embeddModel.max_seq_length = 512
chunk_size, chunk_overlap, top_k_default = 2000, 200, 5
######## Groq to LLM Connect ########
api_key = os.getenv("GROQ_API_KEY")
groq_client = GroqClient(api_key)
llm_model = {
"Gemma9B": "gemma2-9b-it",
"Gemma7B": "gemma-7b-it",
"LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview",
"LLama3.1-70B": "llama-3.1-70b-versatile",
"LLama3-70B": "llama3-70b-8192",
"LLama3.2-90B": "llama-3.2-90b-text-preview",
"Mixtral8x7B": "mixtral-8x7b-32768"
}
max_tokens = {
"Gemma9B": 8192,
"Gemma7B": 8192,
"LLama3-70B": 8192,
"LLama3.1-70B": 8000,
"LLama3-70B": 8192,
"LLama3.2-90B": 8192,
"Mixtral8x7B": 32768
}
## Time-based cleanup settings
EXPIRATION_TIME = timedelta(hours=6)
UPLOAD_DIR = "Uploaded"
VECTOR_DB_DIR = "vectorDB"
LOG_FILE = "upload_log.json"
## Initialize Streamlit app
st.set_page_config(page_title="ChatPDF", layout="wide")
st.markdown("<h2 style='text-align: center;'>chatPDF</h2>", unsafe_allow_html=True)
## Function to log upload time
def log_upload_time(unique_id):
upload_time = datetime.now().isoformat()
log_entry = {unique_id: upload_time}
if os.path.exists(LOG_FILE):
with open(LOG_FILE, "r") as f:
log_data = json.load(f)
log_data.update(log_entry)
else:
log_data = log_entry
with open(LOG_FILE, "w") as f:
json.dump(log_data, f)
## Cleanup expired files based on log
def cleanup_expired_files():
current_time = datetime.now()
# Load upload log
if os.path.exists(LOG_FILE):
with open(LOG_FILE, "r") as f:
log_data = json.load(f)
keys_to_delete = [] # List to keep track of keys to delete
# Check each entry in the log
for unique_id, upload_time in log_data.items():
upload_time_dt = datetime.fromisoformat(upload_time)
if current_time - upload_time_dt > EXPIRATION_TIME:
# Add key to the list for deletion
keys_to_delete.append(unique_id)
# Remove files if expired
pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id)
if os.path.isfile(pdf_file_path):
os.remove(pdf_file_path)
if os.path.isdir(vector_db_path):
shutil.rmtree(vector_db_path)
# Now delete the keys from log_data after iteration
for key in keys_to_delete:
del log_data[key]
# Save updated log
with open(LOG_FILE, "w") as f:
json.dump(log_data, f)
## Context Taking, PDF Upload, and Mode Selection
with st.sidebar:
st.title("Upload PDF:")
research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas")
option = ''
if not research_field:
st.info("Please enter a research field to proceed.")
option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'), disabled=True)
uploaded_file = st.file_uploader("", type=["pdf"], disabled=True)
else:
option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'))
uploaded_file = st.file_uploader("", type=["pdf"], disabled=False)
temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01)
selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3)
top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5)
## Initialize unique ID, db_client, db_path, and timestamp if not already in session state
if 'db_client' not in st.session_state:
unique_id = str(uuid.uuid4())
st.session_state['unique_id'] = unique_id
db_path = os.path.join(VECTOR_DB_DIR, unique_id)
os.makedirs(db_path, exist_ok=True)
st.session_state['db_path'] = db_path
st.session_state['db_client'] = chromadb.PersistentClient(path=db_path)
# Log the upload time
log_upload_time(unique_id)
# Access session-stored variables
db_client = st.session_state['db_client']
unique_id = st.session_state['unique_id']
db_path = st.session_state['db_path']
if 'document_text' not in st.session_state:
st.session_state['document_text'] = None
if 'text_embeddings' not in st.session_state:
st.session_state['text_embeddings'] = None
## Handle PDF Upload and Processing
if uploaded_file is not None and st.session_state['document_text'] is None:
os.makedirs(UPLOAD_DIR, exist_ok=True)
file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
with open(file_path, "wb") as file:
file.write(uploaded_file.getvalue())
document_text = PdfConverter(file_path).convert_to_markdown()
st.session_state['document_text'] = document_text
text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap)
text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path)
st.session_state['text_embeddings'] = text_contents_embeddings
if st.session_state['document_text'] and st.session_state['text_embeddings']:
document_text = st.session_state['document_text']
text_contents_embeddings = st.session_state['text_embeddings']
else:
st.stop()
q_input = st.chat_input(key="input", placeholder="Ask your question")
if q_input:
if option == "Chat":
query_embedding = ragQuery(embeddModel, q_input)
top_k_matches = similarityChroma(query_embedding, db_client, top_k)
LLMmodel = llm_model[selected_llm_model]
domain = research_field
prompt_template = q_input
user_content = top_k_matches
max_tokens = max_tokens[selected_llm_model]
print(max_tokens)
top_p = 1
stream = True
stop = None
groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p, stream, stop)
result = groq_completion.create_completion()
with st.spinner("Processing..."):
chat_response(q_input, result)
## Call the cleanup function periodically
cleanup_expired_files()