import torch.nn.functional as F import torch from torch import nn from einops import rearrange from inspect import isfunction import math from tqdm import tqdm def exists(x): """Return true for x is not None.""" return x is not None def default(val, d): """Helper function""" if exists(val): return val return d() if isfunction(d) else d class Residual(nn.Module): """Skip connection""" def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x def Upsample(dim): """Upsample layer, a transposed convolution layer with stride=2""" return nn.ConvTranspose2d(dim, dim, 4, 2, 1) def Downsample(dim): """Downsample layer, a convolution layer with stride=2""" return nn.Conv2d(dim, dim, 4, 2, 1) class SinusoidalPositionEmbeddings(nn.Module): """Return sinusoidal embedding for integer time step.""" def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings class Block(nn.Module): """Stack of convolution, normalization, and non-linear activation""" def __init__(self, dim, dim_out, groups=8): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift=None): x = self.proj(x) x = self.norm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.act(x) return x class ResnetBlock(nn.Module): """Stack of [conv + norm + act (+ scale&shift)], with positional embedding inserted """ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): super().__init__() self.mlp = ( nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if exists(time_emb_dim) else None ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None): h = self.block1(x) if exists(self.mlp) and exists(time_emb): time_emb = self.mlp(time_emb) # Adding positional embedding to intermediate layer (by broadcasting along spatial dimension) h = rearrange(time_emb, "b c -> b c 1 1") + h h = self.block2(h) return h + self.res_conv(x) class ConvNextBlock(nn.Module): """Stack of [conv7x7 (+ condition(pos)) + norm + conv3x3 + act + norm + conv3x3 + res1x1],with positional embedding inserted""" def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True): super().__init__() self.mlp = ( nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim)) if exists(time_emb_dim) else None ) self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) self.net = nn.Sequential( nn.GroupNorm(1, dim) if norm else nn.Identity(), nn.Conv2d(dim, dim_out * mult, 3, padding=1), nn.GELU(), nn.GroupNorm(1, dim_out * mult), nn.Conv2d(dim_out * mult, dim_out, 3, padding=1), ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None): h = self.ds_conv(x) if exists(self.mlp) and exists(time_emb): assert exists(time_emb), "time embedding must be passed in" condition = self.mlp(time_emb) h = h + rearrange(condition, "b c -> b c 1 1") h = self.net(h) return h + self.res_conv(x) class PreNorm(nn.Module): """Apply normalization before 'fn'""" def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.GroupNorm(1, dim) def forward(self, x, *args, **kwargs): x = self.norm(x) return self.fn(x, *args, **kwargs) class ConditionalEmbedding(nn.Module): """Return embedding for label and projection for text embedding""" def __init__(self, num_labels, embedding_dim, condition_type="instrument_family"): super(ConditionalEmbedding, self).__init__() if condition_type == "instrument_family": self.embedding = nn.Embedding(num_labels, embedding_dim) elif condition_type == "natural_language_prompt": self.embedding = nn.Linear(embedding_dim, embedding_dim, bias=True) else: raise NotImplementedError() def forward(self, labels): return self.embedding(labels) class LinearCrossAttention(nn.Module): """Combination of efficient attention and cross attention.""" def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32): super().__init__() self.dim_head = dim_head self.scale = dim_head ** -0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) # embedding for key and value self.label_key = nn.Linear(label_emb_dim, hidden_dim) self.label_value = nn.Linear(label_emb_dim, hidden_dim) def forward(self, x, label_embedding=None): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv ) if label_embedding is not None: label_k = self.label_key(label_embedding).view(b, self.heads, self.dim_head, 1) label_v = self.label_value(label_embedding).view(b, self.heads, self.dim_head, 1) k = torch.cat([k, label_k], dim=-1) v = torch.cat([v, label_v], dim=-1) q = q.softmax(dim=-2) k = k.softmax(dim=-1) q = q * self.scale context = torch.einsum("b h d n, b h e n -> b h d e", k, v) out = torch.einsum("b h d e, b h d n -> b h e n", context, q) out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) def pad_to_match(encoder_tensor, decoder_tensor): """ Pads the decoder_tensor to match the spatial dimensions of encoder_tensor. :param encoder_tensor: The feature map from the encoder. :param decoder_tensor: The feature map from the decoder that needs to be upsampled. :return: Padded decoder_tensor with the same spatial dimensions as encoder_tensor. """ enc_shape = encoder_tensor.shape[2:] # spatial dimensions are at index 2 and 3 dec_shape = decoder_tensor.shape[2:] # assume enc_shape >= dec_shape delta_w = enc_shape[1] - dec_shape[1] delta_h = enc_shape[0] - dec_shape[0] # padding padding_left = delta_w // 2 padding_right = delta_w - padding_left padding_top = delta_h // 2 padding_bottom = delta_h - padding_top decoder_tensor_padded = F.pad(decoder_tensor, (padding_left, padding_right, padding_top, padding_bottom)) return decoder_tensor_padded def pad_and_concat(encoder_tensor, decoder_tensor): """ Pads the decoder_tensor and concatenates it with the encoder_tensor along the channel dimension. :param encoder_tensor: The feature map from the encoder. :param decoder_tensor: The feature map from the decoder that needs to be concatenated with encoder_tensor. :return: Concatenated tensor. """ # pad decoder_tensor decoder_tensor_padded = pad_to_match(encoder_tensor, decoder_tensor) # concat encoder_tensor and decoder_tensor_padded concatenated_tensor = torch.cat((encoder_tensor, decoder_tensor_padded), dim=1) return concatenated_tensor class LinearCrossAttentionAdd(nn.Module): def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32): super().__init__() self.dim = dim self.dim_head = dim_head self.scale = dim_head ** -0.5 self.heads = heads self.label_emb_dim = label_emb_dim self.dim_head = dim_head self.hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(self.dim, self.hidden_dim * 3, 1, bias=False) self.to_out = nn.Sequential(nn.Conv2d(self.hidden_dim, dim, 1), nn.GroupNorm(1, dim)) # embedding for key and value self.label_key = nn.Linear(label_emb_dim, self.hidden_dim) self.label_query = nn.Linear(label_emb_dim, self.hidden_dim) def forward(self, x, condition=None): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv ) # if condition exists,concat its key and value with origin if condition is not None: label_k = self.label_key(condition).view(b, self.heads, self.dim_head, 1) label_q = self.label_query(condition).view(b, self.heads, self.dim_head, 1) k = k + label_k q = q + label_q q = q.softmax(dim=-2) k = k.softmax(dim=-1) q = q * self.scale context = torch.einsum("b h d n, b h e n -> b h d e", k, v) out = torch.einsum("b h d e, b h d n -> b h e n", context, q) out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) def linear_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start, beta_end, timesteps) def get_beta_schedule(timesteps): betas = linear_beta_schedule(timesteps=timesteps) # define alphas alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # calculations for diffusion q(x_t | x_{t-1}) and others sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) return sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, posterior_variance, sqrt_recip_alphas def extract(a, t, x_shape): batch_size = t.shape[0] out = a.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) # forward diffusion def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None): if noise is None: noise = torch.randn_like(x_start) sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alphas_cumprod_t = extract( sqrt_one_minus_alphas_cumprod, t, x_start.shape ) return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise