Spaces:
Runtime error
Runtime error
import time | |
import click | |
import gin | |
import numpy as np | |
from scipy.stats import describe | |
import torch | |
from tqdm import trange | |
from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping | |
from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT | |
def main( | |
gin_file, num_iters, batch_size, device, length_in_seconds, sample_rate, control_hop, use_fast_newt | |
): | |
gin.parse_config_file(gin_file) | |
dummy_control = torch.rand( | |
batch_size, | |
2, | |
sample_rate * length_in_seconds // control_hop, | |
device=device, | |
requires_grad=False, | |
) | |
dummy_f0 = torch.rand( | |
batch_size, | |
1, | |
sample_rate * length_in_seconds // control_hop, | |
device=device, | |
requires_grad=False, | |
) | |
model = NeuralWaveshaping() | |
if use_fast_newt: | |
model.newt = FastNEWT(model.newt) | |
model.eval() | |
model = model.to(device) | |
times = [] | |
with torch.no_grad(): | |
for i in trange(num_iters): | |
start_time = time.time() | |
model(dummy_f0, dummy_control) | |
time_elapsed = time.time() - start_time | |
times.append(time_elapsed) | |
print(describe(times)) | |
rtfs = np.array(times) / length_in_seconds | |
print("Mean RTF: %.4f" % np.mean(rtfs)) | |
print("90th percentile RTF: %.4f" % np.percentile(rtfs, 90)) | |
if __name__ == "__main__": | |
main() |