Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import io | |
| import random | |
| import subprocess as sp | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| from scipy.io import wavfile | |
| def i16_pcm(wav): | |
| if wav.dtype == np.int16: | |
| return wav | |
| return (wav * 2**15).clamp_(-2**15, 2**15 - 1).short() | |
| def f32_pcm(wav): | |
| if wav.dtype == np.float: | |
| return wav | |
| return wav.float() / 2**15 | |
| class RepitchedWrapper: | |
| """ | |
| Wrap a dataset to apply online change of pitch / tempo. | |
| """ | |
| def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12, tempo_std=5, vocals=[3]): | |
| self.dataset = dataset | |
| self.proba = proba | |
| self.max_pitch = max_pitch | |
| self.max_tempo = max_tempo | |
| self.tempo_std = tempo_std | |
| self.vocals = vocals | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, index): | |
| streams = self.dataset[index] | |
| in_length = streams.shape[-1] | |
| out_length = int((1 - 0.01 * self.max_tempo) * in_length) | |
| if random.random() < self.proba: | |
| delta_pitch = random.randint(-self.max_pitch, self.max_pitch) | |
| delta_tempo = random.gauss(0, self.tempo_std) | |
| delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo) | |
| outs = [] | |
| for idx, stream in enumerate(streams): | |
| stream = repitch( | |
| stream, | |
| delta_pitch, | |
| delta_tempo, | |
| voice=idx in self.vocals) | |
| outs.append(stream[:, :out_length]) | |
| streams = torch.stack(outs) | |
| else: | |
| streams = streams[..., :out_length] | |
| return streams | |
| def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100): | |
| """ | |
| tempo is a relative delta in percentage, so tempo=10 means tempo at 110%! | |
| pitch is in semi tones. | |
| Requires `soundstretch` to be installed, see | |
| https://www.surina.net/soundtouch/soundstretch.html | |
| """ | |
| outfile = tempfile.NamedTemporaryFile(suffix=".wav") | |
| in_ = io.BytesIO() | |
| wavfile.write(in_, samplerate, i16_pcm(wav).t().numpy()) | |
| command = [ | |
| "soundstretch", | |
| "stdin", | |
| outfile.name, | |
| f"-pitch={pitch}", | |
| f"-tempo={tempo:.6f}", | |
| ] | |
| if quick: | |
| command += ["-quick"] | |
| if voice: | |
| command += ["-speech"] | |
| try: | |
| sp.run(command, capture_output=True, input=in_.getvalue(), check=True) | |
| except sp.CalledProcessError as error: | |
| raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}") | |
| sr, wav = wavfile.read(outfile.name) | |
| wav = wav.copy() | |
| wav = f32_pcm(torch.from_numpy(wav).t()) | |
| assert sr == samplerate | |
| return wav | |