mfidabel commited on
Commit
a324479
1 Parent(s): a9e6eb1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ from PIL import Image
4
+ from flax.jax_utils import replicate
5
+ from flax.training.common_utils import shard
6
+ from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+
10
+
11
+ title = "🧨 ControlNet on Segment Anything 🤗"
12
+ description = "This is a demo on ControlNet based on Segment Anything"
13
+
14
+ examples = [["a modern main room of a house", "low quality", "condition_image_1.png", 50, 4]]
15
+
16
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
17
+ "mfidabel/controlnet-segment-anything", dtype=jnp.float32
18
+ )
19
+
20
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
21
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
22
+ )
23
+
24
+ # Add ControlNet params and Replicate
25
+ params["controlnet"] = controlnet_params
26
+ p_params = replicate(params)
27
+
28
+
29
+ # Inference Function
30
+ def infer(prompts, negative_prompts, image, num_inference_steps, seed):
31
+ rng = jax.random.PRNGKey(int(seed))
32
+ num_inference_steps = int(num_inference_steps)
33
+ image = Image.fromarray(image, mode="RGB")
34
+ num_samples = jax.device_count()
35
+ p_rng = jax.random.split(rng, jax.device_count())
36
+
37
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
38
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
39
+ processed_image = pipe.prepare_image_inputs([image] * num_samples)
40
+
41
+ prompt_ids = shard(prompt_ids)
42
+ negative_prompt_ids = shard(negative_prompt_ids)
43
+ processed_image = shard(processed_image)
44
+
45
+ output = pipe(
46
+ prompt_ids=prompt_ids,
47
+ image=processed_image,
48
+ params=p_params,
49
+ prng_seed=p_rng,
50
+ num_inference_steps=num_inference_steps,
51
+ neg_prompt_ids=negative_prompt_ids,
52
+ jit=True,
53
+ ).images
54
+
55
+ print(output[0].shape)
56
+
57
+ final_image = [np.array(x[0]*255, dtype=np.uint8) for x in output]
58
+
59
+ del output
60
+
61
+ return final_image
62
+
63
+ gr.Interface(fn = infer,
64
+ inputs = ["text", "text", "image", "number", "number"],
65
+ outputs = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(columns=[2], rows=[2], object_fit="contain", height="auto", preview=True),
66
+ title = title,
67
+ description = description,
68
+ examples = examples).launch()