XDHDD commited on
Commit
02e1d16
·
1 Parent(s): e34c0af

Upload 8 files

Browse files
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ lightning_logs/predictor/checkpoints/predictor.ckpt filter=lfs diff=lfs merge=lfs -text
2
+ lightning_logs/version_0/checkpoints/frn-epoch=65-val_loss=0.2290.ckpt filter=lfs diff=lfs merge=lfs -text
3
+ lightning_logs/version_0/checkpoints/frn.onnx filter=lfs diff=lfs merge=lfs -text
lightning_logs/predictor/checkpoints/predictor.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f3679c9431666575eb7899e556d040073aa74956c48f122b16b30b9efa2e93b
3
+ size 14985163
lightning_logs/predictor/hparams.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ batch_size: 90
2
+ input: mag
3
+ lstm_dim: 512
4
+ lstm_layers: 1
5
+ output: mag
6
+ window_size: 960
lightning_logs/version_0/checkpoints/frn-epoch=65-val_loss=0.2290.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4061bb0f6e669315e00878009440dab749f60f823d5bf863bfa4b8172d96d073
3
+ size 109184745
lightning_logs/version_0/checkpoints/frn.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fdf07d992ff655e5ab32074d4d7b747986cd79fed16b499ed11b120c7042a666
3
+ size 36527867
lightning_logs/version_0/hparams.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ batch_size: 90
2
+ cnn_dim: 64
3
+ cnn_layers: 5
4
+ lstm_dim: 512
5
+ lstm_layers: 1
6
+ window_size: 960
models/__init__.py ADDED
File without changes
models/blocks.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import pytorch_lightning as pl
3
+ import torch
4
+ from einops.layers.torch import Rearrange
5
+ from torch import nn
6
+
7
+
8
+ class Aff(nn.Module):
9
+ def __init__(self, dim):
10
+ super().__init__()
11
+
12
+ self.alpha = nn.Parameter(torch.ones([1, 1, dim]))
13
+ self.beta = nn.Parameter(torch.zeros([1, 1, dim]))
14
+
15
+ def forward(self, x):
16
+ x = x * self.alpha + self.beta
17
+ return x
18
+
19
+
20
+ class FeedForward(nn.Module):
21
+ def __init__(self, dim, hidden_dim, dropout=0.):
22
+ super().__init__()
23
+ self.net = nn.Sequential(
24
+ nn.Linear(dim, hidden_dim),
25
+ nn.GELU(),
26
+ nn.Dropout(dropout),
27
+ nn.Linear(hidden_dim, dim),
28
+ nn.Dropout(dropout)
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.net(x)
33
+
34
+
35
+ class MLPBlock(nn.Module):
36
+
37
+ def __init__(self, dim, mlp_dim, dropout=0., init_values=1e-4):
38
+ super().__init__()
39
+
40
+ self.pre_affine = Aff(dim)
41
+ self.inter = nn.LSTM(input_size=dim, hidden_size=dim, num_layers=1,
42
+ bidirectional=False, batch_first=True)
43
+ self.ff = nn.Sequential(
44
+ FeedForward(dim, mlp_dim, dropout),
45
+ )
46
+ self.post_affine = Aff(dim)
47
+ self.gamma_1 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)
48
+ self.gamma_2 = nn.Parameter(init_values * torch.ones(dim), requires_grad=True)
49
+
50
+ def forward(self, x, state=None):
51
+ x = self.pre_affine(x)
52
+ if state is None:
53
+ inter, _ = self.inter(x)
54
+ else:
55
+ inter, state = self.inter(x, (state[0], state[1]))
56
+ x = x + self.gamma_1 * inter
57
+ x = self.post_affine(x)
58
+ x = x + self.gamma_2 * self.ff(x)
59
+ if state is None:
60
+ return x
61
+ state = torch.stack(state, 0)
62
+ return x, state
63
+
64
+
65
+ class Encoder(nn.Module):
66
+
67
+ def __init__(self, in_dim, dim, depth, mlp_dim):
68
+ super().__init__()
69
+ self.in_dim = in_dim
70
+ self.dim = dim
71
+ self.depth = depth
72
+ self.mlp_dim = mlp_dim
73
+ self.to_patch_embedding = nn.Sequential(
74
+ Rearrange('b c f t -> b t (c f)'),
75
+ nn.Linear(in_dim, dim),
76
+ nn.GELU()
77
+ )
78
+
79
+ self.mlp_blocks = nn.ModuleList([])
80
+
81
+ for _ in range(depth):
82
+ self.mlp_blocks.append(MLPBlock(self.dim, mlp_dim, dropout=0.15))
83
+
84
+ self.affine = nn.Sequential(
85
+ Aff(self.dim),
86
+ nn.Linear(dim, in_dim),
87
+ Rearrange('b t (c f) -> b c f t', c=2),
88
+ )
89
+
90
+ def forward(self, x_in, states=None):
91
+ x = self.to_patch_embedding(x_in)
92
+ if states is not None:
93
+ out_states = []
94
+ for i, mlp_block in enumerate(self.mlp_blocks):
95
+ if states is None:
96
+ x = mlp_block(x)
97
+ else:
98
+ x, state = mlp_block(x, states[i])
99
+ out_states.append(state)
100
+ x = self.affine(x)
101
+ x = x + x_in
102
+ if states is None:
103
+ return x
104
+ else:
105
+ return x, torch.stack(out_states, 0)
106
+
107
+
108
+ class Predictor(pl.LightningModule): # mel
109
+ def __init__(self, window_size=1536, sr=48000, lstm_dim=256, lstm_layers=3, n_mels=64):
110
+ super(Predictor, self).__init__()
111
+ self.window_size = window_size
112
+ self.hop_size = window_size // 2
113
+ self.lstm_dim = lstm_dim
114
+ self.n_mels = n_mels
115
+ self.lstm_layers = lstm_layers
116
+
117
+ fb = librosa.filters.mel(sr=sr, n_fft=self.window_size, n_mels=self.n_mels)[:, 1:]
118
+ self.fb = torch.from_numpy(fb).unsqueeze(0).unsqueeze(0)
119
+ self.lstm = nn.LSTM(input_size=self.n_mels, hidden_size=self.lstm_dim, bidirectional=False,
120
+ num_layers=self.lstm_layers, batch_first=True)
121
+ self.expand_dim = nn.Linear(self.lstm_dim, self.n_mels)
122
+ self.inv_mel = nn.Linear(self.n_mels, self.hop_size)
123
+
124
+ def forward(self, x, state=None): # B, 2, F, T
125
+
126
+ self.fb = self.fb.to(x.device)
127
+ x = torch.log(torch.matmul(self.fb, x) + 1e-8)
128
+ B, C, F, T = x.shape
129
+ x = x.reshape(B, F * C, T)
130
+ x = x.permute(0, 2, 1)
131
+ if state is None:
132
+ x, _ = self.lstm(x)
133
+ else:
134
+ x, state = self.lstm(x, (state[0], state[1]))
135
+ x = self.expand_dim(x)
136
+ x = torch.abs(self.inv_mel(torch.exp(x)))
137
+ x = x.permute(0, 2, 1)
138
+ x = x.reshape(B, C, -1, T)
139
+ if state is None:
140
+ return x
141
+ else:
142
+ return x, torch.stack(state, 0)
models/frn.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import librosa
4
+ import pytorch_lightning as pl
5
+ import soundfile as sf
6
+ import torch
7
+ from torch import nn
8
+ from torch.utils.data import DataLoader
9
+ from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality as PESQ
10
+ from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility as STOI
11
+
12
+ from PLCMOS.plc_mos import PLCMOSEstimator
13
+ from config import CONFIG
14
+ from loss import Loss
15
+ from models.blocks import Encoder, Predictor
16
+ from utils.utils import visualize, LSD
17
+
18
+ plcmos = PLCMOSEstimator()
19
+
20
+
21
+ class PLCModel(pl.LightningModule):
22
+ def __init__(self, train_dataset=None, val_dataset=None, window_size=960, enc_layers=4, enc_in_dim=384, enc_dim=768,
23
+ pred_dim=512, pred_layers=1, pred_ckpt_path='lightning_logs/predictor/checkpoints/predictor.ckpt'):
24
+ super(PLCModel, self).__init__()
25
+ self.window_size = window_size
26
+ self.hop_size = window_size // 2
27
+ self.learning_rate = CONFIG.TRAIN.lr
28
+ self.hparams.batch_size = CONFIG.TRAIN.batch_size
29
+
30
+ self.enc_layers = enc_layers
31
+ self.enc_in_dim = enc_in_dim
32
+ self.enc_dim = enc_dim
33
+ self.pred_dim = pred_dim
34
+ self.pred_layers = pred_layers
35
+ self.train_dataset = train_dataset
36
+ self.val_dataset = val_dataset
37
+ self.stoi = STOI(48000)
38
+ self.pesq = PESQ(16000, 'wb')
39
+
40
+ if pred_ckpt_path is not None:
41
+ self.predictor = Predictor.load_from_checkpoint(pred_ckpt_path)
42
+ else:
43
+ self.predictor = Predictor(window_size=self.window_size, lstm_dim=self.pred_dim,
44
+ lstm_layers=self.pred_layers)
45
+ self.joiner = nn.Sequential(
46
+ nn.Conv2d(3, 48, kernel_size=(9, 1), stride=1, padding=(4, 0), padding_mode='reflect',
47
+ groups=3),
48
+ nn.LeakyReLU(0.2),
49
+ nn.Conv2d(48, 2, kernel_size=1, stride=1, padding=0, groups=2),
50
+ )
51
+
52
+ self.encoder = Encoder(in_dim=self.window_size, dim=self.enc_in_dim, depth=self.enc_layers,
53
+ mlp_dim=self.enc_dim)
54
+
55
+ self.loss = Loss()
56
+ self.window = torch.sqrt(torch.hann_window(self.window_size))
57
+ self.save_hyperparameters('window_size', 'enc_layers', 'enc_in_dim', 'enc_dim', 'pred_dim', 'pred_layers')
58
+
59
+ def forward(self, x):
60
+ """
61
+ Input: real-imaginary; shape (B, F, T, 2); F = hop_size + 1
62
+ Output: real-imaginary
63
+ """
64
+
65
+ B, C, F, T = x.shape
66
+
67
+ x = x.permute(3, 0, 1, 2).unsqueeze(-1)
68
+ prev_mag = torch.zeros((B, 1, F, 1), device=x.device)
69
+ predictor_state = torch.zeros((2, self.predictor.lstm_layers, B, self.predictor.lstm_dim), device=x.device)
70
+ mlp_state = torch.zeros((self.encoder.depth, 2, 1, B, self.encoder.dim), device=x.device)
71
+ result = []
72
+ for step in x:
73
+ feat, mlp_state = self.encoder(step, mlp_state)
74
+ prev_mag, predictor_state = self.predictor(prev_mag, predictor_state)
75
+ feat = torch.cat((feat, prev_mag), 1)
76
+ feat = self.joiner(feat)
77
+ feat = feat + step
78
+ result.append(feat)
79
+ prev_mag = torch.linalg.norm(feat, dim=1, ord=1, keepdims=True) # compute magnitude
80
+ output = torch.cat(result, -1)
81
+ return output
82
+
83
+ def forward_onnx(self, x, prev_mag, predictor_state=None, mlp_state=None):
84
+ prev_mag, predictor_state = self.predictor(prev_mag, predictor_state)
85
+ feat, mlp_state = self.encoder(x, mlp_state)
86
+
87
+ feat = torch.cat((feat, prev_mag), 1)
88
+ feat = self.joiner(feat)
89
+ prev_mag = torch.linalg.norm(feat, dim=1, ord=1, keepdims=True)
90
+ feat = feat + x
91
+ return feat, prev_mag, predictor_state, mlp_state
92
+
93
+ def train_dataloader(self):
94
+ return DataLoader(self.train_dataset, shuffle=False, batch_size=self.hparams.batch_size,
95
+ num_workers=CONFIG.TRAIN.workers, persistent_workers=True)
96
+
97
+ def val_dataloader(self):
98
+ return DataLoader(self.val_dataset, shuffle=False, batch_size=self.hparams.batch_size,
99
+ num_workers=CONFIG.TRAIN.workers, persistent_workers=True)
100
+
101
+ def training_step(self, batch, batch_idx):
102
+ x_in, y = batch
103
+ f_0 = x_in[:, :, 0:1, :]
104
+ x = x_in[:, :, 1:, :]
105
+
106
+ x = self(x)
107
+ x = torch.cat([f_0, x], dim=2)
108
+
109
+ loss = self.loss(x, y)
110
+ self.log('train_loss', loss, logger=True)
111
+ return loss
112
+
113
+ def validation_step(self, val_batch, batch_idx):
114
+ x, y = val_batch
115
+ f_0 = x[:, :, 0:1, :]
116
+ x_in = x[:, :, 1:, :]
117
+
118
+ pred = self(x_in)
119
+ pred = torch.cat([f_0, pred], dim=2)
120
+
121
+ loss = self.loss(pred, y)
122
+ self.window = self.window.to(pred.device)
123
+ pred = torch.view_as_complex(pred.permute(0, 2, 3, 1).contiguous())
124
+ pred = torch.istft(pred, self.window_size, self.hop_size, window=self.window)
125
+ y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous())
126
+ y = torch.istft(y, self.window_size, self.hop_size, window=self.window)
127
+
128
+ self.log('val_loss', loss, on_step=False, on_epoch=True, logger=True, prog_bar=True, sync_dist=True)
129
+
130
+ if batch_idx == 0:
131
+ i = torch.randint(0, x.shape[0], (1,)).item()
132
+ x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous())
133
+ x = torch.istft(x[i], self.window_size, self.hop_size, window=self.window)
134
+
135
+ self.trainer.logger.log_spectrogram(y[i], x, pred[i], self.current_epoch)
136
+ self.trainer.logger.log_audio(y[i], x, pred[i], self.current_epoch)
137
+
138
+ def test_step(self, test_batch, batch_idx):
139
+ inp, tar, inp_wav, tar_wav = test_batch
140
+ inp_wav = inp_wav.squeeze()
141
+ tar_wav = tar_wav.squeeze()
142
+ f_0 = inp[:, :, 0:1, :]
143
+ x = inp[:, :, 1:, :]
144
+ pred = self(x)
145
+ pred = torch.cat([f_0, pred], dim=2)
146
+ pred = torch.istft(pred.squeeze(0).permute(1, 2, 0), self.window_size, self.hop_size,
147
+ window=self.window.to(pred.device))
148
+ stoi = self.stoi(pred, tar_wav)
149
+
150
+ tar_wav = tar_wav.cpu().numpy()
151
+ inp_wav = inp_wav.cpu().numpy()
152
+ pred = pred.detach().cpu().numpy()
153
+ lsd, _ = LSD(tar_wav, pred)
154
+
155
+ if batch_idx in [5, 7, 9]:
156
+ sample_path = os.path.join(CONFIG.LOG.sample_path)
157
+ path = os.path.join(sample_path, 'sample_' + str(batch_idx))
158
+ visualize(tar_wav, inp_wav, pred, path)
159
+ sf.write(os.path.join(path, 'enhanced_output.wav'), pred, samplerate=CONFIG.DATA.sr, subtype='PCM_16')
160
+ sf.write(os.path.join(path, 'lossy_input.wav'), inp_wav, samplerate=CONFIG.DATA.sr, subtype='PCM_16')
161
+ sf.write(os.path.join(path, 'target.wav'), tar_wav, samplerate=CONFIG.DATA.sr, subtype='PCM_16')
162
+ if CONFIG.DATA.sr != 16000:
163
+ pred = librosa.resample(pred, orig_sr=48000, target_sr=16000)
164
+ tar_wav = librosa.resample(tar_wav, orig_sr=48000, target_sr=16000, res_type='kaiser_fast')
165
+ ret = plcmos.run(pred, tar_wav)
166
+ pesq = self.pesq(torch.tensor(pred), torch.tensor(tar_wav))
167
+ metrics = {
168
+ "Intrusive": ret[0],
169
+ "Non-intrusive": ret[1],
170
+ 'LSD': lsd,
171
+ 'STOI': stoi,
172
+ 'PESQ': pesq,
173
+ }
174
+ self.log_dict(metrics)
175
+ return metrics
176
+
177
+ def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
178
+ f_0 = batch[:, :, 0:1, :]
179
+ x = batch[:, :, 1:, :]
180
+ pred = self(x)
181
+ pred = torch.cat([f_0, pred], dim=2)
182
+ pred = torch.istft(pred.squeeze(0).permute(1, 2, 0), self.window_size, self.hop_size,
183
+ window=self.window.to(pred.device))
184
+ return pred
185
+
186
+ def configure_optimizers(self):
187
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
188
+ lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=CONFIG.TRAIN.patience,
189
+ factor=CONFIG.TRAIN.factor, verbose=True)
190
+
191
+ scheduler = {
192
+ 'scheduler': lr_scheduler,
193
+ 'reduce_on_plateau': True,
194
+ 'monitor': 'val_loss'
195
+ }
196
+ return [optimizer], [scheduler]
197
+
198
+
199
+ class OnnxWrapper(pl.LightningModule):
200
+ def __init__(self, model, *args, **kwargs):
201
+ super().__init__(*args, **kwargs)
202
+ self.model = model
203
+ batch_size = 1
204
+ pred_states = torch.zeros((2, 1, batch_size, model.predictor.lstm_dim))
205
+ mlp_states = torch.zeros((model.encoder.depth, 2, 1, batch_size, model.encoder.dim))
206
+ mag = torch.zeros((batch_size, 1, model.hop_size, 1))
207
+ x = torch.randn(batch_size, model.hop_size + 1, 2)
208
+ self.sample = (x, mag, pred_states, mlp_states)
209
+ self.input_names = ['input', 'mag_in_cached_', 'pred_state_in_cached_', 'mlp_state_in_cached_']
210
+ self.output_names = ['output', 'mag_out_cached_', 'pred_state_out_cached_', 'mlp_state_out_cached_']
211
+
212
+ def forward(self, x, prev_mag, predictor_state=None, mlp_state=None):
213
+ x = x.permute(0, 2, 1).unsqueeze(-1)
214
+ f_0 = x[:, :, 0:1, :]
215
+ x = x[:, :, 1:, :]
216
+
217
+ output, prev_mag, predictor_state, mlp_state = self.model.forward_onnx(x, prev_mag, predictor_state, mlp_state)
218
+ output = torch.cat([f_0, output], dim=2)
219
+ output = output.squeeze(-1).permute(0, 2, 1)
220
+ return output, prev_mag, predictor_state, mlp_state