Interpolation / sd.py
kael558's picture
changed repo name
5ad7e10
import re
import pandas as pd
import numpy as np
from functools import reduce
from einops import rearrange
import sys
from skimage.exposure import match_histograms
import random
from pytorch_lightning import seed_everything
import time
import torch
from torch import autocast
from PIL import Image
import os
import requests
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
# 1. Download stable diffusion repository and set path
os.system("git clone https://github.com/kael558/stable-diffusion-cpu")
os.system("git clone https://github.com/shariqfarooq123/AdaBins.git")
os.system("git clone https://github.com/isl-org/MiDaS.git")
os.system("git clone https://github.com/MSFTserver/pytorch3d-lite.git")
os.system("git clone https://github.com/deforum/k-diffusion/")
with open('k-diffusion/k_diffusion/__init__.py', 'w') as f:
f.write('')
sys.path.extend([
'./taming-transformers',
'./clip',
'stable-diffusion/',
'k-diffusion',
'pytorch3d-lite',
'AdaBins',
'MiDaS',
])
from helpers import sampler_fn
from k_diffusion.external import CompVisDenoiser
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
# 2. Set model download config
def load_model_from_config(config, ckpt, verbose=False, half_precision=False):
map_location = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location=map_location)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd_var = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd_var, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if half_precision:
model = model.half()
model.eval()
return model
print('Model loading...')
'''
models_path = "models"
model_checkpoint = "sd-v1-4.ckpt"
ckpt_path = os.path.join(models_path, model_checkpoint)
ckpt_path = snapshot_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4.ckpt")
if os.path.exists(ckpt_path):
print(f"{ckpt_path} exists")
else:
print(f"Attempting to download {model_checkpoint}...this may take a while")
url = 'https://kael558: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@huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt'
ckpt_request = requests.get(url)
print('Model downloaded.')
with open(os.path.join(models_path, model_checkpoint), 'wb') as model_file:
model_file.write(ckpt_request.content)
'''
ckpt_path = hf_hub_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4.ckpt")
ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
local_config = OmegaConf.load(f"{ckpt_config_path}")
model = load_model_from_config(local_config, f"{ckpt_path}" )
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
print('Model saved.')
class DeformAnimKeys():
def __init__(self, anim_args):
self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle), anim_args.max_frames)
self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom), anim_args.max_frames)
self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x), anim_args.max_frames)
self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y), anim_args.max_frames)
self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z), anim_args.max_frames)
self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x), anim_args.max_frames)
self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y), anim_args.max_frames)
self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z), anim_args.max_frames)
self.perspective_flip_theta_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_theta), anim_args.max_frames)
self.perspective_flip_phi_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_phi), anim_args.max_frames)
self.perspective_flip_gamma_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_gamma), anim_args.max_frames)
self.perspective_flip_fv_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_fv), anim_args.max_frames)
self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule), anim_args.max_frames)
self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule), anim_args.max_frames)
self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule), anim_args.max_frames)
def check_is_number(value):
float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$'
return re.match(float_pattern, value)
def get_inbetweens(key_frames, max_frames, integer=False, interp_method='Linear'):
import numexpr
key_frame_series = pd.Series([np.nan for a in range(max_frames)])
for i in range(0, max_frames):
if i in key_frames:
value = key_frames[i]
value_is_number = check_is_number(value)
# if it's only a number, leave the rest for the default interpolation
if value_is_number:
t = i
key_frame_series[i] = value
if not value_is_number:
t = i
key_frame_series[i] = numexpr.evaluate(value)
key_frame_series = key_frame_series.astype(float)
if interp_method == 'Cubic' and len(key_frames.items()) <= 3:
interp_method = 'Quadratic'
if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:
interp_method = 'Linear'
key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]
key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]
key_frame_series = key_frame_series.interpolate(method=interp_method.lower(), limit_direction='both')
if integer:
return key_frame_series.astype(int)
return key_frame_series
def parse_key_frames(string, prompt_parser=None):
# because math functions (i.e. sin(t)) can utilize brackets
# it extracts the value in form of some stuff
# which has previously been enclosed with brackets and
# with a comma or end of line existing after the closing one
pattern = r'((?P<frame>[0-9]+):[\s]*\((?P<param>[\S\s]*?)\)([,][\s]?|[\s]?$))'
frames = dict()
for match_object in re.finditer(pattern, string):
frame = int(match_object.groupdict()['frame'])
param = match_object.groupdict()['param']
if prompt_parser:
frames[frame] = prompt_parser(param)
else:
frames[frame] = param
if frames == {} and len(string) != 0:
raise RuntimeError('Key Frame string not correctly formatted')
return frames
# https://en.wikipedia.org/wiki/Rotation_matrix
def getRotationMatrixManual(rotation_angles):
rotation_angles = [np.deg2rad(x) for x in rotation_angles]
phi = rotation_angles[0] # around x
gamma = rotation_angles[1] # around y
theta = rotation_angles[2] # around z
# X rotation
Rphi = np.eye(4, 4)
sp = np.sin(phi)
cp = np.cos(phi)
Rphi[1, 1] = cp
Rphi[2, 2] = Rphi[1, 1]
Rphi[1, 2] = -sp
Rphi[2, 1] = sp
# Y rotation
Rgamma = np.eye(4, 4)
sg = np.sin(gamma)
cg = np.cos(gamma)
Rgamma[0, 0] = cg
Rgamma[2, 2] = Rgamma[0, 0]
Rgamma[0, 2] = sg
Rgamma[2, 0] = -sg
# Z rotation (in-image-plane)
Rtheta = np.eye(4, 4)
st = np.sin(theta)
ct = np.cos(theta)
Rtheta[0, 0] = ct
Rtheta[1, 1] = Rtheta[0, 0]
Rtheta[0, 1] = -st
Rtheta[1, 0] = st
R = reduce(lambda x, y: np.matmul(x, y), [Rphi, Rgamma, Rtheta])
return R
def getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sidelength):
ptsIn2D = ptsIn[0, :]
ptsOut2D = ptsOut[0, :]
ptsOut2Dlist = []
ptsIn2Dlist = []
for i in range(0, 4):
ptsOut2Dlist.append([ptsOut2D[i, 0], ptsOut2D[i, 1]])
ptsIn2Dlist.append([ptsIn2D[i, 0], ptsIn2D[i, 1]])
pin = np.array(ptsIn2Dlist) + [W / 2., H / 2.]
pout = (np.array(ptsOut2Dlist) + [1., 1.]) * (0.5 * sidelength)
pin = pin.astype(np.float32)
pout = pout.astype(np.float32)
return pin, pout
def warpMatrix(W, H, theta, phi, gamma, scale, fV):
# M is to be estimated
M = np.eye(4, 4)
fVhalf = np.deg2rad(fV / 2.)
d = np.sqrt(W * W + H * H)
sideLength = scale * d / np.cos(fVhalf)
h = d / (2.0 * np.sin(fVhalf))
n = h - (d / 2.0)
f = h + (d / 2.0)
# Translation along Z-axis by -h
T = np.eye(4, 4)
T[2, 3] = -h
# Rotation matrices around x,y,z
R = getRotationMatrixManual([phi, gamma, theta])
# Projection Matrix
P = np.eye(4, 4)
P[0, 0] = 1.0 / np.tan(fVhalf)
P[1, 1] = P[0, 0]
P[2, 2] = -(f + n) / (f - n)
P[2, 3] = -(2.0 * f * n) / (f - n)
P[3, 2] = -1.0
# pythonic matrix multiplication
F = reduce(lambda x, y: np.matmul(x, y), [P, T, R])
# shape should be 1,4,3 for ptsIn and ptsOut since perspectiveTransform() expects data in this way.
# In C++, this can be achieved by Mat ptsIn(1,4,CV_64FC3);
ptsIn = np.array([[
[-W / 2., H / 2., 0.], [W / 2., H / 2., 0.], [W / 2., -H / 2., 0.], [-W / 2., -H / 2., 0.]
]])
ptsOut = np.array(np.zeros((ptsIn.shape), dtype=ptsIn.dtype))
ptsOut = cv2.perspectiveTransform(ptsIn, F)
ptsInPt2f, ptsOutPt2f = getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sideLength)
# check float32 otherwise OpenCV throws an error
assert (ptsInPt2f.dtype == np.float32)
assert (ptsOutPt2f.dtype == np.float32)
M33 = cv2.getPerspectiveTransform(ptsInPt2f, ptsOutPt2f)
return M33, sideLength
def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx):
angle = keys.angle_series[frame_idx]
zoom = keys.zoom_series[frame_idx]
translation_x = keys.translation_x_series[frame_idx]
translation_y = keys.translation_y_series[frame_idx]
center = (args.W // 2, args.H // 2)
trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]])
rot_mat = cv2.getRotationMatrix2D(center, angle, zoom)
trans_mat = np.vstack([trans_mat, [0, 0, 1]])
rot_mat = np.vstack([rot_mat, [0, 0, 1]])
xform = np.matmul(rot_mat, trans_mat)
return cv2.warpPerspective(
prev_img_cv2,
xform,
(prev_img_cv2.shape[1], prev_img_cv2.shape[0]),
borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE
)
def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:
sample = ((sample.astype(float) / 255.0) * 2) - 1
sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16)
sample = torch.from_numpy(sample)
return sample
def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray:
sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32)
sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)
sample_int8 = (sample_f32 * 255)
return sample_int8.astype(type)
def maintain_colors(prev_img, color_match_sample, mode):
if mode == 'Match Frame 0 RGB':
return match_histograms(prev_img, color_match_sample, multichannel=True)
elif mode == 'Match Frame 0 HSV':
prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)
color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)
matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)
return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)
else: # Match Frame 0 LAB
prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)
color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)
matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)
return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)
def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor:
return sample + torch.randn(sample.shape, device=sample.device) * noise_amt
def next_seed(args):
if args.seed_behavior == 'iter':
args.seed += 1
elif args.seed_behavior == 'fixed':
pass # always keep seed the same
else:
args.seed = random.randint(0, 2**32 - 1)
return args.seed
def generate(args, return_c=False):
seed_everything(args.seed)
sampler = PLMSSampler(model) if args.sampler == 'plms' else DDIMSampler(model)
model_wrap = CompVisDenoiser(model)
batch_size = args.n_samples
prompt = args.prompt
assert prompt is not None
data = [batch_size * [prompt]]
precision_scope = autocast
mask = None
t_enc = int((1.0-args.strength) * args.steps)
init_latent = None
# Noise schedule for the k-diffusion samplers (used for masking)
k_sigmas = model_wrap.get_sigmas(args.steps)
k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:]
if args.sampler in ['plms','ddim']:
sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False)
callback = SamplerCallback(args=args,
mask=mask,
init_latent=init_latent,
sigmas=k_sigmas,
sampler=sampler,
verbose=False).callback
results = []
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
for prompts in data:
uc = model.get_learned_conditioning(batch_size * [""])
c = model.get_learned_conditioning(prompts)
if args.scale == 1.0:
uc = None
if args.init_c != None:
c = args.init_c
if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]:
samples = sampler_fn(
c=c,
uc=uc,
args=args,
model_wrap=model_wrap,
init_latent=init_latent,
t_enc=t_enc,
device=device,
cb=callback)
else:
# args.sampler == 'plms' or args.sampler == 'ddim':
if init_latent is not None and args.strength > 0:
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
else:
z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)
if args.sampler == 'ddim':
samples = sampler.decode(z_enc,
c,
t_enc,
unconditional_guidance_scale=args.scale,
unconditional_conditioning=uc,
img_callback=callback)
elif args.sampler == 'plms': # no "decode" function in plms, so use "sample"
shape = [args.C, args.H // args.f, args.W // args.f]
samples, _ = sampler.sample(S=args.steps,
conditioning=c,
batch_size=args.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=args.scale,
unconditional_conditioning=uc,
eta=args.ddim_eta,
x_T=z_enc,
img_callback=callback)
else:
raise Exception(f"Sampler {args.sampler} not recognised.")
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if return_c:
results.append(c.clone())
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
results.append(image)
return results
def get_output_folder(output_path, batch_folder):
out_path = os.path.join(output_path,time.strftime('%Y-%m'))
if batch_folder != "":
out_path = os.path.join(out_path, batch_folder)
os.makedirs(out_path, exist_ok=True)
return out_path
def DeforumArgs():
#@markdown **Image Settings**
W = 512 #@param
H = 512 #@param
W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64
#@markdown **Sampling Settings**
seed = -1 #@param
sampler = 'dpm2_ancestral' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"]
steps = 50 #@param
scale = 7 #@param
ddim_eta = 0.0 #@param
dynamic_threshold = None
static_threshold = None
#@markdown **Save & Display Settings**
save_samples = True #@param {type:"boolean"}
save_settings = True #@param {type:"boolean"}
display_samples = True #@param {type:"boolean"}
save_sample_per_step = False #@param {type:"boolean"}
show_sample_per_step = False #@param {type:"boolean"}
#@markdown **Prompt Settings**
prompt_weighting = False #@param {type:"boolean"}
normalize_prompt_weights = True #@param {type:"boolean"}
log_weighted_subprompts = False #@param {type:"boolean"}
#@markdown **Batch Settings**
n_batch = 1 #@param
batch_name = "StableFun" #@param {type:"string"}
filename_format = "{timestring}_{index}_{prompt}.png" #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"]
seed_behavior = "fixed" #@param ["iter","fixed","random"]
make_grid = False #@param {type:"boolean"}
grid_rows = 2 #@param
#@markdown **Init Settings**
use_init = False #@param {type:"boolean"}
strength = 0.0 #@param {type:"number"}
strength_0_no_init = True # Set the strength to 0 automatically when no init image is used
init_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"}
# Whiter areas of the mask are areas that change more
use_mask = False #@param {type:"boolean"}
use_alpha_as_mask = False # use the alpha channel of the init image as the mask
mask_file = "https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg" #@param {type:"string"}
invert_mask = False #@param {type:"boolean"}
# Adjust mask image, 1.0 is no adjustment. Should be positive numbers.
mask_brightness_adjust = 1.0 #@param {type:"number"}
mask_contrast_adjust = 1.0 #@param {type:"number"}
# Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding
overlay_mask = True # {type:"boolean"}
# Blur edges of final overlay mask, if used. Minimum = 0 (no blur)
mask_overlay_blur = 5 # {type:"number"}
n_samples = 1 # doesnt do anything
precision = 'autocast'
C = 4
f = 8
prompt = ""
timestring = ""
init_latent = None
init_sample = None
init_c = None
return locals()
def DeforumAnimArgs():
#@markdown ####**Animation:**
animation_mode = '2D' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}
max_frames = 100 #@param {type:"number"}
border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'}
#@markdown ####**Motion Parameters:**
angle = "0:(0)"#@param {type:"string"}
zoom = "0:(1.04)"#@param {type:"string"}
translation_x = "0:(10*sin(2*3.14*t/10))"#@param {type:"string"}
translation_y = "0:(0)"#@param {type:"string"}
translation_z = "0:(10)"#@param {type:"string"}
rotation_3d_x = "0:(0)"#@param {type:"string"}
rotation_3d_y = "0:(0)"#@param {type:"string"}
rotation_3d_z = "0:(0)"#@param {type:"string"}
flip_2d_perspective = False #@param {type:"boolean"}
perspective_flip_theta = "0:(0)"#@param {type:"string"}
perspective_flip_phi = "0:(t%15)"#@param {type:"string"}
perspective_flip_gamma = "0:(0)"#@param {type:"string"}
perspective_flip_fv = "0:(53)"#@param {type:"string"}
noise_schedule = "0: (0.02)"#@param {type:"string"}
strength_schedule = "0: (0.65)"#@param {type:"string"}
contrast_schedule = "0: (1.0)"#@param {type:"string"}
#@markdown ####**Coherence:**
color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}
diffusion_cadence = '7' #@param ['1','2','3','4','5','6','7','8'] {type:'string'}
#@markdown ####**3D Depth Warping:**
use_depth_warping = True #@param {type:"boolean"}
midas_weight = 0.3#@param {type:"number"}
near_plane = 200
far_plane = 10000
fov = 40#@param {type:"number"}
padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'}
sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'}
save_depth_maps = False #@param {type:"boolean"}
#@markdown ####**Video Input:**
video_init_path ='/content/video_in.mp4'#@param {type:"string"}
extract_nth_frame = 1#@param {type:"number"}
overwrite_extracted_frames = True #@param {type:"boolean"}
use_mask_video = False #@param {type:"boolean"}
video_mask_path ='/content/video_in.mp4'#@param {type:"string"}
#@markdown ####**Interpolation:**
interpolate_key_frames = False #@param {type:"boolean"}
interpolate_x_frames = 4 #@param {type:"number"}
#@markdown ####**Resume Animation:**
resume_from_timestring = False #@param {type:"boolean"}
resume_timestring = "20220829210106" #@param {type:"string"}
return locals()
#
# Callback functions
#
class SamplerCallback(object):
# Creates the callback function to be passed into the samplers for each step
def __init__(self, args, mask=None, init_latent=None, sigmas=None, sampler=None,
verbose=False):
self.sampler_name = args.sampler
self.dynamic_threshold = args.dynamic_threshold
self.static_threshold = args.static_threshold
self.mask = mask
self.init_latent = init_latent
self.sigmas = sigmas
self.sampler = sampler
self.verbose = verbose
self.batch_size = args.n_samples
#self.save_sample_per_step = args.save_sample_per_step
#self.show_sample_per_step = args.show_sample_per_step
#self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ]
#if self.save_sample_per_step:
# for path in self.paths_to_image_steps:
# os.makedirs(path, exist_ok=True)
self.step_index = 0
self.noise = None
if init_latent is not None:
self.noise = torch.randn_like(init_latent, device=device)
self.mask_schedule = None
if sigmas is not None and len(sigmas) > 0:
self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas))
elif len(sigmas) == 0:
self.mask = None # no mask needed if no steps (usually happens because strength==1.0)
if self.sampler_name in ["plms","ddim"]:
if mask is not None:
assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable"
if self.sampler_name in ["plms","ddim"]:
# Callback function formated for compvis latent diffusion samplers
self.callback = self.img_callback_
else:
# Default callback function uses k-diffusion sampler variables
self.callback = self.k_callback_
#self.verbose_print = print if verbose else lambda *args, **kwargs: None
# The callback function is applied to the image at each step
def dynamic_thresholding_(self, img, threshold):
# Dynamic thresholding from Imagen paper (May 2022)
s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))
s = np.max(np.append(s,1.0))
torch.clamp_(img, -1*s, s)
torch.FloatTensor.div_(img, s)
# Callback for samplers in the k-diffusion repo, called thus:
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
def k_callback_(self, args_dict):
self.step_index = args_dict['i']
if self.dynamic_threshold is not None:
self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold)
if self.static_threshold is not None:
torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold)
if self.mask is not None:
init_noise = self.init_latent + self.noise * args_dict['sigma']
is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 )
new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1)
args_dict['x'].copy_(new_img)
#self.view_sample_step(args_dict['denoised'], "x0_pred")
# Callback for Compvis samplers
# Function that is called on the image (img) and step (i) at each step
def img_callback_(self, img, i):
self.step_index = i
# Thresholding functions
if self.dynamic_threshold is not None:
self.dynamic_thresholding_(img, self.dynamic_threshold)
if self.static_threshold is not None:
torch.clamp_(img, -1*self.static_threshold, self.static_threshold)
if self.mask is not None:
i_inv = len(self.sigmas) - i - 1
init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(device), noise=self.noise)
is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 )
new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1)
img.copy_(new_img)
#self.view_sample_step(img, "x")