# pytorch_diffusion + derived encoder decoder import logging import math from typing import Any, Callable, Optional import numpy as np import torch import torch.nn as nn from einops import rearrange from packaging import version logpy = logging.getLogger(__name__) try: import xformers import xformers.ops XFORMERS_IS_AVAILABLE = True except: XFORMERS_IS_AVAILABLE = False logpy.warning("no module 'xformers'. Processing without...") from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish return x * torch.sigmoid(x) def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm( num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True ) class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=2, padding=0 ) def forward(self, x): if self.with_conv: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x class ResnetBlock(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: self.nin_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x, temb): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def attention(self, h_: torch.Tensor) -> torch.Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q, k, v = map( lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) ) h_ = torch.nn.functional.scaled_dot_product_attention( q, k, v ) # scale is dim ** -0.5 per default # compute attention return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) def forward(self, x, **kwargs): h_ = x h_ = self.attention(h_) h_ = self.proj_out(h_) return x + h_ class MemoryEfficientAttnBlock(nn.Module): """ Uses xformers efficient implementation, see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 Note: this is a single-head self-attention operation """ # def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.attention_op: Optional[Any] = None def attention(self, h_: torch.Tensor) -> torch.Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention B, C, H, W = q.shape q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) q, k, v = map( lambda t: t.unsqueeze(3) .reshape(B, t.shape[1], 1, C) .permute(0, 2, 1, 3) .reshape(B * 1, t.shape[1], C) .contiguous(), (q, k, v), ) out = xformers.ops.memory_efficient_attention( q, k, v, attn_bias=None, op=self.attention_op ) out = ( out.unsqueeze(0) .reshape(B, 1, out.shape[1], C) .permute(0, 2, 1, 3) .reshape(B, out.shape[1], C) ) return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) def forward(self, x, **kwargs): h_ = x h_ = self.attention(h_) h_ = self.proj_out(h_) return x + h_ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): def forward(self, x, context=None, mask=None, **unused_kwargs): b, c, h, w = x.shape x = rearrange(x, "b c h w -> b (h w) c") out = super().forward(x, context=context, mask=mask) out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) return x + out def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in [ "vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none", ], f"attn_type {attn_type} unknown" if ( version.parse(torch.__version__) < version.parse("2.0.0") and attn_type != "none" ): assert XFORMERS_IS_AVAILABLE, ( f"We do not support vanilla attention in {torch.__version__} anymore, " f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" ) attn_type = "vanilla-xformers" logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": assert attn_kwargs is None return AttnBlock(in_channels) elif attn_type == "vanilla-xformers": logpy.info( f"building MemoryEfficientAttnBlock with {in_channels} in_channels..." ) return MemoryEfficientAttnBlock(in_channels) elif type == "memory-efficient-cross-attn": attn_kwargs["query_dim"] = in_channels return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) elif attn_type == "none": return nn.Identity(in_channels) else: return LinAttnBlock(in_channels) class Model(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla", ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.use_timestep = use_timestep if self.use_timestep: # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList( [ torch.nn.Linear(self.ch, self.temb_ch), torch.nn.Linear(self.temb_ch, self.temb_ch), ] ) # downsampling self.conv_in = torch.nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] skip_in = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: skip_in = ch * in_ch_mult[i_level] block.append( ResnetBlock( in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) def forward(self, x, t=None, context=None): # assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) if self.use_timestep: # timestep embedding assert t is not None temb = get_timestep_embedding(t, self.ch) temb = self.temb.dense[0](temb) temb = nonlinearity(temb) temb = self.temb.dense[1](temb) else: temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block]( torch.cat([h, hs.pop()], dim=1), temb ) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h def get_last_layer(self): return self.conv_out.weight class Encoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", **ignore_kwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = torch.nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1, ) def forward(self, x): # timestep embedding temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, attn_type="vanilla", **ignorekwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) logpy.info( "Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape) ) ) make_attn_cls = self._make_attn() make_resblock_cls = self._make_resblock() make_conv_cls = self._make_conv() # z to block_in self.conv_in = torch.nn.Conv2d( z_channels, block_in, kernel_size=3, stride=1, padding=1 ) # middle self.mid = nn.Module() self.mid.block_1 = make_resblock_cls( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) self.mid.block_2 = make_resblock_cls( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( make_resblock_cls( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn_cls(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = make_conv_cls( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) def _make_attn(self) -> Callable: return make_attn def _make_resblock(self) -> Callable: return ResnetBlock def _make_conv(self) -> Callable: return torch.nn.Conv2d def get_last_layer(self, **kwargs): return self.conv_out.weight def forward(self, z, **kwargs): # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb, **kwargs) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h, **kwargs) if i_level != 0: h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h, **kwargs) if self.tanh_out: h = torch.tanh(h) return h