JustinLin610
update
8437114
raw history blame
No virus
8.79 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import operator
import torch
import torch.nn.functional as F
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import nn
class TiedLinear(nn.Module):
def __init__(self, weight, transpose):
super().__init__()
self.weight = weight
self.transpose = transpose
def forward(self, input):
return F.linear(input, self.weight.t() if self.transpose else self.weight)
class TiedHeadModule(nn.Module):
def __init__(self, weights, input_dim, num_classes, q_noise, qn_block_size):
super().__init__()
tied_emb, _ = weights
self.num_words, emb_dim = tied_emb.size()
self.word_proj = quant_noise(
TiedLinear(tied_emb, transpose=False), q_noise, qn_block_size
)
if input_dim != emb_dim:
self.word_proj = nn.Sequential(
quant_noise(
nn.Linear(input_dim, emb_dim, bias=False), q_noise, qn_block_size
),
self.word_proj,
)
self.class_proj = quant_noise(
nn.Linear(input_dim, num_classes, bias=False), q_noise, qn_block_size
)
self.out_dim = self.num_words + num_classes
self.register_buffer("_float_tensor", torch.FloatTensor(1))
def forward(self, input):
inp_sz = functools.reduce(operator.mul, input.shape[:-1], 1)
out = self._float_tensor.new(inp_sz, self.out_dim)
out[:, : self.num_words] = self.word_proj(input.view(inp_sz, -1))
out[:, self.num_words :] = self.class_proj(input.view(inp_sz, -1))
return out
class AdaptiveSoftmax(nn.Module):
"""
This is an implementation of the efficient softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax
approximation for GPUs" (http://arxiv.org/abs/1609.04309).
"""
def __init__(
self,
vocab_size,
input_dim,
cutoff,
dropout,
factor=4.0,
adaptive_inputs=None,
tie_proj=False,
q_noise=0,
qn_block_size=8,
):
super().__init__()
if vocab_size > cutoff[-1]:
cutoff = cutoff + [vocab_size]
else:
assert (
vocab_size == cutoff[-1]
), "cannot specify cutoff larger than vocab size"
output_dim = cutoff[0] + len(cutoff) - 1
self.vocab_size = vocab_size
self.cutoff = cutoff
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.input_dim = input_dim
self.factor = factor
self.q_noise = q_noise
self.qn_block_size = qn_block_size
self.lsm = nn.LogSoftmax(dim=1)
if adaptive_inputs is not None:
self.head = TiedHeadModule(
adaptive_inputs.weights_for_band(0),
input_dim,
len(cutoff) - 1,
self.q_noise,
self.qn_block_size,
)
else:
self.head = quant_noise(
nn.Linear(input_dim, output_dim, bias=False),
self.q_noise,
self.qn_block_size,
)
self._make_tail(adaptive_inputs, tie_proj)
def init_weights(m):
if (
hasattr(m, "weight")
and not isinstance(m, TiedLinear)
and not isinstance(m, TiedHeadModule)
):
nn.init.xavier_uniform_(m.weight)
self.apply(init_weights)
self.register_buffer("version", torch.LongTensor([1]))
def _make_tail(self, adaptive_inputs=None, tie_proj=False):
self.tail = nn.ModuleList()
for i in range(len(self.cutoff) - 1):
dim = int(self.input_dim // self.factor ** (i + 1))
tied_emb, tied_proj = (
adaptive_inputs.weights_for_band(i + 1)
if adaptive_inputs is not None
else (None, None)
)
if tied_proj is not None:
if tie_proj:
proj = quant_noise(
TiedLinear(tied_proj, transpose=True),
self.q_noise,
self.qn_block_size,
)
else:
proj = quant_noise(
nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False),
self.q_noise,
self.qn_block_size,
)
else:
proj = quant_noise(
nn.Linear(self.input_dim, dim, bias=False),
self.q_noise,
self.qn_block_size,
)
if tied_emb is None:
out_proj = nn.Linear(
dim, self.cutoff[i + 1] - self.cutoff[i], bias=False
)
else:
out_proj = TiedLinear(tied_emb, transpose=False)
m = nn.Sequential(
proj,
nn.Dropout(self.dropout_module.p),
quant_noise(out_proj, self.q_noise, self.qn_block_size),
)
self.tail.append(m)
def upgrade_state_dict_named(self, state_dict, name):
version_name = name + ".version"
if version_name not in state_dict:
raise Exception("This version of the model is no longer supported")
def adapt_target(self, target):
"""
In order to be efficient, the AdaptiveSoftMax does not compute the
scores for all the word of the vocabulary for all the examples. It is
thus necessary to call the method adapt_target of the AdaptiveSoftMax
layer inside each forward pass.
"""
target = target.view(-1)
new_target = [target.clone()]
target_idxs = []
for i in range(len(self.cutoff) - 1):
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
new_target[0][mask] = self.cutoff[0] + i
if mask.any():
target_idxs.append(mask.nonzero(as_tuple=False).squeeze(1))
new_target.append(target[mask].add(-self.cutoff[i]))
else:
target_idxs.append(None)
new_target.append(None)
return new_target, target_idxs
def forward(self, input, target):
"""
Args:
input: (b x t x d)
target: (b x t)
Returns:
2 lists: output for each cutoff section and new targets by cut off
"""
input = input.contiguous().view(-1, input.size(-1))
input = self.dropout_module(input)
new_target, target_idxs = self.adapt_target(target)
output = [self.head(input)]
for i in range(len(target_idxs)):
if target_idxs[i] is not None:
output.append(self.tail[i](input.index_select(0, target_idxs[i])))
else:
output.append(None)
return output, new_target
def get_log_prob(self, input, target):
"""
Computes the log probabilities for all the words of the vocabulary,
given a 2D tensor of hidden vectors.
"""
bsz, length, dim = input.size()
input = input.contiguous().view(-1, dim)
if target is not None:
_, target_idxs = self.adapt_target(target)
else:
target_idxs = None
head_y = self.head(input)
log_probs = head_y.new_zeros(input.size(0), self.vocab_size)
head_sz = self.cutoff[0] + len(self.tail)
log_probs[:, :head_sz] = self.lsm(head_y)
tail_priors = log_probs[:, self.cutoff[0] : head_sz].clone()
for i in range(len(self.tail)):
start = self.cutoff[i]
end = self.cutoff[i + 1]
if target_idxs is None:
tail_out = log_probs[:, start:end]
tail_out.copy_(self.tail[i](input))
log_probs[:, start:end] = self.lsm(tail_out).add_(
tail_priors[:, i, None]
)
elif target_idxs[i] is not None:
idxs = target_idxs[i]
tail_out = log_probs[idxs, start:end]
tail_out.copy_(self.tail[i](input[idxs]))
log_probs[idxs, start:end] = self.lsm(tail_out).add_(
tail_priors[idxs, i, None]
)
log_probs = log_probs.view(bsz, length, -1)
return log_probs