Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from functools import partial | |
from einops.layers.torch import Rearrange, Reduce | |
from einops import rearrange | |
pair = lambda x: x if isinstance(x, tuple) else (x, x) | |
class PreNormResidual(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = nn.LayerNorm(dim) | |
def forward(self, x): | |
return self.fn(self.norm(x)) + x | |
def FeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear): | |
inner_dim = int(dim * expansion_factor) | |
return nn.Sequential( | |
dense(dim, inner_dim), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
dense(inner_dim, dim), | |
nn.Dropout(dropout) | |
) | |
class MappingSub2W(nn.Module): | |
def __init__(self, N=8, dim=512, depth=6, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1): | |
super(MappingSub2W, self).__init__() | |
num_patches = N * 34 | |
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear | |
self.layer = nn.Sequential( | |
Rearrange('b c h w -> b (c h) w'), | |
*[nn.Sequential( | |
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), | |
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) | |
) for _ in range(depth)], | |
nn.LayerNorm(dim), | |
Rearrange('b c h -> b h c'), | |
nn.Linear(34 * N, 34 * N), | |
nn.LayerNorm(34 * N), | |
nn.GELU(), | |
nn.Linear(34 * N, N), | |
Rearrange('b h c -> b c h') | |
) | |
def forward(self, x): | |
return self.layer(x) | |
class MappingW2Sub(nn.Module): | |
def __init__(self, N=8, dim=512, depth=8, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1): | |
super(MappingW2Sub, self).__init__() | |
self.N = N | |
num_patches = N * 34 | |
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear | |
self.layer = nn.Sequential( | |
Rearrange('b c h -> b h c'), | |
nn.Linear(N, num_patches), | |
Rearrange('b h c -> b c h'), | |
*[nn.Sequential( | |
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), | |
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) | |
) for _ in range(depth)], | |
nn.LayerNorm(dim) | |
) | |
self.mu_fc = nn.Sequential( | |
*[nn.Sequential( | |
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), | |
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) | |
) for _ in range(2)], | |
nn.LayerNorm(dim), | |
nn.Tanh(), | |
Rearrange('b c h -> b (c h)') | |
) | |
self.var_fc = nn.Sequential( | |
*[nn.Sequential( | |
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), | |
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) | |
) for _ in range(2)], | |
nn.LayerNorm(dim), | |
nn.Tanh(), | |
Rearrange('b c h -> b (c h)') | |
) | |
def reparameterize(self, mu, logvar): | |
""" | |
Reparameterization trick to sample from N(mu, var) from | |
N(0,1). | |
:param mu: (Tensor) Mean of the latent Gaussian [B x D] | |
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] | |
:return: (Tensor) [B x D] | |
""" | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
return eps * std + mu | |
def forward(self, x): | |
f = self.layer(x) | |
mu = self.mu_fc(f) | |
var = self.var_fc(f) | |
z = self.reparameterize(mu, var) | |
z = rearrange(z, 'a (b c d) -> a b c d', b=self.N, c=34) | |
return rearrange(mu, 'a (b c d) -> a b c d', b=self.N, c=34), rearrange(var, 'a (b c d) -> a b c d', | |
b=self.N, c=34), z | |
class HeadEncoder(nn.Module): | |
def __init__(self, N=8, dim=512, depth=2, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1): | |
super(HeadEncoder, self).__init__() | |
channels = [32, 64, 64, 64] | |
self.N = N | |
num_patches = N | |
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear | |
self.s1 = nn.Sequential( | |
nn.Conv2d(channels[0], channels[1], kernel_size=5, padding=2, stride=2), | |
nn.BatchNorm2d(channels[1]), | |
nn.LeakyReLU(), | |
nn.Conv2d(channels[1], channels[2], kernel_size=5, padding=2, stride=2), | |
nn.BatchNorm2d(channels[2]), | |
nn.LeakyReLU(), | |
nn.Conv2d(channels[2], channels[3], kernel_size=5, padding=2, stride=2), | |
nn.BatchNorm2d(channels[3]), | |
nn.LeakyReLU()) | |
self.mlp1 = nn.Linear(channels[3] * 8 * 8, 512) | |
self.up_N = nn.Linear(1, N) | |
self.mu_fc = nn.Sequential( | |
*[nn.Sequential( | |
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), | |
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) | |
) for _ in range(depth)], | |
nn.LayerNorm(dim), | |
nn.Tanh() | |
) | |
self.var_fc = nn.Sequential( | |
*[nn.Sequential( | |
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)), | |
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) | |
) for _ in range(depth)], | |
nn.LayerNorm(dim), | |
nn.Tanh() | |
) | |
def reparameterize(self, mu, logvar): | |
""" | |
Reparameterization trick to sample from N(mu, var) from | |
N(0,1). | |
:param mu: (Tensor) Mean of the latent Gaussian [B x D] | |
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] | |
:return: (Tensor) [B x D] | |
""" | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
return eps * std + mu | |
def forward(self, x): | |
feature = self.s1(x) | |
s2 = torch.flatten(feature, start_dim=1) | |
s2 = self.mlp1(s2).unsqueeze(2) | |
s2 = self.up_N(s2) | |
s2 = rearrange(s2, 'b h c -> b c h') | |
mu = self.mu_fc(s2) | |
var = self.var_fc(s2) | |
z = self.reparameterize(mu, var) | |
return mu, var, z | |
class RegionEncoder(nn.Module): | |
def __init__(self, N=8): | |
super(RegionEncoder, self).__init__() | |
channels = [8, 16, 32, 32, 64, 64] | |
self.s1 = nn.Conv2d(3, channels[0], kernel_size=3, padding=1, stride=2) | |
self.s2 = nn.Sequential( | |
nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, stride=2), | |
nn.BatchNorm2d(channels[1]), | |
nn.LeakyReLU(), | |
nn.Conv2d(channels[1], channels[2], kernel_size=3, padding=1, stride=2), | |
nn.BatchNorm2d(channels[2]), | |
nn.LeakyReLU() | |
) | |
self.heads = nn.ModuleList() | |
for i in range(34): | |
self.heads.append(HeadEncoder(N=N)) | |
def forward(self, x, all_mask=None): | |
s1 = self.s1(x) | |
s2 = self.s2(s1) | |
result = [] | |
mus = [] | |
log_vars = [] | |
for i, head in enumerate(self.heads): | |
m = all_mask[:, i, :].unsqueeze(1) | |
mu, var, z = head(s2 * m) | |
result.append(z.unsqueeze(2)) | |
mus.append(mu.unsqueeze(2)) | |
log_vars.append(var.unsqueeze(2)) | |
return torch.cat(mus, dim=2), torch.cat(log_vars, dim=2), torch.cat(result, dim=2) | |