zhenyundeng commited on
Commit
0db7c43
1 Parent(s): 0c5727b
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -6,6 +6,7 @@ from fastapi import FastAPI
6
  from pydantic import BaseModel
7
  # from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation
8
  import uvicorn
 
9
 
10
  app = FastAPI()
11
 
@@ -35,6 +36,7 @@ wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en')
35
 
36
  import nltk
37
  nltk.download('punkt')
 
38
  from nltk import pos_tag, word_tokenize, sent_tokenize
39
 
40
  import spacy
@@ -74,15 +76,19 @@ LABEL = [
74
  ]
75
 
76
  # Veracity
 
77
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
78
  veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
79
  bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
80
  veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
 
81
  veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
 
82
  # Justification
83
  justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
84
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
85
- best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
 
86
  justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
87
  # ---------------------------------------------------------------------------
88
 
@@ -259,7 +265,7 @@ class SequenceClassificationDataLoader(pl.LightningDataModule):
259
  + bool_explanation
260
  )
261
 
262
-
263
  def veracity_prediction(claim, evidence):
264
  dataLoader = SequenceClassificationDataLoader(
265
  tokenizer=veracity_tokenizer,
@@ -277,8 +283,8 @@ def veracity_prediction(claim, evidence):
277
  return pred_label
278
 
279
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
280
- example_support = torch.argmax(
281
- veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
282
 
283
  has_unanswerable = False
284
  has_true = False
@@ -335,11 +341,12 @@ def extract_claim_str(claim, evidence, verdict_label):
335
 
336
  return claim_str
337
 
338
-
339
  def justification_generation(claim, evidence, verdict_label):
340
  #
341
  claim_str = extract_claim_str(claim, evidence, verdict_label)
342
  claim_str.strip()
 
343
  pred_justification = justification_model.generate(claim_str, device=device)
344
 
345
  return pred_justification.strip()
@@ -362,6 +369,7 @@ def log_on_azure(file, logs, azure_share_client):
362
  file_client.upload_file(logs)
363
 
364
 
 
365
  @app.post("/predict/")
366
  def fact_checking(item: Item):
367
  # claim = item['claim']
 
6
  from pydantic import BaseModel
7
  # from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation
8
  import uvicorn
9
+ # import spaces
10
 
11
  app = FastAPI()
12
 
 
36
 
37
  import nltk
38
  nltk.download('punkt')
39
+ nltk.download('punkt_tab')
40
  from nltk import pos_tag, word_tokenize, sent_tokenize
41
 
42
  import spacy
 
76
  ]
77
 
78
  # Veracity
79
+
80
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
81
  veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
82
  bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
83
  veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
84
+ # veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to('cuda')
85
  veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
86
+
87
  # Justification
88
  justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
89
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
90
+ best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
91
+ # justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
92
  justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
93
  # ---------------------------------------------------------------------------
94
 
 
265
  + bool_explanation
266
  )
267
 
268
+ # @spaces.GPU
269
  def veracity_prediction(claim, evidence):
270
  dataLoader = SequenceClassificationDataLoader(
271
  tokenizer=veracity_tokenizer,
 
283
  return pred_label
284
 
285
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
286
+ # example_support = torch.argmax(veracity_model(tokenized_strings.to('cuda'), attention_mask=attention_mask.to('cuda')).logits, axis=1)
287
+ example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
288
 
289
  has_unanswerable = False
290
  has_true = False
 
341
 
342
  return claim_str
343
 
344
+ # @spaces.GPU
345
  def justification_generation(claim, evidence, verdict_label):
346
  #
347
  claim_str = extract_claim_str(claim, evidence, verdict_label)
348
  claim_str.strip()
349
+ # pred_justification = justification_model.generate(claim_str, device='cuda')
350
  pred_justification = justification_model.generate(claim_str, device=device)
351
 
352
  return pred_justification.strip()
 
369
  file_client.upload_file(logs)
370
 
371
 
372
+ # @spaces.GPU
373
  @app.post("/predict/")
374
  def fact_checking(item: Item):
375
  # claim = item['claim']