OFA-Visual_Grounding / fairseq /examples /latent_depth /latent_depth_src /multilingual_translation_latent_depth.py
JustinLin610
update
10b0761
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.tasks import register_task
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
from fairseq.utils import safe_hasattr
from .loss.latent_depth import LatentLayersKLLoss, LatentLayersSparsityLoss
@register_task("multilingual_translation_latent_depth")
class MultilingualTranslationTaskLatentDepth(MultilingualTranslationTask):
"""A task for multiple translation with latent depth.
See `"Deep Transformer with Latent Depth"
(Li et al., 2020) <https://arxiv.org/pdf/2009.13102.pdf>`_.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
MultilingualTranslationTask.add_args(parser)
parser.add_argument('--encoder-latent-layer', action='store_true', help='latent layer selection in encoder')
parser.add_argument('--decoder-latent-layer', action='store_true', help='latent layer selection in decoder')
parser.add_argument('--target-layers', default=-1, type=int,
help='number of effective layers to learn; -1 means no constraint')
parser.add_argument('--sparsity-weight', default=0.0, type=float,
help='weight for sparsity loss')
parser.add_argument('--share-weight', default=0.0, type=float,
help='weight for sharing loss')
parser.add_argument('--soft-update', default=1, type=int,
help='number of updates with soft sampling')
parser.add_argument('--anneal-updates', default=1, type=int,
help='number of updates to anneal the KL loss weight')
parser.add_argument('--prior', default="uniform", type=str,
help='prior used for computing KL loss')
# fmt: on
def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
self.src_langs, self.tgt_langs = zip(
*[(lang.split("-")[0], lang.split("-")[1]) for lang in args.lang_pairs]
)
if self.training and self.encoder_latent_layer:
assert self.args.share_encoders
if self.training and self.decoder_latent_layer:
assert self.args.share_decoders
if training or self.encoder_latent_layer or self.decoder_latent_layer:
self.lang_pairs = args.lang_pairs
else:
self.lang_pairs = ["{}-{}".format(args.source_lang, args.target_lang)]
self.eval_lang_pairs = self.lang_pairs
self.model_lang_pairs = self.lang_pairs
if self.training and (self.encoder_latent_layer or self.decoder_latent_layer):
self.kl_loss = LatentLayersKLLoss(self.args)
self.sparsity_loss = LatentLayersSparsityLoss(self.args)
def _per_lang_pair_train_loss(
self, lang_pair, model, update_num, criterion, sample, optimizer, ignore_grad
):
src, tgt = lang_pair.split("-")
if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
model.models[lang_pair].encoder.layer_select.hard_select = (
update_num > self.args.soft_update
)
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
model.models[lang_pair].decoder.layer_select.hard_select = (
update_num > self.args.soft_update
)
loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
if self.encoder_latent_layer:
none_samples = sum(
1 if x is None else 0
for x in model.models[lang_pair].encoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].encoder.layer_select.layer_samples,
src_lang_idx,
update_num,
sample_size,
)
if self.decoder_latent_layer:
none_samples = sum(
1 if x is None else 0
for x in model.models[lang_pair].decoder.layer_select.layer_samples
)
if none_samples == 0 or self.args.prior != "agged_posterior":
loss += self.kl_loss(
model.models[lang_pair].decoder.layer_select.layer_samples,
tgt_lang_idx,
update_num,
sample_size,
)
if ignore_grad:
loss *= 0
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
# need to retain the graph if sparsity loss needs to be added
loss.backward(retain_graph=True)
else:
optimizer.backward(loss)
return loss, sample_size, logging_output
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False
):
agg_loss, agg_sample_size, agg_logging_output = super().train_step(
sample, model, criterion, optimizer, update_num, ignore_grad
)
# compute auxiliary loss from layere sparsity, based on all samples from all languages
if hasattr(self, "sparsity_loss") and self.sparsity_loss.is_valid(update_num):
sparsity_loss = 0
if self.encoder_latent_layer:
sparsity_loss += self.sparsity_loss(
next(
iter(model.models.values())
).encoder.layer_select.layer_samples,
update_num,
agg_sample_size,
)
if self.decoder_latent_layer:
sparsity_loss += self.sparsity_loss(
next(
iter(model.models.values())
).decoder.layer_select.layer_samples,
update_num,
agg_sample_size,
)
if sparsity_loss > 0:
optimizer.backward(sparsity_loss)
return agg_loss, agg_sample_size, agg_logging_output
def _per_lang_pair_valid_loss(self, lang_pair, model, criterion, sample):
src, tgt = lang_pair.split("-")
if self.encoder_latent_layer:
src_lang_idx = self.src_lang_idx_dict[src]
model.models[lang_pair].encoder.set_lang_idx(src_lang_idx)
if self.decoder_latent_layer:
tgt_lang_idx = self.tgt_lang_idx_dict[tgt]
model.models[lang_pair].decoder.set_lang_idx(tgt_lang_idx)
loss, sample_size, logging_output = criterion(
model.models[lang_pair], sample[lang_pair]
)
return loss, sample_size, logging_output
def inference_step(
self, generator, models, sample, prefix_tokens=None, constraints=None
):
if self.encoder_latent_layer or self.decoder_latent_layer:
for model in models:
if self.encoder_latent_layer:
assert model.encoder.layer_select is not None
src_lang_idx = self.src_lang_idx_dict[self.args.source_lang]
model.encoder.set_lang_idx(src_lang_idx)
if self.decoder_latent_layer:
assert model.decoder.layer_select is not None
tgt_lang_idx = self.tgt_lang_idx_dict[self.args.target_lang]
model.decoder.set_lang_idx(tgt_lang_idx)
return super().inference_step(
generator, models, sample, prefix_tokens, constraints
)
@property
def encoder_latent_layer(self):
return (
safe_hasattr(self.args, "encoder_latent_layer")
and self.args.encoder_latent_layer
)
@property
def decoder_latent_layer(self):
return (
safe_hasattr(self.args, "decoder_latent_layer")
and self.args.decoder_latent_layer
)
@property
def src_lang_idx_dict(self):
return {lang: lang_idx for lang_idx, lang in enumerate(self.src_langs)}
@property
def tgt_lang_idx_dict(self):
return {lang: lang_idx for lang_idx, lang in enumerate(self.tgt_langs)}