File size: 5,026 Bytes
fe57ff5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad24004
fe57ff5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
import os

from langchain_community.vectorstores import FAISS
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_nvidia_ai_endpoints import ChatNVIDIA


def format_docs(docs):
    print("-------- Documents ------------")
    print(docs)
    return "\n\n".join(doc.page_content for doc in docs)


embeddings = NVIDIAEmbeddings(model="nvidia/nv-embedqa-mistral-7b-v2")
db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
retriever = db.as_retriever()
compressor = FlashrankRerank()
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=retriever
)

st.title("KCE Chatbot")
with st.expander("Disclaimer", icon="ℹ️"):
    st.info("""
    We appreciate your engagement with our chatbot! We hope this chatbot can help you with the questions you have regarding with the KCE company.
    This chatbot is a demonstration preview. While the system is designed to provide helpful and informative responses by retrieving and generating relevant information, it is important to note the following:
    1. Potential for Inaccuracies: The chatbot may sometimes produce incorrect or misleading information. The responses generated by the LLM are based on patterns in the data it has been trained on and the information retrieved, which might not always be accurate or up-to-date.
    2. Hallucinations: The LLM might generate responses that seem plausible but are entirely fabricated. These "hallucinations" are a known limitation of current LLM technology and can occur despite the retrieval mechanism.\n
    By interacting with this chatbot, you acknowledge and accept these limitations and agree to use the information provided responsibly.
    """)

models_dict = {
    "meta/llama-3.1-405b": "meta/llama-3.1-405b-instruct",
    "meta/llama-3.1-70b": "meta/llama-3.1-70b-instruct",
    "meta/llama3.1-8b": "meta/llama-3.1-8b-instruct",
    "google/gemma-2-27b": "google/gemma-2-27b-it",
    "google/gemma-7b": "google/gemma-7b",
    "microsoft/phi-3-mini-128k": "microsoft/phi-3-mini-128k-instruct",
    "microsoft/phi-3-medium-4k": "microsoft/phi-3-medium-4k-instruct"
}


# openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
model = st.sidebar.selectbox(
    "Choose model",
    tuple(models_dict.keys()),
    label_visibility="visible",
)
st.sidebar.write(f"Selected model: {model}")



def response_generator(message):
    llm = ChatNVIDIA(model=models_dict[model])
    prompt = ChatPromptTemplate.from_messages([
        ('system',
        "You are a KCE chatbot, and you are assisting customers with the inquires about the company."
        "Answer the questions witht the provided context. Do not include based on the context or based on the documents in your answer."
        "Remember that your job is to represent KCE company."
        "Please say you do not know if you do not know or cannot find the information needed."
        "\n Question: {question} \nContext: {context}"),
        ('user', "{question}")
    ])

    rag_chain = (
        {"context": retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    # response = f"Response to: {message}"
    # for word in response.split():
    #     yield word + " "
    #     time.sleep(0.5)
    partial_message=""
    for chunk in rag_chain.stream(message):
        # partial_message = partial_message + chunk
        yield partial_message + chunk
    # response = random.choice(
    #     [
    #         'Hello there! How can I asist you today?',
    #         'Hi, human! Is there anything I can help you with?',
    #         'Do you need any help?'
    #     ]
    # )
    
    # for word in response.split():
    #     yield word + " "
    #     time.sleep(0.05)


# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# Accept user input
if prompt := st.chat_input("Please type your question here"):
    # Add user message to chat history
    st.session_state.messages.append({"role": "user", "content": prompt})
    # Display user message in chat message container
    with st.chat_message("user"):
        st.markdown(prompt)

    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        response = st.write_stream(response_generator(prompt))
    # Add assistant response to chat history
    st.session_state.messages.append({"role": "assistant", "content": response})