img_backswapper / maskformer.py
jgurzoni's picture
creating gradio app
d7713d2
import os
import time
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageDraw
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
import torch
import cv2
def dilate_image_mask(image_mask: Image, dilate_siz=50):
# Convert the PIL image to a NumPy array
image_np = np.array(image_mask)
kernel = np.ones((dilate_siz, dilate_siz),np.uint8)
dilated_image_np = cv2.dilate(image_np, kernel, iterations = 1)
# Convert the expanded NumPy array back to PIL format
dilated_image = Image.fromarray(dilated_image_np)
return dilated_image
def get_foreground_image(image: Image, mask_array: np.ndarray):
"""Returns a PIL RGBA image with the mask applied to the original image."""
# resize the overlay mask to the original image size
resized_mask = Image.fromarray(mask_array.astype(np.uint8)).resize(image.size)
resized_mask = np.array(resized_mask)
image_array = np.array(image)
# Apply binary mask element-wise using NumPy for each color channel
fg_array = image_array * resized_mask[:, :, np.newaxis]
# Create a new ndarray with 4 channels (R, G, B, A)
result_array = np.zeros((*fg_array.shape[:2], 4), dtype=np.uint8)
# Assign RGB values from the original image
result_array[:, :, :3] = fg_array
# Assign alpha values from the resized mask
result_array[:, :, 3] = resized_mask*255
result_image = Image.fromarray(result_array, mode='RGBA')
return result_image
def overlay_mask_on_image(image: Image, mask_array: np.ndarray, alpha=0.5):
original_image = image
overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 0))
# resize the overlay mask to the original image size
overlay_mask = Image.fromarray(mask_array.astype(np.uint8)*255).resize(original_image.size, resample=Image.LANCZOS)
# dilates the mask a bit to cover the edges of the objects
dilate_image_mask(overlay_mask, dilate_siz=50)
# Apply the overlay color to the overlayed array
overlay_color = (0, 240, 0, int(255*alpha)) # RGBA
draw = ImageDraw.Draw(overlay_image)
draw.bitmap((0, 0), overlay_mask, fill=overlay_color)
result_image = Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
return result_image
def filter_segment_classes(segmentation, filter_classes, mode='filt_out') -> np.ndarray:
""" Returns a boolean mask removing the values in filter_classes from the segmentation array.
mode: 'filt_out' - filter out the classes in filter_classes
'filt_in' - keeps only the classes in filter_classes
"""
# Create a boolean mask removing the values in filter_classes
if mode=='filt_out':
overlay_mask = ~np.isin(segmentation, filter_classes)
elif mode=='filt_in':
overlay_mask = np.isin(segmentation, filter_classes)
else:
raise ValueError(f'Invalid mode: {mode}')
return overlay_mask
class Mask2FormerSegmenter:
def __init__(self):
self.processor = None
self.model = None
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# TODO - train a classifier to learn this from the dataset
# - classes that appear much less frequently are good candidates
self.filter_classes = [0,1,2,3,5,6,10,11,12,13,14,15,18,19,22,24,36,38,40,45,46,47,69,105,128]
def load_models(self, checkpoint_name):
self.processor = AutoImageProcessor.from_pretrained(checkpoint_name)
self.model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint_name)
self.model.to(self.device)
@torch.no_grad()
def run_semantic_inference(self, image, model, processor)-> torch.Tensor:
"""Runs semantic segmentation inference on a single image file."""
if (model is None) or (processor is None):
raise ValueError(f'Model or Processor not loaded.')
funcstart_time = time.time()
inputs = processor(image, return_tensors="pt")
inputs = inputs.to(self.device)
#Forward pass - to segment the image
outputs = model(**inputs)
#meaures the time taken for the processing and forward pass
model_time = time.time() - funcstart_time
print(f'Model time: {model_time:.2f}')
#Post Processing - Semantic Segmentation
semantic_segmentation = processor.post_process_semantic_segmentation(outputs)[0]
return semantic_segmentation
def batch_inference_demo(self, dirpath):
# List image files in the input directory
image_files = [file for file in os.listdir(dirpath) if file.lower().endswith(('.jpg', '.jpeg', '.png'))]
for file in tqdm(image_files, desc="Processing images"):
filepath = os.path.join(dirpath, file)
image = Image.open(filepath)
semantic_segmentation = self.run_semantic_inference(image, self.model, self.processor)
labels_ids = torch.unique(semantic_segmentation).tolist()
valid_ids = [label_id for label_id in labels_ids if label_id not in self.filter_classes]
print(f'{os.path.basename(file)}: {valid_ids}')
# filter out the classes in filter_classes
binary_mask = filter_segment_classes(semantic_segmentation.numpy(), self.filter_classes)
overlaid_img = overlay_mask_on_image(image, binary_mask)
foreground_img = get_foreground_image(image, binary_mask)
mask_img = Image.fromarray(binary_mask.astype(np.uint8)*255).resize(image.size)
# dilates the mask a bit
mask_img = dilate_image_mask(mask_img, dilate_siz=50)
#saves the images in the results folder
outp_folder = 'results/mask2former_masked'
overlaid_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_overlay.png")
foreground_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_foreground.png")
mask_img.save(f"{outp_folder}/{os.path.basename(file).split('.')[0]}_mask.png")
def retrieve_fg_image_and_mask(self, input_image: Image,
dilate_siz=50,
verbose=False
) -> (Image, Image):
"""Generetes a RGBA image with the foreground objects of the input image
and a binary mask for the given image file.
input_image: PIL image
dilate_siz: size in pixels of the dilation kernel to aply on the objects' mask
verbose: if True, prints the list of classes in the image that have not been filtered
returns: foreground_img (RGBA), mask_img (L)
"""
# runs the semantic segmentation model
semantic_segmentation = self.run_semantic_inference(input_image,
self.model,
self.processor)
semantic_segmentation = semantic_segmentation.cpu()
if (verbose):
labels_ids = torch.unique(semantic_segmentation).tolist()
valid_ids = [label_id for label_id in labels_ids if label_id not in self.filter_classes]
print(f'valid classes detected: {valid_ids}')
# filter out the classes in filter_classes
binary_mask = filter_segment_classes(semantic_segmentation.numpy(),
self.filter_classes)
foreground_img = get_foreground_image(input_image, binary_mask)
mask_img = Image.fromarray(binary_mask.astype(np.uint8)*255).resize(input_image.size, resample=Image.LANCZOS)
# dilates the mask a bit to cover the edges of the objects. This helps the inpainting model
mask_img = dilate_image_mask(mask_img, dilate_siz=dilate_siz)
return foreground_img, mask_img