forwarder1121's picture
Update handler.py
7b9e393 verified
from models import StudentForAudioClassification
import torch
import torchaudio
class EndpointHandler:
def __init__(self, model_dir, *args, **kwargs):
self.model = StudentForAudioClassification.from_pretrained(model_dir, trust_remote_code=True)
self.model.eval()
bundle = torchaudio.pipelines.WAV2VEC2_BASE
self.w2v_model = bundle.get_model()
self.w2v_model.eval()
def __call__(self, data):
import io
waveform, orig_sr = torchaudio.load(io.BytesIO(data["inputs"]))
waveform = waveform.mean(dim=0, keepdim=True)
if orig_sr != 16000:
resampler = torchaudio.transforms.Resample(orig_sr, 16000)
waveform = resampler(waveform)
with torch.no_grad():
features = self.w2v_model(waveform)[0]
x_w2v = features.mean(dim=1)
x_w2v = x_w2v[:, :512]
outputs = self.model(x_w2v)
probs = torch.softmax(outputs.logits, dim=-1)
return {
"probabilities": probs.squeeze(0).tolist(),
"label": int(probs.argmax(dim=-1)[0])
}