Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from torchvision import datasets, models, transforms | |
| from PIL import Image | |
| LABELS = ['Fiat 500', 'VW Up!'] | |
| model = models.resnet18(pretrained=True) | |
| num_ftrs = model.fc.in_features | |
| model.fc = torch.nn.Linear(num_ftrs, 2) | |
| state_dict = torch.load('up500Model.pt', map_location='cpu') | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| title = "VW Up! or Fiat 500" | |
| description = "Demo for classification of automobiles. To use it, simply upload your image, or click one of the examples to load them." | |
| imgTransforms = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) | |
| def predict(inp): | |
| inp = Image.fromarray(inp.astype('uint8'), 'RGB') | |
| inp = imgTransforms(inp).unsqueeze(0) | |
| with torch.no_grad(): | |
| prediction = torch.nn.functional.softmax(model(inp)[0]) | |
| return {LABELS[i]: float(prediction[i]) for i in range(2)} | |
| examples = [['fiat500.jpg'],['VWUP.jpg']] | |
| interface = gr.Interface(predict, inputs='image', outputs="label", title=title, description=description, examples=examples, cache_examples=False) | |
| interface.launch() |