File size: 6,095 Bytes
da8d589 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import logging
import math
from typing import Union
import torch
import torchaudio
from torch import nn
from audio_denoiser.helpers.torch_helper import batched_apply
from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
from audio_denoiser.helpers.audio_helper import (
create_spectrogram,
reconstruct_from_spectrogram,
)
_expected_t_std = 0.23
_recommended_backend = "soundfile"
# ref: https://github.com/jose-solorzano/audio-denoiser
class AudioDenoiser:
def __init__(
self,
local_dir: str,
device: Union[str, torch.device] = None,
num_iterations: int = 100,
):
super().__init__()
if device is None:
is_cuda = torch.cuda.is_available()
if not is_cuda:
logging.warning("CUDA not available. Will use CPU.")
device = torch.device("cuda:0") if is_cuda else torch.device("cpu")
self.device = device
self.model = load_audio_denosier_model(dir_path=local_dir, device=device)
self.model.eval()
self.model_sample_rate = self.model.sample_rate
self.scaler = self.model.scaler
self.n_fft = self.model.n_fft
self.segment_num_frames = self.model.num_frames
self.num_iterations = num_iterations
@staticmethod
def _sp_log(spectrogram: torch.Tensor, eps=0.01):
return torch.log(spectrogram + eps)
@staticmethod
def _sp_exp(log_spectrogram: torch.Tensor, eps=0.01):
return torch.clamp(torch.exp(log_spectrogram) - eps, min=0)
@staticmethod
def _trimmed_dev(waveform: torch.Tensor, q: float = 0.90) -> float:
# Expected for training data is ~0.23
abs_waveform = torch.abs(waveform)
quantile_value = torch.quantile(abs_waveform, q).item()
trimmed_values = waveform[abs_waveform >= quantile_value]
return torch.std(trimmed_values).item()
def process_waveform(
self,
waveform: torch.Tensor,
sample_rate: int,
return_cpu_tensor: bool = False,
auto_scale: bool = False,
) -> torch.Tensor:
"""
Denoises a waveform.
@param waveform: A waveform tensor. Use torchaudio structure.
@param sample_rate: The sample rate of the waveform in Hz.
@param return_cpu_tensor: Whether the returned tensor must be a CPU tensor.
@param auto_scale: Normalize the scale of the waveform before processing. Recommended for low-volume audio.
@return: A denoised waveform.
"""
waveform = waveform.cpu()
if auto_scale:
w_t_std = self._trimmed_dev(waveform)
waveform = waveform * _expected_t_std / w_t_std
if sample_rate != self.model_sample_rate:
transform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.model_sample_rate
)
waveform = transform(waveform)
hop_len = self.n_fft // 2
spectrogram = create_spectrogram(waveform, n_fft=self.n_fft, hop_length=hop_len)
spectrogram = spectrogram.to(self.device)
num_a_channels = spectrogram.size(0)
with torch.no_grad():
results = []
for c in range(num_a_channels):
c_spectrogram = spectrogram[c]
# c_spectrogram: (257, num_frames)
fft_size, num_frames = c_spectrogram.shape
num_segments = math.ceil(num_frames / self.segment_num_frames)
adj_num_frames = num_segments * self.segment_num_frames
if adj_num_frames > num_frames:
c_spectrogram = nn.functional.pad(
c_spectrogram, (0, adj_num_frames - num_frames)
)
c_spectrogram = c_spectrogram.view(
fft_size, num_segments, self.segment_num_frames
)
# c_spectrogram: (257, num_segments, 32)
c_spectrogram = torch.permute(c_spectrogram, (1, 0, 2))
# c_spectrogram: (num_segments, 257, 32)
log_c_spectrogram = self._sp_log(c_spectrogram)
scaled_log_c_sp = self.scaler(log_c_spectrogram)
pred_noise_log_sp = batched_apply(
self.model, scaled_log_c_sp, detached=True
)
log_denoised_sp = log_c_spectrogram - pred_noise_log_sp
denoised_sp = self._sp_exp(log_denoised_sp)
# denoised_sp: (num_segments, 257, 32)
denoised_sp = torch.permute(denoised_sp, (1, 0, 2))
# denoised_sp: (257, num_segments, 32)
denoised_sp = denoised_sp.contiguous().view(1, fft_size, adj_num_frames)
# denoised_sp: (1, 257, adj_num_frames)
denoised_sp = denoised_sp[:, :, :num_frames]
denoised_sp = denoised_sp.cpu()
denoised_waveform = reconstruct_from_spectrogram(
denoised_sp, num_iterations=self.num_iterations
)
# denoised_waveform: (1, num_samples)
results.append(denoised_waveform)
cpu_results = torch.cat(results)
return cpu_results if return_cpu_tensor else cpu_results.to(self.device)
def process_audio_file(
self, in_audio_file: str, out_audio_file: str, auto_scale: bool = False
):
"""
Denoises an audio file.
@param in_audio_file: An input audio file with a format supported by torchaudio.
@param out_audio_file: Am output audio file with a format supported by torchaudio.
@param auto_scale: Whether the input waveform scale should be normalized before processing. Recommended for low-volume audio.
"""
waveform, sample_rate = torchaudio.load(in_audio_file)
denoised_waveform = self.process_waveform(
waveform, sample_rate, return_cpu_tensor=True, auto_scale=auto_scale
)
torchaudio.save(
out_audio_file, denoised_waveform, sample_rate=self.model_sample_rate
)
|