RemFx / remfx /tcn.py
mattricesound's picture
Update to latest classifier inference
568c3f1
raw
history blame contribute delete
No virus
3.99 kB
# This code is based on the following repository written by Christian J. Steinmetz
# https://github.com/csteinmetz1/micro-tcn
from typing import Callable
import torch
import torch.nn as nn
from torch import Tensor
from remfx.utils import causal_crop, center_crop
class TCNBlock(nn.Module):
def __init__(
self,
in_ch: int,
out_ch: int,
kernel_size: int = 3,
dilation: int = 1,
stride: int = 1,
crop_fn: Callable = causal_crop,
) -> None:
super().__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.kernel_size = kernel_size
self.stride = stride
self.crop_fn = crop_fn
self.conv1 = nn.Conv1d(
in_ch,
out_ch,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
bias=True,
)
# residual connection
self.res = nn.Conv1d(
in_ch,
out_ch,
kernel_size=1,
groups=1,
stride=stride,
bias=False,
)
self.relu = nn.PReLU(out_ch)
def forward(self, x: Tensor) -> Tensor:
x_in = x
x = self.conv1(x)
x = self.relu(x)
# residual
x_res = self.res(x_in)
# causal crop
x = x + self.crop_fn(x_res, x.shape[-1])
return x
class TCN(nn.Module):
def __init__(
self,
ninputs: int = 1,
noutputs: int = 1,
nblocks: int = 4,
channel_growth: int = 0,
channel_width: int = 32,
kernel_size: int = 13,
stack_size: int = 10,
dilation_growth: int = 10,
condition: bool = False,
latent_dim: int = 2,
norm_type: str = "identity",
causal: bool = False,
estimate_loudness: bool = False,
) -> None:
super().__init__()
self.ninputs = ninputs
self.noutputs = noutputs
self.nblocks = nblocks
self.channel_growth = channel_growth
self.channel_width = channel_width
self.kernel_size = kernel_size
self.stack_size = stack_size
self.dilation_growth = dilation_growth
self.condition = condition
self.latent_dim = latent_dim
self.norm_type = norm_type
self.causal = causal
self.estimate_loudness = estimate_loudness
if self.causal:
self.crop_fn = causal_crop
else:
self.crop_fn = center_crop
if estimate_loudness:
self.loudness = torch.nn.Linear(latent_dim, 1)
# audio model
self.process_blocks = torch.nn.ModuleList()
out_ch = -1
for n in range(nblocks):
in_ch = out_ch if n > 0 else ninputs
out_ch = in_ch * channel_growth if channel_growth > 1 else channel_width
dilation = dilation_growth ** (n % stack_size)
self.process_blocks.append(
TCNBlock(
in_ch,
out_ch,
kernel_size,
dilation,
stride=1,
crop_fn=self.crop_fn,
)
)
self.output = nn.Conv1d(out_ch, noutputs, kernel_size=1)
# model configuration
self.receptive_field = self.compute_receptive_field()
self.block_size = 2048
self.buffer = torch.zeros(2, self.receptive_field + self.block_size - 1)
def forward(self, x: Tensor) -> Tensor:
for _, block in enumerate(self.process_blocks):
x = block(x)
y_hat = torch.tanh(self.output(x))
return y_hat
def compute_receptive_field(self):
"""Compute the receptive field in samples."""
rf = self.kernel_size
for n in range(1, self.nblocks):
dilation = self.dilation_growth ** (n % self.stack_size)
rf = rf + ((self.kernel_size - 1) * dilation)
return rf