Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
import os | |
from collections import OrderedDict | |
from torch.autograd import Variable | |
from options.test_options import TestOptions | |
from models.models import create_model | |
from models.mapping_model import Pix2PixHDModel_Mapping | |
import util.util as util | |
from PIL import Image | |
import torch | |
import torchvision.utils as vutils | |
import torchvision.transforms as transforms | |
import numpy as np | |
import cv2 | |
def data_transforms(img, method=Image.BILINEAR, scale=False): | |
ow, oh = img.size | |
pw, ph = ow, oh | |
if scale == True: | |
if ow < oh: | |
ow = 256 | |
oh = ph / pw * 256 | |
else: | |
oh = 256 | |
ow = pw / ph * 256 | |
h = int(round(oh / 4) * 4) | |
w = int(round(ow / 4) * 4) | |
if (h == ph) and (w == pw): | |
return img | |
return img.resize((w, h), method) | |
def data_transforms_rgb_old(img): | |
w, h = img.size | |
A = img | |
if w < 256 or h < 256: | |
A = transforms.Scale(256, Image.BILINEAR)(img) | |
return transforms.CenterCrop(256)(A) | |
def irregular_hole_synthesize(img, mask): | |
img_np = np.array(img).astype("uint8") | |
mask_np = np.array(mask).astype("uint8") | |
mask_np = mask_np / 255 | |
img_new = img_np * (1 - mask_np) + mask_np * 255 | |
hole_img = Image.fromarray(img_new.astype("uint8")).convert("RGB") | |
return hole_img | |
def parameter_set(opt): | |
## Default parameters | |
opt.serial_batches = True # no shuffle | |
opt.no_flip = True # no flip | |
opt.label_nc = 0 | |
opt.n_downsample_global = 3 | |
opt.mc = 64 | |
opt.k_size = 4 | |
opt.start_r = 1 | |
opt.mapping_n_block = 6 | |
opt.map_mc = 512 | |
opt.no_instance = True | |
opt.checkpoints_dir = "./checkpoints/restoration" | |
## | |
if opt.Quality_restore: | |
opt.name = "mapping_quality" | |
opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") | |
opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_quality") | |
if opt.Scratch_and_Quality_restore: | |
opt.NL_res = True | |
opt.use_SN = True | |
opt.correlation_renormalize = True | |
opt.NL_use_mask = True | |
opt.NL_fusion_method = "combine" | |
opt.non_local = "Setting_42" | |
opt.name = "mapping_scratch" | |
opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality") | |
opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch") | |
if opt.HR: | |
opt.mapping_exp = 1 | |
opt.inference_optimize = True | |
opt.mask_dilation = 3 | |
opt.name = "mapping_Patch_Attention" | |
if __name__ == "__main__": | |
opt = TestOptions().parse(save=False) | |
parameter_set(opt) | |
model = Pix2PixHDModel_Mapping() | |
model.initialize(opt) | |
model.eval() | |
if not os.path.exists(opt.outputs_dir + "/" + "input_image"): | |
os.makedirs(opt.outputs_dir + "/" + "input_image") | |
if not os.path.exists(opt.outputs_dir + "/" + "restored_image"): | |
os.makedirs(opt.outputs_dir + "/" + "restored_image") | |
if not os.path.exists(opt.outputs_dir + "/" + "origin"): | |
os.makedirs(opt.outputs_dir + "/" + "origin") | |
dataset_size = 0 | |
input_loader = os.listdir(opt.test_input) | |
dataset_size = len(input_loader) | |
input_loader.sort() | |
if opt.test_mask != "": | |
mask_loader = os.listdir(opt.test_mask) | |
dataset_size = len(os.listdir(opt.test_mask)) | |
mask_loader.sort() | |
img_transform = transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |
) | |
mask_transform = transforms.ToTensor() | |
for i in range(dataset_size): | |
input_name = input_loader[i] | |
input_file = os.path.join(opt.test_input, input_name) | |
if not os.path.isfile(input_file): | |
print("Skipping non-file %s" % input_name) | |
continue | |
input = Image.open(input_file).convert("RGB") | |
print("Now you are processing %s" % (input_name)) | |
if opt.NL_use_mask: | |
mask_name = mask_loader[i] | |
mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB") | |
if opt.mask_dilation != 0: | |
kernel = np.ones((3,3),np.uint8) | |
mask = np.array(mask) | |
mask = cv2.dilate(mask,kernel,iterations = opt.mask_dilation) | |
mask = Image.fromarray(mask.astype('uint8')) | |
origin = input | |
input = irregular_hole_synthesize(input, mask) | |
mask = mask_transform(mask) | |
mask = mask[:1, :, :] ## Convert to single channel | |
mask = mask.unsqueeze(0) | |
input = img_transform(input) | |
input = input.unsqueeze(0) | |
else: | |
if opt.test_mode == "Scale": | |
input = data_transforms(input, scale=True) | |
if opt.test_mode == "Full": | |
input = data_transforms(input, scale=False) | |
if opt.test_mode == "Crop": | |
input = data_transforms_rgb_old(input) | |
origin = input | |
input = img_transform(input) | |
input = input.unsqueeze(0) | |
mask = torch.zeros_like(input) | |
### Necessary input | |
try: | |
with torch.no_grad(): | |
generated = model.inference(input, mask) | |
except Exception as ex: | |
print("Skip %s due to an error:\n%s" % (input_name, str(ex))) | |
continue | |
if input_name.endswith(".jpg"): | |
input_name = input_name[:-4] + ".png" | |
image_grid = vutils.save_image( | |
(input + 1.0) / 2.0, | |
opt.outputs_dir + "/input_image/" + input_name, | |
nrow=1, | |
padding=0, | |
normalize=True, | |
) | |
image_grid = vutils.save_image( | |
(generated.data.cpu() + 1.0) / 2.0, | |
opt.outputs_dir + "/restored_image/" + input_name, | |
nrow=1, | |
padding=0, | |
normalize=True, | |
) | |
origin.save(opt.outputs_dir + "/origin/" + input_name) |