File size: 2,496 Bytes
b5f78f9
 
 
 
 
 
a89ac38
b5f78f9
a89ac38
b5f78f9
3a583fb
b5f78f9
 
 
 
a89ac38
b5f78f9
 
 
a89ac38
 
 
 
 
 
7672bfa
96fbc4e
b5f78f9
dd1a630
b5f78f9
 
 
a89ac38
e5cc6c5
dd1a630
 
e5cc6c5
b5f78f9
 
 
e5cc6c5
 
 
 
 
 
 
 
 
 
 
 
 
b5f78f9
 
a89ac38
 
b5f78f9
 
 
 
 
 
dd1a630
b5f78f9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
import os

from langchain import OpenAI, ConversationChain
from langchain.prompts import PromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.chains import RetrievalQAWithSourcesChain

from langchain.chains.conversation.memory import ConversationEntityMemory
from langchain.chains.conversation.prompt import ENTITY_MEMORY_CONVERSATION_TEMPLATE

from langchain import LLMChain

memory = ConversationBufferMemory(memory_key="chat_history")

persist_directory="db"
llm=OpenAI(model_name = "text-davinci-003", temperature=0)
model_name = "hkunlp/instructor-large"
embed_instruction = "Represent the text from the BMW website for retrieval"
query_instruction = "Query the most relevant text from the BMW website"
embeddings = HuggingFaceInstructEmbeddings(model_name=model_name, embed_instruction=embed_instruction, query_instruction=query_instruction)
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
chain = RetrievalQAWithSourcesChain.from_chain_type(llm, chain_type="stuff", retriever=vectordb.as_retriever(), memory=memory)

def chat(message, history):
    history = history or []
    response = ""
    try:
        response = chain.run(input=message)
        markdown = generate_markdown(response)
    except Exception as e:
        print(f"Erorr: {e}")
    history.append((message, markdown))

    return history, history

def generate_markdown(obj):
    md_string = ""

    if 'answer' in obj:
        md_string += f"**Answer:**\n\n{obj['answer']}\n"

    if 'sources' in obj:
        sources_list = obj['sources'].strip().split('\n')
        md_string += "**Sources:**\n\n"
        for i, source in enumerate(sources_list):
            md_string += f"{i + 1}. {source}\n"
    
    return md_string

with gr.Blocks() as demo:
    gr.Markdown("<h3><center>BMW Chat Bot</center></h3>")
    gr.Markdown("<p><center>Ask questions about BMW</center></p>")
    chatbot = gr.Chatbot()
    with gr.Row():
        inp = gr.Textbox(placeholder="Question",label =None)
        btn = gr.Button("Run").style(full_width=False)
        state = gr.State()
        agent_state = gr.State()
        btn.click(chat, [inp, state],[chatbot, state])
if __name__ == '__main__':
    demo.launch()