manhkhanhUIT's picture
Init
e78c13e
raw history blame
No virus
6.03 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from collections import OrderedDict
from torch.autograd import Variable
from options.test_options import TestOptions
from models.models import create_model
from models.mapping_model import Pix2PixHDModel_Mapping
import util.util as util
from PIL import Image
import torch
import torchvision.utils as vutils
import torchvision.transforms as transforms
import numpy as np
import cv2
def data_transforms(img, method=Image.BILINEAR, scale=False):
ow, oh = img.size
pw, ph = ow, oh
if scale == True:
if ow < oh:
ow = 256
oh = ph / pw * 256
else:
oh = 256
ow = pw / ph * 256
h = int(round(oh / 4) * 4)
w = int(round(ow / 4) * 4)
if (h == ph) and (w == pw):
return img
return img.resize((w, h), method)
def data_transforms_rgb_old(img):
w, h = img.size
A = img
if w < 256 or h < 256:
A = transforms.Scale(256, Image.BILINEAR)(img)
return transforms.CenterCrop(256)(A)
def irregular_hole_synthesize(img, mask):
img_np = np.array(img).astype("uint8")
mask_np = np.array(mask).astype("uint8")
mask_np = mask_np / 255
img_new = img_np * (1 - mask_np) + mask_np * 255
hole_img = Image.fromarray(img_new.astype("uint8")).convert("RGB")
return hole_img
def parameter_set(opt):
## Default parameters
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
opt.label_nc = 0
opt.n_downsample_global = 3
opt.mc = 64
opt.k_size = 4
opt.start_r = 1
opt.mapping_n_block = 6
opt.map_mc = 512
opt.no_instance = True
opt.checkpoints_dir = "./checkpoints/restoration"
##
if opt.Quality_restore:
opt.name = "mapping_quality"
opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality")
opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_quality")
if opt.Scratch_and_Quality_restore:
opt.NL_res = True
opt.use_SN = True
opt.correlation_renormalize = True
opt.NL_use_mask = True
opt.NL_fusion_method = "combine"
opt.non_local = "Setting_42"
opt.name = "mapping_scratch"
opt.load_pretrainA = os.path.join(opt.checkpoints_dir, "VAE_A_quality")
opt.load_pretrainB = os.path.join(opt.checkpoints_dir, "VAE_B_scratch")
if opt.HR:
opt.mapping_exp = 1
opt.inference_optimize = True
opt.mask_dilation = 3
opt.name = "mapping_Patch_Attention"
if __name__ == "__main__":
opt = TestOptions().parse(save=False)
parameter_set(opt)
model = Pix2PixHDModel_Mapping()
model.initialize(opt)
model.eval()
if not os.path.exists(opt.outputs_dir + "/" + "input_image"):
os.makedirs(opt.outputs_dir + "/" + "input_image")
if not os.path.exists(opt.outputs_dir + "/" + "restored_image"):
os.makedirs(opt.outputs_dir + "/" + "restored_image")
if not os.path.exists(opt.outputs_dir + "/" + "origin"):
os.makedirs(opt.outputs_dir + "/" + "origin")
dataset_size = 0
input_loader = os.listdir(opt.test_input)
dataset_size = len(input_loader)
input_loader.sort()
if opt.test_mask != "":
mask_loader = os.listdir(opt.test_mask)
dataset_size = len(os.listdir(opt.test_mask))
mask_loader.sort()
img_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
mask_transform = transforms.ToTensor()
for i in range(dataset_size):
input_name = input_loader[i]
input_file = os.path.join(opt.test_input, input_name)
if not os.path.isfile(input_file):
print("Skipping non-file %s" % input_name)
continue
input = Image.open(input_file).convert("RGB")
print("Now you are processing %s" % (input_name))
if opt.NL_use_mask:
mask_name = mask_loader[i]
mask = Image.open(os.path.join(opt.test_mask, mask_name)).convert("RGB")
if opt.mask_dilation != 0:
kernel = np.ones((3,3),np.uint8)
mask = np.array(mask)
mask = cv2.dilate(mask,kernel,iterations = opt.mask_dilation)
mask = Image.fromarray(mask.astype('uint8'))
origin = input
input = irregular_hole_synthesize(input, mask)
mask = mask_transform(mask)
mask = mask[:1, :, :] ## Convert to single channel
mask = mask.unsqueeze(0)
input = img_transform(input)
input = input.unsqueeze(0)
else:
if opt.test_mode == "Scale":
input = data_transforms(input, scale=True)
if opt.test_mode == "Full":
input = data_transforms(input, scale=False)
if opt.test_mode == "Crop":
input = data_transforms_rgb_old(input)
origin = input
input = img_transform(input)
input = input.unsqueeze(0)
mask = torch.zeros_like(input)
### Necessary input
try:
with torch.no_grad():
generated = model.inference(input, mask)
except Exception as ex:
print("Skip %s due to an error:\n%s" % (input_name, str(ex)))
continue
if input_name.endswith(".jpg"):
input_name = input_name[:-4] + ".png"
image_grid = vutils.save_image(
(input + 1.0) / 2.0,
opt.outputs_dir + "/input_image/" + input_name,
nrow=1,
padding=0,
normalize=True,
)
image_grid = vutils.save_image(
(generated.data.cpu() + 1.0) / 2.0,
opt.outputs_dir + "/restored_image/" + input_name,
nrow=1,
padding=0,
normalize=True,
)
origin.save(opt.outputs_dir + "/origin/" + input_name)