Spaces:
Runtime error
Runtime error
import streamlit as st | |
from streamlit_drawable_canvas import st_canvas | |
from PIL import Image | |
import time | |
import io | |
import cv2 | |
import numpy as np | |
from camera_input_live import camera_input_live | |
from inference import inpainting | |
st.set_page_config(layout="wide") | |
def make_canvas(image): | |
canvas_dict = dict( | |
fill_color='#F00000', | |
stroke_color='#000000', | |
background_color="#FFFFFF", | |
background_image=image, | |
stroke_width=40, | |
update_streamlit=True, | |
height=512, | |
width=512, | |
drawing_mode='freedraw', | |
key="canvas" | |
) | |
return st_canvas(**canvas_dict) | |
def get_mask(image_mask: np.ndarray) -> np.ndarray: | |
"""Get the mask from the segmentation mask. | |
Args: | |
image_mask (np.ndarray): segmentation mask | |
Returns: | |
np.ndarray: mask | |
""" | |
# average the colors of the segmentation masks | |
average_color = np.mean(image_mask, axis=(2)) | |
mask = average_color[:, :] > 0 | |
if mask.sum() > 0: | |
mask = mask * 1 | |
# 3 channels | |
mask = np.stack([mask, mask, mask], axis=2)*255 | |
mask = mask.astype(np.uint8) | |
mask = Image.fromarray(mask).convert("RGB") | |
return mask | |
def make_prompt_fields(): | |
st.write("### Prompting") | |
# prompt | |
prompt = st.text_input("Prompt", value="A person in a room with colored hair", key="prompt") | |
# negative prompt | |
negative_prompt = st.text_input("Negative Prompt", value="Facial hair", key="negative_prompt") | |
return prompt, negative_prompt | |
def make_input_fields(): | |
st.write("### Parameters") | |
guidance_scale = st.slider("Guidance Scale", min_value=0.0, max_value=50.0, value=7.5, step=0.25, key="guidance_scale") | |
inference_steps = st.slider("Inference Steps", min_value=1, max_value=50, value=20, step=1, key="inference_steps") | |
generator_seed = st.slider("Generator Seed", min_value=0, max_value=10_000, value=0, step=1, key="generator_seed") | |
st.write("### Latent walk") | |
static_latents = st.checkbox("Static Latents", value=False, key="static_latents") | |
latent_walk = st.slider("Latent Walk", min_value=0.0, max_value=1.0, value=0.0, step=0.01, key="latent_walk") | |
return guidance_scale, inference_steps, generator_seed, static_latents, latent_walk | |
def decode_image(image): | |
cv2_img = cv2.imdecode(np.frombuffer(image, np.uint8), cv2.IMREAD_COLOR) | |
cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB) | |
image = Image.fromarray(cv2_img).convert("RGB").resize((512, 512)) | |
return image | |
if __name__ == "__main__": | |
st.sidebar.title("Sidebar") | |
with st.sidebar: | |
webcam = camera_input_live(debounce=1000, key="webcam", width=512, height=512) | |
prompt, negative_prompt = make_prompt_fields() | |
guidance_scale, inference_steps, generator_seed, static_latents, latent_walk = make_input_fields() | |
colA, colB = st.columns(2) | |
with colA: | |
st.write("## Webcam image") | |
st.write("You can draw the mask on the image below.") | |
image = decode_image(webcam.getvalue()) | |
canvas = make_canvas(image) | |
if st.button("Inpaint"): | |
st.write("Start inpainting process") | |
mask_image = get_mask(np.array(canvas.image_data)) | |
result = inpainting(image, mask_image, prompt, negative_prompt) | |
st.session_state["result"] = result | |
else: | |
result = None | |
with colB: | |
st.write("## Generated image") | |
st.write("The generated image will appear here.") | |
if webcam: | |
st.image(webcam) | |
if 'result' in st.session_state: | |
st.image(st.session_state["result"]) | |