Spaces:
Runtime error
Runtime error
import random | |
from typing import Any, Optional | |
import numpy as np | |
import os | |
import cv2 | |
from glob import glob | |
from PIL import Image, ImageDraw | |
from tqdm import tqdm | |
import kornia | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import albumentations as albu | |
import functools | |
import math | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
import torchvision as tv | |
import torchvision.models as models | |
from torchvision import transforms | |
from torchvision.transforms import functional as F | |
from losses import TempCombLoss | |
########### DeblurGAN function | |
def get_norm_layer(norm_type='instance'): | |
if norm_type == 'batch': | |
norm_layer = functools.partial(nn.BatchNorm2d, affine=True) | |
elif norm_type == 'instance': | |
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) | |
else: | |
raise NotImplementedError('normalization layer [%s] is not found' % norm_type) | |
return norm_layer | |
def _array_to_batch(x): | |
x = np.transpose(x, (2, 0, 1)) | |
x = np.expand_dims(x, 0) | |
return torch.from_numpy(x) | |
def get_normalize(): | |
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
normalize = albu.Compose([normalize], additional_targets={'target': 'image'}) | |
def process(a, b): | |
r = normalize(image=a, target=b) | |
return r['image'], r['target'] | |
return process | |
def preprocess(x: np.ndarray, mask: Optional[np.ndarray]): | |
x, _ = get_normalize()(x, x) | |
if mask is None: | |
mask = np.ones_like(x, dtype=np.float32) | |
else: | |
mask = np.round(mask.astype('float32') / 255) | |
h, w, _ = x.shape | |
block_size = 32 | |
min_height = (h // block_size + 1) * block_size | |
min_width = (w // block_size + 1) * block_size | |
pad_params = {'mode': 'constant', | |
'constant_values': 0, | |
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0)) | |
} | |
x = np.pad(x, **pad_params) | |
mask = np.pad(mask, **pad_params) | |
return map(_array_to_batch, (x, mask)), h, w | |
def postprocess(x: torch.Tensor) -> np.ndarray: | |
x, = x | |
x = x.detach().cpu().float().numpy() | |
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0 | |
return x.astype('uint8') | |
def sorted_glob(pattern): | |
return sorted(glob(pattern)) | |
########### | |
def normalize(image: np.ndarray) -> np.ndarray: | |
"""Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data. | |
Args: | |
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``. | |
Returns: | |
Normalized image data. Data range [0, 1]. | |
""" | |
return image.astype(np.float64) / 255.0 | |
def unnormalize(image: np.ndarray) -> np.ndarray: | |
"""Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data. | |
Args: | |
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``. | |
Returns: | |
Denormalized image data. Data range [0, 255]. | |
""" | |
return image.astype(np.float64) * 255.0 | |
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor: | |
"""Convert ``PIL.Image`` to Tensor. | |
Args: | |
image (np.ndarray): The image data read by ``PIL.Image`` | |
range_norm (bool): Scale [0, 1] data to between [-1, 1] | |
half (bool): Whether to convert torch.float32 similarly to torch.half type. | |
Returns: | |
Normalized image data | |
Examples: | |
>>> image = Image.open("image.bmp") | |
>>> tensor_image = image2tensor(image, range_norm=False, half=False) | |
""" | |
tensor = F.to_tensor(image) | |
if range_norm: | |
tensor = tensor.mul_(2.0).sub_(1.0) | |
if half: | |
tensor = tensor.half() | |
return tensor | |
def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any: | |
"""Converts ``torch.Tensor`` to ``PIL.Image``. | |
Args: | |
tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image`` | |
range_norm (bool): Scale [-1, 1] data to between [0, 1] | |
half (bool): Whether to convert torch.float32 similarly to torch.half type. | |
Returns: | |
Convert image data to support PIL library | |
Examples: | |
>>> tensor = torch.randn([1, 3, 128, 128]) | |
>>> image = tensor2image(tensor, range_norm=False, half=False) | |
""" | |
if range_norm: | |
tensor = tensor.add_(1.0).div_(2.0) | |
if half: | |
tensor = tensor.half() | |
image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8") | |
return image | |
def convert_rgb_to_y(image: Any) -> Any: | |
"""Convert RGB image or tensor image data to YCbCr(Y) format. | |
Args: | |
image: RGB image data read by ``PIL.Image''. | |
Returns: | |
Y image array data. | |
""" | |
if type(image) == np.ndarray: | |
return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256. | |
elif type(image) == torch.Tensor: | |
if len(image.shape) == 4: | |
image = image.squeeze_(0) | |
return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256. | |
else: | |
raise Exception("Unknown Type", type(image)) | |
def convert_rgb_to_ycbcr(image: Any) -> Any: | |
"""Convert RGB image or tensor image data to YCbCr format. | |
Args: | |
image: RGB image data read by ``PIL.Image''. | |
Returns: | |
YCbCr image array data. | |
""" | |
if type(image) == np.ndarray: | |
y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256. | |
cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256. | |
cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256. | |
return np.array([y, cb, cr]).transpose([1, 2, 0]) | |
elif type(image) == torch.Tensor: | |
if len(image.shape) == 4: | |
image = image.squeeze(0) | |
y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256. | |
cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256. | |
cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256. | |
return torch.cat([y, cb, cr], 0).permute(1, 2, 0) | |
else: | |
raise Exception("Unknown Type", type(image)) | |
def convert_ycbcr_to_rgb(image: Any) -> Any: | |
"""Convert YCbCr format image to RGB format. | |
Args: | |
image: YCbCr image data read by ``PIL.Image''. | |
Returns: | |
RGB image array data. | |
""" | |
if type(image) == np.ndarray: | |
r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921 | |
g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576 | |
b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836 | |
return np.array([r, g, b]).transpose([1, 2, 0]) | |
elif type(image) == torch.Tensor: | |
if len(image.shape) == 4: | |
image = image.squeeze(0) | |
r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921 | |
g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576 | |
b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836 | |
return torch.cat([r, g, b], 0).permute(1, 2, 0) | |
else: | |
raise Exception("Unknown Type", type(image)) | |
def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]: | |
"""Cut ``PIL.Image`` in the center area of the image. | |
Args: | |
lr: Low-resolution image data read by ``PIL.Image``. | |
hr: High-resolution image data read by ``PIL.Image``. | |
image_size (int): The size of the captured image area. It should be the size of the high-resolution image. | |
upscale_factor (int): magnification factor. | |
Returns: | |
Randomly cropped low-resolution images and high-resolution images. | |
""" | |
w, h = hr.size | |
left = (w - image_size) // 2 | |
top = (h - image_size) // 2 | |
right = left + image_size | |
bottom = top + image_size | |
lr = lr.crop((left // upscale_factor, | |
top // upscale_factor, | |
right // upscale_factor, | |
bottom // upscale_factor)) | |
hr = hr.crop((left, top, right, bottom)) | |
return lr, hr | |
def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]: | |
"""Will ``PIL.Image`` randomly capture the specified area of the image. | |
Args: | |
lr: Low-resolution image data read by ``PIL.Image``. | |
hr: High-resolution image data read by ``PIL.Image``. | |
image_size (int): The size of the captured image area. It should be the size of the high-resolution image. | |
upscale_factor (int): magnification factor. | |
Returns: | |
Randomly cropped low-resolution images and high-resolution images. | |
""" | |
w, h = hr.size | |
left = torch.randint(0, w - image_size + 1, size=(1,)).item() | |
top = torch.randint(0, h - image_size + 1, size=(1,)).item() | |
right = left + image_size | |
bottom = top + image_size | |
lr = lr.crop((left // upscale_factor, | |
top // upscale_factor, | |
right // upscale_factor, | |
bottom // upscale_factor)) | |
hr = hr.crop((left, top, right, bottom)) | |
return lr, hr | |
def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]: | |
"""Will ``PIL.Image`` randomly rotate the image. | |
Args: | |
lr: Low-resolution image data read by ``PIL.Image``. | |
hr: High-resolution image data read by ``PIL.Image``. | |
angle (int): rotation angle, clockwise and counterclockwise rotation. | |
Returns: | |
Randomly rotated low-resolution images and high-resolution images. | |
""" | |
angle = random.choice((+angle, -angle)) | |
lr = F.rotate(lr, angle) | |
hr = F.rotate(hr, angle) | |
return lr, hr | |
def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]: | |
"""Flip the ``PIL.Image`` image horizontally randomly. | |
Args: | |
lr: Low-resolution image data read by ``PIL.Image``. | |
hr: High-resolution image data read by ``PIL.Image``. | |
p (optional, float): rollover probability. (Default: 0.5) | |
Returns: | |
Low-resolution image and high-resolution image after random horizontal flip. | |
""" | |
if torch.rand(1).item() > p: | |
lr = F.hflip(lr) | |
hr = F.hflip(hr) | |
return lr, hr | |
def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]: | |
"""Turn the ``PIL.Image`` image upside down randomly. | |
Args: | |
lr: Low-resolution image data read by ``PIL.Image``. | |
hr: High-resolution image data read by ``PIL.Image``. | |
p (optional, float): rollover probability. (Default: 0.5) | |
Returns: | |
Randomly rotated up and down low-resolution images and high-resolution images. | |
""" | |
if torch.rand(1).item() > p: | |
lr = F.vflip(lr) | |
hr = F.vflip(hr) | |
return lr, hr | |
def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]: | |
"""Set ``PIL.Image`` to randomly adjust the image brightness. | |
Args: | |
lr: Low-resolution image data read by ``PIL.Image``. | |
hr: High-resolution image data read by ``PIL.Image``. | |
Returns: | |
Low-resolution image and high-resolution image with randomly adjusted brightness. | |
""" | |
# Randomly adjust the brightness gain range. | |
factor = random.uniform(0.5, 2) | |
lr = F.adjust_brightness(lr, factor) | |
hr = F.adjust_brightness(hr, factor) | |
return lr, hr | |
def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]: | |
"""Set ``PIL.Image`` to randomly adjust the image contrast. | |
Args: | |
lr: Low-resolution image data read by ``PIL.Image``. | |
hr: High-resolution image data read by ``PIL.Image``. | |
Returns: | |
Low-resolution image and high-resolution image with randomly adjusted contrast. | |
""" | |
# Randomly adjust the contrast gain range. | |
factor = random.uniform(0.5, 2) | |
lr = F.adjust_contrast(lr, factor) | |
hr = F.adjust_contrast(hr, factor) | |
return lr, hr | |
#### metrics to compute -- assumes single images, i.e., tensor of 3 dims | |
def img_mae(x1, x2): | |
m = torch.abs(x1-x2).mean() | |
return m | |
def img_mse(x1, x2): | |
m = torch.pow(torch.abs(x1-x2),2).mean() | |
return m | |
def img_psnr(x1, x2): | |
m = kornia.metrics.psnr(x1, x2, 1) | |
return m | |
def img_ssim(x1, x2): | |
m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5) | |
m = m.mean() | |
return m | |
def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)): | |
''' | |
xLR/SR/HR: 3xHxW | |
xSRvar: 1xHxW | |
''' | |
plt.figure(figsize=(30,10)) | |
plt.subplot(1,5,1) | |
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1)) | |
plt.axis('off') | |
plt.subplot(1,5,2) | |
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1)) | |
plt.axis('off') | |
plt.subplot(1,5,3) | |
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1)) | |
plt.axis('off') | |
plt.subplot(1,5,4) | |
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0) | |
print('error', error_map.min(), error_map.max()) | |
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet') | |
plt.clim(elim[0], elim[1]) | |
plt.axis('off') | |
plt.subplot(1,5,5) | |
print('uncer', xSRvar.min(), xSRvar.max()) | |
plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot') | |
plt.clim(ulim[0], ulim[1]) | |
plt.axis('off') | |
plt.subplots_adjust(wspace=0, hspace=0) | |
plt.show() | |
def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None): | |
''' | |
xLR/SR/HR: 3xHxW | |
''' | |
plt.figure(figsize=(30,10)) | |
if task != 'm': | |
plt.subplot(1,4,1) | |
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1)) | |
plt.axis('off') | |
plt.subplot(1,4,2) | |
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1)) | |
plt.axis('off') | |
plt.subplot(1,4,3) | |
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1)) | |
plt.axis('off') | |
else: | |
plt.subplot(1,4,1) | |
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray') | |
plt.clim(0,0.9) | |
plt.axis('off') | |
plt.subplot(1,4,2) | |
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray') | |
plt.clim(0,0.9) | |
plt.axis('off') | |
plt.subplot(1,4,3) | |
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray') | |
plt.clim(0,0.9) | |
plt.axis('off') | |
plt.subplot(1,4,4) | |
if task == 'inpainting': | |
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data | |
else: | |
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0) | |
print('error', error_map.min(), error_map.max()) | |
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet') | |
plt.clim(elim[0], elim[1]) | |
plt.axis('off') | |
plt.subplots_adjust(wspace=0, hspace=0) | |
plt.show() | |
def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)): | |
''' | |
xSRvar: 1xHxW | |
''' | |
plt.figure(figsize=(30,10)) | |
plt.subplot(1,4,1) | |
print('uncer', xSRvar1.min(), xSRvar1.max()) | |
plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot') | |
plt.clim(ulim[0], ulim[1]) | |
plt.axis('off') | |
plt.subplot(1,4,2) | |
print('uncer', xSRvar2.min(), xSRvar2.max()) | |
plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot') | |
plt.clim(ulim[0], ulim[1]) | |
plt.axis('off') | |
plt.subplot(1,4,3) | |
print('uncer', xSRvar3.min(), xSRvar3.max()) | |
plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot') | |
plt.clim(ulim[0], ulim[1]) | |
plt.axis('off') | |
plt.subplot(1,4,4) | |
print('uncer', xSRvar4.min(), xSRvar4.max()) | |
plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot') | |
plt.clim(ulim[0], ulim[1]) | |
plt.axis('off') | |
plt.subplots_adjust(wspace=0, hspace=0) | |
plt.show() | |
def get_UCE(list_err, list_yout_var, num_bins=100): | |
err_min = np.min(list_err) | |
err_max = np.max(list_err) | |
err_len = (err_max-err_min)/num_bins | |
num_points = len(list_err) | |
bin_stats = {} | |
for i in range(num_bins): | |
bin_stats[i] = { | |
'start_idx': err_min + i*err_len, | |
'end_idx': err_min + (i+1)*err_len, | |
'num_points': 0, | |
'mean_err': 0, | |
'mean_var': 0, | |
} | |
for e,v in zip(list_err, list_yout_var): | |
for i in range(num_bins): | |
if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']: | |
bin_stats[i]['num_points'] += 1 | |
bin_stats[i]['mean_err'] += e | |
bin_stats[i]['mean_var'] += v | |
uce = 0 | |
eps = 1e-8 | |
for i in range(num_bins): | |
bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps | |
bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps | |
bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \ | |
*(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var'])) | |
uce += bin_stats[i]['uce_bin'] | |
list_x, list_y = [], [] | |
for i in range(num_bins): | |
if bin_stats[i]['num_points']>0: | |
list_x.append(bin_stats[i]['mean_err']) | |
list_y.append(bin_stats[i]['mean_var']) | |
# sns.set_style('darkgrid') | |
# sns.scatterplot(x=list_x, y=list_y) | |
# sns.regplot(x=list_x, y=list_y, order=1) | |
# plt.xlabel('MSE', fontsize=34) | |
# plt.ylabel('Uncertainty', fontsize=34) | |
# plt.plot(list_x, list_x, color='r') | |
# plt.xlim(np.min(list_x), np.max(list_x)) | |
# plt.ylim(np.min(list_err), np.max(list_x)) | |
# plt.show() | |
return bin_stats, uce | |
##################### training BayesCap | |
def train_BayesCap( | |
NetC, | |
NetG, | |
train_loader, | |
eval_loader, | |
Cri = TempCombLoss(), | |
device='cuda', | |
dtype=torch.cuda.FloatTensor(), | |
init_lr=1e-4, | |
num_epochs=100, | |
eval_every=1, | |
ckpt_path='../ckpt/BayesCap', | |
T1=1e0, | |
T2=5e-2, | |
task=None, | |
): | |
NetC.to(device) | |
NetC.train() | |
NetG.to(device) | |
NetG.eval() | |
optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr) | |
optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs) | |
score = -1e8 | |
all_loss = [] | |
for eph in range(num_epochs): | |
eph_loss = 0 | |
with tqdm(train_loader, unit='batch') as tepoch: | |
for (idx, batch) in enumerate(tepoch): | |
if idx>2000: | |
break | |
tepoch.set_description('Epoch {}'.format(eph)) | |
## | |
xLR, xHR = batch[0].to(device), batch[1].to(device) | |
xLR, xHR = xLR.type(dtype), xHR.type(dtype) | |
if task == 'inpainting': | |
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3])) | |
xMask = xMask.to(device).type(dtype) | |
# pass them through the network | |
with torch.no_grad(): | |
if task == 'inpainting': | |
_, xSR1 = NetG(xLR, xMask) | |
elif task == 'depth': | |
xSR1 = NetG(xLR)[("disp", 0)] | |
else: | |
xSR1 = NetG(xLR) | |
# with torch.autograd.set_detect_anomaly(True): | |
xSR = xSR1.clone() | |
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR) | |
# print(xSRC_alpha) | |
optimizer.zero_grad() | |
if task == 'depth': | |
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2) | |
else: | |
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2) | |
# print(loss) | |
loss.backward() | |
optimizer.step() | |
## | |
eph_loss += loss.item() | |
tepoch.set_postfix(loss=loss.item()) | |
eph_loss /= len(train_loader) | |
all_loss.append(eph_loss) | |
print('Avg. loss: {}'.format(eph_loss)) | |
# evaluate and save the models | |
torch.save(NetC.state_dict(), ckpt_path+'_last.pth') | |
if eph%eval_every == 0: | |
curr_score = eval_BayesCap( | |
NetC, | |
NetG, | |
eval_loader, | |
device=device, | |
dtype=dtype, | |
task=task, | |
) | |
print('current score: {} | Last best score: {}'.format(curr_score, score)) | |
if curr_score >= score: | |
score = curr_score | |
torch.save(NetC.state_dict(), ckpt_path+'_best.pth') | |
optim_scheduler.step() | |
#### get different uncertainty maps | |
def get_uncer_BayesCap( | |
NetC, | |
NetG, | |
xin, | |
task=None, | |
xMask=None, | |
): | |
with torch.no_grad(): | |
if task == 'inpainting': | |
_, xSR = NetG(xin, xMask) | |
else: | |
xSR = NetG(xin) | |
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR) | |
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data | |
b_map = xSRC_beta.to('cpu').data | |
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2)))) | |
return xSRvar | |
def get_uncer_TTDAp( | |
NetG, | |
xin, | |
p_mag=0.05, | |
num_runs=50, | |
task=None, | |
xMask=None, | |
): | |
list_xSR = [] | |
with torch.no_grad(): | |
for z in range(num_runs): | |
if task == 'inpainting': | |
_, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask) | |
else: | |
xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin)) | |
list_xSR.append(xSRz) | |
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0) | |
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1) | |
return xSRvar | |
def get_uncer_DO( | |
NetG, | |
xin, | |
dop=0.2, | |
num_runs=50, | |
task=None, | |
xMask=None, | |
): | |
list_xSR = [] | |
with torch.no_grad(): | |
for z in range(num_runs): | |
if task == 'inpainting': | |
_, xSRz = NetG(xin, xMask, dop=dop) | |
else: | |
xSRz = NetG(xin, dop=dop) | |
list_xSR.append(xSRz) | |
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0) | |
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1) | |
return xSRvar | |
################### Different eval functions | |
def eval_BayesCap( | |
NetC, | |
NetG, | |
eval_loader, | |
device='cuda', | |
dtype=torch.cuda.FloatTensor, | |
task=None, | |
xMask=None, | |
): | |
NetC.to(device) | |
NetC.eval() | |
NetG.to(device) | |
NetG.eval() | |
mean_ssim = 0 | |
mean_psnr = 0 | |
mean_mse = 0 | |
mean_mae = 0 | |
num_imgs = 0 | |
list_error = [] | |
list_var = [] | |
with tqdm(eval_loader, unit='batch') as tepoch: | |
for (idx, batch) in enumerate(tepoch): | |
tepoch.set_description('Validating ...') | |
## | |
xLR, xHR = batch[0].to(device), batch[1].to(device) | |
xLR, xHR = xLR.type(dtype), xHR.type(dtype) | |
if task == 'inpainting': | |
if xMask==None: | |
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3])) | |
xMask = xMask.to(device).type(dtype) | |
else: | |
xMask = xMask.to(device).type(dtype) | |
# pass them through the network | |
with torch.no_grad(): | |
if task == 'inpainting': | |
_, xSR = NetG(xLR, xMask) | |
elif task == 'depth': | |
xSR = NetG(xLR)[("disp", 0)] | |
else: | |
xSR = NetG(xLR) | |
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR) | |
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data | |
b_map = xSRC_beta.to('cpu').data | |
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2)))) | |
n_batch = xSRC_mu.shape[0] | |
if task == 'depth': | |
xHR = xSR | |
for j in range(n_batch): | |
num_imgs += 1 | |
mean_ssim += img_ssim(xSRC_mu[j], xHR[j]) | |
mean_psnr += img_psnr(xSRC_mu[j], xHR[j]) | |
mean_mse += img_mse(xSRC_mu[j], xHR[j]) | |
mean_mae += img_mae(xSRC_mu[j], xHR[j]) | |
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j]) | |
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1) | |
var_map = xSRvar[j].to('cpu').data.reshape(-1) | |
list_error.extend(list(error_map.numpy())) | |
list_var.extend(list(var_map.numpy())) | |
## | |
mean_ssim /= num_imgs | |
mean_psnr /= num_imgs | |
mean_mse /= num_imgs | |
mean_mae /= num_imgs | |
print( | |
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format | |
( | |
mean_ssim, mean_psnr, mean_mse, mean_mae | |
) | |
) | |
# print(len(list_error), len(list_var)) | |
# print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1]) | |
# print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10]))) | |
return mean_ssim | |
def eval_TTDA_p( | |
NetG, | |
eval_loader, | |
device='cuda', | |
dtype=torch.cuda.FloatTensor, | |
p_mag=0.05, | |
num_runs=50, | |
task = None, | |
xMask = None, | |
): | |
NetG.to(device) | |
NetG.eval() | |
mean_ssim = 0 | |
mean_psnr = 0 | |
mean_mse = 0 | |
mean_mae = 0 | |
num_imgs = 0 | |
with tqdm(eval_loader, unit='batch') as tepoch: | |
for (idx, batch) in enumerate(tepoch): | |
tepoch.set_description('Validating ...') | |
## | |
xLR, xHR = batch[0].to(device), batch[1].to(device) | |
xLR, xHR = xLR.type(dtype), xHR.type(dtype) | |
# pass them through the network | |
list_xSR = [] | |
with torch.no_grad(): | |
if task=='inpainting': | |
_, xSR = NetG(xLR, xMask) | |
else: | |
xSR = NetG(xLR) | |
for z in range(num_runs): | |
xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR)) | |
list_xSR.append(xSRz) | |
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0) | |
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1) | |
n_batch = xSR.shape[0] | |
for j in range(n_batch): | |
num_imgs += 1 | |
mean_ssim += img_ssim(xSR[j], xHR[j]) | |
mean_psnr += img_psnr(xSR[j], xHR[j]) | |
mean_mse += img_mse(xSR[j], xHR[j]) | |
mean_mae += img_mae(xSR[j], xHR[j]) | |
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j]) | |
mean_ssim /= num_imgs | |
mean_psnr /= num_imgs | |
mean_mse /= num_imgs | |
mean_mae /= num_imgs | |
print( | |
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format | |
( | |
mean_ssim, mean_psnr, mean_mse, mean_mae | |
) | |
) | |
return mean_ssim | |
def eval_DO( | |
NetG, | |
eval_loader, | |
device='cuda', | |
dtype=torch.cuda.FloatTensor, | |
dop=0.2, | |
num_runs=50, | |
task=None, | |
xMask=None, | |
): | |
NetG.to(device) | |
NetG.eval() | |
mean_ssim = 0 | |
mean_psnr = 0 | |
mean_mse = 0 | |
mean_mae = 0 | |
num_imgs = 0 | |
with tqdm(eval_loader, unit='batch') as tepoch: | |
for (idx, batch) in enumerate(tepoch): | |
tepoch.set_description('Validating ...') | |
## | |
xLR, xHR = batch[0].to(device), batch[1].to(device) | |
xLR, xHR = xLR.type(dtype), xHR.type(dtype) | |
# pass them through the network | |
list_xSR = [] | |
with torch.no_grad(): | |
if task == 'inpainting': | |
_, xSR = NetG(xLR, xMask) | |
else: | |
xSR = NetG(xLR) | |
for z in range(num_runs): | |
xSRz = NetG(xLR, dop=dop) | |
list_xSR.append(xSRz) | |
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0) | |
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1) | |
n_batch = xSR.shape[0] | |
for j in range(n_batch): | |
num_imgs += 1 | |
mean_ssim += img_ssim(xSR[j], xHR[j]) | |
mean_psnr += img_psnr(xSR[j], xHR[j]) | |
mean_mse += img_mse(xSR[j], xHR[j]) | |
mean_mae += img_mae(xSR[j], xHR[j]) | |
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j]) | |
## | |
mean_ssim /= num_imgs | |
mean_psnr /= num_imgs | |
mean_mse /= num_imgs | |
mean_mae /= num_imgs | |
print( | |
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format | |
( | |
mean_ssim, mean_psnr, mean_mse, mean_mae | |
) | |
) | |
return mean_ssim | |
############### compare all function | |
def compare_all( | |
NetC, | |
NetG, | |
eval_loader, | |
p_mag = 0.05, | |
dop = 0.2, | |
num_runs = 100, | |
device='cuda', | |
dtype=torch.cuda.FloatTensor, | |
task=None, | |
): | |
NetC.to(device) | |
NetC.eval() | |
NetG.to(device) | |
NetG.eval() | |
with tqdm(eval_loader, unit='batch') as tepoch: | |
for (idx, batch) in enumerate(tepoch): | |
tepoch.set_description('Comparing ...') | |
## | |
xLR, xHR = batch[0].to(device), batch[1].to(device) | |
xLR, xHR = xLR.type(dtype), xHR.type(dtype) | |
if task == 'inpainting': | |
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3])) | |
xMask = xMask.to(device).type(dtype) | |
# pass them through the network | |
with torch.no_grad(): | |
if task == 'inpainting': | |
_, xSR = NetG(xLR, xMask) | |
else: | |
xSR = NetG(xLR) | |
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR) | |
if task == 'inpainting': | |
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask) | |
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask) | |
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask) | |
else: | |
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs) | |
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs) | |
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR) | |
print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape) | |
n_batch = xSR.shape[0] | |
for j in range(n_batch): | |
if task=='s': | |
show_SR_w_err(xLR[j], xHR[j], xSR[j]) | |
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42)) | |
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j]) | |
if task=='d': | |
show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j]) | |
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42)) | |
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j]) | |
if task=='inpainting': | |
show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j]) | |
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4)) | |
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j]) | |
if task=='m': | |
show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m') | |
show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15)) | |
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15)) | |
################# Degrading Identity | |
def degrage_BayesCap_p( | |
NetC, | |
NetG, | |
eval_loader, | |
device='cuda', | |
dtype=torch.cuda.FloatTensor, | |
num_runs=50, | |
): | |
NetC.to(device) | |
NetC.eval() | |
NetG.to(device) | |
NetG.eval() | |
p_mag_list = [0, 0.05, 0.1, 0.15, 0.2] | |
list_s = [] | |
list_p = [] | |
list_u1 = [] | |
list_u2 = [] | |
list_c = [] | |
for p_mag in p_mag_list: | |
mean_ssim = 0 | |
mean_psnr = 0 | |
mean_mse = 0 | |
mean_mae = 0 | |
num_imgs = 0 | |
list_error = [] | |
list_error2 = [] | |
list_var = [] | |
with tqdm(eval_loader, unit='batch') as tepoch: | |
for (idx, batch) in enumerate(tepoch): | |
tepoch.set_description('Validating ...') | |
## | |
xLR, xHR = batch[0].to(device), batch[1].to(device) | |
xLR, xHR = xLR.type(dtype), xHR.type(dtype) | |
# pass them through the network | |
with torch.no_grad(): | |
xSR = NetG(xLR) | |
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR)) | |
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data | |
b_map = xSRC_beta.to('cpu').data | |
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2)))) | |
n_batch = xSRC_mu.shape[0] | |
for j in range(n_batch): | |
num_imgs += 1 | |
mean_ssim += img_ssim(xSRC_mu[j], xSR[j]) | |
mean_psnr += img_psnr(xSRC_mu[j], xSR[j]) | |
mean_mse += img_mse(xSRC_mu[j], xSR[j]) | |
mean_mae += img_mae(xSRC_mu[j], xSR[j]) | |
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1) | |
error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1) | |
var_map = xSRvar[j].to('cpu').data.reshape(-1) | |
list_error.extend(list(error_map.numpy())) | |
list_error2.extend(list(error_map2.numpy())) | |
list_var.extend(list(var_map.numpy())) | |
## | |
mean_ssim /= num_imgs | |
mean_psnr /= num_imgs | |
mean_mse /= num_imgs | |
mean_mae /= num_imgs | |
print( | |
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format | |
( | |
mean_ssim, mean_psnr, mean_mse, mean_mae | |
) | |
) | |
uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1] | |
uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1] | |
print('UCE1: ', uce1) | |
print('UCE2: ', uce2) | |
list_s.append(mean_ssim.item()) | |
list_p.append(mean_psnr.item()) | |
list_u1.append(uce1) | |
list_u2.append(uce2) | |
plt.plot(list_s) | |
plt.show() | |
plt.plot(list_p) | |
plt.show() | |
plt.plot(list_u1, label='wrt SR output') | |
plt.plot(list_u2, label='wrt BayesCap output') | |
plt.legend() | |
plt.show() | |
sns.set_style('darkgrid') | |
fig,ax = plt.subplots() | |
# make a plot | |
ax.plot(p_mag_list, list_s, color="red", marker="o") | |
# set x-axis label | |
ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10) | |
# set y-axis label | |
ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10) | |
# twin object for two different y-axis on the sample plot | |
ax2=ax.twinx() | |
# make a plot with different y-axis using second axis object | |
ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT') | |
ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT') | |
ax2.set_ylabel("UCE", color="green", fontsize=10) | |
plt.legend(fontsize=10) | |
plt.tight_layout() | |
plt.show() | |
################# DeepFill_v2 | |
# ---------------------------------------- | |
# PATH processing | |
# ---------------------------------------- | |
def text_readlines(filename): | |
# Try to read a txt file and return a list.Return [] if there was a mistake. | |
try: | |
file = open(filename, 'r') | |
except IOError: | |
error = [] | |
return error | |
content = file.readlines() | |
# This for loop deletes the EOF (like \n) | |
for i in range(len(content)): | |
content[i] = content[i][:len(content[i])-1] | |
file.close() | |
return content | |
def savetxt(name, loss_log): | |
np_loss_log = np.array(loss_log) | |
np.savetxt(name, np_loss_log) | |
def get_files(path): | |
# read a folder, return the complete path | |
ret = [] | |
for root, dirs, files in os.walk(path): | |
for filespath in files: | |
ret.append(os.path.join(root, filespath)) | |
return ret | |
def get_names(path): | |
# read a folder, return the image name | |
ret = [] | |
for root, dirs, files in os.walk(path): | |
for filespath in files: | |
ret.append(filespath) | |
return ret | |
def text_save(content, filename, mode = 'a'): | |
# save a list to a txt | |
# Try to save a list variable in txt file. | |
file = open(filename, mode) | |
for i in range(len(content)): | |
file.write(str(content[i]) + '\n') | |
file.close() | |
def check_path(path): | |
if not os.path.exists(path): | |
os.makedirs(path) | |
# ---------------------------------------- | |
# Validation and Sample at training | |
# ---------------------------------------- | |
def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255): | |
# Save image one-by-one | |
for i in range(len(img_list)): | |
img = img_list[i] | |
# Recover normalization: * 255 because last layer is sigmoid activated | |
img = img * 255 | |
# Process img_copy and do not destroy the data of img | |
img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy() | |
img_copy = np.clip(img_copy, 0, pixel_max_cnt) | |
img_copy = img_copy.astype(np.uint8) | |
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR) | |
# Save to certain path | |
save_img_name = sample_name + '_' + name_list[i] + '.jpg' | |
save_img_path = os.path.join(sample_folder, save_img_name) | |
cv2.imwrite(save_img_path, img_copy) | |
def psnr(pred, target, pixel_max_cnt = 255): | |
mse = torch.mul(target - pred, target - pred) | |
rmse_avg = (torch.mean(mse).item()) ** 0.5 | |
p = 20 * np.log10(pixel_max_cnt / rmse_avg) | |
return p | |
def grey_psnr(pred, target, pixel_max_cnt = 255): | |
pred = torch.sum(pred, dim = 0) | |
target = torch.sum(target, dim = 0) | |
mse = torch.mul(target - pred, target - pred) | |
rmse_avg = (torch.mean(mse).item()) ** 0.5 | |
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg) | |
return p | |
def ssim(pred, target): | |
pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy() | |
target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy() | |
target = target[0] | |
pred = pred[0] | |
ssim = skimage.measure.compare_ssim(target, pred, multichannel = True) | |
return ssim | |
## for contextual attention | |
def extract_image_patches(images, ksizes, strides, rates, padding='same'): | |
""" | |
Extract patches from images and put them in the C output dimension. | |
:param padding: | |
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape | |
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for | |
each dimension of images | |
:param strides: [stride_rows, stride_cols] | |
:param rates: [dilation_rows, dilation_cols] | |
:return: A Tensor | |
""" | |
assert len(images.size()) == 4 | |
assert padding in ['same', 'valid'] | |
batch_size, channel, height, width = images.size() | |
if padding == 'same': | |
images = same_padding(images, ksizes, strides, rates) | |
elif padding == 'valid': | |
pass | |
else: | |
raise NotImplementedError('Unsupported padding type: {}.\ | |
Only "same" or "valid" are supported.'.format(padding)) | |
unfold = torch.nn.Unfold(kernel_size=ksizes, | |
dilation=rates, | |
padding=0, | |
stride=strides) | |
patches = unfold(images) | |
return patches # [N, C*k*k, L], L is the total number of such blocks | |
def same_padding(images, ksizes, strides, rates): | |
assert len(images.size()) == 4 | |
batch_size, channel, rows, cols = images.size() | |
out_rows = (rows + strides[0] - 1) // strides[0] | |
out_cols = (cols + strides[1] - 1) // strides[1] | |
effective_k_row = (ksizes[0] - 1) * rates[0] + 1 | |
effective_k_col = (ksizes[1] - 1) * rates[1] + 1 | |
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) | |
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) | |
# Pad the input | |
padding_top = int(padding_rows / 2.) | |
padding_left = int(padding_cols / 2.) | |
padding_bottom = padding_rows - padding_top | |
padding_right = padding_cols - padding_left | |
paddings = (padding_left, padding_right, padding_top, padding_bottom) | |
images = torch.nn.ZeroPad2d(paddings)(images) | |
return images | |
def reduce_mean(x, axis=None, keepdim=False): | |
if not axis: | |
axis = range(len(x.shape)) | |
for i in sorted(axis, reverse=True): | |
x = torch.mean(x, dim=i, keepdim=keepdim) | |
return x | |
def reduce_std(x, axis=None, keepdim=False): | |
if not axis: | |
axis = range(len(x.shape)) | |
for i in sorted(axis, reverse=True): | |
x = torch.std(x, dim=i, keepdim=keepdim) | |
return x | |
def reduce_sum(x, axis=None, keepdim=False): | |
if not axis: | |
axis = range(len(x.shape)) | |
for i in sorted(axis, reverse=True): | |
x = torch.sum(x, dim=i, keepdim=keepdim) | |
return x | |
def random_mask(num_batch=1, mask_shape=(256,256)): | |
list_mask = [] | |
for _ in range(num_batch): | |
# rectangle mask | |
image_height = mask_shape[0] | |
image_width = mask_shape[1] | |
max_delta_height = image_height//8 | |
max_delta_width = image_width//8 | |
height = image_height//4 | |
width = image_width//4 | |
max_t = image_height - height | |
max_l = image_width - width | |
t = random.randint(0, max_t) | |
l = random.randint(0, max_l) | |
# bbox = (t, l, height, width) | |
h = random.randint(0, max_delta_height//2) | |
w = random.randint(0, max_delta_width//2) | |
mask = torch.zeros((1, 1, image_height, image_width)) | |
mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1 | |
rect_mask = mask | |
# brush mask | |
min_num_vertex = 4 | |
max_num_vertex = 12 | |
mean_angle = 2 * math.pi / 5 | |
angle_range = 2 * math.pi / 15 | |
min_width = 12 | |
max_width = 40 | |
H, W = image_height, image_width | |
average_radius = math.sqrt(H*H+W*W) / 8 | |
mask = Image.new('L', (W, H), 0) | |
for _ in range(np.random.randint(1, 4)): | |
num_vertex = np.random.randint(min_num_vertex, max_num_vertex) | |
angle_min = mean_angle - np.random.uniform(0, angle_range) | |
angle_max = mean_angle + np.random.uniform(0, angle_range) | |
angles = [] | |
vertex = [] | |
for i in range(num_vertex): | |
if i % 2 == 0: | |
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max)) | |
else: | |
angles.append(np.random.uniform(angle_min, angle_max)) | |
h, w = mask.size | |
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h)))) | |
for i in range(num_vertex): | |
r = np.clip( | |
np.random.normal(loc=average_radius, scale=average_radius//2), | |
0, 2*average_radius) | |
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w) | |
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h) | |
vertex.append((int(new_x), int(new_y))) | |
draw = ImageDraw.Draw(mask) | |
width = int(np.random.uniform(min_width, max_width)) | |
draw.line(vertex, fill=255, width=width) | |
for v in vertex: | |
draw.ellipse((v[0] - width//2, | |
v[1] - width//2, | |
v[0] + width//2, | |
v[1] + width//2), | |
fill=255) | |
if np.random.normal() > 0: | |
mask.transpose(Image.FLIP_LEFT_RIGHT) | |
if np.random.normal() > 0: | |
mask.transpose(Image.FLIP_TOP_BOTTOM) | |
mask = transforms.ToTensor()(mask) | |
mask = mask.reshape((1, 1, H, W)) | |
brush_mask = mask | |
mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0] | |
list_mask.append(mask) | |
mask = torch.cat(list_mask, dim=0) | |
return mask |