Spaces:
Sleeping
Sleeping
import os | |
import importlib | |
try: | |
importlib.import_module("detectron2") | |
except ImportError: | |
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") | |
try: | |
importlib.import_module("cv2") | |
except ImportError: | |
os.system("pip install opencv-python") | |
import gradio as gr | |
import torch | |
from torch import nn | |
from torchvision import models, transforms | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
from modernity import get_year_modernity_score | |
from typicality import get_typicality_score | |
from viewpoint import get_viewpoint | |
if torch.cuda.is_available(): | |
DEVICE = 'cuda' | |
# elif torch.backends.mps.is_available(): | |
# DEVICE = 'mps' | |
else: | |
DEVICE = 'cpu' | |
# Load model for instance segmentation using Detectron2 | |
from detectron2.config import get_cfg | |
from detectron2.engine import DefaultPredictor | |
from detectron2 import model_zoo | |
cfg = get_cfg() | |
cfg.MODEL.DEVICE = DEVICE | |
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) | |
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") | |
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
predictor = DefaultPredictor(cfg) | |
###---MODEL: VIEWPOINT---### | |
model_viewpoint = models.resnet18(weights=True) | |
num_ftrs = model_viewpoint.fc.in_features | |
model_viewpoint.fc = nn.Linear(num_ftrs, 2) | |
model_viewpoint = model_viewpoint.to(DEVICE) | |
model_viewpoint.load_state_dict(torch.load('models_morphs/pdl-a4-viewpoint-best-model.pt', map_location=DEVICE)) | |
###---MODEL: MODERNITY---### | |
model_modernity = models.resnet18(weights=True) | |
num_ftrs = model_modernity.fc.in_features | |
model_modernity.fc = nn.Linear(num_ftrs, 5) | |
model_modernity = model_modernity.to(DEVICE) | |
model_modernity.load_state_dict(torch.load('models_morphs/pdl-a3-modernity-best-model.pt', map_location=DEVICE)) | |
###---MODEL: TYPICALITY---### | |
model_typicality = models.resnet18(weights=True) | |
num_ftrs = model_typicality.fc.in_features | |
model_typicality.fc = nn.Linear(num_ftrs, 5) | |
model_typicality = model_typicality.to(DEVICE) | |
model_typicality.load_state_dict(torch.load('models_morphs/pdl-a3-typicality-best-model.pt', map_location=DEVICE)) | |
# Output labels | |
LABELS = ['Design Modernity score (scaled)', 'Design Typicality score'] | |
# For later image transformations | |
imgTransforms = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Lambda(lambda x: x[:3, :, :]), # Keep only RGB channels and ignore alpha channel | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Predict classes | |
def predict(inp): | |
###---------------INSTANCE SEGMENTATION---------------### | |
# Convert RGB to BGR and predict segmentation | |
inp_seg = inp[:, :, ::-1] | |
outputs = predictor(inp_seg) | |
# Check if any automobile instances were found | |
instances = outputs["instances"][outputs["instances"].pred_classes == 2] | |
if len(instances) == 0: | |
return "Found no automobile instances in the image.", None, None, None, None | |
# Select instance with the largest number of pixels | |
boxes = instances.pred_boxes.tensor | |
areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |
selected_instance_idx = int(torch.argmax(areas)) | |
selected_instance = instances[selected_instance_idx] | |
###---------------BACKGROUND MASK---------------### | |
# Find the index of class ID 2 (Automobile class) and create mask of that instance | |
class_ids = np.array(selected_instance.pred_classes.cpu()) | |
class_index = np.where(class_ids == 2) | |
mask_tensor = selected_instance.pred_masks[class_index] | |
auto_mask = np.array(mask_tensor[0]) | |
# Create 4-channel mask | |
mask = np.stack([auto_mask, auto_mask, auto_mask, auto_mask], axis=2) | |
# Split into RGB channels | |
b, g, r = cv2.split(inp.astype("uint8")) | |
# Create alpha channel array of ones (will become background) | |
a = np.ones(auto_mask.shape, dtype="uint8") * 255 # 255 to make transparent | |
a[auto_mask] = 0 | |
# Merge alpha channel | |
merged_image = cv2.merge([b, g, r, a], 4) | |
# Create white background with same shape as merged_image | |
bg = np.ones_like(merged_image, dtype=np.uint8) * 255 | |
# Copy color pixels from the original color image from mask | |
resulting_image = np.where(mask, merged_image, bg).astype(np.uint8) | |
# Crop image to the size of the selected instance | |
x0, y0, x1, y1 = selected_instance.get('pred_boxes').tensor[0].tolist() | |
resulting_image_cropped = resulting_image[int(y0):int(y1), int(x0):int(x1)].copy() | |
# Ensure the resulting image has three dimensions (RGB) by summing the alpha dimension over axis 2 | |
resulting_image_cropped = resulting_image_cropped[:, :, :-1] + resulting_image_cropped[:, :, -1][:, :, None] | |
###---------------TRANSFORM IMAGE---------------### | |
resulting_image_cropped_transformed = Image.fromarray(resulting_image_cropped) | |
resulting_image_cropped_transformed = imgTransforms(resulting_image_cropped_transformed).unsqueeze(0).to(DEVICE) | |
###---------------VIEWPOINT CHECK---------------### | |
viewpoint = get_viewpoint(model_viewpoint, resulting_image_cropped_transformed, DEVICE) | |
if viewpoint == 0: | |
return ( | |
"Found " + str(len(instances)) + " automobile instances in the image.", | |
np.array(resulting_image_cropped), | |
"The model couldn't detected a frontal viewpoint of the automobile. It cannot proceed to calculate the design scores.", | |
None, None | |
) | |
###---------------CLASSIFICATIONS---------------### | |
modernity_score, year_group = get_year_modernity_score(model_modernity, resulting_image_cropped_transformed, DEVICE) | |
typicality_score = get_typicality_score(model_typicality, resulting_image_cropped_transformed, year_group, DEVICE) | |
# Prepare the labels and scores | |
labels_with_scores = { | |
LABELS[0]: float(modernity_score)/4, | |
LABELS[1]: float(typicality_score) | |
} | |
return ( | |
"Found " + str(len(instances)) + " automobile instances in the image.", | |
np.array(resulting_image_cropped), | |
"The model has detected a frontal viewpoint of the automobile. It proceeded to calculate the design scores.", | |
labels_with_scores, | |
"Design Modernity Score (absolute): " + str(modernity_score) + "\n" | |
"Design Typicality Score: " + str(typicality_score) | |
) | |
# Create Gradio Interface | |
TITLE = "Segmentation and Classification of Automobiles" | |
DESCRIPTION = "Demo for the segmentation and classification of automobiles along with their design modernity and typicality scores. To use it, simply upload your image, or click one of the examples to load them." | |
EXAMPLES = [['sample_imgs/BMW_6_front.jpg'],['sample_imgs/BMW_6_back.jpg'],['sample_imgs/BMW_multiple.jpg'],['sample_imgs/BMW_multiple_back.jpeg'],['sample_imgs/pony.jpeg']] | |
iface = gr.Interface(predict, | |
inputs='image', | |
outputs=["label", "image", "label", "label", "text"], | |
title=TITLE, | |
description=DESCRIPTION, | |
examples=EXAMPLES, | |
cache_examples=False) | |
iface.launch() |