gfhayworth commited on
Commit
8a2c414
1 Parent(s): 094be94

Update greg_funcs.py

Browse files

change data acquisition to use vector db
show chain of thought in the chat output

Files changed (1) hide show
  1. greg_funcs.py +93 -66
greg_funcs.py CHANGED
@@ -1,6 +1,6 @@
1
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
2
- from torch import tensor as torch_tensor
3
- from datasets import load_dataset
4
 
5
  from langchain.llms import OpenAI
6
  from langchain.docstore.document import Document
@@ -13,6 +13,7 @@ from langchain.agents import initialize_agent, Tool
13
  import sqlite3
14
  #import pandas as pd
15
  import json
 
16
 
17
  # database
18
  cxn = sqlite3.connect('./data/mbr.db')
@@ -29,81 +30,82 @@ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
29
 
30
  """# import datasets"""
31
 
32
- dataset = load_dataset("gfhayworth/hack_policy", split='train')
33
- mypassages = list(dataset.to_pandas()['psg'])
34
-
35
- dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
36
- dataset_embed_pd = dataset_embed.to_pandas()
37
- mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
38
-
39
- def search(query, passages = mypassages, doc_embedding = mycorpus_embeddings, top_k=20, top_n = 1):
40
- question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
41
- question_embedding = question_embedding #.cuda()
42
- hits = util.semantic_search(question_embedding, doc_embedding, top_k=top_k)
43
- hits = hits[0] # Get the hits for the first query
44
-
45
- ##### Re-Ranking #####
46
- cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
47
- cross_scores = cross_encoder.predict(cross_inp)
48
-
49
- # Sort results by the cross-encoder scores
50
- for idx in range(len(cross_scores)):
51
- hits[idx]['cross-score'] = cross_scores[idx]
52
-
53
- hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
54
- predictions = hits[:top_n]
55
- return predictions
56
- # for hit in hits[0:3]:
57
- # print("\t{:.3f}\t{}".format(hit['cross-score'], mypassages[hit['corpus_id']].replace("\n", " ")))
58
-
59
-
60
-
61
- def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings):
62
- predictions = search(qry, passages = passages, doc_embedding = doc_embedding, top_n=5, )
63
- prediction_text = []
64
- for hit in predictions:
65
- page_content = passages[hit['corpus_id']]
66
- metadata = {"source": hit['corpus_id']}
67
- result = Document(page_content=page_content, metadata=metadata)
68
- prediction_text.append(result)
69
- return prediction_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
71
  """# LLM based qa functions"""
72
 
73
- template = """You are a friendly AI assistant for the insurance company Humana. Given the following extracted parts of a long document and a question, create a succinct final answer.
 
74
  If you don't know the answer, just say that you don't know. Don't try to make up an answer.
75
- If the question is not about Humana, politely inform the user that you are tuned to only answer questions about Humana benefits.
76
  QUESTION: {question}
77
  =========
78
- {context}
79
  =========
80
  FINAL ANSWER:"""
81
- PROMPT = PromptTemplate(template=template, input_variables=["context", "question"])
82
 
83
- chain_qa = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", prompt=PROMPT)
84
-
85
- def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings):
86
- predictions = search(qry, passages = passages, doc_embedding = doc_embedding, top_n=5, )
87
- prediction_text = []
88
- for hit in predictions:
89
- page_content = passages[hit['corpus_id']]
90
- metadata = {"source": hit['corpus_id']}
91
- result = Document(page_content=page_content, metadata=metadata)
92
- prediction_text.append(result)
93
- return prediction_text
94
 
95
  def get_llm_response(message):
96
  mydocs = get_text_fmt(message)
97
- responses = chain_qa.run(input_documents=mydocs, question=message)
98
  return responses
99
 
100
- # for x in xmpl_list:
101
- # print(32*'=')
102
- # print(x)
103
- # print(32*'=')
104
- # r = get_llm_response(x)
105
- # print(r)
106
-
107
  """# Database query"""
108
 
109
  db = SQLDatabase.from_uri("sqlite:///./data/mbr.db")
@@ -113,10 +115,10 @@ llm = OpenAI(temperature=0)
113
  # model_name: str = "text-davinci-003"
114
  # instruction fine-tuned, sometimes referred to as GPT-3.5
115
 
116
- db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
117
 
118
  def db_qry(qry):
119
- responses = db_chain.run(query='my mbr_id is 456 ;'+str(qry) ) ############### hardcode mbr id 456 for demo
120
  return responses
121
 
122
  #db_qry('how many footcare visits have I had?')
@@ -178,13 +180,38 @@ def mrkl_rspnd(qry):
178
  response = mrkl({"input":str(qry) })
179
  return response
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  def chat(message, history):
182
  history = history or []
183
  message = message.lower()
184
 
185
  response = mrkl_rspnd(message)
 
186
  history.append((message, response['output']))
187
- return history, history
188
 
189
 
190
 
 
1
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
2
+ #from torch import tensor as torch_tensor
3
+ #from datasets import load_dataset
4
 
5
  from langchain.llms import OpenAI
6
  from langchain.docstore.document import Document
 
13
  import sqlite3
14
  #import pandas as pd
15
  import json
16
+ import chromadb
17
 
18
  # database
19
  cxn = sqlite3.connect('./data/mbr.db')
 
30
 
31
  """# import datasets"""
32
 
33
+ # dataset = load_dataset("gfhayworth/hack_policy", split='train')
34
+ # mypassages = list(dataset.to_pandas()['psg'])
35
+
36
+ # dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
37
+ # dataset_embed_pd = dataset_embed.to_pandas()
38
+ # mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
39
+ ###########################################################################################################################
40
+ """# set up vector db"""
41
+ from chromadb.config import Settings
42
+
43
+ chroma_client = chromadb.Client(settings=Settings(
44
+ chroma_db_impl="duckdb+parquet",
45
+ persist_directory="./data/mychromadb/" # Optional, defaults to .chromadb/ in the current directory
46
+ ))
47
+ collection = chroma_client.get_collection(name="benefit_collection")
48
+
49
+ def vdb_rslt(qry,src,top_k=20):
50
+ results = collection.query(
51
+ query_embeddings=[ bi_encoder.encode(qry) ],
52
+ n_results=top_k,
53
+ where={"source": src},
54
+ )
55
+ return results
56
+ ##################################################################################################################################
57
+ # Semantic Search Functions
58
+ def rtrv(qry, src = 'H1036236000SB23.pdf', top_k=20):
59
+ rslts = vdb_rslt(qry,src)
60
+ return rslts
61
+
62
+ def rernk(query, collection=collection, top_k=20, top_n = 5):
63
+ rtrv_rslts = rtrv(query, top_k=top_k)
64
+ rtrv_ids = rtrv_rslts.get('ids')[0]
65
+ rtrv_docs = rtrv_rslts.get('documents')[0]
66
+
67
+ ##### Re-Ranking #####
68
+ cross_inp = [[query, doc] for doc in rtrv_docs]
69
+ cross_scores = cross_encoder.predict(cross_inp)
70
+
71
+ # Sort results by the cross-encoder scores
72
+ combined = list(zip(rtrv_ids, list(cross_scores)))
73
+ sorted_tuples = sorted(combined, key=lambda x: x[1], reverse=True)
74
+ sorted_ids = [t[0] for t in sorted_tuples[:top_n]]
75
+ predictions = collection.get(ids=sorted_ids, include=["documents","metadatas"])
76
+ return predictions
77
+
78
+ def get_text_fmt(qry):
79
+ prediction_text = []
80
+ predictions = rernk(qry, collection=collection, top_k=20, top_n = 5)
81
+ docs = predictions['documents']
82
+ meta = predictions['metadatas']
83
+ for i in range(len(docs)):
84
+ result = Document(page_content=docs[i], metadata=meta[i])
85
+ prediction_text.append(result)
86
+ return prediction_text
87
 
88
+ ##################################################################################################################################
89
  """# LLM based qa functions"""
90
 
91
+ template = """You are a friendly AI assistant for the insurance company Humana.
92
+ Given the following extracted parts of a long document and a question, create a succinct final answer.
93
  If you don't know the answer, just say that you don't know. Don't try to make up an answer.
94
+ If the question is not about Humana, politely inform the user that you are tuned to only answer questions about Humana.
95
  QUESTION: {question}
96
  =========
97
+ {summaries}
98
  =========
99
  FINAL ANSWER:"""
100
+ PROMPT = PromptTemplate(template=template, input_variables=["summaries", "question"])
101
 
102
+ chain_qa = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="stuff", prompt=PROMPT, verbose=True)
 
 
 
 
 
 
 
 
 
 
103
 
104
  def get_llm_response(message):
105
  mydocs = get_text_fmt(message)
106
+ responses = chain_qa({"input_documents":mydocs, "question":message})
107
  return responses
108
 
 
 
 
 
 
 
 
109
  """# Database query"""
110
 
111
  db = SQLDatabase.from_uri("sqlite:///./data/mbr.db")
 
115
  # model_name: str = "text-davinci-003"
116
  # instruction fine-tuned, sometimes referred to as GPT-3.5
117
 
118
+ db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True, return_intermediate_steps=True)
119
 
120
  def db_qry(qry):
121
+ responses = db_chain('my mbr_id is 456 ;'+str(qry) ) ############### hardcode mbr id 456 for demo
122
  return responses
123
 
124
  #db_qry('how many footcare visits have I had?')
 
180
  response = mrkl({"input":str(qry) })
181
  return response
182
 
183
+ def get_cot(r):
184
+ cot = '<p>'
185
+ try:
186
+ intermedObj = r['intermediate_steps']
187
+ cot +='<b>Input:</b> '+r['input']+'<br>'
188
+ for agnt_action, obs in intermedObj:
189
+ al = '<br> '.join(agnt_action.log.split('\n') )
190
+ cot += '<b>AI chain of thought:</b> '+ al +'<br>'
191
+ if type(obs) is dict:
192
+ if obs.get('input_documents') is not None:
193
+ for d in obs['input_documents']:
194
+ cot += '&nbsp;&nbsp;&nbsp;&nbsp;'+'<i>- '+str(d.page_content)+'</i>'+' <a href="'+ str(d.metadata['url']) +'">'+str(d.metadata['page'])+'</a> '+'<br>'
195
+ cot += '<b>Observation:</b> '+str(obs['output_text']) +'<br><br>'
196
+ elif obs.get('intermediate_steps') is not None:
197
+ cot += '<b>Query:</b> '+str(obs.get('intermediate_steps')) +'<br><br>'
198
+ else:
199
+ pass
200
+ else:
201
+ cot += '<b>Observation:</b> '+str(obs) +'<br><br>'
202
+ except:
203
+ pass
204
+ cot += '</p>'
205
+ return cot
206
+
207
  def chat(message, history):
208
  history = history or []
209
  message = message.lower()
210
 
211
  response = mrkl_rspnd(message)
212
+ cot = get_cot(response)
213
  history.append((message, response['output']))
214
+ return history, history, cot
215
 
216
 
217