Mateo Fidabel commited on
Commit
e6915e1
1 Parent(s): 892096a

Changed Example Layout, Predefined Input

Browse files
Files changed (1) hide show
  1. app.py +70 -49
app.py CHANGED
@@ -7,6 +7,7 @@ from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
7
  from diffusers.utils import load_image
8
  import jax.numpy as jnp
9
  import numpy as np
 
10
 
11
 
12
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
@@ -31,67 +32,94 @@ description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anyt
31
  Test some of the examples below to give it a try ⬇️
32
  """
33
 
34
- examples = [["a modern main room of a house", "low quality", "examples/condition_image_1.png"],
35
  ["new york buildings, Vincent Van Gogh starry night ", "low quality, monochrome", "examples/condition_image_2.png"],
36
  ["contemporary living room, high quality, 4k, realistic", "low quality, monochrome, low res", "examples/condition_image_3.png"]]
37
 
38
 
39
  # Inference Function
40
  def infer(prompts, negative_prompts, image, num_inference_steps = 50, seed = 4, num_samples = 4):
41
- rng = jax.random.PRNGKey(int(seed))
42
- num_inference_steps = int(num_inference_steps)
43
- image = Image.fromarray(image, mode="RGB")
44
- num_samples = max(jax.device_count(), int(num_samples))
45
- p_rng = jax.random.split(rng, jax.device_count())
46
-
47
- prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
48
- negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
49
- processed_image = pipe.prepare_image_inputs([image] * num_samples)
50
-
51
- prompt_ids = shard(prompt_ids)
52
- negative_prompt_ids = shard(negative_prompt_ids)
53
- processed_image = shard(processed_image)
54
-
55
- output = pipe(
56
- prompt_ids=prompt_ids,
57
- image=processed_image,
58
- params=p_params,
59
- prng_seed=p_rng,
60
- num_inference_steps=num_inference_steps,
61
- neg_prompt_ids=negative_prompt_ids,
62
- jit=True,
63
- ).images
64
-
65
- output = output.reshape((num_samples,) + output.shape[-3:])
66
-
67
- print(output.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- final_image = [np.array(x*255, dtype=np.uint8) for x in output]
70
 
71
- del output
72
-
73
- return final_image
 
 
 
 
 
 
 
 
74
 
75
  with gr.Blocks(css="h1 { text-align: center }") as demo:
76
- # Title
77
- gr.Markdown(title)
78
- # Description
79
- gr.Markdown(description)
 
 
 
 
 
 
 
 
 
80
 
81
  # Images
82
  with gr.Row(variant="panel"):
83
  with gr.Column(scale=2):
84
- cond_img = gr.Image(label="Input")\
85
- .style(height=200)
86
  with gr.Column(scale=1):
87
- output = gr.Gallery(label="Generated images")\
88
- .style(height=200, rows=[2], columns=[1, 2], object_fit="contain")
89
 
90
  # Submit & Clear
91
  with gr.Row():
92
  with gr.Column():
93
- prompt = gr.Textbox(lines=1, label="Prompt")
94
- negative_prompt = gr.Textbox(lines=1, label="Negative Prompt")
95
 
96
  with gr.Column():
97
  with gr.Accordion("Advanced options", open=False):
@@ -102,13 +130,6 @@ with gr.Blocks(css="h1 { text-align: center }") as demo:
102
  submit = gr.Button("Generate")
103
  # TODO: Download Button
104
 
105
- # Examples
106
- gr.Examples(examples=examples,
107
- inputs=[prompt, negative_prompt, cond_img],
108
- outputs=output,
109
- fn=infer,
110
- cache_examples=True)
111
-
112
 
113
  submit.click(infer,
114
  inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],
 
7
  from diffusers.utils import load_image
8
  import jax.numpy as jnp
9
  import numpy as np
10
+ import gc
11
 
12
 
13
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
 
32
  Test some of the examples below to give it a try ⬇️
33
  """
34
 
35
+ examples = [["contemporary living room of a house", "low quality", "examples/condition_image_1.png"],
36
  ["new york buildings, Vincent Van Gogh starry night ", "low quality, monochrome", "examples/condition_image_2.png"],
37
  ["contemporary living room, high quality, 4k, realistic", "low quality, monochrome, low res", "examples/condition_image_3.png"]]
38
 
39
 
40
  # Inference Function
41
  def infer(prompts, negative_prompts, image, num_inference_steps = 50, seed = 4, num_samples = 4):
42
+ try:
43
+ rng = jax.random.PRNGKey(int(seed))
44
+ num_inference_steps = int(num_inference_steps)
45
+ image = Image.fromarray(image, mode="RGB")
46
+ num_samples = max(jax.device_count(), int(num_samples))
47
+ p_rng = jax.random.split(rng, jax.device_count())
48
+
49
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
50
+ negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
51
+ processed_image = pipe.prepare_image_inputs([image] * num_samples)
52
+
53
+ prompt_ids = shard(prompt_ids)
54
+ negative_prompt_ids = shard(negative_prompt_ids)
55
+ processed_image = shard(processed_image)
56
+
57
+ output = pipe(
58
+ prompt_ids=prompt_ids,
59
+ image=processed_image,
60
+ params=p_params,
61
+ prng_seed=p_rng,
62
+ num_inference_steps=num_inference_steps,
63
+ neg_prompt_ids=negative_prompt_ids,
64
+ jit=True,
65
+ ).images
66
+
67
+ del negative_prompt_ids
68
+ del processed_image
69
+ del prompt_ids
70
+
71
+ output = output.reshape((num_samples,) + output.shape[-3:])
72
+ final_image = [np.array(x*255, dtype=np.uint8) for x in output]
73
+ print(output.shape)
74
+ del output
75
+
76
+ except Exception as e:
77
+ print("Error: " + str(e))
78
+ final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
79
+ finally:
80
+ gc.collect()
81
+ return final_image
82
 
 
83
 
84
+ default_example = examples[2]
85
+
86
+ cond_img = gr.Image(label="Input", shape=(512, 512), value=default_example[2])\
87
+ .style(height=200)
88
+
89
+ output = gr.Gallery(label="Generated images")\
90
+ .style(height=200, rows=[2], columns=[1, 2], object_fit="contain")
91
+
92
+ prompt = gr.Textbox(lines=1, label="Prompt", value=default_example[0])
93
+ negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", value=default_example[1])
94
+
95
 
96
  with gr.Blocks(css="h1 { text-align: center }") as demo:
97
+ with gr.Row():
98
+ with gr.Column():
99
+ # Title
100
+ gr.Markdown(title)
101
+ # Description
102
+ gr.Markdown(description)
103
+
104
+ with gr.Column():
105
+ # Examples
106
+ gr.Examples(examples=examples,
107
+ inputs=[prompt, negative_prompt, cond_img],
108
+ outputs=output,
109
+ fn=infer)
110
 
111
  # Images
112
  with gr.Row(variant="panel"):
113
  with gr.Column(scale=2):
114
+ cond_img.render()
 
115
  with gr.Column(scale=1):
116
+ output.render()
 
117
 
118
  # Submit & Clear
119
  with gr.Row():
120
  with gr.Column():
121
+ prompt.render()
122
+ negative_prompt.render()
123
 
124
  with gr.Column():
125
  with gr.Accordion("Advanced options", open=False):
 
130
  submit = gr.Button("Generate")
131
  # TODO: Download Button
132
 
 
 
 
 
 
 
 
133
 
134
  submit.click(infer,
135
  inputs=[prompt, negative_prompt, cond_img, num_steps, seed, num_samples],