Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						0c0c966
	
1
								Parent(s):
							
							f8ded7c
								
Upload with huggingface_hub
Browse files- cldm/cldm.py +417 -0
- cldm/model.py +21 -0
- ldm/models/autoencoder.py +219 -0
- ldm/models/diffusion/ddim.py +336 -0
- ldm/modules/attention.py +341 -0
- ldm/modules/diffusionmodules/model.py +852 -0
- ldm/modules/diffusionmodules/openaimodel.py +786 -0
- ldm/modules/diffusionmodules/upscaling.py +81 -0
- ldm/modules/diffusionmodules/util.py +270 -0
- ldm/modules/ema.py +80 -0
- ldm/util.py +197 -0
- models/cldm_v15.yaml +79 -0
    	
        cldm/cldm.py
    ADDED
    
    | @@ -0,0 +1,417 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import einops
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch as th
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from ldm.modules.diffusionmodules.util import (
         | 
| 7 | 
            +
                conv_nd,
         | 
| 8 | 
            +
                linear,
         | 
| 9 | 
            +
                zero_module,
         | 
| 10 | 
            +
                timestep_embedding,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from einops import rearrange, repeat
         | 
| 14 | 
            +
            from torchvision.utils import make_grid
         | 
| 15 | 
            +
            from ldm.modules.attention import SpatialTransformer
         | 
| 16 | 
            +
            from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
         | 
| 17 | 
            +
            from ldm.models.diffusion.ddpm import LatentDiffusion
         | 
| 18 | 
            +
            from ldm.util import log_txt_as_img, exists, instantiate_from_config
         | 
| 19 | 
            +
            from ldm.models.diffusion.ddim import DDIMSampler
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class ControlledUnetModel(UNetModel):
         | 
| 23 | 
            +
                def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
         | 
| 24 | 
            +
                    hs = []
         | 
| 25 | 
            +
                    with torch.no_grad():
         | 
| 26 | 
            +
                        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
         | 
| 27 | 
            +
                        emb = self.time_embed(t_emb)
         | 
| 28 | 
            +
                        h = x.type(self.dtype)
         | 
| 29 | 
            +
                        for module in self.input_blocks:
         | 
| 30 | 
            +
                            h = module(h, emb, context)
         | 
| 31 | 
            +
                            hs.append(h)
         | 
| 32 | 
            +
                        h = self.middle_block(h, emb, context)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    h += control.pop()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    for i, module in enumerate(self.output_blocks):
         | 
| 37 | 
            +
                        if only_mid_control:
         | 
| 38 | 
            +
                            h = torch.cat([h, hs.pop()], dim=1)
         | 
| 39 | 
            +
                        else:
         | 
| 40 | 
            +
                            h = torch.cat([h, hs.pop() + control.pop()], dim=1)
         | 
| 41 | 
            +
                        h = module(h, emb, context)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    h = h.type(x.dtype)
         | 
| 44 | 
            +
                    return self.out(h)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class ControlNet(nn.Module):
         | 
| 48 | 
            +
                def __init__(
         | 
| 49 | 
            +
                    self,
         | 
| 50 | 
            +
                    image_size,
         | 
| 51 | 
            +
                    in_channels,
         | 
| 52 | 
            +
                    model_channels,
         | 
| 53 | 
            +
                    hint_channels,
         | 
| 54 | 
            +
                    num_res_blocks,
         | 
| 55 | 
            +
                    attention_resolutions,
         | 
| 56 | 
            +
                    dropout=0,
         | 
| 57 | 
            +
                    channel_mult=(1, 2, 4, 8),
         | 
| 58 | 
            +
                    conv_resample=True,
         | 
| 59 | 
            +
                    dims=2,
         | 
| 60 | 
            +
                    use_checkpoint=False,
         | 
| 61 | 
            +
                    use_fp16=False,
         | 
| 62 | 
            +
                    num_heads=-1,
         | 
| 63 | 
            +
                    num_head_channels=-1,
         | 
| 64 | 
            +
                    num_heads_upsample=-1,
         | 
| 65 | 
            +
                    use_scale_shift_norm=False,
         | 
| 66 | 
            +
                    resblock_updown=False,
         | 
| 67 | 
            +
                    use_new_attention_order=False,
         | 
| 68 | 
            +
                    use_spatial_transformer=False,    # custom transformer support
         | 
| 69 | 
            +
                    transformer_depth=1,              # custom transformer support
         | 
| 70 | 
            +
                    context_dim=None,                 # custom transformer support
         | 
| 71 | 
            +
                    n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
         | 
| 72 | 
            +
                    legacy=True,
         | 
| 73 | 
            +
                    disable_self_attentions=None,
         | 
| 74 | 
            +
                    num_attention_blocks=None,
         | 
| 75 | 
            +
                    disable_middle_self_attn=False,
         | 
| 76 | 
            +
                    use_linear_in_transformer=False,
         | 
| 77 | 
            +
                ):
         | 
| 78 | 
            +
                    super().__init__()
         | 
| 79 | 
            +
                    if use_spatial_transformer:
         | 
| 80 | 
            +
                        assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    if context_dim is not None:
         | 
| 83 | 
            +
                        assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
         | 
| 84 | 
            +
                        from omegaconf.listconfig import ListConfig
         | 
| 85 | 
            +
                        if type(context_dim) == ListConfig:
         | 
| 86 | 
            +
                            context_dim = list(context_dim)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    if num_heads_upsample == -1:
         | 
| 89 | 
            +
                        num_heads_upsample = num_heads
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    if num_heads == -1:
         | 
| 92 | 
            +
                        assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    if num_head_channels == -1:
         | 
| 95 | 
            +
                        assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.dims = dims
         | 
| 98 | 
            +
                    self.image_size = image_size
         | 
| 99 | 
            +
                    self.in_channels = in_channels
         | 
| 100 | 
            +
                    self.model_channels = model_channels
         | 
| 101 | 
            +
                    if isinstance(num_res_blocks, int):
         | 
| 102 | 
            +
                        self.num_res_blocks = len(channel_mult) * [num_res_blocks]
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        if len(num_res_blocks) != len(channel_mult):
         | 
| 105 | 
            +
                            raise ValueError("provide num_res_blocks either as an int (globally constant) or "
         | 
| 106 | 
            +
                                             "as a list/tuple (per-level) with the same length as channel_mult")
         | 
| 107 | 
            +
                        self.num_res_blocks = num_res_blocks
         | 
| 108 | 
            +
                    if disable_self_attentions is not None:
         | 
| 109 | 
            +
                        # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
         | 
| 110 | 
            +
                        assert len(disable_self_attentions) == len(channel_mult)
         | 
| 111 | 
            +
                    if num_attention_blocks is not None:
         | 
| 112 | 
            +
                        assert len(num_attention_blocks) == len(self.num_res_blocks)
         | 
| 113 | 
            +
                        assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
         | 
| 114 | 
            +
                        print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
         | 
| 115 | 
            +
                              f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
         | 
| 116 | 
            +
                              f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
         | 
| 117 | 
            +
                              f"attention will still not be set.")
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    self.attention_resolutions = attention_resolutions
         | 
| 120 | 
            +
                    self.dropout = dropout
         | 
| 121 | 
            +
                    self.channel_mult = channel_mult
         | 
| 122 | 
            +
                    self.conv_resample = conv_resample
         | 
| 123 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 124 | 
            +
                    self.dtype = th.float16 if use_fp16 else th.float32
         | 
| 125 | 
            +
                    self.num_heads = num_heads
         | 
| 126 | 
            +
                    self.num_head_channels = num_head_channels
         | 
| 127 | 
            +
                    self.num_heads_upsample = num_heads_upsample
         | 
| 128 | 
            +
                    self.predict_codebook_ids = n_embed is not None
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    time_embed_dim = model_channels * 4
         | 
| 131 | 
            +
                    self.time_embed = nn.Sequential(
         | 
| 132 | 
            +
                        linear(model_channels, time_embed_dim),
         | 
| 133 | 
            +
                        nn.SiLU(),
         | 
| 134 | 
            +
                        linear(time_embed_dim, time_embed_dim),
         | 
| 135 | 
            +
                    )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    self.input_blocks = nn.ModuleList(
         | 
| 138 | 
            +
                        [
         | 
| 139 | 
            +
                            TimestepEmbedSequential(
         | 
| 140 | 
            +
                                conv_nd(dims, in_channels, model_channels, 3, padding=1)
         | 
| 141 | 
            +
                            )
         | 
| 142 | 
            +
                        ]
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
                    self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    self.input_hint_block = TimestepEmbedSequential(
         | 
| 147 | 
            +
                                conv_nd(dims, hint_channels, 16, 3, padding=1),
         | 
| 148 | 
            +
                                nn.SiLU(),
         | 
| 149 | 
            +
                                conv_nd(dims, 16, 16, 3, padding=1),
         | 
| 150 | 
            +
                                nn.SiLU(),
         | 
| 151 | 
            +
                                conv_nd(dims, 16, 32, 3, padding=1, stride=2),
         | 
| 152 | 
            +
                                nn.SiLU(),
         | 
| 153 | 
            +
                                conv_nd(dims, 32, 32, 3, padding=1),
         | 
| 154 | 
            +
                                nn.SiLU(),
         | 
| 155 | 
            +
                                conv_nd(dims, 32, 96, 3, padding=1, stride=2),
         | 
| 156 | 
            +
                                nn.SiLU(),
         | 
| 157 | 
            +
                                conv_nd(dims, 96, 96, 3, padding=1),
         | 
| 158 | 
            +
                                nn.SiLU(),
         | 
| 159 | 
            +
                                conv_nd(dims, 96, 256, 3, padding=1, stride=2),
         | 
| 160 | 
            +
                                nn.SiLU(),
         | 
| 161 | 
            +
                                zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
         | 
| 162 | 
            +
                    )
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    self._feature_size = model_channels
         | 
| 165 | 
            +
                    input_block_chans = [model_channels]
         | 
| 166 | 
            +
                    ch = model_channels
         | 
| 167 | 
            +
                    ds = 1
         | 
| 168 | 
            +
                    for level, mult in enumerate(channel_mult):
         | 
| 169 | 
            +
                        for nr in range(self.num_res_blocks[level]):
         | 
| 170 | 
            +
                            layers = [
         | 
| 171 | 
            +
                                ResBlock(
         | 
| 172 | 
            +
                                    ch,
         | 
| 173 | 
            +
                                    time_embed_dim,
         | 
| 174 | 
            +
                                    dropout,
         | 
| 175 | 
            +
                                    out_channels=mult * model_channels,
         | 
| 176 | 
            +
                                    dims=dims,
         | 
| 177 | 
            +
                                    use_checkpoint=use_checkpoint,
         | 
| 178 | 
            +
                                    use_scale_shift_norm=use_scale_shift_norm,
         | 
| 179 | 
            +
                                )
         | 
| 180 | 
            +
                            ]
         | 
| 181 | 
            +
                            ch = mult * model_channels
         | 
| 182 | 
            +
                            if ds in attention_resolutions:
         | 
| 183 | 
            +
                                if num_head_channels == -1:
         | 
| 184 | 
            +
                                    dim_head = ch // num_heads
         | 
| 185 | 
            +
                                else:
         | 
| 186 | 
            +
                                    num_heads = ch // num_head_channels
         | 
| 187 | 
            +
                                    dim_head = num_head_channels
         | 
| 188 | 
            +
                                if legacy:
         | 
| 189 | 
            +
                                    #num_heads = 1
         | 
| 190 | 
            +
                                    dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
         | 
| 191 | 
            +
                                if exists(disable_self_attentions):
         | 
| 192 | 
            +
                                    disabled_sa = disable_self_attentions[level]
         | 
| 193 | 
            +
                                else:
         | 
| 194 | 
            +
                                    disabled_sa = False
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                                if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
         | 
| 197 | 
            +
                                    layers.append(
         | 
| 198 | 
            +
                                        AttentionBlock(
         | 
| 199 | 
            +
                                            ch,
         | 
| 200 | 
            +
                                            use_checkpoint=use_checkpoint,
         | 
| 201 | 
            +
                                            num_heads=num_heads,
         | 
| 202 | 
            +
                                            num_head_channels=dim_head,
         | 
| 203 | 
            +
                                            use_new_attention_order=use_new_attention_order,
         | 
| 204 | 
            +
                                        ) if not use_spatial_transformer else SpatialTransformer(
         | 
| 205 | 
            +
                                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
         | 
| 206 | 
            +
                                            disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
         | 
| 207 | 
            +
                                            use_checkpoint=use_checkpoint
         | 
| 208 | 
            +
                                        )
         | 
| 209 | 
            +
                                    )
         | 
| 210 | 
            +
                            self.input_blocks.append(TimestepEmbedSequential(*layers))
         | 
| 211 | 
            +
                            self.zero_convs.append(self.make_zero_conv(ch))
         | 
| 212 | 
            +
                            self._feature_size += ch
         | 
| 213 | 
            +
                            input_block_chans.append(ch)
         | 
| 214 | 
            +
                        if level != len(channel_mult) - 1:
         | 
| 215 | 
            +
                            out_ch = ch
         | 
| 216 | 
            +
                            self.input_blocks.append(
         | 
| 217 | 
            +
                                TimestepEmbedSequential(
         | 
| 218 | 
            +
                                    ResBlock(
         | 
| 219 | 
            +
                                        ch,
         | 
| 220 | 
            +
                                        time_embed_dim,
         | 
| 221 | 
            +
                                        dropout,
         | 
| 222 | 
            +
                                        out_channels=out_ch,
         | 
| 223 | 
            +
                                        dims=dims,
         | 
| 224 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 225 | 
            +
                                        use_scale_shift_norm=use_scale_shift_norm,
         | 
| 226 | 
            +
                                        down=True,
         | 
| 227 | 
            +
                                    )
         | 
| 228 | 
            +
                                    if resblock_updown
         | 
| 229 | 
            +
                                    else Downsample(
         | 
| 230 | 
            +
                                        ch, conv_resample, dims=dims, out_channels=out_ch
         | 
| 231 | 
            +
                                    )
         | 
| 232 | 
            +
                                )
         | 
| 233 | 
            +
                            )
         | 
| 234 | 
            +
                            ch = out_ch
         | 
| 235 | 
            +
                            input_block_chans.append(ch)
         | 
| 236 | 
            +
                            self.zero_convs.append(self.make_zero_conv(ch))
         | 
| 237 | 
            +
                            ds *= 2
         | 
| 238 | 
            +
                            self._feature_size += ch
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    if num_head_channels == -1:
         | 
| 241 | 
            +
                        dim_head = ch // num_heads
         | 
| 242 | 
            +
                    else:
         | 
| 243 | 
            +
                        num_heads = ch // num_head_channels
         | 
| 244 | 
            +
                        dim_head = num_head_channels
         | 
| 245 | 
            +
                    if legacy:
         | 
| 246 | 
            +
                        #num_heads = 1
         | 
| 247 | 
            +
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
         | 
| 248 | 
            +
                    self.middle_block = TimestepEmbedSequential(
         | 
| 249 | 
            +
                        ResBlock(
         | 
| 250 | 
            +
                            ch,
         | 
| 251 | 
            +
                            time_embed_dim,
         | 
| 252 | 
            +
                            dropout,
         | 
| 253 | 
            +
                            dims=dims,
         | 
| 254 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 255 | 
            +
                            use_scale_shift_norm=use_scale_shift_norm,
         | 
| 256 | 
            +
                        ),
         | 
| 257 | 
            +
                        AttentionBlock(
         | 
| 258 | 
            +
                            ch,
         | 
| 259 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 260 | 
            +
                            num_heads=num_heads,
         | 
| 261 | 
            +
                            num_head_channels=dim_head,
         | 
| 262 | 
            +
                            use_new_attention_order=use_new_attention_order,
         | 
| 263 | 
            +
                        ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
         | 
| 264 | 
            +
                                        ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
         | 
| 265 | 
            +
                                        disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
         | 
| 266 | 
            +
                                        use_checkpoint=use_checkpoint
         | 
| 267 | 
            +
                                    ),
         | 
| 268 | 
            +
                        ResBlock(
         | 
| 269 | 
            +
                            ch,
         | 
| 270 | 
            +
                            time_embed_dim,
         | 
| 271 | 
            +
                            dropout,
         | 
| 272 | 
            +
                            dims=dims,
         | 
| 273 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 274 | 
            +
                            use_scale_shift_norm=use_scale_shift_norm,
         | 
| 275 | 
            +
                        ),
         | 
| 276 | 
            +
                    )
         | 
| 277 | 
            +
                    self.middle_block_out = self.make_zero_conv(ch)
         | 
| 278 | 
            +
                    self._feature_size += ch
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                def make_zero_conv(self, channels):
         | 
| 281 | 
            +
                    return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                def forward(self, x, hint, timesteps, context, **kwargs):
         | 
| 284 | 
            +
                    t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
         | 
| 285 | 
            +
                    emb = self.time_embed(t_emb)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    guided_hint = self.input_hint_block(hint, emb, context)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    outs = []
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                    h = x.type(self.dtype)
         | 
| 292 | 
            +
                    for module, zero_conv in zip(self.input_blocks, self.zero_convs):
         | 
| 293 | 
            +
                        if guided_hint is not None:
         | 
| 294 | 
            +
                            h = module(h, emb, context)
         | 
| 295 | 
            +
                            h += guided_hint
         | 
| 296 | 
            +
                            guided_hint = None
         | 
| 297 | 
            +
                        else:
         | 
| 298 | 
            +
                            h = module(h, emb, context)
         | 
| 299 | 
            +
                        outs.append(zero_conv(h, emb, context))
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    h = self.middle_block(h, emb, context)
         | 
| 302 | 
            +
                    outs.append(self.middle_block_out(h, emb, context))
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    return outs
         | 
| 305 | 
            +
             | 
| 306 | 
            +
             | 
| 307 | 
            +
            class ControlLDM(LatentDiffusion):
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
         | 
| 310 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 311 | 
            +
                    self.control_model = instantiate_from_config(control_stage_config)
         | 
| 312 | 
            +
                    self.control_key = control_key
         | 
| 313 | 
            +
                    self.only_mid_control = only_mid_control
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                @torch.no_grad()
         | 
| 316 | 
            +
                def get_input(self, batch, k, bs=None, *args, **kwargs):
         | 
| 317 | 
            +
                    x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
         | 
| 318 | 
            +
                    control = batch[self.control_key]
         | 
| 319 | 
            +
                    if bs is not None:
         | 
| 320 | 
            +
                        control = control[:bs]
         | 
| 321 | 
            +
                    control = control.to(self.device)
         | 
| 322 | 
            +
                    control = einops.rearrange(control, 'b h w c -> b c h w')
         | 
| 323 | 
            +
                    control = control.to(memory_format=torch.contiguous_format).float()
         | 
| 324 | 
            +
                    return x, dict(c_crossattn=[c], c_concat=[control])
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                def apply_model(self, x_noisy, t, cond, *args, **kwargs):
         | 
| 327 | 
            +
                    assert isinstance(cond, dict)
         | 
| 328 | 
            +
                    diffusion_model = self.model.diffusion_model
         | 
| 329 | 
            +
                    cond_txt = torch.cat(cond['c_crossattn'], 1)
         | 
| 330 | 
            +
                    cond_hint = torch.cat(cond['c_concat'], 1)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    control = self.control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt)
         | 
| 333 | 
            +
                    eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    return eps
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                @torch.no_grad()
         | 
| 338 | 
            +
                def get_unconditional_conditioning(self, N):
         | 
| 339 | 
            +
                    return self.get_learned_conditioning([""] * N)
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                @torch.no_grad()
         | 
| 342 | 
            +
                def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
         | 
| 343 | 
            +
                               quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
         | 
| 344 | 
            +
                               plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
         | 
| 345 | 
            +
                               use_ema_scope=True,
         | 
| 346 | 
            +
                               **kwargs):
         | 
| 347 | 
            +
                    use_ddim = ddim_steps is not None
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    log = dict()
         | 
| 350 | 
            +
                    z, c = self.get_input(batch, self.first_stage_key, bs=N)
         | 
| 351 | 
            +
                    c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
         | 
| 352 | 
            +
                    N = min(z.shape[0], N)
         | 
| 353 | 
            +
                    n_row = min(z.shape[0], n_row)
         | 
| 354 | 
            +
                    log["reconstruction"] = self.decode_first_stage(z)
         | 
| 355 | 
            +
                    log["control"] = c_cat * 2.0 - 1.0
         | 
| 356 | 
            +
                    log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    if plot_diffusion_rows:
         | 
| 359 | 
            +
                        # get diffusion row
         | 
| 360 | 
            +
                        diffusion_row = list()
         | 
| 361 | 
            +
                        z_start = z[:n_row]
         | 
| 362 | 
            +
                        for t in range(self.num_timesteps):
         | 
| 363 | 
            +
                            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
         | 
| 364 | 
            +
                                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
         | 
| 365 | 
            +
                                t = t.to(self.device).long()
         | 
| 366 | 
            +
                                noise = torch.randn_like(z_start)
         | 
| 367 | 
            +
                                z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
         | 
| 368 | 
            +
                                diffusion_row.append(self.decode_first_stage(z_noisy))
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                        diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
         | 
| 371 | 
            +
                        diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
         | 
| 372 | 
            +
                        diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
         | 
| 373 | 
            +
                        diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
         | 
| 374 | 
            +
                        log["diffusion_row"] = diffusion_grid
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    if sample:
         | 
| 377 | 
            +
                        # get denoise row
         | 
| 378 | 
            +
                        samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
         | 
| 379 | 
            +
                                                                 batch_size=N, ddim=use_ddim,
         | 
| 380 | 
            +
                                                                 ddim_steps=ddim_steps, eta=ddim_eta)
         | 
| 381 | 
            +
                        x_samples = self.decode_first_stage(samples)
         | 
| 382 | 
            +
                        log["samples"] = x_samples
         | 
| 383 | 
            +
                        if plot_denoise_rows:
         | 
| 384 | 
            +
                            denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
         | 
| 385 | 
            +
                            log["denoise_row"] = denoise_grid
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    if unconditional_guidance_scale > 1.0:
         | 
| 388 | 
            +
                        uc_cross = self.get_unconditional_conditioning(N)
         | 
| 389 | 
            +
                        uc_cat = c_cat  # torch.zeros_like(c_cat)
         | 
| 390 | 
            +
                        uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
         | 
| 391 | 
            +
                        samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
         | 
| 392 | 
            +
                                                         batch_size=N, ddim=use_ddim,
         | 
| 393 | 
            +
                                                         ddim_steps=ddim_steps, eta=ddim_eta,
         | 
| 394 | 
            +
                                                         unconditional_guidance_scale=unconditional_guidance_scale,
         | 
| 395 | 
            +
                                                         unconditional_conditioning=uc_full,
         | 
| 396 | 
            +
                                                         )
         | 
| 397 | 
            +
                        x_samples_cfg = self.decode_first_stage(samples_cfg)
         | 
| 398 | 
            +
                        log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    return log
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                @torch.no_grad()
         | 
| 403 | 
            +
                def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
         | 
| 404 | 
            +
                    ddim_sampler = DDIMSampler(self)
         | 
| 405 | 
            +
                    b, c, h, w = cond["c_concat"][0].shape
         | 
| 406 | 
            +
                    shape = (self.channels, h // 8, w // 8)
         | 
| 407 | 
            +
                    samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
         | 
| 408 | 
            +
                    return samples, intermediates
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                def configure_optimizers(self):
         | 
| 411 | 
            +
                    lr = self.learning_rate
         | 
| 412 | 
            +
                    params = list(self.control_model.parameters())
         | 
| 413 | 
            +
                    if not self.sd_locked:
         | 
| 414 | 
            +
                        params += list(self.model.diffusion_model.output_blocks.parameters())
         | 
| 415 | 
            +
                        params += list(self.model.diffusion_model.out.parameters())
         | 
| 416 | 
            +
                    opt = torch.optim.AdamW(params, lr=lr)
         | 
| 417 | 
            +
                    return opt
         | 
    	
        cldm/model.py
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from omegaconf import OmegaConf
         | 
| 4 | 
            +
            from ldm.util import instantiate_from_config
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def get_state_dict(d):
         | 
| 8 | 
            +
                return d.get('state_dict', d)
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def load_state_dict(ckpt_path, location='cpu'):
         | 
| 12 | 
            +
                state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
         | 
| 13 | 
            +
                print(f'Loaded state_dict from [{ckpt_path}]')
         | 
| 14 | 
            +
                return state_dict
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def create_model(config_path):
         | 
| 18 | 
            +
                config = OmegaConf.load(config_path)
         | 
| 19 | 
            +
                model = instantiate_from_config(config.model).cpu()
         | 
| 20 | 
            +
                print(f'Loaded model config from [{config_path}]')
         | 
| 21 | 
            +
                return model
         | 
    	
        ldm/models/autoencoder.py
    ADDED
    
    | @@ -0,0 +1,219 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import pytorch_lightning as pl
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from contextlib import contextmanager
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from ldm.modules.diffusionmodules.model import Encoder, Decoder
         | 
| 7 | 
            +
            from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from ldm.util import instantiate_from_config
         | 
| 10 | 
            +
            from ldm.modules.ema import LitEma
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class AutoencoderKL(pl.LightningModule):
         | 
| 14 | 
            +
                def __init__(self,
         | 
| 15 | 
            +
                             ddconfig,
         | 
| 16 | 
            +
                             lossconfig,
         | 
| 17 | 
            +
                             embed_dim,
         | 
| 18 | 
            +
                             ckpt_path=None,
         | 
| 19 | 
            +
                             ignore_keys=[],
         | 
| 20 | 
            +
                             image_key="image",
         | 
| 21 | 
            +
                             colorize_nlabels=None,
         | 
| 22 | 
            +
                             monitor=None,
         | 
| 23 | 
            +
                             ema_decay=None,
         | 
| 24 | 
            +
                             learn_logvar=False
         | 
| 25 | 
            +
                             ):
         | 
| 26 | 
            +
                    super().__init__()
         | 
| 27 | 
            +
                    self.learn_logvar = learn_logvar
         | 
| 28 | 
            +
                    self.image_key = image_key
         | 
| 29 | 
            +
                    self.encoder = Encoder(**ddconfig)
         | 
| 30 | 
            +
                    self.decoder = Decoder(**ddconfig)
         | 
| 31 | 
            +
                    self.loss = instantiate_from_config(lossconfig)
         | 
| 32 | 
            +
                    assert ddconfig["double_z"]
         | 
| 33 | 
            +
                    self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
         | 
| 34 | 
            +
                    self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
         | 
| 35 | 
            +
                    self.embed_dim = embed_dim
         | 
| 36 | 
            +
                    if colorize_nlabels is not None:
         | 
| 37 | 
            +
                        assert type(colorize_nlabels)==int
         | 
| 38 | 
            +
                        self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
         | 
| 39 | 
            +
                    if monitor is not None:
         | 
| 40 | 
            +
                        self.monitor = monitor
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.use_ema = ema_decay is not None
         | 
| 43 | 
            +
                    if self.use_ema:
         | 
| 44 | 
            +
                        self.ema_decay = ema_decay
         | 
| 45 | 
            +
                        assert 0. < ema_decay < 1.
         | 
| 46 | 
            +
                        self.model_ema = LitEma(self, decay=ema_decay)
         | 
| 47 | 
            +
                        print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    if ckpt_path is not None:
         | 
| 50 | 
            +
                        self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def init_from_ckpt(self, path, ignore_keys=list()):
         | 
| 53 | 
            +
                    sd = torch.load(path, map_location="cpu")["state_dict"]
         | 
| 54 | 
            +
                    keys = list(sd.keys())
         | 
| 55 | 
            +
                    for k in keys:
         | 
| 56 | 
            +
                        for ik in ignore_keys:
         | 
| 57 | 
            +
                            if k.startswith(ik):
         | 
| 58 | 
            +
                                print("Deleting key {} from state_dict.".format(k))
         | 
| 59 | 
            +
                                del sd[k]
         | 
| 60 | 
            +
                    self.load_state_dict(sd, strict=False)
         | 
| 61 | 
            +
                    print(f"Restored from {path}")
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                @contextmanager
         | 
| 64 | 
            +
                def ema_scope(self, context=None):
         | 
| 65 | 
            +
                    if self.use_ema:
         | 
| 66 | 
            +
                        self.model_ema.store(self.parameters())
         | 
| 67 | 
            +
                        self.model_ema.copy_to(self)
         | 
| 68 | 
            +
                        if context is not None:
         | 
| 69 | 
            +
                            print(f"{context}: Switched to EMA weights")
         | 
| 70 | 
            +
                    try:
         | 
| 71 | 
            +
                        yield None
         | 
| 72 | 
            +
                    finally:
         | 
| 73 | 
            +
                        if self.use_ema:
         | 
| 74 | 
            +
                            self.model_ema.restore(self.parameters())
         | 
| 75 | 
            +
                            if context is not None:
         | 
| 76 | 
            +
                                print(f"{context}: Restored training weights")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def on_train_batch_end(self, *args, **kwargs):
         | 
| 79 | 
            +
                    if self.use_ema:
         | 
| 80 | 
            +
                        self.model_ema(self)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def encode(self, x):
         | 
| 83 | 
            +
                    h = self.encoder(x)
         | 
| 84 | 
            +
                    moments = self.quant_conv(h)
         | 
| 85 | 
            +
                    posterior = DiagonalGaussianDistribution(moments)
         | 
| 86 | 
            +
                    return posterior
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def decode(self, z):
         | 
| 89 | 
            +
                    z = self.post_quant_conv(z)
         | 
| 90 | 
            +
                    dec = self.decoder(z)
         | 
| 91 | 
            +
                    return dec
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def forward(self, input, sample_posterior=True):
         | 
| 94 | 
            +
                    posterior = self.encode(input)
         | 
| 95 | 
            +
                    if sample_posterior:
         | 
| 96 | 
            +
                        z = posterior.sample()
         | 
| 97 | 
            +
                    else:
         | 
| 98 | 
            +
                        z = posterior.mode()
         | 
| 99 | 
            +
                    dec = self.decode(z)
         | 
| 100 | 
            +
                    return dec, posterior
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def get_input(self, batch, k):
         | 
| 103 | 
            +
                    x = batch[k]
         | 
| 104 | 
            +
                    if len(x.shape) == 3:
         | 
| 105 | 
            +
                        x = x[..., None]
         | 
| 106 | 
            +
                    x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
         | 
| 107 | 
            +
                    return x
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def training_step(self, batch, batch_idx, optimizer_idx):
         | 
| 110 | 
            +
                    inputs = self.get_input(batch, self.image_key)
         | 
| 111 | 
            +
                    reconstructions, posterior = self(inputs)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if optimizer_idx == 0:
         | 
| 114 | 
            +
                        # train encoder+decoder+logvar
         | 
| 115 | 
            +
                        aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
         | 
| 116 | 
            +
                                                        last_layer=self.get_last_layer(), split="train")
         | 
| 117 | 
            +
                        self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
         | 
| 118 | 
            +
                        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
         | 
| 119 | 
            +
                        return aeloss
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    if optimizer_idx == 1:
         | 
| 122 | 
            +
                        # train the discriminator
         | 
| 123 | 
            +
                        discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
         | 
| 124 | 
            +
                                                            last_layer=self.get_last_layer(), split="train")
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
         | 
| 127 | 
            +
                        self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
         | 
| 128 | 
            +
                        return discloss
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def validation_step(self, batch, batch_idx):
         | 
| 131 | 
            +
                    log_dict = self._validation_step(batch, batch_idx)
         | 
| 132 | 
            +
                    with self.ema_scope():
         | 
| 133 | 
            +
                        log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
         | 
| 134 | 
            +
                    return log_dict
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def _validation_step(self, batch, batch_idx, postfix=""):
         | 
| 137 | 
            +
                    inputs = self.get_input(batch, self.image_key)
         | 
| 138 | 
            +
                    reconstructions, posterior = self(inputs)
         | 
| 139 | 
            +
                    aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
         | 
| 140 | 
            +
                                                    last_layer=self.get_last_layer(), split="val"+postfix)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
         | 
| 143 | 
            +
                                                        last_layer=self.get_last_layer(), split="val"+postfix)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
         | 
| 146 | 
            +
                    self.log_dict(log_dict_ae)
         | 
| 147 | 
            +
                    self.log_dict(log_dict_disc)
         | 
| 148 | 
            +
                    return self.log_dict
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def configure_optimizers(self):
         | 
| 151 | 
            +
                    lr = self.learning_rate
         | 
| 152 | 
            +
                    ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
         | 
| 153 | 
            +
                        self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
         | 
| 154 | 
            +
                    if self.learn_logvar:
         | 
| 155 | 
            +
                        print(f"{self.__class__.__name__}: Learning logvar")
         | 
| 156 | 
            +
                        ae_params_list.append(self.loss.logvar)
         | 
| 157 | 
            +
                    opt_ae = torch.optim.Adam(ae_params_list,
         | 
| 158 | 
            +
                                              lr=lr, betas=(0.5, 0.9))
         | 
| 159 | 
            +
                    opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
         | 
| 160 | 
            +
                                                lr=lr, betas=(0.5, 0.9))
         | 
| 161 | 
            +
                    return [opt_ae, opt_disc], []
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def get_last_layer(self):
         | 
| 164 | 
            +
                    return self.decoder.conv_out.weight
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                @torch.no_grad()
         | 
| 167 | 
            +
                def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
         | 
| 168 | 
            +
                    log = dict()
         | 
| 169 | 
            +
                    x = self.get_input(batch, self.image_key)
         | 
| 170 | 
            +
                    x = x.to(self.device)
         | 
| 171 | 
            +
                    if not only_inputs:
         | 
| 172 | 
            +
                        xrec, posterior = self(x)
         | 
| 173 | 
            +
                        if x.shape[1] > 3:
         | 
| 174 | 
            +
                            # colorize with random projection
         | 
| 175 | 
            +
                            assert xrec.shape[1] > 3
         | 
| 176 | 
            +
                            x = self.to_rgb(x)
         | 
| 177 | 
            +
                            xrec = self.to_rgb(xrec)
         | 
| 178 | 
            +
                        log["samples"] = self.decode(torch.randn_like(posterior.sample()))
         | 
| 179 | 
            +
                        log["reconstructions"] = xrec
         | 
| 180 | 
            +
                        if log_ema or self.use_ema:
         | 
| 181 | 
            +
                            with self.ema_scope():
         | 
| 182 | 
            +
                                xrec_ema, posterior_ema = self(x)
         | 
| 183 | 
            +
                                if x.shape[1] > 3:
         | 
| 184 | 
            +
                                    # colorize with random projection
         | 
| 185 | 
            +
                                    assert xrec_ema.shape[1] > 3
         | 
| 186 | 
            +
                                    xrec_ema = self.to_rgb(xrec_ema)
         | 
| 187 | 
            +
                                log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
         | 
| 188 | 
            +
                                log["reconstructions_ema"] = xrec_ema
         | 
| 189 | 
            +
                    log["inputs"] = x
         | 
| 190 | 
            +
                    return log
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                def to_rgb(self, x):
         | 
| 193 | 
            +
                    assert self.image_key == "segmentation"
         | 
| 194 | 
            +
                    if not hasattr(self, "colorize"):
         | 
| 195 | 
            +
                        self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
         | 
| 196 | 
            +
                    x = F.conv2d(x, weight=self.colorize)
         | 
| 197 | 
            +
                    x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
         | 
| 198 | 
            +
                    return x
         | 
| 199 | 
            +
             | 
| 200 | 
            +
             | 
| 201 | 
            +
            class IdentityFirstStage(torch.nn.Module):
         | 
| 202 | 
            +
                def __init__(self, *args, vq_interface=False, **kwargs):
         | 
| 203 | 
            +
                    self.vq_interface = vq_interface
         | 
| 204 | 
            +
                    super().__init__()
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                def encode(self, x, *args, **kwargs):
         | 
| 207 | 
            +
                    return x
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def decode(self, x, *args, **kwargs):
         | 
| 210 | 
            +
                    return x
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                def quantize(self, x, *args, **kwargs):
         | 
| 213 | 
            +
                    if self.vq_interface:
         | 
| 214 | 
            +
                        return x, None, [None, None, None]
         | 
| 215 | 
            +
                    return x
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                def forward(self, x, *args, **kwargs):
         | 
| 218 | 
            +
                    return x
         | 
| 219 | 
            +
             | 
    	
        ldm/models/diffusion/ddim.py
    ADDED
    
    | @@ -0,0 +1,336 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """SAMPLING ONLY."""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from tqdm import tqdm
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class DDIMSampler(object):
         | 
| 11 | 
            +
                def __init__(self, model, schedule="linear", **kwargs):
         | 
| 12 | 
            +
                    super().__init__()
         | 
| 13 | 
            +
                    self.model = model
         | 
| 14 | 
            +
                    self.ddpm_num_timesteps = model.num_timesteps
         | 
| 15 | 
            +
                    self.schedule = schedule
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def register_buffer(self, name, attr):
         | 
| 18 | 
            +
                    if type(attr) == torch.Tensor:
         | 
| 19 | 
            +
                        if attr.device != torch.device("cuda"):
         | 
| 20 | 
            +
                            attr = attr.to(torch.device("cuda"))
         | 
| 21 | 
            +
                    setattr(self, name, attr)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
         | 
| 24 | 
            +
                    self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
         | 
| 25 | 
            +
                                                              num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
         | 
| 26 | 
            +
                    alphas_cumprod = self.model.alphas_cumprod
         | 
| 27 | 
            +
                    assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
         | 
| 28 | 
            +
                    to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    self.register_buffer('betas', to_torch(self.model.betas))
         | 
| 31 | 
            +
                    self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
         | 
| 32 | 
            +
                    self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    # calculations for diffusion q(x_t | x_{t-1}) and others
         | 
| 35 | 
            +
                    self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
         | 
| 36 | 
            +
                    self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
         | 
| 37 | 
            +
                    self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
         | 
| 38 | 
            +
                    self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
         | 
| 39 | 
            +
                    self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    # ddim sampling parameters
         | 
| 42 | 
            +
                    ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
         | 
| 43 | 
            +
                                                                                               ddim_timesteps=self.ddim_timesteps,
         | 
| 44 | 
            +
                                                                                               eta=ddim_eta,verbose=verbose)
         | 
| 45 | 
            +
                    self.register_buffer('ddim_sigmas', ddim_sigmas)
         | 
| 46 | 
            +
                    self.register_buffer('ddim_alphas', ddim_alphas)
         | 
| 47 | 
            +
                    self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
         | 
| 48 | 
            +
                    self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
         | 
| 49 | 
            +
                    sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
         | 
| 50 | 
            +
                        (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
         | 
| 51 | 
            +
                                    1 - self.alphas_cumprod / self.alphas_cumprod_prev))
         | 
| 52 | 
            +
                    self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                @torch.no_grad()
         | 
| 55 | 
            +
                def sample(self,
         | 
| 56 | 
            +
                           S,
         | 
| 57 | 
            +
                           batch_size,
         | 
| 58 | 
            +
                           shape,
         | 
| 59 | 
            +
                           conditioning=None,
         | 
| 60 | 
            +
                           callback=None,
         | 
| 61 | 
            +
                           normals_sequence=None,
         | 
| 62 | 
            +
                           img_callback=None,
         | 
| 63 | 
            +
                           quantize_x0=False,
         | 
| 64 | 
            +
                           eta=0.,
         | 
| 65 | 
            +
                           mask=None,
         | 
| 66 | 
            +
                           x0=None,
         | 
| 67 | 
            +
                           temperature=1.,
         | 
| 68 | 
            +
                           noise_dropout=0.,
         | 
| 69 | 
            +
                           score_corrector=None,
         | 
| 70 | 
            +
                           corrector_kwargs=None,
         | 
| 71 | 
            +
                           verbose=True,
         | 
| 72 | 
            +
                           x_T=None,
         | 
| 73 | 
            +
                           log_every_t=100,
         | 
| 74 | 
            +
                           unconditional_guidance_scale=1.,
         | 
| 75 | 
            +
                           unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
         | 
| 76 | 
            +
                           dynamic_threshold=None,
         | 
| 77 | 
            +
                           ucg_schedule=None,
         | 
| 78 | 
            +
                           **kwargs
         | 
| 79 | 
            +
                           ):
         | 
| 80 | 
            +
                    if conditioning is not None:
         | 
| 81 | 
            +
                        if isinstance(conditioning, dict):
         | 
| 82 | 
            +
                            ctmp = conditioning[list(conditioning.keys())[0]]
         | 
| 83 | 
            +
                            while isinstance(ctmp, list): ctmp = ctmp[0]
         | 
| 84 | 
            +
                            cbs = ctmp.shape[0]
         | 
| 85 | 
            +
                            if cbs != batch_size:
         | 
| 86 | 
            +
                                print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                        elif isinstance(conditioning, list):
         | 
| 89 | 
            +
                            for ctmp in conditioning:
         | 
| 90 | 
            +
                                if ctmp.shape[0] != batch_size:
         | 
| 91 | 
            +
                                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                        else:
         | 
| 94 | 
            +
                            if conditioning.shape[0] != batch_size:
         | 
| 95 | 
            +
                                print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
         | 
| 98 | 
            +
                    # sampling
         | 
| 99 | 
            +
                    C, H, W = shape
         | 
| 100 | 
            +
                    size = (batch_size, C, H, W)
         | 
| 101 | 
            +
                    print(f'Data shape for DDIM sampling is {size}, eta {eta}')
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    samples, intermediates = self.ddim_sampling(conditioning, size,
         | 
| 104 | 
            +
                                                                callback=callback,
         | 
| 105 | 
            +
                                                                img_callback=img_callback,
         | 
| 106 | 
            +
                                                                quantize_denoised=quantize_x0,
         | 
| 107 | 
            +
                                                                mask=mask, x0=x0,
         | 
| 108 | 
            +
                                                                ddim_use_original_steps=False,
         | 
| 109 | 
            +
                                                                noise_dropout=noise_dropout,
         | 
| 110 | 
            +
                                                                temperature=temperature,
         | 
| 111 | 
            +
                                                                score_corrector=score_corrector,
         | 
| 112 | 
            +
                                                                corrector_kwargs=corrector_kwargs,
         | 
| 113 | 
            +
                                                                x_T=x_T,
         | 
| 114 | 
            +
                                                                log_every_t=log_every_t,
         | 
| 115 | 
            +
                                                                unconditional_guidance_scale=unconditional_guidance_scale,
         | 
| 116 | 
            +
                                                                unconditional_conditioning=unconditional_conditioning,
         | 
| 117 | 
            +
                                                                dynamic_threshold=dynamic_threshold,
         | 
| 118 | 
            +
                                                                ucg_schedule=ucg_schedule
         | 
| 119 | 
            +
                                                                )
         | 
| 120 | 
            +
                    return samples, intermediates
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                @torch.no_grad()
         | 
| 123 | 
            +
                def ddim_sampling(self, cond, shape,
         | 
| 124 | 
            +
                                  x_T=None, ddim_use_original_steps=False,
         | 
| 125 | 
            +
                                  callback=None, timesteps=None, quantize_denoised=False,
         | 
| 126 | 
            +
                                  mask=None, x0=None, img_callback=None, log_every_t=100,
         | 
| 127 | 
            +
                                  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
         | 
| 128 | 
            +
                                  unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
         | 
| 129 | 
            +
                                  ucg_schedule=None):
         | 
| 130 | 
            +
                    device = self.model.betas.device
         | 
| 131 | 
            +
                    b = shape[0]
         | 
| 132 | 
            +
                    if x_T is None:
         | 
| 133 | 
            +
                        img = torch.randn(shape, device=device)
         | 
| 134 | 
            +
                    else:
         | 
| 135 | 
            +
                        img = x_T
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    if timesteps is None:
         | 
| 138 | 
            +
                        timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
         | 
| 139 | 
            +
                    elif timesteps is not None and not ddim_use_original_steps:
         | 
| 140 | 
            +
                        subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
         | 
| 141 | 
            +
                        timesteps = self.ddim_timesteps[:subset_end]
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    intermediates = {'x_inter': [img], 'pred_x0': [img]}
         | 
| 144 | 
            +
                    time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
         | 
| 145 | 
            +
                    total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
         | 
| 146 | 
            +
                    print(f"Running DDIM Sampling with {total_steps} timesteps")
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    for i, step in enumerate(iterator):
         | 
| 151 | 
            +
                        index = total_steps - i - 1
         | 
| 152 | 
            +
                        ts = torch.full((b,), step, device=device, dtype=torch.long)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                        if mask is not None:
         | 
| 155 | 
            +
                            assert x0 is not None
         | 
| 156 | 
            +
                            img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
         | 
| 157 | 
            +
                            img = img_orig * mask + (1. - mask) * img
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                        if ucg_schedule is not None:
         | 
| 160 | 
            +
                            assert len(ucg_schedule) == len(time_range)
         | 
| 161 | 
            +
                            unconditional_guidance_scale = ucg_schedule[i]
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                        outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
         | 
| 164 | 
            +
                                                  quantize_denoised=quantize_denoised, temperature=temperature,
         | 
| 165 | 
            +
                                                  noise_dropout=noise_dropout, score_corrector=score_corrector,
         | 
| 166 | 
            +
                                                  corrector_kwargs=corrector_kwargs,
         | 
| 167 | 
            +
                                                  unconditional_guidance_scale=unconditional_guidance_scale,
         | 
| 168 | 
            +
                                                  unconditional_conditioning=unconditional_conditioning,
         | 
| 169 | 
            +
                                                  dynamic_threshold=dynamic_threshold)
         | 
| 170 | 
            +
                        img, pred_x0 = outs
         | 
| 171 | 
            +
                        if callback: callback(i)
         | 
| 172 | 
            +
                        if img_callback: img_callback(pred_x0, i)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                        if index % log_every_t == 0 or index == total_steps - 1:
         | 
| 175 | 
            +
                            intermediates['x_inter'].append(img)
         | 
| 176 | 
            +
                            intermediates['pred_x0'].append(pred_x0)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    return img, intermediates
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                @torch.no_grad()
         | 
| 181 | 
            +
                def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
         | 
| 182 | 
            +
                                  temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
         | 
| 183 | 
            +
                                  unconditional_guidance_scale=1., unconditional_conditioning=None,
         | 
| 184 | 
            +
                                  dynamic_threshold=None):
         | 
| 185 | 
            +
                    b, *_, device = *x.shape, x.device
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
         | 
| 188 | 
            +
                        model_output = self.model.apply_model(x, t, c)
         | 
| 189 | 
            +
                    else:
         | 
| 190 | 
            +
                        x_in = torch.cat([x] * 2)
         | 
| 191 | 
            +
                        t_in = torch.cat([t] * 2)
         | 
| 192 | 
            +
                        if isinstance(c, dict):
         | 
| 193 | 
            +
                            assert isinstance(unconditional_conditioning, dict)
         | 
| 194 | 
            +
                            c_in = dict()
         | 
| 195 | 
            +
                            for k in c:
         | 
| 196 | 
            +
                                if isinstance(c[k], list):
         | 
| 197 | 
            +
                                    c_in[k] = [torch.cat([
         | 
| 198 | 
            +
                                        unconditional_conditioning[k][i],
         | 
| 199 | 
            +
                                        c[k][i]]) for i in range(len(c[k]))]
         | 
| 200 | 
            +
                                else:
         | 
| 201 | 
            +
                                    c_in[k] = torch.cat([
         | 
| 202 | 
            +
                                            unconditional_conditioning[k],
         | 
| 203 | 
            +
                                            c[k]])
         | 
| 204 | 
            +
                        elif isinstance(c, list):
         | 
| 205 | 
            +
                            c_in = list()
         | 
| 206 | 
            +
                            assert isinstance(unconditional_conditioning, list)
         | 
| 207 | 
            +
                            for i in range(len(c)):
         | 
| 208 | 
            +
                                c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
         | 
| 209 | 
            +
                        else:
         | 
| 210 | 
            +
                            c_in = torch.cat([unconditional_conditioning, c])
         | 
| 211 | 
            +
                        model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
         | 
| 212 | 
            +
                        model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    if self.model.parameterization == "v":
         | 
| 215 | 
            +
                        e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
         | 
| 216 | 
            +
                    else:
         | 
| 217 | 
            +
                        e_t = model_output
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    if score_corrector is not None:
         | 
| 220 | 
            +
                        assert self.model.parameterization == "eps", 'not implemented'
         | 
| 221 | 
            +
                        e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
         | 
| 224 | 
            +
                    alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
         | 
| 225 | 
            +
                    sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
         | 
| 226 | 
            +
                    sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
         | 
| 227 | 
            +
                    # select parameters corresponding to the currently considered timestep
         | 
| 228 | 
            +
                    a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
         | 
| 229 | 
            +
                    a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
         | 
| 230 | 
            +
                    sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
         | 
| 231 | 
            +
                    sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # current prediction for x_0
         | 
| 234 | 
            +
                    if self.model.parameterization != "v":
         | 
| 235 | 
            +
                        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
         | 
| 236 | 
            +
                    else:
         | 
| 237 | 
            +
                        pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    if quantize_denoised:
         | 
| 240 | 
            +
                        pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    if dynamic_threshold is not None:
         | 
| 243 | 
            +
                        raise NotImplementedError()
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    # direction pointing to x_t
         | 
| 246 | 
            +
                    dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
         | 
| 247 | 
            +
                    noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
         | 
| 248 | 
            +
                    if noise_dropout > 0.:
         | 
| 249 | 
            +
                        noise = torch.nn.functional.dropout(noise, p=noise_dropout)
         | 
| 250 | 
            +
                    x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
         | 
| 251 | 
            +
                    return x_prev, pred_x0
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                @torch.no_grad()
         | 
| 254 | 
            +
                def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
         | 
| 255 | 
            +
                           unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
         | 
| 256 | 
            +
                    num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    assert t_enc <= num_reference_steps
         | 
| 259 | 
            +
                    num_steps = t_enc
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    if use_original_steps:
         | 
| 262 | 
            +
                        alphas_next = self.alphas_cumprod[:num_steps]
         | 
| 263 | 
            +
                        alphas = self.alphas_cumprod_prev[:num_steps]
         | 
| 264 | 
            +
                    else:
         | 
| 265 | 
            +
                        alphas_next = self.ddim_alphas[:num_steps]
         | 
| 266 | 
            +
                        alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    x_next = x0
         | 
| 269 | 
            +
                    intermediates = []
         | 
| 270 | 
            +
                    inter_steps = []
         | 
| 271 | 
            +
                    for i in tqdm(range(num_steps), desc='Encoding Image'):
         | 
| 272 | 
            +
                        t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
         | 
| 273 | 
            +
                        if unconditional_guidance_scale == 1.:
         | 
| 274 | 
            +
                            noise_pred = self.model.apply_model(x_next, t, c)
         | 
| 275 | 
            +
                        else:
         | 
| 276 | 
            +
                            assert unconditional_conditioning is not None
         | 
| 277 | 
            +
                            e_t_uncond, noise_pred = torch.chunk(
         | 
| 278 | 
            +
                                self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
         | 
| 279 | 
            +
                                                       torch.cat((unconditional_conditioning, c))), 2)
         | 
| 280 | 
            +
                            noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                        xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
         | 
| 283 | 
            +
                        weighted_noise_pred = alphas_next[i].sqrt() * (
         | 
| 284 | 
            +
                                (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
         | 
| 285 | 
            +
                        x_next = xt_weighted + weighted_noise_pred
         | 
| 286 | 
            +
                        if return_intermediates and i % (
         | 
| 287 | 
            +
                                num_steps // return_intermediates) == 0 and i < num_steps - 1:
         | 
| 288 | 
            +
                            intermediates.append(x_next)
         | 
| 289 | 
            +
                            inter_steps.append(i)
         | 
| 290 | 
            +
                        elif return_intermediates and i >= num_steps - 2:
         | 
| 291 | 
            +
                            intermediates.append(x_next)
         | 
| 292 | 
            +
                            inter_steps.append(i)
         | 
| 293 | 
            +
                        if callback: callback(i)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
         | 
| 296 | 
            +
                    if return_intermediates:
         | 
| 297 | 
            +
                        out.update({'intermediates': intermediates})
         | 
| 298 | 
            +
                    return x_next, out
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                @torch.no_grad()
         | 
| 301 | 
            +
                def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
         | 
| 302 | 
            +
                    # fast, but does not allow for exact reconstruction
         | 
| 303 | 
            +
                    # t serves as an index to gather the correct alphas
         | 
| 304 | 
            +
                    if use_original_steps:
         | 
| 305 | 
            +
                        sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
         | 
| 306 | 
            +
                        sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
         | 
| 307 | 
            +
                    else:
         | 
| 308 | 
            +
                        sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
         | 
| 309 | 
            +
                        sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    if noise is None:
         | 
| 312 | 
            +
                        noise = torch.randn_like(x0)
         | 
| 313 | 
            +
                    return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
         | 
| 314 | 
            +
                            extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                @torch.no_grad()
         | 
| 317 | 
            +
                def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
         | 
| 318 | 
            +
                           use_original_steps=False, callback=None):
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
         | 
| 321 | 
            +
                    timesteps = timesteps[:t_start]
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    time_range = np.flip(timesteps)
         | 
| 324 | 
            +
                    total_steps = timesteps.shape[0]
         | 
| 325 | 
            +
                    print(f"Running DDIM Sampling with {total_steps} timesteps")
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
         | 
| 328 | 
            +
                    x_dec = x_latent
         | 
| 329 | 
            +
                    for i, step in enumerate(iterator):
         | 
| 330 | 
            +
                        index = total_steps - i - 1
         | 
| 331 | 
            +
                        ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
         | 
| 332 | 
            +
                        x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
         | 
| 333 | 
            +
                                                      unconditional_guidance_scale=unconditional_guidance_scale,
         | 
| 334 | 
            +
                                                      unconditional_conditioning=unconditional_conditioning)
         | 
| 335 | 
            +
                        if callback: callback(i)
         | 
| 336 | 
            +
                    return x_dec
         | 
    	
        ldm/modules/attention.py
    ADDED
    
    | @@ -0,0 +1,341 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from inspect import isfunction
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from torch import nn, einsum
         | 
| 6 | 
            +
            from einops import rearrange, repeat
         | 
| 7 | 
            +
            from typing import Optional, Any
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from ldm.modules.diffusionmodules.util import checkpoint
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            try:
         | 
| 13 | 
            +
                import xformers
         | 
| 14 | 
            +
                import xformers.ops
         | 
| 15 | 
            +
                XFORMERS_IS_AVAILBLE = True
         | 
| 16 | 
            +
            except:
         | 
| 17 | 
            +
                XFORMERS_IS_AVAILBLE = False
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # CrossAttn precision handling
         | 
| 20 | 
            +
            import os
         | 
| 21 | 
            +
            _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            def exists(val):
         | 
| 24 | 
            +
                return val is not None
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def uniq(arr):
         | 
| 28 | 
            +
                return{el: True for el in arr}.keys()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def default(val, d):
         | 
| 32 | 
            +
                if exists(val):
         | 
| 33 | 
            +
                    return val
         | 
| 34 | 
            +
                return d() if isfunction(d) else d
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def max_neg_value(t):
         | 
| 38 | 
            +
                return -torch.finfo(t.dtype).max
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def init_(tensor):
         | 
| 42 | 
            +
                dim = tensor.shape[-1]
         | 
| 43 | 
            +
                std = 1 / math.sqrt(dim)
         | 
| 44 | 
            +
                tensor.uniform_(-std, std)
         | 
| 45 | 
            +
                return tensor
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            # feedforward
         | 
| 49 | 
            +
            class GEGLU(nn.Module):
         | 
| 50 | 
            +
                def __init__(self, dim_in, dim_out):
         | 
| 51 | 
            +
                    super().__init__()
         | 
| 52 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out * 2)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def forward(self, x):
         | 
| 55 | 
            +
                    x, gate = self.proj(x).chunk(2, dim=-1)
         | 
| 56 | 
            +
                    return x * F.gelu(gate)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            class FeedForward(nn.Module):
         | 
| 60 | 
            +
                def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
         | 
| 61 | 
            +
                    super().__init__()
         | 
| 62 | 
            +
                    inner_dim = int(dim * mult)
         | 
| 63 | 
            +
                    dim_out = default(dim_out, dim)
         | 
| 64 | 
            +
                    project_in = nn.Sequential(
         | 
| 65 | 
            +
                        nn.Linear(dim, inner_dim),
         | 
| 66 | 
            +
                        nn.GELU()
         | 
| 67 | 
            +
                    ) if not glu else GEGLU(dim, inner_dim)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    self.net = nn.Sequential(
         | 
| 70 | 
            +
                        project_in,
         | 
| 71 | 
            +
                        nn.Dropout(dropout),
         | 
| 72 | 
            +
                        nn.Linear(inner_dim, dim_out)
         | 
| 73 | 
            +
                    )
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def forward(self, x):
         | 
| 76 | 
            +
                    return self.net(x)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def zero_module(module):
         | 
| 80 | 
            +
                """
         | 
| 81 | 
            +
                Zero out the parameters of a module and return it.
         | 
| 82 | 
            +
                """
         | 
| 83 | 
            +
                for p in module.parameters():
         | 
| 84 | 
            +
                    p.detach().zero_()
         | 
| 85 | 
            +
                return module
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def Normalize(in_channels):
         | 
| 89 | 
            +
                return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            class SpatialSelfAttention(nn.Module):
         | 
| 93 | 
            +
                def __init__(self, in_channels):
         | 
| 94 | 
            +
                    super().__init__()
         | 
| 95 | 
            +
                    self.in_channels = in_channels
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 98 | 
            +
                    self.q = torch.nn.Conv2d(in_channels,
         | 
| 99 | 
            +
                                             in_channels,
         | 
| 100 | 
            +
                                             kernel_size=1,
         | 
| 101 | 
            +
                                             stride=1,
         | 
| 102 | 
            +
                                             padding=0)
         | 
| 103 | 
            +
                    self.k = torch.nn.Conv2d(in_channels,
         | 
| 104 | 
            +
                                             in_channels,
         | 
| 105 | 
            +
                                             kernel_size=1,
         | 
| 106 | 
            +
                                             stride=1,
         | 
| 107 | 
            +
                                             padding=0)
         | 
| 108 | 
            +
                    self.v = torch.nn.Conv2d(in_channels,
         | 
| 109 | 
            +
                                             in_channels,
         | 
| 110 | 
            +
                                             kernel_size=1,
         | 
| 111 | 
            +
                                             stride=1,
         | 
| 112 | 
            +
                                             padding=0)
         | 
| 113 | 
            +
                    self.proj_out = torch.nn.Conv2d(in_channels,
         | 
| 114 | 
            +
                                                    in_channels,
         | 
| 115 | 
            +
                                                    kernel_size=1,
         | 
| 116 | 
            +
                                                    stride=1,
         | 
| 117 | 
            +
                                                    padding=0)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                def forward(self, x):
         | 
| 120 | 
            +
                    h_ = x
         | 
| 121 | 
            +
                    h_ = self.norm(h_)
         | 
| 122 | 
            +
                    q = self.q(h_)
         | 
| 123 | 
            +
                    k = self.k(h_)
         | 
| 124 | 
            +
                    v = self.v(h_)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    # compute attention
         | 
| 127 | 
            +
                    b,c,h,w = q.shape
         | 
| 128 | 
            +
                    q = rearrange(q, 'b c h w -> b (h w) c')
         | 
| 129 | 
            +
                    k = rearrange(k, 'b c h w -> b c (h w)')
         | 
| 130 | 
            +
                    w_ = torch.einsum('bij,bjk->bik', q, k)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    w_ = w_ * (int(c)**(-0.5))
         | 
| 133 | 
            +
                    w_ = torch.nn.functional.softmax(w_, dim=2)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # attend to values
         | 
| 136 | 
            +
                    v = rearrange(v, 'b c h w -> b c (h w)')
         | 
| 137 | 
            +
                    w_ = rearrange(w_, 'b i j -> b j i')
         | 
| 138 | 
            +
                    h_ = torch.einsum('bij,bjk->bik', v, w_)
         | 
| 139 | 
            +
                    h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
         | 
| 140 | 
            +
                    h_ = self.proj_out(h_)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    return x+h_
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            class CrossAttention(nn.Module):
         | 
| 146 | 
            +
                def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
         | 
| 147 | 
            +
                    super().__init__()
         | 
| 148 | 
            +
                    inner_dim = dim_head * heads
         | 
| 149 | 
            +
                    context_dim = default(context_dim, query_dim)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    self.scale = dim_head ** -0.5
         | 
| 152 | 
            +
                    self.heads = heads
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
         | 
| 155 | 
            +
                    self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 156 | 
            +
                    self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    self.to_out = nn.Sequential(
         | 
| 159 | 
            +
                        nn.Linear(inner_dim, query_dim),
         | 
| 160 | 
            +
                        nn.Dropout(dropout)
         | 
| 161 | 
            +
                    )
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def forward(self, x, context=None, mask=None):
         | 
| 164 | 
            +
                    h = self.heads
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    q = self.to_q(x)
         | 
| 167 | 
            +
                    context = default(context, x)
         | 
| 168 | 
            +
                    k = self.to_k(context)
         | 
| 169 | 
            +
                    v = self.to_v(context)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    # force cast to fp32 to avoid overflowing
         | 
| 174 | 
            +
                    if _ATTN_PRECISION =="fp32":
         | 
| 175 | 
            +
                        with torch.autocast(enabled=False, device_type = 'cuda'):
         | 
| 176 | 
            +
                            q, k = q.float(), k.float()
         | 
| 177 | 
            +
                            sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
         | 
| 178 | 
            +
                    else:
         | 
| 179 | 
            +
                        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
         | 
| 180 | 
            +
                    
         | 
| 181 | 
            +
                    del q, k
         | 
| 182 | 
            +
                
         | 
| 183 | 
            +
                    if exists(mask):
         | 
| 184 | 
            +
                        mask = rearrange(mask, 'b ... -> b (...)')
         | 
| 185 | 
            +
                        max_neg_value = -torch.finfo(sim.dtype).max
         | 
| 186 | 
            +
                        mask = repeat(mask, 'b j -> (b h) () j', h=h)
         | 
| 187 | 
            +
                        sim.masked_fill_(~mask, max_neg_value)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    # attention, what we cannot get enough of
         | 
| 190 | 
            +
                    sim = sim.softmax(dim=-1)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    out = einsum('b i j, b j d -> b i d', sim, v)
         | 
| 193 | 
            +
                    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
         | 
| 194 | 
            +
                    return self.to_out(out)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
             | 
| 197 | 
            +
            class MemoryEfficientCrossAttention(nn.Module):
         | 
| 198 | 
            +
                # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
         | 
| 199 | 
            +
                def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
         | 
| 200 | 
            +
                    super().__init__()
         | 
| 201 | 
            +
                    print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
         | 
| 202 | 
            +
                          f"{heads} heads.")
         | 
| 203 | 
            +
                    inner_dim = dim_head * heads
         | 
| 204 | 
            +
                    context_dim = default(context_dim, query_dim)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    self.heads = heads
         | 
| 207 | 
            +
                    self.dim_head = dim_head
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
         | 
| 210 | 
            +
                    self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 211 | 
            +
                    self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
         | 
| 214 | 
            +
                    self.attention_op: Optional[Any] = None
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def forward(self, x, context=None, mask=None):
         | 
| 217 | 
            +
                    q = self.to_q(x)
         | 
| 218 | 
            +
                    context = default(context, x)
         | 
| 219 | 
            +
                    k = self.to_k(context)
         | 
| 220 | 
            +
                    v = self.to_v(context)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    b, _, _ = q.shape
         | 
| 223 | 
            +
                    q, k, v = map(
         | 
| 224 | 
            +
                        lambda t: t.unsqueeze(3)
         | 
| 225 | 
            +
                        .reshape(b, t.shape[1], self.heads, self.dim_head)
         | 
| 226 | 
            +
                        .permute(0, 2, 1, 3)
         | 
| 227 | 
            +
                        .reshape(b * self.heads, t.shape[1], self.dim_head)
         | 
| 228 | 
            +
                        .contiguous(),
         | 
| 229 | 
            +
                        (q, k, v),
         | 
| 230 | 
            +
                    )
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    # actually compute the attention, what we cannot get enough of
         | 
| 233 | 
            +
                    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    if exists(mask):
         | 
| 236 | 
            +
                        raise NotImplementedError
         | 
| 237 | 
            +
                    out = (
         | 
| 238 | 
            +
                        out.unsqueeze(0)
         | 
| 239 | 
            +
                        .reshape(b, self.heads, out.shape[1], self.dim_head)
         | 
| 240 | 
            +
                        .permute(0, 2, 1, 3)
         | 
| 241 | 
            +
                        .reshape(b, out.shape[1], self.heads * self.dim_head)
         | 
| 242 | 
            +
                    )
         | 
| 243 | 
            +
                    return self.to_out(out)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
             | 
| 246 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 247 | 
            +
                ATTENTION_MODES = {
         | 
| 248 | 
            +
                    "softmax": CrossAttention,  # vanilla attention
         | 
| 249 | 
            +
                    "softmax-xformers": MemoryEfficientCrossAttention
         | 
| 250 | 
            +
                }
         | 
| 251 | 
            +
                def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
         | 
| 252 | 
            +
                             disable_self_attn=False):
         | 
| 253 | 
            +
                    super().__init__()
         | 
| 254 | 
            +
                    attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
         | 
| 255 | 
            +
                    assert attn_mode in self.ATTENTION_MODES
         | 
| 256 | 
            +
                    attn_cls = self.ATTENTION_MODES[attn_mode]
         | 
| 257 | 
            +
                    self.disable_self_attn = disable_self_attn
         | 
| 258 | 
            +
                    self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
         | 
| 259 | 
            +
                                          context_dim=context_dim if self.disable_self_attn else None)  # is a self-attention if not self.disable_self_attn
         | 
| 260 | 
            +
                    self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
         | 
| 261 | 
            +
                    self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
         | 
| 262 | 
            +
                                          heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
         | 
| 263 | 
            +
                    self.norm1 = nn.LayerNorm(dim)
         | 
| 264 | 
            +
                    self.norm2 = nn.LayerNorm(dim)
         | 
| 265 | 
            +
                    self.norm3 = nn.LayerNorm(dim)
         | 
| 266 | 
            +
                    self.checkpoint = checkpoint
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                def forward(self, x, context=None):
         | 
| 269 | 
            +
                    return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                def _forward(self, x, context=None):
         | 
| 272 | 
            +
                    x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
         | 
| 273 | 
            +
                    x = self.attn2(self.norm2(x), context=context) + x
         | 
| 274 | 
            +
                    x = self.ff(self.norm3(x)) + x
         | 
| 275 | 
            +
                    return x
         | 
| 276 | 
            +
             | 
| 277 | 
            +
             | 
| 278 | 
            +
            class SpatialTransformer(nn.Module):
         | 
| 279 | 
            +
                """
         | 
| 280 | 
            +
                Transformer block for image-like data.
         | 
| 281 | 
            +
                First, project the input (aka embedding)
         | 
| 282 | 
            +
                and reshape to b, t, d.
         | 
| 283 | 
            +
                Then apply standard transformer action.
         | 
| 284 | 
            +
                Finally, reshape to image
         | 
| 285 | 
            +
                NEW: use_linear for more efficiency instead of the 1x1 convs
         | 
| 286 | 
            +
                """
         | 
| 287 | 
            +
                def __init__(self, in_channels, n_heads, d_head,
         | 
| 288 | 
            +
                             depth=1, dropout=0., context_dim=None,
         | 
| 289 | 
            +
                             disable_self_attn=False, use_linear=False,
         | 
| 290 | 
            +
                             use_checkpoint=True):
         | 
| 291 | 
            +
                    super().__init__()
         | 
| 292 | 
            +
                    if exists(context_dim) and not isinstance(context_dim, list):
         | 
| 293 | 
            +
                        context_dim = [context_dim]
         | 
| 294 | 
            +
                    self.in_channels = in_channels
         | 
| 295 | 
            +
                    inner_dim = n_heads * d_head
         | 
| 296 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 297 | 
            +
                    if not use_linear:
         | 
| 298 | 
            +
                        self.proj_in = nn.Conv2d(in_channels,
         | 
| 299 | 
            +
                                                 inner_dim,
         | 
| 300 | 
            +
                                                 kernel_size=1,
         | 
| 301 | 
            +
                                                 stride=1,
         | 
| 302 | 
            +
                                                 padding=0)
         | 
| 303 | 
            +
                    else:
         | 
| 304 | 
            +
                        self.proj_in = nn.Linear(in_channels, inner_dim)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 307 | 
            +
                        [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
         | 
| 308 | 
            +
                                               disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
         | 
| 309 | 
            +
                            for d in range(depth)]
         | 
| 310 | 
            +
                    )
         | 
| 311 | 
            +
                    if not use_linear:
         | 
| 312 | 
            +
                        self.proj_out = zero_module(nn.Conv2d(inner_dim,
         | 
| 313 | 
            +
                                                              in_channels,
         | 
| 314 | 
            +
                                                              kernel_size=1,
         | 
| 315 | 
            +
                                                              stride=1,
         | 
| 316 | 
            +
                                                              padding=0))
         | 
| 317 | 
            +
                    else:
         | 
| 318 | 
            +
                        self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
         | 
| 319 | 
            +
                    self.use_linear = use_linear
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                def forward(self, x, context=None):
         | 
| 322 | 
            +
                    # note: if no context is given, cross-attention defaults to self-attention
         | 
| 323 | 
            +
                    if not isinstance(context, list):
         | 
| 324 | 
            +
                        context = [context]
         | 
| 325 | 
            +
                    b, c, h, w = x.shape
         | 
| 326 | 
            +
                    x_in = x
         | 
| 327 | 
            +
                    x = self.norm(x)
         | 
| 328 | 
            +
                    if not self.use_linear:
         | 
| 329 | 
            +
                        x = self.proj_in(x)
         | 
| 330 | 
            +
                    x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
         | 
| 331 | 
            +
                    if self.use_linear:
         | 
| 332 | 
            +
                        x = self.proj_in(x)
         | 
| 333 | 
            +
                    for i, block in enumerate(self.transformer_blocks):
         | 
| 334 | 
            +
                        x = block(x, context=context[i])
         | 
| 335 | 
            +
                    if self.use_linear:
         | 
| 336 | 
            +
                        x = self.proj_out(x)
         | 
| 337 | 
            +
                    x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
         | 
| 338 | 
            +
                    if not self.use_linear:
         | 
| 339 | 
            +
                        x = self.proj_out(x)
         | 
| 340 | 
            +
                    return x + x_in
         | 
| 341 | 
            +
             | 
    	
        ldm/modules/diffusionmodules/model.py
    ADDED
    
    | @@ -0,0 +1,852 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # pytorch_diffusion + derived encoder decoder
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
            from typing import Optional, Any
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from ldm.modules.attention import MemoryEfficientCrossAttention
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            try:
         | 
| 12 | 
            +
                import xformers
         | 
| 13 | 
            +
                import xformers.ops
         | 
| 14 | 
            +
                XFORMERS_IS_AVAILBLE = True
         | 
| 15 | 
            +
            except:
         | 
| 16 | 
            +
                XFORMERS_IS_AVAILBLE = False
         | 
| 17 | 
            +
                print("No module 'xformers'. Proceeding without it.")
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def get_timestep_embedding(timesteps, embedding_dim):
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models:
         | 
| 23 | 
            +
                From Fairseq.
         | 
| 24 | 
            +
                Build sinusoidal embeddings.
         | 
| 25 | 
            +
                This matches the implementation in tensor2tensor, but differs slightly
         | 
| 26 | 
            +
                from the description in Section 3.5 of "Attention Is All You Need".
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
                assert len(timesteps.shape) == 1
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                half_dim = embedding_dim // 2
         | 
| 31 | 
            +
                emb = math.log(10000) / (half_dim - 1)
         | 
| 32 | 
            +
                emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
         | 
| 33 | 
            +
                emb = emb.to(device=timesteps.device)
         | 
| 34 | 
            +
                emb = timesteps.float()[:, None] * emb[None, :]
         | 
| 35 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         | 
| 36 | 
            +
                if embedding_dim % 2 == 1:  # zero pad
         | 
| 37 | 
            +
                    emb = torch.nn.functional.pad(emb, (0,1,0,0))
         | 
| 38 | 
            +
                return emb
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def nonlinearity(x):
         | 
| 42 | 
            +
                # swish
         | 
| 43 | 
            +
                return x*torch.sigmoid(x)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def Normalize(in_channels, num_groups=32):
         | 
| 47 | 
            +
                return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            class Upsample(nn.Module):
         | 
| 51 | 
            +
                def __init__(self, in_channels, with_conv):
         | 
| 52 | 
            +
                    super().__init__()
         | 
| 53 | 
            +
                    self.with_conv = with_conv
         | 
| 54 | 
            +
                    if self.with_conv:
         | 
| 55 | 
            +
                        self.conv = torch.nn.Conv2d(in_channels,
         | 
| 56 | 
            +
                                                    in_channels,
         | 
| 57 | 
            +
                                                    kernel_size=3,
         | 
| 58 | 
            +
                                                    stride=1,
         | 
| 59 | 
            +
                                                    padding=1)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def forward(self, x):
         | 
| 62 | 
            +
                    x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
         | 
| 63 | 
            +
                    if self.with_conv:
         | 
| 64 | 
            +
                        x = self.conv(x)
         | 
| 65 | 
            +
                    return x
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            class Downsample(nn.Module):
         | 
| 69 | 
            +
                def __init__(self, in_channels, with_conv):
         | 
| 70 | 
            +
                    super().__init__()
         | 
| 71 | 
            +
                    self.with_conv = with_conv
         | 
| 72 | 
            +
                    if self.with_conv:
         | 
| 73 | 
            +
                        # no asymmetric padding in torch conv, must do it ourselves
         | 
| 74 | 
            +
                        self.conv = torch.nn.Conv2d(in_channels,
         | 
| 75 | 
            +
                                                    in_channels,
         | 
| 76 | 
            +
                                                    kernel_size=3,
         | 
| 77 | 
            +
                                                    stride=2,
         | 
| 78 | 
            +
                                                    padding=0)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def forward(self, x):
         | 
| 81 | 
            +
                    if self.with_conv:
         | 
| 82 | 
            +
                        pad = (0,1,0,1)
         | 
| 83 | 
            +
                        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
         | 
| 84 | 
            +
                        x = self.conv(x)
         | 
| 85 | 
            +
                    else:
         | 
| 86 | 
            +
                        x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
         | 
| 87 | 
            +
                    return x
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            class ResnetBlock(nn.Module):
         | 
| 91 | 
            +
                def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
         | 
| 92 | 
            +
                             dropout, temb_channels=512):
         | 
| 93 | 
            +
                    super().__init__()
         | 
| 94 | 
            +
                    self.in_channels = in_channels
         | 
| 95 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 96 | 
            +
                    self.out_channels = out_channels
         | 
| 97 | 
            +
                    self.use_conv_shortcut = conv_shortcut
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.norm1 = Normalize(in_channels)
         | 
| 100 | 
            +
                    self.conv1 = torch.nn.Conv2d(in_channels,
         | 
| 101 | 
            +
                                                 out_channels,
         | 
| 102 | 
            +
                                                 kernel_size=3,
         | 
| 103 | 
            +
                                                 stride=1,
         | 
| 104 | 
            +
                                                 padding=1)
         | 
| 105 | 
            +
                    if temb_channels > 0:
         | 
| 106 | 
            +
                        self.temb_proj = torch.nn.Linear(temb_channels,
         | 
| 107 | 
            +
                                                         out_channels)
         | 
| 108 | 
            +
                    self.norm2 = Normalize(out_channels)
         | 
| 109 | 
            +
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 110 | 
            +
                    self.conv2 = torch.nn.Conv2d(out_channels,
         | 
| 111 | 
            +
                                                 out_channels,
         | 
| 112 | 
            +
                                                 kernel_size=3,
         | 
| 113 | 
            +
                                                 stride=1,
         | 
| 114 | 
            +
                                                 padding=1)
         | 
| 115 | 
            +
                    if self.in_channels != self.out_channels:
         | 
| 116 | 
            +
                        if self.use_conv_shortcut:
         | 
| 117 | 
            +
                            self.conv_shortcut = torch.nn.Conv2d(in_channels,
         | 
| 118 | 
            +
                                                                 out_channels,
         | 
| 119 | 
            +
                                                                 kernel_size=3,
         | 
| 120 | 
            +
                                                                 stride=1,
         | 
| 121 | 
            +
                                                                 padding=1)
         | 
| 122 | 
            +
                        else:
         | 
| 123 | 
            +
                            self.nin_shortcut = torch.nn.Conv2d(in_channels,
         | 
| 124 | 
            +
                                                                out_channels,
         | 
| 125 | 
            +
                                                                kernel_size=1,
         | 
| 126 | 
            +
                                                                stride=1,
         | 
| 127 | 
            +
                                                                padding=0)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def forward(self, x, temb):
         | 
| 130 | 
            +
                    h = x
         | 
| 131 | 
            +
                    h = self.norm1(h)
         | 
| 132 | 
            +
                    h = nonlinearity(h)
         | 
| 133 | 
            +
                    h = self.conv1(h)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    if temb is not None:
         | 
| 136 | 
            +
                        h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    h = self.norm2(h)
         | 
| 139 | 
            +
                    h = nonlinearity(h)
         | 
| 140 | 
            +
                    h = self.dropout(h)
         | 
| 141 | 
            +
                    h = self.conv2(h)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    if self.in_channels != self.out_channels:
         | 
| 144 | 
            +
                        if self.use_conv_shortcut:
         | 
| 145 | 
            +
                            x = self.conv_shortcut(x)
         | 
| 146 | 
            +
                        else:
         | 
| 147 | 
            +
                            x = self.nin_shortcut(x)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    return x+h
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
            class AttnBlock(nn.Module):
         | 
| 153 | 
            +
                def __init__(self, in_channels):
         | 
| 154 | 
            +
                    super().__init__()
         | 
| 155 | 
            +
                    self.in_channels = in_channels
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 158 | 
            +
                    self.q = torch.nn.Conv2d(in_channels,
         | 
| 159 | 
            +
                                             in_channels,
         | 
| 160 | 
            +
                                             kernel_size=1,
         | 
| 161 | 
            +
                                             stride=1,
         | 
| 162 | 
            +
                                             padding=0)
         | 
| 163 | 
            +
                    self.k = torch.nn.Conv2d(in_channels,
         | 
| 164 | 
            +
                                             in_channels,
         | 
| 165 | 
            +
                                             kernel_size=1,
         | 
| 166 | 
            +
                                             stride=1,
         | 
| 167 | 
            +
                                             padding=0)
         | 
| 168 | 
            +
                    self.v = torch.nn.Conv2d(in_channels,
         | 
| 169 | 
            +
                                             in_channels,
         | 
| 170 | 
            +
                                             kernel_size=1,
         | 
| 171 | 
            +
                                             stride=1,
         | 
| 172 | 
            +
                                             padding=0)
         | 
| 173 | 
            +
                    self.proj_out = torch.nn.Conv2d(in_channels,
         | 
| 174 | 
            +
                                                    in_channels,
         | 
| 175 | 
            +
                                                    kernel_size=1,
         | 
| 176 | 
            +
                                                    stride=1,
         | 
| 177 | 
            +
                                                    padding=0)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def forward(self, x):
         | 
| 180 | 
            +
                    h_ = x
         | 
| 181 | 
            +
                    h_ = self.norm(h_)
         | 
| 182 | 
            +
                    q = self.q(h_)
         | 
| 183 | 
            +
                    k = self.k(h_)
         | 
| 184 | 
            +
                    v = self.v(h_)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    # compute attention
         | 
| 187 | 
            +
                    b,c,h,w = q.shape
         | 
| 188 | 
            +
                    q = q.reshape(b,c,h*w)
         | 
| 189 | 
            +
                    q = q.permute(0,2,1)   # b,hw,c
         | 
| 190 | 
            +
                    k = k.reshape(b,c,h*w) # b,c,hw
         | 
| 191 | 
            +
                    w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
         | 
| 192 | 
            +
                    w_ = w_ * (int(c)**(-0.5))
         | 
| 193 | 
            +
                    w_ = torch.nn.functional.softmax(w_, dim=2)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    # attend to values
         | 
| 196 | 
            +
                    v = v.reshape(b,c,h*w)
         | 
| 197 | 
            +
                    w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
         | 
| 198 | 
            +
                    h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
         | 
| 199 | 
            +
                    h_ = h_.reshape(b,c,h,w)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    h_ = self.proj_out(h_)
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    return x+h_
         | 
| 204 | 
            +
             | 
| 205 | 
            +
            class MemoryEfficientAttnBlock(nn.Module):
         | 
| 206 | 
            +
                """
         | 
| 207 | 
            +
                    Uses xformers efficient implementation,
         | 
| 208 | 
            +
                    see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
         | 
| 209 | 
            +
                    Note: this is a single-head self-attention operation
         | 
| 210 | 
            +
                """
         | 
| 211 | 
            +
                #
         | 
| 212 | 
            +
                def __init__(self, in_channels):
         | 
| 213 | 
            +
                    super().__init__()
         | 
| 214 | 
            +
                    self.in_channels = in_channels
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 217 | 
            +
                    self.q = torch.nn.Conv2d(in_channels,
         | 
| 218 | 
            +
                                             in_channels,
         | 
| 219 | 
            +
                                             kernel_size=1,
         | 
| 220 | 
            +
                                             stride=1,
         | 
| 221 | 
            +
                                             padding=0)
         | 
| 222 | 
            +
                    self.k = torch.nn.Conv2d(in_channels,
         | 
| 223 | 
            +
                                             in_channels,
         | 
| 224 | 
            +
                                             kernel_size=1,
         | 
| 225 | 
            +
                                             stride=1,
         | 
| 226 | 
            +
                                             padding=0)
         | 
| 227 | 
            +
                    self.v = torch.nn.Conv2d(in_channels,
         | 
| 228 | 
            +
                                             in_channels,
         | 
| 229 | 
            +
                                             kernel_size=1,
         | 
| 230 | 
            +
                                             stride=1,
         | 
| 231 | 
            +
                                             padding=0)
         | 
| 232 | 
            +
                    self.proj_out = torch.nn.Conv2d(in_channels,
         | 
| 233 | 
            +
                                                    in_channels,
         | 
| 234 | 
            +
                                                    kernel_size=1,
         | 
| 235 | 
            +
                                                    stride=1,
         | 
| 236 | 
            +
                                                    padding=0)
         | 
| 237 | 
            +
                    self.attention_op: Optional[Any] = None
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                def forward(self, x):
         | 
| 240 | 
            +
                    h_ = x
         | 
| 241 | 
            +
                    h_ = self.norm(h_)
         | 
| 242 | 
            +
                    q = self.q(h_)
         | 
| 243 | 
            +
                    k = self.k(h_)
         | 
| 244 | 
            +
                    v = self.v(h_)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    # compute attention
         | 
| 247 | 
            +
                    B, C, H, W = q.shape
         | 
| 248 | 
            +
                    q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    q, k, v = map(
         | 
| 251 | 
            +
                        lambda t: t.unsqueeze(3)
         | 
| 252 | 
            +
                        .reshape(B, t.shape[1], 1, C)
         | 
| 253 | 
            +
                        .permute(0, 2, 1, 3)
         | 
| 254 | 
            +
                        .reshape(B * 1, t.shape[1], C)
         | 
| 255 | 
            +
                        .contiguous(),
         | 
| 256 | 
            +
                        (q, k, v),
         | 
| 257 | 
            +
                    )
         | 
| 258 | 
            +
                    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    out = (
         | 
| 261 | 
            +
                        out.unsqueeze(0)
         | 
| 262 | 
            +
                        .reshape(B, 1, out.shape[1], C)
         | 
| 263 | 
            +
                        .permute(0, 2, 1, 3)
         | 
| 264 | 
            +
                        .reshape(B, out.shape[1], C)
         | 
| 265 | 
            +
                    )
         | 
| 266 | 
            +
                    out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
         | 
| 267 | 
            +
                    out = self.proj_out(out)
         | 
| 268 | 
            +
                    return x+out
         | 
| 269 | 
            +
             | 
| 270 | 
            +
             | 
| 271 | 
            +
            class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
         | 
| 272 | 
            +
                def forward(self, x, context=None, mask=None):
         | 
| 273 | 
            +
                    b, c, h, w = x.shape
         | 
| 274 | 
            +
                    x = rearrange(x, 'b c h w -> b (h w) c')
         | 
| 275 | 
            +
                    out = super().forward(x, context=context, mask=mask)
         | 
| 276 | 
            +
                    out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
         | 
| 277 | 
            +
                    return x + out
         | 
| 278 | 
            +
             | 
| 279 | 
            +
             | 
| 280 | 
            +
            def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
         | 
| 281 | 
            +
                assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
         | 
| 282 | 
            +
                if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
         | 
| 283 | 
            +
                    attn_type = "vanilla-xformers"
         | 
| 284 | 
            +
                print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
         | 
| 285 | 
            +
                if attn_type == "vanilla":
         | 
| 286 | 
            +
                    assert attn_kwargs is None
         | 
| 287 | 
            +
                    return AttnBlock(in_channels)
         | 
| 288 | 
            +
                elif attn_type == "vanilla-xformers":
         | 
| 289 | 
            +
                    print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
         | 
| 290 | 
            +
                    return MemoryEfficientAttnBlock(in_channels)
         | 
| 291 | 
            +
                elif type == "memory-efficient-cross-attn":
         | 
| 292 | 
            +
                    attn_kwargs["query_dim"] = in_channels
         | 
| 293 | 
            +
                    return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
         | 
| 294 | 
            +
                elif attn_type == "none":
         | 
| 295 | 
            +
                    return nn.Identity(in_channels)
         | 
| 296 | 
            +
                else:
         | 
| 297 | 
            +
                    raise NotImplementedError()
         | 
| 298 | 
            +
             | 
| 299 | 
            +
             | 
| 300 | 
            +
            class Model(nn.Module):
         | 
| 301 | 
            +
                def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
         | 
| 302 | 
            +
                             attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
         | 
| 303 | 
            +
                             resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
         | 
| 304 | 
            +
                    super().__init__()
         | 
| 305 | 
            +
                    if use_linear_attn: attn_type = "linear"
         | 
| 306 | 
            +
                    self.ch = ch
         | 
| 307 | 
            +
                    self.temb_ch = self.ch*4
         | 
| 308 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 309 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 310 | 
            +
                    self.resolution = resolution
         | 
| 311 | 
            +
                    self.in_channels = in_channels
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    self.use_timestep = use_timestep
         | 
| 314 | 
            +
                    if self.use_timestep:
         | 
| 315 | 
            +
                        # timestep embedding
         | 
| 316 | 
            +
                        self.temb = nn.Module()
         | 
| 317 | 
            +
                        self.temb.dense = nn.ModuleList([
         | 
| 318 | 
            +
                            torch.nn.Linear(self.ch,
         | 
| 319 | 
            +
                                            self.temb_ch),
         | 
| 320 | 
            +
                            torch.nn.Linear(self.temb_ch,
         | 
| 321 | 
            +
                                            self.temb_ch),
         | 
| 322 | 
            +
                        ])
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    # downsampling
         | 
| 325 | 
            +
                    self.conv_in = torch.nn.Conv2d(in_channels,
         | 
| 326 | 
            +
                                                   self.ch,
         | 
| 327 | 
            +
                                                   kernel_size=3,
         | 
| 328 | 
            +
                                                   stride=1,
         | 
| 329 | 
            +
                                                   padding=1)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    curr_res = resolution
         | 
| 332 | 
            +
                    in_ch_mult = (1,)+tuple(ch_mult)
         | 
| 333 | 
            +
                    self.down = nn.ModuleList()
         | 
| 334 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 335 | 
            +
                        block = nn.ModuleList()
         | 
| 336 | 
            +
                        attn = nn.ModuleList()
         | 
| 337 | 
            +
                        block_in = ch*in_ch_mult[i_level]
         | 
| 338 | 
            +
                        block_out = ch*ch_mult[i_level]
         | 
| 339 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 340 | 
            +
                            block.append(ResnetBlock(in_channels=block_in,
         | 
| 341 | 
            +
                                                     out_channels=block_out,
         | 
| 342 | 
            +
                                                     temb_channels=self.temb_ch,
         | 
| 343 | 
            +
                                                     dropout=dropout))
         | 
| 344 | 
            +
                            block_in = block_out
         | 
| 345 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 346 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 347 | 
            +
                        down = nn.Module()
         | 
| 348 | 
            +
                        down.block = block
         | 
| 349 | 
            +
                        down.attn = attn
         | 
| 350 | 
            +
                        if i_level != self.num_resolutions-1:
         | 
| 351 | 
            +
                            down.downsample = Downsample(block_in, resamp_with_conv)
         | 
| 352 | 
            +
                            curr_res = curr_res // 2
         | 
| 353 | 
            +
                        self.down.append(down)
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    # middle
         | 
| 356 | 
            +
                    self.mid = nn.Module()
         | 
| 357 | 
            +
                    self.mid.block_1 = ResnetBlock(in_channels=block_in,
         | 
| 358 | 
            +
                                                   out_channels=block_in,
         | 
| 359 | 
            +
                                                   temb_channels=self.temb_ch,
         | 
| 360 | 
            +
                                                   dropout=dropout)
         | 
| 361 | 
            +
                    self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
         | 
| 362 | 
            +
                    self.mid.block_2 = ResnetBlock(in_channels=block_in,
         | 
| 363 | 
            +
                                                   out_channels=block_in,
         | 
| 364 | 
            +
                                                   temb_channels=self.temb_ch,
         | 
| 365 | 
            +
                                                   dropout=dropout)
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    # upsampling
         | 
| 368 | 
            +
                    self.up = nn.ModuleList()
         | 
| 369 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 370 | 
            +
                        block = nn.ModuleList()
         | 
| 371 | 
            +
                        attn = nn.ModuleList()
         | 
| 372 | 
            +
                        block_out = ch*ch_mult[i_level]
         | 
| 373 | 
            +
                        skip_in = ch*ch_mult[i_level]
         | 
| 374 | 
            +
                        for i_block in range(self.num_res_blocks+1):
         | 
| 375 | 
            +
                            if i_block == self.num_res_blocks:
         | 
| 376 | 
            +
                                skip_in = ch*in_ch_mult[i_level]
         | 
| 377 | 
            +
                            block.append(ResnetBlock(in_channels=block_in+skip_in,
         | 
| 378 | 
            +
                                                     out_channels=block_out,
         | 
| 379 | 
            +
                                                     temb_channels=self.temb_ch,
         | 
| 380 | 
            +
                                                     dropout=dropout))
         | 
| 381 | 
            +
                            block_in = block_out
         | 
| 382 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 383 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 384 | 
            +
                        up = nn.Module()
         | 
| 385 | 
            +
                        up.block = block
         | 
| 386 | 
            +
                        up.attn = attn
         | 
| 387 | 
            +
                        if i_level != 0:
         | 
| 388 | 
            +
                            up.upsample = Upsample(block_in, resamp_with_conv)
         | 
| 389 | 
            +
                            curr_res = curr_res * 2
         | 
| 390 | 
            +
                        self.up.insert(0, up) # prepend to get consistent order
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    # end
         | 
| 393 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 394 | 
            +
                    self.conv_out = torch.nn.Conv2d(block_in,
         | 
| 395 | 
            +
                                                    out_ch,
         | 
| 396 | 
            +
                                                    kernel_size=3,
         | 
| 397 | 
            +
                                                    stride=1,
         | 
| 398 | 
            +
                                                    padding=1)
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                def forward(self, x, t=None, context=None):
         | 
| 401 | 
            +
                    #assert x.shape[2] == x.shape[3] == self.resolution
         | 
| 402 | 
            +
                    if context is not None:
         | 
| 403 | 
            +
                        # assume aligned context, cat along channel axis
         | 
| 404 | 
            +
                        x = torch.cat((x, context), dim=1)
         | 
| 405 | 
            +
                    if self.use_timestep:
         | 
| 406 | 
            +
                        # timestep embedding
         | 
| 407 | 
            +
                        assert t is not None
         | 
| 408 | 
            +
                        temb = get_timestep_embedding(t, self.ch)
         | 
| 409 | 
            +
                        temb = self.temb.dense[0](temb)
         | 
| 410 | 
            +
                        temb = nonlinearity(temb)
         | 
| 411 | 
            +
                        temb = self.temb.dense[1](temb)
         | 
| 412 | 
            +
                    else:
         | 
| 413 | 
            +
                        temb = None
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    # downsampling
         | 
| 416 | 
            +
                    hs = [self.conv_in(x)]
         | 
| 417 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 418 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 419 | 
            +
                            h = self.down[i_level].block[i_block](hs[-1], temb)
         | 
| 420 | 
            +
                            if len(self.down[i_level].attn) > 0:
         | 
| 421 | 
            +
                                h = self.down[i_level].attn[i_block](h)
         | 
| 422 | 
            +
                            hs.append(h)
         | 
| 423 | 
            +
                        if i_level != self.num_resolutions-1:
         | 
| 424 | 
            +
                            hs.append(self.down[i_level].downsample(hs[-1]))
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    # middle
         | 
| 427 | 
            +
                    h = hs[-1]
         | 
| 428 | 
            +
                    h = self.mid.block_1(h, temb)
         | 
| 429 | 
            +
                    h = self.mid.attn_1(h)
         | 
| 430 | 
            +
                    h = self.mid.block_2(h, temb)
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    # upsampling
         | 
| 433 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 434 | 
            +
                        for i_block in range(self.num_res_blocks+1):
         | 
| 435 | 
            +
                            h = self.up[i_level].block[i_block](
         | 
| 436 | 
            +
                                torch.cat([h, hs.pop()], dim=1), temb)
         | 
| 437 | 
            +
                            if len(self.up[i_level].attn) > 0:
         | 
| 438 | 
            +
                                h = self.up[i_level].attn[i_block](h)
         | 
| 439 | 
            +
                        if i_level != 0:
         | 
| 440 | 
            +
                            h = self.up[i_level].upsample(h)
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                    # end
         | 
| 443 | 
            +
                    h = self.norm_out(h)
         | 
| 444 | 
            +
                    h = nonlinearity(h)
         | 
| 445 | 
            +
                    h = self.conv_out(h)
         | 
| 446 | 
            +
                    return h
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                def get_last_layer(self):
         | 
| 449 | 
            +
                    return self.conv_out.weight
         | 
| 450 | 
            +
             | 
| 451 | 
            +
             | 
| 452 | 
            +
            class Encoder(nn.Module):
         | 
| 453 | 
            +
                def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
         | 
| 454 | 
            +
                             attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
         | 
| 455 | 
            +
                             resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
         | 
| 456 | 
            +
                             **ignore_kwargs):
         | 
| 457 | 
            +
                    super().__init__()
         | 
| 458 | 
            +
                    if use_linear_attn: attn_type = "linear"
         | 
| 459 | 
            +
                    self.ch = ch
         | 
| 460 | 
            +
                    self.temb_ch = 0
         | 
| 461 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 462 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 463 | 
            +
                    self.resolution = resolution
         | 
| 464 | 
            +
                    self.in_channels = in_channels
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    # downsampling
         | 
| 467 | 
            +
                    self.conv_in = torch.nn.Conv2d(in_channels,
         | 
| 468 | 
            +
                                                   self.ch,
         | 
| 469 | 
            +
                                                   kernel_size=3,
         | 
| 470 | 
            +
                                                   stride=1,
         | 
| 471 | 
            +
                                                   padding=1)
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    curr_res = resolution
         | 
| 474 | 
            +
                    in_ch_mult = (1,)+tuple(ch_mult)
         | 
| 475 | 
            +
                    self.in_ch_mult = in_ch_mult
         | 
| 476 | 
            +
                    self.down = nn.ModuleList()
         | 
| 477 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 478 | 
            +
                        block = nn.ModuleList()
         | 
| 479 | 
            +
                        attn = nn.ModuleList()
         | 
| 480 | 
            +
                        block_in = ch*in_ch_mult[i_level]
         | 
| 481 | 
            +
                        block_out = ch*ch_mult[i_level]
         | 
| 482 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 483 | 
            +
                            block.append(ResnetBlock(in_channels=block_in,
         | 
| 484 | 
            +
                                                     out_channels=block_out,
         | 
| 485 | 
            +
                                                     temb_channels=self.temb_ch,
         | 
| 486 | 
            +
                                                     dropout=dropout))
         | 
| 487 | 
            +
                            block_in = block_out
         | 
| 488 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 489 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 490 | 
            +
                        down = nn.Module()
         | 
| 491 | 
            +
                        down.block = block
         | 
| 492 | 
            +
                        down.attn = attn
         | 
| 493 | 
            +
                        if i_level != self.num_resolutions-1:
         | 
| 494 | 
            +
                            down.downsample = Downsample(block_in, resamp_with_conv)
         | 
| 495 | 
            +
                            curr_res = curr_res // 2
         | 
| 496 | 
            +
                        self.down.append(down)
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    # middle
         | 
| 499 | 
            +
                    self.mid = nn.Module()
         | 
| 500 | 
            +
                    self.mid.block_1 = ResnetBlock(in_channels=block_in,
         | 
| 501 | 
            +
                                                   out_channels=block_in,
         | 
| 502 | 
            +
                                                   temb_channels=self.temb_ch,
         | 
| 503 | 
            +
                                                   dropout=dropout)
         | 
| 504 | 
            +
                    self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
         | 
| 505 | 
            +
                    self.mid.block_2 = ResnetBlock(in_channels=block_in,
         | 
| 506 | 
            +
                                                   out_channels=block_in,
         | 
| 507 | 
            +
                                                   temb_channels=self.temb_ch,
         | 
| 508 | 
            +
                                                   dropout=dropout)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                    # end
         | 
| 511 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 512 | 
            +
                    self.conv_out = torch.nn.Conv2d(block_in,
         | 
| 513 | 
            +
                                                    2*z_channels if double_z else z_channels,
         | 
| 514 | 
            +
                                                    kernel_size=3,
         | 
| 515 | 
            +
                                                    stride=1,
         | 
| 516 | 
            +
                                                    padding=1)
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                def forward(self, x):
         | 
| 519 | 
            +
                    # timestep embedding
         | 
| 520 | 
            +
                    temb = None
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    # downsampling
         | 
| 523 | 
            +
                    hs = [self.conv_in(x)]
         | 
| 524 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 525 | 
            +
                        for i_block in range(self.num_res_blocks):
         | 
| 526 | 
            +
                            h = self.down[i_level].block[i_block](hs[-1], temb)
         | 
| 527 | 
            +
                            if len(self.down[i_level].attn) > 0:
         | 
| 528 | 
            +
                                h = self.down[i_level].attn[i_block](h)
         | 
| 529 | 
            +
                            hs.append(h)
         | 
| 530 | 
            +
                        if i_level != self.num_resolutions-1:
         | 
| 531 | 
            +
                            hs.append(self.down[i_level].downsample(hs[-1]))
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    # middle
         | 
| 534 | 
            +
                    h = hs[-1]
         | 
| 535 | 
            +
                    h = self.mid.block_1(h, temb)
         | 
| 536 | 
            +
                    h = self.mid.attn_1(h)
         | 
| 537 | 
            +
                    h = self.mid.block_2(h, temb)
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    # end
         | 
| 540 | 
            +
                    h = self.norm_out(h)
         | 
| 541 | 
            +
                    h = nonlinearity(h)
         | 
| 542 | 
            +
                    h = self.conv_out(h)
         | 
| 543 | 
            +
                    return h
         | 
| 544 | 
            +
             | 
| 545 | 
            +
             | 
| 546 | 
            +
            class Decoder(nn.Module):
         | 
| 547 | 
            +
                def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
         | 
| 548 | 
            +
                             attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
         | 
| 549 | 
            +
                             resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
         | 
| 550 | 
            +
                             attn_type="vanilla", **ignorekwargs):
         | 
| 551 | 
            +
                    super().__init__()
         | 
| 552 | 
            +
                    if use_linear_attn: attn_type = "linear"
         | 
| 553 | 
            +
                    self.ch = ch
         | 
| 554 | 
            +
                    self.temb_ch = 0
         | 
| 555 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 556 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 557 | 
            +
                    self.resolution = resolution
         | 
| 558 | 
            +
                    self.in_channels = in_channels
         | 
| 559 | 
            +
                    self.give_pre_end = give_pre_end
         | 
| 560 | 
            +
                    self.tanh_out = tanh_out
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    # compute in_ch_mult, block_in and curr_res at lowest res
         | 
| 563 | 
            +
                    in_ch_mult = (1,)+tuple(ch_mult)
         | 
| 564 | 
            +
                    block_in = ch*ch_mult[self.num_resolutions-1]
         | 
| 565 | 
            +
                    curr_res = resolution // 2**(self.num_resolutions-1)
         | 
| 566 | 
            +
                    self.z_shape = (1,z_channels,curr_res,curr_res)
         | 
| 567 | 
            +
                    print("Working with z of shape {} = {} dimensions.".format(
         | 
| 568 | 
            +
                        self.z_shape, np.prod(self.z_shape)))
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    # z to block_in
         | 
| 571 | 
            +
                    self.conv_in = torch.nn.Conv2d(z_channels,
         | 
| 572 | 
            +
                                                   block_in,
         | 
| 573 | 
            +
                                                   kernel_size=3,
         | 
| 574 | 
            +
                                                   stride=1,
         | 
| 575 | 
            +
                                                   padding=1)
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                    # middle
         | 
| 578 | 
            +
                    self.mid = nn.Module()
         | 
| 579 | 
            +
                    self.mid.block_1 = ResnetBlock(in_channels=block_in,
         | 
| 580 | 
            +
                                                   out_channels=block_in,
         | 
| 581 | 
            +
                                                   temb_channels=self.temb_ch,
         | 
| 582 | 
            +
                                                   dropout=dropout)
         | 
| 583 | 
            +
                    self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
         | 
| 584 | 
            +
                    self.mid.block_2 = ResnetBlock(in_channels=block_in,
         | 
| 585 | 
            +
                                                   out_channels=block_in,
         | 
| 586 | 
            +
                                                   temb_channels=self.temb_ch,
         | 
| 587 | 
            +
                                                   dropout=dropout)
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                    # upsampling
         | 
| 590 | 
            +
                    self.up = nn.ModuleList()
         | 
| 591 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 592 | 
            +
                        block = nn.ModuleList()
         | 
| 593 | 
            +
                        attn = nn.ModuleList()
         | 
| 594 | 
            +
                        block_out = ch*ch_mult[i_level]
         | 
| 595 | 
            +
                        for i_block in range(self.num_res_blocks+1):
         | 
| 596 | 
            +
                            block.append(ResnetBlock(in_channels=block_in,
         | 
| 597 | 
            +
                                                     out_channels=block_out,
         | 
| 598 | 
            +
                                                     temb_channels=self.temb_ch,
         | 
| 599 | 
            +
                                                     dropout=dropout))
         | 
| 600 | 
            +
                            block_in = block_out
         | 
| 601 | 
            +
                            if curr_res in attn_resolutions:
         | 
| 602 | 
            +
                                attn.append(make_attn(block_in, attn_type=attn_type))
         | 
| 603 | 
            +
                        up = nn.Module()
         | 
| 604 | 
            +
                        up.block = block
         | 
| 605 | 
            +
                        up.attn = attn
         | 
| 606 | 
            +
                        if i_level != 0:
         | 
| 607 | 
            +
                            up.upsample = Upsample(block_in, resamp_with_conv)
         | 
| 608 | 
            +
                            curr_res = curr_res * 2
         | 
| 609 | 
            +
                        self.up.insert(0, up) # prepend to get consistent order
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                    # end
         | 
| 612 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 613 | 
            +
                    self.conv_out = torch.nn.Conv2d(block_in,
         | 
| 614 | 
            +
                                                    out_ch,
         | 
| 615 | 
            +
                                                    kernel_size=3,
         | 
| 616 | 
            +
                                                    stride=1,
         | 
| 617 | 
            +
                                                    padding=1)
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                def forward(self, z):
         | 
| 620 | 
            +
                    #assert z.shape[1:] == self.z_shape[1:]
         | 
| 621 | 
            +
                    self.last_z_shape = z.shape
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                    # timestep embedding
         | 
| 624 | 
            +
                    temb = None
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                    # z to block_in
         | 
| 627 | 
            +
                    h = self.conv_in(z)
         | 
| 628 | 
            +
             | 
| 629 | 
            +
                    # middle
         | 
| 630 | 
            +
                    h = self.mid.block_1(h, temb)
         | 
| 631 | 
            +
                    h = self.mid.attn_1(h)
         | 
| 632 | 
            +
                    h = self.mid.block_2(h, temb)
         | 
| 633 | 
            +
             | 
| 634 | 
            +
                    # upsampling
         | 
| 635 | 
            +
                    for i_level in reversed(range(self.num_resolutions)):
         | 
| 636 | 
            +
                        for i_block in range(self.num_res_blocks+1):
         | 
| 637 | 
            +
                            h = self.up[i_level].block[i_block](h, temb)
         | 
| 638 | 
            +
                            if len(self.up[i_level].attn) > 0:
         | 
| 639 | 
            +
                                h = self.up[i_level].attn[i_block](h)
         | 
| 640 | 
            +
                        if i_level != 0:
         | 
| 641 | 
            +
                            h = self.up[i_level].upsample(h)
         | 
| 642 | 
            +
             | 
| 643 | 
            +
                    # end
         | 
| 644 | 
            +
                    if self.give_pre_end:
         | 
| 645 | 
            +
                        return h
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                    h = self.norm_out(h)
         | 
| 648 | 
            +
                    h = nonlinearity(h)
         | 
| 649 | 
            +
                    h = self.conv_out(h)
         | 
| 650 | 
            +
                    if self.tanh_out:
         | 
| 651 | 
            +
                        h = torch.tanh(h)
         | 
| 652 | 
            +
                    return h
         | 
| 653 | 
            +
             | 
| 654 | 
            +
             | 
| 655 | 
            +
            class SimpleDecoder(nn.Module):
         | 
| 656 | 
            +
                def __init__(self, in_channels, out_channels, *args, **kwargs):
         | 
| 657 | 
            +
                    super().__init__()
         | 
| 658 | 
            +
                    self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
         | 
| 659 | 
            +
                                                 ResnetBlock(in_channels=in_channels,
         | 
| 660 | 
            +
                                                             out_channels=2 * in_channels,
         | 
| 661 | 
            +
                                                             temb_channels=0, dropout=0.0),
         | 
| 662 | 
            +
                                                 ResnetBlock(in_channels=2 * in_channels,
         | 
| 663 | 
            +
                                                            out_channels=4 * in_channels,
         | 
| 664 | 
            +
                                                            temb_channels=0, dropout=0.0),
         | 
| 665 | 
            +
                                                 ResnetBlock(in_channels=4 * in_channels,
         | 
| 666 | 
            +
                                                            out_channels=2 * in_channels,
         | 
| 667 | 
            +
                                                            temb_channels=0, dropout=0.0),
         | 
| 668 | 
            +
                                                 nn.Conv2d(2*in_channels, in_channels, 1),
         | 
| 669 | 
            +
                                                 Upsample(in_channels, with_conv=True)])
         | 
| 670 | 
            +
                    # end
         | 
| 671 | 
            +
                    self.norm_out = Normalize(in_channels)
         | 
| 672 | 
            +
                    self.conv_out = torch.nn.Conv2d(in_channels,
         | 
| 673 | 
            +
                                                    out_channels,
         | 
| 674 | 
            +
                                                    kernel_size=3,
         | 
| 675 | 
            +
                                                    stride=1,
         | 
| 676 | 
            +
                                                    padding=1)
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                def forward(self, x):
         | 
| 679 | 
            +
                    for i, layer in enumerate(self.model):
         | 
| 680 | 
            +
                        if i in [1,2,3]:
         | 
| 681 | 
            +
                            x = layer(x, None)
         | 
| 682 | 
            +
                        else:
         | 
| 683 | 
            +
                            x = layer(x)
         | 
| 684 | 
            +
             | 
| 685 | 
            +
                    h = self.norm_out(x)
         | 
| 686 | 
            +
                    h = nonlinearity(h)
         | 
| 687 | 
            +
                    x = self.conv_out(h)
         | 
| 688 | 
            +
                    return x
         | 
| 689 | 
            +
             | 
| 690 | 
            +
             | 
| 691 | 
            +
            class UpsampleDecoder(nn.Module):
         | 
| 692 | 
            +
                def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
         | 
| 693 | 
            +
                             ch_mult=(2,2), dropout=0.0):
         | 
| 694 | 
            +
                    super().__init__()
         | 
| 695 | 
            +
                    # upsampling
         | 
| 696 | 
            +
                    self.temb_ch = 0
         | 
| 697 | 
            +
                    self.num_resolutions = len(ch_mult)
         | 
| 698 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 699 | 
            +
                    block_in = in_channels
         | 
| 700 | 
            +
                    curr_res = resolution // 2 ** (self.num_resolutions - 1)
         | 
| 701 | 
            +
                    self.res_blocks = nn.ModuleList()
         | 
| 702 | 
            +
                    self.upsample_blocks = nn.ModuleList()
         | 
| 703 | 
            +
                    for i_level in range(self.num_resolutions):
         | 
| 704 | 
            +
                        res_block = []
         | 
| 705 | 
            +
                        block_out = ch * ch_mult[i_level]
         | 
| 706 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 707 | 
            +
                            res_block.append(ResnetBlock(in_channels=block_in,
         | 
| 708 | 
            +
                                                     out_channels=block_out,
         | 
| 709 | 
            +
                                                     temb_channels=self.temb_ch,
         | 
| 710 | 
            +
                                                     dropout=dropout))
         | 
| 711 | 
            +
                            block_in = block_out
         | 
| 712 | 
            +
                        self.res_blocks.append(nn.ModuleList(res_block))
         | 
| 713 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 714 | 
            +
                            self.upsample_blocks.append(Upsample(block_in, True))
         | 
| 715 | 
            +
                            curr_res = curr_res * 2
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                    # end
         | 
| 718 | 
            +
                    self.norm_out = Normalize(block_in)
         | 
| 719 | 
            +
                    self.conv_out = torch.nn.Conv2d(block_in,
         | 
| 720 | 
            +
                                                    out_channels,
         | 
| 721 | 
            +
                                                    kernel_size=3,
         | 
| 722 | 
            +
                                                    stride=1,
         | 
| 723 | 
            +
                                                    padding=1)
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                def forward(self, x):
         | 
| 726 | 
            +
                    # upsampling
         | 
| 727 | 
            +
                    h = x
         | 
| 728 | 
            +
                    for k, i_level in enumerate(range(self.num_resolutions)):
         | 
| 729 | 
            +
                        for i_block in range(self.num_res_blocks + 1):
         | 
| 730 | 
            +
                            h = self.res_blocks[i_level][i_block](h, None)
         | 
| 731 | 
            +
                        if i_level != self.num_resolutions - 1:
         | 
| 732 | 
            +
                            h = self.upsample_blocks[k](h)
         | 
| 733 | 
            +
                    h = self.norm_out(h)
         | 
| 734 | 
            +
                    h = nonlinearity(h)
         | 
| 735 | 
            +
                    h = self.conv_out(h)
         | 
| 736 | 
            +
                    return h
         | 
| 737 | 
            +
             | 
| 738 | 
            +
             | 
| 739 | 
            +
            class LatentRescaler(nn.Module):
         | 
| 740 | 
            +
                def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
         | 
| 741 | 
            +
                    super().__init__()
         | 
| 742 | 
            +
                    # residual block, interpolate, residual block
         | 
| 743 | 
            +
                    self.factor = factor
         | 
| 744 | 
            +
                    self.conv_in = nn.Conv2d(in_channels,
         | 
| 745 | 
            +
                                             mid_channels,
         | 
| 746 | 
            +
                                             kernel_size=3,
         | 
| 747 | 
            +
                                             stride=1,
         | 
| 748 | 
            +
                                             padding=1)
         | 
| 749 | 
            +
                    self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
         | 
| 750 | 
            +
                                                                 out_channels=mid_channels,
         | 
| 751 | 
            +
                                                                 temb_channels=0,
         | 
| 752 | 
            +
                                                                 dropout=0.0) for _ in range(depth)])
         | 
| 753 | 
            +
                    self.attn = AttnBlock(mid_channels)
         | 
| 754 | 
            +
                    self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
         | 
| 755 | 
            +
                                                                 out_channels=mid_channels,
         | 
| 756 | 
            +
                                                                 temb_channels=0,
         | 
| 757 | 
            +
                                                                 dropout=0.0) for _ in range(depth)])
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                    self.conv_out = nn.Conv2d(mid_channels,
         | 
| 760 | 
            +
                                              out_channels,
         | 
| 761 | 
            +
                                              kernel_size=1,
         | 
| 762 | 
            +
                                              )
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                def forward(self, x):
         | 
| 765 | 
            +
                    x = self.conv_in(x)
         | 
| 766 | 
            +
                    for block in self.res_block1:
         | 
| 767 | 
            +
                        x = block(x, None)
         | 
| 768 | 
            +
                    x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
         | 
| 769 | 
            +
                    x = self.attn(x)
         | 
| 770 | 
            +
                    for block in self.res_block2:
         | 
| 771 | 
            +
                        x = block(x, None)
         | 
| 772 | 
            +
                    x = self.conv_out(x)
         | 
| 773 | 
            +
                    return x
         | 
| 774 | 
            +
             | 
| 775 | 
            +
             | 
| 776 | 
            +
            class MergedRescaleEncoder(nn.Module):
         | 
| 777 | 
            +
                def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
         | 
| 778 | 
            +
                             attn_resolutions, dropout=0.0, resamp_with_conv=True,
         | 
| 779 | 
            +
                             ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
         | 
| 780 | 
            +
                    super().__init__()
         | 
| 781 | 
            +
                    intermediate_chn = ch * ch_mult[-1]
         | 
| 782 | 
            +
                    self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
         | 
| 783 | 
            +
                                           z_channels=intermediate_chn, double_z=False, resolution=resolution,
         | 
| 784 | 
            +
                                           attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
         | 
| 785 | 
            +
                                           out_ch=None)
         | 
| 786 | 
            +
                    self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
         | 
| 787 | 
            +
                                                   mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
         | 
| 788 | 
            +
             | 
| 789 | 
            +
                def forward(self, x):
         | 
| 790 | 
            +
                    x = self.encoder(x)
         | 
| 791 | 
            +
                    x = self.rescaler(x)
         | 
| 792 | 
            +
                    return x
         | 
| 793 | 
            +
             | 
| 794 | 
            +
             | 
| 795 | 
            +
            class MergedRescaleDecoder(nn.Module):
         | 
| 796 | 
            +
                def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
         | 
| 797 | 
            +
                             dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
         | 
| 798 | 
            +
                    super().__init__()
         | 
| 799 | 
            +
                    tmp_chn = z_channels*ch_mult[-1]
         | 
| 800 | 
            +
                    self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
         | 
| 801 | 
            +
                                           resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
         | 
| 802 | 
            +
                                           ch_mult=ch_mult, resolution=resolution, ch=ch)
         | 
| 803 | 
            +
                    self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
         | 
| 804 | 
            +
                                                   out_channels=tmp_chn, depth=rescale_module_depth)
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                def forward(self, x):
         | 
| 807 | 
            +
                    x = self.rescaler(x)
         | 
| 808 | 
            +
                    x = self.decoder(x)
         | 
| 809 | 
            +
                    return x
         | 
| 810 | 
            +
             | 
| 811 | 
            +
             | 
| 812 | 
            +
            class Upsampler(nn.Module):
         | 
| 813 | 
            +
                def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
         | 
| 814 | 
            +
                    super().__init__()
         | 
| 815 | 
            +
                    assert out_size >= in_size
         | 
| 816 | 
            +
                    num_blocks = int(np.log2(out_size//in_size))+1
         | 
| 817 | 
            +
                    factor_up = 1.+ (out_size % in_size)
         | 
| 818 | 
            +
                    print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
         | 
| 819 | 
            +
                    self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
         | 
| 820 | 
            +
                                                   out_channels=in_channels)
         | 
| 821 | 
            +
                    self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
         | 
| 822 | 
            +
                                           attn_resolutions=[], in_channels=None, ch=in_channels,
         | 
| 823 | 
            +
                                           ch_mult=[ch_mult for _ in range(num_blocks)])
         | 
| 824 | 
            +
             | 
| 825 | 
            +
                def forward(self, x):
         | 
| 826 | 
            +
                    x = self.rescaler(x)
         | 
| 827 | 
            +
                    x = self.decoder(x)
         | 
| 828 | 
            +
                    return x
         | 
| 829 | 
            +
             | 
| 830 | 
            +
             | 
| 831 | 
            +
            class Resize(nn.Module):
         | 
| 832 | 
            +
                def __init__(self, in_channels=None, learned=False, mode="bilinear"):
         | 
| 833 | 
            +
                    super().__init__()
         | 
| 834 | 
            +
                    self.with_conv = learned
         | 
| 835 | 
            +
                    self.mode = mode
         | 
| 836 | 
            +
                    if self.with_conv:
         | 
| 837 | 
            +
                        print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
         | 
| 838 | 
            +
                        raise NotImplementedError()
         | 
| 839 | 
            +
                        assert in_channels is not None
         | 
| 840 | 
            +
                        # no asymmetric padding in torch conv, must do it ourselves
         | 
| 841 | 
            +
                        self.conv = torch.nn.Conv2d(in_channels,
         | 
| 842 | 
            +
                                                    in_channels,
         | 
| 843 | 
            +
                                                    kernel_size=4,
         | 
| 844 | 
            +
                                                    stride=2,
         | 
| 845 | 
            +
                                                    padding=1)
         | 
| 846 | 
            +
             | 
| 847 | 
            +
                def forward(self, x, scale_factor=1.0):
         | 
| 848 | 
            +
                    if scale_factor==1.0:
         | 
| 849 | 
            +
                        return x
         | 
| 850 | 
            +
                    else:
         | 
| 851 | 
            +
                        x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
         | 
| 852 | 
            +
                    return x
         | 
    	
        ldm/modules/diffusionmodules/openaimodel.py
    ADDED
    
    | @@ -0,0 +1,786 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from abc import abstractmethod
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch as th
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            import torch.nn.functional as F
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from ldm.modules.diffusionmodules.util import (
         | 
| 10 | 
            +
                checkpoint,
         | 
| 11 | 
            +
                conv_nd,
         | 
| 12 | 
            +
                linear,
         | 
| 13 | 
            +
                avg_pool_nd,
         | 
| 14 | 
            +
                zero_module,
         | 
| 15 | 
            +
                normalization,
         | 
| 16 | 
            +
                timestep_embedding,
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
            from ldm.modules.attention import SpatialTransformer
         | 
| 19 | 
            +
            from ldm.util import exists
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            # dummy replace
         | 
| 23 | 
            +
            def convert_module_to_f16(x):
         | 
| 24 | 
            +
                pass
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            def convert_module_to_f32(x):
         | 
| 27 | 
            +
                pass
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            ## go
         | 
| 31 | 
            +
            class AttentionPool2d(nn.Module):
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
         | 
| 34 | 
            +
                """
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def __init__(
         | 
| 37 | 
            +
                    self,
         | 
| 38 | 
            +
                    spacial_dim: int,
         | 
| 39 | 
            +
                    embed_dim: int,
         | 
| 40 | 
            +
                    num_heads_channels: int,
         | 
| 41 | 
            +
                    output_dim: int = None,
         | 
| 42 | 
            +
                ):
         | 
| 43 | 
            +
                    super().__init__()
         | 
| 44 | 
            +
                    self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
         | 
| 45 | 
            +
                    self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
         | 
| 46 | 
            +
                    self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
         | 
| 47 | 
            +
                    self.num_heads = embed_dim // num_heads_channels
         | 
| 48 | 
            +
                    self.attention = QKVAttention(self.num_heads)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def forward(self, x):
         | 
| 51 | 
            +
                    b, c, *_spatial = x.shape
         | 
| 52 | 
            +
                    x = x.reshape(b, c, -1)  # NC(HW)
         | 
| 53 | 
            +
                    x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)
         | 
| 54 | 
            +
                    x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)
         | 
| 55 | 
            +
                    x = self.qkv_proj(x)
         | 
| 56 | 
            +
                    x = self.attention(x)
         | 
| 57 | 
            +
                    x = self.c_proj(x)
         | 
| 58 | 
            +
                    return x[:, :, 0]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            class TimestepBlock(nn.Module):
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                Any module where forward() takes timestep embeddings as a second argument.
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                @abstractmethod
         | 
| 67 | 
            +
                def forward(self, x, emb):
         | 
| 68 | 
            +
                    """
         | 
| 69 | 
            +
                    Apply the module to `x` given `emb` timestep embeddings.
         | 
| 70 | 
            +
                    """
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                A sequential module that passes timestep embeddings to the children that
         | 
| 76 | 
            +
                support it as an extra input.
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def forward(self, x, emb, context=None):
         | 
| 80 | 
            +
                    for layer in self:
         | 
| 81 | 
            +
                        if isinstance(layer, TimestepBlock):
         | 
| 82 | 
            +
                            x = layer(x, emb)
         | 
| 83 | 
            +
                        elif isinstance(layer, SpatialTransformer):
         | 
| 84 | 
            +
                            x = layer(x, context)
         | 
| 85 | 
            +
                        else:
         | 
| 86 | 
            +
                            x = layer(x)
         | 
| 87 | 
            +
                    return x
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            class Upsample(nn.Module):
         | 
| 91 | 
            +
                """
         | 
| 92 | 
            +
                An upsampling layer with an optional convolution.
         | 
| 93 | 
            +
                :param channels: channels in the inputs and outputs.
         | 
| 94 | 
            +
                :param use_conv: a bool determining if a convolution is applied.
         | 
| 95 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
         | 
| 96 | 
            +
                             upsampling occurs in the inner-two dimensions.
         | 
| 97 | 
            +
                """
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
         | 
| 100 | 
            +
                    super().__init__()
         | 
| 101 | 
            +
                    self.channels = channels
         | 
| 102 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 103 | 
            +
                    self.use_conv = use_conv
         | 
| 104 | 
            +
                    self.dims = dims
         | 
| 105 | 
            +
                    if use_conv:
         | 
| 106 | 
            +
                        self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def forward(self, x):
         | 
| 109 | 
            +
                    assert x.shape[1] == self.channels
         | 
| 110 | 
            +
                    if self.dims == 3:
         | 
| 111 | 
            +
                        x = F.interpolate(
         | 
| 112 | 
            +
                            x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
         | 
| 113 | 
            +
                        )
         | 
| 114 | 
            +
                    else:
         | 
| 115 | 
            +
                        x = F.interpolate(x, scale_factor=2, mode="nearest")
         | 
| 116 | 
            +
                    if self.use_conv:
         | 
| 117 | 
            +
                        x = self.conv(x)
         | 
| 118 | 
            +
                    return x
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            class TransposedUpsample(nn.Module):
         | 
| 121 | 
            +
                'Learned 2x upsampling without padding'
         | 
| 122 | 
            +
                def __init__(self, channels, out_channels=None, ks=5):
         | 
| 123 | 
            +
                    super().__init__()
         | 
| 124 | 
            +
                    self.channels = channels
         | 
| 125 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def forward(self,x):
         | 
| 130 | 
            +
                    return self.up(x)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            class Downsample(nn.Module):
         | 
| 134 | 
            +
                """
         | 
| 135 | 
            +
                A downsampling layer with an optional convolution.
         | 
| 136 | 
            +
                :param channels: channels in the inputs and outputs.
         | 
| 137 | 
            +
                :param use_conv: a bool determining if a convolution is applied.
         | 
| 138 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
         | 
| 139 | 
            +
                             downsampling occurs in the inner-two dimensions.
         | 
| 140 | 
            +
                """
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
         | 
| 143 | 
            +
                    super().__init__()
         | 
| 144 | 
            +
                    self.channels = channels
         | 
| 145 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 146 | 
            +
                    self.use_conv = use_conv
         | 
| 147 | 
            +
                    self.dims = dims
         | 
| 148 | 
            +
                    stride = 2 if dims != 3 else (1, 2, 2)
         | 
| 149 | 
            +
                    if use_conv:
         | 
| 150 | 
            +
                        self.op = conv_nd(
         | 
| 151 | 
            +
                            dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
         | 
| 152 | 
            +
                        )
         | 
| 153 | 
            +
                    else:
         | 
| 154 | 
            +
                        assert self.channels == self.out_channels
         | 
| 155 | 
            +
                        self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                def forward(self, x):
         | 
| 158 | 
            +
                    assert x.shape[1] == self.channels
         | 
| 159 | 
            +
                    return self.op(x)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            class ResBlock(TimestepBlock):
         | 
| 163 | 
            +
                """
         | 
| 164 | 
            +
                A residual block that can optionally change the number of channels.
         | 
| 165 | 
            +
                :param channels: the number of input channels.
         | 
| 166 | 
            +
                :param emb_channels: the number of timestep embedding channels.
         | 
| 167 | 
            +
                :param dropout: the rate of dropout.
         | 
| 168 | 
            +
                :param out_channels: if specified, the number of out channels.
         | 
| 169 | 
            +
                :param use_conv: if True and out_channels is specified, use a spatial
         | 
| 170 | 
            +
                    convolution instead of a smaller 1x1 convolution to change the
         | 
| 171 | 
            +
                    channels in the skip connection.
         | 
| 172 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D.
         | 
| 173 | 
            +
                :param use_checkpoint: if True, use gradient checkpointing on this module.
         | 
| 174 | 
            +
                :param up: if True, use this block for upsampling.
         | 
| 175 | 
            +
                :param down: if True, use this block for downsampling.
         | 
| 176 | 
            +
                """
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                def __init__(
         | 
| 179 | 
            +
                    self,
         | 
| 180 | 
            +
                    channels,
         | 
| 181 | 
            +
                    emb_channels,
         | 
| 182 | 
            +
                    dropout,
         | 
| 183 | 
            +
                    out_channels=None,
         | 
| 184 | 
            +
                    use_conv=False,
         | 
| 185 | 
            +
                    use_scale_shift_norm=False,
         | 
| 186 | 
            +
                    dims=2,
         | 
| 187 | 
            +
                    use_checkpoint=False,
         | 
| 188 | 
            +
                    up=False,
         | 
| 189 | 
            +
                    down=False,
         | 
| 190 | 
            +
                ):
         | 
| 191 | 
            +
                    super().__init__()
         | 
| 192 | 
            +
                    self.channels = channels
         | 
| 193 | 
            +
                    self.emb_channels = emb_channels
         | 
| 194 | 
            +
                    self.dropout = dropout
         | 
| 195 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 196 | 
            +
                    self.use_conv = use_conv
         | 
| 197 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 198 | 
            +
                    self.use_scale_shift_norm = use_scale_shift_norm
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    self.in_layers = nn.Sequential(
         | 
| 201 | 
            +
                        normalization(channels),
         | 
| 202 | 
            +
                        nn.SiLU(),
         | 
| 203 | 
            +
                        conv_nd(dims, channels, self.out_channels, 3, padding=1),
         | 
| 204 | 
            +
                    )
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    self.updown = up or down
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    if up:
         | 
| 209 | 
            +
                        self.h_upd = Upsample(channels, False, dims)
         | 
| 210 | 
            +
                        self.x_upd = Upsample(channels, False, dims)
         | 
| 211 | 
            +
                    elif down:
         | 
| 212 | 
            +
                        self.h_upd = Downsample(channels, False, dims)
         | 
| 213 | 
            +
                        self.x_upd = Downsample(channels, False, dims)
         | 
| 214 | 
            +
                    else:
         | 
| 215 | 
            +
                        self.h_upd = self.x_upd = nn.Identity()
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    self.emb_layers = nn.Sequential(
         | 
| 218 | 
            +
                        nn.SiLU(),
         | 
| 219 | 
            +
                        linear(
         | 
| 220 | 
            +
                            emb_channels,
         | 
| 221 | 
            +
                            2 * self.out_channels if use_scale_shift_norm else self.out_channels,
         | 
| 222 | 
            +
                        ),
         | 
| 223 | 
            +
                    )
         | 
| 224 | 
            +
                    self.out_layers = nn.Sequential(
         | 
| 225 | 
            +
                        normalization(self.out_channels),
         | 
| 226 | 
            +
                        nn.SiLU(),
         | 
| 227 | 
            +
                        nn.Dropout(p=dropout),
         | 
| 228 | 
            +
                        zero_module(
         | 
| 229 | 
            +
                            conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
         | 
| 230 | 
            +
                        ),
         | 
| 231 | 
            +
                    )
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    if self.out_channels == channels:
         | 
| 234 | 
            +
                        self.skip_connection = nn.Identity()
         | 
| 235 | 
            +
                    elif use_conv:
         | 
| 236 | 
            +
                        self.skip_connection = conv_nd(
         | 
| 237 | 
            +
                            dims, channels, self.out_channels, 3, padding=1
         | 
| 238 | 
            +
                        )
         | 
| 239 | 
            +
                    else:
         | 
| 240 | 
            +
                        self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def forward(self, x, emb):
         | 
| 243 | 
            +
                    """
         | 
| 244 | 
            +
                    Apply the block to a Tensor, conditioned on a timestep embedding.
         | 
| 245 | 
            +
                    :param x: an [N x C x ...] Tensor of features.
         | 
| 246 | 
            +
                    :param emb: an [N x emb_channels] Tensor of timestep embeddings.
         | 
| 247 | 
            +
                    :return: an [N x C x ...] Tensor of outputs.
         | 
| 248 | 
            +
                    """
         | 
| 249 | 
            +
                    return checkpoint(
         | 
| 250 | 
            +
                        self._forward, (x, emb), self.parameters(), self.use_checkpoint
         | 
| 251 | 
            +
                    )
         | 
| 252 | 
            +
             | 
| 253 | 
            +
             | 
| 254 | 
            +
                def _forward(self, x, emb):
         | 
| 255 | 
            +
                    if self.updown:
         | 
| 256 | 
            +
                        in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
         | 
| 257 | 
            +
                        h = in_rest(x)
         | 
| 258 | 
            +
                        h = self.h_upd(h)
         | 
| 259 | 
            +
                        x = self.x_upd(x)
         | 
| 260 | 
            +
                        h = in_conv(h)
         | 
| 261 | 
            +
                    else:
         | 
| 262 | 
            +
                        h = self.in_layers(x)
         | 
| 263 | 
            +
                    emb_out = self.emb_layers(emb).type(h.dtype)
         | 
| 264 | 
            +
                    while len(emb_out.shape) < len(h.shape):
         | 
| 265 | 
            +
                        emb_out = emb_out[..., None]
         | 
| 266 | 
            +
                    if self.use_scale_shift_norm:
         | 
| 267 | 
            +
                        out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
         | 
| 268 | 
            +
                        scale, shift = th.chunk(emb_out, 2, dim=1)
         | 
| 269 | 
            +
                        h = out_norm(h) * (1 + scale) + shift
         | 
| 270 | 
            +
                        h = out_rest(h)
         | 
| 271 | 
            +
                    else:
         | 
| 272 | 
            +
                        h = h + emb_out
         | 
| 273 | 
            +
                        h = self.out_layers(h)
         | 
| 274 | 
            +
                    return self.skip_connection(x) + h
         | 
| 275 | 
            +
             | 
| 276 | 
            +
             | 
| 277 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 278 | 
            +
                """
         | 
| 279 | 
            +
                An attention block that allows spatial positions to attend to each other.
         | 
| 280 | 
            +
                Originally ported from here, but adapted to the N-d case.
         | 
| 281 | 
            +
                https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
         | 
| 282 | 
            +
                """
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                def __init__(
         | 
| 285 | 
            +
                    self,
         | 
| 286 | 
            +
                    channels,
         | 
| 287 | 
            +
                    num_heads=1,
         | 
| 288 | 
            +
                    num_head_channels=-1,
         | 
| 289 | 
            +
                    use_checkpoint=False,
         | 
| 290 | 
            +
                    use_new_attention_order=False,
         | 
| 291 | 
            +
                ):
         | 
| 292 | 
            +
                    super().__init__()
         | 
| 293 | 
            +
                    self.channels = channels
         | 
| 294 | 
            +
                    if num_head_channels == -1:
         | 
| 295 | 
            +
                        self.num_heads = num_heads
         | 
| 296 | 
            +
                    else:
         | 
| 297 | 
            +
                        assert (
         | 
| 298 | 
            +
                            channels % num_head_channels == 0
         | 
| 299 | 
            +
                        ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
         | 
| 300 | 
            +
                        self.num_heads = channels // num_head_channels
         | 
| 301 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 302 | 
            +
                    self.norm = normalization(channels)
         | 
| 303 | 
            +
                    self.qkv = conv_nd(1, channels, channels * 3, 1)
         | 
| 304 | 
            +
                    if use_new_attention_order:
         | 
| 305 | 
            +
                        # split qkv before split heads
         | 
| 306 | 
            +
                        self.attention = QKVAttention(self.num_heads)
         | 
| 307 | 
            +
                    else:
         | 
| 308 | 
            +
                        # split heads before split qkv
         | 
| 309 | 
            +
                        self.attention = QKVAttentionLegacy(self.num_heads)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                def forward(self, x):
         | 
| 314 | 
            +
                    return checkpoint(self._forward, (x,), self.parameters(), True)   # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
         | 
| 315 | 
            +
                    #return pt_checkpoint(self._forward, x)  # pytorch
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                def _forward(self, x):
         | 
| 318 | 
            +
                    b, c, *spatial = x.shape
         | 
| 319 | 
            +
                    x = x.reshape(b, c, -1)
         | 
| 320 | 
            +
                    qkv = self.qkv(self.norm(x))
         | 
| 321 | 
            +
                    h = self.attention(qkv)
         | 
| 322 | 
            +
                    h = self.proj_out(h)
         | 
| 323 | 
            +
                    return (x + h).reshape(b, c, *spatial)
         | 
| 324 | 
            +
             | 
| 325 | 
            +
             | 
| 326 | 
            +
            def count_flops_attn(model, _x, y):
         | 
| 327 | 
            +
                """
         | 
| 328 | 
            +
                A counter for the `thop` package to count the operations in an
         | 
| 329 | 
            +
                attention operation.
         | 
| 330 | 
            +
                Meant to be used like:
         | 
| 331 | 
            +
                    macs, params = thop.profile(
         | 
| 332 | 
            +
                        model,
         | 
| 333 | 
            +
                        inputs=(inputs, timestamps),
         | 
| 334 | 
            +
                        custom_ops={QKVAttention: QKVAttention.count_flops},
         | 
| 335 | 
            +
                    )
         | 
| 336 | 
            +
                """
         | 
| 337 | 
            +
                b, c, *spatial = y[0].shape
         | 
| 338 | 
            +
                num_spatial = int(np.prod(spatial))
         | 
| 339 | 
            +
                # We perform two matmuls with the same number of ops.
         | 
| 340 | 
            +
                # The first computes the weight matrix, the second computes
         | 
| 341 | 
            +
                # the combination of the value vectors.
         | 
| 342 | 
            +
                matmul_ops = 2 * b * (num_spatial ** 2) * c
         | 
| 343 | 
            +
                model.total_ops += th.DoubleTensor([matmul_ops])
         | 
| 344 | 
            +
             | 
| 345 | 
            +
             | 
| 346 | 
            +
            class QKVAttentionLegacy(nn.Module):
         | 
| 347 | 
            +
                """
         | 
| 348 | 
            +
                A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
         | 
| 349 | 
            +
                """
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                def __init__(self, n_heads):
         | 
| 352 | 
            +
                    super().__init__()
         | 
| 353 | 
            +
                    self.n_heads = n_heads
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                def forward(self, qkv):
         | 
| 356 | 
            +
                    """
         | 
| 357 | 
            +
                    Apply QKV attention.
         | 
| 358 | 
            +
                    :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
         | 
| 359 | 
            +
                    :return: an [N x (H * C) x T] tensor after attention.
         | 
| 360 | 
            +
                    """
         | 
| 361 | 
            +
                    bs, width, length = qkv.shape
         | 
| 362 | 
            +
                    assert width % (3 * self.n_heads) == 0
         | 
| 363 | 
            +
                    ch = width // (3 * self.n_heads)
         | 
| 364 | 
            +
                    q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
         | 
| 365 | 
            +
                    scale = 1 / math.sqrt(math.sqrt(ch))
         | 
| 366 | 
            +
                    weight = th.einsum(
         | 
| 367 | 
            +
                        "bct,bcs->bts", q * scale, k * scale
         | 
| 368 | 
            +
                    )  # More stable with f16 than dividing afterwards
         | 
| 369 | 
            +
                    weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
         | 
| 370 | 
            +
                    a = th.einsum("bts,bcs->bct", weight, v)
         | 
| 371 | 
            +
                    return a.reshape(bs, -1, length)
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                @staticmethod
         | 
| 374 | 
            +
                def count_flops(model, _x, y):
         | 
| 375 | 
            +
                    return count_flops_attn(model, _x, y)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
             | 
| 378 | 
            +
            class QKVAttention(nn.Module):
         | 
| 379 | 
            +
                """
         | 
| 380 | 
            +
                A module which performs QKV attention and splits in a different order.
         | 
| 381 | 
            +
                """
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                def __init__(self, n_heads):
         | 
| 384 | 
            +
                    super().__init__()
         | 
| 385 | 
            +
                    self.n_heads = n_heads
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                def forward(self, qkv):
         | 
| 388 | 
            +
                    """
         | 
| 389 | 
            +
                    Apply QKV attention.
         | 
| 390 | 
            +
                    :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
         | 
| 391 | 
            +
                    :return: an [N x (H * C) x T] tensor after attention.
         | 
| 392 | 
            +
                    """
         | 
| 393 | 
            +
                    bs, width, length = qkv.shape
         | 
| 394 | 
            +
                    assert width % (3 * self.n_heads) == 0
         | 
| 395 | 
            +
                    ch = width // (3 * self.n_heads)
         | 
| 396 | 
            +
                    q, k, v = qkv.chunk(3, dim=1)
         | 
| 397 | 
            +
                    scale = 1 / math.sqrt(math.sqrt(ch))
         | 
| 398 | 
            +
                    weight = th.einsum(
         | 
| 399 | 
            +
                        "bct,bcs->bts",
         | 
| 400 | 
            +
                        (q * scale).view(bs * self.n_heads, ch, length),
         | 
| 401 | 
            +
                        (k * scale).view(bs * self.n_heads, ch, length),
         | 
| 402 | 
            +
                    )  # More stable with f16 than dividing afterwards
         | 
| 403 | 
            +
                    weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
         | 
| 404 | 
            +
                    a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
         | 
| 405 | 
            +
                    return a.reshape(bs, -1, length)
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                @staticmethod
         | 
| 408 | 
            +
                def count_flops(model, _x, y):
         | 
| 409 | 
            +
                    return count_flops_attn(model, _x, y)
         | 
| 410 | 
            +
             | 
| 411 | 
            +
             | 
| 412 | 
            +
            class UNetModel(nn.Module):
         | 
| 413 | 
            +
                """
         | 
| 414 | 
            +
                The full UNet model with attention and timestep embedding.
         | 
| 415 | 
            +
                :param in_channels: channels in the input Tensor.
         | 
| 416 | 
            +
                :param model_channels: base channel count for the model.
         | 
| 417 | 
            +
                :param out_channels: channels in the output Tensor.
         | 
| 418 | 
            +
                :param num_res_blocks: number of residual blocks per downsample.
         | 
| 419 | 
            +
                :param attention_resolutions: a collection of downsample rates at which
         | 
| 420 | 
            +
                    attention will take place. May be a set, list, or tuple.
         | 
| 421 | 
            +
                    For example, if this contains 4, then at 4x downsampling, attention
         | 
| 422 | 
            +
                    will be used.
         | 
| 423 | 
            +
                :param dropout: the dropout probability.
         | 
| 424 | 
            +
                :param channel_mult: channel multiplier for each level of the UNet.
         | 
| 425 | 
            +
                :param conv_resample: if True, use learned convolutions for upsampling and
         | 
| 426 | 
            +
                    downsampling.
         | 
| 427 | 
            +
                :param dims: determines if the signal is 1D, 2D, or 3D.
         | 
| 428 | 
            +
                :param num_classes: if specified (as an int), then this model will be
         | 
| 429 | 
            +
                    class-conditional with `num_classes` classes.
         | 
| 430 | 
            +
                :param use_checkpoint: use gradient checkpointing to reduce memory usage.
         | 
| 431 | 
            +
                :param num_heads: the number of attention heads in each attention layer.
         | 
| 432 | 
            +
                :param num_heads_channels: if specified, ignore num_heads and instead use
         | 
| 433 | 
            +
                                           a fixed channel width per attention head.
         | 
| 434 | 
            +
                :param num_heads_upsample: works with num_heads to set a different number
         | 
| 435 | 
            +
                                           of heads for upsampling. Deprecated.
         | 
| 436 | 
            +
                :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
         | 
| 437 | 
            +
                :param resblock_updown: use residual blocks for up/downsampling.
         | 
| 438 | 
            +
                :param use_new_attention_order: use a different attention pattern for potentially
         | 
| 439 | 
            +
                                                increased efficiency.
         | 
| 440 | 
            +
                """
         | 
| 441 | 
            +
             | 
| 442 | 
            +
                def __init__(
         | 
| 443 | 
            +
                    self,
         | 
| 444 | 
            +
                    image_size,
         | 
| 445 | 
            +
                    in_channels,
         | 
| 446 | 
            +
                    model_channels,
         | 
| 447 | 
            +
                    out_channels,
         | 
| 448 | 
            +
                    num_res_blocks,
         | 
| 449 | 
            +
                    attention_resolutions,
         | 
| 450 | 
            +
                    dropout=0,
         | 
| 451 | 
            +
                    channel_mult=(1, 2, 4, 8),
         | 
| 452 | 
            +
                    conv_resample=True,
         | 
| 453 | 
            +
                    dims=2,
         | 
| 454 | 
            +
                    num_classes=None,
         | 
| 455 | 
            +
                    use_checkpoint=False,
         | 
| 456 | 
            +
                    use_fp16=False,
         | 
| 457 | 
            +
                    num_heads=-1,
         | 
| 458 | 
            +
                    num_head_channels=-1,
         | 
| 459 | 
            +
                    num_heads_upsample=-1,
         | 
| 460 | 
            +
                    use_scale_shift_norm=False,
         | 
| 461 | 
            +
                    resblock_updown=False,
         | 
| 462 | 
            +
                    use_new_attention_order=False,
         | 
| 463 | 
            +
                    use_spatial_transformer=False,    # custom transformer support
         | 
| 464 | 
            +
                    transformer_depth=1,              # custom transformer support
         | 
| 465 | 
            +
                    context_dim=None,                 # custom transformer support
         | 
| 466 | 
            +
                    n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
         | 
| 467 | 
            +
                    legacy=True,
         | 
| 468 | 
            +
                    disable_self_attentions=None,
         | 
| 469 | 
            +
                    num_attention_blocks=None,
         | 
| 470 | 
            +
                    disable_middle_self_attn=False,
         | 
| 471 | 
            +
                    use_linear_in_transformer=False,
         | 
| 472 | 
            +
                ):
         | 
| 473 | 
            +
                    super().__init__()
         | 
| 474 | 
            +
                    if use_spatial_transformer:
         | 
| 475 | 
            +
                        assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    if context_dim is not None:
         | 
| 478 | 
            +
                        assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
         | 
| 479 | 
            +
                        from omegaconf.listconfig import ListConfig
         | 
| 480 | 
            +
                        if type(context_dim) == ListConfig:
         | 
| 481 | 
            +
                            context_dim = list(context_dim)
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    if num_heads_upsample == -1:
         | 
| 484 | 
            +
                        num_heads_upsample = num_heads
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    if num_heads == -1:
         | 
| 487 | 
            +
                        assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    if num_head_channels == -1:
         | 
| 490 | 
            +
                        assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                    self.image_size = image_size
         | 
| 493 | 
            +
                    self.in_channels = in_channels
         | 
| 494 | 
            +
                    self.model_channels = model_channels
         | 
| 495 | 
            +
                    self.out_channels = out_channels
         | 
| 496 | 
            +
                    if isinstance(num_res_blocks, int):
         | 
| 497 | 
            +
                        self.num_res_blocks = len(channel_mult) * [num_res_blocks]
         | 
| 498 | 
            +
                    else:
         | 
| 499 | 
            +
                        if len(num_res_blocks) != len(channel_mult):
         | 
| 500 | 
            +
                            raise ValueError("provide num_res_blocks either as an int (globally constant) or "
         | 
| 501 | 
            +
                                             "as a list/tuple (per-level) with the same length as channel_mult")
         | 
| 502 | 
            +
                        self.num_res_blocks = num_res_blocks
         | 
| 503 | 
            +
                    if disable_self_attentions is not None:
         | 
| 504 | 
            +
                        # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
         | 
| 505 | 
            +
                        assert len(disable_self_attentions) == len(channel_mult)
         | 
| 506 | 
            +
                    if num_attention_blocks is not None:
         | 
| 507 | 
            +
                        assert len(num_attention_blocks) == len(self.num_res_blocks)
         | 
| 508 | 
            +
                        assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
         | 
| 509 | 
            +
                        print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
         | 
| 510 | 
            +
                              f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
         | 
| 511 | 
            +
                              f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
         | 
| 512 | 
            +
                              f"attention will still not be set.")
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                    self.attention_resolutions = attention_resolutions
         | 
| 515 | 
            +
                    self.dropout = dropout
         | 
| 516 | 
            +
                    self.channel_mult = channel_mult
         | 
| 517 | 
            +
                    self.conv_resample = conv_resample
         | 
| 518 | 
            +
                    self.num_classes = num_classes
         | 
| 519 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 520 | 
            +
                    self.dtype = th.float16 if use_fp16 else th.float32
         | 
| 521 | 
            +
                    self.num_heads = num_heads
         | 
| 522 | 
            +
                    self.num_head_channels = num_head_channels
         | 
| 523 | 
            +
                    self.num_heads_upsample = num_heads_upsample
         | 
| 524 | 
            +
                    self.predict_codebook_ids = n_embed is not None
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                    time_embed_dim = model_channels * 4
         | 
| 527 | 
            +
                    self.time_embed = nn.Sequential(
         | 
| 528 | 
            +
                        linear(model_channels, time_embed_dim),
         | 
| 529 | 
            +
                        nn.SiLU(),
         | 
| 530 | 
            +
                        linear(time_embed_dim, time_embed_dim),
         | 
| 531 | 
            +
                    )
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    if self.num_classes is not None:
         | 
| 534 | 
            +
                        if isinstance(self.num_classes, int):
         | 
| 535 | 
            +
                            self.label_emb = nn.Embedding(num_classes, time_embed_dim)
         | 
| 536 | 
            +
                        elif self.num_classes == "continuous":
         | 
| 537 | 
            +
                            print("setting up linear c_adm embedding layer")
         | 
| 538 | 
            +
                            self.label_emb = nn.Linear(1, time_embed_dim)
         | 
| 539 | 
            +
                        else:
         | 
| 540 | 
            +
                            raise ValueError()
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    self.input_blocks = nn.ModuleList(
         | 
| 543 | 
            +
                        [
         | 
| 544 | 
            +
                            TimestepEmbedSequential(
         | 
| 545 | 
            +
                                conv_nd(dims, in_channels, model_channels, 3, padding=1)
         | 
| 546 | 
            +
                            )
         | 
| 547 | 
            +
                        ]
         | 
| 548 | 
            +
                    )
         | 
| 549 | 
            +
                    self._feature_size = model_channels
         | 
| 550 | 
            +
                    input_block_chans = [model_channels]
         | 
| 551 | 
            +
                    ch = model_channels
         | 
| 552 | 
            +
                    ds = 1
         | 
| 553 | 
            +
                    for level, mult in enumerate(channel_mult):
         | 
| 554 | 
            +
                        for nr in range(self.num_res_blocks[level]):
         | 
| 555 | 
            +
                            layers = [
         | 
| 556 | 
            +
                                ResBlock(
         | 
| 557 | 
            +
                                    ch,
         | 
| 558 | 
            +
                                    time_embed_dim,
         | 
| 559 | 
            +
                                    dropout,
         | 
| 560 | 
            +
                                    out_channels=mult * model_channels,
         | 
| 561 | 
            +
                                    dims=dims,
         | 
| 562 | 
            +
                                    use_checkpoint=use_checkpoint,
         | 
| 563 | 
            +
                                    use_scale_shift_norm=use_scale_shift_norm,
         | 
| 564 | 
            +
                                )
         | 
| 565 | 
            +
                            ]
         | 
| 566 | 
            +
                            ch = mult * model_channels
         | 
| 567 | 
            +
                            if ds in attention_resolutions:
         | 
| 568 | 
            +
                                if num_head_channels == -1:
         | 
| 569 | 
            +
                                    dim_head = ch // num_heads
         | 
| 570 | 
            +
                                else:
         | 
| 571 | 
            +
                                    num_heads = ch // num_head_channels
         | 
| 572 | 
            +
                                    dim_head = num_head_channels
         | 
| 573 | 
            +
                                if legacy:
         | 
| 574 | 
            +
                                    #num_heads = 1
         | 
| 575 | 
            +
                                    dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
         | 
| 576 | 
            +
                                if exists(disable_self_attentions):
         | 
| 577 | 
            +
                                    disabled_sa = disable_self_attentions[level]
         | 
| 578 | 
            +
                                else:
         | 
| 579 | 
            +
                                    disabled_sa = False
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                                if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
         | 
| 582 | 
            +
                                    layers.append(
         | 
| 583 | 
            +
                                        AttentionBlock(
         | 
| 584 | 
            +
                                            ch,
         | 
| 585 | 
            +
                                            use_checkpoint=use_checkpoint,
         | 
| 586 | 
            +
                                            num_heads=num_heads,
         | 
| 587 | 
            +
                                            num_head_channels=dim_head,
         | 
| 588 | 
            +
                                            use_new_attention_order=use_new_attention_order,
         | 
| 589 | 
            +
                                        ) if not use_spatial_transformer else SpatialTransformer(
         | 
| 590 | 
            +
                                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
         | 
| 591 | 
            +
                                            disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
         | 
| 592 | 
            +
                                            use_checkpoint=use_checkpoint
         | 
| 593 | 
            +
                                        )
         | 
| 594 | 
            +
                                    )
         | 
| 595 | 
            +
                            self.input_blocks.append(TimestepEmbedSequential(*layers))
         | 
| 596 | 
            +
                            self._feature_size += ch
         | 
| 597 | 
            +
                            input_block_chans.append(ch)
         | 
| 598 | 
            +
                        if level != len(channel_mult) - 1:
         | 
| 599 | 
            +
                            out_ch = ch
         | 
| 600 | 
            +
                            self.input_blocks.append(
         | 
| 601 | 
            +
                                TimestepEmbedSequential(
         | 
| 602 | 
            +
                                    ResBlock(
         | 
| 603 | 
            +
                                        ch,
         | 
| 604 | 
            +
                                        time_embed_dim,
         | 
| 605 | 
            +
                                        dropout,
         | 
| 606 | 
            +
                                        out_channels=out_ch,
         | 
| 607 | 
            +
                                        dims=dims,
         | 
| 608 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 609 | 
            +
                                        use_scale_shift_norm=use_scale_shift_norm,
         | 
| 610 | 
            +
                                        down=True,
         | 
| 611 | 
            +
                                    )
         | 
| 612 | 
            +
                                    if resblock_updown
         | 
| 613 | 
            +
                                    else Downsample(
         | 
| 614 | 
            +
                                        ch, conv_resample, dims=dims, out_channels=out_ch
         | 
| 615 | 
            +
                                    )
         | 
| 616 | 
            +
                                )
         | 
| 617 | 
            +
                            )
         | 
| 618 | 
            +
                            ch = out_ch
         | 
| 619 | 
            +
                            input_block_chans.append(ch)
         | 
| 620 | 
            +
                            ds *= 2
         | 
| 621 | 
            +
                            self._feature_size += ch
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                    if num_head_channels == -1:
         | 
| 624 | 
            +
                        dim_head = ch // num_heads
         | 
| 625 | 
            +
                    else:
         | 
| 626 | 
            +
                        num_heads = ch // num_head_channels
         | 
| 627 | 
            +
                        dim_head = num_head_channels
         | 
| 628 | 
            +
                    if legacy:
         | 
| 629 | 
            +
                        #num_heads = 1
         | 
| 630 | 
            +
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
         | 
| 631 | 
            +
                    self.middle_block = TimestepEmbedSequential(
         | 
| 632 | 
            +
                        ResBlock(
         | 
| 633 | 
            +
                            ch,
         | 
| 634 | 
            +
                            time_embed_dim,
         | 
| 635 | 
            +
                            dropout,
         | 
| 636 | 
            +
                            dims=dims,
         | 
| 637 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 638 | 
            +
                            use_scale_shift_norm=use_scale_shift_norm,
         | 
| 639 | 
            +
                        ),
         | 
| 640 | 
            +
                        AttentionBlock(
         | 
| 641 | 
            +
                            ch,
         | 
| 642 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 643 | 
            +
                            num_heads=num_heads,
         | 
| 644 | 
            +
                            num_head_channels=dim_head,
         | 
| 645 | 
            +
                            use_new_attention_order=use_new_attention_order,
         | 
| 646 | 
            +
                        ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
         | 
| 647 | 
            +
                                        ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
         | 
| 648 | 
            +
                                        disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
         | 
| 649 | 
            +
                                        use_checkpoint=use_checkpoint
         | 
| 650 | 
            +
                                    ),
         | 
| 651 | 
            +
                        ResBlock(
         | 
| 652 | 
            +
                            ch,
         | 
| 653 | 
            +
                            time_embed_dim,
         | 
| 654 | 
            +
                            dropout,
         | 
| 655 | 
            +
                            dims=dims,
         | 
| 656 | 
            +
                            use_checkpoint=use_checkpoint,
         | 
| 657 | 
            +
                            use_scale_shift_norm=use_scale_shift_norm,
         | 
| 658 | 
            +
                        ),
         | 
| 659 | 
            +
                    )
         | 
| 660 | 
            +
                    self._feature_size += ch
         | 
| 661 | 
            +
             | 
| 662 | 
            +
                    self.output_blocks = nn.ModuleList([])
         | 
| 663 | 
            +
                    for level, mult in list(enumerate(channel_mult))[::-1]:
         | 
| 664 | 
            +
                        for i in range(self.num_res_blocks[level] + 1):
         | 
| 665 | 
            +
                            ich = input_block_chans.pop()
         | 
| 666 | 
            +
                            layers = [
         | 
| 667 | 
            +
                                ResBlock(
         | 
| 668 | 
            +
                                    ch + ich,
         | 
| 669 | 
            +
                                    time_embed_dim,
         | 
| 670 | 
            +
                                    dropout,
         | 
| 671 | 
            +
                                    out_channels=model_channels * mult,
         | 
| 672 | 
            +
                                    dims=dims,
         | 
| 673 | 
            +
                                    use_checkpoint=use_checkpoint,
         | 
| 674 | 
            +
                                    use_scale_shift_norm=use_scale_shift_norm,
         | 
| 675 | 
            +
                                )
         | 
| 676 | 
            +
                            ]
         | 
| 677 | 
            +
                            ch = model_channels * mult
         | 
| 678 | 
            +
                            if ds in attention_resolutions:
         | 
| 679 | 
            +
                                if num_head_channels == -1:
         | 
| 680 | 
            +
                                    dim_head = ch // num_heads
         | 
| 681 | 
            +
                                else:
         | 
| 682 | 
            +
                                    num_heads = ch // num_head_channels
         | 
| 683 | 
            +
                                    dim_head = num_head_channels
         | 
| 684 | 
            +
                                if legacy:
         | 
| 685 | 
            +
                                    #num_heads = 1
         | 
| 686 | 
            +
                                    dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
         | 
| 687 | 
            +
                                if exists(disable_self_attentions):
         | 
| 688 | 
            +
                                    disabled_sa = disable_self_attentions[level]
         | 
| 689 | 
            +
                                else:
         | 
| 690 | 
            +
                                    disabled_sa = False
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                                if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
         | 
| 693 | 
            +
                                    layers.append(
         | 
| 694 | 
            +
                                        AttentionBlock(
         | 
| 695 | 
            +
                                            ch,
         | 
| 696 | 
            +
                                            use_checkpoint=use_checkpoint,
         | 
| 697 | 
            +
                                            num_heads=num_heads_upsample,
         | 
| 698 | 
            +
                                            num_head_channels=dim_head,
         | 
| 699 | 
            +
                                            use_new_attention_order=use_new_attention_order,
         | 
| 700 | 
            +
                                        ) if not use_spatial_transformer else SpatialTransformer(
         | 
| 701 | 
            +
                                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
         | 
| 702 | 
            +
                                            disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
         | 
| 703 | 
            +
                                            use_checkpoint=use_checkpoint
         | 
| 704 | 
            +
                                        )
         | 
| 705 | 
            +
                                    )
         | 
| 706 | 
            +
                            if level and i == self.num_res_blocks[level]:
         | 
| 707 | 
            +
                                out_ch = ch
         | 
| 708 | 
            +
                                layers.append(
         | 
| 709 | 
            +
                                    ResBlock(
         | 
| 710 | 
            +
                                        ch,
         | 
| 711 | 
            +
                                        time_embed_dim,
         | 
| 712 | 
            +
                                        dropout,
         | 
| 713 | 
            +
                                        out_channels=out_ch,
         | 
| 714 | 
            +
                                        dims=dims,
         | 
| 715 | 
            +
                                        use_checkpoint=use_checkpoint,
         | 
| 716 | 
            +
                                        use_scale_shift_norm=use_scale_shift_norm,
         | 
| 717 | 
            +
                                        up=True,
         | 
| 718 | 
            +
                                    )
         | 
| 719 | 
            +
                                    if resblock_updown
         | 
| 720 | 
            +
                                    else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
         | 
| 721 | 
            +
                                )
         | 
| 722 | 
            +
                                ds //= 2
         | 
| 723 | 
            +
                            self.output_blocks.append(TimestepEmbedSequential(*layers))
         | 
| 724 | 
            +
                            self._feature_size += ch
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                    self.out = nn.Sequential(
         | 
| 727 | 
            +
                        normalization(ch),
         | 
| 728 | 
            +
                        nn.SiLU(),
         | 
| 729 | 
            +
                        zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
         | 
| 730 | 
            +
                    )
         | 
| 731 | 
            +
                    if self.predict_codebook_ids:
         | 
| 732 | 
            +
                        self.id_predictor = nn.Sequential(
         | 
| 733 | 
            +
                        normalization(ch),
         | 
| 734 | 
            +
                        conv_nd(dims, model_channels, n_embed, 1),
         | 
| 735 | 
            +
                        #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
         | 
| 736 | 
            +
                    )
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                def convert_to_fp16(self):
         | 
| 739 | 
            +
                    """
         | 
| 740 | 
            +
                    Convert the torso of the model to float16.
         | 
| 741 | 
            +
                    """
         | 
| 742 | 
            +
                    self.input_blocks.apply(convert_module_to_f16)
         | 
| 743 | 
            +
                    self.middle_block.apply(convert_module_to_f16)
         | 
| 744 | 
            +
                    self.output_blocks.apply(convert_module_to_f16)
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                def convert_to_fp32(self):
         | 
| 747 | 
            +
                    """
         | 
| 748 | 
            +
                    Convert the torso of the model to float32.
         | 
| 749 | 
            +
                    """
         | 
| 750 | 
            +
                    self.input_blocks.apply(convert_module_to_f32)
         | 
| 751 | 
            +
                    self.middle_block.apply(convert_module_to_f32)
         | 
| 752 | 
            +
                    self.output_blocks.apply(convert_module_to_f32)
         | 
| 753 | 
            +
             | 
| 754 | 
            +
                def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
         | 
| 755 | 
            +
                    """
         | 
| 756 | 
            +
                    Apply the model to an input batch.
         | 
| 757 | 
            +
                    :param x: an [N x C x ...] Tensor of inputs.
         | 
| 758 | 
            +
                    :param timesteps: a 1-D batch of timesteps.
         | 
| 759 | 
            +
                    :param context: conditioning plugged in via crossattn
         | 
| 760 | 
            +
                    :param y: an [N] Tensor of labels, if class-conditional.
         | 
| 761 | 
            +
                    :return: an [N x C x ...] Tensor of outputs.
         | 
| 762 | 
            +
                    """
         | 
| 763 | 
            +
                    assert (y is not None) == (
         | 
| 764 | 
            +
                        self.num_classes is not None
         | 
| 765 | 
            +
                    ), "must specify y if and only if the model is class-conditional"
         | 
| 766 | 
            +
                    hs = []
         | 
| 767 | 
            +
                    t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
         | 
| 768 | 
            +
                    emb = self.time_embed(t_emb)
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                    if self.num_classes is not None:
         | 
| 771 | 
            +
                        assert y.shape[0] == x.shape[0]
         | 
| 772 | 
            +
                        emb = emb + self.label_emb(y)
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                    h = x.type(self.dtype)
         | 
| 775 | 
            +
                    for module in self.input_blocks:
         | 
| 776 | 
            +
                        h = module(h, emb, context)
         | 
| 777 | 
            +
                        hs.append(h)
         | 
| 778 | 
            +
                    h = self.middle_block(h, emb, context)
         | 
| 779 | 
            +
                    for module in self.output_blocks:
         | 
| 780 | 
            +
                        h = th.cat([h, hs.pop()], dim=1)
         | 
| 781 | 
            +
                        h = module(h, emb, context)
         | 
| 782 | 
            +
                    h = h.type(x.dtype)
         | 
| 783 | 
            +
                    if self.predict_codebook_ids:
         | 
| 784 | 
            +
                        return self.id_predictor(h)
         | 
| 785 | 
            +
                    else:
         | 
| 786 | 
            +
                        return self.out(h)
         | 
    	
        ldm/modules/diffusionmodules/upscaling.py
    ADDED
    
    | @@ -0,0 +1,81 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            from functools import partial
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
         | 
| 7 | 
            +
            from ldm.util import default
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class AbstractLowScaleModel(nn.Module):
         | 
| 11 | 
            +
                # for concatenating a downsampled image to the latent representation
         | 
| 12 | 
            +
                def __init__(self, noise_schedule_config=None):
         | 
| 13 | 
            +
                    super(AbstractLowScaleModel, self).__init__()
         | 
| 14 | 
            +
                    if noise_schedule_config is not None:
         | 
| 15 | 
            +
                        self.register_schedule(**noise_schedule_config)
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def register_schedule(self, beta_schedule="linear", timesteps=1000,
         | 
| 18 | 
            +
                                      linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
         | 
| 19 | 
            +
                    betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
         | 
| 20 | 
            +
                                               cosine_s=cosine_s)
         | 
| 21 | 
            +
                    alphas = 1. - betas
         | 
| 22 | 
            +
                    alphas_cumprod = np.cumprod(alphas, axis=0)
         | 
| 23 | 
            +
                    alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    timesteps, = betas.shape
         | 
| 26 | 
            +
                    self.num_timesteps = int(timesteps)
         | 
| 27 | 
            +
                    self.linear_start = linear_start
         | 
| 28 | 
            +
                    self.linear_end = linear_end
         | 
| 29 | 
            +
                    assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    to_torch = partial(torch.tensor, dtype=torch.float32)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    self.register_buffer('betas', to_torch(betas))
         | 
| 34 | 
            +
                    self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
         | 
| 35 | 
            +
                    self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    # calculations for diffusion q(x_t | x_{t-1}) and others
         | 
| 38 | 
            +
                    self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
         | 
| 39 | 
            +
                    self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
         | 
| 40 | 
            +
                    self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
         | 
| 41 | 
            +
                    self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
         | 
| 42 | 
            +
                    self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def q_sample(self, x_start, t, noise=None):
         | 
| 45 | 
            +
                    noise = default(noise, lambda: torch.randn_like(x_start))
         | 
| 46 | 
            +
                    return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
         | 
| 47 | 
            +
                            extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def forward(self, x):
         | 
| 50 | 
            +
                    return x, None
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def decode(self, x):
         | 
| 53 | 
            +
                    return x
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            class SimpleImageConcat(AbstractLowScaleModel):
         | 
| 57 | 
            +
                # no noise level conditioning
         | 
| 58 | 
            +
                def __init__(self):
         | 
| 59 | 
            +
                    super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
         | 
| 60 | 
            +
                    self.max_noise_level = 0
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def forward(self, x):
         | 
| 63 | 
            +
                    # fix to constant noise level
         | 
| 64 | 
            +
                    return x, torch.zeros(x.shape[0], device=x.device).long()
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
         | 
| 68 | 
            +
                def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
         | 
| 69 | 
            +
                    super().__init__(noise_schedule_config=noise_schedule_config)
         | 
| 70 | 
            +
                    self.max_noise_level = max_noise_level
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def forward(self, x, noise_level=None):
         | 
| 73 | 
            +
                    if noise_level is None:
         | 
| 74 | 
            +
                        noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
         | 
| 75 | 
            +
                    else:
         | 
| 76 | 
            +
                        assert isinstance(noise_level, torch.Tensor)
         | 
| 77 | 
            +
                    z = self.q_sample(x, noise_level)
         | 
| 78 | 
            +
                    return z, noise_level
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
    	
        ldm/modules/diffusionmodules/util.py
    ADDED
    
    | @@ -0,0 +1,270 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # adopted from
         | 
| 2 | 
            +
            # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
         | 
| 3 | 
            +
            # and
         | 
| 4 | 
            +
            # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
         | 
| 5 | 
            +
            # and
         | 
| 6 | 
            +
            # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            # thanks!
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            import os
         | 
| 12 | 
            +
            import math
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            import torch.nn as nn
         | 
| 15 | 
            +
            import numpy as np
         | 
| 16 | 
            +
            from einops import repeat
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from ldm.util import instantiate_from_config
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
         | 
| 22 | 
            +
                if schedule == "linear":
         | 
| 23 | 
            +
                    betas = (
         | 
| 24 | 
            +
                            torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
         | 
| 25 | 
            +
                    )
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                elif schedule == "cosine":
         | 
| 28 | 
            +
                    timesteps = (
         | 
| 29 | 
            +
                            torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
         | 
| 30 | 
            +
                    )
         | 
| 31 | 
            +
                    alphas = timesteps / (1 + cosine_s) * np.pi / 2
         | 
| 32 | 
            +
                    alphas = torch.cos(alphas).pow(2)
         | 
| 33 | 
            +
                    alphas = alphas / alphas[0]
         | 
| 34 | 
            +
                    betas = 1 - alphas[1:] / alphas[:-1]
         | 
| 35 | 
            +
                    betas = np.clip(betas, a_min=0, a_max=0.999)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                elif schedule == "sqrt_linear":
         | 
| 38 | 
            +
                    betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
         | 
| 39 | 
            +
                elif schedule == "sqrt":
         | 
| 40 | 
            +
                    betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
         | 
| 41 | 
            +
                else:
         | 
| 42 | 
            +
                    raise ValueError(f"schedule '{schedule}' unknown.")
         | 
| 43 | 
            +
                return betas.numpy()
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
         | 
| 47 | 
            +
                if ddim_discr_method == 'uniform':
         | 
| 48 | 
            +
                    c = num_ddpm_timesteps // num_ddim_timesteps
         | 
| 49 | 
            +
                    ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
         | 
| 50 | 
            +
                elif ddim_discr_method == 'quad':
         | 
| 51 | 
            +
                    ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
         | 
| 52 | 
            +
                else:
         | 
| 53 | 
            +
                    raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                # assert ddim_timesteps.shape[0] == num_ddim_timesteps
         | 
| 56 | 
            +
                # add one to get the final alpha values right (the ones from first scale to data during sampling)
         | 
| 57 | 
            +
                steps_out = ddim_timesteps + 1
         | 
| 58 | 
            +
                if verbose:
         | 
| 59 | 
            +
                    print(f'Selected timesteps for ddim sampler: {steps_out}')
         | 
| 60 | 
            +
                return steps_out
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
         | 
| 64 | 
            +
                # select alphas for computing the variance schedule
         | 
| 65 | 
            +
                alphas = alphacums[ddim_timesteps]
         | 
| 66 | 
            +
                alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                # according the the formula provided in https://arxiv.org/abs/2010.02502
         | 
| 69 | 
            +
                sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
         | 
| 70 | 
            +
                if verbose:
         | 
| 71 | 
            +
                    print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
         | 
| 72 | 
            +
                    print(f'For the chosen value of eta, which is {eta}, '
         | 
| 73 | 
            +
                          f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
         | 
| 74 | 
            +
                return sigmas, alphas, alphas_prev
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
         | 
| 78 | 
            +
                """
         | 
| 79 | 
            +
                Create a beta schedule that discretizes the given alpha_t_bar function,
         | 
| 80 | 
            +
                which defines the cumulative product of (1-beta) over time from t = [0,1].
         | 
| 81 | 
            +
                :param num_diffusion_timesteps: the number of betas to produce.
         | 
| 82 | 
            +
                :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
         | 
| 83 | 
            +
                                  produces the cumulative product of (1-beta) up to that
         | 
| 84 | 
            +
                                  part of the diffusion process.
         | 
| 85 | 
            +
                :param max_beta: the maximum beta to use; use values lower than 1 to
         | 
| 86 | 
            +
                                 prevent singularities.
         | 
| 87 | 
            +
                """
         | 
| 88 | 
            +
                betas = []
         | 
| 89 | 
            +
                for i in range(num_diffusion_timesteps):
         | 
| 90 | 
            +
                    t1 = i / num_diffusion_timesteps
         | 
| 91 | 
            +
                    t2 = (i + 1) / num_diffusion_timesteps
         | 
| 92 | 
            +
                    betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
         | 
| 93 | 
            +
                return np.array(betas)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            def extract_into_tensor(a, t, x_shape):
         | 
| 97 | 
            +
                b, *_ = t.shape
         | 
| 98 | 
            +
                out = a.gather(-1, t)
         | 
| 99 | 
            +
                return out.reshape(b, *((1,) * (len(x_shape) - 1)))
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
            def checkpoint(func, inputs, params, flag):
         | 
| 103 | 
            +
                """
         | 
| 104 | 
            +
                Evaluate a function without caching intermediate activations, allowing for
         | 
| 105 | 
            +
                reduced memory at the expense of extra compute in the backward pass.
         | 
| 106 | 
            +
                :param func: the function to evaluate.
         | 
| 107 | 
            +
                :param inputs: the argument sequence to pass to `func`.
         | 
| 108 | 
            +
                :param params: a sequence of parameters `func` depends on but does not
         | 
| 109 | 
            +
                               explicitly take as arguments.
         | 
| 110 | 
            +
                :param flag: if False, disable gradient checkpointing.
         | 
| 111 | 
            +
                """
         | 
| 112 | 
            +
                if flag:
         | 
| 113 | 
            +
                    args = tuple(inputs) + tuple(params)
         | 
| 114 | 
            +
                    return CheckpointFunction.apply(func, len(inputs), *args)
         | 
| 115 | 
            +
                else:
         | 
| 116 | 
            +
                    return func(*inputs)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            class CheckpointFunction(torch.autograd.Function):
         | 
| 120 | 
            +
                @staticmethod
         | 
| 121 | 
            +
                def forward(ctx, run_function, length, *args):
         | 
| 122 | 
            +
                    ctx.run_function = run_function
         | 
| 123 | 
            +
                    ctx.input_tensors = list(args[:length])
         | 
| 124 | 
            +
                    ctx.input_params = list(args[length:])
         | 
| 125 | 
            +
                    ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
         | 
| 126 | 
            +
                                               "dtype": torch.get_autocast_gpu_dtype(),
         | 
| 127 | 
            +
                                               "cache_enabled": torch.is_autocast_cache_enabled()}
         | 
| 128 | 
            +
                    with torch.no_grad():
         | 
| 129 | 
            +
                        output_tensors = ctx.run_function(*ctx.input_tensors)
         | 
| 130 | 
            +
                    return output_tensors
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                @staticmethod
         | 
| 133 | 
            +
                def backward(ctx, *output_grads):
         | 
| 134 | 
            +
                    ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
         | 
| 135 | 
            +
                    with torch.enable_grad(), \
         | 
| 136 | 
            +
                            torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
         | 
| 137 | 
            +
                        # Fixes a bug where the first op in run_function modifies the
         | 
| 138 | 
            +
                        # Tensor storage in place, which is not allowed for detach()'d
         | 
| 139 | 
            +
                        # Tensors.
         | 
| 140 | 
            +
                        shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
         | 
| 141 | 
            +
                        output_tensors = ctx.run_function(*shallow_copies)
         | 
| 142 | 
            +
                    input_grads = torch.autograd.grad(
         | 
| 143 | 
            +
                        output_tensors,
         | 
| 144 | 
            +
                        ctx.input_tensors + ctx.input_params,
         | 
| 145 | 
            +
                        output_grads,
         | 
| 146 | 
            +
                        allow_unused=True,
         | 
| 147 | 
            +
                    )
         | 
| 148 | 
            +
                    del ctx.input_tensors
         | 
| 149 | 
            +
                    del ctx.input_params
         | 
| 150 | 
            +
                    del output_tensors
         | 
| 151 | 
            +
                    return (None, None) + input_grads
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
         | 
| 155 | 
            +
                """
         | 
| 156 | 
            +
                Create sinusoidal timestep embeddings.
         | 
| 157 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 158 | 
            +
                                  These may be fractional.
         | 
| 159 | 
            +
                :param dim: the dimension of the output.
         | 
| 160 | 
            +
                :param max_period: controls the minimum frequency of the embeddings.
         | 
| 161 | 
            +
                :return: an [N x dim] Tensor of positional embeddings.
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
                if not repeat_only:
         | 
| 164 | 
            +
                    half = dim // 2
         | 
| 165 | 
            +
                    freqs = torch.exp(
         | 
| 166 | 
            +
                        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
         | 
| 167 | 
            +
                    ).to(device=timesteps.device)
         | 
| 168 | 
            +
                    args = timesteps[:, None].float() * freqs[None]
         | 
| 169 | 
            +
                    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
         | 
| 170 | 
            +
                    if dim % 2:
         | 
| 171 | 
            +
                        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
         | 
| 172 | 
            +
                else:
         | 
| 173 | 
            +
                    embedding = repeat(timesteps, 'b -> b d', d=dim)
         | 
| 174 | 
            +
                return embedding
         | 
| 175 | 
            +
             | 
| 176 | 
            +
             | 
| 177 | 
            +
            def zero_module(module):
         | 
| 178 | 
            +
                """
         | 
| 179 | 
            +
                Zero out the parameters of a module and return it.
         | 
| 180 | 
            +
                """
         | 
| 181 | 
            +
                for p in module.parameters():
         | 
| 182 | 
            +
                    p.detach().zero_()
         | 
| 183 | 
            +
                return module
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            def scale_module(module, scale):
         | 
| 187 | 
            +
                """
         | 
| 188 | 
            +
                Scale the parameters of a module and return it.
         | 
| 189 | 
            +
                """
         | 
| 190 | 
            +
                for p in module.parameters():
         | 
| 191 | 
            +
                    p.detach().mul_(scale)
         | 
| 192 | 
            +
                return module
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            def mean_flat(tensor):
         | 
| 196 | 
            +
                """
         | 
| 197 | 
            +
                Take the mean over all non-batch dimensions.
         | 
| 198 | 
            +
                """
         | 
| 199 | 
            +
                return tensor.mean(dim=list(range(1, len(tensor.shape))))
         | 
| 200 | 
            +
             | 
| 201 | 
            +
             | 
| 202 | 
            +
            def normalization(channels):
         | 
| 203 | 
            +
                """
         | 
| 204 | 
            +
                Make a standard normalization layer.
         | 
| 205 | 
            +
                :param channels: number of input channels.
         | 
| 206 | 
            +
                :return: an nn.Module for normalization.
         | 
| 207 | 
            +
                """
         | 
| 208 | 
            +
                return GroupNorm32(32, channels)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
             | 
| 211 | 
            +
            # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
         | 
| 212 | 
            +
            class SiLU(nn.Module):
         | 
| 213 | 
            +
                def forward(self, x):
         | 
| 214 | 
            +
                    return x * torch.sigmoid(x)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
             | 
| 217 | 
            +
            class GroupNorm32(nn.GroupNorm):
         | 
| 218 | 
            +
                def forward(self, x):
         | 
| 219 | 
            +
                    return super().forward(x.float()).type(x.dtype)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            def conv_nd(dims, *args, **kwargs):
         | 
| 222 | 
            +
                """
         | 
| 223 | 
            +
                Create a 1D, 2D, or 3D convolution module.
         | 
| 224 | 
            +
                """
         | 
| 225 | 
            +
                if dims == 1:
         | 
| 226 | 
            +
                    return nn.Conv1d(*args, **kwargs)
         | 
| 227 | 
            +
                elif dims == 2:
         | 
| 228 | 
            +
                    return nn.Conv2d(*args, **kwargs)
         | 
| 229 | 
            +
                elif dims == 3:
         | 
| 230 | 
            +
                    return nn.Conv3d(*args, **kwargs)
         | 
| 231 | 
            +
                raise ValueError(f"unsupported dimensions: {dims}")
         | 
| 232 | 
            +
             | 
| 233 | 
            +
             | 
| 234 | 
            +
            def linear(*args, **kwargs):
         | 
| 235 | 
            +
                """
         | 
| 236 | 
            +
                Create a linear module.
         | 
| 237 | 
            +
                """
         | 
| 238 | 
            +
                return nn.Linear(*args, **kwargs)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
             | 
| 241 | 
            +
            def avg_pool_nd(dims, *args, **kwargs):
         | 
| 242 | 
            +
                """
         | 
| 243 | 
            +
                Create a 1D, 2D, or 3D average pooling module.
         | 
| 244 | 
            +
                """
         | 
| 245 | 
            +
                if dims == 1:
         | 
| 246 | 
            +
                    return nn.AvgPool1d(*args, **kwargs)
         | 
| 247 | 
            +
                elif dims == 2:
         | 
| 248 | 
            +
                    return nn.AvgPool2d(*args, **kwargs)
         | 
| 249 | 
            +
                elif dims == 3:
         | 
| 250 | 
            +
                    return nn.AvgPool3d(*args, **kwargs)
         | 
| 251 | 
            +
                raise ValueError(f"unsupported dimensions: {dims}")
         | 
| 252 | 
            +
             | 
| 253 | 
            +
             | 
| 254 | 
            +
            class HybridConditioner(nn.Module):
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                def __init__(self, c_concat_config, c_crossattn_config):
         | 
| 257 | 
            +
                    super().__init__()
         | 
| 258 | 
            +
                    self.concat_conditioner = instantiate_from_config(c_concat_config)
         | 
| 259 | 
            +
                    self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def forward(self, c_concat, c_crossattn):
         | 
| 262 | 
            +
                    c_concat = self.concat_conditioner(c_concat)
         | 
| 263 | 
            +
                    c_crossattn = self.crossattn_conditioner(c_crossattn)
         | 
| 264 | 
            +
                    return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
         | 
| 265 | 
            +
             | 
| 266 | 
            +
             | 
| 267 | 
            +
            def noise_like(shape, device, repeat=False):
         | 
| 268 | 
            +
                repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
         | 
| 269 | 
            +
                noise = lambda: torch.randn(shape, device=device)
         | 
| 270 | 
            +
                return repeat_noise() if repeat else noise()
         | 
    	
        ldm/modules/ema.py
    ADDED
    
    | @@ -0,0 +1,80 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class LitEma(nn.Module):
         | 
| 6 | 
            +
                def __init__(self, model, decay=0.9999, use_num_upates=True):
         | 
| 7 | 
            +
                    super().__init__()
         | 
| 8 | 
            +
                    if decay < 0.0 or decay > 1.0:
         | 
| 9 | 
            +
                        raise ValueError('Decay must be between 0 and 1')
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                    self.m_name2s_name = {}
         | 
| 12 | 
            +
                    self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
         | 
| 13 | 
            +
                    self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
         | 
| 14 | 
            +
                    else torch.tensor(-1, dtype=torch.int))
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    for name, p in model.named_parameters():
         | 
| 17 | 
            +
                        if p.requires_grad:
         | 
| 18 | 
            +
                            # remove as '.'-character is not allowed in buffers
         | 
| 19 | 
            +
                            s_name = name.replace('.', '')
         | 
| 20 | 
            +
                            self.m_name2s_name.update({name: s_name})
         | 
| 21 | 
            +
                            self.register_buffer(s_name, p.clone().detach().data)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    self.collected_params = []
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def reset_num_updates(self):
         | 
| 26 | 
            +
                    del self.num_updates
         | 
| 27 | 
            +
                    self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def forward(self, model):
         | 
| 30 | 
            +
                    decay = self.decay
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    if self.num_updates >= 0:
         | 
| 33 | 
            +
                        self.num_updates += 1
         | 
| 34 | 
            +
                        decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    one_minus_decay = 1.0 - decay
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    with torch.no_grad():
         | 
| 39 | 
            +
                        m_param = dict(model.named_parameters())
         | 
| 40 | 
            +
                        shadow_params = dict(self.named_buffers())
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                        for key in m_param:
         | 
| 43 | 
            +
                            if m_param[key].requires_grad:
         | 
| 44 | 
            +
                                sname = self.m_name2s_name[key]
         | 
| 45 | 
            +
                                shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
         | 
| 46 | 
            +
                                shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
         | 
| 47 | 
            +
                            else:
         | 
| 48 | 
            +
                                assert not key in self.m_name2s_name
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def copy_to(self, model):
         | 
| 51 | 
            +
                    m_param = dict(model.named_parameters())
         | 
| 52 | 
            +
                    shadow_params = dict(self.named_buffers())
         | 
| 53 | 
            +
                    for key in m_param:
         | 
| 54 | 
            +
                        if m_param[key].requires_grad:
         | 
| 55 | 
            +
                            m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
         | 
| 56 | 
            +
                        else:
         | 
| 57 | 
            +
                            assert not key in self.m_name2s_name
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def store(self, parameters):
         | 
| 60 | 
            +
                    """
         | 
| 61 | 
            +
                    Save the current parameters for restoring later.
         | 
| 62 | 
            +
                    Args:
         | 
| 63 | 
            +
                      parameters: Iterable of `torch.nn.Parameter`; the parameters to be
         | 
| 64 | 
            +
                        temporarily stored.
         | 
| 65 | 
            +
                    """
         | 
| 66 | 
            +
                    self.collected_params = [param.clone() for param in parameters]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def restore(self, parameters):
         | 
| 69 | 
            +
                    """
         | 
| 70 | 
            +
                    Restore the parameters stored with the `store` method.
         | 
| 71 | 
            +
                    Useful to validate the model with EMA parameters without affecting the
         | 
| 72 | 
            +
                    original optimization process. Store the parameters before the
         | 
| 73 | 
            +
                    `copy_to` method. After validation (or model saving), use this to
         | 
| 74 | 
            +
                    restore the former parameters.
         | 
| 75 | 
            +
                    Args:
         | 
| 76 | 
            +
                      parameters: Iterable of `torch.nn.Parameter`; the parameters to be
         | 
| 77 | 
            +
                        updated with the stored parameters.
         | 
| 78 | 
            +
                    """
         | 
| 79 | 
            +
                    for c_param, param in zip(self.collected_params, parameters):
         | 
| 80 | 
            +
                        param.data.copy_(c_param.data)
         | 
    	
        ldm/util.py
    ADDED
    
    | @@ -0,0 +1,197 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import importlib
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import optim
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from inspect import isfunction
         | 
| 8 | 
            +
            from PIL import Image, ImageDraw, ImageFont
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def log_txt_as_img(wh, xc, size=10):
         | 
| 12 | 
            +
                # wh a tuple of (width, height)
         | 
| 13 | 
            +
                # xc a list of captions to plot
         | 
| 14 | 
            +
                b = len(xc)
         | 
| 15 | 
            +
                txts = list()
         | 
| 16 | 
            +
                for bi in range(b):
         | 
| 17 | 
            +
                    txt = Image.new("RGB", wh, color="white")
         | 
| 18 | 
            +
                    draw = ImageDraw.Draw(txt)
         | 
| 19 | 
            +
                    font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
         | 
| 20 | 
            +
                    nc = int(40 * (wh[0] / 256))
         | 
| 21 | 
            +
                    lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    try:
         | 
| 24 | 
            +
                        draw.text((0, 0), lines, fill="black", font=font)
         | 
| 25 | 
            +
                    except UnicodeEncodeError:
         | 
| 26 | 
            +
                        print("Cant encode string for logging. Skipping.")
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
         | 
| 29 | 
            +
                    txts.append(txt)
         | 
| 30 | 
            +
                txts = np.stack(txts)
         | 
| 31 | 
            +
                txts = torch.tensor(txts)
         | 
| 32 | 
            +
                return txts
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def ismap(x):
         | 
| 36 | 
            +
                if not isinstance(x, torch.Tensor):
         | 
| 37 | 
            +
                    return False
         | 
| 38 | 
            +
                return (len(x.shape) == 4) and (x.shape[1] > 3)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def isimage(x):
         | 
| 42 | 
            +
                if not isinstance(x,torch.Tensor):
         | 
| 43 | 
            +
                    return False
         | 
| 44 | 
            +
                return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def exists(x):
         | 
| 48 | 
            +
                return x is not None
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def default(val, d):
         | 
| 52 | 
            +
                if exists(val):
         | 
| 53 | 
            +
                    return val
         | 
| 54 | 
            +
                return d() if isfunction(d) else d
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def mean_flat(tensor):
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
         | 
| 60 | 
            +
                Take the mean over all non-batch dimensions.
         | 
| 61 | 
            +
                """
         | 
| 62 | 
            +
                return tensor.mean(dim=list(range(1, len(tensor.shape))))
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            def count_params(model, verbose=False):
         | 
| 66 | 
            +
                total_params = sum(p.numel() for p in model.parameters())
         | 
| 67 | 
            +
                if verbose:
         | 
| 68 | 
            +
                    print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
         | 
| 69 | 
            +
                return total_params
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            def instantiate_from_config(config):
         | 
| 73 | 
            +
                if not "target" in config:
         | 
| 74 | 
            +
                    if config == '__is_first_stage__':
         | 
| 75 | 
            +
                        return None
         | 
| 76 | 
            +
                    elif config == "__is_unconditional__":
         | 
| 77 | 
            +
                        return None
         | 
| 78 | 
            +
                    raise KeyError("Expected key `target` to instantiate.")
         | 
| 79 | 
            +
                return get_obj_from_str(config["target"])(**config.get("params", dict()))
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def get_obj_from_str(string, reload=False):
         | 
| 83 | 
            +
                module, cls = string.rsplit(".", 1)
         | 
| 84 | 
            +
                if reload:
         | 
| 85 | 
            +
                    module_imp = importlib.import_module(module)
         | 
| 86 | 
            +
                    importlib.reload(module_imp)
         | 
| 87 | 
            +
                return getattr(importlib.import_module(module, package=None), cls)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            class AdamWwithEMAandWings(optim.Optimizer):
         | 
| 91 | 
            +
                # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
         | 
| 92 | 
            +
                def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8,  # TODO: check hyperparameters before using
         | 
| 93 | 
            +
                             weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999,   # ema decay to match previous code
         | 
| 94 | 
            +
                             ema_power=1., param_names=()):
         | 
| 95 | 
            +
                    """AdamW that saves EMA versions of the parameters."""
         | 
| 96 | 
            +
                    if not 0.0 <= lr:
         | 
| 97 | 
            +
                        raise ValueError("Invalid learning rate: {}".format(lr))
         | 
| 98 | 
            +
                    if not 0.0 <= eps:
         | 
| 99 | 
            +
                        raise ValueError("Invalid epsilon value: {}".format(eps))
         | 
| 100 | 
            +
                    if not 0.0 <= betas[0] < 1.0:
         | 
| 101 | 
            +
                        raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
         | 
| 102 | 
            +
                    if not 0.0 <= betas[1] < 1.0:
         | 
| 103 | 
            +
                        raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
         | 
| 104 | 
            +
                    if not 0.0 <= weight_decay:
         | 
| 105 | 
            +
                        raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
         | 
| 106 | 
            +
                    if not 0.0 <= ema_decay <= 1.0:
         | 
| 107 | 
            +
                        raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
         | 
| 108 | 
            +
                    defaults = dict(lr=lr, betas=betas, eps=eps,
         | 
| 109 | 
            +
                                    weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
         | 
| 110 | 
            +
                                    ema_power=ema_power, param_names=param_names)
         | 
| 111 | 
            +
                    super().__init__(params, defaults)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                def __setstate__(self, state):
         | 
| 114 | 
            +
                    super().__setstate__(state)
         | 
| 115 | 
            +
                    for group in self.param_groups:
         | 
| 116 | 
            +
                        group.setdefault('amsgrad', False)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                @torch.no_grad()
         | 
| 119 | 
            +
                def step(self, closure=None):
         | 
| 120 | 
            +
                    """Performs a single optimization step.
         | 
| 121 | 
            +
                    Args:
         | 
| 122 | 
            +
                        closure (callable, optional): A closure that reevaluates the model
         | 
| 123 | 
            +
                            and returns the loss.
         | 
| 124 | 
            +
                    """
         | 
| 125 | 
            +
                    loss = None
         | 
| 126 | 
            +
                    if closure is not None:
         | 
| 127 | 
            +
                        with torch.enable_grad():
         | 
| 128 | 
            +
                            loss = closure()
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    for group in self.param_groups:
         | 
| 131 | 
            +
                        params_with_grad = []
         | 
| 132 | 
            +
                        grads = []
         | 
| 133 | 
            +
                        exp_avgs = []
         | 
| 134 | 
            +
                        exp_avg_sqs = []
         | 
| 135 | 
            +
                        ema_params_with_grad = []
         | 
| 136 | 
            +
                        state_sums = []
         | 
| 137 | 
            +
                        max_exp_avg_sqs = []
         | 
| 138 | 
            +
                        state_steps = []
         | 
| 139 | 
            +
                        amsgrad = group['amsgrad']
         | 
| 140 | 
            +
                        beta1, beta2 = group['betas']
         | 
| 141 | 
            +
                        ema_decay = group['ema_decay']
         | 
| 142 | 
            +
                        ema_power = group['ema_power']
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                        for p in group['params']:
         | 
| 145 | 
            +
                            if p.grad is None:
         | 
| 146 | 
            +
                                continue
         | 
| 147 | 
            +
                            params_with_grad.append(p)
         | 
| 148 | 
            +
                            if p.grad.is_sparse:
         | 
| 149 | 
            +
                                raise RuntimeError('AdamW does not support sparse gradients')
         | 
| 150 | 
            +
                            grads.append(p.grad)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                            state = self.state[p]
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                            # State initialization
         | 
| 155 | 
            +
                            if len(state) == 0:
         | 
| 156 | 
            +
                                state['step'] = 0
         | 
| 157 | 
            +
                                # Exponential moving average of gradient values
         | 
| 158 | 
            +
                                state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
         | 
| 159 | 
            +
                                # Exponential moving average of squared gradient values
         | 
| 160 | 
            +
                                state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
         | 
| 161 | 
            +
                                if amsgrad:
         | 
| 162 | 
            +
                                    # Maintains max of all exp. moving avg. of sq. grad. values
         | 
| 163 | 
            +
                                    state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
         | 
| 164 | 
            +
                                # Exponential moving average of parameter values
         | 
| 165 | 
            +
                                state['param_exp_avg'] = p.detach().float().clone()
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                            exp_avgs.append(state['exp_avg'])
         | 
| 168 | 
            +
                            exp_avg_sqs.append(state['exp_avg_sq'])
         | 
| 169 | 
            +
                            ema_params_with_grad.append(state['param_exp_avg'])
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                            if amsgrad:
         | 
| 172 | 
            +
                                max_exp_avg_sqs.append(state['max_exp_avg_sq'])
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                            # update the steps for each param group update
         | 
| 175 | 
            +
                            state['step'] += 1
         | 
| 176 | 
            +
                            # record the step after step update
         | 
| 177 | 
            +
                            state_steps.append(state['step'])
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                        optim._functional.adamw(params_with_grad,
         | 
| 180 | 
            +
                                grads,
         | 
| 181 | 
            +
                                exp_avgs,
         | 
| 182 | 
            +
                                exp_avg_sqs,
         | 
| 183 | 
            +
                                max_exp_avg_sqs,
         | 
| 184 | 
            +
                                state_steps,
         | 
| 185 | 
            +
                                amsgrad=amsgrad,
         | 
| 186 | 
            +
                                beta1=beta1,
         | 
| 187 | 
            +
                                beta2=beta2,
         | 
| 188 | 
            +
                                lr=group['lr'],
         | 
| 189 | 
            +
                                weight_decay=group['weight_decay'],
         | 
| 190 | 
            +
                                eps=group['eps'],
         | 
| 191 | 
            +
                                maximize=False)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
         | 
| 194 | 
            +
                        for param, ema_param in zip(params_with_grad, ema_params_with_grad):
         | 
| 195 | 
            +
                            ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    return loss
         | 
    	
        models/cldm_v15.yaml
    ADDED
    
    | @@ -0,0 +1,79 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            model:
         | 
| 2 | 
            +
              target: cldm.cldm.ControlLDM
         | 
| 3 | 
            +
              params:
         | 
| 4 | 
            +
                linear_start: 0.00085
         | 
| 5 | 
            +
                linear_end: 0.0120
         | 
| 6 | 
            +
                num_timesteps_cond: 1
         | 
| 7 | 
            +
                log_every_t: 200
         | 
| 8 | 
            +
                timesteps: 1000
         | 
| 9 | 
            +
                first_stage_key: "jpg"
         | 
| 10 | 
            +
                cond_stage_key: "txt"
         | 
| 11 | 
            +
                control_key: "hint"
         | 
| 12 | 
            +
                image_size: 64
         | 
| 13 | 
            +
                channels: 4
         | 
| 14 | 
            +
                cond_stage_trainable: false
         | 
| 15 | 
            +
                conditioning_key: crossattn
         | 
| 16 | 
            +
                monitor: val/loss_simple_ema
         | 
| 17 | 
            +
                scale_factor: 0.18215
         | 
| 18 | 
            +
                use_ema: False
         | 
| 19 | 
            +
                only_mid_control: False
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                control_stage_config:
         | 
| 22 | 
            +
                  target: cldm.cldm.ControlNet
         | 
| 23 | 
            +
                  params:
         | 
| 24 | 
            +
                    image_size: 32 # unused
         | 
| 25 | 
            +
                    in_channels: 4
         | 
| 26 | 
            +
                    hint_channels: 3
         | 
| 27 | 
            +
                    model_channels: 320
         | 
| 28 | 
            +
                    attention_resolutions: [ 4, 2, 1 ]
         | 
| 29 | 
            +
                    num_res_blocks: 2
         | 
| 30 | 
            +
                    channel_mult: [ 1, 2, 4, 4 ]
         | 
| 31 | 
            +
                    num_heads: 8
         | 
| 32 | 
            +
                    use_spatial_transformer: True
         | 
| 33 | 
            +
                    transformer_depth: 1
         | 
| 34 | 
            +
                    context_dim: 768
         | 
| 35 | 
            +
                    use_checkpoint: True
         | 
| 36 | 
            +
                    legacy: False
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                unet_config:
         | 
| 39 | 
            +
                  target: cldm.cldm.ControlledUnetModel
         | 
| 40 | 
            +
                  params:
         | 
| 41 | 
            +
                    image_size: 32 # unused
         | 
| 42 | 
            +
                    in_channels: 4
         | 
| 43 | 
            +
                    out_channels: 4
         | 
| 44 | 
            +
                    model_channels: 320
         | 
| 45 | 
            +
                    attention_resolutions: [ 4, 2, 1 ]
         | 
| 46 | 
            +
                    num_res_blocks: 2
         | 
| 47 | 
            +
                    channel_mult: [ 1, 2, 4, 4 ]
         | 
| 48 | 
            +
                    num_heads: 8
         | 
| 49 | 
            +
                    use_spatial_transformer: True
         | 
| 50 | 
            +
                    transformer_depth: 1
         | 
| 51 | 
            +
                    context_dim: 768
         | 
| 52 | 
            +
                    use_checkpoint: True
         | 
| 53 | 
            +
                    legacy: False
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                first_stage_config:
         | 
| 56 | 
            +
                  target: ldm.models.autoencoder.AutoencoderKL
         | 
| 57 | 
            +
                  params:
         | 
| 58 | 
            +
                    embed_dim: 4
         | 
| 59 | 
            +
                    monitor: val/rec_loss
         | 
| 60 | 
            +
                    ddconfig:
         | 
| 61 | 
            +
                      double_z: true
         | 
| 62 | 
            +
                      z_channels: 4
         | 
| 63 | 
            +
                      resolution: 256
         | 
| 64 | 
            +
                      in_channels: 3
         | 
| 65 | 
            +
                      out_ch: 3
         | 
| 66 | 
            +
                      ch: 128
         | 
| 67 | 
            +
                      ch_mult:
         | 
| 68 | 
            +
                      - 1
         | 
| 69 | 
            +
                      - 2
         | 
| 70 | 
            +
                      - 4
         | 
| 71 | 
            +
                      - 4
         | 
| 72 | 
            +
                      num_res_blocks: 2
         | 
| 73 | 
            +
                      attn_resolutions: []
         | 
| 74 | 
            +
                      dropout: 0.0
         | 
| 75 | 
            +
                    lossconfig:
         | 
| 76 | 
            +
                      target: torch.nn.Identity
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                cond_stage_config:
         | 
| 79 | 
            +
                  target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
         | 
