gfhayworth commited on
Commit
4d7e790
1 Parent(s): 88f055d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +311 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ #!pip install gradio
4
+ #!pip install -U sentence-transformers
5
+ #!pip install langchain
6
+ #!pip install openai
7
+ #!pip install -U chromadb
8
+
9
+ import gradio as gr
10
+ from sentence_transformers import SentenceTransformer, CrossEncoder, util
11
+ from langchain.llms import OpenAI
12
+ from langchain.docstore.document import Document
13
+ from langchain.prompts import PromptTemplate
14
+ from langchain.chains.question_answering import load_qa_chain
15
+ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
16
+ from langchain import LLMMathChain, SQLDatabase, SQLDatabaseChain, LLMChain
17
+ from langchain.agents import initialize_agent, Tool
18
+
19
+ # import sqlite3
20
+ import pandas as pd
21
+ import json
22
+
23
+ import chromadb
24
+ import os
25
+
26
+ # cxn = sqlite3.connect('./data/mbr.db')
27
+
28
+ """# import models"""
29
+
30
+ bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
31
+ bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
32
+
33
+ #The bi-encoder will retrieve top_k documents. We use a cross-encoder, to re-rank the results list to improve the quality
34
+ #cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
35
+
36
+ """# setup vector db
37
+ - chromadb
38
+ - https://docs.trychroma.com/getting-started
39
+ """
40
+
41
+ from chromadb.config import Settings
42
+
43
+ chroma_client = chromadb.Client(settings=Settings(
44
+ chroma_db_impl="duckdb+parquet",
45
+ persist_directory="./data/mychromadb/" # Optional, defaults to .chromadb/ in the current directory
46
+ ))
47
+
48
+ #!ls ./data/mychromadb/
49
+ #collection = chroma_client.create_collection(name="benefit_collection")
50
+ collection = chroma_client.get_collection(name="healthy_opt_collection", embedding_function=bi_encoder)
51
+
52
+ """### vector db search examples"""
53
+
54
+ def rtrv(qry,top_k=20):
55
+ results = collection.query(
56
+ query_embeddings=[ bi_encoder.encode(qry) ],
57
+ n_results=top_k,
58
+ )
59
+ return results
60
+
61
+ def vdb_qry(qry,top_k=10):
62
+ results = collection.query(
63
+ query_embeddings=[ bi_encoder.encode(qry) ],
64
+ n_results=top_k,
65
+ include=["metadatas", "documents", "distances","embeddings"]
66
+ )
67
+ rslt_pd = pd.DataFrame(results ).explode(['ids','documents', 'metadatas', 'distances', 'embeddings'])
68
+ rslt_fmt = pd.concat([rslt_pd.drop(['metadatas'], axis=1), rslt_pd['metadatas'].apply(pd.Series)], axis=1 )
69
+ return rslt_fmt
70
+
71
+ # qry = 'what should I do with my old card'
72
+ # rslt_fmt = vdb_qry(qry, top_k=10)
73
+ # rslt_fmt
74
+
75
+ # doc_lst = rslt_fmt[['documents']].values.tolist()
76
+ # len(doc_lst)
77
+
78
+ ## important to do this if you want to save the data for re-use
79
+ # chroma_client.persist()
80
+
81
+ """# Introduction
82
+ - example of the kind of question answering that is possible with this tool
83
+ - assumes we are answering for a member with a Healthy Options Card
84
+
85
+ *When will I get my card?*
86
+
87
+ # semantic search functions
88
+ """
89
+
90
+ ## choosing not to use rerank for this use case
91
+
92
+ # def rernk(query, collection=collection, top_k=20, top_n = 5):
93
+ # rtrv_rslts = rtrv(query, top_k=top_k)
94
+ # rtrv_ids = rtrv_rslts.get('ids')[0]
95
+ # rtrv_docs = rtrv_rslts.get('documents')[0]
96
+
97
+ # ##### Re-Ranking #####
98
+ # cross_inp = [[query, doc] for doc in rtrv_docs]
99
+ # cross_scores = cross_encoder.predict(cross_inp)
100
+
101
+ # # Sort results by the cross-encoder scores
102
+ # combined = list(zip(rtrv_ids, list(cross_scores)))
103
+ # sorted_tuples = sorted(combined, key=lambda x: x[1], reverse=True)
104
+ # sorted_ids = [t[0] for t in sorted_tuples[:top_n]]
105
+ # predictions = collection.get(ids=sorted_ids, include=["documents","metadatas"])
106
+ # return predictions
107
+ # #return cross_scores
108
+
109
+ def get_text_fmt(qry):
110
+ prediction_text = []
111
+ predictions = rtrv(qry, top_k = 5)
112
+ docs = predictions['documents'][0]
113
+ meta = predictions['metadatas'][0]
114
+ for i in range(len(docs)):
115
+ result = Document(page_content=docs[i], metadata=meta[i])
116
+ prediction_text.append(result)
117
+ return prediction_text
118
+
119
+ # get_text_fmt('can I buy fish?')
120
+
121
+ """# LLM based qa functions"""
122
+
123
+ llm = OpenAI(temperature=0)
124
+ # default model
125
+ # model_name: str = "text-davinci-003"
126
+ # instruction fine-tuned, sometimes referred to as GPT-3.5
127
+
128
+ template = """You are a friendly AI assistant for the insurance company Humana.
129
+ Given the following extracted parts of a long document and a question, create a succinct final answer.
130
+ If you don't know the answer, just say that you don't know. Don't try to make up an answer.
131
+ If the question is not about Humana or what you can buy with the card, politely inform the user that you are tuned to only answer questions about Humana Healthy Options.
132
+ QUESTION: {question}
133
+ =========
134
+ {summaries}
135
+ =========
136
+ FINAL ANSWER:"""
137
+ PROMPT = PromptTemplate(template=template, input_variables=["summaries", "question"])
138
+
139
+ chain_qa = load_qa_with_sources_chain(llm=llm, chain_type="stuff", prompt=PROMPT, verbose=False)
140
+
141
+ def get_llm_response(message):
142
+ mydocs = get_text_fmt(message)
143
+ responses = chain_qa({"input_documents":mydocs, "question":message})
144
+ return responses
145
+
146
+ # rslt = get_llm_response('can I buy shrimp?')
147
+ # rslt['output_text']
148
+
149
+ # for d in rslt['input_documents']:
150
+ # print(d.page_content)
151
+ # print(d.metadata['url'])
152
+
153
+ # rslt['output_text']
154
+
155
+ """# Database query"""
156
+
157
+ # db = SQLDatabase.from_uri("sqlite:///./data/mbr.db")
158
+
159
+ # db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True, return_intermediate_steps=True)
160
+
161
+ # def db_qry(qry):
162
+ # responses = db_chain('my mbr_id is 456 ;'+str(qry) ) ############### hardcode mbr id 456 for demo
163
+ # return responses
164
+
165
+ # r = db_qry('how many footcare visits have I had?')
166
+ # r['intermediate_steps']
167
+
168
+ """# Math
169
+ - default version
170
+ """
171
+
172
+ # llm_math_chain = LLMMathChain(llm=llm, verbose=True)
173
+
174
+ # llm_math_chain.run('what is the square root of 49?')
175
+
176
+ """# Greeting"""
177
+
178
+ template = """You are an AI assistant for the insurance company Humana.
179
+ Your name is Jarvis and you were created on February 13, 2020.
180
+ Offer polite, friendly greetings and brief small talk.
181
+ Respond to thanks with, 'Glad to help.'
182
+ If the question is not about Humana, politely guide the user to ask questions about Humana Healthy Options benefits.
183
+ QUESTION: {question}
184
+ =========
185
+ FINAL ANSWER:"""
186
+ greet_prompt = PromptTemplate(template=template, input_variables=["question"])
187
+
188
+ greet_llm = LLMChain(prompt=greet_prompt, llm=llm, verbose=True)
189
+
190
+ greet_llm.run('will it snow in Lousiville tomorrow')
191
+
192
+ greet_llm.run('Thanks, that was great')
193
+
194
+ """# MRKL Chain"""
195
+
196
+ tools = [
197
+ Tool(
198
+ name = "Benefit",
199
+ func=get_llm_response,
200
+ description='''Useful for confirming what items can be bought with the healthy options card.
201
+ Useful for when you need to answer questions about healthy options allowance.
202
+ You should ask targeted questions'''
203
+ ),
204
+ # Tool(
205
+ # name="Calculator",
206
+ # func=llm_math_chain.run,
207
+ # description="useful for when you need to answer questions about math"
208
+ # ),
209
+ # Tool(
210
+ # name="Member DB",
211
+ # func=db_qry,
212
+ # description='''useful for when you need to answer questions about member details such their name, id and accumulated use of services.
213
+ # This tool shows how much a benfit has already been consumed.
214
+ # Input should be in the form of a question containing full context'''
215
+ # ),
216
+ Tool(
217
+ name="Greeting",
218
+ func=greet_llm.run,
219
+ description="useful for when you need to respond to greetings, thanks, make small talk or answer questions about yourself"
220
+ ),
221
+ ]
222
+
223
+ mrkl = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=False, return_intermediate_steps=True, max_iterations=5, early_stopping_method="generate")
224
+
225
+ def mrkl_rspnd(qry):
226
+ response = mrkl({"input":str(qry) })
227
+ return response
228
+
229
+ # r = mrkl_rspnd("can I buy fish with the card?")
230
+ # print(r['output'])
231
+
232
+ # print(json.dumps(r['intermediate_steps'], indent=2))
233
+
234
+ #r['intermediate_steps']
235
+
236
+ # r.keys()
237
+
238
+ # from IPython.core.display import display, HTML
239
+
240
+ def get_cot(r):
241
+ cot = '<p>'
242
+ try:
243
+ intermedObj = r['intermediate_steps']
244
+ cot +='<b>Input:</b> '+r['input']+'<br>'
245
+ for agnt_action, obs in intermedObj:
246
+ al = '<br> '.join(agnt_action.log.split('\n') )
247
+ cot += '<b>AI chain of thought:</b> '+ al +'<br>'
248
+ if type(obs) is dict:
249
+ if obs.get('input_documents') is not None: #### this criteria doesn't work
250
+ for d in obs['input_documents']:
251
+ cot += '&nbsp;&nbsp;&nbsp;&nbsp;'+'<i>- '+str(d.page_content)+'</i>'+' <a href="'+ str(d.metadata['url']) +'">'+str(d.metadata['page'])+'</a> '+'<br>'
252
+ cot += '<b>Observation:</b> '+str(obs['output_text']) +'<br><br>'
253
+ elif obs.get('intermediate_steps') is not None:
254
+ cot += '<b>Query:</b> '+str(obs.get('intermediate_steps')) +'<br><br>'
255
+ else:
256
+ pass
257
+ else:
258
+ cot += '<b>Observation:</b> '+str(obs) +'<br><br>'
259
+ except:
260
+ pass
261
+ cot += '</p>'
262
+ return cot
263
+
264
+ # cot = get_cot(r)
265
+ # display(HTML(cot))
266
+
267
+ """# chat example"""
268
+
269
+ def chat(message, history):
270
+ history = history or []
271
+ message = message.lower()
272
+
273
+ response = mrkl_rspnd(message)
274
+ cot = get_cot(response)
275
+ history.append((message, response['output']))
276
+ return history, history, cot
277
+
278
+ css=".gradio-container {background-color: lightgray}"
279
+
280
+ xmpl_list = ["How do I activate my spending account card?",
281
+ "Can I use my card for copays at the doctor?",
282
+ "Can I get fish with this card?",
283
+ "Can I buy vitamins?",
284
+ "Can I use this card with Uber?"]
285
+
286
+ with gr.Blocks(css=css) as demo:
287
+ history_state = gr.State()
288
+ response_state = gr.State()
289
+ gr.Markdown('# Hack QA')
290
+ title='Benefit Chatbot'
291
+ description='chatbot with search on Health Benefits'
292
+ with gr.Row():
293
+ chatbot = gr.Chatbot()
294
+ # with gr.Row():
295
+ with gr.Accordion(label='Show AI chain of thought: ', open=False,):
296
+ ai_cot = gr.HTML(show_label=False)
297
+ with gr.Row():
298
+ message = gr.Textbox(label='Input your question here:',
299
+ placeholder='What is the name of the plan described by this summary of benefits?',
300
+ lines=1)
301
+ submit = gr.Button(value='Send',
302
+ variant='secondary').style(full_width=False)
303
+ submit.click(chat,
304
+ inputs=[message, history_state],
305
+ outputs=[chatbot, history_state, ai_cot])
306
+ gr.Examples(
307
+ examples=xmpl_list,
308
+ inputs=message
309
+ )
310
+
311
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ sentence-transformers==2.2.2
2
+ openai==0.27.0
3
+ gradio==3.19.1
4
+ langchain==0.0.100
5
+ chromadb==0.3.10