# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from modules.encoder.position_encoder import PositionEncoder from modules.general.utils import append_dims, ConvNd, normalization, zero_module from .attention import AttentionBlock from .resblock import Downsample, ResBlock, Upsample class UNet(nn.Module): r"""The full UNet model with attention and timestep embedding. Args: dims: determines if the signal is 1D (temporal), 2D(spatial). in_channels: channels in the input Tensor. model_channels: base channel count for the model. out_channels: channels in the output Tensor. num_res_blocks: number of residual blocks per downsample. channel_mult: channel multiplier for each level of the UNet. num_attn_blocks: number of attention blocks at place. attention_resolutions: a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used. num_heads: the number of attention heads in each attention layer. num_head_channels: if specified, ignore num_heads and instead use a fixed channel width per attention head. d_context: if specified, use for cross-attention channel project. p_dropout: the dropout probability. use_self_attention: Apply self attention before cross attention. num_classes: if specified (as an int), then this model will be class-conditional with ``num_classes`` classes. use_extra_film: if specified, use an extra FiLM-like conditioning mechanism. d_emb: if specified, use for FiLM-like conditioning. use_scale_shift_norm: use a FiLM-like conditioning mechanism. resblock_updown: use residual blocks for up/downsampling. """ def __init__( self, dims: int = 1, in_channels: int = 100, model_channels: int = 128, out_channels: int = 100, h_dim: int = 128, num_res_blocks: int = 1, channel_mult: tuple = (1, 2, 4), num_attn_blocks: int = 1, attention_resolutions: tuple = (1, 2, 4), num_heads: int = 1, num_head_channels: int = -1, d_context: int = None, context_hdim: int = 128, p_dropout: float = 0.0, num_classes: int = -1, use_extra_film: str = None, d_emb: int = None, use_scale_shift_norm: bool = True, resblock_updown: bool = False, ): super().__init__() self.dims = dims self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.channel_mult = channel_mult self.num_attn_blocks = num_attn_blocks self.attention_resolutions = attention_resolutions self.num_heads = num_heads self.num_head_channels = num_head_channels self.d_context = d_context self.p_dropout = p_dropout self.num_classes = num_classes self.use_extra_film = use_extra_film self.d_emb = d_emb self.use_scale_shift_norm = use_scale_shift_norm self.resblock_updown = resblock_updown time_embed_dim = model_channels * 4 self.pos_enc = PositionEncoder(model_channels, time_embed_dim) assert ( num_classes == -1 or use_extra_film is None ), "You cannot set both num_classes and use_extra_film." if self.num_classes > 0: # TODO: if used for singer, norm should be 1, correct? self.label_emb = nn.Embedding(num_classes, time_embed_dim, max_norm=1.0) elif use_extra_film is not None: assert ( d_emb is not None ), "d_emb must be specified if use_extra_film is not None" assert use_extra_film in [ "add", "concat", ], f"use_extra_film only supported by add or concat. Your input is {use_extra_film}" self.use_extra_film = use_extra_film self.film_emb = ConvNd(dims, d_emb, time_embed_dim, 1) if use_extra_film == "concat": time_embed_dim *= 2 # Input blocks ch = input_ch = int(channel_mult[0] * model_channels) self.input_blocks = nn.ModuleList( [UNetSequential(ConvNd(dims, in_channels, ch, 3, padding=1))] ) self._feature_size = ch input_block_chans = [ch] ds = 1 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, p_dropout, out_channels=int(mult * model_channels), dims=dims, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = int(mult * model_channels) if ds in attention_resolutions: for _ in range(num_attn_blocks): layers.append( AttentionBlock( ch, num_heads=num_heads, num_head_channels=num_head_channels, encoder_channels=d_context, dims=dims, h_dim=h_dim // (level + 1), encoder_hdim=context_hdim, p_dropout=p_dropout, ) ) self.input_blocks.append(UNetSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( UNetSequential( ResBlock( ch, time_embed_dim, p_dropout, out_channels=out_ch, dims=dims, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample(ch, dims=dims, out_channels=out_ch) ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch # Middle blocks self.middle_block = UNetSequential( ResBlock( ch, time_embed_dim, p_dropout, dims=dims, use_scale_shift_norm=use_scale_shift_norm, ), AttentionBlock( ch, num_heads=num_heads, num_head_channels=num_head_channels, encoder_channels=d_context, dims=dims, h_dim=h_dim // (level + 1), encoder_hdim=context_hdim, p_dropout=p_dropout, ), ResBlock( ch, time_embed_dim, p_dropout, dims=dims, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch # Output blocks self.output_blocks = nn.ModuleList([]) for level, mult in tuple(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ ResBlock( ch + ich, time_embed_dim, p_dropout, out_channels=int(model_channels * mult), dims=dims, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = int(model_channels * mult) if ds in attention_resolutions: for _ in range(num_attn_blocks): layers.append( AttentionBlock( ch, num_heads=num_heads, num_head_channels=num_head_channels, encoder_channels=d_context, dims=dims, h_dim=h_dim // (level + 1), encoder_hdim=context_hdim, p_dropout=p_dropout, ) ) if level and i == num_res_blocks: out_ch = ch layers.append( ResBlock( ch, time_embed_dim, p_dropout, out_channels=out_ch, dims=dims, use_scale_shift_norm=use_scale_shift_norm, up=True, ) if resblock_updown else Upsample(ch, dims=dims, out_channels=out_ch) ) ds //= 2 self.output_blocks.append(UNetSequential(*layers)) self._feature_size += ch # Final proj out self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module(ConvNd(dims, input_ch, out_channels, 3, padding=1)), ) def forward(self, x, timesteps=None, context=None, y=None, **kwargs): r"""Apply the model to an input batch. Args: x: an [N x C x ...] Tensor of inputs. timesteps: a 1-D batch of timesteps, i.e. [N]. context: conditioning Tensor with shape of [N x ``d_context`` x ...] plugged in via cross attention. y: an [N] Tensor of labels, if **class-conditional**. an [N x ``d_emb`` x ...] Tensor if **film-embed conditional**. Returns: an [N x C x ...] Tensor of outputs. """ assert (y is None) or ( (y is not None) and ((self.num_classes > 0) or (self.use_extra_film is not None)) ), f"y must be specified if num_classes or use_extra_film is not None. \nGot num_classes: {self.num_classes}\t\nuse_extra_film: {self.use_extra_film}\t\n" hs = [] emb = self.pos_enc(timesteps) emb = append_dims(emb, x.dim()) if self.num_classes > 0: assert y.size() == (x.size(0),) emb = emb + self.label_emb(y) elif self.use_extra_film is not None: assert y.size() == (x.size(0), self.d_emb, *x.size()[2:]) y = self.film_emb(y) if self.use_extra_film == "add": emb = emb + y elif self.use_extra_film == "concat": emb = torch.cat([emb, y], dim=1) h = x for module in self.input_blocks: h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb, context) return self.out(h) class UNetSequential(nn.Sequential): r"""A sequential module that passes embeddings to the children that support it.""" def forward(self, x, emb=None, context=None): for layer in self: if isinstance(layer, ResBlock): x = layer(x, emb) elif isinstance(layer, AttentionBlock): x = layer(x, context) else: x = layer(x) return x