Contrebande Labs commited on
Commit
e0cb68e
1 Parent(s): 06f2eaf

sync with working jax inference code from main repo

Browse files
Files changed (1) hide show
  1. app.py +31 -53
app.py CHANGED
@@ -16,6 +16,7 @@ from diffusers import (
16
 
17
  from transformers import ByT5Tokenizer, FlaxT5ForConditionalGeneration
18
 
 
19
  def get_inference_lambda(seed):
20
 
21
  tokenizer = ByT5Tokenizer()
@@ -51,7 +52,7 @@ def get_inference_lambda(seed):
51
  "trained_betas": None,
52
  }
53
  )
54
- timesteps = 50
55
  guidance_scale = jnp.array([7.5], dtype=jnp.float32)
56
 
57
  unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
@@ -68,7 +69,13 @@ def get_inference_lambda(seed):
68
 
69
  image_width = image_height = 256
70
 
71
- print("all models setup")
 
 
 
 
 
 
72
 
73
  def __tokenize_prompt(prompt: str):
74
 
@@ -78,15 +85,11 @@ def get_inference_lambda(seed):
78
  padding="max_length",
79
  truncation=True,
80
  return_tensors="jax",
81
- ).input_ids.astype(jnp.float32)
82
 
83
- def __convert_image(vae_output):
84
- print("skipping image conversion...")
85
- return None
86
- # return [
87
- # Image.fromarray(image)
88
- # for image in (np.asarray(vae_output) * 255).round().astype(np.uint8)
89
- # ]
90
 
91
  def __predict_image(tokenized_prompt: jnp.array):
92
 
@@ -99,14 +102,6 @@ def get_inference_lambda(seed):
99
  context = jnp.concatenate(
100
  [negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
101
  )
102
- jax.debug.print("got text encoding...")
103
-
104
- latent_shape = (
105
- tokenized_prompt.shape[0],
106
- unet.in_channels,
107
- image_width // vae_scale_factor,
108
- image_height // vae_scale_factor,
109
- )
110
 
111
  def ___timestep(step, step_args):
112
 
@@ -148,15 +143,12 @@ def get_inference_lambda(seed):
148
  scheduler_state, guided_unet_prediction_sample, t, latents
149
  ).to_tuple()
150
 
151
- jax.debug.print("did one step...")
152
-
153
  return latents, scheduler_state
154
 
155
  # initialize scheduler state
156
  initial_scheduler_state = scheduler.set_timesteps(
157
  scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
158
  )
159
- jax.debug.print("initialized scheduler state...")
160
 
161
  # initialize latents
162
  initial_latents = (
@@ -165,49 +157,33 @@ def get_inference_lambda(seed):
165
  )
166
  * initial_scheduler_state.init_noise_sigma
167
  )
168
- jax.debug.print("initialized latents...")
169
 
170
  final_latents, _ = jax.lax.fori_loop(
171
  0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
172
  )
173
- jax.debug.print("got final latents...")
174
-
175
- # scale and decode the image latents with vae
176
- image = (
177
- (
178
- vae.apply(
179
- {"params": vae_params},
180
- 1 / vae.config.scaling_factor * final_latents,
181
- method=vae.decode,
182
- ).sample
183
- / 2
184
- + 0.5
185
- )
186
- .clip(0, 1)
187
- .transpose(0, 2, 3, 1)
188
- )
189
- jax.debug.print("got vae processed image output...")
190
 
191
- # return reshaped vae outputs
192
- return image
 
 
 
 
 
 
 
 
 
 
193
 
194
- jax_pmap_predict_image = jax.jit(__predict_image)
195
 
196
  return lambda prompt: __convert_image(
197
- jax_pmap_predict_image(__tokenize_prompt(prompt))
198
  )
199
 
200
 
201
  generate_image_for_prompt = get_inference_lambda(87)
202
 
203
- print(f"JAX devices: {jax.devices()}")
204
- print(f"JAX device type: {jax.devices()[0].device_kind}")
205
-
206
- def infer_charred(prompt):
207
- # your inference function for charr stable difusion control
208
- generate_image_for_prompt(prompt)
209
- return None
210
-
211
 
212
  with gr.Blocks(theme="gradio/soft") as demo:
213
 
@@ -239,10 +215,12 @@ with gr.Blocks(theme="gradio/soft") as demo:
239
  submit_btn = gr.Button(value="Submit")
240
  charred_inputs = [prompt_input_charr]
241
  submit_btn.click(
242
- fn=infer_charred, inputs=charred_inputs, outputs=[charred_output]
 
 
243
  )
244
  # examples = [["postage stamp from california", "low quality", "charr_output.png", "charr_output.png" ]]
245
  # gr.Examples(fn = infer_sd, inputs = ["text", "text", "image", "image"], examples=examples, cache_examples=True)
246
 
247
  demo.queue(concurrency_count=1)
248
- demo.launch(debug=True, show_error=True, quiet=False)
 
16
 
17
  from transformers import ByT5Tokenizer, FlaxT5ForConditionalGeneration
18
 
19
+
20
  def get_inference_lambda(seed):
21
 
22
  tokenizer = ByT5Tokenizer()
 
52
  "trained_betas": None,
53
  }
54
  )
55
+ timesteps = 20
56
  guidance_scale = jnp.array([7.5], dtype=jnp.float32)
57
 
58
  unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
 
69
 
70
  image_width = image_height = 256
71
 
72
+ # Generating latent shape
73
+ latent_shape = (
74
+ negative_prompt_text_encoder_hidden_states.shape[0],
75
+ unet.in_channels,
76
+ image_width // vae_scale_factor,
77
+ image_height // vae_scale_factor,
78
+ )
79
 
80
  def __tokenize_prompt(prompt: str):
81
 
 
85
  padding="max_length",
86
  truncation=True,
87
  return_tensors="jax",
88
+ ).input_ids
89
 
90
+ def __convert_image(image):
91
+ # create PIL image from JAX tensor converted to numpy
92
+ return Image.fromarray(np.asarray(image), mode="RGB")
 
 
 
 
93
 
94
  def __predict_image(tokenized_prompt: jnp.array):
95
 
 
102
  context = jnp.concatenate(
103
  [negative_prompt_text_encoder_hidden_states, text_encoder_hidden_states]
104
  )
 
 
 
 
 
 
 
 
105
 
106
  def ___timestep(step, step_args):
107
 
 
143
  scheduler_state, guided_unet_prediction_sample, t, latents
144
  ).to_tuple()
145
 
 
 
146
  return latents, scheduler_state
147
 
148
  # initialize scheduler state
149
  initial_scheduler_state = scheduler.set_timesteps(
150
  scheduler.create_state(), num_inference_steps=timesteps, shape=latent_shape
151
  )
 
152
 
153
  # initialize latents
154
  initial_latents = (
 
157
  )
158
  * initial_scheduler_state.init_noise_sigma
159
  )
 
160
 
161
  final_latents, _ = jax.lax.fori_loop(
162
  0, timesteps, ___timestep, (initial_latents, initial_scheduler_state)
163
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ vae_output = vae.apply(
166
+ {"params": vae_params},
167
+ 1 / vae.config.scaling_factor * final_latents,
168
+ method=vae.decode,
169
+ ).sample
170
+
171
+ # return 8 bit RGB image (width, height, rgb)
172
+ return (
173
+ ((vae_output / 2 + 0.5).transpose(0, 2, 3, 1).clip(0, 1) * 255)
174
+ .round()
175
+ .astype(jnp.uint8)[0]
176
+ )
177
 
178
+ jax_jit_compiled_predict_image = jax.jit(__predict_image)
179
 
180
  return lambda prompt: __convert_image(
181
+ jax_jit_compiled_predict_image(__tokenize_prompt(prompt))
182
  )
183
 
184
 
185
  generate_image_for_prompt = get_inference_lambda(87)
186
 
 
 
 
 
 
 
 
 
187
 
188
  with gr.Blocks(theme="gradio/soft") as demo:
189
 
 
215
  submit_btn = gr.Button(value="Submit")
216
  charred_inputs = [prompt_input_charr]
217
  submit_btn.click(
218
+ fn=generate_image_for_prompt,
219
+ inputs=charred_inputs,
220
+ outputs=[charred_output],
221
  )
222
  # examples = [["postage stamp from california", "low quality", "charr_output.png", "charr_output.png" ]]
223
  # gr.Examples(fn = infer_sd, inputs = ["text", "text", "image", "image"], examples=examples, cache_examples=True)
224
 
225
  demo.queue(concurrency_count=1)
226
+ demo.launch(debug=True, show_error=True)