Spaces:
Runtime error
Runtime error
import torch | |
from .base import BaseGenerator | |
from torchvision.transforms.functional import gaussian_blur | |
import torch.nn.functional as F | |
class PixelationGenerator(BaseGenerator): | |
def __init__(self, pixelation_size, **kwargs): | |
super().__init__(z_channels=0) | |
self.pixelation_size = pixelation_size | |
self.z_channels = 0 | |
self.latent_space = None | |
def forward(self, img, condition, mask, **kwargs): | |
old_shape = img.shape[-2:] | |
img = F.interpolate(img, size=( | |
self.pixelation_size, self.pixelation_size), mode="bilinear", align_corners=True) | |
img = F.interpolate(img, size=old_shape, mode="bilinear", align_corners=True) | |
out = img*(1-mask) + condition*mask | |
return {"img": out} | |
class MaskOutGenerator(BaseGenerator): | |
def __init__(self, noise: str, **kwargs): | |
super().__init__(z_channels=0) | |
self.noise = noise | |
self.z_channels = 0 | |
assert self.noise in ["rand", "constant"] | |
self.latent_space = None | |
def forward(self, img, condition, mask, **kwargs): | |
if self.noise == "constant": | |
img = torch.zeros_like(img) | |
elif self.noise == "rand": | |
img = torch.rand_like(img) | |
out = img*(1-mask) + condition*mask | |
return {"img": out} | |
class IdentityGenerator(BaseGenerator): | |
def __init__(self): | |
super().__init__(z_channels=0) | |
def forward(self, img, condition, mask, **kwargs): | |
return dict(img=img) | |
class GaussianBlurGenerator(BaseGenerator): | |
def __init__(self): | |
super().__init__(z_channels=0) | |
self.sigma = 7 | |
def forward(self, img, condition, mask, **kwargs): | |
img_blur = gaussian_blur(img, kernel_size=min(self.sigma*3, img.shape[-1]), sigma=self.sigma) | |
return dict(img=img * mask + (1-mask) * img_blur) | |