Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from model import AlexNet | |
| from torchvision import transforms | |
| #More Libraries ... | |
| model_path = './alexnet_model_v1.pth' | |
| model = AlexNet() | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
| def predict(inp): | |
| inp = transforms.ToTensor()(inp).unsqueeze(0) | |
| with torch.no_grad(): | |
| prediction = torch.nn.functional.softmax(model(inp)[0], dim=0) | |
| confidences = {labels[i]: float(prediction[i]) for i in range(10)} | |
| return confidences | |
| gr.Interface(fn=predict, | |
| inputs=gr.components.Image(type="pil"), | |
| outputs=gr.components.Label(num_top_classes=5), | |
| examples=["frog.jpeg", "car.jpeg", "cat.jpeg", "ship.jpeg", "dog.jpeg"], | |
| theme="default", | |
| css=".footer{display:none !important}").launch() |