# Copyright 2022 Christian J. Steinmetz # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # TCN implementation adapted from: # https://github.com/csteinmetz1/micro-tcn/blob/main/microtcn/tcn.py import torch from argparse import ArgumentParser from deepafx_st.utils import center_crop, causal_crop class FiLM(torch.nn.Module): def __init__(self, num_features, cond_dim): super().__init__() self.num_features = num_features self.bn = torch.nn.BatchNorm1d(num_features, affine=False) self.adaptor = torch.nn.Linear(cond_dim, num_features * 2) def forward(self, x, cond): # project conditioning to 2 x num. conv channels cond = self.adaptor(cond) # split the projection into gain and bias g, b = torch.chunk(cond, 2, dim=-1) # add virtual channel dim if needed if g.ndim == 2: g = g.unsqueeze(1) b = b.unsqueeze(1) # reshape for application g = g.permute(0, 2, 1) b = b.permute(0, 2, 1) x = self.bn(x) # apply BatchNorm without affine x = (x * g) + b # then apply conditional affine return x class ConditionalTCNBlock(torch.nn.Module): def __init__( self, in_ch, out_ch, cond_dim, kernel_size=3, dilation=1, causal=False, **kwargs ): super().__init__() self.in_ch = in_ch self.out_ch = out_ch self.kernel_size = kernel_size self.dilation = dilation self.causal = causal self.conv1 = torch.nn.Conv1d( in_ch, out_ch, kernel_size=kernel_size, padding=0, dilation=dilation, bias=True, ) self.film = FiLM(out_ch, cond_dim) self.relu = torch.nn.PReLU(out_ch) self.res = torch.nn.Conv1d( in_ch, out_ch, kernel_size=1, groups=in_ch, bias=False ) def forward(self, x, p): x_in = x x = self.conv1(x) x = self.film(x, p) # apply FiLM conditioning x = self.relu(x) x_res = self.res(x_in) if self.causal: x = x + causal_crop(x_res, x.shape[-1]) else: x = x + center_crop(x_res, x.shape[-1]) return x class ConditionalTCN(torch.nn.Module): """Temporal convolutional network with conditioning module. Args: sample_rate (float): Audio sample rate. num_control_params (int, optional): Dimensionality of the conditioning signal. Default: 24 ninputs (int, optional): Number of input channels (mono = 1, stereo 2). Default: 1 noutputs (int, optional): Number of output channels (mono = 1, stereo 2). Default: 1 nblocks (int, optional): Number of total TCN blocks. Default: 10 kernel_size (int, optional: Width of the convolutional kernels. Default: 3 dialation_growth (int, optional): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1 channel_growth (int, optional): Compute the output channels at each black as in_ch * channel_growth. Default: 2 channel_width (int, optional): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64 stack_size (int, optional): Number of blocks that constitute a single stack of blocks. Default: 10 causal (bool, optional): Causal TCN configuration does not consider future input values. Default: False """ def __init__( self, sample_rate, num_control_params=24, ninputs=1, noutputs=1, nblocks=10, kernel_size=15, dilation_growth=2, channel_growth=1, channel_width=64, stack_size=10, causal=False, skip_connections=False, **kwargs, ): super().__init__() self.num_control_params = num_control_params self.ninputs = ninputs self.noutputs = noutputs self.nblocks = nblocks self.kernel_size = kernel_size self.dilation_growth = dilation_growth self.channel_growth = channel_growth self.channel_width = channel_width self.stack_size = stack_size self.causal = causal self.skip_connections = skip_connections self.sample_rate = sample_rate self.blocks = torch.nn.ModuleList() for n in range(nblocks): in_ch = out_ch if n > 0 else ninputs if self.channel_growth > 1: out_ch = in_ch * self.channel_growth else: out_ch = self.channel_width dilation = self.dilation_growth ** (n % self.stack_size) self.blocks.append( ConditionalTCNBlock( in_ch, out_ch, self.num_control_params, kernel_size=self.kernel_size, dilation=dilation, padding="same" if self.causal else "valid", causal=self.causal, ) ) self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1) self.receptive_field = self.compute_receptive_field() # print( # f"TCN receptive field: {self.receptive_field} samples", # f" or {(self.receptive_field/self.sample_rate)*1e3:0.3f} ms", # ) def forward(self, x, p, **kwargs): # causally pad input signal x = torch.nn.functional.pad(x, (self.receptive_field - 1, 0)) # iterate over blocks passing conditioning for idx, block in enumerate(self.blocks): x = block(x, p) if self.skip_connections: if idx == 0: skips = x else: skips = center_crop(skips, x[-1]) + x else: skips = 0 # final 1x1 convolution to collapse channels out = self.output(x + skips) return out def compute_receptive_field(self): """Compute the receptive field in samples.""" rf = self.kernel_size for n in range(1, self.nblocks): dilation = self.dilation_growth ** (n % self.stack_size) rf = rf + ((self.kernel_size - 1) * dilation) return rf