Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, | |
**{'topN': 6, 'device': 'cpu', 'num_classes': 200}).eval() | |
transform_test = transforms.Compose([ | |
transforms.Resize((600, 600), Image.BILINEAR), | |
transforms.CenterCrop((448, 448)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
def predict(inp): | |
inp = Image.fromarray(inp.astype('uint8'), 'RGB') | |
scaled_img = transform_test(inp) | |
torch_images = scaled_img.unsqueeze(0) | |
with torch.no_grad(): | |
top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model( | |
torch_images) | |
pred = torch.nn.functional.softmax(concat_logits) | |
return {model.bird_classes[i]: float(p) for i, p in enumerate(pred.squeeze(0))} | |
inputs = gr.inputs.Image() | |
outputs = gr.outputs.Label(num_top_classes=10) | |
gr.Interface(fn=predict, inputs=inputs, outputs=outputs, | |
title="200 Bird Species Classifications with NTS-NET (From CUB 200)", | |
examples=["gabbiano.jpg"]).launch() | |