|  | import logging | 
					
						
						|  | import math | 
					
						
						|  | from typing import Union | 
					
						
						|  | import torch | 
					
						
						|  | import torchaudio | 
					
						
						|  | from torch import nn | 
					
						
						|  | from audio_denoiser.helpers.torch_helper import batched_apply | 
					
						
						|  | from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model | 
					
						
						|  | from audio_denoiser.helpers.audio_helper import ( | 
					
						
						|  | create_spectrogram, | 
					
						
						|  | reconstruct_from_spectrogram, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | _expected_t_std = 0.23 | 
					
						
						|  | _recommended_backend = "soundfile" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AudioDenoiser: | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | local_dir: str, | 
					
						
						|  | device: Union[str, torch.device] = None, | 
					
						
						|  | num_iterations: int = 100, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | if device is None: | 
					
						
						|  | is_cuda = torch.cuda.is_available() | 
					
						
						|  | if not is_cuda: | 
					
						
						|  | logging.warning("CUDA not available. Will use CPU.") | 
					
						
						|  | device = torch.device("cuda:0") if is_cuda else torch.device("cpu") | 
					
						
						|  | self.device = device | 
					
						
						|  | self.model = load_audio_denosier_model(dir_path=local_dir, device=device) | 
					
						
						|  | self.model.eval() | 
					
						
						|  | self.model_sample_rate = self.model.sample_rate | 
					
						
						|  | self.scaler = self.model.scaler | 
					
						
						|  | self.n_fft = self.model.n_fft | 
					
						
						|  | self.segment_num_frames = self.model.num_frames | 
					
						
						|  | self.num_iterations = num_iterations | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _sp_log(spectrogram: torch.Tensor, eps=0.01): | 
					
						
						|  | return torch.log(spectrogram + eps) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _sp_exp(log_spectrogram: torch.Tensor, eps=0.01): | 
					
						
						|  | return torch.clamp(torch.exp(log_spectrogram) - eps, min=0) | 
					
						
						|  |  | 
					
						
						|  | @staticmethod | 
					
						
						|  | def _trimmed_dev(waveform: torch.Tensor, q: float = 0.90) -> float: | 
					
						
						|  |  | 
					
						
						|  | abs_waveform = torch.abs(waveform) | 
					
						
						|  | quantile_value = torch.quantile(abs_waveform, q).item() | 
					
						
						|  | trimmed_values = waveform[abs_waveform >= quantile_value] | 
					
						
						|  | return torch.std(trimmed_values).item() | 
					
						
						|  |  | 
					
						
						|  | def process_waveform( | 
					
						
						|  | self, | 
					
						
						|  | waveform: torch.Tensor, | 
					
						
						|  | sample_rate: int, | 
					
						
						|  | return_cpu_tensor: bool = False, | 
					
						
						|  | auto_scale: bool = False, | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """ | 
					
						
						|  | Denoises a waveform. | 
					
						
						|  | @param waveform: A waveform tensor. Use torchaudio structure. | 
					
						
						|  | @param sample_rate: The sample rate of the waveform in Hz. | 
					
						
						|  | @param return_cpu_tensor: Whether the returned tensor must be a CPU tensor. | 
					
						
						|  | @param auto_scale: Normalize the scale of the waveform before processing. Recommended for low-volume audio. | 
					
						
						|  | @return: A denoised waveform. | 
					
						
						|  | """ | 
					
						
						|  | waveform = waveform.cpu() | 
					
						
						|  | if auto_scale: | 
					
						
						|  | w_t_std = self._trimmed_dev(waveform) | 
					
						
						|  | waveform = waveform * _expected_t_std / w_t_std | 
					
						
						|  | if sample_rate != self.model_sample_rate: | 
					
						
						|  | transform = torchaudio.transforms.Resample( | 
					
						
						|  | orig_freq=sample_rate, new_freq=self.model_sample_rate | 
					
						
						|  | ) | 
					
						
						|  | waveform = transform(waveform) | 
					
						
						|  | hop_len = self.n_fft // 2 | 
					
						
						|  | spectrogram = create_spectrogram(waveform, n_fft=self.n_fft, hop_length=hop_len) | 
					
						
						|  | spectrogram = spectrogram.to(self.device) | 
					
						
						|  | num_a_channels = spectrogram.size(0) | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | results = [] | 
					
						
						|  | for c in range(num_a_channels): | 
					
						
						|  | c_spectrogram = spectrogram[c] | 
					
						
						|  |  | 
					
						
						|  | fft_size, num_frames = c_spectrogram.shape | 
					
						
						|  | num_segments = math.ceil(num_frames / self.segment_num_frames) | 
					
						
						|  | adj_num_frames = num_segments * self.segment_num_frames | 
					
						
						|  | if adj_num_frames > num_frames: | 
					
						
						|  | c_spectrogram = nn.functional.pad( | 
					
						
						|  | c_spectrogram, (0, adj_num_frames - num_frames) | 
					
						
						|  | ) | 
					
						
						|  | c_spectrogram = c_spectrogram.view( | 
					
						
						|  | fft_size, num_segments, self.segment_num_frames | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | c_spectrogram = torch.permute(c_spectrogram, (1, 0, 2)) | 
					
						
						|  |  | 
					
						
						|  | log_c_spectrogram = self._sp_log(c_spectrogram) | 
					
						
						|  | scaled_log_c_sp = self.scaler(log_c_spectrogram) | 
					
						
						|  | pred_noise_log_sp = batched_apply( | 
					
						
						|  | self.model, scaled_log_c_sp, detached=True | 
					
						
						|  | ) | 
					
						
						|  | log_denoised_sp = log_c_spectrogram - pred_noise_log_sp | 
					
						
						|  | denoised_sp = self._sp_exp(log_denoised_sp) | 
					
						
						|  |  | 
					
						
						|  | denoised_sp = torch.permute(denoised_sp, (1, 0, 2)) | 
					
						
						|  |  | 
					
						
						|  | denoised_sp = denoised_sp.contiguous().view(1, fft_size, adj_num_frames) | 
					
						
						|  |  | 
					
						
						|  | denoised_sp = denoised_sp[:, :, :num_frames] | 
					
						
						|  | denoised_sp = denoised_sp.cpu() | 
					
						
						|  | denoised_waveform = reconstruct_from_spectrogram( | 
					
						
						|  | denoised_sp, num_iterations=self.num_iterations | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | results.append(denoised_waveform) | 
					
						
						|  | cpu_results = torch.cat(results) | 
					
						
						|  | return cpu_results if return_cpu_tensor else cpu_results.to(self.device) | 
					
						
						|  |  | 
					
						
						|  | def process_audio_file( | 
					
						
						|  | self, in_audio_file: str, out_audio_file: str, auto_scale: bool = False | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Denoises an audio file. | 
					
						
						|  | @param in_audio_file: An input audio file with a format supported by torchaudio. | 
					
						
						|  | @param out_audio_file: Am output audio file with a format supported by torchaudio. | 
					
						
						|  | @param auto_scale: Whether the input waveform scale should be normalized before processing. Recommended for low-volume audio. | 
					
						
						|  | """ | 
					
						
						|  | waveform, sample_rate = torchaudio.load(in_audio_file) | 
					
						
						|  | denoised_waveform = self.process_waveform( | 
					
						
						|  | waveform, sample_rate, return_cpu_tensor=True, auto_scale=auto_scale | 
					
						
						|  | ) | 
					
						
						|  | torchaudio.save( | 
					
						
						|  | out_audio_file, denoised_waveform, sample_rate=self.model_sample_rate | 
					
						
						|  | ) | 
					
						
						|  |  |