Aitrepreneur commited on
Commit
96fc080
1 Parent(s): e13c198

Upload 128 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. README.md +14 -0
  3. app.py +675 -0
  4. checkpoints/ControlNetModel/config.json +57 -0
  5. checkpoints/ControlNetModel/diffusion_pytorch_model.safetensors +3 -0
  6. checkpoints/ip-adapter.bin +3 -0
  7. depth_anything/__pycache__/blocks.cpython-310.pyc +0 -0
  8. depth_anything/__pycache__/dpt.cpython-310.pyc +0 -0
  9. depth_anything/blocks.py +153 -0
  10. depth_anything/dpt.py +187 -0
  11. depth_anything/util/__pycache__/transform.cpython-310.pyc +0 -0
  12. depth_anything/util/transform.py +248 -0
  13. examples/.DS_Store +0 -0
  14. examples/kaifu_resize.png +3 -0
  15. examples/musk_resize.jpeg +0 -0
  16. examples/poses/pose.jpg +0 -0
  17. examples/poses/pose2.jpg +0 -0
  18. examples/poses/pose3.jpg +0 -0
  19. examples/poses/pose4.jpg +0 -0
  20. examples/sam_resize.png +3 -0
  21. examples/schmidhuber_resize.png +3 -0
  22. examples/yann-lecun_resize.jpg +0 -0
  23. ip_adapter/__pycache__/attention_processor.cpython-310.pyc +0 -0
  24. ip_adapter/__pycache__/resampler.cpython-310.pyc +0 -0
  25. ip_adapter/__pycache__/utils.cpython-310.pyc +0 -0
  26. ip_adapter/attention_processor.py +446 -0
  27. ip_adapter/resampler.py +121 -0
  28. ip_adapter/utils.py +5 -0
  29. models/antelopev2/1k3d68.onnx +3 -0
  30. models/antelopev2/2d106det.onnx +3 -0
  31. models/antelopev2/genderage.onnx +3 -0
  32. models/antelopev2/glintr100.onnx +3 -0
  33. models/antelopev2/scrfd_10g_bnkps.onnx +3 -0
  34. pipeline_stable_diffusion_xl_instantid_full.py +1204 -0
  35. requirements.txt +16 -0
  36. style_template.py +155 -0
  37. torchhub/README.md +3 -0
  38. torchhub/facebookresearch_dinov2_main/CODE_OF_CONDUCT.md +80 -0
  39. torchhub/facebookresearch_dinov2_main/CONTRIBUTING.md +31 -0
  40. torchhub/facebookresearch_dinov2_main/LICENSE +400 -0
  41. torchhub/facebookresearch_dinov2_main/MODEL_CARD.md +201 -0
  42. torchhub/facebookresearch_dinov2_main/README.md +277 -0
  43. torchhub/facebookresearch_dinov2_main/__pycache__/hubconf.cpython-310.pyc +0 -0
  44. torchhub/facebookresearch_dinov2_main/__pycache__/vision_transformer.cpython-310.pyc +0 -0
  45. torchhub/facebookresearch_dinov2_main/conda.yaml +22 -0
  46. torchhub/facebookresearch_dinov2_main/dinov2/__init__.py +7 -0
  47. torchhub/facebookresearch_dinov2_main/dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
  48. torchhub/facebookresearch_dinov2_main/dinov2/configs/__init__.py +23 -0
  49. torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitb14_pretrain.yaml +6 -0
  50. torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitg14_pretrain.yaml +7 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/kaifu_resize.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/sam_resize.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/schmidhuber_resize.png filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: InstantID
3
+ emoji: 😻
4
+ colorFrom: gray
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 4.15.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ disable_embedding: true
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ #import spaces
5
+ import torch
6
+ import random
7
+ import numpy as np
8
+ import argparse
9
+
10
+ import PIL
11
+ from PIL import Image
12
+ from typing import Tuple
13
+
14
+ import diffusers
15
+ from diffusers.utils import load_image
16
+ from diffusers.models import ControlNetModel
17
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
18
+
19
+ from huggingface_hub import hf_hub_download
20
+
21
+ import insightface
22
+ from insightface.app import FaceAnalysis
23
+
24
+ from style_template import styles
25
+ from pipeline_stable_diffusion_xl_instantid_full import StableDiffusionXLInstantIDPipeline, draw_kps
26
+
27
+ from controlnet_aux import OpenposeDetector
28
+
29
+ import gradio as gr
30
+
31
+ from depth_anything.dpt import DepthAnything
32
+ from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet
33
+
34
+ import torch.nn.functional as F
35
+ from torchvision.transforms import Compose
36
+
37
+ # global variable
38
+ MAX_SEED = np.iinfo(np.int32).max
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
41
+ STYLE_NAMES = list(styles.keys())
42
+ DEFAULT_STYLE_NAME = "Spring Festival"
43
+ enable_lcm_arg = False
44
+
45
+ # download checkpoints
46
+ from huggingface_hub import hf_hub_download
47
+
48
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
49
+ hf_hub_download(
50
+ repo_id="InstantX/InstantID",
51
+ filename="ControlNetModel/diffusion_pytorch_model.safetensors",
52
+ local_dir="./checkpoints",
53
+ )
54
+ hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
55
+
56
+ # Load face encoder
57
+ app = FaceAnalysis(
58
+ name="antelopev2",
59
+ root="./",
60
+ providers=["CPUExecutionProvider"],
61
+ )
62
+ app.prepare(ctx_id=0, det_size=(640, 640))
63
+
64
+ openpose = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
65
+
66
+ depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_vitl14').to(device).eval()
67
+
68
+ transform = Compose([
69
+ Resize(
70
+ width=518,
71
+ height=518,
72
+ resize_target=False,
73
+ keep_aspect_ratio=True,
74
+ ensure_multiple_of=14,
75
+ resize_method='lower_bound',
76
+ image_interpolation_method=cv2.INTER_CUBIC,
77
+ ),
78
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
79
+ PrepareForNet(),
80
+ ])
81
+
82
+ # Path to InstantID models
83
+ face_adapter = f"./checkpoints/ip-adapter.bin"
84
+ controlnet_path = f"./checkpoints/ControlNetModel"
85
+
86
+ # Load pipeline face ControlNetModel
87
+ controlnet_identitynet = ControlNetModel.from_pretrained(
88
+ controlnet_path, torch_dtype=dtype
89
+ )
90
+
91
+ # controlnet-pose/canny/depth
92
+ controlnet_pose_model = "thibaud/controlnet-openpose-sdxl-1.0"
93
+ controlnet_canny_model = "diffusers/controlnet-canny-sdxl-1.0"
94
+ controlnet_depth_model = "diffusers/controlnet-depth-sdxl-1.0-small"
95
+
96
+ controlnet_pose = ControlNetModel.from_pretrained(
97
+ controlnet_pose_model, torch_dtype=dtype
98
+ ).to(device)
99
+ controlnet_canny = ControlNetModel.from_pretrained(
100
+ controlnet_canny_model, torch_dtype=dtype
101
+ ).to(device)
102
+ controlnet_depth = ControlNetModel.from_pretrained(
103
+ controlnet_depth_model, torch_dtype=dtype
104
+ ).to(device)
105
+
106
+ def get_depth_map(image):
107
+
108
+ image = np.array(image) / 255.0
109
+
110
+ h, w = image.shape[:2]
111
+
112
+ image = transform({'image': image})['image']
113
+ image = torch.from_numpy(image).unsqueeze(0).to("cuda")
114
+
115
+ with torch.no_grad():
116
+ depth = depth_anything(image)
117
+
118
+ depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
119
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
120
+
121
+ depth = depth.cpu().numpy().astype(np.uint8)
122
+
123
+ depth_image = Image.fromarray(depth)
124
+
125
+ return depth_image
126
+
127
+ def get_canny_image(image, t1=100, t2=200):
128
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
129
+ edges = cv2.Canny(image, t1, t2)
130
+ return Image.fromarray(edges, "L")
131
+
132
+ controlnet_map = {
133
+ "pose": controlnet_pose,
134
+ "canny": controlnet_canny,
135
+ "depth": controlnet_depth,
136
+ }
137
+ controlnet_map_fn = {
138
+ "pose": openpose,
139
+ "canny": get_canny_image,
140
+ "depth": get_depth_map,
141
+ }
142
+
143
+ #base_model_path = "wangqixun/YamerMIX_v8"
144
+
145
+ if __name__ == '__main__':
146
+ parser = argparse.ArgumentParser()
147
+ parser.add_argument('--inbrowser', action='store_true', help='Open in browser')
148
+ parser.add_argument('--server_port', type=int, default=7860, help='Server port')
149
+ parser.add_argument('--share', action='store_true', help='Share the Gradio UI')
150
+ parser.add_argument('--model_path', type=str, default='stablediffusionapi/juggernaut-xl-v8', help='Base model path')
151
+ parser.add_argument('--medvram', action='store_true', help='Medium VRAM settings')
152
+ parser.add_argument('--lowvram', action='store_true', help='Low VRAM settings')
153
+
154
+ args = parser.parse_args()
155
+
156
+ base_model_path = args.model_path
157
+
158
+ pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
159
+ base_model_path,
160
+ controlnet=[controlnet_identitynet],
161
+ torch_dtype=dtype,
162
+ safety_checker=None,
163
+ feature_extractor=None,
164
+ ).to(device)
165
+
166
+ pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(
167
+ pipe.scheduler.config
168
+ )
169
+
170
+ # load and disable LCM
171
+ pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
172
+ pipe.disable_lora()
173
+
174
+ pipe.cuda()
175
+ pipe.load_ip_adapter_instantid(face_adapter)
176
+ pipe.image_proj_model.to("cuda")
177
+ pipe.unet.to("cuda")
178
+
179
+ def toggle_lcm_ui(value):
180
+ if value:
181
+ return (
182
+ gr.update(minimum=0, maximum=100, step=1, value=5),
183
+ gr.update(minimum=0.1, maximum=20.0, step=0.1, value=1.5),
184
+ )
185
+ else:
186
+ return (
187
+ gr.update(minimum=5, maximum=100, step=1, value=30),
188
+ gr.update(minimum=0.1, maximum=20.0, step=0.1, value=5),
189
+ )
190
+
191
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
192
+ if randomize_seed:
193
+ seed = random.randint(0, MAX_SEED)
194
+ return seed
195
+
196
+ def remove_tips():
197
+ return gr.update(visible=False)
198
+
199
+ def get_example():
200
+ case = [
201
+ [
202
+ "./examples/yann-lecun_resize.jpg",
203
+ None,
204
+ "a man",
205
+ "Spring Festival",
206
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
207
+ ],
208
+ [
209
+ "./examples/musk_resize.jpeg",
210
+ "./examples/poses/pose2.jpg",
211
+ "a man flying in the sky in Mars",
212
+ "Mars",
213
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
214
+ ],
215
+ [
216
+ "./examples/sam_resize.png",
217
+ "./examples/poses/pose4.jpg",
218
+ "a man doing a silly pose wearing a suite",
219
+ "Jungle",
220
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, gree",
221
+ ],
222
+ [
223
+ "./examples/schmidhuber_resize.png",
224
+ "./examples/poses/pose3.jpg",
225
+ "a man sit on a chair",
226
+ "Neon",
227
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
228
+ ],
229
+ [
230
+ "./examples/kaifu_resize.png",
231
+ "./examples/poses/pose.jpg",
232
+ "a man",
233
+ "Vibrant Color",
234
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
235
+ ],
236
+ ]
237
+ return case
238
+
239
+ def run_for_examples(face_file, pose_file, prompt, style, negative_prompt):
240
+ return generate_image(
241
+ face_file,
242
+ pose_file,
243
+ prompt,
244
+ negative_prompt,
245
+ style,
246
+ 20, # num_steps
247
+ 0.8, # identitynet_strength_ratio
248
+ 0.8, # adapter_strength_ratio
249
+ 0.4, # pose_strength
250
+ 0.3, # canny_strength
251
+ 0.5, # depth_strength
252
+ ["pose", "canny"], # controlnet_selection
253
+ 5.0, # guidance_scale
254
+ 42, # seed
255
+ "EulerDiscreteScheduler", # scheduler
256
+ False, # enable_LCM
257
+ True, # enable_Face_Region
258
+ )
259
+
260
+ def convert_from_cv2_to_image(img: np.ndarray) -> Image:
261
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
262
+
263
+ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
264
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
265
+
266
+ def resize_img(
267
+ input_image,
268
+ max_side=1280,
269
+ min_side=1024,
270
+ size=None,
271
+ pad_to_max_side=False,
272
+ mode=PIL.Image.BILINEAR,
273
+ base_pixel_number=64,
274
+ ):
275
+ w, h = input_image.size
276
+ if size is not None:
277
+ w_resize_new, h_resize_new = size
278
+ else:
279
+ ratio = min_side / min(h, w)
280
+ w, h = round(ratio * w), round(ratio * h)
281
+ ratio = max_side / max(h, w)
282
+ input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
283
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
284
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
285
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
286
+
287
+ if pad_to_max_side:
288
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
289
+ offset_x = (max_side - w_resize_new) // 2
290
+ offset_y = (max_side - h_resize_new) // 2
291
+ res[
292
+ offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
293
+ ] = np.array(input_image)
294
+ input_image = Image.fromarray(res)
295
+ return input_image
296
+
297
+ def apply_style(
298
+ style_name: str, positive: str, negative: str = ""
299
+ ) -> Tuple[str, str]:
300
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
301
+ return p.replace("{prompt}", positive), n + " " + negative
302
+
303
+ #@spaces.GPU
304
+ def generate_image(
305
+ face_image_path,
306
+ pose_image_path,
307
+ prompt,
308
+ negative_prompt,
309
+ style_name,
310
+ num_steps,
311
+ identitynet_strength_ratio,
312
+ adapter_strength_ratio,
313
+ pose_strength,
314
+ canny_strength,
315
+ depth_strength,
316
+ controlnet_selection,
317
+ guidance_scale,
318
+ seed,
319
+ scheduler,
320
+ enable_LCM,
321
+ enhance_face_region,
322
+ progress=gr.Progress(track_tqdm=True),
323
+ ):
324
+
325
+ if enable_LCM:
326
+ pipe.scheduler = diffusers.LCMScheduler.from_config(pipe.scheduler.config)
327
+ pipe.enable_lora()
328
+ else:
329
+ pipe.disable_lora()
330
+ scheduler_class_name = scheduler.split("-")[0]
331
+
332
+ add_kwargs = {}
333
+ if len(scheduler.split("-")) > 1:
334
+ add_kwargs["use_karras_sigmas"] = True
335
+ if len(scheduler.split("-")) > 2:
336
+ add_kwargs["algorithm_type"] = "sde-dpmsolver++"
337
+ scheduler = getattr(diffusers, scheduler_class_name)
338
+ pipe.scheduler = scheduler.from_config(pipe.scheduler.config, **add_kwargs)
339
+
340
+ if face_image_path is None:
341
+ raise gr.Error(
342
+ f"Cannot find any input face image! Please upload the face image"
343
+ )
344
+
345
+ if prompt is None:
346
+ prompt = "a person"
347
+
348
+ # apply the style template
349
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
350
+
351
+ face_image = load_image(face_image_path)
352
+ face_image = resize_img(face_image, max_side=1024)
353
+ face_image_cv2 = convert_from_image_to_cv2(face_image)
354
+ height, width, _ = face_image_cv2.shape
355
+
356
+ # Extract face features
357
+ face_info = app.get(face_image_cv2)
358
+
359
+ if len(face_info) == 0:
360
+ raise gr.Error(
361
+ f"Unable to detect a face in the image. Please upload a different photo with a clear face."
362
+ )
363
+
364
+ face_info = sorted(
365
+ face_info,
366
+ key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
367
+ )[
368
+ -1
369
+ ] # only use the maximum face
370
+ face_emb = face_info["embedding"]
371
+ face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
372
+ img_controlnet = face_image
373
+ if pose_image_path is not None:
374
+ pose_image = load_image(pose_image_path)
375
+ pose_image = resize_img(pose_image, max_side=1024)
376
+ img_controlnet = pose_image
377
+ pose_image_cv2 = convert_from_image_to_cv2(pose_image)
378
+
379
+ face_info = app.get(pose_image_cv2)
380
+
381
+ if len(face_info) == 0:
382
+ raise gr.Error(
383
+ f"Cannot find any face in the reference image! Please upload another person image"
384
+ )
385
+
386
+ face_info = face_info[-1]
387
+ face_kps = draw_kps(pose_image, face_info["kps"])
388
+
389
+ width, height = face_kps.size
390
+
391
+ if enhance_face_region:
392
+ control_mask = np.zeros([height, width, 3])
393
+ x1, y1, x2, y2 = face_info["bbox"]
394
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
395
+ control_mask[y1:y2, x1:x2] = 255
396
+ control_mask = Image.fromarray(control_mask.astype(np.uint8))
397
+ else:
398
+ control_mask = None
399
+
400
+ if len(controlnet_selection) > 0:
401
+ controlnet_scales = {
402
+ "pose": pose_strength,
403
+ "canny": canny_strength,
404
+ "depth": depth_strength,
405
+ }
406
+ pipe.controlnet = MultiControlNetModel(
407
+ [controlnet_identitynet]
408
+ + [controlnet_map[s] for s in controlnet_selection]
409
+ )
410
+ control_scales = [float(identitynet_strength_ratio)] + [
411
+ controlnet_scales[s] for s in controlnet_selection
412
+ ]
413
+ control_images = [face_kps] + [
414
+ controlnet_map_fn[s](img_controlnet).resize((width, height))
415
+ for s in controlnet_selection
416
+ ]
417
+ else:
418
+ pipe.controlnet = controlnet_identitynet
419
+ control_scales = float(identitynet_strength_ratio)
420
+ control_images = face_kps
421
+
422
+ generator = torch.Generator(device=device).manual_seed(seed)
423
+
424
+ print("Start inference...")
425
+ print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
426
+
427
+ pipe.set_ip_adapter_scale(adapter_strength_ratio)
428
+ images = pipe(
429
+ prompt=prompt,
430
+ negative_prompt=negative_prompt,
431
+ image_embeds=face_emb,
432
+ image=control_images,
433
+ control_mask=control_mask,
434
+ controlnet_conditioning_scale=control_scales,
435
+ num_inference_steps=num_steps,
436
+ guidance_scale=guidance_scale,
437
+ height=height,
438
+ width=width,
439
+ generator=generator,
440
+ ).images
441
+
442
+ return images[0], gr.update(visible=True)
443
+
444
+ def clear_cuda_cache():
445
+ if torch.cuda.is_available():
446
+ torch.cuda.empty_cache()
447
+
448
+ # Description
449
+ title = r"""
450
+ <h1 align="center">InstantID: Zero-shot Identity-Preserving Generation in Seconds</h1>
451
+ """
452
+
453
+ description = r"""
454
+ <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/InstantID/InstantID' target='_blank'><b>InstantID: Zero-shot Identity-Preserving Generation in Seconds</b></a>.<br>
455
+
456
+ We are organizing a Spring Festival event with HuggingFace from 2.7 to 2.25, and you can now generate pictures of Spring Festival costumes. Happy Dragon Year 🐲 ! Share the joy with your family.<br>
457
+
458
+ How to use:<br>
459
+ 1. Upload an image with a face. For images with multiple faces, we will only detect the largest face. Ensure the face is not too small and is clearly visible without significant obstructions or blurring.
460
+ 2. (Optional) You can upload another image as a reference for the face pose. If you don't, we will use the first detected face image to extract facial landmarks. If you use a cropped face at step 1, it is recommended to upload it to define a new face pose.
461
+ 3. (Optional) You can select multiple ControlNet models to control the generation process. The default is to use the IdentityNet only. The ControlNet models include pose skeleton, canny, and depth. You can adjust the strength of each ControlNet model to control the generation process.
462
+ 4. Enter a text prompt, as done in normal text-to-image models.
463
+ 5. Click the <b>Submit</b> button to begin customization.
464
+ 6. Share your customized photo with your friends and enjoy! 😊"""
465
+
466
+ article = r"""
467
+ ---
468
+ 📝 **Citation**
469
+ <br>
470
+ If our work is helpful for your research or applications, please cite us via:
471
+ ```bibtex
472
+ @article{wang2024instantid,
473
+ title={InstantID: Zero-shot Identity-Preserving Generation in Seconds},
474
+ author={Wang, Qixun and Bai, Xu and Wang, Haofan and Qin, Zekui and Chen, Anthony},
475
+ journal={arXiv preprint arXiv:2401.07519},
476
+ year={2024}
477
+ }
478
+ ```
479
+ 📧 **Contact**
480
+ <br>
481
+ If you have any questions, please feel free to open an issue or directly reach us out at <b>haofanwang.ai@gmail.com</b>.
482
+ """
483
+
484
+ tips = r"""
485
+ ### Usage tips of InstantID
486
+ 1. If you're not satisfied with the similarity, try increasing the weight of "IdentityNet Strength" and "Adapter Strength."
487
+ 2. If you feel that the saturation is too high, first decrease the Adapter strength. If it remains too high, then decrease the IdentityNet strength.
488
+ 3. If you find that text control is not as expected, decrease Adapter strength.
489
+ 4. If you find that realistic style is not good enough, go for our Github repo and use a more realistic base model.
490
+ """
491
+
492
+ css = """
493
+ .gradio-container {width: 85% !important}
494
+ """
495
+ with gr.Blocks(css=css) as demo:
496
+ # description
497
+ gr.Markdown(title)
498
+ gr.Markdown(description)
499
+
500
+ with gr.Row():
501
+ with gr.Column():
502
+ with gr.Row(equal_height=True):
503
+ # upload face image
504
+ face_file = gr.Image(
505
+ label="Upload a photo of your face", type="filepath"
506
+ )
507
+ # optional: upload a reference pose image
508
+ pose_file = gr.Image(
509
+ label="Upload a reference pose image (Optional)",
510
+ type="filepath",
511
+ )
512
+
513
+ # prompt
514
+ prompt = gr.Textbox(
515
+ label="Prompt",
516
+ info="Give simple prompt is enough to achieve good face fidelity",
517
+ placeholder="A photo of a person",
518
+ value="",
519
+ )
520
+
521
+ submit = gr.Button("Submit", variant="primary")
522
+ enable_LCM = gr.Checkbox(
523
+ label="Enable Fast Inference with LCM", value=enable_lcm_arg,
524
+ info="LCM speeds up the inference step, the trade-off is the quality of the generated image. It performs better with portrait face images rather than distant faces",
525
+ )
526
+ style = gr.Dropdown(
527
+ label="Style template",
528
+ choices=STYLE_NAMES,
529
+ value=DEFAULT_STYLE_NAME,
530
+ )
531
+
532
+ # strength
533
+ identitynet_strength_ratio = gr.Slider(
534
+ label="IdentityNet strength (for fidelity)",
535
+ minimum=0,
536
+ maximum=1.5,
537
+ step=0.05,
538
+ value=0.80,
539
+ )
540
+ adapter_strength_ratio = gr.Slider(
541
+ label="Image adapter strength (for detail)",
542
+ minimum=0,
543
+ maximum=1.5,
544
+ step=0.05,
545
+ value=0.80,
546
+ )
547
+ with gr.Accordion("Controlnet"):
548
+ controlnet_selection = gr.CheckboxGroup(
549
+ ["pose", "canny", "depth"], label="Controlnet", value=["pose"],
550
+ info="Use pose for skeleton inference, canny for edge detection, and depth for depth map estimation. You can try all three to control the generation process"
551
+ )
552
+ pose_strength = gr.Slider(
553
+ label="Pose strength",
554
+ minimum=0,
555
+ maximum=1.5,
556
+ step=0.05,
557
+ value=0.40,
558
+ )
559
+ canny_strength = gr.Slider(
560
+ label="Canny strength",
561
+ minimum=0,
562
+ maximum=1.5,
563
+ step=0.05,
564
+ value=0.40,
565
+ )
566
+ depth_strength = gr.Slider(
567
+ label="Depth strength",
568
+ minimum=0,
569
+ maximum=1.5,
570
+ step=0.05,
571
+ value=0.40,
572
+ )
573
+ with gr.Accordion(open=False, label="Advanced Options"):
574
+ negative_prompt = gr.Textbox(
575
+ label="Negative Prompt",
576
+ placeholder="low quality",
577
+ value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
578
+ )
579
+ num_steps = gr.Slider(
580
+ label="Number of sample steps",
581
+ minimum=1,
582
+ maximum=100,
583
+ step=1,
584
+ value=5 if enable_lcm_arg else 30,
585
+ )
586
+ guidance_scale = gr.Slider(
587
+ label="Guidance scale",
588
+ minimum=0.1,
589
+ maximum=20.0,
590
+ step=0.1,
591
+ value=0.0 if enable_lcm_arg else 5.0,
592
+ )
593
+ seed = gr.Slider(
594
+ label="Seed",
595
+ minimum=0,
596
+ maximum=MAX_SEED,
597
+ step=1,
598
+ value=42,
599
+ )
600
+ schedulers = [
601
+ "DEISMultistepScheduler",
602
+ "HeunDiscreteScheduler",
603
+ "EulerDiscreteScheduler",
604
+ "DPMSolverMultistepScheduler",
605
+ "DPMSolverMultistepScheduler-Karras",
606
+ "DPMSolverMultistepScheduler-Karras-SDE",
607
+ ]
608
+ scheduler = gr.Dropdown(
609
+ label="Schedulers",
610
+ choices=schedulers,
611
+ value="EulerDiscreteScheduler",
612
+ )
613
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
614
+ enhance_face_region = gr.Checkbox(label="Enhance non-face region", value=True)
615
+
616
+ with gr.Column(scale=1):
617
+ gallery = gr.Image(label="Generated Images")
618
+ usage_tips = gr.Markdown(
619
+ label="InstantID Usage Tips", value=tips, visible=False
620
+ )
621
+
622
+ submit.click(
623
+ fn=remove_tips,
624
+ outputs=usage_tips,
625
+ ).then(
626
+ fn=randomize_seed_fn,
627
+ inputs=[seed, randomize_seed],
628
+ outputs=seed,
629
+ queue=False,
630
+ api_name=False,
631
+ ).then(
632
+ fn=generate_image,
633
+ inputs=[
634
+ face_file,
635
+ pose_file,
636
+ prompt,
637
+ negative_prompt,
638
+ style,
639
+ num_steps,
640
+ identitynet_strength_ratio,
641
+ adapter_strength_ratio,
642
+ pose_strength,
643
+ canny_strength,
644
+ depth_strength,
645
+ controlnet_selection,
646
+ guidance_scale,
647
+ seed,
648
+ scheduler,
649
+ enable_LCM,
650
+ enhance_face_region,
651
+ ],
652
+ outputs=[gallery, usage_tips],
653
+ ).then(
654
+ fn=clear_cuda_cache
655
+ )
656
+
657
+ enable_LCM.input(
658
+ fn=toggle_lcm_ui,
659
+ inputs=[enable_LCM],
660
+ outputs=[num_steps, guidance_scale],
661
+ queue=False,
662
+ )
663
+
664
+ gr.Examples(
665
+ examples=get_example(),
666
+ inputs=[face_file, pose_file, prompt, style, negative_prompt],
667
+ fn=run_for_examples,
668
+ outputs=[gallery, usage_tips],
669
+ cache_examples=False,
670
+ )
671
+
672
+ gr.Markdown(article)
673
+
674
+
675
+ demo.launch(inbrowser=args.inbrowser, server_port=args.server_port, share=args.share)
checkpoints/ControlNetModel/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.21.2",
4
+ "_name_or_path": "/mnt/nj-aigc/usr/guiwan/workspace/diffusion_output/face_xl_ipc_v4_2_XiezhenAnimeForeigner/checkpoint-150000/ControlNetModel",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": "text_time",
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": 256,
9
+ "attention_head_dim": [
10
+ 5,
11
+ 10,
12
+ 20
13
+ ],
14
+ "block_out_channels": [
15
+ 320,
16
+ 640,
17
+ 1280
18
+ ],
19
+ "class_embed_type": null,
20
+ "conditioning_channels": 3,
21
+ "conditioning_embedding_out_channels": [
22
+ 16,
23
+ 32,
24
+ 96,
25
+ 256
26
+ ],
27
+ "controlnet_conditioning_channel_order": "rgb",
28
+ "cross_attention_dim": 2048,
29
+ "down_block_types": [
30
+ "DownBlock2D",
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D"
33
+ ],
34
+ "downsample_padding": 1,
35
+ "encoder_hid_dim": null,
36
+ "encoder_hid_dim_type": null,
37
+ "flip_sin_to_cos": true,
38
+ "freq_shift": 0,
39
+ "global_pool_conditions": false,
40
+ "in_channels": 4,
41
+ "layers_per_block": 2,
42
+ "mid_block_scale_factor": 1,
43
+ "norm_eps": 1e-05,
44
+ "norm_num_groups": 32,
45
+ "num_attention_heads": null,
46
+ "num_class_embeds": null,
47
+ "only_cross_attention": false,
48
+ "projection_class_embeddings_input_dim": 2816,
49
+ "resnet_time_scale_shift": "default",
50
+ "transformer_layers_per_block": [
51
+ 1,
52
+ 2,
53
+ 10
54
+ ],
55
+ "upcast_attention": null,
56
+ "use_linear_projection": true
57
+ }
checkpoints/ControlNetModel/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8127be9f174101ebdafee9964d856b49b634435cf6daa396d3f593cf0bbbb05
3
+ size 2502139136
checkpoints/ip-adapter.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02b3618e36d803784166660520098089a81388e61a93ef8002aa79a5b1c546e1
3
+ size 1691134141
depth_anything/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (3.2 kB). View file
 
depth_anything/__pycache__/dpt.cpython-310.pyc ADDED
Binary file (5.02 kB). View file
 
depth_anything/blocks.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5
+ scratch = nn.Module()
6
+
7
+ out_shape1 = out_shape
8
+ out_shape2 = out_shape
9
+ out_shape3 = out_shape
10
+ if len(in_shape) >= 4:
11
+ out_shape4 = out_shape
12
+
13
+ if expand:
14
+ out_shape1 = out_shape
15
+ out_shape2 = out_shape*2
16
+ out_shape3 = out_shape*4
17
+ if len(in_shape) >= 4:
18
+ out_shape4 = out_shape*8
19
+
20
+ scratch.layer1_rn = nn.Conv2d(
21
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
22
+ )
23
+ scratch.layer2_rn = nn.Conv2d(
24
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
25
+ )
26
+ scratch.layer3_rn = nn.Conv2d(
27
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
28
+ )
29
+ if len(in_shape) >= 4:
30
+ scratch.layer4_rn = nn.Conv2d(
31
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
32
+ )
33
+
34
+ return scratch
35
+
36
+
37
+ class ResidualConvUnit(nn.Module):
38
+ """Residual convolution module.
39
+ """
40
+
41
+ def __init__(self, features, activation, bn):
42
+ """Init.
43
+
44
+ Args:
45
+ features (int): number of features
46
+ """
47
+ super().__init__()
48
+
49
+ self.bn = bn
50
+
51
+ self.groups=1
52
+
53
+ self.conv1 = nn.Conv2d(
54
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
55
+ )
56
+
57
+ self.conv2 = nn.Conv2d(
58
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
59
+ )
60
+
61
+ if self.bn==True:
62
+ self.bn1 = nn.BatchNorm2d(features)
63
+ self.bn2 = nn.BatchNorm2d(features)
64
+
65
+ self.activation = activation
66
+
67
+ self.skip_add = nn.quantized.FloatFunctional()
68
+
69
+ def forward(self, x):
70
+ """Forward pass.
71
+
72
+ Args:
73
+ x (tensor): input
74
+
75
+ Returns:
76
+ tensor: output
77
+ """
78
+
79
+ out = self.activation(x)
80
+ out = self.conv1(out)
81
+ if self.bn==True:
82
+ out = self.bn1(out)
83
+
84
+ out = self.activation(out)
85
+ out = self.conv2(out)
86
+ if self.bn==True:
87
+ out = self.bn2(out)
88
+
89
+ if self.groups > 1:
90
+ out = self.conv_merge(out)
91
+
92
+ return self.skip_add.add(out, x)
93
+
94
+
95
+ class FeatureFusionBlock(nn.Module):
96
+ """Feature fusion block.
97
+ """
98
+
99
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
100
+ """Init.
101
+
102
+ Args:
103
+ features (int): number of features
104
+ """
105
+ super(FeatureFusionBlock, self).__init__()
106
+
107
+ self.deconv = deconv
108
+ self.align_corners = align_corners
109
+
110
+ self.groups=1
111
+
112
+ self.expand = expand
113
+ out_features = features
114
+ if self.expand==True:
115
+ out_features = features//2
116
+
117
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
118
+
119
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
120
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
121
+
122
+ self.skip_add = nn.quantized.FloatFunctional()
123
+
124
+ self.size=size
125
+
126
+ def forward(self, *xs, size=None):
127
+ """Forward pass.
128
+
129
+ Returns:
130
+ tensor: output
131
+ """
132
+ output = xs[0]
133
+
134
+ if len(xs) == 2:
135
+ res = self.resConfUnit1(xs[1])
136
+ output = self.skip_add.add(output, res)
137
+
138
+ output = self.resConfUnit2(output)
139
+
140
+ if (size is None) and (self.size is None):
141
+ modifier = {"scale_factor": 2}
142
+ elif size is None:
143
+ modifier = {"size": self.size}
144
+ else:
145
+ modifier = {"size": size}
146
+
147
+ output = nn.functional.interpolate(
148
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
149
+ )
150
+
151
+ output = self.out_conv(output)
152
+
153
+ return output
depth_anything/dpt.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
6
+
7
+ from depth_anything.blocks import FeatureFusionBlock, _make_scratch
8
+
9
+
10
+ def _make_fusion_block(features, use_bn, size = None):
11
+ return FeatureFusionBlock(
12
+ features,
13
+ nn.ReLU(False),
14
+ deconv=False,
15
+ bn=use_bn,
16
+ expand=False,
17
+ align_corners=True,
18
+ size=size,
19
+ )
20
+
21
+
22
+ class DPTHead(nn.Module):
23
+ def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False):
24
+ super(DPTHead, self).__init__()
25
+
26
+ self.nclass = nclass
27
+ self.use_clstoken = use_clstoken
28
+
29
+ self.projects = nn.ModuleList([
30
+ nn.Conv2d(
31
+ in_channels=in_channels,
32
+ out_channels=out_channel,
33
+ kernel_size=1,
34
+ stride=1,
35
+ padding=0,
36
+ ) for out_channel in out_channels
37
+ ])
38
+
39
+ self.resize_layers = nn.ModuleList([
40
+ nn.ConvTranspose2d(
41
+ in_channels=out_channels[0],
42
+ out_channels=out_channels[0],
43
+ kernel_size=4,
44
+ stride=4,
45
+ padding=0),
46
+ nn.ConvTranspose2d(
47
+ in_channels=out_channels[1],
48
+ out_channels=out_channels[1],
49
+ kernel_size=2,
50
+ stride=2,
51
+ padding=0),
52
+ nn.Identity(),
53
+ nn.Conv2d(
54
+ in_channels=out_channels[3],
55
+ out_channels=out_channels[3],
56
+ kernel_size=3,
57
+ stride=2,
58
+ padding=1)
59
+ ])
60
+
61
+ if use_clstoken:
62
+ self.readout_projects = nn.ModuleList()
63
+ for _ in range(len(self.projects)):
64
+ self.readout_projects.append(
65
+ nn.Sequential(
66
+ nn.Linear(2 * in_channels, in_channels),
67
+ nn.GELU()))
68
+
69
+ self.scratch = _make_scratch(
70
+ out_channels,
71
+ features,
72
+ groups=1,
73
+ expand=False,
74
+ )
75
+
76
+ self.scratch.stem_transpose = None
77
+
78
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
79
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
80
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
81
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
82
+
83
+ head_features_1 = features
84
+ head_features_2 = 32
85
+
86
+ if nclass > 1:
87
+ self.scratch.output_conv = nn.Sequential(
88
+ nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
89
+ nn.ReLU(True),
90
+ nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
91
+ )
92
+ else:
93
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
94
+
95
+ self.scratch.output_conv2 = nn.Sequential(
96
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
97
+ nn.ReLU(True),
98
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
99
+ nn.ReLU(True),
100
+ nn.Identity(),
101
+ )
102
+
103
+ def forward(self, out_features, patch_h, patch_w):
104
+ out = []
105
+ for i, x in enumerate(out_features):
106
+ if self.use_clstoken:
107
+ x, cls_token = x[0], x[1]
108
+ readout = cls_token.unsqueeze(1).expand_as(x)
109
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
110
+ else:
111
+ x = x[0]
112
+
113
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
114
+
115
+ x = self.projects[i](x)
116
+ x = self.resize_layers[i](x)
117
+
118
+ out.append(x)
119
+
120
+ layer_1, layer_2, layer_3, layer_4 = out
121
+
122
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
123
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
124
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
125
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
126
+
127
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
128
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
129
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
130
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
131
+
132
+ out = self.scratch.output_conv1(path_1)
133
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
134
+ out = self.scratch.output_conv2(out)
135
+
136
+ return out
137
+
138
+
139
+ class DPT_DINOv2(nn.Module):
140
+ def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True):
141
+ super(DPT_DINOv2, self).__init__()
142
+
143
+ assert encoder in ['vits', 'vitb', 'vitl']
144
+
145
+ # in case the Internet connection is not stable, please load the DINOv2 locally
146
+ if localhub:
147
+ self.pretrained = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False)
148
+ else:
149
+ self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder))
150
+
151
+ dim = self.pretrained.blocks[0].attn.qkv.in_features
152
+
153
+ self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
154
+
155
+ def forward(self, x):
156
+ h, w = x.shape[-2:]
157
+
158
+ features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
159
+
160
+ patch_h, patch_w = h // 14, w // 14
161
+
162
+ depth = self.depth_head(features, patch_h, patch_w)
163
+ depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
164
+ depth = F.relu(depth)
165
+
166
+ return depth.squeeze(1)
167
+
168
+
169
+ class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin):
170
+ def __init__(self, config):
171
+ super().__init__(**config)
172
+
173
+
174
+ if __name__ == '__main__':
175
+ parser = argparse.ArgumentParser()
176
+ parser.add_argument(
177
+ "--encoder",
178
+ default="vits",
179
+ type=str,
180
+ choices=["vits", "vitb", "vitl"],
181
+ )
182
+ args = parser.parse_args()
183
+
184
+ model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
185
+
186
+ print(model)
187
+
depth_anything/util/__pycache__/transform.cpython-310.pyc ADDED
Binary file (6.03 kB). View file
 
depth_anything/util/transform.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from PIL import Image, ImageOps, ImageFilter
3
+ import torch
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+
7
+ import numpy as np
8
+ import cv2
9
+ import math
10
+
11
+
12
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
13
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
14
+
15
+ Args:
16
+ sample (dict): sample
17
+ size (tuple): image size
18
+
19
+ Returns:
20
+ tuple: new size
21
+ """
22
+ shape = list(sample["disparity"].shape)
23
+
24
+ if shape[0] >= size[0] and shape[1] >= size[1]:
25
+ return sample
26
+
27
+ scale = [0, 0]
28
+ scale[0] = size[0] / shape[0]
29
+ scale[1] = size[1] / shape[1]
30
+
31
+ scale = max(scale)
32
+
33
+ shape[0] = math.ceil(scale * shape[0])
34
+ shape[1] = math.ceil(scale * shape[1])
35
+
36
+ # resize
37
+ sample["image"] = cv2.resize(
38
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
39
+ )
40
+
41
+ sample["disparity"] = cv2.resize(
42
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
43
+ )
44
+ sample["mask"] = cv2.resize(
45
+ sample["mask"].astype(np.float32),
46
+ tuple(shape[::-1]),
47
+ interpolation=cv2.INTER_NEAREST,
48
+ )
49
+ sample["mask"] = sample["mask"].astype(bool)
50
+
51
+ return tuple(shape)
52
+
53
+
54
+ class Resize(object):
55
+ """Resize sample to given size (width, height).
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ width,
61
+ height,
62
+ resize_target=True,
63
+ keep_aspect_ratio=False,
64
+ ensure_multiple_of=1,
65
+ resize_method="lower_bound",
66
+ image_interpolation_method=cv2.INTER_AREA,
67
+ ):
68
+ """Init.
69
+
70
+ Args:
71
+ width (int): desired output width
72
+ height (int): desired output height
73
+ resize_target (bool, optional):
74
+ True: Resize the full sample (image, mask, target).
75
+ False: Resize image only.
76
+ Defaults to True.
77
+ keep_aspect_ratio (bool, optional):
78
+ True: Keep the aspect ratio of the input sample.
79
+ Output sample might not have the given width and height, and
80
+ resize behaviour depends on the parameter 'resize_method'.
81
+ Defaults to False.
82
+ ensure_multiple_of (int, optional):
83
+ Output width and height is constrained to be multiple of this parameter.
84
+ Defaults to 1.
85
+ resize_method (str, optional):
86
+ "lower_bound": Output will be at least as large as the given size.
87
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
88
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
89
+ Defaults to "lower_bound".
90
+ """
91
+ self.__width = width
92
+ self.__height = height
93
+
94
+ self.__resize_target = resize_target
95
+ self.__keep_aspect_ratio = keep_aspect_ratio
96
+ self.__multiple_of = ensure_multiple_of
97
+ self.__resize_method = resize_method
98
+ self.__image_interpolation_method = image_interpolation_method
99
+
100
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
101
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ if max_val is not None and y > max_val:
104
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
105
+
106
+ if y < min_val:
107
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
108
+
109
+ return y
110
+
111
+ def get_size(self, width, height):
112
+ # determine new height and width
113
+ scale_height = self.__height / height
114
+ scale_width = self.__width / width
115
+
116
+ if self.__keep_aspect_ratio:
117
+ if self.__resize_method == "lower_bound":
118
+ # scale such that output size is lower bound
119
+ if scale_width > scale_height:
120
+ # fit width
121
+ scale_height = scale_width
122
+ else:
123
+ # fit height
124
+ scale_width = scale_height
125
+ elif self.__resize_method == "upper_bound":
126
+ # scale such that output size is upper bound
127
+ if scale_width < scale_height:
128
+ # fit width
129
+ scale_height = scale_width
130
+ else:
131
+ # fit height
132
+ scale_width = scale_height
133
+ elif self.__resize_method == "minimal":
134
+ # scale as least as possbile
135
+ if abs(1 - scale_width) < abs(1 - scale_height):
136
+ # fit width
137
+ scale_height = scale_width
138
+ else:
139
+ # fit height
140
+ scale_width = scale_height
141
+ else:
142
+ raise ValueError(
143
+ f"resize_method {self.__resize_method} not implemented"
144
+ )
145
+
146
+ if self.__resize_method == "lower_bound":
147
+ new_height = self.constrain_to_multiple_of(
148
+ scale_height * height, min_val=self.__height
149
+ )
150
+ new_width = self.constrain_to_multiple_of(
151
+ scale_width * width, min_val=self.__width
152
+ )
153
+ elif self.__resize_method == "upper_bound":
154
+ new_height = self.constrain_to_multiple_of(
155
+ scale_height * height, max_val=self.__height
156
+ )
157
+ new_width = self.constrain_to_multiple_of(
158
+ scale_width * width, max_val=self.__width
159
+ )
160
+ elif self.__resize_method == "minimal":
161
+ new_height = self.constrain_to_multiple_of(scale_height * height)
162
+ new_width = self.constrain_to_multiple_of(scale_width * width)
163
+ else:
164
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
165
+
166
+ return (new_width, new_height)
167
+
168
+ def __call__(self, sample):
169
+ width, height = self.get_size(
170
+ sample["image"].shape[1], sample["image"].shape[0]
171
+ )
172
+
173
+ # resize sample
174
+ sample["image"] = cv2.resize(
175
+ sample["image"],
176
+ (width, height),
177
+ interpolation=self.__image_interpolation_method,
178
+ )
179
+
180
+ if self.__resize_target:
181
+ if "disparity" in sample:
182
+ sample["disparity"] = cv2.resize(
183
+ sample["disparity"],
184
+ (width, height),
185
+ interpolation=cv2.INTER_NEAREST,
186
+ )
187
+
188
+ if "depth" in sample:
189
+ sample["depth"] = cv2.resize(
190
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
191
+ )
192
+
193
+ if "semseg_mask" in sample:
194
+ # sample["semseg_mask"] = cv2.resize(
195
+ # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
196
+ # )
197
+ sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
198
+
199
+ if "mask" in sample:
200
+ sample["mask"] = cv2.resize(
201
+ sample["mask"].astype(np.float32),
202
+ (width, height),
203
+ interpolation=cv2.INTER_NEAREST,
204
+ )
205
+ # sample["mask"] = sample["mask"].astype(bool)
206
+
207
+ # print(sample['image'].shape, sample['depth'].shape)
208
+ return sample
209
+
210
+
211
+ class NormalizeImage(object):
212
+ """Normlize image by given mean and std.
213
+ """
214
+
215
+ def __init__(self, mean, std):
216
+ self.__mean = mean
217
+ self.__std = std
218
+
219
+ def __call__(self, sample):
220
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
221
+
222
+ return sample
223
+
224
+
225
+ class PrepareForNet(object):
226
+ """Prepare sample for usage as network input.
227
+ """
228
+
229
+ def __init__(self):
230
+ pass
231
+
232
+ def __call__(self, sample):
233
+ image = np.transpose(sample["image"], (2, 0, 1))
234
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
235
+
236
+ if "mask" in sample:
237
+ sample["mask"] = sample["mask"].astype(np.float32)
238
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
239
+
240
+ if "depth" in sample:
241
+ depth = sample["depth"].astype(np.float32)
242
+ sample["depth"] = np.ascontiguousarray(depth)
243
+
244
+ if "semseg_mask" in sample:
245
+ sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
246
+ sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
247
+
248
+ return sample
examples/.DS_Store ADDED
Binary file (6.15 kB). View file
 
examples/kaifu_resize.png ADDED

Git LFS Details

  • SHA256: b7302f0f7d0ff61be67bf13d172ad2393b6cb2bc985f048089f4e901145324d7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
examples/musk_resize.jpeg ADDED
examples/poses/pose.jpg ADDED
examples/poses/pose2.jpg ADDED
examples/poses/pose3.jpg ADDED
examples/poses/pose4.jpg ADDED
examples/sam_resize.png ADDED

Git LFS Details

  • SHA256: 1390d8a9a1be7b8f5388c0bc8483b2d5cca6c0f0adeb6eecd970a4413b1f1deb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
examples/schmidhuber_resize.png ADDED

Git LFS Details

  • SHA256: 51beaa72d1eb9f56118118fae8775bda818bcb56b03220f3cd39daa425f57a9a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.23 MB
examples/yann-lecun_resize.jpg ADDED
ip_adapter/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (9.05 kB). View file
 
ip_adapter/__pycache__/resampler.cpython-310.pyc ADDED
Binary file (3.18 kB). View file
 
ip_adapter/__pycache__/utils.cpython-310.pyc ADDED
Binary file (365 Bytes). View file
 
ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ try:
7
+ import xformers
8
+ import xformers.ops
9
+ xformers_available = True
10
+ except Exception as e:
11
+ xformers_available = False
12
+
13
+ class RegionControler(object):
14
+ def __init__(self) -> None:
15
+ self.prompt_image_conditioning = []
16
+ region_control = RegionControler()
17
+
18
+ class AttnProcessor(nn.Module):
19
+ r"""
20
+ Default processor for performing attention-related computations.
21
+ """
22
+ def __init__(
23
+ self,
24
+ hidden_size=None,
25
+ cross_attention_dim=None,
26
+ ):
27
+ super().__init__()
28
+
29
+ def forward(
30
+ self,
31
+ attn,
32
+ hidden_states,
33
+ encoder_hidden_states=None,
34
+ attention_mask=None,
35
+ temb=None,
36
+ ):
37
+ residual = hidden_states
38
+
39
+ if attn.spatial_norm is not None:
40
+ hidden_states = attn.spatial_norm(hidden_states, temb)
41
+
42
+ input_ndim = hidden_states.ndim
43
+
44
+ if input_ndim == 4:
45
+ batch_size, channel, height, width = hidden_states.shape
46
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
47
+
48
+ batch_size, sequence_length, _ = (
49
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
50
+ )
51
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
52
+
53
+ if attn.group_norm is not None:
54
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
55
+
56
+ query = attn.to_q(hidden_states)
57
+
58
+ if encoder_hidden_states is None:
59
+ encoder_hidden_states = hidden_states
60
+ elif attn.norm_cross:
61
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
62
+
63
+ key = attn.to_k(encoder_hidden_states)
64
+ value = attn.to_v(encoder_hidden_states)
65
+
66
+ query = attn.head_to_batch_dim(query)
67
+ key = attn.head_to_batch_dim(key)
68
+ value = attn.head_to_batch_dim(value)
69
+
70
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
71
+ hidden_states = torch.bmm(attention_probs, value)
72
+ hidden_states = attn.batch_to_head_dim(hidden_states)
73
+
74
+ # linear proj
75
+ hidden_states = attn.to_out[0](hidden_states)
76
+ # dropout
77
+ hidden_states = attn.to_out[1](hidden_states)
78
+
79
+ if input_ndim == 4:
80
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
81
+
82
+ if attn.residual_connection:
83
+ hidden_states = hidden_states + residual
84
+
85
+ hidden_states = hidden_states / attn.rescale_output_factor
86
+
87
+ return hidden_states
88
+
89
+
90
+ class IPAttnProcessor(nn.Module):
91
+ r"""
92
+ Attention processor for IP-Adapater.
93
+ Args:
94
+ hidden_size (`int`):
95
+ The hidden size of the attention layer.
96
+ cross_attention_dim (`int`):
97
+ The number of channels in the `encoder_hidden_states`.
98
+ scale (`float`, defaults to 1.0):
99
+ the weight scale of image prompt.
100
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
101
+ The context length of the image features.
102
+ """
103
+
104
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
105
+ super().__init__()
106
+
107
+ self.hidden_size = hidden_size
108
+ self.cross_attention_dim = cross_attention_dim
109
+ self.scale = scale
110
+ self.num_tokens = num_tokens
111
+
112
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
113
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
114
+
115
+ def forward(
116
+ self,
117
+ attn,
118
+ hidden_states,
119
+ encoder_hidden_states=None,
120
+ attention_mask=None,
121
+ temb=None,
122
+ ):
123
+ residual = hidden_states
124
+
125
+ if attn.spatial_norm is not None:
126
+ hidden_states = attn.spatial_norm(hidden_states, temb)
127
+
128
+ input_ndim = hidden_states.ndim
129
+
130
+ if input_ndim == 4:
131
+ batch_size, channel, height, width = hidden_states.shape
132
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
133
+
134
+ batch_size, sequence_length, _ = (
135
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
136
+ )
137
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
138
+
139
+ if attn.group_norm is not None:
140
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
141
+
142
+ query = attn.to_q(hidden_states)
143
+
144
+ if encoder_hidden_states is None:
145
+ encoder_hidden_states = hidden_states
146
+ else:
147
+ # get encoder_hidden_states, ip_hidden_states
148
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
149
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
150
+ if attn.norm_cross:
151
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
152
+
153
+ key = attn.to_k(encoder_hidden_states)
154
+ value = attn.to_v(encoder_hidden_states)
155
+
156
+ query = attn.head_to_batch_dim(query)
157
+ key = attn.head_to_batch_dim(key)
158
+ value = attn.head_to_batch_dim(value)
159
+
160
+ if xformers_available:
161
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
162
+ else:
163
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
164
+ hidden_states = torch.bmm(attention_probs, value)
165
+ hidden_states = attn.batch_to_head_dim(hidden_states)
166
+
167
+ # for ip-adapter
168
+ ip_key = self.to_k_ip(ip_hidden_states)
169
+ ip_value = self.to_v_ip(ip_hidden_states)
170
+
171
+ ip_key = attn.head_to_batch_dim(ip_key)
172
+ ip_value = attn.head_to_batch_dim(ip_value)
173
+
174
+ if xformers_available:
175
+ ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
176
+ else:
177
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
178
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
179
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
180
+
181
+ # region control
182
+ if len(region_control.prompt_image_conditioning) == 1:
183
+ region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
184
+ if region_mask is not None:
185
+ h, w = region_mask.shape[:2]
186
+ ratio = (h * w / query.shape[1]) ** 0.5
187
+ mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
188
+ else:
189
+ mask = torch.ones_like(ip_hidden_states)
190
+ ip_hidden_states = ip_hidden_states * mask
191
+
192
+ hidden_states = hidden_states + self.scale * ip_hidden_states
193
+
194
+ # linear proj
195
+ hidden_states = attn.to_out[0](hidden_states)
196
+ # dropout
197
+ hidden_states = attn.to_out[1](hidden_states)
198
+
199
+ if input_ndim == 4:
200
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
201
+
202
+ if attn.residual_connection:
203
+ hidden_states = hidden_states + residual
204
+
205
+ hidden_states = hidden_states / attn.rescale_output_factor
206
+
207
+ return hidden_states
208
+
209
+
210
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
211
+ # TODO attention_mask
212
+ query = query.contiguous()
213
+ key = key.contiguous()
214
+ value = value.contiguous()
215
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
216
+ # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
217
+ return hidden_states
218
+
219
+
220
+ class AttnProcessor2_0(torch.nn.Module):
221
+ r"""
222
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
223
+ """
224
+ def __init__(
225
+ self,
226
+ hidden_size=None,
227
+ cross_attention_dim=None,
228
+ ):
229
+ super().__init__()
230
+ if not hasattr(F, "scaled_dot_product_attention"):
231
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
232
+
233
+ def forward(
234
+ self,
235
+ attn,
236
+ hidden_states,
237
+ encoder_hidden_states=None,
238
+ attention_mask=None,
239
+ temb=None,
240
+ ):
241
+ residual = hidden_states
242
+
243
+ if attn.spatial_norm is not None:
244
+ hidden_states = attn.spatial_norm(hidden_states, temb)
245
+
246
+ input_ndim = hidden_states.ndim
247
+
248
+ if input_ndim == 4:
249
+ batch_size, channel, height, width = hidden_states.shape
250
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
251
+
252
+ batch_size, sequence_length, _ = (
253
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
254
+ )
255
+
256
+ if attention_mask is not None:
257
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
258
+ # scaled_dot_product_attention expects attention_mask shape to be
259
+ # (batch, heads, source_length, target_length)
260
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
261
+
262
+ if attn.group_norm is not None:
263
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
264
+
265
+ query = attn.to_q(hidden_states)
266
+
267
+ if encoder_hidden_states is None:
268
+ encoder_hidden_states = hidden_states
269
+ elif attn.norm_cross:
270
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
271
+
272
+ key = attn.to_k(encoder_hidden_states)
273
+ value = attn.to_v(encoder_hidden_states)
274
+
275
+ inner_dim = key.shape[-1]
276
+ head_dim = inner_dim // attn.heads
277
+
278
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
279
+
280
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
281
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
282
+
283
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
284
+ # TODO: add support for attn.scale when we move to Torch 2.1
285
+ hidden_states = F.scaled_dot_product_attention(
286
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
287
+ )
288
+
289
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
290
+ hidden_states = hidden_states.to(query.dtype)
291
+
292
+ # linear proj
293
+ hidden_states = attn.to_out[0](hidden_states)
294
+ # dropout
295
+ hidden_states = attn.to_out[1](hidden_states)
296
+
297
+ if input_ndim == 4:
298
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
299
+
300
+ if attn.residual_connection:
301
+ hidden_states = hidden_states + residual
302
+
303
+ hidden_states = hidden_states / attn.rescale_output_factor
304
+
305
+ return hidden_states
306
+
307
+ class IPAttnProcessor2_0(torch.nn.Module):
308
+ r"""
309
+ Attention processor for IP-Adapater for PyTorch 2.0.
310
+ Args:
311
+ hidden_size (`int`):
312
+ The hidden size of the attention layer.
313
+ cross_attention_dim (`int`):
314
+ The number of channels in the `encoder_hidden_states`.
315
+ scale (`float`, defaults to 1.0):
316
+ the weight scale of image prompt.
317
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
318
+ The context length of the image features.
319
+ """
320
+
321
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
322
+ super().__init__()
323
+
324
+ if not hasattr(F, "scaled_dot_product_attention"):
325
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
326
+
327
+ self.hidden_size = hidden_size
328
+ self.cross_attention_dim = cross_attention_dim
329
+ self.scale = scale
330
+ self.num_tokens = num_tokens
331
+
332
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
333
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
334
+
335
+ def forward(
336
+ self,
337
+ attn,
338
+ hidden_states,
339
+ encoder_hidden_states=None,
340
+ attention_mask=None,
341
+ temb=None,
342
+ ):
343
+ residual = hidden_states
344
+
345
+ if attn.spatial_norm is not None:
346
+ hidden_states = attn.spatial_norm(hidden_states, temb)
347
+
348
+ input_ndim = hidden_states.ndim
349
+
350
+ if input_ndim == 4:
351
+ batch_size, channel, height, width = hidden_states.shape
352
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
353
+
354
+ batch_size, sequence_length, _ = (
355
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
356
+ )
357
+
358
+ if attention_mask is not None:
359
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
360
+ # scaled_dot_product_attention expects attention_mask shape to be
361
+ # (batch, heads, source_length, target_length)
362
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
363
+
364
+ if attn.group_norm is not None:
365
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
366
+
367
+ query = attn.to_q(hidden_states)
368
+
369
+ if encoder_hidden_states is None:
370
+ encoder_hidden_states = hidden_states
371
+ else:
372
+ # get encoder_hidden_states, ip_hidden_states
373
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
374
+ encoder_hidden_states, ip_hidden_states = (
375
+ encoder_hidden_states[:, :end_pos, :],
376
+ encoder_hidden_states[:, end_pos:, :],
377
+ )
378
+ if attn.norm_cross:
379
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
380
+
381
+ key = attn.to_k(encoder_hidden_states)
382
+ value = attn.to_v(encoder_hidden_states)
383
+
384
+ inner_dim = key.shape[-1]
385
+ head_dim = inner_dim // attn.heads
386
+
387
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
388
+
389
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
390
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
391
+
392
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
393
+ # TODO: add support for attn.scale when we move to Torch 2.1
394
+ hidden_states = F.scaled_dot_product_attention(
395
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
396
+ )
397
+
398
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
399
+ hidden_states = hidden_states.to(query.dtype)
400
+
401
+ # for ip-adapter
402
+ ip_key = self.to_k_ip(ip_hidden_states)
403
+ ip_value = self.to_v_ip(ip_hidden_states)
404
+
405
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
406
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
407
+
408
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
409
+ # TODO: add support for attn.scale when we move to Torch 2.1
410
+ ip_hidden_states = F.scaled_dot_product_attention(
411
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
412
+ )
413
+ with torch.no_grad():
414
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
415
+ #print(self.attn_map.shape)
416
+
417
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
418
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
419
+
420
+ # region control
421
+ if len(region_control.prompt_image_conditioning) == 1:
422
+ region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
423
+ if region_mask is not None:
424
+ h, w = region_mask.shape[:2]
425
+ ratio = (h * w / query.shape[1]) ** 0.5
426
+ mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
427
+ else:
428
+ mask = torch.ones_like(ip_hidden_states)
429
+ ip_hidden_states = ip_hidden_states * mask
430
+
431
+ hidden_states = hidden_states + self.scale * ip_hidden_states
432
+
433
+ # linear proj
434
+ hidden_states = attn.to_out[0](hidden_states)
435
+ # dropout
436
+ hidden_states = attn.to_out[1](hidden_states)
437
+
438
+ if input_ndim == 4:
439
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
440
+
441
+ if attn.residual_connection:
442
+ hidden_states = hidden_states + residual
443
+
444
+ hidden_states = hidden_states / attn.rescale_output_factor
445
+
446
+ return hidden_states
ip_adapter/resampler.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ # FFN
9
+ def FeedForward(dim, mult=4):
10
+ inner_dim = int(dim * mult)
11
+ return nn.Sequential(
12
+ nn.LayerNorm(dim),
13
+ nn.Linear(dim, inner_dim, bias=False),
14
+ nn.GELU(),
15
+ nn.Linear(inner_dim, dim, bias=False),
16
+ )
17
+
18
+
19
+ def reshape_tensor(x, heads):
20
+ bs, length, width = x.shape
21
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
22
+ x = x.view(bs, length, heads, -1)
23
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24
+ x = x.transpose(1, 2)
25
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26
+ x = x.reshape(bs, heads, length, -1)
27
+ return x
28
+
29
+
30
+ class PerceiverAttention(nn.Module):
31
+ def __init__(self, *, dim, dim_head=64, heads=8):
32
+ super().__init__()
33
+ self.scale = dim_head**-0.5
34
+ self.dim_head = dim_head
35
+ self.heads = heads
36
+ inner_dim = dim_head * heads
37
+
38
+ self.norm1 = nn.LayerNorm(dim)
39
+ self.norm2 = nn.LayerNorm(dim)
40
+
41
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
42
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
44
+
45
+
46
+ def forward(self, x, latents):
47
+ """
48
+ Args:
49
+ x (torch.Tensor): image features
50
+ shape (b, n1, D)
51
+ latent (torch.Tensor): latent features
52
+ shape (b, n2, D)
53
+ """
54
+ x = self.norm1(x)
55
+ latents = self.norm2(latents)
56
+
57
+ b, l, _ = latents.shape
58
+
59
+ q = self.to_q(latents)
60
+ kv_input = torch.cat((x, latents), dim=-2)
61
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
62
+
63
+ q = reshape_tensor(q, self.heads)
64
+ k = reshape_tensor(k, self.heads)
65
+ v = reshape_tensor(v, self.heads)
66
+
67
+ # attention
68
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
69
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
70
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71
+ out = weight @ v
72
+
73
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
74
+
75
+ return self.to_out(out)
76
+
77
+
78
+ class Resampler(nn.Module):
79
+ def __init__(
80
+ self,
81
+ dim=1024,
82
+ depth=8,
83
+ dim_head=64,
84
+ heads=16,
85
+ num_queries=8,
86
+ embedding_dim=768,
87
+ output_dim=1024,
88
+ ff_mult=4,
89
+ ):
90
+ super().__init__()
91
+
92
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
93
+
94
+ self.proj_in = nn.Linear(embedding_dim, dim)
95
+
96
+ self.proj_out = nn.Linear(dim, output_dim)
97
+ self.norm_out = nn.LayerNorm(output_dim)
98
+
99
+ self.layers = nn.ModuleList([])
100
+ for _ in range(depth):
101
+ self.layers.append(
102
+ nn.ModuleList(
103
+ [
104
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
105
+ FeedForward(dim=dim, mult=ff_mult),
106
+ ]
107
+ )
108
+ )
109
+
110
+ def forward(self, x):
111
+
112
+ latents = self.latents.repeat(x.size(0), 1, 1)
113
+
114
+ x = self.proj_in(x)
115
+
116
+ for attn, ff in self.layers:
117
+ latents = attn(x, latents) + latents
118
+ latents = ff(latents) + latents
119
+
120
+ latents = self.proj_out(latents)
121
+ return self.norm_out(latents)
ip_adapter/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def is_torch2_available():
5
+ return hasattr(F, "scaled_dot_product_attention")
models/antelopev2/1k3d68.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
3
+ size 143607619
models/antelopev2/2d106det.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
3
+ size 5030888
models/antelopev2/genderage.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
3
+ size 1322532
models/antelopev2/glintr100.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf
3
+ size 260665334
models/antelopev2/scrfd_10g_bnkps.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
3
+ size 16923827
pipeline_stable_diffusion_xl_instantid_full.py ADDED
@@ -0,0 +1,1204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import cv2
19
+ import math
20
+
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ import torch.nn.functional as F
25
+
26
+ from diffusers.image_processor import PipelineImageInput
27
+
28
+ from diffusers.models import ControlNetModel
29
+
30
+ from diffusers.utils import (
31
+ deprecate,
32
+ logging,
33
+ replace_example_docstring,
34
+ )
35
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
36
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
37
+
38
+ from diffusers import StableDiffusionXLControlNetPipeline
39
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
40
+ from diffusers.utils.import_utils import is_xformers_available
41
+
42
+ from ip_adapter.resampler import Resampler
43
+ from ip_adapter.utils import is_torch2_available
44
+
45
+ from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
46
+ from ip_adapter.attention_processor import region_control
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+
51
+ EXAMPLE_DOC_STRING = """
52
+ Examples:
53
+ ```py
54
+ >>> # !pip install opencv-python transformers accelerate insightface
55
+ >>> import diffusers
56
+ >>> from diffusers.utils import load_image
57
+ >>> from diffusers.models import ControlNetModel
58
+
59
+ >>> import cv2
60
+ >>> import torch
61
+ >>> import numpy as np
62
+ >>> from PIL import Image
63
+
64
+ >>> from insightface.app import FaceAnalysis
65
+ >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
66
+
67
+ >>> # download 'antelopev2' under ./models
68
+ >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
69
+ >>> app.prepare(ctx_id=0, det_size=(640, 640))
70
+
71
+ >>> # download models under ./checkpoints
72
+ >>> face_adapter = f'./checkpoints/ip-adapter.bin'
73
+ >>> controlnet_path = f'./checkpoints/ControlNetModel'
74
+
75
+ >>> # load IdentityNet
76
+ >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
77
+
78
+ >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
79
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
80
+ ... )
81
+ >>> pipe.cuda()
82
+
83
+ >>> # load adapter
84
+ >>> pipe.load_ip_adapter_instantid(face_adapter)
85
+
86
+ >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
87
+ >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
88
+
89
+ >>> # load an image
90
+ >>> image = load_image("your-example.jpg")
91
+
92
+ >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
93
+ >>> face_emb = face_info['embedding']
94
+ >>> face_kps = draw_kps(face_image, face_info['kps'])
95
+
96
+ >>> pipe.set_ip_adapter_scale(0.8)
97
+
98
+ >>> # generate image
99
+ >>> image = pipe(
100
+ ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
101
+ ... ).images[0]
102
+ ```
103
+ """
104
+
105
+ from transformers import CLIPTokenizer
106
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline
107
+ class LongPromptWeight(object):
108
+
109
+ """
110
+ Copied from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion_xl.py
111
+ """
112
+
113
+ def __init__(self) -> None:
114
+ pass
115
+
116
+ def parse_prompt_attention(self, text):
117
+ """
118
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
119
+ Accepted tokens are:
120
+ (abc) - increases attention to abc by a multiplier of 1.1
121
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
122
+ [abc] - decreases attention to abc by a multiplier of 1.1
123
+ \( - literal character '('
124
+ \[ - literal character '['
125
+ \) - literal character ')'
126
+ \] - literal character ']'
127
+ \\ - literal character '\'
128
+ anything else - just text
129
+
130
+ >>> parse_prompt_attention('normal text')
131
+ [['normal text', 1.0]]
132
+ >>> parse_prompt_attention('an (important) word')
133
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
134
+ >>> parse_prompt_attention('(unbalanced')
135
+ [['unbalanced', 1.1]]
136
+ >>> parse_prompt_attention('\(literal\]')
137
+ [['(literal]', 1.0]]
138
+ >>> parse_prompt_attention('(unnecessary)(parens)')
139
+ [['unnecessaryparens', 1.1]]
140
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
141
+ [['a ', 1.0],
142
+ ['house', 1.5730000000000004],
143
+ [' ', 1.1],
144
+ ['on', 1.0],
145
+ [' a ', 1.1],
146
+ ['hill', 0.55],
147
+ [', sun, ', 1.1],
148
+ ['sky', 1.4641000000000006],
149
+ ['.', 1.1]]
150
+ """
151
+ import re
152
+
153
+ re_attention = re.compile(
154
+ r"""
155
+ \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|
156
+ \)|]|[^\\()\[\]:]+|:
157
+ """,
158
+ re.X,
159
+ )
160
+
161
+ re_break = re.compile(r"\s*\bBREAK\b\s*", re.S)
162
+
163
+ res = []
164
+ round_brackets = []
165
+ square_brackets = []
166
+
167
+ round_bracket_multiplier = 1.1
168
+ square_bracket_multiplier = 1 / 1.1
169
+
170
+ def multiply_range(start_position, multiplier):
171
+ for p in range(start_position, len(res)):
172
+ res[p][1] *= multiplier
173
+
174
+ for m in re_attention.finditer(text):
175
+ text = m.group(0)
176
+ weight = m.group(1)
177
+
178
+ if text.startswith("\\"):
179
+ res.append([text[1:], 1.0])
180
+ elif text == "(":
181
+ round_brackets.append(len(res))
182
+ elif text == "[":
183
+ square_brackets.append(len(res))
184
+ elif weight is not None and len(round_brackets) > 0:
185
+ multiply_range(round_brackets.pop(), float(weight))
186
+ elif text == ")" and len(round_brackets) > 0:
187
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
188
+ elif text == "]" and len(square_brackets) > 0:
189
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
190
+ else:
191
+ parts = re.split(re_break, text)
192
+ for i, part in enumerate(parts):
193
+ if i > 0:
194
+ res.append(["BREAK", -1])
195
+ res.append([part, 1.0])
196
+
197
+ for pos in round_brackets:
198
+ multiply_range(pos, round_bracket_multiplier)
199
+
200
+ for pos in square_brackets:
201
+ multiply_range(pos, square_bracket_multiplier)
202
+
203
+ if len(res) == 0:
204
+ res = [["", 1.0]]
205
+
206
+ # merge runs of identical weights
207
+ i = 0
208
+ while i + 1 < len(res):
209
+ if res[i][1] == res[i + 1][1]:
210
+ res[i][0] += res[i + 1][0]
211
+ res.pop(i + 1)
212
+ else:
213
+ i += 1
214
+
215
+ return res
216
+
217
+ def get_prompts_tokens_with_weights(self, clip_tokenizer: CLIPTokenizer, prompt: str):
218
+ """
219
+ Get prompt token ids and weights, this function works for both prompt and negative prompt
220
+
221
+ Args:
222
+ pipe (CLIPTokenizer)
223
+ A CLIPTokenizer
224
+ prompt (str)
225
+ A prompt string with weights
226
+
227
+ Returns:
228
+ text_tokens (list)
229
+ A list contains token ids
230
+ text_weight (list)
231
+ A list contains the correspodent weight of token ids
232
+
233
+ Example:
234
+ import torch
235
+ from transformers import CLIPTokenizer
236
+
237
+ clip_tokenizer = CLIPTokenizer.from_pretrained(
238
+ "stablediffusionapi/deliberate-v2"
239
+ , subfolder = "tokenizer"
240
+ , dtype = torch.float16
241
+ )
242
+
243
+ token_id_list, token_weight_list = get_prompts_tokens_with_weights(
244
+ clip_tokenizer = clip_tokenizer
245
+ ,prompt = "a (red:1.5) cat"*70
246
+ )
247
+ """
248
+ texts_and_weights = self.parse_prompt_attention(prompt)
249
+ text_tokens, text_weights = [], []
250
+ for word, weight in texts_and_weights:
251
+ # tokenize and discard the starting and the ending token
252
+ token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt
253
+ # the returned token is a 1d list: [320, 1125, 539, 320]
254
+
255
+ # merge the new tokens to the all tokens holder: text_tokens
256
+ text_tokens = [*text_tokens, *token]
257
+
258
+ # each token chunk will come with one weight, like ['red cat', 2.0]
259
+ # need to expand weight for each token.
260
+ chunk_weights = [weight] * len(token)
261
+
262
+ # append the weight back to the weight holder: text_weights
263
+ text_weights = [*text_weights, *chunk_weights]
264
+ return text_tokens, text_weights
265
+
266
+ def group_tokens_and_weights(self, token_ids: list, weights: list, pad_last_block=False):
267
+ """
268
+ Produce tokens and weights in groups and pad the missing tokens
269
+
270
+ Args:
271
+ token_ids (list)
272
+ The token ids from tokenizer
273
+ weights (list)
274
+ The weights list from function get_prompts_tokens_with_weights
275
+ pad_last_block (bool)
276
+ Control if fill the last token list to 75 tokens with eos
277
+ Returns:
278
+ new_token_ids (2d list)
279
+ new_weights (2d list)
280
+
281
+ Example:
282
+ token_groups,weight_groups = group_tokens_and_weights(
283
+ token_ids = token_id_list
284
+ , weights = token_weight_list
285
+ )
286
+ """
287
+ bos, eos = 49406, 49407
288
+
289
+ # this will be a 2d list
290
+ new_token_ids = []
291
+ new_weights = []
292
+ while len(token_ids) >= 75:
293
+ # get the first 75 tokens
294
+ head_75_tokens = [token_ids.pop(0) for _ in range(75)]
295
+ head_75_weights = [weights.pop(0) for _ in range(75)]
296
+
297
+ # extract token ids and weights
298
+ temp_77_token_ids = [bos] + head_75_tokens + [eos]
299
+ temp_77_weights = [1.0] + head_75_weights + [1.0]
300
+
301
+ # add 77 token and weights chunk to the holder list
302
+ new_token_ids.append(temp_77_token_ids)
303
+ new_weights.append(temp_77_weights)
304
+
305
+ # padding the left
306
+ if len(token_ids) >= 0:
307
+ padding_len = 75 - len(token_ids) if pad_last_block else 0
308
+
309
+ temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
310
+ new_token_ids.append(temp_77_token_ids)
311
+
312
+ temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
313
+ new_weights.append(temp_77_weights)
314
+
315
+ return new_token_ids, new_weights
316
+
317
+ def get_weighted_text_embeddings_sdxl(
318
+ self,
319
+ pipe: StableDiffusionXLPipeline,
320
+ prompt: str = "",
321
+ prompt_2: str = None,
322
+ neg_prompt: str = "",
323
+ neg_prompt_2: str = None,
324
+ prompt_embeds=None,
325
+ negative_prompt_embeds=None,
326
+ pooled_prompt_embeds=None,
327
+ negative_pooled_prompt_embeds=None,
328
+ extra_emb=None,
329
+ extra_emb_alpha=0.6,
330
+ ):
331
+ """
332
+ This function can process long prompt with weights, no length limitation
333
+ for Stable Diffusion XL
334
+
335
+ Args:
336
+ pipe (StableDiffusionPipeline)
337
+ prompt (str)
338
+ prompt_2 (str)
339
+ neg_prompt (str)
340
+ neg_prompt_2 (str)
341
+ Returns:
342
+ prompt_embeds (torch.Tensor)
343
+ neg_prompt_embeds (torch.Tensor)
344
+ """
345
+ #
346
+ if prompt_embeds is not None and \
347
+ negative_prompt_embeds is not None and \
348
+ pooled_prompt_embeds is not None and \
349
+ negative_pooled_prompt_embeds is not None:
350
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
351
+
352
+ if prompt_2:
353
+ prompt = f"{prompt} {prompt_2}"
354
+
355
+ if neg_prompt_2:
356
+ neg_prompt = f"{neg_prompt} {neg_prompt_2}"
357
+
358
+ eos = pipe.tokenizer.eos_token_id
359
+
360
+ # tokenizer 1
361
+ prompt_tokens, prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
362
+ neg_prompt_tokens, neg_prompt_weights = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
363
+
364
+ # tokenizer 2
365
+ # prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt)
366
+ # neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt)
367
+ # tokenizer 2 遇到 !! !!!! 等多感叹号和tokenizer 1的效果不一致
368
+ prompt_tokens_2, prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, prompt)
369
+ neg_prompt_tokens_2, neg_prompt_weights_2 = self.get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt)
370
+
371
+ # padding the shorter one for prompt set 1
372
+ prompt_token_len = len(prompt_tokens)
373
+ neg_prompt_token_len = len(neg_prompt_tokens)
374
+
375
+ if prompt_token_len > neg_prompt_token_len:
376
+ # padding the neg_prompt with eos token
377
+ neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
378
+ neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
379
+ else:
380
+ # padding the prompt
381
+ prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len)
382
+ prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len)
383
+
384
+ # padding the shorter one for token set 2
385
+ prompt_token_len_2 = len(prompt_tokens_2)
386
+ neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
387
+
388
+ if prompt_token_len_2 > neg_prompt_token_len_2:
389
+ # padding the neg_prompt with eos token
390
+ neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
391
+ neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
392
+ else:
393
+ # padding the prompt
394
+ prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
395
+ prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
396
+
397
+ embeds = []
398
+ neg_embeds = []
399
+
400
+ prompt_token_groups, prompt_weight_groups = self.group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy())
401
+
402
+ neg_prompt_token_groups, neg_prompt_weight_groups = self.group_tokens_and_weights(
403
+ neg_prompt_tokens.copy(), neg_prompt_weights.copy()
404
+ )
405
+
406
+ prompt_token_groups_2, prompt_weight_groups_2 = self.group_tokens_and_weights(
407
+ prompt_tokens_2.copy(), prompt_weights_2.copy()
408
+ )
409
+
410
+ neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = self.group_tokens_and_weights(
411
+ neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy()
412
+ )
413
+
414
+ # get prompt embeddings one by one is not working.
415
+ for i in range(len(prompt_token_groups)):
416
+ # get positive prompt embeddings with weights
417
+ token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
418
+ weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
419
+
420
+ token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
421
+
422
+ # use first text encoder
423
+ prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
424
+ prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
425
+
426
+ # use second text encoder
427
+ prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
428
+ prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
429
+ pooled_prompt_embeds = prompt_embeds_2[0]
430
+
431
+ prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
432
+ token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
433
+
434
+ for j in range(len(weight_tensor)):
435
+ if weight_tensor[j] != 1.0:
436
+ token_embedding[j] = (
437
+ token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
438
+ )
439
+
440
+ token_embedding = token_embedding.unsqueeze(0)
441
+ embeds.append(token_embedding)
442
+
443
+ # get negative prompt embeddings with weights
444
+ neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
445
+ neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
446
+ neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
447
+
448
+ # use first text encoder
449
+ neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
450
+ neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
451
+
452
+ # use second text encoder
453
+ neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
454
+ neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
455
+ negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
456
+
457
+ neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
458
+ neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
459
+
460
+ for z in range(len(neg_weight_tensor)):
461
+ if neg_weight_tensor[z] != 1.0:
462
+ neg_token_embedding[z] = (
463
+ neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
464
+ )
465
+
466
+ neg_token_embedding = neg_token_embedding.unsqueeze(0)
467
+ neg_embeds.append(neg_token_embedding)
468
+
469
+ prompt_embeds = torch.cat(embeds, dim=1)
470
+ negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
471
+
472
+ if extra_emb is not None:
473
+ extra_emb = extra_emb.to(prompt_embeds.device, dtype=prompt_embeds.dtype) * extra_emb_alpha
474
+ prompt_embeds = torch.cat([prompt_embeds, extra_emb], 1)
475
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, torch.zeros_like(extra_emb)], 1)
476
+ print(f'fix prompt_embeds, extra_emb_alpha={extra_emb_alpha}')
477
+
478
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
479
+
480
+ def get_prompt_embeds(self, *args, **kwargs):
481
+ prompt_embeds, negative_prompt_embeds, _, _ = self.get_weighted_text_embeddings_sdxl(*args, **kwargs)
482
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
483
+ return prompt_embeds
484
+
485
+ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
486
+
487
+ stickwidth = 4
488
+ limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
489
+ kps = np.array(kps)
490
+
491
+ w, h = image_pil.size
492
+ out_img = np.zeros([h, w, 3])
493
+
494
+ for i in range(len(limbSeq)):
495
+ index = limbSeq[i]
496
+ color = color_list[index[0]]
497
+
498
+ x = kps[index][:, 0]
499
+ y = kps[index][:, 1]
500
+ length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
501
+ angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
502
+ polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
503
+ out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
504
+ out_img = (out_img * 0.6).astype(np.uint8)
505
+
506
+ for idx_kp, kp in enumerate(kps):
507
+ color = color_list[idx_kp]
508
+ x, y = kp
509
+ out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
510
+
511
+ out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
512
+ return out_img_pil
513
+
514
+ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
515
+
516
+ def cuda(self, dtype=torch.float16, use_xformers=False):
517
+ self.to('cuda', dtype)
518
+
519
+ if hasattr(self, 'image_proj_model'):
520
+ self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
521
+
522
+ if use_xformers:
523
+ if is_xformers_available():
524
+ import xformers
525
+ from packaging import version
526
+
527
+ xformers_version = version.parse(xformers.__version__)
528
+ if xformers_version == version.parse("0.0.16"):
529
+ logger.warn(
530
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
531
+ )
532
+ self.enable_xformers_memory_efficient_attention()
533
+ else:
534
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
535
+
536
+ def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
537
+ self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
538
+ self.set_ip_adapter(model_ckpt, num_tokens, scale)
539
+
540
+ def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
541
+
542
+ image_proj_model = Resampler(
543
+ dim=1280,
544
+ depth=4,
545
+ dim_head=64,
546
+ heads=20,
547
+ num_queries=num_tokens,
548
+ embedding_dim=image_emb_dim,
549
+ output_dim=self.unet.config.cross_attention_dim,
550
+ ff_mult=4,
551
+ )
552
+
553
+ image_proj_model.eval()
554
+
555
+ self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
556
+ state_dict = torch.load(model_ckpt, map_location="cpu")
557
+ if 'image_proj' in state_dict:
558
+ state_dict = state_dict["image_proj"]
559
+ self.image_proj_model.load_state_dict(state_dict)
560
+
561
+ self.image_proj_model_in_features = image_emb_dim
562
+
563
+ def set_ip_adapter(self, model_ckpt, num_tokens, scale):
564
+
565
+ unet = self.unet
566
+ attn_procs = {}
567
+ for name in unet.attn_processors.keys():
568
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
569
+ if name.startswith("mid_block"):
570
+ hidden_size = unet.config.block_out_channels[-1]
571
+ elif name.startswith("up_blocks"):
572
+ block_id = int(name[len("up_blocks.")])
573
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
574
+ elif name.startswith("down_blocks"):
575
+ block_id = int(name[len("down_blocks.")])
576
+ hidden_size = unet.config.block_out_channels[block_id]
577
+ if cross_attention_dim is None:
578
+ attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
579
+ else:
580
+ attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size,
581
+ cross_attention_dim=cross_attention_dim,
582
+ scale=scale,
583
+ num_tokens=num_tokens).to(unet.device, dtype=unet.dtype)
584
+ unet.set_attn_processor(attn_procs)
585
+
586
+ state_dict = torch.load(model_ckpt, map_location="cpu")
587
+ ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
588
+ if 'ip_adapter' in state_dict:
589
+ state_dict = state_dict['ip_adapter']
590
+ ip_layers.load_state_dict(state_dict)
591
+
592
+ def set_ip_adapter_scale(self, scale):
593
+ unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
594
+ for attn_processor in unet.attn_processors.values():
595
+ if isinstance(attn_processor, IPAttnProcessor):
596
+ attn_processor.scale = scale
597
+
598
+ def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype, do_classifier_free_guidance):
599
+
600
+ if isinstance(prompt_image_emb, torch.Tensor):
601
+ prompt_image_emb = prompt_image_emb.clone().detach()
602
+ else:
603
+ prompt_image_emb = torch.tensor(prompt_image_emb)
604
+
605
+ prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
606
+ prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
607
+
608
+ if do_classifier_free_guidance:
609
+ prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
610
+ else:
611
+ prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
612
+
613
+ prompt_image_emb = self.image_proj_model(prompt_image_emb)
614
+
615
+ bs_embed, seq_len, _ = prompt_image_emb.shape
616
+ prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
617
+ prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
618
+
619
+ return prompt_image_emb
620
+
621
+ @torch.no_grad()
622
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
623
+ def __call__(
624
+ self,
625
+ prompt: Union[str, List[str]] = None,
626
+ prompt_2: Optional[Union[str, List[str]]] = None,
627
+ image: PipelineImageInput = None,
628
+ height: Optional[int] = None,
629
+ width: Optional[int] = None,
630
+ num_inference_steps: int = 50,
631
+ guidance_scale: float = 5.0,
632
+ negative_prompt: Optional[Union[str, List[str]]] = None,
633
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
634
+ num_images_per_prompt: Optional[int] = 1,
635
+ eta: float = 0.0,
636
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
637
+ latents: Optional[torch.FloatTensor] = None,
638
+ prompt_embeds: Optional[torch.FloatTensor] = None,
639
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
640
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
641
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
642
+ image_embeds: Optional[torch.FloatTensor] = None,
643
+ output_type: Optional[str] = "pil",
644
+ return_dict: bool = True,
645
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
646
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
647
+ guess_mode: bool = False,
648
+ control_guidance_start: Union[float, List[float]] = 0.0,
649
+ control_guidance_end: Union[float, List[float]] = 1.0,
650
+ original_size: Tuple[int, int] = None,
651
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
652
+ target_size: Tuple[int, int] = None,
653
+ negative_original_size: Optional[Tuple[int, int]] = None,
654
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
655
+ negative_target_size: Optional[Tuple[int, int]] = None,
656
+ clip_skip: Optional[int] = None,
657
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
658
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
659
+
660
+ # IP adapter
661
+ ip_adapter_scale=None,
662
+
663
+ # Enhance Face Region
664
+ control_mask = None,
665
+
666
+ **kwargs,
667
+ ):
668
+ r"""
669
+ The call function to the pipeline for generation.
670
+
671
+ Args:
672
+ prompt (`str` or `List[str]`, *optional*):
673
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
674
+ prompt_2 (`str` or `List[str]`, *optional*):
675
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
676
+ used in both text-encoders.
677
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
678
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
679
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
680
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
681
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
682
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
683
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
684
+ input to a single ControlNet.
685
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
686
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
687
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
688
+ and checkpoints that are not specifically fine-tuned on low resolutions.
689
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
690
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
691
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
692
+ and checkpoints that are not specifically fine-tuned on low resolutions.
693
+ num_inference_steps (`int`, *optional*, defaults to 50):
694
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
695
+ expense of slower inference.
696
+ guidance_scale (`float`, *optional*, defaults to 5.0):
697
+ A higher guidance scale value encourages the model to generate images closely linked to the text
698
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
699
+ negative_prompt (`str` or `List[str]`, *optional*):
700
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
701
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
702
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
703
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
704
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
705
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
706
+ The number of images to generate per prompt.
707
+ eta (`float`, *optional*, defaults to 0.0):
708
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
709
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
710
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
711
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
712
+ generation deterministic.
713
+ latents (`torch.FloatTensor`, *optional*):
714
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
715
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
716
+ tensor is generated by sampling using the supplied random `generator`.
717
+ prompt_embeds (`torch.FloatTensor`, *optional*):
718
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
719
+ provided, text embeddings are generated from the `prompt` input argument.
720
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
721
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
722
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
723
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
724
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
725
+ not provided, pooled text embeddings are generated from `prompt` input argument.
726
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
727
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
728
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
729
+ argument.
730
+ image_embeds (`torch.FloatTensor`, *optional*):
731
+ Pre-generated image embeddings.
732
+ output_type (`str`, *optional*, defaults to `"pil"`):
733
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
734
+ return_dict (`bool`, *optional*, defaults to `True`):
735
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
736
+ plain tuple.
737
+ cross_attention_kwargs (`dict`, *optional*):
738
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
739
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
740
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
741
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
742
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
743
+ the corresponding scale as a list.
744
+ guess_mode (`bool`, *optional*, defaults to `False`):
745
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
746
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
747
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
748
+ The percentage of total steps at which the ControlNet starts applying.
749
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
750
+ The percentage of total steps at which the ControlNet stops applying.
751
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
752
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
753
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
754
+ explained in section 2.2 of
755
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
756
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
757
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
758
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
759
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
760
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
761
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
762
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
763
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
764
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
765
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
766
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
767
+ micro-conditioning as explained in section 2.2 of
768
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
769
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
770
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
771
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
772
+ micro-conditioning as explained in section 2.2 of
773
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
774
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
775
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
776
+ To negatively condition the generation process based on a target image resolution. It should be as same
777
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
778
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
779
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
780
+ clip_skip (`int`, *optional*):
781
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
782
+ the output of the pre-final layer will be used for computing the prompt embeddings.
783
+ callback_on_step_end (`Callable`, *optional*):
784
+ A function that calls at the end of each denoising steps during the inference. The function is called
785
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
786
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
787
+ `callback_on_step_end_tensor_inputs`.
788
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
789
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
790
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
791
+ `._callback_tensor_inputs` attribute of your pipeine class.
792
+
793
+ Examples:
794
+
795
+ Returns:
796
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
797
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
798
+ otherwise a `tuple` is returned containing the output images.
799
+ """
800
+
801
+ lpw = LongPromptWeight()
802
+
803
+ callback = kwargs.pop("callback", None)
804
+ callback_steps = kwargs.pop("callback_steps", None)
805
+
806
+ if callback is not None:
807
+ deprecate(
808
+ "callback",
809
+ "1.0.0",
810
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
811
+ )
812
+ if callback_steps is not None:
813
+ deprecate(
814
+ "callback_steps",
815
+ "1.0.0",
816
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
817
+ )
818
+
819
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
820
+
821
+ # align format for control guidance
822
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
823
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
824
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
825
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
826
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
827
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
828
+ control_guidance_start, control_guidance_end = (
829
+ mult * [control_guidance_start],
830
+ mult * [control_guidance_end],
831
+ )
832
+
833
+ # 0. set ip_adapter_scale
834
+ if ip_adapter_scale is not None:
835
+ self.set_ip_adapter_scale(ip_adapter_scale)
836
+
837
+ # 1. Check inputs. Raise error if not correct
838
+ self.check_inputs(
839
+ prompt,
840
+ prompt_2,
841
+ image,
842
+ callback_steps,
843
+ negative_prompt,
844
+ negative_prompt_2,
845
+ prompt_embeds,
846
+ negative_prompt_embeds,
847
+ pooled_prompt_embeds,
848
+ negative_pooled_prompt_embeds,
849
+ controlnet_conditioning_scale,
850
+ control_guidance_start,
851
+ control_guidance_end,
852
+ callback_on_step_end_tensor_inputs,
853
+ )
854
+
855
+ self._guidance_scale = guidance_scale
856
+ self._clip_skip = clip_skip
857
+ self._cross_attention_kwargs = cross_attention_kwargs
858
+
859
+ # 2. Define call parameters
860
+ if prompt is not None and isinstance(prompt, str):
861
+ batch_size = 1
862
+ elif prompt is not None and isinstance(prompt, list):
863
+ batch_size = len(prompt)
864
+ else:
865
+ batch_size = prompt_embeds.shape[0]
866
+
867
+ device = self._execution_device
868
+
869
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
870
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
871
+
872
+ global_pool_conditions = (
873
+ controlnet.config.global_pool_conditions
874
+ if isinstance(controlnet, ControlNetModel)
875
+ else controlnet.nets[0].config.global_pool_conditions
876
+ )
877
+ guess_mode = guess_mode or global_pool_conditions
878
+
879
+ # 3.1 Encode input prompt
880
+ (
881
+ prompt_embeds,
882
+ negative_prompt_embeds,
883
+ pooled_prompt_embeds,
884
+ negative_pooled_prompt_embeds,
885
+ ) = lpw.get_weighted_text_embeddings_sdxl(
886
+ pipe=self,
887
+ prompt=prompt,
888
+ neg_prompt=negative_prompt,
889
+ prompt_embeds=prompt_embeds,
890
+ negative_prompt_embeds=negative_prompt_embeds,
891
+ pooled_prompt_embeds=pooled_prompt_embeds,
892
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
893
+ )
894
+
895
+ # 3.2 Encode image prompt
896
+ prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
897
+ device,
898
+ num_images_per_prompt,
899
+ self.unet.dtype,
900
+ self.do_classifier_free_guidance)
901
+
902
+ # 4. Prepare image
903
+ if isinstance(controlnet, ControlNetModel):
904
+ image = self.prepare_image(
905
+ image=image,
906
+ width=width,
907
+ height=height,
908
+ batch_size=batch_size * num_images_per_prompt,
909
+ num_images_per_prompt=num_images_per_prompt,
910
+ device=device,
911
+ dtype=controlnet.dtype,
912
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
913
+ guess_mode=guess_mode,
914
+ )
915
+ height, width = image.shape[-2:]
916
+ elif isinstance(controlnet, MultiControlNetModel):
917
+ images = []
918
+
919
+ for image_ in image:
920
+ image_ = self.prepare_image(
921
+ image=image_,
922
+ width=width,
923
+ height=height,
924
+ batch_size=batch_size * num_images_per_prompt,
925
+ num_images_per_prompt=num_images_per_prompt,
926
+ device=device,
927
+ dtype=controlnet.dtype,
928
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
929
+ guess_mode=guess_mode,
930
+ )
931
+
932
+ images.append(image_)
933
+
934
+ image = images
935
+ height, width = image[0].shape[-2:]
936
+ else:
937
+ assert False
938
+
939
+ # 4.1 Region control
940
+ if control_mask is not None:
941
+ mask_weight_image = control_mask
942
+ mask_weight_image = np.array(mask_weight_image)
943
+ mask_weight_image_tensor = torch.from_numpy(mask_weight_image).to(device=device, dtype=prompt_embeds.dtype)
944
+ mask_weight_image_tensor = mask_weight_image_tensor[:, :, 0] / 255.
945
+ mask_weight_image_tensor = mask_weight_image_tensor[None, None]
946
+ h, w = mask_weight_image_tensor.shape[-2:]
947
+ control_mask_wight_image_list = []
948
+ for scale in [8, 8, 8, 16, 16, 16, 32, 32, 32]:
949
+ scale_mask_weight_image_tensor = F.interpolate(
950
+ mask_weight_image_tensor,(h // scale, w // scale), mode='bilinear')
951
+ control_mask_wight_image_list.append(scale_mask_weight_image_tensor)
952
+ region_mask = torch.from_numpy(np.array(control_mask)[:, :, 0]).to(self.unet.device, dtype=self.unet.dtype) / 255.
953
+ region_control.prompt_image_conditioning = [dict(region_mask=region_mask)]
954
+ else:
955
+ control_mask_wight_image_list = None
956
+ region_control.prompt_image_conditioning = [dict(region_mask=None)]
957
+
958
+ # 5. Prepare timesteps
959
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
960
+ timesteps = self.scheduler.timesteps
961
+ self._num_timesteps = len(timesteps)
962
+
963
+ # 6. Prepare latent variables
964
+ num_channels_latents = self.unet.config.in_channels
965
+ latents = self.prepare_latents(
966
+ batch_size * num_images_per_prompt,
967
+ num_channels_latents,
968
+ height,
969
+ width,
970
+ prompt_embeds.dtype,
971
+ device,
972
+ generator,
973
+ latents,
974
+ )
975
+
976
+ # 6.5 Optionally get Guidance Scale Embedding
977
+ timestep_cond = None
978
+ if self.unet.config.time_cond_proj_dim is not None:
979
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
980
+ timestep_cond = self.get_guidance_scale_embedding(
981
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
982
+ ).to(device=device, dtype=latents.dtype)
983
+
984
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
985
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
986
+
987
+ # 7.1 Create tensor stating which controlnets to keep
988
+ controlnet_keep = []
989
+ for i in range(len(timesteps)):
990
+ keeps = [
991
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
992
+ for s, e in zip(control_guidance_start, control_guidance_end)
993
+ ]
994
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
995
+
996
+ # 7.2 Prepare added time ids & embeddings
997
+ if isinstance(image, list):
998
+ original_size = original_size or image[0].shape[-2:]
999
+ else:
1000
+ original_size = original_size or image.shape[-2:]
1001
+ target_size = target_size or (height, width)
1002
+
1003
+ add_text_embeds = pooled_prompt_embeds
1004
+ if self.text_encoder_2 is None:
1005
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1006
+ else:
1007
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1008
+
1009
+ add_time_ids = self._get_add_time_ids(
1010
+ original_size,
1011
+ crops_coords_top_left,
1012
+ target_size,
1013
+ dtype=prompt_embeds.dtype,
1014
+ text_encoder_projection_dim=text_encoder_projection_dim,
1015
+ )
1016
+
1017
+ if negative_original_size is not None and negative_target_size is not None:
1018
+ negative_add_time_ids = self._get_add_time_ids(
1019
+ negative_original_size,
1020
+ negative_crops_coords_top_left,
1021
+ negative_target_size,
1022
+ dtype=prompt_embeds.dtype,
1023
+ text_encoder_projection_dim=text_encoder_projection_dim,
1024
+ )
1025
+ else:
1026
+ negative_add_time_ids = add_time_ids
1027
+
1028
+ if self.do_classifier_free_guidance:
1029
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1030
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1031
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1032
+
1033
+ prompt_embeds = prompt_embeds.to(device)
1034
+ add_text_embeds = add_text_embeds.to(device)
1035
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1036
+ encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
1037
+
1038
+ # 8. Denoising loop
1039
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1040
+ is_unet_compiled = is_compiled_module(self.unet)
1041
+ is_controlnet_compiled = is_compiled_module(self.controlnet)
1042
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1043
+
1044
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1045
+ for i, t in enumerate(timesteps):
1046
+ # Relevant thread:
1047
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1048
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
1049
+ torch._inductor.cudagraph_mark_step_begin()
1050
+ # expand the latents if we are doing classifier free guidance
1051
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1052
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1053
+
1054
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1055
+
1056
+ # controlnet(s) inference
1057
+ if guess_mode and self.do_classifier_free_guidance:
1058
+ # Infer ControlNet only for the conditional batch.
1059
+ control_model_input = latents
1060
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1061
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1062
+ controlnet_added_cond_kwargs = {
1063
+ "text_embeds": add_text_embeds.chunk(2)[1],
1064
+ "time_ids": add_time_ids.chunk(2)[1],
1065
+ }
1066
+ else:
1067
+ control_model_input = latent_model_input
1068
+ controlnet_prompt_embeds = prompt_embeds
1069
+ controlnet_added_cond_kwargs = added_cond_kwargs
1070
+
1071
+ if isinstance(controlnet_keep[i], list):
1072
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1073
+ else:
1074
+ controlnet_cond_scale = controlnet_conditioning_scale
1075
+ if isinstance(controlnet_cond_scale, list):
1076
+ controlnet_cond_scale = controlnet_cond_scale[0]
1077
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1078
+
1079
+ if isinstance(self.controlnet, MultiControlNetModel):
1080
+ down_block_res_samples_list, mid_block_res_sample_list = [], []
1081
+ for control_index in range(len(self.controlnet.nets)):
1082
+ controlnet = self.controlnet.nets[control_index]
1083
+ if control_index == 0:
1084
+ # assume fhe first controlnet is IdentityNet
1085
+ controlnet_prompt_embeds = prompt_image_emb
1086
+ else:
1087
+ controlnet_prompt_embeds = prompt_embeds
1088
+ down_block_res_samples, mid_block_res_sample = controlnet(control_model_input,
1089
+ t,
1090
+ encoder_hidden_states=controlnet_prompt_embeds,
1091
+ controlnet_cond=image[control_index],
1092
+ conditioning_scale=cond_scale[control_index],
1093
+ guess_mode=guess_mode,
1094
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1095
+ return_dict=False)
1096
+
1097
+ # controlnet mask
1098
+ if control_index == 0 and control_mask_wight_image_list is not None:
1099
+ down_block_res_samples = [
1100
+ down_block_res_sample * mask_weight
1101
+ for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
1102
+ ]
1103
+ mid_block_res_sample *= control_mask_wight_image_list[-1]
1104
+
1105
+ down_block_res_samples_list.append(down_block_res_samples)
1106
+ mid_block_res_sample_list.append(mid_block_res_sample)
1107
+
1108
+ mid_block_res_sample = torch.stack(mid_block_res_sample_list).sum(dim=0)
1109
+ down_block_res_samples = [torch.stack(down_block_res_samples).sum(dim=0) for down_block_res_samples in
1110
+ zip(*down_block_res_samples_list)]
1111
+ else:
1112
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1113
+ control_model_input,
1114
+ t,
1115
+ encoder_hidden_states=prompt_image_emb,
1116
+ controlnet_cond=image,
1117
+ conditioning_scale=cond_scale,
1118
+ guess_mode=guess_mode,
1119
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1120
+ return_dict=False,
1121
+ )
1122
+
1123
+ # controlnet mask
1124
+ if control_mask_wight_image_list is not None:
1125
+ down_block_res_samples = [
1126
+ down_block_res_sample * mask_weight
1127
+ for down_block_res_sample, mask_weight in zip(down_block_res_samples, control_mask_wight_image_list)
1128
+ ]
1129
+ mid_block_res_sample *= control_mask_wight_image_list[-1]
1130
+
1131
+ if guess_mode and self.do_classifier_free_guidance:
1132
+ # Infered ControlNet only for the conditional batch.
1133
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1134
+ # add 0 to the unconditional batch to keep it unchanged.
1135
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1136
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1137
+
1138
+ # predict the noise residual
1139
+ noise_pred = self.unet(
1140
+ latent_model_input,
1141
+ t,
1142
+ encoder_hidden_states=encoder_hidden_states,
1143
+ timestep_cond=timestep_cond,
1144
+ cross_attention_kwargs=self.cross_attention_kwargs,
1145
+ down_block_additional_residuals=down_block_res_samples,
1146
+ mid_block_additional_residual=mid_block_res_sample,
1147
+ added_cond_kwargs=added_cond_kwargs,
1148
+ return_dict=False,
1149
+ )[0]
1150
+
1151
+ # perform guidance
1152
+ if self.do_classifier_free_guidance:
1153
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1154
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1155
+
1156
+ # compute the previous noisy sample x_t -> x_t-1
1157
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1158
+
1159
+ if callback_on_step_end is not None:
1160
+ callback_kwargs = {}
1161
+ for k in callback_on_step_end_tensor_inputs:
1162
+ callback_kwargs[k] = locals()[k]
1163
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1164
+
1165
+ latents = callback_outputs.pop("latents", latents)
1166
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1167
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1168
+
1169
+ # call the callback, if provided
1170
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1171
+ progress_bar.update()
1172
+ if callback is not None and i % callback_steps == 0:
1173
+ step_idx = i // getattr(self.scheduler, "order", 1)
1174
+ callback(step_idx, t, latents)
1175
+
1176
+ if not output_type == "latent":
1177
+ # make sure the VAE is in float32 mode, as it overflows in float16
1178
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1179
+ if needs_upcasting:
1180
+ self.upcast_vae()
1181
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1182
+
1183
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1184
+
1185
+ # cast back to fp16 if needed
1186
+ if needs_upcasting:
1187
+ self.vae.to(dtype=torch.float16)
1188
+ else:
1189
+ image = latents
1190
+
1191
+ if not output_type == "latent":
1192
+ # apply watermark if available
1193
+ if self.watermark is not None:
1194
+ image = self.watermark.apply_watermark(image)
1195
+
1196
+ image = self.image_processor.postprocess(image, output_type=output_type)
1197
+
1198
+ # Offload all models
1199
+ self.maybe_free_model_hooks()
1200
+
1201
+ if not return_dict:
1202
+ return (image,)
1203
+
1204
+ return StableDiffusionXLPipelineOutput(images=image)
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.25.1
2
+ transformers==4.37.1
3
+ accelerate
4
+ safetensors
5
+ einops
6
+ onnxruntime
7
+ spaces==0.19.4
8
+ omegaconf
9
+ peft
10
+ huggingface-hub==0.20.2
11
+ opencv-python
12
+ insightface
13
+ gradio
14
+ controlnet_aux
15
+ gdown
16
+ peft
style_template.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ style_list = [
2
+ {
3
+ "name": "(No style)",
4
+ "prompt": "{prompt}",
5
+ "negative_prompt": "",
6
+
7
+ },
8
+ {
9
+ "name": "Spring Festival",
10
+ "prompt": "Flat illustration, a Chinese {prompt}, ancient style, wearing a red cloth, smile face, white skin, clean background, fireworks blooming, red lanterns",
11
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast, realistic, cropped, worst quality, missing fingers, extra digit, jpeg artifacts, signature, multiple, (lowres, low quality, worst quality:1.2)",
12
+ },
13
+ {
14
+ "name": "Watercolor",
15
+ "prompt": "watercolor painting, {prompt}. vibrant, beautiful, painterly, detailed, textural, artistic",
16
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy",
17
+ },
18
+ {
19
+ "name": "Film Noir",
20
+ "prompt": "film noir style, ink sketch|vector, {prompt} highly detailed, sharp focus, ultra sharpness, monochrome, high contrast, dramatic shadows, 1940s style, mysterious, cinematic",
21
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
22
+ },
23
+ {
24
+ "name": "Neon",
25
+ "prompt": "masterpiece painting, buildings in the backdrop, kaleidoscope, lilac orange blue cream fuchsia bright vivid gradient colors, the scene is cinematic, {prompt}, emotional realism, double exposure, watercolor ink pencil, graded wash, color layering, magic realism, figurative painting, intricate motifs, organic tracery, polished",
26
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
27
+ },
28
+ {
29
+ "name": "Jungle",
30
+ "prompt": 'waist-up "{prompt} in a Jungle" by Syd Mead, tangerine cold color palette, muted colors, detailed, 8k,photo r3al,dripping paint,3d toon style,3d style,Movie Still',
31
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
32
+ },
33
+ {
34
+ "name": "Mars",
35
+ "prompt": "{prompt}, Post-apocalyptic. Mars Colony, Scavengers roam the wastelands searching for valuable resources, rovers, bright morning sunlight shining, (detailed) (intricate) (8k) (HDR) (cinematic lighting) (sharp focus)",
36
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
37
+ },
38
+ {
39
+ "name": "Vibrant Color",
40
+ "prompt": "vibrant colorful, ink sketch|vector|2d colors, at nightfall, sharp focus, {prompt}, highly detailed, sharp focus, the clouds,colorful,ultra sharpness",
41
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
42
+ },
43
+ {
44
+ "name": "Snow",
45
+ "prompt": "cinema 4d render, {prompt}, high contrast, vibrant and saturated, sico style, surrounded by magical glow,floating ice shards, snow crystals, cold, windy background, frozen natural landscape in background cinematic atmosphere,highly detailed, sharp focus, intricate design, 3d, unreal engine, octane render, CG best quality, highres, photorealistic, dramatic lighting, artstation, concept art, cinematic, epic Steven Spielberg movie still, sharp focus, smoke, sparks, art by pascal blanche and greg rutkowski and repin, trending on artstation, hyperrealism painting, matte painting, 4k resolution",
46
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
47
+ },
48
+ {
49
+ "name": "Line art",
50
+ "prompt": "line art drawing {prompt} . professional, sleek, modern, minimalist, graphic, line art, vector graphics",
51
+ "negative_prompt": "anime, photorealistic, 35mm film, deformed, glitch, blurry, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, mutated, realism, realistic, impressionism, expressionism, oil, acrylic",
52
+ },
53
+ {
54
+ "name": "Art Nouveau",
55
+ "prompt": "Art Nouveau style, {prompt}, organic forms, curvilinear lines, elegant, nature-inspired, intricate patterns, flowing designs, soft colors",
56
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, brutalism, geometric, harsh lines, noisy",
57
+ },
58
+ {
59
+ "name": "Cubism",
60
+ "prompt": "cubist painting, {prompt}, abstract, geometric shapes, fragmented objects, multiple viewpoints, bold lines, reduced colors, Pablo Picasso inspired",
61
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, smooth, flowing, noisy",
62
+ },
63
+ {
64
+ "name": "Minimalism",
65
+ "prompt": "minimalist art, {prompt}, simple, clean lines, monochrome, negative space, minimal detail, geometric shapes, modern, sophisticated",
66
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, baroque, cluttered, over-detailed, noisy",
67
+ },
68
+ {
69
+ "name": "Baroque",
70
+ "prompt": "baroque style, {prompt}, grandeur, drama, contrast, rich colors, intense lighting, ornate, Caravaggio inspired, emotional, detailed",
71
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, minimalist, flat, dull, noisy",
72
+ },
73
+ {
74
+ "name": "Expressionism",
75
+ "prompt": "expressionist art, {prompt}, emotional, distorted, exaggerated, vivid colors, dynamic brushstrokes, subjective perspective, impactful",
76
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealism, bland, underwhelming, noisy",
77
+ },
78
+ {
79
+ "name": "Digital Glitch",
80
+ "prompt": "digital glitch art, {prompt}, corrupted image, tech-inspired, vibrant, surreal, abstract, pixelated, data moshing, futuristic",
81
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, classical, realistic, smooth, noiseless",
82
+ },
83
+ {
84
+ "name": "Psychedelic",
85
+ "prompt": "psychedelic art, {prompt}, vivid colors, hallucinatory patterns, surreal, fluid shapes, trippy, 1960s style, kaleidoscopic, vibrant",
86
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, monochrome, simplistic, noiseless",
87
+ },
88
+ {
89
+ "name": "Victorian",
90
+ "prompt": "Victorian style, {prompt}, elegant, ornate, romantic, historical, detailed patterns, rich textures, sophisticated, vintage",
91
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, modern, minimalist, plain, noisy",
92
+ },
93
+ {
94
+ "name": "Graffiti",
95
+ "prompt": "graffiti art, {prompt}, urban, street style, bold colors, spray paint, tagging, hip-hop culture, dynamic, expressive, contemporary",
96
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, classical, fine art, noiseless, smooth",
97
+ },
98
+ {
99
+ "name": "Ukiyo-e",
100
+ "prompt": "Ukiyo-e style, {prompt}, Japanese woodblock print, flat color areas, bold outlines, historical scenes, nature, Edo period, traditional",
101
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, 3D, photorealistic, modern, noisy",
102
+ },
103
+ {
104
+ "name": "Retro Futurism",
105
+ "prompt": "retro futurism, {prompt}, vibrant colors, geometric shapes, streamline design, 1950s and 1960s style, optimistic, space-age, visionary, neon lighting, bold typography",
106
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, medieval, deformed, glitch, blurry, noisy",
107
+ },
108
+ {
109
+ "name": "Steampunk",
110
+ "prompt": "steampunk style, {prompt}, Victorian era, industrial, mechanical gears, brass and copper, steam engines, intricate details, fantastical machines, sepia tones",
111
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, neon, futuristic, blurry, noisy",
112
+ },
113
+ {
114
+ "name": "Cyberpunk",
115
+ "prompt": "cyberpunk aesthetic, {prompt}, neon lights, urban dystopia, futuristic, high-tech, low-life, cybernetics, dark and gritty, vivid colors, digital world",
116
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, medieval, steampunk, blurry, noisy",
117
+ },
118
+ {
119
+ "name": "Impressionist",
120
+ "prompt": "impressionist style painting, {prompt}, vibrant, light brushstrokes, open composition, emphasis on light in its changing qualities, movement, ordinary subject matter",
121
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy",
122
+ },
123
+ {
124
+ "name": "Art Deco",
125
+ "prompt": "art deco style, {prompt}, glamorous, elegant, functional, geometric patterns, bold colors, sleek lines, chrome, glass, shiny fabrics",
126
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, rustic, medieval, deformed, glitch, noisy",
127
+ },
128
+ {
129
+ "name": "Fantasy",
130
+ "prompt": "fantasy style, {prompt}, mythical creatures, enchanted forests, magic elements, dreamlike landscapes, vibrant colors, detailed, imaginative",
131
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy",
132
+ },
133
+ {
134
+ "name": "Gothic",
135
+ "prompt": "gothic style, {prompt}, dark and moody, gothic architecture, medieval elements, dramatic lighting, somber tones, intricate details, mysterious atmosphere",
136
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, neon, futuristic, blurry, noisy",
137
+ },
138
+ {
139
+ "name": "Pop Art",
140
+ "prompt": "pop art style, {prompt}, bold colors, mass culture, comic style, ironic, whimsical, repetition of images, bright, high contrast",
141
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, medieval, steampunk, blurry, noisy",
142
+ },
143
+ {
144
+ "name": "Surrealism",
145
+ "prompt": "surrealist style, {prompt}, dreamlike, bizarre, irrational, abstract, imaginative, distorted reality, vivid, unexpected juxtapositions",
146
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy",
147
+ },
148
+ {
149
+ "name": "Abstract",
150
+ "prompt": "abstract style, {prompt}, non-representational, shapes, forms, colors, lines, dynamic, modern, expressive, non-figurative, bold",
151
+ "negative_prompt": "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, anime, photorealistic, 35mm film, deformed, glitch, low contrast, noisy",
152
+ },
153
+ ]
154
+
155
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
torchhub/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Local PyTorch Hub
2
+
3
+ This directory is for loading the DINOv2 encoder locally in case of no Internet connection.
torchhub/facebookresearch_dinov2_main/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
torchhub/facebookresearch_dinov2_main/CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to DINOv2
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Meta's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to DINOv2, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
torchhub/facebookresearch_dinov2_main/LICENSE ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Attribution-NonCommercial 4.0 International
3
+
4
+ =======================================================================
5
+
6
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
7
+ does not provide legal services or legal advice. Distribution of
8
+ Creative Commons public licenses does not create a lawyer-client or
9
+ other relationship. Creative Commons makes its licenses and related
10
+ information available on an "as-is" basis. Creative Commons gives no
11
+ warranties regarding its licenses, any material licensed under their
12
+ terms and conditions, or any related information. Creative Commons
13
+ disclaims all liability for damages resulting from their use to the
14
+ fullest extent possible.
15
+
16
+ Using Creative Commons Public Licenses
17
+
18
+ Creative Commons public licenses provide a standard set of terms and
19
+ conditions that creators and other rights holders may use to share
20
+ original works of authorship and other material subject to copyright
21
+ and certain other rights specified in the public license below. The
22
+ following considerations are for informational purposes only, are not
23
+ exhaustive, and do not form part of our licenses.
24
+
25
+ Considerations for licensors: Our public licenses are
26
+ intended for use by those authorized to give the public
27
+ permission to use material in ways otherwise restricted by
28
+ copyright and certain other rights. Our licenses are
29
+ irrevocable. Licensors should read and understand the terms
30
+ and conditions of the license they choose before applying it.
31
+ Licensors should also secure all rights necessary before
32
+ applying our licenses so that the public can reuse the
33
+ material as expected. Licensors should clearly mark any
34
+ material not subject to the license. This includes other CC-
35
+ licensed material, or material used under an exception or
36
+ limitation to copyright. More considerations for licensors:
37
+ wiki.creativecommons.org/Considerations_for_licensors
38
+
39
+ Considerations for the public: By using one of our public
40
+ licenses, a licensor grants the public permission to use the
41
+ licensed material under specified terms and conditions. If
42
+ the licensor's permission is not necessary for any reason--for
43
+ example, because of any applicable exception or limitation to
44
+ copyright--then that use is not regulated by the license. Our
45
+ licenses grant only permissions under copyright and certain
46
+ other rights that a licensor has authority to grant. Use of
47
+ the licensed material may still be restricted for other
48
+ reasons, including because others have copyright or other
49
+ rights in the material. A licensor may make special requests,
50
+ such as asking that all changes be marked or described.
51
+ Although not required by our licenses, you are encouraged to
52
+ respect those requests where reasonable. More_considerations
53
+ for the public:
54
+ wiki.creativecommons.org/Considerations_for_licensees
55
+
56
+ =======================================================================
57
+
58
+ Creative Commons Attribution-NonCommercial 4.0 International Public
59
+ License
60
+
61
+ By exercising the Licensed Rights (defined below), You accept and agree
62
+ to be bound by the terms and conditions of this Creative Commons
63
+ Attribution-NonCommercial 4.0 International Public License ("Public
64
+ License"). To the extent this Public License may be interpreted as a
65
+ contract, You are granted the Licensed Rights in consideration of Your
66
+ acceptance of these terms and conditions, and the Licensor grants You
67
+ such rights in consideration of benefits the Licensor receives from
68
+ making the Licensed Material available under these terms and
69
+ conditions.
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+ Section 2 -- Scope.
142
+
143
+ a. License grant.
144
+
145
+ 1. Subject to the terms and conditions of this Public License,
146
+ the Licensor hereby grants You a worldwide, royalty-free,
147
+ non-sublicensable, non-exclusive, irrevocable license to
148
+ exercise the Licensed Rights in the Licensed Material to:
149
+
150
+ a. reproduce and Share the Licensed Material, in whole or
151
+ in part, for NonCommercial purposes only; and
152
+
153
+ b. produce, reproduce, and Share Adapted Material for
154
+ NonCommercial purposes only.
155
+
156
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
157
+ Exceptions and Limitations apply to Your use, this Public
158
+ License does not apply, and You do not need to comply with
159
+ its terms and conditions.
160
+
161
+ 3. Term. The term of this Public License is specified in Section
162
+ 6(a).
163
+
164
+ 4. Media and formats; technical modifications allowed. The
165
+ Licensor authorizes You to exercise the Licensed Rights in
166
+ all media and formats whether now known or hereafter created,
167
+ and to make technical modifications necessary to do so. The
168
+ Licensor waives and/or agrees not to assert any right or
169
+ authority to forbid You from making technical modifications
170
+ necessary to exercise the Licensed Rights, including
171
+ technical modifications necessary to circumvent Effective
172
+ Technological Measures. For purposes of this Public License,
173
+ simply making modifications authorized by this Section 2(a)
174
+ (4) never produces Adapted Material.
175
+
176
+ 5. Downstream recipients.
177
+
178
+ a. Offer from the Licensor -- Licensed Material. Every
179
+ recipient of the Licensed Material automatically
180
+ receives an offer from the Licensor to exercise the
181
+ Licensed Rights under the terms and conditions of this
182
+ Public License.
183
+
184
+ b. No downstream restrictions. You may not offer or impose
185
+ any additional or different terms or conditions on, or
186
+ apply any Effective Technological Measures to, the
187
+ Licensed Material if doing so restricts exercise of the
188
+ Licensed Rights by any recipient of the Licensed
189
+ Material.
190
+
191
+ 6. No endorsement. Nothing in this Public License constitutes or
192
+ may be construed as permission to assert or imply that You
193
+ are, or that Your use of the Licensed Material is, connected
194
+ with, or sponsored, endorsed, or granted official status by,
195
+ the Licensor or others designated to receive attribution as
196
+ provided in Section 3(a)(1)(A)(i).
197
+
198
+ b. Other rights.
199
+
200
+ 1. Moral rights, such as the right of integrity, are not
201
+ licensed under this Public License, nor are publicity,
202
+ privacy, and/or other similar personality rights; however, to
203
+ the extent possible, the Licensor waives and/or agrees not to
204
+ assert any such rights held by the Licensor to the limited
205
+ extent necessary to allow You to exercise the Licensed
206
+ Rights, but not otherwise.
207
+
208
+ 2. Patent and trademark rights are not licensed under this
209
+ Public License.
210
+
211
+ 3. To the extent possible, the Licensor waives any right to
212
+ collect royalties from You for the exercise of the Licensed
213
+ Rights, whether directly or through a collecting society
214
+ under any voluntary or waivable statutory or compulsory
215
+ licensing scheme. In all other cases the Licensor expressly
216
+ reserves any right to collect such royalties, including when
217
+ the Licensed Material is used other than for NonCommercial
218
+ purposes.
219
+
220
+ Section 3 -- License Conditions.
221
+
222
+ Your exercise of the Licensed Rights is expressly made subject to the
223
+ following conditions.
224
+
225
+ a. Attribution.
226
+
227
+ 1. If You Share the Licensed Material (including in modified
228
+ form), You must:
229
+
230
+ a. retain the following if it is supplied by the Licensor
231
+ with the Licensed Material:
232
+
233
+ i. identification of the creator(s) of the Licensed
234
+ Material and any others designated to receive
235
+ attribution, in any reasonable manner requested by
236
+ the Licensor (including by pseudonym if
237
+ designated);
238
+
239
+ ii. a copyright notice;
240
+
241
+ iii. a notice that refers to this Public License;
242
+
243
+ iv. a notice that refers to the disclaimer of
244
+ warranties;
245
+
246
+ v. a URI or hyperlink to the Licensed Material to the
247
+ extent reasonably practicable;
248
+
249
+ b. indicate if You modified the Licensed Material and
250
+ retain an indication of any previous modifications; and
251
+
252
+ c. indicate the Licensed Material is licensed under this
253
+ Public License, and include the text of, or the URI or
254
+ hyperlink to, this Public License.
255
+
256
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
257
+ reasonable manner based on the medium, means, and context in
258
+ which You Share the Licensed Material. For example, it may be
259
+ reasonable to satisfy the conditions by providing a URI or
260
+ hyperlink to a resource that includes the required
261
+ information.
262
+
263
+ 3. If requested by the Licensor, You must remove any of the
264
+ information required by Section 3(a)(1)(A) to the extent
265
+ reasonably practicable.
266
+
267
+ 4. If You Share Adapted Material You produce, the Adapter's
268
+ License You apply must not prevent recipients of the Adapted
269
+ Material from complying with this Public License.
270
+
271
+ Section 4 -- Sui Generis Database Rights.
272
+
273
+ Where the Licensed Rights include Sui Generis Database Rights that
274
+ apply to Your use of the Licensed Material:
275
+
276
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
277
+ to extract, reuse, reproduce, and Share all or a substantial
278
+ portion of the contents of the database for NonCommercial purposes
279
+ only;
280
+
281
+ b. if You include all or a substantial portion of the database
282
+ contents in a database in which You have Sui Generis Database
283
+ Rights, then the database in which You have Sui Generis Database
284
+ Rights (but not its individual contents) is Adapted Material; and
285
+
286
+ c. You must comply with the conditions in Section 3(a) if You Share
287
+ all or a substantial portion of the contents of the database.
288
+
289
+ For the avoidance of doubt, this Section 4 supplements and does not
290
+ replace Your obligations under this Public License where the Licensed
291
+ Rights include other Copyright and Similar Rights.
292
+
293
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
294
+
295
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
296
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
297
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
298
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
299
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
300
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
301
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
302
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
303
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
304
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
305
+
306
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
307
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
308
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
309
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
310
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
311
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
312
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
313
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
314
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
315
+
316
+ c. The disclaimer of warranties and limitation of liability provided
317
+ above shall be interpreted in a manner that, to the extent
318
+ possible, most closely approximates an absolute disclaimer and
319
+ waiver of all liability.
320
+
321
+ Section 6 -- Term and Termination.
322
+
323
+ a. This Public License applies for the term of the Copyright and
324
+ Similar Rights licensed here. However, if You fail to comply with
325
+ this Public License, then Your rights under this Public License
326
+ terminate automatically.
327
+
328
+ b. Where Your right to use the Licensed Material has terminated under
329
+ Section 6(a), it reinstates:
330
+
331
+ 1. automatically as of the date the violation is cured, provided
332
+ it is cured within 30 days of Your discovery of the
333
+ violation; or
334
+
335
+ 2. upon express reinstatement by the Licensor.
336
+
337
+ For the avoidance of doubt, this Section 6(b) does not affect any
338
+ right the Licensor may have to seek remedies for Your violations
339
+ of this Public License.
340
+
341
+ c. For the avoidance of doubt, the Licensor may also offer the
342
+ Licensed Material under separate terms or conditions or stop
343
+ distributing the Licensed Material at any time; however, doing so
344
+ will not terminate this Public License.
345
+
346
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347
+ License.
348
+
349
+ Section 7 -- Other Terms and Conditions.
350
+
351
+ a. The Licensor shall not be bound by any additional or different
352
+ terms or conditions communicated by You unless expressly agreed.
353
+
354
+ b. Any arrangements, understandings, or agreements regarding the
355
+ Licensed Material not stated herein are separate from and
356
+ independent of the terms and conditions of this Public License.
357
+
358
+ Section 8 -- Interpretation.
359
+
360
+ a. For the avoidance of doubt, this Public License does not, and
361
+ shall not be interpreted to, reduce, limit, restrict, or impose
362
+ conditions on any use of the Licensed Material that could lawfully
363
+ be made without permission under this Public License.
364
+
365
+ b. To the extent possible, if any provision of this Public License is
366
+ deemed unenforceable, it shall be automatically reformed to the
367
+ minimum extent necessary to make it enforceable. If the provision
368
+ cannot be reformed, it shall be severed from this Public License
369
+ without affecting the enforceability of the remaining terms and
370
+ conditions.
371
+
372
+ c. No term or condition of this Public License will be waived and no
373
+ failure to comply consented to unless expressly agreed to by the
374
+ Licensor.
375
+
376
+ d. Nothing in this Public License constitutes or may be interpreted
377
+ as a limitation upon, or waiver of, any privileges and immunities
378
+ that apply to the Licensor or You, including from the legal
379
+ processes of any jurisdiction or authority.
380
+
381
+ =======================================================================
382
+
383
+ Creative Commons is not a party to its public
384
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
385
+ its public licenses to material it publishes and in those instances
386
+ will be considered the “Licensor.” The text of the Creative Commons
387
+ public licenses is dedicated to the public domain under the CC0 Public
388
+ Domain Dedication. Except for the limited purpose of indicating that
389
+ material is shared under a Creative Commons public license or as
390
+ otherwise permitted by the Creative Commons policies published at
391
+ creativecommons.org/policies, Creative Commons does not authorize the
392
+ use of the trademark "Creative Commons" or any other trademark or logo
393
+ of Creative Commons without its prior written consent including,
394
+ without limitation, in connection with any unauthorized modifications
395
+ to any of its public licenses or any other arrangements,
396
+ understandings, or agreements concerning use of licensed material. For
397
+ the avoidance of doubt, this paragraph does not form part of the
398
+ public licenses.
399
+
400
+ Creative Commons may be contacted at creativecommons.org.
torchhub/facebookresearch_dinov2_main/MODEL_CARD.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card for DINOv2-S/B/L/g
2
+
3
+ These are Vision Transformer models trained following the method described in the paper:
4
+ "DINOv2: Learning Robust Visual Features without Supervision"
5
+
6
+ We provide 4 models: 1 ViT-g trained from scratch, and 3 ViT-S/B/L models distilled from the ViT-g.
7
+
8
+ ## Model Details
9
+ The model takes an image as input and returns a class token and patch tokens.
10
+
11
+ The embedding dimension is:
12
+ - 384 for ViT-S.
13
+ - 768 for ViT-B.
14
+ - 1024 for ViT-L.
15
+ - 1536 for ViT-g.
16
+
17
+ The models follow a Transformer architecture, with a patch size of 14.
18
+
19
+ For a 224x224 image, this results in 1 class token + 256 patch tokens.
20
+
21
+ The models can accept larger images provided the image shapes are multiples of the patch size (14).
22
+ If this condition is not verified, the model will crop to the closest smaller multiple of the patch size.
23
+
24
+ ### Model Description
25
+
26
+ - **Developed by:** Meta AI
27
+ - **Model type:** Vision Transformer
28
+ - **License:** CC-BY-NC
29
+
30
+ - **Repository:** https://github.com/facebookresearch/dinov2
31
+ - **Paper:** https://arxiv.org/abs/2304.07193
32
+ - **Demo:** https://dinov2.metademolab.com/
33
+
34
+ ## Uses
35
+
36
+ The models are vision backbones providing multi-purpose features for downstream tasks.
37
+
38
+ ### Direct Use
39
+
40
+ The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results:
41
+ - on depth estimation, semantic segmentation, using linear layers.
42
+ - on image classification, using k-NN classifiers on the class token.
43
+ - on image classification, with logistic regression classifiers applied on the class token.
44
+ - on image classification, with a linear layer applied on the class token and the average of the patch tokens.
45
+ - on image retrieval using nearest neighbors.
46
+
47
+ ### Downstream Use
48
+
49
+ It is technically possible to perform fine-tuning on the models, for small gains (we measured +2% on ImageNet-1k classification).
50
+ We recommend keeping this as a very last step and only when necessary, as the features already provide good performance out-of-the-box.
51
+
52
+ ## Bias, Risks, and Limitations
53
+
54
+ Despite improvements thanks to the training method not using annotations, we still observe significant biases in our models toward rich households from Western countries.
55
+
56
+ ### Recommendations
57
+
58
+ We expect fine-tuning will increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels.
59
+
60
+ ## How to Get Started with the Model
61
+
62
+ Use the code below to get started with the model.
63
+
64
+ ```python
65
+ import torch
66
+ dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
67
+ dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
68
+ dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
69
+ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
70
+ ```
71
+
72
+ ## Training Details
73
+
74
+ ### Training Data
75
+
76
+ - **Training data:** LVD-142M (see paper)
77
+ - **Training regime:** fp16 using PyTorch-FSDP mixed-precision.
78
+
79
+ ### Training Procedure
80
+
81
+ - **Training objective:**
82
+ - DINO self-distillation loss with multi-crop
83
+ - iBOT masked-image modeling loss
84
+ - KoLeo regularization on [CLS] tokens
85
+ - **Architectures:**
86
+ - ViT-S (21M params): Patch size 14, embedding dimension 384, 6 heads, MLP FFN
87
+ - ViT-B (86M params): Patch size 14, embedding dimension 768, 12 heads, MLP FFN
88
+ - ViT-L (0.3B params): Patch size 14, embedding dimension 1024, 16 heads, MLP FFN
89
+ - ViT-g (1.1B params): Patch size 14, embedding dimension 1536, 24 heads, SwiGLU FFN
90
+ - **Distillation:**
91
+ - Distillation follows the standard DINOv2 pretraining procedure, except the teacher is a pretrained ViT-g, frozen.
92
+
93
+ ## Evaluation
94
+
95
+ We refer users to the associated paper for the evaluation protocols.
96
+
97
+ <table>
98
+ <tr>
99
+ <th>model</th>
100
+ <th colspan="3">ImageNet-1k</th>
101
+ <th>NYU-Depth v2</th>
102
+ <th>SUN-RGBD</th>
103
+ <th>ADE20k</th>
104
+ <th>iNaturalist 2018</th>
105
+ <th>Oxford-H</th>
106
+ </tr>
107
+ <tr>
108
+ <th rowspan="2">task</th>
109
+ <th>classif. (acc)</th>
110
+ <th>classif. (acc)</th>
111
+ <th>classif. V2 (acc)</th>
112
+ <th>depth (RMSE)</th>
113
+ <th>depth (RMSE)</th>
114
+ <th>segm. (mAP)</th>
115
+ <th>classif. (acc)</th>
116
+ <th>retrieval (mAP)</th>
117
+ </tr>
118
+ <tr>
119
+ <!-- <th>^</th> -->
120
+ <th>k-NN</th>
121
+ <th>linear</th>
122
+ <th>linear</th>
123
+ <th>linear<br />4 layers</th>
124
+ <th>NYU-D transfer</th>
125
+ <th>multiscale</th>
126
+ <th>linear</th>
127
+ <th>nearest neighbor</th>
128
+ </tr>
129
+ <tr>
130
+ <td>ViT-S/14</td>
131
+ <td align="right">79.0%</td>
132
+ <td align="right">81.1%</td>
133
+ <td align="right">70.8%</td>
134
+ <td align="right">0.417</td>
135
+ <td align="right">0.431</td>
136
+ <td align="right">47.2</td>
137
+ <td align="right">69.5%</td>
138
+ <td align="right">43.2</td>
139
+ </tr>
140
+ <tr>
141
+ <td>ViT-B/14</td>
142
+ <td align="right">82.1%</td>
143
+ <td align="right">84.5%</td>
144
+ <td align="right">74.9%</td>
145
+ <td align="right">0.362</td>
146
+ <td align="right">0.400</td>
147
+ <td align="right">51.3</td>
148
+ <td align="right">76.3%</td>
149
+ <td align="right">49.5</td>
150
+ </tr>
151
+ <tr>
152
+ <td>ViT-L/14</td>
153
+ <td align="right">83.5%</td>
154
+ <td align="right">86.3%</td>
155
+ <td align="right">77.6%</td>
156
+ <td align="right">0.333</td>
157
+ <td align="right">0.396</td>
158
+ <td align="right">53.1</td>
159
+ <td align="right">79.8%</td>
160
+ <td align="right">54.0</td>
161
+ </tr>
162
+ <tr>
163
+ <td>ViT-g/14</td>
164
+ <td align="right">83.5%</td>
165
+ <td align="right">86.5%</td>
166
+ <td align="right">78.4%</td>
167
+ <td align="right">0.298</td>
168
+ <td align="right">0.362</td>
169
+ <td align="right">53.0</td>
170
+ <td align="right">81.6%</td>
171
+ <td align="right">52.3</td>
172
+ </tr>
173
+ </table>
174
+
175
+ ## Environmental Impact
176
+
177
+ - **Hardware Type:** Nvidia A100
178
+ - **Hours used:** 22,000 for ViT-g, 4,500 for ViT-S distillation, 5,300 for ViT-B distillation, 8,000 for ViT-L distillation
179
+ - **Cloud Provider:** Private infra
180
+ - **Compute Region:** USA
181
+ - **Carbon Emitted:** 7t CO2eq
182
+
183
+ #### Hardware
184
+
185
+ Nvidia A100 GPUs
186
+
187
+ #### Software
188
+
189
+ PyTorch 2.0,
190
+ xFormers 0.0.18
191
+
192
+ **BibTeX**
193
+
194
+ ```
195
+ @misc{oquab2023dinov2,
196
+ title={DINOv2: Learning Robust Visual Features without Supervision},
197
+ author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
198
+ journal={arXiv:2304.07193},
199
+ year={2023}
200
+ }
201
+ ```
torchhub/facebookresearch_dinov2_main/README.md ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DINOv2: Learning Robust Visual Features without Supervision
2
+
3
+ **[Meta AI Research, FAIR](https://ai.facebook.com/research/)**
4
+
5
+ Maxime Oquab,
6
+ Timothée Darcet,
7
+ Théo Moutakanni,
8
+ Huy V. Vo,
9
+ Marc Szafraniec,
10
+ Vasil Khalidov,
11
+ Patrick Labatut,
12
+ Armand Joulin,
13
+ Piotr Bojanowski
14
+
15
+ [[`Paper`](https://arxiv.org/abs/2304.07193)] [[`Blog`](https://ai.facebook.com/blog/dino-v2-computer-vision-self-supervised-learning/)] [[`Demo`](https://dinov2.metademolab.com)] [[`BibTeX`](#citing-dinov2)]
16
+
17
+ PyTorch implementation and pretrained models for DINOv2. For details, see the paper: **[DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)**.
18
+
19
+ DINOv2 models produce high-performance visual features that can be directly employed with classifiers as simple as linear layers on a variety of computer vision tasks; these visual features are robust and perform well across domains without any requirement for fine-tuning. The models were pretrained on a dataset of 142 M images without using any labels or annotations.
20
+
21
+ https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b429-578badf5c356
22
+
23
+ <div align="center">
24
+ Visualization of the three first principal components of the patch features of all frames, mapped to RGB values.
25
+ </div>
26
+
27
+ ## Pretrained models
28
+
29
+ <table style="margin: auto">
30
+ <tr>
31
+ <th>model</th>
32
+ <th># of<br />params</th>
33
+ <th>ImageNet<br />k-NN</th>
34
+ <th>ImageNet<br />linear</th>
35
+ <th>download</th>
36
+ </tr>
37
+ <tr>
38
+ <td>ViT-S/14 distilled</td>
39
+ <td align="right">21 M</td>
40
+ <td align="right">79.0%</td>
41
+ <td align="right">81.1%</td>
42
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth">backbone only</a></td>
43
+ </tr>
44
+ <tr>
45
+ <td>ViT-B/14 distilled</td>
46
+ <td align="right">86 M</td>
47
+ <td align="right">82.1%</td>
48
+ <td align="right">84.5%</td>
49
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth">backbone only</a></td>
50
+ </tr>
51
+ <tr>
52
+ <td>ViT-L/14 distilled</td>
53
+ <td align="right">300 M</td>
54
+ <td align="right">83.5%</td>
55
+ <td align="right">86.3%</td>
56
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth">backbone only</a></td>
57
+ </tr>
58
+ <tr>
59
+ <td>ViT-g/14</td>
60
+ <td align="right">1,100 M</td>
61
+ <td align="right">83.5%</td>
62
+ <td align="right">86.5%</td>
63
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth">backbone only</a></td>
64
+ </tr>
65
+ </table>
66
+
67
+ ### Pretrained models via PyTorch Hub
68
+
69
+ Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install PyTorch (the only required dependency for loading the model). Installing PyTorch with CUDA support is strongly recommended.
70
+
71
+ A corresponding [model card](MODEL_CARD.md) is included in the repository.
72
+
73
+ ```python
74
+ import torch
75
+
76
+ dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
77
+ dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
78
+ dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
79
+ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
80
+ ```
81
+
82
+ ## Installation
83
+
84
+ The training and evaluation code requires PyTorch 2.0 and [xFormers](https://github.com/facebookresearch/xformers) 0.0.18 as well as a number of other 3rd party packages. Note that the code has only been tested with the specified versions and also expects a Linux environment. To setup all the required dependencies for training and evaluation, please follow the instructions below:
85
+
86
+ *[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html)* **(Recommended)** - Clone the repository and then create and activate a `dinov2` conda environment using the provided environment definition:
87
+
88
+ ```shell
89
+ conda env create -f conda.yaml
90
+ conda activate dinov2
91
+ ```
92
+
93
+ *[pip](https://pip.pypa.io/en/stable/getting-started/)* - Clone the repository and then use the provided `requirements.txt` to install the dependencies:
94
+
95
+ ```shell
96
+ pip install -r requirements.txt
97
+ ```
98
+
99
+ ## Data preparation
100
+
101
+ ### ImageNet-1k
102
+
103
+ The root directory of the dataset should hold the following contents:
104
+
105
+ - `<ROOT>/test/ILSVRC2012_test_00000001.JPEG`
106
+ - `<ROOT>/test/[..]`
107
+ - `<ROOT>/test/ILSVRC2012_test_00100000.JPEG`
108
+ - `<ROOT>/train/n01440764/n01440764_10026.JPEG`
109
+ - `<ROOT>/train/[...]`
110
+ - `<ROOT>/train/n15075141/n15075141_9993.JPEG`
111
+ - `<ROOT>/val/n01440764/ILSVRC2012_val_00000293.JPEG`
112
+ - `<ROOT>/val/[...]`
113
+ - `<ROOT>/val/n15075141/ILSVRC2012_val_00049174.JPEG`
114
+ - `<ROOT>/labels.txt`
115
+
116
+ The provided dataset implementation expects a few additional metadata files to be present under the extra directory:
117
+
118
+ - `<EXTRA>/class-ids-TRAIN.npy`
119
+ - `<EXTRA>/class-ids-VAL.npy`
120
+ - `<EXTRA>/class-names-TRAIN.npy`
121
+ - `<EXTRA>/class-names-VAL.npy`
122
+ - `<EXTRA>/entries-TEST.npy`
123
+ - `<EXTRA>/entries-TRAIN.npy`
124
+ - `<EXTRA>/entries-VAL.npy`
125
+
126
+ These metadata files can be generated (once) with the following lines of Python code:
127
+
128
+ ```python
129
+ from dinov2.data.datasets import ImageNet
130
+
131
+ for split in ImageNet.Split:
132
+ dataset = ImageNet(split=split, root="<ROOT>", extra="<EXTRA>")
133
+ dataset.dump_extra()
134
+ ```
135
+
136
+ Note that the root and extra directories do not have to be distinct directories.
137
+
138
+ ### ImageNet-22k
139
+
140
+ Please adapt the [dataset class](dinov2/data/datasets/image_net_22k.py) to match your local setup.
141
+
142
+ <br />
143
+
144
+ :warning: To execute the commands provided in the next sections for training and evaluation, the `dinov2` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`.
145
+
146
+ ## Training
147
+
148
+ ### Fast setup: training DINOv2 ViT-L/16 on ImageNet-1k
149
+
150
+ Run DINOv2 training on 4 A100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit:
151
+
152
+ ```shell
153
+ python dinov2/run/train/train.py \
154
+ --nodes 4 \
155
+ --config-file dinov2/configs/train/vitl16_short.yaml \
156
+ --output-dir <PATH/TO/OUTPUT/DIR> \
157
+ train.dataset_path=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
158
+ ```
159
+
160
+ Training time is approximately 1 day and the resulting checkpoint should reach 81.6% on k-NN eval and 82.9% on linear eval.
161
+
162
+ The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation.
163
+
164
+ ### Long setup: training DINOv2 ViT-L/14 on ImageNet-22k
165
+
166
+ Run DINOv2 training on 12 A100-80GB nodes (96 GPUs) in a SLURM cluster environment with submitit:
167
+
168
+ ```shell
169
+ python dinov2/run/train/train.py \
170
+ --nodes 12 \
171
+ --config-file dinov2/configs/train/vitl14.yaml \
172
+ --output-dir <PATH/TO/OUTPUT/DIR> \
173
+ train.dataset_path=ImageNet22k:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
174
+ ```
175
+
176
+ Training time is approximately 3.3 days and the resulting checkpoint should reach 82.0% on k-NN eval and 84.5% on linear eval.
177
+
178
+ The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation.
179
+
180
+
181
+ ## Evaluation
182
+
183
+ The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node:
184
+
185
+ ### k-NN classification on ImageNet-1k
186
+
187
+ ```shell
188
+ python dinov2/run/eval/knn.py \
189
+ --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
190
+ --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
191
+ --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/knn \
192
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
193
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
194
+ ```
195
+
196
+ ### Logistic regression classification on ImageNet-1k
197
+
198
+ ```shell
199
+ python dinov2/run/eval/log_regression.py \
200
+ --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
201
+ --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
202
+ --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/logreg \
203
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
204
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
205
+ ```
206
+
207
+ ### Linear classification with data augmentation on ImageNet-1k
208
+
209
+ ```shell
210
+ python dinov2/run/eval/linear.py \
211
+ --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
212
+ --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
213
+ --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/linear \
214
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
215
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
216
+ ```
217
+
218
+ We release the weights from evaluating the different models:
219
+
220
+ <table style="margin: auto">
221
+ <tr>
222
+ <th>model</th>
223
+ <th>ImageNet<br />top-1</th>
224
+ <th>linear evaluation</th>
225
+ </tr>
226
+ <tr>
227
+ <td>ViT-S/14 distilled</td>
228
+ <td align="right">81.1%</td>
229
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">linear head weights</a></td>
230
+ </tr>
231
+ <tr>
232
+ <td>ViT-B/14 distilled</td>
233
+ <td align="right">84.5%</td>
234
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">linear head weights</a></td>
235
+ </tr>
236
+ <tr>
237
+ <td>ViT-L/14 distilled</td>
238
+ <td align="right">86.3%</td>
239
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">linear head weights</a></td>
240
+ </tr>
241
+ <tr>
242
+ <td>ViT-g/14</td>
243
+ <td align="right">86.5%</td>
244
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">linear head weights</a></td>
245
+ </tr>
246
+ </table>
247
+
248
+ The performance of the provided pretrained model weights can be evaluated as follows on ImageNet-1k:
249
+
250
+ ```shell
251
+ python dinov2/run/eval/linear.py \
252
+ --config-file dinov2/configs/eval/vitg14_pretrain.yaml \
253
+ --pretrained-weights https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth \
254
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
255
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
256
+ ```
257
+
258
+ ## License
259
+
260
+ DINOv2 code and model weights are released under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for additional details.
261
+
262
+ ## Contributing
263
+
264
+ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
265
+
266
+ ## Citing DINOv2
267
+
268
+ If you find this repository useful, please consider giving a star :star: and citation :t-rex::
269
+
270
+ ```
271
+ @misc{oquab2023dinov2,
272
+ title={DINOv2: Learning Robust Visual Features without Supervision},
273
+ author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
274
+ journal={arXiv:2304.07193},
275
+ year={2023}
276
+ }
277
+ ```
torchhub/facebookresearch_dinov2_main/__pycache__/hubconf.cpython-310.pyc ADDED
Binary file (4.25 kB). View file
 
torchhub/facebookresearch_dinov2_main/__pycache__/vision_transformer.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
torchhub/facebookresearch_dinov2_main/conda.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dinov2
2
+ channels:
3
+ - defaults
4
+ - pytorch
5
+ - nvidia
6
+ - xformers
7
+ - conda-forge
8
+ dependencies:
9
+ - python=3.9
10
+ - pytorch::pytorch=2.0.0
11
+ - pytorch::pytorch-cuda=11.7.0
12
+ - pytorch::torchvision=0.15.0
13
+ - omegaconf
14
+ - torchmetrics=0.10.3
15
+ - fvcore
16
+ - iopath
17
+ - xformers::xformers=0.0.18
18
+ - pip
19
+ - pip:
20
+ - git+https://github.com/facebookincubator/submitit
21
+ - --extra-index-url https://pypi.nvidia.com
22
+ - cuml-cu11
torchhub/facebookresearch_dinov2_main/dinov2/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ __version__ = "0.0.1"
torchhub/facebookresearch_dinov2_main/dinov2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (202 Bytes). View file
 
torchhub/facebookresearch_dinov2_main/dinov2/configs/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import pathlib
8
+
9
+ from omegaconf import OmegaConf
10
+
11
+
12
+ def load_config(config_name: str):
13
+ config_filename = config_name + ".yaml"
14
+ return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
15
+
16
+
17
+ dinov2_default_config = load_config("ssl_default_config")
18
+
19
+
20
+ def load_and_merge_config(config_name: str):
21
+ default_config = OmegaConf.create(dinov2_default_config)
22
+ loaded_config = load_config(config_name)
23
+ return OmegaConf.merge(default_config, loaded_config)
torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitb14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_base
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitg14_pretrain.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_giant2
3
+ patch_size: 14
4
+ ffn_layer: swiglufused
5
+ crops:
6
+ global_crops_size: 518 # this is to set up the position embeddings properly
7
+ local_crops_size: 98