neural-waveshaping-synthesis / scripts /resynthesise_dataset.py
akhaliq3
spaces demo
607ecc1
import os
import click
import gin
from scipy.io import wavfile
from tqdm import tqdm
import torch
from neural_waveshaping_synthesis.data.urmp import URMPDataset
from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT
from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
from neural_waveshaping_synthesis.utils import make_dir_if_not_exists
@click.command()
@click.option("--model-gin", prompt="Model .gin file")
@click.option("--model-checkpoint", prompt="Model checkpoint")
@click.option("--dataset-root", prompt="Dataset root directory")
@click.option("--dataset-split", default="test")
@click.option("--output-path", default="audio_output")
@click.option("--load-data-to-memory", default=False)
@click.option("--device", default="cuda:0")
@click.option("--batch-size", default=8)
@click.option("--num_workers", default=16)
@click.option("--use-fastnewt", is_flag=True)
def main(
model_gin,
model_checkpoint,
dataset_root,
dataset_split,
output_path,
load_data_to_memory,
device,
batch_size,
num_workers,
use_fastnewt
):
gin.parse_config_file(model_gin)
make_dir_if_not_exists(output_path)
data = URMPDataset(dataset_root, dataset_split, load_data_to_memory)
data_loader = torch.utils.data.DataLoader(
data, batch_size=batch_size, num_workers=num_workers
)
device = torch.device(device)
model = NeuralWaveshaping.load_from_checkpoint(model_checkpoint)
model.eval()
if use_fastnewt:
model.newt = FastNEWT(model.newt)
model = model.to(device)
for i, batch in enumerate(tqdm(data_loader)):
with torch.no_grad():
f0 = batch["f0"].float().to(device)
control = batch["control"].float().to(device)
output = model(f0, control)
target_audio = batch["audio"].float().numpy()
output_audio = output.cpu().numpy()
for j in range(output_audio.shape[0]):
name = batch["name"][j]
target_name = "%s.target.wav" % name
output_name = "%s.output.wav" % name
wavfile.write(
os.path.join(output_path, target_name),
model.sample_rate,
target_audio[j],
)
wavfile.write(
os.path.join(output_path, output_name),
model.sample_rate,
output_audio[j],
)
if __name__ == "__main__":
main()