Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
import math | |
import torch.nn.functional as F | |
from datetime import datetime | |
import random | |
import logging | |
from collections import OrderedDict | |
import numpy as np | |
import cv2 | |
import torch | |
from torchvision.utils import make_grid | |
from shutil import get_terminal_size | |
import torchvision.utils as vutils | |
from shutil import copyfile | |
import torchvision.transforms as transforms | |
import yaml | |
try: | |
from yaml import CLoader as Loader, CDumper as Dumper | |
except ImportError: | |
from yaml import Loader, Dumper | |
def OrderedYaml(): | |
'''yaml orderedDict support''' | |
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG | |
def dict_representer(dumper, data): | |
return dumper.represent_dict(data.items()) | |
def dict_constructor(loader, node): | |
return OrderedDict(loader.construct_pairs(node)) | |
Dumper.add_representer(OrderedDict, dict_representer) | |
Loader.add_constructor(_mapping_tag, dict_constructor) | |
return Loader, Dumper | |
#################### | |
# miscellaneous | |
#################### | |
def get_timestamp(): | |
return datetime.now().strftime('%y%m%d-%H%M%S') | |
def mkdir(path): | |
if not os.path.exists(path): | |
os.makedirs(path) | |
def mkdirs(paths): | |
if isinstance(paths, str): | |
print('path is : ', paths) | |
mkdir(paths) | |
else: | |
for path in paths: | |
print('path is : {}'.format(path)) | |
mkdir(path) | |
def mkdir_and_rename(path): | |
new_name = None | |
if os.path.exists(path): | |
new_name = path + '_archived_' + get_timestamp() | |
logger = logging.getLogger('base') | |
logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) | |
os.rename(path, new_name) | |
os.makedirs(path) | |
return new_name | |
def set_random_seed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): | |
'''set up logger''' | |
lg = logging.getLogger(logger_name) | |
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', | |
datefmt='%y-%m-%d %H:%M:%S') | |
lg.setLevel(level) | |
if tofile: | |
log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) | |
fh = logging.FileHandler(log_file, mode='w') | |
fh.setFormatter(formatter) | |
lg.addHandler(fh) | |
if screen: | |
sh = logging.StreamHandler() | |
sh.setFormatter(formatter) | |
lg.addHandler(sh) | |
#################### | |
# image convert | |
#################### | |
def crop_border(img_list, crop_border): | |
"""Crop borders of images | |
Args: | |
img_list (list [Numpy]): HWC | |
crop_border (int): crop border for each end of height and weight | |
Returns: | |
(list [Numpy]): cropped image list | |
""" | |
if crop_border == 0: | |
return img_list | |
else: | |
return [v[crop_border:-crop_border, crop_border:-crop_border] for v in img_list] | |
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): | |
''' | |
Converts a torch Tensor into an image Numpy array | |
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order | |
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) | |
''' | |
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp | |
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] | |
n_dim = tensor.dim() | |
if n_dim == 4: | |
n_img = len(tensor) | |
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() | |
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR | |
elif n_dim == 3: | |
img_np = tensor.numpy() | |
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR | |
elif n_dim == 2: | |
img_np = tensor.numpy() | |
else: | |
raise TypeError( | |
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) | |
if out_type == np.uint8: | |
img_np = (img_np * 255.0).round() | |
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default. | |
return img_np.astype(out_type) | |
def save_img(img, img_path, mode='RGB'): | |
cv2.imwrite(img_path, img) | |
def DUF_downsample(x, scale=4): | |
"""Downsamping with Gaussian kernel used in the DUF official code | |
Args: | |
x (Tensor, [B, T, C, H, W]): frames to be downsampled. | |
scale (int): downsampling factor: 2 | 3 | 4. | |
""" | |
assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale) | |
def gkern(kernlen=13, nsig=1.6): | |
import scipy.ndimage.filters as fi | |
inp = np.zeros((kernlen, kernlen)) | |
# set element at the middle to one, a dirac delta | |
inp[kernlen // 2, kernlen // 2] = 1 | |
# gaussian-smooth the dirac, resulting in a gaussian filter mask | |
return fi.gaussian_filter(inp, nsig) | |
B, T, C, H, W = x.size() | |
x = x.view(-1, 1, H, W) | |
pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter | |
r_h, r_w = 0, 0 | |
if scale == 3: | |
r_h = 3 - (H % 3) | |
r_w = 3 - (W % 3) | |
x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect') | |
gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0) | |
x = F.conv2d(x, gaussian_filter, stride=scale) | |
x = x[:, :, 2:-2, 2:-2] | |
x = x.view(B, T, C, x.size(2), x.size(3)) | |
return x | |
def single_forward(model, inp): | |
"""PyTorch model forward (single test), it is just a simple warpper | |
Args: | |
model (PyTorch model) | |
inp (Tensor): inputs defined by the model | |
Returns: | |
output (Tensor): outputs of the model. float, in CPU | |
""" | |
with torch.no_grad(): | |
model_output = model(inp) | |
if isinstance(model_output, list) or isinstance(model_output, tuple): | |
output = model_output[0] | |
else: | |
output = model_output | |
output = output.data.float().cpu() | |
return output | |
def flipx4_forward(model, inp): | |
"""Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W | |
Args: | |
model (PyTorch model) | |
inp (Tensor): inputs defined by the model | |
Returns: | |
output (Tensor): outputs of the model. float, in CPU | |
""" | |
# normal | |
output_f = single_forward(model, inp) | |
# flip W | |
output = single_forward(model, torch.flip(inp, (-1,))) | |
output_f = output_f + torch.flip(output, (-1,)) | |
# flip H | |
output = single_forward(model, torch.flip(inp, (-2,))) | |
output_f = output_f + torch.flip(output, (-2,)) | |
# flip both H and W | |
output = single_forward(model, torch.flip(inp, (-2, -1))) | |
output_f = output_f + torch.flip(output, (-2, -1)) | |
return output_f / 4 | |
#################### | |
# metric | |
#################### | |
class ProgressBar(object): | |
'''A progress bar which can print the progress | |
modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py | |
''' | |
def __init__(self, task_num=0, bar_width=50, start=True): | |
self.task_num = task_num | |
max_bar_width = self._get_max_bar_width() | |
self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) | |
self.completed = 0 | |
if start: | |
self.start() | |
def _get_max_bar_width(self): | |
terminal_width, _ = get_terminal_size() | |
max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) | |
if max_bar_width < 10: | |
print('terminal width is too small ({}), please consider widen the terminal for better ' | |
'progressbar visualization'.format(terminal_width)) | |
max_bar_width = 10 | |
return max_bar_width | |
def start(self): | |
if self.task_num > 0: | |
sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( | |
' ' * self.bar_width, self.task_num, 'Start...')) | |
else: | |
sys.stdout.write('completed: 0, elapsed: 0s') | |
sys.stdout.flush() | |
self.start_time = time.time() | |
def update(self, msg='In progress...'): | |
self.completed += 1 | |
elapsed = time.time() - self.start_time | |
fps = self.completed / elapsed | |
if self.task_num > 0: | |
percentage = self.completed / float(self.task_num) | |
eta = int(elapsed * (1 - percentage) / percentage + 0.5) | |
mark_width = int(self.bar_width * percentage) | |
bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) | |
sys.stdout.write('\033[2F') # cursor up 2 lines | |
sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) | |
sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( | |
bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) | |
else: | |
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( | |
self.completed, int(elapsed + 0.5), fps)) | |
sys.stdout.flush() | |
### communication | |
def find_free_port(): | |
import socket | |
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
sock.bind(("", 0)) | |
port = sock.getsockname()[1] | |
sock.close() | |
return port | |
# for debug | |
def visualize_image(result, outputDir, epoch, mode, video_name, minData=0): | |
### Only visualize one frame | |
targetDir = os.path.join(outputDir, str(epoch), video_name) | |
if not os.path.exists(targetDir): | |
os.makedirs(targetDir) | |
if minData == -1: | |
result = (result + 1) / 2 | |
vutils.save_image(result, os.path.join(targetDir, '{}.png'.format(mode))) | |
elif minData == 0: | |
vutils.save_image(result, os.path.join(targetDir, '{}.png'.format(mode))) | |
else: | |
raise ValueError('minValue {} is not supported'.format(minData)) | |
def get_learning_rate(optimizer): | |
lr = [] | |
for param_group in optimizer.param_groups: | |
lr += [param_group['lr']] | |
return lr | |
def adjust_learning_rate(optimizer, target_lr): | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = target_lr | |
def save_checkpoint(epoch, model, discriminator, current_step, schedulers, dist_scheduler, optimizers, dist_optimizer, save_path, is_best, monitor, monitor_value, | |
config): | |
# for entriely resuming state, you need to save the state dict of model, optimizer and learning scheduler | |
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
model_state = model.module.state_dict() | |
discriminator_state = discriminator.module.state_dict() | |
else: | |
model_state = model.state_dict() | |
discriminator_state = discriminator.state_dict() | |
state = { | |
'epoch': epoch, | |
'iteration': current_step, | |
'model_state_dict': model_state, | |
'discriminator_state_dict': discriminator_state, | |
'optimizer_state_dict': optimizers.state_dict(), | |
'dist_optim_state_dict': dist_optimizer.state_dict(), | |
'scheduler_state_dict': schedulers.state_dict(), | |
'dist_scheduler_state_dict': dist_scheduler.state_dict(), | |
'is_best': is_best, | |
'config': config, | |
} | |
best_str = '-best-so-far' if is_best else '' | |
monitor_str = '-{}:{}'.format(monitor, monitor_value) if monitor_value else '' | |
if not os.path.exists(os.path.join(save_path, 'best')): | |
os.makedirs(os.path.join(save_path, 'best')) | |
file_name = os.path.join(save_path, 'checkpoint-epoch:{}{}{}.pth.tar'.format(epoch, monitor_str, best_str)) | |
torch.save(state, file_name) | |
if is_best: | |
copyfile(src=file_name, dst=os.path.join(save_path, 'best', | |
'checkpoint-epoch:{}{}{}.pth.tar'.format(epoch, monitor_str, | |
best_str))) | |
def save_dist_checkpoint(epoch, model, dist, current_step, schedulers, schedulersD, optimizers, optimizersD, save_path, | |
is_best, monitor, monitor_value, | |
config): | |
# for entriely resuming state, you need to save the state dict of model, optimizer and learning scheduler | |
if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
model_state = model.module.state_dict() | |
dist_state = dist.module.state_dict() | |
else: | |
model_state = model.state_dict() | |
dist_state = dist.state_dict() | |
state = { | |
'epoch': epoch, | |
'iteration': current_step, | |
'model_state_dict': model_state, | |
'dist_state_dict': dist_state, | |
'optimizer_state_dict': optimizers.state_dict(), | |
'optimizerD_state_dict': optimizersD.state_dict(), | |
'scheduler_state_dict': schedulers.state_dict(), | |
'schedulerD_state_dict': schedulersD.state_dict(), | |
'is_best': is_best, | |
'config': config | |
} | |
best_str = '-best-so-far' if is_best else '' | |
monitor_str = '-{}:{}'.format(monitor, monitor_value) if monitor_value else '' | |
if not os.path.exists(os.path.join(save_path, 'best')): | |
os.makedirs(os.path.join(save_path, 'best')) | |
file_name = os.path.join(save_path, 'checkpoint-epoch:{}{}{}.pth.tar'.format(epoch, monitor_str, best_str)) | |
torch.save(state, file_name) | |
if is_best: | |
copyfile(src=file_name, dst=os.path.join(save_path, 'best', | |
'checkpoint-epoch:{}{}{}.pth.tar'.format(epoch, monitor_str, | |
best_str))) | |
def poisson_blend(input, output, mask): | |
""" | |
* inputs: | |
- input (torch.Tensor, required) | |
Input tensor of Completion Network, whose shape = (N, 3, H, W). | |
- output (torch.Tensor, required) | |
Output tensor of Completion Network, whose shape = (N, 3, H, W). | |
- mask (torch.Tensor, required) | |
Input mask tensor of Completion Network, whose shape = (N, 1, H, W). | |
* returns: | |
Output image tensor of shape (N, 3, H, W) inpainted with poisson image editing method. | |
from lizuka et al: https://github.com/otenim/GLCIC-PyTorch/blob/caf9bebe667fba0aebbd041918f2d8128f59ec62/utils.py | |
""" | |
input = input.clone().cpu() | |
output = output.clone().cpu() | |
mask = mask.clone().cpu() | |
mask = torch.cat((mask, mask, mask), dim=1) # convert to 3-channel format | |
num_samples = input.shape[0] | |
ret = [] | |
for i in range(num_samples): | |
dstimg = transforms.functional.to_pil_image(input[i]) | |
dstimg = np.array(dstimg)[:, :, [2, 1, 0]] | |
srcimg = transforms.functional.to_pil_image(output[i]) | |
srcimg = np.array(srcimg)[:, :, [2, 1, 0]] | |
msk = transforms.functional.to_pil_image(mask[i]) | |
msk = np.array(msk)[:, :, [2, 1, 0]] | |
# compute mask's center | |
xs, ys = [], [] | |
for j in range(msk.shape[0]): | |
for k in range(msk.shape[1]): | |
if msk[j, k, 0] == 255: | |
ys.append(j) | |
xs.append(k) | |
xmin, xmax = min(xs), max(xs) | |
ymin, ymax = min(ys), max(ys) | |
center = ((xmax + xmin) // 2, (ymax + ymin) // 2) | |
dstimg = cv2.inpaint(dstimg, msk[:, :, 0], 1, cv2.INPAINT_TELEA) | |
out = cv2.seamlessClone(srcimg, dstimg, msk, center, cv2.NORMAL_CLONE) | |
out = out[:, :, [2, 1, 0]] | |
out = transforms.functional.to_tensor(out) | |
out = torch.unsqueeze(out, dim=0) | |
ret.append(out) | |
ret = torch.cat(ret, dim=0) | |
return ret | |