MobileStyleGAN / model.py
hysts's picture
hysts HF staff
Update
8575998
raw
history blame
1.54 kB
import sys
import torch
import torch.nn as nn
sys.path.insert(0, "MobileStyleGAN.pytorch")
from core.models.mapping_network import MappingNetwork
from core.models.mobile_synthesis_network import MobileSynthesisNetwork
from core.models.synthesis_network import SynthesisNetwork
class Model(nn.Module):
def __init__(self):
super().__init__()
# teacher model
mapping_net_params = {"style_dim": 512, "n_layers": 8, "lr_mlp": 0.01}
synthesis_net_params = {
"size": 1024,
"style_dim": 512,
"blur_kernel": [1, 3, 3, 1],
"channels": [512, 512, 512, 512, 512, 256, 128, 64, 32],
}
self.mapping_net = MappingNetwork(**mapping_net_params).eval()
self.synthesis_net = SynthesisNetwork(**synthesis_net_params).eval()
# student network
self.student = MobileSynthesisNetwork(
style_dim=self.mapping_net.style_dim, channels=synthesis_net_params["channels"][:-1]
)
self.style_mean = nn.Parameter(torch.zeros((1, 512)), requires_grad=False)
def forward(self, var: torch.Tensor, truncation_psi: float = 0.5, generator: str = "student") -> torch.Tensor:
style = self.mapping_net(var)
style = self.style_mean + truncation_psi * (style - self.style_mean)
if generator == "student":
img = self.student(style)["img"]
elif generator == "teacher":
img = self.synthesis_net(style)["img"]
else:
raise ValueError
return img