ShaoTengLiu commited on
Commit
44fa1db
1 Parent(s): ab4a868
Files changed (4) hide show
  1. Video-P2P/run.py +0 -993
  2. Video-P2P/run_tuning.py +30 -5
  3. Video-P2P/run_videop2p.py +106 -69
  4. trainer.py +3 -1
Video-P2P/run.py DELETED
@@ -1,993 +0,0 @@
1
- import argparse
2
- import datetime
3
- import logging
4
- import inspect
5
- import math
6
- import os
7
- from typing import Optional, Union, Tuple, List, Callable, Dict
8
- from omegaconf import OmegaConf
9
-
10
- import torch
11
- import torch.nn.functional as F
12
- import torch.utils.checkpoint
13
-
14
- import diffusers
15
- import transformers
16
- from accelerate import Accelerator
17
- from accelerate.logging import get_logger
18
- from accelerate.utils import set_seed
19
- from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
20
- from diffusers.optimization import get_scheduler
21
- from diffusers.utils import check_min_version
22
- from diffusers.utils.import_utils import is_xformers_available
23
- from tqdm.auto import tqdm
24
- from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer
25
-
26
- from tuneavideo.models.unet import UNet3DConditionModel
27
- from tuneavideo.data.dataset import TuneAVideoDataset
28
- from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
29
- from tuneavideo.util import save_videos_grid, ddim_inversion
30
- from einops import rearrange
31
-
32
- import cv2
33
- import abc
34
- import ptp_utils
35
- import seq_aligner
36
- import shutil
37
- from torch.optim.adam import Adam
38
- from PIL import Image
39
- import numpy as np
40
- import decord
41
- decord.bridge.set_bridge('torch')
42
-
43
- # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
44
- check_min_version("0.10.0.dev0")
45
-
46
- logger = get_logger(__name__, log_level="INFO")
47
-
48
-
49
- def main(
50
- pretrained_model_path: str,
51
- output_dir: str,
52
- train_data: Dict,
53
- validation_data: Dict,
54
- validation_steps: int = 100,
55
- trainable_modules: Tuple[str] = (
56
- "attn1.to_q",
57
- "attn2.to_q",
58
- "attn_temp",
59
- ),
60
- train_batch_size: int = 1,
61
- max_train_steps: int = 500,
62
- learning_rate: float = 3e-5,
63
- scale_lr: bool = False,
64
- lr_scheduler: str = "constant",
65
- lr_warmup_steps: int = 0,
66
- adam_beta1: float = 0.9,
67
- adam_beta2: float = 0.999,
68
- adam_weight_decay: float = 1e-2,
69
- adam_epsilon: float = 1e-08,
70
- max_grad_norm: float = 1.0,
71
- gradient_accumulation_steps: int = 1,
72
- gradient_checkpointing: bool = True,
73
- checkpointing_steps: int = 500,
74
- resume_from_checkpoint: Optional[str] = None,
75
- mixed_precision: Optional[str] = "fp16",
76
- use_8bit_adam: bool = False,
77
- enable_xformers_memory_efficient_attention: bool = True,
78
- seed: Optional[int] = None,
79
- # pretrained_model_path: str,
80
- # image_path: str = None,
81
- # prompt: str = None,
82
- prompts: Tuple[str] = None,
83
- eq_params: Dict = None,
84
- save_name: str = None,
85
- is_word_swap: bool = None,
86
- blend_word: Tuple[str] = None,
87
- cross_replace_steps: float = 0.2,
88
- self_replace_steps: float = 0.5,
89
- video_len: int = 8,
90
- fast: bool = False,
91
- mixed_precision_p2p: str = 'fp32',
92
- ):
93
- *_, config = inspect.getargvalues(inspect.currentframe())
94
-
95
- accelerator = Accelerator(
96
- gradient_accumulation_steps=gradient_accumulation_steps,
97
- mixed_precision=mixed_precision,
98
- )
99
-
100
- # Make one log on every process with the configuration for debugging.
101
- logging.basicConfig(
102
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
103
- datefmt="%m/%d/%Y %H:%M:%S",
104
- level=logging.INFO,
105
- )
106
- logger.info(accelerator.state, main_process_only=False)
107
- if accelerator.is_local_main_process:
108
- transformers.utils.logging.set_verbosity_warning()
109
- diffusers.utils.logging.set_verbosity_info()
110
- else:
111
- transformers.utils.logging.set_verbosity_error()
112
- diffusers.utils.logging.set_verbosity_error()
113
-
114
- # If passed along, set the training seed now.
115
- if seed is not None:
116
- set_seed(seed)
117
-
118
- # Handle the output folder creation
119
- if accelerator.is_main_process:
120
- # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
121
- # output_dir = os.path.join(output_dir, now)
122
- os.makedirs(output_dir, exist_ok=True)
123
- os.makedirs(f"{output_dir}/samples", exist_ok=True)
124
- os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
125
- OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
126
-
127
- # Load scheduler, tokenizer and models.
128
- noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
129
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
130
- text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
131
- vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
132
- unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet")
133
-
134
- # Freeze vae and text_encoder
135
- vae.requires_grad_(False)
136
- text_encoder.requires_grad_(False)
137
-
138
- unet.requires_grad_(False)
139
- for name, module in unet.named_modules():
140
- if name.endswith(tuple(trainable_modules)):
141
- for params in module.parameters():
142
- params.requires_grad = True
143
-
144
- if enable_xformers_memory_efficient_attention:
145
- if is_xformers_available():
146
- unet.enable_xformers_memory_efficient_attention()
147
- else:
148
- raise ValueError("xformers is not available. Make sure it is installed correctly")
149
-
150
- if gradient_checkpointing:
151
- unet.enable_gradient_checkpointing()
152
-
153
- if scale_lr:
154
- learning_rate = (
155
- learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
156
- )
157
-
158
- # Initialize the optimizer
159
- if use_8bit_adam:
160
- try:
161
- import bitsandbytes as bnb
162
- except ImportError:
163
- raise ImportError(
164
- "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
165
- )
166
-
167
- optimizer_cls = bnb.optim.AdamW8bit
168
- else:
169
- optimizer_cls = torch.optim.AdamW
170
-
171
- optimizer = optimizer_cls(
172
- unet.parameters(),
173
- lr=learning_rate,
174
- betas=(adam_beta1, adam_beta2),
175
- weight_decay=adam_weight_decay,
176
- eps=adam_epsilon,
177
- )
178
-
179
- # Get the training dataset
180
- train_dataset = TuneAVideoDataset(**train_data)
181
-
182
- # Preprocessing the dataset
183
- train_dataset.prompt_ids = tokenizer(
184
- train_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
185
- ).input_ids[0]
186
-
187
- # DataLoaders creation:
188
- train_dataloader = torch.utils.data.DataLoader(
189
- train_dataset, batch_size=train_batch_size
190
- )
191
-
192
- # Get the validation pipeline
193
- validation_pipeline = TuneAVideoPipeline(
194
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
195
- scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
196
- )
197
- validation_pipeline.enable_vae_slicing()
198
- ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler')
199
- ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps)
200
-
201
- # Scheduler
202
- lr_scheduler = get_scheduler(
203
- lr_scheduler,
204
- optimizer=optimizer,
205
- num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
206
- num_training_steps=max_train_steps * gradient_accumulation_steps,
207
- )
208
-
209
- # Prepare everything with our `accelerator`.
210
- unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
211
- unet, optimizer, train_dataloader, lr_scheduler
212
- )
213
-
214
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
215
- # as these models are only used for inference, keeping weights in full precision is not required.
216
- weight_dtype = torch.float32
217
- if accelerator.mixed_precision == "fp16":
218
- weight_dtype = torch.float16
219
- elif accelerator.mixed_precision == "bf16":
220
- weight_dtype = torch.bfloat16
221
-
222
- # Move text_encode and vae to gpu and cast to weight_dtype
223
- text_encoder.to(accelerator.device, dtype=weight_dtype)
224
- vae.to(accelerator.device, dtype=weight_dtype)
225
-
226
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
227
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
228
- # Afterwards we recalculate our number of training epochs
229
- num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
230
-
231
- # We need to initialize the trackers we use, and also store our configuration.
232
- # The trackers initializes automatically on the main process.
233
- if accelerator.is_main_process:
234
- accelerator.init_trackers("text2video-fine-tune")
235
-
236
- # Train!
237
- total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
238
-
239
- logger.info("***** Running training *****")
240
- logger.info(f" Num examples = {len(train_dataset)}")
241
- logger.info(f" Num Epochs = {num_train_epochs}")
242
- logger.info(f" Instantaneous batch size per device = {train_batch_size}")
243
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
244
- logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
245
- logger.info(f" Total optimization steps = {max_train_steps}")
246
- global_step = 0
247
- first_epoch = 0
248
-
249
- # Potentially load in the weights and states from a previous save
250
- if resume_from_checkpoint:
251
- if resume_from_checkpoint != "latest":
252
- path = os.path.basename(resume_from_checkpoint)
253
- else:
254
- # Get the most recent checkpoint
255
- dirs = os.listdir(output_dir)
256
- dirs = [d for d in dirs if d.startswith("checkpoint")]
257
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
258
- path = dirs[-1]
259
- accelerator.print(f"Resuming from checkpoint {path}")
260
- accelerator.load_state(os.path.join(output_dir, path))
261
- global_step = int(path.split("-")[1])
262
-
263
- first_epoch = global_step // num_update_steps_per_epoch
264
- resume_step = global_step % num_update_steps_per_epoch
265
-
266
- # Only show the progress bar once on each machine.
267
- progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
268
- progress_bar.set_description("Steps")
269
-
270
- for epoch in range(first_epoch, num_train_epochs):
271
- unet.train()
272
- train_loss = 0.0
273
- for step, batch in enumerate(train_dataloader):
274
- # Skip steps until we reach the resumed step
275
- if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
276
- if step % gradient_accumulation_steps == 0:
277
- progress_bar.update(1)
278
- continue
279
-
280
- with accelerator.accumulate(unet):
281
- # Convert videos to latent space
282
- pixel_values = batch["pixel_values"].to(weight_dtype)
283
- video_length = pixel_values.shape[1]
284
- pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
285
- latents = vae.encode(pixel_values).latent_dist.sample()
286
- latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
287
- latents = latents * 0.18215
288
-
289
- # Sample noise that we'll add to the latents
290
- noise = torch.randn_like(latents)
291
- bsz = latents.shape[0]
292
- # Sample a random timestep for each video
293
- timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
294
- timesteps = timesteps.long()
295
-
296
- # Add noise to the latents according to the noise magnitude at each timestep
297
- # (this is the forward diffusion process)
298
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
299
-
300
- # Get the text embedding for conditioning
301
- encoder_hidden_states = text_encoder(batch["prompt_ids"])[0]
302
-
303
- # Get the target for loss depending on the prediction type
304
- if noise_scheduler.prediction_type == "epsilon":
305
- target = noise
306
- elif noise_scheduler.prediction_type == "v_prediction":
307
- target = noise_scheduler.get_velocity(latents, noise, timesteps)
308
- else:
309
- raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")
310
-
311
- # Predict the noise residual and compute loss
312
- model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
313
- loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
314
-
315
- # Gather the losses across all processes for logging (if we use distributed training).
316
- avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
317
- train_loss += avg_loss.item() / gradient_accumulation_steps
318
-
319
- # Backpropagate
320
- accelerator.backward(loss)
321
- if accelerator.sync_gradients:
322
- accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
323
- optimizer.step()
324
- lr_scheduler.step()
325
- optimizer.zero_grad()
326
-
327
- # Checks if the accelerator has performed an optimization step behind the scenes
328
- if accelerator.sync_gradients:
329
- progress_bar.update(1)
330
- global_step += 1
331
- accelerator.log({"train_loss": train_loss}, step=global_step)
332
- train_loss = 0.0
333
-
334
- if global_step % checkpointing_steps == 0:
335
- if accelerator.is_main_process:
336
- save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
337
- accelerator.save_state(save_path)
338
- logger.info(f"Saved state to {save_path}")
339
-
340
- if global_step % validation_steps == 0:
341
- if accelerator.is_main_process:
342
- samples = []
343
- generator = torch.Generator(device=latents.device)
344
- generator.manual_seed(seed)
345
-
346
- ddim_inv_latent = None
347
- if validation_data.use_inv_latent:
348
- inv_latents_path = os.path.join(output_dir, f"inv_latents/ddim_latent-{global_step}.pt")
349
- ddim_inv_latent = ddim_inversion(
350
- validation_pipeline, ddim_inv_scheduler, video_latent=latents,
351
- num_inv_steps=validation_data.num_inv_steps, prompt="")[-1].to(weight_dtype)
352
- torch.save(ddim_inv_latent, inv_latents_path)
353
-
354
- for idx, prompt in enumerate(validation_data.prompts):
355
- sample = validation_pipeline(prompt, generator=generator, latents=ddim_inv_latent,
356
- **validation_data).videos
357
- save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{prompt}.gif")
358
- samples.append(sample)
359
- samples = torch.concat(samples)
360
- save_path = f"{output_dir}/samples/sample-{global_step}.gif"
361
- save_videos_grid(samples, save_path)
362
- logger.info(f"Saved samples to {save_path}")
363
-
364
- logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
365
- progress_bar.set_postfix(**logs)
366
-
367
- if global_step >= max_train_steps:
368
- break
369
-
370
- # Create the pipeline using the trained modules and save it.
371
- accelerator.wait_for_everyone()
372
- if accelerator.is_main_process:
373
- unet = accelerator.unwrap_model(unet)
374
- pipeline = TuneAVideoPipeline.from_pretrained(
375
- pretrained_model_path,
376
- text_encoder=text_encoder,
377
- vae=vae,
378
- unet=unet,
379
- )
380
- pipeline.save_pretrained(output_dir)
381
-
382
- accelerator.end_training()
383
-
384
- torch.cuda.empty_cache()
385
-
386
- # Video-P2P
387
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
388
- MY_TOKEN = ''
389
- LOW_RESOURCE = False
390
- NUM_DDIM_STEPS = 50
391
- GUIDANCE_SCALE = 7.5
392
- MAX_NUM_WORDS = 77
393
- device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
394
-
395
- # need to adjust sometimes
396
- mask_th = (.3, .3)
397
-
398
-
399
- pretrained_model_path = output_dir
400
- image_path = train_data['video_path']
401
- prompt = train_data['prompt']
402
- # prompts = [prompt, ]
403
- output_folder = os.path.join(pretrained_model_path, 'results')
404
- if fast:
405
- save_name_1 = os.path.join(output_folder, 'inversion_fast.gif')
406
- save_name_2 = os.path.join(output_folder, '{}_fast.gif'.format(save_name))
407
- else:
408
- save_name_1 = os.path.join(output_folder, 'inversion.gif')
409
- save_name_2 = os.path.join(output_folder, '{}.gif'.format(save_name))
410
- if blend_word:
411
- blend_word = (((blend_word[0],), (blend_word[1],)))
412
- eq_params = dict(eq_params)
413
- prompts = list(prompts)
414
- cross_replace_steps = {'default_': cross_replace_steps,}
415
-
416
- weight_dtype = torch.float32
417
- if mixed_precision_p2p == "fp16":
418
- weight_dtype = torch.float16
419
- elif mixed_precision_p2p == "bf16":
420
- weight_dtype = torch.bfloat16
421
-
422
- if not os.path.exists(output_folder):
423
- os.makedirs(output_folder)
424
-
425
- # Load the tokenizer
426
- tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
427
- # Load models and create wrapper for stable diffusion
428
- text_encoder = CLIPTextModel.from_pretrained(
429
- pretrained_model_path,
430
- subfolder="text_encoder",
431
- ).to(device, dtype=weight_dtype)
432
- vae = AutoencoderKL.from_pretrained(
433
- pretrained_model_path,
434
- subfolder="vae",
435
- ).to(device, dtype=weight_dtype)
436
- unet = UNet3DConditionModel.from_pretrained(
437
- pretrained_model_path, subfolder="unet"
438
- ).to(device)
439
- ldm_stable = TuneAVideoPipeline(
440
- vae=vae,
441
- text_encoder=text_encoder,
442
- tokenizer=tokenizer,
443
- unet=unet,
444
- scheduler=scheduler,
445
- ).to(device)
446
-
447
- try:
448
- ldm_stable.disable_xformers_memory_efficient_attention()
449
- except AttributeError:
450
- print("Attribute disable_xformers_memory_efficient_attention() is missing")
451
- tokenizer = ldm_stable.tokenizer # Tokenizer of class: [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
452
- # A tokenizer breaks a stream of text into tokens, usually by looking for whitespace (tabs, spaces, new lines).
453
-
454
- class LocalBlend:
455
-
456
- def get_mask(self, maps, alpha, use_pool):
457
- k = 1
458
- maps = (maps * alpha).sum(-1).mean(2)
459
- if use_pool:
460
- maps = F.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
461
- mask = F.interpolate(maps, size=(x_t.shape[3:]))
462
- mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
463
- mask = mask.gt(self.th[1-int(use_pool)])
464
- mask = mask[:1] + mask
465
- return mask
466
-
467
- def __call__(self, x_t, attention_store, step):
468
- self.counter += 1
469
- if self.counter > self.start_blend:
470
- maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
471
- maps = [item.reshape(self.alpha_layers.shape[0], -1, 8, 16, 16, MAX_NUM_WORDS) for item in maps]
472
- maps = torch.cat(maps, dim=2)
473
- mask = self.get_mask(maps, self.alpha_layers, True)
474
- if self.substruct_layers is not None:
475
- maps_sub = ~self.get_mask(maps, self.substruct_layers, False)
476
- mask = mask * maps_sub
477
- mask = mask.float()
478
- mask = mask.reshape(-1, 1, mask.shape[-3], mask.shape[-2], mask.shape[-1])
479
- x_t = x_t[:1] + mask * (x_t - x_t[:1])
480
- return x_t
481
-
482
- def __init__(self, prompts: List[str], words: [List[List[str]]], substruct_words=None, start_blend=0.2, th=(.3, .3)):
483
- alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
484
- for i, (prompt, words_) in enumerate(zip(prompts, words)):
485
- if type(words_) is str:
486
- words_ = [words_]
487
- for word in words_:
488
- ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
489
- alpha_layers[i, :, :, :, :, ind] = 1
490
-
491
- if substruct_words is not None:
492
- substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
493
- for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)):
494
- if type(words_) is str:
495
- words_ = [words_]
496
- for word in words_:
497
- ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
498
- substruct_layers[i, :, :, :, :, ind] = 1
499
- self.substruct_layers = substruct_layers.to(device)
500
- else:
501
- self.substruct_layers = None
502
- self.alpha_layers = alpha_layers.to(device)
503
- self.start_blend = int(start_blend * NUM_DDIM_STEPS)
504
- self.counter = 0
505
- self.th=th
506
-
507
-
508
- class EmptyControl:
509
-
510
-
511
- def step_callback(self, x_t):
512
- return x_t
513
-
514
- def between_steps(self):
515
- return
516
-
517
- def __call__(self, attn, is_cross: bool, place_in_unet: str):
518
- return attn
519
-
520
-
521
- class AttentionControl(abc.ABC):
522
-
523
- def step_callback(self, x_t):
524
- return x_t
525
-
526
- def between_steps(self):
527
- return
528
-
529
- @property
530
- def num_uncond_att_layers(self):
531
- return self.num_att_layers if LOW_RESOURCE else 0
532
-
533
- @abc.abstractmethod
534
- def forward (self, attn, is_cross: bool, place_in_unet: str):
535
- raise NotImplementedError
536
-
537
- def __call__(self, attn, is_cross: bool, place_in_unet: str):
538
- if self.cur_att_layer >= self.num_uncond_att_layers:
539
- if LOW_RESOURCE:
540
- attn = self.forward(attn, is_cross, place_in_unet)
541
- else:
542
- h = attn.shape[0]
543
- attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
544
- self.cur_att_layer += 1
545
- if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
546
- self.cur_att_layer = 0
547
- self.cur_step += 1
548
- self.between_steps()
549
- return attn
550
-
551
- def reset(self):
552
- self.cur_step = 0
553
- self.cur_att_layer = 0
554
-
555
- def __init__(self):
556
- self.cur_step = 0
557
- self.num_att_layers = -1
558
- self.cur_att_layer = 0
559
-
560
- class SpatialReplace(EmptyControl):
561
-
562
- def step_callback(self, x_t):
563
- if self.cur_step < self.stop_inject:
564
- b = x_t.shape[0]
565
- x_t = x_t[:1].expand(b, *x_t.shape[1:])
566
- return x_t
567
-
568
- def __init__(self, stop_inject: float):
569
- super(SpatialReplace, self).__init__()
570
- self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS)
571
-
572
-
573
- class AttentionStore(AttentionControl):
574
-
575
- @staticmethod
576
- def get_empty_store():
577
- return {"down_cross": [], "mid_cross": [], "up_cross": [],
578
- "down_self": [], "mid_self": [], "up_self": []}
579
-
580
- def forward(self, attn, is_cross: bool, place_in_unet: str):
581
- key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
582
- if attn.shape[1] <= 32 ** 2:
583
- self.step_store[key].append(attn)
584
- return attn
585
-
586
- def between_steps(self):
587
- if len(self.attention_store) == 0:
588
- self.attention_store = self.step_store
589
- else:
590
- for key in self.attention_store:
591
- for i in range(len(self.attention_store[key])):
592
- self.attention_store[key][i] += self.step_store[key][i]
593
- self.step_store = self.get_empty_store()
594
-
595
- def get_average_attention(self):
596
- average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
597
- return average_attention
598
-
599
-
600
- def reset(self):
601
- super(AttentionStore, self).reset()
602
- self.step_store = self.get_empty_store()
603
- self.attention_store = {}
604
-
605
- def __init__(self):
606
- super(AttentionStore, self).__init__()
607
- self.step_store = self.get_empty_store()
608
- self.attention_store = {}
609
-
610
-
611
- class AttentionControlEdit(AttentionStore, abc.ABC):
612
-
613
- def step_callback(self, x_t):
614
- if self.local_blend is not None:
615
- x_t = self.local_blend(x_t, self.attention_store, self.cur_step)
616
- return x_t
617
-
618
- def replace_self_attention(self, attn_base, att_replace, place_in_unet):
619
- if att_replace.shape[2] <= 32 ** 2:
620
- attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
621
- return attn_base
622
- else:
623
- return att_replace
624
-
625
- @abc.abstractmethod
626
- def replace_cross_attention(self, attn_base, att_replace):
627
- raise NotImplementedError
628
-
629
- def forward(self, attn, is_cross: bool, place_in_unet: str):
630
- super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
631
- if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
632
- h = attn.shape[0] // (self.batch_size)
633
- attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
634
- attn_base, attn_repalce = attn[0], attn[1:]
635
- if is_cross:
636
- alpha_words = self.cross_replace_alpha[self.cur_step]
637
- attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
638
- attn[1:] = attn_repalce_new
639
- else:
640
- attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet)
641
- attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
642
- return attn
643
-
644
- def __init__(self, prompts, num_steps: int,
645
- cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
646
- self_replace_steps: Union[float, Tuple[float, float]],
647
- local_blend: Optional[LocalBlend]):
648
- super(AttentionControlEdit, self).__init__()
649
- self.batch_size = len(prompts)
650
- self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
651
- if type(self_replace_steps) is float:
652
- self_replace_steps = 0, self_replace_steps
653
- self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
654
- self.local_blend = local_blend
655
-
656
- class AttentionReplace(AttentionControlEdit):
657
-
658
- def replace_cross_attention(self, attn_base, att_replace):
659
- return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
660
-
661
- def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
662
- local_blend: Optional[LocalBlend] = None):
663
- super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
664
- self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
665
-
666
-
667
- class AttentionRefine(AttentionControlEdit):
668
-
669
- def replace_cross_attention(self, attn_base, att_replace):
670
- attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
671
- attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
672
- return attn_replace
673
-
674
- def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
675
- local_blend: Optional[LocalBlend] = None):
676
- super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
677
- self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
678
- self.mapper, alphas = self.mapper.to(device), alphas.to(device)
679
- self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
680
-
681
-
682
- class AttentionReweight(AttentionControlEdit):
683
-
684
- def replace_cross_attention(self, attn_base, att_replace):
685
- if self.prev_controller is not None:
686
- attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
687
- attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
688
- return attn_replace
689
-
690
- def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
691
- local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
692
- super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
693
- self.equalizer = equalizer.to(device)
694
- self.prev_controller = controller
695
-
696
-
697
- def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],
698
- Tuple[float, ...]]):
699
- if type(word_select) is int or type(word_select) is str:
700
- word_select = (word_select,)
701
- equalizer = torch.ones(1, 77)
702
-
703
- for word, val in zip(word_select, values):
704
- inds = ptp_utils.get_word_inds(text, word, tokenizer)
705
- equalizer[:, inds] = val
706
- return equalizer
707
-
708
- def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
709
- out = []
710
- attention_maps = attention_store.get_average_attention()
711
- num_pixels = res ** 2
712
- for location in from_where:
713
- for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
714
- if item.shape[1] == num_pixels:
715
- cross_maps = item.reshape(8, 8, res, res, item.shape[-1])
716
- out.append(cross_maps)
717
- out = torch.cat(out, dim=1)
718
- out = out.sum(1) / out.shape[1]
719
- return out.cpu()
720
-
721
-
722
- def make_controller(prompts: List[str], is_replace_controller: bool, cross_replace_steps: Dict[str, float], self_replace_steps: float, blend_words=None, equilizer_params=None, mask_th=(.3,.3)) -> AttentionControlEdit:
723
- if blend_words is None:
724
- lb = None
725
- else:
726
- lb = LocalBlend(prompts, blend_word, th=mask_th)
727
- if is_replace_controller:
728
- controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb)
729
- else:
730
- controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb)
731
- if equilizer_params is not None:
732
- eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"])
733
- controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
734
- self_replace_steps=self_replace_steps, equalizer=eq, local_blend=lb, controller=controller)
735
- return controller
736
-
737
-
738
- def load_512_seq(image_path, left=0, right=0, top=0, bottom=0, n_sample_frame=video_len, sampling_rate=1):
739
- vr = decord.VideoReader(image_path, width=512, height=512)
740
- sample_index = list(range(0, len(vr), sampling_rate))[:n_sample_frame]
741
- video = vr.get_batch(sample_index)
742
- return video.numpy()
743
-
744
-
745
- class NullInversion:
746
-
747
- def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
748
- prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
749
- alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
750
- alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
751
- beta_prod_t = 1 - alpha_prod_t
752
- pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
753
- pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
754
- prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
755
- return prev_sample
756
-
757
- def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
758
- timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
759
- alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
760
- alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
761
- beta_prod_t = 1 - alpha_prod_t
762
- next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
763
- next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
764
- next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
765
- return next_sample
766
-
767
- def get_noise_pred_single(self, latents, t, context):
768
- noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
769
- return noise_pred
770
-
771
- def get_noise_pred(self, latents, t, is_forward=True, context=None):
772
- latents_input = torch.cat([latents] * 2)
773
- if context is None:
774
- context = self.context
775
- guidance_scale = 1 if is_forward else GUIDANCE_SCALE
776
- noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
777
- noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
778
- noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
779
- if is_forward:
780
- latents = self.next_step(noise_pred, t, latents)
781
- else:
782
- latents = self.prev_step(noise_pred, t, latents)
783
- return latents
784
-
785
- @torch.no_grad()
786
- def latent2image(self, latents, return_type='np'):
787
- latents = 1 / 0.18215 * latents.detach()
788
- image = self.model.vae.decode(latents)['sample']
789
- if return_type == 'np':
790
- image = (image / 2 + 0.5).clamp(0, 1)
791
- image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
792
- image = (image * 255).astype(np.uint8)
793
- return image
794
-
795
- @torch.no_grad()
796
- def latent2image_video(self, latents, return_type='np'):
797
- latents = 1 / 0.18215 * latents.detach()
798
- latents = latents[0].permute(1, 0, 2, 3)
799
- image = self.model.vae.decode(latents)['sample']
800
- if return_type == 'np':
801
- image = (image / 2 + 0.5).clamp(0, 1)
802
- image = image.cpu().permute(0, 2, 3, 1).numpy()
803
- image = (image * 255).astype(np.uint8)
804
- return image
805
-
806
- @torch.no_grad()
807
- def image2latent(self, image):
808
- with torch.no_grad():
809
- if type(image) is Image:
810
- image = np.array(image)
811
- if type(image) is torch.Tensor and image.dim() == 4:
812
- latents = image
813
- else:
814
- image = torch.from_numpy(image).float() / 127.5 - 1
815
- image = image.permute(2, 0, 1).unsqueeze(0).to(device, dtype=weight_dtype)
816
- latents = self.model.vae.encode(image)['latent_dist'].mean
817
- latents = latents * 0.18215
818
- return latents
819
-
820
- @torch.no_grad()
821
- def image2latent_video(self, image):
822
- with torch.no_grad():
823
- image = torch.from_numpy(image).float() / 127.5 - 1
824
- image = image.permute(0, 3, 1, 2).to(device).to(device, dtype=weight_dtype)
825
- latents = self.model.vae.encode(image)['latent_dist'].mean
826
- latents = rearrange(latents, "(b f) c h w -> b c f h w", b=1)
827
- latents = latents * 0.18215
828
- return latents
829
-
830
- @torch.no_grad()
831
- def init_prompt(self, prompt: str):
832
- uncond_input = self.model.tokenizer(
833
- [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
834
- return_tensors="pt"
835
- )
836
- uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
837
- text_input = self.model.tokenizer(
838
- [prompt],
839
- padding="max_length",
840
- max_length=self.model.tokenizer.model_max_length,
841
- truncation=True,
842
- return_tensors="pt",
843
- )
844
- text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
845
- self.context = torch.cat([uncond_embeddings, text_embeddings])
846
- self.prompt = prompt
847
-
848
- @torch.no_grad()
849
- def ddim_loop(self, latent):
850
- uncond_embeddings, cond_embeddings = self.context.chunk(2)
851
- all_latent = [latent]
852
- latent = latent.clone().detach()
853
- for i in range(NUM_DDIM_STEPS):
854
- t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
855
- noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
856
- latent = self.next_step(noise_pred, t, latent)
857
- all_latent.append(latent)
858
- return all_latent
859
-
860
- @property
861
- def scheduler(self):
862
- return self.model.scheduler
863
-
864
- @torch.no_grad()
865
- def ddim_inversion(self, image):
866
- latent = self.image2latent_video(image)
867
- image_rec = self.latent2image_video(latent)
868
- ddim_latents = self.ddim_loop(latent)
869
- return image_rec, ddim_latents
870
-
871
- def null_optimization(self, latents, num_inner_steps, epsilon):
872
- uncond_embeddings, cond_embeddings = self.context.chunk(2)
873
- uncond_embeddings_list = []
874
- latent_cur = latents[-1]
875
- # bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS)
876
- for i in range(NUM_DDIM_STEPS):
877
- uncond_embeddings = uncond_embeddings.clone().detach()
878
- uncond_embeddings.requires_grad = True
879
- optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
880
- latent_prev = latents[len(latents) - i - 2]
881
- t = self.model.scheduler.timesteps[i]
882
- with torch.no_grad():
883
- noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
884
- for j in range(num_inner_steps):
885
- noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
886
- noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)
887
- latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
888
- loss = F.mse_loss(latents_prev_rec, latent_prev)
889
- optimizer.zero_grad()
890
- loss.backward()
891
- optimizer.step()
892
- loss_item = loss.item()
893
- # bar.update()
894
- if loss_item < epsilon + i * 2e-5:
895
- break
896
- # for j in range(j + 1, num_inner_steps):
897
- # bar.update()
898
- uncond_embeddings_list.append(uncond_embeddings[:1].detach())
899
- with torch.no_grad():
900
- context = torch.cat([uncond_embeddings, cond_embeddings])
901
- latent_cur = self.get_noise_pred(latent_cur, t, False, context)
902
- # bar.close()
903
- return uncond_embeddings_list
904
-
905
- def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
906
- self.init_prompt(prompt)
907
- ptp_utils.register_attention_control(self.model, None)
908
- image_gt = load_512_seq(image_path, *offsets)
909
- if verbose:
910
- print("DDIM inversion...")
911
- image_rec, ddim_latents = self.ddim_inversion(image_gt)
912
- if verbose:
913
- print("Null-text optimization...")
914
- uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon)
915
- return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
916
-
917
- def invert_(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
918
- self.init_prompt(prompt)
919
- ptp_utils.register_attention_control(self.model, None)
920
- image_gt = load_512_seq(image_path, *offsets)
921
- if verbose:
922
- print("DDIM inversion...")
923
- image_rec, ddim_latents = self.ddim_inversion(image_gt)
924
- if verbose:
925
- print("Null-text optimization...")
926
- return (image_gt, image_rec), ddim_latents[-1], None
927
-
928
- def __init__(self, model):
929
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
930
- set_alpha_to_one=False)
931
- self.model = model
932
- self.tokenizer = self.model.tokenizer
933
- self.model.scheduler.set_timesteps(NUM_DDIM_STEPS)
934
- self.prompt = None
935
- self.context = None
936
-
937
- null_inversion = NullInversion(ldm_stable)
938
-
939
- ###############
940
- # Custom APIs:
941
-
942
- ldm_stable.enable_xformers_memory_efficient_attention()
943
-
944
- if fast:
945
- (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert_(image_path, prompt, offsets=(0,0,0,0), verbose=True)
946
- else:
947
- (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(image_path, prompt, offsets=(0,0,0,0), verbose=True)
948
-
949
- ##### load uncond #####
950
- # uncond_embeddings_load = np.load(uncond_embeddings_path)
951
- # uncond_embeddings = []
952
- # for i in range(uncond_embeddings_load.shape[0]):
953
- # uncond_embeddings.append(torch.from_numpy(uncond_embeddings_load[i]).to(device))
954
- #######################
955
-
956
- ##### save uncond #####
957
- # uncond_embeddings = torch.cat(uncond_embeddings)
958
- # uncond_embeddings = uncond_embeddings.cpu().numpy()
959
- #######################
960
-
961
- print("Start Video-P2P!")
962
- controller = make_controller(prompts, is_word_swap, cross_replace_steps, self_replace_steps, blend_word, eq_params, mask_th=mask_th)
963
- ptp_utils.register_attention_control(ldm_stable, controller)
964
- generator = torch.Generator(device=device)
965
- with torch.no_grad():
966
- sequence = ldm_stable(
967
- prompts,
968
- generator=generator,
969
- latents=x_t,
970
- uncond_embeddings_pre=uncond_embeddings,
971
- controller = controller,
972
- video_length=video_len,
973
- fast=fast,
974
- ).videos
975
- sequence1 = rearrange(sequence[0], "c t h w -> t h w c")
976
- sequence2 = rearrange(sequence[1], "c t h w -> t h w c")
977
- inversion = []
978
- videop2p = []
979
- for i in range(sequence1.shape[0]):
980
- inversion.append( Image.fromarray((sequence1[i] * 255).numpy().astype(np.uint8)) )
981
- videop2p.append( Image.fromarray((sequence2[i] * 255).numpy().astype(np.uint8)) )
982
-
983
- # inversion[0].save(save_name_1, save_all=True, append_images=inversion[1:], optimize=False, loop=0, duration=250)
984
- videop2p[0].save(save_name_2, save_all=True, append_images=videop2p[1:], optimize=False, loop=0, duration=250)
985
-
986
-
987
- if __name__ == "__main__":
988
- parser = argparse.ArgumentParser()
989
- parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml")
990
- parser.add_argument("--fast", action='store_true')
991
- args = parser.parse_args()
992
-
993
- main(**OmegaConf.load(args.config), fast=args.fast)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Video-P2P/run_tuning.py CHANGED
@@ -1,12 +1,10 @@
1
- # From https://github.com/showlab/Tune-A-Video/blob/main/train_tuneavideo.py
2
-
3
  import argparse
4
  import datetime
5
  import logging
6
  import inspect
7
  import math
8
  import os
9
- from typing import Dict, Optional, Tuple
10
  from omegaconf import OmegaConf
11
 
12
  import torch
@@ -23,7 +21,7 @@ from diffusers.optimization import get_scheduler
23
  from diffusers.utils import check_min_version
24
  from diffusers.utils.import_utils import is_xformers_available
25
  from tqdm.auto import tqdm
26
- from transformers import CLIPTextModel, CLIPTokenizer
27
 
28
  from tuneavideo.models.unet import UNet3DConditionModel
29
  from tuneavideo.data.dataset import TuneAVideoDataset
@@ -31,6 +29,16 @@ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
31
  from tuneavideo.util import save_videos_grid, ddim_inversion
32
  from einops import rearrange
33
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
36
  check_min_version("0.10.0.dev0")
@@ -68,6 +76,19 @@ def main(
68
  use_8bit_adam: bool = False,
69
  enable_xformers_memory_efficient_attention: bool = True,
70
  seed: Optional[int] = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ):
72
  *_, config = inspect.getargvalues(inspect.currentframe())
73
 
@@ -96,6 +117,8 @@ def main(
96
 
97
  # Handle the output folder creation
98
  if accelerator.is_main_process:
 
 
99
  os.makedirs(output_dir, exist_ok=True)
100
  os.makedirs(f"{output_dir}/samples", exist_ok=True)
101
  os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
@@ -358,10 +381,12 @@ def main(
358
 
359
  accelerator.end_training()
360
 
 
361
 
362
  if __name__ == "__main__":
363
  parser = argparse.ArgumentParser()
364
  parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml")
 
365
  args = parser.parse_args()
366
 
367
- main(**OmegaConf.load(args.config))
 
 
 
1
  import argparse
2
  import datetime
3
  import logging
4
  import inspect
5
  import math
6
  import os
7
+ from typing import Optional, Union, Tuple, List, Callable, Dict
8
  from omegaconf import OmegaConf
9
 
10
  import torch
 
21
  from diffusers.utils import check_min_version
22
  from diffusers.utils.import_utils import is_xformers_available
23
  from tqdm.auto import tqdm
24
+ from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer
25
 
26
  from tuneavideo.models.unet import UNet3DConditionModel
27
  from tuneavideo.data.dataset import TuneAVideoDataset
 
29
  from tuneavideo.util import save_videos_grid, ddim_inversion
30
  from einops import rearrange
31
 
32
+ import cv2
33
+ import abc
34
+ import ptp_utils
35
+ import seq_aligner
36
+ import shutil
37
+ from torch.optim.adam import Adam
38
+ from PIL import Image
39
+ import numpy as np
40
+ import decord
41
+ decord.bridge.set_bridge('torch')
42
 
43
  # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
44
  check_min_version("0.10.0.dev0")
 
76
  use_8bit_adam: bool = False,
77
  enable_xformers_memory_efficient_attention: bool = True,
78
  seed: Optional[int] = None,
79
+ # pretrained_model_path: str,
80
+ # image_path: str = None,
81
+ # prompt: str = None,
82
+ prompts: Tuple[str] = None,
83
+ eq_params: Dict = None,
84
+ save_name: str = None,
85
+ is_word_swap: bool = None,
86
+ blend_word: Tuple[str] = None,
87
+ cross_replace_steps: float = 0.2,
88
+ self_replace_steps: float = 0.5,
89
+ video_len: int = 8,
90
+ fast: bool = False,
91
+ mixed_precision_p2p: str = 'fp32',
92
  ):
93
  *_, config = inspect.getargvalues(inspect.currentframe())
94
 
 
117
 
118
  # Handle the output folder creation
119
  if accelerator.is_main_process:
120
+ # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
121
+ # output_dir = os.path.join(output_dir, now)
122
  os.makedirs(output_dir, exist_ok=True)
123
  os.makedirs(f"{output_dir}/samples", exist_ok=True)
124
  os.makedirs(f"{output_dir}/inv_latents", exist_ok=True)
 
381
 
382
  accelerator.end_training()
383
 
384
+ torch.cuda.empty_cache()
385
 
386
  if __name__ == "__main__":
387
  parser = argparse.ArgumentParser()
388
  parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml")
389
+ parser.add_argument("--fast", action='store_true')
390
  args = parser.parse_args()
391
 
392
+ main(**OmegaConf.load(args.config), fast=args.fast)
Video-P2P/run_videop2p.py CHANGED
@@ -1,54 +1,113 @@
1
- # Adapted from https://github.com/google/prompt-to-prompt/blob/main/null_text_w_ptp.ipynb
2
-
 
 
 
3
  import os
4
  from typing import Optional, Union, Tuple, List, Callable, Dict
5
- from tqdm.notebook import tqdm
 
6
  import torch
7
- from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
8
- import torch.nn.functional as nnf
9
- import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import abc
11
  import ptp_utils
12
  import seq_aligner
13
  import shutil
14
  from torch.optim.adam import Adam
15
  from PIL import Image
16
- from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer
17
- from einops import rearrange
 
18
 
19
- from tuneavideo.models.unet import UNet3DConditionModel
20
- from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
21
 
22
- import cv2
23
- import argparse
24
- from omegaconf import OmegaConf
25
 
26
- scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
27
- MY_TOKEN = ''
28
- LOW_RESOURCE = False
29
- NUM_DDIM_STEPS = 50
30
- GUIDANCE_SCALE = 7.5
31
- MAX_NUM_WORDS = 77
32
- device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
33
-
34
- # need to adjust sometimes
35
- mask_th = (.3, .3)
36
 
37
  def main(
38
  pretrained_model_path: str,
39
- image_path: str,
40
- prompt: str,
41
- prompts: Tuple[str],
42
- eq_params: Dict,
43
- save_name: str,
44
- is_word_swap: bool,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  blend_word: Tuple[str] = None,
46
  cross_replace_steps: float = 0.2,
47
  self_replace_steps: float = 0.5,
48
  video_len: int = 8,
49
  fast: bool = False,
50
- mixed_precision: str = 'fp32',
51
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  output_folder = os.path.join(pretrained_model_path, 'results')
53
  if fast:
54
  save_name_1 = os.path.join(output_folder, 'inversion_fast.gif')
@@ -63,9 +122,9 @@ def main(
63
  cross_replace_steps = {'default_': cross_replace_steps,}
64
 
65
  weight_dtype = torch.float32
66
- if mixed_precision == "fp16":
67
  weight_dtype = torch.float16
68
- elif mixed_precision == "bf16":
69
  weight_dtype = torch.bfloat16
70
 
71
  if not os.path.exists(output_folder):
@@ -106,8 +165,8 @@ def main(
106
  k = 1
107
  maps = (maps * alpha).sum(-1).mean(2)
108
  if use_pool:
109
- maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
110
- mask = nnf.interpolate(maps, size=(x_t.shape[3:]))
111
  mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
112
  mask = mask.gt(self.th[1-int(use_pool)])
113
  mask = mask[:1] + mask
@@ -385,33 +444,10 @@ def main(
385
 
386
 
387
  def load_512_seq(image_path, left=0, right=0, top=0, bottom=0, n_sample_frame=video_len, sampling_rate=1):
388
- images = []
389
- for file in sorted(os.listdir(image_path)):
390
- images.append(file)
391
- n_images = len(images)
392
- sequence_length = (n_sample_frame - 1) * sampling_rate + 1
393
- if n_images < sequence_length:
394
- raise ValueError
395
- frames = []
396
- for index in range(n_sample_frame):
397
- p = os.path.join(image_path, images[index])
398
- image = np.array(Image.open(p).convert("RGB"))
399
- h, w, c = image.shape
400
- left = min(left, w-1)
401
- right = min(right, w - left - 1)
402
- top = min(top, h - left - 1)
403
- bottom = min(bottom, h - top - 1)
404
- image = image[top:h-bottom, left:w-right]
405
- h, w, c = image.shape
406
- if h < w:
407
- offset = (w - h) // 2
408
- image = image[:, offset:offset + h]
409
- elif w < h:
410
- offset = (h - w) // 2
411
- image = image[offset:offset + w]
412
- image = np.array(Image.fromarray(image).resize((512, 512)))
413
- frames.append(image)
414
- return np.stack(frames)
415
 
416
 
417
  class NullInversion:
@@ -544,7 +580,7 @@ def main(
544
  uncond_embeddings, cond_embeddings = self.context.chunk(2)
545
  uncond_embeddings_list = []
546
  latent_cur = latents[-1]
547
- bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS)
548
  for i in range(NUM_DDIM_STEPS):
549
  uncond_embeddings = uncond_embeddings.clone().detach()
550
  uncond_embeddings.requires_grad = True
@@ -557,21 +593,21 @@ def main(
557
  noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
558
  noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)
559
  latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
560
- loss = nnf.mse_loss(latents_prev_rec, latent_prev)
561
  optimizer.zero_grad()
562
  loss.backward()
563
  optimizer.step()
564
  loss_item = loss.item()
565
- bar.update()
566
  if loss_item < epsilon + i * 2e-5:
567
  break
568
- for j in range(j + 1, num_inner_steps):
569
- bar.update()
570
  uncond_embeddings_list.append(uncond_embeddings[:1].detach())
571
  with torch.no_grad():
572
  context = torch.cat([uncond_embeddings, cond_embeddings])
573
  latent_cur = self.get_noise_pred(latent_cur, t, False, context)
574
- bar.close()
575
  return uncond_embeddings_list
576
 
577
  def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
@@ -652,12 +688,13 @@ def main(
652
  inversion.append( Image.fromarray((sequence1[i] * 255).numpy().astype(np.uint8)) )
653
  videop2p.append( Image.fromarray((sequence2[i] * 255).numpy().astype(np.uint8)) )
654
 
655
- inversion[0].save(save_name_1, save_all=True, append_images=inversion[1:], optimize=False, loop=0, duration=250)
656
  videop2p[0].save(save_name_2, save_all=True, append_images=videop2p[1:], optimize=False, loop=0, duration=250)
657
 
 
658
  if __name__ == "__main__":
659
  parser = argparse.ArgumentParser()
660
- parser.add_argument("--config", type=str, default="./configs/videop2p.yaml")
661
  parser.add_argument("--fast", action='store_true')
662
  args = parser.parse_args()
663
 
 
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ import inspect
5
+ import math
6
  import os
7
  from typing import Optional, Union, Tuple, List, Callable, Dict
8
+ from omegaconf import OmegaConf
9
+
10
  import torch
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+
14
+ import diffusers
15
+ import transformers
16
+ from accelerate import Accelerator
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import set_seed
19
+ from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
20
+ from diffusers.optimization import get_scheduler
21
+ from diffusers.utils import check_min_version
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+ from tqdm.auto import tqdm
24
+ from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer
25
+
26
+ from tuneavideo.models.unet import UNet3DConditionModel
27
+ from tuneavideo.data.dataset import TuneAVideoDataset
28
+ from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
29
+ from tuneavideo.util import save_videos_grid, ddim_inversion
30
+ from einops import rearrange
31
+
32
+ import cv2
33
  import abc
34
  import ptp_utils
35
  import seq_aligner
36
  import shutil
37
  from torch.optim.adam import Adam
38
  from PIL import Image
39
+ import numpy as np
40
+ import decord
41
+ decord.bridge.set_bridge('torch')
42
 
43
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
44
+ check_min_version("0.10.0.dev0")
45
 
46
+ logger = get_logger(__name__, log_level="INFO")
 
 
47
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def main(
50
  pretrained_model_path: str,
51
+ output_dir: str,
52
+ train_data: Dict,
53
+ validation_data: Dict,
54
+ validation_steps: int = 100,
55
+ trainable_modules: Tuple[str] = (
56
+ "attn1.to_q",
57
+ "attn2.to_q",
58
+ "attn_temp",
59
+ ),
60
+ train_batch_size: int = 1,
61
+ max_train_steps: int = 500,
62
+ learning_rate: float = 3e-5,
63
+ scale_lr: bool = False,
64
+ lr_scheduler: str = "constant",
65
+ lr_warmup_steps: int = 0,
66
+ adam_beta1: float = 0.9,
67
+ adam_beta2: float = 0.999,
68
+ adam_weight_decay: float = 1e-2,
69
+ adam_epsilon: float = 1e-08,
70
+ max_grad_norm: float = 1.0,
71
+ gradient_accumulation_steps: int = 1,
72
+ gradient_checkpointing: bool = True,
73
+ checkpointing_steps: int = 500,
74
+ resume_from_checkpoint: Optional[str] = None,
75
+ mixed_precision: Optional[str] = "fp16",
76
+ use_8bit_adam: bool = False,
77
+ enable_xformers_memory_efficient_attention: bool = True,
78
+ seed: Optional[int] = None,
79
+ # pretrained_model_path: str,
80
+ # image_path: str = None,
81
+ # prompt: str = None,
82
+ prompts: Tuple[str] = None,
83
+ eq_params: Dict = None,
84
+ save_name: str = None,
85
+ is_word_swap: bool = None,
86
  blend_word: Tuple[str] = None,
87
  cross_replace_steps: float = 0.2,
88
  self_replace_steps: float = 0.5,
89
  video_len: int = 8,
90
  fast: bool = False,
91
+ mixed_precision_p2p: str = 'fp32',
92
  ):
93
+
94
+ # Video-P2P
95
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
96
+ MY_TOKEN = ''
97
+ LOW_RESOURCE = False
98
+ NUM_DDIM_STEPS = 50
99
+ GUIDANCE_SCALE = 7.5
100
+ MAX_NUM_WORDS = 77
101
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
102
+
103
+ # need to adjust sometimes
104
+ mask_th = (.3, .3)
105
+
106
+
107
+ pretrained_model_path = output_dir
108
+ image_path = train_data['video_path']
109
+ prompt = train_data['prompt']
110
+ # prompts = [prompt, ]
111
  output_folder = os.path.join(pretrained_model_path, 'results')
112
  if fast:
113
  save_name_1 = os.path.join(output_folder, 'inversion_fast.gif')
 
122
  cross_replace_steps = {'default_': cross_replace_steps,}
123
 
124
  weight_dtype = torch.float32
125
+ if mixed_precision_p2p == "fp16":
126
  weight_dtype = torch.float16
127
+ elif mixed_precision_p2p == "bf16":
128
  weight_dtype = torch.bfloat16
129
 
130
  if not os.path.exists(output_folder):
 
165
  k = 1
166
  maps = (maps * alpha).sum(-1).mean(2)
167
  if use_pool:
168
+ maps = F.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
169
+ mask = F.interpolate(maps, size=(x_t.shape[3:]))
170
  mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
171
  mask = mask.gt(self.th[1-int(use_pool)])
172
  mask = mask[:1] + mask
 
444
 
445
 
446
  def load_512_seq(image_path, left=0, right=0, top=0, bottom=0, n_sample_frame=video_len, sampling_rate=1):
447
+ vr = decord.VideoReader(image_path, width=512, height=512)
448
+ sample_index = list(range(0, len(vr), sampling_rate))[:n_sample_frame]
449
+ video = vr.get_batch(sample_index)
450
+ return video.numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
 
453
  class NullInversion:
 
580
  uncond_embeddings, cond_embeddings = self.context.chunk(2)
581
  uncond_embeddings_list = []
582
  latent_cur = latents[-1]
583
+ # bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS)
584
  for i in range(NUM_DDIM_STEPS):
585
  uncond_embeddings = uncond_embeddings.clone().detach()
586
  uncond_embeddings.requires_grad = True
 
593
  noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
594
  noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)
595
  latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
596
+ loss = F.mse_loss(latents_prev_rec, latent_prev)
597
  optimizer.zero_grad()
598
  loss.backward()
599
  optimizer.step()
600
  loss_item = loss.item()
601
+ # bar.update()
602
  if loss_item < epsilon + i * 2e-5:
603
  break
604
+ # for j in range(j + 1, num_inner_steps):
605
+ # bar.update()
606
  uncond_embeddings_list.append(uncond_embeddings[:1].detach())
607
  with torch.no_grad():
608
  context = torch.cat([uncond_embeddings, cond_embeddings])
609
  latent_cur = self.get_noise_pred(latent_cur, t, False, context)
610
+ # bar.close()
611
  return uncond_embeddings_list
612
 
613
  def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
 
688
  inversion.append( Image.fromarray((sequence1[i] * 255).numpy().astype(np.uint8)) )
689
  videop2p.append( Image.fromarray((sequence2[i] * 255).numpy().astype(np.uint8)) )
690
 
691
+ # inversion[0].save(save_name_1, save_all=True, append_images=inversion[1:], optimize=False, loop=0, duration=250)
692
  videop2p[0].save(save_name_2, save_all=True, append_images=videop2p[1:], optimize=False, loop=0, duration=250)
693
 
694
+
695
  if __name__ == "__main__":
696
  parser = argparse.ArgumentParser()
697
+ parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml")
698
  parser.add_argument("--fast", action='store_true')
699
  args = parser.parse_args()
700
 
trainer.py CHANGED
@@ -145,7 +145,9 @@ class Trainer:
145
  with open(config_path, 'w') as f:
146
  OmegaConf.save(config, f)
147
 
148
- command = f'accelerate launch Video-P2P/run.py --config {config_path} --fast'
 
 
149
  subprocess.run(shlex.split(command))
150
  save_model_card(save_dir=output_dir,
151
  base_model=base_model,
 
145
  with open(config_path, 'w') as f:
146
  OmegaConf.save(config, f)
147
 
148
+ command = f'accelerate launch Video-P2P/run_tuning.py --config {config_path}'
149
+ subprocess.run(shlex.split(command))
150
+ command = f'python Video-P2P/run_tuning.py --config {config_path} --fast'
151
  subprocess.run(shlex.split(command))
152
  save_model_card(save_dir=output_dir,
153
  base_model=base_model,