|
import subprocess |
|
subprocess.run(['bash', "/setup.sh"], check=True) |
|
|
|
import streamlit as st |
|
from langchain.llms import LlamaCpp |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains import RetrievalQA |
|
from langchain.embeddings import FastEmbedEmbeddings |
|
from langchain.vectorstores import Chroma |
|
from langchain.callbacks.manager import CallbackManager |
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
from langchain import hub |
|
from langchain.prompts import PromptTemplate |
|
|
|
def init_retriever(): |
|
""" |
|
Initialize and return the retriever function |
|
""" |
|
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) |
|
llm = LlamaCpp(model_path="./models/llama-2-13b-chat.Q4_K_S.gguf", |
|
template = 0.4, |
|
n_ctx=4000, |
|
max_tokens=4000, |
|
n_gpu_layers = 50, |
|
n_batch = 512, |
|
callback_manager=callback_manager, |
|
verbose=True) |
|
|
|
embeddings = FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5", cache_dir="./embeddings/") |
|
db = Chroma(persist_directory="./vectordb/", embedding_function=embeddings) |
|
|
|
|
|
template = """ |
|
You are a Experience human Resource Manager. When the employee asks you a question, you will have to refer the company policy and respond in a professional way. Make sure to sound Empethetic while being professional and sound like a Human! |
|
|
|
Try to summarise the content and keep the answer to the point. |
|
|
|
|
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
|
|
When generating answer for the given question make sure to follow the example template! |
|
|
|
Example: |
|
Question : how many paid leaves do i have ? |
|
Answer : The number of paid leaves varies depending on the type of leave, like privilege leave you're entitled to a maximum of 21 days in a calendar year. Other leaves might have different entitlements. thanks for asking! |
|
|
|
make sure to add "thanks for asking!" after every answer |
|
|
|
{context} |
|
Question: {question} |
|
Answer: |
|
""" |
|
|
|
rag_prompt_custom = PromptTemplate.from_template(template) |
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm, |
|
retriever=db.as_retriever(), |
|
chain_type_kwargs={"prompt": rag_prompt_custom}, |
|
) |
|
qa_chain.callback_manager = callback_manager |
|
qa_chain.memory = ConversationBufferMemory() |
|
|
|
return qa_chain |
|
|
|
|
|
if "retriever" not in st.session_state: |
|
st.session_state.retriever = init_retriever() |
|
|
|
|
|
def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30): |
|
st.markdown( |
|
f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>', |
|
unsafe_allow_html=True, |
|
) |
|
st.image(image_path, use_column_width=True, output_format='auto') |
|
|
|
|
|
with st.sidebar: |
|
|
|
add_rounded_edges() |
|
|
|
st.title("π¬ HR Chatbot") |
|
st.caption("π A chatbot powered by Local LLM") |
|
|
|
clear = False |
|
|
|
|
|
if st.button("Clear Chat History"): |
|
clear = True |
|
st.session_state.messages = [] |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}] |
|
|
|
for msg in st.session_state.messages: |
|
st.chat_message(msg["role"]).write(msg["content"]) |
|
|
|
if prompt := st.chat_input(): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
st.chat_message("user").write(prompt) |
|
chain = st.session_state.retriever |
|
if clear: |
|
chain.clean() |
|
msg = chain.run(st.session_state.messages) |
|
st.session_state.messages.append({"role": "assistant", "content": msg}) |
|
st.chat_message("assistant").write(msg) |
|
|
|
|