Baptlem commited on
Commit
a5e9129
1 Parent(s): cd1e8dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -25
app.py CHANGED
@@ -35,9 +35,9 @@ def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"):
35
 
36
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
37
  sb_path,
38
- controlnet=controlnet,
39
- dtype=jnp.float32,
40
- from_pt=True
41
  )
42
 
43
  pipe.scheduler = scheduler
@@ -56,9 +56,9 @@ high_threshold = 200
56
 
57
  pipe, params = load_sb_pipe(controlnet_version)
58
 
59
- pipe.enable_xformers_memory_efficient_attention()
60
- pipe.enable_model_cpu_offload()
61
- pipe.enable_attention_slicing()
62
 
63
  def pipe_inference(
64
  image,
@@ -78,18 +78,20 @@ def pipe_inference(
78
  resized_image = resize_image(image, resolution)
79
 
80
  if not is_canny:
81
- resized_image = preprocess_canny(resized_image)
82
 
83
  rng = create_key(seed)
84
- # rng = jax.random.split(rng,)
85
-
86
  prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
87
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
88
  processed_image = pipe.prepare_image_inputs([resized_image] * num_samples)
 
89
  p_params = replicate(params)
90
  prompt_ids = shard(prompt_ids)
91
  negative_prompt_ids = shard(negative_prompt_ids)
92
  processed_image = shard(processed_image)
 
93
  output = pipe(
94
  prompt_ids=prompt_ids,
95
  image=processed_image,
@@ -122,15 +124,6 @@ def resize_image(image, resolution):
122
 
123
 
124
  def preprocess_canny(image, resolution=128):
125
- h, w = image.shape
126
- ratio = w/h
127
- if ratio > 1 :
128
- resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST)
129
- elif ratio < 1 :
130
- resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST)
131
- else:
132
- resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST)
133
-
134
  processed_image = cv2.Canny(resized_image, low_threshold, high_threshold)
135
  processed_image = processed_image[:, :, None]
136
  processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)
@@ -139,6 +132,7 @@ def preprocess_canny(image, resolution=128):
139
  processed_image = Image.fromarray(processed_image)
140
  return resized_image, processed_image
141
 
 
142
  def create_demo(process, max_images=12, default_num_images=4):
143
  with gr.Blocks() as demo:
144
  with gr.Row():
@@ -218,14 +212,12 @@ def create_demo(process, max_images=12, default_num_images=4):
218
  inputs=inputs,
219
  outputs=result,
220
  api_name='canny')
221
- return demo
222
 
223
 
224
  if __name__ == '__main__':
225
- """
226
- from model import Model
227
- model = Model()
228
- demo = create_demo(model.process_canny)
229
  demo.queue().launch()
230
- """
231
- pass
 
35
 
36
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
37
  sb_path,
38
+ controlnet=controlnet,
39
+ revision="flax",
40
+ dtype=jnp.bfloat16
41
  )
42
 
43
  pipe.scheduler = scheduler
 
56
 
57
  pipe, params = load_sb_pipe(controlnet_version)
58
 
59
+ # pipe.enable_xformers_memory_efficient_attention()
60
+ # pipe.enable_model_cpu_offload()
61
+ # pipe.enable_attention_slicing()
62
 
63
  def pipe_inference(
64
  image,
 
78
  resized_image = resize_image(image, resolution)
79
 
80
  if not is_canny:
81
+ resized_image = preprocess_canny(resized_image, resolution)
82
 
83
  rng = create_key(seed)
84
+ rng = jax.random.split(rng, jax.device_count())
85
+
86
  prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples)
87
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples)
88
  processed_image = pipe.prepare_image_inputs([resized_image] * num_samples)
89
+
90
  p_params = replicate(params)
91
  prompt_ids = shard(prompt_ids)
92
  negative_prompt_ids = shard(negative_prompt_ids)
93
  processed_image = shard(processed_image)
94
+
95
  output = pipe(
96
  prompt_ids=prompt_ids,
97
  image=processed_image,
 
124
 
125
 
126
  def preprocess_canny(image, resolution=128):
 
 
 
 
 
 
 
 
 
127
  processed_image = cv2.Canny(resized_image, low_threshold, high_threshold)
128
  processed_image = processed_image[:, :, None]
129
  processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2)
 
132
  processed_image = Image.fromarray(processed_image)
133
  return resized_image, processed_image
134
 
135
+
136
  def create_demo(process, max_images=12, default_num_images=4):
137
  with gr.Blocks() as demo:
138
  with gr.Row():
 
212
  inputs=inputs,
213
  outputs=result,
214
  api_name='canny')
 
215
 
216
 
217
  if __name__ == '__main__':
218
+
219
+ pipe_inference
220
+ demo = create_demo(pipe_inference)
 
221
  demo.queue().launch()
222
+ # gr.Interface(create_demo).launch()
223
+