haakohu's picture
fix
44539fc
import torch
import numpy as np
import tqdm
import tops
from ..layers import Module
from ..layers.sg2_layers import FullyConnectedLayer
class BaseGenerator(Module):
def __init__(self, z_channels: int):
super().__init__()
self.z_channels = z_channels
self.latent_space = "Z"
@torch.no_grad()
def get_z(
self,
x: torch.Tensor = None,
z: torch.Tensor = None,
truncation_value: float = None,
batch_size: int = None,
dtype=None, device=None) -> torch.Tensor:
"""Generates a latent variable for generator.
"""
if z is not None:
return z
if x is not None:
batch_size = x.shape[0]
dtype = x.dtype
device = x.device
if device is None:
device = tops.get_device()
if truncation_value == 0:
return torch.zeros((batch_size, self.z_channels), device=device, dtype=dtype)
z = torch.randn((batch_size, self.z_channels), device=device, dtype=dtype)
if truncation_value is None:
return z
while z.abs().max() > truncation_value:
m = z.abs() > truncation_value
z[m] = torch.rand_like(z)[m]
return z
def sample(self, truncation_value, z=None, **kwargs):
"""
Samples via interpolating to the mean (0).
"""
if truncation_value is None:
return self.forward(**kwargs)
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
if z is None:
z = self.get_z(kwargs["condition"])
z = z * truncation_value
return self.forward(**kwargs, z=z)
class SG2StyleNet(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
num_layers=2, # Number of mapping layers.
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
w_avg_beta=0.998, # Decay for tracking the moving average of W during training.
):
super().__init__()
self.z_dim = z_dim
self.w_dim = w_dim
self.num_layers = num_layers
self.w_avg_beta = w_avg_beta
# Construct layers.
features = [self.z_dim] + [self.w_dim] * self.num_layers
for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)
self.register_buffer('w_avg', torch.zeros([w_dim]))
def forward(self, z, update_emas=False, **kwargs):
tops.assert_shape(z, [None, self.z_dim])
# Embed, normalize, and concatenate inputs.
x = z.to(torch.float32)
x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
# Execute layers.
for idx in range(self.num_layers):
x = getattr(self, f'fc{idx}')(x)
# Update moving average of W.
if update_emas:
self.w_avg.copy_(x.float().detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
return x
def extra_repr(self):
return f'z_dim={self.z_dim:d}, w_dim={self.w_dim:d}'
def update_w(self, n=int(10e3), batch_size=32):
"""
Calculate w_ema over n iterations.
Useful in cases where w_ema is calculated incorrectly during training.
"""
n = n // batch_size
for i in tqdm.trange(n, desc="Updating w"):
z = torch.randn((batch_size, self.z_dim), device=tops.get_device())
self(z, update_emas=True)
def get_truncated(self, truncation_value, condition, z=None, **kwargs):
if z is None:
z = torch.randn((condition.shape[0], self.z_dim), device=tops.get_device())
w = self(z)
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
return self.w_avg.to(w.dtype).lerp(w, truncation_value)
def multi_modal_truncate(self, truncation_value, condition, w_indices, z=None, **kwargs):
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
if z is None:
z = torch.randn((condition.shape[0], self.z_dim), device=tops.get_device())
w = self(z)
if w_indices is None:
w_indices = np.random.randint(0, len(self.w_centers), size=(len(w)))
w_centers = self.w_centers[w_indices].to(w.device)
w = w_centers.to(w.dtype).lerp(w, truncation_value)
return w
class BaseStyleGAN(BaseGenerator):
def __init__(self, z_channels: int, w_dim: int):
super().__init__(z_channels)
self.style_net = SG2StyleNet(z_channels, w_dim)
self.latent_space = "W"
def get_w(self, z, update_emas):
return self.style_net(z, update_emas=update_emas)
@torch.no_grad()
def sample(self, truncation_value, **kwargs):
if truncation_value is None:
return self.forward(**kwargs)
w = self.style_net.get_truncated(truncation_value, **kwargs)
return self.forward(**kwargs, w=w)
def update_w(self, *args, **kwargs):
self.style_net.update_w(*args, **kwargs)
@torch.no_grad()
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
w = self.style_net.multi_modal_truncate(truncation_value, w_indices=w_indices, **kwargs)
return self.forward(**kwargs, w=w)