tsungtao commited on
Commit
4967dbc
1 Parent(s): bc912c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py CHANGED
@@ -1,3 +1,68 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  gr.Interface.load("models/tsungtao/controlnet-mlsd-202305011046").launch()
 
1
  import gradio as gr
2
+ import jax
3
+ from datasets import load_dataset
4
+ import numpy as np
5
+ import jax.numpy as jnp
6
+ from flax.jax_utils import replicate
7
+ from flax.training.common_utils import shard
8
+ from diffusers.utils import load_image
9
+ from PIL import Image
10
+ from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
11
+
12
+ def image_grid(imgs, rows, cols):
13
+ w, h = imgs[0].size
14
+ grid = Image.new("RGB", size=(cols * w, rows * h))
15
+ for i, img in enumerate(imgs):
16
+ grid.paste(img, box=(i % cols * w, i // cols * h))
17
+ return grid
18
+
19
+ def create_key(seed=0):
20
+ return jax.random.PRNGKey(seed)
21
+
22
+ rng = create_key(0)
23
+
24
+ canny_image = load_image(
25
+ "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg"
26
+ )
27
+
28
+ prompts = "a living room with tv, sea, window"
29
+ negative_prompts = "fan "
30
+
31
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
32
+ "tsungtao/controlnet-mlsd-202305011046", from_flax=True, dtype=jnp.float32
33
+ )
34
+
35
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
36
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
37
+ )
38
+
39
+ params["controlnet"] = controlnet_params
40
+
41
+ num_samples = jax.device_count()
42
+ rng = jax.random.split(rng, jax.device_count())
43
+
44
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
45
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
46
+ processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
47
+
48
+ p_params = replicate(params)
49
+ prompt_ids = shard(prompt_ids)
50
+ negative_prompt_ids = shard(negative_prompt_ids)
51
+ processed_image = shard(processed_image)
52
+
53
+ output = pipe(
54
+ prompt_ids=prompt_ids,
55
+ image=processed_image,
56
+ params=p_params,
57
+ prng_seed=rng,
58
+ num_inference_steps=50,
59
+ neg_prompt_ids=negative_prompt_ids,
60
+ jit=True,
61
+ ).images
62
+
63
+ output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
64
+ output_images = image_grid(output_images, num_samples // 4, 4)
65
+ output_images.save("tao/image.png")
66
+
67
 
68
  gr.Interface.load("models/tsungtao/controlnet-mlsd-202305011046").launch()