NTSNET / app.py
akhaliq's picture
akhaliq HF staff
Create app.py
c35b372
raw history blame
No virus
2.02 kB
from torchvision import transforms
import torch
import urllib
from PIL import Image
import gradio as gr
import torch
# Images
torch.hub.download_url_to_file('https://static.scientificamerican.com/sciam/cache/file/7A715AD8-449D-4B5A-ABA2C5D92D9B5A21_source.png', 'bird.png')
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True,
**{'topN': 6, 'device':'cpu', 'num_classes': 200})
transform_test = transforms.Compose([
transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
# transforms.RandomHorizontalFlip(), # only if train
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, **{'topN': 6, 'device':'cpu', 'num_classes': 200})
def birds(img):
scaled_img = transform_test(img)
torch_images = scaled_img.unsqueeze(0)
with torch.no_grad():
top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images)
_, predict = torch.max(concat_logits, 1)
pred_id = predict.item()
return model.bird_classes[pred_id].split('.')[1]
inputs = gr.inputs.Image(type='pil', label="Original Image")
outputs = gr.outputs.Textbox(label="bird class")
title = "ntsnet"
description = "demo for ntsnet to classify birds. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='http://artelab.dista.uninsubria.it/res/research/papers/2019/2019-IVCNZ-Nawaz-Birds.pdf'>Are These Birds Similar: Learning Branched Networks for Fine-grained Representations</a> | <a href='https://github.com/nicolalandro/ntsnet-cub200'>Github Repo</a></p>"
examples = [
['bird.png']
]
gr.Interface(birds, inputs, outputs, title=title, description=description,
article=article, examples=examples, analytics_enabled=False).launch()