|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from fairseq.dataclass import FairseqDataclass |
|
|
from fairseq.scoring import BaseScorer, register_scorer |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BertScoreScorerConfig(FairseqDataclass): |
|
|
bert_score_lang: str = field(default="en", metadata={"help": "BERTScore language"}) |
|
|
|
|
|
|
|
|
@register_scorer("bert_score", dataclass=BertScoreScorerConfig) |
|
|
class BertScoreScorer(BaseScorer): |
|
|
def __init__(self, cfg): |
|
|
super(BertScoreScorer, self).__init__(cfg) |
|
|
try: |
|
|
import bert_score as _bert_score |
|
|
except ImportError: |
|
|
raise ImportError("Please install BERTScore: pip install bert-score") |
|
|
|
|
|
self.cfg = cfg |
|
|
self._bert_score = _bert_score |
|
|
self.scores = None |
|
|
|
|
|
def add_string(self, ref, pred): |
|
|
self.ref.append(ref) |
|
|
self.pred.append(pred) |
|
|
|
|
|
def score(self, order=4): |
|
|
_, _, self.scores = self._bert_score.score( |
|
|
self.pred, self.ref, lang=self.cfg.bert_score_lang |
|
|
) |
|
|
self.scores = self.scores.numpy() |
|
|
return np.mean(self.scores) |
|
|
|
|
|
def result_string(self, order=4): |
|
|
return f"BERTScore: {self.score():.4f}" |
|
|
|