xu1998hz commited on
Commit
64fd7e3
1 Parent(s): cc92fc5

fix small bugs in name for kept

Browse files
Files changed (1) hide show
  1. sescore.py +10 -1
sescore.py CHANGED
@@ -129,7 +129,16 @@ class SEScore(evaluate.Metric):
129
  else:
130
  destination = snapshot_download(repo_id=self.config_name, revision="main")
131
  suffix = self.config_name.split('/')[-1]
132
- self.scorer = load_from_checkpoint(f'{destination}/checkpoint/{suffix}.ckpt')
 
 
 
 
 
 
 
 
 
133
 
134
  def _compute(self, predictions, references, gpus=None, progress_bar=False):
135
  if gpus is None:
 
129
  else:
130
  destination = snapshot_download(repo_id=self.config_name, revision="main")
131
  suffix = self.config_name.split('/')[-1]
132
+ if suffix == 'sescore_english_mt':
133
+ self.scorer = load_from_checkpoint(f'{destination}/checkpoint/sescore_english_mt.ckpt')
134
+ elif suffix == 'sescore_german_mt':
135
+ self.scorer = load_from_checkpoint(f'{destination}/checkpoint/sescore_german.ckpt')
136
+ elif suffix == 'sescore_english_webnlg17':
137
+ self.scorer = load_from_checkpoint(f'{destination}/checkpoint/webnlg.ckpt')
138
+ elif suffix == 'sescore_english_coco':
139
+ self.scorer = load_from_checkpoint(f'{destination}/checkpoint/caption.ckpt')
140
+ else:
141
+ self.scorer = load_from_checkpoint(f'{destination}/checkpoint/{suffix}.ckpt')
142
 
143
  def _compute(self, predictions, references, gpus=None, progress_bar=False):
144
  if gpus is None: