pminervini commited on
Commit
590fea3
1 Parent(s): 7fe1886
src/backend/tasks/cnndm/task.py CHANGED
@@ -69,6 +69,7 @@ class CnnDm(Task):
69
  super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config)
70
  self.factkb_tokenizer = None
71
  self.factkb_model = None
 
72
 
73
  def maybe_init_factkb(self):
74
  if self.factkb_tokenizer is None or self.factkb_model is None:
@@ -76,6 +77,11 @@ class CnnDm(Task):
76
  self.factkb_tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True)
77
  self.factkb_model = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels=2, device_map="auto")
78
 
 
 
 
 
 
79
  def has_training_docs(self):
80
  return True
81
 
@@ -153,11 +159,17 @@ class CnnDm(Task):
153
  factkb_logits = self.factkb_model(**factkb_tokens).logits
154
  factkb_res = torch.softmax(factkb_logits, dim=1)
155
 
 
 
 
156
  res = {
157
  "rouge1": rouge1_scores[0],
158
  "rouge2": rouge2_scores[0],
159
  "rougeL": rougeL_scores[0],
160
- "factKB": float(factkb_res[0][1])
 
 
 
161
  }
162
 
163
  return res
 
69
  super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config)
70
  self.factkb_tokenizer = None
71
  self.factkb_model = None
72
+ self.bert_score = None
73
 
74
  def maybe_init_factkb(self):
75
  if self.factkb_tokenizer is None or self.factkb_model is None:
 
77
  self.factkb_tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True)
78
  self.factkb_model = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels=2, device_map="auto")
79
 
80
+ def maybe_init_bertscore(self):
81
+ if self.bert_score is None:
82
+ from evaluate import load
83
+ self.bert_score = load("bertscore")
84
+
85
  def has_training_docs(self):
86
  return True
87
 
 
159
  factkb_logits = self.factkb_model(**factkb_tokens).logits
160
  factkb_res = torch.softmax(factkb_logits, dim=1)
161
 
162
+ self.maybe_init_bertscore()
163
+ bert_score_res = self.bert_score.compute(predictions=[completion], references=[gold_summary], model_type="microsoft/deberta-xlarge-mnli", lang="en")
164
+
165
  res = {
166
  "rouge1": rouge1_scores[0],
167
  "rouge2": rouge2_scores[0],
168
  "rougeL": rougeL_scores[0],
169
+ "factKB": float(factkb_res[0][1]),
170
+ "bertscore_precision": float(bert_score_res["precision"][0]),
171
+ "bertscore_recall": float(bert_score_res["recall"][0]),
172
+ "bertscore_f1": float(bert_score_res["f1"][0])
173
  }
174
 
175
  return res
src/backend/tasks/xsum/task.py CHANGED
@@ -153,7 +153,7 @@ class XSum(Task):
153
  factkb_res = torch.softmax(factkb_logits, dim=1)
154
 
155
  self.maybe_init_bertscore()
156
- bert_score_res = self.bert_score.compute(predictions=[completion], references=[gold_summary], lang="en")
157
 
158
  res = {
159
  "rouge1": rouge1_scores[0],
 
153
  factkb_res = torch.softmax(factkb_logits, dim=1)
154
 
155
  self.maybe_init_bertscore()
156
+ bert_score_res = self.bert_score.compute(predictions=[completion], references=[gold_summary], model_type="microsoft/deberta-xlarge-mnli", lang="en")
157
 
158
  res = {
159
  "rouge1": rouge1_scores[0],