import gradio as gr import numpy as np # linear algebra import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) import os import torch import torchvision import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions import torchvision.datasets as datasets # Has standard datasets we can import in a nice way import torchvision.transforms as transforms # Transformations we can perform on our dataset import torch.nn.functional as F # All functions that don't have any parameters from torch.utils.data import DataLoader, Dataset # Gives easier dataset managment and creates mini batches from torchvision.datasets import ImageFolder import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc. from PIL import Image device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # use gpu or cpu from tqdm import tqdm from torchvision import models # load pretrain model and modify... model = models.resnet50(pretrained=True) # If you want to do finetuning then set requires_grad = False # Remove these two lines if you want to train entire model, # and only want to load the pretrain weights. for param in model.parameters(): param.requires_grad = False num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 2) model.to(device) # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.01) checkpoint = torch.load("checpoint_epoch_4.pt", map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) def image_classifier(inp): model.eval() data_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224)), transforms.Normalize([0.5] * 3, [0.5] * 3), ]) img = data_transforms(inp).unsqueeze(dim=0) img = img.to(device) pred = model(img) _, preds = torch.max(pred, 1) print(f"class : {preds}") cur_name = "" if preds[0] == 1: print(f"predicted ----> Dog") cur_name = "DOG" else: print(f"predicted ----> Cat") cur_name = "CAT" return cur_name demo = gr.Interface(fn=image_classifier, inputs="image", outputs="text") demo.launch()