|
from typing import Optional |
|
|
|
import numpy as np |
|
import cv2 |
|
import streamlit as st |
|
from PIL import Image |
|
|
|
from sdfile import PIPELINES, generate |
|
|
|
DEFAULT_PROMPT = "belted shirt black belted portrait-collar wrap blouse with black prints" |
|
DEAFULT_WIDTH, DEFAULT_HEIGHT = 512,512 |
|
OUTPUT_IMAGE_KEY = "output_img" |
|
LOADED_IMAGE_KEY = "loaded_img" |
|
|
|
def get_image(key: str) -> Optional[Image.Image]: |
|
if key in st.session_state: |
|
return st.session_state[key] |
|
return None |
|
|
|
def set_image(key:str, img: Image.Image): |
|
st.session_state[key] = img |
|
|
|
def prompt_and_generate_button(prefix, pipeline_name: PIPELINES, **kwargs): |
|
prompt = st.text_area( |
|
"Prompt", |
|
value = DEFAULT_PROMPT, |
|
key = f"{prefix}-prompt" |
|
) |
|
negative_prompt = st.text_area( |
|
"Negative prompt", |
|
value = "", |
|
key =f"{prefix}-negative_prompt", |
|
) |
|
col1,col2 =st.columns(2) |
|
with col1: |
|
steps = st.slider( |
|
"Number of inference steps", |
|
min_value=1, |
|
max_value=200, |
|
value=30, |
|
key=f"{prefix}-inference-steps", |
|
) |
|
with col2: |
|
guidance_scale = st.slider( |
|
"Guidance scale", |
|
min_value=0.0, |
|
max_value=20.0, |
|
value= 7.5, |
|
step = 0.5, |
|
key=f"{prefix}-guidance-scale", |
|
) |
|
enable_cpu_offload = st.checkbox( |
|
"Enable CPU offload if you run out of memory", |
|
key =f"{prefix}-cpu-offload", |
|
value= False, |
|
) |
|
|
|
if st.button("Generate Image", key = f"{prefix}-btn"): |
|
with st.spinner("Generating image ..."): |
|
image = generate( |
|
prompt, |
|
pipeline_name, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance_scale, |
|
enable_cpu_offload=enable_cpu_offload, |
|
**kwargs, |
|
) |
|
set_image(OUTPUT_IMAGE_KEY,image.copy()) |
|
st.image(image) |
|
def width_and_height_sliders(prefix): |
|
col1, col2 = st.columns(2) |
|
with col1: |
|
width = st.slider( |
|
"Width", |
|
min_value=64, |
|
max_value=1600, |
|
step=16, |
|
value=512, |
|
key=f"{prefix}-width", |
|
) |
|
with col2: |
|
height = st.slider( |
|
"Height", |
|
min_value=64, |
|
max_value=1600, |
|
step=16, |
|
value=512, |
|
key=f"{prefix}-height", |
|
) |
|
return width, height |
|
|
|
def image_uploader(prefix): |
|
image = st.file_uploader("Image", ["jpg", "png"], key=f"{prefix}-uploader") |
|
if image: |
|
image = Image.open(image) |
|
print(f"loaded input image of size ({image.width}, {image.height})") |
|
return image |
|
|
|
return get_image(LOADED_IMAGE_KEY) |
|
|
|
def sketching(): |
|
image = image_uploader("Controlnet") |
|
|
|
if not image: |
|
return None,None |
|
image = cv2.imread(image) |
|
image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY) |
|
image_blur = cv2.GaussianBlur(image,(5,5),0) |
|
sketch = cv2.adaptiveThreshold(image_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRES_BINARY,11,2) |
|
sketch_pil = Image.fromarray(sketch) |
|
return sketch_pil |
|
|
|
def txt2img_tab(): |
|
prefix = "txt2img" |
|
width, height = width_and_height_sliders(prefix) |
|
prompt_and_generate_button(prefix,"txt2img",width=width,height=height) |
|
|
|
def sketching_tab(): |
|
prefix = "sketch2img" |
|
col1,col2 = st.columns(2) |
|
with col1: |
|
sketch_pil = sketching() |
|
with col2: |
|
if sketch_pil: |
|
controlnet_conditioning_scale = st.slider( |
|
"Strength or dependence on the input sketch", |
|
min_value=0.0, |
|
max_value= 1.0, |
|
value = 0.5, |
|
step = 0.05, |
|
key=f"{prefix}-controlnet_conditioning_scale", |
|
) |
|
prompt_and_generate_button( |
|
prefix, |
|
"sketch2img", |
|
sketch_pil=sketch_pil, |
|
controlnet_conditioning_scale=controlnet_conditioning_scale, |
|
) |
|
|
|
def main(): |
|
st.set_page_config(layout="wide") |
|
st.title("Fashion-SDX: Playground") |
|
|
|
tab1,tab2 = st.tabs( |
|
["Text to image", "Sketch to image"] |
|
) |
|
with tab1: |
|
txt2img_tab() |
|
with tab2: |
|
sketching_tab() |
|
|
|
with st.sidebar: |
|
st.header("Most Recent Output Image") |
|
output_image = get_image((OUTPUT_IMAGE_KEY)) |
|
if output_image: |
|
st.image(output_image) |
|
else: |
|
st.markdown("no output generated yet") |
|
if __name__ =="__main__": |
|
main() |