Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from diffusers import ControlNetModel, EulerAncestralDiscreteScheduler | |
import torch | |
import numpy as np | |
from PIL import Image, ImageFilter | |
from extension import CustomStableDiffusionControlNetPipeline | |
import spaces | |
negative_prompt = "" | |
device = torch.device('cuda') | |
pipe = None | |
def load(): | |
global pipe | |
controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device) | |
pipe = CustomStableDiffusionControlNetPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
controlnet=controlnet, torch_dtype=torch.float16 | |
).to(device) | |
pipe.safety_checker = None | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
threshold = 250 | |
curr_num_samples = 2 | |
all_gens = [] | |
num_images = 5 | |
with gr.Blocks() as demo: | |
start_state = [] | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
gr.Textbox(label="Stroke Type", value="To sketch Blocking strokes, change brush color to green. To sketch Detail strokes, change brush color to black."), | |
dilation_strength = gr.Slider(7, 117, value=65, step=2, label="Dilation Strength"), | |
canvas = gr.Sketchpad(image_mode="RGBA", crop_size="1:1", label="Sketch", sources=(), brush = gr.Brush(colors=["#00FF00", "#000000"], default_size = 2, color_mode="fixed")) | |
prompt_box = gr.Textbox(label="Prompt") | |
with gr.Row(): | |
btn = gr.Button("Generate") | |
btn2 = gr.Button("Reset") | |
with gr.Column(): | |
num_samples = gr.Slider(1, 5, value=2, step=1, label="Num Samples to Generate"), | |
with gr.Tab("Renoised Images"): | |
gallery0 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512) | |
with gr.Tab("Renoised Overlay"): | |
gallery1 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512) | |
with gr.Tab("Pre-Renoise Images"): | |
gallery2 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512) | |
with gr.Tab("Pre-Renoise Overlay"): | |
gallery3 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512) | |
for k in range(num_images): | |
start_state.append([None, None]) | |
sketch_states = gr.State(start_state) | |
checkbox_state = gr.State(True) | |
def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation): | |
global curr_num_samples | |
global pipe | |
generator = torch.Generator(device="cuda:0") | |
generator.manual_seed(seed) | |
negative_prompt = "" | |
guidance_scale = 7 | |
controlnet_conditioning_scale = 1.0 | |
images = pipe([prompt]*curr_num_samples, [curr_sketch_image.convert("RGB").point( lambda p: 256 if p > 128 else 0)]*curr_num_samples, guidance_scale=guidance_scale, controlnet_conditioning_scale = controlnet_conditioning_scale, negative_prompt = [negative_prompt] * curr_num_samples, num_inference_steps=num_steps, generator=generator, key_image=None, neg_mask=None).images | |
# run blended renoising if blocking strokes are provided | |
if dilation_mask is not None: | |
new_images = pipe.collage([prompt] * curr_num_samples, images, [dilation_mask] * curr_num_samples, num_inference_steps=50, strength=0.8)["images"] | |
else: | |
new_images = images | |
return images, new_images | |
def run_sketching(prompt, curr_sketch, sketch_states, dilation, contour_dilation=11): | |
seed = sketch_states[k][1] | |
if seed is None: | |
seed = np.random.randint(1000) | |
sketch_states[k][1] = seed | |
curr_sketch_image = Image.fromarray(curr_sketch["composite"]) | |
curr_sketch = np.array(curr_sketch_image.resize((512, 512), resample=0)) | |
curr_sketch[:, :, 0][curr_sketch[:, :, -1] == 0] = 255 | |
curr_sketch[:, :, 2][curr_sketch[:, :, -1] == 0] = 255 | |
curr_sketch[:, :, 1][curr_sketch[:, :, -1] == 0] = 255 | |
curr_sketch_image = Image.fromarray(curr_sketch[:, :, 0]).resize((512, 512)) | |
curr_construction_image = Image.fromarray(255 - curr_sketch[:, :, 1] + curr_sketch[:, :, 0]) | |
if np.sum(255 - np.array(curr_construction_image)) == 0: | |
curr_construction_image = None | |
curr_detail_image = Image.fromarray(curr_sketch[:, :, 1]).resize((512, 512)) | |
if curr_construction_image is not None: | |
dilation_mask = Image.fromarray(255 - np.array(curr_construction_image)).filter(ImageFilter.MaxFilter(dilation)) | |
dilation_mask = dilation_mask.point( lambda p: 256 if p > 0 else 25).filter(ImageFilter.GaussianBlur(radius = 5)) | |
neg_dilation_mask = Image.fromarray(255 - np.array(curr_detail_image)).filter(ImageFilter.MaxFilter(contour_dilation)) | |
neg_dilation_mask = np.array(neg_dilation_mask.point( lambda p: 256 if p > 0 else 0)) | |
dilation_mask = np.array(dilation_mask) | |
dilation_mask[neg_dilation_mask > 0] = 25 | |
dilation_mask = Image.fromarray(dilation_mask).filter(ImageFilter.GaussianBlur(radius = 5)) | |
else: | |
dilation_mask = None | |
images, new_images = sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps = 40, dilation = dilation) | |
save_sketch = np.array(Image.fromarray(curr_sketch).convert("RGBA")) | |
save_sketch[:, :, 3][save_sketch[:, :, 0] > 128] = 0 | |
overlays = [] | |
for i in images: | |
background = i.copy() | |
background.putalpha(80) | |
background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background) | |
overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA")) | |
overlays.append(overlay.convert("RGB")) | |
new_overlays = [] | |
for i in new_images: | |
background = i.copy() | |
background.putalpha(80) | |
background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background) | |
overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA")) | |
new_overlays.append(overlay.convert("RGB")) | |
global all_gens | |
all_gens = new_images | |
return new_images, new_overlays, images, overlays | |
def reset(sketch_states): | |
for k in range(len(sketch_states)): | |
sketch_states[k] = [None, None] | |
return None, sketch_states | |
# def change_color(stroke_type): | |
# if stroke_type == "Blocking": | |
# color = "#00FF00" | |
# else: | |
# color = "#000000" | |
# return gr.Sketchpad(sources = (), width=512, brush = gr.Brush(colors=[color], default_size = 2, color_mode="fixed"), height=512) | |
def change_background(option): | |
global all_gens | |
if option == "None" or len(all_gens) == 0: | |
return None | |
elif option == "Sample 0": | |
image_overlay = all_gens[0].copy() | |
elif option == "Sample 1": | |
image_overlay = all_gens[0].copy() | |
else: | |
return None | |
image_overlay.putalpha(80) | |
return image_overlay | |
def change_num_samples(change): | |
global curr_num_samples | |
curr_num_samples = change | |
return None | |
btn.click(run_sketching, [prompt_box, canvas, sketch_states, dilation_strength[0]], [gallery0, gallery1, gallery2, gallery3]) | |
btn2.click(reset, sketch_states, [canvas, sketch_states]) | |
# stroke_type[0].change(change_color, [stroke_type[0]], canvas) | |
num_samples[0].change(change_num_samples, [num_samples[0]], None) | |
load() | |
demo.launch(share = True, debug = True) |