Spaces:
Sleeping
Sleeping
zhenyundeng
commited on
Commit
·
dcdc5f5
1
Parent(s):
075c300
update
Browse files- app.py +5 -12
- requirements.txt +3 -3
app.py
CHANGED
@@ -6,7 +6,6 @@ from fastapi import FastAPI
|
|
6 |
from pydantic import BaseModel
|
7 |
# from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation
|
8 |
import uvicorn
|
9 |
-
# import spaces
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
@@ -75,19 +74,15 @@ LABEL = [
|
|
75 |
]
|
76 |
|
77 |
# Veracity
|
78 |
-
|
79 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
80 |
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
81 |
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
|
82 |
veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
|
83 |
-
# veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to('cuda')
|
84 |
veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
|
85 |
-
|
86 |
# Justification
|
87 |
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
88 |
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
89 |
-
best_checkpoint = os.getcwd()
|
90 |
-
# justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
|
91 |
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
92 |
# ---------------------------------------------------------------------------
|
93 |
|
@@ -264,7 +259,7 @@ class SequenceClassificationDataLoader(pl.LightningDataModule):
|
|
264 |
+ bool_explanation
|
265 |
)
|
266 |
|
267 |
-
|
268 |
def veracity_prediction(claim, evidence):
|
269 |
dataLoader = SequenceClassificationDataLoader(
|
270 |
tokenizer=veracity_tokenizer,
|
@@ -282,8 +277,8 @@ def veracity_prediction(claim, evidence):
|
|
282 |
return pred_label
|
283 |
|
284 |
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
285 |
-
|
286 |
-
|
287 |
|
288 |
has_unanswerable = False
|
289 |
has_true = False
|
@@ -340,12 +335,11 @@ def extract_claim_str(claim, evidence, verdict_label):
|
|
340 |
|
341 |
return claim_str
|
342 |
|
343 |
-
|
344 |
def justification_generation(claim, evidence, verdict_label):
|
345 |
#
|
346 |
claim_str = extract_claim_str(claim, evidence, verdict_label)
|
347 |
claim_str.strip()
|
348 |
-
# pred_justification = justification_model.generate(claim_str, device='cuda')
|
349 |
pred_justification = justification_model.generate(claim_str, device=device)
|
350 |
|
351 |
return pred_justification.strip()
|
@@ -368,7 +362,6 @@ def log_on_azure(file, logs, azure_share_client):
|
|
368 |
file_client.upload_file(logs)
|
369 |
|
370 |
|
371 |
-
# @spaces.GPU
|
372 |
@app.post("/predict/")
|
373 |
def fact_checking(item: Item):
|
374 |
claim = item['claim']
|
|
|
6 |
from pydantic import BaseModel
|
7 |
# from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation
|
8 |
import uvicorn
|
|
|
9 |
|
10 |
app = FastAPI()
|
11 |
|
|
|
74 |
]
|
75 |
|
76 |
# Veracity
|
|
|
77 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
78 |
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
79 |
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
|
80 |
veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
|
|
|
81 |
veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
|
|
|
82 |
# Justification
|
83 |
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
84 |
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
85 |
+
best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
|
|
86 |
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
87 |
# ---------------------------------------------------------------------------
|
88 |
|
|
|
259 |
+ bool_explanation
|
260 |
)
|
261 |
|
262 |
+
|
263 |
def veracity_prediction(claim, evidence):
|
264 |
dataLoader = SequenceClassificationDataLoader(
|
265 |
tokenizer=veracity_tokenizer,
|
|
|
277 |
return pred_label
|
278 |
|
279 |
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
280 |
+
example_support = torch.argmax(
|
281 |
+
veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
|
282 |
|
283 |
has_unanswerable = False
|
284 |
has_true = False
|
|
|
335 |
|
336 |
return claim_str
|
337 |
|
338 |
+
|
339 |
def justification_generation(claim, evidence, verdict_label):
|
340 |
#
|
341 |
claim_str = extract_claim_str(claim, evidence, verdict_label)
|
342 |
claim_str.strip()
|
|
|
343 |
pred_justification = justification_model.generate(claim_str, device=device)
|
344 |
|
345 |
return pred_justification.strip()
|
|
|
362 |
file_client.upload_file(logs)
|
363 |
|
364 |
|
|
|
365 |
@app.post("/predict/")
|
366 |
def fact_checking(item: Item):
|
367 |
claim = item['claim']
|
requirements.txt
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
gradio
|
2 |
-
nltk
|
3 |
rank_bm25
|
4 |
accelerate
|
5 |
trafilatura
|
6 |
-
spacy
|
7 |
pytorch_lightning
|
8 |
transformers==4.29.2
|
9 |
datasets
|
@@ -20,4 +20,4 @@ azure-storage-file-share
|
|
20 |
azure-storage-blob
|
21 |
bm25s
|
22 |
PyStemmer
|
23 |
-
lxml_html_clean
|
|
|
1 |
gradio
|
2 |
+
nltk
|
3 |
rank_bm25
|
4 |
accelerate
|
5 |
trafilatura
|
6 |
+
spacy
|
7 |
pytorch_lightning
|
8 |
transformers==4.29.2
|
9 |
datasets
|
|
|
20 |
azure-storage-blob
|
21 |
bm25s
|
22 |
PyStemmer
|
23 |
+
lxml_html_clean
|