import numpy as np import torch import librosa from torch.utils.data import TensorDataset import matplotlib.pyplot as plt import jax import jax.tools.colab_tpu import jax.numpy as jnp import flax import flax.linen as nn from typing import Any, Tuple import functools import torch from flax.serialization import to_bytes, from_bytes import tensorflow as tf from torch.utils.data import DataLoader import torchvision.transforms as transforms from torchvision.datasets import MNIST import tqdm from scipy import integrate import matplotlib.pyplot as plt from torchvision.utils import make_grid import soundfile import librosa.display import IPython.display as ipd import random import argparse parser = argparse.ArgumentParser() parser.add_argument('--sigma', type=float, default=25.0) parser.add_argument('--n_epochs', type=int, default=500) parser.add_argument('--batch_size', type=int, default=512) parser.add_argument('--lr', type=float, default=1e-2) parser.add_argument('--num_steps', type=int, default=500) parser.add_argument('--pc_num_steps', type=int, default=500) parser.add_argument('--signal_to_noise_ratio', type=float, default=0.16) parser.add_argument('--etol', type=float, default=1e-5) parser.add_argument('--sample_batch_size', type=int, default=64) parser.add_argument('--sample_no', type=int, default=25) args = parser.parse_args(args=[]) # required for colab class GaussianFourierProjection(nn.Module): """Gaussian random features for encoding time steps.""" embed_dim: int scale: float = 30. @nn.compact def __call__(self, x): # Randomly sample weights during initialization. These weights are fixed # during optimization and are not trainable. W = self.param('W', jax.nn.initializers.normal(stddev=self.scale), (self.embed_dim // 2, )) W = jax.lax.stop_gradient(W) x_proj = x[:, None] * W[None, :] * 2 * jnp.pi return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1) class Dense(nn.Module): """A fully connected layer that reshapes outputs to feature maps.""" output_dim: int @nn.compact def __call__(self, x): return nn.Dense(self.output_dim)(x)[:, None, None, :] class ScoreNet(nn.Module): """A time-dependent score-based model built upon U-Net architecture. Args: marginal_prob_std: A function that takes time t and gives the standard deviation of the perturbation kernel p_{0t}(x(t) | x(0)). channels: The number of channels for feature maps of each resolution. embed_dim: The dimensionality of Gaussian random feature embeddings. """ marginal_prob_std: Any channels: Tuple[int] = (32, 64, 128, 256) embed_dim: int = 256 @nn.compact def __call__(self, x, t): # The swish activation function act = nn.swish # Obtain the Gaussian random feature embedding for t embed = act(nn.Dense(self.embed_dim)( GaussianFourierProjection(embed_dim=self.embed_dim)(t))) # Encoding path h1 = nn.Conv(self.channels[0], (3, 3), (1, 1), padding='VALID', use_bias=False)(x) ## Incorporate information from t h1 += Dense(self.channels[0])(embed) ## Group normalization h1 = nn.GroupNorm(4)(h1) h1 = act(h1) h2 = nn.Conv(self.channels[1], (3, 3), (2, 2), padding='VALID', use_bias=False)(h1) h2 += Dense(self.channels[1])(embed) h2 = nn.GroupNorm()(h2) h2 = act(h2) h3 = nn.Conv(self.channels[2], (3, 3), (2, 2), padding='VALID', use_bias=False)(h2) h3 += Dense(self.channels[2])(embed) h3 = nn.GroupNorm()(h3) h3 = act(h3) h4 = nn.Conv(self.channels[3], (3, 3), (2, 2), padding='VALID', use_bias=False)(h3) h4 += Dense(self.channels[3])(embed) h4 = nn.GroupNorm()(h4) h4 = act(h4) # Decoding path h = nn.Conv(self.channels[2], (3, 3), (1, 1), padding=((2, 2), (2, 2)), input_dilation=(2, 2), use_bias=False)(h4) ## Skip connection from the encoding path h += Dense(self.channels[2])(embed) h = nn.GroupNorm()(h) h = act(h) h = nn.Conv(self.channels[1], (3, 3), (1, 1), padding=((2, 3), (2, 2)), input_dilation=(2, 2), use_bias=False)( jnp.concatenate([h, h3], axis=-1) ) h += Dense(self.channels[1])(embed) h = nn.GroupNorm()(h) h = act(h) h = nn.Conv(self.channels[0], (3, 3), (1, 1), padding=((2, 3), (2, 2)), input_dilation=(2, 2), use_bias=False)( jnp.concatenate([h, h2], axis=-1) ) h += Dense(self.channels[0])(embed) h = nn.GroupNorm()(h) h = act(h) h = nn.Conv(1, (3, 3), (1, 1), padding=((2, 2), (2, 2)))( jnp.concatenate([h, h1], axis=-1) ) # Normalize output h = h / self.marginal_prob_std(t)[:, None, None, None] return h def marginal_prob_std(t, sigma): """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$. Args: t: A vector of time steps. sigma: The $\sigma$ in our SDE. Returns: The standard deviation. """ return jnp.sqrt((sigma**(2 * t) - 1.) / 2. / jnp.log(sigma)) def diffusion_coeff(t, sigma): """Compute the diffusion coefficient of our SDE. Args: t: A vector of time steps. sigma: The $\sigma$ in our SDE. Returns: The vector of diffusion coefficients. """ return sigma**t def loss_fn(rng, model, params, x, marginal_prob_std, eps=1e-5): """The loss function for training score-based generative models. Args: model: A `flax.linen.Module` object that represents the structure of the score-based model. params: A dictionary that contains all trainable parameters. x: A mini-batch of training data. marginal_prob_std: A function that gives the standard deviation of the perturbation kernel. eps: A tolerance value for numerical stability. """ rng, step_rng = jax.random.split(rng) random_t = jax.random.uniform(step_rng, (x.shape[0],), minval=eps, maxval=1.) rng, step_rng = jax.random.split(rng) z = jax.random.normal(step_rng, x.shape) std = marginal_prob_std(random_t) perturbed_x = x + z * std[:, None, None, None] score = model.apply(params, perturbed_x, random_t) loss = jnp.mean(jnp.sum((score * std[:, None, None, None] + z)**2, axis=(1,2,3))) return loss def get_train_step_fn(model, marginal_prob_std): """Create a one-step training function. Args: model: A `flax.linen.Module` object that represents the structure of the score-based model. marginal_prob_std: A function that gives the standard deviation of the perturbation kernel. Returns: A function that runs one step of training. """ val_and_grad_fn = jax.value_and_grad(loss_fn, argnums=2) def step_fn(rng, x, optimizer): params = optimizer.target loss, grad = val_and_grad_fn(rng, model, params, x, marginal_prob_std) mean_grad = jax.lax.pmean(grad, axis_name='device') mean_loss = jax.lax.pmean(loss, axis_name='device') new_optimizer = optimizer.apply_gradient(mean_grad) return mean_loss, new_optimizer return jax.pmap(step_fn, axis_name='device') def score_fn(score_model, params, x, t): return score_model.apply(params, x, t) def Euler_Maruyama_sampler(rng, score_model, params, marginal_prob_std, diffusion_coeff, batch_size=64, num_steps=args.num_steps, eps=1e-3): """Generate samples from score-based models with the Euler-Maruyama solver. Args: rng: A JAX random state. score_model: A `flax.linen.Module` object that represents the architecture of a score-based model. params: A dictionary that contains the model parameters. marginal_prob_std: A function that gives the standard deviation of the perturbation kernel. diffusion_coeff: A function that gives the diffusion coefficient of the SDE. batch_size: The number of samplers to generate by calling this function once. num_steps: The number of sampling steps. Equivalent to the number of discretized time steps. eps: The smallest time step for numerical stability. Returns: Samples. """ rng, step_rng = jax.random.split(rng) time_shape = (jax.local_device_count(), batch_size // jax.local_device_count()) sample_shape = time_shape + (28, 313, 1) init_x = jax.random.normal(step_rng, sample_shape) * marginal_prob_std(1.) time_steps = jnp.linspace(1., eps, num_steps) step_size = time_steps[0] - time_steps[1] x = init_x for time_step in tqdm.notebook.tqdm(time_steps): batch_time_step = jnp.ones(time_shape) * time_step g = diffusion_coeff(time_step) mean_x = x + (g**2) * pmap_score_fn(score_model, params, x, batch_time_step) * step_size rng, step_rng = jax.random.split(rng) x = mean_x + jnp.sqrt(step_size) * g * jax.random.normal(step_rng, x.shape) # Do not include any noise in the last sampling step. return mean_x def pc_sampler(rng, score_model, params, marginal_prob_std, diffusion_coeff, batch_size=64, num_steps=args.num_steps, snr=args.signal_to_noise_ratio, eps=1e-3): """Generate samples from score-based models with Predictor-Corrector method. Args: rng: A JAX random state. score_model: A `flax.linen.Module` that represents the architecture of the score-based model. params: A dictionary that contains the parameters of the score-based model. marginal_prob_std: A function that gives the standard deviation of the perturbation kernel. diffusion_coeff: A function that gives the diffusion coefficient of the SDE. batch_size: The number of samplers to generate by calling this function once. num_steps: The number of sampling steps. Equivalent to the number of discretized time steps. eps: The smallest time step for numerical stability. Returns: Samples. """ time_shape = (jax.local_device_count(), batch_size // jax.local_device_count()) sample_shape = time_shape + (28, 313, 1) rng, step_rng = jax.random.split(rng) init_x = jax.random.normal(step_rng, sample_shape) * marginal_prob_std(1.) time_steps = jnp.linspace(1., eps, num_steps) step_size = time_steps[0] - time_steps[1] x = init_x for time_step in tqdm.notebook.tqdm(time_steps): batch_time_step = jnp.ones(time_shape) * time_step # Corrector step (Langevin MCMC) grad = pmap_score_fn(score_model, params, x, batch_time_step) grad_norm = jnp.linalg.norm(grad.reshape(sample_shape[0], sample_shape[1], -1), axis=-1).mean() noise_norm = np.sqrt(np.prod(x.shape[1:])) langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2 rng, step_rng = jax.random.split(rng) z = jax.random.normal(step_rng, x.shape) x = x + langevin_step_size * grad + jnp.sqrt(2 * langevin_step_size) * z # Predictor step (Euler-Maruyama) g = diffusion_coeff(time_step) score = pmap_score_fn(score_model, params, x, batch_time_step) x_mean = x + (g**2) * score * step_size rng, step_rng = jax.random.split(rng) z = jax.random.normal(step_rng, x.shape) x = x_mean + jnp.sqrt(g**2 * step_size) * z # The last step does not include any noise return x_mean def ode_sampler(rng, score_model, params, marginal_prob_std, diffusion_coeff, batch_size=64, atol=args.etol, rtol=args.etol, z=None, eps=1e-3): """Generate samples from score-based models with black-box ODE solvers. Args: rng: A JAX random state. score_model: A `flax.linen.Module` object that represents architecture of the score-based model. params: A dictionary that contains model parameters. marginal_prob_std: A function that returns the standard deviation of the perturbation kernel. diffusion_coeff: A function that returns the diffusion coefficient of the SDE. batch_size: The number of samplers to generate by calling this function once. atol: Tolerance of absolute errors. rtol: Tolerance of relative errors. z: The latent code that governs the final sample. If None, we start from p_1; otherwise, we start from the given z. eps: The smallest time step for numerical stability. """ time_shape = (jax.local_device_count(), batch_size // jax.local_device_count()) sample_shape = time_shape + (28, 313, 1) # Create the latent code if z is None: rng, step_rng = jax.random.split(rng) z = jax.random.normal(step_rng, sample_shape) init_x = z * marginal_prob_std(1.) else: init_x = z shape = init_x.shape def score_eval_wrapper(sample, time_steps): """A wrapper of the score-based model for use by the ODE solver.""" sample = jnp.asarray(sample, dtype=jnp.float32).reshape(sample_shape) time_steps = jnp.asarray(time_steps).reshape(time_shape) score = pmap_score_fn(score_model, params, sample, time_steps) return np.asarray(score).reshape((-1,)).astype(np.float64) def ode_func(t, x): """The ODE function for use by the ODE solver.""" time_steps = np.ones(time_shape) * t g = diffusion_coeff(t) return -0.5 * (g**2) * score_eval_wrapper(x, time_steps) # Run the black-box ODE solver. res = integrate.solve_ivp(ode_func, (1., eps), np.asarray(init_x).reshape(-1), rtol=rtol, atol=atol, method='RK45') print(f"Number of function evaluations: {res.nfev}") x = jnp.asarray(res.y[:, -1]).reshape(shape) return x def noise_removal(sample, threshold=-35.0): # k = torch.tensor(np.asarray(samples)[args.sample_no]) # k = torch.mean(k, axis=1, keepdims=False) p = np.array(sample) DB = librosa.amplitude_to_db(p, ref=np.max) DB_noise_removed = np.where(DB > threshold, DB, -80) return DB, DB_noise_removed def audio(sample, noise_threshold=-35.0): sampling_rate = 16000 call_with_noise, call_wo_noise = noise_removal(sample, threshold=noise_threshold) call_wo_noise = librosa.db_to_amplitude(call_wo_noise) back_audio = librosa.feature.inverse.mel_to_audio(call_wo_noise, sr=sampling_rate) return back_audio # soundfile.write('audio.wav', back_audio, samplerate=sampling_rate, subtype='FLOAT') # birdsong_back_audio, _ = librosa.load('audio.wav', sr=sampling_rate) # return birdsong_back_audio if __name__ == '__main__': sigma = args.sigma marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) n_epochs = args.n_epochs batch_size = args.batch_size lr=args.lr pmap_score_fn = jax.pmap(score_fn, static_broadcasted_argnums=(0, 1)) rng = jax.random.PRNGKey(0) fake_input = jnp.ones((batch_size, 28, 313, 1)) fake_time = jnp.ones(batch_size) score_model = ScoreNet(marginal_prob_std_fn) params = score_model.init({'params': rng}, fake_input, fake_time) # dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) optimizer = flax.optim.Adam(learning_rate=lr).create(params) train_step_fn = get_train_step_fn(score_model, marginal_prob_std_fn) tqdm_epoch = tqdm.notebook.trange(n_epochs) assert batch_size % jax.local_device_count() == 0 data_shape = (jax.local_device_count(), -1, 28, 313, 1) # optimizer = flax.jax_utils.replicate(optimizer) # for epoch in tqdm_epoch: # avg_loss = 0. # num_items = 0 # for x in data_loader: # x = x[0] # x = x.numpy().reshape(data_shape) # rng, *step_rng = jax.random.split(rng, jax.local_device_count() + 1) # step_rng = jnp.asarray(step_rng) # loss, optimizer = train_step_fn(step_rng, x, optimizer) # loss = flax.jax_utils.unreplicate(loss) # avg_loss += loss.item() * x.shape[0] # num_items += x.shape[0] # # Print the averaged training loss so far. # tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items)) # # Update the checkpoint after each epoch of training. # with tf.io.gfile.GFile('ckpt.flax', 'wb') as fout: # fout.write(to_bytes(flax.jax_utils.unreplicate(optimizer))) num_steps = args.num_steps signal_to_noise_ratio = args.signal_to_noise_ratio pc_num_steps = args.pc_num_steps error_tolerance = args.etol sample_batch_size = args.sample_batch_size sampler = ode_sampler ## Load the pre-trained checkpoint from disk. score_model = ScoreNet(marginal_prob_std_fn) fake_input = jnp.ones((sample_batch_size, 28, 313, 1)) fake_time = jnp.ones((sample_batch_size, )) rng = jax.random.PRNGKey(0) params = score_model.init({'params': rng}, fake_input, fake_time) optimizer = flax.optim.Adam().create(params) with tf.io.gfile.GFile('ckpt.flax', 'rb') as fin: optimizer = from_bytes(optimizer, fin.read()) ## Generate samples using the specified sampler. rng, step_rng = jax.random.split(rng) samples = sampler(rng, score_model, optimizer.target, marginal_prob_std_fn, diffusion_coeff_fn, sample_batch_size) ## Sample visualization. # samples = jnp.clip(samples, 0.0, 10000.0) samples = jnp.transpose(samples.reshape((-1, 28, 313, 1)), (0, 3, 1, 2)) %matplotlib inline sample_grid = make_grid(torch.tensor(np.asarray(samples)), nrow=int(np.sqrt(sample_batch_size))) plt.figure(figsize=(6,6)) plt.axis('off') plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.) plt.show() # audio_and_viz(samples) j = 7 viz(jnp.mean(samples[j], 0)) ipd.Audio(audio(jnp.mean(samples[j], 0), noise_threshold=-25.0), rate=16000)