File size: 2,477 Bytes
607ecc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os

import click
import gin
from scipy.io import wavfile
from tqdm import tqdm
import torch

from neural_waveshaping_synthesis.data.urmp import URMPDataset
from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT
from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
from neural_waveshaping_synthesis.utils import make_dir_if_not_exists


@click.command()
@click.option("--model-gin", prompt="Model .gin file")
@click.option("--model-checkpoint", prompt="Model checkpoint")
@click.option("--dataset-root", prompt="Dataset root directory")
@click.option("--dataset-split", default="test")
@click.option("--output-path", default="audio_output")
@click.option("--load-data-to-memory", default=False)
@click.option("--device", default="cuda:0")
@click.option("--batch-size", default=8)
@click.option("--num_workers", default=16)
@click.option("--use-fastnewt", is_flag=True)
def main(
    model_gin,
    model_checkpoint,
    dataset_root,
    dataset_split,
    output_path,
    load_data_to_memory,
    device,
    batch_size,
    num_workers,
    use_fastnewt
):
    gin.parse_config_file(model_gin)
    make_dir_if_not_exists(output_path)

    data = URMPDataset(dataset_root, dataset_split, load_data_to_memory)
    data_loader = torch.utils.data.DataLoader(
        data, batch_size=batch_size, num_workers=num_workers
    )

    device = torch.device(device)
    model = NeuralWaveshaping.load_from_checkpoint(model_checkpoint)
    model.eval()

    if use_fastnewt:
        model.newt = FastNEWT(model.newt)
    
    model = model.to(device)

    for i, batch in enumerate(tqdm(data_loader)):
        with torch.no_grad():
            f0 = batch["f0"].float().to(device)
            control = batch["control"].float().to(device)
            output = model(f0, control)

        target_audio = batch["audio"].float().numpy()
        output_audio = output.cpu().numpy()
        for j in range(output_audio.shape[0]):
            name = batch["name"][j]
            target_name = "%s.target.wav" % name
            output_name = "%s.output.wav" % name
            wavfile.write(
                os.path.join(output_path, target_name),
                model.sample_rate,
                target_audio[j],
            )
            wavfile.write(
                os.path.join(output_path, output_name),
                model.sample_rate,
                output_audio[j],
            )


if __name__ == "__main__":
    main()