import os import time import numpy as np from tqdm import tqdm from PIL import Image, ImageDraw from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation import torch import cv2 def dilate_image_mask(image_mask: Image, dilate_siz=50): # Convert the PIL image to a NumPy array image_np = np.array(image_mask) kernel = np.ones((dilate_siz, dilate_siz),np.uint8) dilated_image_np = cv2.dilate(image_np, kernel, iterations = 1) # Convert the expanded NumPy array back to PIL format dilated_image = Image.fromarray(dilated_image_np) return dilated_image def get_foreground_image(image: Image, mask_array: np.ndarray): """Returns a PIL RGBA image with the mask applied to the original image.""" # resize the overlay mask to the original image size resized_mask = Image.fromarray(mask_array.astype(np.uint8)).resize(image.size) resized_mask = np.array(resized_mask) image_array = np.array(image) # Apply binary mask element-wise using NumPy for each color channel fg_array = image_array * resized_mask[:, :, np.newaxis] # Create a new ndarray with 4 channels (R, G, B, A) result_array = np.zeros((*fg_array.shape[:2], 4), dtype=np.uint8) # Assign RGB values from the original image result_array[:, :, :3] = fg_array # Assign alpha values from the resized mask result_array[:, :, 3] = resized_mask*255 result_image = Image.fromarray(result_array, mode='RGBA') return result_image def overlay_mask_on_image(image: Image, mask_array: np.ndarray, alpha=0.5): original_image = image overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 0)) # resize the overlay mask to the original image size overlay_mask = Image.fromarray(mask_array.astype(np.uint8)*255).resize(original_image.size, resample=Image.LANCZOS) # dilates the mask a bit to cover the edges of the objects dilate_image_mask(overlay_mask, dilate_siz=50) # Apply the overlay color to the overlayed array overlay_color = (0, 240, 0, int(255*alpha)) # RGBA draw = ImageDraw.Draw(overlay_image) draw.bitmap((0, 0), overlay_mask, fill=overlay_color) result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image) return result_image def filter_segment_classes(segmentation, filter_classes, mode='filt_out') -> np.ndarray: """ Returns a boolean mask removing the values in filter_classes from the segmentation array. mode: 'filt_out' - filter out the classes in filter_classes 'filt_in' - keeps only the classes in filter_classes """ # Create a boolean mask removing the values in filter_classes if mode=='filt_out': overlay_mask = ~np.isin(segmentation, filter_classes) elif mode=='filt_in': overlay_mask = np.isin(segmentation, filter_classes) else: raise ValueError(f'Invalid mode: {mode}') return overlay_mask class Mask2FormerSegmenter: def __init__(self): self.processor = None self.model = None self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # TODO - train a classifier to learn this from the dataset # - classes that appear much less frequently are good candidates self.filter_classes = [0,1,2,3,5,6,10,11,12,13,14,15,18,19,22,24,36,38,40,45,46,47,69,105,128] def load_models(self, checkpoint_name): self.processor = AutoImageProcessor.from_pretrained(checkpoint_name) self.model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint_name) self.model.to(self.device) @torch.no_grad() def run_semantic_inference(self, image, model, processor)-> torch.Tensor: """Runs semantic segmentation inference on a single image file.""" if (model is None) or (processor is None): raise ValueError(f'Model or Processor not loaded.') funcstart_time = time.time() inputs = processor(image, return_tensors="pt") inputs = inputs.to(self.device) #Forward pass - to segment the image outputs = model(**inputs) #meaures the time taken for the processing and forward pass model_time = time.time() - funcstart_time print(f'Model time: {model_time:.2f}') #Post Processing - Semantic Segmentation semantic_segmentation = processor.post_process_semantic_segmentation(outputs)[0] return semantic_segmentation def batch_inference_demo(self, dirpath): # List image files in the input directory image_files = [file for file in os.listdir(dirpath) if file.lower().endswith(('.jpg', '.jpeg', '.png'))] for file in tqdm(image_files, desc="Processing images"): filepath = os.path.join(dirpath, file) image = Image.open(filepath) semantic_segmentation = self.run_semantic_inference(image, self.model, self.processor) labels_ids = torch.unique(semantic_segmentation).tolist() valid_ids = [label_id for label_id in labels_ids if label_id not in self.filter_classes] print(f'{os.path.basename(file)}: {valid_ids}') # filter out the classes in filter_classes binary_mask = filter_segment_classes(semantic_segmentation.numpy(), self.filter_classes) overlaid_img = overlay_mask_on_image(image, binary_mask) foreground_img = get_foreground_image(image, binary_mask) mask_img = Image.fromarray(binary_mask.astype(np.uint8)*255).resize(image.size) # dilates the mask a bit mask_img = dilate_image_mask(mask_img, dilate_siz=50) #saves the images in the results folder outp_folder = 'results/mask2former_masked' overlaid_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_overlay.png") foreground_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_foreground.png") mask_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_mask.png") def retrieve_fg_image_and_mask(self, input_image: Image, dilate_siz=50, verbose=False ) -> (Image, Image): """Generetes a RGBA image with the foreground objects of the input image and a binary mask for the given image file. input_image: PIL image dilate_siz: size in pixels of the dilation kernel to aply on the objects' mask verbose: if True, prints the list of classes in the image that have not been filtered returns: foreground_img (RGBA), mask_img (L) """ # runs the semantic segmentation model semantic_segmentation = self.run_semantic_inference(input_image, self.model, self.processor) semantic_segmentation = semantic_segmentation.cpu() if (verbose): labels_ids = torch.unique(semantic_segmentation).tolist() valid_ids = [label_id for label_id in labels_ids if label_id not in self.filter_classes] print(f'valid classes detected: {valid_ids}') # filter out the classes in filter_classes binary_mask = filter_segment_classes(semantic_segmentation.numpy(), self.filter_classes) foreground_img = get_foreground_image(input_image, binary_mask) mask_img = Image.fromarray(binary_mask.astype(np.uint8)*255).resize(input_image.size, resample=Image.LANCZOS) # dilates the mask a bit to cover the edges of the objects. This helps the inpainting model mask_img = dilate_image_mask(mask_img, dilate_siz=dilate_siz) return foreground_img, mask_img