yadav's picture
Add application file
6b0f7b2
raw
history blame
2.65 kB
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classes = { 0:'Speed limit (20km/h)',
1:'Speed limit (30km/h)',
2:'Speed limit (50km/h)',
3:'Speed limit (60km/h)',
4:'Speed limit (70km/h)',
5:'Speed limit (80km/h)',
6:'End of speed limit (80km/h)',
7:'Speed limit (100km/h)',
8:'Speed limit (120km/h)',
9:'No passing',
10:'No passing veh over 3.5 tons',
11:'Right-of-way at intersection',
12:'Priority road',
13:'Yield',
14:'Stop',
15:'No vehicles',
16:'Veh > 3.5 tons prohibited',
17:'No entry',
18:'General caution',
19:'Dangerous curve left',
20:'Dangerous curve right',
21:'Double curve',
22:'Bumpy road',
23:'Slippery road',
24:'Road narrows on the right',
25:'Road work',
26:'Traffic signals',
27:'Pedestrians',
28:'Children crossing',
29:'Bicycles crossing',
30:'Beware of ice/snow',
31:'Wild animals crossing',
32:'End speed + passing limits',
33:'Turn right ahead',
34:'Turn left ahead',
35:'Ahead only',
36:'Go straight or right',
37:'Go straight or left',
38:'Keep right',
39:'Keep left',
40:'Roundabout mandatory',
41:'End of no passing',
42:'End no passing veh > 3.5 tons' }
def transform_images(img):
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize((30, 30)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
return transform(img)
model = torch.jit.load('model_scripted.pt')
model.eval()
def classify_image(img):
image = transform_images(img).to(device)
outputs = model(image)
_, predicted = torch.max(outputs.data, 1)
return classes[int(predicted[0])]
image = gr.inputs.Image(shape=(30,30))
label = gr.outputs.Label()
examples = ['002_0003_j.png', '054_0024_j.png', '056_1_0001_1_j.png', '003_1_0009_1_j.png', '055_1_0005_1_j.png', '056_1_0013_1_j.png']
intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=examples)
intf.launch(inline=False)