|  | import math | 
					
						
						|  | import os | 
					
						
						|  | import urllib | 
					
						
						|  | import warnings | 
					
						
						|  | from argparse import ArgumentParser | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from huggingface_hub.utils import insecure_hashlib | 
					
						
						|  | from safetensors.torch import load_file as stl | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  |  | 
					
						
						|  | from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel | 
					
						
						|  | from diffusers.models.autoencoders.vae import Encoder | 
					
						
						|  | from diffusers.models.embeddings import TimestepEmbedding | 
					
						
						|  | from diffusers.models.unets.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | args = ArgumentParser() | 
					
						
						|  | args.add_argument("--save_pretrained", required=False, default=None, type=str) | 
					
						
						|  | args.add_argument("--test_image", required=True, type=str) | 
					
						
						|  | args = args.parse_args() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _extract_into_tensor(arr, timesteps, broadcast_shape): | 
					
						
						|  |  | 
					
						
						|  | res = arr[timesteps].float() | 
					
						
						|  | dims_to_append = len(broadcast_shape) - len(res.shape) | 
					
						
						|  | return res[(...,) + (None,) * dims_to_append] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): | 
					
						
						|  |  | 
					
						
						|  | betas = [] | 
					
						
						|  | for i in range(num_diffusion_timesteps): | 
					
						
						|  | t1 = i / num_diffusion_timesteps | 
					
						
						|  | t2 = (i + 1) / num_diffusion_timesteps | 
					
						
						|  | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) | 
					
						
						|  | return torch.tensor(betas) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _download(url: str, root: str): | 
					
						
						|  | os.makedirs(root, exist_ok=True) | 
					
						
						|  | filename = os.path.basename(url) | 
					
						
						|  |  | 
					
						
						|  | expected_sha256 = url.split("/")[-2] | 
					
						
						|  | download_target = os.path.join(root, filename) | 
					
						
						|  |  | 
					
						
						|  | if os.path.exists(download_target) and not os.path.isfile(download_target): | 
					
						
						|  | raise RuntimeError(f"{download_target} exists and is not a regular file") | 
					
						
						|  |  | 
					
						
						|  | if os.path.isfile(download_target): | 
					
						
						|  | if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: | 
					
						
						|  | return download_target | 
					
						
						|  | else: | 
					
						
						|  | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") | 
					
						
						|  |  | 
					
						
						|  | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: | 
					
						
						|  | with tqdm( | 
					
						
						|  | total=int(source.info().get("Content-Length")), | 
					
						
						|  | ncols=80, | 
					
						
						|  | unit="iB", | 
					
						
						|  | unit_scale=True, | 
					
						
						|  | unit_divisor=1024, | 
					
						
						|  | ) as loop: | 
					
						
						|  | while True: | 
					
						
						|  | buffer = source.read(8192) | 
					
						
						|  | if not buffer: | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  | output.write(buffer) | 
					
						
						|  | loop.update(len(buffer)) | 
					
						
						|  |  | 
					
						
						|  | if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: | 
					
						
						|  | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") | 
					
						
						|  |  | 
					
						
						|  | return download_target | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ConsistencyDecoder: | 
					
						
						|  | def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")): | 
					
						
						|  | self.n_distilled_steps = 64 | 
					
						
						|  | download_target = _download( | 
					
						
						|  | "https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt", | 
					
						
						|  | download_root, | 
					
						
						|  | ) | 
					
						
						|  | self.ckpt = torch.jit.load(download_target).to(device) | 
					
						
						|  | self.device = device | 
					
						
						|  | sigma_data = 0.5 | 
					
						
						|  | betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device) | 
					
						
						|  | alphas = 1.0 - betas | 
					
						
						|  | alphas_cumprod = torch.cumprod(alphas, dim=0) | 
					
						
						|  | self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) | 
					
						
						|  | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) | 
					
						
						|  | sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) | 
					
						
						|  | sigmas = torch.sqrt(1.0 / alphas_cumprod - 1) | 
					
						
						|  | self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2) | 
					
						
						|  | self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5 | 
					
						
						|  | self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5 | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True): | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor") | 
					
						
						|  | rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space | 
					
						
						|  | if truncate_start: | 
					
						
						|  | rounded_timesteps[rounded_timesteps == total_timesteps] -= space | 
					
						
						|  | else: | 
					
						
						|  | rounded_timesteps[rounded_timesteps == total_timesteps] -= space | 
					
						
						|  | rounded_timesteps[rounded_timesteps == 0] += space | 
					
						
						|  | return rounded_timesteps | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def ldm_transform_latent(z, extra_scale_factor=1): | 
					
						
						|  | channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294] | 
					
						
						|  | channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034] | 
					
						
						|  |  | 
					
						
						|  | if len(z.shape) != 4: | 
					
						
						|  | raise ValueError() | 
					
						
						|  |  | 
					
						
						|  | z = z * 0.18215 | 
					
						
						|  | channels = [z[:, i] for i in range(z.shape[1])] | 
					
						
						|  |  | 
					
						
						|  | channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)] | 
					
						
						|  | return torch.stack(channels, dim=1) | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def __call__( | 
					
						
						|  | self, | 
					
						
						|  | features: torch.Tensor, | 
					
						
						|  | schedule=[1.0, 0.5], | 
					
						
						|  | generator=None, | 
					
						
						|  | ): | 
					
						
						|  | features = self.ldm_transform_latent(features) | 
					
						
						|  | ts = self.round_timesteps( | 
					
						
						|  | torch.arange(0, 1024), | 
					
						
						|  | 1024, | 
					
						
						|  | self.n_distilled_steps, | 
					
						
						|  | truncate_start=False, | 
					
						
						|  | ) | 
					
						
						|  | shape = ( | 
					
						
						|  | features.size(0), | 
					
						
						|  | 3, | 
					
						
						|  | 8 * features.size(2), | 
					
						
						|  | 8 * features.size(3), | 
					
						
						|  | ) | 
					
						
						|  | x_start = torch.zeros(shape, device=features.device, dtype=features.dtype) | 
					
						
						|  | schedule_timesteps = [int((1024 - 1) * s) for s in schedule] | 
					
						
						|  | for i in schedule_timesteps: | 
					
						
						|  | t = ts[i].item() | 
					
						
						|  | t_ = torch.tensor([t] * features.shape[0]).to(self.device) | 
					
						
						|  |  | 
					
						
						|  | noise = torch.randn(x_start.shape, dtype=x_start.dtype, generator=generator).to(device=x_start.device) | 
					
						
						|  | x_start = ( | 
					
						
						|  | _extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) * x_start | 
					
						
						|  | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape) * noise | 
					
						
						|  | ) | 
					
						
						|  | c_in = _extract_into_tensor(self.c_in, t_, x_start.shape) | 
					
						
						|  |  | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  |  | 
					
						
						|  | from diffusers import UNet2DModel | 
					
						
						|  |  | 
					
						
						|  | if isinstance(self.ckpt, UNet2DModel): | 
					
						
						|  | input = torch.concat([c_in * x_start, F.upsample_nearest(features, scale_factor=8)], dim=1) | 
					
						
						|  | model_output = self.ckpt(input, t_).sample | 
					
						
						|  | else: | 
					
						
						|  | model_output = self.ckpt(c_in * x_start, t_, features=features) | 
					
						
						|  |  | 
					
						
						|  | B, C = x_start.shape[:2] | 
					
						
						|  | model_output, _ = torch.split(model_output, C, dim=1) | 
					
						
						|  | pred_xstart = ( | 
					
						
						|  | _extract_into_tensor(self.c_out, t_, x_start.shape) * model_output | 
					
						
						|  | + _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start | 
					
						
						|  | ).clamp(-1, 1) | 
					
						
						|  | x_start = pred_xstart | 
					
						
						|  | return x_start | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def save_image(image, name): | 
					
						
						|  | import numpy as np | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  | image = image[0].cpu().numpy() | 
					
						
						|  | image = (image + 1.0) * 127.5 | 
					
						
						|  | image = image.clip(0, 255).astype(np.uint8) | 
					
						
						|  | image = Image.fromarray(image.transpose(1, 2, 0)) | 
					
						
						|  | image.save(name) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def load_image(uri, size=None, center_crop=False): | 
					
						
						|  | import numpy as np | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  | image = Image.open(uri) | 
					
						
						|  | if center_crop: | 
					
						
						|  | image = image.crop( | 
					
						
						|  | ( | 
					
						
						|  | (image.width - min(image.width, image.height)) // 2, | 
					
						
						|  | (image.height - min(image.width, image.height)) // 2, | 
					
						
						|  | (image.width + min(image.width, image.height)) // 2, | 
					
						
						|  | (image.height + min(image.width, image.height)) // 2, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | if size is not None: | 
					
						
						|  | image = image.resize(size) | 
					
						
						|  | image = torch.tensor(np.array(image).transpose(2, 0, 1)).unsqueeze(0).float() | 
					
						
						|  | image = image / 127.5 - 1.0 | 
					
						
						|  | return image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TimestepEmbedding_(nn.Module): | 
					
						
						|  | def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.emb = nn.Embedding(n_time, n_emb) | 
					
						
						|  | self.f_1 = nn.Linear(n_emb, n_out) | 
					
						
						|  | self.f_2 = nn.Linear(n_out, n_out) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x) -> torch.Tensor: | 
					
						
						|  | x = self.emb(x) | 
					
						
						|  | x = self.f_1(x) | 
					
						
						|  | x = F.silu(x) | 
					
						
						|  | return self.f_2(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ImageEmbedding(nn.Module): | 
					
						
						|  | def __init__(self, in_channels=7, out_channels=320) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x) -> torch.Tensor: | 
					
						
						|  | return self.f(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ImageUnembedding(nn.Module): | 
					
						
						|  | def __init__(self, in_channels=320, out_channels=6) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.gn = nn.GroupNorm(32, in_channels) | 
					
						
						|  | self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x) -> torch.Tensor: | 
					
						
						|  | return self.f(F.silu(self.gn(x))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ConvResblock(nn.Module): | 
					
						
						|  | def __init__(self, in_features=320, out_features=320) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.f_t = nn.Linear(1280, out_features * 2) | 
					
						
						|  |  | 
					
						
						|  | self.gn_1 = nn.GroupNorm(32, in_features) | 
					
						
						|  | self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1) | 
					
						
						|  |  | 
					
						
						|  | self.gn_2 = nn.GroupNorm(32, out_features) | 
					
						
						|  | self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1) | 
					
						
						|  |  | 
					
						
						|  | skip_conv = in_features != out_features | 
					
						
						|  | self.f_s = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, t): | 
					
						
						|  | x_skip = x | 
					
						
						|  | t = self.f_t(F.silu(t)) | 
					
						
						|  | t = t.chunk(2, dim=1) | 
					
						
						|  | t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1 | 
					
						
						|  | t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3) | 
					
						
						|  |  | 
					
						
						|  | gn_1 = F.silu(self.gn_1(x)) | 
					
						
						|  | f_1 = self.f_1(gn_1) | 
					
						
						|  |  | 
					
						
						|  | gn_2 = self.gn_2(f_1) | 
					
						
						|  |  | 
					
						
						|  | return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Downsample(nn.Module): | 
					
						
						|  | def __init__(self, in_channels=320) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.f_t = nn.Linear(1280, in_channels * 2) | 
					
						
						|  |  | 
					
						
						|  | self.gn_1 = nn.GroupNorm(32, in_channels) | 
					
						
						|  | self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | 
					
						
						|  | self.gn_2 = nn.GroupNorm(32, in_channels) | 
					
						
						|  |  | 
					
						
						|  | self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, t) -> torch.Tensor: | 
					
						
						|  | x_skip = x | 
					
						
						|  |  | 
					
						
						|  | t = self.f_t(F.silu(t)) | 
					
						
						|  | t_1, t_2 = t.chunk(2, dim=1) | 
					
						
						|  | t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 | 
					
						
						|  | t_2 = t_2.unsqueeze(2).unsqueeze(3) | 
					
						
						|  |  | 
					
						
						|  | gn_1 = F.silu(self.gn_1(x)) | 
					
						
						|  | avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None) | 
					
						
						|  |  | 
					
						
						|  | f_1 = self.f_1(avg_pool2d) | 
					
						
						|  | gn_2 = self.gn_2(f_1) | 
					
						
						|  |  | 
					
						
						|  | f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) | 
					
						
						|  |  | 
					
						
						|  | return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Upsample(nn.Module): | 
					
						
						|  | def __init__(self, in_channels=1024) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.f_t = nn.Linear(1280, in_channels * 2) | 
					
						
						|  |  | 
					
						
						|  | self.gn_1 = nn.GroupNorm(32, in_channels) | 
					
						
						|  | self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | 
					
						
						|  | self.gn_2 = nn.GroupNorm(32, in_channels) | 
					
						
						|  |  | 
					
						
						|  | self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, t) -> torch.Tensor: | 
					
						
						|  | x_skip = x | 
					
						
						|  |  | 
					
						
						|  | t = self.f_t(F.silu(t)) | 
					
						
						|  | t_1, t_2 = t.chunk(2, dim=1) | 
					
						
						|  | t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1 | 
					
						
						|  | t_2 = t_2.unsqueeze(2).unsqueeze(3) | 
					
						
						|  |  | 
					
						
						|  | gn_1 = F.silu(self.gn_1(x)) | 
					
						
						|  | upsample = F.upsample_nearest(gn_1, scale_factor=2) | 
					
						
						|  | f_1 = self.f_1(upsample) | 
					
						
						|  | gn_2 = self.gn_2(f_1) | 
					
						
						|  |  | 
					
						
						|  | f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2))) | 
					
						
						|  |  | 
					
						
						|  | return f_2 + F.upsample_nearest(x_skip, scale_factor=2) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ConvUNetVAE(nn.Module): | 
					
						
						|  | def __init__(self) -> None: | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.embed_image = ImageEmbedding() | 
					
						
						|  | self.embed_time = TimestepEmbedding_() | 
					
						
						|  |  | 
					
						
						|  | down_0 = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | ConvResblock(320, 320), | 
					
						
						|  | ConvResblock(320, 320), | 
					
						
						|  | ConvResblock(320, 320), | 
					
						
						|  | Downsample(320), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | down_1 = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | ConvResblock(320, 640), | 
					
						
						|  | ConvResblock(640, 640), | 
					
						
						|  | ConvResblock(640, 640), | 
					
						
						|  | Downsample(640), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | down_2 = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | ConvResblock(640, 1024), | 
					
						
						|  | ConvResblock(1024, 1024), | 
					
						
						|  | ConvResblock(1024, 1024), | 
					
						
						|  | Downsample(1024), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | down_3 = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | ConvResblock(1024, 1024), | 
					
						
						|  | ConvResblock(1024, 1024), | 
					
						
						|  | ConvResblock(1024, 1024), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | self.down = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | down_0, | 
					
						
						|  | down_1, | 
					
						
						|  | down_2, | 
					
						
						|  | down_3, | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.mid = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | ConvResblock(1024, 1024), | 
					
						
						|  | ConvResblock(1024, 1024), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | up_3 = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | ConvResblock(1024 * 2, 1024), | 
					
						
						|  | ConvResblock(1024 * 2, 1024), | 
					
						
						|  | ConvResblock(1024 * 2, 1024), | 
					
						
						|  | ConvResblock(1024 * 2, 1024), | 
					
						
						|  | Upsample(1024), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | up_2 = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | ConvResblock(1024 * 2, 1024), | 
					
						
						|  | ConvResblock(1024 * 2, 1024), | 
					
						
						|  | ConvResblock(1024 * 2, 1024), | 
					
						
						|  | ConvResblock(1024 + 640, 1024), | 
					
						
						|  | Upsample(1024), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | up_1 = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | ConvResblock(1024 + 640, 640), | 
					
						
						|  | ConvResblock(640 * 2, 640), | 
					
						
						|  | ConvResblock(640 * 2, 640), | 
					
						
						|  | ConvResblock(320 + 640, 640), | 
					
						
						|  | Upsample(640), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | up_0 = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | ConvResblock(320 + 640, 320), | 
					
						
						|  | ConvResblock(320 * 2, 320), | 
					
						
						|  | ConvResblock(320 * 2, 320), | 
					
						
						|  | ConvResblock(320 * 2, 320), | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | self.up = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | up_0, | 
					
						
						|  | up_1, | 
					
						
						|  | up_2, | 
					
						
						|  | up_3, | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.output = ImageUnembedding() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, t, features) -> torch.Tensor: | 
					
						
						|  | converted = hasattr(self, "converted") and self.converted | 
					
						
						|  |  | 
					
						
						|  | x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1) | 
					
						
						|  |  | 
					
						
						|  | if converted: | 
					
						
						|  | t = self.time_embedding(self.time_proj(t)) | 
					
						
						|  | else: | 
					
						
						|  | t = self.embed_time(t) | 
					
						
						|  |  | 
					
						
						|  | x = self.embed_image(x) | 
					
						
						|  |  | 
					
						
						|  | skips = [x] | 
					
						
						|  | for i, down in enumerate(self.down): | 
					
						
						|  | if converted and i in [0, 1, 2, 3]: | 
					
						
						|  | x, skips_ = down(x, t) | 
					
						
						|  | for skip in skips_: | 
					
						
						|  | skips.append(skip) | 
					
						
						|  | else: | 
					
						
						|  | for block in down: | 
					
						
						|  | x = block(x, t) | 
					
						
						|  | skips.append(x) | 
					
						
						|  | print(x.float().abs().sum()) | 
					
						
						|  |  | 
					
						
						|  | if converted: | 
					
						
						|  | x = self.mid(x, t) | 
					
						
						|  | else: | 
					
						
						|  | for i in range(2): | 
					
						
						|  | x = self.mid[i](x, t) | 
					
						
						|  | print(x.float().abs().sum()) | 
					
						
						|  |  | 
					
						
						|  | for i, up in enumerate(self.up[::-1]): | 
					
						
						|  | if converted and i in [0, 1, 2, 3]: | 
					
						
						|  | skip_4 = skips.pop() | 
					
						
						|  | skip_3 = skips.pop() | 
					
						
						|  | skip_2 = skips.pop() | 
					
						
						|  | skip_1 = skips.pop() | 
					
						
						|  | skips_ = (skip_1, skip_2, skip_3, skip_4) | 
					
						
						|  | x = up(x, skips_, t) | 
					
						
						|  | else: | 
					
						
						|  | for block in up: | 
					
						
						|  | if isinstance(block, ConvResblock): | 
					
						
						|  | x = torch.concat([x, skips.pop()], dim=1) | 
					
						
						|  | x = block(x, t) | 
					
						
						|  |  | 
					
						
						|  | return self.output(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def rename_state_dict_key(k): | 
					
						
						|  | k = k.replace("blocks.", "") | 
					
						
						|  | for i in range(5): | 
					
						
						|  | k = k.replace(f"down_{i}_", f"down.{i}.") | 
					
						
						|  | k = k.replace(f"conv_{i}.", f"{i}.") | 
					
						
						|  | k = k.replace(f"up_{i}_", f"up.{i}.") | 
					
						
						|  | k = k.replace(f"mid_{i}", f"mid.{i}") | 
					
						
						|  | k = k.replace("upsamp.", "4.") | 
					
						
						|  | k = k.replace("downsamp.", "3.") | 
					
						
						|  | k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias") | 
					
						
						|  | k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias") | 
					
						
						|  | k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias") | 
					
						
						|  | k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias") | 
					
						
						|  | k = k.replace("f.w", "f.weight").replace("f.b", "f.bias") | 
					
						
						|  | k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias") | 
					
						
						|  | k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias") | 
					
						
						|  | k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias") | 
					
						
						|  | return k | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def rename_state_dict(sd, embedding): | 
					
						
						|  | sd = {rename_state_dict_key(k): v for k, v in sd.items()} | 
					
						
						|  | sd["embed_time.emb.weight"] = embedding["weight"] | 
					
						
						|  | return sd | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) | 
					
						
						|  | pipe.vae.cuda() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | decoder_consistency = ConsistencyDecoder(device="cuda:0") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model = ConvUNetVAE() | 
					
						
						|  | model.load_state_dict( | 
					
						
						|  | rename_state_dict( | 
					
						
						|  | stl("consistency_decoder.safetensors"), | 
					
						
						|  | stl("embedding.safetensors"), | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | model = model.cuda() | 
					
						
						|  |  | 
					
						
						|  | decoder_consistency.ckpt = model | 
					
						
						|  |  | 
					
						
						|  | image = load_image(args.test_image, size=(256, 256), center_crop=True) | 
					
						
						|  | latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sample_gan = pipe.vae.decode(latent).sample.detach() | 
					
						
						|  | save_image(sample_gan, "gan.png") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sample_consistency_orig = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0)) | 
					
						
						|  | save_image(sample_consistency_orig, "con_orig.png") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("CONVERSION") | 
					
						
						|  |  | 
					
						
						|  | print("DOWN BLOCK ONE") | 
					
						
						|  |  | 
					
						
						|  | block_one_sd_orig = model.down[0].state_dict() | 
					
						
						|  | block_one_sd_new = {} | 
					
						
						|  |  | 
					
						
						|  | for i in range(3): | 
					
						
						|  | block_one_sd_new[f"resnets.{i}.norm1.weight"] = block_one_sd_orig.pop(f"{i}.gn_1.weight") | 
					
						
						|  | block_one_sd_new[f"resnets.{i}.norm1.bias"] = block_one_sd_orig.pop(f"{i}.gn_1.bias") | 
					
						
						|  | block_one_sd_new[f"resnets.{i}.conv1.weight"] = block_one_sd_orig.pop(f"{i}.f_1.weight") | 
					
						
						|  | block_one_sd_new[f"resnets.{i}.conv1.bias"] = block_one_sd_orig.pop(f"{i}.f_1.bias") | 
					
						
						|  | block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_one_sd_orig.pop(f"{i}.f_t.weight") | 
					
						
						|  | block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_one_sd_orig.pop(f"{i}.f_t.bias") | 
					
						
						|  | block_one_sd_new[f"resnets.{i}.norm2.weight"] = block_one_sd_orig.pop(f"{i}.gn_2.weight") | 
					
						
						|  | block_one_sd_new[f"resnets.{i}.norm2.bias"] = block_one_sd_orig.pop(f"{i}.gn_2.bias") | 
					
						
						|  | block_one_sd_new[f"resnets.{i}.conv2.weight"] = block_one_sd_orig.pop(f"{i}.f_2.weight") | 
					
						
						|  | block_one_sd_new[f"resnets.{i}.conv2.bias"] = block_one_sd_orig.pop(f"{i}.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | block_one_sd_new["downsamplers.0.norm1.weight"] = block_one_sd_orig.pop("3.gn_1.weight") | 
					
						
						|  | block_one_sd_new["downsamplers.0.norm1.bias"] = block_one_sd_orig.pop("3.gn_1.bias") | 
					
						
						|  | block_one_sd_new["downsamplers.0.conv1.weight"] = block_one_sd_orig.pop("3.f_1.weight") | 
					
						
						|  | block_one_sd_new["downsamplers.0.conv1.bias"] = block_one_sd_orig.pop("3.f_1.bias") | 
					
						
						|  | block_one_sd_new["downsamplers.0.time_emb_proj.weight"] = block_one_sd_orig.pop("3.f_t.weight") | 
					
						
						|  | block_one_sd_new["downsamplers.0.time_emb_proj.bias"] = block_one_sd_orig.pop("3.f_t.bias") | 
					
						
						|  | block_one_sd_new["downsamplers.0.norm2.weight"] = block_one_sd_orig.pop("3.gn_2.weight") | 
					
						
						|  | block_one_sd_new["downsamplers.0.norm2.bias"] = block_one_sd_orig.pop("3.gn_2.bias") | 
					
						
						|  | block_one_sd_new["downsamplers.0.conv2.weight"] = block_one_sd_orig.pop("3.f_2.weight") | 
					
						
						|  | block_one_sd_new["downsamplers.0.conv2.bias"] = block_one_sd_orig.pop("3.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | assert len(block_one_sd_orig) == 0 | 
					
						
						|  |  | 
					
						
						|  | block_one = ResnetDownsampleBlock2D( | 
					
						
						|  | in_channels=320, | 
					
						
						|  | out_channels=320, | 
					
						
						|  | temb_channels=1280, | 
					
						
						|  | num_layers=3, | 
					
						
						|  | add_downsample=True, | 
					
						
						|  | resnet_time_scale_shift="scale_shift", | 
					
						
						|  | resnet_eps=1e-5, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | block_one.load_state_dict(block_one_sd_new) | 
					
						
						|  |  | 
					
						
						|  | print("DOWN BLOCK TWO") | 
					
						
						|  |  | 
					
						
						|  | block_two_sd_orig = model.down[1].state_dict() | 
					
						
						|  | block_two_sd_new = {} | 
					
						
						|  |  | 
					
						
						|  | for i in range(3): | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.norm1.weight"] = block_two_sd_orig.pop(f"{i}.gn_1.weight") | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.norm1.bias"] = block_two_sd_orig.pop(f"{i}.gn_1.bias") | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.conv1.weight"] = block_two_sd_orig.pop(f"{i}.f_1.weight") | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.conv1.bias"] = block_two_sd_orig.pop(f"{i}.f_1.bias") | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_two_sd_orig.pop(f"{i}.f_t.weight") | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_two_sd_orig.pop(f"{i}.f_t.bias") | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.norm2.weight"] = block_two_sd_orig.pop(f"{i}.gn_2.weight") | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.norm2.bias"] = block_two_sd_orig.pop(f"{i}.gn_2.bias") | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.conv2.weight"] = block_two_sd_orig.pop(f"{i}.f_2.weight") | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.conv2.bias"] = block_two_sd_orig.pop(f"{i}.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | if i == 0: | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_two_sd_orig.pop(f"{i}.f_s.weight") | 
					
						
						|  | block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_two_sd_orig.pop(f"{i}.f_s.bias") | 
					
						
						|  |  | 
					
						
						|  | block_two_sd_new["downsamplers.0.norm1.weight"] = block_two_sd_orig.pop("3.gn_1.weight") | 
					
						
						|  | block_two_sd_new["downsamplers.0.norm1.bias"] = block_two_sd_orig.pop("3.gn_1.bias") | 
					
						
						|  | block_two_sd_new["downsamplers.0.conv1.weight"] = block_two_sd_orig.pop("3.f_1.weight") | 
					
						
						|  | block_two_sd_new["downsamplers.0.conv1.bias"] = block_two_sd_orig.pop("3.f_1.bias") | 
					
						
						|  | block_two_sd_new["downsamplers.0.time_emb_proj.weight"] = block_two_sd_orig.pop("3.f_t.weight") | 
					
						
						|  | block_two_sd_new["downsamplers.0.time_emb_proj.bias"] = block_two_sd_orig.pop("3.f_t.bias") | 
					
						
						|  | block_two_sd_new["downsamplers.0.norm2.weight"] = block_two_sd_orig.pop("3.gn_2.weight") | 
					
						
						|  | block_two_sd_new["downsamplers.0.norm2.bias"] = block_two_sd_orig.pop("3.gn_2.bias") | 
					
						
						|  | block_two_sd_new["downsamplers.0.conv2.weight"] = block_two_sd_orig.pop("3.f_2.weight") | 
					
						
						|  | block_two_sd_new["downsamplers.0.conv2.bias"] = block_two_sd_orig.pop("3.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | assert len(block_two_sd_orig) == 0 | 
					
						
						|  |  | 
					
						
						|  | block_two = ResnetDownsampleBlock2D( | 
					
						
						|  | in_channels=320, | 
					
						
						|  | out_channels=640, | 
					
						
						|  | temb_channels=1280, | 
					
						
						|  | num_layers=3, | 
					
						
						|  | add_downsample=True, | 
					
						
						|  | resnet_time_scale_shift="scale_shift", | 
					
						
						|  | resnet_eps=1e-5, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | block_two.load_state_dict(block_two_sd_new) | 
					
						
						|  |  | 
					
						
						|  | print("DOWN BLOCK THREE") | 
					
						
						|  |  | 
					
						
						|  | block_three_sd_orig = model.down[2].state_dict() | 
					
						
						|  | block_three_sd_new = {} | 
					
						
						|  |  | 
					
						
						|  | for i in range(3): | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.norm1.weight"] = block_three_sd_orig.pop(f"{i}.gn_1.weight") | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.norm1.bias"] = block_three_sd_orig.pop(f"{i}.gn_1.bias") | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.conv1.weight"] = block_three_sd_orig.pop(f"{i}.f_1.weight") | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.conv1.bias"] = block_three_sd_orig.pop(f"{i}.f_1.bias") | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_three_sd_orig.pop(f"{i}.f_t.weight") | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_three_sd_orig.pop(f"{i}.f_t.bias") | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.norm2.weight"] = block_three_sd_orig.pop(f"{i}.gn_2.weight") | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.norm2.bias"] = block_three_sd_orig.pop(f"{i}.gn_2.bias") | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.conv2.weight"] = block_three_sd_orig.pop(f"{i}.f_2.weight") | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.conv2.bias"] = block_three_sd_orig.pop(f"{i}.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | if i == 0: | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_three_sd_orig.pop(f"{i}.f_s.weight") | 
					
						
						|  | block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_three_sd_orig.pop(f"{i}.f_s.bias") | 
					
						
						|  |  | 
					
						
						|  | block_three_sd_new["downsamplers.0.norm1.weight"] = block_three_sd_orig.pop("3.gn_1.weight") | 
					
						
						|  | block_three_sd_new["downsamplers.0.norm1.bias"] = block_three_sd_orig.pop("3.gn_1.bias") | 
					
						
						|  | block_three_sd_new["downsamplers.0.conv1.weight"] = block_three_sd_orig.pop("3.f_1.weight") | 
					
						
						|  | block_three_sd_new["downsamplers.0.conv1.bias"] = block_three_sd_orig.pop("3.f_1.bias") | 
					
						
						|  | block_three_sd_new["downsamplers.0.time_emb_proj.weight"] = block_three_sd_orig.pop("3.f_t.weight") | 
					
						
						|  | block_three_sd_new["downsamplers.0.time_emb_proj.bias"] = block_three_sd_orig.pop("3.f_t.bias") | 
					
						
						|  | block_three_sd_new["downsamplers.0.norm2.weight"] = block_three_sd_orig.pop("3.gn_2.weight") | 
					
						
						|  | block_three_sd_new["downsamplers.0.norm2.bias"] = block_three_sd_orig.pop("3.gn_2.bias") | 
					
						
						|  | block_three_sd_new["downsamplers.0.conv2.weight"] = block_three_sd_orig.pop("3.f_2.weight") | 
					
						
						|  | block_three_sd_new["downsamplers.0.conv2.bias"] = block_three_sd_orig.pop("3.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | assert len(block_three_sd_orig) == 0 | 
					
						
						|  |  | 
					
						
						|  | block_three = ResnetDownsampleBlock2D( | 
					
						
						|  | in_channels=640, | 
					
						
						|  | out_channels=1024, | 
					
						
						|  | temb_channels=1280, | 
					
						
						|  | num_layers=3, | 
					
						
						|  | add_downsample=True, | 
					
						
						|  | resnet_time_scale_shift="scale_shift", | 
					
						
						|  | resnet_eps=1e-5, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | block_three.load_state_dict(block_three_sd_new) | 
					
						
						|  |  | 
					
						
						|  | print("DOWN BLOCK FOUR") | 
					
						
						|  |  | 
					
						
						|  | block_four_sd_orig = model.down[3].state_dict() | 
					
						
						|  | block_four_sd_new = {} | 
					
						
						|  |  | 
					
						
						|  | for i in range(3): | 
					
						
						|  | block_four_sd_new[f"resnets.{i}.norm1.weight"] = block_four_sd_orig.pop(f"{i}.gn_1.weight") | 
					
						
						|  | block_four_sd_new[f"resnets.{i}.norm1.bias"] = block_four_sd_orig.pop(f"{i}.gn_1.bias") | 
					
						
						|  | block_four_sd_new[f"resnets.{i}.conv1.weight"] = block_four_sd_orig.pop(f"{i}.f_1.weight") | 
					
						
						|  | block_four_sd_new[f"resnets.{i}.conv1.bias"] = block_four_sd_orig.pop(f"{i}.f_1.bias") | 
					
						
						|  | block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_four_sd_orig.pop(f"{i}.f_t.weight") | 
					
						
						|  | block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_four_sd_orig.pop(f"{i}.f_t.bias") | 
					
						
						|  | block_four_sd_new[f"resnets.{i}.norm2.weight"] = block_four_sd_orig.pop(f"{i}.gn_2.weight") | 
					
						
						|  | block_four_sd_new[f"resnets.{i}.norm2.bias"] = block_four_sd_orig.pop(f"{i}.gn_2.bias") | 
					
						
						|  | block_four_sd_new[f"resnets.{i}.conv2.weight"] = block_four_sd_orig.pop(f"{i}.f_2.weight") | 
					
						
						|  | block_four_sd_new[f"resnets.{i}.conv2.bias"] = block_four_sd_orig.pop(f"{i}.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | assert len(block_four_sd_orig) == 0 | 
					
						
						|  |  | 
					
						
						|  | block_four = ResnetDownsampleBlock2D( | 
					
						
						|  | in_channels=1024, | 
					
						
						|  | out_channels=1024, | 
					
						
						|  | temb_channels=1280, | 
					
						
						|  | num_layers=3, | 
					
						
						|  | add_downsample=False, | 
					
						
						|  | resnet_time_scale_shift="scale_shift", | 
					
						
						|  | resnet_eps=1e-5, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | block_four.load_state_dict(block_four_sd_new) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("MID BLOCK 1") | 
					
						
						|  |  | 
					
						
						|  | mid_block_one_sd_orig = model.mid.state_dict() | 
					
						
						|  | mid_block_one_sd_new = {} | 
					
						
						|  |  | 
					
						
						|  | for i in range(2): | 
					
						
						|  | mid_block_one_sd_new[f"resnets.{i}.norm1.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.weight") | 
					
						
						|  | mid_block_one_sd_new[f"resnets.{i}.norm1.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.bias") | 
					
						
						|  | mid_block_one_sd_new[f"resnets.{i}.conv1.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_1.weight") | 
					
						
						|  | mid_block_one_sd_new[f"resnets.{i}.conv1.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_1.bias") | 
					
						
						|  | mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_t.weight") | 
					
						
						|  | mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_t.bias") | 
					
						
						|  | mid_block_one_sd_new[f"resnets.{i}.norm2.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.weight") | 
					
						
						|  | mid_block_one_sd_new[f"resnets.{i}.norm2.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.bias") | 
					
						
						|  | mid_block_one_sd_new[f"resnets.{i}.conv2.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_2.weight") | 
					
						
						|  | mid_block_one_sd_new[f"resnets.{i}.conv2.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | assert len(mid_block_one_sd_orig) == 0 | 
					
						
						|  |  | 
					
						
						|  | mid_block_one = UNetMidBlock2D( | 
					
						
						|  | in_channels=1024, | 
					
						
						|  | temb_channels=1280, | 
					
						
						|  | num_layers=1, | 
					
						
						|  | resnet_time_scale_shift="scale_shift", | 
					
						
						|  | resnet_eps=1e-5, | 
					
						
						|  | add_attention=False, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | mid_block_one.load_state_dict(mid_block_one_sd_new) | 
					
						
						|  |  | 
					
						
						|  | print("UP BLOCK ONE") | 
					
						
						|  |  | 
					
						
						|  | up_block_one_sd_orig = model.up[-1].state_dict() | 
					
						
						|  | up_block_one_sd_new = {} | 
					
						
						|  |  | 
					
						
						|  | for i in range(4): | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.norm1.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_1.weight") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.norm1.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_1.bias") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.conv1.weight"] = up_block_one_sd_orig.pop(f"{i}.f_1.weight") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.conv1.bias"] = up_block_one_sd_orig.pop(f"{i}.f_1.bias") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_one_sd_orig.pop(f"{i}.f_t.weight") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_one_sd_orig.pop(f"{i}.f_t.bias") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.norm2.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_2.weight") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.norm2.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_2.bias") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.conv2.weight"] = up_block_one_sd_orig.pop(f"{i}.f_2.weight") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.conv2.bias"] = up_block_one_sd_orig.pop(f"{i}.f_2.bias") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_one_sd_orig.pop(f"{i}.f_s.weight") | 
					
						
						|  | up_block_one_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_one_sd_orig.pop(f"{i}.f_s.bias") | 
					
						
						|  |  | 
					
						
						|  | up_block_one_sd_new["upsamplers.0.norm1.weight"] = up_block_one_sd_orig.pop("4.gn_1.weight") | 
					
						
						|  | up_block_one_sd_new["upsamplers.0.norm1.bias"] = up_block_one_sd_orig.pop("4.gn_1.bias") | 
					
						
						|  | up_block_one_sd_new["upsamplers.0.conv1.weight"] = up_block_one_sd_orig.pop("4.f_1.weight") | 
					
						
						|  | up_block_one_sd_new["upsamplers.0.conv1.bias"] = up_block_one_sd_orig.pop("4.f_1.bias") | 
					
						
						|  | up_block_one_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_one_sd_orig.pop("4.f_t.weight") | 
					
						
						|  | up_block_one_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_one_sd_orig.pop("4.f_t.bias") | 
					
						
						|  | up_block_one_sd_new["upsamplers.0.norm2.weight"] = up_block_one_sd_orig.pop("4.gn_2.weight") | 
					
						
						|  | up_block_one_sd_new["upsamplers.0.norm2.bias"] = up_block_one_sd_orig.pop("4.gn_2.bias") | 
					
						
						|  | up_block_one_sd_new["upsamplers.0.conv2.weight"] = up_block_one_sd_orig.pop("4.f_2.weight") | 
					
						
						|  | up_block_one_sd_new["upsamplers.0.conv2.bias"] = up_block_one_sd_orig.pop("4.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | assert len(up_block_one_sd_orig) == 0 | 
					
						
						|  |  | 
					
						
						|  | up_block_one = ResnetUpsampleBlock2D( | 
					
						
						|  | in_channels=1024, | 
					
						
						|  | prev_output_channel=1024, | 
					
						
						|  | out_channels=1024, | 
					
						
						|  | temb_channels=1280, | 
					
						
						|  | num_layers=4, | 
					
						
						|  | add_upsample=True, | 
					
						
						|  | resnet_time_scale_shift="scale_shift", | 
					
						
						|  | resnet_eps=1e-5, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | up_block_one.load_state_dict(up_block_one_sd_new) | 
					
						
						|  |  | 
					
						
						|  | print("UP BLOCK TWO") | 
					
						
						|  |  | 
					
						
						|  | up_block_two_sd_orig = model.up[-2].state_dict() | 
					
						
						|  | up_block_two_sd_new = {} | 
					
						
						|  |  | 
					
						
						|  | for i in range(4): | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.norm1.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_1.weight") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.norm1.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_1.bias") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.conv1.weight"] = up_block_two_sd_orig.pop(f"{i}.f_1.weight") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.conv1.bias"] = up_block_two_sd_orig.pop(f"{i}.f_1.bias") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_two_sd_orig.pop(f"{i}.f_t.weight") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_two_sd_orig.pop(f"{i}.f_t.bias") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.norm2.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_2.weight") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.norm2.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_2.bias") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.conv2.weight"] = up_block_two_sd_orig.pop(f"{i}.f_2.weight") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.conv2.bias"] = up_block_two_sd_orig.pop(f"{i}.f_2.bias") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_two_sd_orig.pop(f"{i}.f_s.weight") | 
					
						
						|  | up_block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_two_sd_orig.pop(f"{i}.f_s.bias") | 
					
						
						|  |  | 
					
						
						|  | up_block_two_sd_new["upsamplers.0.norm1.weight"] = up_block_two_sd_orig.pop("4.gn_1.weight") | 
					
						
						|  | up_block_two_sd_new["upsamplers.0.norm1.bias"] = up_block_two_sd_orig.pop("4.gn_1.bias") | 
					
						
						|  | up_block_two_sd_new["upsamplers.0.conv1.weight"] = up_block_two_sd_orig.pop("4.f_1.weight") | 
					
						
						|  | up_block_two_sd_new["upsamplers.0.conv1.bias"] = up_block_two_sd_orig.pop("4.f_1.bias") | 
					
						
						|  | up_block_two_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_two_sd_orig.pop("4.f_t.weight") | 
					
						
						|  | up_block_two_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_two_sd_orig.pop("4.f_t.bias") | 
					
						
						|  | up_block_two_sd_new["upsamplers.0.norm2.weight"] = up_block_two_sd_orig.pop("4.gn_2.weight") | 
					
						
						|  | up_block_two_sd_new["upsamplers.0.norm2.bias"] = up_block_two_sd_orig.pop("4.gn_2.bias") | 
					
						
						|  | up_block_two_sd_new["upsamplers.0.conv2.weight"] = up_block_two_sd_orig.pop("4.f_2.weight") | 
					
						
						|  | up_block_two_sd_new["upsamplers.0.conv2.bias"] = up_block_two_sd_orig.pop("4.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | assert len(up_block_two_sd_orig) == 0 | 
					
						
						|  |  | 
					
						
						|  | up_block_two = ResnetUpsampleBlock2D( | 
					
						
						|  | in_channels=640, | 
					
						
						|  | prev_output_channel=1024, | 
					
						
						|  | out_channels=1024, | 
					
						
						|  | temb_channels=1280, | 
					
						
						|  | num_layers=4, | 
					
						
						|  | add_upsample=True, | 
					
						
						|  | resnet_time_scale_shift="scale_shift", | 
					
						
						|  | resnet_eps=1e-5, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | up_block_two.load_state_dict(up_block_two_sd_new) | 
					
						
						|  |  | 
					
						
						|  | print("UP BLOCK THREE") | 
					
						
						|  |  | 
					
						
						|  | up_block_three_sd_orig = model.up[-3].state_dict() | 
					
						
						|  | up_block_three_sd_new = {} | 
					
						
						|  |  | 
					
						
						|  | for i in range(4): | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.norm1.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_1.weight") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.norm1.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_1.bias") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.conv1.weight"] = up_block_three_sd_orig.pop(f"{i}.f_1.weight") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.conv1.bias"] = up_block_three_sd_orig.pop(f"{i}.f_1.bias") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_three_sd_orig.pop(f"{i}.f_t.weight") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_three_sd_orig.pop(f"{i}.f_t.bias") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.norm2.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_2.weight") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.norm2.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_2.bias") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.conv2.weight"] = up_block_three_sd_orig.pop(f"{i}.f_2.weight") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.conv2.bias"] = up_block_three_sd_orig.pop(f"{i}.f_2.bias") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_three_sd_orig.pop(f"{i}.f_s.weight") | 
					
						
						|  | up_block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_three_sd_orig.pop(f"{i}.f_s.bias") | 
					
						
						|  |  | 
					
						
						|  | up_block_three_sd_new["upsamplers.0.norm1.weight"] = up_block_three_sd_orig.pop("4.gn_1.weight") | 
					
						
						|  | up_block_three_sd_new["upsamplers.0.norm1.bias"] = up_block_three_sd_orig.pop("4.gn_1.bias") | 
					
						
						|  | up_block_three_sd_new["upsamplers.0.conv1.weight"] = up_block_three_sd_orig.pop("4.f_1.weight") | 
					
						
						|  | up_block_three_sd_new["upsamplers.0.conv1.bias"] = up_block_three_sd_orig.pop("4.f_1.bias") | 
					
						
						|  | up_block_three_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_three_sd_orig.pop("4.f_t.weight") | 
					
						
						|  | up_block_three_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_three_sd_orig.pop("4.f_t.bias") | 
					
						
						|  | up_block_three_sd_new["upsamplers.0.norm2.weight"] = up_block_three_sd_orig.pop("4.gn_2.weight") | 
					
						
						|  | up_block_three_sd_new["upsamplers.0.norm2.bias"] = up_block_three_sd_orig.pop("4.gn_2.bias") | 
					
						
						|  | up_block_three_sd_new["upsamplers.0.conv2.weight"] = up_block_three_sd_orig.pop("4.f_2.weight") | 
					
						
						|  | up_block_three_sd_new["upsamplers.0.conv2.bias"] = up_block_three_sd_orig.pop("4.f_2.bias") | 
					
						
						|  |  | 
					
						
						|  | assert len(up_block_three_sd_orig) == 0 | 
					
						
						|  |  | 
					
						
						|  | up_block_three = ResnetUpsampleBlock2D( | 
					
						
						|  | in_channels=320, | 
					
						
						|  | prev_output_channel=1024, | 
					
						
						|  | out_channels=640, | 
					
						
						|  | temb_channels=1280, | 
					
						
						|  | num_layers=4, | 
					
						
						|  | add_upsample=True, | 
					
						
						|  | resnet_time_scale_shift="scale_shift", | 
					
						
						|  | resnet_eps=1e-5, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | up_block_three.load_state_dict(up_block_three_sd_new) | 
					
						
						|  |  | 
					
						
						|  | print("UP BLOCK FOUR") | 
					
						
						|  |  | 
					
						
						|  | up_block_four_sd_orig = model.up[-4].state_dict() | 
					
						
						|  | up_block_four_sd_new = {} | 
					
						
						|  |  | 
					
						
						|  | for i in range(4): | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.norm1.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_1.weight") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.norm1.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_1.bias") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.conv1.weight"] = up_block_four_sd_orig.pop(f"{i}.f_1.weight") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.conv1.bias"] = up_block_four_sd_orig.pop(f"{i}.f_1.bias") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_four_sd_orig.pop(f"{i}.f_t.weight") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_four_sd_orig.pop(f"{i}.f_t.bias") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.norm2.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_2.weight") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.norm2.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_2.bias") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.conv2.weight"] = up_block_four_sd_orig.pop(f"{i}.f_2.weight") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.conv2.bias"] = up_block_four_sd_orig.pop(f"{i}.f_2.bias") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_four_sd_orig.pop(f"{i}.f_s.weight") | 
					
						
						|  | up_block_four_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_four_sd_orig.pop(f"{i}.f_s.bias") | 
					
						
						|  |  | 
					
						
						|  | assert len(up_block_four_sd_orig) == 0 | 
					
						
						|  |  | 
					
						
						|  | up_block_four = ResnetUpsampleBlock2D( | 
					
						
						|  | in_channels=320, | 
					
						
						|  | prev_output_channel=640, | 
					
						
						|  | out_channels=320, | 
					
						
						|  | temb_channels=1280, | 
					
						
						|  | num_layers=4, | 
					
						
						|  | add_upsample=False, | 
					
						
						|  | resnet_time_scale_shift="scale_shift", | 
					
						
						|  | resnet_eps=1e-5, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | up_block_four.load_state_dict(up_block_four_sd_new) | 
					
						
						|  |  | 
					
						
						|  | print("initial projection (conv_in)") | 
					
						
						|  |  | 
					
						
						|  | conv_in_sd_orig = model.embed_image.state_dict() | 
					
						
						|  | conv_in_sd_new = {} | 
					
						
						|  |  | 
					
						
						|  | conv_in_sd_new["weight"] = conv_in_sd_orig.pop("f.weight") | 
					
						
						|  | conv_in_sd_new["bias"] = conv_in_sd_orig.pop("f.bias") | 
					
						
						|  |  | 
					
						
						|  | assert len(conv_in_sd_orig) == 0 | 
					
						
						|  |  | 
					
						
						|  | block_out_channels = [320, 640, 1024, 1024] | 
					
						
						|  |  | 
					
						
						|  | in_channels = 7 | 
					
						
						|  | conv_in_kernel = 3 | 
					
						
						|  | conv_in_padding = (conv_in_kernel - 1) // 2 | 
					
						
						|  | conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding) | 
					
						
						|  |  | 
					
						
						|  | conv_in.load_state_dict(conv_in_sd_new) | 
					
						
						|  |  | 
					
						
						|  | print("out projection (conv_out) (conv_norm_out)") | 
					
						
						|  | out_channels = 6 | 
					
						
						|  | norm_num_groups = 32 | 
					
						
						|  | norm_eps = 1e-5 | 
					
						
						|  | act_fn = "silu" | 
					
						
						|  | conv_out_kernel = 3 | 
					
						
						|  | conv_out_padding = (conv_out_kernel - 1) // 2 | 
					
						
						|  | conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding) | 
					
						
						|  |  | 
					
						
						|  | conv_norm_out.load_state_dict(model.output.gn.state_dict()) | 
					
						
						|  | conv_out.load_state_dict(model.output.f.state_dict()) | 
					
						
						|  |  | 
					
						
						|  | print("timestep projection (time_proj) (time_embedding)") | 
					
						
						|  |  | 
					
						
						|  | f1_sd = model.embed_time.f_1.state_dict() | 
					
						
						|  | f2_sd = model.embed_time.f_2.state_dict() | 
					
						
						|  |  | 
					
						
						|  | time_embedding_sd = { | 
					
						
						|  | "linear_1.weight": f1_sd.pop("weight"), | 
					
						
						|  | "linear_1.bias": f1_sd.pop("bias"), | 
					
						
						|  | "linear_2.weight": f2_sd.pop("weight"), | 
					
						
						|  | "linear_2.bias": f2_sd.pop("bias"), | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | assert len(f1_sd) == 0 | 
					
						
						|  | assert len(f2_sd) == 0 | 
					
						
						|  |  | 
					
						
						|  | time_embedding_type = "learned" | 
					
						
						|  | num_train_timesteps = 1024 | 
					
						
						|  | time_embedding_dim = 1280 | 
					
						
						|  |  | 
					
						
						|  | time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) | 
					
						
						|  | timestep_input_dim = block_out_channels[0] | 
					
						
						|  |  | 
					
						
						|  | time_embedding = TimestepEmbedding(timestep_input_dim, time_embedding_dim) | 
					
						
						|  |  | 
					
						
						|  | time_proj.load_state_dict(model.embed_time.emb.state_dict()) | 
					
						
						|  | time_embedding.load_state_dict(time_embedding_sd) | 
					
						
						|  |  | 
					
						
						|  | print("CONVERT") | 
					
						
						|  |  | 
					
						
						|  | time_embedding.to("cuda") | 
					
						
						|  | time_proj.to("cuda") | 
					
						
						|  | conv_in.to("cuda") | 
					
						
						|  |  | 
					
						
						|  | block_one.to("cuda") | 
					
						
						|  | block_two.to("cuda") | 
					
						
						|  | block_three.to("cuda") | 
					
						
						|  | block_four.to("cuda") | 
					
						
						|  |  | 
					
						
						|  | mid_block_one.to("cuda") | 
					
						
						|  |  | 
					
						
						|  | up_block_one.to("cuda") | 
					
						
						|  | up_block_two.to("cuda") | 
					
						
						|  | up_block_three.to("cuda") | 
					
						
						|  | up_block_four.to("cuda") | 
					
						
						|  |  | 
					
						
						|  | conv_norm_out.to("cuda") | 
					
						
						|  | conv_out.to("cuda") | 
					
						
						|  |  | 
					
						
						|  | model.time_proj = time_proj | 
					
						
						|  | model.time_embedding = time_embedding | 
					
						
						|  | model.embed_image = conv_in | 
					
						
						|  |  | 
					
						
						|  | model.down[0] = block_one | 
					
						
						|  | model.down[1] = block_two | 
					
						
						|  | model.down[2] = block_three | 
					
						
						|  | model.down[3] = block_four | 
					
						
						|  |  | 
					
						
						|  | model.mid = mid_block_one | 
					
						
						|  |  | 
					
						
						|  | model.up[-1] = up_block_one | 
					
						
						|  | model.up[-2] = up_block_two | 
					
						
						|  | model.up[-3] = up_block_three | 
					
						
						|  | model.up[-4] = up_block_four | 
					
						
						|  |  | 
					
						
						|  | model.output.gn = conv_norm_out | 
					
						
						|  | model.output.f = conv_out | 
					
						
						|  |  | 
					
						
						|  | model.converted = True | 
					
						
						|  |  | 
					
						
						|  | sample_consistency_new = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0)) | 
					
						
						|  | save_image(sample_consistency_new, "con_new.png") | 
					
						
						|  |  | 
					
						
						|  | assert (sample_consistency_orig == sample_consistency_new).all() | 
					
						
						|  |  | 
					
						
						|  | print("making unet") | 
					
						
						|  |  | 
					
						
						|  | unet = UNet2DModel( | 
					
						
						|  | in_channels=in_channels, | 
					
						
						|  | out_channels=out_channels, | 
					
						
						|  | down_block_types=( | 
					
						
						|  | "ResnetDownsampleBlock2D", | 
					
						
						|  | "ResnetDownsampleBlock2D", | 
					
						
						|  | "ResnetDownsampleBlock2D", | 
					
						
						|  | "ResnetDownsampleBlock2D", | 
					
						
						|  | ), | 
					
						
						|  | up_block_types=( | 
					
						
						|  | "ResnetUpsampleBlock2D", | 
					
						
						|  | "ResnetUpsampleBlock2D", | 
					
						
						|  | "ResnetUpsampleBlock2D", | 
					
						
						|  | "ResnetUpsampleBlock2D", | 
					
						
						|  | ), | 
					
						
						|  | block_out_channels=block_out_channels, | 
					
						
						|  | layers_per_block=3, | 
					
						
						|  | norm_num_groups=norm_num_groups, | 
					
						
						|  | norm_eps=norm_eps, | 
					
						
						|  | resnet_time_scale_shift="scale_shift", | 
					
						
						|  | time_embedding_type="learned", | 
					
						
						|  | num_train_timesteps=num_train_timesteps, | 
					
						
						|  | add_attention=False, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | unet_state_dict = {} | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def add_state_dict(prefix, mod): | 
					
						
						|  | for k, v in mod.state_dict().items(): | 
					
						
						|  | unet_state_dict[f"{prefix}.{k}"] = v | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | add_state_dict("conv_in", conv_in) | 
					
						
						|  | add_state_dict("time_proj", time_proj) | 
					
						
						|  | add_state_dict("time_embedding", time_embedding) | 
					
						
						|  | add_state_dict("down_blocks.0", block_one) | 
					
						
						|  | add_state_dict("down_blocks.1", block_two) | 
					
						
						|  | add_state_dict("down_blocks.2", block_three) | 
					
						
						|  | add_state_dict("down_blocks.3", block_four) | 
					
						
						|  | add_state_dict("mid_block", mid_block_one) | 
					
						
						|  | add_state_dict("up_blocks.0", up_block_one) | 
					
						
						|  | add_state_dict("up_blocks.1", up_block_two) | 
					
						
						|  | add_state_dict("up_blocks.2", up_block_three) | 
					
						
						|  | add_state_dict("up_blocks.3", up_block_four) | 
					
						
						|  | add_state_dict("conv_norm_out", conv_norm_out) | 
					
						
						|  | add_state_dict("conv_out", conv_out) | 
					
						
						|  |  | 
					
						
						|  | unet.load_state_dict(unet_state_dict) | 
					
						
						|  |  | 
					
						
						|  | print("running with diffusers unet") | 
					
						
						|  |  | 
					
						
						|  | unet.to("cuda") | 
					
						
						|  |  | 
					
						
						|  | decoder_consistency.ckpt = unet | 
					
						
						|  |  | 
					
						
						|  | sample_consistency_new_2 = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0)) | 
					
						
						|  | save_image(sample_consistency_new_2, "con_new_2.png") | 
					
						
						|  |  | 
					
						
						|  | assert (sample_consistency_orig == sample_consistency_new_2).all() | 
					
						
						|  |  | 
					
						
						|  | print("running with diffusers model") | 
					
						
						|  |  | 
					
						
						|  | Encoder.old_constructor = Encoder.__init__ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def new_constructor(self, **kwargs): | 
					
						
						|  | self.old_constructor(**kwargs) | 
					
						
						|  | self.constructor_arguments = kwargs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | Encoder.__init__ = new_constructor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") | 
					
						
						|  | consistency_vae = ConsistencyDecoderVAE( | 
					
						
						|  | encoder_args=vae.encoder.constructor_arguments, | 
					
						
						|  | decoder_args=unet.config, | 
					
						
						|  | scaling_factor=vae.config.scaling_factor, | 
					
						
						|  | block_out_channels=vae.config.block_out_channels, | 
					
						
						|  | latent_channels=vae.config.latent_channels, | 
					
						
						|  | ) | 
					
						
						|  | consistency_vae.encoder.load_state_dict(vae.encoder.state_dict()) | 
					
						
						|  | consistency_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict()) | 
					
						
						|  | consistency_vae.decoder_unet.load_state_dict(unet.state_dict()) | 
					
						
						|  |  | 
					
						
						|  | consistency_vae.to(dtype=torch.float16, device="cuda") | 
					
						
						|  |  | 
					
						
						|  | sample_consistency_new_3 = consistency_vae.decode( | 
					
						
						|  | 0.18215 * latent, generator=torch.Generator("cpu").manual_seed(0) | 
					
						
						|  | ).sample | 
					
						
						|  |  | 
					
						
						|  | print("max difference") | 
					
						
						|  | print((sample_consistency_orig - sample_consistency_new_3).abs().max()) | 
					
						
						|  | print("total difference") | 
					
						
						|  | print((sample_consistency_orig - sample_consistency_new_3).abs().sum()) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | print("running with diffusers pipeline") | 
					
						
						|  |  | 
					
						
						|  | pipe = DiffusionPipeline.from_pretrained( | 
					
						
						|  | "runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16 | 
					
						
						|  | ) | 
					
						
						|  | pipe.to("cuda") | 
					
						
						|  |  | 
					
						
						|  | pipe("horse", generator=torch.Generator("cpu").manual_seed(0)).images[0].save("horse.png") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if args.save_pretrained is not None: | 
					
						
						|  | consistency_vae.save_pretrained(args.save_pretrained) | 
					
						
						|  |  |