whos_dat_doggo / app.py
Samuel Diaz
Fixed model
7b29154
import gradio as gr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn.functional as F
from torchvision import transforms
import torchvision
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset, random_split, DataLoader
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
class net50(torch.nn.Module):
def __init__(self, base_model, base_out_features, num_classes):
super(net50,self).__init__()
self.base_model=base_model
self.linear1 = torch.nn.Linear(base_out_features, 512)
self.output = torch.nn.Linear(512,num_classes)
def forward(self,x):
x = F.relu(self.base_model(x))
x = F.relu(self.linear1(x))
x = self.output(x)
return x
def get_default_device():
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
device = get_default_device()
PATH = "./model/model.zip"
map_location=torch.device('cpu')
def predict_single(img):
xb = transform_image(img) # Transforming image to Tensor
xb = xb.to(device)
preds = model(xb) # change model object here
max_val, kls = torch.max(preds, 1)
print('Predicted :', breeds[kls])
return breeds[kls]
def image_mod(image):
return predict_single(image)
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize((500)),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
return my_transforms(image_bytes).unsqueeze(0)
res = torchvision.models.resnet50(pretrained=True)
for param in res.parameters(): ## Freezing layers
param.requires_grad=False
model = net50(base_model=res, base_out_features=res.fc.out_features, num_classes=120)
model.load_state_dict(torch.load(PATH,map_location))
model.eval()
breeds=['Chihuahua',
'Japanese spaniel',
'Maltese dog',
'Pekinese',
'Shih Tzu',
'Blenheim spaniel',
'papillon',
'toy terrier',
'Rhodesian ridgeback',
'Afghan hound',
'basset',
'beagle',
'bloodhound',
'bluetick',
'black and tan coonhound',
'Walker hound',
'English foxhound',
'redbone',
'borzoi',
'Irish wolfhound',
'Italian greyhound',
'whippet',
'Ibizan hound',
'Norwegian elkhound',
'otterhound',
'Saluki',
'Scottish deerhound',
'Weimaraner',
'Staffordshire bullterrier',
'American Staffordshire terrier',
'Bedlington terrier',
'Border terrier',
'Kerry blue terrier',
'Irish terrier',
'Norfolk terrier',
'Norwich terrier',
'Yorkshire terrier',
'wire haired fox terrier',
'Lakeland terrier',
'Sealyham terrier',
'Airedale',
'cairn',
'Australian terrier',
'Dandie Dinmont',
'Boston bull',
'miniature schnauzer',
'giant schnauzer',
'standard schnauzer',
'Scotch terrier',
'Tibetan terrier',
'silky terrier',
'soft coated wheaten terrier',
'West Highland white terrier',
'Lhasa',
'flat coated retriever',
'curly coated retriever',
'golden retriever',
'Labrador retriever',
'Chesapeake Bay retriever',
'German short haired pointer',
'vizsla',
'English setter',
'Irish setter',
'Gordon setter',
'Brittany spaniel',
'clumber',
'English springer',
'Welsh springer spaniel',
'cocker spaniel',
'Sussex spaniel',
'Irish water spaniel',
'kuvasz',
'schipperke',
'groenendael',
'malinois',
'briard',
'kelpie',
'komondor',
'Old English sheepdog',
'Shetland sheepdog',
'collie',
'Border collie',
'Bouvier des Flandres',
'Rottweiler',
'German shepherd',
'Doberman',
'miniature pinscher',
'Greater Swiss Mountain dog',
'Bernese mountain dog',
'Appenzeller',
'EntleBucher',
'boxer',
'bull mastiff',
'Tibetan mastiff',
'French bulldog',
'Great Dane',
'Saint Bernard',
'Eskimo dog',
'malamute',
'Siberian husky',
'affenpinscher',
'basenji',
'pug',
'Leonberg',
'Newfoundland',
'Great Pyrenees',
'Samoyed',
'Pomeranian',
'chow',
'keeshond',
'Brabancon griffon',
'Pembroke',
'Cardigan',
'toy poodle',
'miniature poodle',
'standard poodle',
'Mexican hairless',
'dingo',
'dhole',
'African hunting dog']
iface = gr.Interface(image_mod, gr.Image(type="pil"), "text", examples=["doggo1.png","doggo2.jpg","doggo3.png","doggo4.png"])
iface.launch()