manhkhanhUIT's picture
Add code
7fab858
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import importlib
import torch
from argparse import Namespace
import numpy as np
from PIL import Image
import os
import argparse
import dill as pickle
def save_obj(obj, name):
with open(name, "wb") as f:
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
def load_obj(name):
with open(name, "rb") as f:
return pickle.load(f)
def copyconf(default_opt, **kwargs):
conf = argparse.Namespace(**vars(default_opt))
for key in kwargs:
print(key, kwargs[key])
setattr(conf, key, kwargs[key])
return conf
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
if isinstance(image_tensor, list):
image_numpy = []
for i in range(len(image_tensor)):
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
return image_numpy
if image_tensor.dim() == 4:
# transform each image in the batch
images_np = []
for b in range(image_tensor.size(0)):
one_image = image_tensor[b]
one_image_np = tensor2im(one_image)
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
images_np = np.concatenate(images_np, axis=0)
return images_np
if image_tensor.dim() == 2:
image_tensor = image_tensor.unsqueeze(0)
image_numpy = image_tensor.detach().cpu().float().numpy()
if normalize:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
else:
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
if image_numpy.shape[2] == 1:
image_numpy = image_numpy[:, :, 0]
return image_numpy.astype(imtype)
# Converts a one-hot tensor into a colorful label map
def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
if label_tensor.dim() == 4:
# transform each image in the batch
images_np = []
for b in range(label_tensor.size(0)):
one_image = label_tensor[b]
one_image_np = tensor2label(one_image, n_label, imtype)
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
images_np = np.concatenate(images_np, axis=0)
# if tile:
# images_tiled = tile_images(images_np)
# return images_tiled
# else:
# images_np = images_np[0]
# return images_np
return images_np
if label_tensor.dim() == 1:
return np.zeros((64, 64, 3), dtype=np.uint8)
if n_label == 0:
return tensor2im(label_tensor, imtype)
label_tensor = label_tensor.cpu().float()
if label_tensor.size()[0] > 1:
label_tensor = label_tensor.max(0, keepdim=True)[1]
label_tensor = Colorize(n_label)(label_tensor)
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
result = label_numpy.astype(imtype)
return result
def save_image(image_numpy, image_path, create_dir=False):
if create_dir:
os.makedirs(os.path.dirname(image_path), exist_ok=True)
if len(image_numpy.shape) == 2:
image_numpy = np.expand_dims(image_numpy, axis=2)
if image_numpy.shape[2] == 1:
image_numpy = np.repeat(image_numpy, 3, 2)
image_pil = Image.fromarray(image_numpy)
# save to png
image_pil.save(image_path.replace(".jpg", ".png"))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
"""
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
"""
return [atoi(c) for c in re.split("(\d+)", text)]
def natural_sort(items):
items.sort(key=natural_keys)
def str2bool(v):
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def find_class_in_module(target_cls_name, module):
target_cls_name = target_cls_name.replace("_", "").lower()
clslib = importlib.import_module(module)
cls = None
for name, clsobj in clslib.__dict__.items():
if name.lower() == target_cls_name:
cls = clsobj
if cls is None:
print(
"In %s, there should be a class whose name matches %s in lowercase without underscore(_)"
% (module, target_cls_name)
)
exit(0)
return cls
def save_network(net, label, epoch, opt):
save_filename = "%s_net_%s.pth" % (epoch, label)
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
torch.save(net.cpu().state_dict(), save_path)
if len(opt.gpu_ids) and torch.cuda.is_available():
net.cuda()
def load_network(net, label, epoch, opt):
save_filename = "%s_net_%s.pth" % (epoch, label)
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
save_path = os.path.join(save_dir, save_filename)
if os.path.exists(save_path):
weights = torch.load(save_path)
net.load_state_dict(weights)
return net
###############################################################################
# Code from
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
# Modified so it complies with the Citscape label map colors
###############################################################################
def uint82bin(n, count=8):
"""returns the binary of integer n, count refers to amount of bits"""
return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)])
class Colorize(object):
def __init__(self, n=35):
self.cmap = labelcolormap(n)
self.cmap = torch.from_numpy(self.cmap[:n])
def __call__(self, gray_image):
size = gray_image.size()
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
for label in range(0, len(self.cmap)):
mask = (label == gray_image[0]).cpu()
color_image[0][mask] = self.cmap[label][0]
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]
return color_image