Spaces:
Running
on
Zero
Running
on
Zero
"""Fine-tuning script for Stable Video Diffusion for image2video with support for LoRA.""" | |
import logging | |
import math | |
import os | |
import shutil | |
from glob import glob | |
from pathlib import Path | |
from PIL import Image | |
import accelerate | |
import datasets | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from einops import rearrange | |
import transformers | |
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import ProjectConfiguration, set_seed | |
from packaging import version | |
from tqdm.auto import tqdm | |
import copy | |
import diffusers | |
from diffusers import AutoencoderKLTemporalDecoder | |
from diffusers import UNetSpatioTemporalConditionModel | |
from diffusers.optimization import get_scheduler | |
from diffusers.training_utils import cast_training_params | |
from diffusers.utils import check_min_version, is_wandb_available | |
from diffusers.utils.import_utils import is_xformers_available | |
from diffusers.utils.torch_utils import is_compiled_module | |
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import _resize_with_antialiasing | |
from custom_diffusers.pipelines.pipeline_stable_video_diffusion_with_ref_attnmap import StableVideoDiffusionWithRefAttnMapPipeline | |
from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler | |
from attn_ctrl.attention_control import (AttentionStore, | |
register_temporal_self_attention_control, | |
register_temporal_self_attention_flip_control, | |
) | |
from utils.parse_args import parse_args | |
from dataset.stable_video_dataset import StableVideoDataset | |
logger = get_logger(__name__, log_level="INFO") | |
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32): | |
"""Draws samples from an lognormal distribution.""" | |
u = torch.rand(shape, dtype=dtype, device=device) * (1 - 2e-7) + 1e-7 | |
return torch.distributions.Normal(loc, scale).icdf(u).exp() | |
def main(): | |
args = parse_args() | |
logging_dir = Path(args.output_dir, args.logging_dir) | |
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) | |
accelerator = Accelerator( | |
gradient_accumulation_steps=args.gradient_accumulation_steps, | |
mixed_precision=args.mixed_precision, | |
log_with=args.report_to, | |
project_config=accelerator_project_config, | |
) | |
if args.report_to == "wandb": | |
if not is_wandb_available(): | |
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") | |
import wandb | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger.info(accelerator.state, main_process_only=False) | |
if accelerator.is_local_main_process: | |
datasets.utils.logging.set_verbosity_warning() | |
transformers.utils.logging.set_verbosity_warning() | |
diffusers.utils.logging.set_verbosity_info() | |
else: | |
datasets.utils.logging.set_verbosity_error() | |
transformers.utils.logging.set_verbosity_error() | |
diffusers.utils.logging.set_verbosity_error() | |
# If passed along, set the training seed now. | |
if args.seed is not None: | |
set_seed(args.seed) | |
# Handle the repository creation | |
if accelerator.is_main_process: | |
if args.output_dir is not None: | |
os.makedirs(args.output_dir, exist_ok=True) | |
# Load scheduler, tokenizer and models. | |
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") | |
feature_extractor = CLIPImageProcessor.from_pretrained(args.pretrained_model_name_or_path, subfolder="feature_extractor") | |
image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="image_encoder", variant=args.variant | |
) | |
vae = AutoencoderKLTemporalDecoder.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="vae", variant=args.variant | |
) | |
unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
args.pretrained_model_name_or_path, subfolder="unet", low_cpu_mem_usage=True, variant=args.variant | |
) | |
ref_unet = copy.deepcopy(unet) | |
# register customized attn processors | |
controller_ref = AttentionStore() | |
register_temporal_self_attention_control(ref_unet, controller_ref) | |
controller = AttentionStore() | |
register_temporal_self_attention_flip_control(unet, controller, controller_ref) | |
# freeze parameters of models to save more memory | |
ref_unet.requires_grad_(False) | |
unet.requires_grad_(False) | |
vae.requires_grad_(False) | |
image_encoder.requires_grad_(False) | |
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision | |
# as these weights are only used for inference, keeping weights in full precision is not required. | |
weight_dtype = torch.float32 | |
if accelerator.mixed_precision == "fp16": | |
weight_dtype = torch.float16 | |
elif accelerator.mixed_precision == "bf16": | |
weight_dtype = torch.bfloat16 | |
# Move unet, vae and image_encoder to device and cast to weight_dtype | |
# unet.to(accelerator.device, dtype=weight_dtype) | |
vae.to(accelerator.device, dtype=weight_dtype) | |
image_encoder.to(accelerator.device, dtype=weight_dtype) | |
ref_unet.to(accelerator.device, dtype=weight_dtype) | |
unet_train_params_list = [] | |
# Customize the parameters that need to be trained; if necessary, you can uncomment them yourself. | |
for name, para in unet.named_parameters(): | |
if 'temporal_transformer_blocks.0.attn1.to_v.weight' in name or 'temporal_transformer_blocks.0.attn1.to_out.0.weight' in name: | |
unet_train_params_list.append(para) | |
para.requires_grad = True | |
else: | |
para.requires_grad = False | |
if args.mixed_precision == "fp16": | |
# only upcast trainable parameters into fp32 | |
cast_training_params(unet, dtype=torch.float32) | |
if args.enable_xformers_memory_efficient_attention: | |
if is_xformers_available(): | |
import xformers | |
xformers_version = version.parse(xformers.__version__) | |
if xformers_version == version.parse("0.0.16"): | |
logger.warn( | |
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | |
) | |
unet.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError("xformers is not available. Make sure it is installed correctly") | |
# `accelerate` 0.16.0 will have better support for customized saving | |
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): | |
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | |
def save_model_hook(models, weights, output_dir): | |
if accelerator.is_main_process: | |
for i, model in enumerate(models): | |
model.save_pretrained(os.path.join(output_dir, "unet")) | |
# make sure to pop weight so that corresponding model is not saved again | |
weights.pop() | |
def load_model_hook(models, input_dir): | |
for _ in range(len(models)): | |
# pop models so that they are not loaded again | |
model = models.pop() | |
# load diffusers style into model | |
load_model = UNetSpatioTemporalConditionModel.from_pretrained(input_dir, subfolder="unet") | |
model.register_to_config(**load_model.config) | |
model.load_state_dict(load_model.state_dict()) | |
del load_model | |
accelerator.register_save_state_pre_hook(save_model_hook) | |
accelerator.register_load_state_pre_hook(load_model_hook) | |
if args.gradient_checkpointing: | |
unet.enable_gradient_checkpointing() | |
if args.gradient_checkpointing: | |
unet.enable_gradient_checkpointing() | |
if accelerator.is_main_process: | |
rec_txt1 = open('frozen_param.txt', 'w') | |
rec_txt2 = open('train_param.txt', 'w') | |
for name, para in unet.named_parameters(): | |
if para.requires_grad is False: | |
rec_txt1.write(f'{name}\n') | |
else: | |
rec_txt2.write(f'{name}\n') | |
rec_txt1.close() | |
rec_txt2.close() | |
# Enable TF32 for faster training on Ampere GPUs, | |
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
if args.allow_tf32: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
if args.scale_lr: | |
args.learning_rate = ( | |
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes | |
) | |
# Initialize the optimizer | |
optimizer = torch.optim.AdamW( | |
unet_train_params_list, | |
lr=args.learning_rate, | |
betas=(args.adam_beta1, args.adam_beta2), | |
weight_decay=args.adam_weight_decay, | |
eps=args.adam_epsilon, | |
) | |
def unwrap_model(model): | |
model = accelerator.unwrap_model(model) | |
model = model._orig_mod if is_compiled_module(model) else model | |
return model | |
train_dataset = StableVideoDataset(video_data_dir=args.train_data_dir, | |
max_num_videos=args.max_train_samples, | |
num_frames=args.num_frames, | |
is_reverse_video=True, | |
double_sampling_rate=args.double_sampling_rate) | |
def collate_fn(examples): | |
pixel_values = torch.stack([example["pixel_values"] for example in examples]) | |
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | |
conditions = torch.stack([example["conditions"] for example in examples]) | |
conditions =conditions.to(memory_format=torch.contiguous_format).float() | |
return {"pixel_values": pixel_values, "conditions": conditions} | |
# DataLoaders creation: | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
shuffle=True, | |
collate_fn=collate_fn, | |
batch_size=args.train_batch_size, | |
num_workers=args.dataloader_num_workers, | |
) | |
# Validation data | |
if args.validation_data_dir is not None: | |
validation_image_paths = sorted(glob(os.path.join(args.validation_data_dir, '*.png'))) | |
num_validation_images = min(args.num_validation_images, len(validation_image_paths)) | |
validation_image_paths = validation_image_paths[:num_validation_images] | |
validation_images = [Image.open(image_path).convert('RGB').resize((1024, 576)) for image_path in validation_image_paths] | |
# Scheduler and math around the number of training steps. | |
overrode_max_train_steps = False | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
if args.max_train_steps is None: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
overrode_max_train_steps = True | |
lr_scheduler = get_scheduler( | |
args.lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, | |
num_training_steps=args.max_train_steps * accelerator.num_processes, | |
) | |
# Prepare everything with our `accelerator`. | |
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
unet, optimizer, train_dataloader, lr_scheduler | |
) | |
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | |
if overrode_max_train_steps: | |
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
# Afterwards we recalculate our number of training epochs | |
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
# We need to initialize the trackers we use, and also store our configuration. | |
# The trackers initializes automatically on the main process. | |
if accelerator.is_main_process: | |
accelerator.init_trackers("image2video-reverse-fine-tune", config=vars(args)) | |
# Train! | |
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | |
logger.info("***** Running training *****") | |
logger.info(f" Num examples = {len(train_dataset)}") | |
logger.info(f" Num Epochs = {args.num_train_epochs}") | |
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | |
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | |
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | |
logger.info(f" Total optimization steps = {args.max_train_steps}") | |
global_step = 0 | |
first_epoch = 0 | |
# Potentially load in the weights and states from a previous save | |
if args.resume_from_checkpoint: | |
if args.resume_from_checkpoint != "latest": | |
path = os.path.basename(args.resume_from_checkpoint) | |
else: | |
# Get the most recent checkpoint | |
dirs = os.listdir(args.output_dir) | |
dirs = [d for d in dirs if d.startswith("checkpoint")] | |
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) | |
path = dirs[-1] if len(dirs) > 0 else None | |
if path is None: | |
accelerator.print( | |
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." | |
) | |
args.resume_from_checkpoint = None | |
initial_global_step = 0 | |
else: | |
accelerator.print(f"Resuming from checkpoint {path}") | |
accelerator.load_state(os.path.join(args.output_dir, path)) | |
global_step = int(path.split("-")[1]) | |
initial_global_step = global_step | |
first_epoch = global_step // num_update_steps_per_epoch | |
else: | |
initial_global_step = 0 | |
progress_bar = tqdm( | |
range(0, args.max_train_steps), | |
initial=initial_global_step, | |
desc="Steps", | |
# Only show the progress bar once on each machine. | |
disable=not accelerator.is_local_main_process, | |
) | |
# default motion param setting | |
def _get_add_time_ids( | |
dtype, | |
batch_size, | |
fps=6, | |
motion_bucket_id=127, | |
noise_aug_strength=0.02, | |
): | |
add_time_ids = [fps, motion_bucket_id, noise_aug_strength] | |
passed_add_embed_dim = unet.module.config.addition_time_embed_dim * \ | |
len(add_time_ids) | |
expected_add_embed_dim = unet.module.add_embedding.linear_1.in_features | |
assert (expected_add_embed_dim == passed_add_embed_dim) | |
add_time_ids = torch.tensor([add_time_ids], dtype=dtype) | |
add_time_ids = add_time_ids.repeat(batch_size, 1) | |
return add_time_ids | |
def compute_image_embeddings(image): | |
image = _resize_with_antialiasing(image, (224, 224)) | |
image = (image + 1.0) / 2.0 | |
# Normalize the image with for CLIP input | |
image = feature_extractor( | |
images=image, | |
do_normalize=True, | |
do_center_crop=False, | |
do_resize=False, | |
do_rescale=False, | |
return_tensors="pt", | |
).pixel_values | |
image = image.to(accelerator.device).to(dtype=weight_dtype) | |
image_embeddings = image_encoder(image).image_embeds | |
image_embeddings = image_embeddings.unsqueeze(1) | |
return image_embeddings | |
noise_aug_strength = 0.02 | |
fps=7 | |
for epoch in range(first_epoch, args.num_train_epochs): | |
unet.train() | |
train_loss = 0.0 | |
for step, batch in enumerate(train_dataloader): | |
with accelerator.accumulate(unet): | |
# Get the image embedding for conditioning | |
encoder_hidden_states = compute_image_embeddings(batch["conditions"]) | |
encoder_hidden_states_ref = compute_image_embeddings(batch["pixel_values"][:, -1]) | |
batch["conditions"] = batch["conditions"].to(accelerator.device).to(dtype=weight_dtype) | |
batch["pixel_values"] = batch["pixel_values"].to(accelerator.device).to(dtype=weight_dtype) | |
# Get the image latent for input condtioning | |
noise = torch.randn_like(batch["conditions"]) | |
conditions = batch["conditions"] + noise_aug_strength * noise | |
conditions_latent = vae.encode(conditions).latent_dist.mode() | |
conditions_latent = conditions_latent.unsqueeze(1).repeat(1, args.num_frames, 1, 1, 1) | |
conditions_ref = batch["pixel_values"][:, -1] + noise_aug_strength * noise | |
conditions_latent_ref = vae.encode(conditions_ref).latent_dist.mode() | |
conditions_latent_ref = conditions_latent_ref.unsqueeze(1).repeat(1, args.num_frames, 1, 1, 1) | |
# Convert frames to latent space | |
pixel_values = rearrange(batch["pixel_values"], "b f c h w -> (b f) c h w") | |
latents = vae.encode(pixel_values).latent_dist.sample() | |
latents = latents * vae.config.scaling_factor | |
latents = rearrange(latents, "(b f) c h w -> b f c h w", f=args.num_frames) | |
latents_ref= torch.flip(latents, dims=(1,)) | |
# Sample noise that we'll add to the latents | |
noise = torch.randn_like(latents) | |
if args.noise_offset: | |
# https://www.crosslabs.org//blog/diffusion-with-offset-noise | |
noise += args.noise_offset * torch.randn( | |
(latents.shape[0], latents.shape[1], latents.shape[2], 1, 1), device=latents.device | |
) | |
bsz = latents.shape[0] | |
# Sample a random timestep for each image | |
# P_mean=0.7 P_std=1.6 | |
sigmas = rand_log_normal(shape=[bsz,], loc=0.7, scale=1.6).to(latents.device) | |
# Add noise to the latents according to the noise magnitude at each timestep | |
# (this is the forward diffusion process) | |
sigmas = sigmas[:, None, None, None, None] | |
timesteps = torch.Tensor( | |
[0.25 * sigma.log() for sigma in sigmas]).to(accelerator.device) | |
# Add noise to the latents according to the noise magnitude at each timestep | |
# (this is the forward diffusion process) | |
noisy_latents = latents + noise * sigmas | |
noisy_latents_inp = noisy_latents / ((sigmas**2 + 1) ** 0.5) | |
noisy_latents_inp = torch.cat([noisy_latents_inp, conditions_latent], dim=2) | |
noisy_latents_ref = latents_ref + torch.flip(noise, dims=(1,)) * sigmas | |
noisy_latents_ref_inp = noisy_latents_ref / ((sigmas**2 + 1) ** 0.5) | |
noisy_latents_ref_inp = torch.cat([noisy_latents_ref_inp, conditions_latent_ref], dim=2) | |
# Get the target for loss depending on the prediction type | |
target = latents | |
# Predict the noise residual and compute loss | |
added_time_ids = _get_add_time_ids(encoder_hidden_states.dtype, bsz).to(accelerator.device) | |
ref_model_pred = ref_unet(noisy_latents_ref_inp.to(weight_dtype), timesteps.to(weight_dtype), | |
encoder_hidden_states=encoder_hidden_states_ref, | |
added_time_ids=added_time_ids, | |
return_dict=False)[0] | |
model_pred = unet(noisy_latents_inp, timesteps, | |
encoder_hidden_states=encoder_hidden_states, | |
added_time_ids=added_time_ids, | |
return_dict=False)[0] # v-prediction | |
# Denoise the latents | |
c_out = -sigmas / ((sigmas**2 + 1)**0.5) | |
c_skip = 1 / (sigmas**2 + 1) | |
denoised_latents = model_pred * c_out + c_skip * noisy_latents | |
weighing = (1 + sigmas ** 2) * (sigmas**-2.0) | |
# MSE loss | |
loss = torch.mean( | |
(weighing.float() * (denoised_latents.float() - | |
target.float()) ** 2).reshape(target.shape[0], -1), | |
dim=1, | |
) | |
loss = loss.mean() | |
# Gather the losses across all processes for logging (if we use distributed training). | |
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() | |
train_loss += avg_loss.item() / args.gradient_accumulation_steps | |
# Backpropagate | |
accelerator.backward(loss) | |
if accelerator.sync_gradients: | |
params_to_clip = unet_train_params_list | |
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
# Checks if the accelerator has performed an optimization step behind the scenes | |
if accelerator.sync_gradients: | |
progress_bar.update(1) | |
global_step += 1 | |
accelerator.log({"train_loss": train_loss}, step=global_step) | |
train_loss = 0.0 | |
if global_step % args.checkpointing_steps == 0: | |
if accelerator.is_main_process: | |
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit` | |
if args.checkpoints_total_limit is not None: | |
checkpoints = os.listdir(args.output_dir) | |
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] | |
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) | |
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints | |
if len(checkpoints) >= args.checkpoints_total_limit: | |
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 | |
removing_checkpoints = checkpoints[0:num_to_remove] | |
logger.info( | |
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" | |
) | |
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") | |
for removing_checkpoint in removing_checkpoints: | |
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) | |
shutil.rmtree(removing_checkpoint) | |
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") | |
accelerator.save_state(save_path) | |
logger.info(f"Saved state to {save_path}") | |
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | |
progress_bar.set_postfix(**logs) | |
if global_step >= args.max_train_steps: | |
break | |
if accelerator.is_main_process: | |
if args.validation_data_dir is not None and epoch % args.validation_epochs == 0: | |
logger.info( | |
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" | |
f" {args.validation_data_dir}." | |
) | |
# create pipeline | |
pipeline = StableVideoDiffusionWithRefAttnMapPipeline.from_pretrained( | |
args.pretrained_model_name_or_path, | |
scheduler=noise_scheduler, | |
unet=unwrap_model(unet), | |
variant=args.variant, | |
torch_dtype=weight_dtype, | |
) | |
pipeline = pipeline.to(accelerator.device) | |
pipeline.set_progress_bar_config(disable=True) | |
# run inference | |
generator = torch.Generator(device=accelerator.device) | |
if args.seed is not None: | |
generator = generator.manual_seed(args.seed) | |
videos = [] | |
with torch.cuda.amp.autocast(): | |
for val_idx in range(num_validation_images): | |
val_img = validation_images[val_idx] | |
videos.append( | |
pipeline(ref_unet=ref_unet, image=val_img, ref_image=val_img, num_inference_steps=50, generator=generator, output_type='pt').frames[0] | |
) | |
for tracker in accelerator.trackers: | |
if tracker.name == "tensorboard": | |
videos = torch.stack(videos) | |
tracker.writer.add_video("validation", videos, epoch, fps=fps) | |
del pipeline | |
torch.cuda.empty_cache() | |
# Save the lora layers | |
accelerator.wait_for_everyone() | |
if accelerator.is_main_process: | |
unet = unet.to(torch.float32) | |
unwrapped_unet = unwrap_model(unet) | |
pipeline = StableVideoDiffusionWithRefAttnMapPipeline.from_pretrained( | |
args.pretrained_model_name_or_path, | |
scheduler=noise_scheduler, | |
unet=unwrapped_unet, | |
variant=args.variant, | |
) | |
pipeline.save_pretrained(args.output_dir) | |
# Final inference | |
# Load previous pipeline | |
if args.validation_data_dir is not None: | |
pipeline = pipeline.to(accelerator.device) | |
pipeline.torch_dtype = weight_dtype | |
# run inference | |
generator = torch.Generator(device=accelerator.device) | |
if args.seed is not None: | |
generator = generator.manual_seed(args.seed) | |
videos = [] | |
with torch.cuda.amp.autocast(): | |
for val_idx in range(num_validation_images): | |
val_img = validation_images[val_idx] | |
videos.append( | |
pipeline(ref_unet=ref_unet, image=val_img, ref_image=val_img, num_inference_steps=50, generator=generator, output_type='pt').frames[0] | |
) | |
for tracker in accelerator.trackers: | |
if len(videos) != 0: | |
if tracker.name == "tensorboard": | |
videos = torch.stack(videos) | |
tracker.writer.add_video("validation", videos, epoch, fps=fps) | |
accelerator.end_training() | |
if __name__ == "__main__": | |
main() | |