r3gm's picture
Upload 288 files
7bc29af
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import io
import random
import subprocess as sp
import tempfile
import numpy as np
import torch
from scipy.io import wavfile
def i16_pcm(wav):
if wav.dtype == np.int16:
return wav
return (wav * 2**15).clamp_(-2**15, 2**15 - 1).short()
def f32_pcm(wav):
if wav.dtype == np.float:
return wav
return wav.float() / 2**15
class RepitchedWrapper:
"""
Wrap a dataset to apply online change of pitch / tempo.
"""
def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12, tempo_std=5, vocals=[3]):
self.dataset = dataset
self.proba = proba
self.max_pitch = max_pitch
self.max_tempo = max_tempo
self.tempo_std = tempo_std
self.vocals = vocals
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
streams = self.dataset[index]
in_length = streams.shape[-1]
out_length = int((1 - 0.01 * self.max_tempo) * in_length)
if random.random() < self.proba:
delta_pitch = random.randint(-self.max_pitch, self.max_pitch)
delta_tempo = random.gauss(0, self.tempo_std)
delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo)
outs = []
for idx, stream in enumerate(streams):
stream = repitch(
stream,
delta_pitch,
delta_tempo,
voice=idx in self.vocals)
outs.append(stream[:, :out_length])
streams = torch.stack(outs)
else:
streams = streams[..., :out_length]
return streams
def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
"""
tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
pitch is in semi tones.
Requires `soundstretch` to be installed, see
https://www.surina.net/soundtouch/soundstretch.html
"""
outfile = tempfile.NamedTemporaryFile(suffix=".wav")
in_ = io.BytesIO()
wavfile.write(in_, samplerate, i16_pcm(wav).t().numpy())
command = [
"soundstretch",
"stdin",
outfile.name,
f"-pitch={pitch}",
f"-tempo={tempo:.6f}",
]
if quick:
command += ["-quick"]
if voice:
command += ["-speech"]
try:
sp.run(command, capture_output=True, input=in_.getvalue(), check=True)
except sp.CalledProcessError as error:
raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
sr, wav = wavfile.read(outfile.name)
wav = wav.copy()
wav = f32_pcm(torch.from_numpy(wav).t())
assert sr == samplerate
return wav