1lint commited on
Commit
7833185
1 Parent(s): f94c315

add textual inversion functionality

Browse files
Files changed (4) hide show
  1. .gitignore +3 -1
  2. app.py +126 -7
  3. inpaint_noise_issue.ipynb +0 -0
  4. textual_inversion.py +778 -0
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  __pycache__/
2
- WIP/
 
 
 
1
  __pycache__/
2
+ WIP/
3
+ concept_images/
4
+ output_model/
app.py CHANGED
@@ -5,7 +5,6 @@ from inpaint_pipeline import SDInpaintPipeline as StableDiffusionInpaintPipeline
5
  from diffusers import (
6
  StableDiffusionPipeline,
7
  StableDiffusionImg2ImgPipeline,
8
- #StableDiffusionInpaintPipelineLegacy # uncomment this line to use original inpaint pipeline
9
  )
10
 
11
  import diffusers.schedulers
@@ -14,9 +13,29 @@ import torch
14
  import random
15
  from multiprocessing import cpu_count
16
  import json
 
 
 
 
 
17
 
18
  import importlib
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  _xformers_available = importlib.util.find_spec("xformers") is not None
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  low_vram_mode = False
@@ -56,8 +75,7 @@ def load_pipe(
56
  if model_id != loaded_model_id:
57
  pipe = pipe_class.from_pretrained(
58
  model_id,
59
- # torch_dtype=torch.float16,
60
- # revision='fp16',
61
  safety_checker=None,
62
  requires_safety_checker=False,
63
  scheduler=scheduler.from_pretrained(model_id, subfolder="scheduler"),
@@ -87,7 +105,8 @@ def load_pipe(
87
  pipe = None
88
  pipe = load_pipe(model_ids[0], default_scheduler)
89
 
90
-
 
91
  def generate(
92
  model_name,
93
  scheduler_name,
@@ -177,9 +196,75 @@ def generate(
177
  )
178
 
179
  else:
180
- None, f"Unhandled pipeline class: {pipe_class}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- return result.images, status_message
183
 
184
 
185
  default_img_size = 512
@@ -260,6 +345,32 @@ with gr.Blocks(css="style.css") as demo:
260
  inpaint_options = ["preserve non-masked portions of image", "output entire inpainted image"]
261
  inpaint_radio = gr.Radio(inpaint_options, value=inpaint_options[0], show_label=False, interactive=True)
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  with gr.Row():
264
  batch_size = gr.Slider(
265
  label="Batch Size", value=1, minimum=1, maximum=8, step=1
@@ -325,10 +436,18 @@ with gr.Blocks(css="style.css") as demo:
325
  pipe_state,
326
  pipe_kwargs,
327
  ]
328
- outputs = [gallery, generation_details]
329
 
330
  prompt.submit(generate, inputs=inputs, outputs=outputs)
331
  generate_button.click(generate, inputs=inputs, outputs=outputs)
332
 
 
 
 
 
 
 
 
333
  demo.queue(concurrency_count=cpu_count())
 
334
  demo.launch()
 
5
  from diffusers import (
6
  StableDiffusionPipeline,
7
  StableDiffusionImg2ImgPipeline,
 
8
  )
9
 
10
  import diffusers.schedulers
 
13
  import random
14
  from multiprocessing import cpu_count
15
  import json
16
+ from PIL import Image
17
+ import os
18
+ import argparse
19
+ import shutil
20
+ import gc
21
 
22
  import importlib
23
 
24
+ from textual_inversion import main as run_textual_inversion
25
+
26
+ def pad_image(image):
27
+ w, h = image.size
28
+ if w == h:
29
+ return image
30
+ elif w > h:
31
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
32
+ new_image.paste(image, (0, (w - h) // 2))
33
+ return new_image
34
+ else:
35
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
36
+ new_image.paste(image, ((h - w) // 2, 0))
37
+ return new_image
38
+
39
  _xformers_available = importlib.util.find_spec("xformers") is not None
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
  low_vram_mode = False
 
75
  if model_id != loaded_model_id:
76
  pipe = pipe_class.from_pretrained(
77
  model_id,
78
+ torch_dtype=torch.float16,
 
79
  safety_checker=None,
80
  requires_safety_checker=False,
81
  scheduler=scheduler.from_pretrained(model_id, subfolder="scheduler"),
 
105
  pipe = None
106
  pipe = load_pipe(model_ids[0], default_scheduler)
107
 
108
+ @torch.autocast(device)
109
+ @torch.no_grad()
110
  def generate(
111
  model_name,
112
  scheduler_name,
 
196
  )
197
 
198
  else:
199
+ return None, f"Unhandled pipeline class: {pipe_class}", -1
200
+
201
+ return result.images, status_message, seed
202
+
203
+
204
+ # based on lvkaokao/textual-inversion-training
205
+ def train_textual_inversion(model_name, scheduler_name, type_of_thing, files, concept_word, init_word, text_train_steps, text_train_bsz, text_learning_rate, progress=gr.Progress(track_tqdm=True)):
206
+
207
+ pipe = load_pipe(
208
+ model_id=model_name,
209
+ scheduler_name=scheduler_name,
210
+ pipe_class=StableDiffusionPipeline,
211
+ )
212
+
213
+ pipe.disable_xformers_memory_efficient_attention() # xformers handled by textual inversion script
214
+
215
+ concept_dir = 'concept_images'
216
+ output_dir = 'output_model'
217
+ training_resolution = 512
218
+
219
+ if os.path.exists(output_dir): shutil.rmtree('output_model')
220
+ if os.path.exists(concept_dir): shutil.rmtree('concept_images')
221
+
222
+ os.makedirs(concept_dir, exist_ok=True)
223
+ os.makedirs(output_dir, exist_ok=True)
224
+
225
+ gc.collect()
226
+ torch.cuda.empty_cache()
227
+
228
+ if(prompt == "" or prompt == None):
229
+ raise gr.Error("You forgot to define your concept prompt")
230
+
231
+ for j, file_temp in enumerate(files):
232
+ file = Image.open(file_temp.name)
233
+ image = pad_image(file)
234
+ image = image.resize((training_resolution, training_resolution))
235
+ extension = file_temp.name.split(".")[1]
236
+ image = image.convert('RGB')
237
+ image.save(f'{concept_dir}/{j+1}.{extension}', quality=100)
238
+
239
+
240
+ args_general = argparse.Namespace(
241
+ train_data_dir=concept_dir,
242
+ learnable_property=type_of_thing,
243
+ placeholder_token=concept_word,
244
+ initializer_token=init_word,
245
+ resolution=training_resolution,
246
+ train_batch_size=text_train_bsz,
247
+ gradient_accumulation_steps=1,
248
+ gradient_checkpointing=True,
249
+ mixed_precision='fp16',
250
+ use_bf16=False,
251
+ max_train_steps=int(text_train_steps),
252
+ learning_rate=text_learning_rate,
253
+ scale_lr=True,
254
+ lr_scheduler="constant",
255
+ lr_warmup_steps=0,
256
+ output_dir=output_dir,
257
+ )
258
+
259
+ try:
260
+ final_result = run_textual_inversion(pipe, args_general)
261
+ except Exception as e:
262
+ raise gr.Error(e)
263
+
264
+ gc.collect()
265
+ torch.cuda.empty_cache()
266
 
267
+ return f'Finished training! Check the {output_dir} directory for saved model weights'
268
 
269
 
270
  default_img_size = 512
 
345
  inpaint_options = ["preserve non-masked portions of image", "output entire inpainted image"]
346
  inpaint_radio = gr.Radio(inpaint_options, value=inpaint_options[0], show_label=False, interactive=True)
347
 
348
+ with gr.Tab("Textual Inversion") as tab:
349
+ tab.select(lambda: StableDiffusionPipeline, [], pipe_state)
350
+
351
+ type_of_thing = gr.Dropdown(label="What would you like to train?", choices=["object", "person", "style"], value="object", interactive=True)
352
+
353
+ text_train_bsz = gr.Slider(
354
+ label="Training Batch Size",
355
+ minimum=1,
356
+ maximum=8,
357
+ step=1,
358
+ value=1,
359
+ )
360
+
361
+ files = gr.File(label=f'''Upload the images for your concept''', file_count="multiple", interactive=True, visible=True)
362
+
363
+ text_train_steps = gr.Number(label="How many steps", value=1000)
364
+
365
+ text_learning_rate = gr.Number(label="Learning Rate", value=5.e-4)
366
+
367
+ concept_word = gr.Textbox(label=f'''concept word - use a unique, made up word to avoid collisions''')
368
+ init_word = gr.Textbox(label=f'''initial word - to init the concept embedding''')
369
+
370
+ textual_inversion_button = gr.Button(value="Train Textual Inversion")
371
+
372
+ training_status = gr.Text(label="Training Status")
373
+
374
  with gr.Row():
375
  batch_size = gr.Slider(
376
  label="Batch Size", value=1, minimum=1, maximum=8, step=1
 
436
  pipe_state,
437
  pipe_kwargs,
438
  ]
439
+ outputs = [gallery, generation_details, seed]
440
 
441
  prompt.submit(generate, inputs=inputs, outputs=outputs)
442
  generate_button.click(generate, inputs=inputs, outputs=outputs)
443
 
444
+ textual_inversion_inputs = [model_name, scheduler_name, type_of_thing, files, concept_word, init_word, text_train_steps, text_train_bsz, text_learning_rate]
445
+
446
+ textual_inversion_button.click(train_textual_inversion, inputs=textual_inversion_inputs, outputs=[training_status])
447
+
448
+
449
+ #demo = gr.TabbedInterface([demo, dreambooth_tab], ["Main", "Dreambooth"])
450
+
451
  demo.queue(concurrency_count=cpu_count())
452
+
453
  demo.launch()
inpaint_noise_issue.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
textual_inversion.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import logging
18
+ import math
19
+ import os
20
+ import random
21
+ from pathlib import Path
22
+ from typing import Optional
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ from torch.utils.data import Dataset
29
+
30
+ import datasets
31
+ import diffusers
32
+ import PIL
33
+ import transformers
34
+ from accelerate import Accelerator
35
+ from accelerate.logging import get_logger
36
+ from accelerate.utils import set_seed
37
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
38
+ from diffusers.optimization import get_scheduler
39
+ from diffusers.utils import check_min_version
40
+ from diffusers.utils.import_utils import is_xformers_available
41
+ from huggingface_hub import HfFolder, Repository, create_repo, whoami
42
+
43
+ # TODO: remove and import from diffusers.utils when the new version of diffusers is released
44
+ from packaging import version
45
+ from PIL import Image
46
+ from torchvision import transforms
47
+ from tqdm.auto import tqdm
48
+ from transformers import CLIPTextModel, CLIPTokenizer
49
+
50
+
51
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
52
+ PIL_INTERPOLATION = {
53
+ "linear": PIL.Image.Resampling.BILINEAR,
54
+ "bilinear": PIL.Image.Resampling.BILINEAR,
55
+ "bicubic": PIL.Image.Resampling.BICUBIC,
56
+ "lanczos": PIL.Image.Resampling.LANCZOS,
57
+ "nearest": PIL.Image.Resampling.NEAREST,
58
+ }
59
+ else:
60
+ PIL_INTERPOLATION = {
61
+ "linear": PIL.Image.LINEAR,
62
+ "bilinear": PIL.Image.BILINEAR,
63
+ "bicubic": PIL.Image.BICUBIC,
64
+ "lanczos": PIL.Image.LANCZOS,
65
+ "nearest": PIL.Image.NEAREST,
66
+ }
67
+ # ------------------------------------------------------------------------------
68
+
69
+
70
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
71
+ check_min_version("0.10.0.dev0")
72
+
73
+
74
+ logger = get_logger(__name__)
75
+
76
+
77
+ def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
78
+ logger.info("Saving embeddings")
79
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
80
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
81
+ torch.save(learned_embeds_dict, save_path)
82
+
83
+
84
+ def parse_args():
85
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
86
+ parser.add_argument(
87
+ "--save_steps",
88
+ type=int,
89
+ default=500,
90
+ help="Save learned_embeds.bin every X updates steps.",
91
+ )
92
+ parser.add_argument(
93
+ "--only_save_embeds",
94
+ action="store_true",
95
+ default=False,
96
+ help="Save only the embeddings for the new concept.",
97
+ )
98
+ parser.add_argument(
99
+ "--pretrained_model_name_or_path",
100
+ type=str,
101
+ default=None,
102
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
103
+ )
104
+ parser.add_argument(
105
+ "--revision",
106
+ type=str,
107
+ default=None,
108
+ help="Revision of pretrained model identifier from huggingface.co/models.",
109
+ )
110
+ parser.add_argument(
111
+ "--tokenizer_name",
112
+ type=str,
113
+ default=None,
114
+ help="Pretrained tokenizer name or path if not the same as model_name",
115
+ )
116
+ parser.add_argument(
117
+ "--train_data_dir", type=str, default=None, help="A folder containing the training data."
118
+ )
119
+ parser.add_argument(
120
+ "--placeholder_token",
121
+ type=str,
122
+ default=None,
123
+ help="A token to use as a placeholder for the concept.",
124
+ )
125
+ parser.add_argument(
126
+ "--initializer_token", type=str, default=None, help="A token to use as initializer word."
127
+ )
128
+
129
+ parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
130
+ parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
131
+ parser.add_argument(
132
+ "--output_dir",
133
+ type=str,
134
+ default="text-inversion-model",
135
+ help="The output directory where the model predictions and checkpoints will be written.",
136
+ )
137
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
138
+ parser.add_argument(
139
+ "--resolution",
140
+ type=int,
141
+ default=512,
142
+ help=(
143
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
144
+ " resolution"
145
+ ),
146
+ )
147
+ parser.add_argument(
148
+ "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
149
+ )
150
+ parser.add_argument(
151
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
152
+ )
153
+ parser.add_argument("--num_train_epochs", type=int, default=100)
154
+ parser.add_argument(
155
+ "--max_train_steps",
156
+ type=int,
157
+ default=5000,
158
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
159
+ )
160
+ parser.add_argument(
161
+ "--gradient_accumulation_steps",
162
+ type=int,
163
+ default=1,
164
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
165
+ )
166
+ parser.add_argument(
167
+ "--gradient_checkpointing",
168
+ action="store_true",
169
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
170
+ )
171
+ parser.add_argument(
172
+ "--learning_rate",
173
+ type=float,
174
+ default=1e-4,
175
+ help="Initial learning rate (after the potential warmup period) to use.",
176
+ )
177
+ parser.add_argument(
178
+ "--scale_lr",
179
+ action="store_true",
180
+ default=False,
181
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
182
+ )
183
+ parser.add_argument(
184
+ "--lr_scheduler",
185
+ type=str,
186
+ default="constant",
187
+ help=(
188
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
189
+ ' "constant", "constant_with_warmup"]'
190
+ ),
191
+ )
192
+ parser.add_argument(
193
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
194
+ )
195
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
196
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
197
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
198
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
199
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
200
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
201
+ parser.add_argument(
202
+ "--hub_model_id",
203
+ type=str,
204
+ default=None,
205
+ help="The name of the repository to keep in sync with the local `output_dir`.",
206
+ )
207
+ parser.add_argument(
208
+ "--logging_dir",
209
+ type=str,
210
+ default="logs",
211
+ help=(
212
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
213
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
214
+ ),
215
+ )
216
+ parser.add_argument(
217
+ "--mixed_precision",
218
+ type=str,
219
+ default="no",
220
+ choices=["no", "fp16", "bf16"],
221
+ help=(
222
+ "Whether to use mixed precision. Choose"
223
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
224
+ "and an Nvidia Ampere GPU."
225
+ ),
226
+ )
227
+ parser.add_argument(
228
+ "--allow_tf32",
229
+ action="store_true",
230
+ help=(
231
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
232
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
233
+ ),
234
+ )
235
+ parser.add_argument(
236
+ "--report_to",
237
+ type=str,
238
+ default="tensorboard",
239
+ help=(
240
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
241
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
242
+ ),
243
+ )
244
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
245
+ parser.add_argument(
246
+ "--checkpointing_steps",
247
+ type=int,
248
+ default=500,
249
+ help=(
250
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
251
+ " training using `--resume_from_checkpoint`."
252
+ ),
253
+ )
254
+ parser.add_argument(
255
+ "--resume_from_checkpoint",
256
+ type=str,
257
+ default=None,
258
+ help=(
259
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
260
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
261
+ ),
262
+ )
263
+ parser.add_argument(
264
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
265
+ )
266
+
267
+ args = parser.parse_args()
268
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
269
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
270
+ args.local_rank = env_local_rank
271
+
272
+ #if args.train_data_dir is None:
273
+ # raise ValueError("You must specify a train data directory.")
274
+
275
+ return args
276
+
277
+
278
+ imagenet_templates_small = [
279
+ "a photo of a {}",
280
+ "a rendering of a {}",
281
+ "a cropped photo of the {}",
282
+ "the photo of a {}",
283
+ "a photo of a clean {}",
284
+ "a photo of a dirty {}",
285
+ "a dark photo of the {}",
286
+ "a photo of my {}",
287
+ "a photo of the cool {}",
288
+ "a close-up photo of a {}",
289
+ "a bright photo of the {}",
290
+ "a cropped photo of a {}",
291
+ "a photo of the {}",
292
+ "a good photo of the {}",
293
+ "a photo of one {}",
294
+ "a close-up photo of the {}",
295
+ "a rendition of the {}",
296
+ "a photo of the clean {}",
297
+ "a rendition of a {}",
298
+ "a photo of a nice {}",
299
+ "a good photo of a {}",
300
+ "a photo of the nice {}",
301
+ "a photo of the small {}",
302
+ "a photo of the weird {}",
303
+ "a photo of the large {}",
304
+ "a photo of a cool {}",
305
+ "a photo of a small {}",
306
+ ]
307
+
308
+ imagenet_style_templates_small = [
309
+ "a painting in the style of {}",
310
+ "a rendering in the style of {}",
311
+ "a cropped painting in the style of {}",
312
+ "the painting in the style of {}",
313
+ "a clean painting in the style of {}",
314
+ "a dirty painting in the style of {}",
315
+ "a dark painting in the style of {}",
316
+ "a picture in the style of {}",
317
+ "a cool painting in the style of {}",
318
+ "a close-up painting in the style of {}",
319
+ "a bright painting in the style of {}",
320
+ "a cropped painting in the style of {}",
321
+ "a good painting in the style of {}",
322
+ "a close-up painting in the style of {}",
323
+ "a rendition in the style of {}",
324
+ "a nice painting in the style of {}",
325
+ "a small painting in the style of {}",
326
+ "a weird painting in the style of {}",
327
+ "a large painting in the style of {}",
328
+ ]
329
+
330
+
331
+ class TextualInversionDataset(Dataset):
332
+ def __init__(
333
+ self,
334
+ data_root,
335
+ tokenizer,
336
+ learnable_property="object", # [object, style]
337
+ size=512,
338
+ repeats=100,
339
+ interpolation="bicubic",
340
+ flip_p=0.5,
341
+ set="train",
342
+ placeholder_token="*",
343
+ center_crop=False,
344
+ ):
345
+ self.data_root = data_root
346
+ self.tokenizer = tokenizer
347
+ self.learnable_property = learnable_property
348
+ self.size = size
349
+ self.placeholder_token = placeholder_token
350
+ self.center_crop = center_crop
351
+ self.flip_p = flip_p
352
+
353
+ self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
354
+
355
+ self.num_images = len(self.image_paths)
356
+ self._length = self.num_images
357
+
358
+ if set == "train":
359
+ self._length = self.num_images * repeats
360
+
361
+ self.interpolation = {
362
+ "linear": PIL_INTERPOLATION["linear"],
363
+ "bilinear": PIL_INTERPOLATION["bilinear"],
364
+ "bicubic": PIL_INTERPOLATION["bicubic"],
365
+ "lanczos": PIL_INTERPOLATION["lanczos"],
366
+ }[interpolation]
367
+
368
+ self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
369
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
370
+
371
+ def __len__(self):
372
+ return self._length
373
+
374
+ def __getitem__(self, i):
375
+ example = {}
376
+ image = Image.open(self.image_paths[i % self.num_images])
377
+
378
+ if not image.mode == "RGB":
379
+ image = image.convert("RGB")
380
+
381
+ placeholder_string = self.placeholder_token
382
+ text = random.choice(self.templates).format(placeholder_string)
383
+
384
+ example["input_ids"] = self.tokenizer(
385
+ text,
386
+ padding="max_length",
387
+ truncation=True,
388
+ max_length=self.tokenizer.model_max_length,
389
+ return_tensors="pt",
390
+ ).input_ids[0]
391
+
392
+ # default to score-sde preprocessing
393
+ img = np.array(image).astype(np.uint8)
394
+
395
+ if self.center_crop:
396
+ crop = min(img.shape[0], img.shape[1])
397
+ (
398
+ h,
399
+ w,
400
+ ) = (
401
+ img.shape[0],
402
+ img.shape[1],
403
+ )
404
+ img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
405
+
406
+ image = Image.fromarray(img)
407
+ image = image.resize((self.size, self.size), resample=self.interpolation)
408
+
409
+ image = self.flip_transform(image)
410
+ image = np.array(image).astype(np.uint8)
411
+ image = (image / 127.5 - 1.0).astype(np.float32)
412
+
413
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
414
+ return example
415
+
416
+
417
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
418
+ if token is None:
419
+ token = HfFolder.get_token()
420
+ if organization is None:
421
+ username = whoami(token)["name"]
422
+ return f"{username}/{model_id}"
423
+ else:
424
+ return f"{organization}/{model_id}"
425
+
426
+
427
+
428
+ def main(pipe, args_imported):
429
+
430
+ args = parse_args()
431
+ vars(args).update(vars(args_imported))
432
+
433
+ print(args)
434
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
435
+
436
+ accelerator = Accelerator(
437
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
438
+ mixed_precision=args.mixed_precision,
439
+ log_with=args.report_to,
440
+ logging_dir=logging_dir,
441
+ )
442
+
443
+ # Make one log on every process with the configuration for debugging.
444
+ logging.basicConfig(
445
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
446
+ datefmt="%m/%d/%Y %H:%M:%S",
447
+ level=logging.INFO,
448
+ )
449
+ logger.info(accelerator.state, main_process_only=False)
450
+ if accelerator.is_local_main_process:
451
+ datasets.utils.logging.set_verbosity_warning()
452
+ transformers.utils.logging.set_verbosity_warning()
453
+ diffusers.utils.logging.set_verbosity_info()
454
+ else:
455
+ datasets.utils.logging.set_verbosity_error()
456
+ transformers.utils.logging.set_verbosity_error()
457
+ diffusers.utils.logging.set_verbosity_error()
458
+
459
+ # If passed along, set the training seed now.
460
+ if args.seed is not None:
461
+ set_seed(args.seed)
462
+
463
+ # Handle the repository creation
464
+ if accelerator.is_main_process:
465
+ if args.push_to_hub:
466
+ if args.hub_model_id is None:
467
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
468
+ else:
469
+ repo_name = args.hub_model_id
470
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
471
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
472
+
473
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
474
+ if "step_*" not in gitignore:
475
+ gitignore.write("step_*\n")
476
+ if "epoch_*" not in gitignore:
477
+ gitignore.write("epoch_*\n")
478
+ elif args.output_dir is not None:
479
+ os.makedirs(args.output_dir, exist_ok=True)
480
+
481
+ # Load tokenizer
482
+ tokenizer = pipe.tokenizer
483
+
484
+ # Load scheduler and models
485
+ noise_scheduler = pipe.scheduler
486
+ text_encoder = pipe.text_encoder
487
+ vae = pipe.vae
488
+ unet = pipe.unet
489
+
490
+ # Add the placeholder token in tokenizer
491
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
492
+ if num_added_tokens == 0:
493
+ raise ValueError(
494
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
495
+ " `placeholder_token` that is not already in the tokenizer."
496
+ )
497
+
498
+ # Convert the initializer_token, placeholder_token to ids
499
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
500
+ # Check if initializer_token is a single token or a sequence of tokens
501
+ if len(token_ids) > 1:
502
+ raise ValueError("The initializer token must be a single token.")
503
+
504
+ initializer_token_id = token_ids[0]
505
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
506
+
507
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
508
+ text_encoder.resize_token_embeddings(len(tokenizer))
509
+
510
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
511
+ token_embeds = text_encoder.get_input_embeddings().weight.data
512
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
513
+
514
+ # Freeze vae and unet
515
+ vae.requires_grad_(False)
516
+ unet.requires_grad_(False)
517
+ # Freeze all parameters except for the token embeddings in text encoder
518
+ text_encoder.text_model.encoder.requires_grad_(False)
519
+ text_encoder.text_model.final_layer_norm.requires_grad_(False)
520
+ text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
521
+
522
+ if args.gradient_checkpointing:
523
+ # Keep unet in train mode if we are using gradient checkpointing to save memory.
524
+ # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
525
+ unet.train()
526
+ text_encoder.gradient_checkpointing_enable()
527
+ unet.enable_gradient_checkpointing()
528
+
529
+ if args.enable_xformers_memory_efficient_attention:
530
+ if is_xformers_available():
531
+ unet.enable_xformers_memory_efficient_attention()
532
+ else:
533
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
534
+
535
+ # Enable TF32 for faster training on Ampere GPUs,
536
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
537
+ if args.allow_tf32:
538
+ torch.backends.cuda.matmul.allow_tf32 = True
539
+
540
+ if args.scale_lr:
541
+ args.learning_rate = (
542
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
543
+ )
544
+
545
+ # Initialize the optimizer
546
+ optimizer = torch.optim.AdamW(
547
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
548
+ lr=args.learning_rate,
549
+ betas=(args.adam_beta1, args.adam_beta2),
550
+ weight_decay=args.adam_weight_decay,
551
+ eps=args.adam_epsilon,
552
+ )
553
+
554
+ # Dataset and DataLoaders creation:
555
+ train_dataset = TextualInversionDataset(
556
+ data_root=args.train_data_dir,
557
+ tokenizer=tokenizer,
558
+ size=args.resolution,
559
+ placeholder_token=args.placeholder_token,
560
+ repeats=args.repeats,
561
+ learnable_property=args.learnable_property,
562
+ center_crop=args.center_crop,
563
+ set="train",
564
+ )
565
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
566
+
567
+ # Scheduler and math around the number of training steps.
568
+ overrode_max_train_steps = False
569
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
570
+ if args.max_train_steps is None:
571
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
572
+ overrode_max_train_steps = True
573
+
574
+ lr_scheduler = get_scheduler(
575
+ args.lr_scheduler,
576
+ optimizer=optimizer,
577
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
578
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
579
+ )
580
+
581
+ # Prepare everything with our `accelerator`.
582
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
583
+ text_encoder, optimizer, train_dataloader, lr_scheduler
584
+ )
585
+
586
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
587
+ # as these models are only used for inference, keeping weights in full precision is not required.
588
+ weight_dtype = torch.float32
589
+ if accelerator.mixed_precision == "fp16":
590
+ weight_dtype = torch.float16
591
+ elif accelerator.mixed_precision == "bf16":
592
+ weight_dtype = torch.bfloat16
593
+
594
+ # Move vae and unet to device and cast to weight_dtype
595
+ unet.to(accelerator.device, dtype=weight_dtype)
596
+ vae.to(accelerator.device, dtype=weight_dtype)
597
+ text_encoder.to(accelerator.device, dtype=torch.float32)
598
+
599
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
600
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
601
+ if overrode_max_train_steps:
602
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
603
+ # Afterwards we recalculate our number of training epochs
604
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
605
+
606
+ # We need to initialize the trackers we use, and also store our configuration.
607
+ # The trackers initializes automatically on the main process.
608
+ if accelerator.is_main_process:
609
+ accelerator.init_trackers("textual_inversion", config=vars(args))
610
+
611
+ # Train!
612
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
613
+
614
+ logger.info("***** Running training *****")
615
+ logger.info(f" Num examples = {len(train_dataset)}")
616
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
617
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
618
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
619
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
620
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
621
+ global_step = 0
622
+ first_epoch = 0
623
+
624
+ # Potentially load in the weights and states from a previous save
625
+ if args.resume_from_checkpoint:
626
+ if args.resume_from_checkpoint != "latest":
627
+ path = os.path.basename(args.resume_from_checkpoint)
628
+ else:
629
+ # Get the most recent checkpoint
630
+ dirs = os.listdir(args.output_dir)
631
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
632
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
633
+ path = dirs[-1]
634
+ accelerator.print(f"Resuming from checkpoint {path}")
635
+ accelerator.load_state(os.path.join(args.output_dir, path))
636
+ global_step = int(path.split("-")[1])
637
+
638
+ resume_global_step = global_step * args.gradient_accumulation_steps
639
+ first_epoch = resume_global_step // num_update_steps_per_epoch
640
+ resume_step = resume_global_step % num_update_steps_per_epoch
641
+
642
+ # Only show the progress bar once on each machine.
643
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
644
+ progress_bar.set_description("Steps")
645
+
646
+ # keep original embeddings as reference
647
+ orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
648
+
649
+ for epoch in (range(first_epoch, args.num_train_epochs)):
650
+ text_encoder.train()
651
+ for step, batch in enumerate(train_dataloader):
652
+ # Skip steps until we reach the resumed step
653
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
654
+ if step % args.gradient_accumulation_steps == 0:
655
+ progress_bar.update(1)
656
+ continue
657
+
658
+ with accelerator.accumulate(text_encoder):
659
+ # Convert images to latent space
660
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
661
+ latents = latents * 0.18215
662
+
663
+ # Sample noise that we'll add to the latents
664
+ noise = torch.randn_like(latents)
665
+ bsz = latents.shape[0]
666
+ # Sample a random timestep for each image
667
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
668
+ timesteps = timesteps.long()
669
+
670
+ # Add noise to the latents according to the noise magnitude at each timestep
671
+ # (this is the forward diffusion process)
672
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
673
+
674
+ # Get the text embedding for conditioning
675
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
676
+
677
+ # Predict the noise residual
678
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
679
+
680
+ # Get the target for loss depending on the prediction type
681
+ if noise_scheduler.config.prediction_type == "epsilon":
682
+ target = noise
683
+ elif noise_scheduler.config.prediction_type == "v_prediction":
684
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
685
+ else:
686
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
687
+
688
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
689
+
690
+ accelerator.backward(loss)
691
+
692
+ if accelerator.num_processes > 1:
693
+ grads = text_encoder.module.get_input_embeddings().weight.grad
694
+ else:
695
+ grads = text_encoder.get_input_embeddings().weight.grad
696
+ # Get the index for tokens that we want to zero the grads for
697
+ index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
698
+ grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
699
+
700
+ optimizer.step()
701
+ lr_scheduler.step()
702
+ optimizer.zero_grad()
703
+
704
+ # Let's make sure we don't update any embedding weights besides the newly added token
705
+ index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
706
+ with torch.no_grad():
707
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
708
+ index_no_updates
709
+ ] = orig_embeds_params[index_no_updates]
710
+
711
+ # Checks if the accelerator has performed an optimization step behind the scenes
712
+ if accelerator.sync_gradients:
713
+ progress_bar.update(1)
714
+ global_step += 1
715
+ if global_step % args.save_steps == 0:
716
+ save_path = os.path.join(args.output_dir, f"{args.placeholder_token}-{global_step}.bin")
717
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
718
+
719
+ if global_step % args.checkpointing_steps == 0:
720
+ if accelerator.is_main_process:
721
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
722
+ accelerator.save_state(save_path)
723
+ logger.info(f"Saved state to {save_path}")
724
+
725
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
726
+ progress_bar.set_postfix(**logs)
727
+ accelerator.log(logs, step=global_step)
728
+
729
+ if global_step >= args.max_train_steps:
730
+ break
731
+
732
+ # Create the pipeline using using the trained modules and save it.
733
+ accelerator.wait_for_everyone()
734
+ if accelerator.is_main_process:
735
+ if args.push_to_hub and args.only_save_embeds:
736
+ logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
737
+ save_full_model = True
738
+ else:
739
+ save_full_model = not args.only_save_embeds
740
+ if save_full_model:
741
+ pipe.save_pretrained(args.output_dir)
742
+ # Save the newly trained embeddings
743
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
744
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
745
+
746
+ if args.push_to_hub:
747
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
748
+
749
+ accelerator.end_training()
750
+ text_encoder.eval()
751
+ unet.eval()
752
+ vae.eval()
753
+
754
+
755
+ if __name__ == "__main__":
756
+ pipeline = StableDiffusionPipeline.from_pretrained('andite/anything-v4.0', torch_dtype=torch.float16)
757
+
758
+ imported_args = argparse.Namespace(
759
+ train_data_dir="concept_images",
760
+ learnable_property='object',
761
+ placeholder_token='redeyegirl',
762
+ initializer_token='girl',
763
+ resolution=512,
764
+ train_batch_size=1,
765
+ gradient_accumulation_steps=1,
766
+ gradient_checkpointing=True,
767
+ mixed_precision='fp16',
768
+ use_bf16=False,
769
+ max_train_steps=1000,
770
+ learning_rate=5.0e-4,
771
+ scale_lr=False,
772
+ lr_scheduler="constant",
773
+ lr_warmup_steps=0,
774
+ output_dir="output_model",
775
+ )
776
+
777
+
778
+ main(pipeline, imported_args)