import torch class Conv1d(torch.nn.Conv1d): def __init__(self, w_init_gain= 'linear', *args, **kwargs): self.w_init_gain = w_init_gain super().__init__(*args, **kwargs) def reset_parameters(self): if self.w_init_gain in ['zero']: torch.nn.init.zeros_(self.weight) elif self.w_init_gain is None: pass elif self.w_init_gain in ['relu', 'leaky_relu']: torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) elif self.w_init_gain == 'glu': assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) elif self.w_init_gain == 'gate': assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) else: torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) if not self.bias is None: torch.nn.init.zeros_(self.bias) class ConvTranspose1d(torch.nn.ConvTranspose1d): def __init__(self, w_init_gain= 'linear', *args, **kwargs): self.w_init_gain = w_init_gain super().__init__(*args, **kwargs) def reset_parameters(self): if self.w_init_gain in ['zero']: torch.nn.init.zeros_(self.weight) elif self.w_init_gain in ['relu', 'leaky_relu']: torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) elif self.w_init_gain == 'glu': assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) elif self.w_init_gain == 'gate': assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) else: torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) if not self.bias is None: torch.nn.init.zeros_(self.bias) class Conv2d(torch.nn.Conv2d): def __init__(self, w_init_gain= 'linear', *args, **kwargs): self.w_init_gain = w_init_gain super().__init__(*args, **kwargs) def reset_parameters(self): if self.w_init_gain in ['zero']: torch.nn.init.zeros_(self.weight) elif self.w_init_gain in ['relu', 'leaky_relu']: torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) elif self.w_init_gain == 'glu': assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) elif self.w_init_gain == 'gate': assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) else: torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) if not self.bias is None: torch.nn.init.zeros_(self.bias) class ConvTranspose2d(torch.nn.ConvTranspose2d): def __init__(self, w_init_gain= 'linear', *args, **kwargs): self.w_init_gain = w_init_gain super().__init__(*args, **kwargs) def reset_parameters(self): if self.w_init_gain in ['zero']: torch.nn.init.zeros_(self.weight) elif self.w_init_gain in ['relu', 'leaky_relu']: torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) elif self.w_init_gain == 'glu': assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) elif self.w_init_gain == 'gate': assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh')) torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) else: torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) if not self.bias is None: torch.nn.init.zeros_(self.bias) class Linear(torch.nn.Linear): def __init__(self, w_init_gain= 'linear', *args, **kwargs): self.w_init_gain = w_init_gain super().__init__(*args, **kwargs) def reset_parameters(self): if self.w_init_gain in ['zero']: torch.nn.init.zeros_(self.weight) elif self.w_init_gain in ['relu', 'leaky_relu']: torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) elif self.w_init_gain == 'glu': assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) else: torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) if not self.bias is None: torch.nn.init.zeros_(self.bias) class Lambda(torch.nn.Module): def __init__(self, lambd): super().__init__() self.lambd = lambd def forward(self, x): return self.lambd(x) class Residual(torch.nn.Module): def __init__(self, module): super().__init__() self.module = module def forward(self, *args, **kwargs): return self.module(*args, **kwargs) class LayerNorm(torch.nn.Module): def __init__(self, num_features: int, eps: float= 1e-5): super().__init__() self.eps = eps self.gamma = torch.nn.Parameter(torch.ones(num_features)) self.beta = torch.nn.Parameter(torch.zeros(num_features)) def forward(self, inputs: torch.Tensor): means = inputs.mean(dim= 1, keepdim= True) variances = (inputs - means).pow(2.0).mean(dim= 1, keepdim= True) x = (inputs - means) * (variances + self.eps).rsqrt() shape = [1, -1] + [1] * (x.ndim - 2) return x * self.gamma.view(*shape) + self.beta.view(*shape) class LightweightConv1d(torch.nn.Module): ''' Args: input_size: # of channels of the input and output kernel_size: convolution channels padding: padding num_heads: number of heads used. The weight is of shape `(num_heads, 1, kernel_size)` weight_softmax: normalize the weight with softmax before the convolution Shape: Input: BxCxT, i.e. (batch_size, input_size, timesteps) Output: BxCxT, i.e. (batch_size, input_size, timesteps) Attributes: weight: the learnable weights of the module of shape `(num_heads, 1, kernel_size)` bias: the learnable bias of the module of shape `(input_size)` ''' def __init__( self, input_size, kernel_size=1, padding=0, num_heads=1, weight_softmax=False, bias=False, weight_dropout=0.0, w_init_gain= 'linear' ): super().__init__() self.input_size = input_size self.kernel_size = kernel_size self.num_heads = num_heads self.padding = padding self.weight_softmax = weight_softmax self.weight = torch.nn.Parameter(torch.Tensor(num_heads, 1, kernel_size)) self.w_init_gain = w_init_gain if bias: self.bias = torch.nn.Parameter(torch.Tensor(input_size)) else: self.bias = None self.weight_dropout_module = FairseqDropout( weight_dropout, module_name=self.__class__.__name__ ) self.reset_parameters() def reset_parameters(self): if self.w_init_gain in ['relu', 'leaky_relu']: torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain) elif self.w_init_gain == 'glu': assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.' torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear') torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid')) else: torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain)) if not self.bias is None: torch.nn.init.zeros_(self.bias) def forward(self, input): """ input size: B x C x T output size: B x C x T """ B, C, T = input.size() H = self.num_heads weight = self.weight if self.weight_softmax: weight = weight.softmax(dim=-1) weight = self.weight_dropout_module(weight) # Merge every C/H entries into the batch dimension (C = self.input_size) # B x C x T -> (B * C/H) x H x T # One can also expand the weight to C x 1 x K by a factor of C/H # and do not reshape the input instead, which is slow though input = input.view(-1, H, T) output = torch.nn.functional.conv1d(input, weight, padding=self.padding, groups=self.num_heads) output = output.view(B, C, T) if self.bias is not None: output = output + self.bias.view(1, -1, 1) return output class FairseqDropout(torch.nn.Module): def __init__(self, p, module_name=None): super().__init__() self.p = p self.module_name = module_name self.apply_during_inference = False def forward(self, x, inplace: bool = False): if self.training or self.apply_during_inference: return torch.nn.functional.dropout(x, p=self.p, training=True, inplace=inplace) else: return x class LinearAttention(torch.nn.Module): def __init__( self, channels: int, calc_channels: int, num_heads: int, dropout_rate: float= 0.1, use_scale: bool= True, use_residual: bool= True, use_norm: bool= True ): super().__init__() assert calc_channels % num_heads == 0 self.calc_channels = calc_channels self.num_heads = num_heads self.use_scale = use_scale self.use_residual = use_residual self.use_norm = use_norm self.prenet = Conv1d( in_channels= channels, out_channels= calc_channels * 3, kernel_size= 1, bias=False, w_init_gain= 'linear' ) self.projection = Conv1d( in_channels= calc_channels, out_channels= channels, kernel_size= 1, w_init_gain= 'linear' ) self.dropout = torch.nn.Dropout(p= dropout_rate) if use_scale: self.scale = torch.nn.Parameter(torch.zeros(1)) if use_norm: self.norm = LayerNorm(num_features= channels) def forward(self, x: torch.Tensor, *args, **kwargs): ''' x: [Batch, Enc_d, Enc_t] ''' residuals = x x = self.prenet(x) # [Batch, Calc_d * 3, Enc_t] x = x.view(x.size(0), self.num_heads, x.size(1) // self.num_heads, x.size(2)) # [Batch, Head, Calc_d // Head * 3, Enc_t] queries, keys, values = x.chunk(chunks= 3, dim= 2) # [Batch, Head, Calc_d // Head, Enc_t] * 3 keys = (keys + 1e-5).softmax(dim= 3) contexts = keys @ values.permute(0, 1, 3, 2) # [Batch, Head, Calc_d // Head, Calc_d // Head] contexts = contexts.permute(0, 1, 3, 2) @ queries # [Batch, Head, Calc_d // Head, Enc_t] contexts = contexts.view(contexts.size(0), contexts.size(1) * contexts.size(2), contexts.size(3)) # [Batch, Calc_d, Enc_t] contexts = self.projection(contexts) # [Batch, Enc_d, Enc_t] if self.use_scale: contexts = self.scale * contexts contexts = self.dropout(contexts) if self.use_residual: contexts = contexts + residuals if self.use_norm: contexts = self.norm(contexts) return contexts