import torch import torchaudio class AudioPipeline: def __init__(self, feature_extractor, model, top_k=5): self.fe = feature_extractor self.model = model self.top_k = top_k def __call__(self, audio_file): if isinstance(audio_file, str): waveform, sample_rate = torchaudio.load(audio_file) else: waveform, sample_rate = torchaudio.load(audio_file.name) waveform = waveform.mean(dim=0) if sample_rate != self.fe.sampling_rate: transform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=self.fe.sampling_rate) waveform = transform(waveform) inputs = self.fe(waveform, sampling_rate=self.fe.sampling_rate, return_tensors="pt", padding=True) with torch.no_grad(): logits = self.model(**inputs).logits probs = torch.nn.functional.softmax(logits, dim=-1)[0] top_probs, top_ids = torch.topk(probs, self.top_k) top_labels = [self.model.config.id2label[idx.item()] for idx in top_ids] return {label: prob.item() for label, prob in zip(top_labels, top_probs)}