Spaces:
Running
Running
File size: 5,039 Bytes
ed91efa bd3d872 ed91efa bd3d872 ed91efa bd3d872 ed91efa bd3d872 ed91efa bd3d872 ed91efa bd3d872 ed91efa bd3d872 ed91efa bd3d872 ed91efa bd3d872 ed91efa bd3d872 ed91efa bd3d872 ed91efa bd3d872 ed91efa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
from pathlib import Path
import shutil
import tempfile, time
import zipfile
import librosa
import numpy as np
import torch
import torchaudio
torch.set_num_threads(1)
from project_settings import project_path
from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2PretrainedModel, MODEL_FILE
logger = logging.getLogger("toolbox")
class InferenceDfNet(object):
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
self.device = torch.device(device)
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
self.config = config
self.model = model
self.model.to(device)
self.model.eval()
def load_models(self, model_path: str):
model_path = Path(model_path)
if model_path.name.endswith(".zip"):
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
out_root.mkdir(parents=True, exist_ok=True)
f_zip.extractall(path=out_root)
model_path = out_root / model_path.stem
config = DfNet2Config.from_pretrained(
pretrained_model_name_or_path=model_path.as_posix(),
)
model = DfNet2PretrainedModel.from_pretrained(
pretrained_model_name_or_path=model_path.as_posix(),
)
model.to(self.device)
model.eval()
shutil.rmtree(model_path)
return config, model
def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
noisy_audio = noisy_audio.unsqueeze(dim=0)
# noisy_audio shape: [batch_size, n_samples]
enhanced_audio = self.denoise_offline(noisy_audio)
# enhanced_audio shape: [channels, num_samples]
enhanced_audio = enhanced_audio[0]
# enhanced_audio shape: [num_samples]
return enhanced_audio.cpu().numpy()
def denoise_offline(self, noisy_audio: torch.Tensor) -> torch.Tensor:
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
# noisy_audio shape: [batch_size, num_samples]
noisy_audios = noisy_audio.to(self.device)
with torch.no_grad():
est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios)
# shape: [batch_size, 1, num_samples]
denoise = est_wav[0]
# shape: [channels, num_samples]
return denoise
def denoise_online(self, noisy_audio: torch.Tensor) -> torch.Tensor:
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
# noisy_audio shape: [batch_size, num_samples]
noisy_audios = noisy_audio.to(self.device)
with torch.no_grad():
denoise = self.model.forward_chunk_by_chunk(noisy_audios)
# shape: [batch_size, 1, num_samples]
denoise = denoise[0]
# shape: [channels, num_samples]
return denoise
def main():
model_zip_file = project_path / "trained_models/dfnet2-nx-dns3.zip"
infer_model = InferenceDfNet(model_zip_file)
sample_rate = 8000
noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"
noisy_audio, sample_rate = librosa.load(
noisy_audio_file.as_posix(),
sr=sample_rate,
)
duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
# noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
noisy_audio = noisy_audio.unsqueeze(dim=0)
begin = time.time()
enhanced_audio = infer_model.denoise_offline(noisy_audio)
time_cost = time.time() - begin
print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
filename = "enhanced_audio_offline.wav"
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
begin = time.time()
enhanced_audio = infer_model.denoise_online(noisy_audio)
time_cost = time.time() - begin
print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
filename = "enhanced_audio_online.wav"
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
return
if __name__ == "__main__":
main()
|