import copy import os import random import urllib.request import numpy as np import torch import torch.nn.functional as FF import torch.optim from torchvision import utils from tqdm import tqdm from stylegan2.model import Generator class DownloadProgressBar(tqdm): def update_to(self, b=1, bsize=1, tsize=None): if tsize is not None: self.total = tsize self.update(b * bsize - self.n) def get_path(base_path): BASE_DIR = os.path.join('checkpoints') save_path = os.path.join(BASE_DIR, base_path) if not os.path.exists(save_path): url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}" print(f'{base_path} not found') print('Try to download from huggingface: ', url) os.makedirs(os.path.dirname(save_path), exist_ok=True) download_url(url, save_path) print('Downloaded to ', save_path) return save_path def download_url(url, output_path): with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t: urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) class CustomGenerator(Generator): def prepare( self, styles, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, noise=None, randomize_noise=True, ): if not input_is_latent: styles = [self.style(s) for s in styles] if noise is None: if randomize_noise: noise = [None] * self.num_layers else: noise = [ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) ] if truncation < 1: style_t = [] for style in styles: style_t.append( truncation_latent + truncation * (style - truncation_latent) ) styles = style_t if len(styles) < 2: inject_index = self.n_latent if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles[0] else: if inject_index is None: inject_index = random.randint(1, self.n_latent - 1) latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) latent = torch.cat([latent, latent2], 1) return latent, noise def generate( self, latent, noise, ): out = self.input(latent) out = self.conv1(out, latent[:, 0], noise=noise[0]) skip = self.to_rgb1(out, latent[:, 1]) i = 1 for conv1, conv2, noise1, noise2, to_rgb in zip( self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs ): out = conv1(out, latent[:, i], noise=noise1) out = conv2(out, latent[:, i + 1], noise=noise2) skip = to_rgb(out, latent[:, i + 2], skip) if out.shape[-1] == 256: F = out i += 2 image = skip F = FF.interpolate(F, image.shape[-2:], mode='bilinear') return image, F def stylegan2( size=1024, channel_multiplier=2, latent=512, n_mlp=8, ckpt='stylegan2-ffhq-config-f.pt' ): g_ema = CustomGenerator(size, latent, n_mlp, channel_multiplier=channel_multiplier) checkpoint = torch.load(get_path(ckpt)) g_ema.load_state_dict(checkpoint["g_ema"], strict=False) g_ema.requires_grad_(False) g_ema.eval() return g_ema def bilinear_interpolate_torch(im, y, x): """ im : B,C,H,W y : 1,numPoints -- pixel location y float x : 1,numPOints -- pixel location y float """ x0 = torch.floor(x).long() x1 = x0 + 1 y0 = torch.floor(y).long() y1 = y0 + 1 wa = (x1.float() - x) * (y1.float() - y) wb = (x1.float() - x) * (y - y0.float()) wc = (x - x0.float()) * (y1.float() - y) wd = (x - x0.float()) * (y - y0.float()) # Instead of clamp x1 = x1 - torch.floor(x1 / im.shape[3]).int() y1 = y1 - torch.floor(y1 / im.shape[2]).int() Ia = im[:, :, y0, x0] Ib = im[:, :, y1, x0] Ic = im[:, :, y0, x1] Id = im[:, :, y1, x1] return Ia * wa + Ib * wb + Ic * wc + Id * wd def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points, mask, max_iters=1000): handle_points0 = copy.deepcopy(handle_points) n = len(handle_points) r1, r2, lam, d = 3, 12, 20, 1 def neighbor(x, y, d): points = [] for i in range(x - d, x + d): for j in range(y - d, y + d): points.append(torch.tensor([i, j]).float().cuda()) return points F0 = F.detach().clone() # latent = latent.detach().clone().requires_grad_(True) latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True) latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False) optimizer = torch.optim.Adam([latent_trainable], lr=2e-3) for iter in range(max_iters): for s in range(1): optimizer.zero_grad() latent = torch.cat([latent_trainable, latent_untrainable], dim=1) sample2, F2 = g_ema.generate(latent, noise) # motion supervision loss = 0 for i in range(n): pi, ti = handle_points[i], target_points[i] di = (ti - pi) / torch.sum((ti - pi)**2) for qi in neighbor(int(pi[0]), int(pi[1]), r1): # f1 = F[..., int(qi[0]), int(qi[1])] # f2 = F2[..., int(qi[0] + di[0]), int(qi[1] + di[1])] f1 = bilinear_interpolate_torch(F2, qi[0], qi[1]).detach() f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1]) loss += FF.l1_loss(f2, f1) # loss += ((F-F0) * (1-mask)).abs().mean() * lam loss.backward() optimizer.step() print(latent_trainable[0, 0, :10]) # if s % 10 ==0: # utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1)) # point tracking with torch.no_grad(): sample2, F2 = g_ema.generate(latent, noise) for i in range(n): pi = handle_points0[i] # f = F0[..., int(pi[0]), int(pi[1])] f0 = bilinear_interpolate_torch(F0, pi[0], pi[1]) minv = 1e9 minx = 1e9 miny = 1e9 for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2): # f2 = F2[..., int(qi[0]), int(qi[1])] try: f2 = bilinear_interpolate_torch(F2, qi[0], qi[1]) except: import ipdb ipdb.set_trace() v = torch.norm(f2 - f0, p=1) if v < minv: minv = v minx = int(qi[0]) miny = int(qi[1]) handle_points[i][0] = minx handle_points[i][1] = miny F = F2.detach().clone() if iter % 1 == 0: print(iter, loss.item(), handle_points, target_points) # p = handle_points[0].int() # sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] = sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] * 0 # t = target_points[0].int() # sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] = sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] * 255 # sample2[0, :, 210, 134] = sample2[0, :, 210, 134] * 0 utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1)) yield sample2, latent, F2