Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import ctypes | |
import math | |
import sys | |
from dataclasses import dataclass, field | |
import torch | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.scoring import BaseScorer, register_scorer | |
from fairseq.scoring.tokenizer import EvaluationTokenizer | |
class BleuStat(ctypes.Structure): | |
_fields_ = [ | |
("reflen", ctypes.c_size_t), | |
("predlen", ctypes.c_size_t), | |
("match1", ctypes.c_size_t), | |
("count1", ctypes.c_size_t), | |
("match2", ctypes.c_size_t), | |
("count2", ctypes.c_size_t), | |
("match3", ctypes.c_size_t), | |
("count3", ctypes.c_size_t), | |
("match4", ctypes.c_size_t), | |
("count4", ctypes.c_size_t), | |
] | |
class SacrebleuConfig(FairseqDataclass): | |
sacrebleu_tokenizer: EvaluationTokenizer.ALL_TOKENIZER_TYPES = field( | |
default="13a", metadata={"help": "tokenizer"} | |
) | |
sacrebleu_lowercase: bool = field( | |
default=False, metadata={"help": "apply lowercasing"} | |
) | |
sacrebleu_char_level: bool = field( | |
default=False, metadata={"help": "evaluate at character level"} | |
) | |
class SacrebleuScorer(BaseScorer): | |
def __init__(self, cfg): | |
super(SacrebleuScorer, self).__init__(cfg) | |
import sacrebleu | |
self.sacrebleu = sacrebleu | |
self.tokenizer = EvaluationTokenizer( | |
tokenizer_type=cfg.sacrebleu_tokenizer, | |
lowercase=cfg.sacrebleu_lowercase, | |
character_tokenization=cfg.sacrebleu_char_level, | |
) | |
def add_string(self, ref, pred): | |
self.ref.append(self.tokenizer.tokenize(ref)) | |
self.pred.append(self.tokenizer.tokenize(pred)) | |
def score(self, order=4): | |
return self.result_string(order).score | |
def result_string(self, order=4): | |
if order != 4: | |
raise NotImplementedError | |
# tokenization and lowercasing are performed by self.tokenizer instead. | |
return self.sacrebleu.corpus_bleu( | |
self.pred, [self.ref], tokenize="none" | |
).format() | |
class BleuConfig(FairseqDataclass): | |
pad: int = field(default=1, metadata={"help": "padding index"}) | |
eos: int = field(default=2, metadata={"help": "eos index"}) | |
unk: int = field(default=3, metadata={"help": "unk index"}) | |
class Scorer(object): | |
def __init__(self, cfg): | |
self.stat = BleuStat() | |
self.pad = cfg.pad | |
self.eos = cfg.eos | |
self.unk = cfg.unk | |
try: | |
from fairseq import libbleu | |
except ImportError as e: | |
sys.stderr.write( | |
"ERROR: missing libbleu.so. run `pip install --editable .`\n" | |
) | |
raise e | |
self.C = ctypes.cdll.LoadLibrary(libbleu.__file__) | |
self.reset() | |
def reset(self, one_init=False): | |
if one_init: | |
self.C.bleu_one_init(ctypes.byref(self.stat)) | |
else: | |
self.C.bleu_zero_init(ctypes.byref(self.stat)) | |
def add(self, ref, pred): | |
if not isinstance(ref, torch.IntTensor): | |
raise TypeError("ref must be a torch.IntTensor (got {})".format(type(ref))) | |
if not isinstance(pred, torch.IntTensor): | |
raise TypeError("pred must be a torch.IntTensor(got {})".format(type(pred))) | |
# don't match unknown words | |
rref = ref.clone() | |
assert not rref.lt(0).any() | |
rref[rref.eq(self.unk)] = -999 | |
rref = rref.contiguous().view(-1) | |
pred = pred.contiguous().view(-1) | |
self.C.bleu_add( | |
ctypes.byref(self.stat), | |
ctypes.c_size_t(rref.size(0)), | |
ctypes.c_void_p(rref.data_ptr()), | |
ctypes.c_size_t(pred.size(0)), | |
ctypes.c_void_p(pred.data_ptr()), | |
ctypes.c_int(self.pad), | |
ctypes.c_int(self.eos), | |
) | |
def score(self, order=4): | |
psum = sum( | |
math.log(p) if p > 0 else float("-Inf") for p in self.precision()[:order] | |
) | |
return self.brevity() * math.exp(psum / order) * 100 | |
def precision(self): | |
def ratio(a, b): | |
return a / b if b > 0 else 0 | |
return [ | |
ratio(self.stat.match1, self.stat.count1), | |
ratio(self.stat.match2, self.stat.count2), | |
ratio(self.stat.match3, self.stat.count3), | |
ratio(self.stat.match4, self.stat.count4), | |
] | |
def brevity(self): | |
r = self.stat.reflen / self.stat.predlen | |
return min(1, math.exp(1 - r)) | |
def result_string(self, order=4): | |
assert order <= 4, "BLEU scores for order > 4 aren't supported" | |
fmt = "BLEU{} = {:2.2f}, {:2.1f}" | |
for _ in range(1, order): | |
fmt += "/{:2.1f}" | |
fmt += " (BP={:.3f}, ratio={:.3f}, syslen={}, reflen={})" | |
bleup = [p * 100 for p in self.precision()[:order]] | |
return fmt.format( | |
order, | |
self.score(order=order), | |
*bleup, | |
self.brevity(), | |
self.stat.predlen / self.stat.reflen, | |
self.stat.predlen, | |
self.stat.reflen | |
) | |