File size: 5,474 Bytes
607ecc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

import os
import time
import warnings
warnings.filterwarnings("ignore")

import gin
import numpy as np
from scipy.io import wavfile
import torch

from neural_waveshaping_synthesis.data.utils.loudness_extraction import extract_perceptual_loudness
from neural_waveshaping_synthesis.data.utils.mfcc_extraction import extract_mfcc
from neural_waveshaping_synthesis.data.utils.f0_extraction import extract_f0_with_crepe
from neural_waveshaping_synthesis.data.utils.preprocess_audio import preprocess_audio, convert_to_float32_audio, make_monophonic, resample_audio
from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT
from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
import gradio as gr

torch.hub.download_url_to_file('https://benhayes.net/assets/audio/nws_examples/tt/tt1_in.wav', 'test1.wav')
torch.hub.download_url_to_file('https://benhayes.net/assets/audio/nws_examples/tt/tt2_in.wav', 'test2.wav')
torch.hub.download_url_to_file('https://benhayes.net/assets/audio/nws_examples/tt/tt3_in.wav', 'test3.wav')


try:
  gin.constant("device", "cuda" if torch.cuda.is_available() else "cpu")
except ValueError as err:
  pass

from scipy.io.wavfile import write


gin.parse_config_file("gin/models/newt.gin")
gin.parse_config_file("gin/data/urmp_4second_crepe.gin")

checkpoints = dict(Violin="vn", Flute="fl", Trumpet="tpt")

use_gpu = False 
dev_string = "cuda" if use_gpu else "cpu"
device = torch.device(dev_string)



def inference(wav, instrument):
    selected_checkpoint_name = instrument
    selected_checkpoint = checkpoints[selected_checkpoint_name]

    checkpoint_path = os.path.join(
      "checkpoints/nws", selected_checkpoint)
    model = NeuralWaveshaping.load_from_checkpoint(
        os.path.join(checkpoint_path, "last.ckpt")).to(device)
    original_newt = model.newt
    model.eval()
    data_mean = np.load(
        os.path.join(checkpoint_path, "data_mean.npy"))
    data_std = np.load(
        os.path.join(checkpoint_path, "data_std.npy"))
    rate, audio = wavfile.read(wav.name)
    audio = convert_to_float32_audio(make_monophonic(audio))
    audio = resample_audio(audio, rate, model.sample_rate)

    use_full_crepe_model = False 
    with torch.no_grad():
        f0, confidence = extract_f0_with_crepe(
            audio,
            full_model=use_full_crepe_model,
            maximum_frequency=1000)
        loudness = extract_perceptual_loudness(audio)



    octave_shift = 1 
    loudness_scale = 0.5 

 
    loudness_floor = 0 
    loudness_conf_filter = 0 
    pitch_conf_filter = 0 

    pitch_smoothing = 0 
    loudness_smoothing = 0 

    with torch.no_grad():
        f0_filtered = f0 * (confidence > pitch_conf_filter)
        loudness_filtered = loudness * (confidence > loudness_conf_filter)
        f0_shifted = f0_filtered * (2 ** octave_shift)
        loudness_floored = loudness_filtered * (loudness_filtered > loudness_floor) - loudness_floor
        loudness_scaled = loudness_floored * loudness_scale
    
        loud_norm = (loudness_scaled - data_mean[1]) / data_std[1]
    
        f0_t = torch.tensor(f0_shifted, device=device).float()
        loud_norm_t = torch.tensor(loud_norm, device=device).float()

        if pitch_smoothing != 0:
            f0_t = torch.nn.functional.conv1d(
            f0_t.expand(1, 1, -1),
            torch.ones(1, 1, pitch_smoothing * 2 + 1, device=device) /
                (pitch_smoothing * 2 + 1),
            padding=pitch_smoothing
            ).squeeze()
        f0_norm_t = torch.tensor((f0_t.cpu() - data_mean[0]) / data_std[0], device=device).float()

        if loudness_smoothing != 0:
            loud_norm_t = torch.nn.functional.conv1d(
            loud_norm_t.expand(1, 1, -1),
            torch.ones(1, 1, loudness_smoothing * 2 + 1, device=device) /
                (loudness_smoothing * 2 + 1),
            padding=loudness_smoothing
            ).squeeze()
        f0_norm_t = torch.tensor((f0_t.cpu() - data_mean[0]) / data_std[0], device=device).float()
        
        control = torch.stack((f0_norm_t, loud_norm_t), dim=0)

    model.newt = FastNEWT(original_newt)

    with torch.no_grad():
        start_time = time.time()
        out = model(f0_t.expand(1, 1, -1), control.unsqueeze(0))
        run_time = time.time() - start_time
    sample_rates=model.sample_rate
    rtf = (audio.shape[-1] / model.sample_rate) / run_time
    write('test.wav', sample_rates, out.detach().cpu().numpy().T)
    return 'test.wav'

inputs = [gr.inputs.Audio(label="input audio", type="file"), 
          gr.inputs.Dropdown(["Violin", "Flute", "Trumpet"], type="value", default="Violin", label="Instrument")]
outputs =  gr.outputs.Audio(label="output audio", type="file")


title = "neural waveshaping synthesis"
description = "demo for neural waveshaping synthesis: efficient neural audio synthesis in the waveform domain for timbre transfer. To use it, simply add your audio, or click one of the examples to load them. Read more at the links below. Input audio should be in WAV format similar to the example audio below"
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2107.05050'>neural waveshaping synthesis</a> | <a href='https://github.com/ben-hayes/neural-waveshaping-synthesis'>Github Repo</a></p>"

examples = [
 ['test1.wav'],
 ['test2.wav'],
 ['test3.wav']
]

gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples).launch()