|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torchvision import models, transforms
|
|
|
from PIL import Image
|
|
|
import gradio as gr
|
|
|
|
|
|
|
|
|
class_names = [
|
|
|
"accordion",
|
|
|
"banjo",
|
|
|
"drum",
|
|
|
"flute",
|
|
|
"guitar",
|
|
|
"harmonica",
|
|
|
"saxophone",
|
|
|
"sitar",
|
|
|
"tabla",
|
|
|
"violin"
|
|
|
]
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([
|
|
|
transforms.Resize(256),
|
|
|
transforms.CenterCrop(224),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize([0.485, 0.456, 0.406],
|
|
|
[0.229, 0.224, 0.225])
|
|
|
])
|
|
|
|
|
|
|
|
|
def load_model(model_path="music_model.pth"):
|
|
|
model = models.resnet18(weights=None)
|
|
|
|
|
|
|
|
|
model.fc = nn.Sequential(
|
|
|
nn.Linear(model.fc.in_features, 256),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(0.4),
|
|
|
nn.Linear(256, 10)
|
|
|
)
|
|
|
|
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
|
|
model.eval()
|
|
|
return model
|
|
|
|
|
|
model = load_model("music_model.pth")
|
|
|
|
|
|
|
|
|
def predict(image):
|
|
|
image = Image.fromarray(image).convert("RGB")
|
|
|
img_tensor = transform(image).unsqueeze(0)
|
|
|
with torch.no_grad():
|
|
|
outputs = model(img_tensor)
|
|
|
_, predicted = torch.max(outputs, 1)
|
|
|
prediction = class_names[predicted.item()]
|
|
|
confidences = torch.nn.functional.softmax(outputs[0], dim=0)
|
|
|
confidences_dict = {class_names[i]: float(confidences[i]) for i in range(10)}
|
|
|
return prediction, confidences_dict
|
|
|
|
|
|
|
|
|
interface = gr.Interface(
|
|
|
fn=predict,
|
|
|
inputs=gr.Image(type="numpy", label="Upload Instrument Image"),
|
|
|
outputs=[
|
|
|
gr.Label(label="Predicted Instrument"),
|
|
|
gr.Label(label="Confidence Scores")
|
|
|
],
|
|
|
title="π΅ Musical Instrument Classifier",
|
|
|
description="Upload an image of a musical instrument and get the predicted class (accordion, guitar, etc.)"
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
interface.launch()
|
|
|
|