MSc_02_PDL_A4 / app.py
maxjmohr's picture
Change to opencv installation
32d5b6e
import os
import importlib
try:
importlib.import_module("detectron2")
except ImportError:
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
try:
importlib.import_module("cv2")
except ImportError:
os.system("pip install opencv-python")
import gradio as gr
import torch
from torch import nn
from torchvision import models, transforms
from PIL import Image
import cv2
import numpy as np
from modernity import get_year_modernity_score
from typicality import get_typicality_score
from viewpoint import get_viewpoint
if torch.cuda.is_available():
DEVICE = 'cuda'
# elif torch.backends.mps.is_available():
# DEVICE = 'mps'
else:
DEVICE = 'cpu'
# Load model for instance segmentation using Detectron2
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2 import model_zoo
cfg = get_cfg()
cfg.MODEL.DEVICE = DEVICE
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
predictor = DefaultPredictor(cfg)
###---MODEL: VIEWPOINT---###
model_viewpoint = models.resnet18(weights=True)
num_ftrs = model_viewpoint.fc.in_features
model_viewpoint.fc = nn.Linear(num_ftrs, 2)
model_viewpoint = model_viewpoint.to(DEVICE)
model_viewpoint.load_state_dict(torch.load('models_morphs/pdl-a4-viewpoint-best-model.pt', map_location=DEVICE))
###---MODEL: MODERNITY---###
model_modernity = models.resnet18(weights=True)
num_ftrs = model_modernity.fc.in_features
model_modernity.fc = nn.Linear(num_ftrs, 5)
model_modernity = model_modernity.to(DEVICE)
model_modernity.load_state_dict(torch.load('models_morphs/pdl-a3-modernity-best-model.pt', map_location=DEVICE))
###---MODEL: TYPICALITY---###
model_typicality = models.resnet18(weights=True)
num_ftrs = model_typicality.fc.in_features
model_typicality.fc = nn.Linear(num_ftrs, 5)
model_typicality = model_typicality.to(DEVICE)
model_typicality.load_state_dict(torch.load('models_morphs/pdl-a3-typicality-best-model.pt', map_location=DEVICE))
# Output labels
LABELS = ['Design Modernity score (scaled)', 'Design Typicality score']
# For later image transformations
imgTransforms = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x[:3, :, :]), # Keep only RGB channels and ignore alpha channel
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Predict classes
def predict(inp):
###---------------INSTANCE SEGMENTATION---------------###
# Convert RGB to BGR and predict segmentation
inp_seg = inp[:, :, ::-1]
outputs = predictor(inp_seg)
# Check if any automobile instances were found
instances = outputs["instances"][outputs["instances"].pred_classes == 2]
if len(instances) == 0:
return "Found no automobile instances in the image.", None, None, None, None
# Select instance with the largest number of pixels
boxes = instances.pred_boxes.tensor
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
selected_instance_idx = int(torch.argmax(areas))
selected_instance = instances[selected_instance_idx]
###---------------BACKGROUND MASK---------------###
# Find the index of class ID 2 (Automobile class) and create mask of that instance
class_ids = np.array(selected_instance.pred_classes.cpu())
class_index = np.where(class_ids == 2)
mask_tensor = selected_instance.pred_masks[class_index]
auto_mask = np.array(mask_tensor[0])
# Create 4-channel mask
mask = np.stack([auto_mask, auto_mask, auto_mask, auto_mask], axis=2)
# Split into RGB channels
b, g, r = cv2.split(inp.astype("uint8"))
# Create alpha channel array of ones (will become background)
a = np.ones(auto_mask.shape, dtype="uint8") * 255 # 255 to make transparent
a[auto_mask] = 0
# Merge alpha channel
merged_image = cv2.merge([b, g, r, a], 4)
# Create white background with same shape as merged_image
bg = np.ones_like(merged_image, dtype=np.uint8) * 255
# Copy color pixels from the original color image from mask
resulting_image = np.where(mask, merged_image, bg).astype(np.uint8)
# Crop image to the size of the selected instance
x0, y0, x1, y1 = selected_instance.get('pred_boxes').tensor[0].tolist()
resulting_image_cropped = resulting_image[int(y0):int(y1), int(x0):int(x1)].copy()
# Ensure the resulting image has three dimensions (RGB) by summing the alpha dimension over axis 2
resulting_image_cropped = resulting_image_cropped[:, :, :-1] + resulting_image_cropped[:, :, -1][:, :, None]
###---------------TRANSFORM IMAGE---------------###
resulting_image_cropped_transformed = Image.fromarray(resulting_image_cropped)
resulting_image_cropped_transformed = imgTransforms(resulting_image_cropped_transformed).unsqueeze(0).to(DEVICE)
###---------------VIEWPOINT CHECK---------------###
viewpoint = get_viewpoint(model_viewpoint, resulting_image_cropped_transformed, DEVICE)
if viewpoint == 0:
return (
"Found " + str(len(instances)) + " automobile instances in the image.",
np.array(resulting_image_cropped),
"The model couldn't detected a frontal viewpoint of the automobile. It cannot proceed to calculate the design scores.",
None, None
)
###---------------CLASSIFICATIONS---------------###
modernity_score, year_group = get_year_modernity_score(model_modernity, resulting_image_cropped_transformed, DEVICE)
typicality_score = get_typicality_score(model_typicality, resulting_image_cropped_transformed, year_group, DEVICE)
# Prepare the labels and scores
labels_with_scores = {
LABELS[0]: float(modernity_score)/4,
LABELS[1]: float(typicality_score)
}
return (
"Found " + str(len(instances)) + " automobile instances in the image.",
np.array(resulting_image_cropped),
"The model has detected a frontal viewpoint of the automobile. It proceeded to calculate the design scores.",
labels_with_scores,
"Design Modernity Score (absolute): " + str(modernity_score) + "\n"
"Design Typicality Score: " + str(typicality_score)
)
# Create Gradio Interface
TITLE = "Segmentation and Classification of Automobiles"
DESCRIPTION = "Demo for the segmentation and classification of automobiles along with their design modernity and typicality scores. To use it, simply upload your image, or click one of the examples to load them."
EXAMPLES = [['sample_imgs/BMW_6_front.jpg'],['sample_imgs/BMW_6_back.jpg'],['sample_imgs/BMW_multiple.jpg'],['sample_imgs/BMW_multiple_back.jpeg'],['sample_imgs/pony.jpeg']]
iface = gr.Interface(predict,
inputs='image',
outputs=["label", "image", "label", "label", "text"],
title=TITLE,
description=DESCRIPTION,
examples=EXAMPLES,
cache_examples=False)
iface.launch()