ImageMotion / main.py
jfischoff's picture
Create main.py (#2)
71e3361
raw
history blame
8.73 kB
import argparse
import glob
import os
from pathlib import Path
import uuid
from src.pipelines.pipeline_animatediff_pix2pix import StableDiffusionInstructPix2PixPipeline
from diffusers import EulerAncestralDiscreteScheduler
import torch
from src.models.unet import UNet3DConditionModel
import numpy as np
from PIL import Image
import imageio
def convert_frames_to_mp4(frames, filename, fps=30):
"""Converts a list of PIL Image frames to an MP4 file.
Args:
frames: A list of PIL Image frames.
filename: The name of the MP4 file to save.
fps: Frames per second for the video.
Returns:
None
"""
# Convert PIL Images to numpy arrays
numpy_frames = [np.array(frame) for frame in frames]
# Write frames to mp4
imageio.mimwrite(filename, numpy_frames, fps=fps)
def convert_frames_to_gif(frames, filename, duration=100):
"""Converts a list of PIL Image frames to a GIF file.
Args:
frames: A list of PIL Image frames.
filename: The name of the GIF file to save.
duration: Duration of each frame in milliseconds.
Returns:
None
"""
frames[0].save(
filename,
save_all=True,
append_images=frames[1:],
loop=0,
duration=duration
)
def convert_frames_to_gif_with_fps(frames, filename, fps=30):
"""Converts a list of PIL Image frames to a GIF file using fps.
Args:
frames: A list of PIL Image frames.
filename: The name of the GIF file to save.
fps: Frames per second for the gif.
Returns:
None
"""
duration = 1000 // fps
frames[0].save(
filename,
save_all=True,
append_images=frames[1:],
loop=0,
duration=duration
)
def run(t2i_model,
prompt="",
negative_prompt="",
frame_count=16,
num_inference_steps=20,
guidance_scale=7.5,
image_guidance_scale=1.5,
width=512,
height=512,
dtype="float16",
output_frames_directory="output_frames",
output_video_directory="output_video",
output_gif_directory="output_gif",
motion_module="viddle/viddle-pix2pix-animatediff-v1.ckpt",
init_image=None,
init_folder=None,
seed=42,
fps=15,
no_save_frames=False,
no_save_video=False,
no_save_gif=False,
):
scheduler_kwargs = {
"num_train_timesteps": 1000,
"beta_start": 0.00085,
"beta_end": 0.012,
"beta_schedule": "linear",
}
device = "cuda" if torch.cuda.is_available() else "cpu"
if dtype == "float16":
dtype = torch.float16
variant = "fp16"
elif dtype == "float32":
dtype = torch.float32
variant = "fp32"
unet_additional_kwargs = {
"in_channels": 8,
"unet_use_cross_frame_attention": False,
"unet_use_temporal_attention": False,
"use_motion_module": True,
"motion_module_resolutions": [1, 2, 4, 8],
"motion_module_mid_block": False,
"motion_module_decoder_only": False,
"motion_module_type": "Vanilla",
"motion_module_kwargs": {
"num_attention_heads": 8,
"num_transformer_block": 1,
"attention_block_types": ["Temporal_Self", "Temporal_Self"],
"temporal_position_encoding": True,
"temporal_position_encoding_max_len": 32,
"temporal_attention_dim_div": 1,
},
}
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
t2i_model,
scheduler=EulerAncestralDiscreteScheduler(**scheduler_kwargs),
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
torch_dtype=dtype,
variant=variant,
).to(device)
pipeline.unet = UNet3DConditionModel.from_pretrained_unet(pipeline.unet,
unet_additional_kwargs=unet_additional_kwargs,
).to(device=device, dtype=dtype)
pipeline.enable_vae_slicing()
motion_module_state_dict = torch.load(motion_module, map_location="cpu")
_, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
assert len(unexpected) == 0
if init_image is not None and init_folder is None:
image = Image.open(init_image)
image = image.resize((width, height))
elif init_folder is not None and init_image is None:
image_paths = glob.glob(init_folder + "/*.png")
# add the jpgs
image_paths += glob.glob(init_folder + "/*.jpg")
image_paths.sort()
image_paths = image_paths[:frame_count]
image = []
for image_path in image_paths:
image.append(Image.open(image_path).resize((width, height)))
else:
raise ValueError("Must provide either init_image or init_folder but not both")
generator = torch.Generator(device=device).manual_seed(seed)
frames = pipeline(prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
image_guidance_scale=image_guidance_scale,
image=image,
video_length=frame_count,
generator=generator,
)[0]
# create a uuid prefix for the output files
uuid_prefix = str(uuid.uuid4())
if not no_save_frames:
# Create output directory
Path(output_frames_directory).mkdir(parents=True, exist_ok=True)
# make the specific directory for this run
output_frames_directory = os.path.join(output_frames_directory, uuid_prefix)
Path(output_frames_directory).mkdir(parents=True, exist_ok=True)
# Save frames
for i, frame in enumerate(frames):
frame.save(os.path.join(output_frames_directory, f"{str(i).zfill(4)}.png"))
if not no_save_video:
# Create output directory
Path(output_video_directory).mkdir(parents=True, exist_ok=True)
convert_frames_to_mp4(frames, os.path.join(output_video_directory, f"{uuid_prefix}.mp4"), fps=fps)
if not no_save_gif:
# Create output directory
Path(output_gif_directory).mkdir(parents=True, exist_ok=True)
# Convert frames to GIF
convert_frames_to_gif(frames, os.path.join(output_gif_directory, f"{uuid_prefix}.gif"), duration=1000 // fps)
if __name__ == "__main__":
argsparser = argparse.ArgumentParser()
argsparser.add_argument("--prompt", type=str, default="")
argsparser.add_argument("--negative_prompt", type=str, default="")
argsparser.add_argument("--frame_count", type=int, default=16)
argsparser.add_argument("--num_inference_steps", type=int, default=20)
argsparser.add_argument("--guidance_scale", type=float, default=7.5)
argsparser.add_argument("--image_guidance_scale", type=float, default=1.5)
argsparser.add_argument("--width", type=int, default=512)
argsparser.add_argument("--height", type=int, default=512)
argsparser.add_argument("--dtype", type=str, default="float16")
argsparser.add_argument("--output_frames_directory", type=str, default="output_frames")
argsparser.add_argument("--output_video_directory", type=str, default="output_videos")
argsparser.add_argument("--output_gif_directory", type=str, default="output_gifs")
argsparser.add_argument("--init_image", type=str, default=None)
argsparser.add_argument("--init_folder", type=str, default=None)
argsparser.add_argument("--motion_module", type=str, default="checkpoints/viddle-pix2pix-animatediff-v1.ckpt")
argsparser.add_argument("--t2i_model", type=str, default="timbrooks/instruct-pix2pix")
argsparser.add_argument("--seed", type=int, default=42)
argsparser.add_argument("--fps", type=int, default=15)
argsparser.add_argument("--no_save_frames", action="store_true", default=False)
argsparser.add_argument("--no_save_video", action="store_true", default=False)
argsparser.add_argument("--no_save_gif", action="store_true", default=False)
args = argsparser.parse_args()
run(t2i_model=args.t2i_model,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
frame_count=args.frame_count,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
width=args.width,
height=args.height,
dtype=args.dtype,
output_frames_directory=args.output_frames_directory,
output_video_directory=args.output_video_directory,
output_gif_directory=args.output_gif_directory,
motion_module=args.motion_module,
init_image=args.init_image,
init_folder=args.init_folder,
seed=args.seed,
fps=args.fps,
no_save_frames=args.no_save_frames,
no_save_video=args.no_save_video,
no_save_gif=args.no_save_gif,
)