import streamlit as st from streamlit_drawable_canvas import st_canvas from PIL import Image from typing import Union import random import numpy as np import os import time from models import make_image_controlnet, make_inpainting from segmentation import segment_image from config import HEIGHT, WIDTH, POS_PROMPT, NEG_PROMPT, COLOR_MAPPING, map_colors, map_colors_rgb from palette import COLOR_MAPPING_CATEGORY from preprocessing import preprocess_seg_mask, get_image, get_mask # wide layout st.set_page_config(layout="wide") def on_upload() -> None: """Upload image to the canvas.""" if 'input_image' in st.session_state and st.session_state['input_image'] is not None: image = Image.open(st.session_state['input_image']).convert('RGB') st.session_state['initial_image'] = image if 'seg' in st.session_state: del st.session_state['seg'] if 'unique_colors' in st.session_state: del st.session_state['unique_colors'] if 'output_image' in st.session_state: del st.session_state['output_image'] def check_reset_state() -> bool: """Check whether the UI elements need to be reset Returns: bool: True if the UI elements need to be reset, False otherwise """ if ('reset_canvas' in st.session_state and st.session_state['reset_canvas']): st.session_state['reset_canvas'] = False return True st.session_state['reset_canvas'] = False return False def move_image(source: Union[str, Image.Image], dest: str, rerun: bool = True, remove_state: bool = True) -> None: """Move image from source to destination. Args: source (Union[str, Image.Image]): source image dest (str): destination image location rerun (bool, optional): rerun streamlit. Defaults to True. remove_state (bool, optional): remove the canvas state. Defaults to True. """ source_image = source if isinstance(source, Image.Image) else st.session_state[source] if remove_state: st.session_state['reset_canvas'] = True if 'seg' in st.session_state: del st.session_state['seg'] if 'unique_colors' in st.session_state: del st.session_state['unique_colors'] st.session_state[dest] = source_image if rerun: st.experimental_rerun() def on_change_radio() -> None: """Reset the UI elements when the radio button is changed.""" st.session_state['reset_canvas'] = True def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state): canvas_dict = dict( fill_color=canvas_color, stroke_color=canvas_color, background_color="#FFFFFF", background_image=st.session_state['initial_image'] if 'initial_image' in st.session_state else None, stroke_width=brush, initial_drawing={'version': '4.4.0', 'objects': []} if _reset_state else None, update_streamlit=True, height=512, width=512, drawing_mode=paint_mode, key="canvas", ) return canvas_dict def make_prompt_row(): col_0_0, col_0_1 = st.columns(2) with col_0_0: st.text_input(label="Positive prompt", value="a photograph of a room, interior design, 4k, high resolution", key='positive_prompt') with col_0_1: st.text_input(label="Negative prompt", value="", key='negative_prompt') def make_sidebar(): with st.sidebar: input_image = st.file_uploader("", type=["png", "jpg"], key='input_image', on_change=on_upload) generation_mode = st.selectbox("Generation mode", ["Re-generate objects", "Segmentation conditioning", "Inpainting"], on_change=on_change_radio) if generation_mode == "Segmentation conditioning": paint_mode = st.sidebar.selectbox("Painting mode", ("freedraw", "polygon")) if paint_mode == "freedraw": brush = st.slider("Stroke width", 5, 140, 100, key='slider_seg') else: brush = 5 category_chooser = st.sidebar.selectbox("Filter on category", list( COLOR_MAPPING_CATEGORY.keys()), index=0, key='category_chooser') chosen_colors = list(COLOR_MAPPING_CATEGORY[category_chooser].keys()) color_chooser = st.sidebar.selectbox( "Choose a color", chosen_colors, index=0, format_func=map_colors, key='color_chooser' ) elif generation_mode == "Re-generate objects": color_chooser = "rgba(0, 0, 0, 0.0)" paint_mode = 'freedraw' brush = 0 else: paint_mode = st.sidebar.selectbox("Painting mode", ("freedraw", "polygon")) if paint_mode == "freedraw": brush = st.slider("Stroke width", 5, 140, 100, key='slider_seg') else: brush = 5 color_chooser = "#000000" return input_image, generation_mode, brush, color_chooser, paint_mode def make_output_image(): if 'output_image' in st.session_state: output_image = st.session_state['output_image'] if isinstance(output_image, np.ndarray): output_image = Image.fromarray(output_image) if isinstance(output_image, Image.Image): output_image = output_image.resize((512, 512)) else: output_image = Image.new('RGB', (512, 512), (255, 255, 255)) st.write("#### Output image") st.image(output_image, width=512) if st.button("Move to input image"): move_image('output_image', 'initial_image', remove_state=True, rerun=True) def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode): st.write("#### Input image") canvas_dict = make_canvas_dict( canvas_color=canvas_color, paint_mode=paint_mode, brush=brush, _reset_state=_reset_state ) if generation_mode == "Segmentation conditioning": canvas = st_canvas( **canvas_dict, ) if st.button("generate image", key='generate_button'): image = get_image() print("Preparing image segmentation") real_seg = segment_image(Image.fromarray(image)) mask, seg = preprocess_seg_mask(canvas, real_seg) with st.spinner(text="Generating image"): print("Making image") result_image = make_image_controlnet(image=image, mask_image=mask, controlnet_conditioning_image=seg, positive_prompt=st.session_state['positive_prompt'], negative_prompt=st.session_state['negative_prompt'], seed=random.randint(0, 100000) # nosec )[0] if isinstance(result_image, np.ndarray): result_image = Image.fromarray(result_image) st.session_state['output_image'] = result_image elif generation_mode == "Re-generate objects": canvas = st_canvas( **canvas_dict, ) if 'seg' not in st.session_state: with st.spinner(text="Preparing image segmentation"): image = get_image() real_seg = np.array(segment_image(Image.fromarray(image))) st.session_state['seg'] = real_seg if 'unique_colors' not in st.session_state: real_seg = st.session_state['seg'] unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0) unique_colors = [tuple(color) for color in unique_colors] st.session_state['unique_colors'] = unique_colors chosen_colors = st.multiselect( label="Choose which concepts you want to regenerate in the image", options=st.session_state['unique_colors'], key='chosen_colors', default=st.session_state['unique_colors'], format_func=map_colors_rgb, ) with st.expander("Explanation", expanded=False): st.write("This mode allows you to choose which objects you want to re-generate in the image. " "Use the selection dropdown to add or remove objects. If you are ready, press the generate button" " to generate the image, which can take up to 30 seconds. If you want to improve the generated image, click" " the 'move image to input' button." ) if st.button("generate image", key='generate_button'): image = get_image() print(chosen_colors) segmentation = st.session_state['seg'] mask = np.zeros_like(segmentation) for color in chosen_colors: # if the color is in the segmentation, set mask to 1 mask[np.where((segmentation == color).all(axis=2))] = 1 with st.spinner(text="Generating image"): result_image = make_image_controlnet(image=image, mask_image=mask, controlnet_conditioning_image=segmentation, positive_prompt=st.session_state['positive_prompt'], negative_prompt=st.session_state['negative_prompt'], seed=random.randint(0, 100000) # nosec )[0] if isinstance(result_image, np.ndarray): result_image = Image.fromarray(result_image) st.session_state['output_image'] = result_image elif generation_mode == "Inpainting": image = get_image() canvas = st_canvas( **canvas_dict, ) if st.button("generate images", key='generate_button'): canvas_mask = canvas.image_data if not isinstance(canvas_mask, np.ndarray): canvas_mask = np.array(canvas_mask) mask = get_mask(canvas_mask) with st.spinner(text="Generating new images"): print("Making image") result_image = make_inpainting(positive_prompt=st.session_state['positive_prompt'], image=image, mask_image=mask, negative_prompt=st.session_state['negative_prompt'], )[0] if isinstance(result_image, np.ndarray): result_image = Image.fromarray(result_image) st.session_state['output_image'] = result_image def main(): # center text st.write("## Controlnet sprint - interior design", unsafe_allow_html=True) input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar() # check if there is an input_image if not ('input_image' in st.session_state and st.session_state['input_image'] is not None): print("Image not present") st.success("Upload an image to start") else: make_prompt_row() _reset_state = check_reset_state() col1, col2 = st.columns(2) with col1: make_editing_canvas(canvas_color=color_chooser, brush=brush, _reset_state=_reset_state, generation_mode=generation_mode, paint_mode=paint_mode ) with col2: make_output_image() if __name__ == "__main__": main()