HugoHE's picture
Update app.py
ea9241e
import torch
from torch.nn import functional as F
import torchvision
from torchvision import transforms, models
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from PIL import Image
import gradio as gr
classes = ['Speed limit (20km/h)',
'Speed limit (30km/h)',
'Speed limit (50km/h)',
'Speed limit (60km/h)',
'Speed limit (70km/h)',
'Speed limit (80km/h)',
'End of speed limit (80km/h)',
'Speed limit (100km/h)',
'Speed limit (120km/h)',
'No passing',
'No passing veh over 3.5 tons',
'Right-of-way at intersection',
'Priority road',
'Yield',
'Stop',
'No vehicles',
'Veh > 3.5 tons prohibited',
'No entry',
'General caution',
'Dangerous curve left',
'Dangerous curve right',
'Double curve',
'Bumpy road',
'Slippery road',
'Road narrows on the right',
'Road work',
'Traffic signals',
'Pedestrians',
'Children crossing',
'Bicycles crossing',
'Beware of ice/snow',
'Wild animals crossing',
'End speed + passing limits',
'Turn right ahead',
'Turn left ahead',
'Ahead only',
'Go straight or right',
'Go straight or left',
'Keep right',
'Keep left',
'Roundabout mandatory',
'End of no passing',
'End no passing veh > 3.5 tons']
class LitGTSRB(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = models.resnet18(pretrained=False, num_classes=43)
def forward(self, x):
out = self.model(x)
return F.log_softmax(out, dim=1)
def predict_image(image):
model = LitGTSRB().load_from_checkpoint('resnet18.ckpt')
model.eval()
image = image.convert('RGB')
test_transforms = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = test_transforms(image).float()
image_tensor = image_tensor.unsqueeze_(0)
with torch.no_grad():
output = model(image_tensor)
probs = torch.exp(output.data.cpu().squeeze())
prediction_score , pred_label_idx = torch.topk(probs,5)
class_top5 = [classes[idx] for idx in pred_label_idx.numpy()]
return dict(zip(class_top5, map(float, prediction_score.numpy())))
image = gr.Image(type='pil')
label = gr.Label()
examples = ['1.png', '2.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png']
intf = gr.Interface(fn=predict_image, inputs=image, outputs=label, examples=examples)
intf.launch(inline=True)