bienom's picture
hotfix
59390e7
raw
history blame
4.35 kB
import cv2
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
class SegmentationTool:
def __init__(self,
segmentation_version='nvidia/segformer-b5-finetuned-ade-640-640'):
self.segmentation_version = segmentation_version
if segmentation_version == "openmmlab/upernet-convnext-tiny":
self.feature_extractor = AutoImageProcessor.from_pretrained(self.segmentation_version)
self.segmentation_model = UperNetForSemanticSegmentation.from_pretrained(self.segmentation_version)
elif segmentation_version == "nvidia/segformer-b5-finetuned-ade-640-640":
self.feature_extractor = SegformerFeatureExtractor.from_pretrained(self.segmentation_version)
self.segmentation_model = SegformerForSemanticSegmentation.from_pretrained(self.segmentation_version)
def _predict(self, image):
inputs = self.feature_extractor(images=image, return_tensors="pt")
outputs = self.segmentation_model(**inputs)
prediction = \
self.feature_extractor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
return prediction
def _save_mask(self, prediction_array, mask_items=[]):
mask = np.zeros_like(prediction_array, dtype=np.uint8)
mask[np.isin(prediction_array, mask_items)] = 0
mask[~np.isin(prediction_array, mask_items)] = 255
buffer_size = 10
# Dilate the binary image
kernel = np.ones((buffer_size, buffer_size), np.uint8)
dilated_image = cv2.dilate(mask, kernel, iterations=1)
# Subtract the original binary image
buffer_area = dilated_image - mask
# Apply buffer area to the original image
mask = cv2.bitwise_or(mask, buffer_area)
# # # Create a PIL Image object from the mask
mask_image = Image.fromarray(mask, mode='L')
# display(mask_image)
# mask_image = mask_image.resize((512, 512))
# mask_image.save(".tmp/mask_1.png", "PNG")
# img = img.resize((512, 512))
# img.save(".tmp/input_1.png", "PNG")
return mask_image
def _save_transparent_mask(self, img, prediction_array, mask_items=None):
if mask_items is None:
mask_items = []
mask = np.array(img)
mask[~np.isin(prediction_array, mask_items), :] = 255
mask_image = Image.fromarray(mask).convert('RGBA')
# Set the transparency of the pixels corresponding to object 1 to 0 (fully transparent)
mask_data = mask_image.getdata()
mask_data = [(r, g, b, 0) if r == 255 else (r, g, b, 255) for (r, g, b, a) in mask_data]
mask_image.putdata(mask_data)
return mask_image
def get_mask(self, image_path=None, image=None):
if image_path:
image = Image.open(image_path)
else:
if image is None:
raise ValueError("no image provided")
# display(image)
prediction = self._predict(image)
label_ids = np.unique(prediction)
# mask_items = [0, 3, 5, 8, 14]
mask_items = [8] # windowpane
if 73 in label_ids or 50 in label_ids or 61 in label_ids:
# mask_items = [0, 3, 5, 8, 14, 50, 61, 71, 73, 118, 124, 129]
room = 'kitchen'
elif 37 in label_ids or 65 in label_ids or (27 in label_ids and 47 in label_ids and 70 in label_ids):
# mask_items = [0, 3, 5, 8, 14, 27, 65]
room = 'bathroom'
elif 7 in label_ids:
room = 'bedroom'
elif 23 in label_ids or 49 in label_ids:
# mask_items = [0, 3, 5, 8, 14, 49]
room = 'living room'
elif 15 in label_ids and 19 in label_ids:
room = 'dining room'
else:
room = 'room'
label_ids_without_mask = [i for i in label_ids if i not in mask_items]
items = [self.segmentation_model.config.id2label[i] for i in label_ids_without_mask]
mask_image = self._save_mask(prediction, mask_items)
transparent_mask_image = self._save_transparent_mask(image, prediction, mask_items)
return mask_image, transparent_mask_image, image, items, room