Retrieval-Augmented-Generation / streamlit_app.py
arif97's picture
Final Touch
d987e6a
#------------------------------------------------------------------------------ Importing necessary libraries and loading those variables
import streamlit as st
from streamlit_lottie import st_lottie, st_lottie_spinner
import os
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
st.set_page_config(layout="wide",
page_title="RAG With Llama 3",
page_icon="Lottie Animations/LlamaIcon.jpeg")
#------------------------------------------------------------------------------ Importing Backend Functions from _helper_functions py file
from _helper_functions import loadLottieFile
from _helper_functions import initialRampUp
from _helper_functions import navigationBar
from _helper_functions import removedOrAdded
from _helper_functions import buildVectorDatabase
from _helper_functions import RetrievalChainGenerator
from _helper_functions import loadItOnce
#------------------------------------------------------------------------------ Importing Frontend Functions from _helper_functions py file
from _helper_functions import display_main_title
from _helper_functions import display_alert_note
from _helper_functions import display_attention_text
from _helper_functions import display_custom_arrow
from _helper_functions import display_heading_box
from _helper_functions import display_error_message
from _helper_functions import display_small_text
from _helper_functions import display_response_message
from _helper_functions import display_question_box
from _helper_functions import display_allCitations
#------------------------------------------------------------------------------ Loading all lottie animations to showcase progress
cwd = Path.cwd()
filePath = cwd / "Lottie Animations"
llama3 = loadLottieFile(filePath / "llama3.json")
finetuning = loadLottieFile(filePath / "finetuning.json")
forecasting = loadLottieFile(filePath / "forecasting.json")
buildingDatabase = loadLottieFile(filePath / "buildingDatabase.json")
fancyload = loadLottieFile(filePath / "fancyloading.json")
citations = loadLottieFile(filePath / "citations.json")
vdbList = loadLottieFile(filePath / "knowledgeBase1.json")
#------------------------------------------------------------------------------ Declaring Session variables for the app life
if 'initialRampUp' not in st.session_state:
st.session_state.initialRampUp = True
#------------------------------------------------------------------------------ Intro Title
display_main_title("Let's Chat With Llama 3!!!", st)
#------------------------------------------------------------------------------ Navigation Bar
selected_option = navigationBar()
#------------------------------------------------------------------------------ Divider
st.divider()
#------------------------------------------------------------------------------ First Page: Retrieval Augmented Generation
if selected_option == "Retrieval Augmented Generation":
#------------------------------------------------------------------------------ And create a fresh one on Refresh
if st.session_state.initialRampUp:
initialRampUp(llamaAnimation=llama3)
st.session_state.initialRampUp = False
#------------------------------------------------------------------------------ MAIN CONTAINER
with st.container():
leftCol, upld, rightCol = st.columns((3,4,3))
#------------------------------------------------------------------------------ Heading in the right column of Parent Container
display_heading_box(message = "Knowledge Base Contents", container= rightCol)
loadItOnce(container=rightCol, animation=vdbList, height=200, quality='low')
#------------------------------------------------------------------------------ Heading in the left column of Parent Container
display_heading_box(message = "Citations for responses", container= leftCol)
loadItOnce(container=leftCol, animation=citations, height=200, quality='low')
upld.markdown("#")
upld.markdown("#")
#------------------------------------------------------------------------------ Alert Message in the Middle Column of parent container
display_alert_note(message="Note: \
Multiple Files with same names will be considered unique while constructing Vector Embeddings \
It takes a little bit of time for Vector Embeddings to be built, BE PATIENT!", container= upld)
upld.markdown("#")
upld.markdown("#")
#------------------------------------------------------------------------------ Build Message in the Middle Column of parent container
display_attention_text(text="Build your Knowledge Base (Vector DB)!", container=upld)
with upld.container():
#------------------------------------------------------------------------------ Divide the middle parent container into yet 3 more columns (Child Container 1)
upldLeft, upldcenter, upldRight = st.columns((1,5,1))
#------------------------------------------------------------------------------ Displaying right arrow on left column of child container 1
upldLeft.markdown("###")
upldLeft.markdown("###")
display_custom_arrow(direction="right", container=upldLeft)
#------------------------------------------------------------------------------ Upload option in the middle column of child container 1
uploadedFiles = upldcenter.file_uploader(label= "Upload or Add Documents",
type=['pdf', 'txt'],
accept_multiple_files=True,
key="fileUpload"
)
#------------------------------------------------------------------------------ Displaying left arrow on right child container 1
upldRight.markdown("###")
upldRight.markdown("###")
display_custom_arrow(direction="left", container=upldRight)
#------------------------------------------------------------------------------ Query Inputs for Wikipedia
queryInputs = upldcenter.text_input(label="Type keywords to enhance augmented search & generation! Separate each keyword with ';'",
placeholder="Domain Information will be scrapped from wikipedia based on key words you enter",
)
#------------------------------------------------------------------------------ Removed any files that were removed from upload options
# removedOrAdded() returns a list of files that have been removed from Upload Component
rem = removedOrAdded(uploadedFiles)
st.session_state.filesUploadedRecords = uploadedFiles
# If there is a file that has been removed, delete its pertaining content from Vector Database and from state variable 'vdbBuilt' as well
if len(rem) > 0:
fileName = st.session_state.vdbBuilt.pop(list(rem.keys())[0], None)
if fileName is not None:
# addOrRemove=False indicates you want to delete the file or query passed
# You can pass file and query both at the same time
# You can pass file as str to delete but keep query as list
# You can pass query as list to delete but keep files as None
buildVectorDatabase(files=str(fileName), addOrRemove=False, query= [])
with upld.container():
#------------------------------------------------------------------------------ Divide the middle parent container into yet 3 more columns (Child Container 2)
_, crtOrAd, _ = st.columns((2,2,2))
#------------------------------------------------------------------------------ Button to create Vector Databases in center Child Container 2
createOrAdd = crtOrAd.button("Create/Add to the knowledge base")
#------------------------------------------------------------------------------ If Button to create clicked, execute creation of Vector Databases
if createOrAdd:
with st_lottie_spinner(buildingDatabase , height=700):
fileNames = []
wikiQueries = []
# Check if there are any keywords that user entered
if queryInputs:
# If keywords entered, iterate through those and check if there is already a Vector Database pertaining to those Key words
# If not, create a new entry in state variable 'vdbBuilt' with key as keyword and value as None for now
# Value is the sources from metadat that will be entered later when its loaded from wikipedia
for w in queryInputs.split(';'):
if 'Keyword ; ' + w.strip() not in st.session_state.vdbBuilt.keys() and len(w.strip()) > 0:
st.session_state.vdbBuilt['Keyword ; ' + w.strip()] = None
wikiQueries.append(w.strip())
# See if there any new files that have been uploaded
# If there are:
# Combine its name and upload id to create unique name for the file
# Append it fileNames list
# Also, add in state variable 'vdbBuilt', file_id as key and combined file name as value
for file in uploadedFiles:
if file.file_id not in st.session_state.vdbBuilt.keys():
fileName = "Dump/" + file.file_id + "---" + file.name
with open(fileName, "wb") as f:
f.write(file.getvalue())
fileNames.append(fileName)
st.session_state.vdbBuilt[file.file_id] = fileName
# If there are any new files or new Queries for which vectorDatabases do not exists already, build it
if len(fileNames) > 0 or len(wikiQueries) > 0:
buildVectorDatabase(files= fileNames, addOrRemove=True, query = wikiQueries)
# Otherwise prompt the user, there's nothing new on which we can build
else:
#------------------------------------------------------------------------------ Display Error Message if there are no files/new files to create vector databases
display_error_message(message= "You have no new files or Keywords to create/add vector databases!", container=upld)
#------------------------------------------------------------------------------ Choice to remove certain Keywords based Knowledge base!
keyWordOptions = [key.split(';')[-1].strip() for key, _ in st.session_state.vdbBuilt.items() if key.startswith("Keyword ; ")]
optionsMessage = "You can cross off certain keywords from below if you'd prefer to remove contents relevant to your entered keywords, be removed from knowledge base" if len(keyWordOptions)>0 \
else "Your Knowledge Base does not have any keywords based contents scrapped from internet"
upldcenter.multiselect(
label= optionsMessage,
options=keyWordOptions,
default=keyWordOptions,
key='keywordsDBOptions')
popKeys = None
for key, value in st.session_state.vdbBuilt.items():
val = key.split(";")[-1].strip() if key.startswith('Keyword ; ') else None
if val is not None and val not in st.session_state.keywordsDBOptions:
popKeys = key
break
if popKeys:
buildVectorDatabase(files=None, addOrRemove=False, query=st.session_state.vdbBuilt[popKeys])
_ = st.session_state.vdbBuilt.pop(popKeys, None)
st.rerun()
#------------------------------------------------------------------------------ Display List of Files for Vector Databases available in right column of Parent Container
for key, value in st.session_state.vdbBuilt.items():
val = key.split(";")[-1].strip() if key.startswith('Keyword ; ') else value.split("---")[-1].strip()
display_small_text(val, rightCol)
with upld.container():
st.markdown("###")
st.divider()
st.markdown("###")
_, c1,c2, _ = st.columns((2,1,2,2))
loadItOnce(container=c1, animation=llama3, height=150, quality='low')
display_question_box(c2)
#------------------------------------------------------------------------------ Display Chat Input in center column of Parent Container if Vector Databases are available
if len(st.session_state.vdbBuilt) == 0:
st.chat_input(placeholder="Type your query here once you have built vector databases!",
disabled=True)
else:
query = st.chat_input(placeholder="Type your query here once you have built vector databases!",
disabled=False)
generatorLlama3_8b = RetrievalChainGenerator(model_name=os.environ['LLAMA3MODEL8B'], vectorStore=st.session_state.vectorDatabase)
if query:
with st.container():
with st_lottie_spinner(fancyload, height=400):
response = generatorLlama3_8b.chain.invoke({"input": query})
#------------------------------------------------------------------------------ Display response from AI in center column of Parent Container
display_response_message(response['answer'], upld)
#------------------------------------------------------------------------------ Display citations of responses from AI in left column of Parent Container
display_allCitations(response, leftCol)
#------------------------------------------------------------------------------ Second Page: Fine Tuning
elif selected_option == "Fine Tuning LLMs (Coming Soon)":
st_lottie(finetuning, quality='medium', height=700)
#------------------------------------------------------------------------------ Third Page: Forecasting
elif selected_option == "Forecasting LLMs (Coming Soon)":
st_lottie(forecasting, quality='high', height=700)
else:
pass