Spaces:
Runtime error
Runtime error
import re | |
import pandas as pd | |
import numpy as np | |
from functools import reduce | |
from einops import rearrange | |
import sys | |
from skimage.exposure import match_histograms | |
import random | |
from pytorch_lightning import seed_everything | |
import time | |
import torch | |
from torch import autocast | |
from PIL import Image | |
import os | |
import requests | |
from omegaconf import OmegaConf | |
from huggingface_hub import hf_hub_download | |
# 1. Download stable diffusion repository and set path | |
os.system("git clone https://github.com/kael558/stable-diffusion-cpu") | |
os.system("git clone https://github.com/shariqfarooq123/AdaBins.git") | |
os.system("git clone https://github.com/isl-org/MiDaS.git") | |
os.system("git clone https://github.com/MSFTserver/pytorch3d-lite.git") | |
os.system("git clone https://github.com/deforum/k-diffusion/") | |
with open('k-diffusion/k_diffusion/__init__.py', 'w') as f: | |
f.write('') | |
sys.path.extend([ | |
'./taming-transformers', | |
'./clip', | |
'stable-diffusion/', | |
'k-diffusion', | |
'pytorch3d-lite', | |
'AdaBins', | |
'MiDaS', | |
]) | |
from helpers import sampler_fn | |
from k_diffusion.external import CompVisDenoiser | |
from ldm.util import instantiate_from_config | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
# 2. Set model download config | |
def load_model_from_config(config, ckpt, verbose=False, half_precision=False): | |
map_location = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location=map_location) | |
if "global_step" in pl_sd: | |
print(f"Global Step: {pl_sd['global_step']}") | |
sd_var = pl_sd["state_dict"] | |
model = instantiate_from_config(config.model) | |
m, u = model.load_state_dict(sd_var, strict=False) | |
if len(m) > 0 and verbose: | |
print("missing keys:") | |
print(m) | |
if len(u) > 0 and verbose: | |
print("unexpected keys:") | |
print(u) | |
if half_precision: | |
model = model.half() | |
model.eval() | |
return model | |
print('Model loading...') | |
''' | |
models_path = "models" | |
model_checkpoint = "sd-v1-4.ckpt" | |
ckpt_path = os.path.join(models_path, model_checkpoint) | |
ckpt_path = snapshot_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4.ckpt") | |
if os.path.exists(ckpt_path): | |
print(f"{ckpt_path} exists") | |
else: | |
print(f"Attempting to download {model_checkpoint}...this may take a while") | |
url = 'https://kael558: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@huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt' | |
ckpt_request = requests.get(url) | |
print('Model downloaded.') | |
with open(os.path.join(models_path, model_checkpoint), 'wb') as model_file: | |
model_file.write(ckpt_request.content) | |
''' | |
ckpt_path = hf_hub_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4.ckpt") | |
ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml" | |
local_config = OmegaConf.load(f"{ckpt_config_path}") | |
model = load_model_from_config(local_config, f"{ckpt_path}" ) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = model.to(device) | |
print('Model saved.') | |
class DeformAnimKeys(): | |
def __init__(self, anim_args): | |
self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle), anim_args.max_frames) | |
self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom), anim_args.max_frames) | |
self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x), anim_args.max_frames) | |
self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y), anim_args.max_frames) | |
self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z), anim_args.max_frames) | |
self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x), anim_args.max_frames) | |
self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y), anim_args.max_frames) | |
self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z), anim_args.max_frames) | |
self.perspective_flip_theta_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_theta), anim_args.max_frames) | |
self.perspective_flip_phi_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_phi), anim_args.max_frames) | |
self.perspective_flip_gamma_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_gamma), anim_args.max_frames) | |
self.perspective_flip_fv_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_fv), anim_args.max_frames) | |
self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule), anim_args.max_frames) | |
self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule), anim_args.max_frames) | |
self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule), anim_args.max_frames) | |
def check_is_number(value): | |
float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$' | |
return re.match(float_pattern, value) | |
def get_inbetweens(key_frames, max_frames, integer=False, interp_method='Linear'): | |
import numexpr | |
key_frame_series = pd.Series([np.nan for a in range(max_frames)]) | |
for i in range(0, max_frames): | |
if i in key_frames: | |
value = key_frames[i] | |
value_is_number = check_is_number(value) | |
# if it's only a number, leave the rest for the default interpolation | |
if value_is_number: | |
t = i | |
key_frame_series[i] = value | |
if not value_is_number: | |
t = i | |
key_frame_series[i] = numexpr.evaluate(value) | |
key_frame_series = key_frame_series.astype(float) | |
if interp_method == 'Cubic' and len(key_frames.items()) <= 3: | |
interp_method = 'Quadratic' | |
if interp_method == 'Quadratic' and len(key_frames.items()) <= 2: | |
interp_method = 'Linear' | |
key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] | |
key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()] | |
key_frame_series = key_frame_series.interpolate(method=interp_method.lower(), limit_direction='both') | |
if integer: | |
return key_frame_series.astype(int) | |
return key_frame_series | |
def parse_key_frames(string, prompt_parser=None): | |
# because math functions (i.e. sin(t)) can utilize brackets | |
# it extracts the value in form of some stuff | |
# which has previously been enclosed with brackets and | |
# with a comma or end of line existing after the closing one | |
pattern = r'((?P<frame>[0-9]+):[\s]*\((?P<param>[\S\s]*?)\)([,][\s]?|[\s]?$))' | |
frames = dict() | |
for match_object in re.finditer(pattern, string): | |
frame = int(match_object.groupdict()['frame']) | |
param = match_object.groupdict()['param'] | |
if prompt_parser: | |
frames[frame] = prompt_parser(param) | |
else: | |
frames[frame] = param | |
if frames == {} and len(string) != 0: | |
raise RuntimeError('Key Frame string not correctly formatted') | |
return frames | |
# https://en.wikipedia.org/wiki/Rotation_matrix | |
def getRotationMatrixManual(rotation_angles): | |
rotation_angles = [np.deg2rad(x) for x in rotation_angles] | |
phi = rotation_angles[0] # around x | |
gamma = rotation_angles[1] # around y | |
theta = rotation_angles[2] # around z | |
# X rotation | |
Rphi = np.eye(4, 4) | |
sp = np.sin(phi) | |
cp = np.cos(phi) | |
Rphi[1, 1] = cp | |
Rphi[2, 2] = Rphi[1, 1] | |
Rphi[1, 2] = -sp | |
Rphi[2, 1] = sp | |
# Y rotation | |
Rgamma = np.eye(4, 4) | |
sg = np.sin(gamma) | |
cg = np.cos(gamma) | |
Rgamma[0, 0] = cg | |
Rgamma[2, 2] = Rgamma[0, 0] | |
Rgamma[0, 2] = sg | |
Rgamma[2, 0] = -sg | |
# Z rotation (in-image-plane) | |
Rtheta = np.eye(4, 4) | |
st = np.sin(theta) | |
ct = np.cos(theta) | |
Rtheta[0, 0] = ct | |
Rtheta[1, 1] = Rtheta[0, 0] | |
Rtheta[0, 1] = -st | |
Rtheta[1, 0] = st | |
R = reduce(lambda x, y: np.matmul(x, y), [Rphi, Rgamma, Rtheta]) | |
return R | |
def getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sidelength): | |
ptsIn2D = ptsIn[0, :] | |
ptsOut2D = ptsOut[0, :] | |
ptsOut2Dlist = [] | |
ptsIn2Dlist = [] | |
for i in range(0, 4): | |
ptsOut2Dlist.append([ptsOut2D[i, 0], ptsOut2D[i, 1]]) | |
ptsIn2Dlist.append([ptsIn2D[i, 0], ptsIn2D[i, 1]]) | |
pin = np.array(ptsIn2Dlist) + [W / 2., H / 2.] | |
pout = (np.array(ptsOut2Dlist) + [1., 1.]) * (0.5 * sidelength) | |
pin = pin.astype(np.float32) | |
pout = pout.astype(np.float32) | |
return pin, pout | |
def warpMatrix(W, H, theta, phi, gamma, scale, fV): | |
# M is to be estimated | |
M = np.eye(4, 4) | |
fVhalf = np.deg2rad(fV / 2.) | |
d = np.sqrt(W * W + H * H) | |
sideLength = scale * d / np.cos(fVhalf) | |
h = d / (2.0 * np.sin(fVhalf)) | |
n = h - (d / 2.0) | |
f = h + (d / 2.0) | |
# Translation along Z-axis by -h | |
T = np.eye(4, 4) | |
T[2, 3] = -h | |
# Rotation matrices around x,y,z | |
R = getRotationMatrixManual([phi, gamma, theta]) | |
# Projection Matrix | |
P = np.eye(4, 4) | |
P[0, 0] = 1.0 / np.tan(fVhalf) | |
P[1, 1] = P[0, 0] | |
P[2, 2] = -(f + n) / (f - n) | |
P[2, 3] = -(2.0 * f * n) / (f - n) | |
P[3, 2] = -1.0 | |
# pythonic matrix multiplication | |
F = reduce(lambda x, y: np.matmul(x, y), [P, T, R]) | |
# shape should be 1,4,3 for ptsIn and ptsOut since perspectiveTransform() expects data in this way. | |
# In C++, this can be achieved by Mat ptsIn(1,4,CV_64FC3); | |
ptsIn = np.array([[ | |
[-W / 2., H / 2., 0.], [W / 2., H / 2., 0.], [W / 2., -H / 2., 0.], [-W / 2., -H / 2., 0.] | |
]]) | |
ptsOut = np.array(np.zeros((ptsIn.shape), dtype=ptsIn.dtype)) | |
ptsOut = cv2.perspectiveTransform(ptsIn, F) | |
ptsInPt2f, ptsOutPt2f = getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sideLength) | |
# check float32 otherwise OpenCV throws an error | |
assert (ptsInPt2f.dtype == np.float32) | |
assert (ptsOutPt2f.dtype == np.float32) | |
M33 = cv2.getPerspectiveTransform(ptsInPt2f, ptsOutPt2f) | |
return M33, sideLength | |
def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx): | |
angle = keys.angle_series[frame_idx] | |
zoom = keys.zoom_series[frame_idx] | |
translation_x = keys.translation_x_series[frame_idx] | |
translation_y = keys.translation_y_series[frame_idx] | |
center = (args.W // 2, args.H // 2) | |
trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]]) | |
rot_mat = cv2.getRotationMatrix2D(center, angle, zoom) | |
trans_mat = np.vstack([trans_mat, [0, 0, 1]]) | |
rot_mat = np.vstack([rot_mat, [0, 0, 1]]) | |
xform = np.matmul(rot_mat, trans_mat) | |
return cv2.warpPerspective( | |
prev_img_cv2, | |
xform, | |
(prev_img_cv2.shape[1], prev_img_cv2.shape[0]), | |
borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE | |
) | |
def sample_from_cv2(sample: np.ndarray) -> torch.Tensor: | |
sample = ((sample.astype(float) / 255.0) * 2) - 1 | |
sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16) | |
sample = torch.from_numpy(sample) | |
return sample | |
def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray: | |
sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32) | |
sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1) | |
sample_int8 = (sample_f32 * 255) | |
return sample_int8.astype(type) | |
def maintain_colors(prev_img, color_match_sample, mode): | |
if mode == 'Match Frame 0 RGB': | |
return match_histograms(prev_img, color_match_sample, multichannel=True) | |
elif mode == 'Match Frame 0 HSV': | |
prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV) | |
color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV) | |
matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True) | |
return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB) | |
else: # Match Frame 0 LAB | |
prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB) | |
color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB) | |
matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True) | |
return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB) | |
def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor: | |
return sample + torch.randn(sample.shape, device=sample.device) * noise_amt | |
def next_seed(args): | |
if args.seed_behavior == 'iter': | |
args.seed += 1 | |
elif args.seed_behavior == 'fixed': | |
pass # always keep seed the same | |
else: | |
args.seed = random.randint(0, 2**32 - 1) | |
return args.seed | |
def generate(args, return_c=False): | |
seed_everything(args.seed) | |
sampler = PLMSSampler(model) if args.sampler == 'plms' else DDIMSampler(model) | |
model_wrap = CompVisDenoiser(model) | |
batch_size = args.n_samples | |
prompt = args.prompt | |
assert prompt is not None | |
data = [batch_size * [prompt]] | |
precision_scope = autocast | |
mask = None | |
t_enc = int((1.0-args.strength) * args.steps) | |
init_latent = None | |
# Noise schedule for the k-diffusion samplers (used for masking) | |
k_sigmas = model_wrap.get_sigmas(args.steps) | |
k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:] | |
if args.sampler in ['plms','ddim']: | |
sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False) | |
callback = SamplerCallback(args=args, | |
mask=mask, | |
init_latent=init_latent, | |
sigmas=k_sigmas, | |
sampler=sampler, | |
verbose=False).callback | |
results = [] | |
with torch.no_grad(): | |
with precision_scope("cuda"): | |
with model.ema_scope(): | |
for prompts in data: | |
uc = model.get_learned_conditioning(batch_size * [""]) | |
c = model.get_learned_conditioning(prompts) | |
if args.scale == 1.0: | |
uc = None | |
if args.init_c != None: | |
c = args.init_c | |
if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]: | |
samples = sampler_fn( | |
c=c, | |
uc=uc, | |
args=args, | |
model_wrap=model_wrap, | |
init_latent=init_latent, | |
t_enc=t_enc, | |
device=device, | |
cb=callback) | |
else: | |
# args.sampler == 'plms' or args.sampler == 'ddim': | |
if init_latent is not None and args.strength > 0: | |
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) | |
else: | |
z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device) | |
if args.sampler == 'ddim': | |
samples = sampler.decode(z_enc, | |
c, | |
t_enc, | |
unconditional_guidance_scale=args.scale, | |
unconditional_conditioning=uc, | |
img_callback=callback) | |
elif args.sampler == 'plms': # no "decode" function in plms, so use "sample" | |
shape = [args.C, args.H // args.f, args.W // args.f] | |
samples, _ = sampler.sample(S=args.steps, | |
conditioning=c, | |
batch_size=args.n_samples, | |
shape=shape, | |
verbose=False, | |
unconditional_guidance_scale=args.scale, | |
unconditional_conditioning=uc, | |
eta=args.ddim_eta, | |
x_T=z_enc, | |
img_callback=callback) | |
else: | |
raise Exception(f"Sampler {args.sampler} not recognised.") | |
x_samples = model.decode_first_stage(samples) | |
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) | |
if return_c: | |
results.append(c.clone()) | |
for x_sample in x_samples: | |
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') | |
image = Image.fromarray(x_sample.astype(np.uint8)) | |
results.append(image) | |
return results | |
def get_output_folder(output_path, batch_folder): | |
out_path = os.path.join(output_path,time.strftime('%Y-%m')) | |
if batch_folder != "": | |
out_path = os.path.join(out_path, batch_folder) | |
os.makedirs(out_path, exist_ok=True) | |
return out_path | |
def DeforumArgs(): | |
#@markdown **Image Settings** | |
W = 512 #@param | |
H = 512 #@param | |
W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64 | |
#@markdown **Sampling Settings** | |
seed = -1 #@param | |
sampler = 'dpm2_ancestral' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"] | |
steps = 50 #@param | |
scale = 7 #@param | |
ddim_eta = 0.0 #@param | |
dynamic_threshold = None | |
static_threshold = None | |
#@markdown **Save & Display Settings** | |
save_samples = True #@param {type:"boolean"} | |
save_settings = True #@param {type:"boolean"} | |
display_samples = True #@param {type:"boolean"} | |
save_sample_per_step = False #@param {type:"boolean"} | |
show_sample_per_step = False #@param {type:"boolean"} | |
#@markdown **Prompt Settings** | |
prompt_weighting = False #@param {type:"boolean"} | |
normalize_prompt_weights = True #@param {type:"boolean"} | |
log_weighted_subprompts = False #@param {type:"boolean"} | |
#@markdown **Batch Settings** | |
n_batch = 1 #@param | |
batch_name = "StableFun" #@param {type:"string"} | |
filename_format = "{timestring}_{index}_{prompt}.png" #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"] | |
seed_behavior = "fixed" #@param ["iter","fixed","random"] | |
make_grid = False #@param {type:"boolean"} | |
grid_rows = 2 #@param | |
#@markdown **Init Settings** | |
use_init = False #@param {type:"boolean"} | |
strength = 0.0 #@param {type:"number"} | |
strength_0_no_init = True # Set the strength to 0 automatically when no init image is used | |
init_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"} | |
# Whiter areas of the mask are areas that change more | |
use_mask = False #@param {type:"boolean"} | |
use_alpha_as_mask = False # use the alpha channel of the init image as the mask | |
mask_file = "https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg" #@param {type:"string"} | |
invert_mask = False #@param {type:"boolean"} | |
# Adjust mask image, 1.0 is no adjustment. Should be positive numbers. | |
mask_brightness_adjust = 1.0 #@param {type:"number"} | |
mask_contrast_adjust = 1.0 #@param {type:"number"} | |
# Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding | |
overlay_mask = True # {type:"boolean"} | |
# Blur edges of final overlay mask, if used. Minimum = 0 (no blur) | |
mask_overlay_blur = 5 # {type:"number"} | |
n_samples = 1 # doesnt do anything | |
precision = 'autocast' | |
C = 4 | |
f = 8 | |
prompt = "" | |
timestring = "" | |
init_latent = None | |
init_sample = None | |
init_c = None | |
return locals() | |
def DeforumAnimArgs(): | |
#@markdown ####**Animation:** | |
animation_mode = '2D' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'} | |
max_frames = 100 #@param {type:"number"} | |
border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'} | |
#@markdown ####**Motion Parameters:** | |
angle = "0:(0)"#@param {type:"string"} | |
zoom = "0:(1.04)"#@param {type:"string"} | |
translation_x = "0:(10*sin(2*3.14*t/10))"#@param {type:"string"} | |
translation_y = "0:(0)"#@param {type:"string"} | |
translation_z = "0:(10)"#@param {type:"string"} | |
rotation_3d_x = "0:(0)"#@param {type:"string"} | |
rotation_3d_y = "0:(0)"#@param {type:"string"} | |
rotation_3d_z = "0:(0)"#@param {type:"string"} | |
flip_2d_perspective = False #@param {type:"boolean"} | |
perspective_flip_theta = "0:(0)"#@param {type:"string"} | |
perspective_flip_phi = "0:(t%15)"#@param {type:"string"} | |
perspective_flip_gamma = "0:(0)"#@param {type:"string"} | |
perspective_flip_fv = "0:(53)"#@param {type:"string"} | |
noise_schedule = "0: (0.02)"#@param {type:"string"} | |
strength_schedule = "0: (0.65)"#@param {type:"string"} | |
contrast_schedule = "0: (1.0)"#@param {type:"string"} | |
#@markdown ####**Coherence:** | |
color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} | |
diffusion_cadence = '7' #@param ['1','2','3','4','5','6','7','8'] {type:'string'} | |
#@markdown ####**3D Depth Warping:** | |
use_depth_warping = True #@param {type:"boolean"} | |
midas_weight = 0.3#@param {type:"number"} | |
near_plane = 200 | |
far_plane = 10000 | |
fov = 40#@param {type:"number"} | |
padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'} | |
sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'} | |
save_depth_maps = False #@param {type:"boolean"} | |
#@markdown ####**Video Input:** | |
video_init_path ='/content/video_in.mp4'#@param {type:"string"} | |
extract_nth_frame = 1#@param {type:"number"} | |
overwrite_extracted_frames = True #@param {type:"boolean"} | |
use_mask_video = False #@param {type:"boolean"} | |
video_mask_path ='/content/video_in.mp4'#@param {type:"string"} | |
#@markdown ####**Interpolation:** | |
interpolate_key_frames = False #@param {type:"boolean"} | |
interpolate_x_frames = 4 #@param {type:"number"} | |
#@markdown ####**Resume Animation:** | |
resume_from_timestring = False #@param {type:"boolean"} | |
resume_timestring = "20220829210106" #@param {type:"string"} | |
return locals() | |
# | |
# Callback functions | |
# | |
class SamplerCallback(object): | |
# Creates the callback function to be passed into the samplers for each step | |
def __init__(self, args, mask=None, init_latent=None, sigmas=None, sampler=None, | |
verbose=False): | |
self.sampler_name = args.sampler | |
self.dynamic_threshold = args.dynamic_threshold | |
self.static_threshold = args.static_threshold | |
self.mask = mask | |
self.init_latent = init_latent | |
self.sigmas = sigmas | |
self.sampler = sampler | |
self.verbose = verbose | |
self.batch_size = args.n_samples | |
#self.save_sample_per_step = args.save_sample_per_step | |
#self.show_sample_per_step = args.show_sample_per_step | |
#self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ] | |
#if self.save_sample_per_step: | |
# for path in self.paths_to_image_steps: | |
# os.makedirs(path, exist_ok=True) | |
self.step_index = 0 | |
self.noise = None | |
if init_latent is not None: | |
self.noise = torch.randn_like(init_latent, device=device) | |
self.mask_schedule = None | |
if sigmas is not None and len(sigmas) > 0: | |
self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas)) | |
elif len(sigmas) == 0: | |
self.mask = None # no mask needed if no steps (usually happens because strength==1.0) | |
if self.sampler_name in ["plms","ddim"]: | |
if mask is not None: | |
assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable" | |
if self.sampler_name in ["plms","ddim"]: | |
# Callback function formated for compvis latent diffusion samplers | |
self.callback = self.img_callback_ | |
else: | |
# Default callback function uses k-diffusion sampler variables | |
self.callback = self.k_callback_ | |
#self.verbose_print = print if verbose else lambda *args, **kwargs: None | |
# The callback function is applied to the image at each step | |
def dynamic_thresholding_(self, img, threshold): | |
# Dynamic thresholding from Imagen paper (May 2022) | |
s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim))) | |
s = np.max(np.append(s,1.0)) | |
torch.clamp_(img, -1*s, s) | |
torch.FloatTensor.div_(img, s) | |
# Callback for samplers in the k-diffusion repo, called thus: | |
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) | |
def k_callback_(self, args_dict): | |
self.step_index = args_dict['i'] | |
if self.dynamic_threshold is not None: | |
self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold) | |
if self.static_threshold is not None: | |
torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold) | |
if self.mask is not None: | |
init_noise = self.init_latent + self.noise * args_dict['sigma'] | |
is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 ) | |
new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1) | |
args_dict['x'].copy_(new_img) | |
#self.view_sample_step(args_dict['denoised'], "x0_pred") | |
# Callback for Compvis samplers | |
# Function that is called on the image (img) and step (i) at each step | |
def img_callback_(self, img, i): | |
self.step_index = i | |
# Thresholding functions | |
if self.dynamic_threshold is not None: | |
self.dynamic_thresholding_(img, self.dynamic_threshold) | |
if self.static_threshold is not None: | |
torch.clamp_(img, -1*self.static_threshold, self.static_threshold) | |
if self.mask is not None: | |
i_inv = len(self.sigmas) - i - 1 | |
init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(device), noise=self.noise) | |
is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 ) | |
new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1) | |
img.copy_(new_img) | |
#self.view_sample_step(img, "x") | |