|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
from torchvision import transforms |
|
|
|
from src.modeling import ASTPretrained, FeatureExtractor, PreprocessPipeline, StudentAST |
|
|
|
MODELS_FOLDER = Path(__file__).parent / "models" |
|
|
|
CLASSES = ["tru", "sax", "vio", "gac", "org", "cla", "flu", "voi", "gel", "cel", "pia"] |
|
|
|
|
|
def load_model(model_type: str): |
|
""" |
|
Loads a pre-trained AST model of the specified type. |
|
|
|
:param model_type: The type of model to load |
|
:type model_type: str |
|
:return: The loaded pre-trained AST model. |
|
:rtype: ASTPretrained |
|
""" |
|
|
|
if model_type == "accuracy": |
|
model = ASTPretrained(n_classes=11, download_weights=False) |
|
model.load_state_dict(torch.load(f"{MODELS_FOLDER}/acc_model_ast.pth", map_location=torch.device("cpu"))) |
|
else: |
|
model = StudentAST(n_classes=11, hidden_size=192, num_heads=3) |
|
model.load_state_dict(torch.load(f"{MODELS_FOLDER}/speed_model_ast.pth", map_location=torch.device("cpu"))) |
|
model.eval() |
|
return model |
|
|
|
|
|
def load_labels(): |
|
""" |
|
Loads a dictionary of class labels for the AST model. |
|
|
|
:return: A dictionary where the keys are the class indices and the values are the class labels. |
|
:rtype: Dict[int, str] |
|
""" |
|
|
|
labels = {i: CLASSES[i] for i in range(len(CLASSES))} |
|
return labels |
|
|
|
|
|
def load_thresholds(model_type: str): |
|
""" |
|
Loads the prediction thresholds for the AST model. |
|
|
|
:return: The prediction thresholds for each class. |
|
:rtype: np.ndarray |
|
""" |
|
if model_type == "accuracy": |
|
thresholds = np.load(f"{MODELS_FOLDER}/acc_model_thresh.npy", allow_pickle=True) |
|
else: |
|
thresholds = np.load(f"{MODELS_FOLDER}/speed_model_thresh.npy", allow_pickle=True) |
|
return thresholds |
|
|
|
|
|
class ModelServiceAST: |
|
def __init__(self, model_type: str): |
|
""" |
|
Initializes a ModelServiceAST instance with the specified model type. |
|
|
|
:param model_type: The type of model to load |
|
:type model_type: str |
|
""" |
|
|
|
self.model = load_model(model_type) |
|
self.labels = load_labels() |
|
self.thresholds = load_thresholds(model_type) |
|
self.transform = transforms.Compose([PreprocessPipeline(target_sr=16000), FeatureExtractor(sr=16000)]) |
|
|
|
def get_prediction(self, audio): |
|
""" |
|
Gets the binary predictions for the given audio file. |
|
|
|
:param audio_file: The file object for the input audio to make predictions for. |
|
:type audio_file: file object |
|
:return: A dictionary where the keys are the class labels and the values are binary predictions (0 or 1). |
|
:rtype: Dict[str, int] |
|
""" |
|
processed = self.transform(audio) |
|
with torch.no_grad(): |
|
|
|
output = torch.sigmoid(self.model(processed.mT)) |
|
output = output.squeeze().numpy().astype(float) |
|
|
|
binary_predictions = {} |
|
for i, label in enumerate(CLASSES): |
|
binary_predictions[label] = int(output[i] >= self.thresholds[i]) |
|
|
|
return binary_predictions |
|
|