ibraheemmoosa commited on
Commit
c096304
1 Parent(s): 5966102

Add Score-SDE training script.

Browse files
Files changed (1) hide show
  1. Score-SDE/train-score-sde.py +257 -0
Score-SDE/train-score-sde.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from jax import random
4
+ import flax
5
+ import flax.linen as nn
6
+ from typing import Any, Tuple
7
+ import functools
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import TensorDataset
11
+
12
+ key = random.PRNGKey(0)
13
+
14
+ dataset = []
15
+ with np.load('spectograms.npz') as data:
16
+ for file in data.files:
17
+ dataset.append(data[file])
18
+
19
+ dataset = np.stack(dataset)
20
+ dataset = np.expand_dims(dataset, axis=3)
21
+ dataset = TensorDataset(torch.from_numpy(dataset))
22
+
23
+
24
+ # The following code is copied with minor modifications from https://colab.research.google.com/drive/1SeXMpILhkJPjXUaesvzEhc3Ke6Zl_zxJ?usp=sharing
25
+
26
+ class GaussianFourierProjection(nn.Module):
27
+ """Gaussian random features for encoding time steps."""
28
+ embed_dim: int
29
+ scale: float = 30.
30
+ @nn.compact
31
+ def __call__(self, x):
32
+ # Randomly sample weights during initialization. These weights are fixed
33
+ # during optimization and are not trainable.
34
+ W = self.param('W', jax.nn.initializers.normal(stddev=self.scale),
35
+ (self.embed_dim // 2, ))
36
+ W = jax.lax.stop_gradient(W)
37
+ x_proj = x[:, None] * W[None, :] * 2 * jnp.pi
38
+ return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)
39
+
40
+
41
+ class Dense(nn.Module):
42
+ """A fully connected layer that reshapes outputs to feature maps."""
43
+ output_dim: int
44
+
45
+ @nn.compact
46
+ def __call__(self, x):
47
+ return nn.Dense(self.output_dim)(x)[:, None, None, :]
48
+
49
+
50
+ class ScoreNet(nn.Module):
51
+ """A time-dependent score-based model built upon U-Net architecture.
52
+
53
+ Args:
54
+ marginal_prob_std: A function that takes time t and gives the standard
55
+ deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
56
+ channels: The number of channels for feature maps of each resolution.
57
+ embed_dim: The dimensionality of Gaussian random feature embeddings.
58
+ """
59
+ marginal_prob_std: Any
60
+ channels: Tuple[int] = (32, 64, 128, 256)
61
+ embed_dim: int = 256
62
+
63
+ @nn.compact
64
+ def __call__(self, x, t):
65
+ # The swish activation function
66
+ act = nn.swish
67
+ # Obtain the Gaussian random feature embedding for t
68
+ embed = act(nn.Dense(self.embed_dim)(
69
+ GaussianFourierProjection(embed_dim=self.embed_dim)(t)))
70
+
71
+ # Encoding path
72
+ h1 = nn.Conv(self.channels[0], (3, 3), (1, 1), padding='VALID',
73
+ use_bias=False)(x)
74
+ # print('h1', h1.shape)#26x311
75
+ ## Incorporate information from t
76
+ h1 += Dense(self.channels[0])(embed)
77
+ ## Group normalization
78
+ h1 = nn.GroupNorm(4)(h1)
79
+ h1 = act(h1)
80
+ h2 = nn.Conv(self.channels[1], (3, 3), (2, 2), padding='VALID',
81
+ use_bias=False)(h1)
82
+ # print('h2', h2.shape)#12x155
83
+ h2 += Dense(self.channels[1])(embed)
84
+ h2 = nn.GroupNorm()(h2)
85
+ h2 = act(h2)
86
+ h3 = nn.Conv(self.channels[2], (3, 3), (2, 2), padding='VALID',
87
+ use_bias=False)(h2)
88
+ # print('h3', h3.shape)#5x77
89
+ h3 += Dense(self.channels[2])(embed)
90
+ h3 = nn.GroupNorm()(h3)
91
+ h3 = act(h3)
92
+ h4 = nn.Conv(self.channels[3], (3, 3), (2, 2), padding='VALID',
93
+ use_bias=False)(h3)
94
+ # print('h4', h4.shape)#2x38
95
+ h4 += Dense(self.channels[3])(embed)
96
+ h4 = nn.GroupNorm()(h4)
97
+ h4 = act(h4)
98
+
99
+ # Decoding path
100
+ h = nn.Conv(self.channels[2], (3, 3), (1, 1), padding=((2, 2), (2, 2)),
101
+ input_dilation=(2, 2), use_bias=False)(h4)
102
+ # print('h', h.shape)#5x77
103
+ ## Skip connection from the encoding path
104
+ h += Dense(self.channels[2])(embed)
105
+ h = nn.GroupNorm()(h)
106
+ h = act(h)
107
+ h = nn.Conv(self.channels[1], (3, 3), (1, 1), padding=((2, 3), (2, 2)),
108
+ input_dilation=(2, 2), use_bias=False)(
109
+ jnp.concatenate([h, h3], axis=-1)
110
+ )
111
+ # print('h', h.shape)#12x155
112
+ h += Dense(self.channels[1])(embed)
113
+ h = nn.GroupNorm()(h)
114
+ h = act(h)
115
+ h = nn.Conv(self.channels[0], (3, 3), (1, 1), padding=((2, 3), (2, 2)),
116
+ input_dilation=(2, 2), use_bias=False)(
117
+ jnp.concatenate([h, h2], axis=-1)
118
+ )
119
+ # print('h', h.shape)#26x311
120
+ h += Dense(self.channels[0])(embed)
121
+ h = nn.GroupNorm()(h)
122
+ h = act(h)
123
+ h = nn.Conv(1, (3, 3), (1, 1), padding=((2, 2), (2, 2)))(
124
+ jnp.concatenate([h, h1], axis=-1)
125
+ )
126
+ # print('h', h.shape)#28x313
127
+ # Normalize output
128
+ h = h / self.marginal_prob_std(t)[:, None, None, None]
129
+ return h
130
+
131
+
132
+ def marginal_prob_std(t, sigma):
133
+ """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.
134
+
135
+ Args:
136
+ t: A vector of time steps.
137
+ sigma: The $\sigma$ in our SDE.
138
+
139
+ Returns:
140
+ The standard deviation.
141
+ """
142
+ return jnp.sqrt((sigma**(2 * t) - 1.) / 2. / jnp.log(sigma))
143
+
144
+ def diffusion_coeff(t, sigma):
145
+ """Compute the diffusion coefficient of our SDE.
146
+
147
+ Args:
148
+ t: A vector of time steps.
149
+ sigma: The $\sigma$ in our SDE.
150
+
151
+ Returns:
152
+ The vector of diffusion coefficients.
153
+ """
154
+ return sigma**t
155
+
156
+ sigma = 25.0#@param {'type':'number'}
157
+ marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
158
+ diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
159
+
160
+
161
+ def loss_fn(rng, model, params, x, marginal_prob_std, eps=1e-5):
162
+ """The loss function for training score-based generative models.
163
+
164
+ Args:
165
+ model: A `flax.linen.Module` object that represents the structure of
166
+ the score-based model.
167
+ params: A dictionary that contains all trainable parameters.
168
+ x: A mini-batch of training data.
169
+ marginal_prob_std: A function that gives the standard deviation of
170
+ the perturbation kernel.
171
+ eps: A tolerance value for numerical stability.
172
+ """
173
+ rng, step_rng = jax.random.split(rng)
174
+ random_t = jax.random.uniform(step_rng, (x.shape[0],), minval=eps, maxval=1.)
175
+ rng, step_rng = jax.random.split(rng)
176
+ z = jax.random.normal(step_rng, x.shape)
177
+ std = marginal_prob_std(random_t)
178
+ perturbed_x = x + z * std[:, None, None, None]
179
+ score = model.apply(params, perturbed_x, random_t)
180
+ loss = jnp.mean(jnp.sum((score * std[:, None, None, None] + z)**2,
181
+ axis=(1,2,3)))
182
+ return loss
183
+
184
+ def get_train_step_fn(model, marginal_prob_std):
185
+ """Create a one-step training function.
186
+
187
+ Args:
188
+ model: A `flax.linen.Module` object that represents the structure of
189
+ the score-based model.
190
+ marginal_prob_std: A function that gives the standard deviation of
191
+ the perturbation kernel.
192
+ Returns:
193
+ A function that runs one step of training.
194
+ """
195
+
196
+ val_and_grad_fn = jax.value_and_grad(loss_fn, argnums=2)
197
+ def step_fn(rng, x, optimizer):
198
+ params = optimizer.target
199
+ loss, grad = val_and_grad_fn(rng, model, params, x, marginal_prob_std)
200
+ mean_grad = jax.lax.pmean(grad, axis_name='device')
201
+ mean_loss = jax.lax.pmean(loss, axis_name='device')
202
+ new_optimizer = optimizer.apply_gradient(mean_grad)
203
+
204
+ return mean_loss, new_optimizer
205
+ return jax.pmap(step_fn, axis_name='device')
206
+
207
+
208
+ #@title Training (double click to expand or collapse)
209
+ import torch
210
+ import functools
211
+ import flax
212
+ from flax.serialization import to_bytes, from_bytes
213
+ import tensorflow as tf
214
+ from torch.utils.data import DataLoader
215
+ import torchvision.transforms as transforms
216
+ from torchvision.datasets import MNIST
217
+ import tqdm
218
+
219
+ n_epochs = 500#@param {'type':'integer'}
220
+ ## size of a mini-batch
221
+ batch_size = 512#@param {'type':'integer'}
222
+ ## learning rate
223
+ lr=1e-3 #@param {'type':'number'}
224
+
225
+ rng = jax.random.PRNGKey(0)
226
+ fake_input = jnp.ones((batch_size, 28, 313, 1))
227
+ fake_time = jnp.ones(batch_size)
228
+ score_model = ScoreNet(marginal_prob_std_fn)
229
+ params = score_model.init({'params': rng}, fake_input, fake_time)
230
+
231
+ # dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
232
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
233
+ optimizer = flax.optim.Adam(learning_rate=lr).create(params)
234
+ train_step_fn = get_train_step_fn(score_model, marginal_prob_std_fn)
235
+ tqdm_epoch = tqdm.notebook.trange(n_epochs)
236
+
237
+ assert batch_size % jax.local_device_count() == 0
238
+ data_shape = (jax.local_device_count(), -1, 28, 313, 1)
239
+
240
+ optimizer = flax.jax_utils.replicate(optimizer)
241
+ for epoch in tqdm_epoch:
242
+ avg_loss = 0.
243
+ num_items = 0
244
+ for x in data_loader:
245
+ x = x[0]
246
+ x = x.numpy().reshape(data_shape)
247
+ rng, *step_rng = jax.random.split(rng, jax.local_device_count() + 1)
248
+ step_rng = jnp.asarray(step_rng)
249
+ loss, optimizer = train_step_fn(step_rng, x, optimizer)
250
+ loss = flax.jax_utils.unreplicate(loss)
251
+ avg_loss += loss.item() * x.shape[0]
252
+ num_items += x.shape[0]
253
+ # Print the averaged training loss so far.
254
+ tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
255
+ # Update the checkpoint after each epoch of training.
256
+ with tf.io.gfile.GFile('ckpt.flax', 'wb') as fout:
257
+ fout.write(to_bytes(flax.jax_utils.unreplicate(optimizer)))