interior-design / preprocessing.py
BertChristiaens's picture
Duplicate from BertChristiaens/controlnet-seg-backup
3d4d894
raw
history blame
4.53 kB
"""Preprocessing methods"""
import logging
from typing import List, Tuple
import numpy as np
from PIL import Image, ImageFilter
import streamlit as st
from config import COLOR_RGB, WIDTH, HEIGHT
# from enhance_config import ENHANCE_SETTINGS
LOGGING = logging.getLogger(__name__)
def preprocess_seg_mask(canvas_seg, real_seg: Image.Image = None) -> Tuple[np.ndarray, np.ndarray]:
"""Preprocess the segmentation mask.
Args:
canvas_seg: segmentation canvas
real_seg (Image.Image, optional): segmentation mask. Defaults to None.
Returns:
Tuple[np.ndarray, np.ndarray]: segmentation mask, segmentation mask with overlay
"""
# get unique colors in the segmentation
image_seg = canvas_seg.image_data.copy()[:, :, :3]
# average the colors of the segmentation masks
average_color = np.mean(image_seg, axis=(2))
mask = average_color[:, :] > 0
if mask.sum() > 0:
mask = mask * 1
unique_colors = np.unique(image_seg.reshape(-1, image_seg.shape[-1]), axis=0)
unique_colors = [tuple(color) for color in unique_colors]
unique_colors = [color for color in unique_colors if np.sum(
np.all(image_seg == color, axis=-1)) > 100]
unique_colors_exact = [color for color in unique_colors if color in COLOR_RGB]
if real_seg is not None:
overlay_seg = np.array(real_seg)
unique_colors = np.unique(overlay_seg.reshape(-1, overlay_seg.shape[-1]), axis=0)
unique_colors = [tuple(color) for color in unique_colors]
for color in unique_colors_exact:
if color != (255, 255, 255) and color != (0, 0, 0):
overlay_seg[np.all(image_seg == color, axis=-1)] = color
image_seg = overlay_seg
return mask, image_seg
def get_mask(image_mask: np.ndarray) -> np.ndarray:
"""Get the mask from the segmentation mask.
Args:
image_mask (np.ndarray): segmentation mask
Returns:
np.ndarray: mask
"""
# average the colors of the segmentation masks
average_color = np.mean(image_mask, axis=(2))
mask = average_color[:, :] > 0
if mask.sum() > 0:
mask = mask * 1
return mask
def get_image() -> np.ndarray:
"""Get the image from the session state.
Returns:
np.ndarray: image
"""
if 'initial_image' in st.session_state and st.session_state['initial_image'] is not None:
initial_image = st.session_state['initial_image']
if isinstance(initial_image, Image.Image):
return np.array(initial_image.resize((WIDTH, HEIGHT)))
else:
return np.array(Image.fromarray(initial_image).resize((WIDTH, HEIGHT)))
else:
return None
# def make_enhance_config(segmentation, objects=None):
"""Make the enhance config for the segmentation image.
"""
info = ENHANCE_SETTINGS[objects]
segmentation = np.array(segmentation)
if 'replace' in info:
replace_color = info['replace']
mask = np.zeros(segmentation.shape)
for color in info['colors']:
mask[np.all(segmentation == color, axis=-1)] = [1, 1, 1]
segmentation[np.all(segmentation == color, axis=-1)] = replace_color
if info['inverse'] is False:
mask = np.zeros(segmentation.shape)
for color in info['colors']:
mask[np.all(segmentation == color, axis=-1)] = [1, 1, 1]
else:
mask = np.ones(segmentation.shape)
for color in info['colors']:
mask[np.all(segmentation == color, axis=-1)] = [0, 0, 0]
st.session_state['positive_prompt'] = info['positive_prompt']
st.session_state['negative_prompt'] = info['negative_prompt']
if info['inpainting'] is True:
mask = mask.astype(np.uint8)
mask = Image.fromarray(mask)
mask = mask.filter(ImageFilter.GaussianBlur(radius=13))
mask = mask.filter(ImageFilter.MaxFilter(size=9))
mask = np.array(mask)
mask[mask < 0.1] = 0
mask[mask >= 0.1] = 1
mask = mask.astype(np.uint8)
conditioning = dict(
mask_image=mask,
positive_prompt=info['positive_prompt'],
negative_prompt=info['negative_prompt'],
)
else:
conditioning = dict(
mask_image=mask,
controlnet_conditioning_image=segmentation,
positive_prompt=info['positive_prompt'],
negative_prompt=info['negative_prompt'],
strength=info['strength']
)
return conditioning, info['inpainting']