Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from audio_denoiser.modules.Permute import Permute | |
from audio_denoiser.modules.SimpleRoberta import SimpleRoberta | |
from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler | |
import json | |
class AudioNoiseModel(nn.Module): | |
def __init__(self, config: dict): | |
super(AudioNoiseModel, self).__init__() | |
# Encoder layers | |
self.config = config | |
scaler_dict = config["scaler"] | |
self.scaler = SpectrogramScaler.from_dict(scaler_dict) | |
self.in_channels = config.get("in_channels", 257) | |
self.roberta_hidden_size = config.get("roberta_hidden_size", 768) | |
self.model1 = nn.Sequential( | |
nn.Conv1d(self.in_channels, 1024, kernel_size=1), | |
nn.ELU(), | |
nn.Conv1d(1024, 1024, kernel_size=1), | |
nn.ELU(), | |
nn.Conv1d(1024, self.in_channels, kernel_size=1), | |
) | |
self.model2 = nn.Sequential( | |
Permute(0, 2, 1), | |
nn.Linear(self.in_channels, self.roberta_hidden_size), | |
SimpleRoberta(num_hidden_layers=5, hidden_size=self.roberta_hidden_size), | |
nn.Linear(self.roberta_hidden_size, self.in_channels), | |
Permute(0, 2, 1), | |
) | |
def sample_rate(self) -> int: | |
return self.config.get("sample_rate", 16000) | |
def n_fft(self) -> int: | |
return self.config.get("n_fft", 512) | |
def num_frames(self) -> int: | |
return self.config.get("num_frames", 32) | |
def forward(self, x, use_scaler: bool = False, out_scale: float = 1.0): | |
if use_scaler: | |
x = self.scaler(x) | |
x1 = self.model1(x) | |
x2 = self.model2(x) | |
x = x1 + x2 | |
return x * out_scale | |
def load_audio_denosier_model(dir_path: str, device) -> AudioNoiseModel: | |
config = json.load(open(f"{dir_path}/config.json", "r")) | |
model = AudioNoiseModel(config) | |
model.load_state_dict(torch.load(f"{dir_path}/pytorch_model.bin")) | |
model.to(device) | |
model.model1.to(device) | |
model.model2.to(device) | |
return model | |