zhenyundeng commited on
Commit
dd80156
1 Parent(s): d8442ad

udpate app.py

Browse files
.git-credentials ADDED
@@ -0,0 +1 @@
 
 
1
+ https://your_access_token:x-oauth-basic@huggingface.co
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.json filter=lfs diff=lfs merge=lfs -text
36
+ *.db filter=lfs diff=lfs merge=lfs -text
37
+
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .env
2
+ __pycache__/app.cpython-38.pyc
3
+ __pycache__/app.cpython-39.pyc
4
+ __pycache__/utils.cpython-38.pyc
5
+
6
+ notebooks/
7
+ *.pyc
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Averitec Api Gpu
3
- emoji: 🌍
4
  colorFrom: indigo
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.43.0
8
  app_file: app.py
 
1
  ---
2
  title: Averitec Api Gpu
3
+ emoji: 🏆
4
  colorFrom: indigo
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.43.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Created by zd302 at 17/07/2024
4
+
5
+ from fastapi import FastAPI
6
+ from pydantic import BaseModel
7
+ # from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation
8
+ import uvicorn
9
+ import spaces
10
+
11
+ app = FastAPI()
12
+
13
+ # ---------------------------------------------------------------------------------------------------------------------
14
+ import gradio as gr
15
+ import os
16
+ import torch
17
+ import json
18
+ import numpy as np
19
+ import requests
20
+ from rank_bm25 import BM25Okapi
21
+ from bs4 import BeautifulSoup
22
+ from datetime import datetime
23
+
24
+ from transformers import BartTokenizer, BartForConditionalGeneration
25
+ from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
26
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification
27
+ import pytorch_lightning as pl
28
+
29
+ from averitec.models.DualEncoderModule import DualEncoderModule
30
+ from averitec.models.SequenceClassificationModule import SequenceClassificationModule
31
+ from averitec.models.JustificationGenerationModule import JustificationGenerationModule
32
+
33
+ # ---------------------------------------------------------------------------------------------------------------------
34
+ import wikipediaapi
35
+ wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en')
36
+
37
+ import nltk
38
+ nltk.download('punkt')
39
+ nltk.download('punkt_tab')
40
+ from nltk import pos_tag, word_tokenize, sent_tokenize
41
+
42
+ import spacy
43
+ os.system("python -m spacy download en_core_web_sm")
44
+ nlp = spacy.load("en_core_web_sm")
45
+
46
+ # ---------------------------------------------------------------------------------------------------------------------
47
+ # ---------------------------------------------------------------------------
48
+ # load .env
49
+ from utils import create_user_id
50
+ user_id = create_user_id()
51
+
52
+ from azure.storage.fileshare import ShareServiceClient
53
+ try:
54
+ from dotenv import load_dotenv
55
+ load_dotenv()
56
+ except Exception as e:
57
+ pass
58
+
59
+ account_url = os.environ["AZURE_ACCOUNT_URL"]
60
+ credential = {
61
+ "account_key": os.environ['AZURE_ACCOUNT_KEY'],
62
+ "account_name": os.environ['AZURE_ACCOUNT_NAME']
63
+ }
64
+
65
+ file_share_name = "averitec"
66
+ azure_service = ShareServiceClient(account_url=account_url, credential=credential)
67
+ azure_share_client = azure_service.get_share_client(file_share_name)
68
+
69
+ # ---------- Setting ----------
70
+ # ---------- Load Veracity and Justification prediction model ----------
71
+ LABEL = [
72
+ "Supported",
73
+ "Refuted",
74
+ "Not Enough Evidence",
75
+ "Conflicting Evidence/Cherrypicking",
76
+ ]
77
+
78
+ if torch.cuda.is_available():
79
+ # Veracity
80
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
81
+ veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
82
+ bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
83
+ veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
84
+ veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model)
85
+ # veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
86
+
87
+ # Justification
88
+ justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
89
+ bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
90
+ best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
91
+ justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model)
92
+ # justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
93
+ # ---------------------------------------------------------------------------
94
+
95
+ # ----------------------------------------------------------------------------
96
+ class Docs:
97
+ def __init__(self, metadata=dict(), page_content=""):
98
+ self.metadata = metadata
99
+ self.page_content = page_content
100
+
101
+
102
+ # ------------------------------ Googleretriever -----------------------------
103
+ def Googleretriever():
104
+
105
+
106
+ return 0
107
+
108
+ # ------------------------------ Googleretriever -----------------------------
109
+
110
+ # ------------------------------ Wikipediaretriever --------------------------
111
+ def search_entity_wikipeida(entity):
112
+ find_evidence = []
113
+
114
+ page_py = wiki_wiki.page(entity)
115
+ if page_py.exists():
116
+ introduction = page_py.summary
117
+ find_evidence.append([str(entity), introduction])
118
+
119
+ return find_evidence
120
+
121
+
122
+ def clean_str(p):
123
+ return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
124
+
125
+
126
+ def find_similar_wikipedia(entity, relevant_wikipages):
127
+ # If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
128
+ ent_ = entity.replace(" ", "+")
129
+ search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
130
+ response_text = requests.get(search_url).text
131
+ soup = BeautifulSoup(response_text, features="html.parser")
132
+ result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
133
+
134
+ if result_divs:
135
+ result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
136
+ similar_titles = result_titles[:5]
137
+
138
+ saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
139
+ for _t in similar_titles:
140
+ if _t not in saved_titles and len(relevant_wikipages) < 5:
141
+ _evi = search_entity_wikipeida(_t)
142
+ # _evi = search_step(_t)
143
+ relevant_wikipages.extend(_evi)
144
+
145
+ return relevant_wikipages
146
+
147
+
148
+ def find_evidence_from_wikipedia(claim):
149
+ #
150
+ doc = nlp(claim)
151
+ #
152
+ wikipedia_page = []
153
+ for ent in doc.ents:
154
+ relevant_wikipages = search_entity_wikipeida(ent)
155
+
156
+ if len(relevant_wikipages) < 5:
157
+ relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
158
+
159
+ wikipedia_page.extend(relevant_wikipages)
160
+
161
+ return wikipedia_page
162
+
163
+
164
+ def bm25_retriever(query, corpus, topk=3):
165
+ bm25 = BM25Okapi(corpus)
166
+ #
167
+ query_tokens = word_tokenize(query)
168
+ scores = bm25.get_scores(query_tokens)
169
+ top_n = np.argsort(scores)[::-1][:topk]
170
+ top_n_scores = [scores[i] for i in top_n]
171
+
172
+ return top_n, top_n_scores
173
+
174
+
175
+ def relevant_sentence_retrieval(query, wiki_intro, k):
176
+ # 1. Create corpus here
177
+ corpus, sentences = [], []
178
+ titles = []
179
+ for i, (title, intro) in enumerate(wiki_intro):
180
+ sents_in_intro = sent_tokenize(intro)
181
+
182
+ for sent in sents_in_intro:
183
+ corpus.append(word_tokenize(sent))
184
+ sentences.append(sent)
185
+ titles.append(title)
186
+
187
+ # ----- BM25
188
+ bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
189
+ bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
190
+ bm25_top_n_titles = [titles[i] for i in bm25_top_n]
191
+
192
+ return bm25_top_n_sents, bm25_top_n_titles
193
+
194
+ # ------------------------------ Wikipediaretriever -----------------------------
195
+
196
+ def Wikipediaretriever(claim):
197
+ # 1. extract relevant wikipedia pages from wikipedia dumps
198
+ wikipedia_page = find_evidence_from_wikipedia(claim)
199
+
200
+ # 2. extract relevant sentences from extracted wikipedia pages
201
+ sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
202
+
203
+ #
204
+ results = []
205
+ for i, (sent, title) in enumerate(zip(sents, titles)):
206
+ metadata = dict()
207
+ metadata['name'] = claim
208
+ metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
209
+ metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
210
+ metadata['short_name'] = "Evidence {}".format(i + 1)
211
+ metadata['page_number'] = ""
212
+ metadata['query'] = sent
213
+ metadata['title'] = title
214
+ metadata['evidence'] = sent
215
+ metadata['answer'] = ""
216
+ metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata['evidence']
217
+ page_content = f"""{metadata['page_content']}"""
218
+
219
+ results.append(Docs(metadata, page_content))
220
+
221
+ return results
222
+
223
+
224
+ # ------------------------------ Veracity Prediction ------------------------------
225
+ class SequenceClassificationDataLoader(pl.LightningDataModule):
226
+ def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
227
+ super().__init__()
228
+ self.tokenizer = tokenizer
229
+ self.data_file = data_file
230
+ self.batch_size = batch_size
231
+ self.add_extra_nee = add_extra_nee
232
+
233
+ def tokenize_strings(
234
+ self,
235
+ source_sentences,
236
+ max_length=400,
237
+ pad_to_max_length=False,
238
+ return_tensors="pt",
239
+ ):
240
+ encoded_dict = self.tokenizer(
241
+ source_sentences,
242
+ max_length=max_length,
243
+ padding="max_length" if pad_to_max_length else "longest",
244
+ truncation=True,
245
+ return_tensors=return_tensors,
246
+ )
247
+
248
+ input_ids = encoded_dict["input_ids"]
249
+ attention_masks = encoded_dict["attention_mask"]
250
+
251
+ return input_ids, attention_masks
252
+
253
+ def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
254
+ if bool_explanation is not None and len(bool_explanation) > 0:
255
+ bool_explanation = ", because " + bool_explanation.lower().strip()
256
+ else:
257
+ bool_explanation = ""
258
+ return (
259
+ "[CLAIM] "
260
+ + claim.strip()
261
+ + " [QUESTION] "
262
+ + question.strip()
263
+ + " "
264
+ + answer.strip()
265
+ + bool_explanation
266
+ )
267
+
268
+
269
+ @spaces.GPU
270
+ def veracity_prediction(claim, evidence):
271
+ dataLoader = SequenceClassificationDataLoader(
272
+ tokenizer=veracity_tokenizer,
273
+ data_file="this_is_discontinued",
274
+ batch_size=32,
275
+ add_extra_nee=False,
276
+ )
277
+
278
+ evidence_strings = []
279
+ for evi in evidence:
280
+ evidence_strings.append(dataLoader.quadruple_to_string(claim, evi.metadata["query"], evi.metadata["answer"], ""))
281
+
282
+ if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
283
+ pred_label = "Not Enough Evidence"
284
+ return pred_label
285
+
286
+ tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
287
+ example_support = torch.argmax(veracity_model(tokenized_strings.to(veracity_model.device), attention_mask=attention_mask.to(veracity_model.device)).logits, axis=1)
288
+ # example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
289
+
290
+ has_unanswerable = False
291
+ has_true = False
292
+ has_false = False
293
+
294
+ for v in example_support:
295
+ if v == 0:
296
+ has_true = True
297
+ if v == 1:
298
+ has_false = True
299
+ if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
300
+ has_unanswerable = True
301
+
302
+ if has_unanswerable:
303
+ answer = 2
304
+ elif has_true and not has_false:
305
+ answer = 0
306
+ elif not has_true and has_false:
307
+ answer = 1
308
+ else:
309
+ answer = 3
310
+
311
+ pred_label = LABEL[answer]
312
+
313
+ return pred_label
314
+
315
+
316
+ # ------------------------------ Justification Generation ------------------------------
317
+ def extract_claim_str(claim, evidence, verdict_label):
318
+ claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
319
+
320
+ for evi in evidence:
321
+ q_text = evi.metadata['query'].strip()
322
+
323
+ if len(q_text) == 0:
324
+ continue
325
+
326
+ if not q_text[-1] == "?":
327
+ q_text += "?"
328
+
329
+ answer_strings = []
330
+ answer_strings.append(evi.metadata['answer'])
331
+
332
+ claim_str += q_text
333
+ for a_text in answer_strings:
334
+ if a_text:
335
+ if not a_text[-1] == ".":
336
+ a_text += "."
337
+ claim_str += " " + a_text.strip()
338
+
339
+ claim_str += " "
340
+
341
+ claim_str += " [VERDICT] " + verdict_label
342
+
343
+ return claim_str
344
+
345
+
346
+ @spaces.GPU
347
+ def justification_generation(claim, evidence, verdict_label):
348
+ #
349
+ # claim_str = extract_claim_str(claim, evidence, verdict_label)
350
+ claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
351
+
352
+ for evi in evidence:
353
+ q_text = evi.metadata['query'].strip()
354
+
355
+ if len(q_text) == 0:
356
+ continue
357
+
358
+ if not q_text[-1] == "?":
359
+ q_text += "?"
360
+
361
+ answer_strings = []
362
+ answer_strings.append(evi.metadata['answer'])
363
+
364
+ claim_str += q_text
365
+ for a_text in answer_strings:
366
+ if a_text:
367
+ if not a_text[-1] == ".":
368
+ a_text += "."
369
+ claim_str += " " + a_text.strip()
370
+
371
+ claim_str += " "
372
+
373
+ claim_str += " [VERDICT] " + verdict_label
374
+ #
375
+ claim_str.strip()
376
+ pred_justification = justification_model.generate(claim_str, device=justification_model.device)
377
+ # pred_justification = justification_model.generate(claim_str, device=device)
378
+
379
+ return pred_justification.strip()
380
+
381
+
382
+ # ---------------------------------------------------------------------------------------------------------------------
383
+ class Item(BaseModel):
384
+ claim: str
385
+ source: str
386
+
387
+
388
+ @app.get("/")
389
+ @spaces.GPU
390
+ def greet_json():
391
+ return {"Hello": "World!"}
392
+
393
+
394
+ def log_on_azure(file, logs, azure_share_client):
395
+ logs = json.dumps(logs)
396
+ file_client = azure_share_client.get_file_client(file)
397
+ file_client.upload_file(logs)
398
+
399
+
400
+ @app.post("/predict/")
401
+ @spaces.GPU
402
+ def fact_checking(item: Item):
403
+ # claim = item['claim']
404
+ # source = item['source']
405
+ claim = item.claim
406
+ source = item.source
407
+
408
+ # Step1: Evidence Retrieval
409
+ if source == "Wikipedia":
410
+ evidence = Wikipediaretriever(claim)
411
+ elif source == "Google":
412
+ evidence = Googleretriever(claim)
413
+
414
+ # Step2: Veracity Prediction and Justification Generation
415
+ verdict_label = veracity_prediction(claim, evidence)
416
+ justification_label = justification_generation(claim, evidence, verdict_label)
417
+
418
+ ############################################################
419
+ evidence_list = []
420
+ for evi in evidence:
421
+ title_str = evi.metadata['title']
422
+ evi_str = evi.metadata['evidence']
423
+ url_str = evi.metadata['url']
424
+ evidence_list.append([title_str, evi_str, url_str])
425
+
426
+ try:
427
+ # Log answer on Azure Blob Storage
428
+ # IF AZURE_ISSAVE=TRUE, save the logs into the Azure share client.
429
+ if os.environ["AZURE_ISSAVE"] == "TRUE":
430
+ timestamp = str(datetime.now().timestamp())
431
+ # timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
432
+ file = timestamp + ".json"
433
+ logs = {
434
+ "user_id": str(user_id),
435
+ "claim": claim,
436
+ "sources": source,
437
+ "evidence": evidence_list,
438
+ "answer": [verdict_label, justification_label],
439
+ "time": timestamp,
440
+ }
441
+ log_on_azure(file, logs, azure_share_client)
442
+ except Exception as e:
443
+ print(f"Error logging on Azure Blob Storage: {e}")
444
+ raise gr.Error(
445
+ f"AVeriTeC Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
446
+ ##########
447
+
448
+ return {"Verdict": verdict_label, "Justification": justification_label, "Evidence": evidence_list}
449
+
450
+
451
+ if __name__ == "__main__":
452
+ uvicorn.run(app, host="0.0.0.0", port=7860)
453
+
454
+
455
+ # if __name__ == "__main__":
456
+ # item = {
457
+ # "claim": "England won the Euro 2024.",
458
+ # "source": "Wikipedia",
459
+ # }
460
+ #
461
+ # results = fact_checking(item)
462
+ #
463
+ # print(results)
464
+
averitec/data/all_samples.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef79bab962c2b17d56eb2582b9919bfe8023858fa13ba20c591900857b561854
3
+ size 11444395
averitec/data/sample_claims.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Created by zd302 at 09/05/2024
4
+
5
+
6
+ CLAIMS_Type = {
7
+ "Claim": [
8
+ "England won the Euro 2024.",
9
+ "Albert Einstein works in the field of computer science.",
10
+ ],
11
+ "Event/Property Claim": [
12
+ 'Hunter Biden had no experience in Ukraine or in the energy sector when he joined the board of Burisma.',
13
+ "After the police shooting of Jacob Blake, Gov. Tony Evers & Lt. Gov. Mandela Barnes did not call for peace or encourage calm.",
14
+ "President Trump fully co-operated with the investigation into Russian interference in the 2016 U.S presidential campaign.",
15
+ ],
16
+ "Causal Claim":[
17
+ "Anxiety levels among young teenagers dropped during the coronavirus pandemic, a study has suggested",
18
+ "Auto workers across Michigan could have lost their jobs if not for Barack Obama and Joe Biden",
19
+ ],
20
+ "Numerical Claim":[
21
+ "Sweden, despite never having had lockdown, has a lower COVID-19 death rate than Spain, Italy, and the United Kingdom.",
22
+ "According to Harry Roque, even if 10,000 people die, 10 million COVID-19 cases in the country will not be a loss.",
23
+ ]
24
+ }
25
+
26
+ CLAIMS_FACT_CHECKING_STRATEGY= {
27
+ "Written Evidence": [
28
+ "Pretty Little Thing's terms and conditions state that its products may contain chemicals that can cause cancer, birth defects or other reproductive harm.",
29
+ "Pretty Little Thing products may contain chemicals that can cause cancer, birth defects or other reproductive harm.",
30
+ ],
31
+ "Numerical Comparison":[
32
+ "Congress party claims regarding shortfall in Government earnings",
33
+ "On average, one person dies by suicide every 22 hours in West Virginia, United States.",
34
+ ],
35
+ "Consultation":[
36
+ "Your reaction to an optical illusion is an indication of your state of mind.",
37
+ "The last time people created a Hollywood blacklist, people ended up killing themselves. They were accused, and they lost their right to work.",
38
+ ]
39
+ }
averitec/data/train.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae5eda7c42ddf1695ef185a7ba1bc716928f5adf57103e4f78aae5f9afe00f9c
3
+ size 10184813
averitec/models/AveritecModule.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Created by zd302 at 17/07/2024
4
+
5
+ import torch
6
+ import numpy as np
7
+ import requests
8
+ from rank_bm25 import BM25Okapi
9
+ from bs4 import BeautifulSoup
10
+
11
+ from transformers import BartTokenizer, BartForConditionalGeneration
12
+ from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
13
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification
14
+ import pytorch_lightning as pl
15
+
16
+ from averitec.models.DualEncoderModule import DualEncoderModule
17
+ from averitec.models.SequenceClassificationModule import SequenceClassificationModule
18
+ from averitec.models.JustificationGenerationModule import JustificationGenerationModule
19
+
20
+
21
+ import wikipediaapi
22
+ wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en')
23
+ import os
24
+ import nltk
25
+ nltk.download('punkt')
26
+ from nltk import pos_tag, word_tokenize, sent_tokenize
27
+
28
+ import spacy
29
+ os.system("python -m spacy download en_core_web_sm")
30
+ nlp = spacy.load("en_core_web_sm")
31
+
32
+ # ---------- Load Veracity and Justification prediction model ----------
33
+ LABEL = [
34
+ "Supported",
35
+ "Refuted",
36
+ "Not Enough Evidence",
37
+ "Conflicting Evidence/Cherrypicking",
38
+ ]
39
+
40
+ # Veracity
41
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
42
+ veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
43
+ bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
44
+ veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
45
+ tokenizer=veracity_tokenizer, model=bert_model).to(device)
46
+ # Justification
47
+ justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
48
+ bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
49
+ best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
50
+ justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
51
+ # ---------------------------------------------------------------------------
52
+
53
+
54
+ # ----------------------------------------------------------------------------
55
+ class Docs:
56
+ def __init__(self, metadata=dict(), page_content=""):
57
+ self.metadata = metadata
58
+ self.page_content = page_content
59
+
60
+
61
+ # ------------------------------ Googleretriever -----------------------------
62
+ def Googleretriever():
63
+
64
+
65
+ return 0
66
+
67
+ # ------------------------------ Googleretriever -----------------------------
68
+
69
+ # ------------------------------ Wikipediaretriever --------------------------
70
+ def search_entity_wikipeida(entity):
71
+ find_evidence = []
72
+
73
+ page_py = wiki_wiki.page(entity)
74
+ if page_py.exists():
75
+ introduction = page_py.summary
76
+ find_evidence.append([str(entity), introduction])
77
+
78
+ return find_evidence
79
+
80
+
81
+ def clean_str(p):
82
+ return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
83
+
84
+
85
+ def find_similar_wikipedia(entity, relevant_wikipages):
86
+ # If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
87
+ ent_ = entity.replace(" ", "+")
88
+ search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
89
+ response_text = requests.get(search_url).text
90
+ soup = BeautifulSoup(response_text, features="html.parser")
91
+ result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
92
+
93
+ if result_divs:
94
+ result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
95
+ similar_titles = result_titles[:5]
96
+
97
+ saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
98
+ for _t in similar_titles:
99
+ if _t not in saved_titles and len(relevant_wikipages) < 5:
100
+ _evi = search_entity_wikipeida(_t)
101
+ # _evi = search_step(_t)
102
+ relevant_wikipages.extend(_evi)
103
+
104
+ return relevant_wikipages
105
+
106
+
107
+ def find_evidence_from_wikipedia(claim):
108
+ #
109
+ doc = nlp(claim)
110
+ #
111
+ wikipedia_page = []
112
+ for ent in doc.ents:
113
+ relevant_wikipages = search_entity_wikipeida(ent)
114
+
115
+ if len(relevant_wikipages) < 5:
116
+ relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
117
+
118
+ wikipedia_page.extend(relevant_wikipages)
119
+
120
+ return wikipedia_page
121
+
122
+
123
+ def bm25_retriever(query, corpus, topk=3):
124
+ bm25 = BM25Okapi(corpus)
125
+ #
126
+ query_tokens = word_tokenize(query)
127
+ scores = bm25.get_scores(query_tokens)
128
+ top_n = np.argsort(scores)[::-1][:topk]
129
+ top_n_scores = [scores[i] for i in top_n]
130
+
131
+ return top_n, top_n_scores
132
+
133
+
134
+ def relevant_sentence_retrieval(query, wiki_intro, k):
135
+ # 1. Create corpus here
136
+ corpus, sentences = [], []
137
+ titles = []
138
+ for i, (title, intro) in enumerate(wiki_intro):
139
+ sents_in_intro = sent_tokenize(intro)
140
+
141
+ for sent in sents_in_intro:
142
+ corpus.append(word_tokenize(sent))
143
+ sentences.append(sent)
144
+ titles.append(title)
145
+
146
+ # ----- BM25
147
+ bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
148
+ bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
149
+ bm25_top_n_titles = [titles[i] for i in bm25_top_n]
150
+
151
+ return bm25_top_n_sents, bm25_top_n_titles
152
+
153
+ # ------------------------------ Wikipediaretriever -----------------------------
154
+
155
+ def Wikipediaretriever(claim):
156
+ # 1. extract relevant wikipedia pages from wikipedia dumps
157
+ wikipedia_page = find_evidence_from_wikipedia(claim)
158
+
159
+ # 2. extract relevant sentences from extracted wikipedia pages
160
+ sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
161
+
162
+ #
163
+ results = []
164
+ for i, (sent, title) in enumerate(zip(sents, titles)):
165
+ metadata = dict()
166
+ metadata['name'] = claim
167
+ metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
168
+ metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
169
+ metadata['short_name'] = "Evidence {}".format(i + 1)
170
+ metadata['page_number'] = ""
171
+ metadata['query'] = sent
172
+ metadata['title'] = title
173
+ metadata['evidence'] = sent
174
+ metadata['answer'] = ""
175
+ metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata['evidence']
176
+ page_content = f"""{metadata['page_content']}"""
177
+
178
+ results.append(Docs(metadata, page_content))
179
+
180
+ return results
181
+
182
+
183
+ # ------------------------------ Veracity Prediction ------------------------------
184
+ class SequenceClassificationDataLoader(pl.LightningDataModule):
185
+ def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
186
+ super().__init__()
187
+ self.tokenizer = tokenizer
188
+ self.data_file = data_file
189
+ self.batch_size = batch_size
190
+ self.add_extra_nee = add_extra_nee
191
+
192
+ def tokenize_strings(
193
+ self,
194
+ source_sentences,
195
+ max_length=400,
196
+ pad_to_max_length=False,
197
+ return_tensors="pt",
198
+ ):
199
+ encoded_dict = self.tokenizer(
200
+ source_sentences,
201
+ max_length=max_length,
202
+ padding="max_length" if pad_to_max_length else "longest",
203
+ truncation=True,
204
+ return_tensors=return_tensors,
205
+ )
206
+
207
+ input_ids = encoded_dict["input_ids"]
208
+ attention_masks = encoded_dict["attention_mask"]
209
+
210
+ return input_ids, attention_masks
211
+
212
+ def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
213
+ if bool_explanation is not None and len(bool_explanation) > 0:
214
+ bool_explanation = ", because " + bool_explanation.lower().strip()
215
+ else:
216
+ bool_explanation = ""
217
+ return (
218
+ "[CLAIM] "
219
+ + claim.strip()
220
+ + " [QUESTION] "
221
+ + question.strip()
222
+ + " "
223
+ + answer.strip()
224
+ + bool_explanation
225
+ )
226
+
227
+
228
+ def veracity_prediction(claim, evidence):
229
+ dataLoader = SequenceClassificationDataLoader(
230
+ tokenizer=veracity_tokenizer,
231
+ data_file="this_is_discontinued",
232
+ batch_size=32,
233
+ add_extra_nee=False,
234
+ )
235
+
236
+ evidence_strings = []
237
+ for evi in evidence:
238
+ evidence_strings.append(dataLoader.quadruple_to_string(claim, evi.metadata["query"], evi.metadata["answer"], ""))
239
+
240
+ if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
241
+ pred_label = "Not Enough Evidence"
242
+ return pred_label
243
+
244
+ tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
245
+ example_support = torch.argmax(
246
+ veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
247
+
248
+ has_unanswerable = False
249
+ has_true = False
250
+ has_false = False
251
+
252
+ for v in example_support:
253
+ if v == 0:
254
+ has_true = True
255
+ if v == 1:
256
+ has_false = True
257
+ if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
258
+ has_unanswerable = True
259
+
260
+ if has_unanswerable:
261
+ answer = 2
262
+ elif has_true and not has_false:
263
+ answer = 0
264
+ elif not has_true and has_false:
265
+ answer = 1
266
+ else:
267
+ answer = 3
268
+
269
+ pred_label = LABEL[answer]
270
+
271
+ return pred_label
272
+
273
+
274
+ # ------------------------------ Justification Generation ------------------------------
275
+ def extract_claim_str(claim, evidence, verdict_label):
276
+ claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
277
+
278
+ for evi in evidence:
279
+ q_text = evi.metadata['query'].strip()
280
+
281
+ if len(q_text) == 0:
282
+ continue
283
+
284
+ if not q_text[-1] == "?":
285
+ q_text += "?"
286
+
287
+ answer_strings = []
288
+ answer_strings.append(evi.metadata['answer'])
289
+
290
+ claim_str += q_text
291
+ for a_text in answer_strings:
292
+ if a_text:
293
+ if not a_text[-1] == ".":
294
+ a_text += "."
295
+ claim_str += " " + a_text.strip()
296
+
297
+ claim_str += " "
298
+
299
+ claim_str += " [VERDICT] " + verdict_label
300
+
301
+ return claim_str
302
+
303
+
304
+ def justification_generation(claim, evidence, verdict_label):
305
+ #
306
+ claim_str = extract_claim_str(claim, evidence, verdict_label)
307
+ claim_str.strip()
308
+ pred_justification = justification_model.generate(claim_str, device=device)
309
+
310
+ return pred_justification.strip()
311
+
312
+
averitec/models/DualEncoderModule.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ from transformers.optimization import AdamW
4
+ import torchmetrics
5
+
6
+
7
+ class DualEncoderModule(pl.LightningModule):
8
+
9
+ def __init__(self, tokenizer, model, learning_rate=1e-3):
10
+ super().__init__()
11
+ self.tokenizer = tokenizer
12
+ self.model = model
13
+ self.learning_rate = learning_rate
14
+
15
+ self.train_acc = torchmetrics.Accuracy(
16
+ task="multiclass", num_classes=model.num_labels
17
+ )
18
+ self.val_acc = torchmetrics.Accuracy(
19
+ task="multiclass", num_classes=model.num_labels
20
+ )
21
+ self.test_acc = torchmetrics.Accuracy(
22
+ task="multiclass", num_classes=model.num_labels
23
+ )
24
+
25
+ def forward(self, input_ids, **kwargs):
26
+ return self.model(input_ids, **kwargs)
27
+
28
+ def configure_optimizers(self):
29
+ optimizer = AdamW(self.parameters(), lr=self.learning_rate)
30
+ return optimizer
31
+
32
+ def training_step(self, batch, batch_idx):
33
+ pos_ids, pos_mask, neg_ids, neg_mask = batch
34
+
35
+ neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
36
+ neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
37
+
38
+ pos_outputs = self(
39
+ pos_ids,
40
+ attention_mask=pos_mask,
41
+ labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
42
+ pos_ids.get_device()
43
+ ),
44
+ )
45
+ neg_outputs = self(
46
+ neg_ids,
47
+ attention_mask=neg_mask,
48
+ labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
49
+ neg_ids.get_device()
50
+ ),
51
+ )
52
+
53
+ loss_scale = 1.0
54
+ loss = pos_outputs.loss + loss_scale * neg_outputs.loss
55
+
56
+ pos_logits = pos_outputs.logits
57
+ pos_preds = torch.argmax(pos_logits, axis=1)
58
+ self.train_acc(
59
+ pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
60
+ )
61
+
62
+ neg_logits = neg_outputs.logits
63
+ neg_preds = torch.argmax(neg_logits, axis=1)
64
+ self.train_acc(
65
+ neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
66
+ )
67
+
68
+ return {"loss": loss}
69
+
70
+ def validation_step(self, batch, batch_idx):
71
+ pos_ids, pos_mask, neg_ids, neg_mask = batch
72
+
73
+ neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
74
+ neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
75
+
76
+ pos_outputs = self(
77
+ pos_ids,
78
+ attention_mask=pos_mask,
79
+ labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
80
+ pos_ids.get_device()
81
+ ),
82
+ )
83
+ neg_outputs = self(
84
+ neg_ids,
85
+ attention_mask=neg_mask,
86
+ labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
87
+ neg_ids.get_device()
88
+ ),
89
+ )
90
+
91
+ loss_scale = 1.0
92
+ loss = pos_outputs.loss + loss_scale * neg_outputs.loss
93
+
94
+ pos_logits = pos_outputs.logits
95
+ pos_preds = torch.argmax(pos_logits, axis=1)
96
+ self.val_acc(
97
+ pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
98
+ )
99
+
100
+ neg_logits = neg_outputs.logits
101
+ neg_preds = torch.argmax(neg_logits, axis=1)
102
+ self.val_acc(
103
+ neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
104
+ )
105
+
106
+ self.log("val_acc", self.val_acc)
107
+
108
+ return {"loss": loss}
109
+
110
+ def test_step(self, batch, batch_idx):
111
+ pos_ids, pos_mask, neg_ids, neg_mask = batch
112
+
113
+ neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
114
+ neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
115
+
116
+ pos_outputs = self(
117
+ pos_ids,
118
+ attention_mask=pos_mask,
119
+ labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
120
+ pos_ids.get_device()
121
+ ),
122
+ )
123
+ neg_outputs = self(
124
+ neg_ids,
125
+ attention_mask=neg_mask,
126
+ labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
127
+ neg_ids.get_device()
128
+ ),
129
+ )
130
+
131
+ pos_logits = pos_outputs.logits
132
+ pos_preds = torch.argmax(pos_logits, axis=1)
133
+ self.test_acc(
134
+ pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
135
+ )
136
+
137
+ neg_logits = neg_outputs.logits
138
+ neg_preds = torch.argmax(neg_logits, axis=1)
139
+ self.test_acc(
140
+ neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
141
+ )
142
+
143
+ self.log("test_acc", self.test_acc)
averitec/models/JustificationGenerationModule.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import numpy as np
4
+ import datasets
5
+ from transformers import MaxLengthCriteria, StoppingCriteriaList
6
+ from transformers.optimization import AdamW
7
+ import itertools
8
+ from averitec.models.utils import count_stats, f1_metric, pairwise_meteor
9
+ from torchmetrics.text.rouge import ROUGEScore
10
+ import torch.nn.functional as F
11
+ import torchmetrics
12
+ from torchmetrics.classification import F1Score
13
+
14
+ def freeze_params(model):
15
+ for layer in model.parameters():
16
+ layer.requires_grade = False
17
+
18
+ class JustificationGenerationModule(pl.LightningModule):
19
+
20
+ def __init__(self, tokenizer, model, learning_rate=1e-3, gen_num_beams=2, gen_max_length=100, should_pad_gen=True):
21
+ super().__init__()
22
+ self.tokenizer = tokenizer
23
+ self.model = model
24
+ self.learning_rate = learning_rate
25
+
26
+ self.gen_num_beams = gen_num_beams
27
+ self.gen_max_length = gen_max_length
28
+ self.should_pad_gen = should_pad_gen
29
+
30
+ #self.metrics = datasets.load_metric('meteor')
31
+
32
+ freeze_params(self.model.get_encoder())
33
+ self.freeze_embeds()
34
+
35
+ def freeze_embeds(self):
36
+ ''' freeze the positional embedding parameters of the model; adapted from finetune.py '''
37
+ freeze_params(self.model.model.shared)
38
+ for d in [self.model.model.encoder, self.model.model.decoder]:
39
+ freeze_params(d.embed_positions)
40
+ freeze_params(d.embed_tokens)
41
+
42
+ # Do a forward pass through the model
43
+ def forward(self, input_ids, **kwargs):
44
+ return self.model(input_ids, **kwargs)
45
+
46
+ def configure_optimizers(self):
47
+ optimizer = AdamW(self.parameters(), lr = self.learning_rate)
48
+ return optimizer
49
+
50
+ def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
51
+ """
52
+ Shift input ids one token to the right.
53
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/modeling_bart.py.
54
+ """
55
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
56
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
57
+ shifted_input_ids[:, 0] = decoder_start_token_id
58
+
59
+ if pad_token_id is None:
60
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
61
+ # replace possible -100 values in labels by `pad_token_id`
62
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
63
+
64
+ return shifted_input_ids
65
+
66
+ def run_model(self, batch):
67
+ src_ids, src_mask, tgt_ids = batch[0], batch[1], batch[2]
68
+
69
+ decoder_input_ids = self.shift_tokens_right(
70
+ tgt_ids, self.tokenizer.pad_token_id, self.tokenizer.pad_token_id # BART uses the EOS token to start generation as well. Might have to change for other models.
71
+ )
72
+
73
+ outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
74
+ return outputs
75
+
76
+ def compute_loss(self, batch):
77
+ tgt_ids = batch[2]
78
+ logits = self.run_model(batch)[0]
79
+
80
+ cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
81
+ loss = cross_entropy(logits.view(-1, logits.shape[-1]), tgt_ids.view(-1))
82
+
83
+ return loss
84
+
85
+ def training_step(self, batch, batch_idx):
86
+ loss = self.compute_loss(batch)
87
+
88
+ self.log("train_loss", loss, on_epoch=True)
89
+
90
+ return {'loss':loss}
91
+
92
+ def validation_step(self, batch, batch_idx):
93
+ preds, loss, tgts = self.generate_and_compute_loss_and_tgts(batch)
94
+ if self.should_pad_gen:
95
+ preds = F.pad(preds, pad=(0, self.gen_max_length - preds.shape[1]), value=self.tokenizer.pad_token_id)
96
+
97
+ self.log('val_loss', loss, prog_bar=True, sync_dist=True)
98
+
99
+ return {'loss': loss, 'pred': preds, 'target': tgts}
100
+
101
+ def test_step(self, batch, batch_idx):
102
+ test_preds, test_loss, test_tgts = self.generate_and_compute_loss_and_tgts(batch)
103
+ if self.should_pad_gen:
104
+ test_preds = F.pad(test_preds, pad=(0, self.gen_max_length - test_preds.shape[1]), value=self.tokenizer.pad_token_id)
105
+
106
+ self.log('test_loss', test_loss, prog_bar=True, sync_dist=True)
107
+
108
+ return {'loss': test_loss, 'pred': test_preds, 'target': test_tgts}
109
+
110
+ def test_epoch_end(self, outputs):
111
+ self.handle_end_of_epoch_scoring(outputs, "test")
112
+
113
+ def validation_epoch_end(self, outputs):
114
+ self.handle_end_of_epoch_scoring(outputs, "val")
115
+
116
+ def handle_end_of_epoch_scoring(self, outputs, prefix):
117
+ gen = {}
118
+ tgt = {}
119
+ rouge = ROUGEScore()
120
+ rouge_metric = lambda x, y: rouge(x,y)["rougeL_precision"]
121
+ for out in outputs:
122
+ preds = out['pred']
123
+ tgts = out['target']
124
+
125
+ preds = self.do_batch_detokenize(preds)
126
+ tgts = self.do_batch_detokenize(tgts)
127
+
128
+ for pred, t in zip(preds, tgts):
129
+ rouge_d = rouge_metric(pred, t)
130
+ self.log(prefix+"_rouge", rouge_d)
131
+
132
+ meteor_d = pairwise_meteor(pred, t)
133
+ self.log(prefix+"_meteor", meteor_d)
134
+
135
+ def generate_and_compute_loss_and_tgts(self, batch):
136
+ src_ids = batch[0]
137
+ loss = self.compute_loss(batch)
138
+ pred_ids, _ = self.generate_for_batch(src_ids)
139
+
140
+ tgt_ids = batch[2]
141
+
142
+ return pred_ids, loss, tgt_ids
143
+
144
+ def do_batch_detokenize(self, batch):
145
+ tokens = self.tokenizer.batch_decode(
146
+ batch,
147
+ skip_special_tokens=True,
148
+ clean_up_tokenization_spaces=True
149
+ )
150
+
151
+ # Huggingface skipping of special tokens doesn't work for all models, so we do it manually as well for safety:
152
+ tokens = [p.replace("<pad>", "") for p in tokens]
153
+ tokens = [p.replace("<s>", "") for p in tokens]
154
+ tokens = [p.replace("</s>", "") for p in tokens]
155
+
156
+ return [t for t in tokens if t != ""]
157
+
158
+ def generate_for_batch(self, batch):
159
+ generated_ids = self.model.generate(
160
+ batch,
161
+ decoder_start_token_id = self.tokenizer.pad_token_id,
162
+ num_beams = self.gen_num_beams,
163
+ max_length = self.gen_max_length
164
+ )
165
+
166
+ generated_tokens = self.tokenizer.batch_decode(
167
+ generated_ids,
168
+ skip_special_tokens=True,
169
+ clean_up_tokenization_spaces=True
170
+ )
171
+
172
+ return generated_ids, generated_tokens
173
+
174
+
175
+ def generate(self, text, max_input_length=512, device=None):
176
+ encoded_dict = self.tokenizer(
177
+ [text],
178
+ max_length=max_input_length,
179
+ padding="longest",
180
+ truncation=True,
181
+ return_tensors="pt",
182
+ add_prefix_space = True
183
+ )
184
+
185
+ input_ids = encoded_dict['input_ids']
186
+
187
+ if device is not None:
188
+ input_ids = input_ids.to(device)
189
+
190
+ with torch.no_grad():
191
+ _, generated_tokens = self.generate_for_batch(input_ids)
192
+
193
+ return generated_tokens[0]
averitec/models/NaiveSeqClassModule.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import numpy as np
4
+ import datasets
5
+ from transformers import MaxLengthCriteria, StoppingCriteriaList
6
+ from transformers.optimization import AdamW
7
+ import itertools
8
+ from utils import count_stats, f1_metric, pairwise_meteor
9
+ from torchmetrics.text.rouge import ROUGEScore
10
+ import torch.nn.functional as F
11
+ import torchmetrics
12
+ from torchmetrics.classification import F1Score
13
+
14
+ class NaiveSeqClassModule(pl.LightningModule):
15
+ # Instantiate the model
16
+ def __init__(self, tokenizer, model, use_question_stance_approach=True, learning_rate=1e-3):
17
+ super().__init__()
18
+ self.tokenizer = tokenizer
19
+ self.model = model
20
+ self.learning_rate = learning_rate
21
+
22
+ self.train_acc = torchmetrics.Accuracy()
23
+ self.val_acc = torchmetrics.Accuracy()
24
+ self.test_acc = torchmetrics.Accuracy()
25
+
26
+ self.train_f1 = F1Score(num_classes=4, average="macro")
27
+ self.val_f1 = F1Score(num_classes=4, average=None)
28
+ self.test_f1 = F1Score(num_classes=4, average=None)
29
+
30
+ self.use_question_stance_approach = use_question_stance_approach
31
+
32
+
33
+ # Do a forward pass through the model
34
+ def forward(self, input_ids, **kwargs):
35
+ return self.model(input_ids, **kwargs)
36
+
37
+ def configure_optimizers(self):
38
+ optimizer = AdamW(self.parameters(), lr = self.learning_rate)
39
+ return optimizer
40
+
41
+ def training_step(self, batch, batch_idx):
42
+ x, x_mask, y = batch
43
+
44
+ outputs = self(x, attention_mask=x_mask, labels=y)
45
+ logits = outputs.logits
46
+ loss = outputs.loss
47
+
48
+ #cross_entropy = torch.nn.CrossEntropyLoss()
49
+ #loss = cross_entropy(logits, y)
50
+
51
+ preds = torch.argmax(logits, axis=1)
52
+
53
+ self.train_acc(preds.cpu(), y.cpu())
54
+ self.train_f1(preds.cpu(), y.cpu())
55
+
56
+ self.log("train_loss", loss)
57
+
58
+ return {'loss': loss}
59
+
60
+ def training_epoch_end(self, outs):
61
+ self.log('train_acc_epoch', self.train_acc)
62
+ self.log('train_f1_epoch', self.train_f1)
63
+
64
+ def validation_step(self, batch, batch_idx):
65
+ x, x_mask, y = batch
66
+
67
+ outputs = self(x, attention_mask=x_mask, labels=y)
68
+ logits = outputs.logits
69
+ loss = outputs.loss
70
+
71
+ preds = torch.argmax(logits, axis=1)
72
+
73
+ if not self.use_question_stance_approach:
74
+ self.val_acc(preds, y)
75
+ self.log('val_acc_step', self.val_acc)
76
+
77
+ self.val_f1(preds, y)
78
+ self.log("val_loss", loss)
79
+
80
+ return {'val_loss':loss, "src": x, "pred": preds, "target": y}
81
+
82
+ def validation_epoch_end(self, outs):
83
+ if self.use_question_stance_approach:
84
+ self.handle_end_of_epoch_scoring(outs, self.val_acc, self.val_f1)
85
+
86
+ self.log('val_acc_epoch', self.val_acc)
87
+
88
+ f1 = self.val_f1.compute()
89
+ self.val_f1.reset()
90
+
91
+ self.log('val_f1_epoch', torch.mean(f1))
92
+
93
+ class_names = ["supported", "refuted", "nei", "conflicting"]
94
+ for i, c_name in enumerate(class_names):
95
+ self.log("val_f1_" + c_name, f1[i])
96
+
97
+
98
+ def test_step(self, batch, batch_idx):
99
+ x, x_mask, y = batch
100
+
101
+ outputs = self(x, attention_mask=x_mask)
102
+ logits = outputs.logits
103
+
104
+ preds = torch.argmax(logits, axis=1)
105
+
106
+ if not self.use_question_stance_approach:
107
+ self.test_acc(preds, y)
108
+ self.log('test_acc_step', self.test_acc)
109
+ self.test_f1(preds, y)
110
+
111
+ return {"src": x, "pred": preds, "target": y}
112
+
113
+ def test_epoch_end(self, outs):
114
+ if self.use_question_stance_approach:
115
+ self.handle_end_of_epoch_scoring(outs, self.test_acc, self.test_f1)
116
+
117
+ self.log('test_acc_epoch', self.test_acc)
118
+
119
+ f1 = self.test_f1.compute()
120
+ self.test_f1.reset()
121
+ self.log('test_f1_epoch', torch.mean(f1))
122
+
123
+ class_names = ["supported", "refuted", "nei", "conflicting"]
124
+ for i, c_name in enumerate(class_names):
125
+ self.log("test_f1_" + c_name, f1[i])
126
+
127
+ def handle_end_of_epoch_scoring(self, outputs, acc_scorer, f1_scorer):
128
+ gold_labels = {}
129
+ question_support = {}
130
+ for out in outputs:
131
+ srcs = out['src']
132
+ preds = out['pred']
133
+ tgts = out['target']
134
+
135
+ tokens = self.tokenizer.batch_decode(
136
+ srcs,
137
+ skip_special_tokens=True,
138
+ clean_up_tokenization_spaces=True
139
+ )
140
+
141
+ for src, pred, tgt in zip(tokens, preds, tgts):
142
+ acc_scorer(torch.as_tensor([pred]).to("cuda:0"), torch.as_tensor([tgt]).to("cuda:0"))
143
+ f1_scorer(torch.as_tensor([pred]).to("cuda:0"), torch.as_tensor([tgt]).to("cuda:0"))
144
+
145
+
averitec/models/SequenceClassificationModule.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import numpy as np
4
+ import datasets
5
+ from transformers import MaxLengthCriteria, StoppingCriteriaList
6
+ from transformers.optimization import AdamW
7
+ import itertools
8
+ # from utils import count_stats, f1_metric, pairwise_meteor
9
+ from torchmetrics.text.rouge import ROUGEScore
10
+ import torch.nn.functional as F
11
+ import torchmetrics
12
+ from torchmetrics.classification import F1Score
13
+
14
+ class SequenceClassificationModule(pl.LightningModule):
15
+ # Instantiate the model
16
+ def __init__(self, tokenizer, model, use_question_stance_approach=True, learning_rate=1e-3):
17
+ super().__init__()
18
+ self.tokenizer = tokenizer
19
+ self.model = model
20
+ self.learning_rate = learning_rate
21
+
22
+ self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
23
+ self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
24
+ self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
25
+
26
+ self.train_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average="macro")
27
+ self.val_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average=None)
28
+ self.test_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average=None)
29
+ # self.train_acc = torchmetrics.Accuracy()
30
+ # self.val_acc = torchmetrics.Accuracy()
31
+ # self.test_acc = torchmetrics.Accuracy()
32
+
33
+ # self.train_f1 = F1Score(num_classes=4, average="macro")
34
+ # self.val_f1 = F1Score(num_classes=4, average=None)
35
+ # self.test_f1 = F1Score(num_classes=4, average=None)
36
+
37
+ self.use_question_stance_approach = use_question_stance_approach
38
+
39
+
40
+ # Do a forward pass through the model
41
+ def forward(self, input_ids, **kwargs):
42
+ return self.model(input_ids, **kwargs)
43
+
44
+ def configure_optimizers(self):
45
+ optimizer = AdamW(self.parameters(), lr = self.learning_rate)
46
+ return optimizer
47
+
48
+ def training_step(self, batch, batch_idx):
49
+ x, x_mask, y = batch
50
+
51
+ outputs = self(x, attention_mask=x_mask, labels=y)
52
+ logits = outputs.logits
53
+ loss = outputs.loss
54
+
55
+ #cross_entropy = torch.nn.CrossEntropyLoss()
56
+ #loss = cross_entropy(logits, y)
57
+
58
+ preds = torch.argmax(logits, axis=1)
59
+
60
+ self.log("train_loss", loss)
61
+
62
+ return {'loss': loss}
63
+
64
+ def validation_step(self, batch, batch_idx):
65
+ x, x_mask, y = batch
66
+
67
+ outputs = self(x, attention_mask=x_mask, labels=y)
68
+ logits = outputs.logits
69
+ loss = outputs.loss
70
+
71
+ preds = torch.argmax(logits, axis=1)
72
+
73
+ if not self.use_question_stance_approach:
74
+ self.val_acc(preds, y)
75
+ self.log('val_acc_step', self.val_acc)
76
+
77
+ self.val_f1(preds, y)
78
+ self.log("val_loss", loss)
79
+
80
+ return {'val_loss':loss, "src": x, "pred": preds, "target": y}
81
+
82
+ def validation_epoch_end(self, outs):
83
+ if self.use_question_stance_approach:
84
+ self.handle_end_of_epoch_scoring(outs, self.val_acc, self.val_f1)
85
+
86
+ self.log('val_acc_epoch', self.val_acc)
87
+
88
+ f1 = self.val_f1.compute()
89
+ self.val_f1.reset()
90
+
91
+ self.log('val_f1_epoch', torch.mean(f1))
92
+
93
+ class_names = ["supported", "refuted", "nei", "conflicting"]
94
+ for i, c_name in enumerate(class_names):
95
+ self.log("val_f1_" + c_name, f1[i])
96
+
97
+
98
+ def test_step(self, batch, batch_idx):
99
+ x, x_mask, y = batch
100
+
101
+ outputs = self(x, attention_mask=x_mask)
102
+ logits = outputs.logits
103
+
104
+ preds = torch.argmax(logits, axis=1)
105
+
106
+ if not self.use_question_stance_approach:
107
+ self.test_acc(preds, y)
108
+ self.log('test_acc_step', self.test_acc)
109
+ self.test_f1(preds, y)
110
+
111
+ return {"src": x, "pred": preds, "target": y}
112
+
113
+ def test_epoch_end(self, outs):
114
+ if self.use_question_stance_approach:
115
+ self.handle_end_of_epoch_scoring(outs, self.test_acc, self.test_f1)
116
+
117
+ self.log('test_acc_epoch', self.test_acc)
118
+
119
+ f1 = self.test_f1.compute()
120
+ self.test_f1.reset()
121
+ self.log('test_f1_epoch', torch.mean(f1))
122
+
123
+ class_names = ["supported", "refuted", "nei", "conflicting"]
124
+ for i, c_name in enumerate(class_names):
125
+ self.log("test_f1_" + c_name, f1[i])
126
+
127
+ def handle_end_of_epoch_scoring(self, outputs, acc_scorer, f1_scorer):
128
+ gold_labels = {}
129
+ question_support = {}
130
+ for out in outputs:
131
+ srcs = out['src']
132
+ preds = out['pred']
133
+ tgts = out['target']
134
+
135
+ tokens = self.tokenizer.batch_decode(
136
+ srcs,
137
+ skip_special_tokens=True,
138
+ clean_up_tokenization_spaces=True
139
+ )
140
+
141
+ for src, pred, tgt in zip(tokens, preds, tgts):
142
+ claim_id = src.split("[ question ]")[0]
143
+
144
+ if claim_id not in gold_labels:
145
+ gold_labels[claim_id] = tgt
146
+ question_support[claim_id] = []
147
+
148
+ question_support[claim_id].append(pred)
149
+
150
+ for k,gold_label in gold_labels.items():
151
+ support = question_support[k]
152
+
153
+ has_unansw = False
154
+ has_true = False
155
+ has_false = False
156
+
157
+ for v in support:
158
+ if v == 0:
159
+ has_true = True
160
+ if v == 1:
161
+ has_false = True
162
+ if v == 2 or v == 3: # TODO very ugly hack -- we cant have different numbers of labels for train and test so we do this
163
+ has_unansw = True
164
+
165
+ if has_unansw:
166
+ answer = 2
167
+ elif has_true and not has_false:
168
+ answer = 0
169
+ elif has_false and not has_true:
170
+ answer = 1
171
+ elif has_true and has_false:
172
+ answer = 3
173
+
174
+
175
+ # TODO this is very hacky and wont work if the device is literally anything other than cuda:0
176
+ acc_scorer(torch.as_tensor([answer]).to("cuda:0"), torch.as_tensor([gold_label]).to("cuda:0"))
177
+ f1_scorer(torch.as_tensor([answer]).to("cuda:0"), torch.as_tensor([gold_label]).to("cuda:0"))
178
+
179
+
averitec/models/utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import nltk
3
+ from nltk import word_tokenize
4
+ import numpy as np
5
+ from leven import levenshtein
6
+ from sklearn.cluster import DBSCAN, dbscan
7
+
8
+ def delete_if_exists(filepath):
9
+ if os.path.exists(filepath):
10
+ os.remove(filepath)
11
+
12
+ def pairwise_meteor(candidate, reference): # Todo this is not thread safe, no idea how to make it so
13
+ return nltk.translate.meteor_score.single_meteor_score(word_tokenize(reference), word_tokenize(candidate))
14
+
15
+ def count_stats(candidate_dict, reference_dict):
16
+ count_match = [0 for _ in candidate_dict]
17
+ count_diff = [0 for _ in candidate_dict]
18
+
19
+ for i, k in enumerate(candidate_dict.keys()):
20
+ pred_parts = candidate_dict[k]
21
+ tgt_parts = reference_dict[k]
22
+
23
+ if len(pred_parts) == len(tgt_parts):
24
+ count_match[i] = 1
25
+
26
+ count_diff[i] = abs(len(pred_parts) - len(tgt_parts))
27
+
28
+ count_match_score = np.mean(count_match)
29
+ count_diff_score = np.mean(count_diff)
30
+
31
+ return {
32
+ "count_match_score": count_match_score,
33
+ "count_diff_score": count_diff_score
34
+ }
35
+
36
+ def f1_metric(candidate_dict, reference_dict, pairwise_metric):
37
+ all_best_p = [0 for _ in candidate_dict]
38
+ all_best_t = [0 for _ in candidate_dict]
39
+ p_unnorm = []
40
+
41
+ for i, k in enumerate(candidate_dict.keys()):
42
+ pred_parts = candidate_dict[k]
43
+ tgt_parts = reference_dict[k]
44
+
45
+ best_p_score = [0 for _ in pred_parts]
46
+ best_t_score = [0 for _ in tgt_parts]
47
+
48
+ for p_idx in range(len(pred_parts)):
49
+ for t_idx in range(len(tgt_parts)):
50
+ #meteor_score = pairwise_meteor(pred_parts[p_idx], tgt_parts[t_idx])
51
+ metric_score = pairwise_metric(pred_parts[p_idx], tgt_parts[t_idx])
52
+
53
+ if metric_score > best_p_score[p_idx]:
54
+ best_p_score[p_idx] = metric_score
55
+
56
+ if metric_score > best_t_score[t_idx]:
57
+ best_t_score[t_idx] = metric_score
58
+
59
+ all_best_p[i] = np.mean(best_p_score) if len(best_p_score) > 0 else 1.0
60
+ all_best_t[i] = np.mean(best_t_score) if len(best_t_score) > 0 else 1.0
61
+
62
+ p_unnorm.extend(best_p_score)
63
+
64
+ p_score = np.mean(all_best_p)
65
+ r_score = np.mean(all_best_t)
66
+ avg_score = (p_score + r_score) / 2
67
+ f1_score = 2 * p_score * r_score / (p_score + r_score + 1e-8)
68
+
69
+ p_unnorm_score = np.mean(p_unnorm)
70
+
71
+ return {
72
+ "p": p_score,
73
+ "r": r_score,
74
+ "avg": avg_score,
75
+ "f1": f1_score,
76
+ "p_unnorm": p_unnorm_score,
77
+ }
78
+
79
+ def edit_distance_dbscan(data):
80
+ # Inspired by https://scikit-learn.org/stable/faq.html#how-do-i-deal-with-string-data-or-trees-graphs
81
+ def lev_metric(x, y):
82
+ i, j = int(x[0]), int(y[0])
83
+ return levenshtein(data[i], data[j])
84
+
85
+ X = np.arange(len(data)).reshape(-1, 1)
86
+
87
+ clustering = dbscan(X, metric=lev_metric, eps=20, min_samples=2, algorithm='brute')
88
+ return clustering
89
+
90
+ def compute_all_pairwise_edit_distances(data):
91
+ X = np.empty((len(data), len(data)))
92
+
93
+ for i in range(len(data)):
94
+ for j in range(len(data)):
95
+ X[i][j] = levenshtein(data[i], data[j])
96
+
97
+ return X
98
+
99
+ def compute_all_pairwise_scores(src_data, tgt_data, metric):
100
+ X = np.empty((len(src_data), len(tgt_data)))
101
+
102
+ for i in range(len(src_data)):
103
+ for j in range(len(tgt_data)):
104
+ X[i][j] = (metric(src_data[i], tgt_data[j]))
105
+
106
+ return X
107
+
108
+ def compute_all_pairwise_meteor_scores(data):
109
+ X = np.empty((len(data), len(data)))
110
+
111
+ for i in range(len(data)):
112
+ for j in range(len(data)):
113
+ X[i][j] = (pairwise_meteor(data[i], data[j]) + pairwise_meteor(data[j], data[i])) / 2
114
+
115
+ return X
116
+
117
+ def edit_distance_custom(data, X, eps=0.5, min_samples=3):
118
+ clustering = DBSCAN(metric="precomputed", eps=eps, min_samples=min_samples).fit(X)
119
+ return clustering.labels_
averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4b7bf02daaf10b3443f4f2cbe79c3c9f10c453dfdf818a4d14e44b2b4311cf4
3
+ size 4876206567
averitec/pretrained_models/bert_dual_encoder.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fee6737f655f4f1dfb46cc1bb812b5eaf9a72cfc0b69d4e5c05cde27ea7b6051
3
+ size 1314015751
averitec/pretrained_models/bert_veracity.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ddb8132a28ceff149904dd3ad3c3edd3e5f0c7de0169819207104a80e425c9a
3
+ size 1314034311
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ nltk
3
+ rank_bm25
4
+ accelerate
5
+ trafilatura
6
+ spacy
7
+ pytorch_lightning
8
+ transformers==4.29.2
9
+ SentencePiece
10
+ datasets
11
+ leven
12
+ scikit-learn
13
+ pexpect
14
+ elasticsearch
15
+ torch
16
+ huggingface_hub
17
+ google-api-python-client
18
+ wikipedia-api
19
+ beautifulsoup4
20
+ azure-storage-file-share
21
+ azure-storage-blob
22
+ bm25s
23
+ PyStemmer
24
+ lxml_html_clean
25
+ spaces
utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import string
4
+ import uuid
5
+ from datetime import datetime
6
+
7
+
8
+ def create_user_id():
9
+ """Create user_id
10
+ str: String to id user
11
+ """
12
+ current_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
13
+ user_id = str(uuid.uuid4())
14
+ return current_date + '_' +user_id