Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import uuid | |
import shutil | |
from datetime import datetime, timedelta | |
from dotenv import load_dotenv | |
from chatMode import chat_response | |
from modules.pdfExtractor import PdfConverter | |
from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma | |
from sentence_transformers import SentenceTransformer | |
from modules.llm import GroqClient, GroqCompletion | |
import chromadb | |
import json | |
# Load environment variables | |
load_dotenv() | |
######## Embedding Model ######## | |
embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel")) | |
embeddModel.max_seq_length = 512 | |
chunk_size, chunk_overlap, top_k_default = 2000, 200, 5 | |
######## Groq to LLM Connect ######## | |
api_key = os.getenv("GROQ_API_KEY") | |
groq_client = GroqClient(api_key) | |
llm_model = { | |
"Gemma9B": "gemma2-9b-it", | |
"Gemma7B": "gemma-7b-it", | |
"LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview", | |
"LLama3.1-70B": "llama-3.1-70b-versatile", | |
"LLama3-70B": "llama3-70b-8192", | |
"LLama3.2-90B": "llama-3.2-90b-text-preview", | |
"Mixtral8x7B": "mixtral-8x7b-32768" | |
} | |
max_tokens = { | |
"Gemma9B": 8192, | |
"Gemma7B": 8192, | |
"LLama3-70B": 8192, | |
"LLama3.1-70B": 8000, | |
"LLama3-70B": 8192, | |
"LLama3.2-90B": 8192, | |
"Mixtral8x7B": 32768 | |
} | |
## Time-based cleanup settings | |
EXPIRATION_TIME = timedelta(hours=6) | |
UPLOAD_DIR = "Uploaded" | |
VECTOR_DB_DIR = "vectorDB" | |
LOG_FILE = "upload_log.json" | |
## Initialize Streamlit app | |
st.set_page_config(page_title="ChatPDF", layout="wide") | |
st.markdown("<h2 style='text-align: center;'>chatPDF</h2>", unsafe_allow_html=True) | |
## Function to log upload time | |
def log_upload_time(unique_id): | |
upload_time = datetime.now().isoformat() | |
log_entry = {unique_id: upload_time} | |
if os.path.exists(LOG_FILE): | |
with open(LOG_FILE, "r") as f: | |
log_data = json.load(f) | |
log_data.update(log_entry) | |
else: | |
log_data = log_entry | |
with open(LOG_FILE, "w") as f: | |
json.dump(log_data, f) | |
## Cleanup expired files based on log | |
def cleanup_expired_files(): | |
current_time = datetime.now() | |
# Load upload log | |
if os.path.exists(LOG_FILE): | |
with open(LOG_FILE, "r") as f: | |
log_data = json.load(f) | |
keys_to_delete = [] # List to keep track of keys to delete | |
# Check each entry in the log | |
for unique_id, upload_time in log_data.items(): | |
upload_time_dt = datetime.fromisoformat(upload_time) | |
if current_time - upload_time_dt > EXPIRATION_TIME: | |
# Add key to the list for deletion | |
keys_to_delete.append(unique_id) | |
# Remove files if expired | |
pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf") | |
vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id) | |
if os.path.isfile(pdf_file_path): | |
os.remove(pdf_file_path) | |
if os.path.isdir(vector_db_path): | |
shutil.rmtree(vector_db_path) | |
# Now delete the keys from log_data after iteration | |
for key in keys_to_delete: | |
del log_data[key] | |
# Save updated log | |
with open(LOG_FILE, "w") as f: | |
json.dump(log_data, f) | |
## Context Taking, PDF Upload, and Mode Selection | |
with st.sidebar: | |
st.title("Upload PDF:") | |
research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas") | |
option = '' | |
if not research_field: | |
st.info("Please enter a research field to proceed.") | |
option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'), disabled=True) | |
uploaded_file = st.file_uploader("", type=["pdf"], disabled=True) | |
else: | |
option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting')) | |
uploaded_file = st.file_uploader("", type=["pdf"], disabled=False) | |
temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01) | |
selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3) | |
top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5) | |
## Initialize unique ID, db_client, db_path, and timestamp if not already in session state | |
if 'db_client' not in st.session_state: | |
unique_id = str(uuid.uuid4()) | |
st.session_state['unique_id'] = unique_id | |
db_path = os.path.join(VECTOR_DB_DIR, unique_id) | |
os.makedirs(db_path, exist_ok=True) | |
st.session_state['db_path'] = db_path | |
st.session_state['db_client'] = chromadb.PersistentClient(path=db_path) | |
# Log the upload time | |
log_upload_time(unique_id) | |
# Access session-stored variables | |
db_client = st.session_state['db_client'] | |
unique_id = st.session_state['unique_id'] | |
db_path = st.session_state['db_path'] | |
if 'document_text' not in st.session_state: | |
st.session_state['document_text'] = None | |
if 'text_embeddings' not in st.session_state: | |
st.session_state['text_embeddings'] = None | |
## Handle PDF Upload and Processing | |
if uploaded_file is not None and st.session_state['document_text'] is None: | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf") | |
with open(file_path, "wb") as file: | |
file.write(uploaded_file.getvalue()) | |
document_text = PdfConverter(file_path).convert_to_markdown() | |
st.session_state['document_text'] = document_text | |
text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap) | |
text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path) | |
st.session_state['text_embeddings'] = text_contents_embeddings | |
if st.session_state['document_text'] and st.session_state['text_embeddings']: | |
document_text = st.session_state['document_text'] | |
text_contents_embeddings = st.session_state['text_embeddings'] | |
else: | |
st.stop() | |
q_input = st.chat_input(key="input", placeholder="Ask your question") | |
if q_input: | |
if option == "Chat": | |
query_embedding = ragQuery(embeddModel, q_input) | |
top_k_matches = similarityChroma(query_embedding, db_client, top_k) | |
LLMmodel = llm_model[selected_llm_model] | |
domain = research_field | |
prompt_template = q_input | |
user_content = top_k_matches | |
max_tokens = max_tokens[selected_llm_model] | |
print(max_tokens) | |
top_p = 1 | |
stream = True | |
stop = None | |
groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p, stream, stop) | |
result = groq_completion.create_completion() | |
with st.spinner("Processing..."): | |
chat_response(q_input, result) | |
## Call the cleanup function periodically | |
cleanup_expired_files() | |