|
import torch |
|
from torch import Tensor as T |
|
import torchvision.models as models |
|
import torchvision.transforms as transforms |
|
import torchvision.datasets as datasets |
|
from torchvision.transforms import Compose |
|
from torch.nn import Module |
|
from torch.nn.functional import softmax |
|
import requests |
|
from PIL import Image |
|
import random |
|
from gradio import Blocks, Tabs, TabItem, Row, Column, Image, Dropdown, Button, Label |
|
|
|
|
|
IMAGE_PREDICTION_MODELS = { |
|
'resnet': models.resnet50, |
|
'alexnet': models.alexnet, |
|
'vgg': models.vgg16, |
|
'squeezenet': models.squeezenet1_0, |
|
'densenet': models.densenet161, |
|
'inception': models.inception_v3, |
|
'googlenet': models.googlenet, |
|
'shufflenet': models.shufflenet_v2_x1_0, |
|
'mobilenet': models.mobilenet_v2, |
|
'resnext': models.resnext50_32x4d, |
|
'wide_resnet': models.wide_resnet50_2, |
|
'mnasnet': models.mnasnet1_0, |
|
'efficientnet': models.efficientnet_b0, |
|
'regnet': models.regnet_y_400mf, |
|
'vit': models.vit_b_16, |
|
'convnext': models.convnext_tiny |
|
} |
|
|
|
|
|
class ModelLoader: |
|
def __init__(self, model_dict : dict): |
|
self.model_dict = model_dict |
|
|
|
def load_model(self, model_name : str) -> Module : |
|
model_name_lower = model_name.lower() |
|
if model_name_lower in self.model_dict: |
|
model_class = self.model_dict[model_name_lower] |
|
model = model_class(pretrained=True) |
|
return model |
|
else: |
|
raise ValueError(f"Model {model_name} is not available for image prediction in torchvision.models") |
|
|
|
def get_model_names(self) -> list: |
|
return [name.capitalize() for name in self.model_dict.keys()] |
|
|
|
|
|
class Preprocessor: |
|
def __init__(self): |
|
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
|
def preprocess(self, model_name : str) -> Compose: |
|
input_size = 224 |
|
if model_name == 'inception': |
|
input_size = 299 |
|
return transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(input_size), |
|
transforms.ToTensor(), |
|
self.normalize, |
|
]) |
|
|
|
|
|
class Postprocessor: |
|
def __init__(self, labels : list): |
|
self.labels = labels |
|
|
|
def postprocess_default(self, output) -> dict: |
|
probabilities = softmax(output[0], dim=0) |
|
top_prob , top_catid = torch.topk(probabilities, 5) |
|
confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))} |
|
return confidences |
|
|
|
def postprocess_inception(self, output) -> dict: |
|
probabilities : T = softmax(output[1], dim=0) |
|
top_prob, top_catid = torch.topk(probabilities, 5) |
|
confidences = {self.labels[top_catid[i].item()]: top_prob[i].item() for i in range(top_prob.size(0))} |
|
return confidences |
|
|
|
|
|
class ImageClassifier: |
|
def __init__(self, model_loader : ModelLoader, preprocessor: Preprocessor, postprocessor : Postprocessor): |
|
self.model_loader = model_loader |
|
self.preprocessor = preprocessor |
|
self.postprocessor = postprocessor |
|
|
|
def classify(self, input_image : Image, selected_model : str) -> dict: |
|
preprocess_input : Compose = self.preprocessor.preprocess(model_name=selected_model) |
|
input_tensor : T = preprocess_input(input_image) |
|
input_batch = input_tensor.unsqueeze(0) |
|
model = self.model_loader.load_model(selected_model) |
|
|
|
if torch.cuda.is_available(): |
|
input_batch = input_batch.to('cuda') |
|
model.to('cuda') |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
output : T = model(input_batch) |
|
|
|
if selected_model.lower() == 'inception': |
|
return self.postprocessor.postprocess_inception(output) |
|
else: |
|
return self.postprocessor.postprocess_default(output) |
|
|
|
|
|
class CIFAR10ImageProvider: |
|
def __init__(self, dataset_root='./data', transform = transforms.ToTensor()): |
|
self.dataset_root = dataset_root |
|
self.transform = transform |
|
|
|
def get_random_image(self, resize_dim=(256, 256)) -> Image: |
|
cifar10 = datasets.CIFAR10(root=self.dataset_root, train=False, download=True, transform= self.transform) |
|
random_idx = random.randint(0, len(cifar10) - 1) |
|
image, _ = cifar10[random_idx] |
|
image= transforms.ToPILImage()(image) |
|
image = image.resize(resize_dim, ) |
|
return image |
|
|
|
|
|
class GradioApp: |
|
def __init__(self, image_classifier : ImageClassifier, image_provider : CIFAR10ImageProvider, model_list : list): |
|
self.image_classifier = image_classifier |
|
self.image_provider = image_provider |
|
self.model_list = model_list |
|
|
|
def launch(self): |
|
with Blocks() as demo: |
|
with Tabs(): |
|
with TabItem("Upload Image"): |
|
with Row(): |
|
with Column(): |
|
upload_image = Image(type='pil', label="Upload Image") |
|
model_dropdown_upload = Dropdown(self.model_list, label="Select Model") |
|
classify_button_upload = Button("Classify") |
|
with Column(): |
|
output_label_upload = Label(num_top_classes=5) |
|
classify_button_upload.click(self.image_classifier.classify, inputs=[upload_image, model_dropdown_upload], outputs=output_label_upload) |
|
|
|
with TabItem("Generate Random Image"): |
|
with Row(): |
|
with Column(): |
|
generate_button = Button("Generate Random Image") |
|
random_image_output = Image(type='pil', label="Random CIFAR-10 Image") |
|
with Column(): |
|
model_dropdown_random = Dropdown(self.model_list, label="Select Model") |
|
classify_button_random = Button("Classify") |
|
output_label_random = Label(num_top_classes=5) |
|
generate_button.click(self.image_provider.get_random_image, inputs=[], outputs=random_image_output) |
|
classify_button_random.click(self.image_classifier.classify, inputs=[random_image_output, model_dropdown_random], outputs=output_label_random) |
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
model_loader = ModelLoader(IMAGE_PREDICTION_MODELS) |
|
preprocessor = Preprocessor() |
|
response = requests.get("https://git.io/JJkYN") |
|
labels = response.text.split("\n") |
|
postprocessor = Postprocessor(labels) |
|
image_classifier = ImageClassifier(model_loader, preprocessor, postprocessor) |
|
image_provider = CIFAR10ImageProvider() |
|
model_list = model_loader.get_model_names() |
|
|
|
|
|
app = GradioApp(image_classifier, image_provider, model_list) |
|
app.launch() |