SillyTavern-Extras1
/
modules
/voice_conversion
/fairseq
/models
/nat
/nonautoregressive_ensembles.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import torch | |
import torch.nn.functional as F | |
from fairseq.models.nat import ( | |
_apply_del_words, | |
_apply_ins_masks, | |
_apply_ins_words, | |
_fill, | |
_skip, | |
_skip_encoder_out, | |
) | |
class _EnsembleModelEncoder(object): | |
def __init__(self, models): | |
self.models = models | |
def reorder_encoder_out(self, encoder_outs, new_order): | |
encoder_outs = [ | |
model.encoder.reorder_encoder_out(encoder_out, new_order) | |
for model, encoder_out in zip(self.models, encoder_outs) | |
] | |
return encoder_outs | |
class BasicEnsembleModel(torch.nn.Module): | |
"""A wrapper around an ensemble of models.""" | |
def __init__(self, models): | |
super().__init__() | |
self.models = torch.nn.ModuleList(models) | |
self.bos = self.models[0].decoder.dictionary.bos() | |
self.eos = self.models[0].decoder.dictionary.eos() | |
self.pad = self.models[0].decoder.dictionary.pad() | |
self.unk = self.models[0].decoder.dictionary.unk() | |
self.encoder = _EnsembleModelEncoder(self.models) | |
def has_encoder(self): | |
return hasattr(self.models[0], "encoder") | |
def max_decoder_positions(self): | |
return min(m.max_decoder_positions() for m in self.models) | |
def forward_encoder(self, encoder_input): | |
if not self.has_encoder(): | |
return None | |
return [model.forward_encoder(encoder_input) for model in self.models] | |
def forward_decoder(self, *inputs): | |
raise NotImplementedError | |
def initialize_output_tokens(self, *inputs): | |
raise NotImplementedError | |
class EnsembleLevT(BasicEnsembleModel): | |
"""A wrapper around an ensemble of models.""" | |
def __init__(self, models): | |
super().__init__(models) | |
def forward_decoder( | |
self, decoder_out, encoder_outs, eos_penalty=0.0, max_ratio=None, **kwargs | |
): | |
# LevT ensembling | |
# A pipeline of three steps: deletion, placeholder, and word insertion. | |
# We need to average scores in each step in a pipeline way because of dependence. | |
# deletion | |
output_tokens = decoder_out.output_tokens | |
output_scores = decoder_out.output_scores | |
attn = decoder_out.attn | |
bsz = output_tokens.size(0) | |
if max_ratio is None: | |
max_lens = output_tokens.new().fill_(255) | |
else: | |
if not encoder_outs[0]["encoder_padding_mask"]: | |
src_lens = ( | |
encoder_outs[0]["encoder_out"][0] | |
.new(bsz) | |
.fill_(encoder_outs[0]["encoder_out"][0].size(1)) | |
) | |
else: | |
src_lens = (~encoder_outs[0]["encoder_padding_mask"][0]).sum(1) | |
max_lens = (src_lens * max_ratio).clamp(min=10).long() | |
# delete words | |
# do not delete tokens if it is <s> </s> | |
can_del_word = output_tokens.ne(self.pad).sum(1) > 2 | |
if can_del_word.sum() != 0: # we cannot delete, skip | |
output_tokens, output_scores, attn = self.forward_word_del( | |
encoder_outs, | |
output_tokens, | |
output_scores, | |
attn, | |
can_del_word, | |
) | |
# insert placeholders | |
can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens | |
if can_ins_mask.sum() != 0: | |
output_tokens, output_scores = self.forward_mask_ins( | |
encoder_outs, | |
output_tokens, | |
output_scores, | |
can_ins_mask, | |
eos_penalty, | |
max_lens, | |
) | |
# insert words | |
can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 | |
if can_ins_word.sum() != 0: | |
output_tokens, output_scores, attn = self.forward_word_ins( | |
encoder_outs, | |
output_tokens, | |
output_scores, | |
attn, | |
can_ins_word, | |
) | |
# delete some unnecessary paddings | |
cut_off = output_tokens.ne(self.pad).sum(1).max() | |
output_tokens = output_tokens[:, :cut_off] | |
output_scores = output_scores[:, :cut_off] | |
attn = None if attn is None else attn[:, :cut_off, :] | |
return decoder_out._replace( | |
output_tokens=output_tokens, | |
output_scores=output_scores, | |
attn=attn, | |
history=None, | |
) | |
def forward_word_del( | |
self, encoder_outs, output_tokens, output_scores, attn, can_del_word | |
): | |
word_del_score_avg = [] | |
word_del_attn_avg = [] | |
for model, encoder_out in zip(self.models, encoder_outs): | |
word_del_out, word_del_attn = model.decoder.forward_word_del( | |
_skip(output_tokens, can_del_word), | |
_skip_encoder_out(model.encoder, encoder_out, can_del_word), | |
) | |
word_del_score = F.log_softmax(word_del_out, 2) | |
word_del_score_avg.append(word_del_score) | |
word_del_attn_avg.append(word_del_attn) | |
word_del_score_avg = torch.logsumexp( | |
torch.stack(word_del_score_avg, dim=0), dim=0 | |
) - math.log(len(self.models)) | |
word_del_pred = word_del_score_avg.max(-1)[1].bool() | |
if word_del_attn_avg[0] is not None: | |
word_del_attn_avg = torch.stack(word_del_attn_avg, dim=0) / len(self.models) | |
else: | |
word_del_attn_avg = None | |
_tokens, _scores, _attn = _apply_del_words( | |
output_tokens[can_del_word], | |
output_scores[can_del_word], | |
word_del_attn_avg, | |
word_del_pred, | |
self.pad, | |
self.bos, | |
self.eos, | |
) | |
output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad) | |
output_scores = _fill(output_scores, can_del_word, _scores, 0) | |
attn = _fill(attn, can_del_word, _attn, 0.0) | |
return output_tokens, output_scores, attn | |
def forward_mask_ins( | |
self, | |
encoder_outs, | |
output_tokens, | |
output_scores, | |
can_ins_mask, | |
eos_penalty, | |
max_lens, | |
): | |
mask_ins_score_avg = [] | |
for model, encoder_out in zip(self.models, encoder_outs): | |
mask_ins_out, _ = model.decoder.forward_mask_ins( | |
_skip(output_tokens, can_ins_mask), | |
_skip_encoder_out(model.encoder, encoder_out, can_ins_mask), | |
) | |
mask_ins_score = F.log_softmax(mask_ins_out, 2) | |
if eos_penalty > 0.0: | |
mask_ins_score[:, :, 0] -= eos_penalty | |
mask_ins_score_avg.append(mask_ins_score) | |
mask_ins_score_avg = torch.logsumexp( | |
torch.stack(mask_ins_score_avg, dim=0), dim=0 | |
) - math.log(len(self.models)) | |
mask_ins_pred = mask_ins_score_avg.max(-1)[1] | |
mask_ins_pred = torch.min( | |
mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred) | |
) | |
_tokens, _scores = _apply_ins_masks( | |
output_tokens[can_ins_mask], | |
output_scores[can_ins_mask], | |
mask_ins_pred, | |
self.pad, | |
self.unk, | |
self.eos, | |
) | |
output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) | |
output_scores = _fill(output_scores, can_ins_mask, _scores, 0) | |
return output_tokens, output_scores | |
def forward_word_ins( | |
self, encoder_outs, output_tokens, output_scores, attn, can_ins_word | |
): | |
word_ins_score_avg = [] | |
word_ins_attn_avg = [] | |
for model, encoder_out in zip(self.models, encoder_outs): | |
word_ins_out, word_ins_attn = model.decoder.forward_word_ins( | |
_skip(output_tokens, can_ins_word), | |
_skip_encoder_out(model.encoder, encoder_out, can_ins_word), | |
) | |
word_ins_score = F.log_softmax(word_ins_out, 2) | |
word_ins_score_avg.append(word_ins_score) | |
word_ins_attn_avg.append(word_ins_attn) | |
word_ins_score_avg = torch.logsumexp( | |
torch.stack(word_ins_score_avg, dim=0), dim=0 | |
) - math.log(len(self.models)) | |
if word_ins_attn_avg[0] is not None: | |
word_ins_attn_avg = torch.stack(word_ins_attn_avg, dim=0) / len(self.models) | |
else: | |
word_ins_attn_avg = None | |
word_ins_score_max, word_ins_pred = word_ins_score_avg.max(-1) | |
_tokens, _scores = _apply_ins_words( | |
output_tokens[can_ins_word], | |
output_scores[can_ins_word], | |
word_ins_pred, | |
word_ins_score_max, | |
self.unk, | |
) | |
output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) | |
output_scores = _fill(output_scores, can_ins_word, _scores, 0) | |
attn = _fill(attn, can_ins_word, word_ins_attn, 0.0) | |
return output_tokens, output_scores, attn | |
def initialize_output_tokens(self, encoder_outs, src_tokens): | |
# LevT doesn't do length prediction. | |
return self.models[0].initialize_output_tokens(encoder_outs[0], src_tokens) | |