Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import contextlib | |
import json | |
import logging | |
import math | |
import os | |
from argparse import Namespace | |
from collections import OrderedDict, defaultdict | |
from pathlib import Path | |
from typing import Dict, Sequence, Tuple | |
from argparse import ArgumentError | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import fairseq | |
from fairseq import metrics, options, utils | |
from fairseq.data import ( | |
FairseqDataset, | |
LanguagePairDataset, | |
NoisingDataset, | |
PrependTokenDataset, | |
RoundRobinZipDatasets, | |
TransformEosLangPairDataset, | |
data_utils, | |
encoders, | |
) | |
from fairseq.sequence_generator import SequenceGenerator | |
from fairseq.tasks import register_task | |
from fairseq.tasks.translation import TranslationTask, load_langpair_dataset | |
logger = logging.getLogger(__name__) | |
class PiecewiseLinearFn: | |
"""Piecewise linear function. Can be configured with a string.""" | |
def __init__(self, pieces: Sequence[Tuple[int, float]]): | |
assert pieces == sorted( | |
pieces | |
), f"PiecewiseLinearFn configuration should be sorted, received: {pieces}" | |
self.pieces = pieces | |
def __call__(self, x: int) -> float: | |
for i, (x_a, y_a) in enumerate(self.pieces[:-1]): | |
x_b, y_b = self.pieces[i + 1] | |
if x_a <= x <= x_b: | |
return y_a + (x - x_a) * (y_b - y_a) / (x_b - x_a) | |
return self.pieces[-1][1] | |
def from_string(configuration: str) -> "PiecewiseLinearFn": | |
""" | |
Parse the configuration of lambda coefficient (for scheduling). | |
x = "3" # lambda will be a constant equal to x | |
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease | |
# to 0 during the first 1000 iterations | |
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 | |
# iterations, then will linearly increase to 1 until iteration 2000 | |
""" | |
if isinstance(configuration, float): | |
return PiecewiseLinearFn([(0, configuration)]) | |
try: | |
parts = configuration.split(",") | |
if len(parts) == 1: | |
v = float(configuration) | |
return PiecewiseLinearFn([(0, v)]) | |
split = [s.split(":") for s in parts] | |
pieces = [(int(t), float(v)) for t, v in split] | |
return PiecewiseLinearFn(pieces) | |
except Exception: | |
raise ValueError( | |
f"Invalid PiecewiseLinearFn configuration: {configuration!r}" | |
) | |
def one() -> "PiecewiseLinearFn": | |
return PiecewiseLinearFn([(0, 1.0)]) | |
class OnlineBackTranslationTask(TranslationTask): | |
def add_args(parser): | |
"""Add task-specific arguments to the parser.""" | |
# fmt: off | |
# Generic translation args | |
parser.add_argument('data', help='colon separated path to data directories list, \ | |
will be iterated upon during epochs in round-robin manner; \ | |
however, valid and test data are always in the first directory to \ | |
avoid the need for repeating them in all directories') | |
parser.add_argument('--mono-langs', metavar='MONO_LANGS', | |
help='monolingual languages for training') | |
parser.add_argument('--valid-lang-pairs', default=None, metavar='VALID_LANG_PAIRS', | |
help='language pairs for validation') | |
parser.add_argument('--load-alignments', action='store_true', | |
help='load the binarized alignments') | |
parser.add_argument('--left-pad-source', default='False', type=str, metavar='BOOL', | |
help='pad the source on the left') | |
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', | |
help='pad the target on the left') | |
parser.add_argument('--upsample-primary', default=1, type=int, | |
help='amount to upsample primary dataset') | |
try: | |
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', | |
help='max number of tokens in the source sequence') | |
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', | |
help='max number of tokens in the target sequence') | |
except ArgumentError: | |
# this might have already been defined. Once we transition this to hydra it should be fine to add it here. | |
pass | |
parser.add_argument('--truncate-source', action='store_true', default=False, | |
help='truncate source to max-source-positions') | |
parser.add_argument('--num-batch-buckets', default=0, type=int, metavar='N', | |
help='if >0, then bucket source and target lengths into N ' | |
'buckets and pad accordingly; this is useful on TPUs ' | |
'to minimize the number of compilations') | |
# Denoising args | |
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N', | |
help='maximum word shuffle distance for denoising autoencoding data generation') | |
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N', | |
help='word dropout probability for denoising autoencoding data generation') | |
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N', | |
help='word blanking probability for denoising autoencoding data generation') | |
# Backtranslation args | |
parser.add_argument('--lambda-bt', default="1.0", type=str, metavar='N', | |
help='back-translation weight') | |
parser.add_argument('--lambda-dae', default="1.0", type=str, metavar='N', | |
help='denoising auto-encoder weight') | |
# Evaluation args | |
parser.add_argument('--generate-one-by-one', action='store_true', | |
help='generate one sentence at a time for backtranslation') | |
parser.add_argument('--eval-bleu', action='store_true', | |
help='evaluation with BLEU scores') | |
parser.add_argument('--eval-bleu-detok', type=str, default="space", | |
help='detokenize before computing BLEU (e.g., "moses"); ' | |
'required if using --eval-bleu; use "space" to ' | |
'disable detokenization; see fairseq.data.encoders ' | |
'for other options') | |
parser.add_argument('--eval-bleu-detok-args', type=str, metavar='JSON', | |
help='args for building the tokenizer, if needed') | |
parser.add_argument('--eval-tokenized-bleu', action='store_true', default=False, | |
help='compute tokenized BLEU instead of sacrebleu') | |
parser.add_argument('--eval-bleu-remove-bpe', nargs='?', const='@@ ', default=None, | |
help='remove BPE before computing BLEU') | |
parser.add_argument('--eval-bleu-args', type=str, metavar='JSON', | |
help='generation args for BLUE scoring, ' | |
'e.g., \'{"beam": 4, "lenpen": 0.6}\'') | |
parser.add_argument('--eval-bleu-print-samples', action='store_true', | |
help='print sample generations during validation') | |
# fmt: on | |
def __init__(self, args, common_dict, mono_langs, valid_lang_pairs): | |
super().__init__(args, common_dict, common_dict) | |
self.common_dict = common_dict | |
self.mono_langs = mono_langs | |
self.valid_lang_pairs = valid_lang_pairs | |
self.SHOW_SAMPLES_INTERVAL = 1000 | |
# Start by showing samples | |
self._show_samples_ctr = self.SHOW_SAMPLES_INTERVAL | |
self.SHOW_SAMPLES_NUMBER = 5 | |
self.lambda_bt = PiecewiseLinearFn.from_string(args.lambda_bt) | |
self.lambda_dae = PiecewiseLinearFn.from_string(args.lambda_dae) | |
self.args = args | |
self.data = utils.split_paths(self.args.data) | |
if len(self.data) == 1: | |
shards = list(Path(self.data[0]).glob("shard*")) | |
if len(shards) > 0: | |
# keep this as strings, since it can also be a manifold path | |
old_data = self.data | |
self.data = [str(shard) for shard in shards] | |
logging.warning(f"Expanded data directory {old_data} to {self.data}") | |
def setup_task(cls, args, **kwargs): | |
"""Setup the task (e.g., load dictionaries). | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
""" | |
args.left_pad_source = options.eval_bool(args.left_pad_source) | |
args.left_pad_target = options.eval_bool(args.left_pad_target) | |
paths = utils.split_paths(args.data) | |
assert len(paths) > 0 | |
assert args.mono_langs is not None | |
mono_langs = args.mono_langs.split(",") | |
valid_lang_pairs = args.valid_lang_pairs.split(",") | |
# load dictionary | |
dict_path = os.path.join(paths[0], "dict.txt") | |
common_dict = cls.load_dictionary(dict_path) | |
return cls(args, common_dict, mono_langs, valid_lang_pairs) | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs) -> FairseqDataset: | |
"""Load a given dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
""" | |
if split == "train": | |
data_path = self.data[(epoch - 1) % len(self.data)] | |
dataset = self.load_train_dataset(data_path) | |
else: | |
# valid/test should always be the same. | |
dataset = self.load_translation_dataset(split, self.data[0]) | |
self.datasets[split] = dataset | |
return dataset | |
def load_train_dataset(self, data_path: str) -> FairseqDataset: | |
"""The training dataset is made of backtranslation dataset and denoising dataset.""" | |
data = [] | |
for lang in self.mono_langs: | |
train_path = os.path.join(data_path, lang, "train") | |
# TODO: could we do the BT using denoise sample ? | |
# this would half the data loading work | |
data.append((f"{lang}-BT", self.load_bt_dataset(train_path, lang))) | |
data.append( | |
(f"{lang}-DENOISE", self.load_denoise_dataset(train_path, lang)) | |
) | |
return RoundRobinZipDatasets(OrderedDict(data)) | |
def _langpair_dataset( | |
self, src: FairseqDataset, tgt: FairseqDataset | |
) -> LanguagePairDataset: | |
return LanguagePairDataset( | |
src, | |
src.sizes, | |
self.dictionary, | |
tgt=tgt, | |
tgt_sizes=tgt.sizes, | |
tgt_dict=self.dictionary, | |
left_pad_source=self.args.left_pad_source, | |
left_pad_target=self.args.left_pad_target, | |
# TODO: should we shuffle ? we are already sorting batch by sizes so ? | |
# shuffle=True, | |
) | |
def _prepend_lang_bos_to_target( | |
self, dataset: LanguagePairDataset, lang: str | |
) -> LanguagePairDataset: | |
bos = _lang_token_index(self.dictionary, lang) | |
return TransformEosLangPairDataset( | |
dataset, | |
src_eos=self.dictionary.eos(), | |
new_src_eos=self.dictionary.eos(), | |
tgt_bos=self.dictionary.eos(), | |
new_tgt_bos=bos, | |
) | |
def load_bt_dataset(self, data_path: str, lang: str) -> FairseqDataset: | |
"""The BT dataset is generated with (tgt, tgt) pairs. | |
The actual translation to a (generated_src, tgt) pair | |
is done on the fly during training. | |
""" | |
mono_dataset = data_utils.load_indexed_dataset( | |
data_path, self.common_dict, self.args.dataset_impl | |
) | |
assert mono_dataset is not None, f"No dataset found for {lang}" | |
mono_dataset_src = PrependTokenDataset( | |
mono_dataset, _lang_token_index(self.dictionary, lang) | |
) | |
mono_dataset_bt = self._langpair_dataset(mono_dataset_src, mono_dataset) | |
logger.info( | |
f"mono_lang = {lang} " | |
f"lang token index = {_lang_token_index(self.dictionary, lang)} " | |
f"lang token = {_lang_token(lang)}" | |
) | |
mono_dataset_bt = self._prepend_lang_bos_to_target(mono_dataset_bt, lang) | |
return mono_dataset_bt | |
def load_denoise_dataset(self, data_path: str, lang: str) -> FairseqDataset: | |
"""Classic denoising dataset""" | |
dataset = data_utils.load_indexed_dataset( | |
data_path, self.common_dict, self.args.dataset_impl | |
) | |
noisy_dataset = NoisingDataset( | |
dataset, | |
self.dictionary, | |
seed=1, | |
max_word_shuffle_distance=self.args.max_word_shuffle_distance, | |
word_dropout_prob=self.args.word_dropout_prob, | |
word_blanking_prob=self.args.word_blanking_prob, | |
) | |
noisy_dataset = PrependTokenDataset( | |
noisy_dataset, _lang_token_index(self.dictionary, lang) | |
) | |
clean_dataset = data_utils.load_indexed_dataset( | |
data_path, self.common_dict, self.args.dataset_impl | |
) | |
denoising_dataset = self._langpair_dataset(noisy_dataset, clean_dataset) | |
denoising_dataset = self._prepend_lang_bos_to_target(denoising_dataset, lang) | |
return denoising_dataset | |
def load_translation_dataset( | |
self, split: str, data_path: str, combine: bool = False | |
): | |
# only judging with one language pair for the moment, | |
# since ConcatDataset doesn't work as expected | |
assert len(self.valid_lang_pairs) == 1, "For now..." | |
valid_lang_pair = self.valid_lang_pairs[0] | |
src, tgt = valid_lang_pair.split("-") | |
# use the same function than TranslationTask | |
src_tgt_dt = load_langpair_dataset( | |
data_path, | |
split, | |
src, | |
self.common_dict, | |
tgt, | |
self.common_dict, | |
combine=combine, | |
dataset_impl=self.args.dataset_impl, | |
upsample_primary=self.args.upsample_primary, | |
left_pad_source=self.args.left_pad_source, | |
left_pad_target=self.args.left_pad_target, | |
max_source_positions=self.args.max_source_positions, | |
max_target_positions=self.args.max_target_positions, | |
load_alignments=self.args.load_alignments, | |
truncate_source=self.args.truncate_source, | |
num_buckets=self.args.num_batch_buckets, | |
shuffle=(split != "test"), | |
prepend_bos_src=_lang_token_index(self.dictionary, src), | |
) | |
src_tgt_eos_dt = self._prepend_lang_bos_to_target(src_tgt_dt, tgt) | |
src_tgt_eos_dt.args = self.args | |
return src_tgt_eos_dt | |
def build_dataset_for_inference(self, src_tokens, src_lengths, constraints=None): | |
raise NotImplementedError | |
def build_model(self, args, from_checkpoint=False): | |
# torch.autograd.set_detect_anomaly(True) | |
model = super().build_model(args, from_checkpoint) | |
add_secial_tokens_to_dict_and_model(self.common_dict, model, self.mono_langs) | |
self.sequence_generators = {} | |
for mono_lang in self.mono_langs: | |
self.sequence_generators[mono_lang] = SequenceGenerator( | |
[model], | |
tgt_dict=self.dictionary, | |
beam_size=1, | |
max_len_a=1.3, | |
max_len_b=5, | |
min_len=5, | |
# keep 1 to be able to prepend bos | |
max_len=model.max_decoder_positions() - 1, | |
) | |
if getattr(args, "eval_bleu", False): | |
assert getattr(args, "eval_bleu_detok", None) is not None, ( | |
"--eval-bleu-detok is required if using --eval-bleu; " | |
"try --eval-bleu-detok=moses (or --eval-bleu-detok=space " | |
"to disable detokenization, e.g., when using sentencepiece)" | |
) | |
detok_args = json.loads(getattr(args, "eval_bleu_detok_args", "{}") or "{}") | |
self.tokenizer = encoders.build_tokenizer( | |
Namespace( | |
tokenizer=getattr(args, "eval_bleu_detok", None), **detok_args | |
) | |
) | |
gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}") | |
self.bleu_sequence_generator = self.build_generator( | |
[model], Namespace(**gen_args) | |
) | |
return model | |
def max_positions(self): | |
"""Return the max sentence length allowed by the task.""" | |
return (self.args.max_source_positions, self.args.max_target_positions) | |
def dictionary(self): | |
"""Return the source :class:`~fairseq.data.Dictionary`.""" | |
return self.common_dict | |
def display_samples_once_in_a_while(self, smp, mono_lang, other_lang): | |
self._show_samples_ctr += 1 | |
if self._show_samples_ctr < self.SHOW_SAMPLES_INTERVAL: | |
return | |
self._show_samples_ctr = 0 | |
ln = smp["net_input"]["src_tokens"].shape[0] | |
logger.info( | |
f"(r:{self.args.distributed_rank}) : " | |
f"{other_lang} ---> {mono_lang} " | |
f"({other_lang} was generated by back-translation.) {ln} samples" | |
) | |
for i in range(min(ln, self.SHOW_SAMPLES_NUMBER)): | |
src_tokens = smp["net_input"]["src_tokens"][i] | |
tgt_tokens = smp["target"][i] | |
src_str = self.dictionary.string(src_tokens, "sentencepiece") | |
tgt_str = self.dictionary.string(tgt_tokens, "sentencepiece") | |
logger.info( | |
f"\n{i}\t\t[{other_lang} generated] {src_str}\n" | |
f"\t\t[{mono_lang} original ] {tgt_str}\n" | |
f"\t\t[ src tokens] {src_tokens}\n" | |
) | |
def backtranslate_sample(self, smp, orig_lang, other_lang) -> None: | |
""" | |
* WARNING: smp is modified in place. | |
* At the start of this function, `smp` has the same input and target: | |
|--------------------------------------------------------| | |
| smp['net_input']['src_tokens'] | smp['target'] | | |
| (from data) __en__ hello world | __en__ hello world | | |
|--------------------------------------------------------| | |
* We call generator.generate(smp, bos_token = token("ro")), | |
and copy the result as input | |
* At the end, `smp` has the translation to other language. | |
|--------------------------------------------------------| | |
| smp['net_input']['src_tokens'] | smp['target'] | | |
| (generated) __ro__ salut lume | __en__ hello world | | |
|--------------------------------------------------------| | |
""" | |
bos_token = _lang_token_index(self.dictionary, other_lang) | |
generated = self.sequence_generators[orig_lang].generate( | |
models=[], sample=smp, bos_token=bos_token | |
) | |
max_lngth = max([gn[0]["tokens"].size(0) for gn in generated]) | |
net_input = smp["net_input"] | |
n_src_tokens = torch.empty( | |
size=(len(generated), max_lngth + 1), dtype=net_input["src_tokens"].dtype | |
) | |
n_src_lengths = torch.empty( | |
len(generated), dtype=net_input["src_lengths"].dtype | |
) | |
for i, gn in enumerate(generated): | |
tokens = gn[0]["tokens"] | |
tokens_size = tokens.size(0) | |
padding_needed = max_lngth - tokens_size | |
tokens = torch.cat([tokens.new([bos_token]), tokens]) | |
tokens = F.pad(tokens, (0, padding_needed), value=self.dictionary.pad()) | |
n_src_tokens[i] = tokens | |
n_src_lengths[i] = tokens_size + 1 | |
device = net_input["src_tokens"].device | |
# This seems to be important | |
del net_input["src_tokens"] | |
del net_input["src_lengths"] | |
net_input["src_tokens"] = n_src_tokens.to(device) | |
net_input["src_lengths"] = n_src_lengths.to(device) | |
def generate(self, smp, model): | |
model.eval() | |
orig_lang = ( | |
self.dictionary[smp["net_input"]["src_tokens"][0][0]] | |
.replace(" ", "") | |
.replace("_", "") | |
) | |
bos_token = smp["net_input"]["prev_output_tokens"][0][0] | |
with torch.no_grad(): | |
generated = self.sequence_generators[orig_lang].generate( | |
models=[model], sample=smp, bos_token=bos_token | |
) | |
return generated | |
def get_other_lang(self, lang): | |
# TODO: allow more complex mapping | |
if lang != self.mono_langs[0]: | |
return self.mono_langs[0] | |
if len(self.mono_langs) == 2: | |
return self.mono_langs[1] | |
return self.mono_langs[np.random.randint(1, len(self.mono_langs))] | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
): | |
model.train() | |
model.set_num_updates(update_num) | |
agg_loss, agg_sample_size = 0.0, 0.0 | |
agg_logging_output: Dict[str, float] = defaultdict(float) | |
dataset_keys = self.datasets["train"].datasets.keys() | |
weights = { | |
"BT": self.lambda_bt(update_num), | |
"DENOISE": self.lambda_dae(update_num), | |
} | |
log_keys = {"BT": "bt_", "DENOISE": "dae_"} | |
for dataset_key in dataset_keys: | |
smp = sample[dataset_key] | |
mono_lang, task_subtype = dataset_key.split("-") | |
if weights[task_subtype] == 0: | |
continue | |
if task_subtype == "BT": | |
with torch.autograd.profiler.record_function("backtranslation"): | |
model.eval() | |
# TODO: Could we translate to several language at once ? | |
# this would allow to share encoder_out and maximize GPU usage. | |
other_lang = self.get_other_lang(mono_lang) | |
self.backtranslate_sample(smp, mono_lang, other_lang) | |
self.display_samples_once_in_a_while(smp, mono_lang, other_lang) | |
model.train() | |
# Like in FairseqTask.train_step | |
with torch.autograd.profiler.record_function("forward"): | |
loss, sample_size, logging_output = criterion(model, smp) | |
loss *= weights[task_subtype] | |
if ignore_grad: | |
loss *= 0 | |
with torch.autograd.profiler.record_function("backward"): | |
optimizer.backward(loss) | |
agg_loss += loss.item() | |
agg_sample_size += sample_size | |
for k in logging_output: | |
agg_logging_output[log_keys[task_subtype] + k] += logging_output[k] | |
agg_logging_output[k] += logging_output[k] | |
return agg_loss, agg_sample_size, agg_logging_output | |
def get_bos_token_from_sample(self, sample): | |
net_input = sample["net_input"] | |
source_lang_token_id = torch.unique(net_input["src_tokens"][:, 0]).item() | |
source_lang_token = self.dictionary[source_lang_token_id].replace("_", "") | |
target_lang_token_id = _lang_token_index( | |
self.dictionary, self.get_other_lang(source_lang_token) | |
) | |
return target_lang_token_id | |
def reduce_metrics(self, logging_outputs, criterion): | |
super().reduce_metrics(logging_outputs, criterion) | |
bt_sample_size = sum(x.get("bt_sample_size", 0) for x in logging_outputs) | |
if bt_sample_size: | |
bt_loss_sum = sum(x.get("bt_loss", 0) for x in logging_outputs) | |
bt_loss_sum *= 1 / bt_sample_size / math.log(2) | |
metrics.log_scalar("bt_loss", bt_loss_sum, bt_sample_size, round=3) | |
bt_nll_loss_sum = sum(x.get("bt_nll_loss", 0) for x in logging_outputs) | |
bt_ntokens = sum(x.get("bt_ntokens", 0) for x in logging_outputs) | |
bt_nll_loss_sum *= 1 / bt_ntokens / math.log(2) | |
metrics.log_scalar("bt_nll_loss", bt_nll_loss_sum, bt_ntokens, round=3) | |
metrics.log_derived( | |
"bt_ppl", lambda meters: utils.get_perplexity(meters["bt_nll_loss"].avg) | |
) | |
dae_sample_size = sum(x.get("dae_sample_size", 0) for x in logging_outputs) | |
if dae_sample_size: | |
dae_loss_sum = sum(x.get("dae_loss", 0) for x in logging_outputs) | |
dae_loss_sum *= 1 / dae_sample_size / math.log(2) | |
metrics.log_scalar("dae_loss", dae_loss_sum, dae_sample_size, round=3) | |
dae_nll_loss_sum = sum(x.get("dae_nll_loss", 0) for x in logging_outputs) | |
dae_ntokens = sum(x.get("dae_ntokens", 0) for x in logging_outputs) | |
dae_nll_loss_sum *= 1 / dae_ntokens / math.log(2) | |
metrics.log_scalar("dae_nll_loss", dae_nll_loss_sum, dae_ntokens, round=3) | |
metrics.log_derived( | |
"dae_ppl", | |
lambda meters: utils.get_perplexity(meters["dae_nll_loss"].avg), | |
) | |
def extend_embedding( | |
emb: nn.Module, new_vocab_size: int, copy_from_token_id: int | |
) -> None: | |
old_emb_data = emb.weight.data | |
(old_vocab_size, dim) = old_emb_data.shape | |
assert new_vocab_size >= old_vocab_size | |
if new_vocab_size > old_vocab_size: | |
emb.weight.data = torch.zeros((new_vocab_size, dim)) | |
emb.weight.data[:old_vocab_size, :] = old_emb_data | |
# initialize new embeddings | |
emb.weight.data[old_vocab_size:, :] = old_emb_data[copy_from_token_id] | |
if hasattr(emb, "num_embeddings"): | |
emb.num_embeddings = new_vocab_size | |
if hasattr(emb, "out_features"): | |
emb.out_features = new_vocab_size | |
if getattr(emb, "bias", None) is None: | |
return | |
# Fix the bias. | |
# Bias shape can be different from the previous vocab size | |
# if the weight matrix was shared and alread extended but not the bias. | |
(old_vocab_size,) = emb.bias.shape | |
assert new_vocab_size >= old_vocab_size | |
if new_vocab_size > old_vocab_size: | |
old_bias = emb.bias.data | |
new_bias = torch.zeros( | |
(new_vocab_size,), dtype=old_bias.dtype, device=old_bias.device | |
) | |
new_bias[:old_vocab_size] = old_bias | |
emb.bias.data = new_bias | |
def add_secial_tokens_to_dict_and_model( | |
dictionary: "fairseq.data.Dictionary", | |
model: nn.Module, | |
mono_langs: Sequence[str], | |
) -> None: | |
embs = model.encoder.embed_tokens | |
vocab_size, embedding_dim = embs.weight.shape | |
# The model may or may not have a '<mask>' embedding yet | |
assert ( | |
len(dictionary) <= vocab_size <= len(dictionary) + 1 | |
), f"Dictionary len ({len(dictionary)}) doesn't match embs shape ({embs.weight.shape})" | |
# TODO: we should reuse the pretrained model dict which already has <mask> | |
dictionary.add_symbol("<mask>") | |
for lang in mono_langs: | |
lang_token = _lang_token(lang) | |
dictionary.add_symbol(lang_token) | |
logger.info( | |
f"dictionary: {len(dictionary)} -> {vocab_size} tokens " | |
f"after adding {len(mono_langs)} lang tokens." | |
) | |
if len(dictionary) <= vocab_size: | |
return | |
extend_embedding(embs, len(dictionary), dictionary.bos()) | |
dec_embs = model.decoder.embed_tokens | |
extend_embedding(dec_embs, len(dictionary), dictionary.bos()) | |
lm_head = model.decoder.output_projection | |
extend_embedding(lm_head, len(dictionary), dictionary.bos()) | |
assert lm_head.weight.shape == (len(dictionary), embedding_dim) | |
def _lang_token(lang: str) -> str: | |
return f"__{lang}__" | |
def _lang_token_index(dictionary, lang: str) -> int: | |
return dictionary.index(_lang_token(lang)) | |
def assert_weights_have_changed(model: nn.Module): | |
def checksum(model: nn.Module) -> float: | |
return sum(p.sum().item() for p in model.parameters()) | |
initial_checksum = checksum(model) | |
yield model | |
final_checksum = checksum(model) | |
logger.info( | |
f"initial_checksum={initial_checksum} -> final_checksum={final_checksum}" | |
) | |
assert initial_checksum != final_checksum, "Model hasn't changed !" | |