wusize commited on
Commit
1a2a9f7
·
verified ·
1 Parent(s): 550aa5d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .idea/.gitignore +8 -0
  3. .idea/Puffin.iml +12 -0
  4. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  5. .idea/misc.xml +4 -0
  6. .idea/modules.xml +8 -0
  7. .idea/workspace.xml +50 -0
  8. LICENSE +35 -0
  9. README.md +4 -4
  10. app.py +248 -0
  11. assets/1.jpg +3 -0
  12. assets/2.jpg +3 -0
  13. assets/3.jpg +3 -0
  14. assets/4.jpg +3 -0
  15. assets/5.jpg +3 -0
  16. assets/6.jpg +3 -0
  17. configs/models/qwen2_5_1_5b_radio_sd3_dynamic_puffin.py +87 -0
  18. configs/pipelines/stage_2_base.py +10 -0
  19. configs/pipelines/stage_3_thinking.py +11 -0
  20. configs/pipelines/stage_4_instruction_tuning.py +9 -0
  21. configs/qwen2.5/config.json +27 -0
  22. configs/qwen2.5/generation_config.json +14 -0
  23. configs/qwen2.5/tokenizer.json +0 -0
  24. configs/qwen2.5/tokenizer_config.json +207 -0
  25. configs/qwen2.5/vocab.json +0 -0
  26. configs/radio3/config.json +241 -0
  27. configs/sd3/scheduler/scheduler_config.json +6 -0
  28. configs/sd3/transformer/config.json +15 -0
  29. configs/sd3/vae/config.json +36 -0
  30. requirements.txt +17 -0
  31. scripts/camera/cam_dataset.py +107 -0
  32. scripts/camera/geometry/__init__.py +0 -0
  33. scripts/camera/geometry/base_camera.py +518 -0
  34. scripts/camera/geometry/camera.py +281 -0
  35. scripts/camera/geometry/gravity.py +129 -0
  36. scripts/camera/geometry/jacobians.py +63 -0
  37. scripts/camera/geometry/manifolds.py +113 -0
  38. scripts/camera/geometry/perspective_fields.py +379 -0
  39. scripts/camera/utils/conversions.py +150 -0
  40. scripts/camera/utils/image.py +182 -0
  41. scripts/camera/utils/tensor.py +249 -0
  42. scripts/camera/utils/text.py +47 -0
  43. scripts/camera/visualization/visualize_batch.py +188 -0
  44. scripts/camera/visualization/viz2d.py +521 -0
  45. src/datasets/utils.py +162 -0
  46. src/models/connector/__init__.py +2 -0
  47. src/models/connector/configuration_connector.py +27 -0
  48. src/models/connector/modeling_connector.py +507 -0
  49. src/models/connector/modeling_qwen2.py +50 -0
  50. src/models/puffin/model.py +790 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/Puffin.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="$USER_HOME$/envs/pt2.7" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="$USER_HOME$/envs/pt2.7" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Puffin.iml" filepath="$PROJECT_DIR$/.idea/Puffin.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ChangeListManager">
4
+ <list default="true" id="9dd87dac-8a5e-4178-a1d7-afa664ac2f6a" name="Changes" comment="" />
5
+ <option name="SHOW_DIALOG" value="false" />
6
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
7
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
8
+ <option name="LAST_RESOLUTION" value="IGNORE" />
9
+ </component>
10
+ <component name="ProjectColorInfo"><![CDATA[{
11
+ "associatedIndex": 6
12
+ }]]></component>
13
+ <component name="ProjectId" id="33qty5NdkHzw3ffLzNseYxoODo4" />
14
+ <component name="ProjectViewState">
15
+ <option name="hideEmptyMiddlePackages" value="true" />
16
+ <option name="showLibraryContents" value="true" />
17
+ </component>
18
+ <component name="PropertiesComponent"><![CDATA[{
19
+ "keyToString": {
20
+ "ModuleVcsDetector.initialDetectionPerformed": "true",
21
+ "RunOnceActivity.ShowReadmeOnStart": "true",
22
+ "last_opened_file_path": "/Users/wusize/projects/Puffin",
23
+ "nodejs_package_manager_path": "npm",
24
+ "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable",
25
+ "vue.rearranger.settings.migration": "true"
26
+ }
27
+ }]]></component>
28
+ <component name="SharedIndexes">
29
+ <attachedChunks>
30
+ <set>
31
+ <option value="bundled-js-predefined-d6986cc7102b-6a121458b545-JavaScript-PY-251.25410.159" />
32
+ <option value="bundled-python-sdk-e0ed3721d81e-36ea0e71a18c-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-251.25410.159" />
33
+ </set>
34
+ </attachedChunks>
35
+ </component>
36
+ <component name="TaskManager">
37
+ <task active="true" id="Default" summary="Default task">
38
+ <changelist id="9dd87dac-8a5e-4178-a1d7-afa664ac2f6a" name="Changes" comment="" />
39
+ <created>1760056680813</created>
40
+ <option name="number" value="Default" />
41
+ <option name="presentableId" value="Default" />
42
+ <updated>1760056680813</updated>
43
+ <workItem from="1760056681869" duration="11000" />
44
+ </task>
45
+ <servers />
46
+ </component>
47
+ <component name="TypeScriptGeneratedFilesManager">
48
+ <option name="version" value="3" />
49
+ </component>
50
+ </project>
LICENSE ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S-Lab License 1.0
2
+
3
+ Copyright 2025 S-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and
6
+ binary forms, with or without modification, are permitted provided
7
+ that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright
10
+ notice, this list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright
13
+ notice, this list of conditions and the following disclaimer in
14
+ the documentation and/or other materials provided with the
15
+ distribution.
16
+
17
+ 3. Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived
19
+ from this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+
33
+ In the event that redistribution and/or use for commercial purpose in
34
+ source or binary forms, with or without modification is required,
35
+ please contact the contributor(s) of the work.
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Puffin
3
- emoji: 📚
4
- colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
  title: Puffin
3
+ emoji: 👀
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.23.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import io
4
+ from PIL import Image
5
+ import numpy as np
6
+ import spaces # Import spaces for ZeroGPU compatibility
7
+ import math
8
+ import re
9
+ from einops import rearrange
10
+ from mmengine.config import Config
11
+ from xtuner.registry import BUILDER
12
+
13
+ import matplotlib
14
+ matplotlib.use("Agg")
15
+ import matplotlib.pyplot as plt
16
+
17
+ from scripts.camera.cam_dataset import Cam_Generator
18
+ from scripts.camera.visualization.visualize_batch import make_perspective_figures
19
+
20
+ from huggingface_hub import snapshot_download
21
+ import os
22
+ local_path = snapshot_download(
23
+ repo_id="KangLiao/Puffin",
24
+ repo_type="model",
25
+ #filename="Puffin-Base.pth",
26
+ local_dir="checkpoints/",
27
+ local_dir_use_symlinks=False,
28
+ revision="main",
29
+ )
30
+
31
+
32
+ NUM = r"[+-]?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?"
33
+ CAM_PATTERN = re.compile(r"(?:camera parameters.*?:|roll.*?:)\s*("+NUM+r")\s*,\s*("+NUM+r")\s*,\s*("+NUM+r")", re.IGNORECASE|re.DOTALL)
34
+
35
+ def center_crop(image):
36
+ w, h = image.size
37
+ s = min(w, h)
38
+ l = (w - s) // 2
39
+ t = (h - s) // 2
40
+ return image.crop((l, t, l + s, t + s))
41
+
42
+
43
+ ##### load model
44
+ config = "configs/pipelines/stage_2_base.py"
45
+ config = Config.fromfile(config)
46
+ model = BUILDER.build(config.model).cuda().bfloat16().eval()
47
+ checkpoint_path = "checkpoints/Puffin-Base.pth"
48
+ checkpoint = torch.load(checkpoint_path)
49
+ info = model.load_state_dict(checkpoint, strict=False)
50
+
51
+ def fig_to_image(fig):
52
+ buf = io.BytesIO()
53
+ fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
54
+ buf.seek(0)
55
+ img = Image.open(buf).convert('RGB')
56
+ buf.close()
57
+ return img
58
+
59
+ def extract_up_lat_figs(fig_dict):
60
+ fig_up, fig_lat = None, None
61
+ others = {}
62
+ for k, fig in fig_dict.items():
63
+ if ("up_field" in k) and (fig_up is None):
64
+ fig_up = fig
65
+ elif ("latitude_field" in k) and (fig_lat is None):
66
+ fig_lat = fig
67
+ else:
68
+ others[k] = fig
69
+ return fig_up, fig_lat, others
70
+
71
+
72
+ @torch.inference_mode()
73
+ @spaces.GPU(duration=120)
74
+ # Multimodal Understanding function
75
+ def camera_understanding(image_src, question, seed, progress=gr.Progress(track_tqdm=True)):
76
+ # Clear CUDA cache before generating
77
+ torch.cuda.empty_cache()
78
+
79
+ # set seed
80
+ # torch.manual_seed(seed)
81
+ # np.random.seed(seed)
82
+ # torch.cuda.manual_seed(seed)
83
+ print(torch.cuda.is_available())
84
+
85
+ prompt = ("Describe the image in detail. Then reason its spatial distribution and estimate its camera parameters (roll, pitch, and field-of-view).")
86
+
87
+ image = Image.fromarray(image_src).convert('RGB')
88
+ image = center_crop(image)
89
+ image = image.resize((512, 512))
90
+ x = torch.from_numpy(np.array(image)).float()
91
+ x = x / 255.0
92
+ x = 2 * x - 1
93
+ x = rearrange(x, 'h w c -> c h w')
94
+
95
+ with torch.no_grad():
96
+ outputs = model.understand(prompt=[prompt], pixel_values=[x], progress_bar=False)
97
+
98
+ text = outputs[0]
99
+
100
+ gen = Cam_Generator(mode="base")
101
+ cam = gen.get_cam(text)
102
+
103
+ bgr = np.array(image)[:, :, ::-1].astype(np.float32) / 255.0
104
+ rgb = bgr[:, :, ::-1].copy()
105
+ image_tensor = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0)
106
+ single_batch = {}
107
+ single_batch["image"] = image_tensor
108
+ single_batch["up_field"] = cam[:2].unsqueeze(0)
109
+ single_batch["latitude_field"] = cam[2:].unsqueeze(0)
110
+
111
+ figs = make_perspective_figures(single_batch, single_batch, n_pairs=1)
112
+ up_img = lat_img = None
113
+ for k, fig in figs.items():
114
+ if "up_field" in k:
115
+ up_img = fig_to_image(fig)
116
+ elif "latitude_field" in k:
117
+ lat_img = fig_to_image(fig)
118
+ plt.close(fig)
119
+
120
+ return text#, up_img, lat_img
121
+
122
+
123
+ @torch.inference_mode()
124
+ @spaces.GPU(duration=120) # Specify a duration to avoid timeout
125
+ def generate_image(prompt_scene,
126
+ seed=42,
127
+ roll=0.1,
128
+ pitch=0.1,
129
+ fov=1.0,
130
+ progress=gr.Progress(track_tqdm=True)):
131
+ # Clear CUDA cache and avoid tracking gradients
132
+ torch.cuda.empty_cache()
133
+ # Set the seed for reproducible results
134
+ # if seed is not None:
135
+ torch.manual_seed(seed)
136
+ torch.cuda.manual_seed(seed)
137
+ np.random.seed(seed)
138
+ print(torch.cuda.is_available())
139
+
140
+ generator = torch.Generator().manual_seed(seed)
141
+ prompt_camera = (
142
+ "The camera parameters (roll, pitch, and field-of-view) are: "
143
+ f"{roll:.4f}, {pitch:.4f}, {fov:.4f}."
144
+ )
145
+ gen = Cam_Generator()
146
+ cam_map = gen.get_cam(prompt_camera).to(model.device)
147
+ cam_map = cam_map / (math.pi / 2)
148
+
149
+ prompt = prompt_scene + " " + prompt_camera
150
+ print("prompt:", prompt)
151
+
152
+ bsz = 4
153
+ with torch.no_grad():
154
+ images, output_reasoning = model.generate(
155
+ prompt=[prompt]*bsz,
156
+ cfg_prompt=[""]*bsz,
157
+ pixel_values_init=None,
158
+ cfg_scale=4.5,
159
+ num_steps=50,
160
+ cam_values=[[cam_map]]*bsz,
161
+ progress_bar=False,
162
+ reasoning=False,
163
+ prompt_reasoning=[""]*bsz,
164
+ generator=generator,
165
+ height=512,
166
+ width=512
167
+ )
168
+
169
+ images = rearrange(images, 'b c h w -> b h w c')
170
+ images = torch.clamp(127.5 * images + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
171
+ ret_images = [Image.fromarray(image) for image in images]
172
+ return ret_images
173
+
174
+
175
+ # Gradio interface
176
+ css = '''
177
+ .gradio-container {max-width: 960px !important}
178
+ '''
179
+ with gr.Blocks(css=css) as demo:
180
+ gr.Markdown("# Puffin")
181
+
182
+ with gr.Tab("Camera-controllable Image Generation"):
183
+ gr.Markdown(value="## Camera-controllable Image Generation")
184
+
185
+ prompt_input = gr.Textbox(label="Prompt.")
186
+
187
+ with gr.Accordion("Camera Parameters", open=True):
188
+ with gr.Row():
189
+ roll = gr.Slider(minimum=-0.7854, maximum=0.7854, value=0.1000, step=0.1000, label="roll value")
190
+ pitch = gr.Slider(minimum=-0.7854, maximum=0.7854, value=-0.1000, step=0.1000, label="pitch value")
191
+ fov = gr.Slider(minimum=0.3491, maximum=1.8326, value=1.5000, step=0.1000, label="fov value")
192
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=42)
193
+
194
+ generation_button = gr.Button("Generate Images")
195
+
196
+ image_output = gr.Gallery(label="Generated Images", columns=4, rows=1)
197
+
198
+ examples_t2i = gr.Examples(
199
+ label="Prompt examples.",
200
+ examples=[
201
+ "A sunny day casts light on two warmly colored buildings—yellow with green accents and deeper orange—framed by a lush green tree, with a blue sign and street lamp adding details in the foreground.",
202
+ "A high-vantage-point view of lush, autumn-colored mountains blanketed in green and gold, set against a clear blue sky with scattered white clouds, offering a tranquil and breathtaking vista of a serene valley below.",
203
+ "A grand, historic castle with pointed spires and elaborate stone structures stands against a clear blue sky, flanked by a circular fountain, vibrant red flowers, and neatly trimmed hedges in a beautifully landscaped garden.",
204
+ "A serene aerial view of a coastal landscape at sunrise/sunset, featuring warm pink and orange skies transitioning to cool blues, with calm waters stretching to rugged, snow-capped mountains in the background, creating a tranquil and picturesque scene.",
205
+ "A worn, light-yellow walls room with herringbone terracotta floors and three large arched windows framed in pink trim and white panes, showcasing signs of age and disrepair, overlooks a residential area through glimpses of greenery and neighboring buildings.",
206
+ ],
207
+ inputs=prompt_input,
208
+ )
209
+
210
+ with gr.Tab("Camera Understanding"):
211
+ gr.Markdown(value="## Camera Understanding")
212
+ image_input = gr.Image()
213
+
214
+ understanding_button = gr.Button("Chat")
215
+ understanding_output = gr.Textbox(label="Response")
216
+
217
+ #camera1 = gr.Gallery(label="Camera Maps", columns=1, rows=1)
218
+ #camera2 = gr.Gallery(label="Camera Maps", columns=1, rows=1)
219
+
220
+ with gr.Accordion("Advanced options", open=False):
221
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
222
+
223
+ examples_inpainting = gr.Examples(
224
+ label="Camera Understanding examples",
225
+ examples=[
226
+ "assets/1.jpg",
227
+ "assets/2.jpg",
228
+ "assets/3.jpg",
229
+ "assets/4.jpg",
230
+ "assets/5.jpg",
231
+ "assets/6.jpg",
232
+ ],
233
+ inputs=image_input,
234
+ )
235
+
236
+ generation_button.click(
237
+ fn=generate_image,
238
+ inputs=[prompt_input, seed_input, roll, pitch, fov],
239
+ outputs=image_output
240
+ )
241
+
242
+ understanding_button.click(
243
+ camera_understanding,
244
+ inputs=[image_input, und_seed_input],
245
+ outputs=[understanding_output]#, camera1, camera2]
246
+ )
247
+
248
+ demo.launch(share=True)
assets/1.jpg ADDED

Git LFS Details

  • SHA256: 6293913ecf6462b8aeb11516568e20b117fdf8498db71f13c02afc05e2c04a4a
  • Pointer size: 131 Bytes
  • Size of remote file: 172 kB
assets/2.jpg ADDED

Git LFS Details

  • SHA256: fb208f03a5d544b2aa580926ed5957ef275cb694c2070cd2eff9d84cdb788a5a
  • Pointer size: 131 Bytes
  • Size of remote file: 193 kB
assets/3.jpg ADDED

Git LFS Details

  • SHA256: beba586a41f6bea1b5b14aa8be8b6566cc8e14545ebe6067a79a8cfa8e100da9
  • Pointer size: 131 Bytes
  • Size of remote file: 158 kB
assets/4.jpg ADDED

Git LFS Details

  • SHA256: 2bd1447dcf4ca1ffa212a09c53f5ff1495de37818b7800e80f9125f463ba9a2f
  • Pointer size: 131 Bytes
  • Size of remote file: 109 kB
assets/5.jpg ADDED

Git LFS Details

  • SHA256: 7708e04ad0115344a3b1c166a32b9903809756c73d6a02d93f8fd2e8b668a758
  • Pointer size: 131 Bytes
  • Size of remote file: 124 kB
assets/6.jpg ADDED

Git LFS Details

  • SHA256: 24613661b0d0d297dd6c9617ea5dcd70546a3df1406249ffef0cbd649830daf4
  • Pointer size: 131 Bytes
  • Size of remote file: 161 kB
configs/models/qwen2_5_1_5b_radio_sd3_dynamic_puffin.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from src.models.puffin.model import Qwen2p5RadioStableDiffusion3HFDynamic
4
+ from src.models.stable_diffusion3.transformer_sd3_dynamic import SD3Transformer2DModel
5
+ from src.models.radiov3.hf_model import RADIOModel
6
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ llm_name_or_path = 'Qwen/Qwen2.5-1.5B-Instruct'
10
+ sd3_model_name_or_path = "configs/sd3"
11
+ radiov3_model_name_or_path = "configs/radiov3"
12
+
13
+ prompt_template = dict(
14
+ SYSTEM=('<|im_start|>system\n{system}<|im_end|>\n'),
15
+ INSTRUCTION=('<|im_start|>user\n{input}<|im_end|>\n'
16
+ '<|im_start|>assistant\n'),
17
+ SUFFIX='<|im_end|>',
18
+ IMG_START_TOKEN='<|vision_start|>',
19
+ IMG_END_TOKEN='<|vision_end|>',
20
+ IMG_CONTEXT_TOKEN='<|image_pad|>',
21
+ GENERATION='Generate an image: {input}',
22
+ GENERATION_CROSS='Generate a target image given an initial view: {input}',
23
+ SUFFIX_AS_EOS=True,
24
+ SEP='\n',
25
+ STOP_WORDS=['<|im_end|>', '<|endoftext|>']
26
+ )
27
+
28
+ model = dict(type=Qwen2p5RadioStableDiffusion3HFDynamic,
29
+ num_queries=64,
30
+ connector_1=dict(
31
+ hidden_size=1024,
32
+ intermediate_size=4096,
33
+ num_hidden_layers=6,
34
+ #_attn_implementation='flash_attention_2',
35
+ num_attention_heads=16, ),
36
+ connector_2=dict(
37
+ hidden_size=1024,
38
+ intermediate_size=4096,
39
+ num_hidden_layers=6,
40
+ #_attn_implementation='flash_attention_2',
41
+ num_attention_heads=16,
42
+ ),
43
+ transformer=dict(
44
+ type=SD3Transformer2DModel.from_config,
45
+ pretrained_model_name_or_path=sd3_model_name_or_path,
46
+ subfolder="transformer",
47
+ torch_dtype=torch.bfloat16,
48
+ #local_files_only=True,
49
+ ),
50
+ test_scheduler=dict(
51
+ type=FlowMatchEulerDiscreteScheduler.from_config,
52
+ pretrained_model_name_or_path=sd3_model_name_or_path,
53
+ subfolder="scheduler",
54
+ #local_files_only=True,
55
+ ),
56
+ train_scheduler=dict(
57
+ type=FlowMatchEulerDiscreteScheduler.from_config,
58
+ pretrained_model_name_or_path=sd3_model_name_or_path,
59
+ subfolder="scheduler",
60
+ #local_files_only=True,
61
+ ),
62
+ vae=dict(
63
+ type=AutoencoderKL.from_config,
64
+ pretrained_model_name_or_path=sd3_model_name_or_path,
65
+ subfolder="vae",
66
+ torch_dtype=torch.bfloat16,
67
+ #local_files_only=True,
68
+ ),
69
+ freeze_visual_encoder=True,
70
+ freeze_llm=True,
71
+ llm=dict(
72
+ type=AutoModelForCausalLM.from_pretrained,
73
+ pretrained_model_name_or_path=llm_name_or_path,
74
+ torch_dtype=torch.bfloat16,
75
+ #attn_implementation='flash_attention_2',
76
+ ),
77
+ tokenizer=dict(
78
+ type=AutoTokenizer.from_pretrained,
79
+ pretrained_model_name_or_path=llm_name_or_path),
80
+ prompt_template=prompt_template,
81
+ pretrained_pth=None,
82
+ use_activation_checkpointing=False,
83
+ visual_encoder=dict(
84
+ type=RADIOModel.from_pretrained,
85
+ pretrained_model_name_or_path="nvidia/C-RADIOv3-H",
86
+ torch_dtype=torch.bfloat16,),
87
+ )
configs/pipelines/stage_2_base.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine.config import read_base
2
+
3
+ with read_base():
4
+ from ..models.qwen2_5_1_5b_radio_sd3_dynamic_puffin import model
5
+
6
+ model.freeze_visual_encoder = False
7
+ model.freeze_llm = False
8
+ model.freeze_transformer = False
9
+ model.use_activation_checkpointing = True
10
+ model.visual_encoder_grad_scale = 0.1
configs/pipelines/stage_3_thinking.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine.config import read_base
2
+
3
+ with read_base():
4
+ from ..models.qwen2_5_1_5b_radio_sd3_dynamic_puffin import model
5
+
6
+ model.freeze_visual_encoder = False
7
+ model.freeze_llm = False
8
+ model.freeze_transformer = False
9
+ model.use_activation_checkpointing = True
10
+ model.visual_encoder_grad_scale = 0.1
11
+ #model.pretrained_pth = 'work_dirs/stage_2_base/iter_30000.pth'
configs/pipelines/stage_4_instruction_tuning.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from mmengine.config import read_base
2
+ with read_base():
3
+ from ..models.qwen2_5_1_5b_radio_sd3_dynamic_puffin import model
4
+
5
+ model.freeze_visual_encoder = True
6
+ model.freeze_llm = False
7
+ model.freeze_transformer = False
8
+ model.use_activation_checkpointing = True
9
+ model.unconditional_cross_view=0.1
configs/qwen2.5/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen2ForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 151643,
7
+ "eos_token_id": 151645,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 1536,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 8960,
12
+ "max_position_embeddings": 32768,
13
+ "max_window_layers": 21,
14
+ "model_type": "qwen2",
15
+ "num_attention_heads": 12,
16
+ "num_hidden_layers": 28,
17
+ "num_key_value_heads": 2,
18
+ "rms_norm_eps": 1e-06,
19
+ "rope_theta": 1000000.0,
20
+ "sliding_window": 32768,
21
+ "tie_word_embeddings": true,
22
+ "torch_dtype": "bfloat16",
23
+ "transformers_version": "4.43.1",
24
+ "use_cache": true,
25
+ "use_sliding_window": false,
26
+ "vocab_size": 151936
27
+ }
configs/qwen2.5/generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "pad_token_id": 151643,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 151645,
7
+ 151643
8
+ ],
9
+ "repetition_penalty": 1.1,
10
+ "temperature": 0.7,
11
+ "top_p": 0.8,
12
+ "top_k": 20,
13
+ "transformers_version": "4.37.0"
14
+ }
configs/qwen2.5/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
configs/qwen2.5/tokenizer_config.json ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ }
181
+ },
182
+ "additional_special_tokens": [
183
+ "<|im_start|>",
184
+ "<|im_end|>",
185
+ "<|object_ref_start|>",
186
+ "<|object_ref_end|>",
187
+ "<|box_start|>",
188
+ "<|box_end|>",
189
+ "<|quad_start|>",
190
+ "<|quad_end|>",
191
+ "<|vision_start|>",
192
+ "<|vision_end|>",
193
+ "<|vision_pad|>",
194
+ "<|image_pad|>",
195
+ "<|video_pad|>"
196
+ ],
197
+ "bos_token": null,
198
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
199
+ "clean_up_tokenization_spaces": false,
200
+ "eos_token": "<|im_end|>",
201
+ "errors": "replace",
202
+ "model_max_length": 131072,
203
+ "pad_token": "<|endoftext|>",
204
+ "split_special_tokens": false,
205
+ "tokenizer_class": "Qwen2Tokenizer",
206
+ "unk_token": null
207
+ }
configs/qwen2.5/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
configs/radio3/config.json ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "adaptor_configs": {},
3
+ "adaptor_names": null,
4
+ "architectures": [
5
+ "RADIOModel"
6
+ ],
7
+ "args": {
8
+ "aa": null,
9
+ "amp": true,
10
+ "amp_dtype": "bfloat16",
11
+ "amp_impl": "native",
12
+ "aug_repeats": 0,
13
+ "aug_splits": 0,
14
+ "bn_eps": null,
15
+ "bn_momentum": null,
16
+ "cache_dir": null,
17
+ "channels_last": false,
18
+ "checkpoint_hist": 10,
19
+ "chk_keep_forever": 100,
20
+ "class_map": "",
21
+ "clip_grad": null,
22
+ "clip_mode": "norm",
23
+ "cls_token_per_teacher": true,
24
+ "coco_annotations_file": "/datasets/coco2017-adlsa/annotations/captions_val2017.json",
25
+ "coco_image_dir": "/datasets/coco2017-adlsa/val2017",
26
+ "color_jitter": 0.4,
27
+ "cooldown_epochs": 0,
28
+ "cpe_max_size": 2048,
29
+ "cpe_num_registers": 4,
30
+ "crd_loss": false,
31
+ "crd_loss_weight": 0.8,
32
+ "crop_pct": null,
33
+ "cutmix": 0.0,
34
+ "cutmix_minmax": null,
35
+ "dataset_download": false,
36
+ "debug_full_knn": false,
37
+ "decay_epochs": 90,
38
+ "decay_milestones": [
39
+ 90,
40
+ 180,
41
+ 270
42
+ ],
43
+ "decay_rate": 0.1,
44
+ "depchain": true,
45
+ "detect_anomaly": false,
46
+ "dist_bn": "reduce",
47
+ "dist_norm_weight": 0.0,
48
+ "distributed": true,
49
+ "drop": 0.0,
50
+ "drop_block": null,
51
+ "drop_connect": null,
52
+ "drop_path": null,
53
+ "dtype": "float32",
54
+ "epoch_repeats": 0.0,
55
+ "eval": false,
56
+ "eval_metric": "knn_top1",
57
+ "eval_teacher": false,
58
+ "eval_teacher_only": false,
59
+ "eval_throughput": false,
60
+ "fast_norm": false,
61
+ "fd_loss_fn": "MSE",
62
+ "feature_normalization": "PHI_STANDARDIZE",
63
+ "feature_summarizer": "cls_token",
64
+ "feature_upscale_factor": null,
65
+ "force_new_wandb_id": false,
66
+ "force_spectral_reparam": false,
67
+ "freeze_bn": false,
68
+ "fsdp": true,
69
+ "full_equivariance": false,
70
+ "fuser": "",
71
+ "gp": null,
72
+ "grad_accum_steps": 1,
73
+ "grad_checkpointing": false,
74
+ "head_init_bias": null,
75
+ "head_init_scale": null,
76
+ "head_lr": null,
77
+ "head_warmup": 5,
78
+ "head_weight_decay": 0.01,
79
+ "hflip": 0.5,
80
+ "img_size": null,
81
+ "in_chans": null,
82
+ "initial_checkpoint": null,
83
+ "input_size": null,
84
+ "interpolation": "",
85
+ "layer_decay": null,
86
+ "local_rank": 0,
87
+ "log_interval": 50,
88
+ "log_mlflow": false,
89
+ "log_wandb": true,
90
+ "loss_auto_balance": false,
91
+ "lr_base": 0.1,
92
+ "lr_base_scale": "",
93
+ "lr_base_size": 256,
94
+ "lr_cycle_decay": 0.5,
95
+ "lr_cycle_limit": 1,
96
+ "lr_cycle_mul": 1.0,
97
+ "lr_k_decay": 1.0,
98
+ "lr_noise": null,
99
+ "lr_noise_pct": 0.67,
100
+ "lr_noise_std": 1.0,
101
+ "mean": null,
102
+ "mesa": false,
103
+ "min_lr": 0.0001,
104
+ "mixup": 0.0,
105
+ "mixup_mode": "batch",
106
+ "mixup_off_epoch": 0,
107
+ "mixup_prob": 1.0,
108
+ "mixup_switch_prob": 0.5,
109
+ "mlp_hidden_size": 2560,
110
+ "mlp_num_inner": 1,
111
+ "mlp_version": "v2",
112
+ "model": "vit_huge_patch16_224",
113
+ "model_kwargs": {},
114
+ "model_norm": false,
115
+ "momentum": 0.9,
116
+ "no_aug": false,
117
+ "no_custom_validation": false,
118
+ "no_ddp_bb": true,
119
+ "no_knn": false,
120
+ "no_prefetcher": false,
121
+ "no_resume_opt": false,
122
+ "num_classes": null,
123
+ "one_logger_app_tag": "",
124
+ "one_logger_is_baseline": false,
125
+ "one_logger_run_name": "",
126
+ "onelogger": null,
127
+ "opt_betas": null,
128
+ "opt_eps": null,
129
+ "patience_epochs": 10,
130
+ "pin_mem": false,
131
+ "prefetcher": true,
132
+ "pretrained": false,
133
+ "rank": 0,
134
+ "ratio": [
135
+ 0.75,
136
+ 1.3333333333333333
137
+ ],
138
+ "recount": 1,
139
+ "recovery_interval": 0,
140
+ "register_multiple": 0,
141
+ "remode": "pixel",
142
+ "reprob": 0.0,
143
+ "reset_loss_state": true,
144
+ "resplit": false,
145
+ "sample_tracking": false,
146
+ "save_images": false,
147
+ "scale": [
148
+ 0.5,
149
+ 1.0
150
+ ],
151
+ "sched": "cosine",
152
+ "seed": 42,
153
+ "shift_equivariance": true,
154
+ "smoothing": 0.1,
155
+ "spectral_heads": false,
156
+ "spectral_reparam": false,
157
+ "spectral_weight_decay": null,
158
+ "split_bn": false,
159
+ "start_epoch": null,
160
+ "std": null,
161
+ "stream_teachers": true,
162
+ "sync_bn": false,
163
+ "synchronize_step": false,
164
+ "teachers": [
165
+ {
166
+ "fd_normalize": false,
167
+ "feature_distillation": true,
168
+ "input_size": 378,
169
+ "model": "ViT-H-14-378-quickgelu",
170
+ "name": "clip",
171
+ "pretrained": "dfn5b",
172
+ "type": "open_clip",
173
+ "use_summary": true
174
+ },
175
+ {
176
+ "fd_normalize": false,
177
+ "feature_distillation": true,
178
+ "input_size": 384,
179
+ "model": "siglip2-g-384",
180
+ "name": "siglip2-g",
181
+ "type": "siglip2",
182
+ "use_summary": true
183
+ },
184
+ {
185
+ "fd_normalize": false,
186
+ "feature_distillation": true,
187
+ "input_size": 224,
188
+ "model": "dinov2_vitg14_reg",
189
+ "name": "dino_v2",
190
+ "type": "dino_v2",
191
+ "use_summary": true
192
+ },
193
+ {
194
+ "fd_normalize": false,
195
+ "feature_distillation": true,
196
+ "input_size": 1024,
197
+ "model": "vit-h",
198
+ "name": "sam",
199
+ "type": "sam",
200
+ "use_summary": false
201
+ }
202
+ ],
203
+ "torchcompile": null,
204
+ "torchscript": false,
205
+ "train_interpolation": "random",
206
+ "train_split": "train",
207
+ "tta": 0,
208
+ "use_coco": false,
209
+ "use_multi_epochs_loader": false,
210
+ "val_ema_only": false,
211
+ "val_split": "val",
212
+ "vflip": 0.0,
213
+ "vitdet_version": 1,
214
+ "wandb_entity": "",
215
+ "wandb_id": "",
216
+ "wandb_job_type": "",
217
+ "wandb_name": "",
218
+ "wandb_project": "",
219
+ "warmup_lr": 1e-05,
220
+ "warmup_prefix": false,
221
+ "worker_seeding": "all",
222
+ "workers": 8,
223
+ "world_size": 256
224
+ },
225
+ "auto_map": {
226
+ "AutoConfig": "hf_model.RADIOConfig",
227
+ "AutoModel": "hf_model.RADIOModel"
228
+ },
229
+ "feature_normalizer_config": null,
230
+ "inter_feature_normalizer_config": null,
231
+ "max_resolution": 2048,
232
+ "patch_size": 16,
233
+ "preferred_resolution": [
234
+ 512,
235
+ 512
236
+ ],
237
+ "torch_dtype": "float32",
238
+ "transformers_version": "4.51.3",
239
+ "version": "c-radio_v3-h",
240
+ "vitdet_window_size": null
241
+ }
configs/sd3/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.29.0.dev0",
4
+ "num_train_timesteps": 1000,
5
+ "shift": 3.0
6
+ }
configs/sd3/transformer/config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SD3Transformer2DModel",
3
+ "_diffusers_version": "0.29.0.dev0",
4
+ "attention_head_dim": 64,
5
+ "caption_projection_dim": 1536,
6
+ "in_channels": 16,
7
+ "joint_attention_dim": 4096,
8
+ "num_attention_heads": 24,
9
+ "num_layers": 24,
10
+ "out_channels": 16,
11
+ "patch_size": 2,
12
+ "pooled_projection_dim": 2048,
13
+ "pos_embed_max_size": 192,
14
+ "sample_size": 128
15
+ }
configs/sd3/vae/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.29.0.dev0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 16,
20
+ "latents_mean": null,
21
+ "latents_std": null,
22
+ "layers_per_block": 2,
23
+ "norm_num_groups": 32,
24
+ "out_channels": 3,
25
+ "sample_size": 1024,
26
+ "scaling_factor": 1.5305,
27
+ "shift_factor": 0.0609,
28
+ "up_block_types": [
29
+ "UpDecoderBlock2D",
30
+ "UpDecoderBlock2D",
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D"
33
+ ],
34
+ "use_post_quant_conv": false,
35
+ "use_quant_conv": false
36
+ }
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ accelerate
3
+ diffusers==0.34.0
4
+ gradio
5
+ torchvision
6
+ safetensors
7
+ matplotlib==3.10.1
8
+ matplotlib-inline==0.1.7
9
+ mmengine==0.10.7
10
+ numpy==2.2.5
11
+ pillow==11.2.1
12
+ scipy==1.15.2
13
+ timm==0.9.12
14
+ transformers==4.49.0
15
+ xtuner==0.1.23
16
+ deepspeed
17
+
scripts/camera/cam_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ from scripts.camera.geometry.camera import SimpleRadial
8
+ from scripts.camera.geometry.gravity import Gravity
9
+ from scripts.camera.geometry.perspective_fields import get_perspective_field
10
+ from scripts.camera.utils.conversions import fov2focal
11
+ from scripts.camera.utils.text import parse_camera_params
12
+
13
+ class Cam_Generator:
14
+ def __init__(self, mode="base"):
15
+ self.mode = mode
16
+
17
+ def _load_text(self, caption, h=512, w=512, k1=0, k2=0):
18
+ # Parse camera params from caption
19
+ roll, pitch, vfov = parse_camera_params(caption, self.mode)
20
+
21
+ # Convert vertical FoV to focal length
22
+ f = fov2focal(torch.tensor(vfov), h)
23
+ px, py = w / 2, h / 2
24
+ params = torch.tensor([w, h, f, f, px, py, k1, k2]).float()
25
+ gravity = torch.tensor([roll, pitch]).float()
26
+ return params, gravity
27
+
28
+ def _read_param(self, parameters, gravity):
29
+ # Build camera and gravity objects
30
+ camera = SimpleRadial(parameters).float()
31
+ roll, pitch = gravity.unbind(-1)
32
+ gravity_obj = Gravity.from_rp(roll, pitch)
33
+ camera = camera.scale(torch.Tensor([1, 1]))
34
+ return {"camera": camera, "gravity": gravity_obj}
35
+
36
+ def _get_perspective(self, data):
37
+ # Generate up and latitude fields
38
+ camera = data["camera"]
39
+ gravity_obj = data["gravity"]
40
+ up_field, lat_field = get_perspective_field(
41
+ camera, gravity_obj, use_up=True, use_latitude=True
42
+ )
43
+ del camera, gravity_obj
44
+ return torch.cat([up_field[0], lat_field[0]], dim=0)
45
+
46
+ def get_cam(self, caption):
47
+ params, gravity = self._load_text(caption)
48
+ data = self._read_param(params, gravity)
49
+ return self._get_perspective(data)
50
+
51
+ def process_folders(input_root, output_root, start_idx=0, num_folders=None, mode="base"):
52
+ gen = Cam_Generator(mode=mode)
53
+ all_dirs = sorted([
54
+ d for d in os.listdir(input_root)
55
+ if os.path.isdir(os.path.join(input_root, d))
56
+ ])
57
+ if num_folders is None:
58
+ num_folders = len(all_dirs) - start_idx
59
+ selected = all_dirs[start_idx:start_idx + num_folders]
60
+
61
+ for sub in tqdm(selected, desc="Subfolders"):
62
+ in_sub = os.path.join(input_root, sub)
63
+ out_sub = os.path.join(output_root, sub)
64
+ os.makedirs(out_sub, exist_ok=True)
65
+
66
+ json_files = sorted([
67
+ f for f in os.listdir(in_sub)
68
+ if f.lower().endswith('.json')
69
+ ])
70
+
71
+ for jf in tqdm(json_files, desc=f"Processing {sub}", leave=False):
72
+ in_path = os.path.join(in_sub, jf)
73
+ with open(in_path, 'r', encoding='utf-8') as f:
74
+ data = json.load(f)
75
+ caption = data.get('caption', '')
76
+ cam = gen.get_cam(caption)
77
+ out_name = os.path.splitext(jf)[0] + '.pt'
78
+ out_path = os.path.join(out_sub, out_name)
79
+ torch.save(cam, out_path)
80
+
81
+ def main():
82
+ parser = argparse.ArgumentParser(
83
+ description="Batch process the captions to the camera maps and save as .pt"
84
+ )
85
+ parser.add_argument('--input_root', type=str,
86
+ help='Root directory of JSON subfolders')
87
+ parser.add_argument('--output_root', type=str,
88
+ help='Root directory to save .pt files')
89
+ parser.add_argument('--start_idx', type=int, default=0,
90
+ help='Start index of subfolders (0-based, default=0)')
91
+ parser.add_argument('--num_folders', type=int, default=None,
92
+ help='Number of subfolders to process (default: all)')
93
+ parser.add_argument('--mode', type=str, default='base',
94
+ help='parse_camera_params mode')
95
+ args = parser.parse_args()
96
+
97
+ process_folders(
98
+ args.input_root,
99
+ args.output_root,
100
+ start_idx=args.start_idx,
101
+ num_folders=args.num_folders,
102
+ mode=args.mode
103
+ )
104
+
105
+
106
+ if __name__ == '__main__':
107
+ main()
scripts/camera/geometry/__init__.py ADDED
File without changes
scripts/camera/geometry/base_camera.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convenience classes a for camera models.
2
+
3
+ Based on PyTorch tensors: differentiable, batched, with GPU support.
4
+ Adapted from https://github.com/cvg/GeoCalib
5
+ """
6
+
7
+ from abc import abstractmethod
8
+ from typing import Dict, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch.func import jacfwd, vmap
12
+ from torch.nn import functional as F
13
+
14
+ from scripts.camera.geometry.gravity import Gravity
15
+ from scripts.camera.utils.conversions import deg2rad, focal2fov, fov2focal, rad2rotmat
16
+ from scripts.camera.utils.tensor import TensorWrapper, autocast
17
+
18
+ # mypy: ignore-errors
19
+
20
+
21
+ class BaseCamera(TensorWrapper):
22
+ """Camera tensor class."""
23
+
24
+ eps = 1e-3
25
+
26
+ @autocast
27
+ def __init__(self, data: torch.Tensor):
28
+ """Camera parameters with shape (..., {w, h, fx, fy, cx, cy, *dist}).
29
+
30
+ Tensor convention: (..., {w, h, fx, fy, cx, cy, pitch, roll, *dist}) where
31
+ - w, h: image size in pixels
32
+ - fx, fy: focal lengths in pixels
33
+ - cx, cy: principal points in normalized image coordinates
34
+ - dist: distortion parameters
35
+
36
+ Args:
37
+ data (torch.Tensor): Camera parameters with shape (..., {6, 7, 8}).
38
+ """
39
+ # w, h, fx, fy, cx, cy, dist
40
+ assert data.shape[-1] in {6, 7, 8}, data.shape
41
+
42
+ pad = data.new_zeros(data.shape[:-1] + (8 - data.shape[-1],))
43
+ data = torch.cat([data, pad], -1) if data.shape[-1] != 8 else data
44
+ super().__init__(data)
45
+
46
+ @classmethod
47
+ def from_dict(cls, param_dict: Dict[str, torch.Tensor]) -> "BaseCamera":
48
+ """Create a Camera object from a dictionary of parameters.
49
+
50
+ Args:
51
+ param_dict (Dict[str, torch.Tensor]): Dictionary of parameters.
52
+
53
+ Returns:
54
+ Camera: Camera object.
55
+ """
56
+ for key, value in param_dict.items():
57
+ if not isinstance(value, torch.Tensor):
58
+ param_dict[key] = torch.tensor(value)
59
+
60
+ h, w = param_dict["height"], param_dict["width"]
61
+ cx, cy = param_dict.get("cx", w / 2), param_dict.get("cy", h / 2)
62
+
63
+ vfov = param_dict.get("vfov")
64
+ f = param_dict.get("f", fov2focal(vfov, h))
65
+
66
+ if "dist" in param_dict:
67
+ k1, k2 = param_dict["dist"][..., 0], param_dict["dist"][..., 1]
68
+ elif "k1_hat" in param_dict:
69
+ k1 = param_dict["k1_hat"] * (f / h) ** 2
70
+
71
+ k2 = param_dict.get("k2", torch.zeros_like(k1))
72
+ else:
73
+ k1 = param_dict.get("k1", torch.zeros_like(f))
74
+ k2 = param_dict.get("k2", torch.zeros_like(f))
75
+
76
+ fx, fy = f, f
77
+ if "scales" in param_dict:
78
+ scales = param_dict["scales"]
79
+ fx = fx * scales[..., 0] / scales[..., 1]
80
+
81
+ params = torch.stack([w, h, fx, fy, cx, cy, k1, k2], dim=-1)
82
+ return cls(params)
83
+
84
+ def pinhole(self):
85
+ """Return the pinhole camera model."""
86
+ return self.__class__(self._data[..., :6])
87
+
88
+ @property
89
+ def size(self) -> torch.Tensor:
90
+ """Size (width height) of the images, with shape (..., 2)."""
91
+ return self._data[..., :2]
92
+
93
+ @property
94
+ def f(self) -> torch.Tensor:
95
+ """Focal lengths (fx, fy) with shape (..., 2)."""
96
+ return self._data[..., 2:4]
97
+
98
+ @property
99
+ def vfov(self) -> torch.Tensor:
100
+ """Vertical field of view in radians."""
101
+ return focal2fov(self.f[..., 1], self.size[..., 1])
102
+
103
+ @property
104
+ def hfov(self) -> torch.Tensor:
105
+ """Horizontal field of view in radians."""
106
+ return focal2fov(self.f[..., 0], self.size[..., 0])
107
+
108
+ @property
109
+ def c(self) -> torch.Tensor:
110
+ """Principal points (cx, cy) with shape (..., 2)."""
111
+ return self._data[..., 4:6]
112
+
113
+ @property
114
+ def K(self) -> torch.Tensor:
115
+ """Returns the self intrinsic matrix with shape (..., 3, 3)."""
116
+ shape = self.shape + (3, 3)
117
+ K = self._data.new_zeros(shape)
118
+ K[..., 0, 0] = self.f[..., 0]
119
+ K[..., 1, 1] = self.f[..., 1]
120
+ K[..., 0, 2] = self.c[..., 0]
121
+ K[..., 1, 2] = self.c[..., 1]
122
+ K[..., 2, 2] = 1
123
+ return K
124
+
125
+ def update_focal(self, delta: torch.Tensor, as_log: bool = False):
126
+ """Update the self parameters after changing the focal length."""
127
+ f = torch.exp(torch.log(self.f) + delta) if as_log else self.f + delta
128
+
129
+ # clamp focal length to a reasonable range for stability during training
130
+ min_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(150), self.size[..., 1])
131
+ max_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(5), self.size[..., 1])
132
+ min_f = min_f.unsqueeze(-1).expand(-1, 2)
133
+ max_f = max_f.unsqueeze(-1).expand(-1, 2)
134
+ f = f.clamp(min=min_f, max=max_f)
135
+
136
+ # make sure focal ration stays the same (avoid inplace operations)
137
+ fx = f[..., 1] * self.f[..., 0] / self.f[..., 1]
138
+ f = torch.stack([fx, f[..., 1]], -1)
139
+
140
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
141
+ return self.__class__(torch.cat([self.size, f, self.c, dist], -1))
142
+
143
+ def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):
144
+ """Update the self parameters after resizing an image."""
145
+ scales = (scales, scales) if isinstance(scales, (int, float)) else scales
146
+ s = scales if isinstance(scales, torch.Tensor) else self.new_tensor(scales)
147
+
148
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
149
+ return self.__class__(torch.cat([self.size * s, self.f * s, self.c * s, dist], -1))
150
+
151
+ def crop(self, pad: Tuple[float]):
152
+ """Update the self parameters after cropping an image."""
153
+ pad = pad if isinstance(pad, torch.Tensor) else self.new_tensor(pad)
154
+ size = self.size + pad.to(self.size)
155
+ c = self.c + pad.to(self.c) / 2
156
+
157
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
158
+ return self.__class__(torch.cat([size, self.f, c, dist], -1))
159
+
160
+ def undo_scale_crop(self, data: Dict[str, torch.Tensor]):
161
+ """Undo transforms done during scaling and cropping."""
162
+ camera = self.crop(-data["crop_pad"]) if "crop_pad" in data else self
163
+ return camera.scale(1.0 / data["scales"])
164
+
165
+ @autocast
166
+ def in_image(self, p2d: torch.Tensor):
167
+ """Check if 2D points are within the image boundaries."""
168
+ assert p2d.shape[-1] == 2
169
+ size = self.size.unsqueeze(-2)
170
+ return torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
171
+
172
+ @autocast
173
+ def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
174
+ """Project 3D points into the self plane and check for visibility."""
175
+ z = p3d[..., -1]
176
+ valid = z > self.eps
177
+ z = z.clamp(min=self.eps)
178
+ p2d = p3d[..., :-1] / z.unsqueeze(-1)
179
+ return p2d, valid
180
+
181
+ def J_project(self, p3d: torch.Tensor):
182
+ """Jacobian of the projection function."""
183
+ x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
184
+ zero = torch.zeros_like(z)
185
+ z = z.clamp(min=self.eps)
186
+ J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
187
+ J = J.reshape(p3d.shape[:-1] + (2, 3))
188
+ return J # N x 2 x 3
189
+
190
+ @abstractmethod
191
+ def distort(self, pts: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
192
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
193
+ raise NotImplementedError("distort() must be implemented.")
194
+
195
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
196
+ """Jacobian of the distortion function."""
197
+ if wrt == "scale2pts": # (..., 2)
198
+ J = [
199
+ vmap(jacfwd(lambda x: self[idx].distort(x, return_scale=True)[0]))(p2d[idx])[None]
200
+ for idx in range(p2d.shape[0])
201
+ ]
202
+
203
+ return torch.cat(J, dim=0).squeeze(-3, -2)
204
+
205
+ elif wrt == "scale2dist": # (..., 1)
206
+ J = []
207
+ for idx in range(p2d.shape[0]): # loop to batch pts dimension
208
+
209
+ def func(x):
210
+ params = torch.cat([self._data[idx, :6], x[None]], -1)
211
+ return self.__class__(params).distort(p2d[idx], return_scale=True)[0]
212
+
213
+ J.append(vmap(jacfwd(func))(self[idx].dist))
214
+
215
+ return torch.cat(J, dim=0)
216
+
217
+ else:
218
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
219
+
220
+ @abstractmethod
221
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
222
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
223
+ raise NotImplementedError("undistort() must be implemented.")
224
+
225
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
226
+ """Jacobian of the undistortion function."""
227
+ if wrt == "pts": # (..., 2, 2)
228
+ J = [
229
+ vmap(jacfwd(lambda x: self[idx].undistort(x)[0]))(p2d[idx])[None]
230
+ for idx in range(p2d.shape[0])
231
+ ]
232
+
233
+ return torch.cat(J, dim=0).squeeze(-3)
234
+
235
+ elif wrt == "dist": # (..., 1)
236
+ J = []
237
+ for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
238
+
239
+ def func(x):
240
+ params = torch.cat([self._data[batch_idx, :6], x[None]], -1)
241
+ return self.__class__(params).undistort(p2d[batch_idx])[0]
242
+
243
+ J.append(vmap(jacfwd(func))(self[batch_idx].dist))
244
+
245
+ return torch.cat(J, dim=0)
246
+ else:
247
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
248
+
249
+ @autocast
250
+ def up_projection_offset(self, p2d: torch.Tensor) -> torch.Tensor:
251
+ """Compute the offset for the up-projection."""
252
+ return self.J_distort(p2d, wrt="scale2pts") # (B, N, 2)
253
+
254
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
255
+ """Jacobian of the distortion offset for up-projection."""
256
+ if wrt == "uv": # (B, N, 2, 2)
257
+ J = [
258
+ vmap(jacfwd(lambda x: self[idx].up_projection_offset(x)[0, 0]))(p2d[idx])[None]
259
+ for idx in range(p2d.shape[0])
260
+ ]
261
+
262
+ return torch.cat(J, dim=0)
263
+
264
+ elif wrt == "dist": # (B, N, 2)
265
+ J = []
266
+ for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
267
+
268
+ def func(x):
269
+ params = torch.cat([self._data[batch_idx, :6], x[None]], -1)[None]
270
+ return self.__class__(params).up_projection_offset(p2d[batch_idx][None])
271
+
272
+ J.append(vmap(jacfwd(func))(self[batch_idx].dist))
273
+
274
+ return torch.cat(J, dim=0).squeeze(1)
275
+ else:
276
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
277
+
278
+ @autocast
279
+ def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
280
+ """Convert normalized 2D coordinates into pixel coordinates."""
281
+ return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
282
+
283
+ def J_denormalize(self):
284
+ """Jacobian of the denormalization function."""
285
+ return torch.diag_embed(self.f) # ..., 2 x 2
286
+
287
+ @autocast
288
+ def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
289
+ """Convert pixel coordinates into normalized 2D coordinates."""
290
+ return (p2d - self.c.unsqueeze(-2)) / (self.f.unsqueeze(-2))
291
+
292
+ def J_normalize(self, p2d: torch.Tensor, wrt: str = "f"):
293
+ """Jacobian of the normalization function."""
294
+ # ... x N x 2 x 2
295
+ if wrt == "f":
296
+ J_f = -(p2d - self.c.unsqueeze(-2)) / ((self.f.unsqueeze(-2)) ** 2)
297
+ return torch.diag_embed(J_f)
298
+ elif wrt == "pts":
299
+ J_pts = 1 / self.f
300
+ return torch.diag_embed(J_pts)
301
+ else:
302
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
303
+
304
+ def pixel_coordinates(self) -> torch.Tensor:
305
+ """Pixel coordinates in self frame.
306
+
307
+ Returns:
308
+ torch.Tensor: Pixel coordinates as a tensor of shape (B, h * w, 2).
309
+ """
310
+ w, h = self.size[0].unbind(-1)
311
+ h, w = h.round().to(int), w.round().to(int)
312
+
313
+ # create grid
314
+ x = torch.arange(0, w, dtype=self.dtype, device=self.device)
315
+ y = torch.arange(0, h, dtype=self.dtype, device=self.device)
316
+ x, y = torch.meshgrid(x, y, indexing="xy")
317
+ xy = torch.stack((x, y), dim=-1).reshape(-1, 2) # shape (h * w, 2)
318
+
319
+ # add batch dimension (normalize() would broadcast but we make it explicit)
320
+ B = self.shape[0]
321
+ xy = xy.unsqueeze(0).expand(B, -1, -1) # if B > 0 else xy
322
+
323
+ return xy.to(self.device).to(self.dtype)
324
+
325
+ def normalized_image_coordinates(self) -> torch.Tensor:
326
+ """Normalized image coordinates in self frame.
327
+
328
+ Returns:
329
+ torch.Tensor: Normalized image coordinates as a tensor of shape (B, h * w, 3).
330
+ """
331
+ xy = self.pixel_coordinates()
332
+ uv1, _ = self.image2world(xy)
333
+
334
+ B = self.shape[0]
335
+ uv1 = uv1.reshape(B, -1, 3)
336
+ return uv1.to(self.device).to(self.dtype)
337
+
338
+ @autocast
339
+ def pixel_bearing_many(self, p3d: torch.Tensor) -> torch.Tensor:
340
+ """Get the bearing vectors of pixel coordinates.
341
+
342
+ Args:
343
+ p2d (torch.Tensor): Pixel coordinates as a tensor of shape (..., 3).
344
+
345
+ Returns:
346
+ torch.Tensor: Bearing vectors as a tensor of shape (..., 3).
347
+ """
348
+ return F.normalize(p3d, dim=-1)
349
+
350
+ @autocast
351
+ def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
352
+ """Transform 3D points into 2D pixel coordinates."""
353
+ p2d, visible = self.project(p3d)
354
+ p2d, mask = self.distort(p2d)
355
+ p2d = self.denormalize(p2d)
356
+ valid = visible & mask & self.in_image(p2d)
357
+ return p2d, valid
358
+
359
+ @autocast
360
+ def J_world2image(self, p3d: torch.Tensor):
361
+ """Jacobian of the world2image function."""
362
+ p2d_proj, valid = self.project(p3d)
363
+
364
+ J_dnorm = self.J_denormalize()
365
+ J_dist = self.J_distort(p2d_proj)
366
+ J_proj = self.J_project(p3d)
367
+
368
+ J = torch.einsum("...ij,...jk,...kl->...il", J_dnorm, J_dist, J_proj)
369
+ return J, valid
370
+
371
+ @autocast
372
+ def image2world(self, p2d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
373
+ """Transform point in the image plane to 3D world coordinates."""
374
+ p2d = self.normalize(p2d)
375
+ p2d, valid = self.undistort(p2d)
376
+ ones = p2d.new_ones(p2d.shape[:-1] + (1,))
377
+ p3d = torch.cat([p2d, ones], -1)
378
+ return p3d, valid
379
+
380
+ @autocast
381
+ def J_image2world(self, p2d: torch.Tensor, wrt: str = "f") -> Tuple[torch.Tensor, torch.Tensor]:
382
+ """Jacobian of the image2world function."""
383
+ if wrt == "dist":
384
+ p2d_norm = self.normalize(p2d)
385
+ return self.J_undistort(p2d_norm, wrt)
386
+ elif wrt == "f":
387
+ J_norm2f = self.J_normalize(p2d, wrt)
388
+ p2d_norm = self.normalize(p2d)
389
+ J_dist2norm = self.J_undistort(p2d_norm, "pts")
390
+
391
+ return torch.einsum("...ij,...jk->...ik", J_dist2norm, J_norm2f)
392
+ else:
393
+ raise ValueError(f"Unknown wrt: {wrt}")
394
+
395
+ @autocast
396
+ def undistort_image(self, img: torch.Tensor) -> torch.Tensor:
397
+ """Undistort an image using the distortion model."""
398
+ assert self.shape[0] == 1, "Batch size must be 1."
399
+ W, H = self.size.unbind(-1)
400
+ H, W = H.int().item(), W.int().item()
401
+
402
+ x, y = torch.arange(0, W), torch.arange(0, H)
403
+ x, y = torch.meshgrid(x, y, indexing="xy")
404
+ coords = torch.stack((x, y), dim=-1).reshape(-1, 2)
405
+
406
+ p3d, _ = self.pinhole().image2world(coords.to(self.device).to(self.dtype))
407
+ p2d, _ = self.world2image(p3d)
408
+
409
+ mapx, mapy = p2d[..., 0].reshape((1, H, W)), p2d[..., 1].reshape((1, H, W))
410
+ grid = torch.stack((mapx, mapy), dim=-1)
411
+ grid = 2.0 * grid / torch.tensor([W - 1, H - 1]).to(grid) - 1
412
+ return F.grid_sample(img, grid, align_corners=True)
413
+
414
+ def get_img_from_pano(
415
+ self,
416
+ pano_img: torch.Tensor,
417
+ gravity: Gravity,
418
+ yaws: torch.Tensor = 0.0,
419
+ resize_factor: Optional[torch.Tensor] = None,
420
+ ) -> torch.Tensor:
421
+ """Render an image from a panorama.
422
+
423
+ Args:
424
+ pano_img (torch.Tensor): Panorama image of shape (3, H, W) in [0, 1].
425
+ gravity (Gravity): Gravity direction of the camera.
426
+ yaws (torch.Tensor | list, optional): Yaw angle in radians. Defaults to 0.0.
427
+ resize_factor (torch.Tensor, optional): Resize the panorama to be a multiple of the
428
+ field of view. Defaults to 1.
429
+
430
+ Returns:
431
+ torch.Tensor: Image rendered from the panorama.
432
+ """
433
+ B = self.shape[0]
434
+ if B > 0:
435
+ assert self.size[..., 0].unique().shape[0] == 1, "All images must have the same width."
436
+ assert self.size[..., 1].unique().shape[0] == 1, "All images must have the same height."
437
+
438
+ w, h = self.size[0].unbind(-1)
439
+ h, w = h.round().to(int), w.round().to(int)
440
+
441
+ if isinstance(yaws, (int, float)):
442
+ yaws = [yaws]
443
+ if isinstance(resize_factor, (int, float)):
444
+ resize_factor = [resize_factor]
445
+
446
+ yaws = (
447
+ yaws.to(self.dtype).to(self.device)
448
+ if isinstance(yaws, torch.Tensor)
449
+ else self.new_tensor(yaws)
450
+ )
451
+
452
+ if isinstance(resize_factor, torch.Tensor):
453
+ resize_factor = resize_factor.to(self.dtype).to(self.device)
454
+ elif resize_factor is not None:
455
+ resize_factor = self.new_tensor(resize_factor)
456
+
457
+ assert isinstance(pano_img, torch.Tensor), "Panorama image must be a torch.Tensor."
458
+ pano_img = pano_img if pano_img.dim() == 4 else pano_img.unsqueeze(0) # B x 3 x H x W
459
+
460
+ pano_imgs = []
461
+ for i, yaw in enumerate(yaws):
462
+ if resize_factor is not None:
463
+ # resize the panorama such that the fov of the panorama has the same height as the
464
+ # image
465
+ vfov = self.vfov[i] if B != 0 else self.vfov
466
+ scale = torch.pi / float(vfov) * float(h) / pano_img.shape[-2] * resize_factor[i]
467
+ pano_shape = (int(pano_img.shape[-2] * scale), int(pano_img.shape[-1] * scale))
468
+
469
+ mode = "bicubic" if scale >= 1 else "area"
470
+ resized_pano = F.interpolate(pano_img, size=pano_shape, mode=mode)
471
+ else:
472
+ # make sure to copy: resized_pano = pano_img
473
+ resized_pano = pano_img
474
+ pano_shape = pano_img.shape[-2:][::-1]
475
+
476
+ pano_imgs.append((resized_pano, pano_shape))
477
+
478
+ xy = self.pixel_coordinates()
479
+ uv1, valid = self.image2world(xy)
480
+ bearings = self.pixel_bearing_many(uv1)
481
+
482
+ # rotate bearings
483
+ R_yaw = rad2rotmat(self.new_zeros(yaw.shape), self.new_zeros(yaw.shape), yaws)
484
+ rotated_bearings = bearings @ gravity.R @ R_yaw
485
+
486
+ # spherical coordinates
487
+ lon = torch.atan2(rotated_bearings[..., 0], rotated_bearings[..., 2])
488
+ lat = torch.atan2(
489
+ rotated_bearings[..., 1], torch.norm(rotated_bearings[..., [0, 2]], dim=-1)
490
+ )
491
+
492
+ images = []
493
+ for idx, (resized_pano, pano_shape) in enumerate(pano_imgs):
494
+ min_lon, max_lon = -torch.pi, torch.pi
495
+ min_lat, max_lat = -torch.pi / 2.0, torch.pi / 2.0
496
+ min_x, max_x = 0, pano_shape[0] - 1.0
497
+ min_y, max_y = 0, pano_shape[1] - 1.0
498
+
499
+ # map Spherical Coordinates to Panoramic Coordinates
500
+ nx = (lon[idx] - min_lon) / (max_lon - min_lon) * (max_x - min_x) + min_x
501
+ ny = (lat[idx] - min_lat) / (max_lat - min_lat) * (max_y - min_y) + min_y
502
+
503
+ # reshape and cast to numpy for remap
504
+ mapx = nx.reshape((1, h, w))
505
+ mapy = ny.reshape((1, h, w))
506
+
507
+ grid = torch.stack((mapx, mapy), dim=-1) # Add batch dimension
508
+ # Normalize to [-1, 1]
509
+ grid = 2.0 * grid / torch.tensor([pano_shape[-2] - 1, pano_shape[-1] - 1]).to(grid) - 1
510
+ # Apply grid sample
511
+ image = F.grid_sample(resized_pano, grid, align_corners=True)#True
512
+ images.append(image)
513
+
514
+ return torch.concatenate(images, 0) if B > 0 else images[0]
515
+
516
+ def __repr__(self):
517
+ """Print the Camera object."""
518
+ return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
scripts/camera/geometry/camera.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of the pinhole, simple radial, and simple divisional camera models."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ from typing import Tuple
5
+
6
+ import torch
7
+
8
+ from scripts.camera.geometry.base_camera import BaseCamera
9
+ from scripts.camera.utils.tensor import autocast
10
+
11
+ # flake8: noqa: E741
12
+
13
+ # mypy: ignore-errors
14
+
15
+
16
+ class Pinhole(BaseCamera):
17
+ """Implementation of the pinhole camera model."""
18
+
19
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
20
+ """Distort normalized 2D coordinates."""
21
+ if return_scale:
22
+ return p2d.new_ones(p2d.shape[:-1] + (1,))
23
+
24
+ return p2d, p2d.new_ones((p2d.shape[0], 1)).bool()
25
+
26
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
27
+ """Jacobian of the distortion function."""
28
+ if wrt == "pts":
29
+ return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
30
+ else:
31
+ raise ValueError(f"Unknown wrt: {wrt}")
32
+
33
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
34
+ """Undistort normalized 2D coordinates."""
35
+ return pts, pts.new_ones((pts.shape[0], 1)).bool()
36
+
37
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
38
+ """Jacobian of the undistortion function."""
39
+ if wrt == "pts":
40
+ return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
41
+ else:
42
+ raise ValueError(f"Unknown wrt: {wrt}")
43
+
44
+
45
+ class SimpleRadial(BaseCamera):
46
+ """Implementation of the simple radial camera model."""
47
+
48
+ @property
49
+ def dist(self) -> torch.Tensor:
50
+ """Distortion parameters, with shape (..., 1)."""
51
+ return self._data[..., 6:]
52
+
53
+ @property
54
+ def k1(self) -> torch.Tensor:
55
+ """Distortion parameters, with shape (...)."""
56
+ return self._data[..., 6]
57
+
58
+ @property
59
+ def k1_hat(self) -> torch.Tensor:
60
+ """Distortion parameters, with shape (...)."""
61
+ return self.k1 / (self.f[..., 1] / self.size[..., 1]) ** 2
62
+
63
+ def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-0.7, 0.7)):
64
+ """Update the self parameters after changing the k1 distortion parameter."""
65
+ delta_dist = self.new_ones(self.dist.shape) * delta
66
+ dist = (self.dist + delta_dist).clamp(*dist_range)
67
+ data = torch.cat([self.size, self.f, self.c, dist], -1)
68
+ return self.__class__(data)
69
+
70
+ @autocast
71
+ def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
72
+ """Check if the distorted points are valid."""
73
+ return p2d.new_ones(p2d.shape[:-1]).bool()
74
+
75
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
76
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
77
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
78
+ radial = 1 + self.k1[..., None, None] * r2
79
+
80
+ if return_scale:
81
+ return radial, None
82
+
83
+ return p2d * radial, self.check_valid(p2d)
84
+
85
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
86
+ """Jacobian of the distortion function."""
87
+ k1 = self.k1[..., None, None]
88
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
89
+ if wrt == "pts": # (..., 2, 2)
90
+ radial = 1 + k1 * r2
91
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
92
+ return (2 * k1 * ppT) + torch.diag_embed(radial.expand(radial.shape[:-1] + (2,)))
93
+ elif wrt == "dist": # (..., 2)
94
+ return r2 * p2d
95
+ elif wrt == "scale2dist": # (..., 1)
96
+ return r2
97
+ elif wrt == "scale2pts": # (..., 2)
98
+ return 2 * k1 * p2d
99
+ else:
100
+ return super().J_distort(p2d, wrt)
101
+
102
+ @autocast
103
+ def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
104
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
105
+ b1 = -self.k1[..., None, None]
106
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
107
+ radial = 1 + b1 * r2
108
+ return p2d * radial, self.check_valid(p2d)
109
+
110
+ @autocast
111
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
112
+ """Jacobian of the undistortion function."""
113
+ b1 = -self.k1[..., None, None]
114
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
115
+ if wrt == "dist":
116
+ return -r2 * p2d
117
+ elif wrt == "pts":
118
+ radial = 1 + b1 * r2
119
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
120
+ return (2 * b1[..., None] * ppT) + torch.diag_embed(
121
+ radial.expand(radial.shape[:-1] + (2,))
122
+ )
123
+ else:
124
+ return super().J_undistort(p2d, wrt)
125
+
126
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
127
+ """Jacobian of the up-projection offset."""
128
+ if wrt == "uv": # (..., 2, 2)
129
+ return torch.diag_embed((2 * self.k1[..., None, None]).expand(p2d.shape[:-1] + (2,)))
130
+ elif wrt == "dist":
131
+ return 2 * p2d # (..., 2)
132
+ else:
133
+ return super().J_up_projection_offset(p2d, wrt)
134
+
135
+
136
+ class SimpleDivisional(BaseCamera):
137
+ """Implementation of the simple divisional camera model."""
138
+
139
+ @property
140
+ def dist(self) -> torch.Tensor:
141
+ """Distortion parameters, with shape (..., 1)."""
142
+ return self._data[..., 6:]
143
+
144
+ @property
145
+ def k1(self) -> torch.Tensor:
146
+ """Distortion parameters, with shape (...)."""
147
+ return self._data[..., 6]
148
+
149
+ def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-3.0, 3.0)):
150
+ """Update the self parameters after changing the k1 distortion parameter."""
151
+ delta_dist = self.new_ones(self.dist.shape) * delta
152
+ dist = (self.dist + delta_dist).clamp(*dist_range)
153
+ data = torch.cat([self.size, self.f, self.c, dist], -1)
154
+ return self.__class__(data)
155
+
156
+ @autocast
157
+ def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
158
+ """Check if the distorted points are valid."""
159
+ return p2d.new_ones(p2d.shape[:-1]).bool()
160
+
161
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
162
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
163
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
164
+ radial = 1 - torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=0))
165
+ denom = 2 * self.k1[..., None, None] * r2
166
+
167
+ ones = radial.new_ones(radial.shape)
168
+ radial = torch.where(denom == 0, ones, radial / denom.masked_fill(denom == 0, 1e6))
169
+
170
+ if return_scale:
171
+ return radial, None
172
+
173
+ return p2d * radial, self.check_valid(p2d)
174
+
175
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
176
+ """Jacobian of the distortion function."""
177
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
178
+ t0 = torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=1e-6))
179
+ if wrt == "scale2pts": # (B, N, 2)
180
+ d1 = t0 * 2 * r2
181
+ d2 = self.k1[..., None, None] * r2**2
182
+ denom = d1 * d2
183
+ return p2d * (4 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
184
+
185
+ elif wrt == "scale2dist":
186
+ d1 = 2 * self.k1[..., None, None] * t0
187
+ d2 = 2 * r2 * self.k1[..., None, None] ** 2
188
+ denom = d1 * d2
189
+ return (2 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
190
+
191
+ else:
192
+ return super().J_distort(p2d, wrt)
193
+
194
+ @autocast
195
+ def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
196
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
197
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
198
+ denom = 1 + self.k1[..., None, None] * r2
199
+ radial = 1 / denom.masked_fill(denom == 0, 1e6)
200
+ return p2d * radial, self.check_valid(p2d)
201
+
202
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
203
+ """Jacobian of the undistortion function."""
204
+ # return super().J_undistort(p2d, wrt)
205
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
206
+ k1 = self.k1[..., None, None]
207
+ if wrt == "dist":
208
+ denom = (1 + k1 * r2) ** 2
209
+ return -r2 / denom.masked_fill(denom == 0, 1e6) * p2d
210
+ elif wrt == "pts":
211
+ t0 = 1 + k1 * r2
212
+ t0 = t0.masked_fill(t0 == 0, 1e6)
213
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
214
+ J = torch.diag_embed((1 / t0).expand(p2d.shape[:-1] + (2,)))
215
+ return J - 2 * k1[..., None] * ppT / t0[..., None] ** 2 # (..., N, 2, 2)
216
+
217
+ else:
218
+ return super().J_undistort(p2d, wrt)
219
+
220
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
221
+ """Jacobian of the up-projection offset.
222
+
223
+ func(uv, dist) = 4 / (2 * norm2(uv)^2 * (1-4*k1*norm2(uv)^2)^0.5) * uv
224
+ - (1-(1-4*k1*norm2(uv)^2)^0.5) / (k1 * norm2(uv)^4) * uv
225
+ """
226
+ k1 = self.k1[..., None, None]
227
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
228
+ t0 = (1 - 4 * k1 * r2).clamp(min=1e-6)
229
+ t1 = torch.sqrt(t0)
230
+ if wrt == "dist":
231
+ denom = 4 * t0 ** (3 / 2)
232
+ denom = denom.masked_fill(denom == 0, 1e6)
233
+ J = 16 / denom
234
+
235
+ denom = r2 * t1 * k1
236
+ denom = denom.masked_fill(denom == 0, 1e6)
237
+ J = J - 2 / denom
238
+
239
+ denom = (r2 * k1) ** 2
240
+ denom = denom.masked_fill(denom == 0, 1e6)
241
+ J = J + (1 - t1) / denom
242
+
243
+ return J * p2d
244
+ elif wrt == "uv":
245
+ # ! unstable (gradient checker might fail), rewrite to use single division (by denom)
246
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
247
+
248
+ denom = 2 * r2 * t1
249
+ denom = denom.masked_fill(denom == 0, 1e6)
250
+ J = torch.diag_embed((4 / denom).expand(p2d.shape[:-1] + (2,)))
251
+
252
+ denom = 4 * t1 * r2**2
253
+ denom = denom.masked_fill(denom == 0, 1e6)
254
+ J = J - 16 / denom[..., None] * ppT
255
+
256
+ denom = 4 * r2 * t0 ** (3 / 2)
257
+ denom = denom.masked_fill(denom == 0, 1e6)
258
+ J = J + (32 * k1[..., None]) / denom[..., None] * ppT
259
+
260
+ denom = r2**2 * t1
261
+ denom = denom.masked_fill(denom == 0, 1e6)
262
+ J = J - 4 / denom[..., None] * ppT
263
+
264
+ denom = k1 * r2**3
265
+ denom = denom.masked_fill(denom == 0, 1e6)
266
+ J = J + (4 * (1 - t1) / denom)[..., None] * ppT
267
+
268
+ denom = k1 * r2**2
269
+ denom = denom.masked_fill(denom == 0, 1e6)
270
+ J = J - torch.diag_embed(((1 - t1) / denom).expand(p2d.shape[:-1] + (2,)))
271
+
272
+ return J
273
+ else:
274
+ return super().J_up_projection_offset(p2d, wrt)
275
+
276
+
277
+ camera_models = {
278
+ "pinhole": Pinhole,
279
+ "simple_radial": SimpleRadial,
280
+ "simple_divisional": SimpleDivisional,
281
+ }
scripts/camera/geometry/gravity.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tensor class for gravity vector in camera frame."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+ from scripts.camera.geometry.manifolds import EuclideanManifold, SphericalManifold
8
+ from scripts.camera.utils.conversions import rad2rotmat
9
+ from scripts.camera.utils.tensor import TensorWrapper, autocast
10
+
11
+ # mypy: ignore-errors
12
+
13
+
14
+ class Gravity(TensorWrapper):
15
+ """Gravity vector in camera frame."""
16
+
17
+ eps = 1e-4
18
+
19
+ @autocast
20
+ def __init__(self, data: torch.Tensor) -> None:
21
+ """Create gravity vector from data.
22
+
23
+ Args:
24
+ data (torch.Tensor): gravity vector as 3D vector in camera frame.
25
+ """
26
+ assert data.shape[-1] == 3, data.shape
27
+
28
+ data = F.normalize(data, dim=-1)
29
+
30
+ super().__init__(data)
31
+
32
+ @classmethod
33
+ def from_rp(cls, roll: torch.Tensor, pitch: torch.Tensor) -> "Gravity":
34
+ """Create gravity vector from roll and pitch angles."""
35
+ if not isinstance(roll, torch.Tensor):
36
+ roll = torch.tensor(roll)
37
+ if not isinstance(pitch, torch.Tensor):
38
+ pitch = torch.tensor(pitch)
39
+
40
+ sr, cr = torch.sin(roll), torch.cos(roll)
41
+ sp, cp = torch.sin(pitch), torch.cos(pitch)
42
+ return cls(torch.stack([-sr * cp, -cr * cp, sp], dim=-1))
43
+
44
+ @property
45
+ def vec3d(self) -> torch.Tensor:
46
+ """Return the gravity vector in the representation."""
47
+ return self._data
48
+
49
+ @property
50
+ def x(self) -> torch.Tensor:
51
+ """Return first component of the gravity vector."""
52
+ return self._data[..., 0]
53
+
54
+ @property
55
+ def y(self) -> torch.Tensor:
56
+ """Return second component of the gravity vector."""
57
+ return self._data[..., 1]
58
+
59
+ @property
60
+ def z(self) -> torch.Tensor:
61
+ """Return third component of the gravity vector."""
62
+ return self._data[..., 2]
63
+
64
+ @property
65
+ def roll(self) -> torch.Tensor:
66
+ """Return the roll angle of the gravity vector."""
67
+ roll = torch.asin(-self.x / (torch.sqrt(1 - self.z**2) + self.eps))
68
+ offset = -torch.pi * torch.sign(self.x)
69
+ return torch.where(self.y < 0, roll, -roll + offset)
70
+
71
+ def J_roll(self) -> torch.Tensor:
72
+ """Return the Jacobian of the roll angle of the gravity vector."""
73
+ cp, _ = torch.cos(self.pitch), torch.sin(self.pitch)
74
+ cr, sr = torch.cos(self.roll), torch.sin(self.roll)
75
+ Jr = self.new_zeros(self.shape + (3,))
76
+ Jr[..., 0] = -cr * cp
77
+ Jr[..., 1] = sr * cp
78
+ return Jr
79
+
80
+ @property
81
+ def pitch(self) -> torch.Tensor:
82
+ """Return the pitch angle of the gravity vector."""
83
+ return torch.asin(self.z)
84
+
85
+ def J_pitch(self) -> torch.Tensor:
86
+ """Return the Jacobian of the pitch angle of the gravity vector."""
87
+ cp, sp = torch.cos(self.pitch), torch.sin(self.pitch)
88
+ cr, sr = torch.cos(self.roll), torch.sin(self.roll)
89
+
90
+ Jp = self.new_zeros(self.shape + (3,))
91
+ Jp[..., 0] = sr * sp
92
+ Jp[..., 1] = cr * sp
93
+ Jp[..., 2] = cp
94
+ return Jp
95
+
96
+ @property
97
+ def rp(self) -> torch.Tensor:
98
+ """Return the roll and pitch angles of the gravity vector."""
99
+ return torch.stack([self.roll, self.pitch], dim=-1)
100
+
101
+ def J_rp(self) -> torch.Tensor:
102
+ """Return the Jacobian of the roll and pitch angles of the gravity vector."""
103
+ return torch.stack([self.J_roll(), self.J_pitch()], dim=-1)
104
+
105
+ @property
106
+ def R(self) -> torch.Tensor:
107
+ """Return the rotation matrix from the gravity vector."""
108
+ return rad2rotmat(roll=self.roll, pitch=self.pitch)
109
+
110
+ def J_R(self) -> torch.Tensor:
111
+ """Return the Jacobian of the rotation matrix from the gravity vector."""
112
+ raise NotImplementedError
113
+
114
+ def update(self, delta: torch.Tensor, spherical: bool = False) -> "Gravity":
115
+ """Update the gravity vector by adding a delta."""
116
+ if spherical:
117
+ data = SphericalManifold.plus(self.vec3d, delta)
118
+ return self.__class__(data)
119
+
120
+ data = EuclideanManifold.plus(self.rp, delta)
121
+ return self.from_rp(data[..., 0], data[..., 1])
122
+
123
+ def J_update(self, spherical: bool = False) -> torch.Tensor:
124
+ """Return the Jacobian of the update."""
125
+ return SphericalManifold if spherical else EuclideanManifold
126
+
127
+ def __repr__(self):
128
+ """Print the Camera object."""
129
+ return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
scripts/camera/geometry/jacobians.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Jacobians for optimization."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ import torch
5
+
6
+
7
+ @torch.jit.script
8
+ def J_vecnorm(vec: torch.Tensor) -> torch.Tensor:
9
+ """Compute the jacobian of vec / norm2(vec).
10
+
11
+ Args:
12
+ vec (torch.Tensor): [..., D] tensor.
13
+
14
+ Returns:
15
+ torch.Tensor: [..., D, D] Jacobian.
16
+ """
17
+ D = vec.shape[-1]
18
+ norm_x = torch.norm(vec, dim=-1, keepdim=True).unsqueeze(-1) # (..., 1, 1)
19
+
20
+ if (norm_x == 0).any():
21
+ norm_x = norm_x + 1e-6
22
+
23
+ xxT = torch.einsum("...i,...j->...ij", vec, vec) # (..., D, D)
24
+ identity = torch.eye(D, device=vec.device, dtype=vec.dtype) # (D, D)
25
+
26
+ return identity / norm_x - (xxT / norm_x**3) # (..., D, D)
27
+
28
+
29
+ @torch.jit.script
30
+ def J_focal2fov(focal: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
31
+ """Compute the jacobian of the focal2fov function."""
32
+ return -4 * h / (4 * focal**2 + h**2)
33
+
34
+
35
+ @torch.jit.script
36
+ def J_up_projection(uv: torch.Tensor, abc: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
37
+ """Compute the jacobian of the up-vector projection.
38
+
39
+ Args:
40
+ uv (torch.Tensor): Normalized image coordinates of shape (..., 2).
41
+ abc (torch.Tensor): Gravity vector of shape (..., 3).
42
+ wrt (str, optional): Parameter to differentiate with respect to. Defaults to "uv".
43
+
44
+ Raises:
45
+ ValueError: If the wrt parameter is unknown.
46
+
47
+ Returns:
48
+ torch.Tensor: Jacobian with respect to the parameter.
49
+ """
50
+ if wrt == "uv":
51
+ c = abc[..., 2][..., None, None, None]
52
+ return -c * torch.eye(2, device=uv.device, dtype=uv.dtype).expand(uv.shape[:-1] + (2, 2))
53
+
54
+ elif wrt == "abc":
55
+ J = uv.new_zeros(uv.shape[:-1] + (2, 3))
56
+ J[..., 0, 0] = 1
57
+ J[..., 1, 1] = 1
58
+ J[..., 0, 2] = -uv[..., 0]
59
+ J[..., 1, 2] = -uv[..., 1]
60
+ return J
61
+
62
+ else:
63
+ raise ValueError(f"Unknown wrt: {wrt}")
scripts/camera/geometry/manifolds.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of manifolds."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ import logging
5
+
6
+ import torch
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class EuclideanManifold:
12
+ """Simple euclidean manifold."""
13
+
14
+ @staticmethod
15
+ def J_plus(x: torch.Tensor) -> torch.Tensor:
16
+ """Plus operator Jacobian."""
17
+ return torch.eye(x.shape[-1]).to(x)
18
+
19
+ @staticmethod
20
+ def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
21
+ """Plus operator."""
22
+ return x + delta
23
+
24
+
25
+ class SphericalManifold:
26
+ """Implementation of the spherical manifold.
27
+
28
+ Following the derivation from 'Integrating Generic Sensor Fusion Algorithms with Sound State
29
+ Representations through Encapsulation of Manifolds' by Hertzberg et al. (B.2, p. 25).
30
+
31
+ Householder transformation following Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by
32
+ Golub et al.
33
+ """
34
+
35
+ @staticmethod
36
+ def householder_vector(x: torch.Tensor) -> torch.Tensor:
37
+ """Return the Householder vector and beta.
38
+
39
+ Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by Golub et al. (Johns Hopkins Studies
40
+ in Mathematical Sciences) but using the nth element of the input vector as pivot instead of
41
+ first.
42
+
43
+ This computes the vector v with v(n) = 1 and beta such that H = I - beta * v * v^T is
44
+ orthogonal and H * x = ||x||_2 * e_n.
45
+
46
+ Args:
47
+ x (torch.Tensor): [..., n] tensor.
48
+
49
+ Returns:
50
+ torch.Tensor: v of shape [..., n]
51
+ torch.Tensor: beta of shape [...]
52
+ """
53
+ sigma = torch.sum(x[..., :-1] ** 2, -1)
54
+ xpiv = x[..., -1]
55
+ norm = torch.norm(x, dim=-1)
56
+ if torch.any(sigma < 1e-7):
57
+ sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma)
58
+ logger.warning("sigma < 1e-7")
59
+
60
+ vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm))
61
+ beta = 2 * vpiv**2 / (sigma + vpiv**2)
62
+ v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1)
63
+ return v, beta
64
+
65
+ @staticmethod
66
+ def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
67
+ """Apply Householder transformation.
68
+
69
+ Args:
70
+ y (torch.Tensor): Vector to transform of shape [..., n].
71
+ v (torch.Tensor): Householder vector of shape [..., n].
72
+ beta (torch.Tensor): Householder beta of shape [...].
73
+
74
+ Returns:
75
+ torch.Tensor: Transformed vector of shape [..., n].
76
+ """
77
+ return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None]
78
+
79
+ @classmethod
80
+ def J_plus(cls, x: torch.Tensor) -> torch.Tensor:
81
+ """Plus operator Jacobian."""
82
+ v, beta = cls.householder_vector(x)
83
+ H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v)
84
+ H = H + torch.eye(H.shape[-1]).to(H)
85
+ return H[..., :-1] # J
86
+
87
+ @classmethod
88
+ def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
89
+ """Plus operator.
90
+
91
+ Equation 109 (p. 25) from 'Integrating Generic Sensor Fusion Algorithms with Sound State
92
+ Representations through Encapsulation of Manifolds' by Hertzberg et al. but using the nth
93
+ element of the input vector as pivot instead of first.
94
+
95
+ Args:
96
+ x: point on the manifold
97
+ delta: tangent vector
98
+ """
99
+ eps = 1e-7
100
+ # keep norm is not equal to 1
101
+ nx = torch.norm(x, dim=-1, keepdim=True)
102
+ nd = torch.norm(delta, dim=-1, keepdim=True)
103
+
104
+ # make sure we don't divide by zero in backward as torch.where computes grad for both
105
+ # branches
106
+ nd_ = torch.where(nd < eps, nd + eps, nd)
107
+ sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_)
108
+
109
+ # cos is applied to last dim instead of first
110
+ exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1)
111
+
112
+ v, beta = cls.householder_vector(x)
113
+ return nx * cls.apply_householder(exp_delta, v, beta)
scripts/camera/geometry/perspective_fields.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of perspective fields.
2
+
3
+ Adapted from https://github.com/jinlinyi/PerspectiveFields/blob/main/perspective2d/utils/panocam.py
4
+ """
5
+
6
+ from typing import Tuple
7
+
8
+ import torch
9
+ from torch.nn import functional as F
10
+
11
+ from scripts.camera.geometry.base_camera import BaseCamera
12
+ from scripts.camera.geometry.gravity import Gravity
13
+ from scripts.camera.geometry.jacobians import J_up_projection, J_vecnorm
14
+ from scripts.camera.geometry.manifolds import SphericalManifold
15
+
16
+ # flake8: noqa: E266
17
+
18
+
19
+ def get_horizon_line(camera: BaseCamera, gravity: Gravity, relative: bool = True) -> torch.Tensor:
20
+ """Get the horizon line from the camera parameters.
21
+
22
+ Args:
23
+ camera (Camera): Camera parameters.
24
+ gravity (Gravity): Gravity vector.
25
+ relative (bool, optional): Whether to normalize horizon line by img_h. Defaults to True.
26
+
27
+ Returns:
28
+ torch.Tensor: In image frame, fraction of image left/right border intersection with
29
+ respect to image height.
30
+ """
31
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
32
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
33
+
34
+ # project horizon midpoint to image plane
35
+ horizon_midpoint = camera.new_tensor([0, 0, 1])
36
+ horizon_midpoint = camera.K @ gravity.R @ horizon_midpoint
37
+ midpoint = horizon_midpoint[:2] / horizon_midpoint[2]
38
+
39
+ # compute left and right offset to borders
40
+ left_offset = midpoint[0] * torch.tan(gravity.roll)
41
+ right_offset = (camera.size[0] - midpoint[0]) * torch.tan(gravity.roll)
42
+ left, right = midpoint[1] + left_offset, midpoint[1] - right_offset
43
+
44
+ horizon = camera.new_tensor([left, right])
45
+ return horizon / camera.size[1] if relative else horizon
46
+
47
+
48
+ def get_up_field(camera: BaseCamera, gravity: Gravity, normalize: bool = True) -> torch.Tensor:
49
+ """Get the up vector field from the camera parameters.
50
+
51
+ Args:
52
+ camera (Camera): Camera parameters.
53
+ normalize (bool, optional): Whether to normalize the up vector. Defaults to True.
54
+
55
+ Returns:
56
+ torch.Tensor: up vector field as tensor of shape (..., h, w, 2).
57
+ """
58
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
59
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
60
+
61
+ w, h = camera.size[0].unbind(-1)
62
+ h, w = h.round().to(int), w.round().to(int)
63
+
64
+ uv = camera.normalize(camera.pixel_coordinates())
65
+
66
+ # projected up is (a, b) - c * (u, v)
67
+ abc = gravity.vec3d
68
+ projected_up2d = abc[..., None, :2] - abc[..., 2, None, None] * uv # (..., N, 2)
69
+
70
+ if hasattr(camera, "dist"):
71
+ d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
72
+ d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
73
+ offset = camera.up_projection_offset(uv) # (..., N, 2)
74
+ offset = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
75
+
76
+ # (..., N, 2)
77
+ projected_up2d = torch.einsum("...Nij,...Nj->...Ni", d_uv + offset, projected_up2d)
78
+
79
+ if normalize:
80
+ projected_up2d = F.normalize(projected_up2d, dim=-1) # (..., N, 2)
81
+
82
+ try:
83
+ del uv, abc, d_uv, offset
84
+ except NameError:
85
+ pass
86
+
87
+ return projected_up2d.reshape(camera.shape[0], h, w, 2)
88
+
89
+
90
+ def J_up_field(
91
+ camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
92
+ ) -> torch.Tensor:
93
+ """Get the jacobian of the up field.
94
+
95
+ Args:
96
+ camera (Camera): Camera parameters.
97
+ gravity (Gravity): Gravity vector.
98
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
99
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
100
+
101
+ Returns:
102
+ torch.Tensor: Jacobian of the up field as a tensor of shape (..., h, w, 2, 2, 3).
103
+ """
104
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
105
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
106
+
107
+ w, h = camera.size[0].unbind(-1)
108
+ h, w = h.round().to(int), w.round().to(int)
109
+
110
+ # Forward
111
+ xy = camera.pixel_coordinates()
112
+ uv = camera.normalize(xy)
113
+
114
+ projected_up2d = gravity.vec3d[..., None, :2] - gravity.vec3d[..., 2, None, None] * uv
115
+
116
+ # Backward
117
+ J = []
118
+
119
+ # (..., N, 2, 2)
120
+ J_norm2proj = J_vecnorm(
121
+ get_up_field(camera, gravity, normalize=False).reshape(camera.shape[0], -1, 2)
122
+ )
123
+
124
+ # distortion values
125
+ if hasattr(camera, "dist"):
126
+ d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
127
+ d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
128
+ offset = camera.up_projection_offset(uv) # (..., N, 2)
129
+ offset_uv = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
130
+
131
+ ######################
132
+ ## Gravity Jacobian ##
133
+ ######################
134
+
135
+ J_proj2abc = J_up_projection(uv, gravity.vec3d, wrt="abc") # (..., N, 2, 3)
136
+
137
+ if hasattr(camera, "dist"):
138
+ # (..., N, 2, 3)
139
+ J_proj2abc = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2abc)
140
+
141
+ J_abc2delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
142
+ J_proj2delta = torch.einsum("...Nij,...jk->...Nik", J_proj2abc, J_abc2delta)
143
+ J_up2delta = torch.einsum("...Nij,...Njk->...Nik", J_norm2proj, J_proj2delta)
144
+ J.append(J_up2delta)
145
+
146
+ ######################
147
+ ### Focal Jacobian ###
148
+ ######################
149
+
150
+ J_proj2uv = J_up_projection(uv, gravity.vec3d, wrt="uv") # (..., N, 2, 2)
151
+
152
+ if hasattr(camera, "dist"):
153
+ J_proj2up = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2uv)
154
+ J_proj2duv = torch.einsum("...i,...j->...ji", offset, projected_up2d)
155
+
156
+ inner = (uv * projected_up2d).sum(-1)[..., None, None]
157
+ J_proj2offset1 = inner * camera.J_up_projection_offset(uv, wrt="uv")
158
+ J_proj2offset2 = torch.einsum("...i,...j->...ij", offset, projected_up2d) # (..., N, 2, 2)
159
+ J_proj2uv = (J_proj2duv + J_proj2offset1 + J_proj2offset2) + J_proj2up
160
+
161
+ J_uv2f = camera.J_normalize(xy) # (..., N, 2, 2)
162
+
163
+ if log_focal:
164
+ J_uv2f = J_uv2f * camera.f[..., None, None, :] # (..., N, 2, 2)
165
+
166
+ J_uv2f = J_uv2f.sum(-1) # (..., N, 2)
167
+
168
+ J_proj2f = torch.einsum("...ij,...j->...i", J_proj2uv, J_uv2f) # (..., N, 2)
169
+ J_up2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2f)[..., None] # (..., N, 2, 1)
170
+ J.append(J_up2f)
171
+
172
+ ######################
173
+ ##### K1 Jacobian ####
174
+ ######################
175
+
176
+ if hasattr(camera, "dist"):
177
+ J_duv = camera.J_distort(uv, wrt="scale2dist")
178
+ J_duv = torch.diag_embed(J_duv.expand(J_duv.shape[:-1] + (2,))) # (..., N, 2, 2)
179
+ J_offset = torch.einsum(
180
+ "...i,...j->...ij", camera.J_up_projection_offset(uv, wrt="dist"), uv
181
+ )
182
+ J_proj2k1 = torch.einsum("...Nij,...Nj->...Ni", J_duv + J_offset, projected_up2d)
183
+ J_k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2k1)[..., None]
184
+ J.append(J_k1)
185
+
186
+ n_params = sum(j.shape[-1] for j in J)
187
+ return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 2, n_params)
188
+
189
+
190
+ def get_latitude_field(camera: BaseCamera, gravity: Gravity) -> torch.Tensor:
191
+ """Get the latitudes of the camera pixels in radians.
192
+
193
+ Latitudes are defined as the angle between the ray and the up vector.
194
+
195
+ Args:
196
+ camera (Camera): Camera parameters.
197
+ gravity (Gravity): Gravity vector.
198
+
199
+ Returns:
200
+ torch.Tensor: Latitudes in radians as a tensor of shape (..., h, w, 1).
201
+ """
202
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
203
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
204
+
205
+ w, h = camera.size[0].unbind(-1)
206
+ h, w = h.round().to(int), w.round().to(int)
207
+
208
+ uv1, _ = camera.image2world(camera.pixel_coordinates())
209
+ rays = camera.pixel_bearing_many(uv1)
210
+
211
+ lat = torch.einsum("...Nj,...j->...N", rays, gravity.vec3d)
212
+
213
+ eps = 1e-6
214
+ lat_asin = torch.asin(lat.clamp(min=-1 + eps, max=1 - eps))
215
+
216
+ try:
217
+ del uv1, rays
218
+ except NameError:
219
+ pass
220
+
221
+ return lat_asin.reshape(camera.shape[0], h, w, 1)
222
+
223
+
224
+ def J_latitude_field(
225
+ camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
226
+ ) -> torch.Tensor:
227
+ """Get the jacobian of the latitude field.
228
+
229
+ Args:
230
+ camera (Camera): Camera parameters.
231
+ gravity (Gravity): Gravity vector.
232
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
233
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
234
+
235
+ Returns:
236
+ torch.Tensor: Jacobian of the latitude field as a tensor of shape (..., h, w, 1, 3).
237
+ """
238
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
239
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
240
+
241
+ w, h = camera.size[0].unbind(-1)
242
+ h, w = h.round().to(int), w.round().to(int)
243
+
244
+ # Forward
245
+ xy = camera.pixel_coordinates()
246
+ uv1, _ = camera.image2world(xy)
247
+ uv1_norm = camera.pixel_bearing_many(uv1) # (..., N, 3)
248
+
249
+ # Backward
250
+ J = []
251
+ J_norm2w_to_img = J_vecnorm(uv1)[..., :2] # (..., N, 2)
252
+
253
+ ######################
254
+ ## Gravity Jacobian ##
255
+ ######################
256
+
257
+ J_delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
258
+ J_delta = torch.einsum("...Ni,...ij->...Nj", uv1_norm, J_delta) # (..., N, 2)
259
+ J.append(J_delta)
260
+
261
+ ######################
262
+ ### Focal Jacobian ###
263
+ ######################
264
+
265
+ J_w_to_img2f = camera.J_image2world(xy, "f") # (..., N, 2, 2)
266
+ if log_focal:
267
+ J_w_to_img2f = J_w_to_img2f * camera.f[..., None, None, :]
268
+ J_w_to_img2f = J_w_to_img2f.sum(-1) # (..., N, 2)
269
+
270
+ J_norm2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2f) # (..., N, 3)
271
+ J_f = torch.einsum("...Ni,...i->...N", J_norm2f, gravity.vec3d).unsqueeze(-1) # (..., N, 1)
272
+ J.append(J_f)
273
+
274
+ ######################
275
+ ##### K1 Jacobian ####
276
+ ######################
277
+
278
+ if hasattr(camera, "dist"):
279
+ J_w_to_img2k1 = camera.J_image2world(xy, "dist") # (..., N, 2)
280
+ # (..., N, 2)
281
+ J_norm2k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2k1)
282
+ # (..., N, 1)
283
+ J_k1 = torch.einsum("...Ni,...i->...N", J_norm2k1, gravity.vec3d).unsqueeze(-1)
284
+ J.append(J_k1)
285
+
286
+ n_params = sum(j.shape[-1] for j in J)
287
+ return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 1, n_params)
288
+
289
+
290
+ def get_perspective_field(
291
+ camera: BaseCamera,
292
+ gravity: Gravity,
293
+ use_up: bool = True,
294
+ use_latitude: bool = True,
295
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
296
+ """Get the perspective field from the camera parameters.
297
+
298
+ Args:
299
+ camera (Camera): Camera parameters.
300
+ gravity (Gravity): Gravity vector.
301
+ use_up (bool, optional): Whether to include the up vector field. Defaults to True.
302
+ use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
303
+
304
+ Returns:
305
+ Tuple[torch.Tensor, torch.Tensor]: Up and latitude fields as tensors of shape
306
+ (..., 2, h, w) and (..., 1, h, w).
307
+ """
308
+ assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
309
+
310
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
311
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
312
+
313
+ w, h = camera.size[0].unbind(-1)
314
+ h, w = h.round().to(int), w.round().to(int)
315
+
316
+ if use_up:
317
+ permute = (0, 3, 1, 2)
318
+ # (..., 2, h, w)
319
+ up = get_up_field(camera, gravity).permute(permute)
320
+ else:
321
+ shape = (camera.shape[0], 2, h, w)
322
+ up = camera.new_zeros(shape)
323
+
324
+ if use_latitude:
325
+ permute = (0, 3, 1, 2)
326
+ # (..., 1, h, w)
327
+ lat = get_latitude_field(camera, gravity).permute(permute)
328
+ else:
329
+ shape = (camera.shape[0], 1, h, w)
330
+ lat = camera.new_zeros(shape)
331
+
332
+ torch.cuda.empty_cache()
333
+
334
+ return up, lat
335
+
336
+
337
+ def J_perspective_field(
338
+ camera: BaseCamera,
339
+ gravity: Gravity,
340
+ use_up: bool = True,
341
+ use_latitude: bool = True,
342
+ spherical: bool = False,
343
+ log_focal: bool = False,
344
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
345
+ """Get the jacobian of the perspective field.
346
+
347
+ Args:
348
+ camera (Camera): Camera parameters.
349
+ gravity (Gravity): Gravity vector.
350
+ use_up (bool, optional): Whether to include the up vector field. Defaults to True.
351
+ use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
352
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
353
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
354
+
355
+ Returns:
356
+ Tuple[torch.Tensor, torch.Tensor]: Up and latitude jacobians as tensors of shape
357
+ (..., h, w, 2, 4) and (..., h, w, 1, 4).
358
+ """
359
+ assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
360
+
361
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
362
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
363
+
364
+ w, h = camera.size[0].unbind(-1)
365
+ h, w = h.round().to(int), w.round().to(int)
366
+
367
+ if use_up:
368
+ J_up = J_up_field(camera, gravity, spherical, log_focal) # (..., h, w, 2, 4)
369
+ else:
370
+ shape = (camera.shape[0], h, w, 2, 4)
371
+ J_up = camera.new_zeros(shape)
372
+
373
+ if use_latitude:
374
+ J_lat = J_latitude_field(camera, gravity, spherical, log_focal) # (..., h, w, 1, 4)
375
+ else:
376
+ shape = (camera.shape[0], h, w, 1, 4)
377
+ J_lat = camera.new_zeros(shape)
378
+
379
+ return J_up, J_lat
scripts/camera/utils/conversions.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for conversions between different representations."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+
9
+ def skew_symmetric(v: torch.Tensor) -> torch.Tensor:
10
+ """Create a skew-symmetric matrix from a (batched) vector of size (..., 3).
11
+
12
+ Args:
13
+ (torch.Tensor): Vector of size (..., 3).
14
+
15
+ Returns:
16
+ (torch.Tensor): Skew-symmetric matrix of size (..., 3, 3).
17
+ """
18
+ z = torch.zeros_like(v[..., 0])
19
+ return torch.stack(
20
+ [
21
+ z,
22
+ -v[..., 2],
23
+ v[..., 1],
24
+ v[..., 2],
25
+ z,
26
+ -v[..., 0],
27
+ -v[..., 1],
28
+ v[..., 0],
29
+ z,
30
+ ],
31
+ dim=-1,
32
+ ).reshape(v.shape[:-1] + (3, 3))
33
+
34
+
35
+ def rad2rotmat(
36
+ roll: torch.Tensor, pitch: torch.Tensor, yaw: Optional[torch.Tensor] = None
37
+ ) -> torch.Tensor:
38
+ """Convert (batched) roll, pitch, yaw angles (in radians) to rotation matrix.
39
+
40
+ Args:
41
+ roll (torch.Tensor): Roll angle in radians.
42
+ pitch (torch.Tensor): Pitch angle in radians.
43
+ yaw (torch.Tensor, optional): Yaw angle in radians. Defaults to None.
44
+
45
+ Returns:
46
+ torch.Tensor: Rotation matrix of shape (..., 3, 3).
47
+ """
48
+ if yaw is None:
49
+ yaw = roll.new_zeros(roll.shape)
50
+
51
+ Rx = pitch.new_zeros(pitch.shape + (3, 3))
52
+ Rx[..., 0, 0] = 1
53
+ Rx[..., 1, 1] = torch.cos(pitch)
54
+ Rx[..., 1, 2] = torch.sin(pitch)
55
+ Rx[..., 2, 1] = -torch.sin(pitch)
56
+ Rx[..., 2, 2] = torch.cos(pitch)
57
+
58
+ Ry = yaw.new_zeros(yaw.shape + (3, 3))
59
+ Ry[..., 0, 0] = torch.cos(yaw)
60
+ Ry[..., 0, 2] = -torch.sin(yaw)
61
+ Ry[..., 1, 1] = 1
62
+ Ry[..., 2, 0] = torch.sin(yaw)
63
+ Ry[..., 2, 2] = torch.cos(yaw)
64
+
65
+ Rz = roll.new_zeros(roll.shape + (3, 3))
66
+ Rz[..., 0, 0] = torch.cos(roll)
67
+ Rz[..., 0, 1] = torch.sin(roll)
68
+ Rz[..., 1, 0] = -torch.sin(roll)
69
+ Rz[..., 1, 1] = torch.cos(roll)
70
+ Rz[..., 2, 2] = 1
71
+
72
+ return Rz @ Rx @ Ry
73
+
74
+
75
+ def fov2focal(fov: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
76
+ """Compute focal length from (vertical/horizontal) field of view.
77
+
78
+ Args:
79
+ fov (torch.Tensor): Field of view in radians.
80
+ size (torch.Tensor): Image height / width in pixels.
81
+
82
+ Returns:
83
+ torch.Tensor: Focal length in pixels.
84
+ """
85
+ return size / 2 / torch.tan(fov / 2)
86
+
87
+
88
+ def focal2fov(focal: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
89
+ """Compute (vertical/horizontal) field of view from focal length.
90
+
91
+ Args:
92
+ focal (torch.Tensor): Focal length in pixels.
93
+ size (torch.Tensor): Image height / width in pixels.
94
+
95
+ Returns:
96
+ torch.Tensor: Field of view in radians.
97
+ """
98
+ return 2 * torch.arctan(size / (2 * focal))
99
+
100
+
101
+ def pitch2rho(pitch: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
102
+ """Compute the distance from principal point to the horizon.
103
+
104
+ Args:
105
+ pitch (torch.Tensor): Pitch angle in radians.
106
+ f (torch.Tensor): Focal length in pixels.
107
+ h (torch.Tensor): Image height in pixels.
108
+
109
+ Returns:
110
+ torch.Tensor: Relative distance to the horizon.
111
+ """
112
+ return torch.tan(pitch) * f / h
113
+
114
+
115
+ def rho2pitch(rho: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
116
+ """Compute the pitch angle from the distance to the horizon.
117
+
118
+ Args:
119
+ rho (torch.Tensor): Relative distance to the horizon.
120
+ f (torch.Tensor): Focal length in pixels.
121
+ h (torch.Tensor): Image height in pixels.
122
+
123
+ Returns:
124
+ torch.Tensor: Pitch angle in radians.
125
+ """
126
+ return torch.atan(rho * h / f)
127
+
128
+
129
+ def rad2deg(rad: torch.Tensor) -> torch.Tensor:
130
+ """Convert radians to degrees.
131
+
132
+ Args:
133
+ rad (torch.Tensor): Angle in radians.
134
+
135
+ Returns:
136
+ torch.Tensor: Angle in degrees.
137
+ """
138
+ return rad / torch.pi * 180
139
+
140
+
141
+ def deg2rad(deg: torch.Tensor) -> torch.Tensor:
142
+ """Convert degrees to radians.
143
+
144
+ Args:
145
+ deg (torch.Tensor): Angle in degrees.
146
+
147
+ Returns:
148
+ torch.Tensor: Angle in radians.
149
+ """
150
+ return deg / 180 * torch.pi
scripts/camera/utils/image.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image preprocessing utilities."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ import collections.abc as collections
5
+ from pathlib import Path
6
+ from typing import Optional, Tuple
7
+
8
+ import cv2
9
+ import kornia
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image
15
+
16
+ from tensor import fit_features_to_multiple
17
+
18
+ # mypy: ignore-errors
19
+
20
+
21
+ class ImagePreprocessor:
22
+ """Preprocess images for calibration."""
23
+
24
+ default_conf = {
25
+ "resize": None, # target edge length (320), None for no resizing
26
+ "edge_divisible_by": None,
27
+ "side": "short",
28
+ "interpolation": "bilinear",
29
+ "align_corners": None,
30
+ "antialias": True,
31
+ "square_crop": False,
32
+ "add_padding_mask": False,
33
+ "resize_backend": "kornia", # torchvision, kornia
34
+ }
35
+
36
+ def __init__(self, conf) -> None:
37
+ """Initialize the image preprocessor."""
38
+ super().__init__()
39
+ default_conf = OmegaConf.create(self.default_conf)
40
+ OmegaConf.set_struct(default_conf, True)
41
+ self.conf = OmegaConf.merge(default_conf, conf)
42
+
43
+ def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict:
44
+ """Resize and preprocess an image, return image and resize scale."""
45
+ h, w = img.shape[-2:]
46
+ size = h, w
47
+
48
+ if self.conf.square_crop:
49
+ min_size = min(h, w)
50
+ offset = (h - min_size) // 2, (w - min_size) // 2
51
+ img = img[:, offset[0] : offset[0] + min_size, offset[1] : offset[1] + min_size]
52
+ size = img.shape[-2:]
53
+
54
+ if self.conf.resize is not None:
55
+ if interpolation is None:
56
+ interpolation = self.conf.interpolation
57
+ size = self.get_new_image_size(h, w)
58
+ img = self.resize(img, size, interpolation)
59
+
60
+ scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
61
+ T = np.diag([scale[0].cpu(), scale[1].cpu(), 1])
62
+
63
+ data = {
64
+ "scales": scale,
65
+ "image_size": np.array(size[::-1]),
66
+ "transform": T,
67
+ "original_image_size": np.array([w, h]),
68
+ }
69
+
70
+ if self.conf.edge_divisible_by is not None:
71
+ # crop to make the edge divisible by a number
72
+ w_, h_ = img.shape[-1], img.shape[-2]
73
+ img, _ = fit_features_to_multiple(img, self.conf.edge_divisible_by, crop=True)
74
+ crop_pad = torch.Tensor([img.shape[-1] - w_, img.shape[-2] - h_]).to(img)
75
+ data["crop_pad"] = crop_pad
76
+ data["image_size"] = np.array([img.shape[-1], img.shape[-2]])
77
+
78
+ data["image"] = img
79
+ return data
80
+
81
+ def resize(self, img: torch.Tensor, size: Tuple[int, int], interpolation: str) -> torch.Tensor:
82
+ """Resize an image using the specified backend."""
83
+ if self.conf.resize_backend == "kornia":
84
+ return kornia.geometry.transform.resize(
85
+ img,
86
+ size,
87
+ side=self.conf.side,
88
+ antialias=self.conf.antialias,
89
+ align_corners=self.conf.align_corners,
90
+ interpolation=interpolation,
91
+ )
92
+ elif self.conf.resize_backend == "PIL":
93
+ device = img.device
94
+ imgs = []
95
+ has_batch_dim = img.ndim == 4
96
+ img = img if has_batch_dim else img[None]
97
+ for im in img:
98
+ im = (im.permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8)
99
+ im = Image.fromarray(im).resize(size[::-1], Image.BILINEAR)
100
+ im = torch.tensor(np.array(im)).permute(2, 0, 1) / 255.0
101
+ imgs.append(im.to(device))
102
+ imgs = torch.stack(imgs)
103
+ return imgs if has_batch_dim else imgs[0]
104
+
105
+ elif self.conf.resize_backend == "torchvision":
106
+ return torchvision.transforms.Resize(size, antialias=self.conf.antialias)(img)
107
+ else:
108
+ raise ValueError(f"{self.conf.resize_backend} not implemented.")
109
+
110
+ def load_image(self, image_path: Path) -> dict:
111
+ """Load an image from a path and preprocess it."""
112
+ return self(load_image(image_path))
113
+
114
+ def get_new_image_size(self, h: int, w: int) -> Tuple[int, int]:
115
+ """Get the new image size after resizing."""
116
+ side = self.conf.side
117
+ if isinstance(self.conf.resize, collections.Iterable):
118
+ assert len(self.conf.resize) == 2
119
+ return tuple(self.conf.resize)
120
+ side_size = self.conf.resize
121
+ aspect_ratio = w / h
122
+ if side not in ("short", "long", "vert", "horz"):
123
+ raise ValueError(
124
+ f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'"
125
+ )
126
+ return (
127
+ (side_size, int(side_size * aspect_ratio))
128
+ if side == "vert" or (side != "horz" and (side == "short") ^ (aspect_ratio < 1.0))
129
+ else (int(side_size / aspect_ratio), side_size)
130
+ )
131
+
132
+
133
+ def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
134
+ """Normalize the image tensor and reorder the dimensions."""
135
+ if image.ndim == 3:
136
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
137
+ elif image.ndim == 2:
138
+ image = image[None] # add channel axis
139
+ else:
140
+ raise ValueError(f"Not an image: {image.shape}")
141
+ return torch.tensor(image / 255.0, dtype=torch.float)
142
+
143
+
144
+ def torch_image_to_numpy(image: torch.Tensor) -> np.ndarray:
145
+ """Normalize and reorder the dimensions of an image tensor."""
146
+ if image.ndim == 3:
147
+ image = image.permute((1, 2, 0)) # CxHxW to HxWxC
148
+ elif image.ndim == 2:
149
+ image = image[None] # add channel axis
150
+ else:
151
+ raise ValueError(f"Not an image: {image.shape}")
152
+ return (image.cpu().detach().numpy() * 255).astype(np.uint8)
153
+
154
+
155
+ def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
156
+ """Read an image from path as RGB or grayscale."""
157
+ if not Path(path).exists():
158
+ raise FileNotFoundError(f"No image at path {path}.")
159
+ mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
160
+ image = cv2.imread(str(path), mode)
161
+ if image is None:
162
+ raise IOError(f"Could not read image at {path}.")
163
+ if not grayscale:
164
+ image = image[..., ::-1]
165
+ return image
166
+
167
+
168
+ def write_image(img: torch.Tensor, path: Path):
169
+ """Write an image tensor to a file."""
170
+ img = torch_image_to_numpy(img) if isinstance(img, torch.Tensor) else img
171
+ cv2.imwrite(str(path), img[..., ::-1])
172
+
173
+
174
+ def load_image(path: Path, grayscale: bool = False, return_tensor: bool = True) -> torch.Tensor:
175
+ """Load an image from a path and return as a tensor."""
176
+ image = read_image(path, grayscale=grayscale)
177
+ if return_tensor:
178
+ return numpy_image_to_torch(image)
179
+
180
+ assert image.ndim in [2, 3], f"Not an image: {image.shape}"
181
+ image = image[None] if image.ndim == 2 else image
182
+ return torch.tensor(image.copy(), dtype=torch.uint8)
scripts/camera/utils/tensor.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapted from https://github.com/cvg/GeoCalib"""
2
+
3
+ import collections.abc as collections
4
+ import functools
5
+ import inspect
6
+ from typing import Callable, List, Tuple
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ # flake8: noqa
12
+ # mypy: ignore-errors
13
+
14
+
15
+ string_classes = (str, bytes)
16
+
17
+
18
+ def autocast(func: Callable) -> Callable:
19
+ """Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays.
20
+
21
+ Use the device and dtype of the wrapper.
22
+
23
+ Args:
24
+ func (Callable): Method of a TensorWrapper class.
25
+
26
+ Returns:
27
+ Callable: Wrapped method.
28
+ """
29
+
30
+ @functools.wraps(func)
31
+ def wrap(self, *args):
32
+ device = torch.device("cpu")
33
+ dtype = None
34
+ if isinstance(self, TensorWrapper):
35
+ if self._data is not None:
36
+ device = self.device
37
+ dtype = self.dtype
38
+ elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
39
+ raise ValueError(self)
40
+
41
+ cast_args = []
42
+ for arg in args:
43
+ if isinstance(arg, np.ndarray):
44
+ arg = torch.from_numpy(arg)
45
+ arg = arg.to(device=device, dtype=dtype)
46
+ cast_args.append(arg)
47
+ return func(self, *cast_args)
48
+
49
+ return wrap
50
+
51
+
52
+ class TensorWrapper:
53
+ """Wrapper for PyTorch tensors."""
54
+
55
+ _data = None
56
+
57
+ @autocast
58
+ def __init__(self, data: torch.Tensor):
59
+ """Wrapper for PyTorch tensors."""
60
+ self._data = data
61
+
62
+ @property
63
+ def shape(self) -> torch.Size:
64
+ """Shape of the underlying tensor."""
65
+ return self._data.shape[:-1]
66
+
67
+ @property
68
+ def device(self) -> torch.device:
69
+ """Get the device of the underlying tensor."""
70
+ return self._data.device
71
+
72
+ @property
73
+ def dtype(self) -> torch.dtype:
74
+ """Get the dtype of the underlying tensor."""
75
+ return self._data.dtype
76
+
77
+ def __getitem__(self, index) -> torch.Tensor:
78
+ """Get the underlying tensor."""
79
+ return self.__class__(self._data[index])
80
+
81
+ def __setitem__(self, index, item):
82
+ """Set the underlying tensor."""
83
+ self._data[index] = item.data
84
+
85
+ def to(self, *args, **kwargs):
86
+ """Move the underlying tensor to a new device."""
87
+ return self.__class__(self._data.to(*args, **kwargs))
88
+
89
+ def cpu(self):
90
+ """Move the underlying tensor to the CPU."""
91
+ return self.__class__(self._data.cpu())
92
+
93
+ def cuda(self):
94
+ """Move the underlying tensor to the GPU."""
95
+ return self.__class__(self._data.cuda())
96
+
97
+ def pin_memory(self):
98
+ """Pin the underlying tensor to memory."""
99
+ return self.__class__(self._data.pin_memory())
100
+
101
+ def float(self):
102
+ """Cast the underlying tensor to float."""
103
+ return self.__class__(self._data.float())
104
+
105
+ def double(self):
106
+ """Cast the underlying tensor to double."""
107
+ return self.__class__(self._data.double())
108
+
109
+ def detach(self):
110
+ """Detach the underlying tensor."""
111
+ return self.__class__(self._data.detach())
112
+
113
+ def numpy(self):
114
+ """Convert the underlying tensor to a numpy array."""
115
+ return self._data.detach().cpu().numpy()
116
+
117
+ def new_tensor(self, *args, **kwargs):
118
+ """Create a new tensor of the same type and device."""
119
+ return self._data.new_tensor(*args, **kwargs)
120
+
121
+ def new_zeros(self, *args, **kwargs):
122
+ """Create a new tensor of the same type and device."""
123
+ return self._data.new_zeros(*args, **kwargs)
124
+
125
+ def new_ones(self, *args, **kwargs):
126
+ """Create a new tensor of the same type and device."""
127
+ return self._data.new_ones(*args, **kwargs)
128
+
129
+ def new_full(self, *args, **kwargs):
130
+ """Create a new tensor of the same type and device."""
131
+ return self._data.new_full(*args, **kwargs)
132
+
133
+ def new_empty(self, *args, **kwargs):
134
+ """Create a new tensor of the same type and device."""
135
+ return self._data.new_empty(*args, **kwargs)
136
+
137
+ def unsqueeze(self, *args, **kwargs):
138
+ """Create a new tensor of the same type and device."""
139
+ return self.__class__(self._data.unsqueeze(*args, **kwargs))
140
+
141
+ def squeeze(self, *args, **kwargs):
142
+ """Create a new tensor of the same type and device."""
143
+ return self.__class__(self._data.squeeze(*args, **kwargs))
144
+
145
+ @classmethod
146
+ def stack(cls, objects: List, dim=0, *, out=None):
147
+ """Stack a list of objects with the same type and shape."""
148
+ data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
149
+ return cls(data)
150
+
151
+ @classmethod
152
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
153
+ """Support torch functions."""
154
+ if kwargs is None:
155
+ kwargs = {}
156
+ return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented
157
+
158
+
159
+ def map_tensor(input_, func):
160
+ if isinstance(input_, string_classes):
161
+ return input_
162
+ elif isinstance(input_, collections.Mapping):
163
+ return {k: map_tensor(sample, func) for k, sample in input_.items()}
164
+ elif isinstance(input_, collections.Sequence):
165
+ return [map_tensor(sample, func) for sample in input_]
166
+ elif input_ is None:
167
+ return None
168
+ else:
169
+ return func(input_)
170
+
171
+
172
+ def batch_to_numpy(batch):
173
+ return map_tensor(batch, lambda tensor: tensor.cpu().numpy())
174
+
175
+
176
+ def batch_to_device(batch, device, non_blocking=True, detach=False):
177
+ def _func(tensor):
178
+ t = tensor.to(device=device, non_blocking=non_blocking, dtype=torch.float32)
179
+ return t.detach() if detach else t
180
+
181
+ return map_tensor(batch, _func)
182
+
183
+
184
+ def remove_batch_dim(data: dict) -> dict:
185
+ """Remove batch dimension from elements in data"""
186
+ return {
187
+ k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items()
188
+ }
189
+
190
+
191
+ def add_batch_dim(data: dict) -> dict:
192
+ """Add batch dimension to elements in data"""
193
+ return {
194
+ k: v[None] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
195
+ for k, v in data.items()
196
+ }
197
+
198
+
199
+ def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False):
200
+ """Get padding to make the image size a multiple of the given number.
201
+
202
+ Args:
203
+ x (torch.Tensor): Input tensor.
204
+ multiple (int, optional): Multiple.
205
+ crop (bool, optional): Whether to crop or pad. Defaults to False.
206
+
207
+ Returns:
208
+ torch.Tensor: Padding.
209
+ """
210
+ h, w = x.shape[-2:]
211
+
212
+ if crop:
213
+ pad_w = (w // multiple) * multiple - w
214
+ pad_h = (h // multiple) * multiple - h
215
+ else:
216
+ pad_w = (multiple - w % multiple) % multiple
217
+ pad_h = (multiple - h % multiple) % multiple
218
+
219
+ if mode == "center":
220
+ pad_l = pad_w // 2
221
+ pad_r = pad_w - pad_l
222
+ pad_t = pad_h // 2
223
+ pad_b = pad_h - pad_t
224
+ elif mode == "left":
225
+ pad_l = 0
226
+ pad_r = pad_w
227
+ pad_t = 0
228
+ pad_b = pad_h
229
+ else:
230
+ raise ValueError(f"Unknown mode {mode}")
231
+
232
+ return (pad_l, pad_r, pad_t, pad_b)
233
+
234
+
235
+ def fit_features_to_multiple(
236
+ features: torch.Tensor, multiple: int = 32, crop: bool = False
237
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
238
+ """Pad image to a multiple of the given number.
239
+
240
+ Args:
241
+ features (torch.Tensor): Input features.
242
+ multiple (int, optional): Multiple. Defaults to 32.
243
+ crop (bool, optional): Whether to crop or pad. Defaults to False.
244
+
245
+ Returns:
246
+ Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding.
247
+ """
248
+ pad = fit_to_multiple(features, multiple, crop=crop)
249
+ return torch.nn.functional.pad(features, pad, mode="reflect"), pad
scripts/camera/utils/text.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Tuple
3
+
4
+ def parse_camera_params(
5
+ text: str,
6
+ mode: str = "base"
7
+ ) -> Tuple[float, float, float]:
8
+ """
9
+ Extract roll, pitch, fov from text using one of two patterns:
10
+ - 'base' mode: ... are: roll, pitch, fov.
11
+ - 'cot' mode: <answer>roll, pitch, fov</answer>
12
+
13
+ Args:
14
+ text: The full text to search.
15
+ mode: One of {"base", "cot"}.
16
+
17
+ Returns:
18
+ roll, pitch, fov as floats.
19
+
20
+ Raises:
21
+ ValueError if the chosen pattern is not found, or mode is invalid.
22
+ """
23
+ # compile both regexes
24
+ pat_base = re.compile(
25
+ r"are:\s*([+-]?\d+(?:\.\d+)?)\s*,\s*"
26
+ r"([+-]?\d+(?:\.\d+)?)\s*,\s*"
27
+ r"([+-]?\d+(?:\.\d+)?)[\.\s]*$"
28
+ )
29
+ pat_cot = re.compile(
30
+ r"<answer>\s*([+-]?\d+(?:\.\d+)?)\s*,\s*"
31
+ r"([+-]?\d+(?:\.\d+)?)\s*,\s*"
32
+ r"([+-]?\d+(?:\.\d+)?)\s*</answer>"
33
+ )
34
+
35
+ m = None
36
+ if mode == "base":
37
+ m = pat_base.search(text)
38
+ elif mode == "cot":
39
+ m = pat_cot.search(text)
40
+ else:
41
+ raise ValueError(f"Invalid mode: {mode!r}. Choose 'base', 'cot', or 'auto'.")
42
+
43
+ if not m:
44
+ raise ValueError(f"No camera parameters found using mode '{mode}'.")
45
+
46
+ roll_s, pitch_s, fov_s = m.group(1), m.group(2), m.group(3)
47
+ return float(roll_s), float(pitch_s), float(fov_s)
scripts/camera/visualization/visualize_batch.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualization of predicted and ground truth for a single batch."""
2
+ """Adapted from https://github.com/cvg/GeoCalib"""
3
+
4
+ from typing import Any, Dict
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from scripts.camera.geometry.perspective_fields import get_latitude_field
10
+ from scripts.camera.utils.conversions import rad2deg
11
+ from scripts.camera.utils.tensor import batch_to_device
12
+ from scripts.camera.visualization.viz2d import (
13
+ plot_confidences,
14
+ plot_heatmaps,
15
+ plot_image_grid,
16
+ plot_latitudes,
17
+ plot_vector_fields,
18
+ )
19
+
20
+
21
+ def make_up_figure(
22
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
23
+ ) -> Dict[str, Any]:
24
+ """Get predicted and ground truth up fields and errors.
25
+
26
+ Args:
27
+ pred (Dict[str, torch.Tensor]): Predicted up field.
28
+ data (Dict[str, torch.Tensor]): Ground truth up field.
29
+ n_pairs (int): Number of pairs to visualize.
30
+
31
+ Returns:
32
+ Dict[str, Any]: Dictionary with figure.
33
+ """
34
+ pred = batch_to_device(pred, "cpu", detach=True)
35
+ data = batch_to_device(data, "cpu", detach=True)
36
+
37
+ n_pairs = min(n_pairs, len(data["image"]))
38
+
39
+ if "up_field" not in pred.keys():
40
+ return {}
41
+
42
+ up_fields = []
43
+ for i in range(n_pairs):
44
+ row = [data["up_field"][i]]
45
+ titles = ["Up GT"]
46
+
47
+ if "up_confidence" in pred.keys():
48
+ row += [pred["up_confidence"][i]]
49
+ titles += ["Up Confidence"]
50
+
51
+ row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row]
52
+ up_fields.append(row)
53
+
54
+ # create figure
55
+ N, M = len(up_fields), len(up_fields[0]) + 1
56
+ imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
57
+ fig, ax = plot_image_grid(imgs, return_fig=True, set_lim=True)
58
+ ax = np.array(ax)
59
+
60
+ for i in range(n_pairs):
61
+ plot_vector_fields([up_fields[i][0]], axes=ax[i, [1]])
62
+ #plot_heatmaps([up_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]])
63
+
64
+ if "up_confidence" in pred.keys():
65
+ plot_confidences([up_fields[i][3]], axes=ax[i, [4]])
66
+
67
+ return {"up": fig}
68
+
69
+
70
+ def make_latitude_figure(
71
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
72
+ ) -> Dict[str, Any]:
73
+ """Get predicted and ground truth latitude fields and errors.
74
+
75
+ Args:
76
+ pred (Dict[str, torch.Tensor]): Predicted latitude field.
77
+ data (Dict[str, torch.Tensor]): Ground truth latitude field.
78
+ n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
79
+
80
+ Returns:
81
+ Dict[str, Any]: Dictionary with figure.
82
+ """
83
+ pred = batch_to_device(pred, "cpu", detach=True)
84
+ data = batch_to_device(data, "cpu", detach=True)
85
+
86
+ n_pairs = min(n_pairs, len(data["image"]))
87
+ latitude_fields = []
88
+
89
+ if "latitude_field" not in pred.keys():
90
+ return {}
91
+
92
+ for i in range(n_pairs):
93
+ row = [
94
+ rad2deg(data["latitude_field"][i][0]),
95
+ #rad2deg(pred["latitude_field"][i][0]),
96
+ #errors[i],
97
+ ]
98
+ titles = ["Latitude GT"]
99
+
100
+ if "latitude_confidence" in pred.keys():
101
+ row += [pred["latitude_confidence"][i]]
102
+ titles += ["Latitude Confidence"]
103
+
104
+ row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row]
105
+ latitude_fields.append(row)
106
+
107
+ # create figure
108
+ N, M = len(latitude_fields), len(latitude_fields[0]) + 1
109
+ imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
110
+ fig, ax = plot_image_grid(imgs, return_fig=True, set_lim=True)
111
+ ax = np.array(ax)
112
+
113
+ for i in range(n_pairs):
114
+ plot_latitudes([latitude_fields[i][0]], is_radians=False, axes=ax[i, [1]])
115
+ #plot_heatmaps([latitude_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]])
116
+
117
+ if "latitude_confidence" in pred.keys():
118
+ plot_confidences([latitude_fields[i][3]], axes=ax[i, [4]])
119
+
120
+ return {"latitude": fig}
121
+
122
+
123
+ def make_camera_figure(
124
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
125
+ ) -> Dict[str, Any]:
126
+ """Get predicted and ground truth camera parameters.
127
+
128
+ Args:
129
+ pred (Dict[str, torch.Tensor]): Predicted camera parameters.
130
+ data (Dict[str, torch.Tensor]): Ground truth camera parameters.
131
+ n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
132
+
133
+ Returns:
134
+ Dict[str, Any]: Dictionary with figure.
135
+ """
136
+ pred = batch_to_device(pred, "cpu", detach=True)
137
+ data = batch_to_device(data, "cpu", detach=True)
138
+
139
+ n_pairs = min(n_pairs, len(data["image"]))
140
+
141
+ if "camera" not in pred.keys():
142
+ return {}
143
+
144
+ latitudes = []
145
+ for i in range(n_pairs):
146
+ titles = ["Cameras GT"]
147
+ row = [get_latitude_field(data["camera"][i], data["gravity"][i])]
148
+
149
+ if "camera" in pred.keys() and "gravity" in pred.keys():
150
+ row += [get_latitude_field(pred["camera"][i], pred["gravity"][i])]
151
+ titles += ["Cameras Pred"]
152
+
153
+ row = [rad2deg(r).squeeze(-1).float().numpy()[0] for r in row]
154
+ latitudes.append(row)
155
+
156
+ # create figure
157
+ N, M = len(latitudes), len(latitudes[0]) + 1
158
+ imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
159
+ fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True)
160
+ ax = np.array(ax)
161
+
162
+ for i in range(n_pairs):
163
+ plot_latitudes(latitudes[i], is_radians=False, axes=ax[i, 1:])
164
+
165
+ return {"camera": fig}
166
+
167
+
168
+ def make_perspective_figures(
169
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
170
+ ) -> Dict[str, Any]:
171
+ """Get predicted and ground truth perspective fields.
172
+
173
+ Args:
174
+ pred (Dict[str, torch.Tensor]): Predicted perspective fields.
175
+ data (Dict[str, torch.Tensor]): Ground truth perspective fields.
176
+ n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
177
+
178
+ Returns:
179
+ Dict[str, Any]: Dictionary with figure.
180
+ """
181
+ n_pairs = min(n_pairs, len(data["image"]))
182
+ figures = make_up_figure(pred, data, n_pairs)
183
+ figures |= make_latitude_figure(pred, data, n_pairs)
184
+ #figures |= make_camera_figure(pred, data, n_pairs)
185
+
186
+ {f.tight_layout() for f in figures.values()}
187
+
188
+ return figures
scripts/camera/visualization/viz2d.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2D visualization primitives based on Matplotlib.
3
+ 1) Plot images with `plot_images`.
4
+ 2) Call TODO: add functions
5
+ 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
6
+ """
7
+ """Adapted from https://github.com/cvg/GeoCalib"""
8
+
9
+ import matplotlib.patheffects as path_effects
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+
14
+ from scripts.camera.geometry.perspective_fields import get_perspective_field
15
+ from scripts.camera.utils.conversions import rad2deg
16
+
17
+ # flake8: noqa
18
+ # mypy: ignore-errors
19
+
20
+
21
+ def cm_ranking(sc, ths=None):
22
+ if ths is None:
23
+ ths = [512, 1024, 2048, 4096]
24
+
25
+ ls = sc.shape[0]
26
+ colors = ["red", "yellow", "lime", "cyan", "blue"]
27
+ out = ["gray"] * ls
28
+ for i in range(ls):
29
+ for c, th in zip(colors[: len(ths) + 1], ths + [ls]):
30
+ if i < th:
31
+ out[i] = c
32
+ break
33
+ sid = np.argsort(sc, axis=0).flip(0)
34
+ return np.array(out)[sid]
35
+
36
+
37
+ def cm_RdBl(x):
38
+ """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
39
+ x = np.clip(x, 0, 1)[..., None] * 2
40
+ c = x * np.array([[0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0]])
41
+ return np.clip(c, 0, 1)
42
+
43
+
44
+ def cm_RdGn(x):
45
+ """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
46
+ x = np.clip(x, 0, 1)[..., None] * 2
47
+ c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
48
+ return np.clip(c, 0, 1)
49
+
50
+
51
+ def cm_BlRdGn(x_):
52
+ """Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
53
+ x = np.clip(x_, 0, 1)[..., None] * 2
54
+ c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
55
+
56
+ xn = -np.clip(x_, -1, 0)[..., None] * 2
57
+ cn = xn * np.array([[0, 1.0, 0, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
58
+ return np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
59
+
60
+
61
+ def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True):
62
+ """Plot a list of images.
63
+
64
+ Args:
65
+ imgs (List[np.ndarray]): List of images to plot.
66
+ titles (List[str], optional): Titles. Defaults to None.
67
+ cmaps (str, optional): Colormaps. Defaults to "gray".
68
+ dpi (int, optional): Dots per inch. Defaults to 200.
69
+ pad (float, optional): Padding. Defaults to 0.5.
70
+ adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
71
+
72
+ Returns:
73
+ plt.Figure: Figure of the images.
74
+ """
75
+ n = len(imgs)
76
+ if not isinstance(cmaps, (list, tuple)):
77
+ cmaps = [cmaps] * n
78
+
79
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n
80
+ figsize = [sum(ratios) * 4.5, 4.5]
81
+ fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios})
82
+ if n == 1:
83
+ axs = [axs]
84
+ for i, (img, ax) in enumerate(zip(imgs, axs)):
85
+ ax.imshow(img, cmap=plt.get_cmap(cmaps[i]))
86
+ ax.set_axis_off()
87
+ if titles:
88
+ ax.set_title(titles[i])
89
+ fig.tight_layout(pad=pad)
90
+
91
+ return fig
92
+
93
+
94
+ def plot_image_grid(
95
+ imgs,
96
+ titles=None,
97
+ cmaps="gray",
98
+ dpi=100,
99
+ pad=0.5,
100
+ fig=None,
101
+ adaptive=True,
102
+ figs=3.0,
103
+ return_fig=False,
104
+ set_lim=False,
105
+ ) -> plt.Figure:
106
+ """Plot a grid of images.
107
+
108
+ Args:
109
+ imgs (List[np.ndarray]): List of images to plot.
110
+ titles (List[str], optional): Titles. Defaults to None.
111
+ cmaps (str, optional): Colormaps. Defaults to "gray".
112
+ dpi (int, optional): Dots per inch. Defaults to 100.
113
+ pad (float, optional): Padding. Defaults to 0.5.
114
+ fig (_type_, optional): Figure to plot on. Defaults to None.
115
+ adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
116
+ figs (float, optional): Figure size. Defaults to 3.0.
117
+ return_fig (bool, optional): Whether to return the figure. Defaults to False.
118
+ set_lim (bool, optional): Whether to set the limits. Defaults to False.
119
+
120
+ Returns:
121
+ plt.Figure: Figure and axes or just axes.
122
+ """
123
+ nr, n = len(imgs), len(imgs[0])
124
+ if not isinstance(cmaps, (list, tuple)):
125
+ cmaps = [cmaps] * n
126
+
127
+ if adaptive:
128
+ ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H
129
+ else:
130
+ ratios = [4 / 3] * n
131
+
132
+ figsize = [sum(ratios) * figs, nr * figs]
133
+ if fig is None:
134
+ fig, axs = plt.subplots(
135
+ nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
136
+ )
137
+ else:
138
+ axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios})
139
+ fig.figure.set_size_inches(figsize)
140
+
141
+ if nr == 1 and n == 1:
142
+ axs = [[axs]]
143
+ elif n == 1:
144
+ axs = axs[:, None]
145
+ elif nr == 1:
146
+ axs = [axs]
147
+
148
+ for j in range(nr):
149
+ for i in range(n):
150
+ ax = axs[j][i]
151
+ ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i]))
152
+ ax.set_axis_off()
153
+ if set_lim:
154
+ ax.set_xlim([0, imgs[j][i].shape[1]])
155
+ ax.set_ylim([imgs[j][i].shape[0], 0])
156
+ if titles:
157
+ ax.set_title(titles[j][i])
158
+ if isinstance(fig, plt.Figure):
159
+ fig.tight_layout(pad=pad)
160
+ return (fig, axs) if return_fig else axs
161
+
162
+
163
+ def add_text(
164
+ idx,
165
+ text,
166
+ pos=(0.01, 0.99),
167
+ fs=15,
168
+ color="w",
169
+ lcolor="k",
170
+ lwidth=4,
171
+ ha="left",
172
+ va="top",
173
+ axes=None,
174
+ **kwargs,
175
+ ):
176
+ """Add text to a plot.
177
+
178
+ Args:
179
+ idx (int): Index of the axes.
180
+ text (str): Text to add.
181
+ pos (tuple, optional): Text position. Defaults to (0.01, 0.99).
182
+ fs (int, optional): Font size. Defaults to 15.
183
+ color (str, optional): Text color. Defaults to "w".
184
+ lcolor (str, optional): Line color. Defaults to "k".
185
+ lwidth (int, optional): Line width. Defaults to 4.
186
+ ha (str, optional): Horizontal alignment. Defaults to "left".
187
+ va (str, optional): Vertical alignment. Defaults to "top".
188
+ axes (List[plt.Axes], optional): Axes to put text on. Defaults to None.
189
+
190
+ Returns:
191
+ plt.Text: Text object.
192
+ """
193
+ if axes is None:
194
+ axes = plt.gcf().axes
195
+
196
+ ax = axes[idx]
197
+
198
+ t = ax.text(
199
+ *pos,
200
+ text,
201
+ fontsize=fs,
202
+ ha=ha,
203
+ va=va,
204
+ color=color,
205
+ transform=ax.transAxes,
206
+ zorder=5,
207
+ **kwargs,
208
+ )
209
+ if lcolor is not None:
210
+ t.set_path_effects(
211
+ [
212
+ path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
213
+ path_effects.Normal(),
214
+ ]
215
+ )
216
+ return t
217
+
218
+
219
+ def plot_heatmaps(
220
+ heatmaps,
221
+ vmin=-1e-6, # include negative zero
222
+ vmax=None,
223
+ cmap="Spectral",
224
+ a=0.5,
225
+ axes=None,
226
+ contours_every=None,
227
+ contour_style="solid",
228
+ colorbar=False,
229
+ ):
230
+ """Plot heatmaps with optional contours.
231
+
232
+ To plot latitude field, set vmin=-90, vmax=90 and contours_every=15.
233
+
234
+ Args:
235
+ heatmaps (List[np.ndarray | torch.Tensor]): List of 2D heatmaps.
236
+ vmin (float, optional): Min Value. Defaults to -1e-6.
237
+ vmax (float, optional): Max Value. Defaults to None.
238
+ cmap (str, optional): Colormap. Defaults to "Spectral".
239
+ a (float, optional): Alpha value. Defaults to 0.5.
240
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
241
+ contours_every (int, optional): If not none, will draw contours. Defaults to None.
242
+ contour_style (str, optional): Style of the contours. Defaults to "solid".
243
+ colorbar (bool, optional): Whether to show colorbar. Defaults to False.
244
+
245
+ Returns:
246
+ List[plt.Artist]: List of artists.
247
+ """
248
+ if axes is None:
249
+ axes = plt.gcf().axes
250
+ artists = []
251
+
252
+ for i in range(len(axes)):
253
+ a_ = a if isinstance(a, float) else a[i]
254
+
255
+ if isinstance(heatmaps[i], torch.Tensor):
256
+ heatmaps[i] = heatmaps[i].detach().cpu().numpy()
257
+
258
+ alpha = a_
259
+ # Plot the heatmap
260
+ art = axes[i].imshow(
261
+ heatmaps[i],
262
+ alpha=alpha,
263
+ vmin=vmin,
264
+ vmax=vmax,
265
+ cmap=cmap,
266
+ )
267
+ if colorbar:
268
+ cmax = vmax or np.percentile(heatmaps[i], 99)
269
+ art.set_clim(vmin, cmax)
270
+ cbar = plt.colorbar(art, ax=axes[i])
271
+ artists.append(cbar)
272
+
273
+ artists.append(art)
274
+
275
+ if contours_every is not None:
276
+ # Add contour lines to the heatmap
277
+ contour_data = np.arange(vmin, vmax + contours_every, contours_every)
278
+
279
+ # Get the colormap colors for contour lines
280
+ contour_colors = [
281
+ plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level))
282
+ for level in contour_data
283
+ ]
284
+ contours = axes[i].contour(
285
+ heatmaps[i],
286
+ levels=contour_data,
287
+ linewidths=2,
288
+ colors=contour_colors,
289
+ linestyles=contour_style,
290
+ )
291
+
292
+ contours.set_clim(vmin, vmax)
293
+
294
+ fmt = {
295
+ level: f"{label}°"
296
+ for level, label in zip(contour_data, contour_data.astype(int).astype(str))
297
+ }
298
+ t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white")
299
+
300
+ for label in t:
301
+ label.set_path_effects(
302
+ [
303
+ path_effects.Stroke(linewidth=1, foreground="k"),
304
+ path_effects.Normal(),
305
+ ]
306
+ )
307
+ artists.append(contours)
308
+
309
+ return artists
310
+
311
+
312
+ def plot_horizon_lines(
313
+ cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None
314
+ ):
315
+ """Plot horizon lines on the perspective field.
316
+
317
+ Args:
318
+ cameras (List[Camera]): List of cameras.
319
+ gravities (List[Gravity]): Gravities.
320
+ line_colors (str, optional): Line Colors. Defaults to "orange".
321
+ lw (int, optional): Line width. Defaults to 2.
322
+ styles (str, optional): Line styles. Defaults to "solid".
323
+ alpha (float, optional): Alphas. Defaults to 1.0.
324
+ ax (List[plt.Axes], optional): Axes to draw horizon line on. Defaults to None.
325
+ """
326
+ if not isinstance(line_colors, list):
327
+ line_colors = [line_colors] * len(cameras)
328
+
329
+ if not isinstance(styles, list):
330
+ styles = [styles] * len(cameras)
331
+
332
+ fig = plt.gcf()
333
+ ax = fig.gca() if ax is None else ax
334
+
335
+ if isinstance(ax, plt.Axes):
336
+ ax = [ax] * len(cameras)
337
+
338
+ assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}"
339
+
340
+ for i in range(len(cameras)):
341
+ _, lat = get_perspective_field(cameras[i], gravities[i])
342
+ # horizon line is zero level of the latitude field
343
+ lat = lat[0, 0].cpu().numpy()
344
+ contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i])
345
+ for contour_line in contours.collections:
346
+ contour_line.set_linestyle(styles[i])
347
+
348
+
349
+ def plot_vector_fields(
350
+ vector_fields,
351
+ cmap="lime",
352
+ subsample=15,
353
+ scale=None,
354
+ lw=None,
355
+ alphas=0.8,
356
+ axes=None,
357
+ ):
358
+ """Plot vector fields.
359
+
360
+ Args:
361
+ vector_fields (List[torch.Tensor]): List of vector fields of shape (2, H, W).
362
+ cmap (str, optional): Color of the vectors. Defaults to "lime".
363
+ subsample (int, optional): Subsample the vector field. Defaults to 15.
364
+ scale (float, optional): Scale of the vectors. Defaults to None.
365
+ lw (float, optional): Line width of the vectors. Defaults to None.
366
+ alphas (float | np.ndarray, optional): Alpha per vector or global. Defaults to 0.8.
367
+ axes (List[plt.Axes], optional): List of axes to draw on. Defaults to None.
368
+
369
+ Returns:
370
+ List[plt.Artist]: List of artists.
371
+ """
372
+ if axes is None:
373
+ axes = plt.gcf().axes
374
+
375
+ vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields]
376
+
377
+ artists = []
378
+
379
+ H, W = vector_fields[0].shape[-2:]
380
+ if scale is None:
381
+ scale = subsample / min(H, W)
382
+
383
+ if lw is None:
384
+ lw = 0.1 / subsample
385
+
386
+ if alphas is None:
387
+ alphas = np.ones_like(vector_fields[0][0])
388
+ alphas = np.stack([alphas] * len(vector_fields), 0)
389
+ elif isinstance(alphas, float):
390
+ alphas = np.ones_like(vector_fields[0][0]) * alphas
391
+ alphas = np.stack([alphas] * len(vector_fields), 0)
392
+ else:
393
+ alphas = np.array(alphas)
394
+
395
+ subsample = min(W, H) // subsample
396
+ offset_x = ((W % subsample) + subsample) // 2
397
+
398
+ samples_x = np.arange(offset_x, W, subsample)
399
+ samples_y = np.arange(int(subsample * 0.9), H, subsample)
400
+
401
+ x_grid, y_grid = np.meshgrid(samples_x, samples_y)
402
+
403
+ for i in range(len(axes)):
404
+ # vector field of shape (2, H, W) with vectors of norm == 1
405
+ vector_field = vector_fields[i]
406
+
407
+ a = alphas[i][samples_y][:, samples_x]
408
+ x, y = vector_field[:, samples_y][:, :, samples_x]
409
+
410
+ c = cmap
411
+ if not isinstance(cmap, str):
412
+ c = cmap[i][samples_y][:, samples_x].reshape(-1, 3)
413
+
414
+ s = scale * min(H, W)
415
+ arrows = axes[i].quiver(
416
+ x_grid,
417
+ y_grid,
418
+ x,
419
+ y,
420
+ scale=s,
421
+ scale_units="width" if H > W else "height",
422
+ units="width" if H > W else "height",
423
+ alpha=a,
424
+ color=c,
425
+ angles="xy",
426
+ antialiased=True,
427
+ width=lw,
428
+ headaxislength=3.5,
429
+ zorder=5,
430
+ )
431
+
432
+ artists.append(arrows)
433
+
434
+ return artists
435
+
436
+
437
+ def plot_latitudes(
438
+ latitude,
439
+ is_radians=True,
440
+ vmin=-90,
441
+ vmax=90,
442
+ cmap="seismic",
443
+ contours_every=15,
444
+ alpha=0.4,
445
+ axes=None,
446
+ **kwargs,
447
+ ):
448
+ """Plot latitudes.
449
+
450
+ Args:
451
+ latitude (List[torch.Tensor]): List of latitudes.
452
+ is_radians (bool, optional): Whether the latitudes are in radians. Defaults to True.
453
+ vmin (int, optional): Min value to clip to. Defaults to -90.
454
+ vmax (int, optional): Max value to clip to. Defaults to 90.
455
+ cmap (str, optional): Colormap. Defaults to "seismic".
456
+ contours_every (int, optional): Contours every. Defaults to 15.
457
+ alpha (float, optional): Alpha value. Defaults to 0.4.
458
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
459
+
460
+ Returns:
461
+ List[plt.Artist]: List of artists.
462
+ """
463
+ if axes is None:
464
+ axes = plt.gcf().axes
465
+
466
+ assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}"
467
+ lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude
468
+ return plot_heatmaps(
469
+ lat,
470
+ vmin=vmin,
471
+ vmax=vmax,
472
+ cmap=cmap,
473
+ a=alpha,
474
+ axes=axes,
475
+ contours_every=contours_every,
476
+ **kwargs,
477
+ )
478
+
479
+
480
+ def plot_confidences(
481
+ confidence,
482
+ as_log=True,
483
+ vmin=-4,
484
+ vmax=0,
485
+ cmap="turbo",
486
+ alpha=0.4,
487
+ axes=None,
488
+ **kwargs,
489
+ ):
490
+ """Plot confidences.
491
+
492
+ Args:
493
+ confidence (List[torch.Tensor]): Confidence maps.
494
+ as_log (bool, optional): Whether to plot in log scale. Defaults to True.
495
+ vmin (int, optional): Min value to clip to. Defaults to -4.
496
+ vmax (int, optional): Max value to clip to. Defaults to 0.
497
+ cmap (str, optional): Colormap. Defaults to "turbo".
498
+ alpha (float, optional): Alpha value. Defaults to 0.4.
499
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
500
+
501
+ Returns:
502
+ List[plt.Artist]: List of artists.
503
+ """
504
+ if axes is None:
505
+ axes = plt.gcf().axes
506
+
507
+ confidence = [c.cpu() if isinstance(c, torch.Tensor) else torch.tensor(c) for c in confidence]
508
+
509
+ assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}"
510
+
511
+ if as_log:
512
+ confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence]
513
+
514
+ # normalize to [0, 1]
515
+ confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence]
516
+ return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs)
517
+
518
+
519
+ def save_plot(path, **kw):
520
+ """Save the current figure without any white margin."""
521
+ plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
src/datasets/utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ from xtuner.dataset.utils import get_bos_eos_token_ids
4
+ from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
5
+ import json
6
+
7
+ INPUT_IMAGE_TOKEN_INDEX = IMAGE_TOKEN_INDEX
8
+ OUTPUT_IMAGE_TOKEN_INDEX = -300
9
+ QUERY_TOKEN_INDEX = -400
10
+ QUERY_TOKEN = '<query>'
11
+
12
+ def crop2square(pil_img):
13
+ width, height = pil_img.width, pil_img.height
14
+
15
+ if width > height:
16
+ y0, y1 = 0, height
17
+ x0 = random.randint(0, width - height)
18
+ x1 = x0 + height
19
+ else:
20
+ x0, x1 = 0, width
21
+ y0 = random.randint(0, height - width)
22
+ y1 = y0 + width
23
+
24
+ return pil_img.crop(box=(x0, y0, x1, y1))
25
+
26
+ def load_jsonl(json_file):
27
+ with open(json_file) as f:
28
+ lines = f.readlines()
29
+ data = []
30
+ for line in lines:
31
+ data.append(json.loads(line))
32
+ return data
33
+
34
+
35
+ def encode_fn(example,
36
+ tokenizer,
37
+ max_length=None,
38
+ image_length=1,
39
+ query_length=1,
40
+ input_ids_with_output=True,
41
+ with_image_token=False,
42
+ prompt_template=None,
43
+ truncation='right'):
44
+ """Only support the following three scenarios:
45
+
46
+ 1. Incremental pretraining dataset.
47
+ example['conversation'] = [
48
+ {
49
+ 'input': '',
50
+ 'output': '### Human: Can you write xxx'
51
+ }
52
+ ]
53
+
54
+ 2. Single-turn conversation dataset.
55
+ example['conversation'] = [
56
+ {
57
+ 'input': 'Give three tips for staying healthy.',
58
+ 'output': '1.Eat a balanced diet xxx'
59
+ }
60
+ ]
61
+
62
+ 3. Multi-turn conversation dataset.
63
+ example['conversation'] = [
64
+ {
65
+ 'input': 'Give three tips for staying healthy.',
66
+ 'output': '1.Eat a balanced diet xxx'
67
+ },
68
+ {
69
+ 'input': 'Please expand on the second point.',
70
+ 'output': 'Here is an expanded explanation of the xxx'
71
+ }
72
+ ]
73
+ """
74
+
75
+ bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
76
+ is_multi_turn_conversation = len(example['conversation']) > 1
77
+ if is_multi_turn_conversation:
78
+ assert input_ids_with_output
79
+
80
+ input_ids, labels = [], []
81
+ next_needs_bos_token = True
82
+ for single_turn_conversation in example['conversation']:
83
+ input = single_turn_conversation['input']
84
+ if DEFAULT_IMAGE_TOKEN in input and with_image_token:
85
+ chunk_encode = [
86
+ tokenizer.encode(chunk, add_special_tokens=False)
87
+ for chunk in input.split(DEFAULT_IMAGE_TOKEN)
88
+ ]
89
+ assert len(chunk_encode) == 2
90
+ input_encode = []
91
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
92
+ input_encode.extend(cur_chunk_encode)
93
+ if idx != len(chunk_encode) - 1:
94
+ input_encode += [INPUT_IMAGE_TOKEN_INDEX] * image_length
95
+ else:
96
+ input_encode = tokenizer.encode(input, add_special_tokens=False)
97
+ if next_needs_bos_token:
98
+ input_ids += bos_token_id
99
+ labels += [IGNORE_INDEX] * len(bos_token_id)
100
+ input_ids += input_encode
101
+ labels += [IGNORE_INDEX] * len(input_encode)
102
+ if input_ids_with_output and 'output' in single_turn_conversation:
103
+ # Add output
104
+ output_with_loss = single_turn_conversation.get(
105
+ 'output_with_loss', True)
106
+ output = single_turn_conversation['output']
107
+ if DEFAULT_IMAGE_TOKEN in output and with_image_token:
108
+ chunk_encode = [
109
+ tokenizer.encode(chunk, add_special_tokens=False)
110
+ for chunk in output.split(DEFAULT_IMAGE_TOKEN)
111
+ ]
112
+ assert len(chunk_encode) == 2
113
+ output_encode = []
114
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
115
+ output_encode.extend(cur_chunk_encode)
116
+ if idx != len(chunk_encode) - 1:
117
+ output_encode += [OUTPUT_IMAGE_TOKEN_INDEX] * image_length
118
+ elif QUERY_TOKEN in output:
119
+ chunk_encode = [
120
+ tokenizer.encode(chunk, add_special_tokens=False)
121
+ for chunk in output.split(QUERY_TOKEN)
122
+ ]
123
+ assert len(chunk_encode) == 2
124
+ output_encode = []
125
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
126
+ output_encode.extend(cur_chunk_encode)
127
+ if idx != len(chunk_encode) - 1:
128
+ output_encode += [QUERY_TOKEN_INDEX] * query_length
129
+ else:
130
+ output_encode = tokenizer.encode(output, add_special_tokens=False)
131
+ input_ids += output_encode
132
+ if output_with_loss:
133
+ labels += copy.deepcopy(output_encode)
134
+ else:
135
+ labels += [IGNORE_INDEX] * len(output_encode)
136
+ # Add EOS_TOKEN (with loss)
137
+ if single_turn_conversation.get('need_eos_token', True):
138
+ next_needs_bos_token = True
139
+ input_ids += eos_token_id
140
+ if output_with_loss:
141
+ labels += copy.deepcopy(eos_token_id)
142
+ else:
143
+ labels += [IGNORE_INDEX] * len(eos_token_id)
144
+ else:
145
+ next_needs_bos_token = False
146
+ # Add SEP (without loss)
147
+ sep = single_turn_conversation.get('sep', '')
148
+ if sep != '':
149
+ sep_encode = tokenizer.encode(sep, add_special_tokens=False)
150
+ input_ids += sep_encode
151
+ labels += [IGNORE_INDEX] * len(sep_encode)
152
+
153
+ if max_length is not None and len(input_ids) > max_length:
154
+ if truncation == 'right':
155
+ input_ids = input_ids[:max_length]
156
+ labels = labels[:max_length]
157
+ elif truncation == 'left':
158
+ input_ids = input_ids[-max_length:]
159
+ labels = labels[-max_length:]
160
+ else:
161
+ assert truncation is None
162
+ return {'input_ids': input_ids, 'labels': labels}
src/models/connector/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_connector import ConnectorConfig
2
+ from .modeling_connector import ConnectorEncoder
src/models/connector/configuration_connector.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+ logger = logging.get_logger(__name__)
5
+
6
+
7
+ class ConnectorConfig(PretrainedConfig):
8
+ def __init__(
9
+ self,
10
+ hidden_size=768,
11
+ intermediate_size=3072,
12
+ num_hidden_layers=12,
13
+ num_attention_heads=12,
14
+ hidden_act="gelu_pytorch_tanh",
15
+ layer_norm_eps=1e-6,
16
+ attention_dropout=0.0,
17
+ **kwargs,
18
+ ):
19
+ super().__init__(**kwargs)
20
+
21
+ self.hidden_size = hidden_size
22
+ self.intermediate_size = intermediate_size
23
+ self.num_hidden_layers = num_hidden_layers
24
+ self.num_attention_heads = num_attention_heads
25
+ self.attention_dropout = attention_dropout
26
+ self.layer_norm_eps = layer_norm_eps
27
+ self.hidden_act = hidden_act
src/models/connector/modeling_connector.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Connector model."""
16
+
17
+ import math
18
+ import warnings
19
+ from typing import Any, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn.init import _calculate_fan_in_and_fan_out
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
28
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.utils import (
31
+ ModelOutput,
32
+ is_flash_attn_2_available,
33
+ is_flash_attn_greater_or_equal_2_10,
34
+ logging,
35
+ replace_return_docstrings,
36
+ torch_int,
37
+ )
38
+ from .configuration_connector import ConnectorConfig
39
+
40
+
41
+ if is_flash_attn_2_available():
42
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+
48
+ def init_weights(module):
49
+ """Initialize the weights"""
50
+ if isinstance(module, nn.Embedding):
51
+ default_flax_embed_init(module.weight)
52
+ elif isinstance(module, ConnectorAttention):
53
+ nn.init.xavier_uniform_(module.q_proj.weight)
54
+ nn.init.xavier_uniform_(module.k_proj.weight)
55
+ nn.init.xavier_uniform_(module.v_proj.weight)
56
+ nn.init.xavier_uniform_(module.out_proj.weight)
57
+ nn.init.zeros_(module.q_proj.bias)
58
+ nn.init.zeros_(module.k_proj.bias)
59
+ nn.init.zeros_(module.v_proj.bias)
60
+ nn.init.zeros_(module.out_proj.bias)
61
+ elif isinstance(module, ConnectorMLP):
62
+ nn.init.xavier_uniform_(module.fc1.weight)
63
+ nn.init.xavier_uniform_(module.fc2.weight)
64
+ nn.init.normal_(module.fc1.bias, std=1e-6)
65
+ nn.init.normal_(module.fc2.bias, std=1e-6)
66
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
67
+ lecun_normal_(module.weight)
68
+ if module.bias is not None:
69
+ nn.init.zeros_(module.bias)
70
+ elif isinstance(module, nn.LayerNorm):
71
+ module.bias.data.zero_()
72
+ module.weight.data.fill_(1.0)
73
+
74
+
75
+ def _trunc_normal_(tensor, mean, std, a, b):
76
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
77
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
78
+ def norm_cdf(x):
79
+ # Computes standard normal cumulative distribution function
80
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
81
+
82
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
83
+ warnings.warn(
84
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
85
+ "The distribution of values may be incorrect.",
86
+ stacklevel=2,
87
+ )
88
+
89
+ # Values are generated by using a truncated uniform distribution and
90
+ # then using the inverse CDF for the normal distribution.
91
+ # Get upper and lower cdf values
92
+ l = norm_cdf((a - mean) / std)
93
+ u = norm_cdf((b - mean) / std)
94
+
95
+ # Uniformly fill tensor with values from [l, u], then translate to
96
+ # [2l-1, 2u-1].
97
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
98
+
99
+ # Use inverse cdf transform for normal distribution to get truncated
100
+ # standard normal
101
+ tensor.erfinv_()
102
+
103
+ # Transform to proper mean, std
104
+ tensor.mul_(std * math.sqrt(2.0))
105
+ tensor.add_(mean)
106
+
107
+ # Clamp to ensure it's in the proper range
108
+ tensor.clamp_(min=a, max=b)
109
+
110
+
111
+ def trunc_normal_tf_(
112
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
113
+ ) -> torch.Tensor:
114
+ """Fills the input Tensor with values drawn from a truncated
115
+ normal distribution. The values are effectively drawn from the
116
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
117
+ with values outside :math:`[a, b]` redrawn until they are within
118
+ the bounds. The method used for generating the random values works
119
+ best when :math:`a \\leq \text{mean} \\leq b`.
120
+
121
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
122
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
123
+ and the result is subsequently scaled and shifted by the mean and std args.
124
+
125
+ Args:
126
+ tensor: an n-dimensional `torch.Tensor`
127
+ mean: the mean of the normal distribution
128
+ std: the standard deviation of the normal distribution
129
+ a: the minimum cutoff value
130
+ b: the maximum cutoff value
131
+ """
132
+ with torch.no_grad():
133
+ _trunc_normal_(tensor, 0, 1.0, a, b)
134
+ tensor.mul_(std).add_(mean)
135
+
136
+
137
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
138
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
139
+ if mode == "fan_in":
140
+ denom = fan_in
141
+ elif mode == "fan_out":
142
+ denom = fan_out
143
+ elif mode == "fan_avg":
144
+ denom = (fan_in + fan_out) / 2
145
+
146
+ variance = scale / denom
147
+
148
+ if distribution == "truncated_normal":
149
+ # constant is stddev of standard normal truncated to (-2, 2)
150
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
151
+ elif distribution == "normal":
152
+ with torch.no_grad():
153
+ tensor.normal_(std=math.sqrt(variance))
154
+ elif distribution == "uniform":
155
+ bound = math.sqrt(3 * variance)
156
+ with torch.no_grad():
157
+ tensor.uniform_(-bound, bound)
158
+ else:
159
+ raise ValueError(f"invalid distribution {distribution}")
160
+
161
+
162
+ def lecun_normal_(tensor):
163
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
164
+
165
+
166
+ def default_flax_embed_init(tensor):
167
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
168
+
169
+
170
+ class ConnectorAttention(nn.Module):
171
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
172
+
173
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
174
+ def __init__(self, config):
175
+ super().__init__()
176
+ self.config = config
177
+ self.embed_dim = config.hidden_size
178
+ self.num_heads = config.num_attention_heads
179
+ self.head_dim = self.embed_dim // self.num_heads
180
+ if self.head_dim * self.num_heads != self.embed_dim:
181
+ raise ValueError(
182
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
183
+ f" {self.num_heads})."
184
+ )
185
+ self.scale = self.head_dim**-0.5
186
+ self.dropout = config.attention_dropout
187
+
188
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
189
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
190
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
191
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
192
+
193
+ def forward(
194
+ self,
195
+ hidden_states: torch.Tensor,
196
+ attention_mask: Optional[torch.Tensor] = None,
197
+ output_attentions: Optional[bool] = False,
198
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
199
+ """Input shape: Batch x Time x Channel"""
200
+
201
+ batch_size, q_len, _ = hidden_states.size()
202
+
203
+ query_states = self.q_proj(hidden_states)
204
+ key_states = self.k_proj(hidden_states)
205
+ value_states = self.v_proj(hidden_states)
206
+
207
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
208
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
209
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
210
+
211
+ k_v_seq_len = key_states.shape[-2]
212
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
213
+
214
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
215
+ raise ValueError(
216
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
217
+ f" {attn_weights.size()}"
218
+ )
219
+
220
+ if attention_mask is not None:
221
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
222
+ raise ValueError(
223
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
224
+ )
225
+ attn_weights = attn_weights + attention_mask
226
+
227
+ # upcast attention to fp32
228
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
229
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
230
+ attn_output = torch.matmul(attn_weights, value_states)
231
+
232
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
233
+ raise ValueError(
234
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
235
+ f" {attn_output.size()}"
236
+ )
237
+
238
+ attn_output = attn_output.transpose(1, 2).contiguous()
239
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
240
+
241
+ attn_output = self.out_proj(attn_output)
242
+
243
+ return attn_output, attn_weights
244
+
245
+
246
+ class ConnectorFlashAttention2(ConnectorAttention):
247
+ """
248
+ ConnectorAttention flash attention module. This module inherits from `ConnectorAttention` as the weights of the module stays
249
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
250
+ flash attention and deal with padding tokens in case the input contains any of them.
251
+ """
252
+
253
+ is_causal = False
254
+
255
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
256
+ def __init__(self, *args, **kwargs):
257
+ super().__init__(*args, **kwargs)
258
+
259
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
260
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
261
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
262
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
263
+
264
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
265
+ def forward(
266
+ self,
267
+ hidden_states: torch.Tensor,
268
+ attention_mask: Optional[torch.LongTensor] = None,
269
+ output_attentions: bool = False,
270
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
271
+ output_attentions = False
272
+
273
+ batch_size, q_len, _ = hidden_states.size()
274
+
275
+ query_states = self.q_proj(hidden_states)
276
+ key_states = self.k_proj(hidden_states)
277
+ value_states = self.v_proj(hidden_states)
278
+
279
+ # Flash attention requires the input to have the shape
280
+ # batch_size x seq_length x head_dim x hidden_dim
281
+ # therefore we just need to keep the original shape
282
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
283
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
284
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
285
+
286
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
287
+ # to be able to avoid many of these transpose/reshape/view.
288
+ query_states = query_states.transpose(1, 2)
289
+ key_states = key_states.transpose(1, 2)
290
+ value_states = value_states.transpose(1, 2)
291
+
292
+ dropout_rate = self.dropout if self.training else 0.0
293
+
294
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
295
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
296
+ # cast them back in the correct dtype just to be sure everything works as expected.
297
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
298
+ # in fp32.
299
+
300
+ input_dtype = query_states.dtype
301
+ if input_dtype == torch.float32:
302
+ if torch.is_autocast_enabled():
303
+ target_dtype = torch.get_autocast_gpu_dtype()
304
+ # Handle the case where the model is quantized
305
+ elif hasattr(self.config, "_pre_quantization_dtype"):
306
+ target_dtype = self.config._pre_quantization_dtype
307
+ else:
308
+ target_dtype = self.q_proj.weight.dtype
309
+
310
+ logger.warning_once(
311
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
312
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
313
+ f" {target_dtype}."
314
+ )
315
+
316
+ query_states = query_states.to(target_dtype)
317
+ key_states = key_states.to(target_dtype)
318
+ value_states = value_states.to(target_dtype)
319
+
320
+ attn_output = _flash_attention_forward(
321
+ query_states,
322
+ key_states,
323
+ value_states,
324
+ attention_mask,
325
+ q_len,
326
+ dropout=dropout_rate,
327
+ is_causal=self.is_causal,
328
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
329
+ )
330
+
331
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
332
+ attn_output = self.out_proj(attn_output)
333
+
334
+ if not output_attentions:
335
+ attn_weights = None
336
+
337
+ return attn_output, attn_weights
338
+
339
+
340
+ class ConnectorSdpaAttention(ConnectorAttention):
341
+ """
342
+ Connector attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
343
+ `ConnectorAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
344
+ SDPA API.
345
+ """
346
+
347
+ is_causal = False
348
+
349
+ # Adapted from ConnectorAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
350
+ def forward(
351
+ self,
352
+ hidden_states: torch.Tensor,
353
+ attention_mask: Optional[torch.Tensor] = None,
354
+ output_attentions: Optional[bool] = False,
355
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
356
+ if output_attentions:
357
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
358
+ logger.warning_once(
359
+ "ConnectorModel is using ConnectorSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
360
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
361
+ )
362
+ return super().forward(
363
+ hidden_states=hidden_states,
364
+ attention_mask=attention_mask,
365
+ output_attentions=output_attentions,
366
+ )
367
+
368
+ batch_size, q_len, _ = hidden_states.size()
369
+
370
+ query_states = self.q_proj(hidden_states)
371
+ key_states = self.k_proj(hidden_states)
372
+ value_states = self.v_proj(hidden_states)
373
+
374
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
375
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
376
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
377
+
378
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
379
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
380
+ if query_states.device.type == "cuda" and attention_mask is not None:
381
+ query_states = query_states.contiguous()
382
+ key_states = key_states.contiguous()
383
+ value_states = value_states.contiguous()
384
+
385
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
386
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
387
+ is_causal = True if self.is_causal and q_len > 1 else False
388
+
389
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
390
+ query_states,
391
+ key_states,
392
+ value_states,
393
+ attn_mask=attention_mask,
394
+ dropout_p=self.dropout if self.training else 0.0,
395
+ is_causal=is_causal,
396
+ )
397
+
398
+ attn_output = attn_output.transpose(1, 2).contiguous()
399
+ attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
400
+
401
+ attn_output = self.out_proj(attn_output)
402
+
403
+ return attn_output, None
404
+
405
+
406
+ CONNECTOR_ATTENTION_CLASSES = {
407
+ "eager": ConnectorAttention,
408
+ "flash_attention_2": ConnectorFlashAttention2,
409
+ "sdpa": ConnectorSdpaAttention,
410
+ }
411
+
412
+
413
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Connector
414
+ class ConnectorMLP(nn.Module):
415
+ def __init__(self, config):
416
+ super().__init__()
417
+ self.config = config
418
+ self.activation_fn = ACT2FN[config.hidden_act]
419
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
420
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
421
+
422
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
423
+ hidden_states = self.fc1(hidden_states)
424
+ hidden_states = self.activation_fn(hidden_states)
425
+ hidden_states = self.fc2(hidden_states)
426
+ return hidden_states
427
+
428
+
429
+ class ConnectorEncoderLayer(nn.Module):
430
+ def __init__(self, config: ConnectorConfig):
431
+ super().__init__()
432
+ self.embed_dim = config.hidden_size
433
+ self.self_attn = CONNECTOR_ATTENTION_CLASSES[config._attn_implementation](config=config)
434
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
435
+ self.mlp = ConnectorMLP(config)
436
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
437
+
438
+ # Ignore copy
439
+ def forward(
440
+ self,
441
+ hidden_states: torch.Tensor,
442
+ attention_mask: torch.Tensor,
443
+ output_attentions: Optional[bool] = False,
444
+ ) -> Tuple[torch.FloatTensor]:
445
+ """
446
+ Args:
447
+ hidden_states (`torch.FloatTensor`):
448
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
449
+ attention_mask (`torch.FloatTensor`):
450
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
451
+ output_attentions (`bool`, *optional*, defaults to `False`):
452
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
453
+ returned tensors for more detail.
454
+ """
455
+ residual = hidden_states
456
+
457
+ hidden_states = self.layer_norm1(hidden_states)
458
+ hidden_states, attn_weights = self.self_attn(
459
+ hidden_states=hidden_states,
460
+ attention_mask=attention_mask,
461
+ output_attentions=output_attentions,
462
+ )
463
+ hidden_states = residual + hidden_states
464
+
465
+ residual = hidden_states
466
+ hidden_states = self.layer_norm2(hidden_states)
467
+ hidden_states = self.mlp(hidden_states)
468
+ hidden_states = residual + hidden_states
469
+
470
+ outputs = (hidden_states,)
471
+
472
+ if output_attentions:
473
+ outputs += (attn_weights,)
474
+
475
+ return outputs
476
+
477
+
478
+ # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Connector
479
+ class ConnectorEncoder(nn.Module):
480
+ def __init__(self, config: ConnectorConfig):
481
+ super().__init__()
482
+ self.config = config
483
+ self.layers = nn.ModuleList([ConnectorEncoderLayer(config) for _ in range(config.num_hidden_layers)])
484
+ self.gradient_checkpointing = False
485
+ self.apply(init_weights)
486
+
487
+ def forward(self, inputs_embeds):
488
+ hidden_states = inputs_embeds
489
+ for encoder_layer in self.layers:
490
+ if self.gradient_checkpointing and self.training:
491
+ layer_outputs = torch.utils.checkpoint.checkpoint(
492
+ encoder_layer.__call__,
493
+ hidden_states,
494
+ None,
495
+ False,
496
+ use_reentrant=False
497
+ )
498
+ else:
499
+ layer_outputs = encoder_layer(
500
+ hidden_states,
501
+ None,
502
+ output_attentions=False,
503
+ )
504
+
505
+ hidden_states = layer_outputs[0]
506
+
507
+ return hidden_states
src/models/connector/modeling_qwen2.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import Qwen2PreTrainedModel, Qwen2Config
4
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm, Qwen2DecoderLayer
5
+
6
+
7
+ class Qwen2Connector(Qwen2PreTrainedModel):
8
+ def __init__(self, config: Qwen2Config):
9
+ super().__init__(config)
10
+ self.layers = nn.ModuleList(
11
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
12
+ )
13
+
14
+ for layer in self.layers:
15
+ layer.self_attn.is_causal = False
16
+
17
+ self._attn_implementation = config._attn_implementation
18
+ assert self._attn_implementation == 'flash_attention_2'
19
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
20
+
21
+ self.gradient_checkpointing = False
22
+ # Initialize weights and apply final processing
23
+ self.post_init()
24
+
25
+ def forward(self, inputs_embeds):
26
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
27
+ position_ids = position_ids.expand(inputs_embeds.shape[0], -1)
28
+ hidden_states = inputs_embeds
29
+
30
+ for encoder_layer in self.layers:
31
+ if self.gradient_checkpointing and self.training:
32
+ layer_outputs = self._gradient_checkpointing_func(
33
+ encoder_layer.__call__,
34
+ hidden_states,
35
+ None,
36
+ position_ids,
37
+ use_reentrant=False
38
+ )
39
+ else:
40
+ layer_outputs = encoder_layer(
41
+ hidden_states,
42
+ attention_mask=None,
43
+ position_ids=position_ids,
44
+ )
45
+
46
+ hidden_states = layer_outputs[0]
47
+
48
+ hidden_states = self.norm(hidden_states)
49
+
50
+ return hidden_states
src/models/puffin/model.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import math
4
+ from tqdm import tqdm
5
+ from einops import rearrange
6
+ from copy import deepcopy
7
+ from six.moves import zip
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.autograd.function import Function
11
+ from torch.nn.utils.rnn import pad_sequence
12
+ from mmengine.logging import print_log
13
+ from mmengine.model import BaseModel
14
+ from xtuner.utils import IGNORE_INDEX
15
+ from xtuner.registry import BUILDER
16
+ from xtuner.model.utils import guess_load_checkpoint
17
+ from xtuner.dataset.map_fns.template_map_fn import template_map_fn
18
+ from transformers.cache_utils import DynamicCache
19
+ from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3
20
+
21
+ from src.models.connector import ConnectorConfig, ConnectorEncoder
22
+ from src.models.stable_diffusion3.pipeline_stable_diffusion_3_dynamic import StableDiffusion3Pipeline
23
+ from src.datasets.utils import encode_fn, QUERY_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, INPUT_IMAGE_TOKEN_INDEX
24
+
25
+ class _ScaleGradient(Function):
26
+ @staticmethod
27
+ def forward(ctx, input, scale):
28
+ ctx.scale = scale
29
+ return input
30
+
31
+ @staticmethod
32
+ def backward(ctx, grad_output):
33
+ return grad_output * ctx.scale, None
34
+
35
+ def build_mlp(hidden_size, projector_dim, z_dim):
36
+ return nn.Sequential(
37
+ nn.Linear(hidden_size, projector_dim),
38
+ nn.SiLU(),
39
+ nn.Linear(projector_dim, z_dim),)
40
+
41
+ def pad_an_image_tensor(image, pad_value=0):
42
+ h, w = image.shape[-2:]
43
+ if h > w:
44
+ pad_left = (h - w) // 2
45
+ pad_right = h - w - pad_left
46
+ p2d = (pad_left, pad_right, 0, 0)
47
+ else:
48
+ pad_top = (h - w) // 2
49
+ pad_bottom = h - w - pad_top
50
+ p2d = (0, 0, pad_top, pad_bottom)
51
+
52
+ image = F.pad(image, p2d, "constant", pad_value)
53
+
54
+ return image
55
+
56
+ class Qwen2p5RadioStableDiffusion3HFDynamic(BaseModel):
57
+ def __init__(self,
58
+ llm,
59
+ tokenizer,
60
+ prompt_template,
61
+ visual_encoder,
62
+ vae,
63
+ transformer,
64
+ train_scheduler,
65
+ test_scheduler,
66
+ connector_1,
67
+ connector_2,
68
+ num_queries=64,
69
+ freeze_transformer=True,
70
+ max_length=256,
71
+ freeze_visual_encoder=True,
72
+ freeze_llm=True,
73
+ visual_encoder_grad_scale=0.1,
74
+ fold_size=2,
75
+ unconditional=0.1,
76
+ unconditional_cross_view=0.1,
77
+ pretrained_pth=None,
78
+ use_activation_checkpointing=False,
79
+ *args, **kwargs):
80
+ super().__init__()
81
+
82
+ # basic settings
83
+ self.max_length = max_length
84
+ self.fold_size = fold_size
85
+ self.prompt_template = prompt_template
86
+ self.unconditional = unconditional
87
+ self.unconditional_cross_view = unconditional_cross_view
88
+
89
+ # networks building
90
+ # understanding branch
91
+ self.visual_encoder = BUILDER.build(visual_encoder)
92
+ self.llm = BUILDER.build(llm)
93
+ self.tokenizer = BUILDER.build(tokenizer)
94
+ self.projector = build_mlp(hidden_size=self.visual_encoder.model.embed_dim*fold_size**2,
95
+ projector_dim=self.llm.config.hidden_size,
96
+ z_dim=self.llm.config.hidden_size)
97
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(prompt_template['IMG_CONTEXT_TOKEN'])
98
+
99
+ # generation branch
100
+ self.vae = BUILDER.build(vae)
101
+ self.vae.requires_grad_(False)
102
+ self.transformer = BUILDER.build(transformer)
103
+ self.num_queries = num_queries
104
+ self.connector_1 = ConnectorEncoder(ConnectorConfig(**connector_1))
105
+ self.connector_2 = ConnectorEncoder(ConnectorConfig(**connector_2))
106
+
107
+ self.llm2connector_1 = nn.Linear(self.llm.config.hidden_size, self.connector_1.config.hidden_size)
108
+ self.llm2connector_2 = nn.Linear(self.llm.config.hidden_size, self.connector_2.config.hidden_size)
109
+ self.projector_1 = nn.Linear(self.connector_1.config.hidden_size, self.transformer.config.pooled_projection_dim)
110
+ self.projector_2 = nn.Linear(self.connector_2.config.hidden_size, self.transformer.config.joint_attention_dim)
111
+ nn.init.zeros_(self.projector_1.weight)
112
+ nn.init.zeros_(self.projector_2.weight)
113
+ nn.init.zeros_(self.projector_1.bias)
114
+ nn.init.zeros_(self.projector_2.bias)
115
+
116
+ self.meta_queries = nn.Parameter(
117
+ torch.zeros(num_queries, self.llm.config.hidden_size))
118
+ nn.init.normal_(self.meta_queries, std=1 / math.sqrt(self.llm.config.hidden_size))
119
+
120
+ # networks and training initialization
121
+ if freeze_visual_encoder:
122
+ self.visual_encoder.requires_grad_(False)
123
+ self.freeze_visual_encoder = freeze_visual_encoder
124
+ if freeze_llm:
125
+ self.llm.requires_grad_(False)
126
+ self.freeze_llm = freeze_llm
127
+ if freeze_transformer:
128
+ self.transformer.requires_grad_(False)
129
+ self.freeze_transformer = freeze_transformer
130
+
131
+ self.visual_encoder_grad_scale = visual_encoder_grad_scale
132
+ self.train_scheduler = BUILDER.build(train_scheduler)
133
+ self.test_scheduler = BUILDER.build(test_scheduler)
134
+
135
+ self.use_activation_checkpointing = use_activation_checkpointing
136
+ if use_activation_checkpointing:
137
+ self.llm.enable_input_require_grads()
138
+ self.gradient_checkpointing_enable()
139
+
140
+ if pretrained_pth is not None:
141
+ pretrained_state_dict = guess_load_checkpoint(pretrained_pth)
142
+ info = self.load_state_dict(pretrained_state_dict, strict=False)
143
+ print_log(f'Load pretrained weight from {pretrained_pth}')
144
+
145
+ @property
146
+ def device(self):
147
+ return self.llm.device
148
+
149
+ @property
150
+ def dtype(self):
151
+ return self.llm.dtype
152
+
153
+ def gradient_checkpointing_enable(self):
154
+ self.activation_checkpointing_enable()
155
+
156
+ def activation_checkpointing_enable(self):
157
+ self.llm.gradient_checkpointing_enable()
158
+ self.transformer.enable_gradient_checkpointing()
159
+ self.connector_1.gradient_checkpointing = True
160
+ self.connector_2.gradient_checkpointing = True
161
+
162
+ def gradient_checkpointing_disable(self):
163
+ self.activation_checkpointing_disable()
164
+
165
+ def activation_checkpointing_disable(self):
166
+ self.llm.gradient_checkpointing_disable()
167
+ self.transformer.disable_gradient_checkpointing()
168
+ self.connector_1.gradient_checkpointing = False
169
+ self.connector_2.gradient_checkpointing = False
170
+
171
+ def forward(self, data, data_samples=None, mode='loss'):
172
+ if mode == 'loss':
173
+ return self.compute_loss(data_dict=data)
174
+ else:
175
+ raise NotImplementedError
176
+
177
+ def extract_visual_features(self, pixel_values):
178
+ pixel_values = (pixel_values + 1.0) / 2 # [0, 1]
179
+ height, width = pixel_values.shape[-2:]
180
+ summary, features = self.visual_encoder(pixel_values)
181
+ patch_size = int((height * width // features.shape[1]) ** 0.5)
182
+ height, width = height // (patch_size * self.fold_size), width // (patch_size * self.fold_size)
183
+ features = rearrange(features, 'b (h p w q) d -> b (h w) (p q d)',
184
+ h=height, w=width, p=self.fold_size, q=self.fold_size)
185
+
186
+ return features
187
+
188
+ def llm2dit(self, x):
189
+ x_1 = self.connector_1(self.llm2connector_1(x))
190
+ x_1 = self.projector_1(x_1.mean(1))
191
+ x_2 = self.connector_2(self.llm2connector_2(x))
192
+ x_2 = self.projector_2(x_2)
193
+
194
+ return x_1, x_2
195
+
196
+
197
+ @torch.no_grad()
198
+ def prepare_gen_prompts(self, texts, data_type='text2image', num_refs=None, ref_lens=None, gen_type='GENERATION_CROSS'):
199
+ if data_type == 'text2image':
200
+ prompts = [self.prompt_template['GENERATION'].format(input=text) for text in texts]
201
+ prompts = [self.prompt_template['INSTRUCTION'].format(input=text) for text in prompts]
202
+
203
+ elif data_type == 'image2image':
204
+ assert num_refs is not None and ref_lens is not None, "num_refs and ref_lens are required for image2image"
205
+ prompts = []
206
+ cnt = 0
207
+ for text, num_ref in zip(texts, num_refs):
208
+ image_tokens = ''
209
+ for _ in range(num_ref):
210
+ image_tokens += (
211
+ self.prompt_template['IMG_START_TOKEN'] +
212
+ self.prompt_template['IMG_CONTEXT_TOKEN'] * ref_lens[cnt] +
213
+ self.prompt_template['IMG_END_TOKEN']
214
+ )
215
+ cnt += 1
216
+
217
+ text = self.prompt_template[gen_type].format(input=text)
218
+ prompt = self.prompt_template['INSTRUCTION'].format(input=f'{image_tokens}\n{text}')
219
+ prompts.append(prompt)
220
+ else:
221
+ raise ValueError(f"Unsupported data_type: {data_type}")
222
+
223
+ return self.tokenizer(
224
+ prompts, add_special_tokens=True, return_tensors='pt', padding=True, padding_side='left').to(self.device)
225
+
226
+
227
+ @torch.no_grad()
228
+ def prepare_und_prompts(self, conversations, data_type='image2text', image_lengths=None, input_ids_with_output=True):
229
+ input_ids, labels, input_lengths = [], [], []
230
+
231
+ if data_type == 'image2text':
232
+ assert image_lengths is not None, "`image_lengths` must be provided for image2text"
233
+ if isinstance(image_lengths, int):
234
+ image_lengths = [image_lengths] * len(conversations)
235
+ elif data_type == 'text2text':
236
+ image_lengths = [None] * len(conversations)
237
+ else:
238
+ raise ValueError(f"Unsupported data_type: {data_type}")
239
+
240
+ for conv, image_len in zip(conversations, image_lengths):
241
+ data_dict = template_map_fn(example=dict(conversation=deepcopy(conv)), template=self.prompt_template)
242
+ data_dict.update(encode_fn(data_dict,
243
+ tokenizer=self.tokenizer,
244
+ max_length=None,
245
+ input_ids_with_output=input_ids_with_output,
246
+ with_image_token=(data_type == 'image2text'),
247
+ image_length=image_len,
248
+ prompt_template=self.prompt_template))
249
+
250
+ input_ids.append(torch.tensor(data_dict['input_ids'], dtype=torch.long, device=self.device))
251
+ labels.append(torch.tensor(data_dict['labels'], dtype=torch.long, device=self.device))
252
+ input_lengths.append(len(data_dict['input_ids']))
253
+
254
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0, padding_side='left')
255
+ labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX, padding_side='left')
256
+
257
+ attention_mask = torch.zeros_like(input_ids).bool()
258
+ for i in range(len(input_ids)):
259
+ attention_mask[i, -input_lengths[i]:] = True
260
+
261
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
262
+ position_ids[position_ids < 0] = 0
263
+
264
+ return dict(input_ids=input_ids, attention_mask=attention_mask, labels=labels, position_ids=position_ids)
265
+
266
+ def train(self, mode=True):
267
+ super().train(mode=mode)
268
+ self.vae.train(mode=False)
269
+ if not mode:
270
+ self.gradient_checkpointing_disable()
271
+
272
+ return self
273
+
274
+ @torch.no_grad()
275
+ def pixels_to_latents(self, x):
276
+ z = self.vae.encode(x).latent_dist.sample()
277
+ z = (z - self.vae.config.shift_factor) * self.vae.config.scaling_factor
278
+ return z
279
+
280
+ @torch.no_grad()
281
+ def latents_to_pixels(self, z):
282
+ z = (z / self.vae.config.scaling_factor) + self.vae.config.shift_factor
283
+ x_rec = self.vae.decode(z).sample
284
+ return x_rec
285
+
286
+ def prepare_forward_input(self,
287
+ query_embeds,
288
+ input_ids=None,
289
+ image_embeds=None,
290
+ attention_mask=None,
291
+ past_key_values=None,
292
+ append_queries=True):
293
+ b, l, _ = query_embeds.shape
294
+ assert l > 0
295
+ attention_mask = attention_mask.to(device=self.device, dtype=torch.bool)
296
+ assert l == self.num_queries
297
+
298
+ if append_queries:
299
+ input_ids = torch.cat([
300
+ input_ids, input_ids.new_full(size=(b, l), fill_value=QUERY_TOKEN_INDEX)], dim=1)
301
+ attention_mask = torch.cat([attention_mask, attention_mask.new_ones(b, l)], dim=1)
302
+
303
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
304
+ position_ids[position_ids < 0] = 0
305
+
306
+ # prepare context
307
+ if past_key_values is not None:
308
+ inputs_embeds = query_embeds
309
+ position_ids = position_ids[..., -l:]
310
+ else:
311
+ inputs_embeds = torch.zeros(*input_ids.shape, self.llm.config.hidden_size,
312
+ device=self.device, dtype=self.dtype)
313
+ if image_embeds is not None:
314
+ inputs_embeds[input_ids == self.image_token_id] = \
315
+ image_embeds.contiguous().view(-1, self.llm.config.hidden_size)
316
+
317
+ inputs_embeds[input_ids == QUERY_TOKEN_INDEX] = \
318
+ query_embeds.contiguous().view(-1, self.llm.config.hidden_size)
319
+
320
+ text_places = torch.logical_and(input_ids != self.image_token_id, input_ids != QUERY_TOKEN_INDEX)
321
+
322
+ inputs_embeds[text_places] = self.llm.get_input_embeddings()(input_ids[text_places])
323
+
324
+ inputs = dict(inputs_embeds=inputs_embeds,
325
+ attention_mask=attention_mask,
326
+ position_ids=position_ids,
327
+ past_key_values=past_key_values)
328
+
329
+ return inputs
330
+
331
+ def get_sigmas(self, timesteps, n_dim=4):
332
+ sigmas = self.train_scheduler.sigmas.to(device=self.device, dtype=self.dtype)
333
+ schedule_timesteps = self.train_scheduler.timesteps.to(self.device)
334
+ timesteps = timesteps.to(self.device)
335
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
336
+
337
+ sigma = sigmas[step_indices].flatten()
338
+ while len(sigma.shape) < n_dim:
339
+ sigma = sigma.unsqueeze(-1)
340
+ return sigma
341
+
342
+ def diff_loss(self, model_input, pooled_prompt_embeds, prompt_embeds, cond_input=None):
343
+ noise = [torch.randn_like(x) for x in model_input]
344
+ bsz = len(model_input)
345
+
346
+ u = compute_density_for_timestep_sampling(
347
+ weighting_scheme='none',
348
+ batch_size=bsz,
349
+ logit_mean=0.0,
350
+ logit_std=1.0,
351
+ )
352
+ indices = (u * self.train_scheduler.config.num_train_timesteps).long()
353
+ timesteps = self.train_scheduler.timesteps[indices].to(device=self.device)
354
+
355
+ # Add noise according to flow matching
356
+ sigmas = self.get_sigmas(timesteps, n_dim=model_input[0].ndim + 1)
357
+ noisy_model_input = [(1.0 - x) * y + x * z for x, y, z in zip(sigmas, model_input, noise)]
358
+
359
+ # Predict the noise residual
360
+ model_pred = self.transformer(
361
+ hidden_states=noisy_model_input,
362
+ cond_hidden_states=cond_input,
363
+ encoder_hidden_states=prompt_embeds,
364
+ pooled_projections=pooled_prompt_embeds,
365
+ timestep=timesteps,
366
+ return_dict=False,
367
+ )[0]
368
+
369
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme='none', sigmas=sigmas)
370
+
371
+ # flow matching loss
372
+ target = [x - y for x, y in zip(noise, model_input)]
373
+
374
+ loss = [(x.float() * (y.float() - z.float()) ** 2).mean() for x, y, z in zip(weighting, model_pred, target)]
375
+ loss = sum(loss) / len(loss)
376
+
377
+ return loss
378
+
379
+ '''text-to-image generation (single-view)'''
380
+ def text2image_loss(self, data_dict):
381
+ pixel_values = [p.to(dtype=self.dtype, device=self.device) for p in data_dict['pixel_values']]
382
+ image_latents = [self.pixels_to_latents(p[None])[0] for p in pixel_values]
383
+
384
+ b = len(image_latents)
385
+
386
+ texts = ['' if random.uniform(0, 1) < self.unconditional else text
387
+ for text in data_dict['texts']]
388
+
389
+ text_inputs = self.prepare_gen_prompts(texts)
390
+ hidden_states = self.meta_queries[None].expand(b, self.num_queries, -1)
391
+
392
+ inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs)
393
+
394
+ max_length = self.max_length + self.num_queries
395
+ inputs_embeds = inputs['inputs_embeds'][:, -max_length:]
396
+ attention_mask = inputs['attention_mask'][:, -max_length:]
397
+ position_ids = inputs['position_ids'][:, -max_length:]
398
+
399
+ output = self.llm.model(
400
+ inputs_embeds=inputs_embeds,
401
+ attention_mask=attention_mask,
402
+ position_ids=position_ids,
403
+ return_dict=True)
404
+
405
+ hidden_states = output.last_hidden_state[:, -self.num_queries:]
406
+ pooled_prompt_embeds, prompt_embeds = self.llm2dit(hidden_states)
407
+
408
+ loss_diff = self.diff_loss(model_input=image_latents,
409
+ pooled_prompt_embeds=pooled_prompt_embeds,
410
+ prompt_embeds=prompt_embeds)
411
+
412
+ return loss_diff
413
+
414
+ '''text-to-image generation (single-view) with camera map'''
415
+ def cam2image_loss(self, data_dict):
416
+ pixel_values = [p.to(dtype=self.dtype, device=self.device) for p in data_dict['pixel_values']]
417
+ image_latents = [self.pixels_to_latents(p[None])[0] for p in pixel_values]
418
+ b = len(image_latents)
419
+ # camera map as condition for the diffusion model
420
+ cam_values = [[img.to(dtype=self.dtype, device=self.device) for img in ref_images]
421
+ for ref_images in data_dict['cam_values']]
422
+ cam_latents = [[self.pixels_to_latents(img[None])[0] for img in ref_images]
423
+ for ref_images in cam_values]
424
+
425
+ texts = ['' if random.uniform(0, 1) < self.unconditional else text
426
+ for text in data_dict['texts']]
427
+
428
+ text_inputs = self.prepare_gen_prompts(texts)
429
+ hidden_states = self.meta_queries[None].expand(b, self.num_queries, -1)
430
+
431
+ inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs)
432
+
433
+ max_length = self.max_length + self.num_queries
434
+ inputs_embeds = inputs['inputs_embeds'][:, -max_length:]
435
+ attention_mask = inputs['attention_mask'][:, -max_length:]
436
+ position_ids = inputs['position_ids'][:, -max_length:]
437
+
438
+ output = self.llm.model(
439
+ inputs_embeds=inputs_embeds,
440
+ attention_mask=attention_mask,
441
+ position_ids=position_ids,
442
+ return_dict=True)
443
+
444
+ hidden_states = output.last_hidden_state[:, -self.num_queries:]
445
+ pooled_prompt_embeds, prompt_embeds = self.llm2dit(hidden_states)
446
+
447
+ loss_diff = self.diff_loss(model_input=image_latents,
448
+ pooled_prompt_embeds=pooled_prompt_embeds,
449
+ prompt_embeds=prompt_embeds,
450
+ cond_input=cam_latents)
451
+
452
+ return loss_diff
453
+
454
+ '''image-to-image (cross-view) generation'''
455
+ def image2image_loss(self, data_dict):
456
+ # condition for the diffusion model (concat the camera map and the initial view)
457
+ cam_values = [[img.to(dtype=self.dtype, device=self.device) for img in ref_images]
458
+ for ref_images in data_dict['cam_values']]
459
+ cam_latents = [[self.pixels_to_latents(img[None])[0] for img in ref_images]
460
+ for ref_images in cam_values]
461
+ pixel_values_init = [[img.to(dtype=self.dtype, device=self.device) for img in ref_images]
462
+ for ref_images in data_dict['pixel_values_init']]
463
+ image_latents_init = [[self.pixels_to_latents(img[None])[0] for img in ref_images]
464
+ for ref_images in pixel_values_init]
465
+ mix_latents = [cam + img for cam, img in zip(cam_latents, image_latents_init)]
466
+
467
+ # condition embedding for querying the LLM (only initial view)
468
+ num_refs = [len(ref_images) for ref_images in pixel_values_init]
469
+ image_embeds = self.extract_visual_features(
470
+ torch.stack([pad_an_image_tensor(img) for ref_images in pixel_values_init for img in ref_images]))
471
+
472
+ image_embeds = self.projector(image_embeds)
473
+ ref_lens = [len(x) for x in image_embeds]
474
+ text_inputs = self.prepare_gen_prompts(data_dict['texts'], data_type='image2image',
475
+ num_refs=num_refs, ref_lens=ref_lens)
476
+
477
+ # input for the diffusion model
478
+ pixel_values = [p.to(dtype=self.dtype, device=self.device) for p in data_dict['pixel_values']]
479
+ image_latents = [self.pixels_to_latents(p[None])[0] for p in pixel_values]
480
+
481
+ # querying the LLM
482
+ b = len(image_latents)
483
+ hidden_states = self.meta_queries[None].expand(b, self.num_queries, -1)
484
+ inputs = self.prepare_forward_input(query_embeds=hidden_states, image_embeds=image_embeds, **text_inputs)
485
+
486
+ max_length = self.max_length + max(num_refs) * max(ref_lens) + self.num_queries
487
+ inputs_embeds = inputs['inputs_embeds'][:, -max_length:]
488
+ attention_mask = inputs['attention_mask'][:, -max_length:]
489
+ position_ids = inputs['position_ids'][:, -max_length:]
490
+
491
+ output = self.llm.model(inputs_embeds=inputs_embeds,
492
+ attention_mask=attention_mask,
493
+ position_ids=position_ids,
494
+ return_dict=True)
495
+ hidden_states = output.last_hidden_state[:, -self.num_queries:]
496
+ pooled_prompt_embeds, prompt_embeds = self.llm2dit(hidden_states)
497
+ loss_diff = self.diff_loss(model_input=image_latents,
498
+ pooled_prompt_embeds=pooled_prompt_embeds,
499
+ prompt_embeds=prompt_embeds,
500
+ cond_input=mix_latents)
501
+
502
+ return loss_diff
503
+
504
+ '''image-to-text(camera) understanding, mixed base, thinking, and instruction tuning'''
505
+ def image2text_loss(self, data_dict):
506
+ pixel_values = [pad_an_image_tensor(img) for img in data_dict['pixel_values']]
507
+ pixel_values = torch.stack(pixel_values).to(dtype=self.dtype, device=self.device)
508
+ image_embeds = self.extract_visual_features(pixel_values)
509
+
510
+ if not self.freeze_visual_encoder:
511
+ image_embeds = _ScaleGradient.apply(image_embeds, self.visual_encoder_grad_scale)
512
+
513
+ image_embeds = self.projector(image_embeds)
514
+ text_inputs = self.prepare_und_prompts(conversations=data_dict['conversations'],
515
+ data_type='image2text',
516
+ image_lengths=image_embeds.shape[1])
517
+
518
+ labels, input_ids, attention_mask, position_ids = \
519
+ text_inputs['labels'], text_inputs['input_ids'], text_inputs['attention_mask'], text_inputs['position_ids']
520
+
521
+
522
+ inputs_embeds = torch.zeros(*input_ids.shape, self.llm.config.hidden_size,
523
+ device=self.device, dtype=self.dtype)
524
+ inputs_embeds[input_ids == INPUT_IMAGE_TOKEN_INDEX] = image_embeds.flatten(0, 1)
525
+ inputs_embeds[input_ids != INPUT_IMAGE_TOKEN_INDEX] = \
526
+ self.llm.get_input_embeddings()(input_ids[input_ids != INPUT_IMAGE_TOKEN_INDEX])
527
+
528
+ max_length = self.max_length + image_embeds.shape[1]
529
+ inputs_embeds = inputs_embeds[:, -max_length:]
530
+ attention_mask = attention_mask[:, -max_length:]
531
+ position_ids = position_ids[:, -max_length:]
532
+ labels = labels[:, -max_length:]
533
+
534
+ output = self.llm.model(inputs_embeds=inputs_embeds,
535
+ attention_mask=attention_mask,
536
+ position_ids=position_ids,
537
+ return_dict=True)
538
+
539
+ hidden_states = output.last_hidden_state[:, :-1]
540
+ labels = labels[:, 1:]
541
+ hidden_states = hidden_states[labels >= 0]
542
+ labels = labels[labels >= 0]
543
+
544
+ logits = self.llm.get_output_embeddings()(hidden_states)
545
+ loss = F.cross_entropy(input=logits, target=labels)
546
+
547
+ return loss
548
+
549
+ '''text-to-text understanding, offering the enhanced caption for the generation'''
550
+ def text2text_loss(self, data_dict):
551
+ text_inputs = self.prepare_und_prompts(conversations=data_dict['conversations'], data_type='text2text')
552
+ labels, input_ids, attention_mask, position_ids = \
553
+ text_inputs['labels'], text_inputs['input_ids'], text_inputs['attention_mask'], text_inputs['position_ids']
554
+
555
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
556
+ max_length = self.max_length
557
+ inputs_embeds = inputs_embeds[:, -max_length:]
558
+ attention_mask = attention_mask[:, -max_length:]
559
+ position_ids = position_ids[:, -max_length:]
560
+ labels = labels[:, -max_length:]
561
+
562
+ output = self.llm.model(inputs_embeds=inputs_embeds,
563
+ attention_mask=attention_mask,
564
+ position_ids=position_ids,
565
+ return_dict=True)
566
+
567
+ hidden_states = output.last_hidden_state[:, :-1]
568
+ labels = labels[:, 1:]
569
+ hidden_states = hidden_states[labels >= 0]
570
+ labels = labels[labels >= 0]
571
+
572
+ logits = self.llm.get_output_embeddings()(hidden_states)
573
+ loss = F.cross_entropy(input=logits, target=labels)
574
+
575
+ return loss
576
+
577
+ '''distribute different losses for each task'''
578
+ def compute_loss(self, data_dict):
579
+ loss_fn_map = {
580
+ 'text2image': self.text2image_loss,
581
+ 'cam2image': self.cam2image_loss,
582
+ 'image2text': self.image2text_loss,
583
+ 'text2text': self.text2text_loss,
584
+ 'image2image': self.image2image_loss,
585
+ 'image2text_cross_view': self.image2text_loss,
586
+ }
587
+
588
+ losses = {}
589
+ for data_type, batch_data in data_dict.items():
590
+ if data_type not in loss_fn_map:
591
+ raise ValueError(f"Unsupported data_type: {data_type}")
592
+ loss_fn = loss_fn_map[data_type]
593
+ loss = loss_fn(batch_data)
594
+ losses[f'loss_{data_type}'] = loss
595
+ return losses
596
+
597
+ @torch.no_grad()
598
+ def generate(self,
599
+ prompt,
600
+ cfg_prompt,
601
+ cam_values=None,
602
+ pixel_values_init=None,
603
+ cfg_scale=4.5,
604
+ num_steps=50,
605
+ generator=None,
606
+ height=512,
607
+ width=512,
608
+ max_new_tokens=512,
609
+ reasoning=False,
610
+ prompt_reasoning=None,
611
+ progress_bar=True):
612
+ assert len(prompt) == len(cfg_prompt)
613
+ b = len(prompt)
614
+ output_reasoning = [''] * b
615
+
616
+ if reasoning:
617
+ # enrich the prompt if required reasoning generation
618
+ assert prompt_reasoning is not None, \
619
+ "prompt_reasoning must be provided for reasoning generation"
620
+ if isinstance(prompt_reasoning, str):
621
+ prompt_reasoning = [prompt_reasoning]
622
+ if isinstance(prompt, str):
623
+ prompt = [prompt]
624
+
625
+ conversations = [[{'input': f"{p1} {p2}",}]
626
+ for p1, p2 in zip(prompt_reasoning, prompt)]
627
+
628
+ text_inputs = self.prepare_und_prompts(
629
+ conversations=conversations, data_type="text2text", input_ids_with_output=False)
630
+ input_ids, attention_mask, position_ids = \
631
+ text_inputs['input_ids'], text_inputs['attention_mask'], text_inputs['position_ids']
632
+
633
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
634
+ past_key_values = DynamicCache.from_legacy_cache()
635
+
636
+ output_ids = []
637
+ for _ in tqdm(range(max_new_tokens), disable=not progress_bar):
638
+ output = self.llm.model(
639
+ inputs_embeds=inputs_embeds,
640
+ attention_mask=attention_mask,
641
+ position_ids=position_ids,
642
+ past_key_values=past_key_values,
643
+ use_cache=True,
644
+ return_dict=True)
645
+ logits = self.llm.get_output_embeddings()(output.last_hidden_state[:, -1:])
646
+ input_ids = torch.argmax(logits, dim=-1) # b 1
647
+ if len(output_ids) > 0:
648
+ input_ids = torch.where(output_ids[-1] == self.tokenizer.eos_token_id,
649
+ output_ids[-1], input_ids)
650
+ output_ids.append(input_ids)
651
+
652
+ if (input_ids == self.tokenizer.eos_token_id).all():
653
+ break
654
+
655
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
656
+ attention_mask = torch.cat([attention_mask, attention_mask.new_ones(b, 1)], dim=1)
657
+ position_ids = torch.max(position_ids, dim=1, keepdim=True).values + 1
658
+ past_key_values = output.past_key_values
659
+
660
+ output_ids = torch.cat(output_ids, dim=1)
661
+ output_reasoning = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
662
+ prompt = [f"{p} {o}" for p, o in zip(prompt, output_reasoning)]
663
+
664
+ if cam_values is not None:
665
+ # for the generation with the camera map
666
+ cam_values = [[img.to(dtype=self.dtype, device=self.device) for img in ref_images]
667
+ for ref_images in cam_values]
668
+ cond_latents = [[self.pixels_to_latents(img[None])[0] for img in ref_images]
669
+ for ref_images in cam_values]
670
+ text_inputs = self.prepare_gen_prompts(prompt + cfg_prompt)
671
+ if pixel_values_init is not None:
672
+ # for the generation with the camera map and initial view (cross-view generation)
673
+ num_refs = [len(ref_images) for ref_images in pixel_values_init]
674
+ pixel_values_init = [[img.to(dtype=self.dtype, device=self.device) for img in ref_images]
675
+ for ref_images in pixel_values_init]
676
+ image_embeds = self.extract_visual_features(
677
+ torch.stack([pad_an_image_tensor(img) for ref_images in pixel_values_init for img in ref_images]))
678
+ image_embeds = self.projector(image_embeds)
679
+
680
+ ref_lens = [len(x) for x in image_embeds]
681
+ text_inputs = self.prepare_gen_prompts(prompt + cfg_prompt, data_type='image2image', num_refs=num_refs*2, ref_lens=ref_lens*2)
682
+ text_inputs.update(image_embeds=torch.cat([image_embeds]*2))
683
+
684
+ cond_latents_init = [[self.pixels_to_latents(img[None])[0] for img in ref_imgs]
685
+ for ref_imgs in pixel_values_init]
686
+ cond_latents = [cam + img for cam, img in zip(cond_latents, cond_latents_init)]
687
+
688
+ cond_latents = cond_latents * 2
689
+ else:
690
+ # for the text2image generation
691
+ text_inputs = self.prepare_gen_prompts(prompt + cfg_prompt)
692
+ cond_latents = None
693
+
694
+ hidden_states = self.meta_queries[None].expand(2*b, self.num_queries, -1)
695
+ inputs = self.prepare_forward_input(query_embeds=hidden_states, **text_inputs)
696
+
697
+ output = self.llm.model(**inputs, return_dict=True)
698
+ hidden_states = output.last_hidden_state[:, -self.num_queries:]
699
+ pooled_prompt_embeds, prompt_embeds = self.llm2dit(hidden_states)
700
+
701
+ pipeline = StableDiffusion3Pipeline(
702
+ transformer=self.transformer,
703
+ scheduler=self.test_scheduler,
704
+ vae=self.vae,
705
+ text_encoder=None,
706
+ tokenizer=None,
707
+ text_encoder_2=None,
708
+ tokenizer_2=None,
709
+ text_encoder_3=None,
710
+ tokenizer_3=None,
711
+ )
712
+
713
+ pipeline.set_progress_bar_config(disable=not progress_bar)
714
+
715
+ samples = pipeline(
716
+ height=height,
717
+ width=width,
718
+ guidance_scale=cfg_scale,
719
+ num_inference_steps=num_steps,
720
+ prompt_embeds=prompt_embeds[:b],
721
+ pooled_prompt_embeds=pooled_prompt_embeds[:b],
722
+ negative_prompt_embeds=prompt_embeds[b:],
723
+ negative_pooled_prompt_embeds=pooled_prompt_embeds[b:],
724
+ generator=generator,
725
+ output_type='latent',
726
+ cond_latents=cond_latents
727
+ ).images.to(self.dtype)
728
+
729
+ return self.latents_to_pixels(samples), output_reasoning
730
+
731
+ @torch.no_grad()
732
+ def understand(self, prompt, pixel_values, max_new_tokens=512, progress_bar=True):
733
+ if isinstance(prompt, str):
734
+ prompt = [prompt]
735
+ if isinstance(pixel_values, torch.Tensor):
736
+ pixel_values = [pixel_values]
737
+
738
+ bsz = len(prompt)
739
+ assert len(pixel_values) == bsz
740
+
741
+ pixel_values = [pad_an_image_tensor(img) for img in pixel_values]
742
+ pixel_values = torch.stack(pixel_values).to(dtype=self.dtype, device=self.device)
743
+ image_embeds = self.extract_visual_features(pixel_values)
744
+ image_embeds = self.projector(image_embeds)
745
+
746
+ conversations = [[{'input': f"{DEFAULT_IMAGE_TOKEN}\n{p}",}] for p in prompt]
747
+
748
+ text_inputs = self.prepare_und_prompts(conversations=conversations, image_lengths=image_embeds.shape[1],
749
+ input_ids_with_output=False)
750
+
751
+ input_ids, attention_mask, position_ids = \
752
+ text_inputs['input_ids'], text_inputs['attention_mask'], text_inputs['position_ids']
753
+
754
+ inputs_embeds = torch.zeros(*input_ids.shape, self.llm.config.hidden_size,
755
+ device=self.device, dtype=self.dtype)
756
+ inputs_embeds[input_ids == INPUT_IMAGE_TOKEN_INDEX] = image_embeds.flatten(0, 1)
757
+ inputs_embeds[input_ids != INPUT_IMAGE_TOKEN_INDEX] = \
758
+ self.llm.get_input_embeddings()(input_ids[input_ids != INPUT_IMAGE_TOKEN_INDEX])
759
+
760
+ past_key_values = DynamicCache.from_legacy_cache()
761
+
762
+ output_ids = []
763
+
764
+ for _ in tqdm(range(max_new_tokens), disable=not progress_bar):
765
+ output = self.llm.model(
766
+ inputs_embeds=inputs_embeds,
767
+ attention_mask=attention_mask,
768
+ position_ids=position_ids,
769
+ past_key_values=past_key_values,
770
+ use_cache=True,
771
+ return_dict=True)
772
+ logits = self.llm.get_output_embeddings()(output.last_hidden_state[:, -1:])
773
+ input_ids = torch.argmax(logits, dim=-1) # b 1
774
+ if len(output_ids) > 0:
775
+ input_ids = torch.where(output_ids[-1] == self.tokenizer.eos_token_id,
776
+ output_ids[-1], input_ids)
777
+ output_ids.append(input_ids)
778
+
779
+ if (input_ids == self.tokenizer.eos_token_id).all():
780
+ break
781
+
782
+ inputs_embeds = self.llm.get_input_embeddings()(input_ids)
783
+ attention_mask = torch.cat([attention_mask, attention_mask.new_ones(bsz, 1)], dim=1)
784
+ position_ids = torch.max(position_ids, dim=1, keepdim=True).values + 1
785
+ past_key_values = output.past_key_values
786
+
787
+ output_ids = torch.cat(output_ids, dim=1)
788
+ output_text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
789
+
790
+ return output_text