File size: 4,918 Bytes
68a69f9
 
 
 
 
04c9e40
ce088ab
 
04c9e40
 
68a69f9
 
 
0ddd61a
68a69f9
 
 
 
 
0ddd61a
68a69f9
 
 
 
ce088ab
68a69f9
ce088ab
68a69f9
ce088ab
 
68a69f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeaceee
68a69f9
 
 
 
 
 
 
 
 
 
aeaceee
 
c7ab35b
68a69f9
aeaceee
 
 
 
 
 
 
 
 
 
68a69f9
 
 
 
ce088ab
 
68a69f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce088ab
68a69f9
 
aeaceee
 
 
ce088ab
68a69f9
 
 
 
 
 
ce088ab
68a69f9
ce088ab
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import random
import numpy as np
from PIL import Image
from collections import defaultdict
import os
# Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
# Hence, installing detectron2 this way when using Gradio HF spaces.
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')

from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from color_palette import ade_palette
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation

def load_model_and_processor(model_ckpt: str):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
    model.eval()
    image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
    return model, image_preprocessor

def load_default_ckpt(segmentation_task: str):
    if segmentation_task == "semantic":
        default_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
    elif segmentation_task == "instance":
        default_ckpt = "facebook/mask2former-swin-small-coco-instance"
    else:
        default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
    return default_ckpt

def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
    metadata = MetadataCatalog.get("coco_2017_val_panoptic")
    for res in seg_info:
        res['category_id'] = res.pop('label_id')
        pred_class = res['category_id']
        isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
        res['isthing'] = bool(isthing)

    visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
    out = visualizer.draw_panoptic_seg_predictions(
    predicted_segmentation_map.cpu(), seg_info, alpha=0.5
    )
    output_img = Image.fromarray(out.get_image())
    return output_img

def draw_semantic_segmentation(segmentation_map, image, palette):
    
    color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
    for label, color in enumerate(palette):
        color_segmentation_map[segmentation_map - 1 == label, :] = color
    # Convert to BGR
    ground_truth_color_seg = color_segmentation_map[..., ::-1]

    img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5
    img = img.astype(np.uint8)
    return img

def visualize_instance_seg_mask(mask, input_image):
    color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    
    labels = np.unique(mask)
    label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}

    for label, color in label2color.items():
        color_segmentation_map[mask - 1 == label, :] = color

    ground_truth_color_seg = color_segmentation_map[..., ::-1]

    img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5
    img = img.astype(np.uint8)
    return img

def predict_masks(input_img_path: str, segmentation_task: str):
    
    #load model and image processor
    default_ckpt = load_default_ckpt(segmentation_task)
    model, image_processor = load_model_and_processor(default_ckpt)
    
    ## pass input image through image processor
    image = Image.open(input_img_path)
    inputs = image_processor(images=image, return_tensors="pt")
    
    ## pass inputs to model for prediction
    with torch.no_grad():
        outputs = model(**inputs)
    
    # pass outputs to processor for postprocessing
    if segmentation_task == "semantic":
        result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
        predicted_segmentation_map = result.cpu().numpy()
        palette = ade_palette()
        output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)
        output_heading = "Semantic Segmentation Output"

    elif segmentation_task == "instance":
        result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
        predicted_instance_map = result["segmentation"].cpu().detach().numpy()
        output_result = visualize_instance_seg_mask(predicted_instance_map, image)
        output_heading = "Instance Segmentation Output"

    else:
        result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
        predicted_segmentation_map = result["segmentation"]
        seg_info = result['segments_info']
        output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
        output_heading = "Panoptic Segmentation Output"


    return output_result, output_heading