Spaces:
Running
Running
from functools import partial | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from modules.layers.simswap.pg_modules.blocks import DownBlock, DownBlockPatch, conv2d | |
from modules.layers.simswap.pg_modules.projector import F_RandomProj | |
from modules.layers.simswap.pg_modules.diffaug import DiffAugment | |
class SingleDisc(nn.Module): | |
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False): | |
super().__init__() | |
channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, | |
256: 32, 512: 16, 1024: 8} | |
# interpolate for start sz that are not powers of two | |
if start_sz not in channel_dict.keys(): | |
sizes = np.array(list(channel_dict.keys())) | |
start_sz = sizes[np.argmin(abs(sizes - start_sz))] | |
self.start_sz = start_sz | |
# if given ndf, allocate all layers with the same ndf | |
if ndf is None: | |
nfc = channel_dict | |
else: | |
nfc = {k: ndf for k, v in channel_dict.items()} | |
# for feature map discriminators with nfc not in channel_dict | |
# this is the case for the pretrained backbone (midas.pretrained) | |
if nc is not None and head is None: | |
nfc[start_sz] = nc | |
layers = [] | |
# Head if the initial input is the full modality | |
if head: | |
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), | |
nn.LeakyReLU(0.2, inplace=True)] | |
# Down Blocks | |
DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) | |
while start_sz > end_sz: | |
layers.append(DB(nfc[start_sz], nfc[start_sz//2])) | |
start_sz = start_sz // 2 | |
layers.append(conv2d(nfc[end_sz], 1, 4, 1, 0, bias=False)) | |
self.main = nn.Sequential(*layers) | |
def forward(self, x, c): | |
return self.main(x) | |
class SingleDiscCond(nn.Module): | |
def __init__(self, nc=None, ndf=None, start_sz=256, end_sz=8, head=None, separable=False, patch=False, c_dim=1000, cmap_dim=64, embedding_dim=128): | |
super().__init__() | |
self.cmap_dim = cmap_dim | |
# midas channels | |
channel_dict = {4: 512, 8: 512, 16: 256, 32: 128, 64: 64, 128: 64, | |
256: 32, 512: 16, 1024: 8} | |
# interpolate for start sz that are not powers of two | |
if start_sz not in channel_dict.keys(): | |
sizes = np.array(list(channel_dict.keys())) | |
start_sz = sizes[np.argmin(abs(sizes - start_sz))] | |
self.start_sz = start_sz | |
# if given ndf, allocate all layers with the same ndf | |
if ndf is None: | |
nfc = channel_dict | |
else: | |
nfc = {k: ndf for k, v in channel_dict.items()} | |
# for feature map discriminators with nfc not in channel_dict | |
# this is the case for the pretrained backbone (midas.pretrained) | |
if nc is not None and head is None: | |
nfc[start_sz] = nc | |
layers = [] | |
# Head if the initial input is the full modality | |
if head: | |
layers += [conv2d(nc, nfc[256], 3, 1, 1, bias=False), | |
nn.LeakyReLU(0.2, inplace=True)] | |
# Down Blocks | |
DB = partial(DownBlockPatch, separable=separable) if patch else partial(DownBlock, separable=separable) | |
while start_sz > end_sz: | |
layers.append(DB(nfc[start_sz], nfc[start_sz//2])) | |
start_sz = start_sz // 2 | |
self.main = nn.Sequential(*layers) | |
# additions for conditioning on class information | |
self.cls = conv2d(nfc[end_sz], self.cmap_dim, 4, 1, 0, bias=False) | |
self.embed = nn.Embedding(num_embeddings=c_dim, embedding_dim=embedding_dim) | |
self.embed_proj = nn.Sequential( | |
nn.Linear(self.embed.embedding_dim, self.cmap_dim), | |
nn.LeakyReLU(0.2, inplace=True), | |
) | |
def forward(self, x, c): | |
h = self.main(x) | |
out = self.cls(h) | |
# conditioning via projection | |
cmap = self.embed_proj(self.embed(c.argmax(1))).unsqueeze(-1).unsqueeze(-1) | |
out = (out * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) | |
return out | |
class MultiScaleD(nn.Module): | |
def __init__( | |
self, | |
channels, | |
resolutions, | |
num_discs=4, | |
proj_type=2, # 0 = no projection, 1 = cross channel mixing, 2 = cross scale mixing | |
cond=0, | |
separable=False, | |
patch=False, | |
**kwargs, | |
): | |
super().__init__() | |
assert num_discs in [1, 2, 3, 4] | |
# the first disc is on the lowest level of the backbone | |
self.disc_in_channels = channels[:num_discs] | |
self.disc_in_res = resolutions[:num_discs] | |
Disc = SingleDiscCond if cond else SingleDisc | |
mini_discs = [] | |
for i, (cin, res) in enumerate(zip(self.disc_in_channels, self.disc_in_res)): | |
start_sz = res if not patch else 16 | |
mini_discs += [str(i), Disc(nc=cin, start_sz=start_sz, end_sz=8, separable=separable, patch=patch)], | |
self.mini_discs = nn.ModuleDict(mini_discs) | |
def forward(self, features, c): | |
all_logits = [] | |
for k, disc in self.mini_discs.items(): | |
res = disc(features[k], c).view(features[k].size(0), -1) | |
all_logits.append(res) | |
all_logits = torch.cat(all_logits, dim=1) | |
return all_logits | |
class ProjectedDiscriminator(torch.nn.Module): | |
def __init__( | |
self, | |
diffaug=True, | |
interp224=True, | |
backbone_kwargs={}, | |
**kwargs | |
): | |
super().__init__() | |
self.diffaug = diffaug | |
self.interp224 = interp224 | |
self.feature_network = F_RandomProj(**backbone_kwargs) | |
self.discriminator = MultiScaleD( | |
channels=self.feature_network.CHANNELS, | |
resolutions=self.feature_network.RESOLUTIONS, | |
**backbone_kwargs, | |
) | |
def train(self, mode=True): | |
self.feature_network = self.feature_network.train(False) | |
self.discriminator = self.discriminator.train(mode) | |
return self | |
def eval(self): | |
return self.train(False) | |
def get_feature(self, x): | |
features = self.feature_network(x, get_features=True) | |
return features | |
def forward(self, x, c): | |
# if self.diffaug: | |
# x = DiffAugment(x, policy='color,translation,cutout') | |
# if self.interp224: | |
# x = F.interpolate(x, 224, mode='bilinear', align_corners=False) | |
features,backbone_features = self.feature_network(x) | |
logits = self.discriminator(features, c) | |
return logits,backbone_features | |