# コードは Stable Cascade からコピーし、一部修正しています。元ライセンスは MIT です。 # The code is copied from Stable Cascade and modified. The original license is MIT. # https://github.com/Stability-AI/StableCascade import math from types import SimpleNamespace from typing import List, Optional import numpy as np import torch import torch.nn as nn import torch.utils.checkpoint import torchvision def check_scale(tensor): return torch.mean(torch.abs(tensor)) # region VectorQuantize # from torchtools https://github.com/pabloppp/pytorch-tools # 依存ライブラリを増やしたくないのでここにコピペ class vector_quantize(torch.autograd.Function): @staticmethod def forward(ctx, x, codebook): with torch.no_grad(): codebook_sqr = torch.sum(codebook**2, dim=1) x_sqr = torch.sum(x**2, dim=1, keepdim=True) dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) _, indices = dist.min(dim=1) ctx.save_for_backward(indices, codebook) ctx.mark_non_differentiable(indices) nn = torch.index_select(codebook, 0, indices) return nn, indices @staticmethod def backward(ctx, grad_output, grad_indices): grad_inputs, grad_codebook = None, None if ctx.needs_input_grad[0]: grad_inputs = grad_output.clone() if ctx.needs_input_grad[1]: # Gradient wrt. the codebook indices, codebook = ctx.saved_tensors grad_codebook = torch.zeros_like(codebook) grad_codebook.index_add_(0, indices, grad_output) return (grad_inputs, grad_codebook) class VectorQuantize(nn.Module): def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): """ Takes an input of variable size (as long as the last dimension matches the embedding size). Returns one tensor containing the nearest neighbour embeddings to each of the inputs, with the same size as the input, vq and commitment components for the loss as a tuple in the second output and the indices of the quantized vectors in the third: quantized, (vq_loss, commit_loss), indices """ super(VectorQuantize, self).__init__() self.codebook = nn.Embedding(k, embedding_size) self.codebook.weight.data.uniform_(-1.0 / k, 1.0 / k) self.vq = vector_quantize.apply self.ema_decay = ema_decay self.ema_loss = ema_loss if ema_loss: self.register_buffer("ema_element_count", torch.ones(k)) self.register_buffer("ema_weight_sum", torch.zeros_like(self.codebook.weight)) def _laplace_smoothing(self, x, epsilon): n = torch.sum(x) return (x + epsilon) / (n + x.size(0) * epsilon) * n def _updateEMA(self, z_e_x, indices): mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() elem_count = mask.sum(dim=0) weight_sum = torch.mm(mask.t(), z_e_x) self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1 - self.ema_decay) * elem_count) self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1 - self.ema_decay) * weight_sum) self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) def idx2vq(self, idx, dim=-1): q_idx = self.codebook(idx) if dim != -1: q_idx = q_idx.movedim(-1, dim) return q_idx def forward(self, x, get_losses=True, dim=-1): if dim != -1: x = x.movedim(dim, -1) z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) vq_loss, commit_loss = None, None if self.ema_loss and self.training: self._updateEMA(z_e_x.detach(), indices.detach()) # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) if get_losses: vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() z_q_x = z_q_x.view(x.shape) if dim != -1: z_q_x = z_q_x.movedim(-1, dim) return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) # endregion class EfficientNetEncoder(nn.Module): def __init__(self, c_latent=16): super().__init__() self.backbone = torchvision.models.efficientnet_v2_s(weights="DEFAULT").features.eval() self.mapper = nn.Sequential( nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 ) def forward(self, x): return self.mapper(self.backbone(x)) @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype @property def device(self) -> torch.device: return next(self.parameters()).device def encode(self, x): """ VAE と同じように使えるようにするためのメソッド。正しくはちゃんと呼び出し側で分けるべきだが、暫定的な対応。 The method to make it usable like VAE. It should be separated properly, but it is a temporary response. """ # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") x = self(x) return SimpleNamespace(latent_dist=SimpleNamespace(sample=lambda: x)) # なんかわりと乱暴な実装(;'∀') # 一から学習することもないだろうから、無効化しておく # class Linear(torch.nn.Linear): # def reset_parameters(self): # return None # class Conv2d(torch.nn.Conv2d): # def reset_parameters(self): # return None from torch.nn import Conv2d from torch.nn import Linear class Attention2D(nn.Module): def __init__(self, c, nhead, dropout=0.0): super().__init__() self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True) def forward(self, x, kv, self_attn=False): orig_shape = x.shape x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 if self_attn: kv = torch.cat([x, kv], dim=1) x = self.attn(x, kv, kv, need_weights=False)[0] x = x.permute(0, 2, 1).view(*orig_shape) return x class LayerNorm2d(nn.LayerNorm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, x): return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) class GlobalResponseNorm(nn.Module): "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x class ResBlock(nn.Module): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2): super().__init__() self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) # self.depthwise = SAMBlock(c, num_heads, expansion) self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( Linear(c + c_skip, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c) ) self.gradient_checkpointing = False self.factor = 1 def set_factor(self, k): if self.factor!=1: return self.factor = k self.depthwise.bias.data /= k self.channelwise[4].weight.data /= k self.channelwise[4].bias.data /= k def set_gradient_checkpointing(self, value): self.gradient_checkpointing = value def forward_body(self, x, x_skip=None): x_res = x #x = x /self.factor x = self.depthwise(x) x = self.norm(x) # if torch.any(torch.isnan(x)): #print("nan in first norm") if x_skip is not None: x = torch.cat([x, x_skip], dim=1) x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)# * self.factor # if torch.any(torch.isnan(x)): #print("nan in second norm") # result = x + x_res # if check_scale(x) > 5: # self.scale = 0.1 return x+ x_res def forward(self, x, x_skip=None): # if self.factor > 1: #print("ResBlock: factor > 1") if self.training and self.gradient_checkpointing: # logger.info("ResnetBlock2D: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): return func(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, x_skip) else: x = self.forward_body(x, x_skip) return x class AttnBlock(nn.Module): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): super().__init__() self.self_attn = self_attn self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) self.attention = Attention2D(c, nhead, dropout) self.kv_mapper = nn.Sequential(nn.SiLU(), Linear(c_cond, c)) self.gradient_checkpointing = False self.factor = 1 def set_factor(self, k): if self.factor!=1: return self.factor = k self.attention.attn.out_proj.weight.data /= k if self.attention.attn.out_proj.bias is not None: self.attention.attn.out_proj.bias.data /= k def set_gradient_checkpointing(self, value): self.gradient_checkpointing = value def forward_body(self, x, kv): kv = self.kv_mapper(kv) x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) #* self.factor return x def forward(self, x, kv): # if self.factor > 1: #print("AttnBlock: factor > 1") if self.training and self.gradient_checkpointing: # logger.info("AttnBlock: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): return func(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, kv) else: x = self.forward_body(x, kv) return x class FeedForwardBlock(nn.Module): def __init__(self, c, dropout=0.0): super().__init__() self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c) ) self.gradient_checkpointing = False def set_gradient_checkpointing(self, value): self.gradient_checkpointing = value def forward_body(self, x): x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x def forward(self, x): if self.training and self.gradient_checkpointing: # logger.info("FeedForwardBlock: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): return func(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x) else: x = self.forward_body(x) return x class TimestepBlock(nn.Module): def __init__(self, c, c_timestep, conds=["sca"]): super().__init__() self.mapper = Linear(c_timestep, c * 2) self.conds = conds for cname in conds: setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2)) self.factor = 1 def set_factor(self, k, ext_k): if self.factor!=1: return #print(f"TimestepBlock: factor = {k}, ext_k = {ext_k}") self.factor = k k_factor = k/ext_k a_weight_factor = 1/k_factor b_weight_factor = 1/k a_bias_offset = - ((k_factor - 1)/(k_factor))/(len(self.conds) + 1) for module in [self.mapper, *(getattr(self, f"mapper_{cname}") for cname in self.conds)]: a_bias, b_bias = module.bias.data.chunk(2, dim=0) a_weight, b_weight = module.weight.data.chunk(2, dim=0) module.weight.data.copy_( torch.concat([ a_weight * a_weight_factor, b_weight * b_weight_factor ]) ) module.bias.data.copy_( torch.concat([ a_bias * a_weight_factor + a_bias_offset, b_bias * b_weight_factor ]) ) def forward(self, x, t): # if self.factor > 1: #print("TimestepBlock: factor > 1") t = t.chunk(len(self.conds) + 1, dim=1) a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) for i, c in enumerate(self.conds): ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) a, b = a + ac, b + bc return (x * (1 + a) + b) # * self.factor class UpDownBlock2d(nn.Module): def __init__(self, c_in, c_out, mode, enabled=True): super().__init__() assert mode in ["up", "down"] interpolation = ( nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) if enabled else nn.Identity() ) mapping = nn.Conv2d(c_in, c_out, kernel_size=1) self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation]) self.mode = mode self.gradient_checkpointing = False def set_gradient_checkpointing(self, value): self.gradient_checkpointing = value def forward_body(self, x): org_dtype = x.dtype for i, block in enumerate(self.blocks): # 公式の実装では、常に float で計算しているが、すこしでもメモリを節約するために bfloat16 + Upsample のみ float に変換する # In the official implementation, it always calculates in float, but for the sake of saving memory, it converts to float only for bfloat16 + Upsample if x.dtype == torch.bfloat16 and (self.mode == "up" and i == 0 or self.mode != "up" and i == 1): x = x.float() x = block(x) x = x.to(org_dtype) return x def forward(self, x): if self.training and self.gradient_checkpointing: # logger.info("UpDownBlock2d: gradient_checkpointing") def create_custom_forward(func): def custom_forward(*inputs): return func(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x) else: x = self.forward_body(x) return x class StageAResBlock(nn.Module): def __init__(self, c, c_hidden): super().__init__() # depthwise/attention self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.depthwise = nn.Sequential(nn.ReplicationPad2d(1), nn.Conv2d(c, c, kernel_size=3, groups=c)) # channelwise self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( nn.Linear(c, c_hidden), nn.GELU(), nn.Linear(c_hidden, c), ) self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) # Init weights def _basic_init(module): if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) def _norm(self, x, norm): return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) def forward(self, x): mods = self.gammas x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] x = x + self.depthwise(x_temp) * mods[2] x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] return x class StageA(nn.Module): def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.43): # 0.3764 super().__init__() self.c_latent = c_latent self.scale_factor = scale_factor c_levels = [c_hidden // (2**i) for i in reversed(range(levels))] # Encoder blocks self.in_block = nn.Sequential(nn.PixelUnshuffle(2), nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)) down_blocks = [] for i in range(levels): if i > 0: down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) block = StageAResBlock(c_levels[i], c_levels[i] * 4) down_blocks.append(block) down_blocks.append( nn.Sequential( nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 ) ) self.down_blocks = nn.Sequential(*down_blocks) self.down_blocks[0] self.codebook_size = codebook_size self.vquantizer = VectorQuantize(c_latent, k=codebook_size) # Decoder blocks up_blocks = [nn.Sequential(nn.Conv2d(c_latent, c_levels[-1], kernel_size=1))] for i in range(levels): for j in range(bottleneck_blocks if i == 0 else 1): block = StageAResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) up_blocks.append(block) if i < levels - 1: up_blocks.append( nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1) ) self.up_blocks = nn.Sequential(*up_blocks) self.out_block = nn.Sequential( nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), nn.PixelShuffle(2), ) def encode(self, x, quantize=False): x = self.in_block(x) x = self.down_blocks(x) if quantize: qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 else: return x / self.scale_factor, None, None, None def decode(self, x): x = x * self.scale_factor x = self.up_blocks(x) x = self.out_block(x) return x def forward(self, x, quantize=False): qe, x, _, vq_loss = self.encode(x, quantize) x = self.decode(qe) return x, vq_loss r""" https://github.com/Stability-AI/StableCascade/blob/master/configs/inference/stage_b_3b.yaml # GLOBAL STUFF model_version: 3B dtype: bfloat16 # For demonstration purposes in reconstruct_images.ipynb webdataset_path: file:inference/imagenet_1024.tar batch_size: 4 image_size: 1024 grad_accum_steps: 1 effnet_checkpoint_path: models/effnet_encoder.safetensors stage_a_checkpoint_path: models/stage_a.safetensors generator_checkpoint_path: models/stage_b_bf16.safetensors """ class StageB(nn.Module): def __init__( self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=["CT", "CT", "CTA", "CTA"], c_clip=1280, c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.1, 0.1], self_attn=True, t_conds=["sca"], ): super().__init__() self.c_r = c_r self.t_conds = t_conds self.c_clip_seq = c_clip_seq if not isinstance(dropout, list): dropout = [dropout] * len(c_hidden) if not isinstance(self_attn, list): self_attn = [self_attn] * len(c_hidden) # CONDITIONING self.effnet_mapper = nn.Sequential( nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1), nn.GELU(), nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), ) self.pixels_mapper = nn.Sequential( nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1), nn.GELU(), nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1), LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), ) self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq) self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) self.embedding = nn.Sequential( nn.PixelUnshuffle(patch_size), nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), ) def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): if block_type == "C": return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) elif block_type == "A": return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) elif block_type == "F": return FeedForwardBlock(c_hidden, dropout=dropout) elif block_type == "T": return TimestepBlock(c_hidden, c_r, conds=t_conds) else: raise Exception(f"Block type {block_type} not supported") # BLOCKS # -- down blocks self.down_blocks = nn.ModuleList() self.down_downscalers = nn.ModuleList() self.down_repeat_mappers = nn.ModuleList() for i in range(len(c_hidden)): if i > 0: self.down_downscalers.append( nn.Sequential( LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2), ) ) else: self.down_downscalers.append(nn.Identity()) down_block = nn.ModuleList() for _ in range(blocks[0][i]): for block_type in level_config[i]: block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) down_block.append(block) self.down_blocks.append(down_block) if block_repeat is not None: block_repeat_mappers = nn.ModuleList() for _ in range(block_repeat[0][i] - 1): block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) self.down_repeat_mappers.append(block_repeat_mappers) # -- up blocks self.up_blocks = nn.ModuleList() self.up_upscalers = nn.ModuleList() self.up_repeat_mappers = nn.ModuleList() for i in reversed(range(len(c_hidden))): if i > 0: self.up_upscalers.append( nn.Sequential( LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2), ) ) else: self.up_upscalers.append(nn.Identity()) up_block = nn.ModuleList() for j in range(blocks[1][::-1][i]): for k, block_type in enumerate(level_config[i]): c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i]) up_block.append(block) self.up_blocks.append(up_block) if block_repeat is not None: block_repeat_mappers = nn.ModuleList() for _ in range(block_repeat[1][::-1][i] - 1): block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) self.up_repeat_mappers.append(block_repeat_mappers) # OUTPUT self.clf = nn.Sequential( LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1), nn.PixelShuffle(patch_size), ) # --- WEIGHT INIT --- self.apply(self._init_weights) # General init nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs nn.init.constant_(self.clf[1].weight, 0) # outputs # blocks for level_block in self.down_blocks + self.up_blocks: for block in level_block: if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) elif isinstance(block, TimestepBlock): for layer in block.modules(): if isinstance(layer, nn.Linear): nn.init.constant_(layer.weight, 0) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 emb = math.log(max_positions) / (half_dim - 1) emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() emb = r[:, None] * emb[None, :] emb = torch.cat([emb.sin(), emb.cos()], dim=1) if self.c_r % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), mode="constant") return emb def gen_c_embeddings(self, clip): if len(clip.shape) == 2: clip = clip.unsqueeze(1) clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) clip = self.clip_norm(clip) return clip def _down_encode(self, x, r_embed, clip): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: x = downscaler(x) for i in range(len(repmap) + 1): for block in down_block: if isinstance(block, ResBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock) ): x = block(x) elif isinstance(block, AttnBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock) ): x = block(x, clip) elif isinstance(block, TimestepBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock) ): x = block(x, r_embed) else: x = block(x) if i < len(repmap): x = repmap[i](x) level_outputs.insert(0, x) return level_outputs def _up_decode(self, level_outputs, r_embed, clip): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) for i, (up_block, upscaler, repmap) in enumerate(block_group): for j in range(len(repmap) + 1): for k, block in enumerate(up_block): if isinstance(block, ResBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock) ): skip = level_outputs[i] if k == 0 and i > 0 else None if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True) x = block(x, skip) elif isinstance(block, AttnBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock) ): x = block(x, clip) elif isinstance(block, TimestepBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock) ): x = block(x, r_embed) else: x = block(x) if j < len(repmap): x = repmap[j](x) x = upscaler(x) return x def forward(self, x, r, effnet, clip, pixels=None, **kwargs): if pixels is None: pixels = x.new_zeros(x.size(0), 3, 8, 8) # Process the conditioning embeddings r_embed = self.gen_r_embedding(r) for c in self.t_conds: t_cond = kwargs.get(c, torch.zeros_like(r)) r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) clip = self.gen_c_embeddings(clip) # Model Blocks x = self.embedding(x) x = x + self.effnet_mapper( nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode="bilinear", align_corners=True) ) x = x + nn.functional.interpolate( self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode="bilinear", align_corners=True ) level_outputs = self._down_encode(x, r_embed, clip) x = self._up_decode(level_outputs, r_embed, clip) return self.clf(x) def update_weights_ema(self, src_model, beta=0.999): for self_params, src_params in zip(self.parameters(), src_model.parameters()): self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) r""" https://github.com/Stability-AI/StableCascade/blob/master/configs/inference/stage_c_3b.yaml # GLOBAL STUFF model_version: 3.6B dtype: bfloat16 effnet_checkpoint_path: models/effnet_encoder.safetensors previewer_checkpoint_path: models/previewer.safetensors generator_checkpoint_path: models/stage_c_bf16.safetensors """ class StageC(nn.Module): def __init__( self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=["CTA", "CTA"], c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, dropout=[0.1, 0.1], self_attn=True, t_conds=["sca", "crp"], switch_level=[False], ): super().__init__() self.c_r = c_r self.t_conds = t_conds self.c_clip_seq = c_clip_seq if not isinstance(dropout, list): dropout = [dropout] * len(c_hidden) if not isinstance(self_attn, list): self_attn = [self_attn] * len(c_hidden) # CONDITIONING self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond) self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq) self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq) self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6) self.embedding = nn.Sequential( nn.PixelUnshuffle(patch_size), nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1), LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), ) def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): if block_type == "C": return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout) elif block_type == "A": return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout) elif block_type == "F": return FeedForwardBlock(c_hidden, dropout=dropout) elif block_type == "T": return TimestepBlock(c_hidden, c_r, conds=t_conds) else: raise Exception(f"Block type {block_type} not supported") # BLOCKS # -- down blocks self.down_blocks = nn.ModuleList() self.down_downscalers = nn.ModuleList() self.down_repeat_mappers = nn.ModuleList() for i in range(len(c_hidden)): if i > 0: self.down_downscalers.append( nn.Sequential( LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode="down", enabled=switch_level[i - 1]), ) ) else: self.down_downscalers.append(nn.Identity()) down_block = nn.ModuleList() for _ in range(blocks[0][i]): for block_type in level_config[i]: block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) down_block.append(block) self.down_blocks.append(down_block) if block_repeat is not None: block_repeat_mappers = nn.ModuleList() for _ in range(block_repeat[0][i] - 1): block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) self.down_repeat_mappers.append(block_repeat_mappers) # -- up blocks self.up_blocks = nn.ModuleList() self.up_upscalers = nn.ModuleList() self.up_repeat_mappers = nn.ModuleList() for i in reversed(range(len(c_hidden))): if i > 0: self.up_upscalers.append( nn.Sequential( LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6), UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode="up", enabled=switch_level[i - 1]), ) ) else: self.up_upscalers.append(nn.Identity()) up_block = nn.ModuleList() for j in range(blocks[1][::-1][i]): for k, block_type in enumerate(level_config[i]): c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i]) up_block.append(block) self.up_blocks.append(up_block) if block_repeat is not None: block_repeat_mappers = nn.ModuleList() for _ in range(block_repeat[1][::-1][i] - 1): block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1)) self.up_repeat_mappers.append(block_repeat_mappers) # OUTPUT self.clf = nn.Sequential( LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6), nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1), nn.PixelShuffle(patch_size), ) # --- WEIGHT INIT --- self.apply(self._init_weights) # General init nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs nn.init.constant_(self.clf[1].weight, 0) # outputs # blocks for level_block in self.down_blocks + self.up_blocks: for block in level_block: if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) elif isinstance(block, TimestepBlock): for layer in block.modules(): if isinstance(layer, nn.Linear): nn.init.constant_(layer.weight, 0) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): torch.nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def set_gradient_checkpointing(self, value): for block in self.down_blocks + self.up_blocks: for layer in block: if hasattr(layer, "set_gradient_checkpointing"): layer.set_gradient_checkpointing(value) def gen_r_embedding(self, r, max_positions=10000): r = r * max_positions half_dim = self.c_r // 2 emb = math.log(max_positions) / (half_dim - 1) emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() emb = r[:, None] * emb[None, :] emb = torch.cat([emb.sin(), emb.cos()], dim=1) if self.c_r % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), mode="constant") return emb def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): clip_txt = self.clip_txt_mapper(clip_txt) if len(clip_txt_pooled.shape) == 2: clip_txt_pool = clip_txt_pooled.unsqueeze(1) if len(clip_img.shape) == 2: clip_img = clip_img.unsqueeze(1) clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view( clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1 ) clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) clip = self.clip_norm(clip) return clip def _down_encode(self, x, r_embed, clip, cnet=None): level_outputs = [] block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) for down_block, downscaler, repmap in block_group: x = downscaler(x) for i in range(len(repmap) + 1): for block in down_block: if isinstance(block, ResBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock) ): if cnet is not None: next_cnet = cnet() if next_cnet is not None: x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True) x = block(x) elif isinstance(block, AttnBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock) ): x = block(x, clip) elif isinstance(block, TimestepBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock) ): x = block(x, r_embed) else: x = block(x) if i < len(repmap): x = repmap[i](x) level_outputs.insert(0, x) return level_outputs def _up_decode(self, level_outputs, r_embed, clip, cnet=None): x = level_outputs[0] block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) now_factor = 1 for i, (up_block, upscaler, repmap) in enumerate(block_group): for j in range(len(repmap) + 1): for k, block in enumerate(up_block): # if getattr(block, "factor", 1) > 1: # now_factor = -getattr(block, "factor", 1) # scale = check_scale(x) # if scale > 5 or (now_factor < 0 and scale > (5/-now_factor)): #print('='*55) #print(f"in: {i} {j} {k}") #print("up", scale) if isinstance(block, ResBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock) ): skip = level_outputs[i] if k == 0 and i > 0 else None if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True) if cnet is not None: next_cnet = cnet() if next_cnet is not None: x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True) x = block(x, skip) # if now_factor > 1 and block.factor == 1: # block.set_factor(now_factor) elif isinstance(block, AttnBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock) ): x = block(x, clip) # if now_factor > 1 and block.factor == 1: # block.set_factor(now_factor) elif isinstance(block, TimestepBlock) or ( hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock) ): x = block(x, r_embed) # scale = check_scale(x) # if now_factor > 1 and block.factor == 1: # block.set_factor(now_factor, now_factor) # pass # elif i==1: # now_factor = 5 # block.set_factor(now_factor, 1) else: x = block(x) # scale = check_scale(x) # if scale > 5 or (now_factor < 0 and scale > (5/-now_factor)): #print(f"out: {i} {j} {k}", '='*50) #print("up", scale) #print(block.__class__.__name__, torch.sum(torch.isnan(x))) if j < len(repmap): x = repmap[j](x) #print('-- pre upscaler ---') #print(check_scale(x)) x = upscaler(x) #print('-- post upscaler ---') #print(check_scale(x)) # if now_factor > 1: # if isinstance(upscaler, UpDownBlock2d): # upscaler.blocks[1].weight.data /= now_factor # upscaler.blocks[1].bias.data /= now_factor # scale = check_scale(x) # if scale > 5: #print('='*50) #print("upscaler", check_scale(x)) return x def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): # Process the conditioning embeddings r_embed = self.gen_r_embedding(r) for c in self.t_conds: t_cond = kwargs.get(c, torch.zeros_like(r)) r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1) clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) # Model Blocks x = self.embedding(x) #print(check_scale(x)) # ControlNet is not supported yet # if cnet is not None: # cnet = ControlNetDeliverer(cnet) level_outputs = self._down_encode(x, r_embed, clip, cnet) x1 = self._up_decode(level_outputs, r_embed, clip, cnet) result1 = self.clf(x1) return result1 # self.half() sd = self.state_dict() # x2 = self._up_decode(level_outputs, r_embed, clip, cnet) # result2 = self.clf(x2) #print(torch.nn.functional.mse_loss(result1, result2)) from safetensors.torch import save_file save_file(sd, 'factor5_pass4.safetensors') raise Exception("Early Stop") def update_weights_ema(self, src_model, beta=0.999): for self_params, src_params in zip(self.parameters(), src_model.parameters()): self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) @property def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 class Previewer(nn.Module): def __init__(self, c_in=16, c_hidden=512, c_out=3): super().__init__() self.blocks = nn.Sequential( nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels nn.GELU(), nn.BatchNorm2d(c_hidden), nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden), nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 nn.GELU(), nn.BatchNorm2d(c_hidden // 2), nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 2), nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 nn.GELU(), nn.BatchNorm2d(c_hidden // 4), nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 4), nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 nn.GELU(), nn.BatchNorm2d(c_hidden // 4), nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 4), nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), ) def forward(self, x): return self.blocks(x) @property def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, text_model): # deprecated # self, batch: dict, tokenizer, text_model, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None # is_eval の処理をここでやるのは微妙なので別のところでやる # is_unconditional もここでやるのは微妙なので別のところでやる # clip_image はとりあえずサポートしない if captions is not None: clip_tokens_unpooled = tokenizer( captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" ).to(text_model.device) text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True) else: text_encoder_output = text_model(input_ids, output_hidden_states=True) text_embeddings = text_encoder_output.hidden_states[-1] text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1) return text_embeddings, text_pooled_embeddings # return {"clip_text": text_embeddings, "clip_text_pooled": text_pooled_embeddings} # , "clip_img": image_embeddings} # region gdf class SimpleSampler: def __init__(self, gdf): self.gdf = gdf self.current_step = -1 def __call__(self, *args, **kwargs): self.current_step += 1 return self.step(*args, **kwargs) def init_x(self, shape): return torch.randn(*shape) def step(self, x, x0, epsilon, logSNR, logSNR_prev): raise NotImplementedError("You should override the 'apply' function.") class DDIMSampler(SimpleSampler): def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0): a, b = self.gdf.input_scaler(logSNR) if len(a.shape) == 1: a, b = a.view(-1, *[1] * (len(x0.shape) - 1)), b.view(-1, *[1] * (len(x0.shape) - 1)) a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) if len(a_prev.shape) == 1: a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1)) sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0 # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0) return x class DDPMSampler(DDIMSampler): def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1): return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta) class LCMSampler(SimpleSampler): def step(self, x, x0, epsilon, logSNR, logSNR_prev): a_prev, b_prev = self.gdf.input_scaler(logSNR_prev) if len(a_prev.shape) == 1: a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1)) return x0 * a_prev + torch.randn_like(epsilon) * b_prev class GDF: def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0): self.schedule = schedule self.input_scaler = input_scaler self.target = target self.noise_cond = noise_cond self.loss_weight = loss_weight self.offset_noise = offset_noise def setup_limits(self, stretch_max=True, stretch_min=True, shift=1): stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift) return stretched_limits def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None): if epsilon is None: epsilon = torch.randn_like(x0) if self.offset_noise > 0: if offset is None: offset = torch.randn([x0.size(0), x0.size(1)] + [1] * (len(x0.shape) - 2)).to(x0.device) epsilon = epsilon + offset * self.offset_noise logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device) a, b = self.input_scaler(logSNR) # B if len(a.shape) == 1: a, b = a.view(-1, *[1] * (len(x0.shape) - 1)), b.view(-1, *[1] * (len(x0.shape) - 1)) # BxCxHxW target = self.target(x0, epsilon, logSNR, a, b) # noised, noise, logSNR, t_cond return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift) def undiffuse(self, x, logSNR, pred): a, b = self.input_scaler(logSNR) if len(a.shape) == 1: a, b = a.view(-1, *[1] * (len(x.shape) - 1)), b.view(-1, *[1] * (len(x.shape) - 1)) return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b) def sample( self, model, model_inputs, shape, unconditional_inputs=None, sampler=None, schedule=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_t_stop=None, cfg_t_start=None, cfg_rho=0.7, sampler_params=None, shift=1, device="cpu", ): sampler_params = {} if sampler_params is None else sampler_params if sampler is None: sampler = DDPMSampler(self) r_range = torch.linspace(t_start, t_end, timesteps + 1) schedule = self.schedule if schedule is None else schedule logSNR_range = schedule(r_range, shift=shift)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(device) x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone() if cfg is not None: if unconditional_inputs is None: unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} model_inputs = { k: ( torch.cat([v, v_u], dim=0) if isinstance(v, torch.Tensor) else ( [ ( torch.cat([vi, vi_u], dim=0) if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor) else None ) for vi, vi_u in zip(v, v_u) ] if isinstance(v, list) else ( {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v} if isinstance(v, dict) else None ) ) ) for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items()) } for i in range(0, timesteps): noise_cond = self.noise_cond(logSNR_range[i]) if ( cfg is not None and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop) and (cfg_t_start is None or r_range[i].item() <= cfg_t_start) ): cfg_val = cfg if isinstance(cfg_val, (list, tuple)): assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2" cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1 - r_range[i].item()) pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2) pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val) if cfg_rho > 0: std_pos, std_cfg = pred.std(), pred_cfg.std() pred = cfg_rho * (pred_cfg * std_pos / (std_cfg + 1e-9)) + pred_cfg * (1 - cfg_rho) else: pred = pred_cfg else: pred = model(x, noise_cond, **model_inputs) x0, epsilon = self.undiffuse(x, logSNR_range[i], pred) x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i + 1], **sampler_params) altered_vars = yield (x0, x, pred) # Update some running variables if the user wants if altered_vars is not None: cfg = altered_vars.get("cfg", cfg) cfg_rho = altered_vars.get("cfg_rho", cfg_rho) sampler = altered_vars.get("sampler", sampler) model_inputs = altered_vars.get("model_inputs", model_inputs) x = altered_vars.get("x", x) x_init = altered_vars.get("x_init", x_init) class BaseSchedule: def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs): self.setup(*args, **kwargs) self.limits = None self.discrete_steps = discrete_steps self.shift = shift if force_limits: self.reset_limits() def reset_limits(self, shift=1, disable=False): try: self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max return self.limits except Exception: #print("WARNING: this schedule doesn't support t and will be unbounded") return None def setup(self, *args, **kwargs): raise NotImplementedError("this method needs to be overridden") def schedule(self, *args, **kwargs): raise NotImplementedError("this method needs to be overridden") def __call__(self, t, *args, shift=1, **kwargs): if isinstance(t, torch.Tensor): batch_size = None if self.discrete_steps is not None: if t.dtype != torch.long: t = (t * (self.discrete_steps - 1)).round().long() t = t / (self.discrete_steps - 1) t = t.clamp(0, 1) else: batch_size = t t = None logSNR = self.schedule(t, batch_size, *args, **kwargs) if shift * self.shift != 1: logSNR += 2 * np.log(1 / (shift * self.shift)) if self.limits is not None: logSNR = logSNR.clamp(*self.limits) return logSNR class CosineSchedule(BaseSchedule): def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False): self.s = torch.tensor([s]) self.clamp_range = clamp_range self.norm_instead = norm_instead self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 def schedule(self, t, batch_size): if t is None: t = (1 - torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0) s, min_var = self.s.to(t.device), self.min_var.to(t.device) var = torch.cos((s + t) / (1 + s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var if self.norm_instead: var = var * (self.clamp_range[1] - self.clamp_range[0]) + self.clamp_range[0] else: var = var.clamp(*self.clamp_range) logSNR = (var / (1 - var)).log() return logSNR class BaseScaler: def __init__(self): self.stretched_limits = None def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1): min_logSNR = schedule(torch.ones(1), shift=shift) max_logSNR = schedule(torch.zeros(1), shift=shift) min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1] max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0] self.stretched_limits = [min_a, max_a, min_b, max_b] return self.stretched_limits def stretch_limits(self, a, b): min_a, max_a, min_b, max_b = self.stretched_limits return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b) def scalers(self, logSNR): raise NotImplementedError("this method needs to be overridden") def __call__(self, logSNR): a, b = self.scalers(logSNR) if self.stretched_limits is not None: a, b = self.stretch_limits(a, b) return a, b class VPScaler(BaseScaler): def scalers(self, logSNR): a_squared = logSNR.sigmoid() a = a_squared.sqrt() b = (1 - a_squared).sqrt() return a, b class EpsilonTarget: def __call__(self, x0, epsilon, logSNR, a, b): return epsilon def x0(self, noised, pred, logSNR, a, b): return (noised - pred * b) / a def epsilon(self, noised, pred, logSNR, a, b): return pred class BaseNoiseCond: def __init__(self, *args, shift=1, clamp_range=None, **kwargs): clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range self.shift = shift self.clamp_range = clamp_range self.setup(*args, **kwargs) def setup(self, *args, **kwargs): pass # this method is optional, override it if required def cond(self, logSNR): raise NotImplementedError("this method needs to be overridden") def __call__(self, logSNR): if self.shift != 1: logSNR = logSNR.clone() + 2 * np.log(self.shift) return self.cond(logSNR).clamp(*self.clamp_range) class CosineTNoiseCond(BaseNoiseCond): def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999] self.s = torch.tensor([s]) self.clamp_range = clamp_range self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 def cond(self, logSNR): var = logSNR.sigmoid() var = var.clamp(*self.clamp_range) s, min_var = self.s.to(var.device), self.min_var.to(var.device) t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s return t # --- Loss Weighting class BaseLossWeight: def weight(self, logSNR): raise NotImplementedError("this method needs to be overridden") def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs): clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range if shift != 1: logSNR = logSNR.clone() + 2 * np.log(shift) return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range) # class ComposedLossWeight(BaseLossWeight): # def __init__(self, div, mul): # self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul # self.div = [div] if isinstance(div, BaseLossWeight) else div # def weight(self, logSNR): # prod, div = 1, 1 # for m in self.mul: # prod *= m.weight(logSNR) # for d in self.div: # div *= d.weight(logSNR) # return prod/div # class ConstantLossWeight(BaseLossWeight): # def __init__(self, v=1): # self.v = v # def weight(self, logSNR): # return torch.ones_like(logSNR) * self.v # class SNRLossWeight(BaseLossWeight): # def weight(self, logSNR): # return logSNR.exp() class P2LossWeight(BaseLossWeight): def __init__(self, k=1.0, gamma=1.0, s=1.0): self.k, self.gamma, self.s = k, gamma, s def weight(self, logSNR): return (self.k + (logSNR * self.s).exp()) ** -self.gamma # class SNRPlusOneLossWeight(BaseLossWeight): # def weight(self, logSNR): # return logSNR.exp() + 1 # class MinSNRLossWeight(BaseLossWeight): # def __init__(self, max_snr=5): # self.max_snr = max_snr # def weight(self, logSNR): # return logSNR.exp().clamp(max=self.max_snr) # class MinSNRPlusOneLossWeight(BaseLossWeight): # def __init__(self, max_snr=5): # self.max_snr = max_snr # def weight(self, logSNR): # return (logSNR.exp() + 1).clamp(max=self.max_snr) # class TruncatedSNRLossWeight(BaseLossWeight): # def __init__(self, min_snr=1): # self.min_snr = min_snr # def weight(self, logSNR): # return logSNR.exp().clamp(min=self.min_snr) # class SechLossWeight(BaseLossWeight): # def __init__(self, div=2): # self.div = div # def weight(self, logSNR): # return 1/(logSNR/self.div).cosh() # class DebiasedLossWeight(BaseLossWeight): # def weight(self, logSNR): # return 1/logSNR.exp().sqrt() # class SigmoidLossWeight(BaseLossWeight): # def __init__(self, s=1): # self.s = s # def weight(self, logSNR): # return (logSNR * self.s).sigmoid() class AdaptiveLossWeight(BaseLossWeight): def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]): self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets - 1) self.bucket_losses = torch.ones(buckets) self.weight_range = weight_range def weight(self, logSNR): indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR) return (1 / self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range) def update_buckets(self, logSNR, loss, beta=0.99): indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu() self.bucket_losses[indices] = self.bucket_losses[indices] * beta + loss.detach().cpu() * (1 - beta) # endregion gdf