sonsus commited on
Commit
f25ae2f
1 Parent(s): 38283e3

Update harim_plus.py

Browse files
Files changed (1) hide show
  1. harim_plus.py +9 -8
harim_plus.py CHANGED
@@ -171,7 +171,7 @@ class Harimplus_Scorer:
171
  bsz:int=32,
172
  use_aggregator:bool=False,
173
  return_details:bool=False,
174
- tokenwise_score:bool=False,
175
  ):
176
  '''
177
  returns harim+ score (List[float]) for predictions (summaries) and references (articles)
@@ -238,15 +238,15 @@ class Harimplus_Scorer:
238
  harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, tgt_mask)
239
  harim = harim_tok.sum(-1) / sent_lengths
240
 
241
- harim_plus_normalized = ll + self._lambda * harim # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
242
 
243
  scores['harim+'].extend(harim_plus_normalized.tolist())
244
  scores['harim'].extend(harim.tolist())
245
  scores['log_ppl'].extend(ll.tolist())
246
 
247
- if tokenwise_score:
248
- scores['tok_harim+'].append(harim_tok*self._lambda + ll_tok)
249
- scores['tok_predictions'].append( [self._tokenizer.convert_ids_to_token(idxs) for idxs in src_in.labels] )
250
 
251
  if use_aggregator: # after
252
  for k, v in scores.items():
@@ -314,13 +314,14 @@ class Harimplus(evaluate.Metric):
314
  references=None,
315
  use_aggregator=False,
316
  bsz=32,
317
- tokenwise_score=False,
318
- return_details=False):
 
319
  summaries = predictions
320
  articles = references
321
  scores = self.scorer.compute(predictions=summaries,
322
  references=articles,
323
  use_aggregator=use_aggregator,
324
- bsz=bsz, tokenwise_score=tokenwise_score,
325
  return_details=return_details)
326
  return scores
 
171
  bsz:int=32,
172
  use_aggregator:bool=False,
173
  return_details:bool=False,
174
+ # tokenwise_score:bool=False,
175
  ):
176
  '''
177
  returns harim+ score (List[float]) for predictions (summaries) and references (articles)
 
238
  harim_tok = self.harim(s2s_logits, lm_logits, tgt_in.input_ids, tgt_mask)
239
  harim = harim_tok.sum(-1) / sent_lengths
240
 
241
+ harim_plus_normalized = (ll + self._lambda * harim)/sent_lengths # loglikelihood + lambda * negative_harim (negative harim=-1* risk)
242
 
243
  scores['harim+'].extend(harim_plus_normalized.tolist())
244
  scores['harim'].extend(harim.tolist())
245
  scores['log_ppl'].extend(ll.tolist())
246
 
247
+ # if tokenwise_score:
248
+ # scores['tok_harim+'].append(harim_tok*self._lambda + ll_tok)
249
+ # scores['tok_predictions'].append( [self._tokenizer.convert_ids_to_token(idxs) for idxs in src_in.labels] )
250
 
251
  if use_aggregator: # after
252
  for k, v in scores.items():
 
314
  references=None,
315
  use_aggregator=False,
316
  bsz=32,
317
+ return_details=False):
318
+ # tokenwise_score=False,
319
+
320
  summaries = predictions
321
  articles = references
322
  scores = self.scorer.compute(predictions=summaries,
323
  references=articles,
324
  use_aggregator=use_aggregator,
325
+ bsz=bsz, #tokenwise_score=tokenwise_score,
326
  return_details=return_details)
327
  return scores