Spaces:
Sleeping
Sleeping
Audio-Deepfake-Detection
/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
/fairseq
/modules
/adaptive_softmax.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import functools | |
import operator | |
import torch | |
import torch.nn.functional as F | |
from fairseq.modules.fairseq_dropout import FairseqDropout | |
from fairseq.modules.quant_noise import quant_noise | |
from torch import nn | |
class TiedLinear(nn.Module): | |
def __init__(self, weight, transpose): | |
super().__init__() | |
self.weight = weight | |
self.transpose = transpose | |
def forward(self, input): | |
return F.linear(input, self.weight.t() if self.transpose else self.weight) | |
class TiedHeadModule(nn.Module): | |
def __init__(self, weights, input_dim, num_classes, q_noise, qn_block_size): | |
super().__init__() | |
tied_emb, _ = weights | |
self.num_words, emb_dim = tied_emb.size() | |
self.word_proj = quant_noise( | |
TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size | |
) | |
if input_dim != emb_dim: | |
self.word_proj = nn.Sequential( | |
quant_noise( | |
nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size | |
), | |
self.word_proj, | |
) | |
self.class_proj = quant_noise( | |
nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size | |
) | |
self.out_dim = self.num_words + num_classes | |
self.register_buffer("_float_tensor", torch.FloatTensor(1)) | |
def forward(self, input): | |
inp_sz = functools.reduce(operator.mul, input.shape[:-1], 1) | |
out = self._float_tensor.new(inp_sz, self.out_dim) | |
out[:, : self.num_words] = self.word_proj(input.view(inp_sz, -1)) | |
out[:, self.num_words :] = self.class_proj(input.view(inp_sz, -1)) | |
return out | |
class AdaptiveSoftmax(nn.Module): | |
""" | |
This is an implementation of the efficient softmax approximation for | |
graphical processing units (GPU), described in the paper "Efficient softmax | |
approximation for GPUs" (http://arxiv.org/abs/1609.04309). | |
""" | |
def __init__( | |
self, | |
vocab_size, | |
input_dim, | |
cutoff, | |
dropout, | |
factor=4.0, | |
adaptive_inputs=None, | |
tie_proj=False, | |
q_noise=0, | |
qn_block_size=8, | |
): | |
super().__init__() | |
if vocab_size > cutoff[-1]: | |
cutoff = cutoff + [vocab_size] | |
else: | |
assert ( | |
vocab_size == cutoff[-1] | |
), "cannot specify cutoff larger than vocab size" | |
output_dim = cutoff[0] + len(cutoff) - 1 | |
self.vocab_size = vocab_size | |
self.cutoff = cutoff | |
self.dropout_module = FairseqDropout( | |
dropout, module_name=self.__class__.__name__ | |
) | |
self.input_dim = input_dim | |
self.factor = factor | |
self.q_noise = q_noise | |
self.qn_block_size = qn_block_size | |
self.lsm = nn.LogSoftmax(dim=1) | |
if adaptive_inputs is not None: | |
self.head = TiedHeadModule( | |
adaptive_inputs.weights_for_band(0), | |
input_dim, | |
len(cutoff) - 1, | |
self.q_noise, | |
self.qn_block_size, | |
) | |
else: | |
self.head = quant_noise( | |
nn.Linear(input_dim, output_dim, bias=False), | |
self.q_noise, | |
self.qn_block_size, | |
) | |
self._make_tail(adaptive_inputs, tie_proj) | |
def init_weights(m): | |
if ( | |
hasattr(m, "weight") | |
and not isinstance(m, TiedLinear) | |
and not isinstance(m, TiedHeadModule) | |
): | |
nn.init.xavier_uniform_(m.weight) | |
self.apply(init_weights) | |
self.register_buffer("version", torch.LongTensor([1])) | |
def _make_tail(self, adaptive_inputs=None, tie_proj=False): | |
self.tail = nn.ModuleList() | |
for i in range(len(self.cutoff) - 1): | |
dim = int(self.input_dim // self.factor ** (i + 1)) | |
tied_emb, tied_proj = ( | |
adaptive_inputs.weights_for_band(i + 1) | |
if adaptive_inputs is not None | |
else (None, None) | |
) | |
if tied_proj is not None: | |
if tie_proj: | |
proj = quant_noise( | |
TiedLinear(tied_proj, transpose=True), | |
self.q_noise, | |
self.qn_block_size, | |
) | |
else: | |
proj = quant_noise( | |
nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False), | |
self.q_noise, | |
self.qn_block_size, | |
) | |
else: | |
proj = quant_noise( | |
nn.Linear(self.input_dim, dim, bias=False), | |
self.q_noise, | |
self.qn_block_size, | |
) | |
if tied_emb is None: | |
out_proj = nn.Linear( | |
dim, self.cutoff[i + 1] - self.cutoff[i], bias=False | |
) | |
else: | |
out_proj = TiedLinear(tied_emb, transpose=False) | |
m = nn.Sequential( | |
proj, | |
nn.Dropout(self.dropout_module.p), | |
quant_noise(out_proj, self.q_noise, self.qn_block_size), | |
) | |
self.tail.append(m) | |
def upgrade_state_dict_named(self, state_dict, name): | |
version_name = name + ".version" | |
if version_name not in state_dict: | |
raise Exception("This version of the model is no longer supported") | |
def adapt_target(self, target): | |
""" | |
In order to be efficient, the AdaptiveSoftMax does not compute the | |
scores for all the word of the vocabulary for all the examples. It is | |
thus necessary to call the method adapt_target of the AdaptiveSoftMax | |
layer inside each forward pass. | |
""" | |
target = target.view(-1) | |
new_target = [target.clone()] | |
target_idxs = [] | |
for i in range(len(self.cutoff) - 1): | |
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) | |
new_target[0][mask] = self.cutoff[0] + i | |
if mask.any(): | |
target_idxs.append(mask.nonzero(as_tuple=False).squeeze(1)) | |
new_target.append(target[mask].add(-self.cutoff[i])) | |
else: | |
target_idxs.append(None) | |
new_target.append(None) | |
return new_target, target_idxs | |
def forward(self, input, target): | |
""" | |
Args: | |
input: (b x t x d) | |
target: (b x t) | |
Returns: | |
2 lists: output for each cutoff section and new targets by cut off | |
""" | |
input = input.contiguous().view(-1, input.size(-1)) | |
input = self.dropout_module(input) | |
new_target, target_idxs = self.adapt_target(target) | |
output = [self.head(input)] | |
for i in range(len(target_idxs)): | |
if target_idxs[i] is not None: | |
output.append(self.tail[i](input.index_select(0, target_idxs[i]))) | |
else: | |
output.append(None) | |
return output, new_target | |
def get_log_prob(self, input, target): | |
""" | |
Computes the log probabilities for all the words of the vocabulary, | |
given a 2D tensor of hidden vectors. | |
""" | |
bsz, length, dim = input.size() | |
input = input.contiguous().view(-1, dim) | |
if target is not None: | |
_, target_idxs = self.adapt_target(target) | |
else: | |
target_idxs = None | |
head_y = self.head(input) | |
log_probs = head_y.new_zeros(input.size(0), self.vocab_size) | |
head_sz = self.cutoff[0] + len(self.tail) | |
log_probs[:, :head_sz] = self.lsm(head_y) | |
tail_priors = log_probs[:, self.cutoff[0] : head_sz].clone() | |
for i in range(len(self.tail)): | |
start = self.cutoff[i] | |
end = self.cutoff[i + 1] | |
if target_idxs is None: | |
tail_out = log_probs[:, start:end] | |
tail_out.copy_(self.tail[i](input)) | |
log_probs[:, start:end] = self.lsm(tail_out).add_( | |
tail_priors[:, i, None] | |
) | |
elif target_idxs[i] is not None: | |
idxs = target_idxs[i] | |
tail_out = log_probs[idxs, start:end] | |
tail_out.copy_(self.tail[i](input[idxs])) | |
log_probs[idxs, start:end] = self.lsm(tail_out).add_( | |
tail_priors[idxs, i, None] | |
) | |
log_probs = log_probs.view(bsz, length, -1) | |
return log_probs | |