Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Kartheik Iyer
commited on
Commit
•
18e51e3
1
Parent(s):
036767e
update dataset and security fixes
Browse files- app_gradio.py +75 -45
app_gradio.py
CHANGED
@@ -34,7 +34,7 @@ from typing import List, Literal
|
|
34 |
|
35 |
from nltk.corpus import stopwords
|
36 |
import nltk
|
37 |
-
from openai import OpenAI
|
38 |
# import anthropic
|
39 |
import cohere
|
40 |
import faiss
|
@@ -64,6 +64,12 @@ embed_model = "text-embedding-3-small"
|
|
64 |
embeddings = OpenAIEmbeddings(model = embed_model, api_key = openai_key)
|
65 |
nlp = load_nlp()
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
def get_keywords(text, nlp=nlp):
|
69 |
result = []
|
@@ -77,8 +83,12 @@ def get_keywords(text, nlp=nlp):
|
|
77 |
return result
|
78 |
|
79 |
def load_arxiv_corpus():
|
80 |
-
arxiv_corpus = load_from_disk('data/')
|
81 |
-
arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss')
|
|
|
|
|
|
|
|
|
82 |
print('loading arxiv corpus from disk')
|
83 |
return arxiv_corpus
|
84 |
|
@@ -344,6 +354,23 @@ def guess_question_type(query: str):
|
|
344 |
messages = [("system",question_categorization_prompt,),("human", query),]
|
345 |
return gen_client.invoke(messages).content
|
346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
class OverallConsensusEvaluation(BaseModel):
|
348 |
rewritten_statement: str = Field(
|
349 |
...,
|
@@ -459,48 +486,51 @@ def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type,
|
|
459 |
search_text_list = ['rooting around in the paper pile...','looking for clarity...','scanning the event horizon...','peering into the abyss...','potatoes power this ongoing search...']
|
460 |
gen_text_list = ['making the LLM talk to the papers...','invoking arcane rituals...','gone to library, please wait...','is there really an answer to this...']
|
461 |
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
ec.
|
468 |
-
ec.
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
|
|
|
|
|
|
504 |
|
505 |
def create_interface():
|
506 |
custom_css = """
|
|
|
34 |
|
35 |
from nltk.corpus import stopwords
|
36 |
import nltk
|
37 |
+
from openai import OpenAI, moderations
|
38 |
# import anthropic
|
39 |
import cohere
|
40 |
import faiss
|
|
|
64 |
embeddings = OpenAIEmbeddings(model = embed_model, api_key = openai_key)
|
65 |
nlp = load_nlp()
|
66 |
|
67 |
+
def check_mod(query):
|
68 |
+
mod_report = moderations.create(input=query)
|
69 |
+
for i in mod_report.results[0].categories:
|
70 |
+
if i[1] == True:
|
71 |
+
return True
|
72 |
+
return False
|
73 |
|
74 |
def get_keywords(text, nlp=nlp):
|
75 |
result = []
|
|
|
83 |
return result
|
84 |
|
85 |
def load_arxiv_corpus():
|
86 |
+
# arxiv_corpus = load_from_disk('data/')
|
87 |
+
# arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss')
|
88 |
+
|
89 |
+
# keeping it up to date with the dataset
|
90 |
+
arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data', split='train')
|
91 |
+
arxiv_corpus.add_faiss_index(column='embed')
|
92 |
print('loading arxiv corpus from disk')
|
93 |
return arxiv_corpus
|
94 |
|
|
|
354 |
messages = [("system",question_categorization_prompt,),("human", query),]
|
355 |
return gen_client.invoke(messages).content
|
356 |
|
357 |
+
def log_to_gist(strings):
|
358 |
+
# Adding query logs to prevent and account for possible malicious use.
|
359 |
+
# Logs will be deleted periodically if not needed.
|
360 |
+
github_token = os.environ['github_token']
|
361 |
+
gist_id = os.environ['gist_id']
|
362 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
363 |
+
content = f"\n{timestamp}: {' '.join(strings)}\n"
|
364 |
+
headers = {'Authorization': f'token {github_token}','Accept': 'application/vnd.github.v3+json'}
|
365 |
+
response = requests.get(f'https://api.github.com/gists/{gist_id}', headers=headers)
|
366 |
+
if response.status_code == 200:
|
367 |
+
existing_content = response.json()['files']['log.txt']['content']
|
368 |
+
content = existing_content + content
|
369 |
+
data = {"description": "Logged Strings","public": False,"files": {"log.txt": {"content": content}}}
|
370 |
+
headers = {'Authorization': f'token {github_token}','Accept': 'application/vnd.github.v3+json'}
|
371 |
+
response = requests.patch(f'https://api.github.com/gists/{gist_id}', headers=headers, data=json.dumps(data)) # Update existing gist
|
372 |
+
return
|
373 |
+
|
374 |
class OverallConsensusEvaluation(BaseModel):
|
375 |
rewritten_statement: str = Field(
|
376 |
...,
|
|
|
486 |
search_text_list = ['rooting around in the paper pile...','looking for clarity...','scanning the event horizon...','peering into the abyss...','potatoes power this ongoing search...']
|
487 |
gen_text_list = ['making the LLM talk to the papers...','invoking arcane rituals...','gone to library, please wait...','is there really an answer to this...']
|
488 |
|
489 |
+
log_to_gist(['[mod flag: '+str(check_mod(query))+']', query])
|
490 |
+
if check_mod(query) == False:
|
491 |
+
|
492 |
+
input_keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else []
|
493 |
+
query_keywords = get_keywords(query)
|
494 |
+
ec.query_input_keywords = input_keywords+query_keywords
|
495 |
+
ec.toggles = toggles
|
496 |
+
if rag_type == "Semantic Search":
|
497 |
+
ec.hyde = False
|
498 |
+
ec.rerank = False
|
499 |
+
elif rag_type == "Semantic + HyDE":
|
500 |
+
ec.hyde = True
|
501 |
+
ec.rerank = False
|
502 |
+
elif rag_type == "Semantic + HyDE + CoHERE":
|
503 |
+
ec.hyde = True
|
504 |
+
ec.rerank = True
|
505 |
+
|
506 |
+
progress(0.2, desc=search_text_list[np.random.choice(len(search_text_list))])
|
507 |
+
rs, small_df = ec.retrieve(query, top_k = top_k, return_scores=True)
|
508 |
+
formatted_df = ec.return_formatted_df(rs, small_df)
|
509 |
+
yield formatted_df, None, None, None, None
|
510 |
+
|
511 |
+
progress(0.4, desc=gen_text_list[np.random.choice(len(gen_text_list))])
|
512 |
+
rag_answer = run_rag_qa(query, formatted_df, prompt_type)
|
513 |
+
yield formatted_df, rag_answer['answer'], None, None, None
|
514 |
+
|
515 |
+
progress(0.6, desc="Generating consensus")
|
516 |
+
consensus_answer = evaluate_overall_consensus(query, [formatted_df['abstract'][i+1] for i in range(len(formatted_df))])
|
517 |
+
consensus = '## Consensus \n'+consensus_answer.consensus + '\n\n'+consensus_answer.explanation + '\n\n > Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score
|
518 |
+
yield formatted_df, rag_answer['answer'], consensus, None, None
|
519 |
+
|
520 |
+
progress(0.8, desc="Analyzing question type")
|
521 |
+
question_type_gen = guess_question_type(query)
|
522 |
+
if '<categorization>' in question_type_gen:
|
523 |
+
question_type_gen = question_type_gen.split('<categorization>')[1]
|
524 |
+
if '</categorization>' in question_type_gen:
|
525 |
+
question_type_gen = question_type_gen.split('</categorization>')[0]
|
526 |
+
question_type_gen = question_type_gen.replace('\n',' \n')
|
527 |
+
qn_type = question_type_gen
|
528 |
+
yield formatted_df, rag_answer['answer'], consensus, qn_type, None
|
529 |
+
|
530 |
+
progress(1.0, desc="Visualizing embeddings")
|
531 |
+
fig = make_embedding_plot(formatted_df, top_k, consensus_answer)
|
532 |
+
|
533 |
+
yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
|
534 |
|
535 |
def create_interface():
|
536 |
custom_css = """
|