Spaces:
Sleeping
Sleeping
import subprocess | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
import streamlit as st | |
from langchain.llms import LlamaCpp | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import FastEmbedEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from langchain import hub | |
from langchain.prompts import PromptTemplate | |
# Function to update the chat output dynamically | |
def update_chat_output(msg): | |
st.session_state.messages.append({"role": "assistant", "content": msg}) | |
st.chat_message("assistant").write(msg) | |
def init_retriever(): | |
""" | |
Initialize and return the retriever function | |
""" | |
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) | |
llm = LlamaCpp(model_path="./models/llama-2-13b-chat.Q4_K_S.gguf", | |
template = 0.4, | |
n_ctx=4000, | |
max_tokens=4000, | |
n_gpu_layers = 50, | |
n_batch = 512, | |
callback_manager=callback_manager, | |
verbose=True) | |
embeddings = FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5", cache_dir="./embeddings/") | |
db = Chroma(persist_directory="./vectordb/", embedding_function=embeddings) | |
# prompt template π | |
template = """ | |
You are a Experience human Resource Manager. When the employee asks you a question, you will have to refer the company policy and respond in a professional way. Make sure to sound Empethetic while being professional and sound like a Human! | |
Try to summarise the content and keep the answer to the point. | |
If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
When generating answer for the given question make sure to follow the example template! | |
Example: | |
Question : how many paid leaves do i have ? | |
Answer : The number of paid leaves varies depending on the type of leave, like privilege leave you're entitled to a maximum of 21 days in a calendar year. Other leaves might have different entitlements. thanks for asking! | |
make sure to add "thanks for asking!" after every answer | |
{context} | |
Question: {question} | |
Answer: | |
""" | |
rag_prompt_custom = PromptTemplate.from_template(template) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm, | |
retriever=db.as_retriever(), | |
chain_type_kwargs={"prompt": rag_prompt_custom}, | |
) | |
qa_chain.callback_manager = callback_manager | |
qa_chain.memory = ConversationBufferMemory() | |
return qa_chain | |
# Check if retriever is already initialized in the session state | |
if "retriever" not in st.session_state: | |
# Define the command | |
command = 'CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir' | |
# Run the command using subprocess | |
try: | |
subprocess.run(command, shell=True, check=True) | |
print("Command executed successfully.") | |
except subprocess.CalledProcessError as e: | |
print(f"Error: {e}") | |
st.session_state.retriever = init_retriever() | |
# Function to apply rounded edges using CSS | |
def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30): | |
st.markdown( | |
f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>', | |
unsafe_allow_html=True, | |
) | |
st.image(image_path, use_column_width=True, output_format='auto') | |
# add side bar | |
with st.sidebar: | |
# add Randstad logo | |
add_rounded_edges() | |
st.title("π¬ HR Chatbot") | |
st.caption("π A chatbot powered by Local LLM") | |
clear = False | |
# Add clear chat button | |
if st.button("Clear Chat History"): | |
clear = True | |
st.session_state.messages = [] | |
if "messages" not in st.session_state: | |
st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}] | |
for msg in st.session_state.messages: | |
st.chat_message(msg["role"]).write(msg["content"]) | |
if prompt := st.chat_input(): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
st.chat_message("user").write(prompt) | |
chain = st.session_state.retriever | |
if clear: | |
chain.clean() | |
msg = chain.run(st.session_state.messages) | |
st.session_state.messages.append({"role": "assistant", "content": msg}) | |
st.chat_message("assistant").write(msg) | |