|
import streamlit as st |
|
import os |
|
import uuid |
|
import shutil |
|
from datetime import datetime, timedelta |
|
from dotenv import load_dotenv |
|
from chatMode import chat_response |
|
from modules.pdfExtractor import PdfConverter |
|
from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma |
|
from sentence_transformers import SentenceTransformer |
|
from modules.llm import GroqClient, GroqCompletion |
|
import chromadb |
|
import json |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel")) |
|
embeddModel.max_seq_length = 512 |
|
chunk_size, chunk_overlap, top_k_default = 1000, 300, 5 |
|
|
|
|
|
api_key = os.getenv("GROQ_API_KEY") |
|
groq_client = GroqClient(api_key) |
|
llm_model = { |
|
"Gemma9B": "gemma2-9b-it", |
|
"Gemma7B": "gemma-7b-it", |
|
"LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview", |
|
"LLama3.1-70B": "llama-3.1-70b-versatile", |
|
"LLama3-70B": "llama3-70b-8192", |
|
"LLama3.2-90B": "llama-3.2-90b-text-preview", |
|
"Mixtral8x7B": "mixtral-8x7b-32768" |
|
} |
|
max_tokens = { |
|
"Gemma9B": 8192, |
|
"Gemma7B": 8192, |
|
"LLama3-70B": 8192, |
|
"LLama3.1-70B": 8000, |
|
"LLama3-70B": 8192, |
|
"LLama3.2-90B": 8192, |
|
"Mixtral8x7B": 32768 |
|
} |
|
|
|
|
|
EXPIRATION_TIME = timedelta(hours=6) |
|
UPLOAD_DIR = "Uploaded" |
|
VECTOR_DB_DIR = "vectorDB" |
|
LOG_FILE = "upload_log.json" |
|
|
|
|
|
st.set_page_config(page_title="Ospyn AI", layout="wide") |
|
st.markdown("<h2 style='text-align: center;'>Ospyn AI</h2>", unsafe_allow_html=True) |
|
|
|
|
|
def log_upload_time(unique_id): |
|
upload_time = datetime.now().isoformat() |
|
log_entry = {unique_id: upload_time} |
|
if os.path.exists(LOG_FILE): |
|
with open(LOG_FILE, "r") as f: |
|
log_data = json.load(f) |
|
log_data.update(log_entry) |
|
else: |
|
log_data = log_entry |
|
|
|
with open(LOG_FILE, "w") as f: |
|
json.dump(log_data, f) |
|
|
|
|
|
def cleanup_expired_files(): |
|
current_time = datetime.now() |
|
|
|
|
|
if os.path.exists(LOG_FILE): |
|
with open(LOG_FILE, "r") as f: |
|
log_data = json.load(f) |
|
|
|
keys_to_delete = [] |
|
|
|
for unique_id, upload_time in log_data.items(): |
|
upload_time_dt = datetime.fromisoformat(upload_time) |
|
if current_time - upload_time_dt > EXPIRATION_TIME: |
|
|
|
keys_to_delete.append(unique_id) |
|
|
|
|
|
pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf") |
|
vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id) |
|
|
|
if os.path.isfile(pdf_file_path): |
|
os.remove(pdf_file_path) |
|
if os.path.isdir(vector_db_path): |
|
shutil.rmtree(vector_db_path) |
|
|
|
|
|
for key in keys_to_delete: |
|
del log_data[key] |
|
|
|
|
|
with open(LOG_FILE, "w") as f: |
|
json.dump(log_data, f) |
|
|
|
|
|
with st.sidebar: |
|
st.title("Upload PDF:") |
|
|
|
research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas") |
|
option = '' |
|
|
|
if not research_field: |
|
st.info("Please enter a research field to proceed.") |
|
option = st.selectbox('Select Mode', ('Chat', 'Code'), disabled=True) |
|
uploaded_file = st.file_uploader("", type=["pdf"], disabled=True) |
|
else: |
|
option = st.selectbox('Select Mode', ('Chat', 'Code')) |
|
uploaded_file = st.file_uploader("", type=["pdf"], disabled=False) |
|
|
|
temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01) |
|
selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3) |
|
top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5) |
|
|
|
|
|
if 'db_client' not in st.session_state: |
|
unique_id = str(uuid.uuid4()) |
|
st.session_state['unique_id'] = unique_id |
|
db_path = os.path.join(VECTOR_DB_DIR, unique_id) |
|
os.makedirs(db_path, exist_ok=True) |
|
st.session_state['db_path'] = db_path |
|
st.session_state['db_client'] = chromadb.PersistentClient(path=db_path) |
|
|
|
|
|
log_upload_time(unique_id) |
|
|
|
|
|
db_client = st.session_state['db_client'] |
|
unique_id = st.session_state['unique_id'] |
|
db_path = st.session_state['db_path'] |
|
|
|
if 'document_text' not in st.session_state: |
|
st.session_state['document_text'] = None |
|
|
|
if 'text_embeddings' not in st.session_state: |
|
st.session_state['text_embeddings'] = None |
|
|
|
|
|
if uploaded_file is not None and st.session_state['document_text'] is None: |
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf") |
|
with open(file_path, "wb") as file: |
|
file.write(uploaded_file.getvalue()) |
|
|
|
document_text = PdfConverter(file_path).convert_to_markdown() |
|
st.session_state['document_text'] = document_text |
|
|
|
text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap) |
|
text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path) |
|
st.session_state['text_embeddings'] = text_contents_embeddings |
|
|
|
if st.session_state['document_text'] and st.session_state['text_embeddings']: |
|
document_text = st.session_state['document_text'] |
|
text_contents_embeddings = st.session_state['text_embeddings'] |
|
else: |
|
st.stop() |
|
|
|
q_input = st.chat_input(key="input", placeholder="Ask your question") |
|
|
|
if q_input: |
|
if option == "Chat": |
|
query_embedding = ragQuery(embeddModel, q_input) |
|
top_k_matches = similarityChroma(query_embedding, db_client, top_k) |
|
|
|
LLMmodel = llm_model[selected_llm_model] |
|
domain = research_field |
|
prompt_template = q_input |
|
user_content = top_k_matches |
|
max_tokens = max_tokens[selected_llm_model] |
|
print(max_tokens) |
|
top_p = 1 |
|
stream = True |
|
stop = None |
|
|
|
groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p, stream, stop) |
|
result = groq_completion.create_completion() |
|
|
|
with st.spinner("Processing..."): |
|
chat_response(q_input, result) |
|
|
|
|
|
cleanup_expired_files() |
|
|
|
|
|
|