from torchvision.datasets.utils import download_url from ldm.util import instantiate_from_config import torch import os # todo ? from google.colab import files from IPython.display import Image as ipyimg import ipywidgets as widgets from PIL import Image from numpy import asarray from einops import rearrange, repeat import torch, torchvision from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import ismap import time from omegaconf import OmegaConf def download_models(mode): if mode == "superresolution": # this is the small bsr light model url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1' url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1' path_conf = 'logs/diffusion/superresolution_bsr/configs/project.yaml' path_ckpt = 'logs/diffusion/superresolution_bsr/checkpoints/last.ckpt' download_url(url_conf, path_conf) download_url(url_ckpt, path_ckpt) path_conf = path_conf + '/?dl=1' # fix it path_ckpt = path_ckpt + '/?dl=1' # fix it return path_conf, path_ckpt else: raise NotImplementedError def load_model_from_config(config, ckpt): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") global_step = pl_sd["global_step"] sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) model.cuda() model.eval() return {"model": model}, global_step def get_model(mode): path_conf, path_ckpt = download_models(mode) config = OmegaConf.load(path_conf) model, step = load_model_from_config(config, path_ckpt) return model def get_custom_cond(mode): dest = "data/example_conditioning" if mode == "superresolution": uploaded_img = files.upload() filename = next(iter(uploaded_img)) name, filetype = filename.split(".") # todo assumes just one dot in name ! os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}") elif mode == "text_conditional": w = widgets.Text(value='A cake with cream!', disabled=True) display(w) with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f: f.write(w.value) elif mode == "class_conditional": w = widgets.IntSlider(min=0, max=1000) display(w) with open(f"{dest}/{mode}/custom.txt", 'w') as f: f.write(w.value) else: raise NotImplementedError(f"cond not implemented for mode{mode}") def get_cond_options(mode): path = "data/example_conditioning" path = os.path.join(path, mode) onlyfiles = [f for f in sorted(os.listdir(path))] return path, onlyfiles def select_cond_path(mode): path = "data/example_conditioning" # todo path = os.path.join(path, mode) onlyfiles = [f for f in sorted(os.listdir(path))] selected = widgets.RadioButtons( options=onlyfiles, description='Select conditioning:', disabled=False ) display(selected) selected_path = os.path.join(path, selected.value) return selected_path def get_cond(mode, selected_path): example = dict() if mode == "superresolution": up_f = 4 visualize_cond_img(selected_path) c = Image.open(selected_path) c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True) c_up = rearrange(c_up, '1 c h w -> 1 h w c') c = rearrange(c, '1 c h w -> 1 h w c') c = 2. * c - 1. c = c.to(torch.device("cuda")) example["LR_image"] = c example["image"] = c_up return example def visualize_cond_img(path): display(ipyimg(filename=path)) def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None): example = get_cond(task, selected_path) save_intermediate_vid = False n_runs = 1 masked = False guider = None ckwargs = None mode = 'ddim' ddim_use_x0_pred = False temperature = 1. eta = 1. make_progrow = True custom_shape = None height, width = example["image"].shape[1:3] split_input = height >= 128 and width >= 128 if split_input: ks = 128 stride = 64 vqf = 4 # model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), "vqf": vqf, "patch_distributed_vq": True, "tie_braker": False, "clip_max_weight": 0.5, "clip_min_weight": 0.01, "clip_max_tie_weight": 0.5, "clip_min_tie_weight": 0.01} else: if hasattr(model, "split_input_params"): delattr(model, "split_input_params") invert_mask = False x_T = None for n in range(n_runs): if custom_shape is not None: x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0]) logs = make_convolutional_sample(example, model, mode=mode, custom_steps=custom_steps, eta=eta, swap_mode=False , masked=masked, invert_mask=invert_mask, quantize_x0=False, custom_schedule=None, decode_interval=10, resize_enabled=resize_enabled, custom_shape=custom_shape, temperature=temperature, noise_dropout=0., corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid, make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred ) return logs @torch.no_grad() def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, mask=None, x0=None, quantize_x0=False, img_callback=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, x_T=None, log_every_t=None ): ddim = DDIMSampler(model) bs = shape[0] # dont know where this comes from but wayne shape = shape[1:] # cut batch dim print(f"Sampling with eta = {eta}; steps: {steps}") samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, mask=mask, x0=x0, temperature=temperature, verbose=False, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, x_T=x_T) return samples, intermediates @torch.no_grad() def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False, invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000, resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False): log = dict() z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, return_first_stage_outputs=True, force_c_encode=not (hasattr(model, 'split_input_params') and model.cond_stage_key == 'coordinates_bbox'), return_original_cond=True) log_every_t = 1 if save_intermediate_vid else None if custom_shape is not None: z = torch.randn(custom_shape) print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") z0 = None log["input"] = x log["reconstruction"] = xrec if ismap(xc): log["original_conditioning"] = model.to_rgb(xc) if hasattr(model, 'cond_stage_key'): log[model.cond_stage_key] = model.to_rgb(xc) else: log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) if model.cond_stage_model: log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) if model.cond_stage_key =='class_label': log[model.cond_stage_key] = xc[model.cond_stage_key] with model.ema_scope("Plotting"): t0 = time.time() img_cb = None sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, eta=eta, quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0, temperature=temperature, noise_dropout=noise_dropout, score_corrector=corrector, corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t) t1 = time.time() if ddim_use_x0_pred: sample = intermediates['pred_x0'][-1] x_sample = model.decode_first_stage(sample) try: x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) log["sample_noquant"] = x_sample_noquant log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) except: pass log["sample"] = x_sample log["time"] = t1 - t0 return log