import importlib from functools import partial import cv2 import numpy as np import safetensors import safetensors.torch import torch import torch.nn as nn from inspect import isfunction from omegaconf import OmegaConf from lib.smplfusion import DDIM, share, scheduler from .common import * DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/resolve/main/x4-upscaler-ema.safetensors?download=true' MODEL_PATH = f'{MODEL_FOLDER}/sd-2-0-upsample/x4-upscaler-ema.safetensors' # pre-download download_file(DOWNLOAD_URL, MODEL_PATH) def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def predict_eps_from_z_and_v(schedule, x_t, t, v): return ( extract_into_tensor(schedule.sqrt_alphas.cuda(), t, x_t.shape) * v + extract_into_tensor(schedule.sqrt_one_minus_alphas.cuda(), t, x_t.shape) * x_t ) def predict_start_from_z_and_v(schedule, x_t, t, v): return ( extract_into_tensor(schedule.sqrt_alphas.cuda(), t, x_t.shape) * x_t - extract_into_tensor(schedule.sqrt_one_minus_alphas.cuda(), t, x_t.shape) * v ) def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": betas = ( torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 ) elif schedule == "cosine": timesteps = ( torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s ) alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) elif schedule == "sqrt_linear": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) elif schedule == "sqrt": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self class AbstractLowScaleModel(nn.Module): # for concatenating a downsampled image to the latent representation def __init__(self, noise_schedule_config=None): super(AbstractLowScaleModel, self).__init__() if noise_schedule_config is not None: self.register_schedule(**noise_schedule_config) def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) def forward(self, x): return x, None def decode(self, x): return x class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): super().__init__(noise_schedule_config=noise_schedule_config) self.max_noise_level = max_noise_level def forward(self, x, noise_level=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) z = self.q_sample(x, noise_level) return z, noise_level def get_obj_from_str(string): module, cls = string.rsplit(".", 1) try: return getattr(importlib.import_module(module, package=None), cls) except: return getattr(importlib.import_module('lib.' + module, package=None), cls) def load_obj(path): objyaml = OmegaConf.load(path) return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {})) def load_model(dtype=torch.bfloat16): print ("Loading model: SD2 superresolution...") download_file(DOWNLOAD_URL, MODEL_PATH) state_dict = safetensors.torch.load_file(MODEL_PATH) config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v2-upsample.yaml') unet = load_obj(f'{CONFIG_FOLDER}/unet/upsample/v2.yaml').eval().cuda() vae = load_obj(f'{CONFIG_FOLDER}/vae-upsample.yaml').eval().cuda() encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda() ddim = DDIM(config, vae, encoder, unet) extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x} unet_state = extract(state_dict, 'model.diffusion_model') encoder_state = extract(state_dict, 'cond_stage_model') vae_state = extract(state_dict, 'first_stage_model') unet.load_state_dict(unet_state) encoder.load_state_dict(encoder_state) vae.load_state_dict(vae_state) unet = unet.requires_grad_(False) encoder = encoder.requires_grad_(False) vae = vae.requires_grad_(False) unet.to(dtype) vae.to(dtype) encoder.to(dtype) ddim = DDIM(config, vae, encoder, unet) params = { 'noise_schedule_config': { 'linear_start': 0.0001, 'linear_end': 0.02 }, 'max_noise_level': 350 } low_scale_model = ImageConcatWithNoiseAugmentation(**params).eval().to('cuda') low_scale_model.train = disabled_train for param in low_scale_model.parameters(): param.requires_grad = False ddim.low_scale_model = low_scale_model print('SD2 superresolution loaded') return ddim