Spaces:
Build error
Build error
zhenyundeng
commited on
Commit
•
016ab20
1
Parent(s):
200e5b6
update files
Browse files- .gitattributes +3 -1
- app.py +179 -11
- utils.py +3 -1
.gitattributes
CHANGED
@@ -25,7 +25,6 @@
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
@@ -33,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
28 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
37 |
+
|
app.py
CHANGED
@@ -69,7 +69,9 @@ nlp = spacy.load("en_core_web_sm")
|
|
69 |
# ---------------------------------------------------------------------------
|
70 |
# Load sample dict for AVeriTeC search
|
71 |
# all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
|
|
|
72 |
|
|
|
73 |
# ---------------------------------------------------------------------------
|
74 |
# ---------- Load pretrained models ----------
|
75 |
# ---------- load Evidence retrieval model ----------
|
@@ -424,9 +426,8 @@ def QAprediction(claim, evidence, sources):
|
|
424 |
|
425 |
# ----------GoogleAPIretriever---------
|
426 |
def generate_reference_corpus(reference_file):
|
427 |
-
with open(reference_file) as f:
|
428 |
-
|
429 |
-
train_examples = json.load(f)
|
430 |
|
431 |
all_data_corpus = []
|
432 |
tokenized_corpus = []
|
@@ -578,6 +579,12 @@ def get_and_store(url_link, fp, worker, worker_stack):
|
|
578 |
gc.collect()
|
579 |
|
580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
581 |
def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
|
582 |
search_results = []
|
583 |
for i in range(3):
|
@@ -599,7 +606,7 @@ def get_google_search_results(api_key, search_engine_id, google_search, sort_dat
|
|
599 |
return search_results
|
600 |
|
601 |
|
602 |
-
def
|
603 |
# default config
|
604 |
api_key = os.environ["GOOGLE_API_KEY"]
|
605 |
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
|
@@ -651,7 +658,6 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
|
|
651 |
for page_num in range(n_pages):
|
652 |
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
|
653 |
this_search_string, page=page_num)
|
654 |
-
search_results = search_results[:5]
|
655 |
|
656 |
for result in search_results:
|
657 |
link = str(result["link"])
|
@@ -668,8 +674,6 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
|
|
668 |
if link.endswith(".pdf") or link.endswith(".doc"):
|
669 |
continue
|
670 |
|
671 |
-
store_file_path = ""
|
672 |
-
|
673 |
if link in visited:
|
674 |
store_file_path = visited[link]
|
675 |
else:
|
@@ -678,7 +682,7 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
|
|
678 |
store_counter) + ".store"
|
679 |
visited[link] = store_file_path
|
680 |
|
681 |
-
while len(worker_stack) == 0: # Wait for a
|
682 |
sleep(1)
|
683 |
|
684 |
worker = worker_stack.pop()
|
@@ -692,6 +696,89 @@ def averitec_search(claim, generate_question, speaker="they", check_date="2024-0
|
|
692 |
return retrieve_evidence
|
693 |
|
694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
def claim2prompts(example):
|
696 |
claim = example["claim"]
|
697 |
|
@@ -725,8 +812,8 @@ def claim2prompts(example):
|
|
725 |
|
726 |
|
727 |
def generate_step2_reference_corpus(reference_file):
|
728 |
-
with open(reference_file) as f:
|
729 |
-
|
730 |
|
731 |
prompt_corpus = []
|
732 |
tokenized_corpus = []
|
@@ -762,6 +849,87 @@ def decorate_with_questions(claim, retrieve_evidence, top_k=10): # top_k=100
|
|
762 |
tokenized_corpus = []
|
763 |
all_data_corpus = []
|
764 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
765 |
for retri_evi in tqdm.tqdm(retrieve_evidence):
|
766 |
store_file = retri_evi[-1]
|
767 |
|
@@ -1222,7 +1390,7 @@ def chat(claim, history, sources):
|
|
1222 |
try:
|
1223 |
# Log answer on Azure Blob Storage
|
1224 |
# IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
|
1225 |
-
if
|
1226 |
timestamp = str(datetime.now().timestamp())
|
1227 |
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
1228 |
file = timestamp + ".json"
|
|
|
69 |
# ---------------------------------------------------------------------------
|
70 |
# Load sample dict for AVeriTeC search
|
71 |
# all_samples_dict = json.load(open('averitec/data/all_samples.json', 'r'))
|
72 |
+
train_examples = json.load(open('averitec/data/train.json', 'r'))
|
73 |
|
74 |
+
print(train_examples[0]['claim'])
|
75 |
# ---------------------------------------------------------------------------
|
76 |
# ---------- Load pretrained models ----------
|
77 |
# ---------- load Evidence retrieval model ----------
|
|
|
426 |
|
427 |
# ----------GoogleAPIretriever---------
|
428 |
def generate_reference_corpus(reference_file):
|
429 |
+
# with open(reference_file) as f:
|
430 |
+
# train_examples = json.load(f)
|
|
|
431 |
|
432 |
all_data_corpus = []
|
433 |
tokenized_corpus = []
|
|
|
579 |
gc.collect()
|
580 |
|
581 |
|
582 |
+
def get_text_from_link(url_link):
|
583 |
+
page_lines = url2lines(url_link)
|
584 |
+
|
585 |
+
return "\n".join([url_link] + page_lines)
|
586 |
+
|
587 |
+
|
588 |
def get_google_search_results(api_key, search_engine_id, google_search, sort_date, search_string, page=0):
|
589 |
search_results = []
|
590 |
for i in range(3):
|
|
|
606 |
return search_results
|
607 |
|
608 |
|
609 |
+
def averitec_search_michael(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
|
610 |
# default config
|
611 |
api_key = os.environ["GOOGLE_API_KEY"]
|
612 |
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
|
|
|
658 |
for page_num in range(n_pages):
|
659 |
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
|
660 |
this_search_string, page=page_num)
|
|
|
661 |
|
662 |
for result in search_results:
|
663 |
link = str(result["link"])
|
|
|
674 |
if link.endswith(".pdf") or link.endswith(".doc"):
|
675 |
continue
|
676 |
|
|
|
|
|
677 |
if link in visited:
|
678 |
store_file_path = visited[link]
|
679 |
else:
|
|
|
682 |
store_counter) + ".store"
|
683 |
visited[link] = store_file_path
|
684 |
|
685 |
+
while len(worker_stack) == 0: # Wait for a worker to become available. Check every second.
|
686 |
sleep(1)
|
687 |
|
688 |
worker = worker_stack.pop()
|
|
|
696 |
return retrieve_evidence
|
697 |
|
698 |
|
699 |
+
def averitec_search(claim, generate_question, speaker="they", check_date="2024-07-01", n_pages=1): # n_pages=3
|
700 |
+
# default config
|
701 |
+
api_key = os.environ["GOOGLE_API_KEY"]
|
702 |
+
search_engine_id = os.environ["GOOGLE_SEARCH_ENGINE_ID"]
|
703 |
+
|
704 |
+
blacklist = [
|
705 |
+
"jstor.org", # Blacklisted because their pdfs are not labelled as such, and clog up the download
|
706 |
+
"facebook.com", # Blacklisted because only post titles can be scraped, but the scraper doesn't know this,
|
707 |
+
"ftp.cs.princeton.edu", # Blacklisted because it hosts many large NLP corpora that keep showing up
|
708 |
+
"nlp.cs.princeton.edu",
|
709 |
+
"huggingface.co"
|
710 |
+
]
|
711 |
+
|
712 |
+
blacklist_files = [ # Blacklisted some NLP nonsense that crashes my machine with OOM errors
|
713 |
+
"/glove.",
|
714 |
+
"ftp://ftp.cs.princeton.edu/pub/cs226/autocomplete/words-333333.txt",
|
715 |
+
"https://web.mit.edu/adamrose/Public/googlelist",
|
716 |
+
]
|
717 |
+
|
718 |
+
# save to folder
|
719 |
+
store_folder = "averitec/data/store/retrieved_docs"
|
720 |
+
#
|
721 |
+
index = 0
|
722 |
+
questions = [q["question"] for q in generate_question]
|
723 |
+
|
724 |
+
# check the date of the claim
|
725 |
+
current_date = datetime.now().strftime("%Y-%m-%d")
|
726 |
+
sort_date = check_claim_date(current_date) # check_date="2022-01-01"
|
727 |
+
|
728 |
+
#
|
729 |
+
search_strings = []
|
730 |
+
search_types = []
|
731 |
+
|
732 |
+
search_string_2 = string_to_search_query(claim, None)
|
733 |
+
search_strings += [search_string_2, claim, ]
|
734 |
+
search_types += ["claim", "claim-noformat", ]
|
735 |
+
|
736 |
+
search_strings += questions
|
737 |
+
search_types += ["question" for _ in questions]
|
738 |
+
|
739 |
+
# start to search
|
740 |
+
search_results = []
|
741 |
+
visited = {}
|
742 |
+
store_counter = 0
|
743 |
+
worker_stack = list(range(10))
|
744 |
+
|
745 |
+
retrieve_evidence = []
|
746 |
+
|
747 |
+
for this_search_string, this_search_type in zip(search_strings, search_types):
|
748 |
+
for page_num in range(n_pages):
|
749 |
+
search_results = get_google_search_results(api_key, search_engine_id, google_search, sort_date,
|
750 |
+
this_search_string, page=page_num)
|
751 |
+
search_results = search_results[:5]
|
752 |
+
|
753 |
+
for result in search_results:
|
754 |
+
link = str(result["link"])
|
755 |
+
domain = get_domain_name(link)
|
756 |
+
|
757 |
+
if domain in blacklist:
|
758 |
+
continue
|
759 |
+
broken = False
|
760 |
+
for b_file in blacklist_files:
|
761 |
+
if b_file in link:
|
762 |
+
broken = True
|
763 |
+
if broken:
|
764 |
+
continue
|
765 |
+
if link.endswith(".pdf") or link.endswith(".doc"):
|
766 |
+
continue
|
767 |
+
|
768 |
+
store_file_path = ""
|
769 |
+
|
770 |
+
if link in visited:
|
771 |
+
web_text = visited[link]
|
772 |
+
else:
|
773 |
+
web_text = get_text_from_link(link)
|
774 |
+
visited[link] = web_text
|
775 |
+
|
776 |
+
line = [str(index), claim, link, str(page_num), this_search_string, this_search_type, web_text]
|
777 |
+
retrieve_evidence.append(line)
|
778 |
+
|
779 |
+
return retrieve_evidence
|
780 |
+
|
781 |
+
|
782 |
def claim2prompts(example):
|
783 |
claim = example["claim"]
|
784 |
|
|
|
812 |
|
813 |
|
814 |
def generate_step2_reference_corpus(reference_file):
|
815 |
+
# with open(reference_file) as f:
|
816 |
+
# train_examples = json.load(f)
|
817 |
|
818 |
prompt_corpus = []
|
819 |
tokenized_corpus = []
|
|
|
849 |
tokenized_corpus = []
|
850 |
all_data_corpus = []
|
851 |
|
852 |
+
for retri_evi in tqdm.tqdm(retrieve_evidence):
|
853 |
+
# store_file = retri_evi[-1]
|
854 |
+
# with open(store_file, 'r') as f:
|
855 |
+
web_text = retri_evi[-1]
|
856 |
+
lines_in_web = web_text.split("\n")
|
857 |
+
|
858 |
+
first = True
|
859 |
+
for line in lines_in_web:
|
860 |
+
# for line in f:
|
861 |
+
line = line.strip()
|
862 |
+
|
863 |
+
if first:
|
864 |
+
first = False
|
865 |
+
location_url = line
|
866 |
+
continue
|
867 |
+
|
868 |
+
if len(line) > 3:
|
869 |
+
entry = nltk.word_tokenize(line)
|
870 |
+
if (location_url, line) not in all_data_corpus:
|
871 |
+
tokenized_corpus.append(entry)
|
872 |
+
all_data_corpus.append((location_url, line))
|
873 |
+
|
874 |
+
if len(tokenized_corpus) == 0:
|
875 |
+
print("")
|
876 |
+
|
877 |
+
bm25 = BM25Okapi(tokenized_corpus)
|
878 |
+
s = bm25.get_scores(nltk.word_tokenize(claim))
|
879 |
+
top_n = np.argsort(s)[::-1][:top_k]
|
880 |
+
docs = [all_data_corpus[i] for i in top_n]
|
881 |
+
|
882 |
+
generate_qa_pairs = []
|
883 |
+
# Then, generate questions for those top 50:
|
884 |
+
for doc in tqdm.tqdm(docs):
|
885 |
+
# prompt_lookup_str = example["claim"] + " " + doc[1]
|
886 |
+
prompt_lookup_str = doc[1]
|
887 |
+
|
888 |
+
prompt_s = prompt_bm25.get_scores(nltk.word_tokenize(prompt_lookup_str))
|
889 |
+
prompt_n = 10
|
890 |
+
prompt_top_n = np.argsort(prompt_s)[::-1][:prompt_n]
|
891 |
+
prompt_docs = [prompt_corpus[i] for i in prompt_top_n]
|
892 |
+
|
893 |
+
claim_prompt = "Evidence: " + doc[1].replace("\n", " ") + "\nQuestion answered: "
|
894 |
+
prompt = "\n\n".join(prompt_docs + [claim_prompt])
|
895 |
+
sentences = [prompt]
|
896 |
+
|
897 |
+
inputs = qg_tokenizer(sentences, padding=True, return_tensors="pt").to(device)
|
898 |
+
outputs = qg_model.generate(inputs["input_ids"], max_length=5000, num_beams=2, no_repeat_ngram_size=2,
|
899 |
+
early_stopping=True)
|
900 |
+
|
901 |
+
tgt_text = qg_tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:], skip_special_tokens=True)[0]
|
902 |
+
# We are not allowed to generate more than 250 characters:
|
903 |
+
tgt_text = tgt_text[:250]
|
904 |
+
|
905 |
+
qa_pair = [tgt_text.strip().split("?")[0].replace("\n", " ") + "?", doc[1].replace("\n", " "), doc[0]]
|
906 |
+
generate_qa_pairs.append(qa_pair)
|
907 |
+
|
908 |
+
return generate_qa_pairs
|
909 |
+
|
910 |
+
|
911 |
+
def decorate_with_questions_michale(claim, retrieve_evidence, top_k=10): # top_k=100
|
912 |
+
#
|
913 |
+
reference_file = "averitec/data/train.json"
|
914 |
+
tokenized_corpus, prompt_corpus = generate_step2_reference_corpus(reference_file)
|
915 |
+
prompt_bm25 = BM25Okapi(tokenized_corpus)
|
916 |
+
|
917 |
+
# Define the bloom model:
|
918 |
+
accelerator = Accelerator()
|
919 |
+
accel_device = accelerator.device
|
920 |
+
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
921 |
+
# tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-7b1")
|
922 |
+
# model = BloomForCausalLM.from_pretrained(
|
923 |
+
# "bigscience/bloom-7b1",
|
924 |
+
# device_map="auto",
|
925 |
+
# torch_dtype=torch.bfloat16,
|
926 |
+
# offload_folder="./offload"
|
927 |
+
# )
|
928 |
+
|
929 |
+
#
|
930 |
+
tokenized_corpus = []
|
931 |
+
all_data_corpus = []
|
932 |
+
|
933 |
for retri_evi in tqdm.tqdm(retrieve_evidence):
|
934 |
store_file = retri_evi[-1]
|
935 |
|
|
|
1390 |
try:
|
1391 |
# Log answer on Azure Blob Storage
|
1392 |
# IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
|
1393 |
+
if os.environ["AZURE_ISSAVE"] == "TRUE":
|
1394 |
timestamp = str(datetime.now().timestamp())
|
1395 |
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
1396 |
file = timestamp + ".json"
|
utils.py
CHANGED
@@ -2,11 +2,13 @@ import numpy as np
|
|
2 |
import random
|
3 |
import string
|
4 |
import uuid
|
|
|
5 |
|
6 |
|
7 |
def create_user_id():
|
8 |
"""Create user_id
|
9 |
str: String to id user
|
10 |
"""
|
|
|
11 |
user_id = str(uuid.uuid4())
|
12 |
-
return user_id
|
|
|
2 |
import random
|
3 |
import string
|
4 |
import uuid
|
5 |
+
from datetime import datetime
|
6 |
|
7 |
|
8 |
def create_user_id():
|
9 |
"""Create user_id
|
10 |
str: String to id user
|
11 |
"""
|
12 |
+
current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
13 |
user_id = str(uuid.uuid4())
|
14 |
+
return current_date + '_' +user_id
|