zhenyundeng commited on
Commit
016ab20
1 Parent(s): 200e5b6

update files

Browse files
Files changed (3) hide show
  1. .gitattributes +3 -1
  2. app.py +179 -11
  3. utils.py +3 -1
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.json filter=lfs diff=lfs merge=lfs -text
36
+ *.db filter=lfs diff=lfs merge=lfs -text
37
+
app.py CHANGED
@@ -69,7 +69,9 @@ nlp = spacy.load("en_core_web_sm")
69
  # ---------------------------------------------------------------------------
70
  # Load sample dict for AVeriTeC search
71
  # all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
 
72
 
 
73
  # ---------------------------------------------------------------------------
74
  # ---------- Load pretrained models ----------
75
  # ---------- load Evidence retrieval model ----------
@@ -424,9 +426,8 @@ def QAprediction(claim, evidence, sources):
424
 
425
  # ----------GoogleAPIretriever---------
426
  def generate_reference_corpus(reference_file):
427
- with open(reference_file) as f:
428
- #
429
- train_examples = json.load(f)
430
 
431
  all_data_corpus = []
432
  tokenized_corpus = []
@@ -578,6 +579,12 @@ def get_and_store(url_link, fp, worker, worker_stack):
578
  gc.collect()
579
 
580
 
 
 
 
 
 
 
581
  def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
582
  search_results = []
583
  for i in range(3):
@@ -599,7 +606,7 @@ def get_google_search_results(api_key, search_engine_id, google_search, sort_dat
599
  return search_results
600
 
601
 
602
- def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
603
  # default config
604
  api_key = os.environ["GOOGLE_API_KEY"]
605
  search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
@@ -651,7 +658,6 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
651
  for page_num in range(n_pages):
652
  search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
653
  this_search_string, page=page_num)
654
- search_results = search_results[:5]
655
 
656
  for result in search_results:
657
  link = str(result["link"])
@@ -668,8 +674,6 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
668
  if link.endswith(".pdf") or link.endswith(".doc"):
669
  continue
670
 
671
- store_file_path = ""
672
-
673
  if link in visited:
674
  store_file_path = visited[link]
675
  else:
@@ -678,7 +682,7 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
678
  store_counter) + ".store"
679
  visited[link] = store_file_path
680
 
681
- while len(worker_stack) == 0: # Wait for a wrrker to become available. Check every second.
682
  sleep(1)
683
 
684
  worker = worker_stack.pop()
@@ -692,6 +696,89 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
692
  return retrieve_evidence
693
 
694
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695
  def claim2prompts(example):
696
  claim = example["claim"]
697
 
@@ -725,8 +812,8 @@ def claim2prompts(example):
725
 
726
 
727
  def generate_step2_reference_corpus(reference_file):
728
- with open(reference_file) as f:
729
- train_examples = json.load(f)
730
 
731
  prompt_corpus = []
732
  tokenized_corpus = []
@@ -762,6 +849,87 @@ def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
762
  tokenized_corpus = []
763
  all_data_corpus = []
764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
  for retri_evi in tqdm.tqdm(retrieve_evidence):
766
  store_file = retri_evi[-1]
767
 
@@ -1222,7 +1390,7 @@ def chat(claim, history, sources):
1222
  try:
1223
  # Log answer on Azure Blob Storage
1224
  # IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
1225
- if bool(os.environ["AZURE_ISSAVE"]):
1226
  timestamp = str(datetime.now().timestamp())
1227
  # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
1228
  file = timestamp + ".json"
 
69
  # ---------------------------------------------------------------------------
70
  # Load sample dict for AVeriTeC search
71
  # all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
72
+ train_examples = json.load(open('averitec/data/train.json', 'r'))
73
 
74
+ print(train_examples[0]['claim'])
75
  # ---------------------------------------------------------------------------
76
  # ---------- Load pretrained models ----------
77
  # ---------- load Evidence retrieval model ----------
 
426
 
427
  # ----------GoogleAPIretriever---------
428
  def generate_reference_corpus(reference_file):
429
+ # with open(reference_file) as f:
430
+ # train_examples = json.load(f)
 
431
 
432
  all_data_corpus = []
433
  tokenized_corpus = []
 
579
  gc.collect()
580
 
581
 
582
+ def get_text_from_link(url_link):
583
+ page_lines = url2lines(url_link)
584
+
585
+ return "\n".join([url_link] + page_lines)
586
+
587
+
588
  def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
589
  search_results = []
590
  for i in range(3):
 
606
  return search_results
607
 
608
 
609
+ def averitec_search_michael(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
610
  # default config
611
  api_key = os.environ["GOOGLE_API_KEY"]
612
  search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
 
658
  for page_num in range(n_pages):
659
  search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
660
  this_search_string, page=page_num)
 
661
 
662
  for result in search_results:
663
  link = str(result["link"])
 
674
  if link.endswith(".pdf") or link.endswith(".doc"):
675
  continue
676
 
 
 
677
  if link in visited:
678
  store_file_path = visited[link]
679
  else:
 
682
  store_counter) + ".store"
683
  visited[link] = store_file_path
684
 
685
+ while len(worker_stack) == 0: # Wait for a worker to become available. Check every second.
686
  sleep(1)
687
 
688
  worker = worker_stack.pop()
 
696
  return retrieve_evidence
697
 
698
 
699
+ def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
700
+ # default config
701
+ api_key = os.environ["GOOGLE_API_KEY"]
702
+ search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
703
+
704
+ blacklist = [
705
+ "jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download
706
+ "facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this,
707
+ "ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up
708
+ "nlp.cs.princeton.edu",
709
+ "huggingface.co"
710
+ ]
711
+
712
+ blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors
713
+ "/glove.",
714
+ "ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt",
715
+ "https://web.mit.edu/adamrose/Public/googlelist",
716
+ ]
717
+
718
+ # save to folder
719
+ store_folder = "averitec/data/store/retrieved_docs"
720
+ #
721
+ index = 0
722
+ questions = [q["question"] for q in generate_question]
723
+
724
+ # check the date of the claim
725
+ current_date = datetime.now().strftime("%Y-%m-%d")
726
+ sort_date = check_claim_date(current_date) # check_date="2022-01-01"
727
+
728
+ #
729
+ search_strings = []
730
+ search_types = []
731
+
732
+ search_string_2 = string_to_search_query(claim, None)
733
+ search_strings += [search_string_2, claim, ]
734
+ search_types += ["claim", "claim-noformat", ]
735
+
736
+ search_strings += questions
737
+ search_types += ["question" for _ in questions]
738
+
739
+ # start to search
740
+ search_results = []
741
+ visited = {}
742
+ store_counter = 0
743
+ worker_stack = list(range(10))
744
+
745
+ retrieve_evidence = []
746
+
747
+ for this_search_string, this_search_type in zip(search_strings, search_types):
748
+ for page_num in range(n_pages):
749
+ search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
750
+ this_search_string, page=page_num)
751
+ search_results = search_results[:5]
752
+
753
+ for result in search_results:
754
+ link = str(result["link"])
755
+ domain = get_domain_name(link)
756
+
757
+ if domain in blacklist:
758
+ continue
759
+ broken = False
760
+ for b_file in blacklist_files:
761
+ if b_file in link:
762
+ broken = True
763
+ if broken:
764
+ continue
765
+ if link.endswith(".pdf") or link.endswith(".doc"):
766
+ continue
767
+
768
+ store_file_path = ""
769
+
770
+ if link in visited:
771
+ web_text = visited[link]
772
+ else:
773
+ web_text = get_text_from_link(link)
774
+ visited[link] = web_text
775
+
776
+ line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, web_text]
777
+ retrieve_evidence.append(line)
778
+
779
+ return retrieve_evidence
780
+
781
+
782
  def claim2prompts(example):
783
  claim = example["claim"]
784
 
 
812
 
813
 
814
  def generate_step2_reference_corpus(reference_file):
815
+ # with open(reference_file) as f:
816
+ # train_examples = json.load(f)
817
 
818
  prompt_corpus = []
819
  tokenized_corpus = []
 
849
  tokenized_corpus = []
850
  all_data_corpus = []
851
 
852
+ for retri_evi in tqdm.tqdm(retrieve_evidence):
853
+ # store_file = retri_evi[-1]
854
+ # with open(store_file, 'r') as f:
855
+ web_text = retri_evi[-1]
856
+ lines_in_web = web_text.split("\n")
857
+
858
+ first = True
859
+ for line in lines_in_web:
860
+ # for line in f:
861
+ line = line.strip()
862
+
863
+ if first:
864
+ first = False
865
+ location_url = line
866
+ continue
867
+
868
+ if len(line) > 3:
869
+ entry = nltk.word_tokenize(line)
870
+ if (location_url, line) not in all_data_corpus:
871
+ tokenized_corpus.append(entry)
872
+ all_data_corpus.append((location_url, line))
873
+
874
+ if len(tokenized_corpus) == 0:
875
+ print("")
876
+
877
+ bm25 = BM25Okapi(tokenized_corpus)
878
+ s = bm25.get_scores(nltk.word_tokenize(claim))
879
+ top_n = np.argsort(s)[::-1][:top_k]
880
+ docs = [all_data_corpus[i] for i in top_n]
881
+
882
+ generate_qa_pairs = []
883
+ # Then, generate questions for those top 50:
884
+ for doc in tqdm.tqdm(docs):
885
+ # prompt_lookup_str = example["claim"] + " " + doc[1]
886
+ prompt_lookup_str = doc[1]
887
+
888
+ prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
889
+ prompt_n = 10
890
+ prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
891
+ prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
892
+
893
+ claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
894
+ prompt = "\n\n".join(prompt_docs + [claim_prompt])
895
+ sentences = [prompt]
896
+
897
+ inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
898
+ outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
899
+ early_stopping=True)
900
+
901
+ tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
902
+ # We are not allowed to generate more than 250 characters:
903
+ tgt_text = tgt_text[:250]
904
+
905
+ qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
906
+ generate_qa_pairs.append(qa_pair)
907
+
908
+ return generate_qa_pairs
909
+
910
+
911
+ def decorate_with_questions_michale(claim, retrieve_evidence, top_k=10): # top_k=100
912
+ #
913
+ reference_file = "averitec/data/train.json"
914
+ tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
915
+ prompt_bm25 = BM25Okapi(tokenized_corpus)
916
+
917
+ # Define the bloom model:
918
+ accelerator = Accelerator()
919
+ accel_device = accelerator.device
920
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
921
+ # tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
922
+ # model = BloomForCausalLM.from_pretrained(
923
+ # "bigscience/bloom-7b1",
924
+ # device_map="auto",
925
+ # torch_dtype=torch.bfloat16,
926
+ # offload_folder="./offload"
927
+ # )
928
+
929
+ #
930
+ tokenized_corpus = []
931
+ all_data_corpus = []
932
+
933
  for retri_evi in tqdm.tqdm(retrieve_evidence):
934
  store_file = retri_evi[-1]
935
 
 
1390
  try:
1391
  # Log answer on Azure Blob Storage
1392
  # IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
1393
+ if os.environ["AZURE_ISSAVE"] == "TRUE":
1394
  timestamp = str(datetime.now().timestamp())
1395
  # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
1396
  file = timestamp + ".json"
utils.py CHANGED
@@ -2,11 +2,13 @@ import numpy as np
2
  import random
3
  import string
4
  import uuid
 
5
 
6
 
7
  def create_user_id():
8
  """Create user_id
9
  str: String to id user
10
  """
 
11
  user_id = str(uuid.uuid4())
12
- return user_id
 
2
  import random
3
  import string
4
  import uuid
5
+ from datetime import datetime
6
 
7
 
8
  def create_user_id():
9
  """Create user_id
10
  str: String to id user
11
  """
12
+ current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
13
  user_id = str(uuid.uuid4())
14
+ return current_date + '_' +user_id