Spaces:
Running
Running
import numpy as np | |
# import keras | |
from kapre.time_frequency import Spectrogram | |
from tensorflow import keras | |
from generators.generator import * | |
from models.common.architectures import layers_map | |
""" | |
The STFT spectrogram of the input signal is fed | |
into a 2D CNN that predicts the synthesizer parameter | |
configuration. This configuration is then used to produce | |
a sound that is similar to the input sound. | |
""" | |
"""Model Architecture""" | |
# @ paper: | |
# 1 2D Strided Convolution Layer C(38,13,26,13,26) | |
# where C(F,K1,K2,S1,S2) stands for a ReLU activated | |
# 2D strided convolutional layer with F filters in size of (K1,K2) | |
# and strides (S1,S2). | |
def assemble_model( | |
src: np.ndarray, | |
n_outputs: int, | |
arch_layers: list, | |
n_dft: int = 512, # Orig:128 | |
n_hop: int = 256, # Orig:64 | |
data_format: str = "channels_first", | |
) -> keras.Model: | |
inputs = keras.Input(shape=src.shape, name="stft") | |
# @paper: Spectrogram based CNN that receives the (log) spectrogram matrix as input | |
# @kapre: | |
# abs(Spectrogram) in a shape of 2D data, i.e., | |
# `(None, n_channel, n_freq, n_time)` if `'channels_first'`, | |
# `(None, n_freq, n_time, n_channel)` if `'channels_last'`, | |
x = Spectrogram( | |
n_dft=n_dft, | |
n_hop=n_hop, | |
input_shape=src.shape, | |
trainable_kernel=True, | |
name="static_stft", | |
image_data_format=data_format, | |
return_decibel_spectrogram=True, | |
)(inputs) | |
# Swaps order to match the paper? | |
# TODO: dig in to this (GPU only?) | |
if data_format == "channels_first": # n_channel, n_freq, n_time) | |
x = keras.layers.Permute((1, 3, 2))(x) | |
else: | |
x = keras.layers.Permute((2, 1, 3))(x) | |
# x = keras.layers.Conv2D(64,(3,3),strides=(2,2),activation="relu",data_format="channels_last", padding='same')(x) | |
# x = keras.layers.Conv2D(128,(3,3),strides=(2,2),activation="relu",data_format="channels_last", padding='same')(x) | |
# x = keras.layers.Conv2D(128,(3,4),strides=(2,3),activation="relu",data_format="channels_last", padding='same')(x) | |
# x = keras.layers.Conv2D(128,(3,3),strides=(2,2),activation="relu",data_format="channels_last", padding='same')(x) | |
# x = keras.layers.Conv2D(256,(3,3),strides=(2,2),activation="relu",data_format="channels_last", padding='same')(x) | |
# x = keras.layers.Conv2D(256,(3,3),strides=(1,2),activation="relu",data_format="channels_last", padding='same')(x) | |
for arch_layer in arch_layers: | |
x = keras.layers.Conv2D( | |
arch_layer.filters, | |
arch_layer.window_size, | |
strides=arch_layer.strides, | |
activation=arch_layer.activation, | |
data_format=data_format, | |
padding='same' | |
)(x) | |
# Flatten down to a single dimension | |
x = keras.layers.Flatten()(x) | |
# @paper: sigmoid activations with binary cross entropy loss | |
# @paper: FC-512 | |
x = keras.layers.Dense(512)(x) | |
# @paper: FC-368(sigmoid) | |
outputs = keras.layers.Dense(n_outputs, activation="sigmoid", name="predictions")(x) | |
return keras.Model(inputs=inputs, outputs=outputs) | |
""" | |
Standard callback to get a model ready to train | |
""" | |
def get_model( | |
model_name: str, inputs: int, outputs: int, data_format: str = "channels_last" | |
) -> keras.Model: | |
arch_layers = layers_map.get("C1") | |
if model_name in layers_map: | |
arch_layers = layers_map.get(model_name) | |
else: | |
print( | |
f"Warning: {model_name} is not compatible with the spectrogram model. C1 Architecture will be used instead." | |
) | |
return assemble_model( | |
np.zeros([1, inputs]), | |
n_outputs=outputs, | |
arch_layers=arch_layers, | |
data_format=data_format, | |
) | |
if __name__ == "__main__": | |
from models.launch import train_model, inference | |
from models.runner import standard_run_parser | |
# Get a standard parser, and the arguments out of it | |
parser = standard_run_parser() | |
args = parser.parse_args() | |
setup = vars(args) | |
print(setup) | |
# distinguish model type for reshaping | |
setup["model_type"] = "STFT" | |
# tf.config.run_functions_eagerly(True) | |
# Actually train the model | |
model, parameters_file = train_model(model_callback=get_model, **setup) | |
file_path, csv_path = inference(model, parameters_file) | |
print(file_path) | |
print(csv_path) |