# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import typing as tp import numpy as np import torch.nn as nn from .conv import StreamableConv1d, StreamableConvTranspose1d class StreamableLSTM(nn.Module): """LSTM without worrying about the hidden state, nor the layout of the data. Expects input as convolutional layout. """ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): super().__init__() self.skip = skip self.lstm = nn.LSTM(dimension, dimension, num_layers) def forward(self, x): print('LSTM called 1c') x = x.permute(2, 0, 1) y, _ = self.lstm(x) if self.skip: y = y + x y = y.permute(1, 2, 0) return y class SEANetResnetBlock(nn.Module): """Residual block from SEANet model. Args: dim (int): Dimension of the input/output. kernel_sizes (list): List of kernel sizes for the convolutions. dilations (list): List of dilations for the convolutions. activation (str): Activation function. activation_params (dict): Parameters to provide to the activation function. norm (str): Normalization method. norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. causal (bool): Whether to use fully causal convolution. pad_mode (str): Padding mode for the convolutions. compress (int): Reduced dimensionality in residual branches (from Demucs v3). true_skip (bool): Whether to use true skip connection or a simple (streamable) convolution as the skip connection. """ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False, pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True): super().__init__() assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations' act = getattr(nn, activation) hidden = dim // compress block = [] for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): in_chs = dim if i == 0 else hidden out_chs = dim if i == len(kernel_sizes) - 1 else hidden block += [ act(**activation_params), StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation, norm=norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode), ] self.block = nn.Sequential(*block) self.shortcut: nn.Module if true_skip: self.shortcut = nn.Identity() else: self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) def forward(self, x): return self.shortcut(x) + self.block(x) class SEANetDecoder(nn.Module): def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3, ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0}, final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None, norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7, last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False, pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0, disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0): super().__init__() self.dimension = dimension self.channels = channels self.n_filters = n_filters self.ratios = ratios del ratios self.n_residual_layers = n_residual_layers self.hop_length = np.prod(self.ratios) self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks self.disable_norm_outer_blocks = disable_norm_outer_blocks assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \ "Number of blocks for which to disable norm is invalid." \ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." act = getattr(nn, activation) mult = int(2 ** len(self.ratios)) model: tp.List[nn.Module] = [ StreamableConv1d(dimension, mult * n_filters, kernel_size, norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) ] if lstm: print('\n\n\n\nLSTM IN SEANET\n\n\n\n') model += [StreamableLSTM(mult * n_filters, num_layers=lstm)] # Upsample to raw audio scale for i, ratio in enumerate(self.ratios): block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm # Add upsampling layers model += [ act(**activation_params), StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2, kernel_size=ratio * 2, stride=ratio, norm=block_norm, norm_kwargs=norm_params, causal=causal, trim_right_ratio=trim_right_ratio), ] # Add residual layers for j in range(n_residual_layers): model += [ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1], dilations=[dilation_base ** j, 1], activation=activation, activation_params=activation_params, norm=block_norm, norm_params=norm_params, causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)] mult //= 2 # Add final layers model += [ act(**activation_params), StreamableConv1d(n_filters, channels, last_kernel_size, norm='none' if self.disable_norm_outer_blocks >= 1 else norm, norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode) ] # Add optional final activation to decoder (eg. tanh) if final_activation is not None: final_act = getattr(nn, final_activation) final_activation_params = final_activation_params or {} model += [ final_act(**final_activation_params) ] self.model = nn.Sequential(*model) def forward(self, z): print(f'\n Enter seanet with shape {z.shape}\n') # arrives here with (1,128,35) # how can this convnet care for the value that is in z so it crashes? y = self.model(z) print(f'\n Exit seanet with shape {y.shape}\n') # arrives here with (1,128,35) return y