DragGAN-unofficial / drag_gan.py
ucalyptus's picture
Add application file
880233f
import copy
import os
import random
import urllib.request
import numpy as np
import torch
import torch.nn.functional as FF
import torch.optim
from torchvision import utils
from tqdm import tqdm
from stylegan2.model import Generator
class DownloadProgressBar(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def get_path(base_path):
BASE_DIR = os.path.join('checkpoints')
save_path = os.path.join(BASE_DIR, base_path)
if not os.path.exists(save_path):
url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
print(f'{base_path} not found')
print('Try to download from huggingface: ', url)
os.makedirs(os.path.dirname(save_path), exist_ok=True)
download_url(url, save_path)
print('Downloaded to ', save_path)
return save_path
def download_url(url, output_path):
with DownloadProgressBar(unit='B', unit_scale=True,
miniters=1, desc=url.split('/')[-1]) as t:
urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
class CustomGenerator(Generator):
def prepare(
self,
styles,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
]
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
return latent, noise
def generate(
self,
latent,
noise,
):
out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
if out.shape[-1] == 256: F = out
i += 2
image = skip
F = FF.interpolate(F, image.shape[-2:], mode='bilinear')
return image, F
def stylegan2(
size=1024,
channel_multiplier=2,
latent=512,
n_mlp=8,
ckpt='stylegan2-ffhq-config-f.pt'
):
g_ema = CustomGenerator(size, latent, n_mlp, channel_multiplier=channel_multiplier)
checkpoint = torch.load(get_path(ckpt))
g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
g_ema.requires_grad_(False)
g_ema.eval()
return g_ema
def bilinear_interpolate_torch(im, y, x):
"""
im : B,C,H,W
y : 1,numPoints -- pixel location y float
x : 1,numPOints -- pixel location y float
"""
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
wa = (x1.float() - x) * (y1.float() - y)
wb = (x1.float() - x) * (y - y0.float())
wc = (x - x0.float()) * (y1.float() - y)
wd = (x - x0.float()) * (y - y0.float())
# Instead of clamp
x1 = x1 - torch.floor(x1 / im.shape[3]).int()
y1 = y1 - torch.floor(y1 / im.shape[2]).int()
Ia = im[:, :, y0, x0]
Ib = im[:, :, y1, x0]
Ic = im[:, :, y0, x1]
Id = im[:, :, y1, x1]
return Ia * wa + Ib * wb + Ic * wc + Id * wd
def drag_gan(g_ema, latent: torch.Tensor, noise, F, handle_points, target_points, mask, max_iters=1000):
handle_points0 = copy.deepcopy(handle_points)
n = len(handle_points)
r1, r2, lam, d = 3, 12, 20, 1
def neighbor(x, y, d):
points = []
for i in range(x - d, x + d):
for j in range(y - d, y + d):
points.append(torch.tensor([i, j]).float().cuda())
return points
F0 = F.detach().clone()
# latent = latent.detach().clone().requires_grad_(True)
latent_trainable = latent[:, :6, :].detach().clone().requires_grad_(True)
latent_untrainable = latent[:, 6:, :].detach().clone().requires_grad_(False)
optimizer = torch.optim.Adam([latent_trainable], lr=2e-3)
for iter in range(max_iters):
for s in range(1):
optimizer.zero_grad()
latent = torch.cat([latent_trainable, latent_untrainable], dim=1)
sample2, F2 = g_ema.generate(latent, noise)
# motion supervision
loss = 0
for i in range(n):
pi, ti = handle_points[i], target_points[i]
di = (ti - pi) / torch.sum((ti - pi)**2)
for qi in neighbor(int(pi[0]), int(pi[1]), r1):
# f1 = F[..., int(qi[0]), int(qi[1])]
# f2 = F2[..., int(qi[0] + di[0]), int(qi[1] + di[1])]
f1 = bilinear_interpolate_torch(F2, qi[0], qi[1]).detach()
f2 = bilinear_interpolate_torch(F2, qi[0] + di[0], qi[1] + di[1])
loss += FF.l1_loss(f2, f1)
# loss += ((F-F0) * (1-mask)).abs().mean() * lam
loss.backward()
optimizer.step()
print(latent_trainable[0, 0, :10])
# if s % 10 ==0:
# utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1))
# point tracking
with torch.no_grad():
sample2, F2 = g_ema.generate(latent, noise)
for i in range(n):
pi = handle_points0[i]
# f = F0[..., int(pi[0]), int(pi[1])]
f0 = bilinear_interpolate_torch(F0, pi[0], pi[1])
minv = 1e9
minx = 1e9
miny = 1e9
for qi in neighbor(int(handle_points[i][0]), int(handle_points[i][1]), r2):
# f2 = F2[..., int(qi[0]), int(qi[1])]
try:
f2 = bilinear_interpolate_torch(F2, qi[0], qi[1])
except:
import ipdb
ipdb.set_trace()
v = torch.norm(f2 - f0, p=1)
if v < minv:
minv = v
minx = int(qi[0])
miny = int(qi[1])
handle_points[i][0] = minx
handle_points[i][1] = miny
F = F2.detach().clone()
if iter % 1 == 0:
print(iter, loss.item(), handle_points, target_points)
# p = handle_points[0].int()
# sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] = sample2[0, :, p[0] - 5:p[0] + 5, p[1] - 5:p[1] + 5] * 0
# t = target_points[0].int()
# sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] = sample2[0, :, t[0] - 5:t[0] + 5, t[1] - 5:t[1] + 5] * 255
# sample2[0, :, 210, 134] = sample2[0, :, 210, 134] * 0
utils.save_image(sample2, "test2.png", normalize=True, range=(-1, 1))
yield sample2, latent, F2