cartoonize / utils /image_processing.py
YANGYYYY's picture
Upload 8 files
922e494 verified
raw
history blame
No virus
3.08 kB
import torch
import cv2
import os
import numpy as np
from tqdm import tqdm
def gram(input):
"""
Calculate Gram Matrix
https://pytorch.org/tutorials/advanced/neural_style_tutorial.html#style-loss
"""
b, c, w, h = input.size()
x = input.contiguous().view(b * c, w * h)
# x = x / 2
# Work around, torch.mm would generate some inf values.
# https://discuss.pytorch.org/t/gram-matrix-in-mixed-precision/166800/2
# x = torch.clamp(x, max=1.0e2, min=-1.0e2)
# x[x > 1.0e2] = 1.0e2
# x[x < -1.0e2] = -1.0e2
G = torch.mm(x, x.T)
G = torch.clamp(G, -64990.0, 64990.0)
# normalize by total elements
result = G.div(b * c * w * h)
return result
def divisible(dim):
'''
Make width and height divisible by 32
'''
width, height = dim
return width - (width % 32), height - (height % 32)
def resize_image(image, width=None, height=None, inter=cv2.INTER_AREA):
dim = None
h, w = image.shape[:2]
if width and height:
return cv2.resize(image, divisible((width, height)), interpolation=inter)
if width is None and height is None:
return cv2.resize(image, divisible((w, h)), interpolation=inter)
if width is None:
r = height / float(h)
dim = (int(w * r), height)
else:
r = width / float(w)
dim = (width, int(h * r))
return cv2.resize(image, divisible(dim), interpolation=inter)
def normalize_input(images):
'''
[0, 255] -> [-1, 1]
'''
return images / 127.5 - 1.0
def denormalize_input(images, dtype=None):
'''
[-1, 1] -> [0, 255]
'''
images = images * 127.5 + 127.5
if dtype is not None:
if isinstance(images, torch.Tensor):
images = images.type(dtype)
else:
# numpy.ndarray
images = images.astype(dtype)
return images
def preprocess_images(images):
'''
Preprocess image for inference
@Arguments:
- images: np.ndarray
@Returns
- images: torch.tensor
'''
images = images.astype(np.float32)
# Normalize to [-1, 1]
images = normalize_input(images)
images = torch.from_numpy(images)
# Add batch dim
if len(images.shape) == 3:
images = images.unsqueeze(0)
# channel first
images = images.permute(0, 3, 1, 2)
return images
def compute_data_mean(data_folder):
if not os.path.exists(data_folder):
raise FileNotFoundError(f'Folder {data_folder} does not exits')
image_files = os.listdir(data_folder)
total = np.zeros(3)
print(f"Compute mean (R, G, B) from {len(image_files)} images")
for img_file in tqdm(image_files):
path = os.path.join(data_folder, img_file)
image = cv2.imread(path)
total += image.mean(axis=(0, 1))
channel_mean = total / len(image_files)
mean = np.mean(channel_mean)
return mean - channel_mean[...,::-1] # Convert to BGR for training
if __name__ == '__main__':
t = torch.rand(2, 14, 32, 32)
with torch.autocast("cpu"):
print(gram(t))