import logging import math from typing import Union import torch import torchaudio from torch import nn from audio_denoiser.helpers.torch_helper import batched_apply from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model from audio_denoiser.helpers.audio_helper import ( create_spectrogram, reconstruct_from_spectrogram, ) _expected_t_std = 0.23 _recommended_backend = "soundfile" # ref: https://github.com/jose-solorzano/audio-denoiser class AudioDenoiser: def __init__( self, local_dir: str, device: Union[str, torch.device] = None, num_iterations: int = 100, ): super().__init__() if device is None: is_cuda = torch.cuda.is_available() if not is_cuda: logging.warning("CUDA not available. Will use CPU.") device = torch.device("cuda:0") if is_cuda else torch.device("cpu") self.device = device self.model = load_audio_denosier_model(dir_path=local_dir, device=device) self.model.eval() self.model_sample_rate = self.model.sample_rate self.scaler = self.model.scaler self.n_fft = self.model.n_fft self.segment_num_frames = self.model.num_frames self.num_iterations = num_iterations @staticmethod def _sp_log(spectrogram: torch.Tensor, eps=0.01): return torch.log(spectrogram + eps) @staticmethod def _sp_exp(log_spectrogram: torch.Tensor, eps=0.01): return torch.clamp(torch.exp(log_spectrogram) - eps, min=0) @staticmethod def _trimmed_dev(waveform: torch.Tensor, q: float = 0.90) -> float: # Expected for training data is ~0.23 abs_waveform = torch.abs(waveform) quantile_value = torch.quantile(abs_waveform, q).item() trimmed_values = waveform[abs_waveform >= quantile_value] return torch.std(trimmed_values).item() def process_waveform( self, waveform: torch.Tensor, sample_rate: int, return_cpu_tensor: bool = False, auto_scale: bool = False, ) -> torch.Tensor: """ Denoises a waveform. @param waveform: A waveform tensor. Use torchaudio structure. @param sample_rate: The sample rate of the waveform in Hz. @param return_cpu_tensor: Whether the returned tensor must be a CPU tensor. @param auto_scale: Normalize the scale of the waveform before processing. Recommended for low-volume audio. @return: A denoised waveform. """ waveform = waveform.cpu() if auto_scale: w_t_std = self._trimmed_dev(waveform) waveform = waveform * _expected_t_std / w_t_std if sample_rate != self.model_sample_rate: transform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=self.model_sample_rate ) waveform = transform(waveform) hop_len = self.n_fft // 2 spectrogram = create_spectrogram(waveform, n_fft=self.n_fft, hop_length=hop_len) spectrogram = spectrogram.to(self.device) num_a_channels = spectrogram.size(0) with torch.no_grad(): results = [] for c in range(num_a_channels): c_spectrogram = spectrogram[c] # c_spectrogram: (257, num_frames) fft_size, num_frames = c_spectrogram.shape num_segments = math.ceil(num_frames / self.segment_num_frames) adj_num_frames = num_segments * self.segment_num_frames if adj_num_frames > num_frames: c_spectrogram = nn.functional.pad( c_spectrogram, (0, adj_num_frames - num_frames) ) c_spectrogram = c_spectrogram.view( fft_size, num_segments, self.segment_num_frames ) # c_spectrogram: (257, num_segments, 32) c_spectrogram = torch.permute(c_spectrogram, (1, 0, 2)) # c_spectrogram: (num_segments, 257, 32) log_c_spectrogram = self._sp_log(c_spectrogram) scaled_log_c_sp = self.scaler(log_c_spectrogram) pred_noise_log_sp = batched_apply( self.model, scaled_log_c_sp, detached=True ) log_denoised_sp = log_c_spectrogram - pred_noise_log_sp denoised_sp = self._sp_exp(log_denoised_sp) # denoised_sp: (num_segments, 257, 32) denoised_sp = torch.permute(denoised_sp, (1, 0, 2)) # denoised_sp: (257, num_segments, 32) denoised_sp = denoised_sp.contiguous().view(1, fft_size, adj_num_frames) # denoised_sp: (1, 257, adj_num_frames) denoised_sp = denoised_sp[:, :, :num_frames] denoised_sp = denoised_sp.cpu() denoised_waveform = reconstruct_from_spectrogram( denoised_sp, num_iterations=self.num_iterations ) # denoised_waveform: (1, num_samples) results.append(denoised_waveform) cpu_results = torch.cat(results) return cpu_results if return_cpu_tensor else cpu_results.to(self.device) def process_audio_file( self, in_audio_file: str, out_audio_file: str, auto_scale: bool = False ): """ Denoises an audio file. @param in_audio_file: An input audio file with a format supported by torchaudio. @param out_audio_file: Am output audio file with a format supported by torchaudio. @param auto_scale: Whether the input waveform scale should be normalized before processing. Recommended for low-volume audio. """ waveform, sample_rate = torchaudio.load(in_audio_file) denoised_waveform = self.process_waveform( waveform, sample_rate, return_cpu_tensor=True, auto_scale=auto_scale ) torchaudio.save( out_audio_file, denoised_waveform, sample_rate=self.model_sample_rate )