|
import os |
|
from typing import Union |
|
|
|
import torch |
|
import torchaudio |
|
from modules.Denoiser.AudioDenoiser import AudioDenoiser |
|
|
|
from modules.utils.constants import MODELS_DIR |
|
|
|
from modules.devices import devices |
|
|
|
import soundfile as sf |
|
|
|
ad: Union[AudioDenoiser, None] = None |
|
|
|
|
|
class TTSAudioDenoiser: |
|
|
|
def load_ad(self): |
|
global ad |
|
if ad is None: |
|
ad = AudioDenoiser( |
|
os.path.join( |
|
MODELS_DIR, |
|
"Denoise", |
|
"audio-denoiser-512-32-v1", |
|
), |
|
device=devices.device, |
|
) |
|
ad.model.to(devices.device) |
|
return ad |
|
|
|
def denoise(self, audio_data, sample_rate, auto_scale=False): |
|
ad = self.load_ad() |
|
sr = ad.model_sample_rate |
|
return sr, ad.process_waveform(audio_data, sample_rate, auto_scale) |
|
|
|
|
|
if __name__ == "__main__": |
|
tts_deno = TTSAudioDenoiser() |
|
data, sr = sf.read("test.wav") |
|
audio_tensor = torch.from_numpy(data).unsqueeze(0).float() |
|
print(audio_tensor) |
|
|
|
|
|
|
|
|
|
|
|
sr, denoised = tts_deno.denoise(audio_data=audio_tensor, sample_rate=sr) |
|
denoised = denoised.cpu() |
|
torchaudio.save("denoised.wav", denoised, sample_rate=sr) |
|
|