manhkhanhUIT's picture
Init
e78c13e
raw history blame
No virus
5.03 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import argparse
import gc
import json
import os
import time
import warnings
import numpy as np
import torch
import torch.nn.functional as F
import torchvision as tv
from PIL import Image, ImageFile
from detection_models import networks
from detection_util.util import *
warnings.filterwarnings("ignore", category=UserWarning)
ImageFile.LOAD_TRUNCATED_IMAGES = True
def data_transforms(img, full_size, method=Image.BICUBIC):
if full_size == "full_size":
ow, oh = img.size
h = int(round(oh / 16) * 16)
w = int(round(ow / 16) * 16)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
elif full_size == "scale_256":
ow, oh = img.size
pw, ph = ow, oh
if ow < oh:
ow = 256
oh = ph / pw * 256
else:
oh = 256
ow = pw / ph * 256
h = int(round(oh / 16) * 16)
w = int(round(ow / 16) * 16)
if (h == ph) and (w == pw):
return img
return img.resize((w, h), method)
def scale_tensor(img_tensor, default_scale=256):
_, _, w, h = img_tensor.shape
if w < h:
ow = default_scale
oh = h / w * default_scale
else:
oh = default_scale
ow = w / h * default_scale
oh = int(round(oh / 16) * 16)
ow = int(round(ow / 16) * 16)
return F.interpolate(img_tensor, [ow, oh], mode="bilinear")
def blend_mask(img, mask):
np_img = np.array(img).astype("float")
return Image.fromarray((np_img * (1 - mask) + mask * 255.0).astype("uint8")).convert("RGB")
def main(config):
print("initializing the dataloader")
model = networks.UNet(
in_channels=1,
out_channels=1,
depth=4,
conv_num=2,
wf=6,
padding=True,
batch_norm=True,
up_mode="upsample",
with_tanh=False,
sync_bn=True,
antialiasing=True,
)
## load model
checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint["model_state"])
print("model weights loaded")
if config.GPU >= 0:
model.to(config.GPU)
else:
model.cpu()
model.eval()
## dataloader and transformation
print("directory of testing image: " + config.test_path)
imagelist = os.listdir(config.test_path)
imagelist.sort()
total_iter = 0
P_matrix = {}
save_url = os.path.join(config.output_dir)
mkdir_if_not(save_url)
input_dir = os.path.join(save_url, "input")
output_dir = os.path.join(save_url, "mask")
# blend_output_dir=os.path.join(save_url, 'blend_output')
mkdir_if_not(input_dir)
mkdir_if_not(output_dir)
# mkdir_if_not(blend_output_dir)
idx = 0
results = []
for image_name in imagelist:
idx += 1
print("processing", image_name)
scratch_file = os.path.join(config.test_path, image_name)
if not os.path.isfile(scratch_file):
print("Skipping non-file %s" % image_name)
continue
scratch_image = Image.open(scratch_file).convert("RGB")
w, h = scratch_image.size
transformed_image_PIL = data_transforms(scratch_image, config.input_size)
scratch_image = transformed_image_PIL.convert("L")
scratch_image = tv.transforms.ToTensor()(scratch_image)
scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image)
scratch_image = torch.unsqueeze(scratch_image, 0)
_, _, ow, oh = scratch_image.shape
scratch_image_scale = scale_tensor(scratch_image)
if config.GPU >= 0:
scratch_image_scale = scratch_image_scale.to(config.GPU)
else:
scratch_image_scale = scratch_image_scale.cpu()
with torch.no_grad():
P = torch.sigmoid(model(scratch_image_scale))
P = P.data.cpu()
P = F.interpolate(P, [ow, oh], mode="nearest")
tv.utils.save_image(
(P >= 0.4).float(),
os.path.join(
output_dir,
image_name[:-4] + ".png",
),
nrow=1,
padding=0,
normalize=True,
)
transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png"))
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# parser.add_argument('--checkpoint_name', type=str, default="FT_Epoch_latest.pt", help='Checkpoint Name')
parser.add_argument("--GPU", type=int, default=0)
parser.add_argument("--test_path", type=str, default=".")
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument("--input_size", type=str, default="scale_256", help="resize_256|full_size|scale_256")
config = parser.parse_args()
main(config)