Update app.py
Browse files
app.py
CHANGED
@@ -35,9 +35,9 @@ def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
|
|
35 |
|
36 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
37 |
sb_path,
|
38 |
-
controlnet=controlnet,
|
39 |
-
|
40 |
-
|
41 |
)
|
42 |
|
43 |
pipe.scheduler = scheduler
|
@@ -56,9 +56,9 @@ high_threshold = 200
|
|
56 |
|
57 |
pipe, params = load_sb_pipe(controlnet_version)
|
58 |
|
59 |
-
pipe.enable_xformers_memory_efficient_attention()
|
60 |
-
pipe.enable_model_cpu_offload()
|
61 |
-
pipe.enable_attention_slicing()
|
62 |
|
63 |
def pipe_inference(
|
64 |
image,
|
@@ -78,18 +78,20 @@ def pipe_inference(
|
|
78 |
resized_image = resize_image(image, resolution)
|
79 |
|
80 |
if not is_canny:
|
81 |
-
resized_image = preprocess_canny(resized_image)
|
82 |
|
83 |
rng = create_key(seed)
|
84 |
-
|
85 |
-
|
86 |
prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
|
87 |
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
|
88 |
processed_image = pipe.prepare_image_inputs([resized_image] * num_samples)
|
|
|
89 |
p_params = replicate(params)
|
90 |
prompt_ids = shard(prompt_ids)
|
91 |
negative_prompt_ids = shard(negative_prompt_ids)
|
92 |
processed_image = shard(processed_image)
|
|
|
93 |
output = pipe(
|
94 |
prompt_ids=prompt_ids,
|
95 |
image=processed_image,
|
@@ -122,15 +124,6 @@ def resize_image(image, resolution):
|
|
122 |
|
123 |
|
124 |
def preprocess_canny(image, resolution=128):
|
125 |
-
h, w = image.shape
|
126 |
-
ratio = w/h
|
127 |
-
if ratio > 1 :
|
128 |
-
resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
|
129 |
-
elif ratio < 1 :
|
130 |
-
resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
|
131 |
-
else:
|
132 |
-
resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
|
133 |
-
|
134 |
processed_image = cv2.Canny(resized_image, low_threshold, high_threshold)
|
135 |
processed_image = processed_image[:, :, None]
|
136 |
processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)
|
@@ -139,6 +132,7 @@ def preprocess_canny(image, resolution=128):
|
|
139 |
processed_image = Image.fromarray(processed_image)
|
140 |
return resized_image, processed_image
|
141 |
|
|
|
142 |
def create_demo(process, max_images=12, default_num_images=4):
|
143 |
with gr.Blocks() as demo:
|
144 |
with gr.Row():
|
@@ -218,14 +212,12 @@ def create_demo(process, max_images=12, default_num_images=4):
|
|
218 |
inputs=inputs,
|
219 |
outputs=result,
|
220 |
api_name='canny')
|
221 |
-
return demo
|
222 |
|
223 |
|
224 |
if __name__ == '__main__':
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
demo = create_demo(model.process_canny)
|
229 |
demo.queue().launch()
|
230 |
-
|
231 |
-
|
|
|
35 |
|
36 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
37 |
sb_path,
|
38 |
+
controlnet=controlnet,
|
39 |
+
revision="flax",
|
40 |
+
dtype=jnp.bfloat16
|
41 |
)
|
42 |
|
43 |
pipe.scheduler = scheduler
|
|
|
56 |
|
57 |
pipe, params = load_sb_pipe(controlnet_version)
|
58 |
|
59 |
+
# pipe.enable_xformers_memory_efficient_attention()
|
60 |
+
# pipe.enable_model_cpu_offload()
|
61 |
+
# pipe.enable_attention_slicing()
|
62 |
|
63 |
def pipe_inference(
|
64 |
image,
|
|
|
78 |
resized_image = resize_image(image, resolution)
|
79 |
|
80 |
if not is_canny:
|
81 |
+
resized_image = preprocess_canny(resized_image, resolution)
|
82 |
|
83 |
rng = create_key(seed)
|
84 |
+
rng = jax.random.split(rng, jax.device_count())
|
85 |
+
|
86 |
prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
|
87 |
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
|
88 |
processed_image = pipe.prepare_image_inputs([resized_image] * num_samples)
|
89 |
+
|
90 |
p_params = replicate(params)
|
91 |
prompt_ids = shard(prompt_ids)
|
92 |
negative_prompt_ids = shard(negative_prompt_ids)
|
93 |
processed_image = shard(processed_image)
|
94 |
+
|
95 |
output = pipe(
|
96 |
prompt_ids=prompt_ids,
|
97 |
image=processed_image,
|
|
|
124 |
|
125 |
|
126 |
def preprocess_canny(image, resolution=128):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
processed_image = cv2.Canny(resized_image, low_threshold, high_threshold)
|
128 |
processed_image = processed_image[:, :, None]
|
129 |
processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)
|
|
|
132 |
processed_image = Image.fromarray(processed_image)
|
133 |
return resized_image, processed_image
|
134 |
|
135 |
+
|
136 |
def create_demo(process, max_images=12, default_num_images=4):
|
137 |
with gr.Blocks() as demo:
|
138 |
with gr.Row():
|
|
|
212 |
inputs=inputs,
|
213 |
outputs=result,
|
214 |
api_name='canny')
|
|
|
215 |
|
216 |
|
217 |
if __name__ == '__main__':
|
218 |
+
|
219 |
+
pipe_inference
|
220 |
+
demo = create_demo(pipe_inference)
|
|
|
221 |
demo.queue().launch()
|
222 |
+
# gr.Interface(create_demo).launch()
|
223 |
+
|