Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 -u | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import copy | |
import logging | |
import os | |
from typing import Any, Dict, Iterator, List | |
import torch | |
from fairseq import utils | |
from fairseq.data import encoders | |
from omegaconf import open_dict | |
from torch import nn | |
logger = logging.getLogger(__name__) | |
def from_pretrained( | |
model_name_or_path, | |
checkpoint_file="model.pt", | |
data_name_or_path=".", | |
archive_map=None, | |
**kwargs | |
): | |
from fairseq import checkpoint_utils, file_utils | |
if archive_map is not None: | |
if model_name_or_path in archive_map: | |
model_name_or_path = archive_map[model_name_or_path] | |
if data_name_or_path is not None and data_name_or_path in archive_map: | |
data_name_or_path = archive_map[data_name_or_path] | |
# allow archive_map to set default arg_overrides (e.g., tokenizer, bpe) | |
# for each model | |
if isinstance(model_name_or_path, dict): | |
for k, v in model_name_or_path.items(): | |
if k == "checkpoint_file": | |
checkpoint_file = v | |
elif ( | |
k != "path" | |
# only set kwargs that don't already have overrides | |
and k not in kwargs | |
): | |
kwargs[k] = v | |
model_name_or_path = model_name_or_path["path"] | |
model_path = file_utils.load_archive_file(model_name_or_path) | |
# convenience hack for loading data and BPE codes from model archive | |
if data_name_or_path.startswith("."): | |
kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path)) | |
else: | |
kwargs["data"] = file_utils.load_archive_file(data_name_or_path) | |
for file, arg in { | |
"code": "bpe_codes", | |
"bpecodes": "bpe_codes", | |
"sentencepiece.bpe.model": "sentencepiece_model", | |
"merges.txt": "bpe_merges", | |
"vocab.json": "bpe_vocab", | |
}.items(): | |
path = os.path.join(model_path, file) | |
if os.path.exists(path): | |
kwargs[arg] = path | |
if "user_dir" in kwargs: | |
utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"])) | |
models, args, task = checkpoint_utils.load_model_ensemble_and_task( | |
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)], | |
arg_overrides=kwargs, | |
) | |
return { | |
"args": args, | |
"task": task, | |
"models": models, | |
} | |
class GeneratorHubInterface(nn.Module): | |
""" | |
PyTorch Hub interface for generating sequences from a pre-trained | |
translation or language model. | |
""" | |
def __init__(self, cfg, task, models): | |
super().__init__() | |
self.cfg = cfg | |
self.task = task | |
self.models = nn.ModuleList(models) | |
self.src_dict = task.source_dictionary | |
self.tgt_dict = task.target_dictionary | |
# optimize model for generation | |
for model in self.models: | |
model.prepare_for_inference_(cfg) | |
# Load alignment dictionary for unknown word replacement | |
# (None if no unknown word replacement, empty if no path to align dictionary) | |
self.align_dict = utils.load_align_dict(cfg.generation.replace_unk) | |
self.tokenizer = encoders.build_tokenizer(cfg.tokenizer) | |
self.bpe = encoders.build_bpe(cfg.bpe) | |
self.max_positions = utils.resolve_max_positions( | |
self.task.max_positions(), *[model.max_positions() for model in models] | |
) | |
# this is useful for determining the device | |
self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) | |
def device(self): | |
return self._float_tensor.device | |
def translate( | |
self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs | |
) -> List[str]: | |
return self.sample(sentences, beam, verbose, **kwargs) | |
def sample( | |
self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs | |
) -> List[str]: | |
if isinstance(sentences, str): | |
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0] | |
tokenized_sentences = [self.encode(sentence) for sentence in sentences] | |
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs) | |
return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos] | |
def score(self, sentences: List[str], **kwargs): | |
if isinstance(sentences, str): | |
return self.score([sentences], **kwargs)[0] | |
# NOTE: this doesn't support translation tasks currently | |
tokenized_sentences = [self.encode(sentence) for sentence in sentences] | |
return [ | |
hypos[0] | |
for hypos in self.generate( | |
tokenized_sentences, score_reference=True, **kwargs | |
) | |
] | |
def generate( | |
self, | |
tokenized_sentences: List[torch.LongTensor], | |
beam: int = 5, | |
verbose: bool = False, | |
skip_invalid_size_inputs=False, | |
inference_step_args=None, | |
prefix_allowed_tokens_fn=None, | |
**kwargs | |
) -> List[List[Dict[str, torch.Tensor]]]: | |
if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1: | |
return self.generate( | |
tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs | |
)[0] | |
# build generator using current args as well as any kwargs | |
gen_args = copy.deepcopy(self.cfg.generation) | |
with open_dict(gen_args): | |
gen_args.beam = beam | |
for k, v in kwargs.items(): | |
setattr(gen_args, k, v) | |
generator = self.task.build_generator( | |
self.models, | |
gen_args, | |
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, | |
) | |
inference_step_args = inference_step_args or {} | |
results = [] | |
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): | |
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) | |
translations = self.task.inference_step( | |
generator, self.models, batch, **inference_step_args | |
) | |
for id, hypos in zip(batch["id"].tolist(), translations): | |
results.append((id, hypos)) | |
# sort output to match input order | |
outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])] | |
if verbose: | |
def getarg(name, default): | |
return getattr(gen_args, name, getattr(self.cfg, name, default)) | |
for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): | |
src_str_with_unk = self.string(source_tokens) | |
logger.info("S\t{}".format(src_str_with_unk)) | |
for hypo in target_hypotheses: | |
hypo_str = self.decode(hypo["tokens"]) | |
logger.info("H\t{}\t{}".format(hypo["score"], hypo_str)) | |
logger.info( | |
"P\t{}".format( | |
" ".join( | |
map( | |
lambda x: "{:.4f}".format(x), | |
hypo["positional_scores"].tolist(), | |
) | |
) | |
) | |
) | |
if hypo["alignment"] is not None and getarg( | |
"print_alignment", False | |
): | |
logger.info( | |
"A\t{}".format( | |
" ".join( | |
[ | |
"{}-{}".format(src_idx, tgt_idx) | |
for src_idx, tgt_idx in hypo["alignment"] | |
] | |
) | |
) | |
) | |
return outputs | |
def encode(self, sentence: str) -> torch.LongTensor: | |
sentence = self.tokenize(sentence) | |
sentence = self.apply_bpe(sentence) | |
return self.binarize(sentence) | |
def decode(self, tokens: torch.LongTensor) -> str: | |
sentence = self.string(tokens) | |
sentence = self.remove_bpe(sentence) | |
return self.detokenize(sentence) | |
def tokenize(self, sentence: str) -> str: | |
if self.tokenizer is not None: | |
sentence = self.tokenizer.encode(sentence) | |
return sentence | |
def detokenize(self, sentence: str) -> str: | |
if self.tokenizer is not None: | |
sentence = self.tokenizer.decode(sentence) | |
return sentence | |
def apply_bpe(self, sentence: str) -> str: | |
if self.bpe is not None: | |
sentence = self.bpe.encode(sentence) | |
return sentence | |
def remove_bpe(self, sentence: str) -> str: | |
if self.bpe is not None: | |
sentence = self.bpe.decode(sentence) | |
return sentence | |
def binarize(self, sentence: str) -> torch.LongTensor: | |
return self.src_dict.encode_line(sentence, add_if_not_exist=False).long() | |
def string(self, tokens: torch.LongTensor) -> str: | |
return self.tgt_dict.string(tokens) | |
def _build_batches( | |
self, tokens: List[List[int]], skip_invalid_size_inputs: bool | |
) -> Iterator[Dict[str, Any]]: | |
lengths = torch.LongTensor([t.numel() for t in tokens]) | |
batch_iterator = self.task.get_batch_iterator( | |
dataset=self.task.build_dataset_for_inference(tokens, lengths), | |
max_tokens=self.cfg.dataset.max_tokens, | |
max_sentences=self.cfg.dataset.batch_size, | |
max_positions=self.max_positions, | |
ignore_invalid_inputs=skip_invalid_size_inputs, | |
disable_iterator_cache=True, | |
).next_epoch_itr(shuffle=False) | |
return batch_iterator | |
class BPEHubInterface(object): | |
"""PyTorch Hub interface for Byte-Pair Encoding (BPE).""" | |
def __init__(self, bpe, **kwargs): | |
super().__init__() | |
args = argparse.Namespace(bpe=bpe, **kwargs) | |
self.bpe = encoders.build_bpe(args) | |
assert self.bpe is not None | |
def encode(self, sentence: str) -> str: | |
return self.bpe.encode(sentence) | |
def decode(self, sentence: str) -> str: | |
return self.bpe.decode(sentence) | |
class TokenizerHubInterface(object): | |
"""PyTorch Hub interface for tokenization.""" | |
def __init__(self, tokenizer, **kwargs): | |
super().__init__() | |
args = argparse.Namespace(tokenizer=tokenizer, **kwargs) | |
self.tokenizer = encoders.build_tokenizer(args) | |
assert self.tokenizer is not None | |
def encode(self, sentence: str) -> str: | |
return self.tokenizer.encode(sentence) | |
def decode(self, sentence: str) -> str: | |
return self.tokenizer.decode(sentence) | |