|
|
|
|
|
|
|
import argparse |
|
import math |
|
import random |
|
|
|
from urllib.request import urlopen |
|
from tqdm import tqdm |
|
import sys |
|
import os |
|
|
|
|
|
|
|
sys.path.append('taming-transformers') |
|
|
|
from omegaconf import OmegaConf |
|
from taming.models import cond_transformer, vqgan |
|
|
|
|
|
import torch |
|
from torch import nn, optim |
|
from torch.nn import functional as F |
|
from torchvision import transforms |
|
from torchvision.transforms import functional as TF |
|
from torch.cuda import get_device_properties |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
from torch_optimizer import DiffGrad, AdamP |
|
|
|
from CLIP import clip |
|
import kornia.augmentation as K |
|
import numpy as np |
|
import imageio |
|
|
|
from PIL import ImageFile, Image, PngImagePlugin, ImageChops |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
from subprocess import Popen, PIPE |
|
import re |
|
|
|
|
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
default_image_size = 512 |
|
if not torch.cuda.is_available(): |
|
default_image_size = 256 |
|
elif get_device_properties(0).total_memory <= 2 ** 33: |
|
default_image_size = 304 |
|
|
|
|
|
vq_parser = argparse.ArgumentParser(description='Image generation using VQGAN+CLIP') |
|
|
|
|
|
vq_parser.add_argument("-p", "--prompts", type=str, help="Text prompts", default=None, dest='prompts') |
|
vq_parser.add_argument("-ip", "--image_prompts", type=str, help="Image prompts / target image", default=[], dest='image_prompts') |
|
vq_parser.add_argument("-i", "--iterations", type=int, help="Number of iterations", default=500, dest='max_iterations') |
|
vq_parser.add_argument("-se", "--save_every", type=int, help="Save image iterations", default=50, dest='display_freq') |
|
vq_parser.add_argument("-s", "--size", nargs=2, type=int, help="Image size (width height) (default: %(default)s)", default=[default_image_size,default_image_size], dest='size') |
|
vq_parser.add_argument("-ii", "--init_image", type=str, help="Initial image", default=None, dest='init_image') |
|
vq_parser.add_argument("-in", "--init_noise", type=str, help="Initial noise image (pixels or gradient)", default=None, dest='init_noise') |
|
vq_parser.add_argument("-iw", "--init_weight", type=float, help="Initial weight", default=0., dest='init_weight') |
|
vq_parser.add_argument("-m", "--clip_model", type=str, help="CLIP model (e.g. ViT-B/32, ViT-B/16)", default='ViT-B/32', dest='clip_model') |
|
vq_parser.add_argument("-conf", "--vqgan_config", type=str, help="VQGAN config", default=f'checkpoints/vqgan_imagenet_f16_16384.yaml', dest='vqgan_config') |
|
vq_parser.add_argument("-ckpt", "--vqgan_checkpoint", type=str, help="VQGAN checkpoint", default=f'checkpoints/vqgan_imagenet_f16_16384.ckpt', dest='vqgan_checkpoint') |
|
vq_parser.add_argument("-nps", "--noise_prompt_seeds", nargs="*", type=int, help="Noise prompt seeds", default=[], dest='noise_prompt_seeds') |
|
vq_parser.add_argument("-npw", "--noise_prompt_weights", nargs="*", type=float, help="Noise prompt weights", default=[], dest='noise_prompt_weights') |
|
vq_parser.add_argument("-lr", "--learning_rate", type=float, help="Learning rate", default=0.1, dest='step_size') |
|
vq_parser.add_argument("-cutm", "--cut_method", type=str, help="Cut method", choices=['original','updated','nrupdated','updatedpooling','latest'], default='latest', dest='cut_method') |
|
vq_parser.add_argument("-cuts", "--num_cuts", type=int, help="Number of cuts", default=32, dest='cutn') |
|
vq_parser.add_argument("-cutp", "--cut_power", type=float, help="Cut power", default=1., dest='cut_pow') |
|
vq_parser.add_argument("-sd", "--seed", type=int, help="Seed", default=None, dest='seed') |
|
vq_parser.add_argument("-opt", "--optimiser", type=str, help="Optimiser", choices=['Adam','AdamW','Adagrad','Adamax','DiffGrad','AdamP','RAdam','RMSprop'], default='Adam', dest='optimiser') |
|
vq_parser.add_argument("-o", "--output", type=str, help="Output image filename", default="output.png", dest='output') |
|
vq_parser.add_argument("-vid", "--video", action='store_true', help="Create video frames?", dest='make_video') |
|
vq_parser.add_argument("-zvid", "--zoom_video", action='store_true', help="Create zoom video?", dest='make_zoom_video') |
|
vq_parser.add_argument("-zs", "--zoom_start", type=int, help="Zoom start iteration", default=0, dest='zoom_start') |
|
vq_parser.add_argument("-zse", "--zoom_save_every", type=int, help="Save zoom image iterations", default=10, dest='zoom_frequency') |
|
vq_parser.add_argument("-zsc", "--zoom_scale", type=float, help="Zoom scale %%", default=0.99, dest='zoom_scale') |
|
vq_parser.add_argument("-zsx", "--zoom_shift_x", type=int, help="Zoom shift x (left/right) amount in pixels", default=0, dest='zoom_shift_x') |
|
vq_parser.add_argument("-zsy", "--zoom_shift_y", type=int, help="Zoom shift y (up/down) amount in pixels", default=0, dest='zoom_shift_y') |
|
vq_parser.add_argument("-cpe", "--change_prompt_every", type=int, help="Prompt change frequency", default=0, dest='prompt_frequency') |
|
vq_parser.add_argument("-vl", "--video_length", type=float, help="Video length in seconds (not interpolated)", default=10, dest='video_length') |
|
vq_parser.add_argument("-ofps", "--output_video_fps", type=float, help="Create an interpolated video (Nvidia GPU only) with this fps (min 10. best set to 30 or 60)", default=0, dest='output_video_fps') |
|
vq_parser.add_argument("-ifps", "--input_video_fps", type=float, help="When creating an interpolated video, use this as the input fps to interpolate from (>0 & <ofps)", default=15, dest='input_video_fps') |
|
vq_parser.add_argument("-d", "--deterministic", action='store_true', help="Enable cudnn.deterministic?", dest='cudnn_determinism') |
|
vq_parser.add_argument("-aug", "--augments", nargs='+', action='append', type=str, choices=['Ji','Sh','Gn','Pe','Ro','Af','Et','Ts','Cr','Er','Re'], help="Enabled augments (latest vut method only)", default=[], dest='augments') |
|
vq_parser.add_argument("-vsd", "--video_style_dir", type=str, help="Directory with video frames to style", default=None, dest='video_style_dir') |
|
vq_parser.add_argument("-cd", "--cuda_device", type=str, help="Cuda device to use", default="cuda:0", dest='cuda_device') |
|
|
|
|
|
|
|
args = vq_parser.parse_args() |
|
|
|
if not args.prompts and not args.image_prompts: |
|
args.prompts = "A cute, smiling, Nerdy Rodent" |
|
|
|
if args.cudnn_determinism: |
|
torch.backends.cudnn.deterministic = True |
|
|
|
if not args.augments: |
|
args.augments = [['Af', 'Pe', 'Ji', 'Er']] |
|
|
|
|
|
if args.prompts: |
|
|
|
story_phrases = [phrase.strip() for phrase in args.prompts.split("^")] |
|
|
|
|
|
all_phrases = [] |
|
for phrase in story_phrases: |
|
all_phrases.append(phrase.split("|")) |
|
|
|
|
|
args.prompts = all_phrases[0] |
|
|
|
|
|
if args.image_prompts: |
|
args.image_prompts = args.image_prompts.split("|") |
|
args.image_prompts = [image.strip() for image in args.image_prompts] |
|
|
|
if args.make_video and args.make_zoom_video: |
|
print("Warning: Make video and make zoom video are mutually exclusive.") |
|
args.make_video = False |
|
|
|
|
|
if args.make_video or args.make_zoom_video: |
|
if not os.path.exists('steps'): |
|
os.mkdir('steps') |
|
|
|
|
|
|
|
if not args.cuda_device == 'cpu' and not torch.cuda.is_available(): |
|
args.cuda_device = 'cpu' |
|
args.video_fps = 0 |
|
print("Warning: No GPU found! Using the CPU instead. The iterations will be slow.") |
|
print("Perhaps CUDA/ROCm or the right pytorch version is not properly installed?") |
|
|
|
|
|
if args.video_style_dir: |
|
print("Locating video frames...") |
|
video_frame_list = [] |
|
for entry in os.scandir(args.video_style_dir): |
|
if (entry.path.endswith(".jpg") |
|
or entry.path.endswith(".png")) and entry.is_file(): |
|
video_frame_list.append(entry.path) |
|
|
|
|
|
if not os.path.exists('steps'): |
|
os.mkdir('steps') |
|
|
|
args.init_image = video_frame_list[0] |
|
filename = os.path.basename(args.init_image) |
|
cwd = os.getcwd() |
|
args.output = os.path.join(cwd, "steps", filename) |
|
num_video_frames = len(video_frame_list) |
|
|
|
|
|
|
|
def sinc(x): |
|
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) |
|
|
|
|
|
def lanczos(x, a): |
|
cond = torch.logical_and(-a < x, x < a) |
|
out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([])) |
|
return out / out.sum() |
|
|
|
|
|
def ramp(ratio, width): |
|
n = math.ceil(width / ratio + 1) |
|
out = torch.empty([n]) |
|
cur = 0 |
|
for i in range(out.shape[0]): |
|
out[i] = cur |
|
cur += ratio |
|
return torch.cat([-out[1:].flip([0]), out])[1:-1] |
|
|
|
|
|
|
|
def zoom_at(img, x, y, zoom): |
|
w, h = img.size |
|
zoom2 = zoom * 2 |
|
img = img.crop((x - w / zoom2, y - h / zoom2, |
|
x + w / zoom2, y + h / zoom2)) |
|
return img.resize((w, h), Image.LANCZOS) |
|
|
|
|
|
|
|
def random_noise_image(w,h): |
|
random_image = Image.fromarray(np.random.randint(0,255,(w,h,3),dtype=np.dtype('uint8'))) |
|
return random_image |
|
|
|
|
|
|
|
def gradient_2d(start, stop, width, height, is_horizontal): |
|
if is_horizontal: |
|
return np.tile(np.linspace(start, stop, width), (height, 1)) |
|
else: |
|
return np.tile(np.linspace(start, stop, height), (width, 1)).T |
|
|
|
|
|
def gradient_3d(width, height, start_list, stop_list, is_horizontal_list): |
|
result = np.zeros((height, width, len(start_list)), dtype=float) |
|
|
|
for i, (start, stop, is_horizontal) in enumerate(zip(start_list, stop_list, is_horizontal_list)): |
|
result[:, :, i] = gradient_2d(start, stop, width, height, is_horizontal) |
|
|
|
return result |
|
|
|
|
|
def random_gradient_image(w,h): |
|
array = gradient_3d(w, h, (0, 0, np.random.randint(0,255)), (np.random.randint(1,255), np.random.randint(2,255), np.random.randint(3,128)), (True, False, False)) |
|
random_image = Image.fromarray(np.uint8(array)) |
|
return random_image |
|
|
|
|
|
|
|
def resample(input, size, align_corners=True): |
|
n, c, h, w = input.shape |
|
dh, dw = size |
|
|
|
input = input.view([n * c, 1, h, w]) |
|
|
|
if dh < h: |
|
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) |
|
pad_h = (kernel_h.shape[0] - 1) // 2 |
|
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') |
|
input = F.conv2d(input, kernel_h[None, None, :, None]) |
|
|
|
if dw < w: |
|
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) |
|
pad_w = (kernel_w.shape[0] - 1) // 2 |
|
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') |
|
input = F.conv2d(input, kernel_w[None, None, None, :]) |
|
|
|
input = input.view([n, c, h, w]) |
|
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) |
|
|
|
|
|
class ReplaceGrad(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x_forward, x_backward): |
|
ctx.shape = x_backward.shape |
|
return x_forward |
|
|
|
@staticmethod |
|
def backward(ctx, grad_in): |
|
return None, grad_in.sum_to_size(ctx.shape) |
|
|
|
replace_grad = ReplaceGrad.apply |
|
|
|
|
|
class ClampWithGrad(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, input, min, max): |
|
ctx.min = min |
|
ctx.max = max |
|
ctx.save_for_backward(input) |
|
return input.clamp(min, max) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_in): |
|
input, = ctx.saved_tensors |
|
return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None |
|
|
|
clamp_with_grad = ClampWithGrad.apply |
|
|
|
|
|
def vector_quantize(x, codebook): |
|
d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T |
|
indices = d.argmin(-1) |
|
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook |
|
return replace_grad(x_q, x) |
|
|
|
|
|
class Prompt(nn.Module): |
|
def __init__(self, embed, weight=1., stop=float('-inf')): |
|
super().__init__() |
|
self.register_buffer('embed', embed) |
|
self.register_buffer('weight', torch.as_tensor(weight)) |
|
self.register_buffer('stop', torch.as_tensor(stop)) |
|
|
|
def forward(self, input): |
|
input_normed = F.normalize(input.unsqueeze(1), dim=2) |
|
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) |
|
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) |
|
dists = dists * self.weight.sign() |
|
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() |
|
|
|
|
|
|
|
def split_prompt(prompt): |
|
vals = prompt.rsplit(':', 2) |
|
vals = vals + ['', '1', '-inf'][len(vals):] |
|
return vals[0], float(vals[1]), float(vals[2]) |
|
|
|
|
|
class MakeCutouts(nn.Module): |
|
def __init__(self, cut_size, cutn, cut_pow=1.): |
|
super().__init__() |
|
self.cut_size = cut_size |
|
self.cutn = cutn |
|
self.cut_pow = cut_pow |
|
|
|
|
|
augment_list = [] |
|
for item in args.augments[0]: |
|
if item == 'Ji': |
|
augment_list.append(K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7)) |
|
elif item == 'Sh': |
|
augment_list.append(K.RandomSharpness(sharpness=0.3, p=0.5)) |
|
elif item == 'Gn': |
|
augment_list.append(K.RandomGaussianNoise(mean=0.0, std=1., p=0.5)) |
|
elif item == 'Pe': |
|
augment_list.append(K.RandomPerspective(distortion_scale=0.7, p=0.7)) |
|
elif item == 'Ro': |
|
augment_list.append(K.RandomRotation(degrees=15, p=0.7)) |
|
elif item == 'Af': |
|
augment_list.append(K.RandomAffine(degrees=15, translate=0.1, shear=5, p=0.7, padding_mode='zeros', keepdim=True)) |
|
elif item == 'Et': |
|
augment_list.append(K.RandomElasticTransform(p=0.7)) |
|
elif item == 'Ts': |
|
augment_list.append(K.RandomThinPlateSpline(scale=0.8, same_on_batch=True, p=0.7)) |
|
elif item == 'Cr': |
|
augment_list.append(K.RandomCrop(size=(self.cut_size,self.cut_size), pad_if_needed=True, padding_mode='reflect', p=0.5)) |
|
elif item == 'Er': |
|
augment_list.append(K.RandomErasing(scale=(.1, .4), ratio=(.3, 1/.3), same_on_batch=True, p=0.7)) |
|
elif item == 'Re': |
|
augment_list.append(K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1), ratio=(0.75,1.333), cropping_mode='resample', p=0.5)) |
|
|
|
self.augs = nn.Sequential(*augment_list) |
|
self.noise_fac = 0.1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) |
|
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) |
|
|
|
def forward(self, input): |
|
cutouts = [] |
|
|
|
for _ in range(self.cutn): |
|
|
|
cutout = (self.av_pool(input) + self.max_pool(input))/2 |
|
cutouts.append(cutout) |
|
|
|
batch = self.augs(torch.cat(cutouts, dim=0)) |
|
|
|
if self.noise_fac: |
|
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) |
|
batch = batch + facs * torch.randn_like(batch) |
|
return batch |
|
|
|
|
|
|
|
class MakeCutoutsPoolingUpdate(nn.Module): |
|
def __init__(self, cut_size, cutn, cut_pow=1.): |
|
super().__init__() |
|
self.cut_size = cut_size |
|
self.cutn = cutn |
|
self.cut_pow = cut_pow |
|
|
|
self.augs = nn.Sequential( |
|
K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'), |
|
K.RandomPerspective(0.7,p=0.7), |
|
K.ColorJitter(hue=0.1, saturation=0.1, p=0.7), |
|
K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7), |
|
) |
|
|
|
self.noise_fac = 0.1 |
|
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) |
|
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) |
|
|
|
def forward(self, input): |
|
sideY, sideX = input.shape[2:4] |
|
max_size = min(sideX, sideY) |
|
min_size = min(sideX, sideY, self.cut_size) |
|
cutouts = [] |
|
|
|
for _ in range(self.cutn): |
|
cutout = (self.av_pool(input) + self.max_pool(input))/2 |
|
cutouts.append(cutout) |
|
|
|
batch = self.augs(torch.cat(cutouts, dim=0)) |
|
|
|
if self.noise_fac: |
|
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) |
|
batch = batch + facs * torch.randn_like(batch) |
|
return batch |
|
|
|
|
|
|
|
class MakeCutoutsNRUpdate(nn.Module): |
|
def __init__(self, cut_size, cutn, cut_pow=1.): |
|
super().__init__() |
|
self.cut_size = cut_size |
|
self.cutn = cutn |
|
self.cut_pow = cut_pow |
|
self.noise_fac = 0.1 |
|
|
|
|
|
augment_list = [] |
|
for item in args.augments[0]: |
|
if item == 'Ji': |
|
augment_list.append(K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7)) |
|
elif item == 'Sh': |
|
augment_list.append(K.RandomSharpness(sharpness=0.3, p=0.5)) |
|
elif item == 'Gn': |
|
augment_list.append(K.RandomGaussianNoise(mean=0.0, std=1., p=0.5)) |
|
elif item == 'Pe': |
|
augment_list.append(K.RandomPerspective(distortion_scale=0.5, p=0.7)) |
|
elif item == 'Ro': |
|
augment_list.append(K.RandomRotation(degrees=15, p=0.7)) |
|
elif item == 'Af': |
|
augment_list.append(K.RandomAffine(degrees=30, translate=0.1, shear=5, p=0.7, padding_mode='zeros', keepdim=True)) |
|
elif item == 'Et': |
|
augment_list.append(K.RandomElasticTransform(p=0.7)) |
|
elif item == 'Ts': |
|
augment_list.append(K.RandomThinPlateSpline(scale=0.8, same_on_batch=True, p=0.7)) |
|
elif item == 'Cr': |
|
augment_list.append(K.RandomCrop(size=(self.cut_size,self.cut_size), pad_if_needed=True, padding_mode='reflect', p=0.5)) |
|
elif item == 'Er': |
|
augment_list.append(K.RandomErasing(scale=(.1, .4), ratio=(.3, 1/.3), same_on_batch=True, p=0.7)) |
|
elif item == 'Re': |
|
augment_list.append(K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1), ratio=(0.75,1.333), cropping_mode='resample', p=0.5)) |
|
|
|
self.augs = nn.Sequential(*augment_list) |
|
|
|
|
|
def forward(self, input): |
|
sideY, sideX = input.shape[2:4] |
|
max_size = min(sideX, sideY) |
|
min_size = min(sideX, sideY, self.cut_size) |
|
cutouts = [] |
|
for _ in range(self.cutn): |
|
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) |
|
offsetx = torch.randint(0, sideX - size + 1, ()) |
|
offsety = torch.randint(0, sideY - size + 1, ()) |
|
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] |
|
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
|
batch = self.augs(torch.cat(cutouts, dim=0)) |
|
if self.noise_fac: |
|
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) |
|
batch = batch + facs * torch.randn_like(batch) |
|
return batch |
|
|
|
|
|
|
|
class MakeCutoutsUpdate(nn.Module): |
|
def __init__(self, cut_size, cutn, cut_pow=1.): |
|
super().__init__() |
|
self.cut_size = cut_size |
|
self.cutn = cutn |
|
self.cut_pow = cut_pow |
|
self.augs = nn.Sequential( |
|
K.RandomHorizontalFlip(p=0.5), |
|
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), |
|
|
|
K.RandomSharpness(0.3,p=0.4), |
|
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'), |
|
K.RandomPerspective(0.2,p=0.4),) |
|
self.noise_fac = 0.1 |
|
|
|
|
|
def forward(self, input): |
|
sideY, sideX = input.shape[2:4] |
|
max_size = min(sideX, sideY) |
|
min_size = min(sideX, sideY, self.cut_size) |
|
cutouts = [] |
|
for _ in range(self.cutn): |
|
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) |
|
offsetx = torch.randint(0, sideX - size + 1, ()) |
|
offsety = torch.randint(0, sideY - size + 1, ()) |
|
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] |
|
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
|
batch = self.augs(torch.cat(cutouts, dim=0)) |
|
if self.noise_fac: |
|
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) |
|
batch = batch + facs * torch.randn_like(batch) |
|
return batch |
|
|
|
|
|
|
|
class MakeCutoutsOrig(nn.Module): |
|
def __init__(self, cut_size, cutn, cut_pow=1.): |
|
super().__init__() |
|
self.cut_size = cut_size |
|
self.cutn = cutn |
|
self.cut_pow = cut_pow |
|
|
|
def forward(self, input): |
|
sideY, sideX = input.shape[2:4] |
|
max_size = min(sideX, sideY) |
|
min_size = min(sideX, sideY, self.cut_size) |
|
cutouts = [] |
|
for _ in range(self.cutn): |
|
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) |
|
offsetx = torch.randint(0, sideX - size + 1, ()) |
|
offsety = torch.randint(0, sideY - size + 1, ()) |
|
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] |
|
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
|
return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1) |
|
|
|
|
|
def load_vqgan_model(config_path, checkpoint_path): |
|
global gumbel |
|
gumbel = False |
|
config = OmegaConf.load(config_path) |
|
if config.model.target == 'taming.models.vqgan.VQModel': |
|
model = vqgan.VQModel(**config.model.params) |
|
model.eval().requires_grad_(False) |
|
model.init_from_ckpt(checkpoint_path) |
|
elif config.model.target == 'taming.models.vqgan.GumbelVQ': |
|
model = vqgan.GumbelVQ(**config.model.params) |
|
model.eval().requires_grad_(False) |
|
model.init_from_ckpt(checkpoint_path) |
|
gumbel = True |
|
elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer': |
|
parent_model = cond_transformer.Net2NetTransformer(**config.model.params) |
|
parent_model.eval().requires_grad_(False) |
|
parent_model.init_from_ckpt(checkpoint_path) |
|
model = parent_model.first_stage_model |
|
else: |
|
raise ValueError(f'unknown model type: {config.model.target}') |
|
del model.loss |
|
return model |
|
|
|
|
|
def resize_image(image, out_size): |
|
ratio = image.size[0] / image.size[1] |
|
area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) |
|
size = round((area * ratio)**0.5), round((area / ratio)**0.5) |
|
return image.resize(size, Image.LANCZOS) |
|
|
|
|
|
|
|
device = torch.device(args.cuda_device) |
|
model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device) |
|
jit = True if "1.7.1" in torch.__version__ else False |
|
perceptor = clip.load(args.clip_model, jit=jit)[0].eval().requires_grad_(False).to(device) |
|
|
|
|
|
|
|
|
|
|
|
cut_size = perceptor.visual.input_resolution |
|
f = 2**(model.decoder.num_resolutions - 1) |
|
|
|
|
|
|
|
if args.cut_method == 'latest': |
|
make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow) |
|
elif args.cut_method == 'original': |
|
make_cutouts = MakeCutoutsOrig(cut_size, args.cutn, cut_pow=args.cut_pow) |
|
elif args.cut_method == 'updated': |
|
make_cutouts = MakeCutoutsUpdate(cut_size, args.cutn, cut_pow=args.cut_pow) |
|
elif args.cut_method == 'nrupdated': |
|
make_cutouts = MakeCutoutsNRUpdate(cut_size, args.cutn, cut_pow=args.cut_pow) |
|
else: |
|
make_cutouts = MakeCutoutsPoolingUpdate(cut_size, args.cutn, cut_pow=args.cut_pow) |
|
|
|
toksX, toksY = args.size[0] // f, args.size[1] // f |
|
sideX, sideY = toksX * f, toksY * f |
|
|
|
|
|
if gumbel: |
|
e_dim = 256 |
|
n_toks = model.quantize.n_embed |
|
z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None] |
|
z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None] |
|
else: |
|
e_dim = model.quantize.e_dim |
|
n_toks = model.quantize.n_e |
|
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] |
|
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] |
|
|
|
|
|
if args.init_image: |
|
if 'http' in args.init_image: |
|
img = Image.open(urlopen(args.init_image)) |
|
else: |
|
img = Image.open(args.init_image) |
|
pil_image = img.convert('RGB') |
|
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) |
|
pil_tensor = TF.to_tensor(pil_image) |
|
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1) |
|
elif args.init_noise == 'pixels': |
|
img = random_noise_image(args.size[0], args.size[1]) |
|
pil_image = img.convert('RGB') |
|
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) |
|
pil_tensor = TF.to_tensor(pil_image) |
|
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1) |
|
elif args.init_noise == 'gradient': |
|
img = random_gradient_image(args.size[0], args.size[1]) |
|
pil_image = img.convert('RGB') |
|
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) |
|
pil_tensor = TF.to_tensor(pil_image) |
|
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1) |
|
else: |
|
one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float() |
|
|
|
if gumbel: |
|
z = one_hot @ model.quantize.embed.weight |
|
else: |
|
z = one_hot @ model.quantize.embedding.weight |
|
|
|
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) |
|
|
|
|
|
z_orig = z.clone() |
|
z.requires_grad_(True) |
|
|
|
pMs = [] |
|
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], |
|
std=[0.26862954, 0.26130258, 0.27577711]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.prompts: |
|
for prompt in args.prompts: |
|
txt, weight, stop = split_prompt(prompt) |
|
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() |
|
pMs.append(Prompt(embed, weight, stop).to(device)) |
|
|
|
for prompt in args.image_prompts: |
|
path, weight, stop = split_prompt(prompt) |
|
img = Image.open(path) |
|
pil_image = img.convert('RGB') |
|
img = resize_image(pil_image, (sideX, sideY)) |
|
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device)) |
|
embed = perceptor.encode_image(normalize(batch)).float() |
|
pMs.append(Prompt(embed, weight, stop).to(device)) |
|
|
|
for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights): |
|
gen = torch.Generator().manual_seed(seed) |
|
embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen) |
|
pMs.append(Prompt(embed, weight).to(device)) |
|
|
|
|
|
|
|
def get_opt(opt_name, opt_lr): |
|
if opt_name == "Adam": |
|
opt = optim.Adam([z], lr=opt_lr) |
|
elif opt_name == "AdamW": |
|
opt = optim.AdamW([z], lr=opt_lr) |
|
elif opt_name == "Adagrad": |
|
opt = optim.Adagrad([z], lr=opt_lr) |
|
elif opt_name == "Adamax": |
|
opt = optim.Adamax([z], lr=opt_lr) |
|
elif opt_name == "DiffGrad": |
|
opt = DiffGrad([z], lr=opt_lr, eps=1e-9, weight_decay=1e-9) |
|
elif opt_name == "AdamP": |
|
opt = AdamP([z], lr=opt_lr) |
|
elif opt_name == "RAdam": |
|
opt = optim.RAdam([z], lr=opt_lr) |
|
elif opt_name == "RMSprop": |
|
opt = optim.RMSprop([z], lr=opt_lr) |
|
else: |
|
print("Unknown optimiser. Are choices broken?") |
|
opt = optim.Adam([z], lr=opt_lr) |
|
return opt |
|
|
|
opt = get_opt(args.optimiser, args.step_size) |
|
|
|
|
|
|
|
print('Using device:', device) |
|
print('Optimising using:', args.optimiser) |
|
|
|
if args.prompts: |
|
print('Using text prompts:', args.prompts) |
|
if args.image_prompts: |
|
print('Using image prompts:', args.image_prompts) |
|
if args.init_image: |
|
print('Using initial image:', args.init_image) |
|
if args.noise_prompt_weights: |
|
print('Noise prompt weights:', args.noise_prompt_weights) |
|
|
|
|
|
if args.seed is None: |
|
seed = torch.seed() |
|
else: |
|
seed = args.seed |
|
torch.manual_seed(seed) |
|
print('Using seed:', seed) |
|
|
|
|
|
|
|
def synth(z): |
|
if gumbel: |
|
z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1) |
|
else: |
|
z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1) |
|
return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1) |
|
|
|
|
|
|
|
@torch.inference_mode() |
|
def checkin(i, losses): |
|
losses_str = ', '.join(f'{loss.item():g}' for loss in losses) |
|
tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}') |
|
out = synth(z) |
|
info = PngImagePlugin.PngInfo() |
|
info.add_text('comment', f'{args.prompts}') |
|
TF.to_pil_image(out[0].cpu()).save(args.output, pnginfo=info) |
|
|
|
|
|
def ascend_txt(): |
|
global i |
|
out = synth(z) |
|
iii = perceptor.encode_image(normalize(make_cutouts(out))).float() |
|
|
|
result = [] |
|
|
|
if args.init_weight: |
|
|
|
result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2) |
|
|
|
for prompt in pMs: |
|
result.append(prompt(iii)) |
|
|
|
if args.make_video: |
|
img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:] |
|
img = np.transpose(img, (1, 2, 0)) |
|
imageio.imwrite('./steps/' + str(i) + '.png', np.array(img)) |
|
|
|
return result |
|
|
|
|
|
def train(i): |
|
opt.zero_grad(set_to_none=True) |
|
lossAll = ascend_txt() |
|
|
|
if i % args.display_freq == 0: |
|
checkin(i, lossAll) |
|
|
|
loss = sum(lossAll) |
|
loss.backward() |
|
opt.step() |
|
|
|
|
|
with torch.inference_mode(): |
|
z.copy_(z.maximum(z_min).minimum(z_max)) |
|
|
|
|
|
|
|
i = 0 |
|
j = 0 |
|
p = 1 |
|
smoother = 0 |
|
this_video_frame = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
with tqdm() as pbar: |
|
while True: |
|
|
|
if args.make_zoom_video: |
|
if i % args.zoom_frequency == 0: |
|
out = synth(z) |
|
|
|
|
|
img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:] |
|
img = np.transpose(img, (1, 2, 0)) |
|
imageio.imwrite('./steps/' + str(j) + '.png', np.array(img)) |
|
|
|
|
|
if args.zoom_start <= i: |
|
|
|
|
|
|
|
|
|
pil_image = Image.fromarray(np.array(img).astype('uint8'), 'RGB') |
|
|
|
|
|
if args.zoom_scale != 1: |
|
pil_image_zoom = zoom_at(pil_image, sideX/2, sideY/2, args.zoom_scale) |
|
else: |
|
pil_image_zoom = pil_image |
|
|
|
|
|
if args.zoom_shift_x or args.zoom_shift_y: |
|
|
|
pil_image_zoom = ImageChops.offset(pil_image_zoom, args.zoom_shift_x, args.zoom_shift_y) |
|
|
|
|
|
pil_tensor = TF.to_tensor(pil_image_zoom) |
|
|
|
|
|
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1) |
|
z_orig = z.clone() |
|
z.requires_grad_(True) |
|
|
|
|
|
opt = get_opt(args.optimiser, args.step_size) |
|
|
|
|
|
j += 1 |
|
|
|
|
|
if args.prompt_frequency > 0: |
|
if i % args.prompt_frequency == 0 and i > 0: |
|
|
|
if p >= len(all_phrases): |
|
p = 0 |
|
|
|
pMs = [] |
|
args.prompts = all_phrases[p] |
|
|
|
|
|
print(args.prompts) |
|
|
|
for prompt in args.prompts: |
|
txt, weight, stop = split_prompt(prompt) |
|
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() |
|
pMs.append(Prompt(embed, weight, stop).to(device)) |
|
|
|
''' |
|
# Smooth test |
|
smoother = args.zoom_frequency * 15 # smoothing over x frames |
|
variable_lr = args.step_size * 0.25 |
|
opt = get_opt(args.optimiser, variable_lr) |
|
''' |
|
|
|
p += 1 |
|
|
|
''' |
|
if smoother > 0: |
|
if smoother == 1: |
|
opt = get_opt(args.optimiser, args.step_size) |
|
smoother -= 1 |
|
''' |
|
|
|
''' |
|
# Messing with learning rate / optimisers |
|
if i % 225 == 0 and i > 0: |
|
variable_optimiser_item = random.choice(optimiser_list) |
|
variable_optimiser = variable_optimiser_item[0] |
|
variable_lr = variable_optimiser_item[1] |
|
|
|
opt = get_opt(variable_optimiser, variable_lr) |
|
print("New opt: %s, lr= %f" %(variable_optimiser,variable_lr)) |
|
''' |
|
|
|
|
|
|
|
train(i) |
|
|
|
|
|
if i == args.max_iterations: |
|
if not args.video_style_dir: |
|
|
|
break |
|
else: |
|
if this_video_frame == (num_video_frames - 1): |
|
|
|
make_styled_video = True |
|
break |
|
else: |
|
|
|
this_video_frame += 1 |
|
|
|
|
|
i = -1 |
|
pbar.reset() |
|
|
|
|
|
args.init_image = video_frame_list[this_video_frame] |
|
print("Next frame: ", args.init_image) |
|
|
|
if args.seed is None: |
|
seed = torch.seed() |
|
else: |
|
seed = args.seed |
|
torch.manual_seed(seed) |
|
print("Seed: ", seed) |
|
|
|
filename = os.path.basename(args.init_image) |
|
args.output = os.path.join(cwd, "steps", filename) |
|
|
|
|
|
img = Image.open(args.init_image) |
|
pil_image = img.convert('RGB') |
|
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) |
|
pil_tensor = TF.to_tensor(pil_image) |
|
|
|
|
|
z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1) |
|
z_orig = z.clone() |
|
z.requires_grad_(True) |
|
|
|
|
|
opt = get_opt(args.optimiser, args.step_size) |
|
|
|
i += 1 |
|
pbar.update() |
|
except KeyboardInterrupt: |
|
pass |
|
|
|
|
|
|
|
|
|
if args.make_video or args.make_zoom_video: |
|
init_frame = 1 |
|
if args.make_zoom_video: |
|
last_frame = j |
|
else: |
|
last_frame = i |
|
|
|
length = args.video_length |
|
|
|
min_fps = 10 |
|
max_fps = 60 |
|
|
|
total_frames = last_frame-init_frame |
|
|
|
frames = [] |
|
tqdm.write('Generating video...') |
|
for i in range(init_frame,last_frame): |
|
temp = Image.open("./steps/"+ str(i) +'.png') |
|
keep = temp.copy() |
|
frames.append(keep) |
|
temp.close() |
|
|
|
if args.output_video_fps > 9: |
|
|
|
print("Creating interpolated frames...") |
|
ffmpeg_filter = f"minterpolate='mi_mode=mci:me=hexbs:me_mode=bidir:mc_mode=aobmc:vsbmc=1:mb_size=8:search_param=32:fps={args.output_video_fps}'" |
|
output_file = re.compile('\.png$').sub('.mp4', args.output) |
|
try: |
|
p = Popen(['ffmpeg', |
|
'-y', |
|
'-f', 'image2pipe', |
|
'-vcodec', 'png', |
|
'-r', str(args.input_video_fps), |
|
'-i', |
|
'-', |
|
'-b:v', '10M', |
|
'-vcodec', 'h264_nvenc', |
|
'-pix_fmt', 'yuv420p', |
|
'-strict', '-2', |
|
'-filter:v', f'{ffmpeg_filter}', |
|
'-metadata', f'comment={args.prompts}', |
|
output_file], stdin=PIPE) |
|
except FileNotFoundError: |
|
print("ffmpeg command failed - check your installation") |
|
for im in tqdm(frames): |
|
im.save(p.stdin, 'PNG') |
|
p.stdin.close() |
|
p.wait() |
|
else: |
|
|
|
fps = np.clip(total_frames/length,min_fps,max_fps) |
|
output_file = re.compile('\.png$').sub('.mp4', args.output) |
|
try: |
|
p = Popen(['ffmpeg', |
|
'-y', |
|
'-f', 'image2pipe', |
|
'-vcodec', 'png', |
|
'-r', str(fps), |
|
'-i', |
|
'-', |
|
'-vcodec', 'libx264', |
|
'-r', str(fps), |
|
'-pix_fmt', 'yuv420p', |
|
'-crf', '17', |
|
'-preset', 'veryslow', |
|
'-metadata', f'comment={args.prompts}', |
|
output_file], stdin=PIPE) |
|
except FileNotFoundError: |
|
print("ffmpeg command failed - check your installation") |
|
for im in tqdm(frames): |
|
im.save(p.stdin, 'PNG') |
|
p.stdin.close() |
|
p.wait() |
|
|