|
from datasets import load_dataset, Audio |
|
from transformers import pipeline |
|
import torchaudio |
|
import numpy as np |
|
|
|
|
|
zero_shot_classifier = pipeline( |
|
task="zero-shot-audio-classification", |
|
model="laion/clap-htsat-unfused" |
|
) |
|
|
|
|
|
candidate_labels = [ |
|
"Sound of a dog barking", |
|
"Sound of car driving", |
|
"Sound of a person talking", |
|
"Sound of a bird singing", |
|
"Sound of a plane flying", |
|
] |
|
|
|
|
|
def audio_dataset_inference(): |
|
|
|
dataset = load_dataset("ashraq/esc50", split="train[0:10]") |
|
|
|
|
|
dataset = dataset.cast_column("audio", Audio(sampling_rate=48_000)) |
|
|
|
|
|
audio_sample = dataset[0] |
|
|
|
|
|
result = zero_shot_classifier( |
|
audio_sample["audio"]["array"], |
|
candidate_labels=candidate_labels |
|
) |
|
print(result) |
|
|
|
def classify_audio(audio_file): |
|
""" |
|
Perform zero-shot classification on a single audio file. |
|
|
|
Args: |
|
audio_file (str): Path to the audio file to classify. |
|
|
|
Returns: |
|
dict: Classification labels and their corresponding scores. |
|
""" |
|
try: |
|
|
|
waveform, sample_rate = torchaudio.load(audio_file) |
|
|
|
|
|
if sample_rate != 48000: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=48000) |
|
waveform = resampler(waveform) |
|
|
|
|
|
audio_array = waveform.squeeze().numpy() |
|
|
|
|
|
result = zero_shot_classifier( |
|
audio_array, |
|
candidate_labels=candidate_labels |
|
) |
|
return {label['label']: label['score'] for label in result} |
|
except Exception as e: |
|
print(f"Error in classify_audio: {e}") |
|
return {"Error": str(e)} |
|
|
|
|