# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import functools import operator import torch import torch.nn.functional as F from fairseq.modules.fairseq_dropout import FairseqDropout from fairseq.modules.quant_noise import quant_noise from torch import nn class TiedLinear(nn.Module): def __init__(self, weight, transpose): super().__init__() self.weight = weight self.transpose = transpose def forward(self, input): return F.linear(input, self.weight.t() if self.transpose else self.weight) class TiedHeadModule(nn.Module): def __init__(self, weights, input_dim, num_classes, q_noise, qn_block_size): super().__init__() tied_emb, _ = weights self.num_words, emb_dim = tied_emb.size() self.word_proj = quant_noise( TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size ) if input_dim != emb_dim: self.word_proj = nn.Sequential( quant_noise( nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size ), self.word_proj, ) self.class_proj = quant_noise( nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size ) self.out_dim = self.num_words + num_classes self.register_buffer("_float_tensor", torch.FloatTensor(1)) def forward(self, input): inp_sz = functools.reduce(operator.mul, input.shape[:-1], 1) out = self._float_tensor.new(inp_sz, self.out_dim) out[:, : self.num_words] = self.word_proj(input.view(inp_sz, -1)) out[:, self.num_words :] = self.class_proj(input.view(inp_sz, -1)) return out class AdaptiveSoftmax(nn.Module): """ This is an implementation of the efficient softmax approximation for graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" (http://arxiv.org/abs/1609.04309). """ def __init__( self, vocab_size, input_dim, cutoff, dropout, factor=4.0, adaptive_inputs=None, tie_proj=False, q_noise=0, qn_block_size=8, ): super().__init__() if vocab_size > cutoff[-1]: cutoff = cutoff + [vocab_size] else: assert ( vocab_size == cutoff[-1] ), "cannot specify cutoff larger than vocab size" output_dim = cutoff[0] + len(cutoff) - 1 self.vocab_size = vocab_size self.cutoff = cutoff self.dropout_module = FairseqDropout( dropout, module_name=self.__class__.__name__ ) self.input_dim = input_dim self.factor = factor self.q_noise = q_noise self.qn_block_size = qn_block_size self.lsm = nn.LogSoftmax(dim=1) if adaptive_inputs is not None: self.head = TiedHeadModule( adaptive_inputs.weights_for_band(0), input_dim, len(cutoff) - 1, self.q_noise, self.qn_block_size, ) else: self.head = quant_noise( nn.Linear(input_dim, output_dim, bias=False), self.q_noise, self.qn_block_size, ) self._make_tail(adaptive_inputs, tie_proj) def init_weights(m): if ( hasattr(m, "weight") and not isinstance(m, TiedLinear) and not isinstance(m, TiedHeadModule) ): nn.init.xavier_uniform_(m.weight) self.apply(init_weights) self.register_buffer("version", torch.LongTensor([1])) def _make_tail(self, adaptive_inputs=None, tie_proj=False): self.tail = nn.ModuleList() for i in range(len(self.cutoff) - 1): dim = int(self.input_dim // self.factor ** (i + 1)) tied_emb, tied_proj = ( adaptive_inputs.weights_for_band(i + 1) if adaptive_inputs is not None else (None, None) ) if tied_proj is not None: if tie_proj: proj = quant_noise( TiedLinear(tied_proj, transpose=True), self.q_noise, self.qn_block_size, ) else: proj = quant_noise( nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False), self.q_noise, self.qn_block_size, ) else: proj = quant_noise( nn.Linear(self.input_dim, dim, bias=False), self.q_noise, self.qn_block_size, ) if tied_emb is None: out_proj = nn.Linear( dim, self.cutoff[i + 1] - self.cutoff[i], bias=False ) else: out_proj = TiedLinear(tied_emb, transpose=False) m = nn.Sequential( proj, nn.Dropout(self.dropout_module.p), quant_noise(out_proj, self.q_noise, self.qn_block_size), ) self.tail.append(m) def upgrade_state_dict_named(self, state_dict, name): version_name = name + ".version" if version_name not in state_dict: raise Exception("This version of the model is no longer supported") def adapt_target(self, target): """ In order to be efficient, the AdaptiveSoftMax does not compute the scores for all the word of the vocabulary for all the examples. It is thus necessary to call the method adapt_target of the AdaptiveSoftMax layer inside each forward pass. """ target = target.view(-1) new_target = [target.clone()] target_idxs = [] for i in range(len(self.cutoff) - 1): mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) new_target[0][mask] = self.cutoff[0] + i if mask.any(): target_idxs.append(mask.nonzero(as_tuple=False).squeeze(1)) new_target.append(target[mask].add(-self.cutoff[i])) else: target_idxs.append(None) new_target.append(None) return new_target, target_idxs def forward(self, input, target): """ Args: input: (b x t x d) target: (b x t) Returns: 2 lists: output for each cutoff section and new targets by cut off """ input = input.contiguous().view(-1, input.size(-1)) input = self.dropout_module(input) new_target, target_idxs = self.adapt_target(target) output = [self.head(input)] for i in range(len(target_idxs)): if target_idxs[i] is not None: output.append(self.tail[i](input.index_select(0, target_idxs[i]))) else: output.append(None) return output, new_target def get_log_prob(self, input, target): """ Computes the log probabilities for all the words of the vocabulary, given a 2D tensor of hidden vectors. """ bsz, length, dim = input.size() input = input.contiguous().view(-1, dim) if target is not None: _, target_idxs = self.adapt_target(target) else: target_idxs = None head_y = self.head(input) log_probs = head_y.new_zeros(input.size(0), self.vocab_size) head_sz = self.cutoff[0] + len(self.tail) log_probs[:, :head_sz] = self.lsm(head_y) tail_priors = log_probs[:, self.cutoff[0] : head_sz].clone() for i in range(len(self.tail)): start = self.cutoff[i] end = self.cutoff[i + 1] if target_idxs is None: tail_out = log_probs[:, start:end] tail_out.copy_(self.tail[i](input)) log_probs[:, start:end] = self.lsm(tail_out).add_( tail_priors[:, i, None] ) elif target_idxs[i] is not None: idxs = target_idxs[i] tail_out = log_probs[idxs, start:end] tail_out.copy_(self.tail[i](input[idxs])) log_probs[idxs, start:end] = self.lsm(tail_out).add_( tail_priors[idxs, i, None] ) log_probs = log_probs.view(bsz, length, -1) return log_probs