|
import torch |
|
from torch import nn |
|
from torch.nn import Module |
|
|
|
from models.StyleCLIP.models.stylegan2.model import EqualLinear, PixelNorm |
|
|
|
|
|
class Mapper(Module): |
|
|
|
def __init__(self, opts): |
|
super(Mapper, self).__init__() |
|
|
|
self.opts = opts |
|
layers = [PixelNorm()] |
|
|
|
for i in range(4): |
|
layers.append( |
|
EqualLinear( |
|
512, 512, lr_mul=0.01, activation='fused_lrelu' |
|
) |
|
) |
|
|
|
self.mapping = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
x = self.mapping(x) |
|
return x |
|
|
|
|
|
class SingleMapper(Module): |
|
|
|
def __init__(self, opts): |
|
super(SingleMapper, self).__init__() |
|
|
|
self.opts = opts |
|
|
|
self.mapping = Mapper(opts) |
|
|
|
def forward(self, x): |
|
out = self.mapping(x) |
|
return out |
|
|
|
|
|
class LevelsMapper(Module): |
|
|
|
def __init__(self, opts): |
|
super(LevelsMapper, self).__init__() |
|
|
|
self.opts = opts |
|
|
|
if not opts.no_coarse_mapper: |
|
self.course_mapping = Mapper(opts) |
|
if not opts.no_medium_mapper: |
|
self.medium_mapping = Mapper(opts) |
|
if not opts.no_fine_mapper: |
|
self.fine_mapping = Mapper(opts) |
|
|
|
def forward(self, x): |
|
x_coarse = x[:, :4, :] |
|
x_medium = x[:, 4:8, :] |
|
x_fine = x[:, 8:, :] |
|
|
|
if not self.opts.no_coarse_mapper: |
|
x_coarse = self.course_mapping(x_coarse) |
|
else: |
|
x_coarse = torch.zeros_like(x_coarse) |
|
if not self.opts.no_medium_mapper: |
|
x_medium = self.medium_mapping(x_medium) |
|
else: |
|
x_medium = torch.zeros_like(x_medium) |
|
if not self.opts.no_fine_mapper: |
|
x_fine = self.fine_mapping(x_fine) |
|
else: |
|
x_fine = torch.zeros_like(x_fine) |
|
|
|
|
|
out = torch.cat([x_coarse, x_medium, x_fine], dim=1) |
|
|
|
return out |
|
|
|
|