Spaces:
Runtime error
Runtime error
File size: 2,537 Bytes
4af5909 0349d8b 4af5909 0349d8b 3d879f2 4af5909 0349d8b 4af5909 0349d8b 4af5909 0349d8b 4af5909 ea9241e 4af5909 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
from torch.nn import functional as F
import torchvision
from torchvision import transforms, models
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from PIL import Image
import gradio as gr
classes = ['Speed limit (20km/h)',
'Speed limit (30km/h)',
'Speed limit (50km/h)',
'Speed limit (60km/h)',
'Speed limit (70km/h)',
'Speed limit (80km/h)',
'End of speed limit (80km/h)',
'Speed limit (100km/h)',
'Speed limit (120km/h)',
'No passing',
'No passing veh over 3.5 tons',
'Right-of-way at intersection',
'Priority road',
'Yield',
'Stop',
'No vehicles',
'Veh > 3.5 tons prohibited',
'No entry',
'General caution',
'Dangerous curve left',
'Dangerous curve right',
'Double curve',
'Bumpy road',
'Slippery road',
'Road narrows on the right',
'Road work',
'Traffic signals',
'Pedestrians',
'Children crossing',
'Bicycles crossing',
'Beware of ice/snow',
'Wild animals crossing',
'End speed + passing limits',
'Turn right ahead',
'Turn left ahead',
'Ahead only',
'Go straight or right',
'Go straight or left',
'Keep right',
'Keep left',
'Roundabout mandatory',
'End of no passing',
'End no passing veh > 3.5 tons']
class LitGTSRB(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = models.resnet18(pretrained=False, num_classes=43)
def forward(self, x):
out = self.model(x)
return F.log_softmax(out, dim=1)
def predict_image(image):
model = LitGTSRB().load_from_checkpoint('resnet18.ckpt')
model.eval()
image = image.convert('RGB')
test_transforms = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_tensor = test_transforms(image).float()
image_tensor = image_tensor.unsqueeze_(0)
with torch.no_grad():
output = model(image_tensor)
probs = torch.exp(output.data.cpu().squeeze())
prediction_score , pred_label_idx = torch.topk(probs,5)
class_top5 = [classes[idx] for idx in pred_label_idx.numpy()]
return dict(zip(class_top5, map(float, prediction_score.numpy())))
image = gr.Image(type='pil')
label = gr.Label()
examples = ['1.png', '2.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png']
intf = gr.Interface(fn=predict_image, inputs=image, outputs=label, examples=examples)
intf.launch(inline=True) |