Spaces:
Runtime error
Runtime error
from typing import Dict | |
from src.models import ( | |
lcnn, | |
specrnet, | |
whisper_specrnet, | |
rawnet3, | |
whisper_lcnn, | |
meso_net, | |
whisper_meso_net | |
) | |
def get_model(model_name: str, config: Dict, device: str): | |
if model_name == "rawnet3": | |
return rawnet3.prepare_model() | |
elif model_name == "lcnn": | |
return lcnn.FrontendLCNN(device=device, **config) | |
elif model_name == "specrnet": | |
return specrnet.FrontendSpecRNet( | |
device=device, | |
**config, | |
) | |
elif model_name == "mesonet": | |
return meso_net.FrontendMesoInception4( | |
input_channels=config.get("input_channels", 1), | |
fc1_dim=config.get("fc1_dim", 1024), | |
frontend_algorithm=config.get("frontend_algorithm", "lfcc"), | |
device=device, | |
) | |
elif model_name == "whisper_lcnn": | |
return whisper_lcnn.WhisperLCNN( | |
input_channels=config.get("input_channels", 1), | |
freeze_encoder=config.get("freeze_encoder", False), | |
device=device, | |
) | |
elif model_name == "whisper_specrnet": | |
return whisper_specrnet.WhisperSpecRNet( | |
input_channels=config.get("input_channels", 1), | |
freeze_encoder=config.get("freeze_encoder", False), | |
device=device, | |
) | |
elif model_name == "whisper_mesonet": | |
return whisper_meso_net.WhisperMesoNet( | |
input_channels=config.get("input_channels", 1), | |
freeze_encoder=config.get("freeze_encoder", True), | |
fc1_dim=config.get("fc1_dim", 1024), | |
device=device, | |
) | |
elif model_name == "whisper_frontend_lcnn": | |
return whisper_lcnn.WhisperMultiFrontLCNN( | |
input_channels=config.get("input_channels", 2), | |
freeze_encoder=config.get("freeze_encoder", False), | |
frontend_algorithm=config.get("frontend_algorithm", "lfcc"), | |
device=device, | |
) | |
elif model_name == "whisper_frontend_specrnet": | |
return whisper_specrnet.WhisperMultiFrontSpecRNet( | |
input_channels=config.get("input_channels", 2), | |
freeze_encoder=config.get("freeze_encoder", False), | |
frontend_algorithm=config.get("frontend_algorithm", "lfcc"), | |
device=device, | |
) | |
elif model_name == "whisper_frontend_mesonet": | |
return whisper_meso_net.WhisperMultiFrontMesoNet( | |
input_channels=config.get("input_channels", 2), | |
fc1_dim=config.get("fc1_dim", 1024), | |
freeze_encoder=config.get("freeze_encoder", True), | |
frontend_algorithm=config.get("frontend_algorithm", "lfcc"), | |
device=device, | |
) | |
else: | |
raise ValueError(f"Model '{model_name}' not supported") | |