Spaces:
Runtime error
Runtime error
import torch | |
import tops | |
def r1_regularization( | |
real_img, real_score, mask, lambd: float, lazy_reg_interval: int, | |
lazy_regularization: bool, | |
scaler: torch.cuda.amp.GradScaler, mask_out: bool, | |
mask_out_scale: bool, | |
**kwargs | |
): | |
grad = torch.autograd.grad( | |
outputs=scaler.scale(real_score), | |
inputs=real_img, | |
grad_outputs=torch.ones_like(real_score), | |
create_graph=True, | |
only_inputs=True, | |
)[0] | |
inv_scale = 1.0 / scaler.get_scale() | |
grad = grad * inv_scale | |
with torch.cuda.amp.autocast(tops.AMP()): | |
if mask_out: | |
grad = grad * (1 - mask) | |
grad = grad.square().sum(dim=[1, 2, 3]) | |
if mask_out and mask_out_scale: | |
total_pixels = real_img.shape[1] * real_img.shape[2] * real_img.shape[3] | |
n_fake = (1-mask).sum(dim=[1, 2, 3]) | |
scaling = total_pixels / n_fake | |
grad = grad * scaling | |
if lazy_regularization: | |
lambd_ = lambd * lazy_reg_interval / 2 # From stylegan2, lazy regularization | |
return grad * lambd_, grad.detach() | |