from torch import nn from .encoder import Encoder from .styledecoder import Synthesis class Generator(nn.Module): def __init__(self, size, style_dim=512, motion_dim=20, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]): super(Generator, self).__init__() # encoder self.enc = Encoder(size, style_dim, motion_dim) self.dec = Synthesis(size, style_dim, motion_dim, blur_kernel, channel_multiplier) def get_direction(self): return self.dec.direction(None) def synthesis(self, wa, alpha, feat): img = self.dec(wa, alpha, feat) return img def forward(self, img_source, img_drive, h_start=None): wa, alpha, feats = self.enc(img_source, img_drive, h_start) # import pdb;pdb.set_trace() img_recon = self.dec(wa, alpha, feats) return img_recon