gfhayworth commited on
Commit
6970ab5
1 Parent(s): 8dd4ecb

Upload app.py

Browse files

replace to use LLM and new data sources

Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """wiki_chat_3_hack.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1chXsWeq1LzbvYIs6H73gibYmNDRbIgkD
8
+ """
9
+
10
+ #!pip install gradio
11
+
12
+ #!pip install -U sentence-transformers
13
+
14
+ #!pip install datasets
15
+
16
+ #!pip install langchain
17
+
18
+ #!pip install openai
19
+
20
+ #!pip install faiss-cpu
21
+
22
+ #import numpy as np
23
+ import gradio as gr
24
+ #import random
25
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
26
+ from torch import tensor as torch_tensor
27
+ from datasets import load_dataset
28
+
29
+ """# import models"""
30
+
31
+ bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
32
+ bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
33
+
34
+ #The bi-encoder will retrieve top_k documents. We use a cross-encoder, to re-rank the results list to improve the quality
35
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
36
+
37
+ """# import datasets"""
38
+
39
+ dataset = load_dataset("gfhayworth/hack_policy", split='train')
40
+ mypassages = list(dataset.to_pandas()['psg'])
41
+
42
+ dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
43
+ dataset_embed_pd = dataset_embed.to_pandas()
44
+ mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
45
+
46
+ def search(query, passages = mypassages, doc_embedding = mycorpus_embeddings, top_k=20, top_n = 1):
47
+ question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
48
+ question_embedding = question_embedding #.cuda()
49
+ hits = util.semantic_search(question_embedding, doc_embedding, top_k=top_k)
50
+ hits = hits[0] # Get the hits for the first query
51
+
52
+ ##### Re-Ranking #####
53
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
54
+ cross_scores = cross_encoder.predict(cross_inp)
55
+
56
+ # Sort results by the cross-encoder scores
57
+ for idx in range(len(cross_scores)):
58
+ hits[idx]['cross-score'] = cross_scores[idx]
59
+
60
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
61
+ predictions = hits[:top_n]
62
+ return predictions
63
+ # for hit in hits[0:3]:
64
+ # print("\t{:.3f}\t{}".format(hit['cross-score'], mypassages[hit['corpus_id']].replace("\n", " ")))
65
+
66
+ def get_text(qry):
67
+ predictions = search(qry)
68
+ prediction_text = []
69
+ for hit in predictions:
70
+ prediction_text.append("{}".format(mypassages[hit['corpus_id']]))
71
+ return prediction_text
72
+
73
+ # def prt_rslt(qry):
74
+ # rslt = get_text(qry)
75
+ # for r in rslt:
76
+ # print(r)
77
+
78
+ # prt_rslt("What is the name of the plan described by this summary of benefits?")
79
+
80
+ """# new LLM based functions"""
81
+
82
+ import os
83
+ os.environ["OPENAI_API_KEY"] = "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"
84
+
85
+ from langchain.llms import OpenAI
86
+ from langchain.embeddings.openai import OpenAIEmbeddings
87
+ from langchain.embeddings import HuggingFaceEmbeddings
88
+
89
+ from langchain.text_splitter import CharacterTextSplitter
90
+ #from langchain.vectorstores.faiss import FAISS
91
+ from langchain.docstore.document import Document
92
+ from langchain.prompts import PromptTemplate
93
+ from langchain.chains.question_answering import load_qa_chain
94
+ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
95
+ from langchain.chains import VectorDBQAWithSourcesChain
96
+
97
+ chain_qa = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="stuff")
98
+
99
+ def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings):
100
+ predictions = search(qry, passages = passages, doc_embedding = doc_embedding, top_n=5, )
101
+ prediction_text = []
102
+ for hit in predictions:
103
+ page_content = passages[hit['corpus_id']]
104
+ metadata = {"source": hit['corpus_id']}
105
+ result = Document(page_content=page_content, metadata=metadata)
106
+ prediction_text.append(result)
107
+ return prediction_text
108
+
109
+ #mypassages[0]
110
+
111
+ #mycorpus_embeddings[0][:5]
112
+
113
+ # query = "What is the name of the plan described by this summary of benefits?"
114
+ # mydocs = get_text_fmt(query)
115
+ # print(len(mydocs))
116
+ # for d in mydocs:
117
+ # print(d)
118
+
119
+ # chain_qa.run(input_documents=mydocs, question=query)
120
+
121
+ def get_llm_response(message):
122
+ mydocs = get_text_fmt(message)
123
+ responses = chain_qa.run(input_documents=mydocs, question=message)
124
+ return responses
125
+
126
+ """# chat example"""
127
+
128
+ def chat(message, history):
129
+ history = history or []
130
+ message = message.lower()
131
+
132
+ response = get_llm_response(message)
133
+ history.append((message, response))
134
+ return history, history
135
+
136
+ css=".gradio-container {background-color: lightgray}"
137
+
138
+ with gr.Blocks(css=css) as demo:
139
+ history_state = gr.State()
140
+ gr.Markdown('# Hack QA')
141
+ title='Benefit Chatbot'
142
+ description='chatbot with search on Health Benefits'
143
+ with gr.Row():
144
+ chatbot = gr.Chatbot()
145
+ with gr.Row():
146
+ message = gr.Textbox(label='Input your question here:',
147
+ placeholder='What is the name of the plan described by this summary of benefits?',
148
+ lines=1)
149
+ submit = gr.Button(value='Send',
150
+ variant='secondary').style(full_width=False)
151
+ submit.click(chat,
152
+ inputs=[message, history_state],
153
+ outputs=[chatbot, history_state])
154
+ gr.Examples(
155
+ examples=["What is the name of the plan described by this summary of benefits?",
156
+ "How much is the monthly premium?",
157
+ "How much do I have to pay if I am admitted to the hospital?"],
158
+ inputs=message
159
+ )
160
+
161
+ demo.launch()
162
+