import json
import os
import time
import random
import string
import shutil
import base64
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
import streamlit as st
from streamlit_option_menu import option_menu
from streamlit_lottie import st_lottie
import torch
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_community.document_loaders import PyPDFLoader, TextLoader, WikipediaLoader
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders.merge import MergedDataLoader
from langchain_groq import ChatGroq
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
# langchain-openai==0.1.6
# langchain-text-splitters==0.0.1
# langdetect==1.0.9
# langsmith==0.1.53
# --------------------------------------------------------- Helper Function/Class 1
def loadLottieFile(filePath: str):
with open(file=filePath, mode = "r") as f:
return json.load(f)
# --------------------------------------------------------- Helper Function/Class 2
def loadItOnce(container, animation, height, quality = 'high'):
with container.container():
st_lottie(animation_source=animation, height=height, quality=quality)
# --------------------------------------------------------- Helper Function/Class 3
def initialRampUp(llamaAnimation):
with st.lottie_spinner(llamaAnimation, height = 700):
if 'filesUploadedRecords' not in st.session_state: # To Maintain Active File Records
st.session_state.filesUploadedRecords = None
if 'vdbBuilt' not in st.session_state: # To Maintain Active Vector Database records for corresponding files
st.session_state.vdbBuilt = {}
if 'vectorDatabase' not in st.session_state: # To Maintain or hold vector database connection
st.session_state.vectorDatabase = {}
if 'collectionName' not in st.session_state:
st.session_state.collectionName = str(''.join(random.choices(string.ascii_letters, k=25)))
try:
shutil.rmtree('Dump')
except:
pass
os.makedirs('Dump', exist_ok=True)
time.sleep(3)
# --------------------------------------------------------- Helper Function/Class 4
def image_to_base64(image_path):
with open(image_path, "rb") as img_file:
encoded_string = base64.b64encode(img_file.read()).decode("utf-8")
return encoded_string
def set_background_image(base64_image, opacity):
# Define custom CSS for setting background image
custom_css = f"""
"""
# Display custom CSS using markdown
st.markdown(custom_css, unsafe_allow_html=True)
# --------------------------------------------------------- Helper Function/Class 5
def navigationBar():
# Use the following link to get whichever icon you'd like:
# https://getbootstrap.com/
options = [
{"label": "Retrieval Augmented Generation", "icon": "bezier"},
{"label": "Fine Tuning LLMs (Coming Soon)", "icon": "gpu-card"},
{"label": "Forecasting LLMs (Coming Soon)", "icon": "graph-up-arrow"}
]
selected = option_menu(
menu_title= None, #"Ask Me Anything", # Menu title
options=[option["label"] for option in options],
icons=[option["icon"] for option in options],
menu_icon="lightbulb-fill",
default_index=0,
orientation="horizontal",
styles={
"container": {
"display": "flex",
"flex-direction": "column",
"justify-content": "center",
"padding": "20px 40px 20px 40px", # Increased top and bottom padding
"background-color": "#222", # Dark background color
"border-radius": "20px",
"width":"100%",
"box-shadow": "0px 2px 10px rgba(0, 0, 0, 0.2)", # Shadow effect
"margin": "auto", # Center align the navigation bar
"overflow-x": "auto", # Allow horizontal scrolling for small screens
},
"menu-title": {
"font-size": "36px",
"font-weight": "bold",
"background-color": "#222",
"color": "#FFFFFF", # White text color
"margin-bottom": "20px", # Spacing below the menu title
},
"menu-icon": {
"color": "#FFD700", # Golden yellow icon color
"font-size": "36px",
"margin-right": "10px",
},
"icon": {
"color": "#FFD700", # Golden yellow icon color
"font-size": "36px",
"margin-right": "10px",
},
"nav-link": {
"font-size": "20px",
"text-align": "center",
"color": "#FFFFFF", # White text color
"padding": "10px 20px",
"border-radius": "15px",
"transition": "background-color 0.3s ease",
},
"nav-link-selected": {
"background-color": "#FF6347", # Tomato red when selected
"color": "#FFFFFF", # White text color when selected
}
},
)
return selected
# --------------------------------------------------------- Helper Function/Class 6
@st.cache_resource(show_spinner=False)
def loadEmbeddingsModels():
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Constructing Embeddings function to pass as an argument in ChromaDB for Calculating Embeddings
return HuggingFaceInstructEmbeddings(
model_name = "hkunlp/instructor-base",
query_instruction = "Represent the query for retrieval: ",
model_kwargs = {"device": DEVICE})
# --------------------------------------------------------- Helper Function/Class 7
def removedOrAdded(files):
removed = {}
# In the beginning, when there are no files
# In the end when there are no files
if len(files) == 0:
if (st.session_state.filesUploadedRecords is not None) and (len(st.session_state.filesUploadedRecords) > 0):
removed = {st.session_state.filesUploadedRecords[0].file_id : st.session_state.filesUploadedRecords[0]}
st.session_state.filesUploadedRecords = None
return removed
# Files that were just removed
currentFileIds = [file_obj.file_id for file_obj in files]
for file in st.session_state.filesUploadedRecords:
if file.file_id not in currentFileIds:
removed[file.file_id] = file
# Removing the crossed off files from active files directory that we maintained
for toRemove in removed.values():
st.session_state.filesUploadedRecords.remove(toRemove)
return removed
# --------------------------------------------------------- Helper Function/Class 8
# To parse PDF and turn to vector embeddings
def buildVectorDatabase(files, query, addOrRemove):
"""
files: Either a list of fileUploader Objects with file details
or that one filename which was just removed by the user and should be from the vector Database as well
addOrRemove: if true, add else remove
"""
# Create Embeddings and Add to the Vector Store
if addOrRemove:
embeddings = loadEmbeddingsModels()
collName = st.session_state.collectionName
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100)
loader_map = {
'.pdf': PyPDFLoader,# Add more mappings as needed
'.txt': TextLoader,
'wiki' : WikipediaLoader
}
docsMergedPDF = [
loader_map[Path(doc).suffix.lower()](file_path=doc)
for doc in files
]
loadersPDF = MergedDataLoader(loaders=docsMergedPDF)
pagesPDF = loadersPDF.load_and_split(text_splitter) # Split into pages
pagesWiki = []
for wikiQ in query:
loader = WikipediaLoader(query=wikiQ, load_max_docs=2)
try:
pagesW = loader.load_and_split(text_splitter)
if len(pagesW) > 0:
references = list(set(list(map(lambda x: x.metadata['source'], pagesW))))
st.session_state.vdbBuilt['Keyword ; ' + wikiQ.strip()] = references
pagesWiki += pagesW
except Exception as e:
message = str(e) + '\n' + f'Looks like we could not search for your key word {wikiQ}'
st.toast(body=message, icon="⚠️")
_ = st.session_state.vdbBuilt.pop('Keyword ; ' + wikiQ.strip(), None)
pages = pagesPDF + pagesWiki
st.session_state.vectorDatabase = Chroma.from_documents(documents= pages,
embedding= embeddings,
collection_name= collName,
) # Load the pages into vector database (Build ChromaDB)
# Delete corresponding embeddings from the vector store
else:
if files is not None:
st.session_state.vectorDatabase._collection.delete(where={"source": {'$eq':files}})
for src in query:
st.session_state.vectorDatabase._collection.delete(where={"source": {'$eq':src}})
# --------------------------------------------------------- Helper Function/Class 9
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class RetrievalChainGenerator:
def __init__(self, model_name, vectorStore):
self.model_name = model_name
self.groq_api_key = os.environ['GROQ_API_KEY']#os.getenv('GROQ_API_KEY')
self.vectorStore = vectorStore
self.chain = None
self.generate_retrieval_chain()
def generate_retrieval_chain(self):
llm = ChatGroq(groq_api_key=self.groq_api_key, model_name=self.model_name)
prompt = ChatPromptTemplate.from_template("""
Answer the following question based only on the provided context.
Think step by step before providing a detailed answer.
If using finance terms, briefly explain them for clarity.
Always return your complete response in html format.
I will tip you $200 if the user finds the answer helpful.