Noename commited on
Commit
37b43d6
1 Parent(s): b8db1c0
Files changed (4) hide show
  1. app.py +201 -0
  2. ciff_dataset.py +214 -0
  3. train_controlnet.py +1239 -0
  4. train_multi_open.py +1192 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import json
4
+ import random
5
+
6
+ import cv2
7
+ import einops
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torch
11
+
12
+ from pytorch_lightning import seed_everything
13
+ from annotator.util import resize_image, HWC3
14
+ from torch.nn.functional import threshold, normalize, interpolate
15
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
16
+ from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
17
+ from einops import rearrange, repeat
18
+
19
+ import argparse
20
+
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ parseargs = argparse.ArgumentParser()
24
+ parseargs.add_argument('--pretrained_model', type=str, default='runwayml/stable-diffusion-v1-5')
25
+ parseargs.add_argument('--controlnet', type=str, default='controlnet')
26
+ parseargs.add_argument('--precision', type=str, default='fp32')
27
+ args = parseargs.parse_args()
28
+ pretrained_model = args.pretrained_model
29
+
30
+ # Check for different hardware architectures
31
+ if torch.cuda.is_available():
32
+ device = "cuda"
33
+ # Check for xformers
34
+ try:
35
+ import xformers
36
+
37
+ enable_xformers = True
38
+ except ImportError:
39
+ enable_xformers = False
40
+ elif torch.backends.mps.is_available():
41
+ device = "mps"
42
+ else:
43
+ device = "cpu"
44
+
45
+ print(f"Using device: {device}")
46
+
47
+ # Load models
48
+ if args.precision == 'fp32':
49
+ torch_dtype = torch.float32
50
+ elif args.precision == 'fp16':
51
+ torch_dtype = torch.float16
52
+ elif args.precision == 'bf16':
53
+ torch_dtype = torch.bfloat16
54
+ else:
55
+ raise ValueError(f"Invalid precision: {args.precision}")
56
+
57
+ controlnet = ControlNetModel.from_pretrained(args.controlnet, torch_dtype=torch_dtype)
58
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
59
+ args.pretrained_model, controlnet=controlnet, torch_dtype=torch_dtype
60
+ )
61
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
62
+ pipe = pipe.to(device)
63
+
64
+ # Apply optimizations based on hardware
65
+ if device == "cuda":
66
+ pipe = pipe.to(device)
67
+ if enable_xformers:
68
+ pipe.enable_xformers_memory_efficient_attention()
69
+ print("xformers optimization enabled")
70
+ elif device == "mps":
71
+ pipe = pipe.to(device)
72
+ pipe.enable_attention_slicing()
73
+ print("Attention slicing enabled for Apple Silicon")
74
+ else:
75
+ # CPU-specific optimizations
76
+ pipe = pipe.to(device)
77
+ # pipe.enable_sequential_cpu_offload()
78
+ # pipe.enable_attention_slicing()
79
+
80
+ feature_extractor = SegformerFeatureExtractor.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
81
+ segmodel = SegformerForSemanticSegmentation.from_pretrained("matei-dorian/segformer-b5-finetuned-human-parsing")
82
+
83
+
84
+ def LGB_TO_RGB(gray_image, rgb_image):
85
+ # gray_image [H, W, 3]
86
+ # rgb_image [H, W, 3]
87
+
88
+ print("gray_image shape: ", gray_image.shape)
89
+ print("rgb_image shape: ", rgb_image.shape)
90
+
91
+ gray_image = cv2.cvtColor(gray_image, cv2.COLOR_RGB2GRAY)
92
+ lab_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2LAB)
93
+ lab_image[:, :, 0] = gray_image[:, :]
94
+
95
+ return cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
96
+
97
+
98
+ @torch.inference_mode()
99
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength,
100
+ guidance_scale, seed, eta, threshold, save_memory=False):
101
+ with torch.no_grad():
102
+ img = resize_image(input_image, image_resolution)
103
+ H, W, C = img.shape
104
+ print("img shape: ", img.shape)
105
+ if C == 3:
106
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
107
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
108
+ control = torch.from_numpy(img).to(device).float()
109
+ control = control / 255.0
110
+ control = rearrange(control, 'h w c -> 1 c h w')
111
+ # control = repeat(control, 'b c h w -> b c h w', b=num_samples)
112
+ # control = rearrange(control, 'b h w c -> b c h w')
113
+
114
+ if a_prompt:
115
+ prompt = prompt + ', ' + a_prompt
116
+
117
+ if seed == -1:
118
+ seed = random.randint(0, 65535)
119
+ seed_everything(seed)
120
+
121
+ generator = torch.Generator(device=device).manual_seed(seed)
122
+ # Generate images
123
+ output = pipe(
124
+ num_images_per_prompt=num_samples,
125
+ prompt=prompt,
126
+ image=control.to(device),
127
+ negative_prompt=n_prompt,
128
+ num_inference_steps=ddim_steps,
129
+ guidance_scale=guidance_scale,
130
+ generator=generator,
131
+ eta=eta,
132
+ strength=strength,
133
+ output_type='np',
134
+
135
+ ).images
136
+
137
+ # output = einops.rearrange(output, 'b c h w -> b h w c')
138
+ output = (output * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
139
+
140
+ results = [output[i] for i in range(num_samples)]
141
+ results = [LGB_TO_RGB(img, result) for result in results]
142
+
143
+ # results의 각 이미지를 mask로 변환
144
+ masks = []
145
+ for result in results:
146
+ inputs = feature_extractor(images=result, return_tensors="pt")
147
+ outputs = segmodel(**inputs)
148
+ logits = outputs.logits
149
+ logits = logits.squeeze(0)
150
+ thresholded = torch.zeros_like(logits)
151
+ thresholded[logits > threshold] = 1
152
+ mask = thresholded[1:, :, :].sum(dim=0)
153
+ mask = mask.unsqueeze(0).unsqueeze(0)
154
+ mask = interpolate(mask, size=(H, W), mode='bilinear')
155
+ mask = mask.detach().numpy()
156
+ mask = np.squeeze(mask)
157
+ mask = np.where(mask > threshold, 1, 0)
158
+ masks.append(mask)
159
+
160
+ # results의 각 이미지를 mask를 이용해 mask가 0인 부분은 img 즉 흑백 이미지로 변환.
161
+ # img를 channel이 3인 rgb 이미지로 변환
162
+ final = [img * (1 - mask[:, :, None]) + result * mask[:, :, None] for result, mask in zip(results, masks)]
163
+
164
+ # mask to 255 img
165
+
166
+ mask_img = [mask * 255 for mask in masks]
167
+ return [img] + results + mask_img + final
168
+
169
+
170
+ block = gr.Blocks().queue()
171
+ with block:
172
+ with gr.Row():
173
+ gr.Markdown("## Control Stable Diffusion with Gray Image")
174
+ with gr.Row():
175
+ with gr.Column():
176
+ input_image = gr.Image(sources=['upload'], type="numpy")
177
+ prompt = gr.Textbox(label="Prompt")
178
+ run_button = gr.Button(value="Run")
179
+ with gr.Accordion("Advanced options", open=False):
180
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=1, value=1, step=1, visible=False)
181
+ # num_samples = 1
182
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
183
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
184
+ # guess_mode = gr.Checkbox(label='Guess Mode', value=False)
185
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=20, value=20, step=1)
186
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
187
+ threshold = gr.Slider(label="Segmentation Threshold", minimum=0.1, maximum=0.9, value=0.5, step=0.05)
188
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=-1, step=1)
189
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
190
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
191
+ n_prompt = gr.Textbox(label="Negative Prompt",
192
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
193
+ with gr.Column():
194
+ # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
195
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery")
196
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, strength, scale, seed,
197
+ eta, threshold]
198
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery], concurrency_limit=4)
199
+
200
+ block.queue(max_size=100)
201
+ block.launch(share=True)
ciff_dataset.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from concurrent.futures import ProcessPoolExecutor
4
+ from pathlib import Path
5
+ import json
6
+ from PIL import Image
7
+ import numpy as np
8
+ import argparse
9
+ from tqdm import tqdm
10
+
11
+ # 인자 파싱
12
+ parser = argparse.ArgumentParser(description="Dataset creation for image colorization")
13
+ parser.add_argument("--source_dir", type=str, required=True, help="Source directory")
14
+ parser.add_argument(
15
+ "--target_dir", type=str, required=True, help="Target directory for the dataset"
16
+ )
17
+ parser.add_argument(
18
+ "--resolution", type=int, default=512, help="Resolution for the dataset"
19
+ )
20
+ args = parser.parse_args()
21
+
22
+ # 경로 설정
23
+ root_dir = Path("E:/datasets")
24
+ target_dir = root_dir / args.target_dir
25
+ source_dir = root_dir / args.source_dir
26
+ target_images_dir = target_dir / "images"
27
+ target_conditioning_dir = target_dir / "conditioning_images"
28
+ metadata_file = target_dir / "metadata.jsonl"
29
+
30
+ # 디렉토리 생성
31
+ target_dir.mkdir(parents=True, exist_ok=True)
32
+ target_images_dir.mkdir(exist_ok=True)
33
+ target_conditioning_dir.mkdir(exist_ok=True)
34
+
35
+ # 프롬프트 목록
36
+ prompts = [
37
+ "a color image, realistic style, photo",
38
+ "a color image, high resolution, realistic, painting",
39
+ "a color image, high resolution, realistic, photo",
40
+ "very good quality, absurd, photo, color, 4k image",
41
+ "high resolution, color, photo, realistic",
42
+ "high resolution, color, photo, realistic, 4k image",
43
+ "a color image, high resolution, realistic, 4k image",
44
+ "color, high resolution, photo, realistic",
45
+ "512x512, color, photo, realistic",
46
+ ]
47
+
48
+
49
+ def process_image(image_path):
50
+ try:
51
+ # 이미지 로드 및 크롭
52
+ with Image.open(image_path) as img:
53
+ # 이미지 크기 확인
54
+ width, height = img.size
55
+ size = min(width, height)
56
+ left = (width - size) // 2
57
+ top = (height - size) // 2
58
+ right = left + size
59
+ bottom = top + size
60
+
61
+ # 크롭 및 리사이즈
62
+ img_cropped = img.crop((left, top, right, bottom)).resize(
63
+ (args.resolution, args.resolution), Image.LANCZOS
64
+ )
65
+
66
+ # 그레이스케일 변환
67
+ img_gray = img_cropped.convert("L")
68
+
69
+ # 파일명 생성
70
+ filename = image_path.stem + ".jpg"
71
+
72
+ # 이미지 저장
73
+ img_cropped.save(target_images_dir / filename)
74
+ img_gray.save(target_conditioning_dir / filename)
75
+
76
+ # 메타데이터 생성
77
+ metadata = {
78
+ "image": str(filename),
79
+ "text": random.choice(prompts),
80
+ "conditioning_image": str(filename),
81
+ }
82
+
83
+ return metadata
84
+ except Exception as e:
85
+ print(f"Error processing {image_path}: {e}")
86
+ return None
87
+
88
+
89
+ def generate_dataset_loader(target_dir):
90
+ # 대상 디렉토리의 이름을 가져옵니다
91
+ dir_name = target_dir.name
92
+
93
+ # 클래스 이름을 생성합니다 (예: ciff_dataset -> CiffDataset)
94
+ class_name = ''.join(word.capitalize() for word in dir_name.split('_'))
95
+
96
+ # 파일 이름을 생성합니다
97
+ file_name = f"{dir_name}.py"
98
+
99
+ # 파일 경로를 생성합니다
100
+ file_path = target_dir / file_name
101
+
102
+ # 데이터셋 로더 코드를 생성합니다
103
+ code = f'''
104
+ import pandas as pd
105
+ from pathlib import Path
106
+ import datasets
107
+ import os
108
+
109
+ _VERSION = datasets.Version("0.0.2")
110
+
111
+ _DESCRIPTION = "TODO"
112
+ _HOMEPAGE = "TODO"
113
+ _LICENSE = "TODO"
114
+ _CITATION = "TODO"
115
+
116
+ _FEATURES = datasets.Features(
117
+ {{
118
+ "image": datasets.Image(),
119
+ "conditioning_image": datasets.Image(),
120
+ "text": datasets.Value("string"),
121
+ }}
122
+ )
123
+
124
+ _DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)
125
+
126
+
127
+ class {class_name}(datasets.GeneratorBasedBuilder):
128
+ BUILDER_CONFIGS = [_DEFAULT_CONFIG]
129
+ DEFAULT_CONFIG_NAME = "default"
130
+
131
+ def _info(self):
132
+ return datasets.DatasetInfo(
133
+ description=_DESCRIPTION,
134
+ features=_FEATURES,
135
+ supervised_keys=None,
136
+ homepage=_HOMEPAGE,
137
+ license=_LICENSE,
138
+ citation=_CITATION,
139
+ )
140
+
141
+ def _split_generators(self, dl_manager):
142
+ base_path = Path(dl_manager._base_path)
143
+ metadata_path = base_path / "metadata.jsonl"
144
+ images_dir = base_path / "images"
145
+ conditioning_images_dir = base_path / "conditioning_images"
146
+
147
+ return [
148
+ datasets.SplitGenerator(
149
+ name=datasets.Split.TRAIN,
150
+ gen_kwargs={{
151
+ "metadata_path": metadata_path,
152
+ "images_dir": images_dir,
153
+ "conditioning_images_dir": conditioning_images_dir,
154
+ }},
155
+ ),
156
+ ]
157
+
158
+ def _generate_examples(self, metadata_path, images_dir, conditioning_images_dir):
159
+ metadata = pd.read_json(metadata_path, lines=True)
160
+
161
+ for idx, row in metadata.iterrows():
162
+ text = row["text"]
163
+
164
+ image_path = os.path.join(images_dir, row["image"])
165
+ image = open(image_path, "rb").read()
166
+
167
+ conditioning_image_path = os.path.join(conditioning_images_dir, row["conditioning_image"])
168
+ conditioning_image = open(conditioning_image_path, "rb").read()
169
+
170
+ yield idx, {{
171
+ "text": text,
172
+ "image": {{
173
+ "path": image_path,
174
+ "bytes": image,
175
+ }},
176
+ "conditioning_image": {{
177
+ "path": conditioning_image_path,
178
+ "bytes": conditioning_image,
179
+ }},
180
+ }}
181
+ '''
182
+
183
+ # 파일을 생성하고 코드를 작성합니다
184
+ with open(file_path, 'w') as f:
185
+ f.write(code)
186
+
187
+ print(f"데이터셋 로더 파일이 생성되었습니다: {file_path}")
188
+
189
+
190
+ def main():
191
+ # 이미지 파일 목록 가져오기
192
+ image_files = list(source_dir.glob("*"))
193
+
194
+ # 프로세스 수 설정 (CPU 코어 수 - 1)
195
+ num_workers = (3 * os.cpu_count()) // 4
196
+
197
+ # 멀티프로세싱 실행
198
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
199
+ results = list(tqdm(executor.map(process_image, image_files), total=len(image_files), desc="Processing images"))
200
+
201
+ # 메타데이터 저장
202
+ with open(metadata_file, "w") as f:
203
+ for metadata in results:
204
+ if metadata:
205
+ json.dump(metadata, f)
206
+ f.write("\n")
207
+
208
+ # 데이터셋 로더 파일 생성
209
+ generate_dataset_loader(target_dir)
210
+
211
+
212
+ if __name__ == "__main__":
213
+ main()
214
+ print(f"Dataset creation completed. Output directory: {target_dir}")
train_controlnet.py ADDED
@@ -0,0 +1,1239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import contextlib
18
+ import gc
19
+ import logging
20
+ import math
21
+ import os
22
+ import random
23
+ import shutil
24
+ from functools import partial
25
+ from pathlib import Path
26
+
27
+ import accelerate
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import ProjectConfiguration, set_seed
36
+ from datasets import load_dataset, Features, Value, Image
37
+ from huggingface_hub import create_repo, upload_folder
38
+ from packaging import version
39
+ from PIL import Image
40
+ from torchvision import transforms
41
+ from tqdm.auto import tqdm
42
+ from transformers import AutoTokenizer, PretrainedConfig
43
+
44
+ import diffusers
45
+ from diffusers import (
46
+ AutoencoderKL,
47
+ ControlNetModel,
48
+ DDPMScheduler,
49
+ StableDiffusionControlNetPipeline,
50
+ UNet2DConditionModel,
51
+ UniPCMultistepScheduler,
52
+ )
53
+ from diffusers.optimization import get_scheduler
54
+ from diffusers.utils import check_min_version, is_wandb_available
55
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
56
+ from diffusers.utils.import_utils import is_xformers_available
57
+ from diffusers.utils.torch_utils import is_compiled_module
58
+
59
+
60
+ if is_wandb_available():
61
+ import wandb
62
+
63
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
64
+ check_min_version("0.30.0.dev0")
65
+
66
+ logger = get_logger(__name__)
67
+
68
+
69
+ def image_grid(imgs, rows, cols):
70
+ assert len(imgs) == rows * cols
71
+
72
+ w, h = imgs[0].size
73
+ grid = Image.new("RGB", size=(cols * w, rows * h))
74
+
75
+ for i, img in enumerate(imgs):
76
+ grid.paste(img, box=(i % cols * w, i // cols * h))
77
+ return grid
78
+
79
+
80
+ def log_validation(
81
+ vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
82
+ ):
83
+ logger.info("Running validation... ")
84
+
85
+ if not is_final_validation:
86
+ controlnet = accelerator.unwrap_model(controlnet)
87
+ else:
88
+ controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
89
+
90
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
91
+ args.pretrained_model_name_or_path,
92
+ vae=vae,
93
+ text_encoder=text_encoder,
94
+ tokenizer=tokenizer,
95
+ unet=unet,
96
+ controlnet=controlnet,
97
+ safety_checker=None,
98
+ revision=args.revision,
99
+ variant=args.variant,
100
+ torch_dtype=weight_dtype,
101
+ )
102
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
103
+ pipeline = pipeline.to(accelerator.device)
104
+ pipeline.set_progress_bar_config(disable=True)
105
+
106
+ if args.enable_xformers_memory_efficient_attention:
107
+ pipeline.enable_xformers_memory_efficient_attention()
108
+
109
+ if args.seed is None:
110
+ generator = None
111
+ else:
112
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
113
+
114
+ # if args.validation_image is folder, get all images in the folder
115
+ if len(args.validation_image) == 1 and os.path.isdir(args.validation_image[0]):
116
+ logger.info(f"Loading images from {args.validation_image[0]}")
117
+ dir_path = args.validation_image[0]
118
+ validation_images = [os.path.join(dir_path, f) for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]
119
+ logger.info(f"Found {len(validation_images)} images")
120
+ else:
121
+ validation_images = args.validation_image
122
+
123
+
124
+ if len(validation_images) == len(args.validation_prompt):
125
+ validation_prompts = args.validation_prompt
126
+ elif len(validation_images) == 1:
127
+ validation_images = validation_images * len(args.validation_prompt)
128
+ validation_prompts = args.validation_prompt
129
+ elif len(args.validation_prompt) == 1:
130
+ validation_prompts = args.validation_prompt * len(validation_images)
131
+ else:
132
+ raise ValueError(
133
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
134
+ )
135
+
136
+ image_logs = []
137
+ inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
138
+
139
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
140
+ validation_image = Image.open(validation_image).convert("RGB")
141
+ # Resize
142
+ validation_image = transforms.Resize(args.resolution)(validation_image)
143
+
144
+ images = []
145
+
146
+ for _ in range(args.num_validation_images):
147
+ with inference_ctx:
148
+ image = pipeline(
149
+ validation_prompt, validation_image, num_inference_steps=20, generator=generator
150
+ ).images[0]
151
+
152
+ images.append(image)
153
+
154
+ image_logs.append(
155
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
156
+ )
157
+
158
+ tracker_key = "test" if is_final_validation else "validation"
159
+ for tracker in accelerator.trackers:
160
+ if tracker.name == "tensorboard":
161
+ for log in image_logs:
162
+ images = log["images"]
163
+ validation_prompt = log["validation_prompt"]
164
+ validation_image = log["validation_image"]
165
+
166
+ formatted_images = []
167
+
168
+ formatted_images.append(np.asarray(validation_image))
169
+
170
+ for image in images:
171
+ formatted_images.append(np.asarray(image))
172
+
173
+ formatted_images = np.stack(formatted_images)
174
+
175
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
176
+ elif tracker.name == "wandb":
177
+ formatted_images = []
178
+
179
+ for log in image_logs:
180
+ images = log["images"]
181
+ validation_prompt = log["validation_prompt"]
182
+ validation_image = log["validation_image"]
183
+
184
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
185
+
186
+ for image in images:
187
+ image = wandb.Image(image, caption=validation_prompt)
188
+ formatted_images.append(image)
189
+
190
+ tracker.log({tracker_key: formatted_images})
191
+ else:
192
+ logger.warning(f"image logging not implemented for {tracker.name}")
193
+
194
+ del pipeline
195
+ gc.collect()
196
+ torch.cuda.empty_cache()
197
+
198
+ return image_logs
199
+
200
+
201
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
202
+ text_encoder_config = PretrainedConfig.from_pretrained(
203
+ pretrained_model_name_or_path,
204
+ subfolder="text_encoder",
205
+ revision=revision,
206
+ )
207
+ model_class = text_encoder_config.architectures[0]
208
+
209
+ if model_class == "CLIPTextModel":
210
+ from transformers import CLIPTextModel
211
+
212
+ return CLIPTextModel
213
+ elif model_class == "RobertaSeriesModelWithTransformation":
214
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
215
+
216
+ return RobertaSeriesModelWithTransformation
217
+ else:
218
+ raise ValueError(f"{model_class} is not supported.")
219
+
220
+
221
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
222
+ img_str = ""
223
+ if image_logs is not None:
224
+ img_str = "You can find some example images below.\n\n"
225
+ for i, log in enumerate(image_logs):
226
+ images = log["images"]
227
+ validation_prompt = log["validation_prompt"]
228
+ validation_image = log["validation_image"]
229
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
230
+ img_str += f"prompt: {validation_prompt}\n"
231
+ images = [validation_image] + images
232
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
233
+ img_str += f"![images_{i})](./images_{i}.png)\n"
234
+
235
+ model_description = f"""
236
+ # controlnet-{repo_id}
237
+
238
+ These are controlnet weights trained on {base_model} with new type of conditioning.
239
+ {img_str}
240
+ """
241
+ model_card = load_or_create_model_card(
242
+ repo_id_or_path=repo_id,
243
+ from_training=True,
244
+ license="creativeml-openrail-m",
245
+ base_model=base_model,
246
+ model_description=model_description,
247
+ inference=True,
248
+ )
249
+
250
+ tags = [
251
+ "stable-diffusion",
252
+ "stable-diffusion-diffusers",
253
+ "text-to-image",
254
+ "diffusers",
255
+ "controlnet",
256
+ "diffusers-training",
257
+ ]
258
+ model_card = populate_model_card(model_card, tags=tags)
259
+
260
+ model_card.save(os.path.join(repo_folder, "README.md"))
261
+
262
+
263
+ def parse_args(input_args=None):
264
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
265
+ parser.add_argument(
266
+ "--pretrained_model_name_or_path",
267
+ type=str,
268
+ default=None,
269
+ required=True,
270
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
271
+ )
272
+ parser.add_argument(
273
+ "--controlnet_model_name_or_path",
274
+ type=str,
275
+ default=None,
276
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
277
+ " If not specified controlnet weights are initialized from unet.",
278
+ )
279
+ parser.add_argument(
280
+ "--revision",
281
+ type=str,
282
+ default=None,
283
+ required=False,
284
+ help="Revision of pretrained model identifier from huggingface.co/models.",
285
+ )
286
+ parser.add_argument(
287
+ "--variant",
288
+ type=str,
289
+ default=None,
290
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
291
+ )
292
+ parser.add_argument(
293
+ "--tokenizer_name",
294
+ type=str,
295
+ default=None,
296
+ help="Pretrained tokenizer name or path if not the same as model_name",
297
+ )
298
+ parser.add_argument(
299
+ "--output_dir",
300
+ type=str,
301
+ default="controlnet-model",
302
+ help="The output directory where the model predictions and checkpoints will be written.",
303
+ )
304
+ parser.add_argument(
305
+ "--cache_dir",
306
+ type=str,
307
+ default=None,
308
+ help="The directory where the downloaded models and datasets will be stored.",
309
+ )
310
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
311
+ parser.add_argument(
312
+ "--resolution",
313
+ type=int,
314
+ default=512,
315
+ help=(
316
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
317
+ " resolution"
318
+ ),
319
+ )
320
+ parser.add_argument(
321
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
322
+ )
323
+ parser.add_argument("--num_train_epochs", type=int, default=1)
324
+ parser.add_argument(
325
+ "--max_train_steps",
326
+ type=int,
327
+ default=None,
328
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
329
+ )
330
+ parser.add_argument(
331
+ "--checkpointing_steps",
332
+ type=int,
333
+ default=500,
334
+ help=(
335
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
336
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
337
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
338
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
339
+ "instructions."
340
+ ),
341
+ )
342
+ parser.add_argument(
343
+ "--checkpoints_total_limit",
344
+ type=int,
345
+ default=None,
346
+ help=("Max number of checkpoints to store."),
347
+ )
348
+ parser.add_argument(
349
+ "--resume_from_checkpoint",
350
+ type=str,
351
+ default=None,
352
+ help=(
353
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
354
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
355
+ ),
356
+ )
357
+ parser.add_argument(
358
+ "--gradient_accumulation_steps",
359
+ type=int,
360
+ default=1,
361
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
362
+ )
363
+ parser.add_argument(
364
+ "--gradient_checkpointing",
365
+ action="store_true",
366
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
367
+ )
368
+ parser.add_argument(
369
+ "--learning_rate",
370
+ type=float,
371
+ default=5e-6,
372
+ help="Initial learning rate (after the potential warmup period) to use.",
373
+ )
374
+ parser.add_argument(
375
+ "--scale_lr",
376
+ action="store_true",
377
+ default=False,
378
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
379
+ )
380
+ parser.add_argument(
381
+ "--lr_scheduler",
382
+ type=str,
383
+ default="constant",
384
+ help=(
385
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
386
+ ' "constant", "constant_with_warmup"]'
387
+ ),
388
+ )
389
+ parser.add_argument(
390
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
391
+ )
392
+ parser.add_argument(
393
+ "--lr_num_cycles",
394
+ type=int,
395
+ default=1,
396
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
397
+ )
398
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
399
+ parser.add_argument(
400
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
401
+ )
402
+ parser.add_argument(
403
+ "--dataset_num_workers",
404
+ type=int,
405
+ default=1,
406
+ help=(
407
+ "Number of subprocesses to use for data loading."
408
+ ),
409
+ )
410
+ parser.add_argument(
411
+ "--dataloader_num_workers",
412
+ type=int,
413
+ default=0,
414
+ help=(
415
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
416
+ ),
417
+ )
418
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
419
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
420
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
421
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
422
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
423
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
424
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
425
+ parser.add_argument(
426
+ "--hub_model_id",
427
+ type=str,
428
+ default=None,
429
+ help="The name of the repository to keep in sync with the local `output_dir`.",
430
+ )
431
+ parser.add_argument(
432
+ "--logging_dir",
433
+ type=str,
434
+ default="logs",
435
+ help=(
436
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
437
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
438
+ ),
439
+ )
440
+ parser.add_argument(
441
+ "--allow_tf32",
442
+ action="store_true",
443
+ help=(
444
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
445
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
446
+ ),
447
+ )
448
+ parser.add_argument(
449
+ "--report_to",
450
+ type=str,
451
+ default="tensorboard",
452
+ help=(
453
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
454
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
455
+ ),
456
+ )
457
+ parser.add_argument(
458
+ "--mixed_precision",
459
+ type=str,
460
+ default=None,
461
+ choices=["no", "fp16", "bf16"],
462
+ help=(
463
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
464
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
465
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
466
+ ),
467
+ )
468
+ parser.add_argument(
469
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
470
+ )
471
+ parser.add_argument(
472
+ "--set_grads_to_none",
473
+ action="store_true",
474
+ help=(
475
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
476
+ " behaviors, so disable this argument if it causes any problems. More info:"
477
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
478
+ ),
479
+ )
480
+ parser.add_argument(
481
+ "--dataset_name",
482
+ type=str,
483
+ default=None,
484
+ help=(
485
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
486
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
487
+ " or to a folder containing files that 🤗 Datasets can understand."
488
+ ),
489
+ )
490
+ parser.add_argument(
491
+ "--dataset_config_name",
492
+ type=str,
493
+ default=None,
494
+ help="The config of the Dataset, leave as None if there's only one config.",
495
+ )
496
+ parser.add_argument(
497
+ "--train_data_dir",
498
+ type=str,
499
+ default=None,
500
+ help=(
501
+ "A folder containing the training data. Folder contents must follow the structure described in"
502
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
503
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
504
+ ),
505
+ )
506
+ parser.add_argument(
507
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
508
+ )
509
+ parser.add_argument(
510
+ "--conditioning_image_column",
511
+ type=str,
512
+ default="conditioning_image",
513
+ help="The column of the dataset containing the controlnet conditioning image.",
514
+ )
515
+ parser.add_argument(
516
+ "--caption_column",
517
+ type=str,
518
+ default="text",
519
+ help="The column of the dataset containing a caption or a list of captions.",
520
+ )
521
+ parser.add_argument(
522
+ "--max_train_samples",
523
+ type=int,
524
+ default=None,
525
+ help=(
526
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
527
+ "value if set."
528
+ ),
529
+ )
530
+ parser.add_argument(
531
+ "--proportion_empty_prompts",
532
+ type=float,
533
+ default=0,
534
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
535
+ )
536
+ parser.add_argument(
537
+ "--validation_prompt",
538
+ type=str,
539
+ default=None,
540
+ nargs="+",
541
+ help=(
542
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
543
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
544
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
545
+ ),
546
+ )
547
+ parser.add_argument(
548
+ "--validation_image",
549
+ type=str,
550
+ default=None,
551
+ nargs="+",
552
+ help=(
553
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
554
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
555
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
556
+ " `--validation_image` that will be used with all `--validation_prompt`s."
557
+ ),
558
+ )
559
+ parser.add_argument(
560
+ "--num_validation_images",
561
+ type=int,
562
+ default=4,
563
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
564
+ )
565
+ parser.add_argument(
566
+ "--validation_steps",
567
+ type=int,
568
+ default=100,
569
+ help=(
570
+ "Run validation every X steps. Validation consists of running the prompt"
571
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
572
+ " and logging the images."
573
+ ),
574
+ )
575
+ parser.add_argument(
576
+ "--tracker_project_name",
577
+ type=str,
578
+ default="train_controlnet",
579
+ help=(
580
+ "The `project_name` argument passed to Accelerator.init_trackers for"
581
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
582
+ ),
583
+ )
584
+ parser.add_argument(
585
+ "--tracker_run_name",
586
+ type=str,
587
+ default=None,
588
+ help=(
589
+ "The `run_name` argument passed to Accelerator.init_trackers for"
590
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
591
+ ),
592
+ )
593
+
594
+ if input_args is not None:
595
+ args = parser.parse_args(input_args)
596
+ else:
597
+ args = parser.parse_args()
598
+
599
+ if args.dataset_name is None and args.train_data_dir is None:
600
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
601
+
602
+ if args.dataset_name is not None and args.train_data_dir is not None:
603
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
604
+
605
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
606
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
607
+
608
+ if args.validation_prompt is not None and args.validation_image is None:
609
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
610
+
611
+ if args.validation_prompt is None and args.validation_image is not None:
612
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
613
+
614
+ if (
615
+ args.validation_image is not None
616
+ and args.validation_prompt is not None
617
+ and len(args.validation_image) != 1
618
+ and len(args.validation_prompt) != 1
619
+ and len(args.validation_image) != len(args.validation_prompt)
620
+ ):
621
+ raise ValueError(
622
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
623
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
624
+ )
625
+
626
+ if args.resolution % 8 != 0:
627
+ raise ValueError(
628
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
629
+ )
630
+
631
+ return args
632
+
633
+ def preprocess_train(examples, image_column, conditioning_image_column, image_transforms, conditioning_image_transforms, tokenize_caption):
634
+ images = [image.convert("RGB") for image in examples[image_column]]
635
+ images = [image_transforms(image) for image in images]
636
+
637
+ conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]]
638
+ conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images]
639
+
640
+ examples["pixel_values"] = images
641
+ examples["conditioning_pixel_values"] = conditioning_images
642
+ examples["input_ids"] = tokenize_caption
643
+
644
+ return examples
645
+
646
+ def make_train_dataset(args, tokenizer, accelerator):
647
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
648
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
649
+
650
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
651
+ # download the dataset.
652
+ if args.dataset_name is not None:
653
+ # Downloading and loading a dataset from the hub.
654
+ dataset = load_dataset(
655
+ args.dataset_name,
656
+ args.dataset_config_name,
657
+ cache_dir=args.cache_dir,
658
+ )
659
+ else:
660
+ # Get train_data_dir's last folder name
661
+ if args.train_data_dir is not None:
662
+ # For optimal performance
663
+ dataset = load_dataset(
664
+ args.train_data_dir,
665
+ cache_dir=args.cache_dir,
666
+ num_proc=args.dataset_num_workers,
667
+ # streaming=True,
668
+ trust_remote_code=True,
669
+ )
670
+
671
+ # See more about loading custom images at
672
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
673
+
674
+ # Preprocessing the datasets.
675
+ # We need to tokenize inputs and targets.
676
+ column_names = dataset["train"].column_names
677
+
678
+ # 6. Get the column names for input/target.
679
+ if args.image_column is None:
680
+ image_column = column_names[0]
681
+ logger.info(f"image column defaulting to {image_column}")
682
+ else:
683
+ image_column = args.image_column
684
+ if image_column not in column_names:
685
+ raise ValueError(
686
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
687
+ )
688
+
689
+ if args.caption_column is None:
690
+ caption_column = column_names[1]
691
+ logger.info(f"caption column defaulting to {caption_column}")
692
+ else:
693
+ caption_column = args.caption_column
694
+ if caption_column not in column_names:
695
+ raise ValueError(
696
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
697
+ )
698
+
699
+ if args.conditioning_image_column is None:
700
+ conditioning_image_column = column_names[2]
701
+ logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
702
+ else:
703
+ conditioning_image_column = args.conditioning_image_column
704
+ if conditioning_image_column not in column_names:
705
+ raise ValueError(
706
+ f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
707
+ )
708
+
709
+ def tokenize_captions(examples, is_train=True):
710
+ captions = []
711
+ for caption in examples[caption_column]:
712
+ if random.random() < args.proportion_empty_prompts:
713
+ captions.append("")
714
+ elif isinstance(caption, str):
715
+ captions.append(caption)
716
+ elif isinstance(caption, (list, np.ndarray)):
717
+ # take a random caption if there are multiple
718
+ captions.append(random.choice(caption) if is_train else caption[0])
719
+ else:
720
+ raise ValueError(
721
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
722
+ )
723
+ inputs = tokenizer(
724
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
725
+ )
726
+ return inputs.input_ids
727
+
728
+ # Tokenize captions
729
+ # tokenize_caption = tokenizer(
730
+ # [""], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
731
+ # ).input_ids
732
+
733
+ image_transforms = transforms.Compose(
734
+ [
735
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
736
+ transforms.CenterCrop(args.resolution),
737
+ transforms.ToTensor(),
738
+ transforms.Normalize([0.5], [0.5]),
739
+ ]
740
+ )
741
+
742
+ conditioning_image_transforms = transforms.Compose(
743
+ [
744
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
745
+ transforms.CenterCrop(args.resolution),
746
+ transforms.ToTensor(),
747
+ ]
748
+ )
749
+
750
+ batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
751
+ tokenize_caption = tokenizer(
752
+ [""] * batch_size, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
753
+ ).input_ids
754
+
755
+ with accelerator.main_process_first():
756
+ if args.max_train_samples is not None:
757
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
758
+ # Set the training transforms
759
+ logger.info("Applying preprocessing to the training dataset...")
760
+ train_dataset = dataset["train"].with_transform(partial(preprocess_train, image_column=image_column, conditioning_image_column=conditioning_image_column, image_transforms=image_transforms, conditioning_image_transforms=conditioning_image_transforms, tokenize_caption=tokenize_caption))
761
+ logger.info("Preprocessing applied to the training dataset.")
762
+
763
+ return train_dataset
764
+
765
+
766
+ def collate_fn(examples):
767
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
768
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
769
+
770
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
771
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
772
+
773
+ input_ids = torch.stack([example["input_ids"] for example in examples])
774
+
775
+ return {
776
+ "pixel_values": pixel_values,
777
+ "conditioning_pixel_values": conditioning_pixel_values,
778
+ "input_ids": input_ids,
779
+ }
780
+
781
+
782
+ def main(args):
783
+ if args.report_to == "wandb" and args.hub_token is not None:
784
+ raise ValueError(
785
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
786
+ " Please use `huggingface-cli login` to authenticate with the Hub."
787
+ )
788
+ os.environ["WANDB__SERVICE_WAIT"] = "300"
789
+
790
+ logging_dir = Path(args.output_dir, args.logging_dir)
791
+
792
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
793
+
794
+ accelerator = Accelerator(
795
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
796
+ mixed_precision=args.mixed_precision,
797
+ log_with=args.report_to,
798
+ project_config=accelerator_project_config,
799
+ )
800
+
801
+ # Disable AMP for MPS.
802
+ if torch.backends.mps.is_available():
803
+ accelerator.native_amp = False
804
+
805
+ # Make one log on every process with the configuration for debugging.
806
+ logging.basicConfig(
807
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
808
+ datefmt="%m/%d/%Y %H:%M:%S",
809
+ level=logging.INFO,
810
+ )
811
+ logger.info(accelerator.state, main_process_only=False)
812
+ logger.info(f"Training/evaluation parameters {args}")
813
+ if accelerator.is_local_main_process:
814
+ transformers.utils.logging.set_verbosity_warning()
815
+ diffusers.utils.logging.set_verbosity_info()
816
+ else:
817
+ transformers.utils.logging.set_verbosity_error()
818
+ diffusers.utils.logging.set_verbosity_error()
819
+
820
+ # If passed along, set the training seed now.
821
+ if args.seed is not None:
822
+ set_seed(args.seed)
823
+
824
+ # Handle the repository creation
825
+ if accelerator.is_main_process:
826
+ if args.output_dir is not None:
827
+ os.makedirs(args.output_dir, exist_ok=True)
828
+
829
+ if args.push_to_hub:
830
+ repo_id = create_repo(
831
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
832
+ ).repo_id
833
+
834
+ # Load the tokenizer
835
+ if args.tokenizer_name:
836
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
837
+ elif args.pretrained_model_name_or_path:
838
+ tokenizer = AutoTokenizer.from_pretrained(
839
+ args.pretrained_model_name_or_path,
840
+ subfolder="tokenizer",
841
+ revision=args.revision,
842
+ use_fast=False,
843
+ )
844
+
845
+ # import correct text encoder class
846
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
847
+
848
+ # Load scheduler and models
849
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
850
+ text_encoder = text_encoder_cls.from_pretrained(
851
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
852
+ )
853
+ vae = AutoencoderKL.from_pretrained(
854
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
855
+ )
856
+ unet = UNet2DConditionModel.from_pretrained(
857
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
858
+ )
859
+
860
+ if args.controlnet_model_name_or_path:
861
+ logger.info("Loading existing controlnet weights")
862
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
863
+ else:
864
+ logger.info("Initializing controlnet weights from unet")
865
+ controlnet = ControlNetModel.from_unet(unet)
866
+
867
+ # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
868
+ def unwrap_model(model):
869
+ model = accelerator.unwrap_model(model)
870
+ model = model._orig_mod if is_compiled_module(model) else model
871
+ return model
872
+
873
+ # `accelerate` 0.16.0 will have better support for customized saving
874
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
875
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
876
+ def save_model_hook(models, weights, output_dir):
877
+ if accelerator.is_main_process:
878
+ i = len(weights) - 1
879
+
880
+ while len(weights) > 0:
881
+ weights.pop()
882
+ model = models[i]
883
+
884
+ sub_dir = "controlnet"
885
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
886
+
887
+ i -= 1
888
+
889
+ def load_model_hook(models, input_dir):
890
+ while len(models) > 0:
891
+ # pop models so that they are not loaded again
892
+ model = models.pop()
893
+
894
+ # load diffusers style into model
895
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
896
+ model.register_to_config(**load_model.config)
897
+
898
+ model.load_state_dict(load_model.state_dict())
899
+ del load_model
900
+
901
+ accelerator.register_save_state_pre_hook(save_model_hook)
902
+ accelerator.register_load_state_pre_hook(load_model_hook)
903
+
904
+ vae.requires_grad_(False)
905
+ unet.requires_grad_(False)
906
+ text_encoder.requires_grad_(False)
907
+ controlnet.train()
908
+
909
+ if args.enable_xformers_memory_efficient_attention:
910
+ if is_xformers_available():
911
+ import xformers
912
+
913
+ xformers_version = version.parse(xformers.__version__)
914
+ if xformers_version == version.parse("0.0.16"):
915
+ logger.warning(
916
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
917
+ )
918
+ unet.enable_xformers_memory_efficient_attention()
919
+ controlnet.enable_xformers_memory_efficient_attention()
920
+ else:
921
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
922
+
923
+ if args.gradient_checkpointing:
924
+ controlnet.enable_gradient_checkpointing()
925
+
926
+ # Check that all trainable models are in full precision
927
+ low_precision_error_string = (
928
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
929
+ " doing mixed precision training, copy of the weights should still be float32."
930
+ )
931
+
932
+ if unwrap_model(controlnet).dtype != torch.float32:
933
+ raise ValueError(
934
+ f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
935
+ )
936
+
937
+ # Enable TF32 for faster training on Ampere GPUs,
938
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
939
+ if args.allow_tf32:
940
+ torch.backends.cuda.matmul.allow_tf32 = True
941
+
942
+ if args.scale_lr:
943
+ args.learning_rate = (
944
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
945
+ )
946
+
947
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
948
+ if args.use_8bit_adam:
949
+ try:
950
+ import bitsandbytes as bnb
951
+ except ImportError:
952
+ raise ImportError(
953
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
954
+ )
955
+
956
+ optimizer_class = bnb.optim.AdamW8bit
957
+ else:
958
+ optimizer_class = torch.optim.AdamW
959
+
960
+ # Optimizer creation
961
+ params_to_optimize = controlnet.parameters()
962
+ optimizer = optimizer_class(
963
+ params_to_optimize,
964
+ lr=args.learning_rate,
965
+ betas=(args.adam_beta1, args.adam_beta2),
966
+ weight_decay=args.adam_weight_decay,
967
+ eps=args.adam_epsilon,
968
+ )
969
+
970
+ logger.info("Loading the training dataset")
971
+ train_dataset = make_train_dataset(args, tokenizer, accelerator)
972
+
973
+ logger.info("Creating the training dataloader")
974
+ train_dataloader = torch.utils.data.DataLoader(
975
+ train_dataset,
976
+ shuffle=False,
977
+ collate_fn=collate_fn,
978
+ batch_size=args.train_batch_size,
979
+ num_workers=args.dataloader_num_workers,
980
+ )
981
+
982
+ # Scheduler and math around the number of training steps.
983
+ overrode_max_train_steps = False
984
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
985
+ if args.max_train_steps is None:
986
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
987
+ overrode_max_train_steps = True
988
+
989
+ lr_scheduler = get_scheduler(
990
+ args.lr_scheduler,
991
+ optimizer=optimizer,
992
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
993
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
994
+ num_cycles=args.lr_num_cycles,
995
+ power=args.lr_power,
996
+ )
997
+
998
+ # Prepare everything with our `accelerator`.
999
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1000
+ controlnet, optimizer, train_dataloader, lr_scheduler
1001
+ )
1002
+
1003
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
1004
+ # as these models are only used for inference, keeping weights in full precision is not required.
1005
+ weight_dtype = torch.float32
1006
+ if accelerator.mixed_precision == "fp16":
1007
+ weight_dtype = torch.float16
1008
+ elif accelerator.mixed_precision == "bf16":
1009
+ weight_dtype = torch.bfloat16
1010
+
1011
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
1012
+ vae.to(accelerator.device, dtype=weight_dtype)
1013
+ unet.to(accelerator.device, dtype=weight_dtype)
1014
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
1015
+
1016
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1017
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1018
+ if overrode_max_train_steps:
1019
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1020
+ # Afterwards we recalculate our number of training epochs
1021
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1022
+
1023
+ # We need to initialize the trackers we use, and also store our configuration.
1024
+ # The trackers initializes automatically on the main process.
1025
+ if accelerator.is_main_process:
1026
+ tracker_config = dict(vars(args))
1027
+
1028
+ # tensorboard cannot handle list types for config
1029
+ tracker_config.pop("validation_prompt")
1030
+ tracker_config.pop("validation_image")
1031
+
1032
+ logger.info(f"Init trackers: {args.tracker_project_name}")
1033
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1034
+ # if args.tracker_run_name is not None:
1035
+ # accelerator.trackers[-1].run.name = args.tracker_run_name
1036
+
1037
+ # Train!
1038
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1039
+
1040
+ logger.info("***** Running training *****")
1041
+ logger.info(f" Num examples = {len(train_dataset)}")
1042
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1043
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1044
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1045
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1046
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1047
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1048
+ global_step = 0
1049
+ first_epoch = 0
1050
+
1051
+ # Potentially load in the weights and states from a previous save
1052
+ if args.resume_from_checkpoint:
1053
+ if args.resume_from_checkpoint != "latest":
1054
+ path = os.path.basename(args.resume_from_checkpoint)
1055
+ else:
1056
+ # Get the most recent checkpoint
1057
+ dirs = os.listdir(args.output_dir)
1058
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1059
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1060
+ path = dirs[-1] if len(dirs) > 0 else None
1061
+
1062
+ if path is None:
1063
+ accelerator.print(
1064
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1065
+ )
1066
+ args.resume_from_checkpoint = None
1067
+ initial_global_step = 0
1068
+ else:
1069
+ accelerator.print(f"Resuming from checkpoint {path}")
1070
+ accelerator.load_state(os.path.join(args.output_dir, path))
1071
+ global_step = int(path.split("-")[1])
1072
+
1073
+ initial_global_step = global_step
1074
+ first_epoch = global_step // num_update_steps_per_epoch
1075
+ else:
1076
+ initial_global_step = 0
1077
+
1078
+ progress_bar = tqdm(
1079
+ range(0, args.max_train_steps),
1080
+ initial=initial_global_step,
1081
+ desc="Steps",
1082
+ # Only show the progress bar once on each machine.
1083
+ disable=not accelerator.is_local_main_process,
1084
+ )
1085
+
1086
+ image_logs = None
1087
+ for epoch in range(first_epoch, args.num_train_epochs):
1088
+ for step, batch in enumerate(train_dataloader):
1089
+ with accelerator.accumulate(controlnet):
1090
+ # Convert images to latent space
1091
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1092
+ latents = latents * vae.config.scaling_factor
1093
+
1094
+ # Sample noise that we'll add to the latents
1095
+ noise = torch.randn_like(latents)
1096
+ bsz = latents.shape[0]
1097
+ # Sample a random timestep for each image
1098
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1099
+ timesteps = timesteps.long()
1100
+
1101
+ # Add noise to the latents according to the noise magnitude at each timestep
1102
+ # (this is the forward diffusion process)
1103
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1104
+
1105
+ # Get the text embedding for conditioning
1106
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
1107
+
1108
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1109
+
1110
+ down_block_res_samples, mid_block_res_sample = controlnet(
1111
+ noisy_latents,
1112
+ timesteps,
1113
+ encoder_hidden_states=encoder_hidden_states,
1114
+ controlnet_cond=controlnet_image,
1115
+ return_dict=False,
1116
+ )
1117
+
1118
+ # Predict the noise residual
1119
+ model_pred = unet(
1120
+ noisy_latents,
1121
+ timesteps,
1122
+ encoder_hidden_states=encoder_hidden_states,
1123
+ down_block_additional_residuals=[
1124
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1125
+ ],
1126
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1127
+ return_dict=False,
1128
+ )[0]
1129
+
1130
+ # Get the target for loss depending on the prediction type
1131
+ if noise_scheduler.config.prediction_type == "epsilon":
1132
+ target = noise
1133
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1134
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1135
+ else:
1136
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1137
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1138
+
1139
+ accelerator.backward(loss)
1140
+ if accelerator.sync_gradients:
1141
+ params_to_clip = controlnet.parameters()
1142
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1143
+ optimizer.step()
1144
+ lr_scheduler.step()
1145
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1146
+
1147
+ # Checks if the accelerator has performed an optimization step behind the scenes
1148
+ if accelerator.sync_gradients:
1149
+ progress_bar.update(1)
1150
+ global_step += 1
1151
+
1152
+ if accelerator.is_main_process:
1153
+ if global_step % args.checkpointing_steps == 0:
1154
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1155
+ if args.checkpoints_total_limit is not None:
1156
+ checkpoints = os.listdir(args.output_dir)
1157
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1158
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1159
+
1160
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1161
+ if len(checkpoints) >= args.checkpoints_total_limit:
1162
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1163
+ removing_checkpoints = checkpoints[0:num_to_remove]
1164
+
1165
+ logger.info(
1166
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1167
+ )
1168
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1169
+
1170
+ for removing_checkpoint in removing_checkpoints:
1171
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1172
+ shutil.rmtree(removing_checkpoint)
1173
+
1174
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1175
+ accelerator.save_state(save_path)
1176
+ logger.info(f"Saved state to {save_path}")
1177
+
1178
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1179
+ image_logs = log_validation(
1180
+ vae,
1181
+ text_encoder,
1182
+ tokenizer,
1183
+ unet,
1184
+ controlnet,
1185
+ args,
1186
+ accelerator,
1187
+ weight_dtype,
1188
+ global_step,
1189
+ )
1190
+
1191
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1192
+ progress_bar.set_postfix(**logs)
1193
+ accelerator.log(logs, step=global_step)
1194
+
1195
+ if global_step >= args.max_train_steps:
1196
+ break
1197
+
1198
+ # Create the pipeline using using the trained modules and save it.
1199
+ accelerator.wait_for_everyone()
1200
+ if accelerator.is_main_process:
1201
+ controlnet = unwrap_model(controlnet)
1202
+ controlnet.save_pretrained(args.output_dir)
1203
+
1204
+ # Run a final round of validation.
1205
+ image_logs = None
1206
+ if args.validation_prompt is not None:
1207
+ image_logs = log_validation(
1208
+ vae=vae,
1209
+ text_encoder=text_encoder,
1210
+ tokenizer=tokenizer,
1211
+ unet=unet,
1212
+ controlnet=None,
1213
+ args=args,
1214
+ accelerator=accelerator,
1215
+ weight_dtype=weight_dtype,
1216
+ step=global_step,
1217
+ is_final_validation=True,
1218
+ )
1219
+
1220
+ if args.push_to_hub:
1221
+ save_model_card(
1222
+ repo_id,
1223
+ image_logs=image_logs,
1224
+ base_model=args.pretrained_model_name_or_path,
1225
+ repo_folder=args.output_dir,
1226
+ )
1227
+ upload_folder(
1228
+ repo_id=repo_id,
1229
+ folder_path=args.output_dir,
1230
+ commit_message="End of training",
1231
+ ignore_patterns=["step_*", "epoch_*"],
1232
+ )
1233
+
1234
+ accelerator.end_training()
1235
+
1236
+
1237
+ if __name__ == "__main__":
1238
+ args = parse_args()
1239
+ main(args)
train_multi_open.py ADDED
@@ -0,0 +1,1192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import contextlib
18
+ import gc
19
+ import logging
20
+ import math
21
+ import os
22
+ import random
23
+ import shutil
24
+ from pathlib import Path
25
+
26
+ import accelerate
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.utils.checkpoint
31
+ import transformers
32
+ from accelerate import Accelerator
33
+ from accelerate.logging import get_logger
34
+ from accelerate.utils import ProjectConfiguration, set_seed
35
+ from datasets import load_dataset, Features, Value, Image, IterableDataset
36
+ from datasets.iterable_dataset import ExamplesIterable
37
+ from torch.utils.data import get_worker_info
38
+ from multiprocessing import Pool, Queue, cpu_count
39
+ from threading import Thread
40
+ from itertools import cycle
41
+ from huggingface_hub import create_repo, upload_folder
42
+ from packaging import version
43
+ from PIL import Image
44
+ from torchvision import transforms
45
+ from tqdm.auto import tqdm
46
+ from transformers import AutoTokenizer, PretrainedConfig
47
+
48
+ import diffusers
49
+ from diffusers import (
50
+ AutoencoderKL,
51
+ ControlNetModel,
52
+ DDPMScheduler,
53
+ StableDiffusionControlNetPipeline,
54
+ UNet2DConditionModel,
55
+ UniPCMultistepScheduler,
56
+ )
57
+ from diffusers.optimization import get_scheduler
58
+ from diffusers.utils import check_min_version, is_wandb_available
59
+ from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
60
+ from diffusers.utils.import_utils import is_xformers_available
61
+ from diffusers.utils.torch_utils import is_compiled_module
62
+
63
+
64
+ if is_wandb_available():
65
+ import wandb
66
+
67
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
68
+ check_min_version("0.30.0.dev0")
69
+
70
+ logger = get_logger(__name__)
71
+
72
+
73
+ def image_grid(imgs, rows, cols):
74
+ assert len(imgs) == rows * cols
75
+
76
+ w, h = imgs[0].size
77
+ grid = Image.new("RGB", size=(cols * w, rows * h))
78
+
79
+ for i, img in enumerate(imgs):
80
+ grid.paste(img, box=(i % cols * w, i // cols * h))
81
+ return grid
82
+
83
+
84
+ def log_validation(
85
+ vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
86
+ ):
87
+ logger.info("Running validation... ")
88
+
89
+ if not is_final_validation:
90
+ controlnet = accelerator.unwrap_model(controlnet)
91
+ else:
92
+ controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
93
+
94
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
95
+ args.pretrained_model_name_or_path,
96
+ vae=vae,
97
+ text_encoder=text_encoder,
98
+ tokenizer=tokenizer,
99
+ unet=unet,
100
+ controlnet=controlnet,
101
+ safety_checker=None,
102
+ revision=args.revision,
103
+ variant=args.variant,
104
+ torch_dtype=weight_dtype,
105
+ )
106
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
107
+ pipeline = pipeline.to(accelerator.device)
108
+ pipeline.set_progress_bar_config(disable=True)
109
+
110
+ if args.enable_xformers_memory_efficient_attention:
111
+ pipeline.enable_xformers_memory_efficient_attention()
112
+
113
+ if args.seed is None:
114
+ generator = None
115
+ else:
116
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
117
+
118
+ # if args.validation_image is folder, get all images in the folder
119
+ if len(args.validation_image) == 1 and os.path.isdir(args.validation_image[0]):
120
+ logger.info(f"Loading images from {args.validation_image[0]}")
121
+ dir_path = args.validation_image[0]
122
+ validation_images = [os.path.join(dir_path, f) for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]
123
+ logger.info(f"Found {len(validation_images)} images")
124
+ else:
125
+ validation_images = args.validation_image
126
+
127
+
128
+ if len(validation_images) == len(args.validation_prompt):
129
+ validation_prompts = args.validation_prompt
130
+ elif len(validation_images) == 1:
131
+ validation_images = validation_images * len(args.validation_prompt)
132
+ validation_prompts = args.validation_prompt
133
+ elif len(args.validation_prompt) == 1:
134
+ validation_prompts = args.validation_prompt * len(validation_images)
135
+ else:
136
+ raise ValueError(
137
+ "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
138
+ )
139
+
140
+ image_logs = []
141
+ inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
142
+
143
+ for validation_prompt, validation_image in zip(validation_prompts, validation_images):
144
+ validation_image = Image.open(validation_image).convert("RGB")
145
+ # Resize
146
+ validation_image = transforms.Resize(args.resolution)(validation_image)
147
+
148
+ images = []
149
+
150
+ for _ in range(args.num_validation_images):
151
+ with inference_ctx:
152
+ image = pipeline(
153
+ validation_prompt, validation_image, num_inference_steps=20, generator=generator
154
+ ).images[0]
155
+
156
+ images.append(image)
157
+
158
+ image_logs.append(
159
+ {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
160
+ )
161
+
162
+ tracker_key = "test" if is_final_validation else "validation"
163
+ for tracker in accelerator.trackers:
164
+ if tracker.name == "tensorboard":
165
+ for log in image_logs:
166
+ images = log["images"]
167
+ validation_prompt = log["validation_prompt"]
168
+ validation_image = log["validation_image"]
169
+
170
+ formatted_images = []
171
+
172
+ formatted_images.append(np.asarray(validation_image))
173
+
174
+ for image in images:
175
+ formatted_images.append(np.asarray(image))
176
+
177
+ formatted_images = np.stack(formatted_images)
178
+
179
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
180
+ elif tracker.name == "wandb":
181
+ formatted_images = []
182
+
183
+ for log in image_logs:
184
+ images = log["images"]
185
+ validation_prompt = log["validation_prompt"]
186
+ validation_image = log["validation_image"]
187
+
188
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
189
+
190
+ for image in images:
191
+ image = wandb.Image(image, caption=validation_prompt)
192
+ formatted_images.append(image)
193
+
194
+ tracker.log({tracker_key: formatted_images})
195
+ else:
196
+ logger.warning(f"image logging not implemented for {tracker.name}")
197
+
198
+ del pipeline
199
+ gc.collect()
200
+ torch.cuda.empty_cache()
201
+
202
+ return image_logs
203
+
204
+
205
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
206
+ text_encoder_config = PretrainedConfig.from_pretrained(
207
+ pretrained_model_name_or_path,
208
+ subfolder="text_encoder",
209
+ revision=revision,
210
+ )
211
+ model_class = text_encoder_config.architectures[0]
212
+
213
+ if model_class == "CLIPTextModel":
214
+ from transformers import CLIPTextModel
215
+
216
+ return CLIPTextModel
217
+ elif model_class == "RobertaSeriesModelWithTransformation":
218
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
219
+
220
+ return RobertaSeriesModelWithTransformation
221
+ else:
222
+ raise ValueError(f"{model_class} is not supported.")
223
+
224
+
225
+ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
226
+ img_str = ""
227
+ if image_logs is not None:
228
+ img_str = "You can find some example images below.\n\n"
229
+ for i, log in enumerate(image_logs):
230
+ images = log["images"]
231
+ validation_prompt = log["validation_prompt"]
232
+ validation_image = log["validation_image"]
233
+ validation_image.save(os.path.join(repo_folder, "image_control.png"))
234
+ img_str += f"prompt: {validation_prompt}\n"
235
+ images = [validation_image] + images
236
+ image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
237
+ img_str += f"![images_{i})](./images_{i}.png)\n"
238
+
239
+ model_description = f"""
240
+ # controlnet-{repo_id}
241
+
242
+ These are controlnet weights trained on {base_model} with new type of conditioning.
243
+ {img_str}
244
+ """
245
+ model_card = load_or_create_model_card(
246
+ repo_id_or_path=repo_id,
247
+ from_training=True,
248
+ license="creativeml-openrail-m",
249
+ base_model=base_model,
250
+ model_description=model_description,
251
+ inference=True,
252
+ )
253
+
254
+ tags = [
255
+ "stable-diffusion",
256
+ "stable-diffusion-diffusers",
257
+ "text-to-image",
258
+ "diffusers",
259
+ "controlnet",
260
+ "diffusers-training",
261
+ ]
262
+ model_card = populate_model_card(model_card, tags=tags)
263
+
264
+ model_card.save(os.path.join(repo_folder, "README.md"))
265
+
266
+
267
+ def parse_args(input_args=None):
268
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
269
+ parser.add_argument(
270
+ "--pretrained_model_name_or_path",
271
+ type=str,
272
+ default=None,
273
+ required=True,
274
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
275
+ )
276
+ parser.add_argument(
277
+ "--controlnet_model_name_or_path",
278
+ type=str,
279
+ default=None,
280
+ help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
281
+ " If not specified controlnet weights are initialized from unet.",
282
+ )
283
+ parser.add_argument(
284
+ "--revision",
285
+ type=str,
286
+ default=None,
287
+ required=False,
288
+ help="Revision of pretrained model identifier from huggingface.co/models.",
289
+ )
290
+ parser.add_argument(
291
+ "--variant",
292
+ type=str,
293
+ default=None,
294
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
295
+ )
296
+ parser.add_argument(
297
+ "--tokenizer_name",
298
+ type=str,
299
+ default=None,
300
+ help="Pretrained tokenizer name or path if not the same as model_name",
301
+ )
302
+ parser.add_argument(
303
+ "--output_dir",
304
+ type=str,
305
+ default="controlnet-model",
306
+ help="The output directory where the model predictions and checkpoints will be written.",
307
+ )
308
+ parser.add_argument(
309
+ "--cache_dir",
310
+ type=str,
311
+ default=None,
312
+ help="The directory where the downloaded models and datasets will be stored.",
313
+ )
314
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
315
+ parser.add_argument(
316
+ "--resolution",
317
+ type=int,
318
+ default=512,
319
+ help=(
320
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
321
+ " resolution"
322
+ ),
323
+ )
324
+ parser.add_argument(
325
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
326
+ )
327
+ parser.add_argument(
328
+ "--max_train_steps",
329
+ type=int,
330
+ default=None,
331
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
332
+ )
333
+ parser.add_argument(
334
+ "--checkpointing_steps",
335
+ type=int,
336
+ default=500,
337
+ help=(
338
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
339
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
340
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
341
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
342
+ "instructions."
343
+ ),
344
+ )
345
+ parser.add_argument(
346
+ "--checkpoints_total_limit",
347
+ type=int,
348
+ default=None,
349
+ help=("Max number of checkpoints to store."),
350
+ )
351
+ parser.add_argument(
352
+ "--resume_from_checkpoint",
353
+ type=str,
354
+ default=None,
355
+ help=(
356
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
357
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
358
+ ),
359
+ )
360
+ parser.add_argument(
361
+ "--gradient_accumulation_steps",
362
+ type=int,
363
+ default=1,
364
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
365
+ )
366
+ parser.add_argument(
367
+ "--gradient_checkpointing",
368
+ action="store_true",
369
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
370
+ )
371
+ parser.add_argument(
372
+ "--learning_rate",
373
+ type=float,
374
+ default=5e-6,
375
+ help="Initial learning rate (after the potential warmup period) to use.",
376
+ )
377
+ parser.add_argument(
378
+ "--scale_lr",
379
+ action="store_true",
380
+ default=False,
381
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
382
+ )
383
+ parser.add_argument(
384
+ "--lr_scheduler",
385
+ type=str,
386
+ default="constant",
387
+ help=(
388
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
389
+ ' "constant", "constant_with_warmup"]'
390
+ ),
391
+ )
392
+ parser.add_argument(
393
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
394
+ )
395
+ parser.add_argument(
396
+ "--lr_num_cycles",
397
+ type=int,
398
+ default=1,
399
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
400
+ )
401
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
402
+ parser.add_argument(
403
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
404
+ )
405
+ parser.add_argument(
406
+ "--dataset_num_workers",
407
+ type=int,
408
+ default=1,
409
+ help=(
410
+ "Number of subprocesses to use for data loading."
411
+ ),
412
+ )
413
+ parser.add_argument(
414
+ "--dataloader_num_workers",
415
+ type=int,
416
+ default=0,
417
+ help=(
418
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
419
+ ),
420
+ )
421
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
422
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
423
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
424
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
425
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
426
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
427
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
428
+ parser.add_argument(
429
+ "--hub_model_id",
430
+ type=str,
431
+ default=None,
432
+ help="The name of the repository to keep in sync with the local `output_dir`.",
433
+ )
434
+ parser.add_argument(
435
+ "--logging_dir",
436
+ type=str,
437
+ default="logs",
438
+ help=(
439
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
440
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
441
+ ),
442
+ )
443
+ parser.add_argument(
444
+ "--allow_tf32",
445
+ action="store_true",
446
+ help=(
447
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
448
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
449
+ ),
450
+ )
451
+ parser.add_argument(
452
+ "--report_to",
453
+ type=str,
454
+ default="tensorboard",
455
+ help=(
456
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
457
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
458
+ ),
459
+ )
460
+ parser.add_argument(
461
+ "--mixed_precision",
462
+ type=str,
463
+ default=None,
464
+ choices=["no", "fp16", "bf16"],
465
+ help=(
466
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
467
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
468
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
469
+ ),
470
+ )
471
+ parser.add_argument(
472
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
473
+ )
474
+ parser.add_argument(
475
+ "--set_grads_to_none",
476
+ action="store_true",
477
+ help=(
478
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
479
+ " behaviors, so disable this argument if it causes any problems. More info:"
480
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
481
+ ),
482
+ )
483
+ parser.add_argument(
484
+ "--dataset_name",
485
+ type=str,
486
+ default=None,
487
+ help=(
488
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
489
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
490
+ " or to a folder containing files that 🤗 Datasets can understand."
491
+ ),
492
+ )
493
+ parser.add_argument(
494
+ "--dataset_config_name",
495
+ type=str,
496
+ default=None,
497
+ help="The config of the Dataset, leave as None if there's only one config.",
498
+ )
499
+ parser.add_argument(
500
+ "--train_data_dir",
501
+ type=str,
502
+ default=None,
503
+ help=(
504
+ "A folder containing the training data. Folder contents must follow the structure described in"
505
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
506
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
507
+ ),
508
+ )
509
+ parser.add_argument(
510
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
511
+ )
512
+ parser.add_argument(
513
+ "--conditioning_image_column",
514
+ type=str,
515
+ default="conditioning_image",
516
+ help="The column of the dataset containing the controlnet conditioning image.",
517
+ )
518
+ parser.add_argument(
519
+ "--caption_column",
520
+ type=str,
521
+ default="text",
522
+ help="The column of the dataset containing a caption or a list of captions.",
523
+ )
524
+ parser.add_argument(
525
+ "--proportion_empty_prompts",
526
+ type=float,
527
+ default=0,
528
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
529
+ )
530
+ parser.add_argument(
531
+ "--validation_prompt",
532
+ type=str,
533
+ default=None,
534
+ nargs="+",
535
+ help=(
536
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
537
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
538
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
539
+ ),
540
+ )
541
+ parser.add_argument(
542
+ "--validation_image",
543
+ type=str,
544
+ default=None,
545
+ nargs="+",
546
+ help=(
547
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
548
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
549
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
550
+ " `--validation_image` that will be used with all `--validation_prompt`s."
551
+ ),
552
+ )
553
+ parser.add_argument(
554
+ "--num_validation_images",
555
+ type=int,
556
+ default=4,
557
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
558
+ )
559
+ parser.add_argument(
560
+ "--validation_steps",
561
+ type=int,
562
+ default=100,
563
+ help=(
564
+ "Run validation every X steps. Validation consists of running the prompt"
565
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
566
+ " and logging the images."
567
+ ),
568
+ )
569
+ parser.add_argument(
570
+ "--tracker_project_name",
571
+ type=str,
572
+ default="train_controlnet",
573
+ help=(
574
+ "The `project_name` argument passed to Accelerator.init_trackers for"
575
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
576
+ ),
577
+ )
578
+ parser.add_argument(
579
+ "--tracker_run_name",
580
+ type=str,
581
+ default=None,
582
+ help=(
583
+ "The `run_name` argument passed to Accelerator.init_trackers for"
584
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
585
+ ),
586
+ )
587
+
588
+ if input_args is not None:
589
+ args = parser.parse_args(input_args)
590
+ else:
591
+ args = parser.parse_args()
592
+
593
+ if args.dataset_name is None and args.train_data_dir is None:
594
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
595
+
596
+ if args.dataset_name is not None and args.train_data_dir is not None:
597
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
598
+
599
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
600
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
601
+
602
+ if args.validation_prompt is not None and args.validation_image is None:
603
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
604
+
605
+ if args.validation_prompt is None and args.validation_image is not None:
606
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
607
+
608
+ if (
609
+ args.validation_image is not None
610
+ and args.validation_prompt is not None
611
+ and len(args.validation_image) != 1
612
+ and len(args.validation_prompt) != 1
613
+ and len(args.validation_image) != len(args.validation_prompt)
614
+ ):
615
+ raise ValueError(
616
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
617
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
618
+ )
619
+
620
+ if args.resolution % 8 != 0:
621
+ raise ValueError(
622
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
623
+ )
624
+
625
+ return args
626
+
627
+ class PrefetchIterableDataset(IterableDataset):
628
+ def __init__(self, image_files, resizer, transforms, conditioning_transforms, tokenizer, prefetch_factor=16):
629
+ ex_iterable = ExamplesIterable(self._generate_examples, kwargs={
630
+ "image_files": image_files,
631
+ "resizer": resizer,
632
+ "transforms": transforms,
633
+ "conditioning_transforms": conditioning_transforms,
634
+ "tokenizer": tokenizer
635
+ })
636
+
637
+ # Call the parent's __init__ with the created ex_iterable
638
+ super(PrefetchIterableDataset, self).__init__(ex_iterable=ex_iterable)
639
+ self.image_files = image_files
640
+ self.resizer = resizer
641
+ self.transforms = transforms
642
+ self.conditioning_transforms = conditioning_transforms
643
+ self.tokenizer = tokenizer
644
+ self.prefetch_factor = prefetch_factor
645
+ self.queue = Queue(maxsize=prefetch_factor)
646
+ # 미리 계산된 input_ids
647
+ self.empty_input_ids = tokenizer(
648
+ "",
649
+ max_length=tokenizer.model_max_length,
650
+ padding="max_length",
651
+ truncation=True,
652
+ return_tensors="pt"
653
+ ).input_ids[0]
654
+
655
+ def _generate_examples(self, image_files, resizer, transforms, conditioning_transforms, tokenizer):
656
+ for image_file in image_files:
657
+ yield self.preprocess_image(image_file)
658
+
659
+ def preprocess_image(self, image_file):
660
+ image_path = os.path.join(args.train_data_dir, image_file)
661
+ image = Image.open(image_path).convert("RGB")
662
+ image = self.resizer(image)
663
+ return {
664
+ "pixel_values": self.transforms(image),
665
+ "conditioning_pixel_values": self.conditioning_transforms(image),
666
+ "input_ids": self.empty_input_ids,
667
+ }
668
+
669
+ def producer(self):
670
+ with Pool(self.num_workers) as p:
671
+ for item in p.imap(self.preprocess_image, cycle(self.image_files)):
672
+ self.queue.put(item)
673
+
674
+ def __iter__(self):
675
+ worker_info = get_worker_info()
676
+ if worker_info is None: # single-process data loading
677
+ iter_start = 0
678
+ iter_end = len(self.image_files)
679
+ else: # in a worker process
680
+ per_worker = int(math.ceil(len(self.image_files) / float(worker_info.num_workers)))
681
+ worker_id = worker_info.id
682
+ iter_start = worker_id * per_worker
683
+ iter_end = min(iter_start + per_worker, len(self.image_files))
684
+
685
+ # 각 워커에 할당된 이미지 파일
686
+ worker_image_files = self.image_files[iter_start:iter_end]
687
+
688
+ # 프리페치를 위한 큐 생성
689
+ queue = Queue(maxsize=self.prefetch_factor)
690
+
691
+ # 프로듀서 함수 정의
692
+ def producer():
693
+ for image_file in cycle(worker_image_files): # cycle을 사용하여 무한 반복
694
+ item = self.preprocess_image(image_file)
695
+ queue.put(item)
696
+
697
+ # 프로듀서 스레드 시작
698
+ thread = Thread(target=producer)
699
+ thread.daemon = True
700
+ thread.start()
701
+
702
+ # 아이템 yield
703
+ while True:
704
+ yield queue.get()
705
+
706
+
707
+ def collate_fn(examples):
708
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
709
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
710
+
711
+ conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
712
+ conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
713
+
714
+ input_ids = torch.stack([example["input_ids"] for example in examples])
715
+
716
+ return {
717
+ "pixel_values": pixel_values,
718
+ "conditioning_pixel_values": conditioning_pixel_values,
719
+ "input_ids": input_ids,
720
+ }
721
+
722
+ def main(args):
723
+ if args.report_to == "wandb" and args.hub_token is not None:
724
+ raise ValueError(
725
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
726
+ " Please use `huggingface-cli login` to authenticate with the Hub."
727
+ )
728
+ os.environ["WANDB__SERVICE_WAIT"] = "300"
729
+
730
+ logging_dir = Path(args.output_dir, args.logging_dir)
731
+
732
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
733
+
734
+ accelerator = Accelerator(
735
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
736
+ mixed_precision=args.mixed_precision,
737
+ log_with=args.report_to,
738
+ project_config=accelerator_project_config,
739
+ )
740
+
741
+ # Disable AMP for MPS.
742
+ if torch.backends.mps.is_available():
743
+ accelerator.native_amp = False
744
+
745
+ # Make one log on every process with the configuration for debugging.
746
+ logging.basicConfig(
747
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
748
+ datefmt="%m/%d/%Y %H:%M:%S",
749
+ level=logging.INFO,
750
+ )
751
+ logger.info(accelerator.state, main_process_only=False)
752
+ logger.info(f"Training/evaluation parameters {args}")
753
+ if accelerator.is_local_main_process:
754
+ transformers.utils.logging.set_verbosity_warning()
755
+ diffusers.utils.logging.set_verbosity_info()
756
+ else:
757
+ transformers.utils.logging.set_verbosity_error()
758
+ diffusers.utils.logging.set_verbosity_error()
759
+
760
+ # If passed along, set the training seed now.
761
+ if args.seed is not None:
762
+ set_seed(args.seed)
763
+
764
+ # Handle the repository creation
765
+ if accelerator.is_main_process:
766
+ if args.output_dir is not None:
767
+ os.makedirs(args.output_dir, exist_ok=True)
768
+
769
+ if args.push_to_hub:
770
+ repo_id = create_repo(
771
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
772
+ ).repo_id
773
+
774
+ # Load the tokenizer
775
+ if args.tokenizer_name:
776
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
777
+ elif args.pretrained_model_name_or_path:
778
+ tokenizer = AutoTokenizer.from_pretrained(
779
+ args.pretrained_model_name_or_path,
780
+ subfolder="tokenizer",
781
+ revision=args.revision,
782
+ use_fast=False,
783
+ )
784
+
785
+ # import correct text encoder class
786
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
787
+
788
+ # Load scheduler and models
789
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
790
+ text_encoder = text_encoder_cls.from_pretrained(
791
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
792
+ )
793
+ vae = AutoencoderKL.from_pretrained(
794
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
795
+ )
796
+ unet = UNet2DConditionModel.from_pretrained(
797
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
798
+ )
799
+
800
+ if args.controlnet_model_name_or_path:
801
+ logger.info("Loading existing controlnet weights")
802
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
803
+ else:
804
+ logger.info("Initializing controlnet weights from unet")
805
+ controlnet = ControlNetModel.from_unet(unet)
806
+
807
+ # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
808
+ def unwrap_model(model):
809
+ model = accelerator.unwrap_model(model)
810
+ model = model._orig_mod if is_compiled_module(model) else model
811
+ return model
812
+
813
+ # `accelerate` 0.16.0 will have better support for customized saving
814
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
815
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
816
+ def save_model_hook(models, weights, output_dir):
817
+ if accelerator.is_main_process:
818
+ i = len(weights) - 1
819
+
820
+ while len(weights) > 0:
821
+ weights.pop()
822
+ model = models[i]
823
+
824
+ sub_dir = "controlnet"
825
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
826
+
827
+ i -= 1
828
+
829
+ def load_model_hook(models, input_dir):
830
+ while len(models) > 0:
831
+ # pop models so that they are not loaded again
832
+ model = models.pop()
833
+
834
+ # load diffusers style into model
835
+ load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet")
836
+ model.register_to_config(**load_model.config)
837
+
838
+ model.load_state_dict(load_model.state_dict())
839
+ del load_model
840
+
841
+ accelerator.register_save_state_pre_hook(save_model_hook)
842
+ accelerator.register_load_state_pre_hook(load_model_hook)
843
+
844
+ vae.requires_grad_(False)
845
+ unet.requires_grad_(False)
846
+ text_encoder.requires_grad_(False)
847
+ controlnet.train()
848
+
849
+ if args.enable_xformers_memory_efficient_attention:
850
+ if is_xformers_available():
851
+ import xformers
852
+
853
+ xformers_version = version.parse(xformers.__version__)
854
+ if xformers_version == version.parse("0.0.16"):
855
+ logger.warning(
856
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
857
+ )
858
+ unet.enable_xformers_memory_efficient_attention()
859
+ controlnet.enable_xformers_memory_efficient_attention()
860
+ else:
861
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
862
+
863
+ if args.gradient_checkpointing:
864
+ controlnet.enable_gradient_checkpointing()
865
+
866
+ # Check that all trainable models are in full precision
867
+ low_precision_error_string = (
868
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
869
+ " doing mixed precision training, copy of the weights should still be float32."
870
+ )
871
+
872
+ if unwrap_model(controlnet).dtype != torch.float32:
873
+ raise ValueError(
874
+ f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
875
+ )
876
+
877
+ # Enable TF32 for faster training on Ampere GPUs,
878
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
879
+ if args.allow_tf32:
880
+ torch.backends.cuda.matmul.allow_tf32 = True
881
+
882
+ if args.scale_lr:
883
+ args.learning_rate = (
884
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
885
+ )
886
+
887
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
888
+ if args.use_8bit_adam:
889
+ try:
890
+ import bitsandbytes as bnb
891
+ except ImportError:
892
+ raise ImportError(
893
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
894
+ )
895
+
896
+ optimizer_class = bnb.optim.AdamW8bit
897
+ else:
898
+ optimizer_class = torch.optim.AdamW
899
+
900
+ # Optimizer creation
901
+ params_to_optimize = controlnet.parameters()
902
+ optimizer = optimizer_class(
903
+ params_to_optimize,
904
+ lr=args.learning_rate,
905
+ betas=(args.adam_beta1, args.adam_beta2),
906
+ weight_decay=args.adam_weight_decay,
907
+ eps=args.adam_epsilon,
908
+ )
909
+
910
+ logger.info("Loading the training dataset")
911
+ image_files = [f for f in os.listdir(args.train_data_dir) if
912
+ f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))]
913
+
914
+ # Define image transforms
915
+ image_resizers = transforms.Compose([
916
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
917
+ transforms.CenterCrop(args.resolution),
918
+ ])
919
+ image_transforms = transforms.Compose([
920
+ transforms.ToTensor(),
921
+ transforms.Normalize([0.5], [0.5]),
922
+ ])
923
+ conditioning_image_transforms = transforms.Compose([
924
+ transforms.Grayscale(num_output_channels=3),
925
+ transforms.ToTensor(),
926
+ ])
927
+
928
+ train_dataset = PrefetchIterableDataset(
929
+ image_files=image_files,
930
+ resizer=image_resizers,
931
+ transforms=image_transforms,
932
+ conditioning_transforms=conditioning_image_transforms,
933
+ tokenizer=tokenizer,
934
+ prefetch_factor=128,
935
+ )
936
+
937
+ logger.info("Creating the training dataloader")
938
+ train_dataloader = torch.utils.data.DataLoader(
939
+ train_dataset,
940
+ batch_size=args.train_batch_size,
941
+ collate_fn=collate_fn,
942
+ num_workers=32,
943
+ pin_memory=True
944
+ )
945
+
946
+ lr_scheduler = get_scheduler(
947
+ args.lr_scheduler,
948
+ optimizer=optimizer,
949
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
950
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
951
+ num_cycles=args.lr_num_cycles,
952
+ power=args.lr_power,
953
+ )
954
+
955
+ # Prepare everything with our `accelerator`.
956
+ controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
957
+ controlnet, optimizer, train_dataloader, lr_scheduler
958
+ )
959
+
960
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
961
+ # as these models are only used for inference, keeping weights in full precision is not required.
962
+ weight_dtype = torch.float32
963
+ if accelerator.mixed_precision == "fp16":
964
+ weight_dtype = torch.float16
965
+ elif accelerator.mixed_precision == "bf16":
966
+ weight_dtype = torch.bfloat16
967
+
968
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
969
+ vae.to(accelerator.device, dtype=weight_dtype)
970
+ unet.to(accelerator.device, dtype=weight_dtype)
971
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
972
+
973
+ # We need to initialize the trackers we use, and also store our configuration.
974
+ # The trackers initializes automatically on the main process.
975
+ if accelerator.is_main_process:
976
+ tracker_config = dict(vars(args))
977
+
978
+ # tensorboard cannot handle list types for config
979
+ tracker_config.pop("validation_prompt")
980
+ tracker_config.pop("validation_image")
981
+
982
+ logger.info(f"Init trackers: {args.tracker_project_name}")
983
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
984
+ # if args.tracker_run_name is not None:
985
+ # accelerator.trackers[-1].run.name = args.tracker_run_name
986
+
987
+ # Train!
988
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
989
+
990
+ logger.info("***** Running training *****")
991
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
992
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
993
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
994
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
995
+ global_step = 0
996
+ first_epoch = 0
997
+
998
+ # Potentially load in the weights and states from a previous save
999
+ if args.resume_from_checkpoint:
1000
+ if args.resume_from_checkpoint != "latest":
1001
+ path = os.path.basename(args.resume_from_checkpoint)
1002
+ else:
1003
+ # Get the most recent checkpoint
1004
+ dirs = os.listdir(args.output_dir)
1005
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1006
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1007
+ path = dirs[-1] if len(dirs) > 0 else None
1008
+
1009
+ if path is None:
1010
+ accelerator.print(
1011
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1012
+ )
1013
+ args.resume_from_checkpoint = None
1014
+ initial_global_step = 0
1015
+ else:
1016
+ accelerator.print(f"Resuming from checkpoint {path}")
1017
+ accelerator.load_state(os.path.join(args.output_dir, path))
1018
+ global_step = int(path.split("-")[1])
1019
+
1020
+ initial_global_step = global_step
1021
+ else:
1022
+ initial_global_step = 0
1023
+
1024
+ progress_bar = tqdm(
1025
+ range(0, args.max_train_steps),
1026
+ initial=initial_global_step,
1027
+ desc="Steps",
1028
+ # Only show the progress bar once on each machine.
1029
+ disable=not accelerator.is_local_main_process,
1030
+ )
1031
+
1032
+ image_logs = None
1033
+ while True:
1034
+ for batch in train_dataloader:
1035
+ try:
1036
+ with accelerator.accumulate(controlnet):
1037
+ # Convert images to latent space
1038
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1039
+ latents = latents * vae.config.scaling_factor
1040
+
1041
+ # Sample noise that we'll add to the latents
1042
+ noise = torch.randn_like(latents)
1043
+ bsz = latents.shape[0]
1044
+ # Sample a random timestep for each image
1045
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1046
+ timesteps = timesteps.long()
1047
+
1048
+ # Add noise to the latents according to the noise magnitude at each timestep
1049
+ # (this is the forward diffusion process)
1050
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1051
+
1052
+ # Get the text embedding for conditioning
1053
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
1054
+
1055
+ controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
1056
+
1057
+ down_block_res_samples, mid_block_res_sample = controlnet(
1058
+ noisy_latents,
1059
+ timesteps,
1060
+ encoder_hidden_states=encoder_hidden_states,
1061
+ controlnet_cond=controlnet_image,
1062
+ return_dict=False,
1063
+ )
1064
+
1065
+ # Predict the noise residual
1066
+ model_pred = unet(
1067
+ noisy_latents,
1068
+ timesteps,
1069
+ encoder_hidden_states=encoder_hidden_states,
1070
+ down_block_additional_residuals=[
1071
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1072
+ ],
1073
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1074
+ return_dict=False,
1075
+ )[0]
1076
+
1077
+ # Get the target for loss depending on the prediction type
1078
+ if noise_scheduler.config.prediction_type == "epsilon":
1079
+ target = noise
1080
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1081
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
1082
+ else:
1083
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1084
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1085
+
1086
+ accelerator.backward(loss)
1087
+ if accelerator.sync_gradients:
1088
+ params_to_clip = controlnet.parameters()
1089
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1090
+ optimizer.step()
1091
+ lr_scheduler.step()
1092
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1093
+
1094
+ # Checks if the accelerator has performed an optimization step behind the scenes
1095
+ if accelerator.sync_gradients:
1096
+ progress_bar.update(1)
1097
+ global_step += 1
1098
+
1099
+ if accelerator.is_main_process:
1100
+ if global_step % args.checkpointing_steps == 0:
1101
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1102
+ if args.checkpoints_total_limit is not None:
1103
+ checkpoints = os.listdir(args.output_dir)
1104
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1105
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1106
+
1107
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1108
+ if len(checkpoints) >= args.checkpoints_total_limit:
1109
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1110
+ removing_checkpoints = checkpoints[0:num_to_remove]
1111
+
1112
+ logger.info(
1113
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1114
+ )
1115
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1116
+
1117
+ for removing_checkpoint in removing_checkpoints:
1118
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1119
+ shutil.rmtree(removing_checkpoint)
1120
+
1121
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1122
+ accelerator.save_state(save_path)
1123
+ logger.info(f"Saved state to {save_path}")
1124
+
1125
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1126
+ image_logs = log_validation(
1127
+ vae,
1128
+ text_encoder,
1129
+ tokenizer,
1130
+ unet,
1131
+ controlnet,
1132
+ args,
1133
+ accelerator,
1134
+ weight_dtype,
1135
+ global_step,
1136
+ )
1137
+
1138
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1139
+ progress_bar.set_postfix(**logs)
1140
+ accelerator.log(logs, step=global_step)
1141
+
1142
+ if global_step >= args.max_train_steps:
1143
+ break
1144
+ except Exception as e:
1145
+ logger.warning(f"Error processing batch: {str(e)}")
1146
+ continue
1147
+
1148
+ if global_step >= args.max_train_steps:
1149
+ break
1150
+
1151
+ # Create the pipeline using using the trained modules and save it.
1152
+ accelerator.wait_for_everyone()
1153
+ if accelerator.is_main_process:
1154
+ controlnet = unwrap_model(controlnet)
1155
+ controlnet.save_pretrained(args.output_dir)
1156
+
1157
+ # Run a final round of validation.
1158
+ image_logs = None
1159
+ if args.validation_prompt is not None:
1160
+ image_logs = log_validation(
1161
+ vae=vae,
1162
+ text_encoder=text_encoder,
1163
+ tokenizer=tokenizer,
1164
+ unet=unet,
1165
+ controlnet=None,
1166
+ args=args,
1167
+ accelerator=accelerator,
1168
+ weight_dtype=weight_dtype,
1169
+ step=global_step,
1170
+ is_final_validation=True,
1171
+ )
1172
+
1173
+ if args.push_to_hub:
1174
+ save_model_card(
1175
+ repo_id,
1176
+ image_logs=image_logs,
1177
+ base_model=args.pretrained_model_name_or_path,
1178
+ repo_folder=args.output_dir,
1179
+ )
1180
+ upload_folder(
1181
+ repo_id=repo_id,
1182
+ folder_path=args.output_dir,
1183
+ commit_message="End of training",
1184
+ ignore_patterns=["step_*", "epoch_*"],
1185
+ )
1186
+
1187
+ accelerator.end_training()
1188
+
1189
+
1190
+ if __name__ == "__main__":
1191
+ args = parse_args()
1192
+ main(args)