anonymous9a7b commited on
Commit
d4c980e
·
1 Parent(s): 8207a5c
app.py CHANGED
@@ -1,7 +1,217 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ import numpy as np
4
+ import torch
5
+ from fastgeco.model import ScoreModel
6
+ from geco.util.other import pad_spec
7
+ import os
8
+ import torchaudio
9
+ from speechbrain.lobes.models.dual_path import Encoder, SBTransformerBlock, SBTransformerBlock, Dual_Path_Model, Decoder
10
 
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
12
 
13
+ def load_sepformer(ckpt_path):
14
+ encoder = Encoder(
15
+ kernel_size=160,
16
+ out_channels=256,
17
+ in_channels=1
18
+ )
19
+ SBtfintra = SBTransformerBlock(
20
+ num_layers=8,
21
+ d_model=256,
22
+ nhead=8,
23
+ d_ffn=1024,
24
+ dropout=0,
25
+ use_positional_encoding=True,
26
+ norm_before=True,
27
+ )
28
+ SBtfinter = SBTransformerBlock(
29
+ num_layers=8,
30
+ d_model=256,
31
+ nhead=8,
32
+ d_ffn=1024,
33
+ dropout=0,
34
+ use_positional_encoding=True,
35
+ norm_before=True,
36
+ )
37
+ masknet = Dual_Path_Model(
38
+ num_spks=args.num_spks,
39
+ in_channels=256,
40
+ out_channels=256,
41
+ num_layers=2,
42
+ K=250,
43
+ intra_model=SBtfintra,
44
+ inter_model=SBtfinter,
45
+ norm='ln',
46
+ linear_layer_after_inter_intra=False,
47
+ skip_around_intra=True,
48
+ )
49
+ decoder = Decoder(
50
+ in_channels=256,
51
+ out_channels=1,
52
+ kernel_size=160,
53
+ stride=80,
54
+ bias=False,
55
+ )
56
+
57
+ encoder_weights = torch.load(os.path.join(ckpt_path, 'encoder.ckpt'))
58
+ encoder.load_state_dict(encoder_weights)
59
+ masknet_weights = torch.load(os.path.join(ckpt_path, 'masknet.ckpt'))
60
+ masknet.load_state_dict(masknet_weights)
61
+ decoder_weights = torch.load(os.path.join(ckpt_path, 'decoder.ckpt'))
62
+ decoder.load_state_dict(decoder_weights)
63
+ encoder = encoder.eval().to(device)
64
+ masknet = masknet.eval().to(device)
65
+ decoder = decoder.eval().to(device)
66
+ return encoder, masknet, decoder
67
+
68
+ def load_fastgeco(ckpt_path):
69
+ checkpoint_file = os.path.join(ckpt_path, 'fastgeco.ckpt')
70
+ model = ScoreModel.load_from_checkpoint(
71
+ checkpoint_file,
72
+ batch_size=1, num_workers=0, kwargs=dict(gpu=False)
73
+ )
74
+ model.eval(no_ema=False)
75
+ model.to(device)
76
+ return model
77
+
78
+ ckpt_path = 'ckpts/'
79
+ encoder, masknet, decoder = load_sepformer(ckpt_path)
80
+ fastgeco_model = load_fastgeco(ckpt_path)
81
+ sample_rate = 8000
82
+ num_spks = 2
83
+
84
+ @spaces.GPU
85
+ def separate(test_file, encoder, masknet, decoder):
86
+ with torch.no_grad():
87
+ print('Process SepFormer...')
88
+ mix, fs_file = torchaudio.load(test_file)
89
+ mix = mix.to(device)
90
+ fs_model = sample_rate
91
+
92
+ # resample the data if needed
93
+ if fs_file != fs_model:
94
+ print(
95
+ "Resampling the audio from {} Hz to {} Hz".format(
96
+ fs_file, fs_model
97
+ )
98
+ )
99
+ tf = torchaudio.transforms.Resample(
100
+ orig_freq=fs_file, new_freq=fs_model
101
+ ).to(device)
102
+ mix = mix.mean(dim=0, keepdim=True)
103
+ mix = tf(mix)
104
+
105
+ mix = mix.to(device)
106
+
107
+ # Separation
108
+ mix_w = encoder(mix)
109
+ est_mask = masknet(mix_w)
110
+ mix_w = torch.stack([mix_w] * num_spks)
111
+ sep_h = mix_w * est_mask
112
+
113
+ # Decoding
114
+ est_sources = torch.cat(
115
+ [
116
+ decoder(sep_h[i]).unsqueeze(-1)
117
+ for i in range(num_spks)
118
+ ],
119
+ dim=-1,
120
+ )
121
+ est_sources = (
122
+ est_sources / est_sources.abs().max(dim=1, keepdim=True)[0]
123
+ ).squeeze()
124
+
125
+ return est_sources, mix
126
+
127
+
128
+ @spaces.GPU
129
+ def correct(model, est_sources, mix):
130
+ with torch.no_grad():
131
+ print('Process Fast-Geco...')
132
+ N = 1
133
+ reverse_starting_point = 0.5
134
+ output = []
135
+ for idx in range(num_spks):
136
+ y = est_sources[:, idx].unsqueeze(0) # noisy
137
+ m = mix
138
+ min_leng = min(y.shape[-1],m.shape[-1])
139
+ y = y[...,:min_leng]
140
+ m = m[...,:min_leng]
141
+ T_orig = y.size(1)
142
+
143
+ norm_factor = y.abs().max()
144
+ y = y / norm_factor
145
+ m = m / norm_factor
146
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(device))), 0)
147
+ Y = pad_spec(Y)
148
+ M = torch.unsqueeze(model._forward_transform(model._stft(m.to(device))), 0)
149
+ M = pad_spec(M)
150
+
151
+ timesteps = torch.linspace(reverse_starting_point, 0.03, N, device=Y.device)
152
+ std = model.sde._std(reverse_starting_point*torch.ones((Y.shape[0],), device=Y.device))
153
+ z = torch.randn_like(Y)
154
+ X_t = Y + z * std[:, None, None, None]
155
+
156
+ t = timesteps[0]
157
+ dt = timesteps[-1]
158
+ f, g = model.sde.sde(X_t, t, Y)
159
+ vec_t = torch.ones(Y.shape[0], device=Y.device) * t
160
+ mean_x_tm1 = X_t - (f - g**2*model.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None]))*dt #mean of x t minus 1 = mu(x_{t-1})
161
+ sample = mean_x_tm1
162
+ sample = sample.squeeze()
163
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
164
+ x_hat = x_hat * norm_factor
165
+ new_norm_factor = x_hat.abs().max()
166
+ x_hat = x_hat / new_norm_factor
167
+ x_hat = x_hat.squeeze().cpu().numpy()
168
+ output.append(x_hat)
169
+ return output[0], output[1]
170
+
171
+ @spaces.GPU
172
+ def process_audio(test_file):
173
+ result, mix = separate(test_file, encoder, masknet, decoder)
174
+ audio1, audio2 = correct(fastgeco_model, result, mix)
175
+ return audio1, audio2
176
+
177
+
178
+ # CSS styling (optional)
179
+ css = """
180
+ #col-container {
181
+ margin: 0 auto;
182
+ max-width: 1280px;
183
+ }
184
+ """
185
+
186
+ # Gradio Blocks layout
187
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
188
+ with gr.Column(elem_id="col-container"):
189
+ gr.Markdown("""
190
+ # Fast-GeCo: Noise-robust Speech Separation with Fast Generative Correction
191
+ Separate the noisy mixture speech with a generative correction method, only support 2 speakers now.
192
+
193
+ Learn more about 🟣**Fast-GeCo** on the [Fast-GeCo Repo](https://github.com/WangHelin1997/Fast-GeCo/).
194
+ """)
195
+
196
+ with gr.Tab("Speech Separation"):
197
+ # Input: Upload audio file
198
+ with gr.Row():
199
+ gt_file_input = gr.Audio(label="Upload Audio to Separate", type="filepath", value="demo/item0_mix.wav")
200
+ button = gr.Button("Generate", scale=1)
201
+
202
+ # Output Component for edited audio
203
+ with gr.Row():
204
+ result1 = gr.Audio(label="Separated Audio 1", type="numpy")
205
+ result2 = gr.Audio(label="Separated Audio 2", type="numpy")
206
+
207
+ # Define the trigger and input-output linking
208
+ button.click(
209
+ fn=process_audio,
210
+ inputs=[
211
+ gt_file_input,
212
+ ],
213
+ outputs=[result1, result2]
214
+ )
215
+
216
+ # Launch the Gradio demo
217
+ demo.launch()
demo/item0_mix.wav ADDED
Binary file (173 kB). View file
 
demo/item1_mix.wav ADDED
Binary file (164 kB). View file
 
demo/item2_mix.wav ADDED
Binary file (103 kB). View file
 
demo/item3_mix.wav ADDED
Binary file (105 kB). View file
 
demo/item4_mix.wav ADDED
Binary file (104 kB). View file
 
fastgeco/.DS_Store ADDED
Binary file (6.15 kB). View file
 
fastgeco/backbones/.DS_Store ADDED
Binary file (6.15 kB). View file
 
fastgeco/backbones/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .shared import BackboneRegistry
2
+ from .ncsnpp import NCSNpp
3
+
4
+ __all__ = ['BackboneRegistry', 'NCSNpp']
fastgeco/backbones/ncsnpp.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ from score_models.layers import UpsampleLayer, DownsampleLayer
18
+ from .ncsnpp_utils import layers, layerspp, normalization
19
+ import torch.nn as nn
20
+ import functools
21
+ import torch
22
+ import numpy as np
23
+
24
+ from .shared import BackboneRegistry
25
+
26
+ ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
+ ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
+ Combine = layerspp.Combine
29
+ conv3x3 = layerspp.conv3x3
30
+ conv1x1 = layerspp.conv1x1
31
+ get_act = layers.get_act
32
+ get_normalization = normalization.get_normalization
33
+ default_initializer = layers.default_init
34
+
35
+
36
+ @BackboneRegistry.register("ncsnpp")
37
+ class NCSNpp(nn.Module):
38
+ """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
+
40
+ @staticmethod
41
+ def add_argparse_args(parser):
42
+ # TODO: add additional arguments of constructor, if you wish to modify them.
43
+ return parser
44
+
45
+ def __init__(self,
46
+ scale_by_sigma = True,
47
+ nonlinearity = 'swish',
48
+ nf = 128,
49
+ ch_mult = (1, 1, 2, 2, 2, 2, 2),
50
+ num_res_blocks = 2,
51
+ attn_resolutions = (16,),
52
+ resamp_with_conv = True,
53
+ conditional = True,
54
+ fir = True,
55
+ fir_kernel = 'song',
56
+ skip_rescale = True,
57
+ resblock_type = 'biggan',
58
+ progressive = 'output_skip',
59
+ progressive_input = 'input_skip',
60
+ progressive_combine = 'sum',
61
+ init_scale = 0.,
62
+ fourier_scale = 16,
63
+ image_size = 256,
64
+ embedding_type = 'fourier',
65
+ dropout = .0,
66
+ **unused_kwargs
67
+ ):
68
+ super().__init__()
69
+ self.act = act = get_act(nonlinearity)
70
+
71
+ self.nf = nf = nf
72
+ ch_mult = ch_mult
73
+ self.num_res_blocks = num_res_blocks = num_res_blocks
74
+ self.attn_resolutions = attn_resolutions = attn_resolutions
75
+ dropout = dropout
76
+ resamp_with_conv = resamp_with_conv
77
+ self.num_resolutions = num_resolutions = len(ch_mult)
78
+ self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
79
+
80
+ self.conditional = conditional = conditional # noise-conditional
81
+ self.scale_by_sigma = scale_by_sigma
82
+ fir = fir
83
+ fir_kernel = [1, 3, 3, 1]
84
+ self.skip_rescale = skip_rescale = skip_rescale
85
+ self.resblock_type = resblock_type = resblock_type.lower()
86
+ self.progressive = progressive = progressive.lower()
87
+ self.progressive_input = progressive_input = progressive_input.lower()
88
+ self.embedding_type = embedding_type = embedding_type.lower()
89
+ init_scale = init_scale
90
+ assert progressive in ['none', 'output_skip', 'residual']
91
+ assert progressive_input in ['none', 'input_skip', 'residual']
92
+ assert embedding_type in ['fourier', 'positional']
93
+ combine_method = progressive_combine.lower()
94
+ combiner = functools.partial(Combine, method=combine_method)
95
+
96
+ num_channels = 6 # x.real, x.imag, y.real, y.imag
97
+ self.output_layer = nn.Conv2d(num_channels, 2, 1)
98
+
99
+ modules = []
100
+ # timestep/noise_level embedding
101
+ if embedding_type == 'fourier':
102
+ # Gaussian Fourier features embeddings.
103
+ modules.append(layerspp.GaussianFourierProjection(
104
+ embedding_size=nf, scale=fourier_scale
105
+ ))
106
+ embed_dim = 2 * nf
107
+ elif embedding_type == 'positional':
108
+ embed_dim = nf
109
+ else:
110
+ raise ValueError(f'embedding type {embedding_type} unknown.')
111
+
112
+ if conditional:
113
+ modules.append(nn.Linear(embed_dim, nf * 4))
114
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
115
+ nn.init.zeros_(modules[-1].bias)
116
+ modules.append(nn.Linear(nf * 4, nf * 4))
117
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
118
+ nn.init.zeros_(modules[-1].bias)
119
+
120
+ AttnBlock = functools.partial(layerspp.AttnBlockpp,
121
+ init_scale=init_scale, skip_rescale=skip_rescale)
122
+
123
+ Upsample = functools.partial(UpsampleLayer,
124
+ with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
125
+
126
+ if progressive == 'output_skip':
127
+ self.pyramid_upsample = UpsampleLayer(fir=fir, fir_kernel=fir_kernel, with_conv=False)
128
+ elif progressive == 'residual':
129
+ pyramid_upsample = functools.partial(UpsampleLayer, fir=fir,
130
+ fir_kernel=fir_kernel, with_conv=True)
131
+
132
+ Downsample = functools.partial(DownsampleLayer, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
133
+
134
+ if progressive_input == 'input_skip':
135
+ self.pyramid_downsample = DownsampleLayer(fir=fir, fir_kernel=fir_kernel, with_conv=False)
136
+ elif progressive_input == 'residual':
137
+ pyramid_downsample = functools.partial(DownsampleLayer,
138
+ fir=fir, fir_kernel=fir_kernel, with_conv=True)
139
+
140
+ if resblock_type == 'ddpm':
141
+ ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
142
+ dropout=dropout, init_scale=init_scale,
143
+ skip_rescale=skip_rescale, temb_dim=nf * 4)
144
+
145
+ elif resblock_type == 'biggan':
146
+ ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
147
+ dropout=dropout, fir=fir, fir_kernel=fir_kernel,
148
+ init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
149
+
150
+ else:
151
+ raise ValueError(f'resblock type {resblock_type} unrecognized.')
152
+
153
+ # Downsampling block
154
+
155
+ channels = num_channels
156
+ if progressive_input != 'none':
157
+ input_pyramid_ch = channels
158
+
159
+ modules.append(conv3x3(channels, nf))
160
+ hs_c = [nf]
161
+
162
+ in_ch = nf
163
+ for i_level in range(num_resolutions):
164
+ # Residual blocks for this resolution
165
+ for i_block in range(num_res_blocks):
166
+ out_ch = nf * ch_mult[i_level]
167
+ modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
168
+ in_ch = out_ch
169
+
170
+ if all_resolutions[i_level] in attn_resolutions:
171
+ modules.append(AttnBlock(channels=in_ch))
172
+ hs_c.append(in_ch)
173
+
174
+ if i_level != num_resolutions - 1:
175
+ if resblock_type == 'ddpm':
176
+ modules.append(Downsample(in_ch=in_ch))
177
+ else:
178
+ modules.append(ResnetBlock(down=True, in_ch=in_ch))
179
+
180
+ if progressive_input == 'input_skip':
181
+ modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
182
+ if combine_method == 'cat':
183
+ in_ch *= 2
184
+
185
+ elif progressive_input == 'residual':
186
+ modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
187
+ input_pyramid_ch = in_ch
188
+
189
+ hs_c.append(in_ch)
190
+
191
+ in_ch = hs_c[-1]
192
+ modules.append(ResnetBlock(in_ch=in_ch))
193
+ modules.append(AttnBlock(channels=in_ch))
194
+ modules.append(ResnetBlock(in_ch=in_ch))
195
+
196
+ pyramid_ch = 0
197
+ # Upsampling block
198
+ for i_level in reversed(range(num_resolutions)):
199
+ for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
200
+ out_ch = nf * ch_mult[i_level]
201
+ modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
202
+ in_ch = out_ch
203
+
204
+ if all_resolutions[i_level] in attn_resolutions:
205
+ modules.append(AttnBlock(channels=in_ch))
206
+
207
+ if progressive != 'none':
208
+ if i_level == num_resolutions - 1:
209
+ if progressive == 'output_skip':
210
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
211
+ num_channels=in_ch, eps=1e-6))
212
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
213
+ pyramid_ch = channels
214
+ elif progressive == 'residual':
215
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
216
+ modules.append(conv3x3(in_ch, in_ch, bias=True))
217
+ pyramid_ch = in_ch
218
+ else:
219
+ raise ValueError(f'{progressive} is not a valid name.')
220
+ else:
221
+ if progressive == 'output_skip':
222
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
223
+ num_channels=in_ch, eps=1e-6))
224
+ modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
225
+ pyramid_ch = channels
226
+ elif progressive == 'residual':
227
+ modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
228
+ pyramid_ch = in_ch
229
+ else:
230
+ raise ValueError(f'{progressive} is not a valid name')
231
+
232
+ if i_level != 0:
233
+ if resblock_type == 'ddpm':
234
+ modules.append(Upsample(in_ch=in_ch))
235
+ else:
236
+ modules.append(ResnetBlock(in_ch=in_ch, up=True))
237
+
238
+ assert not hs_c
239
+
240
+ if progressive != 'output_skip':
241
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
242
+ num_channels=in_ch, eps=1e-6))
243
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
244
+
245
+ self.all_modules = nn.ModuleList(modules)
246
+
247
+ def forward(self, x, time_cond, scale_divide):
248
+ # timestep/noise_level embedding; only for continuous training
249
+ modules = self.all_modules
250
+ m_idx = 0
251
+
252
+ # Convert real and imaginary parts of (x,y) into four channel dimensions
253
+ x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
254
+ x[:,[1],:,:].real, x[:,[1],:,:].imag,
255
+ x[:,[2],:,:].real, x[:,[2],:,:].imag), dim=1)
256
+
257
+ if self.embedding_type == 'fourier':
258
+ # Gaussian Fourier features embeddings.
259
+ used_sigmas = time_cond
260
+ temb = modules[m_idx](torch.log(used_sigmas))
261
+ m_idx += 1
262
+
263
+ elif self.embedding_type == 'positional':
264
+ # Sinusoidal positional embeddings.
265
+ timesteps = time_cond
266
+ used_sigmas = self.sigmas[time_cond.long()]
267
+ temb = layers.get_timestep_embedding(timesteps, self.nf)
268
+
269
+ else:
270
+ raise ValueError(f'embedding type {self.embedding_type} unknown.')
271
+
272
+ if self.conditional:
273
+ temb = modules[m_idx](temb)
274
+ m_idx += 1
275
+ temb = modules[m_idx](self.act(temb))
276
+ m_idx += 1
277
+ else:
278
+ temb = None
279
+
280
+ # Downsampling block
281
+ input_pyramid = None
282
+ if self.progressive_input != 'none':
283
+ input_pyramid = x
284
+
285
+ # Input layer: Conv2d: 4ch -> 128ch
286
+ hs = [modules[m_idx](x)]
287
+ m_idx += 1
288
+
289
+ # Down path in U-Net
290
+ for i_level in range(self.num_resolutions):
291
+ # Residual blocks for this resolution
292
+ for i_block in range(self.num_res_blocks):
293
+ h = modules[m_idx](hs[-1], temb)
294
+ m_idx += 1
295
+ # Attention layer (optional)
296
+ if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
297
+ h = modules[m_idx](h)
298
+ m_idx += 1
299
+ hs.append(h)
300
+
301
+ # Downsampling
302
+ if i_level != self.num_resolutions - 1:
303
+ if self.resblock_type == 'ddpm':
304
+ h = modules[m_idx](hs[-1])
305
+ m_idx += 1
306
+ else:
307
+ h = modules[m_idx](hs[-1], temb)
308
+ m_idx += 1
309
+
310
+ if self.progressive_input == 'input_skip': # Combine h with x
311
+ input_pyramid = self.pyramid_downsample(input_pyramid)
312
+ h = modules[m_idx](input_pyramid, h)
313
+ m_idx += 1
314
+
315
+ elif self.progressive_input == 'residual':
316
+ input_pyramid = modules[m_idx](input_pyramid)
317
+ m_idx += 1
318
+ if self.skip_rescale:
319
+ input_pyramid = (input_pyramid + h) / np.sqrt(2.)
320
+ else:
321
+ input_pyramid = input_pyramid + h
322
+ h = input_pyramid
323
+ hs.append(h)
324
+
325
+ h = hs[-1] # actualy equal to: h = h
326
+ h = modules[m_idx](h, temb) # ResNet block
327
+ m_idx += 1
328
+ h = modules[m_idx](h) # Attention block
329
+ m_idx += 1
330
+ h = modules[m_idx](h, temb) # ResNet block
331
+ m_idx += 1
332
+
333
+ pyramid = None
334
+
335
+ # Upsampling block
336
+ for i_level in reversed(range(self.num_resolutions)):
337
+ for i_block in range(self.num_res_blocks + 1):
338
+ h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
339
+ m_idx += 1
340
+
341
+ # edit: from -1 to -2
342
+ if h.shape[-2] in self.attn_resolutions:
343
+ h = modules[m_idx](h)
344
+ m_idx += 1
345
+
346
+ if self.progressive != 'none':
347
+ if i_level == self.num_resolutions - 1:
348
+ if self.progressive == 'output_skip':
349
+ pyramid = self.act(modules[m_idx](h)) # GroupNorm
350
+ m_idx += 1
351
+ pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
352
+ m_idx += 1
353
+ elif self.progressive == 'residual':
354
+ pyramid = self.act(modules[m_idx](h))
355
+ m_idx += 1
356
+ pyramid = modules[m_idx](pyramid)
357
+ m_idx += 1
358
+ else:
359
+ raise ValueError(f'{self.progressive} is not a valid name.')
360
+ else:
361
+ if self.progressive == 'output_skip':
362
+ pyramid = self.pyramid_upsample(pyramid) # Upsample
363
+ pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
364
+ m_idx += 1
365
+ pyramid_h = modules[m_idx](pyramid_h)
366
+ m_idx += 1
367
+ pyramid = pyramid + pyramid_h
368
+ elif self.progressive == 'residual':
369
+ pyramid = modules[m_idx](pyramid)
370
+ m_idx += 1
371
+ if self.skip_rescale:
372
+ pyramid = (pyramid + h) / np.sqrt(2.)
373
+ else:
374
+ pyramid = pyramid + h
375
+ h = pyramid
376
+ else:
377
+ raise ValueError(f'{self.progressive} is not a valid name')
378
+
379
+ # Upsampling Layer
380
+ if i_level != 0:
381
+ if self.resblock_type == 'ddpm':
382
+ h = modules[m_idx](h)
383
+ m_idx += 1
384
+ else:
385
+ h = modules[m_idx](h, temb) # Upspampling
386
+ m_idx += 1
387
+
388
+ assert not hs
389
+
390
+ if self.progressive == 'output_skip':
391
+ h = pyramid
392
+ else:
393
+ h = self.act(modules[m_idx](h))
394
+ m_idx += 1
395
+ h = modules[m_idx](h)
396
+ m_idx += 1
397
+
398
+ assert m_idx == len(modules), "Implementation error"
399
+ h = h / scale_divide
400
+ # h = h / used_sigmas[:, None, None, None]
401
+
402
+ # Convert back to complex number
403
+ h = self.output_layer(h)
404
+ h = torch.permute(h, (0, 2, 3, 1)).contiguous()
405
+ h = torch.view_as_complex(h)[:,None, :, :]
406
+ return h
fastgeco/backbones/ncsnpp_utils/layers.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ """Common layers for defining score networks.
18
+ """
19
+ import math
20
+ import string
21
+ from functools import partial
22
+ import torch.nn as nn
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import numpy as np
26
+ from .normalization import ConditionalInstanceNorm2dPlus
27
+
28
+
29
+ def get_act(config):
30
+ """Get activation functions from the config file."""
31
+
32
+ if config == 'elu':
33
+ return nn.ELU()
34
+ elif config == 'relu':
35
+ return nn.ReLU()
36
+ elif config == 'lrelu':
37
+ return nn.LeakyReLU(negative_slope=0.2)
38
+ elif config == 'swish':
39
+ return nn.SiLU()
40
+ else:
41
+ raise NotImplementedError('activation function does not exist!')
42
+
43
+
44
+ def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
45
+ """1x1 convolution. Same as NCSNv1/v2."""
46
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
47
+ padding=padding)
48
+ init_scale = 1e-10 if init_scale == 0 else init_scale
49
+ conv.weight.data *= init_scale
50
+ conv.bias.data *= init_scale
51
+ return conv
52
+
53
+
54
+ def variance_scaling(scale, mode, distribution,
55
+ in_axis=1, out_axis=0,
56
+ dtype=torch.float32,
57
+ device='cpu'):
58
+ """Ported from JAX. """
59
+
60
+ def _compute_fans(shape, in_axis=1, out_axis=0):
61
+ receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
62
+ fan_in = shape[in_axis] * receptive_field_size
63
+ fan_out = shape[out_axis] * receptive_field_size
64
+ return fan_in, fan_out
65
+
66
+ def init(shape, dtype=dtype, device=device):
67
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
68
+ if mode == "fan_in":
69
+ denominator = fan_in
70
+ elif mode == "fan_out":
71
+ denominator = fan_out
72
+ elif mode == "fan_avg":
73
+ denominator = (fan_in + fan_out) / 2
74
+ else:
75
+ raise ValueError(
76
+ "invalid mode for variance scaling initializer: {}".format(mode))
77
+ variance = scale / denominator
78
+ if distribution == "normal":
79
+ return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
80
+ elif distribution == "uniform":
81
+ return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
82
+ else:
83
+ raise ValueError("invalid distribution for variance scaling initializer")
84
+
85
+ return init
86
+
87
+
88
+ def default_init(scale=1.):
89
+ """The same initialization used in DDPM."""
90
+ scale = 1e-10 if scale == 0 else scale
91
+ return variance_scaling(scale, 'fan_avg', 'uniform')
92
+
93
+
94
+ class Dense(nn.Module):
95
+ """Linear layer with `default_init`."""
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+
100
+ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
101
+ """1x1 convolution with DDPM initialization."""
102
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
103
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
104
+ nn.init.zeros_(conv.bias)
105
+ return conv
106
+
107
+
108
+ def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
109
+ """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
110
+ init_scale = 1e-10 if init_scale == 0 else init_scale
111
+ conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
112
+ dilation=dilation, padding=padding, kernel_size=3)
113
+ conv.weight.data *= init_scale
114
+ conv.bias.data *= init_scale
115
+ return conv
116
+
117
+
118
+ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
119
+ """3x3 convolution with DDPM initialization."""
120
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
121
+ dilation=dilation, bias=bias)
122
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
123
+ nn.init.zeros_(conv.bias)
124
+ return conv
125
+
126
+ ###########################################################################
127
+ # Functions below are ported over from the NCSNv1/NCSNv2 codebase:
128
+ # https://github.com/ermongroup/ncsn
129
+ # https://github.com/ermongroup/ncsnv2
130
+ ###########################################################################
131
+
132
+
133
+ class CRPBlock(nn.Module):
134
+ def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
135
+ super().__init__()
136
+ self.convs = nn.ModuleList()
137
+ for i in range(n_stages):
138
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
139
+ self.n_stages = n_stages
140
+ if maxpool:
141
+ self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
142
+ else:
143
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
144
+
145
+ self.act = act
146
+
147
+ def forward(self, x):
148
+ x = self.act(x)
149
+ path = x
150
+ for i in range(self.n_stages):
151
+ path = self.pool(path)
152
+ path = self.convs[i](path)
153
+ x = path + x
154
+ return x
155
+
156
+
157
+ class CondCRPBlock(nn.Module):
158
+ def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
159
+ super().__init__()
160
+ self.convs = nn.ModuleList()
161
+ self.norms = nn.ModuleList()
162
+ self.normalizer = normalizer
163
+ for i in range(n_stages):
164
+ self.norms.append(normalizer(features, num_classes, bias=True))
165
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
166
+
167
+ self.n_stages = n_stages
168
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
169
+ self.act = act
170
+
171
+ def forward(self, x, y):
172
+ x = self.act(x)
173
+ path = x
174
+ for i in range(self.n_stages):
175
+ path = self.norms[i](path, y)
176
+ path = self.pool(path)
177
+ path = self.convs[i](path)
178
+
179
+ x = path + x
180
+ return x
181
+
182
+
183
+ class RCUBlock(nn.Module):
184
+ def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
185
+ super().__init__()
186
+
187
+ for i in range(n_blocks):
188
+ for j in range(n_stages):
189
+ setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
190
+
191
+ self.stride = 1
192
+ self.n_blocks = n_blocks
193
+ self.n_stages = n_stages
194
+ self.act = act
195
+
196
+ def forward(self, x):
197
+ for i in range(self.n_blocks):
198
+ residual = x
199
+ for j in range(self.n_stages):
200
+ x = self.act(x)
201
+ x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
202
+
203
+ x += residual
204
+ return x
205
+
206
+
207
+ class CondRCUBlock(nn.Module):
208
+ def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
209
+ super().__init__()
210
+
211
+ for i in range(n_blocks):
212
+ for j in range(n_stages):
213
+ setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
214
+ setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
215
+
216
+ self.stride = 1
217
+ self.n_blocks = n_blocks
218
+ self.n_stages = n_stages
219
+ self.act = act
220
+ self.normalizer = normalizer
221
+
222
+ def forward(self, x, y):
223
+ for i in range(self.n_blocks):
224
+ residual = x
225
+ for j in range(self.n_stages):
226
+ x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
227
+ x = self.act(x)
228
+ x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
229
+
230
+ x += residual
231
+ return x
232
+
233
+
234
+ class MSFBlock(nn.Module):
235
+ def __init__(self, in_planes, features):
236
+ super().__init__()
237
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
238
+ self.convs = nn.ModuleList()
239
+ self.features = features
240
+
241
+ for i in range(len(in_planes)):
242
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
243
+
244
+ def forward(self, xs, shape):
245
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
246
+ for i in range(len(self.convs)):
247
+ h = self.convs[i](xs[i])
248
+ h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
249
+ sums += h
250
+ return sums
251
+
252
+
253
+ class CondMSFBlock(nn.Module):
254
+ def __init__(self, in_planes, features, num_classes, normalizer):
255
+ super().__init__()
256
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
257
+
258
+ self.convs = nn.ModuleList()
259
+ self.norms = nn.ModuleList()
260
+ self.features = features
261
+ self.normalizer = normalizer
262
+
263
+ for i in range(len(in_planes)):
264
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
265
+ self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
266
+
267
+ def forward(self, xs, y, shape):
268
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
269
+ for i in range(len(self.convs)):
270
+ h = self.norms[i](xs[i], y)
271
+ h = self.convs[i](h)
272
+ h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
273
+ sums += h
274
+ return sums
275
+
276
+
277
+ class RefineBlock(nn.Module):
278
+ def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
279
+ super().__init__()
280
+
281
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
282
+ self.n_blocks = n_blocks = len(in_planes)
283
+
284
+ self.adapt_convs = nn.ModuleList()
285
+ for i in range(n_blocks):
286
+ self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
287
+
288
+ self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
289
+
290
+ if not start:
291
+ self.msf = MSFBlock(in_planes, features)
292
+
293
+ self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
294
+
295
+ def forward(self, xs, output_shape):
296
+ assert isinstance(xs, tuple) or isinstance(xs, list)
297
+ hs = []
298
+ for i in range(len(xs)):
299
+ h = self.adapt_convs[i](xs[i])
300
+ hs.append(h)
301
+
302
+ if self.n_blocks > 1:
303
+ h = self.msf(hs, output_shape)
304
+ else:
305
+ h = hs[0]
306
+
307
+ h = self.crp(h)
308
+ h = self.output_convs(h)
309
+
310
+ return h
311
+
312
+
313
+ class CondRefineBlock(nn.Module):
314
+ def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
315
+ super().__init__()
316
+
317
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
318
+ self.n_blocks = n_blocks = len(in_planes)
319
+
320
+ self.adapt_convs = nn.ModuleList()
321
+ for i in range(n_blocks):
322
+ self.adapt_convs.append(
323
+ CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
324
+ )
325
+
326
+ self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
327
+
328
+ if not start:
329
+ self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
330
+
331
+ self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
332
+
333
+ def forward(self, xs, y, output_shape):
334
+ assert isinstance(xs, tuple) or isinstance(xs, list)
335
+ hs = []
336
+ for i in range(len(xs)):
337
+ h = self.adapt_convs[i](xs[i], y)
338
+ hs.append(h)
339
+
340
+ if self.n_blocks > 1:
341
+ h = self.msf(hs, y, output_shape)
342
+ else:
343
+ h = hs[0]
344
+
345
+ h = self.crp(h, y)
346
+ h = self.output_convs(h, y)
347
+
348
+ return h
349
+
350
+
351
+ class ConvMeanPool(nn.Module):
352
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
353
+ super().__init__()
354
+ if not adjust_padding:
355
+ conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
356
+ self.conv = conv
357
+ else:
358
+ conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
359
+
360
+ self.conv = nn.Sequential(
361
+ nn.ZeroPad2d((1, 0, 1, 0)),
362
+ conv
363
+ )
364
+
365
+ def forward(self, inputs):
366
+ output = self.conv(inputs)
367
+ output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
368
+ output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
369
+ return output
370
+
371
+
372
+ class MeanPoolConv(nn.Module):
373
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
374
+ super().__init__()
375
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
376
+
377
+ def forward(self, inputs):
378
+ output = inputs
379
+ output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
380
+ output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
381
+ return self.conv(output)
382
+
383
+
384
+ class UpsampleConv(nn.Module):
385
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
386
+ super().__init__()
387
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
388
+ self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
389
+
390
+ def forward(self, inputs):
391
+ output = inputs
392
+ output = torch.cat([output, output, output, output], dim=1)
393
+ output = self.pixelshuffle(output)
394
+ return self.conv(output)
395
+
396
+
397
+ class ConditionalResidualBlock(nn.Module):
398
+ def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
399
+ normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
400
+ super().__init__()
401
+ self.non_linearity = act
402
+ self.input_dim = input_dim
403
+ self.output_dim = output_dim
404
+ self.resample = resample
405
+ self.normalization = normalization
406
+ if resample == 'down':
407
+ if dilation > 1:
408
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
409
+ self.normalize2 = normalization(input_dim, num_classes)
410
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
411
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
412
+ else:
413
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
414
+ self.normalize2 = normalization(input_dim, num_classes)
415
+ self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
416
+ conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
417
+
418
+ elif resample is None:
419
+ if dilation > 1:
420
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
421
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
422
+ self.normalize2 = normalization(output_dim, num_classes)
423
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
424
+ else:
425
+ conv_shortcut = nn.Conv2d
426
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
427
+ self.normalize2 = normalization(output_dim, num_classes)
428
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
429
+ else:
430
+ raise Exception('invalid resample value')
431
+
432
+ if output_dim != input_dim or resample is not None:
433
+ self.shortcut = conv_shortcut(input_dim, output_dim)
434
+
435
+ self.normalize1 = normalization(input_dim, num_classes)
436
+
437
+ def forward(self, x, y):
438
+ output = self.normalize1(x, y)
439
+ output = self.non_linearity(output)
440
+ output = self.conv1(output)
441
+ output = self.normalize2(output, y)
442
+ output = self.non_linearity(output)
443
+ output = self.conv2(output)
444
+
445
+ if self.output_dim == self.input_dim and self.resample is None:
446
+ shortcut = x
447
+ else:
448
+ shortcut = self.shortcut(x)
449
+
450
+ return shortcut + output
451
+
452
+
453
+ class ResidualBlock(nn.Module):
454
+ def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
455
+ normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
456
+ super().__init__()
457
+ self.non_linearity = act
458
+ self.input_dim = input_dim
459
+ self.output_dim = output_dim
460
+ self.resample = resample
461
+ self.normalization = normalization
462
+ if resample == 'down':
463
+ if dilation > 1:
464
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
465
+ self.normalize2 = normalization(input_dim)
466
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
467
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
468
+ else:
469
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
470
+ self.normalize2 = normalization(input_dim)
471
+ self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
472
+ conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
473
+
474
+ elif resample is None:
475
+ if dilation > 1:
476
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
477
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
478
+ self.normalize2 = normalization(output_dim)
479
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
480
+ else:
481
+ # conv_shortcut = nn.Conv2d ### Something wierd here.
482
+ conv_shortcut = partial(ncsn_conv1x1)
483
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
484
+ self.normalize2 = normalization(output_dim)
485
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
486
+ else:
487
+ raise Exception('invalid resample value')
488
+
489
+ if output_dim != input_dim or resample is not None:
490
+ self.shortcut = conv_shortcut(input_dim, output_dim)
491
+
492
+ self.normalize1 = normalization(input_dim)
493
+
494
+ def forward(self, x):
495
+ output = self.normalize1(x)
496
+ output = self.non_linearity(output)
497
+ output = self.conv1(output)
498
+ output = self.normalize2(output)
499
+ output = self.non_linearity(output)
500
+ output = self.conv2(output)
501
+
502
+ if self.output_dim == self.input_dim and self.resample is None:
503
+ shortcut = x
504
+ else:
505
+ shortcut = self.shortcut(x)
506
+
507
+ return shortcut + output
508
+
509
+
510
+ ###########################################################################
511
+ # Functions below are ported over from the DDPM codebase:
512
+ # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
513
+ ###########################################################################
514
+
515
+ def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
516
+ assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
517
+ half_dim = embedding_dim // 2
518
+ # magic number 10000 is from transformers
519
+ emb = math.log(max_positions) / (half_dim - 1)
520
+ # emb = math.log(2.) / (half_dim - 1)
521
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
522
+ # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
523
+ # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
524
+ emb = timesteps.float()[:, None] * emb[None, :]
525
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
526
+ if embedding_dim % 2 == 1: # zero pad
527
+ emb = F.pad(emb, (0, 1), mode='constant')
528
+ assert emb.shape == (timesteps.shape[0], embedding_dim)
529
+ return emb
530
+
531
+
532
+ def _einsum(a, b, c, x, y):
533
+ einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
534
+ return torch.einsum(einsum_str, x, y)
535
+
536
+
537
+ def contract_inner(x, y):
538
+ """tensordot(x, y, 1)."""
539
+ x_chars = list(string.ascii_lowercase[:len(x.shape)])
540
+ y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
541
+ y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
542
+ out_chars = x_chars[:-1] + y_chars[1:]
543
+ return _einsum(x_chars, y_chars, out_chars, x, y)
544
+
545
+
546
+ class NIN(nn.Module):
547
+ def __init__(self, in_dim, num_units, init_scale=0.1):
548
+ super().__init__()
549
+ self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
550
+ self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
551
+
552
+ def forward(self, x):
553
+ x = x.permute(0, 2, 3, 1)
554
+ y = contract_inner(x, self.W) + self.b
555
+ return y.permute(0, 3, 1, 2)
556
+
557
+
558
+ class AttnBlock(nn.Module):
559
+ """Channel-wise self-attention block."""
560
+ def __init__(self, channels):
561
+ super().__init__()
562
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
563
+ self.NIN_0 = NIN(channels, channels)
564
+ self.NIN_1 = NIN(channels, channels)
565
+ self.NIN_2 = NIN(channels, channels)
566
+ self.NIN_3 = NIN(channels, channels, init_scale=0.)
567
+
568
+ def forward(self, x):
569
+ B, C, H, W = x.shape
570
+ h = self.GroupNorm_0(x)
571
+ q = self.NIN_0(h)
572
+ k = self.NIN_1(h)
573
+ v = self.NIN_2(h)
574
+
575
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
576
+ w = torch.reshape(w, (B, H, W, H * W))
577
+ w = F.softmax(w, dim=-1)
578
+ w = torch.reshape(w, (B, H, W, H, W))
579
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
580
+ h = self.NIN_3(h)
581
+ return x + h
582
+
583
+
584
+ class Upsample(nn.Module):
585
+ def __init__(self, channels, with_conv=False):
586
+ super().__init__()
587
+ if with_conv:
588
+ self.Conv_0 = ddpm_conv3x3(channels, channels)
589
+ self.with_conv = with_conv
590
+
591
+ def forward(self, x):
592
+ B, C, H, W = x.shape
593
+ h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
594
+ if self.with_conv:
595
+ h = self.Conv_0(h)
596
+ return h
597
+
598
+
599
+ class Downsample(nn.Module):
600
+ def __init__(self, channels, with_conv=False):
601
+ super().__init__()
602
+ if with_conv:
603
+ self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
604
+ self.with_conv = with_conv
605
+
606
+ def forward(self, x):
607
+ B, C, H, W = x.shape
608
+ # Emulate 'SAME' padding
609
+ if self.with_conv:
610
+ x = F.pad(x, (0, 1, 0, 1))
611
+ x = self.Conv_0(x)
612
+ else:
613
+ x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
614
+
615
+ assert x.shape == (B, C, H // 2, W // 2)
616
+ return x
617
+
618
+
619
+ class ResnetBlockDDPM(nn.Module):
620
+ """The ResNet Blocks used in DDPM."""
621
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
622
+ super().__init__()
623
+ if out_ch is None:
624
+ out_ch = in_ch
625
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
626
+ self.act = act
627
+ self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
628
+ if temb_dim is not None:
629
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
630
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
631
+ nn.init.zeros_(self.Dense_0.bias)
632
+
633
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
634
+ self.Dropout_0 = nn.Dropout(dropout)
635
+ self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
636
+ if in_ch != out_ch:
637
+ if conv_shortcut:
638
+ self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
639
+ else:
640
+ self.NIN_0 = NIN(in_ch, out_ch)
641
+ self.out_ch = out_ch
642
+ self.in_ch = in_ch
643
+ self.conv_shortcut = conv_shortcut
644
+
645
+ def forward(self, x, temb=None):
646
+ B, C, H, W = x.shape
647
+ assert C == self.in_ch
648
+ out_ch = self.out_ch if self.out_ch else self.in_ch
649
+ h = self.act(self.GroupNorm_0(x))
650
+ h = self.Conv_0(h)
651
+ # Add bias to each feature map conditioned on the time embedding
652
+ if temb is not None:
653
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
654
+ h = self.act(self.GroupNorm_1(h))
655
+ h = self.Dropout_0(h)
656
+ h = self.Conv_1(h)
657
+ if C != out_ch:
658
+ if self.conv_shortcut:
659
+ x = self.Conv_2(x)
660
+ else:
661
+ x = self.NIN_0(x)
662
+ return x + h
fastgeco/backbones/ncsnpp_utils/layerspp.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ """Layers for defining NCSN++.
18
+ """
19
+ from . import layers
20
+ import score_models.layers.up_or_downsampling2d as up_or_down_sampling
21
+ import torch.nn as nn
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import numpy as np
25
+
26
+ conv1x1 = layers.ddpm_conv1x1
27
+ conv3x3 = layers.ddpm_conv3x3
28
+ NIN = layers.NIN
29
+ default_init = layers.default_init
30
+
31
+
32
+ class GaussianFourierProjection(nn.Module):
33
+ """Gaussian Fourier embeddings for noise levels."""
34
+
35
+ def __init__(self, embedding_size=256, scale=1.0):
36
+ super().__init__()
37
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
38
+
39
+ def forward(self, x):
40
+ x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
41
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
42
+
43
+
44
+ class Combine(nn.Module):
45
+ """Combine information from skip connections."""
46
+
47
+ def __init__(self, dim1, dim2, method='cat'):
48
+ super().__init__()
49
+ self.Conv_0 = conv1x1(dim1, dim2)
50
+ self.method = method
51
+
52
+ def forward(self, x, y):
53
+ h = self.Conv_0(x)
54
+ if self.method == 'cat':
55
+ return torch.cat([h, y], dim=1)
56
+ elif self.method == 'sum':
57
+ return h + y
58
+ else:
59
+ raise ValueError(f'Method {self.method} not recognized.')
60
+
61
+
62
+ class AttnBlockpp(nn.Module):
63
+ """Channel-wise self-attention block. Modified from DDPM."""
64
+
65
+ def __init__(self, channels, skip_rescale=False, init_scale=0.):
66
+ super().__init__()
67
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
68
+ eps=1e-6)
69
+ self.NIN_0 = NIN(channels, channels)
70
+ self.NIN_1 = NIN(channels, channels)
71
+ self.NIN_2 = NIN(channels, channels)
72
+ self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
73
+ self.skip_rescale = skip_rescale
74
+
75
+ def forward(self, x):
76
+ B, C, H, W = x.shape
77
+ h = self.GroupNorm_0(x)
78
+ q = self.NIN_0(h)
79
+ k = self.NIN_1(h)
80
+ v = self.NIN_2(h)
81
+
82
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
83
+ w = torch.reshape(w, (B, H, W, H * W))
84
+ w = F.softmax(w, dim=-1)
85
+ w = torch.reshape(w, (B, H, W, H, W))
86
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
87
+ h = self.NIN_3(h)
88
+ if not self.skip_rescale:
89
+ return x + h
90
+ else:
91
+ return (x + h) / np.sqrt(2.)
92
+
93
+
94
+ # class Upsample(nn.Module):
95
+ # def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
96
+ # fir_kernel=(1, 3, 3, 1)):
97
+ # super().__init__()
98
+ # out_ch = out_ch if out_ch else in_ch
99
+ # if not fir:
100
+ # if with_conv:
101
+ # self.Conv_0 = conv3x3(in_ch, out_ch)
102
+ # else:
103
+ # if with_conv:
104
+ # self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
105
+ # kernel=3, up=True,
106
+ # resample_kernel=fir_kernel,
107
+ # use_bias=True,
108
+ # kernel_init=default_init())
109
+ # self.fir = fir
110
+ # self.with_conv = with_conv
111
+ # self.fir_kernel = fir_kernel
112
+ # self.out_ch = out_ch
113
+
114
+ # def forward(self, x):
115
+ # B, C, H, W = x.shape
116
+ # if not self.fir:
117
+ # h = F.interpolate(x, (H * 2, W * 2), 'nearest')
118
+ # if self.with_conv:
119
+ # h = self.Conv_0(h)
120
+ # else:
121
+ # if not self.with_conv:
122
+ # h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
123
+ # else:
124
+ # h = self.Conv2d_0(x)
125
+
126
+ # return h
127
+
128
+
129
+ # class Downsample(nn.Module):
130
+ # def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
131
+ # fir_kernel=(1, 3, 3, 1)):
132
+ # super().__init__()
133
+ # out_ch = out_ch if out_ch else in_ch
134
+ # if not fir:
135
+ # if with_conv:
136
+ # self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
137
+ # else:
138
+ # if with_conv:
139
+ # self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
140
+ # kernel=3, down=True,
141
+ # resample_kernel=fir_kernel,
142
+ # use_bias=True,
143
+ # kernel_init=default_init())
144
+ # self.fir = fir
145
+ # self.fir_kernel = fir_kernel
146
+ # self.with_conv = with_conv
147
+ # self.out_ch = out_ch
148
+
149
+ # def forward(self, x):
150
+ # B, C, H, W = x.shape
151
+ # if not self.fir:
152
+ # if self.with_conv:
153
+ # x = F.pad(x, (0, 1, 0, 1))
154
+ # x = self.Conv_0(x)
155
+ # else:
156
+ # x = F.avg_pool2d(x, 2, stride=2)
157
+ # else:
158
+ # if not self.with_conv:
159
+ # x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
160
+ # else:
161
+ # x = self.Conv2d_0(x)
162
+
163
+ # return x
164
+
165
+
166
+ class ResnetBlockDDPMpp(nn.Module):
167
+ """ResBlock adapted from DDPM."""
168
+
169
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
170
+ dropout=0.1, skip_rescale=False, init_scale=0.):
171
+ super().__init__()
172
+ out_ch = out_ch if out_ch else in_ch
173
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
174
+ self.Conv_0 = conv3x3(in_ch, out_ch)
175
+ if temb_dim is not None:
176
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
177
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
178
+ nn.init.zeros_(self.Dense_0.bias)
179
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
180
+ self.Dropout_0 = nn.Dropout(dropout)
181
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
182
+ if in_ch != out_ch:
183
+ if conv_shortcut:
184
+ self.Conv_2 = conv3x3(in_ch, out_ch)
185
+ else:
186
+ self.NIN_0 = NIN(in_ch, out_ch)
187
+
188
+ self.skip_rescale = skip_rescale
189
+ self.act = act
190
+ self.out_ch = out_ch
191
+ self.conv_shortcut = conv_shortcut
192
+
193
+ def forward(self, x, temb=None):
194
+ h = self.act(self.GroupNorm_0(x))
195
+ h = self.Conv_0(h)
196
+ if temb is not None:
197
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
198
+ h = self.act(self.GroupNorm_1(h))
199
+ h = self.Dropout_0(h)
200
+ h = self.Conv_1(h)
201
+ if x.shape[1] != self.out_ch:
202
+ if self.conv_shortcut:
203
+ x = self.Conv_2(x)
204
+ else:
205
+ x = self.NIN_0(x)
206
+ if not self.skip_rescale:
207
+ return x + h
208
+ else:
209
+ return (x + h) / np.sqrt(2.)
210
+
211
+
212
+ class ResnetBlockBigGANpp(nn.Module):
213
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
214
+ dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
215
+ skip_rescale=True, init_scale=0.):
216
+ super().__init__()
217
+
218
+ out_ch = out_ch if out_ch else in_ch
219
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
220
+ self.up = up
221
+ self.down = down
222
+ self.fir = fir
223
+ self.fir_kernel = fir_kernel
224
+
225
+ self.Conv_0 = conv3x3(in_ch, out_ch)
226
+ if temb_dim is not None:
227
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
228
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
229
+ nn.init.zeros_(self.Dense_0.bias)
230
+
231
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
232
+ self.Dropout_0 = nn.Dropout(dropout)
233
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
234
+ if in_ch != out_ch or up or down:
235
+ self.Conv_2 = conv1x1(in_ch, out_ch)
236
+
237
+ self.skip_rescale = skip_rescale
238
+ self.act = act
239
+ self.in_ch = in_ch
240
+ self.out_ch = out_ch
241
+
242
+ def forward(self, x, temb=None):
243
+ h = self.act(self.GroupNorm_0(x))
244
+
245
+ if self.up:
246
+ if self.fir:
247
+ h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
248
+ x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
249
+ else:
250
+ h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
251
+ x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
252
+ elif self.down:
253
+ if self.fir:
254
+ h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
255
+ x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
256
+ else:
257
+ h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
258
+ x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
259
+
260
+ h = self.Conv_0(h)
261
+ # Add bias to each feature map conditioned on the time embedding
262
+ if temb is not None:
263
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
264
+ h = self.act(self.GroupNorm_1(h))
265
+ h = self.Dropout_0(h)
266
+ h = self.Conv_1(h)
267
+
268
+ if self.in_ch != self.out_ch or self.up or self.down:
269
+ x = self.Conv_2(x)
270
+
271
+ if not self.skip_rescale:
272
+ return x + h
273
+ else:
274
+ return (x + h) / np.sqrt(2.)
fastgeco/backbones/ncsnpp_utils/normalization.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Normalization layers."""
17
+ import torch.nn as nn
18
+ import torch
19
+ import functools
20
+
21
+
22
+ def get_normalization(config, conditional=False):
23
+ """Obtain normalization modules from the config file."""
24
+ norm = config.model.normalization
25
+ if conditional:
26
+ if norm == 'InstanceNorm++':
27
+ return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
28
+ else:
29
+ raise NotImplementedError(f'{norm} not implemented yet.')
30
+ else:
31
+ if norm == 'InstanceNorm':
32
+ return nn.InstanceNorm2d
33
+ elif norm == 'InstanceNorm++':
34
+ return InstanceNorm2dPlus
35
+ elif norm == 'VarianceNorm':
36
+ return VarianceNorm2d
37
+ elif norm == 'GroupNorm':
38
+ return nn.GroupNorm
39
+ else:
40
+ raise ValueError('Unknown normalization: %s' % norm)
41
+
42
+
43
+ class ConditionalBatchNorm2d(nn.Module):
44
+ def __init__(self, num_features, num_classes, bias=True):
45
+ super().__init__()
46
+ self.num_features = num_features
47
+ self.bias = bias
48
+ self.bn = nn.BatchNorm2d(num_features, affine=False)
49
+ if self.bias:
50
+ self.embed = nn.Embedding(num_classes, num_features * 2)
51
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
52
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
53
+ else:
54
+ self.embed = nn.Embedding(num_classes, num_features)
55
+ self.embed.weight.data.uniform_()
56
+
57
+ def forward(self, x, y):
58
+ out = self.bn(x)
59
+ if self.bias:
60
+ gamma, beta = self.embed(y).chunk(2, dim=1)
61
+ out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
62
+ else:
63
+ gamma = self.embed(y)
64
+ out = gamma.view(-1, self.num_features, 1, 1) * out
65
+ return out
66
+
67
+
68
+ class ConditionalInstanceNorm2d(nn.Module):
69
+ def __init__(self, num_features, num_classes, bias=True):
70
+ super().__init__()
71
+ self.num_features = num_features
72
+ self.bias = bias
73
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
74
+ if bias:
75
+ self.embed = nn.Embedding(num_classes, num_features * 2)
76
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
77
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
78
+ else:
79
+ self.embed = nn.Embedding(num_classes, num_features)
80
+ self.embed.weight.data.uniform_()
81
+
82
+ def forward(self, x, y):
83
+ h = self.instance_norm(x)
84
+ if self.bias:
85
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
86
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
87
+ else:
88
+ gamma = self.embed(y)
89
+ out = gamma.view(-1, self.num_features, 1, 1) * h
90
+ return out
91
+
92
+
93
+ class ConditionalVarianceNorm2d(nn.Module):
94
+ def __init__(self, num_features, num_classes, bias=False):
95
+ super().__init__()
96
+ self.num_features = num_features
97
+ self.bias = bias
98
+ self.embed = nn.Embedding(num_classes, num_features)
99
+ self.embed.weight.data.normal_(1, 0.02)
100
+
101
+ def forward(self, x, y):
102
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
103
+ h = x / torch.sqrt(vars + 1e-5)
104
+
105
+ gamma = self.embed(y)
106
+ out = gamma.view(-1, self.num_features, 1, 1) * h
107
+ return out
108
+
109
+
110
+ class VarianceNorm2d(nn.Module):
111
+ def __init__(self, num_features, bias=False):
112
+ super().__init__()
113
+ self.num_features = num_features
114
+ self.bias = bias
115
+ self.alpha = nn.Parameter(torch.zeros(num_features))
116
+ self.alpha.data.normal_(1, 0.02)
117
+
118
+ def forward(self, x):
119
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
120
+ h = x / torch.sqrt(vars + 1e-5)
121
+
122
+ out = self.alpha.view(-1, self.num_features, 1, 1) * h
123
+ return out
124
+
125
+
126
+ class ConditionalNoneNorm2d(nn.Module):
127
+ def __init__(self, num_features, num_classes, bias=True):
128
+ super().__init__()
129
+ self.num_features = num_features
130
+ self.bias = bias
131
+ if bias:
132
+ self.embed = nn.Embedding(num_classes, num_features * 2)
133
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
134
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
135
+ else:
136
+ self.embed = nn.Embedding(num_classes, num_features)
137
+ self.embed.weight.data.uniform_()
138
+
139
+ def forward(self, x, y):
140
+ if self.bias:
141
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
142
+ out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
143
+ else:
144
+ gamma = self.embed(y)
145
+ out = gamma.view(-1, self.num_features, 1, 1) * x
146
+ return out
147
+
148
+
149
+ class NoneNorm2d(nn.Module):
150
+ def __init__(self, num_features, bias=True):
151
+ super().__init__()
152
+
153
+ def forward(self, x):
154
+ return x
155
+
156
+
157
+ class InstanceNorm2dPlus(nn.Module):
158
+ def __init__(self, num_features, bias=True):
159
+ super().__init__()
160
+ self.num_features = num_features
161
+ self.bias = bias
162
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
163
+ self.alpha = nn.Parameter(torch.zeros(num_features))
164
+ self.gamma = nn.Parameter(torch.zeros(num_features))
165
+ self.alpha.data.normal_(1, 0.02)
166
+ self.gamma.data.normal_(1, 0.02)
167
+ if bias:
168
+ self.beta = nn.Parameter(torch.zeros(num_features))
169
+
170
+ def forward(self, x):
171
+ means = torch.mean(x, dim=(2, 3))
172
+ m = torch.mean(means, dim=-1, keepdim=True)
173
+ v = torch.var(means, dim=-1, keepdim=True)
174
+ means = (means - m) / (torch.sqrt(v + 1e-5))
175
+ h = self.instance_norm(x)
176
+
177
+ if self.bias:
178
+ h = h + means[..., None, None] * self.alpha[..., None, None]
179
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
180
+ else:
181
+ h = h + means[..., None, None] * self.alpha[..., None, None]
182
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h
183
+ return out
184
+
185
+
186
+ class ConditionalInstanceNorm2dPlus(nn.Module):
187
+ def __init__(self, num_features, num_classes, bias=True):
188
+ super().__init__()
189
+ self.num_features = num_features
190
+ self.bias = bias
191
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
192
+ if bias:
193
+ self.embed = nn.Embedding(num_classes, num_features * 3)
194
+ self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
195
+ self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
196
+ else:
197
+ self.embed = nn.Embedding(num_classes, 2 * num_features)
198
+ self.embed.weight.data.normal_(1, 0.02)
199
+
200
+ def forward(self, x, y):
201
+ means = torch.mean(x, dim=(2, 3))
202
+ m = torch.mean(means, dim=-1, keepdim=True)
203
+ v = torch.var(means, dim=-1, keepdim=True)
204
+ means = (means - m) / (torch.sqrt(v + 1e-5))
205
+ h = self.instance_norm(x)
206
+
207
+ if self.bias:
208
+ gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
209
+ h = h + means[..., None, None] * alpha[..., None, None]
210
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
211
+ else:
212
+ gamma, alpha = self.embed(y).chunk(2, dim=-1)
213
+ h = h + means[..., None, None] * alpha[..., None, None]
214
+ out = gamma.view(-1, self.num_features, 1, 1) * h
215
+ return out
fastgeco/backbones/ncsnpp_utils/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """All functions and modules related to model definition.
17
+ """
18
+
19
+ import torch
20
+
21
+ import numpy as np
22
+ from ...sdes import OUVESDE, OUVPSDE
23
+
24
+
25
+ _MODELS = {}
26
+
27
+
28
+ def register_model(cls=None, *, name=None):
29
+ """A decorator for registering model classes."""
30
+
31
+ def _register(cls):
32
+ if name is None:
33
+ local_name = cls.__name__
34
+ else:
35
+ local_name = name
36
+ if local_name in _MODELS:
37
+ raise ValueError(f'Already registered model with name: {local_name}')
38
+ _MODELS[local_name] = cls
39
+ return cls
40
+
41
+ if cls is None:
42
+ return _register
43
+ else:
44
+ return _register(cls)
45
+
46
+
47
+ def get_model(name):
48
+ return _MODELS[name]
49
+
50
+
51
+ def get_sigmas(sigma_min, sigma_max, num_scales):
52
+ """Get sigmas --- the set of noise levels for SMLD from config files.
53
+ Args:
54
+ config: A ConfigDict object parsed from the config file
55
+ Returns:
56
+ sigmas: a jax numpy arrary of noise levels
57
+ """
58
+ sigmas = np.exp(
59
+ np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
60
+
61
+ return sigmas
62
+
63
+
64
+ def get_ddpm_params(config):
65
+ """Get betas and alphas --- parameters used in the original DDPM paper."""
66
+ num_diffusion_timesteps = 1000
67
+ # parameters need to be adapted if number of time steps differs from 1000
68
+ beta_start = config.model.beta_min / config.model.num_scales
69
+ beta_end = config.model.beta_max / config.model.num_scales
70
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
71
+
72
+ alphas = 1. - betas
73
+ alphas_cumprod = np.cumprod(alphas, axis=0)
74
+ sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
75
+ sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
76
+
77
+ return {
78
+ 'betas': betas,
79
+ 'alphas': alphas,
80
+ 'alphas_cumprod': alphas_cumprod,
81
+ 'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
82
+ 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
83
+ 'beta_min': beta_start * (num_diffusion_timesteps - 1),
84
+ 'beta_max': beta_end * (num_diffusion_timesteps - 1),
85
+ 'num_diffusion_timesteps': num_diffusion_timesteps
86
+ }
87
+
88
+
89
+ def create_model(config):
90
+ """Create the score model."""
91
+ model_name = config.model.name
92
+ score_model = get_model(model_name)(config)
93
+ score_model = score_model.to(config.device)
94
+ score_model = torch.nn.DataParallel(score_model)
95
+ return score_model
96
+
97
+
98
+ def get_model_fn(model, train=False):
99
+ """Create a function to give the output of the score-based model.
100
+
101
+ Args:
102
+ model: The score model.
103
+ train: `True` for training and `False` for evaluation.
104
+
105
+ Returns:
106
+ A model function.
107
+ """
108
+
109
+ def model_fn(x, labels):
110
+ """Compute the output of the score-based model.
111
+
112
+ Args:
113
+ x: A mini-batch of input data.
114
+ labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
115
+ for different models.
116
+
117
+ Returns:
118
+ A tuple of (model output, new mutable states)
119
+ """
120
+ if not train:
121
+ model.eval()
122
+ return model(x, labels)
123
+ else:
124
+ model.train()
125
+ return model(x, labels)
126
+
127
+ return model_fn
128
+
129
+
130
+ def get_score_fn(sde, model, train=False, continuous=False):
131
+ """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
132
+
133
+ Args:
134
+ sde: An `sde_lib.SDE` object that represents the forward SDE.
135
+ model: A score model.
136
+ train: `True` for training and `False` for evaluation.
137
+ continuous: If `True`, the score-based model is expected to directly take continuous time steps.
138
+
139
+ Returns:
140
+ A score function.
141
+ """
142
+ model_fn = get_model_fn(model, train=train)
143
+
144
+ if isinstance(sde, OUVPSDE):
145
+ def score_fn(x, t):
146
+ # Scale neural network output by standard deviation and flip sign
147
+ if continuous:
148
+ # For VP-trained models, t=0 corresponds to the lowest noise level
149
+ # The maximum value of time embedding is assumed to 999 for
150
+ # continuously-trained models.
151
+ labels = t * 999
152
+ score = model_fn(x, labels)
153
+ std = sde.marginal_prob(torch.zeros_like(x), t)[1]
154
+ else:
155
+ # For VP-trained models, t=0 corresponds to the lowest noise level
156
+ labels = t * (sde.N - 1)
157
+ score = model_fn(x, labels)
158
+ std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
159
+
160
+ score = -score / std[:, None, None, None]
161
+ return score
162
+
163
+ elif isinstance(sde, OUVESDE):
164
+ def score_fn(x, t):
165
+ if continuous:
166
+ labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
167
+ else:
168
+ # For VE-trained models, t=0 corresponds to the highest noise level
169
+ labels = sde.T - t
170
+ labels *= sde.N - 1
171
+ labels = torch.round(labels).long()
172
+
173
+ score = model_fn(x, labels)
174
+ return score
175
+
176
+ else:
177
+ raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
178
+
179
+ return score_fn
180
+
181
+
182
+ def to_flattened_numpy(x):
183
+ """Flatten a torch tensor `x` and convert it to numpy."""
184
+ return x.detach().cpu().numpy().reshape((-1,))
185
+
186
+
187
+ def from_flattened_numpy(x, shape):
188
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
189
+ return torch.from_numpy(x.reshape(shape))
fastgeco/backbones/shared.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from geco.util.registry import Registry
8
+
9
+
10
+ BackboneRegistry = Registry("Backbone")
11
+
12
+
13
+ class GaussianFourierProjection(nn.Module):
14
+ """Gaussian random features for encoding time steps."""
15
+
16
+ def __init__(self, embed_dim, scale=16, complex_valued=False):
17
+ super().__init__()
18
+ self.complex_valued = complex_valued
19
+ if not complex_valued:
20
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
21
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
22
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
23
+ # and this halving is not necessary.
24
+ embed_dim = embed_dim // 2
25
+ # Randomly sample weights during initialization. These weights are fixed
26
+ # during optimization and are not trainable.
27
+ self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
28
+
29
+ def forward(self, t):
30
+ t_proj = t[:, None] * self.W[None, :] * 2*np.pi
31
+ if self.complex_valued:
32
+ return torch.exp(1j * t_proj)
33
+ else:
34
+ return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
35
+
36
+
37
+ class DiffusionStepEmbedding(nn.Module):
38
+ """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
39
+
40
+ def __init__(self, embed_dim, complex_valued=False):
41
+ super().__init__()
42
+ self.complex_valued = complex_valued
43
+ if not complex_valued:
44
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
45
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
46
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
47
+ # and this halving is not necessary.
48
+ embed_dim = embed_dim // 2
49
+ self.embed_dim = embed_dim
50
+
51
+ def forward(self, t):
52
+ fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
53
+ inner = t[:, None] * fac[None, :]
54
+ if self.complex_valued:
55
+ return torch.exp(1j * inner)
56
+ else:
57
+ return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
58
+
59
+
60
+ class ComplexLinear(nn.Module):
61
+ """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
62
+ def __init__(self, input_dim, output_dim, complex_valued):
63
+ super().__init__()
64
+ self.complex_valued = complex_valued
65
+ if self.complex_valued:
66
+ self.re = nn.Linear(input_dim, output_dim)
67
+ self.im = nn.Linear(input_dim, output_dim)
68
+ else:
69
+ self.lin = nn.Linear(input_dim, output_dim)
70
+
71
+ def forward(self, x):
72
+ if self.complex_valued:
73
+ return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
74
+ else:
75
+ return self.lin(x)
76
+
77
+
78
+ class FeatureMapDense(nn.Module):
79
+ """A fully connected layer that reshapes outputs to feature maps."""
80
+
81
+ def __init__(self, input_dim, output_dim, complex_valued=False):
82
+ super().__init__()
83
+ self.complex_valued = complex_valued
84
+ self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
85
+
86
+ def forward(self, x):
87
+ return self.dense(x)[..., None, None]
88
+
89
+
90
+ def torch_complex_from_reim(re, im):
91
+ return torch.view_as_complex(torch.stack([re, im], dim=-1))
92
+
93
+
94
+ class ArgsComplexMultiplicationWrapper(nn.Module):
95
+ """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
96
+
97
+ Make a complex-valued module `F` from a real-valued module `f` by applying
98
+ complex multiplication rules:
99
+
100
+ F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
101
+
102
+ where `f1`, `f2` are instances of `f` that do *not* share weights.
103
+
104
+ Args:
105
+ module_cls (callable): A class or function that returns a Torch module/functional.
106
+ Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
107
+ to construct the real and imaginary component modules.
108
+ """
109
+
110
+ def __init__(self, module_cls, *args, **kwargs):
111
+ super().__init__()
112
+ self.re_module = module_cls(*args, **kwargs)
113
+ self.im_module = module_cls(*args, **kwargs)
114
+
115
+ def forward(self, x, *args, **kwargs):
116
+ return torch_complex_from_reim(
117
+ self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
118
+ self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
119
+ )
120
+
121
+
122
+ ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
123
+ ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
fastgeco/model.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+ from math import ceil
4
+ import warnings
5
+ import numpy as np
6
+ # from asteroid.losses.sdr import SingleSrcNegSDR
7
+ import torch
8
+ import pytorch_lightning as pl
9
+ from torch_ema import ExponentialMovingAverage
10
+ import torch.nn.functional as F
11
+ from geco import sampling
12
+ from geco.sdes import SDERegistry
13
+ from fastgeco.backbones import BackboneRegistry
14
+ from geco.util.inference import evaluate_model2
15
+ from geco.util.other import pad_spec
16
+ import numpy as np
17
+ import matplotlib.pyplot as plt
18
+
19
+
20
+
21
+ class ScoreModel(pl.LightningModule):
22
+ @staticmethod
23
+ def add_argparse_args(parser):
24
+ parser.add_argument("--lr", type=float, default=1e-5, help="The learning rate (1e-4 by default)")
25
+ parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
26
+ parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)")
27
+ parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
28
+ parser.add_argument("--loss_type", type=str, default="mse", help="The type of loss function to use.")
29
+ parser.add_argument("--loss_abs_exponent", type=float, default=0.5, help="magnitude transformation in the loss term")
30
+ parser.add_argument("--output_scale", type=str, choices=('sigma', 'time'), default= 'time', help="backbone model scale before last output layer")
31
+ return parser
32
+
33
+ def __init__(
34
+ self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2, loss_abs_exponent=0.5,
35
+ num_eval_files=20, loss_type='mse', data_module_cls=None, output_scale='time', inference_N=1,
36
+ inference_start=0.5, **kwargs
37
+ ):
38
+ """
39
+ Create a new ScoreModel.
40
+
41
+ Args:
42
+ backbone: Backbone DNN that serves as a score-based model.
43
+ sde: The SDE that defines the diffusion process.
44
+ lr: The learning rate of the optimizer. (1e-4 by default).
45
+ ema_decay: The decay constant of the parameter EMA (0.999 by default).
46
+ t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
47
+ loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
48
+ """
49
+ super().__init__()
50
+ # Initialize Backbone DNN
51
+ dnn_cls = BackboneRegistry.get_by_name(backbone)
52
+ self.dnn = dnn_cls(**kwargs)
53
+ # Initialize SDE
54
+ sde_cls = SDERegistry.get_by_name(sde)
55
+ self.sde = sde_cls(**kwargs)
56
+ # Store hyperparams and save them
57
+ self.lr = lr
58
+ self.ema_decay = ema_decay
59
+ self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
60
+ self._error_loading_ema = False
61
+ self.t_eps = t_eps
62
+ self.loss_type = loss_type
63
+ self.num_eval_files = num_eval_files
64
+ self.loss_abs_exponent = loss_abs_exponent
65
+ self.output_scale = output_scale
66
+ self.save_hyperparameters(ignore=['no_wandb'])
67
+ self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
68
+ self.inference_N = inference_N
69
+ self.inference_start = inference_start
70
+
71
+ # self.si_snr = SingleSrcNegSDR("sisdr", reduction='mean', zero_mean=False)
72
+
73
+ def configure_optimizers(self):
74
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
75
+ return optimizer
76
+
77
+ def optimizer_step(self, *args, **kwargs):
78
+ # Method overridden so that the EMA params are updated after each optimizer step
79
+ super().optimizer_step(*args, **kwargs)
80
+ self.ema.update(self.parameters())
81
+
82
+ # on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
83
+ def on_load_checkpoint(self, checkpoint):
84
+ ema = checkpoint.get('ema', None)
85
+ if ema is not None:
86
+ self.ema.load_state_dict(checkpoint['ema'])
87
+ else:
88
+ self._error_loading_ema = True
89
+ warnings.warn("EMA state_dict not found in checkpoint!")
90
+
91
+ def on_save_checkpoint(self, checkpoint):
92
+ checkpoint['ema'] = self.ema.state_dict()
93
+
94
+ def train(self, mode, no_ema=False):
95
+ res = super().train(mode) # call the standard `train` method with the given mode
96
+ if not self._error_loading_ema:
97
+ if mode == False and not no_ema:
98
+ # eval
99
+ self.ema.store(self.parameters()) # store current params in EMA
100
+ self.ema.copy_to(self.parameters()) # copy EMA parameters over current params for evaluation
101
+ else:
102
+ # train
103
+ if self.ema.collected_params is not None:
104
+ self.ema.restore(self.parameters()) # restore the EMA weights (if stored)
105
+ return res
106
+
107
+ def eval(self, no_ema=False):
108
+ return self.train(False, no_ema=no_ema)
109
+
110
+
111
+ def sisnr(self, est, ref, eps = 1e-8):
112
+ est = est - torch.mean(est, dim = -1, keepdim = True)
113
+ ref = ref - torch.mean(ref, dim = -1, keepdim = True)
114
+ est_p = (torch.sum(est * ref, dim = -1, keepdim = True) * ref) / torch.sum(ref * ref, dim = -1, keepdim = True)
115
+ est_v = est - est_p
116
+ est_sisnr = 10 * torch.log10((torch.sum(est_p * est_p, dim = -1, keepdim = True) + eps) / (torch.sum(est_v * est_v, dim = -1, keepdim = True) + eps))
117
+ return -est_sisnr
118
+
119
+
120
+ def _loss(self, wav_x_tm1, wav_gt):
121
+ if self.loss_type == 'default':
122
+ min_leng = min(wav_x_tm1.shape[-1], wav_gt.shape[-1])
123
+ wav_x_tm1 = wav_x_tm1.squeeze(1)[:,:min_leng]
124
+ wav_gt = wav_gt.squeeze(1)[:,:min_leng]
125
+ loss = torch.mean(self.sisnr(wav_x_tm1, wav_gt))
126
+ else:
127
+ raise RuntimeError(f'{self.loss_type} loss not defined')
128
+
129
+ return loss
130
+
131
+
132
+
133
+ def euler_step(self, X, X_t, Y, M, t, dt):
134
+ f, g = self.sde.sde(X_t, t, Y)
135
+ vec_t = torch.ones(Y.shape[0], device=Y.device) * t
136
+ mean_x_tm1 = X_t - (f - g**2*self.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None]))*dt
137
+ z = torch.randn_like(X)
138
+ X_t = mean_x_tm1 + z*g*torch.sqrt(dt)
139
+
140
+ return X_t
141
+
142
+
143
+ def training_step(self, batch, batch_idx):
144
+ X, Y, M = batch
145
+
146
+ reverse_start_time = random.uniform(self.t_rsp_min, self.t_rsp_max)
147
+ N_reverse = random.randint(self.N_min, self.N_max)
148
+
149
+ if self.stop_iteration_random == "random":
150
+ stop_iteration = random.randint(0, N_reverse-1)
151
+ elif self.stop_iteration_random == "last":
152
+ #Used in publication. This means that only the last step is used for updating weights.
153
+ stop_iteration = N_reverse-1
154
+ else:
155
+ raise RuntimeError(f'{self.stop_iteration_random} not defined')
156
+
157
+ timesteps = torch.linspace(reverse_start_time, self.t_eps, N_reverse, device=Y.device)
158
+
159
+ #prior sampling starting from reverse_start_time
160
+ std = self.sde._std(reverse_start_time*torch.ones((Y.shape[0],), device=Y.device))
161
+ z = torch.randn_like(Y)
162
+ X_t = Y + z * std[:, None, None, None]
163
+
164
+ #reverse steps by Euler Maruyama
165
+ for i in range(len(timesteps)):
166
+ t = timesteps[i]
167
+ if i != len(timesteps) - 1:
168
+ dt = t - timesteps[i+1]
169
+ else:
170
+ dt = timesteps[-1]
171
+
172
+ if i != stop_iteration:
173
+ with torch.no_grad():
174
+ #take Euler step here
175
+ X_t = self.euler_step(X, X_t, Y, M, t, dt)
176
+ else:
177
+ #take a Euler step and compute loss
178
+ f, g = self.sde.sde(X_t, t, Y)
179
+ vec_t = torch.ones(Y.shape[0], device=Y.device) * t
180
+ score = self.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None])
181
+ mean_x_tm1 = X_t - (f - g**2*score)*dt #mean of x t minus 1 = mu(x_{t-1})
182
+ mean_gt, _ = self.sde.marginal_prob(X, torch.ones(Y.shape[0], device=Y.device) * (t-dt), Y)
183
+
184
+ wav_gt = self.to_audio(mean_gt.squeeze())
185
+ wav_x_tm1 = self.to_audio(mean_x_tm1.squeeze())
186
+ loss = self._loss(wav_x_tm1, wav_gt)
187
+ break
188
+
189
+ self.log('train_loss', loss, on_step=True, on_epoch=True)
190
+ return loss
191
+
192
+
193
+ def validation_step(self, batch, batch_idx):
194
+ # Evaluate speech enhancement performance, compute loss only for a few val data
195
+ if batch_idx == 0 and self.num_eval_files != 0:
196
+ pesq, si_sdr, estoi, loss = evaluate_model2(self, self.num_eval_files, self.inference_N, inference_start=self.inference_start)
197
+ self.log('pesq', pesq, on_step=False, on_epoch=True)
198
+ self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
199
+ self.log('estoi', estoi, on_step=False, on_epoch=True)
200
+ self.log('valid_loss', loss, on_step=False, on_epoch=True)
201
+ return loss
202
+
203
+
204
+ def forward(self, x, t, y, m, divide_scale):
205
+ # Concatenate y as an extra channel
206
+ dnn_input = torch.cat([x, y, m], dim=1)
207
+
208
+ # the minus is most likely unimportant here - taken from Song's repo
209
+ score = -self.dnn(dnn_input, t, divide_scale)
210
+ return score
211
+
212
+ def to(self, *args, **kwargs):
213
+ """Override PyTorch .to() to also transfer the EMA of the model weights"""
214
+ self.ema.to(*args, **kwargs)
215
+ return super().to(*args, **kwargs)
216
+
217
+
218
+ def train_dataloader(self):
219
+ return self.data_module.train_dataloader()
220
+
221
+ def val_dataloader(self):
222
+ return self.data_module.val_dataloader()
223
+
224
+ def test_dataloader(self):
225
+ return self.data_module.test_dataloader()
226
+
227
+ def setup(self, stage=None):
228
+ return self.data_module.setup(stage=stage)
229
+
230
+ def to_audio(self, spec, length=None):
231
+ return self._istft(self._backward_transform(spec), length)
232
+
233
+ def _forward_transform(self, spec):
234
+ return self.data_module.spec_fwd(spec)
235
+
236
+ def _backward_transform(self, spec):
237
+ return self.data_module.spec_back(spec)
238
+
239
+ def _stft(self, sig):
240
+ return self.data_module.stft(sig)
241
+
242
+ def _istft(self, spec, length=None):
243
+ return self.data_module.istft(spec, length)
244
+
245
+
246
+ def add_para(self, N_min=1, N_max=1, t_rsp_min=0.5, t_rsp_max=0.5, batch_size=64, loss_type='default', lr=5e-5, stop_iteration_random='last', inference_N=1, inference_start=0.5):
247
+ self.t_rsp_min = t_rsp_min
248
+ self.t_rsp_max = t_rsp_max
249
+ self.N_min = N_min
250
+ self.N_max = N_max
251
+ self.data_module.batch_size = batch_size
252
+ self.data_module.num_workers = 4
253
+ self.data_module.gpu = True
254
+ self.loss_type = loss_type
255
+ self.lr = lr
256
+ self.stop_iteration_random = stop_iteration_random
257
+ self.inference_N = inference_N
258
+ self.inference_start = inference_start
geco/.DS_Store ADDED
Binary file (6.15 kB). View file
 
geco/backbones/.DS_Store ADDED
Binary file (6.15 kB). View file
 
geco/backbones/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .shared import BackboneRegistry
2
+ from .ncsnpp import NCSNpp
3
+
4
+ __all__ = ['BackboneRegistry', 'NCSNpp']
geco/backbones/ncsnpp.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ from score_models.layers import UpsampleLayer, DownsampleLayer
18
+ from .ncsnpp_utils import layers, layerspp, normalization
19
+ import torch.nn as nn
20
+ import functools
21
+ import torch
22
+ import numpy as np
23
+
24
+ from .shared import BackboneRegistry
25
+
26
+ ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
+ ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
+ Combine = layerspp.Combine
29
+ conv3x3 = layerspp.conv3x3
30
+ conv1x1 = layerspp.conv1x1
31
+ get_act = layers.get_act
32
+ get_normalization = normalization.get_normalization
33
+ default_initializer = layers.default_init
34
+
35
+
36
+ @BackboneRegistry.register("ncsnpp")
37
+ class NCSNpp(nn.Module):
38
+ """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
+
40
+ @staticmethod
41
+ def add_argparse_args(parser):
42
+ # TODO: add additional arguments of constructor, if you wish to modify them.
43
+ return parser
44
+
45
+ def __init__(self,
46
+ scale_by_sigma = True,
47
+ nonlinearity = 'swish',
48
+ nf = 128,
49
+ ch_mult = (1, 1, 2, 2, 2, 2, 2),
50
+ num_res_blocks = 2,
51
+ attn_resolutions = (16,),
52
+ resamp_with_conv = True,
53
+ conditional = True,
54
+ fir = True,
55
+ fir_kernel = 'song',
56
+ skip_rescale = True,
57
+ resblock_type = 'biggan',
58
+ progressive = 'output_skip',
59
+ progressive_input = 'input_skip',
60
+ progressive_combine = 'sum',
61
+ init_scale = 0.,
62
+ fourier_scale = 16,
63
+ image_size = 256,
64
+ embedding_type = 'fourier',
65
+ dropout = .0,
66
+ **unused_kwargs
67
+ ):
68
+ super().__init__()
69
+ self.act = act = get_act(nonlinearity)
70
+
71
+ self.nf = nf = nf
72
+ ch_mult = ch_mult
73
+ self.num_res_blocks = num_res_blocks = num_res_blocks
74
+ self.attn_resolutions = attn_resolutions = attn_resolutions
75
+ dropout = dropout
76
+ resamp_with_conv = resamp_with_conv
77
+ self.num_resolutions = num_resolutions = len(ch_mult)
78
+ self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
79
+
80
+ self.conditional = conditional = conditional # noise-conditional
81
+ self.scale_by_sigma = scale_by_sigma
82
+ fir = fir
83
+ fir_kernel = [1, 3, 3, 1]
84
+ self.skip_rescale = skip_rescale = skip_rescale
85
+ self.resblock_type = resblock_type = resblock_type.lower()
86
+ self.progressive = progressive = progressive.lower()
87
+ self.progressive_input = progressive_input = progressive_input.lower()
88
+ self.embedding_type = embedding_type = embedding_type.lower()
89
+ init_scale = init_scale
90
+ assert progressive in ['none', 'output_skip', 'residual']
91
+ assert progressive_input in ['none', 'input_skip', 'residual']
92
+ assert embedding_type in ['fourier', 'positional']
93
+ combine_method = progressive_combine.lower()
94
+ combiner = functools.partial(Combine, method=combine_method)
95
+
96
+ num_channels = 6 # x.real, x.imag, y.real, y.imag
97
+ self.output_layer = nn.Conv2d(num_channels, 2, 1)
98
+
99
+ modules = []
100
+ # timestep/noise_level embedding
101
+ if embedding_type == 'fourier':
102
+ # Gaussian Fourier features embeddings.
103
+ modules.append(layerspp.GaussianFourierProjection(
104
+ embedding_size=nf, scale=fourier_scale
105
+ ))
106
+ embed_dim = 2 * nf
107
+ elif embedding_type == 'positional':
108
+ embed_dim = nf
109
+ else:
110
+ raise ValueError(f'embedding type {embedding_type} unknown.')
111
+
112
+ if conditional:
113
+ modules.append(nn.Linear(embed_dim, nf * 4))
114
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
115
+ nn.init.zeros_(modules[-1].bias)
116
+ modules.append(nn.Linear(nf * 4, nf * 4))
117
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
118
+ nn.init.zeros_(modules[-1].bias)
119
+
120
+ AttnBlock = functools.partial(layerspp.AttnBlockpp,
121
+ init_scale=init_scale, skip_rescale=skip_rescale)
122
+
123
+ Upsample = functools.partial(UpsampleLayer,
124
+ with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
125
+
126
+ if progressive == 'output_skip':
127
+ self.pyramid_upsample = UpsampleLayer(fir=fir, fir_kernel=fir_kernel, with_conv=False)
128
+ elif progressive == 'residual':
129
+ pyramid_upsample = functools.partial(UpsampleLayer, fir=fir,
130
+ fir_kernel=fir_kernel, with_conv=True)
131
+
132
+ Downsample = functools.partial(DownsampleLayer, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
133
+
134
+ if progressive_input == 'input_skip':
135
+ self.pyramid_downsample = DownsampleLayer(fir=fir, fir_kernel=fir_kernel, with_conv=False)
136
+ elif progressive_input == 'residual':
137
+ pyramid_downsample = functools.partial(DownsampleLayer,
138
+ fir=fir, fir_kernel=fir_kernel, with_conv=True)
139
+
140
+ if resblock_type == 'ddpm':
141
+ ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
142
+ dropout=dropout, init_scale=init_scale,
143
+ skip_rescale=skip_rescale, temb_dim=nf * 4)
144
+
145
+ elif resblock_type == 'biggan':
146
+ ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
147
+ dropout=dropout, fir=fir, fir_kernel=fir_kernel,
148
+ init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
149
+
150
+ else:
151
+ raise ValueError(f'resblock type {resblock_type} unrecognized.')
152
+
153
+ # Downsampling block
154
+
155
+ channels = num_channels
156
+ if progressive_input != 'none':
157
+ input_pyramid_ch = channels
158
+
159
+ modules.append(conv3x3(channels, nf))
160
+ hs_c = [nf]
161
+
162
+ in_ch = nf
163
+ for i_level in range(num_resolutions):
164
+ # Residual blocks for this resolution
165
+ for i_block in range(num_res_blocks):
166
+ out_ch = nf * ch_mult[i_level]
167
+ modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
168
+ in_ch = out_ch
169
+
170
+ if all_resolutions[i_level] in attn_resolutions:
171
+ modules.append(AttnBlock(channels=in_ch))
172
+ hs_c.append(in_ch)
173
+
174
+ if i_level != num_resolutions - 1:
175
+ if resblock_type == 'ddpm':
176
+ modules.append(Downsample(in_ch=in_ch))
177
+ else:
178
+ modules.append(ResnetBlock(down=True, in_ch=in_ch))
179
+
180
+ if progressive_input == 'input_skip':
181
+ modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
182
+ if combine_method == 'cat':
183
+ in_ch *= 2
184
+
185
+ elif progressive_input == 'residual':
186
+ modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
187
+ input_pyramid_ch = in_ch
188
+
189
+ hs_c.append(in_ch)
190
+
191
+ in_ch = hs_c[-1]
192
+ modules.append(ResnetBlock(in_ch=in_ch))
193
+ modules.append(AttnBlock(channels=in_ch))
194
+ modules.append(ResnetBlock(in_ch=in_ch))
195
+
196
+ pyramid_ch = 0
197
+ # Upsampling block
198
+ for i_level in reversed(range(num_resolutions)):
199
+ for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
200
+ out_ch = nf * ch_mult[i_level]
201
+ modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
202
+ in_ch = out_ch
203
+
204
+ if all_resolutions[i_level] in attn_resolutions:
205
+ modules.append(AttnBlock(channels=in_ch))
206
+
207
+ if progressive != 'none':
208
+ if i_level == num_resolutions - 1:
209
+ if progressive == 'output_skip':
210
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
211
+ num_channels=in_ch, eps=1e-6))
212
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
213
+ pyramid_ch = channels
214
+ elif progressive == 'residual':
215
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
216
+ modules.append(conv3x3(in_ch, in_ch, bias=True))
217
+ pyramid_ch = in_ch
218
+ else:
219
+ raise ValueError(f'{progressive} is not a valid name.')
220
+ else:
221
+ if progressive == 'output_skip':
222
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
223
+ num_channels=in_ch, eps=1e-6))
224
+ modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
225
+ pyramid_ch = channels
226
+ elif progressive == 'residual':
227
+ modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
228
+ pyramid_ch = in_ch
229
+ else:
230
+ raise ValueError(f'{progressive} is not a valid name')
231
+
232
+ if i_level != 0:
233
+ if resblock_type == 'ddpm':
234
+ modules.append(Upsample(in_ch=in_ch))
235
+ else:
236
+ modules.append(ResnetBlock(in_ch=in_ch, up=True))
237
+
238
+ assert not hs_c
239
+
240
+ if progressive != 'output_skip':
241
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
242
+ num_channels=in_ch, eps=1e-6))
243
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
244
+
245
+ self.all_modules = nn.ModuleList(modules)
246
+
247
+ def forward(self, x, time_cond):
248
+ # timestep/noise_level embedding; only for continuous training
249
+ modules = self.all_modules
250
+ m_idx = 0
251
+
252
+ # Convert real and imaginary parts of (x,y) into four channel dimensions
253
+ x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
254
+ x[:,[1],:,:].real, x[:,[1],:,:].imag,
255
+ x[:,[2],:,:].real, x[:,[2],:,:].imag), dim=1)
256
+
257
+ if self.embedding_type == 'fourier':
258
+ # Gaussian Fourier features embeddings.
259
+ used_sigmas = time_cond
260
+ temb = modules[m_idx](torch.log(used_sigmas))
261
+ m_idx += 1
262
+
263
+ elif self.embedding_type == 'positional':
264
+ # Sinusoidal positional embeddings.
265
+ timesteps = time_cond
266
+ used_sigmas = self.sigmas[time_cond.long()]
267
+ temb = layers.get_timestep_embedding(timesteps, self.nf)
268
+
269
+ else:
270
+ raise ValueError(f'embedding type {self.embedding_type} unknown.')
271
+
272
+ if self.conditional:
273
+ temb = modules[m_idx](temb)
274
+ m_idx += 1
275
+ temb = modules[m_idx](self.act(temb))
276
+ m_idx += 1
277
+ else:
278
+ temb = None
279
+
280
+ # Downsampling block
281
+ input_pyramid = None
282
+ if self.progressive_input != 'none':
283
+ input_pyramid = x
284
+
285
+ # Input layer: Conv2d: 4ch -> 128ch
286
+ hs = [modules[m_idx](x)]
287
+ m_idx += 1
288
+
289
+ # Down path in U-Net
290
+ for i_level in range(self.num_resolutions):
291
+ # Residual blocks for this resolution
292
+ for i_block in range(self.num_res_blocks):
293
+ h = modules[m_idx](hs[-1], temb)
294
+ m_idx += 1
295
+ # Attention layer (optional)
296
+ if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
297
+ h = modules[m_idx](h)
298
+ m_idx += 1
299
+ hs.append(h)
300
+
301
+ # Downsampling
302
+ if i_level != self.num_resolutions - 1:
303
+ if self.resblock_type == 'ddpm':
304
+ h = modules[m_idx](hs[-1])
305
+ m_idx += 1
306
+ else:
307
+ h = modules[m_idx](hs[-1], temb)
308
+ m_idx += 1
309
+
310
+ if self.progressive_input == 'input_skip': # Combine h with x
311
+ input_pyramid = self.pyramid_downsample(input_pyramid)
312
+ h = modules[m_idx](input_pyramid, h)
313
+ m_idx += 1
314
+
315
+ elif self.progressive_input == 'residual':
316
+ input_pyramid = modules[m_idx](input_pyramid)
317
+ m_idx += 1
318
+ if self.skip_rescale:
319
+ input_pyramid = (input_pyramid + h) / np.sqrt(2.)
320
+ else:
321
+ input_pyramid = input_pyramid + h
322
+ h = input_pyramid
323
+ hs.append(h)
324
+
325
+ h = hs[-1] # actualy equal to: h = h
326
+ h = modules[m_idx](h, temb) # ResNet block
327
+ m_idx += 1
328
+ h = modules[m_idx](h) # Attention block
329
+ m_idx += 1
330
+ h = modules[m_idx](h, temb) # ResNet block
331
+ m_idx += 1
332
+
333
+ pyramid = None
334
+
335
+ # Upsampling block
336
+ for i_level in reversed(range(self.num_resolutions)):
337
+ for i_block in range(self.num_res_blocks + 1):
338
+ h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
339
+ m_idx += 1
340
+
341
+ # edit: from -1 to -2
342
+ if h.shape[-2] in self.attn_resolutions:
343
+ h = modules[m_idx](h)
344
+ m_idx += 1
345
+
346
+ if self.progressive != 'none':
347
+ if i_level == self.num_resolutions - 1:
348
+ if self.progressive == 'output_skip':
349
+ pyramid = self.act(modules[m_idx](h)) # GroupNorm
350
+ m_idx += 1
351
+ pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
352
+ m_idx += 1
353
+ elif self.progressive == 'residual':
354
+ pyramid = self.act(modules[m_idx](h))
355
+ m_idx += 1
356
+ pyramid = modules[m_idx](pyramid)
357
+ m_idx += 1
358
+ else:
359
+ raise ValueError(f'{self.progressive} is not a valid name.')
360
+ else:
361
+ if self.progressive == 'output_skip':
362
+ pyramid = self.pyramid_upsample(pyramid) # Upsample
363
+ pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
364
+ m_idx += 1
365
+ pyramid_h = modules[m_idx](pyramid_h)
366
+ m_idx += 1
367
+ pyramid = pyramid + pyramid_h
368
+ elif self.progressive == 'residual':
369
+ pyramid = modules[m_idx](pyramid)
370
+ m_idx += 1
371
+ if self.skip_rescale:
372
+ pyramid = (pyramid + h) / np.sqrt(2.)
373
+ else:
374
+ pyramid = pyramid + h
375
+ h = pyramid
376
+ else:
377
+ raise ValueError(f'{self.progressive} is not a valid name')
378
+
379
+ # Upsampling Layer
380
+ if i_level != 0:
381
+ if self.resblock_type == 'ddpm':
382
+ h = modules[m_idx](h)
383
+ m_idx += 1
384
+ else:
385
+ h = modules[m_idx](h, temb) # Upspampling
386
+ m_idx += 1
387
+
388
+ assert not hs
389
+
390
+ if self.progressive == 'output_skip':
391
+ h = pyramid
392
+ else:
393
+ h = self.act(modules[m_idx](h))
394
+ m_idx += 1
395
+ h = modules[m_idx](h)
396
+ m_idx += 1
397
+
398
+ assert m_idx == len(modules), "Implementation error"
399
+ h = h / used_sigmas[:, None, None, None]
400
+
401
+ # Convert back to complex number
402
+ h = self.output_layer(h)
403
+ h = torch.permute(h, (0, 2, 3, 1)).contiguous()
404
+ h = torch.view_as_complex(h)[:,None, :, :]
405
+ return h
geco/backbones/ncsnpp_utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
geco/backbones/ncsnpp_utils/layers.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ """Common layers for defining score networks.
18
+ """
19
+ import math
20
+ import string
21
+ from functools import partial
22
+ import torch.nn as nn
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import numpy as np
26
+ from .normalization import ConditionalInstanceNorm2dPlus
27
+
28
+
29
+ def get_act(config):
30
+ """Get activation functions from the config file."""
31
+
32
+ if config == 'elu':
33
+ return nn.ELU()
34
+ elif config == 'relu':
35
+ return nn.ReLU()
36
+ elif config == 'lrelu':
37
+ return nn.LeakyReLU(negative_slope=0.2)
38
+ elif config == 'swish':
39
+ return nn.SiLU()
40
+ else:
41
+ raise NotImplementedError('activation function does not exist!')
42
+
43
+
44
+ def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
45
+ """1x1 convolution. Same as NCSNv1/v2."""
46
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
47
+ padding=padding)
48
+ init_scale = 1e-10 if init_scale == 0 else init_scale
49
+ conv.weight.data *= init_scale
50
+ conv.bias.data *= init_scale
51
+ return conv
52
+
53
+
54
+ def variance_scaling(scale, mode, distribution,
55
+ in_axis=1, out_axis=0,
56
+ dtype=torch.float32,
57
+ device='cpu'):
58
+ """Ported from JAX. """
59
+
60
+ def _compute_fans(shape, in_axis=1, out_axis=0):
61
+ receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
62
+ fan_in = shape[in_axis] * receptive_field_size
63
+ fan_out = shape[out_axis] * receptive_field_size
64
+ return fan_in, fan_out
65
+
66
+ def init(shape, dtype=dtype, device=device):
67
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
68
+ if mode == "fan_in":
69
+ denominator = fan_in
70
+ elif mode == "fan_out":
71
+ denominator = fan_out
72
+ elif mode == "fan_avg":
73
+ denominator = (fan_in + fan_out) / 2
74
+ else:
75
+ raise ValueError(
76
+ "invalid mode for variance scaling initializer: {}".format(mode))
77
+ variance = scale / denominator
78
+ if distribution == "normal":
79
+ return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
80
+ elif distribution == "uniform":
81
+ return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
82
+ else:
83
+ raise ValueError("invalid distribution for variance scaling initializer")
84
+
85
+ return init
86
+
87
+
88
+ def default_init(scale=1.):
89
+ """The same initialization used in DDPM."""
90
+ scale = 1e-10 if scale == 0 else scale
91
+ return variance_scaling(scale, 'fan_avg', 'uniform')
92
+
93
+
94
+ class Dense(nn.Module):
95
+ """Linear layer with `default_init`."""
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+
100
+ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
101
+ """1x1 convolution with DDPM initialization."""
102
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
103
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
104
+ nn.init.zeros_(conv.bias)
105
+ return conv
106
+
107
+
108
+ def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
109
+ """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
110
+ init_scale = 1e-10 if init_scale == 0 else init_scale
111
+ conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
112
+ dilation=dilation, padding=padding, kernel_size=3)
113
+ conv.weight.data *= init_scale
114
+ conv.bias.data *= init_scale
115
+ return conv
116
+
117
+
118
+ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
119
+ """3x3 convolution with DDPM initialization."""
120
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
121
+ dilation=dilation, bias=bias)
122
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
123
+ nn.init.zeros_(conv.bias)
124
+ return conv
125
+
126
+ ###########################################################################
127
+ # Functions below are ported over from the NCSNv1/NCSNv2 codebase:
128
+ # https://github.com/ermongroup/ncsn
129
+ # https://github.com/ermongroup/ncsnv2
130
+ ###########################################################################
131
+
132
+
133
+ class CRPBlock(nn.Module):
134
+ def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
135
+ super().__init__()
136
+ self.convs = nn.ModuleList()
137
+ for i in range(n_stages):
138
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
139
+ self.n_stages = n_stages
140
+ if maxpool:
141
+ self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
142
+ else:
143
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
144
+
145
+ self.act = act
146
+
147
+ def forward(self, x):
148
+ x = self.act(x)
149
+ path = x
150
+ for i in range(self.n_stages):
151
+ path = self.pool(path)
152
+ path = self.convs[i](path)
153
+ x = path + x
154
+ return x
155
+
156
+
157
+ class CondCRPBlock(nn.Module):
158
+ def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
159
+ super().__init__()
160
+ self.convs = nn.ModuleList()
161
+ self.norms = nn.ModuleList()
162
+ self.normalizer = normalizer
163
+ for i in range(n_stages):
164
+ self.norms.append(normalizer(features, num_classes, bias=True))
165
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
166
+
167
+ self.n_stages = n_stages
168
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
169
+ self.act = act
170
+
171
+ def forward(self, x, y):
172
+ x = self.act(x)
173
+ path = x
174
+ for i in range(self.n_stages):
175
+ path = self.norms[i](path, y)
176
+ path = self.pool(path)
177
+ path = self.convs[i](path)
178
+
179
+ x = path + x
180
+ return x
181
+
182
+
183
+ class RCUBlock(nn.Module):
184
+ def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
185
+ super().__init__()
186
+
187
+ for i in range(n_blocks):
188
+ for j in range(n_stages):
189
+ setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
190
+
191
+ self.stride = 1
192
+ self.n_blocks = n_blocks
193
+ self.n_stages = n_stages
194
+ self.act = act
195
+
196
+ def forward(self, x):
197
+ for i in range(self.n_blocks):
198
+ residual = x
199
+ for j in range(self.n_stages):
200
+ x = self.act(x)
201
+ x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
202
+
203
+ x += residual
204
+ return x
205
+
206
+
207
+ class CondRCUBlock(nn.Module):
208
+ def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
209
+ super().__init__()
210
+
211
+ for i in range(n_blocks):
212
+ for j in range(n_stages):
213
+ setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
214
+ setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
215
+
216
+ self.stride = 1
217
+ self.n_blocks = n_blocks
218
+ self.n_stages = n_stages
219
+ self.act = act
220
+ self.normalizer = normalizer
221
+
222
+ def forward(self, x, y):
223
+ for i in range(self.n_blocks):
224
+ residual = x
225
+ for j in range(self.n_stages):
226
+ x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
227
+ x = self.act(x)
228
+ x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
229
+
230
+ x += residual
231
+ return x
232
+
233
+
234
+ class MSFBlock(nn.Module):
235
+ def __init__(self, in_planes, features):
236
+ super().__init__()
237
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
238
+ self.convs = nn.ModuleList()
239
+ self.features = features
240
+
241
+ for i in range(len(in_planes)):
242
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
243
+
244
+ def forward(self, xs, shape):
245
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
246
+ for i in range(len(self.convs)):
247
+ h = self.convs[i](xs[i])
248
+ h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
249
+ sums += h
250
+ return sums
251
+
252
+
253
+ class CondMSFBlock(nn.Module):
254
+ def __init__(self, in_planes, features, num_classes, normalizer):
255
+ super().__init__()
256
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
257
+
258
+ self.convs = nn.ModuleList()
259
+ self.norms = nn.ModuleList()
260
+ self.features = features
261
+ self.normalizer = normalizer
262
+
263
+ for i in range(len(in_planes)):
264
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
265
+ self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
266
+
267
+ def forward(self, xs, y, shape):
268
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
269
+ for i in range(len(self.convs)):
270
+ h = self.norms[i](xs[i], y)
271
+ h = self.convs[i](h)
272
+ h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
273
+ sums += h
274
+ return sums
275
+
276
+
277
+ class RefineBlock(nn.Module):
278
+ def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
279
+ super().__init__()
280
+
281
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
282
+ self.n_blocks = n_blocks = len(in_planes)
283
+
284
+ self.adapt_convs = nn.ModuleList()
285
+ for i in range(n_blocks):
286
+ self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
287
+
288
+ self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
289
+
290
+ if not start:
291
+ self.msf = MSFBlock(in_planes, features)
292
+
293
+ self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
294
+
295
+ def forward(self, xs, output_shape):
296
+ assert isinstance(xs, tuple) or isinstance(xs, list)
297
+ hs = []
298
+ for i in range(len(xs)):
299
+ h = self.adapt_convs[i](xs[i])
300
+ hs.append(h)
301
+
302
+ if self.n_blocks > 1:
303
+ h = self.msf(hs, output_shape)
304
+ else:
305
+ h = hs[0]
306
+
307
+ h = self.crp(h)
308
+ h = self.output_convs(h)
309
+
310
+ return h
311
+
312
+
313
+ class CondRefineBlock(nn.Module):
314
+ def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
315
+ super().__init__()
316
+
317
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
318
+ self.n_blocks = n_blocks = len(in_planes)
319
+
320
+ self.adapt_convs = nn.ModuleList()
321
+ for i in range(n_blocks):
322
+ self.adapt_convs.append(
323
+ CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
324
+ )
325
+
326
+ self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
327
+
328
+ if not start:
329
+ self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
330
+
331
+ self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
332
+
333
+ def forward(self, xs, y, output_shape):
334
+ assert isinstance(xs, tuple) or isinstance(xs, list)
335
+ hs = []
336
+ for i in range(len(xs)):
337
+ h = self.adapt_convs[i](xs[i], y)
338
+ hs.append(h)
339
+
340
+ if self.n_blocks > 1:
341
+ h = self.msf(hs, y, output_shape)
342
+ else:
343
+ h = hs[0]
344
+
345
+ h = self.crp(h, y)
346
+ h = self.output_convs(h, y)
347
+
348
+ return h
349
+
350
+
351
+ class ConvMeanPool(nn.Module):
352
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
353
+ super().__init__()
354
+ if not adjust_padding:
355
+ conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
356
+ self.conv = conv
357
+ else:
358
+ conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
359
+
360
+ self.conv = nn.Sequential(
361
+ nn.ZeroPad2d((1, 0, 1, 0)),
362
+ conv
363
+ )
364
+
365
+ def forward(self, inputs):
366
+ output = self.conv(inputs)
367
+ output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
368
+ output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
369
+ return output
370
+
371
+
372
+ class MeanPoolConv(nn.Module):
373
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
374
+ super().__init__()
375
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
376
+
377
+ def forward(self, inputs):
378
+ output = inputs
379
+ output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
380
+ output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
381
+ return self.conv(output)
382
+
383
+
384
+ class UpsampleConv(nn.Module):
385
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
386
+ super().__init__()
387
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
388
+ self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
389
+
390
+ def forward(self, inputs):
391
+ output = inputs
392
+ output = torch.cat([output, output, output, output], dim=1)
393
+ output = self.pixelshuffle(output)
394
+ return self.conv(output)
395
+
396
+
397
+ class ConditionalResidualBlock(nn.Module):
398
+ def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
399
+ normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
400
+ super().__init__()
401
+ self.non_linearity = act
402
+ self.input_dim = input_dim
403
+ self.output_dim = output_dim
404
+ self.resample = resample
405
+ self.normalization = normalization
406
+ if resample == 'down':
407
+ if dilation > 1:
408
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
409
+ self.normalize2 = normalization(input_dim, num_classes)
410
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
411
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
412
+ else:
413
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
414
+ self.normalize2 = normalization(input_dim, num_classes)
415
+ self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
416
+ conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
417
+
418
+ elif resample is None:
419
+ if dilation > 1:
420
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
421
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
422
+ self.normalize2 = normalization(output_dim, num_classes)
423
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
424
+ else:
425
+ conv_shortcut = nn.Conv2d
426
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
427
+ self.normalize2 = normalization(output_dim, num_classes)
428
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
429
+ else:
430
+ raise Exception('invalid resample value')
431
+
432
+ if output_dim != input_dim or resample is not None:
433
+ self.shortcut = conv_shortcut(input_dim, output_dim)
434
+
435
+ self.normalize1 = normalization(input_dim, num_classes)
436
+
437
+ def forward(self, x, y):
438
+ output = self.normalize1(x, y)
439
+ output = self.non_linearity(output)
440
+ output = self.conv1(output)
441
+ output = self.normalize2(output, y)
442
+ output = self.non_linearity(output)
443
+ output = self.conv2(output)
444
+
445
+ if self.output_dim == self.input_dim and self.resample is None:
446
+ shortcut = x
447
+ else:
448
+ shortcut = self.shortcut(x)
449
+
450
+ return shortcut + output
451
+
452
+
453
+ class ResidualBlock(nn.Module):
454
+ def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
455
+ normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
456
+ super().__init__()
457
+ self.non_linearity = act
458
+ self.input_dim = input_dim
459
+ self.output_dim = output_dim
460
+ self.resample = resample
461
+ self.normalization = normalization
462
+ if resample == 'down':
463
+ if dilation > 1:
464
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
465
+ self.normalize2 = normalization(input_dim)
466
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
467
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
468
+ else:
469
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
470
+ self.normalize2 = normalization(input_dim)
471
+ self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
472
+ conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
473
+
474
+ elif resample is None:
475
+ if dilation > 1:
476
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
477
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
478
+ self.normalize2 = normalization(output_dim)
479
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
480
+ else:
481
+ # conv_shortcut = nn.Conv2d ### Something wierd here.
482
+ conv_shortcut = partial(ncsn_conv1x1)
483
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
484
+ self.normalize2 = normalization(output_dim)
485
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
486
+ else:
487
+ raise Exception('invalid resample value')
488
+
489
+ if output_dim != input_dim or resample is not None:
490
+ self.shortcut = conv_shortcut(input_dim, output_dim)
491
+
492
+ self.normalize1 = normalization(input_dim)
493
+
494
+ def forward(self, x):
495
+ output = self.normalize1(x)
496
+ output = self.non_linearity(output)
497
+ output = self.conv1(output)
498
+ output = self.normalize2(output)
499
+ output = self.non_linearity(output)
500
+ output = self.conv2(output)
501
+
502
+ if self.output_dim == self.input_dim and self.resample is None:
503
+ shortcut = x
504
+ else:
505
+ shortcut = self.shortcut(x)
506
+
507
+ return shortcut + output
508
+
509
+
510
+ ###########################################################################
511
+ # Functions below are ported over from the DDPM codebase:
512
+ # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
513
+ ###########################################################################
514
+
515
+ def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
516
+ assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
517
+ half_dim = embedding_dim // 2
518
+ # magic number 10000 is from transformers
519
+ emb = math.log(max_positions) / (half_dim - 1)
520
+ # emb = math.log(2.) / (half_dim - 1)
521
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
522
+ # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
523
+ # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
524
+ emb = timesteps.float()[:, None] * emb[None, :]
525
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
526
+ if embedding_dim % 2 == 1: # zero pad
527
+ emb = F.pad(emb, (0, 1), mode='constant')
528
+ assert emb.shape == (timesteps.shape[0], embedding_dim)
529
+ return emb
530
+
531
+
532
+ def _einsum(a, b, c, x, y):
533
+ einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
534
+ return torch.einsum(einsum_str, x, y)
535
+
536
+
537
+ def contract_inner(x, y):
538
+ """tensordot(x, y, 1)."""
539
+ x_chars = list(string.ascii_lowercase[:len(x.shape)])
540
+ y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
541
+ y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
542
+ out_chars = x_chars[:-1] + y_chars[1:]
543
+ return _einsum(x_chars, y_chars, out_chars, x, y)
544
+
545
+
546
+ class NIN(nn.Module):
547
+ def __init__(self, in_dim, num_units, init_scale=0.1):
548
+ super().__init__()
549
+ self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
550
+ self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
551
+
552
+ def forward(self, x):
553
+ x = x.permute(0, 2, 3, 1)
554
+ y = contract_inner(x, self.W) + self.b
555
+ return y.permute(0, 3, 1, 2)
556
+
557
+
558
+ class AttnBlock(nn.Module):
559
+ """Channel-wise self-attention block."""
560
+ def __init__(self, channels):
561
+ super().__init__()
562
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
563
+ self.NIN_0 = NIN(channels, channels)
564
+ self.NIN_1 = NIN(channels, channels)
565
+ self.NIN_2 = NIN(channels, channels)
566
+ self.NIN_3 = NIN(channels, channels, init_scale=0.)
567
+
568
+ def forward(self, x):
569
+ B, C, H, W = x.shape
570
+ h = self.GroupNorm_0(x)
571
+ q = self.NIN_0(h)
572
+ k = self.NIN_1(h)
573
+ v = self.NIN_2(h)
574
+
575
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
576
+ w = torch.reshape(w, (B, H, W, H * W))
577
+ w = F.softmax(w, dim=-1)
578
+ w = torch.reshape(w, (B, H, W, H, W))
579
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
580
+ h = self.NIN_3(h)
581
+ return x + h
582
+
583
+
584
+ class Upsample(nn.Module):
585
+ def __init__(self, channels, with_conv=False):
586
+ super().__init__()
587
+ if with_conv:
588
+ self.Conv_0 = ddpm_conv3x3(channels, channels)
589
+ self.with_conv = with_conv
590
+
591
+ def forward(self, x):
592
+ B, C, H, W = x.shape
593
+ h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
594
+ if self.with_conv:
595
+ h = self.Conv_0(h)
596
+ return h
597
+
598
+
599
+ class Downsample(nn.Module):
600
+ def __init__(self, channels, with_conv=False):
601
+ super().__init__()
602
+ if with_conv:
603
+ self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
604
+ self.with_conv = with_conv
605
+
606
+ def forward(self, x):
607
+ B, C, H, W = x.shape
608
+ # Emulate 'SAME' padding
609
+ if self.with_conv:
610
+ x = F.pad(x, (0, 1, 0, 1))
611
+ x = self.Conv_0(x)
612
+ else:
613
+ x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
614
+
615
+ assert x.shape == (B, C, H // 2, W // 2)
616
+ return x
617
+
618
+
619
+ class ResnetBlockDDPM(nn.Module):
620
+ """The ResNet Blocks used in DDPM."""
621
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
622
+ super().__init__()
623
+ if out_ch is None:
624
+ out_ch = in_ch
625
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
626
+ self.act = act
627
+ self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
628
+ if temb_dim is not None:
629
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
630
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
631
+ nn.init.zeros_(self.Dense_0.bias)
632
+
633
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
634
+ self.Dropout_0 = nn.Dropout(dropout)
635
+ self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
636
+ if in_ch != out_ch:
637
+ if conv_shortcut:
638
+ self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
639
+ else:
640
+ self.NIN_0 = NIN(in_ch, out_ch)
641
+ self.out_ch = out_ch
642
+ self.in_ch = in_ch
643
+ self.conv_shortcut = conv_shortcut
644
+
645
+ def forward(self, x, temb=None):
646
+ B, C, H, W = x.shape
647
+ assert C == self.in_ch
648
+ out_ch = self.out_ch if self.out_ch else self.in_ch
649
+ h = self.act(self.GroupNorm_0(x))
650
+ h = self.Conv_0(h)
651
+ # Add bias to each feature map conditioned on the time embedding
652
+ if temb is not None:
653
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
654
+ h = self.act(self.GroupNorm_1(h))
655
+ h = self.Dropout_0(h)
656
+ h = self.Conv_1(h)
657
+ if C != out_ch:
658
+ if self.conv_shortcut:
659
+ x = self.Conv_2(x)
660
+ else:
661
+ x = self.NIN_0(x)
662
+ return x + h
geco/backbones/ncsnpp_utils/layerspp.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ """Layers for defining NCSN++.
18
+ """
19
+ from . import layers
20
+ import score_models.layers.up_or_downsampling2d as up_or_down_sampling
21
+ import torch.nn as nn
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import numpy as np
25
+
26
+ conv1x1 = layers.ddpm_conv1x1
27
+ conv3x3 = layers.ddpm_conv3x3
28
+ NIN = layers.NIN
29
+ default_init = layers.default_init
30
+
31
+
32
+ class GaussianFourierProjection(nn.Module):
33
+ """Gaussian Fourier embeddings for noise levels."""
34
+
35
+ def __init__(self, embedding_size=256, scale=1.0):
36
+ super().__init__()
37
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
38
+
39
+ def forward(self, x):
40
+ x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
41
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
42
+
43
+
44
+ class Combine(nn.Module):
45
+ """Combine information from skip connections."""
46
+
47
+ def __init__(self, dim1, dim2, method='cat'):
48
+ super().__init__()
49
+ self.Conv_0 = conv1x1(dim1, dim2)
50
+ self.method = method
51
+
52
+ def forward(self, x, y):
53
+ h = self.Conv_0(x)
54
+ if self.method == 'cat':
55
+ return torch.cat([h, y], dim=1)
56
+ elif self.method == 'sum':
57
+ return h + y
58
+ else:
59
+ raise ValueError(f'Method {self.method} not recognized.')
60
+
61
+
62
+ class AttnBlockpp(nn.Module):
63
+ """Channel-wise self-attention block. Modified from DDPM."""
64
+
65
+ def __init__(self, channels, skip_rescale=False, init_scale=0.):
66
+ super().__init__()
67
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
68
+ eps=1e-6)
69
+ self.NIN_0 = NIN(channels, channels)
70
+ self.NIN_1 = NIN(channels, channels)
71
+ self.NIN_2 = NIN(channels, channels)
72
+ self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
73
+ self.skip_rescale = skip_rescale
74
+
75
+ def forward(self, x):
76
+ B, C, H, W = x.shape
77
+ h = self.GroupNorm_0(x)
78
+ q = self.NIN_0(h)
79
+ k = self.NIN_1(h)
80
+ v = self.NIN_2(h)
81
+
82
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
83
+ w = torch.reshape(w, (B, H, W, H * W))
84
+ w = F.softmax(w, dim=-1)
85
+ w = torch.reshape(w, (B, H, W, H, W))
86
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
87
+ h = self.NIN_3(h)
88
+ if not self.skip_rescale:
89
+ return x + h
90
+ else:
91
+ return (x + h) / np.sqrt(2.)
92
+
93
+
94
+ class ResnetBlockDDPMpp(nn.Module):
95
+ """ResBlock adapted from DDPM."""
96
+
97
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
98
+ dropout=0.1, skip_rescale=False, init_scale=0.):
99
+ super().__init__()
100
+ out_ch = out_ch if out_ch else in_ch
101
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
102
+ self.Conv_0 = conv3x3(in_ch, out_ch)
103
+ if temb_dim is not None:
104
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
105
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
106
+ nn.init.zeros_(self.Dense_0.bias)
107
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
108
+ self.Dropout_0 = nn.Dropout(dropout)
109
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
110
+ if in_ch != out_ch:
111
+ if conv_shortcut:
112
+ self.Conv_2 = conv3x3(in_ch, out_ch)
113
+ else:
114
+ self.NIN_0 = NIN(in_ch, out_ch)
115
+
116
+ self.skip_rescale = skip_rescale
117
+ self.act = act
118
+ self.out_ch = out_ch
119
+ self.conv_shortcut = conv_shortcut
120
+
121
+ def forward(self, x, temb=None):
122
+ h = self.act(self.GroupNorm_0(x))
123
+ h = self.Conv_0(h)
124
+ if temb is not None:
125
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
126
+ h = self.act(self.GroupNorm_1(h))
127
+ h = self.Dropout_0(h)
128
+ h = self.Conv_1(h)
129
+ if x.shape[1] != self.out_ch:
130
+ if self.conv_shortcut:
131
+ x = self.Conv_2(x)
132
+ else:
133
+ x = self.NIN_0(x)
134
+ if not self.skip_rescale:
135
+ return x + h
136
+ else:
137
+ return (x + h) / np.sqrt(2.)
138
+
139
+
140
+ class ResnetBlockBigGANpp(nn.Module):
141
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
142
+ dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
143
+ skip_rescale=True, init_scale=0.):
144
+ super().__init__()
145
+
146
+ out_ch = out_ch if out_ch else in_ch
147
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
148
+ self.up = up
149
+ self.down = down
150
+ self.fir = fir
151
+ self.fir_kernel = fir_kernel
152
+
153
+ self.Conv_0 = conv3x3(in_ch, out_ch)
154
+ if temb_dim is not None:
155
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
156
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
157
+ nn.init.zeros_(self.Dense_0.bias)
158
+
159
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
160
+ self.Dropout_0 = nn.Dropout(dropout)
161
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
162
+ if in_ch != out_ch or up or down:
163
+ self.Conv_2 = conv1x1(in_ch, out_ch)
164
+
165
+ self.skip_rescale = skip_rescale
166
+ self.act = act
167
+ self.in_ch = in_ch
168
+ self.out_ch = out_ch
169
+
170
+ def forward(self, x, temb=None):
171
+ h = self.act(self.GroupNorm_0(x))
172
+
173
+ if self.up:
174
+ if self.fir:
175
+ h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
176
+ x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
177
+ else:
178
+ h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
179
+ x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
180
+ elif self.down:
181
+ if self.fir:
182
+ h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
183
+ x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
184
+ else:
185
+ h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
186
+ x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
187
+
188
+ h = self.Conv_0(h)
189
+ # Add bias to each feature map conditioned on the time embedding
190
+ if temb is not None:
191
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
192
+ h = self.act(self.GroupNorm_1(h))
193
+ h = self.Dropout_0(h)
194
+ h = self.Conv_1(h)
195
+
196
+ if self.in_ch != self.out_ch or self.up or self.down:
197
+ x = self.Conv_2(x)
198
+
199
+ if not self.skip_rescale:
200
+ return x + h
201
+ else:
202
+ return (x + h) / np.sqrt(2.)
geco/backbones/ncsnpp_utils/normalization.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Normalization layers."""
17
+ import torch.nn as nn
18
+ import torch
19
+ import functools
20
+
21
+
22
+ def get_normalization(config, conditional=False):
23
+ """Obtain normalization modules from the config file."""
24
+ norm = config.model.normalization
25
+ if conditional:
26
+ if norm == 'InstanceNorm++':
27
+ return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
28
+ else:
29
+ raise NotImplementedError(f'{norm} not implemented yet.')
30
+ else:
31
+ if norm == 'InstanceNorm':
32
+ return nn.InstanceNorm2d
33
+ elif norm == 'InstanceNorm++':
34
+ return InstanceNorm2dPlus
35
+ elif norm == 'VarianceNorm':
36
+ return VarianceNorm2d
37
+ elif norm == 'GroupNorm':
38
+ return nn.GroupNorm
39
+ else:
40
+ raise ValueError('Unknown normalization: %s' % norm)
41
+
42
+
43
+ class ConditionalBatchNorm2d(nn.Module):
44
+ def __init__(self, num_features, num_classes, bias=True):
45
+ super().__init__()
46
+ self.num_features = num_features
47
+ self.bias = bias
48
+ self.bn = nn.BatchNorm2d(num_features, affine=False)
49
+ if self.bias:
50
+ self.embed = nn.Embedding(num_classes, num_features * 2)
51
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
52
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
53
+ else:
54
+ self.embed = nn.Embedding(num_classes, num_features)
55
+ self.embed.weight.data.uniform_()
56
+
57
+ def forward(self, x, y):
58
+ out = self.bn(x)
59
+ if self.bias:
60
+ gamma, beta = self.embed(y).chunk(2, dim=1)
61
+ out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
62
+ else:
63
+ gamma = self.embed(y)
64
+ out = gamma.view(-1, self.num_features, 1, 1) * out
65
+ return out
66
+
67
+
68
+ class ConditionalInstanceNorm2d(nn.Module):
69
+ def __init__(self, num_features, num_classes, bias=True):
70
+ super().__init__()
71
+ self.num_features = num_features
72
+ self.bias = bias
73
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
74
+ if bias:
75
+ self.embed = nn.Embedding(num_classes, num_features * 2)
76
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
77
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
78
+ else:
79
+ self.embed = nn.Embedding(num_classes, num_features)
80
+ self.embed.weight.data.uniform_()
81
+
82
+ def forward(self, x, y):
83
+ h = self.instance_norm(x)
84
+ if self.bias:
85
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
86
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
87
+ else:
88
+ gamma = self.embed(y)
89
+ out = gamma.view(-1, self.num_features, 1, 1) * h
90
+ return out
91
+
92
+
93
+ class ConditionalVarianceNorm2d(nn.Module):
94
+ def __init__(self, num_features, num_classes, bias=False):
95
+ super().__init__()
96
+ self.num_features = num_features
97
+ self.bias = bias
98
+ self.embed = nn.Embedding(num_classes, num_features)
99
+ self.embed.weight.data.normal_(1, 0.02)
100
+
101
+ def forward(self, x, y):
102
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
103
+ h = x / torch.sqrt(vars + 1e-5)
104
+
105
+ gamma = self.embed(y)
106
+ out = gamma.view(-1, self.num_features, 1, 1) * h
107
+ return out
108
+
109
+
110
+ class VarianceNorm2d(nn.Module):
111
+ def __init__(self, num_features, bias=False):
112
+ super().__init__()
113
+ self.num_features = num_features
114
+ self.bias = bias
115
+ self.alpha = nn.Parameter(torch.zeros(num_features))
116
+ self.alpha.data.normal_(1, 0.02)
117
+
118
+ def forward(self, x):
119
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
120
+ h = x / torch.sqrt(vars + 1e-5)
121
+
122
+ out = self.alpha.view(-1, self.num_features, 1, 1) * h
123
+ return out
124
+
125
+
126
+ class ConditionalNoneNorm2d(nn.Module):
127
+ def __init__(self, num_features, num_classes, bias=True):
128
+ super().__init__()
129
+ self.num_features = num_features
130
+ self.bias = bias
131
+ if bias:
132
+ self.embed = nn.Embedding(num_classes, num_features * 2)
133
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
134
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
135
+ else:
136
+ self.embed = nn.Embedding(num_classes, num_features)
137
+ self.embed.weight.data.uniform_()
138
+
139
+ def forward(self, x, y):
140
+ if self.bias:
141
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
142
+ out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
143
+ else:
144
+ gamma = self.embed(y)
145
+ out = gamma.view(-1, self.num_features, 1, 1) * x
146
+ return out
147
+
148
+
149
+ class NoneNorm2d(nn.Module):
150
+ def __init__(self, num_features, bias=True):
151
+ super().__init__()
152
+
153
+ def forward(self, x):
154
+ return x
155
+
156
+
157
+ class InstanceNorm2dPlus(nn.Module):
158
+ def __init__(self, num_features, bias=True):
159
+ super().__init__()
160
+ self.num_features = num_features
161
+ self.bias = bias
162
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
163
+ self.alpha = nn.Parameter(torch.zeros(num_features))
164
+ self.gamma = nn.Parameter(torch.zeros(num_features))
165
+ self.alpha.data.normal_(1, 0.02)
166
+ self.gamma.data.normal_(1, 0.02)
167
+ if bias:
168
+ self.beta = nn.Parameter(torch.zeros(num_features))
169
+
170
+ def forward(self, x):
171
+ means = torch.mean(x, dim=(2, 3))
172
+ m = torch.mean(means, dim=-1, keepdim=True)
173
+ v = torch.var(means, dim=-1, keepdim=True)
174
+ means = (means - m) / (torch.sqrt(v + 1e-5))
175
+ h = self.instance_norm(x)
176
+
177
+ if self.bias:
178
+ h = h + means[..., None, None] * self.alpha[..., None, None]
179
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
180
+ else:
181
+ h = h + means[..., None, None] * self.alpha[..., None, None]
182
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h
183
+ return out
184
+
185
+
186
+ class ConditionalInstanceNorm2dPlus(nn.Module):
187
+ def __init__(self, num_features, num_classes, bias=True):
188
+ super().__init__()
189
+ self.num_features = num_features
190
+ self.bias = bias
191
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
192
+ if bias:
193
+ self.embed = nn.Embedding(num_classes, num_features * 3)
194
+ self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
195
+ self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
196
+ else:
197
+ self.embed = nn.Embedding(num_classes, 2 * num_features)
198
+ self.embed.weight.data.normal_(1, 0.02)
199
+
200
+ def forward(self, x, y):
201
+ means = torch.mean(x, dim=(2, 3))
202
+ m = torch.mean(means, dim=-1, keepdim=True)
203
+ v = torch.var(means, dim=-1, keepdim=True)
204
+ means = (means - m) / (torch.sqrt(v + 1e-5))
205
+ h = self.instance_norm(x)
206
+
207
+ if self.bias:
208
+ gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
209
+ h = h + means[..., None, None] * alpha[..., None, None]
210
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
211
+ else:
212
+ gamma, alpha = self.embed(y).chunk(2, dim=-1)
213
+ h = h + means[..., None, None] * alpha[..., None, None]
214
+ out = gamma.view(-1, self.num_features, 1, 1) * h
215
+ return out
geco/backbones/ncsnpp_utils/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """All functions and modules related to model definition.
17
+ """
18
+
19
+ import torch
20
+
21
+ import numpy as np
22
+ from ...sdes import OUVESDE, OUVPSDE
23
+
24
+
25
+ _MODELS = {}
26
+
27
+
28
+ def register_model(cls=None, *, name=None):
29
+ """A decorator for registering model classes."""
30
+
31
+ def _register(cls):
32
+ if name is None:
33
+ local_name = cls.__name__
34
+ else:
35
+ local_name = name
36
+ if local_name in _MODELS:
37
+ raise ValueError(f'Already registered model with name: {local_name}')
38
+ _MODELS[local_name] = cls
39
+ return cls
40
+
41
+ if cls is None:
42
+ return _register
43
+ else:
44
+ return _register(cls)
45
+
46
+
47
+ def get_model(name):
48
+ return _MODELS[name]
49
+
50
+
51
+ def get_sigmas(sigma_min, sigma_max, num_scales):
52
+ """Get sigmas --- the set of noise levels for SMLD from config files.
53
+ Args:
54
+ config: A ConfigDict object parsed from the config file
55
+ Returns:
56
+ sigmas: a jax numpy arrary of noise levels
57
+ """
58
+ sigmas = np.exp(
59
+ np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
60
+
61
+ return sigmas
62
+
63
+
64
+ def get_ddpm_params(config):
65
+ """Get betas and alphas --- parameters used in the original DDPM paper."""
66
+ num_diffusion_timesteps = 1000
67
+ # parameters need to be adapted if number of time steps differs from 1000
68
+ beta_start = config.model.beta_min / config.model.num_scales
69
+ beta_end = config.model.beta_max / config.model.num_scales
70
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
71
+
72
+ alphas = 1. - betas
73
+ alphas_cumprod = np.cumprod(alphas, axis=0)
74
+ sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
75
+ sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
76
+
77
+ return {
78
+ 'betas': betas,
79
+ 'alphas': alphas,
80
+ 'alphas_cumprod': alphas_cumprod,
81
+ 'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
82
+ 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
83
+ 'beta_min': beta_start * (num_diffusion_timesteps - 1),
84
+ 'beta_max': beta_end * (num_diffusion_timesteps - 1),
85
+ 'num_diffusion_timesteps': num_diffusion_timesteps
86
+ }
87
+
88
+
89
+ def create_model(config):
90
+ """Create the score model."""
91
+ model_name = config.model.name
92
+ score_model = get_model(model_name)(config)
93
+ score_model = score_model.to(config.device)
94
+ score_model = torch.nn.DataParallel(score_model)
95
+ return score_model
96
+
97
+
98
+ def get_model_fn(model, train=False):
99
+ """Create a function to give the output of the score-based model.
100
+
101
+ Args:
102
+ model: The score model.
103
+ train: `True` for training and `False` for evaluation.
104
+
105
+ Returns:
106
+ A model function.
107
+ """
108
+
109
+ def model_fn(x, labels):
110
+ """Compute the output of the score-based model.
111
+
112
+ Args:
113
+ x: A mini-batch of input data.
114
+ labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
115
+ for different models.
116
+
117
+ Returns:
118
+ A tuple of (model output, new mutable states)
119
+ """
120
+ if not train:
121
+ model.eval()
122
+ return model(x, labels)
123
+ else:
124
+ model.train()
125
+ return model(x, labels)
126
+
127
+ return model_fn
128
+
129
+
130
+ def get_score_fn(sde, model, train=False, continuous=False):
131
+ """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
132
+
133
+ Args:
134
+ sde: An `sde_lib.SDE` object that represents the forward SDE.
135
+ model: A score model.
136
+ train: `True` for training and `False` for evaluation.
137
+ continuous: If `True`, the score-based model is expected to directly take continuous time steps.
138
+
139
+ Returns:
140
+ A score function.
141
+ """
142
+ model_fn = get_model_fn(model, train=train)
143
+
144
+ if isinstance(sde, OUVPSDE):
145
+ def score_fn(x, t):
146
+ # Scale neural network output by standard deviation and flip sign
147
+ if continuous:
148
+ # For VP-trained models, t=0 corresponds to the lowest noise level
149
+ # The maximum value of time embedding is assumed to 999 for
150
+ # continuously-trained models.
151
+ labels = t * 999
152
+ score = model_fn(x, labels)
153
+ std = sde.marginal_prob(torch.zeros_like(x), t)[1]
154
+ else:
155
+ # For VP-trained models, t=0 corresponds to the lowest noise level
156
+ labels = t * (sde.N - 1)
157
+ score = model_fn(x, labels)
158
+ std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
159
+
160
+ score = -score / std[:, None, None, None]
161
+ return score
162
+
163
+ elif isinstance(sde, OUVESDE):
164
+ def score_fn(x, t):
165
+ if continuous:
166
+ labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
167
+ else:
168
+ # For VE-trained models, t=0 corresponds to the highest noise level
169
+ labels = sde.T - t
170
+ labels *= sde.N - 1
171
+ labels = torch.round(labels).long()
172
+
173
+ score = model_fn(x, labels)
174
+ return score
175
+
176
+ else:
177
+ raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
178
+
179
+ return score_fn
180
+
181
+
182
+ def to_flattened_numpy(x):
183
+ """Flatten a torch tensor `x` and convert it to numpy."""
184
+ return x.detach().cpu().numpy().reshape((-1,))
185
+
186
+
187
+ def from_flattened_numpy(x, shape):
188
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
189
+ return torch.from_numpy(x.reshape(shape))
geco/backbones/shared.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from geco.util.registry import Registry
8
+
9
+
10
+ BackboneRegistry = Registry("Backbone")
11
+
12
+
13
+ class GaussianFourierProjection(nn.Module):
14
+ """Gaussian random features for encoding time steps."""
15
+
16
+ def __init__(self, embed_dim, scale=16, complex_valued=False):
17
+ super().__init__()
18
+ self.complex_valued = complex_valued
19
+ if not complex_valued:
20
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
21
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
22
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
23
+ # and this halving is not necessary.
24
+ embed_dim = embed_dim // 2
25
+ # Randomly sample weights during initialization. These weights are fixed
26
+ # during optimization and are not trainable.
27
+ self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
28
+
29
+ def forward(self, t):
30
+ t_proj = t[:, None] * self.W[None, :] * 2*np.pi
31
+ if self.complex_valued:
32
+ return torch.exp(1j * t_proj)
33
+ else:
34
+ return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
35
+
36
+
37
+ class DiffusionStepEmbedding(nn.Module):
38
+ """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
39
+
40
+ def __init__(self, embed_dim, complex_valued=False):
41
+ super().__init__()
42
+ self.complex_valued = complex_valued
43
+ if not complex_valued:
44
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
45
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
46
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
47
+ # and this halving is not necessary.
48
+ embed_dim = embed_dim // 2
49
+ self.embed_dim = embed_dim
50
+
51
+ def forward(self, t):
52
+ fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
53
+ inner = t[:, None] * fac[None, :]
54
+ if self.complex_valued:
55
+ return torch.exp(1j * inner)
56
+ else:
57
+ return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
58
+
59
+
60
+ class ComplexLinear(nn.Module):
61
+ """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
62
+ def __init__(self, input_dim, output_dim, complex_valued):
63
+ super().__init__()
64
+ self.complex_valued = complex_valued
65
+ if self.complex_valued:
66
+ self.re = nn.Linear(input_dim, output_dim)
67
+ self.im = nn.Linear(input_dim, output_dim)
68
+ else:
69
+ self.lin = nn.Linear(input_dim, output_dim)
70
+
71
+ def forward(self, x):
72
+ if self.complex_valued:
73
+ return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
74
+ else:
75
+ return self.lin(x)
76
+
77
+
78
+ class FeatureMapDense(nn.Module):
79
+ """A fully connected layer that reshapes outputs to feature maps."""
80
+
81
+ def __init__(self, input_dim, output_dim, complex_valued=False):
82
+ super().__init__()
83
+ self.complex_valued = complex_valued
84
+ self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
85
+
86
+ def forward(self, x):
87
+ return self.dense(x)[..., None, None]
88
+
89
+
90
+ def torch_complex_from_reim(re, im):
91
+ return torch.view_as_complex(torch.stack([re, im], dim=-1))
92
+
93
+
94
+ class ArgsComplexMultiplicationWrapper(nn.Module):
95
+ """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
96
+
97
+ Make a complex-valued module `F` from a real-valued module `f` by applying
98
+ complex multiplication rules:
99
+
100
+ F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
101
+
102
+ where `f1`, `f2` are instances of `f` that do *not* share weights.
103
+
104
+ Args:
105
+ module_cls (callable): A class or function that returns a Torch module/functional.
106
+ Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
107
+ to construct the real and imaginary component modules.
108
+ """
109
+
110
+ def __init__(self, module_cls, *args, **kwargs):
111
+ super().__init__()
112
+ self.re_module = module_cls(*args, **kwargs)
113
+ self.im_module = module_cls(*args, **kwargs)
114
+
115
+ def forward(self, x, *args, **kwargs):
116
+ return torch_complex_from_reim(
117
+ self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
118
+ self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
119
+ )
120
+
121
+
122
+ ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
123
+ ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
geco/data_module.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import join
3
+ import torch
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import Dataset
6
+ from torch.utils.data import DataLoader
7
+ from glob import glob
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+
12
+
13
+ def get_window(window_type, window_length):
14
+ if window_type == 'sqrthann':
15
+ return torch.sqrt(torch.hann_window(window_length, periodic=True))
16
+ elif window_type == 'hann':
17
+ return torch.hann_window(window_length, periodic=True)
18
+ else:
19
+ raise NotImplementedError(f"Window type {window_type} not implemented!")
20
+
21
+
22
+ class Specs(Dataset):
23
+ def __init__(self, data_dir, dummy, shuffle_spec, num_frames, sampling_rate=8000,
24
+ format='default', normalize="noisy", spec_transform=None,
25
+ stft_kwargs=None, **ignored_kwargs):
26
+
27
+ # Read file paths according to file naming format.
28
+ if format == "default":
29
+ noisy_files1 = sorted(glob(os.path.join(data_dir, '*_source1hatP.wav')))
30
+ clean_files1 = [item.replace('_source1hatP.wav', '_source1.wav') for item in noisy_files1]
31
+ mixture_files1 = [item.replace('_source1hatP.wav', '_mix.wav') for item in noisy_files1]
32
+ noisy_files2 = sorted(glob(os.path.join(data_dir, '*_source2hatP.wav')))
33
+ clean_files2 = [item.replace('_source2hatP.wav', '_source2.wav') for item in noisy_files2]
34
+ mixture_files2 = [item.replace('_source2hatP.wav', '_mix.wav') for item in noisy_files2]
35
+
36
+ self.mixture_files = [*mixture_files1,*mixture_files2]
37
+ self.noisy_files = [*noisy_files1,*noisy_files2]
38
+ self.clean_files = [*clean_files1,*clean_files2]
39
+ else:
40
+ # Feel free to add your own directory format
41
+ raise NotImplementedError(f"Directory format {format} unknown!")
42
+
43
+ self.dummy = dummy
44
+ self.num_frames = num_frames
45
+ self.shuffle_spec = shuffle_spec
46
+ self.normalize = normalize
47
+ self.spec_transform = spec_transform
48
+ self.sampling_rate = sampling_rate
49
+
50
+ assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs"
51
+ self.stft_kwargs = stft_kwargs
52
+ self.hop_length = self.stft_kwargs["hop_length"]
53
+ assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation"
54
+
55
+ def __getitem__(self, i):
56
+ x, sr = torchaudio.load(self.clean_files[i])
57
+ if sr != self.sampling_rate:
58
+ x = torchaudio.transforms.Resample(sr, self.sampling_rate)(x)
59
+ y, sr = torchaudio.load(self.noisy_files[i])
60
+ if sr != self.sampling_rate:
61
+ y = torchaudio.transforms.Resample(sr, self.sampling_rate)(y)
62
+ m, sr = torchaudio.load(self.mixture_files[i])
63
+ if sr != self.sampling_rate:
64
+ m = torchaudio.transforms.Resample(sr, self.sampling_rate)(m)
65
+
66
+ min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1])
67
+ x = x[...,:min_leng]
68
+ y = y[...,:min_leng]
69
+ m = m[...,:min_leng]
70
+
71
+ # formula applies for center=True
72
+ target_len = (self.num_frames - 1) * self.hop_length
73
+ current_len = x.size(-1)
74
+ pad = max(target_len - current_len, 0)
75
+ if pad == 0:
76
+ # extract random part of the audio file
77
+ if self.shuffle_spec:
78
+ start = int(np.random.uniform(0, current_len-target_len))
79
+ else:
80
+ start = int((current_len-target_len)/2)
81
+
82
+ if y[..., start:start+target_len].abs().max() < 0.05:
83
+ start = 0
84
+
85
+ x = x[..., start:start+target_len]
86
+ y = y[..., start:start+target_len]
87
+ m = m[..., start:start+target_len]
88
+ else:
89
+ # pad audio if the length T is smaller than num_frames
90
+ x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant')
91
+ y = F.pad(y, (pad//2, pad//2+(pad%2)), mode='constant')
92
+ m = F.pad(m, (pad//2, pad//2+(pad%2)), mode='constant')
93
+
94
+ # normalize w.r.t to the noisy or the clean signal or not at all
95
+ # to ensure same clean signal power in x and y.
96
+ if self.normalize == "noisy":
97
+ normfac = y.abs().max()
98
+ elif self.normalize == "clean":
99
+ normfac = x.abs().max()
100
+ elif self.normalize == "not":
101
+ normfac = 1.0
102
+ x = x / normfac
103
+ y = y / normfac
104
+ m = m / normfac
105
+ X = torch.stft(x, **self.stft_kwargs)
106
+ Y = torch.stft(y, **self.stft_kwargs)
107
+ M = torch.stft(m, **self.stft_kwargs)
108
+ X, Y, M = self.spec_transform(X), self.spec_transform(Y), self.spec_transform(M)
109
+ return X, Y, M
110
+
111
+ def __len__(self):
112
+ if self.dummy:
113
+ # for debugging shrink the data set size
114
+ return int(len(self.clean_files)/200)
115
+ else:
116
+ return len(self.clean_files)
117
+
118
+
119
+ class SpecsDataModule(pl.LightningDataModule):
120
+ @staticmethod
121
+ def add_argparse_args(parser):
122
+ parser.add_argument("--train_dir", type=str, default='/export/corpora7/HW/speechbrain/recipes/LibriMix/separation/2025/save/libri2mix-train100')
123
+ parser.add_argument("--val_dir", type=str, default='/export/corpora7/HW/speechbrain/recipes/LibriMix/separation/2025/save/libri2mix-dev')
124
+ parser.add_argument("--test_dir", type=str, default='/export/corpora7/HW/speechbrain/recipes/LibriMix/separation/2025/save/libri2mix-test')
125
+ parser.add_argument("--format", type=str, default="default", help="Read file paths according to file naming format.")
126
+ parser.add_argument("--sampling_rate", type=int, default=8000, help="The sampling rate.")
127
+ parser.add_argument("--batch_size", type=int, default=16, help="The batch size. 8 by default.")
128
+ parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 128 freq bins
129
+ parser.add_argument("--hop_length", type=int, default=64, help="Window hop length. 128 by default.")
130
+ parser.add_argument("--num_frames", type=int, default=256, help="Number of frames for the dataset. 256 by default.")
131
+ parser.add_argument("--window", type=str, choices=("sqrthann", "hann"), default="hann", help="The window function to use for the STFT. 'hann' by default.")
132
+ parser.add_argument("--num_workers", type=int, default=8, help="Number of workers to use for DataLoaders. 4 by default.")
133
+ parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.")
134
+ parser.add_argument("--spec_factor", type=float, default=0.15, help="Factor to multiply complex STFT coefficients by. 0.15 by default.")
135
+ parser.add_argument("--spec_abs_exponent", type=float, default=0.5, help="Exponent e for the transformation abs(z)**e * exp(1j*angle(z)). 0.5 by default.")
136
+ parser.add_argument("--normalize", type=str, choices=("clean", "noisy", "not"), default="noisy", help="Normalize the input waveforms by the clean signal, the noisy signal, or not at all.")
137
+ parser.add_argument("--transform_type", type=str, choices=("exponent", "log", "none"), default="exponent", help="Spectogram transformation for input representation.")
138
+ return parser
139
+
140
+ def __init__(
141
+ self, train_dir, val_dir, test_dir, format='default', sampling_rate=8000, batch_size=8,
142
+ n_fft=510, hop_length=64, num_frames=256, window='hann',
143
+ num_workers=4, dummy=False, spec_factor=0.15, spec_abs_exponent=0.5,
144
+ gpu=True, normalize='noisy', transform_type="exponent", **kwargs
145
+ ):
146
+ super().__init__()
147
+ self.train_dir = train_dir
148
+ self.val_dir = val_dir
149
+ self.test_dir = test_dir
150
+ self.format = format
151
+ self.sampling_rate = sampling_rate
152
+ self.batch_size = batch_size
153
+ self.n_fft = n_fft
154
+ self.hop_length = hop_length
155
+ self.num_frames = num_frames
156
+ self.window = get_window(window, self.n_fft)
157
+ self.windows = {}
158
+ self.num_workers = num_workers
159
+ self.dummy = dummy
160
+ self.spec_factor = spec_factor
161
+ self.spec_abs_exponent = spec_abs_exponent
162
+ self.gpu = gpu
163
+ self.normalize = normalize
164
+ self.transform_type = transform_type
165
+ self.kwargs = kwargs
166
+
167
+ def setup(self, stage=None):
168
+ specs_kwargs = dict(
169
+ stft_kwargs=self.stft_kwargs, num_frames=self.num_frames,
170
+ spec_transform=self.spec_fwd, **self.kwargs
171
+ )
172
+ if stage == 'fit' or stage is None:
173
+ self.train_set = Specs(data_dir=self.train_dir,
174
+ dummy=self.dummy, shuffle_spec=True, format=self.format,
175
+ normalize=self.normalize, sampling_rate=self.sampling_rate, **specs_kwargs)
176
+ self.valid_set = Specs(data_dir=self.val_dir,
177
+ dummy=self.dummy, shuffle_spec=False, format=self.format,
178
+ normalize=self.normalize, sampling_rate=self.sampling_rate, **specs_kwargs)
179
+ if stage == 'test' or stage is None:
180
+ self.test_set = Specs(data_dir=self.test_dir,
181
+ dummy=self.dummy, shuffle_spec=False, format=self.format,
182
+ normalize=self.normalize, sampling_rate=self.sampling_rate, **specs_kwargs)
183
+
184
+ def spec_fwd(self, spec):
185
+ if self.transform_type == "exponent":
186
+ if self.spec_abs_exponent != 1:
187
+ # only do this calculation if spec_exponent != 1, otherwise it's quite a bit of wasted computation
188
+ # and introduced numerical error
189
+ e = self.spec_abs_exponent
190
+ spec = spec.abs()**e * torch.exp(1j * spec.angle())
191
+ spec = spec * self.spec_factor
192
+ elif self.transform_type == "log":
193
+ spec = torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle())
194
+ spec = spec * self.spec_factor
195
+ elif self.transform_type == "none":
196
+ spec = spec
197
+ return spec
198
+
199
+ def spec_back(self, spec):
200
+ if self.transform_type == "exponent":
201
+ spec = spec / self.spec_factor
202
+ if self.spec_abs_exponent != 1:
203
+ e = self.spec_abs_exponent
204
+ spec = spec.abs()**(1/e) * torch.exp(1j * spec.angle())
205
+ elif self.transform_type == "log":
206
+ spec = spec / self.spec_factor
207
+ spec = (torch.exp(spec.abs()) - 1) * torch.exp(1j * spec.angle())
208
+ elif self.transform_type == "none":
209
+ spec = spec
210
+ return spec
211
+
212
+ @property
213
+ def stft_kwargs(self):
214
+ return {**self.istft_kwargs, "return_complex": True}
215
+
216
+ @property
217
+ def istft_kwargs(self):
218
+ return dict(
219
+ n_fft=self.n_fft, hop_length=self.hop_length,
220
+ window=self.window, center=True
221
+ )
222
+
223
+ def _get_window(self, x):
224
+ """
225
+ Retrieve an appropriate window for the given tensor x, matching the device.
226
+ Caches the retrieved windows so that only one window tensor will be allocated per device.
227
+ """
228
+ window = self.windows.get(x.device, None)
229
+ if window is None:
230
+ window = self.window.to(x.device)
231
+ self.windows[x.device] = window
232
+ return window
233
+
234
+ def stft(self, sig):
235
+ window = self._get_window(sig)
236
+ return torch.stft(sig, **{**self.stft_kwargs, "window": window})
237
+
238
+ def istft(self, spec, length=None):
239
+ window = self._get_window(spec)
240
+ return torch.istft(spec, **{**self.istft_kwargs, "window": window, "length": length})
241
+
242
+ def train_dataloader(self):
243
+ return DataLoader(
244
+ self.train_set, batch_size=self.batch_size,
245
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=True
246
+ )
247
+
248
+ def val_dataloader(self):
249
+ return DataLoader(
250
+ self.valid_set, batch_size=self.batch_size,
251
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
252
+ )
253
+
254
+ def test_dataloader(self):
255
+ return DataLoader(
256
+ self.test_set, batch_size=self.batch_size,
257
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
258
+ )
geco/model.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from math import ceil
3
+ import warnings
4
+ import numpy as np
5
+ import torch
6
+ import pytorch_lightning as pl
7
+ from torch_ema import ExponentialMovingAverage
8
+ import torch.nn.functional as F
9
+ from geco import sampling
10
+ from geco.sdes import SDERegistry
11
+ from geco.backbones import BackboneRegistry
12
+ from geco.util.inference import evaluate_model
13
+ from geco.util.other import pad_spec
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+
17
+
18
+
19
+ class ScoreModel(pl.LightningModule):
20
+ @staticmethod
21
+ def add_argparse_args(parser):
22
+ parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
23
+ parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
24
+ parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)")
25
+ parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
26
+ parser.add_argument("--loss_type", type=str, default="mse", help="The type of loss function to use.")
27
+ parser.add_argument("--loss_abs_exponent", type=float, default=0.5, help="magnitude transformation in the loss term")
28
+ return parser
29
+
30
+ def __init__(
31
+ self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2, loss_abs_exponent=0.5,
32
+ num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
33
+ ):
34
+ """
35
+ Create a new ScoreModel.
36
+
37
+ Args:
38
+ backbone: Backbone DNN that serves as a score-based model.
39
+ sde: The SDE that defines the diffusion process.
40
+ lr: The learning rate of the optimizer. (1e-4 by default).
41
+ ema_decay: The decay constant of the parameter EMA (0.999 by default).
42
+ t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
43
+ loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
44
+ """
45
+ super().__init__()
46
+ # Initialize Backbone DNN
47
+ dnn_cls = BackboneRegistry.get_by_name(backbone)
48
+ self.dnn = dnn_cls(**kwargs)
49
+ # Initialize SDE
50
+ if sde == 'bbve':
51
+ #change parameters, if the old class bbve is used. Needed for loading the provided checkpoint
52
+ #as that checkpoint was trained with the old class.
53
+ sde = 'bbed'
54
+ kwargs['k'] = kwargs['sigma_max']
55
+ del kwargs['sigma_max']
56
+ del kwargs['sigma_min']
57
+
58
+ sde_cls = SDERegistry.get_by_name(sde)
59
+ self.sde = sde_cls(**kwargs)
60
+ # Store hyperparams and save them
61
+ self.lr = lr
62
+ self.ema_decay = ema_decay
63
+ self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
64
+ self._error_loading_ema = False
65
+ self.t_eps = t_eps
66
+ self.loss_type = loss_type
67
+ self.num_eval_files = num_eval_files
68
+ self.loss_abs_exponent = loss_abs_exponent
69
+ self.save_hyperparameters(ignore=['no_wandb'])
70
+ self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
71
+
72
+
73
+
74
+ def configure_optimizers(self):
75
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
76
+ return optimizer
77
+
78
+ def optimizer_step(self, *args, **kwargs):
79
+ # Method overridden so that the EMA params are updated after each optimizer step
80
+ super().optimizer_step(*args, **kwargs)
81
+ self.ema.update(self.parameters())
82
+
83
+ # on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
84
+ def on_load_checkpoint(self, checkpoint):
85
+ ema = checkpoint.get('ema', None)
86
+ if ema is not None:
87
+ self.ema.load_state_dict(checkpoint['ema'])
88
+ else:
89
+ self._error_loading_ema = True
90
+ warnings.warn("EMA state_dict not found in checkpoint!")
91
+
92
+ def on_save_checkpoint(self, checkpoint):
93
+ checkpoint['ema'] = self.ema.state_dict()
94
+
95
+ def train(self, mode, no_ema=False):
96
+ res = super().train(mode) # call the standard `train` method with the given mode
97
+ if not self._error_loading_ema:
98
+ if mode == False and not no_ema:
99
+ # eval
100
+ self.ema.store(self.parameters()) # store current params in EMA
101
+ self.ema.copy_to(self.parameters()) # copy EMA parameters over current params for evaluation
102
+ else:
103
+ # train
104
+ if self.ema.collected_params is not None:
105
+ self.ema.restore(self.parameters()) # restore the EMA weights (if stored)
106
+ return res
107
+
108
+ def eval(self, no_ema=False):
109
+ return self.train(False, no_ema=no_ema)
110
+
111
+
112
+ def _loss(self, score, sigmas, z):
113
+ if self.loss_type == 'mse':
114
+ err = sigmas*score + z
115
+ losses = torch.square(err.abs())
116
+ elif self.loss_type == 'mae':
117
+ losses = err.abs()
118
+ # taken from reduce_op function: sum over channels and position and mean over batch dim
119
+ # presumably only important for absolute loss number, not for gradients
120
+ loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
121
+ return loss
122
+
123
+ def _step(self, batch, batch_idx):
124
+ x, y, m = batch
125
+ rdm = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps
126
+ t = torch.min(rdm, torch.tensor(self.sde.T))
127
+ mean, std = self.sde.marginal_prob(x, t, y)
128
+ z = torch.randn_like(x) #
129
+ sigmas = std[:, None, None, None]
130
+ perturbed_data = mean + sigmas * z
131
+ score = self(perturbed_data, t, y, m)
132
+ loss = self._loss(score, sigmas, z)
133
+ return loss
134
+
135
+ def training_step(self, batch, batch_idx):
136
+ loss = self._step(batch, batch_idx)
137
+ self.log('train_loss', loss, on_step=True, on_epoch=True)
138
+ return loss
139
+
140
+ def validation_step(self, batch, batch_idx):
141
+ loss = self._step(batch, batch_idx)
142
+ self.log('valid_loss', loss, on_step=False, on_epoch=True)
143
+
144
+ # Evaluate speech enhancement performance
145
+ if batch_idx == 0 and self.num_eval_files != 0:
146
+ pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files)
147
+ self.log('pesq', pesq, on_step=False, on_epoch=True)
148
+ self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
149
+ self.log('estoi', estoi, on_step=False, on_epoch=True)
150
+
151
+ return loss
152
+
153
+ def forward(self, x, t, y, m):
154
+ # Concatenate y as an extra channel
155
+ dnn_input = torch.cat([x, y, m], dim=1)
156
+ # print(dnn_input.shape)
157
+ # the minus is most likely unimportant here - taken from Song's repo
158
+ score = -self.dnn(dnn_input, t)
159
+ return score
160
+
161
+ def to(self, *args, **kwargs):
162
+ """Override PyTorch .to() to also transfer the EMA of the model weights"""
163
+ self.ema.to(*args, **kwargs)
164
+ return super().to(*args, **kwargs)
165
+
166
+ def get_pc_sampler(self, predictor_name, corrector_name, y, m, Y_prior=None, N=None, minibatch=None, timestep_type=None, **kwargs):
167
+ N = self.sde.N if N is None else N
168
+ sde = self.sde.copy()
169
+ sde.N = N
170
+
171
+ kwargs = {"eps": self.t_eps, **kwargs}
172
+ if minibatch is None:
173
+ return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, Y=y, M=m, Y_prior=Y_prior, timestep_type=timestep_type, **kwargs)
174
+ else:
175
+ M = y.shape[0]
176
+ def batched_sampling_fn():
177
+ samples, ns = [], []
178
+ for i in range(int(ceil(M / minibatch))):
179
+ y_mini = y[i*minibatch:(i+1)*minibatch]
180
+ y_prior_mini = Y_prior[i*minibatch:(i+1)*minibatch]
181
+ sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, Y=y_mini, M=m, y_prior=y_prior_mini, **kwargs)
182
+ sample, n = sampler()
183
+ samples.append(sample)
184
+ ns.append(n)
185
+ samples = torch.cat(samples, dim=0)
186
+ return samples, ns
187
+ return batched_sampling_fn
188
+
189
+
190
+ def train_dataloader(self):
191
+ return self.data_module.train_dataloader()
192
+
193
+ def val_dataloader(self):
194
+ return self.data_module.val_dataloader()
195
+
196
+ def test_dataloader(self):
197
+ return self.data_module.test_dataloader()
198
+
199
+ def setup(self, stage=None):
200
+ return self.data_module.setup(stage=stage)
201
+
202
+ def to_audio(self, spec, length=None):
203
+ return self._istft(self._backward_transform(spec), length)
204
+
205
+ def _forward_transform(self, spec):
206
+ return self.data_module.spec_fwd(spec)
207
+
208
+ def _backward_transform(self, spec):
209
+ return self.data_module.spec_back(spec)
210
+
211
+ def _stft(self, sig):
212
+ return self.data_module.stft(sig)
213
+
214
+ def _istft(self, spec, length=None):
215
+ return self.data_module.istft(spec, length)
216
+
217
+ def enhance(self, y, m, sampler_type="pc", predictor="reverse_diffusion",
218
+ corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False,
219
+ **kwargs
220
+ ):
221
+ """
222
+ One-call speech enhancement of noisy speech `y`, for convenience.
223
+ """
224
+ sr=8000
225
+ start = time.time()
226
+ T_orig = y.size(1)
227
+ norm_factor = y.abs().max().item()
228
+ y = y / norm_factor
229
+ m = m / norm_factor
230
+
231
+ Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0)
232
+ Y = pad_spec(Y)
233
+ M = torch.unsqueeze(self._forward_transform(self._stft(m.cuda())), 0)
234
+ M = pad_spec(M)
235
+
236
+ if sampler_type == "pc":
237
+ sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), M.cuda(), N=N,
238
+ corrector_steps=corrector_steps, snr=snr, intermediate=False,
239
+ **kwargs)
240
+ else:
241
+ print("{} is not a valid sampler type!".format(sampler_type))
242
+ sample, nfe = sampler()
243
+
244
+ sample = sample.squeeze()
245
+
246
+ x_hat = self.to_audio(sample)
247
+ x_hat = x_hat * norm_factor
248
+ x_hat = x_hat.squeeze().cpu().numpy()
249
+ end = time.time()
250
+ if timeit:
251
+ rtf = (end-start)/(len(x_hat)/sr)
252
+ return x_hat, nfe, rtf
253
+ else:
254
+ return x_hat
255
+
geco/sampling/__init__.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Various sampling methods."""
2
+ from scipy import integrate
3
+ import torch
4
+
5
+ from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
6
+ from .correctors import Corrector, CorrectorRegistry
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+
10
+
11
+ __all__ = [
12
+ 'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
13
+ 'get_sampler'
14
+ ]
15
+
16
+
17
+ def to_flattened_numpy(x):
18
+ """Flatten a torch tensor `x` and convert it to numpy."""
19
+ return x.detach().cpu().numpy().reshape((-1,))
20
+
21
+
22
+ def from_flattened_numpy(x, shape):
23
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
24
+ return torch.from_numpy(x.reshape(shape))
25
+
26
+
27
+ def get_pc_sampler(
28
+ predictor_name, corrector_name, sde, score_fn, Y, M, Y_prior=None,
29
+ denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
30
+ intermediate=False, timestep_type=None, **kwargs
31
+ ):
32
+ """Create a Predictor-Corrector (PC) sampler.
33
+
34
+ Args:
35
+ predictor_name: The name of a registered `sampling.Predictor`.
36
+ corrector_name: The name of a registered `sampling.Corrector`.
37
+ sde: An `sdes.SDE` object representing the forward SDE.
38
+ score_fn: A function (typically learned model) that predicts the score.
39
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
40
+ denoise: If `True`, add one-step denoising to the final samples.
41
+ eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
42
+ snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
43
+ N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
44
+
45
+ Returns:
46
+ A sampling function that returns samples and the number of function evaluations during sampling.
47
+ """
48
+ predictor_cls = PredictorRegistry.get_by_name(predictor_name)
49
+ corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
50
+ predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
51
+ corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
52
+
53
+ def pc_sampler(Y_prior=Y_prior, timestep_type=timestep_type):
54
+ """The PC sampler function."""
55
+ with torch.no_grad():
56
+
57
+ if Y_prior == None:
58
+ Y_prior = Y
59
+
60
+ xt, _ = sde.prior_sampling(Y_prior.shape, Y_prior)
61
+ timesteps = timesteps_space(sde.T, sde.N,eps, Y.device, type=timestep_type)
62
+ xt = xt.to(Y_prior.device)
63
+ for i in range(len(timesteps)):
64
+ t = timesteps[i]
65
+ if i != len(timesteps) - 1:
66
+ stepsize = t - timesteps[i+1]
67
+ else:
68
+ stepsize = timesteps[-1]
69
+ vec_t = torch.ones(Y.shape[0], device=Y.device) * t
70
+ xt, xt_mean = corrector.update_fn(xt, vec_t, Y, M)
71
+ xt, xt_mean = predictor.update_fn(xt, vec_t, Y, M, stepsize)
72
+ x_result = xt_mean if denoise else xt
73
+ ns = len(timesteps) * (corrector.n_steps + 1)
74
+ return x_result, ns
75
+
76
+ if intermediate:
77
+ return pc_sampler_intermediate
78
+ else:
79
+ return pc_sampler
80
+
81
+
82
+
83
+ def timesteps_space(sdeT, sdeN, eps, device, type='linear'):
84
+ timesteps = torch.linspace(sdeT, eps, sdeN, device=device)
85
+ if type == 'linear':
86
+ return timesteps
87
+ else:
88
+ pass #not used, can be used to implement different sampling schedules
89
+
90
+ return timesteps
geco/sampling/correctors.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import torch
3
+
4
+ from geco import sdes
5
+ from geco.util.registry import Registry
6
+
7
+
8
+ CorrectorRegistry = Registry("Corrector")
9
+
10
+
11
+ class Corrector(abc.ABC):
12
+ """The abstract class for a corrector algorithm."""
13
+
14
+ def __init__(self, sde, score_fn, snr, n_steps):
15
+ super().__init__()
16
+ self.rsde = sde.reverse(score_fn)
17
+ self.score_fn = score_fn
18
+ self.snr = snr
19
+ self.n_steps = n_steps
20
+
21
+ @abc.abstractmethod
22
+ def update_fn(self, x, t, *args):
23
+ """One update of the corrector.
24
+
25
+ Args:
26
+ x: A PyTorch tensor representing the current state
27
+ t: A PyTorch tensor representing the current time step.
28
+ *args: Possibly additional arguments, in particular `y` for OU processes
29
+
30
+ Returns:
31
+ x: A PyTorch tensor of the next state.
32
+ x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
33
+ """
34
+ pass
35
+
36
+
37
+ @CorrectorRegistry.register(name='ald')
38
+ class AnnealedLangevinDynamics(Corrector):
39
+ """The original annealed Langevin dynamics predictor in NCSN/NCSNv2."""
40
+ def __init__(self, sde, score_fn, snr, n_steps):
41
+ super().__init__(sde, score_fn, snr, n_steps)
42
+ self.sde = sde
43
+ self.score_fn = score_fn
44
+ self.snr = snr
45
+ self.n_steps = n_steps
46
+
47
+ def update_fn(self, x, t, y, m):
48
+ x_mean = 0
49
+ n_steps = self.n_steps
50
+ target_snr = self.snr
51
+ std = self.sde.marginal_prob(x, t, y)[1]
52
+ for _ in range(n_steps):
53
+ # print(x.shape, y.shape,m.shape)
54
+ grad = self.score_fn(x, t, y, m)
55
+ noise = torch.randn_like(x)
56
+ step_size = (target_snr * std) ** 2 * 2
57
+ x_mean = x + step_size[:, None, None, None] * grad
58
+ x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
59
+
60
+ return x, x_mean
geco/sampling/predictors.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from geco.util.registry import Registry
7
+
8
+
9
+ PredictorRegistry = Registry("Predictor")
10
+
11
+
12
+ class Predictor(abc.ABC):
13
+ """The abstract class for a predictor algorithm."""
14
+
15
+ def __init__(self, sde, score_fn, probability_flow=False):
16
+ super().__init__()
17
+ self.sde = sde
18
+ self.rsde = sde.reverse(score_fn)
19
+ self.score_fn = score_fn
20
+ self.probability_flow = probability_flow
21
+
22
+ @abc.abstractmethod
23
+ def update_fn(self, x, t, *args):
24
+ """One update of the predictor.
25
+
26
+ Args:
27
+ x: A PyTorch tensor representing the current state
28
+ t: A Pytorch tensor representing the current time step.
29
+ *args: Possibly additional arguments, in particular `y` for OU processes
30
+
31
+ Returns:
32
+ x: A PyTorch tensor of the next state.
33
+ x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
34
+ """
35
+ pass
36
+
37
+ def debug_update_fn(self, x, t, *args):
38
+ raise NotImplementedError(f"Debug update function not implemented for predictor {self}.")
39
+
40
+
41
+ @PredictorRegistry.register('reverse_diffusion')
42
+ class ReverseDiffusionPredictor(Predictor):
43
+ def __init__(self, sde, score_fn, probability_flow=False):
44
+ super().__init__(sde, score_fn, probability_flow=probability_flow)
45
+
46
+ def update_fn(self, x, t, y, m, stepsize):
47
+ f, g = self.rsde.discretize(x, t, y, m, stepsize)
48
+ z = torch.randn_like(x)
49
+ x_mean = x - f
50
+ x = x_mean + g[:, None, None, None] * z
51
+ return x, x_mean
52
+
53
+ def update_fn_analyze(self, x, t, *args):
54
+ raise NotImplementedError("update_fn_analyze() has not been implemented yet for the ReverseDiffusionPredictor")
55
+
geco/sdes.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
3
+
4
+ Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
5
+ """
6
+ import abc
7
+ import warnings
8
+ import math
9
+ import scipy.special as sc
10
+ import numpy as np
11
+ from geco.util.tensors import batch_broadcast
12
+ import torch
13
+
14
+ from geco.util.registry import Registry
15
+
16
+
17
+ SDERegistry = Registry("SDE")
18
+
19
+
20
+ class SDE(abc.ABC):
21
+ """SDE abstract class. Functions are designed for a mini-batch of inputs."""
22
+
23
+ def __init__(self, N):
24
+ """Construct an SDE.
25
+
26
+ Args:
27
+ N: number of discretization time steps.
28
+ """
29
+ super().__init__()
30
+ self.N = N
31
+
32
+ @property
33
+ @abc.abstractmethod
34
+ def T(self):
35
+ """End time of the SDE."""
36
+ pass
37
+
38
+ @abc.abstractmethod
39
+ def sde(self, x, t, *args):
40
+ pass
41
+
42
+ @abc.abstractmethod
43
+ def marginal_prob(self, x, t, *args):
44
+ """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
45
+ pass
46
+
47
+ @abc.abstractmethod
48
+ def prior_sampling(self, shape, *args):
49
+ """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
50
+ pass
51
+
52
+ @abc.abstractmethod
53
+ def prior_logp(self, z):
54
+ """Compute log-density of the prior distribution.
55
+
56
+ Useful for computing the log-likelihood via probability flow ODE.
57
+
58
+ Args:
59
+ z: latent code
60
+ Returns:
61
+ log probability density
62
+ """
63
+ pass
64
+
65
+ @staticmethod
66
+ @abc.abstractmethod
67
+ def add_argparse_args(parent_parser):
68
+ """
69
+ Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
70
+ """
71
+ pass
72
+
73
+ def discretize(self, x, t, y, stepsize):
74
+ """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
75
+
76
+ Useful for reverse diffusion sampling and probabiliy flow sampling.
77
+ Defaults to Euler-Maruyama discretization.
78
+
79
+ Args:
80
+ x: a torch tensor
81
+ t: a torch float representing the time step (from 0 to `self.T`)
82
+
83
+ Returns:
84
+ f, G
85
+ """
86
+ dt = stepsize
87
+ #dt = 1 /self.N
88
+ drift, diffusion = self.sde(x, t, y)
89
+ f = drift * dt
90
+ G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
91
+ return f, G
92
+
93
+ def reverse(oself, score_model, probability_flow=False):
94
+ """Create the reverse-time SDE/ODE.
95
+
96
+ Args:
97
+ score_model: A function that takes x, t and y and returns the score.
98
+ probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
99
+ """
100
+ N = oself.N
101
+ T = oself.T
102
+ sde_fn = oself.sde
103
+ discretize_fn = oself.discretize
104
+
105
+ # Build the class for reverse-time SDE.
106
+ class RSDE(oself.__class__):
107
+ def __init__(self):
108
+ self.N = N
109
+ self.probability_flow = probability_flow
110
+
111
+ @property
112
+ def T(self):
113
+ return T
114
+
115
+ def sde(self, x, t, *args):
116
+ """Create the drift and diffusion functions for the reverse SDE/ODE."""
117
+ rsde_parts = self.rsde_parts(x, t, *args)
118
+ total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
119
+ return total_drift, diffusion
120
+
121
+ def discretize(self, x, t, y, m, stepsize):
122
+ """Create discretized iteration rules for the reverse diffusion sampler."""
123
+ f, G = discretize_fn(x, t, y, stepsize)
124
+ if torch.is_complex(G):
125
+ G = G.imag
126
+ rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y, m) * (0.5 if self.probability_flow else 1.)
127
+ rev_G = torch.zeros_like(G) if self.probability_flow else G
128
+ return rev_f, rev_G
129
+
130
+ return RSDE()
131
+
132
+ @abc.abstractmethod
133
+ def copy(self):
134
+ pass
135
+
136
+
137
+ @SDERegistry.register("bbed")
138
+ class BBED(SDE):
139
+ @staticmethod
140
+ def add_argparse_args(parser):
141
+ parser.add_argument("--sde-n", type=int, default=30, help="The number of timesteps in the SDE discretization. 30 by default")
142
+ parser.add_argument("--T_sampling", type=float, default=0.999, help="The T so that t < T during sampling in the train step.")
143
+ parser.add_argument("--k", type=float, default = 2.6, help="base factor for diffusion term")
144
+ parser.add_argument("--theta", type=float, default = 0.52, help="root scale factor for diffusion term.")
145
+ return parser
146
+
147
+ def __init__(self, T_sampling, k, theta, N=1000, **kwargs):
148
+ """Construct an Brownian Bridge with Exploding Diffusion Coefficient SDE with parameterization as in the paper.
149
+ dx = (y-x)/(Tc-t) dt + sqrt(theta)*k^t dw
150
+ """
151
+ super().__init__(N)
152
+ self.k = k
153
+ self.logk = np.log(self.k)
154
+ self.theta = theta
155
+ self.N = N
156
+ self.Eilog = sc.expi(-2*self.logk)
157
+ self.T = T_sampling #for sampling in train step and inference
158
+ self.Tc = 1 #for constructing the SDE, dont change this
159
+
160
+
161
+ def copy(self):
162
+ return BBED(self.T, self.k, self.theta, N=self.N)
163
+
164
+
165
+ def T(self):
166
+ return self.T
167
+
168
+ def Tc(self):
169
+ return self.Tc
170
+
171
+
172
+ def sde(self, x, t, y):
173
+ drift = (y - x)/(self.Tc - t)
174
+ sigma = (self.k) ** t
175
+ diffusion = sigma * np.sqrt(self.theta)
176
+ return drift, diffusion
177
+
178
+
179
+ def _mean(self, x0, t, y):
180
+ time = (t/self.Tc)[:, None, None, None]
181
+ mean = x0*(1-time) + y*time
182
+ return mean
183
+
184
+ def _std(self, t):
185
+ t_np = t.cpu().detach().numpy()
186
+ Eis = sc.expi(2*(t_np-1)*self.logk) - self.Eilog
187
+ h = 2*self.k**2*self.logk
188
+ var = (self.k**(2*t_np)-1+t_np) + h*(1-t_np)*Eis
189
+ var = torch.tensor(var).to(device=t.device)*(1-t)*self.theta
190
+ return torch.sqrt(var)
191
+
192
+ def marginal_prob(self, x0, t, y):
193
+ return self._mean(x0, t, y), self._std(t)
194
+
195
+ def prior_sampling(self, shape, y):
196
+ if shape != y.shape:
197
+ warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
198
+ std = self._std(self.T*torch.ones((y.shape[0],), device=y.device))
199
+ z = torch.randn_like(y)
200
+ x_T = y + z * std[:, None, None, None]
201
+ return x_T, z
202
+
203
+ def prior_logp(self, z):
204
+ raise NotImplementedError("prior_logp for BBED not yet implemented!")
205
+
geco/util/inference.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torch.nn.functional as F
4
+ from pesq import pesq
5
+ from pystoi import stoi
6
+
7
+ from .other import si_sdr, pad_spec
8
+
9
+ # Settings
10
+ sr = 8000
11
+ snr = 0.5
12
+ N = 30
13
+ corrector_steps = 1
14
+
15
+
16
+ def evaluate_model(model, num_eval_files):
17
+
18
+ clean_files = model.data_module.valid_set.clean_files
19
+ noisy_files = model.data_module.valid_set.noisy_files
20
+ mixture_files = model.data_module.valid_set.mixture_files
21
+
22
+ # Select test files uniformly accros validation files
23
+ total_num_files = len(clean_files)
24
+ indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
25
+ clean_files = list(clean_files[i] for i in indices)
26
+ noisy_files = list(noisy_files[i] for i in indices)
27
+ mixture_files = list(mixture_files[i] for i in indices)
28
+
29
+ _pesq = 0
30
+ _si_sdr = 0
31
+ _estoi = 0
32
+ # iterate over files
33
+ for (clean_file, noisy_file, mixture_file) in zip(clean_files, noisy_files, mixture_files):
34
+ # Load wavs
35
+ x, sr_ = torchaudio.load(clean_file)
36
+ if sr_ != sr:
37
+ x = torchaudio.transforms.Resample(sr_, sr)(x)
38
+ y, sr_ = torchaudio.load(noisy_file)
39
+ if sr_ != sr:
40
+ y = torchaudio.transforms.Resample(sr_, sr)(y)
41
+ m, sr_ = torchaudio.load(mixture_file)
42
+ if sr_ != sr:
43
+ m = torchaudio.transforms.Resample(sr_, sr)(m)
44
+
45
+ min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1])
46
+ x = x[...,:min_leng]
47
+ y = y[...,:min_leng]
48
+ m = m[...,:min_leng]
49
+
50
+ T_orig = x.size(1)
51
+
52
+ # Normalize per utterance
53
+ norm_factor = y.abs().max()
54
+ y = y / norm_factor
55
+ m = m / norm_factor
56
+
57
+ # Prepare DNN input
58
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
59
+ Y = pad_spec(Y)
60
+
61
+ M = torch.unsqueeze(model._forward_transform(model._stft(m.cuda())), 0)
62
+ M = pad_spec(M)
63
+
64
+ y = y * norm_factor
65
+
66
+ # print(x.shape,y.shape,m.shape,Y.shape,M.shape)
67
+ # Reverse sampling
68
+ sampler = model.get_pc_sampler(
69
+ 'reverse_diffusion', 'ald', Y.cuda(), M.cuda(), N=N,
70
+ corrector_steps=corrector_steps, snr=snr)
71
+ sample, _ = sampler()
72
+
73
+ sample = sample.squeeze()
74
+
75
+
76
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
77
+ x_hat = x_hat * norm_factor
78
+
79
+ x_hat = x_hat.squeeze().cpu().numpy()
80
+ x = x.squeeze().cpu().numpy()
81
+ y = y.squeeze().cpu().numpy()
82
+
83
+ _si_sdr += si_sdr(x, x_hat)
84
+ _pesq += pesq(sr, x, x_hat, 'nb')
85
+ _estoi += stoi(x, x_hat, sr, extended=True)
86
+
87
+ return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
88
+
89
+
90
+ def evaluate_model2(model, num_eval_files, inference_N, inference_start=0.5):
91
+
92
+
93
+ N = inference_N
94
+ reverse_start_time = inference_start
95
+
96
+ clean_files = model.data_module.valid_set.clean_files
97
+ noisy_files = model.data_module.valid_set.noisy_files
98
+ mixture_files = model.data_module.valid_set.mixture_files
99
+
100
+ # Select test files uniformly accros validation files
101
+ total_num_files = len(clean_files)
102
+ indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
103
+ clean_files = list(clean_files[i] for i in indices)
104
+ noisy_files = list(noisy_files[i] for i in indices)
105
+ mixture_files = list(mixture_files[i] for i in indices)
106
+
107
+
108
+
109
+ _pesq = 0
110
+ _si_sdr = 0
111
+ _estoi = 0
112
+ # iterate over files
113
+ for (clean_file, noisy_file, mixture_file) in zip(clean_files, noisy_files, mixture_files):
114
+ # Load wavs
115
+ x, sr_ = torchaudio.load(clean_file)
116
+ if sr_ != sr:
117
+ x = torchaudio.transforms.Resample(sr_, sr)(x)
118
+ y, sr_ = torchaudio.load(noisy_file)
119
+ if sr_ != sr:
120
+ y = torchaudio.transforms.Resample(sr_, sr)(y)
121
+ m, sr_ = torchaudio.load(mixture_file)
122
+ if sr_ != sr:
123
+ m = torchaudio.transforms.Resample(sr_, sr)(m)
124
+
125
+ #requires only for BWE as the dataset has different length of clean and noisy files
126
+ min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1])
127
+ x = x[...,:min_leng]
128
+ y = y[...,:min_leng]
129
+ m = m[...,:min_leng]
130
+
131
+ T_orig = x.size(1)
132
+
133
+ # Normalize per utterance
134
+ norm_factor = y.abs().max()
135
+ y = y / norm_factor
136
+ x = x / norm_factor
137
+ m = m / norm_factor
138
+
139
+ # Prepare DNN input
140
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
141
+ Y = pad_spec(Y)
142
+
143
+ X = torch.unsqueeze(model._forward_transform(model._stft(x.cuda())), 0)
144
+ X = pad_spec(X)
145
+
146
+ M = torch.unsqueeze(model._forward_transform(model._stft(m.cuda())), 0)
147
+ M = pad_spec(M)
148
+
149
+
150
+ y = y * norm_factor
151
+ x = x * norm_factor
152
+
153
+ x = x.squeeze().cpu().numpy()
154
+ y = y.squeeze().cpu().numpy()
155
+
156
+ total_loss = 0
157
+ timesteps = torch.linspace(reverse_start_time, 0.03, N, device=Y.device)
158
+ #prior sampling starting from reverse_start_time
159
+ std = model.sde._std(reverse_start_time*torch.ones((Y.shape[0],), device=Y.device))
160
+ z = torch.randn_like(Y)
161
+ X_t = Y + z * std[:, None, None, None]
162
+
163
+ #reverse steps by Euler Maruyama
164
+ for i in range(len(timesteps)):
165
+ t = timesteps[i]
166
+ if i != len(timesteps) - 1:
167
+ dt = t - timesteps[i+1]
168
+ else:
169
+ dt = timesteps[-1]
170
+ with torch.no_grad():
171
+ #take Euler step here
172
+ f, g = model.sde.sde(X_t, t, Y)
173
+ vec_t = torch.ones(Y.shape[0], device=Y.device) * t
174
+ score = model.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None])
175
+ mean_x_tm1 = X_t - (f - g**2*score)*dt #mean of x t minus 1 = mu(x_{t-1})
176
+ if i == len(timesteps) - 1: #output
177
+ X_t = mean_x_tm1
178
+ break
179
+ z = torch.randn_like(X)
180
+
181
+ X_t = mean_x_tm1 + z*g*torch.sqrt(dt)
182
+
183
+ sample = X_t
184
+ sample = sample.squeeze()
185
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
186
+ x_hat = x_hat * norm_factor
187
+ x_hat = x_hat.squeeze().cpu().numpy()
188
+ _si_sdr += si_sdr(x, x_hat)
189
+ _pesq += pesq(sr, x, x_hat, 'nb')
190
+ _estoi += stoi(x, x_hat, sr, extended=True)
191
+
192
+ return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files, total_loss/num_eval_files
193
+
194
+
195
+ def convert_to_audio(X, deemp, T_orig, model, norm_factor):
196
+
197
+ sample = X
198
+
199
+ sample = sample.squeeze()
200
+ if len(sample.shape)==4:
201
+ sample = sample*deemp[None, None, :, None].to(device=sample.device)
202
+ elif len(sample.shape)==3:
203
+ sample = sample*deemp[None, :, None].to(device=sample.device)
204
+ else:
205
+ sample = sample*deemp[:, None].to(device=sample.device)
206
+
207
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
208
+ x_hat = x_hat * norm_factor
209
+
210
+ x_hat = x_hat.squeeze().cpu().numpy()
211
+ return x_hat
geco/util/other.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+
5
+ import scipy.stats
6
+ from scipy.signal import butter, sosfilt
7
+
8
+ import torch
9
+
10
+ from pesq import pesq
11
+ from pystoi import stoi
12
+
13
+
14
+ def si_sdr_components(s_hat, s, n):
15
+ """
16
+ """
17
+ # s_target
18
+ alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
19
+ s_target = alpha_s * s
20
+
21
+ # e_noise
22
+ alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
23
+ e_noise = alpha_n * n
24
+
25
+ # e_art
26
+ e_art = s_hat - s_target - e_noise
27
+
28
+ return s_target, e_noise, e_art
29
+
30
+ def energy_ratios(s_hat, s, n):
31
+ """
32
+ """
33
+ s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
34
+
35
+ si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
36
+ si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
37
+ si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
38
+
39
+ return si_sdr, si_sir, si_sar
40
+
41
+ def mean_conf_int(data, confidence=0.95):
42
+ a = 1.0 * np.array(data)
43
+ n = len(a)
44
+ m, se = np.mean(a), scipy.stats.sem(a)
45
+ h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
46
+ return m, h
47
+
48
+ class Method():
49
+ def __init__(self, name, base_dir, metrics):
50
+ self.name = name
51
+ self.base_dir = base_dir
52
+ self.metrics = {}
53
+
54
+ for i in range(len(metrics)):
55
+ metric = metrics[i]
56
+ value = []
57
+ self.metrics[metric] = value
58
+
59
+ def append(self, matric, value):
60
+ self.metrics[matric].append(value)
61
+
62
+ def get_mean_ci(self, metric):
63
+ return mean_conf_int(np.array(self.metrics[metric]))
64
+
65
+ def hp_filter(signal, cut_off=80, order=10, sr=16000):
66
+ factor = cut_off /sr * 2
67
+ sos = butter(order, factor, 'hp', output='sos')
68
+ filtered = sosfilt(sos, signal)
69
+ return filtered
70
+
71
+ def si_sdr(s, s_hat):
72
+ alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
73
+ sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
74
+ alpha*s - s_hat)**2)
75
+ return sdr
76
+
77
+ def snr_dB(s,n):
78
+ s_power = 1/len(s)*np.sum(s**2)
79
+ n_power = 1/len(n)*np.sum(n**2)
80
+ snr_dB = 10*np.log10(s_power/n_power)
81
+ return snr_dB
82
+
83
+ def pad_spec(Y):
84
+ T = Y.size(3)
85
+ if T%64 !=0:
86
+ num_pad = 64-T%64
87
+ else:
88
+ num_pad = 0
89
+ pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
90
+ return pad2d(Y)
91
+
92
+
93
+ def ensure_dir(file_path):
94
+ directory = file_path
95
+ if not os.path.exists(directory):
96
+ os.makedirs(directory)
97
+
98
+
99
+ def print_metrics(x, y, x_hat_list, labels, sr=16000):
100
+ _si_sdr_mix = si_sdr(x, y)
101
+ _pesq_mix = pesq(sr, x, y, 'wb')
102
+ _estoi_mix = stoi(x, y, sr, extended=True)
103
+ print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
104
+ for i, x_hat in enumerate(x_hat_list):
105
+ _si_sdr = si_sdr(x, x_hat)
106
+ _pesq = pesq(sr, x, x_hat, 'wb')
107
+ _estoi = stoi(x, x_hat, sr, extended=True)
108
+ print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
109
+
110
+ def mean_std(data):
111
+ data = data[~np.isnan(data)]
112
+ mean = np.mean(data)
113
+ std = np.std(data)
114
+ return mean, std
115
+
116
+ def print_mean_std(data, decimal=2):
117
+ data = np.array(data)
118
+ data = data[~np.isnan(data)]
119
+ mean = np.mean(data)
120
+ std = np.std(data)
121
+ if decimal == 2:
122
+ string = f'{mean:.2f} ± {std:.2f}'
123
+ elif decimal == 1:
124
+ string = f'{mean:.1f} ± {std:.1f}'
125
+ return string
geco/util/registry.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Callable
3
+
4
+
5
+ class Registry:
6
+ def __init__(self, managed_thing: str):
7
+ """
8
+ Create a new registry.
9
+
10
+ Args:
11
+ managed_thing: A string describing what type of thing is managed by this registry. Will be used for
12
+ warnings and errors, so it's a good idea to keep this string globally unique and easily understood.
13
+ """
14
+ self.managed_thing = managed_thing
15
+ self._registry = {}
16
+
17
+ def register(self, name: str) -> Callable:
18
+ def inner_wrapper(wrapped_class) -> Callable:
19
+ if name in self._registry:
20
+ warnings.warn(f"{self.managed_thing} with name '{name}' doubly registered, old class will be replaced.")
21
+ self._registry[name] = wrapped_class
22
+ return wrapped_class
23
+ return inner_wrapper
24
+
25
+ def get_by_name(self, name: str):
26
+ """Get a managed thing by name."""
27
+ if name in self._registry:
28
+ return self._registry[name]
29
+ else:
30
+ raise ValueError(f"{self.managed_thing} with name '{name}' unknown.")
31
+
32
+ def get_all_names(self):
33
+ """Get the list of things' names registered to this registry."""
34
+ return list(self._registry.keys())
geco/util/tensors.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def batch_broadcast(a, x):
2
+ """Broadcasts a over all dimensions of x, except the batch dimension, which must match."""
3
+
4
+ if len(a.shape) != 1:
5
+ a = a.squeeze()
6
+ if len(a.shape) != 1:
7
+ raise ValueError(
8
+ f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})"
9
+ )
10
+
11
+ if a.shape[0] != x.shape[0] and a.shape[0] != 1:
12
+ raise ValueError(
13
+ f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching")
14
+
15
+ out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1))))
16
+ return out
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch-lightning==1.5.10
2
+ torch==2.4.0
3
+ torch-ema==0.3
4
+ torch-optimizer==0.3.0
5
+ torch-stoi==0.1.2
6
+ torchaudio==2.4.0
7
+ torchinfo==1.6.3
8
+ torchmetrics==0.9.3
9
+ torchsde==0.2.5
10
+ torchvision==0.19.0
11
+ tornado==6.2
12
+ tqdm==4.63.0
13
+ ninja
14
+ matplotlib
15
+ pesq
16
+ wandb
17
+ PySoundFile
18
+ pandas
19
+ git+https://github.com/WangHelin1997/Fast-GeCo.git#subdirectory=score_models
20
+ speechbrain==1.0.0