Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
class Film(nn.Module): | |
def __init__(self, channels, cond_embedding_dim): | |
super(Film, self).__init__() | |
self.linear = nn.Sequential( | |
nn.Linear(cond_embedding_dim, channels * 2), | |
nn.ReLU(inplace=True), | |
nn.Linear(channels * 2, channels), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, data, cond_vec): | |
""" | |
:param data: [batchsize, channels, samples] or [batchsize, channels, T, F] or [batchsize, channels, F, T] | |
:param cond_vec: [batchsize, cond_embedding_dim] | |
:return: | |
""" | |
bias = self.linear(cond_vec) # [batchsize, channels] | |
if len(list(data.size())) == 3: | |
data = data + bias[..., None] | |
elif len(list(data.size())) == 4: | |
data = data + bias[..., None, None] | |
else: | |
print("Warning: The size of input tensor,", data.size(), "is not correct. Film is not working.") | |
return data |