csyxwei commited on
Commit
0792c6b
1 Parent(s): e942d32

elite code init

Browse files
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _debug*
2
+ .env
3
+ __pycache__
4
+ _sc.py
5
+ *.ckpt
6
+ *.bin
7
+
8
+ checkpoints
9
+ .idea
10
+ .idea/workspace.xml
11
+ .DS_Store
12
+ */__pycache__git
13
+ .pyc
14
+ .iml
15
+ __pycache__/
16
+ */__pycache__/
17
+ */*/__pycache__/
18
+ */*/*/__pycache__/
19
+ */*/*/*/__pycache__/
20
+ */*/*/*/*/__pycache__/
21
+ */*/*/*/*/*/__pycache__/
2_gpu.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compute_environment": "LOCAL_MACHINE",
3
+ "distributed_type": "MULTI_GPU",
4
+ "fp16": false,
5
+ "machine_rank": 0,
6
+ "main_process_ip": null,
7
+ "main_process_port": null,
8
+ "main_training_function": "main",
9
+ "num_machines": 1,
10
+ "num_processes": 2
11
+ }
3_gpu.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compute_environment": "LOCAL_MACHINE",
3
+ "distributed_type": "MULTI_GPU",
4
+ "fp16": false,
5
+ "machine_rank": 0,
6
+ "main_process_ip": null,
7
+ "main_process_port": null,
8
+ "main_training_function": "main",
9
+ "num_machines": 1,
10
+ "num_processes": 3
11
+ }
4_gpu.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compute_environment": "LOCAL_MACHINE",
3
+ "distributed_type": "MULTI_GPU",
4
+ "fp16": false,
5
+ "machine_rank": 0,
6
+ "main_process_ip": null,
7
+ "main_process_port": null,
8
+ "main_training_function": "main",
9
+ "num_machines": 1,
10
+ "num_processes": 4
11
+ }
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,3 @@
1
- ---
2
- license: apache-2.0
3
- ---
1
+ # ELITE
2
+
3
+ The detailed README is coming soom.
datasets.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ import os
5
+ import PIL
6
+ from torch.utils.data import Dataset
7
+ import torchvision
8
+ import numpy as np
9
+ import torch
10
+ import random
11
+ import albumentations as A
12
+ import copy
13
+ import cv2
14
+ import pandas as pd
15
+
16
+
17
+ imagenet_templates_small = [
18
+ "a photo of a {}",
19
+ "a rendering of a {}",
20
+ "a cropped photo of the {}",
21
+ "the photo of a {}",
22
+ "a photo of a clean {}",
23
+ "a photo of a dirty {}",
24
+ "a dark photo of the {}",
25
+ "a photo of my {}",
26
+ "a photo of the cool {}",
27
+ "a close-up photo of a {}",
28
+ "a bright photo of the {}",
29
+ "a cropped photo of a {}",
30
+ "a photo of the {}",
31
+ "a good photo of the {}",
32
+ "a photo of one {}",
33
+ "a close-up photo of the {}",
34
+ "a rendition of the {}",
35
+ "a photo of the clean {}",
36
+ "a rendition of a {}",
37
+ "a photo of a nice {}",
38
+ "a good photo of a {}",
39
+ "a photo of the nice {}",
40
+ "a photo of the small {}",
41
+ "a photo of the weird {}",
42
+ "a photo of the large {}",
43
+ "a photo of a cool {}",
44
+ "a photo of a small {}",
45
+ ]
46
+
47
+
48
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
49
+ PIL_INTERPOLATION = {
50
+ "linear": PIL.Image.Resampling.BILINEAR,
51
+ "bilinear": PIL.Image.Resampling.BILINEAR,
52
+ "bicubic": PIL.Image.Resampling.BICUBIC,
53
+ "lanczos": PIL.Image.Resampling.LANCZOS,
54
+ "nearest": PIL.Image.Resampling.NEAREST,
55
+ }
56
+ else:
57
+ PIL_INTERPOLATION = {
58
+ "linear": PIL.Image.LINEAR,
59
+ "bilinear": PIL.Image.BILINEAR,
60
+ "bicubic": PIL.Image.BICUBIC,
61
+ "lanczos": PIL.Image.LANCZOS,
62
+ "nearest": PIL.Image.NEAREST,
63
+ }
64
+
65
+ def is_image(file):
66
+ return 'jpg' in file.lower() or 'png' in file.lower() or 'jpeg' in file.lower()
67
+
68
+ class CustomDatasetWithBG(Dataset):
69
+ def __init__(
70
+ self,
71
+ data_root,
72
+ tokenizer,
73
+ size=512,
74
+ interpolation="bicubic",
75
+ placeholder_token="*",
76
+ template="a photo of a {}",
77
+ ):
78
+ self.data_root = data_root
79
+ self.tokenizer = tokenizer
80
+ self.size = size
81
+ self.placeholder_token = placeholder_token
82
+
83
+ self.image_paths = []
84
+ self.image_paths += [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if is_image(file_path) and not 'bg' in file_path]
85
+
86
+ self.image_paths = sorted(self.image_paths)
87
+
88
+ self.num_images = len(self.image_paths)
89
+ self._length = self.num_images
90
+
91
+ self.interpolation = {
92
+ "linear": PIL_INTERPOLATION["linear"],
93
+ "bilinear": PIL_INTERPOLATION["bilinear"],
94
+ "bicubic": PIL_INTERPOLATION["bicubic"],
95
+ "lanczos": PIL_INTERPOLATION["lanczos"],
96
+ }[interpolation]
97
+
98
+ self.template = template
99
+
100
+ def __len__(self):
101
+ return self._length
102
+
103
+ def get_tensor_clip(self, normalize=True, toTensor=True):
104
+ transform_list = []
105
+ if toTensor:
106
+ transform_list += [torchvision.transforms.ToTensor()]
107
+ if normalize:
108
+ transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
109
+ (0.26862954, 0.26130258, 0.27577711))]
110
+ return torchvision.transforms.Compose(transform_list)
111
+
112
+ def process(self, image):
113
+ img = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
114
+ img = np.array(img).astype(np.float32)
115
+ img = img / 127.5 - 1.0
116
+ return torch.from_numpy(img).permute(2, 0, 1)
117
+
118
+ def __getitem__(self, i):
119
+ example = {}
120
+
121
+ placeholder_string = self.placeholder_token
122
+ text = self.template.format(placeholder_string)
123
+ example["text"] = text
124
+
125
+ placeholder_index = 0
126
+ words = text.strip().split(' ')
127
+ for idx, word in enumerate(words):
128
+ if word == placeholder_string:
129
+ placeholder_index = idx + 1
130
+
131
+ example["index"] = torch.tensor(placeholder_index)
132
+
133
+ example["input_ids"] = self.tokenizer(
134
+ text,
135
+ padding="max_length",
136
+ truncation=True,
137
+ max_length=self.tokenizer.model_max_length,
138
+ return_tensors="pt",
139
+ ).input_ids[0]
140
+
141
+ image = Image.open(self.image_paths[i % self.num_images])
142
+
143
+ mask_path = self.image_paths[i % self.num_images].replace('.jpeg', '.png').replace('.jpg', '.png').replace('.JPEG', '.png')[:-4] + '_bg.png'
144
+ mask = np.array(Image.open(mask_path)) / 255.0
145
+
146
+ if not image.mode == "RGB":
147
+ image = image.convert("RGB")
148
+
149
+ image_np = np.array(image)
150
+ object_tensor = image_np * mask
151
+ example["pixel_values"] = self.process(image_np)
152
+
153
+
154
+ ref_object_tensor = Image.fromarray(object_tensor.astype('uint8')).resize((224, 224), resample=self.interpolation)
155
+ ref_image_tenser = Image.fromarray(image_np.astype('uint8')).resize((224, 224), resample=self.interpolation)
156
+ example["pixel_values_obj"] = self.get_tensor_clip()(ref_object_tensor)
157
+ example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tenser)
158
+
159
+ ref_seg_tensor = Image.fromarray(mask.astype('uint8') * 255)
160
+ ref_seg_tensor = self.get_tensor_clip(normalize=False)(ref_seg_tensor)
161
+ example["pixel_values_seg"] = torch.nn.functional.interpolate(ref_seg_tensor.unsqueeze(0), size=(128, 128), mode='nearest').squeeze(0)
162
+
163
+ return example
164
+
165
+
166
+ class OpenImagesDataset(Dataset):
167
+ def __init__(
168
+ self,
169
+ data_root,
170
+ tokenizer,
171
+ size=512,
172
+ interpolation="bicubic",
173
+ set="train",
174
+ placeholder_token="*",
175
+ ):
176
+ self.data_root = data_root
177
+ self.tokenizer = tokenizer
178
+ self.size = size
179
+ self.placeholder_token = placeholder_token
180
+ self.set_type = set
181
+
182
+ self.random_trans = A.Compose([
183
+ A.Resize(height=224, width=224),
184
+ A.HorizontalFlip(p=0.5),
185
+ A.Rotate(limit=20),
186
+ A.Blur(p=0.3),
187
+ A.ElasticTransform(p=0.3)
188
+ ])
189
+
190
+ self.bbox_path_list = []
191
+ if set == "train":
192
+ bboxs_path = os.path.join(data_root, 'annotations', f'oidv6-train-annotations-bbox.csv')
193
+ elif set == "validation":
194
+ bboxs_path = os.path.join(data_root, 'annotations', f'validation-annotations-bbox.csv')
195
+ else:
196
+ bboxs_path = os.path.join(data_root, 'annotations', f'test-annotations-bbox.csv')
197
+
198
+ df_val_bbox = pd.read_csv(bboxs_path)
199
+ bbox_groups = df_val_bbox.groupby(df_val_bbox.LabelName)
200
+
201
+ bbox_full = []
202
+ for label_name in df_val_bbox['LabelName'].unique():
203
+ bboxs = bbox_groups.get_group(label_name)[
204
+ ['XMin', 'XMax', 'YMin', 'YMax', 'LabelName', 'ImageID',
205
+ 'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsInside']].values.tolist()
206
+ bboxs_new = []
207
+ for bbox in bboxs:
208
+ if not ((bbox[1] - bbox[0]) * (bbox[3] - bbox[2]) > 0.8 or (bbox[1] - bbox[0]) * (
209
+ bbox[3] - bbox[2]) < 0.02):
210
+ bboxs_new.append([bbox[0], bbox[1], bbox[2], bbox[3], bbox[4], bbox[5]])
211
+ bbox_full.extend(bboxs_new)
212
+
213
+ self.bboxs_full = bbox_full
214
+
215
+ self.num_images = len(bbox_full)
216
+
217
+ print('{}: total {} images ...'.format(set, self.num_images))
218
+
219
+ self._length = self.num_images
220
+
221
+ self.interpolation = {
222
+ "linear": PIL_INTERPOLATION["linear"],
223
+ "bilinear": PIL_INTERPOLATION["bilinear"],
224
+ "bicubic": PIL_INTERPOLATION["bicubic"],
225
+ "lanczos": PIL_INTERPOLATION["lanczos"],
226
+ }[interpolation]
227
+
228
+ self.templates = imagenet_templates_small
229
+
230
+
231
+ def __len__(self):
232
+ return self._length
233
+
234
+ def get_tensor_clip(self, normalize=True, toTensor=True):
235
+ transform_list = []
236
+ if toTensor:
237
+ transform_list += [torchvision.transforms.ToTensor()]
238
+ if normalize:
239
+ transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
240
+ (0.26862954, 0.26130258, 0.27577711))]
241
+ return torchvision.transforms.Compose(transform_list)
242
+
243
+ def process(self, image):
244
+ img = np.array(image)
245
+ img = cv2.resize(img, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
246
+ img = np.array(img).astype(np.float32)
247
+ img = img / 127.5 - 1.0
248
+ return torch.from_numpy(img).permute(2, 0, 1)
249
+
250
+ def obtain_text(self, add_caption, object_category=None):
251
+
252
+ if object_category is None:
253
+ placeholder_string = self.placeholder_token
254
+ else:
255
+ placeholder_string = object_category
256
+
257
+ text = random.choice(self.templates).format(placeholder_string)
258
+ text = add_caption + text[1:]
259
+
260
+ placeholder_index = 0
261
+ words = text.strip().split(' ')
262
+ for idx, word in enumerate(words):
263
+ if word == placeholder_string:
264
+ placeholder_index = idx + 1
265
+
266
+ index = torch.tensor(placeholder_index)
267
+
268
+ input_ids = self.tokenizer(
269
+ text,
270
+ padding="max_length",
271
+ truncation=True,
272
+ max_length=self.tokenizer.model_max_length,
273
+ return_tensors="pt",
274
+ ).input_ids[0]
275
+ return input_ids, index, text
276
+
277
+ def __getitem__(self, i):
278
+ example = {}
279
+
280
+ input_ids, index, text = self.obtain_text('a')
281
+ example["input_ids"] = input_ids
282
+ example["index"] = index
283
+ example["text"] = text
284
+
285
+ bbox_sample = self.bboxs_full[i % self.num_images]
286
+ bbox_sample = copy.copy(bbox_sample)
287
+
288
+ file_name = bbox_sample[-1] + '.jpg'
289
+ img_path = os.path.join(self.data_root, 'images', self.set_type, file_name)
290
+
291
+ try:
292
+ img_p = Image.open(img_path).convert("RGB")
293
+ img_p_np = np.array(img_p)
294
+ bbox_sample[0] *= int(img_p_np.shape[1])
295
+ bbox_sample[1] *= int(img_p_np.shape[1])
296
+ bbox_sample[2] *= int(img_p_np.shape[0])
297
+ bbox_sample[3] *= int(img_p_np.shape[0])
298
+
299
+ bbox_pad = copy.copy(bbox_sample)
300
+ bbox_pad[0] = int(bbox_sample[0] - min(10, bbox_sample[0] - 0))
301
+ bbox_pad[1] = int(bbox_sample[1] + min(10, img_p.size[0] - bbox_sample[1]))
302
+ bbox_pad[2] = int(bbox_sample[2] - min(10, bbox_sample[2] - 0))
303
+ bbox_pad[3] = int(bbox_sample[3] + min(10, img_p.size[1] - bbox_sample[3]))
304
+
305
+ image_tensor = img_p_np[bbox_pad[2]:bbox_pad[3], bbox_pad[0]:bbox_pad[1], :]
306
+ example["pixel_values"] = self.process(image_tensor)
307
+
308
+ ref_image_tensor = self.random_trans(image=image_tensor)
309
+ ref_image_tensor = Image.fromarray(ref_image_tensor["image"])
310
+ example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tensor)
311
+
312
+ except Exception as e:
313
+ example["pixel_values"] = torch.zeros((3, 512, 512))
314
+ example["pixel_values_clip"] = torch.zeros((3, 224, 224))
315
+ with open('error.txt', 'a+') as f:
316
+ f.write(str(e) + '\n')
317
+
318
+ return example
319
+
320
+
321
+ class OpenImagesDatasetWithMask(OpenImagesDataset):
322
+ def __init__(self,
323
+ data_root,
324
+ tokenizer,
325
+ size=512,
326
+ interpolation="bicubic",
327
+ set="train",
328
+ placeholder_token="*"):
329
+
330
+ # super().__init__(data_root, tokenizer, size, interpolation, set, placeholder_token)
331
+ self.data_root = data_root
332
+ self.tokenizer = tokenizer
333
+ self.size = size
334
+ self.placeholder_token = placeholder_token
335
+ self.set = set
336
+
337
+ class_anno_path = os.path.join(data_root, 'annotations', f'oidv6-class-descriptions.csv')
338
+ anno_files = pd.read_csv(class_anno_path)
339
+ class_groups = anno_files.groupby(anno_files.LabelName)
340
+
341
+ if set == "train":
342
+ bboxs_path = os.path.join(data_root, 'annotations', f'train-annotations-object-segmentation.csv')
343
+ dict_path = os.path.join(data_root, 'segs', f'train_bbox_dict.npy')
344
+ elif set == "validation":
345
+ bboxs_path = os.path.join(data_root, 'annotations', f'validation-annotations-object-segmentation.csv')
346
+ dict_path = os.path.join(data_root, 'segs', f'validation_bbox_dict.npy')
347
+ else:
348
+ bboxs_path = os.path.join(data_root, 'annotations', f'test-annotations-object-segmentation.csv')
349
+ dict_path = os.path.join(data_root, 'segs', f'test_bbox_dict.npy')
350
+
351
+ bbox_dict = np.load(dict_path, allow_pickle=True).item()
352
+
353
+ df_val_bbox = pd.read_csv(bboxs_path)
354
+ bbox_groups = df_val_bbox.groupby(df_val_bbox.LabelName)
355
+ bboxes_full = []
356
+ for label_name in df_val_bbox['LabelName'].unique():
357
+ bboxs = bbox_groups.get_group(label_name)[
358
+ ['BoxXMin', 'BoxXMax', 'BoxYMin', 'BoxYMax', 'LabelName', 'MaskPath']].values.tolist()
359
+ bboxes_new = []
360
+ for box in bboxs:
361
+ if not box[-1] in bbox_dict:
362
+ continue
363
+ bbox_data = bbox_dict[box[-1]]
364
+
365
+ if (bbox_data[2] - bbox_data[1]) < 100 or (bbox_data[4] - bbox_data[3]) < 100:
366
+ continue
367
+ if not ((bbox_data[2] - bbox_data[1]) / (bbox_data[4] - bbox_data[3]) < 0.5 or (
368
+ bbox_data[4] - bbox_data[3]) / ( bbox_data[2] - bbox_data[1]) < 0.5):
369
+ class_name = class_groups.get_group(box[4])[['DisplayName']].values.tolist()[0][0]
370
+ bboxes_new.append([box[-1], bbox_data[1], bbox_data[2], bbox_data[3], bbox_data[4], class_name])
371
+
372
+ bboxes_full.extend(bboxes_new)
373
+
374
+ self.bboxes_full = bboxes_full
375
+ self.num_images = len(bboxes_full)
376
+
377
+ print('{}: total {} images ...'.format(set, self.num_images))
378
+
379
+ self._length = self.num_images
380
+ self.interpolation = {
381
+ "linear": PIL_INTERPOLATION["linear"],
382
+ "bilinear": PIL_INTERPOLATION["bilinear"],
383
+ "bicubic": PIL_INTERPOLATION["bicubic"],
384
+ "lanczos": PIL_INTERPOLATION["lanczos"],
385
+ }[interpolation]
386
+
387
+ self.templates = imagenet_templates_small
388
+
389
+
390
+ def __len__(self):
391
+ return self._length
392
+
393
+ ## borrowed from custom diffusion
394
+ def custom_aug(self, instance_image):
395
+ instance_image = Image.fromarray(instance_image)
396
+ #### apply augmentation and create a valid image regions mask ####
397
+ if np.random.randint(0, 3) < 2:
398
+ random_scale = np.random.randint(self.size // 3, self.size + 1)
399
+ else:
400
+ random_scale = np.random.randint(int(1.2 * self.size), int(1.4 * self.size))
401
+
402
+ if random_scale % 2 == 1:
403
+ random_scale += 1
404
+
405
+ if random_scale < 0.6 * self.size:
406
+ add_to_caption = np.random.choice(["a far away", "very small"])
407
+ cx = np.random.randint(random_scale // 2, self.size - random_scale // 2 + 1)
408
+ cy = np.random.randint(random_scale // 2, self.size - random_scale // 2 + 1)
409
+
410
+ instance_image1 = instance_image.resize((random_scale, random_scale), resample=self.interpolation)
411
+ instance_image1 = np.array(instance_image1).astype(np.uint8)
412
+ instance_image1 = (instance_image1 / 127.5 - 1.0).astype(np.float32)
413
+
414
+ instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)
415
+ instance_image[cx - random_scale // 2: cx + random_scale // 2,
416
+ cy - random_scale // 2: cy + random_scale // 2, :] = instance_image1
417
+
418
+ mask = np.zeros((self.size // 8, self.size // 8))
419
+ mask[(cx - random_scale // 2) // 8 + 1: (cx + random_scale // 2) // 8 - 1,
420
+ (cy - random_scale // 2) // 8 + 1: (cy + random_scale // 2) // 8 - 1] = 1.
421
+
422
+ elif random_scale > self.size:
423
+ add_to_caption = np.random.choice(["zoomed in", "close up"])
424
+ cx = np.random.randint(self.size // 2, random_scale - self.size // 2 + 1)
425
+ cy = np.random.randint(self.size // 2, random_scale - self.size // 2 + 1)
426
+
427
+ instance_image = instance_image.resize((random_scale, random_scale), resample=self.interpolation)
428
+ instance_image = np.array(instance_image).astype(np.uint8)
429
+ instance_image = (instance_image / 127.5 - 1.0).astype(np.float32)
430
+ instance_image = instance_image[cx - self.size // 2: cx + self.size // 2,
431
+ cy - self.size // 2: cy + self.size // 2, :]
432
+ mask = np.ones((self.size // 8, self.size // 8))
433
+ else:
434
+ add_to_caption = "a"
435
+ if self.size is not None:
436
+ instance_image = instance_image.resize((self.size, self.size), resample=self.interpolation)
437
+ instance_image = np.array(instance_image).astype(np.uint8)
438
+ instance_image = (instance_image / 127.5 - 1.0).astype(np.float32)
439
+ mask = np.ones((self.size // 8, self.size // 8))
440
+
441
+ return torch.from_numpy(instance_image).permute(2, 0, 1), torch.from_numpy(mask[:, :, None]).permute(2, 0, 1), add_to_caption
442
+
443
+ def aug_cv2(self, img, seg):
444
+
445
+ img_auged = np.array(img).copy()
446
+ seg_auged = np.array(seg).copy()
447
+ # resize and crop
448
+ if random.choice([0, 1]) == 0:
449
+ new_size = random.randint(224, 256)
450
+ img_auged = cv2.resize(img_auged, (new_size, new_size), interpolation=cv2.INTER_CUBIC)
451
+ seg_auged = cv2.resize(seg_auged, (new_size, new_size), interpolation=cv2.INTER_NEAREST)
452
+
453
+ start_x, start_y = random.randint(0, new_size - 224), random.randint(0, new_size - 224)
454
+ img_auged = img_auged[start_x:start_x + 224, start_y:start_y + 224, :]
455
+ seg_auged = seg_auged[start_x:start_x + 224, start_y:start_y + 224, :]
456
+
457
+ h, w = img_auged.shape[:2]
458
+ # rotate
459
+ if random.choice([0, 1]) == 0:
460
+ # print('rotate')
461
+ angle = random.randint(-30, 30)
462
+ M = cv2.getRotationMatrix2D((112, 112), angle, 1)
463
+ img_auged = cv2.warpAffine(img_auged, M, (w, h), flags=cv2.INTER_CUBIC)
464
+ seg_auged = cv2.warpAffine(seg_auged, M, (w, h), flags=cv2.INTER_NEAREST)
465
+
466
+ # translation
467
+ if random.choice([0, 1]) == 0:
468
+ trans_x = random.randint(-60, 60)
469
+ trans_y = random.randint(-60, 60)
470
+ H = np.float32([[1, 0, trans_x],
471
+ [0, 1, trans_y]])
472
+ img_auged = cv2.warpAffine(img_auged, H, (w, h), flags=cv2.INTER_CUBIC)
473
+ seg_auged = cv2.warpAffine(seg_auged, H, (w, h), flags=cv2.INTER_NEAREST)
474
+
475
+ img_auged = Image.fromarray(img_auged)
476
+ seg_auged = Image.fromarray(seg_auged)
477
+
478
+ return img_auged, seg_auged
479
+
480
+
481
+ def __getitem__(self, i):
482
+ example = {}
483
+
484
+ seg_name = self.bboxes_full[i % self.num_images][0]
485
+ file_name = seg_name.split('_')[0] + '.jpg'
486
+ img_path = os.path.join(self.data_root, 'images', self.set, file_name)
487
+ seg_path = os.path.join(self.data_root, 'segs', self.set, seg_name)
488
+
489
+ try:
490
+ # crop image and mask
491
+ bbox_sample = self.bboxes_full[i % self.num_images][1:]
492
+ img_p_np = cv2.imread(img_path)
493
+ img_p_np = cv2.cvtColor(img_p_np, cv2.COLOR_BGR2RGB)
494
+ seg_p_np = cv2.imread(seg_path).astype('float')
495
+ seg_p_np = cv2.resize(seg_p_np, img_p_np.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
496
+
497
+ bbox_pad = copy.copy(bbox_sample)
498
+ pad_size = random.choice(list(range(10, 20)))
499
+ bbox_pad[0] = int(bbox_pad[0] - min(pad_size, bbox_pad[0] - 0))
500
+ bbox_pad[1] = int(bbox_pad[1] + pad_size)
501
+ bbox_pad[2] = int(bbox_pad[2] - min(pad_size, bbox_pad[2] - 0))
502
+ bbox_pad[3] = int(bbox_pad[3] + pad_size)
503
+
504
+ image_tensor = img_p_np[bbox_pad[0]:bbox_pad[1], bbox_pad[2]:bbox_pad[3], :]
505
+ seg_tensor = seg_p_np[bbox_pad[0]:bbox_pad[1], bbox_pad[2]:bbox_pad[3], :]
506
+
507
+ # augmentation for input image
508
+ augged_image, augged_mask, add_caption = self.custom_aug(image_tensor)
509
+ input_ids, index, text = self.obtain_text(add_caption)
510
+
511
+ example["pixel_values"] = augged_image
512
+ example["mask_values"] = augged_mask
513
+ example["input_ids"] = input_ids
514
+ example["index"] = index
515
+ example["text"] = text
516
+
517
+ object_tensor = image_tensor * (seg_tensor / 255)
518
+ ref_object_tensor = cv2.resize(object_tensor, (224, 224), interpolation=cv2.INTER_CUBIC)
519
+ ref_image_tenser = cv2.resize(image_tensor, (224, 224), interpolation=cv2.INTER_CUBIC)
520
+ ref_seg_tensor = cv2.resize(seg_tensor, (224, 224), interpolation=cv2.INTER_NEAREST)
521
+
522
+ ref_object_tensor, ref_seg_tensor = self.aug_cv2(ref_object_tensor.astype('uint8'), ref_seg_tensor.astype('uint8'))
523
+ example["pixel_values_clip"] = self.get_tensor_clip()(Image.fromarray(ref_image_tenser))
524
+ example["pixel_values_obj"] = self.get_tensor_clip()(ref_object_tensor)
525
+ example["pixel_values_seg"] = self.get_tensor_clip(normalize=False)(ref_seg_tensor)
526
+
527
+ except Exception as e:
528
+ example["pixel_values"] = torch.zeros((3, 512, 512))
529
+ example["pixel_values_obj"] = torch.zeros((3, 224, 224))
530
+ example["pixel_values_clip"] = torch.zeros((3, 224, 224))
531
+ example["pixel_values_seg"] = torch.zeros((3, 224, 224))
532
+
533
+ input_ids, index, text = self.obtain_text("a")
534
+ example["input_ids"] = input_ids
535
+ example["index"] = index
536
+ example["text"] = text
537
+
538
+ with open('error.txt', 'a+') as f:
539
+ f.write(str(e) + '\n')
540
+
541
+ return example
elite.yaml ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: elite
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1=main
6
+ - ca-certificates=2022.10.11=h06a4308_0
7
+ - certifi=2022.9.24=py39h06a4308_0
8
+ - ld_impl_linux-64=2.38=h1181459_1
9
+ - libffi=3.3=he6710b0_2
10
+ - libgcc-ng=9.1.0=hdf63c60_0
11
+ - libstdcxx-ng=9.1.0=hdf63c60_0
12
+ - ncurses=6.3=h7f8727e_2
13
+ - openssl=1.1.1s=h7f8727e_0
14
+ - pip=22.2.2=py39h06a4308_0
15
+ - python=3.9.12=h12debd9_1
16
+ - readline=8.1.2=h7f8727e_1
17
+ - sqlite=3.38.5=hc218d9a_0
18
+ - tk=8.6.12=h1ccaba5_0
19
+ - wheel=0.37.1=pyhd3eb1b0_0
20
+ - xz=5.2.5=h7f8727e_1
21
+ - zlib=1.2.12=h7f8727e_2
22
+ - pip:
23
+ - absl-py==1.3.0
24
+ - accelerate==0.15.0
25
+ - aiohttp==3.8.3
26
+ - aiosignal==1.3.1
27
+ - albumentations==1.1.0
28
+ - altair==4.2.0
29
+ - antlr4-python3-runtime==4.8
30
+ - async-timeout==4.0.2
31
+ - attrs==22.1.0
32
+ - blinker==1.5
33
+ - cachetools==5.2.0
34
+ - charset-normalizer==2.1.1
35
+ - click==8.1.3
36
+ - commonmark==0.9.1
37
+ - contourpy==1.0.6
38
+ - cycler==0.11.0
39
+ - cython==0.29.33
40
+ - decorator==5.1.1
41
+ - diffusers==0.11.1
42
+ - einops==0.4.1
43
+ - emoji==2.2.0
44
+ - entrypoints==0.4
45
+ - faiss-gpu==1.7.2
46
+ - filelock==3.8.0
47
+ - fonttools==4.38.0
48
+ - frozenlist==1.3.3
49
+ - fsspec==2022.11.0
50
+ - ftfy==6.1.1
51
+ - future==0.18.2
52
+ - gitdb==4.0.9
53
+ - gitpython==3.1.29
54
+ - google-auth==2.14.1
55
+ - google-auth-oauthlib==0.4.6
56
+ - grpcio==1.50.0
57
+ - huggingface-hub==0.11.0
58
+ - idna==3.4
59
+ - imageio==2.14.1
60
+ - imageio-ffmpeg==0.4.7
61
+ - importlib-metadata==5.0.0
62
+ - jinja2==3.1.2
63
+ - joblib==1.2.0
64
+ - jsonschema==4.17.0
65
+ - kiwisolver==1.4.4
66
+ - kornia==0.6.0
67
+ - markdown==3.4.1
68
+ - markupsafe==2.1.1
69
+ - matplotlib==3.6.2
70
+ - multidict==6.0.2
71
+ - networkx==2.8.8
72
+ - nltk==3.7
73
+ - numpy==1.23.4
74
+ - oauthlib==3.2.2
75
+ - omegaconf==2.1.1
76
+ - opencv-python==4.6.0.66
77
+ - opencv-python-headless==4.6.0.66
78
+ - packaging==21.3
79
+ - pandas==1.5.1
80
+ - pillow==9.0.1
81
+ - protobuf==3.20.1
82
+ - psutil==5.9.4
83
+ - pudb==2019.2
84
+ - pyarrow==10.0.0
85
+ - pyasn1==0.4.8
86
+ - pyasn1-modules==0.2.8
87
+ - pycocotools==2.0.6
88
+ - pydeck==0.8.0
89
+ - pydensecrf==1.0rc2
90
+ - pydeprecate==0.3.2
91
+ - pygments==2.13.0
92
+ - pympler==1.0.1
93
+ - pyparsing==3.0.9
94
+ - pyrsistent==0.19.2
95
+ - python-dateutil==2.8.2
96
+ - python-dotenv==0.21.0
97
+ - pytorch-lightning==1.6.5
98
+ - pytz==2022.6
99
+ - pytz-deprecation-shim==0.1.0.post0
100
+ - pywavelets==1.4.1
101
+ - pyyaml==6.0
102
+ - qudida==0.0.4
103
+ - regex==2022.10.31
104
+ - requests==2.28.1
105
+ - requests-oauthlib==1.3.1
106
+ - rich==12.6.0
107
+ - rsa==4.9
108
+ - sacremoses==0.0.53
109
+ - scikit-image==0.19.3
110
+ - scikit-learn==1.1.3
111
+ - scipy==1.9.3
112
+ - semver==2.13.0
113
+ - setuptools==59.5.0
114
+ - six==1.16.0
115
+ - smmap==5.0.0
116
+ - stanza==1.4.2
117
+ - streamlit==1.15.0
118
+ - tensorboard==2.11.0
119
+ - tensorboard-data-server==0.6.1
120
+ - tensorboard-plugin-wit==1.8.1
121
+ - test-tube==0.7.5
122
+ - threadpoolctl==3.1.0
123
+ - tifffile==2022.10.10
124
+ - timm==0.6.12
125
+ - tokenizers==0.12.1
126
+ - toml==0.10.2
127
+ - toolz==0.12.0
128
+ - torch==1.12.1+cu116
129
+ - torch-fidelity==0.3.0
130
+ - torchaudio==0.12.1+cu116
131
+ - torchmetrics==0.6.0
132
+ - torchvision==0.13.1+cu116
133
+ - tornado==6.2
134
+ - tqdm==4.64.1
135
+ - transformers==4.25.1
136
+ - typing-extensions==4.4.0
137
+ - tzdata==2022.6
138
+ - tzlocal==4.2
139
+ - urllib3==1.26.12
140
+ - urwid==2.1.2
141
+ - validators==0.20.0
142
+ - watchdog==2.1.9
143
+ - wcwidth==0.2.5
144
+ - werkzeug==2.2.2
145
+ - yarl==1.8.1
146
+ - zipp==3.10.0
147
+
inference_global.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Tuple
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
6
+ from PIL import Image
7
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
8
+ from train_global import Mapper, th2image
9
+ from train_global import inj_forward_text, inj_forward_crossattention, validation
10
+ import torch.nn as nn
11
+ from datasets import CustomDatasetWithBG
12
+
13
+ def _pil_from_latents(vae, latents):
14
+ _latents = 1 / 0.18215 * latents.clone()
15
+ image = vae.decode(_latents).sample
16
+
17
+ image = (image / 2 + 0.5).clamp(0, 1)
18
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
19
+ images = (image * 255).round().astype("uint8")
20
+ ret_pil_images = [Image.fromarray(image) for image in images]
21
+
22
+ return ret_pil_images
23
+
24
+
25
+ def pww_load_tools(
26
+ device: str = "cuda:0",
27
+ scheduler_type=LMSDiscreteScheduler,
28
+ mapper_model_path: Optional[str] = None,
29
+ diffusion_model_path: Optional[str] = None,
30
+ model_token: Optional[str] = None,
31
+ ) -> Tuple[
32
+ UNet2DConditionModel,
33
+ CLIPTextModel,
34
+ CLIPTokenizer,
35
+ AutoencoderKL,
36
+ CLIPVisionModel,
37
+ Mapper,
38
+ LMSDiscreteScheduler,
39
+ ]:
40
+
41
+ # 'CompVis/stable-diffusion-v1-4'
42
+ local_path_only = diffusion_model_path is not None
43
+ vae = AutoencoderKL.from_pretrained(
44
+ diffusion_model_path,
45
+ subfolder="vae",
46
+ use_auth_token=model_token,
47
+ torch_dtype=torch.float16,
48
+ local_files_only=local_path_only,
49
+ )
50
+
51
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
52
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
53
+ image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
54
+
55
+
56
+ # Load models and create wrapper for stable diffusion
57
+ for _module in text_encoder.modules():
58
+ if _module.__class__.__name__ == "CLIPTextTransformer":
59
+ _module.__class__.__call__ = inj_forward_text
60
+
61
+ unet = UNet2DConditionModel.from_pretrained(
62
+ diffusion_model_path,
63
+ subfolder="unet",
64
+ use_auth_token=model_token,
65
+ torch_dtype=torch.float16,
66
+ local_files_only=local_path_only,
67
+ )
68
+
69
+ mapper = Mapper(input_dim=1024, output_dim=768)
70
+
71
+ for _name, _module in unet.named_modules():
72
+ if _module.__class__.__name__ == "CrossAttention":
73
+ if 'attn1' in _name: continue
74
+ _module.__class__.__call__ = inj_forward_crossattention
75
+
76
+ shape = _module.to_k.weight.shape
77
+ to_k_global = nn.Linear(shape[1], shape[0], bias=False)
78
+ mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global)
79
+
80
+ shape = _module.to_v.weight.shape
81
+ to_v_global = nn.Linear(shape[1], shape[0], bias=False)
82
+ mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global)
83
+
84
+ mapper.load_state_dict(torch.load(mapper_model_path, map_location='cpu'))
85
+ mapper.half()
86
+
87
+ for _name, _module in unet.named_modules():
88
+ if 'attn1' in _name: continue
89
+ if _module.__class__.__name__ == "CrossAttention":
90
+ _module.add_module('to_k_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_k'))
91
+ _module.add_module('to_v_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_v'))
92
+
93
+ vae.to(device), unet.to(device), text_encoder.to(device), image_encoder.to(device), mapper.to(device)
94
+
95
+ scheduler = scheduler_type(
96
+ beta_start=0.00085,
97
+ beta_end=0.012,
98
+ beta_schedule="scaled_linear",
99
+ num_train_timesteps=1000,
100
+ )
101
+ vae.eval()
102
+ unet.eval()
103
+ image_encoder.eval()
104
+ text_encoder.eval()
105
+ mapper.eval()
106
+ return vae, unet, text_encoder, tokenizer, image_encoder, mapper, scheduler
107
+
108
+
109
+ def parse_args():
110
+ import argparse
111
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
112
+ parser.add_argument(
113
+ "--token_index",
114
+ type=str,
115
+ default="full",
116
+ help="Selected index for word embedding.",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--global_mapper_path",
121
+ type=str,
122
+ required=True,
123
+ help="Path to pretrained global mapping network.",
124
+ )
125
+
126
+ parser.add_argument(
127
+ "--output_dir",
128
+ type=str,
129
+ default='outputs',
130
+ help="The output directory where the model predictions will be written.",
131
+ )
132
+
133
+ parser.add_argument(
134
+ "--placeholder_token",
135
+ type=str,
136
+ default="S",
137
+ help="A token to use as a placeholder for the concept.",
138
+ )
139
+
140
+ parser.add_argument(
141
+ "--template",
142
+ type=str,
143
+ default="a photo of a {}",
144
+ help="Text template for customized genetation.",
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--test_data_dir", type=str, default=None, required=True, help="A folder containing the testing data."
149
+ )
150
+
151
+ parser.add_argument(
152
+ "--pretrained_model_name_or_path",
153
+ type=str,
154
+ default=None,
155
+ required=True,
156
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
157
+ )
158
+
159
+ parser.add_argument(
160
+ "--suffix",
161
+ type=str,
162
+ default="object",
163
+ help="Suffix of save directory.",
164
+ )
165
+
166
+ parser.add_argument(
167
+ "--selected_data",
168
+ type=int,
169
+ default=-1,
170
+ help="Data index. -1 for all.",
171
+ )
172
+
173
+ args = parser.parse_args()
174
+ return args
175
+
176
+
177
+ if __name__ == "__main__":
178
+ args = parse_args()
179
+
180
+ save_dir = os.path.join(args.output_dir, f'{args.suffix}_token{args.token_index}')
181
+ os.makedirs(save_dir, exist_ok=True)
182
+
183
+ vae, unet, text_encoder, tokenizer, image_encoder, mapper, scheduler = pww_load_tools(
184
+ "cuda:0",
185
+ LMSDiscreteScheduler,
186
+ diffusion_model_path=args.pretrained_model_name_or_path,
187
+ mapper_model_path=args.global_mapper_path,
188
+ )
189
+
190
+ train_dataset = CustomDatasetWithBG(
191
+ data_root=args.test_data_dir,
192
+ tokenizer=tokenizer,
193
+ size=512,
194
+ placeholder_token=args.placeholder_token,
195
+ template=args.template,
196
+ )
197
+
198
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False)
199
+ for step, batch in enumerate(train_dataloader):
200
+ if args.selected_data > -1 and step != args.selected_data:
201
+ continue
202
+ batch["pixel_values"] = batch["pixel_values"].to("cuda:0")
203
+ batch["pixel_values_clip"] = batch["pixel_values_clip"].to("cuda:0").half()
204
+ batch["input_ids"] = batch["input_ids"].to("cuda:0")
205
+ batch["index"] = batch["index"].to("cuda:0").long()
206
+ print(step, batch['text'])
207
+ seeds = [0, 42, 10086, 777, 555, 222, 111, 999, 327, 283, 190, 218, 2371, 9329, 2938, 2073, 27367, 293,
208
+ 8269, 87367, 29379, 4658, 39, 598]
209
+ seeds = sorted(seeds)
210
+ for seed in seeds:
211
+ syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, vae, batch["pixel_values_clip"].device, 5,
212
+ token_index=args.token_index, seed=seed)
213
+ concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
214
+ Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(seed).zfill(5)}.jpg'))
inference_global.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
+ export DATA_DIR='./test_datasets/'
3
+
4
+ CUDA_VISIBLE_DEVICES=6 python inference_global.py \
5
+ --pretrained_model_name_or_path=$MODEL_NAME \
6
+ --test_data_dir=$DATA_DIR \
7
+ --output_dir="./outputs/global_mapping" \
8
+ --suffix="object" \
9
+ --token_index="0" \
10
+ --template="a photo of a {}" \
11
+ --global_mapper_path="./checkpoints/global_mapper.pt"
12
+
inference_local.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
8
+ from PIL import Image
9
+ from tqdm.auto import tqdm
10
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
11
+ from train_local import Mapper, th2image, MapperLocal
12
+ from train_local import inj_forward_text, inj_forward_crossattention, validation
13
+ import torch.nn as nn
14
+ from datasets import CustomDatasetWithBG
15
+
16
+ def _pil_from_latents(vae, latents):
17
+ _latents = 1 / 0.18215 * latents.clone()
18
+ image = vae.decode(_latents).sample
19
+
20
+ image = (image / 2 + 0.5).clamp(0, 1)
21
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
22
+ images = (image * 255).round().astype("uint8")
23
+ ret_pil_images = [Image.fromarray(image) for image in images]
24
+
25
+ return ret_pil_images
26
+
27
+
28
+ def pww_load_tools(
29
+ device: str = "cuda:0",
30
+ scheduler_type=LMSDiscreteScheduler,
31
+ mapper_model_path: Optional[str] = None,
32
+ mapper_local_model_path: Optional[str] = None,
33
+ diffusion_model_path: Optional[str] = None,
34
+ model_token: Optional[str] = None,
35
+ ) -> Tuple[
36
+ UNet2DConditionModel,
37
+ CLIPTextModel,
38
+ CLIPTokenizer,
39
+ AutoencoderKL,
40
+ CLIPVisionModel,
41
+ Mapper,
42
+ MapperLocal,
43
+ LMSDiscreteScheduler,
44
+ ]:
45
+
46
+ # 'CompVis/stable-diffusion-v1-4'
47
+ local_path_only = diffusion_model_path is not None
48
+ vae = AutoencoderKL.from_pretrained(
49
+ diffusion_model_path,
50
+ subfolder="vae",
51
+ use_auth_token=model_token,
52
+ torch_dtype=torch.float16,
53
+ local_files_only=local_path_only,
54
+ )
55
+
56
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
57
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
58
+ image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16,)
59
+
60
+
61
+ # Load models and create wrapper for stable diffusion
62
+ for _module in text_encoder.modules():
63
+ if _module.__class__.__name__ == "CLIPTextTransformer":
64
+ _module.__class__.__call__ = inj_forward_text
65
+
66
+ unet = UNet2DConditionModel.from_pretrained(
67
+ diffusion_model_path,
68
+ subfolder="unet",
69
+ use_auth_token=model_token,
70
+ torch_dtype=torch.float16,
71
+ local_files_only=local_path_only,
72
+ )
73
+ inj_forward_crossattention
74
+ mapper = Mapper(input_dim=1024, output_dim=768)
75
+
76
+ mapper_local = MapperLocal(input_dim=1024, output_dim=768)
77
+
78
+ for _name, _module in unet.named_modules():
79
+ if _module.__class__.__name__ == "CrossAttention":
80
+ if 'attn1' in _name: continue
81
+ _module.__class__.__call__ = inj_forward_crossattention
82
+
83
+ shape = _module.to_k.weight.shape
84
+ to_k_global = nn.Linear(shape[1], shape[0], bias=False)
85
+ mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global)
86
+
87
+ shape = _module.to_v.weight.shape
88
+ to_v_global = nn.Linear(shape[1], shape[0], bias=False)
89
+ mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global)
90
+
91
+ to_v_local = nn.Linear(shape[1], shape[0], bias=False)
92
+ mapper_local.add_module(f'{_name.replace(".", "_")}_to_v', to_v_local)
93
+
94
+ to_k_local = nn.Linear(shape[1], shape[0], bias=False)
95
+ mapper_local.add_module(f'{_name.replace(".", "_")}_to_k', to_k_local)
96
+
97
+ mapper.load_state_dict(torch.load(mapper_model_path, map_location='cpu'))
98
+ mapper.half()
99
+
100
+ mapper_local.load_state_dict(torch.load(mapper_local_model_path, map_location='cpu'))
101
+ mapper_local.half()
102
+
103
+ for _name, _module in unet.named_modules():
104
+ if 'attn1' in _name: continue
105
+ if _module.__class__.__name__ == "CrossAttention":
106
+ _module.add_module('to_k_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_k'))
107
+ _module.add_module('to_v_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_v'))
108
+ _module.add_module('to_v_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_v'))
109
+ _module.add_module('to_k_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_k'))
110
+
111
+ vae.to(device), unet.to(device), text_encoder.to(device), image_encoder.to(device), mapper.to(device), mapper_local.to(device)
112
+
113
+ scheduler = scheduler_type(
114
+ beta_start=0.00085,
115
+ beta_end=0.012,
116
+ beta_schedule="scaled_linear",
117
+ num_train_timesteps=1000,
118
+ )
119
+ vae.eval()
120
+ unet.eval()
121
+ image_encoder.eval()
122
+ text_encoder.eval()
123
+ mapper.eval()
124
+ mapper_local.eval()
125
+ return vae, unet, text_encoder, tokenizer, image_encoder, mapper, mapper_local, scheduler
126
+
127
+
128
+
129
+ def parse_args():
130
+
131
+ import argparse
132
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
133
+
134
+ parser.add_argument(
135
+ "--global_mapper_path",
136
+ type=str,
137
+ required=True,
138
+ help="Path to pretrained global mapping network.",
139
+ )
140
+
141
+ parser.add_argument(
142
+ "--local_mapper_path",
143
+ type=str,
144
+ required=True,
145
+ help="Path to pretrained local mapping network.",
146
+ )
147
+
148
+ parser.add_argument(
149
+ "--output_dir",
150
+ type=str,
151
+ default='outputs',
152
+ help="The output directory where the model predictions will be written.",
153
+ )
154
+
155
+ parser.add_argument(
156
+ "--placeholder_token",
157
+ type=str,
158
+ default="S",
159
+ help="A token to use as a placeholder for the concept.",
160
+ )
161
+
162
+ parser.add_argument(
163
+ "--template",
164
+ type=str,
165
+ default="a photo of a {}",
166
+ help="Text template for customized genetation.",
167
+ )
168
+
169
+ parser.add_argument(
170
+ "--test_data_dir", type=str, default=None, required=True, help="A folder containing the testing data."
171
+ )
172
+
173
+ parser.add_argument(
174
+ "--pretrained_model_name_or_path",
175
+ type=str,
176
+ default=None,
177
+ required=True,
178
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
179
+ )
180
+
181
+ parser.add_argument(
182
+ "--suffix",
183
+ type=str,
184
+ default="object",
185
+ help="Suffix of save directory.",
186
+ )
187
+
188
+ parser.add_argument(
189
+ "--selected_data",
190
+ type=int,
191
+ default=-1,
192
+ help="Data index. -1 for all.",
193
+ )
194
+
195
+ parser.add_argument(
196
+ "--llambda",
197
+ type=str,
198
+ default="0.8",
199
+ help="Lambda for fuse the global and local feature.",
200
+ )
201
+
202
+ args = parser.parse_args()
203
+ return args
204
+
205
+
206
+ if __name__ == "__main__":
207
+ args = parse_args()
208
+
209
+ save_dir = os.path.join(args.output_dir, f'{args.suffix}_l{args.llambda.replace(".", "p")}')
210
+ os.makedirs(save_dir, exist_ok=True)
211
+
212
+ vae, unet, text_encoder, tokenizer, image_encoder, mapper, mapper_local, scheduler = pww_load_tools(
213
+ "cuda:0",
214
+ LMSDiscreteScheduler,
215
+ diffusion_model_path=args.pretrained_model_name_or_path,
216
+ mapper_model_path=args.global_mapper_path,
217
+ mapper_local_model_path=args.local_mapper_path,
218
+ )
219
+
220
+ train_dataset = CustomDatasetWithBG(
221
+ data_root=args.test_data_dir,
222
+ tokenizer=tokenizer,
223
+ size=512,
224
+ placeholder_token=args.placeholder_token,
225
+ template=args.template,
226
+ )
227
+
228
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=False)
229
+ for step, batch in enumerate(train_dataloader):
230
+ if args.selected_data > -1 and step != args.selected_data:
231
+ continue
232
+ batch["pixel_values"] = batch["pixel_values"].to("cuda:0")
233
+ batch["pixel_values_clip"] = batch["pixel_values_clip"].to("cuda:0").half()
234
+ batch["pixel_values_obj"] = batch["pixel_values_obj"].to("cuda:0").half()
235
+ batch["pixel_values_seg"] = batch["pixel_values_seg"].to("cuda:0").half()
236
+ batch["input_ids"] = batch["input_ids"].to("cuda:0")
237
+ batch["index"] = batch["index"].to("cuda:0").long()
238
+ print(step, batch['text'])
239
+ seeds = [0, 42, 10086, 777, 555, 222, 111, 999, 327, 283, 190, 218, 2371, 9329, 2938, 2073, 27367, 293,
240
+ 8269, 87367, 29379, 4658, 39, 598]
241
+ seeds = sorted(seeds)
242
+ for seed in seeds:
243
+ syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, mapper_local, vae,
244
+ batch["pixel_values_clip"].device, 5,
245
+ seed=seed, llambda=float(args.llambda))
246
+ concat = np.concatenate((np.array(syn_images[0]), th2image(batch["pixel_values"][0])), axis=1)
247
+ Image.fromarray(concat).save(os.path.join(save_dir, f'{str(step).zfill(5)}_{str(seed).zfill(5)}.jpg'))
inference_local.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
+ export DATA_DIR='./test_datasets/'
3
+ CUDA_VISIBLE_DEVICES=7 python inference_local.py \
4
+ --pretrained_model_name_or_path=$MODEL_NAME \
5
+ --test_data_dir=$DATA_DIR \
6
+ --output_dir="./outputs/local_mapping" \
7
+ --suffix="object" \
8
+ --template="a photo of a {}" \
9
+ --llambda="0.8" \
10
+ --global_mapper_path="./checkpoints/global_mapper.pt" \
11
+ --local_mapper_path="./checkpoints/local_mapper.pt"
12
+
test_datasets/1.jpg ADDED
test_datasets/10.jpg ADDED
test_datasets/10_bg.png ADDED
test_datasets/11.jpg ADDED
test_datasets/11_bg.png ADDED
test_datasets/15.jpg ADDED
test_datasets/15_bg.png ADDED
test_datasets/16.jpg ADDED
test_datasets/16_bg.png ADDED
test_datasets/17.jpg ADDED
test_datasets/17_bg.png ADDED
test_datasets/1_bg.png ADDED
test_datasets/2.jpg ADDED
test_datasets/20.jpg ADDED
test_datasets/20_bg.png ADDED
test_datasets/2_bg.png ADDED
test_datasets/3.jpg ADDED
test_datasets/3_bg.png ADDED
test_datasets/4.png ADDED
test_datasets/4_bg.png ADDED
test_datasets/7.jpg ADDED
test_datasets/7_bg.png ADDED
train_global.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import math
4
+ import os
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint
12
+ from torch.utils.data import Dataset
13
+
14
+ import PIL
15
+ from accelerate import Accelerator
16
+ from accelerate.logging import get_logger
17
+ from accelerate.utils import set_seed
18
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, LMSDiscreteScheduler
19
+ from diffusers.optimization import get_scheduler
20
+ from huggingface_hub import HfFolder, Repository, whoami
21
+
22
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
23
+ from transformers.utils import (
24
+ add_start_docstrings_to_model_forward,
25
+ replace_return_docstrings,
26
+ )
27
+ from transformers.models.clip.configuration_clip import CLIPTextConfig
28
+ from transformers.models.clip.modeling_clip import CLIP_TEXT_INPUTS_DOCSTRING, _expand_mask
29
+
30
+ from PIL import Image
31
+ from tqdm.auto import tqdm
32
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
33
+
34
+ from typing import Any, Optional, Tuple, Union
35
+ from datasets import OpenImagesDataset
36
+
37
+
38
+
39
+ class Mapper(nn.Module):
40
+ def __init__(self,
41
+ input_dim: int,
42
+ output_dim: int,
43
+ ):
44
+ super(Mapper, self).__init__()
45
+
46
+ for i in range(5):
47
+ setattr(self, f'mapping_{i}', nn.Sequential(nn.Linear(input_dim, 1024),
48
+ nn.LayerNorm(1024),
49
+ nn.LeakyReLU(),
50
+ nn.Linear(1024, 1024),
51
+ nn.LayerNorm(1024),
52
+ nn.LeakyReLU(),
53
+ nn.Linear(1024, output_dim)))
54
+
55
+ setattr(self, f'mapping_patch_{i}', nn.Sequential(nn.Linear(input_dim, 1024),
56
+ nn.LayerNorm(1024),
57
+ nn.LeakyReLU(),
58
+ nn.Linear(1024, 1024),
59
+ nn.LayerNorm(1024),
60
+ nn.LeakyReLU(),
61
+ nn.Linear(1024, output_dim)))
62
+
63
+ def forward(self, embs):
64
+ hidden_states = ()
65
+ for i, emb in enumerate(embs):
66
+ hidden_state = getattr(self, f'mapping_{i}')(emb[:, :1]) + getattr(self, f'mapping_patch_{i}')(emb[:, 1:]).mean(dim=1, keepdim=True)
67
+ hidden_states += (hidden_state, )
68
+ hidden_states = torch.cat(hidden_states, dim=1)
69
+ return hidden_states
70
+
71
+
72
+ def _build_causal_attention_mask(bsz, seq_len, dtype):
73
+ # lazily create causal attention mask, with full attention between the vision tokens
74
+ # pytorch uses additive attention mask; fill with -inf
75
+ mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
76
+ mask.fill_(torch.tensor(torch.finfo(dtype).min))
77
+ mask.triu_(1) # zero out the lower diagonal
78
+ mask = mask.unsqueeze(1) # expand mask
79
+ return mask
80
+
81
+
82
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
83
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
84
+ def inj_forward_text(
85
+ self,
86
+ input_ids: Optional[torch.Tensor] = None,
87
+ attention_mask: Optional[torch.Tensor] = None,
88
+ position_ids: Optional[torch.Tensor] = None,
89
+ output_attentions: Optional[bool] = None,
90
+ output_hidden_states: Optional[bool] = None,
91
+ return_dict: Optional[bool] = None,
92
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
93
+ r"""
94
+ Returns:
95
+ """
96
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
97
+ output_hidden_states = (
98
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
99
+ )
100
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
101
+
102
+ if input_ids is None:
103
+ raise ValueError("You have to specify either input_ids")
104
+
105
+ r_input_ids = input_ids['input_ids']
106
+ if 'inj_embedding' in input_ids:
107
+ inj_embedding = input_ids['inj_embedding']
108
+ inj_index = input_ids['inj_index']
109
+ else:
110
+ inj_embedding = None
111
+ inj_index = None
112
+
113
+ input_shape = r_input_ids.size()
114
+ r_input_ids = r_input_ids.view(-1, input_shape[-1])
115
+
116
+
117
+ inputs_embeds = self.embeddings.token_embedding(r_input_ids)
118
+ new_inputs_embeds = inputs_embeds.clone()
119
+ if inj_embedding is not None:
120
+ emb_length = inj_embedding.shape[1]
121
+ for bsz, idx in enumerate(inj_index):
122
+ lll = new_inputs_embeds[bsz, idx+emb_length:].shape[0]
123
+ new_inputs_embeds[bsz, idx+emb_length:] = inputs_embeds[bsz, idx+1:idx+1+lll]
124
+ new_inputs_embeds[bsz, idx:idx+emb_length] = inj_embedding[bsz]
125
+
126
+ hidden_states = self.embeddings(input_ids=r_input_ids, position_ids=position_ids, inputs_embeds=new_inputs_embeds)
127
+
128
+ bsz, seq_len = input_shape
129
+ # CLIP's text model uses causal mask, prepare it here.
130
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
131
+ causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
132
+ hidden_states.device
133
+ )
134
+ # expand attention_mask
135
+ if attention_mask is not None:
136
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
137
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
138
+
139
+ encoder_outputs = self.encoder(
140
+ inputs_embeds=hidden_states,
141
+ attention_mask=attention_mask,
142
+ causal_attention_mask=causal_attention_mask,
143
+ output_attentions=output_attentions,
144
+ output_hidden_states=output_hidden_states,
145
+ return_dict=return_dict,
146
+ )
147
+
148
+ last_hidden_state = encoder_outputs[0]
149
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
150
+
151
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
152
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
153
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
154
+ pooled_output = last_hidden_state[
155
+ torch.arange(last_hidden_state.shape[0], device=r_input_ids.device), r_input_ids.to(torch.int).argmax(dim=-1)
156
+ ]
157
+
158
+ if not return_dict:
159
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
160
+
161
+ return BaseModelOutputWithPooling(
162
+ last_hidden_state=last_hidden_state,
163
+ pooler_output=pooled_output,
164
+ hidden_states=encoder_outputs.hidden_states,
165
+ attentions=encoder_outputs.attentions,
166
+ )
167
+
168
+ def inj_forward_crossattention(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
169
+ context = encoder_hidden_states
170
+ if context is not None:
171
+ context_tensor = context["CONTEXT_TENSOR"]
172
+ else:
173
+ context_tensor = hidden_states
174
+
175
+ batch_size, sequence_length, _ = hidden_states.shape
176
+
177
+ query = self.to_q(hidden_states)
178
+ if context is not None:
179
+ key = self.to_k_global(context_tensor)
180
+ value = self.to_v_global(context_tensor)
181
+ else:
182
+ key = self.to_k(context_tensor)
183
+ value = self.to_v(context_tensor)
184
+
185
+ dim = query.shape[-1]
186
+
187
+ query = self.reshape_heads_to_batch_dim(query)
188
+ key = self.reshape_heads_to_batch_dim(key)
189
+ value = self.reshape_heads_to_batch_dim(value)
190
+
191
+ attention_scores = torch.matmul(query, key.transpose(-1, -2))
192
+ attention_scores = attention_scores * self.scale
193
+
194
+ attention_probs = attention_scores.softmax(dim=-1)
195
+
196
+ hidden_states = torch.matmul(attention_probs, value)
197
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
198
+
199
+ # linear proj
200
+ hidden_states = self.to_out[0](hidden_states)
201
+ # dropout
202
+ hidden_states = self.to_out[1](hidden_states)
203
+
204
+ return hidden_states
205
+
206
+
207
+
208
+ logger = get_logger(__name__)
209
+
210
+
211
+ def save_progress(mapper, accelerator, args, step=None):
212
+ logger.info("Saving embeddings")
213
+
214
+ state_dict = accelerator.unwrap_model(mapper).state_dict()
215
+
216
+ if step is not None:
217
+ torch.save(state_dict, os.path.join(args.output_dir, f"mapper_{str(step).zfill(6)}.pt"))
218
+ else:
219
+ torch.save(state_dict, os.path.join(args.output_dir, "mapper.pt"))
220
+
221
+
222
+ def parse_args():
223
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
224
+ parser.add_argument(
225
+ "--save_steps",
226
+ type=int,
227
+ default=500,
228
+ help="Save learned_embeds.bin every X updates steps.",
229
+ )
230
+ parser.add_argument(
231
+ "--pretrained_model_name_or_path",
232
+ type=str,
233
+ default=None,
234
+ required=True,
235
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
236
+ )
237
+ parser.add_argument(
238
+ "--tokenizer_name",
239
+ type=str,
240
+ default=None,
241
+ help="Pretrained tokenizer name or path if not the same as model_name",
242
+ )
243
+ parser.add_argument(
244
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
245
+ )
246
+ parser.add_argument(
247
+ "--global_mapper_path", type=str, default=None, help="If not none, the training will start from the given checkpoints."
248
+ )
249
+ parser.add_argument(
250
+ "--placeholder_token",
251
+ type=str,
252
+ default=None,
253
+ required=True,
254
+ help="A token to use as a placeholder for the concept.",
255
+ )
256
+ parser.add_argument(
257
+ "--output_dir",
258
+ type=str,
259
+ default="text-inversion-model",
260
+ help="The output directory where the model predictions and checkpoints will be written.",
261
+ )
262
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
263
+ parser.add_argument(
264
+ "--resolution",
265
+ type=int,
266
+ default=512,
267
+ help=(
268
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
269
+ " resolution"
270
+ ),
271
+ )
272
+ parser.add_argument(
273
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
274
+ )
275
+ parser.add_argument("--num_train_epochs", type=int, default=100)
276
+ parser.add_argument(
277
+ "--max_train_steps",
278
+ type=int,
279
+ default=5000,
280
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
281
+ )
282
+ parser.add_argument(
283
+ "--gradient_accumulation_steps",
284
+ type=int,
285
+ default=1,
286
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
287
+ )
288
+ parser.add_argument(
289
+ "--learning_rate",
290
+ type=float,
291
+ default=1e-4,
292
+ help="Initial learning rate (after the potential warmup period) to use.",
293
+ )
294
+ parser.add_argument(
295
+ "--scale_lr",
296
+ action="store_true",
297
+ default=True,
298
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
299
+ )
300
+ parser.add_argument(
301
+ "--lr_scheduler",
302
+ type=str,
303
+ default="constant",
304
+ help=(
305
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
306
+ ' "constant", "constant_with_warmup"]'
307
+ ),
308
+ )
309
+ parser.add_argument(
310
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
311
+ )
312
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
313
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
314
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
315
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
316
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
317
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
318
+ parser.add_argument(
319
+ "--hub_model_id",
320
+ type=str,
321
+ default=None,
322
+ help="The name of the repository to keep in sync with the local `output_dir`.",
323
+ )
324
+ parser.add_argument(
325
+ "--logging_dir",
326
+ type=str,
327
+ default="logs",
328
+ help=(
329
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
330
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
331
+ ),
332
+ )
333
+ parser.add_argument(
334
+ "--mixed_precision",
335
+ type=str,
336
+ default="no",
337
+ choices=["no", "fp16", "bf16"],
338
+ help=(
339
+ "Whether to use mixed precision. Choose"
340
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
341
+ "and an Nvidia Ampere GPU."
342
+ ),
343
+ )
344
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
345
+
346
+ args = parser.parse_args()
347
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
348
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
349
+ args.local_rank = env_local_rank
350
+
351
+ if args.train_data_dir is None:
352
+ raise ValueError("You must specify a train data directory.")
353
+
354
+ return args
355
+
356
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
357
+ if token is None:
358
+ token = HfFolder.get_token()
359
+ if organization is None:
360
+ username = whoami(token)["name"]
361
+ return f"{username}/{model_id}"
362
+ else:
363
+ return f"{organization}/{model_id}"
364
+
365
+
366
+ def freeze_params(params):
367
+ for param in params:
368
+ param.requires_grad = False
369
+
370
+ def unfreeze_params(params):
371
+ for param in params:
372
+ param.requires_grad = True
373
+
374
+ def th2image(image):
375
+ image = (image / 2 + 0.5).clamp(0, 1)
376
+ image = image.detach().cpu().permute(1, 2, 0).numpy()
377
+ image = (image * 255).round().astype("uint8")
378
+ return Image.fromarray(image)
379
+
380
+
381
+ @torch.no_grad()
382
+ def validation(example, tokenizer, image_encoder, text_encoder, unet, mapper, vae, device, guidance_scale, token_index='full', seed=None):
383
+ scheduler = LMSDiscreteScheduler(
384
+ beta_start=0.00085,
385
+ beta_end=0.012,
386
+ beta_schedule="scaled_linear",
387
+ num_train_timesteps=1000,
388
+ )
389
+
390
+ uncond_input = tokenizer(
391
+ [''] * example["pixel_values"].shape[0],
392
+ padding="max_length",
393
+ max_length=tokenizer.model_max_length,
394
+ return_tensors="pt",
395
+ )
396
+ uncond_embeddings = text_encoder({'input_ids':uncond_input.input_ids.to(device)})[0]
397
+
398
+ if seed is None:
399
+ latents = torch.randn(
400
+ (example["pixel_values"].shape[0], unet.in_channels, 64, 64)
401
+ )
402
+ else:
403
+ generator = torch.manual_seed(seed)
404
+ latents = torch.randn(
405
+ (example["pixel_values"].shape[0], unet.in_channels, 64, 64), generator=generator,
406
+ )
407
+
408
+ latents = latents.to(example["pixel_values_clip"])
409
+ scheduler.set_timesteps(100)
410
+ latents = latents * scheduler.init_noise_sigma
411
+
412
+ placeholder_idx = example["index"]
413
+ image = F.interpolate(example["pixel_values_clip"], (224, 224), mode='bilinear')
414
+
415
+ image_features = image_encoder(image, output_hidden_states=True)
416
+ image_embeddings = [image_features[0], image_features[2][4], image_features[2][8], image_features[2][12],
417
+ image_features[2][16]]
418
+ image_embeddings = [emb.detach() for emb in image_embeddings]
419
+ inj_embedding = mapper(image_embeddings)
420
+
421
+ if token_index != 'full':
422
+ token_index = int(token_index)
423
+ inj_embedding = inj_embedding[:, token_index:token_index + 1, :]
424
+
425
+ encoder_hidden_states = text_encoder({'input_ids': example["input_ids"],
426
+ "inj_embedding": inj_embedding,
427
+ "inj_index": placeholder_idx})[0]
428
+
429
+ for t in tqdm(scheduler.timesteps):
430
+ latent_model_input = scheduler.scale_model_input(latents, t)
431
+ noise_pred_text = unet(
432
+ latent_model_input,
433
+ t,
434
+ encoder_hidden_states={
435
+ "CONTEXT_TENSOR": encoder_hidden_states,
436
+ }
437
+ ).sample
438
+
439
+ latent_model_input = scheduler.scale_model_input(latents, t)
440
+
441
+ noise_pred_uncond = unet(
442
+ latent_model_input,
443
+ t,
444
+ encoder_hidden_states={
445
+ "CONTEXT_TENSOR": uncond_embeddings,
446
+ }
447
+ ).sample
448
+
449
+ noise_pred = noise_pred_uncond + guidance_scale * (
450
+ noise_pred_text - noise_pred_uncond
451
+ )
452
+
453
+ # compute the previous noisy sample x_t -> x_t-1
454
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
455
+
456
+ _latents = 1 / 0.18215 * latents.clone()
457
+ images = vae.decode(_latents).sample
458
+ ret_pil_images = [th2image(image) for image in images]
459
+
460
+ return ret_pil_images
461
+
462
+ def main():
463
+ args = parse_args()
464
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
465
+
466
+ accelerator = Accelerator(
467
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
468
+ mixed_precision=args.mixed_precision,
469
+ log_with="tensorboard",
470
+ logging_dir=logging_dir,
471
+ )
472
+
473
+ # If passed along, set the training seed now.
474
+ if args.seed is not None:
475
+ set_seed(args.seed)
476
+
477
+ # Handle the repository creation
478
+ if accelerator.is_main_process:
479
+ if args.push_to_hub:
480
+ if args.hub_model_id is None:
481
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
482
+ else:
483
+ repo_name = args.hub_model_id
484
+ repo = Repository(args.output_dir, clone_from=repo_name)
485
+
486
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
487
+ if "step_*" not in gitignore:
488
+ gitignore.write("step_*\n")
489
+ if "epoch_*" not in gitignore:
490
+ gitignore.write("epoch_*\n")
491
+ elif args.output_dir is not None:
492
+ os.makedirs(args.output_dir, exist_ok=True)
493
+
494
+ # Load the tokenizer and add the placeholder token as a additional special token
495
+ if args.tokenizer_name:
496
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
497
+ elif args.pretrained_model_name_or_path:
498
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
499
+
500
+ # Load models and create wrapper for stable diffusion
501
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
502
+
503
+ # replace the forward method of the text encoder to inject the word embedding
504
+ for _module in text_encoder.modules():
505
+ if _module.__class__.__name__ == "CLIPTextTransformer":
506
+ _module.__class__.__call__ = inj_forward_text
507
+
508
+ image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
509
+
510
+ mapper = Mapper(input_dim=1024, output_dim=768)
511
+
512
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
513
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
514
+
515
+ # replace the forward method of the crossattention to finetune the to_k and to_v layers
516
+ for _name, _module in unet.named_modules():
517
+ if _module.__class__.__name__ == "CrossAttention":
518
+ if 'attn1' in _name: continue
519
+ _module.__class__.__call__ = inj_forward_crossattention
520
+
521
+ shape = _module.to_k.weight.shape
522
+ to_k_global = nn.Linear(shape[1], shape[0], bias=False)
523
+ to_k_global.weight.data = _module.to_k.weight.data.clone()
524
+ mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global)
525
+
526
+ shape = _module.to_v.weight.shape
527
+ to_v_global = nn.Linear(shape[1], shape[0], bias=False)
528
+ to_v_global.weight.data = _module.to_v.weight.data.clone()
529
+ mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global)
530
+
531
+ if args.global_mapper_path is None:
532
+ _module.add_module('to_k_global', to_k_global)
533
+ _module.add_module('to_v_global', to_v_global)
534
+
535
+ if args.global_mapper_path is not None:
536
+ mapper.load_state_dict(torch.load(args.global_mapper_path, map_location='cpu'))
537
+ for _name, _module in unet.named_modules():
538
+ if _module.__class__.__name__ == "CrossAttention":
539
+ if 'attn1' in _name: continue
540
+ _module.add_module('to_k_global', getattr(mapper, f'{_name.replace(".", "_")}_to_k'))
541
+ _module.add_module('to_v_global', getattr(mapper, f'{_name.replace(".", "_")}_to_v'))
542
+
543
+ # Freeze vae and unet, encoder
544
+ freeze_params(vae.parameters())
545
+ freeze_params(unet.parameters())
546
+ freeze_params(text_encoder.parameters())
547
+ freeze_params(image_encoder.parameters())
548
+
549
+ # Unfreeze the mapper
550
+ unfreeze_params(mapper.parameters())
551
+
552
+ if args.scale_lr:
553
+ args.learning_rate = (
554
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
555
+ )
556
+
557
+ # Initialize the optimizer
558
+ optimizer = torch.optim.AdamW(
559
+ itertools.chain(mapper.parameters()), # only optimize the embeddings
560
+ lr=args.learning_rate,
561
+ betas=(args.adam_beta1, args.adam_beta2),
562
+ weight_decay=args.adam_weight_decay,
563
+ eps=args.adam_epsilon,
564
+ )
565
+
566
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
567
+
568
+ train_dataset = OpenImagesDataset(
569
+ data_root=args.train_data_dir,
570
+ tokenizer=tokenizer,
571
+ size=args.resolution,
572
+ placeholder_token=args.placeholder_token,
573
+ set="test",
574
+ )
575
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
576
+
577
+ # Scheduler and math around the number of training steps.
578
+ overrode_max_train_steps = False
579
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
580
+ if args.max_train_steps is None:
581
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
582
+ overrode_max_train_steps = True
583
+
584
+ lr_scheduler = get_scheduler(
585
+ args.lr_scheduler,
586
+ optimizer=optimizer,
587
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
588
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
589
+ )
590
+
591
+ mapper, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
592
+ mapper, optimizer, train_dataloader, lr_scheduler
593
+ )
594
+
595
+ # Move vae, unet, and encoders to device
596
+ vae.to(accelerator.device)
597
+ unet.to(accelerator.device)
598
+ image_encoder.to(accelerator.device)
599
+ text_encoder.to(accelerator.device)
600
+ # Keep vae, unet and image_encoder in eval model as we don't train these
601
+ vae.eval()
602
+ unet.eval()
603
+ image_encoder.eval()
604
+
605
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
606
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
607
+ if overrode_max_train_steps:
608
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
609
+ # Afterwards we recalculate our number of training epochs
610
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
611
+
612
+ # We need to initialize the trackers we use, and also store our configuration.
613
+ # The trackers initialize automatically on the main process.
614
+ if accelerator.is_main_process:
615
+ accelerator.init_trackers("elite", config=vars(args))
616
+
617
+ # Train!
618
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
619
+
620
+ logger.info("***** Running training *****")
621
+ logger.info(f" Num examples = {len(train_dataset)}")
622
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
623
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
624
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
625
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
626
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
627
+ # Only show the progress bar once on each machine.
628
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
629
+ progress_bar.set_description("Steps")
630
+ global_step = 0
631
+
632
+ for epoch in range(args.num_train_epochs):
633
+ mapper.train()
634
+ for step, batch in enumerate(train_dataloader):
635
+ with accelerator.accumulate(mapper):
636
+ # Convert images to latent space
637
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
638
+ latents = latents * 0.18215
639
+
640
+ # Sample noise that we'll add to the latents
641
+ noise = torch.randn(latents.shape).to(latents.device)
642
+ bsz = latents.shape[0]
643
+ # Sample a random timestep for each image
644
+ timesteps = torch.randint(
645
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
646
+ ).long()
647
+
648
+ # Add noise to the latents according to the noise magnitude at each timestep
649
+ # (this is the forward diffusion process)
650
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
651
+
652
+ placeholder_idx = batch["index"]
653
+ image = F.interpolate(batch["pixel_values_clip"], (224, 224), mode='bilinear')
654
+
655
+ image_features = image_encoder(image, output_hidden_states=True)
656
+ image_embeddings = [image_features[0], image_features[2][4], image_features[2][8], image_features[2][12], image_features[2][16]]
657
+ image_embeddings = [emb.detach() for emb in image_embeddings]
658
+ inj_embedding = mapper(image_embeddings)
659
+
660
+ # Get the text embedding for conditioning
661
+ encoder_hidden_states = text_encoder({'input_ids': batch["input_ids"],
662
+ "inj_embedding": inj_embedding,
663
+ "inj_index": placeholder_idx.detach()})[0]
664
+
665
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states={
666
+ "CONTEXT_TENSOR": encoder_hidden_states,
667
+ }).sample
668
+
669
+ loss_mle = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
670
+
671
+ loss_reg = torch.mean(torch.abs(inj_embedding)) * 0.01
672
+
673
+ loss = loss_mle + loss_reg
674
+
675
+ accelerator.backward(loss)
676
+
677
+ if accelerator.sync_gradients:
678
+ accelerator.clip_grad_norm_(mapper.parameters(), 1)
679
+
680
+ optimizer.step()
681
+ lr_scheduler.step()
682
+ optimizer.zero_grad()
683
+
684
+
685
+ # Checks if the accelerator has performed an optimization step behind the scenes
686
+ if accelerator.sync_gradients:
687
+ progress_bar.update(1)
688
+ global_step += 1
689
+ if global_step % args.save_steps == 0:
690
+ save_progress(mapper, accelerator, args, global_step)
691
+ syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, vae, batch["pixel_values_clip"].device, 5)
692
+ gt_images = [th2image(img) for img in batch["pixel_values"]]
693
+ img_list = []
694
+ for syn, gt in zip(syn_images, gt_images):
695
+ img_list.append(np.concatenate((np.array(syn), np.array(gt)), axis=1))
696
+ img_list = np.concatenate(img_list, axis=0)
697
+ Image.fromarray(img_list).save(os.path.join(args.output_dir, f"{str(global_step).zfill(5)}.jpg"))
698
+
699
+ logs = {"loss_mle": loss_mle.detach().item(), "loss_reg": loss_reg.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
700
+ progress_bar.set_postfix(**logs)
701
+ accelerator.log(logs, step=global_step)
702
+
703
+ if global_step >= args.max_train_steps:
704
+ break
705
+
706
+ accelerator.wait_for_everyone()
707
+
708
+ if accelerator.is_main_process:
709
+ save_progress(mapper, accelerator, args)
710
+
711
+ accelerator.end_training()
712
+
713
+
714
+ if __name__ == "__main__":
715
+ main()
train_global.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
+ export DATA_DIR='/home/weiyuxiang/datasets/Open_Images/'
3
+ CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch --config_file 4_gpu.json --main_process_port 25656 train_global.py \
4
+ --pretrained_model_name_or_path=$MODEL_NAME \
5
+ --train_data_dir=$DATA_DIR \
6
+ --placeholder_token="S" \
7
+ --resolution=512 \
8
+ --train_batch_size=4 \
9
+ --gradient_accumulation_steps=4 \
10
+ --max_train_steps=200000 \
11
+ --learning_rate=1e-06 --scale_lr \
12
+ --lr_scheduler="constant" \
13
+ --lr_warmup_steps=0 \
14
+ --output_dir="./elite_experiments/global_mapping" \
15
+ --save_steps 200
train_local.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import itertools
4
+ import math
5
+ import os
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils.checkpoint
13
+ from torch.utils.data import Dataset
14
+
15
+ import PIL
16
+ from accelerate import Accelerator
17
+ from accelerate.logging import get_logger
18
+ from accelerate.utils import set_seed
19
+ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
20
+ from diffusers.optimization import get_scheduler
21
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
22
+ from huggingface_hub import HfFolder, Repository, whoami
23
+
24
+ # TODO: remove and import from diffusers.utils when the new version of diffusers is released
25
+ from PIL import Image
26
+ from tqdm.auto import tqdm
27
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
28
+
29
+
30
+ from typing import Optional
31
+ from train_global import inj_forward_text, th2image, Mapper
32
+ from datasets import OpenImagesDatasetWithMask
33
+
34
+
35
+ class MapperLocal(nn.Module):
36
+ def __init__(self,
37
+ input_dim: int,
38
+ output_dim: int,
39
+ ):
40
+ super(MapperLocal, self).__init__()
41
+
42
+ for i in range(5):
43
+ setattr(self, f'mapping_{i}', nn.Sequential(nn.Linear(input_dim, 1024),
44
+ nn.LayerNorm(1024),
45
+ nn.LeakyReLU(),
46
+ nn.Linear(1024, 1024),
47
+ nn.LayerNorm(1024),
48
+ nn.LeakyReLU(),
49
+ nn.Linear(1024, output_dim)))
50
+
51
+ setattr(self, f'mapping_patch_{i}', nn.Sequential(nn.Linear(input_dim, 1024),
52
+ nn.LayerNorm(1024),
53
+ nn.LeakyReLU(),
54
+ nn.Linear(1024, 1024),
55
+ nn.LayerNorm(1024),
56
+ nn.LeakyReLU(),
57
+ nn.Linear(1024, output_dim)))
58
+
59
+ def forward(self, embs):
60
+ hidden_states = ()
61
+ for i, emb in enumerate(embs):
62
+ hidden_state = getattr(self, f'mapping_{i}')(emb[:, :1]) + getattr(self, f'mapping_patch_{i}')(emb[:, 1:])
63
+ hidden_states += (hidden_state.unsqueeze(0),)
64
+ hidden_states = torch.cat(hidden_states, dim=0).mean(dim=0)
65
+ return hidden_states
66
+
67
+ value_local_list = []
68
+
69
+ def inj_forward_crossattention(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
70
+
71
+ context = encoder_hidden_states
72
+ hidden_states_local = hidden_states.clone()
73
+ if context is not None:
74
+ context_tensor = context["CONTEXT_TENSOR"]
75
+ else:
76
+ context_tensor = hidden_states
77
+
78
+ batch_size, sequence_length, _ = hidden_states.shape
79
+
80
+ query = self.to_q(hidden_states)
81
+
82
+ if context is not None:
83
+ key = self.to_k_global(context_tensor)
84
+ value = self.to_v_global(context_tensor)
85
+ else:
86
+ key = self.to_k(context_tensor)
87
+ value = self.to_v(context_tensor)
88
+
89
+ dim = query.shape[-1]
90
+
91
+ query = self.reshape_heads_to_batch_dim(query)
92
+ key = self.reshape_heads_to_batch_dim(key)
93
+ value = self.reshape_heads_to_batch_dim(value)
94
+
95
+
96
+ attention_scores = torch.matmul(query, key.transpose(-1, -2))
97
+ attention_scores = attention_scores * self.scale
98
+
99
+ attention_probs = attention_scores.softmax(dim=-1)
100
+
101
+ hidden_states = torch.matmul(attention_probs, value)
102
+
103
+ if context is not None and "LOCAL" in context:
104
+ # Perform cross attention with the local context
105
+ query_local = self.to_q(hidden_states_local)
106
+ key_local = self.to_k_local(context["LOCAL"])
107
+ value_local = self.to_v_local(context["LOCAL"])
108
+
109
+ query_local = self.reshape_heads_to_batch_dim(query_local)
110
+ key_local = self.reshape_heads_to_batch_dim(key_local)
111
+ value_local = self.reshape_heads_to_batch_dim(value_local)
112
+
113
+ attention_scores_local = torch.matmul(query_local, key_local.transpose(-1, -2))
114
+ attention_scores_local = attention_scores_local * self.scale
115
+ attention_probs_local = attention_scores_local.softmax(dim=-1)
116
+
117
+ # To extract the attmap of learned [w]
118
+ index_local = context["LOCAL_INDEX"]
119
+ index_local = index_local.reshape(index_local.shape[0], 1).repeat((1, self.heads)).reshape(-1)
120
+ attention_probs_clone = attention_probs.clone().permute((0, 2, 1))
121
+ attention_probs_mask = attention_probs_clone[torch.arange(index_local.shape[0]), index_local]
122
+ # Normalize the attention map
123
+ attention_probs_mask = attention_probs_mask.unsqueeze(2) / attention_probs_mask.max()
124
+
125
+ if "LAMBDA" in context:
126
+ _lambda = context["LAMBDA"]
127
+ else:
128
+ _lambda = 1
129
+
130
+ attention_probs_local = attention_probs_local * attention_probs_mask * _lambda
131
+ hidden_states += torch.matmul(attention_probs_local, value_local)
132
+ value_local_list.append(value_local)
133
+
134
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
135
+
136
+ # linear proj
137
+ hidden_states = self.to_out[0](hidden_states)
138
+ # dropout
139
+ hidden_states = self.to_out[1](hidden_states)
140
+
141
+ return hidden_states
142
+
143
+ # ------------------------------------------------------------------------------
144
+
145
+ logger = get_logger(__name__)
146
+
147
+
148
+ def save_progress(mapper, accelerator, args, step=None):
149
+ logger.info("Saving embeddings")
150
+
151
+ state_dict = accelerator.unwrap_model(mapper).state_dict()
152
+
153
+ if step is not None:
154
+ torch.save(state_dict, os.path.join(args.output_dir, f"local_mapper_{str(step).zfill(6)}.pt"))
155
+ else:
156
+ torch.save(state_dict, os.path.join(args.output_dir, "local_mapper.pt"))
157
+
158
+
159
+ def parse_args():
160
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
161
+ parser.add_argument(
162
+ "--save_steps",
163
+ type=int,
164
+ default=500,
165
+ help="Save learned_embeds.bin every X updates steps.",
166
+ )
167
+ parser.add_argument(
168
+ "--pretrained_model_name_or_path",
169
+ type=str,
170
+ default=None,
171
+ required=True,
172
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
173
+ )
174
+ parser.add_argument(
175
+ "--tokenizer_name",
176
+ type=str,
177
+ default=None,
178
+ help="Pretrained tokenizer name or path if not the same as model_name",
179
+ )
180
+ parser.add_argument(
181
+ "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
182
+ )
183
+ parser.add_argument(
184
+ "--global_mapper_path", type=str, default=None,
185
+ help="If not none, the training will start from the given checkpoints."
186
+ )
187
+ parser.add_argument(
188
+ "--local_mapper_path", type=str, default=None,
189
+ help="If not none, the training will start from the given checkpoints."
190
+ )
191
+ parser.add_argument(
192
+ "--placeholder_token",
193
+ type=str,
194
+ default=None,
195
+ required=True,
196
+ help="A token to use as a placeholder for the concept.",
197
+ )
198
+ parser.add_argument(
199
+ "--output_dir",
200
+ type=str,
201
+ default="text-inversion-model",
202
+ help="The output directory where the model predictions and checkpoints will be written.",
203
+ )
204
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
205
+ parser.add_argument(
206
+ "--resolution",
207
+ type=int,
208
+ default=512,
209
+ help=(
210
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
211
+ " resolution"
212
+ ),
213
+ )
214
+ parser.add_argument(
215
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
216
+ )
217
+ parser.add_argument("--num_train_epochs", type=int, default=100)
218
+ parser.add_argument(
219
+ "--max_train_steps",
220
+ type=int,
221
+ default=5000,
222
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
223
+ )
224
+ parser.add_argument(
225
+ "--gradient_accumulation_steps",
226
+ type=int,
227
+ default=1,
228
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
229
+ )
230
+ parser.add_argument(
231
+ "--learning_rate",
232
+ type=float,
233
+ default=1e-4,
234
+ help="Initial learning rate (after the potential warmup period) to use.",
235
+ )
236
+ parser.add_argument(
237
+ "--scale_lr",
238
+ action="store_true",
239
+ default=True,
240
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
241
+ )
242
+ parser.add_argument(
243
+ "--lr_scheduler",
244
+ type=str,
245
+ default="constant",
246
+ help=(
247
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
248
+ ' "constant", "constant_with_warmup"]'
249
+ ),
250
+ )
251
+ parser.add_argument(
252
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
253
+ )
254
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
255
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
256
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
257
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
258
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
259
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
260
+ parser.add_argument(
261
+ "--hub_model_id",
262
+ type=str,
263
+ default=None,
264
+ help="The name of the repository to keep in sync with the local `output_dir`.",
265
+ )
266
+ parser.add_argument(
267
+ "--logging_dir",
268
+ type=str,
269
+ default="logs",
270
+ help=(
271
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
272
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
273
+ ),
274
+ )
275
+ parser.add_argument(
276
+ "--mixed_precision",
277
+ type=str,
278
+ default="no",
279
+ choices=["no", "fp16", "bf16"],
280
+ help=(
281
+ "Whether to use mixed precision. Choose"
282
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
283
+ "and an Nvidia Ampere GPU."
284
+ ),
285
+ )
286
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
287
+
288
+ args = parser.parse_args()
289
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
290
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
291
+ args.local_rank = env_local_rank
292
+
293
+ if args.train_data_dir is None:
294
+ raise ValueError("You must specify a train data directory.")
295
+
296
+ return args
297
+
298
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
299
+ if token is None:
300
+ token = HfFolder.get_token()
301
+ if organization is None:
302
+ username = whoami(token)["name"]
303
+ return f"{username}/{model_id}"
304
+ else:
305
+ return f"{organization}/{model_id}"
306
+
307
+
308
+ def freeze_params(params):
309
+ for param in params:
310
+ param.requires_grad = False
311
+
312
+ def unfreeze_params(params):
313
+ for param in params:
314
+ param.requires_grad = True
315
+
316
+
317
+ @torch.no_grad()
318
+ def validation(example, tokenizer, image_encoder, text_encoder, unet, mapper, mapper_local, vae, device, guidance_scale, seed=None, llambda=1):
319
+ scheduler = LMSDiscreteScheduler(
320
+ beta_start=0.00085,
321
+ beta_end=0.012,
322
+ beta_schedule="scaled_linear",
323
+ num_train_timesteps=1000,
324
+ )
325
+
326
+ uncond_input = tokenizer(
327
+ [''] * example["pixel_values"].shape[0],
328
+ padding="max_length",
329
+ max_length=tokenizer.model_max_length,
330
+ return_tensors="pt",
331
+ )
332
+ uncond_embeddings = text_encoder({'input_ids':uncond_input.input_ids.to(device)})[0]
333
+
334
+ if seed is None:
335
+ latents = torch.randn(
336
+ (example["pixel_values"].shape[0], unet.in_channels, 64, 64)
337
+ )
338
+ else:
339
+ generator = torch.manual_seed(seed)
340
+ latents = torch.randn(
341
+ (example["pixel_values"].shape[0], unet.in_channels, 64, 64), generator=generator,
342
+ )
343
+
344
+ latents = latents.to(example["pixel_values_clip"])
345
+ scheduler.set_timesteps(100)
346
+ latents = latents * scheduler.init_noise_sigma
347
+
348
+ placeholder_idx = example["index"]
349
+
350
+ image = F.interpolate(example["pixel_values_clip"], (224, 224), mode='bilinear')
351
+ image_features = image_encoder(image, output_hidden_states=True)
352
+ image_embeddings = [image_features[0], image_features[2][4], image_features[2][8], image_features[2][12], image_features[2][16]]
353
+ image_embeddings = [emb.detach() for emb in image_embeddings]
354
+ inj_embedding = mapper(image_embeddings)
355
+
356
+ inj_embedding = inj_embedding[:, 0:1, :]
357
+ encoder_hidden_states = text_encoder({'input_ids': example["input_ids"],
358
+ "inj_embedding": inj_embedding,
359
+ "inj_index": placeholder_idx})[0]
360
+
361
+ image_obj = F.interpolate(example["pixel_values_obj"], (224, 224), mode='bilinear')
362
+ image_features_obj = image_encoder(image_obj, output_hidden_states=True)
363
+ image_embeddings_obj = [image_features_obj[0], image_features_obj[2][4], image_features_obj[2][8],
364
+ image_features_obj[2][12], image_features_obj[2][16]]
365
+ image_embeddings_obj = [emb.detach() for emb in image_embeddings_obj]
366
+
367
+ inj_embedding_local = mapper_local(image_embeddings_obj)
368
+ mask = F.interpolate(example["pixel_values_seg"], (16, 16), mode='nearest')
369
+ mask = mask[:, 0].reshape(mask.shape[0], -1, 1)
370
+ inj_embedding_local = inj_embedding_local * mask
371
+
372
+
373
+ for t in tqdm(scheduler.timesteps):
374
+ latent_model_input = scheduler.scale_model_input(latents, t)
375
+ noise_pred_text = unet(
376
+ latent_model_input,
377
+ t,
378
+ encoder_hidden_states={
379
+ "CONTEXT_TENSOR": encoder_hidden_states,
380
+ "LOCAL": inj_embedding_local,
381
+ "LOCAL_INDEX": placeholder_idx.detach(),
382
+ "LAMBDA": llambda
383
+ }
384
+ ).sample
385
+ value_local_list.clear()
386
+ latent_model_input = scheduler.scale_model_input(latents, t)
387
+
388
+ noise_pred_uncond = unet(
389
+ latent_model_input,
390
+ t,
391
+ encoder_hidden_states={
392
+ "CONTEXT_TENSOR": uncond_embeddings,
393
+ }
394
+ ).sample
395
+ value_local_list.clear()
396
+ noise_pred = noise_pred_uncond + guidance_scale * (
397
+ noise_pred_text - noise_pred_uncond
398
+ )
399
+
400
+ # compute the previous noisy sample x_t -> x_t-1
401
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
402
+
403
+ _latents = 1 / 0.18215 * latents.clone()
404
+ images = vae.decode(_latents).sample
405
+ ret_pil_images = [th2image(image) for image in images]
406
+
407
+ return ret_pil_images
408
+
409
+ def main():
410
+ args = parse_args()
411
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
412
+
413
+ accelerator = Accelerator(
414
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
415
+ mixed_precision=args.mixed_precision,
416
+ log_with="tensorboard",
417
+ logging_dir=logging_dir,
418
+ )
419
+
420
+ # If passed along, set the training seed now.
421
+ if args.seed is not None:
422
+ set_seed(args.seed)
423
+
424
+ # Handle the repository creation
425
+ if accelerator.is_main_process:
426
+ if args.push_to_hub:
427
+ if args.hub_model_id is None:
428
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
429
+ else:
430
+ repo_name = args.hub_model_id
431
+ repo = Repository(args.output_dir, clone_from=repo_name)
432
+
433
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
434
+ if "step_*" not in gitignore:
435
+ gitignore.write("step_*\n")
436
+ if "epoch_*" not in gitignore:
437
+ gitignore.write("epoch_*\n")
438
+ elif args.output_dir is not None:
439
+ os.makedirs(args.output_dir, exist_ok=True)
440
+
441
+ # Load the tokenizer and add the placeholder token as a additional special token
442
+
443
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
444
+ # Load models and create wrapper for stable diffusion
445
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
446
+
447
+ for _module in text_encoder.modules():
448
+ if _module.__class__.__name__ == "CLIPTextTransformer":
449
+ _module.__class__.__call__ = inj_forward_text
450
+
451
+ image_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
452
+
453
+ mapper = Mapper(input_dim=1024, output_dim=768)
454
+ mapper_local = MapperLocal(input_dim=1024, output_dim=768)
455
+
456
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
457
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
458
+
459
+ # replace the forward method of the crossattention to finetune the to_k and to_v layers
460
+ for _name, _module in unet.named_modules():
461
+ if _module.__class__.__name__ == "CrossAttention":
462
+ if 'attn1' in _name: continue
463
+ _module.__class__.__call__ = inj_forward_crossattention
464
+
465
+ shape = _module.to_k.weight.shape
466
+ to_k_global = nn.Linear(shape[1], shape[0], bias=False)
467
+ to_k_global.weight.data = _module.to_k.weight.data.clone()
468
+ mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global)
469
+
470
+ shape = _module.to_v.weight.shape
471
+ to_v_global = nn.Linear(shape[1], shape[0], bias=False)
472
+ to_v_global.weight.data = _module.to_v.weight.data.clone()
473
+ mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global)
474
+
475
+ to_k_local = nn.Linear(shape[1], shape[0], bias=False)
476
+ to_k_local.weight.data = _module.to_k.weight.data.clone()
477
+ mapper_local.add_module(f'{_name.replace(".", "_")}_to_k', to_k_local)
478
+ _module.add_module('to_k_local', to_k_local)
479
+
480
+ to_v_local = nn.Linear(shape[1], shape[0], bias=False)
481
+ to_v_local.weight.data = _module.to_v.weight.data.clone()
482
+ mapper_local.add_module(f'{_name.replace(".", "_")}_to_v', to_v_local)
483
+ _module.add_module('to_v_local', to_v_local)
484
+
485
+ if args.global_mapper_path is None:
486
+ _module.add_module('to_k_global', to_k_global)
487
+ _module.add_module('to_v_global', to_v_global)
488
+
489
+ if args.local_mapper_path is None:
490
+ _module.add_module('to_k_local', to_k_local)
491
+ _module.add_module('to_v_local', to_v_local)
492
+
493
+ if args.global_mapper_path is not None:
494
+ mapper.load_state_dict(torch.load(args.global_mapper_path, map_location='cpu'))
495
+ for _name, _module in unet.named_modules():
496
+ if _module.__class__.__name__ == "CrossAttention":
497
+ if 'attn1' in _name: continue
498
+ _module.add_module('to_k_global', getattr(mapper, f'{_name.replace(".", "_")}_to_k'))
499
+ _module.add_module('to_v_global', getattr(mapper, f'{_name.replace(".", "_")}_to_v'))
500
+
501
+ if args.local_mapper_path is not None:
502
+ mapper_local.load_state_dict(torch.load(args.local_mapper_path, map_location='cpu'))
503
+ for _name, _module in unet.named_modules():
504
+ if _module.__class__.__name__ == "CrossAttention":
505
+ if 'attn1' in _name: continue
506
+ _module.add_module('to_k_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_k'))
507
+ _module.add_module('to_v_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_v'))
508
+
509
+ # Freeze vae and unet
510
+ freeze_params(vae.parameters())
511
+ freeze_params(unet.parameters())
512
+ freeze_params(text_encoder.parameters())
513
+ freeze_params(image_encoder.parameters())
514
+ unfreeze_params(mapper_local.parameters())
515
+
516
+ if args.scale_lr:
517
+ args.learning_rate = (
518
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
519
+ )
520
+
521
+ # Initialize the optimizer
522
+ optimizer = torch.optim.AdamW(
523
+ itertools.chain(mapper_local.parameters()), # only optimize the embeddings
524
+ lr=args.learning_rate,
525
+ betas=(args.adam_beta1, args.adam_beta2),
526
+ weight_decay=args.adam_weight_decay,
527
+ eps=args.adam_epsilon,
528
+ )
529
+
530
+ noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler")
531
+
532
+ train_dataset = OpenImagesDatasetWithMask(
533
+ data_root=args.train_data_dir,
534
+ tokenizer=tokenizer,
535
+ size=args.resolution,
536
+ placeholder_token=args.placeholder_token,
537
+ set="test"
538
+ )
539
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True)
540
+
541
+ # Scheduler and math around the number of training steps.
542
+ overrode_max_train_steps = False
543
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
544
+ if args.max_train_steps is None:
545
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
546
+ overrode_max_train_steps = True
547
+
548
+ lr_scheduler = get_scheduler(
549
+ args.lr_scheduler,
550
+ optimizer=optimizer,
551
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
552
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
553
+ )
554
+
555
+ mapper_local, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
556
+ mapper_local, optimizer, train_dataloader, lr_scheduler
557
+ )
558
+
559
+ # Move vae and unet to device
560
+ vae.to(accelerator.device)
561
+ unet.to(accelerator.device)
562
+ image_encoder.to(accelerator.device)
563
+ text_encoder.to(accelerator.device)
564
+ mapper.to(accelerator.device)
565
+ # Keep vae and unet in eval model as we don't train these
566
+ vae.eval()
567
+ unet.eval()
568
+ image_encoder.eval()
569
+ mapper.eval()
570
+
571
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
572
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
573
+ if overrode_max_train_steps:
574
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
575
+ # Afterwards we recalculate our number of training epochs
576
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
577
+
578
+ # We need to initialize the trackers we use, and also store our configuration.
579
+ # The trackers initialize automatically on the main process.
580
+ if accelerator.is_main_process:
581
+ accelerator.init_trackers("elite", config=vars(args))
582
+
583
+ # Train!
584
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
585
+
586
+ logger.info("***** Running training *****")
587
+ logger.info(f" Num examples = {len(train_dataset)}")
588
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
589
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
590
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
591
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
592
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
593
+ # Only show the progress bar once on each machine.
594
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
595
+ progress_bar.set_description("Steps")
596
+ global_step = 0
597
+
598
+ for epoch in range(args.num_train_epochs):
599
+ mapper_local.train()
600
+ for step, batch in enumerate(train_dataloader):
601
+ with accelerator.accumulate(mapper_local):
602
+ # Convert images to latent space
603
+ latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
604
+ latents = latents * 0.18215
605
+
606
+ # Sample noise that we'll add to the latents
607
+ noise = torch.randn(latents.shape).to(latents.device)
608
+ bsz = latents.shape[0]
609
+ # Sample a random timestep for each image
610
+ timesteps = torch.randint(
611
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
612
+ ).long()
613
+
614
+ # Add noise to the latents according to the noise magnitude at each timestep
615
+ # (this is the forward diffusion process)
616
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
617
+
618
+ placeholder_idx = batch["index"]
619
+ image = F.interpolate(batch["pixel_values_clip"], (224, 224), mode='bilinear')
620
+ image_obj = F.interpolate(batch["pixel_values_obj"], (224, 224), mode='bilinear')
621
+
622
+ mask = F.interpolate(batch["pixel_values_seg"], (16, 16), mode='nearest')
623
+ mask = mask[:, 0].reshape(mask.shape[0], -1, 1)
624
+
625
+ image_features = image_encoder(image, output_hidden_states=True)
626
+ image_embeddings = [image_features[0], image_features[2][4], image_features[2][8], image_features[2][12], image_features[2][16]]
627
+ image_embeddings = [emb.detach() for emb in image_embeddings]
628
+ inj_embedding = mapper(image_embeddings)
629
+
630
+ # only use the first word
631
+ inj_embedding = inj_embedding[:, 0:1, :]
632
+
633
+ # Get the text embedding for conditioning
634
+ encoder_hidden_states = text_encoder({'input_ids': batch["input_ids"],
635
+ "inj_embedding": inj_embedding,
636
+ "inj_index": placeholder_idx.detach()})[0]
637
+
638
+ image_features_obj = image_encoder(image_obj, output_hidden_states=True)
639
+ image_embeddings_obj = [image_features_obj[0], image_features_obj[2][4], image_features_obj[2][8], image_features_obj[2][12], image_features_obj[2][16]]
640
+ image_embeddings_obj = [emb.detach() for emb in image_embeddings_obj]
641
+
642
+ inj_embedding_local = mapper_local(image_embeddings_obj)
643
+ inj_embedding_local = inj_embedding_local * mask
644
+
645
+
646
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states={
647
+ "CONTEXT_TENSOR": encoder_hidden_states,
648
+ "LOCAL": inj_embedding_local,
649
+ "LOCAL_INDEX": placeholder_idx.detach()
650
+ }).sample
651
+
652
+ mask_values = batch["mask_values"]
653
+ loss_mle = F.mse_loss(noise_pred, noise, reduction="none")
654
+ loss_mle = ((loss_mle*mask_values).sum([1, 2, 3])/mask_values.sum([1, 2, 3])).mean()
655
+
656
+ loss_reg = 0
657
+ for vvv in value_local_list:
658
+ loss_reg += torch.mean(torch.abs(vvv))
659
+ loss_reg = loss_reg / len(value_local_list) * 0.0001
660
+
661
+ loss = loss_mle + loss_reg
662
+
663
+ accelerator.backward(loss)
664
+
665
+ if accelerator.sync_gradients:
666
+ accelerator.clip_grad_norm_(mapper_local.parameters(), 1)
667
+
668
+ optimizer.step()
669
+ lr_scheduler.step()
670
+ optimizer.zero_grad()
671
+ value_local_list.clear()
672
+
673
+
674
+ # Checks if the accelerator has performed an optimization step behind the scenes
675
+ if accelerator.sync_gradients:
676
+ progress_bar.update(1)
677
+ global_step += 1
678
+ if global_step % args.save_steps == 0:
679
+ save_progress(mapper_local, accelerator, args, global_step)
680
+ syn_images = validation(batch, tokenizer, image_encoder, text_encoder, unet, mapper, mapper_local, vae, batch["pixel_values_clip"].device, 5)
681
+ input_images = [th2image(img) for img in batch["pixel_values"]]
682
+ clip_images = [th2image(img).resize((512, 512)) for img in batch["pixel_values_clip"]]
683
+ obj_images = [th2image(img).resize((512, 512)) for img in batch["pixel_values_obj"]]
684
+ input_masks = torch.cat([mask_values, mask_values, mask_values], dim=1)
685
+ input_masks = [th2image(img).resize((512, 512)) for img in input_masks]
686
+ obj_masks = [th2image(img).resize((512, 512)) for img in batch["pixel_values_seg"]]
687
+ img_list = []
688
+ for syn, input_img, input_mask, clip_image, obj_image, obj_mask in zip(syn_images, input_images, input_masks, clip_images, obj_images, obj_masks):
689
+ img_list.append(np.concatenate((np.array(syn), np.array(input_img), np.array(input_mask), np.array(clip_image), np.array(obj_image), np.array(obj_mask)), axis=1))
690
+ img_list = np.concatenate(img_list, axis=0)
691
+ Image.fromarray(img_list).save(os.path.join(args.output_dir, f"{str(global_step).zfill(5)}.jpg"))
692
+
693
+ logs = {"loss_mle": loss_mle.detach().item(), "loss_reg": loss_reg.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
694
+ progress_bar.set_postfix(**logs)
695
+ accelerator.log(logs, step=global_step)
696
+
697
+ if global_step >= args.max_train_steps:
698
+ break
699
+
700
+ accelerator.wait_for_everyone()
701
+
702
+ if accelerator.is_main_process:
703
+ save_progress(mapper_local, accelerator, args)
704
+
705
+ accelerator.end_training()
706
+
707
+
708
+ if __name__ == "__main__":
709
+ main()
train_local.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2
+ export DATA_DIR='/home/weiyuxiang/datasets/Open_Images/'
3
+ CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch --config_file 4_gpu.json --main_process_port 25657 train_local.py \
4
+ --pretrained_model_name_or_path=$MODEL_NAME \
5
+ --train_data_dir=$DATA_DIR \
6
+ --placeholder_token="S" \
7
+ --resolution=512 \
8
+ --train_batch_size=2 \
9
+ --gradient_accumulation_steps=4 \
10
+ --max_train_steps=200000 \
11
+ --learning_rate=1e-5 --scale_lr \
12
+ --lr_scheduler="constant" \
13
+ --lr_warmup_steps=0 \
14
+ --global_mapper_path "./elite_experiments/global_mapping/mapper_070000.pt" \
15
+ --output_dir="./elite_experiments/local_mapping" \
16
+ --save_steps 200