File size: 1,326 Bytes
29a525e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse

import hparams
import utils
import multiresunet_model

import tensorflow as tf
import numpy as np

if __name__ == '__main__':
  args = argparse.ArgumentParser()

  args.add_argument('Path',metavar='path',type=str,help='Path to DSD100 pickled spectrograms. See preprocess_data.py for more details')
  args.add_argument('Source',metavar='source',type=str,help='Desired source to separate')
  args.add_argument('Spectrum',metavar='spectrum',type=str,help='Low (lf) or High (hf) frequencies training')
  args.add_argument('Outpath',metavar='model_out_path',type=str,help='Path to save the model to')
  
  ### Parse Args ###
  args = args.parse_args()
  path = args.Path
  source = args.Source
  spectrum = args.Spectrum
  output_path = args.Outpath
  
  ### Load Data ###
  x = np.load(path + 'mixture_' + spectrum + '.npy')
  y = np.load(path + source + '_' + spectrum + '.npy')

  ### Construct model ###
  model = multiresunet_model.Steminator((hparams.frequency_bins,hparams.chunk_size,hparams.n_channels))
  optimizer = tf.keras.optimizers.Adam(lr = hparams.learning_rate)
  model.compile(optimizer, loss='mean_absolute_error')

  ### Training ###
  model.fit(x,y,epochs = hparams.epochs, batch_size = hparams.batch_size)

  ### Save model ###
  model.save(output_path + source + '_' + spectrum + '.h5')