anonymous9a7b
commited on
Commit
·
d4c980e
1
Parent(s):
8207a5c
- app.py +214 -4
- demo/item0_mix.wav +0 -0
- demo/item1_mix.wav +0 -0
- demo/item2_mix.wav +0 -0
- demo/item3_mix.wav +0 -0
- demo/item4_mix.wav +0 -0
- fastgeco/.DS_Store +0 -0
- fastgeco/backbones/.DS_Store +0 -0
- fastgeco/backbones/__init__.py +4 -0
- fastgeco/backbones/ncsnpp.py +406 -0
- fastgeco/backbones/ncsnpp_utils/layers.py +662 -0
- fastgeco/backbones/ncsnpp_utils/layerspp.py +274 -0
- fastgeco/backbones/ncsnpp_utils/normalization.py +215 -0
- fastgeco/backbones/ncsnpp_utils/utils.py +189 -0
- fastgeco/backbones/shared.py +123 -0
- fastgeco/model.py +258 -0
- geco/.DS_Store +0 -0
- geco/backbones/.DS_Store +0 -0
- geco/backbones/__init__.py +4 -0
- geco/backbones/ncsnpp.py +405 -0
- geco/backbones/ncsnpp_utils/.DS_Store +0 -0
- geco/backbones/ncsnpp_utils/layers.py +662 -0
- geco/backbones/ncsnpp_utils/layerspp.py +202 -0
- geco/backbones/ncsnpp_utils/normalization.py +215 -0
- geco/backbones/ncsnpp_utils/utils.py +189 -0
- geco/backbones/shared.py +123 -0
- geco/data_module.py +258 -0
- geco/model.py +255 -0
- geco/sampling/__init__.py +90 -0
- geco/sampling/correctors.py +60 -0
- geco/sampling/predictors.py +55 -0
- geco/sdes.py +205 -0
- geco/util/inference.py +211 -0
- geco/util/other.py +125 -0
- geco/util/registry.py +34 -0
- geco/util/tensors.py +16 -0
- requirements.txt +20 -0
app.py
CHANGED
@@ -1,7 +1,217 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from fastgeco.model import ScoreModel
|
6 |
+
from geco.util.other import pad_spec
|
7 |
+
import os
|
8 |
+
import torchaudio
|
9 |
+
from speechbrain.lobes.models.dual_path import Encoder, SBTransformerBlock, SBTransformerBlock, Dual_Path_Model, Decoder
|
10 |
|
11 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
12 |
|
13 |
+
def load_sepformer(ckpt_path):
|
14 |
+
encoder = Encoder(
|
15 |
+
kernel_size=160,
|
16 |
+
out_channels=256,
|
17 |
+
in_channels=1
|
18 |
+
)
|
19 |
+
SBtfintra = SBTransformerBlock(
|
20 |
+
num_layers=8,
|
21 |
+
d_model=256,
|
22 |
+
nhead=8,
|
23 |
+
d_ffn=1024,
|
24 |
+
dropout=0,
|
25 |
+
use_positional_encoding=True,
|
26 |
+
norm_before=True,
|
27 |
+
)
|
28 |
+
SBtfinter = SBTransformerBlock(
|
29 |
+
num_layers=8,
|
30 |
+
d_model=256,
|
31 |
+
nhead=8,
|
32 |
+
d_ffn=1024,
|
33 |
+
dropout=0,
|
34 |
+
use_positional_encoding=True,
|
35 |
+
norm_before=True,
|
36 |
+
)
|
37 |
+
masknet = Dual_Path_Model(
|
38 |
+
num_spks=args.num_spks,
|
39 |
+
in_channels=256,
|
40 |
+
out_channels=256,
|
41 |
+
num_layers=2,
|
42 |
+
K=250,
|
43 |
+
intra_model=SBtfintra,
|
44 |
+
inter_model=SBtfinter,
|
45 |
+
norm='ln',
|
46 |
+
linear_layer_after_inter_intra=False,
|
47 |
+
skip_around_intra=True,
|
48 |
+
)
|
49 |
+
decoder = Decoder(
|
50 |
+
in_channels=256,
|
51 |
+
out_channels=1,
|
52 |
+
kernel_size=160,
|
53 |
+
stride=80,
|
54 |
+
bias=False,
|
55 |
+
)
|
56 |
+
|
57 |
+
encoder_weights = torch.load(os.path.join(ckpt_path, 'encoder.ckpt'))
|
58 |
+
encoder.load_state_dict(encoder_weights)
|
59 |
+
masknet_weights = torch.load(os.path.join(ckpt_path, 'masknet.ckpt'))
|
60 |
+
masknet.load_state_dict(masknet_weights)
|
61 |
+
decoder_weights = torch.load(os.path.join(ckpt_path, 'decoder.ckpt'))
|
62 |
+
decoder.load_state_dict(decoder_weights)
|
63 |
+
encoder = encoder.eval().to(device)
|
64 |
+
masknet = masknet.eval().to(device)
|
65 |
+
decoder = decoder.eval().to(device)
|
66 |
+
return encoder, masknet, decoder
|
67 |
+
|
68 |
+
def load_fastgeco(ckpt_path):
|
69 |
+
checkpoint_file = os.path.join(ckpt_path, 'fastgeco.ckpt')
|
70 |
+
model = ScoreModel.load_from_checkpoint(
|
71 |
+
checkpoint_file,
|
72 |
+
batch_size=1, num_workers=0, kwargs=dict(gpu=False)
|
73 |
+
)
|
74 |
+
model.eval(no_ema=False)
|
75 |
+
model.to(device)
|
76 |
+
return model
|
77 |
+
|
78 |
+
ckpt_path = 'ckpts/'
|
79 |
+
encoder, masknet, decoder = load_sepformer(ckpt_path)
|
80 |
+
fastgeco_model = load_fastgeco(ckpt_path)
|
81 |
+
sample_rate = 8000
|
82 |
+
num_spks = 2
|
83 |
+
|
84 |
+
@spaces.GPU
|
85 |
+
def separate(test_file, encoder, masknet, decoder):
|
86 |
+
with torch.no_grad():
|
87 |
+
print('Process SepFormer...')
|
88 |
+
mix, fs_file = torchaudio.load(test_file)
|
89 |
+
mix = mix.to(device)
|
90 |
+
fs_model = sample_rate
|
91 |
+
|
92 |
+
# resample the data if needed
|
93 |
+
if fs_file != fs_model:
|
94 |
+
print(
|
95 |
+
"Resampling the audio from {} Hz to {} Hz".format(
|
96 |
+
fs_file, fs_model
|
97 |
+
)
|
98 |
+
)
|
99 |
+
tf = torchaudio.transforms.Resample(
|
100 |
+
orig_freq=fs_file, new_freq=fs_model
|
101 |
+
).to(device)
|
102 |
+
mix = mix.mean(dim=0, keepdim=True)
|
103 |
+
mix = tf(mix)
|
104 |
+
|
105 |
+
mix = mix.to(device)
|
106 |
+
|
107 |
+
# Separation
|
108 |
+
mix_w = encoder(mix)
|
109 |
+
est_mask = masknet(mix_w)
|
110 |
+
mix_w = torch.stack([mix_w] * num_spks)
|
111 |
+
sep_h = mix_w * est_mask
|
112 |
+
|
113 |
+
# Decoding
|
114 |
+
est_sources = torch.cat(
|
115 |
+
[
|
116 |
+
decoder(sep_h[i]).unsqueeze(-1)
|
117 |
+
for i in range(num_spks)
|
118 |
+
],
|
119 |
+
dim=-1,
|
120 |
+
)
|
121 |
+
est_sources = (
|
122 |
+
est_sources / est_sources.abs().max(dim=1, keepdim=True)[0]
|
123 |
+
).squeeze()
|
124 |
+
|
125 |
+
return est_sources, mix
|
126 |
+
|
127 |
+
|
128 |
+
@spaces.GPU
|
129 |
+
def correct(model, est_sources, mix):
|
130 |
+
with torch.no_grad():
|
131 |
+
print('Process Fast-Geco...')
|
132 |
+
N = 1
|
133 |
+
reverse_starting_point = 0.5
|
134 |
+
output = []
|
135 |
+
for idx in range(num_spks):
|
136 |
+
y = est_sources[:, idx].unsqueeze(0) # noisy
|
137 |
+
m = mix
|
138 |
+
min_leng = min(y.shape[-1],m.shape[-1])
|
139 |
+
y = y[...,:min_leng]
|
140 |
+
m = m[...,:min_leng]
|
141 |
+
T_orig = y.size(1)
|
142 |
+
|
143 |
+
norm_factor = y.abs().max()
|
144 |
+
y = y / norm_factor
|
145 |
+
m = m / norm_factor
|
146 |
+
Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(device))), 0)
|
147 |
+
Y = pad_spec(Y)
|
148 |
+
M = torch.unsqueeze(model._forward_transform(model._stft(m.to(device))), 0)
|
149 |
+
M = pad_spec(M)
|
150 |
+
|
151 |
+
timesteps = torch.linspace(reverse_starting_point, 0.03, N, device=Y.device)
|
152 |
+
std = model.sde._std(reverse_starting_point*torch.ones((Y.shape[0],), device=Y.device))
|
153 |
+
z = torch.randn_like(Y)
|
154 |
+
X_t = Y + z * std[:, None, None, None]
|
155 |
+
|
156 |
+
t = timesteps[0]
|
157 |
+
dt = timesteps[-1]
|
158 |
+
f, g = model.sde.sde(X_t, t, Y)
|
159 |
+
vec_t = torch.ones(Y.shape[0], device=Y.device) * t
|
160 |
+
mean_x_tm1 = X_t - (f - g**2*model.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None]))*dt #mean of x t minus 1 = mu(x_{t-1})
|
161 |
+
sample = mean_x_tm1
|
162 |
+
sample = sample.squeeze()
|
163 |
+
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
164 |
+
x_hat = x_hat * norm_factor
|
165 |
+
new_norm_factor = x_hat.abs().max()
|
166 |
+
x_hat = x_hat / new_norm_factor
|
167 |
+
x_hat = x_hat.squeeze().cpu().numpy()
|
168 |
+
output.append(x_hat)
|
169 |
+
return output[0], output[1]
|
170 |
+
|
171 |
+
@spaces.GPU
|
172 |
+
def process_audio(test_file):
|
173 |
+
result, mix = separate(test_file, encoder, masknet, decoder)
|
174 |
+
audio1, audio2 = correct(fastgeco_model, result, mix)
|
175 |
+
return audio1, audio2
|
176 |
+
|
177 |
+
|
178 |
+
# CSS styling (optional)
|
179 |
+
css = """
|
180 |
+
#col-container {
|
181 |
+
margin: 0 auto;
|
182 |
+
max-width: 1280px;
|
183 |
+
}
|
184 |
+
"""
|
185 |
+
|
186 |
+
# Gradio Blocks layout
|
187 |
+
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
|
188 |
+
with gr.Column(elem_id="col-container"):
|
189 |
+
gr.Markdown("""
|
190 |
+
# Fast-GeCo: Noise-robust Speech Separation with Fast Generative Correction
|
191 |
+
Separate the noisy mixture speech with a generative correction method, only support 2 speakers now.
|
192 |
+
|
193 |
+
Learn more about 🟣**Fast-GeCo** on the [Fast-GeCo Repo](https://github.com/WangHelin1997/Fast-GeCo/).
|
194 |
+
""")
|
195 |
+
|
196 |
+
with gr.Tab("Speech Separation"):
|
197 |
+
# Input: Upload audio file
|
198 |
+
with gr.Row():
|
199 |
+
gt_file_input = gr.Audio(label="Upload Audio to Separate", type="filepath", value="demo/item0_mix.wav")
|
200 |
+
button = gr.Button("Generate", scale=1)
|
201 |
+
|
202 |
+
# Output Component for edited audio
|
203 |
+
with gr.Row():
|
204 |
+
result1 = gr.Audio(label="Separated Audio 1", type="numpy")
|
205 |
+
result2 = gr.Audio(label="Separated Audio 2", type="numpy")
|
206 |
+
|
207 |
+
# Define the trigger and input-output linking
|
208 |
+
button.click(
|
209 |
+
fn=process_audio,
|
210 |
+
inputs=[
|
211 |
+
gt_file_input,
|
212 |
+
],
|
213 |
+
outputs=[result1, result2]
|
214 |
+
)
|
215 |
+
|
216 |
+
# Launch the Gradio demo
|
217 |
+
demo.launch()
|
demo/item0_mix.wav
ADDED
Binary file (173 kB). View file
|
|
demo/item1_mix.wav
ADDED
Binary file (164 kB). View file
|
|
demo/item2_mix.wav
ADDED
Binary file (103 kB). View file
|
|
demo/item3_mix.wav
ADDED
Binary file (105 kB). View file
|
|
demo/item4_mix.wav
ADDED
Binary file (104 kB). View file
|
|
fastgeco/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
fastgeco/backbones/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
fastgeco/backbones/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .shared import BackboneRegistry
|
2 |
+
from .ncsnpp import NCSNpp
|
3 |
+
|
4 |
+
__all__ = ['BackboneRegistry', 'NCSNpp']
|
fastgeco/backbones/ncsnpp.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
from score_models.layers import UpsampleLayer, DownsampleLayer
|
18 |
+
from .ncsnpp_utils import layers, layerspp, normalization
|
19 |
+
import torch.nn as nn
|
20 |
+
import functools
|
21 |
+
import torch
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from .shared import BackboneRegistry
|
25 |
+
|
26 |
+
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
|
27 |
+
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
|
28 |
+
Combine = layerspp.Combine
|
29 |
+
conv3x3 = layerspp.conv3x3
|
30 |
+
conv1x1 = layerspp.conv1x1
|
31 |
+
get_act = layers.get_act
|
32 |
+
get_normalization = normalization.get_normalization
|
33 |
+
default_initializer = layers.default_init
|
34 |
+
|
35 |
+
|
36 |
+
@BackboneRegistry.register("ncsnpp")
|
37 |
+
class NCSNpp(nn.Module):
|
38 |
+
"""NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def add_argparse_args(parser):
|
42 |
+
# TODO: add additional arguments of constructor, if you wish to modify them.
|
43 |
+
return parser
|
44 |
+
|
45 |
+
def __init__(self,
|
46 |
+
scale_by_sigma = True,
|
47 |
+
nonlinearity = 'swish',
|
48 |
+
nf = 128,
|
49 |
+
ch_mult = (1, 1, 2, 2, 2, 2, 2),
|
50 |
+
num_res_blocks = 2,
|
51 |
+
attn_resolutions = (16,),
|
52 |
+
resamp_with_conv = True,
|
53 |
+
conditional = True,
|
54 |
+
fir = True,
|
55 |
+
fir_kernel = 'song',
|
56 |
+
skip_rescale = True,
|
57 |
+
resblock_type = 'biggan',
|
58 |
+
progressive = 'output_skip',
|
59 |
+
progressive_input = 'input_skip',
|
60 |
+
progressive_combine = 'sum',
|
61 |
+
init_scale = 0.,
|
62 |
+
fourier_scale = 16,
|
63 |
+
image_size = 256,
|
64 |
+
embedding_type = 'fourier',
|
65 |
+
dropout = .0,
|
66 |
+
**unused_kwargs
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.act = act = get_act(nonlinearity)
|
70 |
+
|
71 |
+
self.nf = nf = nf
|
72 |
+
ch_mult = ch_mult
|
73 |
+
self.num_res_blocks = num_res_blocks = num_res_blocks
|
74 |
+
self.attn_resolutions = attn_resolutions = attn_resolutions
|
75 |
+
dropout = dropout
|
76 |
+
resamp_with_conv = resamp_with_conv
|
77 |
+
self.num_resolutions = num_resolutions = len(ch_mult)
|
78 |
+
self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
|
79 |
+
|
80 |
+
self.conditional = conditional = conditional # noise-conditional
|
81 |
+
self.scale_by_sigma = scale_by_sigma
|
82 |
+
fir = fir
|
83 |
+
fir_kernel = [1, 3, 3, 1]
|
84 |
+
self.skip_rescale = skip_rescale = skip_rescale
|
85 |
+
self.resblock_type = resblock_type = resblock_type.lower()
|
86 |
+
self.progressive = progressive = progressive.lower()
|
87 |
+
self.progressive_input = progressive_input = progressive_input.lower()
|
88 |
+
self.embedding_type = embedding_type = embedding_type.lower()
|
89 |
+
init_scale = init_scale
|
90 |
+
assert progressive in ['none', 'output_skip', 'residual']
|
91 |
+
assert progressive_input in ['none', 'input_skip', 'residual']
|
92 |
+
assert embedding_type in ['fourier', 'positional']
|
93 |
+
combine_method = progressive_combine.lower()
|
94 |
+
combiner = functools.partial(Combine, method=combine_method)
|
95 |
+
|
96 |
+
num_channels = 6 # x.real, x.imag, y.real, y.imag
|
97 |
+
self.output_layer = nn.Conv2d(num_channels, 2, 1)
|
98 |
+
|
99 |
+
modules = []
|
100 |
+
# timestep/noise_level embedding
|
101 |
+
if embedding_type == 'fourier':
|
102 |
+
# Gaussian Fourier features embeddings.
|
103 |
+
modules.append(layerspp.GaussianFourierProjection(
|
104 |
+
embedding_size=nf, scale=fourier_scale
|
105 |
+
))
|
106 |
+
embed_dim = 2 * nf
|
107 |
+
elif embedding_type == 'positional':
|
108 |
+
embed_dim = nf
|
109 |
+
else:
|
110 |
+
raise ValueError(f'embedding type {embedding_type} unknown.')
|
111 |
+
|
112 |
+
if conditional:
|
113 |
+
modules.append(nn.Linear(embed_dim, nf * 4))
|
114 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
115 |
+
nn.init.zeros_(modules[-1].bias)
|
116 |
+
modules.append(nn.Linear(nf * 4, nf * 4))
|
117 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
118 |
+
nn.init.zeros_(modules[-1].bias)
|
119 |
+
|
120 |
+
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
121 |
+
init_scale=init_scale, skip_rescale=skip_rescale)
|
122 |
+
|
123 |
+
Upsample = functools.partial(UpsampleLayer,
|
124 |
+
with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
125 |
+
|
126 |
+
if progressive == 'output_skip':
|
127 |
+
self.pyramid_upsample = UpsampleLayer(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
128 |
+
elif progressive == 'residual':
|
129 |
+
pyramid_upsample = functools.partial(UpsampleLayer, fir=fir,
|
130 |
+
fir_kernel=fir_kernel, with_conv=True)
|
131 |
+
|
132 |
+
Downsample = functools.partial(DownsampleLayer, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
133 |
+
|
134 |
+
if progressive_input == 'input_skip':
|
135 |
+
self.pyramid_downsample = DownsampleLayer(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
136 |
+
elif progressive_input == 'residual':
|
137 |
+
pyramid_downsample = functools.partial(DownsampleLayer,
|
138 |
+
fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
139 |
+
|
140 |
+
if resblock_type == 'ddpm':
|
141 |
+
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
|
142 |
+
dropout=dropout, init_scale=init_scale,
|
143 |
+
skip_rescale=skip_rescale, temb_dim=nf * 4)
|
144 |
+
|
145 |
+
elif resblock_type == 'biggan':
|
146 |
+
ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
|
147 |
+
dropout=dropout, fir=fir, fir_kernel=fir_kernel,
|
148 |
+
init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
|
149 |
+
|
150 |
+
else:
|
151 |
+
raise ValueError(f'resblock type {resblock_type} unrecognized.')
|
152 |
+
|
153 |
+
# Downsampling block
|
154 |
+
|
155 |
+
channels = num_channels
|
156 |
+
if progressive_input != 'none':
|
157 |
+
input_pyramid_ch = channels
|
158 |
+
|
159 |
+
modules.append(conv3x3(channels, nf))
|
160 |
+
hs_c = [nf]
|
161 |
+
|
162 |
+
in_ch = nf
|
163 |
+
for i_level in range(num_resolutions):
|
164 |
+
# Residual blocks for this resolution
|
165 |
+
for i_block in range(num_res_blocks):
|
166 |
+
out_ch = nf * ch_mult[i_level]
|
167 |
+
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
168 |
+
in_ch = out_ch
|
169 |
+
|
170 |
+
if all_resolutions[i_level] in attn_resolutions:
|
171 |
+
modules.append(AttnBlock(channels=in_ch))
|
172 |
+
hs_c.append(in_ch)
|
173 |
+
|
174 |
+
if i_level != num_resolutions - 1:
|
175 |
+
if resblock_type == 'ddpm':
|
176 |
+
modules.append(Downsample(in_ch=in_ch))
|
177 |
+
else:
|
178 |
+
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
179 |
+
|
180 |
+
if progressive_input == 'input_skip':
|
181 |
+
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
182 |
+
if combine_method == 'cat':
|
183 |
+
in_ch *= 2
|
184 |
+
|
185 |
+
elif progressive_input == 'residual':
|
186 |
+
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
|
187 |
+
input_pyramid_ch = in_ch
|
188 |
+
|
189 |
+
hs_c.append(in_ch)
|
190 |
+
|
191 |
+
in_ch = hs_c[-1]
|
192 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
193 |
+
modules.append(AttnBlock(channels=in_ch))
|
194 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
195 |
+
|
196 |
+
pyramid_ch = 0
|
197 |
+
# Upsampling block
|
198 |
+
for i_level in reversed(range(num_resolutions)):
|
199 |
+
for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
|
200 |
+
out_ch = nf * ch_mult[i_level]
|
201 |
+
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
202 |
+
in_ch = out_ch
|
203 |
+
|
204 |
+
if all_resolutions[i_level] in attn_resolutions:
|
205 |
+
modules.append(AttnBlock(channels=in_ch))
|
206 |
+
|
207 |
+
if progressive != 'none':
|
208 |
+
if i_level == num_resolutions - 1:
|
209 |
+
if progressive == 'output_skip':
|
210 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
211 |
+
num_channels=in_ch, eps=1e-6))
|
212 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
213 |
+
pyramid_ch = channels
|
214 |
+
elif progressive == 'residual':
|
215 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
216 |
+
modules.append(conv3x3(in_ch, in_ch, bias=True))
|
217 |
+
pyramid_ch = in_ch
|
218 |
+
else:
|
219 |
+
raise ValueError(f'{progressive} is not a valid name.')
|
220 |
+
else:
|
221 |
+
if progressive == 'output_skip':
|
222 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
223 |
+
num_channels=in_ch, eps=1e-6))
|
224 |
+
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
|
225 |
+
pyramid_ch = channels
|
226 |
+
elif progressive == 'residual':
|
227 |
+
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
228 |
+
pyramid_ch = in_ch
|
229 |
+
else:
|
230 |
+
raise ValueError(f'{progressive} is not a valid name')
|
231 |
+
|
232 |
+
if i_level != 0:
|
233 |
+
if resblock_type == 'ddpm':
|
234 |
+
modules.append(Upsample(in_ch=in_ch))
|
235 |
+
else:
|
236 |
+
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
237 |
+
|
238 |
+
assert not hs_c
|
239 |
+
|
240 |
+
if progressive != 'output_skip':
|
241 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
242 |
+
num_channels=in_ch, eps=1e-6))
|
243 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
244 |
+
|
245 |
+
self.all_modules = nn.ModuleList(modules)
|
246 |
+
|
247 |
+
def forward(self, x, time_cond, scale_divide):
|
248 |
+
# timestep/noise_level embedding; only for continuous training
|
249 |
+
modules = self.all_modules
|
250 |
+
m_idx = 0
|
251 |
+
|
252 |
+
# Convert real and imaginary parts of (x,y) into four channel dimensions
|
253 |
+
x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
|
254 |
+
x[:,[1],:,:].real, x[:,[1],:,:].imag,
|
255 |
+
x[:,[2],:,:].real, x[:,[2],:,:].imag), dim=1)
|
256 |
+
|
257 |
+
if self.embedding_type == 'fourier':
|
258 |
+
# Gaussian Fourier features embeddings.
|
259 |
+
used_sigmas = time_cond
|
260 |
+
temb = modules[m_idx](torch.log(used_sigmas))
|
261 |
+
m_idx += 1
|
262 |
+
|
263 |
+
elif self.embedding_type == 'positional':
|
264 |
+
# Sinusoidal positional embeddings.
|
265 |
+
timesteps = time_cond
|
266 |
+
used_sigmas = self.sigmas[time_cond.long()]
|
267 |
+
temb = layers.get_timestep_embedding(timesteps, self.nf)
|
268 |
+
|
269 |
+
else:
|
270 |
+
raise ValueError(f'embedding type {self.embedding_type} unknown.')
|
271 |
+
|
272 |
+
if self.conditional:
|
273 |
+
temb = modules[m_idx](temb)
|
274 |
+
m_idx += 1
|
275 |
+
temb = modules[m_idx](self.act(temb))
|
276 |
+
m_idx += 1
|
277 |
+
else:
|
278 |
+
temb = None
|
279 |
+
|
280 |
+
# Downsampling block
|
281 |
+
input_pyramid = None
|
282 |
+
if self.progressive_input != 'none':
|
283 |
+
input_pyramid = x
|
284 |
+
|
285 |
+
# Input layer: Conv2d: 4ch -> 128ch
|
286 |
+
hs = [modules[m_idx](x)]
|
287 |
+
m_idx += 1
|
288 |
+
|
289 |
+
# Down path in U-Net
|
290 |
+
for i_level in range(self.num_resolutions):
|
291 |
+
# Residual blocks for this resolution
|
292 |
+
for i_block in range(self.num_res_blocks):
|
293 |
+
h = modules[m_idx](hs[-1], temb)
|
294 |
+
m_idx += 1
|
295 |
+
# Attention layer (optional)
|
296 |
+
if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
|
297 |
+
h = modules[m_idx](h)
|
298 |
+
m_idx += 1
|
299 |
+
hs.append(h)
|
300 |
+
|
301 |
+
# Downsampling
|
302 |
+
if i_level != self.num_resolutions - 1:
|
303 |
+
if self.resblock_type == 'ddpm':
|
304 |
+
h = modules[m_idx](hs[-1])
|
305 |
+
m_idx += 1
|
306 |
+
else:
|
307 |
+
h = modules[m_idx](hs[-1], temb)
|
308 |
+
m_idx += 1
|
309 |
+
|
310 |
+
if self.progressive_input == 'input_skip': # Combine h with x
|
311 |
+
input_pyramid = self.pyramid_downsample(input_pyramid)
|
312 |
+
h = modules[m_idx](input_pyramid, h)
|
313 |
+
m_idx += 1
|
314 |
+
|
315 |
+
elif self.progressive_input == 'residual':
|
316 |
+
input_pyramid = modules[m_idx](input_pyramid)
|
317 |
+
m_idx += 1
|
318 |
+
if self.skip_rescale:
|
319 |
+
input_pyramid = (input_pyramid + h) / np.sqrt(2.)
|
320 |
+
else:
|
321 |
+
input_pyramid = input_pyramid + h
|
322 |
+
h = input_pyramid
|
323 |
+
hs.append(h)
|
324 |
+
|
325 |
+
h = hs[-1] # actualy equal to: h = h
|
326 |
+
h = modules[m_idx](h, temb) # ResNet block
|
327 |
+
m_idx += 1
|
328 |
+
h = modules[m_idx](h) # Attention block
|
329 |
+
m_idx += 1
|
330 |
+
h = modules[m_idx](h, temb) # ResNet block
|
331 |
+
m_idx += 1
|
332 |
+
|
333 |
+
pyramid = None
|
334 |
+
|
335 |
+
# Upsampling block
|
336 |
+
for i_level in reversed(range(self.num_resolutions)):
|
337 |
+
for i_block in range(self.num_res_blocks + 1):
|
338 |
+
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
|
339 |
+
m_idx += 1
|
340 |
+
|
341 |
+
# edit: from -1 to -2
|
342 |
+
if h.shape[-2] in self.attn_resolutions:
|
343 |
+
h = modules[m_idx](h)
|
344 |
+
m_idx += 1
|
345 |
+
|
346 |
+
if self.progressive != 'none':
|
347 |
+
if i_level == self.num_resolutions - 1:
|
348 |
+
if self.progressive == 'output_skip':
|
349 |
+
pyramid = self.act(modules[m_idx](h)) # GroupNorm
|
350 |
+
m_idx += 1
|
351 |
+
pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
|
352 |
+
m_idx += 1
|
353 |
+
elif self.progressive == 'residual':
|
354 |
+
pyramid = self.act(modules[m_idx](h))
|
355 |
+
m_idx += 1
|
356 |
+
pyramid = modules[m_idx](pyramid)
|
357 |
+
m_idx += 1
|
358 |
+
else:
|
359 |
+
raise ValueError(f'{self.progressive} is not a valid name.')
|
360 |
+
else:
|
361 |
+
if self.progressive == 'output_skip':
|
362 |
+
pyramid = self.pyramid_upsample(pyramid) # Upsample
|
363 |
+
pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
|
364 |
+
m_idx += 1
|
365 |
+
pyramid_h = modules[m_idx](pyramid_h)
|
366 |
+
m_idx += 1
|
367 |
+
pyramid = pyramid + pyramid_h
|
368 |
+
elif self.progressive == 'residual':
|
369 |
+
pyramid = modules[m_idx](pyramid)
|
370 |
+
m_idx += 1
|
371 |
+
if self.skip_rescale:
|
372 |
+
pyramid = (pyramid + h) / np.sqrt(2.)
|
373 |
+
else:
|
374 |
+
pyramid = pyramid + h
|
375 |
+
h = pyramid
|
376 |
+
else:
|
377 |
+
raise ValueError(f'{self.progressive} is not a valid name')
|
378 |
+
|
379 |
+
# Upsampling Layer
|
380 |
+
if i_level != 0:
|
381 |
+
if self.resblock_type == 'ddpm':
|
382 |
+
h = modules[m_idx](h)
|
383 |
+
m_idx += 1
|
384 |
+
else:
|
385 |
+
h = modules[m_idx](h, temb) # Upspampling
|
386 |
+
m_idx += 1
|
387 |
+
|
388 |
+
assert not hs
|
389 |
+
|
390 |
+
if self.progressive == 'output_skip':
|
391 |
+
h = pyramid
|
392 |
+
else:
|
393 |
+
h = self.act(modules[m_idx](h))
|
394 |
+
m_idx += 1
|
395 |
+
h = modules[m_idx](h)
|
396 |
+
m_idx += 1
|
397 |
+
|
398 |
+
assert m_idx == len(modules), "Implementation error"
|
399 |
+
h = h / scale_divide
|
400 |
+
# h = h / used_sigmas[:, None, None, None]
|
401 |
+
|
402 |
+
# Convert back to complex number
|
403 |
+
h = self.output_layer(h)
|
404 |
+
h = torch.permute(h, (0, 2, 3, 1)).contiguous()
|
405 |
+
h = torch.view_as_complex(h)[:,None, :, :]
|
406 |
+
return h
|
fastgeco/backbones/ncsnpp_utils/layers.py
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
"""Common layers for defining score networks.
|
18 |
+
"""
|
19 |
+
import math
|
20 |
+
import string
|
21 |
+
from functools import partial
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import numpy as np
|
26 |
+
from .normalization import ConditionalInstanceNorm2dPlus
|
27 |
+
|
28 |
+
|
29 |
+
def get_act(config):
|
30 |
+
"""Get activation functions from the config file."""
|
31 |
+
|
32 |
+
if config == 'elu':
|
33 |
+
return nn.ELU()
|
34 |
+
elif config == 'relu':
|
35 |
+
return nn.ReLU()
|
36 |
+
elif config == 'lrelu':
|
37 |
+
return nn.LeakyReLU(negative_slope=0.2)
|
38 |
+
elif config == 'swish':
|
39 |
+
return nn.SiLU()
|
40 |
+
else:
|
41 |
+
raise NotImplementedError('activation function does not exist!')
|
42 |
+
|
43 |
+
|
44 |
+
def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
|
45 |
+
"""1x1 convolution. Same as NCSNv1/v2."""
|
46 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
|
47 |
+
padding=padding)
|
48 |
+
init_scale = 1e-10 if init_scale == 0 else init_scale
|
49 |
+
conv.weight.data *= init_scale
|
50 |
+
conv.bias.data *= init_scale
|
51 |
+
return conv
|
52 |
+
|
53 |
+
|
54 |
+
def variance_scaling(scale, mode, distribution,
|
55 |
+
in_axis=1, out_axis=0,
|
56 |
+
dtype=torch.float32,
|
57 |
+
device='cpu'):
|
58 |
+
"""Ported from JAX. """
|
59 |
+
|
60 |
+
def _compute_fans(shape, in_axis=1, out_axis=0):
|
61 |
+
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
62 |
+
fan_in = shape[in_axis] * receptive_field_size
|
63 |
+
fan_out = shape[out_axis] * receptive_field_size
|
64 |
+
return fan_in, fan_out
|
65 |
+
|
66 |
+
def init(shape, dtype=dtype, device=device):
|
67 |
+
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
68 |
+
if mode == "fan_in":
|
69 |
+
denominator = fan_in
|
70 |
+
elif mode == "fan_out":
|
71 |
+
denominator = fan_out
|
72 |
+
elif mode == "fan_avg":
|
73 |
+
denominator = (fan_in + fan_out) / 2
|
74 |
+
else:
|
75 |
+
raise ValueError(
|
76 |
+
"invalid mode for variance scaling initializer: {}".format(mode))
|
77 |
+
variance = scale / denominator
|
78 |
+
if distribution == "normal":
|
79 |
+
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
80 |
+
elif distribution == "uniform":
|
81 |
+
return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
|
82 |
+
else:
|
83 |
+
raise ValueError("invalid distribution for variance scaling initializer")
|
84 |
+
|
85 |
+
return init
|
86 |
+
|
87 |
+
|
88 |
+
def default_init(scale=1.):
|
89 |
+
"""The same initialization used in DDPM."""
|
90 |
+
scale = 1e-10 if scale == 0 else scale
|
91 |
+
return variance_scaling(scale, 'fan_avg', 'uniform')
|
92 |
+
|
93 |
+
|
94 |
+
class Dense(nn.Module):
|
95 |
+
"""Linear layer with `default_init`."""
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
|
100 |
+
def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
|
101 |
+
"""1x1 convolution with DDPM initialization."""
|
102 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
103 |
+
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
104 |
+
nn.init.zeros_(conv.bias)
|
105 |
+
return conv
|
106 |
+
|
107 |
+
|
108 |
+
def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
|
109 |
+
"""3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
|
110 |
+
init_scale = 1e-10 if init_scale == 0 else init_scale
|
111 |
+
conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
|
112 |
+
dilation=dilation, padding=padding, kernel_size=3)
|
113 |
+
conv.weight.data *= init_scale
|
114 |
+
conv.bias.data *= init_scale
|
115 |
+
return conv
|
116 |
+
|
117 |
+
|
118 |
+
def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
|
119 |
+
"""3x3 convolution with DDPM initialization."""
|
120 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
|
121 |
+
dilation=dilation, bias=bias)
|
122 |
+
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
123 |
+
nn.init.zeros_(conv.bias)
|
124 |
+
return conv
|
125 |
+
|
126 |
+
###########################################################################
|
127 |
+
# Functions below are ported over from the NCSNv1/NCSNv2 codebase:
|
128 |
+
# https://github.com/ermongroup/ncsn
|
129 |
+
# https://github.com/ermongroup/ncsnv2
|
130 |
+
###########################################################################
|
131 |
+
|
132 |
+
|
133 |
+
class CRPBlock(nn.Module):
|
134 |
+
def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
|
135 |
+
super().__init__()
|
136 |
+
self.convs = nn.ModuleList()
|
137 |
+
for i in range(n_stages):
|
138 |
+
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
|
139 |
+
self.n_stages = n_stages
|
140 |
+
if maxpool:
|
141 |
+
self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
|
142 |
+
else:
|
143 |
+
self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
|
144 |
+
|
145 |
+
self.act = act
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
x = self.act(x)
|
149 |
+
path = x
|
150 |
+
for i in range(self.n_stages):
|
151 |
+
path = self.pool(path)
|
152 |
+
path = self.convs[i](path)
|
153 |
+
x = path + x
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class CondCRPBlock(nn.Module):
|
158 |
+
def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
|
159 |
+
super().__init__()
|
160 |
+
self.convs = nn.ModuleList()
|
161 |
+
self.norms = nn.ModuleList()
|
162 |
+
self.normalizer = normalizer
|
163 |
+
for i in range(n_stages):
|
164 |
+
self.norms.append(normalizer(features, num_classes, bias=True))
|
165 |
+
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
|
166 |
+
|
167 |
+
self.n_stages = n_stages
|
168 |
+
self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
|
169 |
+
self.act = act
|
170 |
+
|
171 |
+
def forward(self, x, y):
|
172 |
+
x = self.act(x)
|
173 |
+
path = x
|
174 |
+
for i in range(self.n_stages):
|
175 |
+
path = self.norms[i](path, y)
|
176 |
+
path = self.pool(path)
|
177 |
+
path = self.convs[i](path)
|
178 |
+
|
179 |
+
x = path + x
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class RCUBlock(nn.Module):
|
184 |
+
def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
for i in range(n_blocks):
|
188 |
+
for j in range(n_stages):
|
189 |
+
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
|
190 |
+
|
191 |
+
self.stride = 1
|
192 |
+
self.n_blocks = n_blocks
|
193 |
+
self.n_stages = n_stages
|
194 |
+
self.act = act
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
for i in range(self.n_blocks):
|
198 |
+
residual = x
|
199 |
+
for j in range(self.n_stages):
|
200 |
+
x = self.act(x)
|
201 |
+
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
|
202 |
+
|
203 |
+
x += residual
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
class CondRCUBlock(nn.Module):
|
208 |
+
def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
for i in range(n_blocks):
|
212 |
+
for j in range(n_stages):
|
213 |
+
setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
|
214 |
+
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
|
215 |
+
|
216 |
+
self.stride = 1
|
217 |
+
self.n_blocks = n_blocks
|
218 |
+
self.n_stages = n_stages
|
219 |
+
self.act = act
|
220 |
+
self.normalizer = normalizer
|
221 |
+
|
222 |
+
def forward(self, x, y):
|
223 |
+
for i in range(self.n_blocks):
|
224 |
+
residual = x
|
225 |
+
for j in range(self.n_stages):
|
226 |
+
x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
|
227 |
+
x = self.act(x)
|
228 |
+
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
|
229 |
+
|
230 |
+
x += residual
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
class MSFBlock(nn.Module):
|
235 |
+
def __init__(self, in_planes, features):
|
236 |
+
super().__init__()
|
237 |
+
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
|
238 |
+
self.convs = nn.ModuleList()
|
239 |
+
self.features = features
|
240 |
+
|
241 |
+
for i in range(len(in_planes)):
|
242 |
+
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
|
243 |
+
|
244 |
+
def forward(self, xs, shape):
|
245 |
+
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
|
246 |
+
for i in range(len(self.convs)):
|
247 |
+
h = self.convs[i](xs[i])
|
248 |
+
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
|
249 |
+
sums += h
|
250 |
+
return sums
|
251 |
+
|
252 |
+
|
253 |
+
class CondMSFBlock(nn.Module):
|
254 |
+
def __init__(self, in_planes, features, num_classes, normalizer):
|
255 |
+
super().__init__()
|
256 |
+
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
|
257 |
+
|
258 |
+
self.convs = nn.ModuleList()
|
259 |
+
self.norms = nn.ModuleList()
|
260 |
+
self.features = features
|
261 |
+
self.normalizer = normalizer
|
262 |
+
|
263 |
+
for i in range(len(in_planes)):
|
264 |
+
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
|
265 |
+
self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
|
266 |
+
|
267 |
+
def forward(self, xs, y, shape):
|
268 |
+
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
|
269 |
+
for i in range(len(self.convs)):
|
270 |
+
h = self.norms[i](xs[i], y)
|
271 |
+
h = self.convs[i](h)
|
272 |
+
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
|
273 |
+
sums += h
|
274 |
+
return sums
|
275 |
+
|
276 |
+
|
277 |
+
class RefineBlock(nn.Module):
|
278 |
+
def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
|
279 |
+
super().__init__()
|
280 |
+
|
281 |
+
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
|
282 |
+
self.n_blocks = n_blocks = len(in_planes)
|
283 |
+
|
284 |
+
self.adapt_convs = nn.ModuleList()
|
285 |
+
for i in range(n_blocks):
|
286 |
+
self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
|
287 |
+
|
288 |
+
self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
|
289 |
+
|
290 |
+
if not start:
|
291 |
+
self.msf = MSFBlock(in_planes, features)
|
292 |
+
|
293 |
+
self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
|
294 |
+
|
295 |
+
def forward(self, xs, output_shape):
|
296 |
+
assert isinstance(xs, tuple) or isinstance(xs, list)
|
297 |
+
hs = []
|
298 |
+
for i in range(len(xs)):
|
299 |
+
h = self.adapt_convs[i](xs[i])
|
300 |
+
hs.append(h)
|
301 |
+
|
302 |
+
if self.n_blocks > 1:
|
303 |
+
h = self.msf(hs, output_shape)
|
304 |
+
else:
|
305 |
+
h = hs[0]
|
306 |
+
|
307 |
+
h = self.crp(h)
|
308 |
+
h = self.output_convs(h)
|
309 |
+
|
310 |
+
return h
|
311 |
+
|
312 |
+
|
313 |
+
class CondRefineBlock(nn.Module):
|
314 |
+
def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
|
315 |
+
super().__init__()
|
316 |
+
|
317 |
+
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
|
318 |
+
self.n_blocks = n_blocks = len(in_planes)
|
319 |
+
|
320 |
+
self.adapt_convs = nn.ModuleList()
|
321 |
+
for i in range(n_blocks):
|
322 |
+
self.adapt_convs.append(
|
323 |
+
CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
|
324 |
+
)
|
325 |
+
|
326 |
+
self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
|
327 |
+
|
328 |
+
if not start:
|
329 |
+
self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
|
330 |
+
|
331 |
+
self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
|
332 |
+
|
333 |
+
def forward(self, xs, y, output_shape):
|
334 |
+
assert isinstance(xs, tuple) or isinstance(xs, list)
|
335 |
+
hs = []
|
336 |
+
for i in range(len(xs)):
|
337 |
+
h = self.adapt_convs[i](xs[i], y)
|
338 |
+
hs.append(h)
|
339 |
+
|
340 |
+
if self.n_blocks > 1:
|
341 |
+
h = self.msf(hs, y, output_shape)
|
342 |
+
else:
|
343 |
+
h = hs[0]
|
344 |
+
|
345 |
+
h = self.crp(h, y)
|
346 |
+
h = self.output_convs(h, y)
|
347 |
+
|
348 |
+
return h
|
349 |
+
|
350 |
+
|
351 |
+
class ConvMeanPool(nn.Module):
|
352 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
|
353 |
+
super().__init__()
|
354 |
+
if not adjust_padding:
|
355 |
+
conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
356 |
+
self.conv = conv
|
357 |
+
else:
|
358 |
+
conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
359 |
+
|
360 |
+
self.conv = nn.Sequential(
|
361 |
+
nn.ZeroPad2d((1, 0, 1, 0)),
|
362 |
+
conv
|
363 |
+
)
|
364 |
+
|
365 |
+
def forward(self, inputs):
|
366 |
+
output = self.conv(inputs)
|
367 |
+
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
|
368 |
+
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
|
369 |
+
return output
|
370 |
+
|
371 |
+
|
372 |
+
class MeanPoolConv(nn.Module):
|
373 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
|
374 |
+
super().__init__()
|
375 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
376 |
+
|
377 |
+
def forward(self, inputs):
|
378 |
+
output = inputs
|
379 |
+
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
|
380 |
+
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
|
381 |
+
return self.conv(output)
|
382 |
+
|
383 |
+
|
384 |
+
class UpsampleConv(nn.Module):
|
385 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
|
386 |
+
super().__init__()
|
387 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
388 |
+
self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
|
389 |
+
|
390 |
+
def forward(self, inputs):
|
391 |
+
output = inputs
|
392 |
+
output = torch.cat([output, output, output, output], dim=1)
|
393 |
+
output = self.pixelshuffle(output)
|
394 |
+
return self.conv(output)
|
395 |
+
|
396 |
+
|
397 |
+
class ConditionalResidualBlock(nn.Module):
|
398 |
+
def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
|
399 |
+
normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
|
400 |
+
super().__init__()
|
401 |
+
self.non_linearity = act
|
402 |
+
self.input_dim = input_dim
|
403 |
+
self.output_dim = output_dim
|
404 |
+
self.resample = resample
|
405 |
+
self.normalization = normalization
|
406 |
+
if resample == 'down':
|
407 |
+
if dilation > 1:
|
408 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
|
409 |
+
self.normalize2 = normalization(input_dim, num_classes)
|
410 |
+
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
411 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
412 |
+
else:
|
413 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
|
414 |
+
self.normalize2 = normalization(input_dim, num_classes)
|
415 |
+
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
|
416 |
+
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
|
417 |
+
|
418 |
+
elif resample is None:
|
419 |
+
if dilation > 1:
|
420 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
421 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
422 |
+
self.normalize2 = normalization(output_dim, num_classes)
|
423 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
|
424 |
+
else:
|
425 |
+
conv_shortcut = nn.Conv2d
|
426 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
|
427 |
+
self.normalize2 = normalization(output_dim, num_classes)
|
428 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
|
429 |
+
else:
|
430 |
+
raise Exception('invalid resample value')
|
431 |
+
|
432 |
+
if output_dim != input_dim or resample is not None:
|
433 |
+
self.shortcut = conv_shortcut(input_dim, output_dim)
|
434 |
+
|
435 |
+
self.normalize1 = normalization(input_dim, num_classes)
|
436 |
+
|
437 |
+
def forward(self, x, y):
|
438 |
+
output = self.normalize1(x, y)
|
439 |
+
output = self.non_linearity(output)
|
440 |
+
output = self.conv1(output)
|
441 |
+
output = self.normalize2(output, y)
|
442 |
+
output = self.non_linearity(output)
|
443 |
+
output = self.conv2(output)
|
444 |
+
|
445 |
+
if self.output_dim == self.input_dim and self.resample is None:
|
446 |
+
shortcut = x
|
447 |
+
else:
|
448 |
+
shortcut = self.shortcut(x)
|
449 |
+
|
450 |
+
return shortcut + output
|
451 |
+
|
452 |
+
|
453 |
+
class ResidualBlock(nn.Module):
|
454 |
+
def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
|
455 |
+
normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
|
456 |
+
super().__init__()
|
457 |
+
self.non_linearity = act
|
458 |
+
self.input_dim = input_dim
|
459 |
+
self.output_dim = output_dim
|
460 |
+
self.resample = resample
|
461 |
+
self.normalization = normalization
|
462 |
+
if resample == 'down':
|
463 |
+
if dilation > 1:
|
464 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
|
465 |
+
self.normalize2 = normalization(input_dim)
|
466 |
+
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
467 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
468 |
+
else:
|
469 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
|
470 |
+
self.normalize2 = normalization(input_dim)
|
471 |
+
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
|
472 |
+
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
|
473 |
+
|
474 |
+
elif resample is None:
|
475 |
+
if dilation > 1:
|
476 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
477 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
478 |
+
self.normalize2 = normalization(output_dim)
|
479 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
|
480 |
+
else:
|
481 |
+
# conv_shortcut = nn.Conv2d ### Something wierd here.
|
482 |
+
conv_shortcut = partial(ncsn_conv1x1)
|
483 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
|
484 |
+
self.normalize2 = normalization(output_dim)
|
485 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
|
486 |
+
else:
|
487 |
+
raise Exception('invalid resample value')
|
488 |
+
|
489 |
+
if output_dim != input_dim or resample is not None:
|
490 |
+
self.shortcut = conv_shortcut(input_dim, output_dim)
|
491 |
+
|
492 |
+
self.normalize1 = normalization(input_dim)
|
493 |
+
|
494 |
+
def forward(self, x):
|
495 |
+
output = self.normalize1(x)
|
496 |
+
output = self.non_linearity(output)
|
497 |
+
output = self.conv1(output)
|
498 |
+
output = self.normalize2(output)
|
499 |
+
output = self.non_linearity(output)
|
500 |
+
output = self.conv2(output)
|
501 |
+
|
502 |
+
if self.output_dim == self.input_dim and self.resample is None:
|
503 |
+
shortcut = x
|
504 |
+
else:
|
505 |
+
shortcut = self.shortcut(x)
|
506 |
+
|
507 |
+
return shortcut + output
|
508 |
+
|
509 |
+
|
510 |
+
###########################################################################
|
511 |
+
# Functions below are ported over from the DDPM codebase:
|
512 |
+
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
|
513 |
+
###########################################################################
|
514 |
+
|
515 |
+
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
516 |
+
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
517 |
+
half_dim = embedding_dim // 2
|
518 |
+
# magic number 10000 is from transformers
|
519 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
520 |
+
# emb = math.log(2.) / (half_dim - 1)
|
521 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
522 |
+
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
|
523 |
+
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
|
524 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
525 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
526 |
+
if embedding_dim % 2 == 1: # zero pad
|
527 |
+
emb = F.pad(emb, (0, 1), mode='constant')
|
528 |
+
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
529 |
+
return emb
|
530 |
+
|
531 |
+
|
532 |
+
def _einsum(a, b, c, x, y):
|
533 |
+
einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
|
534 |
+
return torch.einsum(einsum_str, x, y)
|
535 |
+
|
536 |
+
|
537 |
+
def contract_inner(x, y):
|
538 |
+
"""tensordot(x, y, 1)."""
|
539 |
+
x_chars = list(string.ascii_lowercase[:len(x.shape)])
|
540 |
+
y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
|
541 |
+
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
542 |
+
out_chars = x_chars[:-1] + y_chars[1:]
|
543 |
+
return _einsum(x_chars, y_chars, out_chars, x, y)
|
544 |
+
|
545 |
+
|
546 |
+
class NIN(nn.Module):
|
547 |
+
def __init__(self, in_dim, num_units, init_scale=0.1):
|
548 |
+
super().__init__()
|
549 |
+
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
550 |
+
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
551 |
+
|
552 |
+
def forward(self, x):
|
553 |
+
x = x.permute(0, 2, 3, 1)
|
554 |
+
y = contract_inner(x, self.W) + self.b
|
555 |
+
return y.permute(0, 3, 1, 2)
|
556 |
+
|
557 |
+
|
558 |
+
class AttnBlock(nn.Module):
|
559 |
+
"""Channel-wise self-attention block."""
|
560 |
+
def __init__(self, channels):
|
561 |
+
super().__init__()
|
562 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
|
563 |
+
self.NIN_0 = NIN(channels, channels)
|
564 |
+
self.NIN_1 = NIN(channels, channels)
|
565 |
+
self.NIN_2 = NIN(channels, channels)
|
566 |
+
self.NIN_3 = NIN(channels, channels, init_scale=0.)
|
567 |
+
|
568 |
+
def forward(self, x):
|
569 |
+
B, C, H, W = x.shape
|
570 |
+
h = self.GroupNorm_0(x)
|
571 |
+
q = self.NIN_0(h)
|
572 |
+
k = self.NIN_1(h)
|
573 |
+
v = self.NIN_2(h)
|
574 |
+
|
575 |
+
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
|
576 |
+
w = torch.reshape(w, (B, H, W, H * W))
|
577 |
+
w = F.softmax(w, dim=-1)
|
578 |
+
w = torch.reshape(w, (B, H, W, H, W))
|
579 |
+
h = torch.einsum('bhwij,bcij->bchw', w, v)
|
580 |
+
h = self.NIN_3(h)
|
581 |
+
return x + h
|
582 |
+
|
583 |
+
|
584 |
+
class Upsample(nn.Module):
|
585 |
+
def __init__(self, channels, with_conv=False):
|
586 |
+
super().__init__()
|
587 |
+
if with_conv:
|
588 |
+
self.Conv_0 = ddpm_conv3x3(channels, channels)
|
589 |
+
self.with_conv = with_conv
|
590 |
+
|
591 |
+
def forward(self, x):
|
592 |
+
B, C, H, W = x.shape
|
593 |
+
h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
|
594 |
+
if self.with_conv:
|
595 |
+
h = self.Conv_0(h)
|
596 |
+
return h
|
597 |
+
|
598 |
+
|
599 |
+
class Downsample(nn.Module):
|
600 |
+
def __init__(self, channels, with_conv=False):
|
601 |
+
super().__init__()
|
602 |
+
if with_conv:
|
603 |
+
self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
|
604 |
+
self.with_conv = with_conv
|
605 |
+
|
606 |
+
def forward(self, x):
|
607 |
+
B, C, H, W = x.shape
|
608 |
+
# Emulate 'SAME' padding
|
609 |
+
if self.with_conv:
|
610 |
+
x = F.pad(x, (0, 1, 0, 1))
|
611 |
+
x = self.Conv_0(x)
|
612 |
+
else:
|
613 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
|
614 |
+
|
615 |
+
assert x.shape == (B, C, H // 2, W // 2)
|
616 |
+
return x
|
617 |
+
|
618 |
+
|
619 |
+
class ResnetBlockDDPM(nn.Module):
|
620 |
+
"""The ResNet Blocks used in DDPM."""
|
621 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
|
622 |
+
super().__init__()
|
623 |
+
if out_ch is None:
|
624 |
+
out_ch = in_ch
|
625 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
|
626 |
+
self.act = act
|
627 |
+
self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
|
628 |
+
if temb_dim is not None:
|
629 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
630 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
631 |
+
nn.init.zeros_(self.Dense_0.bias)
|
632 |
+
|
633 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
|
634 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
635 |
+
self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
|
636 |
+
if in_ch != out_ch:
|
637 |
+
if conv_shortcut:
|
638 |
+
self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
|
639 |
+
else:
|
640 |
+
self.NIN_0 = NIN(in_ch, out_ch)
|
641 |
+
self.out_ch = out_ch
|
642 |
+
self.in_ch = in_ch
|
643 |
+
self.conv_shortcut = conv_shortcut
|
644 |
+
|
645 |
+
def forward(self, x, temb=None):
|
646 |
+
B, C, H, W = x.shape
|
647 |
+
assert C == self.in_ch
|
648 |
+
out_ch = self.out_ch if self.out_ch else self.in_ch
|
649 |
+
h = self.act(self.GroupNorm_0(x))
|
650 |
+
h = self.Conv_0(h)
|
651 |
+
# Add bias to each feature map conditioned on the time embedding
|
652 |
+
if temb is not None:
|
653 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
654 |
+
h = self.act(self.GroupNorm_1(h))
|
655 |
+
h = self.Dropout_0(h)
|
656 |
+
h = self.Conv_1(h)
|
657 |
+
if C != out_ch:
|
658 |
+
if self.conv_shortcut:
|
659 |
+
x = self.Conv_2(x)
|
660 |
+
else:
|
661 |
+
x = self.NIN_0(x)
|
662 |
+
return x + h
|
fastgeco/backbones/ncsnpp_utils/layerspp.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
"""Layers for defining NCSN++.
|
18 |
+
"""
|
19 |
+
from . import layers
|
20 |
+
import score_models.layers.up_or_downsampling2d as up_or_down_sampling
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
conv1x1 = layers.ddpm_conv1x1
|
27 |
+
conv3x3 = layers.ddpm_conv3x3
|
28 |
+
NIN = layers.NIN
|
29 |
+
default_init = layers.default_init
|
30 |
+
|
31 |
+
|
32 |
+
class GaussianFourierProjection(nn.Module):
|
33 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
34 |
+
|
35 |
+
def __init__(self, embedding_size=256, scale=1.0):
|
36 |
+
super().__init__()
|
37 |
+
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
41 |
+
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
42 |
+
|
43 |
+
|
44 |
+
class Combine(nn.Module):
|
45 |
+
"""Combine information from skip connections."""
|
46 |
+
|
47 |
+
def __init__(self, dim1, dim2, method='cat'):
|
48 |
+
super().__init__()
|
49 |
+
self.Conv_0 = conv1x1(dim1, dim2)
|
50 |
+
self.method = method
|
51 |
+
|
52 |
+
def forward(self, x, y):
|
53 |
+
h = self.Conv_0(x)
|
54 |
+
if self.method == 'cat':
|
55 |
+
return torch.cat([h, y], dim=1)
|
56 |
+
elif self.method == 'sum':
|
57 |
+
return h + y
|
58 |
+
else:
|
59 |
+
raise ValueError(f'Method {self.method} not recognized.')
|
60 |
+
|
61 |
+
|
62 |
+
class AttnBlockpp(nn.Module):
|
63 |
+
"""Channel-wise self-attention block. Modified from DDPM."""
|
64 |
+
|
65 |
+
def __init__(self, channels, skip_rescale=False, init_scale=0.):
|
66 |
+
super().__init__()
|
67 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
|
68 |
+
eps=1e-6)
|
69 |
+
self.NIN_0 = NIN(channels, channels)
|
70 |
+
self.NIN_1 = NIN(channels, channels)
|
71 |
+
self.NIN_2 = NIN(channels, channels)
|
72 |
+
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
|
73 |
+
self.skip_rescale = skip_rescale
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
B, C, H, W = x.shape
|
77 |
+
h = self.GroupNorm_0(x)
|
78 |
+
q = self.NIN_0(h)
|
79 |
+
k = self.NIN_1(h)
|
80 |
+
v = self.NIN_2(h)
|
81 |
+
|
82 |
+
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
|
83 |
+
w = torch.reshape(w, (B, H, W, H * W))
|
84 |
+
w = F.softmax(w, dim=-1)
|
85 |
+
w = torch.reshape(w, (B, H, W, H, W))
|
86 |
+
h = torch.einsum('bhwij,bcij->bchw', w, v)
|
87 |
+
h = self.NIN_3(h)
|
88 |
+
if not self.skip_rescale:
|
89 |
+
return x + h
|
90 |
+
else:
|
91 |
+
return (x + h) / np.sqrt(2.)
|
92 |
+
|
93 |
+
|
94 |
+
# class Upsample(nn.Module):
|
95 |
+
# def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
|
96 |
+
# fir_kernel=(1, 3, 3, 1)):
|
97 |
+
# super().__init__()
|
98 |
+
# out_ch = out_ch if out_ch else in_ch
|
99 |
+
# if not fir:
|
100 |
+
# if with_conv:
|
101 |
+
# self.Conv_0 = conv3x3(in_ch, out_ch)
|
102 |
+
# else:
|
103 |
+
# if with_conv:
|
104 |
+
# self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
|
105 |
+
# kernel=3, up=True,
|
106 |
+
# resample_kernel=fir_kernel,
|
107 |
+
# use_bias=True,
|
108 |
+
# kernel_init=default_init())
|
109 |
+
# self.fir = fir
|
110 |
+
# self.with_conv = with_conv
|
111 |
+
# self.fir_kernel = fir_kernel
|
112 |
+
# self.out_ch = out_ch
|
113 |
+
|
114 |
+
# def forward(self, x):
|
115 |
+
# B, C, H, W = x.shape
|
116 |
+
# if not self.fir:
|
117 |
+
# h = F.interpolate(x, (H * 2, W * 2), 'nearest')
|
118 |
+
# if self.with_conv:
|
119 |
+
# h = self.Conv_0(h)
|
120 |
+
# else:
|
121 |
+
# if not self.with_conv:
|
122 |
+
# h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
|
123 |
+
# else:
|
124 |
+
# h = self.Conv2d_0(x)
|
125 |
+
|
126 |
+
# return h
|
127 |
+
|
128 |
+
|
129 |
+
# class Downsample(nn.Module):
|
130 |
+
# def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
|
131 |
+
# fir_kernel=(1, 3, 3, 1)):
|
132 |
+
# super().__init__()
|
133 |
+
# out_ch = out_ch if out_ch else in_ch
|
134 |
+
# if not fir:
|
135 |
+
# if with_conv:
|
136 |
+
# self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
|
137 |
+
# else:
|
138 |
+
# if with_conv:
|
139 |
+
# self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
|
140 |
+
# kernel=3, down=True,
|
141 |
+
# resample_kernel=fir_kernel,
|
142 |
+
# use_bias=True,
|
143 |
+
# kernel_init=default_init())
|
144 |
+
# self.fir = fir
|
145 |
+
# self.fir_kernel = fir_kernel
|
146 |
+
# self.with_conv = with_conv
|
147 |
+
# self.out_ch = out_ch
|
148 |
+
|
149 |
+
# def forward(self, x):
|
150 |
+
# B, C, H, W = x.shape
|
151 |
+
# if not self.fir:
|
152 |
+
# if self.with_conv:
|
153 |
+
# x = F.pad(x, (0, 1, 0, 1))
|
154 |
+
# x = self.Conv_0(x)
|
155 |
+
# else:
|
156 |
+
# x = F.avg_pool2d(x, 2, stride=2)
|
157 |
+
# else:
|
158 |
+
# if not self.with_conv:
|
159 |
+
# x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
|
160 |
+
# else:
|
161 |
+
# x = self.Conv2d_0(x)
|
162 |
+
|
163 |
+
# return x
|
164 |
+
|
165 |
+
|
166 |
+
class ResnetBlockDDPMpp(nn.Module):
|
167 |
+
"""ResBlock adapted from DDPM."""
|
168 |
+
|
169 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
|
170 |
+
dropout=0.1, skip_rescale=False, init_scale=0.):
|
171 |
+
super().__init__()
|
172 |
+
out_ch = out_ch if out_ch else in_ch
|
173 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
174 |
+
self.Conv_0 = conv3x3(in_ch, out_ch)
|
175 |
+
if temb_dim is not None:
|
176 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
177 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
178 |
+
nn.init.zeros_(self.Dense_0.bias)
|
179 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
180 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
181 |
+
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
182 |
+
if in_ch != out_ch:
|
183 |
+
if conv_shortcut:
|
184 |
+
self.Conv_2 = conv3x3(in_ch, out_ch)
|
185 |
+
else:
|
186 |
+
self.NIN_0 = NIN(in_ch, out_ch)
|
187 |
+
|
188 |
+
self.skip_rescale = skip_rescale
|
189 |
+
self.act = act
|
190 |
+
self.out_ch = out_ch
|
191 |
+
self.conv_shortcut = conv_shortcut
|
192 |
+
|
193 |
+
def forward(self, x, temb=None):
|
194 |
+
h = self.act(self.GroupNorm_0(x))
|
195 |
+
h = self.Conv_0(h)
|
196 |
+
if temb is not None:
|
197 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
198 |
+
h = self.act(self.GroupNorm_1(h))
|
199 |
+
h = self.Dropout_0(h)
|
200 |
+
h = self.Conv_1(h)
|
201 |
+
if x.shape[1] != self.out_ch:
|
202 |
+
if self.conv_shortcut:
|
203 |
+
x = self.Conv_2(x)
|
204 |
+
else:
|
205 |
+
x = self.NIN_0(x)
|
206 |
+
if not self.skip_rescale:
|
207 |
+
return x + h
|
208 |
+
else:
|
209 |
+
return (x + h) / np.sqrt(2.)
|
210 |
+
|
211 |
+
|
212 |
+
class ResnetBlockBigGANpp(nn.Module):
|
213 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
|
214 |
+
dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
|
215 |
+
skip_rescale=True, init_scale=0.):
|
216 |
+
super().__init__()
|
217 |
+
|
218 |
+
out_ch = out_ch if out_ch else in_ch
|
219 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
220 |
+
self.up = up
|
221 |
+
self.down = down
|
222 |
+
self.fir = fir
|
223 |
+
self.fir_kernel = fir_kernel
|
224 |
+
|
225 |
+
self.Conv_0 = conv3x3(in_ch, out_ch)
|
226 |
+
if temb_dim is not None:
|
227 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
228 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
|
229 |
+
nn.init.zeros_(self.Dense_0.bias)
|
230 |
+
|
231 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
232 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
233 |
+
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
234 |
+
if in_ch != out_ch or up or down:
|
235 |
+
self.Conv_2 = conv1x1(in_ch, out_ch)
|
236 |
+
|
237 |
+
self.skip_rescale = skip_rescale
|
238 |
+
self.act = act
|
239 |
+
self.in_ch = in_ch
|
240 |
+
self.out_ch = out_ch
|
241 |
+
|
242 |
+
def forward(self, x, temb=None):
|
243 |
+
h = self.act(self.GroupNorm_0(x))
|
244 |
+
|
245 |
+
if self.up:
|
246 |
+
if self.fir:
|
247 |
+
h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
|
248 |
+
x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
|
249 |
+
else:
|
250 |
+
h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
|
251 |
+
x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
|
252 |
+
elif self.down:
|
253 |
+
if self.fir:
|
254 |
+
h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
|
255 |
+
x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
|
256 |
+
else:
|
257 |
+
h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
|
258 |
+
x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
|
259 |
+
|
260 |
+
h = self.Conv_0(h)
|
261 |
+
# Add bias to each feature map conditioned on the time embedding
|
262 |
+
if temb is not None:
|
263 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
264 |
+
h = self.act(self.GroupNorm_1(h))
|
265 |
+
h = self.Dropout_0(h)
|
266 |
+
h = self.Conv_1(h)
|
267 |
+
|
268 |
+
if self.in_ch != self.out_ch or self.up or self.down:
|
269 |
+
x = self.Conv_2(x)
|
270 |
+
|
271 |
+
if not self.skip_rescale:
|
272 |
+
return x + h
|
273 |
+
else:
|
274 |
+
return (x + h) / np.sqrt(2.)
|
fastgeco/backbones/ncsnpp_utils/normalization.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Normalization layers."""
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch
|
19 |
+
import functools
|
20 |
+
|
21 |
+
|
22 |
+
def get_normalization(config, conditional=False):
|
23 |
+
"""Obtain normalization modules from the config file."""
|
24 |
+
norm = config.model.normalization
|
25 |
+
if conditional:
|
26 |
+
if norm == 'InstanceNorm++':
|
27 |
+
return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError(f'{norm} not implemented yet.')
|
30 |
+
else:
|
31 |
+
if norm == 'InstanceNorm':
|
32 |
+
return nn.InstanceNorm2d
|
33 |
+
elif norm == 'InstanceNorm++':
|
34 |
+
return InstanceNorm2dPlus
|
35 |
+
elif norm == 'VarianceNorm':
|
36 |
+
return VarianceNorm2d
|
37 |
+
elif norm == 'GroupNorm':
|
38 |
+
return nn.GroupNorm
|
39 |
+
else:
|
40 |
+
raise ValueError('Unknown normalization: %s' % norm)
|
41 |
+
|
42 |
+
|
43 |
+
class ConditionalBatchNorm2d(nn.Module):
|
44 |
+
def __init__(self, num_features, num_classes, bias=True):
|
45 |
+
super().__init__()
|
46 |
+
self.num_features = num_features
|
47 |
+
self.bias = bias
|
48 |
+
self.bn = nn.BatchNorm2d(num_features, affine=False)
|
49 |
+
if self.bias:
|
50 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
51 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
52 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
53 |
+
else:
|
54 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
55 |
+
self.embed.weight.data.uniform_()
|
56 |
+
|
57 |
+
def forward(self, x, y):
|
58 |
+
out = self.bn(x)
|
59 |
+
if self.bias:
|
60 |
+
gamma, beta = self.embed(y).chunk(2, dim=1)
|
61 |
+
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
|
62 |
+
else:
|
63 |
+
gamma = self.embed(y)
|
64 |
+
out = gamma.view(-1, self.num_features, 1, 1) * out
|
65 |
+
return out
|
66 |
+
|
67 |
+
|
68 |
+
class ConditionalInstanceNorm2d(nn.Module):
|
69 |
+
def __init__(self, num_features, num_classes, bias=True):
|
70 |
+
super().__init__()
|
71 |
+
self.num_features = num_features
|
72 |
+
self.bias = bias
|
73 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
74 |
+
if bias:
|
75 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
76 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
77 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
78 |
+
else:
|
79 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
80 |
+
self.embed.weight.data.uniform_()
|
81 |
+
|
82 |
+
def forward(self, x, y):
|
83 |
+
h = self.instance_norm(x)
|
84 |
+
if self.bias:
|
85 |
+
gamma, beta = self.embed(y).chunk(2, dim=-1)
|
86 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
|
87 |
+
else:
|
88 |
+
gamma = self.embed(y)
|
89 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
90 |
+
return out
|
91 |
+
|
92 |
+
|
93 |
+
class ConditionalVarianceNorm2d(nn.Module):
|
94 |
+
def __init__(self, num_features, num_classes, bias=False):
|
95 |
+
super().__init__()
|
96 |
+
self.num_features = num_features
|
97 |
+
self.bias = bias
|
98 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
99 |
+
self.embed.weight.data.normal_(1, 0.02)
|
100 |
+
|
101 |
+
def forward(self, x, y):
|
102 |
+
vars = torch.var(x, dim=(2, 3), keepdim=True)
|
103 |
+
h = x / torch.sqrt(vars + 1e-5)
|
104 |
+
|
105 |
+
gamma = self.embed(y)
|
106 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
class VarianceNorm2d(nn.Module):
|
111 |
+
def __init__(self, num_features, bias=False):
|
112 |
+
super().__init__()
|
113 |
+
self.num_features = num_features
|
114 |
+
self.bias = bias
|
115 |
+
self.alpha = nn.Parameter(torch.zeros(num_features))
|
116 |
+
self.alpha.data.normal_(1, 0.02)
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
vars = torch.var(x, dim=(2, 3), keepdim=True)
|
120 |
+
h = x / torch.sqrt(vars + 1e-5)
|
121 |
+
|
122 |
+
out = self.alpha.view(-1, self.num_features, 1, 1) * h
|
123 |
+
return out
|
124 |
+
|
125 |
+
|
126 |
+
class ConditionalNoneNorm2d(nn.Module):
|
127 |
+
def __init__(self, num_features, num_classes, bias=True):
|
128 |
+
super().__init__()
|
129 |
+
self.num_features = num_features
|
130 |
+
self.bias = bias
|
131 |
+
if bias:
|
132 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
133 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
134 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
135 |
+
else:
|
136 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
137 |
+
self.embed.weight.data.uniform_()
|
138 |
+
|
139 |
+
def forward(self, x, y):
|
140 |
+
if self.bias:
|
141 |
+
gamma, beta = self.embed(y).chunk(2, dim=-1)
|
142 |
+
out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
|
143 |
+
else:
|
144 |
+
gamma = self.embed(y)
|
145 |
+
out = gamma.view(-1, self.num_features, 1, 1) * x
|
146 |
+
return out
|
147 |
+
|
148 |
+
|
149 |
+
class NoneNorm2d(nn.Module):
|
150 |
+
def __init__(self, num_features, bias=True):
|
151 |
+
super().__init__()
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class InstanceNorm2dPlus(nn.Module):
|
158 |
+
def __init__(self, num_features, bias=True):
|
159 |
+
super().__init__()
|
160 |
+
self.num_features = num_features
|
161 |
+
self.bias = bias
|
162 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
163 |
+
self.alpha = nn.Parameter(torch.zeros(num_features))
|
164 |
+
self.gamma = nn.Parameter(torch.zeros(num_features))
|
165 |
+
self.alpha.data.normal_(1, 0.02)
|
166 |
+
self.gamma.data.normal_(1, 0.02)
|
167 |
+
if bias:
|
168 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
means = torch.mean(x, dim=(2, 3))
|
172 |
+
m = torch.mean(means, dim=-1, keepdim=True)
|
173 |
+
v = torch.var(means, dim=-1, keepdim=True)
|
174 |
+
means = (means - m) / (torch.sqrt(v + 1e-5))
|
175 |
+
h = self.instance_norm(x)
|
176 |
+
|
177 |
+
if self.bias:
|
178 |
+
h = h + means[..., None, None] * self.alpha[..., None, None]
|
179 |
+
out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
|
180 |
+
else:
|
181 |
+
h = h + means[..., None, None] * self.alpha[..., None, None]
|
182 |
+
out = self.gamma.view(-1, self.num_features, 1, 1) * h
|
183 |
+
return out
|
184 |
+
|
185 |
+
|
186 |
+
class ConditionalInstanceNorm2dPlus(nn.Module):
|
187 |
+
def __init__(self, num_features, num_classes, bias=True):
|
188 |
+
super().__init__()
|
189 |
+
self.num_features = num_features
|
190 |
+
self.bias = bias
|
191 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
192 |
+
if bias:
|
193 |
+
self.embed = nn.Embedding(num_classes, num_features * 3)
|
194 |
+
self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
|
195 |
+
self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
|
196 |
+
else:
|
197 |
+
self.embed = nn.Embedding(num_classes, 2 * num_features)
|
198 |
+
self.embed.weight.data.normal_(1, 0.02)
|
199 |
+
|
200 |
+
def forward(self, x, y):
|
201 |
+
means = torch.mean(x, dim=(2, 3))
|
202 |
+
m = torch.mean(means, dim=-1, keepdim=True)
|
203 |
+
v = torch.var(means, dim=-1, keepdim=True)
|
204 |
+
means = (means - m) / (torch.sqrt(v + 1e-5))
|
205 |
+
h = self.instance_norm(x)
|
206 |
+
|
207 |
+
if self.bias:
|
208 |
+
gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
|
209 |
+
h = h + means[..., None, None] * alpha[..., None, None]
|
210 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
|
211 |
+
else:
|
212 |
+
gamma, alpha = self.embed(y).chunk(2, dim=-1)
|
213 |
+
h = h + means[..., None, None] * alpha[..., None, None]
|
214 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
215 |
+
return out
|
fastgeco/backbones/ncsnpp_utils/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""All functions and modules related to model definition.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
from ...sdes import OUVESDE, OUVPSDE
|
23 |
+
|
24 |
+
|
25 |
+
_MODELS = {}
|
26 |
+
|
27 |
+
|
28 |
+
def register_model(cls=None, *, name=None):
|
29 |
+
"""A decorator for registering model classes."""
|
30 |
+
|
31 |
+
def _register(cls):
|
32 |
+
if name is None:
|
33 |
+
local_name = cls.__name__
|
34 |
+
else:
|
35 |
+
local_name = name
|
36 |
+
if local_name in _MODELS:
|
37 |
+
raise ValueError(f'Already registered model with name: {local_name}')
|
38 |
+
_MODELS[local_name] = cls
|
39 |
+
return cls
|
40 |
+
|
41 |
+
if cls is None:
|
42 |
+
return _register
|
43 |
+
else:
|
44 |
+
return _register(cls)
|
45 |
+
|
46 |
+
|
47 |
+
def get_model(name):
|
48 |
+
return _MODELS[name]
|
49 |
+
|
50 |
+
|
51 |
+
def get_sigmas(sigma_min, sigma_max, num_scales):
|
52 |
+
"""Get sigmas --- the set of noise levels for SMLD from config files.
|
53 |
+
Args:
|
54 |
+
config: A ConfigDict object parsed from the config file
|
55 |
+
Returns:
|
56 |
+
sigmas: a jax numpy arrary of noise levels
|
57 |
+
"""
|
58 |
+
sigmas = np.exp(
|
59 |
+
np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
|
60 |
+
|
61 |
+
return sigmas
|
62 |
+
|
63 |
+
|
64 |
+
def get_ddpm_params(config):
|
65 |
+
"""Get betas and alphas --- parameters used in the original DDPM paper."""
|
66 |
+
num_diffusion_timesteps = 1000
|
67 |
+
# parameters need to be adapted if number of time steps differs from 1000
|
68 |
+
beta_start = config.model.beta_min / config.model.num_scales
|
69 |
+
beta_end = config.model.beta_max / config.model.num_scales
|
70 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
71 |
+
|
72 |
+
alphas = 1. - betas
|
73 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
74 |
+
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
|
75 |
+
sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
|
76 |
+
|
77 |
+
return {
|
78 |
+
'betas': betas,
|
79 |
+
'alphas': alphas,
|
80 |
+
'alphas_cumprod': alphas_cumprod,
|
81 |
+
'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
|
82 |
+
'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
|
83 |
+
'beta_min': beta_start * (num_diffusion_timesteps - 1),
|
84 |
+
'beta_max': beta_end * (num_diffusion_timesteps - 1),
|
85 |
+
'num_diffusion_timesteps': num_diffusion_timesteps
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
def create_model(config):
|
90 |
+
"""Create the score model."""
|
91 |
+
model_name = config.model.name
|
92 |
+
score_model = get_model(model_name)(config)
|
93 |
+
score_model = score_model.to(config.device)
|
94 |
+
score_model = torch.nn.DataParallel(score_model)
|
95 |
+
return score_model
|
96 |
+
|
97 |
+
|
98 |
+
def get_model_fn(model, train=False):
|
99 |
+
"""Create a function to give the output of the score-based model.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
model: The score model.
|
103 |
+
train: `True` for training and `False` for evaluation.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
A model function.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def model_fn(x, labels):
|
110 |
+
"""Compute the output of the score-based model.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x: A mini-batch of input data.
|
114 |
+
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
115 |
+
for different models.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
A tuple of (model output, new mutable states)
|
119 |
+
"""
|
120 |
+
if not train:
|
121 |
+
model.eval()
|
122 |
+
return model(x, labels)
|
123 |
+
else:
|
124 |
+
model.train()
|
125 |
+
return model(x, labels)
|
126 |
+
|
127 |
+
return model_fn
|
128 |
+
|
129 |
+
|
130 |
+
def get_score_fn(sde, model, train=False, continuous=False):
|
131 |
+
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
135 |
+
model: A score model.
|
136 |
+
train: `True` for training and `False` for evaluation.
|
137 |
+
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
A score function.
|
141 |
+
"""
|
142 |
+
model_fn = get_model_fn(model, train=train)
|
143 |
+
|
144 |
+
if isinstance(sde, OUVPSDE):
|
145 |
+
def score_fn(x, t):
|
146 |
+
# Scale neural network output by standard deviation and flip sign
|
147 |
+
if continuous:
|
148 |
+
# For VP-trained models, t=0 corresponds to the lowest noise level
|
149 |
+
# The maximum value of time embedding is assumed to 999 for
|
150 |
+
# continuously-trained models.
|
151 |
+
labels = t * 999
|
152 |
+
score = model_fn(x, labels)
|
153 |
+
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
154 |
+
else:
|
155 |
+
# For VP-trained models, t=0 corresponds to the lowest noise level
|
156 |
+
labels = t * (sde.N - 1)
|
157 |
+
score = model_fn(x, labels)
|
158 |
+
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
|
159 |
+
|
160 |
+
score = -score / std[:, None, None, None]
|
161 |
+
return score
|
162 |
+
|
163 |
+
elif isinstance(sde, OUVESDE):
|
164 |
+
def score_fn(x, t):
|
165 |
+
if continuous:
|
166 |
+
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
167 |
+
else:
|
168 |
+
# For VE-trained models, t=0 corresponds to the highest noise level
|
169 |
+
labels = sde.T - t
|
170 |
+
labels *= sde.N - 1
|
171 |
+
labels = torch.round(labels).long()
|
172 |
+
|
173 |
+
score = model_fn(x, labels)
|
174 |
+
return score
|
175 |
+
|
176 |
+
else:
|
177 |
+
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
178 |
+
|
179 |
+
return score_fn
|
180 |
+
|
181 |
+
|
182 |
+
def to_flattened_numpy(x):
|
183 |
+
"""Flatten a torch tensor `x` and convert it to numpy."""
|
184 |
+
return x.detach().cpu().numpy().reshape((-1,))
|
185 |
+
|
186 |
+
|
187 |
+
def from_flattened_numpy(x, shape):
|
188 |
+
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
189 |
+
return torch.from_numpy(x.reshape(shape))
|
fastgeco/backbones/shared.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from geco.util.registry import Registry
|
8 |
+
|
9 |
+
|
10 |
+
BackboneRegistry = Registry("Backbone")
|
11 |
+
|
12 |
+
|
13 |
+
class GaussianFourierProjection(nn.Module):
|
14 |
+
"""Gaussian random features for encoding time steps."""
|
15 |
+
|
16 |
+
def __init__(self, embed_dim, scale=16, complex_valued=False):
|
17 |
+
super().__init__()
|
18 |
+
self.complex_valued = complex_valued
|
19 |
+
if not complex_valued:
|
20 |
+
# If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
|
21 |
+
# Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
|
22 |
+
# we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
|
23 |
+
# and this halving is not necessary.
|
24 |
+
embed_dim = embed_dim // 2
|
25 |
+
# Randomly sample weights during initialization. These weights are fixed
|
26 |
+
# during optimization and are not trainable.
|
27 |
+
self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
|
28 |
+
|
29 |
+
def forward(self, t):
|
30 |
+
t_proj = t[:, None] * self.W[None, :] * 2*np.pi
|
31 |
+
if self.complex_valued:
|
32 |
+
return torch.exp(1j * t_proj)
|
33 |
+
else:
|
34 |
+
return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
|
35 |
+
|
36 |
+
|
37 |
+
class DiffusionStepEmbedding(nn.Module):
|
38 |
+
"""Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
|
39 |
+
|
40 |
+
def __init__(self, embed_dim, complex_valued=False):
|
41 |
+
super().__init__()
|
42 |
+
self.complex_valued = complex_valued
|
43 |
+
if not complex_valued:
|
44 |
+
# If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
|
45 |
+
# Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
|
46 |
+
# we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
|
47 |
+
# and this halving is not necessary.
|
48 |
+
embed_dim = embed_dim // 2
|
49 |
+
self.embed_dim = embed_dim
|
50 |
+
|
51 |
+
def forward(self, t):
|
52 |
+
fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
|
53 |
+
inner = t[:, None] * fac[None, :]
|
54 |
+
if self.complex_valued:
|
55 |
+
return torch.exp(1j * inner)
|
56 |
+
else:
|
57 |
+
return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
|
58 |
+
|
59 |
+
|
60 |
+
class ComplexLinear(nn.Module):
|
61 |
+
"""A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
|
62 |
+
def __init__(self, input_dim, output_dim, complex_valued):
|
63 |
+
super().__init__()
|
64 |
+
self.complex_valued = complex_valued
|
65 |
+
if self.complex_valued:
|
66 |
+
self.re = nn.Linear(input_dim, output_dim)
|
67 |
+
self.im = nn.Linear(input_dim, output_dim)
|
68 |
+
else:
|
69 |
+
self.lin = nn.Linear(input_dim, output_dim)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.complex_valued:
|
73 |
+
return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
|
74 |
+
else:
|
75 |
+
return self.lin(x)
|
76 |
+
|
77 |
+
|
78 |
+
class FeatureMapDense(nn.Module):
|
79 |
+
"""A fully connected layer that reshapes outputs to feature maps."""
|
80 |
+
|
81 |
+
def __init__(self, input_dim, output_dim, complex_valued=False):
|
82 |
+
super().__init__()
|
83 |
+
self.complex_valued = complex_valued
|
84 |
+
self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
return self.dense(x)[..., None, None]
|
88 |
+
|
89 |
+
|
90 |
+
def torch_complex_from_reim(re, im):
|
91 |
+
return torch.view_as_complex(torch.stack([re, im], dim=-1))
|
92 |
+
|
93 |
+
|
94 |
+
class ArgsComplexMultiplicationWrapper(nn.Module):
|
95 |
+
"""Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
|
96 |
+
|
97 |
+
Make a complex-valued module `F` from a real-valued module `f` by applying
|
98 |
+
complex multiplication rules:
|
99 |
+
|
100 |
+
F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
|
101 |
+
|
102 |
+
where `f1`, `f2` are instances of `f` that do *not* share weights.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
module_cls (callable): A class or function that returns a Torch module/functional.
|
106 |
+
Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
|
107 |
+
to construct the real and imaginary component modules.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, module_cls, *args, **kwargs):
|
111 |
+
super().__init__()
|
112 |
+
self.re_module = module_cls(*args, **kwargs)
|
113 |
+
self.im_module = module_cls(*args, **kwargs)
|
114 |
+
|
115 |
+
def forward(self, x, *args, **kwargs):
|
116 |
+
return torch_complex_from_reim(
|
117 |
+
self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
|
118 |
+
self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
|
123 |
+
ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
|
fastgeco/model.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import time
|
3 |
+
from math import ceil
|
4 |
+
import warnings
|
5 |
+
import numpy as np
|
6 |
+
# from asteroid.losses.sdr import SingleSrcNegSDR
|
7 |
+
import torch
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from torch_ema import ExponentialMovingAverage
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from geco import sampling
|
12 |
+
from geco.sdes import SDERegistry
|
13 |
+
from fastgeco.backbones import BackboneRegistry
|
14 |
+
from geco.util.inference import evaluate_model2
|
15 |
+
from geco.util.other import pad_spec
|
16 |
+
import numpy as np
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class ScoreModel(pl.LightningModule):
|
22 |
+
@staticmethod
|
23 |
+
def add_argparse_args(parser):
|
24 |
+
parser.add_argument("--lr", type=float, default=1e-5, help="The learning rate (1e-4 by default)")
|
25 |
+
parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
|
26 |
+
parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)")
|
27 |
+
parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
|
28 |
+
parser.add_argument("--loss_type", type=str, default="mse", help="The type of loss function to use.")
|
29 |
+
parser.add_argument("--loss_abs_exponent", type=float, default=0.5, help="magnitude transformation in the loss term")
|
30 |
+
parser.add_argument("--output_scale", type=str, choices=('sigma', 'time'), default= 'time', help="backbone model scale before last output layer")
|
31 |
+
return parser
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2, loss_abs_exponent=0.5,
|
35 |
+
num_eval_files=20, loss_type='mse', data_module_cls=None, output_scale='time', inference_N=1,
|
36 |
+
inference_start=0.5, **kwargs
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
Create a new ScoreModel.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
backbone: Backbone DNN that serves as a score-based model.
|
43 |
+
sde: The SDE that defines the diffusion process.
|
44 |
+
lr: The learning rate of the optimizer. (1e-4 by default).
|
45 |
+
ema_decay: The decay constant of the parameter EMA (0.999 by default).
|
46 |
+
t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
|
47 |
+
loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
|
48 |
+
"""
|
49 |
+
super().__init__()
|
50 |
+
# Initialize Backbone DNN
|
51 |
+
dnn_cls = BackboneRegistry.get_by_name(backbone)
|
52 |
+
self.dnn = dnn_cls(**kwargs)
|
53 |
+
# Initialize SDE
|
54 |
+
sde_cls = SDERegistry.get_by_name(sde)
|
55 |
+
self.sde = sde_cls(**kwargs)
|
56 |
+
# Store hyperparams and save them
|
57 |
+
self.lr = lr
|
58 |
+
self.ema_decay = ema_decay
|
59 |
+
self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
|
60 |
+
self._error_loading_ema = False
|
61 |
+
self.t_eps = t_eps
|
62 |
+
self.loss_type = loss_type
|
63 |
+
self.num_eval_files = num_eval_files
|
64 |
+
self.loss_abs_exponent = loss_abs_exponent
|
65 |
+
self.output_scale = output_scale
|
66 |
+
self.save_hyperparameters(ignore=['no_wandb'])
|
67 |
+
self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
|
68 |
+
self.inference_N = inference_N
|
69 |
+
self.inference_start = inference_start
|
70 |
+
|
71 |
+
# self.si_snr = SingleSrcNegSDR("sisdr", reduction='mean', zero_mean=False)
|
72 |
+
|
73 |
+
def configure_optimizers(self):
|
74 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
75 |
+
return optimizer
|
76 |
+
|
77 |
+
def optimizer_step(self, *args, **kwargs):
|
78 |
+
# Method overridden so that the EMA params are updated after each optimizer step
|
79 |
+
super().optimizer_step(*args, **kwargs)
|
80 |
+
self.ema.update(self.parameters())
|
81 |
+
|
82 |
+
# on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
|
83 |
+
def on_load_checkpoint(self, checkpoint):
|
84 |
+
ema = checkpoint.get('ema', None)
|
85 |
+
if ema is not None:
|
86 |
+
self.ema.load_state_dict(checkpoint['ema'])
|
87 |
+
else:
|
88 |
+
self._error_loading_ema = True
|
89 |
+
warnings.warn("EMA state_dict not found in checkpoint!")
|
90 |
+
|
91 |
+
def on_save_checkpoint(self, checkpoint):
|
92 |
+
checkpoint['ema'] = self.ema.state_dict()
|
93 |
+
|
94 |
+
def train(self, mode, no_ema=False):
|
95 |
+
res = super().train(mode) # call the standard `train` method with the given mode
|
96 |
+
if not self._error_loading_ema:
|
97 |
+
if mode == False and not no_ema:
|
98 |
+
# eval
|
99 |
+
self.ema.store(self.parameters()) # store current params in EMA
|
100 |
+
self.ema.copy_to(self.parameters()) # copy EMA parameters over current params for evaluation
|
101 |
+
else:
|
102 |
+
# train
|
103 |
+
if self.ema.collected_params is not None:
|
104 |
+
self.ema.restore(self.parameters()) # restore the EMA weights (if stored)
|
105 |
+
return res
|
106 |
+
|
107 |
+
def eval(self, no_ema=False):
|
108 |
+
return self.train(False, no_ema=no_ema)
|
109 |
+
|
110 |
+
|
111 |
+
def sisnr(self, est, ref, eps = 1e-8):
|
112 |
+
est = est - torch.mean(est, dim = -1, keepdim = True)
|
113 |
+
ref = ref - torch.mean(ref, dim = -1, keepdim = True)
|
114 |
+
est_p = (torch.sum(est * ref, dim = -1, keepdim = True) * ref) / torch.sum(ref * ref, dim = -1, keepdim = True)
|
115 |
+
est_v = est - est_p
|
116 |
+
est_sisnr = 10 * torch.log10((torch.sum(est_p * est_p, dim = -1, keepdim = True) + eps) / (torch.sum(est_v * est_v, dim = -1, keepdim = True) + eps))
|
117 |
+
return -est_sisnr
|
118 |
+
|
119 |
+
|
120 |
+
def _loss(self, wav_x_tm1, wav_gt):
|
121 |
+
if self.loss_type == 'default':
|
122 |
+
min_leng = min(wav_x_tm1.shape[-1], wav_gt.shape[-1])
|
123 |
+
wav_x_tm1 = wav_x_tm1.squeeze(1)[:,:min_leng]
|
124 |
+
wav_gt = wav_gt.squeeze(1)[:,:min_leng]
|
125 |
+
loss = torch.mean(self.sisnr(wav_x_tm1, wav_gt))
|
126 |
+
else:
|
127 |
+
raise RuntimeError(f'{self.loss_type} loss not defined')
|
128 |
+
|
129 |
+
return loss
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
def euler_step(self, X, X_t, Y, M, t, dt):
|
134 |
+
f, g = self.sde.sde(X_t, t, Y)
|
135 |
+
vec_t = torch.ones(Y.shape[0], device=Y.device) * t
|
136 |
+
mean_x_tm1 = X_t - (f - g**2*self.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None]))*dt
|
137 |
+
z = torch.randn_like(X)
|
138 |
+
X_t = mean_x_tm1 + z*g*torch.sqrt(dt)
|
139 |
+
|
140 |
+
return X_t
|
141 |
+
|
142 |
+
|
143 |
+
def training_step(self, batch, batch_idx):
|
144 |
+
X, Y, M = batch
|
145 |
+
|
146 |
+
reverse_start_time = random.uniform(self.t_rsp_min, self.t_rsp_max)
|
147 |
+
N_reverse = random.randint(self.N_min, self.N_max)
|
148 |
+
|
149 |
+
if self.stop_iteration_random == "random":
|
150 |
+
stop_iteration = random.randint(0, N_reverse-1)
|
151 |
+
elif self.stop_iteration_random == "last":
|
152 |
+
#Used in publication. This means that only the last step is used for updating weights.
|
153 |
+
stop_iteration = N_reverse-1
|
154 |
+
else:
|
155 |
+
raise RuntimeError(f'{self.stop_iteration_random} not defined')
|
156 |
+
|
157 |
+
timesteps = torch.linspace(reverse_start_time, self.t_eps, N_reverse, device=Y.device)
|
158 |
+
|
159 |
+
#prior sampling starting from reverse_start_time
|
160 |
+
std = self.sde._std(reverse_start_time*torch.ones((Y.shape[0],), device=Y.device))
|
161 |
+
z = torch.randn_like(Y)
|
162 |
+
X_t = Y + z * std[:, None, None, None]
|
163 |
+
|
164 |
+
#reverse steps by Euler Maruyama
|
165 |
+
for i in range(len(timesteps)):
|
166 |
+
t = timesteps[i]
|
167 |
+
if i != len(timesteps) - 1:
|
168 |
+
dt = t - timesteps[i+1]
|
169 |
+
else:
|
170 |
+
dt = timesteps[-1]
|
171 |
+
|
172 |
+
if i != stop_iteration:
|
173 |
+
with torch.no_grad():
|
174 |
+
#take Euler step here
|
175 |
+
X_t = self.euler_step(X, X_t, Y, M, t, dt)
|
176 |
+
else:
|
177 |
+
#take a Euler step and compute loss
|
178 |
+
f, g = self.sde.sde(X_t, t, Y)
|
179 |
+
vec_t = torch.ones(Y.shape[0], device=Y.device) * t
|
180 |
+
score = self.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None])
|
181 |
+
mean_x_tm1 = X_t - (f - g**2*score)*dt #mean of x t minus 1 = mu(x_{t-1})
|
182 |
+
mean_gt, _ = self.sde.marginal_prob(X, torch.ones(Y.shape[0], device=Y.device) * (t-dt), Y)
|
183 |
+
|
184 |
+
wav_gt = self.to_audio(mean_gt.squeeze())
|
185 |
+
wav_x_tm1 = self.to_audio(mean_x_tm1.squeeze())
|
186 |
+
loss = self._loss(wav_x_tm1, wav_gt)
|
187 |
+
break
|
188 |
+
|
189 |
+
self.log('train_loss', loss, on_step=True, on_epoch=True)
|
190 |
+
return loss
|
191 |
+
|
192 |
+
|
193 |
+
def validation_step(self, batch, batch_idx):
|
194 |
+
# Evaluate speech enhancement performance, compute loss only for a few val data
|
195 |
+
if batch_idx == 0 and self.num_eval_files != 0:
|
196 |
+
pesq, si_sdr, estoi, loss = evaluate_model2(self, self.num_eval_files, self.inference_N, inference_start=self.inference_start)
|
197 |
+
self.log('pesq', pesq, on_step=False, on_epoch=True)
|
198 |
+
self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
|
199 |
+
self.log('estoi', estoi, on_step=False, on_epoch=True)
|
200 |
+
self.log('valid_loss', loss, on_step=False, on_epoch=True)
|
201 |
+
return loss
|
202 |
+
|
203 |
+
|
204 |
+
def forward(self, x, t, y, m, divide_scale):
|
205 |
+
# Concatenate y as an extra channel
|
206 |
+
dnn_input = torch.cat([x, y, m], dim=1)
|
207 |
+
|
208 |
+
# the minus is most likely unimportant here - taken from Song's repo
|
209 |
+
score = -self.dnn(dnn_input, t, divide_scale)
|
210 |
+
return score
|
211 |
+
|
212 |
+
def to(self, *args, **kwargs):
|
213 |
+
"""Override PyTorch .to() to also transfer the EMA of the model weights"""
|
214 |
+
self.ema.to(*args, **kwargs)
|
215 |
+
return super().to(*args, **kwargs)
|
216 |
+
|
217 |
+
|
218 |
+
def train_dataloader(self):
|
219 |
+
return self.data_module.train_dataloader()
|
220 |
+
|
221 |
+
def val_dataloader(self):
|
222 |
+
return self.data_module.val_dataloader()
|
223 |
+
|
224 |
+
def test_dataloader(self):
|
225 |
+
return self.data_module.test_dataloader()
|
226 |
+
|
227 |
+
def setup(self, stage=None):
|
228 |
+
return self.data_module.setup(stage=stage)
|
229 |
+
|
230 |
+
def to_audio(self, spec, length=None):
|
231 |
+
return self._istft(self._backward_transform(spec), length)
|
232 |
+
|
233 |
+
def _forward_transform(self, spec):
|
234 |
+
return self.data_module.spec_fwd(spec)
|
235 |
+
|
236 |
+
def _backward_transform(self, spec):
|
237 |
+
return self.data_module.spec_back(spec)
|
238 |
+
|
239 |
+
def _stft(self, sig):
|
240 |
+
return self.data_module.stft(sig)
|
241 |
+
|
242 |
+
def _istft(self, spec, length=None):
|
243 |
+
return self.data_module.istft(spec, length)
|
244 |
+
|
245 |
+
|
246 |
+
def add_para(self, N_min=1, N_max=1, t_rsp_min=0.5, t_rsp_max=0.5, batch_size=64, loss_type='default', lr=5e-5, stop_iteration_random='last', inference_N=1, inference_start=0.5):
|
247 |
+
self.t_rsp_min = t_rsp_min
|
248 |
+
self.t_rsp_max = t_rsp_max
|
249 |
+
self.N_min = N_min
|
250 |
+
self.N_max = N_max
|
251 |
+
self.data_module.batch_size = batch_size
|
252 |
+
self.data_module.num_workers = 4
|
253 |
+
self.data_module.gpu = True
|
254 |
+
self.loss_type = loss_type
|
255 |
+
self.lr = lr
|
256 |
+
self.stop_iteration_random = stop_iteration_random
|
257 |
+
self.inference_N = inference_N
|
258 |
+
self.inference_start = inference_start
|
geco/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
geco/backbones/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
geco/backbones/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .shared import BackboneRegistry
|
2 |
+
from .ncsnpp import NCSNpp
|
3 |
+
|
4 |
+
__all__ = ['BackboneRegistry', 'NCSNpp']
|
geco/backbones/ncsnpp.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
from score_models.layers import UpsampleLayer, DownsampleLayer
|
18 |
+
from .ncsnpp_utils import layers, layerspp, normalization
|
19 |
+
import torch.nn as nn
|
20 |
+
import functools
|
21 |
+
import torch
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from .shared import BackboneRegistry
|
25 |
+
|
26 |
+
ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
|
27 |
+
ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
|
28 |
+
Combine = layerspp.Combine
|
29 |
+
conv3x3 = layerspp.conv3x3
|
30 |
+
conv1x1 = layerspp.conv1x1
|
31 |
+
get_act = layers.get_act
|
32 |
+
get_normalization = normalization.get_normalization
|
33 |
+
default_initializer = layers.default_init
|
34 |
+
|
35 |
+
|
36 |
+
@BackboneRegistry.register("ncsnpp")
|
37 |
+
class NCSNpp(nn.Module):
|
38 |
+
"""NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def add_argparse_args(parser):
|
42 |
+
# TODO: add additional arguments of constructor, if you wish to modify them.
|
43 |
+
return parser
|
44 |
+
|
45 |
+
def __init__(self,
|
46 |
+
scale_by_sigma = True,
|
47 |
+
nonlinearity = 'swish',
|
48 |
+
nf = 128,
|
49 |
+
ch_mult = (1, 1, 2, 2, 2, 2, 2),
|
50 |
+
num_res_blocks = 2,
|
51 |
+
attn_resolutions = (16,),
|
52 |
+
resamp_with_conv = True,
|
53 |
+
conditional = True,
|
54 |
+
fir = True,
|
55 |
+
fir_kernel = 'song',
|
56 |
+
skip_rescale = True,
|
57 |
+
resblock_type = 'biggan',
|
58 |
+
progressive = 'output_skip',
|
59 |
+
progressive_input = 'input_skip',
|
60 |
+
progressive_combine = 'sum',
|
61 |
+
init_scale = 0.,
|
62 |
+
fourier_scale = 16,
|
63 |
+
image_size = 256,
|
64 |
+
embedding_type = 'fourier',
|
65 |
+
dropout = .0,
|
66 |
+
**unused_kwargs
|
67 |
+
):
|
68 |
+
super().__init__()
|
69 |
+
self.act = act = get_act(nonlinearity)
|
70 |
+
|
71 |
+
self.nf = nf = nf
|
72 |
+
ch_mult = ch_mult
|
73 |
+
self.num_res_blocks = num_res_blocks = num_res_blocks
|
74 |
+
self.attn_resolutions = attn_resolutions = attn_resolutions
|
75 |
+
dropout = dropout
|
76 |
+
resamp_with_conv = resamp_with_conv
|
77 |
+
self.num_resolutions = num_resolutions = len(ch_mult)
|
78 |
+
self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
|
79 |
+
|
80 |
+
self.conditional = conditional = conditional # noise-conditional
|
81 |
+
self.scale_by_sigma = scale_by_sigma
|
82 |
+
fir = fir
|
83 |
+
fir_kernel = [1, 3, 3, 1]
|
84 |
+
self.skip_rescale = skip_rescale = skip_rescale
|
85 |
+
self.resblock_type = resblock_type = resblock_type.lower()
|
86 |
+
self.progressive = progressive = progressive.lower()
|
87 |
+
self.progressive_input = progressive_input = progressive_input.lower()
|
88 |
+
self.embedding_type = embedding_type = embedding_type.lower()
|
89 |
+
init_scale = init_scale
|
90 |
+
assert progressive in ['none', 'output_skip', 'residual']
|
91 |
+
assert progressive_input in ['none', 'input_skip', 'residual']
|
92 |
+
assert embedding_type in ['fourier', 'positional']
|
93 |
+
combine_method = progressive_combine.lower()
|
94 |
+
combiner = functools.partial(Combine, method=combine_method)
|
95 |
+
|
96 |
+
num_channels = 6 # x.real, x.imag, y.real, y.imag
|
97 |
+
self.output_layer = nn.Conv2d(num_channels, 2, 1)
|
98 |
+
|
99 |
+
modules = []
|
100 |
+
# timestep/noise_level embedding
|
101 |
+
if embedding_type == 'fourier':
|
102 |
+
# Gaussian Fourier features embeddings.
|
103 |
+
modules.append(layerspp.GaussianFourierProjection(
|
104 |
+
embedding_size=nf, scale=fourier_scale
|
105 |
+
))
|
106 |
+
embed_dim = 2 * nf
|
107 |
+
elif embedding_type == 'positional':
|
108 |
+
embed_dim = nf
|
109 |
+
else:
|
110 |
+
raise ValueError(f'embedding type {embedding_type} unknown.')
|
111 |
+
|
112 |
+
if conditional:
|
113 |
+
modules.append(nn.Linear(embed_dim, nf * 4))
|
114 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
115 |
+
nn.init.zeros_(modules[-1].bias)
|
116 |
+
modules.append(nn.Linear(nf * 4, nf * 4))
|
117 |
+
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
118 |
+
nn.init.zeros_(modules[-1].bias)
|
119 |
+
|
120 |
+
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
121 |
+
init_scale=init_scale, skip_rescale=skip_rescale)
|
122 |
+
|
123 |
+
Upsample = functools.partial(UpsampleLayer,
|
124 |
+
with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
125 |
+
|
126 |
+
if progressive == 'output_skip':
|
127 |
+
self.pyramid_upsample = UpsampleLayer(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
128 |
+
elif progressive == 'residual':
|
129 |
+
pyramid_upsample = functools.partial(UpsampleLayer, fir=fir,
|
130 |
+
fir_kernel=fir_kernel, with_conv=True)
|
131 |
+
|
132 |
+
Downsample = functools.partial(DownsampleLayer, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
133 |
+
|
134 |
+
if progressive_input == 'input_skip':
|
135 |
+
self.pyramid_downsample = DownsampleLayer(fir=fir, fir_kernel=fir_kernel, with_conv=False)
|
136 |
+
elif progressive_input == 'residual':
|
137 |
+
pyramid_downsample = functools.partial(DownsampleLayer,
|
138 |
+
fir=fir, fir_kernel=fir_kernel, with_conv=True)
|
139 |
+
|
140 |
+
if resblock_type == 'ddpm':
|
141 |
+
ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
|
142 |
+
dropout=dropout, init_scale=init_scale,
|
143 |
+
skip_rescale=skip_rescale, temb_dim=nf * 4)
|
144 |
+
|
145 |
+
elif resblock_type == 'biggan':
|
146 |
+
ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
|
147 |
+
dropout=dropout, fir=fir, fir_kernel=fir_kernel,
|
148 |
+
init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
|
149 |
+
|
150 |
+
else:
|
151 |
+
raise ValueError(f'resblock type {resblock_type} unrecognized.')
|
152 |
+
|
153 |
+
# Downsampling block
|
154 |
+
|
155 |
+
channels = num_channels
|
156 |
+
if progressive_input != 'none':
|
157 |
+
input_pyramid_ch = channels
|
158 |
+
|
159 |
+
modules.append(conv3x3(channels, nf))
|
160 |
+
hs_c = [nf]
|
161 |
+
|
162 |
+
in_ch = nf
|
163 |
+
for i_level in range(num_resolutions):
|
164 |
+
# Residual blocks for this resolution
|
165 |
+
for i_block in range(num_res_blocks):
|
166 |
+
out_ch = nf * ch_mult[i_level]
|
167 |
+
modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
|
168 |
+
in_ch = out_ch
|
169 |
+
|
170 |
+
if all_resolutions[i_level] in attn_resolutions:
|
171 |
+
modules.append(AttnBlock(channels=in_ch))
|
172 |
+
hs_c.append(in_ch)
|
173 |
+
|
174 |
+
if i_level != num_resolutions - 1:
|
175 |
+
if resblock_type == 'ddpm':
|
176 |
+
modules.append(Downsample(in_ch=in_ch))
|
177 |
+
else:
|
178 |
+
modules.append(ResnetBlock(down=True, in_ch=in_ch))
|
179 |
+
|
180 |
+
if progressive_input == 'input_skip':
|
181 |
+
modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
|
182 |
+
if combine_method == 'cat':
|
183 |
+
in_ch *= 2
|
184 |
+
|
185 |
+
elif progressive_input == 'residual':
|
186 |
+
modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
|
187 |
+
input_pyramid_ch = in_ch
|
188 |
+
|
189 |
+
hs_c.append(in_ch)
|
190 |
+
|
191 |
+
in_ch = hs_c[-1]
|
192 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
193 |
+
modules.append(AttnBlock(channels=in_ch))
|
194 |
+
modules.append(ResnetBlock(in_ch=in_ch))
|
195 |
+
|
196 |
+
pyramid_ch = 0
|
197 |
+
# Upsampling block
|
198 |
+
for i_level in reversed(range(num_resolutions)):
|
199 |
+
for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
|
200 |
+
out_ch = nf * ch_mult[i_level]
|
201 |
+
modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
|
202 |
+
in_ch = out_ch
|
203 |
+
|
204 |
+
if all_resolutions[i_level] in attn_resolutions:
|
205 |
+
modules.append(AttnBlock(channels=in_ch))
|
206 |
+
|
207 |
+
if progressive != 'none':
|
208 |
+
if i_level == num_resolutions - 1:
|
209 |
+
if progressive == 'output_skip':
|
210 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
211 |
+
num_channels=in_ch, eps=1e-6))
|
212 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
213 |
+
pyramid_ch = channels
|
214 |
+
elif progressive == 'residual':
|
215 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
|
216 |
+
modules.append(conv3x3(in_ch, in_ch, bias=True))
|
217 |
+
pyramid_ch = in_ch
|
218 |
+
else:
|
219 |
+
raise ValueError(f'{progressive} is not a valid name.')
|
220 |
+
else:
|
221 |
+
if progressive == 'output_skip':
|
222 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
223 |
+
num_channels=in_ch, eps=1e-6))
|
224 |
+
modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
|
225 |
+
pyramid_ch = channels
|
226 |
+
elif progressive == 'residual':
|
227 |
+
modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
|
228 |
+
pyramid_ch = in_ch
|
229 |
+
else:
|
230 |
+
raise ValueError(f'{progressive} is not a valid name')
|
231 |
+
|
232 |
+
if i_level != 0:
|
233 |
+
if resblock_type == 'ddpm':
|
234 |
+
modules.append(Upsample(in_ch=in_ch))
|
235 |
+
else:
|
236 |
+
modules.append(ResnetBlock(in_ch=in_ch, up=True))
|
237 |
+
|
238 |
+
assert not hs_c
|
239 |
+
|
240 |
+
if progressive != 'output_skip':
|
241 |
+
modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
242 |
+
num_channels=in_ch, eps=1e-6))
|
243 |
+
modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
|
244 |
+
|
245 |
+
self.all_modules = nn.ModuleList(modules)
|
246 |
+
|
247 |
+
def forward(self, x, time_cond):
|
248 |
+
# timestep/noise_level embedding; only for continuous training
|
249 |
+
modules = self.all_modules
|
250 |
+
m_idx = 0
|
251 |
+
|
252 |
+
# Convert real and imaginary parts of (x,y) into four channel dimensions
|
253 |
+
x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
|
254 |
+
x[:,[1],:,:].real, x[:,[1],:,:].imag,
|
255 |
+
x[:,[2],:,:].real, x[:,[2],:,:].imag), dim=1)
|
256 |
+
|
257 |
+
if self.embedding_type == 'fourier':
|
258 |
+
# Gaussian Fourier features embeddings.
|
259 |
+
used_sigmas = time_cond
|
260 |
+
temb = modules[m_idx](torch.log(used_sigmas))
|
261 |
+
m_idx += 1
|
262 |
+
|
263 |
+
elif self.embedding_type == 'positional':
|
264 |
+
# Sinusoidal positional embeddings.
|
265 |
+
timesteps = time_cond
|
266 |
+
used_sigmas = self.sigmas[time_cond.long()]
|
267 |
+
temb = layers.get_timestep_embedding(timesteps, self.nf)
|
268 |
+
|
269 |
+
else:
|
270 |
+
raise ValueError(f'embedding type {self.embedding_type} unknown.')
|
271 |
+
|
272 |
+
if self.conditional:
|
273 |
+
temb = modules[m_idx](temb)
|
274 |
+
m_idx += 1
|
275 |
+
temb = modules[m_idx](self.act(temb))
|
276 |
+
m_idx += 1
|
277 |
+
else:
|
278 |
+
temb = None
|
279 |
+
|
280 |
+
# Downsampling block
|
281 |
+
input_pyramid = None
|
282 |
+
if self.progressive_input != 'none':
|
283 |
+
input_pyramid = x
|
284 |
+
|
285 |
+
# Input layer: Conv2d: 4ch -> 128ch
|
286 |
+
hs = [modules[m_idx](x)]
|
287 |
+
m_idx += 1
|
288 |
+
|
289 |
+
# Down path in U-Net
|
290 |
+
for i_level in range(self.num_resolutions):
|
291 |
+
# Residual blocks for this resolution
|
292 |
+
for i_block in range(self.num_res_blocks):
|
293 |
+
h = modules[m_idx](hs[-1], temb)
|
294 |
+
m_idx += 1
|
295 |
+
# Attention layer (optional)
|
296 |
+
if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
|
297 |
+
h = modules[m_idx](h)
|
298 |
+
m_idx += 1
|
299 |
+
hs.append(h)
|
300 |
+
|
301 |
+
# Downsampling
|
302 |
+
if i_level != self.num_resolutions - 1:
|
303 |
+
if self.resblock_type == 'ddpm':
|
304 |
+
h = modules[m_idx](hs[-1])
|
305 |
+
m_idx += 1
|
306 |
+
else:
|
307 |
+
h = modules[m_idx](hs[-1], temb)
|
308 |
+
m_idx += 1
|
309 |
+
|
310 |
+
if self.progressive_input == 'input_skip': # Combine h with x
|
311 |
+
input_pyramid = self.pyramid_downsample(input_pyramid)
|
312 |
+
h = modules[m_idx](input_pyramid, h)
|
313 |
+
m_idx += 1
|
314 |
+
|
315 |
+
elif self.progressive_input == 'residual':
|
316 |
+
input_pyramid = modules[m_idx](input_pyramid)
|
317 |
+
m_idx += 1
|
318 |
+
if self.skip_rescale:
|
319 |
+
input_pyramid = (input_pyramid + h) / np.sqrt(2.)
|
320 |
+
else:
|
321 |
+
input_pyramid = input_pyramid + h
|
322 |
+
h = input_pyramid
|
323 |
+
hs.append(h)
|
324 |
+
|
325 |
+
h = hs[-1] # actualy equal to: h = h
|
326 |
+
h = modules[m_idx](h, temb) # ResNet block
|
327 |
+
m_idx += 1
|
328 |
+
h = modules[m_idx](h) # Attention block
|
329 |
+
m_idx += 1
|
330 |
+
h = modules[m_idx](h, temb) # ResNet block
|
331 |
+
m_idx += 1
|
332 |
+
|
333 |
+
pyramid = None
|
334 |
+
|
335 |
+
# Upsampling block
|
336 |
+
for i_level in reversed(range(self.num_resolutions)):
|
337 |
+
for i_block in range(self.num_res_blocks + 1):
|
338 |
+
h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
|
339 |
+
m_idx += 1
|
340 |
+
|
341 |
+
# edit: from -1 to -2
|
342 |
+
if h.shape[-2] in self.attn_resolutions:
|
343 |
+
h = modules[m_idx](h)
|
344 |
+
m_idx += 1
|
345 |
+
|
346 |
+
if self.progressive != 'none':
|
347 |
+
if i_level == self.num_resolutions - 1:
|
348 |
+
if self.progressive == 'output_skip':
|
349 |
+
pyramid = self.act(modules[m_idx](h)) # GroupNorm
|
350 |
+
m_idx += 1
|
351 |
+
pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
|
352 |
+
m_idx += 1
|
353 |
+
elif self.progressive == 'residual':
|
354 |
+
pyramid = self.act(modules[m_idx](h))
|
355 |
+
m_idx += 1
|
356 |
+
pyramid = modules[m_idx](pyramid)
|
357 |
+
m_idx += 1
|
358 |
+
else:
|
359 |
+
raise ValueError(f'{self.progressive} is not a valid name.')
|
360 |
+
else:
|
361 |
+
if self.progressive == 'output_skip':
|
362 |
+
pyramid = self.pyramid_upsample(pyramid) # Upsample
|
363 |
+
pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
|
364 |
+
m_idx += 1
|
365 |
+
pyramid_h = modules[m_idx](pyramid_h)
|
366 |
+
m_idx += 1
|
367 |
+
pyramid = pyramid + pyramid_h
|
368 |
+
elif self.progressive == 'residual':
|
369 |
+
pyramid = modules[m_idx](pyramid)
|
370 |
+
m_idx += 1
|
371 |
+
if self.skip_rescale:
|
372 |
+
pyramid = (pyramid + h) / np.sqrt(2.)
|
373 |
+
else:
|
374 |
+
pyramid = pyramid + h
|
375 |
+
h = pyramid
|
376 |
+
else:
|
377 |
+
raise ValueError(f'{self.progressive} is not a valid name')
|
378 |
+
|
379 |
+
# Upsampling Layer
|
380 |
+
if i_level != 0:
|
381 |
+
if self.resblock_type == 'ddpm':
|
382 |
+
h = modules[m_idx](h)
|
383 |
+
m_idx += 1
|
384 |
+
else:
|
385 |
+
h = modules[m_idx](h, temb) # Upspampling
|
386 |
+
m_idx += 1
|
387 |
+
|
388 |
+
assert not hs
|
389 |
+
|
390 |
+
if self.progressive == 'output_skip':
|
391 |
+
h = pyramid
|
392 |
+
else:
|
393 |
+
h = self.act(modules[m_idx](h))
|
394 |
+
m_idx += 1
|
395 |
+
h = modules[m_idx](h)
|
396 |
+
m_idx += 1
|
397 |
+
|
398 |
+
assert m_idx == len(modules), "Implementation error"
|
399 |
+
h = h / used_sigmas[:, None, None, None]
|
400 |
+
|
401 |
+
# Convert back to complex number
|
402 |
+
h = self.output_layer(h)
|
403 |
+
h = torch.permute(h, (0, 2, 3, 1)).contiguous()
|
404 |
+
h = torch.view_as_complex(h)[:,None, :, :]
|
405 |
+
return h
|
geco/backbones/ncsnpp_utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
geco/backbones/ncsnpp_utils/layers.py
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
"""Common layers for defining score networks.
|
18 |
+
"""
|
19 |
+
import math
|
20 |
+
import string
|
21 |
+
from functools import partial
|
22 |
+
import torch.nn as nn
|
23 |
+
import torch
|
24 |
+
import torch.nn.functional as F
|
25 |
+
import numpy as np
|
26 |
+
from .normalization import ConditionalInstanceNorm2dPlus
|
27 |
+
|
28 |
+
|
29 |
+
def get_act(config):
|
30 |
+
"""Get activation functions from the config file."""
|
31 |
+
|
32 |
+
if config == 'elu':
|
33 |
+
return nn.ELU()
|
34 |
+
elif config == 'relu':
|
35 |
+
return nn.ReLU()
|
36 |
+
elif config == 'lrelu':
|
37 |
+
return nn.LeakyReLU(negative_slope=0.2)
|
38 |
+
elif config == 'swish':
|
39 |
+
return nn.SiLU()
|
40 |
+
else:
|
41 |
+
raise NotImplementedError('activation function does not exist!')
|
42 |
+
|
43 |
+
|
44 |
+
def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
|
45 |
+
"""1x1 convolution. Same as NCSNv1/v2."""
|
46 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
|
47 |
+
padding=padding)
|
48 |
+
init_scale = 1e-10 if init_scale == 0 else init_scale
|
49 |
+
conv.weight.data *= init_scale
|
50 |
+
conv.bias.data *= init_scale
|
51 |
+
return conv
|
52 |
+
|
53 |
+
|
54 |
+
def variance_scaling(scale, mode, distribution,
|
55 |
+
in_axis=1, out_axis=0,
|
56 |
+
dtype=torch.float32,
|
57 |
+
device='cpu'):
|
58 |
+
"""Ported from JAX. """
|
59 |
+
|
60 |
+
def _compute_fans(shape, in_axis=1, out_axis=0):
|
61 |
+
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
|
62 |
+
fan_in = shape[in_axis] * receptive_field_size
|
63 |
+
fan_out = shape[out_axis] * receptive_field_size
|
64 |
+
return fan_in, fan_out
|
65 |
+
|
66 |
+
def init(shape, dtype=dtype, device=device):
|
67 |
+
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
68 |
+
if mode == "fan_in":
|
69 |
+
denominator = fan_in
|
70 |
+
elif mode == "fan_out":
|
71 |
+
denominator = fan_out
|
72 |
+
elif mode == "fan_avg":
|
73 |
+
denominator = (fan_in + fan_out) / 2
|
74 |
+
else:
|
75 |
+
raise ValueError(
|
76 |
+
"invalid mode for variance scaling initializer: {}".format(mode))
|
77 |
+
variance = scale / denominator
|
78 |
+
if distribution == "normal":
|
79 |
+
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
|
80 |
+
elif distribution == "uniform":
|
81 |
+
return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
|
82 |
+
else:
|
83 |
+
raise ValueError("invalid distribution for variance scaling initializer")
|
84 |
+
|
85 |
+
return init
|
86 |
+
|
87 |
+
|
88 |
+
def default_init(scale=1.):
|
89 |
+
"""The same initialization used in DDPM."""
|
90 |
+
scale = 1e-10 if scale == 0 else scale
|
91 |
+
return variance_scaling(scale, 'fan_avg', 'uniform')
|
92 |
+
|
93 |
+
|
94 |
+
class Dense(nn.Module):
|
95 |
+
"""Linear layer with `default_init`."""
|
96 |
+
def __init__(self):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
|
100 |
+
def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
|
101 |
+
"""1x1 convolution with DDPM initialization."""
|
102 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
|
103 |
+
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
104 |
+
nn.init.zeros_(conv.bias)
|
105 |
+
return conv
|
106 |
+
|
107 |
+
|
108 |
+
def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
|
109 |
+
"""3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
|
110 |
+
init_scale = 1e-10 if init_scale == 0 else init_scale
|
111 |
+
conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
|
112 |
+
dilation=dilation, padding=padding, kernel_size=3)
|
113 |
+
conv.weight.data *= init_scale
|
114 |
+
conv.bias.data *= init_scale
|
115 |
+
return conv
|
116 |
+
|
117 |
+
|
118 |
+
def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
|
119 |
+
"""3x3 convolution with DDPM initialization."""
|
120 |
+
conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
|
121 |
+
dilation=dilation, bias=bias)
|
122 |
+
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
|
123 |
+
nn.init.zeros_(conv.bias)
|
124 |
+
return conv
|
125 |
+
|
126 |
+
###########################################################################
|
127 |
+
# Functions below are ported over from the NCSNv1/NCSNv2 codebase:
|
128 |
+
# https://github.com/ermongroup/ncsn
|
129 |
+
# https://github.com/ermongroup/ncsnv2
|
130 |
+
###########################################################################
|
131 |
+
|
132 |
+
|
133 |
+
class CRPBlock(nn.Module):
|
134 |
+
def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
|
135 |
+
super().__init__()
|
136 |
+
self.convs = nn.ModuleList()
|
137 |
+
for i in range(n_stages):
|
138 |
+
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
|
139 |
+
self.n_stages = n_stages
|
140 |
+
if maxpool:
|
141 |
+
self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
|
142 |
+
else:
|
143 |
+
self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
|
144 |
+
|
145 |
+
self.act = act
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
x = self.act(x)
|
149 |
+
path = x
|
150 |
+
for i in range(self.n_stages):
|
151 |
+
path = self.pool(path)
|
152 |
+
path = self.convs[i](path)
|
153 |
+
x = path + x
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class CondCRPBlock(nn.Module):
|
158 |
+
def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
|
159 |
+
super().__init__()
|
160 |
+
self.convs = nn.ModuleList()
|
161 |
+
self.norms = nn.ModuleList()
|
162 |
+
self.normalizer = normalizer
|
163 |
+
for i in range(n_stages):
|
164 |
+
self.norms.append(normalizer(features, num_classes, bias=True))
|
165 |
+
self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
|
166 |
+
|
167 |
+
self.n_stages = n_stages
|
168 |
+
self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
|
169 |
+
self.act = act
|
170 |
+
|
171 |
+
def forward(self, x, y):
|
172 |
+
x = self.act(x)
|
173 |
+
path = x
|
174 |
+
for i in range(self.n_stages):
|
175 |
+
path = self.norms[i](path, y)
|
176 |
+
path = self.pool(path)
|
177 |
+
path = self.convs[i](path)
|
178 |
+
|
179 |
+
x = path + x
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class RCUBlock(nn.Module):
|
184 |
+
def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
|
185 |
+
super().__init__()
|
186 |
+
|
187 |
+
for i in range(n_blocks):
|
188 |
+
for j in range(n_stages):
|
189 |
+
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
|
190 |
+
|
191 |
+
self.stride = 1
|
192 |
+
self.n_blocks = n_blocks
|
193 |
+
self.n_stages = n_stages
|
194 |
+
self.act = act
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
for i in range(self.n_blocks):
|
198 |
+
residual = x
|
199 |
+
for j in range(self.n_stages):
|
200 |
+
x = self.act(x)
|
201 |
+
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
|
202 |
+
|
203 |
+
x += residual
|
204 |
+
return x
|
205 |
+
|
206 |
+
|
207 |
+
class CondRCUBlock(nn.Module):
|
208 |
+
def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
for i in range(n_blocks):
|
212 |
+
for j in range(n_stages):
|
213 |
+
setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
|
214 |
+
setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
|
215 |
+
|
216 |
+
self.stride = 1
|
217 |
+
self.n_blocks = n_blocks
|
218 |
+
self.n_stages = n_stages
|
219 |
+
self.act = act
|
220 |
+
self.normalizer = normalizer
|
221 |
+
|
222 |
+
def forward(self, x, y):
|
223 |
+
for i in range(self.n_blocks):
|
224 |
+
residual = x
|
225 |
+
for j in range(self.n_stages):
|
226 |
+
x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
|
227 |
+
x = self.act(x)
|
228 |
+
x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
|
229 |
+
|
230 |
+
x += residual
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
class MSFBlock(nn.Module):
|
235 |
+
def __init__(self, in_planes, features):
|
236 |
+
super().__init__()
|
237 |
+
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
|
238 |
+
self.convs = nn.ModuleList()
|
239 |
+
self.features = features
|
240 |
+
|
241 |
+
for i in range(len(in_planes)):
|
242 |
+
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
|
243 |
+
|
244 |
+
def forward(self, xs, shape):
|
245 |
+
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
|
246 |
+
for i in range(len(self.convs)):
|
247 |
+
h = self.convs[i](xs[i])
|
248 |
+
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
|
249 |
+
sums += h
|
250 |
+
return sums
|
251 |
+
|
252 |
+
|
253 |
+
class CondMSFBlock(nn.Module):
|
254 |
+
def __init__(self, in_planes, features, num_classes, normalizer):
|
255 |
+
super().__init__()
|
256 |
+
assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
|
257 |
+
|
258 |
+
self.convs = nn.ModuleList()
|
259 |
+
self.norms = nn.ModuleList()
|
260 |
+
self.features = features
|
261 |
+
self.normalizer = normalizer
|
262 |
+
|
263 |
+
for i in range(len(in_planes)):
|
264 |
+
self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
|
265 |
+
self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
|
266 |
+
|
267 |
+
def forward(self, xs, y, shape):
|
268 |
+
sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
|
269 |
+
for i in range(len(self.convs)):
|
270 |
+
h = self.norms[i](xs[i], y)
|
271 |
+
h = self.convs[i](h)
|
272 |
+
h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
|
273 |
+
sums += h
|
274 |
+
return sums
|
275 |
+
|
276 |
+
|
277 |
+
class RefineBlock(nn.Module):
|
278 |
+
def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
|
279 |
+
super().__init__()
|
280 |
+
|
281 |
+
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
|
282 |
+
self.n_blocks = n_blocks = len(in_planes)
|
283 |
+
|
284 |
+
self.adapt_convs = nn.ModuleList()
|
285 |
+
for i in range(n_blocks):
|
286 |
+
self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
|
287 |
+
|
288 |
+
self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
|
289 |
+
|
290 |
+
if not start:
|
291 |
+
self.msf = MSFBlock(in_planes, features)
|
292 |
+
|
293 |
+
self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
|
294 |
+
|
295 |
+
def forward(self, xs, output_shape):
|
296 |
+
assert isinstance(xs, tuple) or isinstance(xs, list)
|
297 |
+
hs = []
|
298 |
+
for i in range(len(xs)):
|
299 |
+
h = self.adapt_convs[i](xs[i])
|
300 |
+
hs.append(h)
|
301 |
+
|
302 |
+
if self.n_blocks > 1:
|
303 |
+
h = self.msf(hs, output_shape)
|
304 |
+
else:
|
305 |
+
h = hs[0]
|
306 |
+
|
307 |
+
h = self.crp(h)
|
308 |
+
h = self.output_convs(h)
|
309 |
+
|
310 |
+
return h
|
311 |
+
|
312 |
+
|
313 |
+
class CondRefineBlock(nn.Module):
|
314 |
+
def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
|
315 |
+
super().__init__()
|
316 |
+
|
317 |
+
assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
|
318 |
+
self.n_blocks = n_blocks = len(in_planes)
|
319 |
+
|
320 |
+
self.adapt_convs = nn.ModuleList()
|
321 |
+
for i in range(n_blocks):
|
322 |
+
self.adapt_convs.append(
|
323 |
+
CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
|
324 |
+
)
|
325 |
+
|
326 |
+
self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
|
327 |
+
|
328 |
+
if not start:
|
329 |
+
self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
|
330 |
+
|
331 |
+
self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
|
332 |
+
|
333 |
+
def forward(self, xs, y, output_shape):
|
334 |
+
assert isinstance(xs, tuple) or isinstance(xs, list)
|
335 |
+
hs = []
|
336 |
+
for i in range(len(xs)):
|
337 |
+
h = self.adapt_convs[i](xs[i], y)
|
338 |
+
hs.append(h)
|
339 |
+
|
340 |
+
if self.n_blocks > 1:
|
341 |
+
h = self.msf(hs, y, output_shape)
|
342 |
+
else:
|
343 |
+
h = hs[0]
|
344 |
+
|
345 |
+
h = self.crp(h, y)
|
346 |
+
h = self.output_convs(h, y)
|
347 |
+
|
348 |
+
return h
|
349 |
+
|
350 |
+
|
351 |
+
class ConvMeanPool(nn.Module):
|
352 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
|
353 |
+
super().__init__()
|
354 |
+
if not adjust_padding:
|
355 |
+
conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
356 |
+
self.conv = conv
|
357 |
+
else:
|
358 |
+
conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
359 |
+
|
360 |
+
self.conv = nn.Sequential(
|
361 |
+
nn.ZeroPad2d((1, 0, 1, 0)),
|
362 |
+
conv
|
363 |
+
)
|
364 |
+
|
365 |
+
def forward(self, inputs):
|
366 |
+
output = self.conv(inputs)
|
367 |
+
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
|
368 |
+
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
|
369 |
+
return output
|
370 |
+
|
371 |
+
|
372 |
+
class MeanPoolConv(nn.Module):
|
373 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
|
374 |
+
super().__init__()
|
375 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
376 |
+
|
377 |
+
def forward(self, inputs):
|
378 |
+
output = inputs
|
379 |
+
output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
|
380 |
+
output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
|
381 |
+
return self.conv(output)
|
382 |
+
|
383 |
+
|
384 |
+
class UpsampleConv(nn.Module):
|
385 |
+
def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
|
386 |
+
super().__init__()
|
387 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
|
388 |
+
self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
|
389 |
+
|
390 |
+
def forward(self, inputs):
|
391 |
+
output = inputs
|
392 |
+
output = torch.cat([output, output, output, output], dim=1)
|
393 |
+
output = self.pixelshuffle(output)
|
394 |
+
return self.conv(output)
|
395 |
+
|
396 |
+
|
397 |
+
class ConditionalResidualBlock(nn.Module):
|
398 |
+
def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
|
399 |
+
normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
|
400 |
+
super().__init__()
|
401 |
+
self.non_linearity = act
|
402 |
+
self.input_dim = input_dim
|
403 |
+
self.output_dim = output_dim
|
404 |
+
self.resample = resample
|
405 |
+
self.normalization = normalization
|
406 |
+
if resample == 'down':
|
407 |
+
if dilation > 1:
|
408 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
|
409 |
+
self.normalize2 = normalization(input_dim, num_classes)
|
410 |
+
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
411 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
412 |
+
else:
|
413 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
|
414 |
+
self.normalize2 = normalization(input_dim, num_classes)
|
415 |
+
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
|
416 |
+
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
|
417 |
+
|
418 |
+
elif resample is None:
|
419 |
+
if dilation > 1:
|
420 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
421 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
422 |
+
self.normalize2 = normalization(output_dim, num_classes)
|
423 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
|
424 |
+
else:
|
425 |
+
conv_shortcut = nn.Conv2d
|
426 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
|
427 |
+
self.normalize2 = normalization(output_dim, num_classes)
|
428 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
|
429 |
+
else:
|
430 |
+
raise Exception('invalid resample value')
|
431 |
+
|
432 |
+
if output_dim != input_dim or resample is not None:
|
433 |
+
self.shortcut = conv_shortcut(input_dim, output_dim)
|
434 |
+
|
435 |
+
self.normalize1 = normalization(input_dim, num_classes)
|
436 |
+
|
437 |
+
def forward(self, x, y):
|
438 |
+
output = self.normalize1(x, y)
|
439 |
+
output = self.non_linearity(output)
|
440 |
+
output = self.conv1(output)
|
441 |
+
output = self.normalize2(output, y)
|
442 |
+
output = self.non_linearity(output)
|
443 |
+
output = self.conv2(output)
|
444 |
+
|
445 |
+
if self.output_dim == self.input_dim and self.resample is None:
|
446 |
+
shortcut = x
|
447 |
+
else:
|
448 |
+
shortcut = self.shortcut(x)
|
449 |
+
|
450 |
+
return shortcut + output
|
451 |
+
|
452 |
+
|
453 |
+
class ResidualBlock(nn.Module):
|
454 |
+
def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
|
455 |
+
normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
|
456 |
+
super().__init__()
|
457 |
+
self.non_linearity = act
|
458 |
+
self.input_dim = input_dim
|
459 |
+
self.output_dim = output_dim
|
460 |
+
self.resample = resample
|
461 |
+
self.normalization = normalization
|
462 |
+
if resample == 'down':
|
463 |
+
if dilation > 1:
|
464 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
|
465 |
+
self.normalize2 = normalization(input_dim)
|
466 |
+
self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
467 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
468 |
+
else:
|
469 |
+
self.conv1 = ncsn_conv3x3(input_dim, input_dim)
|
470 |
+
self.normalize2 = normalization(input_dim)
|
471 |
+
self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
|
472 |
+
conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
|
473 |
+
|
474 |
+
elif resample is None:
|
475 |
+
if dilation > 1:
|
476 |
+
conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
|
477 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
|
478 |
+
self.normalize2 = normalization(output_dim)
|
479 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
|
480 |
+
else:
|
481 |
+
# conv_shortcut = nn.Conv2d ### Something wierd here.
|
482 |
+
conv_shortcut = partial(ncsn_conv1x1)
|
483 |
+
self.conv1 = ncsn_conv3x3(input_dim, output_dim)
|
484 |
+
self.normalize2 = normalization(output_dim)
|
485 |
+
self.conv2 = ncsn_conv3x3(output_dim, output_dim)
|
486 |
+
else:
|
487 |
+
raise Exception('invalid resample value')
|
488 |
+
|
489 |
+
if output_dim != input_dim or resample is not None:
|
490 |
+
self.shortcut = conv_shortcut(input_dim, output_dim)
|
491 |
+
|
492 |
+
self.normalize1 = normalization(input_dim)
|
493 |
+
|
494 |
+
def forward(self, x):
|
495 |
+
output = self.normalize1(x)
|
496 |
+
output = self.non_linearity(output)
|
497 |
+
output = self.conv1(output)
|
498 |
+
output = self.normalize2(output)
|
499 |
+
output = self.non_linearity(output)
|
500 |
+
output = self.conv2(output)
|
501 |
+
|
502 |
+
if self.output_dim == self.input_dim and self.resample is None:
|
503 |
+
shortcut = x
|
504 |
+
else:
|
505 |
+
shortcut = self.shortcut(x)
|
506 |
+
|
507 |
+
return shortcut + output
|
508 |
+
|
509 |
+
|
510 |
+
###########################################################################
|
511 |
+
# Functions below are ported over from the DDPM codebase:
|
512 |
+
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
|
513 |
+
###########################################################################
|
514 |
+
|
515 |
+
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
516 |
+
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
517 |
+
half_dim = embedding_dim // 2
|
518 |
+
# magic number 10000 is from transformers
|
519 |
+
emb = math.log(max_positions) / (half_dim - 1)
|
520 |
+
# emb = math.log(2.) / (half_dim - 1)
|
521 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
522 |
+
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
|
523 |
+
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
|
524 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
525 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
526 |
+
if embedding_dim % 2 == 1: # zero pad
|
527 |
+
emb = F.pad(emb, (0, 1), mode='constant')
|
528 |
+
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
529 |
+
return emb
|
530 |
+
|
531 |
+
|
532 |
+
def _einsum(a, b, c, x, y):
|
533 |
+
einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
|
534 |
+
return torch.einsum(einsum_str, x, y)
|
535 |
+
|
536 |
+
|
537 |
+
def contract_inner(x, y):
|
538 |
+
"""tensordot(x, y, 1)."""
|
539 |
+
x_chars = list(string.ascii_lowercase[:len(x.shape)])
|
540 |
+
y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
|
541 |
+
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
|
542 |
+
out_chars = x_chars[:-1] + y_chars[1:]
|
543 |
+
return _einsum(x_chars, y_chars, out_chars, x, y)
|
544 |
+
|
545 |
+
|
546 |
+
class NIN(nn.Module):
|
547 |
+
def __init__(self, in_dim, num_units, init_scale=0.1):
|
548 |
+
super().__init__()
|
549 |
+
self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
|
550 |
+
self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
|
551 |
+
|
552 |
+
def forward(self, x):
|
553 |
+
x = x.permute(0, 2, 3, 1)
|
554 |
+
y = contract_inner(x, self.W) + self.b
|
555 |
+
return y.permute(0, 3, 1, 2)
|
556 |
+
|
557 |
+
|
558 |
+
class AttnBlock(nn.Module):
|
559 |
+
"""Channel-wise self-attention block."""
|
560 |
+
def __init__(self, channels):
|
561 |
+
super().__init__()
|
562 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
|
563 |
+
self.NIN_0 = NIN(channels, channels)
|
564 |
+
self.NIN_1 = NIN(channels, channels)
|
565 |
+
self.NIN_2 = NIN(channels, channels)
|
566 |
+
self.NIN_3 = NIN(channels, channels, init_scale=0.)
|
567 |
+
|
568 |
+
def forward(self, x):
|
569 |
+
B, C, H, W = x.shape
|
570 |
+
h = self.GroupNorm_0(x)
|
571 |
+
q = self.NIN_0(h)
|
572 |
+
k = self.NIN_1(h)
|
573 |
+
v = self.NIN_2(h)
|
574 |
+
|
575 |
+
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
|
576 |
+
w = torch.reshape(w, (B, H, W, H * W))
|
577 |
+
w = F.softmax(w, dim=-1)
|
578 |
+
w = torch.reshape(w, (B, H, W, H, W))
|
579 |
+
h = torch.einsum('bhwij,bcij->bchw', w, v)
|
580 |
+
h = self.NIN_3(h)
|
581 |
+
return x + h
|
582 |
+
|
583 |
+
|
584 |
+
class Upsample(nn.Module):
|
585 |
+
def __init__(self, channels, with_conv=False):
|
586 |
+
super().__init__()
|
587 |
+
if with_conv:
|
588 |
+
self.Conv_0 = ddpm_conv3x3(channels, channels)
|
589 |
+
self.with_conv = with_conv
|
590 |
+
|
591 |
+
def forward(self, x):
|
592 |
+
B, C, H, W = x.shape
|
593 |
+
h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
|
594 |
+
if self.with_conv:
|
595 |
+
h = self.Conv_0(h)
|
596 |
+
return h
|
597 |
+
|
598 |
+
|
599 |
+
class Downsample(nn.Module):
|
600 |
+
def __init__(self, channels, with_conv=False):
|
601 |
+
super().__init__()
|
602 |
+
if with_conv:
|
603 |
+
self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
|
604 |
+
self.with_conv = with_conv
|
605 |
+
|
606 |
+
def forward(self, x):
|
607 |
+
B, C, H, W = x.shape
|
608 |
+
# Emulate 'SAME' padding
|
609 |
+
if self.with_conv:
|
610 |
+
x = F.pad(x, (0, 1, 0, 1))
|
611 |
+
x = self.Conv_0(x)
|
612 |
+
else:
|
613 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
|
614 |
+
|
615 |
+
assert x.shape == (B, C, H // 2, W // 2)
|
616 |
+
return x
|
617 |
+
|
618 |
+
|
619 |
+
class ResnetBlockDDPM(nn.Module):
|
620 |
+
"""The ResNet Blocks used in DDPM."""
|
621 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
|
622 |
+
super().__init__()
|
623 |
+
if out_ch is None:
|
624 |
+
out_ch = in_ch
|
625 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
|
626 |
+
self.act = act
|
627 |
+
self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
|
628 |
+
if temb_dim is not None:
|
629 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
630 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
631 |
+
nn.init.zeros_(self.Dense_0.bias)
|
632 |
+
|
633 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
|
634 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
635 |
+
self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
|
636 |
+
if in_ch != out_ch:
|
637 |
+
if conv_shortcut:
|
638 |
+
self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
|
639 |
+
else:
|
640 |
+
self.NIN_0 = NIN(in_ch, out_ch)
|
641 |
+
self.out_ch = out_ch
|
642 |
+
self.in_ch = in_ch
|
643 |
+
self.conv_shortcut = conv_shortcut
|
644 |
+
|
645 |
+
def forward(self, x, temb=None):
|
646 |
+
B, C, H, W = x.shape
|
647 |
+
assert C == self.in_ch
|
648 |
+
out_ch = self.out_ch if self.out_ch else self.in_ch
|
649 |
+
h = self.act(self.GroupNorm_0(x))
|
650 |
+
h = self.Conv_0(h)
|
651 |
+
# Add bias to each feature map conditioned on the time embedding
|
652 |
+
if temb is not None:
|
653 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
654 |
+
h = self.act(self.GroupNorm_1(h))
|
655 |
+
h = self.Dropout_0(h)
|
656 |
+
h = self.Conv_1(h)
|
657 |
+
if C != out_ch:
|
658 |
+
if self.conv_shortcut:
|
659 |
+
x = self.Conv_2(x)
|
660 |
+
else:
|
661 |
+
x = self.NIN_0(x)
|
662 |
+
return x + h
|
geco/backbones/ncsnpp_utils/layerspp.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: skip-file
|
17 |
+
"""Layers for defining NCSN++.
|
18 |
+
"""
|
19 |
+
from . import layers
|
20 |
+
import score_models.layers.up_or_downsampling2d as up_or_down_sampling
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
conv1x1 = layers.ddpm_conv1x1
|
27 |
+
conv3x3 = layers.ddpm_conv3x3
|
28 |
+
NIN = layers.NIN
|
29 |
+
default_init = layers.default_init
|
30 |
+
|
31 |
+
|
32 |
+
class GaussianFourierProjection(nn.Module):
|
33 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
34 |
+
|
35 |
+
def __init__(self, embedding_size=256, scale=1.0):
|
36 |
+
super().__init__()
|
37 |
+
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
|
41 |
+
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
42 |
+
|
43 |
+
|
44 |
+
class Combine(nn.Module):
|
45 |
+
"""Combine information from skip connections."""
|
46 |
+
|
47 |
+
def __init__(self, dim1, dim2, method='cat'):
|
48 |
+
super().__init__()
|
49 |
+
self.Conv_0 = conv1x1(dim1, dim2)
|
50 |
+
self.method = method
|
51 |
+
|
52 |
+
def forward(self, x, y):
|
53 |
+
h = self.Conv_0(x)
|
54 |
+
if self.method == 'cat':
|
55 |
+
return torch.cat([h, y], dim=1)
|
56 |
+
elif self.method == 'sum':
|
57 |
+
return h + y
|
58 |
+
else:
|
59 |
+
raise ValueError(f'Method {self.method} not recognized.')
|
60 |
+
|
61 |
+
|
62 |
+
class AttnBlockpp(nn.Module):
|
63 |
+
"""Channel-wise self-attention block. Modified from DDPM."""
|
64 |
+
|
65 |
+
def __init__(self, channels, skip_rescale=False, init_scale=0.):
|
66 |
+
super().__init__()
|
67 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
|
68 |
+
eps=1e-6)
|
69 |
+
self.NIN_0 = NIN(channels, channels)
|
70 |
+
self.NIN_1 = NIN(channels, channels)
|
71 |
+
self.NIN_2 = NIN(channels, channels)
|
72 |
+
self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
|
73 |
+
self.skip_rescale = skip_rescale
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
B, C, H, W = x.shape
|
77 |
+
h = self.GroupNorm_0(x)
|
78 |
+
q = self.NIN_0(h)
|
79 |
+
k = self.NIN_1(h)
|
80 |
+
v = self.NIN_2(h)
|
81 |
+
|
82 |
+
w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
|
83 |
+
w = torch.reshape(w, (B, H, W, H * W))
|
84 |
+
w = F.softmax(w, dim=-1)
|
85 |
+
w = torch.reshape(w, (B, H, W, H, W))
|
86 |
+
h = torch.einsum('bhwij,bcij->bchw', w, v)
|
87 |
+
h = self.NIN_3(h)
|
88 |
+
if not self.skip_rescale:
|
89 |
+
return x + h
|
90 |
+
else:
|
91 |
+
return (x + h) / np.sqrt(2.)
|
92 |
+
|
93 |
+
|
94 |
+
class ResnetBlockDDPMpp(nn.Module):
|
95 |
+
"""ResBlock adapted from DDPM."""
|
96 |
+
|
97 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
|
98 |
+
dropout=0.1, skip_rescale=False, init_scale=0.):
|
99 |
+
super().__init__()
|
100 |
+
out_ch = out_ch if out_ch else in_ch
|
101 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
102 |
+
self.Conv_0 = conv3x3(in_ch, out_ch)
|
103 |
+
if temb_dim is not None:
|
104 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
105 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
|
106 |
+
nn.init.zeros_(self.Dense_0.bias)
|
107 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
108 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
109 |
+
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
110 |
+
if in_ch != out_ch:
|
111 |
+
if conv_shortcut:
|
112 |
+
self.Conv_2 = conv3x3(in_ch, out_ch)
|
113 |
+
else:
|
114 |
+
self.NIN_0 = NIN(in_ch, out_ch)
|
115 |
+
|
116 |
+
self.skip_rescale = skip_rescale
|
117 |
+
self.act = act
|
118 |
+
self.out_ch = out_ch
|
119 |
+
self.conv_shortcut = conv_shortcut
|
120 |
+
|
121 |
+
def forward(self, x, temb=None):
|
122 |
+
h = self.act(self.GroupNorm_0(x))
|
123 |
+
h = self.Conv_0(h)
|
124 |
+
if temb is not None:
|
125 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
126 |
+
h = self.act(self.GroupNorm_1(h))
|
127 |
+
h = self.Dropout_0(h)
|
128 |
+
h = self.Conv_1(h)
|
129 |
+
if x.shape[1] != self.out_ch:
|
130 |
+
if self.conv_shortcut:
|
131 |
+
x = self.Conv_2(x)
|
132 |
+
else:
|
133 |
+
x = self.NIN_0(x)
|
134 |
+
if not self.skip_rescale:
|
135 |
+
return x + h
|
136 |
+
else:
|
137 |
+
return (x + h) / np.sqrt(2.)
|
138 |
+
|
139 |
+
|
140 |
+
class ResnetBlockBigGANpp(nn.Module):
|
141 |
+
def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
|
142 |
+
dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
|
143 |
+
skip_rescale=True, init_scale=0.):
|
144 |
+
super().__init__()
|
145 |
+
|
146 |
+
out_ch = out_ch if out_ch else in_ch
|
147 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
|
148 |
+
self.up = up
|
149 |
+
self.down = down
|
150 |
+
self.fir = fir
|
151 |
+
self.fir_kernel = fir_kernel
|
152 |
+
|
153 |
+
self.Conv_0 = conv3x3(in_ch, out_ch)
|
154 |
+
if temb_dim is not None:
|
155 |
+
self.Dense_0 = nn.Linear(temb_dim, out_ch)
|
156 |
+
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
|
157 |
+
nn.init.zeros_(self.Dense_0.bias)
|
158 |
+
|
159 |
+
self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
|
160 |
+
self.Dropout_0 = nn.Dropout(dropout)
|
161 |
+
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
|
162 |
+
if in_ch != out_ch or up or down:
|
163 |
+
self.Conv_2 = conv1x1(in_ch, out_ch)
|
164 |
+
|
165 |
+
self.skip_rescale = skip_rescale
|
166 |
+
self.act = act
|
167 |
+
self.in_ch = in_ch
|
168 |
+
self.out_ch = out_ch
|
169 |
+
|
170 |
+
def forward(self, x, temb=None):
|
171 |
+
h = self.act(self.GroupNorm_0(x))
|
172 |
+
|
173 |
+
if self.up:
|
174 |
+
if self.fir:
|
175 |
+
h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
|
176 |
+
x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
|
177 |
+
else:
|
178 |
+
h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
|
179 |
+
x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
|
180 |
+
elif self.down:
|
181 |
+
if self.fir:
|
182 |
+
h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
|
183 |
+
x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
|
184 |
+
else:
|
185 |
+
h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
|
186 |
+
x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
|
187 |
+
|
188 |
+
h = self.Conv_0(h)
|
189 |
+
# Add bias to each feature map conditioned on the time embedding
|
190 |
+
if temb is not None:
|
191 |
+
h += self.Dense_0(self.act(temb))[:, :, None, None]
|
192 |
+
h = self.act(self.GroupNorm_1(h))
|
193 |
+
h = self.Dropout_0(h)
|
194 |
+
h = self.Conv_1(h)
|
195 |
+
|
196 |
+
if self.in_ch != self.out_ch or self.up or self.down:
|
197 |
+
x = self.Conv_2(x)
|
198 |
+
|
199 |
+
if not self.skip_rescale:
|
200 |
+
return x + h
|
201 |
+
else:
|
202 |
+
return (x + h) / np.sqrt(2.)
|
geco/backbones/ncsnpp_utils/normalization.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Normalization layers."""
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch
|
19 |
+
import functools
|
20 |
+
|
21 |
+
|
22 |
+
def get_normalization(config, conditional=False):
|
23 |
+
"""Obtain normalization modules from the config file."""
|
24 |
+
norm = config.model.normalization
|
25 |
+
if conditional:
|
26 |
+
if norm == 'InstanceNorm++':
|
27 |
+
return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError(f'{norm} not implemented yet.')
|
30 |
+
else:
|
31 |
+
if norm == 'InstanceNorm':
|
32 |
+
return nn.InstanceNorm2d
|
33 |
+
elif norm == 'InstanceNorm++':
|
34 |
+
return InstanceNorm2dPlus
|
35 |
+
elif norm == 'VarianceNorm':
|
36 |
+
return VarianceNorm2d
|
37 |
+
elif norm == 'GroupNorm':
|
38 |
+
return nn.GroupNorm
|
39 |
+
else:
|
40 |
+
raise ValueError('Unknown normalization: %s' % norm)
|
41 |
+
|
42 |
+
|
43 |
+
class ConditionalBatchNorm2d(nn.Module):
|
44 |
+
def __init__(self, num_features, num_classes, bias=True):
|
45 |
+
super().__init__()
|
46 |
+
self.num_features = num_features
|
47 |
+
self.bias = bias
|
48 |
+
self.bn = nn.BatchNorm2d(num_features, affine=False)
|
49 |
+
if self.bias:
|
50 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
51 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
52 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
53 |
+
else:
|
54 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
55 |
+
self.embed.weight.data.uniform_()
|
56 |
+
|
57 |
+
def forward(self, x, y):
|
58 |
+
out = self.bn(x)
|
59 |
+
if self.bias:
|
60 |
+
gamma, beta = self.embed(y).chunk(2, dim=1)
|
61 |
+
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
|
62 |
+
else:
|
63 |
+
gamma = self.embed(y)
|
64 |
+
out = gamma.view(-1, self.num_features, 1, 1) * out
|
65 |
+
return out
|
66 |
+
|
67 |
+
|
68 |
+
class ConditionalInstanceNorm2d(nn.Module):
|
69 |
+
def __init__(self, num_features, num_classes, bias=True):
|
70 |
+
super().__init__()
|
71 |
+
self.num_features = num_features
|
72 |
+
self.bias = bias
|
73 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
74 |
+
if bias:
|
75 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
76 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
77 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
78 |
+
else:
|
79 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
80 |
+
self.embed.weight.data.uniform_()
|
81 |
+
|
82 |
+
def forward(self, x, y):
|
83 |
+
h = self.instance_norm(x)
|
84 |
+
if self.bias:
|
85 |
+
gamma, beta = self.embed(y).chunk(2, dim=-1)
|
86 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
|
87 |
+
else:
|
88 |
+
gamma = self.embed(y)
|
89 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
90 |
+
return out
|
91 |
+
|
92 |
+
|
93 |
+
class ConditionalVarianceNorm2d(nn.Module):
|
94 |
+
def __init__(self, num_features, num_classes, bias=False):
|
95 |
+
super().__init__()
|
96 |
+
self.num_features = num_features
|
97 |
+
self.bias = bias
|
98 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
99 |
+
self.embed.weight.data.normal_(1, 0.02)
|
100 |
+
|
101 |
+
def forward(self, x, y):
|
102 |
+
vars = torch.var(x, dim=(2, 3), keepdim=True)
|
103 |
+
h = x / torch.sqrt(vars + 1e-5)
|
104 |
+
|
105 |
+
gamma = self.embed(y)
|
106 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
107 |
+
return out
|
108 |
+
|
109 |
+
|
110 |
+
class VarianceNorm2d(nn.Module):
|
111 |
+
def __init__(self, num_features, bias=False):
|
112 |
+
super().__init__()
|
113 |
+
self.num_features = num_features
|
114 |
+
self.bias = bias
|
115 |
+
self.alpha = nn.Parameter(torch.zeros(num_features))
|
116 |
+
self.alpha.data.normal_(1, 0.02)
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
vars = torch.var(x, dim=(2, 3), keepdim=True)
|
120 |
+
h = x / torch.sqrt(vars + 1e-5)
|
121 |
+
|
122 |
+
out = self.alpha.view(-1, self.num_features, 1, 1) * h
|
123 |
+
return out
|
124 |
+
|
125 |
+
|
126 |
+
class ConditionalNoneNorm2d(nn.Module):
|
127 |
+
def __init__(self, num_features, num_classes, bias=True):
|
128 |
+
super().__init__()
|
129 |
+
self.num_features = num_features
|
130 |
+
self.bias = bias
|
131 |
+
if bias:
|
132 |
+
self.embed = nn.Embedding(num_classes, num_features * 2)
|
133 |
+
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
|
134 |
+
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
|
135 |
+
else:
|
136 |
+
self.embed = nn.Embedding(num_classes, num_features)
|
137 |
+
self.embed.weight.data.uniform_()
|
138 |
+
|
139 |
+
def forward(self, x, y):
|
140 |
+
if self.bias:
|
141 |
+
gamma, beta = self.embed(y).chunk(2, dim=-1)
|
142 |
+
out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
|
143 |
+
else:
|
144 |
+
gamma = self.embed(y)
|
145 |
+
out = gamma.view(-1, self.num_features, 1, 1) * x
|
146 |
+
return out
|
147 |
+
|
148 |
+
|
149 |
+
class NoneNorm2d(nn.Module):
|
150 |
+
def __init__(self, num_features, bias=True):
|
151 |
+
super().__init__()
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class InstanceNorm2dPlus(nn.Module):
|
158 |
+
def __init__(self, num_features, bias=True):
|
159 |
+
super().__init__()
|
160 |
+
self.num_features = num_features
|
161 |
+
self.bias = bias
|
162 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
163 |
+
self.alpha = nn.Parameter(torch.zeros(num_features))
|
164 |
+
self.gamma = nn.Parameter(torch.zeros(num_features))
|
165 |
+
self.alpha.data.normal_(1, 0.02)
|
166 |
+
self.gamma.data.normal_(1, 0.02)
|
167 |
+
if bias:
|
168 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
means = torch.mean(x, dim=(2, 3))
|
172 |
+
m = torch.mean(means, dim=-1, keepdim=True)
|
173 |
+
v = torch.var(means, dim=-1, keepdim=True)
|
174 |
+
means = (means - m) / (torch.sqrt(v + 1e-5))
|
175 |
+
h = self.instance_norm(x)
|
176 |
+
|
177 |
+
if self.bias:
|
178 |
+
h = h + means[..., None, None] * self.alpha[..., None, None]
|
179 |
+
out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
|
180 |
+
else:
|
181 |
+
h = h + means[..., None, None] * self.alpha[..., None, None]
|
182 |
+
out = self.gamma.view(-1, self.num_features, 1, 1) * h
|
183 |
+
return out
|
184 |
+
|
185 |
+
|
186 |
+
class ConditionalInstanceNorm2dPlus(nn.Module):
|
187 |
+
def __init__(self, num_features, num_classes, bias=True):
|
188 |
+
super().__init__()
|
189 |
+
self.num_features = num_features
|
190 |
+
self.bias = bias
|
191 |
+
self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
|
192 |
+
if bias:
|
193 |
+
self.embed = nn.Embedding(num_classes, num_features * 3)
|
194 |
+
self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
|
195 |
+
self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
|
196 |
+
else:
|
197 |
+
self.embed = nn.Embedding(num_classes, 2 * num_features)
|
198 |
+
self.embed.weight.data.normal_(1, 0.02)
|
199 |
+
|
200 |
+
def forward(self, x, y):
|
201 |
+
means = torch.mean(x, dim=(2, 3))
|
202 |
+
m = torch.mean(means, dim=-1, keepdim=True)
|
203 |
+
v = torch.var(means, dim=-1, keepdim=True)
|
204 |
+
means = (means - m) / (torch.sqrt(v + 1e-5))
|
205 |
+
h = self.instance_norm(x)
|
206 |
+
|
207 |
+
if self.bias:
|
208 |
+
gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
|
209 |
+
h = h + means[..., None, None] * alpha[..., None, None]
|
210 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
|
211 |
+
else:
|
212 |
+
gamma, alpha = self.embed(y).chunk(2, dim=-1)
|
213 |
+
h = h + means[..., None, None] * alpha[..., None, None]
|
214 |
+
out = gamma.view(-1, self.num_features, 1, 1) * h
|
215 |
+
return out
|
geco/backbones/ncsnpp_utils/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""All functions and modules related to model definition.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
from ...sdes import OUVESDE, OUVPSDE
|
23 |
+
|
24 |
+
|
25 |
+
_MODELS = {}
|
26 |
+
|
27 |
+
|
28 |
+
def register_model(cls=None, *, name=None):
|
29 |
+
"""A decorator for registering model classes."""
|
30 |
+
|
31 |
+
def _register(cls):
|
32 |
+
if name is None:
|
33 |
+
local_name = cls.__name__
|
34 |
+
else:
|
35 |
+
local_name = name
|
36 |
+
if local_name in _MODELS:
|
37 |
+
raise ValueError(f'Already registered model with name: {local_name}')
|
38 |
+
_MODELS[local_name] = cls
|
39 |
+
return cls
|
40 |
+
|
41 |
+
if cls is None:
|
42 |
+
return _register
|
43 |
+
else:
|
44 |
+
return _register(cls)
|
45 |
+
|
46 |
+
|
47 |
+
def get_model(name):
|
48 |
+
return _MODELS[name]
|
49 |
+
|
50 |
+
|
51 |
+
def get_sigmas(sigma_min, sigma_max, num_scales):
|
52 |
+
"""Get sigmas --- the set of noise levels for SMLD from config files.
|
53 |
+
Args:
|
54 |
+
config: A ConfigDict object parsed from the config file
|
55 |
+
Returns:
|
56 |
+
sigmas: a jax numpy arrary of noise levels
|
57 |
+
"""
|
58 |
+
sigmas = np.exp(
|
59 |
+
np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
|
60 |
+
|
61 |
+
return sigmas
|
62 |
+
|
63 |
+
|
64 |
+
def get_ddpm_params(config):
|
65 |
+
"""Get betas and alphas --- parameters used in the original DDPM paper."""
|
66 |
+
num_diffusion_timesteps = 1000
|
67 |
+
# parameters need to be adapted if number of time steps differs from 1000
|
68 |
+
beta_start = config.model.beta_min / config.model.num_scales
|
69 |
+
beta_end = config.model.beta_max / config.model.num_scales
|
70 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
71 |
+
|
72 |
+
alphas = 1. - betas
|
73 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
74 |
+
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
|
75 |
+
sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
|
76 |
+
|
77 |
+
return {
|
78 |
+
'betas': betas,
|
79 |
+
'alphas': alphas,
|
80 |
+
'alphas_cumprod': alphas_cumprod,
|
81 |
+
'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
|
82 |
+
'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
|
83 |
+
'beta_min': beta_start * (num_diffusion_timesteps - 1),
|
84 |
+
'beta_max': beta_end * (num_diffusion_timesteps - 1),
|
85 |
+
'num_diffusion_timesteps': num_diffusion_timesteps
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
def create_model(config):
|
90 |
+
"""Create the score model."""
|
91 |
+
model_name = config.model.name
|
92 |
+
score_model = get_model(model_name)(config)
|
93 |
+
score_model = score_model.to(config.device)
|
94 |
+
score_model = torch.nn.DataParallel(score_model)
|
95 |
+
return score_model
|
96 |
+
|
97 |
+
|
98 |
+
def get_model_fn(model, train=False):
|
99 |
+
"""Create a function to give the output of the score-based model.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
model: The score model.
|
103 |
+
train: `True` for training and `False` for evaluation.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
A model function.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def model_fn(x, labels):
|
110 |
+
"""Compute the output of the score-based model.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
x: A mini-batch of input data.
|
114 |
+
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
|
115 |
+
for different models.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
A tuple of (model output, new mutable states)
|
119 |
+
"""
|
120 |
+
if not train:
|
121 |
+
model.eval()
|
122 |
+
return model(x, labels)
|
123 |
+
else:
|
124 |
+
model.train()
|
125 |
+
return model(x, labels)
|
126 |
+
|
127 |
+
return model_fn
|
128 |
+
|
129 |
+
|
130 |
+
def get_score_fn(sde, model, train=False, continuous=False):
|
131 |
+
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
sde: An `sde_lib.SDE` object that represents the forward SDE.
|
135 |
+
model: A score model.
|
136 |
+
train: `True` for training and `False` for evaluation.
|
137 |
+
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
A score function.
|
141 |
+
"""
|
142 |
+
model_fn = get_model_fn(model, train=train)
|
143 |
+
|
144 |
+
if isinstance(sde, OUVPSDE):
|
145 |
+
def score_fn(x, t):
|
146 |
+
# Scale neural network output by standard deviation and flip sign
|
147 |
+
if continuous:
|
148 |
+
# For VP-trained models, t=0 corresponds to the lowest noise level
|
149 |
+
# The maximum value of time embedding is assumed to 999 for
|
150 |
+
# continuously-trained models.
|
151 |
+
labels = t * 999
|
152 |
+
score = model_fn(x, labels)
|
153 |
+
std = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
154 |
+
else:
|
155 |
+
# For VP-trained models, t=0 corresponds to the lowest noise level
|
156 |
+
labels = t * (sde.N - 1)
|
157 |
+
score = model_fn(x, labels)
|
158 |
+
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
|
159 |
+
|
160 |
+
score = -score / std[:, None, None, None]
|
161 |
+
return score
|
162 |
+
|
163 |
+
elif isinstance(sde, OUVESDE):
|
164 |
+
def score_fn(x, t):
|
165 |
+
if continuous:
|
166 |
+
labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
|
167 |
+
else:
|
168 |
+
# For VE-trained models, t=0 corresponds to the highest noise level
|
169 |
+
labels = sde.T - t
|
170 |
+
labels *= sde.N - 1
|
171 |
+
labels = torch.round(labels).long()
|
172 |
+
|
173 |
+
score = model_fn(x, labels)
|
174 |
+
return score
|
175 |
+
|
176 |
+
else:
|
177 |
+
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
|
178 |
+
|
179 |
+
return score_fn
|
180 |
+
|
181 |
+
|
182 |
+
def to_flattened_numpy(x):
|
183 |
+
"""Flatten a torch tensor `x` and convert it to numpy."""
|
184 |
+
return x.detach().cpu().numpy().reshape((-1,))
|
185 |
+
|
186 |
+
|
187 |
+
def from_flattened_numpy(x, shape):
|
188 |
+
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
189 |
+
return torch.from_numpy(x.reshape(shape))
|
geco/backbones/shared.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from geco.util.registry import Registry
|
8 |
+
|
9 |
+
|
10 |
+
BackboneRegistry = Registry("Backbone")
|
11 |
+
|
12 |
+
|
13 |
+
class GaussianFourierProjection(nn.Module):
|
14 |
+
"""Gaussian random features for encoding time steps."""
|
15 |
+
|
16 |
+
def __init__(self, embed_dim, scale=16, complex_valued=False):
|
17 |
+
super().__init__()
|
18 |
+
self.complex_valued = complex_valued
|
19 |
+
if not complex_valued:
|
20 |
+
# If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
|
21 |
+
# Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
|
22 |
+
# we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
|
23 |
+
# and this halving is not necessary.
|
24 |
+
embed_dim = embed_dim // 2
|
25 |
+
# Randomly sample weights during initialization. These weights are fixed
|
26 |
+
# during optimization and are not trainable.
|
27 |
+
self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
|
28 |
+
|
29 |
+
def forward(self, t):
|
30 |
+
t_proj = t[:, None] * self.W[None, :] * 2*np.pi
|
31 |
+
if self.complex_valued:
|
32 |
+
return torch.exp(1j * t_proj)
|
33 |
+
else:
|
34 |
+
return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
|
35 |
+
|
36 |
+
|
37 |
+
class DiffusionStepEmbedding(nn.Module):
|
38 |
+
"""Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
|
39 |
+
|
40 |
+
def __init__(self, embed_dim, complex_valued=False):
|
41 |
+
super().__init__()
|
42 |
+
self.complex_valued = complex_valued
|
43 |
+
if not complex_valued:
|
44 |
+
# If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
|
45 |
+
# Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
|
46 |
+
# we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
|
47 |
+
# and this halving is not necessary.
|
48 |
+
embed_dim = embed_dim // 2
|
49 |
+
self.embed_dim = embed_dim
|
50 |
+
|
51 |
+
def forward(self, t):
|
52 |
+
fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
|
53 |
+
inner = t[:, None] * fac[None, :]
|
54 |
+
if self.complex_valued:
|
55 |
+
return torch.exp(1j * inner)
|
56 |
+
else:
|
57 |
+
return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
|
58 |
+
|
59 |
+
|
60 |
+
class ComplexLinear(nn.Module):
|
61 |
+
"""A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
|
62 |
+
def __init__(self, input_dim, output_dim, complex_valued):
|
63 |
+
super().__init__()
|
64 |
+
self.complex_valued = complex_valued
|
65 |
+
if self.complex_valued:
|
66 |
+
self.re = nn.Linear(input_dim, output_dim)
|
67 |
+
self.im = nn.Linear(input_dim, output_dim)
|
68 |
+
else:
|
69 |
+
self.lin = nn.Linear(input_dim, output_dim)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.complex_valued:
|
73 |
+
return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
|
74 |
+
else:
|
75 |
+
return self.lin(x)
|
76 |
+
|
77 |
+
|
78 |
+
class FeatureMapDense(nn.Module):
|
79 |
+
"""A fully connected layer that reshapes outputs to feature maps."""
|
80 |
+
|
81 |
+
def __init__(self, input_dim, output_dim, complex_valued=False):
|
82 |
+
super().__init__()
|
83 |
+
self.complex_valued = complex_valued
|
84 |
+
self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
return self.dense(x)[..., None, None]
|
88 |
+
|
89 |
+
|
90 |
+
def torch_complex_from_reim(re, im):
|
91 |
+
return torch.view_as_complex(torch.stack([re, im], dim=-1))
|
92 |
+
|
93 |
+
|
94 |
+
class ArgsComplexMultiplicationWrapper(nn.Module):
|
95 |
+
"""Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
|
96 |
+
|
97 |
+
Make a complex-valued module `F` from a real-valued module `f` by applying
|
98 |
+
complex multiplication rules:
|
99 |
+
|
100 |
+
F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
|
101 |
+
|
102 |
+
where `f1`, `f2` are instances of `f` that do *not* share weights.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
module_cls (callable): A class or function that returns a Torch module/functional.
|
106 |
+
Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
|
107 |
+
to construct the real and imaginary component modules.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, module_cls, *args, **kwargs):
|
111 |
+
super().__init__()
|
112 |
+
self.re_module = module_cls(*args, **kwargs)
|
113 |
+
self.im_module = module_cls(*args, **kwargs)
|
114 |
+
|
115 |
+
def forward(self, x, *args, **kwargs):
|
116 |
+
return torch_complex_from_reim(
|
117 |
+
self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
|
118 |
+
self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
|
123 |
+
ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
|
geco/data_module.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import join
|
3 |
+
import torch
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from glob import glob
|
8 |
+
import numpy as np
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchaudio
|
11 |
+
|
12 |
+
|
13 |
+
def get_window(window_type, window_length):
|
14 |
+
if window_type == 'sqrthann':
|
15 |
+
return torch.sqrt(torch.hann_window(window_length, periodic=True))
|
16 |
+
elif window_type == 'hann':
|
17 |
+
return torch.hann_window(window_length, periodic=True)
|
18 |
+
else:
|
19 |
+
raise NotImplementedError(f"Window type {window_type} not implemented!")
|
20 |
+
|
21 |
+
|
22 |
+
class Specs(Dataset):
|
23 |
+
def __init__(self, data_dir, dummy, shuffle_spec, num_frames, sampling_rate=8000,
|
24 |
+
format='default', normalize="noisy", spec_transform=None,
|
25 |
+
stft_kwargs=None, **ignored_kwargs):
|
26 |
+
|
27 |
+
# Read file paths according to file naming format.
|
28 |
+
if format == "default":
|
29 |
+
noisy_files1 = sorted(glob(os.path.join(data_dir, '*_source1hatP.wav')))
|
30 |
+
clean_files1 = [item.replace('_source1hatP.wav', '_source1.wav') for item in noisy_files1]
|
31 |
+
mixture_files1 = [item.replace('_source1hatP.wav', '_mix.wav') for item in noisy_files1]
|
32 |
+
noisy_files2 = sorted(glob(os.path.join(data_dir, '*_source2hatP.wav')))
|
33 |
+
clean_files2 = [item.replace('_source2hatP.wav', '_source2.wav') for item in noisy_files2]
|
34 |
+
mixture_files2 = [item.replace('_source2hatP.wav', '_mix.wav') for item in noisy_files2]
|
35 |
+
|
36 |
+
self.mixture_files = [*mixture_files1,*mixture_files2]
|
37 |
+
self.noisy_files = [*noisy_files1,*noisy_files2]
|
38 |
+
self.clean_files = [*clean_files1,*clean_files2]
|
39 |
+
else:
|
40 |
+
# Feel free to add your own directory format
|
41 |
+
raise NotImplementedError(f"Directory format {format} unknown!")
|
42 |
+
|
43 |
+
self.dummy = dummy
|
44 |
+
self.num_frames = num_frames
|
45 |
+
self.shuffle_spec = shuffle_spec
|
46 |
+
self.normalize = normalize
|
47 |
+
self.spec_transform = spec_transform
|
48 |
+
self.sampling_rate = sampling_rate
|
49 |
+
|
50 |
+
assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs"
|
51 |
+
self.stft_kwargs = stft_kwargs
|
52 |
+
self.hop_length = self.stft_kwargs["hop_length"]
|
53 |
+
assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation"
|
54 |
+
|
55 |
+
def __getitem__(self, i):
|
56 |
+
x, sr = torchaudio.load(self.clean_files[i])
|
57 |
+
if sr != self.sampling_rate:
|
58 |
+
x = torchaudio.transforms.Resample(sr, self.sampling_rate)(x)
|
59 |
+
y, sr = torchaudio.load(self.noisy_files[i])
|
60 |
+
if sr != self.sampling_rate:
|
61 |
+
y = torchaudio.transforms.Resample(sr, self.sampling_rate)(y)
|
62 |
+
m, sr = torchaudio.load(self.mixture_files[i])
|
63 |
+
if sr != self.sampling_rate:
|
64 |
+
m = torchaudio.transforms.Resample(sr, self.sampling_rate)(m)
|
65 |
+
|
66 |
+
min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1])
|
67 |
+
x = x[...,:min_leng]
|
68 |
+
y = y[...,:min_leng]
|
69 |
+
m = m[...,:min_leng]
|
70 |
+
|
71 |
+
# formula applies for center=True
|
72 |
+
target_len = (self.num_frames - 1) * self.hop_length
|
73 |
+
current_len = x.size(-1)
|
74 |
+
pad = max(target_len - current_len, 0)
|
75 |
+
if pad == 0:
|
76 |
+
# extract random part of the audio file
|
77 |
+
if self.shuffle_spec:
|
78 |
+
start = int(np.random.uniform(0, current_len-target_len))
|
79 |
+
else:
|
80 |
+
start = int((current_len-target_len)/2)
|
81 |
+
|
82 |
+
if y[..., start:start+target_len].abs().max() < 0.05:
|
83 |
+
start = 0
|
84 |
+
|
85 |
+
x = x[..., start:start+target_len]
|
86 |
+
y = y[..., start:start+target_len]
|
87 |
+
m = m[..., start:start+target_len]
|
88 |
+
else:
|
89 |
+
# pad audio if the length T is smaller than num_frames
|
90 |
+
x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant')
|
91 |
+
y = F.pad(y, (pad//2, pad//2+(pad%2)), mode='constant')
|
92 |
+
m = F.pad(m, (pad//2, pad//2+(pad%2)), mode='constant')
|
93 |
+
|
94 |
+
# normalize w.r.t to the noisy or the clean signal or not at all
|
95 |
+
# to ensure same clean signal power in x and y.
|
96 |
+
if self.normalize == "noisy":
|
97 |
+
normfac = y.abs().max()
|
98 |
+
elif self.normalize == "clean":
|
99 |
+
normfac = x.abs().max()
|
100 |
+
elif self.normalize == "not":
|
101 |
+
normfac = 1.0
|
102 |
+
x = x / normfac
|
103 |
+
y = y / normfac
|
104 |
+
m = m / normfac
|
105 |
+
X = torch.stft(x, **self.stft_kwargs)
|
106 |
+
Y = torch.stft(y, **self.stft_kwargs)
|
107 |
+
M = torch.stft(m, **self.stft_kwargs)
|
108 |
+
X, Y, M = self.spec_transform(X), self.spec_transform(Y), self.spec_transform(M)
|
109 |
+
return X, Y, M
|
110 |
+
|
111 |
+
def __len__(self):
|
112 |
+
if self.dummy:
|
113 |
+
# for debugging shrink the data set size
|
114 |
+
return int(len(self.clean_files)/200)
|
115 |
+
else:
|
116 |
+
return len(self.clean_files)
|
117 |
+
|
118 |
+
|
119 |
+
class SpecsDataModule(pl.LightningDataModule):
|
120 |
+
@staticmethod
|
121 |
+
def add_argparse_args(parser):
|
122 |
+
parser.add_argument("--train_dir", type=str, default='/export/corpora7/HW/speechbrain/recipes/LibriMix/separation/2025/save/libri2mix-train100')
|
123 |
+
parser.add_argument("--val_dir", type=str, default='/export/corpora7/HW/speechbrain/recipes/LibriMix/separation/2025/save/libri2mix-dev')
|
124 |
+
parser.add_argument("--test_dir", type=str, default='/export/corpora7/HW/speechbrain/recipes/LibriMix/separation/2025/save/libri2mix-test')
|
125 |
+
parser.add_argument("--format", type=str, default="default", help="Read file paths according to file naming format.")
|
126 |
+
parser.add_argument("--sampling_rate", type=int, default=8000, help="The sampling rate.")
|
127 |
+
parser.add_argument("--batch_size", type=int, default=16, help="The batch size. 8 by default.")
|
128 |
+
parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 128 freq bins
|
129 |
+
parser.add_argument("--hop_length", type=int, default=64, help="Window hop length. 128 by default.")
|
130 |
+
parser.add_argument("--num_frames", type=int, default=256, help="Number of frames for the dataset. 256 by default.")
|
131 |
+
parser.add_argument("--window", type=str, choices=("sqrthann", "hann"), default="hann", help="The window function to use for the STFT. 'hann' by default.")
|
132 |
+
parser.add_argument("--num_workers", type=int, default=8, help="Number of workers to use for DataLoaders. 4 by default.")
|
133 |
+
parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.")
|
134 |
+
parser.add_argument("--spec_factor", type=float, default=0.15, help="Factor to multiply complex STFT coefficients by. 0.15 by default.")
|
135 |
+
parser.add_argument("--spec_abs_exponent", type=float, default=0.5, help="Exponent e for the transformation abs(z)**e * exp(1j*angle(z)). 0.5 by default.")
|
136 |
+
parser.add_argument("--normalize", type=str, choices=("clean", "noisy", "not"), default="noisy", help="Normalize the input waveforms by the clean signal, the noisy signal, or not at all.")
|
137 |
+
parser.add_argument("--transform_type", type=str, choices=("exponent", "log", "none"), default="exponent", help="Spectogram transformation for input representation.")
|
138 |
+
return parser
|
139 |
+
|
140 |
+
def __init__(
|
141 |
+
self, train_dir, val_dir, test_dir, format='default', sampling_rate=8000, batch_size=8,
|
142 |
+
n_fft=510, hop_length=64, num_frames=256, window='hann',
|
143 |
+
num_workers=4, dummy=False, spec_factor=0.15, spec_abs_exponent=0.5,
|
144 |
+
gpu=True, normalize='noisy', transform_type="exponent", **kwargs
|
145 |
+
):
|
146 |
+
super().__init__()
|
147 |
+
self.train_dir = train_dir
|
148 |
+
self.val_dir = val_dir
|
149 |
+
self.test_dir = test_dir
|
150 |
+
self.format = format
|
151 |
+
self.sampling_rate = sampling_rate
|
152 |
+
self.batch_size = batch_size
|
153 |
+
self.n_fft = n_fft
|
154 |
+
self.hop_length = hop_length
|
155 |
+
self.num_frames = num_frames
|
156 |
+
self.window = get_window(window, self.n_fft)
|
157 |
+
self.windows = {}
|
158 |
+
self.num_workers = num_workers
|
159 |
+
self.dummy = dummy
|
160 |
+
self.spec_factor = spec_factor
|
161 |
+
self.spec_abs_exponent = spec_abs_exponent
|
162 |
+
self.gpu = gpu
|
163 |
+
self.normalize = normalize
|
164 |
+
self.transform_type = transform_type
|
165 |
+
self.kwargs = kwargs
|
166 |
+
|
167 |
+
def setup(self, stage=None):
|
168 |
+
specs_kwargs = dict(
|
169 |
+
stft_kwargs=self.stft_kwargs, num_frames=self.num_frames,
|
170 |
+
spec_transform=self.spec_fwd, **self.kwargs
|
171 |
+
)
|
172 |
+
if stage == 'fit' or stage is None:
|
173 |
+
self.train_set = Specs(data_dir=self.train_dir,
|
174 |
+
dummy=self.dummy, shuffle_spec=True, format=self.format,
|
175 |
+
normalize=self.normalize, sampling_rate=self.sampling_rate, **specs_kwargs)
|
176 |
+
self.valid_set = Specs(data_dir=self.val_dir,
|
177 |
+
dummy=self.dummy, shuffle_spec=False, format=self.format,
|
178 |
+
normalize=self.normalize, sampling_rate=self.sampling_rate, **specs_kwargs)
|
179 |
+
if stage == 'test' or stage is None:
|
180 |
+
self.test_set = Specs(data_dir=self.test_dir,
|
181 |
+
dummy=self.dummy, shuffle_spec=False, format=self.format,
|
182 |
+
normalize=self.normalize, sampling_rate=self.sampling_rate, **specs_kwargs)
|
183 |
+
|
184 |
+
def spec_fwd(self, spec):
|
185 |
+
if self.transform_type == "exponent":
|
186 |
+
if self.spec_abs_exponent != 1:
|
187 |
+
# only do this calculation if spec_exponent != 1, otherwise it's quite a bit of wasted computation
|
188 |
+
# and introduced numerical error
|
189 |
+
e = self.spec_abs_exponent
|
190 |
+
spec = spec.abs()**e * torch.exp(1j * spec.angle())
|
191 |
+
spec = spec * self.spec_factor
|
192 |
+
elif self.transform_type == "log":
|
193 |
+
spec = torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle())
|
194 |
+
spec = spec * self.spec_factor
|
195 |
+
elif self.transform_type == "none":
|
196 |
+
spec = spec
|
197 |
+
return spec
|
198 |
+
|
199 |
+
def spec_back(self, spec):
|
200 |
+
if self.transform_type == "exponent":
|
201 |
+
spec = spec / self.spec_factor
|
202 |
+
if self.spec_abs_exponent != 1:
|
203 |
+
e = self.spec_abs_exponent
|
204 |
+
spec = spec.abs()**(1/e) * torch.exp(1j * spec.angle())
|
205 |
+
elif self.transform_type == "log":
|
206 |
+
spec = spec / self.spec_factor
|
207 |
+
spec = (torch.exp(spec.abs()) - 1) * torch.exp(1j * spec.angle())
|
208 |
+
elif self.transform_type == "none":
|
209 |
+
spec = spec
|
210 |
+
return spec
|
211 |
+
|
212 |
+
@property
|
213 |
+
def stft_kwargs(self):
|
214 |
+
return {**self.istft_kwargs, "return_complex": True}
|
215 |
+
|
216 |
+
@property
|
217 |
+
def istft_kwargs(self):
|
218 |
+
return dict(
|
219 |
+
n_fft=self.n_fft, hop_length=self.hop_length,
|
220 |
+
window=self.window, center=True
|
221 |
+
)
|
222 |
+
|
223 |
+
def _get_window(self, x):
|
224 |
+
"""
|
225 |
+
Retrieve an appropriate window for the given tensor x, matching the device.
|
226 |
+
Caches the retrieved windows so that only one window tensor will be allocated per device.
|
227 |
+
"""
|
228 |
+
window = self.windows.get(x.device, None)
|
229 |
+
if window is None:
|
230 |
+
window = self.window.to(x.device)
|
231 |
+
self.windows[x.device] = window
|
232 |
+
return window
|
233 |
+
|
234 |
+
def stft(self, sig):
|
235 |
+
window = self._get_window(sig)
|
236 |
+
return torch.stft(sig, **{**self.stft_kwargs, "window": window})
|
237 |
+
|
238 |
+
def istft(self, spec, length=None):
|
239 |
+
window = self._get_window(spec)
|
240 |
+
return torch.istft(spec, **{**self.istft_kwargs, "window": window, "length": length})
|
241 |
+
|
242 |
+
def train_dataloader(self):
|
243 |
+
return DataLoader(
|
244 |
+
self.train_set, batch_size=self.batch_size,
|
245 |
+
num_workers=self.num_workers, pin_memory=self.gpu, shuffle=True
|
246 |
+
)
|
247 |
+
|
248 |
+
def val_dataloader(self):
|
249 |
+
return DataLoader(
|
250 |
+
self.valid_set, batch_size=self.batch_size,
|
251 |
+
num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
|
252 |
+
)
|
253 |
+
|
254 |
+
def test_dataloader(self):
|
255 |
+
return DataLoader(
|
256 |
+
self.test_set, batch_size=self.batch_size,
|
257 |
+
num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
|
258 |
+
)
|
geco/model.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from math import ceil
|
3 |
+
import warnings
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
from torch_ema import ExponentialMovingAverage
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from geco import sampling
|
10 |
+
from geco.sdes import SDERegistry
|
11 |
+
from geco.backbones import BackboneRegistry
|
12 |
+
from geco.util.inference import evaluate_model
|
13 |
+
from geco.util.other import pad_spec
|
14 |
+
import numpy as np
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class ScoreModel(pl.LightningModule):
|
20 |
+
@staticmethod
|
21 |
+
def add_argparse_args(parser):
|
22 |
+
parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
|
23 |
+
parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
|
24 |
+
parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)")
|
25 |
+
parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
|
26 |
+
parser.add_argument("--loss_type", type=str, default="mse", help="The type of loss function to use.")
|
27 |
+
parser.add_argument("--loss_abs_exponent", type=float, default=0.5, help="magnitude transformation in the loss term")
|
28 |
+
return parser
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2, loss_abs_exponent=0.5,
|
32 |
+
num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Create a new ScoreModel.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
backbone: Backbone DNN that serves as a score-based model.
|
39 |
+
sde: The SDE that defines the diffusion process.
|
40 |
+
lr: The learning rate of the optimizer. (1e-4 by default).
|
41 |
+
ema_decay: The decay constant of the parameter EMA (0.999 by default).
|
42 |
+
t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
|
43 |
+
loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
|
44 |
+
"""
|
45 |
+
super().__init__()
|
46 |
+
# Initialize Backbone DNN
|
47 |
+
dnn_cls = BackboneRegistry.get_by_name(backbone)
|
48 |
+
self.dnn = dnn_cls(**kwargs)
|
49 |
+
# Initialize SDE
|
50 |
+
if sde == 'bbve':
|
51 |
+
#change parameters, if the old class bbve is used. Needed for loading the provided checkpoint
|
52 |
+
#as that checkpoint was trained with the old class.
|
53 |
+
sde = 'bbed'
|
54 |
+
kwargs['k'] = kwargs['sigma_max']
|
55 |
+
del kwargs['sigma_max']
|
56 |
+
del kwargs['sigma_min']
|
57 |
+
|
58 |
+
sde_cls = SDERegistry.get_by_name(sde)
|
59 |
+
self.sde = sde_cls(**kwargs)
|
60 |
+
# Store hyperparams and save them
|
61 |
+
self.lr = lr
|
62 |
+
self.ema_decay = ema_decay
|
63 |
+
self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
|
64 |
+
self._error_loading_ema = False
|
65 |
+
self.t_eps = t_eps
|
66 |
+
self.loss_type = loss_type
|
67 |
+
self.num_eval_files = num_eval_files
|
68 |
+
self.loss_abs_exponent = loss_abs_exponent
|
69 |
+
self.save_hyperparameters(ignore=['no_wandb'])
|
70 |
+
self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
def configure_optimizers(self):
|
75 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
76 |
+
return optimizer
|
77 |
+
|
78 |
+
def optimizer_step(self, *args, **kwargs):
|
79 |
+
# Method overridden so that the EMA params are updated after each optimizer step
|
80 |
+
super().optimizer_step(*args, **kwargs)
|
81 |
+
self.ema.update(self.parameters())
|
82 |
+
|
83 |
+
# on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
|
84 |
+
def on_load_checkpoint(self, checkpoint):
|
85 |
+
ema = checkpoint.get('ema', None)
|
86 |
+
if ema is not None:
|
87 |
+
self.ema.load_state_dict(checkpoint['ema'])
|
88 |
+
else:
|
89 |
+
self._error_loading_ema = True
|
90 |
+
warnings.warn("EMA state_dict not found in checkpoint!")
|
91 |
+
|
92 |
+
def on_save_checkpoint(self, checkpoint):
|
93 |
+
checkpoint['ema'] = self.ema.state_dict()
|
94 |
+
|
95 |
+
def train(self, mode, no_ema=False):
|
96 |
+
res = super().train(mode) # call the standard `train` method with the given mode
|
97 |
+
if not self._error_loading_ema:
|
98 |
+
if mode == False and not no_ema:
|
99 |
+
# eval
|
100 |
+
self.ema.store(self.parameters()) # store current params in EMA
|
101 |
+
self.ema.copy_to(self.parameters()) # copy EMA parameters over current params for evaluation
|
102 |
+
else:
|
103 |
+
# train
|
104 |
+
if self.ema.collected_params is not None:
|
105 |
+
self.ema.restore(self.parameters()) # restore the EMA weights (if stored)
|
106 |
+
return res
|
107 |
+
|
108 |
+
def eval(self, no_ema=False):
|
109 |
+
return self.train(False, no_ema=no_ema)
|
110 |
+
|
111 |
+
|
112 |
+
def _loss(self, score, sigmas, z):
|
113 |
+
if self.loss_type == 'mse':
|
114 |
+
err = sigmas*score + z
|
115 |
+
losses = torch.square(err.abs())
|
116 |
+
elif self.loss_type == 'mae':
|
117 |
+
losses = err.abs()
|
118 |
+
# taken from reduce_op function: sum over channels and position and mean over batch dim
|
119 |
+
# presumably only important for absolute loss number, not for gradients
|
120 |
+
loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
|
121 |
+
return loss
|
122 |
+
|
123 |
+
def _step(self, batch, batch_idx):
|
124 |
+
x, y, m = batch
|
125 |
+
rdm = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps
|
126 |
+
t = torch.min(rdm, torch.tensor(self.sde.T))
|
127 |
+
mean, std = self.sde.marginal_prob(x, t, y)
|
128 |
+
z = torch.randn_like(x) #
|
129 |
+
sigmas = std[:, None, None, None]
|
130 |
+
perturbed_data = mean + sigmas * z
|
131 |
+
score = self(perturbed_data, t, y, m)
|
132 |
+
loss = self._loss(score, sigmas, z)
|
133 |
+
return loss
|
134 |
+
|
135 |
+
def training_step(self, batch, batch_idx):
|
136 |
+
loss = self._step(batch, batch_idx)
|
137 |
+
self.log('train_loss', loss, on_step=True, on_epoch=True)
|
138 |
+
return loss
|
139 |
+
|
140 |
+
def validation_step(self, batch, batch_idx):
|
141 |
+
loss = self._step(batch, batch_idx)
|
142 |
+
self.log('valid_loss', loss, on_step=False, on_epoch=True)
|
143 |
+
|
144 |
+
# Evaluate speech enhancement performance
|
145 |
+
if batch_idx == 0 and self.num_eval_files != 0:
|
146 |
+
pesq, si_sdr, estoi = evaluate_model(self, self.num_eval_files)
|
147 |
+
self.log('pesq', pesq, on_step=False, on_epoch=True)
|
148 |
+
self.log('si_sdr', si_sdr, on_step=False, on_epoch=True)
|
149 |
+
self.log('estoi', estoi, on_step=False, on_epoch=True)
|
150 |
+
|
151 |
+
return loss
|
152 |
+
|
153 |
+
def forward(self, x, t, y, m):
|
154 |
+
# Concatenate y as an extra channel
|
155 |
+
dnn_input = torch.cat([x, y, m], dim=1)
|
156 |
+
# print(dnn_input.shape)
|
157 |
+
# the minus is most likely unimportant here - taken from Song's repo
|
158 |
+
score = -self.dnn(dnn_input, t)
|
159 |
+
return score
|
160 |
+
|
161 |
+
def to(self, *args, **kwargs):
|
162 |
+
"""Override PyTorch .to() to also transfer the EMA of the model weights"""
|
163 |
+
self.ema.to(*args, **kwargs)
|
164 |
+
return super().to(*args, **kwargs)
|
165 |
+
|
166 |
+
def get_pc_sampler(self, predictor_name, corrector_name, y, m, Y_prior=None, N=None, minibatch=None, timestep_type=None, **kwargs):
|
167 |
+
N = self.sde.N if N is None else N
|
168 |
+
sde = self.sde.copy()
|
169 |
+
sde.N = N
|
170 |
+
|
171 |
+
kwargs = {"eps": self.t_eps, **kwargs}
|
172 |
+
if minibatch is None:
|
173 |
+
return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, Y=y, M=m, Y_prior=Y_prior, timestep_type=timestep_type, **kwargs)
|
174 |
+
else:
|
175 |
+
M = y.shape[0]
|
176 |
+
def batched_sampling_fn():
|
177 |
+
samples, ns = [], []
|
178 |
+
for i in range(int(ceil(M / minibatch))):
|
179 |
+
y_mini = y[i*minibatch:(i+1)*minibatch]
|
180 |
+
y_prior_mini = Y_prior[i*minibatch:(i+1)*minibatch]
|
181 |
+
sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, Y=y_mini, M=m, y_prior=y_prior_mini, **kwargs)
|
182 |
+
sample, n = sampler()
|
183 |
+
samples.append(sample)
|
184 |
+
ns.append(n)
|
185 |
+
samples = torch.cat(samples, dim=0)
|
186 |
+
return samples, ns
|
187 |
+
return batched_sampling_fn
|
188 |
+
|
189 |
+
|
190 |
+
def train_dataloader(self):
|
191 |
+
return self.data_module.train_dataloader()
|
192 |
+
|
193 |
+
def val_dataloader(self):
|
194 |
+
return self.data_module.val_dataloader()
|
195 |
+
|
196 |
+
def test_dataloader(self):
|
197 |
+
return self.data_module.test_dataloader()
|
198 |
+
|
199 |
+
def setup(self, stage=None):
|
200 |
+
return self.data_module.setup(stage=stage)
|
201 |
+
|
202 |
+
def to_audio(self, spec, length=None):
|
203 |
+
return self._istft(self._backward_transform(spec), length)
|
204 |
+
|
205 |
+
def _forward_transform(self, spec):
|
206 |
+
return self.data_module.spec_fwd(spec)
|
207 |
+
|
208 |
+
def _backward_transform(self, spec):
|
209 |
+
return self.data_module.spec_back(spec)
|
210 |
+
|
211 |
+
def _stft(self, sig):
|
212 |
+
return self.data_module.stft(sig)
|
213 |
+
|
214 |
+
def _istft(self, spec, length=None):
|
215 |
+
return self.data_module.istft(spec, length)
|
216 |
+
|
217 |
+
def enhance(self, y, m, sampler_type="pc", predictor="reverse_diffusion",
|
218 |
+
corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False,
|
219 |
+
**kwargs
|
220 |
+
):
|
221 |
+
"""
|
222 |
+
One-call speech enhancement of noisy speech `y`, for convenience.
|
223 |
+
"""
|
224 |
+
sr=8000
|
225 |
+
start = time.time()
|
226 |
+
T_orig = y.size(1)
|
227 |
+
norm_factor = y.abs().max().item()
|
228 |
+
y = y / norm_factor
|
229 |
+
m = m / norm_factor
|
230 |
+
|
231 |
+
Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0)
|
232 |
+
Y = pad_spec(Y)
|
233 |
+
M = torch.unsqueeze(self._forward_transform(self._stft(m.cuda())), 0)
|
234 |
+
M = pad_spec(M)
|
235 |
+
|
236 |
+
if sampler_type == "pc":
|
237 |
+
sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), M.cuda(), N=N,
|
238 |
+
corrector_steps=corrector_steps, snr=snr, intermediate=False,
|
239 |
+
**kwargs)
|
240 |
+
else:
|
241 |
+
print("{} is not a valid sampler type!".format(sampler_type))
|
242 |
+
sample, nfe = sampler()
|
243 |
+
|
244 |
+
sample = sample.squeeze()
|
245 |
+
|
246 |
+
x_hat = self.to_audio(sample)
|
247 |
+
x_hat = x_hat * norm_factor
|
248 |
+
x_hat = x_hat.squeeze().cpu().numpy()
|
249 |
+
end = time.time()
|
250 |
+
if timeit:
|
251 |
+
rtf = (end-start)/(len(x_hat)/sr)
|
252 |
+
return x_hat, nfe, rtf
|
253 |
+
else:
|
254 |
+
return x_hat
|
255 |
+
|
geco/sampling/__init__.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Various sampling methods."""
|
2 |
+
from scipy import integrate
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
|
6 |
+
from .correctors import Corrector, CorrectorRegistry
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
|
13 |
+
'get_sampler'
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
def to_flattened_numpy(x):
|
18 |
+
"""Flatten a torch tensor `x` and convert it to numpy."""
|
19 |
+
return x.detach().cpu().numpy().reshape((-1,))
|
20 |
+
|
21 |
+
|
22 |
+
def from_flattened_numpy(x, shape):
|
23 |
+
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
|
24 |
+
return torch.from_numpy(x.reshape(shape))
|
25 |
+
|
26 |
+
|
27 |
+
def get_pc_sampler(
|
28 |
+
predictor_name, corrector_name, sde, score_fn, Y, M, Y_prior=None,
|
29 |
+
denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
|
30 |
+
intermediate=False, timestep_type=None, **kwargs
|
31 |
+
):
|
32 |
+
"""Create a Predictor-Corrector (PC) sampler.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
predictor_name: The name of a registered `sampling.Predictor`.
|
36 |
+
corrector_name: The name of a registered `sampling.Corrector`.
|
37 |
+
sde: An `sdes.SDE` object representing the forward SDE.
|
38 |
+
score_fn: A function (typically learned model) that predicts the score.
|
39 |
+
y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
|
40 |
+
denoise: If `True`, add one-step denoising to the final samples.
|
41 |
+
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
|
42 |
+
snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
|
43 |
+
N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
A sampling function that returns samples and the number of function evaluations during sampling.
|
47 |
+
"""
|
48 |
+
predictor_cls = PredictorRegistry.get_by_name(predictor_name)
|
49 |
+
corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
|
50 |
+
predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
|
51 |
+
corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
|
52 |
+
|
53 |
+
def pc_sampler(Y_prior=Y_prior, timestep_type=timestep_type):
|
54 |
+
"""The PC sampler function."""
|
55 |
+
with torch.no_grad():
|
56 |
+
|
57 |
+
if Y_prior == None:
|
58 |
+
Y_prior = Y
|
59 |
+
|
60 |
+
xt, _ = sde.prior_sampling(Y_prior.shape, Y_prior)
|
61 |
+
timesteps = timesteps_space(sde.T, sde.N,eps, Y.device, type=timestep_type)
|
62 |
+
xt = xt.to(Y_prior.device)
|
63 |
+
for i in range(len(timesteps)):
|
64 |
+
t = timesteps[i]
|
65 |
+
if i != len(timesteps) - 1:
|
66 |
+
stepsize = t - timesteps[i+1]
|
67 |
+
else:
|
68 |
+
stepsize = timesteps[-1]
|
69 |
+
vec_t = torch.ones(Y.shape[0], device=Y.device) * t
|
70 |
+
xt, xt_mean = corrector.update_fn(xt, vec_t, Y, M)
|
71 |
+
xt, xt_mean = predictor.update_fn(xt, vec_t, Y, M, stepsize)
|
72 |
+
x_result = xt_mean if denoise else xt
|
73 |
+
ns = len(timesteps) * (corrector.n_steps + 1)
|
74 |
+
return x_result, ns
|
75 |
+
|
76 |
+
if intermediate:
|
77 |
+
return pc_sampler_intermediate
|
78 |
+
else:
|
79 |
+
return pc_sampler
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
def timesteps_space(sdeT, sdeN, eps, device, type='linear'):
|
84 |
+
timesteps = torch.linspace(sdeT, eps, sdeN, device=device)
|
85 |
+
if type == 'linear':
|
86 |
+
return timesteps
|
87 |
+
else:
|
88 |
+
pass #not used, can be used to implement different sampling schedules
|
89 |
+
|
90 |
+
return timesteps
|
geco/sampling/correctors.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from geco import sdes
|
5 |
+
from geco.util.registry import Registry
|
6 |
+
|
7 |
+
|
8 |
+
CorrectorRegistry = Registry("Corrector")
|
9 |
+
|
10 |
+
|
11 |
+
class Corrector(abc.ABC):
|
12 |
+
"""The abstract class for a corrector algorithm."""
|
13 |
+
|
14 |
+
def __init__(self, sde, score_fn, snr, n_steps):
|
15 |
+
super().__init__()
|
16 |
+
self.rsde = sde.reverse(score_fn)
|
17 |
+
self.score_fn = score_fn
|
18 |
+
self.snr = snr
|
19 |
+
self.n_steps = n_steps
|
20 |
+
|
21 |
+
@abc.abstractmethod
|
22 |
+
def update_fn(self, x, t, *args):
|
23 |
+
"""One update of the corrector.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
x: A PyTorch tensor representing the current state
|
27 |
+
t: A PyTorch tensor representing the current time step.
|
28 |
+
*args: Possibly additional arguments, in particular `y` for OU processes
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
x: A PyTorch tensor of the next state.
|
32 |
+
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
33 |
+
"""
|
34 |
+
pass
|
35 |
+
|
36 |
+
|
37 |
+
@CorrectorRegistry.register(name='ald')
|
38 |
+
class AnnealedLangevinDynamics(Corrector):
|
39 |
+
"""The original annealed Langevin dynamics predictor in NCSN/NCSNv2."""
|
40 |
+
def __init__(self, sde, score_fn, snr, n_steps):
|
41 |
+
super().__init__(sde, score_fn, snr, n_steps)
|
42 |
+
self.sde = sde
|
43 |
+
self.score_fn = score_fn
|
44 |
+
self.snr = snr
|
45 |
+
self.n_steps = n_steps
|
46 |
+
|
47 |
+
def update_fn(self, x, t, y, m):
|
48 |
+
x_mean = 0
|
49 |
+
n_steps = self.n_steps
|
50 |
+
target_snr = self.snr
|
51 |
+
std = self.sde.marginal_prob(x, t, y)[1]
|
52 |
+
for _ in range(n_steps):
|
53 |
+
# print(x.shape, y.shape,m.shape)
|
54 |
+
grad = self.score_fn(x, t, y, m)
|
55 |
+
noise = torch.randn_like(x)
|
56 |
+
step_size = (target_snr * std) ** 2 * 2
|
57 |
+
x_mean = x + step_size[:, None, None, None] * grad
|
58 |
+
x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
|
59 |
+
|
60 |
+
return x, x_mean
|
geco/sampling/predictors.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from geco.util.registry import Registry
|
7 |
+
|
8 |
+
|
9 |
+
PredictorRegistry = Registry("Predictor")
|
10 |
+
|
11 |
+
|
12 |
+
class Predictor(abc.ABC):
|
13 |
+
"""The abstract class for a predictor algorithm."""
|
14 |
+
|
15 |
+
def __init__(self, sde, score_fn, probability_flow=False):
|
16 |
+
super().__init__()
|
17 |
+
self.sde = sde
|
18 |
+
self.rsde = sde.reverse(score_fn)
|
19 |
+
self.score_fn = score_fn
|
20 |
+
self.probability_flow = probability_flow
|
21 |
+
|
22 |
+
@abc.abstractmethod
|
23 |
+
def update_fn(self, x, t, *args):
|
24 |
+
"""One update of the predictor.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
x: A PyTorch tensor representing the current state
|
28 |
+
t: A Pytorch tensor representing the current time step.
|
29 |
+
*args: Possibly additional arguments, in particular `y` for OU processes
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
x: A PyTorch tensor of the next state.
|
33 |
+
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
|
34 |
+
"""
|
35 |
+
pass
|
36 |
+
|
37 |
+
def debug_update_fn(self, x, t, *args):
|
38 |
+
raise NotImplementedError(f"Debug update function not implemented for predictor {self}.")
|
39 |
+
|
40 |
+
|
41 |
+
@PredictorRegistry.register('reverse_diffusion')
|
42 |
+
class ReverseDiffusionPredictor(Predictor):
|
43 |
+
def __init__(self, sde, score_fn, probability_flow=False):
|
44 |
+
super().__init__(sde, score_fn, probability_flow=probability_flow)
|
45 |
+
|
46 |
+
def update_fn(self, x, t, y, m, stepsize):
|
47 |
+
f, g = self.rsde.discretize(x, t, y, m, stepsize)
|
48 |
+
z = torch.randn_like(x)
|
49 |
+
x_mean = x - f
|
50 |
+
x = x_mean + g[:, None, None, None] * z
|
51 |
+
return x, x_mean
|
52 |
+
|
53 |
+
def update_fn_analyze(self, x, t, *args):
|
54 |
+
raise NotImplementedError("update_fn_analyze() has not been implemented yet for the ReverseDiffusionPredictor")
|
55 |
+
|
geco/sdes.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
|
3 |
+
|
4 |
+
Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
|
5 |
+
"""
|
6 |
+
import abc
|
7 |
+
import warnings
|
8 |
+
import math
|
9 |
+
import scipy.special as sc
|
10 |
+
import numpy as np
|
11 |
+
from geco.util.tensors import batch_broadcast
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from geco.util.registry import Registry
|
15 |
+
|
16 |
+
|
17 |
+
SDERegistry = Registry("SDE")
|
18 |
+
|
19 |
+
|
20 |
+
class SDE(abc.ABC):
|
21 |
+
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
|
22 |
+
|
23 |
+
def __init__(self, N):
|
24 |
+
"""Construct an SDE.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
N: number of discretization time steps.
|
28 |
+
"""
|
29 |
+
super().__init__()
|
30 |
+
self.N = N
|
31 |
+
|
32 |
+
@property
|
33 |
+
@abc.abstractmethod
|
34 |
+
def T(self):
|
35 |
+
"""End time of the SDE."""
|
36 |
+
pass
|
37 |
+
|
38 |
+
@abc.abstractmethod
|
39 |
+
def sde(self, x, t, *args):
|
40 |
+
pass
|
41 |
+
|
42 |
+
@abc.abstractmethod
|
43 |
+
def marginal_prob(self, x, t, *args):
|
44 |
+
"""Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
|
45 |
+
pass
|
46 |
+
|
47 |
+
@abc.abstractmethod
|
48 |
+
def prior_sampling(self, shape, *args):
|
49 |
+
"""Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
|
50 |
+
pass
|
51 |
+
|
52 |
+
@abc.abstractmethod
|
53 |
+
def prior_logp(self, z):
|
54 |
+
"""Compute log-density of the prior distribution.
|
55 |
+
|
56 |
+
Useful for computing the log-likelihood via probability flow ODE.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
z: latent code
|
60 |
+
Returns:
|
61 |
+
log probability density
|
62 |
+
"""
|
63 |
+
pass
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
@abc.abstractmethod
|
67 |
+
def add_argparse_args(parent_parser):
|
68 |
+
"""
|
69 |
+
Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
|
70 |
+
"""
|
71 |
+
pass
|
72 |
+
|
73 |
+
def discretize(self, x, t, y, stepsize):
|
74 |
+
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
|
75 |
+
|
76 |
+
Useful for reverse diffusion sampling and probabiliy flow sampling.
|
77 |
+
Defaults to Euler-Maruyama discretization.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
x: a torch tensor
|
81 |
+
t: a torch float representing the time step (from 0 to `self.T`)
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
f, G
|
85 |
+
"""
|
86 |
+
dt = stepsize
|
87 |
+
#dt = 1 /self.N
|
88 |
+
drift, diffusion = self.sde(x, t, y)
|
89 |
+
f = drift * dt
|
90 |
+
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
|
91 |
+
return f, G
|
92 |
+
|
93 |
+
def reverse(oself, score_model, probability_flow=False):
|
94 |
+
"""Create the reverse-time SDE/ODE.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
score_model: A function that takes x, t and y and returns the score.
|
98 |
+
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
|
99 |
+
"""
|
100 |
+
N = oself.N
|
101 |
+
T = oself.T
|
102 |
+
sde_fn = oself.sde
|
103 |
+
discretize_fn = oself.discretize
|
104 |
+
|
105 |
+
# Build the class for reverse-time SDE.
|
106 |
+
class RSDE(oself.__class__):
|
107 |
+
def __init__(self):
|
108 |
+
self.N = N
|
109 |
+
self.probability_flow = probability_flow
|
110 |
+
|
111 |
+
@property
|
112 |
+
def T(self):
|
113 |
+
return T
|
114 |
+
|
115 |
+
def sde(self, x, t, *args):
|
116 |
+
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
|
117 |
+
rsde_parts = self.rsde_parts(x, t, *args)
|
118 |
+
total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
|
119 |
+
return total_drift, diffusion
|
120 |
+
|
121 |
+
def discretize(self, x, t, y, m, stepsize):
|
122 |
+
"""Create discretized iteration rules for the reverse diffusion sampler."""
|
123 |
+
f, G = discretize_fn(x, t, y, stepsize)
|
124 |
+
if torch.is_complex(G):
|
125 |
+
G = G.imag
|
126 |
+
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y, m) * (0.5 if self.probability_flow else 1.)
|
127 |
+
rev_G = torch.zeros_like(G) if self.probability_flow else G
|
128 |
+
return rev_f, rev_G
|
129 |
+
|
130 |
+
return RSDE()
|
131 |
+
|
132 |
+
@abc.abstractmethod
|
133 |
+
def copy(self):
|
134 |
+
pass
|
135 |
+
|
136 |
+
|
137 |
+
@SDERegistry.register("bbed")
|
138 |
+
class BBED(SDE):
|
139 |
+
@staticmethod
|
140 |
+
def add_argparse_args(parser):
|
141 |
+
parser.add_argument("--sde-n", type=int, default=30, help="The number of timesteps in the SDE discretization. 30 by default")
|
142 |
+
parser.add_argument("--T_sampling", type=float, default=0.999, help="The T so that t < T during sampling in the train step.")
|
143 |
+
parser.add_argument("--k", type=float, default = 2.6, help="base factor for diffusion term")
|
144 |
+
parser.add_argument("--theta", type=float, default = 0.52, help="root scale factor for diffusion term.")
|
145 |
+
return parser
|
146 |
+
|
147 |
+
def __init__(self, T_sampling, k, theta, N=1000, **kwargs):
|
148 |
+
"""Construct an Brownian Bridge with Exploding Diffusion Coefficient SDE with parameterization as in the paper.
|
149 |
+
dx = (y-x)/(Tc-t) dt + sqrt(theta)*k^t dw
|
150 |
+
"""
|
151 |
+
super().__init__(N)
|
152 |
+
self.k = k
|
153 |
+
self.logk = np.log(self.k)
|
154 |
+
self.theta = theta
|
155 |
+
self.N = N
|
156 |
+
self.Eilog = sc.expi(-2*self.logk)
|
157 |
+
self.T = T_sampling #for sampling in train step and inference
|
158 |
+
self.Tc = 1 #for constructing the SDE, dont change this
|
159 |
+
|
160 |
+
|
161 |
+
def copy(self):
|
162 |
+
return BBED(self.T, self.k, self.theta, N=self.N)
|
163 |
+
|
164 |
+
|
165 |
+
def T(self):
|
166 |
+
return self.T
|
167 |
+
|
168 |
+
def Tc(self):
|
169 |
+
return self.Tc
|
170 |
+
|
171 |
+
|
172 |
+
def sde(self, x, t, y):
|
173 |
+
drift = (y - x)/(self.Tc - t)
|
174 |
+
sigma = (self.k) ** t
|
175 |
+
diffusion = sigma * np.sqrt(self.theta)
|
176 |
+
return drift, diffusion
|
177 |
+
|
178 |
+
|
179 |
+
def _mean(self, x0, t, y):
|
180 |
+
time = (t/self.Tc)[:, None, None, None]
|
181 |
+
mean = x0*(1-time) + y*time
|
182 |
+
return mean
|
183 |
+
|
184 |
+
def _std(self, t):
|
185 |
+
t_np = t.cpu().detach().numpy()
|
186 |
+
Eis = sc.expi(2*(t_np-1)*self.logk) - self.Eilog
|
187 |
+
h = 2*self.k**2*self.logk
|
188 |
+
var = (self.k**(2*t_np)-1+t_np) + h*(1-t_np)*Eis
|
189 |
+
var = torch.tensor(var).to(device=t.device)*(1-t)*self.theta
|
190 |
+
return torch.sqrt(var)
|
191 |
+
|
192 |
+
def marginal_prob(self, x0, t, y):
|
193 |
+
return self._mean(x0, t, y), self._std(t)
|
194 |
+
|
195 |
+
def prior_sampling(self, shape, y):
|
196 |
+
if shape != y.shape:
|
197 |
+
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
|
198 |
+
std = self._std(self.T*torch.ones((y.shape[0],), device=y.device))
|
199 |
+
z = torch.randn_like(y)
|
200 |
+
x_T = y + z * std[:, None, None, None]
|
201 |
+
return x_T, z
|
202 |
+
|
203 |
+
def prior_logp(self, z):
|
204 |
+
raise NotImplementedError("prior_logp for BBED not yet implemented!")
|
205 |
+
|
geco/util/inference.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from pesq import pesq
|
5 |
+
from pystoi import stoi
|
6 |
+
|
7 |
+
from .other import si_sdr, pad_spec
|
8 |
+
|
9 |
+
# Settings
|
10 |
+
sr = 8000
|
11 |
+
snr = 0.5
|
12 |
+
N = 30
|
13 |
+
corrector_steps = 1
|
14 |
+
|
15 |
+
|
16 |
+
def evaluate_model(model, num_eval_files):
|
17 |
+
|
18 |
+
clean_files = model.data_module.valid_set.clean_files
|
19 |
+
noisy_files = model.data_module.valid_set.noisy_files
|
20 |
+
mixture_files = model.data_module.valid_set.mixture_files
|
21 |
+
|
22 |
+
# Select test files uniformly accros validation files
|
23 |
+
total_num_files = len(clean_files)
|
24 |
+
indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
|
25 |
+
clean_files = list(clean_files[i] for i in indices)
|
26 |
+
noisy_files = list(noisy_files[i] for i in indices)
|
27 |
+
mixture_files = list(mixture_files[i] for i in indices)
|
28 |
+
|
29 |
+
_pesq = 0
|
30 |
+
_si_sdr = 0
|
31 |
+
_estoi = 0
|
32 |
+
# iterate over files
|
33 |
+
for (clean_file, noisy_file, mixture_file) in zip(clean_files, noisy_files, mixture_files):
|
34 |
+
# Load wavs
|
35 |
+
x, sr_ = torchaudio.load(clean_file)
|
36 |
+
if sr_ != sr:
|
37 |
+
x = torchaudio.transforms.Resample(sr_, sr)(x)
|
38 |
+
y, sr_ = torchaudio.load(noisy_file)
|
39 |
+
if sr_ != sr:
|
40 |
+
y = torchaudio.transforms.Resample(sr_, sr)(y)
|
41 |
+
m, sr_ = torchaudio.load(mixture_file)
|
42 |
+
if sr_ != sr:
|
43 |
+
m = torchaudio.transforms.Resample(sr_, sr)(m)
|
44 |
+
|
45 |
+
min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1])
|
46 |
+
x = x[...,:min_leng]
|
47 |
+
y = y[...,:min_leng]
|
48 |
+
m = m[...,:min_leng]
|
49 |
+
|
50 |
+
T_orig = x.size(1)
|
51 |
+
|
52 |
+
# Normalize per utterance
|
53 |
+
norm_factor = y.abs().max()
|
54 |
+
y = y / norm_factor
|
55 |
+
m = m / norm_factor
|
56 |
+
|
57 |
+
# Prepare DNN input
|
58 |
+
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
|
59 |
+
Y = pad_spec(Y)
|
60 |
+
|
61 |
+
M = torch.unsqueeze(model._forward_transform(model._stft(m.cuda())), 0)
|
62 |
+
M = pad_spec(M)
|
63 |
+
|
64 |
+
y = y * norm_factor
|
65 |
+
|
66 |
+
# print(x.shape,y.shape,m.shape,Y.shape,M.shape)
|
67 |
+
# Reverse sampling
|
68 |
+
sampler = model.get_pc_sampler(
|
69 |
+
'reverse_diffusion', 'ald', Y.cuda(), M.cuda(), N=N,
|
70 |
+
corrector_steps=corrector_steps, snr=snr)
|
71 |
+
sample, _ = sampler()
|
72 |
+
|
73 |
+
sample = sample.squeeze()
|
74 |
+
|
75 |
+
|
76 |
+
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
77 |
+
x_hat = x_hat * norm_factor
|
78 |
+
|
79 |
+
x_hat = x_hat.squeeze().cpu().numpy()
|
80 |
+
x = x.squeeze().cpu().numpy()
|
81 |
+
y = y.squeeze().cpu().numpy()
|
82 |
+
|
83 |
+
_si_sdr += si_sdr(x, x_hat)
|
84 |
+
_pesq += pesq(sr, x, x_hat, 'nb')
|
85 |
+
_estoi += stoi(x, x_hat, sr, extended=True)
|
86 |
+
|
87 |
+
return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
|
88 |
+
|
89 |
+
|
90 |
+
def evaluate_model2(model, num_eval_files, inference_N, inference_start=0.5):
|
91 |
+
|
92 |
+
|
93 |
+
N = inference_N
|
94 |
+
reverse_start_time = inference_start
|
95 |
+
|
96 |
+
clean_files = model.data_module.valid_set.clean_files
|
97 |
+
noisy_files = model.data_module.valid_set.noisy_files
|
98 |
+
mixture_files = model.data_module.valid_set.mixture_files
|
99 |
+
|
100 |
+
# Select test files uniformly accros validation files
|
101 |
+
total_num_files = len(clean_files)
|
102 |
+
indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
|
103 |
+
clean_files = list(clean_files[i] for i in indices)
|
104 |
+
noisy_files = list(noisy_files[i] for i in indices)
|
105 |
+
mixture_files = list(mixture_files[i] for i in indices)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
_pesq = 0
|
110 |
+
_si_sdr = 0
|
111 |
+
_estoi = 0
|
112 |
+
# iterate over files
|
113 |
+
for (clean_file, noisy_file, mixture_file) in zip(clean_files, noisy_files, mixture_files):
|
114 |
+
# Load wavs
|
115 |
+
x, sr_ = torchaudio.load(clean_file)
|
116 |
+
if sr_ != sr:
|
117 |
+
x = torchaudio.transforms.Resample(sr_, sr)(x)
|
118 |
+
y, sr_ = torchaudio.load(noisy_file)
|
119 |
+
if sr_ != sr:
|
120 |
+
y = torchaudio.transforms.Resample(sr_, sr)(y)
|
121 |
+
m, sr_ = torchaudio.load(mixture_file)
|
122 |
+
if sr_ != sr:
|
123 |
+
m = torchaudio.transforms.Resample(sr_, sr)(m)
|
124 |
+
|
125 |
+
#requires only for BWE as the dataset has different length of clean and noisy files
|
126 |
+
min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1])
|
127 |
+
x = x[...,:min_leng]
|
128 |
+
y = y[...,:min_leng]
|
129 |
+
m = m[...,:min_leng]
|
130 |
+
|
131 |
+
T_orig = x.size(1)
|
132 |
+
|
133 |
+
# Normalize per utterance
|
134 |
+
norm_factor = y.abs().max()
|
135 |
+
y = y / norm_factor
|
136 |
+
x = x / norm_factor
|
137 |
+
m = m / norm_factor
|
138 |
+
|
139 |
+
# Prepare DNN input
|
140 |
+
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
|
141 |
+
Y = pad_spec(Y)
|
142 |
+
|
143 |
+
X = torch.unsqueeze(model._forward_transform(model._stft(x.cuda())), 0)
|
144 |
+
X = pad_spec(X)
|
145 |
+
|
146 |
+
M = torch.unsqueeze(model._forward_transform(model._stft(m.cuda())), 0)
|
147 |
+
M = pad_spec(M)
|
148 |
+
|
149 |
+
|
150 |
+
y = y * norm_factor
|
151 |
+
x = x * norm_factor
|
152 |
+
|
153 |
+
x = x.squeeze().cpu().numpy()
|
154 |
+
y = y.squeeze().cpu().numpy()
|
155 |
+
|
156 |
+
total_loss = 0
|
157 |
+
timesteps = torch.linspace(reverse_start_time, 0.03, N, device=Y.device)
|
158 |
+
#prior sampling starting from reverse_start_time
|
159 |
+
std = model.sde._std(reverse_start_time*torch.ones((Y.shape[0],), device=Y.device))
|
160 |
+
z = torch.randn_like(Y)
|
161 |
+
X_t = Y + z * std[:, None, None, None]
|
162 |
+
|
163 |
+
#reverse steps by Euler Maruyama
|
164 |
+
for i in range(len(timesteps)):
|
165 |
+
t = timesteps[i]
|
166 |
+
if i != len(timesteps) - 1:
|
167 |
+
dt = t - timesteps[i+1]
|
168 |
+
else:
|
169 |
+
dt = timesteps[-1]
|
170 |
+
with torch.no_grad():
|
171 |
+
#take Euler step here
|
172 |
+
f, g = model.sde.sde(X_t, t, Y)
|
173 |
+
vec_t = torch.ones(Y.shape[0], device=Y.device) * t
|
174 |
+
score = model.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None])
|
175 |
+
mean_x_tm1 = X_t - (f - g**2*score)*dt #mean of x t minus 1 = mu(x_{t-1})
|
176 |
+
if i == len(timesteps) - 1: #output
|
177 |
+
X_t = mean_x_tm1
|
178 |
+
break
|
179 |
+
z = torch.randn_like(X)
|
180 |
+
|
181 |
+
X_t = mean_x_tm1 + z*g*torch.sqrt(dt)
|
182 |
+
|
183 |
+
sample = X_t
|
184 |
+
sample = sample.squeeze()
|
185 |
+
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
186 |
+
x_hat = x_hat * norm_factor
|
187 |
+
x_hat = x_hat.squeeze().cpu().numpy()
|
188 |
+
_si_sdr += si_sdr(x, x_hat)
|
189 |
+
_pesq += pesq(sr, x, x_hat, 'nb')
|
190 |
+
_estoi += stoi(x, x_hat, sr, extended=True)
|
191 |
+
|
192 |
+
return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files, total_loss/num_eval_files
|
193 |
+
|
194 |
+
|
195 |
+
def convert_to_audio(X, deemp, T_orig, model, norm_factor):
|
196 |
+
|
197 |
+
sample = X
|
198 |
+
|
199 |
+
sample = sample.squeeze()
|
200 |
+
if len(sample.shape)==4:
|
201 |
+
sample = sample*deemp[None, None, :, None].to(device=sample.device)
|
202 |
+
elif len(sample.shape)==3:
|
203 |
+
sample = sample*deemp[None, :, None].to(device=sample.device)
|
204 |
+
else:
|
205 |
+
sample = sample*deemp[:, None].to(device=sample.device)
|
206 |
+
|
207 |
+
x_hat = model.to_audio(sample.squeeze(), T_orig)
|
208 |
+
x_hat = x_hat * norm_factor
|
209 |
+
|
210 |
+
x_hat = x_hat.squeeze().cpu().numpy()
|
211 |
+
return x_hat
|
geco/util/other.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import scipy.stats
|
6 |
+
from scipy.signal import butter, sosfilt
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from pesq import pesq
|
11 |
+
from pystoi import stoi
|
12 |
+
|
13 |
+
|
14 |
+
def si_sdr_components(s_hat, s, n):
|
15 |
+
"""
|
16 |
+
"""
|
17 |
+
# s_target
|
18 |
+
alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
|
19 |
+
s_target = alpha_s * s
|
20 |
+
|
21 |
+
# e_noise
|
22 |
+
alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
|
23 |
+
e_noise = alpha_n * n
|
24 |
+
|
25 |
+
# e_art
|
26 |
+
e_art = s_hat - s_target - e_noise
|
27 |
+
|
28 |
+
return s_target, e_noise, e_art
|
29 |
+
|
30 |
+
def energy_ratios(s_hat, s, n):
|
31 |
+
"""
|
32 |
+
"""
|
33 |
+
s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
|
34 |
+
|
35 |
+
si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
|
36 |
+
si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
|
37 |
+
si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
|
38 |
+
|
39 |
+
return si_sdr, si_sir, si_sar
|
40 |
+
|
41 |
+
def mean_conf_int(data, confidence=0.95):
|
42 |
+
a = 1.0 * np.array(data)
|
43 |
+
n = len(a)
|
44 |
+
m, se = np.mean(a), scipy.stats.sem(a)
|
45 |
+
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
|
46 |
+
return m, h
|
47 |
+
|
48 |
+
class Method():
|
49 |
+
def __init__(self, name, base_dir, metrics):
|
50 |
+
self.name = name
|
51 |
+
self.base_dir = base_dir
|
52 |
+
self.metrics = {}
|
53 |
+
|
54 |
+
for i in range(len(metrics)):
|
55 |
+
metric = metrics[i]
|
56 |
+
value = []
|
57 |
+
self.metrics[metric] = value
|
58 |
+
|
59 |
+
def append(self, matric, value):
|
60 |
+
self.metrics[matric].append(value)
|
61 |
+
|
62 |
+
def get_mean_ci(self, metric):
|
63 |
+
return mean_conf_int(np.array(self.metrics[metric]))
|
64 |
+
|
65 |
+
def hp_filter(signal, cut_off=80, order=10, sr=16000):
|
66 |
+
factor = cut_off /sr * 2
|
67 |
+
sos = butter(order, factor, 'hp', output='sos')
|
68 |
+
filtered = sosfilt(sos, signal)
|
69 |
+
return filtered
|
70 |
+
|
71 |
+
def si_sdr(s, s_hat):
|
72 |
+
alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
|
73 |
+
sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
|
74 |
+
alpha*s - s_hat)**2)
|
75 |
+
return sdr
|
76 |
+
|
77 |
+
def snr_dB(s,n):
|
78 |
+
s_power = 1/len(s)*np.sum(s**2)
|
79 |
+
n_power = 1/len(n)*np.sum(n**2)
|
80 |
+
snr_dB = 10*np.log10(s_power/n_power)
|
81 |
+
return snr_dB
|
82 |
+
|
83 |
+
def pad_spec(Y):
|
84 |
+
T = Y.size(3)
|
85 |
+
if T%64 !=0:
|
86 |
+
num_pad = 64-T%64
|
87 |
+
else:
|
88 |
+
num_pad = 0
|
89 |
+
pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
|
90 |
+
return pad2d(Y)
|
91 |
+
|
92 |
+
|
93 |
+
def ensure_dir(file_path):
|
94 |
+
directory = file_path
|
95 |
+
if not os.path.exists(directory):
|
96 |
+
os.makedirs(directory)
|
97 |
+
|
98 |
+
|
99 |
+
def print_metrics(x, y, x_hat_list, labels, sr=16000):
|
100 |
+
_si_sdr_mix = si_sdr(x, y)
|
101 |
+
_pesq_mix = pesq(sr, x, y, 'wb')
|
102 |
+
_estoi_mix = stoi(x, y, sr, extended=True)
|
103 |
+
print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
|
104 |
+
for i, x_hat in enumerate(x_hat_list):
|
105 |
+
_si_sdr = si_sdr(x, x_hat)
|
106 |
+
_pesq = pesq(sr, x, x_hat, 'wb')
|
107 |
+
_estoi = stoi(x, x_hat, sr, extended=True)
|
108 |
+
print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
|
109 |
+
|
110 |
+
def mean_std(data):
|
111 |
+
data = data[~np.isnan(data)]
|
112 |
+
mean = np.mean(data)
|
113 |
+
std = np.std(data)
|
114 |
+
return mean, std
|
115 |
+
|
116 |
+
def print_mean_std(data, decimal=2):
|
117 |
+
data = np.array(data)
|
118 |
+
data = data[~np.isnan(data)]
|
119 |
+
mean = np.mean(data)
|
120 |
+
std = np.std(data)
|
121 |
+
if decimal == 2:
|
122 |
+
string = f'{mean:.2f} ± {std:.2f}'
|
123 |
+
elif decimal == 1:
|
124 |
+
string = f'{mean:.1f} ± {std:.1f}'
|
125 |
+
return string
|
geco/util/registry.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
|
5 |
+
class Registry:
|
6 |
+
def __init__(self, managed_thing: str):
|
7 |
+
"""
|
8 |
+
Create a new registry.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
managed_thing: A string describing what type of thing is managed by this registry. Will be used for
|
12 |
+
warnings and errors, so it's a good idea to keep this string globally unique and easily understood.
|
13 |
+
"""
|
14 |
+
self.managed_thing = managed_thing
|
15 |
+
self._registry = {}
|
16 |
+
|
17 |
+
def register(self, name: str) -> Callable:
|
18 |
+
def inner_wrapper(wrapped_class) -> Callable:
|
19 |
+
if name in self._registry:
|
20 |
+
warnings.warn(f"{self.managed_thing} with name '{name}' doubly registered, old class will be replaced.")
|
21 |
+
self._registry[name] = wrapped_class
|
22 |
+
return wrapped_class
|
23 |
+
return inner_wrapper
|
24 |
+
|
25 |
+
def get_by_name(self, name: str):
|
26 |
+
"""Get a managed thing by name."""
|
27 |
+
if name in self._registry:
|
28 |
+
return self._registry[name]
|
29 |
+
else:
|
30 |
+
raise ValueError(f"{self.managed_thing} with name '{name}' unknown.")
|
31 |
+
|
32 |
+
def get_all_names(self):
|
33 |
+
"""Get the list of things' names registered to this registry."""
|
34 |
+
return list(self._registry.keys())
|
geco/util/tensors.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def batch_broadcast(a, x):
|
2 |
+
"""Broadcasts a over all dimensions of x, except the batch dimension, which must match."""
|
3 |
+
|
4 |
+
if len(a.shape) != 1:
|
5 |
+
a = a.squeeze()
|
6 |
+
if len(a.shape) != 1:
|
7 |
+
raise ValueError(
|
8 |
+
f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})"
|
9 |
+
)
|
10 |
+
|
11 |
+
if a.shape[0] != x.shape[0] and a.shape[0] != 1:
|
12 |
+
raise ValueError(
|
13 |
+
f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching")
|
14 |
+
|
15 |
+
out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1))))
|
16 |
+
return out
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch-lightning==1.5.10
|
2 |
+
torch==2.4.0
|
3 |
+
torch-ema==0.3
|
4 |
+
torch-optimizer==0.3.0
|
5 |
+
torch-stoi==0.1.2
|
6 |
+
torchaudio==2.4.0
|
7 |
+
torchinfo==1.6.3
|
8 |
+
torchmetrics==0.9.3
|
9 |
+
torchsde==0.2.5
|
10 |
+
torchvision==0.19.0
|
11 |
+
tornado==6.2
|
12 |
+
tqdm==4.63.0
|
13 |
+
ninja
|
14 |
+
matplotlib
|
15 |
+
pesq
|
16 |
+
wandb
|
17 |
+
PySoundFile
|
18 |
+
pandas
|
19 |
+
git+https://github.com/WangHelin1997/Fast-GeCo.git#subdirectory=score_models
|
20 |
+
speechbrain==1.0.0
|