from pathlib import Path import numpy as np import torch from misc import torch_samps_to_imgs from adapt import Karras, ScoreAdapter, power_schedule from adapt_gddpm import GuidedDDPM from adapt_ncsn import NCSN as _NCSN # from adapt_vesde import VESDE # not included to prevent import conflicts from adapt_sd import StableDiffusion from my.utils import tqdm, EventStorage, HeartBeat, EarlyLoopBreak from my.config import BaseConf, dispatch from my.utils.seed import seed_everything class GDDPM(BaseConf): """Guided DDPM from OpenAI""" model: str = "m_lsun_256" lsun_cat: str = "bedroom" imgnet_cat: int = -1 def make(self): args = self.dict() model = GuidedDDPM(**args) return model class SD(BaseConf): """Stable Diffusion""" variant: str = "v1" v2_highres: bool = False prompt: str = "a photograph of an astronaut riding a horse" scale: float = 3.0 # classifier free guidance scale precision: str = 'autocast' def make(self): args = self.dict() model = StableDiffusion(**args) return model class SDE(BaseConf): def make(self): args = self.dict() model = VESDE(**args) return model class NCSN(BaseConf): def make(self): args = self.dict() model = _NCSN(**args) return model class KarrasGen(BaseConf): family: str = "gddpm" gddpm: GDDPM = GDDPM() sd: SD = SD() # sde: SDE = SDE() ncsn: NCSN = NCSN() batch_size: int = 10 num_images: int = 1250 num_t: int = 40 σ_max: float = 80.0 heun: bool = True langevin: bool = False cls_scaling: float = 1.0 # classifier guidance scaling def run(self): args = self.dict() family = args.pop("family") model = getattr(self, family).make() self.karras_generate(model, **args) @staticmethod def karras_generate( model: ScoreAdapter, batch_size, num_images, σ_max, num_t, langevin, heun, cls_scaling, **kwargs ): del kwargs # removed extra args num_batches = num_images // batch_size fuse = EarlyLoopBreak(5) with tqdm(total=num_batches) as pbar, \ HeartBeat(pbar) as hbeat, \ EventStorage() as metric: all_imgs = [] for _ in range(num_batches): if fuse.on_break(): break pipeline = Karras.inference( model, batch_size, num_t, init_xs=None, heun=heun, σ_max=σ_max, langevin=langevin, cls_scaling=cls_scaling ) for imgs in tqdm(pipeline, total=num_t+1, disable=False): # _std = imgs.std().item() # print(_std) hbeat.beat() pass if isinstance(model, StableDiffusion): imgs = model.decode(imgs) imgs = torch_samps_to_imgs(imgs, uncenter=model.samps_centered()) all_imgs.append(imgs) pbar.update() all_imgs = np.concatenate(all_imgs, axis=0) metric.put_artifact("imgs", ".npy", lambda fn: np.save(fn, all_imgs)) metric.step() hbeat.done() class SMLDGen(BaseConf): family: str = "ncsn" gddpm: GDDPM = GDDPM() # sde: SDE = SDE() ncsn: NCSN = NCSN() batch_size: int = 16 num_images: int = 16 num_stages: int = 80 num_steps: int = 15 σ_max: float = 80.0 ε: float = 1e-5 def run(self): args = self.dict() family = args.pop("family") model = getattr(self, family).make() self.smld_generate(model, **args) @staticmethod def smld_generate( model: ScoreAdapter, batch_size, num_images, num_stages, num_steps, σ_max, ε, **kwargs ): num_batches = num_images // batch_size σs = power_schedule(σ_max, model.σ_min, num_stages) σs = [model.snap_t_to_nearest_tick(σ)[0] for σ in σs] fuse = EarlyLoopBreak(5) with tqdm(total=num_batches) as pbar, \ HeartBeat(pbar) as hbeat, \ EventStorage() as metric: all_imgs = [] for _ in range(num_batches): if fuse.on_break(): break init_xs = torch.rand(batch_size, *model.data_shape(), device=model.device) if model.samps_centered(): init_xs = init_xs * 2 - 1 # [0, 1] -> [-1, 1] pipeline = smld_inference( model, σs, num_steps, ε, init_xs ) for imgs in tqdm(pipeline, total=(num_stages * num_steps)+1, disable=False): pbar.set_description(f"{imgs.max().item():.3f}") metric.put_scalars( max=imgs.max().item(), min=imgs.min().item(), std=imgs.std().item() ) metric.step() hbeat.beat() pbar.update() imgs = torch_samps_to_imgs(imgs, uncenter=model.samps_centered()) all_imgs.append(imgs) all_imgs = np.concatenate(all_imgs, axis=0) metric.put_artifact("imgs", ".npy", lambda fn: np.save(fn, all_imgs)) metric.step() hbeat.done() def smld_inference(model, σs, num_steps, ε, init_xs): from math import sqrt # not doing conditioning or cls guidance; for gddpm only lsun works; fine. xs = init_xs yield xs for i in range(len(σs)): α_i = ε * ((σs[i] / σs[-1]) ** 2) for _ in range(num_steps): grad = model.score(xs, σs[i]) z = torch.randn_like(xs) xs = xs + α_i * grad + sqrt(2 * α_i) * z yield xs def load_np_imgs(fname): fname = Path(fname) data = np.load(fname) if fname.suffix == ".npz": imgs = data['arr_0'] else: imgs = data return imgs def visualize(max_n_imgs=16): import torchvision.utils as vutils from imageio import imwrite from einops import rearrange all_imgs = load_np_imgs("imgs/step_0.npy") imgs = all_imgs[:max_n_imgs] imgs = rearrange(imgs, "N H W C -> N C H W", C=3) imgs = torch.from_numpy(imgs) pane = vutils.make_grid(imgs, padding=2, nrow=4) pane = rearrange(pane, "C H W -> H W C", C=3) pane = pane.numpy() imwrite("preview.jpg", pane) if __name__ == "__main__": seed_everything(0) dispatch(KarrasGen) visualize(16)