diff --git "a/disco_streamlit_run.py" "b/disco_streamlit_run.py" new file mode 100644--- /dev/null +++ "b/disco_streamlit_run.py" @@ -0,0 +1,2522 @@ +# Disco Diffusion v5 [w/ 3D animation] (modified by @softology to work on Visions of Chaos and further modified by @multimodalart to run on MindsEye) +# Adapted from the Visions of Chaos software (https://softology.pro/voc.htm), that adapted it from the +# Original file is located at https://colab.research.google.com/github/alembics/disco-diffusion/blob/main/Disco_Diffusion.ipynb + +# required models +# https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt +# https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt +# git clone https://github.com/isl-org/MiDaS.git +# git clone https://github.com/alembics/disco-diffusion.git + + +"""#Tutorial + +**Diffusion settings (Defaults are heavily outdated)** +--- + +This section is outdated as of v2 + +Setting | Description | Default +--- | --- | --- +**Your vision:** +`text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A +`image_prompts` | Think of these images more as a description of their contents. | N/A +**Image quality:** +`clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000 +`tv_scale` | Controls the smoothness of the final output. | 150 +`range_scale` | Controls how far out of range RGB values are allowed to be. | 150 +`sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0 +`cutn` | Controls how many crops to take from the image. | 16 +`cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2 +**Init settings:** +`init_image` | URL or local path | None +`init_scale` | This enhances the effect of the init image, a good value is 1000 | 0 +`skip_steps Controls the starting point along the diffusion timesteps | 0 +`perlin_init` | Option to start with random perlin noise | False +`perlin_mode` | ('gray', 'color') | 'mixed' +**Advanced:** +`skip_augs` |Controls whether to skip torchvision augmentations | False +`randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True +`clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False +`clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True +`seed` | Choose a random seed and print it at end of run for reproduction | random_seed +`fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False +`rand_mag` |Controls the magnitude of the random noise | 0.1 +`eta` | DDIM hyperparameter | 0.5 + +.. + +**Model settings** +--- + +Setting | Description | Default +--- | --- | --- +**Diffusion:** +`timestep_respacing` | Modify this value to decrease the number of timesteps. | ddim100 +`diffusion_steps` || 1000 +**Diffusion:** +`clip_models` | Models of CLIP to load. Typically the more, the better but they all come at a hefty VRAM cost. | ViT-B/32, ViT-B/16, RN50x4 + +# 1. Set Up +""" + + +is_colab = False +google_drive = False +save_models_to_google_drive = False + +import sys + +sys.stdout.write("Imports ...\n") +sys.stdout.flush() + +sys.path.append("./ResizeRight") +sys.path.append("./MiDaS") +sys.path.append("./CLIP") +sys.path.append("./guided-diffusion") +sys.path.append("./latent-diffusion") +sys.path.append(".") +sys.path.append("./taming-transformers") +sys.path.append("./disco-diffusion") +sys.path.append("./AdaBins") +sys.path.append('./pytorch3d-lite') +# sys.path.append('./pytorch3d') + +import os +import streamlit as st +from os import path +from os.path import exists as path_exists +import sys +import torch + +# sys.path.append('./SLIP') +from dataclasses import dataclass +from functools import partial +import cv2 +import pandas as pd +import gc +import io +import math +import timm +from IPython import display +import lpips +from PIL import Image, ImageOps +import requests +from glob import glob +import json +from types import SimpleNamespace +from torch import nn +from torch.nn import functional as F +import torchvision.transforms as T +import torchvision.transforms.functional as TF +import shutil +from pathvalidate import sanitize_filename + +# from tqdm.notebook import tqdm +# from stqdm_local import stqdm +import clip +from resize_right import resize + +# from models import SLIP_VITB16, SLIP, SLIP_VITL16 +from guided_diffusion.script_util import ( + create_model_and_diffusion, + model_and_diffusion_defaults, +) +from datetime import datetime +import numpy as np +import matplotlib.pyplot as plt +import random +from ipywidgets import Output +import hashlib +import ipywidgets as widgets +import os + +# from taming.models import vqgan # checking correct import from taming +from torchvision.datasets.utils import download_url +from functools import partial +from ldm.util import instantiate_from_config +from ldm.modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) + +# from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import ismap +from IPython.display import Image as ipyimg +from numpy import asarray +from einops import rearrange, repeat +import torch, torchvision +import time +from omegaconf import OmegaConf +from midas.dpt_depth import DPTDepthModel +from midas.midas_net import MidasNet +from midas.midas_net_custom import MidasNet_small +from midas.transforms import Resize, NormalizeImage, PrepareForNet +import torch +import py3d_tools as p3dT +import disco_xform_utils as dxf +import argparse + +sys.stdout.write("Parsing arguments ...\n") +sys.stdout.flush() + + +def run_model(args2, status, stoutput, DefaultPaths): + if args2.seed is not None: + sys.stdout.write(f"Setting seed to {args2.seed} ...\n") + sys.stdout.flush() + status.write(f"Setting seed to {args2.seed} ...\n") + import numpy as np + + np.random.seed(args2.seed) + import random + + random.seed(args2.seed) + # next line forces deterministic random values, but causes other issues with resampling (uncomment to see) + # torch.use_deterministic_algorithms(True) + torch.manual_seed(args2.seed) + torch.cuda.manual_seed(args2.seed) + torch.cuda.manual_seed_all(args2.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print("Using device:", DEVICE) + device = DEVICE # At least one of the modules expects this name.. + + # If running locally, there's a good chance your env will need this in order to not crash upon np.matmul() or similar operations. + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + + PROJECT_DIR = os.path.abspath(os.getcwd()) + + # AdaBins stuff + USE_ADABINS = True + if USE_ADABINS: + sys.path.append("./AdaBins") + from infer import InferenceHelper + + MAX_ADABINS_AREA = 500000 + + model_256_downloaded = False + model_512_downloaded = False + model_secondary_downloaded = False + + # Initialize MiDaS depth model. + # It remains resident in VRAM and likely takes around 2GB VRAM. + # You could instead initialize it for each frame (and free it after each frame) to save VRAM.. but initializing it is slow. + default_models = { + "midas_v21_small": f"{DefaultPaths.model_path}/midas_v21_small-70d6b9c8.pt", + "midas_v21": f"{DefaultPaths.model_path}/midas_v21-f6b98070.pt", + "dpt_large": f"{DefaultPaths.model_path}/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": f"{DefaultPaths.model_path}/dpt_hybrid-midas-501f0c75.pt", + "dpt_hybrid_nyu": f"{DefaultPaths.model_path}/dpt_hybrid_nyu-2ce69ec7.pt", + } + + def init_midas_depth_model(midas_model_type="dpt_large", optimize=True): + midas_model = None + net_w = None + net_h = None + resize_mode = None + normalization = None + + print(f"Initializing MiDaS '{midas_model_type}' depth model...") + # load network + midas_model_path = default_models[midas_model_type] + + if midas_model_type == "dpt_large": # DPT-Large + midas_model = DPTDepthModel( + path=midas_model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + elif midas_model_type == "dpt_hybrid": # DPT-Hybrid + midas_model = DPTDepthModel( + path=midas_model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + elif midas_model_type == "dpt_hybrid_nyu": # DPT-Hybrid-NYU + midas_model = DPTDepthModel( + path=midas_model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + elif midas_model_type == "midas_v21": + midas_model = MidasNet(midas_model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + elif midas_model_type == "midas_v21_small": + midas_model = MidasNet_small( + midas_model_path, + features=64, + backbone="efficientnet_lite3", + exportable=True, + non_negative=True, + blocks={"expand": True}, + ) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + else: + print(f"midas_model_type '{midas_model_type}' not implemented") + assert False + + midas_transform = T.Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + midas_model.eval() + + if optimize == True: + if DEVICE == torch.device("cuda"): + midas_model = midas_model.to(memory_format=torch.channels_last) + midas_model = midas_model.half() + + midas_model.to(DEVICE) + + print(f"MiDaS '{midas_model_type}' depth model initialized.") + return midas_model, midas_transform, net_w, net_h, resize_mode, normalization + + # @title 1.5 Define necessary functions + + # https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869 + + def interp(t): + return 3 * t**2 - 2 * t**3 + + def perlin(width, height, scale=10, device=None): + gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device) + xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device) + ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device) + wx = 1 - interp(xs) + wy = 1 - interp(ys) + dots = 0 + dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys) + dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys) + dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys)) + dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys)) + return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale) + + def perlin_ms(octaves, width, height, grayscale, device=device): + out_array = [0.5] if grayscale else [0.5, 0.5, 0.5] + # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0] + for i in range(1 if grayscale else 3): + scale = 2 ** len(octaves) + oct_width = width + oct_height = height + for oct in octaves: + p = perlin(oct_width, oct_height, scale, device) + out_array[i] += p * oct + scale //= 2 + oct_width *= 2 + oct_height *= 2 + return torch.cat(out_array) + + def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True): + out = perlin_ms(octaves, width, height, grayscale) + if grayscale: + out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0)) + out = TF.to_pil_image(out.clamp(0, 1)).convert("RGB") + else: + out = out.reshape(-1, 3, out.shape[0] // 3, out.shape[1]) + out = TF.resize(size=(side_y, side_x), img=out) + out = TF.to_pil_image(out.clamp(0, 1).squeeze()) + + out = ImageOps.autocontrast(out) + return out + + def regen_perlin(): + if perlin_mode == "color": + init = create_perlin_noise( + [1.5**-i * 0.5 for i in range(12)], 1, 1, False + ) + init2 = create_perlin_noise( + [1.5**-i * 0.5 for i in range(8)], 4, 4, False + ) + elif perlin_mode == "gray": + init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1, True) + init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, True) + else: + init = create_perlin_noise( + [1.5**-i * 0.5 for i in range(12)], 1, 1, False + ) + init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, True) + + init = ( + TF.to_tensor(init) + .add(TF.to_tensor(init2)) + .div(2) + .to(device) + .unsqueeze(0) + .mul(2) + .sub(1) + ) + del init2 + return init.expand(batch_size, -1, -1, -1) + + def fetch(url_or_path): + if str(url_or_path).startswith("http://") or str(url_or_path).startswith( + "https://" + ): + r = requests.get(url_or_path) + r.raise_for_status() + fd = io.BytesIO() + fd.write(r.content) + fd.seek(0) + return fd + return open(url_or_path, "rb") + + def read_image_workaround(path): + """OpenCV reads images as BGR, Pillow saves them as RGB. Work around + this incompatibility to avoid colour inversions.""" + im_tmp = cv2.imread(path) + return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB) + + def parse_prompt(prompt): + if prompt.startswith("http://") or prompt.startswith("https://"): + vals = prompt.rsplit(":", 2) + vals = [vals[0] + ":" + vals[1], *vals[2:]] + else: + vals = prompt.rsplit(":", 1) + vals = vals + ["", "1"][len(vals) :] + return vals[0], float(vals[1]) + + def sinc(x): + return torch.where( + x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]) + ) + + def lanczos(x, a): + cond = torch.logical_and(-a < x, x < a) + out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([])) + return out / out.sum() + + def ramp(ratio, width): + n = math.ceil(width / ratio + 1) + out = torch.empty([n]) + cur = 0 + for i in range(out.shape[0]): + out[i] = cur + cur += ratio + return torch.cat([-out[1:].flip([0]), out])[1:-1] + + def resample(input, size, align_corners=True): + n, c, h, w = input.shape + dh, dw = size + + input = input.reshape([n * c, 1, h, w]) + + if dh < h: + kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) + pad_h = (kernel_h.shape[0] - 1) // 2 + input = F.pad(input, (0, 0, pad_h, pad_h), "reflect") + input = F.conv2d(input, kernel_h[None, None, :, None]) + + if dw < w: + kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) + pad_w = (kernel_w.shape[0] - 1) // 2 + input = F.pad(input, (pad_w, pad_w, 0, 0), "reflect") + input = F.conv2d(input, kernel_w[None, None, None, :]) + + input = input.reshape([n, c, h, w]) + return F.interpolate(input, size, mode="bicubic", align_corners=align_corners) + + class MakeCutouts(nn.Module): + def __init__(self, cut_size, cutn, skip_augs=False): + super().__init__() + self.cut_size = cut_size + self.cutn = cutn + self.skip_augs = skip_augs + self.augs = T.Compose( + [ + T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine(degrees=15, translate=(0.1, 0.1)), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomPerspective(distortion_scale=0.4, p=0.7), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.15), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), + ] + ) + + def forward(self, input): + input = T.Pad(input.shape[2] // 4, fill=0)(input) + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + + cutouts = [] + for ch in range(self.cutn): + if ch > self.cutn - self.cutn // 4: + cutout = input.clone() + else: + size = int( + max_size + * torch.zeros( + 1, + ) + .normal_(mean=0.8, std=0.3) + .clip(float(self.cut_size / max_size), 1.0) + ) + offsetx = torch.randint(0, abs(sideX - size + 1), ()) + offsety = torch.randint(0, abs(sideY - size + 1), ()) + cutout = input[ + :, :, offsety : offsety + size, offsetx : offsetx + size + ] + + if not self.skip_augs: + cutout = self.augs(cutout) + cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) + del cutout + + cutouts = torch.cat(cutouts, dim=0) + return cutouts + + cutout_debug = False + padargs = {} + + class MakeCutoutsDango(nn.Module): + def __init__( + self, cut_size, Overview=4, InnerCrop=0, IC_Size_Pow=0.5, IC_Grey_P=0.2 + ): + super().__init__() + self.cut_size = cut_size + self.Overview = Overview + self.InnerCrop = InnerCrop + self.IC_Size_Pow = IC_Size_Pow + self.IC_Grey_P = IC_Grey_P + if args.animation_mode == "None": + self.augs = T.Compose( + [ + T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine( + degrees=10, + translate=(0.05, 0.05), + interpolation=T.InterpolationMode.BILINEAR, + ), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.1), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.ColorJitter( + brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1 + ), + ] + ) + elif args.animation_mode == "Video Input": + self.augs = T.Compose( + [ + T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine(degrees=15, translate=(0.1, 0.1)), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomPerspective(distortion_scale=0.4, p=0.7), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.15), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), + ] + ) + elif args.animation_mode == "2D" or args.animation_mode == "3D": + self.augs = T.Compose( + [ + T.RandomHorizontalFlip(p=0.4), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine( + degrees=10, + translate=(0.05, 0.05), + interpolation=T.InterpolationMode.BILINEAR, + ), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.1), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.ColorJitter( + brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3 + ), + ] + ) + + def forward(self, input): + cutouts = [] + gray = T.Grayscale(3) + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + l_size = max(sideX, sideY) + output_shape = [1, 3, self.cut_size, self.cut_size] + output_shape_2 = [1, 3, self.cut_size + 2, self.cut_size + 2] + pad_input = F.pad( + input, + ( + (sideY - max_size) // 2, + (sideY - max_size) // 2, + (sideX - max_size) // 2, + (sideX - max_size) // 2, + ), + **padargs, + ) + cutout = resize(pad_input, out_shape=output_shape) + + if self.Overview > 0: + if self.Overview <= 4: + if self.Overview >= 1: + cutouts.append(cutout) + if self.Overview >= 2: + cutouts.append(gray(cutout)) + if self.Overview >= 3: + cutouts.append(TF.hflip(cutout)) + if self.Overview == 4: + cutouts.append(gray(TF.hflip(cutout))) + else: + cutout = resize(pad_input, out_shape=output_shape) + for _ in range(self.Overview): + cutouts.append(cutout) + + if cutout_debug: + if is_colab: + TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save( + "/content/cutout_overview0.jpg", quality=99 + ) + else: + TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save( + "cutout_overview0.jpg", quality=99 + ) + + if self.InnerCrop > 0: + for i in range(self.InnerCrop): + size = int( + torch.rand([]) ** self.IC_Size_Pow * (max_size - min_size) + + min_size + ) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[ + :, :, offsety : offsety + size, offsetx : offsetx + size + ] + if i <= int(self.IC_Grey_P * self.InnerCrop): + cutout = gray(cutout) + cutout = resize(cutout, out_shape=output_shape) + cutouts.append(cutout) + if cutout_debug: + if is_colab: + TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save( + "/content/cutout_InnerCrop.jpg", quality=99 + ) + else: + TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save( + "cutout_InnerCrop.jpg", quality=99 + ) + cutouts = torch.cat(cutouts) + if skip_augs is not True: + cutouts = self.augs(cutouts) + return cutouts + + def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + + def tv_loss(input): + """L2 total variation loss, as in Mahendran et al.""" + input = F.pad(input, (0, 1, 0, 1), "replicate") + x_diff = input[..., :-1, 1:] - input[..., :-1, :-1] + y_diff = input[..., 1:, :-1] - input[..., :-1, :-1] + return (x_diff**2 + y_diff**2).mean([1, 2, 3]) + + def range_loss(input): + return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3]) + + stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete + + def nsToStr(d): + h = 3.6e12 + m = h / 60 + s = m / 60 + return ( + str(int(d / h)) + + ":" + + str(int((d % h) / m)) + + ":" + + str(int((d % h) % m / s)) + + "." + + str(int((d % h) % m % s)) + ) + + def do_run(): + seed = args.seed + # print(range(args.start_frame, args.max_frames)) + + if (args.animation_mode == "3D") and (args.midas_weight > 0.0): + ( + midas_model, + midas_transform, + midas_net_w, + midas_net_h, + midas_resize_mode, + midas_normalization, + ) = init_midas_depth_model(args.midas_depth_model) + for frame_num in range(args.start_frame, args.max_frames): + if stop_on_next_loop: + break + + display.clear_output(wait=True) + + # Print Frame progress if animation mode is on + + """ + if args.animation_mode != "None": + batchBar = tqdm(range(args.max_frames), desc ="Frames") + batchBar.n = frame_num + batchBar.refresh() + """ + + # Inits if not video frames + if args.animation_mode != "Video Input": + if args.init_image == "": + init_image = None + else: + init_image = args.init_image + init_scale = args.init_scale + skip_steps = args.skip_steps + + if args.animation_mode == "2D": + if args.key_frames: + angle = args.angle_series[frame_num] + zoom = args.zoom_series[frame_num] + translation_x = args.translation_x_series[frame_num] + translation_y = args.translation_y_series[frame_num] + print( + f"angle: {angle}", + f"zoom: {zoom}", + f"translation_x: {translation_x}", + f"translation_y: {translation_y}", + ) + + if frame_num > 0: + seed = seed + 1 + if resume_run and frame_num == start_frame: + img_0 = cv2.imread( + batchFolder + + f"/{batch_name}({batchNum})_{start_frame-1:04}.png" + ) + else: + img_0 = cv2.imread("prevFrame.png") + center = (1 * img_0.shape[1] // 2, 1 * img_0.shape[0] // 2) + trans_mat = np.float32( + [[1, 0, translation_x], [0, 1, translation_y]] + ) + rot_mat = cv2.getRotationMatrix2D(center, angle, zoom) + trans_mat = np.vstack([trans_mat, [0, 0, 1]]) + rot_mat = np.vstack([rot_mat, [0, 0, 1]]) + transformation_matrix = np.matmul(rot_mat, trans_mat) + img_0 = cv2.warpPerspective( + img_0, + transformation_matrix, + (img_0.shape[1], img_0.shape[0]), + borderMode=cv2.BORDER_WRAP, + ) + + cv2.imwrite("prevFrameScaled.png", img_0) + init_image = "prevFrameScaled.png" + init_scale = args.frames_scale + skip_steps = args.calc_frames_skip_steps + + if args.animation_mode == "3D": + if args.key_frames: + angle = args.angle_series[frame_num] + # zoom = args.zoom_series[frame_num] + translation_x = args.translation_x_series[frame_num] + translation_y = args.translation_y_series[frame_num] + translation_z = args.translation_z_series[frame_num] + rotation_3d_x = args.rotation_3d_x_series[frame_num] + rotation_3d_y = args.rotation_3d_y_series[frame_num] + rotation_3d_z = args.rotation_3d_z_series[frame_num] + print( + f"angle: {angle}", + # f'zoom: {zoom}', + f"translation_x: {translation_x}", + f"translation_y: {translation_y}", + f"translation_z: {translation_z}", + f"rotation_3d_x: {rotation_3d_x}", + f"rotation_3d_y: {rotation_3d_y}", + f"rotation_3d_z: {rotation_3d_z}", + ) + + sys.stdout.flush() + # sys.stdout.write(f'FRAME_NUM = {frame_num} ...\n') + sys.stdout.flush() + + if frame_num > 0: + seed = seed + 1 + img_filepath = "prevFrame.png" + trans_scale = 1.0 / 200.0 + translate_xyz = [ + -translation_x * trans_scale, + translation_y * trans_scale, + -translation_z * trans_scale, + ] + rotate_xyz = [ + math.radians(rotation_3d_x), + math.radians(rotation_3d_y), + math.radians(rotation_3d_z), + ] + print("translation:", translate_xyz) + print("rotation:", rotate_xyz) + rot_mat = p3dT.euler_angles_to_matrix( + torch.tensor(rotate_xyz, device=device), "XYZ" + ).unsqueeze(0) + print("rot_mat: " + str(rot_mat)) + next_step_pil = dxf.transform_image_3d( + img_filepath, + midas_model, + midas_transform, + DEVICE, + rot_mat, + translate_xyz, + args.near_plane, + args.far_plane, + args.fov, + padding_mode=args.padding_mode, + sampling_mode=args.sampling_mode, + midas_weight=args.midas_weight, + ) + next_step_pil.save("prevFrameScaled.png") + + """ + ### Turbo mode - skip some diffusions to save time + if turbo_mode == True and frame_num > 10 and frame_num % int(turbo_steps) != 0: + #turbo_steps + print('turbo mode is on this frame: skipping clip diffusion steps') + #this is an even frame. copy warped prior frame w/ war + #filename = f'{args.batch_name}({args.batchNum})_{frame_num:04}.png' + #next_step_pil.save(f'{batchFolder}/{filename}') #save it as this frame + #next_step_pil.save(f'{img_filepath}') # save it also as prev_frame for next iteration + filename = f'progress.png' + next_step_pil.save(f'{filename}') #save it as this frame + next_step_pil.save(f'{img_filepath}') # save it also as prev_frame for next iteration + continue + elif turbo_mode == True: + print('turbo mode is OFF this frame') + #else: no turbo + """ + + init_image = "prevFrameScaled.png" + init_scale = args.frames_scale + skip_steps = args.calc_frames_skip_steps + + if args.animation_mode == "Video Input": + seed = seed + 1 + init_image = f"{videoFramesFolder}/{frame_num+1:04}.jpg" + init_scale = args.frames_scale + skip_steps = args.calc_frames_skip_steps + + loss_values = [] + + if seed is not None: + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + target_embeds, weights = [], [] + + if args.prompts_series is not None and frame_num >= len( + args.prompts_series + ): + frame_prompt = args.prompts_series[-1] + elif args.prompts_series is not None: + frame_prompt = args.prompts_series[frame_num] + else: + frame_prompt = [] + + print(args.image_prompts_series) + if args.image_prompts_series is not None and frame_num >= len( + args.image_prompts_series + ): + image_prompt = args.image_prompts_series[-1] + elif args.image_prompts_series is not None: + image_prompt = args.image_prompts_series[frame_num] + else: + image_prompt = [] + + print(f"Frame Prompt: {frame_prompt}") + + model_stats = [] + for clip_model in clip_models: + cutn = args2.cutn + model_stat = { + "clip_model": None, + "target_embeds": [], + "make_cutouts": None, + "weights": [], + } + model_stat["clip_model"] = clip_model + + for prompt in frame_prompt: + txt, weight = parse_prompt(prompt) + txt = clip_model.encode_text( + clip.tokenize(prompt).to(device) + ).float() + + if args.fuzzy_prompt: + for i in range(25): + model_stat["target_embeds"].append( + ( + txt + torch.randn(txt.shape).cuda() * args.rand_mag + ).clamp(0, 1) + ) + model_stat["weights"].append(weight) + else: + model_stat["target_embeds"].append(txt) + model_stat["weights"].append(weight) + + if image_prompt: + model_stat["make_cutouts"] = MakeCutouts( + clip_model.visual.input_resolution, cutn, skip_augs=skip_augs + ) + for prompt in image_prompt: + path, weight = parse_prompt(prompt) + img = Image.open(fetch(path)).convert("RGB") + img = TF.resize( + img, + min(side_x, side_y, *img.size), + T.InterpolationMode.LANCZOS, + ) + batch = model_stat["make_cutouts"]( + TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1) + ) + embed = clip_model.encode_image(normalize(batch)).float() + if fuzzy_prompt: + for i in range(25): + model_stat["target_embeds"].append( + ( + embed + + torch.randn(embed.shape).cuda() * rand_mag + ).clamp(0, 1) + ) + weights.extend([weight / cutn] * cutn) + else: + model_stat["target_embeds"].append(embed) + model_stat["weights"].extend([weight / cutn] * cutn) + + model_stat["target_embeds"] = torch.cat(model_stat["target_embeds"]) + model_stat["weights"] = torch.tensor( + model_stat["weights"], device=device + ) + if model_stat["weights"].sum().abs() < 1e-3: + raise RuntimeError("The weights must not sum to 0.") + model_stat["weights"] /= model_stat["weights"].sum().abs() + model_stats.append(model_stat) + + init = None + if init_image is not None: + init = Image.open(fetch(init_image)).convert("RGB") + init = init.resize((args.side_x, args.side_y), Image.LANCZOS) + init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1) + + if args.perlin_init: + if args.perlin_mode == "color": + init = create_perlin_noise( + [1.5**-i * 0.5 for i in range(12)], 1, 1, False + ) + init2 = create_perlin_noise( + [1.5**-i * 0.5 for i in range(8)], 4, 4, False + ) + elif args.perlin_mode == "gray": + init = create_perlin_noise( + [1.5**-i * 0.5 for i in range(12)], 1, 1, True + ) + init2 = create_perlin_noise( + [1.5**-i * 0.5 for i in range(8)], 4, 4, True + ) + else: + init = create_perlin_noise( + [1.5**-i * 0.5 for i in range(12)], 1, 1, False + ) + init2 = create_perlin_noise( + [1.5**-i * 0.5 for i in range(8)], 4, 4, True + ) + # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device) + init = ( + TF.to_tensor(init) + .add(TF.to_tensor(init2)) + .div(2) + .to(device) + .unsqueeze(0) + .mul(2) + .sub(1) + ) + del init2 + + cur_t = None + + def cond_fn(x, t, y=None): + with torch.enable_grad(): + x_is_NaN = False + x = x.detach().requires_grad_() + n = x.shape[0] + if use_secondary_model is True: + alpha = torch.tensor( + diffusion.sqrt_alphas_cumprod[cur_t], + device=device, + dtype=torch.float32, + ) + sigma = torch.tensor( + diffusion.sqrt_one_minus_alphas_cumprod[cur_t], + device=device, + dtype=torch.float32, + ) + cosine_t = alpha_sigma_to_t(alpha, sigma) + out = secondary_model(x, cosine_t[None].repeat([n])).pred + fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t] + x_in = out * fac + x * (1 - fac) + x_in_grad = torch.zeros_like(x_in) + else: + my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t + out = diffusion.p_mean_variance( + model, x, my_t, clip_denoised=False, model_kwargs={"y": y} + ) + fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t] + x_in = out["pred_xstart"] * fac + x * (1 - fac) + x_in_grad = torch.zeros_like(x_in) + for model_stat in model_stats: + for i in range(int(args.cutn_batches)): + t_int = ( + int(t.item()) + 1 + ) # errors on last step without +1, need to find source + # when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution' + try: + input_resolution = model_stat[ + "clip_model" + ].visual.input_resolution + except: + input_resolution = 224 + + cuts = MakeCutoutsDango( + input_resolution, + Overview=args.cut_overview[1000 - t_int], + InnerCrop=args.cut_innercut[1000 - t_int], + IC_Size_Pow=args.cut_ic_pow, + IC_Grey_P=args.cut_icgray_p[1000 - t_int], + ) + clip_in = normalize(cuts(x_in.add(1).div(2))) + image_embeds = ( + model_stat["clip_model"].encode_image(clip_in).float() + ) + dists = spherical_dist_loss( + image_embeds.unsqueeze(1), + model_stat["target_embeds"].unsqueeze(0), + ) + dists = dists.view( + [ + args.cut_overview[1000 - t_int] + + args.cut_innercut[1000 - t_int], + n, + -1, + ] + ) + losses = dists.mul(model_stat["weights"]).sum(2).mean(0) + loss_values.append( + losses.sum().item() + ) # log loss, probably shouldn't do per cutn_batch + x_in_grad += ( + torch.autograd.grad( + losses.sum() * clip_guidance_scale, x_in + )[0] + / cutn_batches + ) + tv_losses = tv_loss(x_in) + if use_secondary_model is True: + range_losses = range_loss(out) + else: + range_losses = range_loss(out["pred_xstart"]) + sat_losses = torch.abs(x_in - x_in.clamp(min=-1, max=1)).mean() + loss = ( + tv_losses.sum() * tv_scale + + range_losses.sum() * range_scale + + sat_losses.sum() * sat_scale + ) + if init is not None and args.init_scale: + init_losses = lpips_model(x_in, init) + loss = loss + init_losses.sum() * args.init_scale + x_in_grad += torch.autograd.grad(loss, x_in)[0] + if torch.isnan(x_in_grad).any() == False: + grad = -torch.autograd.grad(x_in, x, x_in_grad)[0] + else: + # print("NaN'd") + x_is_NaN = True + grad = torch.zeros_like(x) + if args.clamp_grad and x_is_NaN == False: + magnitude = grad.square().mean().sqrt() + return ( + grad * magnitude.clamp(max=args.clamp_max) / magnitude + ) # min=-0.02, min=-clamp_max, + return grad + + if args.sampling_mode == "ddim": + sample_fn = diffusion.ddim_sample_loop_progressive + elif args.sampling_mode == "bicubic": + sample_fn = diffusion.p_sample_loop_progressive + elif args.sampling_mode == "plms": + sample_fn = diffusion.plms_sample_loop_progressive + # if model_config["timestep_respacing"].startswith("ddim"): + # sample_fn = diffusion.ddim_sample_loop_progressive + # else: + # sample_fn = diffusion.p_sample_loop_progressive + + image_display = Output() + for i in range(args.n_batches): + """ + if args.animation_mode == 'None': + display.clear_output(wait=True) + batchBar = tqdm(range(args.n_batches), desc ="Batches") + batchBar.n = i + batchBar.refresh() + print('') + display.display(image_display) + gc.collect() + torch.cuda.empty_cache() + """ + cur_t = diffusion.num_timesteps - skip_steps - 1 + total_steps = cur_t + + if perlin_init: + init = regen_perlin() + + if args.sampling_mode == "ddim": + samples = sample_fn( + model, + (batch_size, 3, args.side_y, args.side_x), + clip_denoised=clip_denoised, + model_kwargs={}, + cond_fn=cond_fn, + progress=True, + skip_timesteps=skip_steps, + init_image=init, + randomize_class=randomize_class, + eta=eta, + ) + elif args.sampling_mode == "plms": + samples = sample_fn( + model, + (batch_size, 3, args.side_y, args.side_x), + clip_denoised=clip_denoised, + model_kwargs={}, + cond_fn=cond_fn, + progress=True, + skip_timesteps=skip_steps, + init_image=init, + randomize_class=randomize_class, + order=2, + ) + elif args.sampling_mode == "bicubic": + samples = sample_fn( + model, + (batch_size, 3, args.side_y, args.side_x), + clip_denoised=clip_denoised, + model_kwargs={}, + cond_fn=cond_fn, + progress=True, + skip_timesteps=skip_steps, + init_image=init, + randomize_class=randomize_class, + ) + + # with run_display: + # display.clear_output(wait=True) + itt = 1 + imgToSharpen = None + status.write("Starting the execution...") + gc.collect() + torch.cuda.empty_cache() + # from tqdm.auto import tqdm + # from stqdm_local import stqdm + + # total_iterables = stqdm( + # samples, total=total_steps + 1, st_container=stoutput + # ) + total_iterables = samples + try: + j = 0 + before_start_time = time.perf_counter() + bar_container = status.container() + iteration_counter = bar_container.empty() + progress_bar = bar_container.progress(0) + for sample in total_iterables: + if itt == 1: + iteration_counter.empty() + imageLocation = stoutput.empty() + sys.stdout.write(f"Iteration {itt}\n") + sys.stdout.flush() + cur_t -= 1 + intermediateStep = False + if args.steps_per_checkpoint is not None: + if j % steps_per_checkpoint == 0 and j > 0: + intermediateStep = True + elif j in args.intermediate_saves: + intermediateStep = True + with image_display: + """ + if j % args.display_rate == 0 or cur_t == -1 or intermediateStep == True: + for k, image in enumerate(sample['pred_xstart']): + # tqdm.write(f'Batch {i}, step {j}, output {k}:') + current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f') + percent = math.ceil(j/total_steps*100) + if args.n_batches > 0: + #if intermediates are saved to the subfolder, don't append a step or percentage to the name + if cur_t == -1 and args.intermediates_in_subfolder is True: + save_num = f'{frame_num:04}' if animation_mode != "None" else i + filename = f'{args.batch_name}({args.batchNum})_{save_num}.png' + else: + #If we're working with percentages, append it + if args.steps_per_checkpoint is not None: + filename = f'{args.batch_name}({args.batchNum})_{i:04}-{percent:02}%.png' + # Or else, iIf we're working with specific steps, append those + else: + filename = f'{args.batch_name}({args.batchNum})_{i:04}-{j:03}.png' + image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1)) + if j % args.display_rate == 0 or cur_t == -1: + image.save('progress.png') + #display.clear_output(wait=True) + #display.display(display.Image('progress.png')) + if args.steps_per_checkpoint is not None: + if j % args.steps_per_checkpoint == 0 and j > 0: + if args.intermediates_in_subfolder is True: + image.save(f'{partialFolder}/{filename}') + else: + image.save(f'{batchFolder}/{filename}') + else: + if j in args.intermediate_saves: + if args.intermediates_in_subfolder is True: + image.save(f'{partialFolder}/{filename}') + else: + image.save(f'{batchFolder}/{filename}') + if cur_t == -1: + if frame_num == 0: + save_settings() + if args.animation_mode != "None": + image.save('prevFrame.png') + if args.sharpen_preset != "Off" and animation_mode == "None": + imgToSharpen = image + if args.keep_unsharp is True: + image.save(f'{unsharpenFolder}/{filename}') + else: + image.save(f'{batchFolder}/{filename}') + # if frame_num != args.max_frames-1: + # display.clear_output() + """ + if itt % args2.update == 0 or cur_t == -1 or itt == 1: + for k, image in enumerate(sample["pred_xstart"]): + sys.stdout.flush() + sys.stdout.write("Saving progress ...\n") + sys.stdout.flush() + + image = TF.to_pil_image( + image.add(1).div(2).clamp(0, 1) + ) + + if args.animation_mode != "None": + image.save("prevFrame.png") + + image.save(args2.image_file) + if (args2.frame_dir is not None) and ( + args.animation_mode == "None" + ): + import os + + file_list = [] + for file in sorted(os.listdir(args2.frame_dir)): + if file.startswith("FRA"): + if file.endswith("PNG"): + if len(file) == 12: + file_list.append(file) + if file_list: + last_name = file_list[-1] + count_value = int(last_name[3:8]) + 1 + count_string = f"{count_value:05d}" + else: + count_string = "00001" + save_name = ( + args2.frame_dir + + "/FRA" + + count_string + + ".PNG" + ) + image.save(save_name) + + # sys.stdout.flush() + # sys.stdout.write(f'{itt}/{args2.iterations} {skip_steps} {args.animation_mode} {args2.frame_dir}\n') + # sys.stdout.flush() + if ( + (args2.frame_dir is not None) + and (args.animation_mode == "3D") + and (itt == args2.iterations - skip_steps) + ): + sys.stdout.flush() + sys.stdout.write("Saving 3D frame...\n") + sys.stdout.flush() + import os + + file_list = [] + for file in os.listdir(args2.frame_dir): + if file.startswith("FRA"): + if file.endswith("PNG"): + if len(file) == 12: + file_list.append(file) + if file_list: + last_name = file_list[-1] + count_value = int(last_name[3:8]) + 1 + count_string = f"{count_value:05d}" + else: + count_string = "00001" + save_name = ( + args2.frame_dir + + "/FRA" + + count_string + + ".PNG" + ) + image.save(save_name) + + imageLocation.image(Image.open(args2.image_file)) + sys.stdout.flush() + sys.stdout.write("Progress saved\n") + sys.stdout.flush() + itt += 1 + j += 1 + time_past_seconds = time.perf_counter() - before_start_time + iterations_per_second = j / time_past_seconds + time_left = (total_steps - j) / iterations_per_second + percentage = round((j / (total_steps + 1)) * 100) + + iteration_counter.write( + f"{percentage}% {j}/{total_steps+1} [{time.strftime('%M:%S', time.gmtime(time_past_seconds))}<{time.strftime('%M:%S', time.gmtime(time_left))}, {round(iterations_per_second,2)} it/s]" + ) + progress_bar.progress(int(percentage)) + + # if path_exists(drive_path): + + except KeyboardInterrupt: + pass + # except st.script_runner.StopException as e: + # imageLocation.image(args2.image_file) + # gc.collect() + # torch.cuda.empty_cache() + # status.write("Done!") + # pass + imageLocation.empty() + with image_display: + if args.sharpen_preset != "Off" and animation_mode == "None": + print("Starting Diffusion Sharpening...") + do_superres(imgToSharpen, f"{batchFolder}/{filename}") + display.clear_output() + + import shutil + from pathvalidate import sanitize_filename + import os + + if not path_exists(DefaultPaths.output_path): + os.makedirs(DefaultPaths.output_path) + save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}.png" + print(save_filename) + file_list = [] + if path_exists(save_filename): + for file in sorted(os.listdir(f"{DefaultPaths.output_path}/")): + if file.startswith( + f"{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}" + ): + print(file) + file_list.append(file) + print(file_list) + last_name = file_list[-1] + print(last_name) + if last_name[-15:-10] == "batch": + count_value = int(last_name[-10:-4]) + 1 + count_string = f"{count_value:05d}" + save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}_batch {count_string}.png" + else: + save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}_batch 00001.png" + shutil.copyfile( + args2.image_file, + save_filename, + ) + imageLocation.empty() + status.write("Done!") + plt.plot(np.array(loss_values), "r") + + def save_settings(): + setting_list = { + "text_prompts": text_prompts, + "image_prompts": image_prompts, + "clip_guidance_scale": clip_guidance_scale, + "tv_scale": tv_scale, + "range_scale": range_scale, + "sat_scale": sat_scale, + # 'cutn': cutn, + "cutn_batches": cutn_batches, + "max_frames": max_frames, + "interp_spline": interp_spline, + # 'rotation_per_frame': rotation_per_frame, + "init_image": init_image, + "init_scale": init_scale, + "skip_steps": skip_steps, + # 'zoom_per_frame': zoom_per_frame, + "frames_scale": frames_scale, + "frames_skip_steps": frames_skip_steps, + "perlin_init": perlin_init, + "perlin_mode": perlin_mode, + "skip_augs": skip_augs, + "randomize_class": randomize_class, + "clip_denoised": clip_denoised, + "clamp_grad": clamp_grad, + "clamp_max": clamp_max, + "seed": seed, + "fuzzy_prompt": fuzzy_prompt, + "rand_mag": rand_mag, + "eta": eta, + "width": width_height[0], + "height": width_height[1], + "diffusion_model": diffusion_model, + "use_secondary_model": use_secondary_model, + "steps": steps, + "diffusion_steps": diffusion_steps, + "ViTB32": ViTB32, + "ViTB16": ViTB16, + "ViTL14": ViTL14, + "RN101": RN101, + "RN50": RN50, + "RN50x4": RN50x4, + "RN50x16": RN50x16, + "RN50x64": RN50x64, + "cut_overview": str(cut_overview), + "cut_innercut": str(cut_innercut), + "cut_ic_pow": cut_ic_pow, + "cut_icgray_p": str(cut_icgray_p), + "key_frames": key_frames, + "max_frames": max_frames, + "angle": angle, + "zoom": zoom, + "translation_x": translation_x, + "translation_y": translation_y, + "translation_z": translation_z, + "rotation_3d_x": rotation_3d_x, + "rotation_3d_y": rotation_3d_y, + "rotation_3d_z": rotation_3d_z, + "midas_depth_model": midas_depth_model, + "midas_weight": midas_weight, + "near_plane": near_plane, + "far_plane": far_plane, + "fov": fov, + "padding_mode": padding_mode, + "sampling_mode": sampling_mode, + "video_init_path": video_init_path, + "extract_nth_frame": extract_nth_frame, + "turbo_mode": turbo_mode, + "turbo_steps": turbo_steps, + } + # print('Settings:', setting_list) + with open( + f"{batchFolder}/{batch_name}({batchNum})_settings.txt", "w+" + ) as f: # save settings + json.dump(setting_list, f, ensure_ascii=False, indent=4) + + # @title 1.6 Define the secondary diffusion model + + def append_dims(x, n): + return x[(Ellipsis, *(None,) * (n - x.ndim))] + + def expand_to_planes(x, shape): + return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]]) + + def alpha_sigma_to_t(alpha, sigma): + return torch.atan2(sigma, alpha) * 2 / math.pi + + def t_to_alpha_sigma(t): + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + + @dataclass + class DiffusionOutput: + v: torch.Tensor + pred: torch.Tensor + eps: torch.Tensor + + class ConvBlock(nn.Sequential): + def __init__(self, c_in, c_out): + super().__init__( + nn.Conv2d(c_in, c_out, 3, padding=1), + nn.ReLU(inplace=True), + ) + + class SkipBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return torch.cat([self.main(input), self.skip(input)], dim=1) + + class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.0): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter( + torch.randn([out_features // 2, in_features]) * std + ) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + + class SecondaryDiffusionImageNet(nn.Module): + def __init__(self): + super().__init__() + c = 64 # The base channel count + + self.timestep_embed = FourierFeatures(1, 16) + + self.net = nn.Sequential( + ConvBlock(3 + 16, c), + ConvBlock(c, c), + SkipBlock( + [ + nn.AvgPool2d(2), + ConvBlock(c, c * 2), + ConvBlock(c * 2, c * 2), + SkipBlock( + [ + nn.AvgPool2d(2), + ConvBlock(c * 2, c * 4), + ConvBlock(c * 4, c * 4), + SkipBlock( + [ + nn.AvgPool2d(2), + ConvBlock(c * 4, c * 8), + ConvBlock(c * 8, c * 4), + nn.Upsample( + scale_factor=2, + mode="bilinear", + align_corners=False, + ), + ] + ), + ConvBlock(c * 8, c * 4), + ConvBlock(c * 4, c * 2), + nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=False + ), + ] + ), + ConvBlock(c * 4, c * 2), + ConvBlock(c * 2, c), + nn.Upsample( + scale_factor=2, mode="bilinear", align_corners=False + ), + ] + ), + ConvBlock(c * 2, c), + nn.Conv2d(c, 3, 3, padding=1), + ) + + def forward(self, input, t): + timestep_embed = expand_to_planes( + self.timestep_embed(t[:, None]), input.shape + ) + v = self.net(torch.cat([input, timestep_embed], dim=1)) + alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t)) + pred = input * alphas - v * sigmas + eps = input * sigmas + v * alphas + return DiffusionOutput(v, pred, eps) + + class SecondaryDiffusionImageNet2(nn.Module): + def __init__(self): + super().__init__() + c = 64 # The base channel count + cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8] + + self.timestep_embed = FourierFeatures(1, 16) + self.down = nn.AvgPool2d(2) + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) + + self.net = nn.Sequential( + ConvBlock(3 + 16, cs[0]), + ConvBlock(cs[0], cs[0]), + SkipBlock( + [ + self.down, + ConvBlock(cs[0], cs[1]), + ConvBlock(cs[1], cs[1]), + SkipBlock( + [ + self.down, + ConvBlock(cs[1], cs[2]), + ConvBlock(cs[2], cs[2]), + SkipBlock( + [ + self.down, + ConvBlock(cs[2], cs[3]), + ConvBlock(cs[3], cs[3]), + SkipBlock( + [ + self.down, + ConvBlock(cs[3], cs[4]), + ConvBlock(cs[4], cs[4]), + SkipBlock( + [ + self.down, + ConvBlock(cs[4], cs[5]), + ConvBlock(cs[5], cs[5]), + ConvBlock(cs[5], cs[5]), + ConvBlock(cs[5], cs[4]), + self.up, + ] + ), + ConvBlock(cs[4] * 2, cs[4]), + ConvBlock(cs[4], cs[3]), + self.up, + ] + ), + ConvBlock(cs[3] * 2, cs[3]), + ConvBlock(cs[3], cs[2]), + self.up, + ] + ), + ConvBlock(cs[2] * 2, cs[2]), + ConvBlock(cs[2], cs[1]), + self.up, + ] + ), + ConvBlock(cs[1] * 2, cs[1]), + ConvBlock(cs[1], cs[0]), + self.up, + ] + ), + ConvBlock(cs[0] * 2, cs[0]), + nn.Conv2d(cs[0], 3, 3, padding=1), + ) + + def forward(self, input, t): + timestep_embed = expand_to_planes( + self.timestep_embed(t[:, None]), input.shape + ) + v = self.net(torch.cat([input, timestep_embed], dim=1)) + alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t)) + pred = input * alphas - v * sigmas + eps = input * sigmas + v * alphas + return DiffusionOutput(v, pred, eps) + + # 2. Diffusion and CLIP model settings""" + + if args2.use256 == 0: + sys.stdout.write("Loading 512x512_diffusion_uncond_finetune_008100 ...\n") + sys.stdout.flush() + status.write("Loading 512x512_diffusion_uncond_finetune_008100 ...\n") + diffusion_model = "512x512_diffusion_uncond_finetune_008100" # @param ["256x256_diffusion_uncond", "512x512_diffusion_uncond_finetune_008100"] + else: + sys.stdout.write("Loading 256x256_diffusion_uncond ...\n") + sys.stdout.flush() + status.write("Loading 256x256_diffusion_uncond ...\n") + diffusion_model = "256x256_diffusion_uncond" + + if args2.secondarymodel == 1: + use_secondary_model = True # @param {type: 'boolean'} + else: + use_secondary_model = False # @param {type: 'boolean'} + + # timestep_respacing = '50' # param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000'] + if args2.sampling_mode == "ddim" or args2.sampling_mode == "plms": + timestep_respacing = "ddim" + str( + args2.iterations + ) #'ddim100' # Modify this value to decrease the number of timesteps. + else: + timestep_respacing = str( + args2.iterations + ) #'ddim100' # Modify this value to decrease the number of timesteps. + + diffusion_steps = 1000 # param {type: 'number'} + + use_checkpoint = True # @param {type: 'boolean'} + + # @markdown If you're having issues with model downloads, check this to compare SHA's: + check_model_SHA = False # @param{type:"boolean"} + + model_256_SHA = "983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a" + model_512_SHA = "9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648" + model_secondary_SHA = ( + "983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a" + ) + + model_256_link = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt" + model_512_link = "https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt" + model_secondary_link = ( + "https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth" + ) + + model_256_path = f"{DefaultPaths.model_path}/256x256_diffusion_uncond.pt" + model_512_path = ( + f"{DefaultPaths.model_path}/512x512_diffusion_uncond_finetune_008100.pt" + ) + model_secondary_path = f"{DefaultPaths.model_path}/secondary_model_imagenet_2.pth" + + model_256_downloaded = True + model_512_downloaded = True + model_secondary_downloaded = True + + model_config = model_and_diffusion_defaults() + if diffusion_model == "512x512_diffusion_uncond_finetune_008100": + model_config.update( + { + "attention_resolutions": "32, 16, 8", + "class_cond": False, + "diffusion_steps": diffusion_steps, + "rescale_timesteps": True, + "timestep_respacing": timestep_respacing, + "image_size": 512, + "learn_sigma": True, + "noise_schedule": "linear", + "num_channels": 256, + "num_head_channels": 64, + "num_res_blocks": 2, + "resblock_updown": True, + "use_checkpoint": use_checkpoint, + "use_fp16": True, + "use_scale_shift_norm": True, + } + ) + elif diffusion_model == "256x256_diffusion_uncond": + model_config.update( + { + "attention_resolutions": "32, 16, 8", + "class_cond": False, + "diffusion_steps": diffusion_steps, + "rescale_timesteps": True, + "timestep_respacing": timestep_respacing, + "image_size": 256, + "learn_sigma": True, + "noise_schedule": "linear", + "num_channels": 256, + "num_head_channels": 64, + "num_res_blocks": 2, + "resblock_updown": True, + "use_checkpoint": use_checkpoint, + "use_fp16": True, + "use_scale_shift_norm": True, + } + ) + + secondary_model_ver = 2 + model_default = model_config["image_size"] + + if secondary_model_ver == 2: + secondary_model = SecondaryDiffusionImageNet2() + secondary_model.load_state_dict( + torch.load( + f"{DefaultPaths.model_path}/secondary_model_imagenet_2.pth", + map_location="cpu", + ) + ) + secondary_model.eval().requires_grad_(False).to(device) + + clip_models = [] + if args2.usevit32 == 1: + sys.stdout.write("Loading ViT-B/32 CLIP model ...\n") + sys.stdout.flush() + status.write("Loading ViT-B/32 CLIP model ...\n") + clip_models.append( + clip.load("ViT-B/32", jit=False)[0].eval().requires_grad_(False).to(device) + ) + if args2.usevit16 == 1: + sys.stdout.write("Loading ViT-B/16 CLIP model ...\n") + sys.stdout.flush() + status.write("Loading ViT-B/16 CLIP model ...\n") + clip_models.append( + clip.load("ViT-B/16", jit=False)[0].eval().requires_grad_(False).to(device) + ) + if args2.usevit14 == 1: + sys.stdout.write("Loading ViT-L/14 CLIP model ...\n") + sys.stdout.flush() + status.write("Loading ViT-L/14 CLIP model ...\n") + clip_models.append( + clip.load("ViT-L/14", jit=False)[0].eval().requires_grad_(False).to(device) + ) + if args2.usern50x4 == 1: + sys.stdout.write("Loading RN50x4 CLIP model ...\n") + sys.stdout.flush() + status.write("Loading RN50x4 CLIP model ...\n") + clip_models.append( + clip.load("RN50x4", jit=False)[0].eval().requires_grad_(False).to(device) + ) + if args2.usern50x16 == 1: + sys.stdout.write("Loading RN50x16 CLIP model ...\n") + sys.stdout.flush() + status.write("Loading RN50x16 CLIP model ...\n") + clip_models.append( + clip.load("RN50x16", jit=False)[0].eval().requires_grad_(False).to(device) + ) + if args2.usern50x64 == 1: + sys.stdout.write("Loading RN50x64 CLIP model ...\n") + sys.stdout.flush() + status.write("Loading RN50x64 CLIP model ...\n") + clip_models.append( + clip.load("RN50x64", jit=False)[0].eval().requires_grad_(False).to(device) + ) + if args2.usern50 == 1: + sys.stdout.write("Loading RN50 CLIP model ...\n") + sys.stdout.flush() + status.write("Loading RN50 CLIP model ...\n") + clip_models.append( + clip.load("RN50", jit=False)[0].eval().requires_grad_(False).to(device) + ) + if args2.usern101 == 1: + sys.stdout.write("Loading RN101 CLIP model ...\n") + sys.stdout.flush() + status.write("Loading RN101 CLIP model ...\n") + clip_models.append( + clip.load("RN101", jit=False)[0].eval().requires_grad_(False).to(device) + ) + if args2.useslipbase == 1: + sys.stdout.write("Loading SLIP Base model ...\n") + sys.stdout.flush() + SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256) + # next 2 lines needed so torch.load handles posix paths on Windows + import pathlib + + pathlib.PosixPath = pathlib.WindowsPath + sd = torch.load("slip_base_100ep.pt") + real_sd = {} + for k, v in sd["state_dict"].items(): + real_sd[".".join(k.split(".")[1:])] = v + del sd + SLIPB16model.load_state_dict(real_sd) + SLIPB16model.requires_grad_(False).eval().to(device) + clip_models.append(SLIPB16model) + if args2.usesliplarge == 1: + sys.stdout.write("Loading SLIP Large model ...\n") + sys.stdout.flush() + SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256) + # next 2 lines needed so torch.load handles posix paths on Windows + import pathlib + + pathlib.PosixPath = pathlib.WindowsPath + sd = torch.load("slip_large_100ep.pt") + real_sd = {} + for k, v in sd["state_dict"].items(): + real_sd[".".join(k.split(".")[1:])] = v + del sd + SLIPL16model.load_state_dict(real_sd) + SLIPL16model.requires_grad_(False).eval().to(device) + clip_models.append(SLIPL16model) + + normalize = T.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ) + status.write("Loading lpips model...\n") + lpips_model = lpips.LPIPS(net="vgg").to(device) + + """# 3. Settings""" + + # sys.stdout.write("DEBUG0 ...\n") + # sys.stdout.flush() + + # @markdown ####**Basic Settings:** + batch_name = "TimeToDisco" # @param{type: 'string'} + steps = ( + args2.iterations + ) # @param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true} + width_height = [args2.sizex, args2.sizey] # @param{type: 'raw'} + clip_guidance_scale = args2.guidancescale # @param{type: 'number'} + tv_scale = args2.tvscale # @param{type: 'number'} + range_scale = args2.rangescale # @param{type: 'number'} + sat_scale = args2.saturationscale # @param{type: 'number'} + cutn_batches = args2.cutnbatches # @param{type: 'number'} + + if args2.useaugs == 1: + skip_augs = False # False - Controls whether to skip torchvision augmentations + else: + skip_augs = True # False - Controls whether to skip torchvision augmentations + + # @markdown ####**Init Settings:** + if args2.seed_image is not None: + init_image = ( + args2.seed_image + ) # This can be an URL or Colab local path and must be in quotes. + skip_steps = ( + args2.skipseedtimesteps + ) # 12 Skip unstable steps # Higher values make the output look more like the init. + init_scale = ( + args2.initscale + ) # This enhances the effect of the init image, a good value is 1000. + else: + init_image = "" # This can be an URL or Colab local path and must be in quotes. + skip_steps = 0 # 12 Skip unstable steps # Higher values make the output look more like the init. + init_scale = ( + 0 # This enhances the effect of the init image, a good value is 1000. + ) + + if init_image == "": + init_image = None + + side_x = args2.sizex + side_y = args2.sizey + + # Update Model Settings + # timestep_respacing = f'ddim{steps}' + diffusion_steps = (1000 // steps) * steps if steps < 1000 else steps + model_config.update( + { + "timestep_respacing": timestep_respacing, + "diffusion_steps": diffusion_steps, + } + ) + + # Make folder for batch + batchFolder = f"./" + # createPath(batchFolder) + + # sys.stdout.write("DEBUG1 ...\n") + # sys.stdout.flush() + + """###Animation Settings""" + + # @markdown ####**Animation Mode:** + animation_mode = ( + args2.animation_mode + ) #'None' #@param ['None', '2D', '3D', 'Video Input'] {type:'string'} + # @markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.* + + # @markdown --- + + # @markdown ####**Video Input Settings:** + video_init_path = "training.mp4" # "D:\\sample_cat.mp4" #@param {type: 'string'} + extract_nth_frame = 2 # @param {type:"number"} + + # sys.stdout.write("DEBUG1a ...\n") + # sys.stdout.flush() + + if animation_mode == "Video Input": + videoFramesFolder = "./videoFrames" + # createPath(videoFramesFolder) + # print(f"Exporting Video Frames (1 every {extract_nth_frame})...") + sys.stdout.write(f"Exporting Video Frames (1 every {extract_nth_frame})...\n") + sys.stdout.flush() + + """ + try: + !rm {videoFramesFolder}/*.jpg + except: + print('') + """ + # sys.stdout.write("DEBUG1a1 ...\n") + # sys.stdout.flush() + vf = f'"select=not(mod(n\,{extract_nth_frame}))"' + # sys.stdout.write("DEBUG1a2 ...\n") + # sys.stdout.flush() + os.system( + f"ffmpeg.exe -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg" + ) + # sys.stdout.write("DEBUG1a3 ...\n") + # sys.stdout.flush() + + # sys.stdout.write("DEBUG1b ...\n") + # sys.stdout.flush() + + # @markdown --- + + # @markdown ####**2D Animation Settings:** + # @markdown `zoom` is a multiplier of dimensions, 1 is no zoom. + + key_frames = True # @param {type:"boolean"} + max_frames = args2.max_frames # 10000#@param {type:"number"} + + # sys.stdout.write("DEBUG1c ...\n") + # sys.stdout.flush() + + if animation_mode == "Video Input": + max_frames = len(glob(f"{videoFramesFolder}/*.jpg")) + + # sys.stdout.write("DEBUG1d ...\n") + # sys.stdout.flush() + + interp_spline = "Linear" # Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:"string"} + angle = args2.angle # "0:(0)"#@param {type:"string"} + zoom = args2.zoom # "0: (1), 10: (1.05)"#@param {type:"string"} + translation_x = args2.translation_x # "0: (0)"#@param {type:"string"} + translation_y = args2.translation_y # "0: (0)"#@param {type:"string"} + translation_z = args2.translation_z # "0: (10.0)"#@param {type:"string"} + rotation_3d_x = args2.rotation_3d_x # "0: (0)"#@param {type:"string"} + rotation_3d_y = args2.rotation_3d_y # "0: (0)"#@param {type:"string"} + rotation_3d_z = args2.rotation_3d_z # "0: (0)"#@param {type:"string"} + midas_depth_model = "dpt_large" # @param {type:"string"} + midas_weight = args2.midas_weight # 0.3#@param {type:"number"} + near_plane = args2.near_plane # 200#@param {type:"number"} + far_plane = args2.far_plane # 10000#@param {type:"number"} + fov = args2.fov # 40#@param {type:"number"} + padding_mode = "border" # @param {type:"string"} + sampling_mode = args2.sampling_mode # @param {type:"string"} + # @markdown ####**Coherency Settings:** + # @markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500. + frames_scale = args2.frames_scale # 1500 #@param{type: 'integer'} + # @markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into. + frames_skip_steps = ( + args2.frames_skip_steps + ) #'60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'} + + if args2.turbo_mode == 1: + turbo_mode = True # @param {type:"boolean"} + else: + turbo_mode = False # @param {type:"boolean"} + turbo_steps = args2.turbo_steps # "3" #@param ["2","3","4"] {type:'string'} + # @markdown --- + + def parse_key_frames(string, prompt_parser=None): + """Given a string representing frame numbers paired with parameter values at that frame, + return a dictionary with the frame numbers as keys and the parameter values as the values. + + Parameters + ---------- + string: string + Frame numbers paired with parameter values at that frame number, in the format + 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...' + prompt_parser: function or None, optional + If provided, prompt_parser will be applied to each string of parameter values. + + Returns + ------- + dict + Frame numbers as keys, parameter values at that frame number as values + + Raises + ------ + RuntimeError + If the input string does not match the expected format. + + Examples + -------- + >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)") + {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'} + + >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)", prompt_parser=lambda x: x.lower())) + {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'} + """ + import re + + pattern = r"((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])" + frames = dict() + for match_object in re.finditer(pattern, string): + frame = int(match_object.groupdict()["frame"]) + param = match_object.groupdict()["param"] + if prompt_parser: + frames[frame] = prompt_parser(param) + else: + frames[frame] = param + + if frames == {} and len(string) != 0: + raise RuntimeError("Key Frame string not correctly formatted") + return frames + + def get_inbetweens(key_frames, integer=False): + """Given a dict with frame numbers as keys and a parameter value as values, + return a pandas Series containing the value of the parameter at every frame from 0 to max_frames. + Any values not provided in the input dict are calculated by linear interpolation between + the values of the previous and next provided frames. If there is no previous provided frame, then + the value is equal to the value of the next provided frame, or if there is no next provided frame, + then the value is equal to the value of the previous provided frame. If no frames are provided, + all frame values are NaN. + + Parameters + ---------- + key_frames: dict + A dict with integer frame numbers as keys and numerical values of a particular parameter as values. + integer: Bool, optional + If True, the values of the output series are converted to integers. + Otherwise, the values are floats. + + Returns + ------- + pd.Series + A Series with length max_frames representing the parameter values for each frame. + + Examples + -------- + >>> max_frames = 5 + >>> get_inbetweens({1: 5, 3: 6}) + 0 5.0 + 1 5.0 + 2 5.5 + 3 6.0 + 4 6.0 + dtype: float64 + + >>> get_inbetweens({1: 5, 3: 6}, integer=True) + 0 5 + 1 5 + 2 5 + 3 6 + 4 6 + dtype: int64 + """ + key_frame_series = pd.Series([np.nan for a in range(max_frames)]) + + for i, value in key_frames.items(): + key_frame_series[i] = value + key_frame_series = key_frame_series.astype(float) + + interp_method = interp_spline + + if interp_method == "Cubic" and len(key_frames.items()) <= 3: + interp_method = "Quadratic" + + if interp_method == "Quadratic" and len(key_frames.items()) <= 2: + interp_method = "Linear" + + key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] + key_frame_series[max_frames - 1] = key_frame_series[ + key_frame_series.last_valid_index() + ] + # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both') + key_frame_series = key_frame_series.interpolate( + method=interp_method.lower(), limit_direction="both" + ) + if integer: + return key_frame_series.astype(int) + return key_frame_series + + def split_prompts(prompts): + prompt_series = pd.Series([np.nan for a in range(max_frames)]) + for i, prompt in prompts.items(): + prompt_series[i] = prompt + # prompt_series = prompt_series.astype(str) + prompt_series = prompt_series.ffill().bfill() + return prompt_series + + if key_frames: + try: + angle_series = get_inbetweens(parse_key_frames(angle)) + except RuntimeError as e: + print( + "WARNING: You have selected to use key frames, but you have not " + "formatted `angle` correctly for key frames.\n" + "Attempting to interpret `angle` as " + f'"0: ({angle})"\n' + "Please read the instructions to find out how to use key frames " + "correctly.\n" + ) + angle = f"0: ({angle})" + angle_series = get_inbetweens(parse_key_frames(angle)) + + try: + zoom_series = get_inbetweens(parse_key_frames(zoom)) + except RuntimeError as e: + print( + "WARNING: You have selected to use key frames, but you have not " + "formatted `zoom` correctly for key frames.\n" + "Attempting to interpret `zoom` as " + f'"0: ({zoom})"\n' + "Please read the instructions to find out how to use key frames " + "correctly.\n" + ) + zoom = f"0: ({zoom})" + zoom_series = get_inbetweens(parse_key_frames(zoom)) + + try: + translation_x_series = get_inbetweens(parse_key_frames(translation_x)) + except RuntimeError as e: + print( + "WARNING: You have selected to use key frames, but you have not " + "formatted `translation_x` correctly for key frames.\n" + "Attempting to interpret `translation_x` as " + f'"0: ({translation_x})"\n' + "Please read the instructions to find out how to use key frames " + "correctly.\n" + ) + translation_x = f"0: ({translation_x})" + translation_x_series = get_inbetweens(parse_key_frames(translation_x)) + + try: + translation_y_series = get_inbetweens(parse_key_frames(translation_y)) + except RuntimeError as e: + print( + "WARNING: You have selected to use key frames, but you have not " + "formatted `translation_y` correctly for key frames.\n" + "Attempting to interpret `translation_y` as " + f'"0: ({translation_y})"\n' + "Please read the instructions to find out how to use key frames " + "correctly.\n" + ) + translation_y = f"0: ({translation_y})" + translation_y_series = get_inbetweens(parse_key_frames(translation_y)) + + try: + translation_z_series = get_inbetweens(parse_key_frames(translation_z)) + except RuntimeError as e: + print( + "WARNING: You have selected to use key frames, but you have not " + "formatted `translation_z` correctly for key frames.\n" + "Attempting to interpret `translation_z` as " + f'"0: ({translation_z})"\n' + "Please read the instructions to find out how to use key frames " + "correctly.\n" + ) + translation_z = f"0: ({translation_z})" + translation_z_series = get_inbetweens(parse_key_frames(translation_z)) + + try: + rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x)) + except RuntimeError as e: + print( + "WARNING: You have selected to use key frames, but you have not " + "formatted `rotation_3d_x` correctly for key frames.\n" + "Attempting to interpret `rotation_3d_x` as " + f'"0: ({rotation_3d_x})"\n' + "Please read the instructions to find out how to use key frames " + "correctly.\n" + ) + rotation_3d_x = f"0: ({rotation_3d_x})" + rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x)) + + try: + rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y)) + except RuntimeError as e: + print( + "WARNING: You have selected to use key frames, but you have not " + "formatted `rotation_3d_y` correctly for key frames.\n" + "Attempting to interpret `rotation_3d_y` as " + f'"0: ({rotation_3d_y})"\n' + "Please read the instructions to find out how to use key frames " + "correctly.\n" + ) + rotation_3d_y = f"0: ({rotation_3d_y})" + rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y)) + + try: + rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z)) + except RuntimeError as e: + print( + "WARNING: You have selected to use key frames, but you have not " + "formatted `rotation_3d_z` correctly for key frames.\n" + "Attempting to interpret `rotation_3d_z` as " + f'"0: ({rotation_3d_z})"\n' + "Please read the instructions to find out how to use key frames " + "correctly.\n" + ) + rotation_3d_z = f"0: ({rotation_3d_z})" + rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z)) + + else: + angle = float(angle) + zoom = float(zoom) + translation_x = float(translation_x) + translation_y = float(translation_y) + translation_z = float(translation_z) + rotation_3d_x = float(rotation_3d_x) + rotation_3d_y = float(rotation_3d_y) + rotation_3d_z = float(rotation_3d_z) + + """### Extra Settings + Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling + """ + + # @markdown ####**Saving:** + + intermediate_saves = 0 # @param{type: 'raw'} + intermediates_in_subfolder = True # @param{type: 'boolean'} + # @markdown Intermediate steps will save a copy at your specified intervals. You can either format it as a single integer or a list of specific steps + + # @markdown A value of `2` will save a copy at 33% and 66%. 0 will save none. + + # @markdown A value of `[5, 9, 34, 45]` will save at steps 5, 9, 34, and 45. (Make sure to include the brackets) + + if type(intermediate_saves) is not list: + if intermediate_saves: + steps_per_checkpoint = math.floor( + (steps - skip_steps - 1) // (intermediate_saves + 1) + ) + steps_per_checkpoint = ( + steps_per_checkpoint if steps_per_checkpoint > 0 else 1 + ) + print(f"Will save every {steps_per_checkpoint} steps") + else: + steps_per_checkpoint = steps + 10 + else: + steps_per_checkpoint = None + + if intermediate_saves and intermediates_in_subfolder is True: + partialFolder = f"{batchFolder}/partials" + createPath(partialFolder) + + # @markdown --- + + # @markdown ####**SuperRes Sharpening:** + # @markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.* + sharpen_preset = "Off" # @param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow'] + keep_unsharp = True # @param{type: 'boolean'} + + if sharpen_preset != "Off" and keep_unsharp is True: + unsharpenFolder = f"{batchFolder}/unsharpened" + createPath(unsharpenFolder) + + # @markdown --- + + # @markdown ####**Advanced Settings:** + # @markdown *There are a few extra advanced settings available if you double click this cell.* + + # @markdown *Perlin init will replace your init, so uncheck if using one.* + + if args2.perlin_init == 1: + perlin_init = True # @param{type: 'boolean'} + else: + perlin_init = False # @param{type: 'boolean'} + perlin_mode = args2.perlin_mode #'mixed' #@param ['mixed', 'color', 'gray'] + + set_seed = "random_seed" # @param{type: 'string'} + eta = args2.eta # @param{type: 'number'} + clamp_grad = True # @param{type: 'boolean'} + clamp_max = args2.clampmax # @param{type: 'number'} + + ### EXTRA ADVANCED SETTINGS: + randomize_class = True + if args2.denoised == 1: + clip_denoised = True + else: + clip_denoised = False + fuzzy_prompt = False + rand_mag = 0.05 + + # @markdown --- + + # @markdown ####**Cutn Scheduling:** + # @markdown Format: `[40]*400+[20]*600` = 40 cuts for the first 400 /1000 steps, then 20 for the last 600/1000 + + # @markdown cut_overview and cut_innercut are cumulative for total cutn on any given step. Overview cuts see the entire image and are good for early structure, innercuts are your standard cutn. + + cut_overview = "[12]*400+[4]*600" # @param {type: 'string'} + cut_innercut = "[4]*400+[12]*600" # @param {type: 'string'} + cut_ic_pow = 1 # @param {type: 'number'} + cut_icgray_p = "[0.2]*400+[0]*600" # @param {type: 'string'} + + """###Prompts + `animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one. + """ + + """ + text_prompts = { + 0: ["A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation.", "yellow color scheme"], + 100: ["This set of prompts start at frame 100","This prompt has weight five:5"], + } + """ + + text_prompts = {0: [phrase.strip() for phrase in args2.prompt.split("|")]} + + image_prompts = { + # 0:['ImagePromptsWorkButArentVeryGood.png:2',], + } + + """# 4. Diffuse!""" + + # @title Do the Run! + # @markdown `n_batches` ignored with animation modes. + display_rate = args2.update # @param{type: 'number'} + n_batches = 1 # @param{type: 'number'} + + batch_size = 1 + + def move_files(start_num, end_num, old_folder, new_folder): + for i in range(start_num, end_num): + old_file = old_folder + f"/{batch_name}({batchNum})_{i:04}.png" + new_file = new_folder + f"/{batch_name}({batchNum})_{i:04}.png" + os.rename(old_file, new_file) + + # @markdown --- + + resume_run = False # @param{type: 'boolean'} + run_to_resume = "latest" # @param{type: 'string'} + resume_from_frame = "latest" # @param{type: 'string'} + retain_overwritten_frames = False # @param{type: 'boolean'} + if retain_overwritten_frames is True: + retainFolder = f"{batchFolder}/retained" + createPath(retainFolder) + + skip_step_ratio = int(frames_skip_steps.rstrip("%")) / 100 + calc_frames_skip_steps = math.floor(steps * skip_step_ratio) + + if steps <= calc_frames_skip_steps: + sys.exit("ERROR: You can't skip more steps than your total steps") + + """ + if resume_run: + if run_to_resume == 'latest': + try: + batchNum + except: + batchNum = len(glob(f"{batchFolder}/{batch_name}(*)_settings.txt"))-1 + else: + batchNum = int(run_to_resume) + if resume_from_frame == 'latest': + start_frame = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png")) + else: + start_frame = int(resume_from_frame)+1 + if retain_overwritten_frames is True: + existing_frames = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png")) + frames_to_save = existing_frames - start_frame + print(f'Moving {frames_to_save} frames to the Retained folder') + move_files(start_frame, existing_frames, batchFolder, retainFolder) + else: + """ + start_frame = 0 + batchNum = 1 + """ + batchNum = len(glob(batchFolder+"/*.txt")) + while path.isfile(f"{batchFolder}/{batch_name}({batchNum})_settings.txt") is True or path.isfile(f"{batchFolder}/{batch_name}-{batchNum}_settings.txt") is True: + batchNum += 1 + """ + # print(f'Starting Run: {batch_name}({batchNum}) at frame {start_frame}') + + if set_seed == "random_seed": + random.seed() + seed = random.randint(0, 2**32) + # print(f'Using seed: {seed}') + else: + seed = int(set_seed) + + args = { + "batchNum": batchNum, + "prompts_series": split_prompts(text_prompts) if text_prompts else None, + "image_prompts_series": split_prompts(image_prompts) if image_prompts else None, + "seed": seed, + "display_rate": display_rate, + "n_batches": n_batches if animation_mode == "None" else 1, + "batch_size": batch_size, + "batch_name": batch_name, + "steps": steps, + "width_height": width_height, + "clip_guidance_scale": clip_guidance_scale, + "tv_scale": tv_scale, + "range_scale": range_scale, + "sat_scale": sat_scale, + "cutn_batches": cutn_batches, + "init_image": init_image, + "init_scale": init_scale, + "skip_steps": skip_steps, + "sharpen_preset": sharpen_preset, + "keep_unsharp": keep_unsharp, + "side_x": side_x, + "side_y": side_y, + "timestep_respacing": timestep_respacing, + "diffusion_steps": diffusion_steps, + "animation_mode": animation_mode, + "video_init_path": video_init_path, + "extract_nth_frame": extract_nth_frame, + "key_frames": key_frames, + "max_frames": max_frames if animation_mode != "None" else 1, + "interp_spline": interp_spline, + "start_frame": start_frame, + "angle": angle, + "zoom": zoom, + "translation_x": translation_x, + "translation_y": translation_y, + "translation_z": translation_z, + "rotation_3d_x": rotation_3d_x, + "rotation_3d_y": rotation_3d_y, + "rotation_3d_z": rotation_3d_z, + "midas_depth_model": midas_depth_model, + "midas_weight": midas_weight, + "near_plane": near_plane, + "far_plane": far_plane, + "fov": fov, + "padding_mode": padding_mode, + "sampling_mode": sampling_mode, + "angle_series": angle_series, + "zoom_series": zoom_series, + "translation_x_series": translation_x_series, + "translation_y_series": translation_y_series, + "translation_z_series": translation_z_series, + "rotation_3d_x_series": rotation_3d_x_series, + "rotation_3d_y_series": rotation_3d_y_series, + "rotation_3d_z_series": rotation_3d_z_series, + "frames_scale": frames_scale, + "calc_frames_skip_steps": calc_frames_skip_steps, + "skip_step_ratio": skip_step_ratio, + "calc_frames_skip_steps": calc_frames_skip_steps, + "text_prompts": text_prompts, + "image_prompts": image_prompts, + "cut_overview": eval(cut_overview), + "cut_innercut": eval(cut_innercut), + "cut_ic_pow": cut_ic_pow, + "cut_icgray_p": eval(cut_icgray_p), + "intermediate_saves": intermediate_saves, + "intermediates_in_subfolder": intermediates_in_subfolder, + "steps_per_checkpoint": steps_per_checkpoint, + "perlin_init": perlin_init, + "perlin_mode": perlin_mode, + "set_seed": set_seed, + "eta": eta, + "clamp_grad": clamp_grad, + "clamp_max": clamp_max, + "skip_augs": skip_augs, + "randomize_class": randomize_class, + "clip_denoised": clip_denoised, + "fuzzy_prompt": fuzzy_prompt, + "rand_mag": rand_mag, + } + + args = SimpleNamespace(**args) + + print("Prepping model...") + model, diffusion = create_model_and_diffusion(**model_config) + model.load_state_dict( + torch.load( + f"{DefaultPaths.model_path}/{diffusion_model}.pt", map_location="cpu" + ) + ) + model.requires_grad_(False).eval().to(device) + for name, param in model.named_parameters(): + if "qkv" in name or "norm" in name or "proj" in name: + param.requires_grad_() + if model_config["use_fp16"]: + model.convert_to_fp16() + + sys.stdout.write("Starting ...\n") + sys.stdout.flush() + status.write(f"Starting ...\n") + + gc.collect() + torch.cuda.empty_cache() + try: + do_run() + # except st.script_runner.StopException as e: + # print("stopped here (a bit out)") + # pass + except KeyboardInterrupt: + pass + finally: + gc.collect() + torch.cuda.empty_cache()