Kartheik Iyer commited on
Commit
18e51e3
1 Parent(s): 036767e

update dataset and security fixes

Browse files
Files changed (1) hide show
  1. app_gradio.py +75 -45
app_gradio.py CHANGED
@@ -34,7 +34,7 @@ from typing import List, Literal
34
 
35
  from nltk.corpus import stopwords
36
  import nltk
37
- from openai import OpenAI
38
  # import anthropic
39
  import cohere
40
  import faiss
@@ -64,6 +64,12 @@ embed_model = "text-embedding-3-small"
64
  embeddings = OpenAIEmbeddings(model = embed_model, api_key = openai_key)
65
  nlp = load_nlp()
66
 
 
 
 
 
 
 
67
 
68
  def get_keywords(text, nlp=nlp):
69
  result = []
@@ -77,8 +83,12 @@ def get_keywords(text, nlp=nlp):
77
  return result
78
 
79
  def load_arxiv_corpus():
80
- arxiv_corpus = load_from_disk('data/')
81
- arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss')
 
 
 
 
82
  print('loading arxiv corpus from disk')
83
  return arxiv_corpus
84
 
@@ -344,6 +354,23 @@ def guess_question_type(query: str):
344
  messages = [("system",question_categorization_prompt,),("human", query),]
345
  return gen_client.invoke(messages).content
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  class OverallConsensusEvaluation(BaseModel):
348
  rewritten_statement: str = Field(
349
  ...,
@@ -459,48 +486,51 @@ def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type,
459
  search_text_list = ['rooting around in the paper pile...','looking for clarity...','scanning the event horizon...','peering into the abyss...','potatoes power this ongoing search...']
460
  gen_text_list = ['making the LLM talk to the papers...','invoking arcane rituals...','gone to library, please wait...','is there really an answer to this...']
461
 
462
- input_keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else []
463
- query_keywords = get_keywords(query)
464
- ec.query_input_keywords = input_keywords+query_keywords
465
- ec.toggles = toggles
466
- if rag_type == "Semantic Search":
467
- ec.hyde = False
468
- ec.rerank = False
469
- elif rag_type == "Semantic + HyDE":
470
- ec.hyde = True
471
- ec.rerank = False
472
- elif rag_type == "Semantic + HyDE + CoHERE":
473
- ec.hyde = True
474
- ec.rerank = True
475
-
476
- progress(0.2, desc=search_text_list[np.random.choice(len(search_text_list))])
477
- rs, small_df = ec.retrieve(query, top_k = top_k, return_scores=True)
478
- formatted_df = ec.return_formatted_df(rs, small_df)
479
- yield formatted_df, None, None, None, None
480
-
481
- progress(0.4, desc=gen_text_list[np.random.choice(len(gen_text_list))])
482
- rag_answer = run_rag_qa(query, formatted_df, prompt_type)
483
- yield formatted_df, rag_answer['answer'], None, None, None
484
-
485
- progress(0.6, desc="Generating consensus")
486
- consensus_answer = evaluate_overall_consensus(query, [formatted_df['abstract'][i+1] for i in range(len(formatted_df))])
487
- consensus = '## Consensus \n'+consensus_answer.consensus + '\n\n'+consensus_answer.explanation + '\n\n > Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score
488
- yield formatted_df, rag_answer['answer'], consensus, None, None
489
-
490
- progress(0.8, desc="Analyzing question type")
491
- question_type_gen = guess_question_type(query)
492
- if '<categorization>' in question_type_gen:
493
- question_type_gen = question_type_gen.split('<categorization>')[1]
494
- if '</categorization>' in question_type_gen:
495
- question_type_gen = question_type_gen.split('</categorization>')[0]
496
- question_type_gen = question_type_gen.replace('\n',' \n')
497
- qn_type = question_type_gen
498
- yield formatted_df, rag_answer['answer'], consensus, qn_type, None
499
-
500
- progress(1.0, desc="Visualizing embeddings")
501
- fig = make_embedding_plot(formatted_df, top_k, consensus_answer)
502
-
503
- yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
 
 
 
504
 
505
  def create_interface():
506
  custom_css = """
 
34
 
35
  from nltk.corpus import stopwords
36
  import nltk
37
+ from openai import OpenAI, moderations
38
  # import anthropic
39
  import cohere
40
  import faiss
 
64
  embeddings = OpenAIEmbeddings(model = embed_model, api_key = openai_key)
65
  nlp = load_nlp()
66
 
67
+ def check_mod(query):
68
+ mod_report = moderations.create(input=query)
69
+ for i in mod_report.results[0].categories:
70
+ if i[1] == True:
71
+ return True
72
+ return False
73
 
74
  def get_keywords(text, nlp=nlp):
75
  result = []
 
83
  return result
84
 
85
  def load_arxiv_corpus():
86
+ # arxiv_corpus = load_from_disk('data/')
87
+ # arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss')
88
+
89
+ # keeping it up to date with the dataset
90
+ arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data', split='train')
91
+ arxiv_corpus.add_faiss_index(column='embed')
92
  print('loading arxiv corpus from disk')
93
  return arxiv_corpus
94
 
 
354
  messages = [("system",question_categorization_prompt,),("human", query),]
355
  return gen_client.invoke(messages).content
356
 
357
+ def log_to_gist(strings):
358
+ # Adding query logs to prevent and account for possible malicious use.
359
+ # Logs will be deleted periodically if not needed.
360
+ github_token = os.environ['github_token']
361
+ gist_id = os.environ['gist_id']
362
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
363
+ content = f"\n{timestamp}: {' '.join(strings)}\n"
364
+ headers = {'Authorization': f'token {github_token}','Accept': 'application/vnd.github.v3+json'}
365
+ response = requests.get(f'https://api.github.com/gists/{gist_id}', headers=headers)
366
+ if response.status_code == 200:
367
+ existing_content = response.json()['files']['log.txt']['content']
368
+ content = existing_content + content
369
+ data = {"description": "Logged Strings","public": False,"files": {"log.txt": {"content": content}}}
370
+ headers = {'Authorization': f'token {github_token}','Accept': 'application/vnd.github.v3+json'}
371
+ response = requests.patch(f'https://api.github.com/gists/{gist_id}', headers=headers, data=json.dumps(data)) # Update existing gist
372
+ return
373
+
374
  class OverallConsensusEvaluation(BaseModel):
375
  rewritten_statement: str = Field(
376
  ...,
 
486
  search_text_list = ['rooting around in the paper pile...','looking for clarity...','scanning the event horizon...','peering into the abyss...','potatoes power this ongoing search...']
487
  gen_text_list = ['making the LLM talk to the papers...','invoking arcane rituals...','gone to library, please wait...','is there really an answer to this...']
488
 
489
+ log_to_gist(['[mod flag: '+str(check_mod(query))+']', query])
490
+ if check_mod(query) == False:
491
+
492
+ input_keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else []
493
+ query_keywords = get_keywords(query)
494
+ ec.query_input_keywords = input_keywords+query_keywords
495
+ ec.toggles = toggles
496
+ if rag_type == "Semantic Search":
497
+ ec.hyde = False
498
+ ec.rerank = False
499
+ elif rag_type == "Semantic + HyDE":
500
+ ec.hyde = True
501
+ ec.rerank = False
502
+ elif rag_type == "Semantic + HyDE + CoHERE":
503
+ ec.hyde = True
504
+ ec.rerank = True
505
+
506
+ progress(0.2, desc=search_text_list[np.random.choice(len(search_text_list))])
507
+ rs, small_df = ec.retrieve(query, top_k = top_k, return_scores=True)
508
+ formatted_df = ec.return_formatted_df(rs, small_df)
509
+ yield formatted_df, None, None, None, None
510
+
511
+ progress(0.4, desc=gen_text_list[np.random.choice(len(gen_text_list))])
512
+ rag_answer = run_rag_qa(query, formatted_df, prompt_type)
513
+ yield formatted_df, rag_answer['answer'], None, None, None
514
+
515
+ progress(0.6, desc="Generating consensus")
516
+ consensus_answer = evaluate_overall_consensus(query, [formatted_df['abstract'][i+1] for i in range(len(formatted_df))])
517
+ consensus = '## Consensus \n'+consensus_answer.consensus + '\n\n'+consensus_answer.explanation + '\n\n > Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score
518
+ yield formatted_df, rag_answer['answer'], consensus, None, None
519
+
520
+ progress(0.8, desc="Analyzing question type")
521
+ question_type_gen = guess_question_type(query)
522
+ if '<categorization>' in question_type_gen:
523
+ question_type_gen = question_type_gen.split('<categorization>')[1]
524
+ if '</categorization>' in question_type_gen:
525
+ question_type_gen = question_type_gen.split('</categorization>')[0]
526
+ question_type_gen = question_type_gen.replace('\n',' \n')
527
+ qn_type = question_type_gen
528
+ yield formatted_df, rag_answer['answer'], consensus, qn_type, None
529
+
530
+ progress(1.0, desc="Visualizing embeddings")
531
+ fig = make_embedding_plot(formatted_df, top_k, consensus_answer)
532
+
533
+ yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
534
 
535
  def create_interface():
536
  custom_css = """