|
|
|
|
|
|
|
|
|
|
|
""" |
|
Translate raw text with a trained model. Batches data on-the-fly. |
|
""" |
|
|
|
import ast |
|
import fileinput |
|
import logging |
|
import math |
|
import os |
|
import sys |
|
import time |
|
from argparse import Namespace |
|
from collections import namedtuple |
|
|
|
import numpy as np |
|
import torch |
|
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils |
|
from fairseq.dataclass.configs import FairseqConfig |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints |
|
from fairseq_cli.generate import get_symbols_to_strip_from_output |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=os.environ.get("LOGLEVEL", "INFO").upper(), |
|
stream=sys.stdout, |
|
) |
|
logger = logging.getLogger("fairseq_cli.interactive") |
|
|
|
|
|
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") |
|
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments") |
|
|
|
|
|
def buffered_read(input, buffer_size): |
|
buffer = [] |
|
with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h: |
|
for src_str in h: |
|
buffer.append(src_str.strip()) |
|
if len(buffer) >= buffer_size: |
|
yield buffer |
|
buffer = [] |
|
|
|
if len(buffer) > 0: |
|
yield buffer |
|
|
|
|
|
def make_batches(lines, cfg, task, max_positions, encode_fn): |
|
def encode_fn_target(x): |
|
return encode_fn(x) |
|
|
|
if cfg.generation.constraints: |
|
|
|
|
|
batch_constraints = [list() for _ in lines] |
|
for i, line in enumerate(lines): |
|
if "\t" in line: |
|
lines[i], *batch_constraints[i] = line.split("\t") |
|
|
|
|
|
for i, constraint_list in enumerate(batch_constraints): |
|
batch_constraints[i] = [ |
|
task.target_dictionary.encode_line( |
|
encode_fn_target(constraint), |
|
append_eos=False, |
|
add_if_not_exist=False, |
|
) |
|
for constraint in constraint_list |
|
] |
|
|
|
if cfg.generation.constraints: |
|
constraints_tensor = pack_constraints(batch_constraints) |
|
else: |
|
constraints_tensor = None |
|
|
|
tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn) |
|
|
|
itr = task.get_batch_iterator( |
|
dataset=task.build_dataset_for_inference( |
|
tokens, lengths, constraints=constraints_tensor |
|
), |
|
max_tokens=cfg.dataset.max_tokens, |
|
max_sentences=cfg.dataset.batch_size, |
|
max_positions=max_positions, |
|
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, |
|
).next_epoch_itr(shuffle=False) |
|
for batch in itr: |
|
ids = batch["id"] |
|
src_tokens = batch["net_input"]["src_tokens"] |
|
src_lengths = batch["net_input"]["src_lengths"] |
|
constraints = batch.get("constraints", None) |
|
|
|
yield Batch( |
|
ids=ids, |
|
src_tokens=src_tokens, |
|
src_lengths=src_lengths, |
|
constraints=constraints, |
|
) |
|
|
|
|
|
def main(cfg: FairseqConfig): |
|
if isinstance(cfg, Namespace): |
|
cfg = convert_namespace_to_omegaconf(cfg) |
|
|
|
start_time = time.time() |
|
total_translate_time = 0 |
|
|
|
utils.import_user_module(cfg.common) |
|
|
|
if cfg.interactive.buffer_size < 1: |
|
cfg.interactive.buffer_size = 1 |
|
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: |
|
cfg.dataset.batch_size = 1 |
|
|
|
assert ( |
|
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam |
|
), "--sampling requires --nbest to be equal to --beam" |
|
assert ( |
|
not cfg.dataset.batch_size |
|
or cfg.dataset.batch_size <= cfg.interactive.buffer_size |
|
), "--batch-size cannot be larger than --buffer-size" |
|
|
|
logger.info(cfg) |
|
|
|
|
|
if cfg.common.seed is not None and not cfg.generation.no_seed_provided: |
|
np.random.seed(cfg.common.seed) |
|
utils.set_torch_seed(cfg.common.seed) |
|
|
|
use_cuda = torch.cuda.is_available() and not cfg.common.cpu |
|
|
|
|
|
task = tasks.setup_task(cfg.task) |
|
|
|
|
|
overrides = ast.literal_eval(cfg.common_eval.model_overrides) |
|
logger.info("loading model(s) from {}".format(cfg.common_eval.path)) |
|
models, _model_args = checkpoint_utils.load_model_ensemble( |
|
utils.split_paths(cfg.common_eval.path), |
|
arg_overrides=overrides, |
|
task=task, |
|
suffix=cfg.checkpoint.checkpoint_suffix, |
|
strict=(cfg.checkpoint.checkpoint_shard_count == 1), |
|
num_shards=cfg.checkpoint.checkpoint_shard_count, |
|
) |
|
|
|
|
|
src_dict = task.source_dictionary |
|
tgt_dict = task.target_dictionary |
|
|
|
|
|
for model in models: |
|
if model is None: |
|
continue |
|
if cfg.common.fp16: |
|
model.half() |
|
if use_cuda and not cfg.distributed_training.pipeline_model_parallel: |
|
model.cuda() |
|
model.prepare_for_inference_(cfg) |
|
|
|
|
|
generator = task.build_generator(models, cfg.generation) |
|
|
|
|
|
tokenizer = task.build_tokenizer(cfg.tokenizer) |
|
bpe = task.build_bpe(cfg.bpe) |
|
|
|
def encode_fn(x): |
|
if tokenizer is not None: |
|
x = tokenizer.encode(x) |
|
if bpe is not None: |
|
x = bpe.encode(x) |
|
return x |
|
|
|
def decode_fn(x): |
|
if bpe is not None: |
|
x = bpe.decode(x) |
|
if tokenizer is not None: |
|
x = tokenizer.decode(x) |
|
return x |
|
|
|
|
|
|
|
align_dict = utils.load_align_dict(cfg.generation.replace_unk) |
|
|
|
max_positions = utils.resolve_max_positions( |
|
task.max_positions(), *[model.max_positions() for model in models] |
|
) |
|
|
|
if cfg.generation.constraints: |
|
logger.warning( |
|
"NOTE: Constrained decoding currently assumes a shared subword vocabulary." |
|
) |
|
|
|
if cfg.interactive.buffer_size > 1: |
|
logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size) |
|
logger.info("NOTE: hypothesis and token scores are output in base 2") |
|
logger.info("Type the input sentence and press return:") |
|
start_id = 0 |
|
for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size): |
|
results = [] |
|
for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): |
|
bsz = batch.src_tokens.size(0) |
|
src_tokens = batch.src_tokens |
|
src_lengths = batch.src_lengths |
|
constraints = batch.constraints |
|
if use_cuda: |
|
src_tokens = src_tokens.cuda() |
|
src_lengths = src_lengths.cuda() |
|
if constraints is not None: |
|
constraints = constraints.cuda() |
|
|
|
sample = { |
|
"net_input": { |
|
"src_tokens": src_tokens, |
|
"src_lengths": src_lengths, |
|
}, |
|
} |
|
translate_start_time = time.time() |
|
translations = task.inference_step( |
|
generator, models, sample, constraints=constraints |
|
) |
|
translate_time = time.time() - translate_start_time |
|
total_translate_time += translate_time |
|
list_constraints = [[] for _ in range(bsz)] |
|
if cfg.generation.constraints: |
|
list_constraints = [unpack_constraints(c) for c in constraints] |
|
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): |
|
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) |
|
constraints = list_constraints[i] |
|
results.append( |
|
( |
|
start_id + id, |
|
src_tokens_i, |
|
hypos, |
|
{ |
|
"constraints": constraints, |
|
"time": translate_time / len(translations), |
|
}, |
|
) |
|
) |
|
|
|
|
|
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): |
|
src_str = '' |
|
if src_dict is not None: |
|
src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) |
|
print("S-{}\t{}".format(id_, src_str)) |
|
print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) |
|
for constraint in info["constraints"]: |
|
print( |
|
"C-{}\t{}".format( |
|
id_, tgt_dict.string(constraint, cfg.common_eval.post_process) |
|
) |
|
) |
|
|
|
|
|
for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]: |
|
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( |
|
hypo_tokens=hypo["tokens"].int().cpu(), |
|
src_str=src_str, |
|
alignment=hypo["alignment"], |
|
align_dict=align_dict, |
|
tgt_dict=tgt_dict, |
|
remove_bpe=cfg.common_eval.post_process, |
|
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), |
|
) |
|
detok_hypo_str = decode_fn(hypo_str) |
|
score = hypo["score"] / math.log(2) |
|
|
|
print("H-{}\t{}\t{}".format(id_, score, hypo_str)) |
|
|
|
print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str)) |
|
print( |
|
"P-{}\t{}".format( |
|
id_, |
|
" ".join( |
|
map( |
|
lambda x: "{:.4f}".format(x), |
|
|
|
hypo["positional_scores"].div_(math.log(2)).tolist(), |
|
) |
|
), |
|
) |
|
) |
|
if cfg.generation.print_alignment: |
|
alignment_str = " ".join( |
|
["{}-{}".format(src, tgt) for src, tgt in alignment] |
|
) |
|
print("A-{}\t{}".format(id_, alignment_str)) |
|
|
|
|
|
start_id += len(inputs) |
|
|
|
logger.info( |
|
"Total time: {:.3f} seconds; translation time: {:.3f}".format( |
|
time.time() - start_time, total_translate_time |
|
) |
|
) |
|
|
|
|
|
def cli_main(): |
|
parser = options.get_interactive_generation_parser() |
|
args = options.parse_args_and_arch(parser) |
|
distributed_utils.call_main(convert_namespace_to_omegaconf(args), main) |
|
|
|
|
|
if __name__ == "__main__": |
|
cli_main() |
|
|