lol / adapt_gddpm.py
paracanthurus's picture
Duplicate from MirageML/sjc
f327edf
from pathlib import Path
from math import sin, pi, sqrt
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from easydict import EasyDict
from guided_diffusion.script_util import (
create_model_and_diffusion,
model_and_diffusion_defaults,
NUM_CLASSES,
create_classifier,
classifier_defaults,
sr_create_model_and_diffusion,
sr_model_and_diffusion_defaults,
)
from adapt import ScoreAdapter
from my.registry import Registry
PRETRAINED_REGISTRY = Registry("pretrained")
device = torch.device("cuda")
def load_ckpt(path, **kwargs):
# with bf.BlobFile(path, "rb") as f:
# data = f.read()
return torch.load(path, **kwargs)
def pick_out_cfgs(src, target_ks):
return {k: src[k] for k in target_ks}
@PRETRAINED_REGISTRY.register()
def m_imgnet_64():
return dict(
attention_resolutions="32,16,8",
class_cond=True,
diffusion_steps=1000,
dropout=0.1,
image_size=64,
learn_sigma=True,
noise_schedule="cosine",
num_channels=192,
num_head_channels=64,
num_res_blocks=3,
resblock_updown=True,
use_new_attention_order=True,
use_fp16=True,
use_scale_shift_norm=True,
classifier_depth=4,
classifier_scale=1.0,
model_path="models/64x64_diffusion.pt",
classifier_path="models/64x64_classifier.pt",
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_128():
return dict(
attention_resolutions="32,16,8",
class_cond=True,
diffusion_steps=1000,
image_size=128,
learn_sigma=True,
noise_schedule="linear",
num_channels=256,
num_heads=4,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
classifier_scale=0.5,
model_path="models/128x128_diffusion.pt",
classifier_path="models/128x128_classifier.pt",
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_256():
return dict(
attention_resolutions="32,16,8",
class_cond=True,
diffusion_steps=1000,
image_size=256,
learn_sigma=True,
noise_schedule="linear",
num_channels=256,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
classifier_scale=1.0,
model_path="models/256x256_diffusion.pt",
classifier_path="models/256x256_classifier.pt"
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_256_uncond():
return dict(
attention_resolutions="32,16,8",
class_cond=False,
diffusion_steps=1000,
image_size=256,
learn_sigma=True,
noise_schedule="linear",
num_channels=256,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
classifier_scale=10.0,
model_path="models/256x256_diffusion_uncond.pt",
classifier_path="models/256x256_classifier.pt",
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_512():
return dict(
attention_resolutions="32,16,8",
class_cond=True,
diffusion_steps=1000,
image_size=512,
learn_sigma=True,
noise_schedule="linear",
num_channels=256,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=False,
use_scale_shift_norm=True,
classifier_scale=4.0,
model_path="models/512x512_diffusion.pt",
classifier_path="models/512x512_classifier.pt"
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_64_256(base_samples="64_samples.npz"):
return dict(
attention_resolutions="32,16,8",
class_cond=True,
diffusion_steps=1000,
large_size=256,
small_size=64,
learn_sigma=True,
noise_schedule="linear",
num_channels=192,
num_heads=4,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
model_path="models/64_256_upsampler.pt",
base_samples=base_samples,
)
@PRETRAINED_REGISTRY.register()
def m_imgnet_128_512(base_samples="128_samples.npz",):
return dict(
attention_resolutions="32,16",
class_cond=True,
diffusion_steps=1000,
large_size=512,
small_size=128,
learn_sigma=True,
noise_schedule="linear",
num_channels=192,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
model_path="models/128_512_upsampler.pt",
base_samples=base_samples,
)
@PRETRAINED_REGISTRY.register()
def m_lsun_256(category="bedroom"):
return dict(
attention_resolutions="32,16,8",
class_cond=False,
diffusion_steps=1000,
dropout=0.1,
image_size=256,
learn_sigma=True,
noise_schedule="linear",
num_channels=256,
num_head_channels=64,
num_res_blocks=2,
resblock_updown=True,
use_fp16=True,
use_scale_shift_norm=True,
model_path=f"models/lsun_{category}.pt"
)
def img_gen(specific_cfgs, num_samples=16, batch_size=16, load_only=False, ckpt_root=Path("")):
cfgs = EasyDict(
clip_denoised=True,
num_samples=num_samples,
batch_size=batch_size,
use_ddim=False,
model_path="",
classifier_path="",
classifier_scale=1.0,
)
cfgs.update(model_and_diffusion_defaults())
cfgs.update(classifier_defaults())
cfgs.update(specific_cfgs)
use_classifier_guidance = bool(cfgs.classifier_path)
class_aware = cfgs.class_cond or use_classifier_guidance
model, diffusion = create_model_and_diffusion(
**pick_out_cfgs(cfgs, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
load_ckpt(str(ckpt_root / cfgs.model_path), map_location="cpu")
)
model.to(device)
if cfgs.use_fp16:
model.convert_to_fp16()
model.eval()
def model_fn(x, t, y=None):
return model(x, t, y if cfgs.class_cond else None)
classifier = None
cond_fn = None
if use_classifier_guidance:
classifier = create_classifier(
**pick_out_cfgs(cfgs, classifier_defaults().keys())
)
classifier.load_state_dict(
load_ckpt(str(ckpt_root / cfgs.classifier_path), map_location="cpu")
)
classifier.to(device)
if cfgs.classifier_use_fp16:
classifier.convert_to_fp16()
classifier.eval()
def cond_fn(x, t, y=None):
assert y is not None
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
return torch.autograd.grad(selected.sum(), x_in)[0] * cfgs.classifier_scale
if load_only:
return model, classifier
all_images = []
all_labels = []
while len(all_images) * cfgs.batch_size < cfgs.num_samples:
model_kwargs = {}
if class_aware:
classes = torch.randint(
low=0, high=NUM_CLASSES, size=(cfgs.batch_size,), device=device
)
model_kwargs["y"] = classes
sample_fn = (
diffusion.p_sample_loop if not cfgs.use_ddim else diffusion.ddim_sample_loop
)
sample = sample_fn(
model_fn,
(cfgs.batch_size, 3, cfgs.image_size, cfgs.image_size),
clip_denoised=cfgs.clip_denoised,
model_kwargs=model_kwargs,
cond_fn=cond_fn,
device=device,
progress=True
)
sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
all_images.append(sample.cpu().numpy())
if class_aware:
all_labels.append(classes.cpu().numpy())
arr = np.concatenate(all_images, axis=0)
arr = arr[:cfgs.num_samples]
if class_aware:
all_labels = np.concatenate(all_labels, axis=0)
all_labels = all_labels[:cfgs.num_samples]
shape_str = "x".join([str(x) for x in arr.shape])
out_path = Path("./out") / f"samples_{shape_str}.npz"
np.savez(out_path, arr, all_labels)
def img_upsamp(specific_cfgs, num_samples=16, batch_size=16, load_only=False):
"""note that here the ckpt root is not configured properly; will break but easy fix"""
cfgs = EasyDict(
clip_denoised=True,
num_samples=num_samples,
batch_size=batch_size,
use_ddim=False,
base_samples="",
model_path="",
)
cfgs.update(sr_model_and_diffusion_defaults())
cfgs.update(specific_cfgs)
model, diffusion = sr_create_model_and_diffusion(
**pick_out_cfgs(cfgs, sr_model_and_diffusion_defaults().keys())
)
model.load_state_dict(load_ckpt(cfgs.model_path, map_location="cpu"))
model.to(device)
if cfgs.use_fp16:
model.convert_to_fp16()
model.eval()
if load_only:
return model
data = load_low_res_samples(
cfgs.base_samples, cfgs.batch_size, cfgs.class_cond
)
all_images = []
while len(all_images) * cfgs.batch_size < cfgs.num_samples:
model_kwargs = next(data)
model_kwargs = {k: v.to(device) for k, v in model_kwargs.items()}
samples = diffusion.p_sample_loop(
model,
(cfgs.batch_size, 3, cfgs.large_size, cfgs.large_size),
clip_denoised=cfgs.clip_denoised,
model_kwargs=model_kwargs,
progress=True
)
samples = ((samples + 1) * 127.5).clamp(0, 255).to(torch.uint8)
samples = samples.permute(0, 2, 3, 1)
samples = samples.contiguous()
all_images.append(samples.cpu().numpy())
arr = np.concatenate(all_images, axis=0)
arr = arr[: cfgs.num_samples]
shape_str = "x".join([str(x) for x in arr.shape])
out_path = Path("./out") / f"samples_{shape_str}.npz"
np.savez(out_path, arr)
def load_low_res_samples(base_samples, batch_size, class_cond):
obj = np.load(base_samples)
image_arr = obj["arr_0"]
if class_cond:
label_arr = obj["arr_1"]
buffer = []
label_buffer = []
while True:
for i in range(len(image_arr)):
buffer.append(image_arr[i])
if class_cond:
label_buffer.append(label_arr[i])
if len(buffer) == batch_size:
batch = torch.from_numpy(np.stack(buffer)).float()
batch = batch / 127.5 - 1.0
batch = batch.permute(0, 3, 1, 2)
res = {}
res["low_res"] = batch
if class_cond:
res["y"] = torch.from_numpy(np.stack(label_buffer))
yield res
buffer, label_buffer = [], []
def class_cond_info(imgnet_cat):
def rand_cond_fn(batch_size):
cats = torch.randint(
low=0, high=NUM_CLASSES, size=(batch_size,), device=device
)
return {"y": cats}
def class_specific_cond(batch_size):
cats = torch.tensor([imgnet_cat, ] * batch_size, device=device)
return {"y": cats}
if imgnet_cat == -1:
return rand_cond_fn
else:
return class_specific_cond
def _sqrt(x):
if isinstance(x, float):
return sqrt(x)
else:
assert isinstance(x, torch.Tensor)
return torch.sqrt(x)
class GuidedDDPM(ScoreAdapter):
def __init__(self, model, lsun_cat, imgnet_cat):
print(PRETRAINED_REGISTRY)
cfgs = PRETRAINED_REGISTRY.get(model)(
**({"category": lsun_cat} if model.startswith("m_lsun") else {})
)
self.unet, self.classifier = img_gen(
cfgs, load_only=True, ckpt_root=self.checkpoint_root() / "guided_ddpm"
)
H, W = cfgs['image_size'], cfgs['image_size']
self._data_shape = (3, H, W)
if cfgs['class_cond'] or (self.classifier is not None):
cond_func = class_cond_info(imgnet_cat)
else:
cond_func = lambda *args, **kwargs: {}
self.cond_func = cond_func
self._unet_is_cond = bool(cfgs['class_cond'])
noise_schedule = cfgs['noise_schedule']
assert noise_schedule in ("linear", "cosine")
self.M = 1000
if noise_schedule == "linear":
self.us = self.linear_us(self.M)
self._σ_min = 0.01
else:
self.us = self.cosine_us(self.M)
self._σ_min = 0.0064
self.noise_schedule = noise_schedule
self._device = next(self.unet.parameters()).device
def data_shape(self):
return self._data_shape
@property
def σ_max(self):
return self.us[0]
@property
def σ_min(self):
return self.us[-1]
@torch.no_grad()
def denoise(self, xs, σ, **model_kwargs):
N = xs.shape[0]
cond_t, σ = self.time_cond_vec(N, σ)
output = self.unet(
xs / _sqrt(1 + σ**2), cond_t, **model_kwargs
)
# not using the var pred
n_hat = torch.split(output, xs.shape[1], dim=1)[0]
Ds = xs - σ * n_hat
return Ds
def cond_info(self, batch_size):
return self.cond_func(batch_size)
def unet_is_cond(self):
return self._unet_is_cond
def use_cls_guidance(self):
return (self.classifier is not None)
@torch.no_grad()
def classifier_grad(self, xs, σ, ys):
N = xs.shape[0]
cond_t, σ = self.time_cond_vec(N, σ)
with torch.enable_grad():
x_in = xs.detach().requires_grad_(True)
logits = self.classifier(x_in, cond_t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), ys.view(-1)]
grad = torch.autograd.grad(selected.sum(), x_in)[0]
grad = grad * (1 / sqrt(1 + σ**2))
return grad
def snap_t_to_nearest_tick(self, t):
j = np.abs(t - self.us).argmin()
return self.us[j], j
def time_cond_vec(self, N, σ):
if isinstance(σ, float):
σ, j = self.snap_t_to_nearest_tick(σ) # σ might change due to snapping
cond_t = (self.M - 1) - j
cond_t = torch.tensor([cond_t] * N, device=self.device)
return cond_t, σ
else:
assert isinstance(σ, torch.Tensor)
σ = σ.reshape(-1).cpu().numpy()
σs = []
js = []
for elem in σ:
_σ, _j = self.snap_t_to_nearest_tick(elem)
σs.append(_σ)
js.append((self.M - 1) - _j)
cond_t = torch.tensor(js, device=self.device)
σs = torch.tensor(σs, device=self.device, dtype=torch.float32).reshape(-1, 1, 1, 1)
return cond_t, σs
@staticmethod
def cosine_us(M=1000):
assert M == 1000
def α_bar(j):
return sin(pi / 2 * j / (M * (0.008 + 1))) ** 2
us = [0, ]
for j in reversed(range(0, M)): # [M-1, 0], inclusive
u_j = sqrt(((us[-1] ** 2) + 1) / (max(α_bar(j) / α_bar(j+1), 0.001)) - 1)
us.append(u_j)
us = np.array(us)
us = us[1:]
us = us[::-1]
return us
@staticmethod
def linear_us(M=1000):
assert M == 1000
β_start = 0.0001
β_end = 0.02
βs = np.linspace(β_start, β_end, M, dtype=np.float64)
αs = np.cumprod(1 - βs)
us = np.sqrt((1 - αs) / αs)
us = us[::-1]
return us