Karlo Pintaric
Upload 25 files
fdc1efd
raw
history blame
No virus
3.12 kB
from pathlib import Path
import numpy as np
import torch
from torchvision import transforms
from src.modeling import ASTPretrained, FeatureExtractor, PreprocessPipeline, StudentAST
MODELS_FOLDER = Path(__file__).parent / "models"
CLASSES = ["tru", "sax", "vio", "gac", "org", "cla", "flu", "voi", "gel", "cel", "pia"]
def load_model(model_type: str):
"""
Loads a pre-trained AST model of the specified type.
:param model_type: The type of model to load
:type model_type: str
:return: The loaded pre-trained AST model.
:rtype: ASTPretrained
"""
if model_type == "accuracy":
model = ASTPretrained(n_classes=11, download_weights=False)
model.load_state_dict(torch.load(f"{MODELS_FOLDER}/acc_model_ast.pth", map_location=torch.device("cpu")))
else:
model = StudentAST(n_classes=11, hidden_size=192, num_heads=3)
model.load_state_dict(torch.load(f"{MODELS_FOLDER}/speed_model_ast.pth", map_location=torch.device("cpu")))
model.eval()
return model
def load_labels():
"""
Loads a dictionary of class labels for the AST model.
:return: A dictionary where the keys are the class indices and the values are the class labels.
:rtype: Dict[int, str]
"""
labels = {i: CLASSES[i] for i in range(len(CLASSES))}
return labels
def load_thresholds(model_type: str):
"""
Loads the prediction thresholds for the AST model.
:return: The prediction thresholds for each class.
:rtype: np.ndarray
"""
if model_type == "accuracy":
thresholds = np.load(f"{MODELS_FOLDER}/acc_model_thresh.npy", allow_pickle=True)
else:
thresholds = np.load(f"{MODELS_FOLDER}/speed_model_thresh.npy", allow_pickle=True)
return thresholds
class ModelServiceAST:
def __init__(self, model_type: str):
"""
Initializes a ModelServiceAST instance with the specified model type.
:param model_type: The type of model to load
:type model_type: str
"""
self.model = load_model(model_type)
self.labels = load_labels()
self.thresholds = load_thresholds(model_type)
self.transform = transforms.Compose([PreprocessPipeline(target_sr=16000), FeatureExtractor(sr=16000)])
def get_prediction(self, audio):
"""
Gets the binary predictions for the given audio file.
:param audio_file: The file object for the input audio to make predictions for.
:type audio_file: file object
:return: A dictionary where the keys are the class labels and the values are binary predictions (0 or 1).
:rtype: Dict[str, int]
"""
processed = self.transform(audio)
with torch.no_grad():
# Don't forget to transpose the output to seq_len x num_features!!!
output = torch.sigmoid(self.model(processed.mT))
output = output.squeeze().numpy().astype(float)
binary_predictions = {}
for i, label in enumerate(CLASSES):
binary_predictions[label] = int(output[i] >= self.thresholds[i])
return binary_predictions