import cv2, torch import gradio as gr import numpy as np import torch.nn as nn import torchvision.models as models from torchvision import transforms as T device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') new_model = models.resnet18() num_ftrs = new_model.fc.in_features new_model.fc = nn.Linear(num_ftrs, 75) checkpoint = torch.load('model_best_checkpoint.pth.tar', map_location=torch.device('cpu')) new_model.load_state_dict(checkpoint['model']) new_model.to(device) we_are = ['INDRA SWALLOW', 'MALACHITE', 'COMMON BANDED AWL', 'DANAID EGGFLY', 'EASTERN PINE ELFIN', 'YELLOW SWALLOW TAIL', 'WOOD SATYR', 'ULYSES', 'MESTRA', 'MANGROVE SKIPPER', 'BECKERS WHITE', 'CRECENT', 'RED SPOTTED PURPLE', 'SOOTYWING', 'BLACK HAIRSTREAK', 'STRAITED QUEEN', 'ELBOWED PIERROT', 'ORANGE OAKLEAF', 'CHESTNUT', 'POPINJAY', 'COMMON WOOD-NYMPH', 'BROWN SIPROETA', 'QUESTION MARK', 'ADONIS', 'CLOUDED SULPHUR', 'TWO BARRED FLASHER', 'GOLD BANDED', 'BANDED ORANGE HELICONIAN', 'PURPLISH COPPER', 'VICEROY', 'RED CRACKER', 'SILVER SPOT SKIPPER', 'ZEBRA LONG WING', 'ORCHARD SWALLOW', 'RED POSTMAN', 'SOUTHERN DOGFACE', 'SCARCE SWALLOW', 'EASTERN COMA', 'CAIRNS BIRDWING', 'GREEN CELLED CATTLEHEART', 'METALMARK', 'LARGE MARBLE', 'AMERICAN SNOOT', 'COPPER TAIL', 'AN 88', 'AFRICAN GIANT SWALLOWTAIL', 'PAPER KITE', 'EASTERN DAPPLE WHITE', 'PEACOCK', 'ATALA', 'JULIA', 'RED ADMIRAL', 'GREAT JAY', 'GREAT EGGFLY', 'GREY HAIRSTREAK', 'PIPEVINE SWALLOW', 'PURPLE HAIRSTREAK', 'ORANGE TIP', 'BLUE SPOTTED CROW', 'TROPICAL LEAFWING', 'CLEOPATRA', 'APPOLLO', 'IPHICLUS SISTER', 'CABBAGE WHITE', 'BANDED PEACOCK', 'MONARCH', 'CRIMSON PATCH', 'BLUE MORPHO', 'MOURNING CLOAK', 'SLEEPY ORANGE', 'CLODIUS PARNASSIAN', 'MILBERTS TORTOISESHELL', 'PINE WHITE', 'CHECQUERED SKIPPER', 'PAINTED LADY'] def classify(image_): model = new_model.eval() image = Image.open(image_) image = image_transforms(image).float().to(device) image = image.unsqueeze(0) output = model(image) _, predicted = torch.max(output, 1) return we_are[predicted] label = gr.outputs.Label(num_top_classes=75) gr.Interface(fn=classify, inputs='image', outputs=label,interpretation='default', title = 'Butterfly Classification detection ', description = 'It will classify 75 different species ').launch()