sjc / adapt_gddpm.py
amankishore's picture
Updated app.py
7a11626
from pathlib import Path
from math import sin, pi, sqrt
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from easydict import EasyDict
from guided_diffusion.script_util import (
create_model_and_diffusion,
model_and_diffusion_defaults,
NUM_CLASSES,
create_classifier,
classifier_defaults,
sr_create_model_and_diffusion,
sr_model_and_diffusion_defaults,
)
from adapt import ScoreAdapter
from my.registry import Registry
PRETRAINED_REGISTRY = Registry("pretrained")
device = torch.device("cuda")
def load_ckpt(path, **kwargs):
# with bf.BlobFile(path, "rb") as f:
# data = f.read()
return torch.load(path, **kwargs)
def pick_out_cfgs(src, target_ks):
return {k: src[k] for k in target_ks}
@PRETRAINED_REGISTRY.register()
def m_imgnet_64():
return dict(
attention_resolutions="32,16,8",
class_cond=True,
diffusion_steps=1000,
dropout=0.1,
image_size=64,
learn_sigma=True,
noise_schedule="cosine",
num_channels=192,
num_head_channels=64,
num_res_blocks=3,
resblock_updown=True,
use_new_attention_order=True,
use_fp16=True,
use_scale_shift_norm=True,
classifier_depth=4,
classifier_scale=1.0,
model_path="models/64x64_diffusion.pt",
classifier_path="models/64x64_classifier.pt",
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_128():
return dict(
attention_resolutions="32,16,8",
class_cond=True,
diffusion_steps=1000,
image_size=128,
learn_sigma=True,
noise_schedule="linear",
num_channels=256,
num_heads=4,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
classifier_scale=0.5,
model_path="models/128x128_diffusion.pt",
classifier_path="models/128x128_classifier.pt",
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_256():
return dict(
attention_resolutions="32,16,8",
class_cond=True,
diffusion_steps=1000,
image_size=256,
learn_sigma=True,
noise_schedule="linear",
num_channels=256,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
classifier_scale=1.0,
model_path="models/256x256_diffusion.pt",
classifier_path="models/256x256_classifier.pt"
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_256_uncond():
return dict(
attention_resolutions="32,16,8",
class_cond=False,
diffusion_steps=1000,
image_size=256,
learn_sigma=True,
noise_schedule="linear",
num_channels=256,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
classifier_scale=10.0,
model_path="models/256x256_diffusion_uncond.pt",
classifier_path="models/256x256_classifier.pt",
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_512():
return dict(
attention_resolutions="32,16,8",
class_cond=True,
diffusion_steps=1000,
image_size=512,
learn_sigma=True,
noise_schedule="linear",
num_channels=256,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=False,
use_scale_shift_norm=True,
classifier_scale=4.0,
model_path="models/512x512_diffusion.pt",
classifier_path="models/512x512_classifier.pt"
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_64_256(base_samples="64_samples.npz"):
return dict(
attention_resolutions="32,16,8",
class_cond=True,
diffusion_steps=1000,
large_size=256,
small_size=64,
learn_sigma=True,
noise_schedule="linear",
num_channels=192,
num_heads=4,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
model_path="models/64_256_upsampler.pt",
base_samples=base_samples,
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_128_512(base_samples="128_samples.npz",):
return dict(
attention_resolutions="32,16",
class_cond=True,
diffusion_steps=1000,
large_size=512,
small_size=128,
learn_sigma=True,
noise_schedule="linear",
num_channels=192,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
model_path="models/128_512_upsampler.pt",
base_samples=base_samples,
)
@PRETRAINED_REGISTRY.register()
def m_lsun_256(category="bedroom"):
return dict(
attention_resolutions="32,16,8",
class_cond=False,
diffusion_steps=1000,
dropout=0.1,
image_size=256,
learn_sigma=True,
noise_schedule="linear",
num_channels=256,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
model_path=f"models/lsun_{category}.pt"
)
def img_gen(specific_cfgs, num_samples=16, batch_size=16, load_only=False, ckpt_root=Path("")):
cfgs = EasyDict(
clip_denoised=True,
num_samples=num_samples,
batch_size=batch_size,
use_ddim=False,
model_path="",
classifier_path="",
classifier_scale=1.0,
)
cfgs.update(model_and_diffusion_defaults())
cfgs.update(classifier_defaults())
cfgs.update(specific_cfgs)
use_classifier_guidance = bool(cfgs.classifier_path)
class_aware = cfgs.class_cond or use_classifier_guidance
model, diffusion = create_model_and_diffusion(
**pick_out_cfgs(cfgs, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
load_ckpt(str(ckpt_root / cfgs.model_path), map_location="cpu")
)
model.to(device)
if cfgs.use_fp16:
model.convert_to_fp16()
model.eval()
def model_fn(x, t, y=None):
return model(x, t, y if cfgs.class_cond else None)
classifier = None
cond_fn = None
if use_classifier_guidance:
classifier = create_classifier(
**pick_out_cfgs(cfgs, classifier_defaults().keys())
)
classifier.load_state_dict(
load_ckpt(str(ckpt_root / cfgs.classifier_path), map_location="cpu")
)
classifier.to(device)
if cfgs.classifier_use_fp16:
classifier.convert_to_fp16()
classifier.eval()
def cond_fn(x, t, y=None):
assert y is not None
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
return torch.autograd.grad(selected.sum(), x_in)[0] * cfgs.classifier_scale
if load_only:
return model, classifier
all_images = []
all_labels = []
while len(all_images) * cfgs.batch_size < cfgs.num_samples:
model_kwargs = {}
if class_aware:
classes = torch.randint(
low=0, high=NUM_CLASSES, size=(cfgs.batch_size,), device=device
)
model_kwargs["y"] = classes
sample_fn = (
diffusion.p_sample_loop if not cfgs.use_ddim else diffusion.ddim_sample_loop
)
sample = sample_fn(
model_fn,
(cfgs.batch_size, 3, cfgs.image_size, cfgs.image_size),
clip_denoised=cfgs.clip_denoised,
model_kwargs=model_kwargs,
cond_fn=cond_fn,
device=device,
progress=True
)
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
all_images.append(sample.cpu().numpy())
if class_aware:
all_labels.append(classes.cpu().numpy())
arr = np.concatenate(all_images, axis=0)
arr = arr[:cfgs.num_samples]
if class_aware:
all_labels = np.concatenate(all_labels, axis=0)
all_labels = all_labels[:cfgs.num_samples]
shape_str = "x".join([str(x) for x in arr.shape])
out_path = Path("./out") / f"samples_{shape_str}.npz"
np.savez(out_path, arr, all_labels)
def img_upsamp(specific_cfgs, num_samples=16, batch_size=16, load_only=False):
"""note that here the ckpt root is not configured properly; will break but easy fix"""
cfgs = EasyDict(
clip_denoised=True,
num_samples=num_samples,
batch_size=batch_size,
use_ddim=False,
base_samples="",
model_path="",
)
cfgs.update(sr_model_and_diffusion_defaults())
cfgs.update(specific_cfgs)
model, diffusion = sr_create_model_and_diffusion(
**pick_out_cfgs(cfgs, sr_model_and_diffusion_defaults().keys())
)
model.load_state_dict(load_ckpt(cfgs.model_path, map_location="cpu"))
model.to(device)
if cfgs.use_fp16:
model.convert_to_fp16()
model.eval()
if load_only:
return model
data = load_low_res_samples(
cfgs.base_samples, cfgs.batch_size, cfgs.class_cond
)
all_images = []
while len(all_images) * cfgs.batch_size < cfgs.num_samples:
model_kwargs = next(data)
model_kwargs = {k: v.to(device) for k, v in model_kwargs.items()}
samples = diffusion.p_sample_loop(
model,
(cfgs.batch_size, 3, cfgs.large_size, cfgs.large_size),
clip_denoised=cfgs.clip_denoised,
model_kwargs=model_kwargs,
progress=True
)
samples = ((samples + 1) * 127.5).clamp(0, 255).to(torch.uint8)
samples = samples.permute(0, 2, 3, 1)
samples = samples.contiguous()
all_images.append(samples.cpu().numpy())
arr = np.concatenate(all_images, axis=0)
arr = arr[: cfgs.num_samples]
shape_str = "x".join([str(x) for x in arr.shape])
out_path = Path("./out") / f"samples_{shape_str}.npz"
np.savez(out_path, arr)
def load_low_res_samples(base_samples, batch_size, class_cond):
obj = np.load(base_samples)
image_arr = obj["arr_0"]
if class_cond:
label_arr = obj["arr_1"]
buffer = []
label_buffer = []
while True:
for i in range(len(image_arr)):
buffer.append(image_arr[i])
if class_cond:
label_buffer.append(label_arr[i])
if len(buffer) == batch_size:
batch = torch.from_numpy(np.stack(buffer)).float()
batch = batch / 127.5 - 1.0
batch = batch.permute(0, 3, 1, 2)
res = {}
res["low_res"] = batch
if class_cond:
res["y"] = torch.from_numpy(np.stack(label_buffer))
yield res
buffer, label_buffer = [], []
def class_cond_info(imgnet_cat):
def rand_cond_fn(batch_size):
cats = torch.randint(
low=0, high=NUM_CLASSES, size=(batch_size,), device=device
)
return {"y": cats}
def class_specific_cond(batch_size):
cats = torch.tensor([imgnet_cat, ] * batch_size, device=device)
return {"y": cats}
if imgnet_cat == -1:
return rand_cond_fn
else:
return class_specific_cond
def _sqrt(x):
if isinstance(x, float):
return sqrt(x)
else:
assert isinstance(x, torch.Tensor)
return torch.sqrt(x)
class GuidedDDPM(ScoreAdapter):
def __init__(self, model, lsun_cat, imgnet_cat):
print(PRETRAINED_REGISTRY)
cfgs = PRETRAINED_REGISTRY.get(model)(
**({"category": lsun_cat} if model.startswith("m_lsun") else {})
)
self.unet, self.classifier = img_gen(
cfgs, load_only=True, ckpt_root=self.checkpoint_root() / "guided_ddpm"
)
H, W = cfgs['image_size'], cfgs['image_size']
self._data_shape = (3, H, W)
if cfgs['class_cond'] or (self.classifier is not None):
cond_func = class_cond_info(imgnet_cat)
else:
cond_func = lambda *args, **kwargs: {}
self.cond_func = cond_func
self._unet_is_cond = bool(cfgs['class_cond'])
noise_schedule = cfgs['noise_schedule']
assert noise_schedule in ("linear", "cosine")
self.M = 1000
if noise_schedule == "linear":
self.us = self.linear_us(self.M)
self._σ_min = 0.01
else:
self.us = self.cosine_us(self.M)
self._σ_min = 0.0064
self.noise_schedule = noise_schedule
self._device = next(self.unet.parameters()).device
def data_shape(self):
return self._data_shape
@property
def σ_max(self):
return self.us[0]
@property
def σ_min(self):
return self.us[-1]
@torch.no_grad()
def denoise(self, xs, σ, **model_kwargs):
N = xs.shape[0]
cond_t, σ = self.time_cond_vec(N, σ)
output = self.unet(
xs / _sqrt(1 + σ**2), cond_t, **model_kwargs
)
# not using the var pred
n_hat = torch.split(output, xs.shape[1], dim=1)[0]
Ds = xs - σ * n_hat
return Ds
def cond_info(self, batch_size):
return self.cond_func(batch_size)
def unet_is_cond(self):
return self._unet_is_cond
def use_cls_guidance(self):
return (self.classifier is not None)
@torch.no_grad()
def classifier_grad(self, xs, σ, ys):
N = xs.shape[0]
cond_t, σ = self.time_cond_vec(N, σ)
with torch.enable_grad():
x_in = xs.detach().requires_grad_(True)
logits = self.classifier(x_in, cond_t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), ys.view(-1)]
grad = torch.autograd.grad(selected.sum(), x_in)[0]
grad = grad * (1 / sqrt(1 + σ**2))
return grad
def snap_t_to_nearest_tick(self, t):
j = np.abs(t - self.us).argmin()
return self.us[j], j
def time_cond_vec(self, N, σ):
if isinstance(σ, float):
σ, j = self.snap_t_to_nearest_tick(σ) # σ might change due to snapping
cond_t = (self.M - 1) - j
cond_t = torch.tensor([cond_t] * N, device=self.device)
return cond_t, σ
else:
assert isinstance(σ, torch.Tensor)
σ = σ.reshape(-1).cpu().numpy()
σs = []
js = []
for elem in σ:
_σ, _j = self.snap_t_to_nearest_tick(elem)
σs.append(_σ)
js.append((self.M - 1) - _j)
cond_t = torch.tensor(js, device=self.device)
σs = torch.tensor(σs, device=self.device, dtype=torch.float32).reshape(-1, 1, 1, 1)
return cond_t, σs
@staticmethod
def cosine_us(M=1000):
assert M == 1000
def α_bar(j):
return sin(pi / 2 * j / (M * (0.008 + 1))) ** 2
us = [0, ]
for j in reversed(range(0, M)): # [M-1, 0], inclusive
u_j = sqrt(((us[-1] ** 2) + 1) / (max(α_bar(j) / α_bar(j+1), 0.001)) - 1)
us.append(u_j)
us = np.array(us)
us = us[1:]
us = us[::-1]
return us
@staticmethod
def linear_us(M=1000):
assert M == 1000
β_start = 0.0001
β_end = 0.02
βs = np.linspace(β_start, β_end, M, dtype=np.float64)
αs = np.cumprod(1 - βs)
us = np.sqrt((1 - αs) / αs)
us = us[::-1]
return us