File size: 4,392 Bytes
c643e19
1feeea0
 
 
 
 
 
 
 
 
 
c643e19
787200f
 
 
 
 
 
 
 
 
d6bdb65
787200f
9db002b
 
 
 
 
787200f
 
 
 
 
63484ff
d6bdb65
787200f
 
9db002b
d6bdb65
787200f
 
d6bdb65
63484ff
 
d6bdb65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787200f
 
 
d6bdb65
787200f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import subprocess

# Define the command
command = 'CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python --no-cache-dir'

# Run the command using subprocess
try:
    subprocess.run(command, shell=True, check=True)
    print("Command executed successfully.")
except subprocess.CalledProcessError as e:
    print(f"Error: {e}")

import streamlit as st
from langchain.llms import LlamaCpp
from langchain.memory import ConversationBufferMemory
from langchain.chains import RetrievalQA
from langchain.embeddings import FastEmbedEmbeddings
from langchain.vectorstores import Chroma
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler 
from langchain import hub
from langchain.prompts import PromptTemplate

# Function to update the chat output dynamically
def update_chat_output(msg):
    st.session_state.messages.append({"role": "assistant", "content": msg})
    st.chat_message("assistant").write(msg)

def init_retriever():
    """
    Initialize and return the retriever function
    """
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
    llm = LlamaCpp(model_path="./models/llama-2-13b-chat.Q4_K_S.gguf", 
                   template = 0.4,
                   n_ctx=4000, 
                   max_tokens=4000,
                   n_gpu_layers = 10, 
                   n_batch = 512,
                   callback_manager=callback_manager,
                   verbose=True)
    
    embeddings = FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5", cache_dir="./embeddings/")
    db = Chroma(persist_directory="./vectordb/", embedding_function=embeddings)

    # prompt template πŸ“
    template = """
    You are a Experience human Resource Manager. When the employee asks you a question, you will have to refer the company policy and respond in a professional way. Make sure to sound Empethetic while being professional and sound like a Human!

    Try to summarise the content and keep the answer to the point.


    If you don't know the answer, just say that you don't know, don't try to make up an answer.

    When generating answer for the given question make sure to follow the example template!

    Example: 
    Question : how many paid leaves do i have ?
    Answer : The number of paid leaves varies depending on the type of leave, like privilege leave you're entitled to a maximum of 21 days in a calendar year. Other leaves might have different entitlements. thanks for asking!

    make sure to add "thanks for asking!" after every answer 

    {context}
    Question: {question}
    Answer:
    """

    rag_prompt_custom = PromptTemplate.from_template(template)


    qa_chain = RetrievalQA.from_chain_type(
        llm,
        retriever=db.as_retriever(),
        chain_type_kwargs={"prompt": rag_prompt_custom},
    )
    qa_chain.callback_manager = callback_manager
    qa_chain.memory = ConversationBufferMemory()
    
    return qa_chain

# Check if retriever is already initialized in the session state
if "retriever" not in st.session_state:
    st.session_state.retriever = init_retriever()

# Function to apply rounded edges using CSS
def add_rounded_edges(image_path="./randstad_featuredimage.png", radius=30):
    st.markdown(
        f'<style>.rounded-img{{border-radius: {radius}px; overflow: hidden;}}</style>',
        unsafe_allow_html=True,
    )
    st.image(image_path, use_column_width=True, output_format='auto')

# add side bar
with st.sidebar:
    # add Randstad logo
    add_rounded_edges()

st.title("πŸ’¬ HR Chatbot")
st.caption("πŸš€ A chatbot powered by Local LLM")

clear = False

# Add clear chat button
if st.button("Clear Chat History"):
    clear = True
    st.session_state.messages = []

if "messages" not in st.session_state:
    st.session_state.messages = [{"role": "assistant", "content": "How can I help you?"}]

for msg in st.session_state.messages:
    st.chat_message(msg["role"]).write(msg["content"])

if prompt := st.chat_input():
    st.session_state.messages.append({"role": "user", "content": prompt})
    st.chat_message("user").write(prompt)
    chain = st.session_state.retriever
    if clear:
        chain.clean()
    msg = chain.run(st.session_state.messages)
    st.session_state.messages.append({"role": "assistant", "content": msg})
    st.chat_message("assistant").write(msg)