|
from utils.wan_wrapper import WanDiffusionWrapper |
|
from utils.scheduler import SchedulerInterface |
|
from typing import List, Optional |
|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
class SelfForcingTrainingPipeline: |
|
def __init__(self, |
|
denoising_step_list: List[int], |
|
scheduler: SchedulerInterface, |
|
generator: WanDiffusionWrapper, |
|
num_frame_per_block=3, |
|
independent_first_frame: bool = False, |
|
same_step_across_blocks: bool = False, |
|
last_step_only: bool = False, |
|
num_max_frames: int = 21, |
|
context_noise: int = 0, |
|
**kwargs): |
|
super().__init__() |
|
self.scheduler = scheduler |
|
self.generator = generator |
|
self.denoising_step_list = denoising_step_list |
|
if self.denoising_step_list[-1] == 0: |
|
self.denoising_step_list = self.denoising_step_list[:-1] |
|
|
|
|
|
self.num_transformer_blocks = 30 |
|
self.frame_seq_length = 1560 |
|
self.num_frame_per_block = num_frame_per_block |
|
self.context_noise = context_noise |
|
self.i2v = False |
|
|
|
self.kv_cache1 = None |
|
self.kv_cache2 = None |
|
self.independent_first_frame = independent_first_frame |
|
self.same_step_across_blocks = same_step_across_blocks |
|
self.last_step_only = last_step_only |
|
self.kv_cache_size = num_max_frames * self.frame_seq_length |
|
|
|
def generate_and_sync_list(self, num_blocks, num_denoising_steps, device): |
|
rank = dist.get_rank() if dist.is_initialized() else 0 |
|
|
|
if rank == 0: |
|
|
|
indices = torch.randint( |
|
low=0, |
|
high=num_denoising_steps, |
|
size=(num_blocks,), |
|
device=device |
|
) |
|
if self.last_step_only: |
|
indices = torch.ones_like(indices) * (num_denoising_steps - 1) |
|
else: |
|
indices = torch.empty(num_blocks, dtype=torch.long, device=device) |
|
|
|
dist.broadcast(indices, src=0) |
|
return indices.tolist() |
|
|
|
def inference_with_trajectory( |
|
self, |
|
noise: torch.Tensor, |
|
initial_latent: Optional[torch.Tensor] = None, |
|
return_sim_step: bool = False, |
|
**conditional_dict |
|
) -> torch.Tensor: |
|
batch_size, num_frames, num_channels, height, width = noise.shape |
|
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None): |
|
|
|
|
|
assert num_frames % self.num_frame_per_block == 0 |
|
num_blocks = num_frames // self.num_frame_per_block |
|
else: |
|
|
|
assert (num_frames - 1) % self.num_frame_per_block == 0 |
|
num_blocks = (num_frames - 1) // self.num_frame_per_block |
|
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0 |
|
num_output_frames = num_frames + num_input_frames |
|
output = torch.zeros( |
|
[batch_size, num_output_frames, num_channels, height, width], |
|
device=noise.device, |
|
dtype=noise.dtype |
|
) |
|
|
|
|
|
self._initialize_kv_cache( |
|
batch_size=batch_size, dtype=noise.dtype, device=noise.device |
|
) |
|
self._initialize_crossattn_cache( |
|
batch_size=batch_size, dtype=noise.dtype, device=noise.device |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
current_start_frame = 0 |
|
if initial_latent is not None: |
|
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0 |
|
|
|
output[:, :1] = initial_latent |
|
with torch.no_grad(): |
|
self.generator( |
|
noisy_image_or_video=initial_latent, |
|
conditional_dict=conditional_dict, |
|
timestep=timestep * 0, |
|
kv_cache=self.kv_cache1, |
|
crossattn_cache=self.crossattn_cache, |
|
current_start=current_start_frame * self.frame_seq_length |
|
) |
|
current_start_frame += 1 |
|
|
|
|
|
all_num_frames = [self.num_frame_per_block] * num_blocks |
|
if self.independent_first_frame and initial_latent is None: |
|
all_num_frames = [1] + all_num_frames |
|
num_denoising_steps = len(self.denoising_step_list) |
|
exit_flags = self.generate_and_sync_list(len(all_num_frames), num_denoising_steps, device=noise.device) |
|
start_gradient_frame_index = num_output_frames - 21 |
|
|
|
|
|
for block_index, current_num_frames in enumerate(all_num_frames): |
|
noisy_input = noise[ |
|
:, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames] |
|
|
|
|
|
for index, current_timestep in enumerate(self.denoising_step_list): |
|
if self.same_step_across_blocks: |
|
exit_flag = (index == exit_flags[0]) |
|
else: |
|
exit_flag = (index == exit_flags[block_index]) |
|
timestep = torch.ones( |
|
[batch_size, current_num_frames], |
|
device=noise.device, |
|
dtype=torch.int64) * current_timestep |
|
|
|
if not exit_flag: |
|
with torch.no_grad(): |
|
_, denoised_pred = self.generator( |
|
noisy_image_or_video=noisy_input, |
|
conditional_dict=conditional_dict, |
|
timestep=timestep, |
|
kv_cache=self.kv_cache1, |
|
crossattn_cache=self.crossattn_cache, |
|
current_start=current_start_frame * self.frame_seq_length |
|
) |
|
next_timestep = self.denoising_step_list[index + 1] |
|
noisy_input = self.scheduler.add_noise( |
|
denoised_pred.flatten(0, 1), |
|
torch.randn_like(denoised_pred.flatten(0, 1)), |
|
next_timestep * torch.ones( |
|
[batch_size * current_num_frames], device=noise.device, dtype=torch.long) |
|
).unflatten(0, denoised_pred.shape[:2]) |
|
else: |
|
|
|
|
|
if current_start_frame < start_gradient_frame_index: |
|
with torch.no_grad(): |
|
_, denoised_pred = self.generator( |
|
noisy_image_or_video=noisy_input, |
|
conditional_dict=conditional_dict, |
|
timestep=timestep, |
|
kv_cache=self.kv_cache1, |
|
crossattn_cache=self.crossattn_cache, |
|
current_start=current_start_frame * self.frame_seq_length |
|
) |
|
else: |
|
_, denoised_pred = self.generator( |
|
noisy_image_or_video=noisy_input, |
|
conditional_dict=conditional_dict, |
|
timestep=timestep, |
|
kv_cache=self.kv_cache1, |
|
crossattn_cache=self.crossattn_cache, |
|
current_start=current_start_frame * self.frame_seq_length |
|
) |
|
break |
|
|
|
|
|
output[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred |
|
|
|
|
|
context_timestep = torch.ones_like(timestep) * self.context_noise |
|
|
|
denoised_pred = self.scheduler.add_noise( |
|
denoised_pred.flatten(0, 1), |
|
torch.randn_like(denoised_pred.flatten(0, 1)), |
|
context_timestep * torch.ones( |
|
[batch_size * current_num_frames], device=noise.device, dtype=torch.long) |
|
).unflatten(0, denoised_pred.shape[:2]) |
|
with torch.no_grad(): |
|
self.generator( |
|
noisy_image_or_video=denoised_pred, |
|
conditional_dict=conditional_dict, |
|
timestep=context_timestep, |
|
kv_cache=self.kv_cache1, |
|
crossattn_cache=self.crossattn_cache, |
|
current_start=current_start_frame * self.frame_seq_length |
|
) |
|
|
|
|
|
current_start_frame += current_num_frames |
|
|
|
|
|
if not self.same_step_across_blocks: |
|
denoised_timestep_from, denoised_timestep_to = None, None |
|
elif exit_flags[0] == len(self.denoising_step_list) - 1: |
|
denoised_timestep_to = 0 |
|
denoised_timestep_from = 1000 - torch.argmin( |
|
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item() |
|
else: |
|
denoised_timestep_to = 1000 - torch.argmin( |
|
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0] + 1].cuda()).abs(), dim=0).item() |
|
denoised_timestep_from = 1000 - torch.argmin( |
|
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item() |
|
|
|
if return_sim_step: |
|
return output, denoised_timestep_from, denoised_timestep_to, exit_flags[0] + 1 |
|
|
|
return output, denoised_timestep_from, denoised_timestep_to |
|
|
|
def _initialize_kv_cache(self, batch_size, dtype, device): |
|
""" |
|
Initialize a Per-GPU KV cache for the Wan model. |
|
""" |
|
kv_cache1 = [] |
|
|
|
for _ in range(self.num_transformer_blocks): |
|
kv_cache1.append({ |
|
"k": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device), |
|
"v": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device), |
|
"global_end_index": torch.tensor([0], dtype=torch.long, device=device), |
|
"local_end_index": torch.tensor([0], dtype=torch.long, device=device) |
|
}) |
|
|
|
self.kv_cache1 = kv_cache1 |
|
|
|
def _initialize_crossattn_cache(self, batch_size, dtype, device): |
|
""" |
|
Initialize a Per-GPU cross-attention cache for the Wan model. |
|
""" |
|
crossattn_cache = [] |
|
|
|
for _ in range(self.num_transformer_blocks): |
|
crossattn_cache.append({ |
|
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device), |
|
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device), |
|
"is_init": False |
|
}) |
|
self.crossattn_cache = crossattn_cache |
|
|