Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import io | |
import random | |
import subprocess as sp | |
import tempfile | |
import numpy as np | |
import torch | |
from scipy.io import wavfile | |
def i16_pcm(wav): | |
if wav.dtype == np.int16: | |
return wav | |
return (wav * 2**15).clamp_(-2**15, 2**15 - 1).short() | |
def f32_pcm(wav): | |
if wav.dtype == np.float: | |
return wav | |
return wav.float() / 2**15 | |
class RepitchedWrapper: | |
""" | |
Wrap a dataset to apply online change of pitch / tempo. | |
""" | |
def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12, tempo_std=5, vocals=[3]): | |
self.dataset = dataset | |
self.proba = proba | |
self.max_pitch = max_pitch | |
self.max_tempo = max_tempo | |
self.tempo_std = tempo_std | |
self.vocals = vocals | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, index): | |
streams = self.dataset[index] | |
in_length = streams.shape[-1] | |
out_length = int((1 - 0.01 * self.max_tempo) * in_length) | |
if random.random() < self.proba: | |
delta_pitch = random.randint(-self.max_pitch, self.max_pitch) | |
delta_tempo = random.gauss(0, self.tempo_std) | |
delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo) | |
outs = [] | |
for idx, stream in enumerate(streams): | |
stream = repitch( | |
stream, | |
delta_pitch, | |
delta_tempo, | |
voice=idx in self.vocals) | |
outs.append(stream[:, :out_length]) | |
streams = torch.stack(outs) | |
else: | |
streams = streams[..., :out_length] | |
return streams | |
def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100): | |
""" | |
tempo is a relative delta in percentage, so tempo=10 means tempo at 110%! | |
pitch is in semi tones. | |
Requires `soundstretch` to be installed, see | |
https://www.surina.net/soundtouch/soundstretch.html | |
""" | |
outfile = tempfile.NamedTemporaryFile(suffix=".wav") | |
in_ = io.BytesIO() | |
wavfile.write(in_, samplerate, i16_pcm(wav).t().numpy()) | |
command = [ | |
"soundstretch", | |
"stdin", | |
outfile.name, | |
f"-pitch={pitch}", | |
f"-tempo={tempo:.6f}", | |
] | |
if quick: | |
command += ["-quick"] | |
if voice: | |
command += ["-speech"] | |
try: | |
sp.run(command, capture_output=True, input=in_.getvalue(), check=True) | |
except sp.CalledProcessError as error: | |
raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}") | |
sr, wav = wavfile.read(outfile.name) | |
wav = wav.copy() | |
wav = f32_pcm(torch.from_numpy(wav).t()) | |
assert sr == samplerate | |
return wav | |