Spaces:
Runtime error
Runtime error
demo
#2
by
alohajason
- opened
- app.py +8 -10
- data/mychromadb/chroma-collections.parquet +0 -3
- data/mychromadb/chroma-embeddings.parquet +0 -3
- data/mychromadb/index/id_to_uuid_aca62790-d606-4764-91d9-08324ea54984.pkl +0 -3
- data/mychromadb/index/id_to_uuid_f75ac8ed-ecb5-4656-84d1-90e84f6f083a.pkl +0 -3
- data/mychromadb/index/index_aca62790-d606-4764-91d9-08324ea54984.bin +0 -3
- data/mychromadb/index/index_f75ac8ed-ecb5-4656-84d1-90e84f6f083a.bin +0 -3
- data/mychromadb/index/index_metadata_aca62790-d606-4764-91d9-08324ea54984.pkl +0 -3
- data/mychromadb/index/index_metadata_f75ac8ed-ecb5-4656-84d1-90e84f6f083a.pkl +0 -3
- data/mychromadb/index/uuid_to_id_aca62790-d606-4764-91d9-08324ea54984.pkl +0 -3
- data/mychromadb/index/uuid_to_id_f75ac8ed-ecb5-4656-84d1-90e84f6f083a.pkl +0 -3
- greg_funcs.py +66 -93
- requirements.txt +1 -2
app.py
CHANGED
@@ -38,10 +38,10 @@ import requests
|
|
38 |
import os
|
39 |
import gradio as gr
|
40 |
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
41 |
-
|
42 |
-
|
43 |
|
44 |
-
from greg_funcs import mrkl_rspnd
|
45 |
|
46 |
|
47 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
@@ -397,7 +397,6 @@ class ChatWrapper:
|
|
397 |
|
398 |
|
399 |
response = mrkl_rspnd(inp)
|
400 |
-
cot = get_cot(response)
|
401 |
output = response['output']
|
402 |
|
403 |
"""
|
@@ -464,7 +463,7 @@ class ChatWrapper:
|
|
464 |
raise e
|
465 |
finally:
|
466 |
self.lock.release()
|
467 |
-
return history, history, html_video, temp_file, html_audio, temp_aud_file,
|
468 |
# return history, history, html_audio, temp_aud_file, ""
|
469 |
|
470 |
|
@@ -656,8 +655,7 @@ with gr.Blocks(css=CSS) as block:
|
|
656 |
|
657 |
with gr.Column(scale=7):
|
658 |
chatbot = gr.Chatbot()
|
659 |
-
|
660 |
-
ai_cot = gr.HTML(show_label=False)
|
661 |
with gr.Row():
|
662 |
message = gr.Textbox(label="What's on your mind??",
|
663 |
placeholder=PLACEHOLDER,
|
@@ -854,7 +852,7 @@ with gr.Blocks(css=CSS) as block:
|
|
854 |
anticipation_level_state, joy_level_state, trust_level_state, fear_level_state,
|
855 |
surprise_level_state, sadness_level_state, disgust_level_state, anger_level_state,
|
856 |
lang_level_state, translate_to_state, literary_style_state],
|
857 |
-
outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file,
|
858 |
# outputs=[chatbot, history_state, audio_html, tmp_aud_file, message])
|
859 |
|
860 |
submit.click(chat, inputs=[message, history_state, chain_state, trace_chain_state,
|
@@ -863,8 +861,8 @@ with gr.Blocks(css=CSS) as block:
|
|
863 |
anticipation_level_state, joy_level_state, trust_level_state, fear_level_state,
|
864 |
surprise_level_state, sadness_level_state, disgust_level_state, anger_level_state,
|
865 |
lang_level_state, translate_to_state, literary_style_state],
|
866 |
-
outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file,
|
867 |
-
|
868 |
|
869 |
|
870 |
block.launch(debug=True)
|
|
|
38 |
import os
|
39 |
import gradio as gr
|
40 |
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
41 |
+
from torch import tensor as torch_tensor
|
42 |
+
from datasets import load_dataset
|
43 |
|
44 |
+
from greg_funcs import mrkl_rspnd
|
45 |
|
46 |
|
47 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
397 |
|
398 |
|
399 |
response = mrkl_rspnd(inp)
|
|
|
400 |
output = response['output']
|
401 |
|
402 |
"""
|
|
|
463 |
raise e
|
464 |
finally:
|
465 |
self.lock.release()
|
466 |
+
return history, history, html_video, temp_file, html_audio, temp_aud_file, ""
|
467 |
# return history, history, html_audio, temp_aud_file, ""
|
468 |
|
469 |
|
|
|
655 |
|
656 |
with gr.Column(scale=7):
|
657 |
chatbot = gr.Chatbot()
|
658 |
+
|
|
|
659 |
with gr.Row():
|
660 |
message = gr.Textbox(label="What's on your mind??",
|
661 |
placeholder=PLACEHOLDER,
|
|
|
852 |
anticipation_level_state, joy_level_state, trust_level_state, fear_level_state,
|
853 |
surprise_level_state, sadness_level_state, disgust_level_state, anger_level_state,
|
854 |
lang_level_state, translate_to_state, literary_style_state],
|
855 |
+
outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message])
|
856 |
# outputs=[chatbot, history_state, audio_html, tmp_aud_file, message])
|
857 |
|
858 |
submit.click(chat, inputs=[message, history_state, chain_state, trace_chain_state,
|
|
|
861 |
anticipation_level_state, joy_level_state, trust_level_state, fear_level_state,
|
862 |
surprise_level_state, sadness_level_state, disgust_level_state, anger_level_state,
|
863 |
lang_level_state, translate_to_state, literary_style_state],
|
864 |
+
outputs=[chatbot, history_state, video_html, my_file, audio_html, tmp_aud_file, message])
|
865 |
+
# outputs=[chatbot, history_state, audio_html, tmp_aud_file, message])
|
866 |
|
867 |
|
868 |
block.launch(debug=True)
|
data/mychromadb/chroma-collections.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:f53c5c7bfaf512515b19c15daf41df9f862138f1fad64bfa27490123a6ae0630
|
3 |
-
size 592
|
|
|
|
|
|
|
|
data/mychromadb/chroma-embeddings.parquet
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:00c915a2802e985b9b5a95b9c90d1c541cea4b671e64ac2b557e9f4dec0c9648
|
3 |
-
size 3352718
|
|
|
|
|
|
|
|
data/mychromadb/index/id_to_uuid_aca62790-d606-4764-91d9-08324ea54984.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:8a21d750f00302ae6e2e80f188dd5910f14cdc773dcfe8cc420a9b226b1a06ed
|
3 |
-
size 44935
|
|
|
|
|
|
|
|
data/mychromadb/index/id_to_uuid_f75ac8ed-ecb5-4656-84d1-90e84f6f083a.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:39ba580b21452d0c0363b9621737aa4f819cb6e8d811888905aeb08c1370dd46
|
3 |
-
size 24704
|
|
|
|
|
|
|
|
data/mychromadb/index/index_aca62790-d606-4764-91d9-08324ea54984.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:1dd287418a49ded3f5e07a60872c421073d37942ff0b2c43cd69cdccdd069040
|
3 |
-
size 2341688
|
|
|
|
|
|
|
|
data/mychromadb/index/index_f75ac8ed-ecb5-4656-84d1-90e84f6f083a.bin
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:eb8680a43caa02daca0e4fd5ce8abadab705b2ba994bdb3d056164406c606d00
|
3 |
-
size 1292260
|
|
|
|
|
|
|
|
data/mychromadb/index/index_metadata_aca62790-d606-4764-91d9-08324ea54984.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:92cfe00d3fcdf3bffa64e13342d5210490b113a67a82fe54460622777e35bf2b
|
3 |
-
size 74
|
|
|
|
|
|
|
|
data/mychromadb/index/index_metadata_f75ac8ed-ecb5-4656-84d1-90e84f6f083a.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:505fa51457f99d2d4bdf6860f2326abc7d2d45f6f0462657b6de1aa564945472
|
3 |
-
size 74
|
|
|
|
|
|
|
|
data/mychromadb/index/uuid_to_id_aca62790-d606-4764-91d9-08324ea54984.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:dfa0915ece0f552a799dbce3e57e44fe612eb118ab09fe17ab71b98ab1a71a32
|
3 |
-
size 52582
|
|
|
|
|
|
|
|
data/mychromadb/index/uuid_to_id_f75ac8ed-ecb5-4656-84d1-90e84f6f083a.pkl
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:0cfd370f60ead46494e5acd164bccd9dc2a1808e498a98bded54d94dedd73085
|
3 |
-
size 28906
|
|
|
|
|
|
|
|
greg_funcs.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
2 |
-
|
3 |
-
|
4 |
|
5 |
from langchain.llms import OpenAI
|
6 |
from langchain.docstore.document import Document
|
@@ -13,7 +13,6 @@ from langchain.agents import initialize_agent, Tool
|
|
13 |
import sqlite3
|
14 |
#import pandas as pd
|
15 |
import json
|
16 |
-
import chromadb
|
17 |
|
18 |
# database
|
19 |
cxn = sqlite3.connect('./data/mbr.db')
|
@@ -30,82 +29,81 @@ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
|
30 |
|
31 |
"""# import datasets"""
|
32 |
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
)
|
55 |
-
|
56 |
-
|
57 |
-
#
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
def
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
# Sort results by the cross-encoder scores
|
72 |
-
combined = list(zip(rtrv_ids, list(cross_scores)))
|
73 |
-
sorted_tuples = sorted(combined, key=lambda x: x[1], reverse=True)
|
74 |
-
sorted_ids = [t[0] for t in sorted_tuples[:top_n]]
|
75 |
-
predictions = collection.get(ids=sorted_ids, include=["documents","metadatas"])
|
76 |
-
return predictions
|
77 |
-
|
78 |
-
def get_text_fmt(qry):
|
79 |
-
prediction_text = []
|
80 |
-
predictions = rernk(qry, collection=collection, top_k=20, top_n = 5)
|
81 |
-
docs = predictions['documents']
|
82 |
-
meta = predictions['metadatas']
|
83 |
-
for i in range(len(docs)):
|
84 |
-
result = Document(page_content=docs[i], metadata=meta[i])
|
85 |
-
prediction_text.append(result)
|
86 |
-
return prediction_text
|
87 |
|
88 |
-
##################################################################################################################################
|
89 |
"""# LLM based qa functions"""
|
90 |
|
91 |
-
template = """You are a friendly AI assistant for the insurance company Humana.
|
92 |
-
Given the following extracted parts of a long document and a question, create a succinct final answer.
|
93 |
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
94 |
-
If the question is not about Humana, politely inform the user that you are tuned to only answer questions about Humana.
|
95 |
QUESTION: {question}
|
96 |
=========
|
97 |
-
{
|
98 |
=========
|
99 |
FINAL ANSWER:"""
|
100 |
-
PROMPT = PromptTemplate(template=template, input_variables=["
|
101 |
|
102 |
-
chain_qa =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
def get_llm_response(message):
|
105 |
mydocs = get_text_fmt(message)
|
106 |
-
responses = chain_qa(
|
107 |
return responses
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
"""# Database query"""
|
110 |
|
111 |
db = SQLDatabase.from_uri("sqlite:///./data/mbr.db")
|
@@ -115,10 +113,10 @@ llm = OpenAI(temperature=0)
|
|
115 |
# model_name: str = "text-davinci-003"
|
116 |
# instruction fine-tuned, sometimes referred to as GPT-3.5
|
117 |
|
118 |
-
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True
|
119 |
|
120 |
def db_qry(qry):
|
121 |
-
responses = db_chain('my mbr_id is 456 ;'+str(qry) ) ############### hardcode mbr id 456 for demo
|
122 |
return responses
|
123 |
|
124 |
#db_qry('how many footcare visits have I had?')
|
@@ -180,38 +178,13 @@ def mrkl_rspnd(qry):
|
|
180 |
response = mrkl({"input":str(qry) })
|
181 |
return response
|
182 |
|
183 |
-
def get_cot(r):
|
184 |
-
cot = '<p>'
|
185 |
-
try:
|
186 |
-
intermedObj = r['intermediate_steps']
|
187 |
-
cot +='<b>Input:</b> '+r['input']+'<br>'
|
188 |
-
for agnt_action, obs in intermedObj:
|
189 |
-
al = '<br> '.join(agnt_action.log.split('\n') )
|
190 |
-
cot += '<b>AI chain of thought:</b> '+ al +'<br>'
|
191 |
-
if type(obs) is dict:
|
192 |
-
if obs.get('input_documents') is not None:
|
193 |
-
for d in obs['input_documents']:
|
194 |
-
cot += ' '+'<i>- '+str(d.page_content)+'</i>'+' <a href="'+ str(d.metadata['url']) +'">'+'''<span style="color: blue;">'''+str(d.metadata['page'])+'</span></a> '+'<br>'
|
195 |
-
cot += '<b>Observation:</b> '+str(obs['output_text']) +'<br><br>'
|
196 |
-
elif obs.get('intermediate_steps') is not None:
|
197 |
-
cot += '<b>Query:</b> '+str(obs.get('intermediate_steps')) +'<br><br>'
|
198 |
-
else:
|
199 |
-
pass
|
200 |
-
else:
|
201 |
-
cot += '<b>Observation:</b> '+str(obs) +'<br><br>'
|
202 |
-
except:
|
203 |
-
pass
|
204 |
-
cot += '</p>'
|
205 |
-
return cot
|
206 |
-
|
207 |
def chat(message, history):
|
208 |
history = history or []
|
209 |
message = message.lower()
|
210 |
|
211 |
response = mrkl_rspnd(message)
|
212 |
-
cot = get_cot(response)
|
213 |
history.append((message, response['output']))
|
214 |
-
return history, history
|
215 |
|
216 |
|
217 |
|
|
|
1 |
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
2 |
+
from torch import tensor as torch_tensor
|
3 |
+
from datasets import load_dataset
|
4 |
|
5 |
from langchain.llms import OpenAI
|
6 |
from langchain.docstore.document import Document
|
|
|
13 |
import sqlite3
|
14 |
#import pandas as pd
|
15 |
import json
|
|
|
16 |
|
17 |
# database
|
18 |
cxn = sqlite3.connect('./data/mbr.db')
|
|
|
29 |
|
30 |
"""# import datasets"""
|
31 |
|
32 |
+
dataset = load_dataset("gfhayworth/hack_policy", split='train')
|
33 |
+
mypassages = list(dataset.to_pandas()['psg'])
|
34 |
+
|
35 |
+
dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
|
36 |
+
dataset_embed_pd = dataset_embed.to_pandas()
|
37 |
+
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
|
38 |
+
|
39 |
+
def search(query, passages = mypassages, doc_embedding = mycorpus_embeddings, top_k=20, top_n = 1):
|
40 |
+
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
41 |
+
question_embedding = question_embedding #.cuda()
|
42 |
+
hits = util.semantic_search(question_embedding, doc_embedding, top_k=top_k)
|
43 |
+
hits = hits[0] # Get the hits for the first query
|
44 |
+
|
45 |
+
##### Re-Ranking #####
|
46 |
+
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
|
47 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
48 |
+
|
49 |
+
# Sort results by the cross-encoder scores
|
50 |
+
for idx in range(len(cross_scores)):
|
51 |
+
hits[idx]['cross-score'] = cross_scores[idx]
|
52 |
+
|
53 |
+
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
54 |
+
predictions = hits[:top_n]
|
55 |
+
return predictions
|
56 |
+
# for hit in hits[0:3]:
|
57 |
+
# print("\t{:.3f}\t{}".format(hit['cross-score'], mypassages[hit['corpus_id']].replace("\n", " ")))
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings):
|
62 |
+
predictions = search(qry, passages = passages, doc_embedding = doc_embedding, top_n=5, )
|
63 |
+
prediction_text = []
|
64 |
+
for hit in predictions:
|
65 |
+
page_content = passages[hit['corpus_id']]
|
66 |
+
metadata = {"source": hit['corpus_id']}
|
67 |
+
result = Document(page_content=page_content, metadata=metadata)
|
68 |
+
prediction_text.append(result)
|
69 |
+
return prediction_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
|
|
71 |
"""# LLM based qa functions"""
|
72 |
|
73 |
+
template = """You are a friendly AI assistant for the insurance company Humana. Given the following extracted parts of a long document and a question, create a succinct final answer.
|
|
|
74 |
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
75 |
+
If the question is not about Humana, politely inform the user that you are tuned to only answer questions about Humana benefits.
|
76 |
QUESTION: {question}
|
77 |
=========
|
78 |
+
{context}
|
79 |
=========
|
80 |
FINAL ANSWER:"""
|
81 |
+
PROMPT = PromptTemplate(template=template, input_variables=["context", "question"])
|
82 |
|
83 |
+
chain_qa = load_qa_chain(OpenAI(temperature=0), chain_type="stuff", prompt=PROMPT)
|
84 |
+
|
85 |
+
def get_text_fmt(qry, passages = mypassages, doc_embedding=mycorpus_embeddings):
|
86 |
+
predictions = search(qry, passages = passages, doc_embedding = doc_embedding, top_n=5, )
|
87 |
+
prediction_text = []
|
88 |
+
for hit in predictions:
|
89 |
+
page_content = passages[hit['corpus_id']]
|
90 |
+
metadata = {"source": hit['corpus_id']}
|
91 |
+
result = Document(page_content=page_content, metadata=metadata)
|
92 |
+
prediction_text.append(result)
|
93 |
+
return prediction_text
|
94 |
|
95 |
def get_llm_response(message):
|
96 |
mydocs = get_text_fmt(message)
|
97 |
+
responses = chain_qa.run(input_documents=mydocs, question=message)
|
98 |
return responses
|
99 |
|
100 |
+
# for x in xmpl_list:
|
101 |
+
# print(32*'=')
|
102 |
+
# print(x)
|
103 |
+
# print(32*'=')
|
104 |
+
# r = get_llm_response(x)
|
105 |
+
# print(r)
|
106 |
+
|
107 |
"""# Database query"""
|
108 |
|
109 |
db = SQLDatabase.from_uri("sqlite:///./data/mbr.db")
|
|
|
113 |
# model_name: str = "text-davinci-003"
|
114 |
# instruction fine-tuned, sometimes referred to as GPT-3.5
|
115 |
|
116 |
+
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)
|
117 |
|
118 |
def db_qry(qry):
|
119 |
+
responses = db_chain.run(query='my mbr_id is 456 ;'+str(qry) ) ############### hardcode mbr id 456 for demo
|
120 |
return responses
|
121 |
|
122 |
#db_qry('how many footcare visits have I had?')
|
|
|
178 |
response = mrkl({"input":str(qry) })
|
179 |
return response
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
def chat(message, history):
|
182 |
history = history or []
|
183 |
message = message.lower()
|
184 |
|
185 |
response = mrkl_rspnd(message)
|
|
|
186 |
history.append((message, response['output']))
|
187 |
+
return history, history
|
188 |
|
189 |
|
190 |
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
sentence-transformers
|
2 |
-
|
3 |
openai==0.26.1
|
4 |
gradio
|
5 |
# google-search-results
|
@@ -9,4 +9,3 @@ langchain
|
|
9 |
requests==2.28.2
|
10 |
git+https://github.com/openai/whisper.git
|
11 |
boto3
|
12 |
-
chromadb
|
|
|
1 |
sentence-transformers
|
2 |
+
datasets
|
3 |
openai==0.26.1
|
4 |
gradio
|
5 |
# google-search-results
|
|
|
9 |
requests==2.28.2
|
10 |
git+https://github.com/openai/whisper.git
|
11 |
boto3
|
|