anandaa commited on
Commit
cc1e9f0
β€’
1 Parent(s): 9bf0dc9

integrate model in app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -5
app.py CHANGED
@@ -1,5 +1,12 @@
1
  import gradio as gr
2
- # from webapp import webapp
 
 
 
 
 
 
 
3
 
4
  def add_text(history, text):
5
  history = history + [[text, None]]
@@ -7,11 +14,91 @@ def add_text(history, text):
7
 
8
  def process_input(history):
9
  inp = history[-1][0]
10
- response = "I have received your input, which is: \n" + inp
 
11
  history[-1][1] = response
12
  return history
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  with gr.Blocks() as demo:
 
 
 
 
 
 
15
  gr.Markdown('''
16
  ## **CareerPal**
17
  here to ease your anxiety about your future
@@ -31,6 +118,3 @@ with gr.Blocks() as demo:
31
  clear_btn.click(lambda: None, inputs=None, outputs=output_box, queue=False)
32
 
33
  demo.launch() # server_port=7860, show_api=False, share=False, inline=True) # , share = True, inline = True)
34
-
35
- # set FLASK_APP=app.py
36
- # flask run -h localhost -p 7860
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM
4
+ from langchain import PromptTemplate
5
+ from langchain.llms import HuggingFacePipeline
6
+ from langchain.chains.question_answering import load_qa_chain
7
+ from langchain.memory import ConversationSummaryBufferMemory
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.vectorstores import Chroma
10
 
11
  def add_text(history, text):
12
  history = history + [[text, None]]
 
14
 
15
  def process_input(history):
16
  inp = history[-1][0]
17
+ # response = "I have received your input, which is: \n" + inp
18
+ response = chat_bot.chat(inp)
19
  history[-1][1] = response
20
  return history
21
 
22
+ def build_qa_chain():
23
+ torch.cuda.empty_cache()
24
+ # Defining our prompt content.
25
+ # langchain will load our similar documents as {context}
26
+ template = """You are a chatbot having a conversation with a human. You are asked to answer career questions, and you are helping the human apply for jobs.
27
+ Given the following extracted parts of a long document and a question, answer the user question. If you don't know, say that you do not know.
28
+
29
+ {context}
30
+
31
+ {chat_history}
32
+
33
+ {human_input}
34
+
35
+ Response:
36
+ """
37
+ prompt = PromptTemplate(input_variables=['context', 'human_input', 'chat_history'], template=template)
38
+
39
+ # Increase max_new_tokens for a longer response
40
+ # Other settings might give better results! Play around
41
+ model_name = "databricks/dolly-v2-3b" # can use dolly-v2-3b, dolly-v2-7b or dolly-v2-12b for smaller model and faster inferences.
42
+ instruct_pipeline = pipeline(model=model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto",
43
+ return_full_text=True, max_new_tokens=256, top_p=0.95, top_k=50)
44
+ hf_pipe = HuggingFacePipeline(pipeline=instruct_pipeline)
45
+
46
+ # Add a summarizer to our memory conversation
47
+ # Let's make sure we don't summarize the discussion too much to avoid losing to much of the content
48
+
49
+ # Models we'll use to summarize our chat history
50
+ # We could use one of these models: https://huggingface.co/models?filter=summarization. facebook/bart-large-cnn gives great results, we'll use t5-small for memory
51
+ summarize_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
52
+ summarize_tokenizer = AutoTokenizer.from_pretrained("t5-small", padding_side="left", model_max_length = 512)
53
+ pipe_summary = pipeline("summarization", model=summarize_model, tokenizer=summarize_tokenizer) #, max_new_tokens=500, min_new_tokens=300
54
+ # langchain pipeline doesn't support summarization yet, we added it as temp fix in the companion notebook _resources/00-init
55
+ hf_summary = HuggingFacePipeline(pipeline=pipe_summary)
56
+ #will keep 500 token and then ask for a summary. Removes prefix as our model isn't trained on specific chat prefix and can get confused.
57
+ memory = ConversationSummaryBufferMemory(llm=hf_summary, memory_key="chat_history", input_key="human_input", max_token_limit=500, human_prefix = "", ai_prefix = "")
58
+
59
+ # Set verbose=True to see the full prompt:
60
+ print("loading chain, this can take some time...")
61
+ return load_qa_chain(llm=hf_pipe, chain_type="stuff", verbose=True, prompt=prompt, memory=memory)
62
+
63
+ class ChatBot():
64
+ def __init__(self, db):
65
+ self.reset_context()
66
+ self.db = db
67
+
68
+ def reset_context(self):
69
+ self.sources = []
70
+ self.discussion = []
71
+ # Building the chain will load Dolly and can take some time depending on the model size and your GPU
72
+ self.qa_chain = build_qa_chain()
73
+
74
+ def get_similar_docs(self, question, similar_doc_count):
75
+ return self.db.similarity_search(question, k=similar_doc_count)
76
+
77
+ def chat(self, question):
78
+ # Keep the last 3 discussion to search similar content
79
+ self.discussion.append(question)
80
+ similar_docs = self.get_similar_docs(" \n".join(self.discussion[-3:]), similar_doc_count=2)
81
+ # Remove similar doc if they're already in the last questions (as it's already in the history)
82
+ similar_docs = [doc for doc in similar_docs if doc.metadata['source'] not in self.sources[-3:]]
83
+
84
+ result = self.qa_chain({"input_documents": similar_docs, "human_input": question})
85
+ # Cleanup the answer for better display:
86
+ answer = result['output_text'].strip().capitalize()
87
+ result_html = f"<p><blockquote style=\"font-size:18px\">{answer}</blockquote></p>"
88
+ result_html += "<p><hr/></p>"
89
+ for d in result["input_documents"]:
90
+ source_id = d.metadata["source"]
91
+ self.sources.append(source_id)
92
+ result_html += f"<p>(Source: <a href=\"https://workplace.stackexchange.com/a/{source_id}\">{source_id}</a>)</p>"
93
+ return result_html
94
+
95
  with gr.Blocks() as demo:
96
+ global chat_bot
97
+ workplace_vector_db_path = "workplace_db"
98
+
99
+ hf_embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
100
+ chroma_db = Chroma(collection_name="workplace_docs", embedding_function=hf_embed, persist_directory=workplace_vector_db_path)
101
+ chat_bot = ChatBot(chroma_db)
102
  gr.Markdown('''
103
  ## **CareerPal**
104
  here to ease your anxiety about your future
 
118
  clear_btn.click(lambda: None, inputs=None, outputs=output_box, queue=False)
119
 
120
  demo.launch() # server_port=7860, show_api=False, share=False, inline=True) # , share = True, inline = True)