Files changed (1) hide show
  1. main.py +252 -0
main.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ from pathlib import Path
5
+ import uuid
6
+ from src.pipelines.pipeline_animatediff_pix2pix import StableDiffusionInstructPix2PixPipeline
7
+ from diffusers import EulerAncestralDiscreteScheduler
8
+ import torch
9
+ from src.models.unet import UNet3DConditionModel
10
+ import numpy as np
11
+ from PIL import Image
12
+ import imageio
13
+
14
+ def convert_frames_to_mp4(frames, filename, fps=30):
15
+ """Converts a list of PIL Image frames to an MP4 file.
16
+
17
+ Args:
18
+ frames: A list of PIL Image frames.
19
+ filename: The name of the MP4 file to save.
20
+ fps: Frames per second for the video.
21
+
22
+ Returns:
23
+ None
24
+ """
25
+ # Convert PIL Images to numpy arrays
26
+ numpy_frames = [np.array(frame) for frame in frames]
27
+ # Write frames to mp4
28
+ imageio.mimwrite(filename, numpy_frames, fps=fps)
29
+
30
+ def convert_frames_to_gif(frames, filename, duration=100):
31
+ """Converts a list of PIL Image frames to a GIF file.
32
+
33
+ Args:
34
+ frames: A list of PIL Image frames.
35
+ filename: The name of the GIF file to save.
36
+ duration: Duration of each frame in milliseconds.
37
+
38
+ Returns:
39
+ None
40
+ """
41
+ frames[0].save(
42
+ filename,
43
+ save_all=True,
44
+ append_images=frames[1:],
45
+ loop=0,
46
+ duration=duration
47
+ )
48
+
49
+
50
+ def convert_frames_to_gif_with_fps(frames, filename, fps=30):
51
+ """Converts a list of PIL Image frames to a GIF file using fps.
52
+
53
+ Args:
54
+ frames: A list of PIL Image frames.
55
+ filename: The name of the GIF file to save.
56
+ fps: Frames per second for the gif.
57
+
58
+ Returns:
59
+ None
60
+ """
61
+ duration = 1000 // fps
62
+ frames[0].save(
63
+ filename,
64
+ save_all=True,
65
+ append_images=frames[1:],
66
+ loop=0,
67
+ duration=duration
68
+ )
69
+
70
+
71
+ def run(t2i_model,
72
+ prompt="",
73
+ negative_prompt="",
74
+ frame_count=16,
75
+ num_inference_steps=20,
76
+ guidance_scale=7.5,
77
+ image_guidance_scale=1.5,
78
+ width=512,
79
+ height=512,
80
+ dtype="float16",
81
+ output_frames_directory="output_frames",
82
+ output_video_directory="output_video",
83
+ output_gif_directory="output_gif",
84
+ motion_module="viddle/viddle-pix2pix-animatediff-v1.ckpt",
85
+ init_image=None,
86
+ init_folder=None,
87
+ seed=42,
88
+ fps=15,
89
+ no_save_frames=False,
90
+ no_save_video=False,
91
+ no_save_gif=False,
92
+ ):
93
+ scheduler_kwargs = {
94
+ "num_train_timesteps": 1000,
95
+ "beta_start": 0.00085,
96
+ "beta_end": 0.012,
97
+ "beta_schedule": "linear",
98
+ }
99
+
100
+ device = "cuda" if torch.cuda.is_available() else "cpu"
101
+ if dtype == "float16":
102
+ dtype = torch.float16
103
+ variant = "fp16"
104
+ elif dtype == "float32":
105
+ dtype = torch.float32
106
+ variant = "fp32"
107
+
108
+ unet_additional_kwargs = {
109
+ "in_channels": 8,
110
+ "unet_use_cross_frame_attention": False,
111
+ "unet_use_temporal_attention": False,
112
+ "use_motion_module": True,
113
+ "motion_module_resolutions": [1, 2, 4, 8],
114
+ "motion_module_mid_block": False,
115
+ "motion_module_decoder_only": False,
116
+ "motion_module_type": "Vanilla",
117
+ "motion_module_kwargs": {
118
+ "num_attention_heads": 8,
119
+ "num_transformer_block": 1,
120
+ "attention_block_types": ["Temporal_Self", "Temporal_Self"],
121
+ "temporal_position_encoding": True,
122
+ "temporal_position_encoding_max_len": 32,
123
+ "temporal_attention_dim_div": 1,
124
+ },
125
+ }
126
+
127
+ pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
128
+ t2i_model,
129
+ scheduler=EulerAncestralDiscreteScheduler(**scheduler_kwargs),
130
+ safety_checker=None,
131
+ feature_extractor=None,
132
+ requires_safety_checker=False,
133
+ torch_dtype=dtype,
134
+ variant=variant,
135
+ ).to(device)
136
+
137
+ pipeline.unet = UNet3DConditionModel.from_pretrained_unet(pipeline.unet,
138
+ unet_additional_kwargs=unet_additional_kwargs,
139
+ ).to(device=device, dtype=dtype)
140
+
141
+ pipeline.enable_vae_slicing()
142
+
143
+ motion_module_state_dict = torch.load(motion_module, map_location="cpu")
144
+ _, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
145
+ assert len(unexpected) == 0
146
+
147
+ if init_image is not None and init_folder is None:
148
+ image = Image.open(init_image)
149
+ image = image.resize((width, height))
150
+ elif init_folder is not None and init_image is None:
151
+ image_paths = glob.glob(init_folder + "/*.png")
152
+ # add the jpgs
153
+ image_paths += glob.glob(init_folder + "/*.jpg")
154
+ image_paths.sort()
155
+ image_paths = image_paths[:frame_count]
156
+
157
+ image = []
158
+
159
+ for image_path in image_paths:
160
+ image.append(Image.open(image_path).resize((width, height)))
161
+ else:
162
+ raise ValueError("Must provide either init_image or init_folder but not both")
163
+
164
+ generator = torch.Generator(device=device).manual_seed(seed)
165
+
166
+ frames = pipeline(prompt=prompt,
167
+ negative_prompt=negative_prompt,
168
+ num_inference_steps=num_inference_steps,
169
+ guidance_scale=guidance_scale,
170
+ image_guidance_scale=image_guidance_scale,
171
+ image=image,
172
+ video_length=frame_count,
173
+ generator=generator,
174
+ )[0]
175
+
176
+ # create a uuid prefix for the output files
177
+ uuid_prefix = str(uuid.uuid4())
178
+
179
+ if not no_save_frames:
180
+ # Create output directory
181
+ Path(output_frames_directory).mkdir(parents=True, exist_ok=True)
182
+
183
+ # make the specific directory for this run
184
+ output_frames_directory = os.path.join(output_frames_directory, uuid_prefix)
185
+ Path(output_frames_directory).mkdir(parents=True, exist_ok=True)
186
+ # Save frames
187
+ for i, frame in enumerate(frames):
188
+ frame.save(os.path.join(output_frames_directory, f"{str(i).zfill(4)}.png"))
189
+
190
+ if not no_save_video:
191
+ # Create output directory
192
+ Path(output_video_directory).mkdir(parents=True, exist_ok=True)
193
+
194
+ convert_frames_to_mp4(frames, os.path.join(output_video_directory, f"{uuid_prefix}.mp4"), fps=fps)
195
+
196
+ if not no_save_gif:
197
+ # Create output directory
198
+ Path(output_gif_directory).mkdir(parents=True, exist_ok=True)
199
+
200
+ # Convert frames to GIF
201
+ convert_frames_to_gif(frames, os.path.join(output_gif_directory, f"{uuid_prefix}.gif"), duration=1000 // fps)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ argsparser = argparse.ArgumentParser()
206
+ argsparser.add_argument("--prompt", type=str, default="")
207
+ argsparser.add_argument("--negative_prompt", type=str, default="")
208
+ argsparser.add_argument("--frame_count", type=int, default=16)
209
+ argsparser.add_argument("--num_inference_steps", type=int, default=20)
210
+ argsparser.add_argument("--guidance_scale", type=float, default=7.5)
211
+ argsparser.add_argument("--image_guidance_scale", type=float, default=1.5)
212
+ argsparser.add_argument("--width", type=int, default=512)
213
+ argsparser.add_argument("--height", type=int, default=512)
214
+ argsparser.add_argument("--dtype", type=str, default="float16")
215
+ argsparser.add_argument("--output_frames_directory", type=str, default="output_frames")
216
+ argsparser.add_argument("--output_video_directory", type=str, default="output_videos")
217
+ argsparser.add_argument("--output_gif_directory", type=str, default="output_gifs")
218
+ argsparser.add_argument("--init_image", type=str, default=None)
219
+ argsparser.add_argument("--init_folder", type=str, default=None)
220
+ argsparser.add_argument("--motion_module", type=str, default="checkpoints/viddle-pix2pix-animatediff-v1.ckpt")
221
+ argsparser.add_argument("--t2i_model", type=str, default="timbrooks/instruct-pix2pix")
222
+ argsparser.add_argument("--seed", type=int, default=42)
223
+ argsparser.add_argument("--fps", type=int, default=15)
224
+ argsparser.add_argument("--no_save_frames", action="store_true", default=False)
225
+ argsparser.add_argument("--no_save_video", action="store_true", default=False)
226
+ argsparser.add_argument("--no_save_gif", action="store_true", default=False)
227
+ args = argsparser.parse_args()
228
+
229
+ run(t2i_model=args.t2i_model,
230
+ prompt=args.prompt,
231
+ negative_prompt=args.negative_prompt,
232
+ frame_count=args.frame_count,
233
+ num_inference_steps=args.num_inference_steps,
234
+ guidance_scale=args.guidance_scale,
235
+ width=args.width,
236
+ height=args.height,
237
+ dtype=args.dtype,
238
+ output_frames_directory=args.output_frames_directory,
239
+ output_video_directory=args.output_video_directory,
240
+ output_gif_directory=args.output_gif_directory,
241
+ motion_module=args.motion_module,
242
+ init_image=args.init_image,
243
+ init_folder=args.init_folder,
244
+ seed=args.seed,
245
+ fps=args.fps,
246
+ no_save_frames=args.no_save_frames,
247
+ no_save_video=args.no_save_video,
248
+ no_save_gif=args.no_save_gif,
249
+ )
250
+
251
+
252
+