Wendaxu commited on
Commit
97de291
1 Parent(s): 578bca9

fix the data type

Browse files
Files changed (3) hide show
  1. app.py +1 -2
  2. requirements.txt +1 -1
  3. sescore.py +5 -8
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import evaluate
2
  from evaluate.utils import launch_gradio_widget
3
 
4
-
5
  module = evaluate.load("xu1998hz/sescore")
6
- launch_gradio_widget(module)
1
  import evaluate
2
  from evaluate.utils import launch_gradio_widget
3
 
 
4
  module = evaluate.load("xu1998hz/sescore")
5
+ launch_gradio_widget(module)
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  git+https://github.com/huggingface/evaluate@main
2
- gdown
3
  unbabel-comet
 
1
  git+https://github.com/huggingface/evaluate@main
 
2
  unbabel-comet
3
+ torch
sescore.py CHANGED
@@ -107,8 +107,8 @@ class SEScore(evaluate.Metric):
107
  inputs_description=_KWARGS_DESCRIPTION,
108
  # This defines the format of each prediction and reference
109
  features=datasets.Features({
110
- 'predictions': datasets.Value('int64'),
111
- 'references': datasets.Value('int64'),
112
  }),
113
  # Homepage of the module for documentation
114
  homepage="http://module.homepage",
@@ -121,23 +121,20 @@ class SEScore(evaluate.Metric):
121
  """download SEScore checkpoints to compute the scores"""
122
  # Download SEScore checkpoint
123
  from comet import load_from_checkpoint
124
- import gdown
125
  import os
126
  from huggingface_hub import snapshot_download
127
  # initialize roberta into str2encoder
128
  comet.encoders.str2encoder['RoBERTa'] = robertaEncoder
129
- # url = "https://drive.google.com/uc?id=1QgMP_Y4QCbvDMTeVacYt0J76OYvwWK9V&export=download&confirm=true"
130
- # output = 'sescore_ckpt.gz'
131
- # gdown.download(url, output, quiet=False)
132
- # cmd = 'tar -xvf sescore_ckpt.gz'
133
- # os.system(cmd)
134
  destination = snapshot_download(repo_id="xu1998hz/sescore_english_mt", revision="main")
135
  self.scorer = load_from_checkpoint(f'{destination}/checkpoint/sescore_english_mt.ckpt')
136
 
137
  def _compute(self, predictions, references, gpus=None, progress_bar=False):
138
  if gpus is None:
139
  gpus = 1 if torch.cuda.is_available() else 0
 
140
  data = {"src": references, "mt": predictions}
 
141
  data = [dict(zip(data, t)) for t in zip(*data.values())]
 
142
  scores, mean_score = self.scorer.predict(data, gpus=gpus, progress_bar=progress_bar)
143
  return {"mean_score": mean_score, "scores": scores}
107
  inputs_description=_KWARGS_DESCRIPTION,
108
  # This defines the format of each prediction and reference
109
  features=datasets.Features({
110
+ 'predictions': datasets.Value("string", id="sequence"),
111
+ 'references': datasets.Value("string", id="sequence"),
112
  }),
113
  # Homepage of the module for documentation
114
  homepage="http://module.homepage",
121
  """download SEScore checkpoints to compute the scores"""
122
  # Download SEScore checkpoint
123
  from comet import load_from_checkpoint
 
124
  import os
125
  from huggingface_hub import snapshot_download
126
  # initialize roberta into str2encoder
127
  comet.encoders.str2encoder['RoBERTa'] = robertaEncoder
 
 
 
 
 
128
  destination = snapshot_download(repo_id="xu1998hz/sescore_english_mt", revision="main")
129
  self.scorer = load_from_checkpoint(f'{destination}/checkpoint/sescore_english_mt.ckpt')
130
 
131
  def _compute(self, predictions, references, gpus=None, progress_bar=False):
132
  if gpus is None:
133
  gpus = 1 if torch.cuda.is_available() else 0
134
+
135
  data = {"src": references, "mt": predictions}
136
+ print(data)
137
  data = [dict(zip(data, t)) for t in zip(*data.values())]
138
+ print(data)
139
  scores, mean_score = self.scorer.predict(data, gpus=gpus, progress_bar=progress_bar)
140
  return {"mean_score": mean_score, "scores": scores}