z-uo's picture
fake fix
4308ef9
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
**{'topN': 6, 'device': 'cpu', 'num_classes': 200}).eval()
transform_test = transforms.Compose([
transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def predict(inp):
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
scaled_img = transform_test(inp)
torch_images = scaled_img.unsqueeze(0)
with torch.no_grad():
top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(
torch_images)
pred = torch.nn.functional.softmax(concat_logits)
return {model.bird_classes[i]: float(p) for i, p in enumerate(pred.squeeze(0))}
inputs = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=10)
gr.Interface(fn=predict, inputs=inputs, outputs=outputs,
title="200 Bird Species Classifications with NTS-NET (From CUB 200)",
examples=["gabbiano.jpg"]).launch()