KangLiao commited on
Commit
ace9173
·
1 Parent(s): 12b117e
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +35 -0
  2. app.py +227 -0
  3. configs/models/qwen2_5_1_5b_radio_sd3_dynamic_puffin.py +76 -0
  4. configs/pipelines/stage_2_base.py +10 -0
  5. configs/pipelines/stage_3_thinking.py +11 -0
  6. configs/pipelines/stage_4_instruction_tuning.py +9 -0
  7. requirements.txt +48 -0
  8. scripts/camera/cam_dataset.py +107 -0
  9. scripts/camera/geometry/__init__.py +0 -0
  10. scripts/camera/geometry/base_camera.py +518 -0
  11. scripts/camera/geometry/camera.py +281 -0
  12. scripts/camera/geometry/gravity.py +129 -0
  13. scripts/camera/geometry/jacobians.py +63 -0
  14. scripts/camera/geometry/manifolds.py +113 -0
  15. scripts/camera/geometry/perspective_fields.py +379 -0
  16. scripts/camera/utils/conversions.py +150 -0
  17. scripts/camera/utils/image.py +182 -0
  18. scripts/camera/utils/tensor.py +249 -0
  19. scripts/camera/utils/text.py +47 -0
  20. scripts/camera/visualization/visualize_batch.py +188 -0
  21. scripts/camera/visualization/viz2d.py +521 -0
  22. src/datasets/utils.py +162 -0
  23. src/models/connector/__init__.py +2 -0
  24. src/models/connector/configuration_connector.py +27 -0
  25. src/models/connector/modeling_connector.py +507 -0
  26. src/models/connector/modeling_qwen2.py +50 -0
  27. src/models/puffin/model.py +790 -0
  28. src/models/radiov3/adaptor_base.py +37 -0
  29. src/models/radiov3/adaptor_generic.py +69 -0
  30. src/models/radiov3/adaptor_mlp.py +174 -0
  31. src/models/radiov3/adaptor_registry.py +37 -0
  32. src/models/radiov3/cls_token.py +59 -0
  33. src/models/radiov3/common.py +134 -0
  34. src/models/radiov3/dinov2_arch.py +1016 -0
  35. src/models/radiov3/dual_hybrid_vit.py +213 -0
  36. src/models/radiov3/enable_cpe_support.py +170 -0
  37. src/models/radiov3/enable_spectral_reparam.py +277 -0
  38. src/models/radiov3/eradio_model.py +1392 -0
  39. src/models/radiov3/extra_models.py +206 -0
  40. src/models/radiov3/extra_timm_models.py +206 -0
  41. src/models/radiov3/feature_normalizer.py +111 -0
  42. src/models/radiov3/forward_intermediates.py +138 -0
  43. src/models/radiov3/hf_model.py +202 -0
  44. src/models/radiov3/input_conditioner.py +49 -0
  45. src/models/radiov3/open_clip_adaptor.py +41 -0
  46. src/models/radiov3/radio_model.py +344 -0
  47. src/models/radiov3/vit_patch_generator.py +288 -0
  48. src/models/radiov3/vitdet.py +188 -0
  49. src/models/stable_diffusion3/pipeline_stable_diffusion_3.py +1256 -0
  50. src/models/stable_diffusion3/pipeline_stable_diffusion_3_dynamic.py +1257 -0
LICENSE ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S-Lab License 1.0
2
+
3
+ Copyright 2025 S-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and
6
+ binary forms, with or without modification, are permitted provided
7
+ that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright
10
+ notice, this list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright
13
+ notice, this list of conditions and the following disclaimer in
14
+ the documentation and/or other materials provided with the
15
+ distribution.
16
+
17
+ 3. Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived
19
+ from this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+
33
+ In the event that redistribution and/or use for commercial purpose in
34
+ source or binary forms, with or without modification is required,
35
+ please contact the contributor(s) of the work.
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ import spaces # Import spaces for ZeroGPU compatibility
6
+ from einops import rearrange
7
+
8
+ import math
9
+ import torch
10
+ import argparse
11
+ from PIL import Image
12
+ from einops import rearrange
13
+ from mmengine.config import Config
14
+ from xtuner.registry import BUILDER
15
+ from xtuner.model.utils import guess_load_checkpoint
16
+
17
+ from scripts.camera.cam_dataset import Cam_Generator
18
+
19
+ ##### load model
20
+ config = "configs/pipelines/stage_2_base.py"
21
+ config = Config.fromfile(config)
22
+ model = BUILDER.build(config.model).eval()
23
+ checkpoint_path = "checkpoints/Puffin-Base.pth"
24
+ state_dict = guess_load_checkpoint(checkpoint_path)
25
+ model.load_state_dict(state_dict, strict=False)
26
+
27
+ if torch.cuda.is_available():
28
+ model = model.to(torch.bfloat16).cuda()
29
+ else:
30
+ model = model.to(torch.float32)
31
+
32
+ def expand2square(pil_img, background_color):
33
+ width, height = pil_img.size
34
+ if width == height:
35
+ return pil_img
36
+ elif width > height:
37
+ result = Image.new(pil_img.mode, (width, width), background_color)
38
+ result.paste(pil_img, (0, (width - height) // 2))
39
+ return result
40
+ else:
41
+ result = Image.new(pil_img.mode, (height, height), background_color)
42
+ result.paste(pil_img, ((height - width) // 2, 0))
43
+ return result
44
+
45
+
46
+ @torch.inference_mode()
47
+ @spaces.GPU(duration=120)
48
+ # Multimodal Understanding function
49
+ def multimodal_understanding(image, question, seed, top_p, temperature, progress=gr.Progress(track_tqdm=True)):
50
+ # Clear CUDA cache before generating
51
+ torch.cuda.empty_cache()
52
+
53
+ # set seed
54
+ # torch.manual_seed(seed)
55
+ # np.random.seed(seed)
56
+ # torch.cuda.manual_seed(seed)
57
+ print(torch.cuda.is_available())
58
+
59
+ max_new_tokens = 512
60
+ image_size = 512
61
+ '''
62
+ assert image_size == 512
63
+ image = Image.fromarray(image).convert('RGB')
64
+ image = expand2square(
65
+ image, (127, 127, 127))
66
+ image = image.resize(size=(image_size, image_size))
67
+ image = torch.from_numpy(np.array(image)).to(dtype=model.dtype, device=model.device)
68
+ image = rearrange(image, 'h w c -> c h w')[None]
69
+ image = 2 * (image / 255) - 1
70
+
71
+ prompt = PROMPT_TEMPLATE['INSTRUCTION'].format(input="<image>\n" + question)
72
+ assert '<image>' in prompt
73
+ image_length = (image_size // 16) ** 2 + harmon_model.mar.buffer_size
74
+ prompt = prompt.replace('<image>', '<image>' * image_length)
75
+ input_ids = harmon_tokenizer.encode(
76
+ prompt, add_special_tokens=True, return_tensors='pt').to(harmon_model.device)
77
+ _, z_enc = harmon_model.extract_visual_feature(harmon_model.encode(image))
78
+ inputs_embeds = z_enc.new_zeros(*input_ids.shape, harmon_model.llm.config.hidden_size)
79
+ inputs_embeds[input_ids == image_token_idx] = z_enc.flatten(0, 1)
80
+ inputs_embeds[input_ids != image_token_idx] = harmon_model.llm.get_input_embeddings()(
81
+ input_ids[input_ids != image_token_idx]
82
+ )
83
+ output = harmon_model.llm.generate(inputs_embeds=inputs_embeds,
84
+ eos_token_id=harmon_tokenizer.eos_token_id,
85
+ pad_token_id=harmon_tokenizer.pad_token_id
86
+ if harmon_tokenizer.pad_token_id is not None else
87
+ harmon_tokenizer.eos_token_id,
88
+ max_new_tokens=max_new_tokens,
89
+ do_sample=False, # if temperature == 0 else True,
90
+ use_cache=True,
91
+ # temperature=temperature,
92
+ # top_p=top_p
93
+ )
94
+ '''
95
+ return 1#harmon_tokenizer.decode(output[0], skip_special_tokens=True)
96
+
97
+
98
+ @torch.inference_mode()
99
+ @spaces.GPU(duration=120) # Specify a duration to avoid timeout
100
+ def generate_image(prompt_scene,
101
+ seed=42,
102
+ roll=3,
103
+ pitch=1.0,
104
+ fov=1.0,
105
+ progress=gr.Progress(track_tqdm=True)):
106
+ # Clear CUDA cache and avoid tracking gradients
107
+ torch.cuda.empty_cache()
108
+ # Set the seed for reproducible results
109
+ # if seed is not None:
110
+ torch.manual_seed(seed)
111
+ torch.cuda.manual_seed(seed)
112
+ np.random.seed(seed)
113
+ print(torch.cuda.is_available())
114
+
115
+ generator = torch.Generator().manual_seed(seed)
116
+ prompt_camera = (
117
+ "The camera parameters (roll, pitch, and field-of-view) are: "
118
+ f"{roll:.4f}, {pitch:.4f}, {fov:.4f}."
119
+ )
120
+ gen = Cam_Generator()
121
+ cam_map = gen.get_cam(prompt_camera).to(model.device)
122
+ cam_map = cam_map / (math.pi / 2)
123
+
124
+ prompt = prompt_scene + " " + prompt_camera
125
+
126
+
127
+ bsz = 4
128
+ with torch.no_grad():
129
+ images, output_reasoning = model.generate(
130
+ prompt=[prompt]*bsz,
131
+ cfg_prompt=[""]*bsz,
132
+ pixel_values_init=None,
133
+ cfg_scale=4.5,
134
+ num_steps=50,
135
+ cam_values=[[cam_map]]*bsz,
136
+ progress_bar=False,
137
+ reasoning=False,
138
+ prompt_reasoning=[""]*bsz,
139
+ generator=generator,
140
+ height=512,
141
+ width=512
142
+ )
143
+
144
+ images = rearrange(images, 'b c h w -> b h w c')
145
+ images = torch.clamp(127.5 * images + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
146
+ ret_images = [Image.fromarray(image) for image in images]
147
+ return ret_images
148
+
149
+
150
+
151
+ # Gradio interface
152
+ css = '''
153
+ .gradio-container {max-width: 960px !important}
154
+ '''
155
+ with gr.Blocks(css=css) as demo:
156
+ gr.Markdown("# Puffin")
157
+
158
+ with gr.Tab("Camera-controllable Image Generation"):
159
+ gr.Markdown(value="## Camera-controllable Image Generation")
160
+
161
+ prompt_input = gr.Textbox(label="Prompt.")
162
+
163
+ with gr.Accordion("Camera Parameters", open=True):
164
+ with gr.Row():
165
+ roll = gr.Slider(minimum=-0.7854, maximum=0.7854, value=0.1000, step=0.1000, label="roll value")
166
+ pitch = gr.Slider(minimum=-0.7854, maximum=0.7854, value=-0.1000, step=0.1000, label="pitch value")
167
+ fov = gr.Slider(minimum=0.3491, maximum=1.8326, value=1.5000, step=0.1000, label="fov value")
168
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234)
169
+
170
+ generation_button = gr.Button("Generate Images")
171
+
172
+ image_output = gr.Gallery(label="Generated Images", columns=4, rows=1)
173
+
174
+ examples_t2i = gr.Examples(
175
+ label="Prompt examples.",
176
+ examples=[
177
+ "A sunny day casts light on two warmly colored buildings—yellow with green accents and deeper orange—framed by a lush green tree, with a blue sign and street lamp adding details in the foreground.",
178
+ "A high-vantage-point view of lush, autumn-colored mountains blanketed in green and gold, set against a clear blue sky with scattered white clouds, offering a tranquil and breathtaking vista of a serene valley below.",
179
+ "A grand, historic castle with pointed spires and elaborate stone structures stands against a clear blue sky, flanked by a circular fountain, vibrant red flowers, and neatly trimmed hedges in a beautifully landscaped garden.",
180
+ "A serene aerial view of a coastal landscape at sunrise/sunset, featuring warm pink and orange skies transitioning to cool blues, with calm waters stretching to rugged, snow-capped mountains in the background, creating a tranquil and picturesque scene.",
181
+ "A worn, light-yellow walls room with herringbone terracotta floors and three large arched windows framed in pink trim and white panes, showcasing signs of age and disrepair, overlooks a residential area through glimpses of greenery and neighboring buildings.",
182
+ ],
183
+ inputs=prompt_input,
184
+ )
185
+
186
+ with gr.Tab("Multimodal Understanding"):
187
+ gr.Markdown(value="## Multimodal Understanding")
188
+ image_input = gr.Image()
189
+ with gr.Column():
190
+ question_input = gr.Textbox(label="Question")
191
+
192
+ understanding_button = gr.Button("Chat")
193
+ understanding_output = gr.Textbox(label="Response")
194
+
195
+ with gr.Accordion("Advanced options", open=False):
196
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
197
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
198
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
199
+
200
+ examples_inpainting = gr.Examples(
201
+ label="Multimodal Understanding examples",
202
+ examples=[
203
+ [
204
+ "Is the picture taken in winter?",
205
+ "view.jpg",
206
+ ],
207
+ [
208
+ "Briefly describe the image.",
209
+ "view.jpg",
210
+ ],
211
+ ],
212
+ inputs=[question_input, image_input],
213
+ )
214
+
215
+ generation_button.click(
216
+ fn=generate_image,
217
+ inputs=[prompt_input, seed_input, roll, pitch, fov],
218
+ outputs=image_output
219
+ )
220
+
221
+ understanding_button.click(
222
+ multimodal_understanding,
223
+ inputs=[image_input, question_input, und_seed_input, top_p, temperature],
224
+ outputs=understanding_output
225
+ )
226
+
227
+ demo.launch(share=True)
configs/models/qwen2_5_1_5b_radio_sd3_dynamic_puffin.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from src.models.puffin.model import Qwen2p5RadioStableDiffusion3HFDynamic
3
+ from src.models.stable_diffusion3.transformer_sd3_dynamic import SD3Transformer2DModel
4
+ from src.models.radiov3.hf_model import RADIOModel
5
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+ llm_name_or_path = 'Qwen/Qwen2.5-1.5B-Instruct'
9
+ sd3_model_name_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
10
+
11
+ prompt_template = dict(
12
+ SYSTEM=('<|im_start|>system\n{system}<|im_end|>\n'),
13
+ INSTRUCTION=('<|im_start|>user\n{input}<|im_end|>\n'
14
+ '<|im_start|>assistant\n'),
15
+ SUFFIX='<|im_end|>',
16
+ IMG_START_TOKEN='<|vision_start|>',
17
+ IMG_END_TOKEN='<|vision_end|>',
18
+ IMG_CONTEXT_TOKEN='<|image_pad|>',
19
+ GENERATION='Generate an image: {input}',
20
+ GENERATION_CROSS='Generate a target image given an initial view: {input}',
21
+ SUFFIX_AS_EOS=True,
22
+ SEP='\n',
23
+ STOP_WORDS=['<|im_end|>', '<|endoftext|>']
24
+ )
25
+
26
+ model = dict(type=Qwen2p5RadioStableDiffusion3HFDynamic,
27
+ num_queries=64,
28
+ connector_1=dict(
29
+ hidden_size=1024,
30
+ intermediate_size=4096,
31
+ num_hidden_layers=6,
32
+ _attn_implementation='flash_attention_2',
33
+ num_attention_heads=16, ),
34
+ connector_2=dict(
35
+ hidden_size=1024,
36
+ intermediate_size=4096,
37
+ num_hidden_layers=6,
38
+ _attn_implementation='flash_attention_2',
39
+ num_attention_heads=16, ),
40
+ transformer=dict(
41
+ type=SD3Transformer2DModel.from_pretrained,
42
+ pretrained_model_name_or_path=sd3_model_name_or_path,
43
+ subfolder="transformer",
44
+ torch_dtype=torch.bfloat16),
45
+ test_scheduler=dict(
46
+ type=FlowMatchEulerDiscreteScheduler.from_pretrained,
47
+ pretrained_model_name_or_path=sd3_model_name_or_path,
48
+ subfolder="scheduler"),
49
+ train_scheduler=dict(
50
+ type=FlowMatchEulerDiscreteScheduler.from_pretrained,
51
+ pretrained_model_name_or_path=sd3_model_name_or_path,
52
+ subfolder="scheduler"),
53
+ vae=dict(
54
+ type=AutoencoderKL.from_pretrained,
55
+ pretrained_model_name_or_path=sd3_model_name_or_path,
56
+ subfolder="vae",
57
+ torch_dtype=torch.bfloat16),
58
+ freeze_visual_encoder=True,
59
+ freeze_llm=True,
60
+ llm=dict(
61
+ type=AutoModelForCausalLM.from_pretrained,
62
+ pretrained_model_name_or_path=llm_name_or_path,
63
+ torch_dtype=torch.bfloat16,
64
+ attn_implementation='flash_attention_2',
65
+ ),
66
+ tokenizer=dict(
67
+ type=AutoTokenizer.from_pretrained,
68
+ pretrained_model_name_or_path=llm_name_or_path),
69
+ prompt_template=prompt_template,
70
+ pretrained_pth=None,
71
+ use_activation_checkpointing=False,
72
+ visual_encoder=dict(
73
+ type=RADIOModel.from_pretrained,
74
+ pretrained_model_name_or_path="nvidia/C-RADIOv3-H",
75
+ torch_dtype=torch.bfloat16,),
76
+ )
configs/pipelines/stage_2_base.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine.config import read_base
2
+
3
+ with read_base():
4
+ from ..models.qwen2_5_1_5b_radio_sd3_dynamic_puffin import model
5
+
6
+ model.freeze_visual_encoder = False
7
+ model.freeze_llm = False
8
+ model.freeze_transformer = False
9
+ model.use_activation_checkpointing = True
10
+ model.visual_encoder_grad_scale = 0.1
configs/pipelines/stage_3_thinking.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine.config import read_base
2
+
3
+ with read_base():
4
+ from ..models.qwen2_5_1_5b_radio_sd3_dynamic_puffin import model
5
+
6
+ model.freeze_visual_encoder = False
7
+ model.freeze_llm = False
8
+ model.freeze_transformer = False
9
+ model.use_activation_checkpointing = True
10
+ model.visual_encoder_grad_scale = 0.1
11
+ #model.pretrained_pth = 'work_dirs/stage_2_base/iter_30000.pth'
configs/pipelines/stage_4_instruction_tuning.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine.config import read_base
2
+ with read_base():
3
+ from ..models.qwen2_5_1_5b_radio_sd3_dynamic_puffin import model
4
+
5
+ model.freeze_visual_encoder = True
6
+ model.freeze_llm = False
7
+ model.freeze_transformer = False
8
+ model.use_activation_checkpointing = True
9
+ model.unconditional_cross_view=0.1
requirements.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.6.0
2
+ click==8.1.8
3
+ decorator==5.2.1
4
+ deepspeed==0.16.7
5
+ diffusers==0.34.0
6
+ einops==0.8.1
7
+ feedparser==6.0.11
8
+ flash_attn==2.3.4
9
+ huggingface-hub==0.31.1
10
+ hyperframe==6.1.0
11
+ idna==3.10
12
+ imageio==2.37.0
13
+ importlib_metadata==8.7.0
14
+ json5==0.12.0
15
+ lazy_loader==0.4
16
+ lightning-utilities==0.14.3
17
+ matplotlib==3.10.1
18
+ matplotlib-inline==0.1.7
19
+ mmengine==0.10.7
20
+ networkx==3.4.2
21
+ ninja==1.11.1.4
22
+ numpy==2.2.5
23
+ opencv-python==4.11.0.86
24
+ opencv-python-headless==4.12.0.88
25
+ openpyxl==3.1.5
26
+ pandas==2.2.3
27
+ peft==0.15.2
28
+ pillow==11.2.1
29
+ pytz==2025.2
30
+ PyYAML==6.0.2
31
+ safetensors==0.5.3
32
+ scikit-image==0.25.2
33
+ scipy==1.15.2
34
+ six==1.17.0
35
+ timm==0.9.12
36
+ tokenizers==0.21.2
37
+ torch==2.7.0
38
+ torch-fidelity==0.3.0
39
+ torchmetrics==1.7.2
40
+ torchvision==0.22.0
41
+ tornado==6.4.2
42
+ tqdm==4.67.1
43
+ transformers==4.49.0
44
+ transformers-stream-generator==0.0.5
45
+ triton==3.3.0
46
+ xtuner==0.1.23
47
+ yarl==1.20.0
48
+
scripts/camera/cam_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ from scripts.camera.geometry.camera import SimpleRadial
8
+ from scripts.camera.geometry.gravity import Gravity
9
+ from scripts.camera.geometry.perspective_fields import get_perspective_field
10
+ from scripts.camera.utils.conversions import fov2focal
11
+ from scripts.camera.utils.text import parse_camera_params
12
+
13
+ class Cam_Generator:
14
+ def __init__(self, mode="base"):
15
+ self.mode = mode
16
+
17
+ def _load_text(self, caption, h=512, w=512, k1=0, k2=0):
18
+ # Parse camera params from caption
19
+ roll, pitch, vfov = parse_camera_params(caption, self.mode)
20
+
21
+ # Convert vertical FoV to focal length
22
+ f = fov2focal(torch.tensor(vfov), h)
23
+ px, py = w / 2, h / 2
24
+ params = torch.tensor([w, h, f, f, px, py, k1, k2]).float()
25
+ gravity = torch.tensor([roll, pitch]).float()
26
+ return params, gravity
27
+
28
+ def _read_param(self, parameters, gravity):
29
+ # Build camera and gravity objects
30
+ camera = SimpleRadial(parameters).float()
31
+ roll, pitch = gravity.unbind(-1)
32
+ gravity_obj = Gravity.from_rp(roll, pitch)
33
+ camera = camera.scale(torch.Tensor([1, 1]))
34
+ return {"camera": camera, "gravity": gravity_obj}
35
+
36
+ def _get_perspective(self, data):
37
+ # Generate up and latitude fields
38
+ camera = data["camera"]
39
+ gravity_obj = data["gravity"]
40
+ up_field, lat_field = get_perspective_field(
41
+ camera, gravity_obj, use_up=True, use_latitude=True
42
+ )
43
+ del camera, gravity_obj
44
+ return torch.cat([up_field[0], lat_field[0]], dim=0)
45
+
46
+ def get_cam(self, caption):
47
+ params, gravity = self._load_text(caption)
48
+ data = self._read_param(params, gravity)
49
+ return self._get_perspective(data)
50
+
51
+ def process_folders(input_root, output_root, start_idx=0, num_folders=None, mode="base"):
52
+ gen = Cam_Generator(mode=mode)
53
+ all_dirs = sorted([
54
+ d for d in os.listdir(input_root)
55
+ if os.path.isdir(os.path.join(input_root, d))
56
+ ])
57
+ if num_folders is None:
58
+ num_folders = len(all_dirs) - start_idx
59
+ selected = all_dirs[start_idx:start_idx + num_folders]
60
+
61
+ for sub in tqdm(selected, desc="Subfolders"):
62
+ in_sub = os.path.join(input_root, sub)
63
+ out_sub = os.path.join(output_root, sub)
64
+ os.makedirs(out_sub, exist_ok=True)
65
+
66
+ json_files = sorted([
67
+ f for f in os.listdir(in_sub)
68
+ if f.lower().endswith('.json')
69
+ ])
70
+
71
+ for jf in tqdm(json_files, desc=f"Processing {sub}", leave=False):
72
+ in_path = os.path.join(in_sub, jf)
73
+ with open(in_path, 'r', encoding='utf-8') as f:
74
+ data = json.load(f)
75
+ caption = data.get('caption', '')
76
+ cam = gen.get_cam(caption)
77
+ out_name = os.path.splitext(jf)[0] + '.pt'
78
+ out_path = os.path.join(out_sub, out_name)
79
+ torch.save(cam, out_path)
80
+
81
+ def main():
82
+ parser = argparse.ArgumentParser(
83
+ description="Batch process the captions to the camera maps and save as .pt"
84
+ )
85
+ parser.add_argument('--input_root', type=str,
86
+ help='Root directory of JSON subfolders')
87
+ parser.add_argument('--output_root', type=str,
88
+ help='Root directory to save .pt files')
89
+ parser.add_argument('--start_idx', type=int, default=0,
90
+ help='Start index of subfolders (0-based, default=0)')
91
+ parser.add_argument('--num_folders', type=int, default=None,
92
+ help='Number of subfolders to process (default: all)')
93
+ parser.add_argument('--mode', type=str, default='base',
94
+ help='parse_camera_params mode')
95
+ args = parser.parse_args()
96
+
97
+ process_folders(
98
+ args.input_root,
99
+ args.output_root,
100
+ start_idx=args.start_idx,
101
+ num_folders=args.num_folders,
102
+ mode=args.mode
103
+ )
104
+
105
+
106
+ if __name__ == '__main__':
107
+ main()
scripts/camera/geometry/__init__.py ADDED
File without changes
scripts/camera/geometry/base_camera.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convenience classes a for camera models.
2
+
3
+ Based on PyTorch tensors: differentiable, batched, with GPU support.
4
+ Adapted from https://github.com/cvg/GeoCalib
5
+ """
6
+
7
+ from abc import abstractmethod
8
+ from typing import Dict, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch.func import jacfwd, vmap
12
+ from torch.nn import functional as F
13
+
14
+ from scripts.camera.geometry.gravity import Gravity
15
+ from scripts.camera.utils.conversions import deg2rad, focal2fov, fov2focal, rad2rotmat
16
+ from scripts.camera.utils.tensor import TensorWrapper, autocast
17
+
18
+ # mypy: ignore-errors
19
+
20
+
21
+ class BaseCamera(TensorWrapper):
22
+ """Camera tensor class."""
23
+
24
+ eps = 1e-3
25
+
26
+ @autocast
27
+ def __init__(self, data: torch.Tensor):
28
+ """Camera parameters with shape (..., {w, h, fx, fy, cx, cy, *dist}).
29
+
30
+ Tensor convention: (..., {w, h, fx, fy, cx, cy, pitch, roll, *dist}) where
31
+ - w, h: image size in pixels
32
+ - fx, fy: focal lengths in pixels
33
+ - cx, cy: principal points in normalized image coordinates
34
+ - dist: distortion parameters
35
+
36
+ Args:
37
+ data (torch.Tensor): Camera parameters with shape (..., {6, 7, 8}).
38
+ """
39
+ # w, h, fx, fy, cx, cy, dist
40
+ assert data.shape[-1] in {6, 7, 8}, data.shape
41
+
42
+ pad = data.new_zeros(data.shape[:-1] + (8 - data.shape[-1],))
43
+ data = torch.cat([data, pad], -1) if data.shape[-1] != 8 else data
44
+ super().__init__(data)
45
+
46
+ @classmethod
47
+ def from_dict(cls, param_dict: Dict[str, torch.Tensor]) -> "BaseCamera":
48
+ """Create a Camera object from a dictionary of parameters.
49
+
50
+ Args:
51
+ param_dict (Dict[str, torch.Tensor]): Dictionary of parameters.
52
+
53
+ Returns:
54
+ Camera: Camera object.
55
+ """
56
+ for key, value in param_dict.items():
57
+ if not isinstance(value, torch.Tensor):
58
+ param_dict[key] = torch.tensor(value)
59
+
60
+ h, w = param_dict["height"], param_dict["width"]
61
+ cx, cy = param_dict.get("cx", w / 2), param_dict.get("cy", h / 2)
62
+
63
+ vfov = param_dict.get("vfov")
64
+ f = param_dict.get("f", fov2focal(vfov, h))
65
+
66
+ if "dist" in param_dict:
67
+ k1, k2 = param_dict["dist"][..., 0], param_dict["dist"][..., 1]
68
+ elif "k1_hat" in param_dict:
69
+ k1 = param_dict["k1_hat"] * (f / h) ** 2
70
+
71
+ k2 = param_dict.get("k2", torch.zeros_like(k1))
72
+ else:
73
+ k1 = param_dict.get("k1", torch.zeros_like(f))
74
+ k2 = param_dict.get("k2", torch.zeros_like(f))
75
+
76
+ fx, fy = f, f
77
+ if "scales" in param_dict:
78
+ scales = param_dict["scales"]
79
+ fx = fx * scales[..., 0] / scales[..., 1]
80
+
81
+ params = torch.stack([w, h, fx, fy, cx, cy, k1, k2], dim=-1)
82
+ return cls(params)
83
+
84
+ def pinhole(self):
85
+ """Return the pinhole camera model."""
86
+ return self.__class__(self._data[..., :6])
87
+
88
+ @property
89
+ def size(self) -> torch.Tensor:
90
+ """Size (width height) of the images, with shape (..., 2)."""
91
+ return self._data[..., :2]
92
+
93
+ @property
94
+ def f(self) -> torch.Tensor:
95
+ """Focal lengths (fx, fy) with shape (..., 2)."""
96
+ return self._data[..., 2:4]
97
+
98
+ @property
99
+ def vfov(self) -> torch.Tensor:
100
+ """Vertical field of view in radians."""
101
+ return focal2fov(self.f[..., 1], self.size[..., 1])
102
+
103
+ @property
104
+ def hfov(self) -> torch.Tensor:
105
+ """Horizontal field of view in radians."""
106
+ return focal2fov(self.f[..., 0], self.size[..., 0])
107
+
108
+ @property
109
+ def c(self) -> torch.Tensor:
110
+ """Principal points (cx, cy) with shape (..., 2)."""
111
+ return self._data[..., 4:6]
112
+
113
+ @property
114
+ def K(self) -> torch.Tensor:
115
+ """Returns the self intrinsic matrix with shape (..., 3, 3)."""
116
+ shape = self.shape + (3, 3)
117
+ K = self._data.new_zeros(shape)
118
+ K[..., 0, 0] = self.f[..., 0]
119
+ K[..., 1, 1] = self.f[..., 1]
120
+ K[..., 0, 2] = self.c[..., 0]
121
+ K[..., 1, 2] = self.c[..., 1]
122
+ K[..., 2, 2] = 1
123
+ return K
124
+
125
+ def update_focal(self, delta: torch.Tensor, as_log: bool = False):
126
+ """Update the self parameters after changing the focal length."""
127
+ f = torch.exp(torch.log(self.f) + delta) if as_log else self.f + delta
128
+
129
+ # clamp focal length to a reasonable range for stability during training
130
+ min_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(150), self.size[..., 1])
131
+ max_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(5), self.size[..., 1])
132
+ min_f = min_f.unsqueeze(-1).expand(-1, 2)
133
+ max_f = max_f.unsqueeze(-1).expand(-1, 2)
134
+ f = f.clamp(min=min_f, max=max_f)
135
+
136
+ # make sure focal ration stays the same (avoid inplace operations)
137
+ fx = f[..., 1] * self.f[..., 0] / self.f[..., 1]
138
+ f = torch.stack([fx, f[..., 1]], -1)
139
+
140
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
141
+ return self.__class__(torch.cat([self.size, f, self.c, dist], -1))
142
+
143
+ def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):
144
+ """Update the self parameters after resizing an image."""
145
+ scales = (scales, scales) if isinstance(scales, (int, float)) else scales
146
+ s = scales if isinstance(scales, torch.Tensor) else self.new_tensor(scales)
147
+
148
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
149
+ return self.__class__(torch.cat([self.size * s, self.f * s, self.c * s, dist], -1))
150
+
151
+ def crop(self, pad: Tuple[float]):
152
+ """Update the self parameters after cropping an image."""
153
+ pad = pad if isinstance(pad, torch.Tensor) else self.new_tensor(pad)
154
+ size = self.size + pad.to(self.size)
155
+ c = self.c + pad.to(self.c) / 2
156
+
157
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
158
+ return self.__class__(torch.cat([size, self.f, c, dist], -1))
159
+
160
+ def undo_scale_crop(self, data: Dict[str, torch.Tensor]):
161
+ """Undo transforms done during scaling and cropping."""
162
+ camera = self.crop(-data["crop_pad"]) if "crop_pad" in data else self
163
+ return camera.scale(1.0 / data["scales"])
164
+
165
+ @autocast
166
+ def in_image(self, p2d: torch.Tensor):
167
+ """Check if 2D points are within the image boundaries."""
168
+ assert p2d.shape[-1] == 2
169
+ size = self.size.unsqueeze(-2)
170
+ return torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
171
+
172
+ @autocast
173
+ def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
174
+ """Project 3D points into the self plane and check for visibility."""
175
+ z = p3d[..., -1]
176
+ valid = z > self.eps
177
+ z = z.clamp(min=self.eps)
178
+ p2d = p3d[..., :-1] / z.unsqueeze(-1)
179
+ return p2d, valid
180
+
181
+ def J_project(self, p3d: torch.Tensor):
182
+ """Jacobian of the projection function."""
183
+ x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
184
+ zero = torch.zeros_like(z)
185
+ z = z.clamp(min=self.eps)
186
+ J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
187
+ J = J.reshape(p3d.shape[:-1] + (2, 3))
188
+ return J # N x 2 x 3
189
+
190
+ @abstractmethod
191
+ def distort(self, pts: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
192
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
193
+ raise NotImplementedError("distort() must be implemented.")
194
+
195
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
196
+ """Jacobian of the distortion function."""
197
+ if wrt == "scale2pts": # (..., 2)
198
+ J = [
199
+ vmap(jacfwd(lambda x: self[idx].distort(x, return_scale=True)[0]))(p2d[idx])[None]
200
+ for idx in range(p2d.shape[0])
201
+ ]
202
+
203
+ return torch.cat(J, dim=0).squeeze(-3, -2)
204
+
205
+ elif wrt == "scale2dist": # (..., 1)
206
+ J = []
207
+ for idx in range(p2d.shape[0]): # loop to batch pts dimension
208
+
209
+ def func(x):
210
+ params = torch.cat([self._data[idx, :6], x[None]], -1)
211
+ return self.__class__(params).distort(p2d[idx], return_scale=True)[0]
212
+
213
+ J.append(vmap(jacfwd(func))(self[idx].dist))
214
+
215
+ return torch.cat(J, dim=0)
216
+
217
+ else:
218
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
219
+
220
+ @abstractmethod
221
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
222
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
223
+ raise NotImplementedError("undistort() must be implemented.")
224
+
225
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
226
+ """Jacobian of the undistortion function."""
227
+ if wrt == "pts": # (..., 2, 2)
228
+ J = [
229
+ vmap(jacfwd(lambda x: self[idx].undistort(x)[0]))(p2d[idx])[None]
230
+ for idx in range(p2d.shape[0])
231
+ ]
232
+
233
+ return torch.cat(J, dim=0).squeeze(-3)
234
+
235
+ elif wrt == "dist": # (..., 1)
236
+ J = []
237
+ for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
238
+
239
+ def func(x):
240
+ params = torch.cat([self._data[batch_idx, :6], x[None]], -1)
241
+ return self.__class__(params).undistort(p2d[batch_idx])[0]
242
+
243
+ J.append(vmap(jacfwd(func))(self[batch_idx].dist))
244
+
245
+ return torch.cat(J, dim=0)
246
+ else:
247
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
248
+
249
+ @autocast
250
+ def up_projection_offset(self, p2d: torch.Tensor) -> torch.Tensor:
251
+ """Compute the offset for the up-projection."""
252
+ return self.J_distort(p2d, wrt="scale2pts") # (B, N, 2)
253
+
254
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
255
+ """Jacobian of the distortion offset for up-projection."""
256
+ if wrt == "uv": # (B, N, 2, 2)
257
+ J = [
258
+ vmap(jacfwd(lambda x: self[idx].up_projection_offset(x)[0, 0]))(p2d[idx])[None]
259
+ for idx in range(p2d.shape[0])
260
+ ]
261
+
262
+ return torch.cat(J, dim=0)
263
+
264
+ elif wrt == "dist": # (B, N, 2)
265
+ J = []
266
+ for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
267
+
268
+ def func(x):
269
+ params = torch.cat([self._data[batch_idx, :6], x[None]], -1)[None]
270
+ return self.__class__(params).up_projection_offset(p2d[batch_idx][None])
271
+
272
+ J.append(vmap(jacfwd(func))(self[batch_idx].dist))
273
+
274
+ return torch.cat(J, dim=0).squeeze(1)
275
+ else:
276
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
277
+
278
+ @autocast
279
+ def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
280
+ """Convert normalized 2D coordinates into pixel coordinates."""
281
+ return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
282
+
283
+ def J_denormalize(self):
284
+ """Jacobian of the denormalization function."""
285
+ return torch.diag_embed(self.f) # ..., 2 x 2
286
+
287
+ @autocast
288
+ def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
289
+ """Convert pixel coordinates into normalized 2D coordinates."""
290
+ return (p2d - self.c.unsqueeze(-2)) / (self.f.unsqueeze(-2))
291
+
292
+ def J_normalize(self, p2d: torch.Tensor, wrt: str = "f"):
293
+ """Jacobian of the normalization function."""
294
+ # ... x N x 2 x 2
295
+ if wrt == "f":
296
+ J_f = -(p2d - self.c.unsqueeze(-2)) / ((self.f.unsqueeze(-2)) ** 2)
297
+ return torch.diag_embed(J_f)
298
+ elif wrt == "pts":
299
+ J_pts = 1 / self.f
300
+ return torch.diag_embed(J_pts)
301
+ else:
302
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
303
+
304
+ def pixel_coordinates(self) -> torch.Tensor:
305
+ """Pixel coordinates in self frame.
306
+
307
+ Returns:
308
+ torch.Tensor: Pixel coordinates as a tensor of shape (B, h * w, 2).
309
+ """
310
+ w, h = self.size[0].unbind(-1)
311
+ h, w = h.round().to(int), w.round().to(int)
312
+
313
+ # create grid
314
+ x = torch.arange(0, w, dtype=self.dtype, device=self.device)
315
+ y = torch.arange(0, h, dtype=self.dtype, device=self.device)
316
+ x, y = torch.meshgrid(x, y, indexing="xy")
317
+ xy = torch.stack((x, y), dim=-1).reshape(-1, 2) # shape (h * w, 2)
318
+
319
+ # add batch dimension (normalize() would broadcast but we make it explicit)
320
+ B = self.shape[0]
321
+ xy = xy.unsqueeze(0).expand(B, -1, -1) # if B > 0 else xy
322
+
323
+ return xy.to(self.device).to(self.dtype)
324
+
325
+ def normalized_image_coordinates(self) -> torch.Tensor:
326
+ """Normalized image coordinates in self frame.
327
+
328
+ Returns:
329
+ torch.Tensor: Normalized image coordinates as a tensor of shape (B, h * w, 3).
330
+ """
331
+ xy = self.pixel_coordinates()
332
+ uv1, _ = self.image2world(xy)
333
+
334
+ B = self.shape[0]
335
+ uv1 = uv1.reshape(B, -1, 3)
336
+ return uv1.to(self.device).to(self.dtype)
337
+
338
+ @autocast
339
+ def pixel_bearing_many(self, p3d: torch.Tensor) -> torch.Tensor:
340
+ """Get the bearing vectors of pixel coordinates.
341
+
342
+ Args:
343
+ p2d (torch.Tensor): Pixel coordinates as a tensor of shape (..., 3).
344
+
345
+ Returns:
346
+ torch.Tensor: Bearing vectors as a tensor of shape (..., 3).
347
+ """
348
+ return F.normalize(p3d, dim=-1)
349
+
350
+ @autocast
351
+ def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
352
+ """Transform 3D points into 2D pixel coordinates."""
353
+ p2d, visible = self.project(p3d)
354
+ p2d, mask = self.distort(p2d)
355
+ p2d = self.denormalize(p2d)
356
+ valid = visible & mask & self.in_image(p2d)
357
+ return p2d, valid
358
+
359
+ @autocast
360
+ def J_world2image(self, p3d: torch.Tensor):
361
+ """Jacobian of the world2image function."""
362
+ p2d_proj, valid = self.project(p3d)
363
+
364
+ J_dnorm = self.J_denormalize()
365
+ J_dist = self.J_distort(p2d_proj)
366
+ J_proj = self.J_project(p3d)
367
+
368
+ J = torch.einsum("...ij,...jk,...kl->...il", J_dnorm, J_dist, J_proj)
369
+ return J, valid
370
+
371
+ @autocast
372
+ def image2world(self, p2d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
373
+ """Transform point in the image plane to 3D world coordinates."""
374
+ p2d = self.normalize(p2d)
375
+ p2d, valid = self.undistort(p2d)
376
+ ones = p2d.new_ones(p2d.shape[:-1] + (1,))
377
+ p3d = torch.cat([p2d, ones], -1)
378
+ return p3d, valid
379
+
380
+ @autocast
381
+ def J_image2world(self, p2d: torch.Tensor, wrt: str = "f") -> Tuple[torch.Tensor, torch.Tensor]:
382
+ """Jacobian of the image2world function."""
383
+ if wrt == "dist":
384
+ p2d_norm = self.normalize(p2d)
385
+ return self.J_undistort(p2d_norm, wrt)
386
+ elif wrt == "f":
387
+ J_norm2f = self.J_normalize(p2d, wrt)
388
+ p2d_norm = self.normalize(p2d)
389
+ J_dist2norm = self.J_undistort(p2d_norm, "pts")
390
+
391
+ return torch.einsum("...ij,...jk->...ik", J_dist2norm, J_norm2f)
392
+ else:
393
+ raise ValueError(f"Unknown wrt: {wrt}")
394
+
395
+ @autocast
396
+ def undistort_image(self, img: torch.Tensor) -> torch.Tensor:
397
+ """Undistort an image using the distortion model."""
398
+ assert self.shape[0] == 1, "Batch size must be 1."
399
+ W, H = self.size.unbind(-1)
400
+ H, W = H.int().item(), W.int().item()
401
+
402
+ x, y = torch.arange(0, W), torch.arange(0, H)
403
+ x, y = torch.meshgrid(x, y, indexing="xy")
404
+ coords = torch.stack((x, y), dim=-1).reshape(-1, 2)
405
+
406
+ p3d, _ = self.pinhole().image2world(coords.to(self.device).to(self.dtype))
407
+ p2d, _ = self.world2image(p3d)
408
+
409
+ mapx, mapy = p2d[..., 0].reshape((1, H, W)), p2d[..., 1].reshape((1, H, W))
410
+ grid = torch.stack((mapx, mapy), dim=-1)
411
+ grid = 2.0 * grid / torch.tensor([W - 1, H - 1]).to(grid) - 1
412
+ return F.grid_sample(img, grid, align_corners=True)
413
+
414
+ def get_img_from_pano(
415
+ self,
416
+ pano_img: torch.Tensor,
417
+ gravity: Gravity,
418
+ yaws: torch.Tensor = 0.0,
419
+ resize_factor: Optional[torch.Tensor] = None,
420
+ ) -> torch.Tensor:
421
+ """Render an image from a panorama.
422
+
423
+ Args:
424
+ pano_img (torch.Tensor): Panorama image of shape (3, H, W) in [0, 1].
425
+ gravity (Gravity): Gravity direction of the camera.
426
+ yaws (torch.Tensor | list, optional): Yaw angle in radians. Defaults to 0.0.
427
+ resize_factor (torch.Tensor, optional): Resize the panorama to be a multiple of the
428
+ field of view. Defaults to 1.
429
+
430
+ Returns:
431
+ torch.Tensor: Image rendered from the panorama.
432
+ """
433
+ B = self.shape[0]
434
+ if B > 0:
435
+ assert self.size[..., 0].unique().shape[0] == 1, "All images must have the same width."
436
+ assert self.size[..., 1].unique().shape[0] == 1, "All images must have the same height."
437
+
438
+ w, h = self.size[0].unbind(-1)
439
+ h, w = h.round().to(int), w.round().to(int)
440
+
441
+ if isinstance(yaws, (int, float)):
442
+ yaws = [yaws]
443
+ if isinstance(resize_factor, (int, float)):
444
+ resize_factor = [resize_factor]
445
+
446
+ yaws = (
447
+ yaws.to(self.dtype).to(self.device)
448
+ if isinstance(yaws, torch.Tensor)
449
+ else self.new_tensor(yaws)
450
+ )
451
+
452
+ if isinstance(resize_factor, torch.Tensor):
453
+ resize_factor = resize_factor.to(self.dtype).to(self.device)
454
+ elif resize_factor is not None:
455
+ resize_factor = self.new_tensor(resize_factor)
456
+
457
+ assert isinstance(pano_img, torch.Tensor), "Panorama image must be a torch.Tensor."
458
+ pano_img = pano_img if pano_img.dim() == 4 else pano_img.unsqueeze(0) # B x 3 x H x W
459
+
460
+ pano_imgs = []
461
+ for i, yaw in enumerate(yaws):
462
+ if resize_factor is not None:
463
+ # resize the panorama such that the fov of the panorama has the same height as the
464
+ # image
465
+ vfov = self.vfov[i] if B != 0 else self.vfov
466
+ scale = torch.pi / float(vfov) * float(h) / pano_img.shape[-2] * resize_factor[i]
467
+ pano_shape = (int(pano_img.shape[-2] * scale), int(pano_img.shape[-1] * scale))
468
+
469
+ mode = "bicubic" if scale >= 1 else "area"
470
+ resized_pano = F.interpolate(pano_img, size=pano_shape, mode=mode)
471
+ else:
472
+ # make sure to copy: resized_pano = pano_img
473
+ resized_pano = pano_img
474
+ pano_shape = pano_img.shape[-2:][::-1]
475
+
476
+ pano_imgs.append((resized_pano, pano_shape))
477
+
478
+ xy = self.pixel_coordinates()
479
+ uv1, valid = self.image2world(xy)
480
+ bearings = self.pixel_bearing_many(uv1)
481
+
482
+ # rotate bearings
483
+ R_yaw = rad2rotmat(self.new_zeros(yaw.shape), self.new_zeros(yaw.shape), yaws)
484
+ rotated_bearings = bearings @ gravity.R @ R_yaw
485
+
486
+ # spherical coordinates
487
+ lon = torch.atan2(rotated_bearings[..., 0], rotated_bearings[..., 2])
488
+ lat = torch.atan2(
489
+ rotated_bearings[..., 1], torch.norm(rotated_bearings[..., [0, 2]], dim=-1)
490
+ )
491
+
492
+ images = []
493
+ for idx, (resized_pano, pano_shape) in enumerate(pano_imgs):
494
+ min_lon, max_lon = -torch.pi, torch.pi
495
+ min_lat, max_lat = -torch.pi / 2.0, torch.pi / 2.0
496
+ min_x, max_x = 0, pano_shape[0] - 1.0
497
+ min_y, max_y = 0, pano_shape[1] - 1.0
498
+
499
+ # map Spherical Coordinates to Panoramic Coordinates
500
+ nx = (lon[idx] - min_lon) / (max_lon - min_lon) * (max_x - min_x) + min_x
501
+ ny = (lat[idx] - min_lat) / (max_lat - min_lat) * (max_y - min_y) + min_y
502
+
503
+ # reshape and cast to numpy for remap
504
+ mapx = nx.reshape((1, h, w))
505
+ mapy = ny.reshape((1, h, w))
506
+
507
+ grid = torch.stack((mapx, mapy), dim=-1) # Add batch dimension
508
+ # Normalize to [-1, 1]
509
+ grid = 2.0 * grid / torch.tensor([pano_shape[-2] - 1, pano_shape[-1] - 1]).to(grid) - 1
510
+ # Apply grid sample
511
+ image = F.grid_sample(resized_pano, grid, align_corners=True)#True
512
+ images.append(image)
513
+
514
+ return torch.concatenate(images, 0) if B > 0 else images[0]
515
+
516
+ def __repr__(self):
517
+ """Print the Camera object."""
518
+ return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
scripts/camera/geometry/camera.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the pinhole, simple radial, and simple divisional camera models."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+
8
+ from scripts.camera.geometry.base_camera import BaseCamera
9
+ from scripts.camera.utils.tensor import autocast
10
+
11
+ # flake8: noqa: E741
12
+
13
+ # mypy: ignore-errors
14
+
15
+
16
+ class Pinhole(BaseCamera):
17
+ """Implementation of the pinhole camera model."""
18
+
19
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
20
+ """Distort normalized 2D coordinates."""
21
+ if return_scale:
22
+ return p2d.new_ones(p2d.shape[:-1] + (1,))
23
+
24
+ return p2d, p2d.new_ones((p2d.shape[0], 1)).bool()
25
+
26
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
27
+ """Jacobian of the distortion function."""
28
+ if wrt == "pts":
29
+ return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
30
+ else:
31
+ raise ValueError(f"Unknown wrt: {wrt}")
32
+
33
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
34
+ """Undistort normalized 2D coordinates."""
35
+ return pts, pts.new_ones((pts.shape[0], 1)).bool()
36
+
37
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
38
+ """Jacobian of the undistortion function."""
39
+ if wrt == "pts":
40
+ return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
41
+ else:
42
+ raise ValueError(f"Unknown wrt: {wrt}")
43
+
44
+
45
+ class SimpleRadial(BaseCamera):
46
+ """Implementation of the simple radial camera model."""
47
+
48
+ @property
49
+ def dist(self) -> torch.Tensor:
50
+ """Distortion parameters, with shape (..., 1)."""
51
+ return self._data[..., 6:]
52
+
53
+ @property
54
+ def k1(self) -> torch.Tensor:
55
+ """Distortion parameters, with shape (...)."""
56
+ return self._data[..., 6]
57
+
58
+ @property
59
+ def k1_hat(self) -> torch.Tensor:
60
+ """Distortion parameters, with shape (...)."""
61
+ return self.k1 / (self.f[..., 1] / self.size[..., 1]) ** 2
62
+
63
+ def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-0.7, 0.7)):
64
+ """Update the self parameters after changing the k1 distortion parameter."""
65
+ delta_dist = self.new_ones(self.dist.shape) * delta
66
+ dist = (self.dist + delta_dist).clamp(*dist_range)
67
+ data = torch.cat([self.size, self.f, self.c, dist], -1)
68
+ return self.__class__(data)
69
+
70
+ @autocast
71
+ def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
72
+ """Check if the distorted points are valid."""
73
+ return p2d.new_ones(p2d.shape[:-1]).bool()
74
+
75
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
76
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
77
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
78
+ radial = 1 + self.k1[..., None, None] * r2
79
+
80
+ if return_scale:
81
+ return radial, None
82
+
83
+ return p2d * radial, self.check_valid(p2d)
84
+
85
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
86
+ """Jacobian of the distortion function."""
87
+ k1 = self.k1[..., None, None]
88
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
89
+ if wrt == "pts": # (..., 2, 2)
90
+ radial = 1 + k1 * r2
91
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
92
+ return (2 * k1 * ppT) + torch.diag_embed(radial.expand(radial.shape[:-1] + (2,)))
93
+ elif wrt == "dist": # (..., 2)
94
+ return r2 * p2d
95
+ elif wrt == "scale2dist": # (..., 1)
96
+ return r2
97
+ elif wrt == "scale2pts": # (..., 2)
98
+ return 2 * k1 * p2d
99
+ else:
100
+ return super().J_distort(p2d, wrt)
101
+
102
+ @autocast
103
+ def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
104
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
105
+ b1 = -self.k1[..., None, None]
106
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
107
+ radial = 1 + b1 * r2
108
+ return p2d * radial, self.check_valid(p2d)
109
+
110
+ @autocast
111
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
112
+ """Jacobian of the undistortion function."""
113
+ b1 = -self.k1[..., None, None]
114
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
115
+ if wrt == "dist":
116
+ return -r2 * p2d
117
+ elif wrt == "pts":
118
+ radial = 1 + b1 * r2
119
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
120
+ return (2 * b1[..., None] * ppT) + torch.diag_embed(
121
+ radial.expand(radial.shape[:-1] + (2,))
122
+ )
123
+ else:
124
+ return super().J_undistort(p2d, wrt)
125
+
126
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
127
+ """Jacobian of the up-projection offset."""
128
+ if wrt == "uv": # (..., 2, 2)
129
+ return torch.diag_embed((2 * self.k1[..., None, None]).expand(p2d.shape[:-1] + (2,)))
130
+ elif wrt == "dist":
131
+ return 2 * p2d # (..., 2)
132
+ else:
133
+ return super().J_up_projection_offset(p2d, wrt)
134
+
135
+
136
+ class SimpleDivisional(BaseCamera):
137
+ """Implementation of the simple divisional camera model."""
138
+
139
+ @property
140
+ def dist(self) -> torch.Tensor:
141
+ """Distortion parameters, with shape (..., 1)."""
142
+ return self._data[..., 6:]
143
+
144
+ @property
145
+ def k1(self) -> torch.Tensor:
146
+ """Distortion parameters, with shape (...)."""
147
+ return self._data[..., 6]
148
+
149
+ def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-3.0, 3.0)):
150
+ """Update the self parameters after changing the k1 distortion parameter."""
151
+ delta_dist = self.new_ones(self.dist.shape) * delta
152
+ dist = (self.dist + delta_dist).clamp(*dist_range)
153
+ data = torch.cat([self.size, self.f, self.c, dist], -1)
154
+ return self.__class__(data)
155
+
156
+ @autocast
157
+ def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
158
+ """Check if the distorted points are valid."""
159
+ return p2d.new_ones(p2d.shape[:-1]).bool()
160
+
161
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
162
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
163
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
164
+ radial = 1 - torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=0))
165
+ denom = 2 * self.k1[..., None, None] * r2
166
+
167
+ ones = radial.new_ones(radial.shape)
168
+ radial = torch.where(denom == 0, ones, radial / denom.masked_fill(denom == 0, 1e6))
169
+
170
+ if return_scale:
171
+ return radial, None
172
+
173
+ return p2d * radial, self.check_valid(p2d)
174
+
175
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
176
+ """Jacobian of the distortion function."""
177
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
178
+ t0 = torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=1e-6))
179
+ if wrt == "scale2pts": # (B, N, 2)
180
+ d1 = t0 * 2 * r2
181
+ d2 = self.k1[..., None, None] * r2**2
182
+ denom = d1 * d2
183
+ return p2d * (4 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
184
+
185
+ elif wrt == "scale2dist":
186
+ d1 = 2 * self.k1[..., None, None] * t0
187
+ d2 = 2 * r2 * self.k1[..., None, None] ** 2
188
+ denom = d1 * d2
189
+ return (2 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
190
+
191
+ else:
192
+ return super().J_distort(p2d, wrt)
193
+
194
+ @autocast
195
+ def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
196
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
197
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
198
+ denom = 1 + self.k1[..., None, None] * r2
199
+ radial = 1 / denom.masked_fill(denom == 0, 1e6)
200
+ return p2d * radial, self.check_valid(p2d)
201
+
202
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
203
+ """Jacobian of the undistortion function."""
204
+ # return super().J_undistort(p2d, wrt)
205
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
206
+ k1 = self.k1[..., None, None]
207
+ if wrt == "dist":
208
+ denom = (1 + k1 * r2) ** 2
209
+ return -r2 / denom.masked_fill(denom == 0, 1e6) * p2d
210
+ elif wrt == "pts":
211
+ t0 = 1 + k1 * r2
212
+ t0 = t0.masked_fill(t0 == 0, 1e6)
213
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
214
+ J = torch.diag_embed((1 / t0).expand(p2d.shape[:-1] + (2,)))
215
+ return J - 2 * k1[..., None] * ppT / t0[..., None] ** 2 # (..., N, 2, 2)
216
+
217
+ else:
218
+ return super().J_undistort(p2d, wrt)
219
+
220
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
221
+ """Jacobian of the up-projection offset.
222
+
223
+ func(uv, dist) = 4 / (2 * norm2(uv)^2 * (1-4*k1*norm2(uv)^2)^0.5) * uv
224
+ - (1-(1-4*k1*norm2(uv)^2)^0.5) / (k1 * norm2(uv)^4) * uv
225
+ """
226
+ k1 = self.k1[..., None, None]
227
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
228
+ t0 = (1 - 4 * k1 * r2).clamp(min=1e-6)
229
+ t1 = torch.sqrt(t0)
230
+ if wrt == "dist":
231
+ denom = 4 * t0 ** (3 / 2)
232
+ denom = denom.masked_fill(denom == 0, 1e6)
233
+ J = 16 / denom
234
+
235
+ denom = r2 * t1 * k1
236
+ denom = denom.masked_fill(denom == 0, 1e6)
237
+ J = J - 2 / denom
238
+
239
+ denom = (r2 * k1) ** 2
240
+ denom = denom.masked_fill(denom == 0, 1e6)
241
+ J = J + (1 - t1) / denom
242
+
243
+ return J * p2d
244
+ elif wrt == "uv":
245
+ # ! unstable (gradient checker might fail), rewrite to use single division (by denom)
246
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
247
+
248
+ denom = 2 * r2 * t1
249
+ denom = denom.masked_fill(denom == 0, 1e6)
250
+ J = torch.diag_embed((4 / denom).expand(p2d.shape[:-1] + (2,)))
251
+
252
+ denom = 4 * t1 * r2**2
253
+ denom = denom.masked_fill(denom == 0, 1e6)
254
+ J = J - 16 / denom[..., None] * ppT
255
+
256
+ denom = 4 * r2 * t0 ** (3 / 2)
257
+ denom = denom.masked_fill(denom == 0, 1e6)
258
+ J = J + (32 * k1[..., None]) / denom[..., None] * ppT
259
+
260
+ denom = r2**2 * t1
261
+ denom = denom.masked_fill(denom == 0, 1e6)
262
+ J = J - 4 / denom[..., None] * ppT
263
+
264
+ denom = k1 * r2**3
265
+ denom = denom.masked_fill(denom == 0, 1e6)
266
+ J = J + (4 * (1 - t1) / denom)[..., None] * ppT
267
+
268
+ denom = k1 * r2**2
269
+ denom = denom.masked_fill(denom == 0, 1e6)
270
+ J = J - torch.diag_embed(((1 - t1) / denom).expand(p2d.shape[:-1] + (2,)))
271
+
272
+ return J
273
+ else:
274
+ return super().J_up_projection_offset(p2d, wrt)
275
+
276
+
277
+ camera_models = {
278
+ "pinhole": Pinhole,
279
+ "simple_radial": SimpleRadial,
280
+ "simple_divisional": SimpleDivisional,
281
+ }
scripts/camera/geometry/gravity.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tensor class for gravity vector in camera frame."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+ from scripts.camera.geometry.manifolds import EuclideanManifold, SphericalManifold
8
+ from scripts.camera.utils.conversions import rad2rotmat
9
+ from scripts.camera.utils.tensor import TensorWrapper, autocast
10
+
11
+ # mypy: ignore-errors
12
+
13
+
14
+ class Gravity(TensorWrapper):
15
+ """Gravity vector in camera frame."""
16
+
17
+ eps = 1e-4
18
+
19
+ @autocast
20
+ def __init__(self, data: torch.Tensor) -> None:
21
+ """Create gravity vector from data.
22
+
23
+ Args:
24
+ data (torch.Tensor): gravity vector as 3D vector in camera frame.
25
+ """
26
+ assert data.shape[-1] == 3, data.shape
27
+
28
+ data = F.normalize(data, dim=-1)
29
+
30
+ super().__init__(data)
31
+
32
+ @classmethod
33
+ def from_rp(cls, roll: torch.Tensor, pitch: torch.Tensor) -> "Gravity":
34
+ """Create gravity vector from roll and pitch angles."""
35
+ if not isinstance(roll, torch.Tensor):
36
+ roll = torch.tensor(roll)
37
+ if not isinstance(pitch, torch.Tensor):
38
+ pitch = torch.tensor(pitch)
39
+
40
+ sr, cr = torch.sin(roll), torch.cos(roll)
41
+ sp, cp = torch.sin(pitch), torch.cos(pitch)
42
+ return cls(torch.stack([-sr * cp, -cr * cp, sp], dim=-1))
43
+
44
+ @property
45
+ def vec3d(self) -> torch.Tensor:
46
+ """Return the gravity vector in the representation."""
47
+ return self._data
48
+
49
+ @property
50
+ def x(self) -> torch.Tensor:
51
+ """Return first component of the gravity vector."""
52
+ return self._data[..., 0]
53
+
54
+ @property
55
+ def y(self) -> torch.Tensor:
56
+ """Return second component of the gravity vector."""
57
+ return self._data[..., 1]
58
+
59
+ @property
60
+ def z(self) -> torch.Tensor:
61
+ """Return third component of the gravity vector."""
62
+ return self._data[..., 2]
63
+
64
+ @property
65
+ def roll(self) -> torch.Tensor:
66
+ """Return the roll angle of the gravity vector."""
67
+ roll = torch.asin(-self.x / (torch.sqrt(1 - self.z**2) + self.eps))
68
+ offset = -torch.pi * torch.sign(self.x)
69
+ return torch.where(self.y < 0, roll, -roll + offset)
70
+
71
+ def J_roll(self) -> torch.Tensor:
72
+ """Return the Jacobian of the roll angle of the gravity vector."""
73
+ cp, _ = torch.cos(self.pitch), torch.sin(self.pitch)
74
+ cr, sr = torch.cos(self.roll), torch.sin(self.roll)
75
+ Jr = self.new_zeros(self.shape + (3,))
76
+ Jr[..., 0] = -cr * cp
77
+ Jr[..., 1] = sr * cp
78
+ return Jr
79
+
80
+ @property
81
+ def pitch(self) -> torch.Tensor:
82
+ """Return the pitch angle of the gravity vector."""
83
+ return torch.asin(self.z)
84
+
85
+ def J_pitch(self) -> torch.Tensor:
86
+ """Return the Jacobian of the pitch angle of the gravity vector."""
87
+ cp, sp = torch.cos(self.pitch), torch.sin(self.pitch)
88
+ cr, sr = torch.cos(self.roll), torch.sin(self.roll)
89
+
90
+ Jp = self.new_zeros(self.shape + (3,))
91
+ Jp[..., 0] = sr * sp
92
+ Jp[..., 1] = cr * sp
93
+ Jp[..., 2] = cp
94
+ return Jp
95
+
96
+ @property
97
+ def rp(self) -> torch.Tensor:
98
+ """Return the roll and pitch angles of the gravity vector."""
99
+ return torch.stack([self.roll, self.pitch], dim=-1)
100
+
101
+ def J_rp(self) -> torch.Tensor:
102
+ """Return the Jacobian of the roll and pitch angles of the gravity vector."""
103
+ return torch.stack([self.J_roll(), self.J_pitch()], dim=-1)
104
+
105
+ @property
106
+ def R(self) -> torch.Tensor:
107
+ """Return the rotation matrix from the gravity vector."""
108
+ return rad2rotmat(roll=self.roll, pitch=self.pitch)
109
+
110
+ def J_R(self) -> torch.Tensor:
111
+ """Return the Jacobian of the rotation matrix from the gravity vector."""
112
+ raise NotImplementedError
113
+
114
+ def update(self, delta: torch.Tensor, spherical: bool = False) -> "Gravity":
115
+ """Update the gravity vector by adding a delta."""
116
+ if spherical:
117
+ data = SphericalManifold.plus(self.vec3d, delta)
118
+ return self.__class__(data)
119
+
120
+ data = EuclideanManifold.plus(self.rp, delta)
121
+ return self.from_rp(data[..., 0], data[..., 1])
122
+
123
+ def J_update(self, spherical: bool = False) -> torch.Tensor:
124
+ """Return the Jacobian of the update."""
125
+ return SphericalManifold if spherical else EuclideanManifold
126
+
127
+ def __repr__(self):
128
+ """Print the Camera object."""
129
+ return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
scripts/camera/geometry/jacobians.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Jacobians for optimization."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ import torch
5
+
6
+
7
+ @torch.jit.script
8
+ def J_vecnorm(vec: torch.Tensor) -> torch.Tensor:
9
+ """Compute the jacobian of vec / norm2(vec).
10
+
11
+ Args:
12
+ vec (torch.Tensor): [..., D] tensor.
13
+
14
+ Returns:
15
+ torch.Tensor: [..., D, D] Jacobian.
16
+ """
17
+ D = vec.shape[-1]
18
+ norm_x = torch.norm(vec, dim=-1, keepdim=True).unsqueeze(-1) # (..., 1, 1)
19
+
20
+ if (norm_x == 0).any():
21
+ norm_x = norm_x + 1e-6
22
+
23
+ xxT = torch.einsum("...i,...j->...ij", vec, vec) # (..., D, D)
24
+ identity = torch.eye(D, device=vec.device, dtype=vec.dtype) # (D, D)
25
+
26
+ return identity / norm_x - (xxT / norm_x**3) # (..., D, D)
27
+
28
+
29
+ @torch.jit.script
30
+ def J_focal2fov(focal: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
31
+ """Compute the jacobian of the focal2fov function."""
32
+ return -4 * h / (4 * focal**2 + h**2)
33
+
34
+
35
+ @torch.jit.script
36
+ def J_up_projection(uv: torch.Tensor, abc: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
37
+ """Compute the jacobian of the up-vector projection.
38
+
39
+ Args:
40
+ uv (torch.Tensor): Normalized image coordinates of shape (..., 2).
41
+ abc (torch.Tensor): Gravity vector of shape (..., 3).
42
+ wrt (str, optional): Parameter to differentiate with respect to. Defaults to "uv".
43
+
44
+ Raises:
45
+ ValueError: If the wrt parameter is unknown.
46
+
47
+ Returns:
48
+ torch.Tensor: Jacobian with respect to the parameter.
49
+ """
50
+ if wrt == "uv":
51
+ c = abc[..., 2][..., None, None, None]
52
+ return -c * torch.eye(2, device=uv.device, dtype=uv.dtype).expand(uv.shape[:-1] + (2, 2))
53
+
54
+ elif wrt == "abc":
55
+ J = uv.new_zeros(uv.shape[:-1] + (2, 3))
56
+ J[..., 0, 0] = 1
57
+ J[..., 1, 1] = 1
58
+ J[..., 0, 2] = -uv[..., 0]
59
+ J[..., 1, 2] = -uv[..., 1]
60
+ return J
61
+
62
+ else:
63
+ raise ValueError(f"Unknown wrt: {wrt}")
scripts/camera/geometry/manifolds.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of manifolds."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ import logging
5
+
6
+ import torch
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class EuclideanManifold:
12
+ """Simple euclidean manifold."""
13
+
14
+ @staticmethod
15
+ def J_plus(x: torch.Tensor) -> torch.Tensor:
16
+ """Plus operator Jacobian."""
17
+ return torch.eye(x.shape[-1]).to(x)
18
+
19
+ @staticmethod
20
+ def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
21
+ """Plus operator."""
22
+ return x + delta
23
+
24
+
25
+ class SphericalManifold:
26
+ """Implementation of the spherical manifold.
27
+
28
+ Following the derivation from 'Integrating Generic Sensor Fusion Algorithms with Sound State
29
+ Representations through Encapsulation of Manifolds' by Hertzberg et al. (B.2, p. 25).
30
+
31
+ Householder transformation following Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by
32
+ Golub et al.
33
+ """
34
+
35
+ @staticmethod
36
+ def householder_vector(x: torch.Tensor) -> torch.Tensor:
37
+ """Return the Householder vector and beta.
38
+
39
+ Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by Golub et al. (Johns Hopkins Studies
40
+ in Mathematical Sciences) but using the nth element of the input vector as pivot instead of
41
+ first.
42
+
43
+ This computes the vector v with v(n) = 1 and beta such that H = I - beta * v * v^T is
44
+ orthogonal and H * x = ||x||_2 * e_n.
45
+
46
+ Args:
47
+ x (torch.Tensor): [..., n] tensor.
48
+
49
+ Returns:
50
+ torch.Tensor: v of shape [..., n]
51
+ torch.Tensor: beta of shape [...]
52
+ """
53
+ sigma = torch.sum(x[..., :-1] ** 2, -1)
54
+ xpiv = x[..., -1]
55
+ norm = torch.norm(x, dim=-1)
56
+ if torch.any(sigma < 1e-7):
57
+ sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma)
58
+ logger.warning("sigma < 1e-7")
59
+
60
+ vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm))
61
+ beta = 2 * vpiv**2 / (sigma + vpiv**2)
62
+ v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1)
63
+ return v, beta
64
+
65
+ @staticmethod
66
+ def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
67
+ """Apply Householder transformation.
68
+
69
+ Args:
70
+ y (torch.Tensor): Vector to transform of shape [..., n].
71
+ v (torch.Tensor): Householder vector of shape [..., n].
72
+ beta (torch.Tensor): Householder beta of shape [...].
73
+
74
+ Returns:
75
+ torch.Tensor: Transformed vector of shape [..., n].
76
+ """
77
+ return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None]
78
+
79
+ @classmethod
80
+ def J_plus(cls, x: torch.Tensor) -> torch.Tensor:
81
+ """Plus operator Jacobian."""
82
+ v, beta = cls.householder_vector(x)
83
+ H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v)
84
+ H = H + torch.eye(H.shape[-1]).to(H)
85
+ return H[..., :-1] # J
86
+
87
+ @classmethod
88
+ def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
89
+ """Plus operator.
90
+
91
+ Equation 109 (p. 25) from 'Integrating Generic Sensor Fusion Algorithms with Sound State
92
+ Representations through Encapsulation of Manifolds' by Hertzberg et al. but using the nth
93
+ element of the input vector as pivot instead of first.
94
+
95
+ Args:
96
+ x: point on the manifold
97
+ delta: tangent vector
98
+ """
99
+ eps = 1e-7
100
+ # keep norm is not equal to 1
101
+ nx = torch.norm(x, dim=-1, keepdim=True)
102
+ nd = torch.norm(delta, dim=-1, keepdim=True)
103
+
104
+ # make sure we don't divide by zero in backward as torch.where computes grad for both
105
+ # branches
106
+ nd_ = torch.where(nd < eps, nd + eps, nd)
107
+ sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_)
108
+
109
+ # cos is applied to last dim instead of first
110
+ exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1)
111
+
112
+ v, beta = cls.householder_vector(x)
113
+ return nx * cls.apply_householder(exp_delta, v, beta)
scripts/camera/geometry/perspective_fields.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of perspective fields.
2
+
3
+ Adapted from https://github.com/jinlinyi/PerspectiveFields/blob/main/perspective2d/utils/panocam.py
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import torch
9
+ from torch.nn import functional as F
10
+
11
+ from scripts.camera.geometry.base_camera import BaseCamera
12
+ from scripts.camera.geometry.gravity import Gravity
13
+ from scripts.camera.geometry.jacobians import J_up_projection, J_vecnorm
14
+ from scripts.camera.geometry.manifolds import SphericalManifold
15
+
16
+ # flake8: noqa: E266
17
+
18
+
19
+ def get_horizon_line(camera: BaseCamera, gravity: Gravity, relative: bool = True) -> torch.Tensor:
20
+ """Get the horizon line from the camera parameters.
21
+
22
+ Args:
23
+ camera (Camera): Camera parameters.
24
+ gravity (Gravity): Gravity vector.
25
+ relative (bool, optional): Whether to normalize horizon line by img_h. Defaults to True.
26
+
27
+ Returns:
28
+ torch.Tensor: In image frame, fraction of image left/right border intersection with
29
+ respect to image height.
30
+ """
31
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
32
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
33
+
34
+ # project horizon midpoint to image plane
35
+ horizon_midpoint = camera.new_tensor([0, 0, 1])
36
+ horizon_midpoint = camera.K @ gravity.R @ horizon_midpoint
37
+ midpoint = horizon_midpoint[:2] / horizon_midpoint[2]
38
+
39
+ # compute left and right offset to borders
40
+ left_offset = midpoint[0] * torch.tan(gravity.roll)
41
+ right_offset = (camera.size[0] - midpoint[0]) * torch.tan(gravity.roll)
42
+ left, right = midpoint[1] + left_offset, midpoint[1] - right_offset
43
+
44
+ horizon = camera.new_tensor([left, right])
45
+ return horizon / camera.size[1] if relative else horizon
46
+
47
+
48
+ def get_up_field(camera: BaseCamera, gravity: Gravity, normalize: bool = True) -> torch.Tensor:
49
+ """Get the up vector field from the camera parameters.
50
+
51
+ Args:
52
+ camera (Camera): Camera parameters.
53
+ normalize (bool, optional): Whether to normalize the up vector. Defaults to True.
54
+
55
+ Returns:
56
+ torch.Tensor: up vector field as tensor of shape (..., h, w, 2).
57
+ """
58
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
59
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
60
+
61
+ w, h = camera.size[0].unbind(-1)
62
+ h, w = h.round().to(int), w.round().to(int)
63
+
64
+ uv = camera.normalize(camera.pixel_coordinates())
65
+
66
+ # projected up is (a, b) - c * (u, v)
67
+ abc = gravity.vec3d
68
+ projected_up2d = abc[..., None, :2] - abc[..., 2, None, None] * uv # (..., N, 2)
69
+
70
+ if hasattr(camera, "dist"):
71
+ d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
72
+ d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
73
+ offset = camera.up_projection_offset(uv) # (..., N, 2)
74
+ offset = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
75
+
76
+ # (..., N, 2)
77
+ projected_up2d = torch.einsum("...Nij,...Nj->...Ni", d_uv + offset, projected_up2d)
78
+
79
+ if normalize:
80
+ projected_up2d = F.normalize(projected_up2d, dim=-1) # (..., N, 2)
81
+
82
+ try:
83
+ del uv, abc, d_uv, offset
84
+ except NameError:
85
+ pass
86
+
87
+ return projected_up2d.reshape(camera.shape[0], h, w, 2)
88
+
89
+
90
+ def J_up_field(
91
+ camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
92
+ ) -> torch.Tensor:
93
+ """Get the jacobian of the up field.
94
+
95
+ Args:
96
+ camera (Camera): Camera parameters.
97
+ gravity (Gravity): Gravity vector.
98
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
99
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
100
+
101
+ Returns:
102
+ torch.Tensor: Jacobian of the up field as a tensor of shape (..., h, w, 2, 2, 3).
103
+ """
104
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
105
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
106
+
107
+ w, h = camera.size[0].unbind(-1)
108
+ h, w = h.round().to(int), w.round().to(int)
109
+
110
+ # Forward
111
+ xy = camera.pixel_coordinates()
112
+ uv = camera.normalize(xy)
113
+
114
+ projected_up2d = gravity.vec3d[..., None, :2] - gravity.vec3d[..., 2, None, None] * uv
115
+
116
+ # Backward
117
+ J = []
118
+
119
+ # (..., N, 2, 2)
120
+ J_norm2proj = J_vecnorm(
121
+ get_up_field(camera, gravity, normalize=False).reshape(camera.shape[0], -1, 2)
122
+ )
123
+
124
+ # distortion values
125
+ if hasattr(camera, "dist"):
126
+ d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
127
+ d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
128
+ offset = camera.up_projection_offset(uv) # (..., N, 2)
129
+ offset_uv = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
130
+
131
+ ######################
132
+ ## Gravity Jacobian ##
133
+ ######################
134
+
135
+ J_proj2abc = J_up_projection(uv, gravity.vec3d, wrt="abc") # (..., N, 2, 3)
136
+
137
+ if hasattr(camera, "dist"):
138
+ # (..., N, 2, 3)
139
+ J_proj2abc = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2abc)
140
+
141
+ J_abc2delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
142
+ J_proj2delta = torch.einsum("...Nij,...jk->...Nik", J_proj2abc, J_abc2delta)
143
+ J_up2delta = torch.einsum("...Nij,...Njk->...Nik", J_norm2proj, J_proj2delta)
144
+ J.append(J_up2delta)
145
+
146
+ ######################
147
+ ### Focal Jacobian ###
148
+ ######################
149
+
150
+ J_proj2uv = J_up_projection(uv, gravity.vec3d, wrt="uv") # (..., N, 2, 2)
151
+
152
+ if hasattr(camera, "dist"):
153
+ J_proj2up = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2uv)
154
+ J_proj2duv = torch.einsum("...i,...j->...ji", offset, projected_up2d)
155
+
156
+ inner = (uv * projected_up2d).sum(-1)[..., None, None]
157
+ J_proj2offset1 = inner * camera.J_up_projection_offset(uv, wrt="uv")
158
+ J_proj2offset2 = torch.einsum("...i,...j->...ij", offset, projected_up2d) # (..., N, 2, 2)
159
+ J_proj2uv = (J_proj2duv + J_proj2offset1 + J_proj2offset2) + J_proj2up
160
+
161
+ J_uv2f = camera.J_normalize(xy) # (..., N, 2, 2)
162
+
163
+ if log_focal:
164
+ J_uv2f = J_uv2f * camera.f[..., None, None, :] # (..., N, 2, 2)
165
+
166
+ J_uv2f = J_uv2f.sum(-1) # (..., N, 2)
167
+
168
+ J_proj2f = torch.einsum("...ij,...j->...i", J_proj2uv, J_uv2f) # (..., N, 2)
169
+ J_up2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2f)[..., None] # (..., N, 2, 1)
170
+ J.append(J_up2f)
171
+
172
+ ######################
173
+ ##### K1 Jacobian ####
174
+ ######################
175
+
176
+ if hasattr(camera, "dist"):
177
+ J_duv = camera.J_distort(uv, wrt="scale2dist")
178
+ J_duv = torch.diag_embed(J_duv.expand(J_duv.shape[:-1] + (2,))) # (..., N, 2, 2)
179
+ J_offset = torch.einsum(
180
+ "...i,...j->...ij", camera.J_up_projection_offset(uv, wrt="dist"), uv
181
+ )
182
+ J_proj2k1 = torch.einsum("...Nij,...Nj->...Ni", J_duv + J_offset, projected_up2d)
183
+ J_k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2k1)[..., None]
184
+ J.append(J_k1)
185
+
186
+ n_params = sum(j.shape[-1] for j in J)
187
+ return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 2, n_params)
188
+
189
+
190
+ def get_latitude_field(camera: BaseCamera, gravity: Gravity) -> torch.Tensor:
191
+ """Get the latitudes of the camera pixels in radians.
192
+
193
+ Latitudes are defined as the angle between the ray and the up vector.
194
+
195
+ Args:
196
+ camera (Camera): Camera parameters.
197
+ gravity (Gravity): Gravity vector.
198
+
199
+ Returns:
200
+ torch.Tensor: Latitudes in radians as a tensor of shape (..., h, w, 1).
201
+ """
202
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
203
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
204
+
205
+ w, h = camera.size[0].unbind(-1)
206
+ h, w = h.round().to(int), w.round().to(int)
207
+
208
+ uv1, _ = camera.image2world(camera.pixel_coordinates())
209
+ rays = camera.pixel_bearing_many(uv1)
210
+
211
+ lat = torch.einsum("...Nj,...j->...N", rays, gravity.vec3d)
212
+
213
+ eps = 1e-6
214
+ lat_asin = torch.asin(lat.clamp(min=-1 + eps, max=1 - eps))
215
+
216
+ try:
217
+ del uv1, rays
218
+ except NameError:
219
+ pass
220
+
221
+ return lat_asin.reshape(camera.shape[0], h, w, 1)
222
+
223
+
224
+ def J_latitude_field(
225
+ camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
226
+ ) -> torch.Tensor:
227
+ """Get the jacobian of the latitude field.
228
+
229
+ Args:
230
+ camera (Camera): Camera parameters.
231
+ gravity (Gravity): Gravity vector.
232
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
233
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
234
+
235
+ Returns:
236
+ torch.Tensor: Jacobian of the latitude field as a tensor of shape (..., h, w, 1, 3).
237
+ """
238
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
239
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
240
+
241
+ w, h = camera.size[0].unbind(-1)
242
+ h, w = h.round().to(int), w.round().to(int)
243
+
244
+ # Forward
245
+ xy = camera.pixel_coordinates()
246
+ uv1, _ = camera.image2world(xy)
247
+ uv1_norm = camera.pixel_bearing_many(uv1) # (..., N, 3)
248
+
249
+ # Backward
250
+ J = []
251
+ J_norm2w_to_img = J_vecnorm(uv1)[..., :2] # (..., N, 2)
252
+
253
+ ######################
254
+ ## Gravity Jacobian ##
255
+ ######################
256
+
257
+ J_delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
258
+ J_delta = torch.einsum("...Ni,...ij->...Nj", uv1_norm, J_delta) # (..., N, 2)
259
+ J.append(J_delta)
260
+
261
+ ######################
262
+ ### Focal Jacobian ###
263
+ ######################
264
+
265
+ J_w_to_img2f = camera.J_image2world(xy, "f") # (..., N, 2, 2)
266
+ if log_focal:
267
+ J_w_to_img2f = J_w_to_img2f * camera.f[..., None, None, :]
268
+ J_w_to_img2f = J_w_to_img2f.sum(-1) # (..., N, 2)
269
+
270
+ J_norm2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2f) # (..., N, 3)
271
+ J_f = torch.einsum("...Ni,...i->...N", J_norm2f, gravity.vec3d).unsqueeze(-1) # (..., N, 1)
272
+ J.append(J_f)
273
+
274
+ ######################
275
+ ##### K1 Jacobian ####
276
+ ######################
277
+
278
+ if hasattr(camera, "dist"):
279
+ J_w_to_img2k1 = camera.J_image2world(xy, "dist") # (..., N, 2)
280
+ # (..., N, 2)
281
+ J_norm2k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2k1)
282
+ # (..., N, 1)
283
+ J_k1 = torch.einsum("...Ni,...i->...N", J_norm2k1, gravity.vec3d).unsqueeze(-1)
284
+ J.append(J_k1)
285
+
286
+ n_params = sum(j.shape[-1] for j in J)
287
+ return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 1, n_params)
288
+
289
+
290
+ def get_perspective_field(
291
+ camera: BaseCamera,
292
+ gravity: Gravity,
293
+ use_up: bool = True,
294
+ use_latitude: bool = True,
295
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
296
+ """Get the perspective field from the camera parameters.
297
+
298
+ Args:
299
+ camera (Camera): Camera parameters.
300
+ gravity (Gravity): Gravity vector.
301
+ use_up (bool, optional): Whether to include the up vector field. Defaults to True.
302
+ use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
303
+
304
+ Returns:
305
+ Tuple[torch.Tensor, torch.Tensor]: Up and latitude fields as tensors of shape
306
+ (..., 2, h, w) and (..., 1, h, w).
307
+ """
308
+ assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
309
+
310
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
311
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
312
+
313
+ w, h = camera.size[0].unbind(-1)
314
+ h, w = h.round().to(int), w.round().to(int)
315
+
316
+ if use_up:
317
+ permute = (0, 3, 1, 2)
318
+ # (..., 2, h, w)
319
+ up = get_up_field(camera, gravity).permute(permute)
320
+ else:
321
+ shape = (camera.shape[0], 2, h, w)
322
+ up = camera.new_zeros(shape)
323
+
324
+ if use_latitude:
325
+ permute = (0, 3, 1, 2)
326
+ # (..., 1, h, w)
327
+ lat = get_latitude_field(camera, gravity).permute(permute)
328
+ else:
329
+ shape = (camera.shape[0], 1, h, w)
330
+ lat = camera.new_zeros(shape)
331
+
332
+ torch.cuda.empty_cache()
333
+
334
+ return up, lat
335
+
336
+
337
+ def J_perspective_field(
338
+ camera: BaseCamera,
339
+ gravity: Gravity,
340
+ use_up: bool = True,
341
+ use_latitude: bool = True,
342
+ spherical: bool = False,
343
+ log_focal: bool = False,
344
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
345
+ """Get the jacobian of the perspective field.
346
+
347
+ Args:
348
+ camera (Camera): Camera parameters.
349
+ gravity (Gravity): Gravity vector.
350
+ use_up (bool, optional): Whether to include the up vector field. Defaults to True.
351
+ use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
352
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
353
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
354
+
355
+ Returns:
356
+ Tuple[torch.Tensor, torch.Tensor]: Up and latitude jacobians as tensors of shape
357
+ (..., h, w, 2, 4) and (..., h, w, 1, 4).
358
+ """
359
+ assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
360
+
361
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
362
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
363
+
364
+ w, h = camera.size[0].unbind(-1)
365
+ h, w = h.round().to(int), w.round().to(int)
366
+
367
+ if use_up:
368
+ J_up = J_up_field(camera, gravity, spherical, log_focal) # (..., h, w, 2, 4)
369
+ else:
370
+ shape = (camera.shape[0], h, w, 2, 4)
371
+ J_up = camera.new_zeros(shape)
372
+
373
+ if use_latitude:
374
+ J_lat = J_latitude_field(camera, gravity, spherical, log_focal) # (..., h, w, 1, 4)
375
+ else:
376
+ shape = (camera.shape[0], h, w, 1, 4)
377
+ J_lat = camera.new_zeros(shape)
378
+
379
+ return J_up, J_lat
scripts/camera/utils/conversions.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for conversions between different representations."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+
9
+ def skew_symmetric(v: torch.Tensor) -> torch.Tensor:
10
+ """Create a skew-symmetric matrix from a (batched) vector of size (..., 3).
11
+
12
+ Args:
13
+ (torch.Tensor): Vector of size (..., 3).
14
+
15
+ Returns:
16
+ (torch.Tensor): Skew-symmetric matrix of size (..., 3, 3).
17
+ """
18
+ z = torch.zeros_like(v[..., 0])
19
+ return torch.stack(
20
+ [
21
+ z,
22
+ -v[..., 2],
23
+ v[..., 1],
24
+ v[..., 2],
25
+ z,
26
+ -v[..., 0],
27
+ -v[..., 1],
28
+ v[..., 0],
29
+ z,
30
+ ],
31
+ dim=-1,
32
+ ).reshape(v.shape[:-1] + (3, 3))
33
+
34
+
35
+ def rad2rotmat(
36
+ roll: torch.Tensor, pitch: torch.Tensor, yaw: Optional[torch.Tensor] = None
37
+ ) -> torch.Tensor:
38
+ """Convert (batched) roll, pitch, yaw angles (in radians) to rotation matrix.
39
+
40
+ Args:
41
+ roll (torch.Tensor): Roll angle in radians.
42
+ pitch (torch.Tensor): Pitch angle in radians.
43
+ yaw (torch.Tensor, optional): Yaw angle in radians. Defaults to None.
44
+
45
+ Returns:
46
+ torch.Tensor: Rotation matrix of shape (..., 3, 3).
47
+ """
48
+ if yaw is None:
49
+ yaw = roll.new_zeros(roll.shape)
50
+
51
+ Rx = pitch.new_zeros(pitch.shape + (3, 3))
52
+ Rx[..., 0, 0] = 1
53
+ Rx[..., 1, 1] = torch.cos(pitch)
54
+ Rx[..., 1, 2] = torch.sin(pitch)
55
+ Rx[..., 2, 1] = -torch.sin(pitch)
56
+ Rx[..., 2, 2] = torch.cos(pitch)
57
+
58
+ Ry = yaw.new_zeros(yaw.shape + (3, 3))
59
+ Ry[..., 0, 0] = torch.cos(yaw)
60
+ Ry[..., 0, 2] = -torch.sin(yaw)
61
+ Ry[..., 1, 1] = 1
62
+ Ry[..., 2, 0] = torch.sin(yaw)
63
+ Ry[..., 2, 2] = torch.cos(yaw)
64
+
65
+ Rz = roll.new_zeros(roll.shape + (3, 3))
66
+ Rz[..., 0, 0] = torch.cos(roll)
67
+ Rz[..., 0, 1] = torch.sin(roll)
68
+ Rz[..., 1, 0] = -torch.sin(roll)
69
+ Rz[..., 1, 1] = torch.cos(roll)
70
+ Rz[..., 2, 2] = 1
71
+
72
+ return Rz @ Rx @ Ry
73
+
74
+
75
+ def fov2focal(fov: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
76
+ """Compute focal length from (vertical/horizontal) field of view.
77
+
78
+ Args:
79
+ fov (torch.Tensor): Field of view in radians.
80
+ size (torch.Tensor): Image height / width in pixels.
81
+
82
+ Returns:
83
+ torch.Tensor: Focal length in pixels.
84
+ """
85
+ return size / 2 / torch.tan(fov / 2)
86
+
87
+
88
+ def focal2fov(focal: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
89
+ """Compute (vertical/horizontal) field of view from focal length.
90
+
91
+ Args:
92
+ focal (torch.Tensor): Focal length in pixels.
93
+ size (torch.Tensor): Image height / width in pixels.
94
+
95
+ Returns:
96
+ torch.Tensor: Field of view in radians.
97
+ """
98
+ return 2 * torch.arctan(size / (2 * focal))
99
+
100
+
101
+ def pitch2rho(pitch: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
102
+ """Compute the distance from principal point to the horizon.
103
+
104
+ Args:
105
+ pitch (torch.Tensor): Pitch angle in radians.
106
+ f (torch.Tensor): Focal length in pixels.
107
+ h (torch.Tensor): Image height in pixels.
108
+
109
+ Returns:
110
+ torch.Tensor: Relative distance to the horizon.
111
+ """
112
+ return torch.tan(pitch) * f / h
113
+
114
+
115
+ def rho2pitch(rho: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
116
+ """Compute the pitch angle from the distance to the horizon.
117
+
118
+ Args:
119
+ rho (torch.Tensor): Relative distance to the horizon.
120
+ f (torch.Tensor): Focal length in pixels.
121
+ h (torch.Tensor): Image height in pixels.
122
+
123
+ Returns:
124
+ torch.Tensor: Pitch angle in radians.
125
+ """
126
+ return torch.atan(rho * h / f)
127
+
128
+
129
+ def rad2deg(rad: torch.Tensor) -> torch.Tensor:
130
+ """Convert radians to degrees.
131
+
132
+ Args:
133
+ rad (torch.Tensor): Angle in radians.
134
+
135
+ Returns:
136
+ torch.Tensor: Angle in degrees.
137
+ """
138
+ return rad / torch.pi * 180
139
+
140
+
141
+ def deg2rad(deg: torch.Tensor) -> torch.Tensor:
142
+ """Convert degrees to radians.
143
+
144
+ Args:
145
+ deg (torch.Tensor): Angle in degrees.
146
+
147
+ Returns:
148
+ torch.Tensor: Angle in radians.
149
+ """
150
+ return deg / 180 * torch.pi
scripts/camera/utils/image.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image preprocessing utilities."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ import collections.abc as collections
5
+ from pathlib import Path
6
+ from typing import Optional, Tuple
7
+
8
+ import cv2
9
+ import kornia
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+
16
+ from tensor import fit_features_to_multiple
17
+
18
+ # mypy: ignore-errors
19
+
20
+
21
+ class ImagePreprocessor:
22
+ """Preprocess images for calibration."""
23
+
24
+ default_conf = {
25
+ "resize": None, # target edge length (320), None for no resizing
26
+ "edge_divisible_by": None,
27
+ "side": "short",
28
+ "interpolation": "bilinear",
29
+ "align_corners": None,
30
+ "antialias": True,
31
+ "square_crop": False,
32
+ "add_padding_mask": False,
33
+ "resize_backend": "kornia", # torchvision, kornia
34
+ }
35
+
36
+ def __init__(self, conf) -> None:
37
+ """Initialize the image preprocessor."""
38
+ super().__init__()
39
+ default_conf = OmegaConf.create(self.default_conf)
40
+ OmegaConf.set_struct(default_conf, True)
41
+ self.conf = OmegaConf.merge(default_conf, conf)
42
+
43
+ def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict:
44
+ """Resize and preprocess an image, return image and resize scale."""
45
+ h, w = img.shape[-2:]
46
+ size = h, w
47
+
48
+ if self.conf.square_crop:
49
+ min_size = min(h, w)
50
+ offset = (h - min_size) // 2, (w - min_size) // 2
51
+ img = img[:, offset[0] : offset[0] + min_size, offset[1] : offset[1] + min_size]
52
+ size = img.shape[-2:]
53
+
54
+ if self.conf.resize is not None:
55
+ if interpolation is None:
56
+ interpolation = self.conf.interpolation
57
+ size = self.get_new_image_size(h, w)
58
+ img = self.resize(img, size, interpolation)
59
+
60
+ scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
61
+ T = np.diag([scale[0].cpu(), scale[1].cpu(), 1])
62
+
63
+ data = {
64
+ "scales": scale,
65
+ "image_size": np.array(size[::-1]),
66
+ "transform": T,
67
+ "original_image_size": np.array([w, h]),
68
+ }
69
+
70
+ if self.conf.edge_divisible_by is not None:
71
+ # crop to make the edge divisible by a number
72
+ w_, h_ = img.shape[-1], img.shape[-2]
73
+ img, _ = fit_features_to_multiple(img, self.conf.edge_divisible_by, crop=True)
74
+ crop_pad = torch.Tensor([img.shape[-1] - w_, img.shape[-2] - h_]).to(img)
75
+ data["crop_pad"] = crop_pad
76
+ data["image_size"] = np.array([img.shape[-1], img.shape[-2]])
77
+
78
+ data["image"] = img
79
+ return data
80
+
81
+ def resize(self, img: torch.Tensor, size: Tuple[int, int], interpolation: str) -> torch.Tensor:
82
+ """Resize an image using the specified backend."""
83
+ if self.conf.resize_backend == "kornia":
84
+ return kornia.geometry.transform.resize(
85
+ img,
86
+ size,
87
+ side=self.conf.side,
88
+ antialias=self.conf.antialias,
89
+ align_corners=self.conf.align_corners,
90
+ interpolation=interpolation,
91
+ )
92
+ elif self.conf.resize_backend == "PIL":
93
+ device = img.device
94
+ imgs = []
95
+ has_batch_dim = img.ndim == 4
96
+ img = img if has_batch_dim else img[None]
97
+ for im in img:
98
+ im = (im.permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8)
99
+ im = Image.fromarray(im).resize(size[::-1], Image.BILINEAR)
100
+ im = torch.tensor(np.array(im)).permute(2, 0, 1) / 255.0
101
+ imgs.append(im.to(device))
102
+ imgs = torch.stack(imgs)
103
+ return imgs if has_batch_dim else imgs[0]
104
+
105
+ elif self.conf.resize_backend == "torchvision":
106
+ return torchvision.transforms.Resize(size, antialias=self.conf.antialias)(img)
107
+ else:
108
+ raise ValueError(f"{self.conf.resize_backend} not implemented.")
109
+
110
+ def load_image(self, image_path: Path) -> dict:
111
+ """Load an image from a path and preprocess it."""
112
+ return self(load_image(image_path))
113
+
114
+ def get_new_image_size(self, h: int, w: int) -> Tuple[int, int]:
115
+ """Get the new image size after resizing."""
116
+ side = self.conf.side
117
+ if isinstance(self.conf.resize, collections.Iterable):
118
+ assert len(self.conf.resize) == 2
119
+ return tuple(self.conf.resize)
120
+ side_size = self.conf.resize
121
+ aspect_ratio = w / h
122
+ if side not in ("short", "long", "vert", "horz"):
123
+ raise ValueError(
124
+ f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'"
125
+ )
126
+ return (
127
+ (side_size, int(side_size * aspect_ratio))
128
+ if side == "vert" or (side != "horz" and (side == "short") ^ (aspect_ratio < 1.0))
129
+ else (int(side_size / aspect_ratio), side_size)
130
+ )
131
+
132
+
133
+ def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
134
+ """Normalize the image tensor and reorder the dimensions."""
135
+ if image.ndim == 3:
136
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
137
+ elif image.ndim == 2:
138
+ image = image[None] # add channel axis
139
+ else:
140
+ raise ValueError(f"Not an image: {image.shape}")
141
+ return torch.tensor(image / 255.0, dtype=torch.float)
142
+
143
+
144
+ def torch_image_to_numpy(image: torch.Tensor) -> np.ndarray:
145
+ """Normalize and reorder the dimensions of an image tensor."""
146
+ if image.ndim == 3:
147
+ image = image.permute((1, 2, 0)) # CxHxW to HxWxC
148
+ elif image.ndim == 2:
149
+ image = image[None] # add channel axis
150
+ else:
151
+ raise ValueError(f"Not an image: {image.shape}")
152
+ return (image.cpu().detach().numpy() * 255).astype(np.uint8)
153
+
154
+
155
+ def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
156
+ """Read an image from path as RGB or grayscale."""
157
+ if not Path(path).exists():
158
+ raise FileNotFoundError(f"No image at path {path}.")
159
+ mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
160
+ image = cv2.imread(str(path), mode)
161
+ if image is None:
162
+ raise IOError(f"Could not read image at {path}.")
163
+ if not grayscale:
164
+ image = image[..., ::-1]
165
+ return image
166
+
167
+
168
+ def write_image(img: torch.Tensor, path: Path):
169
+ """Write an image tensor to a file."""
170
+ img = torch_image_to_numpy(img) if isinstance(img, torch.Tensor) else img
171
+ cv2.imwrite(str(path), img[..., ::-1])
172
+
173
+
174
+ def load_image(path: Path, grayscale: bool = False, return_tensor: bool = True) -> torch.Tensor:
175
+ """Load an image from a path and return as a tensor."""
176
+ image = read_image(path, grayscale=grayscale)
177
+ if return_tensor:
178
+ return numpy_image_to_torch(image)
179
+
180
+ assert image.ndim in [2, 3], f"Not an image: {image.shape}"
181
+ image = image[None] if image.ndim == 2 else image
182
+ return torch.tensor(image.copy(), dtype=torch.uint8)
scripts/camera/utils/tensor.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/cvg/GeoCalib"""
2
+
3
+ import collections.abc as collections
4
+ import functools
5
+ import inspect
6
+ from typing import Callable, List, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ # flake8: noqa
12
+ # mypy: ignore-errors
13
+
14
+
15
+ string_classes = (str, bytes)
16
+
17
+
18
+ def autocast(func: Callable) -> Callable:
19
+ """Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays.
20
+
21
+ Use the device and dtype of the wrapper.
22
+
23
+ Args:
24
+ func (Callable): Method of a TensorWrapper class.
25
+
26
+ Returns:
27
+ Callable: Wrapped method.
28
+ """
29
+
30
+ @functools.wraps(func)
31
+ def wrap(self, *args):
32
+ device = torch.device("cpu")
33
+ dtype = None
34
+ if isinstance(self, TensorWrapper):
35
+ if self._data is not None:
36
+ device = self.device
37
+ dtype = self.dtype
38
+ elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
39
+ raise ValueError(self)
40
+
41
+ cast_args = []
42
+ for arg in args:
43
+ if isinstance(arg, np.ndarray):
44
+ arg = torch.from_numpy(arg)
45
+ arg = arg.to(device=device, dtype=dtype)
46
+ cast_args.append(arg)
47
+ return func(self, *cast_args)
48
+
49
+ return wrap
50
+
51
+
52
+ class TensorWrapper:
53
+ """Wrapper for PyTorch tensors."""
54
+
55
+ _data = None
56
+
57
+ @autocast
58
+ def __init__(self, data: torch.Tensor):
59
+ """Wrapper for PyTorch tensors."""
60
+ self._data = data
61
+
62
+ @property
63
+ def shape(self) -> torch.Size:
64
+ """Shape of the underlying tensor."""
65
+ return self._data.shape[:-1]
66
+
67
+ @property
68
+ def device(self) -> torch.device:
69
+ """Get the device of the underlying tensor."""
70
+ return self._data.device
71
+
72
+ @property
73
+ def dtype(self) -> torch.dtype:
74
+ """Get the dtype of the underlying tensor."""
75
+ return self._data.dtype
76
+
77
+ def __getitem__(self, index) -> torch.Tensor:
78
+ """Get the underlying tensor."""
79
+ return self.__class__(self._data[index])
80
+
81
+ def __setitem__(self, index, item):
82
+ """Set the underlying tensor."""
83
+ self._data[index] = item.data
84
+
85
+ def to(self, *args, **kwargs):
86
+ """Move the underlying tensor to a new device."""
87
+ return self.__class__(self._data.to(*args, **kwargs))
88
+
89
+ def cpu(self):
90
+ """Move the underlying tensor to the CPU."""
91
+ return self.__class__(self._data.cpu())
92
+
93
+ def cuda(self):
94
+ """Move the underlying tensor to the GPU."""
95
+ return self.__class__(self._data.cuda())
96
+
97
+ def pin_memory(self):
98
+ """Pin the underlying tensor to memory."""
99
+ return self.__class__(self._data.pin_memory())
100
+
101
+ def float(self):
102
+ """Cast the underlying tensor to float."""
103
+ return self.__class__(self._data.float())
104
+
105
+ def double(self):
106
+ """Cast the underlying tensor to double."""
107
+ return self.__class__(self._data.double())
108
+
109
+ def detach(self):
110
+ """Detach the underlying tensor."""
111
+ return self.__class__(self._data.detach())
112
+
113
+ def numpy(self):
114
+ """Convert the underlying tensor to a numpy array."""
115
+ return self._data.detach().cpu().numpy()
116
+
117
+ def new_tensor(self, *args, **kwargs):
118
+ """Create a new tensor of the same type and device."""
119
+ return self._data.new_tensor(*args, **kwargs)
120
+
121
+ def new_zeros(self, *args, **kwargs):
122
+ """Create a new tensor of the same type and device."""
123
+ return self._data.new_zeros(*args, **kwargs)
124
+
125
+ def new_ones(self, *args, **kwargs):
126
+ """Create a new tensor of the same type and device."""
127
+ return self._data.new_ones(*args, **kwargs)
128
+
129
+ def new_full(self, *args, **kwargs):
130
+ """Create a new tensor of the same type and device."""
131
+ return self._data.new_full(*args, **kwargs)
132
+
133
+ def new_empty(self, *args, **kwargs):
134
+ """Create a new tensor of the same type and device."""
135
+ return self._data.new_empty(*args, **kwargs)
136
+
137
+ def unsqueeze(self, *args, **kwargs):
138
+ """Create a new tensor of the same type and device."""
139
+ return self.__class__(self._data.unsqueeze(*args, **kwargs))
140
+
141
+ def squeeze(self, *args, **kwargs):
142
+ """Create a new tensor of the same type and device."""
143
+ return self.__class__(self._data.squeeze(*args, **kwargs))
144
+
145
+ @classmethod
146
+ def stack(cls, objects: List, dim=0, *, out=None):
147
+ """Stack a list of objects with the same type and shape."""
148
+ data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
149
+ return cls(data)
150
+
151
+ @classmethod
152
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
153
+ """Support torch functions."""
154
+ if kwargs is None:
155
+ kwargs = {}
156
+ return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented
157
+
158
+
159
+ def map_tensor(input_, func):
160
+ if isinstance(input_, string_classes):
161
+ return input_
162
+ elif isinstance(input_, collections.Mapping):
163
+ return {k: map_tensor(sample, func) for k, sample in input_.items()}
164
+ elif isinstance(input_, collections.Sequence):
165
+ return [map_tensor(sample, func) for sample in input_]
166
+ elif input_ is None:
167
+ return None
168
+ else:
169
+ return func(input_)
170
+
171
+
172
+ def batch_to_numpy(batch):
173
+ return map_tensor(batch, lambda tensor: tensor.cpu().numpy())
174
+
175
+
176
+ def batch_to_device(batch, device, non_blocking=True, detach=False):
177
+ def _func(tensor):
178
+ t = tensor.to(device=device, non_blocking=non_blocking, dtype=torch.float32)
179
+ return t.detach() if detach else t
180
+
181
+ return map_tensor(batch, _func)
182
+
183
+
184
+ def remove_batch_dim(data: dict) -> dict:
185
+ """Remove batch dimension from elements in data"""
186
+ return {
187
+ k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items()
188
+ }
189
+
190
+
191
+ def add_batch_dim(data: dict) -> dict:
192
+ """Add batch dimension to elements in data"""
193
+ return {
194
+ k: v[None] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
195
+ for k, v in data.items()
196
+ }
197
+
198
+
199
+ def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False):
200
+ """Get padding to make the image size a multiple of the given number.
201
+
202
+ Args:
203
+ x (torch.Tensor): Input tensor.
204
+ multiple (int, optional): Multiple.
205
+ crop (bool, optional): Whether to crop or pad. Defaults to False.
206
+
207
+ Returns:
208
+ torch.Tensor: Padding.
209
+ """
210
+ h, w = x.shape[-2:]
211
+
212
+ if crop:
213
+ pad_w = (w // multiple) * multiple - w
214
+ pad_h = (h // multiple) * multiple - h
215
+ else:
216
+ pad_w = (multiple - w % multiple) % multiple
217
+ pad_h = (multiple - h % multiple) % multiple
218
+
219
+ if mode == "center":
220
+ pad_l = pad_w // 2
221
+ pad_r = pad_w - pad_l
222
+ pad_t = pad_h // 2
223
+ pad_b = pad_h - pad_t
224
+ elif mode == "left":
225
+ pad_l = 0
226
+ pad_r = pad_w
227
+ pad_t = 0
228
+ pad_b = pad_h
229
+ else:
230
+ raise ValueError(f"Unknown mode {mode}")
231
+
232
+ return (pad_l, pad_r, pad_t, pad_b)
233
+
234
+
235
+ def fit_features_to_multiple(
236
+ features: torch.Tensor, multiple: int = 32, crop: bool = False
237
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
238
+ """Pad image to a multiple of the given number.
239
+
240
+ Args:
241
+ features (torch.Tensor): Input features.
242
+ multiple (int, optional): Multiple. Defaults to 32.
243
+ crop (bool, optional): Whether to crop or pad. Defaults to False.
244
+
245
+ Returns:
246
+ Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding.
247
+ """
248
+ pad = fit_to_multiple(features, multiple, crop=crop)
249
+ return torch.nn.functional.pad(features, pad, mode="reflect"), pad
scripts/camera/utils/text.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Tuple
3
+
4
+ def parse_camera_params(
5
+ text: str,
6
+ mode: str = "base"
7
+ ) -> Tuple[float, float, float]:
8
+ """
9
+ Extract roll, pitch, fov from text using one of two patterns:
10
+ - 'base' mode: ... are: roll, pitch, fov.
11
+ - 'cot' mode: <answer>roll, pitch, fov</answer>
12
+
13
+ Args:
14
+ text: The full text to search.
15
+ mode: One of {"base", "cot"}.
16
+
17
+ Returns:
18
+ roll, pitch, fov as floats.
19
+
20
+ Raises:
21
+ ValueError if the chosen pattern is not found, or mode is invalid.
22
+ """
23
+ # compile both regexes
24
+ pat_base = re.compile(
25
+ r"are:\s*([+-]?\d+(?:\.\d+)?)\s*,\s*"
26
+ r"([+-]?\d+(?:\.\d+)?)\s*,\s*"
27
+ r"([+-]?\d+(?:\.\d+)?)[\.\s]*$"
28
+ )
29
+ pat_cot = re.compile(
30
+ r"<answer>\s*([+-]?\d+(?:\.\d+)?)\s*,\s*"
31
+ r"([+-]?\d+(?:\.\d+)?)\s*,\s*"
32
+ r"([+-]?\d+(?:\.\d+)?)\s*</answer>"
33
+ )
34
+
35
+ m = None
36
+ if mode == "base":
37
+ m = pat_base.search(text)
38
+ elif mode == "cot":
39
+ m = pat_cot.search(text)
40
+ else:
41
+ raise ValueError(f"Invalid mode: {mode!r}. Choose 'base', 'cot', or 'auto'.")
42
+
43
+ if not m:
44
+ raise ValueError(f"No camera parameters found using mode '{mode}'.")
45
+
46
+ roll_s, pitch_s, fov_s = m.group(1), m.group(2), m.group(3)
47
+ return float(roll_s), float(pitch_s), float(fov_s)
scripts/camera/visualization/visualize_batch.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualization of predicted and ground truth for a single batch."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ from typing import Any, Dict
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from scripts.camera.geometry.perspective_fields import get_latitude_field
10
+ from scripts.camera.utils.conversions import rad2deg
11
+ from scripts.camera.utils.tensor import batch_to_device
12
+ from scripts.camera.visualization.viz2d import (
13
+ plot_confidences,
14
+ plot_heatmaps,
15
+ plot_image_grid,
16
+ plot_latitudes,
17
+ plot_vector_fields,
18
+ )
19
+
20
+
21
+ def make_up_figure(
22
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
23
+ ) -> Dict[str, Any]:
24
+ """Get predicted and ground truth up fields and errors.
25
+
26
+ Args:
27
+ pred (Dict[str, torch.Tensor]): Predicted up field.
28
+ data (Dict[str, torch.Tensor]): Ground truth up field.
29
+ n_pairs (int): Number of pairs to visualize.
30
+
31
+ Returns:
32
+ Dict[str, Any]: Dictionary with figure.
33
+ """
34
+ pred = batch_to_device(pred, "cpu", detach=True)
35
+ data = batch_to_device(data, "cpu", detach=True)
36
+
37
+ n_pairs = min(n_pairs, len(data["image"]))
38
+
39
+ if "up_field" not in pred.keys():
40
+ return {}
41
+
42
+ up_fields = []
43
+ for i in range(n_pairs):
44
+ row = [data["up_field"][i]]
45
+ titles = ["Up GT"]
46
+
47
+ if "up_confidence" in pred.keys():
48
+ row += [pred["up_confidence"][i]]
49
+ titles += ["Up Confidence"]
50
+
51
+ row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row]
52
+ up_fields.append(row)
53
+
54
+ # create figure
55
+ N, M = len(up_fields), len(up_fields[0]) + 1
56
+ imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
57
+ fig, ax = plot_image_grid(imgs, return_fig=True, set_lim=True)
58
+ ax = np.array(ax)
59
+
60
+ for i in range(n_pairs):
61
+ plot_vector_fields([up_fields[i][0]], axes=ax[i, [1]])
62
+ #plot_heatmaps([up_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]])
63
+
64
+ if "up_confidence" in pred.keys():
65
+ plot_confidences([up_fields[i][3]], axes=ax[i, [4]])
66
+
67
+ return {"up": fig}
68
+
69
+
70
+ def make_latitude_figure(
71
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
72
+ ) -> Dict[str, Any]:
73
+ """Get predicted and ground truth latitude fields and errors.
74
+
75
+ Args:
76
+ pred (Dict[str, torch.Tensor]): Predicted latitude field.
77
+ data (Dict[str, torch.Tensor]): Ground truth latitude field.
78
+ n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
79
+
80
+ Returns:
81
+ Dict[str, Any]: Dictionary with figure.
82
+ """
83
+ pred = batch_to_device(pred, "cpu", detach=True)
84
+ data = batch_to_device(data, "cpu", detach=True)
85
+
86
+ n_pairs = min(n_pairs, len(data["image"]))
87
+ latitude_fields = []
88
+
89
+ if "latitude_field" not in pred.keys():
90
+ return {}
91
+
92
+ for i in range(n_pairs):
93
+ row = [
94
+ rad2deg(data["latitude_field"][i][0]),
95
+ #rad2deg(pred["latitude_field"][i][0]),
96
+ #errors[i],
97
+ ]
98
+ titles = ["Latitude GT"]
99
+
100
+ if "latitude_confidence" in pred.keys():
101
+ row += [pred["latitude_confidence"][i]]
102
+ titles += ["Latitude Confidence"]
103
+
104
+ row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row]
105
+ latitude_fields.append(row)
106
+
107
+ # create figure
108
+ N, M = len(latitude_fields), len(latitude_fields[0]) + 1
109
+ imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
110
+ fig, ax = plot_image_grid(imgs, return_fig=True, set_lim=True)
111
+ ax = np.array(ax)
112
+
113
+ for i in range(n_pairs):
114
+ plot_latitudes([latitude_fields[i][0]], is_radians=False, axes=ax[i, [1]])
115
+ #plot_heatmaps([latitude_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]])
116
+
117
+ if "latitude_confidence" in pred.keys():
118
+ plot_confidences([latitude_fields[i][3]], axes=ax[i, [4]])
119
+
120
+ return {"latitude": fig}
121
+
122
+
123
+ def make_camera_figure(
124
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
125
+ ) -> Dict[str, Any]:
126
+ """Get predicted and ground truth camera parameters.
127
+
128
+ Args:
129
+ pred (Dict[str, torch.Tensor]): Predicted camera parameters.
130
+ data (Dict[str, torch.Tensor]): Ground truth camera parameters.
131
+ n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
132
+
133
+ Returns:
134
+ Dict[str, Any]: Dictionary with figure.
135
+ """
136
+ pred = batch_to_device(pred, "cpu", detach=True)
137
+ data = batch_to_device(data, "cpu", detach=True)
138
+
139
+ n_pairs = min(n_pairs, len(data["image"]))
140
+
141
+ if "camera" not in pred.keys():
142
+ return {}
143
+
144
+ latitudes = []
145
+ for i in range(n_pairs):
146
+ titles = ["Cameras GT"]
147
+ row = [get_latitude_field(data["camera"][i], data["gravity"][i])]
148
+
149
+ if "camera" in pred.keys() and "gravity" in pred.keys():
150
+ row += [get_latitude_field(pred["camera"][i], pred["gravity"][i])]
151
+ titles += ["Cameras Pred"]
152
+
153
+ row = [rad2deg(r).squeeze(-1).float().numpy()[0] for r in row]
154
+ latitudes.append(row)
155
+
156
+ # create figure
157
+ N, M = len(latitudes), len(latitudes[0]) + 1
158
+ imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
159
+ fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True)
160
+ ax = np.array(ax)
161
+
162
+ for i in range(n_pairs):
163
+ plot_latitudes(latitudes[i], is_radians=False, axes=ax[i, 1:])
164
+
165
+ return {"camera": fig}
166
+
167
+
168
+ def make_perspective_figures(
169
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
170
+ ) -> Dict[str, Any]:
171
+ """Get predicted and ground truth perspective fields.
172
+
173
+ Args:
174
+ pred (Dict[str, torch.Tensor]): Predicted perspective fields.
175
+ data (Dict[str, torch.Tensor]): Ground truth perspective fields.
176
+ n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
177
+
178
+ Returns:
179
+ Dict[str, Any]: Dictionary with figure.
180
+ """
181
+ n_pairs = min(n_pairs, len(data["image"]))
182
+ figures = make_up_figure(pred, data, n_pairs)
183
+ figures |= make_latitude_figure(pred, data, n_pairs)
184
+ #figures |= make_camera_figure(pred, data, n_pairs)
185
+
186
+ {f.tight_layout() for f in figures.values()}
187
+
188
+ return figures
scripts/camera/visualization/viz2d.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2D visualization primitives based on Matplotlib.
3
+ 1) Plot images with `plot_images`.
4
+ 2) Call TODO: add functions
5
+ 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
6
+ """
7
+ """Adapted from https://github.com/cvg/GeoCalib"""
8
+
9
+ import matplotlib.patheffects as path_effects
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+
14
+ from scripts.camera.geometry.perspective_fields import get_perspective_field
15
+ from scripts.camera.utils.conversions import rad2deg
16
+
17
+ # flake8: noqa
18
+ # mypy: ignore-errors
19
+
20
+
21
+ def cm_ranking(sc, ths=None):
22
+ if ths is None:
23
+ ths = [512, 1024, 2048, 4096]
24
+
25
+ ls = sc.shape[0]
26
+ colors = ["red", "yellow", "lime", "cyan", "blue"]
27
+ out = ["gray"] * ls
28
+ for i in range(ls):
29
+ for c, th in zip(colors[: len(ths) + 1], ths + [ls]):
30
+ if i < th:
31
+ out[i] = c
32
+ break
33
+ sid = np.argsort(sc, axis=0).flip(0)
34
+ return np.array(out)[sid]
35
+
36
+
37
+ def cm_RdBl(x):
38
+ """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
39
+ x = np.clip(x, 0, 1)[..., None] * 2
40
+ c = x * np.array([[0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0]])
41
+ return np.clip(c, 0, 1)
42
+
43
+
44
+ def cm_RdGn(x):
45
+ """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
46
+ x = np.clip(x, 0, 1)[..., None] * 2
47
+ c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
48
+ return np.clip(c, 0, 1)
49
+
50
+
51
+ def cm_BlRdGn(x_):
52
+ """Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
53
+ x = np.clip(x_, 0, 1)[..., None] * 2
54
+ c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
55
+
56
+ xn = -np.clip(x_, -1, 0)[..., None] * 2
57
+ cn = xn * np.array([[0, 1.0, 0, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
58
+ return np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
59
+
60
+
61
+ def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True):
62
+ """Plot a list of images.
63
+
64
+ Args:
65
+ imgs (List[np.ndarray]): List of images to plot.
66
+ titles (List[str], optional): Titles. Defaults to None.
67
+ cmaps (str, optional): Colormaps. Defaults to "gray".
68
+ dpi (int, optional): Dots per inch. Defaults to 200.
69
+ pad (float, optional): Padding. Defaults to 0.5.
70
+ adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
71
+
72
+ Returns:
73
+ plt.Figure: Figure of the images.
74
+ """
75
+ n = len(imgs)
76
+ if not isinstance(cmaps, (list, tuple)):
77
+ cmaps = [cmaps] * n
78
+
79
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n
80
+ figsize = [sum(ratios) * 4.5, 4.5]
81
+ fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios})
82
+ if n == 1:
83
+ axs = [axs]
84
+ for i, (img, ax) in enumerate(zip(imgs, axs)):
85
+ ax.imshow(img, cmap=plt.get_cmap(cmaps[i]))
86
+ ax.set_axis_off()
87
+ if titles:
88
+ ax.set_title(titles[i])
89
+ fig.tight_layout(pad=pad)
90
+
91
+ return fig
92
+
93
+
94
+ def plot_image_grid(
95
+ imgs,
96
+ titles=None,
97
+ cmaps="gray",
98
+ dpi=100,
99
+ pad=0.5,
100
+ fig=None,
101
+ adaptive=True,
102
+ figs=3.0,
103
+ return_fig=False,
104
+ set_lim=False,
105
+ ) -> plt.Figure:
106
+ """Plot a grid of images.
107
+
108
+ Args:
109
+ imgs (List[np.ndarray]): List of images to plot.
110
+ titles (List[str], optional): Titles. Defaults to None.
111
+ cmaps (str, optional): Colormaps. Defaults to "gray".
112
+ dpi (int, optional): Dots per inch. Defaults to 100.
113
+ pad (float, optional): Padding. Defaults to 0.5.
114
+ fig (_type_, optional): Figure to plot on. Defaults to None.
115
+ adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
116
+ figs (float, optional): Figure size. Defaults to 3.0.
117
+ return_fig (bool, optional): Whether to return the figure. Defaults to False.
118
+ set_lim (bool, optional): Whether to set the limits. Defaults to False.
119
+
120
+ Returns:
121
+ plt.Figure: Figure and axes or just axes.
122
+ """
123
+ nr, n = len(imgs), len(imgs[0])
124
+ if not isinstance(cmaps, (list, tuple)):
125
+ cmaps = [cmaps] * n
126
+
127
+ if adaptive:
128
+ ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H
129
+ else:
130
+ ratios = [4 / 3] * n
131
+
132
+ figsize = [sum(ratios) * figs, nr * figs]
133
+ if fig is None:
134
+ fig, axs = plt.subplots(
135
+ nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
136
+ )
137
+ else:
138
+ axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios})
139
+ fig.figure.set_size_inches(figsize)
140
+
141
+ if nr == 1 and n == 1:
142
+ axs = [[axs]]
143
+ elif n == 1:
144
+ axs = axs[:, None]
145
+ elif nr == 1:
146
+ axs = [axs]
147
+
148
+ for j in range(nr):
149
+ for i in range(n):
150
+ ax = axs[j][i]
151
+ ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i]))
152
+ ax.set_axis_off()
153
+ if set_lim:
154
+ ax.set_xlim([0, imgs[j][i].shape[1]])
155
+ ax.set_ylim([imgs[j][i].shape[0], 0])
156
+ if titles:
157
+ ax.set_title(titles[j][i])
158
+ if isinstance(fig, plt.Figure):
159
+ fig.tight_layout(pad=pad)
160
+ return (fig, axs) if return_fig else axs
161
+
162
+
163
+ def add_text(
164
+ idx,
165
+ text,
166
+ pos=(0.01, 0.99),
167
+ fs=15,
168
+ color="w",
169
+ lcolor="k",
170
+ lwidth=4,
171
+ ha="left",
172
+ va="top",
173
+ axes=None,
174
+ **kwargs,
175
+ ):
176
+ """Add text to a plot.
177
+
178
+ Args:
179
+ idx (int): Index of the axes.
180
+ text (str): Text to add.
181
+ pos (tuple, optional): Text position. Defaults to (0.01, 0.99).
182
+ fs (int, optional): Font size. Defaults to 15.
183
+ color (str, optional): Text color. Defaults to "w".
184
+ lcolor (str, optional): Line color. Defaults to "k".
185
+ lwidth (int, optional): Line width. Defaults to 4.
186
+ ha (str, optional): Horizontal alignment. Defaults to "left".
187
+ va (str, optional): Vertical alignment. Defaults to "top".
188
+ axes (List[plt.Axes], optional): Axes to put text on. Defaults to None.
189
+
190
+ Returns:
191
+ plt.Text: Text object.
192
+ """
193
+ if axes is None:
194
+ axes = plt.gcf().axes
195
+
196
+ ax = axes[idx]
197
+
198
+ t = ax.text(
199
+ *pos,
200
+ text,
201
+ fontsize=fs,
202
+ ha=ha,
203
+ va=va,
204
+ color=color,
205
+ transform=ax.transAxes,
206
+ zorder=5,
207
+ **kwargs,
208
+ )
209
+ if lcolor is not None:
210
+ t.set_path_effects(
211
+ [
212
+ path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
213
+ path_effects.Normal(),
214
+ ]
215
+ )
216
+ return t
217
+
218
+
219
+ def plot_heatmaps(
220
+ heatmaps,
221
+ vmin=-1e-6, # include negative zero
222
+ vmax=None,
223
+ cmap="Spectral",
224
+ a=0.5,
225
+ axes=None,
226
+ contours_every=None,
227
+ contour_style="solid",
228
+ colorbar=False,
229
+ ):
230
+ """Plot heatmaps with optional contours.
231
+
232
+ To plot latitude field, set vmin=-90, vmax=90 and contours_every=15.
233
+
234
+ Args:
235
+ heatmaps (List[np.ndarray | torch.Tensor]): List of 2D heatmaps.
236
+ vmin (float, optional): Min Value. Defaults to -1e-6.
237
+ vmax (float, optional): Max Value. Defaults to None.
238
+ cmap (str, optional): Colormap. Defaults to "Spectral".
239
+ a (float, optional): Alpha value. Defaults to 0.5.
240
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
241
+ contours_every (int, optional): If not none, will draw contours. Defaults to None.
242
+ contour_style (str, optional): Style of the contours. Defaults to "solid".
243
+ colorbar (bool, optional): Whether to show colorbar. Defaults to False.
244
+
245
+ Returns:
246
+ List[plt.Artist]: List of artists.
247
+ """
248
+ if axes is None:
249
+ axes = plt.gcf().axes
250
+ artists = []
251
+
252
+ for i in range(len(axes)):
253
+ a_ = a if isinstance(a, float) else a[i]
254
+
255
+ if isinstance(heatmaps[i], torch.Tensor):
256
+ heatmaps[i] = heatmaps[i].detach().cpu().numpy()
257
+
258
+ alpha = a_
259
+ # Plot the heatmap
260
+ art = axes[i].imshow(
261
+ heatmaps[i],
262
+ alpha=alpha,
263
+ vmin=vmin,
264
+ vmax=vmax,
265
+ cmap=cmap,
266
+ )
267
+ if colorbar:
268
+ cmax = vmax or np.percentile(heatmaps[i], 99)
269
+ art.set_clim(vmin, cmax)
270
+ cbar = plt.colorbar(art, ax=axes[i])
271
+ artists.append(cbar)
272
+
273
+ artists.append(art)
274
+
275
+ if contours_every is not None:
276
+ # Add contour lines to the heatmap
277
+ contour_data = np.arange(vmin, vmax + contours_every, contours_every)
278
+
279
+ # Get the colormap colors for contour lines
280
+ contour_colors = [
281
+ plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level))
282
+ for level in contour_data
283
+ ]
284
+ contours = axes[i].contour(
285
+ heatmaps[i],
286
+ levels=contour_data,
287
+ linewidths=2,
288
+ colors=contour_colors,
289
+ linestyles=contour_style,
290
+ )
291
+
292
+ contours.set_clim(vmin, vmax)
293
+
294
+ fmt = {
295
+ level: f"{label}°"
296
+ for level, label in zip(contour_data, contour_data.astype(int).astype(str))
297
+ }
298
+ t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white")
299
+
300
+ for label in t:
301
+ label.set_path_effects(
302
+ [
303
+ path_effects.Stroke(linewidth=1, foreground="k"),
304
+ path_effects.Normal(),
305
+ ]
306
+ )
307
+ artists.append(contours)
308
+
309
+ return artists
310
+
311
+
312
+ def plot_horizon_lines(
313
+ cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None
314
+ ):
315
+ """Plot horizon lines on the perspective field.
316
+
317
+ Args:
318
+ cameras (List[Camera]): List of cameras.
319
+ gravities (List[Gravity]): Gravities.
320
+ line_colors (str, optional): Line Colors. Defaults to "orange".
321
+ lw (int, optional): Line width. Defaults to 2.
322
+ styles (str, optional): Line styles. Defaults to "solid".
323
+ alpha (float, optional): Alphas. Defaults to 1.0.
324
+ ax (List[plt.Axes], optional): Axes to draw horizon line on. Defaults to None.
325
+ """
326
+ if not isinstance(line_colors, list):
327
+ line_colors = [line_colors] * len(cameras)
328
+
329
+ if not isinstance(styles, list):
330
+ styles = [styles] * len(cameras)
331
+
332
+ fig = plt.gcf()
333
+ ax = fig.gca() if ax is None else ax
334
+
335
+ if isinstance(ax, plt.Axes):
336
+ ax = [ax] * len(cameras)
337
+
338
+ assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}"
339
+
340
+ for i in range(len(cameras)):
341
+ _, lat = get_perspective_field(cameras[i], gravities[i])
342
+ # horizon line is zero level of the latitude field
343
+ lat = lat[0, 0].cpu().numpy()
344
+ contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i])
345
+ for contour_line in contours.collections:
346
+ contour_line.set_linestyle(styles[i])
347
+
348
+
349
+ def plot_vector_fields(
350
+ vector_fields,
351
+ cmap="lime",
352
+ subsample=15,
353
+ scale=None,
354
+ lw=None,
355
+ alphas=0.8,
356
+ axes=None,
357
+ ):
358
+ """Plot vector fields.
359
+
360
+ Args:
361
+ vector_fields (List[torch.Tensor]): List of vector fields of shape (2, H, W).
362
+ cmap (str, optional): Color of the vectors. Defaults to "lime".
363
+ subsample (int, optional): Subsample the vector field. Defaults to 15.
364
+ scale (float, optional): Scale of the vectors. Defaults to None.
365
+ lw (float, optional): Line width of the vectors. Defaults to None.
366
+ alphas (float | np.ndarray, optional): Alpha per vector or global. Defaults to 0.8.
367
+ axes (List[plt.Axes], optional): List of axes to draw on. Defaults to None.
368
+
369
+ Returns:
370
+ List[plt.Artist]: List of artists.
371
+ """
372
+ if axes is None:
373
+ axes = plt.gcf().axes
374
+
375
+ vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields]
376
+
377
+ artists = []
378
+
379
+ H, W = vector_fields[0].shape[-2:]
380
+ if scale is None:
381
+ scale = subsample / min(H, W)
382
+
383
+ if lw is None:
384
+ lw = 0.1 / subsample
385
+
386
+ if alphas is None:
387
+ alphas = np.ones_like(vector_fields[0][0])
388
+ alphas = np.stack([alphas] * len(vector_fields), 0)
389
+ elif isinstance(alphas, float):
390
+ alphas = np.ones_like(vector_fields[0][0]) * alphas
391
+ alphas = np.stack([alphas] * len(vector_fields), 0)
392
+ else:
393
+ alphas = np.array(alphas)
394
+
395
+ subsample = min(W, H) // subsample
396
+ offset_x = ((W % subsample) + subsample) // 2
397
+
398
+ samples_x = np.arange(offset_x, W, subsample)
399
+ samples_y = np.arange(int(subsample * 0.9), H, subsample)
400
+
401
+ x_grid, y_grid = np.meshgrid(samples_x, samples_y)
402
+
403
+ for i in range(len(axes)):
404
+ # vector field of shape (2, H, W) with vectors of norm == 1
405
+ vector_field = vector_fields[i]
406
+
407
+ a = alphas[i][samples_y][:, samples_x]
408
+ x, y = vector_field[:, samples_y][:, :, samples_x]
409
+
410
+ c = cmap
411
+ if not isinstance(cmap, str):
412
+ c = cmap[i][samples_y][:, samples_x].reshape(-1, 3)
413
+
414
+ s = scale * min(H, W)
415
+ arrows = axes[i].quiver(
416
+ x_grid,
417
+ y_grid,
418
+ x,
419
+ y,
420
+ scale=s,
421
+ scale_units="width" if H > W else "height",
422
+ units="width" if H > W else "height",
423
+ alpha=a,
424
+ color=c,
425
+ angles="xy",
426
+ antialiased=True,
427
+ width=lw,
428
+ headaxislength=3.5,
429
+ zorder=5,
430
+ )
431
+
432
+ artists.append(arrows)
433
+
434
+ return artists
435
+
436
+
437
+ def plot_latitudes(
438
+ latitude,
439
+ is_radians=True,
440
+ vmin=-90,
441
+ vmax=90,
442
+ cmap="seismic",
443
+ contours_every=15,
444
+ alpha=0.4,
445
+ axes=None,
446
+ **kwargs,
447
+ ):
448
+ """Plot latitudes.
449
+
450
+ Args:
451
+ latitude (List[torch.Tensor]): List of latitudes.
452
+ is_radians (bool, optional): Whether the latitudes are in radians. Defaults to True.
453
+ vmin (int, optional): Min value to clip to. Defaults to -90.
454
+ vmax (int, optional): Max value to clip to. Defaults to 90.
455
+ cmap (str, optional): Colormap. Defaults to "seismic".
456
+ contours_every (int, optional): Contours every. Defaults to 15.
457
+ alpha (float, optional): Alpha value. Defaults to 0.4.
458
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
459
+
460
+ Returns:
461
+ List[plt.Artist]: List of artists.
462
+ """
463
+ if axes is None:
464
+ axes = plt.gcf().axes
465
+
466
+ assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}"
467
+ lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude
468
+ return plot_heatmaps(
469
+ lat,
470
+ vmin=vmin,
471
+ vmax=vmax,
472
+ cmap=cmap,
473
+ a=alpha,
474
+ axes=axes,
475
+ contours_every=contours_every,
476
+ **kwargs,
477
+ )
478
+
479
+
480
+ def plot_confidences(
481
+ confidence,
482
+ as_log=True,
483
+ vmin=-4,
484
+ vmax=0,
485
+ cmap="turbo",
486
+ alpha=0.4,
487
+ axes=None,
488
+ **kwargs,
489
+ ):
490
+ """Plot confidences.
491
+
492
+ Args:
493
+ confidence (List[torch.Tensor]): Confidence maps.
494
+ as_log (bool, optional): Whether to plot in log scale. Defaults to True.
495
+ vmin (int, optional): Min value to clip to. Defaults to -4.
496
+ vmax (int, optional): Max value to clip to. Defaults to 0.
497
+ cmap (str, optional): Colormap. Defaults to "turbo".
498
+ alpha (float, optional): Alpha value. Defaults to 0.4.
499
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
500
+
501
+ Returns:
502
+ List[plt.Artist]: List of artists.
503
+ """
504
+ if axes is None:
505
+ axes = plt.gcf().axes
506
+
507
+ confidence = [c.cpu() if isinstance(c, torch.Tensor) else torch.tensor(c) for c in confidence]
508
+
509
+ assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}"
510
+
511
+ if as_log:
512
+ confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence]
513
+
514
+ # normalize to [0, 1]
515
+ confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence]
516
+ return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs)
517
+
518
+
519
+ def save_plot(path, **kw):
520
+ """Save the current figure without any white margin."""
521
+ plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
src/datasets/utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ from xtuner.dataset.utils import get_bos_eos_token_ids
4
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
5
+ import json
6
+
7
+ INPUT_IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
8
+ OUTPUT_IMAGE_TOKEN_INDEX = -300
9
+ QUERY_TOKEN_INDEX = -400
10
+ QUERY_TOKEN = '<query>'
11
+
12
+ def crop2square(pil_img):
13
+ width, height = pil_img.width, pil_img.height
14
+
15
+ if width > height:
16
+ y0, y1 = 0, height
17
+ x0 = random.randint(0, width - height)
18
+ x1 = x0 + height
19
+ else:
20
+ x0, x1 = 0, width
21
+ y0 = random.randint(0, height - width)
22
+ y1 = y0 + width
23
+
24
+ return pil_img.crop(box=(x0, y0, x1, y1))
25
+
26
+ def load_jsonl(json_file):
27
+ with open(json_file) as f:
28
+ lines = f.readlines()
29
+ data = []
30
+ for line in lines:
31
+ data.append(json.loads(line))
32
+ return data
33
+
34
+
35
+ def encode_fn(example,
36
+ tokenizer,
37
+ max_length=None,
38
+ image_length=1,
39
+ query_length=1,
40
+ input_ids_with_output=True,
41
+ with_image_token=False,
42
+ prompt_template=None,
43
+ truncation='right'):
44
+ """Only support the following three scenarios:
45
+
46
+ 1. Incremental pretraining dataset.
47
+ example['conversation'] = [
48
+ {
49
+ 'input': '',
50
+ 'output': '### Human: Can you write xxx'
51
+ }
52
+ ]
53
+
54
+ 2. Single-turn conversation dataset.
55
+ example['conversation'] = [
56
+ {
57
+ 'input': 'Give three tips for staying healthy.',
58
+ 'output': '1.Eat a balanced diet xxx'
59
+ }
60
+ ]
61
+
62
+ 3. Multi-turn conversation dataset.
63
+ example['conversation'] = [
64
+ {
65
+ 'input': 'Give three tips for staying healthy.',
66
+ 'output': '1.Eat a balanced diet xxx'
67
+ },
68
+ {
69
+ 'input': 'Please expand on the second point.',
70
+ 'output': 'Here is an expanded explanation of the xxx'
71
+ }
72
+ ]
73
+ """
74
+
75
+ bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
76
+ is_multi_turn_conversation = len(example['conversation']) > 1
77
+ if is_multi_turn_conversation:
78
+ assert input_ids_with_output
79
+
80
+ input_ids, labels = [], []
81
+ next_needs_bos_token = True
82
+ for single_turn_conversation in example['conversation']:
83
+ input = single_turn_conversation['input']
84
+ if DEFAULT_IMAGE_TOKEN in input and with_image_token:
85
+ chunk_encode = [
86
+ tokenizer.encode(chunk, add_special_tokens=False)
87
+ for chunk in input.split(DEFAULT_IMAGE_TOKEN)
88
+ ]
89
+ assert len(chunk_encode) == 2
90
+ input_encode = []
91
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
92
+ input_encode.extend(cur_chunk_encode)
93
+ if idx != len(chunk_encode) - 1:
94
+ input_encode += [INPUT_IMAGE_TOKEN_INDEX] * image_length
95
+ else:
96
+ input_encode = tokenizer.encode(input, add_special_tokens=False)
97
+ if next_needs_bos_token:
98
+ input_ids += bos_token_id
99
+ labels += [IGNORE_INDEX] * len(bos_token_id)
100
+ input_ids += input_encode
101
+ labels += [IGNORE_INDEX] * len(input_encode)
102
+ if input_ids_with_output and 'output' in single_turn_conversation:
103
+ # Add output
104
+ output_with_loss = single_turn_conversation.get(
105
+ 'output_with_loss', True)
106
+ output = single_turn_conversation['output']
107
+ if DEFAULT_IMAGE_TOKEN in output and with_image_token:
108
+ chunk_encode = [
109
+ tokenizer.encode(chunk, add_special_tokens=False)
110
+ for chunk in output.split(DEFAULT_IMAGE_TOKEN)
111
+ ]
112
+ assert len(chunk_encode) == 2
113
+ output_encode = []
114
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
115
+ output_encode.extend(cur_chunk_encode)
116
+ if idx != len(chunk_encode) - 1:
117
+ output_encode += [OUTPUT_IMAGE_TOKEN_INDEX] * image_length
118
+ elif QUERY_TOKEN in output:
119
+ chunk_encode = [
120
+ tokenizer.encode(chunk, add_special_tokens=False)
121
+ for chunk in output.split(QUERY_TOKEN)
122
+ ]
123
+ assert len(chunk_encode) == 2
124
+ output_encode = []
125
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
126
+ output_encode.extend(cur_chunk_encode)
127
+ if idx != len(chunk_encode) - 1:
128
+ output_encode += [QUERY_TOKEN_INDEX] * query_length
129
+ else:
130
+ output_encode = tokenizer.encode(output, add_special_tokens=False)
131
+ input_ids += output_encode
132
+ if output_with_loss:
133
+ labels += copy.deepcopy(output_encode)
134
+ else:
135
+ labels += [IGNORE_INDEX] * len(output_encode)
136
+ # Add EOS_TOKEN (with loss)
137
+ if single_turn_conversation.get('need_eos_token', True):
138
+ next_needs_bos_token = True
139
+ input_ids += eos_token_id
140
+ if output_with_loss:
141
+ labels += copy.deepcopy(eos_token_id)
142
+ else:
143
+ labels += [IGNORE_INDEX] * len(eos_token_id)
144
+ else:
145
+ next_needs_bos_token = False
146
+ # Add SEP (without loss)
147
+ sep = single_turn_conversation.get('sep', '')
148
+ if sep != '':
149
+ sep_encode = tokenizer.encode(sep, add_special_tokens=False)
150
+ input_ids += sep_encode
151
+ labels += [IGNORE_INDEX] * len(sep_encode)
152
+
153
+ if max_length is not None and len(input_ids) > max_length:
154
+ if truncation == 'right':
155
+ input_ids = input_ids[:max_length]
156
+ labels = labels[:max_length]
157
+ elif truncation == 'left':
158
+ input_ids = input_ids[-max_length:]
159
+ labels = labels[-max_length:]
160
+ else:
161
+ assert truncation is None
162
+ return {'input_ids': input_ids, 'labels': labels}
src/models/connector/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_connector import ConnectorConfig
2
+ from .modeling_connector import ConnectorEncoder
src/models/connector/configuration_connector.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+
7
+ class ConnectorConfig(PretrainedConfig):
8
+ def __init__(
9
+ self,
10
+ hidden_size=768,
11
+ intermediate_size=3072,
12
+ num_hidden_layers=12,
13
+ num_attention_heads=12,
14
+ hidden_act="gelu_pytorch_tanh",
15
+ layer_norm_eps=1e-6,
16
+ attention_dropout=0.0,
17
+ **kwargs,
18
+ ):
19
+ super().__init__(**kwargs)
20
+
21
+ self.hidden_size = hidden_size
22
+ self.intermediate_size = intermediate_size
23
+ self.num_hidden_layers = num_hidden_layers
24
+ self.num_attention_heads = num_attention_heads
25
+ self.attention_dropout = attention_dropout
26
+ self.layer_norm_eps = layer_norm_eps
27
+ self.hidden_act = hidden_act
src/models/connector/modeling_connector.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Connector model."""
16
+
17
+ import math
18
+ import warnings
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn.init import _calculate_fan_in_and_fan_out
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
28
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.utils import (
31
+ ModelOutput,
32
+ is_flash_attn_2_available,
33
+ is_flash_attn_greater_or_equal_2_10,
34
+ logging,
35
+ replace_return_docstrings,
36
+ torch_int,
37
+ )
38
+ from .configuration_connector import ConnectorConfig
39
+
40
+
41
+ if is_flash_attn_2_available():
42
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+
48
+ def init_weights(module):
49
+ """Initialize the weights"""
50
+ if isinstance(module, nn.Embedding):
51
+ default_flax_embed_init(module.weight)
52
+ elif isinstance(module, ConnectorAttention):
53
+ nn.init.xavier_uniform_(module.q_proj.weight)
54
+ nn.init.xavier_uniform_(module.k_proj.weight)
55
+ nn.init.xavier_uniform_(module.v_proj.weight)
56
+ nn.init.xavier_uniform_(module.out_proj.weight)
57
+ nn.init.zeros_(module.q_proj.bias)
58
+ nn.init.zeros_(module.k_proj.bias)
59
+ nn.init.zeros_(module.v_proj.bias)
60
+ nn.init.zeros_(module.out_proj.bias)
61
+ elif isinstance(module, ConnectorMLP):
62
+ nn.init.xavier_uniform_(module.fc1.weight)
63
+ nn.init.xavier_uniform_(module.fc2.weight)
64
+ nn.init.normal_(module.fc1.bias, std=1e-6)
65
+ nn.init.normal_(module.fc2.bias, std=1e-6)
66
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
67
+ lecun_normal_(module.weight)
68
+ if module.bias is not None:
69
+ nn.init.zeros_(module.bias)
70
+ elif isinstance(module, nn.LayerNorm):
71
+ module.bias.data.zero_()
72
+ module.weight.data.fill_(1.0)
73
+
74
+
75
+ def _trunc_normal_(tensor, mean, std, a, b):
76
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
77
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
78
+ def norm_cdf(x):
79
+ # Computes standard normal cumulative distribution function
80
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
81
+
82
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
83
+ warnings.warn(
84
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
85
+ "The distribution of values may be incorrect.",
86
+ stacklevel=2,
87
+ )
88
+
89
+ # Values are generated by using a truncated uniform distribution and
90
+ # then using the inverse CDF for the normal distribution.
91
+ # Get upper and lower cdf values
92
+ l = norm_cdf((a - mean) / std)
93
+ u = norm_cdf((b - mean) / std)
94
+
95
+ # Uniformly fill tensor with values from [l, u], then translate to
96
+ # [2l-1, 2u-1].
97
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
98
+
99
+ # Use inverse cdf transform for normal distribution to get truncated
100
+ # standard normal
101
+ tensor.erfinv_()
102
+
103
+ # Transform to proper mean, std
104
+ tensor.mul_(std * math.sqrt(2.0))
105
+ tensor.add_(mean)
106
+
107
+ # Clamp to ensure it's in the proper range
108
+ tensor.clamp_(min=a, max=b)
109
+
110
+
111
+ def trunc_normal_tf_(
112
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
113
+ ) -> torch.Tensor:
114
+ """Fills the input Tensor with values drawn from a truncated
115
+ normal distribution. The values are effectively drawn from the
116
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
117
+ with values outside :math:`[a, b]` redrawn until they are within
118
+ the bounds. The method used for generating the random values works
119
+ best when :math:`a \\leq \text{mean} \\leq b`.
120
+
121
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
122
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
123
+ and the result is subsequently scaled and shifted by the mean and std args.
124
+
125
+ Args:
126
+ tensor: an n-dimensional `torch.Tensor`
127
+ mean: the mean of the normal distribution
128
+ std: the standard deviation of the normal distribution
129
+ a: the minimum cutoff value
130
+ b: the maximum cutoff value
131
+ """
132
+ with torch.no_grad():
133
+ _trunc_normal_(tensor, 0, 1.0, a, b)
134
+ tensor.mul_(std).add_(mean)
135
+
136
+
137
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
138
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
139
+ if mode == "fan_in":
140
+ denom = fan_in
141
+ elif mode == "fan_out":
142
+ denom = fan_out
143
+ elif mode == "fan_avg":
144
+ denom = (fan_in + fan_out) / 2
145
+
146
+ variance = scale / denom
147
+
148
+ if distribution == "truncated_normal":
149
+ # constant is stddev of standard normal truncated to (-2, 2)
150
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
151
+ elif distribution == "normal":
152
+ with torch.no_grad():
153
+ tensor.normal_(std=math.sqrt(variance))
154
+ elif distribution == "uniform":
155
+ bound = math.sqrt(3 * variance)
156
+ with torch.no_grad():
157
+ tensor.uniform_(-bound, bound)
158
+ else:
159
+ raise ValueError(f"invalid distribution {distribution}")
160
+
161
+
162
+ def lecun_normal_(tensor):
163
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
164
+
165
+
166
+ def default_flax_embed_init(tensor):
167
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
168
+
169
+
170
+ class ConnectorAttention(nn.Module):
171
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
172
+
173
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
174
+ def __init__(self, config):
175
+ super().__init__()
176
+ self.config = config
177
+ self.embed_dim = config.hidden_size
178
+ self.num_heads = config.num_attention_heads
179
+ self.head_dim = self.embed_dim // self.num_heads
180
+ if self.head_dim * self.num_heads != self.embed_dim:
181
+ raise ValueError(
182
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
183
+ f" {self.num_heads})."
184
+ )
185
+ self.scale = self.head_dim**-0.5
186
+ self.dropout = config.attention_dropout
187
+
188
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
189
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
190
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
191
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ attention_mask: Optional[torch.Tensor] = None,
197
+ output_attentions: Optional[bool] = False,
198
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
199
+ """Input shape: Batch x Time x Channel"""
200
+
201
+ batch_size, q_len, _ = hidden_states.size()
202
+
203
+ query_states = self.q_proj(hidden_states)
204
+ key_states = self.k_proj(hidden_states)
205
+ value_states = self.v_proj(hidden_states)
206
+
207
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
208
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
209
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
210
+
211
+ k_v_seq_len = key_states.shape[-2]
212
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
213
+
214
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
215
+ raise ValueError(
216
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
217
+ f" {attn_weights.size()}"
218
+ )
219
+
220
+ if attention_mask is not None:
221
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
222
+ raise ValueError(
223
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
224
+ )
225
+ attn_weights = attn_weights + attention_mask
226
+
227
+ # upcast attention to fp32
228
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
229
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
230
+ attn_output = torch.matmul(attn_weights, value_states)
231
+
232
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
233
+ raise ValueError(
234
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
235
+ f" {attn_output.size()}"
236
+ )
237
+
238
+ attn_output = attn_output.transpose(1, 2).contiguous()
239
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
240
+
241
+ attn_output = self.out_proj(attn_output)
242
+
243
+ return attn_output, attn_weights
244
+
245
+
246
+ class ConnectorFlashAttention2(ConnectorAttention):
247
+ """
248
+ ConnectorAttention flash attention module. This module inherits from `ConnectorAttention` as the weights of the module stays
249
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
250
+ flash attention and deal with padding tokens in case the input contains any of them.
251
+ """
252
+
253
+ is_causal = False
254
+
255
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
256
+ def __init__(self, *args, **kwargs):
257
+ super().__init__(*args, **kwargs)
258
+
259
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
260
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
261
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
262
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
263
+
264
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
265
+ def forward(
266
+ self,
267
+ hidden_states: torch.Tensor,
268
+ attention_mask: Optional[torch.LongTensor] = None,
269
+ output_attentions: bool = False,
270
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
271
+ output_attentions = False
272
+
273
+ batch_size, q_len, _ = hidden_states.size()
274
+
275
+ query_states = self.q_proj(hidden_states)
276
+ key_states = self.k_proj(hidden_states)
277
+ value_states = self.v_proj(hidden_states)
278
+
279
+ # Flash attention requires the input to have the shape
280
+ # batch_size x seq_length x head_dim x hidden_dim
281
+ # therefore we just need to keep the original shape
282
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
283
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
284
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
285
+
286
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
287
+ # to be able to avoid many of these transpose/reshape/view.
288
+ query_states = query_states.transpose(1, 2)
289
+ key_states = key_states.transpose(1, 2)
290
+ value_states = value_states.transpose(1, 2)
291
+
292
+ dropout_rate = self.dropout if self.training else 0.0
293
+
294
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
295
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
296
+ # cast them back in the correct dtype just to be sure everything works as expected.
297
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
298
+ # in fp32.
299
+
300
+ input_dtype = query_states.dtype
301
+ if input_dtype == torch.float32:
302
+ if torch.is_autocast_enabled():
303
+ target_dtype = torch.get_autocast_gpu_dtype()
304
+ # Handle the case where the model is quantized
305
+ elif hasattr(self.config, "_pre_quantization_dtype"):
306
+ target_dtype = self.config._pre_quantization_dtype
307
+ else:
308
+ target_dtype = self.q_proj.weight.dtype
309
+
310
+ logger.warning_once(
311
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
312
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
313
+ f" {target_dtype}."
314
+ )
315
+
316
+ query_states = query_states.to(target_dtype)
317
+ key_states = key_states.to(target_dtype)
318
+ value_states = value_states.to(target_dtype)
319
+
320
+ attn_output = _flash_attention_forward(
321
+ query_states,
322
+ key_states,
323
+ value_states,
324
+ attention_mask,
325
+ q_len,
326
+ dropout=dropout_rate,
327
+ is_causal=self.is_causal,
328
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
329
+ )
330
+
331
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
332
+ attn_output = self.out_proj(attn_output)
333
+
334
+ if not output_attentions:
335
+ attn_weights = None
336
+
337
+ return attn_output, attn_weights
338
+
339
+
340
+ class ConnectorSdpaAttention(ConnectorAttention):
341
+ """
342
+ Connector attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
343
+ `ConnectorAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
344
+ SDPA API.
345
+ """
346
+
347
+ is_causal = False
348
+
349
+ # Adapted from ConnectorAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
350
+ def forward(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ output_attentions: Optional[bool] = False,
355
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
356
+ if output_attentions:
357
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
358
+ logger.warning_once(
359
+ "ConnectorModel is using ConnectorSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
360
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
361
+ )
362
+ return super().forward(
363
+ hidden_states=hidden_states,
364
+ attention_mask=attention_mask,
365
+ output_attentions=output_attentions,
366
+ )
367
+
368
+ batch_size, q_len, _ = hidden_states.size()
369
+
370
+ query_states = self.q_proj(hidden_states)
371
+ key_states = self.k_proj(hidden_states)
372
+ value_states = self.v_proj(hidden_states)
373
+
374
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
375
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
376
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
377
+
378
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
379
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
380
+ if query_states.device.type == "cuda" and attention_mask is not None:
381
+ query_states = query_states.contiguous()
382
+ key_states = key_states.contiguous()
383
+ value_states = value_states.contiguous()
384
+
385
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
386
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
387
+ is_causal = True if self.is_causal and q_len > 1 else False
388
+
389
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
390
+ query_states,
391
+ key_states,
392
+ value_states,
393
+ attn_mask=attention_mask,
394
+ dropout_p=self.dropout if self.training else 0.0,
395
+ is_causal=is_causal,
396
+ )
397
+
398
+ attn_output = attn_output.transpose(1, 2).contiguous()
399
+ attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
400
+
401
+ attn_output = self.out_proj(attn_output)
402
+
403
+ return attn_output, None
404
+
405
+
406
+ CONNECTOR_ATTENTION_CLASSES = {
407
+ "eager": ConnectorAttention,
408
+ "flash_attention_2": ConnectorFlashAttention2,
409
+ "sdpa": ConnectorSdpaAttention,
410
+ }
411
+
412
+
413
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Connector
414
+ class ConnectorMLP(nn.Module):
415
+ def __init__(self, config):
416
+ super().__init__()
417
+ self.config = config
418
+ self.activation_fn = ACT2FN[config.hidden_act]
419
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
420
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
421
+
422
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
423
+ hidden_states = self.fc1(hidden_states)
424
+ hidden_states = self.activation_fn(hidden_states)
425
+ hidden_states = self.fc2(hidden_states)
426
+ return hidden_states
427
+
428
+
429
+ class ConnectorEncoderLayer(nn.Module):
430
+ def __init__(self, config: ConnectorConfig):
431
+ super().__init__()
432
+ self.embed_dim = config.hidden_size
433
+ self.self_attn = CONNECTOR_ATTENTION_CLASSES[config._attn_implementation](config=config)
434
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
435
+ self.mlp = ConnectorMLP(config)
436
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
437
+
438
+ # Ignore copy
439
+ def forward(
440
+ self,
441
+ hidden_states: torch.Tensor,
442
+ attention_mask: torch.Tensor,
443
+ output_attentions: Optional[bool] = False,
444
+ ) -> Tuple[torch.FloatTensor]:
445
+ """
446
+ Args:
447
+ hidden_states (`torch.FloatTensor`):
448
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
449
+ attention_mask (`torch.FloatTensor`):
450
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
451
+ output_attentions (`bool`, *optional*, defaults to `False`):
452
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
453
+ returned tensors for more detail.
454
+ """
455
+ residual = hidden_states
456
+
457
+ hidden_states = self.layer_norm1(hidden_states)
458
+ hidden_states, attn_weights = self.self_attn(
459
+ hidden_states=hidden_states,
460
+ attention_mask=attention_mask,
461
+ output_attentions=output_attentions,
462
+ )
463
+ hidden_states = residual + hidden_states
464
+
465
+ residual = hidden_states
466
+ hidden_states = self.layer_norm2(hidden_states)
467
+ hidden_states = self.mlp(hidden_states)
468
+ hidden_states = residual + hidden_states
469
+
470
+ outputs = (hidden_states,)
471
+
472
+ if output_attentions:
473
+ outputs += (attn_weights,)
474
+
475
+ return outputs
476
+
477
+
478
+ # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Connector
479
+ class ConnectorEncoder(nn.Module):
480
+ def __init__(self, config: ConnectorConfig):
481
+ super().__init__()
482
+ self.config = config
483
+ self.layers = nn.ModuleList([ConnectorEncoderLayer(config) for _ in range(config.num_hidden_layers)])
484
+ self.gradient_checkpointing = False
485
+ self.apply(init_weights)
486
+
487
+ def forward(self, inputs_embeds):
488
+ hidden_states = inputs_embeds
489
+ for encoder_layer in self.layers:
490
+ if self.gradient_checkpointing and self.training:
491
+ layer_outputs = torch.utils.checkpoint.checkpoint(
492
+ encoder_layer.__call__,
493
+ hidden_states,
494
+ None,
495
+ False,
496
+ use_reentrant=False
497
+ )
498
+ else:
499
+ layer_outputs = encoder_layer(
500
+ hidden_states,
501
+ None,
502
+ output_attentions=False,
503
+ )
504
+
505
+ hidden_states = layer_outputs[0]
506
+
507
+ return hidden_states
src/models/connector/modeling_qwen2.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import Qwen2PreTrainedModel, Qwen2Config
4
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm, Qwen2DecoderLayer
5
+
6
+
7
+ class Qwen2Connector(Qwen2PreTrainedModel):
8
+ def __init__(self, config: Qwen2Config):
9
+ super().__init__(config)
10
+ self.layers = nn.ModuleList(
11
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
12
+ )
13
+
14
+ for layer in self.layers:
15
+ layer.self_attn.is_causal = False
16
+
17
+ self._attn_implementation = config._attn_implementation
18
+ assert self._attn_implementation == 'flash_attention_2'
19
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
20
+
21
+ self.gradient_checkpointing = False
22
+ # Initialize weights and apply final processing
23
+ self.post_init()
24
+
25
+ def forward(self, inputs_embeds):
26
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
27
+ position_ids = position_ids.expand(inputs_embeds.shape[0], -1)
28
+ hidden_states = inputs_embeds
29
+
30
+ for encoder_layer in self.layers:
31
+ if self.gradient_checkpointing and self.training:
32
+ layer_outputs = self._gradient_checkpointing_func(
33
+ encoder_layer.__call__,
34
+ hidden_states,
35
+ None,
36
+ position_ids,
37
+ use_reentrant=False
38
+ )
39
+ else:
40
+ layer_outputs = encoder_layer(
41
+ hidden_states,
42
+ attention_mask=None,
43
+ position_ids=position_ids,
44
+ )
45
+
46
+ hidden_states = layer_outputs[0]
47
+
48
+ hidden_states = self.norm(hidden_states)
49
+
50
+ return hidden_states
src/models/puffin/model.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import math
4
+ from tqdm import tqdm
5
+ from einops import rearrange
6
+ from copy import deepcopy
7
+ from six.moves import zip
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.autograd.function import Function
11
+ from torch.nn.utils.rnn import pad_sequence
12
+ from mmengine.logging import print_log
13
+ from mmengine.model import BaseModel
14
+ from xtuner.utils import IGNORE_INDEX
15
+ from xtuner.registry import BUILDER
16
+ from xtuner.model.utils import guess_load_checkpoint
17
+ from xtuner.dataset.map_fns.template_map_fn import template_map_fn
18
+ from transformers.cache_utils import DynamicCache
19
+ from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
20
+
21
+ from src.models.connector import ConnectorConfig, ConnectorEncoder
22
+ from src.models.stable_diffusion3.pipeline_stable_diffusion_3_dynamic import StableDiffusion3Pipeline
23
+ from src.datasets.utils import encode_fn, QUERY_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, INPUT_IMAGE_TOKEN_INDEX
24
+
25
+ class _ScaleGradient(Function):
26
+ @staticmethod
27
+ def forward(ctx, input, scale):
28
+ ctx.scale = scale
29
+ return input
30
+
31
+ @staticmethod
32
+ def backward(ctx, grad_output):
33
+ return grad_output * ctx.scale, None
34
+
35
+ def build_mlp(hidden_size, projector_dim, z_dim):
36
+ return nn.Sequential(
37
+ nn.Linear(hidden_size, projector_dim),
38
+ nn.SiLU(),
39
+ nn.Linear(projector_dim, z_dim),)
40
+
41
+ def pad_an_image_tensor(image, pad_value=0):
42
+ h, w = image.shape[-2:]
43
+ if h > w:
44
+ pad_left = (h - w) // 2
45
+ pad_right = h - w - pad_left
46
+ p2d = (pad_left, pad_right, 0, 0)
47
+ else:
48
+ pad_top = (h - w) // 2
49
+ pad_bottom = h - w - pad_top
50
+ p2d = (0, 0, pad_top, pad_bottom)
51
+
52
+ image = F.pad(image, p2d, "constant", pad_value)
53
+
54
+ return image
55
+
56
+ class Qwen2p5RadioStableDiffusion3HFDynamic(BaseModel):
57
+ def __init__(self,
58
+ llm,
59
+ tokenizer,
60
+ prompt_template,
61
+ visual_encoder,
62
+ vae,
63
+ transformer,
64
+ train_scheduler,
65
+ test_scheduler,
66
+ connector_1,
67
+ connector_2,
68
+ num_queries=64,
69
+ freeze_transformer=True,
70
+ max_length=256,
71
+ freeze_visual_encoder=True,
72
+ freeze_llm=True,
73
+ visual_encoder_grad_scale=0.1,
74
+ fold_size=2,
75
+ unconditional=0.1,
76
+ unconditional_cross_view=0.1,
77
+ pretrained_pth=None,
78
+ use_activation_checkpointing=False,
79
+ *args, **kwargs):
80
+ super().__init__()
81
+
82
+ # basic settings
83
+ self.max_length = max_length
84
+ self.fold_size = fold_size
85
+ self.prompt_template = prompt_template
86
+ self.unconditional = unconditional
87
+ self.unconditional_cross_view = unconditional_cross_view
88
+
89
+ # networks building
90
+ # understanding branch
91
+ self.visual_encoder = BUILDER.build(visual_encoder)
92
+ self.llm = BUILDER.build(llm)
93
+ self.tokenizer = BUILDER.build(tokenizer)
94
+ self.projector = build_mlp(hidden_size=self.visual_encoder.model.embed_dim*fold_size**2,
95
+ projector_dim=self.llm.config.hidden_size,
96
+ z_dim=self.llm.config.hidden_size)
97
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(prompt_template['IMG_CONTEXT_TOKEN'])
98
+
99
+ # generation branch
100
+ self.vae = BUILDER.build(vae)
101
+ self.vae.requires_grad_(False)
102
+ self.transformer = BUILDER.build(transformer)
103
+ self.num_queries = num_queries
104
+ self.connector_1 = ConnectorEncoder(ConnectorConfig(**connector_1))
105
+ self.connector_2 = ConnectorEncoder(ConnectorConfig(**connector_2))
106
+
107
+ self.llm2connector_1 = nn.Linear(self.llm.config.hidden_size, self.connector_1.config.hidden_size)
108
+ self.llm2connector_2 = nn.Linear(self.llm.config.hidden_size, self.connector_2.config.hidden_size)
109
+ self.projector_1 = nn.Linear(self.connector_1.config.hidden_size, self.transformer.config.pooled_projection_dim)
110
+ self.projector_2 = nn.Linear(self.connector_2.config.hidden_size, self.transformer.config.joint_attention_dim)
111
+ nn.init.zeros_(self.projector_1.weight)
112
+ nn.init.zeros_(self.projector_2.weight)
113
+ nn.init.zeros_(self.projector_1.bias)
114
+ nn.init.zeros_(self.projector_2.bias)
115
+
116
+ self.meta_queries = nn.Parameter(
117
+ torch.zeros(num_queries, self.llm.config.hidden_size))
118
+ nn.init.normal_(self.meta_queries, std=1 / math.sqrt(self.llm.config.hidden_size))
119
+
120
+ # networks and training initialization
121
+ if freeze_visual_encoder:
122
+ self.visual_encoder.requires_grad_(False)
123
+ self.freeze_visual_encoder = freeze_visual_encoder
124
+ if freeze_llm:
125
+ self.llm.requires_grad_(False)
126
+ self.freeze_llm = freeze_llm
127
+ if freeze_transformer:
128
+ self.transformer.requires_grad_(False)
129
+ self.freeze_transformer = freeze_transformer
130
+
131
+ self.visual_encoder_grad_scale = visual_encoder_grad_scale
132
+ self.train_scheduler = BUILDER.build(train_scheduler)
133
+ self.test_scheduler = BUILDER.build(test_scheduler)
134
+
135
+ self.use_activation_checkpointing = use_activation_checkpointing
136
+ if use_activation_checkpointing:
137
+ self.llm.enable_input_require_grads()
138
+ self.gradient_checkpointing_enable()
139
+
140
+ if pretrained_pth is not None:
141
+ pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
142
+ info = self.load_state_dict(pretrained_state_dict, strict=False)
143
+ print_log(f'Load pretrained weight from {pretrained_pth}')
144
+
145
+ @property
146
+ def device(self):
147
+ return self.llm.device
148
+
149
+ @property
150
+ def dtype(self):
151
+ return self.llm.dtype
152
+
153
+ def gradient_checkpointing_enable(self):
154
+ self.activation_checkpointing_enable()
155
+
156
+ def activation_checkpointing_enable(self):
157
+ self.llm.gradient_checkpointing_enable()
158
+ self.transformer.enable_gradient_checkpointing()
159
+ self.connector_1.gradient_checkpointing = True
160
+ self.connector_2.gradient_checkpointing = True
161
+
162
+ def gradient_checkpointing_disable(self):
163
+ self.activation_checkpointing_disable()
164
+
165
+ def activation_checkpointing_disable(self):
166
+ self.llm.gradient_checkpointing_disable()
167
+ self.transformer.disable_gradient_checkpointing()
168
+ self.connector_1.gradient_checkpointing = False
169
+ self.connector_2.gradient_checkpointing = False
170
+
171
+ def forward(self, data, data_samples=None, mode='loss'):
172
+ if mode == 'loss':
173
+ return self.compute_loss(data_dict=data)
174
+ else:
175
+ raise NotImplementedError
176
+
177
+ def extract_visual_features(self, pixel_values):
178
+ pixel_values = (pixel_values + 1.0) / 2 # [0, 1]
179
+ height, width = pixel_values.shape[-2:]
180
+ summary, features = self.visual_encoder(pixel_values)
181
+ patch_size = int((height * width // features.shape[1]) ** 0.5)
182
+ height, width = height // (patch_size * self.fold_size), width // (patch_size * self.fold_size)
183
+ features = rearrange(features, 'b (h p w q) d -> b (h w) (p q d)',
184
+ h=height, w=width, p=self.fold_size, q=self.fold_size)
185
+
186
+ return features
187
+
188
+ def llm2dit(self, x):
189
+ x_1 = self.connector_1(self.llm2connector_1(x))
190
+ x_1 = self.projector_1(x_1.mean(1))
191
+ x_2 = self.connector_2(self.llm2connector_2(x))
192
+ x_2 = self.projector_2(x_2)
193
+
194
+ return x_1, x_2
195
+
196
+
197
+ @torch.no_grad()
198
+ def prepare_gen_prompts(self, texts, data_type='text2image', num_refs=None, ref_lens=None, gen_type='GENERATION_CROSS'):
199
+ if data_type == 'text2image':
200
+ prompts = [self.prompt_template['GENERATION'].format(input=text) for text in texts]
201
+ prompts = [self.prompt_template['INSTRUCTION'].format(input=text) for text in prompts]
202
+
203
+ elif data_type == 'image2image':
204
+ assert num_refs is not None and ref_lens is not None, "num_refs and ref_lens are required for image2image"
205
+ prompts = []
206
+ cnt = 0
207
+ for text, num_ref in zip(texts, num_refs):
208
+ image_tokens = ''
209
+ for _ in range(num_ref):
210
+ image_tokens += (
211
+ self.prompt_template['IMG_START_TOKEN'] +
212
+ self.prompt_template['IMG_CONTEXT_TOKEN'] * ref_lens[cnt] +
213
+ self.prompt_template['IMG_END_TOKEN']
214
+ )
215
+ cnt += 1
216
+
217
+ text = self.prompt_template[gen_type].format(input=text)
218
+ prompt = self.prompt_template['INSTRUCTION'].format(input=f'{image_tokens}\n{text}')
219
+ prompts.append(prompt)
220
+ else:
221
+ raise ValueError(f"Unsupported data_type: {data_type}")
222
+
223
+ return self.tokenizer(
224
+ prompts, add_special_tokens=True, return_tensors='pt', padding=True, padding_side='left').to(self.device)
225
+
226
+
227
+ @torch.no_grad()
228
+ def prepare_und_prompts(self, conversations, data_type='image2text', image_lengths=None, input_ids_with_output=True):
229
+ input_ids, labels, input_lengths = [], [], []
230
+
231
+ if data_type == 'image2text':
232
+ assert image_lengths is not None, "`image_lengths` must be provided for image2text"
233
+ if isinstance(image_lengths, int):
234
+ image_lengths = [image_lengths] * len(conversations)
235
+ elif data_type == 'text2text':
236
+ image_lengths = [None] * len(conversations)
237
+ else:
238
+ raise ValueError(f"Unsupported data_type: {data_type}")
239
+
240
+ for conv, image_len in zip(conversations, image_lengths):
241
+ data_dict = template_map_fn(example=dict(conversation=deepcopy(conv)), template=self.prompt_template)
242
+ data_dict.update(encode_fn(data_dict,
243
+ tokenizer=self.tokenizer,
244
+ max_length=None,
245
+ input_ids_with_output=input_ids_with_output,
246
+ with_image_token=(data_type == 'image2text'),
247
+ image_length=image_len,
248
+ prompt_template=self.prompt_template))
249
+
250
+ input_ids.append(torch.tensor(data_dict['input_ids'], dtype=torch.long, device=self.device))
251
+ labels.append(torch.tensor(data_dict['labels'], dtype=torch.long, device=self.device))
252
+ input_lengths.append(len(data_dict['input_ids']))
253
+
254
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0, padding_side='left')
255
+ labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX, padding_side='left')
256
+
257
+ attention_mask = torch.zeros_like(input_ids).bool()
258
+ for i in range(len(input_ids)):
259
+ attention_mask[i, -input_lengths[i]:] = True
260
+
261
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
262
+ position_ids[position_ids < 0] = 0
263
+
264
+ return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
265
+
266
+ def train(self, mode=True):
267
+ super().train(mode=mode)
268
+ self.vae.train(mode=False)
269
+ if not mode:
270
+ self.gradient_checkpointing_disable()
271
+
272
+ return self
273
+
274
+ @torch.no_grad()
275
+ def pixels_to_latents(self, x):
276
+ z = self.vae.encode(x).latent_dist.sample()
277
+ z = (z - self.vae.config.shift_factor) * self.vae.config.scaling_factor
278
+ return z
279
+
280
+ @torch.no_grad()
281
+ def latents_to_pixels(self, z):
282
+ z = (z / self.vae.config.scaling_factor) + self.vae.config.shift_factor
283
+ x_rec = self.vae.decode(z).sample
284
+ return x_rec
285
+
286
+ def prepare_forward_input(self,
287
+ query_embeds,
288
+ input_ids=None,
289
+ image_embeds=None,
290
+ attention_mask=None,
291
+ past_key_values=None,
292
+ append_queries=True):
293
+ b, l, _ = query_embeds.shape
294
+ assert l > 0
295
+ attention_mask = attention_mask.to(device=self.device, dtype=torch.bool)
296
+ assert l == self.num_queries
297
+
298
+ if append_queries:
299
+ input_ids = torch.cat([
300
+ input_ids, input_ids.new_full(size=(b, l), fill_value=QUERY_TOKEN_INDEX)], dim=1)
301
+ attention_mask = torch.cat([attention_mask, attention_mask.new_ones(b, l)], dim=1)
302
+
303
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
304
+ position_ids[position_ids < 0] = 0
305
+
306
+ # prepare context
307
+ if past_key_values is not None:
308
+ inputs_embeds = query_embeds
309
+ position_ids = position_ids[..., -l:]
310
+ else:
311
+ inputs_embeds = torch.zeros(*input_ids.shape, self.llm.config.hidden_size,
312
+ device=self.device, dtype=self.dtype)
313
+ if image_embeds is not None:
314
+ inputs_embeds[input_ids == self.image_token_id] = \
315
+ image_embeds.contiguous().view(-1, self.llm.config.hidden_size)
316
+
317
+ inputs_embeds[input_ids == QUERY_TOKEN_INDEX] = \
318
+ query_embeds.contiguous().view(-1, self.llm.config.hidden_size)
319
+
320
+ text_places = torch.logical_and(input_ids != self.image_token_id, input_ids != QUERY_TOKEN_INDEX)
321
+
322
+ inputs_embeds[text_places] = self.llm.get_input_embeddings()(input_ids[text_places])
323
+
324
+ inputs = dict(inputs_embeds=inputs_embeds,
325
+ attention_mask=attention_mask,
326
+ position_ids=position_ids,
327
+ past_key_values=past_key_values)
328
+
329
+ return inputs
330
+
331
+ def get_sigmas(self, timesteps, n_dim=4):
332
+ sigmas = self.train_scheduler.sigmas.to(device=self.device, dtype=self.dtype)
333
+ schedule_timesteps = self.train_scheduler.timesteps.to(self.device)
334
+ timesteps = timesteps.to(self.device)
335
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
336
+
337
+ sigma = sigmas[step_indices].flatten()
338
+ while len(sigma.shape) < n_dim:
339
+ sigma = sigma.unsqueeze(-1)
340
+ return sigma
341
+
342
+ def diff_loss(self, model_input, pooled_prompt_embeds, prompt_embeds, cond_input=None):
343
+ noise = [torch.randn_like(x) for x in model_input]
344
+ bsz = len(model_input)
345
+
346
+ u = compute_density_for_timestep_sampling(
347
+ weighting_scheme='none',
348
+ batch_size=bsz,
349
+ logit_mean=0.0,
350
+ logit_std=1.0,
351
+ )
352
+ indices = (u * self.train_scheduler.config.num_train_timesteps).long()
353
+ timesteps = self.train_scheduler.timesteps[indices].to(device=self.device)
354
+
355
+ # Add noise according to flow matching
356
+ sigmas = self.get_sigmas(timesteps, n_dim=model_input[0].ndim + 1)
357
+ noisy_model_input = [(1.0 - x) * y + x * z for x, y, z in zip(sigmas, model_input, noise)]
358
+
359
+ # Predict the noise residual
360
+ model_pred = self.transformer(
361
+ hidden_states=noisy_model_input,
362
+ cond_hidden_states=cond_input,
363
+ encoder_hidden_states=prompt_embeds,
364
+ pooled_projections=pooled_prompt_embeds,
365
+ timestep=timesteps,
366
+ return_dict=False,
367
+ )[0]
368
+
369
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme='none', sigmas=sigmas)
370
+
371
+ # flow matching loss
372
+ target = [x - y for x, y in zip(noise, model_input)]
373
+
374
+ loss = [(x.float() * (y.float() - z.float()) ** 2).mean() for x, y, z in zip(weighting, model_pred, target)]
375
+ loss = sum(loss) / len(loss)
376
+
377
+ return loss
378
+
379
+ '''text-to-image generation (single-view)'''
380
+ def text2image_loss(self, data_dict):
381
+ pixel_values = [p.to(dtype=self.dtype, device=self.device) for p in data_dict['pixel_values']]
382
+ image_latents = [self.pixels_to_latents(p[None])[0] for p in pixel_values]
383
+
384
+ b = len(image_latents)
385
+
386
+ texts = ['' if random.uniform(0, 1) < self.unconditional else text
387
+ for text in data_dict['texts']]
388
+
389
+ text_inputs = self.prepare_gen_prompts(texts)
390
+ hidden_states = self.meta_queries[None].expand(b, self.num_queries, -1)
391
+
392
+ inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs)
393
+
394
+ max_length = self.max_length + self.num_queries
395
+ inputs_embeds = inputs['inputs_embeds'][:, -max_length:]
396
+ attention_mask = inputs['attention_mask'][:, -max_length:]
397
+ position_ids = inputs['position_ids'][:, -max_length:]
398
+
399
+ output = self.llm.model(
400
+ inputs_embeds=inputs_embeds,
401
+ attention_mask=attention_mask,
402
+ position_ids=position_ids,
403
+ return_dict=True)
404
+
405
+ hidden_states = output.last_hidden_state[:, -self.num_queries:]
406
+ pooled_prompt_embeds, prompt_embeds = self.llm2dit(hidden_states)
407
+
408
+ loss_diff = self.diff_loss(model_input=image_latents,
409
+ pooled_prompt_embeds=pooled_prompt_embeds,
410
+ prompt_embeds=prompt_embeds)
411
+
412
+ return loss_diff
413
+
414
+ '''text-to-image generation (single-view) with camera map'''
415
+ def cam2image_loss(self, data_dict):
416
+ pixel_values = [p.to(dtype=self.dtype, device=self.device) for p in data_dict['pixel_values']]
417
+ image_latents = [self.pixels_to_latents(p[None])[0] for p in pixel_values]
418
+ b = len(image_latents)
419
+ # camera map as condition for the diffusion model
420
+ cam_values = [[img.to(dtype=self.dtype, device=self.device) for img in ref_images]
421
+ for ref_images in data_dict['cam_values']]
422
+ cam_latents = [[self.pixels_to_latents(img[None])[0] for img in ref_images]
423
+ for ref_images in cam_values]
424
+
425
+ texts = ['' if random.uniform(0, 1) < self.unconditional else text
426
+ for text in data_dict['texts']]
427
+
428
+ text_inputs = self.prepare_gen_prompts(texts)
429
+ hidden_states = self.meta_queries[None].expand(b, self.num_queries, -1)
430
+
431
+ inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs)
432
+
433
+ max_length = self.max_length + self.num_queries
434
+ inputs_embeds = inputs['inputs_embeds'][:, -max_length:]
435
+ attention_mask = inputs['attention_mask'][:, -max_length:]
436
+ position_ids = inputs['position_ids'][:, -max_length:]
437
+
438
+ output = self.llm.model(
439
+ inputs_embeds=inputs_embeds,
440
+ attention_mask=attention_mask,
441
+ position_ids=position_ids,
442
+ return_dict=True)
443
+
444
+ hidden_states = output.last_hidden_state[:, -self.num_queries:]
445
+ pooled_prompt_embeds, prompt_embeds = self.llm2dit(hidden_states)
446
+
447
+ loss_diff = self.diff_loss(model_input=image_latents,
448
+ pooled_prompt_embeds=pooled_prompt_embeds,
449
+ prompt_embeds=prompt_embeds,
450
+ cond_input=cam_latents)
451
+
452
+ return loss_diff
453
+
454
+ '''image-to-image (cross-view) generation'''
455
+ def image2image_loss(self, data_dict):
456
+ # condition for the diffusion model (concat the camera map and the initial view)
457
+ cam_values = [[img.to(dtype=self.dtype, device=self.device) for img in ref_images]
458
+ for ref_images in data_dict['cam_values']]
459
+ cam_latents = [[self.pixels_to_latents(img[None])[0] for img in ref_images]
460
+ for ref_images in cam_values]
461
+ pixel_values_init = [[img.to(dtype=self.dtype, device=self.device) for img in ref_images]
462
+ for ref_images in data_dict['pixel_values_init']]
463
+ image_latents_init = [[self.pixels_to_latents(img[None])[0] for img in ref_images]
464
+ for ref_images in pixel_values_init]
465
+ mix_latents = [cam + img for cam, img in zip(cam_latents, image_latents_init)]
466
+
467
+ # condition embedding for querying the LLM (only initial view)
468
+ num_refs = [len(ref_images) for ref_images in pixel_values_init]
469
+ image_embeds = self.extract_visual_features(
470
+ torch.stack([pad_an_image_tensor(img) for ref_images in pixel_values_init for img in ref_images]))
471
+
472
+ image_embeds = self.projector(image_embeds)
473
+ ref_lens = [len(x) for x in image_embeds]
474
+ text_inputs = self.prepare_gen_prompts(data_dict['texts'], data_type='image2image',
475
+ num_refs=num_refs, ref_lens=ref_lens)
476
+
477
+ # input for the diffusion model
478
+ pixel_values = [p.to(dtype=self.dtype, device=self.device) for p in data_dict['pixel_values']]
479
+ image_latents = [self.pixels_to_latents(p[None])[0] for p in pixel_values]
480
+
481
+ # querying the LLM
482
+ b = len(image_latents)
483
+ hidden_states = self.meta_queries[None].expand(b, self.num_queries, -1)
484
+ inputs = self.prepare_forward_input(query_embeds=hidden_states, image_embeds=image_embeds, **text_inputs)
485
+
486
+ max_length = self.max_length + max(num_refs) * max(ref_lens) + self.num_queries
487
+ inputs_embeds = inputs['inputs_embeds'][:, -max_length:]
488
+ attention_mask = inputs['attention_mask'][:, -max_length:]
489
+ position_ids = inputs['position_ids'][:, -max_length:]
490
+
491
+ output = self.llm.model(inputs_embeds=inputs_embeds,
492
+ attention_mask=attention_mask,
493
+ position_ids=position_ids,
494
+ return_dict=True)
495
+ hidden_states = output.last_hidden_state[:, -self.num_queries:]
496
+ pooled_prompt_embeds, prompt_embeds = self.llm2dit(hidden_states)
497
+ loss_diff = self.diff_loss(model_input=image_latents,
498
+ pooled_prompt_embeds=pooled_prompt_embeds,
499
+ prompt_embeds=prompt_embeds,
500
+ cond_input=mix_latents)
501
+
502
+ return loss_diff
503
+
504
+ '''image-to-text(camera) understanding, mixed base, thinking, and instruction tuning'''
505
+ def image2text_loss(self, data_dict):
506
+ pixel_values = [pad_an_image_tensor(img) for img in data_dict['pixel_values']]
507
+ pixel_values = torch.stack(pixel_values).to(dtype=self.dtype, device=self.device)
508
+ image_embeds = self.extract_visual_features(pixel_values)
509
+
510
+ if not self.freeze_visual_encoder:
511
+ image_embeds = _ScaleGradient.apply(image_embeds, self.visual_encoder_grad_scale)
512
+
513
+ image_embeds = self.projector(image_embeds)
514
+ text_inputs = self.prepare_und_prompts(conversations=data_dict['conversations'],
515
+ data_type='image2text',
516
+ image_lengths=image_embeds.shape[1])
517
+
518
+ labels, input_ids, attention_mask, position_ids = \
519
+ text_inputs['labels'], text_inputs['input_ids'], text_inputs['attention_mask'], text_inputs['position_ids']
520
+
521
+
522
+ inputs_embeds = torch.zeros(*input_ids.shape, self.llm.config.hidden_size,
523
+ device=self.device, dtype=self.dtype)
524
+ inputs_embeds[input_ids == INPUT_IMAGE_TOKEN_INDEX] = image_embeds.flatten(0, 1)
525
+ inputs_embeds[input_ids != INPUT_IMAGE_TOKEN_INDEX] = \
526
+ self.llm.get_input_embeddings()(input_ids[input_ids != INPUT_IMAGE_TOKEN_INDEX])
527
+
528
+ max_length = self.max_length + image_embeds.shape[1]
529
+ inputs_embeds = inputs_embeds[:, -max_length:]
530
+ attention_mask = attention_mask[:, -max_length:]
531
+ position_ids = position_ids[:, -max_length:]
532
+ labels = labels[:, -max_length:]
533
+
534
+ output = self.llm.model(inputs_embeds=inputs_embeds,
535
+ attention_mask=attention_mask,
536
+ position_ids=position_ids,
537
+ return_dict=True)
538
+
539
+ hidden_states = output.last_hidden_state[:, :-1]
540
+ labels = labels[:, 1:]
541
+ hidden_states = hidden_states[labels >= 0]
542
+ labels = labels[labels >= 0]
543
+
544
+ logits = self.llm.get_output_embeddings()(hidden_states)
545
+ loss = F.cross_entropy(input=logits, target=labels)
546
+
547
+ return loss
548
+
549
+ '''text-to-text understanding, offering the enhanced caption for the generation'''
550
+ def text2text_loss(self, data_dict):
551
+ text_inputs = self.prepare_und_prompts(conversations=data_dict['conversations'], data_type='text2text')
552
+ labels, input_ids, attention_mask, position_ids = \
553
+ text_inputs['labels'], text_inputs['input_ids'], text_inputs['attention_mask'], text_inputs['position_ids']
554
+
555
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
556
+ max_length = self.max_length
557
+ inputs_embeds = inputs_embeds[:, -max_length:]
558
+ attention_mask = attention_mask[:, -max_length:]
559
+ position_ids = position_ids[:, -max_length:]
560
+ labels = labels[:, -max_length:]
561
+
562
+ output = self.llm.model(inputs_embeds=inputs_embeds,
563
+ attention_mask=attention_mask,
564
+ position_ids=position_ids,
565
+ return_dict=True)
566
+
567
+ hidden_states = output.last_hidden_state[:, :-1]
568
+ labels = labels[:, 1:]
569
+ hidden_states = hidden_states[labels >= 0]
570
+ labels = labels[labels >= 0]
571
+
572
+ logits = self.llm.get_output_embeddings()(hidden_states)
573
+ loss = F.cross_entropy(input=logits, target=labels)
574
+
575
+ return loss
576
+
577
+ '''distribute different losses for each task'''
578
+ def compute_loss(self, data_dict):
579
+ loss_fn_map = {
580
+ 'text2image': self.text2image_loss,
581
+ 'cam2image': self.cam2image_loss,
582
+ 'image2text': self.image2text_loss,
583
+ 'text2text': self.text2text_loss,
584
+ 'image2image': self.image2image_loss,
585
+ 'image2text_cross_view': self.image2text_loss,
586
+ }
587
+
588
+ losses = {}
589
+ for data_type, batch_data in data_dict.items():
590
+ if data_type not in loss_fn_map:
591
+ raise ValueError(f"Unsupported data_type: {data_type}")
592
+ loss_fn = loss_fn_map[data_type]
593
+ loss = loss_fn(batch_data)
594
+ losses[f'loss_{data_type}'] = loss
595
+ return losses
596
+
597
+ @torch.no_grad()
598
+ def generate(self,
599
+ prompt,
600
+ cfg_prompt,
601
+ cam_values=None,
602
+ pixel_values_init=None,
603
+ cfg_scale=4.5,
604
+ num_steps=50,
605
+ generator=None,
606
+ height=512,
607
+ width=512,
608
+ max_new_tokens=512,
609
+ reasoning=False,
610
+ prompt_reasoning=None,
611
+ progress_bar=True):
612
+ assert len(prompt) == len(cfg_prompt)
613
+ b = len(prompt)
614
+ output_reasoning = [''] * b
615
+
616
+ if reasoning:
617
+ # enrich the prompt if required reasoning generation
618
+ assert prompt_reasoning is not None, \
619
+ "prompt_reasoning must be provided for reasoning generation"
620
+ if isinstance(prompt_reasoning, str):
621
+ prompt_reasoning = [prompt_reasoning]
622
+ if isinstance(prompt, str):
623
+ prompt = [prompt]
624
+
625
+ conversations = [[{'input': f"{p1} {p2}",}]
626
+ for p1, p2 in zip(prompt_reasoning, prompt)]
627
+
628
+ text_inputs = self.prepare_und_prompts(
629
+ conversations=conversations, data_type="text2text", input_ids_with_output=False)
630
+ input_ids, attention_mask, position_ids = \
631
+ text_inputs['input_ids'], text_inputs['attention_mask'], text_inputs['position_ids']
632
+
633
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
634
+ past_key_values = DynamicCache.from_legacy_cache()
635
+
636
+ output_ids = []
637
+ for _ in tqdm(range(max_new_tokens), disable=not progress_bar):
638
+ output = self.llm.model(
639
+ inputs_embeds=inputs_embeds,
640
+ attention_mask=attention_mask,
641
+ position_ids=position_ids,
642
+ past_key_values=past_key_values,
643
+ use_cache=True,
644
+ return_dict=True)
645
+ logits = self.llm.get_output_embeddings()(output.last_hidden_state[:, -1:])
646
+ input_ids = torch.argmax(logits, dim=-1) # b 1
647
+ if len(output_ids) > 0:
648
+ input_ids = torch.where(output_ids[-1] == self.tokenizer.eos_token_id,
649
+ output_ids[-1], input_ids)
650
+ output_ids.append(input_ids)
651
+
652
+ if (input_ids == self.tokenizer.eos_token_id).all():
653
+ break
654
+
655
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
656
+ attention_mask = torch.cat([attention_mask, attention_mask.new_ones(b, 1)], dim=1)
657
+ position_ids = torch.max(position_ids, dim=1, keepdim=True).values + 1
658
+ past_key_values = output.past_key_values
659
+
660
+ output_ids = torch.cat(output_ids, dim=1)
661
+ output_reasoning = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
662
+ prompt = [f"{p} {o}" for p, o in zip(prompt, output_reasoning)]
663
+
664
+ if cam_values is not None:
665
+ # for the generation with the camera map
666
+ cam_values = [[img.to(dtype=self.dtype, device=self.device) for img in ref_images]
667
+ for ref_images in cam_values]
668
+ cond_latents = [[self.pixels_to_latents(img[None])[0] for img in ref_images]
669
+ for ref_images in cam_values]
670
+ text_inputs = self.prepare_gen_prompts(prompt + cfg_prompt)
671
+ if pixel_values_init is not None:
672
+ # for the generation with the camera map and initial view (cross-view generation)
673
+ num_refs = [len(ref_images) for ref_images in pixel_values_init]
674
+ pixel_values_init = [[img.to(dtype=self.dtype, device=self.device) for img in ref_images]
675
+ for ref_images in pixel_values_init]
676
+ image_embeds = self.extract_visual_features(
677
+ torch.stack([pad_an_image_tensor(img) for ref_images in pixel_values_init for img in ref_images]))
678
+ image_embeds = self.projector(image_embeds)
679
+
680
+ ref_lens = [len(x) for x in image_embeds]
681
+ text_inputs = self.prepare_gen_prompts(prompt + cfg_prompt, data_type='image2image', num_refs=num_refs*2, ref_lens=ref_lens*2)
682
+ text_inputs.update(image_embeds=torch.cat([image_embeds]*2))
683
+
684
+ cond_latents_init = [[self.pixels_to_latents(img[None])[0] for img in ref_imgs]
685
+ for ref_imgs in pixel_values_init]
686
+ cond_latents = [cam + img for cam, img in zip(cond_latents, cond_latents_init)]
687
+
688
+ cond_latents = cond_latents * 2
689
+ else:
690
+ # for the text2image generation
691
+ text_inputs = self.prepare_gen_prompts(prompt + cfg_prompt)
692
+ cond_latents = None
693
+
694
+ hidden_states = self.meta_queries[None].expand(2*b, self.num_queries, -1)
695
+ inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs)
696
+
697
+ output = self.llm.model(**inputs, return_dict=True)
698
+ hidden_states = output.last_hidden_state[:, -self.num_queries:]
699
+ pooled_prompt_embeds, prompt_embeds = self.llm2dit(hidden_states)
700
+
701
+ pipeline = StableDiffusion3Pipeline(
702
+ transformer=self.transformer,
703
+ scheduler=self.test_scheduler,
704
+ vae=self.vae,
705
+ text_encoder=None,
706
+ tokenizer=None,
707
+ text_encoder_2=None,
708
+ tokenizer_2=None,
709
+ text_encoder_3=None,
710
+ tokenizer_3=None,
711
+ )
712
+
713
+ pipeline.set_progress_bar_config(disable=not progress_bar)
714
+
715
+ samples = pipeline(
716
+ height=height,
717
+ width=width,
718
+ guidance_scale=cfg_scale,
719
+ num_inference_steps=num_steps,
720
+ prompt_embeds=prompt_embeds[:b],
721
+ pooled_prompt_embeds=pooled_prompt_embeds[:b],
722
+ negative_prompt_embeds=prompt_embeds[b:],
723
+ negative_pooled_prompt_embeds=pooled_prompt_embeds[b:],
724
+ generator=generator,
725
+ output_type='latent',
726
+ cond_latents=cond_latents
727
+ ).images.to(self.dtype)
728
+
729
+ return self.latents_to_pixels(samples), output_reasoning
730
+
731
+ @torch.no_grad()
732
+ def understand(self, prompt, pixel_values, max_new_tokens=512, progress_bar=True):
733
+ if isinstance(prompt, str):
734
+ prompt = [prompt]
735
+ if isinstance(pixel_values, torch.Tensor):
736
+ pixel_values = [pixel_values]
737
+
738
+ bsz = len(prompt)
739
+ assert len(pixel_values) == bsz
740
+
741
+ pixel_values = [pad_an_image_tensor(img) for img in pixel_values]
742
+ pixel_values = torch.stack(pixel_values).to(dtype=self.dtype, device=self.device)
743
+ image_embeds = self.extract_visual_features(pixel_values)
744
+ image_embeds = self.projector(image_embeds)
745
+
746
+ conversations = [[{'input': f"{DEFAULT_IMAGE_TOKEN}\n{p}",}] for p in prompt]
747
+
748
+ text_inputs = self.prepare_und_prompts(conversations=conversations, image_lengths=image_embeds.shape[1],
749
+ input_ids_with_output=False)
750
+
751
+ input_ids, attention_mask, position_ids = \
752
+ text_inputs['input_ids'], text_inputs['attention_mask'], text_inputs['position_ids']
753
+
754
+ inputs_embeds = torch.zeros(*input_ids.shape, self.llm.config.hidden_size,
755
+ device=self.device, dtype=self.dtype)
756
+ inputs_embeds[input_ids == INPUT_IMAGE_TOKEN_INDEX] = image_embeds.flatten(0, 1)
757
+ inputs_embeds[input_ids != INPUT_IMAGE_TOKEN_INDEX] = \
758
+ self.llm.get_input_embeddings()(input_ids[input_ids != INPUT_IMAGE_TOKEN_INDEX])
759
+
760
+ past_key_values = DynamicCache.from_legacy_cache()
761
+
762
+ output_ids = []
763
+
764
+ for _ in tqdm(range(max_new_tokens), disable=not progress_bar):
765
+ output = self.llm.model(
766
+ inputs_embeds=inputs_embeds,
767
+ attention_mask=attention_mask,
768
+ position_ids=position_ids,
769
+ past_key_values=past_key_values,
770
+ use_cache=True,
771
+ return_dict=True)
772
+ logits = self.llm.get_output_embeddings()(output.last_hidden_state[:, -1:])
773
+ input_ids = torch.argmax(logits, dim=-1) # b 1
774
+ if len(output_ids) > 0:
775
+ input_ids = torch.where(output_ids[-1] == self.tokenizer.eos_token_id,
776
+ output_ids[-1], input_ids)
777
+ output_ids.append(input_ids)
778
+
779
+ if (input_ids == self.tokenizer.eos_token_id).all():
780
+ break
781
+
782
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
783
+ attention_mask = torch.cat([attention_mask, attention_mask.new_ones(bsz, 1)], dim=1)
784
+ position_ids = torch.max(position_ids, dim=1, keepdim=True).values + 1
785
+ past_key_values = output.past_key_values
786
+
787
+ output_ids = torch.cat(output_ids, dim=1)
788
+ output_text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
789
+
790
+ return output_text
src/models/radiov3/adaptor_base.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+ from typing import NamedTuple, Optional
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class AdaptorInput(NamedTuple):
17
+ images: torch.Tensor
18
+ summary: torch.Tensor
19
+ features: torch.Tensor
20
+ feature_fmt: str
21
+ patch_size: int
22
+
23
+
24
+ class RadioOutput(NamedTuple):
25
+ summary: torch.Tensor
26
+ features: torch.Tensor
27
+
28
+ def to(self, *args, **kwargs):
29
+ return RadioOutput(
30
+ self.summary.to(*args, **kwargs) if self.summary is not None else None,
31
+ self.features.to(*args, **kwargs) if self.features is not None else None,
32
+ )
33
+
34
+
35
+ class AdaptorBase(nn.Module):
36
+ def forward(self, input: AdaptorInput) -> RadioOutput:
37
+ raise NotImplementedError("Subclasses must implement this!")
src/models/radiov3/adaptor_generic.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput
15
+ from .adaptor_mlp import create_mlp_from_state, create_mlp_from_config
16
+
17
+
18
+ class GenericAdaptor(AdaptorBase):
19
+ def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):
20
+ super().__init__()
21
+
22
+ extra_args = dict()
23
+ ups = None
24
+ ups_rank = None
25
+ if adaptor_config is not None:
26
+ ups = adaptor_config.get('fd_upsample_factor', None)
27
+ ups_rank = adaptor_config.get('fd_upsample_rank', None)
28
+ elif mlp_config is not None:
29
+ ups = mlp_config["feature"].get('upsample_factor', None)
30
+ ups_rank = mlp_config["feature"].get('upsample_rank', None)
31
+ if ups is not None:
32
+ extra_args['upsample_factor'] = ups
33
+ extra_args['upsample_rank'] = ups_rank
34
+
35
+ if state is not None:
36
+ spectral_heads = getattr(main_config, 'spectral_heads', False)
37
+ self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.', spectral_weights=spectral_heads)
38
+ self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.', spectral_weights=spectral_heads, **extra_args)
39
+ else:
40
+ assert mlp_config is not None, "Config must not be None if state is None"
41
+
42
+ self.head_mlp = create_mlp_from_config(
43
+ main_config.mlp_version,
44
+ mlp_config["summary"]["input_dim"],
45
+ mlp_config["summary"]["hidden_dim"],
46
+ mlp_config["summary"]["output_dim"],
47
+ mlp_config["summary"]["num_inner"],
48
+ )
49
+ self.feat_mlp = create_mlp_from_config(
50
+ main_config.mlp_version,
51
+ mlp_config["feature"]["input_dim"],
52
+ mlp_config["feature"]["hidden_dim"],
53
+ mlp_config["feature"]["output_dim"],
54
+ mlp_config["feature"]["num_inner"],
55
+ **extra_args
56
+ )
57
+
58
+ def forward(self, input: AdaptorInput) -> RadioOutput:
59
+ # Convert input'd type to the type of the first parameter of the adaptor.
60
+ first_param = next(self.parameters())
61
+ summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)
62
+ feat = self.feat_mlp(input.features.to(dtype=first_param.dtype), images=input.images, patch_size=input.patch_size).to(dtype=input.features.dtype)
63
+
64
+ if input.feature_fmt == 'NCHW':
65
+ feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size * self.feat_mlp.upsample_factor, input.images.shape[-1] // input.patch_size * self.feat_mlp.upsample_factor, feat.shape[2])
66
+ .permute(0, 3, 1, 2)
67
+ )
68
+
69
+ return RadioOutput(summary, feat)
src/models/radiov3/adaptor_mlp.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ import math
9
+ from typing import Dict, Optional
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from einops import rearrange
15
+ from timm.models.vision_transformer import Block
16
+
17
+ from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam
18
+
19
+
20
+ class MLP(nn.Module):
21
+ def __init__(self, input_size: int, hidden_size: int, output_size: int,
22
+ num_inner: int = 0, device: torch.device = None, **kwargs):
23
+ super(MLP, self).__init__()
24
+ self.fc1 = nn.Linear(input_size, hidden_size, device=device)
25
+ self.norm = nn.LayerNorm(hidden_size, device=device)
26
+ self.relu = nn.ReLU()
27
+
28
+ inner = []
29
+ for _ in range(num_inner):
30
+ inner.extend([
31
+ nn.Linear(hidden_size, hidden_size, device=device),
32
+ nn.LayerNorm(hidden_size, device=device),
33
+ nn.ReLU(),
34
+ ])
35
+ if inner:
36
+ self.inner = nn.Sequential(*inner)
37
+ else:
38
+ self.inner = nn.Identity()
39
+
40
+ self.fc2 = nn.Linear(hidden_size, output_size, device=device)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ x = self.fc1(x)
44
+ x = self.norm(x)
45
+ x = self.relu(x)
46
+ x = self.inner(x)
47
+ x = self.fc2(x)
48
+ return x
49
+
50
+
51
+ class MLP2(nn.Module):
52
+ def __init__(self, input_size: int, hidden_size: int, output_size: int,
53
+ num_inner: int = 0,
54
+ pre_norm: bool = False, device: torch.device = None,
55
+ upsample_factor: int = 1,
56
+ upsample_rank: int = None,
57
+ from_config: bool = False,
58
+ **kwargs):
59
+ super().__init__()
60
+
61
+ self.pre_norm = nn.Sequential(
62
+ nn.LayerNorm(input_size),
63
+ nn.GELU(),
64
+ ) if pre_norm else nn.Identity()
65
+
66
+ self.upsample_factor = upsample_factor
67
+ sq_ups = upsample_factor ** 2
68
+
69
+ self._real_output_dim = output_size // sq_ups
70
+
71
+ # hidden_size *= upsample_factor
72
+ # output_size *= (upsample_factor ** 2)
73
+
74
+ self.fc1 = nn.Linear(input_size, hidden_size, device=device)
75
+
76
+ blocks = []
77
+ for _ in range(num_inner):
78
+ blocks.append(nn.Sequential(
79
+ nn.LayerNorm(hidden_size, device=device),
80
+ nn.GELU(),
81
+ nn.Linear(hidden_size, hidden_size, device=device),
82
+ ))
83
+ self.blocks = nn.ModuleList(blocks)
84
+
85
+ self.final = nn.Sequential(
86
+ nn.LayerNorm(hidden_size, device=device),
87
+ nn.GELU(),
88
+ nn.Linear(hidden_size, output_size, device=device),
89
+ )
90
+
91
+ def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = None, patch_size: Optional[int] = None) -> torch.Tensor:
92
+ x = self.pre_norm(x)
93
+ x = self.fc1(x)
94
+ for block in self.blocks:
95
+ x = x + block(x)
96
+ x = self.final(x)
97
+
98
+ if self.upsample_factor > 1:
99
+ if images is None:
100
+ raise ValueError(f'`images` cannot be `None` when the head\'s `upsample_factor > 1`!')
101
+ if patch_size is None:
102
+ raise ValueError(f'`patch_size` cannot be `None` when the head\'s `upsample_factor > 1`!')
103
+ h, w = tuple(d // patch_size for d in images.shape[-2:])
104
+ x = rearrange(x, 'b (h w) (u1 u2 c) -> b (h u1 w u2) c',
105
+ h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
106
+ c=self._real_output_dim)
107
+
108
+ return x
109
+
110
+
111
+ MLP_FACTORY = {
112
+ 'v1': MLP,
113
+ 'v2': MLP2,
114
+ }
115
+
116
+
117
+ def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
118
+ state = {
119
+ k[len(prefix):]: v
120
+ for k, v in state.items()
121
+ if k.startswith(prefix)
122
+ }
123
+ return state
124
+
125
+
126
+ def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False):
127
+ state = strip_prefix(state, prefix)
128
+
129
+ weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original'
130
+
131
+ if version == 'v1':
132
+ hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
133
+ output_dim = state[f'fc2.{weight_suffix}'].shape[0]
134
+
135
+ for num_inner in range(1000):
136
+ k = f'inner.{num_inner}.0.weight'
137
+ if k not in state:
138
+ break
139
+ elif version == 'v2':
140
+ hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
141
+ output_dim = state[f'final.2.{weight_suffix}'].shape[0]
142
+
143
+ for num_inner in range(1000):
144
+ k = f'blocks.{num_inner}.0.weight'
145
+ if k not in state:
146
+ break
147
+ else:
148
+ raise ValueError(f'Unsupported MLP version: {version}')
149
+
150
+ return input_dim, hidden_dim, output_dim, num_inner
151
+
152
+
153
+ def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, **kwargs):
154
+ ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs)
155
+
156
+ return ret
157
+
158
+
159
+ def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, **kwargs):
160
+ state = strip_prefix(state, prefix)
161
+
162
+ input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights)
163
+
164
+ ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, **kwargs)
165
+
166
+ if spectral_weights:
167
+ enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state)
168
+
169
+ ret.load_state_dict(state)
170
+
171
+ if spectral_weights:
172
+ disable_spectral_reparam(ret)
173
+
174
+ return ret
src/models/radiov3/adaptor_registry.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+ from typing import Dict, Any
10
+
11
+ import torch
12
+
13
+ from .adaptor_generic import GenericAdaptor, AdaptorBase
14
+
15
+ dict_t = Dict[str, Any]
16
+ state_t = Dict[str, torch.Tensor]
17
+
18
+
19
+ class AdaptorRegistry:
20
+ def __init__(self):
21
+ self._registry = {}
22
+
23
+ def register_adaptor(self, name):
24
+ def decorator(factory_function):
25
+ if name in self._registry:
26
+ raise ValueError(f"Model '{name}' already registered")
27
+ self._registry[name] = factory_function
28
+ return factory_function
29
+ return decorator
30
+
31
+ def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:
32
+ if name not in self._registry:
33
+ return GenericAdaptor(main_config, adaptor_config, state)
34
+ return self._registry[name](main_config, adaptor_config, state)
35
+
36
+ # Creating an instance of the registry
37
+ adaptor_registry = AdaptorRegistry()
src/models/radiov3/cls_token.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ class ClsToken(nn.Module):
15
+ def __init__(self, ndim: int,
16
+ num_tokens: int = 1,
17
+ enabled: bool = True,
18
+ register_multiple: Optional[int] = None,
19
+ num_registers: Optional[int] = None,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.ndim = ndim
24
+ self.enabled = enabled
25
+ self.num_registers = 0
26
+ self.num_tokens = num_tokens
27
+ if enabled:
28
+ if num_registers:
29
+ self.num_registers = num_registers
30
+ elif register_multiple:
31
+ self.num_registers = register_multiple - (num_tokens % register_multiple)
32
+
33
+ scale = ndim ** -0.5
34
+ self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale)
35
+ else:
36
+ self.token = None
37
+
38
+ self.num_patches = self.num_tokens + self.num_registers
39
+
40
+ def disable(self):
41
+ self.token = None
42
+ self.enabled = False
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ if self.token is None:
46
+ return x
47
+
48
+ token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1)
49
+ x = torch.cat([
50
+ token,
51
+ x,
52
+ ], dim=1)
53
+
54
+ return x
55
+
56
+ def no_weight_decay(self):
57
+ return [
58
+ 'token',
59
+ ]
src/models/radiov3/common.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+ from .radio_model import Resolution
13
+
14
+
15
+ @dataclass
16
+ class RadioResource:
17
+ url: str
18
+ patch_size: int
19
+ max_resolution: int
20
+ preferred_resolution: Resolution
21
+ supports_vitdet: bool = True
22
+ vitdet_num_windowed: Optional[int] = None
23
+ vitdet_num_global: Optional[int] = None
24
+
25
+
26
+ RESOURCE_MAP = {
27
+ # RADIOv2.5
28
+ "radio_v2.5-b": RadioResource(
29
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-b_half.pth.tar?download=true",
30
+ patch_size=16,
31
+ max_resolution=2048,
32
+ preferred_resolution=(768, 768),
33
+ vitdet_num_global=4,
34
+ ),
35
+ "radio_v2.5-l": RadioResource(
36
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio-v2.5-l_half.pth.tar?download=true",
37
+ patch_size=16,
38
+ max_resolution=2048,
39
+ preferred_resolution=(768, 768),
40
+ vitdet_num_global=4,
41
+ ),
42
+ "radio_v2.5-h": RadioResource(
43
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h.pth.tar?download=true",
44
+ patch_size=16,
45
+ max_resolution=2048,
46
+ preferred_resolution=(768, 768),
47
+ vitdet_num_global=4,
48
+ ),
49
+ "radio_v2.5-h-norm": RadioResource(
50
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-h-norm.pth.tar?download=true",
51
+ patch_size=16,
52
+ max_resolution=2048,
53
+ preferred_resolution=(768, 768),
54
+ vitdet_num_global=4,
55
+ ),
56
+ "radio_v2.5-g": RadioResource(
57
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.5-g.pth.tar?download=true",
58
+ patch_size=14,
59
+ max_resolution=1792,
60
+ preferred_resolution=(896, 896),
61
+ vitdet_num_global=8,
62
+ ),
63
+ # RADIO
64
+ "radio_v2.1": RadioResource(
65
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.1_bf16.pth.tar?download=true",
66
+ patch_size=16,
67
+ max_resolution=2048,
68
+ preferred_resolution=Resolution(432, 432),
69
+ vitdet_num_windowed=5,
70
+ ),
71
+ "radio_v2": RadioResource(
72
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v2.pth.tar?download=true",
73
+ patch_size=16,
74
+ max_resolution=2048,
75
+ preferred_resolution=Resolution(432, 432),
76
+ vitdet_num_windowed=5,
77
+ ),
78
+ "radio_v1": RadioResource(
79
+ "https://huggingface.co/nvidia/RADIO/resolve/main/radio_v1.pth.tar?download=true",
80
+ patch_size=14,
81
+ max_resolution=1050,
82
+ preferred_resolution=Resolution(378, 378),
83
+ ),
84
+ # E-RADIO
85
+ "e-radio_v2": RadioResource(
86
+ "https://huggingface.co/nvidia/RADIO/resolve/main/eradio_v2.pth.tar?download=true",
87
+ patch_size=16,
88
+ max_resolution=2048,
89
+ preferred_resolution=Resolution(512, 512),
90
+ ),
91
+ # C-RADIO
92
+ "c-radio_v2.5-g": RadioResource(
93
+ "https://huggingface.co/nvidia/C-RADIOv2-g/resolve/main/c-radio_v2-g_half.pth.tar",
94
+ patch_size=16,
95
+ max_resolution=2048,
96
+ preferred_resolution=(768, 768),
97
+ vitdet_num_global=8,
98
+ ),
99
+ "c-radio_v3-b": RadioResource(
100
+ # NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-L
101
+ # and accept the license terms.
102
+ "https://huggingface.co/nvidia/C-RADIOv3-B/resolve/main/c-radio-v3_b_half.pth.tar?download=true",
103
+ patch_size=16,
104
+ max_resolution=2048,
105
+ preferred_resolution=Resolution(512, 512),
106
+ supports_vitdet=False,
107
+ ),
108
+ "c-radio_v3-l": RadioResource(
109
+ # NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-L
110
+ # and accept the license terms.
111
+ "https://huggingface.co/nvidia/C-RADIOv3-L/resolve/main/c-radio-v3_l_half.pth.tar?download=true",
112
+ patch_size=16,
113
+ max_resolution=2048,
114
+ preferred_resolution=Resolution(512, 512),
115
+ ),
116
+ "c-radio_v3-h": RadioResource(
117
+ # NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-H
118
+ # and accept the license terms.
119
+ "https://huggingface.co/nvidia/C-RADIOv3-H/resolve/main/c-radio_v3-h_half.pth.tar?download=true",
120
+ patch_size=16,
121
+ max_resolution=2048,
122
+ preferred_resolution=Resolution(512, 512),
123
+ ),
124
+ "c-radio_v3-g": RadioResource(
125
+ # NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-G
126
+ # and accept the license terms.
127
+ "https://huggingface.co/nvidia/C-RADIOv3-G/resolve/main/c-radio-v3_g_half.pth.tar?download=true",
128
+ patch_size=16,
129
+ max_resolution=2048,
130
+ preferred_resolution=Resolution(512, 512),
131
+ ),
132
+ }
133
+
134
+ DEFAULT_VERSION = "c-radio_v3-h"
src/models/radiov3/dinov2_arch.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ # Nvidia
11
+ # NOTE: We re-define this model architecture primarily so that we don't have to worry about version compatibility breaking,
12
+ # but also because Huggingface does a string replace of `gamma` to something else when loading the model state,
13
+ # and this breaks loading of this model.
14
+
15
+ from enum import Enum
16
+ from functools import partial
17
+ import logging
18
+ import math
19
+ import os
20
+ import sys
21
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
22
+ import warnings
23
+
24
+ import torch
25
+ from torch import nn
26
+ from torch.nn import functional as F
27
+ from torch.nn.init import trunc_normal_
28
+
29
+ _torch_has_sdpa = hasattr(F, 'scaled_dot_product_attention')
30
+
31
+
32
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
33
+ try:
34
+ if XFORMERS_ENABLED:
35
+ from xformers.ops import fmha, scaled_index_add, index_select_cat, SwiGLU, memory_efficient_attention, unbind
36
+
37
+ XFORMERS_AVAILABLE = True
38
+ else:
39
+ raise ImportError
40
+ except ImportError:
41
+ XFORMERS_AVAILABLE = False
42
+
43
+
44
+ def make_2tuple(x):
45
+ if isinstance(x, tuple):
46
+ assert len(x) == 2
47
+ return x
48
+
49
+ assert isinstance(x, int)
50
+ return (x, x)
51
+
52
+
53
+ class PatchEmbed(nn.Module):
54
+ """
55
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
56
+
57
+ Args:
58
+ img_size: Image size.
59
+ patch_size: Patch token size.
60
+ in_chans: Number of input image channels.
61
+ embed_dim: Number of linear projection output channels.
62
+ norm_layer: Normalization layer.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ img_size: Union[int, Tuple[int, int]] = 224,
68
+ patch_size: Union[int, Tuple[int, int]] = 16,
69
+ in_chans: int = 3,
70
+ embed_dim: int = 768,
71
+ norm_layer: Optional[Callable] = None,
72
+ flatten_embedding: bool = True,
73
+ ) -> None:
74
+ super().__init__()
75
+
76
+ image_HW = make_2tuple(img_size)
77
+ patch_HW = make_2tuple(patch_size)
78
+ patch_grid_size = (
79
+ image_HW[0] // patch_HW[0],
80
+ image_HW[1] // patch_HW[1],
81
+ )
82
+
83
+ self.img_size = image_HW
84
+ self.patch_size = patch_HW
85
+ self.patches_resolution = patch_grid_size
86
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
87
+
88
+ self.in_chans = in_chans
89
+ self.embed_dim = embed_dim
90
+
91
+ self.flatten_embedding = flatten_embedding
92
+
93
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
94
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
95
+
96
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
97
+ _, _, H, W = x.shape
98
+ patch_H, patch_W = self.patch_size
99
+
100
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
101
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
102
+
103
+ x = self.proj(x) # B C H W
104
+ H, W = x.size(2), x.size(3)
105
+ x = x.flatten(2).transpose(1, 2) # B HW C
106
+ x = self.norm(x)
107
+ if not self.flatten_embedding:
108
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
109
+ return x
110
+
111
+ def flops(self) -> float:
112
+ Ho, Wo = self.patches_resolution
113
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
114
+ if self.norm is not None:
115
+ flops += Ho * Wo * self.embed_dim
116
+ return flops
117
+
118
+
119
+ class Attention(nn.Module):
120
+ def __init__(
121
+ self,
122
+ dim: int,
123
+ num_heads: int = 8,
124
+ qkv_bias: bool = False,
125
+ proj_bias: bool = True,
126
+ attn_drop: float = 0.0,
127
+ proj_drop: float = 0.0,
128
+ ) -> None:
129
+ super().__init__()
130
+ self.num_heads = num_heads
131
+ head_dim = dim // num_heads
132
+ self.scale = head_dim**-0.5
133
+
134
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
135
+ self.attn_drop = nn.Dropout(attn_drop)
136
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
137
+ self.proj_drop = nn.Dropout(proj_drop)
138
+
139
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
140
+ B, N, C = x.shape
141
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
142
+
143
+ q, k, v = qkv[0], qkv[1], qkv[2]
144
+ if _torch_has_sdpa:
145
+ x = F.scaled_dot_product_attention(
146
+ q, k, v,
147
+ is_causal=False,
148
+ dropout_p=self.attn_drop.p if self.training else 0.,
149
+ scale=self.scale,
150
+ )
151
+ else:
152
+ q = q * self.scale
153
+ attn = q @ k.transpose(-2, -1)
154
+
155
+ attn = attn.softmax(dim=-1)
156
+ attn = self.attn_drop(attn)
157
+ x = attn @ v
158
+
159
+ x = x.transpose(1, 2).reshape(B, N, C)
160
+ x = self.proj(x)
161
+ x = self.proj_drop(x)
162
+ return x
163
+
164
+
165
+ class MemEffAttention(Attention):
166
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
167
+ if not XFORMERS_AVAILABLE:
168
+ if attn_bias is not None:
169
+ raise AssertionError("xFormers is required for using nested tensors")
170
+ return super().forward(x)
171
+
172
+ B, N, C = x.shape
173
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
174
+
175
+ q, k, v = unbind(qkv, 2)
176
+
177
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
178
+ x = x.reshape([B, N, C])
179
+
180
+ x = self.proj(x)
181
+ x = self.proj_drop(x)
182
+ return x
183
+
184
+
185
+ class Mlp(nn.Module):
186
+ def __init__(
187
+ self,
188
+ in_features: int,
189
+ hidden_features: Optional[int] = None,
190
+ out_features: Optional[int] = None,
191
+ act_layer: Callable[..., nn.Module] = nn.GELU,
192
+ drop: float = 0.0,
193
+ bias: bool = True,
194
+ ) -> None:
195
+ super().__init__()
196
+ out_features = out_features or in_features
197
+ hidden_features = hidden_features or in_features
198
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
199
+ self.act = act_layer()
200
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
201
+ self.drop = nn.Dropout(drop)
202
+
203
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
204
+ x = self.fc1(x)
205
+ x = self.act(x)
206
+ x = self.drop(x)
207
+ x = self.fc2(x)
208
+ x = self.drop(x)
209
+ return x
210
+
211
+
212
+ class SwiGLUFFN(nn.Module):
213
+ def __init__(
214
+ self,
215
+ in_features: int,
216
+ hidden_features: Optional[int] = None,
217
+ out_features: Optional[int] = None,
218
+ act_layer: Callable[..., nn.Module] = None,
219
+ drop: float = 0.0,
220
+ bias: bool = True,
221
+ ) -> None:
222
+ super().__init__()
223
+ out_features = out_features or in_features
224
+ hidden_features = hidden_features or in_features
225
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
226
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
227
+
228
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
229
+ x12 = self.w12(x)
230
+ x1, x2 = x12.chunk(2, dim=-1)
231
+ hidden = F.silu(x1) * x2
232
+ return self.w3(hidden)
233
+
234
+
235
+ if not XFORMERS_AVAILABLE:
236
+ SwiGLU = SwiGLUFFN
237
+
238
+
239
+ class SwiGLUFFNFused(SwiGLU):
240
+ def __init__(
241
+ self,
242
+ in_features: int,
243
+ hidden_features: Optional[int] = None,
244
+ out_features: Optional[int] = None,
245
+ act_layer: Callable[..., nn.Module] = None,
246
+ drop: float = 0.0,
247
+ bias: bool = True,
248
+ ) -> None:
249
+ out_features = out_features or in_features
250
+ hidden_features = hidden_features or in_features
251
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
252
+ super().__init__(
253
+ in_features=in_features,
254
+ hidden_features=hidden_features,
255
+ out_features=out_features,
256
+ bias=bias,
257
+ )
258
+
259
+
260
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
261
+ if drop_prob == 0.0 or not training:
262
+ return x
263
+ keep_prob = 1 - drop_prob
264
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
265
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
266
+ if keep_prob > 0.0:
267
+ random_tensor.div_(keep_prob)
268
+ output = x * random_tensor
269
+ return output
270
+
271
+
272
+ class DropPath(nn.Module):
273
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
274
+
275
+ def __init__(self, drop_prob=None):
276
+ super(DropPath, self).__init__()
277
+ self.drop_prob = drop_prob
278
+
279
+ def forward(self, x):
280
+ return drop_path(x, self.drop_prob, self.training)
281
+
282
+
283
+ class LayerScale(nn.Module):
284
+ def __init__(
285
+ self,
286
+ dim: int,
287
+ init_values: Union[float, torch.Tensor] = 1e-5,
288
+ inplace: bool = False,
289
+ ) -> None:
290
+ super().__init__()
291
+ self.inplace = inplace
292
+ self.grandma = nn.Parameter(init_values * torch.ones(dim))
293
+
294
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
295
+ return x.mul_(self.grandma) if self.inplace else x * self.grandma
296
+
297
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
298
+ # Huggingface is absurd and it will rename strings that contain `gamma`, which means that the normal DINO implementation
299
+ # of LayerScale won't work with HFHub. So we rename the variable to 'grandma', and support loading checkpoints in either
300
+ # format
301
+ key_a = f'{prefix}gamma'
302
+ key_b = f'{prefix}grandma'
303
+ if key_a in state_dict:
304
+ gamma = state_dict[key_a]
305
+ elif key_b in state_dict:
306
+ gamma = state_dict[key_b]
307
+ else:
308
+ if strict:
309
+ raise KeyError(f"Couldn't find the key {key_a} nor {key_b} in the state dict!")
310
+ else:
311
+ missing_keys.append(key_a)
312
+ missing_keys.append(key_b)
313
+ unexpected_keys.extend(state_dict.keys())
314
+ gamma = None
315
+
316
+ if gamma is not None:
317
+ self.grandma.data.copy_(gamma)
318
+
319
+ # return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
320
+
321
+
322
+ class Block(nn.Module):
323
+ def __init__(
324
+ self,
325
+ dim: int,
326
+ num_heads: int,
327
+ mlp_ratio: float = 4.0,
328
+ qkv_bias: bool = False,
329
+ proj_bias: bool = True,
330
+ ffn_bias: bool = True,
331
+ drop: float = 0.0,
332
+ attn_drop: float = 0.0,
333
+ init_values=None,
334
+ drop_path: float = 0.0,
335
+ act_layer: Callable[..., nn.Module] = nn.GELU,
336
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
337
+ attn_class: Callable[..., nn.Module] = Attention,
338
+ ffn_layer: Callable[..., nn.Module] = Mlp,
339
+ ) -> None:
340
+ super().__init__()
341
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
342
+ self.norm1 = norm_layer(dim)
343
+ self.attn = attn_class(
344
+ dim,
345
+ num_heads=num_heads,
346
+ qkv_bias=qkv_bias,
347
+ proj_bias=proj_bias,
348
+ attn_drop=attn_drop,
349
+ proj_drop=drop,
350
+ )
351
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
352
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
353
+
354
+ self.norm2 = norm_layer(dim)
355
+ mlp_hidden_dim = int(dim * mlp_ratio)
356
+ self.mlp = ffn_layer(
357
+ in_features=dim,
358
+ hidden_features=mlp_hidden_dim,
359
+ act_layer=act_layer,
360
+ drop=drop,
361
+ bias=ffn_bias,
362
+ )
363
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
364
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
365
+
366
+ self.sample_drop_ratio = drop_path
367
+
368
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
369
+ def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
370
+ return self.ls1(self.attn(self.norm1(x)))
371
+
372
+ def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
373
+ return self.ls2(self.mlp(self.norm2(x)))
374
+
375
+ if self.training and self.sample_drop_ratio > 0.1:
376
+ # the overhead is compensated only for a drop path rate larger than 0.1
377
+ x = drop_add_residual_stochastic_depth(
378
+ x,
379
+ residual_func=attn_residual_func,
380
+ sample_drop_ratio=self.sample_drop_ratio,
381
+ )
382
+ x = drop_add_residual_stochastic_depth(
383
+ x,
384
+ residual_func=ffn_residual_func,
385
+ sample_drop_ratio=self.sample_drop_ratio,
386
+ )
387
+ elif self.training and self.sample_drop_ratio > 0.0:
388
+ x = x + self.drop_path1(attn_residual_func(x))
389
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
390
+ else:
391
+ x = x + attn_residual_func(x)
392
+ x = x + ffn_residual_func(x)
393
+ return x
394
+
395
+
396
+ class NestedTensorBlock(Block):
397
+ def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
398
+ """
399
+ x_list contains a list of tensors to nest together and run
400
+ """
401
+ assert isinstance(self.attn, MemEffAttention)
402
+
403
+ if self.training and self.sample_drop_ratio > 0.0:
404
+
405
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
406
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
407
+
408
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
409
+ return self.mlp(self.norm2(x))
410
+
411
+ x_list = drop_add_residual_stochastic_depth_list(
412
+ x_list,
413
+ residual_func=attn_residual_func,
414
+ sample_drop_ratio=self.sample_drop_ratio,
415
+ scaling_vector=self.ls1.grandma if isinstance(self.ls1, LayerScale) else None,
416
+ )
417
+ x_list = drop_add_residual_stochastic_depth_list(
418
+ x_list,
419
+ residual_func=ffn_residual_func,
420
+ sample_drop_ratio=self.sample_drop_ratio,
421
+ scaling_vector=self.ls2.grandma if isinstance(self.ls1, LayerScale) else None,
422
+ )
423
+ return x_list
424
+ else:
425
+
426
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
427
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
428
+
429
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
430
+ return self.ls2(self.mlp(self.norm2(x)))
431
+
432
+ attn_bias, x = get_attn_bias_and_cat(x_list)
433
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
434
+ x = x + ffn_residual_func(x)
435
+ return attn_bias.split(x)
436
+
437
+ def forward(self, x_or_x_list):
438
+ if isinstance(x_or_x_list, torch.Tensor):
439
+ return super().forward(x_or_x_list)
440
+ elif isinstance(x_or_x_list, list):
441
+ if not XFORMERS_AVAILABLE:
442
+ raise AssertionError("xFormers is required for using nested tensors")
443
+ return self.forward_nested(x_or_x_list)
444
+ else:
445
+ raise AssertionError
446
+
447
+
448
+ def drop_add_residual_stochastic_depth(
449
+ x: torch.Tensor,
450
+ residual_func: Callable[[torch.Tensor], torch.Tensor],
451
+ sample_drop_ratio: float = 0.0,
452
+ ) -> torch.Tensor:
453
+ # 1) extract subset using permutation
454
+ b, n, d = x.shape
455
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
456
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
457
+ x_subset = x[brange]
458
+
459
+ # 2) apply residual_func to get residual
460
+ residual = residual_func(x_subset)
461
+
462
+ x_flat = x.flatten(1)
463
+ residual = residual.flatten(1)
464
+
465
+ residual_scale_factor = b / sample_subset_size
466
+
467
+ # 3) add the residual
468
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
469
+ return x_plus_residual.view_as(x)
470
+
471
+
472
+ def get_branges_scales(x, sample_drop_ratio=0.0):
473
+ b, n, d = x.shape
474
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
475
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
476
+ residual_scale_factor = b / sample_subset_size
477
+ return brange, residual_scale_factor
478
+
479
+
480
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
481
+ if scaling_vector is None:
482
+ x_flat = x.flatten(1)
483
+ residual = residual.flatten(1)
484
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
485
+ else:
486
+ x_plus_residual = scaled_index_add(
487
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
488
+ )
489
+ return x_plus_residual
490
+
491
+
492
+ attn_bias_cache: Dict[Tuple, Any] = {}
493
+
494
+
495
+ def get_attn_bias_and_cat(x_list, branges=None):
496
+ """
497
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
498
+ """
499
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
500
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
501
+ if all_shapes not in attn_bias_cache.keys():
502
+ seqlens = []
503
+ for b, x in zip(batch_sizes, x_list):
504
+ for _ in range(b):
505
+ seqlens.append(x.shape[1])
506
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
507
+ attn_bias._batch_sizes = batch_sizes
508
+ attn_bias_cache[all_shapes] = attn_bias
509
+
510
+ if branges is not None:
511
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
512
+ else:
513
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
514
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
515
+
516
+ return attn_bias_cache[all_shapes], cat_tensors
517
+
518
+
519
+ def drop_add_residual_stochastic_depth_list(
520
+ x_list: List[torch.Tensor],
521
+ residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
522
+ sample_drop_ratio: float = 0.0,
523
+ scaling_vector=None,
524
+ ) -> torch.Tensor:
525
+ # 1) generate random set of indices for dropping samples in the batch
526
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
527
+ branges = [s[0] for s in branges_scales]
528
+ residual_scale_factors = [s[1] for s in branges_scales]
529
+
530
+ # 2) get attention bias and index+concat the tensors
531
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
532
+
533
+ # 3) apply residual_func to get residual, and split the result
534
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
535
+
536
+ outputs = []
537
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
538
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
539
+ return outputs
540
+
541
+
542
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
543
+ if not depth_first and include_root:
544
+ fn(module=module, name=name)
545
+ for child_name, child_module in module.named_children():
546
+ child_name = ".".join((name, child_name)) if name else child_name
547
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
548
+ if depth_first and include_root:
549
+ fn(module=module, name=name)
550
+ return module
551
+
552
+
553
+ class BlockChunk(nn.ModuleList):
554
+ def forward(self, x):
555
+ for b in self:
556
+ x = b(x)
557
+ return x
558
+
559
+
560
+ class DinoVisionTransformer(nn.Module):
561
+ def __init__(
562
+ self,
563
+ img_size=224,
564
+ patch_size=16,
565
+ in_chans=3,
566
+ embed_dim=768,
567
+ depth=12,
568
+ num_heads=12,
569
+ mlp_ratio=4.0,
570
+ qkv_bias=True,
571
+ ffn_bias=True,
572
+ proj_bias=True,
573
+ drop_path_rate=0.0,
574
+ drop_path_uniform=False,
575
+ init_values=None, # for layerscale: None or 0 => no layerscale
576
+ embed_layer=PatchEmbed,
577
+ act_layer=nn.GELU,
578
+ block_fn=Block,
579
+ ffn_layer="mlp",
580
+ block_chunks=1,
581
+ num_register_tokens=0,
582
+ interpolate_antialias=False,
583
+ interpolate_offset=0.1,
584
+ ):
585
+ """
586
+ Args:
587
+ img_size (int, tuple): input image size
588
+ patch_size (int, tuple): patch size
589
+ in_chans (int): number of input channels
590
+ embed_dim (int): embedding dimension
591
+ depth (int): depth of transformer
592
+ num_heads (int): number of attention heads
593
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
594
+ qkv_bias (bool): enable bias for qkv if True
595
+ proj_bias (bool): enable bias for proj in attn if True
596
+ ffn_bias (bool): enable bias for ffn if True
597
+ drop_path_rate (float): stochastic depth rate
598
+ drop_path_uniform (bool): apply uniform drop rate across blocks
599
+ weight_init (str): weight init scheme
600
+ init_values (float): layer-scale init values
601
+ embed_layer (nn.Module): patch embedding layer
602
+ act_layer (nn.Module): MLP activation layer
603
+ block_fn (nn.Module): transformer block class
604
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
605
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
606
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
607
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
608
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
609
+ """
610
+ super().__init__()
611
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
612
+
613
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
614
+ self.num_tokens = 1
615
+ self.n_blocks = depth
616
+ self.num_heads = num_heads
617
+ self.patch_size = patch_size
618
+ self.num_register_tokens = num_register_tokens
619
+ self.interpolate_antialias = interpolate_antialias
620
+ self.interpolate_offset = interpolate_offset
621
+
622
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
623
+ num_patches = self.patch_embed.num_patches
624
+
625
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
626
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
627
+ assert num_register_tokens >= 0
628
+ self.register_tokens = (
629
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
630
+ )
631
+
632
+ if drop_path_uniform is True:
633
+ dpr = [drop_path_rate] * depth
634
+ else:
635
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
636
+
637
+ if ffn_layer == "mlp":
638
+ ffn_layer = Mlp
639
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
640
+ ffn_layer = SwiGLUFFNFused
641
+ elif ffn_layer == "identity":
642
+ def f(*args, **kwargs):
643
+ return nn.Identity()
644
+
645
+ ffn_layer = f
646
+ else:
647
+ raise NotImplementedError
648
+
649
+ blocks_list = [
650
+ block_fn(
651
+ dim=embed_dim,
652
+ num_heads=num_heads,
653
+ mlp_ratio=mlp_ratio,
654
+ qkv_bias=qkv_bias,
655
+ proj_bias=proj_bias,
656
+ ffn_bias=ffn_bias,
657
+ drop_path=dpr[i],
658
+ norm_layer=norm_layer,
659
+ act_layer=act_layer,
660
+ ffn_layer=ffn_layer,
661
+ init_values=init_values,
662
+ )
663
+ for i in range(depth)
664
+ ]
665
+ if block_chunks > 0:
666
+ self.chunked_blocks = True
667
+ chunked_blocks = []
668
+ chunksize = depth // block_chunks
669
+ for i in range(0, depth, chunksize):
670
+ # this is to keep the block index consistent if we chunk the block list
671
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
672
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
673
+ else:
674
+ self.chunked_blocks = False
675
+ self.blocks = nn.ModuleList(blocks_list)
676
+
677
+ self.norm = norm_layer(embed_dim)
678
+ self.head = nn.Identity()
679
+
680
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
681
+
682
+ def interpolate_pos_encoding(self, x, w, h):
683
+ previous_dtype = x.dtype
684
+ npatch = x.shape[1] - 1
685
+ N = self.pos_embed.shape[1] - 1
686
+ if npatch == N and w == h:
687
+ return self.pos_embed
688
+ pos_embed = self.pos_embed.float()
689
+ class_pos_embed = pos_embed[:, 0]
690
+ patch_pos_embed = pos_embed[:, 1:]
691
+ dim = x.shape[-1]
692
+ w0 = w // self.patch_size
693
+ h0 = h // self.patch_size
694
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
695
+ assert N == M * M
696
+ kwargs = {}
697
+ if self.interpolate_offset:
698
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
699
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
700
+ sx = float(w0 + self.interpolate_offset) / M
701
+ sy = float(h0 + self.interpolate_offset) / M
702
+ kwargs["scale_factor"] = (sx, sy)
703
+ else:
704
+ # Simply specify an output size instead of a scale factor
705
+ kwargs["size"] = (w0, h0)
706
+ patch_pos_embed = nn.functional.interpolate(
707
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
708
+ mode="bicubic",
709
+ antialias=self.interpolate_antialias,
710
+ **kwargs,
711
+ )
712
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
713
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
714
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
715
+
716
+ def prepare_tokens_with_masks(self, x, masks=None):
717
+ B, nc, w, h = x.shape
718
+ x = self.patch_embed(x)
719
+ if masks is not None:
720
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
721
+
722
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
723
+ x = x + self.interpolate_pos_encoding(x, w, h)
724
+
725
+ if self.register_tokens is not None:
726
+ x = torch.cat(
727
+ (
728
+ x[:, :1],
729
+ self.register_tokens.expand(x.shape[0], -1, -1),
730
+ x[:, 1:],
731
+ ),
732
+ dim=1,
733
+ )
734
+
735
+ return x
736
+
737
+ def forward_features_list(self, x_list, masks_list):
738
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
739
+ for blk in self.blocks:
740
+ x = blk(x)
741
+
742
+ all_x = x
743
+ output = []
744
+ for x, masks in zip(all_x, masks_list):
745
+ x_norm = self.norm(x)
746
+ output.append(
747
+ {
748
+ "x_norm_clstoken": x_norm[:, 0],
749
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
750
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
751
+ "x_prenorm": x,
752
+ "masks": masks,
753
+ }
754
+ )
755
+ return output
756
+
757
+ def forward_features(self, x, masks=None):
758
+ if isinstance(x, list):
759
+ return self.forward_features_list(x, masks)
760
+
761
+ x = self.prepare_tokens_with_masks(x, masks)
762
+
763
+ for blk in self.blocks:
764
+ x = blk(x)
765
+
766
+ x_norm = self.norm(x)
767
+ return {
768
+ "x_norm_clstoken": x_norm[:, 0],
769
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
770
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
771
+ "x_prenorm": x,
772
+ "masks": masks,
773
+ }
774
+
775
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
776
+ x = self.prepare_tokens_with_masks(x)
777
+ # If n is an int, take the n last blocks. If it's a list, take them
778
+ output, total_block_len = [], len(self.blocks)
779
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
780
+ for i, blk in enumerate(self.blocks):
781
+ x = blk(x)
782
+ if i in blocks_to_take:
783
+ output.append(x)
784
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
785
+ return output
786
+
787
+ def _get_intermediate_layers_chunked(self, x, n=1):
788
+ x = self.prepare_tokens_with_masks(x)
789
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
790
+ # If n is an int, take the n last blocks. If it's a list, take them
791
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
792
+ for block_chunk in self.blocks:
793
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
794
+ x = blk(x)
795
+ if i in blocks_to_take:
796
+ output.append(x)
797
+ i += 1
798
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
799
+ return output
800
+
801
+ def get_intermediate_layers(
802
+ self,
803
+ x: torch.Tensor,
804
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
805
+ reshape: bool = False,
806
+ return_class_token: bool = False,
807
+ norm=True,
808
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
809
+ if self.chunked_blocks:
810
+ outputs = self._get_intermediate_layers_chunked(x, n)
811
+ else:
812
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
813
+ if norm:
814
+ outputs = [self.norm(out) for out in outputs]
815
+ class_tokens = [out[:, 0] for out in outputs]
816
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
817
+ if reshape:
818
+ B, _, w, h = x.shape
819
+ outputs = [
820
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
821
+ for out in outputs
822
+ ]
823
+ if return_class_token:
824
+ return tuple(zip(outputs, class_tokens))
825
+ return tuple(outputs)
826
+
827
+ def forward(self, *args, is_training=False, **kwargs):
828
+ ret = self.forward_features(*args, **kwargs)
829
+ if is_training:
830
+ return ret
831
+ else:
832
+ return self.head(ret["x_norm_clstoken"])
833
+
834
+
835
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
836
+ model = DinoVisionTransformer(
837
+ patch_size=patch_size,
838
+ embed_dim=384,
839
+ depth=12,
840
+ num_heads=6,
841
+ mlp_ratio=4,
842
+ block_fn=partial(Block, attn_class=MemEffAttention),
843
+ num_register_tokens=num_register_tokens,
844
+ **kwargs,
845
+ )
846
+ return model
847
+
848
+
849
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
850
+ model = DinoVisionTransformer(
851
+ patch_size=patch_size,
852
+ embed_dim=768,
853
+ depth=12,
854
+ num_heads=12,
855
+ mlp_ratio=4,
856
+ block_fn=partial(Block, attn_class=MemEffAttention),
857
+ num_register_tokens=num_register_tokens,
858
+ **kwargs,
859
+ )
860
+ return model
861
+
862
+
863
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
864
+ model = DinoVisionTransformer(
865
+ patch_size=patch_size,
866
+ embed_dim=1024,
867
+ depth=24,
868
+ num_heads=16,
869
+ mlp_ratio=4,
870
+ block_fn=partial(Block, attn_class=MemEffAttention),
871
+ num_register_tokens=num_register_tokens,
872
+ **kwargs,
873
+ )
874
+ return model
875
+
876
+
877
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
878
+ """
879
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
880
+ """
881
+ model = DinoVisionTransformer(
882
+ patch_size=patch_size,
883
+ embed_dim=1536,
884
+ depth=40,
885
+ num_heads=24,
886
+ mlp_ratio=4,
887
+ block_fn=partial(Block, attn_class=MemEffAttention),
888
+ num_register_tokens=num_register_tokens,
889
+ **kwargs,
890
+ )
891
+ return model
892
+
893
+
894
+ class Weights(Enum):
895
+ LVD142M = "LVD142M"
896
+
897
+
898
+ def _make_dinov2_model(
899
+ *,
900
+ arch_name: str = "vit_large",
901
+ img_size: int = 518,
902
+ patch_size: int = 14,
903
+ init_values: float = 1.0,
904
+ ffn_layer: str = "mlp",
905
+ block_chunks: int = 0,
906
+ num_register_tokens: int = 0,
907
+ interpolate_antialias: bool = False,
908
+ interpolate_offset: float = 0.1,
909
+ weights: Union[Weights, str] = Weights.LVD142M,
910
+ **kwargs,
911
+ ):
912
+ if isinstance(weights, str):
913
+ try:
914
+ weights = Weights[weights]
915
+ except KeyError:
916
+ raise AssertionError(f"Unsupported weights: {weights}")
917
+
918
+ vit_kwargs = dict(
919
+ img_size=img_size,
920
+ patch_size=patch_size,
921
+ init_values=init_values,
922
+ ffn_layer=ffn_layer,
923
+ block_chunks=block_chunks,
924
+ num_register_tokens=num_register_tokens,
925
+ interpolate_antialias=interpolate_antialias,
926
+ interpolate_offset=interpolate_offset,
927
+ )
928
+ vit_kwargs.update(**kwargs)
929
+ model = sys.modules[__name__].__dict__[arch_name](**vit_kwargs)
930
+
931
+ return model
932
+
933
+
934
+ def dinov2_vits14(**kwargs):
935
+ """
936
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
937
+ """
938
+ return _make_dinov2_model(arch_name="vit_small", **kwargs)
939
+
940
+
941
+ def dinov2_vitb14(**kwargs):
942
+ """
943
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
944
+ """
945
+ return _make_dinov2_model(arch_name="vit_base", **kwargs)
946
+
947
+
948
+ def dinov2_vitl14(**kwargs):
949
+ """
950
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
951
+ """
952
+ return _make_dinov2_model(arch_name="vit_large", **kwargs)
953
+
954
+
955
+ def dinov2_vitg14(**kwargs):
956
+ """
957
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
958
+ """
959
+ return _make_dinov2_model(
960
+ arch_name="vit_giant2",
961
+ ffn_layer="swiglufused",
962
+ **kwargs,
963
+ )
964
+
965
+
966
+ def dinov2_vits14_reg(**kwargs):
967
+ """
968
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
969
+ """
970
+ return _make_dinov2_model(
971
+ arch_name="vit_small",
972
+ num_register_tokens=4,
973
+ interpolate_antialias=True,
974
+ interpolate_offset=0.0,
975
+ **kwargs,
976
+ )
977
+
978
+
979
+ def dinov2_vitb14_reg(**kwargs):
980
+ """
981
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
982
+ """
983
+ return _make_dinov2_model(
984
+ arch_name="vit_base",
985
+ num_register_tokens=4,
986
+ interpolate_antialias=True,
987
+ interpolate_offset=0.0,
988
+ **kwargs,
989
+ )
990
+
991
+
992
+ def dinov2_vitl14_reg(**kwargs):
993
+ """
994
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
995
+ """
996
+ return _make_dinov2_model(
997
+ arch_name="vit_large",
998
+ num_register_tokens=4,
999
+ interpolate_antialias=True,
1000
+ interpolate_offset=0.0,
1001
+ **kwargs,
1002
+ )
1003
+
1004
+
1005
+ def dinov2_vitg14_reg(**kwargs):
1006
+ """
1007
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
1008
+ """
1009
+ return _make_dinov2_model(
1010
+ arch_name="vit_giant2",
1011
+ ffn_layer="swiglufused",
1012
+ num_register_tokens=4,
1013
+ interpolate_antialias=True,
1014
+ interpolate_offset=0.0,
1015
+ **kwargs,
1016
+ )
src/models/radiov3/dual_hybrid_vit.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import getLogger
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from timm.models import register_model
9
+ from timm.models import vision_transformer as tvit
10
+ from timm.models import convnext as tconv
11
+
12
+ from einops import rearrange
13
+
14
+ from . import extra_timm_models as et
15
+
16
+
17
+ class Fuser(nn.Module):
18
+ def __init__(self, src_dim: int, tgt_dim: int, gated: bool = True):
19
+ super().__init__()
20
+ self.gated = gated
21
+
22
+ mid_dim = max(src_dim, tgt_dim) * 2
23
+
24
+ self.fwd = nn.Sequential(
25
+ nn.Conv2d(src_dim, mid_dim, kernel_size=3, stride=1, padding=1),
26
+ nn.GELU(),
27
+ nn.Conv2d(mid_dim, tgt_dim * (2 if gated else 1), kernel_size=3, stride=1, padding=1),
28
+ )
29
+
30
+ def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
31
+ if src.ndim == 3:
32
+ shape = tgt.shape[-2:]
33
+ else:
34
+ shape = src.shape[-2:]
35
+
36
+ nd = shape[0] * shape[1]
37
+
38
+ if src.ndim == 3:
39
+ src = src[:, -nd:].reshape(src.shape[0], src.shape[2], *shape)
40
+
41
+ if tgt.ndim == 3:
42
+ tgt_pre = tgt[:, :-nd]
43
+ tgt = tgt[:, -nd:].reshape(tgt.shape[0], tgt.shape[2], *shape)
44
+ else:
45
+ tgt_pre = None
46
+
47
+ pred = self.fwd(src)
48
+
49
+ if self.gated:
50
+ g, pred = torch.chunk(pred, 2, dim=1)
51
+
52
+ g = F.sigmoid(g)
53
+
54
+ pred = g * pred
55
+
56
+ tgt = tgt + pred
57
+
58
+ if tgt_pre is not None:
59
+ tgt = rearrange(tgt, 'b c h w -> b (h w) c')
60
+ tgt = torch.cat([tgt_pre, tgt], dim=1)
61
+
62
+ return tgt
63
+
64
+
65
+ class AttnDownsample(nn.Module):
66
+ def __init__(self, dim: int, window_size: int, num_heads: int = 16):
67
+ super().__init__()
68
+ self.q = nn.Parameter(torch.randn(1, num_heads, 1, dim // num_heads) * 0.01)
69
+ self.kv = nn.Linear(dim, dim * 2)
70
+ self.proj = nn.Linear(dim, dim)
71
+ self.window_size = window_size
72
+ self.num_heads = num_heads
73
+ self.head_dim = dim // num_heads
74
+ self.scale = self.head_dim ** -0.5
75
+
76
+ def forward(self, x: torch.Tensor, twod_shape: Tuple[int, int]) -> torch.Tensor:
77
+ ntok = twod_shape[0] * twod_shape[1]
78
+ x_pre = x[:, :-ntok]
79
+
80
+ B = x.shape[0]
81
+ ds_hw = tuple(s // self.window_size for s in twod_shape)
82
+
83
+ x_spat = rearrange(
84
+ x[:, -ntok:],
85
+ 'b (h d1 w d2) c -> (b h w) (d1 d2) c',
86
+ h=ds_hw[0], w=ds_hw[1],
87
+ d1=self.window_size, d2=self.window_size,
88
+ )
89
+
90
+ B, N, C = x_spat.shape
91
+
92
+ k, v = self.kv(x_spat).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
93
+
94
+ q = (self.q * self.scale).expand(B, -1, -1, -1)
95
+ attn = q @ k.transpose(-2, -1)
96
+ attn = F.softmax(attn, dim=-1)
97
+ x = attn @ v
98
+
99
+ x = x.transpose(1, 2).reshape(B, C)
100
+ x = self.proj(x)
101
+
102
+ x = rearrange(x, '(b h w) c -> b (h w) c', b=x_pre.shape[0], h=ds_hw[0], w=ds_hw[1])
103
+
104
+ x = torch.cat([x_pre, x], dim=1)
105
+ return x
106
+
107
+
108
+ class HybridModel(nn.Module):
109
+ def __init__(self, vit: tvit.VisionTransformer, conv: tconv.ConvNeXt, pretrained: bool = False,
110
+ concatenate: bool = False, **kwargs):
111
+ super().__init__()
112
+ self.conv = conv
113
+ self.vit = vit
114
+ self.concatenate = concatenate
115
+
116
+ conv.stages = nn.ModuleList(conv.stages)
117
+ vit.blocks = nn.ModuleList(vit.blocks)
118
+
119
+ self._half_vit_idx = len(vit.blocks) // 2 + 1
120
+
121
+ self._half_conv_idx = None
122
+ x = torch.empty(1, 3, 256, 256)
123
+ x = self.conv.stem(x)
124
+ for i in range(len(conv.stages)):
125
+ x = conv.stages[i](x)
126
+ if self._half_conv_idx is None and x.shape[-2:] == (16, 16):
127
+ self._half_conv_idx = i + 1
128
+ half_conv_dim = x.shape[1]
129
+ final_conv_dim = x.shape[1]
130
+
131
+ self.vit_to_conv_fusion = Fuser(vit.embed_dim, half_conv_dim)
132
+ self.conv_to_vit_fusion = Fuser(half_conv_dim, vit.embed_dim)
133
+ self.vit_ds = AttnDownsample(vit.embed_dim, window_size=2)
134
+
135
+ embed_dim = vit.embed_dim + (final_conv_dim if concatenate else 0)
136
+ if not concatenate:
137
+ self.final_fuse = Fuser(final_conv_dim, vit.embed_dim, gated=False)
138
+ self.final_block = tvit.Block(embed_dim, num_heads=16)
139
+
140
+ self.embed_dim = embed_dim
141
+
142
+ @property
143
+ def patch_size(self):
144
+ return 32
145
+
146
+ @property
147
+ def no_fsdp_wrap_types(self):
148
+ return {tvit.VisionTransformer, tconv.ConvNeXt}
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ return self.forward_features(x)
152
+
153
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
154
+ y_vit = self.vit.patch_generator(x)
155
+
156
+ for i in range(self._half_vit_idx):
157
+ y_vit = self.vit.blocks[i](y_vit)
158
+
159
+ y_conv = self.conv.stem(x)
160
+ for i in range(self._half_conv_idx):
161
+ y_conv = self.conv.stages[i](y_conv)
162
+
163
+ y_vit, y_conv = self.conv_to_vit_fusion(y_conv, y_vit), self.vit_to_conv_fusion(y_vit, y_conv)
164
+
165
+ y_vit = self.vit_ds(y_vit, y_conv.shape[-2:])
166
+
167
+ for i in range(self._half_vit_idx, len(self.vit.blocks)):
168
+ y_vit = self.vit.blocks[i](y_vit)
169
+
170
+ for i in range(self._half_conv_idx, len(self.conv.stages)):
171
+ y_conv = self.conv.stages[i](y_conv)
172
+
173
+ if self.concatenate:
174
+ y_conv = rearrange(y_conv, 'b c h w -> b (h w) c')
175
+ # Average pool across the board, and replicate for each cls/register token
176
+ conv_summary = y_conv.mean(dim=1, keepdim=True).expand(-1, self.vit.patch_generator.num_cls_patches, -1)
177
+ y_conv = torch.cat([conv_summary, y_conv], dim=1)
178
+ y = torch.cat([y_vit, y_conv], dim=2)
179
+ else:
180
+ y = self.final_fuse(y_conv, y_vit)
181
+ y = self.final_block(y)
182
+
183
+ summary = y[:, :self.vit.patch_generator.num_cls_tokens]
184
+ features = y[:, self.vit.patch_generator.num_cls_patches:]
185
+
186
+ return summary, features
187
+
188
+
189
+ @register_model
190
+ def hybrid_base(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
191
+ cfg = dict(num_classes=0, **kwargs)
192
+ conv = tconv.convnextv2_base(pretrained=pretrained, **cfg)
193
+ vit = tvit.vit_base_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
194
+
195
+ return HybridModel(vit, conv, pretrained, concatenate=concatenate)
196
+
197
+
198
+ @register_model
199
+ def hybrid_large(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
200
+ cfg = dict(num_classes=0, **kwargs)
201
+ conv = tconv.convnextv2_large(pretrained=pretrained, **cfg)
202
+ vit = tvit.vit_large_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
203
+
204
+ return HybridModel(vit, conv, pretrained, concatenate=concatenate)
205
+
206
+
207
+ @register_model
208
+ def hybrid_huge(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
209
+ cfg = dict(num_classes=0, **kwargs)
210
+ conv = tconv.convnextv2_huge(pretrained=pretrained, **cfg)
211
+ vit = et.vit_huge_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
212
+
213
+ return HybridModel(vit, conv, pretrained, concatenate=concatenate)
src/models/radiov3/enable_cpe_support.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from typing import List, Optional, Set, Tuple, Union
10
+ from types import MethodType
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from timm.models import VisionTransformer, checkpoint_seq
16
+
17
+ from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer
18
+
19
+ from .extra_models import DinoWrapper
20
+ from .vit_patch_generator import ViTPatchGenerator
21
+ from .forward_intermediates import forward_intermediates
22
+ from .dual_hybrid_vit import HybridModel
23
+
24
+
25
+ def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
26
+ x = self.patch_generator(x)
27
+ if getattr(self, 'grad_checkpointing', False) and not torch.jit.is_scripting():
28
+ x = checkpoint_seq(self.blocks, x)
29
+ else:
30
+ x = self.blocks(x)
31
+ x = self.norm(x)
32
+ return x
33
+
34
+
35
+ def _take_indices(
36
+ num_blocks: int,
37
+ n: Optional[Union[int, List[int], Tuple[int]]],
38
+ ) -> Tuple[Set[int], int]:
39
+ if isinstance(n, int):
40
+ assert n >= 0
41
+ take_indices = {x for x in range(num_blocks - n, num_blocks)}
42
+ else:
43
+ take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
44
+ return take_indices, max(take_indices)
45
+
46
+
47
+ def _forward_intermediates_cpe(
48
+ self,
49
+ x: torch.Tensor,
50
+ norm: bool = False,
51
+ **kwargs,
52
+ ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
53
+ return forward_intermediates(
54
+ self,
55
+ patch_extractor=self.patch_generator,
56
+ num_summary_tokens=self.patch_generator.num_skip,
57
+ num_cls_tokens=self.patch_generator.num_cls_tokens,
58
+ norm=self.norm if norm else lambda y: y,
59
+ x=x,
60
+ **kwargs,
61
+ )
62
+
63
+
64
+ def _forward_cpe_dinov2(self: DinoWrapper, x: torch.Tensor) -> torch.Tensor:
65
+ y = _forward_cpe(self.inner, x)
66
+
67
+ return y[:, 0], y[:, self.num_summary_tokens:]
68
+
69
+
70
+ def _forward_intermediates_cpe_dinov2(self: DinoWrapper, *args, **kwargs):
71
+ return _forward_intermediates_cpe(self.inner, *args, **kwargs)
72
+
73
+
74
+ def _enable_cpe_for_timm_vit(model: VisionTransformer,
75
+ max_img_size: Union[int, Tuple[int, int]] = 1024,
76
+ num_cls_tokens: int = 1,
77
+ pos_dropout: float = 0.1,
78
+ register_multiple: int = Optional[None],
79
+ num_registers: int = Optional[None],
80
+ ):
81
+ if not isinstance(model, VisionTransformer):
82
+ raise ValueError("CPE only support for VisionTransformer models!")
83
+
84
+ patch_size = model.patch_embed.patch_size[0]
85
+ embed_dim = model.embed_dim
86
+ input_dims = model.patch_embed.img_size
87
+ normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity)
88
+ cls_token = model.cls_token is not None
89
+
90
+ max_img_size = int(round(max_img_size / patch_size) * patch_size)
91
+
92
+ patch_generator = ViTPatchGenerator(
93
+ patch_size=patch_size,
94
+ embed_dim=embed_dim,
95
+ input_dims=input_dims,
96
+ normalize_patches=normalize_patches,
97
+ cls_token=cls_token,
98
+ max_input_dims=max_img_size,
99
+ pos_dropout=pos_dropout,
100
+ num_cls_tokens=num_cls_tokens,
101
+ register_multiple=register_multiple,
102
+ num_registers=num_registers,
103
+ )
104
+
105
+ model.patch_generator = patch_generator
106
+ model.patch_embed = None
107
+ model.cls_token = None
108
+ model.pos_embed = None
109
+ model.pos_drop = None
110
+ model.patch_size = patch_size
111
+ model.num_cls_tokens = num_cls_tokens
112
+ model.num_registers = patch_generator.num_registers
113
+
114
+ model.forward_features = MethodType(_forward_cpe, model)
115
+ model.forward_intermediates = MethodType(_forward_intermediates_cpe, model)
116
+
117
+
118
+ def _enable_cpe_for_dv2_reg_vit(model: DinoWrapper,
119
+ max_img_size: Union[int, Tuple[int, int]] = 1024,
120
+ num_cls_tokens: int = 1,
121
+ pos_dropout: float = 0.1,
122
+ register_multiple: int = Optional[None],
123
+ num_registers: int = Optional[None],
124
+ ):
125
+ patch_size = model.patch_size
126
+ embed_dim = model.embed_dim
127
+ input_dims = model.inner.patch_embed.patches_resolution
128
+ normalize_patches = not isinstance(model.inner.patch_embed.norm, nn.Identity)
129
+ cls_token = True
130
+
131
+ max_img_size = int(round(max_img_size / patch_size) * patch_size)
132
+
133
+ patch_generator = ViTPatchGenerator(
134
+ patch_size=patch_size,
135
+ embed_dim=embed_dim,
136
+ input_dims=input_dims,
137
+ normalize_patches=normalize_patches,
138
+ cls_token=cls_token,
139
+ max_input_dims=max_img_size,
140
+ pos_dropout=pos_dropout,
141
+ num_cls_tokens=num_cls_tokens,
142
+ register_multiple=register_multiple,
143
+ num_registers=num_registers,
144
+ patch_bias=True,
145
+ )
146
+
147
+ inner = model.inner
148
+ inner.patch_generator = patch_generator
149
+ inner.patch_embed = None
150
+ inner.cls_token = None
151
+ inner.pos_embed = None
152
+ inner.register_tokens = None
153
+ inner.patch_size = patch_size
154
+
155
+ model.forward_features = MethodType(_forward_cpe_dinov2, model)
156
+ model.forward_intermediates = MethodType(_forward_intermediates_cpe_dinov2, model)
157
+
158
+
159
+ def enable_cpe(model: nn.Module,
160
+ *args,
161
+ **kwargs,
162
+ ):
163
+ if isinstance(model, VisionTransformer):
164
+ _enable_cpe_for_timm_vit(model, *args, **kwargs)
165
+ elif isinstance(model, DinoWrapper):
166
+ _enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)
167
+ elif isinstance(model, HybridModel):
168
+ _enable_cpe_for_timm_vit(model.vit, *args, **kwargs)
169
+ else:
170
+ raise ValueError(f'CPE not supported for this model type: {type(model)}')
src/models/radiov3/enable_spectral_reparam.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from logging import getLogger
10
+ import math
11
+ import os
12
+ from typing import Dict, List, Optional, Union, Tuple
13
+ from types import MethodType
14
+
15
+ import torch
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.utils import parametrize
19
+ from torch.nn.utils.parametrizations import _SpectralNorm
20
+
21
+ from timm.models.vision_transformer import Attention, Mlp
22
+
23
+ _EPS = 1e-5
24
+
25
+
26
+ class _SNReweight(_SpectralNorm):
27
+ def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs):
28
+ super().__init__(weight, *args, **kwargs)
29
+
30
+ self.alpha = alpha
31
+ self.version = version
32
+ self.register_buffer('_sn_version', torch.tensor(version))
33
+
34
+ if init_norm_to_current:
35
+ # This will set the numerator to match the denominator, which should preserve the original values
36
+ init_scale = self._get_sigma(weight, n_power_iterations=20).item()
37
+ else:
38
+ init_scale = 1.0
39
+
40
+ if version == 1:
41
+ init_value = init_scale
42
+ elif version == 2:
43
+ t = init_scale - alpha
44
+ if t < _EPS:
45
+ getLogger("spectral_reparam").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.')
46
+ t = _EPS
47
+
48
+ init_value = math.log(math.exp(t) - 1)
49
+ else:
50
+ raise ValueError(f'Unsupported version: {version}')
51
+
52
+ # Make 2D so that weight decay gets applied
53
+ self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device))
54
+
55
+ # Re-implementing this because we need to make division by sigma safe
56
+ def _get_sigma(self, weight: torch.Tensor, n_power_iterations: int = None) -> torch.Tensor:
57
+ if not n_power_iterations:
58
+ n_power_iterations = self.n_power_iterations
59
+ if weight.ndim == 1:
60
+ # Faster and more exact path, no need to approximate anything
61
+ sigma = weight.norm()
62
+ else:
63
+ weight_mat = self._reshape_weight_to_matrix(weight)
64
+ if self.training:
65
+ self._power_method(weight_mat, n_power_iterations)
66
+ # See above on why we need to clone
67
+ u = self._u.clone(memory_format=torch.contiguous_format)
68
+ v = self._v.clone(memory_format=torch.contiguous_format)
69
+ # The proper way of computing this should be through F.bilinear, but
70
+ # it seems to have some efficiency issues:
71
+ # https://github.com/pytorch/pytorch/issues/58093
72
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
73
+
74
+ return sigma + self.eps
75
+
76
+ def forward(self, weight: torch.Tensor, *args, **kwargs):
77
+ dtype = weight.dtype
78
+ sigma = self._get_sigma(weight, *args, **kwargs)
79
+
80
+ if self.version == 1:
81
+ scale = self.scale
82
+ elif self.version == 2:
83
+ scale = F.softplus(self.scale) + self.alpha
84
+ else:
85
+ raise ValueError(f'Unsupported version: {self.version}')
86
+
87
+ scale = scale.float() / sigma.float()
88
+
89
+ y = weight * scale
90
+
91
+ if dtype in (torch.float16, torch.bfloat16):
92
+ y = y.to(dtype)
93
+ return y
94
+
95
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
96
+ version_key = f'{prefix}_sn_version'
97
+ if version_key not in state_dict:
98
+ self.version = 1
99
+ state_dict[version_key] = torch.tensor(1)
100
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
101
+
102
+
103
+ class _ChunkedSNReweight(nn.Module):
104
+ def __init__(self, weight: torch.Tensor, num_chunks: int, *args, init_norm_to_current: bool = False, **kwargs):
105
+ super().__init__()
106
+
107
+ self.num_chunks = num_chunks
108
+ parts = weight.split(weight.shape[0] // num_chunks, dim=0)
109
+
110
+ self.parts = nn.ModuleList([
111
+ _SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs)
112
+ for p in parts
113
+ ])
114
+
115
+ def forward(self, weight: torch.Tensor, *args, **kwargs):
116
+ parts = weight.split(weight.shape[0] // self.num_chunks, dim=0)
117
+
118
+ parts = [
119
+ fn(p)
120
+ for fn, p in zip(self.parts, parts)
121
+ ]
122
+
123
+ return torch.cat(parts, dim=0)
124
+
125
+
126
+ class _AttnSNReweight(_ChunkedSNReweight):
127
+ def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs):
128
+ super().__init__(weight, 3, *args, init_norm_to_current=init_norm_to_current, **kwargs)
129
+
130
+ if not renorm_values:
131
+ self.parts[2] = nn.Identity()
132
+
133
+
134
+ def enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]],
135
+ n_power_iterations: int = 1,
136
+ eps: float = 1e-6,
137
+ init_norm_to_current: bool = False,
138
+ renorm_values: bool = True,
139
+ renorm_mlp: bool = True,
140
+ state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None):
141
+ if isinstance(model, (list, tuple)):
142
+ for i, sub in enumerate(model):
143
+ sub_sd = state_dict_guidance[i] if isinstance(state_dict_guidance, (list, tuple)) else state_dict_guidance
144
+ enable_spectral_reparam(sub, n_power_iterations=n_power_iterations, eps=eps,
145
+ init_norm_to_current=init_norm_to_current, renorm_values=renorm_values,
146
+ renorm_mlp=renorm_mlp, state_dict_guidance=sub_sd)
147
+ return
148
+
149
+ print('Enabling spectral reparametrization')
150
+ args = dict(n_power_iterations=n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current)
151
+ visited_prefixes = set()
152
+
153
+ def is_guidance_parametrized(name: str):
154
+ if state_dict_guidance is None:
155
+ return True
156
+
157
+ p_name = f'{name}.parametrizations'
158
+ is_prm = any(k for k in state_dict_guidance if k.startswith(p_name) and k.endswith('_sn_version'))
159
+ return is_prm
160
+
161
+ def parametrize_linear(linear: nn.Linear):
162
+ parametrize.register_parametrization(
163
+ linear,
164
+ 'weight',
165
+ _SNReweight(linear.weight, **args)
166
+ )
167
+
168
+ for name, mod in model.named_modules():
169
+ pref = '.'.join(name.split('.')[:-1])
170
+ if pref in visited_prefixes:
171
+ continue
172
+
173
+ if isinstance(mod, Attention) or name.endswith('.attn'):
174
+ if is_guidance_parametrized(f'{name}.qkv'):
175
+ parametrize.register_parametrization(
176
+ mod.qkv,
177
+ 'weight',
178
+ _AttnSNReweight(mod.qkv.weight, renorm_values=renorm_values, **args),
179
+ )
180
+ if hasattr(mod, 'proj') and is_guidance_parametrized(f'{name}.proj'):
181
+ parametrize_linear(mod.proj)
182
+ visited_prefixes.add(name)
183
+ elif name.endswith('mlp') and renorm_mlp and hasattr(mod, 'w12'):
184
+ if is_guidance_parametrized(f'{name}.w12'):
185
+ parametrize.register_parametrization(
186
+ mod.w12,
187
+ 'weight',
188
+ _ChunkedSNReweight(mod.w12.weight, num_chunks=2, **args),
189
+ )
190
+ if is_guidance_parametrized(f'{name}.w3'):
191
+ parametrize_linear(mod.w3)
192
+ visited_prefixes.add(name)
193
+ elif isinstance(mod, nn.Linear) and 'patch_generator' not in name and is_guidance_parametrized(name):
194
+ parametrize_linear(mod)
195
+
196
+
197
+ def configure_spectral_reparam_from_args(model: nn.Module, args, state_dict_guidance: Optional[Dict[str, torch.Tensor]] = None):
198
+ spectral_reparam = getattr(args, 'spectral_reparam', False)
199
+ if isinstance(spectral_reparam, bool) and spectral_reparam:
200
+ enable_spectral_reparam(model, init_norm_to_current=True, state_dict_guidance=state_dict_guidance)
201
+ elif isinstance(spectral_reparam, dict):
202
+ enable_spectral_reparam(
203
+ model,
204
+ n_power_iterations=spectral_reparam.get('n_power_iterations', 1),
205
+ eps=spectral_reparam.get('eps', 1e-12),
206
+ init_norm_to_current=True,
207
+ state_dict_guidance=state_dict_guidance,
208
+ )
209
+
210
+
211
+ def disable_spectral_reparam(model: nn.Module):
212
+ print('Disabling spectral reparametrization')
213
+ for name, mod in model.named_modules():
214
+ if parametrize.is_parametrized(mod):
215
+ parametrize.remove_parametrizations(mod, 'weight')
216
+ pass
217
+
218
+
219
+
220
+ if __name__ == '__main__':
221
+ import argparse
222
+ from . import radio_model as create_model
223
+
224
+ parser = argparse.ArgumentParser(description='Remove parametrization from state dict')
225
+ parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load')
226
+ parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint')
227
+ parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields')
228
+ parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict')
229
+
230
+ args = parser.parse_args()
231
+
232
+ if not args.output:
233
+ chk_dir, chk_name = os.path.split(args.checkpoint)
234
+ args.output = os.path.join(chk_dir, f'clean_{chk_name}')
235
+ print(f'Set output to "{args.output}"')
236
+
237
+ chk = torch.load(args.checkpoint, map_location='cpu', mmap=True)
238
+
239
+ model = create_model.create_model_from_args(chk['args'])
240
+
241
+ key = 'base_model.'
242
+ mod_state = dict()
243
+ extra_state = dict()
244
+ for k, v in chk['state_dict'].items():
245
+ if k.startswith(key):
246
+ mod_state[k[len(key):]] = v
247
+ else:
248
+ extra_state[k] = v
249
+
250
+ chk_load_info = model.load_state_dict(mod_state, strict=args.strict)
251
+ if chk_load_info.unexpected_keys or chk_load_info.missing_keys:
252
+ print(chk_load_info)
253
+
254
+ if chk['args'].spectral_reparam:
255
+ disable_spectral_reparam(model)
256
+
257
+ if hasattr(chk['args'], 'dtype'):
258
+ model.to(dtype=chk['args'].dtype)
259
+
260
+ mod_state = model.state_dict()
261
+ final_state = dict()
262
+ final_state.update({f'{key}{k}': v for k, v in mod_state.items()})
263
+ final_state.update(extra_state)
264
+
265
+ chk['state_dict'] = final_state
266
+ chk['args'].spectral_reparam = False
267
+
268
+ if args.release:
269
+ chk = {
270
+ 'arch': chk['arch'],
271
+ 'epoch': chk['epoch'],
272
+ 'state_dict': chk['state_dict'],
273
+ 'args': chk['args'],
274
+ }
275
+
276
+ torch.save(chk, args.output)
277
+ pass
src/models/radiov3/eradio_model.py ADDED
@@ -0,0 +1,1392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ # E-RADIO model from
12
+ # Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).
13
+
14
+ # based on FasterViT, Swin Transformer, YOLOv8
15
+
16
+ # FasterViT:
17
+ # Ali Hatamizadeh, Greg Heinrich, Hongxu Yin, Andrew Tao, Jose M. Alvarez, Jan Kautz, and Pavlo Molchanov. "FasterViT: Fast Vision Transformers with Hierarchical Attention." arXiv preprint arXiv:2306.06189 (2023).
18
+
19
+ import timm
20
+ import torch
21
+ import torch.nn as nn
22
+ from timm.models.registry import register_model
23
+
24
+ from timm.models.layers import trunc_normal_, DropPath, LayerNorm2d
25
+ import numpy as np
26
+ import torch.nn.functional as F
27
+ import math
28
+ import warnings
29
+
30
+ #######################
31
+ ## Codebase from YOLOv8
32
+ ## BEGINNING
33
+ #######################
34
+
35
+ class C2f(nn.Module):
36
+ """Faster Implementation of CSP Bottleneck with 2 convolutions."""
37
+ """From YOLOv8 codebase"""
38
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None): # ch_in, ch_out, number, shortcut, groups, expansion
39
+ super().__init__()
40
+ if drop_path is None:
41
+ drop_path = [0.0] * n
42
+
43
+ self.c = int(c2 * e) # hidden channels
44
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
45
+ self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
46
+ self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n))
47
+
48
+ def forward(self, x):
49
+ """Forward pass through C2f layer."""
50
+ y = list(self.cv1(x).chunk(2, 1))
51
+ y.extend(m(y[-1]) for m in self.m)
52
+ return self.cv2(torch.cat(y, 1))
53
+
54
+ def forward_split(self, x):
55
+ """Forward pass using split() instead of chunk()."""
56
+ y = list(self.cv1(x).split((self.c, self.c), 1))
57
+ y.extend(m(y[-1]) for m in self.m)
58
+ return self.cv2(torch.cat(y, 1))
59
+
60
+ class Bottleneck(nn.Module):
61
+ """Standard bottleneck."""
62
+
63
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0): # ch_in, ch_out, shortcut, groups, kernels, expand
64
+ super().__init__()
65
+ c_ = int(c2 * e) # hidden channels
66
+ self.cv1 = Conv(c1, c_, k[0], 1)
67
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
68
+ self.add = shortcut and c1 == c2
69
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
70
+
71
+ def forward(self, x):
72
+ """'forward()' applies the YOLOv5 FPN to input data."""
73
+ return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))
74
+
75
+
76
+ class Conv(nn.Module):
77
+ """Modified to support layer fusion"""
78
+ default_act = nn.SiLU() # default activation
79
+
80
+ def __init__(self, a, b, kernel_size=1, stride=1, padding=None, g=1, dilation=1, bn_weight_init=1, bias=False, act=True):
81
+ super().__init__()
82
+
83
+ self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, autopad(kernel_size, padding, dilation), dilation, g, bias=False)
84
+ if 1:
85
+ self.bn = torch.nn.BatchNorm2d(b)
86
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
87
+ torch.nn.init.constant_(self.bn.bias, 0)
88
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
89
+
90
+
91
+ def forward(self,x):
92
+ x = self.conv(x)
93
+ x = self.bn(x)
94
+ x = self.act(x)
95
+ return x
96
+
97
+ @torch.no_grad()
98
+ def switch_to_deploy(self):
99
+ # return 1
100
+ if not isinstance(self.bn, nn.Identity):
101
+ c, bn = self.conv, self.bn
102
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
103
+ w = c.weight * w[:, None, None, None]
104
+ b = bn.bias - bn.running_mean * bn.weight / \
105
+ (bn.running_var + bn.eps)**0.5
106
+
107
+ self.conv.weight.data.copy_(w)
108
+ self.conv.bias = nn.Parameter(b)
109
+
110
+ self.bn = nn.Identity()
111
+
112
+ def autopad(k, p=None, d=1): # kernel, padding, dilation
113
+ """Pad to 'same' shape outputs."""
114
+ if d > 1:
115
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
116
+ if p is None:
117
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
118
+ return p
119
+
120
+
121
+ #######################
122
+ ## Codebase from YOLOv8
123
+ ## END
124
+ #######################
125
+
126
+ def pixel_unshuffle(data, factor=2):
127
+ # performs nn.PixelShuffle(factor) in reverse, torch has some bug for ONNX and TRT, so doing it manually
128
+ B, C, H, W = data.shape
129
+ return data.view(B, C, factor, H//factor, factor, W//factor).permute(0,1,2,4,3,5).reshape(B, -1, H//factor, W//factor)
130
+
131
+ class SwiGLU(nn.Module):
132
+ # should be more advanced, but doesnt improve results so far
133
+ def forward(self, x):
134
+ x, gate = x.chunk(2, dim=-1)
135
+ return F.silu(gate) * x
136
+
137
+
138
+ def window_partition(x, window_size):
139
+ """
140
+ Function for partitioning image into windows and later do windowed attention
141
+ Args:
142
+ x: (B, C, H, W)
143
+ window_size: window size
144
+ Returns:
145
+ windows - local window features (num_windows*B, window_size*window_size, C)
146
+ (Hp, Wp) - the size of the padded image
147
+ """
148
+ B, C, H, W = x.shape
149
+
150
+ if window_size == 0 or (window_size==H and window_size==W):
151
+ windows = x.flatten(2).transpose(1, 2)
152
+ Hp, Wp = H, W
153
+ else:
154
+ pad_h = (window_size - H % window_size) % window_size
155
+ pad_w = (window_size - W % window_size) % window_size
156
+ if pad_h > 0 or pad_w > 0:
157
+ x = F.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
158
+ Hp, Wp = H + pad_h, W + pad_w
159
+
160
+ x = x.view(B, C, Hp // window_size, window_size, Wp // window_size, window_size)
161
+ windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
162
+
163
+ return windows, (Hp, Wp)
164
+
165
+ class Conv2d_BN(nn.Module):
166
+ '''
167
+ Conv2d + BN layer with folding capability to speed up inference
168
+ Can be merged with Conv() function with additional arguments
169
+ '''
170
+ def __init__(self, a, b, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1, bias=False):
171
+ super().__init__()
172
+ self.conv = torch.nn.Conv2d(a, b, kernel_size, stride, padding, dilation, groups, bias=False)
173
+ if 1:
174
+ self.bn = torch.nn.BatchNorm2d(b)
175
+ torch.nn.init.constant_(self.bn.weight, bn_weight_init)
176
+ torch.nn.init.constant_(self.bn.bias, 0)
177
+
178
+ def forward(self,x):
179
+ x = self.conv(x)
180
+ x = self.bn(x)
181
+ return x
182
+
183
+ @torch.no_grad()
184
+ def switch_to_deploy(self):
185
+ if not isinstance(self.bn, nn.Identity):
186
+ c, bn = self.conv, self.bn
187
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
188
+ w = c.weight * w[:, None, None, None]
189
+ b = bn.bias - bn.running_mean * bn.weight / \
190
+ (bn.running_var + bn.eps)**0.5
191
+ self.conv.weight.data.copy_(w)
192
+ self.conv.bias = nn.Parameter(b)
193
+ self.bn = nn.Identity()
194
+
195
+
196
+
197
+ def window_reverse(windows, window_size, H, W, pad_hw):
198
+ """
199
+ Windows to the full feature map
200
+ Args:
201
+ windows: local window features (num_windows*B, window_size, window_size, C)
202
+ window_size: Window size
203
+ H: Height of image
204
+ W: Width of image
205
+ pad_w - a tuple of image passing used in windowing step
206
+ Returns:
207
+ x: (B, C, H, W)
208
+
209
+ """
210
+ # print(f"window_reverse, windows.shape {windows.shape}")
211
+ Hp, Wp = pad_hw
212
+ if window_size == 0 or (window_size==H and window_size==W):
213
+ B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
214
+ x = windows.transpose(1, 2).view(B, -1, H, W)
215
+ else:
216
+ B = int(windows.shape[0] / (Hp * Wp / window_size / window_size))
217
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
218
+ x = x.permute(0, 5, 1, 3, 2, 4).reshape(B,windows.shape[2], Hp, Wp)
219
+
220
+ if Hp > H or Wp > W:
221
+ x = x[:, :, :H, :W, ].contiguous()
222
+
223
+ return x
224
+
225
+
226
+
227
+ class PosEmbMLPSwinv2D(nn.Module):
228
+ """
229
+ 2D positional embedding from Swin Transformer v2
230
+ Added functionality to store the positional embedding in the model and not recompute it every time
231
+ """
232
+ def __init__(
233
+ self, window_size, pretrained_window_size, num_heads, seq_length, no_log=False, cpb_mlp_hidden=512,
234
+ ):
235
+ super().__init__()
236
+ self.window_size = window_size
237
+ self.num_heads = num_heads
238
+ # mlp to generate continuous relative position bias
239
+ self.cpb_mlp = nn.Sequential(
240
+ nn.Linear(2, cpb_mlp_hidden, bias=True),
241
+ nn.ReLU(inplace=True),
242
+ nn.Linear(cpb_mlp_hidden, num_heads, bias=False),
243
+ )
244
+
245
+ self.grid_exists = False
246
+ self.seq_length = seq_length
247
+ self.deploy = False
248
+ self.num_heads = num_heads
249
+ self.no_log = no_log
250
+ self.pretrained_window_size = pretrained_window_size
251
+ self.relative_bias_window_size = window_size
252
+
253
+ relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(window_size, num_heads,
254
+ pretrained_window_size, seq_length,
255
+ no_log)
256
+
257
+ self.register_buffer("relative_coords_table", relative_coords_table)
258
+ self.register_buffer("relative_position_index", relative_position_index)
259
+ self.register_buffer("relative_bias", relative_bias) # for EMA
260
+
261
+ def relative_bias_initialization(self, window_size, num_heads, pretrained_window_size, seq_length, no_log):
262
+ # as in separate function to support window size chage after model weights loading
263
+ relative_coords_h = torch.arange(
264
+ -(window_size[0] - 1), window_size[0], dtype=torch.float32
265
+ )
266
+ relative_coords_w = torch.arange(
267
+ -(window_size[1] - 1), window_size[1], dtype=torch.float32
268
+ )
269
+ relative_coords_table = (
270
+ torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
271
+ .permute(1, 2, 0)
272
+ .contiguous()
273
+ .unsqueeze(0)
274
+ ) # 1, 2*Wh-1, 2*Ww-1, 2
275
+ if pretrained_window_size[0] > 0:
276
+ relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
277
+ relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
278
+ else:
279
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
280
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
281
+
282
+ if not no_log:
283
+ relative_coords_table *= 8 # normalize to -8, 8
284
+ relative_coords_table = (
285
+ torch.sign(relative_coords_table)
286
+ * torch.log2(torch.abs(relative_coords_table) + 1.0)
287
+ / np.log2(8)
288
+ )
289
+
290
+ # get pair-wise relative position index for each token inside the window
291
+ coords_h = torch.arange(self.window_size[0])
292
+ coords_w = torch.arange(self.window_size[1])
293
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
294
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
295
+ relative_coords = (
296
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
297
+ ) # 2, Wh*Ww, Wh*Ww
298
+ relative_coords = relative_coords.permute(
299
+ 1, 2, 0
300
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
301
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
302
+ relative_coords[:, :, 1] += self.window_size[1] - 1
303
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
304
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
305
+
306
+ relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
307
+
308
+ self.relative_bias_window_size = window_size
309
+
310
+ return relative_coords_table, relative_position_index, relative_bias
311
+
312
+
313
+ def switch_to_deploy(self):
314
+ self.deploy = True
315
+ self.grid_exists = True
316
+
317
+ def forward(self, input_tensor):
318
+ # for efficiency, we want this forward to be folded into a single operation (sum)
319
+ # if resolution stays the same, then we dont need to recompute MLP layers
320
+
321
+ if not self.deploy or self.training:
322
+ self.grid_exists = False
323
+
324
+ #compare if all elements in self.window_size list match those in self.relative_bias_window_size
325
+ if not all([self.window_size[i] == self.relative_bias_window_size[i] for i in range(len(self.window_size))]):
326
+ relative_coords_table, relative_position_index, relative_bias = self.relative_bias_initialization(self.window_size, self.num_heads,
327
+ self.pretrained_window_size, self.seq_length,
328
+ self.no_log)
329
+
330
+ self.relative_coords_table = relative_coords_table.to(self.relative_coords_table.device)
331
+ self.relative_position_index = relative_position_index.to(self.relative_position_index.device)
332
+ self.relative_bias = relative_bias.to(self.relative_bias.device)
333
+
334
+ if self.deploy and self.grid_exists:
335
+ input_tensor = input_tensor + self.relative_bias
336
+ return input_tensor
337
+
338
+ if 1:
339
+ self.grid_exists = True
340
+
341
+ relative_position_bias_table = self.cpb_mlp(
342
+ self.relative_coords_table
343
+ ).view(-1, self.num_heads)
344
+ relative_position_bias = relative_position_bias_table[
345
+ self.relative_position_index.view(-1)
346
+ ].view(
347
+ self.window_size[0] * self.window_size[1],
348
+ self.window_size[0] * self.window_size[1],
349
+ -1,
350
+ ) # Wh*Ww,Wh*Ww,nH
351
+
352
+ relative_position_bias = relative_position_bias.permute(
353
+ 2, 0, 1
354
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
355
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
356
+
357
+ self.relative_bias = relative_position_bias.unsqueeze(0)
358
+
359
+ input_tensor = input_tensor + self.relative_bias
360
+ return input_tensor
361
+
362
+
363
+ class GRAAttentionBlock(nn.Module):
364
+ def __init__(self, window_size, dim_in, dim_out,
365
+ num_heads, drop_path=0., qk_scale=None, qkv_bias=False,
366
+ norm_layer=nn.LayerNorm, layer_scale=None,
367
+ use_swiglu=True,
368
+ subsample_ratio=1, dim_ratio=1, conv_base=False,
369
+ do_windowing=True, multi_query=False, use_shift=0,
370
+ cpb_mlp_hidden=512, conv_groups_ratio=0):
371
+ '''
372
+ Global Resolution Attention Block , see README for details
373
+ Attention with subsampling to get a bigger receptive field for attention
374
+ conv_base - use conv2d instead of avgpool2d for downsample / upsample
375
+
376
+
377
+ '''
378
+ super().__init__()
379
+
380
+ self.shift_size=window_size//2 if use_shift else 0
381
+
382
+ self.do_windowing = do_windowing
383
+ self.subsample_ratio = subsample_ratio
384
+
385
+
386
+
387
+ if do_windowing:
388
+ if conv_base:
389
+ self.downsample_op = nn.Conv2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
390
+
391
+
392
+ self.downsample_mixer = nn.Identity()
393
+ self.upsample_mixer = nn.Identity()
394
+ self.upsample_op = nn.ConvTranspose2d(dim_in, dim_out, kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
395
+ else:
396
+ self.downsample_op = nn.AvgPool2d(kernel_size=subsample_ratio, stride=subsample_ratio) if subsample_ratio > 1 else nn.Identity()
397
+ self.downsample_mixer = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1) if subsample_ratio > 1 else nn.Identity()
398
+ self.upsample_mixer = nn.Upsample(scale_factor=subsample_ratio, mode='nearest') if subsample_ratio > 1 else nn.Identity()
399
+ self.upsample_op = Conv2d_BN(dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False) if subsample_ratio > 1 else nn.Identity()
400
+
401
+
402
+ # in case there is no downsampling conv we want to have it separately
403
+ # will help with information propagation between windows
404
+ if subsample_ratio == 1:
405
+ # conv_groups_ratio=0
406
+ self.pre_conv = Conv2d_BN(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False)
407
+ # self.pre_conv = nn.Conv2d(dim_in, dim_in, kernel_size=3, stride=1, padding=1, groups=max(1,int(conv_groups_ratio*dim_in)), bias=False)
408
+ # self.pre_conv_act = nn.ReLU6()
409
+ #for simplicity:
410
+ self.pre_conv_act = nn.Identity()
411
+ if conv_groups_ratio == -1:
412
+ self.pre_conv = nn.Identity()
413
+ self.pre_conv_act = nn.Identity()
414
+
415
+ self.window_size = window_size
416
+
417
+ self.norm1 = norm_layer(dim_in)
418
+
419
+ self.attn = WindowAttention(
420
+ dim_in,
421
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
422
+ resolution=window_size,
423
+ seq_length=window_size**2, dim_out=dim_in, multi_query=multi_query,
424
+ shift_size=self.shift_size, cpb_mlp_hidden=cpb_mlp_hidden)
425
+
426
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
427
+
428
+ use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
429
+ self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim_in)) if use_layer_scale else 1
430
+
431
+ ### mlp layer
432
+ mlp_ratio = 4
433
+ self.norm2 = norm_layer(dim_in)
434
+ mlp_hidden_dim = int(dim_in * mlp_ratio)
435
+
436
+ activation = nn.GELU if not use_swiglu else SwiGLU
437
+ mlp_hidden_dim = int((4 * dim_in * 1 / 2) / 64) * 64 if use_swiglu else mlp_hidden_dim
438
+
439
+ self.mlp = Mlp(in_features=dim_in, hidden_features=mlp_hidden_dim, act_layer=activation, use_swiglu=use_swiglu)
440
+
441
+ self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim_in)) if layer_scale else 1
442
+ self.drop_path2=DropPath(drop_path) if drop_path > 0. else nn.Identity()
443
+
444
+
445
+ def forward(self, x):
446
+ skip_connection = x
447
+ attn_mask = None
448
+
449
+ # in case there is no downsampling conv we want to have it separately
450
+ # will help with information propagation
451
+ if self.subsample_ratio == 1:
452
+ x = self.pre_conv_act(self.pre_conv(x)) + skip_connection
453
+
454
+ if self.do_windowing:
455
+ # performing windowing if required
456
+ x = self.downsample_op(x)
457
+ x = self.downsample_mixer(x)
458
+
459
+ if self.window_size>0:
460
+ H, W = x.shape[2], x.shape[3]
461
+
462
+ if self.shift_size > 0 and H>self.window_size and W>self.window_size:
463
+ # @swin like cyclic shift, doesnt show better performance
464
+ x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3))
465
+
466
+ x, pad_hw = window_partition(x, self.window_size)
467
+
468
+ if self.shift_size > 0 and H>self.window_size and W>self.window_size:
469
+ # set atten matrix to have -100 and the top right square
470
+ # attn[:, :, :-self.shift_size, -self.shift_size:] = -100.0
471
+ # calculate attention mask for SW-MSA
472
+ # not used in final version, can be useful for some cases especially for high res
473
+ H, W = pad_hw
474
+ img_mask = torch.zeros((1, H, W, 1), device=x.device) # 1 H W 1
475
+ h_slices = (slice(0, -self.window_size),
476
+ slice(-self.window_size, -self.shift_size),
477
+ slice(-self.shift_size, None))
478
+ w_slices = (slice(0, -self.window_size),
479
+ slice(-self.window_size, -self.shift_size),
480
+ slice(-self.shift_size, None))
481
+ cnt = 0
482
+ for h in h_slices:
483
+ for w in w_slices:
484
+ img_mask[:, h, w, :] = cnt
485
+ cnt += 1
486
+ img_mask = img_mask.transpose(1,2).transpose(1,3)
487
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
488
+
489
+ mask_windows = mask_windows[0].view(-1, self.window_size * self.window_size)
490
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
491
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
492
+
493
+ # window attention
494
+ x = x + self.drop_path1(self.gamma1*self.attn(self.norm1(x), attn_mask=attn_mask)) # or pass H,W
495
+ # mlp layer
496
+ x = x + self.drop_path2(self.gamma2*self.mlp(self.norm2(x)))
497
+
498
+ if self.do_windowing:
499
+ if self.window_size > 0:
500
+ x = window_reverse(x, self.window_size, H, W, pad_hw)
501
+
502
+ # reverse cyclic shift
503
+ if self.shift_size > 0 and H>self.window_size and W>self.window_size:
504
+ # @swin like cyclic shift, not tested
505
+ x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(2, 3))
506
+
507
+ x = self.upsample_mixer(x)
508
+ x = self.upsample_op(x)
509
+
510
+
511
+ if x.shape[2] != skip_connection.shape[2] or x.shape[3] != skip_connection.shape[3]:
512
+ x = torch.nn.functional.pad(x, ( 0, -x.shape[3] + skip_connection.shape[3], 0, -x.shape[2] + skip_connection.shape[2]), mode="reflect")
513
+ # need to add skip connection because downsampling and upsampling will break residual connection
514
+ # 0.5 is needed to make sure that the skip connection is not too strong
515
+ # in case of no downsample / upsample we can show that 0.5 compensates for the residual connection
516
+ x = 0.5 * x + 0.5 * skip_connection
517
+ return x
518
+
519
+
520
+
521
+
522
+ class MultiResolutionAttention(nn.Module):
523
+ """
524
+ MultiResolutionAttention (MRA) module
525
+ The idea is to use multiple attention blocks with different resolution
526
+ Feature maps are downsampled / upsampled for each attention block on different blocks
527
+ Every attention block supports windowing
528
+ """
529
+
530
+ def __init__(self, window_size, sr_ratio,
531
+ dim, dim_ratio, num_heads,
532
+ do_windowing=True,
533
+ layer_scale=1e-5, norm_layer=nn.LayerNorm,
534
+ drop_path = 0, qkv_bias=False, qk_scale=1.0,
535
+ use_swiglu=True, multi_query=False, conv_base=False,
536
+ use_shift=0, cpb_mlp_hidden=512, conv_groups_ratio=0) -> None:
537
+ """
538
+ Args:
539
+ input_resolution: input image resolution
540
+ window_size: window size
541
+ compression_ratio: compression ratio
542
+ max_depth: maximum depth of the GRA module
543
+ use_shift: do window shifting
544
+ """
545
+ super().__init__()
546
+
547
+ depth = len(sr_ratio)
548
+
549
+ self.attention_blocks = nn.ModuleList()
550
+
551
+
552
+ for i in range(depth):
553
+ subsample_ratio = sr_ratio[i]
554
+ if len(window_size) > i:
555
+ window_size_local = window_size[i]
556
+ else:
557
+ window_size_local = window_size[0]
558
+
559
+ self.attention_blocks.append(GRAAttentionBlock(window_size=window_size_local,
560
+ dim_in=dim, dim_out=dim, num_heads=num_heads,
561
+ qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer,
562
+ layer_scale=layer_scale, drop_path=drop_path,
563
+ use_swiglu=use_swiglu, subsample_ratio=subsample_ratio, dim_ratio=dim_ratio,
564
+ do_windowing=do_windowing, multi_query=multi_query, conv_base=conv_base,
565
+ use_shift=use_shift, cpb_mlp_hidden=cpb_mlp_hidden, conv_groups_ratio=conv_groups_ratio),
566
+ )
567
+
568
+ def forward(self, x):
569
+
570
+ for attention_block in self.attention_blocks:
571
+ x = attention_block(x)
572
+
573
+ return x
574
+
575
+
576
+
577
+ class Mlp(nn.Module):
578
+ """
579
+ Multi-Layer Perceptron (MLP) block
580
+ """
581
+
582
+ def __init__(self,
583
+ in_features,
584
+ hidden_features=None,
585
+ out_features=None,
586
+ act_layer=nn.GELU,
587
+ use_swiglu=True,
588
+ drop=0.):
589
+ """
590
+ Args:
591
+ in_features: input features dimension.
592
+ hidden_features: hidden features dimension.
593
+ out_features: output features dimension.
594
+ act_layer: activation function.
595
+ drop: dropout rate.
596
+ """
597
+
598
+ super().__init__()
599
+ out_features = out_features or in_features
600
+ hidden_features = hidden_features or in_features
601
+ self.fc1 = nn.Linear(in_features, hidden_features * (2 if use_swiglu else 1), bias=False)
602
+ self.act = act_layer()
603
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
604
+
605
+ def forward(self, x):
606
+ x_size = x.size()
607
+ x = x.view(-1, x_size[-1])
608
+ x = self.fc1(x)
609
+ x = self.act(x)
610
+ x = self.fc2(x)
611
+ x = x.view(x_size)
612
+ return x
613
+
614
+ class Downsample(nn.Module):
615
+ """
616
+ Down-sampling block
617
+ Pixel Unshuffle is used for down-sampling, works great accuracy - wise but takes 10% more TRT time
618
+ """
619
+
620
+ def __init__(self,
621
+ dim,
622
+ shuffle = False,
623
+ ):
624
+ """
625
+ Args:
626
+ dim: feature size dimension.
627
+ shuffle: idea with
628
+ keep_dim: bool argument for maintaining the resolution.
629
+ """
630
+
631
+ super().__init__()
632
+ dim_out = 2 * dim
633
+
634
+ if shuffle:
635
+ self.norm = lambda x: pixel_unshuffle(x, factor=2)
636
+ self.reduction = Conv2d_BN(dim*4, dim_out, 1, 1, 0, bias=False)
637
+ # pixel unshuffleging works well but doesnt provide any speedup
638
+ else:
639
+ # removed layer norm for better, in this formulation we are getting 10% better speed
640
+ # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
641
+ # therefore we remove it compared to the original implementation in FasterViT
642
+ self.norm = nn.Identity()
643
+ self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
644
+
645
+
646
+ def forward(self, x):
647
+ x = self.norm(x)
648
+ x = self.reduction(x)
649
+ return x
650
+
651
+
652
+ class PatchEmbed(nn.Module):
653
+ """
654
+ Patch embedding block
655
+ Used to convert image into an initial set of feature maps with lower resolution
656
+ """
657
+
658
+ def __init__(self, in_chans=3, in_dim=64, dim=96, shuffle_down=False):
659
+ """
660
+ Args:
661
+ in_chans: number of input channels.
662
+ in_dim: intermediate feature size dimension to speed up stem.
663
+ dim: final stem channel number
664
+ shuffle_down: use PixelUnshuffle for down-sampling, effectively increases the receptive field
665
+ """
666
+
667
+ super().__init__()
668
+ # shuffle_down = False
669
+ if not shuffle_down:
670
+ self.proj = nn.Identity()
671
+ self.conv_down = nn.Sequential(
672
+ Conv2d_BN(in_chans, in_dim, 3, 2, 1, bias=False),
673
+ nn.ReLU(),
674
+ Conv2d_BN(in_dim, dim, 3, 2, 1, bias=False),
675
+ nn.ReLU()
676
+ )
677
+ else:
678
+ self.proj = lambda x: pixel_unshuffle(x, factor=4)
679
+ self.conv_down = nn.Sequential(Conv2d_BN(in_chans*16, dim, 3, 1, 1),
680
+ nn.ReLU(),
681
+ )
682
+
683
+ def forward(self, x):
684
+ x = self.proj(x)
685
+ x = self.conv_down(x)
686
+ return x
687
+
688
+
689
+
690
+ class ConvBlock(nn.Module):
691
+ """
692
+ Convolutional block, used in first couple of stages
693
+ Experimented with plan resnet-18 like modules, they are the best in terms of throughput
694
+ Finally, YOLOv8 idea seem to work fine (resnet-18 like block with squeezed feature dimension, and feature concatendation at the end)
695
+ """
696
+ def __init__(self, dim,
697
+ drop_path=0.,
698
+ layer_scale=None,
699
+ kernel_size=3,
700
+ ):
701
+ super().__init__()
702
+
703
+ self.conv1 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
704
+ self.act1 = nn.GELU()
705
+
706
+ self.conv2 = Conv2d_BN(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
707
+
708
+ self.layer_scale = layer_scale
709
+ if layer_scale is not None and type(layer_scale) in [int, float]:
710
+ self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
711
+ self.layer_scale = True
712
+ else:
713
+ self.layer_scale = False
714
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
715
+
716
+ def forward(self, x):
717
+ input = x
718
+
719
+ x = self.conv1(x)
720
+ x = self.act1(x)
721
+ x = self.conv2(x)
722
+
723
+ if self.layer_scale:
724
+ x = x * self.gamma.view(1, -1, 1, 1)
725
+ x = input + self.drop_path(x)
726
+ return x
727
+
728
+
729
+ class WindowAttention(nn.Module):
730
+ # Windowed Attention from SwinV2
731
+ # use a MLP trick to deal with various input image resolutions, then fold it to improve speed
732
+
733
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, resolution=0,
734
+ seq_length=0, dim_out=None, multi_query=False, shift_size=0, cpb_mlp_hidden=512):
735
+ # taken from EdgeViT and tweaked with attention bias.
736
+ super().__init__()
737
+ if not dim_out: dim_out = dim
738
+ self.shift_size = shift_size
739
+ self.multi_query = multi_query
740
+ self.num_heads = num_heads
741
+ head_dim = dim // num_heads
742
+ self.head_dim = dim // num_heads
743
+
744
+ self.dim_internal = dim
745
+
746
+ self.scale = qk_scale or head_dim ** -0.5
747
+ if not multi_query:
748
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
749
+ else:
750
+ self.qkv = nn.Linear(dim, dim + 2*self.head_dim, bias=qkv_bias)
751
+
752
+ self.proj = nn.Linear(dim, dim_out, bias=False)
753
+ # attention positional bias
754
+ self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],
755
+ pretrained_window_size=[resolution, resolution],
756
+ num_heads=num_heads,
757
+ seq_length=seq_length,
758
+ cpb_mlp_hidden=cpb_mlp_hidden)
759
+
760
+ self.resolution = resolution
761
+
762
+ def forward(self, x, attn_mask = None):
763
+ B, N, C = x.shape
764
+
765
+ if not self.multi_query:
766
+ qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
767
+ q, k, v = qkv[0], qkv[1], qkv[2]
768
+ else:
769
+ qkv = self.qkv(x)
770
+ (q, k, v) = qkv.split([self.dim_internal, self.head_dim, self.head_dim], dim=2)
771
+
772
+ q = q.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
773
+ k = k.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
774
+ v = v.reshape(B, -1, 1, C // self.num_heads).permute(0, 2, 1, 3)
775
+
776
+ attn = (q @ k.transpose(-2, -1)) * self.scale
777
+
778
+ attn = self.pos_emb_funct(attn)
779
+
780
+ #add window shift
781
+ if attn_mask is not None:
782
+ nW = attn_mask.shape[0]
783
+ attn = attn.view(B // nW, nW, self.num_heads, N, N) + attn_mask.unsqueeze(1).unsqueeze(0)
784
+ attn = attn.view(-1, self.num_heads, N, N)
785
+
786
+ attn = attn.softmax(dim=-1)
787
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
788
+ x = self.proj(x)
789
+ return x
790
+
791
+
792
+
793
+ class ERADIOLayer(nn.Module):
794
+ """
795
+ E-RADIO Layer
796
+ """
797
+
798
+ def __init__(self,
799
+ dim,
800
+ depth,
801
+ num_heads,
802
+ window_size,
803
+ conv=False,
804
+ downsample=True,
805
+ mlp_ratio=4.,
806
+ qkv_bias=False,
807
+ qk_scale=None,
808
+ norm_layer=nn.LayerNorm,
809
+ drop_path=0.,
810
+ layer_scale=None,
811
+ layer_scale_conv=None,
812
+ sr_dim_ratio=1,
813
+ sr_ratio=1,
814
+ multi_query=False,
815
+ use_swiglu=True,
816
+ yolo_arch=False,
817
+ downsample_shuffle=False,
818
+ conv_base=False,
819
+ use_shift=False,
820
+ cpb_mlp_hidden=512,
821
+ conv_groups_ratio=0,
822
+ verbose: bool = True,
823
+
824
+ ):
825
+ """
826
+ Args:
827
+ dim: feature size dimension.
828
+ depth: number of layers in each stage.
829
+ input_resolution: input image resolution.
830
+ window_size: window size in each stage.
831
+ downsample: bool argument for down-sampling.
832
+ mlp_ratio: MLP ratio.
833
+ num_heads: number of heads in each stage.
834
+ qkv_bias: bool argument for query, key, value learnable bias.
835
+ qk_scale: bool argument to scaling query, key.
836
+ drop: dropout rate.
837
+ attn_drop: attention dropout rate.
838
+ drop_path: drop path rate.
839
+ norm_layer: normalization layer.
840
+ layer_scale: layer scaling coefficient.
841
+ use_shift: SWIN like window shifting for half the window size for every alternating layer (considering multi-resolution)
842
+ conv_groups_ratio: group ratio for conv when no subsampling in multi-res attention
843
+ """
844
+
845
+ super().__init__()
846
+ self.conv = conv
847
+ self.yolo_arch=False
848
+ self.verbose = verbose
849
+ if conv:
850
+ if not yolo_arch:
851
+ self.blocks = nn.ModuleList([
852
+ ConvBlock(dim=dim,
853
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
854
+ layer_scale=layer_scale_conv)
855
+ for i in range(depth)])
856
+ self.blocks = nn.Sequential(*self.blocks)
857
+ else:
858
+ self.blocks = C2f(dim,dim,n=depth,shortcut=True,e=0.5)
859
+ self.yolo_arch=True
860
+ else:
861
+ if not isinstance(window_size, list): window_size = [window_size]
862
+ self.window_size = window_size[0]
863
+ self.do_single_windowing = True
864
+ if not isinstance(sr_ratio, list): sr_ratio = [sr_ratio]
865
+ self.sr_ratio = sr_ratio
866
+ if any([sr!=1 for sr in sr_ratio]) or len(set(window_size))>1:
867
+ self.do_single_windowing = False
868
+ do_windowing = True
869
+ else:
870
+ self.do_single_windowing = True
871
+ do_windowing = False
872
+
873
+ #for v2_2
874
+ if conv_groups_ratio != -1:
875
+ self.do_single_windowing = False
876
+ do_windowing = True
877
+
878
+ self.blocks = nn.ModuleList()
879
+ for i in range(depth):
880
+ self.blocks.append(
881
+ MultiResolutionAttention(window_size=window_size,
882
+ sr_ratio=sr_ratio,
883
+ dim=dim,
884
+ dim_ratio = sr_dim_ratio,
885
+ num_heads=num_heads,
886
+ norm_layer=norm_layer,
887
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
888
+ layer_scale=layer_scale,
889
+ qkv_bias=qkv_bias,
890
+ qk_scale=qk_scale,
891
+ use_swiglu=use_swiglu,
892
+ do_windowing=do_windowing,
893
+ multi_query=multi_query,
894
+ conv_base=conv_base,
895
+ cpb_mlp_hidden=cpb_mlp_hidden,
896
+ use_shift =0 if ((not use_shift) or ((i) % 2 == 0)) else True ,
897
+ conv_groups_ratio=conv_groups_ratio,
898
+ ))
899
+ self.blocks = nn.Sequential(*self.blocks)
900
+
901
+ self.transformer = not conv
902
+ self.downsample = None if not downsample else Downsample(dim=dim, shuffle=downsample_shuffle)
903
+
904
+
905
+ def forward(self, x):
906
+ B, C, H, W = x.shape
907
+
908
+ # do padding for transforemr
909
+ interpolate = True
910
+ if self.transformer and interpolate:
911
+ # Windowed Attention will split feature map into windows with the size of window_size x window_size
912
+ # if the resolution is not divisible by window_size, we need to interpolate the feature map
913
+ # can be done via padding, but doing so after training hurts the model performance.
914
+ # interpolation affects the performance as well, but not as much as padding
915
+ if isinstance(self.window_size, list) or isinstance(self.window_size, tuple):
916
+ current_max_window_size = max(self.window_size)
917
+ else:
918
+ current_max_window_size = self.window_size
919
+
920
+ max_window_size = max([res_upsample*current_max_window_size for res_upsample in self.sr_ratio])
921
+ if H % max_window_size != 0 or W % max_window_size != 0:
922
+ new_h = int(np.ceil(H/max_window_size)*max_window_size)
923
+ new_w = int(np.ceil(W/max_window_size)*max_window_size)
924
+ x = F.interpolate(x, size=(new_h, new_w), mode='nearest')
925
+ if self.verbose:
926
+ warnings.warn(f"Choosen window size is not optimal for given resolution. Interpolation of features maps will be done and it can affect the performance. Max window size is {max_window_size}, feature map size is {H}x{W}, interpolated feature map size is {new_h}x{new_w}.")
927
+
928
+
929
+ if self.transformer and self.do_single_windowing:
930
+ H, W = x.shape[2], x.shape[3]
931
+ x, pad_hw = window_partition(x, self.window_size)
932
+
933
+ #run main blocks
934
+ x = self.blocks(x)
935
+
936
+ if self.transformer and self.do_single_windowing:
937
+ x = window_reverse(x, self.window_size, H, W, pad_hw)
938
+
939
+ if self.transformer and interpolate:
940
+ #lets keep original resolution, might be not ideal, but for the upsampling tower we need to keep the expected resolution.
941
+ x = F.interpolate(x, size=(H, W), mode='nearest')
942
+
943
+ if self.downsample is None:
944
+ return x, x
945
+
946
+ return self.downsample(x), x # changing to output pre downsampled features
947
+
948
+
949
+ class InterpolateLayer(nn.Module):
950
+ def __init__(self, size=None, scale_factor=None, mode='nearest'):
951
+ super(InterpolateLayer, self).__init__()
952
+ self.size = size
953
+ self.scale_factor = scale_factor
954
+ self.mode = mode
955
+
956
+ def forward(self, x):
957
+ return F.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode)
958
+
959
+
960
+ class HiResNeck(nn.Module):
961
+ """
962
+ The block is used to output dense features from all stages
963
+ Otherwise, by default, only the last stage features are returned with E-RADIO
964
+ """
965
+ def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled):
966
+
967
+ '''
968
+ Hi Resolution neck to support output of high res features that are useful for dense tasks.
969
+ depths - total number of layers in the base model
970
+ neck_start_stage - when to start the neck, 0 - start from the first stage, 1 - start from the second stage etc.
971
+ earlier layers result in higher resolution features at the cost of compute
972
+ full_features_head_dim - number of channels in the dense features head
973
+ '''
974
+ super().__init__()
975
+ # create feature projection layers for segmentation output
976
+ self.neck_features_proj = nn.ModuleList()
977
+ self.neck_start_stage = neck_start_stage
978
+ upsample_ratio = 1
979
+ for i in range(len(depths)):
980
+ level_n_features_output = int(dim * 2 ** i)
981
+
982
+ if self.neck_start_stage > i: continue
983
+
984
+ if (upsample_ratio > 1) or full_features_head_dim!=level_n_features_output:
985
+ feature_projection = nn.Sequential()
986
+ if False:
987
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output)) #fast, but worse
988
+ feature_projection.add_module("dconv", nn.ConvTranspose2d(level_n_features_output,
989
+ full_features_head_dim, kernel_size=upsample_ratio, stride=upsample_ratio))
990
+ else:
991
+ # B, in_channels, H, W -> B, in_channels, H*upsample_ratio, W*upsample_ratio
992
+ # print("upsample ratio", upsample_ratio, level_n_features_output, level_n_features_output)
993
+ feature_projection.add_module("upsample", InterpolateLayer(scale_factor=upsample_ratio, mode='nearest'))
994
+ feature_projection.add_module("conv1", nn.Conv2d(level_n_features_output, level_n_features_output, kernel_size=3, stride=1, padding=1, groups=level_n_features_output))
995
+ feature_projection.add_module("norm",nn.BatchNorm2d(level_n_features_output))
996
+ # B, in_channels, H*upsample_ratio, W*upsample_ratio -> B, full_features_head_dim, H*upsample_ratio, W*upsample_ratio
997
+ feature_projection.add_module("conv2", nn.Conv2d(level_n_features_output, full_features_head_dim, kernel_size=1, stride=1, padding=0))
998
+ else:
999
+ feature_projection = nn.Sequential()
1000
+
1001
+ self.neck_features_proj.append(feature_projection)
1002
+
1003
+ if i>0 and downsample_enabled[i]:
1004
+ upsample_ratio *= 2
1005
+
1006
+ def forward(self, x, il_level=-1, full_features=None):
1007
+ if self.neck_start_stage > il_level:
1008
+ return full_features
1009
+
1010
+ if full_features is None:
1011
+ full_features = self.neck_features_proj[il_level - self.neck_start_stage](x)
1012
+ else:
1013
+ #upsample torch tensor x to match full_features size, and add to full_features
1014
+ feature_projection = self.neck_features_proj[il_level - self.neck_start_stage](x)
1015
+ if feature_projection.shape[2] != full_features.shape[2] or feature_projection.shape[3] != full_features.shape[3]:
1016
+ feature_projection = torch.nn.functional.pad(feature_projection, ( 0, -feature_projection.shape[3] + full_features.shape[3], 0, -feature_projection.shape[2] + full_features.shape[2]))
1017
+ full_features = full_features + feature_projection
1018
+ return full_features
1019
+
1020
+ class ERADIO(nn.Module):
1021
+ """
1022
+ Efficient RADIO
1023
+ """
1024
+
1025
+ def __init__(self,
1026
+ dim,
1027
+ in_dim,
1028
+ depths,
1029
+ window_size,
1030
+ mlp_ratio,
1031
+ num_heads,
1032
+ drop_path_rate=0.2,
1033
+ in_chans=3,
1034
+ num_classes=1000,
1035
+ qkv_bias=False,
1036
+ qk_scale=None,
1037
+ layer_scale=None,
1038
+ layer_scale_conv=None,
1039
+ layer_norm_last=False,
1040
+ sr_ratio = [1, 1, 1, 1],
1041
+ max_depth = -1,
1042
+ conv_base=False,
1043
+ use_swiglu=False,
1044
+ multi_query=False,
1045
+ norm_layer=nn.LayerNorm,
1046
+ drop_uniform=False,
1047
+ yolo_arch=False,
1048
+ shuffle_down=False,
1049
+ downsample_shuffle=False,
1050
+ return_full_features=False,
1051
+ full_features_head_dim=128,
1052
+ neck_start_stage=1,
1053
+ use_neck=False,
1054
+ use_shift=False,
1055
+ cpb_mlp_hidden=512,
1056
+ conv_groups_ratio=0,
1057
+ verbose: bool = False,
1058
+ **kwargs):
1059
+ """
1060
+ Args:
1061
+ dim: feature size dimension.
1062
+ depths: number of layers in each stage.
1063
+ window_size: window size in each stage.
1064
+ mlp_ratio: MLP ratio.
1065
+ num_heads: number of heads in each stage.
1066
+ drop_path_rate: drop path rate.
1067
+ in_chans: number of input channels.
1068
+ num_classes: number of classes.
1069
+ qkv_bias: bool argument for query, key, value learnable bias.
1070
+ qk_scale: bool argument to scaling query, key.
1071
+ drop_rate: dropout rate.
1072
+ attn_drop_rate: attention dropout rate.
1073
+ norm_layer: normalization layer.
1074
+ layer_scale: layer scaling coefficient.
1075
+ return_full_features: output dense features as well as logits
1076
+ full_features_head_dim: number of channels in the dense features head
1077
+ neck_start_stage: a stage id to start full feature neck. Model has 4 stages, indix starts with 0
1078
+ for 224 resolution, the output of the stage before downsample:
1079
+ stage 0: 56x56, stage 1: 28x28, stage 2: 14x14, stage 3: 7x7
1080
+ use_neck: even for summarization embedding use neck
1081
+ use_shift: SWIN like window shifting but without masking attention
1082
+ conv_groups_ratio: will be used for conv blocks where there is no multires attention,
1083
+ if 0 then normal conv,
1084
+ if 1 then channels are independent,
1085
+ if -1 then no conv at all
1086
+
1087
+ """
1088
+ super().__init__()
1089
+
1090
+ num_features = int(dim * 2 ** (len(depths) - 1))
1091
+ self.num_classes = num_classes
1092
+ self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim, shuffle_down=shuffle_down)
1093
+ # set return_full_features true if we want to return full features from all stages
1094
+ self.return_full_features = return_full_features
1095
+ self.use_neck = use_neck
1096
+
1097
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
1098
+ if drop_uniform:
1099
+ dpr = [drop_path_rate for x in range(sum(depths))]
1100
+
1101
+ if not isinstance(max_depth, list): max_depth = [max_depth] * len(depths)
1102
+
1103
+ self.levels = nn.ModuleList()
1104
+ for i in range(len(depths)):
1105
+ conv = True if (i == 0 or i == 1) else False
1106
+
1107
+ level = ERADIOLayer(dim=int(dim * 2 ** i),
1108
+ depth=depths[i],
1109
+ num_heads=num_heads[i],
1110
+ window_size=window_size[i],
1111
+ mlp_ratio=mlp_ratio,
1112
+ qkv_bias=qkv_bias,
1113
+ qk_scale=qk_scale,
1114
+ conv=conv,
1115
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
1116
+ downsample=(i < len(depths) - 1),
1117
+ layer_scale=layer_scale,
1118
+ layer_scale_conv=layer_scale_conv,
1119
+ sr_ratio=sr_ratio[i],
1120
+ use_swiglu=use_swiglu,
1121
+ multi_query=multi_query,
1122
+ norm_layer=norm_layer,
1123
+ yolo_arch=yolo_arch,
1124
+ downsample_shuffle=downsample_shuffle,
1125
+ conv_base=conv_base,
1126
+ cpb_mlp_hidden=cpb_mlp_hidden,
1127
+ use_shift=use_shift,
1128
+ conv_groups_ratio=conv_groups_ratio,
1129
+ verbose=verbose)
1130
+
1131
+ self.levels.append(level)
1132
+
1133
+ if self.return_full_features or self.use_neck:
1134
+ #num_heads
1135
+ downsample_enabled = [self.levels[i-1].downsample is not None for i in range(len(self.levels))]
1136
+ self.high_res_neck = HiResNeck(dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled)
1137
+
1138
+ self.switched_to_deploy = False
1139
+
1140
+ self.norm = LayerNorm2d(num_features) if layer_norm_last else nn.BatchNorm2d(num_features)
1141
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
1142
+ self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()
1143
+ self.apply(self._init_weights)
1144
+
1145
+ def _init_weights(self, m):
1146
+ if isinstance(m, nn.Linear):
1147
+ trunc_normal_(m.weight, std=.02)
1148
+ if isinstance(m, nn.Linear) and m.bias is not None:
1149
+ nn.init.constant_(m.bias, 0)
1150
+ elif isinstance(m, nn.LayerNorm):
1151
+ nn.init.constant_(m.bias, 0)
1152
+ nn.init.constant_(m.weight, 1.0)
1153
+ elif isinstance(m, LayerNorm2d):
1154
+ nn.init.constant_(m.bias, 0)
1155
+ nn.init.constant_(m.weight, 1.0)
1156
+ elif isinstance(m, nn.BatchNorm2d):
1157
+ nn.init.ones_(m.weight)
1158
+ nn.init.zeros_(m.bias)
1159
+
1160
+ @torch.jit.ignore
1161
+ def no_weight_decay_keywords(self):
1162
+ return {'rpb'}
1163
+
1164
+ def forward_features(self, x):
1165
+ _, _, H, W = x.shape
1166
+ if H % 32 != 0 or W % 32 != 0:
1167
+ raise ValueError(f"E-RADIO requires input dimensions to be divisible by 32 but got H x W: {H} x {W}")
1168
+ x = self.patch_embed(x)
1169
+ full_features = None
1170
+ for il, level in enumerate(self.levels):
1171
+ x, pre_downsample_x = level(x)
1172
+
1173
+ if self.return_full_features or self.use_neck:
1174
+ full_features = self.high_res_neck(pre_downsample_x, il, full_features)
1175
+
1176
+ # x = self.norm(full_features if (self.return_full_features or self.use_neck) else x)
1177
+ x = self.norm(x) # new version for
1178
+
1179
+ if not self.return_full_features:
1180
+ return x, None
1181
+
1182
+ return x, full_features
1183
+
1184
+ def forward(self, x):
1185
+ x, full_features = self.forward_features(x)
1186
+
1187
+ x = self.avgpool(x)
1188
+ x = torch.flatten(x, 1)
1189
+
1190
+ x = self.head(x)
1191
+ if full_features is not None:
1192
+ return x, full_features
1193
+ return x
1194
+
1195
+ def switch_to_deploy(self):
1196
+ '''
1197
+ A method to perform model self-compression
1198
+ merges BN into conv layers
1199
+ converts MLP relative positional bias into precomputed buffers
1200
+ '''
1201
+ if not self.switched_to_deploy:
1202
+ for level in [self.patch_embed, self.levels, self.head]:
1203
+ for module in level.modules():
1204
+ if hasattr(module, 'switch_to_deploy'):
1205
+ module.switch_to_deploy()
1206
+ self.switched_to_deploy = True
1207
+
1208
+
1209
+ def change_window_size(self, new_window_size):
1210
+ """
1211
+ E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,
1212
+ especially in cases of uneven partitioning of the feature maps.
1213
+ E-RADIO allows for the adjustment of the window size after training,
1214
+ making it adaptable to different input image resolutions.
1215
+ The recommended values for window size based on input resolution are as follows:
1216
+
1217
+ Input Resolution | Window Size
1218
+ 224 | 7
1219
+ 256 | 8
1220
+ 386 | 12
1221
+ 512 | 16
1222
+ Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be
1223
+ img_res/16/2
1224
+ for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size.
1225
+ Manual way to change resolution -> model.change_window_size(resolution)
1226
+ """
1227
+ window_size = new_window_size
1228
+ print(f"Setting window size to {window_size}")
1229
+ for module in self.modules():
1230
+ if hasattr(module, "window_size"):
1231
+ # check if tuple or a number
1232
+ if isinstance(module.window_size, tuple):
1233
+ if module.window_size[0] != window_size:
1234
+ module.window_size = (window_size, window_size)
1235
+ elif isinstance(module.window_size, list):
1236
+ if module.window_size[0] != window_size:
1237
+ module.window_size = [window_size, window_size]
1238
+ else:
1239
+ module.window_size = window_size
1240
+
1241
+
1242
+ def set_optimal_window_size(self, image_dim, max_window_size = 16):
1243
+ """
1244
+ Using hand picked window size for various resolutions.
1245
+
1246
+ E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,
1247
+ especially in cases of uneven partitioning of the feature maps.
1248
+ E-RADIO allows for the adjustment of the window size after training,
1249
+ making it adaptable to different input image resolutions.
1250
+ The recommended values for window size based on input resolution are as follows:
1251
+
1252
+ Input Resolution | Window Size
1253
+ 224 | 7
1254
+ 256 | 8
1255
+ 386 | 12
1256
+ 512 | 16
1257
+ Ideally, the window size should be a factor of the input resolution. In the third stage, we divide the resolution by 16, so the window size should be
1258
+ img_res/16/2
1259
+ for the third stage and img_res/32 for the last stage. While this can be applied in a brute-force manner, a better way is to do model.change_window_size.
1260
+ Manual way to change resolution -> model.change_window_size(resolution)
1261
+
1262
+ """
1263
+ # import math
1264
+
1265
+ def divisorGenerator(n):
1266
+ large_divisors = []
1267
+ for i in range(1, int(math.sqrt(n) + 1)):
1268
+ if n % i == 0:
1269
+ yield i
1270
+ if i*i != n:
1271
+ large_divisors.append(n / i)
1272
+ for divisor in reversed(large_divisors):
1273
+ yield divisor
1274
+
1275
+ if isinstance(image_dim, list) or isinstance(image_dim, tuple):
1276
+ image_dim = min(image_dim)
1277
+
1278
+ # we do windowed attention in the 3rd stage for the first time, therefore //16,
1279
+ # we do subsampled attention with downsample by 2 so need to get //32 actually
1280
+ # ideally we should rewrite this to be dependent on the structure of the model like what if subsampled is removed etc
1281
+ all_divisors = np.array(list(divisorGenerator(image_dim//32)))
1282
+ new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size))
1283
+
1284
+ # for image_dim in [128, 224, 256, 384, 512, 768, 1024]:
1285
+ # all_divisors = np.array(list(divisorGenerator(image_dim//32)))
1286
+ # new_window_size = int(min(all_divisors[all_divisors <= max_window_size][-1], max_window_size))
1287
+ # print(f"Setting window size to {new_window_size} for image resolution {image_dim}")
1288
+
1289
+ self.change_window_size(new_window_size = new_window_size)
1290
+
1291
+
1292
+ @register_model
1293
+ def eradio_large_fullres_ws16(pretrained=False, **kwargs):
1294
+ model = ERADIO(
1295
+ depths=[3, 3, 5, 5],
1296
+ num_heads=[2, 4, 8, 16],
1297
+ window_size=[None, None, [16, 16], 16],
1298
+ dim=192,
1299
+ in_dim=64,
1300
+ mlp_ratio=4,
1301
+ drop_path_rate=0.0,
1302
+ sr_ratio=[1, 1, [2, 1], 1],
1303
+ use_swiglu=False,
1304
+ yolo_arch=True,
1305
+ shuffle_down=False,
1306
+ conv_base=True,
1307
+ use_neck=True,
1308
+ full_features_head_dim=1536,
1309
+ neck_start_stage=2,
1310
+ **kwargs,
1311
+ )
1312
+ if pretrained:
1313
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1314
+ return model
1315
+
1316
+
1317
+ @register_model
1318
+ def eradio_xxxtiny(pretrained=False, **kwargs): # ,
1319
+ model = ERADIO(
1320
+ depths=[1, 3, 4, 5],
1321
+ num_heads=[2, 4, 8, 16],
1322
+ window_size=[None, None, [16, 16], 16],
1323
+ dim=32,
1324
+ in_dim=32,
1325
+ mlp_ratio=4,
1326
+ drop_path_rate=0.0,
1327
+ sr_ratio=[1, 1, [2, 1], 1],
1328
+ use_swiglu=False,
1329
+ yolo_arch=True,
1330
+ shuffle_down=False,
1331
+ conv_base=True,
1332
+ use_neck=True,
1333
+ full_features_head_dim=256,
1334
+ neck_start_stage=2,
1335
+ **kwargs,
1336
+ )
1337
+ if pretrained:
1338
+ model.load_state_dict(torch.load(pretrained))
1339
+ return model
1340
+
1341
+ @register_model
1342
+ def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
1343
+ model = ERADIO(depths=[1, 3, 4, 5],
1344
+ num_heads=[2, 4, 8, 16],
1345
+ window_size=[None, None, [12, 12], 12],
1346
+ dim=32,
1347
+ in_dim=32,
1348
+ mlp_ratio=4,
1349
+ drop_path_rate=0.0,
1350
+ sr_ratio=[1, 1, [2, 1], 1],
1351
+ use_swiglu=False,
1352
+ downsample_shuffle=False,
1353
+ yolo_arch=True,
1354
+ shuffle_down=False,
1355
+ cpb_mlp_hidden=64,
1356
+ use_neck=True,
1357
+ full_features_head_dim=256,
1358
+ neck_start_stage=2,
1359
+ conv_groups_ratio = 1,
1360
+ **kwargs)
1361
+ if pretrained:
1362
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1363
+ return model
1364
+
1365
+
1366
+ @register_model
1367
+ def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
1368
+ model = ERADIO(depths=[1, 3, 4, 5],
1369
+ num_heads=[2, 4, 8, 16],
1370
+ window_size=[None, None, [16, 16], 16],
1371
+ dim=32,
1372
+ in_dim=32,
1373
+ mlp_ratio=4,
1374
+ drop_path_rate=0.0,
1375
+ sr_ratio=[1, 1, [2, 1], 1],
1376
+ use_swiglu=False,
1377
+ downsample_shuffle=False,
1378
+ yolo_arch=True,
1379
+ shuffle_down=False,
1380
+ cpb_mlp_hidden=64,
1381
+ use_neck=True,
1382
+ full_features_head_dim=256,
1383
+ neck_start_stage=1,
1384
+ conv_groups_ratio = 1,
1385
+ **kwargs)
1386
+ if pretrained:
1387
+ model.load_state_dict(torch.load(pretrained)["state_dict"])
1388
+ return model
1389
+
1390
+ @register_model
1391
+ def eradio(pretrained=False, **kwargs):
1392
+ return eradio_large_fullres_ws16(pretrained=pretrained, **kwargs)
src/models/radiov3/extra_models.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.version import LooseVersion
2
+ from types import MethodType
3
+ from typing import List, Optional, Tuple, Union
4
+ import warnings
5
+
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+
10
+ from timm.models.registry import register_model
11
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12
+
13
+ from .forward_intermediates import forward_intermediates
14
+ from .input_conditioner import InputConditioner
15
+
16
+ _has_torch_sdpa = hasattr(F, 'scaled_dot_product_attention')
17
+
18
+
19
+ class PaliGemmaWrapper(nn.Module):
20
+ def __init__(self, vis_model: nn.Module, embed_dim: int):
21
+ super().__init__()
22
+
23
+ self.vis_model = vis_model
24
+ self.embed_dim = embed_dim
25
+
26
+ @property
27
+ def patch_size(self):
28
+ return self.vis_model.embeddings.patch_size
29
+
30
+ @property
31
+ def blocks(self):
32
+ return self.vis_model.encoder.layers
33
+
34
+ @property
35
+ def embed_dim(self):
36
+ return self.vis_model.embeddings.embed_dim
37
+
38
+ def forward(self, x: torch.Tensor):
39
+ outputs = self.vis_model(
40
+ x,
41
+ return_dict=False,
42
+ interpolate_pos_encoding=True,
43
+ )
44
+
45
+ features = outputs[0].to(torch.float32)
46
+
47
+ summary = features.mean(dim=1)
48
+
49
+ return summary, features
50
+
51
+ def forward_features(self, x: torch.Tensor):
52
+ return self(x)
53
+
54
+
55
+ def _get_paligemma_model(repo: str, embed_dim: int = None, dtype: torch.dtype = torch.bfloat16):
56
+ from transformers import PaliGemmaForConditionalGeneration, __version__ as tx_version
57
+
58
+ if LooseVersion(tx_version) > LooseVersion('4.44.2'):
59
+ warnings.warn(f'Your transformers version "{tx_version}" is higher than 4.44.2, and for whatever reason, PaliGemma might be broken.')
60
+
61
+ extra_args = dict()
62
+
63
+ if dtype is not None:
64
+ extra_args['torch_dtype'] = dtype
65
+ rev = str(dtype).split('.')[-1]
66
+ extra_args['revision'] = rev
67
+
68
+ model = PaliGemmaForConditionalGeneration.from_pretrained(repo, **extra_args)
69
+
70
+ vis_model = model.vision_tower.vision_model
71
+
72
+ vis_model = PaliGemmaWrapper(vis_model, embed_dim)
73
+
74
+ return vis_model
75
+
76
+ @register_model
77
+ def paligemma_896_student(**kwargs):
78
+ model = _get_paligemma_model('google/paligemma-3b-pt-896', embed_dim=1152, dtype=None)
79
+
80
+ return model
81
+
82
+
83
+ def dv2_sdpa(self, x: torch.Tensor) -> torch.Tensor:
84
+ B, N, C = x.shape
85
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
86
+
87
+ q, k, v = qkv[0], qkv[1], qkv[2]
88
+ x = F.scaled_dot_product_attention(
89
+ q, k, v,
90
+ is_causal=False,
91
+ dropout_p=self.attn_drop.p if self.training else 0.,
92
+ scale=self.scale,
93
+ )
94
+ x = x.transpose(1, 2).reshape(B, N, C)
95
+ x = self.proj(x)
96
+ x = self.proj_drop(x)
97
+ return x
98
+
99
+ def _load_dino_v2(dino_v2_model, cache_dir: Optional[str] = None, pretrained=True, **kwargs):
100
+ if cache_dir:
101
+ torch.hub.set_dir(cache_dir)
102
+ model: nn.Module = torch.hub.load(
103
+ 'facebookresearch/dinov2',
104
+ dino_v2_model,
105
+ pretrained=pretrained,
106
+ # **kwargs,
107
+ )
108
+
109
+ if _has_torch_sdpa:
110
+ for n, m in model.named_modules():
111
+ if n.endswith('.attn'):
112
+ m.forward = MethodType(dv2_sdpa, m)
113
+
114
+ return model
115
+
116
+ class DinoWrapper(nn.Module):
117
+ def __init__(self, dino_model: nn.Module):
118
+ super().__init__()
119
+
120
+ self.inner = dino_model
121
+ dino_model.blocks = nn.Sequential(*dino_model.blocks)
122
+
123
+ @property
124
+ def embed_dim(self):
125
+ return self.inner.embed_dim
126
+
127
+ @property
128
+ def patch_size(self):
129
+ return self.inner.patch_size
130
+
131
+ @property
132
+ def num_cls_tokens(self):
133
+ return getattr(self.inner, 'num_tokens', 1)
134
+
135
+ @property
136
+ def num_registers(self):
137
+ return getattr(self.inner, 'num_register_tokens', 0)
138
+
139
+ @property
140
+ def num_summary_tokens(self):
141
+ return self.num_cls_tokens + self.num_registers
142
+
143
+ @property
144
+ def blocks(self):
145
+ return self.inner.blocks
146
+
147
+ def forward(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
148
+ parts = self.inner.forward_features(*args, **kwargs)
149
+
150
+ cls_token = parts['x_norm_clstoken']
151
+ features = parts['x_norm_patchtokens']
152
+
153
+ return cls_token, features
154
+
155
+ def forward_features(self, x: torch.Tensor):
156
+ x = self.inner.prepare_tokens_with_masks(x)
157
+ x = self.inner.blocks(x)
158
+ x_norm = self.inner.norm(x)
159
+
160
+ return x_norm[:, 0], x_norm[:, self.num_summary_tokens:]
161
+
162
+ def patchify(self, x: torch.Tensor) -> torch.Tensor:
163
+ return self.inner.prepare_tokens_with_masks(x)
164
+
165
+ def forward_intermediates(self,
166
+ x: torch.Tensor,
167
+ norm: bool = False,
168
+ **kwargs,
169
+ ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
170
+ return forward_intermediates(
171
+ self,
172
+ patch_extractor=self.inner.prepare_tokens_with_masks,
173
+ num_summary_tokens=self.num_summary_tokens,
174
+ num_cls_tokens=self.num_cls_tokens,
175
+ norm=self.inner.norm if norm else lambda y: y,
176
+ x=x,
177
+ **kwargs,
178
+ )
179
+
180
+
181
+ def _dino_student(arch: str, **kwargs):
182
+ from . import dinov2_arch
183
+
184
+ factory = getattr(dinov2_arch, arch)
185
+ model = factory()
186
+
187
+ model = DinoWrapper(model)
188
+
189
+ conditioner = InputConditioner(
190
+ input_scale=1.0,
191
+ norm_mean=IMAGENET_DEFAULT_MEAN,
192
+ norm_std=IMAGENET_DEFAULT_STD,
193
+ )
194
+
195
+ model.input_conditioner = conditioner
196
+
197
+ return model
198
+
199
+
200
+ @register_model
201
+ def dino_v2_l_student(**kwargs):
202
+ return _dino_student('dinov2_vitl14_reg', **kwargs)
203
+
204
+ @register_model
205
+ def dino_v2_g_student(**kwargs):
206
+ return _dino_student('dinov2_vitg14_reg', **kwargs)
src/models/radiov3/extra_timm_models.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import math
10
+ import warnings
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+
16
+ from timm.models import register_model
17
+ from timm.models.vision_transformer import (
18
+ VisionTransformer,
19
+ _create_vision_transformer as _timm_create_vision_transformer,
20
+ Mlp,
21
+ Block,
22
+ LayerScale as TIMMLayerScale,
23
+ )
24
+
25
+ # Import these to also register them
26
+ from . import dinov2_arch
27
+
28
+
29
+ @register_model
30
+ def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
31
+ """ ViT-Tiny (Vit-Ti/16)
32
+ """
33
+ model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3)
34
+ model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
35
+ return model
36
+
37
+
38
+ @register_model
39
+ def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
40
+ """ ViT-Small (ViT-S/16)
41
+ """
42
+ model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6)
43
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
44
+ return model
45
+
46
+
47
+ @register_model
48
+ def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
49
+ """ ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).
50
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
51
+ """
52
+ model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12)
53
+ model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
54
+ return model
55
+
56
+
57
+ @register_model
58
+ def vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransformer:
59
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
60
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
61
+ """
62
+ model_args = dict(
63
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
64
+ reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
65
+ )
66
+ model = _create_vision_transformer(
67
+ 'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
68
+ return model
69
+
70
+
71
+ @register_model
72
+ def vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
73
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
74
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
75
+ """
76
+ name = 'vit_large_patch14_reg4_dinov2'
77
+ model_args = dict(
78
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
79
+ reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
80
+ )
81
+ model = _create_vision_transformer(name, pretrained=pretrained, **dict(model_args, **kwargs))
82
+
83
+ return model
84
+
85
+ @register_model
86
+ def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
87
+ """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
88
+ """
89
+ model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16)
90
+ if pretrained:
91
+ # There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose
92
+ model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs))
93
+ else:
94
+ model = _create_vision_transformer('vit_huge_patch16_224', pretrained=False, **dict(model_args, **kwargs))
95
+ return model
96
+
97
+
98
+ @register_model
99
+ def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:
100
+ """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
101
+ """
102
+ model = vit_huge_patch16_224(pretrained=pretrained, **kwargs)
103
+
104
+ for m in model.modules():
105
+ if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):
106
+ m.norm = nn.LayerNorm(m.fc1.out_features)
107
+
108
+ return model
109
+
110
+
111
+ @register_model
112
+ def vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **kwargs) -> VisionTransformer:
113
+ """ ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).
114
+ """
115
+ model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24)
116
+ model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))
117
+ if scaled_ln:
118
+ _apply_scaled_ln(model)
119
+ return model
120
+
121
+
122
+ @register_model
123
+ def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
124
+ model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6)
125
+ model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
126
+ return model
127
+
128
+
129
+ def _create_vision_transformer(*args, **kwargs):
130
+ model = _timm_create_vision_transformer(*args, **kwargs)
131
+ _patch_layer_scale(model)
132
+ return model
133
+
134
+
135
+ def _patch_layer_scale(model: VisionTransformer):
136
+ def replace_ls(old_ls: TIMMLayerScale):
137
+ new_ls = dinov2_arch.LayerScale(old_ls.gamma.shape[0], inplace=old_ls.inplace)
138
+ new_ls.load_state_dict(old_ls.state_dict())
139
+ return new_ls
140
+
141
+ # Monkey patch: Replace TIMM's LayerScale with our modified DINOv2 one, that uses a param name
142
+ # other than gamma, so that HFHub doesn't mess with it!
143
+ for mod in model.modules():
144
+ if isinstance(mod, Block):
145
+ if isinstance(mod.ls1, TIMMLayerScale):
146
+ mod.ls1 = replace_ls(mod.ls1)
147
+ if isinstance(mod.ls2, TIMMLayerScale):
148
+ mod.ls2 = replace_ls(mod.ls2)
149
+ pass
150
+
151
+
152
+ class ScaledLayerNorm(nn.LayerNorm):
153
+ '''
154
+ https://arxiv.org/pdf/2502.05795v1
155
+ '''
156
+ def __init__(self, ln_base: nn.LayerNorm, depth: int = 0):
157
+ super().__init__(ln_base.normalized_shape, eps=ln_base.eps, elementwise_affine=ln_base.elementwise_affine)
158
+ self.load_state_dict(ln_base.state_dict())
159
+ self.register_buffer('ln_scale', torch.tensor(1.0 / math.sqrt(depth)), persistent=False)
160
+
161
+ def forward(self, x):
162
+ y = super().forward(x)
163
+ y = y * self.ln_scale
164
+ return y
165
+
166
+
167
+ class DyT(nn.Module):
168
+ def __init__(self, C: int, init_alpha: float):
169
+ super().__init__()
170
+ self.alpha = nn.Parameter(torch.full((1,), init_alpha))
171
+ self.gamma = nn.Parameter(torch.ones(C))
172
+ self.beta = nn.Parameter(torch.zeros(C))
173
+
174
+ def forward(self, x: torch.Tensor):
175
+ x = F.tanh(self.alpha * x)
176
+ return self.gamma * x + self.beta
177
+
178
+ @register_model
179
+ def vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
180
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
181
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
182
+ """
183
+ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
184
+ model = _create_vision_transformer('vit_large_dyt_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
185
+
186
+ def _replace_ln_with_dyt(ln: nn.LayerNorm, depth: int):
187
+ return DyT(ln.normalized_shape[0], init_alpha=0.9)
188
+ _replace_ln(model, _replace_ln_with_dyt)
189
+
190
+ return model
191
+
192
+
193
+ def _apply_scaled_ln(model: VisionTransformer):
194
+ warnings.warn('Post-LayerNorm scaling activated!')
195
+
196
+ _replace_ln(model, lambda ln, depth: ScaledLayerNorm(ln, depth=depth))
197
+
198
+ def _replace_ln(model: VisionTransformer, fn):
199
+ def _inner_replace_ln(block: Block, depth: int, key: str):
200
+ prev = getattr(block, key)
201
+ if isinstance(prev, nn.LayerNorm):
202
+ setattr(block, key, fn(prev, depth=depth))
203
+
204
+ for i, block in enumerate(model.blocks):
205
+ _inner_replace_ln(block, i + 1, 'norm1')
206
+ _inner_replace_ln(block, i + 1, 'norm2')
src/models/radiov3/feature_normalizer.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from collections import namedtuple
9
+ from typing import NamedTuple, Optional, Tuple
10
+ import torch
11
+ from torch import nn
12
+
13
+
14
+ def _run_kernel(x: torch.Tensor, mean: torch.Tensor, tx: torch.Tensor):
15
+ if x.ndim <= 3:
16
+ x = x - mean
17
+ x = x @ tx.T
18
+ elif x.ndim == 4:
19
+ x = x - mean.reshape(1, -1, 1, 1)
20
+ kernel = tx.reshape(*tx.shape, 1, 1)
21
+ x = torch.nn.functional.conv2d(x, weight=kernel, bias=None, stride=1, padding=0)
22
+ else:
23
+ raise ValueError(f'Unsupported input dimension: {x.ndim}, shape: {x.shape}')
24
+ return x
25
+
26
+
27
+ class FeatureNormalizer(nn.Module):
28
+ def __init__(self, embed_dim: int, dtype: torch.dtype = torch.float32):
29
+ super().__init__()
30
+
31
+ self.register_buffer('mean', torch.zeros(embed_dim, dtype=dtype))
32
+ self.register_buffer('tx', torch.eye(embed_dim, dtype=dtype))
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ x = _run_kernel(x, self.mean, self.tx)
36
+ return x
37
+
38
+
39
+ class InterFeatState(NamedTuple):
40
+ y: torch.Tensor
41
+ alpha: torch.Tensor
42
+
43
+
44
+ class IntermediateFeatureNormalizerBase(nn.Module):
45
+ def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
46
+ raise NotImplementedError()
47
+
48
+
49
+ class IntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
50
+ def __init__(self, num_intermediates: int, embed_dim: int, rot_per_layer: bool = False, dtype: torch.dtype = torch.float32):
51
+ super().__init__()
52
+ self.register_buffer('alphas', torch.ones(num_intermediates, dtype=dtype))
53
+
54
+ rot = torch.eye(embed_dim, dtype=dtype)
55
+ if rot_per_layer:
56
+ rot = rot.unsqueeze(0).repeat(num_intermediates, 1, 1)
57
+
58
+ self.register_buffer('rotation', rot.contiguous())
59
+ self.register_buffer('means', torch.zeros(num_intermediates, embed_dim, dtype=dtype))
60
+
61
+ def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
62
+ if rot_index is None:
63
+ rot_index = index
64
+
65
+ if skip:
66
+ assert x.ndim == 3, f'Cannot use the `skip` parameter when the `x` tensor isn\'t 3-dimensional.'
67
+ prefix, x = x[:, :skip], x[:, skip:]
68
+
69
+ rotation = self._get_rotation(rot_index)
70
+ y = _run_kernel(x, self.means[index], rotation)
71
+
72
+ alpha = self.alphas[index]
73
+ if skip:
74
+ alpha = torch.cat([
75
+ torch.ones(skip, dtype=alpha.dtype, device=alpha.device),
76
+ alpha[None].expand(y.shape[1]),
77
+ ]).reshape(1, -1, 1)
78
+ y = torch.cat([prefix, y], dim=1)
79
+ else:
80
+ if x.ndim == 3:
81
+ alpha = alpha.reshape(1, 1, 1).expand(1, y.shape[1], 1)
82
+ elif x.ndim == 4:
83
+ alpha = alpha.reshape(1, 1, 1, 1).expand(1, 1, *y.shape[2:])
84
+ else:
85
+ raise ValueError(f'Unsupported input dimension: {x.ndim}')
86
+
87
+ return InterFeatState(y, alpha)
88
+
89
+ def _get_rotation(self, rot_index: int) -> torch.Tensor:
90
+ if self.rotation.ndim == 2:
91
+ return self.rotation
92
+ return self.rotation[rot_index]
93
+
94
+
95
+ class NullIntermediateFeatureNormalizer(IntermediateFeatureNormalizerBase):
96
+ instances = dict()
97
+
98
+ def __init__(self, dtype: torch.dtype, device: torch.device):
99
+ super().__init__()
100
+ self.register_buffer('alpha', torch.tensor(1, dtype=dtype, device=device))
101
+
102
+ @staticmethod
103
+ def get_instance(dtype: torch.dtype, device: torch.device):
104
+ instance = NullIntermediateFeatureNormalizer.instances.get((dtype, device), None)
105
+ if instance is None:
106
+ instance = NullIntermediateFeatureNormalizer(dtype, device)
107
+ NullIntermediateFeatureNormalizer.instances[(dtype, device)] = instance
108
+ return instance
109
+
110
+ def forward(self, x: torch.Tensor, index: int, rot_index: int = None, skip: Optional[int] = None) -> InterFeatState:
111
+ return InterFeatState(x, self.alpha)
src/models/radiov3/forward_intermediates.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Union, Any, Iterable
10
+ from types import MethodType
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer
16
+
17
+
18
+ def _take_indices(
19
+ num_blocks: int,
20
+ n: Optional[Union[int, List[int], Tuple[int]]],
21
+ ) -> Tuple[Set[int], int]:
22
+ if isinstance(n, int):
23
+ assert n >= 0
24
+ take_indices = {x for x in range(num_blocks - n, num_blocks)}
25
+ else:
26
+ take_indices = {num_blocks + idx if idx < 0 else idx for idx in n}
27
+ return take_indices, max(take_indices)
28
+
29
+
30
+ def forward_intermediates(
31
+ model: nn.Module,
32
+ patch_extractor: Callable[[torch.Tensor], torch.Tensor],
33
+ norm: nn.Module,
34
+ num_summary_tokens: int,
35
+ num_cls_tokens: int,
36
+ x: torch.Tensor,
37
+ indices: Optional[Union[int, List[int], Tuple[int]]] = None,
38
+ return_prefix_tokens: bool = False,
39
+ stop_early: bool = False,
40
+ output_fmt: str = 'NCHW',
41
+ intermediates_only: bool = False,
42
+ aggregation: Optional[str] = "sparse",
43
+ inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None,
44
+ norm_alpha_scheme = "post-alpha",
45
+ block_kwargs: Dict = None,
46
+ ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
47
+ """ Forward features that returns intermediates.
48
+
49
+ The Dense layer aggregation method is inspired from the paper: "Dense Connector for MLLMs"
50
+ by Yao, Huanjin et al. (2024). arXiv preprint arXiv:2405.13800}
51
+
52
+ Args:
53
+ x: Input image tensor
54
+ indices: Take last n blocks if int, select matching indices if sequence
55
+ return_prefix_tokens: Return both prefix and spatial intermediate tokens
56
+ norm: Apply norm layer to all intermediates
57
+ stop_early: Stop iterating over blocks when last desired intermediate hit
58
+ output_fmt: Shape of intermediate feature outputs
59
+ intermediates_only: Only return intermediate features
60
+ aggregation: intermediate layer aggregation method (sparse or dense)
61
+ norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha")
62
+ Returns:
63
+ """
64
+ assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
65
+ assert aggregation in ('sparse', 'dense'), 'Aggregation must be one of sparse or dense.'
66
+ reshape = output_fmt == 'NCHW'
67
+ intermediates = []
68
+
69
+ block_kwargs = block_kwargs or dict()
70
+
71
+ blocks = model.blocks
72
+
73
+ take_indices, max_index = _take_indices(len(blocks), indices)
74
+ take_indices = sorted(take_indices)
75
+ # forward pass
76
+ B, _, height, width = x.shape
77
+
78
+ x = patch_extractor(x)
79
+
80
+ if stop_early:
81
+ blocks = blocks[:max_index + 1]
82
+
83
+ if inter_feature_normalizer is None or norm_alpha_scheme == 'none':
84
+ inter_feature_normalizer = NullIntermediateFeatureNormalizer.get_instance(x.dtype, x.device)
85
+
86
+ assert norm_alpha_scheme in ('none', 'pre-alpha', 'post-alpha'), f'Unsupported alpha scheme: {norm_alpha_scheme}'
87
+ post_alpha_scheme = norm_alpha_scheme == 'post-alpha'
88
+
89
+ accumulator = 0
90
+ alpha_sum = 0
91
+ num_accumulated = 0
92
+
93
+ take_off = 0
94
+
95
+ for i, blk in enumerate(blocks):
96
+ x = blk(x, **block_kwargs)
97
+ if aggregation == "dense":
98
+ # Arbitrarily use the rotation matrix from the final layer in the dense group
99
+ y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens)
100
+ if post_alpha_scheme:
101
+ accumulator = accumulator + y
102
+ alpha_sum = alpha_sum + alpha
103
+ else:
104
+ accumulator = accumulator + (alpha * y)
105
+ alpha_sum += 1
106
+ num_accumulated += 1
107
+ if i == take_indices[take_off]:
108
+ if aggregation == "dense":
109
+ alpha = alpha_sum / num_accumulated
110
+ x_ = alpha * accumulator / num_accumulated
111
+ num_accumulated = 0
112
+ accumulator = 0
113
+ alpha_sum = 0
114
+ else:
115
+ y, alpha = inter_feature_normalizer(x, i, skip=num_summary_tokens)
116
+ x_ = alpha * y
117
+ # normalize intermediates with final norm layer if enabled
118
+ intermediates.append(norm(x_))
119
+ take_off = min(take_off + 1, len(take_indices) - 1)
120
+
121
+ # process intermediates
122
+
123
+ # split prefix (e.g. class, distill) and spatial feature tokens
124
+ prefix_tokens = [y[:, :num_cls_tokens] for y in intermediates]
125
+ intermediates = [y[:, num_summary_tokens:] for y in intermediates]
126
+
127
+ if reshape:
128
+ # reshape to BCHW output format
129
+ H = height // model.patch_size
130
+ W = width // model.patch_size
131
+ intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
132
+ if not torch.jit.is_scripting() and return_prefix_tokens:
133
+ # return_prefix not support in torchscript due to poor type handling
134
+ intermediates = list(zip(prefix_tokens, intermediates))
135
+ if intermediates_only:
136
+ return intermediates
137
+ x = norm(x)
138
+ return x, intermediates
src/models/radiov3/hf_model.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from collections import namedtuple
15
+ from typing import Callable, Dict, Optional, List, Union
16
+
17
+ from timm.models import VisionTransformer
18
+ import torch
19
+ from torch import nn
20
+ from transformers import PretrainedConfig, PreTrainedModel
21
+
22
+
23
+ from .common import RESOURCE_MAP, DEFAULT_VERSION
24
+
25
+ # Import all required modules.
26
+ from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
27
+ from .adaptor_generic import GenericAdaptor, AdaptorBase
28
+ from .adaptor_mlp import create_mlp_from_config
29
+ from .adaptor_registry import adaptor_registry
30
+ from .cls_token import ClsToken
31
+ from .dinov2_arch import dinov2_vitg14_reg
32
+ from .enable_cpe_support import enable_cpe
33
+ from .enable_spectral_reparam import configure_spectral_reparam_from_args
34
+ from .eradio_model import eradio
35
+ from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
36
+ from .forward_intermediates import forward_intermediates
37
+ from .radio_model import create_model_from_args
38
+ from .radio_model import RADIOModel as RADIOModelBase, Resolution
39
+ from .input_conditioner import get_default_conditioner, InputConditioner
40
+ from .open_clip_adaptor import OpenCLIP_RADIO
41
+ from .vit_patch_generator import ViTPatchGenerator
42
+ from .vitdet import apply_vitdet_arch, VitDetArgs
43
+
44
+ # Register extra models
45
+ from .extra_timm_models import *
46
+ from .extra_models import *
47
+
48
+
49
+ class RADIOConfig(PretrainedConfig):
50
+ """Pretrained Hugging Face configuration for RADIO models."""
51
+
52
+ def __init__(
53
+ self,
54
+ args: Optional[dict] = None,
55
+ version: Optional[str] = DEFAULT_VERSION,
56
+ patch_size: Optional[int] = None,
57
+ max_resolution: Optional[int] = None,
58
+ preferred_resolution: Optional[Resolution] = None,
59
+ adaptor_names: Union[str, List[str]] = None,
60
+ adaptor_configs: Dict[str, Dict[str, int]] = None,
61
+ vitdet_window_size: Optional[int] = None,
62
+ feature_normalizer_config: Optional[dict] = None,
63
+ inter_feature_normalizer_config: Optional[dict] = None,
64
+ **kwargs,
65
+ ):
66
+ self.args = args
67
+ for field in ["dtype", "amp_dtype"]:
68
+ if self.args is not None and field in self.args:
69
+ # Convert to a string in order to make it serializable.
70
+ # For example for torch.float32 we will store "float32",
71
+ # for "bfloat16" we will store "bfloat16".
72
+ self.args[field] = str(args[field]).split(".")[-1]
73
+ self.version = version
74
+ resource = RESOURCE_MAP[version]
75
+ self.patch_size = patch_size or resource.patch_size
76
+ self.max_resolution = max_resolution or resource.max_resolution
77
+ self.preferred_resolution = (
78
+ preferred_resolution or resource.preferred_resolution
79
+ )
80
+ self.adaptor_names = adaptor_names
81
+ self.adaptor_configs = adaptor_configs
82
+ self.vitdet_window_size = vitdet_window_size
83
+ self.feature_normalizer_config = feature_normalizer_config
84
+ self.inter_feature_normalizer_config = inter_feature_normalizer_config
85
+ super().__init__(**kwargs)
86
+
87
+
88
+
89
+ class RADIOModel(PreTrainedModel):
90
+ """Pretrained Hugging Face model for RADIO.
91
+
92
+ This class inherits from PreTrainedModel, which provides
93
+ HuggingFace's functionality for loading and saving models.
94
+ """
95
+
96
+ config_class = RADIOConfig
97
+
98
+ def __init__(self, config: RADIOConfig):
99
+ super().__init__(config)
100
+
101
+ RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
102
+ args = RADIOArgs(**config.args)
103
+ self.config = config
104
+
105
+ model = create_model_from_args(args)
106
+ input_conditioner: InputConditioner = get_default_conditioner()
107
+
108
+ dtype = getattr(args, "dtype", torch.float32)
109
+ if isinstance(dtype, str):
110
+ # Convert the dtype's string representation back to a dtype.
111
+ dtype = getattr(torch, dtype)
112
+ model.to(dtype=dtype)
113
+ input_conditioner.dtype = dtype
114
+
115
+ summary_idxs = torch.tensor(
116
+ [i for i, t in enumerate(args.teachers) if t.get("use_summary", True)],
117
+ dtype=torch.int64,
118
+ )
119
+
120
+ adaptor_configs = config.adaptor_configs
121
+ adaptor_names = config.adaptor_names or []
122
+
123
+ adaptors = dict()
124
+ for adaptor_name in adaptor_names:
125
+ mlp_config = adaptor_configs[adaptor_name]
126
+ adaptor = GenericAdaptor(args, None, None, mlp_config)
127
+ adaptor.head_idx = mlp_config["head_idx"]
128
+ adaptors[adaptor_name] = adaptor
129
+
130
+ feature_normalizer = None
131
+ if config.feature_normalizer_config is not None:
132
+ # Actual normalization values will be restored when loading checkpoint weights.
133
+ feature_normalizer = FeatureNormalizer(config.feature_normalizer_config["embed_dim"])
134
+
135
+ inter_feature_normalizer = None
136
+ if config.inter_feature_normalizer_config is not None:
137
+ inter_feature_normalizer = IntermediateFeatureNormalizer(
138
+ config.inter_feature_normalizer_config["num_intermediates"],
139
+ config.inter_feature_normalizer_config["embed_dim"],
140
+ rot_per_layer=config.inter_feature_normalizer_config["rot_per_layer"],
141
+ dtype=dtype)
142
+
143
+ self.radio_model = RADIOModelBase(
144
+ model,
145
+ input_conditioner,
146
+ summary_idxs=summary_idxs,
147
+ patch_size=config.patch_size,
148
+ max_resolution=config.max_resolution,
149
+ window_size=config.vitdet_window_size,
150
+ preferred_resolution=config.preferred_resolution,
151
+ adaptors=adaptors,
152
+ feature_normalizer=feature_normalizer,
153
+ inter_feature_normalizer=inter_feature_normalizer,
154
+ )
155
+
156
+ @property
157
+ def adaptors(self) -> nn.ModuleDict:
158
+ return self.radio_model.adaptors
159
+
160
+ @property
161
+ def model(self) -> VisionTransformer:
162
+ return self.radio_model.model
163
+
164
+ @property
165
+ def input_conditioner(self) -> InputConditioner:
166
+ return self.radio_model.input_conditioner
167
+
168
+ @property
169
+ def num_summary_tokens(self) -> int:
170
+ return self.radio_model.num_summary_tokens
171
+
172
+ @property
173
+ def patch_size(self) -> int:
174
+ return self.radio_model.patch_size
175
+
176
+ @property
177
+ def max_resolution(self) -> int:
178
+ return self.radio_model.max_resolution
179
+
180
+ @property
181
+ def preferred_resolution(self) -> Resolution:
182
+ return self.radio_model.preferred_resolution
183
+
184
+ @property
185
+ def window_size(self) -> int:
186
+ return self.radio_model.window_size
187
+
188
+ @property
189
+ def min_resolution_step(self) -> int:
190
+ return self.radio_model.min_resolution_step
191
+
192
+ def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
193
+ return self.radio_model.make_preprocessor_external()
194
+
195
+ def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
196
+ return self.radio_model.get_nearest_supported_resolution(height, width)
197
+
198
+ def switch_to_deploy(self):
199
+ return self.radio_model.switch_to_deploy()
200
+
201
+ def forward(self, x: torch.Tensor):
202
+ return self.radio_model.forward(x)
src/models/radiov3/input_conditioner.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from typing import Union, Tuple
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+
15
+ norm_t = Union[Tuple[float, float, float], torch.Tensor]
16
+
17
+ class InputConditioner(nn.Module):
18
+ def __init__(self,
19
+ input_scale: float,
20
+ norm_mean: norm_t,
21
+ norm_std: norm_t,
22
+ dtype: torch.dtype = None,
23
+ ):
24
+ super().__init__()
25
+
26
+ self.dtype = dtype
27
+
28
+ self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale)
29
+ self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale)
30
+
31
+ def forward(self, x: torch.Tensor):
32
+ y = (x - self.norm_mean) / self.norm_std
33
+ # if self.dtype is not None:
34
+ # y = y.to(self.dtype)
35
+ return y
36
+
37
+
38
+ def get_default_conditioner():
39
+ from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
40
+
41
+ return InputConditioner(
42
+ input_scale=1.0,
43
+ norm_mean=OPENAI_CLIP_MEAN,
44
+ norm_std=OPENAI_CLIP_STD,
45
+ )
46
+
47
+
48
+ def _to_tensor(v: norm_t):
49
+ return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1)
src/models/radiov3/open_clip_adaptor.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ from .adaptor_registry import adaptor_registry, dict_t, state_t
15
+
16
+ from .adaptor_generic import GenericAdaptor
17
+
18
+
19
+ class OpenCLIP_RADIO(GenericAdaptor):
20
+ def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t):
21
+ super().__init__(main_config, adaptor_config, state)
22
+
23
+ import open_clip
24
+
25
+ self.oc_model = open_clip.create_model_from_pretrained(
26
+ model_name=adaptor_config['model'],
27
+ pretrained=adaptor_config['pretrained'],
28
+ return_transform=False,
29
+ )
30
+ # Unload these parameters
31
+ self.oc_model.visual = None
32
+
33
+ self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model'])
34
+
35
+ def encode_text(self, text, normalize: bool = False):
36
+ return self.oc_model.encode_text(text, normalize=normalize)
37
+
38
+
39
+ @adaptor_registry.register_adaptor("open_clip")
40
+ def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t):
41
+ return OpenCLIP_RADIO(main_config, adaptor_config, state)
src/models/radiov3/radio_model.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import nn
12
+
13
+ from timm.models import create_model, VisionTransformer
14
+
15
+ from .enable_cpe_support import enable_cpe
16
+ from .input_conditioner import InputConditioner
17
+ from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
18
+ from . import eradio_model
19
+ from .enable_spectral_reparam import configure_spectral_reparam_from_args
20
+ from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
21
+ from . import dual_hybrid_vit
22
+
23
+
24
+ class Resolution(NamedTuple):
25
+ height: int
26
+ width: int
27
+
28
+
29
+ class RADIOModel(nn.Module):
30
+ def __init__(
31
+ self,
32
+ model: nn.Module,
33
+ input_conditioner: InputConditioner,
34
+ patch_size: int,
35
+ max_resolution: int,
36
+ preferred_resolution: Resolution,
37
+ summary_idxs: Optional[torch.Tensor] = None,
38
+ window_size: int = None,
39
+ adaptors: Dict[str, AdaptorBase] = None,
40
+ feature_normalizer: Optional[FeatureNormalizer] = None,
41
+ inter_feature_normalizer: Optional[IntermediateFeatureNormalizer] = None,
42
+ ):
43
+ super().__init__()
44
+
45
+ self.model = model
46
+ self.input_conditioner = input_conditioner
47
+ if summary_idxs is not None:
48
+ self.register_buffer('summary_idxs', summary_idxs)
49
+ else:
50
+ self.summary_idxs = None
51
+
52
+ self._preferred_resolution = preferred_resolution
53
+ self._patch_size = patch_size
54
+ self._max_resolution = max_resolution
55
+ self._window_size = window_size
56
+
57
+ adaptors = adaptors or dict()
58
+ self.adaptors = nn.ModuleDict(adaptors)
59
+
60
+ if feature_normalizer is None:
61
+ feature_normalizer = nn.Identity()
62
+ self.feature_normalizer = feature_normalizer
63
+ self.inter_feature_normalizer = inter_feature_normalizer
64
+
65
+ @property
66
+ def num_summary_tokens(self) -> int:
67
+ if hasattr(self.model, 'num_summary_tokens'):
68
+ return self.model.num_summary_tokens
69
+
70
+ patch_gen = getattr(self.model, "patch_generator", None)
71
+ if patch_gen is not None:
72
+ return patch_gen.num_skip
73
+ elif getattr(self.model, 'global_pool', None) == 'avg':
74
+ return 0
75
+ return 1
76
+
77
+ @property
78
+ def num_cls_tokens(self) -> int:
79
+ if hasattr(self.model, 'num_cls_tokens'):
80
+ return self.model.num_cls_tokens
81
+
82
+ patch_gen = getattr(self.model, 'patch_generator', None)
83
+ if patch_gen is not None:
84
+ return patch_gen.num_cls_tokens
85
+ elif getattr(self.model, 'global_pool', None) == 'avg':
86
+ return 0
87
+ return 1
88
+
89
+ @property
90
+ def patch_size(self) -> int:
91
+ if self._patch_size is not None:
92
+ return self._patch_size
93
+ if hasattr(self.model, "patch_size"):
94
+ return self.model.patch_size
95
+ patch_gen = getattr(self.model, "patch_generator", None)
96
+ if patch_gen is not None:
97
+ return patch_gen.patch_size
98
+ return None
99
+
100
+ @property
101
+ def max_resolution(self) -> int:
102
+ return self._max_resolution
103
+
104
+ @property
105
+ def preferred_resolution(self) -> Resolution:
106
+ return self._preferred_resolution
107
+
108
+ @property
109
+ def window_size(self) -> int:
110
+ return self._window_size
111
+
112
+ @property
113
+ def min_resolution_step(self) -> int:
114
+ res = self.patch_size
115
+ if self.window_size is not None:
116
+ res *= self.window_size
117
+ return res
118
+
119
+ @property
120
+ def blocks(self) -> Iterable[nn.Module]:
121
+ blocks = getattr(self.model, 'blocks', None)
122
+ if blocks is not None:
123
+ return blocks
124
+ return None
125
+
126
+ @property
127
+ def embed_dim(self) -> int:
128
+ return self.model.embed_dim
129
+
130
+ def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
131
+ ret = self.input_conditioner
132
+ self.input_conditioner = nn.Identity()
133
+ return ret
134
+
135
+ def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
136
+ height = int(round(height / self.min_resolution_step) * self.min_resolution_step)
137
+ width = int(round(width / self.min_resolution_step) * self.min_resolution_step)
138
+
139
+ height = max(height, self.min_resolution_step)
140
+ width = max(width, self.min_resolution_step)
141
+
142
+ return Resolution(height=height, width=width)
143
+
144
+ def switch_to_deploy(self):
145
+ fn = getattr(self.model, 'switch_to_deploy', None)
146
+ if fn is not None:
147
+ fn()
148
+
149
+ def forward(self, x: torch.Tensor, feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
150
+ '''
151
+ Forward process for model.
152
+ Args:
153
+ x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`,
154
+ otherwise `x` is expected to be mean centered with unit standard deviation.
155
+ feature_format: ['NLC', 'NCHW'] - The output format for the features.
156
+ '''
157
+ res_step = self.min_resolution_step
158
+ if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0):
159
+ raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '
160
+ '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
161
+ f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
162
+
163
+ # import pdb; pdb.set_trace()
164
+ x = self.input_conditioner(x)
165
+ y = self.model.forward_features(x)
166
+ ret = self._extract_final(x, y, feature_fmt=feature_fmt)
167
+ return ret
168
+
169
+ def _extract_final(self, x: torch.Tensor, y: torch.Tensor, feature_fmt: str = 'NLC'):
170
+ if isinstance(self.model, VisionTransformer):
171
+ patch_gen = getattr(self.model, "patch_generator", None)
172
+ if patch_gen is not None:
173
+ all_summary = y[:, : patch_gen.num_cls_tokens]
174
+ if self.summary_idxs is not None:
175
+ bb_summary = all_summary[:, self.summary_idxs]
176
+ else:
177
+ bb_summary = all_summary
178
+ all_feat = y[:, patch_gen.num_skip :]
179
+ elif self.model.global_pool == "avg":
180
+ all_summary = y[:, self.model.num_prefix_tokens :].mean(dim=1)
181
+ bb_summary = all_summary
182
+ all_feat = y
183
+ else:
184
+ all_summary = y[:, 0]
185
+ bb_summary = all_summary
186
+ all_feat = y[:, 1:]
187
+ elif isinstance(self.model, eradio_model.ERADIO):
188
+ _, f = y
189
+ all_feat = f.flatten(2).transpose(1, 2)
190
+ all_summary = all_feat.mean(dim=1)
191
+ bb_summary = all_summary
192
+ elif isinstance(y, (list, tuple)):
193
+ all_summary, all_feat = y
194
+ bb_summary = all_summary
195
+ else:
196
+ all_summary = y[:, :self.num_cls_tokens]
197
+ if self.summary_idxs is not None and all_summary.shape[1] > 1:
198
+ if all_summary.shape[1] == 1:
199
+ # Create dummy duplicates
200
+ all_summary = all_summary.expand(-1, 128, -1)
201
+ bb_summary = all_summary[:, self.summary_idxs]
202
+ else:
203
+ bb_summary = all_summary
204
+ all_feat = y[:, self.num_summary_tokens:]
205
+
206
+ all_feat = self.feature_normalizer(all_feat)
207
+
208
+ if feature_fmt == 'NCHW':
209
+ fmt_feat = (all_feat.reshape(all_feat.shape[0], x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size, all_feat.shape[2])
210
+ .permute(0, 3, 1, 2)
211
+ )
212
+ elif feature_fmt == 'NLC':
213
+ fmt_feat = all_feat
214
+ else:
215
+ raise ValueError(f'Unsupported feature_fmt: {feature_fmt}. Must be one of ["NLC", "NCHW"]')
216
+
217
+ ret = RadioOutput(bb_summary.flatten(1), fmt_feat)
218
+
219
+ if self.adaptors:
220
+ ret = dict(backbone=ret)
221
+ for name, adaptor in self.adaptors.items():
222
+ if all_summary.ndim == 3:
223
+ if all_summary.shape[1] == 1:
224
+ summary = all_summary[:, 0]
225
+ else:
226
+ summary = all_summary[:, adaptor.head_idx]
227
+ else:
228
+ summary = all_summary
229
+ ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)
230
+ v = adaptor(ada_input).to(torch.float32)
231
+ ret[name] = v
232
+
233
+ return ret
234
+
235
+ def forward_intermediates(
236
+ self,
237
+ x: torch.Tensor,
238
+ indices: Optional[Union[int, List[int], Tuple[int]]] = None,
239
+ return_prefix_tokens: bool = False,
240
+ norm: bool = False,
241
+ stop_early: bool = False,
242
+ output_fmt: str = 'NCHW',
243
+ intermediates_only: bool = False,
244
+ aggregation: Optional[str] = "sparse",
245
+ norm_alpha_scheme: Optional[str] = "post-alpha",
246
+ ) -> List[RadioOutput]:
247
+ """ Forward features that returns intermediates.
248
+ Args:
249
+ x: Input image tensor
250
+ indices: Take last n blocks if int, select matching indices if sequence
251
+ return_prefix_tokens: Return both prefix and spatial intermediate tokens
252
+ norm: Apply norm layer to all intermediates
253
+ stop_early: Stop iterating over blocks when last desired intermediate hit
254
+ output_fmt: Shape of intermediate feature outputs. Options: NCHW, NLC
255
+ intermediates_only: Only return intermediate features
256
+ aggregation: intermediate layer aggregation method (sparse or dense).
257
+ Dense accumulation is done by averaging the features in each group.
258
+ norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha"), or don't normalize ("none")
259
+ Only affects dense aggregation
260
+ Returns:
261
+ List of RadioOutput objects.
262
+ """
263
+ x = self.input_conditioner(x)
264
+ intermediates = self.model.forward_intermediates(
265
+ x,
266
+ indices=indices,
267
+ return_prefix_tokens=return_prefix_tokens,
268
+ norm=norm,
269
+ stop_early=stop_early,
270
+ output_fmt=output_fmt,
271
+ intermediates_only=intermediates_only,
272
+ aggregation=aggregation,
273
+ inter_feature_normalizer=self.inter_feature_normalizer,
274
+ norm_alpha_scheme=norm_alpha_scheme,
275
+ )
276
+
277
+ if not intermediates_only:
278
+ final, intermediates = intermediates
279
+
280
+ def prepare_summary(summ: Optional[torch.Tensor]):
281
+ if summ is None:
282
+ return summ
283
+ if self.summary_idxs is not None and summ.shape[1] > 1:
284
+ summ = summ[:, self.summary_idxs]
285
+ return summ.flatten(1)
286
+
287
+ if return_prefix_tokens:
288
+ radio_outputs = [
289
+ RadioOutput(prepare_summary(summary), features)
290
+ for summary, features in intermediates
291
+ ]
292
+ else:
293
+ radio_outputs = intermediates
294
+
295
+ if intermediates_only:
296
+ return radio_outputs
297
+ else:
298
+ final = self._extract_final(x, final, feature_fmt=output_fmt)
299
+ return final, radio_outputs
300
+
301
+
302
+ def create_model_from_args(args) -> nn.Module:
303
+ in_chans = 3
304
+ if args.in_chans is not None:
305
+ in_chans = args.in_chans
306
+ elif args.input_size is not None:
307
+ in_chans = args.input_size[0]
308
+
309
+ # Skip weight initialization unless it's explicitly requested.
310
+ weight_init = args.model_kwargs.pop("weight_init", "skip")
311
+
312
+ model = create_model(
313
+ args.model,
314
+ pretrained=args.pretrained,
315
+ in_chans=in_chans,
316
+ num_classes=args.num_classes,
317
+ drop_rate=args.drop,
318
+ drop_path_rate=args.drop_path,
319
+ drop_block_rate=args.drop_block,
320
+ global_pool=args.gp,
321
+ bn_momentum=args.bn_momentum,
322
+ bn_eps=args.bn_eps,
323
+ scriptable=args.torchscript,
324
+ checkpoint_path=args.initial_checkpoint,
325
+ weight_init=weight_init,
326
+ **args.model_kwargs,
327
+ )
328
+
329
+ if hasattr(model, 'norm') and not getattr(args, 'model_norm', False):
330
+ model.norm = nn.Identity()
331
+
332
+ model.head = nn.Identity()
333
+
334
+ if args.cpe_max_size is not None:
335
+ uq_teachers = set(t['name'] for t in args.teachers)
336
+ enable_cpe(
337
+ model,
338
+ args.cpe_max_size,
339
+ num_cls_tokens=len(uq_teachers) if args.cls_token_per_teacher else 1,
340
+ register_multiple=getattr(args, 'register_multiple', None),
341
+ num_registers=getattr(args, 'cpe_num_registers', None),
342
+ )
343
+
344
+ return model
src/models/radiov3/vit_patch_generator.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import math
10
+ from typing import Union, Tuple, Optional
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+ from einops import rearrange
16
+
17
+ from .cls_token import ClsToken
18
+
19
+ input_dim_t = Union[int, Tuple[int, int]]
20
+
21
+ try:
22
+ # raise ImportError()
23
+ from indirect_grid_sample import indirect_grid_sample
24
+ except ImportError:
25
+ indirect_grid_sample = None
26
+
27
+ class ViTPatchGenerator(nn.Module):
28
+ def __init__(self,
29
+ patch_size: int,
30
+ embed_dim: int,
31
+ input_dims: input_dim_t,
32
+ abs_pos: bool = True,
33
+ normalize_patches: bool = False,
34
+ cls_token: bool = False,
35
+ max_input_dims: Optional[input_dim_t] = None,
36
+ pos_dropout: float = 0.0,
37
+ return_pos_enc: bool = False,
38
+ num_cls_tokens: int = 1,
39
+ register_multiple: Optional[int] = None,
40
+ num_registers: Optional[int] = None,
41
+ patch_bias: bool = False,
42
+ device=None, dtype=None,
43
+ ):
44
+ super().__init__()
45
+
46
+ if isinstance(input_dims, int):
47
+ input_dims = (input_dims, input_dims)
48
+
49
+ if max_input_dims is None:
50
+ max_input_dims = input_dims
51
+ if isinstance(max_input_dims, int):
52
+ max_input_dims = (max_input_dims, max_input_dims)
53
+
54
+ max_input_dims = tuple(
55
+ int(math.ceil(d / patch_size) * patch_size)
56
+ for d in max_input_dims
57
+ )
58
+
59
+ self.cpe_mode = max_input_dims != input_dims
60
+ self.pos_dropout = pos_dropout
61
+ self.return_pos_enc = return_pos_enc
62
+
63
+ factory = dict(device=device, dtype=dtype)
64
+
65
+ self.patch_size = patch_size
66
+ self.abs_pos = abs_pos
67
+ self.embed_dim = embed_dim
68
+
69
+ self.num_rows = max_input_dims[0] // patch_size
70
+ self.num_cols = max_input_dims[1] // patch_size
71
+ self.input_dims = tuple(d // patch_size for d in input_dims)
72
+ self.num_patches = self.num_rows * self.num_cols
73
+ self.max_input_dims = max_input_dims
74
+
75
+ self.im_to_patches = Im2Patches(patch_size)
76
+ self.embedder = ViTPatchLinear(patch_size, embed_dim, bias=patch_bias, **factory)
77
+
78
+ if abs_pos:
79
+ scale = embed_dim ** -0.5
80
+ self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim, **factory) * scale)
81
+
82
+ self.cls_token = ClsToken(
83
+ embed_dim,
84
+ num_tokens=num_cls_tokens,
85
+ enabled=cls_token,
86
+ register_multiple=register_multiple,
87
+ num_registers=num_registers,
88
+ )
89
+
90
+ self.patch_normalizer = nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
93
+ patches = self.embed_patches(x)
94
+ patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
95
+ patches = self.cls_token(patches)
96
+ patches = self.patch_normalizer(patches)
97
+ if self.return_pos_enc:
98
+ return patches, pos_enc
99
+ return patches
100
+
101
+ @property
102
+ def apply_cls_token(self):
103
+ return self.cls_token.enabled
104
+
105
+ @property
106
+ def num_cls_tokens(self):
107
+ return self.cls_token.num_tokens
108
+
109
+ @property
110
+ def num_cls_patches(self):
111
+ return self.cls_token.num_patches
112
+
113
+ @property
114
+ def num_registers(self):
115
+ return self.cls_token.num_registers
116
+
117
+ @property
118
+ def num_skip(self):
119
+ return self.num_cls_tokens + self.num_registers
120
+
121
+ def no_weight_decay(self):
122
+ return [
123
+ 'pos_embed',
124
+ ]
125
+
126
+ def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
127
+ if src_embed.shape != targ_embed.shape:
128
+ src_size = int(math.sqrt(src_embed.shape[1]))
129
+
130
+ assert src_size ** 2 == src_embed.shape[1], 'Unable to interpolate non-square embedding'
131
+
132
+ src_embed = rearrange(src_embed, 'b (h w) c -> b c h w', h=src_size, w=src_size)
133
+ src_embed = F.interpolate(src_embed, size=(self.num_rows, self.num_cols), mode='bicubic', align_corners=True, antialias=False)
134
+ src_embed = rearrange(src_embed, 'b c h w -> b (h w) c')
135
+ targ_embed.data.copy_(src_embed)
136
+
137
+ def _load_projection(self, src_proj_weight: torch.Tensor, targ_proj_weight: torch.Tensor):
138
+ if src_proj_weight.shape != targ_proj_weight.shape:
139
+ src_patch_size = int(math.sqrt(src_proj_weight.shape[1] // 3))
140
+
141
+ assert (src_patch_size ** 2) * 3 == src_proj_weight.shape[1], 'Unable to interpolate non-square patch size'
142
+
143
+ src_proj_weight = rearrange(src_proj_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
144
+ src_proj_weight = F.interpolate(src_proj_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
145
+ src_proj_weight = rearrange(src_proj_weight, 'b c h w -> b (c h w)')
146
+ targ_proj_weight.data.copy_(src_proj_weight)
147
+
148
+ def embed_patches(self, x: torch.Tensor) -> torch.Tensor:
149
+ # import pdb; pdb.set_trace()
150
+ patches = self.im_to_patches(x)
151
+ patches = self.embedder(patches)
152
+ return patches
153
+
154
+ def apply_pos_enc(self,
155
+ patches: torch.Tensor,
156
+ patch_idxs: Optional[torch.Tensor] = None,
157
+ input_size: Optional[Tuple[int, int]] = None,
158
+ ) -> torch.Tensor:
159
+ if not self.abs_pos:
160
+ return patches
161
+
162
+ pos_enc = self.get_pos_enc(patches.shape[0], patch_idxs, input_size)
163
+
164
+ if self.training and self.pos_dropout > 0:
165
+ keeps = torch.rand(patches.shape[0], 1, 1, dtype=pos_enc.dtype, device=pos_enc.device) > self.pos_dropout
166
+ pos_enc_drop = torch.where(keeps, pos_enc, 0)
167
+ else:
168
+ pos_enc_drop = pos_enc
169
+
170
+ return patches + pos_enc_drop, pos_enc
171
+
172
+ def get_pos_enc(self,
173
+ batch_size: int,
174
+ patch_idxs: Optional[torch.Tensor] = None,
175
+ input_size: Optional[Tuple[int, int]] = None,
176
+ ) -> torch.Tensor:
177
+ if input_size is None:
178
+ input_dims = self.input_dims
179
+ else:
180
+ input_dims = tuple(d // self.patch_size for d in input_size)
181
+
182
+ pos_embed = self._get_pos_embeddings(batch_size, input_dims)
183
+
184
+ if patch_idxs is None:
185
+ return pos_embed
186
+
187
+ exp_patch_idxs = patch_idxs.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1])
188
+
189
+ pos_embed = torch.gather(pos_embed.expand(patch_idxs.shape[0], -1, -1), dim=1, index=exp_patch_idxs)
190
+ return pos_embed
191
+
192
+
193
+ def _get_pos_embeddings(self, batch_size: int, input_dims: Tuple[int, int]):
194
+ if (self.num_rows, self.num_cols) == input_dims:
195
+ return self.pos_embed
196
+
197
+ pos_embed = self.pos_embed.reshape(1, self.num_rows, self.num_cols, -1).permute(0, 3, 1, 2)
198
+
199
+ def window_select(pos_embed):
200
+ if input_dims[0] < pos_embed.shape[-2]:
201
+ pos_embed = pos_embed[..., :input_dims[0], :]
202
+ if input_dims[1] < pos_embed.shape[-1]:
203
+ pos_embed = pos_embed[..., :, :input_dims[1]]
204
+ return pos_embed
205
+
206
+ if self.cpe_mode:
207
+ if self.training:
208
+ min_scale = math.sqrt(0.1)
209
+ scale = torch.rand(batch_size, 1, 1, device=pos_embed.device) * (1 - min_scale) + min_scale
210
+ aspect_min = math.log(3 / 4)
211
+ aspect_max = -aspect_min
212
+ aspect = torch.exp(torch.rand(batch_size, 1, 1, device=pos_embed.device) * (aspect_max - aspect_min) + aspect_min)
213
+
214
+ scale_x = scale * aspect
215
+ scale_y = scale * (1 / aspect)
216
+ scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
217
+
218
+ pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (1 - scale_xy)
219
+
220
+ lin_x = torch.linspace(0, 1, steps=input_dims[1], device=pos_embed.device)[None, None].expand(batch_size, input_dims[0], -1)
221
+ lin_y = torch.linspace(0, 1, steps=input_dims[0], device=pos_embed.device)[None, :, None].expand(batch_size, -1, input_dims[1])
222
+
223
+ lin_xy = torch.stack([lin_x, lin_y], dim=-1)
224
+
225
+ grid_xy = lin_xy * scale_xy + pos_xy
226
+
227
+ # Convert to [-1, 1] range
228
+ grid_xy.mul_(2).sub_(1)
229
+
230
+ pos_embed = F.grid_sample(
231
+ pos_embed.float().expand(batch_size, -1, -1, -1),
232
+ grid=grid_xy,
233
+ mode='bilinear',
234
+ padding_mode='zeros',
235
+ align_corners=True,
236
+ ).to(pos_embed.dtype)
237
+ else:
238
+ # i_rows, i_cols = input_dims
239
+ # p_rows, p_cols = pos_embed.shape[2:]
240
+ # if i_rows <= p_rows and i_cols <= p_cols:
241
+ # left = (p_cols - i_cols) // 2
242
+ # top = (p_rows - i_rows) // 2
243
+ # pos_embed = pos_embed[..., top:top+i_rows, left:left+i_cols]
244
+ # else:
245
+ max_dim = max(input_dims)
246
+ pos_embed = F.interpolate(pos_embed.float(), size=(max_dim, max_dim), align_corners=True, mode='bilinear').to(pos_embed.dtype)
247
+
248
+ pos_embed = window_select(pos_embed)
249
+ else:
250
+ pos_embed = window_select(pos_embed)
251
+
252
+ if pos_embed.shape[-2:] != input_dims:
253
+ pos_embed = F.interpolate(pos_embed.float(), size=input_dims, align_corners=True, mode='bilinear').to(pos_embed.dtype)
254
+
255
+ pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
256
+
257
+ return pos_embed
258
+
259
+
260
+ class Im2Patches(nn.Module):
261
+ def __init__(self, patch_size: int):
262
+ super().__init__()
263
+ self.patch_size = patch_size
264
+
265
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
266
+ if self.patch_size == 1:
267
+ patches = x.flatten(2)
268
+ patches = patches.permute(0, 2, 1)
269
+ return patches
270
+
271
+ py = x.shape[-2] // self.patch_size
272
+ px = x.shape[-1] // self.patch_size
273
+ patches = rearrange(x, 'b c (py yy) (px xx) -> b (py px) (c yy xx)',
274
+ py=py, yy=self.patch_size,
275
+ px=px, xx=self.patch_size,
276
+ )
277
+ return patches
278
+
279
+
280
+ class ViTPatchLinear(nn.Linear):
281
+ def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):
282
+ super().__init__(
283
+ 3 * (patch_size ** 2),
284
+ embed_dim,
285
+ bias=bias,
286
+ **factory
287
+ )
288
+ self.patch_size = patch_size
src/models/radiov3/vitdet.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from contextlib import contextmanager
3
+ from logging import getLogger
4
+ import math
5
+ import sys
6
+ from typing import List, Union, Iterable
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+
12
+ from timm.models import VisionTransformer
13
+ from einops import rearrange
14
+
15
+ from .extra_models import DinoWrapper
16
+
17
+ DEFAULT_NUM_WINDOWED = 5
18
+ DEFAULT_NUM_GLOBAL = 4
19
+
20
+
21
+ class VitDetArgs:
22
+ def __init__(self,
23
+ window_size: int,
24
+ num_summary_tokens: int,
25
+ num_windowed: int = None,
26
+ num_global: int = None,
27
+ ):
28
+ self.window_size = window_size
29
+ self.num_summary_tokens = num_summary_tokens
30
+ self.num_windowed = num_windowed
31
+ self.num_global = num_global
32
+
33
+
34
+ def apply_vitdet_arch(model: Union[VisionTransformer, DinoWrapper], args: VitDetArgs):
35
+ if isinstance(model, VisionTransformer):
36
+ patch_embed = getattr(model, 'patch_generator', model.patch_embed)
37
+
38
+ return ViTDetHook(patch_embed, model.blocks, args)
39
+ elif isinstance(model, DinoWrapper):
40
+ inner = model.inner
41
+
42
+ patch_embed = getattr(inner, 'patch_generator', inner.patch_embed)
43
+ return ViTDetHook(patch_embed, inner.blocks, args)
44
+ else:
45
+ print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr)
46
+
47
+
48
+ class ViTDetHook:
49
+ def __init__(self,
50
+ embedder: nn.Module,
51
+ blocks: nn.Sequential,
52
+ args: VitDetArgs,
53
+ ):
54
+ self.blocks = blocks
55
+ self.num_summary_tokens = args.num_summary_tokens
56
+ self.window_size = args.window_size
57
+
58
+ self._input_resolution = None
59
+ self._num_windows = None
60
+ self._cls_patch = None
61
+ self._order_cache = dict()
62
+
63
+ embedder.register_forward_pre_hook(self._enter_model)
64
+
65
+ # This will decide if we window-fy the patches
66
+ # and enable vit-det for this iteration, and if so,
67
+ # rearrange the patches for efficient mode switching
68
+ blocks.register_forward_pre_hook(self._enter_blocks)
69
+
70
+ is_global = True
71
+ if args.num_windowed is not None:
72
+ period = args.num_windowed + 1
73
+ else:
74
+ num_global = args.num_global or DEFAULT_NUM_GLOBAL
75
+ period = max(len(blocks) // num_global, 1)
76
+
77
+ for i, layer in enumerate(blocks[:-1]):
78
+ ctr = i % period
79
+ if ctr == 0:
80
+ layer.register_forward_pre_hook(self._to_windows)
81
+ is_global = False
82
+ elif ctr == period - 1:
83
+ layer.register_forward_pre_hook(self._to_global)
84
+ is_global = True
85
+
86
+ # Always ensure the final layer is a global layer
87
+ if not is_global:
88
+ blocks[-1].register_forward_pre_hook(self._to_global)
89
+
90
+ blocks.register_forward_hook(self._exit_model)
91
+
92
+ def _enter_model(self, _, input: List[torch.Tensor]):
93
+ self._input_resolution = input[0].shape[-2:]
94
+
95
+ def _enter_blocks(self, _, input: List[torch.Tensor]):
96
+ # print(f'{get_rank()} - ViTDet Window Size: {self._window_size}', file=sys.stderr)
97
+
98
+ patches = input[0]
99
+ patches = self._rearrange_patches(patches)
100
+
101
+ return (patches,) + input[1:]
102
+
103
+ def _to_windows(self, _, input: List[torch.Tensor]):
104
+ patches = input[0]
105
+
106
+ if self.num_summary_tokens:
107
+ self._cls_patch = patches[:, :self.num_summary_tokens]
108
+ patches = patches[:, self.num_summary_tokens:]
109
+
110
+ patches = rearrange(
111
+ patches, 'b (p t) c -> (b p) t c',
112
+ p=self._num_windows, t=self.window_size ** 2,
113
+ )
114
+
115
+ return (patches,) + input[1:]
116
+
117
+ def _to_global(self, _, input: List[torch.Tensor]):
118
+ patches = input[0]
119
+
120
+ patches = rearrange(
121
+ patches, '(b p) t c -> b (p t) c',
122
+ p=self._num_windows, t=self.window_size ** 2,
123
+ b=patches.shape[0] // self._num_windows,
124
+ )
125
+
126
+ if self.num_summary_tokens:
127
+ patches = torch.cat([
128
+ self._cls_patch,
129
+ patches,
130
+ ], dim=1)
131
+
132
+ return (patches,) + input[1:]
133
+
134
+ def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor):
135
+ # Return patches to their original order
136
+ patch_order = self._order_cache[self._input_resolution][0]
137
+ patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
138
+
139
+ ret_patches = torch.empty_like(patches)
140
+ ret_patches = torch.scatter(
141
+ ret_patches,
142
+ dim=1,
143
+ index=patch_order,
144
+ src=patches,
145
+ )
146
+
147
+ return ret_patches
148
+
149
+ def _rearrange_patches(self, patches: torch.Tensor):
150
+ # We rearrange the patches so that we can efficiently
151
+ # switch between windowed and global mode by just
152
+ # reshaping the tensor
153
+
154
+ patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None))
155
+ if patch_order is None:
156
+ num_feat_patches = patches.shape[1] - self.num_summary_tokens
157
+ num_pixels = self._input_resolution[0] * self._input_resolution[1]
158
+
159
+ patch_size = int(round(math.sqrt(num_pixels / num_feat_patches)))
160
+ rows = self._input_resolution[-2] // patch_size
161
+ cols = self._input_resolution[-1] // patch_size
162
+
163
+ w_rows = rows // self.window_size
164
+ w_cols = cols // self.window_size
165
+
166
+ patch_order = torch.arange(0, num_feat_patches, device=patches.device)
167
+
168
+ patch_order = rearrange(
169
+ patch_order, '(wy py wx px) -> (wy wx py px)',
170
+ wy=w_rows, wx=w_cols,
171
+ py=self.window_size, px=self.window_size,
172
+ )
173
+
174
+ if self.num_summary_tokens:
175
+ patch_order = torch.cat([
176
+ torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device),
177
+ patch_order + self.num_summary_tokens,
178
+ ])
179
+
180
+ self._num_windows = w_rows * w_cols
181
+ self._order_cache[self._input_resolution] = (
182
+ patch_order,
183
+ self._num_windows,
184
+ )
185
+
186
+ patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
187
+ patches = torch.gather(patches, dim=1, index=patch_order)
188
+ return patches
src/models/stable_diffusion3/pipeline_stable_diffusion_3.py ADDED
@@ -0,0 +1,1256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPTextModelWithProjection,
21
+ CLIPTokenizer,
22
+ SiglipImageProcessor,
23
+ SiglipVisionModel,
24
+ T5EncoderModel,
25
+ T5TokenizerFast,
26
+ )
27
+
28
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
29
+ from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
30
+ from diffusers.models.autoencoders import AutoencoderKL
31
+ from diffusers.models.transformers import SD3Transformer2DModel
32
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> import torch
60
+ >>> from diffusers import StableDiffusion3Pipeline
61
+
62
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
63
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
64
+ ... )
65
+ >>> pipe.to("cuda")
66
+ >>> prompt = "A cat holding a sign that says hello world"
67
+ >>> image = pipe(prompt).images[0]
68
+ >>> image.save("sd3.png")
69
+ ```
70
+ """
71
+
72
+
73
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
74
+ def calculate_shift(
75
+ image_seq_len,
76
+ base_seq_len: int = 256,
77
+ max_seq_len: int = 4096,
78
+ base_shift: float = 0.5,
79
+ max_shift: float = 1.15,
80
+ ):
81
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
82
+ b = base_shift - m * base_seq_len
83
+ mu = image_seq_len * m + b
84
+ return mu
85
+
86
+
87
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
88
+ def retrieve_timesteps(
89
+ scheduler,
90
+ num_inference_steps: Optional[int] = None,
91
+ device: Optional[Union[str, torch.device]] = None,
92
+ timesteps: Optional[List[int]] = None,
93
+ sigmas: Optional[List[float]] = None,
94
+ **kwargs,
95
+ ):
96
+ r"""
97
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
98
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
99
+
100
+ Args:
101
+ scheduler (`SchedulerMixin`):
102
+ The scheduler to get timesteps from.
103
+ num_inference_steps (`int`):
104
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
105
+ must be `None`.
106
+ device (`str` or `torch.device`, *optional*):
107
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
108
+ timesteps (`List[int]`, *optional*):
109
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
110
+ `num_inference_steps` and `sigmas` must be `None`.
111
+ sigmas (`List[float]`, *optional*):
112
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
113
+ `num_inference_steps` and `timesteps` must be `None`.
114
+
115
+ Returns:
116
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
117
+ second element is the number of inference steps.
118
+ """
119
+ if timesteps is not None and sigmas is not None:
120
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
121
+ if timesteps is not None:
122
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
123
+ if not accepts_timesteps:
124
+ raise ValueError(
125
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
126
+ f" timestep schedules. Please check whether you are using the correct scheduler."
127
+ )
128
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ num_inference_steps = len(timesteps)
131
+ elif sigmas is not None:
132
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
133
+ if not accept_sigmas:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ num_inference_steps = len(timesteps)
141
+ else:
142
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
143
+ timesteps = scheduler.timesteps
144
+ return timesteps, num_inference_steps
145
+
146
+
147
+ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
148
+ r"""
149
+ Args:
150
+ transformer ([`SD3Transformer2DModel`]):
151
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
152
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
153
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
154
+ vae ([`AutoencoderKL`]):
155
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
156
+ text_encoder ([`CLIPTextModelWithProjection`]):
157
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
158
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
159
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
160
+ as its dimension.
161
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
162
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
163
+ specifically the
164
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
165
+ variant.
166
+ text_encoder_3 ([`T5EncoderModel`]):
167
+ Frozen text-encoder. Stable Diffusion 3 uses
168
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
169
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
170
+ tokenizer (`CLIPTokenizer`):
171
+ Tokenizer of class
172
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
173
+ tokenizer_2 (`CLIPTokenizer`):
174
+ Second Tokenizer of class
175
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
176
+ tokenizer_3 (`T5TokenizerFast`):
177
+ Tokenizer of class
178
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
179
+ image_encoder (`SiglipVisionModel`, *optional*):
180
+ Pre-trained Vision Model for IP Adapter.
181
+ feature_extractor (`SiglipImageProcessor`, *optional*):
182
+ Image processor for IP Adapter.
183
+ """
184
+
185
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
186
+ _optional_components = ["image_encoder", "feature_extractor"]
187
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
188
+
189
+ def __init__(
190
+ self,
191
+ transformer: SD3Transformer2DModel,
192
+ scheduler: FlowMatchEulerDiscreteScheduler,
193
+ vae: AutoencoderKL,
194
+ text_encoder: CLIPTextModelWithProjection,
195
+ tokenizer: CLIPTokenizer,
196
+ text_encoder_2: CLIPTextModelWithProjection,
197
+ tokenizer_2: CLIPTokenizer,
198
+ text_encoder_3: T5EncoderModel,
199
+ tokenizer_3: T5TokenizerFast,
200
+ image_encoder: SiglipVisionModel = None,
201
+ feature_extractor: SiglipImageProcessor = None,
202
+ ):
203
+ super().__init__()
204
+
205
+ self.register_modules(
206
+ vae=vae,
207
+ text_encoder=text_encoder,
208
+ text_encoder_2=text_encoder_2,
209
+ text_encoder_3=text_encoder_3,
210
+ tokenizer=tokenizer,
211
+ tokenizer_2=tokenizer_2,
212
+ tokenizer_3=tokenizer_3,
213
+ transformer=transformer,
214
+ scheduler=scheduler,
215
+ image_encoder=image_encoder,
216
+ feature_extractor=feature_extractor,
217
+ )
218
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
219
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
220
+ self.tokenizer_max_length = (
221
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
222
+ )
223
+ self.default_sample_size = (
224
+ self.transformer.config.sample_size
225
+ if hasattr(self, "transformer") and self.transformer is not None
226
+ else 128
227
+ )
228
+ self.patch_size = (
229
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
230
+ )
231
+
232
+ def _get_t5_prompt_embeds(
233
+ self,
234
+ prompt: Union[str, List[str]] = None,
235
+ num_images_per_prompt: int = 1,
236
+ max_sequence_length: int = 256,
237
+ device: Optional[torch.device] = None,
238
+ dtype: Optional[torch.dtype] = None,
239
+ ):
240
+ device = device or self._execution_device
241
+ dtype = dtype or self.text_encoder.dtype
242
+
243
+ prompt = [prompt] if isinstance(prompt, str) else prompt
244
+ batch_size = len(prompt)
245
+
246
+ if self.text_encoder_3 is None:
247
+ return torch.zeros(
248
+ (
249
+ batch_size * num_images_per_prompt,
250
+ self.tokenizer_max_length,
251
+ self.transformer.config.joint_attention_dim,
252
+ ),
253
+ device=device,
254
+ dtype=dtype,
255
+ )
256
+
257
+ text_inputs = self.tokenizer_3(
258
+ prompt,
259
+ padding="max_length",
260
+ max_length=max_sequence_length,
261
+ truncation=True,
262
+ add_special_tokens=True,
263
+ return_tensors="pt",
264
+ )
265
+ text_input_ids = text_inputs.input_ids
266
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
267
+
268
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
269
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
270
+ # logger.warning(
271
+ # "The following part of your input was truncated because `max_sequence_length` is set to "
272
+ # f" {max_sequence_length} tokens: {removed_text}"
273
+ # )
274
+
275
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
276
+
277
+ dtype = self.text_encoder_3.dtype
278
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
279
+
280
+ _, seq_len, _ = prompt_embeds.shape
281
+
282
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
283
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
284
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
285
+
286
+ return prompt_embeds
287
+
288
+ def _get_clip_prompt_embeds(
289
+ self,
290
+ prompt: Union[str, List[str]],
291
+ num_images_per_prompt: int = 1,
292
+ device: Optional[torch.device] = None,
293
+ clip_skip: Optional[int] = None,
294
+ clip_model_index: int = 0,
295
+ ):
296
+ device = device or self._execution_device
297
+
298
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
299
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
300
+
301
+ tokenizer = clip_tokenizers[clip_model_index]
302
+ text_encoder = clip_text_encoders[clip_model_index]
303
+
304
+ prompt = [prompt] if isinstance(prompt, str) else prompt
305
+ batch_size = len(prompt)
306
+
307
+ text_inputs = tokenizer(
308
+ prompt,
309
+ padding="max_length",
310
+ max_length=self.tokenizer_max_length,
311
+ truncation=True,
312
+ return_tensors="pt",
313
+ )
314
+
315
+ text_input_ids = text_inputs.input_ids
316
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
317
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
318
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
319
+ # logger.warning(
320
+ # "The following part of your input was truncated because CLIP can only handle sequences up to"
321
+ # f" {self.tokenizer_max_length} tokens: {removed_text}"
322
+ # )
323
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
324
+ pooled_prompt_embeds = prompt_embeds[0]
325
+
326
+ if clip_skip is None:
327
+ prompt_embeds = prompt_embeds.hidden_states[-2]
328
+ else:
329
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
330
+
331
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
332
+
333
+ _, seq_len, _ = prompt_embeds.shape
334
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
335
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
336
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
337
+
338
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
339
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
340
+
341
+ return prompt_embeds, pooled_prompt_embeds
342
+
343
+ def encode_pooled_prompt(
344
+ self,
345
+ prompt: Union[str, List[str]],
346
+ prompt_2: Union[str, List[str]],
347
+ device: Optional[torch.device] = None,
348
+ num_images_per_prompt: int = 1,
349
+ do_classifier_free_guidance: bool = True,
350
+ negative_prompt: Optional[Union[str, List[str]]] = None,
351
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
352
+ prompt_embeds: Optional[torch.FloatTensor] = None,
353
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
354
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
355
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
356
+ clip_skip: Optional[int] = None,
357
+ lora_scale: Optional[float] = None,
358
+ ):
359
+ device = device or self._execution_device
360
+
361
+ # set lora scale so that monkey patched LoRA
362
+ # function of text encoder can correctly access it
363
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
364
+ self._lora_scale = lora_scale
365
+
366
+ # dynamically adjust the LoRA scale
367
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
368
+ scale_lora_layers(self.text_encoder, lora_scale)
369
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
370
+ scale_lora_layers(self.text_encoder_2, lora_scale)
371
+
372
+ prompt = [prompt] if isinstance(prompt, str) else prompt
373
+ if prompt is not None:
374
+ batch_size = len(prompt)
375
+ else:
376
+ batch_size = prompt_embeds.shape[0]
377
+
378
+ if prompt_embeds is None:
379
+ prompt_2 = prompt_2 or prompt
380
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
381
+
382
+ _, pooled_prompt_embed = self._get_clip_prompt_embeds(
383
+ prompt=prompt,
384
+ device=device,
385
+ num_images_per_prompt=num_images_per_prompt,
386
+ clip_skip=clip_skip,
387
+ clip_model_index=0,
388
+ )
389
+ _, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
390
+ prompt=prompt_2,
391
+ device=device,
392
+ num_images_per_prompt=num_images_per_prompt,
393
+ clip_skip=clip_skip,
394
+ clip_model_index=1,
395
+ )
396
+
397
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
398
+
399
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
400
+ negative_prompt = negative_prompt or ""
401
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
402
+
403
+ # normalize str to list
404
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
405
+ negative_prompt_2 = (
406
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
407
+ )
408
+
409
+
410
+ if prompt is not None and type(prompt) is not type(negative_prompt):
411
+ raise TypeError(
412
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
413
+ f" {type(prompt)}."
414
+ )
415
+ elif batch_size != len(negative_prompt):
416
+ raise ValueError(
417
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
418
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
419
+ " the batch size of `prompt`."
420
+ )
421
+
422
+ _, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
423
+ negative_prompt,
424
+ device=device,
425
+ num_images_per_prompt=num_images_per_prompt,
426
+ clip_skip=None,
427
+ clip_model_index=0,
428
+ )
429
+ _, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
430
+ negative_prompt_2,
431
+ device=device,
432
+ num_images_per_prompt=num_images_per_prompt,
433
+ clip_skip=None,
434
+ clip_model_index=1,
435
+ )
436
+
437
+ negative_pooled_prompt_embeds = torch.cat(
438
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
439
+ )
440
+
441
+ if self.text_encoder is not None:
442
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
443
+ # Retrieve the original scale by scaling back the LoRA layers
444
+ unscale_lora_layers(self.text_encoder, lora_scale)
445
+
446
+ if self.text_encoder_2 is not None:
447
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
448
+ # Retrieve the original scale by scaling back the LoRA layers
449
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
450
+
451
+ return pooled_prompt_embeds, negative_pooled_prompt_embeds
452
+
453
+
454
+ def encode_prompt(
455
+ self,
456
+ prompt: Union[str, List[str]],
457
+ prompt_2: Union[str, List[str]],
458
+ prompt_3: Union[str, List[str]],
459
+ device: Optional[torch.device] = None,
460
+ num_images_per_prompt: int = 1,
461
+ do_classifier_free_guidance: bool = True,
462
+ negative_prompt: Optional[Union[str, List[str]]] = None,
463
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
464
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
465
+ prompt_embeds: Optional[torch.FloatTensor] = None,
466
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
467
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
468
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
469
+ clip_skip: Optional[int] = None,
470
+ max_sequence_length: int = 256,
471
+ lora_scale: Optional[float] = None,
472
+ ):
473
+ r"""
474
+
475
+ Args:
476
+ prompt (`str` or `List[str]`, *optional*):
477
+ prompt to be encoded
478
+ prompt_2 (`str` or `List[str]`, *optional*):
479
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
480
+ used in all text-encoders
481
+ prompt_3 (`str` or `List[str]`, *optional*):
482
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
483
+ used in all text-encoders
484
+ device: (`torch.device`):
485
+ torch device
486
+ num_images_per_prompt (`int`):
487
+ number of images that should be generated per prompt
488
+ do_classifier_free_guidance (`bool`):
489
+ whether to use classifier free guidance or not
490
+ negative_prompt (`str` or `List[str]`, *optional*):
491
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
492
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
493
+ less than `1`).
494
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
495
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
496
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
497
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
498
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
499
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
500
+ prompt_embeds (`torch.FloatTensor`, *optional*):
501
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
502
+ provided, text embeddings will be generated from `prompt` input argument.
503
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
504
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
505
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
506
+ argument.
507
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
508
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
509
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
510
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
511
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
512
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
513
+ input argument.
514
+ clip_skip (`int`, *optional*):
515
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
516
+ the output of the pre-final layer will be used for computing the prompt embeddings.
517
+ lora_scale (`float`, *optional*):
518
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
519
+ """
520
+ device = device or self._execution_device
521
+
522
+ # set lora scale so that monkey patched LoRA
523
+ # function of text encoder can correctly access it
524
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
525
+ self._lora_scale = lora_scale
526
+
527
+ # dynamically adjust the LoRA scale
528
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
529
+ scale_lora_layers(self.text_encoder, lora_scale)
530
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
531
+ scale_lora_layers(self.text_encoder_2, lora_scale)
532
+
533
+ prompt = [prompt] if isinstance(prompt, str) else prompt
534
+ if prompt is not None:
535
+ batch_size = len(prompt)
536
+ else:
537
+ batch_size = prompt_embeds.shape[0]
538
+
539
+ if prompt_embeds is None:
540
+ prompt_2 = prompt_2 or prompt
541
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
542
+
543
+ prompt_3 = prompt_3 or prompt
544
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
545
+
546
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
547
+ prompt=prompt,
548
+ device=device,
549
+ num_images_per_prompt=num_images_per_prompt,
550
+ clip_skip=clip_skip,
551
+ clip_model_index=0,
552
+ )
553
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
554
+ prompt=prompt_2,
555
+ device=device,
556
+ num_images_per_prompt=num_images_per_prompt,
557
+ clip_skip=clip_skip,
558
+ clip_model_index=1,
559
+ )
560
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
561
+
562
+ t5_prompt_embed = self._get_t5_prompt_embeds(
563
+ prompt=prompt_3,
564
+ num_images_per_prompt=num_images_per_prompt,
565
+ max_sequence_length=max_sequence_length,
566
+ device=device,
567
+ )
568
+
569
+ clip_prompt_embeds = torch.nn.functional.pad(
570
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
571
+ )
572
+
573
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
574
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
575
+
576
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
577
+ negative_prompt = negative_prompt or ""
578
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
579
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
580
+
581
+ # normalize str to list
582
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
583
+ negative_prompt_2 = (
584
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
585
+ )
586
+ negative_prompt_3 = (
587
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
588
+ )
589
+
590
+ if prompt is not None and type(prompt) is not type(negative_prompt):
591
+ raise TypeError(
592
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
593
+ f" {type(prompt)}."
594
+ )
595
+ elif batch_size != len(negative_prompt):
596
+ raise ValueError(
597
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
598
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
599
+ " the batch size of `prompt`."
600
+ )
601
+
602
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
603
+ negative_prompt,
604
+ device=device,
605
+ num_images_per_prompt=num_images_per_prompt,
606
+ clip_skip=None,
607
+ clip_model_index=0,
608
+ )
609
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
610
+ negative_prompt_2,
611
+ device=device,
612
+ num_images_per_prompt=num_images_per_prompt,
613
+ clip_skip=None,
614
+ clip_model_index=1,
615
+ )
616
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
617
+
618
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
619
+ prompt=negative_prompt_3,
620
+ num_images_per_prompt=num_images_per_prompt,
621
+ max_sequence_length=max_sequence_length,
622
+ device=device,
623
+ )
624
+
625
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
626
+ negative_clip_prompt_embeds,
627
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
628
+ )
629
+
630
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
631
+ negative_pooled_prompt_embeds = torch.cat(
632
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
633
+ )
634
+
635
+ if self.text_encoder is not None:
636
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
637
+ # Retrieve the original scale by scaling back the LoRA layers
638
+ unscale_lora_layers(self.text_encoder, lora_scale)
639
+
640
+ if self.text_encoder_2 is not None:
641
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
642
+ # Retrieve the original scale by scaling back the LoRA layers
643
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
644
+
645
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
646
+
647
+ def check_inputs(
648
+ self,
649
+ prompt,
650
+ prompt_2,
651
+ prompt_3,
652
+ height,
653
+ width,
654
+ negative_prompt=None,
655
+ negative_prompt_2=None,
656
+ negative_prompt_3=None,
657
+ prompt_embeds=None,
658
+ negative_prompt_embeds=None,
659
+ pooled_prompt_embeds=None,
660
+ negative_pooled_prompt_embeds=None,
661
+ callback_on_step_end_tensor_inputs=None,
662
+ max_sequence_length=None,
663
+ ):
664
+ if (
665
+ height % (self.vae_scale_factor * self.patch_size) != 0
666
+ or width % (self.vae_scale_factor * self.patch_size) != 0
667
+ ):
668
+ raise ValueError(
669
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
670
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
671
+ )
672
+
673
+ if callback_on_step_end_tensor_inputs is not None and not all(
674
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
675
+ ):
676
+ raise ValueError(
677
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
678
+ )
679
+
680
+ if prompt is not None and prompt_embeds is not None:
681
+ raise ValueError(
682
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
683
+ " only forward one of the two."
684
+ )
685
+ elif prompt_2 is not None and prompt_embeds is not None:
686
+ raise ValueError(
687
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
688
+ " only forward one of the two."
689
+ )
690
+ elif prompt_3 is not None and prompt_embeds is not None:
691
+ raise ValueError(
692
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
693
+ " only forward one of the two."
694
+ )
695
+ elif prompt is None and prompt_embeds is None:
696
+ raise ValueError(
697
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
698
+ )
699
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
700
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
701
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
702
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
703
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
704
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
705
+
706
+ if negative_prompt is not None and negative_prompt_embeds is not None:
707
+ raise ValueError(
708
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
709
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
710
+ )
711
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
712
+ raise ValueError(
713
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
714
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
715
+ )
716
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
717
+ raise ValueError(
718
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
719
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
720
+ )
721
+
722
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
723
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
724
+ raise ValueError(
725
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
726
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
727
+ f" {negative_prompt_embeds.shape}."
728
+ )
729
+
730
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
731
+ raise ValueError(
732
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
733
+ )
734
+
735
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
736
+ raise ValueError(
737
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
738
+ )
739
+
740
+ if max_sequence_length is not None and max_sequence_length > 512:
741
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
742
+
743
+ def prepare_latents(
744
+ self,
745
+ batch_size,
746
+ num_channels_latents,
747
+ height,
748
+ width,
749
+ dtype,
750
+ device,
751
+ generator,
752
+ latents=None,
753
+ ):
754
+ if latents is not None:
755
+ return latents.to(device=device, dtype=dtype)
756
+
757
+ shape = (
758
+ batch_size,
759
+ num_channels_latents,
760
+ int(height) // self.vae_scale_factor,
761
+ int(width) // self.vae_scale_factor,
762
+ )
763
+
764
+ if isinstance(generator, list) and len(generator) != batch_size:
765
+ raise ValueError(
766
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
767
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
768
+ )
769
+
770
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
771
+
772
+ return latents
773
+
774
+ @property
775
+ def guidance_scale(self):
776
+ return self._guidance_scale
777
+
778
+ @property
779
+ def skip_guidance_layers(self):
780
+ return self._skip_guidance_layers
781
+
782
+ @property
783
+ def clip_skip(self):
784
+ return self._clip_skip
785
+
786
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
787
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
788
+ # corresponds to doing no classifier free guidance.
789
+ @property
790
+ def do_classifier_free_guidance(self):
791
+ return self._guidance_scale > 1
792
+
793
+ @property
794
+ def joint_attention_kwargs(self):
795
+ return self._joint_attention_kwargs
796
+
797
+ @property
798
+ def num_timesteps(self):
799
+ return self._num_timesteps
800
+
801
+ @property
802
+ def interrupt(self):
803
+ return self._interrupt
804
+
805
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
806
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
807
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
808
+
809
+ Args:
810
+ image (`PipelineImageInput`):
811
+ Input image to be encoded.
812
+ device: (`torch.device`):
813
+ Torch device.
814
+
815
+ Returns:
816
+ `torch.Tensor`: The encoded image feature representation.
817
+ """
818
+ if not isinstance(image, torch.Tensor):
819
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
820
+
821
+ image = image.to(device=device, dtype=self.dtype)
822
+
823
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
824
+
825
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
826
+ def prepare_ip_adapter_image_embeds(
827
+ self,
828
+ ip_adapter_image: Optional[PipelineImageInput] = None,
829
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
830
+ device: Optional[torch.device] = None,
831
+ num_images_per_prompt: int = 1,
832
+ do_classifier_free_guidance: bool = True,
833
+ ) -> torch.Tensor:
834
+ """Prepares image embeddings for use in the IP-Adapter.
835
+
836
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
837
+
838
+ Args:
839
+ ip_adapter_image (`PipelineImageInput`, *optional*):
840
+ The input image to extract features from for IP-Adapter.
841
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
842
+ Precomputed image embeddings.
843
+ device: (`torch.device`, *optional*):
844
+ Torch device.
845
+ num_images_per_prompt (`int`, defaults to 1):
846
+ Number of images that should be generated per prompt.
847
+ do_classifier_free_guidance (`bool`, defaults to True):
848
+ Whether to use classifier free guidance or not.
849
+ """
850
+ device = device or self._execution_device
851
+
852
+ if ip_adapter_image_embeds is not None:
853
+ if do_classifier_free_guidance:
854
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
855
+ else:
856
+ single_image_embeds = ip_adapter_image_embeds
857
+ elif ip_adapter_image is not None:
858
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
859
+ if do_classifier_free_guidance:
860
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
861
+ else:
862
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
863
+
864
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
865
+
866
+ if do_classifier_free_guidance:
867
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
868
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
869
+
870
+ return image_embeds.to(device=device)
871
+
872
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
873
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
874
+ logger.warning(
875
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
876
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
877
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
878
+ )
879
+
880
+ super().enable_sequential_cpu_offload(*args, **kwargs)
881
+
882
+ @torch.no_grad()
883
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
884
+ def __call__(
885
+ self,
886
+ prompt: Union[str, List[str]] = None,
887
+ prompt_2: Optional[Union[str, List[str]]] = None,
888
+ prompt_3: Optional[Union[str, List[str]]] = None,
889
+ height: Optional[int] = None,
890
+ width: Optional[int] = None,
891
+ num_inference_steps: int = 28,
892
+ sigmas: Optional[List[float]] = None,
893
+ guidance_scale: float = 7.0,
894
+ negative_prompt: Optional[Union[str, List[str]]] = None,
895
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
896
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
897
+ num_images_per_prompt: Optional[int] = 1,
898
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
899
+ latents: Optional[torch.FloatTensor] = None,
900
+ cond_latents: Optional[torch.FloatTensor] = None,
901
+ prompt_embeds: Optional[torch.FloatTensor] = None,
902
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
903
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
904
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
905
+ ip_adapter_image: Optional[PipelineImageInput] = None,
906
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
907
+ output_type: Optional[str] = "pil",
908
+ return_dict: bool = True,
909
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
910
+ clip_skip: Optional[int] = None,
911
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
912
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
913
+ max_sequence_length: int = 256,
914
+ skip_guidance_layers: List[int] = None,
915
+ skip_layer_guidance_scale: float = 2.8,
916
+ skip_layer_guidance_stop: float = 0.2,
917
+ skip_layer_guidance_start: float = 0.01,
918
+ mu: Optional[float] = None,
919
+ ):
920
+ r"""
921
+ Function invoked when calling the pipeline for generation.
922
+
923
+ Args:
924
+ prompt (`str` or `List[str]`, *optional*):
925
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
926
+ instead.
927
+ prompt_2 (`str` or `List[str]`, *optional*):
928
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
929
+ will be used instead
930
+ prompt_3 (`str` or `List[str]`, *optional*):
931
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
932
+ will be used instead
933
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
934
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
935
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
936
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
937
+ num_inference_steps (`int`, *optional*, defaults to 50):
938
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
939
+ expense of slower inference.
940
+ sigmas (`List[float]`, *optional*):
941
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
942
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
943
+ will be used.
944
+ guidance_scale (`float`, *optional*, defaults to 7.0):
945
+ Guidance scale as defined in [Classifier-Free Diffusion
946
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
947
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
948
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
949
+ the text `prompt`, usually at the expense of lower image quality.
950
+ negative_prompt (`str` or `List[str]`, *optional*):
951
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
952
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
953
+ less than `1`).
954
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
955
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
956
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
957
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
958
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
959
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
960
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
961
+ The number of images to generate per prompt.
962
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
963
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
964
+ to make generation deterministic.
965
+ latents (`torch.FloatTensor`, *optional*):
966
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
967
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
968
+ tensor will ge generated by sampling using the supplied random `generator`.
969
+ prompt_embeds (`torch.FloatTensor`, *optional*):
970
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
971
+ provided, text embeddings will be generated from `prompt` input argument.
972
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
973
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
974
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
975
+ argument.
976
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
977
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
978
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
979
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
980
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
981
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
982
+ input argument.
983
+ ip_adapter_image (`PipelineImageInput`, *optional*):
984
+ Optional image input to work with IP Adapters.
985
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
986
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
987
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
988
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
989
+ output_type (`str`, *optional*, defaults to `"pil"`):
990
+ The output format of the generate image. Choose between
991
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
992
+ return_dict (`bool`, *optional*, defaults to `True`):
993
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
994
+ a plain tuple.
995
+ joint_attention_kwargs (`dict`, *optional*):
996
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
997
+ `self.processor` in
998
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
999
+ callback_on_step_end (`Callable`, *optional*):
1000
+ A function that calls at the end of each denoising steps during the inference. The function is called
1001
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1002
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1003
+ `callback_on_step_end_tensor_inputs`.
1004
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1005
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1006
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1007
+ `._callback_tensor_inputs` attribute of your pipeline class.
1008
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
1009
+ skip_guidance_layers (`List[int]`, *optional*):
1010
+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
1011
+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
1012
+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
1013
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
1014
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
1015
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
1016
+ with a scale of `1`.
1017
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
1018
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
1019
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
1020
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
1021
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
1022
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
1023
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
1024
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
1025
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
1026
+
1027
+ Examples:
1028
+
1029
+ Returns:
1030
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
1031
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
1032
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1033
+ """
1034
+
1035
+ height = height or self.default_sample_size * self.vae_scale_factor
1036
+ width = width or self.default_sample_size * self.vae_scale_factor
1037
+
1038
+ # 1. Check inputs. Raise error if not correct
1039
+ self.check_inputs(
1040
+ prompt,
1041
+ prompt_2,
1042
+ prompt_3,
1043
+ height,
1044
+ width,
1045
+ negative_prompt=negative_prompt,
1046
+ negative_prompt_2=negative_prompt_2,
1047
+ negative_prompt_3=negative_prompt_3,
1048
+ prompt_embeds=prompt_embeds,
1049
+ negative_prompt_embeds=negative_prompt_embeds,
1050
+ pooled_prompt_embeds=pooled_prompt_embeds,
1051
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1052
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1053
+ max_sequence_length=max_sequence_length,
1054
+ )
1055
+
1056
+ self._guidance_scale = guidance_scale
1057
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
1058
+ self._clip_skip = clip_skip
1059
+ self._joint_attention_kwargs = joint_attention_kwargs
1060
+ self._interrupt = False
1061
+
1062
+ # 2. Define call parameters
1063
+ if prompt is not None and isinstance(prompt, str):
1064
+ batch_size = 1
1065
+ elif prompt is not None and isinstance(prompt, list):
1066
+ batch_size = len(prompt)
1067
+ else:
1068
+ batch_size = prompt_embeds.shape[0]
1069
+
1070
+ device = self._execution_device
1071
+
1072
+ lora_scale = (
1073
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1074
+ )
1075
+ (
1076
+ prompt_embeds,
1077
+ negative_prompt_embeds,
1078
+ pooled_prompt_embeds,
1079
+ negative_pooled_prompt_embeds,
1080
+ ) = self.encode_prompt(
1081
+ prompt=prompt,
1082
+ prompt_2=prompt_2,
1083
+ prompt_3=prompt_3,
1084
+ negative_prompt=negative_prompt,
1085
+ negative_prompt_2=negative_prompt_2,
1086
+ negative_prompt_3=negative_prompt_3,
1087
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1088
+ prompt_embeds=prompt_embeds,
1089
+ negative_prompt_embeds=negative_prompt_embeds,
1090
+ pooled_prompt_embeds=pooled_prompt_embeds,
1091
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1092
+ device=device,
1093
+ clip_skip=self.clip_skip,
1094
+ num_images_per_prompt=num_images_per_prompt,
1095
+ max_sequence_length=max_sequence_length,
1096
+ lora_scale=lora_scale,
1097
+ )
1098
+
1099
+ if self.do_classifier_free_guidance:
1100
+ if skip_guidance_layers is not None:
1101
+ original_prompt_embeds = prompt_embeds
1102
+ original_pooled_prompt_embeds = pooled_prompt_embeds
1103
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1104
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1105
+
1106
+ # 4. Prepare latent variables
1107
+ num_channels_latents = self.transformer.config.in_channels
1108
+ latents = self.prepare_latents(
1109
+ batch_size * num_images_per_prompt,
1110
+ num_channels_latents,
1111
+ height,
1112
+ width,
1113
+ prompt_embeds.dtype,
1114
+ device,
1115
+ generator,
1116
+ latents,
1117
+ )
1118
+
1119
+ # 5. Prepare timesteps
1120
+ scheduler_kwargs = {}
1121
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1122
+ _, _, height, width = latents.shape
1123
+ image_seq_len = (height // self.transformer.config.patch_size) * (
1124
+ width // self.transformer.config.patch_size
1125
+ )
1126
+ mu = calculate_shift(
1127
+ image_seq_len,
1128
+ self.scheduler.config.get("base_image_seq_len", 256),
1129
+ self.scheduler.config.get("max_image_seq_len", 4096),
1130
+ self.scheduler.config.get("base_shift", 0.5),
1131
+ self.scheduler.config.get("max_shift", 1.16),
1132
+ )
1133
+ scheduler_kwargs["mu"] = mu
1134
+ elif mu is not None:
1135
+ scheduler_kwargs["mu"] = mu
1136
+ timesteps, num_inference_steps = retrieve_timesteps(
1137
+ self.scheduler,
1138
+ num_inference_steps,
1139
+ device,
1140
+ sigmas=sigmas,
1141
+ **scheduler_kwargs,
1142
+ )
1143
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1144
+ self._num_timesteps = len(timesteps)
1145
+
1146
+ # 6. Prepare image embeddings
1147
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1148
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1149
+ ip_adapter_image,
1150
+ ip_adapter_image_embeds,
1151
+ device,
1152
+ batch_size * num_images_per_prompt,
1153
+ self.do_classifier_free_guidance,
1154
+ )
1155
+
1156
+ if self.joint_attention_kwargs is None:
1157
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1158
+ else:
1159
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1160
+
1161
+ if cond_latents is not None and self.do_classifier_free_guidance:
1162
+ if cond_latents.shape[0] == latents.shape[0]:
1163
+ cond_latents = torch.cat([cond_latents]*2)
1164
+
1165
+ # 7. Denoising loop
1166
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1167
+ for i, t in enumerate(timesteps):
1168
+ if self.interrupt:
1169
+ continue
1170
+
1171
+ # expand the latents if we are doing classifier free guidance
1172
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1173
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1174
+ timestep = t.expand(latent_model_input.shape[0])
1175
+
1176
+ noise_pred = self.transformer(
1177
+ hidden_states=latent_model_input,
1178
+ cond_hidden_states=cond_latents,
1179
+ timestep=timestep,
1180
+ encoder_hidden_states=prompt_embeds,
1181
+ pooled_projections=pooled_prompt_embeds,
1182
+ joint_attention_kwargs=self.joint_attention_kwargs,
1183
+ return_dict=False,
1184
+ )[0]
1185
+
1186
+ # perform guidance
1187
+ if self.do_classifier_free_guidance:
1188
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1189
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1190
+ should_skip_layers = (
1191
+ True
1192
+ if i > num_inference_steps * skip_layer_guidance_start
1193
+ and i < num_inference_steps * skip_layer_guidance_stop
1194
+ else False
1195
+ )
1196
+ if skip_guidance_layers is not None and should_skip_layers:
1197
+ timestep = t.expand(latents.shape[0])
1198
+ latent_model_input = latents
1199
+ noise_pred_skip_layers = self.transformer(
1200
+ hidden_states=latent_model_input,
1201
+ timestep=timestep,
1202
+ encoder_hidden_states=original_prompt_embeds,
1203
+ pooled_projections=original_pooled_prompt_embeds,
1204
+ joint_attention_kwargs=self.joint_attention_kwargs,
1205
+ return_dict=False,
1206
+ skip_layers=skip_guidance_layers,
1207
+ )[0]
1208
+ noise_pred = (
1209
+ noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
1210
+ )
1211
+
1212
+ # compute the previous noisy sample x_t -> x_t-1
1213
+ latents_dtype = latents.dtype
1214
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1215
+
1216
+ if latents.dtype != latents_dtype:
1217
+ if torch.backends.mps.is_available():
1218
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1219
+ latents = latents.to(latents_dtype)
1220
+
1221
+ if callback_on_step_end is not None:
1222
+ callback_kwargs = {}
1223
+ for k in callback_on_step_end_tensor_inputs:
1224
+ callback_kwargs[k] = locals()[k]
1225
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1226
+
1227
+ latents = callback_outputs.pop("latents", latents)
1228
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1229
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1230
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1231
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1232
+ )
1233
+
1234
+ # call the callback, if provided
1235
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1236
+ progress_bar.update()
1237
+
1238
+ if XLA_AVAILABLE:
1239
+ xm.mark_step()
1240
+
1241
+ if output_type == "latent":
1242
+ image = latents
1243
+
1244
+ else:
1245
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1246
+
1247
+ image = self.vae.decode(latents, return_dict=False)[0]
1248
+ image = self.image_processor.postprocess(image, output_type=output_type)
1249
+
1250
+ # Offload all models
1251
+ self.maybe_free_model_hooks()
1252
+
1253
+ if not return_dict:
1254
+ return (image,)
1255
+
1256
+ return StableDiffusion3PipelineOutput(images=image)
src/models/stable_diffusion3/pipeline_stable_diffusion_3_dynamic.py ADDED
@@ -0,0 +1,1257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPTextModelWithProjection,
21
+ CLIPTokenizer,
22
+ SiglipImageProcessor,
23
+ SiglipVisionModel,
24
+ T5EncoderModel,
25
+ T5TokenizerFast,
26
+ )
27
+
28
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
29
+ from diffusers.loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
30
+ from diffusers.models.autoencoders import AutoencoderKL
31
+ from diffusers.models.transformers import SD3Transformer2DModel
32
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_xla_available,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
44
+
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> import torch
60
+ >>> from diffusers import StableDiffusion3Pipeline
61
+
62
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
63
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
64
+ ... )
65
+ >>> pipe.to("cuda")
66
+ >>> prompt = "A cat holding a sign that says hello world"
67
+ >>> image = pipe(prompt).images[0]
68
+ >>> image.save("sd3.png")
69
+ ```
70
+ """
71
+
72
+
73
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
74
+ def calculate_shift(
75
+ image_seq_len,
76
+ base_seq_len: int = 256,
77
+ max_seq_len: int = 4096,
78
+ base_shift: float = 0.5,
79
+ max_shift: float = 1.15,
80
+ ):
81
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
82
+ b = base_shift - m * base_seq_len
83
+ mu = image_seq_len * m + b
84
+ return mu
85
+
86
+
87
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
88
+ def retrieve_timesteps(
89
+ scheduler,
90
+ num_inference_steps: Optional[int] = None,
91
+ device: Optional[Union[str, torch.device]] = None,
92
+ timesteps: Optional[List[int]] = None,
93
+ sigmas: Optional[List[float]] = None,
94
+ **kwargs,
95
+ ):
96
+ r"""
97
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
98
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
99
+
100
+ Args:
101
+ scheduler (`SchedulerMixin`):
102
+ The scheduler to get timesteps from.
103
+ num_inference_steps (`int`):
104
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
105
+ must be `None`.
106
+ device (`str` or `torch.device`, *optional*):
107
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
108
+ timesteps (`List[int]`, *optional*):
109
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
110
+ `num_inference_steps` and `sigmas` must be `None`.
111
+ sigmas (`List[float]`, *optional*):
112
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
113
+ `num_inference_steps` and `timesteps` must be `None`.
114
+
115
+ Returns:
116
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
117
+ second element is the number of inference steps.
118
+ """
119
+ if timesteps is not None and sigmas is not None:
120
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
121
+ if timesteps is not None:
122
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
123
+ if not accepts_timesteps:
124
+ raise ValueError(
125
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
126
+ f" timestep schedules. Please check whether you are using the correct scheduler."
127
+ )
128
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ num_inference_steps = len(timesteps)
131
+ elif sigmas is not None:
132
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
133
+ if not accept_sigmas:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ num_inference_steps = len(timesteps)
141
+ else:
142
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
143
+ timesteps = scheduler.timesteps
144
+ return timesteps, num_inference_steps
145
+
146
+
147
+ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
148
+ r"""
149
+ Args:
150
+ transformer ([`SD3Transformer2DModel`]):
151
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
152
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
153
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
154
+ vae ([`AutoencoderKL`]):
155
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
156
+ text_encoder ([`CLIPTextModelWithProjection`]):
157
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
158
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
159
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
160
+ as its dimension.
161
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
162
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
163
+ specifically the
164
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
165
+ variant.
166
+ text_encoder_3 ([`T5EncoderModel`]):
167
+ Frozen text-encoder. Stable Diffusion 3 uses
168
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
169
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
170
+ tokenizer (`CLIPTokenizer`):
171
+ Tokenizer of class
172
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
173
+ tokenizer_2 (`CLIPTokenizer`):
174
+ Second Tokenizer of class
175
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
176
+ tokenizer_3 (`T5TokenizerFast`):
177
+ Tokenizer of class
178
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
179
+ image_encoder (`SiglipVisionModel`, *optional*):
180
+ Pre-trained Vision Model for IP Adapter.
181
+ feature_extractor (`SiglipImageProcessor`, *optional*):
182
+ Image processor for IP Adapter.
183
+ """
184
+
185
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
186
+ _optional_components = ["image_encoder", "feature_extractor"]
187
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
188
+
189
+ def __init__(
190
+ self,
191
+ transformer: SD3Transformer2DModel,
192
+ scheduler: FlowMatchEulerDiscreteScheduler,
193
+ vae: AutoencoderKL,
194
+ text_encoder: CLIPTextModelWithProjection,
195
+ tokenizer: CLIPTokenizer,
196
+ text_encoder_2: CLIPTextModelWithProjection,
197
+ tokenizer_2: CLIPTokenizer,
198
+ text_encoder_3: T5EncoderModel,
199
+ tokenizer_3: T5TokenizerFast,
200
+ image_encoder: SiglipVisionModel = None,
201
+ feature_extractor: SiglipImageProcessor = None,
202
+ ):
203
+ super().__init__()
204
+
205
+ self.register_modules(
206
+ vae=vae,
207
+ text_encoder=text_encoder,
208
+ text_encoder_2=text_encoder_2,
209
+ text_encoder_3=text_encoder_3,
210
+ tokenizer=tokenizer,
211
+ tokenizer_2=tokenizer_2,
212
+ tokenizer_3=tokenizer_3,
213
+ transformer=transformer,
214
+ scheduler=scheduler,
215
+ image_encoder=image_encoder,
216
+ feature_extractor=feature_extractor,
217
+ )
218
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
219
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
220
+ self.tokenizer_max_length = (
221
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
222
+ )
223
+ self.default_sample_size = (
224
+ self.transformer.config.sample_size
225
+ if hasattr(self, "transformer") and self.transformer is not None
226
+ else 128
227
+ )
228
+ self.patch_size = (
229
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
230
+ )
231
+
232
+ def _get_t5_prompt_embeds(
233
+ self,
234
+ prompt: Union[str, List[str]] = None,
235
+ num_images_per_prompt: int = 1,
236
+ max_sequence_length: int = 256,
237
+ device: Optional[torch.device] = None,
238
+ dtype: Optional[torch.dtype] = None,
239
+ ):
240
+ device = device or self._execution_device
241
+ dtype = dtype or self.text_encoder.dtype
242
+
243
+ prompt = [prompt] if isinstance(prompt, str) else prompt
244
+ batch_size = len(prompt)
245
+
246
+ if self.text_encoder_3 is None:
247
+ return torch.zeros(
248
+ (
249
+ batch_size * num_images_per_prompt,
250
+ self.tokenizer_max_length,
251
+ self.transformer.config.joint_attention_dim,
252
+ ),
253
+ device=device,
254
+ dtype=dtype,
255
+ )
256
+
257
+ text_inputs = self.tokenizer_3(
258
+ prompt,
259
+ padding="max_length",
260
+ max_length=max_sequence_length,
261
+ truncation=True,
262
+ add_special_tokens=True,
263
+ return_tensors="pt",
264
+ )
265
+ text_input_ids = text_inputs.input_ids
266
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
267
+
268
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
269
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
270
+ # logger.warning(
271
+ # "The following part of your input was truncated because `max_sequence_length` is set to "
272
+ # f" {max_sequence_length} tokens: {removed_text}"
273
+ # )
274
+
275
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
276
+
277
+ dtype = self.text_encoder_3.dtype
278
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
279
+
280
+ _, seq_len, _ = prompt_embeds.shape
281
+
282
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
283
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
284
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
285
+
286
+ return prompt_embeds
287
+
288
+ def _get_clip_prompt_embeds(
289
+ self,
290
+ prompt: Union[str, List[str]],
291
+ num_images_per_prompt: int = 1,
292
+ device: Optional[torch.device] = None,
293
+ clip_skip: Optional[int] = None,
294
+ clip_model_index: int = 0,
295
+ ):
296
+ device = device or self._execution_device
297
+
298
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
299
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
300
+
301
+ tokenizer = clip_tokenizers[clip_model_index]
302
+ text_encoder = clip_text_encoders[clip_model_index]
303
+
304
+ prompt = [prompt] if isinstance(prompt, str) else prompt
305
+ batch_size = len(prompt)
306
+
307
+ text_inputs = tokenizer(
308
+ prompt,
309
+ padding="max_length",
310
+ max_length=self.tokenizer_max_length,
311
+ truncation=True,
312
+ return_tensors="pt",
313
+ )
314
+
315
+ text_input_ids = text_inputs.input_ids
316
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
317
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
318
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
319
+ # logger.warning(
320
+ # "The following part of your input was truncated because CLIP can only handle sequences up to"
321
+ # f" {self.tokenizer_max_length} tokens: {removed_text}"
322
+ # )
323
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
324
+ pooled_prompt_embeds = prompt_embeds[0]
325
+
326
+ if clip_skip is None:
327
+ prompt_embeds = prompt_embeds.hidden_states[-2]
328
+ else:
329
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
330
+
331
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
332
+
333
+ _, seq_len, _ = prompt_embeds.shape
334
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
335
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
336
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
337
+
338
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
339
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
340
+
341
+ return prompt_embeds, pooled_prompt_embeds
342
+
343
+ def encode_pooled_prompt(
344
+ self,
345
+ prompt: Union[str, List[str]],
346
+ prompt_2: Union[str, List[str]],
347
+ device: Optional[torch.device] = None,
348
+ num_images_per_prompt: int = 1,
349
+ do_classifier_free_guidance: bool = True,
350
+ negative_prompt: Optional[Union[str, List[str]]] = None,
351
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
352
+ prompt_embeds: Optional[torch.FloatTensor] = None,
353
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
354
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
355
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
356
+ clip_skip: Optional[int] = None,
357
+ lora_scale: Optional[float] = None,
358
+ ):
359
+ device = device or self._execution_device
360
+
361
+ # set lora scale so that monkey patched LoRA
362
+ # function of text encoder can correctly access it
363
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
364
+ self._lora_scale = lora_scale
365
+
366
+ # dynamically adjust the LoRA scale
367
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
368
+ scale_lora_layers(self.text_encoder, lora_scale)
369
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
370
+ scale_lora_layers(self.text_encoder_2, lora_scale)
371
+
372
+ prompt = [prompt] if isinstance(prompt, str) else prompt
373
+ if prompt is not None:
374
+ batch_size = len(prompt)
375
+ else:
376
+ batch_size = prompt_embeds.shape[0]
377
+
378
+ if prompt_embeds is None:
379
+ prompt_2 = prompt_2 or prompt
380
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
381
+
382
+ _, pooled_prompt_embed = self._get_clip_prompt_embeds(
383
+ prompt=prompt,
384
+ device=device,
385
+ num_images_per_prompt=num_images_per_prompt,
386
+ clip_skip=clip_skip,
387
+ clip_model_index=0,
388
+ )
389
+ _, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
390
+ prompt=prompt_2,
391
+ device=device,
392
+ num_images_per_prompt=num_images_per_prompt,
393
+ clip_skip=clip_skip,
394
+ clip_model_index=1,
395
+ )
396
+
397
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
398
+
399
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
400
+ negative_prompt = negative_prompt or ""
401
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
402
+
403
+ # normalize str to list
404
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
405
+ negative_prompt_2 = (
406
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
407
+ )
408
+
409
+
410
+ if prompt is not None and type(prompt) is not type(negative_prompt):
411
+ raise TypeError(
412
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
413
+ f" {type(prompt)}."
414
+ )
415
+ elif batch_size != len(negative_prompt):
416
+ raise ValueError(
417
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
418
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
419
+ " the batch size of `prompt`."
420
+ )
421
+
422
+ _, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
423
+ negative_prompt,
424
+ device=device,
425
+ num_images_per_prompt=num_images_per_prompt,
426
+ clip_skip=None,
427
+ clip_model_index=0,
428
+ )
429
+ _, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
430
+ negative_prompt_2,
431
+ device=device,
432
+ num_images_per_prompt=num_images_per_prompt,
433
+ clip_skip=None,
434
+ clip_model_index=1,
435
+ )
436
+
437
+ negative_pooled_prompt_embeds = torch.cat(
438
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
439
+ )
440
+
441
+ if self.text_encoder is not None:
442
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
443
+ # Retrieve the original scale by scaling back the LoRA layers
444
+ unscale_lora_layers(self.text_encoder, lora_scale)
445
+
446
+ if self.text_encoder_2 is not None:
447
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
448
+ # Retrieve the original scale by scaling back the LoRA layers
449
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
450
+
451
+ return pooled_prompt_embeds, negative_pooled_prompt_embeds
452
+
453
+
454
+ def encode_prompt(
455
+ self,
456
+ prompt: Union[str, List[str]],
457
+ prompt_2: Union[str, List[str]],
458
+ prompt_3: Union[str, List[str]],
459
+ device: Optional[torch.device] = None,
460
+ num_images_per_prompt: int = 1,
461
+ do_classifier_free_guidance: bool = True,
462
+ negative_prompt: Optional[Union[str, List[str]]] = None,
463
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
464
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
465
+ prompt_embeds: Optional[torch.FloatTensor] = None,
466
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
467
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
468
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
469
+ clip_skip: Optional[int] = None,
470
+ max_sequence_length: int = 256,
471
+ lora_scale: Optional[float] = None,
472
+ ):
473
+ r"""
474
+
475
+ Args:
476
+ prompt (`str` or `List[str]`, *optional*):
477
+ prompt to be encoded
478
+ prompt_2 (`str` or `List[str]`, *optional*):
479
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
480
+ used in all text-encoders
481
+ prompt_3 (`str` or `List[str]`, *optional*):
482
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
483
+ used in all text-encoders
484
+ device: (`torch.device`):
485
+ torch device
486
+ num_images_per_prompt (`int`):
487
+ number of images that should be generated per prompt
488
+ do_classifier_free_guidance (`bool`):
489
+ whether to use classifier free guidance or not
490
+ negative_prompt (`str` or `List[str]`, *optional*):
491
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
492
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
493
+ less than `1`).
494
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
495
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
496
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
497
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
498
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
499
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
500
+ prompt_embeds (`torch.FloatTensor`, *optional*):
501
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
502
+ provided, text embeddings will be generated from `prompt` input argument.
503
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
504
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
505
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
506
+ argument.
507
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
508
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
509
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
510
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
511
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
512
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
513
+ input argument.
514
+ clip_skip (`int`, *optional*):
515
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
516
+ the output of the pre-final layer will be used for computing the prompt embeddings.
517
+ lora_scale (`float`, *optional*):
518
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
519
+ """
520
+ device = device or self._execution_device
521
+
522
+ # set lora scale so that monkey patched LoRA
523
+ # function of text encoder can correctly access it
524
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
525
+ self._lora_scale = lora_scale
526
+
527
+ # dynamically adjust the LoRA scale
528
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
529
+ scale_lora_layers(self.text_encoder, lora_scale)
530
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
531
+ scale_lora_layers(self.text_encoder_2, lora_scale)
532
+
533
+ prompt = [prompt] if isinstance(prompt, str) else prompt
534
+ if prompt is not None:
535
+ batch_size = len(prompt)
536
+ else:
537
+ batch_size = prompt_embeds.shape[0]
538
+
539
+ if prompt_embeds is None:
540
+ prompt_2 = prompt_2 or prompt
541
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
542
+
543
+ prompt_3 = prompt_3 or prompt
544
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
545
+
546
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
547
+ prompt=prompt,
548
+ device=device,
549
+ num_images_per_prompt=num_images_per_prompt,
550
+ clip_skip=clip_skip,
551
+ clip_model_index=0,
552
+ )
553
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
554
+ prompt=prompt_2,
555
+ device=device,
556
+ num_images_per_prompt=num_images_per_prompt,
557
+ clip_skip=clip_skip,
558
+ clip_model_index=1,
559
+ )
560
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
561
+
562
+ t5_prompt_embed = self._get_t5_prompt_embeds(
563
+ prompt=prompt_3,
564
+ num_images_per_prompt=num_images_per_prompt,
565
+ max_sequence_length=max_sequence_length,
566
+ device=device,
567
+ )
568
+
569
+ clip_prompt_embeds = torch.nn.functional.pad(
570
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
571
+ )
572
+
573
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
574
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
575
+
576
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
577
+ negative_prompt = negative_prompt or ""
578
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
579
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
580
+
581
+ # normalize str to list
582
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
583
+ negative_prompt_2 = (
584
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
585
+ )
586
+ negative_prompt_3 = (
587
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
588
+ )
589
+
590
+ if prompt is not None and type(prompt) is not type(negative_prompt):
591
+ raise TypeError(
592
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
593
+ f" {type(prompt)}."
594
+ )
595
+ elif batch_size != len(negative_prompt):
596
+ raise ValueError(
597
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
598
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
599
+ " the batch size of `prompt`."
600
+ )
601
+
602
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
603
+ negative_prompt,
604
+ device=device,
605
+ num_images_per_prompt=num_images_per_prompt,
606
+ clip_skip=None,
607
+ clip_model_index=0,
608
+ )
609
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
610
+ negative_prompt_2,
611
+ device=device,
612
+ num_images_per_prompt=num_images_per_prompt,
613
+ clip_skip=None,
614
+ clip_model_index=1,
615
+ )
616
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
617
+
618
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
619
+ prompt=negative_prompt_3,
620
+ num_images_per_prompt=num_images_per_prompt,
621
+ max_sequence_length=max_sequence_length,
622
+ device=device,
623
+ )
624
+
625
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
626
+ negative_clip_prompt_embeds,
627
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
628
+ )
629
+
630
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
631
+ negative_pooled_prompt_embeds = torch.cat(
632
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
633
+ )
634
+
635
+ if self.text_encoder is not None:
636
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
637
+ # Retrieve the original scale by scaling back the LoRA layers
638
+ unscale_lora_layers(self.text_encoder, lora_scale)
639
+
640
+ if self.text_encoder_2 is not None:
641
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
642
+ # Retrieve the original scale by scaling back the LoRA layers
643
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
644
+
645
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
646
+
647
+ def check_inputs(
648
+ self,
649
+ prompt,
650
+ prompt_2,
651
+ prompt_3,
652
+ height,
653
+ width,
654
+ negative_prompt=None,
655
+ negative_prompt_2=None,
656
+ negative_prompt_3=None,
657
+ prompt_embeds=None,
658
+ negative_prompt_embeds=None,
659
+ pooled_prompt_embeds=None,
660
+ negative_pooled_prompt_embeds=None,
661
+ callback_on_step_end_tensor_inputs=None,
662
+ max_sequence_length=None,
663
+ ):
664
+ if (
665
+ height % (self.vae_scale_factor * self.patch_size) != 0
666
+ or width % (self.vae_scale_factor * self.patch_size) != 0
667
+ ):
668
+ raise ValueError(
669
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
670
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
671
+ )
672
+
673
+ if callback_on_step_end_tensor_inputs is not None and not all(
674
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
675
+ ):
676
+ raise ValueError(
677
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
678
+ )
679
+
680
+ if prompt is not None and prompt_embeds is not None:
681
+ raise ValueError(
682
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
683
+ " only forward one of the two."
684
+ )
685
+ elif prompt_2 is not None and prompt_embeds is not None:
686
+ raise ValueError(
687
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
688
+ " only forward one of the two."
689
+ )
690
+ elif prompt_3 is not None and prompt_embeds is not None:
691
+ raise ValueError(
692
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
693
+ " only forward one of the two."
694
+ )
695
+ elif prompt is None and prompt_embeds is None:
696
+ raise ValueError(
697
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
698
+ )
699
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
700
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
701
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
702
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
703
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
704
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
705
+
706
+ if negative_prompt is not None and negative_prompt_embeds is not None:
707
+ raise ValueError(
708
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
709
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
710
+ )
711
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
712
+ raise ValueError(
713
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
714
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
715
+ )
716
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
717
+ raise ValueError(
718
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
719
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
720
+ )
721
+
722
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
723
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
724
+ raise ValueError(
725
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
726
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
727
+ f" {negative_prompt_embeds.shape}."
728
+ )
729
+
730
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
731
+ raise ValueError(
732
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
733
+ )
734
+
735
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
736
+ raise ValueError(
737
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
738
+ )
739
+
740
+ if max_sequence_length is not None and max_sequence_length > 512:
741
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
742
+
743
+ def prepare_latents(
744
+ self,
745
+ batch_size,
746
+ num_channels_latents,
747
+ height,
748
+ width,
749
+ dtype,
750
+ device,
751
+ generator,
752
+ latents=None,
753
+ ):
754
+ if latents is not None:
755
+ return latents.to(device=device, dtype=dtype)
756
+
757
+ shape = (
758
+ batch_size,
759
+ num_channels_latents,
760
+ int(height) // self.vae_scale_factor,
761
+ int(width) // self.vae_scale_factor,
762
+ )
763
+
764
+ if isinstance(generator, list) and len(generator) != batch_size:
765
+ raise ValueError(
766
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
767
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
768
+ )
769
+
770
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
771
+
772
+ return latents
773
+
774
+ @property
775
+ def guidance_scale(self):
776
+ return self._guidance_scale
777
+
778
+ @property
779
+ def skip_guidance_layers(self):
780
+ return self._skip_guidance_layers
781
+
782
+ @property
783
+ def clip_skip(self):
784
+ return self._clip_skip
785
+
786
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
787
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
788
+ # corresponds to doing no classifier free guidance.
789
+ @property
790
+ def do_classifier_free_guidance(self):
791
+ return self._guidance_scale > 1
792
+
793
+ @property
794
+ def joint_attention_kwargs(self):
795
+ return self._joint_attention_kwargs
796
+
797
+ @property
798
+ def num_timesteps(self):
799
+ return self._num_timesteps
800
+
801
+ @property
802
+ def interrupt(self):
803
+ return self._interrupt
804
+
805
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
806
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
807
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
808
+
809
+ Args:
810
+ image (`PipelineImageInput`):
811
+ Input image to be encoded.
812
+ device: (`torch.device`):
813
+ Torch device.
814
+
815
+ Returns:
816
+ `torch.Tensor`: The encoded image feature representation.
817
+ """
818
+ if not isinstance(image, torch.Tensor):
819
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
820
+
821
+ image = image.to(device=device, dtype=self.dtype)
822
+
823
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
824
+
825
+ # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
826
+ def prepare_ip_adapter_image_embeds(
827
+ self,
828
+ ip_adapter_image: Optional[PipelineImageInput] = None,
829
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
830
+ device: Optional[torch.device] = None,
831
+ num_images_per_prompt: int = 1,
832
+ do_classifier_free_guidance: bool = True,
833
+ ) -> torch.Tensor:
834
+ """Prepares image embeddings for use in the IP-Adapter.
835
+
836
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
837
+
838
+ Args:
839
+ ip_adapter_image (`PipelineImageInput`, *optional*):
840
+ The input image to extract features from for IP-Adapter.
841
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
842
+ Precomputed image embeddings.
843
+ device: (`torch.device`, *optional*):
844
+ Torch device.
845
+ num_images_per_prompt (`int`, defaults to 1):
846
+ Number of images that should be generated per prompt.
847
+ do_classifier_free_guidance (`bool`, defaults to True):
848
+ Whether to use classifier free guidance or not.
849
+ """
850
+ device = device or self._execution_device
851
+
852
+ if ip_adapter_image_embeds is not None:
853
+ if do_classifier_free_guidance:
854
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
855
+ else:
856
+ single_image_embeds = ip_adapter_image_embeds
857
+ elif ip_adapter_image is not None:
858
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
859
+ if do_classifier_free_guidance:
860
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
861
+ else:
862
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
863
+
864
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
865
+
866
+ if do_classifier_free_guidance:
867
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
868
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
869
+
870
+ return image_embeds.to(device=device)
871
+
872
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
873
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
874
+ logger.warning(
875
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
876
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
877
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
878
+ )
879
+
880
+ super().enable_sequential_cpu_offload(*args, **kwargs)
881
+
882
+ @torch.no_grad()
883
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
884
+ def __call__(
885
+ self,
886
+ prompt: Union[str, List[str]] = None,
887
+ prompt_2: Optional[Union[str, List[str]]] = None,
888
+ prompt_3: Optional[Union[str, List[str]]] = None,
889
+ height: Optional[int] = None,
890
+ width: Optional[int] = None,
891
+ num_inference_steps: int = 28,
892
+ sigmas: Optional[List[float]] = None,
893
+ guidance_scale: float = 7.0,
894
+ negative_prompt: Optional[Union[str, List[str]]] = None,
895
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
896
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
897
+ num_images_per_prompt: Optional[int] = 1,
898
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
899
+ latents: Optional[torch.FloatTensor] = None,
900
+ cond_latents: Optional[list[torch.FloatTensor]] = None,
901
+ prompt_embeds: Optional[torch.FloatTensor] = None,
902
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
903
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
904
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
905
+ ip_adapter_image: Optional[PipelineImageInput] = None,
906
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
907
+ output_type: Optional[str] = "pil",
908
+ return_dict: bool = True,
909
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
910
+ clip_skip: Optional[int] = None,
911
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
912
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
913
+ max_sequence_length: int = 256,
914
+ skip_guidance_layers: List[int] = None,
915
+ skip_layer_guidance_scale: float = 2.8,
916
+ skip_layer_guidance_stop: float = 0.2,
917
+ skip_layer_guidance_start: float = 0.01,
918
+ mu: Optional[float] = None,
919
+ ):
920
+ r"""
921
+ Function invoked when calling the pipeline for generation.
922
+
923
+ Args:
924
+ prompt (`str` or `List[str]`, *optional*):
925
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
926
+ instead.
927
+ prompt_2 (`str` or `List[str]`, *optional*):
928
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
929
+ will be used instead
930
+ prompt_3 (`str` or `List[str]`, *optional*):
931
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
932
+ will be used instead
933
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
934
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
935
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
936
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
937
+ num_inference_steps (`int`, *optional*, defaults to 50):
938
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
939
+ expense of slower inference.
940
+ sigmas (`List[float]`, *optional*):
941
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
942
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
943
+ will be used.
944
+ guidance_scale (`float`, *optional*, defaults to 7.0):
945
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
946
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
947
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
948
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
949
+ usually at the expense of lower image quality.
950
+ negative_prompt (`str` or `List[str]`, *optional*):
951
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
952
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
953
+ less than `1`).
954
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
955
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
956
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
957
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
958
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
959
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
960
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
961
+ The number of images to generate per prompt.
962
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
963
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
964
+ to make generation deterministic.
965
+ latents (`torch.FloatTensor`, *optional*):
966
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
967
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
968
+ tensor will ge generated by sampling using the supplied random `generator`.
969
+ prompt_embeds (`torch.FloatTensor`, *optional*):
970
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
971
+ provided, text embeddings will be generated from `prompt` input argument.
972
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
973
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
974
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
975
+ argument.
976
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
977
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
978
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
979
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
980
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
981
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
982
+ input argument.
983
+ ip_adapter_image (`PipelineImageInput`, *optional*):
984
+ Optional image input to work with IP Adapters.
985
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
986
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
987
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
988
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
989
+ output_type (`str`, *optional*, defaults to `"pil"`):
990
+ The output format of the generate image. Choose between
991
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
992
+ return_dict (`bool`, *optional*, defaults to `True`):
993
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
994
+ a plain tuple.
995
+ joint_attention_kwargs (`dict`, *optional*):
996
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
997
+ `self.processor` in
998
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
999
+ callback_on_step_end (`Callable`, *optional*):
1000
+ A function that calls at the end of each denoising steps during the inference. The function is called
1001
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1002
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1003
+ `callback_on_step_end_tensor_inputs`.
1004
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1005
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1006
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1007
+ `._callback_tensor_inputs` attribute of your pipeline class.
1008
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
1009
+ skip_guidance_layers (`List[int]`, *optional*):
1010
+ A list of integers that specify layers to skip during guidance. If not provided, all layers will be
1011
+ used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
1012
+ Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
1013
+ skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
1014
+ `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
1015
+ with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
1016
+ with a scale of `1`.
1017
+ skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
1018
+ `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
1019
+ `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
1020
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
1021
+ skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
1022
+ `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
1023
+ `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
1024
+ StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
1025
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
1026
+
1027
+ Examples:
1028
+
1029
+ Returns:
1030
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
1031
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
1032
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1033
+ """
1034
+
1035
+ height = height or self.default_sample_size * self.vae_scale_factor
1036
+ width = width or self.default_sample_size * self.vae_scale_factor
1037
+
1038
+ # 1. Check inputs. Raise error if not correct
1039
+ self.check_inputs(
1040
+ prompt,
1041
+ prompt_2,
1042
+ prompt_3,
1043
+ height,
1044
+ width,
1045
+ negative_prompt=negative_prompt,
1046
+ negative_prompt_2=negative_prompt_2,
1047
+ negative_prompt_3=negative_prompt_3,
1048
+ prompt_embeds=prompt_embeds,
1049
+ negative_prompt_embeds=negative_prompt_embeds,
1050
+ pooled_prompt_embeds=pooled_prompt_embeds,
1051
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1052
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1053
+ max_sequence_length=max_sequence_length,
1054
+ )
1055
+
1056
+ self._guidance_scale = guidance_scale
1057
+ self._skip_layer_guidance_scale = skip_layer_guidance_scale
1058
+ self._clip_skip = clip_skip
1059
+ self._joint_attention_kwargs = joint_attention_kwargs
1060
+ self._interrupt = False
1061
+
1062
+ # 2. Define call parameters
1063
+ if prompt is not None and isinstance(prompt, str):
1064
+ batch_size = 1
1065
+ elif prompt is not None and isinstance(prompt, list):
1066
+ batch_size = len(prompt)
1067
+ else:
1068
+ batch_size = prompt_embeds.shape[0]
1069
+
1070
+ device = self._execution_device
1071
+
1072
+ lora_scale = (
1073
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1074
+ )
1075
+ (
1076
+ prompt_embeds,
1077
+ negative_prompt_embeds,
1078
+ pooled_prompt_embeds,
1079
+ negative_pooled_prompt_embeds,
1080
+ ) = self.encode_prompt(
1081
+ prompt=prompt,
1082
+ prompt_2=prompt_2,
1083
+ prompt_3=prompt_3,
1084
+ negative_prompt=negative_prompt,
1085
+ negative_prompt_2=negative_prompt_2,
1086
+ negative_prompt_3=negative_prompt_3,
1087
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1088
+ prompt_embeds=prompt_embeds,
1089
+ negative_prompt_embeds=negative_prompt_embeds,
1090
+ pooled_prompt_embeds=pooled_prompt_embeds,
1091
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1092
+ device=device,
1093
+ clip_skip=self.clip_skip,
1094
+ num_images_per_prompt=num_images_per_prompt,
1095
+ max_sequence_length=max_sequence_length,
1096
+ lora_scale=lora_scale,
1097
+ )
1098
+
1099
+ if self.do_classifier_free_guidance:
1100
+ if skip_guidance_layers is not None:
1101
+ original_prompt_embeds = prompt_embeds
1102
+ original_pooled_prompt_embeds = pooled_prompt_embeds
1103
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1104
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1105
+
1106
+ # 4. Prepare latent variables
1107
+ num_channels_latents = self.transformer.config.in_channels
1108
+ latents = self.prepare_latents(
1109
+ batch_size * num_images_per_prompt,
1110
+ num_channels_latents,
1111
+ height,
1112
+ width,
1113
+ prompt_embeds.dtype,
1114
+ device,
1115
+ generator,
1116
+ latents,
1117
+ )
1118
+
1119
+ # 5. Prepare timesteps
1120
+ scheduler_kwargs = {}
1121
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1122
+ _, _, height, width = latents.shape
1123
+ image_seq_len = (height // self.transformer.config.patch_size) * (
1124
+ width // self.transformer.config.patch_size
1125
+ )
1126
+ mu = calculate_shift(
1127
+ image_seq_len,
1128
+ self.scheduler.config.get("base_image_seq_len", 256),
1129
+ self.scheduler.config.get("max_image_seq_len", 4096),
1130
+ self.scheduler.config.get("base_shift", 0.5),
1131
+ self.scheduler.config.get("max_shift", 1.16),
1132
+ )
1133
+ scheduler_kwargs["mu"] = mu
1134
+ elif mu is not None:
1135
+ scheduler_kwargs["mu"] = mu
1136
+ timesteps, num_inference_steps = retrieve_timesteps(
1137
+ self.scheduler,
1138
+ num_inference_steps,
1139
+ device,
1140
+ sigmas=sigmas,
1141
+ **scheduler_kwargs,
1142
+ )
1143
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1144
+ self._num_timesteps = len(timesteps)
1145
+
1146
+ # 6. Prepare image embeddings
1147
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1148
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1149
+ ip_adapter_image,
1150
+ ip_adapter_image_embeds,
1151
+ device,
1152
+ batch_size * num_images_per_prompt,
1153
+ self.do_classifier_free_guidance,
1154
+ )
1155
+
1156
+ if self.joint_attention_kwargs is None:
1157
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1158
+ else:
1159
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1160
+
1161
+
1162
+ if cond_latents is not None and self.do_classifier_free_guidance:
1163
+ if len(cond_latents) == latents.shape[0]:
1164
+ cond_latents = cond_latents * 2
1165
+
1166
+ # 7. Denoising loop
1167
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1168
+ for i, t in enumerate(timesteps):
1169
+ if self.interrupt:
1170
+ continue
1171
+
1172
+ # expand the latents if we are doing classifier free guidance
1173
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1174
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1175
+ timestep = t.expand(latent_model_input.shape[0])
1176
+
1177
+ noise_pred = self.transformer(
1178
+ hidden_states=latent_model_input,
1179
+ cond_hidden_states=cond_latents,
1180
+ timestep=timestep,
1181
+ encoder_hidden_states=prompt_embeds,
1182
+ pooled_projections=pooled_prompt_embeds,
1183
+ joint_attention_kwargs=self.joint_attention_kwargs,
1184
+ return_dict=False,
1185
+ )[0]
1186
+
1187
+ # perform guidance
1188
+ if self.do_classifier_free_guidance:
1189
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1190
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1191
+ should_skip_layers = (
1192
+ True
1193
+ if i > num_inference_steps * skip_layer_guidance_start
1194
+ and i < num_inference_steps * skip_layer_guidance_stop
1195
+ else False
1196
+ )
1197
+ if skip_guidance_layers is not None and should_skip_layers:
1198
+ timestep = t.expand(latents.shape[0])
1199
+ latent_model_input = latents
1200
+ noise_pred_skip_layers = self.transformer(
1201
+ hidden_states=latent_model_input,
1202
+ timestep=timestep,
1203
+ encoder_hidden_states=original_prompt_embeds,
1204
+ pooled_projections=original_pooled_prompt_embeds,
1205
+ joint_attention_kwargs=self.joint_attention_kwargs,
1206
+ return_dict=False,
1207
+ skip_layers=skip_guidance_layers,
1208
+ )[0]
1209
+ noise_pred = (
1210
+ noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
1211
+ )
1212
+
1213
+ # compute the previous noisy sample x_t -> x_t-1
1214
+ latents_dtype = latents.dtype
1215
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1216
+
1217
+ if latents.dtype != latents_dtype:
1218
+ if torch.backends.mps.is_available():
1219
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1220
+ latents = latents.to(latents_dtype)
1221
+
1222
+ if callback_on_step_end is not None:
1223
+ callback_kwargs = {}
1224
+ for k in callback_on_step_end_tensor_inputs:
1225
+ callback_kwargs[k] = locals()[k]
1226
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1227
+
1228
+ latents = callback_outputs.pop("latents", latents)
1229
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1230
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1231
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1232
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1233
+ )
1234
+
1235
+ # call the callback, if provided
1236
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1237
+ progress_bar.update()
1238
+
1239
+ if XLA_AVAILABLE:
1240
+ xm.mark_step()
1241
+
1242
+ if output_type == "latent":
1243
+ image = latents
1244
+
1245
+ else:
1246
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1247
+
1248
+ image = self.vae.decode(latents, return_dict=False)[0]
1249
+ image = self.image_processor.postprocess(image, output_type=output_type)
1250
+
1251
+ # Offload all models
1252
+ self.maybe_free_model_hooks()
1253
+
1254
+ if not return_dict:
1255
+ return (image,)
1256
+
1257
+ return StableDiffusion3PipelineOutput(images=image)