Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import uuid | |
| import shutil | |
| from datetime import datetime, timedelta | |
| from dotenv import load_dotenv | |
| from chatMode import chat_response | |
| from modules.pdfExtractor import PdfConverter | |
| from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma | |
| from sentence_transformers import SentenceTransformer | |
| from modules.llm import GroqClient, GroqCompletion | |
| import chromadb | |
| import json | |
| # Load environment variables | |
| load_dotenv() | |
| ######## Embedding Model ######## | |
| embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel")) | |
| embeddModel.max_seq_length = 512 | |
| chunk_size, chunk_overlap, top_k_default = 2000, 200, 5 | |
| ######## Groq to LLM Connect ######## | |
| api_key = os.getenv("GROQ_API_KEY") | |
| groq_client = GroqClient(api_key) | |
| llm_model = { | |
| "Gemma9B": "gemma2-9b-it", | |
| "Gemma7B": "gemma-7b-it", | |
| "LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview", | |
| "LLama3.1-70B": "llama-3.1-70b-versatile", | |
| "LLama3-70B": "llama3-70b-8192", | |
| "LLama3.2-90B": "llama-3.2-90b-text-preview", | |
| "Mixtral8x7B": "mixtral-8x7b-32768" | |
| } | |
| max_tokens = { | |
| "Gemma9B": 8192, | |
| "Gemma7B": 8192, | |
| "LLama3-70B": 8192, | |
| "LLama3.1-70B": 8000, | |
| "LLama3-70B": 8192, | |
| "LLama3.2-90B": 8192, | |
| "Mixtral8x7B": 32768 | |
| } | |
| ## Time-based cleanup settings | |
| EXPIRATION_TIME = timedelta(hours=6) | |
| UPLOAD_DIR = "Uploaded" | |
| VECTOR_DB_DIR = "vectorDB" | |
| LOG_FILE = "upload_log.json" | |
| ## Initialize Streamlit app | |
| st.set_page_config(page_title="ChatPDF", layout="wide") | |
| st.markdown("<h2 style='text-align: center;'>chatPDF</h2>", unsafe_allow_html=True) | |
| ## Function to log upload time | |
| def log_upload_time(unique_id): | |
| upload_time = datetime.now().isoformat() | |
| log_entry = {unique_id: upload_time} | |
| if os.path.exists(LOG_FILE): | |
| with open(LOG_FILE, "r") as f: | |
| log_data = json.load(f) | |
| log_data.update(log_entry) | |
| else: | |
| log_data = log_entry | |
| with open(LOG_FILE, "w") as f: | |
| json.dump(log_data, f) | |
| ## Cleanup expired files based on log | |
| def cleanup_expired_files(): | |
| current_time = datetime.now() | |
| # Load upload log | |
| if os.path.exists(LOG_FILE): | |
| with open(LOG_FILE, "r") as f: | |
| log_data = json.load(f) | |
| keys_to_delete = [] # List to keep track of keys to delete | |
| # Check each entry in the log | |
| for unique_id, upload_time in log_data.items(): | |
| upload_time_dt = datetime.fromisoformat(upload_time) | |
| if current_time - upload_time_dt > EXPIRATION_TIME: | |
| # Add key to the list for deletion | |
| keys_to_delete.append(unique_id) | |
| # Remove files if expired | |
| pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf") | |
| vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id) | |
| if os.path.isfile(pdf_file_path): | |
| os.remove(pdf_file_path) | |
| if os.path.isdir(vector_db_path): | |
| shutil.rmtree(vector_db_path) | |
| # Now delete the keys from log_data after iteration | |
| for key in keys_to_delete: | |
| del log_data[key] | |
| # Save updated log | |
| with open(LOG_FILE, "w") as f: | |
| json.dump(log_data, f) | |
| ## Context Taking, PDF Upload, and Mode Selection | |
| with st.sidebar: | |
| st.title("Upload PDF:") | |
| research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas") | |
| option = '' | |
| if not research_field: | |
| st.info("Please enter a research field to proceed.") | |
| option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'), disabled=True) | |
| uploaded_file = st.file_uploader("", type=["pdf"], disabled=True) | |
| else: | |
| option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting')) | |
| uploaded_file = st.file_uploader("", type=["pdf"], disabled=False) | |
| temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01) | |
| selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3) | |
| top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5) | |
| ## Initialize unique ID, db_client, db_path, and timestamp if not already in session state | |
| if 'db_client' not in st.session_state: | |
| unique_id = str(uuid.uuid4()) | |
| st.session_state['unique_id'] = unique_id | |
| db_path = os.path.join(VECTOR_DB_DIR, unique_id) | |
| os.makedirs(db_path, exist_ok=True) | |
| st.session_state['db_path'] = db_path | |
| st.session_state['db_client'] = chromadb.PersistentClient(path=db_path) | |
| # Log the upload time | |
| log_upload_time(unique_id) | |
| # Access session-stored variables | |
| db_client = st.session_state['db_client'] | |
| unique_id = st.session_state['unique_id'] | |
| db_path = st.session_state['db_path'] | |
| if 'document_text' not in st.session_state: | |
| st.session_state['document_text'] = None | |
| if 'text_embeddings' not in st.session_state: | |
| st.session_state['text_embeddings'] = None | |
| ## Handle PDF Upload and Processing | |
| if uploaded_file is not None and st.session_state['document_text'] is None: | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf") | |
| with open(file_path, "wb") as file: | |
| file.write(uploaded_file.getvalue()) | |
| document_text = PdfConverter(file_path).convert_to_markdown() | |
| st.session_state['document_text'] = document_text | |
| text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap) | |
| text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path) | |
| st.session_state['text_embeddings'] = text_contents_embeddings | |
| if st.session_state['document_text'] and st.session_state['text_embeddings']: | |
| document_text = st.session_state['document_text'] | |
| text_contents_embeddings = st.session_state['text_embeddings'] | |
| else: | |
| st.stop() | |
| q_input = st.chat_input(key="input", placeholder="Ask your question") | |
| if q_input: | |
| if option == "Chat": | |
| query_embedding = ragQuery(embeddModel, q_input) | |
| top_k_matches = similarityChroma(query_embedding, db_client, top_k) | |
| LLMmodel = llm_model[selected_llm_model] | |
| domain = research_field | |
| prompt_template = q_input | |
| user_content = top_k_matches | |
| max_tokens = max_tokens[selected_llm_model] | |
| print(max_tokens) | |
| top_p = 1 | |
| stream = True | |
| stop = None | |
| groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p, stream, stop) | |
| result = groq_completion.create_completion() | |
| with st.spinner("Processing..."): | |
| chat_response(q_input, result) | |
| ## Call the cleanup function periodically | |
| cleanup_expired_files() | |