Karlo Pintaric
Upload 25 files
fdc1efd
from pathlib import Path
import numpy as np
import torch
from torchvision import transforms
from src.modeling import ASTPretrained, FeatureExtractor, PreprocessPipeline, StudentAST
MODELS_FOLDER = Path(__file__).parent / "models"
CLASSES = ["tru", "sax", "vio", "gac", "org", "cla", "flu", "voi", "gel", "cel", "pia"]
def load_model(model_type: str):
"""
Loads a pre-trained AST model of the specified type.
:param model_type: The type of model to load
:type model_type: str
:return: The loaded pre-trained AST model.
:rtype: ASTPretrained
"""
if model_type == "accuracy":
model = ASTPretrained(n_classes=11, download_weights=False)
model.load_state_dict(torch.load(f"{MODELS_FOLDER}/acc_model_ast.pth", map_location=torch.device("cpu")))
else:
model = StudentAST(n_classes=11, hidden_size=192, num_heads=3)
model.load_state_dict(torch.load(f"{MODELS_FOLDER}/speed_model_ast.pth", map_location=torch.device("cpu")))
model.eval()
return model
def load_labels():
"""
Loads a dictionary of class labels for the AST model.
:return: A dictionary where the keys are the class indices and the values are the class labels.
:rtype: Dict[int, str]
"""
labels = {i: CLASSES[i] for i in range(len(CLASSES))}
return labels
def load_thresholds(model_type: str):
"""
Loads the prediction thresholds for the AST model.
:return: The prediction thresholds for each class.
:rtype: np.ndarray
"""
if model_type == "accuracy":
thresholds = np.load(f"{MODELS_FOLDER}/acc_model_thresh.npy", allow_pickle=True)
else:
thresholds = np.load(f"{MODELS_FOLDER}/speed_model_thresh.npy", allow_pickle=True)
return thresholds
class ModelServiceAST:
def __init__(self, model_type: str):
"""
Initializes a ModelServiceAST instance with the specified model type.
:param model_type: The type of model to load
:type model_type: str
"""
self.model = load_model(model_type)
self.labels = load_labels()
self.thresholds = load_thresholds(model_type)
self.transform = transforms.Compose([PreprocessPipeline(target_sr=16000), FeatureExtractor(sr=16000)])
def get_prediction(self, audio):
"""
Gets the binary predictions for the given audio file.
:param audio_file: The file object for the input audio to make predictions for.
:type audio_file: file object
:return: A dictionary where the keys are the class labels and the values are binary predictions (0 or 1).
:rtype: Dict[str, int]
"""
processed = self.transform(audio)
with torch.no_grad():
# Don't forget to transpose the output to seq_len x num_features!!!
output = torch.sigmoid(self.model(processed.mT))
output = output.squeeze().numpy().astype(float)
binary_predictions = {}
for i, label in enumerate(CLASSES):
binary_predictions[label] = int(output[i] >= self.thresholds[i])
return binary_predictions