zhenyundeng commited on
Commit
dcdc5f5
·
1 Parent(s): 075c300
Files changed (2) hide show
  1. app.py +5 -12
  2. requirements.txt +3 -3
app.py CHANGED
@@ -6,7 +6,6 @@ from fastapi import FastAPI
6
  from pydantic import BaseModel
7
  # from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation
8
  import uvicorn
9
- # import spaces
10
 
11
  app = FastAPI()
12
 
@@ -75,19 +74,15 @@ LABEL = [
75
  ]
76
 
77
  # Veracity
78
-
79
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
80
  veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
81
  bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
82
  veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
83
- # veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to('cuda')
84
  veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
85
-
86
  # Justification
87
  justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
88
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
89
- best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
90
- # justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
91
  justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
92
  # ---------------------------------------------------------------------------
93
 
@@ -264,7 +259,7 @@ class SequenceClassificationDataLoader(pl.LightningDataModule):
264
  + bool_explanation
265
  )
266
 
267
- # @spaces.GPU
268
  def veracity_prediction(claim, evidence):
269
  dataLoader = SequenceClassificationDataLoader(
270
  tokenizer=veracity_tokenizer,
@@ -282,8 +277,8 @@ def veracity_prediction(claim, evidence):
282
  return pred_label
283
 
284
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
285
- # example_support = torch.argmax(veracity_model(tokenized_strings.to('cuda'), attention_mask=attention_mask.to('cuda')).logits, axis=1)
286
- example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
287
 
288
  has_unanswerable = False
289
  has_true = False
@@ -340,12 +335,11 @@ def extract_claim_str(claim, evidence, verdict_label):
340
 
341
  return claim_str
342
 
343
- # @spaces.GPU
344
  def justification_generation(claim, evidence, verdict_label):
345
  #
346
  claim_str = extract_claim_str(claim, evidence, verdict_label)
347
  claim_str.strip()
348
- # pred_justification = justification_model.generate(claim_str, device='cuda')
349
  pred_justification = justification_model.generate(claim_str, device=device)
350
 
351
  return pred_justification.strip()
@@ -368,7 +362,6 @@ def log_on_azure(file, logs, azure_share_client):
368
  file_client.upload_file(logs)
369
 
370
 
371
- # @spaces.GPU
372
  @app.post("/predict/")
373
  def fact_checking(item: Item):
374
  claim = item['claim']
 
6
  from pydantic import BaseModel
7
  # from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation
8
  import uvicorn
 
9
 
10
  app = FastAPI()
11
 
 
74
  ]
75
 
76
  # Veracity
 
77
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
78
  veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
79
  bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
80
  veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
 
81
  veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
 
82
  # Justification
83
  justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
84
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
85
+ best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
 
86
  justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
87
  # ---------------------------------------------------------------------------
88
 
 
259
  + bool_explanation
260
  )
261
 
262
+
263
  def veracity_prediction(claim, evidence):
264
  dataLoader = SequenceClassificationDataLoader(
265
  tokenizer=veracity_tokenizer,
 
277
  return pred_label
278
 
279
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
280
+ example_support = torch.argmax(
281
+ veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
282
 
283
  has_unanswerable = False
284
  has_true = False
 
335
 
336
  return claim_str
337
 
338
+
339
  def justification_generation(claim, evidence, verdict_label):
340
  #
341
  claim_str = extract_claim_str(claim, evidence, verdict_label)
342
  claim_str.strip()
 
343
  pred_justification = justification_model.generate(claim_str, device=device)
344
 
345
  return pred_justification.strip()
 
362
  file_client.upload_file(logs)
363
 
364
 
 
365
  @app.post("/predict/")
366
  def fact_checking(item: Item):
367
  claim = item['claim']
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
  gradio
2
- nltk==3.8.1
3
  rank_bm25
4
  accelerate
5
  trafilatura
6
- spacy==3.7.5
7
  pytorch_lightning
8
  transformers==4.29.2
9
  datasets
@@ -20,4 +20,4 @@ azure-storage-file-share
20
  azure-storage-blob
21
  bm25s
22
  PyStemmer
23
- lxml_html_clean
 
1
  gradio
2
+ nltk
3
  rank_bm25
4
  accelerate
5
  trafilatura
6
+ spacy
7
  pytorch_lightning
8
  transformers==4.29.2
9
  datasets
 
20
  azure-storage-blob
21
  bm25s
22
  PyStemmer
23
+ lxml_html_clean