File size: 6,116 Bytes
aa5f44e
cc1e9f0
 
 
 
 
 
 
 
aa5f44e
 
 
 
 
 
 
cc1e9f0
7e3ab29
035f498
cc1e9f0
aa5f44e
 
 
cc1e9f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa5f44e
cc1e9f0
 
 
 
 
 
aa5f44e
 
 
 
 
 
 
 
 
 
 
 
 
 
801fc68
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM
from langchain import PromptTemplate
from langchain.llms import HuggingFacePipeline
from langchain.chains.question_answering import load_qa_chain
from langchain.memory import ConversationSummaryBufferMemory
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma

def add_text(history, text):
    history = history + [[text, None]]
    return history, gr.update(value="", interactive=False)

def process_input(history):
    inp = history[-1][0]
    # response = "I have received your input, which is: \n" + inp
    if len(history) <= 2:
       chat_bot.reset_context()
    response = chat_bot.chat(inp)
    history[-1][1] = response
    return history

def build_qa_chain():
  torch.cuda.empty_cache()
  # Defining our prompt content.
  # langchain will load our similar documents as {context}
  template = """You are a chatbot having a conversation with a human. You are asked to answer career questions, and you are helping the human apply for jobs.
  Given the following extracted parts of a long document and a question, answer the user question. If you don't know, say that you do not know. 
  
  {context}
 
  {chat_history}
 
  {human_input}
 
  Response:
  """
  prompt = PromptTemplate(input_variables=['context', 'human_input', 'chat_history'], template=template)
 
  # Increase max_new_tokens for a longer response
  # Other settings might give better results! Play around
  model_name = "databricks/dolly-v2-3b" # can use dolly-v2-3b, dolly-v2-7b or dolly-v2-12b for smaller model and faster inferences.
  instruct_pipeline = pipeline(model=model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", 
                               return_full_text=True, max_new_tokens=256, top_p=0.95, top_k=50)
  hf_pipe = HuggingFacePipeline(pipeline=instruct_pipeline)
 
  # Add a summarizer to our memory conversation
  # Let's make sure we don't summarize the discussion too much to avoid losing to much of the content
 
  # Models we'll use to summarize our chat history
  # We could use one of these models: https://huggingface.co/models?filter=summarization. facebook/bart-large-cnn gives great results, we'll use t5-small for memory
  summarize_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
  summarize_tokenizer = AutoTokenizer.from_pretrained("t5-small", padding_side="left", model_max_length = 512)
  pipe_summary = pipeline("summarization", model=summarize_model, tokenizer=summarize_tokenizer) #, max_new_tokens=500, min_new_tokens=300
  # langchain pipeline doesn't support summarization yet, we added it as temp fix in the companion notebook _resources/00-init 
  hf_summary = HuggingFacePipeline(pipeline=pipe_summary)
  #will keep 500 token and then ask for a summary. Removes prefix as our model isn't trained on specific chat prefix and can get confused.
  memory = ConversationSummaryBufferMemory(llm=hf_summary, memory_key="chat_history", input_key="human_input", max_token_limit=500, human_prefix = "", ai_prefix = "")
 
  # Set verbose=True to see the full prompt:
  print("loading chain, this can take some time...")
  return load_qa_chain(llm=hf_pipe, chain_type="stuff", verbose=True, prompt=prompt, memory=memory)

class ChatBot():
  def __init__(self, db):
    self.reset_context()
    self.db = db
 
  def reset_context(self):
    self.sources = []
    self.discussion = []
    # Building the chain will load Dolly and can take some time depending on the model size and your GPU
    self.qa_chain = build_qa_chain()
 
  def get_similar_docs(self, question, similar_doc_count):
    return self.db.similarity_search(question, k=similar_doc_count)
 
  def chat(self, question):
    # Keep the last 3 discussion to search similar content
    self.discussion.append(question)
    similar_docs = self.get_similar_docs(" \n".join(self.discussion[-3:]), similar_doc_count=2)
    # Remove similar doc if they're already in the last questions (as it's already in the history)
    similar_docs = [doc for doc in similar_docs if doc.metadata['source'] not in self.sources[-3:]]
 
    result = self.qa_chain({"input_documents": similar_docs, "human_input": question})
    # Cleanup the answer for better display:
    answer = result['output_text'].strip().capitalize()
    result_html = f"<p><blockquote style=\"font-size:18px\">{answer}</blockquote></p>"
    result_html += "<p><hr/></p>"
    for d in result["input_documents"]:
      source_id = d.metadata["source"]
      self.sources.append(source_id)
      result_html += f"<p>(Source: <a href=\"https://workplace.stackexchange.com/a/{source_id}\">{source_id}</a>)</p>"
    return result_html

with gr.Blocks() as demo:
    global chat_bot
    workplace_vector_db_path = "workplace_db"

    hf_embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
    chroma_db = Chroma(collection_name="workplace_docs", embedding_function=hf_embed, persist_directory=workplace_vector_db_path)
    chat_bot = ChatBot(chroma_db)
    with gr.Row():
        output_box = gr.Chatbot([[None, "Welcome! What can I help you with today?"]], show_label=False).style(height=450)
    with gr.Row(): # TODO: Box or Group instead of row?
        with gr.Column(scale=7):
            input_box = gr.Textbox(show_label=False, placeholder="Ask something here and press enter...").style(container=False)
        with gr.Column(scale=1):
            clear_btn = gr.Button(value="Clear")
    
    txt_msg = input_box.submit(add_text, inputs=[output_box, input_box], outputs=[output_box, input_box],
                               queue=False).then(process_input, output_box, output_box)
    txt_msg.then(lambda: gr.update(interactive=True), inputs=None, outputs=input_box, queue=False)

    clear_btn.click(lambda: None, inputs=None, outputs=output_box, queue=False)

demo.launch() # server_port=7860, show_api=False, share=False, inline=True) # , share = True, inline = True)