BlockDetail commited on
Commit
907070c
1 Parent(s): c971f08

Add application file

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffusers import ControlNetModel, EulerAncestralDiscreteScheduler
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ from PIL import Image, ImageFilter
7
+ from interface.extension import CustomStableDiffusionControlNetPipeline
8
+
9
+ negative_prompt = ""
10
+ device = torch.device('cuda')
11
+ controlnet = ControlNetModel.from_pretrained("partialsketchcontrolnet", torch_dtype=torch.float16).to(device)
12
+ pipe = CustomStableDiffusionControlNetPipeline.from_pretrained(
13
+ "runwayml/stable-diffusion-v1-5",
14
+ controlnet=controlnet, torch_dtype=torch.float16
15
+ ).to(device)
16
+ pipe.safety_checker = None
17
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
18
+ threshold = 250
19
+ curr_num_samples = 2
20
+
21
+ all_gens = []
22
+
23
+ num_images = 5
24
+
25
+ with gr.Blocks() as demo:
26
+ start_state = []
27
+ with gr.Row():
28
+ with gr.Column():
29
+ with gr.Row():
30
+ stroke_type = gr.Radio(["Blocking", "Detail"], value="Detail", label="Stroke Type"),
31
+ dilation_strength = gr.Slider(7, 117, value=65, step=2, label="Dilation Strength"),
32
+ canvas = gr.Image(source="canvas", shape=(512, 512), tool="color-sketch",
33
+ min_width=512, brush_radius = 2).style(width=512, height=512)
34
+ prompt_box = gr.Textbox(width="50vw", label="Prompt")
35
+ with gr.Row():
36
+ btn = gr.Button("Generate").style(width=100, height=80)
37
+ btn2 = gr.Button("Reset").style(width=100, height=80)
38
+ with gr.Column():
39
+ num_samples = gr.Slider(1, 5, value=2, step=1, label="Num Samples to Generate"),
40
+ with gr.Tab("Renoised Images"):
41
+ gallery0 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height="auto", preview=True, interactive=False).style(width=512, height=512)
42
+ with gr.Tab("Renoised Overlay"):
43
+ gallery1 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height="auto", preview=True, interactive=False).style(width=512, height=512)
44
+ with gr.Tab("Pre-Renoise Images"):
45
+ gallery2 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height="auto", preview=True, interactive=False).style(width=512, height=512)
46
+ with gr.Tab("Pre-Renoise Overlay"):
47
+ gallery3 = gr.Gallery(show_label=False, columns=[num_samples[0].value], rows=[2], object_fit="contain", height="auto", preview=True, interactive=False).style(width=512, height=512)
48
+ for k in range(num_images):
49
+ start_state.append([None, None])
50
+ sketch_states = gr.State(start_state)
51
+ checkbox_state = gr.State(True)
52
+
53
+ def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
54
+ global curr_num_samples
55
+ generator = torch.Generator(device="cuda:0")
56
+ generator.manual_seed(seed)
57
+
58
+ negative_prompt = ""
59
+ guidance_scale = 7
60
+ controlnet_conditioning_scale = 1.0
61
+ images = pipe([prompt]*curr_num_samples, [curr_sketch_image.convert("RGB").point( lambda p: 256 if p > 128 else 0)]*curr_num_samples, guidance_scale=guidance_scale, controlnet_conditioning_scale = controlnet_conditioning_scale, negative_prompt = [negative_prompt] * curr_num_samples, num_inference_steps=num_steps, generator=generator, key_image=None, neg_mask=None).images
62
+
63
+ # run blended renoising if blocking strokes are provided
64
+ if dilation_mask is not None:
65
+ new_images = pipe.collage([prompt] * curr_num_samples, images, [dilation_mask] * curr_num_samples, num_inference_steps=50, strength=0.8)["images"]
66
+ else:
67
+ new_images = images
68
+ return images, new_images
69
+
70
+ def run_sketching(prompt, curr_sketch, sketch_states, dilation, contour_dilation=11):
71
+ seed = sketch_states[k][1]
72
+ if seed is None:
73
+ seed = np.random.randint(1000)
74
+ sketch_states[k][1] = seed
75
+
76
+ curr_sketch_image = Image.fromarray(curr_sketch[:, :, 0]).resize((512, 512))
77
+
78
+ curr_construction_image = Image.fromarray(255 - curr_sketch[:, :, 2] + curr_sketch[:, :, 0])
79
+ if np.sum(255 - np.array(curr_construction_image)) == 0:
80
+ curr_construction_image = None
81
+
82
+ curr_detail_image = Image.fromarray(curr_sketch[:, :, 2]).resize((512, 512))
83
+
84
+ if curr_construction_image is not None:
85
+ dilation_mask = Image.fromarray(255 - np.array(curr_construction_image)).filter(ImageFilter.MaxFilter(dilation))
86
+ dilation_mask = dilation_mask.point( lambda p: 256 if p > 0 else 25).filter(ImageFilter.GaussianBlur(radius = 5))
87
+
88
+ neg_dilation_mask = Image.fromarray(255 - np.array(curr_detail_image)).filter(ImageFilter.MaxFilter(contour_dilation))
89
+ neg_dilation_mask = np.array(neg_dilation_mask.point( lambda p: 256 if p > 0 else 0))
90
+ dilation_mask = np.array(dilation_mask)
91
+ dilation_mask[neg_dilation_mask > 0] = 25
92
+ dilation_mask = Image.fromarray(dilation_mask).filter(ImageFilter.GaussianBlur(radius = 5))
93
+ else:
94
+ dilation_mask = None
95
+
96
+ images, new_images = sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps = 40, dilation = dilation)
97
+
98
+ save_sketch = np.array(Image.fromarray(curr_sketch).convert("RGBA"))
99
+ save_sketch[:, :, 3][save_sketch[:, :, 0] > 128] = 0
100
+
101
+ overlays = []
102
+ for i in images:
103
+ background = i.copy()
104
+ background.putalpha(80)
105
+ background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
106
+ overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).convert("RGBA"))
107
+ overlays.append(overlay.convert("RGB"))
108
+
109
+ new_overlays = []
110
+ for i in new_images:
111
+ background = i.copy()
112
+ background.putalpha(80)
113
+ background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
114
+ overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).convert("RGBA"))
115
+ new_overlays.append(overlay.convert("RGB"))
116
+
117
+ global all_gens
118
+ all_gens = new_images
119
+
120
+ return new_images, new_overlays, images, overlays
121
+
122
+ def reset(sketch_states):
123
+ for k in range(len(sketch_states)):
124
+ sketch_states[k] = [None, None]
125
+ return None, sketch_states
126
+
127
+ def change_color(stroke_type):
128
+ if stroke_type == "Blocking":
129
+ color = "#0000FF"
130
+ else:
131
+ color = "#000000"
132
+ return gr.Image(source="canvas", shape=(512, 512), tool="color-sketch",
133
+ min_width=512, brush_radius = 2, brush_color=color).style(width=400, height=400)
134
+
135
+ def change_background(option):
136
+ global all_gens
137
+ if option == "None" or len(all_gens) == 0:
138
+ return None
139
+ elif option == "Sample 0":
140
+ image_overlay = all_gens[0].copy()
141
+ elif option == "Sample 1":
142
+ image_overlay = all_gens[0].copy()
143
+ else:
144
+ return None
145
+ image_overlay.putalpha(80)
146
+ return image_overlay
147
+
148
+ def change_num_samples(change):
149
+ global curr_num_samples
150
+ curr_num_samples = change
151
+ return None
152
+
153
+ btn.click(run_sketching, [prompt_box, canvas, sketch_states, dilation_strength[0]], [gallery0, gallery1, gallery2, gallery3])
154
+ btn2.click(reset, sketch_states, [canvas, sketch_states])
155
+ stroke_type[0].change(change_color, [stroke_type[0]], canvas)
156
+ num_samples[0].change(change_num_samples, [num_samples[0]], None)
157
+
158
+
159
+ demo.launch(share = True, debug = True)