Spaces:
Runtime error
Runtime error
import os | |
import click | |
import gin | |
from scipy.io import wavfile | |
from tqdm import tqdm | |
import torch | |
from neural_waveshaping_synthesis.data.urmp import URMPDataset | |
from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT | |
from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping | |
from neural_waveshaping_synthesis.utils import make_dir_if_not_exists | |
def main( | |
model_gin, | |
model_checkpoint, | |
dataset_root, | |
dataset_split, | |
output_path, | |
load_data_to_memory, | |
device, | |
batch_size, | |
num_workers, | |
use_fastnewt | |
): | |
gin.parse_config_file(model_gin) | |
make_dir_if_not_exists(output_path) | |
data = URMPDataset(dataset_root, dataset_split, load_data_to_memory) | |
data_loader = torch.utils.data.DataLoader( | |
data, batch_size=batch_size, num_workers=num_workers | |
) | |
device = torch.device(device) | |
model = NeuralWaveshaping.load_from_checkpoint(model_checkpoint) | |
model.eval() | |
if use_fastnewt: | |
model.newt = FastNEWT(model.newt) | |
model = model.to(device) | |
for i, batch in enumerate(tqdm(data_loader)): | |
with torch.no_grad(): | |
f0 = batch["f0"].float().to(device) | |
control = batch["control"].float().to(device) | |
output = model(f0, control) | |
target_audio = batch["audio"].float().numpy() | |
output_audio = output.cpu().numpy() | |
for j in range(output_audio.shape[0]): | |
name = batch["name"][j] | |
target_name = "%s.target.wav" % name | |
output_name = "%s.output.wav" % name | |
wavfile.write( | |
os.path.join(output_path, target_name), | |
model.sample_rate, | |
target_audio[j], | |
) | |
wavfile.write( | |
os.path.join(output_path, output_name), | |
model.sample_rate, | |
output_audio[j], | |
) | |
if __name__ == "__main__": | |
main() | |