Matthijs Hollemans commited on
Commit
0ecd9fb
1 Parent(s): e2a288e
.gitattributes CHANGED
@@ -28,6 +28,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
31
  *.xz filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.wav filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.pyc
2
+ __pycache__/
3
+ .DS_Store
4
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hacked together using the code from https://github.com/nikhilsinghmus/image2reverb
2
+
3
+ import os, types
4
+ import numpy as np
5
+ import gradio as gr
6
+ import soundfile as sf
7
+ import scipy
8
+ import librosa.display
9
+ from PIL import Image
10
+
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pyplot as plt
14
+
15
+ import torch
16
+ from torch.utils.data import Dataset
17
+ import torchvision.transforms as transforms
18
+ from pytorch_lightning import Trainer
19
+
20
+ from image2reverb.model import Image2Reverb
21
+ from image2reverb.stft import STFT
22
+
23
+
24
+ predicted_ir = None
25
+ predicted_spectrogram = None
26
+ predicted_depthmap = None
27
+
28
+
29
+ def test_step(self, batch, batch_idx):
30
+ spec, label, paths = batch
31
+ examples = [os.path.splitext(os.path.basename(s))[0] for _, s in zip(*paths)]
32
+
33
+ f, img = self.enc.forward(label)
34
+
35
+ shape = (
36
+ f.shape[0],
37
+ (self._latent_dimension - f.shape[1]) if f.shape[1] < self._latent_dimension else f.shape[1],
38
+ f.shape[2],
39
+ f.shape[3]
40
+ )
41
+ z = torch.cat((f, torch.randn(shape, device=model.device)), 1)
42
+
43
+ fake_spec = self.g(z)
44
+
45
+ stft = STFT()
46
+ y_f = [stft.inverse(s.squeeze()) for s in fake_spec]
47
+
48
+ # TODO: bit hacky
49
+ global predicted_ir, predicted_spectrogram, predicted_depthmap
50
+ predicted_ir = y_f[0]
51
+
52
+ s = fake_spec.squeeze().cpu().numpy()
53
+ predicted_spectrogram = np.exp((((s + 1) * 0.5) * 19.5) - 17.5) - 1e-8
54
+
55
+ img = (img + 1) * 0.5
56
+ predicted_depthmap = img.cpu().squeeze().permute(1, 2, 0)[:,:,-1].squeeze().numpy()
57
+
58
+ return {"test_audio": y_f, "test_examples": examples}
59
+
60
+
61
+ def test_epoch_end(self, outputs):
62
+ if not self.test_callback:
63
+ return
64
+
65
+ examples = []
66
+ audio = []
67
+
68
+ for output in outputs:
69
+ for i in range(len(output["test_examples"])):
70
+ audio.append(output["test_audio"][i])
71
+ examples.append(output["test_examples"][i])
72
+
73
+ self.test_callback(examples, audio)
74
+
75
+
76
+ checkpoint_path = "./checkpoints/image2reverb_f22.ckpt"
77
+ encoder_path = None
78
+ depthmodel_path = "./checkpoints/mono_odom_640x192"
79
+ constant_depth = None
80
+ latent_dimension = 512
81
+
82
+ model = Image2Reverb(encoder_path, depthmodel_path)
83
+ m = torch.load(checkpoint_path, map_location=model.device)
84
+ model.load_state_dict(m["state_dict"])
85
+
86
+ model.test_step = types.MethodType(test_step, model)
87
+ model.test_epoch_end = types.MethodType(test_epoch_end, model)
88
+
89
+ image_transforms = transforms.Compose([
90
+ transforms.Resize([224, 224], transforms.functional.InterpolationMode.BICUBIC),
91
+ transforms.ToTensor(),
92
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
93
+ ])
94
+
95
+
96
+ class Image2ReverbDemoDataset(Dataset):
97
+ def __init__(self, image):
98
+ self.image = Image.fromarray(image)
99
+ self.stft = STFT()
100
+
101
+ def __getitem__(self, index):
102
+ img_tensor = image_transforms(self.image.convert("RGB"))
103
+ return torch.zeros(1, int(5.94 * 22050)), img_tensor, ("", "")
104
+
105
+ def __len__(self):
106
+ return 1
107
+
108
+ def name(self):
109
+ return "Image2ReverbDemo"
110
+
111
+
112
+ def convolve(audio, reverb):
113
+ # convolve audio with reverb
114
+ wet_audio = np.concatenate((audio, np.zeros(reverb.shape)))
115
+ wet_audio = scipy.signal.oaconvolve(wet_audio, reverb, "full")[:len(wet_audio)]
116
+
117
+ # normalize audio to roughly -1 dB peak and remove DC offset
118
+ wet_audio /= np.max(np.abs(wet_audio))
119
+ wet_audio -= np.mean(wet_audio)
120
+ wet_audio *= 0.9
121
+ return wet_audio
122
+
123
+
124
+ def predict(image, audio):
125
+ # image = numpy (height, width, channels)
126
+ # audio = tuple (sample_rate, frames) or (sample_rate, (frames, channels))
127
+
128
+ test_set = Image2ReverbDemoDataset(image)
129
+ test_loader = torch.utils.data.DataLoader(test_set, num_workers=0, batch_size=1)
130
+ trainer = Trainer(limit_test_batches=1)
131
+ trainer.test(model, test_loader, verbose=True)
132
+
133
+ # depthmap output
134
+ depthmap_fig = plt.figure()
135
+ plt.imshow(predicted_depthmap)
136
+ plt.close()
137
+
138
+ # spectrogram output
139
+ spectrogram_fig = plt.figure()
140
+ librosa.display.specshow(predicted_spectrogram, sr=22050, x_axis="time", y_axis="hz")
141
+ plt.close()
142
+
143
+ # plot the IR as a waveform
144
+ waveform_fig = plt.figure()
145
+ librosa.display.waveshow(predicted_ir, sr=22050, alpha=0.5)
146
+ plt.close()
147
+
148
+ # output audio as 16-bit signed integer
149
+ ir = (22050, (predicted_ir * 32767).astype(np.int16))
150
+
151
+ sample_rate, original_audio = audio
152
+
153
+ # incoming audio is 16-bit signed integer, convert to float and normalize
154
+ original_audio = original_audio.astype(np.float32) / 32768.0
155
+ original_audio /= np.max(np.abs(original_audio))
156
+
157
+ # resample reverb to sample_rate first, also normalize
158
+ reverb = predicted_ir.copy()
159
+ reverb = scipy.signal.resample_poly(reverb, up=sample_rate, down=22050)
160
+ reverb /= np.max(np.abs(reverb))
161
+
162
+ # stereo?
163
+ if len(original_audio.shape) > 1:
164
+ wet_left = convolve(original_audio[:, 0], reverb)
165
+ wet_right = convolve(original_audio[:, 1], reverb)
166
+ wet_audio = np.concatenate([wet_left[:, None], wet_right[:, None]], axis=1)
167
+ else:
168
+ wet_audio = convolve(original_audio, reverb)
169
+
170
+ # 50% dry-wet mix
171
+ mixed_audio = wet_audio * 0.5
172
+ mixed_audio[:len(original_audio), ...] += original_audio * 0.9 * 0.5
173
+
174
+ # convert back to 16-bit signed integer
175
+ wet_audio = (wet_audio * 32767).astype(np.int16)
176
+ mixed_audio = (mixed_audio * 32767).astype(np.int16)
177
+
178
+ convolved_audio_100 = (sample_rate, wet_audio)
179
+ convolved_audio_50 = (sample_rate, mixed_audio)
180
+
181
+ return depthmap_fig, spectrogram_fig, waveform_fig, ir, convolved_audio_100, convolved_audio_50
182
+
183
+
184
+ title = "Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis"
185
+
186
+ description = """
187
+ <b>Image2Reverb</b> predicts the acoustic reverberation of a given environment from a 2D image. <a href="https://arxiv.org/abs/2103.14201">Read the paper</a>
188
+
189
+ How to use: Choose an image of a room or other environment and an audio file.
190
+ The model will predict what the reverb of the room sounds like and applies this to the audio file.
191
+
192
+ First, the image is resized to 224×224. The monodepth model is used to predict a depthmap, which is added as an
193
+ additional channel to the image input. A ResNet-based encoder then converts the image into features, and
194
+ finally a GAN predicts the spectrogram of the reverb's impulse response.
195
+
196
+ <center><img src="file/model.jpg" width="870" height="297" alt="model architecture"></center>
197
+
198
+ The predicted impulse response is mono 22050 kHz. It is upsampled to the sampling rate of the audio
199
+ file and applied to both channels if the audio is stereo.
200
+ Generating the impulse response involves a certain amount of randomness, making it sound a little
201
+ different every time you try it.
202
+ """
203
+
204
+ article = """
205
+ <div style='margin:20px auto;'>
206
+
207
+ <p>Based on original work by Nikhil Singh, Jeff Mentch, Jerry Ng, Matthew Beveridge, Iddo Drori.
208
+ <a href="https://web.media.mit.edu/~nsingh1/image2reverb/">Project Page</a> |
209
+ <a href="https://arxiv.org/abs/2103.14201">Paper</a> |
210
+ <a href="https://github.com/nikhilsinghmus/image2reverb">GitHub</a></p>
211
+
212
+ <pre>
213
+ @InProceedings{Singh_2021_ICCV,
214
+ author = {Singh, Nikhil and Mentch, Jeff and Ng, Jerry and Beveridge, Matthew and Drori, Iddo},
215
+ title = {Image2Reverb: Cross-Modal Reverb Impulse Response Synthesis},
216
+ booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
217
+ month = {October},
218
+ year = {2021},
219
+ pages = {286-295}
220
+ }
221
+ </pre>
222
+
223
+ <p>🌠 Example images from <a href="https://web.media.mit.edu/~nsingh1/image2reverb/">the original project page</a>.</p>
224
+
225
+ <p>🎶 Example sound from <a href="https://freesound.org/people/ashesanddreams/sounds/610414/">Ashes and Dreams @ freesound.org</a> (CC BY 4.0 license). This is a mono 48 kHz recording that has no reverb on it.</p>
226
+
227
+ </div>
228
+ """
229
+
230
+ audio_example = "examples/ashesanddreams.wav"
231
+
232
+ examples = [
233
+ ["examples/input.4e2f71f6.png", audio_example],
234
+ ["examples/input.321eef38.png", audio_example],
235
+ ["examples/input.2238dc21.png", audio_example],
236
+ ["examples/input.4d280b40.png", audio_example],
237
+ ["examples/input.0c3f5013.png", audio_example],
238
+ ["examples/input.98773b90.png", audio_example],
239
+ ["examples/input.ac61500f.png", audio_example],
240
+ ["examples/input.5416407f.png", audio_example],
241
+ ]
242
+
243
+ gr.Interface(
244
+ fn=predict,
245
+ inputs=[
246
+ gr.inputs.Image(label="Upload Image"),
247
+ gr.inputs.Audio(label="Upload Audio", source="upload"),
248
+ ],
249
+ outputs=[
250
+ gr.Plot(label="Depthmap"),
251
+ gr.Plot(label="Impulse Response Spectrogram"),
252
+ gr.Plot(label="Impulse Response Waveform"),
253
+ gr.outputs.Audio(label="Impulse Response"),
254
+ gr.outputs.Audio(label="Output Audio (100% Wet)"),
255
+ gr.outputs.Audio(label="Output Audio (50% Dry, 50% Wet)"),
256
+ ],
257
+ title=title,
258
+ description=description,
259
+ article=article,
260
+ examples=examples,
261
+ ).launch()
checkpoints/image2reverb_f22.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d61422e95dc963e258b68536dc8135633a999c3a85a5a80925878ff75ca092e3
3
+ size 687498725
checkpoints/mono_odom_640x192/depth.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a2f542e274a5b0567e3118bc16aea4c2f44ba09df4a08a6c3a47d6d98285b72
3
+ size 12617260
checkpoints/mono_odom_640x192/encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acbf2534608f06be40eecd5026c505ebd0c1d9442fe5864abba1b5d90bff2e3e
3
+ size 46819013
checkpoints/mono_odom_640x192/pose.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4da0fe66fc1f781a05d8c4778f33ffa1851c219cb7fd561328479f5b439707e
3
+ size 5259718
checkpoints/mono_odom_640x192/pose_encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df8659ecf4363335c13ffc4510ff34556715c7f6435707622c3641a7fe055eb2
3
+ size 46856589
checkpoints/mono_odom_640x192/poses.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71a413ff381d4a58345e9152e0ca8d0b45a71e550df7730633a8cf7693edcced
3
+ size 76928
examples/input.0c3f5013.png ADDED
examples/input.2238dc21.png ADDED
examples/input.321eef38.png ADDED
examples/input.4d280b40.png ADDED
examples/input.4e2f71f6.png ADDED
examples/input.5416407f.png ADDED
examples/input.67bc502e.png ADDED
examples/input.98773b90.png ADDED
examples/input.ac61500f.png ADDED
examples/input.c9ee9d49.png ADDED
image2reverb/dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import soundfile
3
+ import torch
4
+ import torchvision.transforms as transforms
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+ from .stft import STFT
8
+ from .mel import LogMel
9
+
10
+
11
+ F_EXTENSIONS = [
12
+ ".jpg", ".JPG", ".jpeg", ".JPEG",
13
+ ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP", ".tiff", ".wav", ".WAV", ".aif", ".aiff", ".AIF", ".AIFF"
14
+ ]
15
+
16
+
17
+ def is_image_audio_file(filename):
18
+ return any(filename.endswith(extension) for extension in F_EXTENSIONS)
19
+
20
+
21
+ def make_dataset(dir, extensions=F_EXTENSIONS):
22
+ images = []
23
+ assert os.path.isdir(dir), "%s is not a valid directory." % dir
24
+
25
+ for root, _, fnames in sorted(os.walk(dir)):
26
+ for fname in fnames:
27
+ if is_image_audio_file(fname):
28
+ path = os.path.join(root, fname)
29
+ images.append(path)
30
+
31
+ return images
32
+
33
+
34
+ class Image2ReverbDataset(Dataset):
35
+ def __init__(self, dataroot, phase="train", spec="stft"):
36
+ self.root = dataroot
37
+ self.stft = LogMel() if spec == "mel" else STFT()
38
+
39
+ ### input A (images)
40
+ dir_A = "_A"
41
+ self.dir_A = os.path.join(self.root, phase + dir_A)
42
+ self.A_paths = sorted(make_dataset(self.dir_A))
43
+
44
+ ### input B (audio)
45
+ dir_B = "_B"
46
+ self.dir_B = os.path.join(self.root, phase + dir_B)
47
+ self.B_paths = sorted(make_dataset(self.dir_B))
48
+
49
+ def __getitem__(self, index):
50
+ if index > len(self):
51
+ return None
52
+ ### input A (images)
53
+ A_path = self.A_paths[index]
54
+ A = Image.open(A_path)
55
+ t = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
56
+ A_tensor = t(A.convert("RGB"))
57
+
58
+ ### input B (audio)
59
+ B_path = self.B_paths[index]
60
+ B, _ = soundfile.read(B_path)
61
+ B_spec = self.stft.transform(B)
62
+
63
+ return B_spec, A_tensor, (B_path, A_path)
64
+
65
+ def __len__(self):
66
+ return len(self.A_paths)
67
+
68
+ def name(self):
69
+ return "Image2Reverb"
70
+
71
+
72
+ class Image2ReverbDemoDataset(Dataset):
73
+ def __init__(self, image_paths):
74
+ if isinstance(image_paths, str) and os.path.isdir(image_paths):
75
+ self.paths = sorted(make_dataset(image_paths, [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP", ".tiff"]))
76
+ else:
77
+ self.paths = sorted(image_paths)
78
+
79
+ self.stft = STFT()
80
+
81
+ def __getitem__(self, index):
82
+ if index > len(self):
83
+ return None
84
+ ### input A (images)
85
+ path = self.paths[index]
86
+ img = Image.open(path)
87
+ t = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
88
+ img_tensor = t(img.convert("RGB"))
89
+
90
+ return torch.zeros(1, int(5.94 * 22050)), img_tensor, ("", path)
91
+
92
+ def __len__(self):
93
+ return len(self.paths)
94
+
95
+ def name(self):
96
+ return "Image2ReverbDemo"
image2reverb/layers.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.init import kaiming_normal_, calculate_gain
5
+
6
+
7
+ class PixelWiseNormLayer(nn.Module):
8
+ """PixelNorm layer. Implementation is from https://github.com/shanexn/pytorch-pggan."""
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x):
13
+ return x/torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)
14
+
15
+
16
+ class MiniBatchAverageLayer(nn.Module):
17
+ """Minibatch stat concatenation layer. Implementation is from https://github.com/shanexn/pytorch-pggan."""
18
+ def __init__(self, offset=1e-8):
19
+ super().__init__()
20
+ self.offset = offset
21
+
22
+ def forward(self, x):
23
+ stddev = torch.sqrt(torch.mean((x - torch.mean(x, dim=0, keepdim=True))**2, dim=0, keepdim=True) + self.offset)
24
+ inject_shape = list(x.size())[:]
25
+ inject_shape[1] = 1
26
+ inject = torch.mean(stddev, dim=1, keepdim=True)
27
+ inject = inject.expand(inject_shape)
28
+ return torch.cat((x, inject), dim=1)
29
+
30
+
31
+ class EqualizedLearningRateLayer(nn.Module):
32
+ """Applies equalized learning rate to the preceding layer. Implementation is from https://github.com/shanexn/pytorch-pggan."""
33
+ def __init__(self, layer):
34
+ super().__init__()
35
+ self.layer_ = layer
36
+
37
+ kaiming_normal_(self.layer_.weight, a=calculate_gain("conv2d"))
38
+ self.layer_norm_constant_ = (torch.mean(self.layer_.weight.data ** 2)) ** 0.5
39
+ self.layer_.weight.data.copy_(self.layer_.weight.data / self.layer_norm_constant_)
40
+
41
+ self.bias_ = self.layer_.bias if self.layer_.bias else None
42
+ self.layer_.bias = None
43
+
44
+ def forward(self, x):
45
+ self.layer_norm_constant_ = self.layer_norm_constant_.type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor)
46
+ x = self.layer_norm_constant_ * x
47
+ if self.bias_ is not None:
48
+ x += self.bias.view(1, self.bias.size()[0], 1, 1)
49
+ return x
50
+
51
+
52
+ class ConvBlock(nn.Module):
53
+ """Layer to perform a convolution followed by ELU
54
+ """
55
+ def __init__(self, in_channels, out_channels):
56
+ super(ConvBlock, self).__init__()
57
+
58
+ self.conv = Conv3x3(in_channels, out_channels)
59
+ self.nonlin = nn.ELU(inplace=True)
60
+
61
+ def forward(self, x):
62
+ out = self.conv(x)
63
+ out = self.nonlin(out)
64
+ return out
65
+
66
+
67
+ class Conv3x3(nn.Module):
68
+ """Layer to pad and convolve input
69
+ """
70
+ def __init__(self, in_channels, out_channels, use_refl=True):
71
+ super(Conv3x3, self).__init__()
72
+
73
+ if use_refl:
74
+ self.pad = nn.ReflectionPad2d(1)
75
+ else:
76
+ self.pad = nn.ZeroPad2d(1)
77
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
78
+
79
+ def forward(self, x):
80
+ out = self.pad(x)
81
+ out = self.conv(out)
82
+ return out
83
+
84
+
85
+ def upsample(x):
86
+ """Upsample input tensor by a factor of 2
87
+ """
88
+ return F.interpolate(x, scale_factor=2, mode="nearest")
image2reverb/mel.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch
3
+ import librosa
4
+
5
+
6
+ class LogMel(torch.nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self._eps = 1e-8
10
+
11
+ def transform(self, audio):
12
+ m = librosa.feature.melspectrogram(audio/numpy.abs(audio).max())
13
+ m = numpy.log(m + self._eps)
14
+ return torch.Tensor(((m - m.mean()) / m.std()) * 0.8).unsqueeze(0)
15
+
16
+ def inverse(self, spec):
17
+ s = spec.cpu().detach().numpy()
18
+ s = numpy.exp((s * 5) - 15.96) - self._eps # Empirical mean and standard deviation over test set
19
+ y = librosa.feature.inverse.mel_to_audio(s) # Reconstruct audio
20
+ return y/numpy.abs(y).max()
image2reverb/model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ import pytorch_lightning as pl
8
+ import torchvision
9
+ import pyroomacoustics
10
+ from .networks import Encoder, Generator, Discriminator
11
+ from .stft import STFT
12
+ from .mel import LogMel
13
+ from .util import compare_t60
14
+
15
+
16
+ # Hyperparameters
17
+ G_LR = 4e-4
18
+ D_LR = 2e-4
19
+ ENC_LR = 1e-5
20
+ ADAM_BETA = (0.0, 0.99)
21
+ ADAM_EPS = 1e-8
22
+ LAMBDA = 100
23
+
24
+
25
+ class Image2Reverb(pl.LightningModule):
26
+ def __init__(self, encoder_path, depthmodel_path, latent_dimension=512, spec="stft", d_threshold=0.2, t60p=True, constant_depth = None, test_callback=None):
27
+ super().__init__()
28
+ self._latent_dimension = latent_dimension
29
+ self._d_threshold = d_threshold
30
+ self.constant_depth = constant_depth
31
+ self.t60p = t60p
32
+ self.confidence = {}
33
+ self.tau = 50
34
+ self.test_callback = test_callback
35
+ self._opt = (d_threshold != None) and (d_threshold > 0) and (d_threshold < 1)
36
+ self.enc = Encoder(encoder_path, depthmodel_path, constant_depth=self.constant_depth, device=self.device)
37
+ self.g = Generator(latent_dimension, spec == "mel")
38
+ self.d = Discriminator(365, spec == "mel")
39
+ self.validation_inputs = []
40
+ self.stft_type = spec
41
+
42
+ def forward(self, x):
43
+ f = self.enc.forward(x)[0]
44
+ z = torch.cat((f, torch.randn((f.shape[0], (self._latent_dimension - f.shape[1]) if f.shape[1] < self._latent_dimension else f.shape[1], f.shape[2], f.shape[3]), device=self.device)), 1)
45
+ return self.g(z)
46
+
47
+ def training_step(self, batch, batch_idx, optimizer_idx):
48
+ opts = None
49
+ if self._opt:
50
+ opts = self.optimizers()
51
+
52
+ spec, label, p = batch
53
+ spec.requires_grad = True # For the backward pass, seems necessary for now
54
+
55
+ # Forward passes through models
56
+ f = self.enc.forward(label)[0]
57
+ z = torch.cat((f, torch.randn((f.shape[0], (self._latent_dimension - f.shape[1]) if f.shape[1] < self._latent_dimension else f.shape[1], f.shape[2], f.shape[3]), device=self.device)), 1)
58
+ fake_spec = self.g(z)
59
+ d_fake = self.d(fake_spec.detach(), f)
60
+ d_real = self.d(spec, f)
61
+
62
+ # Train Generator or Encoder
63
+ if optimizer_idx == 0 or optimizer_idx == 1:
64
+ d_fake2 = self.d(fake_spec.detach(), f)
65
+ G_loss1 = F.mse_loss(d_fake2, torch.ones(d_fake2.shape, device=self.device))
66
+ G_loss2 = F.l1_loss(fake_spec, spec)
67
+
68
+
69
+ G_loss = G_loss1 + (LAMBDA * G_loss2)
70
+ if self.t60p:
71
+ t60_err = torch.Tensor([compare_t60(torch.exp(a).sum(-2).squeeze(), torch.exp(b).sum(-2).squeeze()) for a, b in zip(spec, fake_spec)]).to(self.device).mean()
72
+ G_loss += t60_err
73
+ self.log("t60", t60_err, on_step=True, on_epoch=True, prog_bar=True)
74
+
75
+ if self._opt:
76
+ self.manual_backward(G_loss, self.opts[optimizer_idx])
77
+ opts[optimizer_idx].step()
78
+ opts[optimizer_idx].zero_grad()
79
+
80
+ self.log("G", G_loss, on_step=True, on_epoch=True, prog_bar=True)
81
+
82
+ return G_loss
83
+ else: # Train Discriminator
84
+ l_fakeD = F.mse_loss(d_fake, torch.zeros(d_fake.shape, device=self.device))
85
+ l_realD = F.mse_loss(d_real, torch.ones(d_real.shape, device=self.device))
86
+ D_loss = (l_realD + l_fakeD)
87
+
88
+ if self._opt and (D_loss > self._d_threshold):
89
+ self.manual_backward(D_loss, self.opts[optimizer_idx])
90
+ opts[optimizer_idx].step()
91
+ opts[optimizer_idx].zero_grad()
92
+
93
+ self.log("D", D_loss, on_step=True, on_epoch=True, prog_bar=True)
94
+
95
+ return D_loss
96
+
97
+ def configure_optimizers(self):
98
+ g_optim = torch.optim.Adam(self.g.parameters(), lr=G_LR, betas=ADAM_BETA, eps=ADAM_EPS)
99
+ d_optim = torch.optim.Adam(self.d.parameters(), lr=D_LR, betas=ADAM_BETA, eps=ADAM_EPS)
100
+ enc_optim = torch.optim.Adam(self.enc.parameters(), lr=ENC_LR, betas=ADAM_BETA, eps=ADAM_EPS)
101
+ return [enc_optim, g_optim, d_optim], []
102
+
103
+ def validation_step(self, batch, batch_idx):
104
+ spec, label, paths = batch
105
+ examples = [os.path.basename(s[:s.rfind("_")]) for s, _ in zip(*paths)]
106
+
107
+ # Forward passes through models
108
+ f = self.enc.forward(label)[0]
109
+ z = torch.cat((f, torch.randn((f.shape[0], (self._latent_dimension - f.shape[1]) if f.shape[1] < self._latent_dimension else f.shape[1], f.shape[2], f.shape[3]), device=self.device)), 1)
110
+ fake_spec = self.g(z)
111
+
112
+ # Get audio
113
+ stft = LogMel() if self.stft_type == "mel" else STFT()
114
+ y_r = [stft.inverse(s.squeeze()) for s in spec]
115
+ y_f = [stft.inverse(s.squeeze()) for s in fake_spec]
116
+
117
+ # RT60 error (in percentages)
118
+ val_pct = 1
119
+ try:
120
+ f = lambda x : pyroomacoustics.experimental.rt60.measure_rt60(x, 22050)
121
+ t60_r = [f(y) for y in y_r if len(y)]
122
+ t60_f = [f(y) for y in y_f if len(y)]
123
+ val_pct = numpy.mean([((t_b - t_a)/t_a) for t_a, t_b in zip(t60_r, t60_f)])
124
+ except:
125
+ pass
126
+
127
+ return {"val_t60err": val_pct, "val_spec": fake_spec, "val_audio": torch.Tensor(y_f), "val_img": label, "val_examples": examples}
128
+
129
+ def validation_epoch_end(self, outputs):
130
+ if not len(outputs):
131
+ return
132
+ # Log mean T60 errors (in percentages)
133
+ val_t60errmean = torch.Tensor(numpy.array([output["val_t60err"] for output in outputs])).mean()
134
+ self.log("val_t60err", val_t60errmean, on_epoch=True, prog_bar=True)
135
+
136
+ # Log generated spectrogram images
137
+ grid = torchvision.utils.make_grid([torch.flip(x, [0]) for y in [output["val_spec"] for output in outputs] for x in y])
138
+ self.logger.experiment.add_image("generated_spectrograms", grid, self.current_epoch)
139
+
140
+ # Log model input images
141
+ grid = torchvision.utils.make_grid([x for y in [output["val_img"] for output in outputs] for x in y])
142
+ self.logger.experiment.add_image("input_images_with_depthmaps", grid, self.current_epoch)
143
+
144
+ # Log generated audio examples
145
+ for output in outputs:
146
+ for example, audio in zip(output["val_examples"], output["val_audio"]):
147
+ y = audio
148
+ self.logger.experiment.add_audio("generated_audio_%s" % example, y, self.current_epoch, sample_rate=22050)
149
+
150
+ def test_step(self, batch, batch_idx):
151
+ spec, label, paths = batch
152
+ examples = [os.path.basename(s[:s.rfind("_")]) for s, _ in zip(*paths)]
153
+
154
+ # Forward passes through models
155
+ f, img = self.enc.forward(label)
156
+ img = (img + 1) * 0.5
157
+ z = torch.cat((f, torch.randn((f.shape[0], (self._latent_dimension - f.shape[1]) if f.shape[1] < self._latent_dimension else f.shape[1], f.shape[2], f.shape[3]), device=self.device)), 1)
158
+ fake_spec = self.g(z)
159
+
160
+ # Get audio
161
+ stft = LogMel() if self.stft_type == "mel" else STFT()
162
+ y_r = [stft.inverse(s.squeeze()) for s in spec]
163
+ y_f = [stft.inverse(s.squeeze()) for s in fake_spec]
164
+
165
+ # RT60 error (in percentages)
166
+ val_pct = 1
167
+ f = lambda x : pyroomacoustics.experimental.rt60.measure_rt60(x, 22050)
168
+ val_pct = []
169
+ for y_real, y_fake in zip(y_r, y_f):
170
+ try:
171
+ t_a = f(y_real)
172
+ t_b = f(y_fake)
173
+ val_pct.append((t_b - t_a)/t_a)
174
+ except:
175
+ val_pct.append(numpy.nan)
176
+
177
+ return {"test_t60err": val_pct, "test_spec": fake_spec, "test_audio": y_f, "test_img": img, "test_examples": examples}
178
+
179
+ def test_epoch_end(self, outputs):
180
+ if not self.test_callback:
181
+ return
182
+
183
+ examples = []
184
+ t60 = []
185
+ spec_images = []
186
+ audio = []
187
+ input_images = []
188
+ input_depthmaps = []
189
+
190
+ for output in outputs:
191
+ for i in range(len(output["test_examples"])):
192
+ img = output["test_img"][i]
193
+ if img.shape[0] == 3:
194
+ rgb = img
195
+ img = torch.cat((rgb, torch.zeros((1, rgb.shape[1], rgb.shape[2]), device=self.device)), 0)
196
+ t60.append(output["test_t60err"][i])
197
+ spec_images.append(output["test_spec"][i].cpu().squeeze().detach().numpy())
198
+ audio.append(output["test_audio"][i])
199
+ input_images.append(img.cpu().squeeze().permute(1, 2, 0)[:,:,:-1].detach().numpy())
200
+ input_depthmaps.append(img.cpu().squeeze().permute(1, 2, 0)[:,:,-1].squeeze().detach().numpy())
201
+ examples.append(output["test_examples"][i])
202
+
203
+ self.test_callback(examples, t60, spec_images, audio, input_images, input_depthmaps)
204
+
205
+ @property
206
+ def automatic_optimization(self) -> bool:
207
+ return not self._opt
image2reverb/networks.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.models as models
6
+ import torch.utils.model_zoo as model_zoo
7
+ from collections import OrderedDict
8
+ from .layers import PixelWiseNormLayer, MiniBatchAverageLayer, EqualizedLearningRateLayer, Conv3x3, ConvBlock, upsample
9
+
10
+
11
+ class Encoder(nn.Module):
12
+ """Load encoder from pre-trained ResNet50 (places365 CNNs) model. Link: http://places2.csail.mit.edu/models_places365/resnet50_places365.pth.tar"""
13
+ def __init__(self, model_weights, depth_model, constant_depth=None, device="cuda", train_enc=True):
14
+ super().__init__()
15
+ self.device = device
16
+ self._constant_depth = constant_depth
17
+ self.model = models.resnet50(num_classes=365)
18
+
19
+ if model_weights:
20
+ c = torch.load(model_weights, map_location=self.device)
21
+ state_dict = {k.replace("module.", ""): v for k, v in c["state_dict"].items()}
22
+ self.model.load_state_dict(state_dict)
23
+
24
+ self._has_depth = False
25
+ if depth_model:
26
+ f = self.model.conv1.weight
27
+ self.model.conv1.weight = torch.nn.Parameter(torch.cat((f, torch.randn(64, 1, 7, 7)), 1))
28
+ self.model.to(self.device)
29
+
30
+ encoder_path = os.path.join(depth_model, "encoder.pth")
31
+ depth_decoder_path = os.path.join(depth_model, "depth.pth")
32
+ self.depth_encoder = ResnetEncoder(18, False)
33
+ loaded_dict_enc = torch.load(encoder_path, map_location=self.device)
34
+
35
+ self.feed_height = loaded_dict_enc["height"]
36
+ self.feed_width = loaded_dict_enc["width"]
37
+ filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in self.depth_encoder.state_dict()}
38
+ self.depth_encoder.load_state_dict(filtered_dict_enc)
39
+ self.depth_encoder.to(self.device)
40
+ self.depth_encoder.eval()
41
+
42
+ self.depth_decoder = DepthDecoder(num_ch_enc=self.depth_encoder.num_ch_enc, scales=range(4))
43
+ loaded_dict = torch.load(depth_decoder_path, map_location=self.device)
44
+ self.depth_decoder.load_state_dict(loaded_dict, strict=False)
45
+ self.depth_decoder.to(self.device)
46
+ self.depth_decoder.eval()
47
+
48
+ self._has_depth = True
49
+
50
+ if train_enc:
51
+ self.model.train()
52
+
53
+ def forward(self, x):
54
+ if self._has_depth:
55
+ d = torch.full((x.shape[0], 1, x.shape[2], x.shape[3]), self._constant_depth, device=x.device) if self._constant_depth is not None else list(self.depth_decoder(self.depth_encoder(x)).values())[-1]
56
+ x = torch.cat((x, d), 1)
57
+ return self.model.forward(x).unsqueeze(-1).unsqueeze(-1), x
58
+
59
+
60
+ class Generator(nn.Module):
61
+ """Build non-progressive variant of GANSynth generator."""
62
+ def __init__(self, latent_size=512, mel_spec=False): # Encoder output should contain 2048 values
63
+ super().__init__()
64
+ self.latent_size = latent_size
65
+ self._mel_spec = mel_spec
66
+ self.build_model()
67
+
68
+ def forward(self, x):
69
+ return self.model(x)
70
+
71
+ def build_model(self):
72
+ model = []
73
+ # Input block
74
+ if self._mel_spec:
75
+ model.append(nn.Conv2d(self.latent_size, 256, kernel_size=(4, 2), stride=1, padding=2, bias=False))
76
+ else:
77
+ model.append(nn.Conv2d(self.latent_size, 256, kernel_size=8, stride=1, padding=7, bias=False)) # Modified to k=8, p=7 for our image dimensions (i.e. 512x512)
78
+ model.append(EqualizedLearningRateLayer(model[-1]))
79
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
80
+ model.append(EqualizedLearningRateLayer(model[-1]))
81
+ model.append(nn.LeakyReLU(negative_slope=0.2))
82
+ model.append(PixelWiseNormLayer())
83
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
84
+ model.append(EqualizedLearningRateLayer(model[-1]))
85
+ model.append(nn.LeakyReLU(negative_slope=0.2))
86
+ model.append(PixelWiseNormLayer())
87
+ model.append(nn.Upsample(scale_factor=2, mode="nearest"))
88
+
89
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
90
+ model.append(EqualizedLearningRateLayer(model[-1]))
91
+ model.append(nn.LeakyReLU(negative_slope=0.2))
92
+ model.append(PixelWiseNormLayer())
93
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
94
+ model.append(EqualizedLearningRateLayer(model[-1]))
95
+ model.append(nn.LeakyReLU(negative_slope=0.2))
96
+ model.append(PixelWiseNormLayer())
97
+ model.append(nn.Upsample(scale_factor=2, mode="nearest"))
98
+
99
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
100
+ model.append(EqualizedLearningRateLayer(model[-1]))
101
+ model.append(nn.LeakyReLU(negative_slope=0.2))
102
+ model.append(PixelWiseNormLayer())
103
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
104
+ model.append(EqualizedLearningRateLayer(model[-1]))
105
+ model.append(nn.LeakyReLU(negative_slope=0.2))
106
+ model.append(PixelWiseNormLayer())
107
+ model.append(nn.Upsample(scale_factor=2, mode="nearest"))
108
+
109
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
110
+ model.append(EqualizedLearningRateLayer(model[-1]))
111
+ model.append(nn.LeakyReLU(negative_slope=0.2))
112
+ model.append(PixelWiseNormLayer())
113
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
114
+ model.append(EqualizedLearningRateLayer(model[-1]))
115
+ model.append(nn.LeakyReLU(negative_slope=0.2))
116
+ model.append(PixelWiseNormLayer())
117
+ model.append(nn.Upsample(scale_factor=2, mode="nearest"))
118
+
119
+ model.append(nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=False))
120
+ model.append(EqualizedLearningRateLayer(model[-1]))
121
+ model.append(nn.LeakyReLU(negative_slope=0.2))
122
+ model.append(PixelWiseNormLayer())
123
+ model.append(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False))
124
+ model.append(EqualizedLearningRateLayer(model[-1]))
125
+ model.append(nn.LeakyReLU(negative_slope=0.2))
126
+ model.append(PixelWiseNormLayer())
127
+ model.append(nn.Upsample(scale_factor=2, mode="nearest"))
128
+
129
+ model.append(nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=False))
130
+ model.append(EqualizedLearningRateLayer(model[-1]))
131
+ model.append(nn.LeakyReLU(negative_slope=0.2))
132
+ model.append(PixelWiseNormLayer())
133
+ model.append(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False))
134
+ model.append(EqualizedLearningRateLayer(model[-1]))
135
+ model.append(nn.LeakyReLU(negative_slope=0.2))
136
+ model.append(PixelWiseNormLayer())
137
+ model.append(nn.Upsample(scale_factor=2, mode="nearest"))
138
+
139
+ model.append(nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, bias=False))
140
+ model.append(EqualizedLearningRateLayer(model[-1]))
141
+ model.append(nn.LeakyReLU(negative_slope=0.2))
142
+ model.append(PixelWiseNormLayer())
143
+ model.append(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False))
144
+ model.append(EqualizedLearningRateLayer(model[-1]))
145
+ model.append(nn.LeakyReLU(negative_slope=0.2))
146
+ model.append(PixelWiseNormLayer())
147
+
148
+ model.append(nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0, bias=False))
149
+ model.append(EqualizedLearningRateLayer(model[-1]))
150
+ model.append(nn.Tanh())
151
+ self.model = nn.Sequential(*model)
152
+
153
+
154
+ class Discriminator(nn.Module):
155
+ def __init__(self, label_size=365, mel_spec=False):
156
+ super().__init__()
157
+ self._label_size = 365
158
+ self._mel_spec = mel_spec
159
+ self.build_model()
160
+
161
+ def forward(self, x, l):
162
+ d = self.model(x)
163
+ if self._mel_spec:
164
+ s = list(l.squeeze().shape)
165
+ s[-1] = 19
166
+ z = torch.cat((l.squeeze(), torch.zeros(s).type_as(x)), -1).reshape(d.shape[0], -1, 2, 4)
167
+ else:
168
+ s = list(l.squeeze().shape)
169
+ s[-1] = 512 - s[-1]
170
+ z = torch.cat((l.squeeze(), torch.zeros(s).type_as(x)), -1).reshape(d.shape[0], -1, 8, 8)
171
+ k = torch.cat((d, z), 1)
172
+ return self.output(k)
173
+
174
+ def build_model(self):
175
+ model = []
176
+ model.append(nn.Conv2d(1, 32, kernel_size=1, stride=1, padding=0, bias=False))
177
+ model.append(EqualizedLearningRateLayer(model[-1]))
178
+ model.append(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False))
179
+ model.append(EqualizedLearningRateLayer(model[-1]))
180
+ model.append(nn.LeakyReLU(negative_slope=0.2))
181
+ model.append(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False))
182
+ model.append(EqualizedLearningRateLayer(model[-1]))
183
+ model.append(nn.LeakyReLU(negative_slope=0.2))
184
+ model.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=False, count_include_pad=False))
185
+
186
+ model.append(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False))
187
+ model.append(EqualizedLearningRateLayer(model[-1]))
188
+ model.append(nn.LeakyReLU(negative_slope=0.2))
189
+ model.append(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False))
190
+ model.append(EqualizedLearningRateLayer(model[-1]))
191
+ model.append(nn.LeakyReLU(negative_slope=0.2))
192
+ model.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=False, count_include_pad=False))
193
+
194
+ model.append(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False))
195
+ model.append(EqualizedLearningRateLayer(model[-1]))
196
+ model.append(nn.LeakyReLU(negative_slope=0.2))
197
+ model.append(nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False))
198
+ model.append(EqualizedLearningRateLayer(model[-1]))
199
+ model.append(nn.LeakyReLU(negative_slope=0.2))
200
+ model.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=False, count_include_pad=False))
201
+
202
+ model.append(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False))
203
+ model.append(EqualizedLearningRateLayer(model[-1]))
204
+ model.append(nn.LeakyReLU(negative_slope=0.2))
205
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
206
+ model.append(EqualizedLearningRateLayer(model[-1]))
207
+ model.append(nn.LeakyReLU(negative_slope=0.2))
208
+ model.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=False, count_include_pad=False))
209
+
210
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
211
+ model.append(EqualizedLearningRateLayer(model[-1]))
212
+ model.append(nn.LeakyReLU(negative_slope=0.2))
213
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
214
+ model.append(EqualizedLearningRateLayer(model[-1]))
215
+ model.append(nn.LeakyReLU(negative_slope=0.2))
216
+ model.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=False, count_include_pad=False))
217
+
218
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
219
+ model.append(EqualizedLearningRateLayer(model[-1]))
220
+ model.append(nn.LeakyReLU(negative_slope=0.2))
221
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
222
+ model.append(EqualizedLearningRateLayer(model[-1]))
223
+ model.append(nn.LeakyReLU(negative_slope=0.2))
224
+ model.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=False, count_include_pad=False))
225
+
226
+ model.append(MiniBatchAverageLayer())
227
+ model.append(nn.Conv2d(257, 256, kernel_size=3, stride=1, padding=1, bias=False))
228
+ model.append(EqualizedLearningRateLayer(model[-1]))
229
+ model.append(nn.LeakyReLU(negative_slope=0.2))
230
+ model.append(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False))
231
+ model.append(EqualizedLearningRateLayer(model[-1]))
232
+ model.append(nn.LeakyReLU(negative_slope=0.2))
233
+
234
+ output = [] # After the label concatenation
235
+ if self._mel_spec:
236
+ output.append(nn.Conv2d(304, 256, kernel_size=1, stride=1, padding=0, bias=False))
237
+ else:
238
+ output.append(nn.Conv2d(264, 256, kernel_size=1, stride=1, padding=0, bias=False))
239
+
240
+ output.append(nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0, bias=False))
241
+
242
+ # model.append(nn.Sigmoid()) # Output probability (in [0, 1])
243
+ self.model = nn.Sequential(*model)
244
+ self.output = nn.Sequential(*output)
245
+
246
+
247
+ class ResnetEncoder(nn.Module):
248
+ """Pytorch module for a resnet encoder
249
+ """
250
+ def __init__(self, num_layers, pretrained, num_input_images=1):
251
+ super(ResnetEncoder, self).__init__()
252
+
253
+ self.num_ch_enc = numpy.array([64, 64, 128, 256, 512])
254
+
255
+ resnets = {18: models.resnet18,
256
+ 34: models.resnet34,
257
+ 50: models.resnet50,
258
+ 101: models.resnet101,
259
+ 152: models.resnet152}
260
+
261
+ if num_layers not in resnets:
262
+ raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
263
+
264
+ if num_input_images > 1:
265
+ self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
266
+ else:
267
+ self.encoder = resnets[num_layers](pretrained)
268
+
269
+ if num_layers > 34:
270
+ self.num_ch_enc[1:] *= 4
271
+
272
+ def forward(self, input_image):
273
+ self.features = []
274
+ x = (input_image - 0.45) / 0.225
275
+ x = self.encoder.conv1(x)
276
+ x = self.encoder.bn1(x)
277
+ self.features.append(self.encoder.relu(x))
278
+ self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
279
+ self.features.append(self.encoder.layer2(self.features[-1]))
280
+ self.features.append(self.encoder.layer3(self.features[-1]))
281
+ self.features.append(self.encoder.layer4(self.features[-1]))
282
+
283
+ return self.features
284
+
285
+
286
+
287
+ class DepthDecoder(nn.Module):
288
+ def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
289
+ super(DepthDecoder, self).__init__()
290
+
291
+ self.num_output_channels = num_output_channels
292
+ self.use_skips = use_skips
293
+ self.upsample_mode = "nearest"
294
+ self.scales = scales
295
+
296
+ self.num_ch_enc = num_ch_enc
297
+ self.num_ch_dec = numpy.array([16, 32, 64, 128, 256])
298
+
299
+ # decoder
300
+ self.convs = OrderedDict()
301
+ for i in range(4, -1, -1):
302
+ # upconv_0
303
+ num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
304
+ num_ch_out = self.num_ch_dec[i]
305
+ # self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
306
+ setattr(self, "upconv_{}_0".format(i), ConvBlock(num_ch_in, num_ch_out))
307
+
308
+ # upconv_1
309
+ num_ch_in = self.num_ch_dec[i]
310
+ if self.use_skips and i > 0:
311
+ num_ch_in += self.num_ch_enc[i - 1]
312
+ num_ch_out = self.num_ch_dec[i]
313
+ # self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
314
+ setattr(self, "upconv_{}_1".format(i), ConvBlock(num_ch_in, num_ch_out))
315
+
316
+ for s in self.scales:
317
+ # self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
318
+ setattr(self, "disp_{}".format(s), Conv3x3(self.num_ch_dec[s], self.num_output_channels))
319
+
320
+ self.decoder = nn.ModuleList(
321
+ [x for y in [[getattr(self, "upconv_{}_0".format(i)), getattr(self, "upconv_{}_1".format(i))] for i in range(4, -1, -1)] for x in y] +
322
+ [getattr(self, "disp_{}".format(s)) for s in self.scales]
323
+ )
324
+ self.sigmoid = nn.Sigmoid()
325
+
326
+ def forward(self, input_features):
327
+ outputs = {}
328
+
329
+ # decoder
330
+ x = input_features[-1]
331
+ for i in range(4, -1, -1):
332
+ # x = self.convs[("upconv", i, 0)](x)
333
+ x = getattr(self, "upconv_{}_0".format(i))(x)
334
+ x = [upsample(x)]
335
+ if self.use_skips and i > 0:
336
+ x += [input_features[i - 1]]
337
+ x = torch.cat(x, 1)
338
+ # x = self.convs[("upconv", i, 1)](x)
339
+ x = getattr(self, "upconv_{}_1".format(i))(x)
340
+ if i in self.scales:
341
+ outputs[("disp", i)] = self.sigmoid(getattr(self, "disp_{}".format(i))(x))
342
+ # setattr(self, "outputs_disp_{}".format(i), self.sigmoid(getattr(self, "disp_{}".format(i))(x)))
343
+
344
+ return outputs
image2reverb/stft.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch
3
+ import librosa
4
+
5
+
6
+ class STFT(torch.nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self._eps = 1e-8
10
+
11
+ def transform(self, audio):
12
+ m = numpy.abs(librosa.stft(audio/numpy.abs(audio).max(), 1024, 256))[:-1,:]
13
+ m = numpy.log(m + self._eps)
14
+ m = (((m - m.min())/(m.max() - m.min()) * 2) - 1)
15
+ return (torch.FloatTensor if torch.cuda.is_available() else torch.Tensor)(m * 0.8).unsqueeze(0)
16
+
17
+ def inverse(self, spec):
18
+ s = spec.cpu().detach().numpy()
19
+ s = numpy.exp((((s + 1) * 0.5) * 19.5) - 17.5) - self._eps # Empirical (average) min and max over test set
20
+ rp = numpy.random.uniform(-numpy.pi, numpy.pi, s.shape)
21
+ f = s * (numpy.cos(rp) + (1.j * numpy.sin(rp)))
22
+ y = librosa.istft(f) # Reconstruct audio
23
+ return y/numpy.abs(y).max()
image2reverb/util.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy
4
+ import torch
5
+ import torch.fft
6
+ from PIL import Image
7
+
8
+
9
+ def compare_t60(a, b, sr=86):
10
+ try:
11
+ a = a.detach().clone().abs()
12
+ b = b.detach().clone().abs()
13
+ a = (a - a.min())/(a.max() - a.min())
14
+ b = (b - b.min())/(b.max() - b.min())
15
+ t_a = estimate_t60(a, sr)
16
+ t_b = estimate_t60(b, sr)
17
+ return abs((t_b - t_a)/t_a) * 100
18
+ except Exception as error:
19
+ return 100
20
+
21
+
22
+ def estimate_t60(audio, sr):
23
+ fs = float(sr)
24
+ audio = audio.detach().clone()
25
+
26
+ decay_db = 20
27
+
28
+ # The power of the impulse response in dB
29
+ power = audio ** 2
30
+ energy = torch.flip(torch.cumsum(torch.flip(power, [0]), 0), [0]) # Integration according to Schroeder
31
+
32
+ # remove the possibly all zero tail
33
+ i_nz = torch.max(torch.where(energy > 0)[0])
34
+ n = energy[:i_nz]
35
+ db = 10 * torch.log10(n)
36
+ db = db - db[0]
37
+
38
+ # -5 dB headroom
39
+ i_5db = torch.min(torch.where(-5 - db > 0)[0])
40
+ e_5db = db[i_5db]
41
+ t_5db = i_5db / fs
42
+
43
+ # after decay
44
+ i_decay = torch.min(torch.where(-5 - decay_db - db > 0)[0])
45
+ t_decay = i_decay / fs
46
+
47
+ # compute the decay time
48
+ decay_time = t_decay - t_5db
49
+ est_rt60 = (60 / decay_db) * decay_time
50
+
51
+ return est_rt60
52
+
53
+ def hilbert(x): #hilbert transform
54
+ N = x.shape[1]
55
+ Xf = torch.fft.fft(x, n=None, dim=-1)
56
+ h = torch.zeros(N)
57
+ if N % 2 == 0:
58
+ h[0] = h[N//2] = 1
59
+ h[1:N//2] = 2
60
+ else:
61
+ h[0] = 1
62
+ h[1:(N + 1)//2] = 2
63
+ x = torch.fft.ifft(Xf * h)
64
+ return x
65
+
66
+
67
+ def spectral_centroid(x): #calculate the spectral centroid "brightness" of an audio input
68
+ Xf = torch.abs(torch.fft.fft(x,n=None,dim=-1)) #take fft and abs of x
69
+ norm_Xf = Xf / sum(sum(Xf)) # like probability mass function
70
+ norm_freqs = torch.linspace(0, 1, Xf.shape[1])
71
+ spectral_centroid = sum(sum(norm_freqs * norm_Xf))
72
+ return spectral_centroid
73
+
74
+
75
+ # Converts a Tensor into a Numpy array
76
+ # |imtype|: the desired type of the converted numpy array
77
+ def tensor2im(image_tensor, imtype=numpy.uint8, normalize=True):
78
+ if isinstance(image_tensor, list):
79
+ image_numpy = []
80
+ for i in range(len(image_tensor)):
81
+ image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
82
+ return image_numpy
83
+ image_numpy = image_tensor.cpu().float().numpy()
84
+ if normalize:
85
+ image_numpy = (numpy.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
86
+ else:
87
+ image_numpy = numpy.transpose(image_numpy, (1, 2, 0)) * 255.0
88
+ image_numpy = numpy.clip(image_numpy, 0, 255)
89
+ if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
90
+ image_numpy = image_numpy[:,:,0]
91
+ return image_numpy.astype(imtype)
92
+
93
+ # Converts a one-hot tensor into a colorful label map
94
+ def tensor2label(label_tensor, n_label, imtype=numpy.uint8):
95
+ if n_label == 0:
96
+ return tensor2im(label_tensor, imtype)
97
+ label_tensor = label_tensor.cpu().float()
98
+ if label_tensor.size()[0] > 1:
99
+ label_tensor = label_tensor.max(0, keepdim=True)[1]
100
+ label_tensor = Colorize(n_label)(label_tensor)
101
+ label_numpy = numpy.transpose(label_tensor.numpy(), (1, 2, 0))
102
+ return label_numpy.astype(imtype)
103
+
104
+ def save_image(image_numpy, image_path):
105
+ image_pil = Image.fromarray(image_numpy)
106
+ image_pil.save(image_path)
107
+
108
+ def mkdirs(paths):
109
+ if isinstance(paths, list) and not isinstance(paths, str):
110
+ for path in paths:
111
+ mkdir(path)
112
+ else:
113
+ mkdir(paths)
114
+
115
+ def mkdir(path):
116
+ if not os.path.exists(path):
117
+ os.makedirs(path)
118
+
119
+ ###############################################################################
120
+ # Code from
121
+ # https://github.com/ycszen/pytorch-seg/blob/master/transform.py
122
+ # Modified so it complies with the Citscape label map colors
123
+ ###############################################################################
124
+ def uint82bin(n, count=8):
125
+ """returns the binary of integer n, count refers to amount of bits"""
126
+ return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
127
+
128
+ def labelcolormap(N):
129
+ if N == 35: # cityscape
130
+ cmap = numpy.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
131
+ (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
132
+ (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
133
+ (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
134
+ ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
135
+ dtype=numpy.uint8)
136
+ else:
137
+ cmap = numpy.zeros((N, 3), dtype=numpy.uint8)
138
+ for i in range(N):
139
+ r, g, b = 0, 0, 0
140
+ id = i
141
+ for j in range(7):
142
+ str_id = uint82bin(id)
143
+ r = r ^ (numpy.uint8(str_id[-1]) << (7-j))
144
+ g = g ^ (numpy.uint8(str_id[-2]) << (7-j))
145
+ b = b ^ (numpy.uint8(str_id[-3]) << (7-j))
146
+ id = id >> 3
147
+ cmap[i, 0] = r
148
+ cmap[i, 1] = g
149
+ cmap[i, 2] = b
150
+ return cmap
151
+
152
+ class Colorize(object):
153
+ def __init__(self, n=35):
154
+ self.cmap = labelcolormap(n)
155
+ self.cmap = torch.from_numpy(self.cmap[:n])
156
+
157
+ def __call__(self, gray_image):
158
+ size = gray_image.size()
159
+ color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
160
+
161
+ for label in range(0, len(self.cmap)):
162
+ mask = (label == gray_image[0]).cpu()
163
+ color_image[0][mask] = self.cmap[label][0]
164
+ color_image[1][mask] = self.cmap[label][1]
165
+ color_image[2][mask] = self.cmap[label][2]
166
+
167
+ return color_image
model.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pytorch_lightning
4
+ pyroomacoustics
5
+ soundfile
6
+ librosa