deep_privacy2_face / dp2 /generator /dummy_generators.py
haakohu's picture
initial
5d756f1
import torch
from .base import BaseGenerator
from torchvision.transforms.functional import gaussian_blur
import torch.nn.functional as F
class PixelationGenerator(BaseGenerator):
def __init__(self, pixelation_size, **kwargs):
super().__init__(z_channels=0)
self.pixelation_size = pixelation_size
self.z_channels = 0
self.latent_space = None
def forward(self, img, condition, mask, **kwargs):
old_shape = img.shape[-2:]
img = F.interpolate(img, size=(
self.pixelation_size, self.pixelation_size), mode="bilinear", align_corners=True)
img = F.interpolate(img, size=old_shape, mode="bilinear", align_corners=True)
out = img*(1-mask) + condition*mask
return {"img": out}
class MaskOutGenerator(BaseGenerator):
def __init__(self, noise: str, **kwargs):
super().__init__(z_channels=0)
self.noise = noise
self.z_channels = 0
assert self.noise in ["rand", "constant"]
self.latent_space = None
def forward(self, img, condition, mask, **kwargs):
if self.noise == "constant":
img = torch.zeros_like(img)
elif self.noise == "rand":
img = torch.rand_like(img)
out = img*(1-mask) + condition*mask
return {"img": out}
class IdentityGenerator(BaseGenerator):
def __init__(self):
super().__init__(z_channels=0)
def forward(self, img, condition, mask, **kwargs):
return dict(img=img)
class GaussianBlurGenerator(BaseGenerator):
def __init__(self):
super().__init__(z_channels=0)
self.sigma = 7
def forward(self, img, condition, mask, **kwargs):
img_blur = gaussian_blur(img, kernel_size=min(self.sigma*3, img.shape[-1]), sigma=self.sigma)
return dict(img=img * mask + (1-mask) * img_blur)