""" |
More info on flavors [here](https://i.ibb.co/hCdm3W4/flavors.png). |
More info on prompt experiments [here](https://i.ibb.co/0FF7vNn/prompt-experiments.png). |
The styles of made-up, not real, artists can be found [here](https://docs.google.com/spreadsheets/d/1nMq-TjBj3t6us-npLRoLFq0VtgpVwdXCKTcQgnxKgTQ/edit?usp=sharing). |
Keywords cheatsheet can be found [here](https://imgur.com/a/SnSIQRu) (made by kingdomakrillic). |
A short guide to prompt engineering can be found [here](https://docs.google.com/document/d/1qy5fdeThu7pIikulQuWpmKvYBiv9wMshIHcrBr-VldA/edit?usp=sharing). |
""" |
""" |
Main_Libraries = True #@param {type:"boolean"} |
Import_Libraries = True |
Download_Video = False #@param {type:"boolean"} |
Download_Super_Res = False #@param {type:"boolean"} |
Download_Super_Slomo = False #@param {type:"boolean"} |
if Main_Libraries == True: |
print('GPU:') |
!nvidia-smi --query-gpu=name,memory.total --format=cs |
print("\nDownloading CLIP...") |
!git clone https://github.com/openai/CLIP &> /dev/null |
print("Installing AI Python libraries...") |
!git clone https://github.com/CompVis/taming-transformers &> /dev/null |
!pip install ftfy regex tqdm omegaconf pytorch-lightning &> /dev/null |
!pip install kornia &> /dev/null |
!pip install einops &> /dev/null |
!pip install transformers &> /dev/null |
!pip install torch_optimizer &> /dev/null |
!pip install noise &> /dev/null |
!pip install gputil &> /dev/null |
!pip install taming-transformers &> /dev/null |
#!git clone https://github.com/lessw2020/Ranger21.git &> /dev/null |
#!cd Ranger21 &> /dev/null |
#!pip install -e . &> /dev/null |
#!cd .. &> /dev/null |
!mkdir steps |
# %mkdir Init_Img |
print("Installing libraries for handling metadata...") |
!pip install stegano &> /dev/null |
!apt install exempi &> /dev/null |
!pip install python-xmp-toolkit &> /dev/null |
!pip install imgtag &> /dev/null |
if Download_Video: |
print("Installing Python libraries for video creation...") |
!pip install imageio-ffmpeg &> /dev/null |
!pip install timm &> /dev/null |
if Download_Super_Res: |
print("Installing Python libraries for super resolution...") |
# %cd /content/ |
!git clone https://github.com/sberbank-ai/Real-ESRGAN /content/RealESRGAN &> /dev/null |
# %cd RealESRGAN |
!pip install -r requirements.txt &> /dev/null |
# download model weights |
# x2 |
#!gdown https://drive.google.com/uc?id=1pG2S3sYvSaO0V0B8QPOl1RapPHpUGOaV -O weights/RealESRGAN_x2.pth |
# x4 |
!gdown https://drive.google.com/uc?id=1SGHdZAln4en65_NQeQY9UjchtkEF9f5F -O weights/RealESRGAN_x4.pth &> /dev/null |
# x8 |
#!gdown https://drive.google.com/uc?id=1mT9ewx86PSrc43b-ax47l1E2UzR7Ln4j -O weights/RealESRGAN_x8.pth |
# %cd /content/ |
if Download_Super_Slomo: |
!git clone -q --depth 1 https://github.com/avinashpaliwal/Super-SloMo.git &> /dev/null |
from os.path import exists |
def download_from_google_drive(file_id, file_name): |
# download a file from the Google Drive link |
!rm -f ./cookie |
!curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id={file_id}" > /dev/null |
confirm_text = !awk '/download/ {print $NF}' ./cookie |
confirm_text = confirm_text[0] |
!curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm={confirm_text}&id={file_id}" -o {file_name} &> /dev/null |
pretrained_model = 'SuperSloMo.ckpt' |
if not exists(pretrained_model): |
download_from_google_drive('1IvobLDbRiBgZr3ryCRrWL8xDbMZ-KnpF', pretrained_model) |
# %mkdir png_processing |
# %mkdir templates |
!curl https://i.ibb.co/3kn9Qrv/flag.png -o templates/flag.png &> /dev/null |
!curl https://i.ibb.co/0BHqVyg/14135136623-3973d3f03c-z.jpg -o templates/planet.png &> /dev/null |
!curl https://i.ibb.co/52WMK2M/j7oocvu80qe11.png -o templates/map.png &> /dev/null |
!curl https://i.ibb.co/3fg9Zkx/creature.png -o templates/creature.png &> /dev/null |
!curl https://i.ibb.co/X3Mh2pP/human.jpg -o templates/human.png &> /dev/null |
""" |
import sys |
import streamlit as st |
import argparse |
import math |
from pathlib import Path |
import sys |
import pandas as pd |
from IPython import display |
from base64 import b64encode |
from omegaconf import OmegaConf |
from PIL import Image |
from taming.models import cond_transformer, vqgan |
import torch |
from os.path import exists as path_exists |
torch.cuda.empty_cache() |
from torch import nn |
import torch.optim as optim |
from torch import optim |
from torch.nn import functional as F |
from torchvision import transforms |
from torchvision.transforms import functional as TF |
import torchvision.transforms as T |
from CLIP import clip |
import kornia.augmentation as K |
import numpy as np |
import subprocess |
import imageio |
from PIL import ImageFile, Image |
import time |
import hashlib |
from PIL.PngImagePlugin import PngImageFile, PngInfo |
import json |
import IPython |
from IPython.display import Markdown, display, Image, clear_output |
import urllib.request |
import random |
from random import randint |
from pathvalidate import sanitize_filename |
sys.stdout.write("Imports ...\n") |
sys.stdout.flush() |
sys.path.append("./CLIP") |
sys.path.append("./taming-transformers") |
sys.stdout.write("Parsing arguments ...\n") |
sys.stdout.flush() |
def run_model(args2, status, stoutput, DefaultPaths): |
if args2.seed is not None: |
import torch |
sys.stdout.write(f"Setting seed to {args2.seed} ...\n") |
sys.stdout.flush() |
status.write(f"Setting seed to {args2.seed} ...\n") |
import numpy as np |
np.random.seed(args2.seed) |
import random |
random.seed(args2.seed) |
torch.manual_seed(args2.seed) |
torch.cuda.manual_seed(args2.seed) |
torch.cuda.manual_seed_all(args2.seed) |
torch.backends.cudnn.deterministic = True |
torch.backends.cudnn.benchmark = False |
""" |
from imgtag import ImgTag # metadata |
from libxmp import * # metadata |
import libxmp # metadata |
from stegano import lsb |
import gc |
import GPUtil as GPU |
""" |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
print("Using device:", device) |
def noise_gen(shape, octaves=5): |
n, c, h, w = shape |
noise = torch.zeros([n, c, 1, 1]) |
max_octaves = min(octaves, math.log(h) / math.log(2), math.log(w) / math.log(2)) |
for i in reversed(range(max_octaves)): |
h_cur, w_cur = h // 2**i, w // 2**i |
noise = F.interpolate( |
noise, (h_cur, w_cur), mode="bicubic", align_corners=False |
) |
noise += torch.randn([n, c, h_cur, w_cur]) / 5 |
return noise |
def sinc(x): |
return torch.where( |
x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]) |
) |
def lanczos(x, a): |
cond = torch.logical_and(-a < x, x < a) |
out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([])) |
return out / out.sum() |
def ramp(ratio, width): |
n = math.ceil(width / ratio + 1) |
out = torch.empty([n]) |
cur = 0 |
for i in range(out.shape[0]): |
out[i] = cur |
cur += ratio |
return torch.cat([-out[1:].flip([0]), out])[1:-1] |
def resample(input, size, align_corners=True): |
n, c, h, w = input.shape |
dh, dw = size |
input = input.view([n * c, 1, h, w]) |
if dh < h: |
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) |
pad_h = (kernel_h.shape[0] - 1) // 2 |
input = F.pad(input, (0, 0, pad_h, pad_h), "reflect") |
input = F.conv2d(input, kernel_h[None, None, :, None]) |
if dw < w: |
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) |
pad_w = (kernel_w.shape[0] - 1) // 2 |
input = F.pad(input, (pad_w, pad_w, 0, 0), "reflect") |
input = F.conv2d(input, kernel_w[None, None, None, :]) |
input = input.view([n, c, h, w]) |
return F.interpolate(input, size, mode="bicubic", align_corners=align_corners) |
def lerp(a, b, f): |
return (a * (1.0 - f)) + (b * f) |
class ReplaceGrad(torch.autograd.Function): |
@staticmethod |
def forward(ctx, x_forward, x_backward): |
ctx.shape = x_backward.shape |
return x_forward |
@staticmethod |
def backward(ctx, grad_in): |
return None, grad_in.sum_to_size(ctx.shape) |
replace_grad = ReplaceGrad.apply |
class ClampWithGrad(torch.autograd.Function): |
@staticmethod |
def forward(ctx, input, min, max): |
ctx.min = min |
ctx.max = max |
ctx.save_for_backward(input) |
return input.clamp(min, max) |
@staticmethod |
def backward(ctx, grad_in): |
(input,) = ctx.saved_tensors |
return ( |
grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), |
None, |
None, |
) |
clamp_with_grad = ClampWithGrad.apply |
def vector_quantize(x, codebook): |
d = ( |
x.pow(2).sum(dim=-1, keepdim=True) |
+ codebook.pow(2).sum(dim=1) |
- 2 * x @ codebook.T |
) |
indices = d.argmin(-1) |
x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook |
return replace_grad(x_q, x) |
class Prompt(nn.Module): |
def __init__(self, embed, weight=1.0, stop=float("-inf")): |
super().__init__() |
self.register_buffer("embed", embed) |
self.register_buffer("weight", torch.as_tensor(weight)) |
self.register_buffer("stop", torch.as_tensor(stop)) |
def forward(self, input): |
input_normed = F.normalize(input.unsqueeze(1), dim=2) |
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) |
dists = ( |
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) |
) |
dists = dists * self.weight.sign() |
return ( |
self.weight.abs() |
* replace_grad(dists, torch.maximum(dists, self.stop)).mean() |
) |
def parse_prompt(prompt): |
if prompt.startswith("http://") or prompt.startswith("https://"): |
vals = prompt.rsplit(":", 1) |
vals = [vals[0] + ":" + vals[1], *vals[2:]] |
else: |
vals = prompt.rsplit(":", 1) |
vals = vals + ["", "1", "-inf"][len(vals) :] |
return vals[0], float(vals[1]), float(vals[2]) |
def one_sided_clip_loss(input, target, labels=None, logit_scale=100): |
input_normed = F.normalize(input, dim=-1) |
target_normed = F.normalize(target, dim=-1) |
logits = input_normed @ target_normed.T * logit_scale |
if labels is None: |
labels = torch.arange(len(input), device=logits.device) |
return F.cross_entropy(logits, labels) |
class EMATensor(nn.Module): |
"""implmeneted by Katherine Crowson""" |
def __init__(self, tensor, decay): |
super().__init__() |
self.tensor = nn.Parameter(tensor) |
self.register_buffer("biased", torch.zeros_like(tensor)) |
self.register_buffer("average", torch.zeros_like(tensor)) |
self.decay = decay |
self.register_buffer("accum", torch.tensor(1.0)) |
self.update() |
@torch.no_grad() |
def update(self): |
if not self.training: |
raise RuntimeError("update() should only be called during training") |
self.accum *= self.decay |
self.biased.mul_(self.decay) |
self.biased.add_((1 - self.decay) * self.tensor) |
self.average.copy_(self.biased) |
self.average.div_(1 - self.accum) |
def forward(self): |
if self.training: |
return self.tensor |
return self.average |
class MakeCutoutsCustom(nn.Module): |
def __init__(self, cut_size, cutn, cut_pow, augs): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.cut_pow = cut_pow |
self.noise_fac = 0.1 |
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) |
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) |
self.augs = nn.Sequential( |
K.RandomHorizontalFlip(p=Random_Horizontal_Flip), |
K.RandomSharpness(Random_Sharpness, p=Random_Sharpness_P), |
K.RandomGaussianBlur( |
(Random_Gaussian_Blur), |
(Random_Gaussian_Blur_W, Random_Gaussian_Blur_W), |
p=Random_Gaussian_Blur_P, |
), |
K.RandomGaussianNoise(p=Random_Gaussian_Noise_P), |
K.RandomElasticTransform( |
kernel_size=( |
Random_Elastic_Transform_Kernel_Size_W, |
Random_Elastic_Transform_Kernel_Size_H, |
), |
sigma=(Random_Elastic_Transform_Sigma), |
p=Random_Elastic_Transform_P, |
), |
K.RandomAffine( |
degrees=Random_Affine_Degrees, |
translate=Random_Affine_Translate, |
p=Random_Affine_P, |
padding_mode="border", |
), |
K.RandomPerspective(Random_Perspective, p=Random_Perspective_P), |
K.ColorJitter( |
hue=Color_Jitter_Hue, |
saturation=Color_Jitter_Saturation, |
p=Color_Jitter_P, |
), |
) |
def set_cut_pow(self, cut_pow): |
self.cut_pow = cut_pow |
def forward(self, input): |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
min_size = min(sideX, sideY, self.cut_size) |
cutouts = [] |
cutouts_full = [] |
noise_fac = 0.1 |
min_size_width = min(sideX, sideY) |
lower_bound = float(self.cut_size / min_size_width) |
for ii in range(self.cutn): |
randsize = ( |
torch.zeros( |
1, |
) |
.normal_(mean=0.8, std=0.3) |
.clip(lower_bound, 1.0) |
) |
size_mult = randsize**self.cut_pow |
size = int( |
min_size_width * (size_mult.clip(lower_bound, 1.0)) |
) |
offsetx = torch.randint(0, sideX - size + 1, ()) |
offsety = torch.randint(0, sideY - size + 1, ()) |
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
cutouts = torch.cat(cutouts, dim=0) |
cutouts = clamp_with_grad(cutouts, 0, 1) |
cutouts = self.augs(cutouts) |
if self.noise_fac: |
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_( |
0, self.noise_fac |
) |
cutouts = cutouts + facs * torch.randn_like(cutouts) |
return cutouts |
class MakeCutoutsJuu(nn.Module): |
def __init__(self, cut_size, cutn, cut_pow, augs): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.cut_pow = cut_pow |
self.augs = nn.Sequential( |
K.RandomHorizontalFlip(p=0.5), |
K.RandomSharpness(0.3, p=0.4), |
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), |
K.RandomPerspective(0.2, p=0.4), |
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), |
K.RandomGrayscale(p=0.1), |
) |
self.noise_fac = 0.1 |
def forward(self, input): |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
min_size = min(sideX, sideY, self.cut_size) |
cutouts = [] |
for _ in range(self.cutn): |
size = int( |
torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size |
) |
offsetx = torch.randint(0, sideX - size + 1, ()) |
offsety = torch.randint(0, sideY - size + 1, ()) |
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
batch = self.augs(torch.cat(cutouts, dim=0)) |
if self.noise_fac: |
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) |
batch = batch + facs * torch.randn_like(batch) |
return batch |
class MakeCutoutsMoth(nn.Module): |
def __init__(self, cut_size, cutn, cut_pow, augs, skip_augs=False): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.cut_pow = cut_pow |
self.skip_augs = skip_augs |
self.augs = T.Compose( |
[ |
T.RandomHorizontalFlip(p=0.5), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
T.RandomAffine(degrees=15, translate=(0.1, 0.1)), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
T.RandomPerspective(distortion_scale=0.4, p=0.7), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
T.RandomGrayscale(p=0.15), |
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
] |
) |
def forward(self, input): |
input = T.Pad(input.shape[2] // 4, fill=0)(input) |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
cutouts = [] |
for ch in range(cutn): |
if ch > cutn - cutn // 4: |
cutout = input.clone() |
else: |
size = int( |
max_size |
* torch.zeros( |
1, |
) |
.normal_(mean=0.8, std=0.3) |
.clip(float(self.cut_size / max_size), 1.0) |
) |
offsetx = torch.randint(0, abs(sideX - size + 1), ()) |
offsety = torch.randint(0, abs(sideY - size + 1), ()) |
cutout = input[ |
:, :, offsety : offsety + size, offsetx : offsetx + size |
] |
if not self.skip_augs: |
cutout = self.augs(cutout) |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
del cutout |
cutouts = torch.cat(cutouts, dim=0) |
return cutouts |
class MakeCutoutsAaron(nn.Module): |
def __init__(self, cut_size, cutn, cut_pow, augs): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.cut_pow = cut_pow |
self.augs = augs |
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) |
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) |
def set_cut_pow(self, cut_pow): |
self.cut_pow = cut_pow |
def forward(self, input): |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
min_size = min(sideX, sideY, self.cut_size) |
cutouts = [] |
cutouts_full = [] |
min_size_width = min(sideX, sideY) |
lower_bound = float(self.cut_size / min_size_width) |
for ii in range(self.cutn): |
size = int( |
min_size_width |
* torch.zeros( |
1, |
) |
.normal_(mean=0.8, std=0.3) |
.clip(lower_bound, 1.0) |
) |
offsetx = torch.randint(0, sideX - size + 1, ()) |
offsety = torch.randint(0, sideY - size + 1, ()) |
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
cutouts = torch.cat(cutouts, dim=0) |
return clamp_with_grad(cutouts, 0, 1) |
class MakeCutoutsCumin(nn.Module): |
def __init__(self, cut_size, cutn, cut_pow, augs): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.cut_pow = cut_pow |
self.noise_fac = 0.1 |
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) |
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) |
self.augs = nn.Sequential( |
K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode="border"), |
K.RandomPerspective(0.7, p=0.7), |
K.ColorJitter(hue=0.1, saturation=0.1, p=0.7), |
K.RandomErasing((0.1, 0.4), (0.3, 1 / 0.3), same_on_batch=True, p=0.7), |
) |
def set_cut_pow(self, cut_pow): |
self.cut_pow = cut_pow |
def forward(self, input): |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
min_size = min(sideX, sideY, self.cut_size) |
cutouts = [] |
cutouts_full = [] |
noise_fac = 0.1 |
min_size_width = min(sideX, sideY) |
lower_bound = float(self.cut_size / min_size_width) |
for ii in range(self.cutn): |
randsize = ( |
torch.zeros( |
1, |
) |
.normal_(mean=0.8, std=0.3) |
.clip(lower_bound, 1.0) |
) |
size_mult = randsize**self.cut_pow |
size = int( |
min_size_width * (size_mult.clip(lower_bound, 1.0)) |
) |
offsetx = torch.randint(0, sideX - size + 1, ()) |
offsety = torch.randint(0, sideY - size + 1, ()) |
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
cutouts = torch.cat(cutouts, dim=0) |
cutouts = clamp_with_grad(cutouts, 0, 1) |
cutouts = self.augs(cutouts) |
if self.noise_fac: |
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_( |
0, self.noise_fac |
) |
cutouts = cutouts + facs * torch.randn_like(cutouts) |
return cutouts |
class MakeCutoutsHolywater(nn.Module): |
def __init__(self, cut_size, cutn, cut_pow, augs): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.cut_pow = cut_pow |
self.noise_fac = 0.1 |
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) |
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) |
self.augs = nn.Sequential( |
K.RandomHorizontalFlip(p=0.5), |
K.RandomSharpness(0.3, p=0.4), |
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), |
K.RandomPerspective(0.2, p=0.4), |
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), |
K.RandomGrayscale(p=0.1), |
) |
def set_cut_pow(self, cut_pow): |
self.cut_pow = cut_pow |
def forward(self, input): |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
min_size = min(sideX, sideY, self.cut_size) |
cutouts = [] |
cutouts_full = [] |
noise_fac = 0.1 |
min_size_width = min(sideX, sideY) |
lower_bound = float(self.cut_size / min_size_width) |
for ii in range(self.cutn): |
size = int( |
torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size |
) |
randsize = ( |
torch.zeros( |
1, |
) |
.normal_(mean=0.8, std=0.3) |
.clip(lower_bound, 1.0) |
) |
size_mult = randsize**self.cut_pow * ii + size |
size1 = int( |
(min_size_width) * (size_mult.clip(lower_bound, 1.0)) |
) |
size2 = int( |
(min_size_width) |
* torch.zeros( |
1, |
) |
.normal_(mean=0.9, std=0.3) |
.clip(lower_bound, 0.95) |
) |
offsetx = torch.randint(0, sideX - size1 + 1, ()) |
offsety = torch.randint(0, sideY - size2 + 1, ()) |
cutout = input[ |
:, :, offsety : offsety + size2 + ii, offsetx : offsetx + size1 + ii |
] |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
cutouts = torch.cat(cutouts, dim=0) |
cutouts = clamp_with_grad(cutouts, 0, 1) |
cutouts = self.augs(cutouts) |
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_( |
0, self.noise_fac |
) |
cutouts = cutouts + facs * torch.randn_like(cutouts) |
return cutouts |
class MakeCutoutsOldHolywater(nn.Module): |
def __init__(self, cut_size, cutn, cut_pow, augs): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.cut_pow = cut_pow |
self.noise_fac = 0.1 |
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) |
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) |
self.augs = nn.Sequential( |
K.RandomAffine( |
degrees=180, translate=0.5, p=0.2, padding_mode="border" |
), |
K.RandomPerspective(0.6, p=0.9), |
K.ColorJitter(hue=0.03, saturation=0.01, p=0.1), |
K.RandomErasing((0.1, 0.7), (0.3, 1 / 0.4), same_on_batch=True, p=0.2), |
) |
def set_cut_pow(self, cut_pow): |
self.cut_pow = cut_pow |
def forward(self, input): |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
min_size = min(sideX, sideY, self.cut_size) |
cutouts = [] |
cutouts_full = [] |
noise_fac = 0.1 |
min_size_width = min(sideX, sideY) |
lower_bound = float(self.cut_size / min_size_width) |
for ii in range(self.cutn): |
randsize = ( |
torch.zeros( |
1, |
) |
.normal_(mean=0.8, std=0.3) |
.clip(lower_bound, 1.0) |
) |
size_mult = randsize**self.cut_pow |
size = int( |
min_size_width * (size_mult.clip(lower_bound, 1.0)) |
) |
offsetx = torch.randint(0, sideX - size + 1, ()) |
offsety = torch.randint(0, sideY - size + 1, ()) |
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
cutouts = torch.cat(cutouts, dim=0) |
cutouts = clamp_with_grad(cutouts, 0, 1) |
cutouts = self.augs(cutouts) |
if self.noise_fac: |
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_( |
0, self.noise_fac |
) |
cutouts = cutouts + facs * torch.randn_like(cutouts) |
return cutouts |
class MakeCutoutsGinger(nn.Module): |
def __init__(self, cut_size, cutn, cut_pow, augs): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.cut_pow = cut_pow |
self.noise_fac = 0.1 |
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) |
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) |
self.augs = augs |
""" |
nn.Sequential( |
K.RandomHorizontalFlip(p=0.5), |
K.RandomSharpness(0.3,p=0.4), |
K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2), |
K.RandomGaussianNoise(p=0.5), |
K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2), |
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'), # padding_mode=2 |
K.RandomPerspective(0.2,p=0.4, ), |
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),) |
""" |
def set_cut_pow(self, cut_pow): |
self.cut_pow = cut_pow |
def forward(self, input): |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
min_size = min(sideX, sideY, self.cut_size) |
cutouts = [] |
cutouts_full = [] |
noise_fac = 0.1 |
min_size_width = min(sideX, sideY) |
lower_bound = float(self.cut_size / min_size_width) |
for ii in range(self.cutn): |
randsize = ( |
torch.zeros( |
1, |
) |
.normal_(mean=0.8, std=0.3) |
.clip(lower_bound, 1.0) |
) |
size_mult = randsize**self.cut_pow |
size = int( |
min_size_width * (size_mult.clip(lower_bound, 1.0)) |
) |
offsetx = torch.randint(0, sideX - size + 1, ()) |
offsety = torch.randint(0, sideY - size + 1, ()) |
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
cutouts = torch.cat(cutouts, dim=0) |
cutouts = clamp_with_grad(cutouts, 0, 1) |
cutouts = self.augs(cutouts) |
if self.noise_fac: |
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_( |
0, self.noise_fac |
) |
cutouts = cutouts + facs * torch.randn_like(cutouts) |
return cutouts |
class MakeCutoutsZynth(nn.Module): |
def __init__(self, cut_size, cutn, cut_pow, augs): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.cut_pow = cut_pow |
self.noise_fac = 0.1 |
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) |
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) |
self.augs = nn.Sequential( |
K.RandomHorizontalFlip(p=0.5), |
K.RandomSharpness(0.3, p=0.4), |
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"), |
K.RandomPerspective(0.2, p=0.4), |
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), |
) |
def set_cut_pow(self, cut_pow): |
self.cut_pow = cut_pow |
def forward(self, input): |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
min_size = min(sideX, sideY, self.cut_size) |
cutouts = [] |
cutouts_full = [] |
noise_fac = 0.1 |
min_size_width = min(sideX, sideY) |
lower_bound = float(self.cut_size / min_size_width) |
for ii in range(self.cutn): |
randsize = ( |
torch.zeros( |
1, |
) |
.normal_(mean=0.8, std=0.3) |
.clip(lower_bound, 1.0) |
) |
size_mult = randsize**self.cut_pow |
size = int( |
min_size_width * (size_mult.clip(lower_bound, 1.0)) |
) |
offsetx = torch.randint(0, sideX - size + 1, ()) |
offsety = torch.randint(0, sideY - size + 1, ()) |
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
cutouts = torch.cat(cutouts, dim=0) |
cutouts = clamp_with_grad(cutouts, 0, 1) |
cutouts = self.augs(cutouts) |
if self.noise_fac: |
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_( |
0, self.noise_fac |
) |
cutouts = cutouts + facs * torch.randn_like(cutouts) |
return cutouts |
class MakeCutoutsWyvern(nn.Module): |
def __init__(self, cut_size, cutn, cut_pow, augs): |
super().__init__() |
self.cut_size = cut_size |
self.cutn = cutn |
self.cut_pow = cut_pow |
self.noise_fac = 0.1 |
self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) |
self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) |
self.augs = augs |
def forward(self, input): |
sideY, sideX = input.shape[2:4] |
max_size = min(sideX, sideY) |
min_size = min(sideX, sideY, self.cut_size) |
cutouts = [] |
for _ in range(self.cutn): |
size = int( |
torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size |
) |
offsetx = torch.randint(0, sideX - size + 1, ()) |
offsety = torch.randint(0, sideY - size + 1, ()) |
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] |
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1) |
def load_vqgan_model(config_path, checkpoint_path): |
config = OmegaConf.load(config_path) |
if config.model.target == "taming.models.vqgan.VQModel": |
model = vqgan.VQModel(**config.model.params) |
model.eval().requires_grad_(False) |
model.init_from_ckpt(checkpoint_path) |
elif config.model.target == "taming.models.cond_transformer.Net2NetTransformer": |
parent_model = cond_transformer.Net2NetTransformer(**config.model.params) |
parent_model.eval().requires_grad_(False) |
parent_model.init_from_ckpt(checkpoint_path) |
model = parent_model.first_stage_model |
elif config.model.target == "taming.models.vqgan.GumbelVQ": |
model = vqgan.GumbelVQ(**config.model.params) |
model.eval().requires_grad_(False) |
model.init_from_ckpt(checkpoint_path) |
else: |
raise ValueError(f"unknown model type: {config.model.target}") |
del model.loss |
return model |
import PIL |
def resize_image(image, out_size): |
ratio = image.size[0] / image.size[1] |
area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) |
size = round((area * ratio) ** 0.5), round((area / ratio) ** 0.5) |
return image.resize(size, PIL.Image.LANCZOS) |
class GaussianBlur2d(nn.Module): |
def __init__(self, sigma, window=0, mode="reflect", value=0): |
super().__init__() |
self.mode = mode |
self.value = value |
if not window: |
window = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3) |
if sigma: |
kernel = torch.exp( |
-((torch.arange(window) - window // 2) ** 2) / 2 / sigma**2 |
) |
kernel /= kernel.sum() |
else: |
kernel = torch.ones([1]) |
self.register_buffer("kernel", kernel) |
def forward(self, input): |
n, c, h, w = input.shape |
input = input.view([n * c, 1, h, w]) |
start_pad = (self.kernel.shape[0] - 1) // 2 |
end_pad = self.kernel.shape[0] // 2 |
input = F.pad( |
input, (start_pad, end_pad, start_pad, end_pad), self.mode, self.value |
) |
input = F.conv2d(input, self.kernel[None, None, None, :]) |
input = F.conv2d(input, self.kernel[None, None, :, None]) |
return input.view([n, c, h, w]) |
BUF_SIZE = 65536 |
def get_digest(path, alg=hashlib.sha256): |
hash = alg() |
with open(path, "rb") as fp: |
while True: |
data = fp.read(BUF_SIZE) |
if not data: |
break |
hash.update(data) |
return b64encode(hash.digest()).decode("utf-8") |
flavordict = { |
"cumin": MakeCutoutsCumin, |
"holywater": MakeCutoutsHolywater, |
"old_holywater": MakeCutoutsOldHolywater, |
"ginger": MakeCutoutsGinger, |
"zynth": MakeCutoutsZynth, |
"wyvern": MakeCutoutsWyvern, |
"aaron": MakeCutoutsAaron, |
"moth": MakeCutoutsMoth, |
"juu": MakeCutoutsJuu, |
"custom": MakeCutoutsCustom, |
} |
@torch.jit.script |
def gelu_impl(x): |
"""OpenAI's gelu implementation.""" |
return ( |
0.5 |
* x |
* (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) |
) |
def gelu(x): |
return gelu_impl(x) |
class MSEDecayLoss(nn.Module): |
def __init__(self, init_weight, mse_decay_rate, mse_epoches, mse_quantize): |
super().__init__() |
self.init_weight = init_weight |
self.has_init_image = False |
self.mse_decay = init_weight / mse_epoches if init_weight else 0 |
self.mse_decay_rate = mse_decay_rate |
self.mse_weight = init_weight |
self.mse_epoches = mse_epoches |
self.mse_quantize = mse_quantize |
@torch.no_grad() |
def set_target(self, z_tensor, model): |
z_tensor = z_tensor.detach().clone() |
if self.mse_quantize: |
z_tensor = vector_quantize( |
z_tensor.movedim(1, 3), model.quantize.embedding.weight |
).movedim( |
3, 1 |
) |
self.z_orig = z_tensor |
def forward(self, i, z): |
if self.is_active(i): |
return F.mse_loss(z, self.z_orig) * self.mse_weight / 2 |
return 0 |
def is_active(self, i): |
if not self.init_weight: |
return False |
if i <= self.mse_decay_rate and not self.has_init_image: |
return False |
return True |
@torch.no_grad() |
def step(self, i): |
if ( |
i % self.mse_decay_rate == 0 |
and i != 0 |
and i < self.mse_decay_rate * self.mse_epoches |
): |
if ( |
self.mse_weight - self.mse_decay > 0 |
and self.mse_weight - self.mse_decay >= self.mse_decay |
): |
self.mse_weight -= self.mse_decay |
else: |
self.mse_weight = 0 |
return True |
return False |
class TVLoss(nn.Module): |
def forward(self, input): |
input = F.pad(input, (0, 1, 0, 1), "replicate") |
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] |
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] |
diff = x_diff**2 + y_diff**2 + 1e-8 |
return diff.mean(dim=1).sqrt().mean() |
class MultiClipLoss(nn.Module): |
def __init__( |
self, clip_models, text_prompt, cutn, cut_pow=1.0, clip_weight=1.0 |
): |
super().__init__() |
self.perceptors = [] |
for cm in clip_models: |
sys.stdout.write(f"Loading {cm[0]} ...\n") |
sys.stdout.flush() |
c = ( |
clip.load(cm[0], jit=False)[0] |
.eval() |
.requires_grad_(False) |
.to(device) |
) |
self.perceptors.append( |
{ |
"res": c.visual.input_resolution, |
"perceptor": c, |
"weight": cm[1], |
"prompts": [], |
} |
) |
self.perceptors.sort(key=lambda e: e["res"], reverse=True) |
self.max_cut_size = self.perceptors[0]["res"] |
texts = text_prompt |
self.pMs = [] |
for prompt in texts: |
txt, weight, stop = parse_prompt(prompt) |
clip_token = clip.tokenize(txt).to(device) |
for p in self.perceptors: |
embed = p["perceptor"].encode_text(clip_token).float() |
embed_normed = F.normalize(embed.unsqueeze(0), dim=2) |
p["prompts"].append( |
{ |
"embed_normed": embed_normed, |
"weight": torch.as_tensor(weight, device=device), |
"stop": torch.as_tensor(stop, device=device), |
} |
) |
self.normalize = transforms.Normalize( |
mean=[0.48145466, 0.4578275, 0.40821073], |
std=[0.26862954, 0.26130258, 0.27577711], |
) |
self.augs = nn.Sequential( |
K.RandomHorizontalFlip(p=0.5), |
K.RandomSharpness(0.3, p=0.1), |
K.RandomAffine( |
degrees=30, translate=0.1, p=0.8, padding_mode="border" |
), |
K.RandomPerspective( |
0.2, |
p=0.4, |
), |
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), |
K.RandomGrayscale(p=0.15), |
) |
self.noise_fac = 0.1 |
self.clip_weight = clip_weight |
def prepare_cuts(self, img): |
cutouts = self.make_cuts(img) |
cutouts = self.augs(cutouts) |
if self.noise_fac: |
facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_( |
0, self.noise_fac |
) |
cutouts = cutouts + facs * torch.randn_like(cutouts) |
cutouts = self.normalize(cutouts) |
return cutouts |
def forward(self, i, img): |
cutouts = checkpoint(self.prepare_cuts, img) |
loss = [] |
current_cuts = cutouts |
currentres = self.max_cut_size |
for p in self.perceptors: |
if currentres != p["res"]: |
current_cuts = resample(cutouts, (p["res"], p["res"])) |
currentres = p["res"] |
iii = p["perceptor"].encode_image(current_cuts).float() |
input_normed = F.normalize(iii.unsqueeze(1), dim=2) |
for prompt in p["prompts"]: |
dists = ( |
input_normed.sub(prompt["embed_normed"]) |
.norm(dim=2) |
.div(2) |
.arcsin() |
.pow(2) |
.mul(2) |
) |
dists = dists * prompt["weight"].sign() |
l = ( |
prompt["weight"].abs() |
* replace_grad( |
dists, torch.maximum(dists, prompt["stop"]) |
).mean() |
) |
loss.append(l * p["weight"]) |
return loss |
class ModelHost: |
def __init__(self, args): |
self.args = args |
self.model, self.perceptor = None, None |
self.make_cutouts = None |
self.alt_make_cutouts = None |
self.imageSize = None |
self.prompts = None |
self.opt = None |
self.normalize = None |
self.z, self.z_orig, self.z_min, self.z_max = None, None, None, None |
self.metadata = None |
self.mse_weight = 0 |
self.normal_flip_optim = None |
self.usealtprompts = False |
def setup_metadata(self, seed): |
metadata = {k: v for k, v in vars(self.args).items()} |
del metadata["max_iterations"] |
del metadata["display_freq"] |
metadata["seed"] = seed |
if metadata["init_image"]: |
path = metadata["init_image"] |
digest = get_digest(path) |
metadata["init_image"] = (path, digest) |
if metadata["image_prompts"]: |
prompts = [] |
for prompt in metadata["image_prompts"]: |
path = prompt |
digest = get_digest(path) |
prompts.append((path, digest)) |
metadata["image_prompts"] = prompts |
self.metadata = metadata |
def setup_model(self, x): |
i = x |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
""" |
print('Using device:', device) |
if self.args.prompts: |
print('Using prompts:', self.args.prompts) |
if self.args.altprompts: |
print('Using alternate augment set prompts:', self.args.altprompts) |
if self.args.image_prompts: |
print('Using image prompts:', self.args.image_prompts) |
if args.seed is None: |
seed = torch.seed() |
else: |
seed = args.seed |
torch.manual_seed(seed) |
print('Using seed:', seed) |
""" |
model = load_vqgan_model( |
f"{DefaultPaths.model_path}/{args.vqgan_model}.yaml", |
f"{DefaultPaths.model_path}/{args.vqgan_model}.ckpt", |
).to(device) |
active_clips = ( |
bool(self.args.clip_model2) |
+ bool(self.args.clip_model3) |
+ bool(self.args.clip_model4) |
+ bool(self.args.clip_model5) |
+ bool(self.args.clip_model6) |
+ bool(self.args.clip_model7) |
+ bool(self.args.clip_model8) |
) |
if active_clips != 0: |
clip_weight = round(1 / (active_clips + 1), 2) |
clip_models = [] |
clip_models.append([self.args.clip_model, clip_weight]) |
print(clip_models) |
else: |
clip_models = [[clip_model, 1.0]] |
if self.args.clip_model2: |
clip_models.append([self.args.clip_model2, clip_weight]) |
if self.args.clip_model3: |
clip_models.append([self.args.clip_model3, clip_weight]) |
if self.args.clip_model4: |
clip_models.append([self.args.clip_model4, clip_weight]) |
if self.args.clip_model5: |
clip_models.append([self.args.clip_model5, clip_weight]) |
if self.args.clip_model6: |
clip_models.append([self.args.clip_model6, clip_weight]) |
if self.args.clip_model7: |
clip_models.append([self.args.clip_model7, clip_weight]) |
if self.args.clip_model8: |
clip_models.append([self.args.clip_model8, clip_weight]) |
clip_loss = MultiClipLoss( |
clip_models, self.args.prompts, cutn=self.args.cutn |
) |
perceptor = ( |
clip.load(args.clip_model, jit=False)[0] |
.eval() |
.requires_grad_(False) |
.to(device) |
) |
cut_size = perceptor.visual.input_resolution |
if self.args.is_gumbel: |
e_dim = model.quantize.embedding_dim |
else: |
e_dim = model.quantize.e_dim |
f = 2 ** (model.decoder.num_resolutions - 1) |
make_cutouts = flavordict[flavor]( |
cut_size, args.mse_cutn, cut_pow=args.mse_cut_pow, augs=args.augs |
) |
if args.altprompts: |
self.usealtprompts = True |
self.alt_make_cutouts = flavordict[flavor]( |
cut_size, |
args.mse_cutn, |
cut_pow=args.alt_mse_cut_pow, |
augs=args.altaugs, |
) |
if self.args.is_gumbel: |
n_toks = model.quantize.n_embed |
else: |
n_toks = model.quantize.n_e |
toksX, toksY = args.size[0] // f, args.size[1] // f |
sideX, sideY = toksX * f, toksY * f |
if self.args.is_gumbel: |
z_min = model.quantize.embed.weight.min(dim=0).values[ |
None, :, None, None |
] |
z_max = model.quantize.embed.weight.max(dim=0).values[ |
None, :, None, None |
] |
else: |
z_min = model.quantize.embedding.weight.min(dim=0).values[ |
None, :, None, None |
] |
z_max = model.quantize.embedding.weight.max(dim=0).values[ |
None, :, None, None |
] |
from PIL import Image |
import cv2 |
working_dir = self.args.folder_name |
if self.args.init_image != "": |
img_0 = cv2.imread(init_image) |
z, *_ = model.encode( |
TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1 |
) |
elif not os.path.isfile(f"{working_dir}/steps/{i:04d}.png"): |
one_hot = F.one_hot( |
torch.randint(n_toks, [toksY * toksX], device=device), n_toks |
).float() |
if self.args.is_gumbel: |
z = one_hot @ model.quantize.embed.weight |
else: |
z = one_hot @ model.quantize.embedding.weight |
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) |
else: |
if save_all_iterations: |
img_0 = cv2.imread( |
f"{working_dir}/steps/{i:04d}_{iterations_per_frame}.png" |
) |
else: |
img_temp = cv2.imread(f"{working_dir}/steps/{i}.png") |
imageio.imwrite("inverted_temp.png", img_temp) |
img_0 = cv2.imread("inverted_temp.png") |
center = (1 * img_0.shape[1] // 2, 1 * img_0.shape[0] // 2) |
trans_mat = np.float32([[1, 0, 10], [0, 1, 10]]) |
rot_mat = cv2.getRotationMatrix2D(center, 10, 20) |
trans_mat = np.vstack([trans_mat, [0, 0, 1]]) |
rot_mat = np.vstack([rot_mat, [0, 0, 1]]) |
transformation_matrix = np.matmul(rot_mat, trans_mat) |
img_0 = cv2.warpPerspective( |
img_0, |
transformation_matrix, |
(img_0.shape[1], img_0.shape[0]), |
borderMode=cv2.BORDER_WRAP, |
) |
z, *_ = model.encode( |
TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1 |
) |
def save_output(i, img, suffix="zoomed"): |
filename = f"{working_dir}/steps/{i:04}{'_' + suffix if suffix else ''}.png" |
imageio.imwrite(filename, np.array(img)) |
save_output(i, img_0) |
if args.init_image: |
pil_image = Image.open(args.init_image).convert("RGB") |
pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS) |
z, *_ = model.encode( |
TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1 |
) |
else: |
one_hot = F.one_hot( |
torch.randint(n_toks, [toksY * toksX], device=device), n_toks |
).float() |
if self.args.is_gumbel: |
z = one_hot @ model.quantize.embed.weight |
else: |
z = one_hot @ model.quantize.embedding.weight |
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) |
z = EMATensor(z, args.ema_val) |
if args.mse_with_zeros and not args.init_image: |
z_orig = torch.zeros_like(z.tensor) |
else: |
z_orig = z.tensor.clone() |
z.requires_grad_(True) |
if self.normal_flip_optim == True: |
if randint(1, 2) == 1: |
opt = torch.optim.AdamW( |
z.parameters(), lr=args.step_size, weight_decay=0.00000000 |
) |
else: |
opt = optim.DiffGrad( |
z.parameters(), lr=args.step_size, weight_decay=0.00000000 |
) |
else: |
opt = torch.optim.AdamW( |
z.parameters(), lr=args.step_size, weight_decay=0.00000000 |
) |
self.cur_step_size = args.mse_step_size |
normalize = transforms.Normalize( |
mean=[0.48145466, 0.4578275, 0.40821073], |
std=[0.26862954, 0.26130258, 0.27577711], |
) |
pMs = [] |
altpMs = [] |
for prompt in args.prompts: |
txt, weight, stop = parse_prompt(prompt) |
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() |
pMs.append(Prompt(embed, weight, stop).to(device)) |
for prompt in args.altprompts: |
txt, weight, stop = parse_prompt(prompt) |
embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float() |
altpMs.append(Prompt(embed, weight, stop).to(device)) |
from PIL import Image |
for prompt in args.image_prompts: |
path, weight, stop = parse_prompt(prompt) |
img = resize_image(Image.open(path).convert("RGB"), (sideX, sideY)) |
batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device)) |
embed = perceptor.encode_image(normalize(batch)).float() |
pMs.append(Prompt(embed, weight, stop).to(device)) |
for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights): |
gen = torch.Generator().manual_seed(seed) |
embed = torch.empty([1, perceptor.visual.output_dim]).normal_( |
generator=gen |
) |
pMs.append(Prompt(embed, weight).to(device)) |
if self.usealtprompts: |
altpMs.append(Prompt(embed, weight).to(device)) |
self.model, self.perceptor = model, perceptor |
self.make_cutouts = make_cutouts |
self.imageSize = (sideX, sideY) |
self.prompts = pMs |
self.altprompts = altpMs |
self.opt = opt |
self.normalize = normalize |
self.z, self.z_orig, self.z_min, self.z_max = z, z_orig, z_min, z_max |
self.setup_metadata(args2.seed) |
self.mse_weight = self.args.init_weight |
def synth(self, z): |
if self.args.is_gumbel: |
z_q = vector_quantize( |
z.movedim(1, 3), self.model.quantize.embed.weight |
).movedim(3, 1) |
else: |
z_q = vector_quantize( |
z.movedim(1, 3), self.model.quantize.embedding.weight |
).movedim(3, 1) |
return clamp_with_grad(self.model.decode(z_q).add(1).div(2), 0, 1) |
def add_metadata(self, path, i): |
imfile = PngImageFile(path) |
meta = PngInfo() |
step_meta = {"iterations": i} |
step_meta.update(self.metadata) |
imfile.save(path, pnginfo=meta) |
@torch.no_grad() |
def checkin(self, i, losses, x): |
""" |
losses_str = ', '.join(f'{loss.item():g}' for loss in losses) |
if i < args.mse_end: |
tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}') |
else: |
tqdm.write(f'i: {i-args.mse_end} ({i}), loss: {sum(losses).item():g}, losses: {losses_str}') |
tqdm.write(f'cutn: {self.make_cutouts.cutn}, cut_pow: {self.make_cutouts.cut_pow}, step_size: {self.cur_step_size}') |
""" |
out = self.synth(self.z.average) |
sys.stdout.flush() |
sys.stdout.write("Saving progress ...\n") |
sys.stdout.flush() |
batchpath = "./" |
TF.to_pil_image(out[0].cpu()).save(args2.image_file) |
if args2.frame_dir is not None: |
import os |
file_list = [] |
for file in sorted(os.listdir(args2.frame_dir)): |
if file.startswith("FRA"): |
if file.endswith("PNG"): |
if len(file) == 12: |
file_list.append(file) |
if file_list: |
last_name = file_list[-1] |
count_value = int(last_name[3:8]) + 1 |
count_string = f"{count_value:05d}" |
else: |
count_string = "00001" |
save_name = args2.frame_dir + "/FRA" + count_string + ".PNG" |
TF.to_pil_image(out[0].cpu()).save(save_name) |
sys.stdout.flush() |
sys.stdout.write("Progress saved\n") |
sys.stdout.flush() |
def unique_index(self, batchpath): |
i = 0 |
while i < 10000: |
if os.path.isfile(batchpath + "/" + str(i) + ".png"): |
i = i + 1 |
else: |
return batchpath + "/" + str(i) + ".png" |
def ascend_txt(self, i): |
out = self.synth(self.z.tensor) |
iii = self.perceptor.encode_image( |
self.normalize(self.make_cutouts(out)) |
).float() |
result = [] |
if self.args.init_weight and self.mse_weight > 0: |
result.append( |
F.mse_loss(self.z.tensor, self.z_orig) * self.mse_weight / 2 |
) |
for prompt in self.prompts: |
result.append(prompt(iii)) |
if self.usealtprompts: |
iii = self.perceptor.encode_image( |
self.normalize(self.alt_make_cutouts(out)) |
).float() |
for prompt in self.altprompts: |
result.append(prompt(iii)) |
""" |
img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:] |
img = np.transpose(img, (1, 2, 0)) |
im_path = 'progress.png' |
imageio.imwrite(im_path, np.array(img)) |
self.add_metadata(im_path, i) |
""" |
return result |
def train(self, i, x): |
self.opt.zero_grad() |
mse_decay = self.args.mse_decay |
mse_decay_rate = self.args.mse_decay_rate |
lossAll = self.ascend_txt(i) |
sys.stdout.write("Iteration {}".format(i) + "\n") |
sys.stdout.flush() |
""" |
if i < args.mse_end and i % args.mse_display_freq == 0: |
self.checkin(i, lossAll, x) |
if i == args.mse_end: |
self.checkin(i,lossAll,x) |
if i > args.mse_end and (i-args.mse_end) % args.display_freq == 0: |
self.checkin(i, lossAll, x) |
""" |
if i % args2.update == 0: |
self.checkin(i, lossAll, x) |
loss = sum(lossAll) |
loss.backward() |
self.opt.step() |
with torch.no_grad(): |
if ( |
self.mse_weight > 0 |
and self.args.init_weight |
and i > 0 |
and i % mse_decay_rate == 0 |
): |
if self.args.is_gumbel: |
self.z_orig = vector_quantize( |
self.z.average.movedim(1, 3), |
self.model.quantize.embed.weight, |
).movedim(3, 1) |
else: |
self.z_orig = vector_quantize( |
self.z.average.movedim(1, 3), |
self.model.quantize.embedding.weight, |
).movedim(3, 1) |
if self.mse_weight - mse_decay > 0: |
self.mse_weight = self.mse_weight - mse_decay |
else: |
self.mse_weight = 0 |
self.make_cutouts = flavordict[flavor]( |
self.perceptor.visual.input_resolution, |
args.cutn, |
cut_pow=args.cut_pow, |
augs=args.augs, |
) |
if self.usealtprompts: |
self.alt_make_cutouts = flavordict[flavor]( |
self.perceptor.visual.input_resolution, |
args.cutn, |
cut_pow=args.alt_cut_pow, |
augs=args.altaugs, |
) |
self.z = EMATensor(self.z.average, args.ema_val) |
self.new_step_size = args.step_size |
self.opt = torch.optim.AdamW( |
self.z.parameters(), |
lr=args.step_size, |
weight_decay=0.00000000, |
) |
if i > args.mse_end: |
if ( |
args.step_size != args.final_step_size |
and args.max_iterations > 0 |
): |
progress = (i - args.mse_end) / (args.max_iterations) |
self.cur_step_size = lerp(step_size, final_step_size, progress) |
for g in self.opt.param_groups: |
g["lr"] = self.cur_step_size |
def run(self, x): |
j = 0 |
status.write("Starting the execution...") |
try: |
before_start_time = time.perf_counter() |
bar_container = status.container() |
iteration_counter = bar_container.empty() |
progress_bar = bar_container.progress(0) |
total_steps = int(args.max_iterations + args.mse_end) - 1 |
for _ in range(total_steps): |
if j == 0: |
iteration_counter.empty() |
imageLocation = stoutput.empty() |
self.train(j, x) |
imageLocation.image(Image.open(args2.image_file)) |
if j > 0 and j % args.mse_decay_rate == 0 and self.mse_weight > 0: |
self.z = EMATensor(self.z.average, args.ema_val) |
self.opt = torch.optim.AdamW( |
self.z.parameters(), |
lr=args.mse_step_size, |
weight_decay=0.00000000, |
) |
if j >= total_steps: |
break |
self.z.update() |
j += 1 |
time_past_seconds = time.perf_counter() - before_start_time |
iterations_per_second = j / time_past_seconds |
time_left = (total_steps - j) / iterations_per_second |
percentage = round((j / (total_steps + 1)) * 100) |
iteration_counter.write( |
f"{percentage}% {j}/{total_steps+1} [{time.strftime('%M:%S', time.gmtime(time_past_seconds))}<{time.strftime('%M:%S', time.gmtime(time_left))}, {round(iterations_per_second,2)} it/s]" |
) |
progress_bar.progress(int(percentage)) |
import shutil |
import os |
if not path_exists(DefaultPaths.output_path): |
os.makedirs(DefaultPaths.output_path) |
save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [{args2.sub_model}] {args2.seed}.png" |
file_list = [] |
if path_exists(save_filename): |
for file in sorted(os.listdir(f"{DefaultPaths.output_path}/")): |
if file.startswith( |
f"{sanitize_filename(args2.prompt)} [{args2.sub_model}] {args2.seed}" |
): |
file_list.append(file) |
last_name = file_list[-1] |
if last_name[-15:-10] == "batch": |
count_value = int(last_name[-10:-4]) + 1 |
count_string = f"{count_value:05d}" |
save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [{args2.sub_model}] {args2.seed}_batch {count_string}.png" |
else: |
save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [{args2.sub_model}] {args2.seed}_batch 00001.png" |
shutil.copyfile( |
args2.image_file, |
save_filename, |
) |
status.write("Done!") |
except KeyboardInterrupt: |
pass |
except st.script_runner.StopException as e: |
imageLocation.image(args2.image_file) |
torch.cuda.empty_cache() |
status.write("Done!") |
pass |
imageLocation.empty() |
return j |
def add_noise(img): |
row, col = img.shape |
number_of_pixels = random.randint(300, 10000) |
for i in range(number_of_pixels): |
y_coord = random.randint(0, row - 1) |
x_coord = random.randint(0, col - 1) |
img[y_coord][x_coord] = 255 |
number_of_pixels = random.randint(300, 10000) |
for i in range(number_of_pixels): |
y_coord = random.randint(0, row - 1) |
x_coord = random.randint(0, col - 1) |
img[y_coord][x_coord] = 0 |
return img |
import io |
import base64 |
def image_to_data_url(img, ext): |
img_byte_arr = io.BytesIO() |
img.save(img_byte_arr, format=ext) |
img_byte_arr = img_byte_arr.getvalue() |
prefix = f"data:image/{ext};base64," |
return prefix + base64.b64encode(img_byte_arr).decode("utf-8") |
import torch |
import math |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
def rand_perlin_2d( |
shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3 |
): |
delta = (res[0] / shape[0], res[1] / shape[1]) |
d = (shape[0] // res[0], shape[1] // res[1]) |
grid = ( |
torch.stack( |
torch.meshgrid( |
torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]) |
), |
dim=-1, |
) |
% 1 |
) |
angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1) |
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) |
tile_grads = ( |
lambda slice1, slice2: gradients[ |
slice1[0] : slice1[1], slice2[0] : slice2[1] |
] |
.repeat_interleave(d[0], 0) |
.repeat_interleave(d[1], 1) |
) |
dot = lambda grad, shift: ( |
torch.stack( |
( |
grid[: shape[0], : shape[1], 0] + shift[0], |
grid[: shape[0], : shape[1], 1] + shift[1], |
), |
dim=-1, |
) |
* grad[: shape[0], : shape[1]] |
).sum(dim=-1) |
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]) |
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]) |
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]) |
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]) |
t = fade(grid[: shape[0], : shape[1]]) |
return math.sqrt(2) * torch.lerp( |
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1] |
) |
def rand_perlin_2d_octaves(desired_shape, octaves=1, persistence=0.5): |
shape = torch.tensor(desired_shape) |
shape = 2 ** torch.ceil(torch.log2(shape)) |
shape = shape.type(torch.int) |
max_octaves = int( |
min( |
octaves, |
math.log(shape[0]) / math.log(2), |
math.log(shape[1]) / math.log(2), |
) |
) |
res = torch.floor(shape / 2**max_octaves).type(torch.int) |
noise = torch.zeros(list(shape)) |
frequency = 1 |
amplitude = 1 |
for _ in range(max_octaves): |
noise += amplitude * rand_perlin_2d( |
shape, (frequency * res[0], frequency * res[1]) |
) |
frequency *= 2 |
amplitude *= persistence |
return noise[: desired_shape[0], : desired_shape[1]] |
def rand_perlin_rgb(desired_shape, amp=0.1, octaves=6): |
r = rand_perlin_2d_octaves(desired_shape, octaves) |
g = rand_perlin_2d_octaves(desired_shape, octaves) |
b = rand_perlin_2d_octaves(desired_shape, octaves) |
rgb = (torch.stack((r, g, b)) * amp + 1) * 0.5 |
return rgb.unsqueeze(0).clip(0, 1).to(device) |
def pyramid_noise_gen(shape, octaves=5, decay=1.0): |
n, c, h, w = shape |
noise = torch.zeros([n, c, 1, 1]) |
max_octaves = int(min(math.log(h) / math.log(2), math.log(w) / math.log(2))) |
if octaves is not None and 0 < octaves: |
max_octaves = min(octaves, max_octaves) |
for i in reversed(range(max_octaves)): |
h_cur, w_cur = h // 2**i, w // 2**i |
noise = F.interpolate( |
noise, (h_cur, w_cur), mode="bicubic", align_corners=False |
) |
noise += (torch.randn([n, c, h_cur, w_cur]) / max_octaves) * decay ** ( |
max_octaves - (i + 1) |
) |
return noise |
def rand_z(model, toksX, toksY): |
e_dim = model.quantize.e_dim |
n_toks = model.quantize.n_e |
z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] |
z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] |
one_hot = F.one_hot( |
torch.randint(n_toks, [toksY * toksX], device=device), n_toks |
).float() |
z = one_hot @ model.quantize.embedding.weight |
z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) |
return z |
def make_rand_init( |
mode, |
model, |
perlin_octaves, |
perlin_weight, |
pyramid_octaves, |
pyramid_decay, |
toksX, |
toksY, |
f, |
): |
if mode == "VQGAN ZRand": |
return rand_z(model, toksX, toksY) |
elif mode == "Perlin Noise": |
rand_init = rand_perlin_rgb( |
(toksY * f, toksX * f), perlin_weight, perlin_octaves |
) |
z, *_ = model.encode(rand_init * 2 - 1) |
return z |
elif mode == "Pyramid Noise": |
rand_init = pyramid_noise_gen( |
(1, 3, toksY * f, toksX * f), pyramid_octaves, pyramid_decay |
).to(device) |
rand_init = (rand_init * 0.5 + 0.5).clip(0, 1) |
z, *_ = model.encode(rand_init * 2 - 1) |
return z |
import os |
imagenet_1024 = False |
imagenet_16384 = True |
gumbel_8192 = False |
sber_gumbel = False |
coco = False |
coco_1stage = False |
faceshq = False |
wikiart_1024 = False |
wikiart_16384 = False |
wikiart_7mil = False |
sflckr = False |
""" |
if imagenet_1024: |
!curl -L -o vqgan_imagenet_f16_1024.yaml -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 1024 |
!curl -L -o vqgan_imagenet_f16_1024.ckpt -C - 'https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 1024 |
if imagenet_16384: |
!curl -L -o vqgan_imagenet_f16_16384.yaml -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #ImageNet 16384 |
!curl -L -o vqgan_imagenet_f16_16384.ckpt -C - 'https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #ImageNet 16384 |
if gumbel_8192: |
!curl -L -o gumbel_8192.yaml -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fconfigs%2Fmodel.yaml&dl=1' #Gumbel 8192 |
!curl -L -o gumbel_8192.ckpt -C - 'https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/files/?p=%2Fckpts%2Flast.ckpt&dl=1' #Gumbel 8192 |
#if imagenet_cin: |
# !curl -L -o imagenet_cin.yaml -C - 'https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml' #ImageNet (cIN) |
# !curl -L -o imagenet_cin.ckpt -C - 'https://app.koofr.net/content/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e/files/get/last.ckpt?path=%2F2021-04-03T19-39-50_cin_transformer%2Fcheckpoints%2Flast.ckpt' #ImageNet (cIN) |
if sber_gumbel: |
models_folder = './' |
configs_folder = './' |
os.makedirs(models_folder, exist_ok=True) |
os.makedirs(configs_folder, exist_ok=True) |
models_storage = [ |
{ |
'id': '1WP6Li2Po8xYcQPGMpmaxIlI1yPB5lF5m', |
'name': 'sber_gumbel.ckpt', |
}, |
] |
configs_storage = [{ |
'id': '1M7RvSoiuKBwpF-98sScKng0lsZnwFebR', |
'name': 'sber_gumbel.yaml', |
}] |
url_template = 'https://drive.google.com/uc?id={}' |
for item in models_storage: |
out_name = os.path.join(models_folder, item['name']) |
url = url_template.format(item['id']) |
gdown.download(url, out_name, quiet=True) |
for item in configs_storage: |
out_name = os.path.join(configs_folder, item['name']) |
url = url_template.format(item['id']) |
gdown.download(url, out_name, quiet=True) |
if coco: |
!curl -L -o coco.yaml -C - 'https://dl.nmkd.de/ai/clip/coco/coco.yaml' #COCO |
!curl -L -o coco.ckpt -C - 'https://dl.nmkd.de/ai/clip/coco/coco.ckpt' #COCO |
if faceshq: |
!curl -L -o faceshq.yaml -C - 'https://drive.google.com/uc?export=download&id=1fHwGx_hnBtC8nsq7hesJvs-Klv-P0gzT' #FacesHQ |
!curl -L -o faceshq.ckpt -C - 'https://app.koofr.net/content/links/a04deec9-0c59-4673-8b37-3d696fe63a5d/files/get/last.ckpt?path=%2F2020-11-13T21-41-45_faceshq_transformer%2Fcheckpoints%2Flast.ckpt' #FacesHQ |
if wikiart_1024: |
#I'm so sorry, I know this is exploiting, but there is no other way. |
!curl -L -o wikiart_1024.yaml -C - 'https://github.com/Eleiber/VQGAN-Mirrors/releases/download/0.0.1/wikiart_1024.yaml' #WikiArt 1024 |
!curl -L -o wikiart_1024.ckpt -C - 'https://github.com/Eleiber/VQGAN-Mirrors/releases/download/0.0.1/wikiart_1024.ckpt' #WikiArt 1024 |
if wikiart_16384: |
!curl -L -o wikiart_16384.yaml -C - 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.yaml' #WikiArt 16384 |
!curl -L -o wikiart_16384.ckpt -C - 'http://eaidata.bmk.sh/data/Wikiart_16384/wikiart_f16_16384_8145600.ckpt' #WikiArt 16384 |
if sflckr: |
!curl -L -o sflckr.yaml -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fconfigs%2F2020-11-09T13-31-51-project.yaml&dl=1' #S-FLCKR |
!curl -L -o sflckr.ckpt -C - 'https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/files/?p=%2Fcheckpoints%2Flast.ckpt&dl=1' #S-FLCKR |
if wikiart_7mil: |
!curl -L -o wikiart_7mil.yaml -C - 'http://batbot.tv/ai/models/VQGAN/WikiArt_augmented_Steps_7mil_finetuned_1mil.yaml' #S-FLCKR |
!curl -L -o wikiart_7mil.ckpt -C - 'http://batbot.tv/ai/models/VQGAN/WikiArt_augmented_Steps_7mil_finetuned_1mil.ckpt' #S-FLCKR |
if coco_1stage: |
!curl -L -o coco_1stage.yaml -C - 'http://batbot.tv/ai/models/VQGAN/coco_first_stage.yaml' #S-FLCKR |
!curl -L -o coco_1stage.ckpt -C - 'http://batbot.tv/ai/models/VQGAN/coco_first_stage.ckpt' #S-FLCKR |
#None of these work, if you know how to make them work, go ahead. - Philipuss |
#if celebahq: |
# !curl -L -o celebahq.yaml -C - 'https://app.koofr.net/content/links/6dddf083-40c8-470a-9360-a9dab2a94e96/files/get/2021-04-23T18-11-19-project.yaml?path=%2F2021-04-23T18-11-19_celebahq_transformer%2Fconfigs%2F2021-04-23T18-11-19-project.yaml&force' #celebahq |
# !curl -L -o celebahq.ckpt -C - 'https://app.koofr.net/content/links/6dddf083-40c8-470a-9360-a9dab2a94e96/files/get/last.ckpt?path=%2F2021-04-23T18-11-19_celebahq_transformer%2Fcheckpoints%2Flast.ckpt' #celebahq |
#if ade20k: |
# !curl -L -o ade20k.yaml -C - 'https://app.koofr.net/content/links/0f65c2cd-7102-4550-a2bd-07fd383aac9e/files/get/2020-11-20T21-45-44-project.yaml?path=%2F2020-11-20T21-45-44_ade20k_transformer%2Fconfigs%2F2020-11-20T21-45-44-project.yaml&force' #celebahq |
# !curl -L -o ade20k.ckpt -C - 'https://app.koofr.net/content/links/0f65c2cd-7102-4550-a2bd-07fd383aac9e/files/get/last.ckpt?path=%2F2020-11-20T21-45-44_ade20k_transformer%2Fcheckpoints%2Flast.ckpt' #celebahq |
#if drin: |
# !curl -L -o drin.yaml -C - 'https://app.koofr.net/content/links/028f1ba8-404d-42c4-a866-9a8a4eebb40c/files/get/2020-11-20T12-54-32-project.yaml?path=%2F2020-11-20T12-54-32_drin_transformer%2Fconfigs%2F2020-11-20T12-54-32-project.yaml&force' #celebahq |
# !curl -L -o drin.ckpt -C - 'https://app.koofr.net/content/links/028f1ba8-404d-42c4-a866-9a8a4eebb40c/files/get/last.ckpt?path=%2F2020-11-20T12-54-32_drin_transformer%2Fcheckpoints%2Flast.ckpt' #celebahq |
""" |
import os |
import random |
import cv2 |
from PIL import Image |
from importlib import reload |
reload(PIL.TiffTags) |
prompts = args2.prompt |
width = args2.sizex |
height = args2.sizey |
sys.stdout.write(f"Loading {args2.vqgan_model} ...\n") |
sys.stdout.flush() |
status.write(f"Loading {args2.vqgan_model} ...\n") |
model = args2.vqgan_model |
if model == "Gumbel 8192" or model == "Sber Gumbel": |
is_gumbel = True |
else: |
is_gumbel = False |
flavor = ( |
args2.flavor |
) |
template = ( |
args2.template |
) |
init = "default noise" |
if args2.seed_image is None: |
init_image = "" |
else: |
init_image = args2.seed_image |
if init == "random image": |
url = ( |
"https://picsum.photos/" |
+ str(width) |
+ "/" |
+ str(height) |
+ "?blur=" |
+ str(random.randrange(5, 10)) |
) |
urllib.request.urlretrieve(url, "Init_Img/Image.png") |
init_image = "Init_Img/Image.png" |
elif init == "random image clear": |
url = "https://source.unsplash.com/random/" + str(width) + "x" + str(height) |
urllib.request.urlretrieve(url, "Init_Img/Image.png") |
init_image = "Init_Img/Image.png" |
elif init == "random image clear 2": |
url = "https://loremflickr.com/" + str(width) + "/" + str(height) |
urllib.request.urlretrieve(url, "Init_Img/Image.png") |
init_image = "Init_Img/Image.png" |
elif init == "salt and pepper noise": |
urllib.request.urlretrieve( |
"https://i.stack.imgur.com/olrL8.png", "Init_Img/Image.png" |
) |
import cv2 |
img = cv2.imread("Init_Img/Image.png", 0) |
cv2.imwrite("Init_Img/Image.png", add_noise(img)) |
init_image = "Init_Img/Image.png" |
elif init == "salt and pepper noise on init image": |
img = cv2.imread(init_image, 0) |
cv2.imwrite("Init_Img/Image.png", add_noise(img)) |
init_image = "Init_Img/Image.png" |
elif init == "perlin noise": |
import noise |
import numpy as np |
from PIL import Image |
shape = (width, height) |
scale = 100 |
octaves = 6 |
persistence = 0.5 |
lacunarity = 2.0 |
seed = np.random.randint(0, 100000) |
world = np.zeros(shape) |
for i in range(shape[0]): |
for j in range(shape[1]): |
world[i][j] = noise.pnoise2( |
i / scale, |
j / scale, |
octaves=octaves, |
persistence=persistence, |
lacunarity=lacunarity, |
repeatx=1024, |
repeaty=1024, |
base=seed, |
) |
Image.fromarray(prep_world(world)).convert("L").save("Init_Img/Image.png") |
init_image = "Init_Img/Image.png" |
elif init == "black and white": |
url = "https://www.random.org/bitmaps/?format=png&width=300&height=300&zoom=1" |
urllib.request.urlretrieve(url, "Init_Img/Image.png") |
init_image = "Init_Img/Image.png" |
seed = args2.seed |
iterations = args2.iterations |
transparent_png = False |
multiple_prompt_batches = False |
multiple_prompt_batches_iter = 300 |
folder_name = "" |
save_to_drive = False |
prompt_experiment = "None" |
if prompt_experiment == "Fever Dream": |
prompts = "<|startoftext|>" + prompts + "<|endoftext|>" |
elif prompt_experiment == "Vivid Turmoil": |
prompts = prompts.replace(" ", "¡") |
prompts = "¬" + prompts + "®" |
elif prompt_experiment == "Mad Dad": |
prompts = prompts.replace(" ", "\\s+") |
elif prompt_experiment == "Platinum": |
prompts = "~!" + prompts + "!~" |
prompts = prompts.replace(" ", "</w>") |
elif prompt_experiment == "Philipuss’s Basement": |
prompts = "<|startoftext|>" + prompts |
prompts = prompts.replace(" ", "<|endoftext|><|startoftext|>") |
elif prompt_experiment == "Lowercase": |
prompts = prompts.lower() |
clip_model = ( |
args2.clip_model_1 |
) |
clip_model2 = ( |
args2.clip_model_2 |
) |
clip_model3 = ( |
args2.clip_model_3 |
) |
clip_model4 = ( |
args2.clip_model_4 |
) |
clip_model5 = ( |
args2.clip_model_5 |
) |
clip_model6 = ( |
args2.clip_model_6 |
) |
clip_model7 = ( |
args2.clip_model_7 |
) |
clip_model8 = ( |
args2.clip_model_8 |
) |
if clip_model2 == "None": |
clip_model2 = None |
if clip_model3 == "None": |
clip_model3 = None |
if clip_model4 == "None": |
clip_model4 = None |
if clip_model5 == "None": |
clip_model5 = None |
if clip_model6 == "None": |
clip_model6 = None |
if clip_model7 == "None": |
clip_model7 = None |
if clip_model8 == "None": |
clip_model8 = None |
target_images = "" |
cutn = 130 |
cut_pow = 1 |
step_size = 0.1 |
start_step_size = 0 |
final_step_size = 0 |
if start_step_size <= 0: |
start_step_size = step_size |
if final_step_size <= 0: |
final_step_size = step_size |
ema_val = 0.98 |
gen_seed = -1 |
init_image_in_drive = False |
if init_image_in_drive and init_image: |
init_image = "/content/drive/MyDrive/VQGAN_Output/" + init_image |
images_interval = args2.update |
batch_size = 1 |
use_mse = args2.mse |
mse_images_interval = images_interval |
mse_init_weight = 0.2 |
mse_decay_rate = 160 |
mse_epoches = 10 |
mse_with_zeros = True |
mse_step_size = 0.87 |
mse_cutn = 42 |
mse_cut_pow = 0.75 |
normal_flip_optim = True |
altprompts = "" |
altprompt_mode = "flipped" |
alt_cut_pow = 0 |
alt_mse_cut_pow = 0 |
zoom = False |
zoom_speed = 100 |
zoom_frequency = 20 |
model_names = { |
"vqgan_imagenet_f16_16384": "vqgan_imagenet_f16_16384", |
"ImageNet 1024": "vqgan_imagenet_f16_1024", |
"Gumbel 8192": "gumbel_8192", |
"Sber Gumbel": "sber_gumbel", |
"imagenet_cin": "imagenet_cin", |
"WikiArt 1024": "wikiart_1024", |
"WikiArt 16384": "wikiart_16384", |
"COCO-Stuff": "coco", |
"FacesHQ": "faceshq", |
"S-FLCKR": "sflckr", |
"WikiArt 7mil": "wikiart_7mil", |
"COCO 1 Stage": "coco_1stage", |
} |
if template == "Better - Fast": |
prompts = prompts + ". Detailed artwork. ArtStationHQ. unreal engine. 4K HD." |
elif template == "Better - Slow": |
prompts = ( |
prompts |
+ ". Detailed artwork. Trending on ArtStation. unreal engine. | Rendered in Maya. " |
+ prompts |
+ ". 4K HD." |
) |
elif template == "Movie Poster": |
prompts = prompts + ". Movie poster. Rendered in unreal engine. ArtStationHQ." |
width = 400 |
height = 592 |
elif template == "flag": |
prompts = ( |
"A photo of a flag of the country " |
+ prompts |
+ " | Flag of " |
+ prompts |
+ ". White background." |
) |
init_image = "templates/flag.png" |
transparent_png = True |
elif template == "planet": |
import cv2 |
img = cv2.imread("templates/planet.png", 0) |
cv2.imwrite("templates/final_planet.png", add_noise(img)) |
prompts = ( |
"A photo of the planet " |
+ prompts |
+ ". Planet in the middle with black background. | The planet of " |
+ prompts |
+ ". Photo of a planet. Black background. Trending on ArtStation. | Colorful." |
) |
init_image = "templates/final_planet.png" |
elif template == "creature": |
prompts = ( |
"A photo of a creature with " |
+ prompts |
+ ". Animal in the middle with white background. | The creature has " |
+ prompts |
+ ". Photo of a creature/animal. White background. Detailed image of a creature. | White background." |
) |
init_image = "templates/creature.png" |
elif template == "Detailed": |
prompts = ( |
prompts |
+ ", by Puer Udger. Detailed artwork, trending on artstation. 4K HD, realism." |
) |
flavor = "cumin" |
elif template == "human": |
init_image = "/content/templates/human.png" |
elif template == "Realistic": |
cutn = 200 |
step_size = 0.03 |
cut_pow = 0.2 |
flavor = "holywater" |
elif template == "Consistent Creativity": |
flavor = "cumin" |
cut_pow = 0.01 |
cutn = 136 |
step_size = 0.08 |
mse_step_size = 0.41 |
mse_cut_pow = 0.3 |
ema_val = 0.99 |
normal_flip_optim = False |
elif template == "Smooth": |
flavor = "wyvern" |
step_size = 0.10 |
cutn = 120 |
normal_flip_optim = False |
tv_weight = 10 |
elif template == "Subtle MSE": |
mse_init_weight = 0.07 |
mse_decay_rate = 130 |
mse_step_size = 0.2 |
mse_cutn = 100 |
mse_cut_pow = 0.6 |
elif template == "Balanced": |
cutn = 130 |
cut_pow = 1 |
step_size = 0.16 |
final_step_size = 0 |
ema_val = 0.98 |
mse_init_weight = 0.2 |
mse_decay_rate = 130 |
mse_with_zeros = True |
mse_step_size = 0.9 |
mse_cutn = 50 |
mse_cut_pow = 0.8 |
normal_flip_optim = True |
elif template == "Size: Square": |
width = 450 |
height = 450 |
elif template == "Size: Landscape": |
width = 480 |
height = 336 |
elif template == "Size: Poster": |
width = 336 |
height = 480 |
elif template == "Negative Prompt": |
prompts = prompts.replace(":", ":-") |
prompts = prompts.replace(":--", ":") |
elif template == "Hyper Fast Results": |
step_size = 1 |
ema_val = 0.3 |
cutn = 30 |
elif template == "Better Quality": |
prompts = ( |
prompts + ":1 | Watermark, blurry, cropped, confusing, cut, incoherent:-1" |
) |
mse_decay = 0 |
if use_mse == False: |
mse_init_weight = 0.0 |
else: |
mse_decay = mse_init_weight / mse_epoches |
if os.path.isdir("/content/drive") == False: |
if save_to_drive == True or init_image_in_drive == True: |
drive.mount("/content/drive") |
if seed == -1: |
seed = None |
if init_image == "None": |
init_image = None |
if target_images == "None" or not target_images: |
target_images = [] |
else: |
target_images = target_images.split("|") |
target_images = [image.strip() for image in target_images] |
prompts = [phrase.strip() for phrase in prompts.split("|")] |
if prompts == [""]: |
prompts = [] |
altprompts = [phrase.strip() for phrase in altprompts.split("|")] |
if altprompts == [""]: |
altprompts = [] |
if mse_images_interval == 0: |
mse_images_interval = images_interval |
if mse_step_size == 0: |
mse_step_size = step_size |
if mse_cutn == 0: |
mse_cutn = cutn |
if mse_cut_pow == 0: |
mse_cut_pow = cut_pow |
if alt_cut_pow == 0: |
alt_cut_pow = cut_pow |
if alt_mse_cut_pow == 0: |
alt_mse_cut_pow = mse_cut_pow |
augs = nn.Sequential( |
K.RandomHorizontalFlip(p=0.5), |
K.RandomSharpness(0.3, p=0.4), |
K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3), |
K.RandomAffine( |
degrees=30, translate=0.1, p=0.8, padding_mode="border" |
), |
K.RandomPerspective( |
0.2, |
p=0.4, |
), |
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), |
K.RandomGrayscale(p=0.1), |
) |
if altprompt_mode == "normal": |
altaugs = nn.Sequential( |
K.RandomRotation(degrees=90.0, return_transform=True), |
K.RandomHorizontalFlip(p=0.5), |
K.RandomSharpness(0.3, p=0.4), |
K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3), |
K.RandomAffine( |
degrees=30, translate=0.1, p=0.8, padding_mode="border" |
), |
K.RandomPerspective( |
0.2, |
p=0.4, |
), |
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), |
K.RandomGrayscale(p=0.1), |
) |
elif altprompt_mode == "flipped": |
altaugs = nn.Sequential( |
K.RandomHorizontalFlip(p=0.5), |
K.RandomVerticalFlip(p=1), |
K.RandomSharpness(0.3, p=0.4), |
K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3), |
K.RandomAffine( |
degrees=30, translate=0.1, p=0.8, padding_mode="border" |
), |
K.RandomPerspective( |
0.2, |
p=0.4, |
), |
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), |
K.RandomGrayscale(p=0.1), |
) |
elif altprompt_mode == "sideways": |
altaugs = nn.Sequential( |
K.RandomHorizontalFlip(p=0.5), |
K.RandomVerticalFlip(p=1), |
K.RandomSharpness(0.3, p=0.4), |
K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3), |
K.RandomAffine( |
degrees=30, translate=0.1, p=0.8, padding_mode="border" |
), |
K.RandomPerspective( |
0.2, |
p=0.4, |
), |
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), |
K.RandomGrayscale(p=0.1), |
) |
if multiple_prompt_batches: |
prompts_all = str(prompts).split("~") |
else: |
prompts_all = prompts |
multiple_prompt_batches_iter = iterations |
if multiple_prompt_batches: |
mtpl_prmpts_btchs = len(prompts_all) |
else: |
mtpl_prmpts_btchs = 1 |
steps_path = "./" |
zoom_path = "./" |
path = "./" |
iterations = multiple_prompt_batches_iter |
for pr in range(0, mtpl_prmpts_btchs): |
if multiple_prompt_batches: |
prompts = prompts_all[pr].replace("['", "").replace("']", "") |
if zoom: |
mdf_iter = round(iterations / zoom_frequency) |
else: |
mdf_iter = 2 |
zoom_frequency = iterations |
for iter in range(1, mdf_iter): |
if zoom: |
if iter != 0: |
image = Image.open("progress.png") |
area = (0, 0, width - zoom_speed, height - zoom_speed) |
cropped_img = image.crop(area) |
cropped_img.show() |
new_image = cropped_img.resize((width, height)) |
new_image.save("zoom.png") |
init_image = "zoom.png" |
args = argparse.Namespace( |
prompts=prompts, |
altprompts=altprompts, |
image_prompts=target_images, |
noise_prompt_seeds=[], |
noise_prompt_weights=[], |
size=[width, height], |
init_image=init_image, |
png=transparent_png, |
init_weight=mse_init_weight, |
vqgan_model=model_names[model], |
step_size=step_size, |
start_step_size=start_step_size, |
final_step_size=final_step_size, |
cutn=cutn, |
cut_pow=cut_pow, |
mse_cutn=mse_cutn, |
mse_cut_pow=mse_cut_pow, |
mse_step_size=mse_step_size, |
display_freq=images_interval, |
mse_display_freq=mse_images_interval, |
max_iterations=zoom_frequency, |
mse_end=0, |
seed=seed, |
folder_name=folder_name, |
save_to_drive=save_to_drive, |
mse_decay_rate=mse_decay_rate, |
mse_decay=mse_decay, |
mse_with_zeros=mse_with_zeros, |
normal_flip_optim=normal_flip_optim, |
ema_val=ema_val, |
augs=augs, |
altaugs=altaugs, |
alt_cut_pow=alt_cut_pow, |
alt_mse_cut_pow=alt_mse_cut_pow, |
is_gumbel=is_gumbel, |
clip_model=clip_model, |
clip_model2=clip_model2, |
clip_model3=clip_model3, |
clip_model4=clip_model4, |
clip_model5=clip_model5, |
clip_model6=clip_model6, |
clip_model7=clip_model7, |
clip_model8=clip_model8, |
gen_seed=gen_seed, |
) |
mh = ModelHost(args) |
x = 0 |
for x in range(batch_size): |
mh.setup_model(x) |
last_iter = mh.run(x) |
x = x + 1 |
if batch_size != 1: |
q = 0 |
while q < batch_size: |
display(Image("/content/" + folder_name + "/" + str(q) + ".png")) |
q += 1 |
if zoom: |
files = os.listdir(steps_path) |
for index, file in enumerate(files): |
os.rename( |
os.path.join(steps_path, file), |
os.path.join( |
steps_path, |
"".join([str(index + 1 + zoom_frequency * iter), ".png"]), |
), |
) |
index = index + 1 |
from pathlib import Path |
import shutil |
src_path = steps_path |
trg_path = zoom_path |
for src_file in range(1, mdf_iter): |
shutil.move(os.path.join(src_path, src_file), trg_path) |