Spaces:
Runtime error
Runtime error
import os | |
from PIL import Image | |
import torchvision.transforms as transforms | |
try: | |
from torchvision.transforms import InterpolationMode | |
bic = InterpolationMode.BICUBIC | |
except ImportError: | |
bic = Image.BICUBIC | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import functools | |
IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".webp"] | |
class UnetGenerator(nn.Module): | |
"""Create a Unet-based generator""" | |
def __init__( | |
self, | |
input_nc, | |
output_nc, | |
num_downs, | |
ngf=64, | |
norm_layer=nn.BatchNorm2d, | |
use_dropout=False, | |
): | |
"""Construct a Unet generator | |
Parameters: | |
input_nc (int) -- the number of channels in input images | |
output_nc (int) -- the number of channels in output images | |
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, | |
image of size 128x128 will become of size 1x1 # at the bottleneck | |
ngf (int) -- the number of filters in the last conv layer | |
norm_layer -- normalization layer | |
We construct the U-Net from the innermost layer to the outermost layer. | |
It is a recursive process. | |
""" | |
super(UnetGenerator, self).__init__() | |
# construct unet structure | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 8, | |
ngf * 8, | |
input_nc=None, | |
submodule=None, | |
norm_layer=norm_layer, | |
innermost=True, | |
) # add the innermost layer | |
for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 8, | |
ngf * 8, | |
input_nc=None, | |
submodule=unet_block, | |
norm_layer=norm_layer, | |
use_dropout=use_dropout, | |
) | |
# gradually reduce the number of filters from ngf * 8 to ngf | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer | |
) | |
unet_block = UnetSkipConnectionBlock( | |
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer | |
) | |
unet_block = UnetSkipConnectionBlock( | |
ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer | |
) | |
self.model = UnetSkipConnectionBlock( | |
output_nc, | |
ngf, | |
input_nc=input_nc, | |
submodule=unet_block, | |
outermost=True, | |
norm_layer=norm_layer, | |
) # add the outermost layer | |
def forward(self, input): | |
"""Standard forward""" | |
return self.model(input) | |
class UnetSkipConnectionBlock(nn.Module): | |
"""Defines the Unet submodule with skip connection. | |
X -------------------identity---------------------- | |
|-- downsampling -- |submodule| -- upsampling --| | |
""" | |
def __init__( | |
self, | |
outer_nc, | |
inner_nc, | |
input_nc=None, | |
submodule=None, | |
outermost=False, | |
innermost=False, | |
norm_layer=nn.BatchNorm2d, | |
use_dropout=False, | |
): | |
"""Construct a Unet submodule with skip connections. | |
Parameters: | |
outer_nc (int) -- the number of filters in the outer conv layer | |
inner_nc (int) -- the number of filters in the inner conv layer | |
input_nc (int) -- the number of channels in input images/features | |
submodule (UnetSkipConnectionBlock) -- previously defined submodules | |
outermost (bool) -- if this module is the outermost module | |
innermost (bool) -- if this module is the innermost module | |
norm_layer -- normalization layer | |
use_dropout (bool) -- if use dropout layers. | |
""" | |
super(UnetSkipConnectionBlock, self).__init__() | |
self.outermost = outermost | |
if type(norm_layer) == functools.partial: | |
use_bias = norm_layer.func == nn.InstanceNorm2d | |
else: | |
use_bias = norm_layer == nn.InstanceNorm2d | |
if input_nc is None: | |
input_nc = outer_nc | |
downconv = nn.Conv2d( | |
input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias | |
) | |
downrelu = nn.LeakyReLU(0.2, True) | |
downnorm = norm_layer(inner_nc) | |
uprelu = nn.ReLU(True) | |
upnorm = norm_layer(outer_nc) | |
if outermost: | |
upconv = nn.ConvTranspose2d( | |
inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1 | |
) | |
down = [downconv] | |
up = [uprelu, upconv, nn.Tanh()] | |
model = down + [submodule] + up | |
elif innermost: | |
upconv = nn.ConvTranspose2d( | |
inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias | |
) | |
down = [downrelu, downconv] | |
up = [uprelu, upconv, upnorm] | |
model = down + up | |
else: | |
upconv = nn.ConvTranspose2d( | |
inner_nc * 2, | |
outer_nc, | |
kernel_size=4, | |
stride=2, | |
padding=1, | |
bias=use_bias, | |
) | |
down = [downrelu, downconv, downnorm] | |
up = [uprelu, upconv, upnorm] | |
if use_dropout: | |
model = down + [submodule] + up + [nn.Dropout(0.5)] | |
else: | |
model = down + [submodule] + up | |
self.model = nn.Sequential(*model) | |
def forward(self, x): | |
if self.outermost: | |
return self.model(x) | |
else: # add skip connections | |
return torch.cat([x, self.model(x)], 1) | |
class Anime2Sketch: | |
def __init__( | |
self, model_path: str = "./models/netG.pth", device: str = "cpu" | |
) -> None: | |
norm_layer = functools.partial( | |
nn.InstanceNorm2d, affine=False, track_running_stats=False | |
) | |
net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False) | |
ckpt = torch.load(model_path) | |
for key in list(ckpt.keys()): | |
if "module." in key: | |
ckpt[key.replace("module.", "")] = ckpt[key].half() | |
del ckpt[key] | |
net.load_state_dict(ckpt) | |
self.model = net | |
if torch.cuda.is_available() and device == "cuda": | |
self.device = "cuda" | |
self.model.to(device) | |
else: | |
self.device = "cpu" | |
self.model.to("cpu") | |
def predict(self, image: Image.Image, load_size: int = 512) -> Image: | |
try: | |
aus_resize = None | |
if load_size > 0: | |
aus_resize = image.size | |
transform = self.get_transform(load_size=load_size) | |
image = transform(image) | |
img = image.unsqueeze(0) | |
except: | |
raise Exception("Error in reading image {}".format(image.filename)) | |
aus_tensor = self.model(img.to(self.device)) | |
aus_img = self.tensor_to_img(aus_tensor) | |
image_pil = Image.fromarray(aus_img) | |
if aus_resize: | |
bic = Image.BICUBIC | |
image_pil = image_pil.resize(aus_resize, bic) | |
return image_pil | |
def get_transform(self, load_size=0, grayscale=False, method=bic, convert=True): | |
transform_list = [] | |
if grayscale: | |
transform_list.append(transforms.Grayscale(1)) | |
if load_size > 0: | |
osize = [load_size, load_size] | |
transform_list.append(transforms.Resize(osize, method)) | |
if convert: | |
transform_list += [transforms.ToTensor()] | |
if grayscale: | |
transform_list += [transforms.Normalize((0.5,), (0.5,))] | |
else: | |
transform_list += [ | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
] | |
return transforms.Compose(transform_list) | |
def tensor_to_img(self, input_image, imtype=np.uint8): | |
""" "Converts a Tensor array into a numpy image array. | |
Parameters: | |
input_image (tensor) -- the input image tensor array | |
imtype (type) -- the desired type of the converted numpy array | |
""" | |
if not isinstance(input_image, np.ndarray): | |
if isinstance(input_image, torch.Tensor): # get the data from a variable | |
image_tensor = input_image.data | |
else: | |
return input_image | |
image_numpy = ( | |
image_tensor[0].cpu().float().numpy() | |
) # convert it into a numpy array | |
if image_numpy.shape[0] == 1: # grayscale to RGB | |
image_numpy = np.tile(image_numpy, (3, 1, 1)) | |
image_numpy = ( | |
(np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 | |
) # post-processing: tranpose and scaling | |
else: # if it is a numpy array, do nothing | |
image_numpy = input_image | |
return image_numpy.astype(imtype) | |