Applio / rvc /lib /predictors /F0Extractor.py
blaise-tk's picture
Upload 190 files
8ef89ca verified
raw
history blame
3.21 kB
import dataclasses
import pathlib
import librosa
import numpy as np
import resampy
import torch
import torchcrepe
import torchfcpe
import os
# from tools.anyf0.rmvpe import RMVPE
from rvc.lib.predictors.RMVPE import RMVPE0Predictor
from rvc.configs.config import Config
config = Config()
@dataclasses.dataclass
class F0Extractor:
wav_path: pathlib.Path
sample_rate: int = 44100
hop_length: int = 512
f0_min: int = 50
f0_max: int = 1600
method: str = "rmvpe"
x: np.ndarray = dataclasses.field(init=False)
def __post_init__(self):
self.x, self.sample_rate = librosa.load(self.wav_path, sr=self.sample_rate)
@property
def hop_size(self):
return self.hop_length / self.sample_rate
@property
def wav16k(self):
return resampy.resample(self.x, self.sample_rate, 16000)
def extract_f0(self):
f0 = None
method = self.method
if method == "crepe":
wav16k_torch = torch.FloatTensor(self.wav16k).unsqueeze(0).to(config.device)
f0 = torchcrepe.predict(
wav16k_torch,
sample_rate=16000,
hop_length=160,
batch_size=512,
fmin=self.f0_min,
fmax=self.f0_max,
device=config.device,
)
f0 = f0[0].cpu().numpy()
elif method == "fcpe":
audio = librosa.to_mono(self.x)
audio_length = len(audio)
f0_target_length = (audio_length // self.hop_length) + 1
audio = (
torch.from_numpy(audio)
.float()
.unsqueeze(0)
.unsqueeze(-1)
.to(config.device)
)
model = torchfcpe.spawn_bundled_infer_model(device=config.device)
f0 = model.infer(
audio,
sr=self.sample_rate,
decoder_mode="local_argmax",
threshold=0.006,
f0_min=self.f0_min,
f0_max=self.f0_max,
interp_uv=False,
output_interp_target_length=f0_target_length,
)
f0 = f0.squeeze().cpu().numpy()
elif method == "rmvpe":
model_rmvpe = RMVPE0Predictor(
os.path.join("rvc", "models", "predictors", "rmvpe.pt"),
device=config.device,
# hop_length=80
)
f0 = model_rmvpe.infer_from_audio(self.wav16k, thred=0.03)
else:
raise ValueError(f"Unknown method: {self.method}")
return self.hz_to_cents(f0, librosa.midi_to_hz(0))
def plot_f0(self, f0):
from matplotlib import pyplot as plt
plt.figure(figsize=(10, 4))
plt.plot(f0)
plt.title(self.method)
plt.xlabel("Time (frames)")
plt.ylabel("F0 (cents)")
plt.show()
@staticmethod
def hz_to_cents(F, F_ref=55.0):
F_temp = np.array(F).astype(float)
F_temp[F_temp == 0] = np.nan
F_cents = 1200 * np.log2(F_temp / F_ref)
return F_cents