Spaces:
Sleeping
Sleeping
import torch.nn.functional as F | |
import torch | |
from torch import nn | |
from einops import rearrange | |
from inspect import isfunction | |
import math | |
from tqdm import tqdm | |
def exists(x): | |
"""Return true for x is not None.""" | |
return x is not None | |
def default(val, d): | |
"""Helper function""" | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
class Residual(nn.Module): | |
"""Skip connection""" | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x, *args, **kwargs): | |
return self.fn(x, *args, **kwargs) + x | |
def Upsample(dim): | |
"""Upsample layer, a transposed convolution layer with stride=2""" | |
return nn.ConvTranspose2d(dim, dim, 4, 2, 1) | |
def Downsample(dim): | |
"""Downsample layer, a convolution layer with stride=2""" | |
return nn.Conv2d(dim, dim, 4, 2, 1) | |
class SinusoidalPositionEmbeddings(nn.Module): | |
"""Return sinusoidal embedding for integer time step.""" | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, time): | |
device = time.device | |
half_dim = self.dim // 2 | |
embeddings = math.log(10000) / (half_dim - 1) | |
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) | |
embeddings = time[:, None] * embeddings[None, :] | |
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
return embeddings | |
class Block(nn.Module): | |
"""Stack of convolution, normalization, and non-linear activation""" | |
def __init__(self, dim, dim_out, groups=8): | |
super().__init__() | |
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) | |
self.norm = nn.GroupNorm(groups, dim_out) | |
self.act = nn.SiLU() | |
def forward(self, x, scale_shift=None): | |
x = self.proj(x) | |
x = self.norm(x) | |
if exists(scale_shift): | |
scale, shift = scale_shift | |
x = x * (scale + 1) + shift | |
x = self.act(x) | |
return x | |
class ResnetBlock(nn.Module): | |
"""Stack of [conv + norm + act (+ scale&shift)], with positional embedding inserted <https://arxiv.org/abs/1512.03385>""" | |
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): | |
super().__init__() | |
self.mlp = ( | |
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) | |
if exists(time_emb_dim) | |
else None | |
) | |
self.block1 = Block(dim, dim_out, groups=groups) | |
self.block2 = Block(dim_out, dim_out, groups=groups) | |
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
def forward(self, x, time_emb=None): | |
h = self.block1(x) | |
if exists(self.mlp) and exists(time_emb): | |
time_emb = self.mlp(time_emb) | |
# Adding positional embedding to intermediate layer (by broadcasting along spatial dimension) | |
h = rearrange(time_emb, "b c -> b c 1 1") + h | |
h = self.block2(h) | |
return h + self.res_conv(x) | |
class ConvNextBlock(nn.Module): | |
"""Stack of [conv7x7 (+ condition(pos)) + norm + conv3x3 + act + norm + conv3x3 + res1x1],with positional embedding inserted""" | |
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True): | |
super().__init__() | |
self.mlp = ( | |
nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim)) | |
if exists(time_emb_dim) | |
else None | |
) | |
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) | |
self.net = nn.Sequential( | |
nn.GroupNorm(1, dim) if norm else nn.Identity(), | |
nn.Conv2d(dim, dim_out * mult, 3, padding=1), | |
nn.GELU(), | |
nn.GroupNorm(1, dim_out * mult), | |
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1), | |
) | |
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() | |
def forward(self, x, time_emb=None): | |
h = self.ds_conv(x) | |
if exists(self.mlp) and exists(time_emb): | |
assert exists(time_emb), "time embedding must be passed in" | |
condition = self.mlp(time_emb) | |
h = h + rearrange(condition, "b c -> b c 1 1") | |
h = self.net(h) | |
return h + self.res_conv(x) | |
class PreNorm(nn.Module): | |
"""Apply normalization before 'fn'""" | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = nn.GroupNorm(1, dim) | |
def forward(self, x, *args, **kwargs): | |
x = self.norm(x) | |
return self.fn(x, *args, **kwargs) | |
class ConditionalEmbedding(nn.Module): | |
"""Return embedding for label and projection for text embedding""" | |
def __init__(self, num_labels, embedding_dim, condition_type="instrument_family"): | |
super(ConditionalEmbedding, self).__init__() | |
if condition_type == "instrument_family": | |
self.embedding = nn.Embedding(num_labels, embedding_dim) | |
elif condition_type == "natural_language_prompt": | |
self.embedding = nn.Linear(embedding_dim, embedding_dim, bias=True) | |
else: | |
raise NotImplementedError() | |
def forward(self, labels): | |
return self.embedding(labels) | |
class LinearCrossAttention(nn.Module): | |
"""Combination of efficient attention and cross attention.""" | |
def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32): | |
super().__init__() | |
self.dim_head = dim_head | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) | |
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) | |
# embedding for key and value | |
self.label_key = nn.Linear(label_emb_dim, hidden_dim) | |
self.label_value = nn.Linear(label_emb_dim, hidden_dim) | |
def forward(self, x, label_embedding=None): | |
b, c, h, w = x.shape | |
qkv = self.to_qkv(x).chunk(3, dim=1) | |
q, k, v = map( | |
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv | |
) | |
if label_embedding is not None: | |
label_k = self.label_key(label_embedding).view(b, self.heads, self.dim_head, 1) | |
label_v = self.label_value(label_embedding).view(b, self.heads, self.dim_head, 1) | |
k = torch.cat([k, label_k], dim=-1) | |
v = torch.cat([v, label_v], dim=-1) | |
q = q.softmax(dim=-2) | |
k = k.softmax(dim=-1) | |
q = q * self.scale | |
context = torch.einsum("b h d n, b h e n -> b h d e", k, v) | |
out = torch.einsum("b h d e, b h d n -> b h e n", context, q) | |
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) | |
return self.to_out(out) | |
def pad_to_match(encoder_tensor, decoder_tensor): | |
""" | |
Pads the decoder_tensor to match the spatial dimensions of encoder_tensor. | |
:param encoder_tensor: The feature map from the encoder. | |
:param decoder_tensor: The feature map from the decoder that needs to be upsampled. | |
:return: Padded decoder_tensor with the same spatial dimensions as encoder_tensor. | |
""" | |
enc_shape = encoder_tensor.shape[2:] # spatial dimensions are at index 2 and 3 | |
dec_shape = decoder_tensor.shape[2:] | |
# assume enc_shape >= dec_shape | |
delta_w = enc_shape[1] - dec_shape[1] | |
delta_h = enc_shape[0] - dec_shape[0] | |
# padding | |
padding_left = delta_w // 2 | |
padding_right = delta_w - padding_left | |
padding_top = delta_h // 2 | |
padding_bottom = delta_h - padding_top | |
decoder_tensor_padded = F.pad(decoder_tensor, (padding_left, padding_right, padding_top, padding_bottom)) | |
return decoder_tensor_padded | |
def pad_and_concat(encoder_tensor, decoder_tensor): | |
""" | |
Pads the decoder_tensor and concatenates it with the encoder_tensor along the channel dimension. | |
:param encoder_tensor: The feature map from the encoder. | |
:param decoder_tensor: The feature map from the decoder that needs to be concatenated with encoder_tensor. | |
:return: Concatenated tensor. | |
""" | |
# pad decoder_tensor | |
decoder_tensor_padded = pad_to_match(encoder_tensor, decoder_tensor) | |
# concat encoder_tensor and decoder_tensor_padded | |
concatenated_tensor = torch.cat((encoder_tensor, decoder_tensor_padded), dim=1) | |
return concatenated_tensor | |
class LinearCrossAttentionAdd(nn.Module): | |
def __init__(self, dim, heads=4, label_emb_dim=128, dim_head=32): | |
super().__init__() | |
self.dim = dim | |
self.dim_head = dim_head | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
self.label_emb_dim = label_emb_dim | |
self.dim_head = dim_head | |
self.hidden_dim = dim_head * heads | |
self.to_qkv = nn.Conv2d(self.dim, self.hidden_dim * 3, 1, bias=False) | |
self.to_out = nn.Sequential(nn.Conv2d(self.hidden_dim, dim, 1), nn.GroupNorm(1, dim)) | |
# embedding for key and value | |
self.label_key = nn.Linear(label_emb_dim, self.hidden_dim) | |
self.label_query = nn.Linear(label_emb_dim, self.hidden_dim) | |
def forward(self, x, condition=None): | |
b, c, h, w = x.shape | |
qkv = self.to_qkv(x).chunk(3, dim=1) | |
q, k, v = map( | |
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv | |
) | |
# if condition exists,concat its key and value with origin | |
if condition is not None: | |
label_k = self.label_key(condition).view(b, self.heads, self.dim_head, 1) | |
label_q = self.label_query(condition).view(b, self.heads, self.dim_head, 1) | |
k = k + label_k | |
q = q + label_q | |
q = q.softmax(dim=-2) | |
k = k.softmax(dim=-1) | |
q = q * self.scale | |
context = torch.einsum("b h d n, b h e n -> b h d e", k, v) | |
out = torch.einsum("b h d e, b h d n -> b h e n", context, q) | |
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) | |
return self.to_out(out) | |
def linear_beta_schedule(timesteps): | |
beta_start = 0.0001 | |
beta_end = 0.02 | |
return torch.linspace(beta_start, beta_end, timesteps) | |
def get_beta_schedule(timesteps): | |
betas = linear_beta_schedule(timesteps=timesteps) | |
# define alphas | |
alphas = 1. - betas | |
alphas_cumprod = torch.cumprod(alphas, axis=0) | |
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) | |
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) | |
# calculations for diffusion q(x_t | x_{t-1}) and others | |
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) | |
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) | |
# calculations for posterior q(x_{t-1} | x_t, x_0) | |
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) | |
return sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, posterior_variance, sqrt_recip_alphas | |
def extract(a, t, x_shape): | |
batch_size = t.shape[0] | |
out = a.gather(-1, t.cpu()) | |
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) | |
# forward diffusion | |
def q_sample(x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None): | |
if noise is None: | |
noise = torch.randn_like(x_start) | |
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape) | |
sqrt_one_minus_alphas_cumprod_t = extract( | |
sqrt_one_minus_alphas_cumprod, t, x_start.shape | |
) | |
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise | |