AnyV2V / i2vgen-xl /run_group_pnp_edit.py
vinesmsuic's picture
init
26853cd
import os
import sys
from pathlib import Path
import torch
import argparse
import logging
from omegaconf import OmegaConf
from PIL import Image
import json
# HF imports
from diffusers import (
DDIMInverseScheduler,
DDIMScheduler,
)
from diffusers.utils import load_image, export_to_video, export_to_gif
# Project imports
from utils import (
seed_everything,
load_video_frames,
convert_video_to_frames,
load_ddim_latents_at_T,
load_ddim_latents_at_t,
)
from pipelines.pipeline_i2vgen_xl import I2VGenXLPipeline
from pnp_utils import (
register_time,
register_conv_injection,
register_spatial_attention_pnp,
register_temp_attention_pnp,
)
def init_pnp(pipe, scheduler, config):
conv_injection_t = int(config.n_steps * config.pnp_f_t)
spatial_attn_qk_injection_t = int(config.n_steps * config.pnp_spatial_attn_t)
temp_attn_qk_injection_t = int(config.n_steps * config.pnp_temp_attn_t)
conv_injection_timesteps = scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else []
spatial_attn_qk_injection_timesteps = (
scheduler.timesteps[:spatial_attn_qk_injection_t] if spatial_attn_qk_injection_t >= 0 else []
)
temp_attn_qk_injection_timesteps = (
scheduler.timesteps[:temp_attn_qk_injection_t] if temp_attn_qk_injection_t >= 0 else []
)
register_conv_injection(pipe, conv_injection_timesteps)
register_spatial_attention_pnp(pipe, spatial_attn_qk_injection_timesteps)
register_temp_attention_pnp(pipe, temp_attn_qk_injection_timesteps)
logger = logging.getLogger(__name__)
logger.debug(f"conv_injection_t: {conv_injection_t}")
logger.debug(f"spatial_attn_qk_injection_t: {spatial_attn_qk_injection_t}")
logger.debug(f"temp_attn_qk_injection_t: {temp_attn_qk_injection_t}")
logger.debug(f"conv_injection_timesteps: {conv_injection_timesteps}")
logger.debug(f"spatial_attn_qk_injection_timesteps: {spatial_attn_qk_injection_timesteps}")
logger.debug(f"temp_attn_qk_injection_timesteps: {temp_attn_qk_injection_timesteps}")
def main(template_config, configs_list):
# Initialize the pipeline
pipe = I2VGenXLPipeline.from_pretrained(
"ali-vilab/i2vgen-xl",
torch_dtype=torch.float16,
variant="fp16",
)
pipe.to(device)
# Initialize the DDIM scheduler
ddim_scheduler = DDIMScheduler.from_pretrained(
"ali-vilab/i2vgen-xl",
subfolder="scheduler",
)
for config_entry in configs_list:
if config_entry["active"] == False:
logger.info(f"Skipping config_entry: {config_entry}")
continue
logger.info(f"Processing config_entry: {config_entry}")
# Override the config with the data_meta_entry
config = OmegaConf.merge(template_config, OmegaConf.create(config_entry))
# Update the related paths to absolute paths
config.video_path = os.path.join(config.video_dir, config.video_name + ".mp4")
config.video_frames_path = os.path.join(config.video_dir, config.video_name)
config.edited_first_frame_path = os.path.join(config.data_dir, config.edited_first_frame_path)
logger.info(f"config: {OmegaConf.to_yaml(config)}")
# Check if there are fields contain "ReplaceMe"
for k, v in config.items():
if "ReplaceMe" in str(v):
logger.error(f"Field {k} contains 'ReplaceMe'")
continue
# This is the same as run_pnp_edit.py
# Load first frame and source frames
try:
logger.info(f"Loading frames from: {config.video_frames_path}")
_, frame_list = load_video_frames(config.video_frames_path, config.n_frames, config.image_size)
except:
logger.error(f"Failed to load frames from: {config.video_frames_path}")
logger.info(f"Converting mp4 video to frames: {config.video_path}")
frame_list = convert_video_to_frames(config.video_path, config.image_size, save_frames=True)
frame_list = frame_list[: config.n_frames] # 16 frames for img2vid
logger.debug(f"len(frame_list): {len(frame_list)}")
src_frame_list = frame_list
src_1st_frame = src_frame_list[0] # Is a PIL image
# Load the edited first frame
edited_1st_frame = load_image(config.edited_first_frame_path)
edited_1st_frame = edited_1st_frame.resize(config.image_size, resample=Image.Resampling.LANCZOS)
# Load the initial latents at t
ddim_init_latents_t_idx = config.ddim_init_latents_t_idx
ddim_scheduler.set_timesteps(config.n_steps)
logger.info(f"ddim_scheduler.timesteps: {ddim_scheduler.timesteps}")
ddim_latents_at_t = load_ddim_latents_at_t(
ddim_scheduler.timesteps[ddim_init_latents_t_idx], ddim_latents_path=config.ddim_latents_path
)
logger.debug(f"ddim_scheduler.timesteps[t_idx]: {ddim_scheduler.timesteps[ddim_init_latents_t_idx]}")
logger.debug(f"ddim_latents_at_t.shape: {ddim_latents_at_t.shape}")
# Blend the latents
random_latents = torch.randn_like(ddim_latents_at_t)
logger.info(f"Blending random_ratio (1 means random latent): {config.random_ratio}")
mixed_latents = random_latents * config.random_ratio + ddim_latents_at_t * (1 - config.random_ratio)
# Init Pnp
init_pnp(pipe, ddim_scheduler, config)
# Edit video
pipe.register_modules(scheduler=ddim_scheduler)
edited_video = pipe.sample_with_pnp(
prompt=config.editing_prompt,
image=edited_1st_frame,
height=config.image_size[1],
width=config.image_size[0],
num_frames=config.n_frames,
num_inference_steps=config.n_steps,
guidance_scale=config.cfg,
negative_prompt=config.editing_negative_prompt,
target_fps=config.target_fps,
latents=mixed_latents,
generator=torch.manual_seed(config.seed),
return_dict=True,
ddim_init_latents_t_idx=ddim_init_latents_t_idx,
ddim_inv_latents_path=config.ddim_latents_path,
ddim_inv_prompt=config.ddim_inv_prompt,
ddim_inv_1st_frame=src_1st_frame,
).frames[0]
# Save video
# Add the config to the output_dir, TODO: make this more elegant
config_suffix = (
"ddim_init_latents_t_idx_"
+ str(ddim_init_latents_t_idx)
+ "_nsteps_"
+ str(config.n_steps)
+ "_cfg_"
+ str(config.cfg)
+ "_pnpf"
+ str(config.pnp_f_t)
+ "_pnps"
+ str(config.pnp_spatial_attn_t)
+ "_pnpt"
+ str(config.pnp_temp_attn_t)
)
output_dir = os.path.join(config.output_dir, config_suffix)
os.makedirs(output_dir, exist_ok=True)
edited_video = [frame.resize(config.image_size, resample=Image.LANCZOS) for frame in edited_video]
# Downsampling the video for space saving
# edited_video = [frame.resize((512, 512), resample=Image.LANCZOS) for frame in edited_video]
# if config.pnp_f_t == 0.0 and config.pnp_spatial_attn_t == 0.0 and config.pnp_temp_attn_t == 0.0:
# edited_video_file_name = "ddim_edit"
# else:
# edited_video_file_name = "pnp_edit"
edited_video_file_name = "video"
export_to_video(edited_video, os.path.join(output_dir, f"{edited_video_file_name}.mp4"), fps=config.target_fps)
export_to_gif(edited_video, os.path.join(output_dir, f"{edited_video_file_name}.gif"))
logger.info(f"Saved video to: {os.path.join(output_dir, f'{edited_video_file_name}.mp4')}")
logger.info(f"Saved gif to: {os.path.join(output_dir, f'{edited_video_file_name}.gif')}")
for i, frame in enumerate(edited_video):
frame.save(os.path.join(output_dir, f"{edited_video_file_name}_{i:05d}.png"))
logger.info(f"Saved frames to: {os.path.join(output_dir, f'{edited_video_file_name}_{i:05d}.png')}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--template_config", type=str, default="./configs/group_pnp_edit/template.yaml")
parser.add_argument(
"--configs_json", type=str, default="./configs/group_config.json"
) # This is going to override the template_config
args = parser.parse_args()
template_config = OmegaConf.load(args.template_config)
# Set up logging
logging_level = logging.DEBUG if template_config.debug else logging.INFO
logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - [%(funcName)s] - %(message)s")
logger = logging.getLogger(__name__)
logger.info(f"template_config: {OmegaConf.to_yaml(template_config)}")
# Load data jsonl into list
configs_json = args.configs_json
assert Path(configs_json).exists()
with open(configs_json, "r") as file:
configs_list = json.load(file)
logger.info(f"Loaded {len(configs_list)} configs from {configs_json}")
# Set up device and seed
device = torch.device(template_config.device)
torch.set_grad_enabled(False)
seed_everything(template_config.seed)
main(template_config, configs_list)