| from models import StudentForAudioClassification | |
| import torch | |
| import torchaudio | |
| class EndpointHandler: | |
| def __init__(self, model_dir, *args, **kwargs): | |
| self.model = StudentForAudioClassification.from_pretrained(model_dir, trust_remote_code=True) | |
| self.model.eval() | |
| bundle = torchaudio.pipelines.WAV2VEC2_BASE | |
| self.w2v_model = bundle.get_model() | |
| self.w2v_model.eval() | |
| def __call__(self, data): | |
| import io | |
| waveform, orig_sr = torchaudio.load(io.BytesIO(data["inputs"])) | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| if orig_sr != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_sr, 16000) | |
| waveform = resampler(waveform) | |
| with torch.no_grad(): | |
| features = self.w2v_model(waveform)[0] | |
| x_w2v = features.mean(dim=1) | |
| x_w2v = x_w2v[:, :512] | |
| outputs = self.model(x_w2v) | |
| probs = torch.softmax(outputs.logits, dim=-1) | |
| return { | |
| "probabilities": probs.squeeze(0).tolist(), | |
| "label": int(probs.argmax(dim=-1)[0]) | |
| } | |