Use bfloat16, 1 image, no grid

#1
by pcuenq HF staff - opened
Files changed (1) hide show
  1. app.py +9 -25
app.py CHANGED
@@ -7,48 +7,30 @@ from flax.training.common_utils import shard
7
  from PIL import Image
8
  from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
  import cv2
10
- import os
11
-
12
-
13
-
14
-
15
- def image_grid(imgs, rows, cols):
16
- w, h = imgs[0].size
17
- grid = Image.new("RGB", size=(cols * w, rows * h))
18
- for i, img in enumerate(imgs):
19
- grid.paste(img, box=(i % cols * w, i // cols * h))
20
- return grid
21
-
22
 
23
  def create_key(seed=0):
24
  return jax.random.PRNGKey(seed)
25
 
26
-
27
-
28
-
29
  def canny_filter(image):
30
  gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
31
-
32
  blurred_image = cv2.GaussianBlur(gray_image, (5, 5), 0)
33
  edges_image = cv2.Canny(blurred_image, 50, 150)
34
  return edges_image
35
 
36
  # load control net and stable diffusion v1-5
37
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
38
- "jax-diffusers-event/canny-coyo1m", dtype=jnp.float32
39
  )
40
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
41
- "runwayml/stable-diffusion-v1-5", from_pt=True, controlnet=controlnet, dtype=jnp.float32
42
  )
43
 
44
  def infer(prompts, negative_prompts, image):
45
-
46
-
47
  params["controlnet"] = controlnet_params
48
 
49
- num_samples = jax.device_count("gpu")
50
  rng = create_key(0)
51
- rng = jax.random.split(rng, jax.device_count("gpu"))
52
  im = canny_filter(image)
53
  canny_image = Image.fromarray(im)
54
 
@@ -56,12 +38,15 @@ def infer(prompts, negative_prompts, image):
56
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
57
  processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
58
 
59
-
 
 
 
60
 
61
  output = pipe(
62
  prompt_ids=prompt_ids,
63
  image=processed_image,
64
- params=params,
65
  prng_seed=rng,
66
  num_inference_steps=50,
67
  neg_prompt_ids=negative_prompt_ids,
@@ -69,7 +54,6 @@ def infer(prompts, negative_prompts, image):
69
  ).images
70
 
71
  output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
72
- output_images = image_grid(output_images, num_samples // 4, 4)
73
  return output_images
74
 
75
  gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery").launch()
 
7
  from PIL import Image
8
  from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
  import cv2
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def create_key(seed=0):
12
  return jax.random.PRNGKey(seed)
13
 
 
 
 
14
  def canny_filter(image):
15
  gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
 
16
  blurred_image = cv2.GaussianBlur(gray_image, (5, 5), 0)
17
  edges_image = cv2.Canny(blurred_image, 50, 150)
18
  return edges_image
19
 
20
  # load control net and stable diffusion v1-5
21
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
22
+ "jax-diffusers-event/canny-coyo1m", dtype=jnp.bfloat16
23
  )
24
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
25
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
26
  )
27
 
28
  def infer(prompts, negative_prompts, image):
 
 
29
  params["controlnet"] = controlnet_params
30
 
31
+ num_samples = 1 #jax.device_count()
32
  rng = create_key(0)
33
+ rng = jax.random.split(rng, jax.device_count())
34
  im = canny_filter(image)
35
  canny_image = Image.fromarray(im)
36
 
 
38
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
39
  processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
40
 
41
+ p_params = replicate(params)
42
+ prompt_ids = shard(prompt_ids)
43
+ negative_prompt_ids = shard(negative_prompt_ids)
44
+ processed_image = shard(processed_image)
45
 
46
  output = pipe(
47
  prompt_ids=prompt_ids,
48
  image=processed_image,
49
+ params=p_params,
50
  prng_seed=rng,
51
  num_inference_steps=50,
52
  neg_prompt_ids=negative_prompt_ids,
 
54
  ).images
55
 
56
  output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
 
57
  return output_images
58
 
59
  gr.Interface(infer, inputs=["text", "text", "image"], outputs="gallery").launch()