Spaces:
Runtime error
Runtime error
Commit
•
8a943d8
1
Parent(s):
d02956d
Upload 3 files
Browse files- fine_tune.py +987 -0
- setup.py +15 -0
- utils.py +228 -0
fine_tune.py
ADDED
@@ -0,0 +1,987 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Natural Synthetics Inc. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
import math
|
17 |
+
import os
|
18 |
+
import traceback
|
19 |
+
from pathlib import Path
|
20 |
+
import time
|
21 |
+
import torch
|
22 |
+
import torch.utils.checkpoint
|
23 |
+
import torch.multiprocessing as mp
|
24 |
+
from accelerate import Accelerator
|
25 |
+
from accelerate.logging import get_logger
|
26 |
+
from accelerate.utils import set_seed
|
27 |
+
from diffusers import AutoencoderKL
|
28 |
+
from diffusers.optimization import get_scheduler
|
29 |
+
from diffusers import DDPMScheduler
|
30 |
+
from torchvision import transforms
|
31 |
+
from tqdm.auto import tqdm
|
32 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
33 |
+
import torch.nn.functional as F
|
34 |
+
import gc
|
35 |
+
from typing import Callable
|
36 |
+
from PIL import Image
|
37 |
+
import numpy as np
|
38 |
+
from concurrent.futures import ThreadPoolExecutor
|
39 |
+
from hotshot_xl.models.unet import UNet3DConditionModel
|
40 |
+
from hotshot_xl.pipelines.hotshot_xl_pipeline import HotshotXLPipeline
|
41 |
+
from hotshot_xl.utils import get_crop_coordinates, res_to_aspect_map, scale_aspect_fill
|
42 |
+
from einops import rearrange
|
43 |
+
from torch.utils.data import Dataset, DataLoader
|
44 |
+
from datetime import timedelta
|
45 |
+
from accelerate.utils.dataclasses import InitProcessGroupKwargs
|
46 |
+
from diffusers.utils import is_wandb_available
|
47 |
+
|
48 |
+
if is_wandb_available():
|
49 |
+
import wandb
|
50 |
+
|
51 |
+
logger = get_logger(__file__)
|
52 |
+
|
53 |
+
|
54 |
+
class HotshotXLDataset(Dataset):
|
55 |
+
|
56 |
+
def __init__(self, directory: str, make_sample_fn: Callable):
|
57 |
+
"""
|
58 |
+
|
59 |
+
Training data folder needs to look like:
|
60 |
+
+ training_samples
|
61 |
+
--- + sample_001
|
62 |
+
------- + frame_0.jpg
|
63 |
+
------- + frame_1.jpg
|
64 |
+
------- + ...
|
65 |
+
------- + frame_n.jpg
|
66 |
+
------- + prompt.txt
|
67 |
+
--- + sample_002
|
68 |
+
------- + frame_0.jpg
|
69 |
+
------- + frame_1.jpg
|
70 |
+
------- + ...
|
71 |
+
------- + frame_n.jpg
|
72 |
+
------- + prompt.txt
|
73 |
+
|
74 |
+
Args:
|
75 |
+
directory: base directory of the training samples
|
76 |
+
make_sample_fn: a delegate call to load the images and prep the sample for batching
|
77 |
+
"""
|
78 |
+
samples_dir = [os.path.join(directory, p) for p in os.listdir(directory)]
|
79 |
+
samples_dir = [p for p in samples_dir if os.path.isdir(p)]
|
80 |
+
samples = []
|
81 |
+
|
82 |
+
for d in samples_dir:
|
83 |
+
file_paths = [os.path.join(d, p) for p in os.listdir(d)]
|
84 |
+
image_fps = [f for f in file_paths if os.path.splitext(f)[1] in {".png", ".jpg"}]
|
85 |
+
with open(os.path.join(d, "prompt.txt")) as f:
|
86 |
+
prompt = f.read().strip()
|
87 |
+
|
88 |
+
samples.append({
|
89 |
+
"image_fps": image_fps,
|
90 |
+
"prompt": prompt
|
91 |
+
})
|
92 |
+
|
93 |
+
self.samples = samples
|
94 |
+
self.length = len(samples)
|
95 |
+
self.make_sample_fn = make_sample_fn
|
96 |
+
|
97 |
+
def __len__(self):
|
98 |
+
return self.length
|
99 |
+
|
100 |
+
def __getitem__(self, index):
|
101 |
+
return self.make_sample_fn(
|
102 |
+
self.samples[index]
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
def parse_args():
|
107 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
108 |
+
parser.add_argument(
|
109 |
+
"--pretrained_model_name_or_path",
|
110 |
+
type=str,
|
111 |
+
default="hotshotco/Hotshot-XL",
|
112 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--unet_resume_path",
|
116 |
+
type=str,
|
117 |
+
default=None,
|
118 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
119 |
+
)
|
120 |
+
|
121 |
+
parser.add_argument(
|
122 |
+
"--data_dir",
|
123 |
+
type=str,
|
124 |
+
required=True,
|
125 |
+
help="Path to data to train.",
|
126 |
+
)
|
127 |
+
|
128 |
+
parser.add_argument(
|
129 |
+
"--report_to",
|
130 |
+
type=str,
|
131 |
+
default="wandb",
|
132 |
+
help=(
|
133 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
134 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
135 |
+
),
|
136 |
+
)
|
137 |
+
|
138 |
+
parser.add_argument("--run_validation_at_start", action="store_true")
|
139 |
+
parser.add_argument("--max_vae_encode", type=int, default=None)
|
140 |
+
parser.add_argument("--vae_b16", action="store_true")
|
141 |
+
parser.add_argument("--disable_optimizer_restore", action="store_true")
|
142 |
+
|
143 |
+
parser.add_argument(
|
144 |
+
"--latent_nan_checking",
|
145 |
+
action="store_true",
|
146 |
+
help="Check if latents contain nans - important if vae is f16",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--test_prompts",
|
150 |
+
type=str,
|
151 |
+
default=None,
|
152 |
+
)
|
153 |
+
parser.add_argument(
|
154 |
+
"--project_name",
|
155 |
+
type=str,
|
156 |
+
default="fine-tune-hotshot-xl",
|
157 |
+
help="the name of the run",
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--run_name",
|
161 |
+
type=str,
|
162 |
+
default="run-01",
|
163 |
+
help="the name of the run",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--output_dir",
|
167 |
+
type=str,
|
168 |
+
default="output",
|
169 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
170 |
+
)
|
171 |
+
parser.add_argument("--noise_offset", type=float, default=0.05, help="The scale of noise offset.")
|
172 |
+
parser.add_argument("--seed", type=int, default=111, help="A seed for reproducible training.")
|
173 |
+
parser.add_argument(
|
174 |
+
"--resolution",
|
175 |
+
type=int,
|
176 |
+
default=512,
|
177 |
+
help=(
|
178 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
179 |
+
" resolution"
|
180 |
+
),
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--aspect_ratio",
|
184 |
+
type=str,
|
185 |
+
default="1.75",
|
186 |
+
choices=list(res_to_aspect_map[512].keys()),
|
187 |
+
help="Aspect ratio to train at",
|
188 |
+
)
|
189 |
+
|
190 |
+
parser.add_argument("--xformers", action="store_true")
|
191 |
+
|
192 |
+
parser.add_argument(
|
193 |
+
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
|
194 |
+
)
|
195 |
+
|
196 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
197 |
+
|
198 |
+
parser.add_argument(
|
199 |
+
"--max_train_steps",
|
200 |
+
type=int,
|
201 |
+
default=9999999,
|
202 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
203 |
+
)
|
204 |
+
parser.add_argument(
|
205 |
+
"--gradient_accumulation_steps",
|
206 |
+
type=int,
|
207 |
+
default=1,
|
208 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
209 |
+
)
|
210 |
+
parser.add_argument(
|
211 |
+
"--gradient_checkpointing",
|
212 |
+
action="store_true",
|
213 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
214 |
+
)
|
215 |
+
|
216 |
+
parser.add_argument(
|
217 |
+
"--learning_rate",
|
218 |
+
type=float,
|
219 |
+
default=5e-6,
|
220 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
221 |
+
)
|
222 |
+
|
223 |
+
parser.add_argument(
|
224 |
+
"--scale_lr",
|
225 |
+
action="store_true",
|
226 |
+
default=False,
|
227 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--lr_scheduler",
|
231 |
+
type=str,
|
232 |
+
default="constant",
|
233 |
+
help=(
|
234 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
235 |
+
' "constant", "constant_with_warmup"]'
|
236 |
+
),
|
237 |
+
)
|
238 |
+
parser.add_argument(
|
239 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
240 |
+
)
|
241 |
+
parser.add_argument(
|
242 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
243 |
+
)
|
244 |
+
|
245 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
246 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
247 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
248 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
249 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
250 |
+
|
251 |
+
parser.add_argument(
|
252 |
+
"--logging_dir",
|
253 |
+
type=str,
|
254 |
+
default="logs",
|
255 |
+
help=(
|
256 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
257 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
258 |
+
),
|
259 |
+
)
|
260 |
+
|
261 |
+
parser.add_argument(
|
262 |
+
"--mixed_precision",
|
263 |
+
type=str,
|
264 |
+
default="no",
|
265 |
+
choices=["no", "fp16", "bf16"],
|
266 |
+
help=(
|
267 |
+
"Whether to use mixed precision. Choose"
|
268 |
+
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
269 |
+
"and an Nvidia Ampere GPU."
|
270 |
+
),
|
271 |
+
)
|
272 |
+
|
273 |
+
parser.add_argument(
|
274 |
+
"--validate_every_steps",
|
275 |
+
type=int,
|
276 |
+
default=100,
|
277 |
+
help="Run inference every",
|
278 |
+
)
|
279 |
+
|
280 |
+
parser.add_argument(
|
281 |
+
"--save_n_steps",
|
282 |
+
type=int,
|
283 |
+
default=100,
|
284 |
+
help="Save the model every n global_steps",
|
285 |
+
)
|
286 |
+
|
287 |
+
parser.add_argument(
|
288 |
+
"--save_starting_step",
|
289 |
+
type=int,
|
290 |
+
default=100,
|
291 |
+
help="The step from which it starts saving intermediary checkpoints",
|
292 |
+
)
|
293 |
+
|
294 |
+
parser.add_argument(
|
295 |
+
"--nccl_timeout",
|
296 |
+
type=int,
|
297 |
+
help="nccl_timeout",
|
298 |
+
default=3600
|
299 |
+
)
|
300 |
+
|
301 |
+
parser.add_argument("--snr_gamma", action="store_true")
|
302 |
+
|
303 |
+
args = parser.parse_args()
|
304 |
+
|
305 |
+
return args
|
306 |
+
|
307 |
+
|
308 |
+
def add_time_ids(
|
309 |
+
unet_config,
|
310 |
+
unet_add_embedding,
|
311 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
312 |
+
original_size: tuple,
|
313 |
+
crops_coords_top_left: tuple,
|
314 |
+
target_size: tuple,
|
315 |
+
dtype: torch.dtype):
|
316 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
317 |
+
|
318 |
+
passed_add_embed_dim = (
|
319 |
+
unet_config.addition_time_embed_dim * len(add_time_ids) + text_encoder_2.config.projection_dim
|
320 |
+
)
|
321 |
+
expected_add_embed_dim = unet_add_embedding.linear_1.in_features
|
322 |
+
|
323 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
324 |
+
raise ValueError(
|
325 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
326 |
+
)
|
327 |
+
|
328 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
329 |
+
return add_time_ids
|
330 |
+
|
331 |
+
|
332 |
+
def main():
|
333 |
+
global_step = 0
|
334 |
+
min_steps_before_validation = 0
|
335 |
+
|
336 |
+
args = parse_args()
|
337 |
+
|
338 |
+
next_save_iter = args.save_starting_step
|
339 |
+
|
340 |
+
if args.save_starting_step < 1:
|
341 |
+
next_save_iter = None
|
342 |
+
|
343 |
+
if args.report_to == "wandb":
|
344 |
+
if not is_wandb_available():
|
345 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
346 |
+
|
347 |
+
accelerator = Accelerator(
|
348 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
349 |
+
mixed_precision=args.mixed_precision,
|
350 |
+
log_with=args.report_to,
|
351 |
+
kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(args.nccl_timeout))]
|
352 |
+
)
|
353 |
+
|
354 |
+
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
355 |
+
def save_model_hook(models, weights, output_dir):
|
356 |
+
nonlocal global_step
|
357 |
+
|
358 |
+
for model in models:
|
359 |
+
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
360 |
+
model.save_pretrained(os.path.join(output_dir, 'unet'))
|
361 |
+
# make sure to pop weight so that corresponding model is not saved again
|
362 |
+
weights.pop()
|
363 |
+
|
364 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
365 |
+
|
366 |
+
set_seed(args.seed)
|
367 |
+
|
368 |
+
# Handle the repository creation
|
369 |
+
if accelerator.is_local_main_process:
|
370 |
+
if args.output_dir is not None:
|
371 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
372 |
+
|
373 |
+
# Load the tokenizer
|
374 |
+
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
375 |
+
tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
|
376 |
+
|
377 |
+
# Load models and create wrapper for stable diffusion
|
378 |
+
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
379 |
+
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path,
|
380 |
+
subfolder="text_encoder_2")
|
381 |
+
|
382 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
383 |
+
|
384 |
+
optimizer_resume_path = None
|
385 |
+
|
386 |
+
if args.unet_resume_path:
|
387 |
+
optimizer_fp = os.path.join(args.unet_resume_path, "optimizer.bin")
|
388 |
+
|
389 |
+
if os.path.exists(optimizer_fp):
|
390 |
+
optimizer_resume_path = optimizer_fp
|
391 |
+
|
392 |
+
unet = UNet3DConditionModel.from_pretrained(args.unet_resume_path,
|
393 |
+
subfolder="unet",
|
394 |
+
low_cpu_mem_usage=False,
|
395 |
+
device_map=None)
|
396 |
+
|
397 |
+
else:
|
398 |
+
unet = UNet3DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
399 |
+
|
400 |
+
if args.xformers:
|
401 |
+
vae.set_use_memory_efficient_attention_xformers(True, None)
|
402 |
+
unet.set_use_memory_efficient_attention_xformers(True, None)
|
403 |
+
|
404 |
+
unet_config = unet.config
|
405 |
+
unet_add_embedding = unet.add_embedding
|
406 |
+
|
407 |
+
unet.requires_grad_(False)
|
408 |
+
|
409 |
+
temporal_params = unet.temporal_parameters()
|
410 |
+
|
411 |
+
for p in temporal_params:
|
412 |
+
p.requires_grad_(True)
|
413 |
+
|
414 |
+
vae.requires_grad_(False)
|
415 |
+
text_encoder.requires_grad_(False)
|
416 |
+
text_encoder_2.requires_grad_(False)
|
417 |
+
|
418 |
+
if args.gradient_checkpointing:
|
419 |
+
unet.enable_gradient_checkpointing()
|
420 |
+
|
421 |
+
if args.scale_lr:
|
422 |
+
args.learning_rate = (
|
423 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
424 |
+
)
|
425 |
+
|
426 |
+
# Use 8-bit Adam for lower memory usage
|
427 |
+
if args.use_8bit_adam:
|
428 |
+
try:
|
429 |
+
import bitsandbytes as bnb
|
430 |
+
except ImportError:
|
431 |
+
raise ImportError(
|
432 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
433 |
+
)
|
434 |
+
|
435 |
+
optimizer_class = bnb.optim.AdamW8bit
|
436 |
+
else:
|
437 |
+
optimizer_class = torch.optim.AdamW
|
438 |
+
|
439 |
+
learning_rate = args.learning_rate
|
440 |
+
|
441 |
+
params_to_optimize = [
|
442 |
+
{'params': temporal_params, "lr": learning_rate},
|
443 |
+
]
|
444 |
+
|
445 |
+
optimizer = optimizer_class(
|
446 |
+
params_to_optimize,
|
447 |
+
lr=args.learning_rate,
|
448 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
449 |
+
weight_decay=args.adam_weight_decay,
|
450 |
+
eps=args.adam_epsilon,
|
451 |
+
)
|
452 |
+
|
453 |
+
if optimizer_resume_path and not args.disable_optimizer_restore:
|
454 |
+
logger.info("Restoring the optimizer.")
|
455 |
+
try:
|
456 |
+
|
457 |
+
old_optimizer_state_dict = torch.load(optimizer_resume_path)
|
458 |
+
|
459 |
+
# Extract only the state
|
460 |
+
old_state = old_optimizer_state_dict['state']
|
461 |
+
|
462 |
+
# Set the state of the new optimizer
|
463 |
+
optimizer.load_state_dict({'state': old_state, 'param_groups': optimizer.param_groups})
|
464 |
+
|
465 |
+
del old_optimizer_state_dict
|
466 |
+
del old_state
|
467 |
+
|
468 |
+
torch.cuda.empty_cache()
|
469 |
+
torch.cuda.synchronize()
|
470 |
+
gc.collect()
|
471 |
+
|
472 |
+
logger.info(f"Restored the optimizer ok")
|
473 |
+
|
474 |
+
except:
|
475 |
+
logger.error("Failed to restore the optimizer...", exc_info=True)
|
476 |
+
traceback.print_exc()
|
477 |
+
raise
|
478 |
+
|
479 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
480 |
+
|
481 |
+
def compute_snr(timesteps):
|
482 |
+
"""
|
483 |
+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
484 |
+
"""
|
485 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
486 |
+
sqrt_alphas_cumprod = alphas_cumprod ** 0.5
|
487 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
488 |
+
|
489 |
+
# Expand the tensors.
|
490 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
491 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
492 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
493 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
494 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
495 |
+
|
496 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
497 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
498 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
499 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
500 |
+
|
501 |
+
# Compute SNR.
|
502 |
+
snr = (alpha / sigma) ** 2
|
503 |
+
return snr
|
504 |
+
|
505 |
+
device = torch.device('cuda')
|
506 |
+
|
507 |
+
image_transforms = transforms.Compose(
|
508 |
+
[
|
509 |
+
transforms.ToTensor(),
|
510 |
+
transforms.Normalize([0.5], [0.5]),
|
511 |
+
]
|
512 |
+
)
|
513 |
+
|
514 |
+
def image_to_tensor(img):
|
515 |
+
with torch.no_grad():
|
516 |
+
|
517 |
+
if img.mode != "RGB":
|
518 |
+
img = img.convert("RGB")
|
519 |
+
|
520 |
+
image = image_transforms(img).to(accelerator.device)
|
521 |
+
|
522 |
+
if image.shape[0] == 1:
|
523 |
+
image = image.repeat(3, 1, 1)
|
524 |
+
|
525 |
+
if image.shape[0] > 3:
|
526 |
+
image = image[:3, :, :]
|
527 |
+
|
528 |
+
return image
|
529 |
+
|
530 |
+
def make_sample(sample):
|
531 |
+
|
532 |
+
nonlocal unet_config
|
533 |
+
nonlocal unet_add_embedding
|
534 |
+
|
535 |
+
images = [Image.open(img) for img in sample['image_fps']]
|
536 |
+
|
537 |
+
og_size = images[0].size
|
538 |
+
|
539 |
+
for i, im in enumerate(images):
|
540 |
+
if im.mode != "RGB":
|
541 |
+
images[i] = im.convert("RGB")
|
542 |
+
|
543 |
+
aspect_ratio_map = res_to_aspect_map[args.resolution]
|
544 |
+
|
545 |
+
required_size = tuple(aspect_ratio_map[args.aspect_ratio])
|
546 |
+
|
547 |
+
if required_size != og_size:
|
548 |
+
|
549 |
+
def resize_image(x):
|
550 |
+
img_size = x.size
|
551 |
+
if img_size == required_size:
|
552 |
+
return x.resize(required_size, Image.LANCZOS)
|
553 |
+
|
554 |
+
return scale_aspect_fill(x, required_size[0], required_size[1])
|
555 |
+
|
556 |
+
with ThreadPoolExecutor(max_workers=len(images)) as executor:
|
557 |
+
images = list(executor.map(resize_image, images))
|
558 |
+
|
559 |
+
frames = torch.stack([image_to_tensor(x) for x in images])
|
560 |
+
|
561 |
+
l, u, *_ = get_crop_coordinates(og_size, images[0].size)
|
562 |
+
crop_coords = (l, u)
|
563 |
+
|
564 |
+
additional_time_ids = add_time_ids(
|
565 |
+
unet_config,
|
566 |
+
unet_add_embedding,
|
567 |
+
text_encoder_2,
|
568 |
+
og_size,
|
569 |
+
crop_coords,
|
570 |
+
(required_size[0], required_size[1]),
|
571 |
+
dtype=torch.float32
|
572 |
+
).to(device)
|
573 |
+
|
574 |
+
input_ids_0 = tokenizer(
|
575 |
+
sample['prompt'],
|
576 |
+
padding="do_not_pad",
|
577 |
+
truncation=True,
|
578 |
+
max_length=tokenizer.model_max_length,
|
579 |
+
).input_ids
|
580 |
+
|
581 |
+
input_ids_1 = tokenizer_2(
|
582 |
+
sample['prompt'],
|
583 |
+
padding="do_not_pad",
|
584 |
+
truncation=True,
|
585 |
+
max_length=tokenizer.model_max_length,
|
586 |
+
).input_ids
|
587 |
+
|
588 |
+
return {
|
589 |
+
"frames": frames,
|
590 |
+
"input_ids_0": input_ids_0,
|
591 |
+
"input_ids_1": input_ids_1,
|
592 |
+
"additional_time_ids": additional_time_ids,
|
593 |
+
}
|
594 |
+
|
595 |
+
def collate_fn(examples: list) -> dict:
|
596 |
+
|
597 |
+
# Two Text encoders
|
598 |
+
# First Text Encoder -> Penultimate Layer
|
599 |
+
# Second Text Encoder -> Pooled Layer
|
600 |
+
|
601 |
+
input_ids_0 = [example['input_ids_0'] for example in examples]
|
602 |
+
input_ids_0 = tokenizer.pad({"input_ids": input_ids_0}, padding="max_length",
|
603 |
+
max_length=tokenizer.model_max_length, return_tensors="pt").input_ids
|
604 |
+
|
605 |
+
prompt_embeds_0 = text_encoder(
|
606 |
+
input_ids_0.to(device),
|
607 |
+
output_hidden_states=True,
|
608 |
+
)
|
609 |
+
|
610 |
+
# we take penultimate embeddings from the first text encoder
|
611 |
+
prompt_embeds_0 = prompt_embeds_0.hidden_states[-2]
|
612 |
+
|
613 |
+
input_ids_1 = [example['input_ids_1'] for example in examples]
|
614 |
+
input_ids_1 = tokenizer_2.pad({"input_ids": input_ids_1}, padding="max_length",
|
615 |
+
max_length=tokenizer.model_max_length, return_tensors="pt").input_ids
|
616 |
+
|
617 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
618 |
+
prompt_embeds = text_encoder_2(
|
619 |
+
input_ids_1.to(device),
|
620 |
+
output_hidden_states=True
|
621 |
+
)
|
622 |
+
|
623 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
624 |
+
prompt_embeds_1 = prompt_embeds.hidden_states[-2]
|
625 |
+
|
626 |
+
prompt_embeds = torch.concat([prompt_embeds_0, prompt_embeds_1], dim=-1)
|
627 |
+
|
628 |
+
*_, h, w = examples[0]['frames'].shape
|
629 |
+
|
630 |
+
return {
|
631 |
+
"frames": torch.stack([x['frames'] for x in examples]).to(memory_format=torch.contiguous_format).float(),
|
632 |
+
"prompt_embeds": prompt_embeds.to(memory_format=torch.contiguous_format).float(),
|
633 |
+
"pooled_prompt_embeds": pooled_prompt_embeds,
|
634 |
+
"additional_time_ids": torch.stack([x['additional_time_ids'] for x in examples]),
|
635 |
+
}
|
636 |
+
|
637 |
+
# Region - Dataloaders
|
638 |
+
dataset = HotshotXLDataset(args.data_dir, make_sample)
|
639 |
+
dataloader = DataLoader(dataset, args.train_batch_size, shuffle=True, collate_fn=collate_fn)
|
640 |
+
|
641 |
+
# Scheduler and math around the number of training steps.
|
642 |
+
overrode_max_train_steps = False
|
643 |
+
num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps)
|
644 |
+
|
645 |
+
if args.max_train_steps is None:
|
646 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
647 |
+
overrode_max_train_steps = True
|
648 |
+
|
649 |
+
lr_scheduler = get_scheduler(
|
650 |
+
args.lr_scheduler,
|
651 |
+
optimizer=optimizer,
|
652 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
653 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
654 |
+
)
|
655 |
+
|
656 |
+
unet, optimizer, lr_scheduler, dataloader = accelerator.prepare(
|
657 |
+
unet, optimizer, lr_scheduler, dataloader
|
658 |
+
)
|
659 |
+
|
660 |
+
def to_images(video_frames: torch.Tensor):
|
661 |
+
import torchvision.transforms as transforms
|
662 |
+
to_pil = transforms.ToPILImage()
|
663 |
+
video_frames = rearrange(video_frames, "b c f w h -> b f c w h")
|
664 |
+
bsz = video_frames.shape[0]
|
665 |
+
images = []
|
666 |
+
for i in range(bsz):
|
667 |
+
video = video_frames[i]
|
668 |
+
for j in range(video.shape[0]):
|
669 |
+
image = to_pil(video[j])
|
670 |
+
images.append(image)
|
671 |
+
return images
|
672 |
+
|
673 |
+
def to_video_frames(images: list) -> np.ndarray:
|
674 |
+
x = np.stack([np.asarray(img) for img in images])
|
675 |
+
return np.transpose(x, (0, 3, 1, 2))
|
676 |
+
|
677 |
+
def run_validation(step=0, node_index=0):
|
678 |
+
|
679 |
+
nonlocal global_step
|
680 |
+
nonlocal accelerator
|
681 |
+
|
682 |
+
if args.test_prompts:
|
683 |
+
prompts = args.test_prompts.split("|")
|
684 |
+
else:
|
685 |
+
prompts = [
|
686 |
+
"a woman is lifting weights in a gym",
|
687 |
+
"a group of people are dancing at a party",
|
688 |
+
"a teddy bear doing the front crawl"
|
689 |
+
]
|
690 |
+
|
691 |
+
torch.cuda.empty_cache()
|
692 |
+
gc.collect()
|
693 |
+
|
694 |
+
logger.info(f"Running inference to test model at {step} steps")
|
695 |
+
with torch.no_grad():
|
696 |
+
|
697 |
+
pipe = HotshotXLPipeline.from_pretrained(
|
698 |
+
args.pretrained_model_name_or_path,
|
699 |
+
unet=accelerator.unwrap_model(unet),
|
700 |
+
text_encoder=text_encoder,
|
701 |
+
text_encoder_2=text_encoder_2,
|
702 |
+
vae=vae,
|
703 |
+
)
|
704 |
+
|
705 |
+
videos = []
|
706 |
+
|
707 |
+
aspect_ratio_map = res_to_aspect_map[args.resolution]
|
708 |
+
w, h = aspect_ratio_map[args.aspect_ratio]
|
709 |
+
|
710 |
+
for prompt in prompts:
|
711 |
+
video = pipe(prompt,
|
712 |
+
width=w,
|
713 |
+
height=h,
|
714 |
+
original_size=(1920, 1080), # todo - pass in as args?
|
715 |
+
target_size=(args.resolution, args.resolution),
|
716 |
+
num_inference_steps=30,
|
717 |
+
video_length=8,
|
718 |
+
output_type="tensor",
|
719 |
+
generator=torch.Generator().manual_seed(111)).videos
|
720 |
+
|
721 |
+
videos.append(to_images(video))
|
722 |
+
|
723 |
+
for tracker in accelerator.trackers:
|
724 |
+
|
725 |
+
if tracker.name == "wandb":
|
726 |
+
tracker.log(
|
727 |
+
{
|
728 |
+
"validation": [wandb.Video(to_video_frames(video), fps=8, format='mp4') for video in
|
729 |
+
videos],
|
730 |
+
}, step=global_step
|
731 |
+
)
|
732 |
+
|
733 |
+
del pipe
|
734 |
+
|
735 |
+
return
|
736 |
+
|
737 |
+
# Move text_encode and vae to gpu.
|
738 |
+
vae.to(accelerator.device, dtype=torch.bfloat16 if args.vae_b16 else torch.float32)
|
739 |
+
text_encoder.to(accelerator.device)
|
740 |
+
text_encoder_2.to(accelerator.device)
|
741 |
+
|
742 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
743 |
+
|
744 |
+
num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps)
|
745 |
+
if overrode_max_train_steps:
|
746 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
747 |
+
# Afterward we recalculate our number of training epochs
|
748 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
749 |
+
|
750 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
751 |
+
# The trackers initialize automatically on the main process.
|
752 |
+
|
753 |
+
if accelerator.is_main_process:
|
754 |
+
accelerator.init_trackers(args.project_name)
|
755 |
+
|
756 |
+
def bar(prg):
|
757 |
+
br = '|' + '█' * prg + ' ' * (25 - prg) + '|'
|
758 |
+
return br
|
759 |
+
|
760 |
+
# Train!
|
761 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
762 |
+
|
763 |
+
if accelerator.is_main_process:
|
764 |
+
logger.info("***** Running training *****")
|
765 |
+
logger.info(f" Num examples = {len(dataset)}")
|
766 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
767 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
768 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
769 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
770 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
771 |
+
|
772 |
+
# Only show the progress bar once on each machine.
|
773 |
+
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
774 |
+
|
775 |
+
latents_scaler = vae.config.scaling_factor
|
776 |
+
|
777 |
+
def save_checkpoint():
|
778 |
+
save_dir = Path(args.output_dir)
|
779 |
+
save_dir = str(save_dir)
|
780 |
+
save_dir = save_dir.replace(" ", "_")
|
781 |
+
if not os.path.exists(save_dir):
|
782 |
+
os.makedirs(save_dir, exist_ok=True)
|
783 |
+
accelerator.save_state(save_dir)
|
784 |
+
|
785 |
+
def save_checkpoint_and_wait():
|
786 |
+
if accelerator.is_main_process:
|
787 |
+
save_checkpoint()
|
788 |
+
accelerator.wait_for_everyone()
|
789 |
+
|
790 |
+
def save_model_and_wait():
|
791 |
+
if accelerator.is_main_process:
|
792 |
+
HotshotXLPipeline.from_pretrained(
|
793 |
+
args.pretrained_model_name_or_path,
|
794 |
+
unet=accelerator.unwrap_model(unet),
|
795 |
+
text_encoder=text_encoder,
|
796 |
+
text_encoder_2=text_encoder_2,
|
797 |
+
vae=vae,
|
798 |
+
).save_pretrained(args.output_dir, safe_serialization=True)
|
799 |
+
accelerator.wait_for_everyone()
|
800 |
+
|
801 |
+
def compute_loss_from_batch(batch: dict):
|
802 |
+
frames = batch["frames"]
|
803 |
+
bsz, number_of_frames, c, w, h = frames.shape
|
804 |
+
|
805 |
+
# Convert images to latent space
|
806 |
+
with torch.no_grad():
|
807 |
+
|
808 |
+
if args.max_vae_encode:
|
809 |
+
latents = []
|
810 |
+
|
811 |
+
x = rearrange(frames, "bs nf c h w -> (bs nf) c h w")
|
812 |
+
|
813 |
+
for latent_index in range(0, x.shape[0], args.max_vae_encode):
|
814 |
+
sample = x[latent_index: latent_index + args.max_vae_encode]
|
815 |
+
|
816 |
+
latent = vae.encode(sample.to(dtype=vae.dtype)).latent_dist.sample().float()
|
817 |
+
if len(latent.shape) == 3:
|
818 |
+
latent = latent.unsqueeze(0)
|
819 |
+
|
820 |
+
latents.append(latent)
|
821 |
+
torch.cuda.empty_cache()
|
822 |
+
|
823 |
+
latents = torch.cat(latents, dim=0)
|
824 |
+
else:
|
825 |
+
|
826 |
+
# convert the latents from 5d -> 4d, so we can run it though the vae encoder
|
827 |
+
x = rearrange(frames, "bs nf c h w -> (bs nf) c h w")
|
828 |
+
|
829 |
+
del frames
|
830 |
+
|
831 |
+
torch.cuda.empty_cache()
|
832 |
+
|
833 |
+
latents = vae.encode(x.to(dtype=vae.dtype)).latent_dist.sample().float()
|
834 |
+
|
835 |
+
if args.latent_nan_checking and torch.any(torch.isnan(latents)):
|
836 |
+
accelerator.print("NaN found in latents, replacing with zeros")
|
837 |
+
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
838 |
+
|
839 |
+
latents = rearrange(latents, "(b f) c h w -> b c f h w", b=bsz)
|
840 |
+
|
841 |
+
torch.cuda.empty_cache()
|
842 |
+
|
843 |
+
noise = torch.randn_like(latents, device=latents.device)
|
844 |
+
|
845 |
+
if args.noise_offset:
|
846 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
847 |
+
noise += args.noise_offset * torch.randn(
|
848 |
+
(latents.shape[0], latents.shape[1], 1, 1, 1), device=latents.device
|
849 |
+
)
|
850 |
+
|
851 |
+
# Sample a random timestep for each image
|
852 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
853 |
+
timesteps = timesteps.long() # .repeat_interleave(number_of_frames)
|
854 |
+
latents = latents * latents_scaler
|
855 |
+
|
856 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
857 |
+
# (this is the forward diffusion process)
|
858 |
+
|
859 |
+
prompt_embeds = batch['prompt_embeds']
|
860 |
+
add_text_embeds = batch['pooled_prompt_embeds']
|
861 |
+
|
862 |
+
additional_time_ids = batch['additional_time_ids'] # .repeat_interleave(number_of_frames, dim=0)
|
863 |
+
|
864 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": additional_time_ids}
|
865 |
+
|
866 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
867 |
+
|
868 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
869 |
+
target = noise
|
870 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
871 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
872 |
+
else:
|
873 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
874 |
+
|
875 |
+
noisy_latents.requires_grad = True
|
876 |
+
|
877 |
+
model_pred = unet(noisy_latents,
|
878 |
+
timesteps,
|
879 |
+
cross_attention_kwargs=None,
|
880 |
+
encoder_hidden_states=prompt_embeds,
|
881 |
+
added_cond_kwargs=added_cond_kwargs,
|
882 |
+
return_dict=False,
|
883 |
+
)[0]
|
884 |
+
|
885 |
+
if args.snr_gamma:
|
886 |
+
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
887 |
+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
888 |
+
# This is discussed in Section 4.2 of the same paper.
|
889 |
+
snr = compute_snr(timesteps)
|
890 |
+
mse_loss_weights = (
|
891 |
+
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
892 |
+
)
|
893 |
+
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
894 |
+
# rebalance the sample-wise losses with their respective loss weights.
|
895 |
+
# Finally, we take the mean of the rebalanced loss.
|
896 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
897 |
+
|
898 |
+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
899 |
+
return loss.mean()
|
900 |
+
else:
|
901 |
+
return F.mse_loss(model_pred.float(), target.float(), reduction='mean')
|
902 |
+
|
903 |
+
def process_batch(batch: dict):
|
904 |
+
nonlocal global_step
|
905 |
+
nonlocal next_save_iter
|
906 |
+
|
907 |
+
now = time.time()
|
908 |
+
|
909 |
+
with accelerator.accumulate(unet):
|
910 |
+
|
911 |
+
logging_data = {}
|
912 |
+
if global_step == 0:
|
913 |
+
# print(f"Running initial validation at step")
|
914 |
+
if accelerator.is_main_process and args.run_validation_at_start:
|
915 |
+
run_validation(step=global_step, node_index=accelerator.process_index // 8)
|
916 |
+
accelerator.wait_for_everyone()
|
917 |
+
|
918 |
+
loss = compute_loss_from_batch(batch)
|
919 |
+
|
920 |
+
accelerator.backward(loss)
|
921 |
+
|
922 |
+
if accelerator.sync_gradients:
|
923 |
+
accelerator.clip_grad_norm_(temporal_params, args.max_grad_norm)
|
924 |
+
|
925 |
+
optimizer.step()
|
926 |
+
|
927 |
+
lr_scheduler.step()
|
928 |
+
optimizer.zero_grad()
|
929 |
+
|
930 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
931 |
+
if accelerator.sync_gradients:
|
932 |
+
progress_bar.update(1)
|
933 |
+
global_step += 1
|
934 |
+
|
935 |
+
fll = round((global_step * 100) / args.max_train_steps)
|
936 |
+
fll = round(fll / 4)
|
937 |
+
pr = bar(fll)
|
938 |
+
|
939 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "loss_time": (time.time() - now)}
|
940 |
+
|
941 |
+
if args.validate_every_steps is not None and global_step > min_steps_before_validation and global_step % args.validate_every_steps == 0:
|
942 |
+
if accelerator.is_main_process:
|
943 |
+
run_validation(step=global_step, node_index=accelerator.process_index // 8)
|
944 |
+
|
945 |
+
accelerator.wait_for_everyone()
|
946 |
+
|
947 |
+
for key, val in logging_data.items():
|
948 |
+
logs[key] = val
|
949 |
+
|
950 |
+
progress_bar.set_postfix(**logs)
|
951 |
+
progress_bar.set_description_str("Progress:" + pr)
|
952 |
+
accelerator.log(logs, step=global_step)
|
953 |
+
|
954 |
+
if accelerator.is_main_process \
|
955 |
+
and next_save_iter is not None \
|
956 |
+
and global_step < args.max_train_steps \
|
957 |
+
and global_step + 1 == next_save_iter:
|
958 |
+
save_checkpoint()
|
959 |
+
|
960 |
+
torch.cuda.empty_cache()
|
961 |
+
gc.collect()
|
962 |
+
|
963 |
+
next_save_iter += args.save_n_steps
|
964 |
+
|
965 |
+
for epoch in range(args.num_train_epochs):
|
966 |
+
unet.train()
|
967 |
+
|
968 |
+
for step, batch in enumerate(dataloader):
|
969 |
+
process_batch(batch)
|
970 |
+
|
971 |
+
if global_step >= args.max_train_steps:
|
972 |
+
break
|
973 |
+
|
974 |
+
if global_step >= args.max_train_steps:
|
975 |
+
logger.info("Max train steps reached. Breaking while loop")
|
976 |
+
break
|
977 |
+
|
978 |
+
accelerator.wait_for_everyone()
|
979 |
+
|
980 |
+
save_model_and_wait()
|
981 |
+
|
982 |
+
accelerator.end_training()
|
983 |
+
|
984 |
+
|
985 |
+
if __name__ == "__main__":
|
986 |
+
mp.set_start_method('spawn')
|
987 |
+
main()
|
setup.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='hotshot_xl',
|
5 |
+
version='1.0',
|
6 |
+
packages=find_packages(include=['hotshot_xl*',]),
|
7 |
+
author="Natural Synthetics Inc",
|
8 |
+
install_requires=[
|
9 |
+
"torch>=2.0.1",
|
10 |
+
"torchvision>=0.15.2",
|
11 |
+
"diffusers>=0.21.4",
|
12 |
+
"transformers>=4.33.3",
|
13 |
+
"einops"
|
14 |
+
],
|
15 |
+
)
|
utils.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Natural Synthetics Inc. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import List, Union
|
16 |
+
from io import BytesIO
|
17 |
+
import PIL
|
18 |
+
from PIL import ImageSequence, Image
|
19 |
+
import requests
|
20 |
+
import os
|
21 |
+
import numpy as np
|
22 |
+
import imageio
|
23 |
+
|
24 |
+
|
25 |
+
def get_image(img_path) -> PIL.Image.Image:
|
26 |
+
if img_path.startswith("http"):
|
27 |
+
return PIL.Image.open(requests.get(img_path, stream=True).raw)
|
28 |
+
if os.path.exists(img_path):
|
29 |
+
return Image.open(img_path)
|
30 |
+
raise Exception("File not found")
|
31 |
+
|
32 |
+
def images_to_gif_bytes(images: List, duration: int = 1000) -> bytes:
|
33 |
+
with BytesIO() as output_buffer:
|
34 |
+
# Save the first image
|
35 |
+
images[0].save(output_buffer,
|
36 |
+
format='GIF',
|
37 |
+
save_all=True,
|
38 |
+
append_images=images[1:],
|
39 |
+
duration=duration,
|
40 |
+
loop=0) # 0 means the GIF will loop indefinitely
|
41 |
+
|
42 |
+
# Get the byte array from the buffer
|
43 |
+
gif_bytes = output_buffer.getvalue()
|
44 |
+
|
45 |
+
return gif_bytes
|
46 |
+
|
47 |
+
def save_as_gif(images: List, file_path: str, duration: int = 1000):
|
48 |
+
with open(file_path, "wb") as f:
|
49 |
+
f.write(images_to_gif_bytes(images, duration))
|
50 |
+
|
51 |
+
def images_to_mp4_bytes(images: List[Image.Image], duration: int = 1000) -> bytes:
|
52 |
+
with BytesIO() as output_buffer:
|
53 |
+
with imageio.get_writer(output_buffer, format='mp4', fps=1/(duration/1000)) as writer:
|
54 |
+
for img in images:
|
55 |
+
writer.append_data(np.array(img))
|
56 |
+
mp4_bytes = output_buffer.getvalue()
|
57 |
+
|
58 |
+
return mp4_bytes
|
59 |
+
|
60 |
+
def save_as_mp4(images: List[Image.Image], file_path: str, duration: int = 1000):
|
61 |
+
with open(file_path, "wb") as f:
|
62 |
+
f.write(images_to_mp4_bytes(images, duration))
|
63 |
+
|
64 |
+
def scale_aspect_fill(img, new_width, new_height):
|
65 |
+
new_width = int(new_width)
|
66 |
+
new_height = int(new_height)
|
67 |
+
|
68 |
+
original_width, original_height = img.size
|
69 |
+
ratio_w = float(new_width) / original_width
|
70 |
+
ratio_h = float(new_height) / original_height
|
71 |
+
|
72 |
+
if ratio_w > ratio_h:
|
73 |
+
# It must be fixed by width
|
74 |
+
resize_width = new_width
|
75 |
+
resize_height = round(original_height * ratio_w)
|
76 |
+
else:
|
77 |
+
# Fixed by height
|
78 |
+
resize_width = round(original_width * ratio_h)
|
79 |
+
resize_height = new_height
|
80 |
+
|
81 |
+
img_resized = img.resize((resize_width, resize_height), Image.LANCZOS)
|
82 |
+
|
83 |
+
# Calculate cropping boundaries and do crop
|
84 |
+
left = (resize_width - new_width) / 2
|
85 |
+
top = (resize_height - new_height) / 2
|
86 |
+
right = (resize_width + new_width) / 2
|
87 |
+
bottom = (resize_height + new_height) / 2
|
88 |
+
|
89 |
+
img_cropped = img_resized.crop((left, top, right, bottom))
|
90 |
+
|
91 |
+
return img_cropped
|
92 |
+
|
93 |
+
def extract_gif_frames_from_midpoint(image: Union[str, PIL.Image.Image], fps: int=8, target_duration: int=1000) -> list:
|
94 |
+
# Load the GIF
|
95 |
+
image = get_image(image) if type(image) is str else image
|
96 |
+
|
97 |
+
frames = []
|
98 |
+
|
99 |
+
estimated_frame_time = None
|
100 |
+
|
101 |
+
# some gifs contain the duration - others don't
|
102 |
+
# so if there is a duration we will grab it otherwise we will fall back
|
103 |
+
|
104 |
+
for frame in ImageSequence.Iterator(image):
|
105 |
+
|
106 |
+
frames.append(frame.copy())
|
107 |
+
if 'duration' in frame.info:
|
108 |
+
frame_info_duration = frame.info['duration']
|
109 |
+
if frame_info_duration > 0:
|
110 |
+
estimated_frame_time = frame_info_duration
|
111 |
+
|
112 |
+
if estimated_frame_time is None:
|
113 |
+
if len(frames) <= 16:
|
114 |
+
# assume it's 8fps
|
115 |
+
estimated_frame_time = 1000 // 8
|
116 |
+
else:
|
117 |
+
# assume it's 15 fps
|
118 |
+
estimated_frame_time = 70
|
119 |
+
|
120 |
+
if len(frames) < fps:
|
121 |
+
raise ValueError(f"fps of {fps} is too small for this gif as it only has {len(frames)} frames.")
|
122 |
+
|
123 |
+
skip = len(frames) // fps
|
124 |
+
upper_bound_index = len(frames) - 1
|
125 |
+
|
126 |
+
best_indices = [x for x in range(0, len(frames), skip)][:fps]
|
127 |
+
offset = int(upper_bound_index - best_indices[-1]) // 2
|
128 |
+
best_indices = [x + offset for x in best_indices]
|
129 |
+
best_duration = (best_indices[-1] - best_indices[0]) * estimated_frame_time
|
130 |
+
|
131 |
+
while True:
|
132 |
+
|
133 |
+
skip -= 1
|
134 |
+
|
135 |
+
if skip == 0:
|
136 |
+
break
|
137 |
+
|
138 |
+
indices = [x for x in range(0, len(frames), skip)][:fps]
|
139 |
+
|
140 |
+
# center the indices, so we sample the middle of the gif...
|
141 |
+
offset = int(upper_bound_index - indices[-1]) // 2
|
142 |
+
if offset == 0:
|
143 |
+
# can't shift
|
144 |
+
break
|
145 |
+
indices = [x + offset for x in indices]
|
146 |
+
|
147 |
+
# is the new duration closer to the target than last guess?
|
148 |
+
duration = (indices[-1] - indices[0]) * estimated_frame_time
|
149 |
+
if abs(duration - target_duration) > abs(best_duration - target_duration):
|
150 |
+
break
|
151 |
+
|
152 |
+
best_indices = indices
|
153 |
+
best_duration = duration
|
154 |
+
|
155 |
+
return [frames[index] for index in best_indices]
|
156 |
+
|
157 |
+
def get_crop_coordinates(old_size: tuple, new_size: tuple) -> tuple:
|
158 |
+
"""
|
159 |
+
Calculate the crop coordinates after scaling an image to fit a new size.
|
160 |
+
|
161 |
+
:param old_size: tuple of the form (width, height) representing the original size of the image.
|
162 |
+
:param new_size: tuple of the form (width, height) representing the desired size after scaling.
|
163 |
+
:return: tuple of the form (left, upper, right, lower) representing the normalized crop coordinates.
|
164 |
+
"""
|
165 |
+
# Check if the input tuples have the right form (width, height)
|
166 |
+
if not (isinstance(old_size, tuple) and isinstance(new_size, tuple) and
|
167 |
+
len(old_size) == 2 and len(new_size) == 2):
|
168 |
+
raise ValueError("old_size and new_size should be tuples of the form (width, height)")
|
169 |
+
|
170 |
+
# Extract the width and height from the old and new sizes
|
171 |
+
old_width, old_height = old_size
|
172 |
+
new_width, new_height = new_size
|
173 |
+
|
174 |
+
# Calculate the ratios for width and height
|
175 |
+
ratio_w = float(new_width) / old_width
|
176 |
+
ratio_h = float(new_height) / old_height
|
177 |
+
|
178 |
+
# Determine which dimension is fixed (width or height)
|
179 |
+
if ratio_w > ratio_h:
|
180 |
+
# It must be fixed by width
|
181 |
+
resize_width = new_width
|
182 |
+
resize_height = round(old_height * ratio_w)
|
183 |
+
else:
|
184 |
+
# Fixed by height
|
185 |
+
resize_width = round(old_width * ratio_h)
|
186 |
+
resize_height = new_height
|
187 |
+
|
188 |
+
# Calculate cropping boundaries in the resized image space
|
189 |
+
left = (resize_width - new_width) / 2
|
190 |
+
upper = (resize_height - new_height) / 2
|
191 |
+
right = (resize_width + new_width) / 2
|
192 |
+
lower = (resize_height + new_height) / 2
|
193 |
+
|
194 |
+
# Normalize the cropping coordinates
|
195 |
+
|
196 |
+
# Return the normalized coordinates as a tuple
|
197 |
+
return (left, upper, right, lower)
|
198 |
+
|
199 |
+
aspect_ratio_to_1024_map = {
|
200 |
+
"0.42": [640, 1536],
|
201 |
+
"0.57": [768, 1344],
|
202 |
+
"0.68": [832, 1216],
|
203 |
+
"1.00": [1024, 1024],
|
204 |
+
"1.46": [1216, 832],
|
205 |
+
"1.75": [1344, 768],
|
206 |
+
"2.40": [1536, 640]
|
207 |
+
}
|
208 |
+
|
209 |
+
res_to_aspect_map = {
|
210 |
+
1024: aspect_ratio_to_1024_map,
|
211 |
+
512: {key: [value[0] // 2, value[1] // 2] for key, value in aspect_ratio_to_1024_map.items()},
|
212 |
+
}
|
213 |
+
|
214 |
+
def best_aspect_ratio(aspect_ratio: float, resolution: int):
|
215 |
+
|
216 |
+
map = res_to_aspect_map[resolution]
|
217 |
+
|
218 |
+
d = 99999999
|
219 |
+
res = None
|
220 |
+
for key, value in map.items():
|
221 |
+
ar = value[0] / value[1]
|
222 |
+
diff = abs(aspect_ratio - ar)
|
223 |
+
if diff < d:
|
224 |
+
d = diff
|
225 |
+
res = value
|
226 |
+
|
227 |
+
ar = res[0] / res[1]
|
228 |
+
return f"{ar:.2f}", res
|