Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
import argparse | |
import gc | |
import json | |
import os | |
import time | |
import warnings | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torchvision as tv | |
from PIL import Image, ImageFile | |
from detection_models import networks | |
from detection_util.util import * | |
warnings.filterwarnings("ignore", category=UserWarning) | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
def data_transforms(img, full_size, method=Image.BICUBIC): | |
if full_size == "full_size": | |
ow, oh = img.size | |
h = int(round(oh / 16) * 16) | |
w = int(round(ow / 16) * 16) | |
if (h == oh) and (w == ow): | |
return img | |
return img.resize((w, h), method) | |
elif full_size == "scale_256": | |
ow, oh = img.size | |
pw, ph = ow, oh | |
if ow < oh: | |
ow = 256 | |
oh = ph / pw * 256 | |
else: | |
oh = 256 | |
ow = pw / ph * 256 | |
h = int(round(oh / 16) * 16) | |
w = int(round(ow / 16) * 16) | |
if (h == ph) and (w == pw): | |
return img | |
return img.resize((w, h), method) | |
def scale_tensor(img_tensor, default_scale=256): | |
_, _, w, h = img_tensor.shape | |
if w < h: | |
ow = default_scale | |
oh = h / w * default_scale | |
else: | |
oh = default_scale | |
ow = w / h * default_scale | |
oh = int(round(oh / 16) * 16) | |
ow = int(round(ow / 16) * 16) | |
return F.interpolate(img_tensor, [ow, oh], mode="bilinear") | |
def blend_mask(img, mask): | |
np_img = np.array(img).astype("float") | |
return Image.fromarray((np_img * (1 - mask) + mask * 255.0).astype("uint8")).convert("RGB") | |
def main(config): | |
print("initializing the dataloader") | |
model = networks.UNet( | |
in_channels=1, | |
out_channels=1, | |
depth=4, | |
conv_num=2, | |
wf=6, | |
padding=True, | |
batch_norm=True, | |
up_mode="upsample", | |
with_tanh=False, | |
sync_bn=True, | |
antialiasing=True, | |
) | |
## load model | |
checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints/detection/FT_Epoch_latest.pt") | |
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |
model.load_state_dict(checkpoint["model_state"]) | |
print("model weights loaded") | |
if config.GPU >= 0: | |
model.to(config.GPU) | |
else: | |
model.cpu() | |
model.eval() | |
## dataloader and transformation | |
print("directory of testing image: " + config.test_path) | |
imagelist = os.listdir(config.test_path) | |
imagelist.sort() | |
total_iter = 0 | |
P_matrix = {} | |
save_url = os.path.join(config.output_dir) | |
mkdir_if_not(save_url) | |
input_dir = os.path.join(save_url, "input") | |
output_dir = os.path.join(save_url, "mask") | |
# blend_output_dir=os.path.join(save_url, 'blend_output') | |
mkdir_if_not(input_dir) | |
mkdir_if_not(output_dir) | |
# mkdir_if_not(blend_output_dir) | |
idx = 0 | |
results = [] | |
for image_name in imagelist: | |
idx += 1 | |
print("processing", image_name) | |
scratch_file = os.path.join(config.test_path, image_name) | |
if not os.path.isfile(scratch_file): | |
print("Skipping non-file %s" % image_name) | |
continue | |
scratch_image = Image.open(scratch_file).convert("RGB") | |
w, h = scratch_image.size | |
transformed_image_PIL = data_transforms(scratch_image, config.input_size) | |
scratch_image = transformed_image_PIL.convert("L") | |
scratch_image = tv.transforms.ToTensor()(scratch_image) | |
scratch_image = tv.transforms.Normalize([0.5], [0.5])(scratch_image) | |
scratch_image = torch.unsqueeze(scratch_image, 0) | |
_, _, ow, oh = scratch_image.shape | |
scratch_image_scale = scale_tensor(scratch_image) | |
if config.GPU >= 0: | |
scratch_image_scale = scratch_image_scale.to(config.GPU) | |
else: | |
scratch_image_scale = scratch_image_scale.cpu() | |
with torch.no_grad(): | |
P = torch.sigmoid(model(scratch_image_scale)) | |
P = P.data.cpu() | |
P = F.interpolate(P, [ow, oh], mode="nearest") | |
tv.utils.save_image( | |
(P >= 0.4).float(), | |
os.path.join( | |
output_dir, | |
image_name[:-4] + ".png", | |
), | |
nrow=1, | |
padding=0, | |
normalize=True, | |
) | |
transformed_image_PIL.save(os.path.join(input_dir, image_name[:-4] + ".png")) | |
gc.collect() | |
torch.cuda.empty_cache() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# parser.add_argument('--checkpoint_name', type=str, default="FT_Epoch_latest.pt", help='Checkpoint Name') | |
parser.add_argument("--GPU", type=int, default=0) | |
parser.add_argument("--test_path", type=str, default=".") | |
parser.add_argument("--output_dir", type=str, default=".") | |
parser.add_argument("--input_size", type=str, default="scale_256", help="resize_256|full_size|scale_256") | |
config = parser.parse_args() | |
main(config) | |