anhnv125 commited on
Commit
871344c
1 Parent(s): 0d48cb7

update code

Browse files
Files changed (3) hide show
  1. main.py +2 -2
  2. models/blocks.py +1 -1
  3. models/frn.py +2 -2
main.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import pytorch_lightning as pl
5
  import soundfile as sf
6
  import torch
7
- from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
8
  from pytorch_lightning.utilities.model_summary import summarize
9
  from torch.utils.data import DataLoader
10
 
@@ -66,7 +66,7 @@ def train():
66
  gpus=len(gpus),
67
  max_epochs=CONFIG.TRAIN.epochs,
68
  accelerator="gpu" if len(gpus) > 1 else None,
69
- callbacks=[checkpoint_callback, StochasticWeightAveraging(swa_lrs=1e-2)]
70
  )
71
 
72
  print(model.hparams)
 
4
  import pytorch_lightning as pl
5
  import soundfile as sf
6
  import torch
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
  from pytorch_lightning.utilities.model_summary import summarize
9
  from torch.utils.data import DataLoader
10
 
 
66
  gpus=len(gpus),
67
  max_epochs=CONFIG.TRAIN.epochs,
68
  accelerator="gpu" if len(gpus) > 1 else None,
69
+ callbacks=[checkpoint_callback]
70
  )
71
 
72
  print(model.hparams)
models/blocks.py CHANGED
@@ -117,7 +117,7 @@ class Predictor(pl.LightningModule): # mel
117
  fb = librosa.filters.mel(sr=sr, n_fft=self.window_size, n_mels=self.n_mels)[:, 1:]
118
  self.fb = torch.from_numpy(fb).unsqueeze(0).unsqueeze(0)
119
  self.lstm = nn.LSTM(input_size=self.n_mels, hidden_size=self.lstm_dim, bidirectional=False,
120
- num_layers=self.lstm_layers)
121
  self.expand_dim = nn.Linear(self.lstm_dim, self.n_mels)
122
  self.inv_mel = nn.Linear(self.n_mels, self.hop_size)
123
 
 
117
  fb = librosa.filters.mel(sr=sr, n_fft=self.window_size, n_mels=self.n_mels)[:, 1:]
118
  self.fb = torch.from_numpy(fb).unsqueeze(0).unsqueeze(0)
119
  self.lstm = nn.LSTM(input_size=self.n_mels, hidden_size=self.lstm_dim, bidirectional=False,
120
+ num_layers=self.lstm_layers, batch_first=True)
121
  self.expand_dim = nn.Linear(self.lstm_dim, self.n_mels)
122
  self.inv_mel = nn.Linear(self.n_mels, self.hop_size)
123
 
models/frn.py CHANGED
@@ -66,7 +66,7 @@ class PLCModel(pl.LightningModule):
66
 
67
  x = x.permute(3, 0, 1, 2).unsqueeze(-1)
68
  prev_mag = torch.zeros((B, 1, F, 1), device=x.device)
69
- predictor_state = torch.zeros((2, self.predictor.lstm_layers, 1, self.predictor.lstm_dim), device=x.device)
70
  mlp_state = torch.zeros((self.encoder.depth, 2, 1, B, self.encoder.dim), device=x.device)
71
  result = []
72
  for step in x:
@@ -201,7 +201,7 @@ class OnnxWrapper(pl.LightningModule):
201
  super().__init__(*args, **kwargs)
202
  self.model = model
203
  batch_size = 1
204
- pred_states = torch.zeros((2, 1, 1, model.predictor.lstm_dim))
205
  mlp_states = torch.zeros((model.encoder.depth, 2, 1, batch_size, model.encoder.dim))
206
  mag = torch.zeros((batch_size, 1, model.hop_size, 1))
207
  x = torch.randn(batch_size, model.hop_size + 1, 2)
 
66
 
67
  x = x.permute(3, 0, 1, 2).unsqueeze(-1)
68
  prev_mag = torch.zeros((B, 1, F, 1), device=x.device)
69
+ predictor_state = torch.zeros((2, self.predictor.lstm_layers, B, self.predictor.lstm_dim), device=x.device)
70
  mlp_state = torch.zeros((self.encoder.depth, 2, 1, B, self.encoder.dim), device=x.device)
71
  result = []
72
  for step in x:
 
201
  super().__init__(*args, **kwargs)
202
  self.model = model
203
  batch_size = 1
204
+ pred_states = torch.zeros((2, 1, batch_size, model.predictor.lstm_dim))
205
  mlp_states = torch.zeros((model.encoder.depth, 2, 1, batch_size, model.encoder.dim))
206
  mag = torch.zeros((batch_size, 1, model.hop_size, 1))
207
  x = torch.randn(batch_size, model.hop_size + 1, 2)