Spaces:
Runtime error
Runtime error
michal
commited on
Commit
•
272ec4b
1
Parent(s):
6f03bef
refactor
Browse files- app.py +4 -104
- wiki_funcs.py +70 -0
app.py
CHANGED
@@ -43,111 +43,12 @@ from datasets import load_dataset
|
|
43 |
|
44 |
|
45 |
from greg_funcs import get_llm_response
|
|
|
46 |
|
47 |
-
"""# import models"""
|
48 |
|
49 |
-
bi_encoder = SentenceTransformer(
|
50 |
-
'sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
|
51 |
-
bi_encoder.max_seq_length = 256 # Truncate long passages to 256 tokens
|
52 |
|
53 |
-
#
|
54 |
-
|
55 |
-
|
56 |
-
"""# import datasets"""
|
57 |
-
dataset = load_dataset("gfhayworth/hack_policy", split='train')
|
58 |
-
|
59 |
-
mypassages = list(dataset.to_pandas()['psg'])
|
60 |
-
|
61 |
-
dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
|
62 |
-
|
63 |
-
dataset_embed_pd = dataset_embed.to_pandas()
|
64 |
-
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
|
65 |
-
|
66 |
-
|
67 |
-
def search(query, top_k=20, top_n=1):
|
68 |
-
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
69 |
-
hits = util.semantic_search(
|
70 |
-
question_embedding, mycorpus_embeddings, top_k=top_k)
|
71 |
-
hits = hits[0] # Get the hits for the first query
|
72 |
-
|
73 |
-
##### Re-Ranking #####
|
74 |
-
cross_inp = [[query, mypassages[hit['corpus_id']]] for hit in hits]
|
75 |
-
cross_scores = cross_encoder.predict(cross_inp)
|
76 |
-
|
77 |
-
# Sort results by the cross-encoder scores
|
78 |
-
for idx in range(len(cross_scores)):
|
79 |
-
hits[idx]['cross-score'] = cross_scores[idx]
|
80 |
-
|
81 |
-
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
82 |
-
predictions = hits[:top_n]
|
83 |
-
return predictions
|
84 |
-
# for hit in hits[0:3]:
|
85 |
-
# print("\t{:.3f}\t{}".format(hit['cross-score'], mypassages[hit['corpus_id']].replace("\n", " ")))
|
86 |
-
|
87 |
-
|
88 |
-
def get_text(qry):
|
89 |
-
# predictions = greg_search(qry)
|
90 |
-
predictions = search(qry)
|
91 |
-
prediction_text = []
|
92 |
-
for hit in predictions:
|
93 |
-
prediction_text.append("{}".format(mypassages[hit['corpus_id']]))
|
94 |
-
return prediction_text
|
95 |
-
|
96 |
-
|
97 |
-
@tool
|
98 |
-
def mysearch(query: str) -> str:
|
99 |
-
"""Query our own datasets.
|
100 |
-
"""
|
101 |
-
rslt = get_text(query)
|
102 |
-
return '\n'.join(rslt)
|
103 |
-
|
104 |
-
|
105 |
-
@tool
|
106 |
-
def mygreetings(greeting: str) -> str:
|
107 |
-
"""Let us do our greetings
|
108 |
-
"""
|
109 |
-
|
110 |
-
return "how are you?"
|
111 |
-
|
112 |
-
# mysearch("who is the best rapper in the world?")
|
113 |
-
|
114 |
-
# """# chat example"""
|
115 |
-
# def chat(message, history):
|
116 |
-
# history = history or []
|
117 |
-
# message = message.lower()
|
118 |
-
|
119 |
-
# responses = get_text(message)
|
120 |
-
# for response in responses:
|
121 |
-
# history.append((message, response))
|
122 |
-
# return history, history
|
123 |
-
|
124 |
-
|
125 |
-
# with gr.Blocks(css=CSS) as demo:
|
126 |
-
# history_state = gr.State()
|
127 |
-
# gr.Markdown('# WikiBot')
|
128 |
-
# title = 'Wikipedia Chatbot'
|
129 |
-
# description = 'chatbot with search on Wikipedia'
|
130 |
-
# with gr.Row():
|
131 |
-
# chatbot = gr.Chatbot()
|
132 |
-
# with gr.Row():
|
133 |
-
# message = gr.Textbox(label='Input your question here:',
|
134 |
-
# placeholder='How many countries are in Europe?',
|
135 |
-
# lines=1)
|
136 |
-
# submit = gr.Button(value='Send',
|
137 |
-
# variant='secondary').style(full_width=False)
|
138 |
-
# submit.click(chat,
|
139 |
-
# inputs=[message, history_state],
|
140 |
-
# outputs=[chatbot, history_state])
|
141 |
-
# gr.Examples(
|
142 |
-
# examples=["How many countries are in Europe?",
|
143 |
-
# "Was Roman Emperor Constantine I a Christian?",
|
144 |
-
# "Who is the best rapper in the world?"],
|
145 |
-
# inputs=message
|
146 |
-
# )
|
147 |
-
|
148 |
-
# demo.launch()
|
149 |
-
|
150 |
-
OPENAI_API_KEY = "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"
|
151 |
|
152 |
# AWS keys
|
153 |
aws_access_key_id = "AKIA3JRWKI2EE5ZFN5NZ"
|
@@ -160,7 +61,7 @@ os.environ["AWS_DEFAULT_REGION"] = aws_region_name
|
|
160 |
|
161 |
# exhumana api key
|
162 |
# todo: may need to pay to get one
|
163 |
-
os.environ['EXHUMAN_API_KEY'] = ''
|
164 |
|
165 |
# news, tmdb keys
|
166 |
os.environ["NEWS_API_KEY"] = ''
|
@@ -171,7 +72,6 @@ tmdb_bearer_token = os.environ["TMDB_BEARER_TOKEN"]
|
|
171 |
|
172 |
TOOLS_LIST = ['serpapi', 'wolfram-alpha', 'pal-math', 'pal-colored-objects', 'news-api', 'tmdb-api',
|
173 |
'open-meteo-api'] # 'google-search'
|
174 |
-
# TOOLS_DEFAULT_LIST = ['mysearch', 'serpapi', 'pal-math']
|
175 |
TOOLS_DEFAULT_LIST = ['mysearch']
|
176 |
BUG_FOUND_MSG = "Congratulations, you've found a bug in this application!"
|
177 |
AUTH_ERR_MSG = "Please paste your OpenAI key from openai.com to use this application. It is not necessary to hit a button or key after pasting it."
|
|
|
43 |
|
44 |
|
45 |
from greg_funcs import get_llm_response
|
46 |
+
from wiki_funcs import mysearch, mygreetings
|
47 |
|
|
|
48 |
|
|
|
|
|
|
|
49 |
|
50 |
+
# OPENAI_API_KEY = "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"
|
51 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # "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"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# AWS keys
|
54 |
aws_access_key_id = "AKIA3JRWKI2EE5ZFN5NZ"
|
|
|
61 |
|
62 |
# exhumana api key
|
63 |
# todo: may need to pay to get one
|
64 |
+
os.environ['EXHUMAN_API_KEY'] = '' # XXX remove, we are not using the talking head because it costs money and doesnt work.
|
65 |
|
66 |
# news, tmdb keys
|
67 |
os.environ["NEWS_API_KEY"] = ''
|
|
|
72 |
|
73 |
TOOLS_LIST = ['serpapi', 'wolfram-alpha', 'pal-math', 'pal-colored-objects', 'news-api', 'tmdb-api',
|
74 |
'open-meteo-api'] # 'google-search'
|
|
|
75 |
TOOLS_DEFAULT_LIST = ['mysearch']
|
76 |
BUG_FOUND_MSG = "Congratulations, you've found a bug in this application!"
|
77 |
AUTH_ERR_MSG = "Please paste your OpenAI key from openai.com to use this application. It is not necessary to hit a button or key after pasting it."
|
wiki_funcs.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.agents import tool
|
2 |
+
|
3 |
+
from torch import tensor as torch_tensor
|
4 |
+
from datasets import load_dataset
|
5 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
6 |
+
|
7 |
+
"""# import models"""
|
8 |
+
|
9 |
+
bi_encoder = SentenceTransformer(
|
10 |
+
'sentence-transformers/multi-qa-MiniLM-L6-cos-v1')
|
11 |
+
bi_encoder.max_seq_length = 256 # Truncate long passages to 256 tokens
|
12 |
+
|
13 |
+
# The bi-encoder will retrieve top_k documents. We use a cross-encoder, to re-rank the results list to improve the quality
|
14 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
15 |
+
|
16 |
+
"""# import datasets"""
|
17 |
+
dataset = load_dataset("gfhayworth/wiki_mini", split='train')
|
18 |
+
|
19 |
+
mypassages = list(dataset.to_pandas()['psg'])
|
20 |
+
|
21 |
+
dataset_embed = load_dataset("gfhayworth/wiki_mini_embed", split='train')
|
22 |
+
|
23 |
+
dataset_embed_pd = dataset_embed.to_pandas()
|
24 |
+
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)
|
25 |
+
|
26 |
+
|
27 |
+
def search(query, top_k=20, top_n=1):
|
28 |
+
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
29 |
+
hits = util.semantic_search(
|
30 |
+
question_embedding, mycorpus_embeddings, top_k=top_k)
|
31 |
+
hits = hits[0] # Get the hits for the first query
|
32 |
+
|
33 |
+
##### Re-Ranking #####
|
34 |
+
cross_inp = [[query, mypassages[hit['corpus_id']]] for hit in hits]
|
35 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
36 |
+
|
37 |
+
# Sort results by the cross-encoder scores
|
38 |
+
for idx in range(len(cross_scores)):
|
39 |
+
hits[idx]['cross-score'] = cross_scores[idx]
|
40 |
+
|
41 |
+
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
42 |
+
predictions = hits[:top_n]
|
43 |
+
return predictions
|
44 |
+
# for hit in hits[0:3]:
|
45 |
+
# print("\t{:.3f}\t{}".format(hit['cross-score'], mypassages[hit['corpus_id']].replace("\n", " ")))
|
46 |
+
|
47 |
+
|
48 |
+
def get_text(qry):
|
49 |
+
# predictions = greg_search(qry)
|
50 |
+
predictions = search(qry)
|
51 |
+
prediction_text = []
|
52 |
+
for hit in predictions:
|
53 |
+
prediction_text.append("{}".format(mypassages[hit['corpus_id']]))
|
54 |
+
return prediction_text
|
55 |
+
|
56 |
+
|
57 |
+
@tool
|
58 |
+
def mysearch(query: str) -> str:
|
59 |
+
"""Query our own datasets.
|
60 |
+
"""
|
61 |
+
rslt = get_text(query)
|
62 |
+
return '\n'.join(rslt)
|
63 |
+
|
64 |
+
|
65 |
+
@tool
|
66 |
+
def mygreetings(greeting: str) -> str:
|
67 |
+
"""Let us do our greetings
|
68 |
+
"""
|
69 |
+
|
70 |
+
return "how are you?"
|