Spaces:
Runtime error
Runtime error
import pandas as pd | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import matplotlib.image as mpimg | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
classes = { 0:'Speed limit (20km/h)', | |
1:'Speed limit (30km/h)', | |
2:'Speed limit (50km/h)', | |
3:'Speed limit (60km/h)', | |
4:'Speed limit (70km/h)', | |
5:'Speed limit (80km/h)', | |
6:'End of speed limit (80km/h)', | |
7:'Speed limit (100km/h)', | |
8:'Speed limit (120km/h)', | |
9:'No passing', | |
10:'No passing veh over 3.5 tons', | |
11:'Right-of-way at intersection', | |
12:'Priority road', | |
13:'Yield', | |
14:'Stop', | |
15:'No vehicles', | |
16:'Veh > 3.5 tons prohibited', | |
17:'No entry', | |
18:'General caution', | |
19:'Dangerous curve left', | |
20:'Dangerous curve right', | |
21:'Double curve', | |
22:'Bumpy road', | |
23:'Slippery road', | |
24:'Road narrows on the right', | |
25:'Road work', | |
26:'Traffic signals', | |
27:'Pedestrians', | |
28:'Children crossing', | |
29:'Bicycles crossing', | |
30:'Beware of ice/snow', | |
31:'Wild animals crossing', | |
32:'End speed + passing limits', | |
33:'Turn right ahead', | |
34:'Turn left ahead', | |
35:'Ahead only', | |
36:'Go straight or right', | |
37:'Go straight or left', | |
38:'Keep right', | |
39:'Keep left', | |
40:'Roundabout mandatory', | |
41:'End of no passing', | |
42:'End no passing veh > 3.5 tons' } | |
def transform_images(img): | |
transform = transforms.Compose( | |
[transforms.ToTensor(), | |
transforms.Resize((30, 30)), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | |
) | |
return transform(img) | |
model = torch.jit.load('model_scripted.pt') | |
model.eval() | |
def classify_image(img): | |
image = transform_images(img).to(device) | |
outputs = model(image) | |
_, predicted = torch.max(outputs.data, 1) | |
return classes[int(predicted[0])] | |
image = gr.inputs.Image(shape=(30,30)) | |
label = gr.outputs.Label() | |
examples = ['002_0003_j.png', '054_0024_j.png', '056_1_0001_1_j.png', '003_1_0009_1_j.png', '055_1_0005_1_j.png', '056_1_0013_1_j.png'] | |
intf = gr.Interface(fn=classify_image, inputs=image, outputs=label, examples=examples) | |
intf.launch(inline=False) |