|
import gradio as gr |
|
import torch |
|
from torchvision import transforms, models |
|
from PIL import Image |
|
import torch.nn as nn |
|
|
|
|
|
class CustomEfficientNet(nn.Module): |
|
def __init__(self, num_classes, num_layers, neurons_per_layer): |
|
super(CustomEfficientNet, self).__init__() |
|
self.base_model = models.efficientnet_b0(pretrained=True) |
|
in_features = self.base_model.classifier[1].in_features |
|
self.base_model.classifier = nn.Identity() |
|
|
|
|
|
layers = [] |
|
for _ in range(num_layers): |
|
layers.append(nn.Linear(in_features, neurons_per_layer)) |
|
layers.append(nn.ReLU()) |
|
layers.append(nn.Dropout(0.5)) |
|
in_features = neurons_per_layer |
|
|
|
layers.append(nn.Linear(neurons_per_layer, num_classes)) |
|
self.custom_classifier = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
x = self.base_model(x) |
|
x = x.view(x.size(0), -1) |
|
x = self.custom_classifier(x) |
|
return x |
|
|
|
def create_model(num_classes, num_layers, neurons_per_layer): |
|
model = CustomEfficientNet(num_classes, num_layers, neurons_per_layer) |
|
return model |
|
|
|
def load_model(path, num_classes, num_layers, neurons_per_layer): |
|
model = create_model(num_classes, num_layers, neurons_per_layer) |
|
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) |
|
model.eval() |
|
return model |
|
|
|
|
|
num_classes = 52 |
|
num_layers = 3 |
|
neurons_per_layer = 1024 |
|
|
|
|
|
model = load_model('card_classification_model.pth', num_classes, num_layers, neurons_per_layer) |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
class_names = ['Coeur 1', 'Coeur 10', 'Coeur 2', 'Coeur 3', 'Coeur 4', 'Coeur 5', 'Coeur 6', |
|
'Coeur 7', 'Coeur 8', 'Coeur 9', 'Coeur Dame', 'Coeur Roi', 'Coeur Valet', 'Pique 1', |
|
'Pique 10', 'Pique 2', 'Pique 3', 'Pique 4', 'Pique 5', 'Pique 6', 'Pique 7', 'Pique 8', |
|
'Pique 9', 'Pique Dame', 'Pique Roi', 'Pique Valet', 'Trefle 1', 'Trefle 10', 'Trefle 2', |
|
'Trefle 3', 'Trefle 4', 'Trefle 5', 'Trefle 6', 'Trefle 7', 'Trefle 8', 'Trefle 9', 'Trefle Dame', |
|
'Trefle Roi', 'Trefle Valet', 'carreau 1', 'carreau 10', 'carreau 2', 'carreau 3', 'carreau 4', 'carreau 5', |
|
'carreau 6', 'carreau 7', 'carreau 8', 'carreau 9', 'carreau Dame', 'carreau Roi', 'carreau Valet'] |
|
|
|
def predict(image): |
|
image = transform(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
outputs = model(image) |
|
_, predicted = torch.max(outputs, 1) |
|
return class_names[predicted[0]] |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs="label", |
|
description="Upload an image to classify" |
|
) |
|
|
|
iface.launch() |
|
|