|
import logging |
|
from dataclasses import dataclass |
|
from typing import Union |
|
|
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor, nn |
|
from torch.nn.utils.parametrizations import weight_norm |
|
|
|
from ...common import Normalizer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class IRMAEOutput: |
|
latent: Tensor |
|
decoded: Union[Tensor, None] |
|
|
|
|
|
class ResBlock(nn.Sequential): |
|
def __init__(self, channels, dilations=[1, 2, 4, 8]): |
|
wn = weight_norm |
|
super().__init__( |
|
nn.GroupNorm(32, channels), |
|
nn.GELU(), |
|
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])), |
|
nn.GroupNorm(32, channels), |
|
nn.GELU(), |
|
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])), |
|
nn.GroupNorm(32, channels), |
|
nn.GELU(), |
|
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])), |
|
nn.GroupNorm(32, channels), |
|
nn.GELU(), |
|
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])), |
|
) |
|
|
|
def forward(self, x: Tensor): |
|
return x + super().forward(x) |
|
|
|
|
|
class IRMAE(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim, |
|
output_dim, |
|
latent_dim, |
|
hidden_dim=1024, |
|
num_irms=4, |
|
): |
|
""" |
|
Args: |
|
input_dim: input dimension |
|
output_dim: output dimension |
|
latent_dim: latent dimension |
|
hidden_dim: hidden layer dimension |
|
num_irm_matrics: number of implicit rank minimization matrices |
|
norm: normalization layer |
|
""" |
|
self.input_dim = input_dim |
|
super().__init__() |
|
|
|
self.encoder = nn.Sequential( |
|
nn.Conv1d(input_dim, hidden_dim, 3, padding="same"), |
|
*[ResBlock(hidden_dim) for _ in range(4)], |
|
|
|
*[ |
|
nn.Conv1d( |
|
hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False |
|
) |
|
for i in range(num_irms) |
|
], |
|
nn.Tanh(), |
|
) |
|
|
|
self.decoder = nn.Sequential( |
|
nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"), |
|
*[ResBlock(hidden_dim) for _ in range(4)], |
|
nn.Conv1d(hidden_dim, output_dim, 1), |
|
) |
|
|
|
self.head = nn.Sequential( |
|
nn.Conv1d(output_dim, hidden_dim, 3, padding="same"), |
|
nn.GELU(), |
|
nn.Conv1d(hidden_dim, input_dim, 1), |
|
) |
|
|
|
self.estimator = Normalizer() |
|
|
|
def encode(self, x): |
|
""" |
|
Args: |
|
x: (b c t) tensor |
|
""" |
|
z = self.encoder(x) |
|
_ = self.estimator(z) |
|
self.stats = {} |
|
self.stats["z_mean"] = z.mean().item() |
|
self.stats["z_std"] = z.std().item() |
|
z_float = z.float() |
|
self.stats["z_abs_68"] = z_float.abs().quantile(0.6827).item() |
|
self.stats["z_abs_95"] = z_float.abs().quantile(0.9545).item() |
|
self.stats["z_abs_99"] = z_float.abs().quantile(0.9973).item() |
|
return z |
|
|
|
def decode(self, z): |
|
""" |
|
Args: |
|
z: (b c t) tensor |
|
""" |
|
return self.decoder(z) |
|
|
|
def forward(self, x, skip_decoding=False): |
|
""" |
|
Args: |
|
x: (b c t) tensor |
|
skip_decoding: if True, skip the decoding step |
|
""" |
|
z = self.encode(x) |
|
|
|
if skip_decoding: |
|
|
|
decoded = None |
|
else: |
|
decoded = self.decode(z) |
|
predicted = self.head(decoded) |
|
self.losses = dict(mse=F.mse_loss(predicted, x)) |
|
|
|
return IRMAEOutput(latent=z, decoded=decoded) |
|
|