Aravindan's picture
Update app.py
5b7bd21
import cv2, torch
import gradio as gr
import numpy as np
from PIL import Image
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms as T
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
new_model = models.resnet18()
num_ftrs = new_model.fc.in_features
new_model.fc = nn.Linear(num_ftrs, 75)
checkpoint = torch.load('model_best_checkpoint.pth.tar', map_location=torch.device('cpu'))
new_model.load_state_dict(checkpoint['model'])
new_model.to(device)
we_are = ['INDRA SWALLOW',
'MALACHITE',
'COMMON BANDED AWL',
'DANAID EGGFLY',
'EASTERN PINE ELFIN',
'YELLOW SWALLOW TAIL',
'WOOD SATYR',
'ULYSES',
'MESTRA',
'MANGROVE SKIPPER',
'BECKERS WHITE',
'CRECENT',
'RED SPOTTED PURPLE',
'SOOTYWING',
'BLACK HAIRSTREAK',
'STRAITED QUEEN',
'ELBOWED PIERROT',
'ORANGE OAKLEAF',
'CHESTNUT',
'POPINJAY',
'COMMON WOOD-NYMPH',
'BROWN SIPROETA',
'QUESTION MARK',
'ADONIS',
'CLOUDED SULPHUR',
'TWO BARRED FLASHER',
'GOLD BANDED',
'BANDED ORANGE HELICONIAN',
'PURPLISH COPPER',
'VICEROY',
'RED CRACKER',
'SILVER SPOT SKIPPER',
'ZEBRA LONG WING',
'ORCHARD SWALLOW',
'RED POSTMAN',
'SOUTHERN DOGFACE',
'SCARCE SWALLOW',
'EASTERN COMA',
'CAIRNS BIRDWING',
'GREEN CELLED CATTLEHEART',
'METALMARK',
'LARGE MARBLE',
'AMERICAN SNOOT',
'COPPER TAIL',
'AN 88',
'AFRICAN GIANT SWALLOWTAIL',
'PAPER KITE',
'EASTERN DAPPLE WHITE',
'PEACOCK',
'ATALA',
'JULIA',
'RED ADMIRAL',
'GREAT JAY',
'GREAT EGGFLY',
'GREY HAIRSTREAK',
'PIPEVINE SWALLOW',
'PURPLE HAIRSTREAK',
'ORANGE TIP',
'BLUE SPOTTED CROW',
'TROPICAL LEAFWING',
'CLEOPATRA',
'APPOLLO',
'IPHICLUS SISTER',
'CABBAGE WHITE',
'BANDED PEACOCK',
'MONARCH',
'CRIMSON PATCH',
'BLUE MORPHO',
'MOURNING CLOAK',
'SLEEPY ORANGE',
'CLODIUS PARNASSIAN',
'MILBERTS TORTOISESHELL',
'PINE WHITE',
'CHECQUERED SKIPPER',
'PAINTED LADY']
def classify(image_):
model = new_model.eval()
image = Image.open(image_)
image = image_transforms(image).float().to(device)
image = image.unsqueeze(0)
output = model(image)
_, predicted = torch.max(output, 1)
return we_are[predicted]
label = gr.outputs.Label(num_top_classes=75)
gr.Interface(fn=classify, inputs='image', outputs=label,interpretation='default', title = 'Butterfly Classification detection ', description = 'It will classify 75 different species ').launch()