Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,456 Bytes
907070c 2ad3800 3214d99 907070c 4f7d543 259a646 907070c 3aacad6 907070c 487acfb 907070c 3229e46 dc045f6 907070c 448fd77 907070c 13a1ef5 907070c 13a1ef5 907070c 13a1ef5 907070c 13a1ef5 907070c 7021612 259a646 907070c 487acfb 907070c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import gradio as gr
from diffusers import ControlNetModel, EulerAncestralDiscreteScheduler
import torch
import numpy as np
from PIL import Image, ImageFilter
from extension import CustomStableDiffusionControlNetPipeline
import spaces
negative_prompt = ""
device = torch.device('cuda')
controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
pipe = CustomStableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet, torch_dtype=torch.float16
).to(device)
pipe.safety_checker = None
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
threshold = 250
curr_num_samples = 2
all_gens = []
num_images = 5
with gr.Blocks() as demo:
start_state = []
with gr.Row():
with gr.Column():
gr.Textbox(label=None, value="We introduce a novel sketch-to-image tool that aligns with the iterative refinement process of artists. Our tool lets users sketch blocking strokes to coarsely represent the placement and form of objects and detail strokes to refine their shape and silhouettes.")
with gr.Row():
gr.Textbox(label="Stroke Type", value="To sketch Blocking strokes, change brush color to green. To sketch Detail strokes, change brush color to black."),
dilation_strength = gr.Slider(7, 117, value=65, step=2, label="Dilation Strength"),
canvas = gr.Sketchpad(image_mode="RGBA", crop_size="1:1", label="Canvas", sources=(), brush = gr.Brush(colors=["#00FF00", "#000000"], default_size = 2, color_mode="fixed"))
prompt_box = gr.Textbox(label="Prompt")
with gr.Row():
btn = gr.Button("Generate")
btn2 = gr.Button("Reset")
with gr.Column():
num_samples = gr.Slider(1, 5, value=2, step=1, label="Num Samples to Generate"),
with gr.Tab("Renoised Images"):
gallery0 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512)
with gr.Tab("Renoised Overlay"):
gallery1 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512)
with gr.Tab("Pre-Renoise Images"):
gallery2 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512)
with gr.Tab("Pre-Renoise Overlay"):
gallery3 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height=512, preview=True, interactive=False, min_width=512)
for k in range(num_images):
start_state.append([None, None])
sketch_states = gr.State(start_state)
checkbox_state = gr.State(True)
@spaces.GPU
def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
global curr_num_samples
global pipe
generator = torch.Generator(device="cuda:0")
generator.manual_seed(seed)
negative_prompt = ""
guidance_scale = 7
controlnet_conditioning_scale = 1.0
images = pipe([prompt]*curr_num_samples, [curr_sketch_image.convert("RGB").point( lambda p: 256 if p > 128 else 0)]*curr_num_samples, guidance_scale=guidance_scale, controlnet_conditioning_scale = controlnet_conditioning_scale, negative_prompt = [negative_prompt] * curr_num_samples, num_inference_steps=num_steps, generator=generator, key_image=None, neg_mask=None).images
# run blended renoising if blocking strokes are provided
if dilation_mask is not None:
new_images = pipe.collage([prompt] * curr_num_samples, images, [dilation_mask] * curr_num_samples, num_inference_steps=50, strength=0.8)["images"]
else:
new_images = images
return images, new_images
def run_sketching(prompt, curr_sketch, sketch_states, dilation, contour_dilation=11):
seed = sketch_states[k][1]
if seed is None:
seed = np.random.randint(1000)
sketch_states[k][1] = seed
curr_sketch_image = Image.fromarray(curr_sketch["composite"])
curr_sketch = np.array(curr_sketch_image.resize((512, 512), resample=0))
curr_sketch[:, :, 0][curr_sketch[:, :, -1] == 0] = 255
curr_sketch[:, :, 2][curr_sketch[:, :, -1] == 0] = 255
curr_sketch[:, :, 1][curr_sketch[:, :, -1] == 0] = 255
curr_sketch_image = Image.fromarray(curr_sketch[:, :, 0]).resize((512, 512))
curr_construction_image = Image.fromarray(255 - curr_sketch[:, :, 1] + curr_sketch[:, :, 0])
if np.sum(255 - np.array(curr_construction_image)) == 0:
curr_construction_image = None
curr_detail_image = Image.fromarray(curr_sketch[:, :, 1]).resize((512, 512))
if curr_construction_image is not None:
dilation_mask = Image.fromarray(255 - np.array(curr_construction_image)).filter(ImageFilter.MaxFilter(dilation))
dilation_mask = dilation_mask.point( lambda p: 256 if p > 0 else 25).filter(ImageFilter.GaussianBlur(radius = 5))
neg_dilation_mask = Image.fromarray(255 - np.array(curr_detail_image)).filter(ImageFilter.MaxFilter(contour_dilation))
neg_dilation_mask = np.array(neg_dilation_mask.point( lambda p: 256 if p > 0 else 0))
dilation_mask = np.array(dilation_mask)
dilation_mask[neg_dilation_mask > 0] = 25
dilation_mask = Image.fromarray(dilation_mask).filter(ImageFilter.GaussianBlur(radius = 5))
else:
dilation_mask = None
images, new_images = sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps = 40, dilation = dilation)
save_sketch = np.array(Image.fromarray(curr_sketch).convert("RGBA"))
save_sketch[:, :, 3][save_sketch[:, :, 0] > 128] = 0
overlays = []
for i in images:
background = i.copy()
background.putalpha(80)
background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA"))
overlays.append(overlay.convert("RGB"))
new_overlays = []
for i in new_images:
background = i.copy()
background.putalpha(80)
background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA"))
new_overlays.append(overlay.convert("RGB"))
global all_gens
all_gens = new_images
return new_images, new_overlays, images, overlays
def reset(sketch_states):
for k in range(len(sketch_states)):
sketch_states[k] = [None, None]
return None, sketch_states
# def change_color(stroke_type):
# if stroke_type == "Blocking":
# color = "#00FF00"
# else:
# color = "#000000"
# return gr.Sketchpad(sources = (), width=512, brush = gr.Brush(colors=[color], default_size = 2, color_mode="fixed"), height=512)
def change_background(option):
global all_gens
if option == "None" or len(all_gens) == 0:
return None
elif option == "Sample 0":
image_overlay = all_gens[0].copy()
elif option == "Sample 1":
image_overlay = all_gens[0].copy()
else:
return None
image_overlay.putalpha(80)
return image_overlay
def change_num_samples(change):
global curr_num_samples
curr_num_samples = change
return None
btn.click(run_sketching, [prompt_box, canvas, sketch_states, dilation_strength[0]], [gallery0, gallery1, gallery2, gallery3])
btn2.click(reset, sketch_states, [canvas, sketch_states])
# stroke_type[0].change(change_color, [stroke_type[0]], canvas)
num_samples[0].change(change_num_samples, [num_samples[0]], None)
demo.launch(share = True, debug = True) |