import jax import jax.numpy as jnp from jax import random import flax import flax.linen as nn from typing import Any, Tuple import functools import numpy as np import torch from torch.utils.data import TensorDataset key = random.PRNGKey(0) dataset = [] with np.load('spectograms.npz') as data: for file in data.files: dataset.append(data[file]) dataset = np.stack(dataset) dataset = np.expand_dims(dataset, axis=3) dataset = TensorDataset(torch.from_numpy(dataset)) # The following code is copied with minor modifications from https://colab.research.google.com/drive/1SeXMpILhkJPjXUaesvzEhc3Ke6Zl_zxJ?usp=sharing class GaussianFourierProjection(nn.Module): """Gaussian random features for encoding time steps.""" embed_dim: int scale: float = 30. @nn.compact def __call__(self, x): # Randomly sample weights during initialization. These weights are fixed # during optimization and are not trainable. W = self.param('W', jax.nn.initializers.normal(stddev=self.scale), (self.embed_dim // 2, )) W = jax.lax.stop_gradient(W) x_proj = x[:, None] * W[None, :] * 2 * jnp.pi return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1) class Dense(nn.Module): """A fully connected layer that reshapes outputs to feature maps.""" output_dim: int @nn.compact def __call__(self, x): return nn.Dense(self.output_dim)(x)[:, None, None, :] class ScoreNet(nn.Module): """A time-dependent score-based model built upon U-Net architecture. Args: marginal_prob_std: A function that takes time t and gives the standard deviation of the perturbation kernel p_{0t}(x(t) | x(0)). channels: The number of channels for feature maps of each resolution. embed_dim: The dimensionality of Gaussian random feature embeddings. """ marginal_prob_std: Any channels: Tuple[int] = (32, 64, 128, 256) embed_dim: int = 256 @nn.compact def __call__(self, x, t): # The swish activation function act = nn.swish # Obtain the Gaussian random feature embedding for t embed = act(nn.Dense(self.embed_dim)( GaussianFourierProjection(embed_dim=self.embed_dim)(t))) # Encoding path h1 = nn.Conv(self.channels[0], (3, 3), (1, 1), padding='VALID', use_bias=False)(x) # print('h1', h1.shape)#26x311 ## Incorporate information from t h1 += Dense(self.channels[0])(embed) ## Group normalization h1 = nn.GroupNorm(4)(h1) h1 = act(h1) h2 = nn.Conv(self.channels[1], (3, 3), (2, 2), padding='VALID', use_bias=False)(h1) # print('h2', h2.shape)#12x155 h2 += Dense(self.channels[1])(embed) h2 = nn.GroupNorm()(h2) h2 = act(h2) h3 = nn.Conv(self.channels[2], (3, 3), (2, 2), padding='VALID', use_bias=False)(h2) # print('h3', h3.shape)#5x77 h3 += Dense(self.channels[2])(embed) h3 = nn.GroupNorm()(h3) h3 = act(h3) h4 = nn.Conv(self.channels[3], (3, 3), (2, 2), padding='VALID', use_bias=False)(h3) # print('h4', h4.shape)#2x38 h4 += Dense(self.channels[3])(embed) h4 = nn.GroupNorm()(h4) h4 = act(h4) # Decoding path h = nn.Conv(self.channels[2], (3, 3), (1, 1), padding=((2, 2), (2, 2)), input_dilation=(2, 2), use_bias=False)(h4) # print('h', h.shape)#5x77 ## Skip connection from the encoding path h += Dense(self.channels[2])(embed) h = nn.GroupNorm()(h) h = act(h) h = nn.Conv(self.channels[1], (3, 3), (1, 1), padding=((2, 3), (2, 2)), input_dilation=(2, 2), use_bias=False)( jnp.concatenate([h, h3], axis=-1) ) # print('h', h.shape)#12x155 h += Dense(self.channels[1])(embed) h = nn.GroupNorm()(h) h = act(h) h = nn.Conv(self.channels[0], (3, 3), (1, 1), padding=((2, 3), (2, 2)), input_dilation=(2, 2), use_bias=False)( jnp.concatenate([h, h2], axis=-1) ) # print('h', h.shape)#26x311 h += Dense(self.channels[0])(embed) h = nn.GroupNorm()(h) h = act(h) h = nn.Conv(1, (3, 3), (1, 1), padding=((2, 2), (2, 2)))( jnp.concatenate([h, h1], axis=-1) ) # print('h', h.shape)#28x313 # Normalize output h = h / self.marginal_prob_std(t)[:, None, None, None] return h def marginal_prob_std(t, sigma): """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$. Args: t: A vector of time steps. sigma: The $\sigma$ in our SDE. Returns: The standard deviation. """ return jnp.sqrt((sigma**(2 * t) - 1.) / 2. / jnp.log(sigma)) def diffusion_coeff(t, sigma): """Compute the diffusion coefficient of our SDE. Args: t: A vector of time steps. sigma: The $\sigma$ in our SDE. Returns: The vector of diffusion coefficients. """ return sigma**t sigma = 25.0#@param {'type':'number'} marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) def loss_fn(rng, model, params, x, marginal_prob_std, eps=1e-5): """The loss function for training score-based generative models. Args: model: A `flax.linen.Module` object that represents the structure of the score-based model. params: A dictionary that contains all trainable parameters. x: A mini-batch of training data. marginal_prob_std: A function that gives the standard deviation of the perturbation kernel. eps: A tolerance value for numerical stability. """ rng, step_rng = jax.random.split(rng) random_t = jax.random.uniform(step_rng, (x.shape[0],), minval=eps, maxval=1.) rng, step_rng = jax.random.split(rng) z = jax.random.normal(step_rng, x.shape) std = marginal_prob_std(random_t) perturbed_x = x + z * std[:, None, None, None] score = model.apply(params, perturbed_x, random_t) loss = jnp.mean(jnp.sum((score * std[:, None, None, None] + z)**2, axis=(1,2,3))) return loss def get_train_step_fn(model, marginal_prob_std): """Create a one-step training function. Args: model: A `flax.linen.Module` object that represents the structure of the score-based model. marginal_prob_std: A function that gives the standard deviation of the perturbation kernel. Returns: A function that runs one step of training. """ val_and_grad_fn = jax.value_and_grad(loss_fn, argnums=2) def step_fn(rng, x, optimizer): params = optimizer.target loss, grad = val_and_grad_fn(rng, model, params, x, marginal_prob_std) mean_grad = jax.lax.pmean(grad, axis_name='device') mean_loss = jax.lax.pmean(loss, axis_name='device') new_optimizer = optimizer.apply_gradient(mean_grad) return mean_loss, new_optimizer return jax.pmap(step_fn, axis_name='device') #@title Training (double click to expand or collapse) import torch import functools import flax from flax.serialization import to_bytes, from_bytes import tensorflow as tf from torch.utils.data import DataLoader import torchvision.transforms as transforms from torchvision.datasets import MNIST import tqdm n_epochs = 500#@param {'type':'integer'} ## size of a mini-batch batch_size = 512#@param {'type':'integer'} ## learning rate lr=1e-3 #@param {'type':'number'} rng = jax.random.PRNGKey(0) fake_input = jnp.ones((batch_size, 28, 313, 1)) fake_time = jnp.ones(batch_size) score_model = ScoreNet(marginal_prob_std_fn) params = score_model.init({'params': rng}, fake_input, fake_time) # dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) optimizer = flax.optim.Adam(learning_rate=lr).create(params) train_step_fn = get_train_step_fn(score_model, marginal_prob_std_fn) tqdm_epoch = tqdm.notebook.trange(n_epochs) assert batch_size % jax.local_device_count() == 0 data_shape = (jax.local_device_count(), -1, 28, 313, 1) optimizer = flax.jax_utils.replicate(optimizer) for epoch in tqdm_epoch: avg_loss = 0. num_items = 0 for x in data_loader: x = x[0] x = x.numpy().reshape(data_shape) rng, *step_rng = jax.random.split(rng, jax.local_device_count() + 1) step_rng = jnp.asarray(step_rng) loss, optimizer = train_step_fn(step_rng, x, optimizer) loss = flax.jax_utils.unreplicate(loss) avg_loss += loss.item() * x.shape[0] num_items += x.shape[0] # Print the averaged training loss so far. tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items)) # Update the checkpoint after each epoch of training. with tf.io.gfile.GFile('ckpt.flax', 'wb') as fout: fout.write(to_bytes(flax.jax_utils.unreplicate(optimizer)))