voidDescriptor commited on
Commit
8a943d8
1 Parent(s): d02956d

Upload 3 files

Browse files
Files changed (3) hide show
  1. fine_tune.py +987 -0
  2. setup.py +15 -0
  3. utils.py +228 -0
fine_tune.py ADDED
@@ -0,0 +1,987 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+ import math
17
+ import os
18
+ import traceback
19
+ from pathlib import Path
20
+ import time
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ import torch.multiprocessing as mp
24
+ from accelerate import Accelerator
25
+ from accelerate.logging import get_logger
26
+ from accelerate.utils import set_seed
27
+ from diffusers import AutoencoderKL
28
+ from diffusers.optimization import get_scheduler
29
+ from diffusers import DDPMScheduler
30
+ from torchvision import transforms
31
+ from tqdm.auto import tqdm
32
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
33
+ import torch.nn.functional as F
34
+ import gc
35
+ from typing import Callable
36
+ from PIL import Image
37
+ import numpy as np
38
+ from concurrent.futures import ThreadPoolExecutor
39
+ from hotshot_xl.models.unet import UNet3DConditionModel
40
+ from hotshot_xl.pipelines.hotshot_xl_pipeline import HotshotXLPipeline
41
+ from hotshot_xl.utils import get_crop_coordinates, res_to_aspect_map, scale_aspect_fill
42
+ from einops import rearrange
43
+ from torch.utils.data import Dataset, DataLoader
44
+ from datetime import timedelta
45
+ from accelerate.utils.dataclasses import InitProcessGroupKwargs
46
+ from diffusers.utils import is_wandb_available
47
+
48
+ if is_wandb_available():
49
+ import wandb
50
+
51
+ logger = get_logger(__file__)
52
+
53
+
54
+ class HotshotXLDataset(Dataset):
55
+
56
+ def __init__(self, directory: str, make_sample_fn: Callable):
57
+ """
58
+
59
+ Training data folder needs to look like:
60
+ + training_samples
61
+ --- + sample_001
62
+ ------- + frame_0.jpg
63
+ ------- + frame_1.jpg
64
+ ------- + ...
65
+ ------- + frame_n.jpg
66
+ ------- + prompt.txt
67
+ --- + sample_002
68
+ ------- + frame_0.jpg
69
+ ------- + frame_1.jpg
70
+ ------- + ...
71
+ ------- + frame_n.jpg
72
+ ------- + prompt.txt
73
+
74
+ Args:
75
+ directory: base directory of the training samples
76
+ make_sample_fn: a delegate call to load the images and prep the sample for batching
77
+ """
78
+ samples_dir = [os.path.join(directory, p) for p in os.listdir(directory)]
79
+ samples_dir = [p for p in samples_dir if os.path.isdir(p)]
80
+ samples = []
81
+
82
+ for d in samples_dir:
83
+ file_paths = [os.path.join(d, p) for p in os.listdir(d)]
84
+ image_fps = [f for f in file_paths if os.path.splitext(f)[1] in {".png", ".jpg"}]
85
+ with open(os.path.join(d, "prompt.txt")) as f:
86
+ prompt = f.read().strip()
87
+
88
+ samples.append({
89
+ "image_fps": image_fps,
90
+ "prompt": prompt
91
+ })
92
+
93
+ self.samples = samples
94
+ self.length = len(samples)
95
+ self.make_sample_fn = make_sample_fn
96
+
97
+ def __len__(self):
98
+ return self.length
99
+
100
+ def __getitem__(self, index):
101
+ return self.make_sample_fn(
102
+ self.samples[index]
103
+ )
104
+
105
+
106
+ def parse_args():
107
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
108
+ parser.add_argument(
109
+ "--pretrained_model_name_or_path",
110
+ type=str,
111
+ default="hotshotco/Hotshot-XL",
112
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
113
+ )
114
+ parser.add_argument(
115
+ "--unet_resume_path",
116
+ type=str,
117
+ default=None,
118
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
119
+ )
120
+
121
+ parser.add_argument(
122
+ "--data_dir",
123
+ type=str,
124
+ required=True,
125
+ help="Path to data to train.",
126
+ )
127
+
128
+ parser.add_argument(
129
+ "--report_to",
130
+ type=str,
131
+ default="wandb",
132
+ help=(
133
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
134
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
135
+ ),
136
+ )
137
+
138
+ parser.add_argument("--run_validation_at_start", action="store_true")
139
+ parser.add_argument("--max_vae_encode", type=int, default=None)
140
+ parser.add_argument("--vae_b16", action="store_true")
141
+ parser.add_argument("--disable_optimizer_restore", action="store_true")
142
+
143
+ parser.add_argument(
144
+ "--latent_nan_checking",
145
+ action="store_true",
146
+ help="Check if latents contain nans - important if vae is f16",
147
+ )
148
+ parser.add_argument(
149
+ "--test_prompts",
150
+ type=str,
151
+ default=None,
152
+ )
153
+ parser.add_argument(
154
+ "--project_name",
155
+ type=str,
156
+ default="fine-tune-hotshot-xl",
157
+ help="the name of the run",
158
+ )
159
+ parser.add_argument(
160
+ "--run_name",
161
+ type=str,
162
+ default="run-01",
163
+ help="the name of the run",
164
+ )
165
+ parser.add_argument(
166
+ "--output_dir",
167
+ type=str,
168
+ default="output",
169
+ help="The output directory where the model predictions and checkpoints will be written.",
170
+ )
171
+ parser.add_argument("--noise_offset", type=float, default=0.05, help="The scale of noise offset.")
172
+ parser.add_argument("--seed", type=int, default=111, help="A seed for reproducible training.")
173
+ parser.add_argument(
174
+ "--resolution",
175
+ type=int,
176
+ default=512,
177
+ help=(
178
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
179
+ " resolution"
180
+ ),
181
+ )
182
+ parser.add_argument(
183
+ "--aspect_ratio",
184
+ type=str,
185
+ default="1.75",
186
+ choices=list(res_to_aspect_map[512].keys()),
187
+ help="Aspect ratio to train at",
188
+ )
189
+
190
+ parser.add_argument("--xformers", action="store_true")
191
+
192
+ parser.add_argument(
193
+ "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
194
+ )
195
+
196
+ parser.add_argument("--num_train_epochs", type=int, default=1)
197
+
198
+ parser.add_argument(
199
+ "--max_train_steps",
200
+ type=int,
201
+ default=9999999,
202
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
203
+ )
204
+ parser.add_argument(
205
+ "--gradient_accumulation_steps",
206
+ type=int,
207
+ default=1,
208
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
209
+ )
210
+ parser.add_argument(
211
+ "--gradient_checkpointing",
212
+ action="store_true",
213
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
214
+ )
215
+
216
+ parser.add_argument(
217
+ "--learning_rate",
218
+ type=float,
219
+ default=5e-6,
220
+ help="Initial learning rate (after the potential warmup period) to use.",
221
+ )
222
+
223
+ parser.add_argument(
224
+ "--scale_lr",
225
+ action="store_true",
226
+ default=False,
227
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
228
+ )
229
+ parser.add_argument(
230
+ "--lr_scheduler",
231
+ type=str,
232
+ default="constant",
233
+ help=(
234
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
235
+ ' "constant", "constant_with_warmup"]'
236
+ ),
237
+ )
238
+ parser.add_argument(
239
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
240
+ )
241
+ parser.add_argument(
242
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
243
+ )
244
+
245
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
246
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
247
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
248
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
249
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
250
+
251
+ parser.add_argument(
252
+ "--logging_dir",
253
+ type=str,
254
+ default="logs",
255
+ help=(
256
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
257
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
258
+ ),
259
+ )
260
+
261
+ parser.add_argument(
262
+ "--mixed_precision",
263
+ type=str,
264
+ default="no",
265
+ choices=["no", "fp16", "bf16"],
266
+ help=(
267
+ "Whether to use mixed precision. Choose"
268
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
269
+ "and an Nvidia Ampere GPU."
270
+ ),
271
+ )
272
+
273
+ parser.add_argument(
274
+ "--validate_every_steps",
275
+ type=int,
276
+ default=100,
277
+ help="Run inference every",
278
+ )
279
+
280
+ parser.add_argument(
281
+ "--save_n_steps",
282
+ type=int,
283
+ default=100,
284
+ help="Save the model every n global_steps",
285
+ )
286
+
287
+ parser.add_argument(
288
+ "--save_starting_step",
289
+ type=int,
290
+ default=100,
291
+ help="The step from which it starts saving intermediary checkpoints",
292
+ )
293
+
294
+ parser.add_argument(
295
+ "--nccl_timeout",
296
+ type=int,
297
+ help="nccl_timeout",
298
+ default=3600
299
+ )
300
+
301
+ parser.add_argument("--snr_gamma", action="store_true")
302
+
303
+ args = parser.parse_args()
304
+
305
+ return args
306
+
307
+
308
+ def add_time_ids(
309
+ unet_config,
310
+ unet_add_embedding,
311
+ text_encoder_2: CLIPTextModelWithProjection,
312
+ original_size: tuple,
313
+ crops_coords_top_left: tuple,
314
+ target_size: tuple,
315
+ dtype: torch.dtype):
316
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
317
+
318
+ passed_add_embed_dim = (
319
+ unet_config.addition_time_embed_dim * len(add_time_ids) + text_encoder_2.config.projection_dim
320
+ )
321
+ expected_add_embed_dim = unet_add_embedding.linear_1.in_features
322
+
323
+ if expected_add_embed_dim != passed_add_embed_dim:
324
+ raise ValueError(
325
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
326
+ )
327
+
328
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
329
+ return add_time_ids
330
+
331
+
332
+ def main():
333
+ global_step = 0
334
+ min_steps_before_validation = 0
335
+
336
+ args = parse_args()
337
+
338
+ next_save_iter = args.save_starting_step
339
+
340
+ if args.save_starting_step < 1:
341
+ next_save_iter = None
342
+
343
+ if args.report_to == "wandb":
344
+ if not is_wandb_available():
345
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
346
+
347
+ accelerator = Accelerator(
348
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
349
+ mixed_precision=args.mixed_precision,
350
+ log_with=args.report_to,
351
+ kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(args.nccl_timeout))]
352
+ )
353
+
354
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
355
+ def save_model_hook(models, weights, output_dir):
356
+ nonlocal global_step
357
+
358
+ for model in models:
359
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
360
+ model.save_pretrained(os.path.join(output_dir, 'unet'))
361
+ # make sure to pop weight so that corresponding model is not saved again
362
+ weights.pop()
363
+
364
+ accelerator.register_save_state_pre_hook(save_model_hook)
365
+
366
+ set_seed(args.seed)
367
+
368
+ # Handle the repository creation
369
+ if accelerator.is_local_main_process:
370
+ if args.output_dir is not None:
371
+ os.makedirs(args.output_dir, exist_ok=True)
372
+
373
+ # Load the tokenizer
374
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
375
+ tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
376
+
377
+ # Load models and create wrapper for stable diffusion
378
+ text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
379
+ text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path,
380
+ subfolder="text_encoder_2")
381
+
382
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
383
+
384
+ optimizer_resume_path = None
385
+
386
+ if args.unet_resume_path:
387
+ optimizer_fp = os.path.join(args.unet_resume_path, "optimizer.bin")
388
+
389
+ if os.path.exists(optimizer_fp):
390
+ optimizer_resume_path = optimizer_fp
391
+
392
+ unet = UNet3DConditionModel.from_pretrained(args.unet_resume_path,
393
+ subfolder="unet",
394
+ low_cpu_mem_usage=False,
395
+ device_map=None)
396
+
397
+ else:
398
+ unet = UNet3DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
399
+
400
+ if args.xformers:
401
+ vae.set_use_memory_efficient_attention_xformers(True, None)
402
+ unet.set_use_memory_efficient_attention_xformers(True, None)
403
+
404
+ unet_config = unet.config
405
+ unet_add_embedding = unet.add_embedding
406
+
407
+ unet.requires_grad_(False)
408
+
409
+ temporal_params = unet.temporal_parameters()
410
+
411
+ for p in temporal_params:
412
+ p.requires_grad_(True)
413
+
414
+ vae.requires_grad_(False)
415
+ text_encoder.requires_grad_(False)
416
+ text_encoder_2.requires_grad_(False)
417
+
418
+ if args.gradient_checkpointing:
419
+ unet.enable_gradient_checkpointing()
420
+
421
+ if args.scale_lr:
422
+ args.learning_rate = (
423
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
424
+ )
425
+
426
+ # Use 8-bit Adam for lower memory usage
427
+ if args.use_8bit_adam:
428
+ try:
429
+ import bitsandbytes as bnb
430
+ except ImportError:
431
+ raise ImportError(
432
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
433
+ )
434
+
435
+ optimizer_class = bnb.optim.AdamW8bit
436
+ else:
437
+ optimizer_class = torch.optim.AdamW
438
+
439
+ learning_rate = args.learning_rate
440
+
441
+ params_to_optimize = [
442
+ {'params': temporal_params, "lr": learning_rate},
443
+ ]
444
+
445
+ optimizer = optimizer_class(
446
+ params_to_optimize,
447
+ lr=args.learning_rate,
448
+ betas=(args.adam_beta1, args.adam_beta2),
449
+ weight_decay=args.adam_weight_decay,
450
+ eps=args.adam_epsilon,
451
+ )
452
+
453
+ if optimizer_resume_path and not args.disable_optimizer_restore:
454
+ logger.info("Restoring the optimizer.")
455
+ try:
456
+
457
+ old_optimizer_state_dict = torch.load(optimizer_resume_path)
458
+
459
+ # Extract only the state
460
+ old_state = old_optimizer_state_dict['state']
461
+
462
+ # Set the state of the new optimizer
463
+ optimizer.load_state_dict({'state': old_state, 'param_groups': optimizer.param_groups})
464
+
465
+ del old_optimizer_state_dict
466
+ del old_state
467
+
468
+ torch.cuda.empty_cache()
469
+ torch.cuda.synchronize()
470
+ gc.collect()
471
+
472
+ logger.info(f"Restored the optimizer ok")
473
+
474
+ except:
475
+ logger.error("Failed to restore the optimizer...", exc_info=True)
476
+ traceback.print_exc()
477
+ raise
478
+
479
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
480
+
481
+ def compute_snr(timesteps):
482
+ """
483
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
484
+ """
485
+ alphas_cumprod = noise_scheduler.alphas_cumprod
486
+ sqrt_alphas_cumprod = alphas_cumprod ** 0.5
487
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
488
+
489
+ # Expand the tensors.
490
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
491
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
492
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
493
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
494
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
495
+
496
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
497
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
498
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
499
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
500
+
501
+ # Compute SNR.
502
+ snr = (alpha / sigma) ** 2
503
+ return snr
504
+
505
+ device = torch.device('cuda')
506
+
507
+ image_transforms = transforms.Compose(
508
+ [
509
+ transforms.ToTensor(),
510
+ transforms.Normalize([0.5], [0.5]),
511
+ ]
512
+ )
513
+
514
+ def image_to_tensor(img):
515
+ with torch.no_grad():
516
+
517
+ if img.mode != "RGB":
518
+ img = img.convert("RGB")
519
+
520
+ image = image_transforms(img).to(accelerator.device)
521
+
522
+ if image.shape[0] == 1:
523
+ image = image.repeat(3, 1, 1)
524
+
525
+ if image.shape[0] > 3:
526
+ image = image[:3, :, :]
527
+
528
+ return image
529
+
530
+ def make_sample(sample):
531
+
532
+ nonlocal unet_config
533
+ nonlocal unet_add_embedding
534
+
535
+ images = [Image.open(img) for img in sample['image_fps']]
536
+
537
+ og_size = images[0].size
538
+
539
+ for i, im in enumerate(images):
540
+ if im.mode != "RGB":
541
+ images[i] = im.convert("RGB")
542
+
543
+ aspect_ratio_map = res_to_aspect_map[args.resolution]
544
+
545
+ required_size = tuple(aspect_ratio_map[args.aspect_ratio])
546
+
547
+ if required_size != og_size:
548
+
549
+ def resize_image(x):
550
+ img_size = x.size
551
+ if img_size == required_size:
552
+ return x.resize(required_size, Image.LANCZOS)
553
+
554
+ return scale_aspect_fill(x, required_size[0], required_size[1])
555
+
556
+ with ThreadPoolExecutor(max_workers=len(images)) as executor:
557
+ images = list(executor.map(resize_image, images))
558
+
559
+ frames = torch.stack([image_to_tensor(x) for x in images])
560
+
561
+ l, u, *_ = get_crop_coordinates(og_size, images[0].size)
562
+ crop_coords = (l, u)
563
+
564
+ additional_time_ids = add_time_ids(
565
+ unet_config,
566
+ unet_add_embedding,
567
+ text_encoder_2,
568
+ og_size,
569
+ crop_coords,
570
+ (required_size[0], required_size[1]),
571
+ dtype=torch.float32
572
+ ).to(device)
573
+
574
+ input_ids_0 = tokenizer(
575
+ sample['prompt'],
576
+ padding="do_not_pad",
577
+ truncation=True,
578
+ max_length=tokenizer.model_max_length,
579
+ ).input_ids
580
+
581
+ input_ids_1 = tokenizer_2(
582
+ sample['prompt'],
583
+ padding="do_not_pad",
584
+ truncation=True,
585
+ max_length=tokenizer.model_max_length,
586
+ ).input_ids
587
+
588
+ return {
589
+ "frames": frames,
590
+ "input_ids_0": input_ids_0,
591
+ "input_ids_1": input_ids_1,
592
+ "additional_time_ids": additional_time_ids,
593
+ }
594
+
595
+ def collate_fn(examples: list) -> dict:
596
+
597
+ # Two Text encoders
598
+ # First Text Encoder -> Penultimate Layer
599
+ # Second Text Encoder -> Pooled Layer
600
+
601
+ input_ids_0 = [example['input_ids_0'] for example in examples]
602
+ input_ids_0 = tokenizer.pad({"input_ids": input_ids_0}, padding="max_length",
603
+ max_length=tokenizer.model_max_length, return_tensors="pt").input_ids
604
+
605
+ prompt_embeds_0 = text_encoder(
606
+ input_ids_0.to(device),
607
+ output_hidden_states=True,
608
+ )
609
+
610
+ # we take penultimate embeddings from the first text encoder
611
+ prompt_embeds_0 = prompt_embeds_0.hidden_states[-2]
612
+
613
+ input_ids_1 = [example['input_ids_1'] for example in examples]
614
+ input_ids_1 = tokenizer_2.pad({"input_ids": input_ids_1}, padding="max_length",
615
+ max_length=tokenizer.model_max_length, return_tensors="pt").input_ids
616
+
617
+ # We are only ALWAYS interested in the pooled output of the final text encoder
618
+ prompt_embeds = text_encoder_2(
619
+ input_ids_1.to(device),
620
+ output_hidden_states=True
621
+ )
622
+
623
+ pooled_prompt_embeds = prompt_embeds[0]
624
+ prompt_embeds_1 = prompt_embeds.hidden_states[-2]
625
+
626
+ prompt_embeds = torch.concat([prompt_embeds_0, prompt_embeds_1], dim=-1)
627
+
628
+ *_, h, w = examples[0]['frames'].shape
629
+
630
+ return {
631
+ "frames": torch.stack([x['frames'] for x in examples]).to(memory_format=torch.contiguous_format).float(),
632
+ "prompt_embeds": prompt_embeds.to(memory_format=torch.contiguous_format).float(),
633
+ "pooled_prompt_embeds": pooled_prompt_embeds,
634
+ "additional_time_ids": torch.stack([x['additional_time_ids'] for x in examples]),
635
+ }
636
+
637
+ # Region - Dataloaders
638
+ dataset = HotshotXLDataset(args.data_dir, make_sample)
639
+ dataloader = DataLoader(dataset, args.train_batch_size, shuffle=True, collate_fn=collate_fn)
640
+
641
+ # Scheduler and math around the number of training steps.
642
+ overrode_max_train_steps = False
643
+ num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps)
644
+
645
+ if args.max_train_steps is None:
646
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
647
+ overrode_max_train_steps = True
648
+
649
+ lr_scheduler = get_scheduler(
650
+ args.lr_scheduler,
651
+ optimizer=optimizer,
652
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
653
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
654
+ )
655
+
656
+ unet, optimizer, lr_scheduler, dataloader = accelerator.prepare(
657
+ unet, optimizer, lr_scheduler, dataloader
658
+ )
659
+
660
+ def to_images(video_frames: torch.Tensor):
661
+ import torchvision.transforms as transforms
662
+ to_pil = transforms.ToPILImage()
663
+ video_frames = rearrange(video_frames, "b c f w h -> b f c w h")
664
+ bsz = video_frames.shape[0]
665
+ images = []
666
+ for i in range(bsz):
667
+ video = video_frames[i]
668
+ for j in range(video.shape[0]):
669
+ image = to_pil(video[j])
670
+ images.append(image)
671
+ return images
672
+
673
+ def to_video_frames(images: list) -> np.ndarray:
674
+ x = np.stack([np.asarray(img) for img in images])
675
+ return np.transpose(x, (0, 3, 1, 2))
676
+
677
+ def run_validation(step=0, node_index=0):
678
+
679
+ nonlocal global_step
680
+ nonlocal accelerator
681
+
682
+ if args.test_prompts:
683
+ prompts = args.test_prompts.split("|")
684
+ else:
685
+ prompts = [
686
+ "a woman is lifting weights in a gym",
687
+ "a group of people are dancing at a party",
688
+ "a teddy bear doing the front crawl"
689
+ ]
690
+
691
+ torch.cuda.empty_cache()
692
+ gc.collect()
693
+
694
+ logger.info(f"Running inference to test model at {step} steps")
695
+ with torch.no_grad():
696
+
697
+ pipe = HotshotXLPipeline.from_pretrained(
698
+ args.pretrained_model_name_or_path,
699
+ unet=accelerator.unwrap_model(unet),
700
+ text_encoder=text_encoder,
701
+ text_encoder_2=text_encoder_2,
702
+ vae=vae,
703
+ )
704
+
705
+ videos = []
706
+
707
+ aspect_ratio_map = res_to_aspect_map[args.resolution]
708
+ w, h = aspect_ratio_map[args.aspect_ratio]
709
+
710
+ for prompt in prompts:
711
+ video = pipe(prompt,
712
+ width=w,
713
+ height=h,
714
+ original_size=(1920, 1080), # todo - pass in as args?
715
+ target_size=(args.resolution, args.resolution),
716
+ num_inference_steps=30,
717
+ video_length=8,
718
+ output_type="tensor",
719
+ generator=torch.Generator().manual_seed(111)).videos
720
+
721
+ videos.append(to_images(video))
722
+
723
+ for tracker in accelerator.trackers:
724
+
725
+ if tracker.name == "wandb":
726
+ tracker.log(
727
+ {
728
+ "validation": [wandb.Video(to_video_frames(video), fps=8, format='mp4') for video in
729
+ videos],
730
+ }, step=global_step
731
+ )
732
+
733
+ del pipe
734
+
735
+ return
736
+
737
+ # Move text_encode and vae to gpu.
738
+ vae.to(accelerator.device, dtype=torch.bfloat16 if args.vae_b16 else torch.float32)
739
+ text_encoder.to(accelerator.device)
740
+ text_encoder_2.to(accelerator.device)
741
+
742
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
743
+
744
+ num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps)
745
+ if overrode_max_train_steps:
746
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
747
+ # Afterward we recalculate our number of training epochs
748
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
749
+
750
+ # We need to initialize the trackers we use, and also store our configuration.
751
+ # The trackers initialize automatically on the main process.
752
+
753
+ if accelerator.is_main_process:
754
+ accelerator.init_trackers(args.project_name)
755
+
756
+ def bar(prg):
757
+ br = '|' + '█' * prg + ' ' * (25 - prg) + '|'
758
+ return br
759
+
760
+ # Train!
761
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
762
+
763
+ if accelerator.is_main_process:
764
+ logger.info("***** Running training *****")
765
+ logger.info(f" Num examples = {len(dataset)}")
766
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
767
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
768
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
769
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
770
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
771
+
772
+ # Only show the progress bar once on each machine.
773
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
774
+
775
+ latents_scaler = vae.config.scaling_factor
776
+
777
+ def save_checkpoint():
778
+ save_dir = Path(args.output_dir)
779
+ save_dir = str(save_dir)
780
+ save_dir = save_dir.replace(" ", "_")
781
+ if not os.path.exists(save_dir):
782
+ os.makedirs(save_dir, exist_ok=True)
783
+ accelerator.save_state(save_dir)
784
+
785
+ def save_checkpoint_and_wait():
786
+ if accelerator.is_main_process:
787
+ save_checkpoint()
788
+ accelerator.wait_for_everyone()
789
+
790
+ def save_model_and_wait():
791
+ if accelerator.is_main_process:
792
+ HotshotXLPipeline.from_pretrained(
793
+ args.pretrained_model_name_or_path,
794
+ unet=accelerator.unwrap_model(unet),
795
+ text_encoder=text_encoder,
796
+ text_encoder_2=text_encoder_2,
797
+ vae=vae,
798
+ ).save_pretrained(args.output_dir, safe_serialization=True)
799
+ accelerator.wait_for_everyone()
800
+
801
+ def compute_loss_from_batch(batch: dict):
802
+ frames = batch["frames"]
803
+ bsz, number_of_frames, c, w, h = frames.shape
804
+
805
+ # Convert images to latent space
806
+ with torch.no_grad():
807
+
808
+ if args.max_vae_encode:
809
+ latents = []
810
+
811
+ x = rearrange(frames, "bs nf c h w -> (bs nf) c h w")
812
+
813
+ for latent_index in range(0, x.shape[0], args.max_vae_encode):
814
+ sample = x[latent_index: latent_index + args.max_vae_encode]
815
+
816
+ latent = vae.encode(sample.to(dtype=vae.dtype)).latent_dist.sample().float()
817
+ if len(latent.shape) == 3:
818
+ latent = latent.unsqueeze(0)
819
+
820
+ latents.append(latent)
821
+ torch.cuda.empty_cache()
822
+
823
+ latents = torch.cat(latents, dim=0)
824
+ else:
825
+
826
+ # convert the latents from 5d -> 4d, so we can run it though the vae encoder
827
+ x = rearrange(frames, "bs nf c h w -> (bs nf) c h w")
828
+
829
+ del frames
830
+
831
+ torch.cuda.empty_cache()
832
+
833
+ latents = vae.encode(x.to(dtype=vae.dtype)).latent_dist.sample().float()
834
+
835
+ if args.latent_nan_checking and torch.any(torch.isnan(latents)):
836
+ accelerator.print("NaN found in latents, replacing with zeros")
837
+ latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
838
+
839
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", b=bsz)
840
+
841
+ torch.cuda.empty_cache()
842
+
843
+ noise = torch.randn_like(latents, device=latents.device)
844
+
845
+ if args.noise_offset:
846
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
847
+ noise += args.noise_offset * torch.randn(
848
+ (latents.shape[0], latents.shape[1], 1, 1, 1), device=latents.device
849
+ )
850
+
851
+ # Sample a random timestep for each image
852
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
853
+ timesteps = timesteps.long() # .repeat_interleave(number_of_frames)
854
+ latents = latents * latents_scaler
855
+
856
+ # Add noise to the latents according to the noise magnitude at each timestep
857
+ # (this is the forward diffusion process)
858
+
859
+ prompt_embeds = batch['prompt_embeds']
860
+ add_text_embeds = batch['pooled_prompt_embeds']
861
+
862
+ additional_time_ids = batch['additional_time_ids'] # .repeat_interleave(number_of_frames, dim=0)
863
+
864
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": additional_time_ids}
865
+
866
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
867
+
868
+ if noise_scheduler.config.prediction_type == "epsilon":
869
+ target = noise
870
+ elif noise_scheduler.config.prediction_type == "v_prediction":
871
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
872
+ else:
873
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
874
+
875
+ noisy_latents.requires_grad = True
876
+
877
+ model_pred = unet(noisy_latents,
878
+ timesteps,
879
+ cross_attention_kwargs=None,
880
+ encoder_hidden_states=prompt_embeds,
881
+ added_cond_kwargs=added_cond_kwargs,
882
+ return_dict=False,
883
+ )[0]
884
+
885
+ if args.snr_gamma:
886
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
887
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
888
+ # This is discussed in Section 4.2 of the same paper.
889
+ snr = compute_snr(timesteps)
890
+ mse_loss_weights = (
891
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
892
+ )
893
+ # We first calculate the original loss. Then we mean over the non-batch dimensions and
894
+ # rebalance the sample-wise losses with their respective loss weights.
895
+ # Finally, we take the mean of the rebalanced loss.
896
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
897
+
898
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
899
+ return loss.mean()
900
+ else:
901
+ return F.mse_loss(model_pred.float(), target.float(), reduction='mean')
902
+
903
+ def process_batch(batch: dict):
904
+ nonlocal global_step
905
+ nonlocal next_save_iter
906
+
907
+ now = time.time()
908
+
909
+ with accelerator.accumulate(unet):
910
+
911
+ logging_data = {}
912
+ if global_step == 0:
913
+ # print(f"Running initial validation at step")
914
+ if accelerator.is_main_process and args.run_validation_at_start:
915
+ run_validation(step=global_step, node_index=accelerator.process_index // 8)
916
+ accelerator.wait_for_everyone()
917
+
918
+ loss = compute_loss_from_batch(batch)
919
+
920
+ accelerator.backward(loss)
921
+
922
+ if accelerator.sync_gradients:
923
+ accelerator.clip_grad_norm_(temporal_params, args.max_grad_norm)
924
+
925
+ optimizer.step()
926
+
927
+ lr_scheduler.step()
928
+ optimizer.zero_grad()
929
+
930
+ # Checks if the accelerator has performed an optimization step behind the scenes
931
+ if accelerator.sync_gradients:
932
+ progress_bar.update(1)
933
+ global_step += 1
934
+
935
+ fll = round((global_step * 100) / args.max_train_steps)
936
+ fll = round(fll / 4)
937
+ pr = bar(fll)
938
+
939
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "loss_time": (time.time() - now)}
940
+
941
+ if args.validate_every_steps is not None and global_step > min_steps_before_validation and global_step % args.validate_every_steps == 0:
942
+ if accelerator.is_main_process:
943
+ run_validation(step=global_step, node_index=accelerator.process_index // 8)
944
+
945
+ accelerator.wait_for_everyone()
946
+
947
+ for key, val in logging_data.items():
948
+ logs[key] = val
949
+
950
+ progress_bar.set_postfix(**logs)
951
+ progress_bar.set_description_str("Progress:" + pr)
952
+ accelerator.log(logs, step=global_step)
953
+
954
+ if accelerator.is_main_process \
955
+ and next_save_iter is not None \
956
+ and global_step < args.max_train_steps \
957
+ and global_step + 1 == next_save_iter:
958
+ save_checkpoint()
959
+
960
+ torch.cuda.empty_cache()
961
+ gc.collect()
962
+
963
+ next_save_iter += args.save_n_steps
964
+
965
+ for epoch in range(args.num_train_epochs):
966
+ unet.train()
967
+
968
+ for step, batch in enumerate(dataloader):
969
+ process_batch(batch)
970
+
971
+ if global_step >= args.max_train_steps:
972
+ break
973
+
974
+ if global_step >= args.max_train_steps:
975
+ logger.info("Max train steps reached. Breaking while loop")
976
+ break
977
+
978
+ accelerator.wait_for_everyone()
979
+
980
+ save_model_and_wait()
981
+
982
+ accelerator.end_training()
983
+
984
+
985
+ if __name__ == "__main__":
986
+ mp.set_start_method('spawn')
987
+ main()
setup.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='hotshot_xl',
5
+ version='1.0',
6
+ packages=find_packages(include=['hotshot_xl*',]),
7
+ author="Natural Synthetics Inc",
8
+ install_requires=[
9
+ "torch>=2.0.1",
10
+ "torchvision>=0.15.2",
11
+ "diffusers>=0.21.4",
12
+ "transformers>=4.33.3",
13
+ "einops"
14
+ ],
15
+ )
utils.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Union
16
+ from io import BytesIO
17
+ import PIL
18
+ from PIL import ImageSequence, Image
19
+ import requests
20
+ import os
21
+ import numpy as np
22
+ import imageio
23
+
24
+
25
+ def get_image(img_path) -> PIL.Image.Image:
26
+ if img_path.startswith("http"):
27
+ return PIL.Image.open(requests.get(img_path, stream=True).raw)
28
+ if os.path.exists(img_path):
29
+ return Image.open(img_path)
30
+ raise Exception("File not found")
31
+
32
+ def images_to_gif_bytes(images: List, duration: int = 1000) -> bytes:
33
+ with BytesIO() as output_buffer:
34
+ # Save the first image
35
+ images[0].save(output_buffer,
36
+ format='GIF',
37
+ save_all=True,
38
+ append_images=images[1:],
39
+ duration=duration,
40
+ loop=0) # 0 means the GIF will loop indefinitely
41
+
42
+ # Get the byte array from the buffer
43
+ gif_bytes = output_buffer.getvalue()
44
+
45
+ return gif_bytes
46
+
47
+ def save_as_gif(images: List, file_path: str, duration: int = 1000):
48
+ with open(file_path, "wb") as f:
49
+ f.write(images_to_gif_bytes(images, duration))
50
+
51
+ def images_to_mp4_bytes(images: List[Image.Image], duration: int = 1000) -> bytes:
52
+ with BytesIO() as output_buffer:
53
+ with imageio.get_writer(output_buffer, format='mp4', fps=1/(duration/1000)) as writer:
54
+ for img in images:
55
+ writer.append_data(np.array(img))
56
+ mp4_bytes = output_buffer.getvalue()
57
+
58
+ return mp4_bytes
59
+
60
+ def save_as_mp4(images: List[Image.Image], file_path: str, duration: int = 1000):
61
+ with open(file_path, "wb") as f:
62
+ f.write(images_to_mp4_bytes(images, duration))
63
+
64
+ def scale_aspect_fill(img, new_width, new_height):
65
+ new_width = int(new_width)
66
+ new_height = int(new_height)
67
+
68
+ original_width, original_height = img.size
69
+ ratio_w = float(new_width) / original_width
70
+ ratio_h = float(new_height) / original_height
71
+
72
+ if ratio_w > ratio_h:
73
+ # It must be fixed by width
74
+ resize_width = new_width
75
+ resize_height = round(original_height * ratio_w)
76
+ else:
77
+ # Fixed by height
78
+ resize_width = round(original_width * ratio_h)
79
+ resize_height = new_height
80
+
81
+ img_resized = img.resize((resize_width, resize_height), Image.LANCZOS)
82
+
83
+ # Calculate cropping boundaries and do crop
84
+ left = (resize_width - new_width) / 2
85
+ top = (resize_height - new_height) / 2
86
+ right = (resize_width + new_width) / 2
87
+ bottom = (resize_height + new_height) / 2
88
+
89
+ img_cropped = img_resized.crop((left, top, right, bottom))
90
+
91
+ return img_cropped
92
+
93
+ def extract_gif_frames_from_midpoint(image: Union[str, PIL.Image.Image], fps: int=8, target_duration: int=1000) -> list:
94
+ # Load the GIF
95
+ image = get_image(image) if type(image) is str else image
96
+
97
+ frames = []
98
+
99
+ estimated_frame_time = None
100
+
101
+ # some gifs contain the duration - others don't
102
+ # so if there is a duration we will grab it otherwise we will fall back
103
+
104
+ for frame in ImageSequence.Iterator(image):
105
+
106
+ frames.append(frame.copy())
107
+ if 'duration' in frame.info:
108
+ frame_info_duration = frame.info['duration']
109
+ if frame_info_duration > 0:
110
+ estimated_frame_time = frame_info_duration
111
+
112
+ if estimated_frame_time is None:
113
+ if len(frames) <= 16:
114
+ # assume it's 8fps
115
+ estimated_frame_time = 1000 // 8
116
+ else:
117
+ # assume it's 15 fps
118
+ estimated_frame_time = 70
119
+
120
+ if len(frames) < fps:
121
+ raise ValueError(f"fps of {fps} is too small for this gif as it only has {len(frames)} frames.")
122
+
123
+ skip = len(frames) // fps
124
+ upper_bound_index = len(frames) - 1
125
+
126
+ best_indices = [x for x in range(0, len(frames), skip)][:fps]
127
+ offset = int(upper_bound_index - best_indices[-1]) // 2
128
+ best_indices = [x + offset for x in best_indices]
129
+ best_duration = (best_indices[-1] - best_indices[0]) * estimated_frame_time
130
+
131
+ while True:
132
+
133
+ skip -= 1
134
+
135
+ if skip == 0:
136
+ break
137
+
138
+ indices = [x for x in range(0, len(frames), skip)][:fps]
139
+
140
+ # center the indices, so we sample the middle of the gif...
141
+ offset = int(upper_bound_index - indices[-1]) // 2
142
+ if offset == 0:
143
+ # can't shift
144
+ break
145
+ indices = [x + offset for x in indices]
146
+
147
+ # is the new duration closer to the target than last guess?
148
+ duration = (indices[-1] - indices[0]) * estimated_frame_time
149
+ if abs(duration - target_duration) > abs(best_duration - target_duration):
150
+ break
151
+
152
+ best_indices = indices
153
+ best_duration = duration
154
+
155
+ return [frames[index] for index in best_indices]
156
+
157
+ def get_crop_coordinates(old_size: tuple, new_size: tuple) -> tuple:
158
+ """
159
+ Calculate the crop coordinates after scaling an image to fit a new size.
160
+
161
+ :param old_size: tuple of the form (width, height) representing the original size of the image.
162
+ :param new_size: tuple of the form (width, height) representing the desired size after scaling.
163
+ :return: tuple of the form (left, upper, right, lower) representing the normalized crop coordinates.
164
+ """
165
+ # Check if the input tuples have the right form (width, height)
166
+ if not (isinstance(old_size, tuple) and isinstance(new_size, tuple) and
167
+ len(old_size) == 2 and len(new_size) == 2):
168
+ raise ValueError("old_size and new_size should be tuples of the form (width, height)")
169
+
170
+ # Extract the width and height from the old and new sizes
171
+ old_width, old_height = old_size
172
+ new_width, new_height = new_size
173
+
174
+ # Calculate the ratios for width and height
175
+ ratio_w = float(new_width) / old_width
176
+ ratio_h = float(new_height) / old_height
177
+
178
+ # Determine which dimension is fixed (width or height)
179
+ if ratio_w > ratio_h:
180
+ # It must be fixed by width
181
+ resize_width = new_width
182
+ resize_height = round(old_height * ratio_w)
183
+ else:
184
+ # Fixed by height
185
+ resize_width = round(old_width * ratio_h)
186
+ resize_height = new_height
187
+
188
+ # Calculate cropping boundaries in the resized image space
189
+ left = (resize_width - new_width) / 2
190
+ upper = (resize_height - new_height) / 2
191
+ right = (resize_width + new_width) / 2
192
+ lower = (resize_height + new_height) / 2
193
+
194
+ # Normalize the cropping coordinates
195
+
196
+ # Return the normalized coordinates as a tuple
197
+ return (left, upper, right, lower)
198
+
199
+ aspect_ratio_to_1024_map = {
200
+ "0.42": [640, 1536],
201
+ "0.57": [768, 1344],
202
+ "0.68": [832, 1216],
203
+ "1.00": [1024, 1024],
204
+ "1.46": [1216, 832],
205
+ "1.75": [1344, 768],
206
+ "2.40": [1536, 640]
207
+ }
208
+
209
+ res_to_aspect_map = {
210
+ 1024: aspect_ratio_to_1024_map,
211
+ 512: {key: [value[0] // 2, value[1] // 2] for key, value in aspect_ratio_to_1024_map.items()},
212
+ }
213
+
214
+ def best_aspect_ratio(aspect_ratio: float, resolution: int):
215
+
216
+ map = res_to_aspect_map[resolution]
217
+
218
+ d = 99999999
219
+ res = None
220
+ for key, value in map.items():
221
+ ar = value[0] / value[1]
222
+ diff = abs(aspect_ratio - ar)
223
+ if diff < d:
224
+ d = diff
225
+ res = value
226
+
227
+ ar = res[0] / res[1]
228
+ return f"{ar:.2f}", res