yadav commited on
Commit
6b0f7b2
1 Parent(s): ead8105

Add application file

Browse files
002_0003_j.png ADDED
003_1_0009_1_j.png ADDED
054_0024_j.png ADDED
055_1_0005_1_j.png ADDED
056_1_0001_1_j.png ADDED
056_1_0013_1_j.png ADDED
app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+ import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.image as mpimg
9
+
10
+
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+
13
+ classes = { 0:'Speed limit (20km/h)',
14
+ 1:'Speed limit (30km/h)',
15
+ 2:'Speed limit (50km/h)',
16
+ 3:'Speed limit (60km/h)',
17
+ 4:'Speed limit (70km/h)',
18
+ 5:'Speed limit (80km/h)',
19
+ 6:'End of speed limit (80km/h)',
20
+ 7:'Speed limit (100km/h)',
21
+ 8:'Speed limit (120km/h)',
22
+ 9:'No passing',
23
+ 10:'No passing veh over 3.5 tons',
24
+ 11:'Right-of-way at intersection',
25
+ 12:'Priority road',
26
+ 13:'Yield',
27
+ 14:'Stop',
28
+ 15:'No vehicles',
29
+ 16:'Veh > 3.5 tons prohibited',
30
+ 17:'No entry',
31
+ 18:'General caution',
32
+ 19:'Dangerous curve left',
33
+ 20:'Dangerous curve right',
34
+ 21:'Double curve',
35
+ 22:'Bumpy road',
36
+ 23:'Slippery road',
37
+ 24:'Road narrows on the right',
38
+ 25:'Road work',
39
+ 26:'Traffic signals',
40
+ 27:'Pedestrians',
41
+ 28:'Children crossing',
42
+ 29:'Bicycles crossing',
43
+ 30:'Beware of ice/snow',
44
+ 31:'Wild animals crossing',
45
+ 32:'End speed + passing limits',
46
+ 33:'Turn right ahead',
47
+ 34:'Turn left ahead',
48
+ 35:'Ahead only',
49
+ 36:'Go straight or right',
50
+ 37:'Go straight or left',
51
+ 38:'Keep right',
52
+ 39:'Keep left',
53
+ 40:'Roundabout mandatory',
54
+ 41:'End of no passing',
55
+ 42:'End no passing veh > 3.5 tons' }
56
+
57
+
58
+ def transform_images(img):
59
+ transform = transforms.Compose(
60
+ [transforms.ToTensor(),
61
+ transforms.Resize((30, 30)),
62
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
63
+ )
64
+ return transform(img)
65
+
66
+
67
+ model = torch.jit.load('model_scripted.pt')
68
+ model.eval()
69
+
70
+ def classify_image(img):
71
+ image = transform_images(img).to(device)
72
+ outputs = model(image)
73
+ _, predicted = torch.max(outputs.data, 1)
74
+ return classes[int(predicted[0])]
75
+
76
+
77
+ image = gr.inputs.Image(shape=(30,30))
78
+ label = gr.outputs.Label()
79
+ examples = ['002_0003_j.png', '054_0024_j.png', '056_1_0001_1_j.png', '003_1_0009_1_j.png', '055_1_0005_1_j.png', '056_1_0013_1_j.png']
80
+
81
+ intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=examples)
82
+ intf.launch(inline=False)
model_scripted.pt ADDED
Binary file (396 kB). View file