pminervini commited on
Commit
19d09c1
1 Parent(s): a0d8a50
src/backend/tasks/cnndm/__pycache__/task.cpython-39.pyc DELETED
Binary file (4.27 kB)
 
src/backend/tasks/cnndm/__pycache__/utils.cpython-39.pyc DELETED
Binary file (2.81 kB)
 
src/backend/tasks/xsum/task.py CHANGED
@@ -3,6 +3,7 @@ from lm_eval.api.instance import Instance
3
  from lm_eval.api.registry import register_task
4
  from lm_eval.api.metrics import mean
5
 
 
6
  import sacrebleu
7
  from rouge_score import rouge_scorer, scoring
8
 
@@ -61,11 +62,11 @@ class XSum(Task):
61
  self.factkb_tokenizer = None
62
  self.factkb_model = None
63
 
64
- def init_factkb(self):
65
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
66
-
67
- self.factkb_tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True)
68
- self.factkb_model = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels=2)
69
 
70
  def has_training_docs(self):
71
  return True
@@ -114,7 +115,8 @@ class XSum(Task):
114
  Instance(
115
  request_type="generate_until",
116
  doc=doc,
117
- arguments=(ctx, {"until": ["\n", "."]}),
 
118
  idx=0,
119
  **kwargs
120
  )
@@ -123,28 +125,34 @@ class XSum(Task):
123
  def process_results(self, doc, results):
124
  completion = results[0]
125
 
126
- # document = doc["document"]
127
  true_refs = [doc["summary"]]
128
  all_refs = true_refs
129
 
130
  # ROUGE-N
131
  rouge_scores = [rouge([ref], [completion]) for ref in all_refs]
132
-
133
  # ROUGE-1
134
  rouge1_scores = [score["rouge1"] for score in rouge_scores]
135
-
136
  # ROUGE-2
137
  rouge2_scores = [score["rouge2"] for score in rouge_scores]
138
-
139
  # ROUGE-L
140
  rougeL_scores = [score["rougeLsum"] for score in rouge_scores]
141
 
 
 
 
 
 
 
142
  res = {
143
  "rouge1": rouge1_scores[0],
144
  "rouge2": rouge2_scores[0],
145
  "rougeL": rougeL_scores[0],
 
146
  }
147
 
 
 
148
  return res
149
 
150
  def aggregation(self):
@@ -153,7 +161,7 @@ class XSum(Task):
153
  A dictionary where keys are the names of submetrics and values are
154
  functions that aggregate a list of metrics
155
  """
156
- return {k: mean for k in ["rouge1", "rouge2", "rougeL"]}
157
 
158
  def higher_is_better(self):
159
  """
@@ -161,4 +169,4 @@ class XSum(Task):
161
  A dictionary where keys are the names of submetrics and values are
162
  whether a higher value of the submetric is better
163
  """
164
- return {k: True for k in ["rouge1", "rouge2", "rougeL"]}
 
3
  from lm_eval.api.registry import register_task
4
  from lm_eval.api.metrics import mean
5
 
6
+ import torch
7
  import sacrebleu
8
  from rouge_score import rouge_scorer, scoring
9
 
 
62
  self.factkb_tokenizer = None
63
  self.factkb_model = None
64
 
65
+ def maybe_init_factkb(self):
66
+ if self.factkb_tokenizer is None or self.factkb_model is None:
67
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
68
+ self.factkb_tokenizer = AutoTokenizer.from_pretrained("roberta-base", padding="max_length", truncation=True)
69
+ self.factkb_model = AutoModelForSequenceClassification.from_pretrained("bunsenfeng/FactKB", num_labels=2, device_map="auto")
70
 
71
  def has_training_docs(self):
72
  return True
 
115
  Instance(
116
  request_type="generate_until",
117
  doc=doc,
118
+ # arguments=(ctx, {"until": ["\n", "."]}),
119
+ arguments=(ctx, {"until": ["\n"]}),
120
  idx=0,
121
  **kwargs
122
  )
 
125
  def process_results(self, doc, results):
126
  completion = results[0]
127
 
128
+ document = doc["document"]
129
  true_refs = [doc["summary"]]
130
  all_refs = true_refs
131
 
132
  # ROUGE-N
133
  rouge_scores = [rouge([ref], [completion]) for ref in all_refs]
 
134
  # ROUGE-1
135
  rouge1_scores = [score["rouge1"] for score in rouge_scores]
 
136
  # ROUGE-2
137
  rouge2_scores = [score["rouge2"] for score in rouge_scores]
 
138
  # ROUGE-L
139
  rougeL_scores = [score["rougeLsum"] for score in rouge_scores]
140
 
141
+ self.maybe_init_factkb()
142
+ input_factkb = [[completion, document]]
143
+ factkb_tokens = self.factkb_tokenizer(input_factkb, return_tensors="pt", padding="max_length", truncation=True).to(self.factkb_model.device)
144
+ factkb_logits = self.factkb_model(**factkb_tokens).logits
145
+ factkb_res = torch.softmax(factkb_logits, dim=1)
146
+
147
  res = {
148
  "rouge1": rouge1_scores[0],
149
  "rouge2": rouge2_scores[0],
150
  "rougeL": rougeL_scores[0],
151
+ "factKB": float(factkb_res[0][1])
152
  }
153
 
154
+ # breakpoint()
155
+
156
  return res
157
 
158
  def aggregation(self):
 
161
  A dictionary where keys are the names of submetrics and values are
162
  functions that aggregate a list of metrics
163
  """
164
+ return {k: mean for k in ["rouge1", "rouge2", "rougeL", "factKB"]}
165
 
166
  def higher_is_better(self):
167
  """
 
169
  A dictionary where keys are the names of submetrics and values are
170
  whether a higher value of the submetric is better
171
  """
172
+ return {k: True for k in ["rouge1", "rouge2", "rougeL", "factKB"]}