import sys import torch import torch.nn as nn sys.path.insert(0, 'MobileStyleGAN.pytorch') from core.models.mapping_network import MappingNetwork from core.models.mobile_synthesis_network import MobileSynthesisNetwork from core.models.synthesis_network import SynthesisNetwork class Model(nn.Module): def __init__(self): super().__init__() # teacher model mapping_net_params = {'style_dim': 512, 'n_layers': 8, 'lr_mlp': 0.01} synthesis_net_params = { 'size': 1024, 'style_dim': 512, 'blur_kernel': [1, 3, 3, 1], 'channels': [512, 512, 512, 512, 512, 256, 128, 64, 32] } self.mapping_net = MappingNetwork(**mapping_net_params).eval() self.synthesis_net = SynthesisNetwork(**synthesis_net_params).eval() # student network self.student = MobileSynthesisNetwork( style_dim=self.mapping_net.style_dim, channels=synthesis_net_params['channels'][:-1]) self.style_mean = nn.Parameter(torch.zeros((1, 512)), requires_grad=False) def forward(self, var: torch.Tensor, truncation_psi: float = 0.5, generator: str = 'student') -> torch.Tensor: style = self.mapping_net(var) style = self.style_mean + truncation_psi * (style - self.style_mean) if generator == 'student': img = self.student(style)['img'] elif generator == 'teacher': img = self.synthesis_net(style)['img'] else: raise ValueError return img