saridormi commited on
Commit
5d84797
1 Parent(s): cdf268e

Remove extra import

Browse files
src/evaluation/commit_message_generation/cmg_metrics.py CHANGED
@@ -3,25 +3,19 @@ from typing import Dict, List
3
  import evaluate # type: ignore[import]
4
 
5
  from ..base_task_metrics import BaseTaskMetrics
6
- from .b_norm import BNorm
7
 
8
 
9
  class CMGMetrics(BaseTaskMetrics):
10
  def __init__(self):
11
- self.bnorm = BNorm()
12
  self.bleu = evaluate.load("sacrebleu")
13
  self.chrf = evaluate.load("chrf")
14
  self.rouge = evaluate.load("rouge")
15
  self.bertscore = evaluate.load("bertscore")
16
  self.bertscore_normalized = evaluate.load("bertscore")
17
 
18
- def reset(self):
19
- self.bnorm.reset()
20
-
21
- def update(
22
  self, predictions: List[str], references: List[str], *args, **kwargs
23
  ) -> None:
24
- self.bnorm.update(predictions=predictions, references=references)
25
  self.bleu.add_batch(
26
  predictions=predictions, references=[[ref] for ref in references]
27
  )
@@ -41,7 +35,6 @@ class CMGMetrics(BaseTaskMetrics):
41
  lang="en", rescale_with_baseline=True
42
  )
43
  return {
44
- "bnorm": self.bnorm.compute(),
45
  "bleu": self.bleu.compute(tokenize="13a")["score"],
46
  "chrf": self.chrf.compute()["score"],
47
  "rouge1": rouge["rouge1"] * 100,
 
3
  import evaluate # type: ignore[import]
4
 
5
  from ..base_task_metrics import BaseTaskMetrics
 
6
 
7
 
8
  class CMGMetrics(BaseTaskMetrics):
9
  def __init__(self):
 
10
  self.bleu = evaluate.load("sacrebleu")
11
  self.chrf = evaluate.load("chrf")
12
  self.rouge = evaluate.load("rouge")
13
  self.bertscore = evaluate.load("bertscore")
14
  self.bertscore_normalized = evaluate.load("bertscore")
15
 
16
+ def add_batch(
 
 
 
17
  self, predictions: List[str], references: List[str], *args, **kwargs
18
  ) -> None:
 
19
  self.bleu.add_batch(
20
  predictions=predictions, references=[[ref] for ref in references]
21
  )
 
35
  lang="en", rescale_with_baseline=True
36
  )
37
  return {
 
38
  "bleu": self.bleu.compute(tokenize="13a")["score"],
39
  "chrf": self.chrf.compute()["score"],
40
  "rouge1": rouge["rouge1"] * 100,