File size: 2,537 Bytes
4af5909
0349d8b
4af5909
0349d8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d879f2
4af5909
0349d8b
 
4af5909
 
 
0349d8b
 
4af5909
 
 
0349d8b
 
 
 
 
 
4af5909
 
ea9241e
4af5909
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from torch.nn import functional as F
import torchvision
from torchvision import transforms, models
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from PIL import Image
import gradio as gr

classes = ['Speed limit (20km/h)',
 'Speed limit (30km/h)',
 'Speed limit (50km/h)',
 'Speed limit (60km/h)',
 'Speed limit (70km/h)',
 'Speed limit (80km/h)',
 'End of speed limit (80km/h)',
 'Speed limit (100km/h)',
 'Speed limit (120km/h)',
 'No passing',
 'No passing veh over 3.5 tons',
 'Right-of-way at intersection',
 'Priority road',
 'Yield',
 'Stop',
 'No vehicles',
 'Veh > 3.5 tons prohibited',
 'No entry',
 'General caution',
 'Dangerous curve left',
 'Dangerous curve right',
 'Double curve',
 'Bumpy road',
 'Slippery road',
 'Road narrows on the right',
 'Road work',
 'Traffic signals',
 'Pedestrians',
 'Children crossing',
 'Bicycles crossing',
 'Beware of ice/snow',
 'Wild animals crossing',
 'End speed + passing limits',
 'Turn right ahead',
 'Turn left ahead',
 'Ahead only',
 'Go straight or right',
 'Go straight or left',
 'Keep right',
 'Keep left',
 'Roundabout mandatory',
 'End of no passing',
 'End no passing veh > 3.5 tons']

class LitGTSRB(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = models.resnet18(pretrained=False, num_classes=43)

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

def predict_image(image):
        model = LitGTSRB().load_from_checkpoint('resnet18.ckpt')
        model.eval()
        image = image.convert('RGB')
        test_transforms = transforms.Compose([
                transforms.Resize([224, 224]),
                transforms.ToTensor(), 
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ])
        image_tensor = test_transforms(image).float()
        image_tensor = image_tensor.unsqueeze_(0)
        with torch.no_grad():
                output = model(image_tensor)
        probs = torch.exp(output.data.cpu().squeeze())
        prediction_score , pred_label_idx = torch.topk(probs,5)
        class_top5 = [classes[idx] for idx in pred_label_idx.numpy()]
        return dict(zip(class_top5, map(float, prediction_score.numpy())))
image = gr.Image(type='pil')
label = gr.Label()
examples = ['1.png', '2.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png']
intf = gr.Interface(fn=predict_image, inputs=image, outputs=label, examples=examples)
intf.launch(inline=True)