lvwerra HF staff commited on
Commit
1698f0a
1 Parent(s): 8a39ecb

Update Space (evaluate main: c447fc8e)

Browse files
Files changed (2) hide show
  1. frugalscore.py +13 -25
  2. requirements.txt +1 -1
frugalscore.py CHANGED
@@ -13,9 +13,6 @@
13
  # limitations under the License.
14
  """FrugalScore metric."""
15
 
16
- from dataclasses import dataclass
17
- from typing import Optional
18
-
19
  import datasets
20
  import torch
21
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
@@ -57,28 +54,13 @@ Examples:
57
  """
58
 
59
 
60
- @dataclass
61
- class FRUGALSCOREConfig(evaluate.info.Config):
62
-
63
- name: str = "default"
64
-
65
- batch_size: int = 32
66
- max_length: int = 128
67
- device: Optional[str] = None
68
-
69
-
70
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
71
  class FRUGALSCORE(evaluate.Metric):
72
-
73
- CONFIG_CLASS = FRUGALSCOREConfig
74
- ALLOWED_CONFIG_NAMES = ["default"]
75
-
76
- def _info(self, config):
77
  return evaluate.MetricInfo(
78
  description=_DESCRIPTION,
79
  citation=_CITATION,
80
  inputs_description=_KWARGS_DESCRIPTION,
81
- config=config,
82
  features=datasets.Features(
83
  {
84
  "predictions": datasets.Value("string"),
@@ -96,20 +78,26 @@ class FRUGALSCORE(evaluate.Metric):
96
  self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
97
  self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
98
 
99
- def _compute(self, predictions, references):
 
 
 
 
 
 
 
100
  """Returns the scores"""
101
  assert len(predictions) == len(
102
  references
103
  ), "predictions and references should have the same number of sentences."
104
- if self.config.device is not None:
105
- assert self.config.device in ["gpu", "cpu"], "device should be either gpu or cpu."
106
- device = self.config.device
107
  else:
108
  device = "gpu" if torch.cuda.is_available() else "cpu"
109
  training_args = TrainingArguments(
110
  "trainer",
111
  fp16=(device == "gpu"),
112
- per_device_eval_batch_size=self.config.batch_size,
113
  report_to="all",
114
  no_cuda=(device == "cpu"),
115
  log_level="warning",
@@ -119,7 +107,7 @@ class FRUGALSCORE(evaluate.Metric):
119
 
120
  def tokenize_function(data):
121
  return self.tokenizer(
122
- data["sentence1"], data["sentence2"], max_length=self.config.max_length, truncation=True, padding=True
123
  )
124
 
125
  tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
 
13
  # limitations under the License.
14
  """FrugalScore metric."""
15
 
 
 
 
16
  import datasets
17
  import torch
18
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
 
54
  """
55
 
56
 
 
 
 
 
 
 
 
 
 
 
57
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
58
  class FRUGALSCORE(evaluate.Metric):
59
+ def _info(self):
 
 
 
 
60
  return evaluate.MetricInfo(
61
  description=_DESCRIPTION,
62
  citation=_CITATION,
63
  inputs_description=_KWARGS_DESCRIPTION,
 
64
  features=datasets.Features(
65
  {
66
  "predictions": datasets.Value("string"),
 
78
  self.model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
79
  self.tokenizer = AutoTokenizer.from_pretrained(checkpoint)
80
 
81
+ def _compute(
82
+ self,
83
+ predictions,
84
+ references,
85
+ batch_size=32,
86
+ max_length=128,
87
+ device=None,
88
+ ):
89
  """Returns the scores"""
90
  assert len(predictions) == len(
91
  references
92
  ), "predictions and references should have the same number of sentences."
93
+ if device is not None:
94
+ assert device in ["gpu", "cpu"], "device should be either gpu or cpu."
 
95
  else:
96
  device = "gpu" if torch.cuda.is_available() else "cpu"
97
  training_args = TrainingArguments(
98
  "trainer",
99
  fp16=(device == "gpu"),
100
+ per_device_eval_batch_size=batch_size,
101
  report_to="all",
102
  no_cuda=(device == "cpu"),
103
  log_level="warning",
 
107
 
108
  def tokenize_function(data):
109
  return self.tokenizer(
110
+ data["sentence1"], data["sentence2"], max_length=max_length, truncation=True, padding=True
111
  )
112
 
113
  tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  torch
3
  transformers
 
1
+ git+https://github.com/huggingface/evaluate@c447fc8eda9c62af501bfdc6988919571050d950
2
  torch
3
  transformers