lvwerra HF staff commited on
Commit
8a39ecb
1 Parent(s): e137d75

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. frugalscore.py +25 -13
  2. requirements.txt +1 -1
frugalscore.py CHANGED
@@ -13,6 +13,9 @@
13
  # limitations under the License.
14
  """FrugalScore metric."""
15
 
 
 
 
16
  import datasets
17
  import torch
18
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
@@ -54,13 +57,28 @@ Examples:
54
  """
55
 
56
 
 
 
 
 
 
 
 
 
 
 
57
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
58
  class FRUGALSCORE(evaluate.Metric):
59
- def _info(self):
 
 
 
 
60
  return evaluate.MetricInfo(
61
  description=_DESCRIPTION,
62
  citation=_CITATION,
63
  inputs_description=_KWARGS_DESCRIPTION,
 
64
  features=datasets.Features(
65
  {
66
  "predictions": datasets.Value("string"),
@@ -78,26 +96,20 @@ class FRUGALSCORE(evaluate.Metric):
78
  self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
79
  self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
80
 
81
- def _compute(
82
- self,
83
- predictions,
84
- references,
85
- batch_size=32,
86
- max_length=128,
87
- device=None,
88
- ):
89
  """Returns the scores"""
90
  assert len(predictions) == len(
91
  references
92
  ), "predictions and references should have the same number of sentences."
93
- if device is not None:
94
- assert device in ["gpu", "cpu"], "device should be either gpu or cpu."
 
95
  else:
96
  device = "gpu" if torch.cuda.is_available() else "cpu"
97
  training_args = TrainingArguments(
98
  "trainer",
99
  fp16=(device == "gpu"),
100
- per_device_eval_batch_size=batch_size,
101
  report_to="all",
102
  no_cuda=(device == "cpu"),
103
  log_level="warning",
@@ -107,7 +119,7 @@ class FRUGALSCORE(evaluate.Metric):
107
 
108
  def tokenize_function(data):
109
  return self.tokenizer(
110
- data["sentence1"], data["sentence2"], max_length=max_length, truncation=True, padding=True
111
  )
112
 
113
  tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
 
13
  # limitations under the License.
14
  """FrugalScore metric."""
15
 
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
  import datasets
20
  import torch
21
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
 
57
  """
58
 
59
 
60
+ @dataclass
61
+ class FRUGALSCOREConfig(evaluate.info.Config):
62
+
63
+ name: str = "default"
64
+
65
+ batch_size: int = 32
66
+ max_length: int = 128
67
+ device: Optional[str] = None
68
+
69
+
70
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
71
  class FRUGALSCORE(evaluate.Metric):
72
+
73
+ CONFIG_CLASS = FRUGALSCOREConfig
74
+ ALLOWED_CONFIG_NAMES = ["default"]
75
+
76
+ def _info(self, config):
77
  return evaluate.MetricInfo(
78
  description=_DESCRIPTION,
79
  citation=_CITATION,
80
  inputs_description=_KWARGS_DESCRIPTION,
81
+ config=config,
82
  features=datasets.Features(
83
  {
84
  "predictions": datasets.Value("string"),
 
96
  self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
97
  self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
98
 
99
+ def _compute(self, predictions, references):
 
 
 
 
 
 
 
100
  """Returns the scores"""
101
  assert len(predictions) == len(
102
  references
103
  ), "predictions and references should have the same number of sentences."
104
+ if self.config.device is not None:
105
+ assert self.config.device in ["gpu", "cpu"], "device should be either gpu or cpu."
106
+ device = self.config.device
107
  else:
108
  device = "gpu" if torch.cuda.is_available() else "cpu"
109
  training_args = TrainingArguments(
110
  "trainer",
111
  fp16=(device == "gpu"),
112
+ per_device_eval_batch_size=self.config.batch_size,
113
  report_to="all",
114
  no_cuda=(device == "cpu"),
115
  log_level="warning",
 
119
 
120
  def tokenize_function(data):
121
  return self.tokenizer(
122
+ data["sentence1"], data["sentence2"], max_length=self.config.max_length, truncation=True, padding=True
123
  )
124
 
125
  tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  torch
3
  transformers
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  torch
3
  transformers