|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Evaluate the perplexity of a trained language model. |
|
""" |
|
|
|
import logging |
|
import math |
|
import os |
|
import sys |
|
from argparse import Namespace |
|
from typing import Iterable, List, Optional |
|
|
|
import torch |
|
import fairseq |
|
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
from fairseq.logging import progress_bar |
|
from fairseq.logging.meters import StopwatchMeter |
|
from fairseq.sequence_scorer import SequenceScorer |
|
from omegaconf import DictConfig |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=os.environ.get("LOGLEVEL", "INFO").upper(), |
|
stream=sys.stdout, |
|
) |
|
logger = logging.getLogger("fairseq_cli.eval_lm") |
|
|
|
|
|
def eval_lm( |
|
models: List[fairseq.models.FairseqModel], |
|
source_dictionary: fairseq.data.Dictionary, |
|
batch_iterator: Iterable, |
|
post_process: Optional[str] = None, |
|
output_word_probs: bool = False, |
|
output_word_stats: bool = False, |
|
target_dictionary: Optional[fairseq.data.Dictionary] = None, |
|
softmax_batch: int = 0, |
|
remove_bos_token: bool = False, |
|
device: Optional[torch.device] = None, |
|
): |
|
""" |
|
Args: |
|
models (List[~fairseq.models.FairseqModel]): list of models to |
|
evaluate. Models are essentially `nn.Module` instances, but |
|
must be compatible with fairseq's `SequenceScorer`. |
|
source_dictionary (~fairseq.data.Dictionary): dictionary for |
|
applying any relevant post processing or outputing word |
|
probs/stats. |
|
batch_iterator (Iterable): yield batches of data |
|
post_process (Optional[str]): post-process text by removing BPE, |
|
letter segmentation, etc. Valid options can be found in |
|
fairseq.data.utils.post_process, although not all options |
|
are implemented here. |
|
output_word_probs (Optional[bool]): output words and their |
|
predicted log probabilities |
|
output_word_stats (Optional[bool]): output word statistics such |
|
as word count and average probability |
|
target_dictionary (Optional[~fairseq.data.Dictionary]): output |
|
dictionary (defaults to *source_dictionary*) |
|
softmax_batch (Optional[bool]): if BxT is more than this, will |
|
batch the softmax over vocab to this amount of tokens, in |
|
order to fit into GPU memory |
|
remove_bos_token (Optional[bool]): if True, confirm that the |
|
first token is the beginning-of-sentence symbol (according |
|
to the relevant dictionary) and remove it from the output |
|
device (Optional[torch.device]): device to use for evaluation |
|
(defaults to device of first model parameter) |
|
""" |
|
if target_dictionary is None: |
|
target_dictionary = source_dictionary |
|
if device is None: |
|
device = next(models[0].parameters()).device |
|
|
|
gen_timer = StopwatchMeter() |
|
scorer = SequenceScorer(target_dictionary, softmax_batch) |
|
|
|
score_sum = 0.0 |
|
count = 0 |
|
|
|
if post_process is not None: |
|
if post_process in {"subword_nmt", "@@ "}: |
|
bpe_cont = post_process.rstrip() |
|
bpe_toks = { |
|
i |
|
for i in range(len(source_dictionary)) |
|
if source_dictionary[i].endswith(bpe_cont) |
|
} |
|
else: |
|
raise NotImplementedError( |
|
"--post-process={post_process} is not implemented" |
|
) |
|
bpe_len = len(bpe_cont) |
|
else: |
|
bpe_toks = None |
|
bpe_len = 0 |
|
|
|
word_stats = dict() |
|
|
|
for sample in batch_iterator: |
|
if "net_input" not in sample: |
|
continue |
|
|
|
sample = utils.move_to_cuda(sample, device=device) |
|
|
|
gen_timer.start() |
|
hypos = scorer.generate(models, sample) |
|
gen_timer.stop(sample["ntokens"]) |
|
|
|
for i, hypos_i in enumerate(hypos): |
|
hypo = hypos_i[0] |
|
sample_id = sample["id"][i] |
|
|
|
tokens = hypo["tokens"] |
|
tgt_len = tokens.numel() |
|
pos_scores = hypo["positional_scores"].float() |
|
|
|
if remove_bos_token: |
|
assert hypo["tokens"][0].item() == target_dictionary.bos() |
|
tokens = tokens[1:] |
|
pos_scores = pos_scores[1:] |
|
|
|
skipped_toks = 0 |
|
if bpe_toks is not None: |
|
for i in range(tgt_len - 1): |
|
if tokens[i].item() in bpe_toks: |
|
skipped_toks += 1 |
|
pos_scores[i + 1] += pos_scores[i] |
|
pos_scores[i] = 0 |
|
|
|
inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf")) |
|
if inf_scores.any(): |
|
logger.info( |
|
"skipping tokens with inf scores:", |
|
target_dictionary.string(tokens[inf_scores.nonzero()]), |
|
) |
|
pos_scores = pos_scores[(~inf_scores).nonzero()] |
|
score_sum += pos_scores.sum().cpu() |
|
count += pos_scores.numel() - skipped_toks |
|
|
|
if output_word_probs or output_word_stats: |
|
w = "" |
|
word_prob = [] |
|
is_bpe = False |
|
for i in range(len(tokens)): |
|
w_ind = tokens[i].item() |
|
w += source_dictionary[w_ind] |
|
if bpe_toks is not None and w_ind in bpe_toks: |
|
w = w[:-bpe_len] |
|
is_bpe = True |
|
else: |
|
word_prob.append((w, pos_scores[i].item())) |
|
|
|
next_prob = None |
|
ind = i + 1 |
|
while ind < len(tokens): |
|
if pos_scores[ind].item() != 0: |
|
next_prob = pos_scores[ind] |
|
break |
|
ind += 1 |
|
|
|
word_stats.setdefault(w, WordStat(w, is_bpe)).add( |
|
pos_scores[i].item(), next_prob |
|
) |
|
is_bpe = False |
|
w = "" |
|
if output_word_probs: |
|
logger.info( |
|
str(int(sample_id)) |
|
+ " " |
|
+ ( |
|
"\t".join( |
|
"{} [{:2f}]".format(x[0], x[1]) for x in word_prob |
|
) |
|
) |
|
) |
|
|
|
avg_nll_loss = ( |
|
-score_sum / count / math.log(2) if count > 0 else 0 |
|
) |
|
logger.info( |
|
"Evaluated {:,} tokens in {:.1f}s ({:.2f} tokens/s)".format( |
|
gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0 |
|
) |
|
) |
|
|
|
if output_word_stats: |
|
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): |
|
logger.info(ws) |
|
|
|
return { |
|
"loss": avg_nll_loss, |
|
"perplexity": 2 ** avg_nll_loss, |
|
} |
|
|
|
|
|
class WordStat(object): |
|
def __init__(self, word, is_bpe): |
|
self.word = word |
|
self.is_bpe = is_bpe |
|
self.log_prob = 0 |
|
self.next_word_prob = 0 |
|
self.count = 0 |
|
self.missing_next_words = 0 |
|
|
|
def add(self, log_prob, next_word_prob): |
|
"""increments counters for the sum of log probs of current word and next |
|
word (given context ending at current word). Since the next word might be at the end of the example, |
|
or it might be not counted because it is not an ending subword unit, |
|
also keeps track of how many of those we have seen""" |
|
if next_word_prob is not None: |
|
self.next_word_prob += next_word_prob |
|
else: |
|
self.missing_next_words += 1 |
|
self.log_prob += log_prob |
|
self.count += 1 |
|
|
|
def __str__(self): |
|
return "{}\t{}\t{}\t{}\t{}\t{}".format( |
|
self.word, |
|
self.count, |
|
self.log_prob, |
|
self.is_bpe, |
|
self.next_word_prob, |
|
self.count - self.missing_next_words, |
|
) |
|
|
|
|
|
def main(cfg: DictConfig, **unused_kwargs): |
|
if isinstance(cfg, Namespace): |
|
cfg = convert_namespace_to_omegaconf(cfg) |
|
|
|
utils.import_user_module(cfg.common) |
|
|
|
logger.info(cfg) |
|
|
|
if cfg.eval_lm.context_window > 0: |
|
|
|
cfg.task.tokens_per_sample -= cfg.eval_lm.context_window |
|
|
|
|
|
task = tasks.setup_task(cfg.task) |
|
|
|
|
|
logger.info("loading model(s) from {}".format(cfg.common_eval.path)) |
|
models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( |
|
[cfg.common_eval.path], |
|
arg_overrides=eval(cfg.common_eval.model_overrides), |
|
suffix=cfg.checkpoint.checkpoint_suffix, |
|
strict=(cfg.checkpoint.checkpoint_shard_count == 1), |
|
num_shards=cfg.checkpoint.checkpoint_shard_count, |
|
task=task, |
|
) |
|
|
|
use_fp16 = cfg.common.fp16 |
|
use_cuda = torch.cuda.is_available() and not cfg.common.cpu |
|
if use_cuda: |
|
torch.cuda.set_device(cfg.distributed_training.device_id) |
|
|
|
|
|
|
|
for model in models: |
|
if use_fp16: |
|
model.half() |
|
if use_cuda and not cfg.distributed_training.pipeline_model_parallel: |
|
model.cuda() |
|
model.prepare_for_inference_(cfg) |
|
|
|
assert len(models) > 0 |
|
|
|
logger.info( |
|
"num. model params: {:,}".format(sum(p.numel() for p in models[0].parameters())) |
|
) |
|
|
|
|
|
task.load_dataset(cfg.dataset.gen_subset) |
|
dataset = task.dataset(cfg.dataset.gen_subset) |
|
logger.info( |
|
"{} {} {:,} examples".format( |
|
cfg.task.data, cfg.dataset.gen_subset, len(dataset) |
|
) |
|
) |
|
|
|
itr = task.eval_lm_dataloader( |
|
dataset=dataset, |
|
max_tokens=cfg.dataset.max_tokens or 36000, |
|
batch_size=cfg.dataset.batch_size, |
|
max_positions=utils.resolve_max_positions( |
|
*[model.max_positions() for model in models] |
|
), |
|
num_shards=max( |
|
cfg.dataset.num_shards, |
|
cfg.distributed_training.distributed_world_size, |
|
), |
|
shard_id=max( |
|
cfg.dataset.shard_id, |
|
cfg.distributed_training.distributed_rank, |
|
), |
|
num_workers=cfg.dataset.num_workers, |
|
data_buffer_size=cfg.dataset.data_buffer_size, |
|
context_window=cfg.eval_lm.context_window, |
|
) |
|
|
|
itr = progress_bar.progress_bar( |
|
itr, |
|
log_format=cfg.common.log_format, |
|
log_interval=cfg.common.log_interval, |
|
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), |
|
) |
|
|
|
results = eval_lm( |
|
models=models, |
|
source_dictionary=task.source_dictionary, |
|
batch_iterator=itr, |
|
post_process=cfg.common_eval.post_process, |
|
output_word_probs=cfg.eval_lm.output_word_probs, |
|
output_word_stats=cfg.eval_lm.output_word_stats, |
|
target_dictionary=task.target_dictionary, |
|
softmax_batch=cfg.eval_lm.softmax_batch, |
|
remove_bos_token=getattr(cfg.task, "add_bos_token", False), |
|
) |
|
|
|
logger.info( |
|
"Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( |
|
results["loss"], results["perplexity"] |
|
) |
|
) |
|
|
|
return results |
|
|
|
|
|
def cli_main(): |
|
parser = options.get_eval_lm_parser() |
|
args = options.parse_args_and_arch(parser) |
|
|
|
distributed_utils.call_main(convert_namespace_to_omegaconf(args), main) |
|
|
|
|
|
if __name__ == "__main__": |
|
cli_main() |
|
|