import os import sys import time import math import numpy as np from scipy.ndimage import gaussian_filter from scipy.interpolate import CubicSpline as CubSpline from scipy.special import comb import scipy from imageio import imread import torch import torch.nn.functional as F # from perlin import PerlinNoiseFactory as Perlin # noise = Perlin(1) # def latent_noise(t, dim, noise_step=78564.543): # latent = np.zeros((1, dim)) # for i in range(dim): # latent[0][i] = noise(t + i * noise_step) # return latent def load_latents(npy_file): key_latents = np.load(npy_file) try: key_latents = key_latents[key_latents.files[0]] except: pass idx_file = os.path.splitext(npy_file)[0] + '.txt' if os.path.exists(idx_file): with open(idx_file) as f: lat_idx = f.readline() lat_idx = [int(l.strip()) for l in lat_idx.split(',') if '\n' not in l and len(l.strip())>0] key_latents = [key_latents[i] for i in lat_idx] return np.asarray(key_latents) # = = = = = = = = = = = = = = = = = = = = = = = = = = = def get_z(shape, seed=None, uniform=False): if seed is None: seed = np.random.seed(int((time.time()%1) * 9999)) rnd = np.random.RandomState(seed) if uniform: return rnd.uniform(0., 1., shape) else: return rnd.randn(*shape) # *x unpacks tuple/list to sequence def smoothstep(x, NN=1., xmin=0., xmax=1.): N = math.ceil(NN) x = np.clip((x - xmin) / (xmax - xmin), 0, 1) result = 0 for n in range(0, N+1): result += scipy.special.comb(N+n, n) * scipy.special.comb(2*N+1, N-n) * (-x)**n result *= x**(N+1) if NN != N: result = (x + result) / 2 return result def lerp(z1, z2, num_steps, smooth=0.): vectors = [] xs = [step / (num_steps - 1) for step in range(num_steps)] if smooth > 0: xs = [smoothstep(x, smooth) for x in xs] for x in xs: interpol = z1 + (z2 - z1) * x vectors.append(interpol) return np.array(vectors) # interpolate on hypersphere def slerp(z1, z2, num_steps, smooth=0.): z1_norm = np.linalg.norm(z1) z2_norm = np.linalg.norm(z2) z2_normal = z2 * (z1_norm / z2_norm) vectors = [] xs = [step / (num_steps - 1) for step in range(num_steps)] if smooth > 0: xs = [smoothstep(x, smooth) for x in xs] for x in xs: interplain = z1 + (z2 - z1) * x interp = z1 + (z2_normal - z1) * x interp_norm = np.linalg.norm(interp) interpol_normal = interplain * (z1_norm / interp_norm) # interpol_normal = interp * (z1_norm / interp_norm) vectors.append(interpol_normal) return np.array(vectors) def cublerp(points, steps, fstep): keys = np.array([i*fstep for i in range(steps)] + [steps*fstep]) points = np.concatenate((points, np.expand_dims(points[0], 0))) cspline = CubSpline(keys, points) return cspline(range(steps*fstep+1)) # = = = = = = = = = = = = = = = = = = = = = = = = = = = def latent_anima(shape, frames, transit, key_latents=None, smooth=0.5, cubic=False, gauss=False, seed=None, verbose=True): if key_latents is None: transit = int(max(1, min(frames//4, transit))) steps = max(1, int(frames // transit)) log = ' timeline: %d steps by %d' % (steps, transit) getlat = lambda : get_z(shape, seed=seed) # make key points if key_latents is None: key_latents = np.array([getlat() for i in range(steps)]) latents = np.expand_dims(key_latents[0], 0) # populate lerp between key points if transit == 1: latents = key_latents else: if cubic: latents = cublerp(key_latents, steps, transit) log += ', cubic' else: for i in range(steps): zA = key_latents[i] zB = key_latents[(i+1) % steps] interps_z = slerp(zA, zB, transit, smooth=smooth) latents = np.concatenate((latents, interps_z)) latents = np.array(latents) if gauss: lats_post = gaussian_filter(latents, [transit, 0, 0], mode="wrap") lats_post = (lats_post / np.linalg.norm(lats_post, axis=-1, keepdims=True)) * math.sqrt(np.prod(shape)) log += ', gauss' latents = lats_post if verbose: print(log) if latents.shape[0] > frames: # extra frame latents = latents[1:] return latents # = = = = = = = = = = = = = = = = = = = = = = = = = = = def multimask(x, size, latmask=None, countHW=[1,1], delta=0.): Hx, Wx = countHW bcount = x.shape[0] if max(countHW) > 1: W = x.shape[3] # width H = x.shape[2] # height if Wx > 1: stripe_mask = [] for i in range(Wx): ch_mask = peak_roll(W, Wx, i, delta).unsqueeze(0).unsqueeze(0) # [1,1,w] th ch_mask = ch_mask.repeat(1,H,1) # [1,h,w] stripe_mask.append(ch_mask) maskW = torch.cat(stripe_mask, 0).unsqueeze(1) # [x,1,h,w] else: maskW = [1] if Hx > 1: stripe_mask = [] for i in range(Hx): ch_mask = peak_roll(H, Hx, i, delta).unsqueeze(1).unsqueeze(0) # [1,h,1] th ch_mask = ch_mask.repeat(1,1,W) # [1,h,w] stripe_mask.append(ch_mask) maskH = torch.cat(stripe_mask, 0).unsqueeze(1) # [y,1,h,w] else: maskH = [1] mask = [] for i in range(Wx): for j in range(Hx): mask.append(maskW[i] * maskH[j]) mask = torch.cat(mask, 0).unsqueeze(1) # [xy,1,h,w] mask = mask.to(x.device) x = torch.sum(x[:Hx*Wx] * mask, 0, keepdim=True) elif latmask is not None: if len(latmask.shape) < 4: latmask = latmask.unsqueeze(1) # [b,1,h,w] lms = latmask.shape if list(lms[2:]) != list(size) and np.prod(lms) > 1: latmask = F.interpolate(latmask, size) # , mode='nearest' latmask = latmask.type(x.dtype) x = torch.sum(x[:lms[0]] * latmask, 0, keepdim=True) else: return x x = x.repeat(bcount,1,1,1) return x # [b,f,h,w] def peak_roll(width, count, num, delta): step = width // count if width > step*2: fill_range = torch.zeros([width-step*2]) full_ax = torch.cat((peak(step, delta), fill_range), 0) else: full_ax = peak(step, delta)[:width] if num == 0: shift = max(width - (step//2), 0.) # must be positive! else: shift = step*num - (step//2) full_ax = torch.roll(full_ax, shift, 0) return full_ax # [width,] def peak(steps, delta): x = torch.linspace(0.-delta, 1.+ delta, steps) x_rev = torch.flip(x,[0]) x = torch.cat((x, x_rev), 0) x = torch.clip(x, 0., 1.) return x # [steps*2,] # = = = = = = = = = = = = = = = = = = = = = = = = = = = def ups2d(x, factor=2): assert isinstance(factor, int) and factor >= 1 if factor == 1: return x s = x.shape x = x.reshape(-1, s[1], s[2], 1, s[3], 1) x = x.repeat(1, 1, 1, factor, 1, factor) x = x.reshape(-1, s[1], s[2] * factor, s[3] * factor) return x # Tiles an array around two points, allowing for pad lengths greater than the input length # NB: if symm=True, every second tile is mirrored = messed up in GAN # adapted from https://discuss.pytorch.org/t/symmetric-padding/19866/3 def tile_pad(xt, padding, symm=True): h, w = xt.shape[-2:] left, right, top, bottom = padding def tile(x, minx, maxx, symm=True): rng = maxx - minx if symm is True: # triangular reflection double_rng = 2*rng mod = np.fmod(x - minx, double_rng) normed_mod = np.where(mod < 0, mod+double_rng, mod) out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx else: # repeating tiles mod = np.remainder(x - minx, rng) out = mod + minx return np.array(out, dtype=x.dtype) x_idx = np.arange(-left, w+right) y_idx = np.arange(-top, h+bottom) x_pad = tile(x_idx, -0.5, w-0.5, symm) y_pad = tile(y_idx, -0.5, h-0.5, symm) xx, yy = np.meshgrid(x_pad, y_pad) return xt[..., yy, xx] def pad_up_to(x, size, type='centr'): sh = x.shape[2:][::-1] if list(x.shape[2:]) == list(size): return x padding = [] for i, s in enumerate(size[::-1]): if 'side' in type.lower(): padding = padding + [0, s-sh[i]] else: # centr p0 = (s-sh[i]) // 2 p1 = s-sh[i] - p0 padding = padding + [p0,p1] y = tile_pad(x, padding, symm = 'symm' in type.lower()) # if 'symm' in type.lower(): # y = tile_pad(x, padding, symm=True) # else: # y = F.pad(x, padding, 'circular') return y # scale_type may include pad, side, symm def fix_size(x, size, scale_type='centr'): if not len(x.shape) == 4: raise Exception(" Wrong data rank, shape:", x.shape) if x.shape[2:] == size: return x if (x.shape[2]*2, x.shape[3]*2) == size: return ups2d(x) if scale_type.lower() == 'fit': return F.interpolate(x, size, mode='nearest') # , align_corners=True elif 'pad' in scale_type.lower(): pass else: # proportional scale to smaller side, then pad to bigger side sh0 = x.shape[2:] upsc = np.min(size) / np.min(sh0) new_size = [int(sh0[i]*upsc) for i in [0,1]] x = F.interpolate(x, new_size, mode='nearest') # , align_corners=True x = pad_up_to(x, size, scale_type) return x # Make list of odd sizes for upsampling to arbitrary resolution def hw_scales(size, base, n, keep_first_layers=None, verbose=False): if isinstance(base, int): base = (base, base) start_res = [int(b * 2 ** (-n)) for b in base] start_res[0] = int(start_res[0] * size[0] // base[0]) start_res[1] = int(start_res[1] * size[1] // base[1]) hw_list = [] if base[0] != base[1] and verbose is True: print(' size', size, 'base', base, 'start_res', start_res, 'n', n) if keep_first_layers is not None and keep_first_layers > 0: for i in range(keep_first_layers): hw_list.append(start_res) start_res = [x*2 for x in start_res] n -= 1 ch = (size[0] / start_res[0]) ** (1/n) cw = (size[1] / start_res[1]) ** (1/n) for i in range(n): h = math.floor(start_res[0] * ch**i) w = math.floor(start_res[1] * cw**i) hw_list.append((h,w)) hw_list.append(size) return hw_list def calc_res(shape): base0 = 2**int(np.log2(shape[0])) base1 = 2**int(np.log2(shape[1])) base = min(base0, base1) min_res = min(shape[0], shape[1]) def int_log2(xs, base): return [x * 2**(2-int(np.log2(base))) % 1 == 0 for x in xs] if min_res != base or max(*shape) / min(*shape) >= 2: if np.log2(base) < 10 and all(int_log2(shape, base*2)): base = base * 2 return base # , [shape[0]/base, shape[1]/base] def calc_init_res(shape, resolution=None): if len(shape) == 1: shape = [shape[0], shape[0], 1] elif len(shape) == 2: shape = [*shape, 1] size = shape[:2] if shape[2] < min(*shape[:2]) else shape[1:] # fewer colors than pixels if resolution is None: resolution = calc_res(size) res_log2 = int(np.log2(resolution)) init_res = [int(s * 2**(2-res_log2)) for s in size] return init_res, resolution, res_log2 def basename(file): return os.path.splitext(os.path.basename(file))[0] def file_list(path, ext=None, subdir=None): if subdir is True: files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn] else: files = [os.path.join(path, f) for f in os.listdir(path)] if ext is not None: if isinstance(ext, list): files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ext] elif isinstance(ext, str): files = [f for f in files if f.endswith(ext)] else: print(' Unknown extension/type for file list!') return sorted([f for f in files if os.path.isfile(f)]) def dir_list(in_dir): dirs = [os.path.join(in_dir, x) for x in os.listdir(in_dir)] return sorted([f for f in dirs if os.path.isdir(f)]) def img_list(path, subdir=None): if subdir is True: files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn] else: files = [os.path.join(path, f) for f in os.listdir(path)] files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ['jpg', 'jpeg', 'png', 'ppm', 'tif']] return sorted([f for f in files if os.path.isfile(f)]) def img_read(path): img = imread(path) # 8bit to 256bit if (img.ndim == 2) or (img.shape[2] == 1): img = np.dstack((img,img,img)) # rgba to rgb if img.shape[2] == 4: img = img[:,:,:3] return img