Spaces:
Runtime error
Runtime error
import gin | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class FiLM(nn.Module): | |
def forward(self, x, gamma, beta): | |
return gamma * x + beta | |
class TimeDistributedLayerNorm(nn.Module): | |
def __init__(self, size: int): | |
super().__init__() | |
self.layer_norm = nn.LayerNorm(size) | |
def forward(self, x): | |
return self.layer_norm(x.transpose(1, 2)).transpose(1, 2) | |
class TimeDistributedMLP(nn.Module): | |
def __init__(self, in_size: int, hidden_size: int, out_size: int, depth: int = 3): | |
super().__init__() | |
assert depth >= 3, "Depth must be at least 3" | |
layers = [] | |
for i in range(depth): | |
layers.append( | |
nn.Conv1d( | |
in_size if i == 0 else hidden_size, | |
hidden_size if i < depth - 1 else out_size, | |
1, | |
) | |
) | |
if i < depth - 1: | |
layers.append(TimeDistributedLayerNorm(hidden_size)) | |
layers.append(nn.LeakyReLU()) | |
self.net = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.net(x) | |