|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""#Tutorial |
|
|
|
**Diffusion settings (Defaults are heavily outdated)** |
|
--- |
|
|
|
This section is outdated as of v2 |
|
|
|
Setting | Description | Default |
|
--- | --- | --- |
|
**Your vision:** |
|
`text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A |
|
`image_prompts` | Think of these images more as a description of their contents. | N/A |
|
**Image quality:** |
|
`clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000 |
|
`tv_scale` | Controls the smoothness of the final output. | 150 |
|
`range_scale` | Controls how far out of range RGB values are allowed to be. | 150 |
|
`sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0 |
|
`cutn` | Controls how many crops to take from the image. | 16 |
|
`cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2 |
|
**Init settings:** |
|
`init_image` | URL or local path | None |
|
`init_scale` | This enhances the effect of the init image, a good value is 1000 | 0 |
|
`skip_steps Controls the starting point along the diffusion timesteps | 0 |
|
`perlin_init` | Option to start with random perlin noise | False |
|
`perlin_mode` | ('gray', 'color') | 'mixed' |
|
**Advanced:** |
|
`skip_augs` |Controls whether to skip torchvision augmentations | False |
|
`randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True |
|
`clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False |
|
`clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True |
|
`seed` | Choose a random seed and print it at end of run for reproduction | random_seed |
|
`fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False |
|
`rand_mag` |Controls the magnitude of the random noise | 0.1 |
|
`eta` | DDIM hyperparameter | 0.5 |
|
|
|
.. |
|
|
|
**Model settings** |
|
--- |
|
|
|
Setting | Description | Default |
|
--- | --- | --- |
|
**Diffusion:** |
|
`timestep_respacing` | Modify this value to decrease the number of timesteps. | ddim100 |
|
`diffusion_steps` || 1000 |
|
**Diffusion:** |
|
`clip_models` | Models of CLIP to load. Typically the more, the better but they all come at a hefty VRAM cost. | ViT-B/32, ViT-B/16, RN50x4 |
|
|
|
# 1. Set Up |
|
""" |
|
|
|
|
|
is_colab = False |
|
google_drive = False |
|
save_models_to_google_drive = False |
|
|
|
import sys |
|
|
|
sys.stdout.write("Imports ...\n") |
|
sys.stdout.flush() |
|
|
|
sys.path.append("./ResizeRight") |
|
sys.path.append("./MiDaS") |
|
sys.path.append("./CLIP") |
|
sys.path.append("./guided-diffusion") |
|
sys.path.append("./latent-diffusion") |
|
sys.path.append(".") |
|
sys.path.append("./taming-transformers") |
|
sys.path.append("./disco-diffusion") |
|
sys.path.append("./AdaBins") |
|
sys.path.append('./pytorch3d-lite') |
|
|
|
|
|
import os |
|
import streamlit as st |
|
from os import path |
|
from os.path import exists as path_exists |
|
import sys |
|
import torch |
|
|
|
|
|
from dataclasses import dataclass |
|
from functools import partial |
|
import cv2 |
|
import pandas as pd |
|
import gc |
|
import io |
|
import math |
|
import timm |
|
from IPython import display |
|
import lpips |
|
from PIL import Image, ImageOps |
|
import requests |
|
from glob import glob |
|
import json |
|
from types import SimpleNamespace |
|
from torch import nn |
|
from torch.nn import functional as F |
|
import torchvision.transforms as T |
|
import torchvision.transforms.functional as TF |
|
import shutil |
|
from pathvalidate import sanitize_filename |
|
|
|
|
|
|
|
import clip |
|
from resize_right import resize |
|
|
|
|
|
from guided_diffusion.script_util import ( |
|
create_model_and_diffusion, |
|
model_and_diffusion_defaults, |
|
) |
|
from datetime import datetime |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import random |
|
from ipywidgets import Output |
|
import hashlib |
|
import ipywidgets as widgets |
|
import os |
|
|
|
|
|
from torchvision.datasets.utils import download_url |
|
from functools import partial |
|
from ldm.util import instantiate_from_config |
|
from ldm.modules.diffusionmodules.util import ( |
|
make_ddim_sampling_parameters, |
|
make_ddim_timesteps, |
|
noise_like, |
|
) |
|
|
|
|
|
from ldm.util import ismap |
|
from IPython.display import Image as ipyimg |
|
from numpy import asarray |
|
from einops import rearrange, repeat |
|
import torch, torchvision |
|
import time |
|
from omegaconf import OmegaConf |
|
from midas.dpt_depth import DPTDepthModel |
|
from midas.midas_net import MidasNet |
|
from midas.midas_net_custom import MidasNet_small |
|
from midas.transforms import Resize, NormalizeImage, PrepareForNet |
|
import torch |
|
import py3d_tools as p3dT |
|
import disco_xform_utils as dxf |
|
import argparse |
|
|
|
sys.stdout.write("Parsing arguments ...\n") |
|
sys.stdout.flush() |
|
|
|
|
|
def run_model(args2, status, stoutput, DefaultPaths): |
|
if args2.seed is not None: |
|
sys.stdout.write(f"Setting seed to {args2.seed} ...\n") |
|
sys.stdout.flush() |
|
status.write(f"Setting seed to {args2.seed} ...\n") |
|
import numpy as np |
|
|
|
np.random.seed(args2.seed) |
|
import random |
|
|
|
random.seed(args2.seed) |
|
|
|
|
|
torch.manual_seed(args2.seed) |
|
torch.cuda.manual_seed(args2.seed) |
|
torch.cuda.manual_seed_all(args2.seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
print("Using device:", DEVICE) |
|
device = DEVICE |
|
|
|
|
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
|
|
|
PROJECT_DIR = os.path.abspath(os.getcwd()) |
|
|
|
|
|
USE_ADABINS = True |
|
if USE_ADABINS: |
|
sys.path.append("./AdaBins") |
|
from infer import InferenceHelper |
|
|
|
MAX_ADABINS_AREA = 500000 |
|
|
|
model_256_downloaded = False |
|
model_512_downloaded = False |
|
model_secondary_downloaded = False |
|
|
|
|
|
|
|
|
|
default_models = { |
|
"midas_v21_small": f"{DefaultPaths.model_path}/midas_v21_small-70d6b9c8.pt", |
|
"midas_v21": f"{DefaultPaths.model_path}/midas_v21-f6b98070.pt", |
|
"dpt_large": f"{DefaultPaths.model_path}/dpt_large-midas-2f21e586.pt", |
|
"dpt_hybrid": f"{DefaultPaths.model_path}/dpt_hybrid-midas-501f0c75.pt", |
|
"dpt_hybrid_nyu": f"{DefaultPaths.model_path}/dpt_hybrid_nyu-2ce69ec7.pt", |
|
} |
|
|
|
def init_midas_depth_model(midas_model_type="dpt_large", optimize=True): |
|
midas_model = None |
|
net_w = None |
|
net_h = None |
|
resize_mode = None |
|
normalization = None |
|
|
|
print(f"Initializing MiDaS '{midas_model_type}' depth model...") |
|
|
|
midas_model_path = default_models[midas_model_type] |
|
|
|
if midas_model_type == "dpt_large": |
|
midas_model = DPTDepthModel( |
|
path=midas_model_path, |
|
backbone="vitl16_384", |
|
non_negative=True, |
|
) |
|
net_w, net_h = 384, 384 |
|
resize_mode = "minimal" |
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
elif midas_model_type == "dpt_hybrid": |
|
midas_model = DPTDepthModel( |
|
path=midas_model_path, |
|
backbone="vitb_rn50_384", |
|
non_negative=True, |
|
) |
|
net_w, net_h = 384, 384 |
|
resize_mode = "minimal" |
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
elif midas_model_type == "dpt_hybrid_nyu": |
|
midas_model = DPTDepthModel( |
|
path=midas_model_path, |
|
backbone="vitb_rn50_384", |
|
non_negative=True, |
|
) |
|
net_w, net_h = 384, 384 |
|
resize_mode = "minimal" |
|
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
|
elif midas_model_type == "midas_v21": |
|
midas_model = MidasNet(midas_model_path, non_negative=True) |
|
net_w, net_h = 384, 384 |
|
resize_mode = "upper_bound" |
|
normalization = NormalizeImage( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
) |
|
elif midas_model_type == "midas_v21_small": |
|
midas_model = MidasNet_small( |
|
midas_model_path, |
|
features=64, |
|
backbone="efficientnet_lite3", |
|
exportable=True, |
|
non_negative=True, |
|
blocks={"expand": True}, |
|
) |
|
net_w, net_h = 256, 256 |
|
resize_mode = "upper_bound" |
|
normalization = NormalizeImage( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
) |
|
else: |
|
print(f"midas_model_type '{midas_model_type}' not implemented") |
|
assert False |
|
|
|
midas_transform = T.Compose( |
|
[ |
|
Resize( |
|
net_w, |
|
net_h, |
|
resize_target=None, |
|
keep_aspect_ratio=True, |
|
ensure_multiple_of=32, |
|
resize_method=resize_mode, |
|
image_interpolation_method=cv2.INTER_CUBIC, |
|
), |
|
normalization, |
|
PrepareForNet(), |
|
] |
|
) |
|
|
|
midas_model.eval() |
|
|
|
if optimize == True: |
|
if DEVICE == torch.device("cuda"): |
|
midas_model = midas_model.to(memory_format=torch.channels_last) |
|
midas_model = midas_model.half() |
|
|
|
midas_model.to(DEVICE) |
|
|
|
print(f"MiDaS '{midas_model_type}' depth model initialized.") |
|
return midas_model, midas_transform, net_w, net_h, resize_mode, normalization |
|
|
|
|
|
|
|
|
|
|
|
def interp(t): |
|
return 3 * t**2 - 2 * t**3 |
|
|
|
def perlin(width, height, scale=10, device=None): |
|
gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device) |
|
xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device) |
|
ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device) |
|
wx = 1 - interp(xs) |
|
wy = 1 - interp(ys) |
|
dots = 0 |
|
dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys) |
|
dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys) |
|
dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys)) |
|
dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys)) |
|
return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale) |
|
|
|
def perlin_ms(octaves, width, height, grayscale, device=device): |
|
out_array = [0.5] if grayscale else [0.5, 0.5, 0.5] |
|
|
|
for i in range(1 if grayscale else 3): |
|
scale = 2 ** len(octaves) |
|
oct_width = width |
|
oct_height = height |
|
for oct in octaves: |
|
p = perlin(oct_width, oct_height, scale, device) |
|
out_array[i] += p * oct |
|
scale //= 2 |
|
oct_width *= 2 |
|
oct_height *= 2 |
|
return torch.cat(out_array) |
|
|
|
def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True): |
|
out = perlin_ms(octaves, width, height, grayscale) |
|
if grayscale: |
|
out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0)) |
|
out = TF.to_pil_image(out.clamp(0, 1)).convert("RGB") |
|
else: |
|
out = out.reshape(-1, 3, out.shape[0] // 3, out.shape[1]) |
|
out = TF.resize(size=(side_y, side_x), img=out) |
|
out = TF.to_pil_image(out.clamp(0, 1).squeeze()) |
|
|
|
out = ImageOps.autocontrast(out) |
|
return out |
|
|
|
def regen_perlin(): |
|
if perlin_mode == "color": |
|
init = create_perlin_noise( |
|
[1.5**-i * 0.5 for i in range(12)], 1, 1, False |
|
) |
|
init2 = create_perlin_noise( |
|
[1.5**-i * 0.5 for i in range(8)], 4, 4, False |
|
) |
|
elif perlin_mode == "gray": |
|
init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1, True) |
|
init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, True) |
|
else: |
|
init = create_perlin_noise( |
|
[1.5**-i * 0.5 for i in range(12)], 1, 1, False |
|
) |
|
init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, True) |
|
|
|
init = ( |
|
TF.to_tensor(init) |
|
.add(TF.to_tensor(init2)) |
|
.div(2) |
|
.to(device) |
|
.unsqueeze(0) |
|
.mul(2) |
|
.sub(1) |
|
) |
|
del init2 |
|
return init.expand(batch_size, -1, -1, -1) |
|
|
|
def fetch(url_or_path): |
|
if str(url_or_path).startswith("http://") or str(url_or_path).startswith( |
|
"https://" |
|
): |
|
r = requests.get(url_or_path) |
|
r.raise_for_status() |
|
fd = io.BytesIO() |
|
fd.write(r.content) |
|
fd.seek(0) |
|
return fd |
|
return open(url_or_path, "rb") |
|
|
|
def read_image_workaround(path): |
|
"""OpenCV reads images as BGR, Pillow saves them as RGB. Work around |
|
this incompatibility to avoid colour inversions.""" |
|
im_tmp = cv2.imread(path) |
|
return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB) |
|
|
|
def parse_prompt(prompt): |
|
if prompt.startswith("http://") or prompt.startswith("https://"): |
|
vals = prompt.rsplit(":", 2) |
|
vals = [vals[0] + ":" + vals[1], *vals[2:]] |
|
else: |
|
vals = prompt.rsplit(":", 1) |
|
vals = vals + ["", "1"][len(vals) :] |
|
return vals[0], float(vals[1]) |
|
|
|
def sinc(x): |
|
return torch.where( |
|
x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]) |
|
) |
|
|
|
def lanczos(x, a): |
|
cond = torch.logical_and(-a < x, x < a) |
|
out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([])) |
|
return out / out.sum() |
|
|
|
def ramp(ratio, width): |
|
n = math.ceil(width / ratio + 1) |
|
out = torch.empty([n]) |
|
cur = 0 |
|
for i in range(out.shape[0]): |
|
out[i] = cur |
|
cur += ratio |
|
return torch.cat([-out[1:].flip([0]), out])[1:-1] |
|
|
|
def resample(input, size, align_corners=True): |
|
n, c, h, w = input.shape |
|
dh, dw = size |
|
|
|
input = input.reshape([n * c, 1, h, w]) |
|
|
|
if dh < h: |
|
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) |
|
pad_h = (kernel_h.shape[0] - 1) // 2 |
|
input = F.pad(input, (0, 0, pad_h, pad_h), "reflect") |
|
input = F.conv2d(input, kernel_h[None, None, :, None]) |
|
|
|
if dw < w: |
|
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) |
|
pad_w = (kernel_w.shape[0] - 1) // 2 |
|
input = F.pad(input, (pad_w, pad_w, 0, 0), "reflect") |
|
input = F.conv2d(input, kernel_w[None, None, None, :]) |
|
|
|
input = input.reshape([n, c, h, w]) |
|
return F.interpolate(input, size, mode="bicubic", align_corners=align_corners) |
|
|
|
class MakeCutouts(nn.Module): |
|
def __init__(self, cut_size, cutn, skip_augs=False): |
|
super().__init__() |
|
self.cut_size = cut_size |
|
self.cutn = cutn |
|
self.skip_augs = skip_augs |
|
self.augs = T.Compose( |
|
[ |
|
T.RandomHorizontalFlip(p=0.5), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomAffine(degrees=15, translate=(0.1, 0.1)), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomPerspective(distortion_scale=0.4, p=0.7), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomGrayscale(p=0.15), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
|
|
] |
|
) |
|
|
|
def forward(self, input): |
|
input = T.Pad(input.shape[2] // 4, fill=0)(input) |
|
sideY, sideX = input.shape[2:4] |
|
max_size = min(sideX, sideY) |
|
|
|
cutouts = [] |
|
for ch in range(self.cutn): |
|
if ch > self.cutn - self.cutn // 4: |
|
cutout = input.clone() |
|
else: |
|
size = int( |
|
max_size |
|
* torch.zeros( |
|
1, |
|
) |
|
.normal_(mean=0.8, std=0.3) |
|
.clip(float(self.cut_size / max_size), 1.0) |
|
) |
|
offsetx = torch.randint(0, abs(sideX - size + 1), ()) |
|
offsety = torch.randint(0, abs(sideY - size + 1), ()) |
|
cutout = input[ |
|
:, :, offsety : offsety + size, offsetx : offsetx + size |
|
] |
|
|
|
if not self.skip_augs: |
|
cutout = self.augs(cutout) |
|
cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) |
|
del cutout |
|
|
|
cutouts = torch.cat(cutouts, dim=0) |
|
return cutouts |
|
|
|
cutout_debug = False |
|
padargs = {} |
|
|
|
class MakeCutoutsDango(nn.Module): |
|
def __init__( |
|
self, cut_size, Overview=4, InnerCrop=0, IC_Size_Pow=0.5, IC_Grey_P=0.2 |
|
): |
|
super().__init__() |
|
self.cut_size = cut_size |
|
self.Overview = Overview |
|
self.InnerCrop = InnerCrop |
|
self.IC_Size_Pow = IC_Size_Pow |
|
self.IC_Grey_P = IC_Grey_P |
|
if args.animation_mode == "None": |
|
self.augs = T.Compose( |
|
[ |
|
T.RandomHorizontalFlip(p=0.5), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomAffine( |
|
degrees=10, |
|
translate=(0.05, 0.05), |
|
interpolation=T.InterpolationMode.BILINEAR, |
|
), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomGrayscale(p=0.1), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.ColorJitter( |
|
brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1 |
|
), |
|
] |
|
) |
|
elif args.animation_mode == "Video Input": |
|
self.augs = T.Compose( |
|
[ |
|
T.RandomHorizontalFlip(p=0.5), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomAffine(degrees=15, translate=(0.1, 0.1)), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomPerspective(distortion_scale=0.4, p=0.7), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomGrayscale(p=0.15), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
|
|
] |
|
) |
|
elif args.animation_mode == "2D" or args.animation_mode == "3D": |
|
self.augs = T.Compose( |
|
[ |
|
T.RandomHorizontalFlip(p=0.4), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomAffine( |
|
degrees=10, |
|
translate=(0.05, 0.05), |
|
interpolation=T.InterpolationMode.BILINEAR, |
|
), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.RandomGrayscale(p=0.1), |
|
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), |
|
T.ColorJitter( |
|
brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3 |
|
), |
|
] |
|
) |
|
|
|
def forward(self, input): |
|
cutouts = [] |
|
gray = T.Grayscale(3) |
|
sideY, sideX = input.shape[2:4] |
|
max_size = min(sideX, sideY) |
|
min_size = min(sideX, sideY, self.cut_size) |
|
l_size = max(sideX, sideY) |
|
output_shape = [1, 3, self.cut_size, self.cut_size] |
|
output_shape_2 = [1, 3, self.cut_size + 2, self.cut_size + 2] |
|
pad_input = F.pad( |
|
input, |
|
( |
|
(sideY - max_size) // 2, |
|
(sideY - max_size) // 2, |
|
(sideX - max_size) // 2, |
|
(sideX - max_size) // 2, |
|
), |
|
**padargs, |
|
) |
|
cutout = resize(pad_input, out_shape=output_shape) |
|
|
|
if self.Overview > 0: |
|
if self.Overview <= 4: |
|
if self.Overview >= 1: |
|
cutouts.append(cutout) |
|
if self.Overview >= 2: |
|
cutouts.append(gray(cutout)) |
|
if self.Overview >= 3: |
|
cutouts.append(TF.hflip(cutout)) |
|
if self.Overview == 4: |
|
cutouts.append(gray(TF.hflip(cutout))) |
|
else: |
|
cutout = resize(pad_input, out_shape=output_shape) |
|
for _ in range(self.Overview): |
|
cutouts.append(cutout) |
|
|
|
if cutout_debug: |
|
if is_colab: |
|
TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save( |
|
"/content/cutout_overview0.jpg", quality=99 |
|
) |
|
else: |
|
TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save( |
|
"cutout_overview0.jpg", quality=99 |
|
) |
|
|
|
if self.InnerCrop > 0: |
|
for i in range(self.InnerCrop): |
|
size = int( |
|
torch.rand([]) ** self.IC_Size_Pow * (max_size - min_size) |
|
+ min_size |
|
) |
|
offsetx = torch.randint(0, sideX - size + 1, ()) |
|
offsety = torch.randint(0, sideY - size + 1, ()) |
|
cutout = input[ |
|
:, :, offsety : offsety + size, offsetx : offsetx + size |
|
] |
|
if i <= int(self.IC_Grey_P * self.InnerCrop): |
|
cutout = gray(cutout) |
|
cutout = resize(cutout, out_shape=output_shape) |
|
cutouts.append(cutout) |
|
if cutout_debug: |
|
if is_colab: |
|
TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save( |
|
"/content/cutout_InnerCrop.jpg", quality=99 |
|
) |
|
else: |
|
TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save( |
|
"cutout_InnerCrop.jpg", quality=99 |
|
) |
|
cutouts = torch.cat(cutouts) |
|
if skip_augs is not True: |
|
cutouts = self.augs(cutouts) |
|
return cutouts |
|
|
|
def spherical_dist_loss(x, y): |
|
x = F.normalize(x, dim=-1) |
|
y = F.normalize(y, dim=-1) |
|
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) |
|
|
|
def tv_loss(input): |
|
"""L2 total variation loss, as in Mahendran et al.""" |
|
input = F.pad(input, (0, 1, 0, 1), "replicate") |
|
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] |
|
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] |
|
return (x_diff**2 + y_diff**2).mean([1, 2, 3]) |
|
|
|
def range_loss(input): |
|
return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) |
|
|
|
stop_on_next_loop = False |
|
|
|
def nsToStr(d): |
|
h = 3.6e12 |
|
m = h / 60 |
|
s = m / 60 |
|
return ( |
|
str(int(d / h)) |
|
+ ":" |
|
+ str(int((d % h) / m)) |
|
+ ":" |
|
+ str(int((d % h) % m / s)) |
|
+ "." |
|
+ str(int((d % h) % m % s)) |
|
) |
|
|
|
def do_run(): |
|
seed = args.seed |
|
|
|
|
|
if (args.animation_mode == "3D") and (args.midas_weight > 0.0): |
|
( |
|
midas_model, |
|
midas_transform, |
|
midas_net_w, |
|
midas_net_h, |
|
midas_resize_mode, |
|
midas_normalization, |
|
) = init_midas_depth_model(args.midas_depth_model) |
|
for frame_num in range(args.start_frame, args.max_frames): |
|
if stop_on_next_loop: |
|
break |
|
|
|
display.clear_output(wait=True) |
|
|
|
|
|
|
|
""" |
|
if args.animation_mode != "None": |
|
batchBar = tqdm(range(args.max_frames), desc ="Frames") |
|
batchBar.n = frame_num |
|
batchBar.refresh() |
|
""" |
|
|
|
|
|
if args.animation_mode != "Video Input": |
|
if args.init_image == "": |
|
init_image = None |
|
else: |
|
init_image = args.init_image |
|
init_scale = args.init_scale |
|
skip_steps = args.skip_steps |
|
|
|
if args.animation_mode == "2D": |
|
if args.key_frames: |
|
angle = args.angle_series[frame_num] |
|
zoom = args.zoom_series[frame_num] |
|
translation_x = args.translation_x_series[frame_num] |
|
translation_y = args.translation_y_series[frame_num] |
|
print( |
|
f"angle: {angle}", |
|
f"zoom: {zoom}", |
|
f"translation_x: {translation_x}", |
|
f"translation_y: {translation_y}", |
|
) |
|
|
|
if frame_num > 0: |
|
seed = seed + 1 |
|
if resume_run and frame_num == start_frame: |
|
img_0 = cv2.imread( |
|
batchFolder |
|
+ f"/{batch_name}({batchNum})_{start_frame-1:04}.png" |
|
) |
|
else: |
|
img_0 = cv2.imread("prevFrame.png") |
|
center = (1 * img_0.shape[1] // 2, 1 * img_0.shape[0] // 2) |
|
trans_mat = np.float32( |
|
[[1, 0, translation_x], [0, 1, translation_y]] |
|
) |
|
rot_mat = cv2.getRotationMatrix2D(center, angle, zoom) |
|
trans_mat = np.vstack([trans_mat, [0, 0, 1]]) |
|
rot_mat = np.vstack([rot_mat, [0, 0, 1]]) |
|
transformation_matrix = np.matmul(rot_mat, trans_mat) |
|
img_0 = cv2.warpPerspective( |
|
img_0, |
|
transformation_matrix, |
|
(img_0.shape[1], img_0.shape[0]), |
|
borderMode=cv2.BORDER_WRAP, |
|
) |
|
|
|
cv2.imwrite("prevFrameScaled.png", img_0) |
|
init_image = "prevFrameScaled.png" |
|
init_scale = args.frames_scale |
|
skip_steps = args.calc_frames_skip_steps |
|
|
|
if args.animation_mode == "3D": |
|
if args.key_frames: |
|
angle = args.angle_series[frame_num] |
|
|
|
translation_x = args.translation_x_series[frame_num] |
|
translation_y = args.translation_y_series[frame_num] |
|
translation_z = args.translation_z_series[frame_num] |
|
rotation_3d_x = args.rotation_3d_x_series[frame_num] |
|
rotation_3d_y = args.rotation_3d_y_series[frame_num] |
|
rotation_3d_z = args.rotation_3d_z_series[frame_num] |
|
print( |
|
f"angle: {angle}", |
|
|
|
f"translation_x: {translation_x}", |
|
f"translation_y: {translation_y}", |
|
f"translation_z: {translation_z}", |
|
f"rotation_3d_x: {rotation_3d_x}", |
|
f"rotation_3d_y: {rotation_3d_y}", |
|
f"rotation_3d_z: {rotation_3d_z}", |
|
) |
|
|
|
sys.stdout.flush() |
|
|
|
sys.stdout.flush() |
|
|
|
if frame_num > 0: |
|
seed = seed + 1 |
|
img_filepath = "prevFrame.png" |
|
trans_scale = 1.0 / 200.0 |
|
translate_xyz = [ |
|
-translation_x * trans_scale, |
|
translation_y * trans_scale, |
|
-translation_z * trans_scale, |
|
] |
|
rotate_xyz = [ |
|
math.radians(rotation_3d_x), |
|
math.radians(rotation_3d_y), |
|
math.radians(rotation_3d_z), |
|
] |
|
print("translation:", translate_xyz) |
|
print("rotation:", rotate_xyz) |
|
rot_mat = p3dT.euler_angles_to_matrix( |
|
torch.tensor(rotate_xyz, device=device), "XYZ" |
|
).unsqueeze(0) |
|
print("rot_mat: " + str(rot_mat)) |
|
next_step_pil = dxf.transform_image_3d( |
|
img_filepath, |
|
midas_model, |
|
midas_transform, |
|
DEVICE, |
|
rot_mat, |
|
translate_xyz, |
|
args.near_plane, |
|
args.far_plane, |
|
args.fov, |
|
padding_mode=args.padding_mode, |
|
sampling_mode=args.sampling_mode, |
|
midas_weight=args.midas_weight, |
|
) |
|
next_step_pil.save("prevFrameScaled.png") |
|
|
|
""" |
|
### Turbo mode - skip some diffusions to save time |
|
if turbo_mode == True and frame_num > 10 and frame_num % int(turbo_steps) != 0: |
|
#turbo_steps |
|
print('turbo mode is on this frame: skipping clip diffusion steps') |
|
#this is an even frame. copy warped prior frame w/ war |
|
#filename = f'{args.batch_name}({args.batchNum})_{frame_num:04}.png' |
|
#next_step_pil.save(f'{batchFolder}/{filename}') #save it as this frame |
|
#next_step_pil.save(f'{img_filepath}') # save it also as prev_frame for next iteration |
|
filename = f'progress.png' |
|
next_step_pil.save(f'{filename}') #save it as this frame |
|
next_step_pil.save(f'{img_filepath}') # save it also as prev_frame for next iteration |
|
continue |
|
elif turbo_mode == True: |
|
print('turbo mode is OFF this frame') |
|
#else: no turbo |
|
""" |
|
|
|
init_image = "prevFrameScaled.png" |
|
init_scale = args.frames_scale |
|
skip_steps = args.calc_frames_skip_steps |
|
|
|
if args.animation_mode == "Video Input": |
|
seed = seed + 1 |
|
init_image = f"{videoFramesFolder}/{frame_num+1:04}.jpg" |
|
init_scale = args.frames_scale |
|
skip_steps = args.calc_frames_skip_steps |
|
|
|
loss_values = [] |
|
|
|
if seed is not None: |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
target_embeds, weights = [], [] |
|
|
|
if args.prompts_series is not None and frame_num >= len( |
|
args.prompts_series |
|
): |
|
frame_prompt = args.prompts_series[-1] |
|
elif args.prompts_series is not None: |
|
frame_prompt = args.prompts_series[frame_num] |
|
else: |
|
frame_prompt = [] |
|
|
|
print(args.image_prompts_series) |
|
if args.image_prompts_series is not None and frame_num >= len( |
|
args.image_prompts_series |
|
): |
|
image_prompt = args.image_prompts_series[-1] |
|
elif args.image_prompts_series is not None: |
|
image_prompt = args.image_prompts_series[frame_num] |
|
else: |
|
image_prompt = [] |
|
|
|
print(f"Frame Prompt: {frame_prompt}") |
|
|
|
model_stats = [] |
|
for clip_model in clip_models: |
|
cutn = args2.cutn |
|
model_stat = { |
|
"clip_model": None, |
|
"target_embeds": [], |
|
"make_cutouts": None, |
|
"weights": [], |
|
} |
|
model_stat["clip_model"] = clip_model |
|
|
|
for prompt in frame_prompt: |
|
txt, weight = parse_prompt(prompt) |
|
txt = clip_model.encode_text( |
|
clip.tokenize(prompt).to(device) |
|
).float() |
|
|
|
if args.fuzzy_prompt: |
|
for i in range(25): |
|
model_stat["target_embeds"].append( |
|
( |
|
txt + torch.randn(txt.shape).cuda() * args.rand_mag |
|
).clamp(0, 1) |
|
) |
|
model_stat["weights"].append(weight) |
|
else: |
|
model_stat["target_embeds"].append(txt) |
|
model_stat["weights"].append(weight) |
|
|
|
if image_prompt: |
|
model_stat["make_cutouts"] = MakeCutouts( |
|
clip_model.visual.input_resolution, cutn, skip_augs=skip_augs |
|
) |
|
for prompt in image_prompt: |
|
path, weight = parse_prompt(prompt) |
|
img = Image.open(fetch(path)).convert("RGB") |
|
img = TF.resize( |
|
img, |
|
min(side_x, side_y, *img.size), |
|
T.InterpolationMode.LANCZOS, |
|
) |
|
batch = model_stat["make_cutouts"]( |
|
TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1) |
|
) |
|
embed = clip_model.encode_image(normalize(batch)).float() |
|
if fuzzy_prompt: |
|
for i in range(25): |
|
model_stat["target_embeds"].append( |
|
( |
|
embed |
|
+ torch.randn(embed.shape).cuda() * rand_mag |
|
).clamp(0, 1) |
|
) |
|
weights.extend([weight / cutn] * cutn) |
|
else: |
|
model_stat["target_embeds"].append(embed) |
|
model_stat["weights"].extend([weight / cutn] * cutn) |
|
|
|
model_stat["target_embeds"] = torch.cat(model_stat["target_embeds"]) |
|
model_stat["weights"] = torch.tensor( |
|
model_stat["weights"], device=device |
|
) |
|
if model_stat["weights"].sum().abs() < 1e-3: |
|
raise RuntimeError("The weights must not sum to 0.") |
|
model_stat["weights"] /= model_stat["weights"].sum().abs() |
|
model_stats.append(model_stat) |
|
|
|
init = None |
|
if init_image is not None: |
|
init = Image.open(fetch(init_image)).convert("RGB") |
|
init = init.resize((args.side_x, args.side_y), Image.LANCZOS) |
|
init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1) |
|
|
|
if args.perlin_init: |
|
if args.perlin_mode == "color": |
|
init = create_perlin_noise( |
|
[1.5**-i * 0.5 for i in range(12)], 1, 1, False |
|
) |
|
init2 = create_perlin_noise( |
|
[1.5**-i * 0.5 for i in range(8)], 4, 4, False |
|
) |
|
elif args.perlin_mode == "gray": |
|
init = create_perlin_noise( |
|
[1.5**-i * 0.5 for i in range(12)], 1, 1, True |
|
) |
|
init2 = create_perlin_noise( |
|
[1.5**-i * 0.5 for i in range(8)], 4, 4, True |
|
) |
|
else: |
|
init = create_perlin_noise( |
|
[1.5**-i * 0.5 for i in range(12)], 1, 1, False |
|
) |
|
init2 = create_perlin_noise( |
|
[1.5**-i * 0.5 for i in range(8)], 4, 4, True |
|
) |
|
|
|
init = ( |
|
TF.to_tensor(init) |
|
.add(TF.to_tensor(init2)) |
|
.div(2) |
|
.to(device) |
|
.unsqueeze(0) |
|
.mul(2) |
|
.sub(1) |
|
) |
|
del init2 |
|
|
|
cur_t = None |
|
|
|
def cond_fn(x, t, y=None): |
|
with torch.enable_grad(): |
|
x_is_NaN = False |
|
x = x.detach().requires_grad_() |
|
n = x.shape[0] |
|
if use_secondary_model is True: |
|
alpha = torch.tensor( |
|
diffusion.sqrt_alphas_cumprod[cur_t], |
|
device=device, |
|
dtype=torch.float32, |
|
) |
|
sigma = torch.tensor( |
|
diffusion.sqrt_one_minus_alphas_cumprod[cur_t], |
|
device=device, |
|
dtype=torch.float32, |
|
) |
|
cosine_t = alpha_sigma_to_t(alpha, sigma) |
|
out = secondary_model(x, cosine_t[None].repeat([n])).pred |
|
fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t] |
|
x_in = out * fac + x * (1 - fac) |
|
x_in_grad = torch.zeros_like(x_in) |
|
else: |
|
my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t |
|
out = diffusion.p_mean_variance( |
|
model, x, my_t, clip_denoised=False, model_kwargs={"y": y} |
|
) |
|
fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t] |
|
x_in = out["pred_xstart"] * fac + x * (1 - fac) |
|
x_in_grad = torch.zeros_like(x_in) |
|
for model_stat in model_stats: |
|
for i in range(int(args.cutn_batches)): |
|
t_int = ( |
|
int(t.item()) + 1 |
|
) |
|
|
|
try: |
|
input_resolution = model_stat[ |
|
"clip_model" |
|
].visual.input_resolution |
|
except: |
|
input_resolution = 224 |
|
|
|
cuts = MakeCutoutsDango( |
|
input_resolution, |
|
Overview=args.cut_overview[1000 - t_int], |
|
InnerCrop=args.cut_innercut[1000 - t_int], |
|
IC_Size_Pow=args.cut_ic_pow, |
|
IC_Grey_P=args.cut_icgray_p[1000 - t_int], |
|
) |
|
clip_in = normalize(cuts(x_in.add(1).div(2))) |
|
image_embeds = ( |
|
model_stat["clip_model"].encode_image(clip_in).float() |
|
) |
|
dists = spherical_dist_loss( |
|
image_embeds.unsqueeze(1), |
|
model_stat["target_embeds"].unsqueeze(0), |
|
) |
|
dists = dists.view( |
|
[ |
|
args.cut_overview[1000 - t_int] |
|
+ args.cut_innercut[1000 - t_int], |
|
n, |
|
-1, |
|
] |
|
) |
|
losses = dists.mul(model_stat["weights"]).sum(2).mean(0) |
|
loss_values.append( |
|
losses.sum().item() |
|
) |
|
x_in_grad += ( |
|
torch.autograd.grad( |
|
losses.sum() * clip_guidance_scale, x_in |
|
)[0] |
|
/ cutn_batches |
|
) |
|
tv_losses = tv_loss(x_in) |
|
if use_secondary_model is True: |
|
range_losses = range_loss(out) |
|
else: |
|
range_losses = range_loss(out["pred_xstart"]) |
|
sat_losses = torch.abs(x_in - x_in.clamp(min=-1, max=1)).mean() |
|
loss = ( |
|
tv_losses.sum() * tv_scale |
|
+ range_losses.sum() * range_scale |
|
+ sat_losses.sum() * sat_scale |
|
) |
|
if init is not None and args.init_scale: |
|
init_losses = lpips_model(x_in, init) |
|
loss = loss + init_losses.sum() * args.init_scale |
|
x_in_grad += torch.autograd.grad(loss, x_in)[0] |
|
if torch.isnan(x_in_grad).any() == False: |
|
grad = -torch.autograd.grad(x_in, x, x_in_grad)[0] |
|
else: |
|
|
|
x_is_NaN = True |
|
grad = torch.zeros_like(x) |
|
if args.clamp_grad and x_is_NaN == False: |
|
magnitude = grad.square().mean().sqrt() |
|
return ( |
|
grad * magnitude.clamp(max=args.clamp_max) / magnitude |
|
) |
|
return grad |
|
|
|
if args.sampling_mode == "ddim": |
|
sample_fn = diffusion.ddim_sample_loop_progressive |
|
elif args.sampling_mode == "bicubic": |
|
sample_fn = diffusion.p_sample_loop_progressive |
|
elif args.sampling_mode == "plms": |
|
sample_fn = diffusion.plms_sample_loop_progressive |
|
|
|
|
|
|
|
|
|
|
|
image_display = Output() |
|
for i in range(args.n_batches): |
|
""" |
|
if args.animation_mode == 'None': |
|
display.clear_output(wait=True) |
|
batchBar = tqdm(range(args.n_batches), desc ="Batches") |
|
batchBar.n = i |
|
batchBar.refresh() |
|
print('') |
|
display.display(image_display) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
""" |
|
cur_t = diffusion.num_timesteps - skip_steps - 1 |
|
total_steps = cur_t |
|
|
|
if perlin_init: |
|
init = regen_perlin() |
|
|
|
if args.sampling_mode == "ddim": |
|
samples = sample_fn( |
|
model, |
|
(batch_size, 3, args.side_y, args.side_x), |
|
clip_denoised=clip_denoised, |
|
model_kwargs={}, |
|
cond_fn=cond_fn, |
|
progress=True, |
|
skip_timesteps=skip_steps, |
|
init_image=init, |
|
randomize_class=randomize_class, |
|
eta=eta, |
|
) |
|
elif args.sampling_mode == "plms": |
|
samples = sample_fn( |
|
model, |
|
(batch_size, 3, args.side_y, args.side_x), |
|
clip_denoised=clip_denoised, |
|
model_kwargs={}, |
|
cond_fn=cond_fn, |
|
progress=True, |
|
skip_timesteps=skip_steps, |
|
init_image=init, |
|
randomize_class=randomize_class, |
|
order=2, |
|
) |
|
elif args.sampling_mode == "bicubic": |
|
samples = sample_fn( |
|
model, |
|
(batch_size, 3, args.side_y, args.side_x), |
|
clip_denoised=clip_denoised, |
|
model_kwargs={}, |
|
cond_fn=cond_fn, |
|
progress=True, |
|
skip_timesteps=skip_steps, |
|
init_image=init, |
|
randomize_class=randomize_class, |
|
) |
|
|
|
|
|
|
|
itt = 1 |
|
imgToSharpen = None |
|
status.write("Starting the execution...") |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
total_iterables = samples |
|
try: |
|
j = 0 |
|
before_start_time = time.perf_counter() |
|
bar_container = status.container() |
|
iteration_counter = bar_container.empty() |
|
progress_bar = bar_container.progress(0) |
|
for sample in total_iterables: |
|
if itt == 1: |
|
iteration_counter.empty() |
|
imageLocation = stoutput.empty() |
|
sys.stdout.write(f"Iteration {itt}\n") |
|
sys.stdout.flush() |
|
cur_t -= 1 |
|
intermediateStep = False |
|
if args.steps_per_checkpoint is not None: |
|
if j % steps_per_checkpoint == 0 and j > 0: |
|
intermediateStep = True |
|
elif j in args.intermediate_saves: |
|
intermediateStep = True |
|
with image_display: |
|
""" |
|
if j % args.display_rate == 0 or cur_t == -1 or intermediateStep == True: |
|
for k, image in enumerate(sample['pred_xstart']): |
|
# tqdm.write(f'Batch {i}, step {j}, output {k}:') |
|
current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f') |
|
percent = math.ceil(j/total_steps*100) |
|
if args.n_batches > 0: |
|
#if intermediates are saved to the subfolder, don't append a step or percentage to the name |
|
if cur_t == -1 and args.intermediates_in_subfolder is True: |
|
save_num = f'{frame_num:04}' if animation_mode != "None" else i |
|
filename = f'{args.batch_name}({args.batchNum})_{save_num}.png' |
|
else: |
|
#If we're working with percentages, append it |
|
if args.steps_per_checkpoint is not None: |
|
filename = f'{args.batch_name}({args.batchNum})_{i:04}-{percent:02}%.png' |
|
# Or else, iIf we're working with specific steps, append those |
|
else: |
|
filename = f'{args.batch_name}({args.batchNum})_{i:04}-{j:03}.png' |
|
image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1)) |
|
if j % args.display_rate == 0 or cur_t == -1: |
|
image.save('progress.png') |
|
#display.clear_output(wait=True) |
|
#display.display(display.Image('progress.png')) |
|
if args.steps_per_checkpoint is not None: |
|
if j % args.steps_per_checkpoint == 0 and j > 0: |
|
if args.intermediates_in_subfolder is True: |
|
image.save(f'{partialFolder}/{filename}') |
|
else: |
|
image.save(f'{batchFolder}/{filename}') |
|
else: |
|
if j in args.intermediate_saves: |
|
if args.intermediates_in_subfolder is True: |
|
image.save(f'{partialFolder}/{filename}') |
|
else: |
|
image.save(f'{batchFolder}/{filename}') |
|
if cur_t == -1: |
|
if frame_num == 0: |
|
save_settings() |
|
if args.animation_mode != "None": |
|
image.save('prevFrame.png') |
|
if args.sharpen_preset != "Off" and animation_mode == "None": |
|
imgToSharpen = image |
|
if args.keep_unsharp is True: |
|
image.save(f'{unsharpenFolder}/{filename}') |
|
else: |
|
image.save(f'{batchFolder}/{filename}') |
|
# if frame_num != args.max_frames-1: |
|
# display.clear_output() |
|
""" |
|
if itt % args2.update == 0 or cur_t == -1 or itt == 1: |
|
for k, image in enumerate(sample["pred_xstart"]): |
|
sys.stdout.flush() |
|
sys.stdout.write("Saving progress ...\n") |
|
sys.stdout.flush() |
|
|
|
image = TF.to_pil_image( |
|
image.add(1).div(2).clamp(0, 1) |
|
) |
|
|
|
if args.animation_mode != "None": |
|
image.save("prevFrame.png") |
|
|
|
image.save(args2.image_file) |
|
if (args2.frame_dir is not None) and ( |
|
args.animation_mode == "None" |
|
): |
|
import os |
|
|
|
file_list = [] |
|
for file in sorted(os.listdir(args2.frame_dir)): |
|
if file.startswith("FRA"): |
|
if file.endswith("PNG"): |
|
if len(file) == 12: |
|
file_list.append(file) |
|
if file_list: |
|
last_name = file_list[-1] |
|
count_value = int(last_name[3:8]) + 1 |
|
count_string = f"{count_value:05d}" |
|
else: |
|
count_string = "00001" |
|
save_name = ( |
|
args2.frame_dir |
|
+ "/FRA" |
|
+ count_string |
|
+ ".PNG" |
|
) |
|
image.save(save_name) |
|
|
|
|
|
|
|
|
|
if ( |
|
(args2.frame_dir is not None) |
|
and (args.animation_mode == "3D") |
|
and (itt == args2.iterations - skip_steps) |
|
): |
|
sys.stdout.flush() |
|
sys.stdout.write("Saving 3D frame...\n") |
|
sys.stdout.flush() |
|
import os |
|
|
|
file_list = [] |
|
for file in os.listdir(args2.frame_dir): |
|
if file.startswith("FRA"): |
|
if file.endswith("PNG"): |
|
if len(file) == 12: |
|
file_list.append(file) |
|
if file_list: |
|
last_name = file_list[-1] |
|
count_value = int(last_name[3:8]) + 1 |
|
count_string = f"{count_value:05d}" |
|
else: |
|
count_string = "00001" |
|
save_name = ( |
|
args2.frame_dir |
|
+ "/FRA" |
|
+ count_string |
|
+ ".PNG" |
|
) |
|
image.save(save_name) |
|
|
|
imageLocation.image(Image.open(args2.image_file)) |
|
sys.stdout.flush() |
|
sys.stdout.write("Progress saved\n") |
|
sys.stdout.flush() |
|
itt += 1 |
|
j += 1 |
|
time_past_seconds = time.perf_counter() - before_start_time |
|
iterations_per_second = j / time_past_seconds |
|
time_left = (total_steps - j) / iterations_per_second |
|
percentage = round((j / (total_steps + 1)) * 100) |
|
|
|
iteration_counter.write( |
|
f"{percentage}% {j}/{total_steps+1} [{time.strftime('%M:%S', time.gmtime(time_past_seconds))}<{time.strftime('%M:%S', time.gmtime(time_left))}, {round(iterations_per_second,2)} it/s]" |
|
) |
|
progress_bar.progress(int(percentage)) |
|
|
|
|
|
|
|
except KeyboardInterrupt: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
imageLocation.empty() |
|
with image_display: |
|
if args.sharpen_preset != "Off" and animation_mode == "None": |
|
print("Starting Diffusion Sharpening...") |
|
do_superres(imgToSharpen, f"{batchFolder}/{filename}") |
|
display.clear_output() |
|
|
|
import shutil |
|
from pathvalidate import sanitize_filename |
|
import os |
|
|
|
if not path_exists(DefaultPaths.output_path): |
|
os.makedirs(DefaultPaths.output_path) |
|
save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}.png" |
|
print(save_filename) |
|
file_list = [] |
|
if path_exists(save_filename): |
|
for file in sorted(os.listdir(f"{DefaultPaths.output_path}/")): |
|
if file.startswith( |
|
f"{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}" |
|
): |
|
print(file) |
|
file_list.append(file) |
|
print(file_list) |
|
last_name = file_list[-1] |
|
print(last_name) |
|
if last_name[-15:-10] == "batch": |
|
count_value = int(last_name[-10:-4]) + 1 |
|
count_string = f"{count_value:05d}" |
|
save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}_batch {count_string}.png" |
|
else: |
|
save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}_batch 00001.png" |
|
shutil.copyfile( |
|
args2.image_file, |
|
save_filename, |
|
) |
|
imageLocation.empty() |
|
status.write("Done!") |
|
plt.plot(np.array(loss_values), "r") |
|
|
|
def save_settings(): |
|
setting_list = { |
|
"text_prompts": text_prompts, |
|
"image_prompts": image_prompts, |
|
"clip_guidance_scale": clip_guidance_scale, |
|
"tv_scale": tv_scale, |
|
"range_scale": range_scale, |
|
"sat_scale": sat_scale, |
|
|
|
"cutn_batches": cutn_batches, |
|
"max_frames": max_frames, |
|
"interp_spline": interp_spline, |
|
|
|
"init_image": init_image, |
|
"init_scale": init_scale, |
|
"skip_steps": skip_steps, |
|
|
|
"frames_scale": frames_scale, |
|
"frames_skip_steps": frames_skip_steps, |
|
"perlin_init": perlin_init, |
|
"perlin_mode": perlin_mode, |
|
"skip_augs": skip_augs, |
|
"randomize_class": randomize_class, |
|
"clip_denoised": clip_denoised, |
|
"clamp_grad": clamp_grad, |
|
"clamp_max": clamp_max, |
|
"seed": seed, |
|
"fuzzy_prompt": fuzzy_prompt, |
|
"rand_mag": rand_mag, |
|
"eta": eta, |
|
"width": width_height[0], |
|
"height": width_height[1], |
|
"diffusion_model": diffusion_model, |
|
"use_secondary_model": use_secondary_model, |
|
"steps": steps, |
|
"diffusion_steps": diffusion_steps, |
|
"ViTB32": ViTB32, |
|
"ViTB16": ViTB16, |
|
"ViTL14": ViTL14, |
|
"RN101": RN101, |
|
"RN50": RN50, |
|
"RN50x4": RN50x4, |
|
"RN50x16": RN50x16, |
|
"RN50x64": RN50x64, |
|
"cut_overview": str(cut_overview), |
|
"cut_innercut": str(cut_innercut), |
|
"cut_ic_pow": cut_ic_pow, |
|
"cut_icgray_p": str(cut_icgray_p), |
|
"key_frames": key_frames, |
|
"max_frames": max_frames, |
|
"angle": angle, |
|
"zoom": zoom, |
|
"translation_x": translation_x, |
|
"translation_y": translation_y, |
|
"translation_z": translation_z, |
|
"rotation_3d_x": rotation_3d_x, |
|
"rotation_3d_y": rotation_3d_y, |
|
"rotation_3d_z": rotation_3d_z, |
|
"midas_depth_model": midas_depth_model, |
|
"midas_weight": midas_weight, |
|
"near_plane": near_plane, |
|
"far_plane": far_plane, |
|
"fov": fov, |
|
"padding_mode": padding_mode, |
|
"sampling_mode": sampling_mode, |
|
"video_init_path": video_init_path, |
|
"extract_nth_frame": extract_nth_frame, |
|
"turbo_mode": turbo_mode, |
|
"turbo_steps": turbo_steps, |
|
} |
|
|
|
with open( |
|
f"{batchFolder}/{batch_name}({batchNum})_settings.txt", "w+" |
|
) as f: |
|
json.dump(setting_list, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
|
|
def append_dims(x, n): |
|
return x[(Ellipsis, *(None,) * (n - x.ndim))] |
|
|
|
def expand_to_planes(x, shape): |
|
return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]]) |
|
|
|
def alpha_sigma_to_t(alpha, sigma): |
|
return torch.atan2(sigma, alpha) * 2 / math.pi |
|
|
|
def t_to_alpha_sigma(t): |
|
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) |
|
|
|
@dataclass |
|
class DiffusionOutput: |
|
v: torch.Tensor |
|
pred: torch.Tensor |
|
eps: torch.Tensor |
|
|
|
class ConvBlock(nn.Sequential): |
|
def __init__(self, c_in, c_out): |
|
super().__init__( |
|
nn.Conv2d(c_in, c_out, 3, padding=1), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
class SkipBlock(nn.Module): |
|
def __init__(self, main, skip=None): |
|
super().__init__() |
|
self.main = nn.Sequential(*main) |
|
self.skip = skip if skip else nn.Identity() |
|
|
|
def forward(self, input): |
|
return torch.cat([self.main(input), self.skip(input)], dim=1) |
|
|
|
class FourierFeatures(nn.Module): |
|
def __init__(self, in_features, out_features, std=1.0): |
|
super().__init__() |
|
assert out_features % 2 == 0 |
|
self.weight = nn.Parameter( |
|
torch.randn([out_features // 2, in_features]) * std |
|
) |
|
|
|
def forward(self, input): |
|
f = 2 * math.pi * input @ self.weight.T |
|
return torch.cat([f.cos(), f.sin()], dim=-1) |
|
|
|
class SecondaryDiffusionImageNet(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
c = 64 |
|
|
|
self.timestep_embed = FourierFeatures(1, 16) |
|
|
|
self.net = nn.Sequential( |
|
ConvBlock(3 + 16, c), |
|
ConvBlock(c, c), |
|
SkipBlock( |
|
[ |
|
nn.AvgPool2d(2), |
|
ConvBlock(c, c * 2), |
|
ConvBlock(c * 2, c * 2), |
|
SkipBlock( |
|
[ |
|
nn.AvgPool2d(2), |
|
ConvBlock(c * 2, c * 4), |
|
ConvBlock(c * 4, c * 4), |
|
SkipBlock( |
|
[ |
|
nn.AvgPool2d(2), |
|
ConvBlock(c * 4, c * 8), |
|
ConvBlock(c * 8, c * 4), |
|
nn.Upsample( |
|
scale_factor=2, |
|
mode="bilinear", |
|
align_corners=False, |
|
), |
|
] |
|
), |
|
ConvBlock(c * 8, c * 4), |
|
ConvBlock(c * 4, c * 2), |
|
nn.Upsample( |
|
scale_factor=2, mode="bilinear", align_corners=False |
|
), |
|
] |
|
), |
|
ConvBlock(c * 4, c * 2), |
|
ConvBlock(c * 2, c), |
|
nn.Upsample( |
|
scale_factor=2, mode="bilinear", align_corners=False |
|
), |
|
] |
|
), |
|
ConvBlock(c * 2, c), |
|
nn.Conv2d(c, 3, 3, padding=1), |
|
) |
|
|
|
def forward(self, input, t): |
|
timestep_embed = expand_to_planes( |
|
self.timestep_embed(t[:, None]), input.shape |
|
) |
|
v = self.net(torch.cat([input, timestep_embed], dim=1)) |
|
alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t)) |
|
pred = input * alphas - v * sigmas |
|
eps = input * sigmas + v * alphas |
|
return DiffusionOutput(v, pred, eps) |
|
|
|
class SecondaryDiffusionImageNet2(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
c = 64 |
|
cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8] |
|
|
|
self.timestep_embed = FourierFeatures(1, 16) |
|
self.down = nn.AvgPool2d(2) |
|
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) |
|
|
|
self.net = nn.Sequential( |
|
ConvBlock(3 + 16, cs[0]), |
|
ConvBlock(cs[0], cs[0]), |
|
SkipBlock( |
|
[ |
|
self.down, |
|
ConvBlock(cs[0], cs[1]), |
|
ConvBlock(cs[1], cs[1]), |
|
SkipBlock( |
|
[ |
|
self.down, |
|
ConvBlock(cs[1], cs[2]), |
|
ConvBlock(cs[2], cs[2]), |
|
SkipBlock( |
|
[ |
|
self.down, |
|
ConvBlock(cs[2], cs[3]), |
|
ConvBlock(cs[3], cs[3]), |
|
SkipBlock( |
|
[ |
|
self.down, |
|
ConvBlock(cs[3], cs[4]), |
|
ConvBlock(cs[4], cs[4]), |
|
SkipBlock( |
|
[ |
|
self.down, |
|
ConvBlock(cs[4], cs[5]), |
|
ConvBlock(cs[5], cs[5]), |
|
ConvBlock(cs[5], cs[5]), |
|
ConvBlock(cs[5], cs[4]), |
|
self.up, |
|
] |
|
), |
|
ConvBlock(cs[4] * 2, cs[4]), |
|
ConvBlock(cs[4], cs[3]), |
|
self.up, |
|
] |
|
), |
|
ConvBlock(cs[3] * 2, cs[3]), |
|
ConvBlock(cs[3], cs[2]), |
|
self.up, |
|
] |
|
), |
|
ConvBlock(cs[2] * 2, cs[2]), |
|
ConvBlock(cs[2], cs[1]), |
|
self.up, |
|
] |
|
), |
|
ConvBlock(cs[1] * 2, cs[1]), |
|
ConvBlock(cs[1], cs[0]), |
|
self.up, |
|
] |
|
), |
|
ConvBlock(cs[0] * 2, cs[0]), |
|
nn.Conv2d(cs[0], 3, 3, padding=1), |
|
) |
|
|
|
def forward(self, input, t): |
|
timestep_embed = expand_to_planes( |
|
self.timestep_embed(t[:, None]), input.shape |
|
) |
|
v = self.net(torch.cat([input, timestep_embed], dim=1)) |
|
alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t)) |
|
pred = input * alphas - v * sigmas |
|
eps = input * sigmas + v * alphas |
|
return DiffusionOutput(v, pred, eps) |
|
|
|
|
|
|
|
if args2.use256 == 0: |
|
sys.stdout.write("Loading 512x512_diffusion_uncond_finetune_008100 ...\n") |
|
sys.stdout.flush() |
|
status.write("Loading 512x512_diffusion_uncond_finetune_008100 ...\n") |
|
diffusion_model = "512x512_diffusion_uncond_finetune_008100" |
|
else: |
|
sys.stdout.write("Loading 256x256_diffusion_uncond ...\n") |
|
sys.stdout.flush() |
|
status.write("Loading 256x256_diffusion_uncond ...\n") |
|
diffusion_model = "256x256_diffusion_uncond" |
|
|
|
if args2.secondarymodel == 1: |
|
use_secondary_model = True |
|
else: |
|
use_secondary_model = False |
|
|
|
|
|
if args2.sampling_mode == "ddim" or args2.sampling_mode == "plms": |
|
timestep_respacing = "ddim" + str( |
|
args2.iterations |
|
) |
|
else: |
|
timestep_respacing = str( |
|
args2.iterations |
|
) |
|
|
|
diffusion_steps = 1000 |
|
|
|
use_checkpoint = True |
|
|
|
|
|
check_model_SHA = False |
|
|
|
model_256_SHA = "983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a" |
|
model_512_SHA = "9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648" |
|
model_secondary_SHA = ( |
|
"983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a" |
|
) |
|
|
|
model_256_link = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt" |
|
model_512_link = "https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt" |
|
model_secondary_link = ( |
|
"https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth" |
|
) |
|
|
|
model_256_path = f"{DefaultPaths.model_path}/256x256_diffusion_uncond.pt" |
|
model_512_path = ( |
|
f"{DefaultPaths.model_path}/512x512_diffusion_uncond_finetune_008100.pt" |
|
) |
|
model_secondary_path = f"{DefaultPaths.model_path}/secondary_model_imagenet_2.pth" |
|
|
|
model_256_downloaded = True |
|
model_512_downloaded = True |
|
model_secondary_downloaded = True |
|
|
|
model_config = model_and_diffusion_defaults() |
|
if diffusion_model == "512x512_diffusion_uncond_finetune_008100": |
|
model_config.update( |
|
{ |
|
"attention_resolutions": "32, 16, 8", |
|
"class_cond": False, |
|
"diffusion_steps": diffusion_steps, |
|
"rescale_timesteps": True, |
|
"timestep_respacing": timestep_respacing, |
|
"image_size": 512, |
|
"learn_sigma": True, |
|
"noise_schedule": "linear", |
|
"num_channels": 256, |
|
"num_head_channels": 64, |
|
"num_res_blocks": 2, |
|
"resblock_updown": True, |
|
"use_checkpoint": use_checkpoint, |
|
"use_fp16": True, |
|
"use_scale_shift_norm": True, |
|
} |
|
) |
|
elif diffusion_model == "256x256_diffusion_uncond": |
|
model_config.update( |
|
{ |
|
"attention_resolutions": "32, 16, 8", |
|
"class_cond": False, |
|
"diffusion_steps": diffusion_steps, |
|
"rescale_timesteps": True, |
|
"timestep_respacing": timestep_respacing, |
|
"image_size": 256, |
|
"learn_sigma": True, |
|
"noise_schedule": "linear", |
|
"num_channels": 256, |
|
"num_head_channels": 64, |
|
"num_res_blocks": 2, |
|
"resblock_updown": True, |
|
"use_checkpoint": use_checkpoint, |
|
"use_fp16": True, |
|
"use_scale_shift_norm": True, |
|
} |
|
) |
|
|
|
secondary_model_ver = 2 |
|
model_default = model_config["image_size"] |
|
|
|
if secondary_model_ver == 2: |
|
secondary_model = SecondaryDiffusionImageNet2() |
|
secondary_model.load_state_dict( |
|
torch.load( |
|
f"{DefaultPaths.model_path}/secondary_model_imagenet_2.pth", |
|
map_location="cpu", |
|
) |
|
) |
|
secondary_model.eval().requires_grad_(False).to(device) |
|
|
|
clip_models = [] |
|
if args2.usevit32 == 1: |
|
sys.stdout.write("Loading ViT-B/32 CLIP model ...\n") |
|
sys.stdout.flush() |
|
status.write("Loading ViT-B/32 CLIP model ...\n") |
|
clip_models.append( |
|
clip.load("ViT-B/32", jit=False)[0].eval().requires_grad_(False).to(device) |
|
) |
|
if args2.usevit16 == 1: |
|
sys.stdout.write("Loading ViT-B/16 CLIP model ...\n") |
|
sys.stdout.flush() |
|
status.write("Loading ViT-B/16 CLIP model ...\n") |
|
clip_models.append( |
|
clip.load("ViT-B/16", jit=False)[0].eval().requires_grad_(False).to(device) |
|
) |
|
if args2.usevit14 == 1: |
|
sys.stdout.write("Loading ViT-L/14 CLIP model ...\n") |
|
sys.stdout.flush() |
|
status.write("Loading ViT-L/14 CLIP model ...\n") |
|
clip_models.append( |
|
clip.load("ViT-L/14", jit=False)[0].eval().requires_grad_(False).to(device) |
|
) |
|
if args2.usern50x4 == 1: |
|
sys.stdout.write("Loading RN50x4 CLIP model ...\n") |
|
sys.stdout.flush() |
|
status.write("Loading RN50x4 CLIP model ...\n") |
|
clip_models.append( |
|
clip.load("RN50x4", jit=False)[0].eval().requires_grad_(False).to(device) |
|
) |
|
if args2.usern50x16 == 1: |
|
sys.stdout.write("Loading RN50x16 CLIP model ...\n") |
|
sys.stdout.flush() |
|
status.write("Loading RN50x16 CLIP model ...\n") |
|
clip_models.append( |
|
clip.load("RN50x16", jit=False)[0].eval().requires_grad_(False).to(device) |
|
) |
|
if args2.usern50x64 == 1: |
|
sys.stdout.write("Loading RN50x64 CLIP model ...\n") |
|
sys.stdout.flush() |
|
status.write("Loading RN50x64 CLIP model ...\n") |
|
clip_models.append( |
|
clip.load("RN50x64", jit=False)[0].eval().requires_grad_(False).to(device) |
|
) |
|
if args2.usern50 == 1: |
|
sys.stdout.write("Loading RN50 CLIP model ...\n") |
|
sys.stdout.flush() |
|
status.write("Loading RN50 CLIP model ...\n") |
|
clip_models.append( |
|
clip.load("RN50", jit=False)[0].eval().requires_grad_(False).to(device) |
|
) |
|
if args2.usern101 == 1: |
|
sys.stdout.write("Loading RN101 CLIP model ...\n") |
|
sys.stdout.flush() |
|
status.write("Loading RN101 CLIP model ...\n") |
|
clip_models.append( |
|
clip.load("RN101", jit=False)[0].eval().requires_grad_(False).to(device) |
|
) |
|
if args2.useslipbase == 1: |
|
sys.stdout.write("Loading SLIP Base model ...\n") |
|
sys.stdout.flush() |
|
SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256) |
|
|
|
import pathlib |
|
|
|
pathlib.PosixPath = pathlib.WindowsPath |
|
sd = torch.load("slip_base_100ep.pt") |
|
real_sd = {} |
|
for k, v in sd["state_dict"].items(): |
|
real_sd[".".join(k.split(".")[1:])] = v |
|
del sd |
|
SLIPB16model.load_state_dict(real_sd) |
|
SLIPB16model.requires_grad_(False).eval().to(device) |
|
clip_models.append(SLIPB16model) |
|
if args2.usesliplarge == 1: |
|
sys.stdout.write("Loading SLIP Large model ...\n") |
|
sys.stdout.flush() |
|
SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256) |
|
|
|
import pathlib |
|
|
|
pathlib.PosixPath = pathlib.WindowsPath |
|
sd = torch.load("slip_large_100ep.pt") |
|
real_sd = {} |
|
for k, v in sd["state_dict"].items(): |
|
real_sd[".".join(k.split(".")[1:])] = v |
|
del sd |
|
SLIPL16model.load_state_dict(real_sd) |
|
SLIPL16model.requires_grad_(False).eval().to(device) |
|
clip_models.append(SLIPL16model) |
|
|
|
normalize = T.Normalize( |
|
mean=[0.48145466, 0.4578275, 0.40821073], |
|
std=[0.26862954, 0.26130258, 0.27577711], |
|
) |
|
status.write("Loading lpips model...\n") |
|
lpips_model = lpips.LPIPS(net="vgg").to(device) |
|
|
|
"""# 3. Settings""" |
|
|
|
|
|
|
|
|
|
|
|
batch_name = "TimeToDisco" |
|
steps = ( |
|
args2.iterations |
|
) |
|
width_height = [args2.sizex, args2.sizey] |
|
clip_guidance_scale = args2.guidancescale |
|
tv_scale = args2.tvscale |
|
range_scale = args2.rangescale |
|
sat_scale = args2.saturationscale |
|
cutn_batches = args2.cutnbatches |
|
|
|
if args2.useaugs == 1: |
|
skip_augs = False |
|
else: |
|
skip_augs = True |
|
|
|
|
|
if args2.seed_image is not None: |
|
init_image = ( |
|
args2.seed_image |
|
) |
|
skip_steps = ( |
|
args2.skipseedtimesteps |
|
) |
|
init_scale = ( |
|
args2.initscale |
|
) |
|
else: |
|
init_image = "" |
|
skip_steps = 0 |
|
init_scale = ( |
|
0 |
|
) |
|
|
|
if init_image == "": |
|
init_image = None |
|
|
|
side_x = args2.sizex |
|
side_y = args2.sizey |
|
|
|
|
|
|
|
diffusion_steps = (1000 // steps) * steps if steps < 1000 else steps |
|
model_config.update( |
|
{ |
|
"timestep_respacing": timestep_respacing, |
|
"diffusion_steps": diffusion_steps, |
|
} |
|
) |
|
|
|
|
|
batchFolder = f"./" |
|
|
|
|
|
|
|
|
|
|
|
"""###Animation Settings""" |
|
|
|
|
|
animation_mode = ( |
|
args2.animation_mode |
|
) |
|
|
|
|
|
|
|
|
|
|
|
video_init_path = "training.mp4" |
|
extract_nth_frame = 2 |
|
|
|
|
|
|
|
|
|
if animation_mode == "Video Input": |
|
videoFramesFolder = "./videoFrames" |
|
|
|
|
|
sys.stdout.write(f"Exporting Video Frames (1 every {extract_nth_frame})...\n") |
|
sys.stdout.flush() |
|
|
|
""" |
|
try: |
|
!rm {videoFramesFolder}/*.jpg |
|
except: |
|
print('') |
|
""" |
|
|
|
|
|
vf = f'"select=not(mod(n\,{extract_nth_frame}))"' |
|
|
|
|
|
os.system( |
|
f"ffmpeg.exe -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key_frames = True |
|
max_frames = args2.max_frames |
|
|
|
|
|
|
|
|
|
if animation_mode == "Video Input": |
|
max_frames = len(glob(f"{videoFramesFolder}/*.jpg")) |
|
|
|
|
|
|
|
|
|
interp_spline = "Linear" |
|
angle = args2.angle |
|
zoom = args2.zoom |
|
translation_x = args2.translation_x |
|
translation_y = args2.translation_y |
|
translation_z = args2.translation_z |
|
rotation_3d_x = args2.rotation_3d_x |
|
rotation_3d_y = args2.rotation_3d_y |
|
rotation_3d_z = args2.rotation_3d_z |
|
midas_depth_model = "dpt_large" |
|
midas_weight = args2.midas_weight |
|
near_plane = args2.near_plane |
|
far_plane = args2.far_plane |
|
fov = args2.fov |
|
padding_mode = "border" |
|
sampling_mode = args2.sampling_mode |
|
|
|
|
|
frames_scale = args2.frames_scale |
|
|
|
frames_skip_steps = ( |
|
args2.frames_skip_steps |
|
) |
|
|
|
if args2.turbo_mode == 1: |
|
turbo_mode = True |
|
else: |
|
turbo_mode = False |
|
turbo_steps = args2.turbo_steps |
|
|
|
|
|
def parse_key_frames(string, prompt_parser=None): |
|
"""Given a string representing frame numbers paired with parameter values at that frame, |
|
return a dictionary with the frame numbers as keys and the parameter values as the values. |
|
|
|
Parameters |
|
---------- |
|
string: string |
|
Frame numbers paired with parameter values at that frame number, in the format |
|
'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...' |
|
prompt_parser: function or None, optional |
|
If provided, prompt_parser will be applied to each string of parameter values. |
|
|
|
Returns |
|
------- |
|
dict |
|
Frame numbers as keys, parameter values at that frame number as values |
|
|
|
Raises |
|
------ |
|
RuntimeError |
|
If the input string does not match the expected format. |
|
|
|
Examples |
|
-------- |
|
>>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)") |
|
{10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'} |
|
|
|
>>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)", prompt_parser=lambda x: x.lower())) |
|
{10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'} |
|
""" |
|
import re |
|
|
|
pattern = r"((?P<frame>[0-9]+):[\s]*[\(](?P<param>[\S\s]*?)[\)])" |
|
frames = dict() |
|
for match_object in re.finditer(pattern, string): |
|
frame = int(match_object.groupdict()["frame"]) |
|
param = match_object.groupdict()["param"] |
|
if prompt_parser: |
|
frames[frame] = prompt_parser(param) |
|
else: |
|
frames[frame] = param |
|
|
|
if frames == {} and len(string) != 0: |
|
raise RuntimeError("Key Frame string not correctly formatted") |
|
return frames |
|
|
|
def get_inbetweens(key_frames, integer=False): |
|
"""Given a dict with frame numbers as keys and a parameter value as values, |
|
return a pandas Series containing the value of the parameter at every frame from 0 to max_frames. |
|
Any values not provided in the input dict are calculated by linear interpolation between |
|
the values of the previous and next provided frames. If there is no previous provided frame, then |
|
the value is equal to the value of the next provided frame, or if there is no next provided frame, |
|
then the value is equal to the value of the previous provided frame. If no frames are provided, |
|
all frame values are NaN. |
|
|
|
Parameters |
|
---------- |
|
key_frames: dict |
|
A dict with integer frame numbers as keys and numerical values of a particular parameter as values. |
|
integer: Bool, optional |
|
If True, the values of the output series are converted to integers. |
|
Otherwise, the values are floats. |
|
|
|
Returns |
|
------- |
|
pd.Series |
|
A Series with length max_frames representing the parameter values for each frame. |
|
|
|
Examples |
|
-------- |
|
>>> max_frames = 5 |
|
>>> get_inbetweens({1: 5, 3: 6}) |
|
0 5.0 |
|
1 5.0 |
|
2 5.5 |
|
3 6.0 |
|
4 6.0 |
|
dtype: float64 |
|
|
|
>>> get_inbetweens({1: 5, 3: 6}, integer=True) |
|
0 5 |
|
1 5 |
|
2 5 |
|
3 6 |
|
4 6 |
|
dtype: int64 |
|
""" |
|
key_frame_series = pd.Series([np.nan for a in range(max_frames)]) |
|
|
|
for i, value in key_frames.items(): |
|
key_frame_series[i] = value |
|
key_frame_series = key_frame_series.astype(float) |
|
|
|
interp_method = interp_spline |
|
|
|
if interp_method == "Cubic" and len(key_frames.items()) <= 3: |
|
interp_method = "Quadratic" |
|
|
|
if interp_method == "Quadratic" and len(key_frames.items()) <= 2: |
|
interp_method = "Linear" |
|
|
|
key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] |
|
key_frame_series[max_frames - 1] = key_frame_series[ |
|
key_frame_series.last_valid_index() |
|
] |
|
|
|
key_frame_series = key_frame_series.interpolate( |
|
method=interp_method.lower(), limit_direction="both" |
|
) |
|
if integer: |
|
return key_frame_series.astype(int) |
|
return key_frame_series |
|
|
|
def split_prompts(prompts): |
|
prompt_series = pd.Series([np.nan for a in range(max_frames)]) |
|
for i, prompt in prompts.items(): |
|
prompt_series[i] = prompt |
|
|
|
prompt_series = prompt_series.ffill().bfill() |
|
return prompt_series |
|
|
|
if key_frames: |
|
try: |
|
angle_series = get_inbetweens(parse_key_frames(angle)) |
|
except RuntimeError as e: |
|
print( |
|
"WARNING: You have selected to use key frames, but you have not " |
|
"formatted `angle` correctly for key frames.\n" |
|
"Attempting to interpret `angle` as " |
|
f'"0: ({angle})"\n' |
|
"Please read the instructions to find out how to use key frames " |
|
"correctly.\n" |
|
) |
|
angle = f"0: ({angle})" |
|
angle_series = get_inbetweens(parse_key_frames(angle)) |
|
|
|
try: |
|
zoom_series = get_inbetweens(parse_key_frames(zoom)) |
|
except RuntimeError as e: |
|
print( |
|
"WARNING: You have selected to use key frames, but you have not " |
|
"formatted `zoom` correctly for key frames.\n" |
|
"Attempting to interpret `zoom` as " |
|
f'"0: ({zoom})"\n' |
|
"Please read the instructions to find out how to use key frames " |
|
"correctly.\n" |
|
) |
|
zoom = f"0: ({zoom})" |
|
zoom_series = get_inbetweens(parse_key_frames(zoom)) |
|
|
|
try: |
|
translation_x_series = get_inbetweens(parse_key_frames(translation_x)) |
|
except RuntimeError as e: |
|
print( |
|
"WARNING: You have selected to use key frames, but you have not " |
|
"formatted `translation_x` correctly for key frames.\n" |
|
"Attempting to interpret `translation_x` as " |
|
f'"0: ({translation_x})"\n' |
|
"Please read the instructions to find out how to use key frames " |
|
"correctly.\n" |
|
) |
|
translation_x = f"0: ({translation_x})" |
|
translation_x_series = get_inbetweens(parse_key_frames(translation_x)) |
|
|
|
try: |
|
translation_y_series = get_inbetweens(parse_key_frames(translation_y)) |
|
except RuntimeError as e: |
|
print( |
|
"WARNING: You have selected to use key frames, but you have not " |
|
"formatted `translation_y` correctly for key frames.\n" |
|
"Attempting to interpret `translation_y` as " |
|
f'"0: ({translation_y})"\n' |
|
"Please read the instructions to find out how to use key frames " |
|
"correctly.\n" |
|
) |
|
translation_y = f"0: ({translation_y})" |
|
translation_y_series = get_inbetweens(parse_key_frames(translation_y)) |
|
|
|
try: |
|
translation_z_series = get_inbetweens(parse_key_frames(translation_z)) |
|
except RuntimeError as e: |
|
print( |
|
"WARNING: You have selected to use key frames, but you have not " |
|
"formatted `translation_z` correctly for key frames.\n" |
|
"Attempting to interpret `translation_z` as " |
|
f'"0: ({translation_z})"\n' |
|
"Please read the instructions to find out how to use key frames " |
|
"correctly.\n" |
|
) |
|
translation_z = f"0: ({translation_z})" |
|
translation_z_series = get_inbetweens(parse_key_frames(translation_z)) |
|
|
|
try: |
|
rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x)) |
|
except RuntimeError as e: |
|
print( |
|
"WARNING: You have selected to use key frames, but you have not " |
|
"formatted `rotation_3d_x` correctly for key frames.\n" |
|
"Attempting to interpret `rotation_3d_x` as " |
|
f'"0: ({rotation_3d_x})"\n' |
|
"Please read the instructions to find out how to use key frames " |
|
"correctly.\n" |
|
) |
|
rotation_3d_x = f"0: ({rotation_3d_x})" |
|
rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x)) |
|
|
|
try: |
|
rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y)) |
|
except RuntimeError as e: |
|
print( |
|
"WARNING: You have selected to use key frames, but you have not " |
|
"formatted `rotation_3d_y` correctly for key frames.\n" |
|
"Attempting to interpret `rotation_3d_y` as " |
|
f'"0: ({rotation_3d_y})"\n' |
|
"Please read the instructions to find out how to use key frames " |
|
"correctly.\n" |
|
) |
|
rotation_3d_y = f"0: ({rotation_3d_y})" |
|
rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y)) |
|
|
|
try: |
|
rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z)) |
|
except RuntimeError as e: |
|
print( |
|
"WARNING: You have selected to use key frames, but you have not " |
|
"formatted `rotation_3d_z` correctly for key frames.\n" |
|
"Attempting to interpret `rotation_3d_z` as " |
|
f'"0: ({rotation_3d_z})"\n' |
|
"Please read the instructions to find out how to use key frames " |
|
"correctly.\n" |
|
) |
|
rotation_3d_z = f"0: ({rotation_3d_z})" |
|
rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z)) |
|
|
|
else: |
|
angle = float(angle) |
|
zoom = float(zoom) |
|
translation_x = float(translation_x) |
|
translation_y = float(translation_y) |
|
translation_z = float(translation_z) |
|
rotation_3d_x = float(rotation_3d_x) |
|
rotation_3d_y = float(rotation_3d_y) |
|
rotation_3d_z = float(rotation_3d_z) |
|
|
|
"""### Extra Settings |
|
Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling |
|
""" |
|
|
|
|
|
|
|
intermediate_saves = 0 |
|
intermediates_in_subfolder = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
if type(intermediate_saves) is not list: |
|
if intermediate_saves: |
|
steps_per_checkpoint = math.floor( |
|
(steps - skip_steps - 1) // (intermediate_saves + 1) |
|
) |
|
steps_per_checkpoint = ( |
|
steps_per_checkpoint if steps_per_checkpoint > 0 else 1 |
|
) |
|
print(f"Will save every {steps_per_checkpoint} steps") |
|
else: |
|
steps_per_checkpoint = steps + 10 |
|
else: |
|
steps_per_checkpoint = None |
|
|
|
if intermediate_saves and intermediates_in_subfolder is True: |
|
partialFolder = f"{batchFolder}/partials" |
|
createPath(partialFolder) |
|
|
|
|
|
|
|
|
|
|
|
sharpen_preset = "Off" |
|
keep_unsharp = True |
|
|
|
if sharpen_preset != "Off" and keep_unsharp is True: |
|
unsharpenFolder = f"{batchFolder}/unsharpened" |
|
createPath(unsharpenFolder) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args2.perlin_init == 1: |
|
perlin_init = True |
|
else: |
|
perlin_init = False |
|
perlin_mode = args2.perlin_mode |
|
|
|
set_seed = "random_seed" |
|
eta = args2.eta |
|
clamp_grad = True |
|
clamp_max = args2.clampmax |
|
|
|
|
|
randomize_class = True |
|
if args2.denoised == 1: |
|
clip_denoised = True |
|
else: |
|
clip_denoised = False |
|
fuzzy_prompt = False |
|
rand_mag = 0.05 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cut_overview = "[12]*400+[4]*600" |
|
cut_innercut = "[4]*400+[12]*600" |
|
cut_ic_pow = 1 |
|
cut_icgray_p = "[0.2]*400+[0]*600" |
|
|
|
"""###Prompts |
|
`animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one. |
|
""" |
|
|
|
""" |
|
text_prompts = { |
|
0: ["A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation.", "yellow color scheme"], |
|
100: ["This set of prompts start at frame 100","This prompt has weight five:5"], |
|
} |
|
""" |
|
|
|
text_prompts = {0: [phrase.strip() for phrase in args2.prompt.split("|")]} |
|
|
|
image_prompts = { |
|
|
|
} |
|
|
|
"""# 4. Diffuse!""" |
|
|
|
|
|
|
|
display_rate = args2.update |
|
n_batches = 1 |
|
|
|
batch_size = 1 |
|
|
|
def move_files(start_num, end_num, old_folder, new_folder): |
|
for i in range(start_num, end_num): |
|
old_file = old_folder + f"/{batch_name}({batchNum})_{i:04}.png" |
|
new_file = new_folder + f"/{batch_name}({batchNum})_{i:04}.png" |
|
os.rename(old_file, new_file) |
|
|
|
|
|
|
|
resume_run = False |
|
run_to_resume = "latest" |
|
resume_from_frame = "latest" |
|
retain_overwritten_frames = False |
|
if retain_overwritten_frames is True: |
|
retainFolder = f"{batchFolder}/retained" |
|
createPath(retainFolder) |
|
|
|
skip_step_ratio = int(frames_skip_steps.rstrip("%")) / 100 |
|
calc_frames_skip_steps = math.floor(steps * skip_step_ratio) |
|
|
|
if steps <= calc_frames_skip_steps: |
|
sys.exit("ERROR: You can't skip more steps than your total steps") |
|
|
|
""" |
|
if resume_run: |
|
if run_to_resume == 'latest': |
|
try: |
|
batchNum |
|
except: |
|
batchNum = len(glob(f"{batchFolder}/{batch_name}(*)_settings.txt"))-1 |
|
else: |
|
batchNum = int(run_to_resume) |
|
if resume_from_frame == 'latest': |
|
start_frame = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png")) |
|
else: |
|
start_frame = int(resume_from_frame)+1 |
|
if retain_overwritten_frames is True: |
|
existing_frames = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png")) |
|
frames_to_save = existing_frames - start_frame |
|
print(f'Moving {frames_to_save} frames to the Retained folder') |
|
move_files(start_frame, existing_frames, batchFolder, retainFolder) |
|
else: |
|
""" |
|
start_frame = 0 |
|
batchNum = 1 |
|
""" |
|
batchNum = len(glob(batchFolder+"/*.txt")) |
|
while path.isfile(f"{batchFolder}/{batch_name}({batchNum})_settings.txt") is True or path.isfile(f"{batchFolder}/{batch_name}-{batchNum}_settings.txt") is True: |
|
batchNum += 1 |
|
""" |
|
|
|
|
|
if set_seed == "random_seed": |
|
random.seed() |
|
seed = random.randint(0, 2**32) |
|
|
|
else: |
|
seed = int(set_seed) |
|
|
|
args = { |
|
"batchNum": batchNum, |
|
"prompts_series": split_prompts(text_prompts) if text_prompts else None, |
|
"image_prompts_series": split_prompts(image_prompts) if image_prompts else None, |
|
"seed": seed, |
|
"display_rate": display_rate, |
|
"n_batches": n_batches if animation_mode == "None" else 1, |
|
"batch_size": batch_size, |
|
"batch_name": batch_name, |
|
"steps": steps, |
|
"width_height": width_height, |
|
"clip_guidance_scale": clip_guidance_scale, |
|
"tv_scale": tv_scale, |
|
"range_scale": range_scale, |
|
"sat_scale": sat_scale, |
|
"cutn_batches": cutn_batches, |
|
"init_image": init_image, |
|
"init_scale": init_scale, |
|
"skip_steps": skip_steps, |
|
"sharpen_preset": sharpen_preset, |
|
"keep_unsharp": keep_unsharp, |
|
"side_x": side_x, |
|
"side_y": side_y, |
|
"timestep_respacing": timestep_respacing, |
|
"diffusion_steps": diffusion_steps, |
|
"animation_mode": animation_mode, |
|
"video_init_path": video_init_path, |
|
"extract_nth_frame": extract_nth_frame, |
|
"key_frames": key_frames, |
|
"max_frames": max_frames if animation_mode != "None" else 1, |
|
"interp_spline": interp_spline, |
|
"start_frame": start_frame, |
|
"angle": angle, |
|
"zoom": zoom, |
|
"translation_x": translation_x, |
|
"translation_y": translation_y, |
|
"translation_z": translation_z, |
|
"rotation_3d_x": rotation_3d_x, |
|
"rotation_3d_y": rotation_3d_y, |
|
"rotation_3d_z": rotation_3d_z, |
|
"midas_depth_model": midas_depth_model, |
|
"midas_weight": midas_weight, |
|
"near_plane": near_plane, |
|
"far_plane": far_plane, |
|
"fov": fov, |
|
"padding_mode": padding_mode, |
|
"sampling_mode": sampling_mode, |
|
"angle_series": angle_series, |
|
"zoom_series": zoom_series, |
|
"translation_x_series": translation_x_series, |
|
"translation_y_series": translation_y_series, |
|
"translation_z_series": translation_z_series, |
|
"rotation_3d_x_series": rotation_3d_x_series, |
|
"rotation_3d_y_series": rotation_3d_y_series, |
|
"rotation_3d_z_series": rotation_3d_z_series, |
|
"frames_scale": frames_scale, |
|
"calc_frames_skip_steps": calc_frames_skip_steps, |
|
"skip_step_ratio": skip_step_ratio, |
|
"calc_frames_skip_steps": calc_frames_skip_steps, |
|
"text_prompts": text_prompts, |
|
"image_prompts": image_prompts, |
|
"cut_overview": eval(cut_overview), |
|
"cut_innercut": eval(cut_innercut), |
|
"cut_ic_pow": cut_ic_pow, |
|
"cut_icgray_p": eval(cut_icgray_p), |
|
"intermediate_saves": intermediate_saves, |
|
"intermediates_in_subfolder": intermediates_in_subfolder, |
|
"steps_per_checkpoint": steps_per_checkpoint, |
|
"perlin_init": perlin_init, |
|
"perlin_mode": perlin_mode, |
|
"set_seed": set_seed, |
|
"eta": eta, |
|
"clamp_grad": clamp_grad, |
|
"clamp_max": clamp_max, |
|
"skip_augs": skip_augs, |
|
"randomize_class": randomize_class, |
|
"clip_denoised": clip_denoised, |
|
"fuzzy_prompt": fuzzy_prompt, |
|
"rand_mag": rand_mag, |
|
} |
|
|
|
args = SimpleNamespace(**args) |
|
|
|
print("Prepping model...") |
|
model, diffusion = create_model_and_diffusion(**model_config) |
|
model.load_state_dict( |
|
torch.load( |
|
f"{DefaultPaths.model_path}/{diffusion_model}.pt", map_location="cpu" |
|
) |
|
) |
|
model.requires_grad_(False).eval().to(device) |
|
for name, param in model.named_parameters(): |
|
if "qkv" in name or "norm" in name or "proj" in name: |
|
param.requires_grad_() |
|
if model_config["use_fp16"]: |
|
model.convert_to_fp16() |
|
|
|
sys.stdout.write("Starting ...\n") |
|
sys.stdout.flush() |
|
status.write(f"Starting ...\n") |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
try: |
|
do_run() |
|
|
|
|
|
|
|
except KeyboardInterrupt: |
|
pass |
|
finally: |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|