Spaces:
Runtime error
Runtime error
gfhayworth
commited on
Commit
•
8a2c414
1
Parent(s):
094be94
Update greg_funcs.py
Browse fileschange data acquisition to use vector db
show chain of thought in the chat output
- greg_funcs.py +93 -66
greg_funcs.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
2 |
-
from torch import tensor as torch_tensor
|
3 |
-
from datasets import load_dataset
|
4 |
|
5 |
from langchain.llms import OpenAI
|
6 |
from langchain.docstore.document import Document
|
@@ -13,6 +13,7 @@ from langchain.agents import initialize_agent, Tool
|
|
13 |
import sqlite3
|
14 |
#import pandas as pd
|
15 |
import json
|
|
|
16 |
|
17 |
# database
|
18 |
cxn = sqlite3.connect('./data/mbr.db')
|
@@ -29,81 +30,82 @@ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
|
29 |
|
30 |
"""# import datasets"""
|
31 |
|
32 |
-
dataset = load_dataset("gfhayworth/hack_policy", split='train')
|
33 |
-
mypassages = list(dataset.to_pandas()['psg'])
|
34 |
-
|
35 |
-
dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
|
36 |
-
dataset_embed_pd = dataset_embed.to_pandas()
|
37 |
-
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
def
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
|
|
71 |
"""# LLM based qa functions"""
|
72 |
|
73 |
-
template = """You are a friendly AI assistant for the insurance company Humana.
|
|
|
74 |
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
75 |
-
If the question is not about Humana, politely inform the user that you are tuned to only answer questions about Humana
|
76 |
QUESTION: {question}
|
77 |
=========
|
78 |
-
{
|
79 |
=========
|
80 |
FINAL ANSWER:"""
|
81 |
-
PROMPT = PromptTemplate(template=template, input_variables=["
|
82 |
|
83 |
-
chain_qa =
|
84 |
-
|
85 |
-
def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings):
|
86 |
-
predictions = search(qry, passages = passages, doc_embedding = doc_embedding, top_n=5, )
|
87 |
-
prediction_text = []
|
88 |
-
for hit in predictions:
|
89 |
-
page_content = passages[hit['corpus_id']]
|
90 |
-
metadata = {"source": hit['corpus_id']}
|
91 |
-
result = Document(page_content=page_content, metadata=metadata)
|
92 |
-
prediction_text.append(result)
|
93 |
-
return prediction_text
|
94 |
|
95 |
def get_llm_response(message):
|
96 |
mydocs = get_text_fmt(message)
|
97 |
-
responses = chain_qa
|
98 |
return responses
|
99 |
|
100 |
-
# for x in xmpl_list:
|
101 |
-
# print(32*'=')
|
102 |
-
# print(x)
|
103 |
-
# print(32*'=')
|
104 |
-
# r = get_llm_response(x)
|
105 |
-
# print(r)
|
106 |
-
|
107 |
"""# Database query"""
|
108 |
|
109 |
db = SQLDatabase.from_uri("sqlite:///./data/mbr.db")
|
@@ -113,10 +115,10 @@ llm = OpenAI(temperature=0)
|
|
113 |
# model_name: str = "text-davinci-003"
|
114 |
# instruction fine-tuned, sometimes referred to as GPT-3.5
|
115 |
|
116 |
-
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
|
117 |
|
118 |
def db_qry(qry):
|
119 |
-
responses = db_chain
|
120 |
return responses
|
121 |
|
122 |
#db_qry('how many footcare visits have I had?')
|
@@ -178,13 +180,38 @@ def mrkl_rspnd(qry):
|
|
178 |
response = mrkl({"input":str(qry) })
|
179 |
return response
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
def chat(message, history):
|
182 |
history = history or []
|
183 |
message = message.lower()
|
184 |
|
185 |
response = mrkl_rspnd(message)
|
|
|
186 |
history.append((message, response['output']))
|
187 |
-
return history, history
|
188 |
|
189 |
|
190 |
|
|
|
1 |
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
2 |
+
#from torch import tensor as torch_tensor
|
3 |
+
#from datasets import load_dataset
|
4 |
|
5 |
from langchain.llms import OpenAI
|
6 |
from langchain.docstore.document import Document
|
|
|
13 |
import sqlite3
|
14 |
#import pandas as pd
|
15 |
import json
|
16 |
+
import chromadb
|
17 |
|
18 |
# database
|
19 |
cxn = sqlite3.connect('./data/mbr.db')
|
|
|
30 |
|
31 |
"""# import datasets"""
|
32 |
|
33 |
+
# dataset = load_dataset("gfhayworth/hack_policy", split='train')
|
34 |
+
# mypassages = list(dataset.to_pandas()['psg'])
|
35 |
+
|
36 |
+
# dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
|
37 |
+
# dataset_embed_pd = dataset_embed.to_pandas()
|
38 |
+
# mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
|
39 |
+
###########################################################################################################################
|
40 |
+
"""# set up vector db"""
|
41 |
+
from chromadb.config import Settings
|
42 |
+
|
43 |
+
chroma_client = chromadb.Client(settings=Settings(
|
44 |
+
chroma_db_impl="duckdb+parquet",
|
45 |
+
persist_directory="./data/mychromadb/" # Optional, defaults to .chromadb/ in the current directory
|
46 |
+
))
|
47 |
+
collection = chroma_client.get_collection(name="benefit_collection")
|
48 |
+
|
49 |
+
def vdb_rslt(qry,src,top_k=20):
|
50 |
+
results = collection.query(
|
51 |
+
query_embeddings=[ bi_encoder.encode(qry) ],
|
52 |
+
n_results=top_k,
|
53 |
+
where={"source": src},
|
54 |
+
)
|
55 |
+
return results
|
56 |
+
##################################################################################################################################
|
57 |
+
# Semantic Search Functions
|
58 |
+
def rtrv(qry, src = 'H1036236000SB23.pdf', top_k=20):
|
59 |
+
rslts = vdb_rslt(qry,src)
|
60 |
+
return rslts
|
61 |
+
|
62 |
+
def rernk(query, collection=collection, top_k=20, top_n = 5):
|
63 |
+
rtrv_rslts = rtrv(query, top_k=top_k)
|
64 |
+
rtrv_ids = rtrv_rslts.get('ids')[0]
|
65 |
+
rtrv_docs = rtrv_rslts.get('documents')[0]
|
66 |
+
|
67 |
+
##### Re-Ranking #####
|
68 |
+
cross_inp = [[query, doc] for doc in rtrv_docs]
|
69 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
70 |
+
|
71 |
+
# Sort results by the cross-encoder scores
|
72 |
+
combined = list(zip(rtrv_ids, list(cross_scores)))
|
73 |
+
sorted_tuples = sorted(combined, key=lambda x: x[1], reverse=True)
|
74 |
+
sorted_ids = [t[0] for t in sorted_tuples[:top_n]]
|
75 |
+
predictions = collection.get(ids=sorted_ids, include=["documents","metadatas"])
|
76 |
+
return predictions
|
77 |
+
|
78 |
+
def get_text_fmt(qry):
|
79 |
+
prediction_text = []
|
80 |
+
predictions = rernk(qry, collection=collection, top_k=20, top_n = 5)
|
81 |
+
docs = predictions['documents']
|
82 |
+
meta = predictions['metadatas']
|
83 |
+
for i in range(len(docs)):
|
84 |
+
result = Document(page_content=docs[i], metadata=meta[i])
|
85 |
+
prediction_text.append(result)
|
86 |
+
return prediction_text
|
87 |
|
88 |
+
##################################################################################################################################
|
89 |
"""# LLM based qa functions"""
|
90 |
|
91 |
+
template = """You are a friendly AI assistant for the insurance company Humana.
|
92 |
+
Given the following extracted parts of a long document and a question, create a succinct final answer.
|
93 |
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
94 |
+
If the question is not about Humana, politely inform the user that you are tuned to only answer questions about Humana.
|
95 |
QUESTION: {question}
|
96 |
=========
|
97 |
+
{summaries}
|
98 |
=========
|
99 |
FINAL ANSWER:"""
|
100 |
+
PROMPT = PromptTemplate(template=template, input_variables=["summaries", "question"])
|
101 |
|
102 |
+
chain_qa = load_qa_with_sources_chain(OpenAI(temperature=0), chain_type="stuff", prompt=PROMPT, verbose=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
def get_llm_response(message):
|
105 |
mydocs = get_text_fmt(message)
|
106 |
+
responses = chain_qa({"input_documents":mydocs, "question":message})
|
107 |
return responses
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
"""# Database query"""
|
110 |
|
111 |
db = SQLDatabase.from_uri("sqlite:///./data/mbr.db")
|
|
|
115 |
# model_name: str = "text-davinci-003"
|
116 |
# instruction fine-tuned, sometimes referred to as GPT-3.5
|
117 |
|
118 |
+
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True, return_intermediate_steps=True)
|
119 |
|
120 |
def db_qry(qry):
|
121 |
+
responses = db_chain('my mbr_id is 456 ;'+str(qry) ) ############### hardcode mbr id 456 for demo
|
122 |
return responses
|
123 |
|
124 |
#db_qry('how many footcare visits have I had?')
|
|
|
180 |
response = mrkl({"input":str(qry) })
|
181 |
return response
|
182 |
|
183 |
+
def get_cot(r):
|
184 |
+
cot = '<p>'
|
185 |
+
try:
|
186 |
+
intermedObj = r['intermediate_steps']
|
187 |
+
cot +='<b>Input:</b> '+r['input']+'<br>'
|
188 |
+
for agnt_action, obs in intermedObj:
|
189 |
+
al = '<br> '.join(agnt_action.log.split('\n') )
|
190 |
+
cot += '<b>AI chain of thought:</b> '+ al +'<br>'
|
191 |
+
if type(obs) is dict:
|
192 |
+
if obs.get('input_documents') is not None:
|
193 |
+
for d in obs['input_documents']:
|
194 |
+
cot += ' '+'<i>- '+str(d.page_content)+'</i>'+' <a href="'+ str(d.metadata['url']) +'">'+str(d.metadata['page'])+'</a> '+'<br>'
|
195 |
+
cot += '<b>Observation:</b> '+str(obs['output_text']) +'<br><br>'
|
196 |
+
elif obs.get('intermediate_steps') is not None:
|
197 |
+
cot += '<b>Query:</b> '+str(obs.get('intermediate_steps')) +'<br><br>'
|
198 |
+
else:
|
199 |
+
pass
|
200 |
+
else:
|
201 |
+
cot += '<b>Observation:</b> '+str(obs) +'<br><br>'
|
202 |
+
except:
|
203 |
+
pass
|
204 |
+
cot += '</p>'
|
205 |
+
return cot
|
206 |
+
|
207 |
def chat(message, history):
|
208 |
history = history or []
|
209 |
message = message.lower()
|
210 |
|
211 |
response = mrkl_rspnd(message)
|
212 |
+
cot = get_cot(response)
|
213 |
history.append((message, response['output']))
|
214 |
+
return history, history, cot
|
215 |
|
216 |
|
217 |
|