Ohayou_Face / util /utilgan.py
Reevee's picture
first
f39e999
import os
import sys
import time
import math
import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.interpolate import CubicSpline as CubSpline
from scipy.special import comb
import scipy
from imageio import imread
import torch
import torch.nn.functional as F
# from perlin import PerlinNoiseFactory as Perlin
# noise = Perlin(1)
# def latent_noise(t, dim, noise_step=78564.543):
# latent = np.zeros((1, dim))
# for i in range(dim):
# latent[0][i] = noise(t + i * noise_step)
# return latent
def load_latents(npy_file):
key_latents = np.load(npy_file)
try:
key_latents = key_latents[key_latents.files[0]]
except:
pass
idx_file = os.path.splitext(npy_file)[0] + '.txt'
if os.path.exists(idx_file):
with open(idx_file) as f:
lat_idx = f.readline()
lat_idx = [int(l.strip()) for l in lat_idx.split(',') if '\n' not in l and len(l.strip())>0]
key_latents = [key_latents[i] for i in lat_idx]
return np.asarray(key_latents)
# = = = = = = = = = = = = = = = = = = = = = = = = = = =
def get_z(shape, seed=None, uniform=False):
if seed is None:
seed = np.random.seed(int((time.time()%1) * 9999))
rnd = np.random.RandomState(seed)
if uniform:
return rnd.uniform(0., 1., shape)
else:
return rnd.randn(*shape) # *x unpacks tuple/list to sequence
def smoothstep(x, NN=1., xmin=0., xmax=1.):
N = math.ceil(NN)
x = np.clip((x - xmin) / (xmax - xmin), 0, 1)
result = 0
for n in range(0, N+1):
result += scipy.special.comb(N+n, n) * scipy.special.comb(2*N+1, N-n) * (-x)**n
result *= x**(N+1)
if NN != N: result = (x + result) / 2
return result
def lerp(z1, z2, num_steps, smooth=0.):
vectors = []
xs = [step / (num_steps - 1) for step in range(num_steps)]
if smooth > 0: xs = [smoothstep(x, smooth) for x in xs]
for x in xs:
interpol = z1 + (z2 - z1) * x
vectors.append(interpol)
return np.array(vectors)
# interpolate on hypersphere
def slerp(z1, z2, num_steps, smooth=0.):
z1_norm = np.linalg.norm(z1)
z2_norm = np.linalg.norm(z2)
z2_normal = z2 * (z1_norm / z2_norm)
vectors = []
xs = [step / (num_steps - 1) for step in range(num_steps)]
if smooth > 0: xs = [smoothstep(x, smooth) for x in xs]
for x in xs:
interplain = z1 + (z2 - z1) * x
interp = z1 + (z2_normal - z1) * x
interp_norm = np.linalg.norm(interp)
interpol_normal = interplain * (z1_norm / interp_norm)
# interpol_normal = interp * (z1_norm / interp_norm)
vectors.append(interpol_normal)
return np.array(vectors)
def cublerp(points, steps, fstep):
keys = np.array([i*fstep for i in range(steps)] + [steps*fstep])
points = np.concatenate((points, np.expand_dims(points[0], 0)))
cspline = CubSpline(keys, points)
return cspline(range(steps*fstep+1))
# = = = = = = = = = = = = = = = = = = = = = = = = = = =
def latent_anima(shape, frames, transit, key_latents=None, smooth=0.5, cubic=False, gauss=False, seed=None, verbose=True):
if key_latents is None:
transit = int(max(1, min(frames//4, transit)))
steps = max(1, int(frames // transit))
log = ' timeline: %d steps by %d' % (steps, transit)
getlat = lambda : get_z(shape, seed=seed)
# make key points
if key_latents is None:
key_latents = np.array([getlat() for i in range(steps)])
latents = np.expand_dims(key_latents[0], 0)
# populate lerp between key points
if transit == 1:
latents = key_latents
else:
if cubic:
latents = cublerp(key_latents, steps, transit)
log += ', cubic'
else:
for i in range(steps):
zA = key_latents[i]
zB = key_latents[(i+1) % steps]
interps_z = slerp(zA, zB, transit, smooth=smooth)
latents = np.concatenate((latents, interps_z))
latents = np.array(latents)
if gauss:
lats_post = gaussian_filter(latents, [transit, 0, 0], mode="wrap")
lats_post = (lats_post / np.linalg.norm(lats_post, axis=-1, keepdims=True)) * math.sqrt(np.prod(shape))
log += ', gauss'
latents = lats_post
if verbose: print(log)
if latents.shape[0] > frames: # extra frame
latents = latents[1:]
return latents
# = = = = = = = = = = = = = = = = = = = = = = = = = = =
def multimask(x, size, latmask=None, countHW=[1,1], delta=0.):
Hx, Wx = countHW
bcount = x.shape[0]
if max(countHW) > 1:
W = x.shape[3] # width
H = x.shape[2] # height
if Wx > 1:
stripe_mask = []
for i in range(Wx):
ch_mask = peak_roll(W, Wx, i, delta).unsqueeze(0).unsqueeze(0) # [1,1,w] th
ch_mask = ch_mask.repeat(1,H,1) # [1,h,w]
stripe_mask.append(ch_mask)
maskW = torch.cat(stripe_mask, 0).unsqueeze(1) # [x,1,h,w]
else: maskW = [1]
if Hx > 1:
stripe_mask = []
for i in range(Hx):
ch_mask = peak_roll(H, Hx, i, delta).unsqueeze(1).unsqueeze(0) # [1,h,1] th
ch_mask = ch_mask.repeat(1,1,W) # [1,h,w]
stripe_mask.append(ch_mask)
maskH = torch.cat(stripe_mask, 0).unsqueeze(1) # [y,1,h,w]
else: maskH = [1]
mask = []
for i in range(Wx):
for j in range(Hx):
mask.append(maskW[i] * maskH[j])
mask = torch.cat(mask, 0).unsqueeze(1) # [xy,1,h,w]
mask = mask.to(x.device)
x = torch.sum(x[:Hx*Wx] * mask, 0, keepdim=True)
elif latmask is not None:
if len(latmask.shape) < 4:
latmask = latmask.unsqueeze(1) # [b,1,h,w]
lms = latmask.shape
if list(lms[2:]) != list(size) and np.prod(lms) > 1:
latmask = F.interpolate(latmask, size) # , mode='nearest'
latmask = latmask.type(x.dtype)
x = torch.sum(x[:lms[0]] * latmask, 0, keepdim=True)
else:
return x
x = x.repeat(bcount,1,1,1)
return x # [b,f,h,w]
def peak_roll(width, count, num, delta):
step = width // count
if width > step*2:
fill_range = torch.zeros([width-step*2])
full_ax = torch.cat((peak(step, delta), fill_range), 0)
else:
full_ax = peak(step, delta)[:width]
if num == 0:
shift = max(width - (step//2), 0.) # must be positive!
else:
shift = step*num - (step//2)
full_ax = torch.roll(full_ax, shift, 0)
return full_ax # [width,]
def peak(steps, delta):
x = torch.linspace(0.-delta, 1.+ delta, steps)
x_rev = torch.flip(x,[0])
x = torch.cat((x, x_rev), 0)
x = torch.clip(x, 0., 1.)
return x # [steps*2,]
# = = = = = = = = = = = = = = = = = = = = = = = = = = =
def ups2d(x, factor=2):
assert isinstance(factor, int) and factor >= 1
if factor == 1: return x
s = x.shape
x = x.reshape(-1, s[1], s[2], 1, s[3], 1)
x = x.repeat(1, 1, 1, factor, 1, factor)
x = x.reshape(-1, s[1], s[2] * factor, s[3] * factor)
return x
# Tiles an array around two points, allowing for pad lengths greater than the input length
# NB: if symm=True, every second tile is mirrored = messed up in GAN
# adapted from https://discuss.pytorch.org/t/symmetric-padding/19866/3
def tile_pad(xt, padding, symm=True):
h, w = xt.shape[-2:]
left, right, top, bottom = padding
def tile(x, minx, maxx, symm=True):
rng = maxx - minx
if symm is True: # triangular reflection
double_rng = 2*rng
mod = np.fmod(x - minx, double_rng)
normed_mod = np.where(mod < 0, mod+double_rng, mod)
out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx
else: # repeating tiles
mod = np.remainder(x - minx, rng)
out = mod + minx
return np.array(out, dtype=x.dtype)
x_idx = np.arange(-left, w+right)
y_idx = np.arange(-top, h+bottom)
x_pad = tile(x_idx, -0.5, w-0.5, symm)
y_pad = tile(y_idx, -0.5, h-0.5, symm)
xx, yy = np.meshgrid(x_pad, y_pad)
return xt[..., yy, xx]
def pad_up_to(x, size, type='centr'):
sh = x.shape[2:][::-1]
if list(x.shape[2:]) == list(size): return x
padding = []
for i, s in enumerate(size[::-1]):
if 'side' in type.lower():
padding = padding + [0, s-sh[i]]
else: # centr
p0 = (s-sh[i]) // 2
p1 = s-sh[i] - p0
padding = padding + [p0,p1]
y = tile_pad(x, padding, symm = 'symm' in type.lower())
# if 'symm' in type.lower():
# y = tile_pad(x, padding, symm=True)
# else:
# y = F.pad(x, padding, 'circular')
return y
# scale_type may include pad, side, symm
def fix_size(x, size, scale_type='centr'):
if not len(x.shape) == 4:
raise Exception(" Wrong data rank, shape:", x.shape)
if x.shape[2:] == size:
return x
if (x.shape[2]*2, x.shape[3]*2) == size:
return ups2d(x)
if scale_type.lower() == 'fit':
return F.interpolate(x, size, mode='nearest') # , align_corners=True
elif 'pad' in scale_type.lower():
pass
else: # proportional scale to smaller side, then pad to bigger side
sh0 = x.shape[2:]
upsc = np.min(size) / np.min(sh0)
new_size = [int(sh0[i]*upsc) for i in [0,1]]
x = F.interpolate(x, new_size, mode='nearest') # , align_corners=True
x = pad_up_to(x, size, scale_type)
return x
# Make list of odd sizes for upsampling to arbitrary resolution
def hw_scales(size, base, n, keep_first_layers=None, verbose=False):
if isinstance(base, int): base = (base, base)
start_res = [int(b * 2 ** (-n)) for b in base]
start_res[0] = int(start_res[0] * size[0] // base[0])
start_res[1] = int(start_res[1] * size[1] // base[1])
hw_list = []
if base[0] != base[1] and verbose is True:
print(' size', size, 'base', base, 'start_res', start_res, 'n', n)
if keep_first_layers is not None and keep_first_layers > 0:
for i in range(keep_first_layers):
hw_list.append(start_res)
start_res = [x*2 for x in start_res]
n -= 1
ch = (size[0] / start_res[0]) ** (1/n)
cw = (size[1] / start_res[1]) ** (1/n)
for i in range(n):
h = math.floor(start_res[0] * ch**i)
w = math.floor(start_res[1] * cw**i)
hw_list.append((h,w))
hw_list.append(size)
return hw_list
def calc_res(shape):
base0 = 2**int(np.log2(shape[0]))
base1 = 2**int(np.log2(shape[1]))
base = min(base0, base1)
min_res = min(shape[0], shape[1])
def int_log2(xs, base):
return [x * 2**(2-int(np.log2(base))) % 1 == 0 for x in xs]
if min_res != base or max(*shape) / min(*shape) >= 2:
if np.log2(base) < 10 and all(int_log2(shape, base*2)):
base = base * 2
return base # , [shape[0]/base, shape[1]/base]
def calc_init_res(shape, resolution=None):
if len(shape) == 1:
shape = [shape[0], shape[0], 1]
elif len(shape) == 2:
shape = [*shape, 1]
size = shape[:2] if shape[2] < min(*shape[:2]) else shape[1:] # fewer colors than pixels
if resolution is None:
resolution = calc_res(size)
res_log2 = int(np.log2(resolution))
init_res = [int(s * 2**(2-res_log2)) for s in size]
return init_res, resolution, res_log2
def basename(file):
return os.path.splitext(os.path.basename(file))[0]
def file_list(path, ext=None, subdir=None):
if subdir is True:
files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn]
else:
files = [os.path.join(path, f) for f in os.listdir(path)]
if ext is not None:
if isinstance(ext, list):
files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ext]
elif isinstance(ext, str):
files = [f for f in files if f.endswith(ext)]
else:
print(' Unknown extension/type for file list!')
return sorted([f for f in files if os.path.isfile(f)])
def dir_list(in_dir):
dirs = [os.path.join(in_dir, x) for x in os.listdir(in_dir)]
return sorted([f for f in dirs if os.path.isdir(f)])
def img_list(path, subdir=None):
if subdir is True:
files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn]
else:
files = [os.path.join(path, f) for f in os.listdir(path)]
files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ['jpg', 'jpeg', 'png', 'ppm', 'tif']]
return sorted([f for f in files if os.path.isfile(f)])
def img_read(path):
img = imread(path)
# 8bit to 256bit
if (img.ndim == 2) or (img.shape[2] == 1):
img = np.dstack((img,img,img))
# rgba to rgb
if img.shape[2] == 4:
img = img[:,:,:3]
return img