ReubenSun commited on
Commit
2ac1c2d
·
1 Parent(s): 13b826f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +11 -0
  3. app.py +135 -0
  4. examples/images/000.png +3 -0
  5. examples/images/001.png +3 -0
  6. examples/images/004.png +3 -0
  7. examples/images/008.png +3 -0
  8. examples/images/028.png +3 -0
  9. examples/images/032.png +3 -0
  10. examples/images/061.png +3 -0
  11. examples/images/107.png +3 -0
  12. requirements.txt +50 -0
  13. step1x3d_geometry/__init__.py +52 -0
  14. step1x3d_geometry/data/Objaverse.py +73 -0
  15. step1x3d_geometry/data/__init__.py +1 -0
  16. step1x3d_geometry/data/base.py +350 -0
  17. step1x3d_geometry/models/__init__.py +1 -0
  18. step1x3d_geometry/models/attention.py +776 -0
  19. step1x3d_geometry/models/attention_processor.py +482 -0
  20. step1x3d_geometry/models/autoencoders/__init__.py +3 -0
  21. step1x3d_geometry/models/autoencoders/michelangelo_autoencoder.py +765 -0
  22. step1x3d_geometry/models/autoencoders/surface_extractors.py +137 -0
  23. step1x3d_geometry/models/autoencoders/transformers/attention.py +286 -0
  24. step1x3d_geometry/models/autoencoders/transformers/perceiver_1d.py +50 -0
  25. step1x3d_geometry/models/autoencoders/transformers/utils.py +21 -0
  26. step1x3d_geometry/models/autoencoders/volume_decoders.py +327 -0
  27. step1x3d_geometry/models/conditional_encoders/__init__.py +6 -0
  28. step1x3d_geometry/models/conditional_encoders/base.py +202 -0
  29. step1x3d_geometry/models/conditional_encoders/clip/modeling_clip.py +1597 -0
  30. step1x3d_geometry/models/conditional_encoders/clip/modeling_conditional_clip.py +443 -0
  31. step1x3d_geometry/models/conditional_encoders/dinov2/modeling_conditional_dinov2.py +248 -0
  32. step1x3d_geometry/models/conditional_encoders/dinov2/modeling_dinov2.py +978 -0
  33. step1x3d_geometry/models/conditional_encoders/dinov2_clip_encoder.py +514 -0
  34. step1x3d_geometry/models/conditional_encoders/dinov2_encoder.py +296 -0
  35. step1x3d_geometry/models/conditional_encoders/dinov2_with_registers/modeling_dinov2_with_registers.py +1088 -0
  36. step1x3d_geometry/models/conditional_encoders/label_encoder.py +167 -0
  37. step1x3d_geometry/models/conditional_encoders/t5_encoder.py +271 -0
  38. step1x3d_geometry/models/pipelines/pipeline.py +513 -0
  39. step1x3d_geometry/models/pipelines/pipeline_utils.py +404 -0
  40. step1x3d_geometry/models/transformers/__init__.py +1 -0
  41. step1x3d_geometry/models/transformers/flux_transformer_1d.py +600 -0
  42. step1x3d_geometry/models/transformers/pixart_transformer_1d.py +574 -0
  43. step1x3d_geometry/systems/__init__.py +1 -0
  44. step1x3d_geometry/systems/base.py +210 -0
  45. step1x3d_geometry/systems/shape_autoencoder.py +151 -0
  46. step1x3d_geometry/systems/shape_diffusion.py +425 -0
  47. step1x3d_geometry/systems/shape_rectified_flow.py +474 -0
  48. step1x3d_geometry/systems/utils.py +391 -0
  49. step1x3d_geometry/utils/__init__.py +1 -0
  50. step1x3d_geometry/utils/base.py +215 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output
2
+ outputs
3
+ **__pycache__
4
+ .DS_Store
5
+ cache
6
+ step1x3d_texture/custom_rasterizer/build
7
+ step1x3d_texture/custom_rasterizer/dist
8
+ step1x3d_texture/custom_rasterizer/custom_rasterizer.egg-info
9
+ step1x3d_texture/differentiable_renderer/build
10
+ step1x3d_texture/differentiable_renderer/dist
11
+ step1x3d_texture/differentiable_renderer/mesh_processor.egg-info
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import uuid
4
+ import torch
5
+ import trimesh
6
+ import argparse
7
+ import numpy as np
8
+ import gradio as gr
9
+ from step1x3d_geometry.models.pipelines.pipeline import Step1X3DGeometryPipeline
10
+ from step1x3d_texture.pipelines.step1x_3d_texture_synthesis_pipeline import (
11
+ Step1X3DTexturePipeline,
12
+ )
13
+ from step1x3d_texture.utils.shape_post_process import (
14
+ FaceReducer,
15
+ DegenerateFaceRemover,
16
+ )
17
+
18
+
19
+ def generate_func(
20
+ input_image_path, guidance_scale, inference_steps, max_facenum, symmetry, edge_type
21
+ ):
22
+ if "Label" in args.geometry_model:
23
+ out = geometry_model(
24
+ input_image_path,
25
+ label={"symmetry": symmetry, "edge_type": edge_type},
26
+ guidance_scale=float(guidance_scale),
27
+ octree_resolution=384,
28
+ max_facenum=int(max_facenum),
29
+ num_inference_steps=int(inference_steps),
30
+ )
31
+ else:
32
+ out = geometry_model(
33
+ input_image_path,
34
+ guidance_scale=float(guidance_scale),
35
+ num_inference_steps=int(inference_steps),
36
+ max_facenum=int(max_facenum),
37
+ )
38
+
39
+ save_name = str(uuid.uuid4())
40
+ print(save_name)
41
+ geometry_save_path = f"{args.cache_dir}/{save_name}.glb"
42
+ geometry_mesh = out.mesh[0]
43
+ geometry_mesh.export(geometry_save_path)
44
+
45
+ geometry_mesh = DegenerateFaceRemover()(geometry_mesh)
46
+ geometry_mesh = FaceReducer()(geometry_mesh)
47
+ textured_mesh = texture_model(input_image_path, geometry_mesh)
48
+ textured_save_path = f"{args.cache_dir}/{save_name}-textured.glb"
49
+ textured_mesh.export(textured_save_path)
50
+
51
+ torch.cuda.empty_cache()
52
+ print("Generate finish")
53
+ return geometry_save_path, textured_save_path
54
+
55
+
56
+ if __name__ == "__main__":
57
+ parser = argparse.ArgumentParser()
58
+ parser.add_argument(
59
+ "--geometry_model", type=str, default="Step1X-3D-Geometry-Label-1300m"
60
+ )
61
+ parser.add_argument(
62
+ "--texture_model", type=str, default="Step1X-3D-Texture"
63
+ )
64
+ parser.add_argument("--cache_dir", type=str, default="cache")
65
+ parser.add_argument("--port", type=int, default=7861)
66
+ parser.add_argument("--host", type=str, default="0.0.0.0")
67
+ args = parser.parse_args()
68
+
69
+ os.makedirs(args.cache_dir, exist_ok=True)
70
+
71
+ geometry_model = Step1X3DGeometryPipeline.from_pretrained(
72
+ "stepfun-ai/Step1X-3D", subfolder=args.geometry_model
73
+ ).to("cuda")
74
+
75
+ texture_model = Step1X3DTexturePipeline.from_pretrained("stepfun-ai/Step1X-3D", subfolder=args.texture_model)
76
+
77
+ with gr.Blocks(title="Step1X-3D demo") as demo:
78
+ gr.Markdown("# Step1X-3D")
79
+ with gr.Row():
80
+ with gr.Column(scale=2):
81
+ input_image = gr.Image(
82
+ label="Image", type="filepath", image_mode="RGBA"
83
+ )
84
+ guidance_scale = gr.Number(label="Guidance Scale", value="7.5")
85
+ inference_steps = gr.Slider(
86
+ label="Inferece Steps", minimum=1, maximum=100, value=50
87
+ )
88
+ max_facenum = gr.Number(label="Max Face Num", value="400000")
89
+ symmetry = gr.Radio(
90
+ choices=["x", "asymmetry"],
91
+ label="Symmetry Type",
92
+ value="x",
93
+ type="value",
94
+ )
95
+ edge_type = gr.Radio(
96
+ choices=["sharp", "normal", "smooth"],
97
+ label="Edge Type",
98
+ value="sharp",
99
+ type="value",
100
+ )
101
+ btn = gr.Button("Start")
102
+ with gr.Column(scale=4):
103
+ textured_preview = gr.Model3D(label="Textured", height=380)
104
+ geometry_preview = gr.Model3D(label="Geometry", height=380)
105
+ with gr.Column(scale=1):
106
+ gr.Examples(
107
+ examples=[
108
+ ["examples/images/000.png"],
109
+ ["examples/images/001.png"],
110
+ ["examples/images/004.png"],
111
+ ["examples/images/008.png"],
112
+ ["examples/images/028.png"],
113
+ ["examples/images/032.png"],
114
+ ["examples/images/061.png"],
115
+ ["examples/images/107.png"],
116
+ ],
117
+ inputs=[input_image],
118
+ cache_examples=False,
119
+ )
120
+
121
+ btn.click(
122
+ generate_func,
123
+ inputs=[
124
+ input_image,
125
+ guidance_scale,
126
+ inference_steps,
127
+ max_facenum,
128
+ symmetry,
129
+ edge_type,
130
+ ],
131
+ outputs=[geometry_preview, textured_preview],
132
+ )
133
+
134
+ demo.launch(server_name=args.host, server_port=args.port)
135
+ demo.queue(concurrency_count=3)
examples/images/000.png ADDED

Git LFS Details

  • SHA256: 62284b41c010dd81524c51d12da4369fc458abd955011f59ce395266a02efb5f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.54 MB
examples/images/001.png ADDED

Git LFS Details

  • SHA256: e93cc2c9850b6ea7cf233ae2f8d96246d86de7fc1d9bf079f2455a47938e946a
  • Pointer size: 131 Bytes
  • Size of remote file: 608 kB
examples/images/004.png ADDED

Git LFS Details

  • SHA256: 19aa7e05ca0cb1eb4e7809eeded332cce8c21daf9e5458338b6ad3bfbba85679
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
examples/images/008.png ADDED

Git LFS Details

  • SHA256: 67cf8e33b715641599c5489f06f6c5d1da312faf3c95196395d9d81a1aa112e1
  • Pointer size: 131 Bytes
  • Size of remote file: 367 kB
examples/images/028.png ADDED

Git LFS Details

  • SHA256: b12c3b18f615fb5c887bfbd946c69eff8934519182ee5ef13f3853ca64e0bc22
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
examples/images/032.png ADDED

Git LFS Details

  • SHA256: 7f655fc199fed98a8d663e6e39baa94307af3e9494efa6389ac5b90c81b45b18
  • Pointer size: 132 Bytes
  • Size of remote file: 1.56 MB
examples/images/061.png ADDED

Git LFS Details

  • SHA256: e28ffd293ba94f8d92c7bef7db7125d6df5e05287f116d6f93617623aa5d7ecf
  • Pointer size: 131 Bytes
  • Size of remote file: 307 kB
examples/images/107.png ADDED

Git LFS Details

  • SHA256: 70c7d618bfd70125d0b61007e549f3369273b1de866b30c703a68045bceb8950
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
requirements.txt ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==2.19.0
2
+ diffusers==0.32.2
3
+ einops==0.8.0
4
+ huggingface-hub==0.26.2
5
+ imageio==2.34.1
6
+ jaxtyping==0.2.28
7
+ joblib==1.4.0
8
+ lightning-utilities==0.11.2
9
+ matplotlib==3.8.4
10
+ numpy==1.26.4
11
+ omegaconf==2.3.0
12
+ opencv-python-headless==4.10.0.84
13
+ pandas==2.2.2
14
+ pillow==10.3.0
15
+ plyfile==1.0.3
16
+ PyMCubes==0.1.4
17
+ pyparsing==3.1.2
18
+ pytorch-lightning==2.2.4
19
+ PyYAML==6.0.1
20
+ safetensors==0.4.3
21
+ scikit-image==0.23.2
22
+ scipy==1.13.0
23
+ tensorboard==2.16.2
24
+ tensorboardX==2.6.2.2
25
+ timm==0.9.16
26
+ tokenizers==0.21.0
27
+ tqdm==4.66.2
28
+ transformers==4.48.0
29
+ trimesh==4.3.2
30
+ spaces==0.28.3
31
+ accelerate==1.5.2
32
+ rembg==2.0.65
33
+ gradio==5.5.0
34
+ wandb==0.18.6
35
+ deepspeed==0.16.4
36
+ sageattention==1.0.6
37
+ mosaicml-streaming==0.11.0
38
+ easydict==1.13
39
+ open3d==0.19.0
40
+ prodigyopt==1.1.2
41
+ peft==0.15.1
42
+ sentencepiece==0.2.0
43
+ pymeshlab==2023.12.post3
44
+ onnxruntime==1.21.0
45
+ bs4==0.0.2
46
+ xatlas==0.0.10
47
+ pybind11==2.13.6
48
+ pygltflib==1.16.4
49
+ kornia==0.8.0
50
+ git+https://github.com/NVlabs/nvdiffrast.git
step1x3d_geometry/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ __modules__ = {}
4
+
5
+
6
+ def register(name):
7
+ def decorator(cls):
8
+ if name in __modules__:
9
+ raise ValueError(
10
+ f"Module {name} already exists! Names of extensions conflict!"
11
+ )
12
+ else:
13
+ __modules__[name] = cls
14
+ return cls
15
+
16
+ return decorator
17
+
18
+
19
+ def find(name):
20
+ if name in __modules__:
21
+ return __modules__[name]
22
+ else:
23
+ try:
24
+ module_string = ".".join(name.split(".")[:-1])
25
+ cls_name = name.split(".")[-1]
26
+ module = importlib.import_module(module_string, package=None)
27
+ return getattr(module, cls_name)
28
+ except Exception as e:
29
+ raise ValueError(f"Module {name} not found!")
30
+
31
+
32
+ ### grammar sugar for logging utilities ###
33
+ import logging
34
+
35
+ logger = logging.getLogger("pytorch_lightning")
36
+
37
+ from pytorch_lightning.utilities.rank_zero import (
38
+ rank_zero_debug,
39
+ rank_zero_info,
40
+ rank_zero_only,
41
+ )
42
+
43
+ debug = rank_zero_debug
44
+ info = rank_zero_info
45
+
46
+
47
+ @rank_zero_only
48
+ def warn(*args, **kwargs):
49
+ logger.warn(*args, **kwargs)
50
+
51
+
52
+ from . import data, models, systems
step1x3d_geometry/data/Objaverse.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import json
4
+ import re
5
+ import cv2
6
+ from dataclasses import dataclass, field
7
+
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import DataLoader
12
+ from step1x3d_geometry import register
13
+ from step1x3d_geometry.utils.typing import *
14
+ from step1x3d_geometry.utils.config import parse_structured
15
+
16
+ from streaming import StreamingDataLoader
17
+ from .base import BaseDataModuleConfig, BaseDataset
18
+
19
+
20
+ @dataclass
21
+ class ObjaverseDataModuleConfig(BaseDataModuleConfig):
22
+ pass
23
+
24
+
25
+ class ObjaverseDataset(BaseDataset):
26
+ pass
27
+
28
+
29
+ @register("Objaverse-datamodule")
30
+ class ObjaverseDataModule(pl.LightningDataModule):
31
+ cfg: ObjaverseDataModuleConfig
32
+
33
+ def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
34
+ super().__init__()
35
+ self.cfg = parse_structured(ObjaverseDataModuleConfig, cfg)
36
+
37
+ def setup(self, stage=None) -> None:
38
+ if stage in [None, "fit"]:
39
+ self.train_dataset = ObjaverseDataset(self.cfg, "train")
40
+ if stage in [None, "fit", "validate"]:
41
+ self.val_dataset = ObjaverseDataset(self.cfg, "val")
42
+ if stage in [None, "test", "predict"]:
43
+ self.test_dataset = ObjaverseDataset(self.cfg, "test")
44
+
45
+ def prepare_data(self):
46
+ pass
47
+
48
+ def general_loader(
49
+ self, dataset, batch_size, collate_fn=None, num_workers=0
50
+ ) -> DataLoader:
51
+ return DataLoader(
52
+ dataset,
53
+ batch_size=batch_size,
54
+ collate_fn=collate_fn,
55
+ num_workers=num_workers,
56
+ )
57
+
58
+ def train_dataloader(self) -> DataLoader:
59
+ return self.general_loader(
60
+ self.train_dataset,
61
+ batch_size=self.cfg.batch_size,
62
+ collate_fn=self.train_dataset.collate,
63
+ num_workers=self.cfg.num_workers,
64
+ )
65
+
66
+ def val_dataloader(self) -> DataLoader:
67
+ return self.general_loader(self.val_dataset, batch_size=1)
68
+
69
+ def test_dataloader(self) -> DataLoader:
70
+ return self.general_loader(self.test_dataset, batch_size=1)
71
+
72
+ def predict_dataloader(self) -> DataLoader:
73
+ return self.general_loader(self.test_dataset, batch_size=1)
step1x3d_geometry/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import Objaverse
step1x3d_geometry/data/base.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import json
4
+ import re
5
+ import cv2
6
+ from dataclasses import dataclass, field
7
+
8
+ import random
9
+ import imageio
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torchvision.transforms as transforms
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from PIL import Image
16
+
17
+ from step1x3d_geometry.utils.typing import *
18
+
19
+
20
+ @dataclass
21
+ class BaseDataModuleConfig:
22
+ root_dir: str = None
23
+ batch_size: int = 4
24
+ num_workers: int = 8
25
+
26
+ ################################# General argumentation #################################
27
+ random_flip: bool = (
28
+ False # whether to randomly flip the input point cloud and the input images
29
+ )
30
+
31
+ ################################# Geometry part #################################
32
+ load_geometry: bool = True # whether to load geometry data
33
+ with_sharp_data: bool = False
34
+ geo_data_type: str = "sdf" # occupancy, sdf
35
+ # for occupancy or sdf supervision
36
+ n_samples: int = 4096 # number of points in input point cloud
37
+ upsample_ratio: int = 1 # upsample ratio for input point cloud
38
+ sampling_strategy: Optional[str] = (
39
+ "random" # sampling strategy for input point cloud
40
+ )
41
+ scale: float = 1.0 # scale of the input point cloud and target supervision
42
+ noise_sigma: float = 0.0 # noise level of the input point cloud
43
+ rotate_points: bool = (
44
+ False # whether to rotate the input point cloud and the supervision, for VAE aug.
45
+ )
46
+ load_geometry_supervision: bool = False # whether to load supervision
47
+ supervision_type: str = "sdf" # occupancy, sdf, tsdf, tsdf_w_surface
48
+ n_supervision: int = 10000 # number of points in supervision
49
+ tsdf_threshold: float = (
50
+ 0.01 # threshold for truncating sdf values, used when input is sdf
51
+ )
52
+
53
+ ################################# Image part #################################
54
+ load_image: bool = False # whether to load images
55
+ image_type: str = "rgb" # rgb, normal, rgb_or_normal
56
+ image_file_type: str = "png" # png, jpeg
57
+ image_type_ratio: float = (
58
+ 1.0 # ratio of rgb for each dataset when image_type is "rgb_or_normal"
59
+ )
60
+ crop_image: bool = True # whether to crop the input image
61
+ random_color_jitter: bool = (
62
+ False # whether to randomly color jitter the input images
63
+ )
64
+ random_rotate: bool = (
65
+ False # whether to randomly rotate the input images, default [-10 deg, 10 deg]
66
+ )
67
+ random_mask: bool = False # whether to add random mask to the input image
68
+ background_color: Tuple[int, int, int] = field(
69
+ default_factory=lambda: (255, 255, 255)
70
+ )
71
+ idx: Optional[List[int]] = None # index of the image to load
72
+ n_views: int = 1 # number of views
73
+ foreground_ratio: Optional[float] = 0.90
74
+
75
+ ################################# Caption part #################################
76
+ load_caption: bool = False # whether to load captions
77
+ load_label: bool = False # whether to load labels
78
+
79
+
80
+ class BaseDataset(Dataset):
81
+ def __init__(self, cfg: Any, split: str) -> None:
82
+ super().__init__()
83
+ self.cfg: BaseDataModuleConfig = cfg
84
+ self.split = split
85
+
86
+ self.uids = json.load(open(f"{cfg.root_dir}/{split}.json"))
87
+ print(f"Loaded {len(self.uids)} {split} uids")
88
+
89
+ # add ColorJitter transforms for input images
90
+ if self.cfg.random_color_jitter:
91
+ self.color_jitter = transforms.ColorJitter(
92
+ brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
93
+ )
94
+
95
+ # add RandomRotation transforms for input images
96
+ if self.cfg.random_rotate:
97
+ self.rotate = transforms.RandomRotation(
98
+ degrees=10, fill=(*self.cfg.background_color, 0.0)
99
+ ) # by default 10 deg
100
+
101
+ def __len__(self):
102
+ return len(self.uids)
103
+
104
+ def _load_shape_from_occupancy_or_sdf(self, index: int) -> Dict[str, Any]:
105
+ if self.cfg.geo_data_type == "sdf":
106
+ data = np.load(f"{self.cfg.root_dir}/surfaces/{self.uids[index]}.npz")
107
+ # for input point cloud
108
+ surface = data["surface"]
109
+ if self.cfg.with_sharp_data:
110
+ sharp_surface = data["sharp_surface"]
111
+ else:
112
+ raise NotImplementedError(
113
+ f"Data type {self.cfg.geo_data_type} not implemented"
114
+ )
115
+
116
+ # random sampling
117
+ if self.cfg.sampling_strategy == "random":
118
+ rng = np.random.default_rng()
119
+ ind = rng.choice(
120
+ surface.shape[0],
121
+ self.cfg.upsample_ratio * self.cfg.n_samples,
122
+ replace=True,
123
+ )
124
+ surface = surface[ind]
125
+ if self.cfg.with_sharp_data:
126
+ sharp_surface = sharp_surface[ind]
127
+ elif self.cfg.sampling_strategy == "fps":
128
+ import fpsample
129
+
130
+ kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(
131
+ surface[:, :3], self.cfg.n_samples, h=5
132
+ )
133
+ surface = surface[kdline_fps_samples_idx]
134
+ if self.cfg.with_sharp_data:
135
+ kdline_fps_samples_idx = fpsample.bucket_fps_kdline_sampling(
136
+ sharp_surface[:, :3], self.cfg.n_samples, h=5
137
+ )
138
+ sharp_surface = sharp_surface[kdline_fps_samples_idx]
139
+ else:
140
+ raise NotImplementedError(
141
+ f"sampling strategy {self.cfg.sampling_strategy} not implemented"
142
+ )
143
+
144
+ # rescale data
145
+ surface[:, :3] = surface[:, :3] * self.cfg.scale # target scale
146
+ if self.cfg.with_sharp_data:
147
+ sharp_surface[:, :3] = sharp_surface[:, :3] * self.cfg.scale # target scale
148
+ ret = {
149
+ "uid": self.uids[index].split("/")[-1],
150
+ "surface": surface.astype(np.float32),
151
+ "sharp_surface": sharp_surface.astype(np.float32),
152
+ }
153
+ else:
154
+ ret = {
155
+ "uid": self.uids[index].split("/")[-1],
156
+ "surface": surface.astype(np.float32),
157
+ }
158
+
159
+ return ret
160
+
161
+ def _load_shape_supervision_occupancy_or_sdf(self, index: int) -> Dict[str, Any]:
162
+ # for supervision
163
+ ret = {}
164
+ if self.cfg.geo_data_type == "sdf":
165
+ data = np.load(f"{self.cfg.root_dir}/surfaces/{self.uids[index]}.npz")
166
+ data = np.concatenate(
167
+ [data["volume_rand_points"], data["near_surface_points"]], axis=0
168
+ )
169
+ rand_points, sdfs = data[:, :3], data[:, 3:]
170
+ else:
171
+ raise NotImplementedError(
172
+ f"Data type {self.cfg.geo_data_type} not implemented"
173
+ )
174
+
175
+ # random sampling
176
+ rng = np.random.default_rng()
177
+ ind = rng.choice(rand_points.shape[0], self.cfg.n_supervision, replace=False)
178
+ rand_points = rand_points[ind]
179
+ rand_points = rand_points * self.cfg.scale
180
+ ret["rand_points"] = rand_points.astype(np.float32)
181
+
182
+ if self.cfg.geo_data_type == "sdf":
183
+ if self.cfg.supervision_type == "sdf":
184
+ ret["sdf"] = sdfs[ind].flatten().astype(np.float32)
185
+ elif self.cfg.supervision_type == "occupancy":
186
+ ret["occupancies"] = np.where(sdfs[ind].flatten() < 1e-3, 0, 1).astype(
187
+ np.float32
188
+ )
189
+ elif self.cfg.supervision_type == "tsdf":
190
+ ret["sdf"] = (
191
+ sdfs[ind]
192
+ .flatten()
193
+ .astype(np.float32)
194
+ .clip(-self.cfg.tsdf_threshold, self.cfg.tsdf_threshold)
195
+ / self.cfg.tsdf_threshold
196
+ )
197
+ else:
198
+ raise NotImplementedError(
199
+ f"Supervision type {self.cfg.supervision_type} not implemented"
200
+ )
201
+
202
+ return ret
203
+
204
+ def _load_image(self, index: int) -> Dict[str, Any]:
205
+ def _process_img(image, background_color=(255, 255, 255), foreground_ratio=0.9):
206
+ alpha = image.getchannel("A")
207
+ background = Image.new("RGBA", image.size, (*background_color, 255))
208
+ image = Image.alpha_composite(background, image)
209
+ image = image.crop(alpha.getbbox())
210
+
211
+ new_size = tuple(int(dim * foreground_ratio) for dim in image.size)
212
+ resized_image = image.resize(new_size)
213
+ padded_image = Image.new("RGBA", image.size, (*background_color, 255))
214
+ paste_position = (
215
+ (image.width - resized_image.width) // 2,
216
+ (image.height - resized_image.height) // 2,
217
+ )
218
+ padded_image.paste(resized_image, paste_position)
219
+
220
+ # Expand image to 1:1
221
+ max_dim = max(padded_image.size)
222
+ image = Image.new("RGBA", (max_dim, max_dim), (*background_color, 255))
223
+ paste_position = (
224
+ (max_dim - padded_image.width) // 2,
225
+ (max_dim - padded_image.height) // 2,
226
+ )
227
+ image.paste(padded_image, paste_position)
228
+ image = image.resize((512, 512))
229
+ return image.convert("RGB"), alpha
230
+
231
+ ret = {}
232
+ if self.cfg.image_type == "rgb" or self.cfg.image_type == "normal":
233
+ assert (
234
+ self.cfg.n_views == 1
235
+ ), "Only single view is supported for single image"
236
+ sel_idx = random.choice(self.cfg.idx)
237
+ ret["sel_image_idx"] = sel_idx
238
+ if self.cfg.image_type == "rgb":
239
+ img_path = (
240
+ f"{self.cfg.root_dir}/images/"
241
+ + "/".join(self.uids[index].split("/")[-2:])
242
+ + f"/{'{:04d}'.format(sel_idx)}_rgb.{self.cfg.image_file_type}"
243
+ )
244
+ elif self.cfg.image_type == "normal":
245
+ img_path = (
246
+ f"{self.cfg.root_dir}/images/"
247
+ + "/".join(self.uids[index].split("/")[-2:])
248
+ + f"/{'{:04d}'.format(sel_idx)}_normal.{self.cfg.image_file_type}"
249
+ )
250
+ image = Image.open(img_path).copy()
251
+
252
+ # add random color jitter
253
+ if self.cfg.random_color_jitter:
254
+ rgb = self.color_jitter(image.convert("RGB"))
255
+ image = Image.merge("RGBA", (*rgb.split(), image.getchannel("A")))
256
+
257
+ # add random rotation
258
+ if self.cfg.random_rotate:
259
+ image = self.rotate(image)
260
+
261
+ # add crop
262
+ if self.cfg.crop_image:
263
+ background_color = (
264
+ torch.randint(0, 256, (3,))
265
+ if self.cfg.background_color is None
266
+ else torch.as_tensor(self.cfg.background_color)
267
+ )
268
+ image, alpha = _process_img(
269
+ image, background_color, self.cfg.foreground_ratio
270
+ )
271
+ else:
272
+ alpha = image.getchannel("A")
273
+ background = Image.new("RGBA", image.size, background_color)
274
+ image = Image.alpha_composite(background, image).convert("RGB")
275
+
276
+ ret["image"] = torch.from_numpy(np.array(image) / 255.0)
277
+ ret["mask"] = torch.from_numpy(np.array(alpha) / 255.0).unsqueeze(0)
278
+ else:
279
+ raise NotImplementedError(
280
+ f"Image type {self.cfg.image_type} not implemented"
281
+ )
282
+
283
+ return ret
284
+
285
+ def _get_data(self, index):
286
+ ret = {"uid": self.uids[index]}
287
+
288
+ # random flip
289
+ flip = np.random.rand() < 0.5 if self.cfg.random_flip else False
290
+
291
+ # load geometry
292
+ if self.cfg.load_geometry:
293
+ if self.cfg.geo_data_type == "occupancy" or self.cfg.geo_data_type == "sdf":
294
+ # load shape
295
+ ret = self._load_shape_from_occupancy_or_sdf(index)
296
+ # load supervision for shape
297
+ if self.cfg.load_geometry_supervision:
298
+ ret.update(self._load_shape_supervision_occupancy_or_sdf(index))
299
+ else:
300
+ raise NotImplementedError(
301
+ f"Geo data type {self.cfg.geo_data_type} not implemented"
302
+ )
303
+
304
+ if flip: # random flip the input point cloud and the supervision
305
+ for key in ret.keys():
306
+ if key in ["surface", "sharp_surface"]: # N x (xyz + normal)
307
+ ret[key][:, 0] = -ret[key][:, 0]
308
+ ret[key][:, 3] = -ret[key][:, 3]
309
+ elif key in ["rand_points"]:
310
+ ret[key][:, 0] = -ret[key][:, 0]
311
+
312
+ # load image
313
+ if self.cfg.load_image:
314
+ ret.update(self._load_image(index))
315
+ if flip: # random flip the input image
316
+ for key in ret.keys():
317
+ if key in ["image"]: # random flip the input image
318
+ ret[key] = torch.flip(ret[key], [2])
319
+ if key in ["mask"]: # random flip the input image
320
+ ret[key] = torch.flip(ret[key], [2])
321
+
322
+ # load caption
323
+ meta = None
324
+ if self.cfg.load_caption:
325
+ with open(f"{self.cfg.root_dir}/metas/{self.uids[index]}.json", "r") as f:
326
+ meta = json.load(f)
327
+ ret.update({"caption": meta["caption"]})
328
+
329
+ # load label
330
+ if self.cfg.load_label:
331
+ if meta is None:
332
+ with open(
333
+ f"{self.cfg.root_dir}/metas/{self.uids[index]}.json", "r"
334
+ ) as f:
335
+ meta = json.load(f)
336
+ ret.update({"label": [meta["label"]]})
337
+
338
+ return ret
339
+
340
+ def __getitem__(self, index):
341
+ try:
342
+ return self._get_data(index)
343
+ except Exception as e:
344
+ print(f"Error in {self.uids[index]}: {e}")
345
+ return self.__getitem__(np.random.randint(len(self)))
346
+
347
+ def collate(self, batch):
348
+ from torch.utils.data._utils.collate import default_collate_fn_map
349
+
350
+ return torch.utils.data.default_collate(batch)
step1x3d_geometry/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import autoencoders, conditional_encoders, transformers
step1x3d_geometry/models/attention.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple, Union
15
+ import collections.abc
16
+ from itertools import repeat
17
+
18
+ import torch
19
+ from torch import nn
20
+ import torch.nn.functional as F
21
+ import torch.distributed as dist
22
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
23
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
24
+ from diffusers.models.attention import FeedForward
25
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
26
+ from diffusers.models.normalization import (
27
+ AdaLayerNormContinuous,
28
+ AdaLayerNormZero,
29
+ AdaLayerNormZeroSingle,
30
+ FP32LayerNorm,
31
+ LayerNorm,
32
+ )
33
+
34
+ from .attention_processor import FluxAttnProcessor2_0, AttnProcessor2_0
35
+
36
+
37
+ @maybe_allow_in_graph
38
+ class MultiCondBasicTransformerBlock(nn.Module):
39
+ r"""
40
+ A basic Transformer block.
41
+
42
+ Parameters:
43
+ dim (`int`): The number of channels in the input and output.
44
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
45
+ attention_head_dim (`int`): The number of channels in each head.
46
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
47
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
48
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
49
+ num_embeds_ada_norm (:
50
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
51
+ attention_bias (:
52
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
53
+ only_cross_attention (`bool`, *optional*):
54
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
55
+ double_self_attention (`bool`, *optional*):
56
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
57
+ upcast_attention (`bool`, *optional*):
58
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
59
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
60
+ Whether to use learnable elementwise affine parameters for normalization.
61
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
62
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
63
+ final_dropout (`bool` *optional*, defaults to False):
64
+ Whether to apply a final dropout after the last feed-forward layer.
65
+ attention_type (`str`, *optional*, defaults to `"default"`):
66
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
67
+ positional_embeddings (`str`, *optional*, defaults to `None`):
68
+ The type of positional embeddings to apply to.
69
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
70
+ The maximum number of positional embeddings to apply.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ dim: int,
76
+ num_attention_heads: int,
77
+ use_self_attention: bool = True,
78
+ use_cross_attention: bool = False,
79
+ self_attention_norm_type: Optional[
80
+ str
81
+ ] = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
82
+ cross_attention_dim: Optional[int] = None,
83
+ cross_attention_norm_type: Optional[str] = None,
84
+ # parallel second cross attention
85
+ use_cross_attention_2: bool = False,
86
+ cross_attention_2_dim: Optional[int] = None,
87
+ cross_attention_2_norm_type: Optional[str] = None,
88
+ # parallel third cross attention
89
+ use_cross_attention_3: bool = False,
90
+ cross_attention_3_dim: Optional[int] = None,
91
+ cross_attention_3_norm_type: Optional[str] = None,
92
+ dropout=0.0,
93
+ activation_fn: str = "geglu",
94
+ num_embeds_ada_norm: Optional[int] = None,
95
+ attention_bias: bool = False,
96
+ only_cross_attention: bool = False,
97
+ double_self_attention: bool = False,
98
+ upcast_attention: bool = False,
99
+ norm_elementwise_affine: bool = True,
100
+ norm_eps: float = 1e-5,
101
+ final_dropout: bool = False,
102
+ attention_type: str = "default",
103
+ positional_embeddings: Optional[str] = None,
104
+ num_positional_embeddings: Optional[int] = None,
105
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
106
+ ada_norm_bias: Optional[int] = None,
107
+ ff_inner_dim: Optional[int] = None,
108
+ ff_bias: bool = True,
109
+ attention_out_bias: bool = True,
110
+ ):
111
+ super().__init__()
112
+ self.dim = dim
113
+ self.num_attention_heads = num_attention_heads
114
+ self.use_self_attention = use_self_attention
115
+ self.use_cross_attention = use_cross_attention
116
+ self.self_attention_norm_type = self_attention_norm_type
117
+ self.cross_attention_dim = cross_attention_dim
118
+ self.cross_attention_norm_type = cross_attention_norm_type
119
+ self.use_cross_attention_2 = use_cross_attention_2
120
+ self.cross_attention_2_dim = cross_attention_2_dim
121
+ self.cross_attention_2_norm_type = cross_attention_2_norm_type
122
+ self.use_cross_attention_3 = use_cross_attention_3
123
+ self.cross_attention_3_dim = cross_attention_3_dim
124
+ self.cross_attention_3_norm_type = cross_attention_3_norm_type
125
+ self.dropout = dropout
126
+ self.cross_attention_dim = cross_attention_dim
127
+ self.activation_fn = activation_fn
128
+ self.attention_bias = attention_bias
129
+ self.double_self_attention = double_self_attention
130
+ self.norm_elementwise_affine = norm_elementwise_affine
131
+ self.positional_embeddings = positional_embeddings
132
+ self.num_positional_embeddings = num_positional_embeddings
133
+ self.only_cross_attention = only_cross_attention
134
+
135
+ # We keep these boolean flags for backward-compatibility.
136
+ self.use_ada_layer_norm_zero = (
137
+ num_embeds_ada_norm is not None
138
+ ) and self_attention_norm_type == "ada_norm_zero"
139
+ self.use_ada_layer_norm = (
140
+ num_embeds_ada_norm is not None
141
+ ) and self_attention_norm_type == "ada_norm"
142
+ self.use_ada_layer_norm_single = self_attention_norm_type == "ada_norm_single"
143
+ self.use_layer_norm = self_attention_norm_type == "layer_norm"
144
+ self.use_ada_layer_norm_continuous = (
145
+ self_attention_norm_type == "ada_norm_continuous"
146
+ )
147
+
148
+ if (
149
+ self_attention_norm_type in ("ada_norm", "ada_norm_zero")
150
+ and num_embeds_ada_norm is None
151
+ ):
152
+ raise ValueError(
153
+ f"`self_attention_norm_type` is set to {self_attention_norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
154
+ f" define `num_embeds_ada_norm` if setting `self_attention_norm_type` to {self_attention_norm_type}."
155
+ )
156
+
157
+ self.self_attention_norm_type = self_attention_norm_type
158
+ self.num_embeds_ada_norm = num_embeds_ada_norm
159
+
160
+ if positional_embeddings and (num_positional_embeddings is None):
161
+ raise ValueError(
162
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
163
+ )
164
+
165
+ if positional_embeddings == "sinusoidal":
166
+ self.pos_embed = SinusoidalPositionalEmbedding(
167
+ dim, max_seq_length=num_positional_embeddings
168
+ )
169
+ else:
170
+ self.pos_embed = None
171
+
172
+ # Define 3 blocks. Each block has its own normalization layer.
173
+ if use_self_attention:
174
+ # 1. Self-Attn
175
+ if self_attention_norm_type == "ada_norm":
176
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
177
+ elif self_attention_norm_type == "ada_norm_zero":
178
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
179
+ elif self_attention_norm_type == "ada_norm_continuous":
180
+ self.norm1 = AdaLayerNormContinuous(
181
+ dim,
182
+ ada_norm_continous_conditioning_embedding_dim,
183
+ norm_elementwise_affine,
184
+ norm_eps,
185
+ ada_norm_bias,
186
+ "rms_norm",
187
+ )
188
+ elif (
189
+ self_attention_norm_type == "fp32_layer_norm"
190
+ or self_attention_norm_type is None
191
+ ):
192
+ self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
193
+ else:
194
+ self.norm1 = nn.RMSNorm(
195
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
196
+ )
197
+
198
+ self.attn1 = Attention(
199
+ query_dim=dim,
200
+ heads=num_attention_heads,
201
+ dim_head=dim // num_attention_heads,
202
+ dropout=dropout,
203
+ bias=attention_bias,
204
+ cross_attention_dim=(
205
+ cross_attention_dim if only_cross_attention else None
206
+ ),
207
+ upcast_attention=upcast_attention,
208
+ out_bias=attention_out_bias,
209
+ processor=AttnProcessor2_0(),
210
+ )
211
+
212
+ # 2. Cross-Attn
213
+ if use_cross_attention or double_self_attention:
214
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
215
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
216
+ # the second cross attention block.
217
+ if cross_attention_norm_type == "ada_norm":
218
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
219
+ elif cross_attention_norm_type == "ada_norm_continuous":
220
+ self.norm2 = AdaLayerNormContinuous(
221
+ dim,
222
+ ada_norm_continous_conditioning_embedding_dim,
223
+ norm_elementwise_affine,
224
+ norm_eps,
225
+ ada_norm_bias,
226
+ "rms_norm",
227
+ )
228
+ elif (
229
+ cross_attention_norm_type == "fp32_layer_norm"
230
+ or cross_attention_norm_type is None
231
+ ):
232
+ self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
233
+ else:
234
+ self.norm2 = nn.RMSNorm(
235
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
236
+ )
237
+
238
+ self.attn2 = Attention(
239
+ query_dim=dim,
240
+ cross_attention_dim=(
241
+ cross_attention_dim if not double_self_attention else None
242
+ ),
243
+ heads=num_attention_heads,
244
+ dim_head=dim // num_attention_heads,
245
+ dropout=dropout,
246
+ bias=attention_bias,
247
+ upcast_attention=upcast_attention,
248
+ out_bias=attention_out_bias,
249
+ processor=AttnProcessor2_0(),
250
+ ) # is self-attn if encoder_hidden_states is none
251
+ else:
252
+ self.norm2 = None
253
+ self.attn2 = None
254
+
255
+ # 2'. Parallel Second Cross-Attn
256
+ if use_cross_attention_2:
257
+ assert cross_attention_2_dim is not None
258
+ if cross_attention_2_norm_type == "ada_norm":
259
+ self.norm2_2 = AdaLayerNorm(dim, num_embeds_ada_norm)
260
+ elif cross_attention_2_norm_type == "ada_norm_continuous":
261
+ self.norm2_2 = AdaLayerNormContinuous(
262
+ dim,
263
+ ada_norm_continous_conditioning_embedding_dim,
264
+ norm_elementwise_affine,
265
+ norm_eps,
266
+ ada_norm_bias,
267
+ "rms_norm",
268
+ )
269
+ elif (
270
+ cross_attention_2_norm_type == "fp32_layer_norm"
271
+ or cross_attention_2_norm_type is None
272
+ ):
273
+ self.norm2_2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
274
+ else:
275
+ self.norm2_2 = nn.RMSNorm(
276
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
277
+ )
278
+
279
+ self.attn2_2 = Attention(
280
+ query_dim=dim,
281
+ cross_attention_dim=cross_attention_2_dim,
282
+ heads=num_attention_heads,
283
+ dim_head=dim // num_attention_heads,
284
+ dropout=dropout,
285
+ bias=attention_bias,
286
+ upcast_attention=upcast_attention,
287
+ out_bias=attention_out_bias,
288
+ processor=AttnProcessor2_0(),
289
+ )
290
+
291
+ # self.attn2_2 = Attention(
292
+ # query_dim=dim,
293
+ # cross_attention_dim=cross_attention_2_dim,
294
+ # dim_head=dim // num_attention_heads,
295
+ # heads=num_attention_heads,
296
+ # qk_norm="rms_norm" if qk_norm else None,
297
+ # cross_attention_norm=cross_attention_2_norm_type,
298
+ # eps=1e-6,
299
+ # bias=qkv_bias,
300
+ # processor=AttnProcessor2_0(),
301
+ # )
302
+ else:
303
+ self.norm2_2 = None
304
+ self.attn2_2 = None
305
+
306
+ # 2'. Parallel Third Cross-Attn
307
+ if use_cross_attention_3:
308
+ assert cross_attention_3_dim is not None
309
+ if cross_attention_3_norm_type == "ada_norm":
310
+ self.norm2_3 = AdaLayerNorm(dim, num_embeds_ada_norm)
311
+ elif cross_attention_3_norm_type == "ada_norm_continuous":
312
+ self.norm2_3 = AdaLayerNormContinuous(
313
+ dim,
314
+ ada_norm_continous_conditioning_embedding_dim,
315
+ norm_elementwise_affine,
316
+ norm_eps,
317
+ ada_norm_bias,
318
+ "rms_norm",
319
+ )
320
+ elif (
321
+ cross_attention_3_norm_type == "fp32_layer_norm"
322
+ or cross_attention_3_norm_type is None
323
+ ):
324
+ self.norm2_3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
325
+ else:
326
+ self.norm2_3 = nn.RMSNorm(
327
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
328
+ )
329
+
330
+ self.attn2_3 = Attention(
331
+ query_dim=dim,
332
+ cross_attention_dim=cross_attention_3_dim,
333
+ heads=num_attention_heads,
334
+ dim_head=dim // num_attention_heads,
335
+ dropout=dropout,
336
+ bias=attention_bias,
337
+ upcast_attention=upcast_attention,
338
+ out_bias=attention_out_bias,
339
+ processor=AttnProcessor2_0(),
340
+ )
341
+ else:
342
+ self.norm2_3 = None
343
+ self.attn2_3 = None
344
+
345
+ # 3. Feed-forward
346
+ if self_attention_norm_type == "ada_norm_continuous":
347
+ self.norm3 = AdaLayerNormContinuous(
348
+ dim,
349
+ ada_norm_continous_conditioning_embedding_dim,
350
+ norm_elementwise_affine,
351
+ norm_eps,
352
+ ada_norm_bias,
353
+ "layer_norm",
354
+ )
355
+
356
+ elif self_attention_norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
357
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
358
+ elif self_attention_norm_type == "layer_norm_i2vgen":
359
+ self.norm3 = None
360
+
361
+ self.ff = FeedForward(
362
+ dim,
363
+ dropout=dropout,
364
+ activation_fn=activation_fn,
365
+ final_dropout=final_dropout,
366
+ inner_dim=ff_inner_dim,
367
+ bias=ff_bias,
368
+ )
369
+
370
+ # 4. Fuser
371
+ if attention_type == "gated" or attention_type == "gated-text-image":
372
+ self.fuser = GatedSelfAttentionDense(
373
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
374
+ )
375
+
376
+ # 5. Scale-shift for PixArt-Alpha.
377
+ if self_attention_norm_type == "ada_norm_single":
378
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
379
+
380
+ # let chunk size default to None
381
+ self._chunk_size = None
382
+ self._chunk_dim = 0
383
+
384
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
385
+ # Sets chunk feed-forward
386
+ self._chunk_size = chunk_size
387
+ self._chunk_dim = dim
388
+
389
+ def forward(
390
+ self,
391
+ hidden_states: torch.Tensor,
392
+ attention_mask: Optional[torch.Tensor] = None,
393
+ encoder_hidden_states: Optional[torch.Tensor] = None,
394
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
395
+ encoder_hidden_states_3: Optional[torch.Tensor] = None,
396
+ encoder_attention_mask: Optional[torch.Tensor] = None,
397
+ encoder_attention_mask_2: Optional[torch.Tensor] = None,
398
+ encoder_attention_mask_3: Optional[torch.Tensor] = None,
399
+ timestep: Optional[torch.LongTensor] = None,
400
+ cross_attention_kwargs: Dict[str, Any] = None,
401
+ class_labels: Optional[torch.LongTensor] = None,
402
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
403
+ ) -> torch.Tensor:
404
+ if cross_attention_kwargs is not None:
405
+ if cross_attention_kwargs.get("scale", None) is not None:
406
+ logger.warning(
407
+ "Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored."
408
+ )
409
+
410
+ # Notice that normalization is always applied before the real computation in the following blocks.
411
+ # 0. Self-Attention
412
+ batch_size = hidden_states.shape[0]
413
+
414
+ if self.self_attention_norm_type == "ada_norm":
415
+ norm_hidden_states = self.norm1(hidden_states, timestep)
416
+ elif self.self_attention_norm_type == "ada_norm_zero":
417
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
418
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
419
+ )
420
+ elif self.self_attention_norm_type in ["layer_norm", "layer_norm_i2vgen"]:
421
+ norm_hidden_states = self.norm1(hidden_states)
422
+ elif self.self_attention_norm_type == "ada_norm_continuous":
423
+ norm_hidden_states = self.norm1(
424
+ hidden_states, added_cond_kwargs["pooled_text_emb"]
425
+ )
426
+ elif self.self_attention_norm_type == "ada_norm_single":
427
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
428
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
429
+ ).chunk(6, dim=1)
430
+ norm_hidden_states = self.norm1(hidden_states)
431
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
432
+ else:
433
+ raise ValueError("Incorrect norm used")
434
+
435
+ if self.pos_embed is not None:
436
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
437
+
438
+ # 1. Prepare GLIGEN inputs
439
+ cross_attention_kwargs = (
440
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
441
+ )
442
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
443
+
444
+ attn_output = self.attn1(
445
+ norm_hidden_states,
446
+ encoder_hidden_states=(
447
+ encoder_hidden_states if self.only_cross_attention else None
448
+ ),
449
+ attention_mask=attention_mask,
450
+ **cross_attention_kwargs,
451
+ )
452
+
453
+ if self.self_attention_norm_type == "ada_norm_zero":
454
+ attn_output = gate_msa.unsqueeze(1) * attn_output
455
+ elif self.self_attention_norm_type == "ada_norm_single":
456
+ attn_output = gate_msa * attn_output
457
+
458
+ hidden_states = attn_output + hidden_states
459
+ if hidden_states.ndim == 4:
460
+ hidden_states = hidden_states.squeeze(1)
461
+
462
+ # 1.2 GLIGEN Control
463
+ if gligen_kwargs is not None:
464
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
465
+
466
+ # 3. Cross-Attention
467
+ if self.attn2 is not None:
468
+ if self.cross_attention_norm_type == "ada_norm":
469
+ norm_hidden_states = self.norm2(hidden_states, timestep)
470
+ elif self.cross_attention_norm_type in [
471
+ "ada_norm_zero",
472
+ "layer_norm",
473
+ "layer_norm_i2vgen",
474
+ ]:
475
+ norm_hidden_states = self.norm2(hidden_states)
476
+ elif self.cross_attention_norm_type == "ada_norm_single":
477
+ # For PixArt norm2 isn't applied here:
478
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
479
+ norm_hidden_states = hidden_states
480
+ elif self.cross_attention_norm_type == "ada_norm_continuous":
481
+ norm_hidden_states = self.norm2(
482
+ hidden_states, added_cond_kwargs["pooled_text_emb"]
483
+ )
484
+ else:
485
+ raise ValueError("Incorrect norm")
486
+
487
+ if (
488
+ self.pos_embed is not None
489
+ and self.cross_attention_norm_type != "ada_norm_single"
490
+ ):
491
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
492
+
493
+ attn_output = self.attn2(
494
+ norm_hidden_states,
495
+ encoder_hidden_states=encoder_hidden_states,
496
+ attention_mask=encoder_attention_mask,
497
+ **cross_attention_kwargs,
498
+ )
499
+ hidden_states = attn_output + hidden_states
500
+
501
+ # 3.1 Parallel Second Cross-Attention
502
+ if self.attn2_2 is not None:
503
+ if self.cross_attention_2_norm_type == "ada_norm":
504
+ norm_hidden_states = self.norm2_2(hidden_states, timestep)
505
+ elif self.cross_attention_2_norm_type in [
506
+ "ada_norm_zero",
507
+ "layer_norm",
508
+ "layer_norm_i2vgen",
509
+ ]:
510
+ norm_hidden_states = self.norm2_2(hidden_states)
511
+ elif self.cross_attention_2_norm_type == "ada_norm_single":
512
+ # For PixArt norm2_2 isn't applied here:
513
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
514
+ norm_hidden_states = hidden_states
515
+ elif self.cross_attention_2_norm_type == "ada_norm_continuous":
516
+ norm_hidden_states = self.norm2_2(
517
+ hidden_states, added_cond_kwargs["pooled_text_emb"]
518
+ )
519
+ else:
520
+ raise ValueError("Incorrect norm")
521
+
522
+ if (
523
+ self.pos_embed is not None
524
+ and self.cross_attention_2_norm_type != "ada_norm_single"
525
+ ):
526
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
527
+
528
+ attn_output_2 = self.attn2_2(
529
+ norm_hidden_states,
530
+ encoder_hidden_states=encoder_hidden_states_2,
531
+ attention_mask=encoder_attention_mask_2,
532
+ **cross_attention_kwargs,
533
+ )
534
+ hidden_states = attn_output_2 + hidden_states
535
+
536
+ # 3.2 Parallel Third Cross-Attention
537
+ if self.attn2_3 is not None:
538
+ if self.cross_attention_3_norm_type == "ada_norm":
539
+ norm_hidden_states = self.norm2_3(hidden_states, timestep)
540
+ elif self.cross_attention_3_norm_type in [
541
+ "ada_norm_zero",
542
+ "layer_norm",
543
+ "layer_norm_i2vgen",
544
+ ]:
545
+ norm_hidden_states = self.norm2_3(hidden_states)
546
+ elif self.cross_attention_3_norm_type == "ada_norm_single":
547
+ # For PixArt norm2_3 isn't applied here:
548
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
549
+ norm_hidden_states = hidden_states
550
+ elif self.cross_attention_3_norm_type == "ada_norm_continuous":
551
+ norm_hidden_states = self.norm2_3(
552
+ hidden_states, added_cond_kwargs["pooled_text_emb"]
553
+ )
554
+ else:
555
+ raise ValueError("Incorrect norm")
556
+
557
+ if (
558
+ self.pos_embed is not None
559
+ and self.cross_attention_3_norm_type != "ada_norm_single"
560
+ ):
561
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
562
+
563
+ attn_output_3 = self.attn2_3(
564
+ norm_hidden_states,
565
+ encoder_hidden_states=encoder_hidden_states_3,
566
+ attention_mask=encoder_attention_mask_3,
567
+ **cross_attention_kwargs,
568
+ )
569
+ hidden_states = attn_output_3 + hidden_states
570
+
571
+ # 4. Feed-forward
572
+ # i2vgen doesn't have this norm 🤷‍♂️
573
+ if self.self_attention_norm_type == "ada_norm_continuous":
574
+ norm_hidden_states = self.norm3(
575
+ hidden_states, added_cond_kwargs["pooled_text_emb"]
576
+ )
577
+ elif not self.self_attention_norm_type == "ada_norm_single":
578
+ norm_hidden_states = self.norm3(hidden_states)
579
+
580
+ if self.self_attention_norm_type == "ada_norm_zero":
581
+ norm_hidden_states = (
582
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
583
+ )
584
+
585
+ if self.self_attention_norm_type == "ada_norm_single":
586
+ norm_hidden_states = self.norm2(hidden_states)
587
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
588
+
589
+ if self._chunk_size is not None:
590
+ # "feed_forward_chunk_size" can be used to save memory
591
+ ff_output = _chunked_feed_forward(
592
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
593
+ )
594
+ else:
595
+ ff_output = self.ff(norm_hidden_states)
596
+
597
+ if self.self_attention_norm_type == "ada_norm_zero":
598
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
599
+ elif self.self_attention_norm_type == "ada_norm_single":
600
+ ff_output = gate_mlp * ff_output
601
+
602
+ hidden_states = ff_output + hidden_states
603
+
604
+ return hidden_states
605
+
606
+
607
+ @maybe_allow_in_graph
608
+ class FluxSingleTransformerBlock(nn.Module):
609
+ def __init__(
610
+ self,
611
+ dim: int,
612
+ num_attention_heads: int,
613
+ attention_head_dim: int,
614
+ mlp_ratio: float = 4.0,
615
+ ):
616
+ super().__init__()
617
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
618
+
619
+ self.norm = AdaLayerNormZeroSingle(dim)
620
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
621
+ self.act_mlp = nn.GELU(approximate="tanh")
622
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
623
+
624
+ if is_torch_npu_available():
625
+ deprecation_message = (
626
+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
627
+ "should be set explicitly using the `set_attn_processor` method."
628
+ )
629
+ deprecate("npu_processor", "0.34.0", deprecation_message)
630
+ processor = FluxAttnProcessor2_0_NPU()
631
+ else:
632
+ processor = FluxAttnProcessor2_0()
633
+
634
+ self.attn = Attention(
635
+ query_dim=dim,
636
+ cross_attention_dim=None,
637
+ dim_head=attention_head_dim,
638
+ heads=num_attention_heads,
639
+ out_dim=dim,
640
+ bias=True,
641
+ processor=processor,
642
+ qk_norm="rms_norm",
643
+ eps=1e-6,
644
+ pre_only=True,
645
+ )
646
+
647
+ def forward(
648
+ self,
649
+ hidden_states: torch.Tensor,
650
+ temb: torch.Tensor,
651
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
652
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
653
+ ) -> torch.Tensor:
654
+ residual = hidden_states
655
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
656
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
657
+ joint_attention_kwargs = joint_attention_kwargs or {}
658
+ attn_output = self.attn(
659
+ hidden_states=norm_hidden_states,
660
+ image_rotary_emb=image_rotary_emb,
661
+ **joint_attention_kwargs,
662
+ )
663
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
664
+ gate = gate.unsqueeze(1)
665
+
666
+ hidden_states = gate * self.proj_out(hidden_states)
667
+ hidden_states = residual + hidden_states
668
+ if hidden_states.dtype == torch.float16:
669
+ hidden_states = hidden_states.clip(-65504, 65504)
670
+
671
+ return hidden_states
672
+
673
+
674
+ @maybe_allow_in_graph
675
+ class FluxTransformerBlock(nn.Module):
676
+ def __init__(
677
+ self,
678
+ dim: int,
679
+ num_attention_heads: int,
680
+ attention_head_dim: int,
681
+ qk_norm: str = "rms_norm",
682
+ eps: float = 1e-6,
683
+ ):
684
+ super().__init__()
685
+
686
+ self.norm1 = AdaLayerNormZero(dim)
687
+ self.norm1_context = AdaLayerNormZero(dim)
688
+
689
+ self.attn = Attention(
690
+ query_dim=dim,
691
+ cross_attention_dim=None,
692
+ added_kv_proj_dim=dim,
693
+ dim_head=attention_head_dim,
694
+ heads=num_attention_heads,
695
+ out_dim=dim,
696
+ context_pre_only=False,
697
+ bias=True,
698
+ processor=FluxAttnProcessor2_0(),
699
+ qk_norm=qk_norm,
700
+ eps=eps,
701
+ )
702
+
703
+ mlp_ratio = 4.0
704
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
705
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
706
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
707
+
708
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
709
+ self.ff_context = FeedForward(
710
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
711
+ )
712
+
713
+ def forward(
714
+ self,
715
+ hidden_states: torch.Tensor,
716
+ encoder_hidden_states: Optional[torch.Tensor] = None,
717
+ temb: Optional[torch.Tensor] = None,
718
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
719
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
720
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
721
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
722
+ hidden_states, emb=temb
723
+ )
724
+
725
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (
726
+ self.norm1_context(encoder_hidden_states, emb=temb)
727
+ )
728
+ joint_attention_kwargs = joint_attention_kwargs or {}
729
+ # Attention.
730
+ attention_outputs = self.attn(
731
+ hidden_states=norm_hidden_states,
732
+ encoder_hidden_states=norm_encoder_hidden_states,
733
+ image_rotary_emb=image_rotary_emb,
734
+ **joint_attention_kwargs,
735
+ )
736
+
737
+ if len(attention_outputs) == 2:
738
+ attn_output, context_attn_output = attention_outputs
739
+ elif len(attention_outputs) == 3:
740
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
741
+
742
+ # Process attention outputs for the `hidden_states`.
743
+ attn_output = gate_msa.unsqueeze(1) * attn_output
744
+ hidden_states = hidden_states + attn_output
745
+
746
+ norm_hidden_states = self.norm2(hidden_states)
747
+ norm_hidden_states = (
748
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
749
+ )
750
+
751
+ ff_output = self.ff(norm_hidden_states)
752
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
753
+
754
+ hidden_states = hidden_states + ff_output
755
+ if len(attention_outputs) == 3:
756
+ hidden_states = hidden_states + ip_attn_output
757
+
758
+ # Process attention outputs for the `encoder_hidden_states`.
759
+
760
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
761
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
762
+
763
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
764
+ norm_encoder_hidden_states = (
765
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
766
+ + c_shift_mlp[:, None]
767
+ )
768
+
769
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
770
+ encoder_hidden_states = (
771
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
772
+ )
773
+ if encoder_hidden_states.dtype == torch.float16:
774
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
775
+
776
+ return encoder_hidden_states, hidden_states
step1x3d_geometry/models/attention_processor.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Callable, List, Optional, Tuple, Union
15
+
16
+ import os
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from diffusers.models.attention_processor import Attention
20
+ from diffusers.utils import logging
21
+ from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
22
+ from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
23
+ from einops import rearrange
24
+ from torch import nn
25
+
26
+ # add sageattention support
27
+ scaled_dot_product_attention = F.scaled_dot_product_attention
28
+ if os.environ.get("USE_SAGEATTN", "0") == "1":
29
+ try:
30
+ from sageattention import sageattn
31
+ except ImportError:
32
+ raise ImportError(
33
+ 'Please install the package "sageattention" to use this USE_SAGEATTN.'
34
+ )
35
+ scaled_dot_product_attention = sageattn
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ class AttnProcessor2_0:
41
+ r"""
42
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
43
+ """
44
+
45
+ def __init__(self):
46
+ if not hasattr(F, "scaled_dot_product_attention"):
47
+ raise ImportError(
48
+ "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
49
+ )
50
+
51
+ def __call__(
52
+ self,
53
+ attn: Attention,
54
+ hidden_states: torch.Tensor,
55
+ encoder_hidden_states: Optional[torch.Tensor] = None,
56
+ attention_mask: Optional[torch.Tensor] = None,
57
+ temb: Optional[torch.Tensor] = None,
58
+ *args,
59
+ **kwargs,
60
+ ) -> torch.Tensor:
61
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
62
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
63
+ deprecate("scale", "1.0.0", deprecation_message)
64
+
65
+ residual = hidden_states
66
+ if attn.spatial_norm is not None:
67
+ hidden_states = attn.spatial_norm(hidden_states, temb)
68
+
69
+ input_ndim = hidden_states.ndim
70
+
71
+ if input_ndim == 4:
72
+ batch_size, channel, height, width = hidden_states.shape
73
+ hidden_states = hidden_states.view(
74
+ batch_size, channel, height * width
75
+ ).transpose(1, 2)
76
+
77
+ batch_size, sequence_length, _ = (
78
+ hidden_states.shape
79
+ if encoder_hidden_states is None
80
+ else encoder_hidden_states.shape
81
+ )
82
+
83
+ if attention_mask is not None:
84
+ attention_mask = attn.prepare_attention_mask(
85
+ attention_mask, sequence_length, batch_size
86
+ )
87
+ # scaled_dot_product_attention expects attention_mask shape to be
88
+ # (batch, heads, source_length, target_length)
89
+ attention_mask = attention_mask.view(
90
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
91
+ )
92
+
93
+ if attn.group_norm is not None:
94
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
95
+ 1, 2
96
+ )
97
+
98
+ query = attn.to_q(hidden_states)
99
+
100
+ if encoder_hidden_states is None:
101
+ encoder_hidden_states = hidden_states
102
+ elif attn.norm_cross:
103
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
104
+ encoder_hidden_states
105
+ )
106
+
107
+ key = attn.to_k(encoder_hidden_states)
108
+ value = attn.to_v(encoder_hidden_states)
109
+
110
+ inner_dim = key.shape[-1]
111
+ head_dim = inner_dim // attn.heads
112
+
113
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
114
+
115
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
116
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
117
+
118
+ if attn.norm_q is not None:
119
+ query = attn.norm_q(query)
120
+ if attn.norm_k is not None:
121
+ key = attn.norm_k(key)
122
+
123
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
124
+ # TODO: add support for attn.scale when we move to Torch 2.1
125
+ hidden_states = scaled_dot_product_attention(
126
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
127
+ )
128
+
129
+ hidden_states = hidden_states.transpose(1, 2).reshape(
130
+ batch_size, -1, attn.heads * head_dim
131
+ )
132
+ hidden_states = hidden_states.to(query.dtype)
133
+
134
+ # linear proj
135
+ hidden_states = attn.to_out[0](hidden_states)
136
+ # dropout
137
+ hidden_states = attn.to_out[1](hidden_states)
138
+
139
+ if input_ndim == 4:
140
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
141
+ batch_size, channel, height, width
142
+ )
143
+
144
+ if attn.residual_connection:
145
+ hidden_states = hidden_states + residual
146
+
147
+ hidden_states = hidden_states / attn.rescale_output_factor
148
+
149
+ return hidden_states
150
+
151
+
152
+ class FusedAttnProcessor2_0:
153
+ r"""
154
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
155
+ fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
156
+ For cross-attention modules, key and value projection matrices are fused.
157
+
158
+ <Tip warning={true}>
159
+
160
+ This API is currently 🧪 experimental in nature and can change in future.
161
+
162
+ </Tip>
163
+ """
164
+
165
+ def __init__(self):
166
+ if not hasattr(F, "scaled_dot_product_attention"):
167
+ raise ImportError(
168
+ "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
169
+ )
170
+
171
+ def __call__(
172
+ self,
173
+ attn: Attention,
174
+ hidden_states: torch.Tensor,
175
+ encoder_hidden_states: Optional[torch.Tensor] = None,
176
+ attention_mask: Optional[torch.Tensor] = None,
177
+ temb: Optional[torch.Tensor] = None,
178
+ *args,
179
+ **kwargs,
180
+ ) -> torch.Tensor:
181
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
182
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
183
+ deprecate("scale", "1.0.0", deprecation_message)
184
+
185
+ residual = hidden_states
186
+ if attn.spatial_norm is not None:
187
+ hidden_states = attn.spatial_norm(hidden_states, temb)
188
+
189
+ input_ndim = hidden_states.ndim
190
+
191
+ if input_ndim == 4:
192
+ batch_size, channel, height, width = hidden_states.shape
193
+ hidden_states = hidden_states.view(
194
+ batch_size, channel, height * width
195
+ ).transpose(1, 2)
196
+
197
+ batch_size, sequence_length, _ = (
198
+ hidden_states.shape
199
+ if encoder_hidden_states is None
200
+ else encoder_hidden_states.shape
201
+ )
202
+
203
+ if attention_mask is not None:
204
+ attention_mask = attn.prepare_attention_mask(
205
+ attention_mask, sequence_length, batch_size
206
+ )
207
+ # scaled_dot_product_attention expects attention_mask shape to be
208
+ # (batch, heads, source_length, target_length)
209
+ attention_mask = attention_mask.view(
210
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
211
+ )
212
+
213
+ if attn.group_norm is not None:
214
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
215
+ 1, 2
216
+ )
217
+
218
+ if encoder_hidden_states is None:
219
+ qkv = attn.to_qkv(hidden_states)
220
+ split_size = qkv.shape[-1] // 3
221
+ query, key, value = torch.split(qkv, split_size, dim=-1)
222
+ else:
223
+ if attn.norm_cross:
224
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
225
+ encoder_hidden_states
226
+ )
227
+ query = attn.to_q(hidden_states)
228
+
229
+ kv = attn.to_kv(encoder_hidden_states)
230
+ split_size = kv.shape[-1] // 2
231
+ key, value = torch.split(kv, split_size, dim=-1)
232
+
233
+ inner_dim = key.shape[-1]
234
+ head_dim = inner_dim // attn.heads
235
+
236
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
237
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
238
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
239
+
240
+ if attn.norm_q is not None:
241
+ query = attn.norm_q(query)
242
+ if attn.norm_k is not None:
243
+ key = attn.norm_k(key)
244
+
245
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
246
+ # TODO: add support for attn.scale when we move to Torch 2.1
247
+ hidden_states = F.scaled_dot_product_attention(
248
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
249
+ )
250
+
251
+ hidden_states = hidden_states.transpose(1, 2).reshape(
252
+ batch_size, -1, attn.heads * head_dim
253
+ )
254
+ hidden_states = hidden_states.to(query.dtype)
255
+
256
+ # linear proj
257
+ hidden_states = attn.to_out[0](hidden_states)
258
+ # dropout
259
+ hidden_states = attn.to_out[1](hidden_states)
260
+
261
+ if input_ndim == 4:
262
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
263
+ batch_size, channel, height, width
264
+ )
265
+
266
+ if attn.residual_connection:
267
+ hidden_states = hidden_states + residual
268
+
269
+ hidden_states = hidden_states / attn.rescale_output_factor
270
+
271
+ return hidden_states
272
+
273
+
274
+ class FluxAttnProcessor2_0:
275
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
276
+
277
+ def __init__(self):
278
+ if not hasattr(F, "scaled_dot_product_attention"):
279
+ raise ImportError(
280
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
281
+ )
282
+
283
+ def __call__(
284
+ self,
285
+ attn: Attention,
286
+ hidden_states: torch.FloatTensor,
287
+ encoder_hidden_states: torch.FloatTensor = None,
288
+ attention_mask: Optional[torch.FloatTensor] = None,
289
+ image_rotary_emb: Optional[torch.Tensor] = None,
290
+ ) -> torch.FloatTensor:
291
+ batch_size, _, _ = (
292
+ hidden_states.shape
293
+ if encoder_hidden_states is None
294
+ else encoder_hidden_states.shape
295
+ )
296
+
297
+ # `sample` projections.
298
+ query = attn.to_q(hidden_states)
299
+ key = attn.to_k(hidden_states)
300
+ value = attn.to_v(hidden_states)
301
+
302
+ inner_dim = key.shape[-1]
303
+ head_dim = inner_dim // attn.heads
304
+
305
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
306
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
307
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
308
+
309
+ if attn.norm_q is not None:
310
+ query = attn.norm_q(query)
311
+ if attn.norm_k is not None:
312
+ key = attn.norm_k(key)
313
+
314
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
315
+ if encoder_hidden_states is not None:
316
+ # `context` projections.
317
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
318
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
319
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
320
+
321
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
322
+ batch_size, -1, attn.heads, head_dim
323
+ ).transpose(1, 2)
324
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
325
+ batch_size, -1, attn.heads, head_dim
326
+ ).transpose(1, 2)
327
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
328
+ batch_size, -1, attn.heads, head_dim
329
+ ).transpose(1, 2)
330
+
331
+ if attn.norm_added_q is not None:
332
+ encoder_hidden_states_query_proj = attn.norm_added_q(
333
+ encoder_hidden_states_query_proj
334
+ )
335
+ if attn.norm_added_k is not None:
336
+ encoder_hidden_states_key_proj = attn.norm_added_k(
337
+ encoder_hidden_states_key_proj
338
+ )
339
+
340
+ # attention
341
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
342
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
343
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
344
+
345
+ if image_rotary_emb is not None:
346
+ from .embeddings import apply_rotary_emb
347
+
348
+ query = apply_rotary_emb(query, image_rotary_emb)
349
+ key = apply_rotary_emb(key, image_rotary_emb)
350
+
351
+ hidden_states = scaled_dot_product_attention(
352
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
353
+ )
354
+
355
+ hidden_states = hidden_states.transpose(1, 2).reshape(
356
+ batch_size, -1, attn.heads * head_dim
357
+ )
358
+ hidden_states = hidden_states.to(query.dtype)
359
+
360
+ if encoder_hidden_states is not None:
361
+ encoder_hidden_states, hidden_states = (
362
+ hidden_states[:, : encoder_hidden_states.shape[1]],
363
+ hidden_states[:, encoder_hidden_states.shape[1] :],
364
+ )
365
+
366
+ # linear proj
367
+ hidden_states = attn.to_out[0](hidden_states)
368
+ # dropout
369
+ hidden_states = attn.to_out[1](hidden_states)
370
+
371
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
372
+
373
+ return hidden_states, encoder_hidden_states
374
+ else:
375
+ return hidden_states
376
+
377
+
378
+ class FusedFluxAttnProcessor2_0:
379
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
380
+
381
+ def __init__(self):
382
+ if not hasattr(F, "scaled_dot_product_attention"):
383
+ raise ImportError(
384
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
385
+ )
386
+
387
+ def __call__(
388
+ self,
389
+ attn: Attention,
390
+ hidden_states: torch.FloatTensor,
391
+ encoder_hidden_states: torch.FloatTensor = None,
392
+ attention_mask: Optional[torch.FloatTensor] = None,
393
+ image_rotary_emb: Optional[torch.Tensor] = None,
394
+ ) -> torch.FloatTensor:
395
+ batch_size, _, _ = (
396
+ hidden_states.shape
397
+ if encoder_hidden_states is None
398
+ else encoder_hidden_states.shape
399
+ )
400
+
401
+ # `sample` projections.
402
+ qkv = attn.to_qkv(hidden_states)
403
+ split_size = qkv.shape[-1] // 3
404
+ query, key, value = torch.split(qkv, split_size, dim=-1)
405
+
406
+ inner_dim = key.shape[-1]
407
+ head_dim = inner_dim // attn.heads
408
+
409
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
410
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
411
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
412
+
413
+ if attn.norm_q is not None:
414
+ query = attn.norm_q(query)
415
+ if attn.norm_k is not None:
416
+ key = attn.norm_k(key)
417
+
418
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
419
+ # `context` projections.
420
+ if encoder_hidden_states is not None:
421
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
422
+ split_size = encoder_qkv.shape[-1] // 3
423
+ (
424
+ encoder_hidden_states_query_proj,
425
+ encoder_hidden_states_key_proj,
426
+ encoder_hidden_states_value_proj,
427
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
428
+
429
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
430
+ batch_size, -1, attn.heads, head_dim
431
+ ).transpose(1, 2)
432
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
433
+ batch_size, -1, attn.heads, head_dim
434
+ ).transpose(1, 2)
435
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
436
+ batch_size, -1, attn.heads, head_dim
437
+ ).transpose(1, 2)
438
+
439
+ if attn.norm_added_q is not None:
440
+ encoder_hidden_states_query_proj = attn.norm_added_q(
441
+ encoder_hidden_states_query_proj
442
+ )
443
+ if attn.norm_added_k is not None:
444
+ encoder_hidden_states_key_proj = attn.norm_added_k(
445
+ encoder_hidden_states_key_proj
446
+ )
447
+
448
+ # attention
449
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
450
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
451
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
452
+
453
+ if image_rotary_emb is not None:
454
+ from .embeddings import apply_rotary_emb
455
+
456
+ query = apply_rotary_emb(query, image_rotary_emb)
457
+ key = apply_rotary_emb(key, image_rotary_emb)
458
+
459
+ hidden_states = scaled_dot_product_attention(
460
+ query, key, value, dropout_p=0.0, is_causal=False
461
+ )
462
+
463
+ hidden_states = hidden_states.transpose(1, 2).reshape(
464
+ batch_size, -1, attn.heads * head_dim
465
+ )
466
+ hidden_states = hidden_states.to(query.dtype)
467
+
468
+ if encoder_hidden_states is not None:
469
+ encoder_hidden_states, hidden_states = (
470
+ hidden_states[:, : encoder_hidden_states.shape[1]],
471
+ hidden_states[:, encoder_hidden_states.shape[1] :],
472
+ )
473
+
474
+ # linear proj
475
+ hidden_states = attn.to_out[0](hidden_states)
476
+ # dropout
477
+ hidden_states = attn.to_out[1](hidden_states)
478
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
479
+
480
+ return hidden_states, encoder_hidden_states
481
+ else:
482
+ return hidden_states
step1x3d_geometry/models/autoencoders/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import (
2
+ michelangelo_autoencoder,
3
+ )
step1x3d_geometry/models/autoencoders/michelangelo_autoencoder.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import math
3
+
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ import time
8
+ import trimesh
9
+ import torch.nn as nn
10
+ from einops import repeat, rearrange
11
+ from tqdm import trange
12
+ from itertools import product
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+
15
+ import step1x3d_geometry
16
+ from step1x3d_geometry.utils.checkpoint import checkpoint
17
+ from step1x3d_geometry.utils.base import BaseModule
18
+ from step1x3d_geometry.utils.typing import *
19
+ from step1x3d_geometry.utils.misc import get_world_size, get_device
20
+
21
+ from .transformers.perceiver_1d import Perceiver
22
+ from .transformers.attention import ResidualCrossAttentionBlock
23
+ from .volume_decoders import HierarchicalVolumeDecoder, VanillaVolumeDecoder
24
+ from .surface_extractors import MCSurfaceExtractor, DMCSurfaceExtractor
25
+
26
+ from ..pipelines.pipeline_utils import smart_load_model
27
+ from safetensors.torch import load_file
28
+
29
+ VALID_EMBED_TYPES = ["identity", "fourier", "learned_fourier", "siren"]
30
+
31
+
32
+ class FourierEmbedder(nn.Module):
33
+ def __init__(
34
+ self,
35
+ num_freqs: int = 6,
36
+ logspace: bool = True,
37
+ input_dim: int = 3,
38
+ include_input: bool = True,
39
+ include_pi: bool = True,
40
+ ) -> None:
41
+ super().__init__()
42
+
43
+ if logspace:
44
+ frequencies = 2.0 ** torch.arange(num_freqs, dtype=torch.float32)
45
+ else:
46
+ frequencies = torch.linspace(
47
+ 1.0, 2.0 ** (num_freqs - 1), num_freqs, dtype=torch.float32
48
+ )
49
+
50
+ if include_pi:
51
+ frequencies *= torch.pi
52
+
53
+ self.register_buffer("frequencies", frequencies, persistent=False)
54
+ self.include_input = include_input
55
+ self.num_freqs = num_freqs
56
+
57
+ self.out_dim = self.get_dims(input_dim)
58
+
59
+ def get_dims(self, input_dim):
60
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
61
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
62
+
63
+ return out_dim
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ if self.num_freqs > 0:
67
+ embed = (x[..., None].contiguous() * self.frequencies).view(
68
+ *x.shape[:-1], -1
69
+ )
70
+ if self.include_input:
71
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
72
+ else:
73
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
74
+ else:
75
+ return x
76
+
77
+
78
+ class LearnedFourierEmbedder(nn.Module):
79
+ def __init__(self, input_dim, dim):
80
+ super().__init__()
81
+ assert (dim % 2) == 0
82
+ half_dim = dim // 2
83
+ per_channel_dim = half_dim // input_dim
84
+ self.weights = nn.Parameter(torch.randn(per_channel_dim))
85
+
86
+ self.out_dim = self.get_dims(input_dim)
87
+
88
+ def forward(self, x):
89
+ # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
90
+ freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
91
+ fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
92
+ return fouriered
93
+
94
+ def get_dims(self, input_dim):
95
+ return input_dim * (self.weights.shape[0] * 2 + 1)
96
+
97
+
98
+ class Sine(nn.Module):
99
+ def __init__(self, w0=1.0):
100
+ super().__init__()
101
+ self.w0 = w0
102
+
103
+ def forward(self, x):
104
+ return torch.sin(self.w0 * x)
105
+
106
+
107
+ class Siren(nn.Module):
108
+ def __init__(
109
+ self,
110
+ in_dim,
111
+ out_dim,
112
+ w0=1.0,
113
+ c=6.0,
114
+ is_first=False,
115
+ use_bias=True,
116
+ activation=None,
117
+ dropout=0.0,
118
+ ):
119
+ super().__init__()
120
+ self.in_dim = in_dim
121
+ self.out_dim = out_dim
122
+ self.is_first = is_first
123
+
124
+ weight = torch.zeros(out_dim, in_dim)
125
+ bias = torch.zeros(out_dim) if use_bias else None
126
+ self.init_(weight, bias, c=c, w0=w0)
127
+
128
+ self.weight = nn.Parameter(weight)
129
+ self.bias = nn.Parameter(bias) if use_bias else None
130
+ self.activation = Sine(w0) if activation is None else activation
131
+ self.dropout = nn.Dropout(dropout)
132
+
133
+ def init_(self, weight, bias, c, w0):
134
+ dim = self.in_dim
135
+
136
+ w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
137
+ weight.uniform_(-w_std, w_std)
138
+
139
+ if bias is not None:
140
+ bias.uniform_(-w_std, w_std)
141
+
142
+ def forward(self, x):
143
+ out = F.linear(x, self.weight, self.bias)
144
+ out = self.activation(out)
145
+ out = self.dropout(out)
146
+ return out
147
+
148
+
149
+ def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, include_pi=True):
150
+ if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
151
+ return nn.Identity(), input_dim
152
+
153
+ elif embed_type == "fourier":
154
+ embedder_obj = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
155
+
156
+ elif embed_type == "learned_fourier":
157
+ embedder_obj = LearnedFourierEmbedder(in_channels=input_dim, dim=num_freqs)
158
+
159
+ elif embed_type == "siren":
160
+ embedder_obj = Siren(
161
+ in_dim=input_dim, out_dim=num_freqs * input_dim * 2 + input_dim
162
+ )
163
+
164
+ else:
165
+ raise ValueError(
166
+ f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}"
167
+ )
168
+ return embedder_obj
169
+
170
+
171
+ ###################### AutoEncoder
172
+ class DiagonalGaussianDistribution(ModelMixin, object):
173
+ def __init__(
174
+ self,
175
+ parameters: Union[torch.Tensor, List[torch.Tensor]],
176
+ deterministic=False,
177
+ feat_dim=1,
178
+ ):
179
+ self.feat_dim = feat_dim
180
+ self.parameters = parameters
181
+
182
+ if isinstance(parameters, list):
183
+ self.mean = parameters[0]
184
+ self.logvar = parameters[1]
185
+ else:
186
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
187
+
188
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
189
+ self.deterministic = deterministic
190
+ self.std = torch.exp(0.5 * self.logvar)
191
+ self.var = torch.exp(self.logvar)
192
+ if self.deterministic:
193
+ self.var = self.std = torch.zeros_like(self.mean)
194
+
195
+ def sample(self):
196
+ x = self.mean + self.std * torch.randn_like(self.mean)
197
+ return x
198
+
199
+ def kl(self, other=None, dims=(1, 2)):
200
+ if self.deterministic:
201
+ return torch.Tensor([0.0])
202
+ else:
203
+ if other is None:
204
+ return 0.5 * torch.mean(
205
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=dims
206
+ )
207
+ else:
208
+ return 0.5 * torch.mean(
209
+ torch.pow(self.mean - other.mean, 2) / other.var
210
+ + self.var / other.var
211
+ - 1.0
212
+ - self.logvar
213
+ + other.logvar,
214
+ dim=dims,
215
+ )
216
+
217
+ def nll(self, sample, dims=(1, 2)):
218
+ if self.deterministic:
219
+ return torch.Tensor([0.0])
220
+ logtwopi = np.log(2.0 * np.pi)
221
+ return 0.5 * torch.sum(
222
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
223
+ dim=dims,
224
+ )
225
+
226
+ def mode(self):
227
+ return self.mean
228
+
229
+
230
+ class PerceiverCrossAttentionEncoder(ModelMixin, nn.Module):
231
+ def __init__(
232
+ self,
233
+ use_downsample: bool,
234
+ num_latents: int,
235
+ embedder: FourierEmbedder,
236
+ point_feats: int,
237
+ embed_point_feats: bool,
238
+ width: int,
239
+ heads: int,
240
+ layers: int,
241
+ init_scale: float = 0.25,
242
+ qkv_bias: bool = True,
243
+ qk_norm: bool = True,
244
+ use_ln_post: bool = False,
245
+ use_flash: bool = False,
246
+ use_checkpoint: bool = False,
247
+ use_multi_reso: bool = False,
248
+ resolutions: list = [],
249
+ sampling_prob: list = [],
250
+ with_sharp_data: bool = False,
251
+ ):
252
+
253
+ super().__init__()
254
+
255
+ self.use_checkpoint = use_checkpoint
256
+ self.num_latents = num_latents
257
+ self.use_downsample = use_downsample
258
+ self.embed_point_feats = embed_point_feats
259
+ self.use_multi_reso = use_multi_reso
260
+ self.resolutions = resolutions
261
+ self.sampling_prob = sampling_prob
262
+
263
+ if not self.use_downsample:
264
+ self.query = nn.Parameter(torch.randn((num_latents, width)) * 0.02)
265
+
266
+ self.embedder = embedder
267
+ if self.embed_point_feats:
268
+ self.input_proj = nn.Linear(self.embedder.out_dim * 2, width)
269
+ else:
270
+ self.input_proj = nn.Linear(self.embedder.out_dim + point_feats, width)
271
+
272
+ self.cross_attn = ResidualCrossAttentionBlock(
273
+ width=width,
274
+ heads=heads,
275
+ init_scale=init_scale,
276
+ qkv_bias=qkv_bias,
277
+ qk_norm=qk_norm,
278
+ use_flash=use_flash,
279
+ )
280
+
281
+ self.with_sharp_data = with_sharp_data
282
+ if with_sharp_data:
283
+ self.downsmaple_num_latents = num_latents // 2
284
+ self.input_proj_sharp = nn.Linear(
285
+ self.embedder.out_dim + point_feats, width
286
+ )
287
+ self.cross_attn_sharp = ResidualCrossAttentionBlock(
288
+ width=width,
289
+ heads=heads,
290
+ init_scale=init_scale,
291
+ qkv_bias=qkv_bias,
292
+ qk_norm=qk_norm,
293
+ use_flash=use_flash,
294
+ )
295
+ else:
296
+ self.downsmaple_num_latents = num_latents
297
+
298
+ self.self_attn = Perceiver(
299
+ n_ctx=num_latents,
300
+ width=width,
301
+ layers=layers,
302
+ heads=heads,
303
+ init_scale=init_scale,
304
+ qkv_bias=qkv_bias,
305
+ qk_norm=qk_norm,
306
+ use_flash=use_flash,
307
+ use_checkpoint=use_checkpoint,
308
+ )
309
+
310
+ if use_ln_post:
311
+ self.ln_post = nn.LayerNorm(width)
312
+ else:
313
+ self.ln_post = None
314
+
315
+ def _forward(self, pc, feats, sharp_pc=None, sharp_feat=None):
316
+ """
317
+
318
+ Args:
319
+ pc (torch.FloatTensor): [B, N, 3]
320
+ feats (torch.FloatTensor or None): [B, N, C]
321
+
322
+ Returns:
323
+
324
+ """
325
+
326
+ bs, N, D = pc.shape
327
+
328
+ data = self.embedder(pc)
329
+ if feats is not None:
330
+ if self.embed_point_feats:
331
+ feats = self.embedder(feats)
332
+ data = torch.cat([data, feats], dim=-1)
333
+ data = self.input_proj(data)
334
+
335
+ if self.with_sharp_data:
336
+ sharp_data = self.embedder(sharp_pc)
337
+ if sharp_feat is not None:
338
+ if self.embed_point_feats:
339
+ sharp_feat = self.embedder(sharp_feat)
340
+ sharp_data = torch.cat([sharp_data, sharp_feat], dim=-1)
341
+ sharp_data = self.input_proj_sharp(sharp_data)
342
+
343
+ if self.use_multi_reso:
344
+ resolution = random.choice(self.resolutions, size=1, p=self.sampling_prob)[
345
+ 0
346
+ ]
347
+
348
+ if resolution != N:
349
+ flattened = pc.view(bs * N, D) # bs*N, 64. 103,4096,3 -> 421888,3
350
+ batch = torch.arange(bs).to(pc.device) # 103
351
+ batch = torch.repeat_interleave(batch, N) # bs*N. 421888
352
+ pos = flattened.to(torch.float16)
353
+ ratio = 1.0 * resolution / N # 0.0625
354
+ idx = fps(pos, batch, ratio=ratio) # 26368
355
+ pc = pc.view(bs * N, -1)[idx].view(bs, -1, D)
356
+ bs, N, D = feats.shape
357
+ flattened1 = feats.view(bs * N, D)
358
+ feats = flattened1.view(bs * N, -1)[idx].view(bs, -1, D)
359
+ bs, N, D = pc.shape
360
+
361
+ if self.use_downsample:
362
+ ###### fps
363
+ from torch_cluster import fps
364
+
365
+ flattened = pc.view(bs * N, D) # bs*N, 64
366
+
367
+ batch = torch.arange(bs).to(pc.device)
368
+ batch = torch.repeat_interleave(batch, N) # bs*N
369
+
370
+ pos = flattened.to(torch.float16)
371
+ ratio = 1.0 * self.downsmaple_num_latents / N
372
+ idx = fps(pos, batch, ratio=ratio).detach()
373
+ query = data.view(bs * N, -1)[idx].view(bs, -1, data.shape[-1])
374
+
375
+ if self.with_sharp_data:
376
+ bs, N, D = sharp_pc.shape
377
+ flattened = sharp_pc.view(bs * N, D) # bs*N, 64
378
+ pos = flattened.to(torch.float16)
379
+ ratio = 1.0 * self.downsmaple_num_latents / N
380
+ idx = fps(pos, batch, ratio=ratio).detach()
381
+ sharp_query = sharp_data.view(bs * N, -1)[idx].view(
382
+ bs, -1, sharp_data.shape[-1]
383
+ )
384
+ query = torch.cat([query, sharp_query], dim=1)
385
+ else:
386
+ query = self.query
387
+ query = repeat(query, "m c -> b m c", b=bs)
388
+
389
+ latents = self.cross_attn(query, data)
390
+ if self.with_sharp_data:
391
+ latents = latents + self.cross_attn_sharp(query, sharp_data)
392
+ latents = self.self_attn(latents)
393
+
394
+ if self.ln_post is not None:
395
+ latents = self.ln_post(latents)
396
+
397
+ return latents
398
+
399
+ def forward(
400
+ self,
401
+ pc: torch.FloatTensor,
402
+ feats: Optional[torch.FloatTensor] = None,
403
+ sharp_pc: Optional[torch.FloatTensor] = None,
404
+ sharp_feats: Optional[torch.FloatTensor] = None,
405
+ ):
406
+ """
407
+
408
+ Args:
409
+ pc (torch.FloatTensor): [B, N, 3]
410
+ feats (torch.FloatTensor or None): [B, N, C]
411
+
412
+ Returns:
413
+ dict
414
+ """
415
+
416
+ return checkpoint(
417
+ self._forward,
418
+ (pc, feats, sharp_pc, sharp_feats),
419
+ self.parameters(),
420
+ self.use_checkpoint,
421
+ )
422
+
423
+
424
+ class PerceiverCrossAttentionDecoder(ModelMixin, nn.Module):
425
+
426
+ def __init__(
427
+ self,
428
+ num_latents: int,
429
+ out_dim: int,
430
+ embedder: FourierEmbedder,
431
+ width: int,
432
+ heads: int,
433
+ init_scale: float = 0.25,
434
+ qkv_bias: bool = True,
435
+ qk_norm: bool = True,
436
+ use_flash: bool = False,
437
+ use_checkpoint: bool = False,
438
+ ):
439
+
440
+ super().__init__()
441
+
442
+ self.use_checkpoint = use_checkpoint
443
+ self.embedder = embedder
444
+
445
+ self.query_proj = nn.Linear(self.embedder.out_dim, width)
446
+
447
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
448
+ n_data=num_latents,
449
+ width=width,
450
+ heads=heads,
451
+ init_scale=init_scale,
452
+ qkv_bias=qkv_bias,
453
+ qk_norm=qk_norm,
454
+ use_flash=use_flash,
455
+ )
456
+
457
+ self.ln_post = nn.LayerNorm(width)
458
+ self.output_proj = nn.Linear(width, out_dim)
459
+
460
+ def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
461
+ queries = self.query_proj(self.embedder(queries))
462
+ x = self.cross_attn_decoder(queries, latents)
463
+ x = self.ln_post(x)
464
+ x = self.output_proj(x)
465
+ return x
466
+
467
+ def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
468
+ return checkpoint(
469
+ self._forward, (queries, latents), self.parameters(), self.use_checkpoint
470
+ )
471
+
472
+
473
+ @step1x3d_geometry.register("michelangelo-autoencoder")
474
+ class MichelangeloAutoencoder(BaseModule):
475
+ r"""
476
+ A VAE model for encoding shapes into latents and decoding latent representations into shapes.
477
+ """
478
+
479
+ @dataclass
480
+ class Config(BaseModule.Config):
481
+ pretrained_model_name_or_path: str = ""
482
+ subfolder: str = ""
483
+ n_samples: int = 4096
484
+ use_downsample: bool = False
485
+ downsample_ratio: float = 0.0625
486
+ num_latents: int = 256
487
+ point_feats: int = 0
488
+ embed_point_feats: bool = False
489
+ out_dim: int = 1
490
+ embed_dim: int = 64
491
+ embed_type: str = "fourier"
492
+ num_freqs: int = 8
493
+ include_pi: bool = True
494
+ width: int = 768
495
+ heads: int = 12
496
+ num_encoder_layers: int = 8
497
+ num_decoder_layers: int = 16
498
+ init_scale: float = 0.25
499
+ qkv_bias: bool = True
500
+ qk_norm: bool = False
501
+ use_ln_post: bool = False
502
+ use_flash: bool = False
503
+ use_checkpoint: bool = True
504
+ use_multi_reso: Optional[bool] = False
505
+ resolutions: Optional[List[int]] = None
506
+ sampling_prob: Optional[List[float]] = None
507
+ with_sharp_data: Optional[bool] = True
508
+ volume_decoder_type: str = "hierarchical"
509
+ surface_extractor_type: str = "mc"
510
+ z_scale_factor: float = 1.0
511
+
512
+ cfg: Config
513
+
514
+ def configure(self) -> None:
515
+ super().configure()
516
+
517
+ self.embedder = get_embedder(
518
+ embed_type=self.cfg.embed_type,
519
+ num_freqs=self.cfg.num_freqs,
520
+ include_pi=self.cfg.include_pi,
521
+ )
522
+
523
+ # encoder
524
+ self.cfg.init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width)
525
+ self.encoder = PerceiverCrossAttentionEncoder(
526
+ use_downsample=self.cfg.use_downsample,
527
+ embedder=self.embedder,
528
+ num_latents=self.cfg.num_latents,
529
+ point_feats=self.cfg.point_feats,
530
+ embed_point_feats=self.cfg.embed_point_feats,
531
+ width=self.cfg.width,
532
+ heads=self.cfg.heads,
533
+ layers=self.cfg.num_encoder_layers,
534
+ init_scale=self.cfg.init_scale,
535
+ qkv_bias=self.cfg.qkv_bias,
536
+ qk_norm=self.cfg.qk_norm,
537
+ use_ln_post=self.cfg.use_ln_post,
538
+ use_flash=self.cfg.use_flash,
539
+ use_checkpoint=self.cfg.use_checkpoint,
540
+ use_multi_reso=self.cfg.use_multi_reso,
541
+ resolutions=self.cfg.resolutions,
542
+ sampling_prob=self.cfg.sampling_prob,
543
+ with_sharp_data=self.cfg.with_sharp_data,
544
+ )
545
+
546
+ if self.cfg.embed_dim > 0:
547
+ # VAE embed
548
+ self.pre_kl = nn.Linear(self.cfg.width, self.cfg.embed_dim * 2)
549
+ self.post_kl = nn.Linear(self.cfg.embed_dim, self.cfg.width)
550
+ self.latent_shape = (self.cfg.num_latents, self.cfg.embed_dim)
551
+ else:
552
+ self.latent_shape = (self.cfg.num_latents, self.cfg.width)
553
+
554
+ self.transformer = Perceiver(
555
+ n_ctx=self.cfg.num_latents,
556
+ width=self.cfg.width,
557
+ layers=self.cfg.num_decoder_layers,
558
+ heads=self.cfg.heads,
559
+ init_scale=self.cfg.init_scale,
560
+ qkv_bias=self.cfg.qkv_bias,
561
+ qk_norm=self.cfg.qk_norm,
562
+ use_flash=self.cfg.use_flash,
563
+ use_checkpoint=self.cfg.use_checkpoint,
564
+ )
565
+
566
+ # decoder
567
+ self.decoder = PerceiverCrossAttentionDecoder(
568
+ embedder=self.embedder,
569
+ out_dim=self.cfg.out_dim,
570
+ num_latents=self.cfg.num_latents,
571
+ width=self.cfg.width,
572
+ heads=self.cfg.heads,
573
+ init_scale=self.cfg.init_scale,
574
+ qkv_bias=self.cfg.qkv_bias,
575
+ qk_norm=self.cfg.qk_norm,
576
+ use_flash=self.cfg.use_flash,
577
+ use_checkpoint=self.cfg.use_checkpoint,
578
+ )
579
+
580
+ # volume decoder
581
+ if self.cfg.volume_decoder_type == "hierarchical":
582
+ self.volume_decoder = HierarchicalVolumeDecoder()
583
+ else:
584
+ self.volume_decoder = VanillaVolumeDecoder()
585
+
586
+ if self.cfg.pretrained_model_name_or_path != "":
587
+ local_model_path = f"{smart_load_model(self.cfg.pretrained_model_name_or_path, self.cfg.subfolder)}/vae/diffusion_pytorch_model.safetensors"
588
+ pretrain_safetensors = load_file(local_model_path)
589
+ print(f"Loading pretrained VAE model from {local_model_path}")
590
+
591
+ if "state_dict" in pretrain_safetensors:
592
+ _pretrained_safetensors = {}
593
+ for k, v in pretrain_safetensors["state_dict"].items():
594
+ if k.startswith("shape_model."):
595
+ if "proj1" in k:
596
+ _pretrained_safetensors[
597
+ k.replace("shape_model.", "").replace(
598
+ "proj1", "proj_sharp"
599
+ )
600
+ ] = v
601
+ elif "attn1" in k:
602
+ _pretrained_safetensors[
603
+ k.replace("shape_model.", "").replace(
604
+ "attn1", "attn_sharp"
605
+ )
606
+ ] = v
607
+ else:
608
+ _pretrained_safetensors[k.replace("shape_model.", "")] = v
609
+
610
+ pretrain_safetensors = _pretrained_safetensors
611
+ self.load_state_dict(pretrain_safetensors, strict=True)
612
+ else:
613
+ _pretrained_safetensors = {}
614
+ for k, v in pretrain_safetensors.items():
615
+ if k.startswith("shape_model"):
616
+ final_module = self
617
+ for key in k.replace("shape_model.", "").split("."):
618
+ final_module = getattr(final_module, key)
619
+ data = final_module.data
620
+ data_zero = torch.zeros_like(data).to(v)
621
+
622
+ if data.shape != v.shape:
623
+ if data.ndim == 1:
624
+ data_zero[: v.shape[0]] = v
625
+ elif data.ndim == 2:
626
+ data_zero[: v.shape[0], : v.shape[1]] = v
627
+ v = data_zero
628
+
629
+ _pretrained_safetensors[k.replace("shape_model.", "")] = v
630
+ else:
631
+ _pretrained_safetensors[k] = v
632
+ pretrain_safetensors = _pretrained_safetensors
633
+ self.load_state_dict(pretrain_safetensors, strict=True)
634
+ print("Successed load pretrained VAE model")
635
+
636
+ def encode(
637
+ self,
638
+ surface: torch.FloatTensor,
639
+ sample_posterior: bool = True,
640
+ sharp_surface: torch.FloatTensor = None,
641
+ ):
642
+ """
643
+ Args:
644
+ surface (torch.FloatTensor): [B, N, 3+C]
645
+ sample_posterior (bool):
646
+
647
+ Returns:
648
+ shape_latents (torch.FloatTensor): [B, num_latents, width]
649
+ kl_embed (torch.FloatTensor): [B, num_latents, embed_dim]
650
+ posterior (DiagonalGaussianDistribution or None):
651
+ """
652
+ assert (
653
+ surface.shape[-1] == 3 + self.cfg.point_feats
654
+ ), f"\
655
+ Expected {3 + self.cfg.point_feats} channels, got {surface.shape[-1]}"
656
+
657
+ pc, feats = surface[..., :3], surface[..., 3:] # B, n_samples, 3
658
+ if sharp_surface is not None:
659
+ sharp_pc, sharp_feats = (
660
+ sharp_surface[..., :3],
661
+ sharp_surface[..., 3:],
662
+ ) # B, n_samples, 3
663
+ else:
664
+ sharp_pc, sharp_feats = None, None
665
+
666
+ shape_embeds = self.encoder(
667
+ pc, feats, sharp_pc, sharp_feats
668
+ ) # B, num_latents, width
669
+ kl_embed, posterior = self.encode_kl_embed(
670
+ shape_embeds, sample_posterior
671
+ ) # B, num_latents, embed_dim
672
+
673
+ kl_embed = kl_embed * self.cfg.z_scale_factor # encode with scale
674
+
675
+ return shape_embeds, kl_embed, posterior
676
+
677
+ def decode(self, latents: torch.FloatTensor):
678
+ """
679
+ Args:
680
+ latents (torch.FloatTensor): [B, embed_dim]
681
+
682
+ Returns:
683
+ latents (torch.FloatTensor): [B, embed_dim]
684
+ """
685
+ latents = self.post_kl(
686
+ latents / self.cfg.z_scale_factor
687
+ ) # [B, num_latents, embed_dim] -> [B, num_latents, width]
688
+
689
+ return self.transformer(latents)
690
+
691
+ def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor):
692
+ """
693
+ Args:
694
+ queries (torch.FloatTensor): [B, N, 3]
695
+ latents (torch.FloatTensor): [B, embed_dim]
696
+
697
+ Returns:
698
+ features (torch.FloatTensor): [B, N, C], output features
699
+ """
700
+
701
+ features = self.decoder(queries, latents)
702
+
703
+ return features
704
+
705
+ def encode_kl_embed(
706
+ self, latents: torch.FloatTensor, sample_posterior: bool = True
707
+ ):
708
+ posterior = None
709
+ if self.cfg.embed_dim > 0:
710
+ moments = self.pre_kl(latents)
711
+ posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
712
+ if sample_posterior:
713
+ kl_embed = posterior.sample()
714
+ else:
715
+ kl_embed = posterior.mode()
716
+ else:
717
+ kl_embed = latents
718
+ return kl_embed, posterior
719
+
720
+ def forward(
721
+ self,
722
+ surface: torch.FloatTensor,
723
+ sharp_surface: torch.FloatTensor = None,
724
+ rand_points: torch.FloatTensor = None,
725
+ sample_posterior: bool = True,
726
+ **kwargs,
727
+ ):
728
+ shape_latents, kl_embed, posterior = self.encode(
729
+ surface, sample_posterior=sample_posterior, sharp_surface=sharp_surface
730
+ )
731
+
732
+ latents = self.decode(kl_embed) # [B, num_latents, width]
733
+
734
+ meshes = self.extract_geometry(latents, **kwargs)
735
+
736
+ return shape_latents, latents, posterior, meshes
737
+
738
+ def extract_geometry(self, latents: torch.FloatTensor, **kwargs):
739
+
740
+ grid_logits_list = []
741
+ for i in range(latents.shape[0]):
742
+ grid_logits = self.volume_decoder(
743
+ latents[i].unsqueeze(0), self.query, **kwargs
744
+ )
745
+ grid_logits_list.append(grid_logits)
746
+ grid_logits = torch.cat(grid_logits_list, dim=0)
747
+
748
+ # extract mesh
749
+ surface_extractor_type = (
750
+ kwargs["surface_extractor_type"]
751
+ if "surface_extractor_type" in kwargs.keys()
752
+ and kwargs["surface_extractor_type"] is not None
753
+ else self.cfg.surface_extractor_type
754
+ )
755
+
756
+ if surface_extractor_type == "mc":
757
+ surface_extractor = MCSurfaceExtractor()
758
+ meshes = surface_extractor(grid_logits, **kwargs)
759
+ elif surface_extractor_type == "dmc":
760
+ surface_extractor = DMCSurfaceExtractor()
761
+ meshes = surface_extractor(grid_logits, **kwargs)
762
+ else:
763
+ raise NotImplementedError
764
+
765
+ return meshes
step1x3d_geometry/models/autoencoders/surface_extractors.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple, List
2
+
3
+ import numpy as np
4
+ import torch
5
+ from skimage import measure
6
+
7
+
8
+ class MeshExtractResult:
9
+ def __init__(self, verts, faces, vertex_attrs=None, res=64):
10
+ self.verts = verts
11
+ self.faces = faces.long()
12
+ self.vertex_attrs = vertex_attrs
13
+ self.face_normal = self.comput_face_normals()
14
+ self.vert_normal = self.comput_v_normals()
15
+ self.res = res
16
+ self.success = verts.shape[0] != 0 and faces.shape[0] != 0
17
+
18
+ # training only
19
+ self.tsdf_v = None
20
+ self.tsdf_s = None
21
+ self.reg_loss = None
22
+
23
+ def comput_face_normals(self):
24
+ i0 = self.faces[..., 0].long()
25
+ i1 = self.faces[..., 1].long()
26
+ i2 = self.faces[..., 2].long()
27
+
28
+ v0 = self.verts[i0, :]
29
+ v1 = self.verts[i1, :]
30
+ v2 = self.verts[i2, :]
31
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
32
+ face_normals = torch.nn.functional.normalize(face_normals, dim=1)
33
+ return face_normals[:, None, :].repeat(1, 3, 1)
34
+
35
+ def comput_v_normals(self):
36
+ i0 = self.faces[..., 0].long()
37
+ i1 = self.faces[..., 1].long()
38
+ i2 = self.faces[..., 2].long()
39
+
40
+ v0 = self.verts[i0, :]
41
+ v1 = self.verts[i1, :]
42
+ v2 = self.verts[i2, :]
43
+ face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
44
+ v_normals = torch.zeros_like(self.verts)
45
+ v_normals.scatter_add_(0, i0[..., None].repeat(1, 3), face_normals)
46
+ v_normals.scatter_add_(0, i1[..., None].repeat(1, 3), face_normals)
47
+ v_normals.scatter_add_(0, i2[..., None].repeat(1, 3), face_normals)
48
+
49
+ v_normals = torch.nn.functional.normalize(v_normals, dim=1)
50
+ return v_normals
51
+
52
+
53
+ def center_vertices(vertices):
54
+ """Translate the vertices so that bounding box is centered at zero."""
55
+ vert_min = vertices.min(dim=0)[0]
56
+ vert_max = vertices.max(dim=0)[0]
57
+ vert_center = 0.5 * (vert_min + vert_max)
58
+ return vertices - vert_center
59
+
60
+
61
+ class SurfaceExtractor:
62
+ def _compute_box_stat(
63
+ self, bounds: Union[Tuple[float], List[float], float], octree_resolution: int
64
+ ):
65
+ if isinstance(bounds, float):
66
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
67
+
68
+ bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
69
+ bbox_size = bbox_max - bbox_min
70
+ grid_size = [
71
+ int(octree_resolution) + 1,
72
+ int(octree_resolution) + 1,
73
+ int(octree_resolution) + 1,
74
+ ]
75
+ return grid_size, bbox_min, bbox_size
76
+
77
+ def run(self, *args, **kwargs):
78
+ return NotImplementedError
79
+
80
+ def __call__(self, grid_logits, **kwargs):
81
+ outputs = []
82
+ for i in range(grid_logits.shape[0]):
83
+ try:
84
+ verts, faces = self.run(grid_logits[i], **kwargs)
85
+ outputs.append(
86
+ MeshExtractResult(
87
+ verts=verts.float(),
88
+ faces=faces,
89
+ res=kwargs["octree_resolution"],
90
+ )
91
+ )
92
+
93
+ except Exception:
94
+ import traceback
95
+
96
+ traceback.print_exc()
97
+ outputs.append(None)
98
+
99
+ return outputs
100
+
101
+
102
+ class MCSurfaceExtractor(SurfaceExtractor):
103
+ def run(self, grid_logit, *, mc_level, bounds, octree_resolution, **kwargs):
104
+ verts, faces, normals, _ = measure.marching_cubes(
105
+ grid_logit.float().cpu().numpy(), mc_level, method="lewiner"
106
+ )
107
+ grid_size, bbox_min, bbox_size = self._compute_box_stat(
108
+ bounds, octree_resolution
109
+ )
110
+ verts = verts / grid_size * bbox_size + bbox_min
111
+ verts = torch.tensor(verts, device=grid_logit.device, dtype=torch.float32)
112
+ faces = torch.tensor(
113
+ np.ascontiguousarray(faces), device=grid_logit.device, dtype=torch.long
114
+ )
115
+ faces = faces[:, [2, 1, 0]]
116
+ return verts, faces
117
+
118
+
119
+ class DMCSurfaceExtractor(SurfaceExtractor):
120
+ def run(self, grid_logit, *, octree_resolution, **kwargs):
121
+ device = grid_logit.device
122
+ if not hasattr(self, "dmc"):
123
+ try:
124
+ from diso import DiffDMC
125
+ except:
126
+ raise ImportError(
127
+ "Please install diso via `pip install diso`, or set mc_algo to 'mc'"
128
+ )
129
+ self.dmc = DiffDMC(dtype=torch.float32).to(device)
130
+ sdf = -grid_logit / octree_resolution
131
+ sdf = sdf.to(torch.float32).contiguous()
132
+ verts, faces = self.dmc(sdf, deform=None, return_quads=False, normalize=True)
133
+ grid_size, bbox_min, bbox_size = self._compute_box_stat(
134
+ kwargs["bounds"], octree_resolution
135
+ )
136
+ verts = verts * kwargs["bounds"] * 2 - kwargs["bounds"]
137
+ return verts, faces
step1x3d_geometry/models/autoencoders/transformers/attention.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from step1x3d_geometry.utils.typing import *
7
+ from step1x3d_geometry.utils.checkpoint import checkpoint
8
+
9
+ from .utils import init_linear, MLP
10
+ from timm.models.vision_transformer import Attention
11
+
12
+
13
+ class MultiheadAttention(nn.Module):
14
+ def __init__(
15
+ self,
16
+ *,
17
+ n_ctx: int,
18
+ width: int,
19
+ heads: int,
20
+ init_scale: float,
21
+ qkv_bias: bool,
22
+ qk_norm: bool,
23
+ norm_layer=nn.LayerNorm,
24
+ use_flash: bool = False,
25
+ ):
26
+ super().__init__()
27
+ self.n_ctx = n_ctx
28
+ self.width = width
29
+ self.heads = heads
30
+ self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias)
31
+ self.c_proj = nn.Linear(width, width)
32
+ self.attention = QKVMultiheadAttention(
33
+ heads=heads,
34
+ n_ctx=n_ctx,
35
+ width=width,
36
+ norm_layer=norm_layer,
37
+ qk_norm=qk_norm,
38
+ use_flash=use_flash,
39
+ )
40
+ init_linear(self.c_qkv, init_scale)
41
+ init_linear(self.c_proj, init_scale)
42
+
43
+ def forward(self, x):
44
+ x = self.c_qkv(x)
45
+ x = checkpoint(self.attention, (x,), (), True)
46
+ x = self.c_proj(x)
47
+ return x
48
+
49
+
50
+ class QKVMultiheadAttention(nn.Module):
51
+ def __init__(
52
+ self,
53
+ *,
54
+ heads: int,
55
+ n_ctx: int,
56
+ width=None,
57
+ qk_norm: bool = False,
58
+ norm_layer=nn.LayerNorm,
59
+ use_flash: bool = False,
60
+ ):
61
+ super().__init__()
62
+ self.heads = heads
63
+ self.n_ctx = n_ctx
64
+ self.use_flash = use_flash
65
+
66
+ self.q_norm = (
67
+ norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
68
+ if qk_norm
69
+ else nn.Identity()
70
+ )
71
+ self.k_norm = (
72
+ norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
73
+ if qk_norm
74
+ else nn.Identity()
75
+ )
76
+
77
+ def forward(self, qkv):
78
+ bs, n_ctx, width = qkv.shape
79
+ attn_ch = width // self.heads // 3
80
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
81
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
82
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
83
+
84
+ q = self.q_norm(q)
85
+ k = self.k_norm(k)
86
+
87
+ if self.use_flash:
88
+ q = q.permute(0, 2, 1, 3)
89
+ k = k.permute(0, 2, 1, 3)
90
+ v = v.permute(0, 2, 1, 3)
91
+ out = (
92
+ F.scaled_dot_product_attention(q, k, v)
93
+ .permute(0, 2, 1, 3)
94
+ .reshape(bs, n_ctx, -1)
95
+ )
96
+ else:
97
+ weight = torch.einsum(
98
+ "bthc,bshc->bhts", q * scale, k * scale
99
+ ) # More stable with f16 than dividing afterwards
100
+ wdtype = weight.dtype
101
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
102
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
103
+
104
+ return out
105
+
106
+
107
+ class ResidualAttentionBlock(nn.Module):
108
+ def __init__(
109
+ self,
110
+ *,
111
+ n_ctx: int,
112
+ width: int,
113
+ heads: int,
114
+ init_scale: float = 1.0,
115
+ qkv_bias: bool = True,
116
+ norm_layer=nn.LayerNorm,
117
+ qk_norm: bool = True,
118
+ use_flash: bool = False,
119
+ use_checkpoint: bool = False,
120
+ ):
121
+ super().__init__()
122
+
123
+ self.use_checkpoint = use_checkpoint
124
+
125
+ self.attn = MultiheadAttention(
126
+ n_ctx=n_ctx,
127
+ width=width,
128
+ heads=heads,
129
+ init_scale=init_scale,
130
+ qkv_bias=qkv_bias,
131
+ norm_layer=norm_layer,
132
+ qk_norm=qk_norm,
133
+ use_flash=use_flash,
134
+ )
135
+ self.ln_1 = nn.LayerNorm(width)
136
+ self.mlp = MLP(width=width, init_scale=init_scale)
137
+ self.ln_2 = nn.LayerNorm(width)
138
+
139
+ def _forward(self, x: torch.Tensor):
140
+ x = x + self.attn(self.ln_1(x))
141
+ x = x + self.mlp(self.ln_2(x))
142
+ return x
143
+
144
+ def forward(self, x: torch.Tensor):
145
+ return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
146
+
147
+
148
+ class MultiheadCrossAttention(nn.Module):
149
+ def __init__(
150
+ self,
151
+ *,
152
+ width: int,
153
+ heads: int,
154
+ init_scale: float,
155
+ qkv_bias: bool = True,
156
+ norm_layer=nn.LayerNorm,
157
+ qk_norm: bool = True,
158
+ use_flash: bool = False,
159
+ n_data: Optional[int] = None,
160
+ data_width: Optional[int] = None,
161
+ ):
162
+ super().__init__()
163
+ self.n_data = n_data
164
+ self.width = width
165
+ self.heads = heads
166
+ self.data_width = width if data_width is None else data_width
167
+ self.c_q = nn.Linear(width, width, bias=qkv_bias)
168
+ self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias)
169
+ self.c_proj = nn.Linear(width, width)
170
+ self.attention = QKVMultiheadCrossAttention(
171
+ heads=heads,
172
+ n_data=n_data,
173
+ width=width,
174
+ norm_layer=norm_layer,
175
+ qk_norm=qk_norm,
176
+ use_flash=use_flash,
177
+ )
178
+ init_linear(self.c_q, init_scale)
179
+ init_linear(self.c_kv, init_scale)
180
+ init_linear(self.c_proj, init_scale)
181
+
182
+ def forward(self, x, data):
183
+ x = self.c_q(x)
184
+ data = self.c_kv(data)
185
+ x = checkpoint(self.attention, (x, data), (), True)
186
+ x = self.c_proj(x)
187
+ return x
188
+
189
+
190
+ class QKVMultiheadCrossAttention(nn.Module):
191
+ def __init__(
192
+ self,
193
+ *,
194
+ heads: int,
195
+ n_data: Optional[int] = None,
196
+ width=None,
197
+ norm_layer=nn.LayerNorm,
198
+ qk_norm: bool = False,
199
+ use_flash: bool = False,
200
+ ):
201
+
202
+ super().__init__()
203
+ self.heads = heads
204
+ self.n_data = n_data
205
+ self.use_flash = use_flash
206
+
207
+ self.q_norm = (
208
+ norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
209
+ if qk_norm
210
+ else nn.Identity()
211
+ )
212
+ self.k_norm = (
213
+ norm_layer(width // heads, elementwise_affine=True, eps=1e-6)
214
+ if qk_norm
215
+ else nn.Identity()
216
+ )
217
+
218
+ def forward(self, q, kv):
219
+ _, n_ctx, _ = q.shape
220
+ bs, n_data, width = kv.shape
221
+ attn_ch = width // self.heads // 2
222
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
223
+ q = q.view(bs, n_ctx, self.heads, -1)
224
+ kv = kv.view(bs, n_data, self.heads, -1)
225
+ k, v = torch.split(kv, attn_ch, dim=-1)
226
+
227
+ q = self.q_norm(q)
228
+ k = self.k_norm(k)
229
+
230
+ if self.use_flash:
231
+ q = q.permute(0, 2, 1, 3)
232
+ k = k.permute(0, 2, 1, 3)
233
+ v = v.permute(0, 2, 1, 3)
234
+ out = (
235
+ F.scaled_dot_product_attention(q, k, v)
236
+ .permute(0, 2, 1, 3)
237
+ .reshape(bs, n_ctx, -1)
238
+ )
239
+ else:
240
+ weight = torch.einsum(
241
+ "bthc,bshc->bhts", q * scale, k * scale
242
+ ) # More stable with f16 than dividing afterwards
243
+ wdtype = weight.dtype
244
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
245
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
246
+
247
+ return out
248
+
249
+
250
+ class ResidualCrossAttentionBlock(nn.Module):
251
+ def __init__(
252
+ self,
253
+ *,
254
+ n_data: Optional[int] = None,
255
+ width: int,
256
+ heads: int,
257
+ data_width: Optional[int] = None,
258
+ init_scale: float = 0.25,
259
+ qkv_bias: bool = True,
260
+ qk_norm: bool = True,
261
+ use_flash: bool = False,
262
+ ):
263
+ super().__init__()
264
+
265
+ if data_width is None:
266
+ data_width = width
267
+
268
+ self.attn = MultiheadCrossAttention(
269
+ n_data=n_data,
270
+ width=width,
271
+ heads=heads,
272
+ data_width=data_width,
273
+ init_scale=init_scale,
274
+ qkv_bias=qkv_bias,
275
+ qk_norm=qk_norm,
276
+ use_flash=use_flash,
277
+ )
278
+ self.ln_1 = nn.LayerNorm(width)
279
+ self.ln_2 = nn.LayerNorm(data_width)
280
+ self.mlp = MLP(width=width, init_scale=init_scale)
281
+ self.ln_3 = nn.LayerNorm(width)
282
+
283
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
284
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
285
+ x = x + self.mlp(self.ln_3(x))
286
+ return x
step1x3d_geometry/models/autoencoders/transformers/perceiver_1d.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from step1x3d_geometry.utils.typing import *
7
+ from step1x3d_geometry.utils.checkpoint import checkpoint
8
+
9
+ from .utils import init_linear
10
+ from .attention import ResidualAttentionBlock
11
+
12
+
13
+ class Perceiver(nn.Module):
14
+ def __init__(
15
+ self,
16
+ *,
17
+ n_ctx: int,
18
+ width: int,
19
+ layers: int,
20
+ heads: int,
21
+ init_scale: float = 0.25,
22
+ qkv_bias: bool = True,
23
+ qk_norm: bool = True,
24
+ use_flash: bool = False,
25
+ use_checkpoint: bool = False
26
+ ):
27
+ super().__init__()
28
+ self.n_ctx = n_ctx
29
+ self.width = width
30
+ self.layers = layers
31
+ self.resblocks = nn.ModuleList(
32
+ [
33
+ ResidualAttentionBlock(
34
+ n_ctx=n_ctx,
35
+ width=width,
36
+ heads=heads,
37
+ init_scale=init_scale,
38
+ qkv_bias=qkv_bias,
39
+ qk_norm=qk_norm,
40
+ use_flash=use_flash,
41
+ use_checkpoint=use_checkpoint,
42
+ )
43
+ for _ in range(layers)
44
+ ]
45
+ )
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ for block in self.resblocks:
49
+ x = block(x)
50
+ return x
step1x3d_geometry/models/autoencoders/transformers/utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def init_linear(l, stddev):
5
+ nn.init.normal_(l.weight, std=stddev)
6
+ if l.bias is not None:
7
+ nn.init.constant_(l.bias, 0.0)
8
+
9
+
10
+ class MLP(nn.Module):
11
+ def __init__(self, *, width: int, init_scale: float):
12
+ super().__init__()
13
+ self.width = width
14
+ self.c_fc = nn.Linear(width, width * 4)
15
+ self.c_proj = nn.Linear(width * 4, width)
16
+ self.gelu = nn.GELU()
17
+ init_linear(self.c_fc, init_scale)
18
+ init_linear(self.c_proj, init_scale)
19
+
20
+ def forward(self, x):
21
+ return self.c_proj(self.gelu(self.c_fc(x)))
step1x3d_geometry/models/autoencoders/volume_decoders.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
2
+ # except for the third-party components listed below.
3
+ # Hunyuan 3D does not impose any additional limitations beyond what is outlined
4
+ # in the repsective licenses of these third-party components.
5
+ # Users must comply with all terms and conditions of original licenses of these third-party
6
+ # components and must ensure that the usage of the third party components adheres to
7
+ # all relevant laws and regulations.
8
+
9
+ # For avoidance of doubts, Hunyuan 3D means the large language models and
10
+ # their software and algorithms, including trained model weights, parameters (including
11
+ # optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
12
+ # fine-tuning enabling code and other elements of the foregoing made publicly available
13
+ # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
14
+
15
+ from typing import Union, Tuple, List, Callable
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from einops import repeat
22
+ from tqdm import tqdm
23
+
24
+ cube_corners = torch.tensor(
25
+ [
26
+ [0, 0, 0],
27
+ [1, 0, 0],
28
+ [0, 1, 0],
29
+ [1, 1, 0],
30
+ [0, 0, 1],
31
+ [1, 0, 1],
32
+ [0, 1, 1],
33
+ [1, 1, 1],
34
+ ],
35
+ dtype=torch.int,
36
+ )
37
+
38
+
39
+ def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
40
+ device = input_tensor.device
41
+ D = input_tensor.shape[0]
42
+ signed_val = 0.0
43
+
44
+ # 添加偏移并处理无效值
45
+ val = input_tensor + alpha
46
+ valid_mask = val > -9000 # 假设-9000是无效值
47
+
48
+ # 改进的邻居获取函数(保持维度一致)
49
+ def get_neighbor(t, shift, axis):
50
+ """根据指定轴进行位移并保持维度一致"""
51
+ if shift == 0:
52
+ return t.clone()
53
+
54
+ # 确定填充轴(输入为[D, D, D]对应z,y,x轴)
55
+ pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前,x后,y前,y后,z前,z后]
56
+
57
+ # 根据轴类型设置填充
58
+ if axis == 0: # x轴(最后一个维度)
59
+ pad_idx = 0 if shift > 0 else 1
60
+ pad_dims[pad_idx] = abs(shift)
61
+ elif axis == 1: # y轴(中间维度)
62
+ pad_idx = 2 if shift > 0 else 3
63
+ pad_dims[pad_idx] = abs(shift)
64
+ elif axis == 2: # z轴(第一个维度)
65
+ pad_idx = 4 if shift > 0 else 5
66
+ pad_dims[pad_idx] = abs(shift)
67
+
68
+ # 执行填充(添加batch和channel维度适配F.pad)
69
+ padded = F.pad(
70
+ t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode="replicate"
71
+ ) # 反转顺序适配F.pad
72
+
73
+ # 构建动态切片索引
74
+ slice_dims = [slice(None)] * 3 # 初始化为全切片
75
+ if axis == 0: # x轴(dim=2)
76
+ if shift > 0:
77
+ slice_dims[0] = slice(shift, None)
78
+ else:
79
+ slice_dims[0] = slice(None, shift)
80
+ elif axis == 1: # y轴(dim=1)
81
+ if shift > 0:
82
+ slice_dims[1] = slice(shift, None)
83
+ else:
84
+ slice_dims[1] = slice(None, shift)
85
+ elif axis == 2: # z轴(dim=0)
86
+ if shift > 0:
87
+ slice_dims[2] = slice(shift, None)
88
+ else:
89
+ slice_dims[2] = slice(None, shift)
90
+
91
+ # 应用切片并恢复维度
92
+ padded = padded.squeeze(0).squeeze(0)
93
+ sliced = padded[slice_dims]
94
+ return sliced
95
+
96
+ # 获取各方向邻居(确保维度一致)
97
+ left = get_neighbor(val, 1, axis=0) # x方向
98
+ right = get_neighbor(val, -1, axis=0)
99
+ back = get_neighbor(val, 1, axis=1) # y方向
100
+ front = get_neighbor(val, -1, axis=1)
101
+ down = get_neighbor(val, 1, axis=2) # z方向
102
+ up = get_neighbor(val, -1, axis=2)
103
+
104
+ # 处理边界无效值(使用where保持维度一致)
105
+ def safe_where(neighbor):
106
+ return torch.where(neighbor > -9000, neighbor, val)
107
+
108
+ left = safe_where(left)
109
+ right = safe_where(right)
110
+ back = safe_where(back)
111
+ front = safe_where(front)
112
+ down = safe_where(down)
113
+ up = safe_where(up)
114
+
115
+ # 计算符号一致性(转换为float32确保精度)
116
+ sign = torch.sign(val.to(torch.float32))
117
+ neighbors_sign = torch.stack(
118
+ [
119
+ torch.sign(left.to(torch.float32)),
120
+ torch.sign(right.to(torch.float32)),
121
+ torch.sign(back.to(torch.float32)),
122
+ torch.sign(front.to(torch.float32)),
123
+ torch.sign(down.to(torch.float32)),
124
+ torch.sign(up.to(torch.float32)),
125
+ ],
126
+ dim=0,
127
+ )
128
+
129
+ # 检查所有符号是否一致
130
+ same_sign = torch.all(neighbors_sign == sign, dim=0)
131
+
132
+ # 生成最终掩码
133
+ mask = (~same_sign).to(torch.int32)
134
+ return mask * valid_mask.to(torch.int32)
135
+
136
+
137
+ def generate_dense_grid_points(
138
+ bbox_min: np.ndarray,
139
+ bbox_max: np.ndarray,
140
+ octree_resolution: int,
141
+ indexing: str = "ij",
142
+ ):
143
+ length = bbox_max - bbox_min
144
+ num_cells = octree_resolution
145
+
146
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
147
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
148
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
149
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
150
+ xyz = np.stack((xs, ys, zs), axis=-1)
151
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
152
+
153
+ return xyz, grid_size, length
154
+
155
+
156
+ class VanillaVolumeDecoder:
157
+ @torch.no_grad()
158
+ def __call__(
159
+ self,
160
+ latents: torch.FloatTensor,
161
+ geo_decoder: Callable,
162
+ bounds: Union[Tuple[float], List[float], float] = 1.01,
163
+ num_chunks: int = 10000,
164
+ octree_resolution: int = 384,
165
+ enable_pbar: bool = True,
166
+ **kwargs,
167
+ ):
168
+ device = latents.device
169
+ dtype = latents.dtype
170
+ batch_size = latents.shape[0]
171
+
172
+ # 1. generate query points
173
+ if isinstance(bounds, float):
174
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
175
+
176
+ bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
177
+ xyz_samples, grid_size, length = generate_dense_grid_points(
178
+ bbox_min=bbox_min,
179
+ bbox_max=bbox_max,
180
+ octree_resolution=octree_resolution,
181
+ indexing="ij",
182
+ )
183
+ xyz_samples = (
184
+ torch.from_numpy(xyz_samples)
185
+ .to(device, dtype=dtype)
186
+ .contiguous()
187
+ .reshape(-1, 3)
188
+ )
189
+
190
+ # 2. latents to 3d volume
191
+ batch_features = []
192
+ for start in tqdm(
193
+ range(0, xyz_samples.shape[0], num_chunks),
194
+ desc=f"Volume Decoding",
195
+ disable=not enable_pbar,
196
+ ):
197
+ chunk_queries = xyz_samples[start : start + num_chunks, :]
198
+ chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
199
+ features = geo_decoder(queries=chunk_queries, latents=latents)
200
+ batch_features.append(features)
201
+
202
+ grid_features = torch.cat(batch_features, dim=1)
203
+ grid_logits, grid_features = grid_features[..., 0:1], grid_features[..., 1:]
204
+ grid_logits = grid_logits.view((batch_size, *grid_size)).float()
205
+
206
+ return grid_logits, xyz_samples, grid_features, None
207
+
208
+
209
+ class HierarchicalVolumeDecoder:
210
+ @torch.no_grad()
211
+ def __call__(
212
+ self,
213
+ latents: torch.FloatTensor,
214
+ geo_decoder: Callable,
215
+ bounds: Union[Tuple[float], List[float], float] = 1.01,
216
+ num_chunks: int = 65536,
217
+ mc_level: float = 0.0,
218
+ octree_resolution: int = 384,
219
+ min_resolution: int = 63,
220
+ enable_pbar: bool = True,
221
+ empty_value: float = float("nan"),
222
+ **kwargs,
223
+ ):
224
+ device = latents.device
225
+ dtype = latents.dtype
226
+
227
+ resolutions = []
228
+ if octree_resolution < min_resolution:
229
+ resolutions.append(octree_resolution)
230
+ while octree_resolution >= min_resolution:
231
+ resolutions.append(octree_resolution)
232
+ octree_resolution = octree_resolution // 2
233
+ resolutions.reverse()
234
+
235
+ # 1. generate query points
236
+ if isinstance(bounds, float):
237
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
238
+ bbox_min = np.array(bounds[0:3])
239
+ bbox_max = np.array(bounds[3:6])
240
+ bbox_size = bbox_max - bbox_min
241
+
242
+ xyz_samples, grid_size, length = generate_dense_grid_points(
243
+ bbox_min=bbox_min,
244
+ bbox_max=bbox_max,
245
+ octree_resolution=resolutions[0],
246
+ indexing="ij",
247
+ )
248
+
249
+ dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
250
+ dilate.weight = torch.nn.Parameter(
251
+ torch.ones(dilate.weight.shape, dtype=dtype, device=device)
252
+ )
253
+
254
+ grid_size = np.array(grid_size)
255
+ xyz_samples = (
256
+ torch.from_numpy(xyz_samples)
257
+ .to(device, dtype=dtype)
258
+ .contiguous()
259
+ .reshape(-1, 3)
260
+ )
261
+
262
+ # 2. latents to 3d volume
263
+ batch_features = []
264
+ batch_size = latents.shape[0]
265
+ for start in tqdm(
266
+ range(0, xyz_samples.shape[0], num_chunks),
267
+ desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]",
268
+ disable=not enable_pbar,
269
+ ):
270
+ queries = xyz_samples[start : start + num_chunks, :]
271
+ batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
272
+ features = geo_decoder(queries=batch_queries, latents=latents)
273
+ batch_features.append(features)
274
+
275
+ grid_features = torch.cat(batch_features, dim=1).view(
276
+ (batch_size, grid_size[0], grid_size[1], grid_size[2], -1)
277
+ )
278
+ grid_logits = grid_features[..., 0] # assume the first element is the logits
279
+
280
+ for octree_depth_now in resolutions[1:]:
281
+ grid_size = np.array([octree_depth_now + 1] * 3)
282
+ resolution = bbox_size / octree_depth_now
283
+ next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
284
+ next_logits = torch.full(
285
+ next_index.shape, -10000.0, dtype=dtype, device=device
286
+ )
287
+ curr_points = extract_near_surface_volume_fn(
288
+ grid_logits.squeeze(0), mc_level
289
+ )
290
+ curr_points += grid_logits.squeeze(0).abs() < 0.95
291
+
292
+ if octree_depth_now == resolutions[-1]:
293
+ expand_num = 0
294
+ else:
295
+ expand_num = 1
296
+ for i in range(expand_num):
297
+ curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
298
+ (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
299
+ next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
300
+ for i in range(2 - expand_num):
301
+ next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
302
+ nidx = torch.where(next_index > 0)
303
+
304
+ next_points = torch.stack(nidx, dim=1)
305
+ next_points = next_points * torch.tensor(
306
+ resolution, dtype=latents.dtype, device=device
307
+ ) + torch.tensor(bbox_min, dtype=latents.dtype, device=device)
308
+
309
+ batch_features = []
310
+ for start in tqdm(
311
+ range(0, next_points.shape[0], num_chunks),
312
+ desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]",
313
+ disable=not enable_pbar,
314
+ ):
315
+ queries = next_points[start : start + num_chunks, :]
316
+ batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
317
+ features = geo_decoder(
318
+ queries=batch_queries.to(latents.dtype), latents=latents
319
+ )
320
+ batch_features.append(features)
321
+ grid_features = torch.cat(batch_features, dim=1)
322
+ grid_logits = grid_features[..., 0:1]
323
+ next_logits[nidx] = grid_logits[0, ..., 0]
324
+ grid_logits = next_logits.unsqueeze(0)
325
+ grid_logits[grid_logits == -10000.0] = empty_value
326
+
327
+ return grid_logits
step1x3d_geometry/models/conditional_encoders/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from . import (
2
+ dinov2_encoder,
3
+ dinov2_clip_encoder,
4
+ t5_encoder,
5
+ label_encoder,
6
+ )
step1x3d_geometry/models/conditional_encoders/base.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ from dataclasses import dataclass
7
+ from torchvision.transforms import Normalize
8
+ from torchvision.transforms import InterpolationMode
9
+ from torchvision.transforms.transforms import _interpolation_modes_from_int
10
+
11
+ from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor
12
+ from transformers.utils import ModelOutput
13
+ from typing import Iterable, Optional, Union, List
14
+
15
+ import step1x3d_geometry
16
+ from step1x3d_geometry.utils.base import BaseModule
17
+ from step1x3d_geometry.utils.typing import *
18
+
19
+ ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
20
+
21
+
22
+ class BaseVisualEncoder(BaseModule):
23
+ @dataclass
24
+ class Config(BaseModule.Config):
25
+ pretrained_model_name_or_path: Optional[str] = (
26
+ None # the pretrained model name or path
27
+ )
28
+
29
+ encode_camera: bool = False # whether to encode camera
30
+ camera_embeds_type: str = "sincos" # the type of camera embeds
31
+ camera_embeds_dim: Optional[int] = None # the dimension of camera embeds
32
+ n_views: int = 1 # the number of views
33
+
34
+ empty_embeds_ratio: float = 0.1 # the ratio of empty embeds
35
+ normalize_embeds: bool = False # whether to normalize the embeds
36
+ zero_uncond_embeds: bool = True
37
+
38
+ cfg: Config
39
+
40
+ def configure(self) -> None:
41
+ super().configure()
42
+
43
+ if self.cfg.encode_camera:
44
+ self.distance = 1.0
45
+ self.register_buffer(
46
+ "cameras",
47
+ torch.as_tensor(
48
+ [
49
+ [
50
+ [1, 0, 0, 0],
51
+ [0, 0, -1, -self.distance],
52
+ [0, 1, 0, 0],
53
+ [0, 0, 0, 1],
54
+ ], # front to back
55
+ [
56
+ [0, 0, 1, self.distance],
57
+ [1, 0, 0, 0],
58
+ [0, 1, 0, 0],
59
+ [0, 0, 0, 1],
60
+ ], # right to left
61
+ [
62
+ [-1, 0, 0, 0],
63
+ [0, 0, 1, self.distance],
64
+ [0, 1, 0, 0],
65
+ [0, 0, 0, 1],
66
+ ], # back to front
67
+ [
68
+ [0, 0, -1, -self.distance],
69
+ [-1, 0, 0, 0],
70
+ [0, 1, 0, 0],
71
+ [0, 0, 0, 1],
72
+ ], # left to right
73
+ ],
74
+ dtype=torch.float32,
75
+ ),
76
+ )
77
+
78
+ def encode_image(
79
+ self,
80
+ images: Iterable[Optional[ImageType]],
81
+ camera_embeds: Optional[torch.Tensor] = None,
82
+ **kwargs,
83
+ ) -> torch.FloatTensor:
84
+ raise NotImplementedError
85
+
86
+ def encode_camera(self, c2ws: torch.Tensor):
87
+ if self.cfg.camera_embeds_type == "sincos":
88
+ assert (
89
+ c2ws.shape[-1] == 4 and c2ws.shape[-2] == 4
90
+ ), f"Invalid c2ws shape: {c2ws.shape}"
91
+ c2ws = c2ws.view(-1, 16)
92
+ return torch.cat([torch.sin(c2ws), torch.cos(c2ws)], dim=-1)
93
+ else:
94
+ raise NotImplementedError(
95
+ f"Unknown camera_embeds_type: {self.cfg.camera_embeds_type}"
96
+ )
97
+
98
+ def forward(self, batch):
99
+ assert (
100
+ "image" in batch or "mvimages" in batch
101
+ ), "image or mvimages is required for visual embeds"
102
+ if batch["image"].dim() == 5:
103
+ bs = batch["image"].shape[0] * batch["image"].shape[1]
104
+ else:
105
+ bs = batch["image"].shape[0]
106
+
107
+ if random.random() < self.cfg.empty_embeds_ratio:
108
+ if "image" in batch or "image_embeds" in batch:
109
+ visual_embeds = self.empty_image_embeds.repeat(bs, 1, 1)
110
+ elif "mvimages" in batch or "mvimage_embeds" in batch:
111
+ visual_embeds = self.empty_image_embeds.unsqueeze(1).repeat(bs, 1, 1, 1)
112
+ else:
113
+ # for visual inputs
114
+ if "image" in batch:
115
+ if self.cfg.encode_camera:
116
+ visual_embeds = self.encode_image(
117
+ batch["image"], cameras=batch["c2w"]
118
+ )
119
+ else:
120
+ visual_embeds = self.encode_image(batch["image"])
121
+ elif "mvimages" in batch:
122
+ n_views = batch["mvimages"].shape[1]
123
+ if self.cfg.encode_camera:
124
+ visual_embeds = self.encode_image(
125
+ batch["mvimages"].view(-1, *batch["mvimages"].shape[-3:]),
126
+ cameras=batch["c2ws"],
127
+ ).view(bs, n_views, *self.empty_image_embeds.shape[-2:])
128
+ else:
129
+ visual_embeds = self.encode_image(
130
+ batch["mvimages"].view(-1, *batch["mvimages"].shape[-3:])
131
+ ).view(bs, n_views, *self.empty_image_embeds.shape[-2:])
132
+
133
+ if self.cfg.normalize_embeds: # post-process the visual embeds
134
+ visual_embeds = visual_embeds / visual_embeds.norm(dim=-1, keepdim=True)
135
+
136
+ return visual_embeds
137
+
138
+
139
+ class BaseCaptionEncoder(BaseModule):
140
+ @dataclass
141
+ class Config(BaseModule.Config):
142
+ pretrained_model_name_or_path: Optional[str] = (
143
+ None # the pretrained model name or path
144
+ )
145
+
146
+ text_max_length: int = 77
147
+
148
+ empty_embeds_ratio: float = 0.1 # the ratio of empty embeds
149
+ normalize_embeds: bool = False # whether to normalize the embeds
150
+ zero_uncond_embeds: bool = True
151
+
152
+ cfg: Config
153
+
154
+ def configure(self) -> None:
155
+ super().configure()
156
+
157
+ def forward(self, batch, force_drop_ids=None):
158
+ assert "caption" in batch, "caption is required for caption embeds"
159
+
160
+ bs = len(batch["label"])
161
+ if random.random() < self.cfg.empty_embeds_ratio:
162
+ caption_embeds = self.empty_text_embeds.repeat(bs, 1, 1)
163
+ else:
164
+ caption_embeds = self.encode_text(batch["caption"])
165
+
166
+ if self.cfg.normalize_embeds: # post-process the label embeds
167
+ caption_embeds = caption_embeds / caption_embeds.norm(dim=-1, keepdim=True)
168
+
169
+ return caption_embeds
170
+
171
+
172
+ class BaseLabelEncoder(BaseModule):
173
+ @dataclass
174
+ class Config(BaseModule.Config):
175
+ pretrained_model_name_or_path: Optional[str] = (
176
+ None # the pretrained model name or path
177
+ )
178
+
179
+ hidden_size: int = 1024
180
+
181
+ empty_embeds_ratio: float = 0.1 # the ratio of empty embeds
182
+ normalize_embeds: bool = False # whether to normalize the embeds
183
+ zero_uncond_embeds: bool = True
184
+
185
+ cfg: Config
186
+
187
+ def configure(self) -> None:
188
+ super().configure()
189
+
190
+ def forward(self, batch, force_drop_ids=None):
191
+ assert "label" in batch, "label is required for label embeds"
192
+
193
+ bs = len(batch["label"])
194
+ if random.random() < self.cfg.empty_embeds_ratio:
195
+ label_embeds = self.empty_label_embeds.repeat(bs, 1, 1)
196
+ else:
197
+ label_embeds = self.encode_label(batch["label"])
198
+
199
+ if self.cfg.normalize_embeds: # post-process the label embeds
200
+ label_embeds = label_embeds / label_embeds.norm(dim=-1, keepdim=True)
201
+
202
+ return label_embeds
step1x3d_geometry/models/conditional_encoders/clip/modeling_clip.py ADDED
@@ -0,0 +1,1597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch CLIP model."""
16
+
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_attn_mask_utils import (
28
+ _create_4d_causal_attention_mask,
29
+ _prepare_4d_attention_mask,
30
+ )
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutput,
33
+ BaseModelOutputWithPooling,
34
+ ImageClassifierOutput,
35
+ )
36
+ from transformers.modeling_utils import PreTrainedModel
37
+ from transformers.utils import (
38
+ ModelOutput,
39
+ add_code_sample_docstrings,
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ logging,
43
+ replace_return_docstrings,
44
+ )
45
+ from transformers.models.clip.configuration_clip import (
46
+ CLIPConfig,
47
+ CLIPTextConfig,
48
+ CLIPVisionConfig,
49
+ )
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ # General docstring
55
+ _CONFIG_FOR_DOC = "CLIPConfig"
56
+ _CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
57
+
58
+ # Image classification docstring
59
+ _IMAGE_CLASS_CHECKPOINT = "openai/clip-vit-base-patch32"
60
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
61
+
62
+
63
+ # contrastive loss function, adapted from
64
+ # https://sachinruk.github.io/blog/2021-03-07-clip.html
65
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
66
+ return nn.functional.cross_entropy(
67
+ logits, torch.arange(len(logits), device=logits.device)
68
+ )
69
+
70
+
71
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
72
+ caption_loss = contrastive_loss(similarity)
73
+ image_loss = contrastive_loss(similarity.t())
74
+ return (caption_loss + image_loss) / 2.0
75
+
76
+
77
+ @dataclass
78
+ class CLIPVisionModelOutput(ModelOutput):
79
+ """
80
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
81
+
82
+ Args:
83
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
84
+ The image embeddings obtained by applying the projection layer to the pooler_output.
85
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
86
+ Sequence of hidden-states at the output of the last layer of the model.
87
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
88
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
89
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
90
+
91
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
92
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
93
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
94
+ sequence_length)`.
95
+
96
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
97
+ heads.
98
+ """
99
+
100
+ image_embeds: Optional[torch.FloatTensor] = None
101
+ last_hidden_state: torch.FloatTensor = None
102
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
103
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
104
+
105
+
106
+ @dataclass
107
+ class CLIPTextModelOutput(ModelOutput):
108
+ """
109
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
110
+
111
+ Args:
112
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
113
+ The text embeddings obtained by applying the projection layer to the pooler_output.
114
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
115
+ Sequence of hidden-states at the output of the last layer of the model.
116
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
117
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
118
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
119
+
120
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
121
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
122
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
123
+ sequence_length)`.
124
+
125
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
126
+ heads.
127
+ """
128
+
129
+ text_embeds: Optional[torch.FloatTensor] = None
130
+ last_hidden_state: torch.FloatTensor = None
131
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
132
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
133
+
134
+
135
+ @dataclass
136
+ class CLIPOutput(ModelOutput):
137
+ """
138
+ Args:
139
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
140
+ Contrastive loss for image-text similarity.
141
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
142
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
143
+ similarity scores.
144
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
145
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
146
+ similarity scores.
147
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
148
+ The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPTextModel`].
149
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
150
+ The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPVisionModel`].
151
+ text_model_output(`BaseModelOutputWithPooling`):
152
+ The output of the [`CLIPTextModel`].
153
+ vision_model_output(`BaseModelOutputWithPooling`):
154
+ The output of the [`CLIPVisionModel`].
155
+ """
156
+
157
+ loss: Optional[torch.FloatTensor] = None
158
+ logits_per_image: torch.FloatTensor = None
159
+ logits_per_text: torch.FloatTensor = None
160
+ text_embeds: torch.FloatTensor = None
161
+ image_embeds: torch.FloatTensor = None
162
+ text_model_output: BaseModelOutputWithPooling = None
163
+ vision_model_output: BaseModelOutputWithPooling = None
164
+
165
+ def to_tuple(self) -> Tuple[Any]:
166
+ return tuple(
167
+ (
168
+ self[k]
169
+ if k not in ["text_model_output", "vision_model_output"]
170
+ else getattr(self, k).to_tuple()
171
+ )
172
+ for k in self.keys()
173
+ )
174
+
175
+
176
+ class CLIPVisionEmbeddings(nn.Module):
177
+ def __init__(self, config: CLIPVisionConfig):
178
+ super().__init__()
179
+ self.config = config
180
+ self.embed_dim = config.hidden_size
181
+ self.image_size = config.image_size
182
+ self.patch_size = config.patch_size
183
+
184
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
185
+
186
+ self.patch_embedding = nn.Conv2d(
187
+ in_channels=config.num_channels,
188
+ out_channels=self.embed_dim,
189
+ kernel_size=self.patch_size,
190
+ stride=self.patch_size,
191
+ bias=False,
192
+ )
193
+
194
+ self.num_patches = (self.image_size // self.patch_size) ** 2
195
+ self.num_positions = self.num_patches + 1
196
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
197
+ self.register_buffer(
198
+ "position_ids",
199
+ torch.arange(self.num_positions).expand((1, -1)),
200
+ persistent=False,
201
+ )
202
+
203
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
204
+ batch_size = pixel_values.shape[0]
205
+ target_dtype = self.patch_embedding.weight.dtype
206
+ patch_embeds = self.patch_embedding(
207
+ pixel_values.to(dtype=target_dtype)
208
+ ) # shape = [*, width, grid, grid]
209
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
210
+
211
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
212
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
213
+ embeddings = embeddings + self.position_embedding(self.position_ids)
214
+ return embeddings
215
+
216
+
217
+ class CLIPTextEmbeddings(nn.Module):
218
+ def __init__(self, config: CLIPTextConfig):
219
+ super().__init__()
220
+ embed_dim = config.hidden_size
221
+
222
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
223
+ self.position_embedding = nn.Embedding(
224
+ config.max_position_embeddings, embed_dim
225
+ )
226
+
227
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
228
+ self.register_buffer(
229
+ "position_ids",
230
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
231
+ persistent=False,
232
+ )
233
+
234
+ def forward(
235
+ self,
236
+ input_ids: Optional[torch.LongTensor] = None,
237
+ position_ids: Optional[torch.LongTensor] = None,
238
+ inputs_embeds: Optional[torch.FloatTensor] = None,
239
+ ) -> torch.Tensor:
240
+ seq_length = (
241
+ input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
242
+ )
243
+
244
+ if position_ids is None:
245
+ position_ids = self.position_ids[:, :seq_length]
246
+
247
+ if inputs_embeds is None:
248
+ inputs_embeds = self.token_embedding(input_ids)
249
+
250
+ position_embeddings = self.position_embedding(position_ids)
251
+ embeddings = inputs_embeds + position_embeddings
252
+
253
+ return embeddings
254
+
255
+
256
+ class CLIPAttention(nn.Module):
257
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
258
+
259
+ def __init__(self, config):
260
+ super().__init__()
261
+ self.config = config
262
+ self.embed_dim = config.hidden_size
263
+ self.num_heads = config.num_attention_heads
264
+ self.head_dim = self.embed_dim // self.num_heads
265
+ if self.head_dim * self.num_heads != self.embed_dim:
266
+ raise ValueError(
267
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
268
+ f" {self.num_heads})."
269
+ )
270
+ self.scale = self.head_dim**-0.5
271
+ self.dropout = config.attention_dropout
272
+
273
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
274
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
275
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
276
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
277
+
278
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
279
+ return (
280
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
281
+ .transpose(1, 2)
282
+ .contiguous()
283
+ )
284
+
285
+ def forward(
286
+ self,
287
+ hidden_states: torch.Tensor,
288
+ attention_mask: Optional[torch.Tensor] = None,
289
+ causal_attention_mask: Optional[torch.Tensor] = None,
290
+ output_attentions: Optional[bool] = False,
291
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
292
+ """Input shape: Batch x Time x Channel"""
293
+
294
+ bsz, tgt_len, embed_dim = hidden_states.size()
295
+
296
+ # get query proj
297
+ query_states = self.q_proj(hidden_states) * self.scale
298
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
299
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
300
+
301
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
302
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
303
+ key_states = key_states.view(*proj_shape)
304
+ value_states = value_states.view(*proj_shape)
305
+
306
+ src_len = key_states.size(1)
307
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
308
+
309
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
310
+ raise ValueError(
311
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
312
+ f" {attn_weights.size()}"
313
+ )
314
+
315
+ # apply the causal_attention_mask first
316
+ if causal_attention_mask is not None:
317
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
318
+ raise ValueError(
319
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
320
+ f" {causal_attention_mask.size()}"
321
+ )
322
+ attn_weights = (
323
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
324
+ + causal_attention_mask
325
+ )
326
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
327
+
328
+ if attention_mask is not None:
329
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
330
+ raise ValueError(
331
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
332
+ )
333
+ attn_weights = (
334
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
335
+ + attention_mask
336
+ )
337
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
338
+
339
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
340
+
341
+ if output_attentions:
342
+ # this operation is a bit akward, but it's required to
343
+ # make sure that attn_weights keeps its gradient.
344
+ # In order to do so, attn_weights have to reshaped
345
+ # twice and have to be reused in the following
346
+ attn_weights_reshaped = attn_weights.view(
347
+ bsz, self.num_heads, tgt_len, src_len
348
+ )
349
+ attn_weights = attn_weights_reshaped.view(
350
+ bsz * self.num_heads, tgt_len, src_len
351
+ )
352
+ else:
353
+ attn_weights_reshaped = None
354
+
355
+ attn_probs = nn.functional.dropout(
356
+ attn_weights, p=self.dropout, training=self.training
357
+ )
358
+
359
+ attn_output = torch.bmm(attn_probs, value_states)
360
+
361
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
362
+ raise ValueError(
363
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
364
+ f" {attn_output.size()}"
365
+ )
366
+
367
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
368
+ attn_output = attn_output.transpose(1, 2)
369
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
370
+
371
+ attn_output = self.out_proj(attn_output)
372
+
373
+ return attn_output, attn_weights_reshaped
374
+
375
+
376
+ class CLIPMLP(nn.Module):
377
+ def __init__(self, config):
378
+ super().__init__()
379
+ self.config = config
380
+ self.activation_fn = ACT2FN[config.hidden_act]
381
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
382
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
383
+
384
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
385
+ hidden_states = self.fc1(hidden_states)
386
+ hidden_states = self.activation_fn(hidden_states)
387
+ hidden_states = self.fc2(hidden_states)
388
+ return hidden_states
389
+
390
+
391
+ class CLIPEncoderLayer(nn.Module):
392
+ def __init__(self, config: CLIPConfig):
393
+ super().__init__()
394
+ self.embed_dim = config.hidden_size
395
+ self.self_attn = CLIPAttention(config)
396
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
397
+ self.mlp = CLIPMLP(config)
398
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states: torch.Tensor,
403
+ attention_mask: torch.Tensor,
404
+ causal_attention_mask: torch.Tensor,
405
+ output_attentions: Optional[bool] = False,
406
+ ) -> Tuple[torch.FloatTensor]:
407
+ """
408
+ Args:
409
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
410
+ attention_mask (`torch.FloatTensor`): attention mask of size
411
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
412
+ `(config.encoder_attention_heads,)`.
413
+ output_attentions (`bool`, *optional*):
414
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
415
+ returned tensors for more detail.
416
+ """
417
+ residual = hidden_states
418
+
419
+ hidden_states = self.layer_norm1(hidden_states)
420
+ hidden_states, attn_weights = self.self_attn(
421
+ hidden_states=hidden_states,
422
+ attention_mask=attention_mask,
423
+ causal_attention_mask=causal_attention_mask,
424
+ output_attentions=output_attentions,
425
+ )
426
+ hidden_states = residual + hidden_states
427
+
428
+ residual = hidden_states
429
+ hidden_states = self.layer_norm2(hidden_states)
430
+ hidden_states = self.mlp(hidden_states)
431
+ hidden_states = residual + hidden_states
432
+
433
+ outputs = (hidden_states,)
434
+
435
+ if output_attentions:
436
+ outputs += (attn_weights,)
437
+
438
+ return outputs
439
+
440
+
441
+ class CLIPPreTrainedModel(PreTrainedModel):
442
+ """
443
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
444
+ models.
445
+ """
446
+
447
+ config_class = CLIPConfig
448
+ base_model_prefix = "clip"
449
+ supports_gradient_checkpointing = True
450
+
451
+ def _init_weights(self, module):
452
+ """Initialize the weights"""
453
+ factor = self.config.initializer_factor
454
+ if isinstance(module, CLIPTextEmbeddings):
455
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
456
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
457
+ elif isinstance(module, CLIPVisionEmbeddings):
458
+ factor = self.config.initializer_factor
459
+ nn.init.normal_(
460
+ module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor
461
+ )
462
+ nn.init.normal_(
463
+ module.patch_embedding.weight,
464
+ std=module.config.initializer_range * factor,
465
+ )
466
+ nn.init.normal_(
467
+ module.position_embedding.weight,
468
+ std=module.config.initializer_range * factor,
469
+ )
470
+ elif isinstance(module, CLIPAttention):
471
+ factor = self.config.initializer_factor
472
+ in_proj_std = (
473
+ (module.embed_dim**-0.5)
474
+ * ((2 * module.config.num_hidden_layers) ** -0.5)
475
+ * factor
476
+ )
477
+ out_proj_std = (module.embed_dim**-0.5) * factor
478
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
479
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
480
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
481
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
482
+ elif isinstance(module, CLIPMLP):
483
+ factor = self.config.initializer_factor
484
+ in_proj_std = (
485
+ (module.config.hidden_size**-0.5)
486
+ * ((2 * module.config.num_hidden_layers) ** -0.5)
487
+ * factor
488
+ )
489
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
490
+ nn.init.normal_(module.fc1.weight, std=fc_std)
491
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
492
+ elif isinstance(module, CLIPModel):
493
+ nn.init.normal_(
494
+ module.text_projection.weight,
495
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
496
+ )
497
+ nn.init.normal_(
498
+ module.visual_projection.weight,
499
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
500
+ )
501
+ elif isinstance(module, CLIPVisionModelWithProjection):
502
+ nn.init.normal_(
503
+ module.visual_projection.weight,
504
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
505
+ )
506
+ elif isinstance(module, CLIPTextModelWithProjection):
507
+ nn.init.normal_(
508
+ module.text_projection.weight,
509
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
510
+ )
511
+ elif isinstance(module, CLIPForImageClassification):
512
+ nn.init.normal_(
513
+ module.classifier.weight,
514
+ std=self.config.vision_config.hidden_size**-0.5
515
+ * self.config.initializer_factor,
516
+ )
517
+
518
+ if isinstance(module, nn.LayerNorm):
519
+ module.bias.data.zero_()
520
+ module.weight.data.fill_(1.0)
521
+ if isinstance(module, nn.Linear) and module.bias is not None:
522
+ module.bias.data.zero_()
523
+
524
+
525
+ CLIP_START_DOCSTRING = r"""
526
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
527
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
528
+ etc.)
529
+
530
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
531
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
532
+ and behavior.
533
+
534
+ Parameters:
535
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
536
+ Initializing with a config file does not load the weights associated with the model, only the
537
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
538
+ """
539
+
540
+ CLIP_TEXT_INPUTS_DOCSTRING = r"""
541
+ Args:
542
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
543
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
544
+ it.
545
+
546
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
547
+ [`PreTrainedTokenizer.__call__`] for details.
548
+
549
+ [What are input IDs?](../glossary#input-ids)
550
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
551
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
552
+
553
+ - 1 for tokens that are **not masked**,
554
+ - 0 for tokens that are **masked**.
555
+
556
+ [What are attention masks?](../glossary#attention-mask)
557
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
558
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
559
+ config.max_position_embeddings - 1]`.
560
+
561
+ [What are position IDs?](../glossary#position-ids)
562
+ output_attentions (`bool`, *optional*):
563
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
564
+ tensors for more detail.
565
+ output_hidden_states (`bool`, *optional*):
566
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
567
+ more detail.
568
+ return_dict (`bool`, *optional*):
569
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
570
+ """
571
+
572
+ CLIP_VISION_INPUTS_DOCSTRING = r"""
573
+ Args:
574
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
575
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
576
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
577
+ output_attentions (`bool`, *optional*):
578
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
579
+ tensors for more detail.
580
+ output_hidden_states (`bool`, *optional*):
581
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
582
+ more detail.
583
+ return_dict (`bool`, *optional*):
584
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
585
+ """
586
+
587
+ CLIP_INPUTS_DOCSTRING = r"""
588
+ Args:
589
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
590
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
591
+ it.
592
+
593
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
594
+ [`PreTrainedTokenizer.__call__`] for details.
595
+
596
+ [What are input IDs?](../glossary#input-ids)
597
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
598
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
599
+
600
+ - 1 for tokens that are **not masked**,
601
+ - 0 for tokens that are **masked**.
602
+
603
+ [What are attention masks?](../glossary#attention-mask)
604
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
605
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
606
+ config.max_position_embeddings - 1]`.
607
+
608
+ [What are position IDs?](../glossary#position-ids)
609
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
610
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
611
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
612
+ return_loss (`bool`, *optional*):
613
+ Whether or not to return the contrastive loss.
614
+ output_attentions (`bool`, *optional*):
615
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
616
+ tensors for more detail.
617
+ output_hidden_states (`bool`, *optional*):
618
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
619
+ more detail.
620
+ return_dict (`bool`, *optional*):
621
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
622
+ """
623
+
624
+
625
+ class CLIPEncoder(nn.Module):
626
+ """
627
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
628
+ [`CLIPEncoderLayer`].
629
+
630
+ Args:
631
+ config: CLIPConfig
632
+ """
633
+
634
+ def __init__(self, config: CLIPConfig):
635
+ super().__init__()
636
+ self.config = config
637
+ self.layers = nn.ModuleList(
638
+ [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]
639
+ )
640
+ self.gradient_checkpointing = False
641
+
642
+ def forward(
643
+ self,
644
+ inputs_embeds,
645
+ attention_mask: Optional[torch.Tensor] = None,
646
+ causal_attention_mask: Optional[torch.Tensor] = None,
647
+ output_attentions: Optional[bool] = None,
648
+ output_hidden_states: Optional[bool] = None,
649
+ return_dict: Optional[bool] = None,
650
+ ) -> Union[Tuple, BaseModelOutput]:
651
+ r"""
652
+ Args:
653
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
654
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
655
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
656
+ than the model's internal embedding lookup matrix.
657
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
658
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
659
+
660
+ - 1 for tokens that are **not masked**,
661
+ - 0 for tokens that are **masked**.
662
+
663
+ [What are attention masks?](../glossary#attention-mask)
664
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
665
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
666
+
667
+ - 1 for tokens that are **not masked**,
668
+ - 0 for tokens that are **masked**.
669
+
670
+ [What are attention masks?](../glossary#attention-mask)
671
+ output_attentions (`bool`, *optional*):
672
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
673
+ returned tensors for more detail.
674
+ output_hidden_states (`bool`, *optional*):
675
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
676
+ for more detail.
677
+ return_dict (`bool`, *optional*):
678
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
679
+ """
680
+ output_attentions = (
681
+ output_attentions
682
+ if output_attentions is not None
683
+ else self.config.output_attentions
684
+ )
685
+ output_hidden_states = (
686
+ output_hidden_states
687
+ if output_hidden_states is not None
688
+ else self.config.output_hidden_states
689
+ )
690
+ return_dict = (
691
+ return_dict if return_dict is not None else self.config.use_return_dict
692
+ )
693
+
694
+ encoder_states = () if output_hidden_states else None
695
+ all_attentions = () if output_attentions else None
696
+
697
+ hidden_states = inputs_embeds
698
+ for idx, encoder_layer in enumerate(self.layers):
699
+ if output_hidden_states:
700
+ encoder_states = encoder_states + (hidden_states,)
701
+ if self.gradient_checkpointing and self.training:
702
+ layer_outputs = self._gradient_checkpointing_func(
703
+ encoder_layer.__call__,
704
+ hidden_states,
705
+ attention_mask,
706
+ causal_attention_mask,
707
+ output_attentions,
708
+ )
709
+ else:
710
+ layer_outputs = encoder_layer(
711
+ hidden_states,
712
+ attention_mask,
713
+ causal_attention_mask,
714
+ output_attentions=output_attentions,
715
+ )
716
+
717
+ hidden_states = layer_outputs[0]
718
+
719
+ if output_attentions:
720
+ all_attentions = all_attentions + (layer_outputs[1],)
721
+
722
+ if output_hidden_states:
723
+ encoder_states = encoder_states + (hidden_states,)
724
+
725
+ if not return_dict:
726
+ return tuple(
727
+ v
728
+ for v in [hidden_states, encoder_states, all_attentions]
729
+ if v is not None
730
+ )
731
+ return BaseModelOutput(
732
+ last_hidden_state=hidden_states,
733
+ hidden_states=encoder_states,
734
+ attentions=all_attentions,
735
+ )
736
+
737
+
738
+ class CLIPTextTransformer(nn.Module):
739
+ def __init__(self, config: CLIPTextConfig):
740
+ super().__init__()
741
+ self.config = config
742
+ embed_dim = config.hidden_size
743
+ self.embeddings = CLIPTextEmbeddings(config)
744
+ self.encoder = CLIPEncoder(config)
745
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
746
+
747
+ # For `pooled_output` computation
748
+ self.eos_token_id = config.eos_token_id
749
+
750
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
751
+ @replace_return_docstrings(
752
+ output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig
753
+ )
754
+ def forward(
755
+ self,
756
+ input_ids: Optional[torch.Tensor] = None,
757
+ attention_mask: Optional[torch.Tensor] = None,
758
+ position_ids: Optional[torch.Tensor] = None,
759
+ output_attentions: Optional[bool] = None,
760
+ output_hidden_states: Optional[bool] = None,
761
+ return_dict: Optional[bool] = None,
762
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
763
+ r"""
764
+ Returns:
765
+
766
+ """
767
+ output_attentions = (
768
+ output_attentions
769
+ if output_attentions is not None
770
+ else self.config.output_attentions
771
+ )
772
+ output_hidden_states = (
773
+ output_hidden_states
774
+ if output_hidden_states is not None
775
+ else self.config.output_hidden_states
776
+ )
777
+ return_dict = (
778
+ return_dict if return_dict is not None else self.config.use_return_dict
779
+ )
780
+
781
+ if input_ids is None:
782
+ raise ValueError("You have to specify input_ids")
783
+
784
+ input_shape = input_ids.size()
785
+ input_ids = input_ids.view(-1, input_shape[-1])
786
+
787
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
788
+
789
+ # CLIP's text model uses causal mask, prepare it here.
790
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
791
+ causal_attention_mask = _create_4d_causal_attention_mask(
792
+ input_shape, hidden_states.dtype, device=hidden_states.device
793
+ )
794
+ # expand attention_mask
795
+ if attention_mask is not None:
796
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
797
+ attention_mask = _prepare_4d_attention_mask(
798
+ attention_mask, hidden_states.dtype
799
+ )
800
+
801
+ encoder_outputs = self.encoder(
802
+ inputs_embeds=hidden_states,
803
+ attention_mask=attention_mask,
804
+ causal_attention_mask=causal_attention_mask,
805
+ output_attentions=output_attentions,
806
+ output_hidden_states=output_hidden_states,
807
+ return_dict=return_dict,
808
+ )
809
+
810
+ last_hidden_state = encoder_outputs[0]
811
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
812
+
813
+ if self.eos_token_id == 2:
814
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
815
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
816
+ # ------------------------------------------------------------
817
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
818
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
819
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
820
+ pooled_output = last_hidden_state[
821
+ torch.arange(
822
+ last_hidden_state.shape[0], device=last_hidden_state.device
823
+ ),
824
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
825
+ dim=-1
826
+ ),
827
+ ]
828
+ else:
829
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
830
+ pooled_output = last_hidden_state[
831
+ torch.arange(
832
+ last_hidden_state.shape[0], device=last_hidden_state.device
833
+ ),
834
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
835
+ # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
836
+ (
837
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device)
838
+ == self.eos_token_id
839
+ )
840
+ .int()
841
+ .argmax(dim=-1),
842
+ ]
843
+
844
+ if not return_dict:
845
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
846
+
847
+ return BaseModelOutputWithPooling(
848
+ last_hidden_state=last_hidden_state,
849
+ pooler_output=pooled_output,
850
+ hidden_states=encoder_outputs.hidden_states,
851
+ attentions=encoder_outputs.attentions,
852
+ )
853
+
854
+
855
+ @add_start_docstrings(
856
+ """The text model from CLIP without any head or projection on top.""",
857
+ CLIP_START_DOCSTRING,
858
+ )
859
+ class CLIPTextModel(CLIPPreTrainedModel):
860
+ config_class = CLIPTextConfig
861
+
862
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
863
+
864
+ def __init__(self, config: CLIPTextConfig):
865
+ super().__init__(config)
866
+ self.text_model = CLIPTextTransformer(config)
867
+ # Initialize weights and apply final processing
868
+ self.post_init()
869
+
870
+ def get_input_embeddings(self) -> nn.Module:
871
+ return self.text_model.embeddings.token_embedding
872
+
873
+ def set_input_embeddings(self, value):
874
+ self.text_model.embeddings.token_embedding = value
875
+
876
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
877
+ @replace_return_docstrings(
878
+ output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig
879
+ )
880
+ def forward(
881
+ self,
882
+ input_ids: Optional[torch.Tensor] = None,
883
+ attention_mask: Optional[torch.Tensor] = None,
884
+ position_ids: Optional[torch.Tensor] = None,
885
+ output_attentions: Optional[bool] = None,
886
+ output_hidden_states: Optional[bool] = None,
887
+ return_dict: Optional[bool] = None,
888
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
889
+ r"""
890
+ Returns:
891
+
892
+ Examples:
893
+
894
+ ```python
895
+ >>> from transformers import AutoTokenizer, CLIPTextModel
896
+
897
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
898
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
899
+
900
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
901
+
902
+ >>> outputs = model(**inputs)
903
+ >>> last_hidden_state = outputs.last_hidden_state
904
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
905
+ ```"""
906
+ return_dict = (
907
+ return_dict if return_dict is not None else self.config.use_return_dict
908
+ )
909
+
910
+ return self.text_model(
911
+ input_ids=input_ids,
912
+ attention_mask=attention_mask,
913
+ position_ids=position_ids,
914
+ output_attentions=output_attentions,
915
+ output_hidden_states=output_hidden_states,
916
+ return_dict=return_dict,
917
+ )
918
+
919
+
920
+ class CLIPVisionTransformer(nn.Module):
921
+ def __init__(self, config: CLIPVisionConfig):
922
+ super().__init__()
923
+ self.config = config
924
+ embed_dim = config.hidden_size
925
+
926
+ self.embeddings = CLIPVisionEmbeddings(config)
927
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
928
+ self.encoder = CLIPEncoder(config)
929
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
930
+
931
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
932
+ @replace_return_docstrings(
933
+ output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig
934
+ )
935
+ def forward(
936
+ self,
937
+ pixel_values: Optional[torch.FloatTensor] = None,
938
+ output_attentions: Optional[bool] = None,
939
+ output_hidden_states: Optional[bool] = None,
940
+ return_dict: Optional[bool] = None,
941
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
942
+ r"""
943
+ Returns:
944
+
945
+ """
946
+ output_attentions = (
947
+ output_attentions
948
+ if output_attentions is not None
949
+ else self.config.output_attentions
950
+ )
951
+ output_hidden_states = (
952
+ output_hidden_states
953
+ if output_hidden_states is not None
954
+ else self.config.output_hidden_states
955
+ )
956
+ return_dict = (
957
+ return_dict if return_dict is not None else self.config.use_return_dict
958
+ )
959
+
960
+ if pixel_values is None:
961
+ raise ValueError("You have to specify pixel_values")
962
+
963
+ hidden_states = self.embeddings(pixel_values)
964
+ hidden_states = self.pre_layrnorm(hidden_states)
965
+
966
+ encoder_outputs = self.encoder(
967
+ inputs_embeds=hidden_states,
968
+ output_attentions=output_attentions,
969
+ output_hidden_states=output_hidden_states,
970
+ return_dict=return_dict,
971
+ )
972
+
973
+ last_hidden_state = encoder_outputs[0]
974
+ pooled_output = last_hidden_state[:, 0, :]
975
+ pooled_output = self.post_layernorm(pooled_output)
976
+
977
+ if not return_dict:
978
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
979
+
980
+ return BaseModelOutputWithPooling(
981
+ last_hidden_state=last_hidden_state,
982
+ pooler_output=pooled_output,
983
+ hidden_states=encoder_outputs.hidden_states,
984
+ attentions=encoder_outputs.attentions,
985
+ )
986
+
987
+
988
+ @add_start_docstrings(
989
+ """The vision model from CLIP without any head or projection on top.""",
990
+ CLIP_START_DOCSTRING,
991
+ )
992
+ class CLIPVisionModel(CLIPPreTrainedModel):
993
+ config_class = CLIPVisionConfig
994
+ main_input_name = "pixel_values"
995
+ _no_split_modules = ["CLIPEncoderLayer"]
996
+
997
+ def __init__(self, config: CLIPVisionConfig):
998
+ super().__init__(config)
999
+ self.vision_model = CLIPVisionTransformer(config)
1000
+ # Initialize weights and apply final processing
1001
+ self.post_init()
1002
+
1003
+ def get_input_embeddings(self) -> nn.Module:
1004
+ return self.vision_model.embeddings.patch_embedding
1005
+
1006
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1007
+ @replace_return_docstrings(
1008
+ output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig
1009
+ )
1010
+ def forward(
1011
+ self,
1012
+ pixel_values: Optional[torch.FloatTensor] = None,
1013
+ output_attentions: Optional[bool] = None,
1014
+ output_hidden_states: Optional[bool] = None,
1015
+ return_dict: Optional[bool] = None,
1016
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1017
+ r"""
1018
+ Returns:
1019
+
1020
+ Examples:
1021
+
1022
+ ```python
1023
+ >>> from PIL import Image
1024
+ >>> import requests
1025
+ >>> from transformers import AutoProcessor, CLIPVisionModel
1026
+
1027
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
1028
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1029
+
1030
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1031
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1032
+
1033
+ >>> inputs = processor(images=image, return_tensors="pt")
1034
+
1035
+ >>> outputs = model(**inputs)
1036
+ >>> last_hidden_state = outputs.last_hidden_state
1037
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
1038
+ ```"""
1039
+ return_dict = (
1040
+ return_dict if return_dict is not None else self.config.use_return_dict
1041
+ )
1042
+
1043
+ return self.vision_model(
1044
+ pixel_values=pixel_values,
1045
+ output_attentions=output_attentions,
1046
+ output_hidden_states=output_hidden_states,
1047
+ return_dict=return_dict,
1048
+ )
1049
+
1050
+
1051
+ @add_start_docstrings(CLIP_START_DOCSTRING)
1052
+ class CLIPModel(CLIPPreTrainedModel):
1053
+ config_class = CLIPConfig
1054
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
1055
+
1056
+ def __init__(self, config: CLIPConfig):
1057
+ super().__init__(config)
1058
+
1059
+ if not isinstance(config.text_config, CLIPTextConfig):
1060
+ raise ValueError(
1061
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
1062
+ f" {type(config.text_config)}."
1063
+ )
1064
+
1065
+ if not isinstance(config.vision_config, CLIPVisionConfig):
1066
+ raise ValueError(
1067
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
1068
+ f" {type(config.vision_config)}."
1069
+ )
1070
+
1071
+ text_config = config.text_config
1072
+ vision_config = config.vision_config
1073
+
1074
+ self.projection_dim = config.projection_dim
1075
+ self.text_embed_dim = text_config.hidden_size
1076
+ self.vision_embed_dim = vision_config.hidden_size
1077
+
1078
+ self.text_model = CLIPTextTransformer(text_config)
1079
+ self.vision_model = CLIPVisionTransformer(vision_config)
1080
+
1081
+ self.visual_projection = nn.Linear(
1082
+ self.vision_embed_dim, self.projection_dim, bias=False
1083
+ )
1084
+ self.text_projection = nn.Linear(
1085
+ self.text_embed_dim, self.projection_dim, bias=False
1086
+ )
1087
+ self.logit_scale = nn.Parameter(
1088
+ torch.tensor(self.config.logit_scale_init_value)
1089
+ )
1090
+
1091
+ # Initialize weights and apply final processing
1092
+ self.post_init()
1093
+
1094
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1095
+ def get_text_features(
1096
+ self,
1097
+ input_ids: Optional[torch.Tensor] = None,
1098
+ attention_mask: Optional[torch.Tensor] = None,
1099
+ position_ids: Optional[torch.Tensor] = None,
1100
+ output_attentions: Optional[bool] = None,
1101
+ output_hidden_states: Optional[bool] = None,
1102
+ return_dict: Optional[bool] = None,
1103
+ ) -> torch.FloatTensor:
1104
+ r"""
1105
+ Returns:
1106
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1107
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
1108
+
1109
+ Examples:
1110
+
1111
+ ```python
1112
+ >>> from transformers import AutoTokenizer, CLIPModel
1113
+
1114
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1115
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1116
+
1117
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1118
+ >>> text_features = model.get_text_features(**inputs)
1119
+ ```"""
1120
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1121
+ output_attentions = (
1122
+ output_attentions
1123
+ if output_attentions is not None
1124
+ else self.config.output_attentions
1125
+ )
1126
+ output_hidden_states = (
1127
+ output_hidden_states
1128
+ if output_hidden_states is not None
1129
+ else self.config.output_hidden_states
1130
+ )
1131
+ return_dict = (
1132
+ return_dict if return_dict is not None else self.config.use_return_dict
1133
+ )
1134
+
1135
+ text_outputs = self.text_model(
1136
+ input_ids=input_ids,
1137
+ attention_mask=attention_mask,
1138
+ position_ids=position_ids,
1139
+ output_attentions=output_attentions,
1140
+ output_hidden_states=output_hidden_states,
1141
+ return_dict=return_dict,
1142
+ )
1143
+
1144
+ pooled_output = text_outputs[1]
1145
+ text_features = self.text_projection(pooled_output)
1146
+
1147
+ return text_features
1148
+
1149
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1150
+ def get_image_features(
1151
+ self,
1152
+ pixel_values: Optional[torch.FloatTensor] = None,
1153
+ output_attentions: Optional[bool] = None,
1154
+ output_hidden_states: Optional[bool] = None,
1155
+ return_dict: Optional[bool] = None,
1156
+ ) -> torch.FloatTensor:
1157
+ r"""
1158
+ Returns:
1159
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1160
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
1161
+
1162
+ Examples:
1163
+
1164
+ ```python
1165
+ >>> from PIL import Image
1166
+ >>> import requests
1167
+ >>> from transformers import AutoProcessor, CLIPModel
1168
+
1169
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1170
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1171
+
1172
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1173
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1174
+
1175
+ >>> inputs = processor(images=image, return_tensors="pt")
1176
+
1177
+ >>> image_features = model.get_image_features(**inputs)
1178
+ ```"""
1179
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1180
+ output_attentions = (
1181
+ output_attentions
1182
+ if output_attentions is not None
1183
+ else self.config.output_attentions
1184
+ )
1185
+ output_hidden_states = (
1186
+ output_hidden_states
1187
+ if output_hidden_states is not None
1188
+ else self.config.output_hidden_states
1189
+ )
1190
+ return_dict = (
1191
+ return_dict if return_dict is not None else self.config.use_return_dict
1192
+ )
1193
+
1194
+ vision_outputs = self.vision_model(
1195
+ pixel_values=pixel_values,
1196
+ output_attentions=output_attentions,
1197
+ output_hidden_states=output_hidden_states,
1198
+ return_dict=return_dict,
1199
+ )
1200
+
1201
+ pooled_output = vision_outputs[1] # pooled_output
1202
+ image_features = self.visual_projection(pooled_output)
1203
+
1204
+ return image_features
1205
+
1206
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
1207
+ @replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
1208
+ def forward(
1209
+ self,
1210
+ input_ids: Optional[torch.LongTensor] = None,
1211
+ pixel_values: Optional[torch.FloatTensor] = None,
1212
+ attention_mask: Optional[torch.Tensor] = None,
1213
+ position_ids: Optional[torch.LongTensor] = None,
1214
+ return_loss: Optional[bool] = None,
1215
+ output_attentions: Optional[bool] = None,
1216
+ output_hidden_states: Optional[bool] = None,
1217
+ return_dict: Optional[bool] = None,
1218
+ ) -> Union[Tuple, CLIPOutput]:
1219
+ r"""
1220
+ Returns:
1221
+
1222
+ Examples:
1223
+
1224
+ ```python
1225
+ >>> from PIL import Image
1226
+ >>> import requests
1227
+ >>> from transformers import AutoProcessor, CLIPModel
1228
+
1229
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1230
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1231
+
1232
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1233
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1234
+
1235
+ >>> inputs = processor(
1236
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
1237
+ ... )
1238
+
1239
+ >>> outputs = model(**inputs)
1240
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
1241
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
1242
+ ```"""
1243
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1244
+ output_attentions = (
1245
+ output_attentions
1246
+ if output_attentions is not None
1247
+ else self.config.output_attentions
1248
+ )
1249
+ output_hidden_states = (
1250
+ output_hidden_states
1251
+ if output_hidden_states is not None
1252
+ else self.config.output_hidden_states
1253
+ )
1254
+ return_dict = (
1255
+ return_dict if return_dict is not None else self.config.use_return_dict
1256
+ )
1257
+
1258
+ vision_outputs = self.vision_model(
1259
+ pixel_values=pixel_values,
1260
+ output_attentions=output_attentions,
1261
+ output_hidden_states=output_hidden_states,
1262
+ return_dict=return_dict,
1263
+ )
1264
+
1265
+ text_outputs = self.text_model(
1266
+ input_ids=input_ids,
1267
+ attention_mask=attention_mask,
1268
+ position_ids=position_ids,
1269
+ output_attentions=output_attentions,
1270
+ output_hidden_states=output_hidden_states,
1271
+ return_dict=return_dict,
1272
+ )
1273
+
1274
+ image_embeds = vision_outputs[1]
1275
+ image_embeds = self.visual_projection(image_embeds)
1276
+
1277
+ text_embeds = text_outputs[1]
1278
+ text_embeds = self.text_projection(text_embeds)
1279
+
1280
+ # normalized features
1281
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1282
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1283
+
1284
+ # cosine similarity as logits
1285
+ logit_scale = self.logit_scale.exp()
1286
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1287
+ logits_per_image = logits_per_text.t()
1288
+
1289
+ loss = None
1290
+ if return_loss:
1291
+ loss = clip_loss(logits_per_text)
1292
+
1293
+ if not return_dict:
1294
+ output = (
1295
+ logits_per_image,
1296
+ logits_per_text,
1297
+ text_embeds,
1298
+ image_embeds,
1299
+ text_outputs,
1300
+ vision_outputs,
1301
+ )
1302
+ return ((loss,) + output) if loss is not None else output
1303
+
1304
+ return CLIPOutput(
1305
+ loss=loss,
1306
+ logits_per_image=logits_per_image,
1307
+ logits_per_text=logits_per_text,
1308
+ text_embeds=text_embeds,
1309
+ image_embeds=image_embeds,
1310
+ text_model_output=text_outputs,
1311
+ vision_model_output=vision_outputs,
1312
+ )
1313
+
1314
+
1315
+ @add_start_docstrings(
1316
+ """
1317
+ CLIP Text Model with a projection layer on top (a linear layer on top of the pooled output).
1318
+ """,
1319
+ CLIP_START_DOCSTRING,
1320
+ )
1321
+ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
1322
+ config_class = CLIPTextConfig
1323
+
1324
+ _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"]
1325
+
1326
+ def __init__(self, config: CLIPTextConfig):
1327
+ super().__init__(config)
1328
+
1329
+ self.text_model = CLIPTextTransformer(config)
1330
+
1331
+ self.text_projection = nn.Linear(
1332
+ config.hidden_size, config.projection_dim, bias=False
1333
+ )
1334
+
1335
+ # Initialize weights and apply final processing
1336
+ self.post_init()
1337
+
1338
+ def get_input_embeddings(self) -> nn.Module:
1339
+ return self.text_model.embeddings.token_embedding
1340
+
1341
+ def set_input_embeddings(self, value):
1342
+ self.text_model.embeddings.token_embedding = value
1343
+
1344
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1345
+ @replace_return_docstrings(
1346
+ output_type=CLIPTextModelOutput, config_class=CLIPTextConfig
1347
+ )
1348
+ def forward(
1349
+ self,
1350
+ input_ids: Optional[torch.Tensor] = None,
1351
+ attention_mask: Optional[torch.Tensor] = None,
1352
+ position_ids: Optional[torch.Tensor] = None,
1353
+ output_attentions: Optional[bool] = None,
1354
+ output_hidden_states: Optional[bool] = None,
1355
+ return_dict: Optional[bool] = None,
1356
+ ) -> Union[Tuple, CLIPTextModelOutput]:
1357
+ r"""
1358
+ Returns:
1359
+
1360
+ Examples:
1361
+
1362
+ ```python
1363
+ >>> from transformers import AutoTokenizer, CLIPTextModelWithProjection
1364
+
1365
+ >>> model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
1366
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1367
+
1368
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1369
+
1370
+ >>> outputs = model(**inputs)
1371
+ >>> text_embeds = outputs.text_embeds
1372
+ ```"""
1373
+ return_dict = (
1374
+ return_dict if return_dict is not None else self.config.use_return_dict
1375
+ )
1376
+
1377
+ text_outputs = self.text_model(
1378
+ input_ids=input_ids,
1379
+ attention_mask=attention_mask,
1380
+ position_ids=position_ids,
1381
+ output_attentions=output_attentions,
1382
+ output_hidden_states=output_hidden_states,
1383
+ return_dict=return_dict,
1384
+ )
1385
+
1386
+ pooled_output = text_outputs[1]
1387
+
1388
+ text_embeds = self.text_projection(pooled_output)
1389
+
1390
+ if not return_dict:
1391
+ outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
1392
+ return tuple(output for output in outputs if output is not None)
1393
+
1394
+ return CLIPTextModelOutput(
1395
+ text_embeds=text_embeds,
1396
+ last_hidden_state=text_outputs.last_hidden_state,
1397
+ hidden_states=text_outputs.hidden_states,
1398
+ attentions=text_outputs.attentions,
1399
+ )
1400
+
1401
+
1402
+ @add_start_docstrings(
1403
+ """
1404
+ CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).
1405
+ """,
1406
+ CLIP_START_DOCSTRING,
1407
+ )
1408
+ class CLIPVisionModelWithProjection(CLIPPreTrainedModel):
1409
+ config_class = CLIPVisionConfig
1410
+ main_input_name = "pixel_values"
1411
+
1412
+ def __init__(self, config: CLIPVisionConfig):
1413
+ super().__init__(config)
1414
+
1415
+ self.vision_model = CLIPVisionTransformer(config)
1416
+
1417
+ self.visual_projection = nn.Linear(
1418
+ config.hidden_size, config.projection_dim, bias=False
1419
+ )
1420
+
1421
+ # Initialize weights and apply final processing
1422
+ self.post_init()
1423
+
1424
+ def get_input_embeddings(self) -> nn.Module:
1425
+ return self.vision_model.embeddings.patch_embedding
1426
+
1427
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1428
+ @replace_return_docstrings(
1429
+ output_type=CLIPVisionModelOutput, config_class=CLIPVisionConfig
1430
+ )
1431
+ def forward(
1432
+ self,
1433
+ pixel_values: Optional[torch.FloatTensor] = None,
1434
+ output_attentions: Optional[bool] = None,
1435
+ output_hidden_states: Optional[bool] = None,
1436
+ return_dict: Optional[bool] = None,
1437
+ ) -> Union[Tuple, CLIPVisionModelOutput]:
1438
+ r"""
1439
+ Returns:
1440
+
1441
+ Examples:
1442
+
1443
+ ```python
1444
+ >>> from PIL import Image
1445
+ >>> import requests
1446
+ >>> from transformers import AutoProcessor, CLIPVisionModelWithProjection
1447
+
1448
+ >>> model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
1449
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1450
+
1451
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1452
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1453
+
1454
+ >>> inputs = processor(images=image, return_tensors="pt")
1455
+
1456
+ >>> outputs = model(**inputs)
1457
+ >>> image_embeds = outputs.image_embeds
1458
+ ```"""
1459
+ return_dict = (
1460
+ return_dict if return_dict is not None else self.config.use_return_dict
1461
+ )
1462
+
1463
+ vision_outputs = self.vision_model(
1464
+ pixel_values=pixel_values,
1465
+ output_attentions=output_attentions,
1466
+ output_hidden_states=output_hidden_states,
1467
+ return_dict=return_dict,
1468
+ )
1469
+
1470
+ pooled_output = vision_outputs[1] # pooled_output
1471
+
1472
+ image_embeds = self.visual_projection(pooled_output)
1473
+
1474
+ if not return_dict:
1475
+ outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]
1476
+ return tuple(output for output in outputs if output is not None)
1477
+
1478
+ return CLIPVisionModelOutput(
1479
+ image_embeds=image_embeds,
1480
+ last_hidden_state=vision_outputs.last_hidden_state,
1481
+ hidden_states=vision_outputs.hidden_states,
1482
+ attentions=vision_outputs.attentions,
1483
+ )
1484
+
1485
+
1486
+ @add_start_docstrings(
1487
+ """
1488
+ CLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
1489
+ the patch tokens) e.g. for ImageNet.
1490
+ """,
1491
+ CLIP_START_DOCSTRING,
1492
+ )
1493
+ class CLIPForImageClassification(CLIPPreTrainedModel):
1494
+ main_input_name = "pixel_values"
1495
+
1496
+ def __init__(self, config: CLIPConfig) -> None:
1497
+ super().__init__(config)
1498
+
1499
+ self.num_labels = config.num_labels
1500
+ self.vision_model = CLIPVisionTransformer(config.vision_config)
1501
+
1502
+ # Classifier head
1503
+ self.classifier = (
1504
+ nn.Linear(config.vision_config.hidden_size, config.num_labels)
1505
+ if config.num_labels > 0
1506
+ else nn.Identity()
1507
+ )
1508
+
1509
+ # Initialize weights and apply final processing
1510
+ self.post_init()
1511
+
1512
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
1513
+ @add_code_sample_docstrings(
1514
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1515
+ output_type=ImageClassifierOutput,
1516
+ config_class=_CONFIG_FOR_DOC,
1517
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1518
+ )
1519
+ def forward(
1520
+ self,
1521
+ pixel_values: Optional[torch.Tensor] = None,
1522
+ labels: Optional[torch.Tensor] = None,
1523
+ output_attentions: Optional[bool] = None,
1524
+ output_hidden_states: Optional[bool] = None,
1525
+ return_dict: Optional[bool] = None,
1526
+ ) -> Union[tuple, ImageClassifierOutput]:
1527
+ r"""
1528
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1529
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1530
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1531
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1532
+ """
1533
+ output_attentions = (
1534
+ output_attentions
1535
+ if output_attentions is not None
1536
+ else self.config.output_attentions
1537
+ )
1538
+ output_hidden_states = (
1539
+ output_hidden_states
1540
+ if output_hidden_states is not None
1541
+ else self.config.output_hidden_states
1542
+ )
1543
+ return_dict = (
1544
+ return_dict if return_dict is not None else self.config.use_return_dict
1545
+ )
1546
+
1547
+ outputs = self.vision_model(
1548
+ pixel_values,
1549
+ output_attentions=output_attentions,
1550
+ output_hidden_states=output_hidden_states,
1551
+ return_dict=return_dict,
1552
+ )
1553
+
1554
+ sequence_output = outputs[0]
1555
+
1556
+ # average pool the patch tokens
1557
+ sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
1558
+ # apply classifier
1559
+ logits = self.classifier(sequence_output)
1560
+
1561
+ loss = None
1562
+ if labels is not None:
1563
+ # move labels to correct device to enable model parallelism
1564
+ labels = labels.to(logits.device)
1565
+ if self.config.problem_type is None:
1566
+ if self.num_labels == 1:
1567
+ self.config.problem_type = "regression"
1568
+ elif self.num_labels > 1 and (
1569
+ labels.dtype == torch.long or labels.dtype == torch.int
1570
+ ):
1571
+ self.config.problem_type = "single_label_classification"
1572
+ else:
1573
+ self.config.problem_type = "multi_label_classification"
1574
+
1575
+ if self.config.problem_type == "regression":
1576
+ loss_fct = MSELoss()
1577
+ if self.num_labels == 1:
1578
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1579
+ else:
1580
+ loss = loss_fct(logits, labels)
1581
+ elif self.config.problem_type == "single_label_classification":
1582
+ loss_fct = CrossEntropyLoss()
1583
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1584
+ elif self.config.problem_type == "multi_label_classification":
1585
+ loss_fct = BCEWithLogitsLoss()
1586
+ loss = loss_fct(logits, labels)
1587
+
1588
+ if not return_dict:
1589
+ output = (logits,) + outputs[2:]
1590
+ return ((loss,) + output) if loss is not None else output
1591
+
1592
+ return ImageClassifierOutput(
1593
+ loss=loss,
1594
+ logits=logits,
1595
+ hidden_states=outputs.hidden_states,
1596
+ attentions=outputs.attentions,
1597
+ )
step1x3d_geometry/models/conditional_encoders/clip/modeling_conditional_clip.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Reference:
16
+ # * transformers/models/dinov2/modeling_dinov2.py
17
+ # * https://github.com/facebookresearch/DiT/blob/main/models.py#L101
18
+ # * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2
19
+ """PyTorch CLIP model."""
20
+
21
+ from typing import Dict, List, Optional, Set, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ from .modeling_clip import (
27
+ CLIPConfig,
28
+ CLIPTextConfig,
29
+ CLIPVisionConfig,
30
+ CLIPEncoderLayer,
31
+ CLIPTextTransformer,
32
+ CLIPVisionTransformer,
33
+ CLIPModel,
34
+ CLIPVisionEmbeddings,
35
+ CLIPVisionModel,
36
+ CLIPOutput,
37
+ BaseModelOutput,
38
+ BaseModelOutputWithPooling,
39
+ )
40
+
41
+
42
+ class ModLN(nn.Module):
43
+ def __init__(self, inner_dim: int, mod_dim: int = 32):
44
+ super().__init__()
45
+ self.mlp = nn.Sequential(
46
+ nn.SiLU(),
47
+ nn.Linear(mod_dim, inner_dim * 2),
48
+ )
49
+
50
+ for m in self.modules():
51
+ if isinstance(m, nn.Linear):
52
+ nn.init.zeros_(m.weight)
53
+ nn.init.zeros_(m.bias)
54
+
55
+ def forward(self, x: torch.Tensor, condition: torch.Tensor):
56
+ """
57
+ x: [N, M, C_in], M: num of tokens
58
+ condition: [N, C_mod]
59
+ """
60
+ shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1)
61
+ return x * (1 + scale) + shift
62
+
63
+
64
+ class ConditionalCLIPVisionConfig(CLIPVisionConfig):
65
+ def __init__(self, modulation_dim: int = 32, *args, **kwargs):
66
+ super().__init__(*args, **kwargs)
67
+ self.modulation_dim = modulation_dim
68
+
69
+
70
+ class ConditionalCLIPEncoderLayer(CLIPEncoderLayer):
71
+ """This corresponds to the Block class in the original implementation."""
72
+
73
+ def __init__(self, config: ConditionalCLIPVisionConfig) -> None:
74
+ super().__init__(config)
75
+ self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim)
76
+ self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim)
77
+
78
+ def forward(
79
+ self,
80
+ hidden_states: torch.Tensor,
81
+ attention_mask: torch.Tensor,
82
+ causal_attention_mask: torch.Tensor,
83
+ condition: Optional[torch.Tensor] = None,
84
+ output_attentions: bool = False,
85
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
86
+ residual = hidden_states
87
+
88
+ hidden_states = self.mod_norm1(self.layer_norm1(hidden_states), condition)
89
+ hidden_states, attn_weights = self.self_attn(
90
+ hidden_states=hidden_states,
91
+ attention_mask=attention_mask,
92
+ causal_attention_mask=causal_attention_mask,
93
+ output_attentions=output_attentions,
94
+ )
95
+ hidden_states = residual + hidden_states
96
+
97
+ residual = hidden_states
98
+ hidden_states = self.mod_norm2(self.layer_norm2(hidden_states), condition)
99
+ hidden_states = self.mlp(hidden_states)
100
+ hidden_states = residual + hidden_states
101
+
102
+ outputs = (hidden_states,)
103
+
104
+ if output_attentions:
105
+ outputs += (attn_weights,)
106
+
107
+ return outputs
108
+
109
+
110
+ class ConditionalCLIPEncoder(nn.Module):
111
+ def __init__(self, config: CLIPConfig) -> None:
112
+ super().__init__()
113
+ self.config = config
114
+ self.layers = nn.ModuleList(
115
+ [
116
+ ConditionalCLIPEncoderLayer(config)
117
+ for _ in range(config.num_hidden_layers)
118
+ ]
119
+ )
120
+ self.gradient_checkpointing = False
121
+
122
+ def forward(
123
+ self,
124
+ inputs_embeds,
125
+ attention_mask: Optional[torch.Tensor] = None,
126
+ causal_attention_mask: Optional[torch.Tensor] = None,
127
+ output_attentions: Optional[bool] = None,
128
+ output_hidden_states: Optional[bool] = None,
129
+ condition: Optional[torch.Tensor] = None,
130
+ return_dict: Optional[bool] = None,
131
+ ) -> Union[tuple, BaseModelOutput]:
132
+ output_attentions = (
133
+ output_attentions
134
+ if output_attentions is not None
135
+ else self.config.output_attentions
136
+ )
137
+ output_hidden_states = (
138
+ output_hidden_states
139
+ if output_hidden_states is not None
140
+ else self.config.output_hidden_states
141
+ )
142
+ return_dict = (
143
+ return_dict if return_dict is not None else self.config.use_return_dict
144
+ )
145
+
146
+ encoder_states = () if output_hidden_states else None
147
+ all_attentions = () if output_attentions else None
148
+
149
+ hidden_states = inputs_embeds
150
+ for idx, encoder_layer in enumerate(self.layers):
151
+ if output_hidden_states:
152
+ encoder_states = encoder_states + (hidden_states,)
153
+ if self.gradient_checkpointing and self.training:
154
+ layer_outputs = self._gradient_checkpointing_func(
155
+ encoder_layer.__call__,
156
+ hidden_states,
157
+ attention_mask,
158
+ causal_attention_mask,
159
+ condition=condition,
160
+ output_attentions=output_attentions,
161
+ )
162
+ else:
163
+ layer_outputs = encoder_layer(
164
+ hidden_states,
165
+ attention_mask,
166
+ causal_attention_mask,
167
+ condition=condition,
168
+ output_attentions=output_attentions,
169
+ )
170
+
171
+ hidden_states = layer_outputs[0]
172
+
173
+ if output_attentions:
174
+ all_attentions = all_attentions + (layer_outputs[1],)
175
+
176
+ if output_hidden_states:
177
+ encoder_states = encoder_states + (hidden_states,)
178
+
179
+ if not return_dict:
180
+ return tuple(
181
+ v
182
+ for v in [hidden_states, encoder_states, all_attentions]
183
+ if v is not None
184
+ )
185
+ return BaseModelOutput(
186
+ last_hidden_state=hidden_states,
187
+ hidden_states=encoder_states,
188
+ attentions=all_attentions,
189
+ )
190
+
191
+
192
+ class ConditionalCLIPVisionTransformer(CLIPVisionTransformer):
193
+ def __init__(self, config: ConditionalCLIPVisionConfig):
194
+ super().__init__(config)
195
+ self.config = config
196
+ embed_dim = config.hidden_size
197
+
198
+ self.embeddings = CLIPVisionEmbeddings(config)
199
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
200
+ self.encoder = ConditionalCLIPEncoder(config)
201
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
202
+
203
+ def forward(
204
+ self,
205
+ pixel_values: Optional[torch.FloatTensor] = None,
206
+ condition: Optional[torch.Tensor] = None,
207
+ output_attentions: Optional[bool] = None,
208
+ output_hidden_states: Optional[bool] = None,
209
+ return_dict: Optional[bool] = None,
210
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
211
+ output_attentions = (
212
+ output_attentions
213
+ if output_attentions is not None
214
+ else self.config.output_attentions
215
+ )
216
+ output_hidden_states = (
217
+ output_hidden_states
218
+ if output_hidden_states is not None
219
+ else self.config.output_hidden_states
220
+ )
221
+ return_dict = (
222
+ return_dict if return_dict is not None else self.config.use_return_dict
223
+ )
224
+
225
+ if pixel_values is None:
226
+ raise ValueError("You have to specify pixel_values")
227
+
228
+ hidden_states = self.embeddings(pixel_values)
229
+ hidden_states = self.pre_layrnorm(hidden_states)
230
+
231
+ encoder_outputs = self.encoder(
232
+ inputs_embeds=hidden_states,
233
+ output_attentions=output_attentions,
234
+ output_hidden_states=output_hidden_states,
235
+ condition=condition,
236
+ return_dict=return_dict,
237
+ )
238
+
239
+ last_hidden_state = encoder_outputs[0]
240
+ pooled_output = last_hidden_state[:, 0, :]
241
+ pooled_output = self.post_layernorm(pooled_output)
242
+
243
+ if not return_dict:
244
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
245
+
246
+ return BaseModelOutputWithPooling(
247
+ last_hidden_state=last_hidden_state,
248
+ pooler_output=pooled_output,
249
+ hidden_states=encoder_outputs.hidden_states,
250
+ attentions=encoder_outputs.attentions,
251
+ )
252
+
253
+
254
+ class ConditionalCLIPVisionModel(CLIPVisionModel):
255
+ config_class = ConditionalCLIPVisionConfig
256
+
257
+ def __init__(self, config: ConditionalCLIPVisionConfig):
258
+ super().__init__(config)
259
+ self.vision_model = ConditionalCLIPVisionTransformer(config)
260
+ # Initialize weights and apply final processing
261
+ self.post_init()
262
+
263
+ def forward(
264
+ self,
265
+ pixel_values: Optional[torch.FloatTensor] = None,
266
+ condition: Optional[torch.Tensor] = None,
267
+ output_attentions: Optional[bool] = None,
268
+ output_hidden_states: Optional[bool] = None,
269
+ return_dict: Optional[bool] = None,
270
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
271
+ return_dict = (
272
+ return_dict if return_dict is not None else self.config.use_return_dict
273
+ )
274
+
275
+ return self.vision_model(
276
+ pixel_values=pixel_values,
277
+ condition=condition,
278
+ output_attentions=output_attentions,
279
+ output_hidden_states=output_hidden_states,
280
+ return_dict=return_dict,
281
+ )
282
+
283
+
284
+ class ConditionalCLIPModel(CLIPModel):
285
+ config_class = CLIPConfig
286
+
287
+ def __init__(self, config: CLIPConfig):
288
+ super().__init__(config)
289
+
290
+ if not isinstance(config.text_config, CLIPTextConfig):
291
+ raise ValueError(
292
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
293
+ f" {type(config.text_config)}."
294
+ )
295
+
296
+ if not isinstance(config.vision_config, CLIPVisionConfig):
297
+ raise ValueError(
298
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
299
+ f" {type(config.vision_config)}."
300
+ )
301
+
302
+ text_config = config.text_config
303
+ vision_config = config.vision_config
304
+
305
+ self.projection_dim = config.projection_dim
306
+ self.text_embed_dim = text_config.hidden_size
307
+ self.vision_embed_dim = vision_config.hidden_size
308
+
309
+ self.text_model = CLIPTextTransformer(text_config)
310
+ self.vision_model = ConditionalCLIPVisionTransformer(vision_config)
311
+
312
+ self.visual_projection = nn.Linear(
313
+ self.vision_embed_dim, self.projection_dim, bias=False
314
+ )
315
+ self.text_projection = nn.Linear(
316
+ self.text_embed_dim, self.projection_dim, bias=False
317
+ )
318
+ self.logit_scale = nn.Parameter(
319
+ torch.tensor(self.config.logit_scale_init_value)
320
+ )
321
+
322
+ # Initialize weights and apply final processing
323
+ self.post_init()
324
+
325
+ def get_image_features(
326
+ self,
327
+ pixel_values: Optional[torch.FloatTensor] = None,
328
+ condition: Optional[torch.Tensor] = None,
329
+ output_attentions: Optional[bool] = None,
330
+ output_hidden_states: Optional[bool] = None,
331
+ return_dict: Optional[bool] = None,
332
+ ) -> torch.FloatTensor:
333
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
334
+ output_attentions = (
335
+ output_attentions
336
+ if output_attentions is not None
337
+ else self.config.output_attentions
338
+ )
339
+ output_hidden_states = (
340
+ output_hidden_states
341
+ if output_hidden_states is not None
342
+ else self.config.output_hidden_states
343
+ )
344
+ return_dict = (
345
+ return_dict if return_dict is not None else self.config.use_return_dict
346
+ )
347
+
348
+ vision_outputs = self.vision_model(
349
+ pixel_values=pixel_values,
350
+ condition=condition,
351
+ output_attentions=output_attentions,
352
+ output_hidden_states=output_hidden_states,
353
+ return_dict=return_dict,
354
+ )
355
+
356
+ pooled_output = vision_outputs[1] # pooled_output
357
+ image_features = self.visual_projection(pooled_output)
358
+
359
+ return image_features
360
+
361
+ def forward(
362
+ self,
363
+ input_ids: Optional[torch.LongTensor] = None,
364
+ pixel_values: Optional[torch.FloatTensor] = None,
365
+ condition: Optional[torch.Tensor] = None,
366
+ attention_mask: Optional[torch.Tensor] = None,
367
+ position_ids: Optional[torch.LongTensor] = None,
368
+ return_loss: Optional[bool] = None,
369
+ output_attentions: Optional[bool] = None,
370
+ output_hidden_states: Optional[bool] = None,
371
+ return_dict: Optional[bool] = None,
372
+ ) -> Union[Tuple, CLIPOutput]:
373
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
374
+ output_attentions = (
375
+ output_attentions
376
+ if output_attentions is not None
377
+ else self.config.output_attentions
378
+ )
379
+ output_hidden_states = (
380
+ output_hidden_states
381
+ if output_hidden_states is not None
382
+ else self.config.output_hidden_states
383
+ )
384
+ return_dict = (
385
+ return_dict if return_dict is not None else self.config.use_return_dict
386
+ )
387
+
388
+ vision_outputs = self.vision_model(
389
+ pixel_values=pixel_values,
390
+ condition=condition,
391
+ output_attentions=output_attentions,
392
+ output_hidden_states=output_hidden_states,
393
+ return_dict=return_dict,
394
+ )
395
+
396
+ text_outputs = self.text_model(
397
+ input_ids=input_ids,
398
+ attention_mask=attention_mask,
399
+ position_ids=position_ids,
400
+ output_attentions=output_attentions,
401
+ output_hidden_states=output_hidden_states,
402
+ return_dict=return_dict,
403
+ )
404
+
405
+ image_embeds = vision_outputs[1]
406
+ image_embeds = self.visual_projection(image_embeds)
407
+
408
+ text_embeds = text_outputs[1]
409
+ text_embeds = self.text_projection(text_embeds)
410
+
411
+ # normalized features
412
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
413
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
414
+
415
+ # cosine similarity as logits
416
+ logit_scale = self.logit_scale.exp()
417
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
418
+ logits_per_image = logits_per_text.t()
419
+
420
+ loss = None
421
+ if return_loss:
422
+ loss = clip_loss(logits_per_text)
423
+
424
+ if not return_dict:
425
+ output = (
426
+ logits_per_image,
427
+ logits_per_text,
428
+ text_embeds,
429
+ image_embeds,
430
+ text_outputs,
431
+ vision_outputs,
432
+ )
433
+ return ((loss,) + output) if loss is not None else output
434
+
435
+ return CLIPOutput(
436
+ loss=loss,
437
+ logits_per_image=logits_per_image,
438
+ logits_per_text=logits_per_text,
439
+ text_embeds=text_embeds,
440
+ image_embeds=image_embeds,
441
+ text_model_output=text_outputs,
442
+ vision_model_output=vision_outputs,
443
+ )
step1x3d_geometry/models/conditional_encoders/dinov2/modeling_conditional_dinov2.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Reference:
16
+ # * transformers/models/dinov2/modeling_dinov2.py
17
+ # * https://github.com/facebookresearch/DiT/blob/main/models.py#L101
18
+ # * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2
19
+ """PyTorch DINOv2 model."""
20
+
21
+ from typing import Dict, List, Optional, Set, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ from .modeling_dinov2 import (
27
+ Dinov2Config,
28
+ Dinov2Layer,
29
+ Dinov2Model,
30
+ Dinov2Embeddings,
31
+ BaseModelOutput,
32
+ BaseModelOutputWithPooling,
33
+ )
34
+
35
+
36
+ class ModLN(nn.Module):
37
+ def __init__(self, inner_dim: int, mod_dim: int = 1024):
38
+ super().__init__()
39
+ self.mlp = nn.Sequential(
40
+ nn.SiLU(),
41
+ nn.Linear(mod_dim, inner_dim * 2),
42
+ )
43
+
44
+ for m in self.modules():
45
+ if isinstance(m, nn.Linear):
46
+ nn.init.zeros_(m.weight)
47
+ nn.init.zeros_(m.bias)
48
+
49
+ def forward(self, x: torch.Tensor, condition: torch.Tensor):
50
+ """
51
+ x: [N, M, C_in], M: num of tokens
52
+ condition: [N, C_mod]
53
+ """
54
+ shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1)
55
+ return x * (1 + scale) + shift
56
+
57
+
58
+ class ConditionalDinov2Config(Dinov2Config):
59
+ def __init__(self, modulation_dim: int = 1024, *args, **kwargs):
60
+ super().__init__(*args, **kwargs)
61
+ self.modulation_dim = modulation_dim
62
+
63
+
64
+ class ConditionalDinov2Layer(Dinov2Layer):
65
+ """This corresponds to the Block class in the original implementation."""
66
+
67
+ def __init__(self, config: ConditionalDinov2Config) -> None:
68
+ super().__init__(config)
69
+ self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim)
70
+ self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim)
71
+
72
+ def forward(
73
+ self,
74
+ hidden_states: torch.Tensor,
75
+ head_mask: Optional[torch.Tensor] = None,
76
+ condition: Optional[torch.Tensor] = None,
77
+ output_attentions: bool = False,
78
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
79
+ self_attention_outputs = self.attention(
80
+ self.mod_norm1(
81
+ self.norm1(hidden_states), condition
82
+ ), # in Dinov2, layernorm is applied before self-attention
83
+ head_mask,
84
+ output_attentions=output_attentions,
85
+ )
86
+ attention_output = self_attention_outputs[0]
87
+
88
+ attention_output = self.layer_scale1(attention_output)
89
+ outputs = self_attention_outputs[
90
+ 1:
91
+ ] # add self attentions if we output attention weights
92
+
93
+ # first residual connection
94
+ hidden_states = self.drop_path(attention_output) + hidden_states
95
+
96
+ # in Dinov2, layernorm is also applied after self-attention
97
+ layer_output = self.mod_norm2(self.norm2(hidden_states), condition)
98
+ layer_output = self.mlp(layer_output)
99
+ layer_output = self.layer_scale2(layer_output)
100
+
101
+ # second residual connection
102
+ layer_output = self.drop_path(layer_output) + hidden_states
103
+
104
+ outputs = (layer_output,) + outputs
105
+
106
+ return outputs
107
+
108
+
109
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
110
+ class ConditionalDinov2Encoder(nn.Module):
111
+ def __init__(self, config: ConditionalDinov2Config) -> None:
112
+ super().__init__()
113
+ self.config = config
114
+ self.layer = nn.ModuleList(
115
+ [ConditionalDinov2Layer(config) for _ in range(config.num_hidden_layers)]
116
+ )
117
+ self.gradient_checkpointing = False
118
+
119
+ def forward(
120
+ self,
121
+ hidden_states: torch.Tensor,
122
+ head_mask: Optional[torch.Tensor] = None,
123
+ output_attentions: bool = False,
124
+ output_hidden_states: bool = False,
125
+ condition: Optional[torch.Tensor] = None,
126
+ return_dict: bool = True,
127
+ ) -> Union[tuple, BaseModelOutput]:
128
+ all_hidden_states = () if output_hidden_states else None
129
+ all_self_attentions = () if output_attentions else None
130
+
131
+ for i, layer_module in enumerate(self.layer):
132
+ if output_hidden_states:
133
+ all_hidden_states = all_hidden_states + (hidden_states,)
134
+
135
+ layer_head_mask = head_mask[i] if head_mask is not None else None
136
+
137
+ if self.gradient_checkpointing and self.training:
138
+ layer_outputs = self._gradient_checkpointing_func(
139
+ layer_module.__call__,
140
+ hidden_states,
141
+ layer_head_mask,
142
+ condition,
143
+ output_attentions,
144
+ )
145
+ else:
146
+ layer_outputs = layer_module(
147
+ hidden_states,
148
+ layer_head_mask,
149
+ condition,
150
+ output_attentions,
151
+ )
152
+
153
+ hidden_states = layer_outputs[0]
154
+
155
+ if output_attentions:
156
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
157
+
158
+ if output_hidden_states:
159
+ all_hidden_states = all_hidden_states + (hidden_states,)
160
+
161
+ if not return_dict:
162
+ return tuple(
163
+ v
164
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
165
+ if v is not None
166
+ )
167
+ return BaseModelOutput(
168
+ last_hidden_state=hidden_states,
169
+ hidden_states=all_hidden_states,
170
+ attentions=all_self_attentions,
171
+ )
172
+
173
+
174
+ class ConditionalDinov2Model(Dinov2Model):
175
+ config_class = ConditionalDinov2Config
176
+
177
+ def __init__(self, config: ConditionalDinov2Config):
178
+ super().__init__(config)
179
+ self.config = config
180
+
181
+ self.embeddings = Dinov2Embeddings(config)
182
+ self.encoder = ConditionalDinov2Encoder(config)
183
+
184
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
185
+
186
+ # Initialize weights and apply final processing
187
+ self.post_init()
188
+
189
+ def forward(
190
+ self,
191
+ pixel_values: Optional[torch.Tensor] = None,
192
+ bool_masked_pos: Optional[torch.Tensor] = None,
193
+ head_mask: Optional[torch.Tensor] = None,
194
+ condition: Optional[torch.Tensor] = None,
195
+ output_attentions: Optional[bool] = None,
196
+ output_hidden_states: Optional[bool] = None,
197
+ return_dict: Optional[bool] = None,
198
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
199
+ output_attentions = (
200
+ output_attentions
201
+ if output_attentions is not None
202
+ else self.config.output_attentions
203
+ )
204
+ output_hidden_states = (
205
+ output_hidden_states
206
+ if output_hidden_states is not None
207
+ else self.config.output_hidden_states
208
+ )
209
+ return_dict = (
210
+ return_dict if return_dict is not None else self.config.use_return_dict
211
+ )
212
+
213
+ if pixel_values is None:
214
+ raise ValueError("You have to specify pixel_values")
215
+
216
+ # Prepare head mask if needed
217
+ # 1.0 in head_mask indicate we keep the head
218
+ # attention_probs has shape bsz x n_heads x N x N
219
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
220
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
221
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
222
+
223
+ embedding_output = self.embeddings(
224
+ pixel_values, bool_masked_pos=bool_masked_pos
225
+ )
226
+
227
+ encoder_outputs = self.encoder(
228
+ embedding_output,
229
+ head_mask=head_mask,
230
+ output_attentions=output_attentions,
231
+ output_hidden_states=output_hidden_states,
232
+ condition=condition,
233
+ return_dict=return_dict,
234
+ )
235
+ sequence_output = encoder_outputs[0]
236
+ sequence_output = self.layernorm(sequence_output)
237
+ pooled_output = sequence_output[:, 0, :]
238
+
239
+ if not return_dict:
240
+ head_outputs = (sequence_output, pooled_output)
241
+ return head_outputs + encoder_outputs[1:]
242
+
243
+ return BaseModelOutputWithPooling(
244
+ last_hidden_state=sequence_output,
245
+ pooler_output=pooled_output,
246
+ hidden_states=encoder_outputs.hidden_states,
247
+ attentions=encoder_outputs.attentions,
248
+ )
step1x3d_geometry/models/conditional_encoders/dinov2/modeling_dinov2.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch DINOv2 model."""
16
+
17
+
18
+ import collections.abc
19
+ import math
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BackboneOutput,
30
+ BaseModelOutput,
31
+ BaseModelOutputWithPooling,
32
+ ImageClassifierOutput,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.pytorch_utils import (
36
+ find_pruneable_heads_and_indices,
37
+ prune_linear_layer,
38
+ )
39
+ from transformers.utils import (
40
+ add_code_sample_docstrings,
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ logging,
44
+ replace_return_docstrings,
45
+ )
46
+ from transformers.utils.backbone_utils import BackboneMixin
47
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ # General docstring
53
+ _CONFIG_FOR_DOC = "Dinov2Config"
54
+
55
+ # Base docstring
56
+ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
57
+ _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
58
+
59
+ # Image classification docstring
60
+ _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-small-imagenet1k-1-layer"
61
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
62
+
63
+
64
+ DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
65
+ "facebook/dinov2-base",
66
+ # See all DINOv2 models at https://huggingface.co/models?filter=dinov2
67
+ ]
68
+
69
+
70
+ class Dinov2Embeddings(nn.Module):
71
+ """
72
+ Construct the CLS token, mask token, position and patch embeddings.
73
+ """
74
+
75
+ def __init__(self, config: Dinov2Config) -> None:
76
+ super().__init__()
77
+
78
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
79
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
80
+ self.patch_embeddings = Dinov2PatchEmbeddings(config)
81
+ num_patches = self.patch_embeddings.num_patches
82
+ self.position_embeddings = nn.Parameter(
83
+ torch.randn(1, num_patches + 1, config.hidden_size)
84
+ )
85
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
86
+ self.config = config
87
+
88
+ def interpolate_pos_encoding(
89
+ self, embeddings: torch.Tensor, height: int, width: int
90
+ ) -> torch.Tensor:
91
+ """
92
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
93
+ resolution images.
94
+
95
+ Source:
96
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
97
+ """
98
+
99
+ num_patches = embeddings.shape[1] - 1
100
+ num_positions = self.position_embeddings.shape[1] - 1
101
+ if num_patches == num_positions and height == width:
102
+ return self.position_embeddings
103
+ class_pos_embed = self.position_embeddings[:, 0]
104
+ patch_pos_embed = self.position_embeddings[:, 1:]
105
+ dim = embeddings.shape[-1]
106
+ height = height // self.config.patch_size
107
+ width = width // self.config.patch_size
108
+ # we add a small number to avoid floating point error in the interpolation
109
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
110
+ height, width = height + 0.1, width + 0.1
111
+ patch_pos_embed = patch_pos_embed.reshape(
112
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
113
+ )
114
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
115
+ target_dtype = patch_pos_embed.dtype
116
+ patch_pos_embed = nn.functional.interpolate(
117
+ patch_pos_embed.to(dtype=torch.float32),
118
+ scale_factor=(
119
+ float(height / math.sqrt(num_positions)),
120
+ float(width / math.sqrt(num_positions)),
121
+ ),
122
+ mode="bicubic",
123
+ align_corners=False,
124
+ ).to(dtype=target_dtype)
125
+ if (
126
+ int(height) != patch_pos_embed.shape[-2]
127
+ or int(width) != patch_pos_embed.shape[-1]
128
+ ):
129
+ raise ValueError(
130
+ "Width or height does not match with the interpolated position embeddings"
131
+ )
132
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
133
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
134
+
135
+ def forward(
136
+ self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None
137
+ ) -> torch.Tensor:
138
+ batch_size, _, height, width = pixel_values.shape
139
+ target_dtype = self.patch_embeddings.projection.weight.dtype
140
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
141
+
142
+ if bool_masked_pos is not None:
143
+ embeddings = torch.where(
144
+ bool_masked_pos.unsqueeze(-1),
145
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
146
+ embeddings,
147
+ )
148
+
149
+ # add the [CLS] token to the embedded patch tokens
150
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
151
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
152
+
153
+ # add positional encoding to each token
154
+ embeddings = embeddings + self.interpolate_pos_encoding(
155
+ embeddings, height, width
156
+ )
157
+
158
+ embeddings = self.dropout(embeddings)
159
+
160
+ return embeddings
161
+
162
+
163
+ class Dinov2PatchEmbeddings(nn.Module):
164
+ """
165
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
166
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
167
+ Transformer.
168
+ """
169
+
170
+ def __init__(self, config):
171
+ super().__init__()
172
+ image_size, patch_size = config.image_size, config.patch_size
173
+ num_channels, hidden_size = config.num_channels, config.hidden_size
174
+
175
+ image_size = (
176
+ image_size
177
+ if isinstance(image_size, collections.abc.Iterable)
178
+ else (image_size, image_size)
179
+ )
180
+ patch_size = (
181
+ patch_size
182
+ if isinstance(patch_size, collections.abc.Iterable)
183
+ else (patch_size, patch_size)
184
+ )
185
+ num_patches = (image_size[1] // patch_size[1]) * (
186
+ image_size[0] // patch_size[0]
187
+ )
188
+ self.image_size = image_size
189
+ self.patch_size = patch_size
190
+ self.num_channels = num_channels
191
+ self.num_patches = num_patches
192
+
193
+ self.projection = nn.Conv2d(
194
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
195
+ )
196
+
197
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
198
+ num_channels = pixel_values.shape[1]
199
+ if num_channels != self.num_channels:
200
+ raise ValueError(
201
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
202
+ f" Expected {self.num_channels} but got {num_channels}."
203
+ )
204
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
205
+ return embeddings
206
+
207
+
208
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Dinov2
209
+ class Dinov2SelfAttention(nn.Module):
210
+ def __init__(self, config: Dinov2Config) -> None:
211
+ super().__init__()
212
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
213
+ config, "embedding_size"
214
+ ):
215
+ raise ValueError(
216
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
217
+ f"heads {config.num_attention_heads}."
218
+ )
219
+
220
+ self.num_attention_heads = config.num_attention_heads
221
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
222
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
223
+
224
+ self.query = nn.Linear(
225
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
226
+ )
227
+ self.key = nn.Linear(
228
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
229
+ )
230
+ self.value = nn.Linear(
231
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
232
+ )
233
+
234
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
235
+
236
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
237
+ new_x_shape = x.size()[:-1] + (
238
+ self.num_attention_heads,
239
+ self.attention_head_size,
240
+ )
241
+ x = x.view(new_x_shape)
242
+ return x.permute(0, 2, 1, 3)
243
+
244
+ def forward(
245
+ self,
246
+ hidden_states,
247
+ head_mask: Optional[torch.Tensor] = None,
248
+ output_attentions: bool = False,
249
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
250
+ mixed_query_layer = self.query(hidden_states)
251
+
252
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
253
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
254
+ query_layer = self.transpose_for_scores(mixed_query_layer)
255
+
256
+ # Take the dot product between "query" and "key" to get the raw attention scores.
257
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
258
+
259
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
260
+
261
+ # Normalize the attention scores to probabilities.
262
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
263
+
264
+ # This is actually dropping out entire tokens to attend to, which might
265
+ # seem a bit unusual, but is taken from the original Transformer paper.
266
+ attention_probs = self.dropout(attention_probs)
267
+
268
+ # Mask heads if we want to
269
+ if head_mask is not None:
270
+ attention_probs = attention_probs * head_mask
271
+
272
+ context_layer = torch.matmul(attention_probs, value_layer)
273
+
274
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
275
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
276
+ context_layer = context_layer.view(new_context_layer_shape)
277
+
278
+ outputs = (
279
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
280
+ )
281
+
282
+ return outputs
283
+
284
+
285
+ # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Dinov2
286
+ class Dinov2SelfOutput(nn.Module):
287
+ """
288
+ The residual connection is defined in Dinov2Layer instead of here (as is the case with other models), due to the
289
+ layernorm applied before each block.
290
+ """
291
+
292
+ def __init__(self, config: Dinov2Config) -> None:
293
+ super().__init__()
294
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
295
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
296
+
297
+ def forward(
298
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
299
+ ) -> torch.Tensor:
300
+ hidden_states = self.dense(hidden_states)
301
+ hidden_states = self.dropout(hidden_states)
302
+
303
+ return hidden_states
304
+
305
+
306
+ # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Dinov2
307
+ class Dinov2Attention(nn.Module):
308
+ def __init__(self, config: Dinov2Config) -> None:
309
+ super().__init__()
310
+ self.attention = Dinov2SelfAttention(config)
311
+ self.output = Dinov2SelfOutput(config)
312
+ self.pruned_heads = set()
313
+
314
+ def prune_heads(self, heads: Set[int]) -> None:
315
+ if len(heads) == 0:
316
+ return
317
+ heads, index = find_pruneable_heads_and_indices(
318
+ heads,
319
+ self.attention.num_attention_heads,
320
+ self.attention.attention_head_size,
321
+ self.pruned_heads,
322
+ )
323
+
324
+ # Prune linear layers
325
+ self.attention.query = prune_linear_layer(self.attention.query, index)
326
+ self.attention.key = prune_linear_layer(self.attention.key, index)
327
+ self.attention.value = prune_linear_layer(self.attention.value, index)
328
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
329
+
330
+ # Update hyper params and store pruned heads
331
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
332
+ heads
333
+ )
334
+ self.attention.all_head_size = (
335
+ self.attention.attention_head_size * self.attention.num_attention_heads
336
+ )
337
+ self.pruned_heads = self.pruned_heads.union(heads)
338
+
339
+ def forward(
340
+ self,
341
+ hidden_states: torch.Tensor,
342
+ head_mask: Optional[torch.Tensor] = None,
343
+ output_attentions: bool = False,
344
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
345
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
346
+
347
+ attention_output = self.output(self_outputs[0], hidden_states)
348
+
349
+ outputs = (attention_output,) + self_outputs[
350
+ 1:
351
+ ] # add attentions if we output them
352
+ return outputs
353
+
354
+
355
+ class Dinov2LayerScale(nn.Module):
356
+ def __init__(self, config) -> None:
357
+ super().__init__()
358
+ self.lambda1 = nn.Parameter(
359
+ config.layerscale_value * torch.ones(config.hidden_size)
360
+ )
361
+
362
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
363
+ return hidden_state * self.lambda1
364
+
365
+
366
+ # Copied from transformers.models.beit.modeling_beit.drop_path
367
+ def drop_path(
368
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
369
+ ) -> torch.Tensor:
370
+ """
371
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
372
+
373
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
374
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
375
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
376
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
377
+ argument.
378
+ """
379
+ if drop_prob == 0.0 or not training:
380
+ return input
381
+ keep_prob = 1 - drop_prob
382
+ shape = (input.shape[0],) + (1,) * (
383
+ input.ndim - 1
384
+ ) # work with diff dim tensors, not just 2D ConvNets
385
+ random_tensor = keep_prob + torch.rand(
386
+ shape, dtype=input.dtype, device=input.device
387
+ )
388
+ random_tensor.floor_() # binarize
389
+ output = input.div(keep_prob) * random_tensor
390
+ return output
391
+
392
+
393
+ # Copied from transformers.models.beit.modeling_beit.BeitDropPath
394
+ class Dinov2DropPath(nn.Module):
395
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
396
+
397
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
398
+ super().__init__()
399
+ self.drop_prob = drop_prob
400
+
401
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
402
+ return drop_path(hidden_states, self.drop_prob, self.training)
403
+
404
+ def extra_repr(self) -> str:
405
+ return "p={}".format(self.drop_prob)
406
+
407
+
408
+ class Dinov2MLP(nn.Module):
409
+ def __init__(self, config) -> None:
410
+ super().__init__()
411
+ in_features = out_features = config.hidden_size
412
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
413
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
414
+ if isinstance(config.hidden_act, str):
415
+ self.activation = ACT2FN[config.hidden_act]
416
+ else:
417
+ self.activation = config.hidden_act
418
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
419
+
420
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
421
+ hidden_state = self.fc1(hidden_state)
422
+ hidden_state = self.activation(hidden_state)
423
+ hidden_state = self.fc2(hidden_state)
424
+ return hidden_state
425
+
426
+
427
+ class Dinov2SwiGLUFFN(nn.Module):
428
+ def __init__(self, config) -> None:
429
+ super().__init__()
430
+ in_features = out_features = config.hidden_size
431
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
432
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
433
+
434
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
435
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
436
+
437
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
438
+ hidden_state = self.weights_in(hidden_state)
439
+ x1, x2 = hidden_state.chunk(2, dim=-1)
440
+ hidden = nn.functional.silu(x1) * x2
441
+ return self.weights_out(hidden)
442
+
443
+
444
+ class Dinov2Layer(nn.Module):
445
+ """This corresponds to the Block class in the original implementation."""
446
+
447
+ def __init__(self, config: Dinov2Config) -> None:
448
+ super().__init__()
449
+
450
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
451
+ self.attention = Dinov2Attention(config)
452
+ self.layer_scale1 = Dinov2LayerScale(config)
453
+ self.drop_path = (
454
+ Dinov2DropPath(config.drop_path_rate)
455
+ if config.drop_path_rate > 0.0
456
+ else nn.Identity()
457
+ )
458
+
459
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
460
+
461
+ if config.use_swiglu_ffn:
462
+ self.mlp = Dinov2SwiGLUFFN(config)
463
+ else:
464
+ self.mlp = Dinov2MLP(config)
465
+ self.layer_scale2 = Dinov2LayerScale(config)
466
+
467
+ def forward(
468
+ self,
469
+ hidden_states: torch.Tensor,
470
+ head_mask: Optional[torch.Tensor] = None,
471
+ output_attentions: bool = False,
472
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
473
+ self_attention_outputs = self.attention(
474
+ self.norm1(
475
+ hidden_states
476
+ ), # in Dinov2, layernorm is applied before self-attention
477
+ head_mask,
478
+ output_attentions=output_attentions,
479
+ )
480
+ attention_output = self_attention_outputs[0]
481
+
482
+ attention_output = self.layer_scale1(attention_output)
483
+ outputs = self_attention_outputs[
484
+ 1:
485
+ ] # add self attentions if we output attention weights
486
+
487
+ # first residual connection
488
+ hidden_states = self.drop_path(attention_output) + hidden_states
489
+
490
+ # in Dinov2, layernorm is also applied after self-attention
491
+ layer_output = self.norm2(hidden_states)
492
+ layer_output = self.mlp(layer_output)
493
+ layer_output = self.layer_scale2(layer_output)
494
+
495
+ # second residual connection
496
+ layer_output = self.drop_path(layer_output) + hidden_states
497
+
498
+ outputs = (layer_output,) + outputs
499
+
500
+ return outputs
501
+
502
+
503
+ # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->Dinov2
504
+ class Dinov2Encoder(nn.Module):
505
+ def __init__(self, config: Dinov2Config) -> None:
506
+ super().__init__()
507
+ self.config = config
508
+ self.layer = nn.ModuleList(
509
+ [Dinov2Layer(config) for _ in range(config.num_hidden_layers)]
510
+ )
511
+ self.gradient_checkpointing = False
512
+
513
+ def forward(
514
+ self,
515
+ hidden_states: torch.Tensor,
516
+ head_mask: Optional[torch.Tensor] = None,
517
+ output_attentions: bool = False,
518
+ output_hidden_states: bool = False,
519
+ return_dict: bool = True,
520
+ ) -> Union[tuple, BaseModelOutput]:
521
+ all_hidden_states = () if output_hidden_states else None
522
+ all_self_attentions = () if output_attentions else None
523
+
524
+ for i, layer_module in enumerate(self.layer):
525
+ if output_hidden_states:
526
+ all_hidden_states = all_hidden_states + (hidden_states,)
527
+
528
+ layer_head_mask = head_mask[i] if head_mask is not None else None
529
+
530
+ if self.gradient_checkpointing and self.training:
531
+ layer_outputs = self._gradient_checkpointing_func(
532
+ layer_module.__call__,
533
+ hidden_states,
534
+ layer_head_mask,
535
+ output_attentions,
536
+ )
537
+ else:
538
+ layer_outputs = layer_module(
539
+ hidden_states, layer_head_mask, output_attentions
540
+ )
541
+
542
+ hidden_states = layer_outputs[0]
543
+
544
+ if output_attentions:
545
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
546
+
547
+ if output_hidden_states:
548
+ all_hidden_states = all_hidden_states + (hidden_states,)
549
+
550
+ if not return_dict:
551
+ return tuple(
552
+ v
553
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
554
+ if v is not None
555
+ )
556
+ return BaseModelOutput(
557
+ last_hidden_state=hidden_states,
558
+ hidden_states=all_hidden_states,
559
+ attentions=all_self_attentions,
560
+ )
561
+
562
+
563
+ class Dinov2PreTrainedModel(PreTrainedModel):
564
+ """
565
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
566
+ models.
567
+ """
568
+
569
+ config_class = Dinov2Config
570
+ base_model_prefix = "dinov2"
571
+ main_input_name = "pixel_values"
572
+ supports_gradient_checkpointing = True
573
+
574
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
575
+ """Initialize the weights"""
576
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
577
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
578
+ # `trunc_normal_cpu` not implemented in `half` issues
579
+ module.weight.data = nn.init.trunc_normal_(
580
+ module.weight.data.to(torch.float32),
581
+ mean=0.0,
582
+ std=self.config.initializer_range,
583
+ ).to(module.weight.dtype)
584
+ if module.bias is not None:
585
+ module.bias.data.zero_()
586
+ elif isinstance(module, nn.LayerNorm):
587
+ module.bias.data.zero_()
588
+ module.weight.data.fill_(1.0)
589
+ elif isinstance(module, Dinov2Embeddings):
590
+ module.position_embeddings.data = nn.init.trunc_normal_(
591
+ module.position_embeddings.data.to(torch.float32),
592
+ mean=0.0,
593
+ std=self.config.initializer_range,
594
+ ).to(module.position_embeddings.dtype)
595
+
596
+ module.cls_token.data = nn.init.trunc_normal_(
597
+ module.cls_token.data.to(torch.float32),
598
+ mean=0.0,
599
+ std=self.config.initializer_range,
600
+ ).to(module.cls_token.dtype)
601
+
602
+
603
+ DINOV2_START_DOCSTRING = r"""
604
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
605
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
606
+ behavior.
607
+
608
+ Parameters:
609
+ config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
610
+ Initializing with a config file does not load the weights associated with the model, only the
611
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
612
+ """
613
+
614
+ DINOV2_BASE_INPUTS_DOCSTRING = r"""
615
+ Args:
616
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
617
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
618
+ [`BitImageProcessor.preprocess`] for details.
619
+
620
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
621
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
622
+ pre-training.
623
+
624
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
625
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
626
+
627
+ - 1 indicates the head is **not masked**,
628
+ - 0 indicates the head is **masked**.
629
+
630
+ output_attentions (`bool`, *optional*):
631
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
632
+ tensors for more detail.
633
+ output_hidden_states (`bool`, *optional*):
634
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
635
+ more detail.
636
+ return_dict (`bool`, *optional*):
637
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
638
+ """
639
+
640
+ DINOV2_INPUTS_DOCSTRING = r"""
641
+ Args:
642
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
643
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
644
+ [`BitImageProcessor.preprocess`] for details.
645
+
646
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
647
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
648
+
649
+ - 1 indicates the head is **not masked**,
650
+ - 0 indicates the head is **masked**.
651
+
652
+ output_attentions (`bool`, *optional*):
653
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
654
+ tensors for more detail.
655
+ output_hidden_states (`bool`, *optional*):
656
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
657
+ more detail.
658
+ return_dict (`bool`, *optional*):
659
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
660
+ """
661
+
662
+
663
+ @add_start_docstrings(
664
+ "The bare DINOv2 Model transformer outputting raw hidden-states without any specific head on top.",
665
+ DINOV2_START_DOCSTRING,
666
+ )
667
+ class Dinov2Model(Dinov2PreTrainedModel):
668
+ def __init__(self, config: Dinov2Config):
669
+ super().__init__(config)
670
+ self.config = config
671
+
672
+ self.embeddings = Dinov2Embeddings(config)
673
+ self.encoder = Dinov2Encoder(config)
674
+
675
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
676
+
677
+ # Initialize weights and apply final processing
678
+ self.post_init()
679
+
680
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
681
+ return self.embeddings.patch_embeddings
682
+
683
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
684
+ """
685
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
686
+ class PreTrainedModel
687
+ """
688
+ for layer, heads in heads_to_prune.items():
689
+ self.encoder.layer[layer].attention.prune_heads(heads)
690
+
691
+ @add_start_docstrings_to_model_forward(DINOV2_BASE_INPUTS_DOCSTRING)
692
+ @add_code_sample_docstrings(
693
+ checkpoint=_CHECKPOINT_FOR_DOC,
694
+ output_type=BaseModelOutputWithPooling,
695
+ config_class=_CONFIG_FOR_DOC,
696
+ modality="vision",
697
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
698
+ )
699
+ def forward(
700
+ self,
701
+ pixel_values: Optional[torch.Tensor] = None,
702
+ bool_masked_pos: Optional[torch.Tensor] = None,
703
+ head_mask: Optional[torch.Tensor] = None,
704
+ output_attentions: Optional[bool] = None,
705
+ output_hidden_states: Optional[bool] = None,
706
+ return_dict: Optional[bool] = None,
707
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
708
+ output_attentions = (
709
+ output_attentions
710
+ if output_attentions is not None
711
+ else self.config.output_attentions
712
+ )
713
+ output_hidden_states = (
714
+ output_hidden_states
715
+ if output_hidden_states is not None
716
+ else self.config.output_hidden_states
717
+ )
718
+ return_dict = (
719
+ return_dict if return_dict is not None else self.config.use_return_dict
720
+ )
721
+
722
+ if pixel_values is None:
723
+ raise ValueError("You have to specify pixel_values")
724
+
725
+ # Prepare head mask if needed
726
+ # 1.0 in head_mask indicate we keep the head
727
+ # attention_probs has shape bsz x n_heads x N x N
728
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
729
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
730
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
731
+
732
+ embedding_output = self.embeddings(
733
+ pixel_values, bool_masked_pos=bool_masked_pos
734
+ )
735
+
736
+ encoder_outputs = self.encoder(
737
+ embedding_output,
738
+ head_mask=head_mask,
739
+ output_attentions=output_attentions,
740
+ output_hidden_states=output_hidden_states,
741
+ return_dict=return_dict,
742
+ )
743
+ sequence_output = encoder_outputs[0]
744
+ sequence_output = self.layernorm(sequence_output)
745
+ pooled_output = sequence_output[:, 0, :]
746
+
747
+ if not return_dict:
748
+ head_outputs = (sequence_output, pooled_output)
749
+ return head_outputs + encoder_outputs[1:]
750
+
751
+ return BaseModelOutputWithPooling(
752
+ last_hidden_state=sequence_output,
753
+ pooler_output=pooled_output,
754
+ hidden_states=encoder_outputs.hidden_states,
755
+ attentions=encoder_outputs.attentions,
756
+ )
757
+
758
+
759
+ @add_start_docstrings(
760
+ """
761
+ Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
762
+ of the [CLS] token) e.g. for ImageNet.
763
+ """,
764
+ DINOV2_START_DOCSTRING,
765
+ )
766
+ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
767
+ def __init__(self, config: Dinov2Config) -> None:
768
+ super().__init__(config)
769
+
770
+ self.num_labels = config.num_labels
771
+ self.dinov2 = Dinov2Model(config)
772
+
773
+ # Classifier head
774
+ self.classifier = (
775
+ nn.Linear(config.hidden_size * 2, config.num_labels)
776
+ if config.num_labels > 0
777
+ else nn.Identity()
778
+ )
779
+
780
+ # Initialize weights and apply final processing
781
+ self.post_init()
782
+
783
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
784
+ @add_code_sample_docstrings(
785
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
786
+ output_type=ImageClassifierOutput,
787
+ config_class=_CONFIG_FOR_DOC,
788
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
789
+ )
790
+ def forward(
791
+ self,
792
+ pixel_values: Optional[torch.Tensor] = None,
793
+ head_mask: Optional[torch.Tensor] = None,
794
+ labels: Optional[torch.Tensor] = None,
795
+ output_attentions: Optional[bool] = None,
796
+ output_hidden_states: Optional[bool] = None,
797
+ return_dict: Optional[bool] = None,
798
+ ) -> Union[tuple, ImageClassifierOutput]:
799
+ r"""
800
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
801
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
802
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
803
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
804
+ """
805
+ return_dict = (
806
+ return_dict if return_dict is not None else self.config.use_return_dict
807
+ )
808
+
809
+ outputs = self.dinov2(
810
+ pixel_values,
811
+ head_mask=head_mask,
812
+ output_attentions=output_attentions,
813
+ output_hidden_states=output_hidden_states,
814
+ return_dict=return_dict,
815
+ )
816
+
817
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
818
+
819
+ cls_token = sequence_output[:, 0]
820
+ patch_tokens = sequence_output[:, 1:]
821
+
822
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
823
+
824
+ logits = self.classifier(linear_input)
825
+
826
+ loss = None
827
+ if labels is not None:
828
+ # move labels to correct device to enable model parallelism
829
+ labels = labels.to(logits.device)
830
+ if self.config.problem_type is None:
831
+ if self.num_labels == 1:
832
+ self.config.problem_type = "regression"
833
+ elif self.num_labels > 1 and (
834
+ labels.dtype == torch.long or labels.dtype == torch.int
835
+ ):
836
+ self.config.problem_type = "single_label_classification"
837
+ else:
838
+ self.config.problem_type = "multi_label_classification"
839
+
840
+ if self.config.problem_type == "regression":
841
+ loss_fct = MSELoss()
842
+ if self.num_labels == 1:
843
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
844
+ else:
845
+ loss = loss_fct(logits, labels)
846
+ elif self.config.problem_type == "single_label_classification":
847
+ loss_fct = CrossEntropyLoss()
848
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
849
+ elif self.config.problem_type == "multi_label_classification":
850
+ loss_fct = BCEWithLogitsLoss()
851
+ loss = loss_fct(logits, labels)
852
+
853
+ if not return_dict:
854
+ output = (logits,) + outputs[2:]
855
+ return ((loss,) + output) if loss is not None else output
856
+
857
+ return ImageClassifierOutput(
858
+ loss=loss,
859
+ logits=logits,
860
+ hidden_states=outputs.hidden_states,
861
+ attentions=outputs.attentions,
862
+ )
863
+
864
+
865
+ @add_start_docstrings(
866
+ """
867
+ Dinov2 backbone, to be used with frameworks like DETR and MaskFormer.
868
+ """,
869
+ DINOV2_START_DOCSTRING,
870
+ )
871
+ class Dinov2Backbone(Dinov2PreTrainedModel, BackboneMixin):
872
+ def __init__(self, config):
873
+ super().__init__(config)
874
+ super()._init_backbone(config)
875
+
876
+ self.num_features = [
877
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
878
+ ]
879
+ self.embeddings = Dinov2Embeddings(config)
880
+ self.encoder = Dinov2Encoder(config)
881
+
882
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
883
+
884
+ # Initialize weights and apply final processing
885
+ self.post_init()
886
+
887
+ def get_input_embeddings(self) -> Dinov2PatchEmbeddings:
888
+ return self.embeddings.patch_embeddings
889
+
890
+ @add_start_docstrings_to_model_forward(DINOV2_INPUTS_DOCSTRING)
891
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
892
+ def forward(
893
+ self,
894
+ pixel_values: torch.Tensor,
895
+ output_hidden_states: Optional[bool] = None,
896
+ output_attentions: Optional[bool] = None,
897
+ return_dict: Optional[bool] = None,
898
+ ) -> BackboneOutput:
899
+ """
900
+ Returns:
901
+
902
+ Examples:
903
+
904
+ ```python
905
+ >>> from transformers import AutoImageProcessor, AutoBackbone
906
+ >>> import torch
907
+ >>> from PIL import Image
908
+ >>> import requests
909
+
910
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
911
+ >>> image = Image.open(requests.get(url, stream=True).raw)
912
+
913
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
914
+ >>> model = AutoBackbone.from_pretrained(
915
+ ... "facebook/dinov2-base", out_features=["stage2", "stage5", "stage8", "stage11"]
916
+ ... )
917
+
918
+ >>> inputs = processor(image, return_tensors="pt")
919
+
920
+ >>> outputs = model(**inputs)
921
+ >>> feature_maps = outputs.feature_maps
922
+ >>> list(feature_maps[-1].shape)
923
+ [1, 768, 16, 16]
924
+ ```"""
925
+ return_dict = (
926
+ return_dict if return_dict is not None else self.config.use_return_dict
927
+ )
928
+ output_hidden_states = (
929
+ output_hidden_states
930
+ if output_hidden_states is not None
931
+ else self.config.output_hidden_states
932
+ )
933
+ output_attentions = (
934
+ output_attentions
935
+ if output_attentions is not None
936
+ else self.config.output_attentions
937
+ )
938
+
939
+ embedding_output = self.embeddings(pixel_values)
940
+
941
+ outputs = self.encoder(
942
+ embedding_output,
943
+ output_hidden_states=True,
944
+ output_attentions=output_attentions,
945
+ return_dict=return_dict,
946
+ )
947
+
948
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
949
+
950
+ feature_maps = ()
951
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
952
+ if stage in self.out_features:
953
+ if self.config.apply_layernorm:
954
+ hidden_state = self.layernorm(hidden_state)
955
+ if self.config.reshape_hidden_states:
956
+ hidden_state = hidden_state[:, 1:]
957
+ # this was actually a bug in the original implementation that we copied here,
958
+ # cause normally the order is height, width
959
+ batch_size, _, height, width = pixel_values.shape
960
+ patch_size = self.config.patch_size
961
+ hidden_state = hidden_state.reshape(
962
+ batch_size, height // patch_size, width // patch_size, -1
963
+ )
964
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
965
+ feature_maps += (hidden_state,)
966
+
967
+ if not return_dict:
968
+ if output_hidden_states:
969
+ output = (feature_maps,) + outputs[1:]
970
+ else:
971
+ output = (feature_maps,) + outputs[2:]
972
+ return output
973
+
974
+ return BackboneOutput(
975
+ feature_maps=feature_maps,
976
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
977
+ attentions=outputs.attentions if output_attentions else None,
978
+ )
step1x3d_geometry/models/conditional_encoders/dinov2_clip_encoder.py ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ import re
6
+ from einops import rearrange
7
+ from dataclasses import dataclass
8
+ from torchvision import transforms
9
+
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from transformers import CLIPTokenizer, CLIPImageProcessor
12
+ from transformers import AutoImageProcessor, AutoModel
13
+ from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer
14
+ from transformers.utils import ModelOutput
15
+ from typing import Iterable, Optional, Union, List
16
+
17
+ import step1x3d_geometry
18
+ from step1x3d_geometry.utils.typing import *
19
+ from .clip.modeling_clip import CLIPModel
20
+ from .clip.modeling_conditional_clip import ConditionalCLIPModel
21
+ from .base import BaseVisualEncoder, ImageType
22
+ from .dinov2.modeling_dinov2 import Dinov2Model
23
+ from .dinov2.modeling_conditional_dinov2 import ConditionalDinov2Model
24
+ from .dinov2_with_registers.modeling_dinov2_with_registers import (
25
+ Dinov2WithRegistersModel,
26
+ )
27
+
28
+ CLIP_IMAGE_SIZE = 224
29
+
30
+
31
+ @dataclass
32
+ class CLIPEmbedOutput(ModelOutput):
33
+ last_hidden_state: torch.FloatTensor = None
34
+ pooler_output: torch.FloatTensor = None
35
+ embeds: torch.FloatTensor = None
36
+
37
+
38
+ class DINOEmbedOutput(ModelOutput):
39
+ last_hidden_state: torch.FloatTensor = None
40
+ pooler_output: torch.FloatTensor = None
41
+
42
+
43
+ @step1x3d_geometry.register("dinov2-clip-encoder")
44
+ class Dinov2CLIPEncoder(BaseVisualEncoder, ModelMixin):
45
+
46
+ @dataclass
47
+ class Config(BaseVisualEncoder.Config):
48
+ pretrained_model_name_or_path: Optional[str] = (
49
+ None # the pretrained model name or path for condition model
50
+ )
51
+ pretrained_clip_name_or_path: Optional[str] = (
52
+ None # the pretrained model name or path for clip
53
+ )
54
+ pretrained_dino_name_or_path: Optional[str] = (
55
+ None # the pretrained model name or path for dino
56
+ )
57
+ pretrained_linear_proj: Optional[str] = None
58
+ freeze_modulation_clip: bool = False
59
+ freeze_modulation_dino: bool = False
60
+ enable_gradient_checkpointing: bool = False
61
+ image_size: int = CLIP_IMAGE_SIZE
62
+ fuse_type: str = "concat"
63
+
64
+ dino_type: Optional[str] = None
65
+ clip_type: Optional[str] = None
66
+ kwargs: Optional[dict] = None
67
+
68
+ cfg: Config
69
+
70
+ def configure(self) -> None:
71
+ super().configure()
72
+
73
+ # Load the CLIP model and processor
74
+ if not self.cfg.encode_camera:
75
+ if self.cfg.pretrained_clip_name_or_path is not None:
76
+ self.cfg.clip_type = f"openai/{self.cfg.pretrained_clip_name_or_path.split('openai--')[-1].split('/')[0]}"
77
+ self.clip_model: CLIPModel = CLIPModel.from_pretrained(
78
+ self.cfg.pretrained_clip_name_or_path
79
+ )
80
+ else:
81
+ print("Loading CLIP model from openai/clip-vit-large-patch14")
82
+ self.dino_type = "openai/clip-vit-large-patch14"
83
+ self.clip_model: CLIPModel = CLIPModel(
84
+ config=ConditionalCLIPModel.config_class.from_pretrained(
85
+ "openai/clip-vit-large-patch14",
86
+ )
87
+ )
88
+ if self.cfg.pretrained_dino_name_or_path is not None:
89
+ self.cfg.dino_type = f"facebook/{self.cfg.pretrained_dino_name_or_path.split('facebook--')[-1].split('/')[0]}"
90
+ self.dino_model: Dinov2Model = AutoModel.from_pretrained(
91
+ self.cfg.pretrained_dino_name_or_path
92
+ )
93
+ else:
94
+ if (
95
+ self.cfg.pretrained_model_name_or_path is None
96
+ ): # default to load Dinov2-base model
97
+ assert (
98
+ self.cfg.dino_type is not None
99
+ ), "The dino_type should be provided"
100
+ print(f"Loading Dinov2 model from {self.cfg.dino_type}")
101
+ if "reg" in self.cfg.dino_type:
102
+ self.dino_model: Dinov2WithRegistersModel = (
103
+ Dinov2WithRegistersModel(
104
+ config=Dinov2WithRegistersModel.config_class.from_pretrained(
105
+ self.cfg.dino_type,
106
+ )
107
+ )
108
+ )
109
+ else:
110
+ self.dino_model: Dinov2Model = Dinov2Model(
111
+ config=Dinov2Model.config_class.from_pretrained(
112
+ self.dino_type,
113
+ )
114
+ )
115
+ elif "dinov2base" in self.cfg.pretrained_model_name_or_path:
116
+ print("Loading Dinov2 model from facebook/dinov2-base")
117
+ self.cfg.dino_type = "facebook/dinov2-base"
118
+ self.dino_model: Dinov2Model = Dinov2Model(
119
+ config=Dinov2Model.config_class.from_pretrained(
120
+ "facebook/dinov2-base",
121
+ )
122
+ )
123
+ elif "dinov2regbase" in self.cfg.pretrained_model_name_or_path:
124
+ print(
125
+ "Loading Dinov2 model from facebook/dinov2-with-registers-base"
126
+ )
127
+ self.cfg.dino_type = "facebook/dinov2-with-registers-base"
128
+ self.dino_model: Dinov2WithRegistersModel = (
129
+ Dinov2WithRegistersModel(
130
+ config=Dinov2WithRegistersModel.config_class.from_pretrained(
131
+ "facebook/dinov2-with-registers-base",
132
+ )
133
+ )
134
+ )
135
+ elif "dinov2reglarge" in self.cfg.pretrained_model_name_or_path:
136
+ print(
137
+ "Loading Dinov2 model from facebook/dinov2-with-registers-large"
138
+ )
139
+ self.cfg.dino_type = "facebook/dinov2-with-registers-large"
140
+ self.dino_model: Dinov2WithRegistersModel = (
141
+ Dinov2WithRegistersModel(
142
+ config=Dinov2WithRegistersModel.config_class.from_pretrained(
143
+ "facebook/dinov2-with-registers-large",
144
+ )
145
+ )
146
+ )
147
+ else:
148
+ raise ValueError(
149
+ f"Unknown Dinov2 model: {self.cfg.pretrained_model_name_or_path}"
150
+ )
151
+ else:
152
+ # clip
153
+ conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained(
154
+ self.cfg.pretrained_clip_name_or_path,
155
+ )
156
+ conditional_clip_config.vision_config.modulation_dim = (
157
+ self.cfg.camera_embeds_dim
158
+ )
159
+ self.clip_model: CLIPModel = ConditionalCLIPModel.from_pretrained(
160
+ self.cfg.pretrained_clip_name_or_path,
161
+ vision_config=conditional_clip_config.vision_config,
162
+ )
163
+
164
+ # dino
165
+ conditional_vit_config = (
166
+ ConditionalDinov2Model.config_class.from_pretrained(
167
+ self.cfg.pretrained_dino_name_or_path,
168
+ )
169
+ )
170
+ conditional_vit_config.modulation_dim = self.cfg.camera_embeds_dim
171
+ self.dino_model: ConditionalDinov2Model = (
172
+ ConditionalDinov2Model.from_pretrained(
173
+ self.cfg.pretrained_dino_name_or_path, config=conditional_vit_config
174
+ )
175
+ )
176
+
177
+ self.image_preprocess_clip = CLIPImageProcessor()
178
+ self.image_preprocess_dino = AutoImageProcessor.from_pretrained(
179
+ self.cfg.dino_type
180
+ if self.cfg.pretrained_dino_name_or_path is None
181
+ else self.cfg.pretrained_dino_name_or_path
182
+ )
183
+ self.transform_clip = transforms.Compose(
184
+ [
185
+ transforms.Resize(
186
+ CLIP_IMAGE_SIZE,
187
+ transforms.InterpolationMode.BICUBIC,
188
+ antialias=True,
189
+ ), # clip is CLIP_IMAGE_SIZE
190
+ transforms.CenterCrop(CLIP_IMAGE_SIZE), # crop a square.
191
+ transforms.Normalize(
192
+ mean=[0.48145466, 0.4578275, 0.40821073],
193
+ std=[0.26862954, 0.26130258, 0.27577711],
194
+ ),
195
+ ]
196
+ )
197
+ self.transform_dino = transforms.Compose(
198
+ [
199
+ transforms.Resize(
200
+ self.cfg.image_size,
201
+ transforms.InterpolationMode.BICUBIC,
202
+ antialias=True,
203
+ ),
204
+ transforms.CenterCrop(self.cfg.image_size), # crop a square
205
+ transforms.Normalize(
206
+ mean=[0.485, 0.456, 0.406],
207
+ std=[0.229, 0.224, 0.225],
208
+ ),
209
+ ]
210
+ )
211
+
212
+ if self.cfg.enable_gradient_checkpointing:
213
+ self.dino_model.encoder.gradient_checkpointing = True
214
+
215
+ if self.cfg.zero_uncond_embeds:
216
+ image_size = max(self.cfg.image_size, self.cfg.image_size)
217
+ self.empty_image_embeds_dino = torch.zeros(
218
+ (self.cfg.n_views, (image_size // 14) ** 2 + 1, 1024)
219
+ ).detach()
220
+ self.empty_image_embeds_clip = torch.zeros(
221
+ (self.cfg.n_views, (CLIP_IMAGE_SIZE // 14) ** 2 + 1, 1024)
222
+ ).detach()
223
+ if self.cfg.fuse_type == "concat":
224
+ self.empty_image_embeds = torch.cat(
225
+ [self.empty_image_embeds_dino, self.empty_image_embeds_clip], dim=1
226
+ )
227
+ else:
228
+ raise ValueError
229
+ else:
230
+ if self.cfg.encode_camera:
231
+ self.empty_image_embeds_dino = self.encode_image_dino(
232
+ torch.zeros(
233
+ self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3
234
+ ),
235
+ self.cameras[: self.cfg.n_views],
236
+ ).detach()
237
+ self.empty_image_embeds_clip = self.encode_image_clip(
238
+ torch.zeros(
239
+ self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3
240
+ ),
241
+ self.cameras[: self.cfg.n_views],
242
+ ).detach()
243
+ else:
244
+ self.empty_image_embeds_dino = self.encode_image_dino(
245
+ torch.zeros(
246
+ self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3
247
+ )
248
+ ).detach()
249
+ self.empty_image_embeds_clip = self.encode_image_clip(
250
+ torch.zeros(
251
+ self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3
252
+ )
253
+ ).detach()
254
+ self.empty_image_embeds_clip, self.empty_image_embeds_dino = (
255
+ self.align_clip_dino(
256
+ self.empty_image_embeds_clip, self.empty_image_embeds_dino
257
+ )
258
+ )
259
+ self.empty_image_embeds = torch.cat(
260
+ [self.empty_image_embeds_dino, self.empty_image_embeds_clip], dim=1
261
+ )
262
+
263
+ # Freeze the clip model parameters
264
+ self.clip_model.eval()
265
+ for k, p in self.clip_model.named_parameters():
266
+ ks = k.split(".")
267
+ if (
268
+ "mod_norm1" in ks
269
+ or "mod_norm2" in ks
270
+ and not self.cfg.freeze_modulation_clip
271
+ ):
272
+ p.requires_grad_(not self.cfg.freeze_modulation_clip)
273
+ else:
274
+ p.requires_grad_(False)
275
+
276
+ # freeze the dino model parameters
277
+ self.dino_model.eval()
278
+ for k, p in self.dino_model.named_parameters():
279
+ ks = k.split(".")
280
+ if (
281
+ "mod_norm1" in ks
282
+ or "mod_norm2" in ks
283
+ and not self.cfg.freeze_modulation_dino
284
+ ):
285
+ p.requires_grad_(not self.cfg.freeze_modulation_dino)
286
+ else:
287
+ p.requires_grad_(False)
288
+
289
+ # add a linear projection layer to project the dino embeddings to the same dimension as clip embeddings
290
+ if (
291
+ self.clip_model.config.vision_config.hidden_size
292
+ != self.dino_model.config.hidden_size
293
+ ):
294
+ self.linear_proj = nn.Linear(
295
+ self.clip_model.config.vision_config.hidden_size,
296
+ self.dino_model.config.vision_config.hidden_size,
297
+ bias=False,
298
+ )
299
+ else:
300
+ self.linear_proj = nn.Identity()
301
+
302
+ if self.cfg.pretrained_model_name_or_path is not None:
303
+ print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}")
304
+ ckpt = torch.load(
305
+ self.cfg.pretrained_model_name_or_path, map_location="cpu"
306
+ )["state_dict"]
307
+ pretrained_model_ckpt = {}
308
+ for k, v in ckpt.items():
309
+ if k.startswith("condition."):
310
+ pretrained_model_ckpt[k.replace("condition.", "")] = v
311
+ self.load_state_dict(pretrained_model_ckpt, strict=True)
312
+
313
+ def encode_image_clip(
314
+ self,
315
+ images: Iterable[Optional[ImageType]],
316
+ cameras: Optional[torch.Tensor] = None,
317
+ force_none_camera_embeds: bool = False,
318
+ return_dict: bool = False,
319
+ **kwargs,
320
+ ) -> torch.FloatTensor:
321
+ camera_embeds = None
322
+ if isinstance(images, (np.ndarray, torch.Tensor)): # for training process
323
+ assert (
324
+ images.min() >= 0.0 and images.max() <= 1.0
325
+ ), "The pixel values should be in the range of [0, 1]"
326
+ if self.cfg.encode_camera:
327
+ assert cameras is not None, "The cameras should be provided"
328
+ camera_embeds = self.encode_camera(cameras)
329
+ pixel_values = self.transform_clip(images.permute(0, 3, 1, 2))
330
+ else: # for inference process
331
+ if self.cfg.encode_camera:
332
+ if cameras is None:
333
+ bs = len(images) // self.cfg.n_views
334
+ cameras = (
335
+ self.cameras[: self.cfg.n_views]
336
+ .repeat(bs, 1, 1)
337
+ .to(self.clip_model.device)
338
+ )
339
+ camera_embeds = self.encode_camera(cameras)
340
+ pixel_values = self.image_preprocess_clip.preprocess(
341
+ images,
342
+ return_tensors="pt",
343
+ do_rescale=True,
344
+ do_resize=True,
345
+ size=CLIP_IMAGE_SIZE,
346
+ crop_size=CLIP_IMAGE_SIZE,
347
+ ).pixel_values
348
+
349
+ if force_none_camera_embeds:
350
+ camera_embeds = None
351
+
352
+ if pixel_values.ndim == 4:
353
+ pixel_values = pixel_values.unsqueeze(1)
354
+ if camera_embeds is not None:
355
+ camera_embeds = camera_embeds.unsqueeze(1)
356
+
357
+ if self.cfg.encode_camera and camera_embeds is not None:
358
+ vision_outputs = self.clip_model.vision_model(
359
+ pixel_values=rearrange(
360
+ pixel_values.to(self.clip_model.device), "B N C H W -> (B N) C H W"
361
+ ),
362
+ condition=rearrange(camera_embeds, "B N C -> (B N) C"),
363
+ )
364
+
365
+ else:
366
+ vision_outputs = self.clip_model.vision_model(
367
+ pixel_values=rearrange(
368
+ pixel_values.to(self.clip_model.device), "B N C H W -> (B N) C H W"
369
+ ),
370
+ )
371
+
372
+ if return_dict:
373
+ # clip
374
+ pooler_output = vision_outputs[1] # pooled_output
375
+ image_features = self.clip_model.visual_projection(pooler_output)
376
+ clip_embeds = vision_outputs.last_hidden_state
377
+
378
+ clip_embeds_dict = CLIPEmbedOutput(
379
+ last_hidden_state=clip_embeds,
380
+ pooler_output=pooler_output,
381
+ embeds=image_features,
382
+ )
383
+
384
+ return clip_embeds_dict
385
+ else:
386
+ return vision_outputs.last_hidden_state
387
+
388
+ def encode_image_dino(
389
+ self,
390
+ images: Iterable[Optional[ImageType]],
391
+ cameras: Optional[torch.Tensor] = None,
392
+ force_none_camera_embeds: bool = False,
393
+ return_dict: bool = False,
394
+ **kwargs,
395
+ ) -> torch.FloatTensor:
396
+ camera_embeds = None
397
+ if isinstance(images, (np.ndarray, torch.Tensor)): # for training process
398
+ assert (
399
+ images.min() >= 0.0 and images.max() <= 1.0
400
+ ), "The pixel values should be in the range of [0, 1]"
401
+ if self.cfg.encode_camera:
402
+ assert cameras is not None, "The cameras should be provided"
403
+ camera_embeds = self.encode_camera(cameras)
404
+ pixel_values = self.transform_dino(images.permute(0, 3, 1, 2))
405
+ else: # for inference process
406
+ if self.cfg.encode_camera:
407
+ if cameras is None:
408
+ bs = len(images) // self.cfg.n_views
409
+ cameras = (
410
+ self.cameras[: self.cfg.n_views]
411
+ .repeat(bs, 1, 1)
412
+ .to(self.dino_model.device)
413
+ )
414
+ camera_embeds = self.encode_camera(cameras)
415
+ pixel_values = self.image_preprocess_dino.preprocess(
416
+ images,
417
+ return_tensors="pt",
418
+ do_rescale=True,
419
+ do_resize=True,
420
+ size=self.cfg.image_size,
421
+ crop_size=self.cfg.image_size,
422
+ ).pixel_values
423
+
424
+ if force_none_camera_embeds:
425
+ camera_embeds = None
426
+
427
+ if pixel_values.ndim == 4:
428
+ pixel_values = pixel_values.unsqueeze(1)
429
+ if camera_embeds is not None:
430
+ camera_embeds = camera_embeds.unsqueeze(1)
431
+
432
+ if self.cfg.encode_camera and camera_embeds is not None:
433
+ vision_outputs = self.dino_model(
434
+ rearrange(
435
+ pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W"
436
+ ),
437
+ condition=rearrange(camera_embeds, "B N C -> (B N) C"),
438
+ )
439
+ else:
440
+ vision_outputs = self.dino_model(
441
+ rearrange(
442
+ pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W"
443
+ ),
444
+ )
445
+
446
+ if return_dict:
447
+ # dino
448
+ dino_embeds_dict = DINOEmbedOutput(
449
+ last_hidden_state=vision_outputs.last_hidden_state,
450
+ pooler_output=vision_outputs.pooler_output,
451
+ )
452
+ return dino_embeds_dict
453
+ else:
454
+ return vision_outputs.last_hidden_state
455
+
456
+ def align_clip_dino(self, clip_embeds, dino_embeds):
457
+ if (
458
+ clip_embeds.shape[-2] != dino_embeds.shape[-2]
459
+ ): # different shape, interpolate the clip embeddings to the same shape as dino embeddings
460
+ assert (
461
+ clip_embeds.shape[-2] == (self.cfg.image_size // 14) ** 2 + 1
462
+ ), "The clip embeddings should have the shape of (n_views, (image_size // 14) ** 2 + 1, 1024)"
463
+ clip_embeds_patch_tokens = clip_embeds[:, 1:].view(
464
+ clip_embeds.shape[0],
465
+ self.cfg.image_size // 14,
466
+ self.cfg.image_size // 14,
467
+ 1024,
468
+ )
469
+ clip_embeds_patch_tokens = (
470
+ torch.nn.functional.interpolate(
471
+ clip_embeds_patch_tokens.permute(0, 3, 1, 2),
472
+ size=(self.cfg.image_size // 14, self.cfg.image_size // 14),
473
+ mode="bilinear",
474
+ align_corners=False,
475
+ )
476
+ .permute(0, 2, 3, 1)
477
+ .view(clip_embeds.shape[0], -1, 1024)
478
+ )
479
+ clip_embeds = torch.cat(
480
+ [clip_embeds[:, :1], clip_embeds_patch_tokens], dim=1
481
+ )
482
+ return clip_embeds, dino_embeds
483
+
484
+ def encode_image(
485
+ self,
486
+ images: Iterable[Optional[ImageType]],
487
+ cameras: Optional[torch.Tensor] = None,
488
+ force_none_camera_embeds: bool = False,
489
+ return_dict: bool = False,
490
+ **kwargs,
491
+ ) -> torch.FloatTensor:
492
+ clip_embeds = self.encode_image_clip(images, cameras)
493
+ dino_embeds = self.encode_image_dino(images, cameras)
494
+ if (
495
+ self.dino_model.__class__.__name__ == "Dinov2WithRegistersModel"
496
+ ): # x_norm_clstoken, x_norm_regtokens, x_norm_patchtokens
497
+ dino_embeds = torch.cat(
498
+ [
499
+ dino_embeds[:, :1],
500
+ dino_embeds[:, self.dino_model.config.num_register_tokens + 1 :],
501
+ ],
502
+ dim=1,
503
+ )
504
+
505
+ clip_embeds = self.linear_proj(clip_embeds) # bs, 257, 1024
506
+
507
+ if self.cfg.fuse_type == "concat":
508
+ visual_embeds = torch.cat([dino_embeds, clip_embeds], dim=1)
509
+ # elif self.cfg.fuse_type == 'add':
510
+ # clip_embeds, dino_embeds = self.align_clip_dino(clip_embeds, dino_embeds)
511
+ else:
512
+ raise ValueError
513
+
514
+ return visual_embeds
step1x3d_geometry/models/conditional_encoders/dinov2_encoder.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ import re
6
+ from einops import rearrange
7
+ from dataclasses import dataclass
8
+ from torchvision import transforms
9
+
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from transformers import AutoImageProcessor, AutoModel
12
+ from transformers.utils import ModelOutput
13
+ from typing import Iterable, Optional, Union, List
14
+
15
+ import step1x3d_geometry
16
+ from step1x3d_geometry.utils.typing import *
17
+ from .base import BaseVisualEncoder, ImageType
18
+ from .dinov2.modeling_dinov2 import Dinov2Model
19
+ from .dinov2.modeling_conditional_dinov2 import ConditionalDinov2Model
20
+ from .dinov2_with_registers.modeling_dinov2_with_registers import (
21
+ Dinov2WithRegistersModel,
22
+ )
23
+
24
+
25
+ class DINOEmbedOutput(ModelOutput):
26
+ last_hidden_state: torch.FloatTensor = None
27
+ pooler_output: torch.FloatTensor = None
28
+
29
+
30
+ @step1x3d_geometry.register("dinov2-encoder")
31
+ class Dinov2Encoder(BaseVisualEncoder, ModelMixin):
32
+
33
+ @dataclass
34
+ class Config(BaseVisualEncoder.Config):
35
+ pretrained_model_name_or_path: Optional[str] = (
36
+ None # the pretrained model name or path for condition model
37
+ )
38
+ pretrained_dino_name_or_path: Optional[str] = (
39
+ None # the pretrained model name or path for dino
40
+ )
41
+ freeze_modulation_dino: bool = False
42
+ enable_gradient_checkpointing: bool = False
43
+ image_size: int = 224
44
+ dino_type: Optional[str] = None
45
+ kwargs: Optional[dict] = None
46
+
47
+ cfg: Config
48
+
49
+ def configure(self) -> None:
50
+ super().configure()
51
+
52
+ # Load the DINOV2 model and processor
53
+ if not self.cfg.encode_camera:
54
+ if self.cfg.pretrained_dino_name_or_path is not None:
55
+ self.cfg.dino_type = f"facebook/{self.cfg.pretrained_dino_name_or_path.split('facebook--')[-1].split('/')[0]}"
56
+ if self.cfg.kwargs is not None:
57
+ self.dino_model: Dinov2Model = AutoModel.from_pretrained(
58
+ self.cfg.pretrained_dino_name_or_path, **self.cfg.kwargs
59
+ )
60
+ else:
61
+ self.dino_model: Dinov2Model = AutoModel.from_pretrained(
62
+ self.cfg.pretrained_dino_name_or_path
63
+ )
64
+ else:
65
+ if (
66
+ self.cfg.pretrained_model_name_or_path is None
67
+ ): # default to load Dinov2-base model
68
+ assert (
69
+ self.cfg.dino_type is not None
70
+ ), "The dino_type should be provided"
71
+ print(f"Loading Dinov2 model from {self.cfg.dino_type}")
72
+ if "reg" in self.cfg.dino_type:
73
+ self.dino_model: Dinov2WithRegistersModel = (
74
+ Dinov2WithRegistersModel(
75
+ config=Dinov2WithRegistersModel.config_class.from_pretrained(
76
+ self.cfg.dino_type,
77
+ )
78
+ )
79
+ )
80
+ else:
81
+ self.dino_model: Dinov2Model = Dinov2Model(
82
+ config=Dinov2Model.config_class.from_pretrained(
83
+ self.dino_type,
84
+ )
85
+ )
86
+ elif "dinov2base" in self.cfg.pretrained_model_name_or_path:
87
+ print("Loading Dinov2 model from facebook/dinov2-base")
88
+ self.cfg.dino_type = "facebook/dinov2-base"
89
+ self.dino_model: Dinov2Model = Dinov2Model(
90
+ config=Dinov2Model.config_class.from_pretrained(
91
+ "facebook/dinov2-base",
92
+ )
93
+ )
94
+ elif "dinov2regbase" in self.cfg.pretrained_model_name_or_path:
95
+ print(
96
+ "Loading Dinov2 model from facebook/dinov2-with-registers-base"
97
+ )
98
+ self.cfg.dino_type = "facebook/dinov2-with-registers-base"
99
+ self.dino_model: Dinov2WithRegistersModel = (
100
+ Dinov2WithRegistersModel(
101
+ config=Dinov2WithRegistersModel.config_class.from_pretrained(
102
+ "facebook/dinov2-with-registers-base",
103
+ )
104
+ )
105
+ )
106
+ elif "dinov2reglarge" in self.cfg.pretrained_model_name_or_path:
107
+ print(
108
+ "Loading Dinov2 model from facebook/dinov2-with-registers-large"
109
+ )
110
+ self.cfg.dino_type = "facebook/dinov2-with-registers-large"
111
+ self.dino_model: Dinov2WithRegistersModel = (
112
+ Dinov2WithRegistersModel(
113
+ config=Dinov2WithRegistersModel.config_class.from_pretrained(
114
+ "facebook/dinov2-with-registers-large",
115
+ )
116
+ )
117
+ )
118
+ else:
119
+ raise ValueError(
120
+ f"Unknown Dinov2 model: {self.cfg.pretrained_model_name_or_path}"
121
+ )
122
+ else:
123
+ # dino
124
+ conditional_vit_config = (
125
+ ConditionalDinov2Model.config_class.from_pretrained(
126
+ self.cfg.pretrained_dino_name_or_path,
127
+ )
128
+ )
129
+ conditional_vit_config.modulation_dim = self.cfg.camera_embeds_dim
130
+ self.dino_model: ConditionalDinov2Model = (
131
+ ConditionalDinov2Model.from_pretrained(
132
+ self.cfg.pretrained_dino_name_or_path, config=conditional_vit_config
133
+ )
134
+ )
135
+
136
+ self.image_preprocess_dino = AutoImageProcessor.from_pretrained(
137
+ self.cfg.dino_type
138
+ if self.cfg.pretrained_dino_name_or_path is None
139
+ else self.cfg.pretrained_dino_name_or_path
140
+ )
141
+ self.transform_dino = transforms.Compose(
142
+ [
143
+ transforms.Resize(
144
+ self.cfg.image_size,
145
+ transforms.InterpolationMode.BICUBIC,
146
+ antialias=True,
147
+ ),
148
+ transforms.CenterCrop(
149
+ self.cfg.image_size
150
+ ), # crop a (image_size, image_size) square
151
+ transforms.Normalize(
152
+ mean=[0.485, 0.456, 0.406],
153
+ std=[0.229, 0.224, 0.225],
154
+ ),
155
+ ]
156
+ )
157
+
158
+ if self.cfg.enable_gradient_checkpointing:
159
+ self.dino_model.encoder.gradient_checkpointing = True
160
+
161
+ if self.cfg.zero_uncond_embeds:
162
+ self.empty_image_embeds = torch.zeros(
163
+ (
164
+ self.cfg.n_views,
165
+ (self.cfg.image_size // 14) ** 2 + 1,
166
+ self.dino_model.config.hidden_size,
167
+ )
168
+ ).detach()
169
+ else:
170
+ if self.cfg.encode_camera:
171
+ self.empty_image_embeds = self.encode_image_dino(
172
+ torch.zeros(
173
+ self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3
174
+ ),
175
+ self.cameras[: self.cfg.n_views],
176
+ ).detach()
177
+ else:
178
+ self.empty_image_embeds = self.encode_image_dino(
179
+ torch.zeros(
180
+ self.cfg.n_views, self.cfg.image_size, self.cfg.image_size, 3
181
+ )
182
+ ).detach()
183
+
184
+ # freeze the dino model parameters
185
+ self.dino_model.eval()
186
+ for k, p in self.dino_model.named_parameters():
187
+ ks = k.split(".")
188
+ if (
189
+ "mod_norm1" in ks
190
+ or "mod_norm2" in ks
191
+ and not self.cfg.freeze_modulation_dino
192
+ ):
193
+ p.requires_grad_(not self.cfg.freeze_modulation_dino)
194
+ else:
195
+ p.requires_grad_(False)
196
+
197
+ # load pretrained_model_name_or_path
198
+ if self.cfg.pretrained_model_name_or_path is not None:
199
+ print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}")
200
+ ckpt = torch.load(
201
+ self.cfg.pretrained_model_name_or_path, map_location="cpu"
202
+ )["state_dict"]
203
+ pretrained_model_ckpt = {}
204
+ for k, v in ckpt.items():
205
+ if k.startswith("visual_condition."):
206
+ pretrained_model_ckpt[k.replace("visual_condition.", "")] = v
207
+ self.load_state_dict(pretrained_model_ckpt, strict=True)
208
+
209
+ def encode_image_dino(
210
+ self,
211
+ images: Iterable[Optional[ImageType]],
212
+ cameras: Optional[torch.Tensor] = None,
213
+ force_none_camera_embeds: bool = False,
214
+ return_dict: bool = False,
215
+ **kwargs,
216
+ ) -> torch.FloatTensor:
217
+ camera_embeds = None
218
+ if isinstance(images, (np.ndarray, torch.Tensor)): # for training process
219
+ assert (
220
+ images.min() >= 0.0 and images.max() <= 1.0
221
+ ), "The pixel values should be in the range of [0, 1]"
222
+ if self.cfg.encode_camera:
223
+ assert cameras is not None, "The cameras should be provided"
224
+ camera_embeds = self.encode_camera(cameras)
225
+ pixel_values = self.transform_dino(images.permute(0, 3, 1, 2))
226
+ else: # for inference process
227
+ if self.cfg.encode_camera:
228
+ if cameras is None:
229
+ bs = len(images) // self.cfg.n_views
230
+ cameras = (
231
+ self.cameras[: self.cfg.n_views]
232
+ .repeat(bs, 1, 1)
233
+ .to(self.dino_model.device)
234
+ )
235
+ camera_embeds = self.encode_camera(cameras)
236
+ pixel_values = self.image_preprocess_dino.preprocess(
237
+ images,
238
+ return_tensors="pt",
239
+ do_rescale=True,
240
+ do_resize=True,
241
+ size=self.cfg.image_size,
242
+ crop_size=self.cfg.image_size,
243
+ ).pixel_values
244
+
245
+ if force_none_camera_embeds:
246
+ camera_embeds = None
247
+
248
+ if pixel_values.ndim == 4:
249
+ pixel_values = pixel_values.unsqueeze(1)
250
+ if camera_embeds is not None:
251
+ camera_embeds = camera_embeds.unsqueeze(1)
252
+
253
+ if self.cfg.encode_camera and camera_embeds is not None:
254
+ vision_outputs = self.dino_model(
255
+ rearrange(
256
+ pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W"
257
+ ),
258
+ condition=rearrange(camera_embeds, "B N C -> (B N) C"),
259
+ )
260
+ else:
261
+ vision_outputs = self.dino_model(
262
+ rearrange(
263
+ pixel_values.to(self.dino_model.device), "B N C H W -> (B N) C H W"
264
+ ),
265
+ )
266
+
267
+ if return_dict:
268
+ # dino
269
+ dino_embeds_dict = DINOEmbedOutput(
270
+ last_hidden_state=vision_outputs.last_hidden_state,
271
+ pooler_output=vision_outputs.pooler_output,
272
+ )
273
+ return dino_embeds_dict
274
+ else:
275
+ return vision_outputs.last_hidden_state
276
+
277
+ def encode_image(
278
+ self,
279
+ images: Iterable[Optional[ImageType]],
280
+ cameras: Optional[torch.Tensor] = None,
281
+ force_none_camera_embeds: bool = False,
282
+ return_dict: bool = False,
283
+ **kwargs,
284
+ ) -> torch.FloatTensor:
285
+ dino_embeds = self.encode_image_dino(images, cameras)
286
+ if (
287
+ self.dino_model.__class__.__name__ == "Dinov2WithRegistersModel"
288
+ ): # x_norm_clstoken, x_norm_regtokens, x_norm_patchtokens
289
+ dino_embeds = torch.cat(
290
+ [
291
+ dino_embeds[:, :1],
292
+ dino_embeds[:, self.dino_model.config.num_register_tokens + 1 :],
293
+ ],
294
+ dim=1,
295
+ )
296
+ return dino_embeds
step1x3d_geometry/models/conditional_encoders/dinov2_with_registers/modeling_dinov2_with_registers.py ADDED
@@ -0,0 +1,1088 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_dinov2_with_registers.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2024 Meta Inc. and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+
23
+ import collections.abc
24
+ import math
25
+ from typing import Dict, List, Optional, Set, Tuple, Union
26
+
27
+ import torch
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.modeling_outputs import (
33
+ BackboneOutput,
34
+ BaseModelOutput,
35
+ BaseModelOutputWithPooling,
36
+ ImageClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.pytorch_utils import (
40
+ find_pruneable_heads_and_indices,
41
+ prune_linear_layer,
42
+ )
43
+ from transformers.utils import (
44
+ add_code_sample_docstrings,
45
+ add_start_docstrings,
46
+ add_start_docstrings_to_model_forward,
47
+ logging,
48
+ replace_return_docstrings,
49
+ torch_int,
50
+ )
51
+ from transformers.utils.backbone_utils import BackboneMixin
52
+ from transformers.models.dinov2_with_registers.configuration_dinov2_with_registers import (
53
+ Dinov2WithRegistersConfig,
54
+ )
55
+
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ # Base docstring
60
+ _CHECKPOINT_FOR_DOC = "facebook/dinov2_with_registers-base"
61
+
62
+ # General docstring
63
+ _CONFIG_FOR_DOC = "Dinov2WithRegistersConfig"
64
+
65
+
66
+ class Dinov2WithRegistersPatchEmbeddings(nn.Module):
67
+ """
68
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
69
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
70
+ Transformer.
71
+ """
72
+
73
+ def __init__(self, config):
74
+ super().__init__()
75
+ image_size, patch_size = config.image_size, config.patch_size
76
+ num_channels, hidden_size = config.num_channels, config.hidden_size
77
+
78
+ image_size = (
79
+ image_size
80
+ if isinstance(image_size, collections.abc.Iterable)
81
+ else (image_size, image_size)
82
+ )
83
+ patch_size = (
84
+ patch_size
85
+ if isinstance(patch_size, collections.abc.Iterable)
86
+ else (patch_size, patch_size)
87
+ )
88
+ num_patches = (image_size[1] // patch_size[1]) * (
89
+ image_size[0] // patch_size[0]
90
+ )
91
+ self.image_size = image_size
92
+ self.patch_size = patch_size
93
+ self.num_channels = num_channels
94
+ self.num_patches = num_patches
95
+
96
+ self.projection = nn.Conv2d(
97
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
98
+ )
99
+
100
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
101
+ num_channels = pixel_values.shape[1]
102
+ if num_channels != self.num_channels:
103
+ raise ValueError(
104
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
105
+ f" Expected {self.num_channels} but got {num_channels}."
106
+ )
107
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
108
+ return embeddings
109
+
110
+
111
+ class Dinov2WithRegistersEmbeddings(nn.Module):
112
+ """
113
+ Construct the CLS token, mask token, register tokens, position and patch embeddings.
114
+ """
115
+
116
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
117
+ super().__init__()
118
+
119
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
120
+ self.mask_token = nn.Parameter(torch.zeros(1, config.hidden_size))
121
+ self.register_tokens = nn.Parameter(
122
+ torch.zeros(1, config.num_register_tokens, config.hidden_size)
123
+ )
124
+ self.patch_embeddings = Dinov2WithRegistersPatchEmbeddings(config)
125
+ num_patches = self.patch_embeddings.num_patches
126
+ self.position_embeddings = nn.Parameter(
127
+ torch.randn(1, num_patches + 1, config.hidden_size)
128
+ )
129
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
130
+ self.patch_size = config.patch_size
131
+ self.config = config
132
+
133
+ def interpolate_pos_encoding(
134
+ self, embeddings: torch.Tensor, height: int, width: int
135
+ ) -> torch.Tensor:
136
+ """
137
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
138
+ resolution images. This implementation supports torch.jit tracing while maintaining backwards compatibility
139
+ with the original implementation.
140
+
141
+ Adapted from:
142
+ - https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
143
+ - https://github.com/facebookresearch/dinov2/blob/main/dinov2/models/vision_transformer.py
144
+ """
145
+ num_patches = embeddings.shape[1] - 1
146
+ num_positions = self.position_embeddings.shape[1] - 1
147
+
148
+ # Skip interpolation for matching dimensions (unless tracing)
149
+ if (
150
+ not torch.jit.is_tracing()
151
+ and num_patches == num_positions
152
+ and height == width
153
+ ):
154
+ return self.position_embeddings
155
+
156
+ # Handle class token and patch embeddings separately
157
+ class_pos_embed = self.position_embeddings[:, 0]
158
+ patch_pos_embed = self.position_embeddings[:, 1:]
159
+ dim = embeddings.shape[-1]
160
+
161
+ # Calculate new dimensions
162
+ height = height // self.config.patch_size
163
+ width = width // self.config.patch_size
164
+
165
+ # Reshape for interpolation
166
+ sqrt_num_positions = torch_int(num_positions**0.5)
167
+ patch_pos_embed = patch_pos_embed.reshape(
168
+ 1, sqrt_num_positions, sqrt_num_positions, dim
169
+ )
170
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
171
+
172
+ # Store original dtype for restoration after interpolation
173
+ target_dtype = patch_pos_embed.dtype
174
+
175
+ # Interpolate at float32 precision
176
+ patch_pos_embed = nn.functional.interpolate(
177
+ patch_pos_embed.to(dtype=torch.float32),
178
+ size=(
179
+ torch_int(height),
180
+ torch_int(width),
181
+ ), # Explicit size instead of scale_factor
182
+ mode="bicubic",
183
+ align_corners=False,
184
+ antialias=True,
185
+ ).to(dtype=target_dtype)
186
+
187
+ # Validate output dimensions if not tracing
188
+ if not torch.jit.is_tracing():
189
+ if (
190
+ int(height) != patch_pos_embed.shape[-2]
191
+ or int(width) != patch_pos_embed.shape[-1]
192
+ ):
193
+ raise ValueError(
194
+ "Width or height does not match with the interpolated position embeddings"
195
+ )
196
+
197
+ # Reshape back to original format
198
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
199
+
200
+ # Combine class and patch embeddings
201
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
202
+
203
+ def forward(
204
+ self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Tensor] = None
205
+ ) -> torch.Tensor:
206
+ batch_size, _, height, width = pixel_values.shape
207
+ target_dtype = self.patch_embeddings.projection.weight.dtype
208
+ embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
209
+
210
+ if bool_masked_pos is not None:
211
+ embeddings = torch.where(
212
+ bool_masked_pos.unsqueeze(-1),
213
+ self.mask_token.to(embeddings.dtype).unsqueeze(0),
214
+ embeddings,
215
+ )
216
+
217
+ # add the [CLS] token to the embedded patch tokens
218
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
219
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
220
+
221
+ # add positional encoding to each token
222
+ embeddings = embeddings + self.interpolate_pos_encoding(
223
+ embeddings, height, width
224
+ )
225
+
226
+ # add register tokens
227
+ embeddings = torch.cat(
228
+ (
229
+ embeddings[:, :1],
230
+ self.register_tokens.expand(embeddings.shape[0], -1, -1),
231
+ embeddings[:, 1:],
232
+ ),
233
+ dim=1,
234
+ )
235
+
236
+ embeddings = self.dropout(embeddings)
237
+
238
+ return embeddings
239
+
240
+
241
+ class Dinov2WithRegistersSelfAttention(nn.Module):
242
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
243
+ super().__init__()
244
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
245
+ config, "embedding_size"
246
+ ):
247
+ raise ValueError(
248
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
249
+ f"heads {config.num_attention_heads}."
250
+ )
251
+
252
+ self.num_attention_heads = config.num_attention_heads
253
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
254
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
255
+
256
+ self.query = nn.Linear(
257
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
258
+ )
259
+ self.key = nn.Linear(
260
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
261
+ )
262
+ self.value = nn.Linear(
263
+ config.hidden_size, self.all_head_size, bias=config.qkv_bias
264
+ )
265
+
266
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
267
+
268
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
269
+ new_x_shape = x.size()[:-1] + (
270
+ self.num_attention_heads,
271
+ self.attention_head_size,
272
+ )
273
+ x = x.view(new_x_shape)
274
+ return x.permute(0, 2, 1, 3)
275
+
276
+ def forward(
277
+ self,
278
+ hidden_states,
279
+ head_mask: Optional[torch.Tensor] = None,
280
+ output_attentions: bool = False,
281
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
282
+ mixed_query_layer = self.query(hidden_states)
283
+
284
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
285
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
286
+ query_layer = self.transpose_for_scores(mixed_query_layer)
287
+
288
+ # Take the dot product between "query" and "key" to get the raw attention scores.
289
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
290
+
291
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
292
+
293
+ # Normalize the attention scores to probabilities.
294
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
295
+
296
+ # This is actually dropping out entire tokens to attend to, which might
297
+ # seem a bit unusual, but is taken from the original Transformer paper.
298
+ attention_probs = self.dropout(attention_probs)
299
+
300
+ # Mask heads if we want to
301
+ if head_mask is not None:
302
+ attention_probs = attention_probs * head_mask
303
+
304
+ context_layer = torch.matmul(attention_probs, value_layer)
305
+
306
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
307
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
308
+ context_layer = context_layer.view(new_context_layer_shape)
309
+
310
+ outputs = (
311
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
312
+ )
313
+
314
+ return outputs
315
+
316
+
317
+ class Dinov2WithRegistersSdpaSelfAttention(Dinov2WithRegistersSelfAttention):
318
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
319
+ super().__init__(config)
320
+ self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states,
325
+ head_mask: Optional[torch.Tensor] = None,
326
+ output_attentions: bool = False,
327
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
328
+ if output_attentions:
329
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
330
+ logger.warning_once(
331
+ "Dinov2WithRegistersModel is using Dinov2WithRegistersSdpaSelfAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
332
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
333
+ )
334
+ return super().forward(
335
+ hidden_states=hidden_states,
336
+ head_mask=head_mask,
337
+ output_attentions=output_attentions,
338
+ )
339
+
340
+ mixed_query_layer = self.query(hidden_states)
341
+
342
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
343
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
344
+ query_layer = self.transpose_for_scores(mixed_query_layer)
345
+
346
+ context_layer = torch.nn.functional.scaled_dot_product_attention(
347
+ query_layer,
348
+ key_layer,
349
+ value_layer,
350
+ head_mask,
351
+ self.attention_probs_dropout_prob if self.training else 0.0,
352
+ is_causal=False,
353
+ scale=None,
354
+ )
355
+
356
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
357
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
358
+ context_layer = context_layer.view(new_context_layer_shape)
359
+
360
+ return context_layer, None
361
+
362
+
363
+ class Dinov2WithRegistersSelfOutput(nn.Module):
364
+ """
365
+ The residual connection is defined in Dinov2WithRegistersLayer instead of here (as is the case with other models), due to the
366
+ layernorm applied before each block.
367
+ """
368
+
369
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
370
+ super().__init__()
371
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
372
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
373
+
374
+ def forward(
375
+ self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
376
+ ) -> torch.Tensor:
377
+ hidden_states = self.dense(hidden_states)
378
+ hidden_states = self.dropout(hidden_states)
379
+
380
+ return hidden_states
381
+
382
+
383
+ class Dinov2WithRegistersAttention(nn.Module):
384
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
385
+ super().__init__()
386
+ self.attention = Dinov2WithRegistersSelfAttention(config)
387
+ self.output = Dinov2WithRegistersSelfOutput(config)
388
+ self.pruned_heads = set()
389
+
390
+ def prune_heads(self, heads: Set[int]) -> None:
391
+ if len(heads) == 0:
392
+ return
393
+ heads, index = find_pruneable_heads_and_indices(
394
+ heads,
395
+ self.attention.num_attention_heads,
396
+ self.attention.attention_head_size,
397
+ self.pruned_heads,
398
+ )
399
+
400
+ # Prune linear layers
401
+ self.attention.query = prune_linear_layer(self.attention.query, index)
402
+ self.attention.key = prune_linear_layer(self.attention.key, index)
403
+ self.attention.value = prune_linear_layer(self.attention.value, index)
404
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
405
+
406
+ # Update hyper params and store pruned heads
407
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(
408
+ heads
409
+ )
410
+ self.attention.all_head_size = (
411
+ self.attention.attention_head_size * self.attention.num_attention_heads
412
+ )
413
+ self.pruned_heads = self.pruned_heads.union(heads)
414
+
415
+ def forward(
416
+ self,
417
+ hidden_states: torch.Tensor,
418
+ head_mask: Optional[torch.Tensor] = None,
419
+ output_attentions: bool = False,
420
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
421
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
422
+
423
+ attention_output = self.output(self_outputs[0], hidden_states)
424
+
425
+ outputs = (attention_output,) + self_outputs[
426
+ 1:
427
+ ] # add attentions if we output them
428
+ return outputs
429
+
430
+
431
+ class Dinov2WithRegistersSdpaAttention(Dinov2WithRegistersAttention):
432
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
433
+ super().__init__(config)
434
+ self.attention = Dinov2WithRegistersSdpaSelfAttention(config)
435
+
436
+
437
+ class Dinov2WithRegistersLayerScale(nn.Module):
438
+ def __init__(self, config) -> None:
439
+ super().__init__()
440
+ self.lambda1 = nn.Parameter(
441
+ config.layerscale_value * torch.ones(config.hidden_size)
442
+ )
443
+
444
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
445
+ return hidden_state * self.lambda1
446
+
447
+
448
+ def drop_path(
449
+ input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
450
+ ) -> torch.Tensor:
451
+ """
452
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
453
+
454
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
455
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
456
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
457
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
458
+ argument.
459
+ """
460
+ if drop_prob == 0.0 or not training:
461
+ return input
462
+ keep_prob = 1 - drop_prob
463
+ shape = (input.shape[0],) + (1,) * (
464
+ input.ndim - 1
465
+ ) # work with diff dim tensors, not just 2D ConvNets
466
+ random_tensor = keep_prob + torch.rand(
467
+ shape, dtype=input.dtype, device=input.device
468
+ )
469
+ random_tensor.floor_() # binarize
470
+ output = input.div(keep_prob) * random_tensor
471
+ return output
472
+
473
+
474
+ class Dinov2WithRegistersDropPath(nn.Module):
475
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
476
+
477
+ def __init__(self, drop_prob: Optional[float] = None) -> None:
478
+ super().__init__()
479
+ self.drop_prob = drop_prob
480
+
481
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
482
+ return drop_path(hidden_states, self.drop_prob, self.training)
483
+
484
+ def extra_repr(self) -> str:
485
+ return "p={}".format(self.drop_prob)
486
+
487
+
488
+ class Dinov2WithRegistersMLP(nn.Module):
489
+ def __init__(self, config) -> None:
490
+ super().__init__()
491
+ in_features = out_features = config.hidden_size
492
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
493
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
494
+ if isinstance(config.hidden_act, str):
495
+ self.activation = ACT2FN[config.hidden_act]
496
+ else:
497
+ self.activation = config.hidden_act
498
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
499
+
500
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
501
+ hidden_state = self.fc1(hidden_state)
502
+ hidden_state = self.activation(hidden_state)
503
+ hidden_state = self.fc2(hidden_state)
504
+ return hidden_state
505
+
506
+
507
+ class Dinov2WithRegistersSwiGLUFFN(nn.Module):
508
+ def __init__(self, config) -> None:
509
+ super().__init__()
510
+ in_features = out_features = config.hidden_size
511
+ hidden_features = int(config.hidden_size * config.mlp_ratio)
512
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
513
+
514
+ self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
515
+ self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
516
+
517
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
518
+ hidden_state = self.weights_in(hidden_state)
519
+ x1, x2 = hidden_state.chunk(2, dim=-1)
520
+ hidden = nn.functional.silu(x1) * x2
521
+ return self.weights_out(hidden)
522
+
523
+
524
+ DINOV2_WITH_REGISTERS_ATTENTION_CLASSES = {
525
+ "eager": Dinov2WithRegistersAttention,
526
+ "sdpa": Dinov2WithRegistersSdpaAttention,
527
+ }
528
+
529
+
530
+ class Dinov2WithRegistersLayer(nn.Module):
531
+ """This corresponds to the Block class in the original implementation."""
532
+
533
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
534
+ super().__init__()
535
+
536
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
537
+ self.attention = DINOV2_WITH_REGISTERS_ATTENTION_CLASSES[
538
+ config._attn_implementation
539
+ ](config)
540
+ self.layer_scale1 = Dinov2WithRegistersLayerScale(config)
541
+ self.drop_path = (
542
+ Dinov2WithRegistersDropPath(config.drop_path_rate)
543
+ if config.drop_path_rate > 0.0
544
+ else nn.Identity()
545
+ )
546
+
547
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
548
+
549
+ if config.use_swiglu_ffn:
550
+ self.mlp = Dinov2WithRegistersSwiGLUFFN(config)
551
+ else:
552
+ self.mlp = Dinov2WithRegistersMLP(config)
553
+ self.layer_scale2 = Dinov2WithRegistersLayerScale(config)
554
+
555
+ def forward(
556
+ self,
557
+ hidden_states: torch.Tensor,
558
+ head_mask: Optional[torch.Tensor] = None,
559
+ output_attentions: bool = False,
560
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
561
+ self_attention_outputs = self.attention(
562
+ self.norm1(
563
+ hidden_states
564
+ ), # in Dinov2WithRegisters, layernorm is applied before self-attention
565
+ head_mask,
566
+ output_attentions=output_attentions,
567
+ )
568
+ attention_output = self_attention_outputs[0]
569
+
570
+ attention_output = self.layer_scale1(attention_output)
571
+ outputs = self_attention_outputs[
572
+ 1:
573
+ ] # add self attentions if we output attention weights
574
+
575
+ # first residual connection
576
+ hidden_states = self.drop_path(attention_output) + hidden_states
577
+
578
+ # in Dinov2WithRegisters, layernorm is also applied after self-attention
579
+ layer_output = self.norm2(hidden_states)
580
+ layer_output = self.mlp(layer_output)
581
+ layer_output = self.layer_scale2(layer_output)
582
+
583
+ # second residual connection
584
+ layer_output = self.drop_path(layer_output) + hidden_states
585
+
586
+ outputs = (layer_output,) + outputs
587
+
588
+ return outputs
589
+
590
+
591
+ class Dinov2WithRegistersEncoder(nn.Module):
592
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
593
+ super().__init__()
594
+ self.config = config
595
+ self.layer = nn.ModuleList(
596
+ [Dinov2WithRegistersLayer(config) for _ in range(config.num_hidden_layers)]
597
+ )
598
+ self.gradient_checkpointing = False
599
+
600
+ def forward(
601
+ self,
602
+ hidden_states: torch.Tensor,
603
+ head_mask: Optional[torch.Tensor] = None,
604
+ output_attentions: bool = False,
605
+ output_hidden_states: bool = False,
606
+ return_dict: bool = True,
607
+ ) -> Union[tuple, BaseModelOutput]:
608
+ all_hidden_states = () if output_hidden_states else None
609
+ all_self_attentions = () if output_attentions else None
610
+
611
+ for i, layer_module in enumerate(self.layer):
612
+ if output_hidden_states:
613
+ all_hidden_states = all_hidden_states + (hidden_states,)
614
+
615
+ layer_head_mask = head_mask[i] if head_mask is not None else None
616
+
617
+ if self.gradient_checkpointing and self.training:
618
+ layer_outputs = self._gradient_checkpointing_func(
619
+ layer_module.__call__,
620
+ hidden_states,
621
+ layer_head_mask,
622
+ output_attentions,
623
+ )
624
+ else:
625
+ layer_outputs = layer_module(
626
+ hidden_states, layer_head_mask, output_attentions
627
+ )
628
+
629
+ hidden_states = layer_outputs[0]
630
+
631
+ if output_attentions:
632
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
633
+
634
+ if output_hidden_states:
635
+ all_hidden_states = all_hidden_states + (hidden_states,)
636
+
637
+ if not return_dict:
638
+ return tuple(
639
+ v
640
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
641
+ if v is not None
642
+ )
643
+ return BaseModelOutput(
644
+ last_hidden_state=hidden_states,
645
+ hidden_states=all_hidden_states,
646
+ attentions=all_self_attentions,
647
+ )
648
+
649
+
650
+ class Dinov2WithRegistersPreTrainedModel(PreTrainedModel):
651
+ """
652
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
653
+ models.
654
+ """
655
+
656
+ config_class = Dinov2WithRegistersConfig
657
+ base_model_prefix = "dinov2_with_registers"
658
+ main_input_name = "pixel_values"
659
+ supports_gradient_checkpointing = True
660
+ _no_split_modules = ["Dinov2WithRegistersSwiGLUFFN"]
661
+ _supports_sdpa = True
662
+
663
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
664
+ """Initialize the weights"""
665
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
666
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
667
+ # `trunc_normal_cpu` not implemented in `half` issues
668
+ module.weight.data = nn.init.trunc_normal_(
669
+ module.weight.data.to(torch.float32),
670
+ mean=0.0,
671
+ std=self.config.initializer_range,
672
+ ).to(module.weight.dtype)
673
+ if module.bias is not None:
674
+ module.bias.data.zero_()
675
+ elif isinstance(module, nn.LayerNorm):
676
+ module.bias.data.zero_()
677
+ module.weight.data.fill_(1.0)
678
+ elif isinstance(module, Dinov2WithRegistersEmbeddings):
679
+ module.position_embeddings.data = nn.init.trunc_normal_(
680
+ module.position_embeddings.data.to(torch.float32),
681
+ mean=0.0,
682
+ std=self.config.initializer_range,
683
+ ).to(module.position_embeddings.dtype)
684
+
685
+ module.cls_token.data = nn.init.trunc_normal_(
686
+ module.cls_token.data.to(torch.float32),
687
+ mean=0.0,
688
+ std=self.config.initializer_range,
689
+ ).to(module.cls_token.dtype)
690
+
691
+
692
+ _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
693
+
694
+
695
+ DINOV2_WITH_REGISTERS_START_DOCSTRING = r"""
696
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
697
+ as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
698
+ behavior.
699
+
700
+ Parameters:
701
+ config ([`Dinov2WithRegistersConfig`]): Model configuration class with all the parameters of the model.
702
+ Initializing with a config file does not load the weights associated with the model, only the
703
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
704
+ """
705
+
706
+ DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING = r"""
707
+ Args:
708
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
709
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
710
+ [`BitImageProcessor.preprocess`] for details.
711
+
712
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
713
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Only relevant for
714
+ pre-training.
715
+
716
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
717
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
718
+
719
+ - 1 indicates the head is **not masked**,
720
+ - 0 indicates the head is **masked**.
721
+
722
+ output_attentions (`bool`, *optional*):
723
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
724
+ tensors for more detail.
725
+ output_hidden_states (`bool`, *optional*):
726
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
727
+ more detail.
728
+ return_dict (`bool`, *optional*):
729
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
730
+ """
731
+
732
+
733
+ @add_start_docstrings(
734
+ "The bare Dinov2WithRegisters Model transformer outputting raw hidden-states without any specific head on top.",
735
+ DINOV2_WITH_REGISTERS_START_DOCSTRING,
736
+ )
737
+ class Dinov2WithRegistersModel(Dinov2WithRegistersPreTrainedModel):
738
+ def __init__(self, config: Dinov2WithRegistersConfig):
739
+ super().__init__(config)
740
+ self.config = config
741
+
742
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
743
+ self.encoder = Dinov2WithRegistersEncoder(config)
744
+
745
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
746
+
747
+ # Initialize weights and apply final processing
748
+ self.post_init()
749
+
750
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
751
+ return self.embeddings.patch_embeddings
752
+
753
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
754
+ """
755
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
756
+ class PreTrainedModel
757
+ """
758
+ for layer, heads in heads_to_prune.items():
759
+ self.encoder.layer[layer].attention.prune_heads(heads)
760
+
761
+ @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_BASE_INPUTS_DOCSTRING)
762
+ @add_code_sample_docstrings(
763
+ checkpoint=_CHECKPOINT_FOR_DOC,
764
+ output_type=BaseModelOutputWithPooling,
765
+ config_class=_CONFIG_FOR_DOC,
766
+ modality="vision",
767
+ expected_output=_EXPECTED_OUTPUT_SHAPE,
768
+ )
769
+ def forward(
770
+ self,
771
+ pixel_values: Optional[torch.Tensor] = None,
772
+ bool_masked_pos: Optional[torch.Tensor] = None,
773
+ head_mask: Optional[torch.Tensor] = None,
774
+ output_attentions: Optional[bool] = None,
775
+ output_hidden_states: Optional[bool] = None,
776
+ return_dict: Optional[bool] = None,
777
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
778
+ output_attentions = (
779
+ output_attentions
780
+ if output_attentions is not None
781
+ else self.config.output_attentions
782
+ )
783
+ output_hidden_states = (
784
+ output_hidden_states
785
+ if output_hidden_states is not None
786
+ else self.config.output_hidden_states
787
+ )
788
+ return_dict = (
789
+ return_dict if return_dict is not None else self.config.use_return_dict
790
+ )
791
+
792
+ if pixel_values is None:
793
+ raise ValueError("You have to specify pixel_values")
794
+
795
+ # Prepare head mask if needed
796
+ # 1.0 in head_mask indicate we keep the head
797
+ # attention_probs has shape bsz x n_heads x N x N
798
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
799
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
800
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
801
+
802
+ embedding_output = self.embeddings(
803
+ pixel_values, bool_masked_pos=bool_masked_pos
804
+ )
805
+
806
+ encoder_outputs = self.encoder(
807
+ embedding_output,
808
+ head_mask=head_mask,
809
+ output_attentions=output_attentions,
810
+ output_hidden_states=output_hidden_states,
811
+ return_dict=return_dict,
812
+ )
813
+ sequence_output = encoder_outputs[0]
814
+ sequence_output = self.layernorm(sequence_output)
815
+ pooled_output = sequence_output[:, 0, :]
816
+
817
+ if not return_dict:
818
+ head_outputs = (sequence_output, pooled_output)
819
+ return head_outputs + encoder_outputs[1:]
820
+
821
+ return BaseModelOutputWithPooling(
822
+ last_hidden_state=sequence_output,
823
+ pooler_output=pooled_output,
824
+ hidden_states=encoder_outputs.hidden_states,
825
+ attentions=encoder_outputs.attentions,
826
+ )
827
+
828
+
829
+ # Image classification docstring
830
+ _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2_with_registers-small-imagenet1k-1-layer"
831
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
832
+
833
+ DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING = r"""
834
+ Args:
835
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
836
+ Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
837
+ [`BitImageProcessor.preprocess`] for details.
838
+
839
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
840
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
841
+
842
+ - 1 indicates the head is **not masked**,
843
+ - 0 indicates the head is **masked**.
844
+
845
+ output_attentions (`bool`, *optional*):
846
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
847
+ tensors for more detail.
848
+ output_hidden_states (`bool`, *optional*):
849
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
850
+ more detail.
851
+ return_dict (`bool`, *optional*):
852
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
853
+ """
854
+
855
+
856
+ @add_start_docstrings(
857
+ """
858
+ Dinov2WithRegisters Model transformer with an image classification head on top (a linear layer on top of the final hidden state
859
+ of the [CLS] token) e.g. for ImageNet.
860
+ """,
861
+ DINOV2_WITH_REGISTERS_START_DOCSTRING,
862
+ )
863
+ class Dinov2WithRegistersForImageClassification(Dinov2WithRegistersPreTrainedModel):
864
+ def __init__(self, config: Dinov2WithRegistersConfig) -> None:
865
+ super().__init__(config)
866
+
867
+ self.num_labels = config.num_labels
868
+ self.dinov2_with_registers = Dinov2WithRegistersModel(config)
869
+
870
+ # Classifier head
871
+ self.classifier = (
872
+ nn.Linear(config.hidden_size * 2, config.num_labels)
873
+ if config.num_labels > 0
874
+ else nn.Identity()
875
+ )
876
+
877
+ # Initialize weights and apply final processing
878
+ self.post_init()
879
+
880
+ @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING)
881
+ @add_code_sample_docstrings(
882
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
883
+ output_type=ImageClassifierOutput,
884
+ config_class=_CONFIG_FOR_DOC,
885
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
886
+ )
887
+ def forward(
888
+ self,
889
+ pixel_values: Optional[torch.Tensor] = None,
890
+ head_mask: Optional[torch.Tensor] = None,
891
+ labels: Optional[torch.Tensor] = None,
892
+ output_attentions: Optional[bool] = None,
893
+ output_hidden_states: Optional[bool] = None,
894
+ return_dict: Optional[bool] = None,
895
+ ) -> Union[tuple, ImageClassifierOutput]:
896
+ r"""
897
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
898
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
899
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
900
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
901
+ """
902
+ return_dict = (
903
+ return_dict if return_dict is not None else self.config.use_return_dict
904
+ )
905
+
906
+ outputs = self.dinov2_with_registers(
907
+ pixel_values,
908
+ head_mask=head_mask,
909
+ output_attentions=output_attentions,
910
+ output_hidden_states=output_hidden_states,
911
+ return_dict=return_dict,
912
+ )
913
+
914
+ sequence_output = outputs[0] # batch_size, sequence_length, hidden_size
915
+
916
+ cls_token = sequence_output[:, 0]
917
+ patch_tokens = sequence_output[:, 1:]
918
+
919
+ linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
920
+
921
+ logits = self.classifier(linear_input)
922
+
923
+ loss = None
924
+ if labels is not None:
925
+ # move labels to correct device to enable model parallelism
926
+ labels = labels.to(logits.device)
927
+ if self.config.problem_type is None:
928
+ if self.num_labels == 1:
929
+ self.config.problem_type = "regression"
930
+ elif self.num_labels > 1 and (
931
+ labels.dtype == torch.long or labels.dtype == torch.int
932
+ ):
933
+ self.config.problem_type = "single_label_classification"
934
+ else:
935
+ self.config.problem_type = "multi_label_classification"
936
+
937
+ if self.config.problem_type == "regression":
938
+ loss_fct = MSELoss()
939
+ if self.num_labels == 1:
940
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
941
+ else:
942
+ loss = loss_fct(logits, labels)
943
+ elif self.config.problem_type == "single_label_classification":
944
+ loss_fct = CrossEntropyLoss()
945
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
946
+ elif self.config.problem_type == "multi_label_classification":
947
+ loss_fct = BCEWithLogitsLoss()
948
+ loss = loss_fct(logits, labels)
949
+
950
+ if not return_dict:
951
+ output = (logits,) + outputs[2:]
952
+ return ((loss,) + output) if loss is not None else output
953
+
954
+ return ImageClassifierOutput(
955
+ loss=loss,
956
+ logits=logits,
957
+ hidden_states=outputs.hidden_states,
958
+ attentions=outputs.attentions,
959
+ )
960
+
961
+
962
+ @add_start_docstrings(
963
+ """
964
+ Dinov2WithRegisters backbone, to be used with frameworks like DETR and MaskFormer.
965
+ """,
966
+ DINOV2_WITH_REGISTERS_START_DOCSTRING,
967
+ )
968
+ class Dinov2WithRegistersBackbone(Dinov2WithRegistersPreTrainedModel, BackboneMixin):
969
+ def __init__(self, config):
970
+ super().__init__(config)
971
+ super()._init_backbone(config)
972
+ self.num_features = [
973
+ config.hidden_size for _ in range(config.num_hidden_layers + 1)
974
+ ]
975
+ self.embeddings = Dinov2WithRegistersEmbeddings(config)
976
+ self.encoder = Dinov2WithRegistersEncoder(config)
977
+
978
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
979
+
980
+ self.num_register_tokens = config.num_register_tokens
981
+
982
+ # Initialize weights and apply final processing
983
+ self.post_init()
984
+
985
+ def get_input_embeddings(self) -> Dinov2WithRegistersPatchEmbeddings:
986
+ return self.embeddings.patch_embeddings
987
+
988
+ @add_start_docstrings_to_model_forward(DINOV2_WITH_REGISTERS_INPUTS_DOCSTRING)
989
+ @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
990
+ def forward(
991
+ self,
992
+ pixel_values: torch.Tensor,
993
+ output_hidden_states: Optional[bool] = None,
994
+ output_attentions: Optional[bool] = None,
995
+ return_dict: Optional[bool] = None,
996
+ ) -> BackboneOutput:
997
+ """
998
+ Returns:
999
+
1000
+ Examples:
1001
+ Returns:
1002
+
1003
+ Examples:
1004
+
1005
+
1006
+ ```python
1007
+ >>> from transformers import AutoImageProcessor, AutoBackbone
1008
+ >>> import torch
1009
+ >>> from PIL import Image
1010
+ >>> import requests
1011
+
1012
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1013
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1014
+
1015
+ >>> processor = AutoImageProcessor.from_pretrained("facebook/dinov2-with-registers-base")
1016
+ >>> model = AutoBackbone.from_pretrained(
1017
+ ... "facebook/dinov2-with-registers-base", out_features=["stage2", "stage5", "stage8", "stage11"]
1018
+ ... )
1019
+
1020
+ >>> inputs = processor(image, return_tensors="pt")
1021
+
1022
+ >>> outputs = model(**inputs)
1023
+ >>> feature_maps = outputs.feature_maps
1024
+ >>> list(feature_maps[-1].shape)
1025
+ [1, 768, 16, 16]
1026
+ ```"""
1027
+ return_dict = (
1028
+ return_dict if return_dict is not None else self.config.use_return_dict
1029
+ )
1030
+ output_hidden_states = (
1031
+ output_hidden_states
1032
+ if output_hidden_states is not None
1033
+ else self.config.output_hidden_states
1034
+ )
1035
+ output_attentions = (
1036
+ output_attentions
1037
+ if output_attentions is not None
1038
+ else self.config.output_attentions
1039
+ )
1040
+
1041
+ embedding_output = self.embeddings(pixel_values)
1042
+
1043
+ outputs = self.encoder(
1044
+ embedding_output,
1045
+ output_hidden_states=True,
1046
+ output_attentions=output_attentions,
1047
+ return_dict=return_dict,
1048
+ )
1049
+
1050
+ hidden_states = outputs.hidden_states if return_dict else outputs[1]
1051
+
1052
+ feature_maps = ()
1053
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
1054
+ if stage in self.out_features:
1055
+ if self.config.apply_layernorm:
1056
+ hidden_state = self.layernorm(hidden_state)
1057
+ if self.config.reshape_hidden_states:
1058
+ hidden_state = hidden_state[:, self.num_register_tokens + 1 :]
1059
+ # this was actually a bug in the original implementation that we copied here,
1060
+ # cause normally the order is height, width
1061
+ batch_size, _, height, width = pixel_values.shape
1062
+ patch_size = self.config.patch_size
1063
+ hidden_state = hidden_state.reshape(
1064
+ batch_size, height // patch_size, width // patch_size, -1
1065
+ )
1066
+ hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
1067
+ feature_maps += (hidden_state,)
1068
+
1069
+ if not return_dict:
1070
+ if output_hidden_states:
1071
+ output = (feature_maps,) + outputs[1:]
1072
+ else:
1073
+ output = (feature_maps,) + outputs[2:]
1074
+ return output
1075
+
1076
+ return BackboneOutput(
1077
+ feature_maps=feature_maps,
1078
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1079
+ attentions=outputs.attentions if output_attentions else None,
1080
+ )
1081
+
1082
+
1083
+ __all__ = [
1084
+ "Dinov2WithRegistersPreTrainedModel",
1085
+ "Dinov2WithRegistersModel",
1086
+ "Dinov2WithRegistersForImageClassification",
1087
+ "Dinov2WithRegistersBackbone",
1088
+ ]
step1x3d_geometry/models/conditional_encoders/label_encoder.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ import re
6
+ from einops import rearrange
7
+ from dataclasses import dataclass
8
+ from torchvision import transforms
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+
11
+ from transformers.utils import ModelOutput
12
+ from typing import Iterable, Optional, Union, List
13
+
14
+ import step1x3d_geometry
15
+ from step1x3d_geometry.utils.typing import *
16
+ from step1x3d_geometry.utils.misc import get_device
17
+
18
+ from .base import BaseLabelEncoder
19
+
20
+ DEFAULT_POSE = 0 # "unknown", "t-pose", "a-pose", uncond
21
+ NUM_POSE_CLASSES = 3
22
+ POSE_MAPPING = {"unknown": 0, "t-pose": 1, "a-pose": 2, "uncond": 3}
23
+
24
+ DEFAULT_SYMMETRY_TYPE = 0 # "asymmetry", "x", uncond
25
+ NUM_SYMMETRY_TYPE_CLASSES = 2
26
+ SYMMETRY_TYPE_MAPPING = {"asymmetry": 0, "x": 1, "y": 0, "z": 0, "uncond": 2}
27
+
28
+ DEFAULT_GEOMETRY_QUALITY = 0 # "normal", "smooth", "sharp", uncond,
29
+ NUM_GEOMETRY_QUALITY_CLASSES = 3
30
+ GEOMETRY_QUALITY_MAPPING = {"normal": 0, "smooth": 1, "sharp": 2, "uncod": 3}
31
+
32
+
33
+ @step1x3d_geometry.register("label-encoder")
34
+ class LabelEncoder(BaseLabelEncoder, ModelMixin):
35
+ """
36
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
37
+
38
+ Args:
39
+ num_classes (`int`): The number of classes.
40
+ hidden_size (`int`): The size of the vector embeddings.
41
+ """
42
+
43
+ def configure(self) -> None:
44
+ super().configure()
45
+
46
+ if self.cfg.zero_uncond_embeds:
47
+ self.embedding_table_tpose = nn.Embedding(
48
+ NUM_POSE_CLASSES, self.cfg.hidden_size
49
+ )
50
+ self.embedding_table_symmetry_type = nn.Embedding(
51
+ NUM_SYMMETRY_TYPE_CLASSES, self.cfg.hidden_size
52
+ )
53
+ self.embedding_table_geometry_quality = nn.Embedding(
54
+ NUM_GEOMETRY_QUALITY_CLASSES, self.cfg.hidden_size
55
+ )
56
+ else:
57
+ self.embedding_table_tpose = nn.Embedding(
58
+ NUM_POSE_CLASSES + 1, self.cfg.hidden_size
59
+ )
60
+ self.embedding_table_symmetry_type = nn.Embedding(
61
+ NUM_SYMMETRY_TYPE_CLASSES + 1, self.cfg.hidden_size
62
+ )
63
+ self.embedding_table_geometry_quality = nn.Embedding(
64
+ NUM_GEOMETRY_QUALITY_CLASSES + 1, self.cfg.hidden_size
65
+ )
66
+
67
+ if self.cfg.zero_uncond_embeds:
68
+ self.empty_label_embeds = torch.zeros((1, 3, self.cfg.hidden_size)).detach()
69
+ else:
70
+ self.empty_label_embeds = (
71
+ self.encode_label( # the last class label is for the uncond
72
+ [{"pose": "", "symetry": "", "geometry_type": ""}]
73
+ ).detach()
74
+ )
75
+
76
+ # load pretrained_model_name_or_path
77
+ if self.cfg.pretrained_model_name_or_path is not None:
78
+ print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}")
79
+ ckpt = torch.load(
80
+ self.cfg.pretrained_model_name_or_path, map_location="cpu"
81
+ )["state_dict"]
82
+ pretrained_model_ckpt = {}
83
+ for k, v in ckpt.items():
84
+ if k.startswith("label_condition."):
85
+ pretrained_model_ckpt[k.replace("label_condition.", "")] = v
86
+ self.load_state_dict(pretrained_model_ckpt, strict=True)
87
+
88
+ def encode_label(self, labels: List[dict]) -> torch.FloatTensor:
89
+ tpose_label_embeds = []
90
+ symmetry_type_label_embeds = []
91
+ geometry_quality_label_embeds = []
92
+
93
+ for label in labels:
94
+ if "pose" in label.keys():
95
+ if label["pose"] is None or label["pose"] == "":
96
+ tpose_label_embeds.append(
97
+ torch.zeros(self.cfg.hidden_size).detach().to(get_device())
98
+ )
99
+ else:
100
+ tpose_label_embeds.append(
101
+ self.embedding_table_symmetry_type(
102
+ torch.tensor(POSE_MAPPING[label["pose"][0]]).to(
103
+ get_device()
104
+ )
105
+ )
106
+ )
107
+ else:
108
+ tpose_label_embeds.append(
109
+ self.embedding_table_tpose(
110
+ torch.tensor(DEFAULT_POSE).to(get_device())
111
+ )
112
+ )
113
+
114
+ if "symmetry" in label.keys():
115
+ if label["symmetry"] is None or label["symmetry"] == "":
116
+ symmetry_type_label_embeds.append(
117
+ torch.zeros(self.cfg.hidden_size).detach().to(get_device())
118
+ )
119
+ else:
120
+ symmetry_type_label_embeds.append(
121
+ self.embedding_table_symmetry_type(
122
+ torch.tensor(
123
+ SYMMETRY_TYPE_MAPPING[label["symmetry"][0]]
124
+ ).to(get_device())
125
+ )
126
+ )
127
+ else:
128
+ symmetry_type_label_embeds.append(
129
+ self.embedding_table_symmetry_type(
130
+ torch.tensor(DEFAULT_SYMMETRY_TYPE).to(get_device())
131
+ )
132
+ )
133
+
134
+ if "geometry_type" in label.keys():
135
+ if label["geometry_type"] is None or label["geometry_type"] == "":
136
+ geometry_quality_label_embeds.append(
137
+ torch.zeros(self.cfg.hidden_size).detach().to(get_device())
138
+ )
139
+ else:
140
+ geometry_quality_label_embeds.append(
141
+ self.embedding_table_geometry_quality(
142
+ torch.tensor(
143
+ GEOMETRY_QUALITY_MAPPING[label["geometry_type"][0]]
144
+ ).to(get_device())
145
+ )
146
+ )
147
+ else:
148
+ geometry_quality_label_embeds.append(
149
+ self.embedding_table_geometry_quality(
150
+ torch.tensor(DEFAULT_GEOMETRY_QUALITY).to(get_device())
151
+ )
152
+ )
153
+
154
+ tpose_label_embeds = torch.stack(tpose_label_embeds)
155
+ symmetry_type_label_embeds = torch.stack(symmetry_type_label_embeds)
156
+ geometry_quality_label_embeds = torch.stack(geometry_quality_label_embeds)
157
+
158
+ label_embeds = torch.stack(
159
+ [
160
+ tpose_label_embeds,
161
+ symmetry_type_label_embeds,
162
+ geometry_quality_label_embeds,
163
+ ],
164
+ dim=1,
165
+ ).to(self.dtype)
166
+
167
+ return label_embeds
step1x3d_geometry/models/conditional_encoders/t5_encoder.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ import re
6
+ import urllib.parse as ul
7
+ from bs4 import BeautifulSoup
8
+ from einops import rearrange
9
+ from dataclasses import dataclass
10
+ from torchvision import transforms
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+
13
+ from transformers import AutoImageProcessor, AutoModel
14
+ from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer
15
+ from transformers.utils import ModelOutput
16
+ from typing import Iterable, Optional, Union, List
17
+
18
+ import step1x3d_geometry
19
+ from step1x3d_geometry.utils.typing import *
20
+
21
+ from .base import BaseCaptionEncoder
22
+
23
+ bad_punct_regex = re.compile(
24
+ r"["
25
+ + "#®•©™&@·º½¾¿¡§~"
26
+ + "\)"
27
+ + "\("
28
+ + "\]"
29
+ + "\["
30
+ + "\}"
31
+ + "\{"
32
+ + "\|"
33
+ + "\\"
34
+ + "\/"
35
+ + "\*"
36
+ + r"]{1,}"
37
+ ) # noqa
38
+
39
+
40
+ @step1x3d_geometry.register("t5-encoder")
41
+ class T5Encoder(BaseCaptionEncoder, ModelMixin):
42
+
43
+ @dataclass
44
+ class Config(BaseCaptionEncoder.Config):
45
+ pretrained_model_name_or_path: Optional[str] = (
46
+ None # the pretrained model name or path for condition model
47
+ )
48
+ pretrained_t5_name_or_path: Optional[str] = (
49
+ None # the pretrained model name or path for T5
50
+ )
51
+ preprocessing_text: bool = False
52
+ text_max_length: int = 77
53
+ t5_type: Optional[str] = None
54
+
55
+ cfg: Config
56
+
57
+ def configure(self) -> None:
58
+ super().configure()
59
+
60
+ # Load the T5 model and tokenizer
61
+ if self.cfg.pretrained_t5_name_or_path is not None:
62
+ self.cfg.t5_type = f"google-t5/{self.cfg.pretrained_t5_name_or_path.split('google-t5--')[-1].split('/')[0]}"
63
+ self.tokenizer = T5Tokenizer.from_pretrained(
64
+ self.cfg.pretrained_t5_name_or_path
65
+ )
66
+ self.text_model = T5EncoderModel.from_pretrained(
67
+ self.cfg.pretrained_t5_name_or_path, torch_dtype=torch.bfloat16
68
+ )
69
+ else:
70
+ if (
71
+ self.cfg.pretrained_model_name_or_path is None
72
+ ): # default to load t5-base model
73
+ assert self.cfg.t5_type is not None, "The t5_type should be provided"
74
+ print(f"Loading T5 model from {self.cfg.t5_type}")
75
+ self.text_model = T5EncoderModel(
76
+ config=T5EncoderModel.config_class.from_pretrained(
77
+ self.cfg.t5_type,
78
+ )
79
+ ).to(torch.bfloat16)
80
+ elif "t5small" in self.cfg.pretrained_model_name_or_path:
81
+ print("Loading Dinov2 model from google-t5/t5-small")
82
+ self.cfg.t5_type = "google-t5/t5-small"
83
+ self.text_model = T5EncoderModel.from_pretrained(
84
+ self.cfg.t5_type, torch_dtype=torch.bfloat16
85
+ )
86
+ elif "t5base" in self.cfg.pretrained_model_name_or_path:
87
+ print("Loading Dinov2 model from google-t5/t5-base")
88
+ self.cfg.t5_type = "google-t5/t5-base"
89
+ self.text_model = T5EncoderModel.from_pretrained(
90
+ self.cfg.t5_type, torch_dtype=torch.bfloat16
91
+ )
92
+ else:
93
+ raise ValueError(
94
+ f"Unknown T5 model: {self.cfg.pretrained_model_name_or_path}"
95
+ )
96
+ self.tokenizer = T5Tokenizer.from_pretrained(self.cfg.t5_type)
97
+
98
+ # Set the empty image/text embeds
99
+ if self.cfg.zero_uncond_embeds:
100
+ self.empty_text_embeds = torch.zeros(
101
+ (1, self.cfg.text_max_length, self.text_model.config.hidden_size)
102
+ ).detach()
103
+ else:
104
+ self.empty_text_embeds = self.encode_text([""]).detach()
105
+
106
+ # load pretrained_model_name_or_path
107
+ if self.cfg.pretrained_model_name_or_path is not None:
108
+ print(f"Loading ckpt from {self.cfg.pretrained_model_name_or_path}")
109
+ ckpt = torch.load(
110
+ self.cfg.pretrained_model_name_or_path, map_location="cpu"
111
+ )["state_dict"]
112
+ pretrained_model_ckpt = {}
113
+ for k, v in ckpt.items():
114
+ if k.startswith("caption_condition."):
115
+ pretrained_model_ckpt[k.replace("caption_condition.", "")] = v
116
+ self.load_state_dict(pretrained_model_ckpt, strict=True)
117
+
118
+ def clean_caption(self, caption):
119
+ caption = str(caption)
120
+ caption = ul.unquote_plus(caption)
121
+ caption = caption.strip().lower()
122
+ caption = re.sub("<person>", "person", caption)
123
+ # urls:
124
+ caption = re.sub(
125
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
126
+ "",
127
+ caption,
128
+ ) # regex for urls
129
+ caption = re.sub(
130
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
131
+ "",
132
+ caption,
133
+ ) # regex for urls
134
+ # html:
135
+ caption = BeautifulSoup(caption, features="html.parser").text
136
+
137
+ # @<nickname>
138
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
139
+
140
+ # 31C0—31EF CJK Strokes
141
+ # 31F0—31FF Katakana Phonetic Extensions
142
+ # 3200—32FF Enclosed CJK Letters and Months
143
+ # 3300—33FF CJK Compatibility
144
+ # 3400—4DBF CJK Unified Ideographs Extension A
145
+ # 4DC0—4DFF Yijing Hexagram Symbols
146
+ # 4E00—9FFF CJK Unified Ideographs
147
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
148
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
149
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
150
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
151
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
152
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
153
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
154
+ #######################################################
155
+
156
+ # все виды тире / all types of dash --> "-"
157
+ caption = re.sub(
158
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
159
+ "-",
160
+ caption,
161
+ )
162
+
163
+ # кавычки к одному стандарту
164
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
165
+ caption = re.sub(r"[‘’]", "'", caption)
166
+
167
+ # &quot;
168
+ caption = re.sub(r"&quot;?", "", caption)
169
+ # &amp
170
+ caption = re.sub(r"&amp", "", caption)
171
+
172
+ # ip adresses:
173
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
174
+
175
+ # article ids:
176
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
177
+
178
+ # \n
179
+ caption = re.sub(r"\\n", " ", caption)
180
+
181
+ # "#123"
182
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
183
+ # "#12345.."
184
+ caption = re.sub(r"#\d{5,}\b", "", caption)
185
+ # "123456.."
186
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
187
+ # filenames:
188
+ caption = re.sub(
189
+ r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption
190
+ )
191
+
192
+ #
193
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
194
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
195
+
196
+ caption = re.sub(
197
+ bad_punct_regex, r" ", caption
198
+ ) # ***AUSVERKAUFT***, #AUSVERKAUFT
199
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
200
+
201
+ # this-is-my-cute-cat / this_is_my_cute_cat
202
+ regex2 = re.compile(r"(?:\-|\_)")
203
+ if len(re.findall(regex2, caption)) > 3:
204
+ caption = re.sub(regex2, " ", caption)
205
+
206
+ caption = self.basic_clean(caption)
207
+
208
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
209
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
210
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
211
+
212
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
213
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
214
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
215
+ caption = re.sub(
216
+ r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption
217
+ )
218
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
219
+
220
+ caption = re.sub(
221
+ r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption
222
+ ) # j2d1a2a...
223
+
224
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
225
+
226
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
227
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
228
+ caption = re.sub(r"\s+", " ", caption)
229
+
230
+ caption.strip()
231
+
232
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
233
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
234
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
235
+ caption = re.sub(r"^\.\S+$", "", caption)
236
+
237
+ return caption.strip()
238
+
239
+ def text_preprocessing(self, text):
240
+ if self.cfg.preprocessing_text:
241
+ # The exact text cleaning as was in the training stage:
242
+ text = self.clean_caption(text)
243
+ return text
244
+ else:
245
+ return text.lower().strip()
246
+
247
+ def encode_text(self, texts: List[str]) -> torch.FloatTensor:
248
+ texts = [self.text_preprocessing(text) for text in texts]
249
+
250
+ text_tokens_and_mask = self.tokenizer(
251
+ texts,
252
+ max_length=self.cfg.text_max_length,
253
+ padding="max_length",
254
+ truncation=True,
255
+ return_attention_mask=True,
256
+ add_special_tokens=True,
257
+ return_tensors="pt",
258
+ )
259
+
260
+ text_tokens_and_mask["input_ids"] = text_tokens_and_mask["input_ids"] # N x 77
261
+ text_tokens_and_mask["attention_mask"] = text_tokens_and_mask["attention_mask"]
262
+
263
+ with torch.no_grad():
264
+ label_embeds = self.text_model(
265
+ input_ids=text_tokens_and_mask["input_ids"].to(self.text_model.device),
266
+ attention_mask=text_tokens_and_mask["attention_mask"].to(
267
+ self.text_model.device
268
+ ),
269
+ )["last_hidden_state"].detach()
270
+
271
+ return label_embeds
step1x3d_geometry/models/pipelines/pipeline.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Some parts of this file are refer to Hugging Face Diffusers library.
2
+ import os
3
+ import json
4
+ import warnings
5
+ from typing import Callable, List, Optional, Union, Dict, Any
6
+ import PIL.Image
7
+ import trimesh
8
+ import rembg
9
+ import torch
10
+ import numpy as np
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
14
+ from diffusers.utils import BaseOutput
15
+ from diffusers.utils.torch_utils import randn_tensor
16
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
17
+ from diffusers.loaders import (
18
+ FluxIPAdapterMixin,
19
+ FluxLoraLoaderMixin,
20
+ FromSingleFileMixin,
21
+ TextualInversionLoaderMixin,
22
+ )
23
+ from .pipeline_utils import (
24
+ TransformerDiffusionMixin,
25
+ preprocess_image,
26
+ retrieve_timesteps,
27
+ remove_floater,
28
+ remove_degenerate_face,
29
+ reduce_face,
30
+ smart_load_model,
31
+ )
32
+ from transformers import (
33
+ BitImageProcessor,
34
+ )
35
+
36
+ import step1x3d_geometry
37
+ from step1x3d_geometry.models.autoencoders.surface_extractors import MeshExtractResult
38
+ from step1x3d_geometry.utils.config import ExperimentConfig, load_config
39
+ from ..autoencoders.michelangelo_autoencoder import MichelangeloAutoencoder
40
+ from ..conditional_encoders.dinov2_encoder import Dinov2Encoder
41
+ from ..conditional_encoders.t5_encoder import T5Encoder
42
+ from ..conditional_encoders.label_encoder import LabelEncoder
43
+ from ..transformers.flux_transformer_1d import FluxDenoiser
44
+
45
+
46
+ class Step1X3DGeometryPipelineOutput(BaseOutput):
47
+ """
48
+ Output class for image pipelines.
49
+
50
+ Args:
51
+ images (`List[PIL.Image.Image]` or `torch.Tensor`):
52
+ List of PIL images or a tensor representing the input images.
53
+ meshes (`List[trimesh.Trimesh]` or `np.ndarray`)
54
+ List of denoised trimesh meshes of length `batch_size` or a tuple of NumPy array with shape `((vertices, 3), (faces, 3)) of length `batch_size``.
55
+ """
56
+
57
+ image: PIL.Image.Image
58
+ mesh: Union[trimesh.Trimesh, MeshExtractResult, np.ndarray]
59
+
60
+
61
+ class Step1X3DGeometryPipeline(
62
+ DiffusionPipeline, FromSingleFileMixin, TransformerDiffusionMixin
63
+ ):
64
+ """
65
+ Step1X-3D Geometry Pipeline, generate high-quality meshes conditioned on image/caption/label inputs
66
+
67
+ Args:
68
+ scheduler (FlowMatchEulerDiscreteScheduler):
69
+ The diffusion scheduler controlling the denoising process
70
+ vae (MichelangeloAutoencoder):
71
+ Variational Autoencoder for latent space compression/reconstruction
72
+ transformer (FluxDenoiser):
73
+ Transformer-based denoising model
74
+ visual_encoder (Dinov2Encoder):
75
+ Pretrained visual encoder for image feature extraction
76
+ caption_encoder (T5Encoder):
77
+ Text encoder for processing natural language captions
78
+ label_encoder (LabelEncoder):
79
+ Auxiliary text encoder for label conditioning
80
+ visual_eature_extractor (BitImageProcessor):
81
+ Preprocessor for input images
82
+
83
+ Note:
84
+ - CPU offloading sequence: visual_encoder → caption_encoder → label_encoder → transformer → vae
85
+ - Optional components: visual_encoder, visual_eature_extractor, caption_encoder, label_encoder
86
+ """
87
+
88
+ model_cpu_offload_seq = (
89
+ "visual_encoder->caption_encoder->label_encoder->transformer->vae"
90
+ )
91
+ _optional_components = [
92
+ "visual_encoder",
93
+ "visual_eature_extractor",
94
+ "caption_encoder",
95
+ "label_encoder",
96
+ ]
97
+
98
+ @classmethod
99
+ def from_pretrained(cls, model_path, subfolder='.', **kwargs):
100
+ local_model_path = smart_load_model(model_path, subfolder)
101
+ return super().from_pretrained(local_model_path, **kwargs)
102
+
103
+ def __init__(
104
+ self,
105
+ scheduler: FlowMatchEulerDiscreteScheduler,
106
+ vae: MichelangeloAutoencoder,
107
+ transformer: FluxDenoiser,
108
+ visual_encoder: Dinov2Encoder,
109
+ caption_encoder: T5Encoder,
110
+ label_encoder: LabelEncoder,
111
+ visual_eature_extractor: BitImageProcessor,
112
+ ):
113
+ super().__init__()
114
+
115
+ self.register_modules(
116
+ vae=vae,
117
+ transformer=transformer,
118
+ scheduler=scheduler,
119
+ visual_encoder=visual_encoder,
120
+ caption_encoder=caption_encoder,
121
+ label_encoder=label_encoder,
122
+ visual_eature_extractor=visual_eature_extractor,
123
+ )
124
+
125
+ @property
126
+ def guidance_scale(self):
127
+ return self._guidance_scale
128
+
129
+ @property
130
+ def do_classifier_free_guidance(self):
131
+ return self._guidance_scale > 1
132
+
133
+ @property
134
+ def num_timesteps(self):
135
+ return self._num_timesteps
136
+
137
+ def check_inputs(
138
+ self,
139
+ image,
140
+ ):
141
+ r"""
142
+ Check if the inputs are valid. Raise an error if not.
143
+ """
144
+ if isinstance(image, str):
145
+ assert os.path.isfile(image) or image.startswith(
146
+ "http"
147
+ ), "Input image must be a valid URL or a file path."
148
+ elif isinstance(image, (torch.Tensor, PIL.Image.Image)):
149
+ raise ValueError(
150
+ "Input image must be a `torch.Tensor` or `PIL.Image.Image`."
151
+ )
152
+
153
+ def encode_image(self, image, device, num_meshes_per_prompt):
154
+ dtype = next(self.visual_encoder.parameters()).dtype
155
+
156
+ image_embeds = self.visual_encoder.encode_image(image)
157
+ image_embeds = image_embeds.repeat_interleave(num_meshes_per_prompt, dim=0)
158
+
159
+ uncond_image_embeds = self.visual_encoder.empty_image_embeds.repeat(
160
+ image_embeds.shape[0], 1, 1
161
+ ).to(image_embeds)
162
+
163
+ return image_embeds, uncond_image_embeds
164
+
165
+ def encode_caption(self, caption, device, num_meshes_per_prompt):
166
+ dtype = next(self.label_encoder.parameters()).dtype
167
+
168
+ caption_embeds = self.caption_encoder.encode_text([caption])
169
+ caption_embeds = caption_embeds.repeat_interleave(num_meshes_per_prompt, dim=0)
170
+
171
+ uncond_caption_embeds = self.caption_encoder.empty_text_embeds.repeat(
172
+ caption_embeds.shape[0], 1, 1
173
+ ).to(caption_embeds)
174
+
175
+ return caption_embeds, uncond_caption_embeds
176
+
177
+ def encode_label(self, label, device, num_meshes_per_prompt):
178
+ dtype = next(self.label_encoder.parameters()).dtype
179
+
180
+ label_embeds = self.label_encoder.encode_label([label])
181
+ label_embeds = label_embeds.repeat_interleave(num_meshes_per_prompt, dim=0)
182
+
183
+ uncond_label_embeds = self.label_encoder.empty_label_embeds.repeat(
184
+ label_embeds.shape[0], 1, 1
185
+ ).to(label_embeds)
186
+
187
+ return label_embeds, uncond_label_embeds
188
+
189
+ def prepare_latents(
190
+ self,
191
+ batch_size,
192
+ num_tokens,
193
+ num_channels_latents,
194
+ dtype,
195
+ device,
196
+ generator,
197
+ latents: Optional[torch.Tensor] = None,
198
+ ):
199
+ if latents is not None:
200
+ return latents.to(device=device, dtype=dtype)
201
+
202
+ shape = (batch_size, num_tokens, num_channels_latents)
203
+
204
+ if isinstance(generator, list) and len(generator) != batch_size:
205
+ raise ValueError(
206
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
207
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
208
+ )
209
+
210
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
211
+
212
+ return latents
213
+
214
+ @torch.no_grad()
215
+ def __call__(
216
+ self,
217
+ image: Union[torch.FloatTensor, PIL.Image.Image, str],
218
+ label: Optional[str] = None,
219
+ caption: Optional[str] = None,
220
+ num_inference_steps: int = 30,
221
+ timesteps: List[int] = None,
222
+ num_meshes_per_prompt: int = 1,
223
+ guidance_scale: float = 7.5,
224
+ generator: Optional[int] = None,
225
+ latents: Optional[torch.FloatTensor] = None,
226
+ force_remove_background: bool = False,
227
+ background_color: List[int] = [255, 255, 255],
228
+ foreground_ratio: float = 0.95,
229
+ surface_extractor_type: Optional[str] = None,
230
+ bounds: float = 1.05,
231
+ mc_level: float = 0.0,
232
+ octree_resolution: int = 384,
233
+ output_type: str = "trimesh",
234
+ do_remove_floater: bool = True,
235
+ do_remove_degenerate_face: bool = False,
236
+ do_reduce_face: bool = True,
237
+ do_shade_smooth: bool = True,
238
+ max_facenum: int = 200000,
239
+ return_dict: bool = True,
240
+ use_zero_init: Optional[bool] = True,
241
+ zero_steps: Optional[int] = 0,
242
+ ):
243
+ r"""
244
+ Function invoked when calling the pipeline for generation.
245
+
246
+ Args:
247
+ image (`torch.FloatTensor` or `PIL.Image.Image` or `str`):
248
+ `Image`, or tensor representing an image batch, or path to an image file. The image will be encoded to
249
+ its CLIP/DINO-v2 embedding which the DiT will be conditioned on.
250
+ label (`str`):
251
+ The label of the generated mesh, like {"symmetry": "asymmetry", "edge_type": "smooth"}
252
+ num_inference_steps (`int`, *optional*, defaults to 30):
253
+ The number of denoising steps. More denoising steps usually lead to a higher quality mesh at the expense
254
+ of slower inference.
255
+ timesteps (`List[int]`, *optional*):
256
+ Custom timesteps to use for the denoising process. If not provided, will use equally spaced timesteps.
257
+ num_meshes_per_prompt (`int`, *optional*, defaults to 1):
258
+ The number of meshes to generate per input image.
259
+ guidance_scale (`float`, *optional*, defaults to 7.5):
260
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
261
+ Higher guidance scale encourages generation that closely matches the input image.
262
+ generator (`int`, *optional*):
263
+ A seed to make the generation deterministic.
264
+ latents (`torch.FloatTensor`, *optional*):
265
+ Pre-generated noisy latents to use as inputs for mesh generation.
266
+ force_remove_background (`bool`, *optional*, defaults to `False`):
267
+ Whether to force remove the background from the input image before processing.
268
+ background_color (`List[int]`, *optional*, defaults to `[255, 255, 255]`):
269
+ RGB color values for the background if it needs to be removed or modified.
270
+ foreground_ratio (`float`, *optional*, defaults to 0.95):
271
+ Ratio of the image to consider as foreground when processing.
272
+ surface_extractor_type (`str`, *optional*, defaults to "mc"):
273
+ Type of surface extraction method to use ("mc" for Marching Cubes or other available methods).
274
+ bounds (`float`, *optional*, defaults to 1.05):
275
+ Bounding box size for the generated mesh.
276
+ mc_level (`float`, *optional*, defaults to 0.0):
277
+ Iso-surface level value for Marching Cubes extraction.
278
+ octree_resolution (`int`, *optional*, defaults to 256):
279
+ Resolution of the octree used for mesh generation.
280
+ output_type (`str`, *optional*, defaults to "trimesh"):
281
+ Type of output mesh format ("trimesh" or other supported formats).
282
+ return_dict (`bool`, *optional*, defaults to `True`):
283
+ Whether or not to return a `MeshPipelineOutput` instead of a plain tuple.
284
+
285
+ Returns:
286
+ [`MeshPipelineOutput`] or `tuple`:
287
+ If `return_dict` is `True`, [`MeshPipelineOutput`] is returned, otherwise a `tuple` is returned where the
288
+ first element is a list of generated meshes and the second element is a list of corresponding metadata.
289
+ """
290
+ # 0. Check inputs. Raise error if not correct
291
+ self.check_inputs(
292
+ image=image,
293
+ )
294
+ device = self._execution_device
295
+ self._guidance_scale = guidance_scale
296
+
297
+ # 1. Define call parameters
298
+ if isinstance(image, torch.Tensor):
299
+ batch_size = image.shape[0]
300
+ elif isinstance(image, PIL.Image.Image) or isinstance(image, str):
301
+ batch_size = 1
302
+
303
+ # 2. Preprocess input image
304
+ if isinstance(image, torch.Tensor):
305
+ assert image.ndim == 3 # H, W, 3
306
+ image_pil = TF.to_pil_image(image)
307
+ elif isinstance(image, PIL.Image.Image):
308
+ image_pil = image
309
+ elif isinstance(image, str):
310
+ if image.startswith("http"):
311
+ import requests
312
+
313
+ image_pil = PIL.Image.open(requests.get(image, stream=True).raw)
314
+ else:
315
+ image_pil = PIL.Image.open(image)
316
+ image_pil = preprocess_image(image_pil, force=force_remove_background, background_color=background_color, foreground_ratio=foreground_ratio) # remove the background images
317
+
318
+ # 3. Encode condition
319
+ image_embeds, negative_image_embeds = self.encode_image(
320
+ image_pil, device, num_meshes_per_prompt
321
+ )
322
+ if self.do_classifier_free_guidance and image_embeds is not None:
323
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
324
+ # 3.1 Encode label condition
325
+ label_embeds = None
326
+ if self.transformer.cfg.use_label_condition:
327
+ if label is not None:
328
+ label_embeds, negative_label_embeds = self.encode_label(
329
+ label, device, num_meshes_per_prompt
330
+ )
331
+ if self.do_classifier_free_guidance:
332
+ label_embeds = torch.cat(
333
+ [negative_label_embeds, label_embeds], dim=0
334
+ )
335
+ else:
336
+ uncond_label_embeds = self.label_encoder.empty_label_embeds.repeat(
337
+ num_meshes_per_prompt, 1, 1
338
+ ).to(image_embeds)
339
+ if self.do_classifier_free_guidance:
340
+ label_embeds = torch.cat(
341
+ [uncond_label_embeds, uncond_label_embeds], dim=0
342
+ )
343
+ # 3.3 Encode caption condition
344
+ caption_embeds = None
345
+ if self.transformer.cfg.use_caption_condition:
346
+ if caption is not None:
347
+ caption_embeds, negative_caption_embeds = self.encode_caption(
348
+ caption, device, num_meshes_per_prompt
349
+ )
350
+ if self.do_classifier_free_guidance:
351
+ caption_embeds = torch.cat(
352
+ [negative_caption_embeds, caption_embeds], dim=0
353
+ )
354
+ else:
355
+ uncond_caption_embeds = self.caption_encoder.empty_text_embeds.repeat(
356
+ num_meshes_per_prompt, 1, 1
357
+ ).to(image_embeds)
358
+ if self.do_classifier_free_guidance:
359
+ caption_embeds = torch.cat(
360
+ [uncond_caption_embeds, uncond_caption_embeds], dim=0
361
+ )
362
+
363
+ # 4. Prepare timesteps
364
+ timesteps, num_inference_steps = retrieve_timesteps(
365
+ self.scheduler, num_inference_steps, device, timesteps
366
+ )
367
+ num_warmup_steps = max(
368
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
369
+ )
370
+ self._num_timesteps = len(timesteps)
371
+
372
+ # 5. Prepare latent variables
373
+ num_latents = self.vae.cfg.num_latents
374
+ num_channels_latents = self.transformer.cfg.input_channels
375
+ latents = self.prepare_latents(
376
+ batch_size * num_meshes_per_prompt,
377
+ num_latents,
378
+ num_channels_latents,
379
+ image_embeds.dtype,
380
+ device,
381
+ generator,
382
+ latents,
383
+ )
384
+
385
+ # 6. Denoising loop
386
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
387
+ for i, t in enumerate(timesteps):
388
+ # expand the latents if we are doing classifier free guidance
389
+ latent_model_input = (
390
+ torch.cat([latents] * 2)
391
+ if self.do_classifier_free_guidance
392
+ else latents
393
+ )
394
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
395
+ timestep = t.expand(latent_model_input.shape[0])
396
+
397
+ noise_pred = self.transformer(
398
+ latent_model_input,
399
+ timestep,
400
+ visual_condition=image_embeds,
401
+ label_condition=label_embeds,
402
+ caption_condition=caption_embeds,
403
+ return_dict=False,
404
+ )[0]
405
+
406
+ # perform guidance
407
+ if self.do_classifier_free_guidance:
408
+ noise_pred_uncond, noise_pred_image = noise_pred.chunk(2)
409
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
410
+ noise_pred_image - noise_pred_uncond
411
+ )
412
+
413
+ if (i <= zero_steps) and use_zero_init:
414
+ noise_pred = noise_pred * 0.0
415
+
416
+ # compute the previous noisy sample x_t -> x_t-1
417
+ latents_dtype = latents.dtype
418
+ latents = self.scheduler.step(
419
+ noise_pred, t, latents, return_dict=False
420
+ )[0]
421
+
422
+ if latents.dtype != latents_dtype:
423
+ if torch.backends.mps.is_available():
424
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
425
+ latents = latents.to(latents_dtype)
426
+
427
+ if i == len(timesteps) - 1 or (
428
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
429
+ ):
430
+ progress_bar.update()
431
+
432
+ # 4. Post-processing
433
+ if not output_type == "latent":
434
+ if latents.dtype == torch.bfloat16:
435
+ self.vae.to(torch.float16)
436
+ latents = latents.to(torch.float16)
437
+ mesh = self.vae.extract_geometry(
438
+ self.vae.decode(latents),
439
+ surface_extractor_type=surface_extractor_type,
440
+ bounds=bounds,
441
+ mc_level=mc_level,
442
+ octree_resolution=octree_resolution,
443
+ enable_pbar=False,
444
+ )
445
+ if output_type != "raw":
446
+ mesh_list = []
447
+ for i, cur_mesh in enumerate(mesh):
448
+ print(f"Generating mesh {i+1}/{num_meshes_per_prompt}")
449
+ if output_type == "trimesh":
450
+ import trimesh
451
+
452
+ cur_mesh = trimesh.Trimesh(
453
+ vertices=cur_mesh.verts.cpu().numpy(),
454
+ faces=cur_mesh.faces.cpu().numpy(),
455
+ )
456
+ cur_mesh.fix_normals()
457
+ cur_mesh.face_normals
458
+ cur_mesh.vertex_normals
459
+ cur_mesh.visual = trimesh.visual.TextureVisuals(
460
+ material=trimesh.visual.material.PBRMaterial(
461
+ baseColorFactor=(255, 255, 255),
462
+ main_color=(255, 255, 255),
463
+ metallicFactor=0.05,
464
+ roughnessFactor=1.0,
465
+ )
466
+ )
467
+ if do_remove_floater:
468
+ cur_mesh = remove_floater(cur_mesh)
469
+ if do_remove_degenerate_face:
470
+ cur_mesh = remove_degenerate_face(cur_mesh)
471
+ if do_reduce_face and max_facenum > 0:
472
+ cur_mesh = reduce_face(cur_mesh, max_facenum)
473
+ if do_shade_smooth:
474
+ cur_mesh = cur_mesh.smooth_shaded
475
+ mesh_list.append(cur_mesh)
476
+ elif output_type == "np":
477
+ if do_remove_floater:
478
+ print(
479
+ 'remove floater is NOT used when output_type is "np". '
480
+ )
481
+ if do_remove_degenerate_face:
482
+ print(
483
+ 'remove degenerate face is NOT used when output_type is "np". '
484
+ )
485
+ if do_reduce_face:
486
+ print(
487
+ 'reduce floater is NOT used when output_type is "np". '
488
+ )
489
+ if do_shade_smooth:
490
+ print('shade smooth is NOT used when output_type is "np". ')
491
+ mesh_list.append(
492
+ [
493
+ cur_mesh[0].verts.cpu().numpy(),
494
+ cur_mesh[0].faces.cpu().numpy(),
495
+ ]
496
+ )
497
+ mesh = mesh_list
498
+ else:
499
+ if do_remove_floater:
500
+ print('remove floater is NOT used when output_type is "raw". ')
501
+ if do_remove_degenerate_face:
502
+ print(
503
+ 'remove degenerate face is NOT used when output_type is "raw". '
504
+ )
505
+ if do_reduce_face:
506
+ print('reduce floater is NOT used when output_type is "raw". ')
507
+
508
+ else:
509
+ mesh = latents
510
+
511
+ if not return_dict:
512
+ return tuple(image_pil), tuple(mesh)
513
+ return Step1X3DGeometryPipelineOutput(image=image_pil, mesh=mesh)
step1x3d_geometry/models/pipelines/pipeline_utils.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Union, Dict, Any
2
+ import os
3
+ from diffusers.utils import logging
4
+ import PIL.Image
5
+ import torch
6
+ import trimesh
7
+ import pymeshlab
8
+ import tempfile
9
+ from step1x3d_geometry.models.autoencoders.surface_extractors import MeshExtractResult
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ def preprocess_image(
15
+ images_pil: Union[List[PIL.Image.Image], PIL.Image.Image],
16
+ force: bool = False,
17
+ background_color: List[int] = [255, 255, 255],
18
+ foreground_ratio: float = 0.9,
19
+ rembg_backend: str = "bria",
20
+ **rembg_kwargs,
21
+ ):
22
+ r"""
23
+ Crop and remote the background of the input image
24
+ Args:
25
+ image_pil (`List[PIL.Image.Image]`):
26
+ List of `PIL.Image.Image` objects representing the input image.
27
+ force (`bool`, *optional*, defaults to `False`):
28
+ Whether to force remove the background even if the image has an alpha channel.
29
+ Returns:
30
+ `List[PIL.Image.Image]`: List of `PIL.Image.Image` objects representing the preprocessed image.
31
+ """
32
+ is_single_image = False
33
+ if isinstance(images_pil, PIL.Image.Image):
34
+ images_pil = [images_pil]
35
+ is_single_image = True
36
+ preprocessed_images = []
37
+ for i in range(len(images_pil)):
38
+ image = images_pil[i]
39
+ width, height, size = image.width, image.height, image.size
40
+ do_remove = True
41
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
42
+ # explain why current do not rm bg
43
+ print(
44
+ "alhpa channl not empty, skip remove background, using alpha channel as mask"
45
+ )
46
+ do_remove = False
47
+ do_remove = do_remove or force
48
+ if do_remove:
49
+ import rembg # lazy import
50
+
51
+ if rembg_backend == "default":
52
+ image = rembg.remove(image, **rembg_kwargs)
53
+ else:
54
+ image = rembg.remove(
55
+ image,
56
+ session=rembg.new_session(
57
+ model_name="bria",
58
+ providers=[
59
+ (
60
+ "CUDAExecutionProvider",
61
+ {
62
+ "device_id": 0,
63
+ "arena_extend_strategy": "kSameAsRequested",
64
+ "gpu_mem_limit": 6 * 1024 * 1024 * 1024,
65
+ "cudnn_conv_algo_search": "HEURISTIC",
66
+ },
67
+ ),
68
+ "CPUExecutionProvider",
69
+ ],
70
+ ),
71
+ **rembg_kwargs,
72
+ )
73
+
74
+ # calculate the min bbox of the image
75
+ alpha = image.split()[-1]
76
+ bboxs = alpha.getbbox()
77
+ x1, y1, x2, y2 = bboxs
78
+ dy, dx = y2 - y1, x2 - x1
79
+ s = min(height * foreground_ratio / dy, width * foreground_ratio / dx)
80
+ Ht, Wt = int(dy * s), int(dx * s)
81
+
82
+ background = PIL.Image.new("RGBA", image.size, (*background_color, 255))
83
+ image = PIL.Image.alpha_composite(background, image)
84
+ image = image.crop(alpha.getbbox())
85
+ alpha = alpha.crop(alpha.getbbox())
86
+
87
+ # Calculate the new size after rescaling
88
+ new_size = tuple(int(dim * foreground_ratio) for dim in size)
89
+ # Resize the image while maintaining the aspect ratio
90
+ resized_image = image.resize((Wt, Ht))
91
+ resized_alpha = alpha.resize((Wt, Ht))
92
+ # Create a new image with the original size and white background
93
+ padded_image = PIL.Image.new("RGB", size, tuple(background_color))
94
+ padded_alpha = PIL.Image.new("L", size, (0))
95
+ paste_position = (
96
+ (width - resized_image.width) // 2,
97
+ (height - resized_image.height) // 2,
98
+ )
99
+ padded_image.paste(resized_image, paste_position)
100
+ padded_alpha.paste(resized_alpha, paste_position)
101
+
102
+ # expand image to 1:1
103
+ width, height = padded_image.size
104
+ if width == height:
105
+ padded_image.putalpha(padded_alpha)
106
+ preprocessed_images.append(padded_image)
107
+ continue
108
+ new_size = (max(width, height), max(width, height))
109
+ new_image = PIL.Image.new("RGB", new_size, tuple(background_color))
110
+ new_alpha = PIL.Image.new("L", new_size, (0))
111
+ paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
112
+ new_image.paste(padded_image, paste_position)
113
+ new_alpha.paste(padded_alpha, paste_position)
114
+ new_image.putalpha(new_alpha)
115
+ preprocessed_images.append(new_image)
116
+
117
+ if is_single_image:
118
+ return preprocessed_images[0]
119
+ return preprocessed_images
120
+
121
+
122
+ def load_mesh(path):
123
+ if path.endswith(".glb"):
124
+ mesh = trimesh.load(path)
125
+ else:
126
+ mesh = pymeshlab.MeshSet()
127
+ mesh.load_new_mesh(path)
128
+ return mesh
129
+
130
+
131
+ def trimesh2pymeshlab(mesh: trimesh.Trimesh):
132
+ with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file:
133
+ if isinstance(mesh, trimesh.scene.Scene):
134
+ for idx, obj in enumerate(mesh.geometry.values()):
135
+ if idx == 0:
136
+ temp_mesh = obj
137
+ else:
138
+ temp_mesh = temp_mesh + obj
139
+ mesh = temp_mesh
140
+ mesh.export(temp_file.name)
141
+ mesh = pymeshlab.MeshSet()
142
+ mesh.load_new_mesh(temp_file.name)
143
+ return mesh
144
+
145
+
146
+ def pymeshlab2trimesh(mesh: pymeshlab.MeshSet):
147
+ with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file:
148
+ mesh.save_current_mesh(temp_file.name)
149
+ mesh = trimesh.load(temp_file.name)
150
+ if isinstance(mesh, trimesh.Scene):
151
+ combined_mesh = trimesh.Trimesh()
152
+ for geom in mesh.geometry.values():
153
+ combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
154
+ mesh = combined_mesh
155
+ return mesh
156
+
157
+
158
+ def import_mesh(mesh):
159
+ mesh_type = type(mesh)
160
+ if isinstance(mesh, str):
161
+ mesh = load_mesh(mesh)
162
+ elif isinstance(mesh, MeshExtractResult):
163
+ mesh = pymeshlab.MeshSet()
164
+ mesh_pymeshlab = pymeshlab.Mesh(
165
+ vertex_matrix=mesh.verts.cpu().numpy(), face_matrix=mesh.faces.cpu().numpy()
166
+ )
167
+ mesh.add_mesh(mesh_pymeshlab, "converted_mesh")
168
+
169
+ if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)):
170
+ mesh = trimesh2pymeshlab(mesh)
171
+
172
+ return mesh, mesh_type
173
+
174
+
175
+ def remove_floater(mesh):
176
+ mesh, mesh_type = import_mesh(mesh)
177
+
178
+ mesh.apply_filter(
179
+ "compute_selection_by_small_disconnected_components_per_face", nbfaceratio=0.001
180
+ )
181
+ mesh.apply_filter("compute_selection_transfer_face_to_vertex", inclusive=False)
182
+ mesh.apply_filter("meshing_remove_selected_vertices_and_faces")
183
+
184
+ return pymeshlab2trimesh(mesh)
185
+
186
+
187
+ def remove_degenerate_face(mesh):
188
+ mesh, mesh_type = import_mesh(mesh)
189
+
190
+ with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file:
191
+ mesh.save_current_mesh(temp_file.name)
192
+ mesh = pymeshlab.MeshSet()
193
+ mesh.load_new_mesh(temp_file.name)
194
+
195
+ return pymeshlab2trimesh(mesh)
196
+
197
+
198
+ def reduce_face(mesh, max_facenum=50000):
199
+ mesh, mesh_type = import_mesh(mesh)
200
+
201
+ if max_facenum > mesh.current_mesh().face_number():
202
+ return pymeshlab2trimesh(mesh)
203
+
204
+ mesh.apply_filter(
205
+ "meshing_decimation_quadric_edge_collapse",
206
+ targetfacenum=max_facenum,
207
+ qualitythr=1.0,
208
+ preserveboundary=True,
209
+ boundaryweight=3,
210
+ preservenormal=True,
211
+ preservetopology=True,
212
+ autoclean=True,
213
+ )
214
+
215
+ return pymeshlab2trimesh(mesh)
216
+
217
+
218
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
219
+ def retrieve_timesteps(
220
+ scheduler,
221
+ num_inference_steps: Optional[int] = None,
222
+ device: Optional[Union[str, torch.device]] = None,
223
+ timesteps: Optional[List[int]] = None,
224
+ sigmas: Optional[List[float]] = None,
225
+ **kwargs,
226
+ ):
227
+ r"""
228
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
229
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
230
+
231
+ Args:
232
+ scheduler (`SchedulerMixin`):
233
+ The scheduler to get timesteps from.
234
+ num_inference_steps (`int`):
235
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
236
+ must be `None`.
237
+ device (`str` or `torch.device`, *optional*):
238
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
239
+ timesteps (`List[int]`, *optional*):
240
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
241
+ `num_inference_steps` and `sigmas` must be `None`.
242
+ sigmas (`List[float]`, *optional*):
243
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
244
+ `num_inference_steps` and `timesteps` must be `None`.
245
+
246
+ Returns:
247
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
248
+ second element is the number of inference steps.
249
+ """
250
+ if timesteps is not None and sigmas is not None:
251
+ raise ValueError(
252
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
253
+ )
254
+ if timesteps is not None:
255
+ accepts_timesteps = "timesteps" in set(
256
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
257
+ )
258
+ if not accepts_timesteps:
259
+ raise ValueError(
260
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
261
+ f" timestep schedules. Please check whether you are using the correct scheduler."
262
+ )
263
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
264
+ timesteps = scheduler.timesteps
265
+ num_inference_steps = len(timesteps)
266
+ elif sigmas is not None:
267
+ accept_sigmas = "sigmas" in set(
268
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
269
+ )
270
+ if not accept_sigmas:
271
+ raise ValueError(
272
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
273
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
274
+ )
275
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
276
+ timesteps = scheduler.timesteps
277
+ num_inference_steps = len(timesteps)
278
+ else:
279
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
280
+ timesteps = scheduler.timesteps
281
+ return timesteps, num_inference_steps
282
+
283
+
284
+ class TransformerDiffusionMixin:
285
+ r"""
286
+ Helper for DiffusionPipeline with vae and transformer.(mainly for DIT)
287
+ """
288
+
289
+ def enable_vae_slicing(self):
290
+ r"""
291
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
292
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
293
+ """
294
+ self.vae.enable_slicing()
295
+
296
+ def disable_vae_slicing(self):
297
+ r"""
298
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
299
+ computing decoding in one step.
300
+ """
301
+ self.vae.disable_slicing()
302
+
303
+ def enable_vae_tiling(self):
304
+ r"""
305
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
306
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
307
+ processing larger images.
308
+ """
309
+ self.vae.enable_tiling()
310
+
311
+ def disable_vae_tiling(self):
312
+ r"""
313
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
314
+ computing decoding in one step.
315
+ """
316
+ self.vae.disable_tiling()
317
+
318
+ def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
319
+ """
320
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
321
+ are fused. For cross-attention modules, key and value projection matrices are fused.
322
+
323
+ <Tip warning={true}>
324
+
325
+ This API is 🧪 experimental.
326
+
327
+ </Tip>
328
+
329
+ Args:
330
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
331
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
332
+ """
333
+ self.fusing_transformer = False
334
+ self.fusing_vae = False
335
+
336
+ if transformer:
337
+ self.fusing_transformer = True
338
+ self.transformer.fuse_qkv_projections()
339
+
340
+ if vae:
341
+ self.fusing_vae = True
342
+ self.vae.fuse_qkv_projections()
343
+
344
+ def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
345
+ """Disable QKV projection fusion if enabled.
346
+
347
+ <Tip warning={true}>
348
+
349
+ This API is 🧪 experimental.
350
+
351
+ </Tip>
352
+
353
+ Args:
354
+ transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
355
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
356
+
357
+ """
358
+ if transformer:
359
+ if not self.fusing_transformer:
360
+ logger.warning(
361
+ "The UNet was not initially fused for QKV projections. Doing nothing."
362
+ )
363
+ else:
364
+ self.transformer.unfuse_qkv_projections()
365
+ self.fusing_transformer = False
366
+
367
+ if vae:
368
+ if not self.fusing_vae:
369
+ logger.warning(
370
+ "The VAE was not initially fused for QKV projections. Doing nothing."
371
+ )
372
+ else:
373
+ self.vae.unfuse_qkv_projections()
374
+ self.fusing_vae = False
375
+
376
+ def try_download(model_id, subfolder):
377
+ try:
378
+ from huggingface_hub import snapshot_download
379
+
380
+ path = snapshot_download(
381
+ repo_id=model_id,
382
+ allow_patterns=[f"{subfolder}/*"],
383
+ )
384
+ print(path)
385
+ model_path = os.path.join(path, subfolder)
386
+ return model_path
387
+ except Exception as e:
388
+ raise e
389
+
390
+
391
+ def smart_load_model(model_path, subfolder = ""):
392
+ if subfolder == "":
393
+ if os.path.exists(model_path):
394
+ return model_path
395
+ else:
396
+ return try_download(model_path, '.')
397
+ else:
398
+ if os.path.exists(os.path.join(model_path, subfolder)):
399
+ return os.path.join(model_path, subfolder)
400
+ else:
401
+ return try_download(model_path, subfolder)
402
+
403
+
404
+
step1x3d_geometry/models/transformers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import flux_transformer_1d, pixart_transformer_1d
step1x3d_geometry/models/transformers/flux_transformer_1d.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Some parts of this file are adapted from Hugging Face Diffusers library.
2
+ from typing import Any, Dict, Optional, Union, Tuple
3
+ from dataclasses import dataclass
4
+
5
+ import re
6
+ import torch
7
+ from torch import nn
8
+
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.loaders import PeftAdapterMixin
11
+ from diffusers.models.attention_processor import (
12
+ Attention,
13
+ AttentionProcessor,
14
+ AttnProcessor,
15
+ )
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers.models.embeddings import (
18
+ GaussianFourierProjection,
19
+ TimestepEmbedding,
20
+ Timesteps,
21
+ )
22
+ from diffusers.utils import (
23
+ USE_PEFT_BACKEND,
24
+ is_torch_version,
25
+ logging,
26
+ scale_lora_layers,
27
+ unscale_lora_layers,
28
+ )
29
+ from diffusers.models.normalization import (
30
+ AdaLayerNormSingle,
31
+ AdaLayerNormContinuous,
32
+ FP32LayerNorm,
33
+ LayerNorm,
34
+ )
35
+
36
+ from ..attention_processor import FusedFluxAttnProcessor2_0, FluxAttnProcessor2_0
37
+ from ..attention import FluxTransformerBlock, FluxSingleTransformerBlock
38
+
39
+ import step1x3d_geometry
40
+ from step1x3d_geometry.utils.base import BaseModule
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+
45
+ @dataclass
46
+ class Transformer1DModelOutput:
47
+ sample: torch.FloatTensor
48
+
49
+
50
+ class FluxTransformer1DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
51
+ r"""
52
+ The Transformer model introduced in Flux.
53
+
54
+ Reference: https://blackforestlabs.ai/announcing-black-forest-la
55
+
56
+ Parameters:
57
+ num_attention_heads (`int`, *optional*, defaults to 16):
58
+ The number of heads to use for multi-head attention.
59
+ width (`int`, *optional*, defaults to 2048):
60
+ Maximum sequence length in latent space (equivalent to max_seq_length in Transformers).
61
+ Determines the first dimension size of positional embedding matrices[1](@ref).
62
+ in_channels (`int`, *optional*, defaults to 64):
63
+ The number of channels in the input and output (specify if the input is **continuous**).
64
+ num_layers (`int`, *optional*, defaults to 1):
65
+ The number of layers of Transformer blocks to use.
66
+ cross_attention_dim (`int`, *optional*):
67
+ Dimensionality of conditional embeddings for cross-attention mechanisms
68
+ """
69
+
70
+ _supports_gradient_checkpointing = True
71
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
72
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ num_attention_heads: int = 16,
78
+ width: int = 2048,
79
+ in_channels: int = 4,
80
+ num_layers: int = 19,
81
+ num_single_layers: int = 38,
82
+ cross_attention_dim: int = 768,
83
+ ):
84
+ super().__init__()
85
+ # Set some common variables used across the board.
86
+ self.out_channels = in_channels
87
+ self.num_heads = num_attention_heads
88
+ self.inner_dim = width
89
+
90
+ # self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
91
+ # self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
92
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
93
+ "positional",
94
+ inner_dim=self.inner_dim,
95
+ flip_sin_to_cos=False,
96
+ freq_shift=0,
97
+ time_embedding_dim=None,
98
+ )
99
+ self.time_proj = TimestepEmbedding(
100
+ timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
101
+ )
102
+ self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
103
+ self.proj_cross_attention = nn.Linear(
104
+ self.config.cross_attention_dim, self.inner_dim, bias=True
105
+ )
106
+
107
+ # 2. Initialize the transformer blocks.
108
+ self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ FluxTransformerBlock(
111
+ dim=self.inner_dim,
112
+ num_attention_heads=num_attention_heads,
113
+ attention_head_dim=width // num_attention_heads,
114
+ )
115
+ for _ in range(self.config.num_layers)
116
+ ]
117
+ )
118
+ self.single_transformer_blocks = nn.ModuleList(
119
+ [
120
+ FluxSingleTransformerBlock(
121
+ dim=self.inner_dim,
122
+ num_attention_heads=num_attention_heads,
123
+ attention_head_dim=width // num_attention_heads,
124
+ )
125
+ for _ in range(self.config.num_single_layers)
126
+ ]
127
+ )
128
+
129
+ # 3. Output blocks.
130
+ self.norm_out = AdaLayerNormContinuous(
131
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
132
+ )
133
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
134
+
135
+ self.gradient_checkpointing = False
136
+
137
+ def _set_time_proj(
138
+ self,
139
+ time_embedding_type: str,
140
+ inner_dim: int,
141
+ flip_sin_to_cos: bool,
142
+ freq_shift: float,
143
+ time_embedding_dim: int,
144
+ ) -> Tuple[int, int]:
145
+ if time_embedding_type == "fourier":
146
+ time_embed_dim = time_embedding_dim or inner_dim * 2
147
+ if time_embed_dim % 2 != 0:
148
+ raise ValueError(
149
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
150
+ )
151
+ self.time_embed = GaussianFourierProjection(
152
+ time_embed_dim // 2,
153
+ set_W_to_weight=False,
154
+ log=False,
155
+ flip_sin_to_cos=flip_sin_to_cos,
156
+ )
157
+ timestep_input_dim = time_embed_dim
158
+ elif time_embedding_type == "positional":
159
+ time_embed_dim = time_embedding_dim or inner_dim * 4
160
+
161
+ self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
162
+ timestep_input_dim = inner_dim
163
+ else:
164
+ raise ValueError(
165
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
166
+ )
167
+
168
+ return time_embed_dim, timestep_input_dim
169
+
170
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
171
+ def fuse_qkv_projections(self):
172
+ """
173
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
174
+ are fused. For cross-attention modules, key and value projection matrices are fused.
175
+
176
+ <Tip warning={true}>
177
+
178
+ This API is 🧪 experimental.
179
+
180
+ </Tip>
181
+ """
182
+ self.original_attn_processors = None
183
+
184
+ for _, attn_processor in self.attn_processors.items():
185
+ if "Added" in str(attn_processor.__class__.__name__):
186
+ raise ValueError(
187
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
188
+ )
189
+
190
+ self.original_attn_processors = self.attn_processors
191
+
192
+ for module in self.modules():
193
+ if isinstance(module, Attention):
194
+ module.fuse_projections(fuse=True)
195
+
196
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
197
+
198
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
199
+ def unfuse_qkv_projections(self):
200
+ """Disables the fused QKV projection if enabled.
201
+
202
+ <Tip warning={true}>
203
+
204
+ This API is 🧪 experimental.
205
+
206
+ </Tip>
207
+
208
+ """
209
+ if self.original_attn_processors is not None:
210
+ self.set_attn_processor(self.original_attn_processors)
211
+
212
+ @property
213
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
214
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
215
+ r"""
216
+ Returns:
217
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
218
+ indexed by its weight name.
219
+ """
220
+ # set recursively
221
+ processors = {}
222
+
223
+ def fn_recursive_add_processors(
224
+ name: str,
225
+ module: torch.nn.Module,
226
+ processors: Dict[str, AttentionProcessor],
227
+ ):
228
+ if hasattr(module, "get_processor"):
229
+ processors[f"{name}.processor"] = module.get_processor()
230
+
231
+ for sub_name, child in module.named_children():
232
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
233
+
234
+ return processors
235
+
236
+ for name, module in self.named_children():
237
+ fn_recursive_add_processors(name, module, processors)
238
+
239
+ return processors
240
+
241
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
242
+ def set_attn_processor(
243
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
244
+ ):
245
+ r"""
246
+ Sets the attention processor to use to compute attention.
247
+
248
+ Parameters:
249
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
250
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
251
+ for **all** `Attention` layers.
252
+
253
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
254
+ processor. This is strongly recommended when setting trainable attention processors.
255
+
256
+ """
257
+ count = len(self.attn_processors.keys())
258
+
259
+ if isinstance(processor, dict) and len(processor) != count:
260
+ raise ValueError(
261
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
262
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
263
+ )
264
+
265
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
266
+ if hasattr(module, "set_processor"):
267
+ if not isinstance(processor, dict):
268
+ module.set_processor(processor)
269
+ else:
270
+ module.set_processor(processor.pop(f"{name}.processor"))
271
+
272
+ for sub_name, child in module.named_children():
273
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
274
+
275
+ for name, module in self.named_children():
276
+ fn_recursive_attn_processor(name, module, processor)
277
+
278
+ def set_default_attn_processor(self):
279
+ """
280
+ Disables custom attention processors and sets the default attention implementation.
281
+ """
282
+ self.set_attn_processor(FluxAttnProcessor2_0())
283
+
284
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
285
+ def enable_forward_chunking(
286
+ self, chunk_size: Optional[int] = None, dim: int = 0
287
+ ) -> None:
288
+ """
289
+ Sets the attention processor to use [feed forward
290
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
291
+
292
+ Parameters:
293
+ chunk_size (`int`, *optional*):
294
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
295
+ over each tensor of dim=`dim`.
296
+ dim (`int`, *optional*, defaults to `0`):
297
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
298
+ or dim=1 (sequence length).
299
+ """
300
+ if dim not in [0, 1]:
301
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
302
+
303
+ # By default chunk size is 1
304
+ chunk_size = chunk_size or 1
305
+
306
+ def fn_recursive_feed_forward(
307
+ module: torch.nn.Module, chunk_size: int, dim: int
308
+ ):
309
+ if hasattr(module, "set_chunk_feed_forward"):
310
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
311
+
312
+ for child in module.children():
313
+ fn_recursive_feed_forward(child, chunk_size, dim)
314
+
315
+ for module in self.children():
316
+ fn_recursive_feed_forward(module, chunk_size, dim)
317
+
318
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
319
+ def disable_forward_chunking(self):
320
+ def fn_recursive_feed_forward(
321
+ module: torch.nn.Module, chunk_size: int, dim: int
322
+ ):
323
+ if hasattr(module, "set_chunk_feed_forward"):
324
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
325
+
326
+ for child in module.children():
327
+ fn_recursive_feed_forward(child, chunk_size, dim)
328
+
329
+ for module in self.children():
330
+ fn_recursive_feed_forward(module, None, 0)
331
+
332
+ def forward(
333
+ self,
334
+ hidden_states: Optional[torch.Tensor],
335
+ timestep: Union[int, float, torch.LongTensor],
336
+ encoder_hidden_states: Optional[torch.Tensor] = None,
337
+ attention_kwargs: Optional[Dict[str, Any]] = None,
338
+ return_dict: bool = True,
339
+ ):
340
+ """
341
+ The [`HunyuanDiT2DModel`] forward method.
342
+
343
+ Args:
344
+ hidden_states (`torch.Tensor` of shape `(batch size, dim, latents_size)`):
345
+ The input tensor.
346
+ timestep ( `torch.LongTensor`, *optional*):
347
+ Used to indicate denoising step.
348
+ encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
349
+ Conditional embeddings for cross attention layer.
350
+ encoder_hidden_states_2 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
351
+ Conditional embeddings for cross attention layer.
352
+ return_dict: bool
353
+ Whether to return a dictionary.
354
+ """
355
+
356
+ if attention_kwargs is not None:
357
+ attention_kwargs = attention_kwargs.copy()
358
+ lora_scale = attention_kwargs.pop("scale", 1.0)
359
+ else:
360
+ lora_scale = 1.0
361
+
362
+ if USE_PEFT_BACKEND:
363
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
364
+ scale_lora_layers(self, lora_scale)
365
+ else:
366
+ if (
367
+ attention_kwargs is not None
368
+ and attention_kwargs.get("scale", None) is not None
369
+ ):
370
+ logger.warning(
371
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
372
+ )
373
+
374
+ _, N, _ = hidden_states.shape
375
+
376
+ # import pdb; pdb.set_trace()
377
+ # timesteps_proj = self.time_proj(timestep) # N x 256
378
+ # temb = self.time_embed(timesteps_proj).to(hidden_states.dtype)
379
+ temb = self.time_embed(timestep).to(hidden_states.dtype) # N x 1280
380
+ temb = self.time_proj(temb) # N x 1280
381
+
382
+ hidden_states = self.proj_in(hidden_states)
383
+ encoder_hidden_states = self.proj_cross_attention(encoder_hidden_states)
384
+
385
+ for layer, block in enumerate(self.transformer_blocks):
386
+ if self.training and self.gradient_checkpointing:
387
+
388
+ def create_custom_forward(module):
389
+ def custom_forward(*inputs):
390
+ return module(*inputs)
391
+
392
+ return custom_forward
393
+
394
+ ckpt_kwargs: Dict[str, Any] = (
395
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
396
+ )
397
+ encoder_hidden_states, hidden_states = (
398
+ torch.utils.checkpoint.checkpoint(
399
+ create_custom_forward(block),
400
+ hidden_states,
401
+ encoder_hidden_states,
402
+ temb,
403
+ None, # image_rotary_emb
404
+ attention_kwargs,
405
+ )
406
+ )
407
+ else:
408
+ encoder_hidden_states, hidden_states = block(
409
+ hidden_states,
410
+ encoder_hidden_states=encoder_hidden_states,
411
+ temb=temb,
412
+ image_rotary_emb=None,
413
+ joint_attention_kwargs=attention_kwargs,
414
+ ) # (N, L, D)
415
+
416
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
417
+
418
+ for layer, block in enumerate(self.single_transformer_blocks):
419
+ if self.training and self.gradient_checkpointing:
420
+
421
+ def create_custom_forward(module):
422
+ def custom_forward(*inputs):
423
+ return module(*inputs)
424
+
425
+ return custom_forward
426
+
427
+ ckpt_kwargs: Dict[str, Any] = (
428
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
429
+ )
430
+ hidden_states = torch.utils.checkpoint.checkpoint(
431
+ create_custom_forward(block),
432
+ hidden_states,
433
+ temb,
434
+ None, # image_rotary_emb
435
+ attention_kwargs,
436
+ )
437
+ else:
438
+ hidden_states = block(
439
+ hidden_states,
440
+ temb=temb,
441
+ image_rotary_emb=None,
442
+ joint_attention_kwargs=attention_kwargs,
443
+ ) # (N, L, D)
444
+
445
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
446
+
447
+ # final layer
448
+ hidden_states = self.norm_out(hidden_states, temb)
449
+ hidden_states = self.proj_out(hidden_states)
450
+
451
+ if USE_PEFT_BACKEND:
452
+ # remove `lora_scale` from each PEFT layer
453
+ unscale_lora_layers(self, lora_scale)
454
+
455
+ if not return_dict:
456
+ return (hidden_states,)
457
+
458
+ return Transformer1DModelOutput(sample=hidden_states)
459
+
460
+
461
+ @step1x3d_geometry.register("flux-denoiser")
462
+ class FluxDenoiser(BaseModule):
463
+ @dataclass
464
+ class Config(BaseModule.Config):
465
+ pretrained_model_name_or_path: Optional[str] = None
466
+ input_channels: int = 32
467
+ width: int = 768
468
+ layers: int = 12
469
+ num_single_layers: int = 12
470
+ num_heads: int = 16
471
+ condition_dim: int = 1024
472
+ multi_condition_type: str = "in_context"
473
+ use_visual_condition: bool = False
474
+ visual_condition_dim: int = 1024
475
+ n_views: int = 1
476
+ use_caption_condition: bool = False
477
+ caption_condition_dim: int = 1024
478
+ use_label_condition: bool = False
479
+ label_condition_dim: int = 1024
480
+
481
+ identity_init: bool = False
482
+
483
+ cfg: Config
484
+
485
+ def configure(self) -> None:
486
+ assert (
487
+ self.cfg.multi_condition_type == "in_context"
488
+ ), "Flux Denoiser only support in_context learning of multiple conditions"
489
+ self.dit_model = FluxTransformer1DModel(
490
+ num_attention_heads=self.cfg.num_heads,
491
+ width=self.cfg.width,
492
+ in_channels=self.cfg.input_channels,
493
+ num_layers=self.cfg.layers,
494
+ num_single_layers=self.cfg.num_single_layers,
495
+ cross_attention_dim=self.cfg.condition_dim,
496
+ )
497
+ if (
498
+ self.cfg.use_visual_condition
499
+ and self.cfg.visual_condition_dim != self.cfg.condition_dim
500
+ ):
501
+ self.proj_visual_condtion = nn.Sequential(
502
+ nn.RMSNorm(self.cfg.visual_condition_dim),
503
+ nn.Linear(self.cfg.visual_condition_dim, self.cfg.condition_dim),
504
+ )
505
+ if (
506
+ self.cfg.use_caption_condition
507
+ and self.cfg.caption_condition_dim != self.cfg.condition_dim
508
+ ):
509
+ self.proj_caption_condtion = nn.Sequential(
510
+ nn.RMSNorm(self.cfg.caption_condition_dim),
511
+ nn.Linear(self.cfg.caption_condition_dim, self.cfg.condition_dim),
512
+ )
513
+ if (
514
+ self.cfg.use_label_condition
515
+ and self.cfg.label_condition_dim != self.cfg.condition_dim
516
+ ):
517
+ self.proj_label_condtion = nn.Sequential(
518
+ nn.RMSNorm(self.cfg.label_condition_dim),
519
+ nn.Linear(self.cfg.label_condition_dim, self.cfg.condition_dim),
520
+ )
521
+
522
+ if self.cfg.identity_init:
523
+ self.identity_initialize()
524
+
525
+ if self.cfg.pretrained_model_name_or_path:
526
+ print(
527
+ f"Loading pretrained DiT model from {self.cfg.pretrained_model_name_or_path}"
528
+ )
529
+ ckpt = torch.load(
530
+ self.cfg.pretrained_model_name_or_path,
531
+ map_location="cpu",
532
+ weights_only=True,
533
+ )
534
+ if "state_dict" in ckpt.keys():
535
+ ckpt = ckpt["state_dict"]
536
+
537
+ self.load_state_dict(ckpt, strict=True)
538
+
539
+ def identity_initialize(self):
540
+ for block in self.dit_model.blocks:
541
+ nn.init.constant_(block.attn.c_proj.weight, 0)
542
+ nn.init.constant_(block.attn.c_proj.bias, 0)
543
+ nn.init.constant_(block.cross_attn.c_proj.weight, 0)
544
+ nn.init.constant_(block.cross_attn.c_proj.bias, 0)
545
+ nn.init.constant_(block.mlp.c_proj.weight, 0)
546
+ nn.init.constant_(block.mlp.c_proj.bias, 0)
547
+
548
+ def forward(
549
+ self,
550
+ model_input: torch.FloatTensor,
551
+ timestep: torch.LongTensor,
552
+ visual_condition: Optional[torch.FloatTensor] = None,
553
+ caption_condition: Optional[torch.FloatTensor] = None,
554
+ label_condition: Optional[torch.FloatTensor] = None,
555
+ attention_kwargs: Dict[str, torch.Tensor] = None,
556
+ return_dict: bool = True,
557
+ ):
558
+ r"""
559
+ Args:
560
+ model_input (torch.FloatTensor): [bs, n_data, c]
561
+ timestep (torch.LongTensor): [bs,]
562
+ visual_condition (torch.FloatTensor): [bs, visual_context_tokens, c]
563
+ caption_condition (torch.FloatTensor): [bs, text_context_tokens, c]
564
+ label_condition (torch.FloatTensor): [bs, c]
565
+
566
+ Returns:
567
+ sample (torch.FloatTensor): [bs, n_data, c]
568
+
569
+ """
570
+
571
+ B, n_data, _ = model_input.shape
572
+
573
+ # 0. conditions projector
574
+ condition = []
575
+ if self.cfg.use_visual_condition:
576
+ assert visual_condition.shape[-1] == self.cfg.visual_condition_dim
577
+ if self.cfg.visual_condition_dim != self.cfg.condition_dim:
578
+ visual_condition = self.proj_visual_condtion(visual_condition)
579
+ condition.append(visual_condition)
580
+ if self.cfg.use_caption_condition:
581
+ assert caption_condition.shape[-1] == self.cfg.caption_condition_dim
582
+ if self.cfg.caption_condition_dim != self.cfg.condition_dim:
583
+ caption_condition = self.proj_caption_condtion(caption_condition)
584
+ condition.append(caption_condition)
585
+ if self.cfg.use_label_condition:
586
+ assert label_condition.shape[-1] == self.cfg.label_condition_dim
587
+ if self.cfg.label_condition_dim != self.cfg.condition_dim:
588
+ label_condition = self.proj_label_condtion(label_condition)
589
+ condition.append(label_condition)
590
+
591
+ # 1. denoise
592
+ output = self.dit_model(
593
+ model_input,
594
+ timestep,
595
+ torch.cat(condition, dim=1),
596
+ attention_kwargs,
597
+ return_dict=return_dict,
598
+ )
599
+
600
+ return output
step1x3d_geometry/models/transformers/pixart_transformer_1d.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Some parts of this file are adapted from Hugging Face Diffusers library.
2
+ from dataclasses import dataclass
3
+
4
+ import re
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ from typing import Callable, List, Optional, Union, Dict, Any
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.utils import logging
11
+ from diffusers.models.attention_processor import (
12
+ Attention,
13
+ AttentionProcessor,
14
+ AttnProcessor,
15
+ )
16
+ from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
17
+ from diffusers.models.modeling_utils import ModelMixin
18
+ from diffusers.models.normalization import AdaLayerNormSingle
19
+
20
+ from ..attention_processor import FusedAttnProcessor2_0, AttnProcessor2_0
21
+ from ..attention import MultiCondBasicTransformerBlock
22
+
23
+ import step1x3d_geometry
24
+ from step1x3d_geometry.utils.base import BaseModule
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+
29
+ @dataclass
30
+ class Transformer1DModelOutput:
31
+ sample: torch.FloatTensor
32
+
33
+
34
+ class PixArtTransformer1DModel(ModelMixin, ConfigMixin):
35
+ r"""
36
+ A 1D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
37
+ https://arxiv.org/abs/2403.04692).
38
+
39
+ Parameters:
40
+ num_attention_heads (`int`, *optional*, defaults to 16):
41
+ The number of heads to use for multi-head attention.
42
+ width (`int`, *optional*, defaults to 2048):
43
+ Maximum sequence length in latent space (equivalent to max_seq_length in Transformers).
44
+ Determines the first dimension size of positional embedding matrices[1](@ref).
45
+ in_channels (`int`, *optional*, defaults to 64):
46
+ The number of channels in the input and output (specify if the input is **continuous**).
47
+ num_layers (`int`, *optional*, defaults to 1):
48
+ The number of layers of Transformer blocks to use.
49
+ cross_attention_dim (`int`, *optional*):
50
+ Dimensionality of conditional embeddings for cross-attention mechanisms
51
+ use_cross_attention_2 (`bool`, *optional*):
52
+ Flag to enable secondary cross-attention mechanism. Used for multi-modal conditioning
53
+ when processing hybrid inputs (e.g., text + image prompts)[1](@ref).
54
+ cross_attention_2_dim (`int`, *optional*, defaults to 1024):
55
+ Dimensionality of secondary cross-attention embeddings. Specifies encoding dimensions
56
+ for additional conditional modalities when use_cross_attention_2 is enabled[1](@ref).
57
+ """
58
+
59
+ _supports_gradient_checkpointing = True
60
+ _no_split_modules = ["MultiCondBasicTransformerBlock", "PatchEmbed"]
61
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ num_attention_heads: int = 16,
67
+ width: int = 2048,
68
+ in_channels: int = 4,
69
+ num_layers: int = 28,
70
+ cross_attention_dim: int = 768,
71
+ use_cross_attention_2: bool = True,
72
+ cross_attention_2_dim: int = 1024,
73
+ use_cross_attention_3: bool = True,
74
+ cross_attention_3_dim: int = 1024,
75
+ ):
76
+ super().__init__()
77
+ # Set some common variables used across the board.
78
+ self.out_channels = in_channels
79
+ self.num_heads = num_attention_heads
80
+ self.inner_dim = width
81
+
82
+ self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
83
+
84
+ # 2. Initialize the transformer blocks.
85
+ self.transformer_blocks = nn.ModuleList(
86
+ [
87
+ MultiCondBasicTransformerBlock(
88
+ self.inner_dim,
89
+ self.config.num_attention_heads,
90
+ use_self_attention=True,
91
+ use_cross_attention=True,
92
+ self_attention_norm_type="ada_norm_single",
93
+ cross_attention_dim=self.config.cross_attention_dim,
94
+ cross_attention_norm_type="ada_norm_single",
95
+ use_cross_attention_2=self.config.use_cross_attention_2,
96
+ cross_attention_2_dim=self.config.cross_attention_2_dim,
97
+ cross_attention_2_norm_type="ada_norm_single",
98
+ use_cross_attention_3=self.config.use_cross_attention_3,
99
+ cross_attention_3_dim=self.config.cross_attention_3_dim,
100
+ cross_attention_3_norm_type="ada_norm_single",
101
+ dropout=0.0,
102
+ attention_bias=False,
103
+ activation_fn="gelu-approximate",
104
+ num_embeds_ada_norm=1000,
105
+ norm_elementwise_affine=True,
106
+ upcast_attention=False,
107
+ norm_eps=1e-6,
108
+ attention_type="default",
109
+ )
110
+ for _ in range(self.config.num_layers)
111
+ ]
112
+ )
113
+
114
+ # 3. Output blocks.
115
+ self.norm_out = nn.RMSNorm(self.inner_dim, elementwise_affine=True, eps=1e-6)
116
+ self.scale_shift_table = nn.Parameter(
117
+ torch.randn(2, self.inner_dim) / self.inner_dim**0.5
118
+ )
119
+ self.proj_out = nn.Linear(self.inner_dim, self.out_channels)
120
+
121
+ self.adaln_single = AdaLayerNormSingle(
122
+ self.inner_dim, use_additional_conditions=None
123
+ )
124
+ self.gradient_checkpointing = False
125
+
126
+ @property
127
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
128
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
129
+ r"""
130
+ Returns:
131
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
132
+ indexed by its weight name.
133
+ """
134
+ # set recursively
135
+ processors = {}
136
+
137
+ def fn_recursive_add_processors(
138
+ name: str,
139
+ module: torch.nn.Module,
140
+ processors: Dict[str, AttentionProcessor],
141
+ ):
142
+ if hasattr(module, "get_processor"):
143
+ processors[f"{name}.processor"] = module.get_processor()
144
+
145
+ for sub_name, child in module.named_children():
146
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
147
+
148
+ return processors
149
+
150
+ for name, module in self.named_children():
151
+ fn_recursive_add_processors(name, module, processors)
152
+
153
+ return processors
154
+
155
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
156
+ def set_attn_processor(
157
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
158
+ ):
159
+ r"""
160
+ Sets the attention processor to use to compute attention.
161
+
162
+ Parameters:
163
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
164
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
165
+ for **all** `Attention` layers.
166
+
167
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
168
+ processor. This is strongly recommended when setting trainable attention processors.
169
+
170
+ """
171
+ count = len(self.attn_processors.keys())
172
+
173
+ if isinstance(processor, dict) and len(processor) != count:
174
+ raise ValueError(
175
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
176
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
177
+ )
178
+
179
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
180
+ if hasattr(module, "set_processor"):
181
+ if not isinstance(processor, dict):
182
+ module.set_processor(processor)
183
+ else:
184
+ module.set_processor(processor.pop(f"{name}.processor"))
185
+
186
+ for sub_name, child in module.named_children():
187
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
188
+
189
+ for name, module in self.named_children():
190
+ fn_recursive_attn_processor(name, module, processor)
191
+
192
+ def set_default_attn_processor(self):
193
+ """
194
+ Disables custom attention processors and sets the default attention implementation.
195
+ """
196
+ self.set_attn_processor(AttnProcessor2_0())
197
+
198
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
199
+ def fuse_qkv_projections(self):
200
+ """
201
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
202
+ are fused. For cross-attention modules, key and value projection matrices are fused.
203
+
204
+ <Tip warning={true}>
205
+
206
+ This API is 🧪 experimental.
207
+
208
+ </Tip>
209
+ """
210
+ self.original_attn_processors = None
211
+
212
+ for _, attn_processor in self.attn_processors.items():
213
+ if "Added" in str(attn_processor.__class__.__name__):
214
+ raise ValueError(
215
+ "`fuse_qkv_projections()` is not supported for models having added KV projections."
216
+ )
217
+
218
+ self.original_attn_processors = self.attn_processors
219
+
220
+ for module in self.modules():
221
+ if isinstance(module, Attention):
222
+ module.fuse_projections(fuse=True)
223
+
224
+ self.set_attn_processor(FusedAttnProcessor2_0())
225
+
226
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
227
+ def unfuse_qkv_projections(self):
228
+ """Disables the fused QKV projection if enabled.
229
+
230
+ <Tip warning={true}>
231
+
232
+ This API is 🧪 experimental.
233
+
234
+ </Tip>
235
+
236
+ """
237
+ if self.original_attn_processors is not None:
238
+ self.set_attn_processor(self.original_attn_processors)
239
+
240
+ def forward(
241
+ self,
242
+ hidden_states: torch.Tensor,
243
+ timestep: Optional[torch.LongTensor] = None,
244
+ encoder_hidden_states: Optional[torch.Tensor] = None,
245
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
246
+ encoder_hidden_states_3: Optional[torch.Tensor] = None,
247
+ cross_attention_kwargs: Dict[str, Any] = None,
248
+ attention_mask: Optional[torch.Tensor] = None,
249
+ encoder_attention_mask: Optional[torch.Tensor] = None,
250
+ encoder_attention_mask_2: Optional[torch.Tensor] = None,
251
+ encoder_attention_mask_3: Optional[torch.Tensor] = None,
252
+ return_dict: bool = True,
253
+ ):
254
+ """
255
+ The [`PixArtTransformer2DModel`] forward method.
256
+
257
+ Args:
258
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, n_tokens)`):
259
+ Input `hidden_states`.
260
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
261
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
262
+ self-attention.
263
+ encoder_hidden_states_2 (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
264
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
265
+ self-attention.
266
+ encoder_hidden_states_3 (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
267
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
268
+ self-attention.
269
+ timestep (`torch.LongTensor`, *optional*):
270
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
271
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
272
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
273
+ `self.processor` in
274
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
275
+ attention_mask ( `torch.Tensor`, *optional*):
276
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
277
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
278
+ negative values to the attention scores corresponding to "discard" tokens.
279
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
280
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
281
+
282
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
283
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
284
+
285
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
286
+ above. This bias will be added to the cross-attention scores.
287
+ encoder_attention_mask_2 ( `torch.Tensor`, *optional*):
288
+ Cross-attention mask applied to `encoder_hidden_states_2`. Two formats supported:
289
+
290
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
291
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
292
+
293
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
294
+ above. This bias will be added to the cross-attention scores.
295
+ encoder_attention_mask_3 ( `torch.Tensor`, *optional*):
296
+ Cross-attention mask applied to `encoder_hidden_states_3`. Two formats supported:
297
+
298
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
299
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
300
+
301
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
302
+ above. This bias will be added to the cross-attention scores.
303
+ return_dict (`bool`, *optional*, defaults to `True`):
304
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
305
+ tuple.
306
+
307
+ Returns:
308
+ If `return_dict` is True, an [`~Transformer1DModelOutput`] is returned, otherwise a
309
+ `tuple` where the first element is the sample tensor.
310
+ """
311
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
312
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
313
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
314
+ # expects mask of shape:
315
+ # [batch, key_tokens]
316
+ # adds singleton query_tokens dimension:
317
+ # [batch, 1, key_tokens]
318
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
319
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
320
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
321
+ if attention_mask is not None and attention_mask.ndim == 2:
322
+ # assume that mask is expressed as:
323
+ # (1 = keep, 0 = discard)
324
+ # convert mask into a bias that can be added to attention scores:
325
+ # (keep = +0, discard = -10000.0)
326
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
327
+ attention_mask = attention_mask.unsqueeze(1)
328
+
329
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
330
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
331
+ encoder_attention_mask = (
332
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
333
+ ) * -10000.0
334
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
335
+
336
+ # convert encoder_attention_mask_2 to a bias the same way we do for attention_mask
337
+ if encoder_attention_mask_2 is not None and encoder_attention_mask_2.ndim == 2:
338
+ encoder_attention_mask_2 = (
339
+ 1 - encoder_attention_mask_2.to(hidden_states.dtype)
340
+ ) * -10000.0
341
+ encoder_attention_mask_2 = encoder_attention_mask_2.unsqueeze(1)
342
+
343
+ # convert encoder_attention_mask_2 to a bias the same way we do for attention_mask
344
+ if encoder_attention_mask_3 is not None and encoder_attention_mask_3.ndim == 2:
345
+ encoder_attention_mask_3 = (
346
+ 1 - encoder_attention_mask_3.to(hidden_states.dtype)
347
+ ) * -10000.0
348
+ encoder_attention_mask_3 = encoder_attention_mask_3.unsqueeze(1)
349
+
350
+ # 1. Input
351
+ batch_size = hidden_states.shape[0]
352
+ timestep, embedded_timestep = self.adaln_single(
353
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
354
+ )
355
+
356
+ hidden_states = self.proj_in(hidden_states)
357
+
358
+ # 2. Blocks
359
+ for block in self.transformer_blocks:
360
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
361
+ hidden_states = self._gradient_checkpointing_func(
362
+ block,
363
+ hidden_states,
364
+ attention_mask,
365
+ encoder_hidden_states,
366
+ encoder_hidden_states_2,
367
+ encoder_hidden_states_3,
368
+ encoder_attention_mask,
369
+ encoder_attention_mask_2,
370
+ encoder_attention_mask_3,
371
+ timestep,
372
+ cross_attention_kwargs,
373
+ None,
374
+ )
375
+ else:
376
+ hidden_states = block(
377
+ hidden_states,
378
+ attention_mask=attention_mask,
379
+ encoder_hidden_states=encoder_hidden_states,
380
+ encoder_hidden_states_2=encoder_hidden_states_2,
381
+ encoder_hidden_states_3=encoder_hidden_states_3,
382
+ encoder_attention_mask=encoder_attention_mask,
383
+ encoder_attention_mask_2=encoder_attention_mask_2,
384
+ encoder_attention_mask_3=encoder_attention_mask_3,
385
+ timestep=timestep,
386
+ cross_attention_kwargs=cross_attention_kwargs,
387
+ class_labels=None,
388
+ )
389
+
390
+ # 3. Output
391
+ shift, scale = (
392
+ self.scale_shift_table[None]
393
+ + embedded_timestep[:, None].to(self.scale_shift_table.device)
394
+ ).chunk(2, dim=1)
395
+ hidden_states = self.norm_out(hidden_states)
396
+ # Modulation
397
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(
398
+ hidden_states.device
399
+ )
400
+ hidden_states = self.proj_out(hidden_states)
401
+ hidden_states = hidden_states.squeeze(1)
402
+
403
+ if not return_dict:
404
+ return (hidden_states,)
405
+
406
+ return Transformer1DModelOutput(sample=hidden_states)
407
+
408
+
409
+ @step1x3d_geometry.register("pixart-denoiser")
410
+ class PixArtDenoiser(BaseModule):
411
+ @dataclass
412
+ class Config(BaseModule.Config):
413
+ pretrained_model_name_or_path: Optional[str] = None
414
+ input_channels: int = 32
415
+ width: int = 768
416
+ layers: int = 28
417
+ num_heads: int = 16
418
+ condition_dim: int = 1024
419
+ multi_condition_type: str = "cross_attention"
420
+ use_visual_condition: bool = False
421
+ visual_condition_dim: int = 1024
422
+ n_views: int = 1 # for multi-view condition
423
+ use_caption_condition: bool = False
424
+ caption_condition_dim: int = 1024
425
+ use_label_condition: bool = False
426
+ label_condition_dim: int = 1024
427
+
428
+ identity_init: bool = False
429
+
430
+ cfg: Config
431
+
432
+ def configure(self) -> None:
433
+ self.dit_model = PixArtTransformer1DModel(
434
+ num_attention_heads=self.cfg.num_heads,
435
+ width=self.cfg.width,
436
+ in_channels=self.cfg.input_channels,
437
+ num_layers=self.cfg.layers,
438
+ cross_attention_dim=self.cfg.condition_dim,
439
+ use_cross_attention_2=self.cfg.use_caption_condition
440
+ and self.cfg.multi_condition_type == "cross_attention",
441
+ cross_attention_2_dim=self.cfg.condition_dim,
442
+ use_cross_attention_3=self.cfg.use_label_condition
443
+ and self.cfg.multi_condition_type == "cross_attention",
444
+ cross_attention_3_dim=self.cfg.condition_dim,
445
+ )
446
+ if (
447
+ self.cfg.use_visual_condition
448
+ and self.cfg.visual_condition_dim != self.cfg.condition_dim
449
+ ):
450
+ self.proj_visual_condtion = nn.Sequential(
451
+ nn.RMSNorm(self.cfg.visual_condition_dim),
452
+ nn.Linear(self.cfg.visual_condition_dim, self.cfg.condition_dim),
453
+ )
454
+ if (
455
+ self.cfg.use_caption_condition
456
+ and self.cfg.caption_condition_dim != self.cfg.condition_dim
457
+ ):
458
+ self.proj_caption_condtion = nn.Sequential(
459
+ nn.RMSNorm(self.cfg.caption_condition_dim),
460
+ nn.Linear(self.cfg.caption_condition_dim, self.cfg.condition_dim),
461
+ )
462
+ if (
463
+ self.cfg.use_label_condition
464
+ and self.cfg.label_condition_dim != self.cfg.condition_dim
465
+ ):
466
+ self.proj_label_condtion = nn.Sequential(
467
+ nn.RMSNorm(self.cfg.label_condition_dim),
468
+ nn.Linear(self.cfg.label_condition_dim, self.cfg.condition_dim),
469
+ )
470
+
471
+ if self.cfg.identity_init:
472
+ self.identity_initialize()
473
+
474
+ if self.cfg.pretrained_model_name_or_path:
475
+ print(
476
+ f"Loading pretrained DiT model from {self.cfg.pretrained_model_name_or_path}"
477
+ )
478
+ ckpt = torch.load(
479
+ self.cfg.pretrained_model_name_or_path,
480
+ map_location="cpu",
481
+ weights_only=False,
482
+ )
483
+ if "state_dict" in ckpt.keys():
484
+ ckpt = ckpt["state_dict"]
485
+ self.load_state_dict(ckpt, strict=True)
486
+
487
+ def identity_initialize(self):
488
+ for block in self.dit_model.blocks:
489
+ nn.init.constant_(block.attn.c_proj.weight, 0)
490
+ nn.init.constant_(block.attn.c_proj.bias, 0)
491
+ nn.init.constant_(block.cross_attn.c_proj.weight, 0)
492
+ nn.init.constant_(block.cross_attn.c_proj.bias, 0)
493
+ nn.init.constant_(block.mlp.c_proj.weight, 0)
494
+ nn.init.constant_(block.mlp.c_proj.bias, 0)
495
+
496
+ def forward(
497
+ self,
498
+ model_input: torch.FloatTensor,
499
+ timestep: torch.LongTensor,
500
+ visual_condition: Optional[torch.FloatTensor] = None,
501
+ caption_condition: Optional[torch.FloatTensor] = None,
502
+ label_condition: Optional[torch.FloatTensor] = None,
503
+ attention_kwargs: Dict[str, torch.Tensor] = None,
504
+ cross_attention_kwargs: Dict[str, Any] = None,
505
+ return_dict: bool = True,
506
+ ):
507
+ r"""
508
+ Args:
509
+ model_input (torch.FloatTensor): [bs, n_data, c]
510
+ timestep (torch.LongTensor): [bs,]
511
+ visual_condition (torch.FloatTensor): [bs, visual_context_tokens, c]
512
+ text_condition (torch.FloatTensor): [bs, text_context_tokens, c]
513
+
514
+ Returns:
515
+ sample (torch.FloatTensor): [bs, n_data, c]
516
+
517
+ """
518
+
519
+ B, n_data, _ = model_input.shape
520
+
521
+ # 0. conditions projector
522
+ condition = []
523
+ if self.cfg.use_visual_condition:
524
+ assert visual_condition.shape[-1] == self.cfg.visual_condition_dim
525
+ if self.cfg.visual_condition_dim != self.cfg.condition_dim:
526
+ visual_condition = self.proj_visual_condtion(visual_condition)
527
+ condition.append(visual_condition)
528
+ else:
529
+ visual_condition = None
530
+ if self.cfg.use_caption_condition:
531
+ assert caption_condition.shape[-1] == self.cfg.caption_condition_dim
532
+ if self.cfg.caption_condition_dim != self.cfg.condition_dim:
533
+ caption_condition = self.proj_caption_condtion(caption_condition)
534
+ condition.append(caption_condition)
535
+ else:
536
+ caption_condition = None
537
+ if self.cfg.use_label_condition:
538
+ assert label_condition.shape[-1] == self.cfg.label_condition_dim
539
+ if self.cfg.label_condition_dim != self.cfg.condition_dim:
540
+ label_condition = self.proj_label_condtion(label_condition)
541
+ condition.append(label_condition)
542
+ else:
543
+ label_condition = None
544
+ assert not (
545
+ visual_condition is None
546
+ and caption_condition is None
547
+ and label_condition is None
548
+ )
549
+
550
+ # 1. denoise
551
+ if self.cfg.multi_condition_type == "cross_attention":
552
+ output = self.dit_model(
553
+ model_input,
554
+ timestep,
555
+ visual_condition,
556
+ caption_condition,
557
+ label_condition,
558
+ cross_attention_kwargs,
559
+ return_dict=return_dict,
560
+ )
561
+ elif self.cfg.multi_condition_type == "in_context":
562
+ output = self.dit_model(
563
+ model_input,
564
+ timestep,
565
+ torch.cat(condition, dim=1),
566
+ None,
567
+ None,
568
+ cross_attention_kwargs,
569
+ return_dict=return_dict,
570
+ )
571
+ else:
572
+ raise ValueError
573
+
574
+ return output
step1x3d_geometry/systems/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import shape_autoencoder, shape_diffusion, shape_rectified_flow
step1x3d_geometry/systems/base.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+
4
+ import pytorch_lightning as pl
5
+ import torch.nn.functional as F
6
+
7
+ import step1x3d_geometry
8
+ from step1x3d_geometry.utils.base import (
9
+ Updateable,
10
+ update_end_if_possible,
11
+ update_if_possible,
12
+ )
13
+ from step1x3d_geometry.utils.scheduler import parse_optimizer, parse_scheduler
14
+ from step1x3d_geometry.utils.config import parse_structured
15
+ from step1x3d_geometry.utils.misc import C, cleanup, get_device, load_module_weights
16
+ from step1x3d_geometry.utils.saving import SaverMixin
17
+ from step1x3d_geometry.utils.typing import *
18
+
19
+
20
+ class BaseSystem(pl.LightningModule, Updateable, SaverMixin):
21
+ @dataclass
22
+ class Config:
23
+ loggers: dict = field(default_factory=dict)
24
+ loss: dict = field(default_factory=dict)
25
+ optimizer: dict = field(default_factory=dict)
26
+ scheduler: Optional[dict] = None
27
+ weights: Optional[str] = None
28
+ weights_ignore_modules: Optional[List[str]] = None
29
+ cleanup_after_validation_step: bool = False
30
+ cleanup_after_test_step: bool = False
31
+
32
+ pretrained_model_path: Optional[str] = None
33
+ strict_load: bool = True
34
+
35
+ cfg: Config
36
+
37
+ def __init__(self, cfg, resumed=False) -> None:
38
+ super().__init__()
39
+ self.cfg = parse_structured(self.Config, cfg)
40
+ self._save_dir: Optional[str] = None
41
+ self._resumed: bool = resumed
42
+ self._resumed_eval: bool = False
43
+ self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0}
44
+ if "loggers" in cfg:
45
+ self.create_loggers(cfg.loggers)
46
+
47
+ self.configure()
48
+ if self.cfg.weights is not None:
49
+ self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules)
50
+ self.post_configure()
51
+
52
+ def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None):
53
+ state_dict, epoch, global_step = load_module_weights(
54
+ weights, ignore_modules=ignore_modules, map_location="cpu"
55
+ )
56
+ self.load_state_dict(state_dict, strict=False)
57
+ # restore step-dependent states
58
+ self.do_update_step(epoch, global_step, on_load_weights=True)
59
+
60
+ def set_resume_status(self, current_epoch: int, global_step: int):
61
+ # restore correct epoch and global step in eval
62
+ self._resumed_eval = True
63
+ self._resumed_eval_status["current_epoch"] = current_epoch
64
+ self._resumed_eval_status["global_step"] = global_step
65
+
66
+ @property
67
+ def resumed(self):
68
+ # whether from resumed checkpoint
69
+ return self._resumed
70
+
71
+ @property
72
+ def true_global_step(self):
73
+ if self._resumed_eval:
74
+ return self._resumed_eval_status["global_step"]
75
+ else:
76
+ return self.global_step
77
+
78
+ @property
79
+ def true_current_epoch(self):
80
+ if self._resumed_eval:
81
+ return self._resumed_eval_status["current_epoch"]
82
+ else:
83
+ return self.current_epoch
84
+
85
+ def configure(self) -> None:
86
+ pass
87
+
88
+ def post_configure(self) -> None:
89
+ """
90
+ executed after weights are loaded
91
+ """
92
+ pass
93
+
94
+ def C(self, value: Any) -> float:
95
+ return C(value, self.true_current_epoch, self.true_global_step)
96
+
97
+ def configure_optimizers(self):
98
+ optim = parse_optimizer(self.cfg.optimizer, self)
99
+ ret = {
100
+ "optimizer": optim,
101
+ }
102
+ if self.cfg.scheduler is not None:
103
+ ret.update(
104
+ {
105
+ "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim),
106
+ }
107
+ )
108
+ return ret
109
+
110
+ def training_step(self, batch, batch_idx):
111
+ raise NotImplementedError
112
+
113
+ def validation_step(self, batch, batch_idx):
114
+ raise NotImplementedError
115
+
116
+ def on_train_batch_end(self, outputs, batch, batch_idx):
117
+ self.dataset = self.trainer.train_dataloader.dataset
118
+ update_end_if_possible(
119
+ self.dataset, self.true_current_epoch, self.true_global_step
120
+ )
121
+ self.do_update_step_end(self.true_current_epoch, self.true_global_step)
122
+
123
+ def on_validation_batch_end(self, outputs, batch, batch_idx):
124
+ self.dataset = self.trainer.val_dataloaders.dataset
125
+ update_end_if_possible(
126
+ self.dataset, self.true_current_epoch, self.true_global_step
127
+ )
128
+ self.do_update_step_end(self.true_current_epoch, self.true_global_step)
129
+ if self.cfg.cleanup_after_validation_step:
130
+ # cleanup to save vram
131
+ cleanup()
132
+
133
+ def on_validation_epoch_end(self):
134
+ raise NotImplementedError
135
+
136
+ def test_step(self, batch, batch_idx):
137
+ raise NotImplementedError
138
+
139
+ def on_test_batch_end(self, outputs, batch, batch_idx):
140
+ self.dataset = self.trainer.test_dataloaders.dataset
141
+ update_end_if_possible(
142
+ self.dataset, self.true_current_epoch, self.true_global_step
143
+ )
144
+ self.do_update_step_end(self.true_current_epoch, self.true_global_step)
145
+ if self.cfg.cleanup_after_test_step:
146
+ # cleanup to save vram
147
+ cleanup()
148
+
149
+ def on_test_epoch_end(self):
150
+ pass
151
+
152
+ def predict_step(self, batch, batch_idx):
153
+ raise NotImplementedError
154
+
155
+ def on_predict_batch_end(self, outputs, batch, batch_idx):
156
+ self.dataset = self.trainer.predict_dataloaders.dataset
157
+ update_end_if_possible(
158
+ self.dataset, self.true_current_epoch, self.true_global_step
159
+ )
160
+ self.do_update_step_end(self.true_current_epoch, self.true_global_step)
161
+ if self.cfg.cleanup_after_test_step:
162
+ # cleanup to save vram
163
+ cleanup()
164
+
165
+ def on_predict_epoch_end(self):
166
+ pass
167
+
168
+ def preprocess_data(self, batch, stage):
169
+ pass
170
+
171
+ """
172
+ Implementing on_after_batch_transfer of DataModule does the same.
173
+ But on_after_batch_transfer does not support DP.
174
+ """
175
+
176
+ def on_train_batch_start(self, batch, batch_idx, unused=0):
177
+ self.preprocess_data(batch, "train")
178
+ self.dataset = self.trainer.train_dataloader.dataset
179
+ update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
180
+ self.do_update_step(self.true_current_epoch, self.true_global_step)
181
+
182
+ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0):
183
+ self.preprocess_data(batch, "validation")
184
+ self.dataset = self.trainer.val_dataloaders.dataset
185
+ update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
186
+ self.do_update_step(self.true_current_epoch, self.true_global_step)
187
+
188
+ def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0):
189
+ self.preprocess_data(batch, "test")
190
+ self.dataset = self.trainer.test_dataloaders.dataset
191
+ update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
192
+ self.do_update_step(self.true_current_epoch, self.true_global_step)
193
+
194
+ def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0):
195
+ self.preprocess_data(batch, "predict")
196
+ self.dataset = self.trainer.predict_dataloaders.dataset
197
+ update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step)
198
+ self.do_update_step(self.true_current_epoch, self.true_global_step)
199
+
200
+ def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
201
+ pass
202
+
203
+ def on_before_optimizer_step(self, optimizer):
204
+ """
205
+ # some gradient-related debugging goes here, example:
206
+ from lightning.pytorch.utilities import grad_norm
207
+ norms = grad_norm(self.geometry, norm_type=2)
208
+ print(norms)
209
+ """
210
+ pass
step1x3d_geometry/systems/shape_autoencoder.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ import numpy as np
3
+ import torch
4
+ from skimage import measure
5
+ from einops import repeat, rearrange
6
+
7
+ import step1x3d_geometry
8
+ from step1x3d_geometry.systems.base import BaseSystem
9
+ from step1x3d_geometry.utils.ops import generate_dense_grid_points
10
+ from step1x3d_geometry.utils.typing import *
11
+ from step1x3d_geometry.utils.misc import get_rank
12
+
13
+
14
+ @step1x3d_geometry.register("shape-autoencoder-system")
15
+ class ShapeAutoEncoderSystem(BaseSystem):
16
+ @dataclass
17
+ class Config(BaseSystem.Config):
18
+ shape_model_type: str = None
19
+ shape_model: dict = field(default_factory=dict)
20
+
21
+ sample_posterior: bool = True
22
+
23
+ # for mesh extraction
24
+ bounds: float = 1.05
25
+ mc_level: float = 0.0
26
+ octree_resolution: int = 256
27
+
28
+ cfg: Config
29
+
30
+ def configure(self):
31
+ super().configure()
32
+
33
+ self.shape_model = step1x3d_geometry.find(self.cfg.shape_model_type)(
34
+ self.cfg.shape_model
35
+ )
36
+
37
+ def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
38
+ rand_points = batch["rand_points"]
39
+ if "sdf" in batch:
40
+ target = batch["sdf"]
41
+ criteria = torch.nn.MSELoss()
42
+ elif "occupancies" in batch:
43
+ target = batch["occupancies"]
44
+ criteria = torch.nn.BCEWithLogitsLoss()
45
+ else:
46
+ raise NotImplementedError
47
+
48
+ # forward pass
49
+ num_point_feats = 3 + self.cfg.shape_model.point_feats
50
+ shape_latents, kl_embed, posterior = self.shape_model.encode(
51
+ batch["surface"][..., :num_point_feats],
52
+ sharp_surface=(
53
+ batch["sharp_surface"][..., :num_point_feats]
54
+ if "sharp_surface" in batch
55
+ else None
56
+ ),
57
+ sample_posterior=self.cfg.sample_posterior,
58
+ )
59
+ latents = self.shape_model.decode(kl_embed) # [B, num_latents, width]
60
+ logits = self.shape_model.query(rand_points, latents).squeeze(
61
+ -1
62
+ ) # [B, num_rand_points]
63
+
64
+ if self.cfg.sample_posterior:
65
+ loss_kl = posterior.kl()
66
+ loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
67
+
68
+ return {
69
+ "loss_logits": criteria(logits, target).mean(),
70
+ "loss_kl": loss_kl,
71
+ "logits": logits,
72
+ "target": target,
73
+ "latents": latents,
74
+ }
75
+ else:
76
+ return {
77
+ "loss_logits": criteria(logits, target).mean(),
78
+ "latents": latents,
79
+ "logits": logits,
80
+ }
81
+
82
+ def training_step(self, batch, batch_idx):
83
+ """
84
+ Description:
85
+
86
+ Args:
87
+ batch:
88
+ batch_idx:
89
+ Returns:
90
+ loss:
91
+ """
92
+ out = self(batch)
93
+
94
+ loss = 0.0
95
+ for name, value in out.items():
96
+ if name.startswith("loss_"):
97
+ self.log(f"train/{name}", value)
98
+ loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])
99
+
100
+ for name, value in self.cfg.loss.items():
101
+ self.log(f"train_params/{name}", self.C(value))
102
+
103
+ return {"loss": loss}
104
+
105
+ @torch.no_grad()
106
+ def validation_step(self, batch, batch_idx):
107
+ self.eval()
108
+ out = self(batch)
109
+
110
+ meshes = self.shape_model.extract_geometry(
111
+ out["latents"],
112
+ bounds=self.cfg.bounds,
113
+ mc_level=self.cfg.mc_level,
114
+ octree_resolution=self.cfg.octree_resolution,
115
+ enable_pbar=False,
116
+ )
117
+ for idx, name in enumerate(batch["uid"]):
118
+ self.save_mesh(
119
+ f"it{self.true_global_step}/{name}.obj",
120
+ meshes[idx].verts,
121
+ meshes[idx].faces,
122
+ )
123
+
124
+ threshold = 0
125
+ outputs = out["logits"]
126
+ labels = out["target"]
127
+ pred = torch.zeros_like(outputs)
128
+ pred[outputs >= threshold] = 1
129
+
130
+ accuracy = (pred == labels).float().sum(dim=1) / labels.shape[1]
131
+ accuracy = accuracy.mean()
132
+ intersection = (pred * labels).sum(dim=1)
133
+ union = (pred + labels).gt(0).sum(dim=1)
134
+ iou = intersection * 1.0 / union + 1e-5
135
+ iou = iou.mean()
136
+ self.log("val/accuracy", accuracy)
137
+ self.log("val/iou", iou)
138
+
139
+ torch.cuda.empty_cache()
140
+
141
+ return {
142
+ "val/loss": out["loss_logits"],
143
+ "val/accuracy": accuracy,
144
+ "val/iou": iou,
145
+ }
146
+
147
+ def on_validation_epoch_end(self):
148
+ pass
149
+
150
+ def test_step(self, batch, batch_idx):
151
+ return
step1x3d_geometry/systems/shape_diffusion.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ from step1x3d_geometry.models.pipelines.pipeline import Step1X3DGeometryPipeline
4
+ import numpy as np
5
+ import json
6
+ import copy
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from skimage import measure
10
+ from einops import repeat
11
+ from tqdm import tqdm
12
+ from PIL import Image
13
+
14
+ from diffusers import (
15
+ DDPMScheduler,
16
+ DDIMScheduler,
17
+ UniPCMultistepScheduler,
18
+ KarrasVeScheduler,
19
+ DPMSolverMultistepScheduler,
20
+ )
21
+ from diffusers.training_utils import (
22
+ compute_snr,
23
+ free_memory,
24
+ )
25
+ import step1x3d_geometry
26
+ from step1x3d_geometry.systems.base import BaseSystem
27
+ from step1x3d_geometry.utils.misc import get_rank
28
+ from step1x3d_geometry.utils.typing import *
29
+ from diffusers import DDIMScheduler
30
+ from step1x3d_geometry.systems.utils import read_image, ddim_sample
31
+
32
+
33
+ # DEBUG = True
34
+ @step1x3d_geometry.register("diffusion-system")
35
+ class DiffusionSystem(BaseSystem):
36
+ @dataclass
37
+ class Config(BaseSystem.Config):
38
+ val_samples_json: str = ""
39
+ bounds: float = 1.05
40
+ mc_level: float = 0.0
41
+ octree_resolution: int = 256
42
+ skip_validation: bool = True
43
+
44
+ # diffusion config
45
+ z_scale_factor: float = 1.0
46
+ guidance_scale: float = 7.5
47
+ num_inference_steps: int = 50
48
+ eta: float = 0.0
49
+ snr_gamma: float = 5.0
50
+
51
+ # shape vae model
52
+ shape_model_type: str = None
53
+ shape_model: dict = field(default_factory=dict)
54
+
55
+ # condition model
56
+ visual_condition_type: Optional[str] = None
57
+ visual_condition: dict = field(default_factory=dict)
58
+ caption_condition_type: Optional[str] = None
59
+ caption_condition: dict = field(default_factory=dict)
60
+ label_condition_type: Optional[str] = None
61
+ label_condition: dict = field(default_factory=dict)
62
+
63
+ # diffusion model
64
+ denoiser_model_type: str = None
65
+ denoiser_model: dict = field(default_factory=dict)
66
+
67
+ # noise scheduler
68
+ noise_scheduler_type: str = None
69
+ noise_scheduler: dict = field(default_factory=dict)
70
+
71
+ # denoise scheduler
72
+ denoise_scheduler_type: str = None
73
+ denoise_scheduler: dict = field(default_factory=dict)
74
+
75
+ cfg: Config
76
+
77
+ def configure(self):
78
+ super().configure()
79
+
80
+ self.shape_model = step1x3d_geometry.find(self.cfg.shape_model_type)(
81
+ self.cfg.shape_model
82
+ )
83
+ self.shape_model.eval()
84
+ self.shape_model.requires_grad_(False)
85
+
86
+ if self.cfg.visual_condition_type is not None:
87
+ self.visual_condition = step1x3d_geometry.find(
88
+ self.cfg.visual_condition_type
89
+ )(self.cfg.visual_condition)
90
+
91
+ if self.cfg.caption_condition_type is not None:
92
+ self.caption_condition = step1x3d_geometry.find(
93
+ self.cfg.caption_condition_type
94
+ )(self.cfg.caption_condition)
95
+
96
+ if self.cfg.label_condition_type is not None:
97
+ self.label_condition = step1x3d_geometry.find(
98
+ self.cfg.label_condition_type
99
+ )(self.cfg.label_condition)
100
+
101
+ self.denoiser_model = step1x3d_geometry.find(self.cfg.denoiser_model_type)(
102
+ self.cfg.denoiser_model
103
+ )
104
+
105
+ self.noise_scheduler = step1x3d_geometry.find(self.cfg.noise_scheduler_type)(
106
+ **self.cfg.noise_scheduler
107
+ )
108
+
109
+ self.denoise_scheduler = step1x3d_geometry.find(
110
+ self.cfg.denoise_scheduler_type
111
+ )(**self.cfg.denoise_scheduler)
112
+
113
+ def forward(self, batch: Dict[str, Any], skip_noise=False) -> Dict[str, Any]:
114
+ # 1. encode shape latents
115
+ if "sharp_surface" in batch.keys():
116
+ sharp_surface = batch["sharp_surface"][
117
+ ..., : 3 + self.cfg.shape_model.point_feats
118
+ ]
119
+ else:
120
+ sharp_surface = None
121
+ shape_embeds, kl_embed, _ = self.shape_model.encode(
122
+ batch["surface"][..., : 3 + self.cfg.shape_model.point_feats],
123
+ sample_posterior=True,
124
+ sharp_surface=sharp_surface,
125
+ )
126
+
127
+ latents = kl_embed * self.cfg.z_scale_factor
128
+
129
+ # 2. gain visual condition
130
+ visual_cond_latents = None
131
+ if self.cfg.visual_condition_type is not None:
132
+ if "image" in batch and batch["image"].dim() == 5:
133
+ if self.training:
134
+ bs, n_images = batch["image"].shape[:2]
135
+ batch["image"] = batch["image"].view(
136
+ bs * n_images, *batch["image"].shape[-3:]
137
+ )
138
+ else:
139
+ batch["image"] = batch["image"][:, 0, ...]
140
+ n_images = 1
141
+ bs = batch["image"].shape[0]
142
+ visual_cond_latents = self.visual_condition(batch).to(latents)
143
+ latents = latents.unsqueeze(1).repeat(1, n_images, 1, 1)
144
+ latents = latents.view(bs * n_images, *latents.shape[-2:])
145
+ else:
146
+ visual_cond_latents = self.visual_condition(batch).to(latents)
147
+
148
+ ## 2.1 text condition if provided
149
+ caption_cond_latents = None
150
+ if self.cfg.caption_condition_type is not None:
151
+ assert "caption" in batch.keys(), "caption is required for caption encoder"
152
+ assert bs == len(
153
+ batch["caption"]
154
+ ), "Batch size must be the same as the caption length."
155
+ caption_cond_latents = (
156
+ self.caption_condition(batch)
157
+ .repeat_interleave(n_images, dim=0)
158
+ .to(latents)
159
+ )
160
+
161
+ ## 2.2 label condition if provided
162
+ label_cond_latents = None
163
+ if self.cfg.label_condition_type is not None:
164
+ assert "label" in batch.keys(), "label is required for label encoder"
165
+ assert bs == len(
166
+ batch["label"]
167
+ ), "Batch size must be the same as the label length."
168
+ label_cond_latents = (
169
+ self.label_condition(batch)
170
+ .repeat_interleave(n_images, dim=0)
171
+ .to(latents)
172
+ )
173
+
174
+ # 3. sample noise that we"ll add to the latents
175
+ noise = torch.randn_like(latents).to(
176
+ latents
177
+ ) # [batch_size, n_token, latent_dim]
178
+ bs = latents.shape[0]
179
+
180
+ # 4. Sample a random timestep for each motion
181
+ timesteps = torch.randint(
182
+ 0,
183
+ self.cfg.noise_scheduler.num_train_timesteps,
184
+ (bs,),
185
+ device=latents.device,
186
+ )
187
+ timesteps = timesteps.long()
188
+
189
+ # 5. add noise
190
+ noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
191
+
192
+ # 6. diffusion model forward
193
+ output = self.denoiser_model(
194
+ noisy_z,
195
+ timesteps.long(),
196
+ visual_cond_latents,
197
+ caption_cond_latents,
198
+ label_cond_latents,
199
+ ).sample
200
+
201
+ # 7. compute loss
202
+ if self.noise_scheduler.config.prediction_type == "epsilon":
203
+ target = noise
204
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
205
+ target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
206
+ else:
207
+ raise ValueError(
208
+ f"Prediction Type: {self.noise_scheduler.prediction_type} not supported."
209
+ )
210
+ if self.cfg.snr_gamma == 0:
211
+ if self.cfg.loss.loss_type == "l1":
212
+ loss = F.l1_loss(output, target, reduction="mean")
213
+ elif self.cfg.loss.loss_type in ["mse", "l2"]:
214
+ loss = F.mse_loss(output, target, reduction="mean")
215
+ else:
216
+ raise ValueError(f"Loss Type: {self.cfg.loss.loss_type} not supported.")
217
+ else:
218
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
219
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
220
+ # This is discussed in Section 4.2 of the same paper.
221
+ snr = compute_snr(self.noise_scheduler, timesteps)
222
+ mse_loss_weights = torch.stack(
223
+ [snr, self.cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
224
+ ).min(dim=1)[0]
225
+ if self.noise_scheduler.config.prediction_type == "epsilon":
226
+ mse_loss_weights = mse_loss_weights / snr
227
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
228
+ mse_loss_weights = mse_loss_weights / (snr + 1)
229
+
230
+ if self.cfg.loss.loss_type == "l1":
231
+ loss = F.l1_loss(output, target, reduction="none")
232
+ elif self.cfg.loss.loss_type in ["mse", "l2"]:
233
+ loss = F.mse_loss(output, target, reduction="none")
234
+ else:
235
+ raise ValueError(f"Loss Type: {self.cfg.loss.loss_type} not supported.")
236
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
237
+ loss = loss.mean()
238
+
239
+ return {
240
+ "loss_diffusion": loss,
241
+ "latents": latents,
242
+ "x_t": noisy_z,
243
+ "noise": noise,
244
+ "noise_pred": output,
245
+ "timesteps": timesteps,
246
+ }
247
+
248
+ def training_step(self, batch, batch_idx):
249
+ out = self(batch)
250
+
251
+ loss = 0.0
252
+ for name, value in out.items():
253
+ if name.startswith("loss_"):
254
+ self.log(f"train/{name}", value)
255
+ loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])
256
+
257
+ for name, value in self.cfg.loss.items():
258
+ if name.startswith("lambda_"):
259
+ self.log(f"train_params/{name}", self.C(value))
260
+
261
+ return {"loss": loss}
262
+
263
+ @torch.no_grad()
264
+ def validation_step(self, batch, batch_idx):
265
+ if self.cfg.skip_validation:
266
+ return {}
267
+ self.eval()
268
+
269
+ if get_rank() == 0:
270
+ sample_inputs = json.loads(
271
+ open(self.cfg.val_samples_json).read()
272
+ ) # condition
273
+ sample_inputs_ = copy.deepcopy(sample_inputs)
274
+ sample_outputs = self.sample(sample_inputs) # list
275
+ for i, latents in enumerate(sample_outputs["latents"]):
276
+ meshes = self.shape_model.extract_geometry(
277
+ latents,
278
+ bounds=self.cfg.bounds,
279
+ mc_level=self.cfg.mc_level,
280
+ octree_resolution=self.cfg.octree_resolution,
281
+ enable_pbar=False,
282
+ )
283
+
284
+ for j in range(len(meshes)):
285
+ name = ""
286
+ if "image" in sample_inputs_:
287
+ name += (
288
+ sample_inputs_["image"][j]
289
+ .split("/")[-1]
290
+ .replace(".png", "")
291
+ )
292
+ elif "mvimages" in sample_inputs_:
293
+ name += (
294
+ sample_inputs_["mvimages"][j][0]
295
+ .split("/")[-2]
296
+ .replace(".png", "")
297
+ )
298
+
299
+ if "caption" in sample_inputs_:
300
+ name += "_" + sample_inputs_["caption"][j].replace(" ", "_")
301
+
302
+ if "label" in sample_inputs_:
303
+ name += (
304
+ "_"
305
+ + sample_inputs_["label"][j]["symmetry"]
306
+ + sample_inputs_["label"][j]["edge_type"]
307
+ )
308
+
309
+ if (
310
+ meshes[j].verts is not None
311
+ and meshes[j].verts.shape[0] > 0
312
+ and meshes[j].faces is not None
313
+ and meshes[j].faces.shape[0] > 0
314
+ ):
315
+ self.save_mesh(
316
+ f"it{self.true_global_step}/{name}_{i}.obj",
317
+ meshes[j].verts,
318
+ meshes[j].faces,
319
+ )
320
+ torch.cuda.empty_cache()
321
+
322
+ out = self(batch)
323
+ if self.global_step == 0:
324
+ latents = self.shape_model.decode(out["latents"])
325
+ meshes = self.shape_model.extract_geometry(
326
+ latents,
327
+ bounds=self.cfg.bounds,
328
+ mc_level=self.cfg.mc_level,
329
+ octree_resolution=self.cfg.octree_resolution,
330
+ enable_pbar=False,
331
+ )
332
+
333
+ for i, mesh in enumerate(meshes):
334
+ self.save_mesh(
335
+ f"it{self.true_global_step}/{batch['uid'][i]}.obj",
336
+ mesh.verts,
337
+ mesh.faces,
338
+ )
339
+
340
+ return {"val/loss": out["loss_diffusion"]}
341
+
342
+ @torch.no_grad()
343
+ def sample(
344
+ self,
345
+ sample_inputs: Dict[str, Union[torch.FloatTensor, List[str]]],
346
+ sample_times: int = 1,
347
+ steps: Optional[int] = None,
348
+ guidance_scale: Optional[float] = None,
349
+ eta: float = 0.0,
350
+ seed: Optional[int] = None,
351
+ **kwargs,
352
+ ):
353
+
354
+ if steps is None:
355
+ steps = self.cfg.num_inference_steps
356
+ if guidance_scale is None:
357
+ guidance_scale = self.cfg.guidance_scale
358
+ do_classifier_free_guidance = guidance_scale != 1.0
359
+
360
+ # conditional encode
361
+ visal_cond = None
362
+ if "image" in sample_inputs:
363
+ sample_inputs["image"] = [
364
+ Image.open(img) if type(img) == str else img
365
+ for img in sample_inputs["image"]
366
+ ]
367
+ sample_inputs["image"] = Step1X3DGeometryPipeline.preprocess_image(
368
+ sample_inputs["image"], **kwargs
369
+ )
370
+ cond = self.visual_condition.encode_image(sample_inputs["image"])
371
+ if do_classifier_free_guidance:
372
+ un_cond = self.visual_condition.empty_image_embeds.repeat(
373
+ len(sample_inputs["image"]), 1, 1
374
+ ).to(cond)
375
+ visal_cond = torch.cat([un_cond, cond], dim=0)
376
+ caption_cond = None
377
+ if "caption" in sample_inputs:
378
+ cond = self.label_condition.encode_label(sample_inputs["caption"])
379
+ if do_classifier_free_guidance:
380
+ un_cond = self.caption_condition.empty_caption_embeds.repeat(
381
+ len(sample_inputs["caption"]), 1, 1
382
+ ).to(cond)
383
+ caption_cond = torch.cat([un_cond, cond], dim=0)
384
+ label_cond = None
385
+ if "label" in sample_inputs:
386
+ cond = self.label_condition.encode_label(sample_inputs["label"])
387
+ if do_classifier_free_guidance:
388
+ un_cond = self.label_condition.empty_label_embeds.repeat(
389
+ len(sample_inputs["label"]), 1
390
+ ).to(cond)
391
+ label_cond = torch.cat([un_cond, cond], dim=0)
392
+
393
+ latents_list = []
394
+ if seed != None:
395
+ generator = torch.Generator(device="cuda").manual_seed(seed)
396
+ else:
397
+ generator = None
398
+
399
+ for _ in range(sample_times):
400
+ sample_loop = ddim_sample(
401
+ self.denoise_scheduler,
402
+ self.denoiser_model.eval(),
403
+ shape=self.shape_model.latent_shape,
404
+ visual_cond=visal_cond,
405
+ caption_cond=caption_cond,
406
+ label_cond=label_cond,
407
+ steps=steps,
408
+ guidance_scale=guidance_scale,
409
+ do_classifier_free_guidance=do_classifier_free_guidance,
410
+ device=self.device,
411
+ eta=eta,
412
+ disable_prog=False,
413
+ generator=generator,
414
+ )
415
+ for sample, t in sample_loop:
416
+ latents = sample
417
+ latents_list.append(self.shape_model.decode(latents))
418
+
419
+ return {"latents": latents_list, "inputs": sample_inputs}
420
+
421
+ def on_validation_epoch_end(self):
422
+ pass
423
+
424
+ def test_step(self, batch, batch_idx):
425
+ return
step1x3d_geometry/systems/shape_rectified_flow.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ import numpy as np
4
+ import json
5
+ import copy
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from skimage import measure
10
+ from einops import repeat
11
+ from tqdm import tqdm
12
+ from PIL import Image
13
+
14
+ from diffusers import (
15
+ DDPMScheduler,
16
+ DDIMScheduler,
17
+ UniPCMultistepScheduler,
18
+ KarrasVeScheduler,
19
+ DPMSolverMultistepScheduler,
20
+ )
21
+ from diffusers.training_utils import (
22
+ compute_density_for_timestep_sampling,
23
+ compute_loss_weighting_for_sd3,
24
+ free_memory,
25
+ )
26
+ import step1x3d_geometry
27
+ from step1x3d_geometry.systems.base import BaseSystem
28
+ from step1x3d_geometry.utils.misc import get_rank
29
+ from step1x3d_geometry.utils.typing import *
30
+ from step1x3d_geometry.systems.utils import read_image, preprocess_image, flow_sample
31
+
32
+
33
+ def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32):
34
+ sigmas = noise_scheduler.sigmas.to(device=timesteps.device, dtype=dtype)
35
+ schedule_timesteps = noise_scheduler.timesteps.to(timesteps.device)
36
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
37
+
38
+ sigma = sigmas[step_indices].flatten()
39
+ while len(sigma.shape) < n_dim:
40
+ sigma = sigma.unsqueeze(-1)
41
+ return sigma
42
+
43
+
44
+ @step1x3d_geometry.register("rectified-flow-system")
45
+ class RectifiedFlowSystem(BaseSystem):
46
+ @dataclass
47
+ class Config(BaseSystem.Config):
48
+ skip_validation: bool = True
49
+ val_samples_json: str = ""
50
+ bounds: float = 1.05
51
+ mc_level: float = 0.0
52
+ octree_resolution: int = 256
53
+
54
+ # diffusion config
55
+ guidance_scale: float = 7.5
56
+ num_inference_steps: int = 30
57
+ eta: float = 0.0
58
+ snr_gamma: float = 5.0
59
+
60
+ # flow
61
+ weighting_scheme: str = "logit_normal"
62
+ logit_mean: float = 0
63
+ logit_std: float = 1.0
64
+ mode_scale: float = 1.29
65
+ precondition_outputs: bool = True
66
+ precondition_t: int = 1000
67
+
68
+ # shape vae model
69
+ shape_model_type: str = None
70
+ shape_model: dict = field(default_factory=dict)
71
+
72
+ # condition model
73
+ visual_condition_type: Optional[str] = None
74
+ visual_condition: dict = field(default_factory=dict)
75
+ caption_condition_type: Optional[str] = None
76
+ caption_condition: dict = field(default_factory=dict)
77
+ label_condition_type: Optional[str] = None
78
+ label_condition: dict = field(default_factory=dict)
79
+
80
+ # diffusion model
81
+ denoiser_model_type: str = None
82
+ denoiser_model: dict = field(default_factory=dict)
83
+
84
+ # noise scheduler
85
+ noise_scheduler_type: str = None
86
+ noise_scheduler: dict = field(default_factory=dict)
87
+
88
+ # denoise scheduler
89
+ denoise_scheduler_type: str = None
90
+ denoise_scheduler: dict = field(default_factory=dict)
91
+
92
+ # lora
93
+ use_lora: bool = False
94
+ lora_layers: Optional[str] = None
95
+ rank: int = 128 # The dimension of the LoRA update matrices.
96
+ alpha: int = 128
97
+
98
+ cfg: Config
99
+
100
+ def configure(self):
101
+ super().configure()
102
+
103
+ self.shape_model = step1x3d_geometry.find(self.cfg.shape_model_type)(
104
+ self.cfg.shape_model
105
+ )
106
+ self.shape_model.eval()
107
+ self.shape_model.requires_grad_(False)
108
+
109
+ if self.cfg.visual_condition_type is not None:
110
+ self.visual_condition = step1x3d_geometry.find(
111
+ self.cfg.visual_condition_type
112
+ )(self.cfg.visual_condition)
113
+ self.visual_condition.requires_grad_(False)
114
+
115
+ if self.cfg.caption_condition_type is not None:
116
+ self.caption_condition = step1x3d_geometry.find(
117
+ self.cfg.caption_condition_type
118
+ )(self.cfg.caption_condition)
119
+ self.caption_condition.requires_grad_(False)
120
+
121
+ if self.cfg.label_condition_type is not None:
122
+ self.label_condition = step1x3d_geometry.find(
123
+ self.cfg.label_condition_type
124
+ )(self.cfg.label_condition)
125
+
126
+ self.denoiser_model = step1x3d_geometry.find(self.cfg.denoiser_model_type)(
127
+ self.cfg.denoiser_model
128
+ )
129
+ if self.cfg.use_lora: # We only train the additional adapter LoRA layers
130
+ self.denoiser_model.requires_grad_(False)
131
+
132
+ self.noise_scheduler = step1x3d_geometry.find(self.cfg.noise_scheduler_type)(
133
+ **self.cfg.noise_scheduler
134
+ )
135
+ self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
136
+
137
+ self.denoise_scheduler = step1x3d_geometry.find(
138
+ self.cfg.denoise_scheduler_type
139
+ )(**self.cfg.denoise_scheduler)
140
+
141
+ if self.cfg.use_lora:
142
+ from peft import LoraConfig, set_peft_model_state_dict
143
+
144
+ if self.cfg.lora_layers is not None:
145
+ self.target_modules = [
146
+ layer.strip() for layer in self.cfg.lora_layers.split(",")
147
+ ]
148
+ else:
149
+ self.target_modules = [
150
+ "attn.to_k",
151
+ "attn.to_q",
152
+ "attn.to_v",
153
+ "attn.to_out.0",
154
+ "attn.add_k_proj",
155
+ "attn.add_q_proj",
156
+ "attn.add_v_proj",
157
+ "attn.to_add_out",
158
+ "ff.net.0.proj",
159
+ "ff.net.2",
160
+ "ff_context.net.0.proj",
161
+ "ff_context.net.2",
162
+ ]
163
+ self.transformer_lora_config = LoraConfig(
164
+ r=self.cfg.rank,
165
+ lora_alpha=self.cfg.alpha,
166
+ init_lora_weights="gaussian",
167
+ target_modules=self.target_modules,
168
+ )
169
+ self.denoiser_model.dit_model.add_adapter(self.transformer_lora_config)
170
+
171
+ def forward(self, batch: Dict[str, Any], skip_noise=False) -> Dict[str, Any]:
172
+ # 1. encode shape latents
173
+ if "sharp_surface" in batch.keys():
174
+ sharp_surface = batch["sharp_surface"][
175
+ ..., : 3 + self.cfg.shape_model.point_feats
176
+ ]
177
+ else:
178
+ sharp_surface = None
179
+ shape_embeds, latents, _ = self.shape_model.encode(
180
+ batch["surface"][..., : 3 + self.cfg.shape_model.point_feats],
181
+ sample_posterior=True,
182
+ sharp_surface=sharp_surface,
183
+ )
184
+
185
+ # 2. gain visual condition
186
+ visual_cond = None
187
+ if self.cfg.visual_condition_type is not None:
188
+ assert "image" in batch.keys(), "image is required for label encoder"
189
+ if "image" in batch and batch["image"].dim() == 5:
190
+ if self.training:
191
+ bs, n_images = batch["image"].shape[:2]
192
+ batch["image"] = batch["image"].view(
193
+ bs * n_images, *batch["image"].shape[-3:]
194
+ )
195
+ else:
196
+ batch["image"] = batch["image"][:, 0, ...]
197
+ n_images = 1
198
+ bs = batch["image"].shape[0]
199
+ visual_cond = self.visual_condition(batch).to(latents)
200
+ latents = latents.unsqueeze(1).repeat(1, n_images, 1, 1)
201
+ latents = latents.view(bs * n_images, *latents.shape[-2:])
202
+ else:
203
+ visual_cond = self.visual_condition(batch).to(latents)
204
+ bs = visual_cond.shape[0]
205
+ n_images = 1
206
+
207
+ ## 2.1 text condition if provided
208
+ caption_cond = None
209
+ if self.cfg.caption_condition_type is not None:
210
+ assert "caption" in batch.keys(), "caption is required for caption encoder"
211
+ assert bs == len(
212
+ batch["caption"]
213
+ ), "Batch size must be the same as the caption length."
214
+ caption_cond = (
215
+ self.caption_condition(batch)
216
+ .repeat_interleave(n_images, dim=0)
217
+ .to(latents)
218
+ )
219
+
220
+ ## 2.2 label condition if provided
221
+ label_cond = None
222
+ if self.cfg.label_condition_type is not None:
223
+ assert "label" in batch.keys(), "label is required for label encoder"
224
+ assert bs == len(
225
+ batch["label"]
226
+ ), "Batch size must be the same as the label length."
227
+ label_cond = (
228
+ self.label_condition(batch)
229
+ .repeat_interleave(n_images, dim=0)
230
+ .to(latents)
231
+ )
232
+
233
+ # 3. sample noise that we"ll add to the latents
234
+ noise = torch.randn_like(latents).to(
235
+ latents
236
+ ) # [batch_size, n_token, latent_dim]
237
+
238
+ # 4. Sample a random timestep
239
+ u = compute_density_for_timestep_sampling(
240
+ weighting_scheme=self.cfg.weighting_scheme,
241
+ batch_size=bs * n_images,
242
+ logit_mean=self.cfg.logit_mean,
243
+ logit_std=self.cfg.logit_std,
244
+ mode_scale=self.cfg.mode_scale,
245
+ )
246
+ indices = (u * self.cfg.noise_scheduler.num_train_timesteps).long()
247
+ timesteps = self.noise_scheduler_copy.timesteps[indices].to(
248
+ device=latents.device
249
+ )
250
+
251
+ # 5. add noise
252
+ sigmas = get_sigmas(
253
+ self.noise_scheduler_copy, timesteps, n_dim=3, dtype=latents.dtype
254
+ )
255
+ noisy_z = (1.0 - sigmas) * latents + sigmas * noise
256
+
257
+ # 6. diffusion model forward
258
+ output = self.denoiser_model(
259
+ noisy_z, timesteps.long(), visual_cond, caption_cond, label_cond
260
+ ).sample
261
+
262
+ # 7. compute loss
263
+ if self.cfg.precondition_outputs:
264
+ output = output * (-sigmas) + noisy_z
265
+ # these weighting schemes use a uniform timestep sampling
266
+ # and instead post-weight the loss
267
+ weighting = compute_loss_weighting_for_sd3(
268
+ weighting_scheme=self.cfg.weighting_scheme, sigmas=sigmas
269
+ )
270
+ # flow matching loss
271
+ if self.cfg.precondition_outputs:
272
+ target = latents
273
+ else:
274
+ target = noise - latents
275
+
276
+ # Compute regular loss.
277
+ loss = torch.mean(
278
+ (weighting.float() * (output.float() - target.float()) ** 2).reshape(
279
+ target.shape[0], -1
280
+ ),
281
+ 1,
282
+ )
283
+ loss = loss.mean()
284
+
285
+ return {
286
+ "loss_diffusion": loss,
287
+ "latents": latents,
288
+ "x_t": noisy_z,
289
+ "noise": noise,
290
+ "noise_pred": output,
291
+ "timesteps": timesteps,
292
+ }
293
+
294
+ def training_step(self, batch, batch_idx):
295
+ out = self(batch)
296
+
297
+ loss = 0.0
298
+ for name, value in out.items():
299
+ if name.startswith("loss_"):
300
+ self.log(f"train/{name}", value)
301
+ loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])
302
+ if name.startswith("log_"):
303
+ self.log(f"log/{name.replace('log_', '')}", value.mean())
304
+
305
+ for name, value in self.cfg.loss.items():
306
+ if name.startswith("lambda_"):
307
+ self.log(f"train_params/{name}", self.C(value))
308
+
309
+ return {"loss": loss}
310
+
311
+ @torch.no_grad()
312
+ def validation_step(self, batch, batch_idx):
313
+ if self.cfg.skip_validation:
314
+ return {}
315
+ self.eval()
316
+
317
+ if get_rank() == 0:
318
+ sample_inputs = json.loads(
319
+ open(self.cfg.val_samples_json).read()
320
+ ) # condition
321
+ sample_inputs_ = copy.deepcopy(sample_inputs)
322
+ sample_outputs = self.sample(sample_inputs) # list
323
+ for i, latents in enumerate(sample_outputs["latents"]):
324
+ meshes = self.shape_model.extract_geometry(
325
+ latents,
326
+ bounds=self.cfg.bounds,
327
+ mc_level=self.cfg.mc_level,
328
+ octree_resolution=self.cfg.octree_resolution,
329
+ enable_pbar=False,
330
+ )
331
+
332
+ for j in range(len(meshes)):
333
+ name = ""
334
+ if "image" in sample_inputs_:
335
+ name += (
336
+ sample_inputs_["image"][j]
337
+ .split("/")[-1]
338
+ .replace(".png", "")
339
+ )
340
+
341
+ elif "mvimages" in sample_inputs_:
342
+ name += (
343
+ sample_inputs_["mvimages"][j][0]
344
+ .split("/")[-2]
345
+ .replace(".png", "")
346
+ )
347
+
348
+ if "caption" in sample_inputs_:
349
+ name += "_" + sample_inputs_["caption"][j].replace(
350
+ " ", "_"
351
+ ).replace(".", "")
352
+
353
+ if "label" in sample_inputs_:
354
+ name += (
355
+ "_"
356
+ + sample_inputs_["label"][j]["symmetry"]
357
+ + sample_inputs_["label"][j]["edge_type"]
358
+ )
359
+
360
+ if (
361
+ meshes[j].verts is not None
362
+ and meshes[j].verts.shape[0] > 0
363
+ and meshes[j].faces is not None
364
+ and meshes[j].faces.shape[0] > 0
365
+ ):
366
+ self.save_mesh(
367
+ f"it{self.true_global_step}/{name}_{i}.obj",
368
+ meshes[j].verts,
369
+ meshes[j].faces,
370
+ )
371
+ torch.cuda.empty_cache()
372
+
373
+ out = self(batch)
374
+ if self.global_step == 0:
375
+ latents = self.shape_model.decode(out["latents"])
376
+ meshes = self.shape_model.extract_geometry(
377
+ latents,
378
+ bounds=self.cfg.bounds,
379
+ mc_level=self.cfg.mc_level,
380
+ octree_resolution=self.cfg.octree_resolution,
381
+ enable_pbar=False,
382
+ )
383
+
384
+ for i, mesh in enumerate(meshes):
385
+ self.save_mesh(
386
+ f"it{self.true_global_step}/{batch['uid'][i]}.obj",
387
+ mesh.verts,
388
+ mesh.faces,
389
+ )
390
+
391
+ return {"val/loss": out["loss_diffusion"]}
392
+
393
+ @torch.no_grad()
394
+ def sample(
395
+ self,
396
+ sample_inputs: Dict[str, Union[torch.FloatTensor, List[str]]],
397
+ sample_times: int = 1,
398
+ steps: Optional[int] = None,
399
+ guidance_scale: Optional[float] = None,
400
+ eta: float = 0.0,
401
+ seed: Optional[int] = None,
402
+ **kwargs,
403
+ ):
404
+
405
+ if steps is None:
406
+ steps = self.cfg.num_inference_steps
407
+ if guidance_scale is None:
408
+ guidance_scale = self.cfg.guidance_scale
409
+ do_classifier_free_guidance = guidance_scale != 1.0
410
+
411
+ # conditional encode
412
+ visal_cond = None
413
+ if "image" in sample_inputs:
414
+ sample_inputs["image"] = [
415
+ Image.open(img) if type(img) == str else img
416
+ for img in sample_inputs["image"]
417
+ ]
418
+ sample_inputs["image"] = preprocess_image(sample_inputs["image"], **kwargs)
419
+ cond = self.visual_condition.encode_image(sample_inputs["image"])
420
+ if do_classifier_free_guidance:
421
+ un_cond = self.visual_condition.empty_image_embeds.repeat(
422
+ len(sample_inputs["image"]), 1, 1
423
+ ).to(cond)
424
+ visal_cond = torch.cat([un_cond, cond], dim=0)
425
+ caption_cond = None
426
+ if "caption" in sample_inputs:
427
+ cond = self.label_condition.encode_label(sample_inputs["caption"])
428
+ if do_classifier_free_guidance:
429
+ un_cond = self.caption_condition.empty_caption_embeds.repeat(
430
+ len(sample_inputs["caption"]), 1, 1
431
+ ).to(cond)
432
+ caption_cond = torch.cat([un_cond, cond], dim=0)
433
+ label_cond = None
434
+ if "label" in sample_inputs:
435
+ cond = self.label_condition.encode_label(sample_inputs["label"])
436
+ if do_classifier_free_guidance:
437
+ un_cond = self.label_condition.empty_label_embeds.repeat(
438
+ len(sample_inputs["label"]), 1, 1
439
+ ).to(cond)
440
+ label_cond = torch.cat([un_cond, cond], dim=0)
441
+
442
+ latents_list = []
443
+ if seed != None:
444
+ generator = torch.Generator(device="cuda").manual_seed(seed)
445
+ else:
446
+ generator = None
447
+
448
+ for _ in range(sample_times):
449
+ sample_loop = flow_sample(
450
+ self.denoise_scheduler,
451
+ self.denoiser_model.eval(),
452
+ shape=self.shape_model.latent_shape,
453
+ visual_cond=visal_cond,
454
+ caption_cond=caption_cond,
455
+ label_cond=label_cond,
456
+ steps=steps,
457
+ guidance_scale=guidance_scale,
458
+ do_classifier_free_guidance=do_classifier_free_guidance,
459
+ device=self.device,
460
+ eta=eta,
461
+ disable_prog=False,
462
+ generator=generator,
463
+ )
464
+ for sample, t in sample_loop:
465
+ latents = sample
466
+ latents_list.append(self.shape_model.decode(latents))
467
+
468
+ return {"latents": latents_list, "inputs": sample_inputs}
469
+
470
+ def on_validation_epoch_end(self):
471
+ pass
472
+
473
+ def test_step(self, batch, batch_idx):
474
+ return
step1x3d_geometry/systems/utils.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ import rembg
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from diffusers import DDIMScheduler
8
+ from torchvision import transforms
9
+
10
+ from step1x3d_geometry.utils.typing import *
11
+ from step1x3d_geometry.utils.misc import get_device
12
+
13
+
14
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
15
+ def retrieve_timesteps(
16
+ scheduler,
17
+ num_inference_steps: Optional[int] = None,
18
+ device: Optional[Union[str, torch.device]] = None,
19
+ timesteps: Optional[List[int]] = None,
20
+ sigmas: Optional[List[float]] = None,
21
+ **kwargs,
22
+ ):
23
+ r"""
24
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
25
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
26
+
27
+ Args:
28
+ scheduler (`SchedulerMixin`):
29
+ The scheduler to get timesteps from.
30
+ num_inference_steps (`int`):
31
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
32
+ must be `None`.
33
+ device (`str` or `torch.device`, *optional*):
34
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
35
+ timesteps (`List[int]`, *optional*):
36
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
37
+ `num_inference_steps` and `sigmas` must be `None`.
38
+ sigmas (`List[float]`, *optional*):
39
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
40
+ `num_inference_steps` and `timesteps` must be `None`.
41
+
42
+ Returns:
43
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
44
+ second element is the number of inference steps.
45
+ """
46
+ if timesteps is not None and sigmas is not None:
47
+ raise ValueError(
48
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
49
+ )
50
+ if timesteps is not None:
51
+ accepts_timesteps = "timesteps" in set(
52
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
53
+ )
54
+ if not accepts_timesteps:
55
+ raise ValueError(
56
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
57
+ f" timestep schedules. Please check whether you are using the correct scheduler."
58
+ )
59
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
60
+ timesteps = scheduler.timesteps
61
+ num_inference_steps = len(timesteps)
62
+ elif sigmas is not None:
63
+ accept_sigmas = "sigmas" in set(
64
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
65
+ )
66
+ if not accept_sigmas:
67
+ raise ValueError(
68
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
69
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
70
+ )
71
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
72
+ timesteps = scheduler.timesteps
73
+ num_inference_steps = len(timesteps)
74
+ else:
75
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
76
+ timesteps = scheduler.timesteps
77
+ return timesteps, num_inference_steps
78
+
79
+
80
+ @torch.no_grad()
81
+ def ddim_sample(
82
+ ddim_scheduler: DDIMScheduler,
83
+ diffusion_model: torch.nn.Module,
84
+ shape: Union[List[int], Tuple[int]],
85
+ visual_cond: torch.FloatTensor,
86
+ caption_cond: torch.FloatTensor,
87
+ label_cond: torch.FloatTensor,
88
+ steps: int,
89
+ eta: float = 0.0,
90
+ guidance_scale: float = 3.0,
91
+ do_classifier_free_guidance: bool = True,
92
+ generator: Optional[torch.Generator] = None,
93
+ device: torch.device = "cuda:0",
94
+ disable_prog: bool = True,
95
+ ):
96
+
97
+ assert steps > 0, f"{steps} must > 0."
98
+
99
+ # init latents
100
+ if visual_cond is not None:
101
+ bsz = visual_cond.shape[0]
102
+ device = visual_cond.device
103
+ dtype = visual_cond.dtype
104
+ if caption_cond is not None:
105
+ bsz = caption_cond.shape[0]
106
+ device = caption_cond.device
107
+ dtype = caption_cond.dtype
108
+ if label_cond is not None:
109
+ bsz = label_cond.shape[0]
110
+ device = label_cond.device
111
+ dtype = label_cond.dtype
112
+
113
+ if do_classifier_free_guidance:
114
+ bsz = bsz // 2
115
+ latents = torch.randn(
116
+ (bsz, *shape),
117
+ generator=generator,
118
+ device=device,
119
+ dtype=dtype,
120
+ )
121
+ try:
122
+ # scale the initial noise by the standard deviation required by the scheduler
123
+ latents = latents * scheduler.init_noise_sigma
124
+ except AttributeError:
125
+ pass
126
+
127
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
128
+ extra_step_kwargs = {"generator": generator}
129
+
130
+ # set timesteps
131
+ timesteps, num_inference_steps = retrieve_timesteps(
132
+ scheduler,
133
+ steps,
134
+ device,
135
+ )
136
+ if eta > 0:
137
+ assert 0 <= eta <= 1, f"eta must be between [0, 1]. Got {eta}."
138
+ assert (
139
+ scheduler.__class__.__name__ == "DDIMScheduler"
140
+ ), f"eta is only used with the DDIMScheduler."
141
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
142
+ # eta (η) is only used with the DDIMScheduler, and between [0, 1]
143
+ extra_step_kwargs["eta"] = eta
144
+
145
+ # reverse
146
+ for i, t in enumerate(
147
+ tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)
148
+ ):
149
+ # expand the latents if we are doing classifier free guidance
150
+ latent_model_input = (
151
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
152
+ )
153
+
154
+ # predict the noise residual
155
+ timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
156
+ timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
157
+ noise_pred = diffusion_model.forward(
158
+ latent_model_input, timestep_tensor, visual_cond, caption_cond, label_cond
159
+ ).sample
160
+
161
+ # perform guidance
162
+ if do_classifier_free_guidance:
163
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
164
+ noise_pred = noise_pred_uncond + guidance_scale * (
165
+ noise_pred_text - noise_pred_uncond
166
+ )
167
+
168
+ # compute the previous noisy sample x_t -> x_t-1
169
+ latents = ddim_scheduler.step(
170
+ noise_pred, t, latents, **extra_step_kwargs
171
+ ).prev_sample
172
+
173
+ yield latents, t
174
+
175
+
176
+ @torch.no_grad()
177
+ def flow_sample(
178
+ scheduler: DDIMScheduler,
179
+ diffusion_model: torch.nn.Module,
180
+ shape: Union[List[int], Tuple[int]],
181
+ visual_cond: torch.FloatTensor,
182
+ caption_cond: torch.FloatTensor,
183
+ label_cond: torch.FloatTensor,
184
+ steps: int,
185
+ eta: float = 0.0,
186
+ guidance_scale: float = 3.0,
187
+ do_classifier_free_guidance: bool = True,
188
+ generator: Optional[torch.Generator] = None,
189
+ device: torch.device = "cuda:0",
190
+ disable_prog: bool = True,
191
+ ):
192
+
193
+ assert steps > 0, f"{steps} must > 0."
194
+
195
+ # init latents
196
+ if visual_cond is not None:
197
+ bsz = visual_cond.shape[0]
198
+ device = visual_cond.device
199
+ dtype = visual_cond.dtype
200
+ if caption_cond is not None:
201
+ bsz = caption_cond.shape[0]
202
+ device = caption_cond.device
203
+ dtype = caption_cond.dtype
204
+ if label_cond is not None:
205
+ bsz = label_cond.shape[0]
206
+ device = label_cond.device
207
+ dtype = label_cond.dtype
208
+
209
+ if do_classifier_free_guidance:
210
+ bsz = bsz // 2
211
+ latents = torch.randn(
212
+ (bsz, *shape),
213
+ generator=generator,
214
+ device=device,
215
+ dtype=dtype,
216
+ )
217
+ try:
218
+ # scale the initial noise by the standard deviation required by the scheduler
219
+ latents = latents * scheduler.init_noise_sigma
220
+ except AttributeError:
221
+ pass
222
+
223
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
224
+ extra_step_kwargs = {"generator": generator}
225
+
226
+ # set timesteps
227
+ timesteps, num_inference_steps = retrieve_timesteps(
228
+ scheduler,
229
+ steps + 1,
230
+ device,
231
+ )
232
+ if eta > 0:
233
+ assert 0 <= eta <= 1, f"eta must be between [0, 1]. Got {eta}."
234
+ assert (
235
+ scheduler.__class__.__name__ == "DDIMScheduler"
236
+ ), f"eta is only used with the DDIMScheduler."
237
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
238
+ # eta (η) is only used with the DDIMScheduler, and between [0, 1]
239
+ extra_step_kwargs["eta"] = eta
240
+
241
+ # reverse
242
+ distance = (timesteps[:-1] - timesteps[1:]) / scheduler.config.num_train_timesteps
243
+ for i, t in enumerate(
244
+ tqdm(timesteps[:-1], disable=disable_prog, desc="Flow Sampling:", leave=False)
245
+ ):
246
+ # expand the latents if we are doing classifier free guidance
247
+ latent_model_input = (
248
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
249
+ )
250
+ # predict the noise residual
251
+ timestep_tensor = torch.tensor([t], dtype=latents.dtype, device=device)
252
+ timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
253
+ noise_pred = diffusion_model.forward(
254
+ latent_model_input, timestep_tensor, visual_cond, caption_cond, label_cond
255
+ ).sample
256
+ if isinstance(noise_pred, tuple):
257
+ noise_pred, layer_idx_list, ones_list, pred_c_list = noise_pred
258
+
259
+ # perform guidance
260
+ if do_classifier_free_guidance:
261
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
262
+ noise_pred = noise_pred_uncond + guidance_scale * (
263
+ noise_pred_text - noise_pred_uncond
264
+ )
265
+
266
+ # compute the previous noisy sample x_t -> x_t-1
267
+ latents = latents - distance[i] * noise_pred
268
+
269
+ yield latents, t
270
+
271
+
272
+ def compute_snr(noise_scheduler, timesteps):
273
+ """
274
+ Computes SNR as per
275
+ https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
276
+ """
277
+ alphas_cumprod = noise_scheduler.alphas_cumprod
278
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
279
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
280
+
281
+ # Expand the tensors.
282
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
283
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
284
+ timesteps
285
+ ].float()
286
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
287
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
288
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
289
+
290
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
291
+ device=timesteps.device
292
+ )[timesteps].float()
293
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
294
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
295
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
296
+
297
+ # Compute SNR.
298
+ snr = (alpha / sigma) ** 2
299
+ return snr
300
+
301
+
302
+ def read_image(img, img_size=224):
303
+ transform = transforms.Compose(
304
+ [
305
+ transforms.Resize(
306
+ img_size, transforms.InterpolationMode.BICUBIC, antialias=True
307
+ ),
308
+ transforms.CenterCrop(img_size), # crop a (224, 224) square
309
+ transforms.ToTensor(),
310
+ ]
311
+ )
312
+ rgb = Image.open(img)
313
+ rgb = transform(rgb)[:3, ...].permute(1, 2, 0)
314
+ return rgb
315
+
316
+
317
+ def preprocess_image(
318
+ images_pil: List[Image.Image],
319
+ force: bool = False,
320
+ background_color: List[int] = [255, 255, 255],
321
+ foreground_ratio: float = 0.95,
322
+ ):
323
+ r"""
324
+ Crop and remote the background of the input image
325
+ Args:
326
+ image_pil (`List[PIL.Image.Image]`):
327
+ List of `PIL.Image.Image` objects representing the input image.
328
+ force (`bool`, *optional*, defaults to `False`):
329
+ Whether to force remove the background even if the image has an alpha channel.
330
+ Returns:
331
+ `List[PIL.Image.Image]`: List of `PIL.Image.Image` objects representing the preprocessed image.
332
+ """
333
+ preprocessed_images = []
334
+ for i in range(len(images_pil)):
335
+ image = images_pil[i]
336
+ width, height, size = image.width, image.height, image.size
337
+ do_remove = True
338
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
339
+ # explain why current do not rm bg
340
+ print(
341
+ "alhpa channl not empty, skip remove background, using alpha channel as mask"
342
+ )
343
+ do_remove = False
344
+ do_remove = do_remove or force
345
+ if do_remove:
346
+ image = rembg.remove(image)
347
+
348
+ # calculate the min bbox of the image
349
+ alpha = image.split()[-1]
350
+ bboxs = alpha.getbbox()
351
+ x1, y1, x2, y2 = bboxs
352
+ dy, dx = y2 - y1, x2 - x1
353
+ s = min(height * foreground_ratio / dy, width * foreground_ratio / dx)
354
+ Ht, Wt = int(dy * s), int(dx * s)
355
+
356
+ background = Image.new("RGBA", image.size, (*background_color, 255))
357
+ image = Image.alpha_composite(background, image)
358
+ image = image.crop(alpha.getbbox())
359
+ alpha = alpha.crop(alpha.getbbox())
360
+
361
+ # Calculate the new size after rescaling
362
+ new_size = tuple(int(dim * foreground_ratio) for dim in size)
363
+ # Resize the image while maintaining the aspect ratio
364
+ resized_image = image.resize((Wt, Ht))
365
+ resized_alpha = alpha.resize((Wt, Ht))
366
+ # Create a new image with the original size and white background
367
+ padded_image = Image.new("RGB", size, tuple(background_color))
368
+ padded_alpha = Image.new("L", size, (0))
369
+ paste_position = (
370
+ (width - resized_image.width) // 2,
371
+ (height - resized_image.height) // 2,
372
+ )
373
+ padded_image.paste(resized_image, paste_position)
374
+ padded_alpha.paste(resized_alpha, paste_position)
375
+
376
+ # expand image to 1:1
377
+ width, height = padded_image.size
378
+ if width == height:
379
+ padded_image.putalpha(padded_alpha)
380
+ preprocessed_images.append(padded_image)
381
+ continue
382
+ new_size = (max(width, height), max(width, height))
383
+ new_image = Image.new("RGB", new_size, tuple(background_color))
384
+ new_alpha = Image.new("L", new_size, (0))
385
+ paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
386
+ new_image.paste(padded_image, paste_position)
387
+ new_alpha.paste(padded_alpha, paste_position)
388
+ new_image.putalpha(new_alpha)
389
+ preprocessed_images.append(new_image)
390
+
391
+ return preprocessed_images
step1x3d_geometry/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import base
step1x3d_geometry/utils/base.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import os
4
+ import copy
5
+ import json
6
+ from omegaconf import OmegaConf
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.utils import (
13
+ extract_commit_hash,
14
+ )
15
+
16
+ from step1x3d_geometry.utils.config import parse_structured
17
+ from step1x3d_geometry.utils.misc import get_device, load_module_weights
18
+ from step1x3d_geometry.utils.typing import *
19
+
20
+
21
+ class Configurable:
22
+ @dataclass
23
+ class Config:
24
+ pass
25
+
26
+ def __init__(self, cfg: Optional[dict] = None) -> None:
27
+ super().__init__()
28
+ self.cfg = parse_structured(self.Config, cfg)
29
+
30
+
31
+ class Updateable:
32
+ def do_update_step(
33
+ self, epoch: int, global_step: int, on_load_weights: bool = False
34
+ ):
35
+ for attr in self.__dir__():
36
+ if attr.startswith("_"):
37
+ continue
38
+ try:
39
+ module = getattr(self, attr)
40
+ except:
41
+ continue # ignore attributes like property, which can't be retrived using getattr?
42
+ if isinstance(module, Updateable):
43
+ module.do_update_step(
44
+ epoch, global_step, on_load_weights=on_load_weights
45
+ )
46
+ self.update_step(epoch, global_step, on_load_weights=on_load_weights)
47
+
48
+ def do_update_step_end(self, epoch: int, global_step: int):
49
+ for attr in self.__dir__():
50
+ if attr.startswith("_"):
51
+ continue
52
+ try:
53
+ module = getattr(self, attr)
54
+ except:
55
+ continue # ignore attributes like property, which can't be retrived using getattr?
56
+ if isinstance(module, Updateable):
57
+ module.do_update_step_end(epoch, global_step)
58
+ self.update_step_end(epoch, global_step)
59
+
60
+ def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
61
+ # override this method to implement custom update logic
62
+ # if on_load_weights is True, you should be careful doing things related to model evaluations,
63
+ # as the models and tensors are not guarenteed to be on the same device
64
+ pass
65
+
66
+ def update_step_end(self, epoch: int, global_step: int):
67
+ pass
68
+
69
+
70
+ def update_if_possible(module: Any, epoch: int, global_step: int) -> None:
71
+ if isinstance(module, Updateable):
72
+ module.do_update_step(epoch, global_step)
73
+
74
+
75
+ def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None:
76
+ if isinstance(module, Updateable):
77
+ module.do_update_step_end(epoch, global_step)
78
+
79
+
80
+ class BaseObject(Updateable):
81
+ @dataclass
82
+ class Config:
83
+ pass
84
+
85
+ cfg: Config # add this to every subclass of BaseObject to enable static type checking
86
+
87
+ def __init__(
88
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
89
+ ) -> None:
90
+ super().__init__()
91
+ self.cfg = parse_structured(self.Config, cfg)
92
+ self.device = get_device()
93
+ self.configure(*args, **kwargs)
94
+
95
+ def configure(self, *args, **kwargs) -> None:
96
+ pass
97
+
98
+
99
+ class BaseModule(ModelMixin, Updateable, nn.Module):
100
+ @dataclass
101
+ class Config:
102
+ weights: Optional[str] = None
103
+
104
+ cfg: Config # add this to every subclass of BaseModule to enable static type checking
105
+ config_name = "config.json"
106
+
107
+ def __init__(
108
+ self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs
109
+ ) -> None:
110
+ super().__init__()
111
+ self.cfg = parse_structured(self.Config, cfg)
112
+ # self.device = get_device()
113
+ self.configure(*args, **kwargs)
114
+ if self.cfg.weights is not None:
115
+ # format: path/to/weights:module_name
116
+ weights_path, module_name = self.cfg.weights.split(":")
117
+ state_dict, epoch, global_step = load_module_weights(
118
+ weights_path, module_name=module_name, map_location="cpu"
119
+ )
120
+ self.load_state_dict(state_dict)
121
+ self.do_update_step(
122
+ epoch, global_step, on_load_weights=True
123
+ ) # restore states
124
+ # dummy tensor to indicate model state
125
+ self._dummy: Float[Tensor, "..."]
126
+ self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False)
127
+
128
+ def configure(self, *args, **kwargs) -> None:
129
+ pass
130
+
131
+ @classmethod
132
+ def load_config(
133
+ cls,
134
+ pretrained_model_name_or_path: Union[str, os.PathLike],
135
+ return_unused_kwargs=False,
136
+ return_commit_hash=False,
137
+ **kwargs,
138
+ ):
139
+ subfolder = kwargs.pop("subfolder", None)
140
+
141
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
142
+ if os.path.isfile(pretrained_model_name_or_path):
143
+ config_file = pretrained_model_name_or_path
144
+ elif os.path.isdir(pretrained_model_name_or_path):
145
+ if subfolder is not None and os.path.isfile(
146
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
147
+ ):
148
+ config_file = os.path.join(
149
+ pretrained_model_name_or_path, subfolder, cls.config_name
150
+ )
151
+ elif os.path.isfile(
152
+ os.path.join(pretrained_model_name_or_path, cls.config_name)
153
+ ):
154
+ # Load from a PyTorch checkpoint
155
+ config_file = os.path.join(
156
+ pretrained_model_name_or_path, cls.config_name
157
+ )
158
+ else:
159
+ raise EnvironmentError(
160
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
161
+ )
162
+ else:
163
+ raise ValueError
164
+
165
+ config_dict = json.load(open(config_file, "r"))
166
+ commit_hash = extract_commit_hash(config_file)
167
+
168
+ outputs = (config_dict,)
169
+
170
+ if return_unused_kwargs:
171
+ outputs += (kwargs,)
172
+
173
+ if return_commit_hash:
174
+ outputs += (commit_hash,)
175
+
176
+ return outputs
177
+
178
+ @classmethod
179
+ def from_config(cls, config: Dict[str, Any] = None, **kwargs):
180
+ model = cls(config)
181
+ return model
182
+
183
+ def register_to_config(self, **kwargs):
184
+ pass
185
+
186
+ def save_config(self, save_directory: Union[str, os.PathLike], **kwargs):
187
+ """
188
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
189
+ [`~ConfigMixin.from_config`] class method.
190
+
191
+ Args:
192
+ save_directory (`str` or `os.PathLike`):
193
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
194
+ kwargs (`Dict[str, Any]`, *optional*):
195
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
196
+ """
197
+ if os.path.isfile(save_directory):
198
+ raise AssertionError(
199
+ f"Provided path ({save_directory}) should be a directory, not a file"
200
+ )
201
+
202
+ os.makedirs(save_directory, exist_ok=True)
203
+
204
+ # If we save using the predefined names, we can load using `from_config`
205
+ output_config_file = os.path.join(save_directory, self.config_name)
206
+
207
+ config_dict = OmegaConf.to_container(self.cfg, resolve=True)
208
+ for k in copy.deepcopy(config_dict).keys():
209
+ if k.startswith("pretrained"):
210
+ config_dict.pop(k)
211
+ config_dict.pop("weights")
212
+ with open(output_config_file, "w", encoding="utf-8") as f:
213
+ json.dump(config_dict, f, ensure_ascii=False, indent=4)
214
+
215
+ print(f"Configuration saved in {output_config_file}")