diff --git "a/disco_streamlit_run.py" "b/disco_streamlit_run.py"
new file mode 100644--- /dev/null
+++ "b/disco_streamlit_run.py"
@@ -0,0 +1,2522 @@
+# Disco Diffusion v5 [w/ 3D animation] (modified by @softology to work on Visions of Chaos and further modified by @multimodalart to run on MindsEye)
+# Adapted from the Visions of Chaos software (https://softology.pro/voc.htm), that adapted it from the
+# Original file is located at https://colab.research.google.com/github/alembics/disco-diffusion/blob/main/Disco_Diffusion.ipynb
+
+# required models
+# https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt
+# https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt
+# git clone https://github.com/isl-org/MiDaS.git
+# git clone https://github.com/alembics/disco-diffusion.git
+
+
+"""#Tutorial
+
+**Diffusion settings (Defaults are heavily outdated)**
+---
+
+This section is outdated as of v2
+
+Setting | Description | Default
+--- | --- | ---
+**Your vision:**
+`text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A
+`image_prompts` | Think of these images more as a description of their contents. | N/A
+**Image quality:**
+`clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000
+`tv_scale` | Controls the smoothness of the final output. | 150
+`range_scale` | Controls how far out of range RGB values are allowed to be. | 150
+`sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0
+`cutn` | Controls how many crops to take from the image. | 16
+`cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2
+**Init settings:**
+`init_image` | URL or local path | None
+`init_scale` | This enhances the effect of the init image, a good value is 1000 | 0
+`skip_steps Controls the starting point along the diffusion timesteps | 0
+`perlin_init` | Option to start with random perlin noise | False
+`perlin_mode` | ('gray', 'color') | 'mixed'
+**Advanced:**
+`skip_augs` |Controls whether to skip torchvision augmentations | False
+`randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True
+`clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False
+`clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True
+`seed` | Choose a random seed and print it at end of run for reproduction | random_seed
+`fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False
+`rand_mag` |Controls the magnitude of the random noise | 0.1
+`eta` | DDIM hyperparameter | 0.5
+
+..
+
+**Model settings**
+---
+
+Setting | Description | Default
+--- | --- | ---
+**Diffusion:**
+`timestep_respacing` | Modify this value to decrease the number of timesteps. | ddim100
+`diffusion_steps` || 1000
+**Diffusion:**
+`clip_models` | Models of CLIP to load. Typically the more, the better but they all come at a hefty VRAM cost. | ViT-B/32, ViT-B/16, RN50x4
+
+# 1. Set Up
+"""
+
+
+is_colab = False
+google_drive = False
+save_models_to_google_drive = False
+
+import sys
+
+sys.stdout.write("Imports ...\n")
+sys.stdout.flush()
+
+sys.path.append("./ResizeRight")
+sys.path.append("./MiDaS")
+sys.path.append("./CLIP")
+sys.path.append("./guided-diffusion")
+sys.path.append("./latent-diffusion")
+sys.path.append(".")
+sys.path.append("./taming-transformers")
+sys.path.append("./disco-diffusion")
+sys.path.append("./AdaBins")
+sys.path.append('./pytorch3d-lite')
+# sys.path.append('./pytorch3d')
+
+import os
+import streamlit as st
+from os import path
+from os.path import exists as path_exists
+import sys
+import torch
+
+# sys.path.append('./SLIP')
+from dataclasses import dataclass
+from functools import partial
+import cv2
+import pandas as pd
+import gc
+import io
+import math
+import timm
+from IPython import display
+import lpips
+from PIL import Image, ImageOps
+import requests
+from glob import glob
+import json
+from types import SimpleNamespace
+from torch import nn
+from torch.nn import functional as F
+import torchvision.transforms as T
+import torchvision.transforms.functional as TF
+import shutil
+from pathvalidate import sanitize_filename
+
+# from tqdm.notebook import tqdm
+# from stqdm_local import stqdm
+import clip
+from resize_right import resize
+
+# from models import SLIP_VITB16, SLIP, SLIP_VITL16
+from guided_diffusion.script_util import (
+ create_model_and_diffusion,
+ model_and_diffusion_defaults,
+)
+from datetime import datetime
+import numpy as np
+import matplotlib.pyplot as plt
+import random
+from ipywidgets import Output
+import hashlib
+import ipywidgets as widgets
+import os
+
+# from taming.models import vqgan # checking correct import from taming
+from torchvision.datasets.utils import download_url
+from functools import partial
+from ldm.util import instantiate_from_config
+from ldm.modules.diffusionmodules.util import (
+ make_ddim_sampling_parameters,
+ make_ddim_timesteps,
+ noise_like,
+)
+
+# from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.util import ismap
+from IPython.display import Image as ipyimg
+from numpy import asarray
+from einops import rearrange, repeat
+import torch, torchvision
+import time
+from omegaconf import OmegaConf
+from midas.dpt_depth import DPTDepthModel
+from midas.midas_net import MidasNet
+from midas.midas_net_custom import MidasNet_small
+from midas.transforms import Resize, NormalizeImage, PrepareForNet
+import torch
+import py3d_tools as p3dT
+import disco_xform_utils as dxf
+import argparse
+
+sys.stdout.write("Parsing arguments ...\n")
+sys.stdout.flush()
+
+
+def run_model(args2, status, stoutput, DefaultPaths):
+ if args2.seed is not None:
+ sys.stdout.write(f"Setting seed to {args2.seed} ...\n")
+ sys.stdout.flush()
+ status.write(f"Setting seed to {args2.seed} ...\n")
+ import numpy as np
+
+ np.random.seed(args2.seed)
+ import random
+
+ random.seed(args2.seed)
+ # next line forces deterministic random values, but causes other issues with resampling (uncomment to see)
+ # torch.use_deterministic_algorithms(True)
+ torch.manual_seed(args2.seed)
+ torch.cuda.manual_seed(args2.seed)
+ torch.cuda.manual_seed_all(args2.seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print("Using device:", DEVICE)
+ device = DEVICE # At least one of the modules expects this name..
+
+ # If running locally, there's a good chance your env will need this in order to not crash upon np.matmul() or similar operations.
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
+
+ PROJECT_DIR = os.path.abspath(os.getcwd())
+
+ # AdaBins stuff
+ USE_ADABINS = True
+ if USE_ADABINS:
+ sys.path.append("./AdaBins")
+ from infer import InferenceHelper
+
+ MAX_ADABINS_AREA = 500000
+
+ model_256_downloaded = False
+ model_512_downloaded = False
+ model_secondary_downloaded = False
+
+ # Initialize MiDaS depth model.
+ # It remains resident in VRAM and likely takes around 2GB VRAM.
+ # You could instead initialize it for each frame (and free it after each frame) to save VRAM.. but initializing it is slow.
+ default_models = {
+ "midas_v21_small": f"{DefaultPaths.model_path}/midas_v21_small-70d6b9c8.pt",
+ "midas_v21": f"{DefaultPaths.model_path}/midas_v21-f6b98070.pt",
+ "dpt_large": f"{DefaultPaths.model_path}/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": f"{DefaultPaths.model_path}/dpt_hybrid-midas-501f0c75.pt",
+ "dpt_hybrid_nyu": f"{DefaultPaths.model_path}/dpt_hybrid_nyu-2ce69ec7.pt",
+ }
+
+ def init_midas_depth_model(midas_model_type="dpt_large", optimize=True):
+ midas_model = None
+ net_w = None
+ net_h = None
+ resize_mode = None
+ normalization = None
+
+ print(f"Initializing MiDaS '{midas_model_type}' depth model...")
+ # load network
+ midas_model_path = default_models[midas_model_type]
+
+ if midas_model_type == "dpt_large": # DPT-Large
+ midas_model = DPTDepthModel(
+ path=midas_model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif midas_model_type == "dpt_hybrid": # DPT-Hybrid
+ midas_model = DPTDepthModel(
+ path=midas_model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif midas_model_type == "dpt_hybrid_nyu": # DPT-Hybrid-NYU
+ midas_model = DPTDepthModel(
+ path=midas_model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+ elif midas_model_type == "midas_v21":
+ midas_model = MidasNet(midas_model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ elif midas_model_type == "midas_v21_small":
+ midas_model = MidasNet_small(
+ midas_model_path,
+ features=64,
+ backbone="efficientnet_lite3",
+ exportable=True,
+ non_negative=True,
+ blocks={"expand": True},
+ )
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+ else:
+ print(f"midas_model_type '{midas_model_type}' not implemented")
+ assert False
+
+ midas_transform = T.Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ midas_model.eval()
+
+ if optimize == True:
+ if DEVICE == torch.device("cuda"):
+ midas_model = midas_model.to(memory_format=torch.channels_last)
+ midas_model = midas_model.half()
+
+ midas_model.to(DEVICE)
+
+ print(f"MiDaS '{midas_model_type}' depth model initialized.")
+ return midas_model, midas_transform, net_w, net_h, resize_mode, normalization
+
+ # @title 1.5 Define necessary functions
+
+ # https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869
+
+ def interp(t):
+ return 3 * t**2 - 2 * t**3
+
+ def perlin(width, height, scale=10, device=None):
+ gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)
+ xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)
+ ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)
+ wx = 1 - interp(xs)
+ wy = 1 - interp(ys)
+ dots = 0
+ dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
+ dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
+ dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
+ dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
+ return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)
+
+ def perlin_ms(octaves, width, height, grayscale, device=device):
+ out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
+ # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
+ for i in range(1 if grayscale else 3):
+ scale = 2 ** len(octaves)
+ oct_width = width
+ oct_height = height
+ for oct in octaves:
+ p = perlin(oct_width, oct_height, scale, device)
+ out_array[i] += p * oct
+ scale //= 2
+ oct_width *= 2
+ oct_height *= 2
+ return torch.cat(out_array)
+
+ def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):
+ out = perlin_ms(octaves, width, height, grayscale)
+ if grayscale:
+ out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))
+ out = TF.to_pil_image(out.clamp(0, 1)).convert("RGB")
+ else:
+ out = out.reshape(-1, 3, out.shape[0] // 3, out.shape[1])
+ out = TF.resize(size=(side_y, side_x), img=out)
+ out = TF.to_pil_image(out.clamp(0, 1).squeeze())
+
+ out = ImageOps.autocontrast(out)
+ return out
+
+ def regen_perlin():
+ if perlin_mode == "color":
+ init = create_perlin_noise(
+ [1.5**-i * 0.5 for i in range(12)], 1, 1, False
+ )
+ init2 = create_perlin_noise(
+ [1.5**-i * 0.5 for i in range(8)], 4, 4, False
+ )
+ elif perlin_mode == "gray":
+ init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1, True)
+ init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, True)
+ else:
+ init = create_perlin_noise(
+ [1.5**-i * 0.5 for i in range(12)], 1, 1, False
+ )
+ init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, True)
+
+ init = (
+ TF.to_tensor(init)
+ .add(TF.to_tensor(init2))
+ .div(2)
+ .to(device)
+ .unsqueeze(0)
+ .mul(2)
+ .sub(1)
+ )
+ del init2
+ return init.expand(batch_size, -1, -1, -1)
+
+ def fetch(url_or_path):
+ if str(url_or_path).startswith("http://") or str(url_or_path).startswith(
+ "https://"
+ ):
+ r = requests.get(url_or_path)
+ r.raise_for_status()
+ fd = io.BytesIO()
+ fd.write(r.content)
+ fd.seek(0)
+ return fd
+ return open(url_or_path, "rb")
+
+ def read_image_workaround(path):
+ """OpenCV reads images as BGR, Pillow saves them as RGB. Work around
+ this incompatibility to avoid colour inversions."""
+ im_tmp = cv2.imread(path)
+ return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)
+
+ def parse_prompt(prompt):
+ if prompt.startswith("http://") or prompt.startswith("https://"):
+ vals = prompt.rsplit(":", 2)
+ vals = [vals[0] + ":" + vals[1], *vals[2:]]
+ else:
+ vals = prompt.rsplit(":", 1)
+ vals = vals + ["", "1"][len(vals) :]
+ return vals[0], float(vals[1])
+
+ def sinc(x):
+ return torch.where(
+ x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])
+ )
+
+ def lanczos(x, a):
+ cond = torch.logical_and(-a < x, x < a)
+ out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
+ return out / out.sum()
+
+ def ramp(ratio, width):
+ n = math.ceil(width / ratio + 1)
+ out = torch.empty([n])
+ cur = 0
+ for i in range(out.shape[0]):
+ out[i] = cur
+ cur += ratio
+ return torch.cat([-out[1:].flip([0]), out])[1:-1]
+
+ def resample(input, size, align_corners=True):
+ n, c, h, w = input.shape
+ dh, dw = size
+
+ input = input.reshape([n * c, 1, h, w])
+
+ if dh < h:
+ kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
+ pad_h = (kernel_h.shape[0] - 1) // 2
+ input = F.pad(input, (0, 0, pad_h, pad_h), "reflect")
+ input = F.conv2d(input, kernel_h[None, None, :, None])
+
+ if dw < w:
+ kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
+ pad_w = (kernel_w.shape[0] - 1) // 2
+ input = F.pad(input, (pad_w, pad_w, 0, 0), "reflect")
+ input = F.conv2d(input, kernel_w[None, None, None, :])
+
+ input = input.reshape([n, c, h, w])
+ return F.interpolate(input, size, mode="bicubic", align_corners=align_corners)
+
+ class MakeCutouts(nn.Module):
+ def __init__(self, cut_size, cutn, skip_augs=False):
+ super().__init__()
+ self.cut_size = cut_size
+ self.cutn = cutn
+ self.skip_augs = skip_augs
+ self.augs = T.Compose(
+ [
+ T.RandomHorizontalFlip(p=0.5),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomPerspective(distortion_scale=0.4, p=0.7),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.15),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
+ ]
+ )
+
+ def forward(self, input):
+ input = T.Pad(input.shape[2] // 4, fill=0)(input)
+ sideY, sideX = input.shape[2:4]
+ max_size = min(sideX, sideY)
+
+ cutouts = []
+ for ch in range(self.cutn):
+ if ch > self.cutn - self.cutn // 4:
+ cutout = input.clone()
+ else:
+ size = int(
+ max_size
+ * torch.zeros(
+ 1,
+ )
+ .normal_(mean=0.8, std=0.3)
+ .clip(float(self.cut_size / max_size), 1.0)
+ )
+ offsetx = torch.randint(0, abs(sideX - size + 1), ())
+ offsety = torch.randint(0, abs(sideY - size + 1), ())
+ cutout = input[
+ :, :, offsety : offsety + size, offsetx : offsetx + size
+ ]
+
+ if not self.skip_augs:
+ cutout = self.augs(cutout)
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
+ del cutout
+
+ cutouts = torch.cat(cutouts, dim=0)
+ return cutouts
+
+ cutout_debug = False
+ padargs = {}
+
+ class MakeCutoutsDango(nn.Module):
+ def __init__(
+ self, cut_size, Overview=4, InnerCrop=0, IC_Size_Pow=0.5, IC_Grey_P=0.2
+ ):
+ super().__init__()
+ self.cut_size = cut_size
+ self.Overview = Overview
+ self.InnerCrop = InnerCrop
+ self.IC_Size_Pow = IC_Size_Pow
+ self.IC_Grey_P = IC_Grey_P
+ if args.animation_mode == "None":
+ self.augs = T.Compose(
+ [
+ T.RandomHorizontalFlip(p=0.5),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(
+ degrees=10,
+ translate=(0.05, 0.05),
+ interpolation=T.InterpolationMode.BILINEAR,
+ ),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.1),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.ColorJitter(
+ brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
+ ),
+ ]
+ )
+ elif args.animation_mode == "Video Input":
+ self.augs = T.Compose(
+ [
+ T.RandomHorizontalFlip(p=0.5),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomPerspective(distortion_scale=0.4, p=0.7),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.15),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
+ ]
+ )
+ elif args.animation_mode == "2D" or args.animation_mode == "3D":
+ self.augs = T.Compose(
+ [
+ T.RandomHorizontalFlip(p=0.4),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomAffine(
+ degrees=10,
+ translate=(0.05, 0.05),
+ interpolation=T.InterpolationMode.BILINEAR,
+ ),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.RandomGrayscale(p=0.1),
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
+ T.ColorJitter(
+ brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3
+ ),
+ ]
+ )
+
+ def forward(self, input):
+ cutouts = []
+ gray = T.Grayscale(3)
+ sideY, sideX = input.shape[2:4]
+ max_size = min(sideX, sideY)
+ min_size = min(sideX, sideY, self.cut_size)
+ l_size = max(sideX, sideY)
+ output_shape = [1, 3, self.cut_size, self.cut_size]
+ output_shape_2 = [1, 3, self.cut_size + 2, self.cut_size + 2]
+ pad_input = F.pad(
+ input,
+ (
+ (sideY - max_size) // 2,
+ (sideY - max_size) // 2,
+ (sideX - max_size) // 2,
+ (sideX - max_size) // 2,
+ ),
+ **padargs,
+ )
+ cutout = resize(pad_input, out_shape=output_shape)
+
+ if self.Overview > 0:
+ if self.Overview <= 4:
+ if self.Overview >= 1:
+ cutouts.append(cutout)
+ if self.Overview >= 2:
+ cutouts.append(gray(cutout))
+ if self.Overview >= 3:
+ cutouts.append(TF.hflip(cutout))
+ if self.Overview == 4:
+ cutouts.append(gray(TF.hflip(cutout)))
+ else:
+ cutout = resize(pad_input, out_shape=output_shape)
+ for _ in range(self.Overview):
+ cutouts.append(cutout)
+
+ if cutout_debug:
+ if is_colab:
+ TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(
+ "/content/cutout_overview0.jpg", quality=99
+ )
+ else:
+ TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(
+ "cutout_overview0.jpg", quality=99
+ )
+
+ if self.InnerCrop > 0:
+ for i in range(self.InnerCrop):
+ size = int(
+ torch.rand([]) ** self.IC_Size_Pow * (max_size - min_size)
+ + min_size
+ )
+ offsetx = torch.randint(0, sideX - size + 1, ())
+ offsety = torch.randint(0, sideY - size + 1, ())
+ cutout = input[
+ :, :, offsety : offsety + size, offsetx : offsetx + size
+ ]
+ if i <= int(self.IC_Grey_P * self.InnerCrop):
+ cutout = gray(cutout)
+ cutout = resize(cutout, out_shape=output_shape)
+ cutouts.append(cutout)
+ if cutout_debug:
+ if is_colab:
+ TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(
+ "/content/cutout_InnerCrop.jpg", quality=99
+ )
+ else:
+ TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(
+ "cutout_InnerCrop.jpg", quality=99
+ )
+ cutouts = torch.cat(cutouts)
+ if skip_augs is not True:
+ cutouts = self.augs(cutouts)
+ return cutouts
+
+ def spherical_dist_loss(x, y):
+ x = F.normalize(x, dim=-1)
+ y = F.normalize(y, dim=-1)
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
+
+ def tv_loss(input):
+ """L2 total variation loss, as in Mahendran et al."""
+ input = F.pad(input, (0, 1, 0, 1), "replicate")
+ x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
+ y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
+ return (x_diff**2 + y_diff**2).mean([1, 2, 3])
+
+ def range_loss(input):
+ return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
+
+ stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete
+
+ def nsToStr(d):
+ h = 3.6e12
+ m = h / 60
+ s = m / 60
+ return (
+ str(int(d / h))
+ + ":"
+ + str(int((d % h) / m))
+ + ":"
+ + str(int((d % h) % m / s))
+ + "."
+ + str(int((d % h) % m % s))
+ )
+
+ def do_run():
+ seed = args.seed
+ # print(range(args.start_frame, args.max_frames))
+
+ if (args.animation_mode == "3D") and (args.midas_weight > 0.0):
+ (
+ midas_model,
+ midas_transform,
+ midas_net_w,
+ midas_net_h,
+ midas_resize_mode,
+ midas_normalization,
+ ) = init_midas_depth_model(args.midas_depth_model)
+ for frame_num in range(args.start_frame, args.max_frames):
+ if stop_on_next_loop:
+ break
+
+ display.clear_output(wait=True)
+
+ # Print Frame progress if animation mode is on
+
+ """
+ if args.animation_mode != "None":
+ batchBar = tqdm(range(args.max_frames), desc ="Frames")
+ batchBar.n = frame_num
+ batchBar.refresh()
+ """
+
+ # Inits if not video frames
+ if args.animation_mode != "Video Input":
+ if args.init_image == "":
+ init_image = None
+ else:
+ init_image = args.init_image
+ init_scale = args.init_scale
+ skip_steps = args.skip_steps
+
+ if args.animation_mode == "2D":
+ if args.key_frames:
+ angle = args.angle_series[frame_num]
+ zoom = args.zoom_series[frame_num]
+ translation_x = args.translation_x_series[frame_num]
+ translation_y = args.translation_y_series[frame_num]
+ print(
+ f"angle: {angle}",
+ f"zoom: {zoom}",
+ f"translation_x: {translation_x}",
+ f"translation_y: {translation_y}",
+ )
+
+ if frame_num > 0:
+ seed = seed + 1
+ if resume_run and frame_num == start_frame:
+ img_0 = cv2.imread(
+ batchFolder
+ + f"/{batch_name}({batchNum})_{start_frame-1:04}.png"
+ )
+ else:
+ img_0 = cv2.imread("prevFrame.png")
+ center = (1 * img_0.shape[1] // 2, 1 * img_0.shape[0] // 2)
+ trans_mat = np.float32(
+ [[1, 0, translation_x], [0, 1, translation_y]]
+ )
+ rot_mat = cv2.getRotationMatrix2D(center, angle, zoom)
+ trans_mat = np.vstack([trans_mat, [0, 0, 1]])
+ rot_mat = np.vstack([rot_mat, [0, 0, 1]])
+ transformation_matrix = np.matmul(rot_mat, trans_mat)
+ img_0 = cv2.warpPerspective(
+ img_0,
+ transformation_matrix,
+ (img_0.shape[1], img_0.shape[0]),
+ borderMode=cv2.BORDER_WRAP,
+ )
+
+ cv2.imwrite("prevFrameScaled.png", img_0)
+ init_image = "prevFrameScaled.png"
+ init_scale = args.frames_scale
+ skip_steps = args.calc_frames_skip_steps
+
+ if args.animation_mode == "3D":
+ if args.key_frames:
+ angle = args.angle_series[frame_num]
+ # zoom = args.zoom_series[frame_num]
+ translation_x = args.translation_x_series[frame_num]
+ translation_y = args.translation_y_series[frame_num]
+ translation_z = args.translation_z_series[frame_num]
+ rotation_3d_x = args.rotation_3d_x_series[frame_num]
+ rotation_3d_y = args.rotation_3d_y_series[frame_num]
+ rotation_3d_z = args.rotation_3d_z_series[frame_num]
+ print(
+ f"angle: {angle}",
+ # f'zoom: {zoom}',
+ f"translation_x: {translation_x}",
+ f"translation_y: {translation_y}",
+ f"translation_z: {translation_z}",
+ f"rotation_3d_x: {rotation_3d_x}",
+ f"rotation_3d_y: {rotation_3d_y}",
+ f"rotation_3d_z: {rotation_3d_z}",
+ )
+
+ sys.stdout.flush()
+ # sys.stdout.write(f'FRAME_NUM = {frame_num} ...\n')
+ sys.stdout.flush()
+
+ if frame_num > 0:
+ seed = seed + 1
+ img_filepath = "prevFrame.png"
+ trans_scale = 1.0 / 200.0
+ translate_xyz = [
+ -translation_x * trans_scale,
+ translation_y * trans_scale,
+ -translation_z * trans_scale,
+ ]
+ rotate_xyz = [
+ math.radians(rotation_3d_x),
+ math.radians(rotation_3d_y),
+ math.radians(rotation_3d_z),
+ ]
+ print("translation:", translate_xyz)
+ print("rotation:", rotate_xyz)
+ rot_mat = p3dT.euler_angles_to_matrix(
+ torch.tensor(rotate_xyz, device=device), "XYZ"
+ ).unsqueeze(0)
+ print("rot_mat: " + str(rot_mat))
+ next_step_pil = dxf.transform_image_3d(
+ img_filepath,
+ midas_model,
+ midas_transform,
+ DEVICE,
+ rot_mat,
+ translate_xyz,
+ args.near_plane,
+ args.far_plane,
+ args.fov,
+ padding_mode=args.padding_mode,
+ sampling_mode=args.sampling_mode,
+ midas_weight=args.midas_weight,
+ )
+ next_step_pil.save("prevFrameScaled.png")
+
+ """
+ ### Turbo mode - skip some diffusions to save time
+ if turbo_mode == True and frame_num > 10 and frame_num % int(turbo_steps) != 0:
+ #turbo_steps
+ print('turbo mode is on this frame: skipping clip diffusion steps')
+ #this is an even frame. copy warped prior frame w/ war
+ #filename = f'{args.batch_name}({args.batchNum})_{frame_num:04}.png'
+ #next_step_pil.save(f'{batchFolder}/{filename}') #save it as this frame
+ #next_step_pil.save(f'{img_filepath}') # save it also as prev_frame for next iteration
+ filename = f'progress.png'
+ next_step_pil.save(f'{filename}') #save it as this frame
+ next_step_pil.save(f'{img_filepath}') # save it also as prev_frame for next iteration
+ continue
+ elif turbo_mode == True:
+ print('turbo mode is OFF this frame')
+ #else: no turbo
+ """
+
+ init_image = "prevFrameScaled.png"
+ init_scale = args.frames_scale
+ skip_steps = args.calc_frames_skip_steps
+
+ if args.animation_mode == "Video Input":
+ seed = seed + 1
+ init_image = f"{videoFramesFolder}/{frame_num+1:04}.jpg"
+ init_scale = args.frames_scale
+ skip_steps = args.calc_frames_skip_steps
+
+ loss_values = []
+
+ if seed is not None:
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+
+ target_embeds, weights = [], []
+
+ if args.prompts_series is not None and frame_num >= len(
+ args.prompts_series
+ ):
+ frame_prompt = args.prompts_series[-1]
+ elif args.prompts_series is not None:
+ frame_prompt = args.prompts_series[frame_num]
+ else:
+ frame_prompt = []
+
+ print(args.image_prompts_series)
+ if args.image_prompts_series is not None and frame_num >= len(
+ args.image_prompts_series
+ ):
+ image_prompt = args.image_prompts_series[-1]
+ elif args.image_prompts_series is not None:
+ image_prompt = args.image_prompts_series[frame_num]
+ else:
+ image_prompt = []
+
+ print(f"Frame Prompt: {frame_prompt}")
+
+ model_stats = []
+ for clip_model in clip_models:
+ cutn = args2.cutn
+ model_stat = {
+ "clip_model": None,
+ "target_embeds": [],
+ "make_cutouts": None,
+ "weights": [],
+ }
+ model_stat["clip_model"] = clip_model
+
+ for prompt in frame_prompt:
+ txt, weight = parse_prompt(prompt)
+ txt = clip_model.encode_text(
+ clip.tokenize(prompt).to(device)
+ ).float()
+
+ if args.fuzzy_prompt:
+ for i in range(25):
+ model_stat["target_embeds"].append(
+ (
+ txt + torch.randn(txt.shape).cuda() * args.rand_mag
+ ).clamp(0, 1)
+ )
+ model_stat["weights"].append(weight)
+ else:
+ model_stat["target_embeds"].append(txt)
+ model_stat["weights"].append(weight)
+
+ if image_prompt:
+ model_stat["make_cutouts"] = MakeCutouts(
+ clip_model.visual.input_resolution, cutn, skip_augs=skip_augs
+ )
+ for prompt in image_prompt:
+ path, weight = parse_prompt(prompt)
+ img = Image.open(fetch(path)).convert("RGB")
+ img = TF.resize(
+ img,
+ min(side_x, side_y, *img.size),
+ T.InterpolationMode.LANCZOS,
+ )
+ batch = model_stat["make_cutouts"](
+ TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1)
+ )
+ embed = clip_model.encode_image(normalize(batch)).float()
+ if fuzzy_prompt:
+ for i in range(25):
+ model_stat["target_embeds"].append(
+ (
+ embed
+ + torch.randn(embed.shape).cuda() * rand_mag
+ ).clamp(0, 1)
+ )
+ weights.extend([weight / cutn] * cutn)
+ else:
+ model_stat["target_embeds"].append(embed)
+ model_stat["weights"].extend([weight / cutn] * cutn)
+
+ model_stat["target_embeds"] = torch.cat(model_stat["target_embeds"])
+ model_stat["weights"] = torch.tensor(
+ model_stat["weights"], device=device
+ )
+ if model_stat["weights"].sum().abs() < 1e-3:
+ raise RuntimeError("The weights must not sum to 0.")
+ model_stat["weights"] /= model_stat["weights"].sum().abs()
+ model_stats.append(model_stat)
+
+ init = None
+ if init_image is not None:
+ init = Image.open(fetch(init_image)).convert("RGB")
+ init = init.resize((args.side_x, args.side_y), Image.LANCZOS)
+ init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
+
+ if args.perlin_init:
+ if args.perlin_mode == "color":
+ init = create_perlin_noise(
+ [1.5**-i * 0.5 for i in range(12)], 1, 1, False
+ )
+ init2 = create_perlin_noise(
+ [1.5**-i * 0.5 for i in range(8)], 4, 4, False
+ )
+ elif args.perlin_mode == "gray":
+ init = create_perlin_noise(
+ [1.5**-i * 0.5 for i in range(12)], 1, 1, True
+ )
+ init2 = create_perlin_noise(
+ [1.5**-i * 0.5 for i in range(8)], 4, 4, True
+ )
+ else:
+ init = create_perlin_noise(
+ [1.5**-i * 0.5 for i in range(12)], 1, 1, False
+ )
+ init2 = create_perlin_noise(
+ [1.5**-i * 0.5 for i in range(8)], 4, 4, True
+ )
+ # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)
+ init = (
+ TF.to_tensor(init)
+ .add(TF.to_tensor(init2))
+ .div(2)
+ .to(device)
+ .unsqueeze(0)
+ .mul(2)
+ .sub(1)
+ )
+ del init2
+
+ cur_t = None
+
+ def cond_fn(x, t, y=None):
+ with torch.enable_grad():
+ x_is_NaN = False
+ x = x.detach().requires_grad_()
+ n = x.shape[0]
+ if use_secondary_model is True:
+ alpha = torch.tensor(
+ diffusion.sqrt_alphas_cumprod[cur_t],
+ device=device,
+ dtype=torch.float32,
+ )
+ sigma = torch.tensor(
+ diffusion.sqrt_one_minus_alphas_cumprod[cur_t],
+ device=device,
+ dtype=torch.float32,
+ )
+ cosine_t = alpha_sigma_to_t(alpha, sigma)
+ out = secondary_model(x, cosine_t[None].repeat([n])).pred
+ fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
+ x_in = out * fac + x * (1 - fac)
+ x_in_grad = torch.zeros_like(x_in)
+ else:
+ my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
+ out = diffusion.p_mean_variance(
+ model, x, my_t, clip_denoised=False, model_kwargs={"y": y}
+ )
+ fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
+ x_in = out["pred_xstart"] * fac + x * (1 - fac)
+ x_in_grad = torch.zeros_like(x_in)
+ for model_stat in model_stats:
+ for i in range(int(args.cutn_batches)):
+ t_int = (
+ int(t.item()) + 1
+ ) # errors on last step without +1, need to find source
+ # when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution'
+ try:
+ input_resolution = model_stat[
+ "clip_model"
+ ].visual.input_resolution
+ except:
+ input_resolution = 224
+
+ cuts = MakeCutoutsDango(
+ input_resolution,
+ Overview=args.cut_overview[1000 - t_int],
+ InnerCrop=args.cut_innercut[1000 - t_int],
+ IC_Size_Pow=args.cut_ic_pow,
+ IC_Grey_P=args.cut_icgray_p[1000 - t_int],
+ )
+ clip_in = normalize(cuts(x_in.add(1).div(2)))
+ image_embeds = (
+ model_stat["clip_model"].encode_image(clip_in).float()
+ )
+ dists = spherical_dist_loss(
+ image_embeds.unsqueeze(1),
+ model_stat["target_embeds"].unsqueeze(0),
+ )
+ dists = dists.view(
+ [
+ args.cut_overview[1000 - t_int]
+ + args.cut_innercut[1000 - t_int],
+ n,
+ -1,
+ ]
+ )
+ losses = dists.mul(model_stat["weights"]).sum(2).mean(0)
+ loss_values.append(
+ losses.sum().item()
+ ) # log loss, probably shouldn't do per cutn_batch
+ x_in_grad += (
+ torch.autograd.grad(
+ losses.sum() * clip_guidance_scale, x_in
+ )[0]
+ / cutn_batches
+ )
+ tv_losses = tv_loss(x_in)
+ if use_secondary_model is True:
+ range_losses = range_loss(out)
+ else:
+ range_losses = range_loss(out["pred_xstart"])
+ sat_losses = torch.abs(x_in - x_in.clamp(min=-1, max=1)).mean()
+ loss = (
+ tv_losses.sum() * tv_scale
+ + range_losses.sum() * range_scale
+ + sat_losses.sum() * sat_scale
+ )
+ if init is not None and args.init_scale:
+ init_losses = lpips_model(x_in, init)
+ loss = loss + init_losses.sum() * args.init_scale
+ x_in_grad += torch.autograd.grad(loss, x_in)[0]
+ if torch.isnan(x_in_grad).any() == False:
+ grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]
+ else:
+ # print("NaN'd")
+ x_is_NaN = True
+ grad = torch.zeros_like(x)
+ if args.clamp_grad and x_is_NaN == False:
+ magnitude = grad.square().mean().sqrt()
+ return (
+ grad * magnitude.clamp(max=args.clamp_max) / magnitude
+ ) # min=-0.02, min=-clamp_max,
+ return grad
+
+ if args.sampling_mode == "ddim":
+ sample_fn = diffusion.ddim_sample_loop_progressive
+ elif args.sampling_mode == "bicubic":
+ sample_fn = diffusion.p_sample_loop_progressive
+ elif args.sampling_mode == "plms":
+ sample_fn = diffusion.plms_sample_loop_progressive
+ # if model_config["timestep_respacing"].startswith("ddim"):
+ # sample_fn = diffusion.ddim_sample_loop_progressive
+ # else:
+ # sample_fn = diffusion.p_sample_loop_progressive
+
+ image_display = Output()
+ for i in range(args.n_batches):
+ """
+ if args.animation_mode == 'None':
+ display.clear_output(wait=True)
+ batchBar = tqdm(range(args.n_batches), desc ="Batches")
+ batchBar.n = i
+ batchBar.refresh()
+ print('')
+ display.display(image_display)
+ gc.collect()
+ torch.cuda.empty_cache()
+ """
+ cur_t = diffusion.num_timesteps - skip_steps - 1
+ total_steps = cur_t
+
+ if perlin_init:
+ init = regen_perlin()
+
+ if args.sampling_mode == "ddim":
+ samples = sample_fn(
+ model,
+ (batch_size, 3, args.side_y, args.side_x),
+ clip_denoised=clip_denoised,
+ model_kwargs={},
+ cond_fn=cond_fn,
+ progress=True,
+ skip_timesteps=skip_steps,
+ init_image=init,
+ randomize_class=randomize_class,
+ eta=eta,
+ )
+ elif args.sampling_mode == "plms":
+ samples = sample_fn(
+ model,
+ (batch_size, 3, args.side_y, args.side_x),
+ clip_denoised=clip_denoised,
+ model_kwargs={},
+ cond_fn=cond_fn,
+ progress=True,
+ skip_timesteps=skip_steps,
+ init_image=init,
+ randomize_class=randomize_class,
+ order=2,
+ )
+ elif args.sampling_mode == "bicubic":
+ samples = sample_fn(
+ model,
+ (batch_size, 3, args.side_y, args.side_x),
+ clip_denoised=clip_denoised,
+ model_kwargs={},
+ cond_fn=cond_fn,
+ progress=True,
+ skip_timesteps=skip_steps,
+ init_image=init,
+ randomize_class=randomize_class,
+ )
+
+ # with run_display:
+ # display.clear_output(wait=True)
+ itt = 1
+ imgToSharpen = None
+ status.write("Starting the execution...")
+ gc.collect()
+ torch.cuda.empty_cache()
+ # from tqdm.auto import tqdm
+ # from stqdm_local import stqdm
+
+ # total_iterables = stqdm(
+ # samples, total=total_steps + 1, st_container=stoutput
+ # )
+ total_iterables = samples
+ try:
+ j = 0
+ before_start_time = time.perf_counter()
+ bar_container = status.container()
+ iteration_counter = bar_container.empty()
+ progress_bar = bar_container.progress(0)
+ for sample in total_iterables:
+ if itt == 1:
+ iteration_counter.empty()
+ imageLocation = stoutput.empty()
+ sys.stdout.write(f"Iteration {itt}\n")
+ sys.stdout.flush()
+ cur_t -= 1
+ intermediateStep = False
+ if args.steps_per_checkpoint is not None:
+ if j % steps_per_checkpoint == 0 and j > 0:
+ intermediateStep = True
+ elif j in args.intermediate_saves:
+ intermediateStep = True
+ with image_display:
+ """
+ if j % args.display_rate == 0 or cur_t == -1 or intermediateStep == True:
+ for k, image in enumerate(sample['pred_xstart']):
+ # tqdm.write(f'Batch {i}, step {j}, output {k}:')
+ current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')
+ percent = math.ceil(j/total_steps*100)
+ if args.n_batches > 0:
+ #if intermediates are saved to the subfolder, don't append a step or percentage to the name
+ if cur_t == -1 and args.intermediates_in_subfolder is True:
+ save_num = f'{frame_num:04}' if animation_mode != "None" else i
+ filename = f'{args.batch_name}({args.batchNum})_{save_num}.png'
+ else:
+ #If we're working with percentages, append it
+ if args.steps_per_checkpoint is not None:
+ filename = f'{args.batch_name}({args.batchNum})_{i:04}-{percent:02}%.png'
+ # Or else, iIf we're working with specific steps, append those
+ else:
+ filename = f'{args.batch_name}({args.batchNum})_{i:04}-{j:03}.png'
+ image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
+ if j % args.display_rate == 0 or cur_t == -1:
+ image.save('progress.png')
+ #display.clear_output(wait=True)
+ #display.display(display.Image('progress.png'))
+ if args.steps_per_checkpoint is not None:
+ if j % args.steps_per_checkpoint == 0 and j > 0:
+ if args.intermediates_in_subfolder is True:
+ image.save(f'{partialFolder}/{filename}')
+ else:
+ image.save(f'{batchFolder}/{filename}')
+ else:
+ if j in args.intermediate_saves:
+ if args.intermediates_in_subfolder is True:
+ image.save(f'{partialFolder}/{filename}')
+ else:
+ image.save(f'{batchFolder}/{filename}')
+ if cur_t == -1:
+ if frame_num == 0:
+ save_settings()
+ if args.animation_mode != "None":
+ image.save('prevFrame.png')
+ if args.sharpen_preset != "Off" and animation_mode == "None":
+ imgToSharpen = image
+ if args.keep_unsharp is True:
+ image.save(f'{unsharpenFolder}/{filename}')
+ else:
+ image.save(f'{batchFolder}/{filename}')
+ # if frame_num != args.max_frames-1:
+ # display.clear_output()
+ """
+ if itt % args2.update == 0 or cur_t == -1 or itt == 1:
+ for k, image in enumerate(sample["pred_xstart"]):
+ sys.stdout.flush()
+ sys.stdout.write("Saving progress ...\n")
+ sys.stdout.flush()
+
+ image = TF.to_pil_image(
+ image.add(1).div(2).clamp(0, 1)
+ )
+
+ if args.animation_mode != "None":
+ image.save("prevFrame.png")
+
+ image.save(args2.image_file)
+ if (args2.frame_dir is not None) and (
+ args.animation_mode == "None"
+ ):
+ import os
+
+ file_list = []
+ for file in sorted(os.listdir(args2.frame_dir)):
+ if file.startswith("FRA"):
+ if file.endswith("PNG"):
+ if len(file) == 12:
+ file_list.append(file)
+ if file_list:
+ last_name = file_list[-1]
+ count_value = int(last_name[3:8]) + 1
+ count_string = f"{count_value:05d}"
+ else:
+ count_string = "00001"
+ save_name = (
+ args2.frame_dir
+ + "/FRA"
+ + count_string
+ + ".PNG"
+ )
+ image.save(save_name)
+
+ # sys.stdout.flush()
+ # sys.stdout.write(f'{itt}/{args2.iterations} {skip_steps} {args.animation_mode} {args2.frame_dir}\n')
+ # sys.stdout.flush()
+ if (
+ (args2.frame_dir is not None)
+ and (args.animation_mode == "3D")
+ and (itt == args2.iterations - skip_steps)
+ ):
+ sys.stdout.flush()
+ sys.stdout.write("Saving 3D frame...\n")
+ sys.stdout.flush()
+ import os
+
+ file_list = []
+ for file in os.listdir(args2.frame_dir):
+ if file.startswith("FRA"):
+ if file.endswith("PNG"):
+ if len(file) == 12:
+ file_list.append(file)
+ if file_list:
+ last_name = file_list[-1]
+ count_value = int(last_name[3:8]) + 1
+ count_string = f"{count_value:05d}"
+ else:
+ count_string = "00001"
+ save_name = (
+ args2.frame_dir
+ + "/FRA"
+ + count_string
+ + ".PNG"
+ )
+ image.save(save_name)
+
+ imageLocation.image(Image.open(args2.image_file))
+ sys.stdout.flush()
+ sys.stdout.write("Progress saved\n")
+ sys.stdout.flush()
+ itt += 1
+ j += 1
+ time_past_seconds = time.perf_counter() - before_start_time
+ iterations_per_second = j / time_past_seconds
+ time_left = (total_steps - j) / iterations_per_second
+ percentage = round((j / (total_steps + 1)) * 100)
+
+ iteration_counter.write(
+ f"{percentage}% {j}/{total_steps+1} [{time.strftime('%M:%S', time.gmtime(time_past_seconds))}<{time.strftime('%M:%S', time.gmtime(time_left))}, {round(iterations_per_second,2)} it/s]"
+ )
+ progress_bar.progress(int(percentage))
+
+ # if path_exists(drive_path):
+
+ except KeyboardInterrupt:
+ pass
+ # except st.script_runner.StopException as e:
+ # imageLocation.image(args2.image_file)
+ # gc.collect()
+ # torch.cuda.empty_cache()
+ # status.write("Done!")
+ # pass
+ imageLocation.empty()
+ with image_display:
+ if args.sharpen_preset != "Off" and animation_mode == "None":
+ print("Starting Diffusion Sharpening...")
+ do_superres(imgToSharpen, f"{batchFolder}/{filename}")
+ display.clear_output()
+
+ import shutil
+ from pathvalidate import sanitize_filename
+ import os
+
+ if not path_exists(DefaultPaths.output_path):
+ os.makedirs(DefaultPaths.output_path)
+ save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}.png"
+ print(save_filename)
+ file_list = []
+ if path_exists(save_filename):
+ for file in sorted(os.listdir(f"{DefaultPaths.output_path}/")):
+ if file.startswith(
+ f"{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}"
+ ):
+ print(file)
+ file_list.append(file)
+ print(file_list)
+ last_name = file_list[-1]
+ print(last_name)
+ if last_name[-15:-10] == "batch":
+ count_value = int(last_name[-10:-4]) + 1
+ count_string = f"{count_value:05d}"
+ save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}_batch {count_string}.png"
+ else:
+ save_filename = f"{DefaultPaths.output_path}/{sanitize_filename(args2.prompt)} [Disco Diffusion v5] {args2.seed}_batch 00001.png"
+ shutil.copyfile(
+ args2.image_file,
+ save_filename,
+ )
+ imageLocation.empty()
+ status.write("Done!")
+ plt.plot(np.array(loss_values), "r")
+
+ def save_settings():
+ setting_list = {
+ "text_prompts": text_prompts,
+ "image_prompts": image_prompts,
+ "clip_guidance_scale": clip_guidance_scale,
+ "tv_scale": tv_scale,
+ "range_scale": range_scale,
+ "sat_scale": sat_scale,
+ # 'cutn': cutn,
+ "cutn_batches": cutn_batches,
+ "max_frames": max_frames,
+ "interp_spline": interp_spline,
+ # 'rotation_per_frame': rotation_per_frame,
+ "init_image": init_image,
+ "init_scale": init_scale,
+ "skip_steps": skip_steps,
+ # 'zoom_per_frame': zoom_per_frame,
+ "frames_scale": frames_scale,
+ "frames_skip_steps": frames_skip_steps,
+ "perlin_init": perlin_init,
+ "perlin_mode": perlin_mode,
+ "skip_augs": skip_augs,
+ "randomize_class": randomize_class,
+ "clip_denoised": clip_denoised,
+ "clamp_grad": clamp_grad,
+ "clamp_max": clamp_max,
+ "seed": seed,
+ "fuzzy_prompt": fuzzy_prompt,
+ "rand_mag": rand_mag,
+ "eta": eta,
+ "width": width_height[0],
+ "height": width_height[1],
+ "diffusion_model": diffusion_model,
+ "use_secondary_model": use_secondary_model,
+ "steps": steps,
+ "diffusion_steps": diffusion_steps,
+ "ViTB32": ViTB32,
+ "ViTB16": ViTB16,
+ "ViTL14": ViTL14,
+ "RN101": RN101,
+ "RN50": RN50,
+ "RN50x4": RN50x4,
+ "RN50x16": RN50x16,
+ "RN50x64": RN50x64,
+ "cut_overview": str(cut_overview),
+ "cut_innercut": str(cut_innercut),
+ "cut_ic_pow": cut_ic_pow,
+ "cut_icgray_p": str(cut_icgray_p),
+ "key_frames": key_frames,
+ "max_frames": max_frames,
+ "angle": angle,
+ "zoom": zoom,
+ "translation_x": translation_x,
+ "translation_y": translation_y,
+ "translation_z": translation_z,
+ "rotation_3d_x": rotation_3d_x,
+ "rotation_3d_y": rotation_3d_y,
+ "rotation_3d_z": rotation_3d_z,
+ "midas_depth_model": midas_depth_model,
+ "midas_weight": midas_weight,
+ "near_plane": near_plane,
+ "far_plane": far_plane,
+ "fov": fov,
+ "padding_mode": padding_mode,
+ "sampling_mode": sampling_mode,
+ "video_init_path": video_init_path,
+ "extract_nth_frame": extract_nth_frame,
+ "turbo_mode": turbo_mode,
+ "turbo_steps": turbo_steps,
+ }
+ # print('Settings:', setting_list)
+ with open(
+ f"{batchFolder}/{batch_name}({batchNum})_settings.txt", "w+"
+ ) as f: # save settings
+ json.dump(setting_list, f, ensure_ascii=False, indent=4)
+
+ # @title 1.6 Define the secondary diffusion model
+
+ def append_dims(x, n):
+ return x[(Ellipsis, *(None,) * (n - x.ndim))]
+
+ def expand_to_planes(x, shape):
+ return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])
+
+ def alpha_sigma_to_t(alpha, sigma):
+ return torch.atan2(sigma, alpha) * 2 / math.pi
+
+ def t_to_alpha_sigma(t):
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
+
+ @dataclass
+ class DiffusionOutput:
+ v: torch.Tensor
+ pred: torch.Tensor
+ eps: torch.Tensor
+
+ class ConvBlock(nn.Sequential):
+ def __init__(self, c_in, c_out):
+ super().__init__(
+ nn.Conv2d(c_in, c_out, 3, padding=1),
+ nn.ReLU(inplace=True),
+ )
+
+ class SkipBlock(nn.Module):
+ def __init__(self, main, skip=None):
+ super().__init__()
+ self.main = nn.Sequential(*main)
+ self.skip = skip if skip else nn.Identity()
+
+ def forward(self, input):
+ return torch.cat([self.main(input), self.skip(input)], dim=1)
+
+ class FourierFeatures(nn.Module):
+ def __init__(self, in_features, out_features, std=1.0):
+ super().__init__()
+ assert out_features % 2 == 0
+ self.weight = nn.Parameter(
+ torch.randn([out_features // 2, in_features]) * std
+ )
+
+ def forward(self, input):
+ f = 2 * math.pi * input @ self.weight.T
+ return torch.cat([f.cos(), f.sin()], dim=-1)
+
+ class SecondaryDiffusionImageNet(nn.Module):
+ def __init__(self):
+ super().__init__()
+ c = 64 # The base channel count
+
+ self.timestep_embed = FourierFeatures(1, 16)
+
+ self.net = nn.Sequential(
+ ConvBlock(3 + 16, c),
+ ConvBlock(c, c),
+ SkipBlock(
+ [
+ nn.AvgPool2d(2),
+ ConvBlock(c, c * 2),
+ ConvBlock(c * 2, c * 2),
+ SkipBlock(
+ [
+ nn.AvgPool2d(2),
+ ConvBlock(c * 2, c * 4),
+ ConvBlock(c * 4, c * 4),
+ SkipBlock(
+ [
+ nn.AvgPool2d(2),
+ ConvBlock(c * 4, c * 8),
+ ConvBlock(c * 8, c * 4),
+ nn.Upsample(
+ scale_factor=2,
+ mode="bilinear",
+ align_corners=False,
+ ),
+ ]
+ ),
+ ConvBlock(c * 8, c * 4),
+ ConvBlock(c * 4, c * 2),
+ nn.Upsample(
+ scale_factor=2, mode="bilinear", align_corners=False
+ ),
+ ]
+ ),
+ ConvBlock(c * 4, c * 2),
+ ConvBlock(c * 2, c),
+ nn.Upsample(
+ scale_factor=2, mode="bilinear", align_corners=False
+ ),
+ ]
+ ),
+ ConvBlock(c * 2, c),
+ nn.Conv2d(c, 3, 3, padding=1),
+ )
+
+ def forward(self, input, t):
+ timestep_embed = expand_to_planes(
+ self.timestep_embed(t[:, None]), input.shape
+ )
+ v = self.net(torch.cat([input, timestep_embed], dim=1))
+ alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))
+ pred = input * alphas - v * sigmas
+ eps = input * sigmas + v * alphas
+ return DiffusionOutput(v, pred, eps)
+
+ class SecondaryDiffusionImageNet2(nn.Module):
+ def __init__(self):
+ super().__init__()
+ c = 64 # The base channel count
+ cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]
+
+ self.timestep_embed = FourierFeatures(1, 16)
+ self.down = nn.AvgPool2d(2)
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
+
+ self.net = nn.Sequential(
+ ConvBlock(3 + 16, cs[0]),
+ ConvBlock(cs[0], cs[0]),
+ SkipBlock(
+ [
+ self.down,
+ ConvBlock(cs[0], cs[1]),
+ ConvBlock(cs[1], cs[1]),
+ SkipBlock(
+ [
+ self.down,
+ ConvBlock(cs[1], cs[2]),
+ ConvBlock(cs[2], cs[2]),
+ SkipBlock(
+ [
+ self.down,
+ ConvBlock(cs[2], cs[3]),
+ ConvBlock(cs[3], cs[3]),
+ SkipBlock(
+ [
+ self.down,
+ ConvBlock(cs[3], cs[4]),
+ ConvBlock(cs[4], cs[4]),
+ SkipBlock(
+ [
+ self.down,
+ ConvBlock(cs[4], cs[5]),
+ ConvBlock(cs[5], cs[5]),
+ ConvBlock(cs[5], cs[5]),
+ ConvBlock(cs[5], cs[4]),
+ self.up,
+ ]
+ ),
+ ConvBlock(cs[4] * 2, cs[4]),
+ ConvBlock(cs[4], cs[3]),
+ self.up,
+ ]
+ ),
+ ConvBlock(cs[3] * 2, cs[3]),
+ ConvBlock(cs[3], cs[2]),
+ self.up,
+ ]
+ ),
+ ConvBlock(cs[2] * 2, cs[2]),
+ ConvBlock(cs[2], cs[1]),
+ self.up,
+ ]
+ ),
+ ConvBlock(cs[1] * 2, cs[1]),
+ ConvBlock(cs[1], cs[0]),
+ self.up,
+ ]
+ ),
+ ConvBlock(cs[0] * 2, cs[0]),
+ nn.Conv2d(cs[0], 3, 3, padding=1),
+ )
+
+ def forward(self, input, t):
+ timestep_embed = expand_to_planes(
+ self.timestep_embed(t[:, None]), input.shape
+ )
+ v = self.net(torch.cat([input, timestep_embed], dim=1))
+ alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))
+ pred = input * alphas - v * sigmas
+ eps = input * sigmas + v * alphas
+ return DiffusionOutput(v, pred, eps)
+
+ # 2. Diffusion and CLIP model settings"""
+
+ if args2.use256 == 0:
+ sys.stdout.write("Loading 512x512_diffusion_uncond_finetune_008100 ...\n")
+ sys.stdout.flush()
+ status.write("Loading 512x512_diffusion_uncond_finetune_008100 ...\n")
+ diffusion_model = "512x512_diffusion_uncond_finetune_008100" # @param ["256x256_diffusion_uncond", "512x512_diffusion_uncond_finetune_008100"]
+ else:
+ sys.stdout.write("Loading 256x256_diffusion_uncond ...\n")
+ sys.stdout.flush()
+ status.write("Loading 256x256_diffusion_uncond ...\n")
+ diffusion_model = "256x256_diffusion_uncond"
+
+ if args2.secondarymodel == 1:
+ use_secondary_model = True # @param {type: 'boolean'}
+ else:
+ use_secondary_model = False # @param {type: 'boolean'}
+
+ # timestep_respacing = '50' # param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000']
+ if args2.sampling_mode == "ddim" or args2.sampling_mode == "plms":
+ timestep_respacing = "ddim" + str(
+ args2.iterations
+ ) #'ddim100' # Modify this value to decrease the number of timesteps.
+ else:
+ timestep_respacing = str(
+ args2.iterations
+ ) #'ddim100' # Modify this value to decrease the number of timesteps.
+
+ diffusion_steps = 1000 # param {type: 'number'}
+
+ use_checkpoint = True # @param {type: 'boolean'}
+
+ # @markdown If you're having issues with model downloads, check this to compare SHA's:
+ check_model_SHA = False # @param{type:"boolean"}
+
+ model_256_SHA = "983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a"
+ model_512_SHA = "9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648"
+ model_secondary_SHA = (
+ "983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a"
+ )
+
+ model_256_link = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt"
+ model_512_link = "https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt"
+ model_secondary_link = (
+ "https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth"
+ )
+
+ model_256_path = f"{DefaultPaths.model_path}/256x256_diffusion_uncond.pt"
+ model_512_path = (
+ f"{DefaultPaths.model_path}/512x512_diffusion_uncond_finetune_008100.pt"
+ )
+ model_secondary_path = f"{DefaultPaths.model_path}/secondary_model_imagenet_2.pth"
+
+ model_256_downloaded = True
+ model_512_downloaded = True
+ model_secondary_downloaded = True
+
+ model_config = model_and_diffusion_defaults()
+ if diffusion_model == "512x512_diffusion_uncond_finetune_008100":
+ model_config.update(
+ {
+ "attention_resolutions": "32, 16, 8",
+ "class_cond": False,
+ "diffusion_steps": diffusion_steps,
+ "rescale_timesteps": True,
+ "timestep_respacing": timestep_respacing,
+ "image_size": 512,
+ "learn_sigma": True,
+ "noise_schedule": "linear",
+ "num_channels": 256,
+ "num_head_channels": 64,
+ "num_res_blocks": 2,
+ "resblock_updown": True,
+ "use_checkpoint": use_checkpoint,
+ "use_fp16": True,
+ "use_scale_shift_norm": True,
+ }
+ )
+ elif diffusion_model == "256x256_diffusion_uncond":
+ model_config.update(
+ {
+ "attention_resolutions": "32, 16, 8",
+ "class_cond": False,
+ "diffusion_steps": diffusion_steps,
+ "rescale_timesteps": True,
+ "timestep_respacing": timestep_respacing,
+ "image_size": 256,
+ "learn_sigma": True,
+ "noise_schedule": "linear",
+ "num_channels": 256,
+ "num_head_channels": 64,
+ "num_res_blocks": 2,
+ "resblock_updown": True,
+ "use_checkpoint": use_checkpoint,
+ "use_fp16": True,
+ "use_scale_shift_norm": True,
+ }
+ )
+
+ secondary_model_ver = 2
+ model_default = model_config["image_size"]
+
+ if secondary_model_ver == 2:
+ secondary_model = SecondaryDiffusionImageNet2()
+ secondary_model.load_state_dict(
+ torch.load(
+ f"{DefaultPaths.model_path}/secondary_model_imagenet_2.pth",
+ map_location="cpu",
+ )
+ )
+ secondary_model.eval().requires_grad_(False).to(device)
+
+ clip_models = []
+ if args2.usevit32 == 1:
+ sys.stdout.write("Loading ViT-B/32 CLIP model ...\n")
+ sys.stdout.flush()
+ status.write("Loading ViT-B/32 CLIP model ...\n")
+ clip_models.append(
+ clip.load("ViT-B/32", jit=False)[0].eval().requires_grad_(False).to(device)
+ )
+ if args2.usevit16 == 1:
+ sys.stdout.write("Loading ViT-B/16 CLIP model ...\n")
+ sys.stdout.flush()
+ status.write("Loading ViT-B/16 CLIP model ...\n")
+ clip_models.append(
+ clip.load("ViT-B/16", jit=False)[0].eval().requires_grad_(False).to(device)
+ )
+ if args2.usevit14 == 1:
+ sys.stdout.write("Loading ViT-L/14 CLIP model ...\n")
+ sys.stdout.flush()
+ status.write("Loading ViT-L/14 CLIP model ...\n")
+ clip_models.append(
+ clip.load("ViT-L/14", jit=False)[0].eval().requires_grad_(False).to(device)
+ )
+ if args2.usern50x4 == 1:
+ sys.stdout.write("Loading RN50x4 CLIP model ...\n")
+ sys.stdout.flush()
+ status.write("Loading RN50x4 CLIP model ...\n")
+ clip_models.append(
+ clip.load("RN50x4", jit=False)[0].eval().requires_grad_(False).to(device)
+ )
+ if args2.usern50x16 == 1:
+ sys.stdout.write("Loading RN50x16 CLIP model ...\n")
+ sys.stdout.flush()
+ status.write("Loading RN50x16 CLIP model ...\n")
+ clip_models.append(
+ clip.load("RN50x16", jit=False)[0].eval().requires_grad_(False).to(device)
+ )
+ if args2.usern50x64 == 1:
+ sys.stdout.write("Loading RN50x64 CLIP model ...\n")
+ sys.stdout.flush()
+ status.write("Loading RN50x64 CLIP model ...\n")
+ clip_models.append(
+ clip.load("RN50x64", jit=False)[0].eval().requires_grad_(False).to(device)
+ )
+ if args2.usern50 == 1:
+ sys.stdout.write("Loading RN50 CLIP model ...\n")
+ sys.stdout.flush()
+ status.write("Loading RN50 CLIP model ...\n")
+ clip_models.append(
+ clip.load("RN50", jit=False)[0].eval().requires_grad_(False).to(device)
+ )
+ if args2.usern101 == 1:
+ sys.stdout.write("Loading RN101 CLIP model ...\n")
+ sys.stdout.flush()
+ status.write("Loading RN101 CLIP model ...\n")
+ clip_models.append(
+ clip.load("RN101", jit=False)[0].eval().requires_grad_(False).to(device)
+ )
+ if args2.useslipbase == 1:
+ sys.stdout.write("Loading SLIP Base model ...\n")
+ sys.stdout.flush()
+ SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)
+ # next 2 lines needed so torch.load handles posix paths on Windows
+ import pathlib
+
+ pathlib.PosixPath = pathlib.WindowsPath
+ sd = torch.load("slip_base_100ep.pt")
+ real_sd = {}
+ for k, v in sd["state_dict"].items():
+ real_sd[".".join(k.split(".")[1:])] = v
+ del sd
+ SLIPB16model.load_state_dict(real_sd)
+ SLIPB16model.requires_grad_(False).eval().to(device)
+ clip_models.append(SLIPB16model)
+ if args2.usesliplarge == 1:
+ sys.stdout.write("Loading SLIP Large model ...\n")
+ sys.stdout.flush()
+ SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)
+ # next 2 lines needed so torch.load handles posix paths on Windows
+ import pathlib
+
+ pathlib.PosixPath = pathlib.WindowsPath
+ sd = torch.load("slip_large_100ep.pt")
+ real_sd = {}
+ for k, v in sd["state_dict"].items():
+ real_sd[".".join(k.split(".")[1:])] = v
+ del sd
+ SLIPL16model.load_state_dict(real_sd)
+ SLIPL16model.requires_grad_(False).eval().to(device)
+ clip_models.append(SLIPL16model)
+
+ normalize = T.Normalize(
+ mean=[0.48145466, 0.4578275, 0.40821073],
+ std=[0.26862954, 0.26130258, 0.27577711],
+ )
+ status.write("Loading lpips model...\n")
+ lpips_model = lpips.LPIPS(net="vgg").to(device)
+
+ """# 3. Settings"""
+
+ # sys.stdout.write("DEBUG0 ...\n")
+ # sys.stdout.flush()
+
+ # @markdown ####**Basic Settings:**
+ batch_name = "TimeToDisco" # @param{type: 'string'}
+ steps = (
+ args2.iterations
+ ) # @param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true}
+ width_height = [args2.sizex, args2.sizey] # @param{type: 'raw'}
+ clip_guidance_scale = args2.guidancescale # @param{type: 'number'}
+ tv_scale = args2.tvscale # @param{type: 'number'}
+ range_scale = args2.rangescale # @param{type: 'number'}
+ sat_scale = args2.saturationscale # @param{type: 'number'}
+ cutn_batches = args2.cutnbatches # @param{type: 'number'}
+
+ if args2.useaugs == 1:
+ skip_augs = False # False - Controls whether to skip torchvision augmentations
+ else:
+ skip_augs = True # False - Controls whether to skip torchvision augmentations
+
+ # @markdown ####**Init Settings:**
+ if args2.seed_image is not None:
+ init_image = (
+ args2.seed_image
+ ) # This can be an URL or Colab local path and must be in quotes.
+ skip_steps = (
+ args2.skipseedtimesteps
+ ) # 12 Skip unstable steps # Higher values make the output look more like the init.
+ init_scale = (
+ args2.initscale
+ ) # This enhances the effect of the init image, a good value is 1000.
+ else:
+ init_image = "" # This can be an URL or Colab local path and must be in quotes.
+ skip_steps = 0 # 12 Skip unstable steps # Higher values make the output look more like the init.
+ init_scale = (
+ 0 # This enhances the effect of the init image, a good value is 1000.
+ )
+
+ if init_image == "":
+ init_image = None
+
+ side_x = args2.sizex
+ side_y = args2.sizey
+
+ # Update Model Settings
+ # timestep_respacing = f'ddim{steps}'
+ diffusion_steps = (1000 // steps) * steps if steps < 1000 else steps
+ model_config.update(
+ {
+ "timestep_respacing": timestep_respacing,
+ "diffusion_steps": diffusion_steps,
+ }
+ )
+
+ # Make folder for batch
+ batchFolder = f"./"
+ # createPath(batchFolder)
+
+ # sys.stdout.write("DEBUG1 ...\n")
+ # sys.stdout.flush()
+
+ """###Animation Settings"""
+
+ # @markdown ####**Animation Mode:**
+ animation_mode = (
+ args2.animation_mode
+ ) #'None' #@param ['None', '2D', '3D', 'Video Input'] {type:'string'}
+ # @markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.*
+
+ # @markdown ---
+
+ # @markdown ####**Video Input Settings:**
+ video_init_path = "training.mp4" # "D:\\sample_cat.mp4" #@param {type: 'string'}
+ extract_nth_frame = 2 # @param {type:"number"}
+
+ # sys.stdout.write("DEBUG1a ...\n")
+ # sys.stdout.flush()
+
+ if animation_mode == "Video Input":
+ videoFramesFolder = "./videoFrames"
+ # createPath(videoFramesFolder)
+ # print(f"Exporting Video Frames (1 every {extract_nth_frame})...")
+ sys.stdout.write(f"Exporting Video Frames (1 every {extract_nth_frame})...\n")
+ sys.stdout.flush()
+
+ """
+ try:
+ !rm {videoFramesFolder}/*.jpg
+ except:
+ print('')
+ """
+ # sys.stdout.write("DEBUG1a1 ...\n")
+ # sys.stdout.flush()
+ vf = f'"select=not(mod(n\,{extract_nth_frame}))"'
+ # sys.stdout.write("DEBUG1a2 ...\n")
+ # sys.stdout.flush()
+ os.system(
+ f"ffmpeg.exe -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg"
+ )
+ # sys.stdout.write("DEBUG1a3 ...\n")
+ # sys.stdout.flush()
+
+ # sys.stdout.write("DEBUG1b ...\n")
+ # sys.stdout.flush()
+
+ # @markdown ---
+
+ # @markdown ####**2D Animation Settings:**
+ # @markdown `zoom` is a multiplier of dimensions, 1 is no zoom.
+
+ key_frames = True # @param {type:"boolean"}
+ max_frames = args2.max_frames # 10000#@param {type:"number"}
+
+ # sys.stdout.write("DEBUG1c ...\n")
+ # sys.stdout.flush()
+
+ if animation_mode == "Video Input":
+ max_frames = len(glob(f"{videoFramesFolder}/*.jpg"))
+
+ # sys.stdout.write("DEBUG1d ...\n")
+ # sys.stdout.flush()
+
+ interp_spline = "Linear" # Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:"string"}
+ angle = args2.angle # "0:(0)"#@param {type:"string"}
+ zoom = args2.zoom # "0: (1), 10: (1.05)"#@param {type:"string"}
+ translation_x = args2.translation_x # "0: (0)"#@param {type:"string"}
+ translation_y = args2.translation_y # "0: (0)"#@param {type:"string"}
+ translation_z = args2.translation_z # "0: (10.0)"#@param {type:"string"}
+ rotation_3d_x = args2.rotation_3d_x # "0: (0)"#@param {type:"string"}
+ rotation_3d_y = args2.rotation_3d_y # "0: (0)"#@param {type:"string"}
+ rotation_3d_z = args2.rotation_3d_z # "0: (0)"#@param {type:"string"}
+ midas_depth_model = "dpt_large" # @param {type:"string"}
+ midas_weight = args2.midas_weight # 0.3#@param {type:"number"}
+ near_plane = args2.near_plane # 200#@param {type:"number"}
+ far_plane = args2.far_plane # 10000#@param {type:"number"}
+ fov = args2.fov # 40#@param {type:"number"}
+ padding_mode = "border" # @param {type:"string"}
+ sampling_mode = args2.sampling_mode # @param {type:"string"}
+ # @markdown ####**Coherency Settings:**
+ # @markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500.
+ frames_scale = args2.frames_scale # 1500 #@param{type: 'integer'}
+ # @markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.
+ frames_skip_steps = (
+ args2.frames_skip_steps
+ ) #'60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}
+
+ if args2.turbo_mode == 1:
+ turbo_mode = True # @param {type:"boolean"}
+ else:
+ turbo_mode = False # @param {type:"boolean"}
+ turbo_steps = args2.turbo_steps # "3" #@param ["2","3","4"] {type:'string'}
+ # @markdown ---
+
+ def parse_key_frames(string, prompt_parser=None):
+ """Given a string representing frame numbers paired with parameter values at that frame,
+ return a dictionary with the frame numbers as keys and the parameter values as the values.
+
+ Parameters
+ ----------
+ string: string
+ Frame numbers paired with parameter values at that frame number, in the format
+ 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'
+ prompt_parser: function or None, optional
+ If provided, prompt_parser will be applied to each string of parameter values.
+
+ Returns
+ -------
+ dict
+ Frame numbers as keys, parameter values at that frame number as values
+
+ Raises
+ ------
+ RuntimeError
+ If the input string does not match the expected format.
+
+ Examples
+ --------
+ >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)")
+ {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}
+
+ >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)", prompt_parser=lambda x: x.lower()))
+ {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}
+ """
+ import re
+
+ pattern = r"((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])"
+ frames = dict()
+ for match_object in re.finditer(pattern, string):
+ frame = int(match_object.groupdict()["frame"])
+ param = match_object.groupdict()["param"]
+ if prompt_parser:
+ frames[frame] = prompt_parser(param)
+ else:
+ frames[frame] = param
+
+ if frames == {} and len(string) != 0:
+ raise RuntimeError("Key Frame string not correctly formatted")
+ return frames
+
+ def get_inbetweens(key_frames, integer=False):
+ """Given a dict with frame numbers as keys and a parameter value as values,
+ return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.
+ Any values not provided in the input dict are calculated by linear interpolation between
+ the values of the previous and next provided frames. If there is no previous provided frame, then
+ the value is equal to the value of the next provided frame, or if there is no next provided frame,
+ then the value is equal to the value of the previous provided frame. If no frames are provided,
+ all frame values are NaN.
+
+ Parameters
+ ----------
+ key_frames: dict
+ A dict with integer frame numbers as keys and numerical values of a particular parameter as values.
+ integer: Bool, optional
+ If True, the values of the output series are converted to integers.
+ Otherwise, the values are floats.
+
+ Returns
+ -------
+ pd.Series
+ A Series with length max_frames representing the parameter values for each frame.
+
+ Examples
+ --------
+ >>> max_frames = 5
+ >>> get_inbetweens({1: 5, 3: 6})
+ 0 5.0
+ 1 5.0
+ 2 5.5
+ 3 6.0
+ 4 6.0
+ dtype: float64
+
+ >>> get_inbetweens({1: 5, 3: 6}, integer=True)
+ 0 5
+ 1 5
+ 2 5
+ 3 6
+ 4 6
+ dtype: int64
+ """
+ key_frame_series = pd.Series([np.nan for a in range(max_frames)])
+
+ for i, value in key_frames.items():
+ key_frame_series[i] = value
+ key_frame_series = key_frame_series.astype(float)
+
+ interp_method = interp_spline
+
+ if interp_method == "Cubic" and len(key_frames.items()) <= 3:
+ interp_method = "Quadratic"
+
+ if interp_method == "Quadratic" and len(key_frames.items()) <= 2:
+ interp_method = "Linear"
+
+ key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]
+ key_frame_series[max_frames - 1] = key_frame_series[
+ key_frame_series.last_valid_index()
+ ]
+ # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')
+ key_frame_series = key_frame_series.interpolate(
+ method=interp_method.lower(), limit_direction="both"
+ )
+ if integer:
+ return key_frame_series.astype(int)
+ return key_frame_series
+
+ def split_prompts(prompts):
+ prompt_series = pd.Series([np.nan for a in range(max_frames)])
+ for i, prompt in prompts.items():
+ prompt_series[i] = prompt
+ # prompt_series = prompt_series.astype(str)
+ prompt_series = prompt_series.ffill().bfill()
+ return prompt_series
+
+ if key_frames:
+ try:
+ angle_series = get_inbetweens(parse_key_frames(angle))
+ except RuntimeError as e:
+ print(
+ "WARNING: You have selected to use key frames, but you have not "
+ "formatted `angle` correctly for key frames.\n"
+ "Attempting to interpret `angle` as "
+ f'"0: ({angle})"\n'
+ "Please read the instructions to find out how to use key frames "
+ "correctly.\n"
+ )
+ angle = f"0: ({angle})"
+ angle_series = get_inbetweens(parse_key_frames(angle))
+
+ try:
+ zoom_series = get_inbetweens(parse_key_frames(zoom))
+ except RuntimeError as e:
+ print(
+ "WARNING: You have selected to use key frames, but you have not "
+ "formatted `zoom` correctly for key frames.\n"
+ "Attempting to interpret `zoom` as "
+ f'"0: ({zoom})"\n'
+ "Please read the instructions to find out how to use key frames "
+ "correctly.\n"
+ )
+ zoom = f"0: ({zoom})"
+ zoom_series = get_inbetweens(parse_key_frames(zoom))
+
+ try:
+ translation_x_series = get_inbetweens(parse_key_frames(translation_x))
+ except RuntimeError as e:
+ print(
+ "WARNING: You have selected to use key frames, but you have not "
+ "formatted `translation_x` correctly for key frames.\n"
+ "Attempting to interpret `translation_x` as "
+ f'"0: ({translation_x})"\n'
+ "Please read the instructions to find out how to use key frames "
+ "correctly.\n"
+ )
+ translation_x = f"0: ({translation_x})"
+ translation_x_series = get_inbetweens(parse_key_frames(translation_x))
+
+ try:
+ translation_y_series = get_inbetweens(parse_key_frames(translation_y))
+ except RuntimeError as e:
+ print(
+ "WARNING: You have selected to use key frames, but you have not "
+ "formatted `translation_y` correctly for key frames.\n"
+ "Attempting to interpret `translation_y` as "
+ f'"0: ({translation_y})"\n'
+ "Please read the instructions to find out how to use key frames "
+ "correctly.\n"
+ )
+ translation_y = f"0: ({translation_y})"
+ translation_y_series = get_inbetweens(parse_key_frames(translation_y))
+
+ try:
+ translation_z_series = get_inbetweens(parse_key_frames(translation_z))
+ except RuntimeError as e:
+ print(
+ "WARNING: You have selected to use key frames, but you have not "
+ "formatted `translation_z` correctly for key frames.\n"
+ "Attempting to interpret `translation_z` as "
+ f'"0: ({translation_z})"\n'
+ "Please read the instructions to find out how to use key frames "
+ "correctly.\n"
+ )
+ translation_z = f"0: ({translation_z})"
+ translation_z_series = get_inbetweens(parse_key_frames(translation_z))
+
+ try:
+ rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))
+ except RuntimeError as e:
+ print(
+ "WARNING: You have selected to use key frames, but you have not "
+ "formatted `rotation_3d_x` correctly for key frames.\n"
+ "Attempting to interpret `rotation_3d_x` as "
+ f'"0: ({rotation_3d_x})"\n'
+ "Please read the instructions to find out how to use key frames "
+ "correctly.\n"
+ )
+ rotation_3d_x = f"0: ({rotation_3d_x})"
+ rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))
+
+ try:
+ rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))
+ except RuntimeError as e:
+ print(
+ "WARNING: You have selected to use key frames, but you have not "
+ "formatted `rotation_3d_y` correctly for key frames.\n"
+ "Attempting to interpret `rotation_3d_y` as "
+ f'"0: ({rotation_3d_y})"\n'
+ "Please read the instructions to find out how to use key frames "
+ "correctly.\n"
+ )
+ rotation_3d_y = f"0: ({rotation_3d_y})"
+ rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))
+
+ try:
+ rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))
+ except RuntimeError as e:
+ print(
+ "WARNING: You have selected to use key frames, but you have not "
+ "formatted `rotation_3d_z` correctly for key frames.\n"
+ "Attempting to interpret `rotation_3d_z` as "
+ f'"0: ({rotation_3d_z})"\n'
+ "Please read the instructions to find out how to use key frames "
+ "correctly.\n"
+ )
+ rotation_3d_z = f"0: ({rotation_3d_z})"
+ rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))
+
+ else:
+ angle = float(angle)
+ zoom = float(zoom)
+ translation_x = float(translation_x)
+ translation_y = float(translation_y)
+ translation_z = float(translation_z)
+ rotation_3d_x = float(rotation_3d_x)
+ rotation_3d_y = float(rotation_3d_y)
+ rotation_3d_z = float(rotation_3d_z)
+
+ """### Extra Settings
+ Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling
+ """
+
+ # @markdown ####**Saving:**
+
+ intermediate_saves = 0 # @param{type: 'raw'}
+ intermediates_in_subfolder = True # @param{type: 'boolean'}
+ # @markdown Intermediate steps will save a copy at your specified intervals. You can either format it as a single integer or a list of specific steps
+
+ # @markdown A value of `2` will save a copy at 33% and 66%. 0 will save none.
+
+ # @markdown A value of `[5, 9, 34, 45]` will save at steps 5, 9, 34, and 45. (Make sure to include the brackets)
+
+ if type(intermediate_saves) is not list:
+ if intermediate_saves:
+ steps_per_checkpoint = math.floor(
+ (steps - skip_steps - 1) // (intermediate_saves + 1)
+ )
+ steps_per_checkpoint = (
+ steps_per_checkpoint if steps_per_checkpoint > 0 else 1
+ )
+ print(f"Will save every {steps_per_checkpoint} steps")
+ else:
+ steps_per_checkpoint = steps + 10
+ else:
+ steps_per_checkpoint = None
+
+ if intermediate_saves and intermediates_in_subfolder is True:
+ partialFolder = f"{batchFolder}/partials"
+ createPath(partialFolder)
+
+ # @markdown ---
+
+ # @markdown ####**SuperRes Sharpening:**
+ # @markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*
+ sharpen_preset = "Off" # @param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']
+ keep_unsharp = True # @param{type: 'boolean'}
+
+ if sharpen_preset != "Off" and keep_unsharp is True:
+ unsharpenFolder = f"{batchFolder}/unsharpened"
+ createPath(unsharpenFolder)
+
+ # @markdown ---
+
+ # @markdown ####**Advanced Settings:**
+ # @markdown *There are a few extra advanced settings available if you double click this cell.*
+
+ # @markdown *Perlin init will replace your init, so uncheck if using one.*
+
+ if args2.perlin_init == 1:
+ perlin_init = True # @param{type: 'boolean'}
+ else:
+ perlin_init = False # @param{type: 'boolean'}
+ perlin_mode = args2.perlin_mode #'mixed' #@param ['mixed', 'color', 'gray']
+
+ set_seed = "random_seed" # @param{type: 'string'}
+ eta = args2.eta # @param{type: 'number'}
+ clamp_grad = True # @param{type: 'boolean'}
+ clamp_max = args2.clampmax # @param{type: 'number'}
+
+ ### EXTRA ADVANCED SETTINGS:
+ randomize_class = True
+ if args2.denoised == 1:
+ clip_denoised = True
+ else:
+ clip_denoised = False
+ fuzzy_prompt = False
+ rand_mag = 0.05
+
+ # @markdown ---
+
+ # @markdown ####**Cutn Scheduling:**
+ # @markdown Format: `[40]*400+[20]*600` = 40 cuts for the first 400 /1000 steps, then 20 for the last 600/1000
+
+ # @markdown cut_overview and cut_innercut are cumulative for total cutn on any given step. Overview cuts see the entire image and are good for early structure, innercuts are your standard cutn.
+
+ cut_overview = "[12]*400+[4]*600" # @param {type: 'string'}
+ cut_innercut = "[4]*400+[12]*600" # @param {type: 'string'}
+ cut_ic_pow = 1 # @param {type: 'number'}
+ cut_icgray_p = "[0.2]*400+[0]*600" # @param {type: 'string'}
+
+ """###Prompts
+ `animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one.
+ """
+
+ """
+ text_prompts = {
+ 0: ["A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation.", "yellow color scheme"],
+ 100: ["This set of prompts start at frame 100","This prompt has weight five:5"],
+ }
+ """
+
+ text_prompts = {0: [phrase.strip() for phrase in args2.prompt.split("|")]}
+
+ image_prompts = {
+ # 0:['ImagePromptsWorkButArentVeryGood.png:2',],
+ }
+
+ """# 4. Diffuse!"""
+
+ # @title Do the Run!
+ # @markdown `n_batches` ignored with animation modes.
+ display_rate = args2.update # @param{type: 'number'}
+ n_batches = 1 # @param{type: 'number'}
+
+ batch_size = 1
+
+ def move_files(start_num, end_num, old_folder, new_folder):
+ for i in range(start_num, end_num):
+ old_file = old_folder + f"/{batch_name}({batchNum})_{i:04}.png"
+ new_file = new_folder + f"/{batch_name}({batchNum})_{i:04}.png"
+ os.rename(old_file, new_file)
+
+ # @markdown ---
+
+ resume_run = False # @param{type: 'boolean'}
+ run_to_resume = "latest" # @param{type: 'string'}
+ resume_from_frame = "latest" # @param{type: 'string'}
+ retain_overwritten_frames = False # @param{type: 'boolean'}
+ if retain_overwritten_frames is True:
+ retainFolder = f"{batchFolder}/retained"
+ createPath(retainFolder)
+
+ skip_step_ratio = int(frames_skip_steps.rstrip("%")) / 100
+ calc_frames_skip_steps = math.floor(steps * skip_step_ratio)
+
+ if steps <= calc_frames_skip_steps:
+ sys.exit("ERROR: You can't skip more steps than your total steps")
+
+ """
+ if resume_run:
+ if run_to_resume == 'latest':
+ try:
+ batchNum
+ except:
+ batchNum = len(glob(f"{batchFolder}/{batch_name}(*)_settings.txt"))-1
+ else:
+ batchNum = int(run_to_resume)
+ if resume_from_frame == 'latest':
+ start_frame = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png"))
+ else:
+ start_frame = int(resume_from_frame)+1
+ if retain_overwritten_frames is True:
+ existing_frames = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png"))
+ frames_to_save = existing_frames - start_frame
+ print(f'Moving {frames_to_save} frames to the Retained folder')
+ move_files(start_frame, existing_frames, batchFolder, retainFolder)
+ else:
+ """
+ start_frame = 0
+ batchNum = 1
+ """
+ batchNum = len(glob(batchFolder+"/*.txt"))
+ while path.isfile(f"{batchFolder}/{batch_name}({batchNum})_settings.txt") is True or path.isfile(f"{batchFolder}/{batch_name}-{batchNum}_settings.txt") is True:
+ batchNum += 1
+ """
+ # print(f'Starting Run: {batch_name}({batchNum}) at frame {start_frame}')
+
+ if set_seed == "random_seed":
+ random.seed()
+ seed = random.randint(0, 2**32)
+ # print(f'Using seed: {seed}')
+ else:
+ seed = int(set_seed)
+
+ args = {
+ "batchNum": batchNum,
+ "prompts_series": split_prompts(text_prompts) if text_prompts else None,
+ "image_prompts_series": split_prompts(image_prompts) if image_prompts else None,
+ "seed": seed,
+ "display_rate": display_rate,
+ "n_batches": n_batches if animation_mode == "None" else 1,
+ "batch_size": batch_size,
+ "batch_name": batch_name,
+ "steps": steps,
+ "width_height": width_height,
+ "clip_guidance_scale": clip_guidance_scale,
+ "tv_scale": tv_scale,
+ "range_scale": range_scale,
+ "sat_scale": sat_scale,
+ "cutn_batches": cutn_batches,
+ "init_image": init_image,
+ "init_scale": init_scale,
+ "skip_steps": skip_steps,
+ "sharpen_preset": sharpen_preset,
+ "keep_unsharp": keep_unsharp,
+ "side_x": side_x,
+ "side_y": side_y,
+ "timestep_respacing": timestep_respacing,
+ "diffusion_steps": diffusion_steps,
+ "animation_mode": animation_mode,
+ "video_init_path": video_init_path,
+ "extract_nth_frame": extract_nth_frame,
+ "key_frames": key_frames,
+ "max_frames": max_frames if animation_mode != "None" else 1,
+ "interp_spline": interp_spline,
+ "start_frame": start_frame,
+ "angle": angle,
+ "zoom": zoom,
+ "translation_x": translation_x,
+ "translation_y": translation_y,
+ "translation_z": translation_z,
+ "rotation_3d_x": rotation_3d_x,
+ "rotation_3d_y": rotation_3d_y,
+ "rotation_3d_z": rotation_3d_z,
+ "midas_depth_model": midas_depth_model,
+ "midas_weight": midas_weight,
+ "near_plane": near_plane,
+ "far_plane": far_plane,
+ "fov": fov,
+ "padding_mode": padding_mode,
+ "sampling_mode": sampling_mode,
+ "angle_series": angle_series,
+ "zoom_series": zoom_series,
+ "translation_x_series": translation_x_series,
+ "translation_y_series": translation_y_series,
+ "translation_z_series": translation_z_series,
+ "rotation_3d_x_series": rotation_3d_x_series,
+ "rotation_3d_y_series": rotation_3d_y_series,
+ "rotation_3d_z_series": rotation_3d_z_series,
+ "frames_scale": frames_scale,
+ "calc_frames_skip_steps": calc_frames_skip_steps,
+ "skip_step_ratio": skip_step_ratio,
+ "calc_frames_skip_steps": calc_frames_skip_steps,
+ "text_prompts": text_prompts,
+ "image_prompts": image_prompts,
+ "cut_overview": eval(cut_overview),
+ "cut_innercut": eval(cut_innercut),
+ "cut_ic_pow": cut_ic_pow,
+ "cut_icgray_p": eval(cut_icgray_p),
+ "intermediate_saves": intermediate_saves,
+ "intermediates_in_subfolder": intermediates_in_subfolder,
+ "steps_per_checkpoint": steps_per_checkpoint,
+ "perlin_init": perlin_init,
+ "perlin_mode": perlin_mode,
+ "set_seed": set_seed,
+ "eta": eta,
+ "clamp_grad": clamp_grad,
+ "clamp_max": clamp_max,
+ "skip_augs": skip_augs,
+ "randomize_class": randomize_class,
+ "clip_denoised": clip_denoised,
+ "fuzzy_prompt": fuzzy_prompt,
+ "rand_mag": rand_mag,
+ }
+
+ args = SimpleNamespace(**args)
+
+ print("Prepping model...")
+ model, diffusion = create_model_and_diffusion(**model_config)
+ model.load_state_dict(
+ torch.load(
+ f"{DefaultPaths.model_path}/{diffusion_model}.pt", map_location="cpu"
+ )
+ )
+ model.requires_grad_(False).eval().to(device)
+ for name, param in model.named_parameters():
+ if "qkv" in name or "norm" in name or "proj" in name:
+ param.requires_grad_()
+ if model_config["use_fp16"]:
+ model.convert_to_fp16()
+
+ sys.stdout.write("Starting ...\n")
+ sys.stdout.flush()
+ status.write(f"Starting ...\n")
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ try:
+ do_run()
+ # except st.script_runner.StopException as e:
+ # print("stopped here (a bit out)")
+ # pass
+ except KeyboardInterrupt:
+ pass
+ finally:
+ gc.collect()
+ torch.cuda.empty_cache()