|
import gradio as gr |
|
import torch |
|
from transformers import pipeline |
|
|
|
username = "ardneebwar" |
|
model_id = f"{username}/distilhubert-finetuned-gtzan" |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
pipe = pipeline("audio-classification", model=model_id, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_audio(filepath): |
|
""" |
|
Goes from |
|
[{'score': 0.8339303731918335, 'label': 'country'}, |
|
{'score': 0.11914275586605072, 'label': 'rock'},] |
|
to |
|
{"country": 0.8339303731918335, "rock":0.11914275586605072} |
|
""" |
|
preds = pipe(filepath) |
|
|
|
outputs = {} |
|
for p in preds: |
|
outputs[p["label"]] = p["score"] |
|
return outputs |
|
|
|
|
|
title = "π΅ Music Genre Classifier" |
|
description = """ |
|
Music Genre Classifier model (Fine-tuned "ntu-spml/distilhubert") Dataset: [GTZAN](https://huggingface.co/datasets/marsyas/gtzan) |
|
""" |
|
|
|
filenames = ['rock-it-21275.mp3'] |
|
filenames = [[f"./{f}"] for f in filenames] |
|
demo = gr.Interface( |
|
fn=classify_audio, |
|
inputs=gr.Audio(type="filepath"), |
|
outputs=[gr.outputs.Label(), gr.Number(label="Prediction time (s)")], |
|
title=title, |
|
description=description, |
|
examples=filenames, |
|
) |
|
|
|
demo.launch() |
|
|