shaokun commited on
Commit
a4d7b31
1 Parent(s): bfc9705
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ */__pycache__/
2
+ */*.pyc
app.py CHANGED
@@ -1,16 +1,118 @@
1
  import gradio as gr
 
 
 
2
  import pandas as pd
 
 
 
 
3
 
4
- def load_csv(file):
5
- df = pd.read_csv(file)
6
- return df
 
 
 
 
 
 
 
 
 
7
 
8
- iface = gr.Interface(
9
- fn=load_csv,
10
- inputs="file",
11
- outputs="dataframe",
12
- title="CSV Loader",
13
- description="Load a CSV file and display its contents.",
14
- )
15
 
16
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import shutil
3
+ import zipfile
4
+ import tensorflow as tf
5
  import pandas as pd
6
+ import pathlib
7
+ import PIL.Image
8
+ import os
9
+ import subprocess
10
 
11
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
12
+ w, h = image.size
13
+ if w == h:
14
+ return image
15
+ elif w > h:
16
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
17
+ new_image.paste(image, (0, (w - h) // 2))
18
+ return new_image
19
+ else:
20
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
21
+ new_image.paste(image, ((h - w) // 2, 0))
22
+ return new_image
23
 
 
 
 
 
 
 
 
24
 
25
+ class ModelTrainer:
26
+ def __init__(self):
27
+ self.training_pictures = []
28
+ self.training_model = None
29
+
30
+ def unzip_file(self, zip_file_path):
31
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
32
+ extracted_path = zip_file_path.replace('.zip', '')
33
+ zip_ref.extractall(extracted_path)
34
+ file_names = zip_ref.namelist()
35
+ for file_name in file_names:
36
+ if file_name.endswith(('.jpeg', '.jpg', '.png')):
37
+ self.training_pictures.append(f'{extracted_path}/{file_name}')
38
+
39
+ def train(self, pretrained_model_name_or_path: str, instance_images: list | None):
40
+ output_model_name = 'a-xyz-model'
41
+ resolution = 512
42
+ repo_dir = pathlib.Path(__file__).parent
43
+ subdirs = ['train-instance', 'train-class', 'experiments']
44
+ dir_paths = []
45
+
46
+ for subdir in subdirs:
47
+ dir_path = repo_dir / subdir / output_model_name
48
+ dir_paths.append(dir_path)
49
+ shutil.rmtree(dir_path, ignore_errors=True)
50
+ os.makedirs(dir_path, exist_ok=True)
51
+
52
+ instance_data_dir, class_data_dir, output_dir = dir_paths
53
+
54
+ for i, temp_path in enumerate(instance_images):
55
+ image = PIL.Image.open(temp_path.name)
56
+ image = pad_image(image)
57
+ image = image.resize((resolution, resolution))
58
+ image = image.convert('RGB')
59
+ out_path = instance_data_dir / f'{i:03d}.jpg'
60
+ image.save(out_path, format='JPEG', quality=100)
61
+
62
+ command = [
63
+ 'python', '-u',
64
+ 'train_dreambooth_cloneofsimo_lora.py',
65
+ '--pretrained_model_name_or_path', pretrained_model_name_or_path,
66
+ '--instance_data_dir', instance_data_dir,
67
+ '--class_data_dir', class_data_dir,
68
+ '--resolution', '768',
69
+ '--output_dir', output_dir,
70
+ '--instance_prompt', 'a photo of a pwsm dog',
71
+ '--with_prior_preservation',
72
+ '--class_prompt', 'a dog',
73
+ '--prior_loss_weight', '1.0',
74
+ '--num_class_images', '100',
75
+ '--learning_rate', '0.0004',
76
+ '--train_batch_size', '1',
77
+ '--sample_batch_size', '1',
78
+ '--max_train_steps', '400',
79
+ '--gradient_accumulation_steps', '1',
80
+ '--gradient_checkpointing',
81
+ '--train_text_encoder',
82
+ '--learning_rate_text', '5e-6',
83
+ '--save_steps', '100',
84
+ '--seed', '1337',
85
+ '--lr_scheduler', 'constant',
86
+ '--lr_warmup_steps', '0'
87
+ ]
88
+
89
+ result = subprocess.run(command)
90
+ return result
91
+
92
+ def generate_picture(self, row):
93
+ num_of_training_steps, learning_rate, checkpoint_steps, abc = row
94
+ return f'Picture generated for num_of_training_steps: {num_of_training_steps}, learning_rate: {learning_rate}, checkpoint_steps: {checkpoint_steps}'
95
+
96
+ def generate_pictures(self, csv_input):
97
+ csv = pd.read_csv(csv_input.name)
98
+ result = []
99
+ for index, row in csv.iterrows():
100
+ result.append(self.generate_picture(row))
101
+ return "\n".join(str(item) for item in result)
102
+
103
+ loader = ModelTrainer()
104
+
105
+ with gr.Blocks() as demo:
106
+ with gr.Box():
107
+ instance_images = gr.Files(label='Instance images')
108
+ pretrained_model_name_or_path = gr.inputs.Textbox(lines=1, label='pretrained_model_name_or_path', default='stabilityai/stable-diffusion-2-1')
109
+ output_message = gr.Markdown()
110
+ train_button = gr.Button('Train')
111
+ train_button.click(fn=loader.train, inputs=[pretrained_model_name_or_path, instance_images], outputs=[output_message])
112
+ with gr.Box():
113
+ csv_input = gr.inputs.File(label='CSV File')
114
+ output_message2 = gr.Markdown()
115
+ generate_button = gr.Button('Generate Pictures from CSV')
116
+ generate_button.click(fn=loader.generate_pictures, inputs=[csv_input], outputs=[output_message2])
117
+
118
+ demo.launch()
lora_diffusion/FOR-cloneofsimo-LoRA ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ This 'lora_diffusion' library in this subdirectory is required by
2
+ 'train_dreambooth_cloneofsimo_lora.py' script and is the underlying library in the
3
+ https://github.com/cloneofsimo/lora project.
4
+
5
+ The 'train_dreambooth_cloneofsimo_lora.py' script, in turn, is merely a renamed copy
6
+ of 'traning_scripts/train_lora_dreambooth.py' from that same project.
lora_diffusion/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .lora import *
2
+ from .dataset import *
3
+ from .utils import *
4
+ from .preprocess_files import *
5
+ from .lora_manager import *
lora_diffusion/cli_lora_add.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union, Dict
2
+ import os
3
+ import shutil
4
+ import fire
5
+ from diffusers import StableDiffusionPipeline
6
+ from safetensors.torch import safe_open, save_file
7
+
8
+ import torch
9
+ from .lora import (
10
+ tune_lora_scale,
11
+ patch_pipe,
12
+ collapse_lora,
13
+ monkeypatch_remove_lora,
14
+ )
15
+ from .lora_manager import lora_join
16
+ from .to_ckpt_v2 import convert_to_ckpt
17
+
18
+
19
+ def _text_lora_path(path: str) -> str:
20
+ assert path.endswith(".pt"), "Only .pt files are supported"
21
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
22
+
23
+
24
+ def add(
25
+ path_1: str,
26
+ path_2: str,
27
+ output_path: str,
28
+ alpha_1: float = 0.5,
29
+ alpha_2: float = 0.5,
30
+ mode: Literal[
31
+ "lpl",
32
+ "upl",
33
+ "upl-ckpt-v2",
34
+ ] = "lpl",
35
+ with_text_lora: bool = False,
36
+ ):
37
+ print("Lora Add, mode " + mode)
38
+ if mode == "lpl":
39
+ if path_1.endswith(".pt") and path_2.endswith(".pt"):
40
+ for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
41
+ [(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
42
+ if with_text_lora
43
+ else []
44
+ ):
45
+ print("Loading", _path_1, _path_2)
46
+ out_list = []
47
+ if opt == "text_encoder":
48
+ if not os.path.exists(_path_1):
49
+ print(f"No text encoder found in {_path_1}, skipping...")
50
+ continue
51
+ if not os.path.exists(_path_2):
52
+ print(f"No text encoder found in {_path_1}, skipping...")
53
+ continue
54
+
55
+ l1 = torch.load(_path_1)
56
+ l2 = torch.load(_path_2)
57
+
58
+ l1pairs = zip(l1[::2], l1[1::2])
59
+ l2pairs = zip(l2[::2], l2[1::2])
60
+
61
+ for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
62
+ # print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
63
+ x1.data = alpha_1 * x1.data + alpha_2 * x2.data
64
+ y1.data = alpha_1 * y1.data + alpha_2 * y2.data
65
+
66
+ out_list.append(x1)
67
+ out_list.append(y1)
68
+
69
+ if opt == "unet":
70
+
71
+ print("Saving merged UNET to", output_path)
72
+ torch.save(out_list, output_path)
73
+
74
+ elif opt == "text_encoder":
75
+ print("Saving merged text encoder to", _text_lora_path(output_path))
76
+ torch.save(
77
+ out_list,
78
+ _text_lora_path(output_path),
79
+ )
80
+
81
+ elif path_1.endswith(".safetensors") and path_2.endswith(".safetensors"):
82
+ safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
83
+ safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
84
+
85
+ metadata = dict(safeloras_1.metadata())
86
+ metadata.update(dict(safeloras_2.metadata()))
87
+
88
+ ret_tensor = {}
89
+
90
+ for keys in set(list(safeloras_1.keys()) + list(safeloras_2.keys())):
91
+ if keys.startswith("text_encoder") or keys.startswith("unet"):
92
+
93
+ tens1 = safeloras_1.get_tensor(keys)
94
+ tens2 = safeloras_2.get_tensor(keys)
95
+
96
+ tens = alpha_1 * tens1 + alpha_2 * tens2
97
+ ret_tensor[keys] = tens
98
+ else:
99
+ if keys in safeloras_1.keys():
100
+
101
+ tens1 = safeloras_1.get_tensor(keys)
102
+ else:
103
+ tens1 = safeloras_2.get_tensor(keys)
104
+
105
+ ret_tensor[keys] = tens1
106
+
107
+ save_file(ret_tensor, output_path, metadata)
108
+
109
+ elif mode == "upl":
110
+
111
+ print(
112
+ f"Merging UNET/CLIP from {path_1} with LoRA from {path_2} to {output_path}. Merging ratio : {alpha_1}."
113
+ )
114
+
115
+ loaded_pipeline = StableDiffusionPipeline.from_pretrained(
116
+ path_1,
117
+ ).to("cpu")
118
+
119
+ patch_pipe(loaded_pipeline, path_2)
120
+
121
+ collapse_lora(loaded_pipeline.unet, alpha_1)
122
+ collapse_lora(loaded_pipeline.text_encoder, alpha_1)
123
+
124
+ monkeypatch_remove_lora(loaded_pipeline.unet)
125
+ monkeypatch_remove_lora(loaded_pipeline.text_encoder)
126
+
127
+ loaded_pipeline.save_pretrained(output_path)
128
+
129
+ elif mode == "upl-ckpt-v2":
130
+
131
+ assert output_path.endswith(".ckpt"), "Only .ckpt files are supported"
132
+ name = os.path.basename(output_path)[0:-5]
133
+
134
+ print(
135
+ f"You will be using {name} as the token in A1111 webui. Make sure {name} is unique enough token."
136
+ )
137
+
138
+ loaded_pipeline = StableDiffusionPipeline.from_pretrained(
139
+ path_1,
140
+ ).to("cpu")
141
+
142
+ tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False)
143
+
144
+ collapse_lora(loaded_pipeline.unet, alpha_1)
145
+ collapse_lora(loaded_pipeline.text_encoder, alpha_1)
146
+
147
+ monkeypatch_remove_lora(loaded_pipeline.unet)
148
+ monkeypatch_remove_lora(loaded_pipeline.text_encoder)
149
+
150
+ _tmp_output = output_path + ".tmp"
151
+
152
+ loaded_pipeline.save_pretrained(_tmp_output)
153
+ convert_to_ckpt(_tmp_output, output_path, as_half=True)
154
+ # remove the tmp_output folder
155
+ shutil.rmtree(_tmp_output)
156
+
157
+ keys = sorted(tok_dict.keys())
158
+ tok_catted = torch.stack([tok_dict[k] for k in keys])
159
+ ret = {
160
+ "string_to_token": {"*": torch.tensor(265)},
161
+ "string_to_param": {"*": tok_catted},
162
+ "name": name,
163
+ }
164
+
165
+ torch.save(ret, output_path[:-5] + ".pt")
166
+ print(
167
+ f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, "
168
+ )
169
+ elif mode == "ljl":
170
+ print("Using Join mode : alpha will not have an effect here.")
171
+ assert path_1.endswith(".safetensors") and path_2.endswith(
172
+ ".safetensors"
173
+ ), "Only .safetensors files are supported"
174
+
175
+ safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
176
+ safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
177
+
178
+ total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2])
179
+ save_file(total_tensor, output_path, total_metadata)
180
+
181
+ else:
182
+ print("Unknown mode", mode)
183
+ raise ValueError(f"Unknown mode {mode}")
184
+
185
+
186
+ def main():
187
+ fire.Fire(add)
lora_diffusion/cli_lora_pti.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bootstrapped from:
2
+ # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
3
+
4
+ import argparse
5
+ import hashlib
6
+ import inspect
7
+ import itertools
8
+ import math
9
+ import os
10
+ import random
11
+ import re
12
+ from pathlib import Path
13
+ from typing import Optional, List, Literal
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torch.optim as optim
18
+ import torch.utils.checkpoint
19
+ from diffusers import (
20
+ AutoencoderKL,
21
+ DDPMScheduler,
22
+ StableDiffusionPipeline,
23
+ UNet2DConditionModel,
24
+ )
25
+ from diffusers.optimization import get_scheduler
26
+ from huggingface_hub import HfFolder, Repository, whoami
27
+ from PIL import Image
28
+ from torch.utils.data import Dataset
29
+ from torchvision import transforms
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+ import wandb
33
+ import fire
34
+
35
+ from lora_diffusion import (
36
+ PivotalTuningDatasetCapation,
37
+ extract_lora_ups_down,
38
+ inject_trainable_lora,
39
+ inject_trainable_lora_extended,
40
+ inspect_lora,
41
+ save_lora_weight,
42
+ save_all,
43
+ prepare_clip_model_sets,
44
+ evaluate_pipe,
45
+ UNET_EXTENDED_TARGET_REPLACE,
46
+ )
47
+
48
+
49
+ def get_models(
50
+ pretrained_model_name_or_path,
51
+ pretrained_vae_name_or_path,
52
+ revision,
53
+ placeholder_tokens: List[str],
54
+ initializer_tokens: List[str],
55
+ device="cuda:0",
56
+ ):
57
+
58
+ tokenizer = CLIPTokenizer.from_pretrained(
59
+ pretrained_model_name_or_path,
60
+ subfolder="tokenizer",
61
+ revision=revision,
62
+ )
63
+
64
+ text_encoder = CLIPTextModel.from_pretrained(
65
+ pretrained_model_name_or_path,
66
+ subfolder="text_encoder",
67
+ revision=revision,
68
+ )
69
+
70
+ placeholder_token_ids = []
71
+
72
+ for token, init_tok in zip(placeholder_tokens, initializer_tokens):
73
+ num_added_tokens = tokenizer.add_tokens(token)
74
+ if num_added_tokens == 0:
75
+ raise ValueError(
76
+ f"The tokenizer already contains the token {token}. Please pass a different"
77
+ " `placeholder_token` that is not already in the tokenizer."
78
+ )
79
+
80
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(token)
81
+
82
+ placeholder_token_ids.append(placeholder_token_id)
83
+
84
+ # Load models and create wrapper for stable diffusion
85
+
86
+ text_encoder.resize_token_embeddings(len(tokenizer))
87
+ token_embeds = text_encoder.get_input_embeddings().weight.data
88
+ if init_tok.startswith("<rand"):
89
+ # <rand-"sigma">, e.g. <rand-0.5>
90
+ sigma_val = float(re.findall(r"<rand-(.*)>", init_tok)[0])
91
+
92
+ token_embeds[placeholder_token_id] = (
93
+ torch.randn_like(token_embeds[0]) * sigma_val
94
+ )
95
+ print(
96
+ f"Initialized {token} with random noise (sigma={sigma_val}), empirically {token_embeds[placeholder_token_id].mean().item():.3f} +- {token_embeds[placeholder_token_id].std().item():.3f}"
97
+ )
98
+ print(f"Norm : {token_embeds[placeholder_token_id].norm():.4f}")
99
+
100
+ elif init_tok == "<zero>":
101
+ token_embeds[placeholder_token_id] = torch.zeros_like(token_embeds[0])
102
+ else:
103
+ token_ids = tokenizer.encode(init_tok, add_special_tokens=False)
104
+ # Check if initializer_token is a single token or a sequence of tokens
105
+ if len(token_ids) > 1:
106
+ raise ValueError("The initializer token must be a single token.")
107
+
108
+ initializer_token_id = token_ids[0]
109
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
110
+
111
+ vae = AutoencoderKL.from_pretrained(
112
+ pretrained_vae_name_or_path or pretrained_model_name_or_path,
113
+ subfolder=None if pretrained_vae_name_or_path else "vae",
114
+ revision=None if pretrained_vae_name_or_path else revision,
115
+ )
116
+ unet = UNet2DConditionModel.from_pretrained(
117
+ pretrained_model_name_or_path,
118
+ subfolder="unet",
119
+ revision=revision,
120
+ )
121
+
122
+ return (
123
+ text_encoder.to(device),
124
+ vae.to(device),
125
+ unet.to(device),
126
+ tokenizer,
127
+ placeholder_token_ids,
128
+ )
129
+
130
+
131
+ @torch.no_grad()
132
+ def text2img_dataloader(
133
+ train_dataset,
134
+ train_batch_size,
135
+ tokenizer,
136
+ vae,
137
+ text_encoder,
138
+ cached_latents: bool = False,
139
+ ):
140
+
141
+ if cached_latents:
142
+ cached_latents_dataset = []
143
+ for idx in tqdm(range(len(train_dataset))):
144
+ batch = train_dataset[idx]
145
+ # rint(batch)
146
+ latents = vae.encode(
147
+ batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device)
148
+ ).latent_dist.sample()
149
+ latents = latents * 0.18215
150
+ batch["instance_images"] = latents.squeeze(0)
151
+ cached_latents_dataset.append(batch)
152
+
153
+ def collate_fn(examples):
154
+ input_ids = [example["instance_prompt_ids"] for example in examples]
155
+ pixel_values = [example["instance_images"] for example in examples]
156
+ pixel_values = torch.stack(pixel_values)
157
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
158
+
159
+ input_ids = tokenizer.pad(
160
+ {"input_ids": input_ids},
161
+ padding="max_length",
162
+ max_length=tokenizer.model_max_length,
163
+ return_tensors="pt",
164
+ ).input_ids
165
+
166
+ batch = {
167
+ "input_ids": input_ids,
168
+ "pixel_values": pixel_values,
169
+ }
170
+
171
+ if examples[0].get("mask", None) is not None:
172
+ batch["mask"] = torch.stack([example["mask"] for example in examples])
173
+
174
+ return batch
175
+
176
+ if cached_latents:
177
+
178
+ train_dataloader = torch.utils.data.DataLoader(
179
+ cached_latents_dataset,
180
+ batch_size=train_batch_size,
181
+ shuffle=True,
182
+ collate_fn=collate_fn,
183
+ )
184
+
185
+ print("PTI : Using cached latent.")
186
+
187
+ else:
188
+ train_dataloader = torch.utils.data.DataLoader(
189
+ train_dataset,
190
+ batch_size=train_batch_size,
191
+ shuffle=True,
192
+ collate_fn=collate_fn,
193
+ )
194
+
195
+ return train_dataloader
196
+
197
+
198
+ def inpainting_dataloader(
199
+ train_dataset, train_batch_size, tokenizer, vae, text_encoder
200
+ ):
201
+ def collate_fn(examples):
202
+ input_ids = [example["instance_prompt_ids"] for example in examples]
203
+ pixel_values = [example["instance_images"] for example in examples]
204
+ mask_values = [example["instance_masks"] for example in examples]
205
+ masked_image_values = [
206
+ example["instance_masked_images"] for example in examples
207
+ ]
208
+
209
+ # Concat class and instance examples for prior preservation.
210
+ # We do this to avoid doing two forward passes.
211
+ if examples[0].get("class_prompt_ids", None) is not None:
212
+ input_ids += [example["class_prompt_ids"] for example in examples]
213
+ pixel_values += [example["class_images"] for example in examples]
214
+ mask_values += [example["class_masks"] for example in examples]
215
+ masked_image_values += [
216
+ example["class_masked_images"] for example in examples
217
+ ]
218
+
219
+ pixel_values = (
220
+ torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
221
+ )
222
+ mask_values = (
223
+ torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
224
+ )
225
+ masked_image_values = (
226
+ torch.stack(masked_image_values)
227
+ .to(memory_format=torch.contiguous_format)
228
+ .float()
229
+ )
230
+
231
+ input_ids = tokenizer.pad(
232
+ {"input_ids": input_ids},
233
+ padding="max_length",
234
+ max_length=tokenizer.model_max_length,
235
+ return_tensors="pt",
236
+ ).input_ids
237
+
238
+ batch = {
239
+ "input_ids": input_ids,
240
+ "pixel_values": pixel_values,
241
+ "mask_values": mask_values,
242
+ "masked_image_values": masked_image_values,
243
+ }
244
+
245
+ if examples[0].get("mask", None) is not None:
246
+ batch["mask"] = torch.stack([example["mask"] for example in examples])
247
+
248
+ return batch
249
+
250
+ train_dataloader = torch.utils.data.DataLoader(
251
+ train_dataset,
252
+ batch_size=train_batch_size,
253
+ shuffle=True,
254
+ collate_fn=collate_fn,
255
+ )
256
+
257
+ return train_dataloader
258
+
259
+
260
+ def loss_step(
261
+ batch,
262
+ unet,
263
+ vae,
264
+ text_encoder,
265
+ scheduler,
266
+ train_inpainting=False,
267
+ t_mutliplier=1.0,
268
+ mixed_precision=False,
269
+ mask_temperature=1.0,
270
+ cached_latents: bool = False,
271
+ ):
272
+ weight_dtype = torch.float32
273
+ if not cached_latents:
274
+ latents = vae.encode(
275
+ batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
276
+ ).latent_dist.sample()
277
+ latents = latents * 0.18215
278
+
279
+ if train_inpainting:
280
+ masked_image_latents = vae.encode(
281
+ batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
282
+ ).latent_dist.sample()
283
+ masked_image_latents = masked_image_latents * 0.18215
284
+ mask = F.interpolate(
285
+ batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
286
+ scale_factor=1 / 8,
287
+ )
288
+ else:
289
+ latents = batch["pixel_values"]
290
+
291
+ if train_inpainting:
292
+ masked_image_latents = batch["masked_image_latents"]
293
+ mask = batch["mask_values"]
294
+
295
+ noise = torch.randn_like(latents)
296
+ bsz = latents.shape[0]
297
+
298
+ timesteps = torch.randint(
299
+ 0,
300
+ int(scheduler.config.num_train_timesteps * t_mutliplier),
301
+ (bsz,),
302
+ device=latents.device,
303
+ )
304
+ timesteps = timesteps.long()
305
+
306
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
307
+
308
+ if train_inpainting:
309
+ latent_model_input = torch.cat(
310
+ [noisy_latents, mask, masked_image_latents], dim=1
311
+ )
312
+ else:
313
+ latent_model_input = noisy_latents
314
+
315
+ if mixed_precision:
316
+ with torch.cuda.amp.autocast():
317
+
318
+ encoder_hidden_states = text_encoder(
319
+ batch["input_ids"].to(text_encoder.device)
320
+ )[0]
321
+
322
+ model_pred = unet(
323
+ latent_model_input, timesteps, encoder_hidden_states
324
+ ).sample
325
+ else:
326
+
327
+ encoder_hidden_states = text_encoder(
328
+ batch["input_ids"].to(text_encoder.device)
329
+ )[0]
330
+
331
+ model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
332
+
333
+ if scheduler.config.prediction_type == "epsilon":
334
+ target = noise
335
+ elif scheduler.config.prediction_type == "v_prediction":
336
+ target = scheduler.get_velocity(latents, noise, timesteps)
337
+ else:
338
+ raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
339
+
340
+ if batch.get("mask", None) is not None:
341
+
342
+ mask = (
343
+ batch["mask"]
344
+ .to(model_pred.device)
345
+ .reshape(
346
+ model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8
347
+ )
348
+ )
349
+ # resize to match model_pred
350
+ mask = F.interpolate(
351
+ mask.float(),
352
+ size=model_pred.shape[-2:],
353
+ mode="nearest",
354
+ )
355
+
356
+ mask = (mask + 0.01).pow(mask_temperature)
357
+
358
+ mask = mask / mask.max()
359
+
360
+ model_pred = model_pred * mask
361
+
362
+ target = target * mask
363
+
364
+ loss = (
365
+ F.mse_loss(model_pred.float(), target.float(), reduction="none")
366
+ .mean([1, 2, 3])
367
+ .mean()
368
+ )
369
+
370
+ return loss
371
+
372
+
373
+ def train_inversion(
374
+ unet,
375
+ vae,
376
+ text_encoder,
377
+ dataloader,
378
+ num_steps: int,
379
+ scheduler,
380
+ index_no_updates,
381
+ optimizer,
382
+ save_steps: int,
383
+ placeholder_token_ids,
384
+ placeholder_tokens,
385
+ save_path: str,
386
+ tokenizer,
387
+ lr_scheduler,
388
+ test_image_path: str,
389
+ cached_latents: bool,
390
+ accum_iter: int = 1,
391
+ log_wandb: bool = False,
392
+ wandb_log_prompt_cnt: int = 10,
393
+ class_token: str = "person",
394
+ train_inpainting: bool = False,
395
+ mixed_precision: bool = False,
396
+ clip_ti_decay: bool = True,
397
+ ):
398
+
399
+ progress_bar = tqdm(range(num_steps))
400
+ progress_bar.set_description("Steps")
401
+ global_step = 0
402
+
403
+ # Original Emb for TI
404
+ orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
405
+
406
+ if log_wandb:
407
+ preped_clip = prepare_clip_model_sets()
408
+
409
+ index_updates = ~index_no_updates
410
+ loss_sum = 0.0
411
+
412
+ for epoch in range(math.ceil(num_steps / len(dataloader))):
413
+ unet.eval()
414
+ text_encoder.train()
415
+ for batch in dataloader:
416
+
417
+ lr_scheduler.step()
418
+
419
+ with torch.set_grad_enabled(True):
420
+ loss = (
421
+ loss_step(
422
+ batch,
423
+ unet,
424
+ vae,
425
+ text_encoder,
426
+ scheduler,
427
+ train_inpainting=train_inpainting,
428
+ mixed_precision=mixed_precision,
429
+ cached_latents=cached_latents,
430
+ )
431
+ / accum_iter
432
+ )
433
+
434
+ loss.backward()
435
+ loss_sum += loss.detach().item()
436
+
437
+ if global_step % accum_iter == 0:
438
+ # print gradient of text encoder embedding
439
+ print(
440
+ text_encoder.get_input_embeddings()
441
+ .weight.grad[index_updates, :]
442
+ .norm(dim=-1)
443
+ .mean()
444
+ )
445
+ optimizer.step()
446
+ optimizer.zero_grad()
447
+
448
+ with torch.no_grad():
449
+
450
+ # normalize embeddings
451
+ if clip_ti_decay:
452
+ pre_norm = (
453
+ text_encoder.get_input_embeddings()
454
+ .weight[index_updates, :]
455
+ .norm(dim=-1, keepdim=True)
456
+ )
457
+
458
+ lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0])
459
+ text_encoder.get_input_embeddings().weight[
460
+ index_updates
461
+ ] = F.normalize(
462
+ text_encoder.get_input_embeddings().weight[
463
+ index_updates, :
464
+ ],
465
+ dim=-1,
466
+ ) * (
467
+ pre_norm + lambda_ * (0.4 - pre_norm)
468
+ )
469
+ print(pre_norm)
470
+
471
+ current_norm = (
472
+ text_encoder.get_input_embeddings()
473
+ .weight[index_updates, :]
474
+ .norm(dim=-1)
475
+ )
476
+
477
+ text_encoder.get_input_embeddings().weight[
478
+ index_no_updates
479
+ ] = orig_embeds_params[index_no_updates]
480
+
481
+ print(f"Current Norm : {current_norm}")
482
+
483
+ global_step += 1
484
+ progress_bar.update(1)
485
+
486
+ logs = {
487
+ "loss": loss.detach().item(),
488
+ "lr": lr_scheduler.get_last_lr()[0],
489
+ }
490
+ progress_bar.set_postfix(**logs)
491
+
492
+ if global_step % save_steps == 0:
493
+ save_all(
494
+ unet=unet,
495
+ text_encoder=text_encoder,
496
+ placeholder_token_ids=placeholder_token_ids,
497
+ placeholder_tokens=placeholder_tokens,
498
+ save_path=os.path.join(
499
+ save_path, f"step_inv_{global_step}.safetensors"
500
+ ),
501
+ save_lora=False,
502
+ )
503
+ if log_wandb:
504
+ with torch.no_grad():
505
+ pipe = StableDiffusionPipeline(
506
+ vae=vae,
507
+ text_encoder=text_encoder,
508
+ tokenizer=tokenizer,
509
+ unet=unet,
510
+ scheduler=scheduler,
511
+ safety_checker=None,
512
+ feature_extractor=None,
513
+ )
514
+
515
+ # open all images in test_image_path
516
+ images = []
517
+ for file in os.listdir(test_image_path):
518
+ if (
519
+ file.lower().endswith(".png")
520
+ or file.lower().endswith(".jpg")
521
+ or file.lower().endswith(".jpeg")
522
+ ):
523
+ images.append(
524
+ Image.open(os.path.join(test_image_path, file))
525
+ )
526
+
527
+ wandb.log({"loss": loss_sum / save_steps})
528
+ loss_sum = 0.0
529
+ wandb.log(
530
+ evaluate_pipe(
531
+ pipe,
532
+ target_images=images,
533
+ class_token=class_token,
534
+ learnt_token="".join(placeholder_tokens),
535
+ n_test=wandb_log_prompt_cnt,
536
+ n_step=50,
537
+ clip_model_sets=preped_clip,
538
+ )
539
+ )
540
+
541
+ if global_step >= num_steps:
542
+ return
543
+
544
+
545
+ def perform_tuning(
546
+ unet,
547
+ vae,
548
+ text_encoder,
549
+ dataloader,
550
+ num_steps,
551
+ scheduler,
552
+ optimizer,
553
+ save_steps: int,
554
+ placeholder_token_ids,
555
+ placeholder_tokens,
556
+ save_path,
557
+ lr_scheduler_lora,
558
+ lora_unet_target_modules,
559
+ lora_clip_target_modules,
560
+ mask_temperature,
561
+ out_name: str,
562
+ tokenizer,
563
+ test_image_path: str,
564
+ cached_latents: bool,
565
+ log_wandb: bool = False,
566
+ wandb_log_prompt_cnt: int = 10,
567
+ class_token: str = "person",
568
+ train_inpainting: bool = False,
569
+ ):
570
+
571
+ progress_bar = tqdm(range(num_steps))
572
+ progress_bar.set_description("Steps")
573
+ global_step = 0
574
+
575
+ weight_dtype = torch.float16
576
+
577
+ unet.train()
578
+ text_encoder.train()
579
+
580
+ if log_wandb:
581
+ preped_clip = prepare_clip_model_sets()
582
+
583
+ loss_sum = 0.0
584
+
585
+ for epoch in range(math.ceil(num_steps / len(dataloader))):
586
+ for batch in dataloader:
587
+ lr_scheduler_lora.step()
588
+
589
+ optimizer.zero_grad()
590
+
591
+ loss = loss_step(
592
+ batch,
593
+ unet,
594
+ vae,
595
+ text_encoder,
596
+ scheduler,
597
+ train_inpainting=train_inpainting,
598
+ t_mutliplier=0.8,
599
+ mixed_precision=True,
600
+ mask_temperature=mask_temperature,
601
+ cached_latents=cached_latents,
602
+ )
603
+ loss_sum += loss.detach().item()
604
+
605
+ loss.backward()
606
+ torch.nn.utils.clip_grad_norm_(
607
+ itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
608
+ )
609
+ optimizer.step()
610
+ progress_bar.update(1)
611
+ logs = {
612
+ "loss": loss.detach().item(),
613
+ "lr": lr_scheduler_lora.get_last_lr()[0],
614
+ }
615
+ progress_bar.set_postfix(**logs)
616
+
617
+ global_step += 1
618
+
619
+ if global_step % save_steps == 0:
620
+ save_all(
621
+ unet,
622
+ text_encoder,
623
+ placeholder_token_ids=placeholder_token_ids,
624
+ placeholder_tokens=placeholder_tokens,
625
+ save_path=os.path.join(
626
+ save_path, f"step_{global_step}.safetensors"
627
+ ),
628
+ target_replace_module_text=lora_clip_target_modules,
629
+ target_replace_module_unet=lora_unet_target_modules,
630
+ )
631
+ moved = (
632
+ torch.tensor(list(itertools.chain(*inspect_lora(unet).values())))
633
+ .mean()
634
+ .item()
635
+ )
636
+
637
+ print("LORA Unet Moved", moved)
638
+ moved = (
639
+ torch.tensor(
640
+ list(itertools.chain(*inspect_lora(text_encoder).values()))
641
+ )
642
+ .mean()
643
+ .item()
644
+ )
645
+
646
+ print("LORA CLIP Moved", moved)
647
+
648
+ if log_wandb:
649
+ with torch.no_grad():
650
+ pipe = StableDiffusionPipeline(
651
+ vae=vae,
652
+ text_encoder=text_encoder,
653
+ tokenizer=tokenizer,
654
+ unet=unet,
655
+ scheduler=scheduler,
656
+ safety_checker=None,
657
+ feature_extractor=None,
658
+ )
659
+
660
+ # open all images in test_image_path
661
+ images = []
662
+ for file in os.listdir(test_image_path):
663
+ if file.endswith(".png") or file.endswith(".jpg"):
664
+ images.append(
665
+ Image.open(os.path.join(test_image_path, file))
666
+ )
667
+
668
+ wandb.log({"loss": loss_sum / save_steps})
669
+ loss_sum = 0.0
670
+ wandb.log(
671
+ evaluate_pipe(
672
+ pipe,
673
+ target_images=images,
674
+ class_token=class_token,
675
+ learnt_token="".join(placeholder_tokens),
676
+ n_test=wandb_log_prompt_cnt,
677
+ n_step=50,
678
+ clip_model_sets=preped_clip,
679
+ )
680
+ )
681
+
682
+ if global_step >= num_steps:
683
+ break
684
+
685
+ save_all(
686
+ unet,
687
+ text_encoder,
688
+ placeholder_token_ids=placeholder_token_ids,
689
+ placeholder_tokens=placeholder_tokens,
690
+ save_path=os.path.join(save_path, f"{out_name}.safetensors"),
691
+ target_replace_module_text=lora_clip_target_modules,
692
+ target_replace_module_unet=lora_unet_target_modules,
693
+ )
694
+
695
+
696
+ def train(
697
+ instance_data_dir: str,
698
+ pretrained_model_name_or_path: str,
699
+ output_dir: str,
700
+ train_text_encoder: bool = True,
701
+ pretrained_vae_name_or_path: str = None,
702
+ revision: Optional[str] = None,
703
+ perform_inversion: bool = True,
704
+ use_template: Literal[None, "object", "style"] = None,
705
+ train_inpainting: bool = False,
706
+ placeholder_tokens: str = "",
707
+ placeholder_token_at_data: Optional[str] = None,
708
+ initializer_tokens: Optional[str] = None,
709
+ seed: int = 42,
710
+ resolution: int = 512,
711
+ color_jitter: bool = True,
712
+ train_batch_size: int = 1,
713
+ sample_batch_size: int = 1,
714
+ max_train_steps_tuning: int = 1000,
715
+ max_train_steps_ti: int = 1000,
716
+ save_steps: int = 100,
717
+ gradient_accumulation_steps: int = 4,
718
+ gradient_checkpointing: bool = False,
719
+ lora_rank: int = 4,
720
+ lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
721
+ lora_clip_target_modules={"CLIPAttention"},
722
+ lora_dropout_p: float = 0.0,
723
+ lora_scale: float = 1.0,
724
+ use_extended_lora: bool = False,
725
+ clip_ti_decay: bool = True,
726
+ learning_rate_unet: float = 1e-4,
727
+ learning_rate_text: float = 1e-5,
728
+ learning_rate_ti: float = 5e-4,
729
+ continue_inversion: bool = False,
730
+ continue_inversion_lr: Optional[float] = None,
731
+ use_face_segmentation_condition: bool = False,
732
+ cached_latents: bool = True,
733
+ use_mask_captioned_data: bool = False,
734
+ mask_temperature: float = 1.0,
735
+ scale_lr: bool = False,
736
+ lr_scheduler: str = "linear",
737
+ lr_warmup_steps: int = 0,
738
+ lr_scheduler_lora: str = "linear",
739
+ lr_warmup_steps_lora: int = 0,
740
+ weight_decay_ti: float = 0.00,
741
+ weight_decay_lora: float = 0.001,
742
+ use_8bit_adam: bool = False,
743
+ device="cuda:0",
744
+ extra_args: Optional[dict] = None,
745
+ log_wandb: bool = False,
746
+ wandb_log_prompt_cnt: int = 10,
747
+ wandb_project_name: str = "new_pti_project",
748
+ wandb_entity: str = "new_pti_entity",
749
+ proxy_token: str = "person",
750
+ enable_xformers_memory_efficient_attention: bool = False,
751
+ out_name: str = "final_lora",
752
+ ):
753
+ torch.manual_seed(seed)
754
+
755
+ if log_wandb:
756
+ wandb.init(
757
+ project=wandb_project_name,
758
+ entity=wandb_entity,
759
+ name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}",
760
+ reinit=True,
761
+ config={
762
+ **(extra_args if extra_args is not None else {}),
763
+ },
764
+ )
765
+
766
+ if output_dir is not None:
767
+ os.makedirs(output_dir, exist_ok=True)
768
+ # print(placeholder_tokens, initializer_tokens)
769
+ if len(placeholder_tokens) == 0:
770
+ placeholder_tokens = []
771
+ print("PTI : Placeholder Tokens not given, using null token")
772
+ else:
773
+ placeholder_tokens = placeholder_tokens.split("|")
774
+
775
+ assert (
776
+ sorted(placeholder_tokens) == placeholder_tokens
777
+ ), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'"
778
+
779
+ if initializer_tokens is None:
780
+ print("PTI : Initializer Tokens not given, doing random inits")
781
+ initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens)
782
+ else:
783
+ initializer_tokens = initializer_tokens.split("|")
784
+
785
+ assert len(initializer_tokens) == len(
786
+ placeholder_tokens
787
+ ), "Unequal Initializer token for Placeholder tokens."
788
+
789
+ if proxy_token is not None:
790
+ class_token = proxy_token
791
+ class_token = "".join(initializer_tokens)
792
+
793
+ if placeholder_token_at_data is not None:
794
+ tok, pat = placeholder_token_at_data.split("|")
795
+ token_map = {tok: pat}
796
+
797
+ else:
798
+ token_map = {"DUMMY": "".join(placeholder_tokens)}
799
+
800
+ print("PTI : Placeholder Tokens", placeholder_tokens)
801
+ print("PTI : Initializer Tokens", initializer_tokens)
802
+
803
+ # get the models
804
+ text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models(
805
+ pretrained_model_name_or_path,
806
+ pretrained_vae_name_or_path,
807
+ revision,
808
+ placeholder_tokens,
809
+ initializer_tokens,
810
+ device=device,
811
+ )
812
+
813
+ noise_scheduler = DDPMScheduler.from_config(
814
+ pretrained_model_name_or_path, subfolder="scheduler"
815
+ )
816
+
817
+ if gradient_checkpointing:
818
+ unet.enable_gradient_checkpointing()
819
+
820
+ if enable_xformers_memory_efficient_attention:
821
+ from diffusers.utils.import_utils import is_xformers_available
822
+
823
+ if is_xformers_available():
824
+ unet.enable_xformers_memory_efficient_attention()
825
+ else:
826
+ raise ValueError(
827
+ "xformers is not available. Make sure it is installed correctly"
828
+ )
829
+
830
+ if scale_lr:
831
+ unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size
832
+ text_encoder_lr = (
833
+ learning_rate_text * gradient_accumulation_steps * train_batch_size
834
+ )
835
+ ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size
836
+ else:
837
+ unet_lr = learning_rate_unet
838
+ text_encoder_lr = learning_rate_text
839
+ ti_lr = learning_rate_ti
840
+
841
+ train_dataset = PivotalTuningDatasetCapation(
842
+ instance_data_root=instance_data_dir,
843
+ token_map=token_map,
844
+ use_template=use_template,
845
+ tokenizer=tokenizer,
846
+ size=resolution,
847
+ color_jitter=color_jitter,
848
+ use_face_segmentation_condition=use_face_segmentation_condition,
849
+ use_mask_captioned_data=use_mask_captioned_data,
850
+ train_inpainting=train_inpainting,
851
+ )
852
+
853
+ train_dataset.blur_amount = 200
854
+
855
+ if train_inpainting:
856
+ assert not cached_latents, "Cached latents not supported for inpainting"
857
+
858
+ train_dataloader = inpainting_dataloader(
859
+ train_dataset, train_batch_size, tokenizer, vae, text_encoder
860
+ )
861
+ else:
862
+ train_dataloader = text2img_dataloader(
863
+ train_dataset,
864
+ train_batch_size,
865
+ tokenizer,
866
+ vae,
867
+ text_encoder,
868
+ cached_latents=cached_latents,
869
+ )
870
+
871
+ index_no_updates = torch.arange(len(tokenizer)) != -1
872
+
873
+ for tok_id in placeholder_token_ids:
874
+ index_no_updates[tok_id] = False
875
+
876
+ unet.requires_grad_(False)
877
+ vae.requires_grad_(False)
878
+
879
+ params_to_freeze = itertools.chain(
880
+ text_encoder.text_model.encoder.parameters(),
881
+ text_encoder.text_model.final_layer_norm.parameters(),
882
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
883
+ )
884
+ for param in params_to_freeze:
885
+ param.requires_grad = False
886
+
887
+ if cached_latents:
888
+ vae = None
889
+ # STEP 1 : Perform Inversion
890
+ if perform_inversion:
891
+ ti_optimizer = optim.AdamW(
892
+ text_encoder.get_input_embeddings().parameters(),
893
+ lr=ti_lr,
894
+ betas=(0.9, 0.999),
895
+ eps=1e-08,
896
+ weight_decay=weight_decay_ti,
897
+ )
898
+
899
+ lr_scheduler = get_scheduler(
900
+ lr_scheduler,
901
+ optimizer=ti_optimizer,
902
+ num_warmup_steps=lr_warmup_steps,
903
+ num_training_steps=max_train_steps_ti,
904
+ )
905
+
906
+ train_inversion(
907
+ unet,
908
+ vae,
909
+ text_encoder,
910
+ train_dataloader,
911
+ max_train_steps_ti,
912
+ cached_latents=cached_latents,
913
+ accum_iter=gradient_accumulation_steps,
914
+ scheduler=noise_scheduler,
915
+ index_no_updates=index_no_updates,
916
+ optimizer=ti_optimizer,
917
+ lr_scheduler=lr_scheduler,
918
+ save_steps=save_steps,
919
+ placeholder_tokens=placeholder_tokens,
920
+ placeholder_token_ids=placeholder_token_ids,
921
+ save_path=output_dir,
922
+ test_image_path=instance_data_dir,
923
+ log_wandb=log_wandb,
924
+ wandb_log_prompt_cnt=wandb_log_prompt_cnt,
925
+ class_token=class_token,
926
+ train_inpainting=train_inpainting,
927
+ mixed_precision=False,
928
+ tokenizer=tokenizer,
929
+ clip_ti_decay=clip_ti_decay,
930
+ )
931
+
932
+ del ti_optimizer
933
+
934
+ # Next perform Tuning with LoRA:
935
+ if not use_extended_lora:
936
+ unet_lora_params, _ = inject_trainable_lora(
937
+ unet,
938
+ r=lora_rank,
939
+ target_replace_module=lora_unet_target_modules,
940
+ dropout_p=lora_dropout_p,
941
+ scale=lora_scale,
942
+ )
943
+ else:
944
+ print("PTI : USING EXTENDED UNET!!!")
945
+ lora_unet_target_modules = (
946
+ lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE
947
+ )
948
+ print("PTI : Will replace modules: ", lora_unet_target_modules)
949
+
950
+ unet_lora_params, _ = inject_trainable_lora_extended(
951
+ unet, r=lora_rank, target_replace_module=lora_unet_target_modules
952
+ )
953
+ print(f"PTI : has {len(unet_lora_params)} lora")
954
+
955
+ print("PTI : Before training:")
956
+ inspect_lora(unet)
957
+
958
+ params_to_optimize = [
959
+ {"params": itertools.chain(*unet_lora_params), "lr": unet_lr},
960
+ ]
961
+
962
+ text_encoder.requires_grad_(False)
963
+
964
+ if continue_inversion:
965
+ params_to_optimize += [
966
+ {
967
+ "params": text_encoder.get_input_embeddings().parameters(),
968
+ "lr": continue_inversion_lr
969
+ if continue_inversion_lr is not None
970
+ else ti_lr,
971
+ }
972
+ ]
973
+ text_encoder.requires_grad_(True)
974
+ params_to_freeze = itertools.chain(
975
+ text_encoder.text_model.encoder.parameters(),
976
+ text_encoder.text_model.final_layer_norm.parameters(),
977
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
978
+ )
979
+ for param in params_to_freeze:
980
+ param.requires_grad = False
981
+ else:
982
+ text_encoder.requires_grad_(False)
983
+ if train_text_encoder:
984
+ text_encoder_lora_params, _ = inject_trainable_lora(
985
+ text_encoder,
986
+ target_replace_module=lora_clip_target_modules,
987
+ r=lora_rank,
988
+ )
989
+ params_to_optimize += [
990
+ {
991
+ "params": itertools.chain(*text_encoder_lora_params),
992
+ "lr": text_encoder_lr,
993
+ }
994
+ ]
995
+ inspect_lora(text_encoder)
996
+
997
+ lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora)
998
+
999
+ unet.train()
1000
+ if train_text_encoder:
1001
+ text_encoder.train()
1002
+
1003
+ train_dataset.blur_amount = 70
1004
+
1005
+ lr_scheduler_lora = get_scheduler(
1006
+ lr_scheduler_lora,
1007
+ optimizer=lora_optimizers,
1008
+ num_warmup_steps=lr_warmup_steps_lora,
1009
+ num_training_steps=max_train_steps_tuning,
1010
+ )
1011
+
1012
+ perform_tuning(
1013
+ unet,
1014
+ vae,
1015
+ text_encoder,
1016
+ train_dataloader,
1017
+ max_train_steps_tuning,
1018
+ cached_latents=cached_latents,
1019
+ scheduler=noise_scheduler,
1020
+ optimizer=lora_optimizers,
1021
+ save_steps=save_steps,
1022
+ placeholder_tokens=placeholder_tokens,
1023
+ placeholder_token_ids=placeholder_token_ids,
1024
+ save_path=output_dir,
1025
+ lr_scheduler_lora=lr_scheduler_lora,
1026
+ lora_unet_target_modules=lora_unet_target_modules,
1027
+ lora_clip_target_modules=lora_clip_target_modules,
1028
+ mask_temperature=mask_temperature,
1029
+ tokenizer=tokenizer,
1030
+ out_name=out_name,
1031
+ test_image_path=instance_data_dir,
1032
+ log_wandb=log_wandb,
1033
+ wandb_log_prompt_cnt=wandb_log_prompt_cnt,
1034
+ class_token=class_token,
1035
+ train_inpainting=train_inpainting,
1036
+ )
1037
+
1038
+
1039
+ def main():
1040
+ fire.Fire(train)
lora_diffusion/cli_pt_to_safetensors.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import fire
4
+ import torch
5
+ from lora_diffusion import (
6
+ DEFAULT_TARGET_REPLACE,
7
+ TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
8
+ UNET_DEFAULT_TARGET_REPLACE,
9
+ convert_loras_to_safeloras_with_embeds,
10
+ safetensors_available,
11
+ )
12
+
13
+ _target_by_name = {
14
+ "unet": UNET_DEFAULT_TARGET_REPLACE,
15
+ "text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
16
+ }
17
+
18
+
19
+ def convert(*paths, outpath, overwrite=False, **settings):
20
+ """
21
+ Converts one or more pytorch Lora and/or Textual Embedding pytorch files
22
+ into a safetensor file.
23
+
24
+ Pass all the input paths as arguments. Whether they are Textual Embedding
25
+ or Lora models will be auto-detected.
26
+
27
+ For Lora models, their name will be taken from the path, i.e.
28
+ "lora_weight.pt" => unet
29
+ "lora_weight.text_encoder.pt" => text_encoder
30
+
31
+ You can also set target_modules and/or rank by providing an argument prefixed
32
+ by the name.
33
+
34
+ So a complete example might be something like:
35
+
36
+ ```
37
+ python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8
38
+ ```
39
+ """
40
+ modelmap = {}
41
+ embeds = {}
42
+
43
+ if os.path.exists(outpath) and not overwrite:
44
+ raise ValueError(
45
+ f"Output path {outpath} already exists, and overwrite is not True"
46
+ )
47
+
48
+ for path in paths:
49
+ data = torch.load(path)
50
+
51
+ if isinstance(data, dict):
52
+ print(f"Loading textual inversion embeds {data.keys()} from {path}")
53
+ embeds.update(data)
54
+
55
+ else:
56
+ name_parts = os.path.split(path)[1].split(".")
57
+ name = name_parts[-2] if len(name_parts) > 2 else "unet"
58
+
59
+ model_settings = {
60
+ "target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE),
61
+ "rank": 4,
62
+ }
63
+
64
+ prefix = f"{name}."
65
+
66
+ arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) }
67
+ model_settings = { **model_settings, **arg_settings }
68
+
69
+ print(f"Loading Lora for {name} from {path} with settings {model_settings}")
70
+
71
+ modelmap[name] = (
72
+ path,
73
+ model_settings["target_modules"],
74
+ model_settings["rank"],
75
+ )
76
+
77
+ convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath)
78
+
79
+
80
+ def main():
81
+ fire.Fire(convert)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
lora_diffusion/cli_svd.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ from diffusers import StableDiffusionPipeline
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .lora import (
7
+ save_all,
8
+ _find_modules,
9
+ LoraInjectedConv2d,
10
+ LoraInjectedLinear,
11
+ inject_trainable_lora,
12
+ inject_trainable_lora_extended,
13
+ )
14
+
15
+
16
+ def _iter_lora(model):
17
+ for module in model.modules():
18
+ if isinstance(module, LoraInjectedConv2d) or isinstance(
19
+ module, LoraInjectedLinear
20
+ ):
21
+ yield module
22
+
23
+
24
+ def overwrite_base(base_model, tuned_model, rank, clamp_quantile):
25
+ device = base_model.device
26
+ dtype = base_model.dtype
27
+
28
+ for lor_base, lor_tune in zip(_iter_lora(base_model), _iter_lora(tuned_model)):
29
+
30
+ if isinstance(lor_base, LoraInjectedLinear):
31
+ residual = lor_tune.linear.weight.data - lor_base.linear.weight.data
32
+ # SVD on residual
33
+ print("Distill Linear shape ", residual.shape)
34
+ residual = residual.float()
35
+ U, S, Vh = torch.linalg.svd(residual)
36
+ U = U[:, :rank]
37
+ S = S[:rank]
38
+ U = U @ torch.diag(S)
39
+
40
+ Vh = Vh[:rank, :]
41
+
42
+ dist = torch.cat([U.flatten(), Vh.flatten()])
43
+ hi_val = torch.quantile(dist, clamp_quantile)
44
+ low_val = -hi_val
45
+
46
+ U = U.clamp(low_val, hi_val)
47
+ Vh = Vh.clamp(low_val, hi_val)
48
+
49
+ assert lor_base.lora_up.weight.shape == U.shape
50
+ assert lor_base.lora_down.weight.shape == Vh.shape
51
+
52
+ lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
53
+ lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
54
+
55
+ if isinstance(lor_base, LoraInjectedConv2d):
56
+ residual = lor_tune.conv.weight.data - lor_base.conv.weight.data
57
+ print("Distill Conv shape ", residual.shape)
58
+
59
+ residual = residual.float()
60
+ residual = residual.flatten(start_dim=1)
61
+
62
+ # SVD on residual
63
+ U, S, Vh = torch.linalg.svd(residual)
64
+ U = U[:, :rank]
65
+ S = S[:rank]
66
+ U = U @ torch.diag(S)
67
+
68
+ Vh = Vh[:rank, :]
69
+
70
+ dist = torch.cat([U.flatten(), Vh.flatten()])
71
+ hi_val = torch.quantile(dist, clamp_quantile)
72
+ low_val = -hi_val
73
+
74
+ U = U.clamp(low_val, hi_val)
75
+ Vh = Vh.clamp(low_val, hi_val)
76
+
77
+ # U is (out_channels, rank) with 1x1 conv. So,
78
+ U = U.reshape(U.shape[0], U.shape[1], 1, 1)
79
+ # V is (rank, in_channels * kernel_size1 * kernel_size2)
80
+ # now reshape:
81
+ Vh = Vh.reshape(
82
+ Vh.shape[0],
83
+ lor_base.conv.in_channels,
84
+ lor_base.conv.kernel_size[0],
85
+ lor_base.conv.kernel_size[1],
86
+ )
87
+
88
+ assert lor_base.lora_up.weight.shape == U.shape
89
+ assert lor_base.lora_down.weight.shape == Vh.shape
90
+
91
+ lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
92
+ lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
93
+
94
+
95
+ def svd_distill(
96
+ target_model: str,
97
+ base_model: str,
98
+ rank: int = 4,
99
+ clamp_quantile: float = 0.99,
100
+ device: str = "cuda:0",
101
+ save_path: str = "svd_distill.safetensors",
102
+ ):
103
+ pipe_base = StableDiffusionPipeline.from_pretrained(
104
+ base_model, torch_dtype=torch.float16
105
+ ).to(device)
106
+
107
+ pipe_tuned = StableDiffusionPipeline.from_pretrained(
108
+ target_model, torch_dtype=torch.float16
109
+ ).to(device)
110
+
111
+ # Inject unet
112
+ _ = inject_trainable_lora_extended(pipe_base.unet, r=rank)
113
+ _ = inject_trainable_lora_extended(pipe_tuned.unet, r=rank)
114
+
115
+ overwrite_base(
116
+ pipe_base.unet, pipe_tuned.unet, rank=rank, clamp_quantile=clamp_quantile
117
+ )
118
+
119
+ # Inject text encoder
120
+ _ = inject_trainable_lora(
121
+ pipe_base.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
122
+ )
123
+ _ = inject_trainable_lora(
124
+ pipe_tuned.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
125
+ )
126
+
127
+ overwrite_base(
128
+ pipe_base.text_encoder,
129
+ pipe_tuned.text_encoder,
130
+ rank=rank,
131
+ clamp_quantile=clamp_quantile,
132
+ )
133
+
134
+ save_all(
135
+ unet=pipe_base.unet,
136
+ text_encoder=pipe_base.text_encoder,
137
+ placeholder_token_ids=None,
138
+ placeholder_tokens=None,
139
+ save_path=save_path,
140
+ save_lora=True,
141
+ save_ti=False,
142
+ )
143
+
144
+
145
+ def main():
146
+ fire.Fire(svd_distill)
lora_diffusion/dataset.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ from PIL import Image
6
+ from torch import zeros_like
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+ import glob
10
+ from .preprocess_files import face_mask_google_mediapipe
11
+
12
+ OBJECT_TEMPLATE = [
13
+ "a photo of a {}",
14
+ "a rendering of a {}",
15
+ "a cropped photo of the {}",
16
+ "the photo of a {}",
17
+ "a photo of a clean {}",
18
+ "a photo of a dirty {}",
19
+ "a dark photo of the {}",
20
+ "a photo of my {}",
21
+ "a photo of the cool {}",
22
+ "a close-up photo of a {}",
23
+ "a bright photo of the {}",
24
+ "a cropped photo of a {}",
25
+ "a photo of the {}",
26
+ "a good photo of the {}",
27
+ "a photo of one {}",
28
+ "a close-up photo of the {}",
29
+ "a rendition of the {}",
30
+ "a photo of the clean {}",
31
+ "a rendition of a {}",
32
+ "a photo of a nice {}",
33
+ "a good photo of a {}",
34
+ "a photo of the nice {}",
35
+ "a photo of the small {}",
36
+ "a photo of the weird {}",
37
+ "a photo of the large {}",
38
+ "a photo of a cool {}",
39
+ "a photo of a small {}",
40
+ ]
41
+
42
+ STYLE_TEMPLATE = [
43
+ "a painting in the style of {}",
44
+ "a rendering in the style of {}",
45
+ "a cropped painting in the style of {}",
46
+ "the painting in the style of {}",
47
+ "a clean painting in the style of {}",
48
+ "a dirty painting in the style of {}",
49
+ "a dark painting in the style of {}",
50
+ "a picture in the style of {}",
51
+ "a cool painting in the style of {}",
52
+ "a close-up painting in the style of {}",
53
+ "a bright painting in the style of {}",
54
+ "a cropped painting in the style of {}",
55
+ "a good painting in the style of {}",
56
+ "a close-up painting in the style of {}",
57
+ "a rendition in the style of {}",
58
+ "a nice painting in the style of {}",
59
+ "a small painting in the style of {}",
60
+ "a weird painting in the style of {}",
61
+ "a large painting in the style of {}",
62
+ ]
63
+
64
+ NULL_TEMPLATE = ["{}"]
65
+
66
+ TEMPLATE_MAP = {
67
+ "object": OBJECT_TEMPLATE,
68
+ "style": STYLE_TEMPLATE,
69
+ "null": NULL_TEMPLATE,
70
+ }
71
+
72
+
73
+ def _randomset(lis):
74
+ ret = []
75
+ for i in range(len(lis)):
76
+ if random.random() < 0.5:
77
+ ret.append(lis[i])
78
+ return ret
79
+
80
+
81
+ def _shuffle(lis):
82
+
83
+ return random.sample(lis, len(lis))
84
+
85
+
86
+ def _get_cutout_holes(
87
+ height,
88
+ width,
89
+ min_holes=8,
90
+ max_holes=32,
91
+ min_height=16,
92
+ max_height=128,
93
+ min_width=16,
94
+ max_width=128,
95
+ ):
96
+ holes = []
97
+ for _n in range(random.randint(min_holes, max_holes)):
98
+ hole_height = random.randint(min_height, max_height)
99
+ hole_width = random.randint(min_width, max_width)
100
+ y1 = random.randint(0, height - hole_height)
101
+ x1 = random.randint(0, width - hole_width)
102
+ y2 = y1 + hole_height
103
+ x2 = x1 + hole_width
104
+ holes.append((x1, y1, x2, y2))
105
+ return holes
106
+
107
+
108
+ def _generate_random_mask(image):
109
+ mask = zeros_like(image[:1])
110
+ holes = _get_cutout_holes(mask.shape[1], mask.shape[2])
111
+ for (x1, y1, x2, y2) in holes:
112
+ mask[:, y1:y2, x1:x2] = 1.0
113
+ if random.uniform(0, 1) < 0.25:
114
+ mask.fill_(1.0)
115
+ masked_image = image * (mask < 0.5)
116
+ return mask, masked_image
117
+
118
+
119
+ class PivotalTuningDatasetCapation(Dataset):
120
+ """
121
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
122
+ It pre-processes the images and the tokenizes prompts.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ instance_data_root,
128
+ tokenizer,
129
+ token_map: Optional[dict] = None,
130
+ use_template: Optional[str] = None,
131
+ size=512,
132
+ h_flip=True,
133
+ color_jitter=False,
134
+ resize=True,
135
+ use_mask_captioned_data=False,
136
+ use_face_segmentation_condition=False,
137
+ train_inpainting=False,
138
+ blur_amount: int = 70,
139
+ ):
140
+ self.size = size
141
+ self.tokenizer = tokenizer
142
+ self.resize = resize
143
+ self.train_inpainting = train_inpainting
144
+
145
+ instance_data_root = Path(instance_data_root)
146
+ if not instance_data_root.exists():
147
+ raise ValueError("Instance images root doesn't exists.")
148
+
149
+ self.instance_images_path = []
150
+ self.mask_path = []
151
+
152
+ assert not (
153
+ use_mask_captioned_data and use_template
154
+ ), "Can't use both mask caption data and template."
155
+
156
+ # Prepare the instance images
157
+ if use_mask_captioned_data:
158
+ src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg")
159
+ for f in src_imgs:
160
+ idx = int(str(Path(f).stem).split(".")[0])
161
+ mask_path = f"{instance_data_root}/{idx}.mask.png"
162
+
163
+ if Path(mask_path).exists():
164
+ self.instance_images_path.append(f)
165
+ self.mask_path.append(mask_path)
166
+ else:
167
+ print(f"Mask not found for {f}")
168
+
169
+ self.captions = open(f"{instance_data_root}/caption.txt").readlines()
170
+
171
+ else:
172
+ possibily_src_images = (
173
+ glob.glob(str(instance_data_root) + "/*.jpg")
174
+ + glob.glob(str(instance_data_root) + "/*.png")
175
+ + glob.glob(str(instance_data_root) + "/*.jpeg")
176
+ )
177
+ possibily_src_images = (
178
+ set(possibily_src_images)
179
+ - set(glob.glob(str(instance_data_root) + "/*mask.png"))
180
+ - set([str(instance_data_root) + "/caption.txt"])
181
+ )
182
+
183
+ self.instance_images_path = list(set(possibily_src_images))
184
+ self.captions = [
185
+ x.split("/")[-1].split(".")[0] for x in self.instance_images_path
186
+ ]
187
+
188
+ assert (
189
+ len(self.instance_images_path) > 0
190
+ ), "No images found in the instance data root."
191
+
192
+ self.instance_images_path = sorted(self.instance_images_path)
193
+
194
+ self.use_mask = use_face_segmentation_condition or use_mask_captioned_data
195
+ self.use_mask_captioned_data = use_mask_captioned_data
196
+
197
+ if use_face_segmentation_condition:
198
+
199
+ for idx in range(len(self.instance_images_path)):
200
+ targ = f"{instance_data_root}/{idx}.mask.png"
201
+ # see if the mask exists
202
+ if not Path(targ).exists():
203
+ print(f"Mask not found for {targ}")
204
+
205
+ print(
206
+ "Warning : this will pre-process all the images in the instance data root."
207
+ )
208
+
209
+ if len(self.mask_path) > 0:
210
+ print(
211
+ "Warning : masks already exists, but will be overwritten."
212
+ )
213
+
214
+ masks = face_mask_google_mediapipe(
215
+ [
216
+ Image.open(f).convert("RGB")
217
+ for f in self.instance_images_path
218
+ ]
219
+ )
220
+ for idx, mask in enumerate(masks):
221
+ mask.save(f"{instance_data_root}/{idx}.mask.png")
222
+
223
+ break
224
+
225
+ for idx in range(len(self.instance_images_path)):
226
+ self.mask_path.append(f"{instance_data_root}/{idx}.mask.png")
227
+
228
+ self.num_instance_images = len(self.instance_images_path)
229
+ self.token_map = token_map
230
+
231
+ self.use_template = use_template
232
+ if use_template is not None:
233
+ self.templates = TEMPLATE_MAP[use_template]
234
+
235
+ self._length = self.num_instance_images
236
+
237
+ self.h_flip = h_flip
238
+ self.image_transforms = transforms.Compose(
239
+ [
240
+ transforms.Resize(
241
+ size, interpolation=transforms.InterpolationMode.BILINEAR
242
+ )
243
+ if resize
244
+ else transforms.Lambda(lambda x: x),
245
+ transforms.ColorJitter(0.1, 0.1)
246
+ if color_jitter
247
+ else transforms.Lambda(lambda x: x),
248
+ transforms.CenterCrop(size),
249
+ transforms.ToTensor(),
250
+ transforms.Normalize([0.5], [0.5]),
251
+ ]
252
+ )
253
+
254
+ self.blur_amount = blur_amount
255
+
256
+ def __len__(self):
257
+ return self._length
258
+
259
+ def __getitem__(self, index):
260
+ example = {}
261
+ instance_image = Image.open(
262
+ self.instance_images_path[index % self.num_instance_images]
263
+ )
264
+ if not instance_image.mode == "RGB":
265
+ instance_image = instance_image.convert("RGB")
266
+ example["instance_images"] = self.image_transforms(instance_image)
267
+
268
+ if self.train_inpainting:
269
+ (
270
+ example["instance_masks"],
271
+ example["instance_masked_images"],
272
+ ) = _generate_random_mask(example["instance_images"])
273
+
274
+ if self.use_template:
275
+ assert self.token_map is not None
276
+ input_tok = list(self.token_map.values())[0]
277
+
278
+ text = random.choice(self.templates).format(input_tok)
279
+ else:
280
+ text = self.captions[index % self.num_instance_images].strip()
281
+
282
+ if self.token_map is not None:
283
+ for token, value in self.token_map.items():
284
+ text = text.replace(token, value)
285
+
286
+ print(text)
287
+
288
+ if self.use_mask:
289
+ example["mask"] = (
290
+ self.image_transforms(
291
+ Image.open(self.mask_path[index % self.num_instance_images])
292
+ )
293
+ * 0.5
294
+ + 1.0
295
+ )
296
+
297
+ if self.h_flip and random.random() > 0.5:
298
+ hflip = transforms.RandomHorizontalFlip(p=1)
299
+
300
+ example["instance_images"] = hflip(example["instance_images"])
301
+ if self.use_mask:
302
+ example["mask"] = hflip(example["mask"])
303
+
304
+ example["instance_prompt_ids"] = self.tokenizer(
305
+ text,
306
+ padding="do_not_pad",
307
+ truncation=True,
308
+ max_length=self.tokenizer.model_max_length,
309
+ ).input_ids
310
+
311
+ return example
lora_diffusion/lora.py ADDED
@@ -0,0 +1,1110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ from itertools import groupby
4
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ try:
13
+ from safetensors.torch import safe_open
14
+ from safetensors.torch import save_file as safe_save
15
+
16
+ safetensors_available = True
17
+ except ImportError:
18
+ from .safe_open import safe_open
19
+
20
+ def safe_save(
21
+ tensors: Dict[str, torch.Tensor],
22
+ filename: str,
23
+ metadata: Optional[Dict[str, str]] = None,
24
+ ) -> None:
25
+ raise EnvironmentError(
26
+ "Saving safetensors requires the safetensors library. Please install with pip or similar."
27
+ )
28
+
29
+ safetensors_available = False
30
+
31
+
32
+ class LoraInjectedLinear(nn.Module):
33
+ def __init__(
34
+ self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
35
+ ):
36
+ super().__init__()
37
+
38
+ if r > min(in_features, out_features):
39
+ raise ValueError(
40
+ f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
41
+ )
42
+ self.r = r
43
+ self.linear = nn.Linear(in_features, out_features, bias)
44
+ self.lora_down = nn.Linear(in_features, r, bias=False)
45
+ self.dropout = nn.Dropout(dropout_p)
46
+ self.lora_up = nn.Linear(r, out_features, bias=False)
47
+ self.scale = scale
48
+ self.selector = nn.Identity()
49
+
50
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
51
+ nn.init.zeros_(self.lora_up.weight)
52
+
53
+ def forward(self, input):
54
+ return (
55
+ self.linear(input)
56
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
57
+ * self.scale
58
+ )
59
+
60
+ def realize_as_lora(self):
61
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
62
+
63
+ def set_selector_from_diag(self, diag: torch.Tensor):
64
+ # diag is a 1D tensor of size (r,)
65
+ assert diag.shape == (self.r,)
66
+ self.selector = nn.Linear(self.r, self.r, bias=False)
67
+ self.selector.weight.data = torch.diag(diag)
68
+ self.selector.weight.data = self.selector.weight.data.to(
69
+ self.lora_up.weight.device
70
+ ).to(self.lora_up.weight.dtype)
71
+
72
+
73
+ class LoraInjectedConv2d(nn.Module):
74
+ def __init__(
75
+ self,
76
+ in_channels: int,
77
+ out_channels: int,
78
+ kernel_size,
79
+ stride=1,
80
+ padding=0,
81
+ dilation=1,
82
+ groups: int = 1,
83
+ bias: bool = True,
84
+ r: int = 4,
85
+ dropout_p: float = 0.1,
86
+ scale: float = 1.0,
87
+ ):
88
+ super().__init__()
89
+ if r > min(in_channels, out_channels):
90
+ raise ValueError(
91
+ f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
92
+ )
93
+ self.r = r
94
+ self.conv = nn.Conv2d(
95
+ in_channels=in_channels,
96
+ out_channels=out_channels,
97
+ kernel_size=kernel_size,
98
+ stride=stride,
99
+ padding=padding,
100
+ dilation=dilation,
101
+ groups=groups,
102
+ bias=bias,
103
+ )
104
+
105
+ self.lora_down = nn.Conv2d(
106
+ in_channels=in_channels,
107
+ out_channels=r,
108
+ kernel_size=kernel_size,
109
+ stride=stride,
110
+ padding=padding,
111
+ dilation=dilation,
112
+ groups=groups,
113
+ bias=False,
114
+ )
115
+ self.dropout = nn.Dropout(dropout_p)
116
+ self.lora_up = nn.Conv2d(
117
+ in_channels=r,
118
+ out_channels=out_channels,
119
+ kernel_size=1,
120
+ stride=1,
121
+ padding=0,
122
+ bias=False,
123
+ )
124
+ self.selector = nn.Identity()
125
+ self.scale = scale
126
+
127
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
128
+ nn.init.zeros_(self.lora_up.weight)
129
+
130
+ def forward(self, input):
131
+ return (
132
+ self.conv(input)
133
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
134
+ * self.scale
135
+ )
136
+
137
+ def realize_as_lora(self):
138
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
139
+
140
+ def set_selector_from_diag(self, diag: torch.Tensor):
141
+ # diag is a 1D tensor of size (r,)
142
+ assert diag.shape == (self.r,)
143
+ self.selector = nn.Conv2d(
144
+ in_channels=self.r,
145
+ out_channels=self.r,
146
+ kernel_size=1,
147
+ stride=1,
148
+ padding=0,
149
+ bias=False,
150
+ )
151
+ self.selector.weight.data = torch.diag(diag)
152
+
153
+ # same device + dtype as lora_up
154
+ self.selector.weight.data = self.selector.weight.data.to(
155
+ self.lora_up.weight.device
156
+ ).to(self.lora_up.weight.dtype)
157
+
158
+
159
+ UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
160
+
161
+ UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
162
+
163
+ TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
164
+
165
+ TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
166
+
167
+ DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
168
+
169
+ EMBED_FLAG = "<embed>"
170
+
171
+
172
+ def _find_children(
173
+ model,
174
+ search_class: List[Type[nn.Module]] = [nn.Linear],
175
+ ):
176
+ """
177
+ Find all modules of a certain class (or union of classes).
178
+
179
+ Returns all matching modules, along with the parent of those moduless and the
180
+ names they are referenced by.
181
+ """
182
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
183
+ for parent in model.modules():
184
+ for name, module in parent.named_children():
185
+ if any([isinstance(module, _class) for _class in search_class]):
186
+ yield parent, name, module
187
+
188
+
189
+ def _find_modules_v2(
190
+ model,
191
+ ancestor_class: Optional[Set[str]] = None,
192
+ search_class: List[Type[nn.Module]] = [nn.Linear],
193
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [
194
+ LoraInjectedLinear,
195
+ LoraInjectedConv2d,
196
+ ],
197
+ ):
198
+ """
199
+ Find all modules of a certain class (or union of classes) that are direct or
200
+ indirect descendants of other modules of a certain class (or union of classes).
201
+
202
+ Returns all matching modules, along with the parent of those moduless and the
203
+ names they are referenced by.
204
+ """
205
+
206
+ # Get the targets we should replace all linears under
207
+ if ancestor_class is not None:
208
+ ancestors = (
209
+ module
210
+ for module in model.modules()
211
+ if module.__class__.__name__ in ancestor_class
212
+ )
213
+ else:
214
+ # this, incase you want to naively iterate over all modules.
215
+ ancestors = [module for module in model.modules()]
216
+
217
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
218
+ for ancestor in ancestors:
219
+ for fullname, module in ancestor.named_modules():
220
+ if any([isinstance(module, _class) for _class in search_class]):
221
+ # Find the direct parent if this is a descendant, not a child, of target
222
+ *path, name = fullname.split(".")
223
+ parent = ancestor
224
+ while path:
225
+ parent = parent.get_submodule(path.pop(0))
226
+ # Skip this linear if it's a child of a LoraInjectedLinear
227
+ if exclude_children_of and any(
228
+ [isinstance(parent, _class) for _class in exclude_children_of]
229
+ ):
230
+ continue
231
+ # Otherwise, yield it
232
+ yield parent, name, module
233
+
234
+
235
+ def _find_modules_old(
236
+ model,
237
+ ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
238
+ search_class: List[Type[nn.Module]] = [nn.Linear],
239
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
240
+ ):
241
+ ret = []
242
+ for _module in model.modules():
243
+ if _module.__class__.__name__ in ancestor_class:
244
+
245
+ for name, _child_module in _module.named_modules():
246
+ if _child_module.__class__ in search_class:
247
+ ret.append((_module, name, _child_module))
248
+ print(ret)
249
+ return ret
250
+
251
+
252
+ _find_modules = _find_modules_v2
253
+
254
+
255
+ def inject_trainable_lora(
256
+ model: nn.Module,
257
+ target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
258
+ r: int = 4,
259
+ loras=None, # path to lora .pt
260
+ verbose: bool = False,
261
+ dropout_p: float = 0.0,
262
+ scale: float = 1.0,
263
+ ):
264
+ """
265
+ inject lora into model, and returns lora parameter groups.
266
+ """
267
+
268
+ require_grad_params = []
269
+ names = []
270
+
271
+ if loras != None:
272
+ loras = torch.load(loras)
273
+
274
+ for _module, name, _child_module in _find_modules(
275
+ model, target_replace_module, search_class=[nn.Linear]
276
+ ):
277
+ weight = _child_module.weight
278
+ bias = _child_module.bias
279
+ if verbose:
280
+ print("LoRA Injection : injecting lora into ", name)
281
+ print("LoRA Injection : weight shape", weight.shape)
282
+ _tmp = LoraInjectedLinear(
283
+ _child_module.in_features,
284
+ _child_module.out_features,
285
+ _child_module.bias is not None,
286
+ r=r,
287
+ dropout_p=dropout_p,
288
+ scale=scale,
289
+ )
290
+ _tmp.linear.weight = weight
291
+ if bias is not None:
292
+ _tmp.linear.bias = bias
293
+
294
+ # switch the module
295
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
296
+ _module._modules[name] = _tmp
297
+
298
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
299
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
300
+
301
+ if loras != None:
302
+ _module._modules[name].lora_up.weight = loras.pop(0)
303
+ _module._modules[name].lora_down.weight = loras.pop(0)
304
+
305
+ _module._modules[name].lora_up.weight.requires_grad = True
306
+ _module._modules[name].lora_down.weight.requires_grad = True
307
+ names.append(name)
308
+
309
+ return require_grad_params, names
310
+
311
+
312
+ def inject_trainable_lora_extended(
313
+ model: nn.Module,
314
+ target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
315
+ r: int = 4,
316
+ loras=None, # path to lora .pt
317
+ ):
318
+ """
319
+ inject lora into model, and returns lora parameter groups.
320
+ """
321
+
322
+ require_grad_params = []
323
+ names = []
324
+
325
+ if loras != None:
326
+ loras = torch.load(loras)
327
+
328
+ for _module, name, _child_module in _find_modules(
329
+ model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
330
+ ):
331
+ if _child_module.__class__ == nn.Linear:
332
+ weight = _child_module.weight
333
+ bias = _child_module.bias
334
+ _tmp = LoraInjectedLinear(
335
+ _child_module.in_features,
336
+ _child_module.out_features,
337
+ _child_module.bias is not None,
338
+ r=r,
339
+ )
340
+ _tmp.linear.weight = weight
341
+ if bias is not None:
342
+ _tmp.linear.bias = bias
343
+ elif _child_module.__class__ == nn.Conv2d:
344
+ weight = _child_module.weight
345
+ bias = _child_module.bias
346
+ _tmp = LoraInjectedConv2d(
347
+ _child_module.in_channels,
348
+ _child_module.out_channels,
349
+ _child_module.kernel_size,
350
+ _child_module.stride,
351
+ _child_module.padding,
352
+ _child_module.dilation,
353
+ _child_module.groups,
354
+ _child_module.bias is not None,
355
+ r=r,
356
+ )
357
+
358
+ _tmp.conv.weight = weight
359
+ if bias is not None:
360
+ _tmp.conv.bias = bias
361
+
362
+ # switch the module
363
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
364
+ if bias is not None:
365
+ _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
366
+
367
+ _module._modules[name] = _tmp
368
+
369
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
370
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
371
+
372
+ if loras != None:
373
+ _module._modules[name].lora_up.weight = loras.pop(0)
374
+ _module._modules[name].lora_down.weight = loras.pop(0)
375
+
376
+ _module._modules[name].lora_up.weight.requires_grad = True
377
+ _module._modules[name].lora_down.weight.requires_grad = True
378
+ names.append(name)
379
+
380
+ return require_grad_params, names
381
+
382
+
383
+ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
384
+
385
+ loras = []
386
+
387
+ for _m, _n, _child_module in _find_modules(
388
+ model,
389
+ target_replace_module,
390
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d],
391
+ ):
392
+ loras.append((_child_module.lora_up, _child_module.lora_down))
393
+
394
+ if len(loras) == 0:
395
+ raise ValueError("No lora injected.")
396
+
397
+ return loras
398
+
399
+
400
+ def extract_lora_as_tensor(
401
+ model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
402
+ ):
403
+
404
+ loras = []
405
+
406
+ for _m, _n, _child_module in _find_modules(
407
+ model,
408
+ target_replace_module,
409
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d],
410
+ ):
411
+ up, down = _child_module.realize_as_lora()
412
+ if as_fp16:
413
+ up = up.to(torch.float16)
414
+ down = down.to(torch.float16)
415
+
416
+ loras.append((up, down))
417
+
418
+ if len(loras) == 0:
419
+ raise ValueError("No lora injected.")
420
+
421
+ return loras
422
+
423
+
424
+ def save_lora_weight(
425
+ model,
426
+ path="./lora.pt",
427
+ target_replace_module=DEFAULT_TARGET_REPLACE,
428
+ ):
429
+ weights = []
430
+ for _up, _down in extract_lora_ups_down(
431
+ model, target_replace_module=target_replace_module
432
+ ):
433
+ weights.append(_up.weight.to("cpu").to(torch.float16))
434
+ weights.append(_down.weight.to("cpu").to(torch.float16))
435
+
436
+ torch.save(weights, path)
437
+
438
+
439
+ def save_lora_as_json(model, path="./lora.json"):
440
+ weights = []
441
+ for _up, _down in extract_lora_ups_down(model):
442
+ weights.append(_up.weight.detach().cpu().numpy().tolist())
443
+ weights.append(_down.weight.detach().cpu().numpy().tolist())
444
+
445
+ import json
446
+
447
+ with open(path, "w") as f:
448
+ json.dump(weights, f)
449
+
450
+
451
+ def save_safeloras_with_embeds(
452
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
453
+ embeds: Dict[str, torch.Tensor] = {},
454
+ outpath="./lora.safetensors",
455
+ ):
456
+ """
457
+ Saves the Lora from multiple modules in a single safetensor file.
458
+
459
+ modelmap is a dictionary of {
460
+ "module name": (module, target_replace_module)
461
+ }
462
+ """
463
+ weights = {}
464
+ metadata = {}
465
+
466
+ for name, (model, target_replace_module) in modelmap.items():
467
+ metadata[name] = json.dumps(list(target_replace_module))
468
+
469
+ for i, (_up, _down) in enumerate(
470
+ extract_lora_as_tensor(model, target_replace_module)
471
+ ):
472
+ rank = _down.shape[0]
473
+
474
+ metadata[f"{name}:{i}:rank"] = str(rank)
475
+ weights[f"{name}:{i}:up"] = _up
476
+ weights[f"{name}:{i}:down"] = _down
477
+
478
+ for token, tensor in embeds.items():
479
+ metadata[token] = EMBED_FLAG
480
+ weights[token] = tensor
481
+
482
+ print(f"Saving weights to {outpath}")
483
+ safe_save(weights, outpath, metadata)
484
+
485
+
486
+ def save_safeloras(
487
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
488
+ outpath="./lora.safetensors",
489
+ ):
490
+ return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
491
+
492
+
493
+ def convert_loras_to_safeloras_with_embeds(
494
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
495
+ embeds: Dict[str, torch.Tensor] = {},
496
+ outpath="./lora.safetensors",
497
+ ):
498
+ """
499
+ Converts the Lora from multiple pytorch .pt files into a single safetensor file.
500
+
501
+ modelmap is a dictionary of {
502
+ "module name": (pytorch_model_path, target_replace_module, rank)
503
+ }
504
+ """
505
+
506
+ weights = {}
507
+ metadata = {}
508
+
509
+ for name, (path, target_replace_module, r) in modelmap.items():
510
+ metadata[name] = json.dumps(list(target_replace_module))
511
+
512
+ lora = torch.load(path)
513
+ for i, weight in enumerate(lora):
514
+ is_up = i % 2 == 0
515
+ i = i // 2
516
+
517
+ if is_up:
518
+ metadata[f"{name}:{i}:rank"] = str(r)
519
+ weights[f"{name}:{i}:up"] = weight
520
+ else:
521
+ weights[f"{name}:{i}:down"] = weight
522
+
523
+ for token, tensor in embeds.items():
524
+ metadata[token] = EMBED_FLAG
525
+ weights[token] = tensor
526
+
527
+ print(f"Saving weights to {outpath}")
528
+ safe_save(weights, outpath, metadata)
529
+
530
+
531
+ def convert_loras_to_safeloras(
532
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
533
+ outpath="./lora.safetensors",
534
+ ):
535
+ convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
536
+
537
+
538
+ def parse_safeloras(
539
+ safeloras,
540
+ ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
541
+ """
542
+ Converts a loaded safetensor file that contains a set of module Loras
543
+ into Parameters and other information
544
+
545
+ Output is a dictionary of {
546
+ "module name": (
547
+ [list of weights],
548
+ [list of ranks],
549
+ target_replacement_modules
550
+ )
551
+ }
552
+ """
553
+ loras = {}
554
+ metadata = safeloras.metadata()
555
+
556
+ get_name = lambda k: k.split(":")[0]
557
+
558
+ keys = list(safeloras.keys())
559
+ keys.sort(key=get_name)
560
+
561
+ for name, module_keys in groupby(keys, get_name):
562
+ info = metadata.get(name)
563
+
564
+ if not info:
565
+ raise ValueError(
566
+ f"Tensor {name} has no metadata - is this a Lora safetensor?"
567
+ )
568
+
569
+ # Skip Textual Inversion embeds
570
+ if info == EMBED_FLAG:
571
+ continue
572
+
573
+ # Handle Loras
574
+ # Extract the targets
575
+ target = json.loads(info)
576
+
577
+ # Build the result lists - Python needs us to preallocate lists to insert into them
578
+ module_keys = list(module_keys)
579
+ ranks = [4] * (len(module_keys) // 2)
580
+ weights = [None] * len(module_keys)
581
+
582
+ for key in module_keys:
583
+ # Split the model name and index out of the key
584
+ _, idx, direction = key.split(":")
585
+ idx = int(idx)
586
+
587
+ # Add the rank
588
+ ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
589
+
590
+ # Insert the weight into the list
591
+ idx = idx * 2 + (1 if direction == "down" else 0)
592
+ weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
593
+
594
+ loras[name] = (weights, ranks, target)
595
+
596
+ return loras
597
+
598
+
599
+ def parse_safeloras_embeds(
600
+ safeloras,
601
+ ) -> Dict[str, torch.Tensor]:
602
+ """
603
+ Converts a loaded safetensor file that contains Textual Inversion embeds into
604
+ a dictionary of embed_token: Tensor
605
+ """
606
+ embeds = {}
607
+ metadata = safeloras.metadata()
608
+
609
+ for key in safeloras.keys():
610
+ # Only handle Textual Inversion embeds
611
+ meta = metadata.get(key)
612
+ if not meta or meta != EMBED_FLAG:
613
+ continue
614
+
615
+ embeds[key] = safeloras.get_tensor(key)
616
+
617
+ return embeds
618
+
619
+
620
+ def load_safeloras(path, device="cpu"):
621
+ safeloras = safe_open(path, framework="pt", device=device)
622
+ return parse_safeloras(safeloras)
623
+
624
+
625
+ def load_safeloras_embeds(path, device="cpu"):
626
+ safeloras = safe_open(path, framework="pt", device=device)
627
+ return parse_safeloras_embeds(safeloras)
628
+
629
+
630
+ def load_safeloras_both(path, device="cpu"):
631
+ safeloras = safe_open(path, framework="pt", device=device)
632
+ return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
633
+
634
+
635
+ def collapse_lora(model, alpha=1.0):
636
+
637
+ for _module, name, _child_module in _find_modules(
638
+ model,
639
+ UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
640
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d],
641
+ ):
642
+
643
+ if isinstance(_child_module, LoraInjectedLinear):
644
+ print("Collapsing Lin Lora in", name)
645
+
646
+ _child_module.linear.weight = nn.Parameter(
647
+ _child_module.linear.weight.data
648
+ + alpha
649
+ * (
650
+ _child_module.lora_up.weight.data
651
+ @ _child_module.lora_down.weight.data
652
+ )
653
+ .type(_child_module.linear.weight.dtype)
654
+ .to(_child_module.linear.weight.device)
655
+ )
656
+
657
+ else:
658
+ print("Collapsing Conv Lora in", name)
659
+ _child_module.conv.weight = nn.Parameter(
660
+ _child_module.conv.weight.data
661
+ + alpha
662
+ * (
663
+ _child_module.lora_up.weight.data.flatten(start_dim=1)
664
+ @ _child_module.lora_down.weight.data.flatten(start_dim=1)
665
+ )
666
+ .reshape(_child_module.conv.weight.data.shape)
667
+ .type(_child_module.conv.weight.dtype)
668
+ .to(_child_module.conv.weight.device)
669
+ )
670
+
671
+
672
+ def monkeypatch_or_replace_lora(
673
+ model,
674
+ loras,
675
+ target_replace_module=DEFAULT_TARGET_REPLACE,
676
+ r: Union[int, List[int]] = 4,
677
+ ):
678
+ for _module, name, _child_module in _find_modules(
679
+ model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
680
+ ):
681
+ _source = (
682
+ _child_module.linear
683
+ if isinstance(_child_module, LoraInjectedLinear)
684
+ else _child_module
685
+ )
686
+
687
+ weight = _source.weight
688
+ bias = _source.bias
689
+ _tmp = LoraInjectedLinear(
690
+ _source.in_features,
691
+ _source.out_features,
692
+ _source.bias is not None,
693
+ r=r.pop(0) if isinstance(r, list) else r,
694
+ )
695
+ _tmp.linear.weight = weight
696
+
697
+ if bias is not None:
698
+ _tmp.linear.bias = bias
699
+
700
+ # switch the module
701
+ _module._modules[name] = _tmp
702
+
703
+ up_weight = loras.pop(0)
704
+ down_weight = loras.pop(0)
705
+
706
+ _module._modules[name].lora_up.weight = nn.Parameter(
707
+ up_weight.type(weight.dtype)
708
+ )
709
+ _module._modules[name].lora_down.weight = nn.Parameter(
710
+ down_weight.type(weight.dtype)
711
+ )
712
+
713
+ _module._modules[name].to(weight.device)
714
+
715
+
716
+ def monkeypatch_or_replace_lora_extended(
717
+ model,
718
+ loras,
719
+ target_replace_module=DEFAULT_TARGET_REPLACE,
720
+ r: Union[int, List[int]] = 4,
721
+ ):
722
+ for _module, name, _child_module in _find_modules(
723
+ model,
724
+ target_replace_module,
725
+ search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d],
726
+ ):
727
+
728
+ if (_child_module.__class__ == nn.Linear) or (
729
+ _child_module.__class__ == LoraInjectedLinear
730
+ ):
731
+ if len(loras[0].shape) != 2:
732
+ continue
733
+
734
+ _source = (
735
+ _child_module.linear
736
+ if isinstance(_child_module, LoraInjectedLinear)
737
+ else _child_module
738
+ )
739
+
740
+ weight = _source.weight
741
+ bias = _source.bias
742
+ _tmp = LoraInjectedLinear(
743
+ _source.in_features,
744
+ _source.out_features,
745
+ _source.bias is not None,
746
+ r=r.pop(0) if isinstance(r, list) else r,
747
+ )
748
+ _tmp.linear.weight = weight
749
+
750
+ if bias is not None:
751
+ _tmp.linear.bias = bias
752
+
753
+ elif (_child_module.__class__ == nn.Conv2d) or (
754
+ _child_module.__class__ == LoraInjectedConv2d
755
+ ):
756
+ if len(loras[0].shape) != 4:
757
+ continue
758
+ _source = (
759
+ _child_module.conv
760
+ if isinstance(_child_module, LoraInjectedConv2d)
761
+ else _child_module
762
+ )
763
+
764
+ weight = _source.weight
765
+ bias = _source.bias
766
+ _tmp = LoraInjectedConv2d(
767
+ _source.in_channels,
768
+ _source.out_channels,
769
+ _source.kernel_size,
770
+ _source.stride,
771
+ _source.padding,
772
+ _source.dilation,
773
+ _source.groups,
774
+ _source.bias is not None,
775
+ r=r.pop(0) if isinstance(r, list) else r,
776
+ )
777
+
778
+ _tmp.conv.weight = weight
779
+
780
+ if bias is not None:
781
+ _tmp.conv.bias = bias
782
+
783
+ # switch the module
784
+ _module._modules[name] = _tmp
785
+
786
+ up_weight = loras.pop(0)
787
+ down_weight = loras.pop(0)
788
+
789
+ _module._modules[name].lora_up.weight = nn.Parameter(
790
+ up_weight.type(weight.dtype)
791
+ )
792
+ _module._modules[name].lora_down.weight = nn.Parameter(
793
+ down_weight.type(weight.dtype)
794
+ )
795
+
796
+ _module._modules[name].to(weight.device)
797
+
798
+
799
+ def monkeypatch_or_replace_safeloras(models, safeloras):
800
+ loras = parse_safeloras(safeloras)
801
+
802
+ for name, (lora, ranks, target) in loras.items():
803
+ model = getattr(models, name, None)
804
+
805
+ if not model:
806
+ print(f"No model provided for {name}, contained in Lora")
807
+ continue
808
+
809
+ monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
810
+
811
+
812
+ def monkeypatch_remove_lora(model):
813
+ for _module, name, _child_module in _find_modules(
814
+ model, search_class=[LoraInjectedLinear, LoraInjectedConv2d]
815
+ ):
816
+ if isinstance(_child_module, LoraInjectedLinear):
817
+ _source = _child_module.linear
818
+ weight, bias = _source.weight, _source.bias
819
+
820
+ _tmp = nn.Linear(
821
+ _source.in_features, _source.out_features, bias is not None
822
+ )
823
+
824
+ _tmp.weight = weight
825
+ if bias is not None:
826
+ _tmp.bias = bias
827
+
828
+ else:
829
+ _source = _child_module.conv
830
+ weight, bias = _source.weight, _source.bias
831
+
832
+ _tmp = nn.Conv2d(
833
+ in_channels=_source.in_channels,
834
+ out_channels=_source.out_channels,
835
+ kernel_size=_source.kernel_size,
836
+ stride=_source.stride,
837
+ padding=_source.padding,
838
+ dilation=_source.dilation,
839
+ groups=_source.groups,
840
+ bias=bias is not None,
841
+ )
842
+
843
+ _tmp.weight = weight
844
+ if bias is not None:
845
+ _tmp.bias = bias
846
+
847
+ _module._modules[name] = _tmp
848
+
849
+
850
+ def monkeypatch_add_lora(
851
+ model,
852
+ loras,
853
+ target_replace_module=DEFAULT_TARGET_REPLACE,
854
+ alpha: float = 1.0,
855
+ beta: float = 1.0,
856
+ ):
857
+ for _module, name, _child_module in _find_modules(
858
+ model, target_replace_module, search_class=[LoraInjectedLinear]
859
+ ):
860
+ weight = _child_module.linear.weight
861
+
862
+ up_weight = loras.pop(0)
863
+ down_weight = loras.pop(0)
864
+
865
+ _module._modules[name].lora_up.weight = nn.Parameter(
866
+ up_weight.type(weight.dtype).to(weight.device) * alpha
867
+ + _module._modules[name].lora_up.weight.to(weight.device) * beta
868
+ )
869
+ _module._modules[name].lora_down.weight = nn.Parameter(
870
+ down_weight.type(weight.dtype).to(weight.device) * alpha
871
+ + _module._modules[name].lora_down.weight.to(weight.device) * beta
872
+ )
873
+
874
+ _module._modules[name].to(weight.device)
875
+
876
+
877
+ def tune_lora_scale(model, alpha: float = 1.0):
878
+ for _module in model.modules():
879
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
880
+ _module.scale = alpha
881
+
882
+
883
+ def set_lora_diag(model, diag: torch.Tensor):
884
+ for _module in model.modules():
885
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
886
+ _module.set_selector_from_diag(diag)
887
+
888
+
889
+ def _text_lora_path(path: str) -> str:
890
+ assert path.endswith(".pt"), "Only .pt files are supported"
891
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
892
+
893
+
894
+ def _ti_lora_path(path: str) -> str:
895
+ assert path.endswith(".pt"), "Only .pt files are supported"
896
+ return ".".join(path.split(".")[:-1] + ["ti", "pt"])
897
+
898
+
899
+ def apply_learned_embed_in_clip(
900
+ learned_embeds,
901
+ text_encoder,
902
+ tokenizer,
903
+ token: Optional[Union[str, List[str]]] = None,
904
+ idempotent=False,
905
+ ):
906
+ if isinstance(token, str):
907
+ trained_tokens = [token]
908
+ elif isinstance(token, list):
909
+ assert len(learned_embeds.keys()) == len(
910
+ token
911
+ ), "The number of tokens and the number of embeds should be the same"
912
+ trained_tokens = token
913
+ else:
914
+ trained_tokens = list(learned_embeds.keys())
915
+
916
+ for token in trained_tokens:
917
+ print(token)
918
+ embeds = learned_embeds[token]
919
+
920
+ # cast to dtype of text_encoder
921
+ dtype = text_encoder.get_input_embeddings().weight.dtype
922
+ num_added_tokens = tokenizer.add_tokens(token)
923
+
924
+ i = 1
925
+ if not idempotent:
926
+ while num_added_tokens == 0:
927
+ print(f"The tokenizer already contains the token {token}.")
928
+ token = f"{token[:-1]}-{i}>"
929
+ print(f"Attempting to add the token {token}.")
930
+ num_added_tokens = tokenizer.add_tokens(token)
931
+ i += 1
932
+ elif num_added_tokens == 0 and idempotent:
933
+ print(f"The tokenizer already contains the token {token}.")
934
+ print(f"Replacing {token} embedding.")
935
+
936
+ # resize the token embeddings
937
+ text_encoder.resize_token_embeddings(len(tokenizer))
938
+
939
+ # get the id for the token and assign the embeds
940
+ token_id = tokenizer.convert_tokens_to_ids(token)
941
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
942
+ return token
943
+
944
+
945
+ def load_learned_embed_in_clip(
946
+ learned_embeds_path,
947
+ text_encoder,
948
+ tokenizer,
949
+ token: Optional[Union[str, List[str]]] = None,
950
+ idempotent=False,
951
+ ):
952
+ learned_embeds = torch.load(learned_embeds_path)
953
+ apply_learned_embed_in_clip(
954
+ learned_embeds, text_encoder, tokenizer, token, idempotent
955
+ )
956
+
957
+
958
+ def patch_pipe(
959
+ pipe,
960
+ maybe_unet_path,
961
+ token: Optional[str] = None,
962
+ r: int = 4,
963
+ patch_unet=True,
964
+ patch_text=True,
965
+ patch_ti=True,
966
+ idempotent_token=True,
967
+ unet_target_replace_module=DEFAULT_TARGET_REPLACE,
968
+ text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
969
+ ):
970
+ if maybe_unet_path.endswith(".pt"):
971
+ # torch format
972
+
973
+ if maybe_unet_path.endswith(".ti.pt"):
974
+ unet_path = maybe_unet_path[:-6] + ".pt"
975
+ elif maybe_unet_path.endswith(".text_encoder.pt"):
976
+ unet_path = maybe_unet_path[:-16] + ".pt"
977
+ else:
978
+ unet_path = maybe_unet_path
979
+
980
+ ti_path = _ti_lora_path(unet_path)
981
+ text_path = _text_lora_path(unet_path)
982
+
983
+ if patch_unet:
984
+ print("LoRA : Patching Unet")
985
+ monkeypatch_or_replace_lora(
986
+ pipe.unet,
987
+ torch.load(unet_path),
988
+ r=r,
989
+ target_replace_module=unet_target_replace_module,
990
+ )
991
+
992
+ if patch_text:
993
+ print("LoRA : Patching text encoder")
994
+ monkeypatch_or_replace_lora(
995
+ pipe.text_encoder,
996
+ torch.load(text_path),
997
+ target_replace_module=text_target_replace_module,
998
+ r=r,
999
+ )
1000
+ if patch_ti:
1001
+ print("LoRA : Patching token input")
1002
+ token = load_learned_embed_in_clip(
1003
+ ti_path,
1004
+ pipe.text_encoder,
1005
+ pipe.tokenizer,
1006
+ token=token,
1007
+ idempotent=idempotent_token,
1008
+ )
1009
+
1010
+ elif maybe_unet_path.endswith(".safetensors"):
1011
+ safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
1012
+ monkeypatch_or_replace_safeloras(pipe, safeloras)
1013
+ tok_dict = parse_safeloras_embeds(safeloras)
1014
+ if patch_ti:
1015
+ apply_learned_embed_in_clip(
1016
+ tok_dict,
1017
+ pipe.text_encoder,
1018
+ pipe.tokenizer,
1019
+ token=token,
1020
+ idempotent=idempotent_token,
1021
+ )
1022
+ return tok_dict
1023
+
1024
+
1025
+ @torch.no_grad()
1026
+ def inspect_lora(model):
1027
+ moved = {}
1028
+
1029
+ for name, _module in model.named_modules():
1030
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
1031
+ ups = _module.lora_up.weight.data.clone()
1032
+ downs = _module.lora_down.weight.data.clone()
1033
+
1034
+ wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
1035
+
1036
+ dist = wght.flatten().abs().mean().item()
1037
+ if name in moved:
1038
+ moved[name].append(dist)
1039
+ else:
1040
+ moved[name] = [dist]
1041
+
1042
+ return moved
1043
+
1044
+
1045
+ def save_all(
1046
+ unet,
1047
+ text_encoder,
1048
+ save_path,
1049
+ placeholder_token_ids=None,
1050
+ placeholder_tokens=None,
1051
+ save_lora=True,
1052
+ save_ti=True,
1053
+ target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1054
+ target_replace_module_unet=DEFAULT_TARGET_REPLACE,
1055
+ safe_form=True,
1056
+ ):
1057
+ if not safe_form:
1058
+ # save ti
1059
+ if save_ti:
1060
+ ti_path = _ti_lora_path(save_path)
1061
+ learned_embeds_dict = {}
1062
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1063
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1064
+ print(
1065
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1066
+ learned_embeds[:4],
1067
+ )
1068
+ learned_embeds_dict[tok] = learned_embeds.detach().cpu()
1069
+
1070
+ torch.save(learned_embeds_dict, ti_path)
1071
+ print("Ti saved to ", ti_path)
1072
+
1073
+ # save text encoder
1074
+ if save_lora:
1075
+
1076
+ save_lora_weight(
1077
+ unet, save_path, target_replace_module=target_replace_module_unet
1078
+ )
1079
+ print("Unet saved to ", save_path)
1080
+
1081
+ save_lora_weight(
1082
+ text_encoder,
1083
+ _text_lora_path(save_path),
1084
+ target_replace_module=target_replace_module_text,
1085
+ )
1086
+ print("Text Encoder saved to ", _text_lora_path(save_path))
1087
+
1088
+ else:
1089
+ assert save_path.endswith(
1090
+ ".safetensors"
1091
+ ), f"Save path : {save_path} should end with .safetensors"
1092
+
1093
+ loras = {}
1094
+ embeds = {}
1095
+
1096
+ if save_lora:
1097
+
1098
+ loras["unet"] = (unet, target_replace_module_unet)
1099
+ loras["text_encoder"] = (text_encoder, target_replace_module_text)
1100
+
1101
+ if save_ti:
1102
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1103
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1104
+ print(
1105
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1106
+ learned_embeds[:4],
1107
+ )
1108
+ embeds[tok] = learned_embeds.detach().cpu()
1109
+
1110
+ save_safeloras_with_embeds(loras, embeds, save_path)
lora_diffusion/lora_manager.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ from safetensors import safe_open
4
+ from diffusers import StableDiffusionPipeline
5
+ from .lora import (
6
+ monkeypatch_or_replace_safeloras,
7
+ apply_learned_embed_in_clip,
8
+ set_lora_diag,
9
+ parse_safeloras_embeds,
10
+ )
11
+
12
+
13
+ def lora_join(lora_safetenors: list):
14
+ metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
15
+ _total_metadata = {}
16
+ total_metadata = {}
17
+ total_tensor = {}
18
+ total_rank = 0
19
+ ranklist = []
20
+ for _metadata in metadatas:
21
+ rankset = []
22
+ for k, v in _metadata.items():
23
+ if k.endswith("rank"):
24
+ rankset.append(int(v))
25
+
26
+ assert len(set(rankset)) <= 1, "Rank should be the same per model"
27
+ if len(rankset) == 0:
28
+ rankset = [0]
29
+
30
+ total_rank += rankset[0]
31
+ _total_metadata.update(_metadata)
32
+ ranklist.append(rankset[0])
33
+
34
+ # remove metadata about tokens
35
+ for k, v in _total_metadata.items():
36
+ if v != "<embed>":
37
+ total_metadata[k] = v
38
+
39
+ tensorkeys = set()
40
+ for safelora in lora_safetenors:
41
+ tensorkeys.update(safelora.keys())
42
+
43
+ for keys in tensorkeys:
44
+ if keys.startswith("text_encoder") or keys.startswith("unet"):
45
+ tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]
46
+
47
+ is_down = keys.endswith("down")
48
+
49
+ if is_down:
50
+ _tensor = torch.cat(tensorset, dim=0)
51
+ assert _tensor.shape[0] == total_rank
52
+ else:
53
+ _tensor = torch.cat(tensorset, dim=1)
54
+ assert _tensor.shape[1] == total_rank
55
+
56
+ total_tensor[keys] = _tensor
57
+ keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
58
+ total_metadata[keys_rank] = str(total_rank)
59
+ token_size_list = []
60
+ for idx, safelora in enumerate(lora_safetenors):
61
+ tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
62
+ for jdx, token in enumerate(sorted(tokens)):
63
+
64
+ total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
65
+ total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
66
+
67
+ print(f"Embedding {token} replaced to <s{idx}-{jdx}>")
68
+
69
+ token_size_list.append(len(tokens))
70
+
71
+ return total_tensor, total_metadata, ranklist, token_size_list
72
+
73
+
74
+ class DummySafeTensorObject:
75
+ def __init__(self, tensor: dict, metadata):
76
+ self.tensor = tensor
77
+ self._metadata = metadata
78
+
79
+ def keys(self):
80
+ return self.tensor.keys()
81
+
82
+ def metadata(self):
83
+ return self._metadata
84
+
85
+ def get_tensor(self, key):
86
+ return self.tensor[key]
87
+
88
+
89
+ class LoRAManager:
90
+ def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline):
91
+
92
+ self.lora_paths_list = lora_paths_list
93
+ self.pipe = pipe
94
+ self._setup()
95
+
96
+ def _setup(self):
97
+
98
+ self._lora_safetenors = [
99
+ safe_open(path, framework="pt", device="cpu")
100
+ for path in self.lora_paths_list
101
+ ]
102
+
103
+ (
104
+ total_tensor,
105
+ total_metadata,
106
+ self.ranklist,
107
+ self.token_size_list,
108
+ ) = lora_join(self._lora_safetenors)
109
+
110
+ self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata)
111
+
112
+ monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora)
113
+ tok_dict = parse_safeloras_embeds(self.total_safelora)
114
+
115
+ apply_learned_embed_in_clip(
116
+ tok_dict,
117
+ self.pipe.text_encoder,
118
+ self.pipe.tokenizer,
119
+ token=None,
120
+ idempotent=True,
121
+ )
122
+
123
+ def tune(self, scales):
124
+
125
+ assert len(scales) == len(
126
+ self.ranklist
127
+ ), "Scale list should be the same length as ranklist"
128
+
129
+ diags = []
130
+ for scale, rank in zip(scales, self.ranklist):
131
+ diags = diags + [scale] * rank
132
+
133
+ set_lora_diag(self.pipe.unet, torch.tensor(diags))
134
+
135
+ def prompt(self, prompt):
136
+ if prompt is not None:
137
+ for idx, tok_size in enumerate(self.token_size_list):
138
+ prompt = prompt.replace(
139
+ f"<{idx + 1}>",
140
+ "".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]),
141
+ )
142
+ # TODO : Rescale LoRA + Text inputs based on prompt scale params
143
+
144
+ return prompt
lora_diffusion/preprocess_files.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Have SwinIR upsample
2
+ # Have BLIP auto caption
3
+ # Have CLIPSeg auto mask concept
4
+
5
+ from typing import List, Literal, Union, Optional, Tuple
6
+ import os
7
+ from PIL import Image, ImageFilter
8
+ import torch
9
+ import numpy as np
10
+ import fire
11
+ from tqdm import tqdm
12
+ import glob
13
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
14
+
15
+
16
+ @torch.no_grad()
17
+ def swin_ir_sr(
18
+ images: List[Image.Image],
19
+ model_id: Literal[
20
+ "caidas/swin2SR-classical-sr-x2-64", "caidas/swin2SR-classical-sr-x4-48"
21
+ ] = "caidas/swin2SR-classical-sr-x2-64",
22
+ target_size: Optional[Tuple[int, int]] = None,
23
+ device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
24
+ **kwargs,
25
+ ) -> List[Image.Image]:
26
+ """
27
+ Upscales images using SwinIR. Returns a list of PIL images.
28
+ """
29
+ # So this is currently in main branch, so this can be used in the future I guess?
30
+ from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor
31
+
32
+ model = Swin2SRForImageSuperResolution.from_pretrained(
33
+ model_id,
34
+ ).to(device)
35
+ processor = Swin2SRImageProcessor()
36
+
37
+ out_images = []
38
+
39
+ for image in tqdm(images):
40
+
41
+ ori_w, ori_h = image.size
42
+ if target_size is not None:
43
+ if ori_w >= target_size[0] and ori_h >= target_size[1]:
44
+ out_images.append(image)
45
+ continue
46
+
47
+ inputs = processor(image, return_tensors="pt").to(device)
48
+ with torch.no_grad():
49
+ outputs = model(**inputs)
50
+
51
+ output = (
52
+ outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
53
+ )
54
+ output = np.moveaxis(output, source=0, destination=-1)
55
+ output = (output * 255.0).round().astype(np.uint8)
56
+ output = Image.fromarray(output)
57
+
58
+ out_images.append(output)
59
+
60
+ return out_images
61
+
62
+
63
+ @torch.no_grad()
64
+ def clipseg_mask_generator(
65
+ images: List[Image.Image],
66
+ target_prompts: Union[List[str], str],
67
+ model_id: Literal[
68
+ "CIDAS/clipseg-rd64-refined", "CIDAS/clipseg-rd16"
69
+ ] = "CIDAS/clipseg-rd64-refined",
70
+ device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
71
+ bias: float = 0.01,
72
+ temp: float = 1.0,
73
+ **kwargs,
74
+ ) -> List[Image.Image]:
75
+ """
76
+ Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image
77
+ """
78
+
79
+ if isinstance(target_prompts, str):
80
+ print(
81
+ f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images'
82
+ )
83
+
84
+ target_prompts = [target_prompts] * len(images)
85
+
86
+ processor = CLIPSegProcessor.from_pretrained(model_id)
87
+ model = CLIPSegForImageSegmentation.from_pretrained(model_id).to(device)
88
+
89
+ masks = []
90
+
91
+ for image, prompt in tqdm(zip(images, target_prompts)):
92
+
93
+ original_size = image.size
94
+
95
+ inputs = processor(
96
+ text=[prompt, ""],
97
+ images=[image] * 2,
98
+ padding="max_length",
99
+ truncation=True,
100
+ return_tensors="pt",
101
+ ).to(device)
102
+
103
+ outputs = model(**inputs)
104
+
105
+ logits = outputs.logits
106
+ probs = torch.nn.functional.softmax(logits / temp, dim=0)[0]
107
+ probs = (probs + bias).clamp_(0, 1)
108
+ probs = 255 * probs / probs.max()
109
+
110
+ # make mask greyscale
111
+ mask = Image.fromarray(probs.cpu().numpy()).convert("L")
112
+
113
+ # resize mask to original size
114
+ mask = mask.resize(original_size)
115
+
116
+ masks.append(mask)
117
+
118
+ return masks
119
+
120
+
121
+ @torch.no_grad()
122
+ def blip_captioning_dataset(
123
+ images: List[Image.Image],
124
+ text: Optional[str] = None,
125
+ model_id: Literal[
126
+ "Salesforce/blip-image-captioning-large",
127
+ "Salesforce/blip-image-captioning-base",
128
+ ] = "Salesforce/blip-image-captioning-large",
129
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
130
+ **kwargs,
131
+ ) -> List[str]:
132
+ """
133
+ Returns a list of captions for the given images
134
+ """
135
+
136
+ from transformers import BlipProcessor, BlipForConditionalGeneration
137
+
138
+ processor = BlipProcessor.from_pretrained(model_id)
139
+ model = BlipForConditionalGeneration.from_pretrained(model_id).to(device)
140
+ captions = []
141
+
142
+ for image in tqdm(images):
143
+ inputs = processor(image, text=text, return_tensors="pt").to("cuda")
144
+ out = model.generate(
145
+ **inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7
146
+ )
147
+ caption = processor.decode(out[0], skip_special_tokens=True)
148
+
149
+ captions.append(caption)
150
+
151
+ return captions
152
+
153
+
154
+ def face_mask_google_mediapipe(
155
+ images: List[Image.Image], blur_amount: float = 80.0, bias: float = 0.05
156
+ ) -> List[Image.Image]:
157
+ """
158
+ Returns a list of images with mask on the face parts.
159
+ """
160
+ import mediapipe as mp
161
+
162
+ mp_face_detection = mp.solutions.face_detection
163
+
164
+ face_detection = mp_face_detection.FaceDetection(
165
+ model_selection=1, min_detection_confidence=0.5
166
+ )
167
+
168
+ masks = []
169
+ for image in tqdm(images):
170
+
171
+ image = np.array(image)
172
+
173
+ results = face_detection.process(image)
174
+ black_image = np.ones((image.shape[0], image.shape[1]), dtype=np.uint8)
175
+
176
+ if results.detections:
177
+
178
+ for detection in results.detections:
179
+
180
+ x_min = int(
181
+ detection.location_data.relative_bounding_box.xmin * image.shape[1]
182
+ )
183
+ y_min = int(
184
+ detection.location_data.relative_bounding_box.ymin * image.shape[0]
185
+ )
186
+ width = int(
187
+ detection.location_data.relative_bounding_box.width * image.shape[1]
188
+ )
189
+ height = int(
190
+ detection.location_data.relative_bounding_box.height
191
+ * image.shape[0]
192
+ )
193
+
194
+ # draw the colored rectangle
195
+ black_image[y_min : y_min + height, x_min : x_min + width] = 255
196
+
197
+ black_image = Image.fromarray(black_image)
198
+ masks.append(black_image)
199
+
200
+ return masks
201
+
202
+
203
+ def _crop_to_square(
204
+ image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None
205
+ ):
206
+ cx, cy = com
207
+ width, height = image.size
208
+ if width > height:
209
+ left_possible = max(cx - height / 2, 0)
210
+ left = min(left_possible, width - height)
211
+ right = left + height
212
+ top = 0
213
+ bottom = height
214
+ else:
215
+ left = 0
216
+ right = width
217
+ top_possible = max(cy - width / 2, 0)
218
+ top = min(top_possible, height - width)
219
+ bottom = top + width
220
+
221
+ image = image.crop((left, top, right, bottom))
222
+
223
+ if resize_to:
224
+ image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS)
225
+
226
+ return image
227
+
228
+
229
+ def _center_of_mass(mask: Image.Image):
230
+ """
231
+ Returns the center of mass of the mask
232
+ """
233
+ x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1]))
234
+
235
+ x_ = x * np.array(mask)
236
+ y_ = y * np.array(mask)
237
+
238
+ x = np.sum(x_) / np.sum(mask)
239
+ y = np.sum(y_) / np.sum(mask)
240
+
241
+ return x, y
242
+
243
+
244
+ def load_and_save_masks_and_captions(
245
+ files: Union[str, List[str]],
246
+ output_dir: str,
247
+ caption_text: Optional[str] = None,
248
+ target_prompts: Optional[Union[List[str], str]] = None,
249
+ target_size: int = 512,
250
+ crop_based_on_salience: bool = True,
251
+ use_face_detection_instead: bool = False,
252
+ temp: float = 1.0,
253
+ n_length: int = -1,
254
+ ):
255
+ """
256
+ Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images
257
+ to output dir.
258
+ """
259
+ os.makedirs(output_dir, exist_ok=True)
260
+
261
+ # load images
262
+ if isinstance(files, str):
263
+ # check if it is a directory
264
+ if os.path.isdir(files):
265
+ # get all the .png .jpg in the directory
266
+ files = glob.glob(os.path.join(files, "*.png")) + glob.glob(
267
+ os.path.join(files, "*.jpg")
268
+ )
269
+
270
+ if len(files) == 0:
271
+ raise Exception(
272
+ f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg files."
273
+ )
274
+ if n_length == -1:
275
+ n_length = len(files)
276
+ files = sorted(files)[:n_length]
277
+
278
+ images = [Image.open(file) for file in files]
279
+
280
+ # captions
281
+ print(f"Generating {len(images)} captions...")
282
+ captions = blip_captioning_dataset(images, text=caption_text)
283
+
284
+ if target_prompts is None:
285
+ target_prompts = captions
286
+
287
+ print(f"Generating {len(images)} masks...")
288
+ if not use_face_detection_instead:
289
+ seg_masks = clipseg_mask_generator(
290
+ images=images, target_prompts=target_prompts, temp=temp
291
+ )
292
+ else:
293
+ seg_masks = face_mask_google_mediapipe(images=images)
294
+
295
+ # find the center of mass of the mask
296
+ if crop_based_on_salience:
297
+ coms = [_center_of_mass(mask) for mask in seg_masks]
298
+ else:
299
+ coms = [(image.size[0] / 2, image.size[1] / 2) for image in images]
300
+ # based on the center of mass, crop the image to a square
301
+ images = [
302
+ _crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms)
303
+ ]
304
+
305
+ print(f"Upscaling {len(images)} images...")
306
+ # upscale images anyways
307
+ images = swin_ir_sr(images, target_size=(target_size, target_size))
308
+ images = [
309
+ image.resize((target_size, target_size), Image.Resampling.LANCZOS)
310
+ for image in images
311
+ ]
312
+
313
+ seg_masks = [
314
+ _crop_to_square(mask, com, resize_to=target_size)
315
+ for mask, com in zip(seg_masks, coms)
316
+ ]
317
+ with open(os.path.join(output_dir, "caption.txt"), "w") as f:
318
+ # save images and masks
319
+ for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)):
320
+ image.save(os.path.join(output_dir, f"{idx}.src.jpg"), quality=99)
321
+ mask.save(os.path.join(output_dir, f"{idx}.mask.png"))
322
+
323
+ f.write(caption + "\n")
324
+
325
+
326
+ def main():
327
+ fire.Fire(load_and_save_masks_and_captions)
lora_diffusion/safe_open.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pure python version of Safetensors safe_open
3
+ From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282
4
+ """
5
+
6
+ import json
7
+ import mmap
8
+ import os
9
+
10
+ import torch
11
+
12
+
13
+ class SafetensorsWrapper:
14
+ def __init__(self, metadata, tensors):
15
+ self._metadata = metadata
16
+ self._tensors = tensors
17
+
18
+ def metadata(self):
19
+ return self._metadata
20
+
21
+ def keys(self):
22
+ return self._tensors.keys()
23
+
24
+ def get_tensor(self, k):
25
+ return self._tensors[k]
26
+
27
+
28
+ DTYPES = {
29
+ "F32": torch.float32,
30
+ "F16": torch.float16,
31
+ "BF16": torch.bfloat16,
32
+ }
33
+
34
+
35
+ def create_tensor(storage, info, offset):
36
+ dtype = DTYPES[info["dtype"]]
37
+ shape = info["shape"]
38
+ start, stop = info["data_offsets"]
39
+ return (
40
+ torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8)
41
+ .view(dtype=dtype)
42
+ .reshape(shape)
43
+ )
44
+
45
+
46
+ def safe_open(filename, framework="pt", device="cpu"):
47
+ if framework != "pt":
48
+ raise ValueError("`framework` must be 'pt'")
49
+
50
+ with open(filename, mode="r", encoding="utf8") as file_obj:
51
+ with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
52
+ header = m.read(8)
53
+ n = int.from_bytes(header, "little")
54
+ metadata_bytes = m.read(n)
55
+ metadata = json.loads(metadata_bytes)
56
+
57
+ size = os.stat(filename).st_size
58
+ storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
59
+ offset = n + 8
60
+
61
+ return SafetensorsWrapper(
62
+ metadata=metadata.get("__metadata__", {}),
63
+ tensors={
64
+ name: create_tensor(storage, info, offset).to(device)
65
+ for name, info in metadata.items()
66
+ if name != "__metadata__"
67
+ },
68
+ )
lora_diffusion/to_ckpt_v2.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05
2
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
3
+ # *Only* converts the UNet, VAE, and Text Encoder.
4
+ # Does not convert optimizer state or any other thing.
5
+ # Written by jachiam
6
+ import argparse
7
+ import os.path as osp
8
+
9
+ import torch
10
+
11
+
12
+ # =================#
13
+ # UNet Conversion #
14
+ # =================#
15
+
16
+ unet_conversion_map = [
17
+ # (stable-diffusion, HF Diffusers)
18
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
19
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
20
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
21
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
22
+ ("input_blocks.0.0.weight", "conv_in.weight"),
23
+ ("input_blocks.0.0.bias", "conv_in.bias"),
24
+ ("out.0.weight", "conv_norm_out.weight"),
25
+ ("out.0.bias", "conv_norm_out.bias"),
26
+ ("out.2.weight", "conv_out.weight"),
27
+ ("out.2.bias", "conv_out.bias"),
28
+ ]
29
+
30
+ unet_conversion_map_resnet = [
31
+ # (stable-diffusion, HF Diffusers)
32
+ ("in_layers.0", "norm1"),
33
+ ("in_layers.2", "conv1"),
34
+ ("out_layers.0", "norm2"),
35
+ ("out_layers.3", "conv2"),
36
+ ("emb_layers.1", "time_emb_proj"),
37
+ ("skip_connection", "conv_shortcut"),
38
+ ]
39
+
40
+ unet_conversion_map_layer = []
41
+ # hardcoded number of downblocks and resnets/attentions...
42
+ # would need smarter logic for other networks.
43
+ for i in range(4):
44
+ # loop over downblocks/upblocks
45
+
46
+ for j in range(2):
47
+ # loop over resnets/attentions for downblocks
48
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
49
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
50
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
51
+
52
+ if i < 3:
53
+ # no attention layers in down_blocks.3
54
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
55
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
56
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
57
+
58
+ for j in range(3):
59
+ # loop over resnets/attentions for upblocks
60
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
61
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
62
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
63
+
64
+ if i > 0:
65
+ # no attention layers in up_blocks.0
66
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
67
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
68
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
69
+
70
+ if i < 3:
71
+ # no downsample in down_blocks.3
72
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
73
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
74
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
75
+
76
+ # no upsample in up_blocks.3
77
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
78
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
79
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
80
+
81
+ hf_mid_atn_prefix = "mid_block.attentions.0."
82
+ sd_mid_atn_prefix = "middle_block.1."
83
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
84
+
85
+ for j in range(2):
86
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
87
+ sd_mid_res_prefix = f"middle_block.{2*j}."
88
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
89
+
90
+
91
+ def convert_unet_state_dict(unet_state_dict):
92
+ # buyer beware: this is a *brittle* function,
93
+ # and correct output requires that all of these pieces interact in
94
+ # the exact order in which I have arranged them.
95
+ mapping = {k: k for k in unet_state_dict.keys()}
96
+ for sd_name, hf_name in unet_conversion_map:
97
+ mapping[hf_name] = sd_name
98
+ for k, v in mapping.items():
99
+ if "resnets" in k:
100
+ for sd_part, hf_part in unet_conversion_map_resnet:
101
+ v = v.replace(hf_part, sd_part)
102
+ mapping[k] = v
103
+ for k, v in mapping.items():
104
+ for sd_part, hf_part in unet_conversion_map_layer:
105
+ v = v.replace(hf_part, sd_part)
106
+ mapping[k] = v
107
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
108
+ return new_state_dict
109
+
110
+
111
+ # ================#
112
+ # VAE Conversion #
113
+ # ================#
114
+
115
+ vae_conversion_map = [
116
+ # (stable-diffusion, HF Diffusers)
117
+ ("nin_shortcut", "conv_shortcut"),
118
+ ("norm_out", "conv_norm_out"),
119
+ ("mid.attn_1.", "mid_block.attentions.0."),
120
+ ]
121
+
122
+ for i in range(4):
123
+ # down_blocks have two resnets
124
+ for j in range(2):
125
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
126
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
127
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
128
+
129
+ if i < 3:
130
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
131
+ sd_downsample_prefix = f"down.{i}.downsample."
132
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
133
+
134
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
135
+ sd_upsample_prefix = f"up.{3-i}.upsample."
136
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
137
+
138
+ # up_blocks have three resnets
139
+ # also, up blocks in hf are numbered in reverse from sd
140
+ for j in range(3):
141
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
142
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
143
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
144
+
145
+ # this part accounts for mid blocks in both the encoder and the decoder
146
+ for i in range(2):
147
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
148
+ sd_mid_res_prefix = f"mid.block_{i+1}."
149
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
150
+
151
+
152
+ vae_conversion_map_attn = [
153
+ # (stable-diffusion, HF Diffusers)
154
+ ("norm.", "group_norm."),
155
+ ("q.", "query."),
156
+ ("k.", "key."),
157
+ ("v.", "value."),
158
+ ("proj_out.", "proj_attn."),
159
+ ]
160
+
161
+
162
+ def reshape_weight_for_sd(w):
163
+ # convert HF linear weights to SD conv2d weights
164
+ return w.reshape(*w.shape, 1, 1)
165
+
166
+
167
+ def convert_vae_state_dict(vae_state_dict):
168
+ mapping = {k: k for k in vae_state_dict.keys()}
169
+ for k, v in mapping.items():
170
+ for sd_part, hf_part in vae_conversion_map:
171
+ v = v.replace(hf_part, sd_part)
172
+ mapping[k] = v
173
+ for k, v in mapping.items():
174
+ if "attentions" in k:
175
+ for sd_part, hf_part in vae_conversion_map_attn:
176
+ v = v.replace(hf_part, sd_part)
177
+ mapping[k] = v
178
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
179
+ weights_to_convert = ["q", "k", "v", "proj_out"]
180
+ for k, v in new_state_dict.items():
181
+ for weight_name in weights_to_convert:
182
+ if f"mid.attn_1.{weight_name}.weight" in k:
183
+ print(f"Reshaping {k} for SD format")
184
+ new_state_dict[k] = reshape_weight_for_sd(v)
185
+ return new_state_dict
186
+
187
+
188
+ # =========================#
189
+ # Text Encoder Conversion #
190
+ # =========================#
191
+ # pretty much a no-op
192
+
193
+
194
+ def convert_text_enc_state_dict(text_enc_dict):
195
+ return text_enc_dict
196
+
197
+
198
+ def convert_to_ckpt(model_path, checkpoint_path, as_half):
199
+
200
+ assert model_path is not None, "Must provide a model path!"
201
+
202
+ assert checkpoint_path is not None, "Must provide a checkpoint path!"
203
+
204
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
205
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
206
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
207
+
208
+ # Convert the UNet model
209
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
210
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
211
+ unet_state_dict = {
212
+ "model.diffusion_model." + k: v for k, v in unet_state_dict.items()
213
+ }
214
+
215
+ # Convert the VAE model
216
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
217
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
218
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
219
+
220
+ # Convert the text encoder model
221
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
222
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
223
+ text_enc_dict = {
224
+ "cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
225
+ }
226
+
227
+ # Put together new checkpoint
228
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
229
+ if as_half:
230
+ state_dict = {k: v.half() for k, v in state_dict.items()}
231
+ state_dict = {"state_dict": state_dict}
232
+ torch.save(state_dict, checkpoint_path)
lora_diffusion/utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import (
6
+ CLIPProcessor,
7
+ CLIPTextModelWithProjection,
8
+ CLIPTokenizer,
9
+ CLIPVisionModelWithProjection,
10
+ )
11
+
12
+ from diffusers import StableDiffusionPipeline
13
+ from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path
14
+ import os
15
+ import glob
16
+ import math
17
+
18
+ EXAMPLE_PROMPTS = [
19
+ "<obj> swimming in a pool",
20
+ "<obj> at a beach with a view of seashore",
21
+ "<obj> in times square",
22
+ "<obj> wearing sunglasses",
23
+ "<obj> in a construction outfit",
24
+ "<obj> playing with a ball",
25
+ "<obj> wearing headphones",
26
+ "<obj> oil painting ghibli inspired",
27
+ "<obj> working on the laptop",
28
+ "<obj> with mountains and sunset in background",
29
+ "Painting of <obj> at a beach by artist claude monet",
30
+ "<obj> digital painting 3d render geometric style",
31
+ "A screaming <obj>",
32
+ "A depressed <obj>",
33
+ "A sleeping <obj>",
34
+ "A sad <obj>",
35
+ "A joyous <obj>",
36
+ "A frowning <obj>",
37
+ "A sculpture of <obj>",
38
+ "<obj> near a pool",
39
+ "<obj> at a beach with a view of seashore",
40
+ "<obj> in a garden",
41
+ "<obj> in grand canyon",
42
+ "<obj> floating in ocean",
43
+ "<obj> and an armchair",
44
+ "A maple tree on the side of <obj>",
45
+ "<obj> and an orange sofa",
46
+ "<obj> with chocolate cake on it",
47
+ "<obj> with a vase of rose flowers on it",
48
+ "A digital illustration of <obj>",
49
+ "Georgia O'Keeffe style <obj> painting",
50
+ "A watercolor painting of <obj> on a beach",
51
+ ]
52
+
53
+
54
+ def image_grid(_imgs, rows=None, cols=None):
55
+
56
+ if rows is None and cols is None:
57
+ rows = cols = math.ceil(len(_imgs) ** 0.5)
58
+
59
+ if rows is None:
60
+ rows = math.ceil(len(_imgs) / cols)
61
+ if cols is None:
62
+ cols = math.ceil(len(_imgs) / rows)
63
+
64
+ w, h = _imgs[0].size
65
+ grid = Image.new("RGB", size=(cols * w, rows * h))
66
+ grid_w, grid_h = grid.size
67
+
68
+ for i, img in enumerate(_imgs):
69
+ grid.paste(img, box=(i % cols * w, i // cols * h))
70
+ return grid
71
+
72
+
73
+ def text_img_alignment(img_embeds, text_embeds, target_img_embeds):
74
+ # evaluation inspired from textual inversion paper
75
+ # https://arxiv.org/abs/2208.01618
76
+
77
+ # text alignment
78
+ assert img_embeds.shape[0] == text_embeds.shape[0]
79
+ text_img_sim = (img_embeds * text_embeds).sum(dim=-1) / (
80
+ img_embeds.norm(dim=-1) * text_embeds.norm(dim=-1)
81
+ )
82
+
83
+ # image alignment
84
+ img_embed_normalized = img_embeds / img_embeds.norm(dim=-1, keepdim=True)
85
+
86
+ avg_target_img_embed = (
87
+ (target_img_embeds / target_img_embeds.norm(dim=-1, keepdim=True))
88
+ .mean(dim=0)
89
+ .unsqueeze(0)
90
+ .repeat(img_embeds.shape[0], 1)
91
+ )
92
+
93
+ img_img_sim = (img_embed_normalized * avg_target_img_embed).sum(dim=-1)
94
+
95
+ return {
96
+ "text_alignment_avg": text_img_sim.mean().item(),
97
+ "image_alignment_avg": img_img_sim.mean().item(),
98
+ "text_alignment_all": text_img_sim.tolist(),
99
+ "image_alignment_all": img_img_sim.tolist(),
100
+ }
101
+
102
+
103
+ def prepare_clip_model_sets(eval_clip_id: str = "openai/clip-vit-large-patch14"):
104
+ text_model = CLIPTextModelWithProjection.from_pretrained(eval_clip_id)
105
+ tokenizer = CLIPTokenizer.from_pretrained(eval_clip_id)
106
+ vis_model = CLIPVisionModelWithProjection.from_pretrained(eval_clip_id)
107
+ processor = CLIPProcessor.from_pretrained(eval_clip_id)
108
+
109
+ return text_model, tokenizer, vis_model, processor
110
+
111
+
112
+ def evaluate_pipe(
113
+ pipe,
114
+ target_images: List[Image.Image],
115
+ class_token: str = "",
116
+ learnt_token: str = "",
117
+ guidance_scale: float = 5.0,
118
+ seed=0,
119
+ clip_model_sets=None,
120
+ eval_clip_id: str = "openai/clip-vit-large-patch14",
121
+ n_test: int = 10,
122
+ n_step: int = 50,
123
+ ):
124
+
125
+ if clip_model_sets is not None:
126
+ text_model, tokenizer, vis_model, processor = clip_model_sets
127
+ else:
128
+ text_model, tokenizer, vis_model, processor = prepare_clip_model_sets(
129
+ eval_clip_id
130
+ )
131
+
132
+ images = []
133
+ img_embeds = []
134
+ text_embeds = []
135
+ for prompt in EXAMPLE_PROMPTS[:n_test]:
136
+ prompt = prompt.replace("<obj>", learnt_token)
137
+ torch.manual_seed(seed)
138
+ with torch.autocast("cuda"):
139
+ img = pipe(
140
+ prompt, num_inference_steps=n_step, guidance_scale=guidance_scale
141
+ ).images[0]
142
+ images.append(img)
143
+
144
+ # image
145
+ inputs = processor(images=img, return_tensors="pt")
146
+ img_embed = vis_model(**inputs).image_embeds
147
+ img_embeds.append(img_embed)
148
+
149
+ prompt = prompt.replace(learnt_token, class_token)
150
+ # prompts
151
+ inputs = tokenizer([prompt], padding=True, return_tensors="pt")
152
+ outputs = text_model(**inputs)
153
+ text_embed = outputs.text_embeds
154
+ text_embeds.append(text_embed)
155
+
156
+ # target images
157
+ inputs = processor(images=target_images, return_tensors="pt")
158
+ target_img_embeds = vis_model(**inputs).image_embeds
159
+
160
+ img_embeds = torch.cat(img_embeds, dim=0)
161
+ text_embeds = torch.cat(text_embeds, dim=0)
162
+
163
+ return text_img_alignment(img_embeds, text_embeds, target_img_embeds)
164
+
165
+
166
+ def visualize_progress(
167
+ path_alls: Union[str, List[str]],
168
+ prompt: str,
169
+ model_id: str = "runwayml/stable-diffusion-v1-5",
170
+ device="cuda:0",
171
+ patch_unet=True,
172
+ patch_text=True,
173
+ patch_ti=True,
174
+ unet_scale=1.0,
175
+ text_sclae=1.0,
176
+ num_inference_steps=50,
177
+ guidance_scale=5.0,
178
+ offset: int = 0,
179
+ limit: int = 10,
180
+ seed: int = 0,
181
+ ):
182
+
183
+ imgs = []
184
+ if isinstance(path_alls, str):
185
+ alls = list(set(glob.glob(path_alls)))
186
+
187
+ alls.sort(key=os.path.getmtime)
188
+ else:
189
+ alls = path_alls
190
+
191
+ pipe = StableDiffusionPipeline.from_pretrained(
192
+ model_id, torch_dtype=torch.float16
193
+ ).to(device)
194
+
195
+ print(f"Found {len(alls)} checkpoints")
196
+ for path in alls[offset:limit]:
197
+ print(path)
198
+
199
+ patch_pipe(
200
+ pipe, path, patch_unet=patch_unet, patch_text=patch_text, patch_ti=patch_ti
201
+ )
202
+
203
+ tune_lora_scale(pipe.unet, unet_scale)
204
+ tune_lora_scale(pipe.text_encoder, text_sclae)
205
+
206
+ torch.manual_seed(seed)
207
+ image = pipe(
208
+ prompt,
209
+ num_inference_steps=num_inference_steps,
210
+ guidance_scale=guidance_scale,
211
+ ).images[0]
212
+ imgs.append(image)
213
+
214
+ return imgs
lora_diffusion/xformers_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ from diffusers.models.attention import BasicTransformerBlock
5
+ from diffusers.utils.import_utils import is_xformers_available
6
+
7
+ from .lora import LoraInjectedLinear
8
+
9
+ if is_xformers_available():
10
+ import xformers
11
+ import xformers.ops
12
+ else:
13
+ xformers = None
14
+
15
+
16
+ @functools.cache
17
+ def test_xformers_backwards(size):
18
+ @torch.enable_grad()
19
+ def _grad(size):
20
+ q = torch.randn((1, 4, size), device="cuda")
21
+ k = torch.randn((1, 4, size), device="cuda")
22
+ v = torch.randn((1, 4, size), device="cuda")
23
+
24
+ q = q.detach().requires_grad_()
25
+ k = k.detach().requires_grad_()
26
+ v = v.detach().requires_grad_()
27
+
28
+ out = xformers.ops.memory_efficient_attention(q, k, v)
29
+ loss = out.sum(2).mean(0).sum()
30
+
31
+ return torch.autograd.grad(loss, v)
32
+
33
+ try:
34
+ _grad(size)
35
+ print(size, "pass")
36
+ return True
37
+ except Exception as e:
38
+ print(size, "fail")
39
+ return False
40
+
41
+
42
+ def set_use_memory_efficient_attention_xformers(
43
+ module: torch.nn.Module, valid: bool
44
+ ) -> None:
45
+ def fn_test_dim_head(module: torch.nn.Module):
46
+ if isinstance(module, BasicTransformerBlock):
47
+ # dim_head isn't stored anywhere, so back-calculate
48
+ source = module.attn1.to_v
49
+ if isinstance(source, LoraInjectedLinear):
50
+ source = source.linear
51
+
52
+ dim_head = source.out_features // module.attn1.heads
53
+
54
+ result = test_xformers_backwards(dim_head)
55
+
56
+ # If dim_head > dim_head_max, turn xformers off
57
+ if not result:
58
+ module.set_use_memory_efficient_attention_xformers(False)
59
+
60
+ for child in module.children():
61
+ fn_test_dim_head(child)
62
+
63
+ if not is_xformers_available() and valid:
64
+ print("XFormers is not available. Skipping.")
65
+ return
66
+
67
+ module.set_use_memory_efficient_attention_xformers(valid)
68
+
69
+ if valid:
70
+ fn_test_dim_head(module)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ diffusers
2
+ accelerate
3
+ transformers>=4.25.1
train_dreambooth_cloneofsimo_lora.py ADDED
@@ -0,0 +1,1008 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bootstrapped from:
2
+ # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
3
+
4
+ import argparse
5
+ import hashlib
6
+ import itertools
7
+ import math
8
+ import os
9
+ import inspect
10
+ from pathlib import Path
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint
16
+
17
+
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import set_seed
21
+ from diffusers import (
22
+ AutoencoderKL,
23
+ DDPMScheduler,
24
+ StableDiffusionPipeline,
25
+ UNet2DConditionModel,
26
+ )
27
+ from diffusers.optimization import get_scheduler
28
+ from huggingface_hub import HfFolder, Repository, whoami
29
+
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+
33
+ from lora_diffusion import (
34
+ extract_lora_ups_down,
35
+ inject_trainable_lora,
36
+ safetensors_available,
37
+ save_lora_weight,
38
+ save_safeloras,
39
+ )
40
+ from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
41
+ from PIL import Image
42
+ from torch.utils.data import Dataset
43
+ from torchvision import transforms
44
+
45
+ from pathlib import Path
46
+
47
+ import random
48
+ import re
49
+
50
+
51
+ class DreamBoothDataset(Dataset):
52
+ """
53
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
54
+ It pre-processes the images and the tokenizes prompts.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ instance_data_root,
60
+ instance_prompt,
61
+ tokenizer,
62
+ class_data_root=None,
63
+ class_prompt=None,
64
+ size=512,
65
+ center_crop=False,
66
+ color_jitter=False,
67
+ h_flip=False,
68
+ resize=False,
69
+ ):
70
+ self.size = size
71
+ self.center_crop = center_crop
72
+ self.tokenizer = tokenizer
73
+ self.resize = resize
74
+
75
+ self.instance_data_root = Path(instance_data_root)
76
+ if not self.instance_data_root.exists():
77
+ raise ValueError("Instance images root doesn't exists.")
78
+
79
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
80
+ self.num_instance_images = len(self.instance_images_path)
81
+ self.instance_prompt = instance_prompt
82
+ self._length = self.num_instance_images
83
+
84
+ if class_data_root is not None:
85
+ self.class_data_root = Path(class_data_root)
86
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
87
+ self.class_images_path = list(self.class_data_root.iterdir())
88
+ self.num_class_images = len(self.class_images_path)
89
+ self._length = max(self.num_class_images, self.num_instance_images)
90
+ self.class_prompt = class_prompt
91
+ else:
92
+ self.class_data_root = None
93
+
94
+ img_transforms = []
95
+
96
+ if resize:
97
+ img_transforms.append(
98
+ transforms.Resize(
99
+ size, interpolation=transforms.InterpolationMode.BILINEAR
100
+ )
101
+ )
102
+ if center_crop:
103
+ img_transforms.append(transforms.CenterCrop(size))
104
+ if color_jitter:
105
+ img_transforms.append(transforms.ColorJitter(0.2, 0.1))
106
+ if h_flip:
107
+ img_transforms.append(transforms.RandomHorizontalFlip())
108
+
109
+ self.image_transforms = transforms.Compose(
110
+ [*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
111
+ )
112
+
113
+ def __len__(self):
114
+ return self._length
115
+
116
+ def __getitem__(self, index):
117
+ example = {}
118
+ instance_image = Image.open(
119
+ self.instance_images_path[index % self.num_instance_images]
120
+ )
121
+ if not instance_image.mode == "RGB":
122
+ instance_image = instance_image.convert("RGB")
123
+ example["instance_images"] = self.image_transforms(instance_image)
124
+ example["instance_prompt_ids"] = self.tokenizer(
125
+ self.instance_prompt,
126
+ padding="do_not_pad",
127
+ truncation=True,
128
+ max_length=self.tokenizer.model_max_length,
129
+ ).input_ids
130
+
131
+ if self.class_data_root:
132
+ class_image = Image.open(
133
+ self.class_images_path[index % self.num_class_images]
134
+ )
135
+ if not class_image.mode == "RGB":
136
+ class_image = class_image.convert("RGB")
137
+ example["class_images"] = self.image_transforms(class_image)
138
+ example["class_prompt_ids"] = self.tokenizer(
139
+ self.class_prompt,
140
+ padding="do_not_pad",
141
+ truncation=True,
142
+ max_length=self.tokenizer.model_max_length,
143
+ ).input_ids
144
+
145
+ return example
146
+
147
+
148
+ class PromptDataset(Dataset):
149
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
150
+
151
+ def __init__(self, prompt, num_samples):
152
+ self.prompt = prompt
153
+ self.num_samples = num_samples
154
+
155
+ def __len__(self):
156
+ return self.num_samples
157
+
158
+ def __getitem__(self, index):
159
+ example = {}
160
+ example["prompt"] = self.prompt
161
+ example["index"] = index
162
+ return example
163
+
164
+
165
+ logger = get_logger(__name__)
166
+
167
+
168
+ def parse_args(input_args=None):
169
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
170
+ parser.add_argument(
171
+ "--pretrained_model_name_or_path",
172
+ type=str,
173
+ default=None,
174
+ required=True,
175
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
176
+ )
177
+ parser.add_argument(
178
+ "--pretrained_vae_name_or_path",
179
+ type=str,
180
+ default=None,
181
+ help="Path to pretrained vae or vae identifier from huggingface.co/models.",
182
+ )
183
+ parser.add_argument(
184
+ "--revision",
185
+ type=str,
186
+ default=None,
187
+ required=False,
188
+ help="Revision of pretrained model identifier from huggingface.co/models.",
189
+ )
190
+ parser.add_argument(
191
+ "--tokenizer_name",
192
+ type=str,
193
+ default=None,
194
+ help="Pretrained tokenizer name or path if not the same as model_name",
195
+ )
196
+ parser.add_argument(
197
+ "--instance_data_dir",
198
+ type=str,
199
+ default=None,
200
+ required=True,
201
+ help="A folder containing the training data of instance images.",
202
+ )
203
+ parser.add_argument(
204
+ "--class_data_dir",
205
+ type=str,
206
+ default=None,
207
+ required=False,
208
+ help="A folder containing the training data of class images.",
209
+ )
210
+ parser.add_argument(
211
+ "--instance_prompt",
212
+ type=str,
213
+ default=None,
214
+ required=True,
215
+ help="The prompt with identifier specifying the instance",
216
+ )
217
+ parser.add_argument(
218
+ "--class_prompt",
219
+ type=str,
220
+ default=None,
221
+ help="The prompt to specify images in the same class as provided instance images.",
222
+ )
223
+ parser.add_argument(
224
+ "--with_prior_preservation",
225
+ default=False,
226
+ action="store_true",
227
+ help="Flag to add prior preservation loss.",
228
+ )
229
+ parser.add_argument(
230
+ "--prior_loss_weight",
231
+ type=float,
232
+ default=1.0,
233
+ help="The weight of prior preservation loss.",
234
+ )
235
+ parser.add_argument(
236
+ "--num_class_images",
237
+ type=int,
238
+ default=100,
239
+ help=(
240
+ "Minimal class images for prior preservation loss. If not have enough images, additional images will be"
241
+ " sampled with class_prompt."
242
+ ),
243
+ )
244
+ parser.add_argument(
245
+ "--output_dir",
246
+ type=str,
247
+ default="text-inversion-model",
248
+ help="The output directory where the model predictions and checkpoints will be written.",
249
+ )
250
+ parser.add_argument(
251
+ "--output_format",
252
+ type=str,
253
+ choices=["pt", "safe", "both"],
254
+ default="both",
255
+ help="The output format of the model predicitions and checkpoints.",
256
+ )
257
+ parser.add_argument(
258
+ "--seed", type=int, default=None, help="A seed for reproducible training."
259
+ )
260
+ parser.add_argument(
261
+ "--resolution",
262
+ type=int,
263
+ default=512,
264
+ help=(
265
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
266
+ " resolution"
267
+ ),
268
+ )
269
+ parser.add_argument(
270
+ "--center_crop",
271
+ action="store_true",
272
+ help="Whether to center crop images before resizing to resolution",
273
+ )
274
+ parser.add_argument(
275
+ "--color_jitter",
276
+ action="store_true",
277
+ help="Whether to apply color jitter to images",
278
+ )
279
+ parser.add_argument(
280
+ "--train_text_encoder",
281
+ action="store_true",
282
+ help="Whether to train the text encoder",
283
+ )
284
+ parser.add_argument(
285
+ "--train_batch_size",
286
+ type=int,
287
+ default=4,
288
+ help="Batch size (per device) for the training dataloader.",
289
+ )
290
+ parser.add_argument(
291
+ "--sample_batch_size",
292
+ type=int,
293
+ default=4,
294
+ help="Batch size (per device) for sampling images.",
295
+ )
296
+ parser.add_argument("--num_train_epochs", type=int, default=1)
297
+ parser.add_argument(
298
+ "--max_train_steps",
299
+ type=int,
300
+ default=None,
301
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
302
+ )
303
+ parser.add_argument(
304
+ "--save_steps",
305
+ type=int,
306
+ default=500,
307
+ help="Save checkpoint every X updates steps.",
308
+ )
309
+ parser.add_argument(
310
+ "--gradient_accumulation_steps",
311
+ type=int,
312
+ default=1,
313
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
314
+ )
315
+ parser.add_argument(
316
+ "--gradient_checkpointing",
317
+ action="store_true",
318
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
319
+ )
320
+ parser.add_argument(
321
+ "--lora_rank",
322
+ type=int,
323
+ default=4,
324
+ help="Rank of LoRA approximation.",
325
+ )
326
+ parser.add_argument(
327
+ "--learning_rate",
328
+ type=float,
329
+ default=None,
330
+ help="Initial learning rate (after the potential warmup period) to use.",
331
+ )
332
+ parser.add_argument(
333
+ "--learning_rate_text",
334
+ type=float,
335
+ default=5e-6,
336
+ help="Initial learning rate for text encoder (after the potential warmup period) to use.",
337
+ )
338
+ parser.add_argument(
339
+ "--scale_lr",
340
+ action="store_true",
341
+ default=False,
342
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
343
+ )
344
+ parser.add_argument(
345
+ "--lr_scheduler",
346
+ type=str,
347
+ default="constant",
348
+ help=(
349
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
350
+ ' "constant", "constant_with_warmup"]'
351
+ ),
352
+ )
353
+ parser.add_argument(
354
+ "--lr_warmup_steps",
355
+ type=int,
356
+ default=500,
357
+ help="Number of steps for the warmup in the lr scheduler.",
358
+ )
359
+ parser.add_argument(
360
+ "--use_8bit_adam",
361
+ action="store_true",
362
+ help="Whether or not to use 8-bit Adam from bitsandbytes.",
363
+ )
364
+ parser.add_argument(
365
+ "--adam_beta1",
366
+ type=float,
367
+ default=0.9,
368
+ help="The beta1 parameter for the Adam optimizer.",
369
+ )
370
+ parser.add_argument(
371
+ "--adam_beta2",
372
+ type=float,
373
+ default=0.999,
374
+ help="The beta2 parameter for the Adam optimizer.",
375
+ )
376
+ parser.add_argument(
377
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
378
+ )
379
+ parser.add_argument(
380
+ "--adam_epsilon",
381
+ type=float,
382
+ default=1e-08,
383
+ help="Epsilon value for the Adam optimizer",
384
+ )
385
+ parser.add_argument(
386
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
387
+ )
388
+ parser.add_argument(
389
+ "--push_to_hub",
390
+ action="store_true",
391
+ help="Whether or not to push the model to the Hub.",
392
+ )
393
+ parser.add_argument(
394
+ "--hub_token",
395
+ type=str,
396
+ default=None,
397
+ help="The token to use to push to the Model Hub.",
398
+ )
399
+ parser.add_argument(
400
+ "--logging_dir",
401
+ type=str,
402
+ default="logs",
403
+ help=(
404
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
405
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
406
+ ),
407
+ )
408
+ parser.add_argument(
409
+ "--mixed_precision",
410
+ type=str,
411
+ default=None,
412
+ choices=["no", "fp16", "bf16"],
413
+ help=(
414
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
415
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
416
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
417
+ ),
418
+ )
419
+ parser.add_argument(
420
+ "--local_rank",
421
+ type=int,
422
+ default=-1,
423
+ help="For distributed training: local_rank",
424
+ )
425
+ parser.add_argument(
426
+ "--resume_unet",
427
+ type=str,
428
+ default=None,
429
+ help=("File path for unet lora to resume training."),
430
+ )
431
+ parser.add_argument(
432
+ "--resume_text_encoder",
433
+ type=str,
434
+ default=None,
435
+ help=("File path for text encoder lora to resume training."),
436
+ )
437
+ parser.add_argument(
438
+ "--resize",
439
+ type=bool,
440
+ default=True,
441
+ required=False,
442
+ help="Should images be resized to --resolution before training?",
443
+ )
444
+ parser.add_argument(
445
+ "--use_xformers", action="store_true", help="Whether or not to use xformers"
446
+ )
447
+
448
+ if input_args is not None:
449
+ args = parser.parse_args(input_args)
450
+ else:
451
+ args = parser.parse_args()
452
+
453
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
454
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
455
+ args.local_rank = env_local_rank
456
+
457
+ if args.with_prior_preservation:
458
+ if args.class_data_dir is None:
459
+ raise ValueError("You must specify a data directory for class images.")
460
+ if args.class_prompt is None:
461
+ raise ValueError("You must specify prompt for class images.")
462
+ else:
463
+ if args.class_data_dir is not None:
464
+ logger.warning(
465
+ "You need not use --class_data_dir without --with_prior_preservation."
466
+ )
467
+ if args.class_prompt is not None:
468
+ logger.warning(
469
+ "You need not use --class_prompt without --with_prior_preservation."
470
+ )
471
+
472
+ if not safetensors_available:
473
+ if args.output_format == "both":
474
+ print(
475
+ "Safetensors is not available - changing output format to just output PyTorch files"
476
+ )
477
+ args.output_format = "pt"
478
+ elif args.output_format == "safe":
479
+ raise ValueError(
480
+ "Safetensors is not available - either install it, or change output_format."
481
+ )
482
+
483
+ return args
484
+
485
+
486
+ def main(args):
487
+ logging_dir = Path(args.output_dir, args.logging_dir)
488
+
489
+ accelerator = Accelerator(
490
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
491
+ mixed_precision=args.mixed_precision,
492
+ log_with="tensorboard",
493
+ logging_dir=logging_dir,
494
+ )
495
+
496
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
497
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
498
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
499
+ if (
500
+ args.train_text_encoder
501
+ and args.gradient_accumulation_steps > 1
502
+ and accelerator.num_processes > 1
503
+ ):
504
+ raise ValueError(
505
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
506
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
507
+ )
508
+
509
+ if args.seed is not None:
510
+ set_seed(args.seed)
511
+
512
+ if args.with_prior_preservation:
513
+ class_images_dir = Path(args.class_data_dir)
514
+ if not class_images_dir.exists():
515
+ class_images_dir.mkdir(parents=True)
516
+ cur_class_images = len(list(class_images_dir.iterdir()))
517
+
518
+ if cur_class_images < args.num_class_images:
519
+ torch_dtype = (
520
+ torch.float16 if accelerator.device.type == "cuda" else torch.float32
521
+ )
522
+ pipeline = StableDiffusionPipeline.from_pretrained(
523
+ args.pretrained_model_name_or_path,
524
+ torch_dtype=torch_dtype,
525
+ safety_checker=None,
526
+ revision=args.revision,
527
+ )
528
+ pipeline.set_progress_bar_config(disable=True)
529
+
530
+ num_new_images = args.num_class_images - cur_class_images
531
+ logger.info(f"Number of class images to sample: {num_new_images}.")
532
+
533
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
534
+ sample_dataloader = torch.utils.data.DataLoader(
535
+ sample_dataset, batch_size=args.sample_batch_size
536
+ )
537
+
538
+ sample_dataloader = accelerator.prepare(sample_dataloader)
539
+ pipeline.to(accelerator.device)
540
+
541
+ for example in tqdm(
542
+ sample_dataloader,
543
+ desc="Generating class images",
544
+ disable=not accelerator.is_local_main_process,
545
+ ):
546
+ images = pipeline(example["prompt"]).images
547
+
548
+ for i, image in enumerate(images):
549
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
550
+ image_filename = (
551
+ class_images_dir
552
+ / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
553
+ )
554
+ image.save(image_filename)
555
+
556
+ del pipeline
557
+ if torch.cuda.is_available():
558
+ torch.cuda.empty_cache()
559
+
560
+ # Handle the repository creation
561
+ if accelerator.is_main_process:
562
+
563
+ if args.output_dir is not None:
564
+ os.makedirs(args.output_dir, exist_ok=True)
565
+
566
+ # Load the tokenizer
567
+ if args.tokenizer_name:
568
+ tokenizer = CLIPTokenizer.from_pretrained(
569
+ args.tokenizer_name,
570
+ revision=args.revision,
571
+ )
572
+ elif args.pretrained_model_name_or_path:
573
+ tokenizer = CLIPTokenizer.from_pretrained(
574
+ args.pretrained_model_name_or_path,
575
+ subfolder="tokenizer",
576
+ revision=args.revision,
577
+ )
578
+
579
+ # Load models and create wrapper for stable diffusion
580
+ text_encoder = CLIPTextModel.from_pretrained(
581
+ args.pretrained_model_name_or_path,
582
+ subfolder="text_encoder",
583
+ revision=args.revision,
584
+ )
585
+ vae = AutoencoderKL.from_pretrained(
586
+ args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
587
+ subfolder=None if args.pretrained_vae_name_or_path else "vae",
588
+ revision=None if args.pretrained_vae_name_or_path else args.revision,
589
+ )
590
+ unet = UNet2DConditionModel.from_pretrained(
591
+ args.pretrained_model_name_or_path,
592
+ subfolder="unet",
593
+ revision=args.revision,
594
+ )
595
+ unet.requires_grad_(False)
596
+ unet_lora_params, _ = inject_trainable_lora(
597
+ unet, r=args.lora_rank, loras=args.resume_unet
598
+ )
599
+
600
+ for _up, _down in extract_lora_ups_down(unet):
601
+ print("Before training: Unet First Layer lora up", _up.weight.data)
602
+ print("Before training: Unet First Layer lora down", _down.weight.data)
603
+ break
604
+
605
+ vae.requires_grad_(False)
606
+ text_encoder.requires_grad_(False)
607
+
608
+ if args.train_text_encoder:
609
+ text_encoder_lora_params, _ = inject_trainable_lora(
610
+ text_encoder,
611
+ target_replace_module=["CLIPAttention"],
612
+ r=args.lora_rank,
613
+ )
614
+ for _up, _down in extract_lora_ups_down(
615
+ text_encoder, target_replace_module=["CLIPAttention"]
616
+ ):
617
+ print("Before training: text encoder First Layer lora up", _up.weight.data)
618
+ print(
619
+ "Before training: text encoder First Layer lora down", _down.weight.data
620
+ )
621
+ break
622
+
623
+ if args.use_xformers:
624
+ set_use_memory_efficient_attention_xformers(unet, True)
625
+ set_use_memory_efficient_attention_xformers(vae, True)
626
+
627
+ if args.gradient_checkpointing:
628
+ unet.enable_gradient_checkpointing()
629
+ if args.train_text_encoder:
630
+ text_encoder.gradient_checkpointing_enable()
631
+
632
+ if args.scale_lr:
633
+ args.learning_rate = (
634
+ args.learning_rate
635
+ * args.gradient_accumulation_steps
636
+ * args.train_batch_size
637
+ * accelerator.num_processes
638
+ )
639
+
640
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
641
+ if args.use_8bit_adam:
642
+ try:
643
+ import bitsandbytes as bnb
644
+ except ImportError:
645
+ raise ImportError(
646
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
647
+ )
648
+
649
+ optimizer_class = bnb.optim.AdamW8bit
650
+ else:
651
+ optimizer_class = torch.optim.AdamW
652
+
653
+ text_lr = (
654
+ args.learning_rate
655
+ if args.learning_rate_text is None
656
+ else args.learning_rate_text
657
+ )
658
+
659
+ params_to_optimize = (
660
+ [
661
+ {"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate},
662
+ {
663
+ "params": itertools.chain(*text_encoder_lora_params),
664
+ "lr": text_lr,
665
+ },
666
+ ]
667
+ if args.train_text_encoder
668
+ else itertools.chain(*unet_lora_params)
669
+ )
670
+ optimizer = optimizer_class(
671
+ params_to_optimize,
672
+ lr=args.learning_rate,
673
+ betas=(args.adam_beta1, args.adam_beta2),
674
+ weight_decay=args.adam_weight_decay,
675
+ eps=args.adam_epsilon,
676
+ )
677
+
678
+ noise_scheduler = DDPMScheduler.from_config(
679
+ args.pretrained_model_name_or_path, subfolder="scheduler"
680
+ )
681
+
682
+ train_dataset = DreamBoothDataset(
683
+ instance_data_root=args.instance_data_dir,
684
+ instance_prompt=args.instance_prompt,
685
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
686
+ class_prompt=args.class_prompt,
687
+ tokenizer=tokenizer,
688
+ size=args.resolution,
689
+ center_crop=args.center_crop,
690
+ color_jitter=args.color_jitter,
691
+ resize=args.resize,
692
+ )
693
+
694
+ def collate_fn(examples):
695
+ input_ids = [example["instance_prompt_ids"] for example in examples]
696
+ pixel_values = [example["instance_images"] for example in examples]
697
+
698
+ # Concat class and instance examples for prior preservation.
699
+ # We do this to avoid doing two forward passes.
700
+ if args.with_prior_preservation:
701
+ input_ids += [example["class_prompt_ids"] for example in examples]
702
+ pixel_values += [example["class_images"] for example in examples]
703
+
704
+ pixel_values = torch.stack(pixel_values)
705
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
706
+
707
+ input_ids = tokenizer.pad(
708
+ {"input_ids": input_ids},
709
+ padding="max_length",
710
+ max_length=tokenizer.model_max_length,
711
+ return_tensors="pt",
712
+ ).input_ids
713
+
714
+ batch = {
715
+ "input_ids": input_ids,
716
+ "pixel_values": pixel_values,
717
+ }
718
+ return batch
719
+
720
+ train_dataloader = torch.utils.data.DataLoader(
721
+ train_dataset,
722
+ batch_size=args.train_batch_size,
723
+ shuffle=True,
724
+ collate_fn=collate_fn,
725
+ num_workers=1,
726
+ )
727
+
728
+ # Scheduler and math around the number of training steps.
729
+ overrode_max_train_steps = False
730
+ num_update_steps_per_epoch = math.ceil(
731
+ len(train_dataloader) / args.gradient_accumulation_steps
732
+ )
733
+ if args.max_train_steps is None:
734
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
735
+ overrode_max_train_steps = True
736
+
737
+ lr_scheduler = get_scheduler(
738
+ args.lr_scheduler,
739
+ optimizer=optimizer,
740
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
741
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
742
+ )
743
+
744
+ if args.train_text_encoder:
745
+ (
746
+ unet,
747
+ text_encoder,
748
+ optimizer,
749
+ train_dataloader,
750
+ lr_scheduler,
751
+ ) = accelerator.prepare(
752
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
753
+ )
754
+ else:
755
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
756
+ unet, optimizer, train_dataloader, lr_scheduler
757
+ )
758
+
759
+ weight_dtype = torch.float32
760
+ if accelerator.mixed_precision == "fp16":
761
+ weight_dtype = torch.float16
762
+ elif accelerator.mixed_precision == "bf16":
763
+ weight_dtype = torch.bfloat16
764
+
765
+ # Move text_encode and vae to gpu.
766
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
767
+ # as these models are only used for inference, keeping weights in full precision is not required.
768
+ vae.to(accelerator.device, dtype=weight_dtype)
769
+ if not args.train_text_encoder:
770
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
771
+
772
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
773
+ num_update_steps_per_epoch = math.ceil(
774
+ len(train_dataloader) / args.gradient_accumulation_steps
775
+ )
776
+ if overrode_max_train_steps:
777
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
778
+ # Afterwards we recalculate our number of training epochs
779
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
780
+
781
+ # We need to initialize the trackers we use, and also store our configuration.
782
+ # The trackers initializes automatically on the main process.
783
+ if accelerator.is_main_process:
784
+ accelerator.init_trackers("dreambooth", config=vars(args))
785
+
786
+ # Train!
787
+ total_batch_size = (
788
+ args.train_batch_size
789
+ * accelerator.num_processes
790
+ * args.gradient_accumulation_steps
791
+ )
792
+
793
+ print("***** Running training *****")
794
+ print(f" Num examples = {len(train_dataset)}")
795
+ print(f" Num batches each epoch = {len(train_dataloader)}")
796
+ print(f" Num Epochs = {args.num_train_epochs}")
797
+ print(f" Instantaneous batch size per device = {args.train_batch_size}")
798
+ print(
799
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
800
+ )
801
+ print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
802
+ print(f" Total optimization steps = {args.max_train_steps}")
803
+ # Only show the progress bar once on each machine.
804
+ progress_bar = tqdm(
805
+ range(args.max_train_steps), disable=not accelerator.is_local_main_process
806
+ )
807
+ progress_bar.set_description("Steps")
808
+ global_step = 0
809
+ last_save = 0
810
+
811
+ for epoch in range(args.num_train_epochs):
812
+ unet.train()
813
+ if args.train_text_encoder:
814
+ text_encoder.train()
815
+
816
+ for step, batch in enumerate(train_dataloader):
817
+ # Convert images to latent space
818
+ latents = vae.encode(
819
+ batch["pixel_values"].to(dtype=weight_dtype)
820
+ ).latent_dist.sample()
821
+ latents = latents * 0.18215
822
+
823
+ # Sample noise that we'll add to the latents
824
+ noise = torch.randn_like(latents)
825
+ bsz = latents.shape[0]
826
+ # Sample a random timestep for each image
827
+ timesteps = torch.randint(
828
+ 0,
829
+ noise_scheduler.config.num_train_timesteps,
830
+ (bsz,),
831
+ device=latents.device,
832
+ )
833
+ timesteps = timesteps.long()
834
+
835
+ # Add noise to the latents according to the noise magnitude at each timestep
836
+ # (this is the forward diffusion process)
837
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
838
+
839
+ # Get the text embedding for conditioning
840
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
841
+
842
+ # Predict the noise residual
843
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
844
+
845
+ # Get the target for loss depending on the prediction type
846
+ if noise_scheduler.config.prediction_type == "epsilon":
847
+ target = noise
848
+ elif noise_scheduler.config.prediction_type == "v_prediction":
849
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
850
+ else:
851
+ raise ValueError(
852
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
853
+ )
854
+
855
+ if args.with_prior_preservation:
856
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
857
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
858
+ target, target_prior = torch.chunk(target, 2, dim=0)
859
+
860
+ # Compute instance loss
861
+ loss = (
862
+ F.mse_loss(model_pred.float(), target.float(), reduction="none")
863
+ .mean([1, 2, 3])
864
+ .mean()
865
+ )
866
+
867
+ # Compute prior loss
868
+ prior_loss = F.mse_loss(
869
+ model_pred_prior.float(), target_prior.float(), reduction="mean"
870
+ )
871
+
872
+ # Add the prior loss to the instance loss.
873
+ loss = loss + args.prior_loss_weight * prior_loss
874
+ else:
875
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
876
+
877
+ accelerator.backward(loss)
878
+ if accelerator.sync_gradients:
879
+ params_to_clip = (
880
+ itertools.chain(unet.parameters(), text_encoder.parameters())
881
+ if args.train_text_encoder
882
+ else unet.parameters()
883
+ )
884
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
885
+ optimizer.step()
886
+ lr_scheduler.step()
887
+ progress_bar.update(1)
888
+ optimizer.zero_grad()
889
+
890
+ global_step += 1
891
+
892
+ # Checks if the accelerator has performed an optimization step behind the scenes
893
+ if accelerator.sync_gradients:
894
+ if args.save_steps and global_step - last_save >= args.save_steps:
895
+ if accelerator.is_main_process:
896
+ # newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
897
+ # it, the models will be unwrapped, and when they are then used for further training,
898
+ # we will crash. pass this, but only to newer versions of accelerate. fixes
899
+ # https://github.com/huggingface/diffusers/issues/1566
900
+ accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
901
+ inspect.signature(
902
+ accelerator.unwrap_model
903
+ ).parameters.keys()
904
+ )
905
+ extra_args = (
906
+ {"keep_fp32_wrapper": True}
907
+ if accepts_keep_fp32_wrapper
908
+ else {}
909
+ )
910
+ pipeline = StableDiffusionPipeline.from_pretrained(
911
+ args.pretrained_model_name_or_path,
912
+ unet=accelerator.unwrap_model(unet, **extra_args),
913
+ text_encoder=accelerator.unwrap_model(
914
+ text_encoder, **extra_args
915
+ ),
916
+ revision=args.revision,
917
+ )
918
+
919
+ filename_unet = (
920
+ f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
921
+ )
922
+ filename_text_encoder = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
923
+ print(f"save weights {filename_unet}, {filename_text_encoder}")
924
+ save_lora_weight(pipeline.unet, filename_unet)
925
+ if args.train_text_encoder:
926
+ save_lora_weight(
927
+ pipeline.text_encoder,
928
+ filename_text_encoder,
929
+ target_replace_module=["CLIPAttention"],
930
+ )
931
+
932
+ for _up, _down in extract_lora_ups_down(pipeline.unet):
933
+ print(
934
+ "First Unet Layer's Up Weight is now : ",
935
+ _up.weight.data,
936
+ )
937
+ print(
938
+ "First Unet Layer's Down Weight is now : ",
939
+ _down.weight.data,
940
+ )
941
+ break
942
+ if args.train_text_encoder:
943
+ for _up, _down in extract_lora_ups_down(
944
+ pipeline.text_encoder,
945
+ target_replace_module=["CLIPAttention"],
946
+ ):
947
+ print(
948
+ "First Text Encoder Layer's Up Weight is now : ",
949
+ _up.weight.data,
950
+ )
951
+ print(
952
+ "First Text Encoder Layer's Down Weight is now : ",
953
+ _down.weight.data,
954
+ )
955
+ break
956
+
957
+ last_save = global_step
958
+
959
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
960
+ progress_bar.set_postfix(**logs)
961
+ accelerator.log(logs, step=global_step)
962
+
963
+ if global_step >= args.max_train_steps:
964
+ break
965
+
966
+ accelerator.wait_for_everyone()
967
+
968
+ # Create the pipeline using using the trained modules and save it.
969
+ if accelerator.is_main_process:
970
+ pipeline = StableDiffusionPipeline.from_pretrained(
971
+ args.pretrained_model_name_or_path,
972
+ unet=accelerator.unwrap_model(unet),
973
+ text_encoder=accelerator.unwrap_model(text_encoder),
974
+ revision=args.revision,
975
+ )
976
+
977
+ print("\n\nLora TRAINING DONE!\n\n")
978
+
979
+ if args.output_format == "pt" or args.output_format == "both":
980
+ save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
981
+ if args.train_text_encoder:
982
+ save_lora_weight(
983
+ pipeline.text_encoder,
984
+ args.output_dir + "/lora_weight.text_encoder.pt",
985
+ target_replace_module=["CLIPAttention"],
986
+ )
987
+
988
+ if args.output_format == "safe" or args.output_format == "both":
989
+ loras = {}
990
+ loras["unet"] = (pipeline.unet, {"CrossAttention", "Attention", "GEGLU"})
991
+ if args.train_text_encoder:
992
+ loras["text_encoder"] = (pipeline.text_encoder, {"CLIPAttention"})
993
+
994
+ save_safeloras(loras, args.output_dir + "/lora_weight.safetensors")
995
+
996
+ if args.push_to_hub:
997
+ repo.push_to_hub(
998
+ commit_message="End of training",
999
+ blocking=False,
1000
+ auto_lfs_prune=True,
1001
+ )
1002
+
1003
+ accelerator.end_training()
1004
+
1005
+
1006
+ if __name__ == "__main__":
1007
+ args = parse_args()
1008
+ main(args)