adrianzarbock's picture
Update app.py
b3a841c
raw
history blame
5.37 kB
# general setup
import os
os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
os.system('pip install opencv-python')
# setup detectron2 logger
import torch, detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
# import some common libraries
import numpy as np
import os, json, cv2
import pandas as pd
from PIL import Image
from torchvision import transforms
from torchvision import models
from torch import nn
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
# import gradio
import gradio as gr
# set device
DEVICE = 'cpu'
# load model
model = models.resnet18(pretrained=True)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 5)
# insert trained paramters
model.load_state_dict(torch.load('model_modernity.pth', map_location=torch.device('cpu')))
# enable model eval
model.eval()
# define mean and std of resent training data
mean = [0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
# define transforms
test_transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=mean,
std=std)
])
# define input and outputs
i1 = gr.inputs.Image(type="numpy", label="Input image")
o1 = gr.outputs.Image(type="pil", label="Cropped image")
o2 = gr.outputs.Textbox(label="Modernity score")
# define function to be called by gradio interface
def modernity(im):
# create detectron2 config and detectron2 DefaultPredictor to run inference on image
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.DEVICE='cpu'
predictor = DefaultPredictor(cfg)
outputs = predictor(im)
# get all masks of input image
masks = outputs['instances'].pred_masks.to('cpu').numpy()
# create empty lists for objects names and object sizes
obj = []
obj_size = []
# iterate over all detected objects in input image to obtain object names and object sizes
for idx, data in enumerate(outputs['instances'].pred_classes):
num = data.item()
obj.append(MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes[num])
obj_size.append(masks[idx].sum())
# define output if there is no automobile detected
if 'car' not in obj:
# return image with all detected objects highlighted
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
out = v.draw_instance_predictions(outputs["instances"].to('cpu'))
img = (out.get_image()[:, :, ::-1])
# return message
out = 'No automobiles were found in the image.'
else:
# create data frame containing all object names and sizes
objects = pd.DataFrame({'obj': obj,
'obj_size': obj_size})
# get mask of the largest object that is labeled as car
item_mask = masks[objects[objects['obj'] == 'car']['obj_size'].idxmax()]
# create segmentation
segmentation = np.where(item_mask == True)
# get x and y boundaries
x_min = int(np.min(segmentation[1]))
x_max = int(np.max(segmentation[1]))
y_min = int(np.min(segmentation[0]))
y_max = int(np.max(segmentation[0]))
# create cropped image
cropped = Image.fromarray(im[y_min:y_max, x_min:x_max, :], mode='RGB')
# create mask
mask = Image.fromarray((item_mask * 255).astype('uint8'))
# create cropped mask
cropped_mask = mask.crop((x_min, y_min, x_max, y_max))
# create background
background = Image.new(mode='RGB', size=cropped_mask.size, color='white')
# define paste position
paste_position = (0,0)
# create foreground image
new_fg_image = Image.new('RGB', background.size)
new_fg_image.paste(cropped, paste_position)
# composite final image
img = Image.composite(new_fg_image, background, cropped_mask)
# apply previously defined transformations
img_t = test_transform(img).to(DEVICE)
# feed transformed image to the model
out = model(img_t[None, :])
# apply softmax
softmax = nn.Softmax(dim=1)
out = softmax(out)
# get label classes
label_classes=torch.tensor([0,1,2,3,4]).to(DEVICE)
# compute modernity score
out = round((label_classes * out).sum(axis=1).item(),1)
return img, out
# set interface title
title = 'Design Modernity of Automobiles'
# set interface description
description = "Demo for design modernity of automobiles. To use it, simply upload your image, or click one of the examples to load them."
# include example images
examples = [['input.jpg'],['input1.jpg']]
# define interface
interface = gr.Interface(modernity,inputs=i1, outputs=[o1, o2], title=title, description=description, examples=examples, cache_examples=False)
# launch interface
interface.launch()