File size: 8,182 Bytes
d7713d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import time
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageDraw
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
import torch
import cv2

def dilate_image_mask(image_mask: Image, dilate_siz=50):
    # Convert the PIL image to a NumPy array
    image_np = np.array(image_mask)
    kernel = np.ones((dilate_siz, dilate_siz),np.uint8)
    dilated_image_np = cv2.dilate(image_np, kernel, iterations = 1)
    # Convert the expanded NumPy array back to PIL format
    dilated_image = Image.fromarray(dilated_image_np)
    
    return dilated_image

def get_foreground_image(image: Image, mask_array: np.ndarray):
    """Returns a PIL RGBA image with the mask applied to the original image."""
    
    # resize the overlay mask to the original image size
    resized_mask = Image.fromarray(mask_array.astype(np.uint8)).resize(image.size)    
    resized_mask = np.array(resized_mask)
    
    image_array = np.array(image)
    # Apply binary mask element-wise using NumPy for each color channel
    fg_array = image_array * resized_mask[:, :, np.newaxis]
    # Create a new ndarray with 4 channels (R, G, B, A)
    result_array = np.zeros((*fg_array.shape[:2], 4), dtype=np.uint8)
    # Assign RGB values from the original image
    result_array[:, :, :3] = fg_array
    # Assign alpha values from the resized mask
    result_array[:, :, 3] = resized_mask*255
    result_image = Image.fromarray(result_array, mode='RGBA')
    
    return result_image


def overlay_mask_on_image(image: Image, mask_array: np.ndarray, alpha=0.5):
    original_image = image
    overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 0))

    # resize the overlay mask to the original image size
    overlay_mask = Image.fromarray(mask_array.astype(np.uint8)*255).resize(original_image.size, resample=Image.LANCZOS)
    
    # dilates the mask a bit to cover the edges of the objects
    dilate_image_mask(overlay_mask, dilate_siz=50)
    
    # Apply the overlay color to the overlayed array
    overlay_color = (0, 240, 0, int(255*alpha))  # RGBA
    draw = ImageDraw.Draw(overlay_image)
    draw.bitmap((0, 0), overlay_mask, fill=overlay_color)
        
    result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
    return result_image

def filter_segment_classes(segmentation, filter_classes, mode='filt_out') -> np.ndarray:
    """ Returns a boolean mask removing the values in filter_classes from the segmentation array.

    mode: 'filt_out' - filter out the classes in filter_classes

          'filt_in'  - keeps only the classes in filter_classes

    """
    # Create a boolean mask removing the values in filter_classes
    if mode=='filt_out':
        overlay_mask = ~np.isin(segmentation, filter_classes)
    elif mode=='filt_in':
        overlay_mask = np.isin(segmentation, filter_classes)
    else:
        raise ValueError(f'Invalid mode: {mode}')
    return overlay_mask

class Mask2FormerSegmenter:
    def __init__(self):
        self.processor = None
        self.model = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        # TODO - train a classifier to learn this from the dataset 
        # - classes that appear much less frequently are good candidates
        self.filter_classes = [0,1,2,3,5,6,10,11,12,13,14,15,18,19,22,24,36,38,40,45,46,47,69,105,128] 
    
    def load_models(self, checkpoint_name):
        self.processor = AutoImageProcessor.from_pretrained(checkpoint_name)
        self.model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint_name)
        self.model.to(self.device)
    
    @torch.no_grad()
    def run_semantic_inference(self, image, model, processor)-> torch.Tensor:
        """Runs semantic segmentation inference on a single image file."""
        
        if (model is None) or (processor is None):
            raise ValueError(f'Model or Processor not loaded.')
        
        funcstart_time = time.time()
        
        inputs = processor(image, return_tensors="pt")
        inputs = inputs.to(self.device)
        #Forward pass - to segment the image
        outputs = model(**inputs)
        #meaures the time taken for the processing and forward pass
        model_time = time.time() - funcstart_time
        print(f'Model time: {model_time:.2f}')
        
        #Post Processing - Semantic Segmentation
        semantic_segmentation = processor.post_process_semantic_segmentation(outputs)[0]
        return semantic_segmentation
            
    def batch_inference_demo(self, dirpath):

        # List image files in the input directory
        image_files = [file for file in os.listdir(dirpath) if file.lower().endswith(('.jpg', '.jpeg', '.png'))]

        for file in tqdm(image_files, desc="Processing images"):
            filepath = os.path.join(dirpath, file)
            image = Image.open(filepath)
            semantic_segmentation = self.run_semantic_inference(image, self.model, self.processor)
            
            labels_ids = torch.unique(semantic_segmentation).tolist()
            valid_ids = [label_id for label_id in labels_ids if label_id not in self.filter_classes]
            print(f'{os.path.basename(file)}: {valid_ids}')
            
            # filter out the classes in filter_classes
            binary_mask = filter_segment_classes(semantic_segmentation.numpy(), self.filter_classes)
            
            overlaid_img = overlay_mask_on_image(image, binary_mask)
            foreground_img = get_foreground_image(image, binary_mask)
            mask_img = Image.fromarray(binary_mask.astype(np.uint8)*255).resize(image.size)
            # dilates the mask a bit
            mask_img = dilate_image_mask(mask_img, dilate_siz=50)
            
            #saves the images in the results folder
            outp_folder = 'results/mask2former_masked'
            overlaid_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_overlay.png")
            foreground_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_foreground.png")
            mask_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_mask.png")

    def retrieve_fg_image_and_mask(self, input_image: Image, 

                                   dilate_siz=50,

                                   verbose=False

                                   ) -> (Image, Image):
        """Generetes a RGBA image with the foreground objects of the input image

        and a binary mask for the given image file.

        input_image: PIL image

        dilate_siz: size in pixels of the dilation kernel to aply on the objects' mask

        verbose: if True, prints the list of classes in the image that have not been filtered

        returns: foreground_img (RGBA), mask_img (L)

        """
        
        # runs the semantic segmentation model
        semantic_segmentation = self.run_semantic_inference(input_image,
                                                            self.model,
                                                            self.processor)
        semantic_segmentation = semantic_segmentation.cpu()
        
        if (verbose):
            labels_ids = torch.unique(semantic_segmentation).tolist()
            valid_ids = [label_id for label_id in labels_ids if label_id not in self.filter_classes]
            print(f'valid classes detected: {valid_ids}')
                
        # filter out the classes in filter_classes
        binary_mask = filter_segment_classes(semantic_segmentation.numpy(),
                                             self.filter_classes)    
        foreground_img = get_foreground_image(input_image, binary_mask)
        mask_img = Image.fromarray(binary_mask.astype(np.uint8)*255).resize(input_image.size, resample=Image.LANCZOS)
        # dilates the mask a bit to cover the edges of the objects. This helps the inpainting model
        mask_img = dilate_image_mask(mask_img, dilate_siz=dilate_siz)
        
        return foreground_img, mask_img