import gradio as gr import torch from PIL import Image from torchvision import transforms model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, **{'topN': 6, 'device': 'cpu', 'num_classes': 200}).eval() transform_test = transforms.Compose([ transforms.Resize((600, 600), Image.BILINEAR), transforms.CenterCrop((448, 448)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) def predict(inp): inp = Image.fromarray(inp.astype('uint8'), 'RGB') scaled_img = transform_test(inp) torch_images = scaled_img.unsqueeze(0) with torch.no_grad(): top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model( torch_images) pred = torch.nn.functional.softmax(concat_logits) return {model.bird_classes[i]: float(p) for i, p in enumerate(pred.squeeze(0))} inputs = gr.inputs.Image() outputs = gr.outputs.Label(num_top_classes=10) gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title="200 Bird Species Classifications with NTS-NET (From CUB 200)", examples=["gabbiano.jpg"]).launch()