|
from transformers import pipeline |
|
import gradio as gr |
|
|
|
pipe = pipeline( |
|
"audio-classification", model="juangtzi/wav2vec2-base-finetuned-gtzan" |
|
) |
|
|
|
def classify_audio(filepath): |
|
import time |
|
start_time = time.time() |
|
|
|
preds = pipe(filepath) |
|
outputs = {} |
|
for p in preds: |
|
outputs[p["label"]] = p["score"] |
|
|
|
end_time = time.time() |
|
prediction_time = end_time - start_time |
|
|
|
|
|
return outputs, prediction_time |
|
|
|
|
|
title = "π΅ Music Genre Classifier" |
|
description = """ |
|
Music Genre Classifier model (Fine-tuned "facebook/wav2vec2-base") Dataset: [GTZAN](https://huggingface.co/datasets/marsyas/gtzan) |
|
""" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_audio, |
|
inputs=gr.Audio(type="filepath"), |
|
outputs=[gr.Label(), gr.Number(label="Prediction time (s)")], |
|
title=title, |
|
description=description, |
|
examples="./example", |
|
|
|
allow_flagging="never", |
|
) |
|
demo.queue() |
|
|
|
demo.launch(share=True) |