DeCRED-base / embeddings.py
Lakoc's picture
Upload JointCTCAttentionEncoderDecoder
9b4bf4d verified
raw
history blame
3.01 kB
import torch
from torch import nn
class AdaptiveEmbedding(nn.Module):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False):
super().__init__()
self.n_token = n_token
self.d_embed = d_embed
self.cutoffs = cutoffs + [n_token]
self.div_val = div_val
self.d_proj = d_proj
self.emb_scale = d_proj**0.5
self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val == 1:
self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0))
if d_proj != d_embed:
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed)))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = d_embed // (div_val**i)
self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)))
def forward(self, inp):
if self.div_val == 1:
embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed:
embed = nn.functional.linear(embed, self.emb_projs[0])
else:
param = next(self.parameters())
inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = nn.functional.linear(emb_i, self.emb_projs[i])
emb_flat.index_copy_(0, indices_i, emb_i)
embed_shape = inp.size() + (self.d_proj,)
embed = emb_flat.view(embed_shape)
embed.mul_(self.emb_scale)
return embed
class PositionalEmbeddingAux(nn.Module):
def __init__(self, demb):
super().__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer("inv_freq", inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.outer(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
if bsz is not None:
return pos_emb[:, None, :].expand(-1, bsz, -1)
else:
return pos_emb[:, None, :]
class PositionalEmbedding(PositionalEmbeddingAux):
def forward(self, pos_seq, bsz=None):
return super().forward(pos_seq.squeeze(0), bsz=bsz).squeeze(1)