Syed Junaid Iqbal
Update app.py
090e23f
raw
history blame
3.92 kB
import subprocess
subprocess.run(['bash', "/setup.sh"], check=True)
import streamlit as st
from langchain.llms import LlamaCpp
from langchain.memory import ConversationBufferMemory
from langchain.chains import RetrievalQA
from langchain.embeddings import FastEmbedEmbeddings
from langchain.vectorstores import Chroma
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain import hub
from langchain.prompts import PromptTemplate
def init_retriever():
"""
Initialize and return the retriever function
"""
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
llm = LlamaCpp(model_path="./models/llama-2-13b-chat.Q4_K_S.gguf",
template = 0.4,
n_ctx=4000,
max_tokens=4000,
n_gpu_layers = 50,
n_batch = 512,
callback_manager=callback_manager,
verbose=True)
embeddings = FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5", cache_dir="./embeddings/")
db = Chroma(persist_directory="./vectordb/", embedding_function=embeddings)
# prompt template πŸ“
template = """
You are a Experience human Resource Manager. When the employee asks you a question, you will have to refer the company policy and respond in a professional way. Make sure to sound Empethetic while being professional and sound like a Human!
Try to summarise the content and keep the answer to the point.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
When generating answer for the given question make sure to follow the example template!
Example:
Question : how many paid leaves do i have ?
Answer : The number of paid leaves varies depending on the type of leave, like privilege leave you're entitled to a maximum of 21 days in a calendar year. Other leaves might have different entitlements. thanks for asking!
make sure to add "thanks for asking!" after every answer
{context}
Question: {question}
Answer:
"""
rag_prompt_custom = PromptTemplate.from_template(template)
qa_chain = RetrievalQA.from_chain_type(
llm,
retriever=db.as_retriever(),
chain_type_kwargs={"prompt": rag_prompt_custom},
)
qa_chain.callback_manager = callback_manager
qa_chain.memory = ConversationBufferMemory()
return qa_chain
# Check if retriever is already initialized in the session state
if "retriever" not in st.session_state:
st.session_state.retriever = init_retriever()
# Function to apply rounded edges using CSS
def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30):
st.markdown(
f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>',
unsafe_allow_html=True,
)
st.image(image_path, use_column_width=True, output_format='auto')
# add side bar
with st.sidebar:
# add Randstad logo
add_rounded_edges()
st.title("πŸ’¬ HR Chatbot")
st.caption("πŸš€ A chatbot powered by Local LLM")
clear = False
# Add clear chat button
if st.button("Clear Chat History"):
clear = True
st.session_state.messages = []
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}]
for msg in st.session_state.messages:
st.chat_message(msg["role"]).write(msg["content"])
if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
st.chat_message("user").write(prompt)
chain = st.session_state.retriever
if clear:
chain.clean()
msg = chain.run(st.session_state.messages)
st.session_state.messages.append({"role": "assistant", "content": msg})
st.chat_message("assistant").write(msg)