merve HF staff commited on
Commit
9aa593c
1 Parent(s): d42ed77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -6
app.py CHANGED
@@ -1,11 +1,71 @@
1
  import gradio as gr
2
- from diffusers import DiffusionPipeline
 
 
 
 
 
 
 
3
 
4
- pipeline = DiffusionPipeline.from_pretrained("jax-diffusers-event/canny-coyo1m")
5
 
6
- def infer(prompt):
7
- image = pipeline(prompt).images[0]
8
- return image
 
 
 
9
 
10
 
11
- gr.Interface(pipeline, inputs="text", outputs="image").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import jax
3
+ import numpy as np
4
+ import jax.numpy as jnp
5
+ from flax.jax_utils import replicate
6
+ from flax.training.common_utils import shard
7
+ from diffusers.utils import load_image
8
+ from PIL import Image
9
+ from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
10
 
 
11
 
12
+ def image_grid(imgs, rows, cols):
13
+ w, h = imgs[0].size
14
+ grid = Image.new("RGB", size=(cols * w, rows * h))
15
+ for i, img in enumerate(imgs):
16
+ grid.paste(img, box=(i % cols * w, i // cols * h))
17
+ return grid
18
 
19
 
20
+ def create_key(seed=0):
21
+ return jax.random.PRNGKey(seed)
22
+
23
+
24
+ rng = create_key(0)
25
+
26
+
27
+ def canny_filter(image):
28
+
29
+ ## TODO: Implement canny filter here
30
+ return canny_image
31
+
32
+ def infer(prompts, negative_prompts, image):
33
+
34
+
35
+ # load control net and stable diffusion v1-5
36
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
37
+ "jax-diffusers-event/canny-coyo1m", from_pt=True, dtype=jnp.float32
38
+ )
39
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
40
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.float32
41
+ )
42
+ params["controlnet"] = controlnet_params
43
+
44
+ num_samples = jax.device_count()
45
+ rng = jax.random.split(rng, jax.device_count())
46
+ canny_image = canny_filter(image)
47
+
48
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
49
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
50
+ processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
51
+
52
+ p_params = replicate(params)
53
+ prompt_ids = shard(prompt_ids)
54
+ negative_prompt_ids = shard(negative_prompt_ids)
55
+ processed_image = shard(processed_image)
56
+
57
+ output = pipe(
58
+ prompt_ids=prompt_ids,
59
+ image=processed_image,
60
+ params=p_params,
61
+ prng_seed=rng,
62
+ num_inference_steps=50,
63
+ neg_prompt_ids=negative_prompt_ids,
64
+ jit=True,
65
+ ).images
66
+
67
+ output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
68
+ output_images = image_grid(output_images, num_samples // 4, 4)
69
+ return output_images
70
+
71
+ gr.Interface(pipeline, inputs=["text", "text", "image"], outputs="gallery").launch()