Pie31415 commited on
Commit
527b597
1 Parent(s): ac67678
app.py CHANGED
@@ -39,7 +39,7 @@ Our code uses <a href="https://www.humphreyshi.com/home">Text2Video-Zero</a> and
39
  notice = """
40
  <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
41
  <br/>
42
- <a href="https://github.com/Pie31415/control-animation">
43
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
44
  </p>
45
  """
@@ -51,13 +51,8 @@ with gr.Blocks(css="style.css") as demo:
51
  if on_huggingspace:
52
  gr.HTML(notice)
53
 
54
- # NOTE: In our final demo we should consider removing zero-shot t2v and pose conditional
55
  with gr.Tab("Control Animation"):
56
  create_demo_animation(model)
57
- # with gr.Tab("Zero-Shot Text2Video"):
58
- # create_demo_text_to_video(model)
59
- # with gr.Tab("Pose Conditional"):
60
- # create_demo_pose(model)
61
 
62
  if on_huggingspace:
63
  demo.queue(max_size=20)
 
39
  notice = """
40
  <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
41
  <br/>
42
+ <a href="https://huggingface.co/spaces/Pie31415/control-animation?duplicate=true">
43
  <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
44
  </p>
45
  """
 
51
  if on_huggingspace:
52
  gr.HTML(notice)
53
 
 
54
  with gr.Tab("Control Animation"):
55
  create_demo_animation(model)
 
 
 
 
56
 
57
  if on_huggingspace:
58
  demo.queue(max_size=20)
text_to_animation/model.py CHANGED
@@ -3,7 +3,6 @@ from enum import Enum
3
  import gc
4
  import numpy as np
5
  import jax.numpy as jnp
6
- import tomesd
7
  import jax
8
 
9
  from PIL import Image
@@ -20,9 +19,12 @@ from diffusers import (
20
  FlaxAutoencoderKL,
21
  FlaxStableDiffusionControlNetPipeline,
22
  StableDiffusionPipeline,
 
23
  )
24
- from text_to_animation.models.unet_2d_condition_flax import FlaxUNet2DConditionModel
25
- from text_to_animation.models.controlnet_flax import FlaxControlNetModel
 
 
26
 
27
  from text_to_animation.pipelines.text_to_video_pipeline_flax import (
28
  FlaxTextToVideoPipeline,
@@ -48,37 +50,31 @@ def replicate_devices(array):
48
 
49
 
50
  class ControlAnimationModel:
51
- def __init__(self, device, dtype, **kwargs):
52
- self.device = device
53
  self.dtype = dtype
54
  self.rng = jax.random.PRNGKey(0)
55
- self.pipe_dict = {
56
- ModelType.Text2Video: FlaxTextToVideoPipeline, # TODO: Replace with our TextToVideo JAX Pipeline
57
- ModelType.ControlNetPose: FlaxStableDiffusionControlNetPipeline,
58
- }
59
  self.pipe = None
60
  self.model_type = None
61
 
62
  self.states = {}
63
  self.model_name = ""
64
 
65
- self.from_local = True # if the attn model is available in local (after adaptation by adapt_attn.py)
66
-
67
  def set_model(
68
  self,
69
- model_type: ModelType,
70
  model_id: str,
71
- controlnet,
72
- controlnet_params,
73
- tokenizer,
74
- scheduler,
75
- scheduler_state,
76
  **kwargs,
77
  ):
78
  if hasattr(self, "pipe") and self.pipe is not None:
79
  del self.pipe
80
  self.pipe = None
81
  gc.collect()
 
 
 
 
 
 
 
82
  scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained(
83
  model_id, subfolder="scheduler", from_pt=True
84
  )
@@ -86,17 +82,12 @@ class ControlAnimationModel:
86
  feature_extractor = CLIPFeatureExtractor.from_pretrained(
87
  model_id, subfolder="feature_extractor"
88
  )
89
- if self.from_local:
90
- unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
91
- f'./{model_id.split("/")[-1]}',
92
- subfolder="unet",
93
- from_pt=True,
94
- dtype=self.dtype,
95
- )
96
- else:
97
- unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
98
- model_id, subfolder="unet", from_pt=True, dtype=self.dtype
99
- )
100
  vae, vae_params = FlaxAutoencoderKL.from_pretrained(
101
  model_id, subfolder="vae", from_pt=True, dtype=self.dtype
102
  )
@@ -108,6 +99,7 @@ class ControlAnimationModel:
108
  text_encoder=text_encoder,
109
  tokenizer=tokenizer,
110
  unet=unet,
 
111
  controlnet=controlnet,
112
  scheduler=scheduler,
113
  safety_checker=None,
@@ -121,313 +113,52 @@ class ControlAnimationModel:
121
  "text_encoder": text_encoder.params,
122
  }
123
  self.p_params = jax_utils.replicate(self.params)
124
-
125
- self.model_type = model_type
126
  self.model_name = model_id
127
 
128
- # def inference_chunk(self, image, frame_ids, prompt, negative_prompt, **kwargs):
129
-
130
- # prompt_ids = self.pipe.prepare_text_inputs(prompt)
131
- # n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt)
132
- # latents = kwargs.pop('latents')
133
- # # rng = jax.random.split(self.rng, jax.device_count())
134
- # prng, self.rng = jax.random.split(self.rng)
135
- # #prng = jax.numpy.stack([prng] * jax.device_count())#same prng seed on every device
136
- # prng_seed = jax.random.split(prng, jax.device_count())
137
- # image = replicate_devices(image[frame_ids])
138
- # latents = replicate_devices(latents)
139
- # prompt_ids = replicate_devices(prompt_ids)
140
- # n_prompt_ids = replicate_devices(n_prompt_ids)
141
- # return (self.pipe(image=image,
142
- # latents=latents,
143
- # prompt_ids=prompt_ids,
144
- # neg_prompt_ids=n_prompt_ids,
145
- # params=self.p_params,
146
- # prng_seed=prng_seed, jit = True,
147
- # ).images)[0]
148
-
149
- def inference(self, image, split_to_chunks=False, chunk_size=8, **kwargs):
150
- if not hasattr(self, "pipe") or self.pipe is None:
151
- return
152
-
153
- if "merging_ratio" in kwargs:
154
- merging_ratio = kwargs.pop("merging_ratio")
155
-
156
- # if merging_ratio > 0:
157
- tomesd.apply_patch(self.pipe, ratio=merging_ratio)
158
-
159
- # f = image.shape[0]
160
-
161
- assert "prompt" in kwargs
162
- prompt = [kwargs.pop("prompt")]
163
- negative_prompt = [kwargs.pop("negative_prompt", "")]
164
-
165
- frames_counter = 0
166
-
167
- # Processing chunk-by-chunk
168
- if split_to_chunks:
169
- pass
170
- # # not tested
171
- # f = image.shape[0]
172
- # chunk_ids = np.arange(0, f, chunk_size - 1)
173
- # result = []
174
- # for i in range(len(chunk_ids)):
175
- # ch_start = chunk_ids[i]
176
- # ch_end = f if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
177
- # frame_ids = [0] + list(range(ch_start, ch_end))
178
- # print(f'Processing chunk {i + 1} / {len(chunk_ids)}')
179
- # result.append(self.inference_chunk(image=image,
180
- # frame_ids=frame_ids,
181
- # prompt=prompt,
182
- # negative_prompt=negative_prompt,
183
- # **kwargs).images[1:])
184
- # frames_counter += len(chunk_ids)-1
185
- # if on_huggingspace and frames_counter >= 80:
186
- # break
187
- # result = np.concatenate(result)
188
- # return result
189
- else:
190
- if "jit" in kwargs and kwargs.pop("jit"):
191
- prompt_ids = self.pipe.prepare_text_inputs(prompt)
192
- n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt)
193
- latents = kwargs.pop("latents")
194
- prng, self.rng = jax.random.split(self.rng)
195
- prng_seed = jax.random.split(prng, jax.device_count())
196
- image = replicate_devices(image)
197
- latents = replicate_devices(latents)
198
- prompt_ids = replicate_devices(prompt_ids)
199
- n_prompt_ids = replicate_devices(n_prompt_ids)
200
- return (
201
- self.pipe(
202
- image=image,
203
- latents=latents,
204
- prompt_ids=prompt_ids,
205
- neg_prompt_ids=n_prompt_ids,
206
- params=self.p_params,
207
- prng_seed=prng_seed,
208
- jit=True,
209
- ).images
210
- )[0]
211
- else:
212
- prompt_ids = self.pipe.prepare_text_inputs(prompt)
213
- n_prompt_ids = self.pipe.prepare_text_inputs(negative_prompt)
214
- latents = kwargs.pop("latents")
215
- prng_seed, self.rng = jax.random.split(self.rng)
216
- return self.pipe(
217
- image=image,
218
- latents=latents,
219
- prompt_ids=prompt_ids,
220
- neg_prompt_ids=n_prompt_ids,
221
- params=self.params,
222
- prng_seed=prng_seed,
223
- jit=False,
224
- ).images
225
-
226
- def process_controlnet_pose(
227
  self,
228
- video_path,
229
- prompt,
230
- chunk_size=8,
231
- watermark="Picsart AI Research",
232
- merging_ratio=0.0,
233
- num_inference_steps=20,
234
- controlnet_conditioning_scale=1.0,
235
- guidance_scale=9.0,
236
- seed=42,
237
- eta=0.0,
238
- resolution=512,
239
- use_cf_attn=True,
240
- save_path=None,
241
- ):
242
- print("Module Pose")
243
  video_path = gradio_utils.motion_to_video_path(video_path)
244
- if self.model_type != ModelType.ControlNetPose:
245
- controlnet = FlaxControlNetModel.from_pretrained(
246
- "fusing/stable-diffusion-v1-5-controlnet-openpose"
247
- )
248
- self.set_model(
249
- ModelType.ControlNetPose,
250
- model_id="runwayml/stable-diffusion-v1-5",
251
- controlnet=controlnet,
252
- )
253
- self.pipe.scheduler = FlaxDDIMScheduler.from_config(
254
- self.pipe.scheduler.config
255
- )
256
- if use_cf_attn:
257
- self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc)
258
- self.pipe.controlnet.set_attn_processor(
259
- processor=self.controlnet_attn_proc
260
- )
261
 
262
- video_path = (
263
- gradio_utils.motion_to_video_path(video_path)
264
- if "Motion" in video_path
265
- else video_path
266
- )
267
 
268
- added_prompt = "best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth"
269
- negative_prompts = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
270
 
271
  video, fps = utils.prepare_video(
272
- video_path, resolution, self.device, self.dtype, False, output_fps=4
273
- )
274
- control = (
275
- utils.pre_process_pose(video, apply_pose_detect=False)
276
- .to(self.device)
277
- .to(self.dtype)
278
  )
279
- f, _, h, w = video.shape
280
- self.generator.manual_seed(seed)
281
- latents = torch.randn(
282
- (1, 4, h // 8, w // 8),
283
- dtype=self.dtype,
284
- device=self.device,
285
- generator=self.generator,
286
- )
287
- latents = latents.repeat(f, 1, 1, 1)
288
- result = self.inference(
289
- image=control,
290
- prompt=prompt + ", " + added_prompt,
291
- height=h,
292
- width=w,
293
- negative_prompt=negative_prompts,
294
- num_inference_steps=num_inference_steps,
295
- guidance_scale=guidance_scale,
296
- controlnet_conditioning_scale=controlnet_conditioning_scale,
297
- eta=eta,
298
- latents=latents,
299
- seed=seed,
300
- output_type="numpy",
301
- split_to_chunks=True,
302
- chunk_size=chunk_size,
303
- merging_ratio=merging_ratio,
304
  )
305
- return utils.create_gif(
306
- result,
307
- fps,
308
- path=save_path,
309
- watermark=gradio_utils.logo_name_to_path(watermark),
310
- )
311
-
312
- def process_text2video(
313
- self,
314
- prompt,
315
- model_name="dreamlike-art/dreamlike-photoreal-2.0",
316
- motion_field_strength_x=12,
317
- motion_field_strength_y=12,
318
- t0=44,
319
- t1=47,
320
- n_prompt="",
321
- chunk_size=8,
322
- video_length=8,
323
- watermark="Picsart AI Research",
324
- merging_ratio=0.0,
325
- seed=0,
326
- resolution=512,
327
- fps=2,
328
- use_cf_attn=True,
329
- use_motion_field=True,
330
- smooth_bg=False,
331
- smooth_bg_strength=0.4,
332
- path=None,
333
- ):
334
- print("Module Text2Video")
335
- if self.model_type != ModelType.Text2Video or model_name != self.model_name:
336
- print("Model update")
337
- unet = FlaxUNet2DConditionModel.from_pretrained(
338
- model_name, subfolder="unet"
339
- )
340
- self.set_model(ModelType.Text2Video, model_id=model_name, unet=unet)
341
- self.pipe.scheduler = FlaxDDIMScheduler.from_config(
342
- self.pipe.scheduler.config
343
- )
344
- if use_cf_attn:
345
- self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc)
346
- self.generator.manual_seed(seed)
347
 
348
- added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
349
- negative_prompts = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
350
-
351
- prompt = prompt.rstrip()
352
- if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."):
353
- prompt = prompt.rstrip()[:-1]
354
- prompt = prompt.rstrip()
355
- prompt = prompt + ", " + added_prompt
356
- if len(n_prompt) > 0:
357
- negative_prompt = n_prompt
358
- else:
359
- negative_prompt = None
360
-
361
- result = self.inference(
362
- prompt=prompt,
363
- video_length=video_length,
364
- height=resolution,
365
- width=resolution,
366
- num_inference_steps=50,
367
- guidance_scale=7.5,
368
- guidance_stop_step=1.0,
369
- t0=t0,
370
- t1=t1,
371
- motion_field_strength_x=motion_field_strength_x,
372
- motion_field_strength_y=motion_field_strength_y,
373
- use_motion_field=use_motion_field,
374
- smooth_bg=smooth_bg,
375
- smooth_bg_strength=smooth_bg_strength,
376
- seed=seed,
377
- output_type="numpy",
378
- negative_prompt=negative_prompt,
379
- merging_ratio=merging_ratio,
380
- split_to_chunks=True,
381
- chunk_size=chunk_size,
382
- )
383
- return utils.create_video(
384
- result, fps, path=path, watermark=gradio_utils.logo_name_to_path(watermark)
385
- )
386
 
387
- @staticmethod
388
- def to_pil_images(images: torch.Tensor) -> List[Image.Image]:
389
- images = (images / 2 + 0.5).clamp(0, 1)
390
- images = images.cpu().permute(0, 2, 3, 1).float().numpy()
391
- images = np.round(images * 255).astype(np.uint8)
392
- return [Image.fromarray(image) for image in images]
393
-
394
- def generate_initial_frames(
395
- self,
396
- prompt: str,
397
- model_link: str = "dreamlike-art/dreamlike-photoreal-2.0",
398
- is_safetensor: bool = False,
399
- n_prompt: str = "",
400
- width: int = 512,
401
- height: int = 512,
402
- # batch_count: int = 4,
403
- # batch_size: int = 1,
404
- cfg_scale: float = 7.0,
405
- seed: int = 0,
406
- ) -> List[Image.Image]:
407
- generator = torch.Generator(device=self.device).manual_seed(seed)
408
- pipe = StableDiffusionPipeline.from_pretrained(model_link)
409
-
410
- batch_size = 4
411
- prompt = [prompt] * batch_size
412
- negative_prompt = [n_prompt] * batch_size
413
-
414
- images = pipe(
415
- prompt,
416
- negative_prompt=negative_prompt,
417
- width=width,
418
- height=height,
419
- guidance_scale=cfg_scale,
420
- generator=generator,
421
- ).images
422
- pil_images = self.to_pil_images(images)
423
-
424
- return pil_images
425
 
426
  def generate_animation(
427
  self,
428
  prompt: str,
 
 
429
  model_link: str = "dreamlike-art/dreamlike-photoreal-2.0",
430
- is_safetensor: bool = False,
431
  motion_field_strength_x: int = 12,
432
  motion_field_strength_y: int = 12,
433
  t0: int = 44,
@@ -445,6 +176,29 @@ class ControlAnimationModel:
445
  smooth_bg_strength: float = 0.4,
446
  path: str = None,
447
  ):
448
- if is_safetensor and model_link[-len(".safetensors") :] == ".safetensors":
449
- pipe = utils.load_safetensors_model(model_link)
450
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import gc
4
  import numpy as np
5
  import jax.numpy as jnp
 
6
  import jax
7
 
8
  from PIL import Image
 
19
  FlaxAutoencoderKL,
20
  FlaxStableDiffusionControlNetPipeline,
21
  StableDiffusionPipeline,
22
+ FlaxUNet2DConditionModel,
23
  )
24
+ from text_to_animation.models.unet_2d_condition_flax import (
25
+ FlaxUNet2DConditionModel as CustomFlaxUNet2DConditionModel,
26
+ )
27
+ from diffusers import FlaxControlNetModel
28
 
29
  from text_to_animation.pipelines.text_to_video_pipeline_flax import (
30
  FlaxTextToVideoPipeline,
 
50
 
51
 
52
  class ControlAnimationModel:
53
+ def __init__(self, dtype, **kwargs):
 
54
  self.dtype = dtype
55
  self.rng = jax.random.PRNGKey(0)
 
 
 
 
56
  self.pipe = None
57
  self.model_type = None
58
 
59
  self.states = {}
60
  self.model_name = ""
61
 
 
 
62
  def set_model(
63
  self,
 
64
  model_id: str,
 
 
 
 
 
65
  **kwargs,
66
  ):
67
  if hasattr(self, "pipe") and self.pipe is not None:
68
  del self.pipe
69
  self.pipe = None
70
  gc.collect()
71
+
72
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
73
+ "fusing/stable-diffusion-v1-5-controlnet-openpose",
74
+ from_pt=True,
75
+ dtype=jnp.float16,
76
+ )
77
+
78
  scheduler, scheduler_state = FlaxDDIMScheduler.from_pretrained(
79
  model_id, subfolder="scheduler", from_pt=True
80
  )
 
82
  feature_extractor = CLIPFeatureExtractor.from_pretrained(
83
  model_id, subfolder="feature_extractor"
84
  )
85
+ unet, unet_params = CustomFlaxUNet2DConditionModel.from_pretrained(
86
+ model_id, subfolder="unet", from_pt=True, dtype=self.dtype
87
+ )
88
+ unet_vanilla, _ = FlaxUNet2DConditionModel.from_pretrained(
89
+ model_id, subfolder="unet", from_pt=True, dtype=self.dtype
90
+ )
 
 
 
 
 
91
  vae, vae_params = FlaxAutoencoderKL.from_pretrained(
92
  model_id, subfolder="vae", from_pt=True, dtype=self.dtype
93
  )
 
99
  text_encoder=text_encoder,
100
  tokenizer=tokenizer,
101
  unet=unet,
102
+ unet_vanilla=unet_vanilla,
103
  controlnet=controlnet,
104
  scheduler=scheduler,
105
  safety_checker=None,
 
113
  "text_encoder": text_encoder.params,
114
  }
115
  self.p_params = jax_utils.replicate(self.params)
 
 
116
  self.model_name = model_id
117
 
118
+ def generate_initial_frames(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  self,
120
+ prompt: str,
121
+ video_path: str,
122
+ n_prompt: str = "",
123
+ num_imgs: int = 4,
124
+ resolution: int = 512,
125
+ model_id: str = "runwayml/stable-diffusion-v1-5",
126
+ ) -> List[Image.Image]:
127
+ self.set_model(model_id=model_id)
128
+
 
 
 
 
 
 
129
  video_path = gradio_utils.motion_to_video_path(video_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ added_prompt = "high quality, best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth"
132
+ prompts = added_prompt + ", " + prompt
 
 
 
133
 
134
+ added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly"
135
+ negative_prompts = added_n_prompt + ", " + n_prompt
136
 
137
  video, fps = utils.prepare_video(
138
+ video_path, resolution, None, self.dtype, False, output_fps=4
 
 
 
 
 
139
  )
140
+ control = utils.pre_process_pose(video, apply_pose_detect=False)
141
+
142
+ seeds = [seed for seed in jax.random.randint(self.rng, [num_imgs], 0, 65536)]
143
+ prngs = [jax.random.PRNGKey(seed) for seed in seeds]
144
+ images = self.pipe.generate_starting_frames(
145
+ params=self.params,
146
+ prngs=prngs,
147
+ controlnet_image=control,
148
+ prompt=prompts,
149
+ neg_prompt=negative_prompts,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ images = [np.array(images[i]) for i in range(images.shape[0])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  def generate_animation(
157
  self,
158
  prompt: str,
159
+ initial_frame_index: int,
160
+ input_video_path: str,
161
  model_link: str = "dreamlike-art/dreamlike-photoreal-2.0",
 
162
  motion_field_strength_x: int = 12,
163
  motion_field_strength_y: int = 12,
164
  t0: int = 44,
 
176
  smooth_bg_strength: float = 0.4,
177
  path: str = None,
178
  ):
179
+ video_path = gradio_utils.motion_to_video_path(video_path)
180
+
181
+ # added_prompt = 'best quality, HD, clay stop-motion, claymation, HQ, masterpiece, art, smooth'
182
+ # added_prompt = 'high quality, anatomically correct, clay stop-motion, aardman, claymation, smooth'
183
+ added_n_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly"
184
+ negative_prompts = added_n_prompt + ", " + n_prompt
185
+
186
+ video, fps = utils.prepare_video(
187
+ video_path, resolution, None, self.dtype, False, output_fps=4
188
+ )
189
+ control = utils.pre_process_pose(video, apply_pose_detect=False)
190
+ f, _, h, w = video.shape
191
+
192
+ prng_seed = jax.random.PRNGKey(seed)
193
+ vid = self.pipe.generate_video(
194
+ prompt,
195
+ image=control,
196
+ params=self.params,
197
+ prng_seed=prng_seed,
198
+ neg_prompt="",
199
+ controlnet_conditioning_scale=1.0,
200
+ motion_field_strength_x=3,
201
+ motion_field_strength_y=4,
202
+ jit=True,
203
+ ).image
204
+ return utils.create_gif(np.array(vid), 4, path=None, watermark=None)
text_to_animation/models/cross_frame_attention_flax.py CHANGED
@@ -50,7 +50,6 @@ class FlaxCrossFrameAttention(nn.Module):
50
  batch_size: The number that represents actual batch size, other than the frames.
51
  For example, using calling unet with a single prompt and num_images_per_prompt=1, batch_size should be
52
  equal to 2, due to classifier-free guidance.
53
-
54
  """
55
  query_dim: int
56
  heads: int = 8
 
50
  batch_size: The number that represents actual batch size, other than the frames.
51
  For example, using calling unet with a single prompt and num_images_per_prompt=1, batch_size should be
52
  equal to 2, due to classifier-free guidance.
 
53
  """
54
  query_dim: int
55
  heads: int = 8
text_to_animation/models/unet_3d_blocks_flax.py ADDED
@@ -0,0 +1,717 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from torch import nn
17
+
18
+ # from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
19
+ # from diffusers.models.transformer_2d import Transformer2DModel
20
+ # from .transformer_temporal import TransformerTemporalModel
21
+
22
+ from diffusers.models.resnet_flax import (
23
+ FlaxDownsample2D,
24
+ FlaxResnetBlock2D,
25
+ FlaxUpsample2D,
26
+ )
27
+ from diffusers.models.attention_flax import FlaxTransformer2DModel
28
+ from diffusers.models.transformer_temporal import (
29
+ TransformerTemporalModel,
30
+ ) # TODO: convert to flax
31
+
32
+
33
+ def get_down_block(
34
+ down_block_type,
35
+ num_layers,
36
+ in_channels,
37
+ out_channels,
38
+ temb_channels,
39
+ add_downsample,
40
+ resnet_eps,
41
+ resnet_act_fn,
42
+ attn_num_head_channels,
43
+ resnet_groups=None,
44
+ cross_attention_dim=None,
45
+ downsample_padding=None,
46
+ dual_cross_attention=False,
47
+ use_linear_projection=True,
48
+ only_cross_attention=False,
49
+ upcast_attention=False,
50
+ resnet_time_scale_shift="default",
51
+ ):
52
+ if down_block_type == "DownBlock3D":
53
+ return DownBlock3D(
54
+ num_layers=num_layers,
55
+ in_channels=in_channels,
56
+ out_channels=out_channels,
57
+ temb_channels=temb_channels,
58
+ add_downsample=add_downsample,
59
+ resnet_eps=resnet_eps,
60
+ resnet_act_fn=resnet_act_fn,
61
+ resnet_groups=resnet_groups,
62
+ downsample_padding=downsample_padding,
63
+ resnet_time_scale_shift=resnet_time_scale_shift,
64
+ )
65
+ elif down_block_type == "CrossAttnDownBlock3D":
66
+ if cross_attention_dim is None:
67
+ raise ValueError(
68
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
69
+ )
70
+ return CrossAttnDownBlock3D(
71
+ num_layers=num_layers,
72
+ in_channels=in_channels,
73
+ out_channels=out_channels,
74
+ temb_channels=temb_channels,
75
+ add_downsample=add_downsample,
76
+ resnet_eps=resnet_eps,
77
+ resnet_act_fn=resnet_act_fn,
78
+ resnet_groups=resnet_groups,
79
+ downsample_padding=downsample_padding,
80
+ cross_attention_dim=cross_attention_dim,
81
+ attn_num_head_channels=attn_num_head_channels,
82
+ dual_cross_attention=dual_cross_attention,
83
+ use_linear_projection=use_linear_projection,
84
+ only_cross_attention=only_cross_attention,
85
+ upcast_attention=upcast_attention,
86
+ resnet_time_scale_shift=resnet_time_scale_shift,
87
+ )
88
+ raise ValueError(f"{down_block_type} does not exist.")
89
+
90
+
91
+ def get_up_block(
92
+ up_block_type,
93
+ num_layers,
94
+ in_channels,
95
+ out_channels,
96
+ prev_output_channel,
97
+ temb_channels,
98
+ add_upsample,
99
+ resnet_eps,
100
+ resnet_act_fn,
101
+ attn_num_head_channels,
102
+ resnet_groups=None,
103
+ cross_attention_dim=None,
104
+ dual_cross_attention=False,
105
+ use_linear_projection=True,
106
+ only_cross_attention=False,
107
+ upcast_attention=False,
108
+ resnet_time_scale_shift="default",
109
+ ):
110
+ if up_block_type == "UpBlock3D":
111
+ return UpBlock3D(
112
+ num_layers=num_layers,
113
+ in_channels=in_channels,
114
+ out_channels=out_channels,
115
+ prev_output_channel=prev_output_channel,
116
+ temb_channels=temb_channels,
117
+ add_upsample=add_upsample,
118
+ resnet_eps=resnet_eps,
119
+ resnet_act_fn=resnet_act_fn,
120
+ resnet_groups=resnet_groups,
121
+ resnet_time_scale_shift=resnet_time_scale_shift,
122
+ )
123
+ elif up_block_type == "CrossAttnUpBlock3D":
124
+ if cross_attention_dim is None:
125
+ raise ValueError(
126
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
127
+ )
128
+ return CrossAttnUpBlock3D(
129
+ num_layers=num_layers,
130
+ in_channels=in_channels,
131
+ out_channels=out_channels,
132
+ prev_output_channel=prev_output_channel,
133
+ temb_channels=temb_channels,
134
+ add_upsample=add_upsample,
135
+ resnet_eps=resnet_eps,
136
+ resnet_act_fn=resnet_act_fn,
137
+ resnet_groups=resnet_groups,
138
+ cross_attention_dim=cross_attention_dim,
139
+ attn_num_head_channels=attn_num_head_channels,
140
+ dual_cross_attention=dual_cross_attention,
141
+ use_linear_projection=use_linear_projection,
142
+ only_cross_attention=only_cross_attention,
143
+ upcast_attention=upcast_attention,
144
+ resnet_time_scale_shift=resnet_time_scale_shift,
145
+ )
146
+ raise ValueError(f"{up_block_type} does not exist.")
147
+
148
+
149
+ class FlaxUNetMidBlock3DCrossAttn(nn.Module):
150
+ def __init__(
151
+ self,
152
+ in_channels: int,
153
+ temb_channels: int,
154
+ dropout: float = 0.0,
155
+ num_layers: int = 1,
156
+ resnet_eps: float = 1e-6,
157
+ resnet_time_scale_shift: str = "default",
158
+ resnet_act_fn: str = "swish",
159
+ resnet_groups: int = 32,
160
+ resnet_pre_norm: bool = True,
161
+ attn_num_head_channels=1,
162
+ output_scale_factor=1.0,
163
+ cross_attention_dim=1280,
164
+ dual_cross_attention=False,
165
+ use_linear_projection=True,
166
+ upcast_attention=False,
167
+ ):
168
+ super().__init__()
169
+
170
+ self.has_cross_attention = True
171
+ self.attn_num_head_channels = attn_num_head_channels
172
+ resnet_groups = (
173
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
174
+ )
175
+
176
+ # there is always at least one resnet
177
+ resnets = [
178
+ FlaxResnetBlock2D(
179
+ in_channels=in_channels,
180
+ out_channels=in_channels,
181
+ temb_channels=temb_channels,
182
+ eps=resnet_eps,
183
+ groups=resnet_groups,
184
+ dropout=dropout,
185
+ time_embedding_norm=resnet_time_scale_shift,
186
+ non_linearity=resnet_act_fn,
187
+ output_scale_factor=output_scale_factor,
188
+ pre_norm=resnet_pre_norm,
189
+ )
190
+ ]
191
+ temp_convs = [
192
+ TemporalConvLayer(
193
+ in_channels,
194
+ in_channels,
195
+ dropout=0.1,
196
+ )
197
+ ]
198
+ attentions = []
199
+ temp_attentions = []
200
+
201
+ for _ in range(num_layers):
202
+ attentions.append(
203
+ Transformer2DModel(
204
+ in_channels // attn_num_head_channels,
205
+ attn_num_head_channels,
206
+ in_channels=in_channels,
207
+ num_layers=1,
208
+ cross_attention_dim=cross_attention_dim,
209
+ norm_num_groups=resnet_groups,
210
+ use_linear_projection=use_linear_projection,
211
+ upcast_attention=upcast_attention,
212
+ )
213
+ )
214
+ temp_attentions.append(
215
+ TransformerTemporalModel(
216
+ in_channels // attn_num_head_channels,
217
+ attn_num_head_channels,
218
+ in_channels=in_channels,
219
+ num_layers=1,
220
+ cross_attention_dim=cross_attention_dim,
221
+ norm_num_groups=resnet_groups,
222
+ )
223
+ )
224
+ resnets.append(
225
+ ResnetBlock2D(
226
+ in_channels=in_channels,
227
+ out_channels=in_channels,
228
+ temb_channels=temb_channels,
229
+ eps=resnet_eps,
230
+ groups=resnet_groups,
231
+ dropout=dropout,
232
+ time_embedding_norm=resnet_time_scale_shift,
233
+ non_linearity=resnet_act_fn,
234
+ output_scale_factor=output_scale_factor,
235
+ pre_norm=resnet_pre_norm,
236
+ )
237
+ )
238
+ temp_convs.append(
239
+ TemporalConvLayer(
240
+ in_channels,
241
+ in_channels,
242
+ dropout=0.1,
243
+ )
244
+ )
245
+
246
+ self.resnets = nn.ModuleList(resnets)
247
+ self.temp_convs = nn.ModuleList(temp_convs)
248
+ self.attentions = nn.ModuleList(attentions)
249
+ self.temp_attentions = nn.ModuleList(temp_attentions)
250
+
251
+ def forward(
252
+ self,
253
+ hidden_states,
254
+ temb=None,
255
+ encoder_hidden_states=None,
256
+ attention_mask=None,
257
+ num_frames=1,
258
+ cross_attention_kwargs=None,
259
+ ):
260
+ hidden_states = self.resnets[0](hidden_states, temb)
261
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
262
+ for attn, temp_attn, resnet, temp_conv in zip(
263
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
264
+ ):
265
+ hidden_states = attn(
266
+ hidden_states,
267
+ encoder_hidden_states=encoder_hidden_states,
268
+ cross_attention_kwargs=cross_attention_kwargs,
269
+ ).sample
270
+ hidden_states = temp_attn(
271
+ hidden_states,
272
+ num_frames=num_frames,
273
+ cross_attention_kwargs=cross_attention_kwargs,
274
+ ).sample
275
+ hidden_states = resnet(hidden_states, temb)
276
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
277
+
278
+ return hidden_states
279
+
280
+
281
+ class CrossAttnDownBlock3D(nn.Module):
282
+ def __init__(
283
+ self,
284
+ in_channels: int,
285
+ out_channels: int,
286
+ temb_channels: int,
287
+ dropout: float = 0.0,
288
+ num_layers: int = 1,
289
+ resnet_eps: float = 1e-6,
290
+ resnet_time_scale_shift: str = "default",
291
+ resnet_act_fn: str = "swish",
292
+ resnet_groups: int = 32,
293
+ resnet_pre_norm: bool = True,
294
+ attn_num_head_channels=1,
295
+ cross_attention_dim=1280,
296
+ output_scale_factor=1.0,
297
+ downsample_padding=1,
298
+ add_downsample=True,
299
+ dual_cross_attention=False,
300
+ use_linear_projection=False,
301
+ only_cross_attention=False,
302
+ upcast_attention=False,
303
+ ):
304
+ super().__init__()
305
+ resnets = []
306
+ attentions = []
307
+ temp_attentions = []
308
+ temp_convs = []
309
+
310
+ self.has_cross_attention = True
311
+ self.attn_num_head_channels = attn_num_head_channels
312
+
313
+ for i in range(num_layers):
314
+ in_channels = in_channels if i == 0 else out_channels
315
+ resnets.append(
316
+ ResnetBlock2D(
317
+ in_channels=in_channels,
318
+ out_channels=out_channels,
319
+ temb_channels=temb_channels,
320
+ eps=resnet_eps,
321
+ groups=resnet_groups,
322
+ dropout=dropout,
323
+ time_embedding_norm=resnet_time_scale_shift,
324
+ non_linearity=resnet_act_fn,
325
+ output_scale_factor=output_scale_factor,
326
+ pre_norm=resnet_pre_norm,
327
+ )
328
+ )
329
+ temp_convs.append(
330
+ TemporalConvLayer(
331
+ out_channels,
332
+ out_channels,
333
+ dropout=0.1,
334
+ )
335
+ )
336
+ attentions.append(
337
+ Transformer2DModel(
338
+ out_channels // attn_num_head_channels,
339
+ attn_num_head_channels,
340
+ in_channels=out_channels,
341
+ num_layers=1,
342
+ cross_attention_dim=cross_attention_dim,
343
+ norm_num_groups=resnet_groups,
344
+ use_linear_projection=use_linear_projection,
345
+ only_cross_attention=only_cross_attention,
346
+ upcast_attention=upcast_attention,
347
+ )
348
+ )
349
+ temp_attentions.append(
350
+ TransformerTemporalModel(
351
+ out_channels // attn_num_head_channels,
352
+ attn_num_head_channels,
353
+ in_channels=out_channels,
354
+ num_layers=1,
355
+ cross_attention_dim=cross_attention_dim,
356
+ norm_num_groups=resnet_groups,
357
+ )
358
+ )
359
+ self.resnets = nn.ModuleList(resnets)
360
+ self.temp_convs = nn.ModuleList(temp_convs)
361
+ self.attentions = nn.ModuleList(attentions)
362
+ self.temp_attentions = nn.ModuleList(temp_attentions)
363
+
364
+ if add_downsample:
365
+ self.downsamplers = nn.ModuleList(
366
+ [
367
+ Downsample2D(
368
+ out_channels,
369
+ use_conv=True,
370
+ out_channels=out_channels,
371
+ padding=downsample_padding,
372
+ name="op",
373
+ )
374
+ ]
375
+ )
376
+ else:
377
+ self.downsamplers = None
378
+
379
+ self.gradient_checkpointing = False
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states,
384
+ temb=None,
385
+ encoder_hidden_states=None,
386
+ attention_mask=None,
387
+ num_frames=1,
388
+ cross_attention_kwargs=None,
389
+ ):
390
+ # TODO(Patrick, William) - attention mask is not used
391
+ output_states = ()
392
+
393
+ for resnet, temp_conv, attn, temp_attn in zip(
394
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
395
+ ):
396
+ hidden_states = resnet(hidden_states, temb)
397
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
398
+ hidden_states = attn(
399
+ hidden_states,
400
+ encoder_hidden_states=encoder_hidden_states,
401
+ cross_attention_kwargs=cross_attention_kwargs,
402
+ ).sample
403
+ hidden_states = temp_attn(
404
+ hidden_states,
405
+ num_frames=num_frames,
406
+ cross_attention_kwargs=cross_attention_kwargs,
407
+ ).sample
408
+
409
+ output_states += (hidden_states,)
410
+
411
+ if self.downsamplers is not None:
412
+ for downsampler in self.downsamplers:
413
+ hidden_states = downsampler(hidden_states)
414
+
415
+ output_states += (hidden_states,)
416
+
417
+ return hidden_states, output_states
418
+
419
+
420
+ class DownBlock3D(nn.Module):
421
+ def __init__(
422
+ self,
423
+ in_channels: int,
424
+ out_channels: int,
425
+ temb_channels: int,
426
+ dropout: float = 0.0,
427
+ num_layers: int = 1,
428
+ resnet_eps: float = 1e-6,
429
+ resnet_time_scale_shift: str = "default",
430
+ resnet_act_fn: str = "swish",
431
+ resnet_groups: int = 32,
432
+ resnet_pre_norm: bool = True,
433
+ output_scale_factor=1.0,
434
+ add_downsample=True,
435
+ downsample_padding=1,
436
+ ):
437
+ super().__init__()
438
+ resnets = []
439
+ temp_convs = []
440
+
441
+ for i in range(num_layers):
442
+ in_channels = in_channels if i == 0 else out_channels
443
+ resnets.append(
444
+ ResnetBlock2D(
445
+ in_channels=in_channels,
446
+ out_channels=out_channels,
447
+ temb_channels=temb_channels,
448
+ eps=resnet_eps,
449
+ groups=resnet_groups,
450
+ dropout=dropout,
451
+ time_embedding_norm=resnet_time_scale_shift,
452
+ non_linearity=resnet_act_fn,
453
+ output_scale_factor=output_scale_factor,
454
+ pre_norm=resnet_pre_norm,
455
+ )
456
+ )
457
+ temp_convs.append(
458
+ TemporalConvLayer(
459
+ out_channels,
460
+ out_channels,
461
+ dropout=0.1,
462
+ )
463
+ )
464
+
465
+ self.resnets = nn.ModuleList(resnets)
466
+ self.temp_convs = nn.ModuleList(temp_convs)
467
+
468
+ if add_downsample:
469
+ self.downsamplers = nn.ModuleList(
470
+ [
471
+ Downsample2D(
472
+ out_channels,
473
+ use_conv=True,
474
+ out_channels=out_channels,
475
+ padding=downsample_padding,
476
+ name="op",
477
+ )
478
+ ]
479
+ )
480
+ else:
481
+ self.downsamplers = None
482
+
483
+ self.gradient_checkpointing = False
484
+
485
+ def forward(self, hidden_states, temb=None, num_frames=1):
486
+ output_states = ()
487
+
488
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
489
+ hidden_states = resnet(hidden_states, temb)
490
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
491
+
492
+ output_states += (hidden_states,)
493
+
494
+ if self.downsamplers is not None:
495
+ for downsampler in self.downsamplers:
496
+ hidden_states = downsampler(hidden_states)
497
+
498
+ output_states += (hidden_states,)
499
+
500
+ return hidden_states, output_states
501
+
502
+
503
+ class CrossAttnUpBlock3D(nn.Module):
504
+ def __init__(
505
+ self,
506
+ in_channels: int,
507
+ out_channels: int,
508
+ prev_output_channel: int,
509
+ temb_channels: int,
510
+ dropout: float = 0.0,
511
+ num_layers: int = 1,
512
+ resnet_eps: float = 1e-6,
513
+ resnet_time_scale_shift: str = "default",
514
+ resnet_act_fn: str = "swish",
515
+ resnet_groups: int = 32,
516
+ resnet_pre_norm: bool = True,
517
+ attn_num_head_channels=1,
518
+ cross_attention_dim=1280,
519
+ output_scale_factor=1.0,
520
+ add_upsample=True,
521
+ dual_cross_attention=False,
522
+ use_linear_projection=False,
523
+ only_cross_attention=False,
524
+ upcast_attention=False,
525
+ ):
526
+ super().__init__()
527
+ resnets = []
528
+ temp_convs = []
529
+ attentions = []
530
+ temp_attentions = []
531
+
532
+ self.has_cross_attention = True
533
+ self.attn_num_head_channels = attn_num_head_channels
534
+
535
+ for i in range(num_layers):
536
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
537
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
538
+
539
+ resnets.append(
540
+ ResnetBlock2D(
541
+ in_channels=resnet_in_channels + res_skip_channels,
542
+ out_channels=out_channels,
543
+ temb_channels=temb_channels,
544
+ eps=resnet_eps,
545
+ groups=resnet_groups,
546
+ dropout=dropout,
547
+ time_embedding_norm=resnet_time_scale_shift,
548
+ non_linearity=resnet_act_fn,
549
+ output_scale_factor=output_scale_factor,
550
+ pre_norm=resnet_pre_norm,
551
+ )
552
+ )
553
+ temp_convs.append(
554
+ TemporalConvLayer(
555
+ out_channels,
556
+ out_channels,
557
+ dropout=0.1,
558
+ )
559
+ )
560
+ attentions.append(
561
+ Transformer2DModel(
562
+ out_channels // attn_num_head_channels,
563
+ attn_num_head_channels,
564
+ in_channels=out_channels,
565
+ num_layers=1,
566
+ cross_attention_dim=cross_attention_dim,
567
+ norm_num_groups=resnet_groups,
568
+ use_linear_projection=use_linear_projection,
569
+ only_cross_attention=only_cross_attention,
570
+ upcast_attention=upcast_attention,
571
+ )
572
+ )
573
+ temp_attentions.append(
574
+ TransformerTemporalModel(
575
+ out_channels // attn_num_head_channels,
576
+ attn_num_head_channels,
577
+ in_channels=out_channels,
578
+ num_layers=1,
579
+ cross_attention_dim=cross_attention_dim,
580
+ norm_num_groups=resnet_groups,
581
+ )
582
+ )
583
+ self.resnets = nn.ModuleList(resnets)
584
+ self.temp_convs = nn.ModuleList(temp_convs)
585
+ self.attentions = nn.ModuleList(attentions)
586
+ self.temp_attentions = nn.ModuleList(temp_attentions)
587
+
588
+ if add_upsample:
589
+ self.upsamplers = nn.ModuleList(
590
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
591
+ )
592
+ else:
593
+ self.upsamplers = None
594
+
595
+ self.gradient_checkpointing = False
596
+
597
+ def forward(
598
+ self,
599
+ hidden_states,
600
+ res_hidden_states_tuple,
601
+ temb=None,
602
+ encoder_hidden_states=None,
603
+ upsample_size=None,
604
+ attention_mask=None,
605
+ num_frames=1,
606
+ cross_attention_kwargs=None,
607
+ ):
608
+ # TODO(Patrick, William) - attention mask is not used
609
+ for resnet, temp_conv, attn, temp_attn in zip(
610
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
611
+ ):
612
+ # pop res hidden states
613
+ res_hidden_states = res_hidden_states_tuple[-1]
614
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
615
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
616
+
617
+ hidden_states = resnet(hidden_states, temb)
618
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
619
+ hidden_states = attn(
620
+ hidden_states,
621
+ encoder_hidden_states=encoder_hidden_states,
622
+ cross_attention_kwargs=cross_attention_kwargs,
623
+ ).sample
624
+ hidden_states = temp_attn(
625
+ hidden_states,
626
+ num_frames=num_frames,
627
+ cross_attention_kwargs=cross_attention_kwargs,
628
+ ).sample
629
+
630
+ if self.upsamplers is not None:
631
+ for upsampler in self.upsamplers:
632
+ hidden_states = upsampler(hidden_states, upsample_size)
633
+
634
+ return hidden_states
635
+
636
+
637
+ class UpBlock3D(nn.Module):
638
+ def __init__(
639
+ self,
640
+ in_channels: int,
641
+ prev_output_channel: int,
642
+ out_channels: int,
643
+ temb_channels: int,
644
+ dropout: float = 0.0,
645
+ num_layers: int = 1,
646
+ resnet_eps: float = 1e-6,
647
+ resnet_time_scale_shift: str = "default",
648
+ resnet_act_fn: str = "swish",
649
+ resnet_groups: int = 32,
650
+ resnet_pre_norm: bool = True,
651
+ output_scale_factor=1.0,
652
+ add_upsample=True,
653
+ ):
654
+ super().__init__()
655
+ resnets = []
656
+ temp_convs = []
657
+
658
+ for i in range(num_layers):
659
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
660
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
661
+
662
+ resnets.append(
663
+ ResnetBlock2D(
664
+ in_channels=resnet_in_channels + res_skip_channels,
665
+ out_channels=out_channels,
666
+ temb_channels=temb_channels,
667
+ eps=resnet_eps,
668
+ groups=resnet_groups,
669
+ dropout=dropout,
670
+ time_embedding_norm=resnet_time_scale_shift,
671
+ non_linearity=resnet_act_fn,
672
+ output_scale_factor=output_scale_factor,
673
+ pre_norm=resnet_pre_norm,
674
+ )
675
+ )
676
+ temp_convs.append(
677
+ TemporalConvLayer(
678
+ out_channels,
679
+ out_channels,
680
+ dropout=0.1,
681
+ )
682
+ )
683
+
684
+ self.resnets = nn.ModuleList(resnets)
685
+ self.temp_convs = nn.ModuleList(temp_convs)
686
+
687
+ if add_upsample:
688
+ self.upsamplers = nn.ModuleList(
689
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
690
+ )
691
+ else:
692
+ self.upsamplers = None
693
+
694
+ self.gradient_checkpointing = False
695
+
696
+ def forward(
697
+ self,
698
+ hidden_states,
699
+ res_hidden_states_tuple,
700
+ temb=None,
701
+ upsample_size=None,
702
+ num_frames=1,
703
+ ):
704
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
705
+ # pop res hidden states
706
+ res_hidden_states = res_hidden_states_tuple[-1]
707
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
708
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
709
+
710
+ hidden_states = resnet(hidden_states, temb)
711
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
712
+
713
+ if self.upsamplers is not None:
714
+ for upsampler in self.upsamplers:
715
+ hidden_states = upsampler(hidden_states, upsample_size)
716
+
717
+ return hidden_states
text_to_animation/models/unet_3d_condition_flax.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2023 The ModelScope Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..loaders import UNet2DConditionLoadersMixin
24
+ from ..utils import BaseOutput, logging
25
+ from .attention_processor import AttentionProcessor, AttnProcessor
26
+ from .embeddings import TimestepEmbedding, Timesteps
27
+ from .modeling_utils import ModelMixin
28
+ from .transformer_temporal import TransformerTemporalModel
29
+ from .unet_3d_blocks import (
30
+ CrossAttnDownBlock3D,
31
+ CrossAttnUpBlock3D,
32
+ DownBlock3D,
33
+ UNetMidBlock3DCrossAttn,
34
+ UpBlock3D,
35
+ get_down_block,
36
+ get_up_block,
37
+ )
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+
43
+ @dataclass
44
+ class UNet3DConditionOutput(BaseOutput):
45
+ """
46
+ Args:
47
+ sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
48
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
49
+ """
50
+
51
+ sample: torch.FloatTensor
52
+
53
+
54
+ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
55
+ r"""
56
+ UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
57
+ and returns sample shaped output.
58
+
59
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
60
+ implements for all the models (such as downloading or saving, etc.)
61
+
62
+ Parameters:
63
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
64
+ Height and width of input/output sample.
65
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
66
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
67
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
68
+ The tuple of downsample blocks to use.
69
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
70
+ The tuple of upsample blocks to use.
71
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
72
+ The tuple of output channels for each block.
73
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
74
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
75
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
76
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
77
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
78
+ If `None`, it will skip the normalization and activation layers in post-processing
79
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
80
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
81
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
82
+ """
83
+
84
+ _supports_gradient_checkpointing = False
85
+
86
+ @register_to_config
87
+ def __init__(
88
+ self,
89
+ sample_size: Optional[int] = None,
90
+ in_channels: int = 4,
91
+ out_channels: int = 4,
92
+ down_block_types: Tuple[str] = (
93
+ "CrossAttnDownBlock3D",
94
+ "CrossAttnDownBlock3D",
95
+ "CrossAttnDownBlock3D",
96
+ "DownBlock3D",
97
+ ),
98
+ up_block_types: Tuple[str] = (
99
+ "UpBlock3D",
100
+ "CrossAttnUpBlock3D",
101
+ "CrossAttnUpBlock3D",
102
+ "CrossAttnUpBlock3D",
103
+ ),
104
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
105
+ layers_per_block: int = 2,
106
+ downsample_padding: int = 1,
107
+ mid_block_scale_factor: float = 1,
108
+ act_fn: str = "silu",
109
+ norm_num_groups: Optional[int] = 32,
110
+ norm_eps: float = 1e-5,
111
+ cross_attention_dim: int = 1024,
112
+ attention_head_dim: Union[int, Tuple[int]] = 64,
113
+ ):
114
+ super().__init__()
115
+
116
+ self.sample_size = sample_size
117
+
118
+ # Check inputs
119
+ if len(down_block_types) != len(up_block_types):
120
+ raise ValueError(
121
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
122
+ )
123
+
124
+ if len(block_out_channels) != len(down_block_types):
125
+ raise ValueError(
126
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
127
+ )
128
+
129
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
130
+ down_block_types
131
+ ):
132
+ raise ValueError(
133
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
134
+ )
135
+
136
+ # input
137
+ conv_in_kernel = 3
138
+ conv_out_kernel = 3
139
+ conv_in_padding = (conv_in_kernel - 1) // 2
140
+ self.conv_in = nn.Conv2d(
141
+ in_channels,
142
+ block_out_channels[0],
143
+ kernel_size=conv_in_kernel,
144
+ padding=conv_in_padding,
145
+ )
146
+
147
+ # time
148
+ time_embed_dim = block_out_channels[0] * 4
149
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
150
+ timestep_input_dim = block_out_channels[0]
151
+
152
+ self.time_embedding = TimestepEmbedding(
153
+ timestep_input_dim,
154
+ time_embed_dim,
155
+ act_fn=act_fn,
156
+ )
157
+
158
+ self.transformer_in = TransformerTemporalModel(
159
+ num_attention_heads=8,
160
+ attention_head_dim=attention_head_dim,
161
+ in_channels=block_out_channels[0],
162
+ num_layers=1,
163
+ )
164
+
165
+ # class embedding
166
+ self.down_blocks = nn.ModuleList([])
167
+ self.up_blocks = nn.ModuleList([])
168
+
169
+ if isinstance(attention_head_dim, int):
170
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
171
+
172
+ # down
173
+ output_channel = block_out_channels[0]
174
+ for i, down_block_type in enumerate(down_block_types):
175
+ input_channel = output_channel
176
+ output_channel = block_out_channels[i]
177
+ is_final_block = i == len(block_out_channels) - 1
178
+
179
+ down_block = get_down_block(
180
+ down_block_type,
181
+ num_layers=layers_per_block,
182
+ in_channels=input_channel,
183
+ out_channels=output_channel,
184
+ temb_channels=time_embed_dim,
185
+ add_downsample=not is_final_block,
186
+ resnet_eps=norm_eps,
187
+ resnet_act_fn=act_fn,
188
+ resnet_groups=norm_num_groups,
189
+ cross_attention_dim=cross_attention_dim,
190
+ attn_num_head_channels=attention_head_dim[i],
191
+ downsample_padding=downsample_padding,
192
+ dual_cross_attention=False,
193
+ )
194
+ self.down_blocks.append(down_block)
195
+
196
+ # mid
197
+ self.mid_block = UNetMidBlock3DCrossAttn(
198
+ in_channels=block_out_channels[-1],
199
+ temb_channels=time_embed_dim,
200
+ resnet_eps=norm_eps,
201
+ resnet_act_fn=act_fn,
202
+ output_scale_factor=mid_block_scale_factor,
203
+ cross_attention_dim=cross_attention_dim,
204
+ attn_num_head_channels=attention_head_dim[-1],
205
+ resnet_groups=norm_num_groups,
206
+ dual_cross_attention=False,
207
+ )
208
+
209
+ # count how many layers upsample the images
210
+ self.num_upsamplers = 0
211
+
212
+ # up
213
+ reversed_block_out_channels = list(reversed(block_out_channels))
214
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
215
+
216
+ output_channel = reversed_block_out_channels[0]
217
+ for i, up_block_type in enumerate(up_block_types):
218
+ is_final_block = i == len(block_out_channels) - 1
219
+
220
+ prev_output_channel = output_channel
221
+ output_channel = reversed_block_out_channels[i]
222
+ input_channel = reversed_block_out_channels[
223
+ min(i + 1, len(block_out_channels) - 1)
224
+ ]
225
+
226
+ # add upsample block for all BUT final layer
227
+ if not is_final_block:
228
+ add_upsample = True
229
+ self.num_upsamplers += 1
230
+ else:
231
+ add_upsample = False
232
+
233
+ up_block = get_up_block(
234
+ up_block_type,
235
+ num_layers=layers_per_block + 1,
236
+ in_channels=input_channel,
237
+ out_channels=output_channel,
238
+ prev_output_channel=prev_output_channel,
239
+ temb_channels=time_embed_dim,
240
+ add_upsample=add_upsample,
241
+ resnet_eps=norm_eps,
242
+ resnet_act_fn=act_fn,
243
+ resnet_groups=norm_num_groups,
244
+ cross_attention_dim=cross_attention_dim,
245
+ attn_num_head_channels=reversed_attention_head_dim[i],
246
+ dual_cross_attention=False,
247
+ )
248
+ self.up_blocks.append(up_block)
249
+ prev_output_channel = output_channel
250
+
251
+ # out
252
+ if norm_num_groups is not None:
253
+ self.conv_norm_out = nn.GroupNorm(
254
+ num_channels=block_out_channels[0],
255
+ num_groups=norm_num_groups,
256
+ eps=norm_eps,
257
+ )
258
+ self.conv_act = nn.SiLU()
259
+ else:
260
+ self.conv_norm_out = None
261
+ self.conv_act = None
262
+
263
+ conv_out_padding = (conv_out_kernel - 1) // 2
264
+ self.conv_out = nn.Conv2d(
265
+ block_out_channels[0],
266
+ out_channels,
267
+ kernel_size=conv_out_kernel,
268
+ padding=conv_out_padding,
269
+ )
270
+
271
+ @property
272
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
273
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
274
+ r"""
275
+ Returns:
276
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
277
+ indexed by its weight name.
278
+ """
279
+ # set recursively
280
+ processors = {}
281
+
282
+ def fn_recursive_add_processors(
283
+ name: str,
284
+ module: torch.nn.Module,
285
+ processors: Dict[str, AttentionProcessor],
286
+ ):
287
+ if hasattr(module, "set_processor"):
288
+ processors[f"{name}.processor"] = module.processor
289
+
290
+ for sub_name, child in module.named_children():
291
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
292
+
293
+ return processors
294
+
295
+ for name, module in self.named_children():
296
+ fn_recursive_add_processors(name, module, processors)
297
+
298
+ return processors
299
+
300
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
301
+ def set_attention_slice(self, slice_size):
302
+ r"""
303
+ Enable sliced attention computation.
304
+
305
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
306
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
307
+
308
+ Args:
309
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
310
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
311
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
312
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
313
+ must be a multiple of `slice_size`.
314
+ """
315
+ sliceable_head_dims = []
316
+
317
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
318
+ if hasattr(module, "set_attention_slice"):
319
+ sliceable_head_dims.append(module.sliceable_head_dim)
320
+
321
+ for child in module.children():
322
+ fn_recursive_retrieve_sliceable_dims(child)
323
+
324
+ # retrieve number of attention layers
325
+ for module in self.children():
326
+ fn_recursive_retrieve_sliceable_dims(module)
327
+
328
+ num_sliceable_layers = len(sliceable_head_dims)
329
+
330
+ if slice_size == "auto":
331
+ # half the attention head size is usually a good trade-off between
332
+ # speed and memory
333
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
334
+ elif slice_size == "max":
335
+ # make smallest slice possible
336
+ slice_size = num_sliceable_layers * [1]
337
+
338
+ slice_size = (
339
+ num_sliceable_layers * [slice_size]
340
+ if not isinstance(slice_size, list)
341
+ else slice_size
342
+ )
343
+
344
+ if len(slice_size) != len(sliceable_head_dims):
345
+ raise ValueError(
346
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
347
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
348
+ )
349
+
350
+ for i in range(len(slice_size)):
351
+ size = slice_size[i]
352
+ dim = sliceable_head_dims[i]
353
+ if size is not None and size > dim:
354
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
355
+
356
+ # Recursively walk through all the children.
357
+ # Any children which exposes the set_attention_slice method
358
+ # gets the message
359
+ def fn_recursive_set_attention_slice(
360
+ module: torch.nn.Module, slice_size: List[int]
361
+ ):
362
+ if hasattr(module, "set_attention_slice"):
363
+ module.set_attention_slice(slice_size.pop())
364
+
365
+ for child in module.children():
366
+ fn_recursive_set_attention_slice(child, slice_size)
367
+
368
+ reversed_slice_size = list(reversed(slice_size))
369
+ for module in self.children():
370
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
371
+
372
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
373
+ def set_attn_processor(
374
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
375
+ ):
376
+ r"""
377
+ Parameters:
378
+ `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
379
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
380
+ of **all** `Attention` layers.
381
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
382
+
383
+ """
384
+ count = len(self.attn_processors.keys())
385
+
386
+ if isinstance(processor, dict) and len(processor) != count:
387
+ raise ValueError(
388
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
389
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
390
+ )
391
+
392
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
393
+ if hasattr(module, "set_processor"):
394
+ if not isinstance(processor, dict):
395
+ module.set_processor(processor)
396
+ else:
397
+ module.set_processor(processor.pop(f"{name}.processor"))
398
+
399
+ for sub_name, child in module.named_children():
400
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
401
+
402
+ for name, module in self.named_children():
403
+ fn_recursive_attn_processor(name, module, processor)
404
+
405
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
406
+ def set_default_attn_processor(self):
407
+ """
408
+ Disables custom attention processors and sets the default attention implementation.
409
+ """
410
+ self.set_attn_processor(AttnProcessor())
411
+
412
+ def _set_gradient_checkpointing(self, module, value=False):
413
+ if isinstance(
414
+ module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)
415
+ ):
416
+ module.gradient_checkpointing = value
417
+
418
+ def forward(
419
+ self,
420
+ sample: torch.FloatTensor,
421
+ timestep: Union[torch.Tensor, float, int],
422
+ encoder_hidden_states: torch.Tensor,
423
+ class_labels: Optional[torch.Tensor] = None,
424
+ timestep_cond: Optional[torch.Tensor] = None,
425
+ attention_mask: Optional[torch.Tensor] = None,
426
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
427
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
428
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
429
+ return_dict: bool = True,
430
+ ) -> Union[UNet3DConditionOutput, Tuple]:
431
+ r"""
432
+ Args:
433
+ sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor
434
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
435
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
436
+ return_dict (`bool`, *optional*, defaults to `True`):
437
+ Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple.
438
+ cross_attention_kwargs (`dict`, *optional*):
439
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
440
+ `self.processor` in
441
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
442
+
443
+ Returns:
444
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`:
445
+ [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
446
+ returning a tuple, the first element is the sample tensor.
447
+ """
448
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
449
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
450
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
451
+ # on the fly if necessary.
452
+ default_overall_up_factor = 2**self.num_upsamplers
453
+
454
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
455
+ forward_upsample_size = False
456
+ upsample_size = None
457
+
458
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
459
+ logger.info("Forward upsample size to force interpolation output size.")
460
+ forward_upsample_size = True
461
+
462
+ # prepare attention_mask
463
+ if attention_mask is not None:
464
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
465
+ attention_mask = attention_mask.unsqueeze(1)
466
+
467
+ # 1. time
468
+ timesteps = timestep
469
+ if not torch.is_tensor(timesteps):
470
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
471
+ # This would be a good case for the `match` statement (Python 3.10+)
472
+ is_mps = sample.device.type == "mps"
473
+ if isinstance(timestep, float):
474
+ dtype = torch.float32 if is_mps else torch.float64
475
+ else:
476
+ dtype = torch.int32 if is_mps else torch.int64
477
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
478
+ elif len(timesteps.shape) == 0:
479
+ timesteps = timesteps[None].to(sample.device)
480
+
481
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
482
+ num_frames = sample.shape[2]
483
+ timesteps = timesteps.expand(sample.shape[0])
484
+
485
+ t_emb = self.time_proj(timesteps)
486
+
487
+ # timesteps does not contain any weights and will always return f32 tensors
488
+ # but time_embedding might actually be running in fp16. so we need to cast here.
489
+ # there might be better ways to encapsulate this.
490
+ t_emb = t_emb.to(dtype=self.dtype)
491
+
492
+ emb = self.time_embedding(t_emb, timestep_cond)
493
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
494
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(
495
+ repeats=num_frames, dim=0
496
+ )
497
+
498
+ # 2. pre-process
499
+ sample = sample.permute(0, 2, 1, 3, 4).reshape(
500
+ (sample.shape[0] * num_frames, -1) + sample.shape[3:]
501
+ )
502
+ sample = self.conv_in(sample)
503
+
504
+ sample = self.transformer_in(
505
+ sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
506
+ ).sample
507
+
508
+ # 3. down
509
+ down_block_res_samples = (sample,)
510
+ for downsample_block in self.down_blocks:
511
+ if (
512
+ hasattr(downsample_block, "has_cross_attention")
513
+ and downsample_block.has_cross_attention
514
+ ):
515
+ sample, res_samples = downsample_block(
516
+ hidden_states=sample,
517
+ temb=emb,
518
+ encoder_hidden_states=encoder_hidden_states,
519
+ attention_mask=attention_mask,
520
+ num_frames=num_frames,
521
+ cross_attention_kwargs=cross_attention_kwargs,
522
+ )
523
+ else:
524
+ sample, res_samples = downsample_block(
525
+ hidden_states=sample, temb=emb, num_frames=num_frames
526
+ )
527
+
528
+ down_block_res_samples += res_samples
529
+
530
+ if down_block_additional_residuals is not None:
531
+ new_down_block_res_samples = ()
532
+
533
+ for down_block_res_sample, down_block_additional_residual in zip(
534
+ down_block_res_samples, down_block_additional_residuals
535
+ ):
536
+ down_block_res_sample = (
537
+ down_block_res_sample + down_block_additional_residual
538
+ )
539
+ new_down_block_res_samples += (down_block_res_sample,)
540
+
541
+ down_block_res_samples = new_down_block_res_samples
542
+
543
+ # 4. mid
544
+ if self.mid_block is not None:
545
+ sample = self.mid_block(
546
+ sample,
547
+ emb,
548
+ encoder_hidden_states=encoder_hidden_states,
549
+ attention_mask=attention_mask,
550
+ num_frames=num_frames,
551
+ cross_attention_kwargs=cross_attention_kwargs,
552
+ )
553
+
554
+ if mid_block_additional_residual is not None:
555
+ sample = sample + mid_block_additional_residual
556
+
557
+ # 5. up
558
+ for i, upsample_block in enumerate(self.up_blocks):
559
+ is_final_block = i == len(self.up_blocks) - 1
560
+
561
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
562
+ down_block_res_samples = down_block_res_samples[
563
+ : -len(upsample_block.resnets)
564
+ ]
565
+
566
+ # if we have not reached the final block and need to forward the
567
+ # upsample size, we do it here
568
+ if not is_final_block and forward_upsample_size:
569
+ upsample_size = down_block_res_samples[-1].shape[2:]
570
+
571
+ if (
572
+ hasattr(upsample_block, "has_cross_attention")
573
+ and upsample_block.has_cross_attention
574
+ ):
575
+ sample = upsample_block(
576
+ hidden_states=sample,
577
+ temb=emb,
578
+ res_hidden_states_tuple=res_samples,
579
+ encoder_hidden_states=encoder_hidden_states,
580
+ upsample_size=upsample_size,
581
+ attention_mask=attention_mask,
582
+ num_frames=num_frames,
583
+ cross_attention_kwargs=cross_attention_kwargs,
584
+ )
585
+ else:
586
+ sample = upsample_block(
587
+ hidden_states=sample,
588
+ temb=emb,
589
+ res_hidden_states_tuple=res_samples,
590
+ upsample_size=upsample_size,
591
+ num_frames=num_frames,
592
+ )
593
+
594
+ # 6. post-process
595
+ if self.conv_norm_out:
596
+ sample = self.conv_norm_out(sample)
597
+ sample = self.conv_act(sample)
598
+
599
+ sample = self.conv_out(sample)
600
+
601
+ # reshape to (batch, channel, framerate, width, height)
602
+ sample = (
603
+ sample[None, :]
604
+ .reshape((-1, num_frames) + sample.shape[1:])
605
+ .permute(0, 2, 1, 3, 4)
606
+ )
607
+
608
+ if not return_dict:
609
+ return (sample,)
610
+
611
+ return UNet3DConditionOutput(sample=sample)
text_to_animation/pipelines/text_to_video_pipeline_flax.py CHANGED
@@ -6,11 +6,16 @@ import jax.numpy as jnp
6
  import numpy as np
7
  from flax.core.frozen_dict import FrozenDict
8
  from flax.jax_utils import unreplicate
 
9
  from flax.training.common_utils import shard
10
  from PIL import Image
11
  from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
12
  from einops import rearrange, repeat
13
- from diffusers.models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel
 
 
 
 
14
  from diffusers.schedulers import (
15
  FlaxDDIMScheduler,
16
  FlaxDPMSolverMultistepScheduler,
@@ -20,17 +25,24 @@ from diffusers.schedulers import (
20
  from diffusers.utils import PIL_INTERPOLATION, logging, replace_example_docstring
21
  from diffusers.pipelines.pipeline_flax_utils import FlaxDiffusionPipeline
22
  from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionPipelineOutput
23
- from diffusers.pipelines.stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker
 
 
 
24
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
  """
26
  Text2Video-Zero:
27
  - Inputs: Prompt, Pose Control via mp4/gif, First Frame (?)
28
  - JAX implementation
29
  - 3DUnet to replace 2DUnetConditional
30
-
31
  """
32
 
33
- DEBUG = False # Set to True to use python for loop instead of jax.fori_loop for easier debugging
 
 
 
 
 
34
 
35
  EXAMPLE_DOC_STRING = """
36
  Examples:
@@ -89,16 +101,22 @@ EXAMPLE_DOC_STRING = """
89
  >>> output_images.save("generated_image.png")
90
  ```
91
  """
 
 
92
  class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
93
  def __init__(
94
  self,
95
- vae: FlaxAutoencoderKL,
96
- text_encoder: FlaxCLIPTextModel,
97
- tokenizer: CLIPTokenizer,
98
- unet: FlaxUNet2DConditionModel,
99
- controlnet: FlaxControlNetModel,
 
100
  scheduler: Union[
101
- FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
 
 
 
102
  ],
103
  safety_checker: FlaxStableDiffusionSafetyChecker,
104
  feature_extractor: CLIPFeatureExtractor,
@@ -122,6 +140,7 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
122
  text_encoder=text_encoder,
123
  tokenizer=tokenizer,
124
  unet=unet,
 
125
  controlnet=controlnet,
126
  scheduler=scheduler,
127
  safety_checker=safety_checker,
@@ -135,30 +154,50 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
135
  else:
136
  eps = jax.random.normal(prng, x0.shape, dtype=text_embeddings.dtype)
137
  alpha_vec = jnp.prod(params["scheduler"].common.alphas[t0:tMax])
138
- xt = jnp.sqrt(alpha_vec) * x0 + \
139
- jnp.sqrt(1-alpha_vec) * eps
140
  return xt
141
-
142
- def DDIM_backward(self, params, num_inference_steps, timesteps, skip_t, t0, t1, do_classifier_free_guidance, text_embeddings, latents_local,
143
- guidance_scale, controlnet_image=None, controlnet_conditioning_scale=None):
144
- scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  f = latents_local.shape[2]
146
- latents_local = rearrange(latents_local, "b c f w h -> (b f) c w h")
147
  latents = latents_local.copy()
148
  x_t0_1 = None
149
  x_t1_1 = None
150
- max_timestep = len(timesteps)-1
151
  timesteps = jnp.array(timesteps)
 
152
  def while_body(args):
153
  step, latents, x_t0_1, x_t1_1, scheduler_state = args
154
  t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
155
- latent_model_input = jnp.concatenate(
156
- [latents] * 2) if do_classifier_free_guidance else latents
 
 
 
157
  latent_model_input = self.scheduler.scale_model_input(
158
  scheduler_state, latent_model_input, timestep=t
159
  )
160
  f = latents.shape[0]
161
- te = jnp.stack([text_embeddings[0, :, :]]*f + [text_embeddings[-1,:,:]]*f)
 
 
162
  timestep = jnp.broadcast_to(t, latent_model_input.shape[0])
163
  if controlnet_image is not None:
164
  down_block_res_samples, mid_block_res_sample = self.controlnet.apply(
@@ -185,41 +224,53 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
185
  jnp.array(latent_model_input),
186
  jnp.array(timestep, dtype=jnp.int32),
187
  encoder_hidden_states=te,
188
- ).sample
189
  # perform guidance
190
  if do_classifier_free_guidance:
191
  noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
192
- noise_pred = noise_pred_uncond + guidance_scale * \
193
- (noise_pred_text - noise_pred_uncond)
 
194
  # compute the previous noisy sample x_t -> x_t-1
195
- latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
196
- x_t0_1 = jax.lax.select((step < max_timestep-1) & (timesteps[step+1] == t0), latents, x_t0_1)
197
- x_t1_1 = jax.lax.select((step < max_timestep-1) & (timesteps[step+1] == t1), latents, x_t1_1)
 
 
 
 
 
 
198
  return (step + 1, latents, x_t0_1, x_t1_1, scheduler_state)
 
199
  latents_shape = latents.shape
200
  x_t0_1, x_t1_1 = jnp.zeros(latents_shape), jnp.zeros(latents_shape)
201
 
202
  def cond_fun(arg):
203
  step, latents, x_t0_1, x_t1_1, scheduler_state = arg
204
  return (step < skip_t) & (step < num_inference_steps)
205
-
206
  if DEBUG:
207
  step = 0
208
  while cond_fun((step, latents, x_t0_1, x_t1_1)):
209
- step, latents, x_t0_1, x_t1_1, scheduler_state = while_body((step, latents, x_t0_1, x_t1_1, scheduler_state))
 
 
210
  step = step + 1
211
  else:
212
- _, latents, x_t0_1, x_t1_1, scheduler_state = jax.lax.while_loop(cond_fun, while_body, (0, latents, x_t0_1, x_t1_1, scheduler_state))
213
- latents = rearrange(latents, "(b f) c w h -> b c f w h", f=f)
 
 
214
  res = {"x0": latents.copy()}
215
  if x_t0_1 is not None:
216
- x_t0_1 = rearrange(x_t0_1, "(b f) c w h -> b c f w h", f=f)
217
  res["x_t0_1"] = x_t0_1.copy()
218
  if x_t1_1 is not None:
219
- x_t1_1 = rearrange(x_t1_1, "(b f) c w h -> b c f w h", f=f)
220
  res["x_t1_1"] = x_t1_1.copy()
221
  return res
222
-
223
  def warp_latents_independently(self, latents, reference_flow):
224
  _, _, H, W = reference_flow.shape
225
  b, _, f, h, w = latents.shape
@@ -230,10 +281,10 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
230
  coords_t0 = coords_t0.at[:, 1].set(coords_t0[:, 1] * h / H)
231
  f, c, _, _ = coords_t0.shape
232
  coords_t0 = jax.image.resize(coords_t0, (f, c, h, w), "linear")
233
- coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c')
234
- latents_0 = rearrange(latents[0], 'c f h w -> f c h w')
235
  warped = grid_sample(latents_0, coords_t0, "mirror")
236
- warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f)
237
  return warped
238
 
239
  def warp_vid_independently(self, vid, reference_flow):
@@ -245,74 +296,173 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
245
  coords_t0 = coords_t0.at[:, 1].set(coords_t0[:, 1] * h / H)
246
  f, c, _, _ = coords_t0.shape
247
  coords_t0 = jax.image.resize(coords_t0, (f, c, h, w), "linear")
248
- coords_t0 = rearrange(coords_t0, 'f c h w -> f h w c')
249
  # latents_0 = rearrange(vid, 'c f h w -> f c h w')
250
  warped = grid_sample(vid, coords_t0, "zeropad")
251
  # warped = rearrange(warped, 'f c h w -> b c f h w', f=f)
252
  return warped
253
-
254
- def create_motion_field(self, motion_field_strength_x, motion_field_strength_y, frame_ids, video_length, latents):
255
- reference_flow = jnp.zeros(
256
- (video_length-1, 2, 512, 512), dtype=latents.dtype)
 
 
 
 
 
 
257
  for fr_idx, frame_id in enumerate(frame_ids):
258
- reference_flow = reference_flow.at[fr_idx, 0, :,
259
- :].set(motion_field_strength_x*(frame_id))
260
- reference_flow = reference_flow.at[fr_idx, 1, :,
261
- :].set(motion_field_strength_y*(frame_id))
 
 
262
  return reference_flow
263
-
264
- def create_motion_field_and_warp_latents(self, motion_field_strength_x, motion_field_strength_y, frame_ids, video_length, latents):
265
- motion_field = self.create_motion_field(motion_field_strength_x=motion_field_strength_x,
266
- motion_field_strength_y=motion_field_strength_y, latents=latents, video_length=video_length, frame_ids=frame_ids)
 
 
 
 
 
 
 
 
 
 
 
 
267
  for idx, latent in enumerate(latents):
268
- latents = latents.at[idx].set(self.warp_latents_independently(
269
- latent[None], motion_field)[0])
 
270
  return motion_field, latents
271
-
272
- def text_to_video_zero(self, params,
273
- prng,
274
- text_embeddings,
275
- video_length: Optional[int],
276
- do_classifier_free_guidance = True,
277
- height: Optional[int] = None,
278
- width: Optional[int] = None,
279
- num_inference_steps: int = 50,
280
- guidance_scale: float = 7.5,
281
- num_videos_per_prompt: Optional[int] = 1,
282
- xT = None,
283
- motion_field_strength_x: float = 12,
284
- motion_field_strength_y: float = 12,
285
- t0: int = 44,
286
- t1: int = 47,
287
- controlnet_image=None,
288
- controlnet_conditioning_scale=0,
289
- ):
 
 
 
290
  frame_ids = list(range(video_length))
291
  # Prepare timesteps
292
- params["scheduler"] = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps)
 
 
293
  timesteps = params["scheduler"].timesteps
294
  # Prepare latent variables
295
  num_channels_latents = self.unet.in_channels
296
  batch_size = 1
297
- xT = prepare_latents(params, prng, batch_size * num_videos_per_prompt, num_channels_latents, 1, height, width, self.vae_scale_factor, xT)
298
- xT = xT[:, :, :1]
299
- timesteps_ddpm = [981, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 721,
300
- 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 441,
301
- 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 161,
302
- 141, 121, 101, 81, 61, 41, 21, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  timesteps_ddpm.reverse()
304
  t0 = timesteps_ddpm[t0]
305
  t1 = timesteps_ddpm[t1]
306
  x_t1_1 = None
307
 
308
  # Denoising loop
309
- shape = (batch_size, num_channels_latents, 1, height //
310
- self.vae.scaling_factor, width // self.vae.scaling_factor)
 
 
 
 
 
311
 
312
  # perform ∆t backward steps by stable diffusion
313
- ddim_res = self.DDIM_backward(params, num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=1000, t0=t0, t1=t1, do_classifier_free_guidance=do_classifier_free_guidance,
314
- text_embeddings=text_embeddings, latents_local=xT, guidance_scale=guidance_scale,
315
- controlnet_image=jnp.stack([controlnet_image[0]] * 2), controlnet_conditioning_scale=controlnet_conditioning_scale)
 
 
 
 
 
 
 
 
 
 
 
316
  x0 = ddim_res["x0"]
317
 
318
  # apply warping functions
@@ -320,37 +470,524 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
320
  x_t0_1 = ddim_res["x_t0_1"]
321
  if "x_t1_1" in ddim_res:
322
  x_t1_1 = ddim_res["x_t1_1"]
323
- x_t0_k = x_t0_1[:, :, :1, :, :].repeat(video_length-1, 2)
324
  reference_flow, x_t0_k = self.create_motion_field_and_warp_latents(
325
- motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=x_t0_k, video_length=video_length, frame_ids=frame_ids[1:])
 
 
 
 
 
326
  # assuming t0=t1=1000, if t0 = 1000
327
 
328
  # DDPM forward for more motion freedom
329
- ddpm_fwd = partial(self.DDPM_forward, params=params, prng=prng, x0=x_t0_k, t0=t0,
330
- tMax=t1, shape=shape, text_embeddings=text_embeddings)
331
- x_t1_k = jax.lax.cond(t1 > t0,
332
- ddpm_fwd,
333
- lambda:x_t0_k
 
 
 
 
334
  )
335
- x_t1 = jnp.concatenate([x_t1_1, x_t1_k], axis=2).copy()
 
336
 
337
  # backward stepts by stable diffusion
338
 
339
- #warp the controlnet image following the same flow defined for latent
340
  controlnet_video = controlnet_image[:video_length]
341
- controlnet_video = controlnet_video.at[1:].set(self.warp_vid_independently(controlnet_video[1:], reference_flow))
342
- controlnet_image = jnp.concatenate([controlnet_video]*2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
- ddim_res = self.DDIM_backward(params, num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance,
346
- text_embeddings=text_embeddings, latents_local=x_t1, guidance_scale=guidance_scale,
347
- controlnet_image=controlnet_image, controlnet_conditioning_scale=controlnet_conditioning_scale)
348
  x0 = ddim_res["x0"]
 
 
 
 
349
  return x0
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  def prepare_text_inputs(self, prompt: Union[str, List[str]]):
352
  if not isinstance(prompt, (str, list)):
353
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
 
354
  text_input = self.tokenizer(
355
  prompt,
356
  padding="max_length",
@@ -359,27 +996,38 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
359
  return_tensors="np",
360
  )
361
  return text_input.input_ids
 
362
  def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):
363
  if not isinstance(image, (Image.Image, list)):
364
- raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}")
 
 
365
  if isinstance(image, Image.Image):
366
  image = [image]
367
- processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image])
 
 
368
  return processed_images
 
369
  def _get_has_nsfw_concepts(self, features, params):
370
  has_nsfw_concepts = self.safety_checker(features, params)
371
  return has_nsfw_concepts
 
372
  def _run_safety_checker(self, images, safety_model_params, jit=False):
373
  # safety_model_params should already be replicated when jit is True
374
  pil_images = [Image.fromarray(image) for image in images]
375
  features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
376
  if jit:
377
  features = shard(features)
378
- has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
 
 
379
  has_nsfw_concepts = unshard(has_nsfw_concepts)
380
  safety_model_params = unreplicate(safety_model_params)
381
  else:
382
- has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
 
 
383
  images_was_copied = False
384
  for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
385
  if has_nsfw_concept:
@@ -393,6 +1041,7 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
393
  " instead. Try again with a different prompt and/or seed."
394
  )
395
  return images, has_nsfw_concepts
 
396
  def _generate(
397
  self,
398
  prompt_ids: jnp.array,
@@ -404,7 +1053,8 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
404
  latents: Optional[jnp.array] = None,
405
  neg_prompt_ids: Optional[jnp.array] = None,
406
  controlnet_conditioning_scale: float = 1.0,
407
- xT = None,
 
408
  motion_field_strength_x: float = 12,
409
  motion_field_strength_y: float = 12,
410
  t0: int = 44,
@@ -413,7 +1063,9 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
413
  height, width = image.shape[-2:]
414
  video_length = image.shape[0]
415
  if height % 64 != 0 or width % 64 != 0:
416
- raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.")
 
 
417
  # get prompt text embeddings
418
  prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
419
  # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
@@ -422,30 +1074,47 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
422
  max_length = prompt_ids.shape[-1]
423
  if neg_prompt_ids is None:
424
  uncond_input = self.tokenizer(
425
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
 
 
 
426
  ).input_ids
427
  else:
428
  uncond_input = neg_prompt_ids
429
- negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
 
 
430
  context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
431
  image = jnp.concatenate([image] * 2)
432
  seed_t2vz, prng_seed = jax.random.split(prng_seed)
433
- #get the latent following text to video zero
434
- latents = self.text_to_video_zero(params, seed_t2vz, text_embeddings=context, video_length=video_length,
435
- height=height, width = width, num_inference_steps=num_inference_steps,
436
- guidance_scale=guidance_scale, controlnet_image=image,
437
- xT=xT, t0=t0, t1=t1,
438
- motion_field_strength_x=motion_field_strength_x,
439
- motion_field_strength_y=motion_field_strength_y,
440
- controlnet_conditioning_scale=controlnet_conditioning_scale
441
- )
 
 
 
 
 
 
 
 
 
 
442
  # scale and decode the image latents with vae
443
  latents = 1 / self.vae.config.scaling_factor * latents
444
  latents = rearrange(latents, "b c f h w -> (b f) c h w")
445
- video = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
 
 
446
  video = (video / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
447
  return video
448
-
449
  @replace_example_docstring(EXAMPLE_DOC_STRING)
450
  def __call__(
451
  self,
@@ -460,7 +1129,8 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
460
  controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
461
  return_dict: bool = True,
462
  jit: bool = False,
463
- xT = None,
 
464
  motion_field_strength_x: float = 3,
465
  motion_field_strength_y: float = 4,
466
  t0: int = 44,
@@ -517,7 +1187,9 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
517
  if isinstance(controlnet_conditioning_scale, float):
518
  # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
519
  # shape information, as they may be sharded (when `jit` is `True`), or not.
520
- controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0])
 
 
521
  if len(prompt_ids.shape) > 2:
522
  # Assume sharded
523
  controlnet_conditioning_scale = controlnet_conditioning_scale[:, None]
@@ -534,6 +1206,7 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
534
  neg_prompt_ids,
535
  controlnet_conditioning_scale,
536
  xT,
 
537
  motion_field_strength_x,
538
  motion_field_strength_y,
539
  t0,
@@ -551,6 +1224,7 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
551
  neg_prompt_ids,
552
  controlnet_conditioning_scale,
553
  xT,
 
554
  motion_field_strength_x,
555
  motion_field_strength_y,
556
  t0,
@@ -560,8 +1234,12 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
560
  safety_params = params["safety_checker"]
561
  images_uint8_casted = (images * 255).round().astype("uint8")
562
  num_devices, batch_size = images.shape[:2]
563
- images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3)
564
- images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit)
 
 
 
 
565
  images = np.asarray(images)
566
  # block images
567
  if any(has_nsfw_concept):
@@ -574,17 +1252,21 @@ class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
574
  has_nsfw_concept = False
575
  if not return_dict:
576
  return (images, has_nsfw_concept)
577
- return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
 
 
 
 
578
  # Static argnums are pipe, num_inference_steps. A change would trigger recompilation.
579
  # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
580
  @partial(
581
  jax.pmap,
582
- in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0, 0, None, None, None, None),
583
- static_broadcasted_argnums=(0, 5, 11, 12, 13, 14),
584
  )
585
  def _p_generate(
586
  pipe,
587
- prompt_ids,
588
  image,
589
  params,
590
  prng_seed,
@@ -594,6 +1276,7 @@ def _p_generate(
594
  neg_prompt_ids,
595
  controlnet_conditioning_scale,
596
  xT,
 
597
  motion_field_strength_x,
598
  motion_field_strength_y,
599
  t0,
@@ -610,19 +1293,26 @@ def _p_generate(
610
  neg_prompt_ids,
611
  controlnet_conditioning_scale,
612
  xT,
 
613
  motion_field_strength_x,
614
  motion_field_strength_y,
615
  t0,
616
  t1,
617
  )
 
 
618
  @partial(jax.pmap, static_broadcasted_argnums=(0,))
619
  def _p_get_has_nsfw_concepts(pipe, features, params):
620
  return pipe._get_has_nsfw_concepts(features, params)
 
 
621
  def unshard(x: jnp.ndarray):
622
  # einops.rearrange(x, 'd b ... -> (d b) ...')
623
  num_devices, batch_size = x.shape[:2]
624
  rest = x.shape[2:]
625
  return x.reshape(num_devices * batch_size, *rest)
 
 
626
  def preprocess(image, dtype):
627
  image = image.convert("RGB")
628
  w, h = image.size
@@ -632,43 +1322,98 @@ def preprocess(image, dtype):
632
  image = image[None].transpose(0, 3, 1, 2)
633
  return image
634
 
635
- def prepare_latents(params, prng, batch_size, num_channels_latents, video_length, height, width, vae_scale_factor, latents=None):
636
- shape = (batch_size, num_channels_latents, video_length, height //
637
- vae_scale_factor, width // vae_scale_factor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  # scale the initial noise by the standard deviation required by the scheduler
639
  if latents is None:
640
  latents = jax.random.normal(prng, shape)
641
  latents = latents * params["scheduler"].init_noise_sigma
642
  return latents
643
 
 
644
  def coords_grid(batch, ht, wd):
645
  coords = jnp.meshgrid(jnp.arange(ht), jnp.arange(wd), indexing="ij")
646
  coords = jnp.stack(coords[::-1], axis=0)
647
  return coords[None].repeat(batch, 0)
648
 
 
649
  def adapt_pos_mirror(x, y, W, H):
650
- #adapt the position, with mirror padding
651
- x_w_mirror = ((x + W - 1) % (2*(W - 1))) - W + 1
652
- x_adapted = jnp.where(x_w_mirror > 0, x_w_mirror, - (x_w_mirror))
653
- y_w_mirror = ((y + H - 1) % (2*(H - 1))) - H + 1
654
- y_adapted = jnp.where(y_w_mirror > 0, y_w_mirror, - (y_w_mirror))
655
- return y_adapted, x_adapted
 
 
 
 
 
656
 
657
- def safe_get_zeropad(img, x,y,W,H):
658
- return jnp.where((x < W) & (x > 0) & (y < H) & (y > 0), img[y,x], 0.)
659
 
660
- def safe_get_mirror(img, x,y,W,H):
661
- return img[adapt_pos_mirror(x,y,W,H)]
662
 
663
  @partial(jax.vmap, in_axes=(0, 0, None))
664
  @partial(jax.vmap, in_axes=(0, None, None))
665
- @partial(jax.vmap, in_axes=(None,0, None))
666
  @partial(jax.vmap, in_axes=(None, 0, None))
667
  def grid_sample(latents, grid, method):
668
  # this is an alternative to torch.functional.nn.grid_sample in jax
669
  # this implementation is following the algorithm described @ https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
670
  # but with coordinates scaled to the size of the image
671
  if method == "mirror":
672
- return safe_get_mirror(latents, jnp.array(grid[0], dtype=jnp.int16), jnp.array(grid[1], dtype=jnp.int16), latents.shape[0], latents.shape[1])
673
- else: #default is zero padding
674
- return safe_get_zeropad(latents, jnp.array(grid[0], dtype=jnp.int16), jnp.array(grid[1], dtype=jnp.int16), latents.shape[0], latents.shape[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import numpy as np
7
  from flax.core.frozen_dict import FrozenDict
8
  from flax.jax_utils import unreplicate
9
+ from flax import jax_utils
10
  from flax.training.common_utils import shard
11
  from PIL import Image
12
  from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
13
  from einops import rearrange, repeat
14
+ from diffusers.models import (
15
+ FlaxAutoencoderKL,
16
+ FlaxControlNetModel,
17
+ FlaxUNet2DConditionModel,
18
+ )
19
  from diffusers.schedulers import (
20
  FlaxDDIMScheduler,
21
  FlaxDPMSolverMultistepScheduler,
 
25
  from diffusers.utils import PIL_INTERPOLATION, logging, replace_example_docstring
26
  from diffusers.pipelines.pipeline_flax_utils import FlaxDiffusionPipeline
27
  from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionPipelineOutput
28
+ from diffusers.pipelines.stable_diffusion.safety_checker_flax import (
29
+ FlaxStableDiffusionSafetyChecker,
30
+ )
31
+
32
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
  """
34
  Text2Video-Zero:
35
  - Inputs: Prompt, Pose Control via mp4/gif, First Frame (?)
36
  - JAX implementation
37
  - 3DUnet to replace 2DUnetConditional
 
38
  """
39
 
40
+
41
+ def replicate_devices(array):
42
+ return jnp.expand_dims(array, 0).repeat(jax.device_count(), 0)
43
+
44
+
45
+ DEBUG = False # Set to True to use python for loop instead of jax.fori_loop for easier debugging
46
 
47
  EXAMPLE_DOC_STRING = """
48
  Examples:
 
101
  >>> output_images.save("generated_image.png")
102
  ```
103
  """
104
+
105
+
106
  class FlaxTextToVideoPipeline(FlaxDiffusionPipeline):
107
  def __init__(
108
  self,
109
+ vae,
110
+ text_encoder,
111
+ tokenizer,
112
+ unet,
113
+ unet_vanilla,
114
+ controlnet,
115
  scheduler: Union[
116
+ FlaxDDIMScheduler,
117
+ FlaxPNDMScheduler,
118
+ FlaxLMSDiscreteScheduler,
119
+ FlaxDPMSolverMultistepScheduler,
120
  ],
121
  safety_checker: FlaxStableDiffusionSafetyChecker,
122
  feature_extractor: CLIPFeatureExtractor,
 
140
  text_encoder=text_encoder,
141
  tokenizer=tokenizer,
142
  unet=unet,
143
+ unet_vanilla=unet_vanilla,
144
  controlnet=controlnet,
145
  scheduler=scheduler,
146
  safety_checker=safety_checker,
 
154
  else:
155
  eps = jax.random.normal(prng, x0.shape, dtype=text_embeddings.dtype)
156
  alpha_vec = jnp.prod(params["scheduler"].common.alphas[t0:tMax])
157
+ xt = jnp.sqrt(alpha_vec) * x0 + jnp.sqrt(1 - alpha_vec) * eps
 
158
  return xt
159
+
160
+ def DDIM_backward(
161
+ self,
162
+ params,
163
+ num_inference_steps,
164
+ timesteps,
165
+ skip_t,
166
+ t0,
167
+ t1,
168
+ do_classifier_free_guidance,
169
+ text_embeddings,
170
+ latents_local,
171
+ guidance_scale,
172
+ controlnet_image=None,
173
+ controlnet_conditioning_scale=None,
174
+ ):
175
+ scheduler_state = self.scheduler.set_timesteps(
176
+ params["scheduler"], num_inference_steps
177
+ )
178
  f = latents_local.shape[2]
179
+ latents_local = rearrange(latents_local, "b c f h w -> (b f) c h w")
180
  latents = latents_local.copy()
181
  x_t0_1 = None
182
  x_t1_1 = None
183
+ max_timestep = len(timesteps) - 1
184
  timesteps = jnp.array(timesteps)
185
+
186
  def while_body(args):
187
  step, latents, x_t0_1, x_t1_1, scheduler_state = args
188
  t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
189
+ latent_model_input = (
190
+ jnp.concatenate([latents] * 2)
191
+ if do_classifier_free_guidance
192
+ else latents
193
+ )
194
  latent_model_input = self.scheduler.scale_model_input(
195
  scheduler_state, latent_model_input, timestep=t
196
  )
197
  f = latents.shape[0]
198
+ te = jnp.stack(
199
+ [text_embeddings[0, :, :]] * f + [text_embeddings[-1, :, :]] * f
200
+ )
201
  timestep = jnp.broadcast_to(t, latent_model_input.shape[0])
202
  if controlnet_image is not None:
203
  down_block_res_samples, mid_block_res_sample = self.controlnet.apply(
 
224
  jnp.array(latent_model_input),
225
  jnp.array(timestep, dtype=jnp.int32),
226
  encoder_hidden_states=te,
227
+ ).sample
228
  # perform guidance
229
  if do_classifier_free_guidance:
230
  noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
231
+ noise_pred = noise_pred_uncond + guidance_scale * (
232
+ noise_pred_text - noise_pred_uncond
233
+ )
234
  # compute the previous noisy sample x_t -> x_t-1
235
+ latents, scheduler_state = self.scheduler.step(
236
+ scheduler_state, noise_pred, t, latents
237
+ ).to_tuple()
238
+ x_t0_1 = jax.lax.select(
239
+ (step < max_timestep - 1) & (timesteps[step + 1] == t0), latents, x_t0_1
240
+ )
241
+ x_t1_1 = jax.lax.select(
242
+ (step < max_timestep - 1) & (timesteps[step + 1] == t1), latents, x_t1_1
243
+ )
244
  return (step + 1, latents, x_t0_1, x_t1_1, scheduler_state)
245
+
246
  latents_shape = latents.shape
247
  x_t0_1, x_t1_1 = jnp.zeros(latents_shape), jnp.zeros(latents_shape)
248
 
249
  def cond_fun(arg):
250
  step, latents, x_t0_1, x_t1_1, scheduler_state = arg
251
  return (step < skip_t) & (step < num_inference_steps)
252
+
253
  if DEBUG:
254
  step = 0
255
  while cond_fun((step, latents, x_t0_1, x_t1_1)):
256
+ step, latents, x_t0_1, x_t1_1, scheduler_state = while_body(
257
+ (step, latents, x_t0_1, x_t1_1, scheduler_state)
258
+ )
259
  step = step + 1
260
  else:
261
+ _, latents, x_t0_1, x_t1_1, scheduler_state = jax.lax.while_loop(
262
+ cond_fun, while_body, (0, latents, x_t0_1, x_t1_1, scheduler_state)
263
+ )
264
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=f)
265
  res = {"x0": latents.copy()}
266
  if x_t0_1 is not None:
267
+ x_t0_1 = rearrange(x_t0_1, "(b f) c h w -> b c f h w", f=f)
268
  res["x_t0_1"] = x_t0_1.copy()
269
  if x_t1_1 is not None:
270
+ x_t1_1 = rearrange(x_t1_1, "(b f) c h w -> b c f h w", f=f)
271
  res["x_t1_1"] = x_t1_1.copy()
272
  return res
273
+
274
  def warp_latents_independently(self, latents, reference_flow):
275
  _, _, H, W = reference_flow.shape
276
  b, _, f, h, w = latents.shape
 
281
  coords_t0 = coords_t0.at[:, 1].set(coords_t0[:, 1] * h / H)
282
  f, c, _, _ = coords_t0.shape
283
  coords_t0 = jax.image.resize(coords_t0, (f, c, h, w), "linear")
284
+ coords_t0 = rearrange(coords_t0, "f c h w -> f h w c")
285
+ latents_0 = rearrange(latents[0], "c f h w -> f c h w")
286
  warped = grid_sample(latents_0, coords_t0, "mirror")
287
+ warped = rearrange(warped, "(b f) c h w -> b c f h w", f=f)
288
  return warped
289
 
290
  def warp_vid_independently(self, vid, reference_flow):
 
296
  coords_t0 = coords_t0.at[:, 1].set(coords_t0[:, 1] * h / H)
297
  f, c, _, _ = coords_t0.shape
298
  coords_t0 = jax.image.resize(coords_t0, (f, c, h, w), "linear")
299
+ coords_t0 = rearrange(coords_t0, "f c h w -> f h w c")
300
  # latents_0 = rearrange(vid, 'c f h w -> f c h w')
301
  warped = grid_sample(vid, coords_t0, "zeropad")
302
  # warped = rearrange(warped, 'f c h w -> b c f h w', f=f)
303
  return warped
304
+
305
+ def create_motion_field(
306
+ self,
307
+ motion_field_strength_x,
308
+ motion_field_strength_y,
309
+ frame_ids,
310
+ video_length,
311
+ latents,
312
+ ):
313
+ reference_flow = jnp.zeros((video_length - 1, 2, 512, 512), dtype=latents.dtype)
314
  for fr_idx, frame_id in enumerate(frame_ids):
315
+ reference_flow = reference_flow.at[fr_idx, 0, :, :].set(
316
+ motion_field_strength_x * (frame_id)
317
+ )
318
+ reference_flow = reference_flow.at[fr_idx, 1, :, :].set(
319
+ motion_field_strength_y * (frame_id)
320
+ )
321
  return reference_flow
322
+
323
+ def create_motion_field_and_warp_latents(
324
+ self,
325
+ motion_field_strength_x,
326
+ motion_field_strength_y,
327
+ frame_ids,
328
+ video_length,
329
+ latents,
330
+ ):
331
+ motion_field = self.create_motion_field(
332
+ motion_field_strength_x=motion_field_strength_x,
333
+ motion_field_strength_y=motion_field_strength_y,
334
+ latents=latents,
335
+ video_length=video_length,
336
+ frame_ids=frame_ids,
337
+ )
338
  for idx, latent in enumerate(latents):
339
+ latents = latents.at[idx].set(
340
+ self.warp_latents_independently(latent[None], motion_field)[0]
341
+ )
342
  return motion_field, latents
343
+
344
+ def text_to_video_zero(
345
+ self,
346
+ params,
347
+ prng,
348
+ text_embeddings,
349
+ video_length: Optional[int],
350
+ do_classifier_free_guidance=True,
351
+ height: Optional[int] = None,
352
+ width: Optional[int] = None,
353
+ num_inference_steps: int = 50,
354
+ guidance_scale: float = 7.5,
355
+ num_videos_per_prompt: Optional[int] = 1,
356
+ xT=None,
357
+ smooth_bg_strength: float = 0.0,
358
+ motion_field_strength_x: float = 12,
359
+ motion_field_strength_y: float = 12,
360
+ t0: int = 44,
361
+ t1: int = 47,
362
+ controlnet_image=None,
363
+ controlnet_conditioning_scale=0,
364
+ ):
365
  frame_ids = list(range(video_length))
366
  # Prepare timesteps
367
+ params["scheduler"] = self.scheduler.set_timesteps(
368
+ params["scheduler"], num_inference_steps
369
+ )
370
  timesteps = params["scheduler"].timesteps
371
  # Prepare latent variables
372
  num_channels_latents = self.unet.in_channels
373
  batch_size = 1
374
+ xT = prepare_latents(
375
+ params,
376
+ prng,
377
+ batch_size * num_videos_per_prompt,
378
+ num_channels_latents,
379
+ height,
380
+ width,
381
+ self.vae_scale_factor,
382
+ xT,
383
+ )
384
+
385
+ timesteps_ddpm = [
386
+ 981,
387
+ 961,
388
+ 941,
389
+ 921,
390
+ 901,
391
+ 881,
392
+ 861,
393
+ 841,
394
+ 821,
395
+ 801,
396
+ 781,
397
+ 761,
398
+ 741,
399
+ 721,
400
+ 701,
401
+ 681,
402
+ 661,
403
+ 641,
404
+ 621,
405
+ 601,
406
+ 581,
407
+ 561,
408
+ 541,
409
+ 521,
410
+ 501,
411
+ 481,
412
+ 461,
413
+ 441,
414
+ 421,
415
+ 401,
416
+ 381,
417
+ 361,
418
+ 341,
419
+ 321,
420
+ 301,
421
+ 281,
422
+ 261,
423
+ 241,
424
+ 221,
425
+ 201,
426
+ 181,
427
+ 161,
428
+ 141,
429
+ 121,
430
+ 101,
431
+ 81,
432
+ 61,
433
+ 41,
434
+ 21,
435
+ 1,
436
+ ]
437
  timesteps_ddpm.reverse()
438
  t0 = timesteps_ddpm[t0]
439
  t1 = timesteps_ddpm[t1]
440
  x_t1_1 = None
441
 
442
  # Denoising loop
443
+ shape = (
444
+ batch_size,
445
+ num_channels_latents,
446
+ 1,
447
+ height // self.vae.scaling_factor,
448
+ width // self.vae.scaling_factor,
449
+ )
450
 
451
  # perform ∆t backward steps by stable diffusion
452
+ ddim_res = self.DDIM_backward(
453
+ params,
454
+ num_inference_steps=num_inference_steps,
455
+ timesteps=timesteps,
456
+ skip_t=1000,
457
+ t0=t0,
458
+ t1=t1,
459
+ do_classifier_free_guidance=do_classifier_free_guidance,
460
+ text_embeddings=text_embeddings,
461
+ latents_local=xT,
462
+ guidance_scale=guidance_scale,
463
+ controlnet_image=jnp.stack([controlnet_image[0]] * 2),
464
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
465
+ )
466
  x0 = ddim_res["x0"]
467
 
468
  # apply warping functions
 
470
  x_t0_1 = ddim_res["x_t0_1"]
471
  if "x_t1_1" in ddim_res:
472
  x_t1_1 = ddim_res["x_t1_1"]
473
+ x_t0_k = x_t0_1[:, :, :1, :, :].repeat(video_length - 1, 2)
474
  reference_flow, x_t0_k = self.create_motion_field_and_warp_latents(
475
+ motion_field_strength_x=motion_field_strength_x,
476
+ motion_field_strength_y=motion_field_strength_y,
477
+ latents=x_t0_k,
478
+ video_length=video_length,
479
+ frame_ids=frame_ids[1:],
480
+ )
481
  # assuming t0=t1=1000, if t0 = 1000
482
 
483
  # DDPM forward for more motion freedom
484
+ ddpm_fwd = partial(
485
+ self.DDPM_forward,
486
+ params=params,
487
+ prng=prng,
488
+ x0=x_t0_k,
489
+ t0=t0,
490
+ tMax=t1,
491
+ shape=shape,
492
+ text_embeddings=text_embeddings,
493
  )
494
+ x_t1_k = jax.lax.cond(t1 > t0, ddpm_fwd, lambda: x_t0_k)
495
+ x_t1 = jnp.concatenate([x_t1_1, x_t1_k], axis=2)
496
 
497
  # backward stepts by stable diffusion
498
 
499
+ # warp the controlnet image following the same flow defined for latent
500
  controlnet_video = controlnet_image[:video_length]
501
+ controlnet_video = controlnet_video.at[1:].set(
502
+ self.warp_vid_independently(controlnet_video[1:], reference_flow)
503
+ )
504
+ controlnet_image = jnp.concatenate([controlnet_video] * 2)
505
+ smooth_bg = True
506
+
507
+ if smooth_bg:
508
+ # latent shape: "b c f h w"
509
+ M_FG = repeat(
510
+ get_mask_pose(controlnet_video),
511
+ "f h w -> b c f h w",
512
+ c=x_t1.shape[1],
513
+ b=batch_size,
514
+ )
515
+ initial_bg = repeat(
516
+ x_t1[:, :, 0] * (1 - M_FG[:, :, 0]),
517
+ "b c h w -> b c f h w",
518
+ f=video_length - 1,
519
+ )
520
+ # warp the controlnet image following the same flow defined for latent #f c h w
521
+ initial_bg_warped = self.warp_latents_independently(
522
+ initial_bg, reference_flow
523
+ )
524
+ bgs = x_t1[:, :, 1:] * (1 - M_FG[:, :, 1:]) # initial background
525
+ initial_mask_warped = 1 - self.warp_latents_independently(
526
+ repeat(M_FG[:, :, 0], "b c h w -> b c f h w", f=video_length - 1),
527
+ reference_flow,
528
+ )
529
+ # initial_mask_warped = 1 - warp_vid_independently(repeat(M_FG[:,:,0], "b c h w -> (b f) c h w", f = video_length-1), reference_flow)
530
+ # initial_mask_warped = rearrange(initial_mask_warped, "(b f) c h w -> b c f h w", b=batch_size)
531
+ mask = (1 - M_FG[:, :, 1:]) * initial_mask_warped
532
+ x_t1 = x_t1.at[:, :, 1:].set(
533
+ (1 - mask) * x_t1[:, :, 1:]
534
+ + mask
535
+ * (
536
+ initial_bg_warped * smooth_bg_strength
537
+ + (1 - smooth_bg_strength) * bgs
538
+ )
539
+ )
540
 
541
+ ddim_res = self.DDIM_backward(
542
+ params,
543
+ num_inference_steps=num_inference_steps,
544
+ timesteps=timesteps,
545
+ skip_t=t1,
546
+ t0=-1,
547
+ t1=-1,
548
+ do_classifier_free_guidance=do_classifier_free_guidance,
549
+ text_embeddings=text_embeddings,
550
+ latents_local=x_t1,
551
+ guidance_scale=guidance_scale,
552
+ controlnet_image=controlnet_image,
553
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
554
+ )
555
 
 
 
 
556
  x0 = ddim_res["x0"]
557
+ del ddim_res
558
+ del x_t1
559
+ del x_t1_1
560
+ del x_t1_k
561
  return x0
562
 
563
+ def denoise_latent(
564
+ self,
565
+ params,
566
+ num_inference_steps,
567
+ timesteps,
568
+ do_classifier_free_guidance,
569
+ text_embeddings,
570
+ latents,
571
+ guidance_scale,
572
+ controlnet_image=None,
573
+ controlnet_conditioning_scale=None,
574
+ ):
575
+ scheduler_state = self.scheduler.set_timesteps(
576
+ params["scheduler"], num_inference_steps
577
+ )
578
+ # f = latents_local.shape[2]
579
+ # latents_local = rearrange(latents_local, "b c f h w -> (b f) c h w")
580
+
581
+ max_timestep = len(timesteps) - 1
582
+ timesteps = jnp.array(timesteps)
583
+
584
+ def while_body(args):
585
+ step, latents, scheduler_state = args
586
+ t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
587
+ latent_model_input = (
588
+ jnp.concatenate([latents] * 2)
589
+ if do_classifier_free_guidance
590
+ else latents
591
+ )
592
+ latent_model_input = self.scheduler.scale_model_input(
593
+ scheduler_state, latent_model_input, timestep=t
594
+ )
595
+ f = latents.shape[0]
596
+ te = jnp.stack(
597
+ [text_embeddings[0, :, :]] * f + [text_embeddings[-1, :, :]] * f
598
+ )
599
+ timestep = jnp.broadcast_to(t, latent_model_input.shape[0])
600
+ if controlnet_image is not None:
601
+ down_block_res_samples, mid_block_res_sample = self.controlnet.apply(
602
+ {"params": params["controlnet"]},
603
+ jnp.array(latent_model_input),
604
+ jnp.array(timestep, dtype=jnp.int32),
605
+ encoder_hidden_states=te,
606
+ controlnet_cond=controlnet_image,
607
+ conditioning_scale=controlnet_conditioning_scale,
608
+ return_dict=False,
609
+ )
610
+ # predict the noise residual
611
+ noise_pred = self.unet_vanilla.apply(
612
+ {"params": params["unet"]},
613
+ jnp.array(latent_model_input),
614
+ jnp.array(timestep, dtype=jnp.int32),
615
+ encoder_hidden_states=te,
616
+ down_block_additional_residuals=down_block_res_samples,
617
+ mid_block_additional_residual=mid_block_res_sample,
618
+ ).sample
619
+ else:
620
+ noise_pred = self.unet_vanilla.apply(
621
+ {"params": params["unet"]},
622
+ jnp.array(latent_model_input),
623
+ jnp.array(timestep, dtype=jnp.int32),
624
+ encoder_hidden_states=te,
625
+ ).sample
626
+ # perform guidance
627
+ if do_classifier_free_guidance:
628
+ noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0)
629
+ noise_pred = noise_pred_uncond + guidance_scale * (
630
+ noise_pred_text - noise_pred_uncond
631
+ )
632
+ # compute the previous noisy sample x_t -> x_t-1
633
+ latents, scheduler_state = self.scheduler.step(
634
+ scheduler_state, noise_pred, t, latents
635
+ ).to_tuple()
636
+ return (step + 1, latents, scheduler_state)
637
+
638
+ def cond_fun(arg):
639
+ step, latents, scheduler_state = arg
640
+ return step < num_inference_steps
641
+
642
+ if DEBUG:
643
+ step = 0
644
+ while cond_fun((step, latents, scheduler_state)):
645
+ step, latents, scheduler_state = while_body(
646
+ (step, latents, scheduler_state)
647
+ )
648
+ step = step + 1
649
+ else:
650
+ _, latents, scheduler_state = jax.lax.while_loop(
651
+ cond_fun, while_body, (0, latents, scheduler_state)
652
+ )
653
+ # latents = rearrange(latents, "(b f) c h w -> b c f h w", f=f)
654
+ return latents
655
+
656
+ @partial(jax.jit, static_argnums=(0, 1))
657
+ def _generate_starting_frames(
658
+ self,
659
+ num_inference_steps,
660
+ params,
661
+ timesteps,
662
+ text_embeddings,
663
+ latents,
664
+ guidance_scale,
665
+ controlnet_image,
666
+ controlnet_conditioning_scale,
667
+ ):
668
+ # perform ∆t backward steps by stable diffusion
669
+ # delta_t_diffusion = jax.vmap(lambda latent : self.DDIM_backward(params, num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=1000, t0=t0, t1=t1, do_classifier_free_guidance=do_classifier_free_guidance,
670
+ # text_embeddings=text_embeddings, latents_local=latent, guidance_scale=guidance_scale,
671
+ # controlnet_image=controlnet_image, controlnet_conditioning_scale=controlnet_conditioning_scale))
672
+ # ddim_res = delta_t_diffusion(latents)
673
+ # latents = ddim_res["x0"] #output is i b c f h w
674
+
675
+ # DDPM forward for more motion freedom
676
+ # ddpm_fwd = jax.vmap(lambda prng, latent: self.DDPM_forward(params=params, prng=prng, x0=latent, t0=t0,
677
+ # tMax=t1, shape=shape, text_embeddings=text_embeddings))
678
+ # latents = ddpm_fwd(stacked_prngs, latents)
679
+ # main backward diffusion
680
+ # denoise_first_frame = lambda latent : self.DDIM_backward(params, num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=100000, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance,
681
+ # text_embeddings=text_embeddings, latents_local=latent, guidance_scale=guidance_scale,
682
+ # controlnet_image=controlnet_image, controlnet_conditioning_scale=controlnet_conditioning_scale, use_vanilla=True)
683
+ # latents = rearrange(latents, 'i b c f h w -> (i b) c f h w')
684
+ # ddim_res = denoise_first_frame(latents)
685
+ latents = self.denoise_latent(
686
+ params,
687
+ num_inference_steps=num_inference_steps,
688
+ timesteps=timesteps,
689
+ do_classifier_free_guidance=True,
690
+ text_embeddings=text_embeddings,
691
+ latents=latents,
692
+ guidance_scale=guidance_scale,
693
+ controlnet_image=controlnet_image,
694
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
695
+ )
696
+ # latents = rearrange(ddim_res["x0"], 'i b c f h w -> (i b) c f h w') #output is i b c f h w
697
+
698
+ # scale and decode the image latents with vae
699
+ latents = 1 / self.vae.config.scaling_factor * latents
700
+ # latents = rearrange(latents, "b c h w -> (b f) c h w")
701
+ imgs = self.vae.apply(
702
+ {"params": params["vae"]}, latents, method=self.vae.decode
703
+ ).sample
704
+ imgs = (imgs / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
705
+ return imgs
706
+
707
+ def generate_starting_frames(
708
+ self,
709
+ params,
710
+ prngs: list, # list of prngs for each img
711
+ prompt,
712
+ neg_prompt,
713
+ controlnet_image,
714
+ do_classifier_free_guidance=True,
715
+ num_inference_steps: int = 50,
716
+ guidance_scale: float = 7.5,
717
+ t0: int = 44,
718
+ t1: int = 47,
719
+ controlnet_conditioning_scale=1.0,
720
+ ):
721
+ height, width = controlnet_image.shape[-2:]
722
+ if height % 64 != 0 or width % 64 != 0:
723
+ raise ValueError(
724
+ f"`height` and `width` have to be divisible by 64 but are {height} and {width}."
725
+ )
726
+
727
+ shape = (
728
+ self.unet.in_channels,
729
+ height // self.vae_scale_factor,
730
+ width // self.vae_scale_factor,
731
+ ) # b c h w
732
+ # scale the initial noise by the standard deviation required by the scheduler
733
+
734
+ print(
735
+ f"Generating {len(prngs)} first frames with prompt {prompt}, for {num_inference_steps} steps. PRNG seeds are: {prngs}"
736
+ )
737
+
738
+ latents = jnp.stack(
739
+ [jax.random.normal(prng, shape) for prng in prngs]
740
+ ) # b c h w
741
+ latents = latents * params["scheduler"].init_noise_sigma
742
+
743
+ timesteps = params["scheduler"].timesteps
744
+ timesteps_ddpm = [
745
+ 981,
746
+ 961,
747
+ 941,
748
+ 921,
749
+ 901,
750
+ 881,
751
+ 861,
752
+ 841,
753
+ 821,
754
+ 801,
755
+ 781,
756
+ 761,
757
+ 741,
758
+ 721,
759
+ 701,
760
+ 681,
761
+ 661,
762
+ 641,
763
+ 621,
764
+ 601,
765
+ 581,
766
+ 561,
767
+ 541,
768
+ 521,
769
+ 501,
770
+ 481,
771
+ 461,
772
+ 441,
773
+ 421,
774
+ 401,
775
+ 381,
776
+ 361,
777
+ 341,
778
+ 321,
779
+ 301,
780
+ 281,
781
+ 261,
782
+ 241,
783
+ 221,
784
+ 201,
785
+ 181,
786
+ 161,
787
+ 141,
788
+ 121,
789
+ 101,
790
+ 81,
791
+ 61,
792
+ 41,
793
+ 21,
794
+ 1,
795
+ ]
796
+ timesteps_ddpm.reverse()
797
+ t0 = timesteps_ddpm[t0]
798
+ t1 = timesteps_ddpm[t1]
799
+
800
+ # get prompt text embeddings
801
+ prompt_ids = self.prepare_text_inputs(prompt)
802
+ prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
803
+
804
+ # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
805
+ # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
806
+ batch_size = 1
807
+ max_length = prompt_ids.shape[-1]
808
+ if neg_prompt is None:
809
+ uncond_input = self.tokenizer(
810
+ [""] * batch_size,
811
+ padding="max_length",
812
+ max_length=max_length,
813
+ return_tensors="np",
814
+ ).input_ids
815
+ else:
816
+ neg_prompt_ids = self.prepare_text_inputs(neg_prompt)
817
+ uncond_input = neg_prompt_ids
818
+
819
+ negative_prompt_embeds = self.text_encoder(
820
+ uncond_input, params=params["text_encoder"]
821
+ )[0]
822
+ text_embeddings = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
823
+ controlnet_image = jnp.stack([controlnet_image[0]] * 2 * len(prngs))
824
+ return self._generate_starting_frames(
825
+ num_inference_steps,
826
+ params,
827
+ timesteps,
828
+ text_embeddings,
829
+ latents,
830
+ guidance_scale,
831
+ controlnet_image,
832
+ controlnet_conditioning_scale,
833
+ )
834
+
835
+ def generate_video(
836
+ self,
837
+ prompt: str,
838
+ image: jnp.array,
839
+ params: Union[Dict, FrozenDict],
840
+ prng_seed: jax.random.KeyArray,
841
+ num_inference_steps: int = 50,
842
+ guidance_scale: Union[float, jnp.array] = 7.5,
843
+ latents: jnp.array = None,
844
+ neg_prompt: str = "",
845
+ controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
846
+ return_dict: bool = True,
847
+ jit: bool = False,
848
+ xT=None,
849
+ smooth_bg_strength: float = 0.0,
850
+ motion_field_strength_x: float = 3,
851
+ motion_field_strength_y: float = 4,
852
+ t0: int = 44,
853
+ t1: int = 47,
854
+ ):
855
+ r"""
856
+ Function invoked when calling the pipeline for generation.
857
+ Args:
858
+ prompt_ids (`jnp.array`):
859
+ The prompt or prompts to guide the image generation.
860
+ image (`jnp.array`):
861
+ Array representing the ControlNet input condition. ControlNet use this input condition to generate
862
+ guidance to Unet.
863
+ params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights
864
+ prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key
865
+ num_inference_steps (`int`, *optional*, defaults to 50):
866
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
867
+ expense of slower inference.
868
+ guidance_scale (`float`, *optional*, defaults to 7.5):
869
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
870
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
871
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
872
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
873
+ usually at the expense of lower image quality.
874
+ latents (`jnp.array`, *optional*):
875
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
876
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
877
+ tensor will ge generated by sampling using the supplied random `generator`.
878
+ controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0):
879
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
880
+ to the residual in the original unet.
881
+ return_dict (`bool`, *optional*, defaults to `True`):
882
+ Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
883
+ a plain tuple.
884
+ jit (`bool`, defaults to `False`):
885
+ Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
886
+ exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
887
+ Examples:
888
+ Returns:
889
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
890
+ [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
891
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
892
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
893
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
894
+ """
895
+ height, width = image.shape[-2:]
896
+ vid_length = image.shape[0]
897
+ # get prompt text embeddings
898
+ prompt_ids = self.prepare_text_inputs([prompt] * vid_length)
899
+ neg_prompt_ids = self.prepare_text_inputs([neg_prompt] * vid_length)
900
+
901
+ # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
902
+ # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
903
+ batch_size = 1
904
+
905
+ if isinstance(guidance_scale, float):
906
+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
907
+ # shape information, as they may be sharded (when `jit` is `True`), or not.
908
+ guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
909
+ if len(prompt_ids.shape) > 2:
910
+ # Assume sharded
911
+ guidance_scale = guidance_scale[:, None]
912
+ if isinstance(controlnet_conditioning_scale, float):
913
+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
914
+ # shape information, as they may be sharded (when `jit` is `True`), or not.
915
+ controlnet_conditioning_scale = jnp.array(
916
+ [controlnet_conditioning_scale] * prompt_ids.shape[0]
917
+ )
918
+ if len(prompt_ids.shape) > 2:
919
+ # Assume sharded
920
+ controlnet_conditioning_scale = controlnet_conditioning_scale[:, None]
921
+ if jit:
922
+ images = _p_generate(
923
+ self,
924
+ replicate_devices(prompt_ids),
925
+ replicate_devices(image),
926
+ jax_utils.replicate(params),
927
+ replicate_devices(prng_seed),
928
+ num_inference_steps,
929
+ replicate_devices(guidance_scale),
930
+ replicate_devices(latents) if latents is not None else None,
931
+ replicate_devices(neg_prompt_ids)
932
+ if neg_prompt_ids is not None
933
+ else None,
934
+ replicate_devices(controlnet_conditioning_scale),
935
+ replicate_devices(xT) if xT is not None else None,
936
+ replicate_devices(smooth_bg_strength),
937
+ replicate_devices(motion_field_strength_x),
938
+ replicate_devices(motion_field_strength_y),
939
+ t0,
940
+ t1,
941
+ )
942
+ else:
943
+ images = self._generate(
944
+ prompt_ids,
945
+ image,
946
+ params,
947
+ prng_seed,
948
+ num_inference_steps,
949
+ guidance_scale,
950
+ latents,
951
+ neg_prompt_ids,
952
+ controlnet_conditioning_scale,
953
+ xT,
954
+ smooth_bg_strength,
955
+ motion_field_strength_x,
956
+ motion_field_strength_y,
957
+ t0,
958
+ t1,
959
+ )
960
+ if self.safety_checker is not None:
961
+ safety_params = params["safety_checker"]
962
+ images_uint8_casted = (images * 255).round().astype("uint8")
963
+ num_devices, batch_size = images.shape[:2]
964
+ images_uint8_casted = np.asarray(images_uint8_casted).reshape(
965
+ num_devices * batch_size, height, width, 3
966
+ )
967
+ images_uint8_casted, has_nsfw_concept = self._run_safety_checker(
968
+ images_uint8_casted, safety_params, jit
969
+ )
970
+ images = np.asarray(images)
971
+ # block images
972
+ if any(has_nsfw_concept):
973
+ for i, is_nsfw in enumerate(has_nsfw_concept):
974
+ if is_nsfw:
975
+ images[i] = np.asarray(images_uint8_casted[i])
976
+ images = images.reshape(num_devices, batch_size, height, width, 3)
977
+ else:
978
+ images = np.asarray(images)
979
+ has_nsfw_concept = False
980
+ if not return_dict:
981
+ return (images, has_nsfw_concept)
982
+ return FlaxStableDiffusionPipelineOutput(
983
+ images=images, nsfw_content_detected=has_nsfw_concept
984
+ )
985
+
986
  def prepare_text_inputs(self, prompt: Union[str, List[str]]):
987
  if not isinstance(prompt, (str, list)):
988
+ raise ValueError(
989
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
990
+ )
991
  text_input = self.tokenizer(
992
  prompt,
993
  padding="max_length",
 
996
  return_tensors="np",
997
  )
998
  return text_input.input_ids
999
+
1000
  def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]):
1001
  if not isinstance(image, (Image.Image, list)):
1002
+ raise ValueError(
1003
+ f"image has to be of type `PIL.Image.Image` or list but is {type(image)}"
1004
+ )
1005
  if isinstance(image, Image.Image):
1006
  image = [image]
1007
+ processed_images = jnp.concatenate(
1008
+ [preprocess(img, jnp.float32) for img in image]
1009
+ )
1010
  return processed_images
1011
+
1012
  def _get_has_nsfw_concepts(self, features, params):
1013
  has_nsfw_concepts = self.safety_checker(features, params)
1014
  return has_nsfw_concepts
1015
+
1016
  def _run_safety_checker(self, images, safety_model_params, jit=False):
1017
  # safety_model_params should already be replicated when jit is True
1018
  pil_images = [Image.fromarray(image) for image in images]
1019
  features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
1020
  if jit:
1021
  features = shard(features)
1022
+ has_nsfw_concepts = _p_get_has_nsfw_concepts(
1023
+ self, features, safety_model_params
1024
+ )
1025
  has_nsfw_concepts = unshard(has_nsfw_concepts)
1026
  safety_model_params = unreplicate(safety_model_params)
1027
  else:
1028
+ has_nsfw_concepts = self._get_has_nsfw_concepts(
1029
+ features, safety_model_params
1030
+ )
1031
  images_was_copied = False
1032
  for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
1033
  if has_nsfw_concept:
 
1041
  " instead. Try again with a different prompt and/or seed."
1042
  )
1043
  return images, has_nsfw_concepts
1044
+
1045
  def _generate(
1046
  self,
1047
  prompt_ids: jnp.array,
 
1053
  latents: Optional[jnp.array] = None,
1054
  neg_prompt_ids: Optional[jnp.array] = None,
1055
  controlnet_conditioning_scale: float = 1.0,
1056
+ xT=None,
1057
+ smooth_bg_strength: float = 0.0,
1058
  motion_field_strength_x: float = 12,
1059
  motion_field_strength_y: float = 12,
1060
  t0: int = 44,
 
1063
  height, width = image.shape[-2:]
1064
  video_length = image.shape[0]
1065
  if height % 64 != 0 or width % 64 != 0:
1066
+ raise ValueError(
1067
+ f"`height` and `width` have to be divisible by 64 but are {height} and {width}."
1068
+ )
1069
  # get prompt text embeddings
1070
  prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
1071
  # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
 
1074
  max_length = prompt_ids.shape[-1]
1075
  if neg_prompt_ids is None:
1076
  uncond_input = self.tokenizer(
1077
+ [""] * batch_size,
1078
+ padding="max_length",
1079
+ max_length=max_length,
1080
+ return_tensors="np",
1081
  ).input_ids
1082
  else:
1083
  uncond_input = neg_prompt_ids
1084
+ negative_prompt_embeds = self.text_encoder(
1085
+ uncond_input, params=params["text_encoder"]
1086
+ )[0]
1087
  context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
1088
  image = jnp.concatenate([image] * 2)
1089
  seed_t2vz, prng_seed = jax.random.split(prng_seed)
1090
+ # get the latent following text to video zero
1091
+ latents = self.text_to_video_zero(
1092
+ params,
1093
+ seed_t2vz,
1094
+ text_embeddings=context,
1095
+ video_length=video_length,
1096
+ height=height,
1097
+ width=width,
1098
+ num_inference_steps=num_inference_steps,
1099
+ guidance_scale=guidance_scale,
1100
+ controlnet_image=image,
1101
+ xT=xT,
1102
+ smooth_bg_strength=smooth_bg_strength,
1103
+ t0=t0,
1104
+ t1=t1,
1105
+ motion_field_strength_x=motion_field_strength_x,
1106
+ motion_field_strength_y=motion_field_strength_y,
1107
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
1108
+ )
1109
  # scale and decode the image latents with vae
1110
  latents = 1 / self.vae.config.scaling_factor * latents
1111
  latents = rearrange(latents, "b c f h w -> (b f) c h w")
1112
+ video = self.vae.apply(
1113
+ {"params": params["vae"]}, latents, method=self.vae.decode
1114
+ ).sample
1115
  video = (video / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
1116
  return video
1117
+
1118
  @replace_example_docstring(EXAMPLE_DOC_STRING)
1119
  def __call__(
1120
  self,
 
1129
  controlnet_conditioning_scale: Union[float, jnp.array] = 1.0,
1130
  return_dict: bool = True,
1131
  jit: bool = False,
1132
+ xT=None,
1133
+ smooth_bg_strength: float = 0.0,
1134
  motion_field_strength_x: float = 3,
1135
  motion_field_strength_y: float = 4,
1136
  t0: int = 44,
 
1187
  if isinstance(controlnet_conditioning_scale, float):
1188
  # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
1189
  # shape information, as they may be sharded (when `jit` is `True`), or not.
1190
+ controlnet_conditioning_scale = jnp.array(
1191
+ [controlnet_conditioning_scale] * prompt_ids.shape[0]
1192
+ )
1193
  if len(prompt_ids.shape) > 2:
1194
  # Assume sharded
1195
  controlnet_conditioning_scale = controlnet_conditioning_scale[:, None]
 
1206
  neg_prompt_ids,
1207
  controlnet_conditioning_scale,
1208
  xT,
1209
+ smooth_bg_strength,
1210
  motion_field_strength_x,
1211
  motion_field_strength_y,
1212
  t0,
 
1224
  neg_prompt_ids,
1225
  controlnet_conditioning_scale,
1226
  xT,
1227
+ smooth_bg_strength,
1228
  motion_field_strength_x,
1229
  motion_field_strength_y,
1230
  t0,
 
1234
  safety_params = params["safety_checker"]
1235
  images_uint8_casted = (images * 255).round().astype("uint8")
1236
  num_devices, batch_size = images.shape[:2]
1237
+ images_uint8_casted = np.asarray(images_uint8_casted).reshape(
1238
+ num_devices * batch_size, height, width, 3
1239
+ )
1240
+ images_uint8_casted, has_nsfw_concept = self._run_safety_checker(
1241
+ images_uint8_casted, safety_params, jit
1242
+ )
1243
  images = np.asarray(images)
1244
  # block images
1245
  if any(has_nsfw_concept):
 
1252
  has_nsfw_concept = False
1253
  if not return_dict:
1254
  return (images, has_nsfw_concept)
1255
+ return FlaxStableDiffusionPipelineOutput(
1256
+ images=images, nsfw_content_detected=has_nsfw_concept
1257
+ )
1258
+
1259
+
1260
  # Static argnums are pipe, num_inference_steps. A change would trigger recompilation.
1261
  # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
1262
  @partial(
1263
  jax.pmap,
1264
+ in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0, 0, 0, 0, 0, None, None),
1265
+ static_broadcasted_argnums=(0, 5, 14, 15),
1266
  )
1267
  def _p_generate(
1268
  pipe,
1269
+ prompt_ids,
1270
  image,
1271
  params,
1272
  prng_seed,
 
1276
  neg_prompt_ids,
1277
  controlnet_conditioning_scale,
1278
  xT,
1279
+ smooth_bg_strength,
1280
  motion_field_strength_x,
1281
  motion_field_strength_y,
1282
  t0,
 
1293
  neg_prompt_ids,
1294
  controlnet_conditioning_scale,
1295
  xT,
1296
+ smooth_bg_strength,
1297
  motion_field_strength_x,
1298
  motion_field_strength_y,
1299
  t0,
1300
  t1,
1301
  )
1302
+
1303
+
1304
  @partial(jax.pmap, static_broadcasted_argnums=(0,))
1305
  def _p_get_has_nsfw_concepts(pipe, features, params):
1306
  return pipe._get_has_nsfw_concepts(features, params)
1307
+
1308
+
1309
  def unshard(x: jnp.ndarray):
1310
  # einops.rearrange(x, 'd b ... -> (d b) ...')
1311
  num_devices, batch_size = x.shape[:2]
1312
  rest = x.shape[2:]
1313
  return x.reshape(num_devices * batch_size, *rest)
1314
+
1315
+
1316
  def preprocess(image, dtype):
1317
  image = image.convert("RGB")
1318
  w, h = image.size
 
1322
  image = image[None].transpose(0, 3, 1, 2)
1323
  return image
1324
 
1325
+
1326
+ def prepare_latents(
1327
+ params,
1328
+ prng,
1329
+ batch_size,
1330
+ num_channels_latents,
1331
+ height,
1332
+ width,
1333
+ vae_scale_factor,
1334
+ latents=None,
1335
+ ):
1336
+ shape = (
1337
+ batch_size,
1338
+ num_channels_latents,
1339
+ 1,
1340
+ height // vae_scale_factor,
1341
+ width // vae_scale_factor,
1342
+ ) # b c f h w
1343
  # scale the initial noise by the standard deviation required by the scheduler
1344
  if latents is None:
1345
  latents = jax.random.normal(prng, shape)
1346
  latents = latents * params["scheduler"].init_noise_sigma
1347
  return latents
1348
 
1349
+
1350
  def coords_grid(batch, ht, wd):
1351
  coords = jnp.meshgrid(jnp.arange(ht), jnp.arange(wd), indexing="ij")
1352
  coords = jnp.stack(coords[::-1], axis=0)
1353
  return coords[None].repeat(batch, 0)
1354
 
1355
+
1356
  def adapt_pos_mirror(x, y, W, H):
1357
+ # adapt the position, with mirror padding
1358
+ x_w_mirror = ((x + W - 1) % (2 * (W - 1))) - W + 1
1359
+ x_adapted = jnp.where(x_w_mirror > 0, x_w_mirror, -(x_w_mirror))
1360
+ y_w_mirror = ((y + H - 1) % (2 * (H - 1))) - H + 1
1361
+ y_adapted = jnp.where(y_w_mirror > 0, y_w_mirror, -(y_w_mirror))
1362
+ return y_adapted, x_adapted
1363
+
1364
+
1365
+ def safe_get_zeropad(img, x, y, W, H):
1366
+ return jnp.where((x < W) & (x > 0) & (y < H) & (y > 0), img[y, x], 0.0)
1367
+
1368
 
1369
+ def safe_get_mirror(img, x, y, W, H):
1370
+ return img[adapt_pos_mirror(x, y, W, H)]
1371
 
 
 
1372
 
1373
  @partial(jax.vmap, in_axes=(0, 0, None))
1374
  @partial(jax.vmap, in_axes=(0, None, None))
1375
+ @partial(jax.vmap, in_axes=(None, 0, None))
1376
  @partial(jax.vmap, in_axes=(None, 0, None))
1377
  def grid_sample(latents, grid, method):
1378
  # this is an alternative to torch.functional.nn.grid_sample in jax
1379
  # this implementation is following the algorithm described @ https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
1380
  # but with coordinates scaled to the size of the image
1381
  if method == "mirror":
1382
+ return safe_get_mirror(
1383
+ latents,
1384
+ jnp.array(grid[0], dtype=jnp.int16),
1385
+ jnp.array(grid[1], dtype=jnp.int16),
1386
+ latents.shape[0],
1387
+ latents.shape[1],
1388
+ )
1389
+ else: # default is zero padding
1390
+ return safe_get_zeropad(
1391
+ latents,
1392
+ jnp.array(grid[0], dtype=jnp.int16),
1393
+ jnp.array(grid[1], dtype=jnp.int16),
1394
+ latents.shape[0],
1395
+ latents.shape[1],
1396
+ )
1397
+
1398
+
1399
+ def bandw_vid(vid, threshold):
1400
+ vid = jnp.max(vid, axis=1)
1401
+ return jnp.where(vid > threshold, 1, 0)
1402
+
1403
+
1404
+ def mean_blur(vid, k):
1405
+ window = jnp.ones((vid.shape[0], k, k)) / (k * k)
1406
+ convolve = jax.vmap(
1407
+ lambda img, kernel: jax.scipy.signal.convolve(img, kernel, mode="same")
1408
+ )
1409
+ smooth_vid = convolve(vid, window)
1410
+ return smooth_vid
1411
+
1412
+
1413
+ def get_mask_pose(vid):
1414
+ vid = bandw_vid(vid, 0.4)
1415
+ l, h, w = vid.shape
1416
+ vid = jax.image.resize(vid, (l, h // 8, w // 8), "nearest")
1417
+ vid = bandw_vid(mean_blur(vid, 7)[:, None], threshold=0.01)
1418
+ return vid / (jnp.max(vid) + 1e-4)
1419
+ # return jax.image.resize(vid/(jnp.max(vid) + 1e-4), (l, h, w), "nearest")
webui/app_control_animation.py CHANGED
@@ -19,112 +19,46 @@ examples = [
19
  ]
20
 
21
 
22
- images = [] # str path of generated images
23
- initial_frame = None
24
- animation_model = None
25
 
26
 
27
- def generate_initial_frames(
28
- frames_prompt,
29
- model_link,
30
- is_safetensor,
31
- frames_n_prompt,
32
- width,
33
- height,
34
- cfg_scale,
35
- seed,
36
- ):
37
- global images
38
 
39
- if not model_link:
40
- model_link = "dreamlike-art/dreamlike-photoreal-2.0"
41
 
42
- images = animation_model.generate_initial_frames(
43
- frames_prompt,
44
- model_link,
45
- is_safetensor,
46
- frames_n_prompt,
47
- width,
48
- height,
49
- cfg_scale,
50
- seed,
51
- )
52
-
53
- return images
54
-
55
-
56
- def select_initial_frame(evt: gr.SelectData):
57
- global initial_frame
58
-
59
- if evt.index < len(images):
60
- initial_frame = images[evt.index]
61
- print(initial_frame)
62
 
63
 
64
  def create_demo(model: ControlAnimationModel):
65
- global animation_model
66
- animation_model = model
67
-
68
  with gr.Blocks() as demo:
69
- with gr.Column(visible=True) as frame_selection_col:
70
  with gr.Row():
71
  with gr.Column():
72
- frames_prompt = gr.Textbox(
73
- placeholder="Prompt", show_label=False, lines=4
 
 
 
 
 
 
 
 
 
74
  )
75
- frames_n_prompt = gr.Textbox(
76
  placeholder="Negative Prompt (optional)",
77
  show_label=False,
78
  lines=2,
79
  )
80
 
81
- with gr.Column():
82
- model_link = gr.Textbox(
83
- label="Model Link",
84
- placeholder="dreamlike-art/dreamlike-photoreal-2.0",
85
- info="Give the hugging face model name or URL link to safetensor.",
86
- )
87
- is_safetensor = gr.Checkbox(label="Safetensors")
88
  gen_frames_button = gr.Button(
89
  value="Generate Initial Frames", variant="primary"
90
  )
91
 
92
- with gr.Row():
93
- with gr.Column(scale=2):
94
- width = gr.Slider(32, 2048, value=512, label="Width")
95
- height = gr.Slider(32, 2048, value=512, label="Height")
96
- cfg_scale = gr.Slider(1, 20, value=7.0, step=0.1, label="CFG scale")
97
- seed = gr.Slider(
98
- label="Seed",
99
- info="-1 for random seed on each run. Otherwise, the seed will be fixed.",
100
- minimum=-1,
101
- maximum=65536,
102
- value=0,
103
- step=1,
104
- )
105
-
106
- with gr.Column(scale=3):
107
- initial_frames = gr.Gallery(
108
- label="Initial Frames", show_label=False
109
- ).style(columns=4, object_fit="contain")
110
- initial_frames.select(select_initial_frame)
111
- select_frame_button = gr.Button(
112
- value="Select Initial Frame", variant="secondary"
113
- )
114
-
115
- with gr.Column(visible=False) as gen_animation_col:
116
- with gr.Row():
117
- with gr.Column():
118
- prompt = gr.Textbox(label="Prompt")
119
- gen_animation_button = gr.Button(
120
- value="Generate Animation", variant="primary"
121
- )
122
-
123
  with gr.Accordion("Advanced options", open=False):
124
- n_prompt = gr.Textbox(
125
- label="Negative Prompt (optional)", value=""
126
- )
127
-
128
  if on_huggingspace:
129
  video_length = gr.Slider(
130
  label="Video length", minimum=8, maximum=16, step=1
@@ -197,68 +131,101 @@ def create_demo(model: ControlAnimationModel):
197
  )
198
 
199
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  result = gr.Video(label="Generated Video")
201
 
202
- inputs = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  prompt,
 
 
204
  model_link,
205
- is_safetensor,
206
  motion_field_strength_x,
207
  motion_field_strength_y,
208
  t0,
209
  t1,
210
- n_prompt,
211
  chunk_size,
212
  video_length,
213
  merging_ratio,
214
  seed,
215
  ]
216
 
217
- # gr.Examples(examples=examples,
218
- # inputs=inputs,
219
- # outputs=result,
220
- # fn=None,
221
- # run_on_click=False,
222
- # cache_examples=on_huggingspace,
223
- # )
224
-
225
- frame_inputs = [
226
- frames_prompt,
227
- model_link,
228
- is_safetensor,
229
- frames_n_prompt,
230
- width,
231
- height,
232
- cfg_scale,
233
- seed,
234
- ]
235
-
236
- def submit_select():
237
- show = True
238
- if initial_frame is not None: # More to next step
239
  return {
240
- frame_selection_col: gr.update(visible=not show),
241
- gen_animation_col: gr.update(visible=show),
242
  }
243
 
244
  return {
245
- frame_selection_col: gr.update(visible=show),
246
- gen_animation_col: gr.update(visible=not show),
247
  }
248
 
249
  gen_frames_button.click(
250
- generate_initial_frames,
251
  inputs=frame_inputs,
252
  outputs=initial_frames,
253
  )
254
- select_frame_button.click(
255
- submit_select, inputs=None, outputs=[frame_selection_col, gen_animation_col]
256
- )
257
 
258
  gen_animation_button.click(
259
- fn=model.process_text2video,
260
- inputs=inputs,
 
 
 
 
261
  outputs=result,
262
  )
263
 
 
 
 
 
 
 
 
 
264
  return demo
 
19
  ]
20
 
21
 
22
+ def on_video_path_update(evt: gr.EventData):
23
+ return f"Selection: **{evt._data}**"
 
24
 
25
 
26
+ def pose_gallery_callback(evt: gr.SelectData):
27
+ return f"Motion {evt.index+1}"
 
 
 
 
 
 
 
 
 
28
 
 
 
29
 
30
+ def get_frame_index(evt: gr.SelectData):
31
+ return evt.index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  def create_demo(model: ControlAnimationModel):
 
 
 
35
  with gr.Blocks() as demo:
36
+ with gr.Column():
37
  with gr.Row():
38
  with gr.Column():
39
+ # TODO: update so that model_link is customizable
40
+ model_link = gr.Dropdown(
41
+ label="Model Link",
42
+ choices=["runwayml/stable-diffusion-v1-5"],
43
+ value="runwayml/stable-diffusion-v1-5",
44
+ )
45
+ prompt = gr.Textbox(
46
+ placeholder="Prompt",
47
+ show_label=False,
48
+ lines=2,
49
+ info="Give a prompt for an animation you would like to generate. The prompt will be used to create the first initial frame and then the animation.",
50
  )
51
+ negative_prompt = gr.Textbox(
52
  placeholder="Negative Prompt (optional)",
53
  show_label=False,
54
  lines=2,
55
  )
56
 
 
 
 
 
 
 
 
57
  gen_frames_button = gr.Button(
58
  value="Generate Initial Frames", variant="primary"
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  with gr.Accordion("Advanced options", open=False):
 
 
 
 
62
  if on_huggingspace:
63
  video_length = gr.Slider(
64
  label="Video length", minimum=8, maximum=16, step=1
 
131
  )
132
 
133
  with gr.Column():
134
+ gallery_pose_sequence = gr.Gallery(
135
+ label="Pose Sequence",
136
+ value=[
137
+ ("__assets__/dance1.gif", "Motion 1"),
138
+ ("__assets__/dance2.gif", "Motion 2"),
139
+ ("__assets__/dance3.gif", "Motion 3"),
140
+ ("__assets__/dance4.gif", "Motion 4"),
141
+ ("__assets__/dance5.gif", "Motion 5"),
142
+ ],
143
+ ).style(columns=3)
144
+ input_video_path = gr.Textbox(
145
+ label="Pose Sequence", visible=False, value="Motion 1"
146
+ )
147
+ pose_sequence_selector = gr.Markdown("Pose Sequence: **Motion 1**")
148
+
149
+ with gr.Row():
150
+ with gr.Column(visible=True) as frame_selection_view:
151
+ initial_frames = gr.Gallery(
152
+ label="Initial Frames", show_label=False
153
+ ).style(columns=4, rows=1, object_fit="contain", preview=True)
154
+
155
+ gr.Markdown("Select an initial frame to start your animation with.")
156
+ gen_animation_button = gr.Button(
157
+ value="Select Initial Frame & Generate Animation",
158
+ variant="secondary",
159
+ )
160
+
161
+ with gr.Column(visible=False) as animation_view:
162
  result = gr.Video(label="Generated Video")
163
 
164
+ with gr.Box(visible=False):
165
+ initial_frame_index = gr.Number(
166
+ label="Selected Initial Frame Index", value=-1, precision=0
167
+ )
168
+
169
+ input_video_path.change(on_video_path_update, None, pose_sequence_selector)
170
+ gallery_pose_sequence.select(pose_gallery_callback, None, input_video_path)
171
+ initial_frames.select(fn=get_frame_index, outputs=initial_frame_index)
172
+
173
+ frame_inputs = [
174
+ prompt,
175
+ input_video_path,
176
+ negative_prompt,
177
+ ]
178
+
179
+ animation_inputs = [
180
  prompt,
181
+ initial_frame_index,
182
+ input_video_path,
183
  model_link,
 
184
  motion_field_strength_x,
185
  motion_field_strength_y,
186
  t0,
187
  t1,
188
+ negative_prompt,
189
  chunk_size,
190
  video_length,
191
  merging_ratio,
192
  seed,
193
  ]
194
 
195
+ def submit_select(initial_frame_index: int):
196
+ if initial_frame_index != -1: # More to next step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  return {
198
+ frame_selection_view: gr.update(visible=False),
199
+ animation_view: gr.update(visible=True),
200
  }
201
 
202
  return {
203
+ frame_selection_view: gr.update(visible=True),
204
+ animation_view: gr.update(visible=False),
205
  }
206
 
207
  gen_frames_button.click(
208
+ fn=model.generate_initial_frames,
209
  inputs=frame_inputs,
210
  outputs=initial_frames,
211
  )
 
 
 
212
 
213
  gen_animation_button.click(
214
+ fn=submit_select,
215
+ inputs=initial_frame_index,
216
+ outputs=[frame_selection_view, animation_view],
217
+ ).then(
218
+ fn=None,
219
+ inputs=animation_inputs,
220
  outputs=result,
221
  )
222
 
223
+ # gr.Examples(examples=examples,
224
+ # inputs=inputs,
225
+ # outputs=result,
226
+ # fn=None,
227
+ # run_on_click=False,
228
+ # cache_examples=on_huggingspace,
229
+ # )
230
+
231
  return demo