tsungtao commited on
Commit
b74d6b0
1 Parent(s): 1bd1511

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -54
app.py CHANGED
@@ -9,60 +9,6 @@ from diffusers.utils import load_image
9
  from PIL import Image
10
  from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
11
 
12
- def image_grid(imgs, rows, cols):
13
- w, h = imgs[0].size
14
- grid = Image.new("RGB", size=(cols * w, rows * h))
15
- for i, img in enumerate(imgs):
16
- grid.paste(img, box=(i % cols * w, i // cols * h))
17
- return grid
18
-
19
- def create_key(seed=0):
20
- return jax.random.PRNGKey(seed)
21
-
22
- rng = create_key(0)
23
-
24
- canny_image = load_image(
25
- "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
26
- )
27
-
28
- prompts = "a living room with tv, sea, window"
29
- negative_prompts = "fan "
30
-
31
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
32
- "tsungtao/controlnet-mlsd-202305011046", from_flax=True, dtype=jnp.float32
33
- )
34
-
35
- pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
36
- "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
37
- )
38
-
39
- params["controlnet"] = controlnet_params
40
-
41
- num_samples = jax.device_count()
42
- rng = jax.random.split(rng, jax.device_count())
43
-
44
- prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
45
- negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
46
- processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
47
-
48
- p_params = replicate(params)
49
- prompt_ids = shard(prompt_ids)
50
- negative_prompt_ids = shard(negative_prompt_ids)
51
- processed_image = shard(processed_image)
52
-
53
- output = pipe(
54
- prompt_ids=prompt_ids,
55
- image=processed_image,
56
- params=p_params,
57
- prng_seed=rng,
58
- num_inference_steps=50,
59
- neg_prompt_ids=negative_prompt_ids,
60
- jit=True,
61
- ).images
62
-
63
- output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
64
- output_images = image_grid(output_images, num_samples // 4, 4)
65
- output_images.save("tao/image.png")
66
 
67
 
68
  #gr.Interface.load("models/tsungtao/controlnet-mlsd-202305011046").launch()
 
9
  from PIL import Image
10
  from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  #gr.Interface.load("models/tsungtao/controlnet-mlsd-202305011046").launch()