Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import tqdm | |
import tops | |
from ..layers import Module | |
from ..layers.sg2_layers import FullyConnectedLayer | |
class BaseGenerator(Module): | |
def __init__(self, z_channels: int): | |
super().__init__() | |
self.z_channels = z_channels | |
self.latent_space = "Z" | |
def get_z( | |
self, | |
x: torch.Tensor = None, | |
z: torch.Tensor = None, | |
truncation_value: float = None, | |
batch_size: int = None, | |
dtype=None, device=None) -> torch.Tensor: | |
"""Generates a latent variable for generator. | |
""" | |
if z is not None: | |
return z | |
if x is not None: | |
batch_size = x.shape[0] | |
dtype = x.dtype | |
device = x.device | |
if device is None: | |
device = tops.get_device() | |
if truncation_value == 0: | |
return torch.zeros((batch_size, self.z_channels), device=device, dtype=dtype) | |
z = torch.randn((batch_size, self.z_channels), device=device, dtype=dtype) | |
if truncation_value is None: | |
return z | |
while z.abs().max() > truncation_value: | |
m = z.abs() > truncation_value | |
z[m] = torch.rand_like(z)[m] | |
return z | |
def sample(self, truncation_value, z=None, **kwargs): | |
""" | |
Samples via interpolating to the mean (0). | |
""" | |
if truncation_value is None: | |
return self.forward(**kwargs) | |
truncation_value = max(0, truncation_value) | |
truncation_value = min(truncation_value, 1) | |
if z is None: | |
z = self.get_z(kwargs["condition"]) | |
z = z * truncation_value | |
return self.forward(**kwargs, z=z) | |
class SG2StyleNet(torch.nn.Module): | |
def __init__(self, | |
z_dim, # Input latent (Z) dimensionality. | |
w_dim, # Intermediate latent (W) dimensionality. | |
num_layers=2, # Number of mapping layers. | |
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers. | |
w_avg_beta=0.998, # Decay for tracking the moving average of W during training. | |
): | |
super().__init__() | |
self.z_dim = z_dim | |
self.w_dim = w_dim | |
self.num_layers = num_layers | |
self.w_avg_beta = w_avg_beta | |
# Construct layers. | |
features = [self.z_dim] + [self.w_dim] * self.num_layers | |
for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]): | |
layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier) | |
setattr(self, f'fc{idx}', layer) | |
self.register_buffer('w_avg', torch.zeros([w_dim])) | |
def forward(self, z, update_emas=False, **kwargs): | |
tops.assert_shape(z, [None, self.z_dim]) | |
# Embed, normalize, and concatenate inputs. | |
x = z.to(torch.float32) | |
x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt() | |
# Execute layers. | |
for idx in range(self.num_layers): | |
x = getattr(self, f'fc{idx}')(x) | |
# Update moving average of W. | |
if update_emas: | |
self.w_avg.copy_(x.float().detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) | |
return x | |
def extra_repr(self): | |
return f'z_dim={self.z_dim:d}, w_dim={self.w_dim:d}' | |
def update_w(self, n=int(10e3), batch_size=32): | |
""" | |
Calculate w_ema over n iterations. | |
Useful in cases where w_ema is calculated incorrectly during training. | |
""" | |
n = n // batch_size | |
for i in tqdm.trange(n, desc="Updating w"): | |
z = torch.randn((batch_size, self.z_dim), device=tops.get_device()) | |
self(z, update_emas=True) | |
def get_truncated(self, truncation_value, condition, z=None, **kwargs): | |
if z is None: | |
z = torch.randn((condition.shape[0], self.z_dim), device=tops.get_device()) | |
w = self(z) | |
truncation_value = max(0, truncation_value) | |
truncation_value = min(truncation_value, 1) | |
return self.w_avg.to(w.dtype).lerp(w, truncation_value) | |
def multi_modal_truncate(self, truncation_value, condition, w_indices, z=None, **kwargs): | |
truncation_value = max(0, truncation_value) | |
truncation_value = min(truncation_value, 1) | |
if z is None: | |
z = torch.randn((condition.shape[0], self.z_dim), device=tops.get_device()) | |
w = self(z) | |
if w_indices is None: | |
w_indices = np.random.randint(0, len(self.w_centers), size=(len(w))) | |
w_centers = self.w_centers[w_indices].to(w.device) | |
w = w_centers.to(w.dtype).lerp(w, truncation_value) | |
return w | |
class BaseStyleGAN(BaseGenerator): | |
def __init__(self, z_channels: int, w_dim: int): | |
super().__init__(z_channels) | |
self.style_net = SG2StyleNet(z_channels, w_dim) | |
self.latent_space = "W" | |
def get_w(self, z, update_emas): | |
return self.style_net(z, update_emas=update_emas) | |
def sample(self, truncation_value, **kwargs): | |
if truncation_value is None: | |
return self.forward(**kwargs) | |
w = self.style_net.get_truncated(truncation_value, **kwargs) | |
return self.forward(**kwargs, w=w) | |
def update_w(self, *args, **kwargs): | |
self.style_net.update_w(*args, **kwargs) | |
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs): | |
w = self.style_net.multi_modal_truncate(truncation_value, w_indices=w_indices, **kwargs) | |
return self.forward(**kwargs, w=w) | |