Spaces:
Runtime error
Runtime error
import torch | |
from torch.nn import functional as F | |
import torchvision | |
from torchvision import transforms, models | |
import pytorch_lightning as pl | |
from pytorch_lightning import LightningModule, Trainer | |
from PIL import Image | |
import gradio as gr | |
classes = ['Speed limit (20km/h)', | |
'Speed limit (30km/h)', | |
'Speed limit (50km/h)', | |
'Speed limit (60km/h)', | |
'Speed limit (70km/h)', | |
'Speed limit (80km/h)', | |
'End of speed limit (80km/h)', | |
'Speed limit (100km/h)', | |
'Speed limit (120km/h)', | |
'No passing', | |
'No passing veh over 3.5 tons', | |
'Right-of-way at intersection', | |
'Priority road', | |
'Yield', | |
'Stop', | |
'No vehicles', | |
'Veh > 3.5 tons prohibited', | |
'No entry', | |
'General caution', | |
'Dangerous curve left', | |
'Dangerous curve right', | |
'Double curve', | |
'Bumpy road', | |
'Slippery road', | |
'Road narrows on the right', | |
'Road work', | |
'Traffic signals', | |
'Pedestrians', | |
'Children crossing', | |
'Bicycles crossing', | |
'Beware of ice/snow', | |
'Wild animals crossing', | |
'End speed + passing limits', | |
'Turn right ahead', | |
'Turn left ahead', | |
'Ahead only', | |
'Go straight or right', | |
'Go straight or left', | |
'Keep right', | |
'Keep left', | |
'Roundabout mandatory', | |
'End of no passing', | |
'End no passing veh > 3.5 tons'] | |
class LitGTSRB(pl.LightningModule): | |
def __init__(self): | |
super().__init__() | |
self.model = models.resnet18(pretrained=False, num_classes=43) | |
def forward(self, x): | |
out = self.model(x) | |
return F.log_softmax(out, dim=1) | |
def predict_image(image): | |
model = LitGTSRB().load_from_checkpoint('resnet18.ckpt') | |
model.eval() | |
image = image.convert('RGB') | |
test_transforms = transforms.Compose([ | |
transforms.Resize([224, 224]), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
image_tensor = test_transforms(image).float() | |
image_tensor = image_tensor.unsqueeze_(0) | |
with torch.no_grad(): | |
output = model(image_tensor) | |
probs = torch.exp(output.data.cpu().squeeze()) | |
prediction_score , pred_label_idx = torch.topk(probs,5) | |
class_top5 = [classes[idx] for idx in pred_label_idx.numpy()] | |
return dict(zip(class_top5, map(float, prediction_score.numpy()))) | |
image = gr.Image(type='pil') | |
label = gr.Label() | |
examples = ['1.png', '2.png', '3.png', '4.png', '5.png', '6.png', '7.png', '8.png'] | |
intf = gr.Interface(fn=predict_image, inputs=image, outputs=label, examples=examples) | |
intf.launch(inline=True) |