Spaces:
Runtime error
Runtime error
pminervini
commited on
Commit
•
590fea3
1
Parent(s):
7fe1886
update
Browse files
src/backend/tasks/cnndm/task.py
CHANGED
@@ -69,6 +69,7 @@ class CnnDm(Task):
|
|
69 |
super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config)
|
70 |
self.factkb_tokenizer = None
|
71 |
self.factkb_model = None
|
|
|
72 |
|
73 |
def maybe_init_factkb(self):
|
74 |
if self.factkb_tokenizer is None or self.factkb_model is None:
|
@@ -76,6 +77,11 @@ class CnnDm(Task):
|
|
76 |
self.factkb_tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True)
|
77 |
self.factkb_model = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels=2, device_map="auto")
|
78 |
|
|
|
|
|
|
|
|
|
|
|
79 |
def has_training_docs(self):
|
80 |
return True
|
81 |
|
@@ -153,11 +159,17 @@ class CnnDm(Task):
|
|
153 |
factkb_logits = self.factkb_model(**factkb_tokens).logits
|
154 |
factkb_res = torch.softmax(factkb_logits, dim=1)
|
155 |
|
|
|
|
|
|
|
156 |
res = {
|
157 |
"rouge1": rouge1_scores[0],
|
158 |
"rouge2": rouge2_scores[0],
|
159 |
"rougeL": rougeL_scores[0],
|
160 |
-
"factKB": float(factkb_res[0][1])
|
|
|
|
|
|
|
161 |
}
|
162 |
|
163 |
return res
|
|
|
69 |
super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config)
|
70 |
self.factkb_tokenizer = None
|
71 |
self.factkb_model = None
|
72 |
+
self.bert_score = None
|
73 |
|
74 |
def maybe_init_factkb(self):
|
75 |
if self.factkb_tokenizer is None or self.factkb_model is None:
|
|
|
77 |
self.factkb_tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True)
|
78 |
self.factkb_model = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels=2, device_map="auto")
|
79 |
|
80 |
+
def maybe_init_bertscore(self):
|
81 |
+
if self.bert_score is None:
|
82 |
+
from evaluate import load
|
83 |
+
self.bert_score = load("bertscore")
|
84 |
+
|
85 |
def has_training_docs(self):
|
86 |
return True
|
87 |
|
|
|
159 |
factkb_logits = self.factkb_model(**factkb_tokens).logits
|
160 |
factkb_res = torch.softmax(factkb_logits, dim=1)
|
161 |
|
162 |
+
self.maybe_init_bertscore()
|
163 |
+
bert_score_res = self.bert_score.compute(predictions=[completion], references=[gold_summary], model_type="microsoft/deberta-xlarge-mnli", lang="en")
|
164 |
+
|
165 |
res = {
|
166 |
"rouge1": rouge1_scores[0],
|
167 |
"rouge2": rouge2_scores[0],
|
168 |
"rougeL": rougeL_scores[0],
|
169 |
+
"factKB": float(factkb_res[0][1]),
|
170 |
+
"bertscore_precision": float(bert_score_res["precision"][0]),
|
171 |
+
"bertscore_recall": float(bert_score_res["recall"][0]),
|
172 |
+
"bertscore_f1": float(bert_score_res["f1"][0])
|
173 |
}
|
174 |
|
175 |
return res
|
src/backend/tasks/xsum/task.py
CHANGED
@@ -153,7 +153,7 @@ class XSum(Task):
|
|
153 |
factkb_res = torch.softmax(factkb_logits, dim=1)
|
154 |
|
155 |
self.maybe_init_bertscore()
|
156 |
-
bert_score_res = self.bert_score.compute(predictions=[completion], references=[gold_summary], lang="en")
|
157 |
|
158 |
res = {
|
159 |
"rouge1": rouge1_scores[0],
|
|
|
153 |
factkb_res = torch.softmax(factkb_logits, dim=1)
|
154 |
|
155 |
self.maybe_init_bertscore()
|
156 |
+
bert_score_res = self.bert_score.compute(predictions=[completion], references=[gold_summary], model_type="microsoft/deberta-xlarge-mnli", lang="en")
|
157 |
|
158 |
res = {
|
159 |
"rouge1": rouge1_scores[0],
|