Spaces:
Build error
Build error
File size: 998 Bytes
98f685a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
import torch.nn as nn
class Film(nn.Module):
def __init__(self, channels, cond_embedding_dim):
super(Film, self).__init__()
self.linear = nn.Sequential(
nn.Linear(cond_embedding_dim, channels * 2),
nn.ReLU(inplace=True),
nn.Linear(channels * 2, channels),
nn.ReLU(inplace=True)
)
def forward(self, data, cond_vec):
"""
:param data: [batchsize, channels, samples] or [batchsize, channels, T, F] or [batchsize, channels, F, T]
:param cond_vec: [batchsize, cond_embedding_dim]
:return:
"""
bias = self.linear(cond_vec) # [batchsize, channels]
if len(list(data.size())) == 3:
data = data + bias[..., None]
elif len(list(data.size())) == 4:
data = data + bias[..., None, None]
else:
print("Warning: The size of input tensor,", data.size(), "is not correct. Film is not working.")
return data |