from collections import defaultdict import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib import cm import cv2 from PIL import Image import numpy as np import torch from transformers import AutoImageProcessor, UperNetForSemanticSegmentation from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation from diffusers import StableDiffusionInpaintPipeline class VirtualStagingToolV2(): def __init__(self, segmentation_version='openmmlab/upernet-convnext-tiny', diffusion_version="stabilityai/stable-diffusion-2-inpainting" ): self.segmentation_version = segmentation_version self.diffusion_version = diffusion_version if segmentation_version == "openmmlab/upernet-convnext-tiny": self.feature_extractor = AutoImageProcessor.from_pretrained(self.segmentation_version) self.segmentation_model = UperNetForSemanticSegmentation.from_pretrained(self.segmentation_version) elif segmentation_version == "nvidia/segformer-b5-finetuned-ade-640-640": self.feature_extractor = SegformerFeatureExtractor.from_pretrained(self.segmentation_version) self.segmentation_model = SegformerForSemanticSegmentation.from_pretrained(self.segmentation_version) self.diffution_pipeline = StableDiffusionInpaintPipeline.from_pretrained( self.diffusion_version, torch_dtype=torch.float16, ) self.diffution_pipeline = self.diffution_pipeline.to("cuda") def _predict(self, image): inputs = self.feature_extractor(images=image, return_tensors="pt") outputs = self.segmentation_model(**inputs) prediction = \ self.feature_extractor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] return prediction def _save_mask(self, img, prediction_array, mask_items=[]): mask = np.zeros_like(prediction_array, dtype=np.uint8) mask[np.isin(prediction_array, mask_items)] = 0 mask[~np.isin(prediction_array, mask_items)] = 255 buffer_size = 10 # Dilate the binary image kernel = np.ones((buffer_size, buffer_size), np.uint8) dilated_image = cv2.dilate(mask, kernel, iterations=1) # Subtract the original binary image buffer_area = dilated_image - mask # Apply buffer area to the original image mask = cv2.bitwise_or(mask, buffer_area) # # # Create a PIL Image object from the mask mask_image = Image.fromarray(mask, mode='L') # display(mask_image) # mask_image = mask_image.resize((512, 512)) # mask_image.save(".tmp/mask_1.png", "PNG") # img = img.resize((512, 512)) # img.save(".tmp/input_1.png", "PNG") return mask_image def _save_transparent_mask(self, img, prediction_array, mask_items=[]): mask = np.array(img) mask[~np.isin(prediction_array, mask_items), :] = 255 mask_image = Image.fromarray(mask).convert('RGBA') # Set the transparency of the pixels corresponding to object 1 to 0 (fully transparent) mask_data = mask_image.getdata() mask_data = [(r, g, b, 0) if r == 255 else (r, g, b, 255) for (r, g, b, a) in mask_data] mask_image.putdata(mask_data) return mask_image def get_mask(self, image_path=None, image=None): if image_path: image = Image.open(image_path) else: if not image: raise ValueError("no image provided") # display(image) prediction = self._predict(image) label_ids = np.unique(prediction) mask_items = [0, 3, 5, 8, 14] if 1 in label_ids or 25 in label_ids: mask_items = [1, 2, 4, 25, 32] room = 'backyard' elif 73 in label_ids or 50 in label_ids or 61 in label_ids: mask_items = [0, 3, 5, 8, 14, 50, 61, 71, 73, 118, 124, 129 ] room = 'kitchen' elif 37 in label_ids or 65 in label_ids or (27 in label_ids and 47 in label_ids and 70 in label_ids): mask_items = [0, 3, 5, 8, 14, 27, 65] room = 'bathroom' elif 7 in label_ids: room = 'bedroom' elif 23 in label_ids or 49 in label_ids: mask_items = [0, 3, 5, 8, 14, 49] room = 'living room' elif 15 in label_ids and 19 in label_ids: room = 'dining room' else: room ='room' label_ids_without_mask = [i for i in label_ids if i not in mask_items] items = [self.segmentation_model.config.id2label[i] for i in label_ids_without_mask] mask_image = self._save_mask(image, prediction, mask_items) transparent_mask_image = self._save_transparent_mask(image, prediction, mask_items) return mask_image, transparent_mask_image, image, items, room def _edit_image(self, init_image, mask_image, prompt, # height, width, number_images=1): init_image = init_image.resize((512, 512)).convert("RGB") mask_image = mask_image.resize((512, 512)).convert("RGB") output_images = self.diffution_pipeline( prompt=prompt, image=init_image, mask_image=mask_image, # width=width, height=height, num_images_per_prompt=number_images).images # display(output_image) return output_images def virtual_stage(self, image_path=None, image=None, style=None, color_preference=None, additional_info=None, number_images=1): mask_image, transparent_mask_image, init_image, items, room = self.get_mask(image_path, image) if not style: raise ValueError('style not provided.') if room == 'kitchen': items = [i for i in items if i in ['cabinet', 'shelf', 'counter', 'countertop', 'stool']] elif room == 'bedroom': items = [i for i in items if i in ['bed ', 'table', 'chest of drawers', 'desk', 'armchair', 'wardrobe']] elif room == 'bathroom': items = [i for i in items if i in ['shower', 'bathtub', 'screen door', 'cabinet']] elif room == 'living room': items = [i for i in items if i in ['table', 'sofa', 'chest of drawers', 'armchair', 'cabinet', 'coffee table']] elif room == 'dining room': items = [i for i in items if i in ['table', 'chair', 'cabinet']] items = ', '.join(items) if room == 'backyard': prompt = f'Realistic, high resolution, {room} with {style}' else: prompt = f'Realistic {items}, high resolution, in the {style} style {room}' if color_preference: prompt = f"{prompt} in {color_preference}" if additional_info: prompt = f'{prompt}. {additional_info}' print(prompt) output_images = self._edit_image(init_image, mask_image, prompt, number_images) final_output_images = [] for output_image in output_images: output_image = output_image.resize(init_image.size) final_output_images.append(output_image) return final_output_images, transparent_mask_image