Baptlem commited on
Commit
cd1e8dc
1 Parent(s): 3f2581f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -11
app.py CHANGED
@@ -1,8 +1,145 @@
1
- # This file is adapted from https://github.com/lllyasviel/ControlNet/blob/f4748e3630d8141d7765e2bd9b1e348f47847707/gradio_canny2image.py
2
  # The original license file is LICENSE.ControlNet in this repo.
 
 
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
 
5
- def create_demo(process, max_images=12, default_num_images=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  with gr.Blocks() as demo:
7
  with gr.Row():
8
  gr.Markdown('## Control Stable Diffusion with Canny Edge Maps')
@@ -12,13 +149,14 @@ def create_demo(process, max_images=12, default_num_images=3):
12
  prompt = gr.Textbox(label='Prompt')
13
  run_button = gr.Button(label='Run')
14
  with gr.Accordion('Advanced options', open=False):
15
- is_segmentation_map = gr.Checkbox(
16
- label='Is segmentation map', value=False)
17
  num_samples = gr.Slider(label='Images',
18
  minimum=1,
19
  maximum=max_images,
20
  value=default_num_images,
21
  step=1)
 
22
  canny_low_threshold = gr.Slider(
23
  label='Canny low threshold',
24
  minimum=1,
@@ -31,6 +169,12 @@ def create_demo(process, max_images=12, default_num_images=3):
31
  maximum=255,
32
  value=200,
33
  step=1)
 
 
 
 
 
 
34
  num_steps = gr.Slider(label='Steps',
35
  minimum=1,
36
  maximum=100,
@@ -46,9 +190,6 @@ def create_demo(process, max_images=12, default_num_images=3):
46
  maximum=2147483647,
47
  step=1,
48
  randomize=True)
49
- a_prompt = gr.Textbox(
50
- label='Added Prompt',
51
- value='best quality, extremely detailed')
52
  n_prompt = gr.Textbox(
53
  label='Negative Prompt',
54
  value=
@@ -62,14 +203,15 @@ def create_demo(process, max_images=12, default_num_images=3):
62
  inputs = [
63
  input_image,
64
  prompt,
65
- a_prompt,
66
- n_prompt,
67
  num_samples,
 
 
 
68
  num_steps,
69
  guidance_scale,
70
  seed,
71
- canny_low_threshold,
72
- canny_high_threshold,
73
  ]
74
  prompt.submit(fn=process, inputs=inputs, outputs=result)
75
  run_button.click(fn=process,
 
1
+ # This file is adapted from https://huggingface.co/spaces/diffusers/controlnet-canny/blob/main/app.py
2
  # The original license file is LICENSE.ControlNet in this repo.
3
+ from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler
4
+ from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
5
+ from flax.training.common_utils import shard
6
+ from flax.jax_utils import replicate
7
+ from diffusers.utils import load_image
8
+ import jax.numpy as jnp
9
+ import jax
10
+ import cv2
11
+ from PIL import Image
12
+ import numpy as np
13
  import gradio as gr
14
 
15
+ def create_key(seed=0):
16
+ return jax.random.PRNGKey(seed)
17
+
18
+ def load_controlnet(controlnet_version):
19
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
20
+ "Baptlem/baptlem-controlnet",
21
+ subfolder=controlnet_version,
22
+ from_flax=True,
23
+ dtype=jnp.float32,
24
+ )
25
+ return controlnet, controlnet_params
26
+
27
+
28
+ def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
29
+ controlnet, controlnet_params = load_controlnet(controlnet_version)
30
+
31
+ scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(
32
+ base_model_path,
33
+ subfolder="scheduler"
34
+ )
35
+
36
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
37
+ sb_path,
38
+ controlnet=controlnet,
39
+ dtype=jnp.float32,
40
+ from_pt=True
41
+ )
42
+
43
+ pipe.scheduler = scheduler
44
+ params["controlnet"] = controlnet_params
45
+ params["scheduler"] = scheduler_params
46
+ return pipe, params
47
+
48
+
49
+
50
+ controlnet_path = "Baptlem/baptlem-controlnet"
51
+ controlnet_version = "coyo-500k"
52
+
53
+ # Constants
54
+ low_threshold = 100
55
+ high_threshold = 200
56
+
57
+ pipe, params = load_sb_pipe(controlnet_version)
58
+
59
+ pipe.enable_xformers_memory_efficient_attention()
60
+ pipe.enable_model_cpu_offload()
61
+ pipe.enable_attention_slicing()
62
+
63
+ def pipe_inference(
64
+ image,
65
+ prompt,
66
+ is_canny=False,
67
+ num_samples=4,
68
+ resolution=128,
69
+ num_inference_steps=50,
70
+ guidance_scale=7.5,
71
+ seed=0,
72
+ negative_prompt="",
73
+ ):
74
+
75
+ if not isinstance(image, np.ndarray):
76
+ image = np.array(image)
77
+
78
+ resized_image = resize_image(image, resolution)
79
+
80
+ if not is_canny:
81
+ resized_image = preprocess_canny(resized_image)
82
+
83
+ rng = create_key(seed)
84
+ # rng = jax.random.split(rng,)
85
+
86
+ prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
87
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
88
+ processed_image = pipe.prepare_image_inputs([resized_image] * num_samples)
89
+ p_params = replicate(params)
90
+ prompt_ids = shard(prompt_ids)
91
+ negative_prompt_ids = shard(negative_prompt_ids)
92
+ processed_image = shard(processed_image)
93
+ output = pipe(
94
+ prompt_ids=prompt_ids,
95
+ image=processed_image,
96
+ params=p_params,
97
+ prng_seed=rng,
98
+ num_inference_steps=num_inference_steps,
99
+ guidance_scale=guidance_scale,
100
+ neg_prompt_ids=negative_prompt_ids,
101
+ jit=True,
102
+ )
103
+ all_outputs = []
104
+ all_outputs.append(image)
105
+ if not is_canny:
106
+ all_outputs.append(resized_image)
107
+
108
+ for image in output.images:
109
+ all_outputs.append(image)
110
+ return all_outputs
111
+
112
+ def resize_image(image, resolution):
113
+ h, w = image.shape
114
+ ratio = w/h
115
+ if ratio > 1 :
116
+ resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
117
+ elif ratio < 1 :
118
+ resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
119
+ else:
120
+ resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
121
+ return resized_image
122
+
123
+
124
+ def preprocess_canny(image, resolution=128):
125
+ h, w = image.shape
126
+ ratio = w/h
127
+ if ratio > 1 :
128
+ resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
129
+ elif ratio < 1 :
130
+ resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
131
+ else:
132
+ resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
133
+
134
+ processed_image = cv2.Canny(resized_image, low_threshold, high_threshold)
135
+ processed_image = processed_image[:, :, None]
136
+ processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)
137
+
138
+ resized_image = Image.fromarray(resized_image)
139
+ processed_image = Image.fromarray(processed_image)
140
+ return resized_image, processed_image
141
+
142
+ def create_demo(process, max_images=12, default_num_images=4):
143
  with gr.Blocks() as demo:
144
  with gr.Row():
145
  gr.Markdown('## Control Stable Diffusion with Canny Edge Maps')
 
149
  prompt = gr.Textbox(label='Prompt')
150
  run_button = gr.Button(label='Run')
151
  with gr.Accordion('Advanced options', open=False):
152
+ is_canny = gr.Checkbox(
153
+ label='Is canny', value=False)
154
  num_samples = gr.Slider(label='Images',
155
  minimum=1,
156
  maximum=max_images,
157
  value=default_num_images,
158
  step=1)
159
+ """
160
  canny_low_threshold = gr.Slider(
161
  label='Canny low threshold',
162
  minimum=1,
 
169
  maximum=255,
170
  value=200,
171
  step=1)
172
+ """
173
+ resolution = gr.Slider(label='Resolution',
174
+ minimum=128,
175
+ maximum=128,
176
+ value=128,
177
+ step=1)
178
  num_steps = gr.Slider(label='Steps',
179
  minimum=1,
180
  maximum=100,
 
190
  maximum=2147483647,
191
  step=1,
192
  randomize=True)
 
 
 
193
  n_prompt = gr.Textbox(
194
  label='Negative Prompt',
195
  value=
 
203
  inputs = [
204
  input_image,
205
  prompt,
206
+ is_canny,
 
207
  num_samples,
208
+ resolution,
209
+ #canny_low_threshold,
210
+ #canny_high_threshold,
211
  num_steps,
212
  guidance_scale,
213
  seed,
214
+ n_prompt,
 
215
  ]
216
  prompt.submit(fn=process, inputs=inputs, outputs=result)
217
  run_button.click(fn=process,