Spaces:
No application file
No application file
import argparse | |
import glob | |
import os | |
from pathlib import Path | |
import uuid | |
from src.pipelines.pipeline_animatediff_pix2pix import StableDiffusionInstructPix2PixPipeline | |
from diffusers import EulerAncestralDiscreteScheduler | |
import torch | |
from src.models.unet import UNet3DConditionModel | |
import numpy as np | |
from PIL import Image | |
import imageio | |
def convert_frames_to_mp4(frames, filename, fps=30): | |
"""Converts a list of PIL Image frames to an MP4 file. | |
Args: | |
frames: A list of PIL Image frames. | |
filename: The name of the MP4 file to save. | |
fps: Frames per second for the video. | |
Returns: | |
None | |
""" | |
# Convert PIL Images to numpy arrays | |
numpy_frames = [np.array(frame) for frame in frames] | |
# Write frames to mp4 | |
imageio.mimwrite(filename, numpy_frames, fps=fps) | |
def convert_frames_to_gif(frames, filename, duration=100): | |
"""Converts a list of PIL Image frames to a GIF file. | |
Args: | |
frames: A list of PIL Image frames. | |
filename: The name of the GIF file to save. | |
duration: Duration of each frame in milliseconds. | |
Returns: | |
None | |
""" | |
frames[0].save( | |
filename, | |
save_all=True, | |
append_images=frames[1:], | |
loop=0, | |
duration=duration | |
) | |
def convert_frames_to_gif_with_fps(frames, filename, fps=30): | |
"""Converts a list of PIL Image frames to a GIF file using fps. | |
Args: | |
frames: A list of PIL Image frames. | |
filename: The name of the GIF file to save. | |
fps: Frames per second for the gif. | |
Returns: | |
None | |
""" | |
duration = 1000 // fps | |
frames[0].save( | |
filename, | |
save_all=True, | |
append_images=frames[1:], | |
loop=0, | |
duration=duration | |
) | |
def run(t2i_model, | |
prompt="", | |
negative_prompt="", | |
frame_count=16, | |
num_inference_steps=20, | |
guidance_scale=7.5, | |
image_guidance_scale=1.5, | |
width=512, | |
height=512, | |
dtype="float16", | |
output_frames_directory="output_frames", | |
output_video_directory="output_video", | |
output_gif_directory="output_gif", | |
motion_module="viddle/viddle-pix2pix-animatediff-v1.ckpt", | |
init_image=None, | |
init_folder=None, | |
seed=42, | |
fps=15, | |
no_save_frames=False, | |
no_save_video=False, | |
no_save_gif=False, | |
): | |
scheduler_kwargs = { | |
"num_train_timesteps": 1000, | |
"beta_start": 0.00085, | |
"beta_end": 0.012, | |
"beta_schedule": "linear", | |
} | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if dtype == "float16": | |
dtype = torch.float16 | |
variant = "fp16" | |
elif dtype == "float32": | |
dtype = torch.float32 | |
variant = "fp32" | |
unet_additional_kwargs = { | |
"in_channels": 8, | |
"unet_use_cross_frame_attention": False, | |
"unet_use_temporal_attention": False, | |
"use_motion_module": True, | |
"motion_module_resolutions": [1, 2, 4, 8], | |
"motion_module_mid_block": False, | |
"motion_module_decoder_only": False, | |
"motion_module_type": "Vanilla", | |
"motion_module_kwargs": { | |
"num_attention_heads": 8, | |
"num_transformer_block": 1, | |
"attention_block_types": ["Temporal_Self", "Temporal_Self"], | |
"temporal_position_encoding": True, | |
"temporal_position_encoding_max_len": 32, | |
"temporal_attention_dim_div": 1, | |
}, | |
} | |
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
t2i_model, | |
scheduler=EulerAncestralDiscreteScheduler(**scheduler_kwargs), | |
safety_checker=None, | |
feature_extractor=None, | |
requires_safety_checker=False, | |
torch_dtype=dtype, | |
variant=variant, | |
).to(device) | |
pipeline.unet = UNet3DConditionModel.from_pretrained_unet(pipeline.unet, | |
unet_additional_kwargs=unet_additional_kwargs, | |
).to(device=device, dtype=dtype) | |
pipeline.enable_vae_slicing() | |
motion_module_state_dict = torch.load(motion_module, map_location="cpu") | |
_, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) | |
assert len(unexpected) == 0 | |
if init_image is not None and init_folder is None: | |
image = Image.open(init_image) | |
image = image.resize((width, height)) | |
elif init_folder is not None and init_image is None: | |
image_paths = glob.glob(init_folder + "/*.png") | |
# add the jpgs | |
image_paths += glob.glob(init_folder + "/*.jpg") | |
image_paths.sort() | |
image_paths = image_paths[:frame_count] | |
image = [] | |
for image_path in image_paths: | |
image.append(Image.open(image_path).resize((width, height))) | |
else: | |
raise ValueError("Must provide either init_image or init_folder but not both") | |
generator = torch.Generator(device=device).manual_seed(seed) | |
frames = pipeline(prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
image_guidance_scale=image_guidance_scale, | |
image=image, | |
video_length=frame_count, | |
generator=generator, | |
)[0] | |
# create a uuid prefix for the output files | |
uuid_prefix = str(uuid.uuid4()) | |
if not no_save_frames: | |
# Create output directory | |
Path(output_frames_directory).mkdir(parents=True, exist_ok=True) | |
# make the specific directory for this run | |
output_frames_directory = os.path.join(output_frames_directory, uuid_prefix) | |
Path(output_frames_directory).mkdir(parents=True, exist_ok=True) | |
# Save frames | |
for i, frame in enumerate(frames): | |
frame.save(os.path.join(output_frames_directory, f"{str(i).zfill(4)}.png")) | |
if not no_save_video: | |
# Create output directory | |
Path(output_video_directory).mkdir(parents=True, exist_ok=True) | |
convert_frames_to_mp4(frames, os.path.join(output_video_directory, f"{uuid_prefix}.mp4"), fps=fps) | |
if not no_save_gif: | |
# Create output directory | |
Path(output_gif_directory).mkdir(parents=True, exist_ok=True) | |
# Convert frames to GIF | |
convert_frames_to_gif(frames, os.path.join(output_gif_directory, f"{uuid_prefix}.gif"), duration=1000 // fps) | |
if __name__ == "__main__": | |
argsparser = argparse.ArgumentParser() | |
argsparser.add_argument("--prompt", type=str, default="") | |
argsparser.add_argument("--negative_prompt", type=str, default="") | |
argsparser.add_argument("--frame_count", type=int, default=16) | |
argsparser.add_argument("--num_inference_steps", type=int, default=20) | |
argsparser.add_argument("--guidance_scale", type=float, default=7.5) | |
argsparser.add_argument("--image_guidance_scale", type=float, default=1.5) | |
argsparser.add_argument("--width", type=int, default=512) | |
argsparser.add_argument("--height", type=int, default=512) | |
argsparser.add_argument("--dtype", type=str, default="float16") | |
argsparser.add_argument("--output_frames_directory", type=str, default="output_frames") | |
argsparser.add_argument("--output_video_directory", type=str, default="output_videos") | |
argsparser.add_argument("--output_gif_directory", type=str, default="output_gifs") | |
argsparser.add_argument("--init_image", type=str, default=None) | |
argsparser.add_argument("--init_folder", type=str, default=None) | |
argsparser.add_argument("--motion_module", type=str, default="checkpoints/viddle-pix2pix-animatediff-v1.ckpt") | |
argsparser.add_argument("--t2i_model", type=str, default="timbrooks/instruct-pix2pix") | |
argsparser.add_argument("--seed", type=int, default=42) | |
argsparser.add_argument("--fps", type=int, default=15) | |
argsparser.add_argument("--no_save_frames", action="store_true", default=False) | |
argsparser.add_argument("--no_save_video", action="store_true", default=False) | |
argsparser.add_argument("--no_save_gif", action="store_true", default=False) | |
args = argsparser.parse_args() | |
run(t2i_model=args.t2i_model, | |
prompt=args.prompt, | |
negative_prompt=args.negative_prompt, | |
frame_count=args.frame_count, | |
num_inference_steps=args.num_inference_steps, | |
guidance_scale=args.guidance_scale, | |
width=args.width, | |
height=args.height, | |
dtype=args.dtype, | |
output_frames_directory=args.output_frames_directory, | |
output_video_directory=args.output_video_directory, | |
output_gif_directory=args.output_gif_directory, | |
motion_module=args.motion_module, | |
init_image=args.init_image, | |
init_folder=args.init_folder, | |
seed=args.seed, | |
fps=args.fps, | |
no_save_frames=args.no_save_frames, | |
no_save_video=args.no_save_video, | |
no_save_gif=args.no_save_gif, | |
) | |