Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from examples.speech_recognition.data.replabels import pack_replabels | |
from fairseq import utils | |
from fairseq.criterions import FairseqCriterion, register_criterion | |
class ASGCriterion(FairseqCriterion): | |
def add_args(parser): | |
group = parser.add_argument_group("ASG Loss") | |
group.add_argument( | |
"--asg-transitions-init", | |
help="initial diagonal value of transition matrix", | |
type=float, | |
default=0.0, | |
) | |
group.add_argument( | |
"--max-replabel", help="maximum # of replabels", type=int, default=2 | |
) | |
group.add_argument( | |
"--linseg-updates", | |
help="# of training updates to use LinSeg initialization", | |
type=int, | |
default=0, | |
) | |
group.add_argument( | |
"--hide-linseg-messages", | |
help="hide messages about LinSeg initialization", | |
action="store_true", | |
) | |
def __init__( | |
self, | |
task, | |
silence_token, | |
asg_transitions_init, | |
max_replabel, | |
linseg_updates, | |
hide_linseg_messages, | |
): | |
from flashlight.lib.sequence.criterion import ASGLoss, CriterionScaleMode | |
super().__init__(task) | |
self.tgt_dict = task.target_dictionary | |
self.eos = self.tgt_dict.eos() | |
self.silence = ( | |
self.tgt_dict.index(silence_token) | |
if silence_token in self.tgt_dict | |
else None | |
) | |
self.max_replabel = max_replabel | |
num_labels = len(self.tgt_dict) | |
self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT) | |
self.asg.trans = torch.nn.Parameter( | |
asg_transitions_init * torch.eye(num_labels), requires_grad=True | |
) | |
self.linseg_progress = torch.nn.Parameter( | |
torch.tensor([0], dtype=torch.int), requires_grad=False | |
) | |
self.linseg_maximum = linseg_updates | |
self.linseg_message_state = "none" if hide_linseg_messages else "start" | |
def build_criterion(cls, args, task): | |
return cls( | |
task, | |
args.silence_token, | |
args.asg_transitions_init, | |
args.max_replabel, | |
args.linseg_updates, | |
args.hide_linseg_messages, | |
) | |
def linseg_step(self): | |
if not self.training: | |
return False | |
if self.linseg_progress.item() < self.linseg_maximum: | |
if self.linseg_message_state == "start": | |
print("| using LinSeg to initialize ASG") | |
self.linseg_message_state = "finish" | |
self.linseg_progress.add_(1) | |
return True | |
elif self.linseg_message_state == "finish": | |
print("| finished LinSeg initialization") | |
self.linseg_message_state = "none" | |
return False | |
def replace_eos_with_silence(self, tgt): | |
if tgt[-1] != self.eos: | |
return tgt | |
elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence): | |
return tgt[:-1] | |
else: | |
return tgt[:-1] + [self.silence] | |
def forward(self, model, sample, reduce=True): | |
"""Compute the loss for the given sample. | |
Returns a tuple with three elements: | |
1) the loss | |
2) the sample size, which is used as the denominator for the gradient | |
3) logging outputs to display while training | |
""" | |
net_output = model(**sample["net_input"]) | |
emissions = net_output["encoder_out"].transpose(0, 1).contiguous() | |
B = emissions.size(0) | |
T = emissions.size(1) | |
device = emissions.device | |
target = torch.IntTensor(B, T) | |
target_size = torch.IntTensor(B) | |
using_linseg = self.linseg_step() | |
for b in range(B): | |
initial_target_size = sample["target_lengths"][b].item() | |
if initial_target_size == 0: | |
raise ValueError("target size cannot be zero") | |
tgt = sample["target"][b, :initial_target_size].tolist() | |
tgt = self.replace_eos_with_silence(tgt) | |
tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel) | |
tgt = tgt[:T] | |
if using_linseg: | |
tgt = [tgt[t * len(tgt) // T] for t in range(T)] | |
target[b][: len(tgt)] = torch.IntTensor(tgt) | |
target_size[b] = len(tgt) | |
loss = self.asg.forward(emissions, target.to(device), target_size.to(device)) | |
if reduce: | |
loss = torch.sum(loss) | |
sample_size = ( | |
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"] | |
) | |
logging_output = { | |
"loss": utils.item(loss.data) if reduce else loss.data, | |
"ntokens": sample["ntokens"], | |
"nsentences": sample["target"].size(0), | |
"sample_size": sample_size, | |
} | |
return loss, sample_size, logging_output | |
def aggregate_logging_outputs(logging_outputs): | |
"""Aggregate logging outputs from data parallel training.""" | |
loss_sum = sum(log.get("loss", 0) for log in logging_outputs) | |
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) | |
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) | |
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) | |
agg_output = { | |
"loss": loss_sum / nsentences, | |
"ntokens": ntokens, | |
"nsentences": nsentences, | |
"sample_size": sample_size, | |
} | |
return agg_output | |