yourusername's picture
:beers: cheers
66a6dc0
import torch
from deepafx_st.processors.proxy.proxy_system import ProxySystem
from deepafx_st.utils import DSPMode
class ProxyChannel(torch.nn.Module):
def __init__(
self,
proxy_system_ckpts: list,
freeze_proxies: bool = True,
dsp_mode: DSPMode = DSPMode.NONE,
num_tcns: int = 2,
tcn_nblocks: int = 4,
tcn_dilation_growth: int = 8,
tcn_channel_width: int = 64,
tcn_kernel_size: int = 13,
sample_rate: int = 24000,
):
super().__init__()
self.freeze_proxies = freeze_proxies
self.dsp_mode = dsp_mode
self.num_tcns = num_tcns
# load the proxies
self.proxies = torch.nn.ModuleList()
self.num_control_params = 0
self.ports = []
for proxy_system_ckpt in proxy_system_ckpts:
proxy = ProxySystem.load_from_checkpoint(proxy_system_ckpt)
# freeze model parameters
if freeze_proxies:
for param in proxy.parameters():
param.requires_grad = False
self.proxies.append(proxy)
if proxy.hparams.processor == "channel":
self.ports = proxy.processor.ports
else:
self.ports.append(proxy.processor.ports)
self.num_control_params += proxy.processor.num_control_params
if len(proxy_system_ckpts) == 0:
if self.num_tcns == 2:
peq_proxy = ProxySystem(
processor="peq",
output_gain=False,
nblocks=tcn_nblocks,
dilation_growth=tcn_dilation_growth,
kernel_size=tcn_kernel_size,
channel_width=tcn_channel_width,
sample_rate=sample_rate,
)
self.proxies.append(peq_proxy)
self.ports.append(peq_proxy.processor.ports)
self.num_control_params += peq_proxy.processor.num_control_params
comp_proxy = ProxySystem(
processor="comp",
output_gain=True,
nblocks=tcn_nblocks,
dilation_growth=tcn_dilation_growth,
kernel_size=tcn_kernel_size,
channel_width=tcn_channel_width,
sample_rate=sample_rate,
)
self.proxies.append(comp_proxy)
self.ports.append(comp_proxy.processor.ports)
self.num_control_params += comp_proxy.processor.num_control_params
elif self.num_tcns == 1:
channel_proxy = ProxySystem(
processor="channel",
output_gain=True,
nblocks=tcn_nblocks,
dilation_growth=tcn_dilation_growth,
kernel_size=tcn_kernel_size,
channel_width=tcn_channel_width,
sample_rate=sample_rate,
)
self.proxies.append(channel_proxy)
for port_list in channel_proxy.processor.ports:
self.ports.append(port_list)
self.num_control_params += channel_proxy.processor.num_control_params
else:
raise ValueError(f"num_tcns must be <= 2. Asked for {self.num_tcns}.")
def forward(
self,
x: torch.Tensor,
p: torch.Tensor,
dsp_mode: DSPMode = DSPMode.NONE,
sample_rate: int = 24000,
**kwargs,
):
# loop over the proxies and pass parameters
stop_idx = 0
for proxy in self.proxies:
start_idx = stop_idx
stop_idx += proxy.processor.num_control_params
p_subset = p[:, start_idx:stop_idx]
if dsp_mode.name == DSPMode.NONE.name:
x = proxy(
x,
p_subset,
use_dsp=False,
)
elif dsp_mode.name == DSPMode.INFER.name:
x = proxy(
x,
p_subset,
use_dsp=True,
sample_rate=sample_rate,
)
elif dsp_mode.name == DSPMode.TRAIN_INFER.name:
# Mimic gumbel softmax implementation to replace grads similar to
# https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
x_hard = proxy(
x,
p_subset,
use_dsp=True,
sample_rate=sample_rate,
)
x = proxy(
x,
p_subset,
use_dsp=False,
sample_rate=sample_rate,
)
x = (x_hard - x).detach() + x
else:
assert 0, "invalid dsp model for proxy"
return x