| | |
| | |
| | |
| | |
| |
|
| |
|
| | from typing import List |
| |
|
| | import torch |
| | from fairseq.modules.quant_noise import quant_noise |
| | from torch import nn |
| |
|
| |
|
| | class AdaptiveInput(nn.Module): |
| | def __init__( |
| | self, |
| | vocab_size: int, |
| | padding_idx: int, |
| | initial_dim: int, |
| | factor: float, |
| | output_dim: int, |
| | cutoff: List[int], |
| | q_noise: float = 0, |
| | qn_block_size: int = 8, |
| | ): |
| | super().__init__() |
| |
|
| | if vocab_size > cutoff[-1]: |
| | cutoff = cutoff + [vocab_size] |
| | else: |
| | assert ( |
| | vocab_size == cutoff[-1] |
| | ), "cannot specify cutoff larger than vocab size" |
| |
|
| | self.cutoff = cutoff |
| | self.embedding_dim = output_dim |
| | self.padding_idx = padding_idx |
| |
|
| | self.embeddings = nn.ModuleList() |
| | for i in range(len(self.cutoff)): |
| | prev = self.cutoff[i - 1] if i > 0 else 0 |
| | size = self.cutoff[i] - prev |
| | dim = int(initial_dim // (factor ** i)) |
| | seq = nn.Sequential( |
| | nn.Embedding(size, dim, self.padding_idx), |
| | quant_noise( |
| | nn.Linear(dim, output_dim, bias=False), q_noise, qn_block_size |
| | ), |
| | ) |
| |
|
| | self.embeddings.append(seq) |
| | self.padding_idx = None |
| | self.padding_idx = padding_idx |
| |
|
| | def init_weights(m): |
| | if isinstance(m, nn.Embedding): |
| | nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5) |
| | nn.init.constant_(m.weight[padding_idx], 0) |
| | elif hasattr(m, "weight"): |
| | nn.init.xavier_uniform_(m.weight) |
| |
|
| | self.apply(init_weights) |
| |
|
| | self.register_buffer("_float_tensor", torch.FloatTensor(1)) |
| |
|
| | def weights_for_band(self, band: int): |
| | return self.embeddings[band][0].weight, self.embeddings[band][1].weight |
| |
|
| | def forward(self, input: torch.Tensor): |
| | result = self._float_tensor.new(input.shape + (self.embedding_dim,)) |
| | for i in range(len(self.cutoff)): |
| | mask = input.lt(self.cutoff[i]) |
| | if i > 0: |
| | mask.mul_(input.ge(self.cutoff[i - 1])) |
| | chunk_input = input[mask] - self.cutoff[i - 1] |
| | else: |
| | chunk_input = input[mask] |
| | if mask.any(): |
| | result[mask] = self.embeddings[i](chunk_input) |
| | return result |
| |
|