| |
| """ |
| RMVPE 模型 - 用于高质量 F0 提取 |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import Optional |
|
|
|
|
| class BiGRU(nn.Module): |
| """双向 GRU 层""" |
|
|
| def __init__(self, input_features: int, hidden_features: int, num_layers: int): |
| super().__init__() |
| self.gru = nn.GRU( |
| input_features, |
| hidden_features, |
| num_layers=num_layers, |
| batch_first=True, |
| bidirectional=True |
| ) |
|
|
| def forward(self, x): |
| return self.gru(x)[0] |
|
|
|
|
| class ConvBlockRes(nn.Module): |
| """残差卷积块""" |
|
|
| def __init__(self, in_channels: int, out_channels: int, momentum: float = 0.01, |
| force_shortcut: bool = False): |
| super().__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), |
| nn.BatchNorm2d(out_channels, momentum=momentum), |
| nn.ReLU(), |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), |
| nn.BatchNorm2d(out_channels, momentum=momentum), |
| nn.ReLU() |
| ) |
|
|
| |
| if in_channels != out_channels or force_shortcut: |
| self.shortcut = nn.Conv2d(in_channels, out_channels, 1) |
| self.has_shortcut = True |
| else: |
| self.has_shortcut = False |
|
|
| def forward(self, x): |
| if self.has_shortcut: |
| return self.conv(x) + self.shortcut(x) |
| else: |
| return self.conv(x) + x |
|
|
|
|
| class EncoderBlock(nn.Module): |
| """编码器块 - 包含多个 ConvBlockRes 和一个池化层""" |
|
|
| def __init__(self, in_channels: int, out_channels: int, kernel_size: int, |
| n_blocks: int, momentum: float = 0.01): |
| super().__init__() |
| self.conv = nn.ModuleList() |
| self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) |
| for _ in range(n_blocks - 1): |
| self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) |
| self.pool = nn.AvgPool2d(kernel_size) |
|
|
| def forward(self, x): |
| for block in self.conv: |
| x = block(x) |
| |
| return self.pool(x), x |
|
|
|
|
| class Encoder(nn.Module): |
| """RMVPE 编码器""" |
|
|
| def __init__(self, in_channels: int, in_size: int, n_encoders: int, |
| kernel_size: int, n_blocks: int, out_channels: int = 16, |
| momentum: float = 0.01): |
| super().__init__() |
|
|
| self.n_encoders = n_encoders |
| self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) |
| self.layers = nn.ModuleList() |
| self.latent_channels = [] |
|
|
| for i in range(n_encoders): |
| self.layers.append( |
| EncoderBlock( |
| in_channels if i == 0 else out_channels * (2 ** (i - 1)), |
| out_channels * (2 ** i), |
| kernel_size, |
| n_blocks, |
| momentum |
| ) |
| ) |
| self.latent_channels.append(out_channels * (2 ** i)) |
|
|
| def forward(self, x): |
| x = self.bn(x) |
| concat_tensors = [] |
| for layer in self.layers: |
| x, skip = layer(x) |
| concat_tensors.append(skip) |
| return x, concat_tensors |
|
|
|
|
| class Intermediate(nn.Module): |
| """中间层""" |
|
|
| def __init__(self, in_channels: int, out_channels: int, n_inters: int, |
| n_blocks: int, momentum: float = 0.01): |
| super().__init__() |
|
|
| self.layers = nn.ModuleList() |
| for i in range(n_inters): |
| if i == 0: |
| |
| self.layers.append( |
| IntermediateBlock(in_channels, out_channels, n_blocks, momentum, first_block_shortcut=True) |
| ) |
| else: |
| |
| self.layers.append( |
| IntermediateBlock(out_channels, out_channels, n_blocks, momentum, first_block_shortcut=False) |
| ) |
|
|
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
|
|
| class IntermediateBlock(nn.Module): |
| """中间层块""" |
|
|
| def __init__(self, in_channels: int, out_channels: int, n_blocks: int, |
| momentum: float = 0.01, first_block_shortcut: bool = False): |
| super().__init__() |
| self.conv = nn.ModuleList() |
| |
| self.conv.append(ConvBlockRes(in_channels, out_channels, momentum, force_shortcut=first_block_shortcut)) |
| for _ in range(n_blocks - 1): |
| self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) |
|
|
| def forward(self, x): |
| for block in self.conv: |
| x = block(x) |
| return x |
|
|
|
|
| class DecoderBlock(nn.Module): |
| """解码器块""" |
|
|
| def __init__(self, in_channels: int, out_channels: int, stride: int, |
| n_blocks: int, momentum: float = 0.01): |
| super().__init__() |
| |
| self.conv1 = nn.Sequential( |
| nn.ConvTranspose2d(in_channels, out_channels, 3, stride, padding=1, output_padding=1, bias=False), |
| nn.BatchNorm2d(out_channels, momentum=momentum) |
| ) |
| |
| |
| |
| self.conv2 = nn.ModuleList() |
| self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) |
| for _ in range(n_blocks - 1): |
| self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) |
|
|
| def forward(self, x, concat_tensor): |
| x = self.conv1(x) |
| |
| diff_h = concat_tensor.size(2) - x.size(2) |
| diff_w = concat_tensor.size(3) - x.size(3) |
| if diff_h != 0 or diff_w != 0: |
| |
| x = F.pad(x, [0, diff_w, 0, diff_h]) |
| x = torch.cat([x, concat_tensor], dim=1) |
| for block in self.conv2: |
| x = block(x) |
| return x |
|
|
|
|
| class Decoder(nn.Module): |
| """RMVPE 解码器""" |
|
|
| def __init__(self, in_channels: int, n_decoders: int, stride: int, |
| n_blocks: int, out_channels: int = 16, momentum: float = 0.01): |
| super().__init__() |
|
|
| self.layers = nn.ModuleList() |
| for i in range(n_decoders): |
| out_ch = out_channels * (2 ** (n_decoders - 1 - i)) |
| in_ch = in_channels if i == 0 else out_channels * (2 ** (n_decoders - i)) |
| self.layers.append( |
| DecoderBlock(in_ch, out_ch, stride, n_blocks, momentum) |
| ) |
|
|
| def forward(self, x, concat_tensors): |
| for i, layer in enumerate(self.layers): |
| x = layer(x, concat_tensors[-1 - i]) |
| return x |
|
|
|
|
| class DeepUnet(nn.Module): |
| """Deep U-Net 架构""" |
|
|
| def __init__(self, kernel_size: int, n_blocks: int, en_de_layers: int = 5, |
| inter_layers: int = 4, in_channels: int = 1, en_out_channels: int = 16): |
| super().__init__() |
|
|
| |
| encoder_out_channels = en_out_channels * (2 ** (en_de_layers - 1)) |
| |
| intermediate_out_channels = encoder_out_channels * 2 |
|
|
| self.encoder = Encoder( |
| in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels |
| ) |
| self.intermediate = Intermediate( |
| encoder_out_channels, |
| intermediate_out_channels, |
| inter_layers, n_blocks |
| ) |
| self.decoder = Decoder( |
| intermediate_out_channels, |
| en_de_layers, kernel_size, n_blocks, en_out_channels |
| ) |
|
|
| def forward(self, x): |
| x, concat_tensors = self.encoder(x) |
| x = self.intermediate(x) |
| x = self.decoder(x, concat_tensors) |
| return x |
|
|
|
|
| class E2E(nn.Module): |
| """端到端 RMVPE 模型""" |
|
|
| def __init__(self, n_blocks: int, n_gru: int, kernel_size: int, |
| en_de_layers: int = 5, inter_layers: int = 4, |
| in_channels: int = 1, en_out_channels: int = 16): |
| super().__init__() |
|
|
| self.unet = DeepUnet( |
| kernel_size, n_blocks, en_de_layers, inter_layers, |
| in_channels, en_out_channels |
| ) |
| self.cnn = nn.Conv2d(en_out_channels, 3, 3, 1, 1) |
|
|
| if n_gru: |
| self.fc = nn.Sequential( |
| BiGRU(3 * 128, 256, n_gru), |
| nn.Linear(512, 360), |
| nn.Dropout(0.25), |
| nn.Sigmoid() |
| ) |
| else: |
| self.fc = nn.Sequential( |
| nn.Linear(3 * 128, 360), |
| nn.Dropout(0.25), |
| nn.Sigmoid() |
| ) |
|
|
| def forward(self, mel): |
| |
| |
| if mel.dim() == 3: |
| |
| mel = mel.transpose(-1, -2).unsqueeze(1) |
| elif mel.dim() == 4 and mel.shape[1] == 1: |
| |
| mel = mel.transpose(-1, -2) |
|
|
| x = self.unet(mel) |
| x = self.cnn(x) |
| |
| |
| x = x.transpose(1, 2).flatten(-2) |
| x = self.fc(x) |
| return x |
|
|
|
|
| class MelSpectrogram(nn.Module): |
| """Mel 频谱提取""" |
|
|
| def __init__(self, n_mel: int = 128, n_fft: int = 1024, win_size: int = 1024, |
| hop_length: int = 160, sample_rate: int = 16000, |
| fmin: int = 30, fmax: int = 8000): |
| super().__init__() |
|
|
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.win_size = win_size |
| self.sample_rate = sample_rate |
| self.n_mel = n_mel |
|
|
| |
| mel_basis = self._mel_filterbank(sample_rate, n_fft, n_mel, fmin, fmax) |
| self.register_buffer("mel_basis", mel_basis) |
| self.register_buffer("window", torch.hann_window(win_size)) |
|
|
| def _mel_filterbank(self, sr, n_fft, n_mels, fmin, fmax): |
| """创建 Mel 滤波器组""" |
| import librosa |
| |
| mel = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=True) |
| return torch.from_numpy(mel).float() |
|
|
| def forward(self, audio): |
| |
| spec = torch.stft( |
| audio, |
| self.n_fft, |
| hop_length=self.hop_length, |
| win_length=self.win_size, |
| window=self.window, |
| center=True, |
| pad_mode="reflect", |
| normalized=False, |
| onesided=True, |
| return_complex=True |
| ) |
| |
| spec = torch.abs(spec) ** 2 |
|
|
| |
| mel = torch.matmul(self.mel_basis, spec) |
| mel = torch.log(torch.clamp(mel, min=1e-5)) |
|
|
| return mel |
|
|
|
|
| class RMVPE: |
| """RMVPE F0 提取器封装类""" |
|
|
| def __init__(self, model_path: str, device: str = "cuda"): |
| self.device = device |
|
|
| |
| self.model = E2E(n_blocks=4, n_gru=1, kernel_size=2) |
| ckpt = torch.load(model_path, map_location="cpu", weights_only=False) |
| self.model.load_state_dict(ckpt) |
| self.model = self.model.to(device).eval() |
|
|
| |
| self.mel_extractor = MelSpectrogram().to(device) |
|
|
| |
| cents_mapping = 20 * np.arange(360) + 1997.3794084376191 |
| self.cents_mapping = np.pad(cents_mapping, (4, 4)) |
|
|
| @torch.no_grad() |
| def infer_from_audio(self, audio: np.ndarray, thred: float = 0.03) -> np.ndarray: |
| """ |
| 从音频提取 F0 |
| |
| Args: |
| audio: 16kHz 音频数据 |
| thred: 置信度阈值 |
| |
| Returns: |
| np.ndarray: F0 序列 |
| """ |
| |
| audio = torch.from_numpy(audio).float().to(self.device) |
| if audio.dim() == 1: |
| audio = audio.unsqueeze(0) |
|
|
| |
| mel = self.mel_extractor(audio) |
|
|
| |
| n_frames = mel.shape[-1] |
|
|
| |
| n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames |
| if n_pad > 0: |
| mel = F.pad(mel, (0, n_pad), mode='constant', value=0) |
|
|
| |
| hidden = self.model(mel) |
|
|
| |
| hidden = hidden[:, :n_frames, :] |
| hidden = hidden.squeeze(0).cpu().numpy() |
|
|
| |
| f0 = self._decode(hidden, thred) |
|
|
| return f0 |
|
|
| def _decode(self, hidden: np.ndarray, thred: float) -> np.ndarray: |
| """解码隐藏状态为 F0 - 使用官方 RVC 算法""" |
| |
| cents = self._to_local_average_cents(hidden, thred) |
|
|
| |
| f0 = 10 * (2 ** (cents / 1200)) |
| f0[f0 == 10] = 0 |
|
|
| return f0 |
|
|
| def _to_local_average_cents(self, salience: np.ndarray, thred: float) -> np.ndarray: |
| """官方 RVC 的 to_local_average_cents 算法""" |
| |
| center = np.argmax(salience, axis=1) |
|
|
| |
| salience = np.pad(salience, ((0, 0), (4, 4))) |
| center += 4 |
|
|
| |
| todo_salience = [] |
| todo_cents_mapping = [] |
| starts = center - 4 |
| ends = center + 5 |
|
|
| for idx in range(salience.shape[0]): |
| todo_salience.append(salience[idx, starts[idx]:ends[idx]]) |
| todo_cents_mapping.append(self.cents_mapping[starts[idx]:ends[idx]]) |
|
|
| todo_salience = np.array(todo_salience) |
| todo_cents_mapping = np.array(todo_cents_mapping) |
|
|
| |
| product_sum = np.sum(todo_salience * todo_cents_mapping, axis=1) |
| weight_sum = np.sum(todo_salience, axis=1) + 1e-9 |
| cents = product_sum / weight_sum |
|
|
| |
| maxx = np.max(salience, axis=1) |
| cents[maxx <= thred] = 0 |
|
|
| return cents |
|
|