File size: 3,724 Bytes
09a5e50
 
 
 
 
 
 
 
 
e7953d7
09a5e50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7953d7
 
 
09a5e50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7953d7
09a5e50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c39785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7953d7
3c39785
 
 
09a5e50
3c39785
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import time
import io
import cv2
import numpy as np

from camera_input_live import camera_input_live
from inference import inpainting
st.set_page_config(layout="wide")


def make_canvas(image):
    canvas_dict = dict(
        fill_color='#F00000',
        stroke_color='#000000',
        background_color="#FFFFFF",
        background_image=image,
        stroke_width=40,
        update_streamlit=True,
        height=512,
        width=512,
        drawing_mode='freedraw',
        key="canvas"
    )
    return st_canvas(**canvas_dict)


def get_mask(image_mask: np.ndarray) -> np.ndarray:
    """Get the mask from the segmentation mask.
    Args:
        image_mask (np.ndarray): segmentation mask
    Returns:
        np.ndarray: mask
    """
    # average the colors of the segmentation masks
    average_color = np.mean(image_mask, axis=(2))
    mask = average_color[:, :] > 0
    if mask.sum() > 0:
        mask = mask * 1
    # 3 channels
    mask = np.stack([mask, mask, mask], axis=2)*255
    mask = mask.astype(np.uint8)
    mask = Image.fromarray(mask).convert("RGB")
    return mask


def make_prompt_fields():
    st.write("### Prompting")
    # prompt
    prompt = st.text_input("Prompt", value="A person in a room with colored hair", key="prompt")
    # negative prompt
    negative_prompt = st.text_input("Negative Prompt", value="Facial hair", key="negative_prompt")

    return prompt, negative_prompt

def make_input_fields():
    st.write("### Parameters")
    guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=50.0, value=7.5, step=0.25, key="guidance_scale")
    inference_steps = st.slider("Inference Steps", min_value=1, max_value=50, value=20, step=1, key="inference_steps")
    generator_seed = st.slider("Generator Seed", min_value=0, max_value=10_000, value=0, step=1, key="generator_seed")

    st.write("### Latent walk")
    static_latents = st.checkbox("Static Latents", value=False, key="static_latents")
    latent_walk = st.slider("Latent Walk", min_value=0.0, max_value=1.0, value=0.0, step=0.01, key="latent_walk")
    
    return guidance_scale, inference_steps, generator_seed, static_latents, latent_walk

def decode_image(image):
    cv2_img = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR)
    cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB)
    image = Image.fromarray(cv2_img).convert("RGB").resize((512, 512))
    return image

if __name__ == "__main__":

    st.sidebar.title("Sidebar")

    with st.sidebar:
        webcam = camera_input_live(debounce=1000, key="webcam", width=512, height=512)
        prompt, negative_prompt = make_prompt_fields()

        guidance_scale, inference_steps, generator_seed, static_latents, latent_walk = make_input_fields()


    colA, colB = st.columns(2)

    if webcam:
        with colA:
            st.write("## Webcam image")
            st.write("You can draw the mask on the image below.")
            image = decode_image(webcam.getvalue())
    
            canvas = make_canvas(image)
        
        if st.button("Inpaint"):
            st.write("Start inpainting process")
            mask_image = get_mask(np.array(canvas.image_data))
            result = inpainting(image, mask_image, prompt, negative_prompt)
            st.session_state["result"] = result
        else:
            result = None
    
        with colB:
            st.write("## Generated image")
            st.write("The generated image will appear here.")
            st.image(webcam)
            if 'result' in st.session_state:
                print("Showing result")
                st.image(st.session_state["result"])