JiantaoLin commited on
Commit
98bebfc
·
0 Parent(s):

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +319 -0
  2. extension/put_here.txt +0 -0
  3. image_to_mesh.py +437 -0
  4. models/ISOMER/__init__.py +0 -0
  5. models/ISOMER/data/__init__.py +0 -0
  6. models/ISOMER/data/utils.py +87 -0
  7. models/ISOMER/mesh_reconstruction/__init__.py +0 -0
  8. models/ISOMER/mesh_reconstruction/func.py +227 -0
  9. models/ISOMER/mesh_reconstruction/opt.py +191 -0
  10. models/ISOMER/mesh_reconstruction/recon.py +58 -0
  11. models/ISOMER/mesh_reconstruction/refine.py +86 -0
  12. models/ISOMER/mesh_reconstruction/remesh.py +363 -0
  13. models/ISOMER/mesh_reconstruction/render.py +142 -0
  14. models/ISOMER/model/__init__.py +0 -0
  15. models/ISOMER/model/inference_pipeline.py +189 -0
  16. models/ISOMER/projection_func.py +86 -0
  17. models/ISOMER/reconstruction_func.py +88 -0
  18. models/ISOMER/scripts/__init__.py +0 -0
  19. models/ISOMER/scripts/all_typing.py +42 -0
  20. models/ISOMER/scripts/fast_geo.py +86 -0
  21. models/ISOMER/scripts/load_onnx.py +48 -0
  22. models/ISOMER/scripts/mesh_init.py +142 -0
  23. models/ISOMER/scripts/normal_to_height_map.py +205 -0
  24. models/ISOMER/scripts/proj_commands.py +69 -0
  25. models/ISOMER/scripts/project_mesh.py +401 -0
  26. models/ISOMER/scripts/refine_lr_to_sr.py +60 -0
  27. models/ISOMER/scripts/sd_model_zoo.py +131 -0
  28. models/ISOMER/scripts/upsampler.py +260 -0
  29. models/ISOMER/scripts/utils.py +611 -0
  30. models/lrm/config/PRM_inference.yaml +22 -0
  31. models/lrm/models/__init__.py +0 -0
  32. models/lrm/models/decoder/__init__.py +0 -0
  33. models/lrm/models/decoder/transformer.py +123 -0
  34. models/lrm/models/encoder/__init__.py +0 -0
  35. models/lrm/models/encoder/dino.py +550 -0
  36. models/lrm/models/encoder/dino_wrapper.py +80 -0
  37. models/lrm/models/geometry/__init__.py +7 -0
  38. models/lrm/models/geometry/camera/__init__.py +16 -0
  39. models/lrm/models/geometry/camera/perspective_camera.py +35 -0
  40. models/lrm/models/geometry/render/__init__.py +8 -0
  41. models/lrm/models/geometry/render/neural_render.py +293 -0
  42. models/lrm/models/geometry/render/renderutils/__init__.py +11 -0
  43. models/lrm/models/geometry/render/renderutils/bsdf.py +151 -0
  44. models/lrm/models/geometry/render/renderutils/c_src/bsdf.cu +710 -0
  45. models/lrm/models/geometry/render/renderutils/c_src/bsdf.h +84 -0
  46. models/lrm/models/geometry/render/renderutils/c_src/common.cpp +74 -0
  47. models/lrm/models/geometry/render/renderutils/c_src/common.h +41 -0
  48. models/lrm/models/geometry/render/renderutils/c_src/cubemap.cu +350 -0
  49. models/lrm/models/geometry/render/renderutils/c_src/cubemap.h +38 -0
  50. models/lrm/models/geometry/render/renderutils/c_src/loss.cu +210 -0
app.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import subprocess
4
+ import shlex
5
+ subprocess.run(
6
+ shlex.split(
7
+ "pip install ./extension/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
8
+ )
9
+ )
10
+
11
+ subprocess.run(
12
+ shlex.split(
13
+ "pip install ./extension/renderutils_plugin-1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
14
+ )
15
+ )
16
+ import torch
17
+ import numpy as np
18
+ from PIL import Image
19
+ from einops import rearrange
20
+ from diffusers import FluxPipeline
21
+ from models.lrm.utils.camera_util import get_flux_input_cameras
22
+ from models.lrm.utils.infer_util import save_video
23
+ from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
24
+ from models.lrm.utils.render_utils import rotate_x, rotate_y
25
+ from models.lrm.utils.train_util import instantiate_from_config
26
+ from models.ISOMER.reconstruction_func import reconstruction
27
+ from models.ISOMER.projection_func import projection
28
+ import os
29
+ from einops import rearrange
30
+ from omegaconf import OmegaConf
31
+ import spaces
32
+ import torch
33
+ import numpy as np
34
+ import trimesh
35
+ import torchvision
36
+ import torch.nn.functional as F
37
+ from PIL import Image
38
+ from torchvision import transforms
39
+ from torchvision.transforms import v2
40
+ from diffusers import HeunDiscreteScheduler
41
+ from diffusers import FluxPipeline
42
+ from pytorch_lightning import seed_everything
43
+ import os
44
+ from huggingface_hub import hf_hub_download
45
+
46
+
47
+ from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
48
+
49
+ device = "cuda"
50
+ resolution = 512
51
+ save_dir = "./outputs"
52
+ normal_transfer = NormalTransfer()
53
+ isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device)
54
+ isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device)
55
+ isomer_radius = 4.5
56
+ isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
57
+ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
58
+
59
+ # model initialization and loading
60
+ # flux
61
+ flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
62
+ flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
63
+ flux_pipe.load_lora_weights(flux_lora_ckpt_path)
64
+
65
+ flux_pipe.to(device=device, dtype=torch.bfloat16)
66
+ generator = torch.Generator(device=device).manual_seed(10)
67
+
68
+ # lrm
69
+ config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
70
+ model_config = config.model_config
71
+ infer_config = config.infer_config
72
+ model = instantiate_from_config(model_config)
73
+ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
74
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
75
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
76
+ model.load_state_dict(state_dict, strict=True)
77
+
78
+ model = model.to(device)
79
+ model.init_flexicubes_geometry(device, fovy=50.0)
80
+ model = model.eval()
81
+
82
+ @spaces.GPU
83
+ def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
84
+ images = image.unsqueeze(0).to(device)
85
+ images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
86
+ # breakpoint()
87
+ with torch.no_grad():
88
+ # get triplane
89
+ planes = model.forward_planes(images, input_cameras)
90
+
91
+ mesh_path_idx = os.path.join(save_path, f'{name}.obj')
92
+
93
+ mesh_out = model.extract_mesh(
94
+ planes,
95
+ use_texture_map=export_texmap,
96
+ **infer_config,
97
+ )
98
+ if export_texmap:
99
+ vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
100
+ save_obj_with_mtl(
101
+ vertices.data.cpu().numpy(),
102
+ uvs.data.cpu().numpy(),
103
+ faces.data.cpu().numpy(),
104
+ mesh_tex_idx.data.cpu().numpy(),
105
+ tex_map.permute(1, 2, 0).data.cpu().numpy(),
106
+ mesh_path_idx,
107
+ )
108
+ else:
109
+ vertices, faces, vertex_colors = mesh_out
110
+ save_obj(vertices, faces, vertex_colors, mesh_path_idx)
111
+ print(f"Mesh saved to {mesh_path_idx}")
112
+
113
+ render_size = 512
114
+ if if_save_video:
115
+ video_path_idx = os.path.join(save_path, f'{name}.mp4')
116
+ render_size = infer_config.render_resolution
117
+ ENV = load_mipmap("models/lrm/env_mipmap/6")
118
+ materials = (0.0,0.9)
119
+
120
+ all_mv, all_mvp, all_campos = get_render_cameras_video(
121
+ batch_size=1,
122
+ M=240,
123
+ radius=4.5,
124
+ elevation=(90, 60.0),
125
+ is_flexicubes=True,
126
+ fov=30
127
+ )
128
+
129
+ frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
130
+ model,
131
+ planes,
132
+ render_cameras=all_mvp,
133
+ camera_pos=all_campos,
134
+ env=ENV,
135
+ materials=materials,
136
+ render_size=render_size,
137
+ chunk_size=20,
138
+ is_flexicubes=True,
139
+ )
140
+ normals = (torch.nn.functional.normalize(normals) + 1) / 2
141
+ normals = normals * alphas + (1-alphas)
142
+ all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
143
+
144
+ save_video(
145
+ all_frames,
146
+ video_path_idx,
147
+ fps=30,
148
+ )
149
+ print(f"Video saved to {video_path_idx}")
150
+
151
+ return vertices, faces
152
+
153
+
154
+ def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
155
+ if local_normal_images.min() >= 0:
156
+ local_normal = local_normal_images.float() * 2 - 1
157
+ else:
158
+ local_normal = local_normal_images.float()
159
+ global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
160
+ global_normal[...,0] *= -1
161
+ global_normal = (global_normal + 1) / 2
162
+ global_normal = global_normal.permute(0, 3, 1, 2)
163
+ return global_normal
164
+
165
+ # 生成多视图图像
166
+ @spaces.GPU
167
+ def generate_multi_view_images(prompt, seed):
168
+ generator = torch.manual_seed(seed)
169
+ with torch.no_grad():
170
+ images = flux_pipe(
171
+ prompt=prompt,
172
+ num_inference_steps=30,
173
+ guidance_scale=3.5,
174
+ num_images_per_prompt=1,
175
+ width=resolution * 4,
176
+ height=resolution * 2,
177
+ output_type='np',
178
+ generator=generator
179
+ ).images
180
+ return images
181
+
182
+ # 重建 3D 模型
183
+ @spaces.GPU
184
+ def reconstruct_3d_model(images, prompt):
185
+ rgb_normal_grid = images
186
+ save_dir_path = os.path.join(save_dir, prompt.replace(" ", "_"))
187
+ os.makedirs(save_dir_path, exist_ok=True)
188
+
189
+ images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
190
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
191
+ rgb_multi_view = images[:4, :3, :, :]
192
+ normal_multi_view = images[4:, :3, :, :]
193
+ multi_view_mask = get_background(normal_multi_view)
194
+ rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
195
+ input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device)
196
+ vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
197
+ # local normal to global normal
198
+
199
+ global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
200
+ global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
201
+
202
+ global_normal = global_normal.permute(0,2,3,1)
203
+ rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
204
+ multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
205
+ vertices = torch.from_numpy(vertices).to(device)
206
+ faces = torch.from_numpy(faces).to(device)
207
+ vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
208
+ vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
209
+
210
+ # global_normal: B,H,W,3
211
+ # multi_view_mask: B,H,W
212
+ # rgb_multi_view: B,H,W,3
213
+
214
+ meshes = reconstruction(
215
+ normal_pils=global_normal,
216
+ masks=multi_view_mask,
217
+ weights=isomer_geo_weights,
218
+ fov=30,
219
+ radius=isomer_radius,
220
+ camera_angles_azi=isomer_azimuths,
221
+ camera_angles_ele=isomer_elevations,
222
+ expansion_weight_stage1=0.1,
223
+ init_type="file",
224
+ init_verts=vertices,
225
+ init_faces=faces,
226
+ stage1_steps=0,
227
+ stage2_steps=50,
228
+ start_edge_len_stage1=0.1,
229
+ end_edge_len_stage1=0.02,
230
+ start_edge_len_stage2=0.02,
231
+ end_edge_len_stage2=0.005,
232
+ )
233
+
234
+
235
+ save_glb_addr = projection(
236
+ meshes,
237
+ masks=multi_view_mask,
238
+ images=rgb_multi_view,
239
+ azimuths=isomer_azimuths,
240
+ elevations=isomer_elevations,
241
+ weights=isomer_color_weights,
242
+ fov=30,
243
+ radius=isomer_radius,
244
+ save_dir=f"{save_dir_path}/ISOMER/",
245
+ )
246
+
247
+ return save_glb_addr
248
+
249
+ # Gradio 接口函数
250
+ def gradio_pipeline(prompt, seed):
251
+ # 生成多视图图像
252
+ rgb_normal_grid = generate_multi_view_images(prompt, seed)
253
+ image_preview = Image.fromarray((rgb_normal_grid * 255).astype(np.uint8))
254
+
255
+ # 3d reconstruction
256
+
257
+
258
+ # 重建 3D 模型并返回 glb 路径
259
+ save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
260
+
261
+ return image_preview, save_glb_addr
262
+
263
+ # Gradio Blocks 应用
264
+ with gr.Blocks() as demo:
265
+ with gr.Row(variant="panel"):
266
+ # 左侧输入区域
267
+ with gr.Column():
268
+ with gr.Row():
269
+ prompt_input = gr.Textbox(
270
+ label="Enter Prompt",
271
+ placeholder="Describe your 3D model...",
272
+ lines=2,
273
+ elem_id="prompt_input"
274
+ )
275
+
276
+ with gr.Row():
277
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
278
+
279
+ with gr.Row():
280
+ submit = gr.Button("Generate", elem_id="generate", variant="primary")
281
+
282
+ with gr.Row(variant="panel"):
283
+ gr.Markdown("Examples:")
284
+ gr.Examples(
285
+ examples=[
286
+ ["a castle on a hill"],
287
+ ["an owl wearing a hat"],
288
+ ["a futuristic car"]
289
+ ],
290
+ inputs=[prompt_input],
291
+ label="Prompt Examples"
292
+ )
293
+
294
+ # 右侧输出区域
295
+ with gr.Column():
296
+ with gr.Row():
297
+ rgb_normal_grid_image = gr.Image(
298
+ label="RGB Normal Grid",
299
+ type="pil",
300
+ interactive=False
301
+ )
302
+
303
+ with gr.Row():
304
+ with gr.Tab("GLB"):
305
+ output_glb_model = gr.Model3D(
306
+ label="Generated 3D Model (GLB Format)",
307
+ interactive=False
308
+ )
309
+ gr.Markdown("Download the model for proper visualization.")
310
+
311
+ # 处理逻辑
312
+ submit.click(
313
+ fn=gradio_pipeline, inputs=[prompt_input, sample_seed],
314
+ outputs=[rgb_normal_grid_image, output_glb_model]
315
+ )
316
+
317
+ # 启动应用
318
+ demo.queue(max_size=10)
319
+ demo.launch(server_port=1211)
extension/put_here.txt ADDED
File without changes
image_to_mesh.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from einops import rearrange
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ import numpy as np
6
+ import trimesh
7
+ import torchvision
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from torchvision.transforms import v2
12
+ from transformers import AutoProcessor, AutoModelForCausalLM
13
+ import rembg
14
+ from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline
15
+ from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
16
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler
17
+ from pytorch_lightning import seed_everything
18
+ import os
19
+
20
+ from models.ISOMER.reconstruction_func import reconstruction
21
+ from models.ISOMER.projection_func import projection
22
+ from models.lrm.utils.infer_util import remove_background, resize_foreground, save_video
23
+ from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
24
+ from models.lrm.utils.render_utils import rotate_x, rotate_y
25
+ from models.lrm.utils.train_util import instantiate_from_config
26
+ from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
27
+ from utils.tool import NormalTransfer, get_render_cameras_frames, load_mipmap
28
+ from utils.tool import get_background, get_render_cameras_video, render_frames
29
+ import time
30
+
31
+ device = "cuda"
32
+ resolution = 512
33
+ save_dir = "./outputs"
34
+ zero123plus_diffusion_steps = 75
35
+ normal_transfer = NormalTransfer()
36
+ rembg_session = rembg.new_session()
37
+ isomer_azimuths = torch.from_numpy(np.array([270, 0, 90, 180])).to(device)
38
+ isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).to(device)
39
+ isomer_radius = 4.1
40
+ isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
41
+ isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
42
+ # seed_everything(42)
43
+
44
+ # model initialization and loading
45
+ # flux
46
+ print('==> Loading Flux model ...')
47
+ flux_base_model_pth = "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev"
48
+ flux_controlnet = FluxControlNetModel.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro")
49
+ flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
50
+
51
+ flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
52
+
53
+
54
+ flux_pipe.to(device=device, dtype=torch.bfloat16)
55
+ generator = torch.Generator(device=device).manual_seed(0)
56
+
57
+ # lrm
58
+ print('==> Loading LRM model ...')
59
+ config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
60
+ model_config = config.model_config
61
+ infer_config = config.infer_config
62
+ model = instantiate_from_config(model_config)
63
+ model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
64
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
65
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
66
+ model.load_state_dict(state_dict, strict=True)
67
+
68
+ model = model.to(device)
69
+ model.init_flexicubes_geometry(device, fovy=50.0)
70
+ model = model.eval()
71
+
72
+ # zero123++
73
+ print('==> Loading diffusion model ...')
74
+ zero123plus_pipeline = DiffusionPipeline.from_pretrained(
75
+ "sudo-ai/zero123plus-v1.2",
76
+ custom_pipeline="./models/zero123plus",
77
+ torch_dtype=torch.float16,
78
+ )
79
+ zero123plus_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
80
+ zero123plus_pipeline.scheduler.config, timestep_spacing='trailing'
81
+ )
82
+ unet_ckpt_path = "./checkpoint/zero123++/flexgen_19w.ckpt"
83
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
84
+ state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
85
+ zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
86
+ zero123plus_pipeline = zero123plus_pipeline.to(device)
87
+
88
+ # unet_ckpt_path = "checkpoint/zero123++/diffusion_pytorch_model.bin"
89
+ # state_dict = torch.load(unet_ckpt_path, map_location='cpu')
90
+ # zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
91
+ # zero123plus_pipeline = zero123plus_pipeline.to(device)
92
+
93
+ # florence
94
+ caption_model = AutoModelForCausalLM.from_pretrained(
95
+ "/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", torch_dtype=torch.bfloat16, trust_remote_code=True,
96
+ ).to(device)
97
+ caption_processor = AutoProcessor.from_pretrained("/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", trust_remote_code=True)
98
+
99
+ # Flux multi-view generation
100
+ def multi_view_rgb_normal_generation_with_controlnet(prompt, image, strength=1.0,
101
+ control_image=[],
102
+ control_mode=[],
103
+ control_guidance_start=None,
104
+ control_guidance_end=None,
105
+ controlnet_conditioning_scale=None,
106
+ lora_scale=1.0
107
+ ):
108
+ control_mode_dict = {
109
+ 'canny': 0,
110
+ 'tile': 1,
111
+ 'depth': 2,
112
+ 'blur': 3,
113
+ 'pose': 4,
114
+ 'gray': 5,
115
+ 'lq': 6,
116
+ } # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
117
+
118
+ hparam_dict = {
119
+ 'prompt': prompt,
120
+ 'image': image,
121
+ 'strength': strength,
122
+ 'num_inference_steps': 30,
123
+ 'guidance_scale': 3.5,
124
+ 'num_images_per_prompt': 1,
125
+ 'width': resolution*4,
126
+ 'height': resolution*2,
127
+ 'output_type': 'np',
128
+ 'generator': generator,
129
+ 'joint_attention_kwargs': {"scale": lora_scale}
130
+ }
131
+
132
+ # append controlnet hparams
133
+ if len(control_image) > 0:
134
+ assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
135
+
136
+ ctrl_hparams = {
137
+ 'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
138
+ 'control_image': control_image,
139
+ 'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))],
140
+ 'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))],
141
+ 'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))],
142
+ }
143
+
144
+ hparam_dict.update(ctrl_hparams)
145
+
146
+ # generate multi-view images
147
+ with torch.no_grad():
148
+ image = flux_pipe(
149
+ **hparam_dict
150
+ ).images
151
+ return image
152
+
153
+ # captioning
154
+ def run_captioning(image):
155
+ device = "cuda" if torch.cuda.is_available() else "cpu"
156
+ torch_dtype = torch.bfloat16
157
+
158
+ if isinstance(image, str): # If image is a file path
159
+ image = Image.open(image).convert("RGB")
160
+
161
+ prompt = "<MORE_DETAILED_CAPTION>"
162
+ inputs = caption_processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
163
+ # print(f"inputs {inputs}")
164
+
165
+ generated_ids = caption_model.generate(
166
+ input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
167
+ )
168
+
169
+ generated_text = caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
170
+ parsed_answer = caption_processor.post_process_generation(
171
+ generated_text, task=prompt, image_size=(image.width, image.height)
172
+ )
173
+ # print(f"parsed_answer = {parsed_answer}")
174
+ caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "")
175
+ return caption_text
176
+
177
+
178
+ # zero123++ multi-view generation
179
+ def multi_view_rgb_generation(cond_img):
180
+ # generate multi-view images
181
+ with torch.no_grad():
182
+ output_image = zero123plus_pipeline(
183
+ cond_img,
184
+ num_inference_steps=zero123plus_diffusion_steps,
185
+ width=resolution*2,
186
+ height=resolution*2,
187
+ ).images[0]
188
+ return output_image
189
+
190
+ # lrm reconstructions
191
+ def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False, render_azimuths=None, render_elevations=None, render_radius=None, render_fov=30):
192
+ images = image.unsqueeze(0).to(device)
193
+ images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
194
+ # breakpoint()
195
+ with torch.no_grad():
196
+ # get triplane
197
+ planes = model.forward_planes(images, input_cameras)
198
+
199
+ mesh_path_idx = os.path.join(save_path, f'{name}.obj')
200
+
201
+ mesh_out = model.extract_mesh(
202
+ planes,
203
+ use_texture_map=export_texmap,
204
+ **infer_config,
205
+ )
206
+ if export_texmap:
207
+ vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
208
+ save_obj_with_mtl(
209
+ vertices.data.cpu().numpy(),
210
+ uvs.data.cpu().numpy(),
211
+ faces.data.cpu().numpy(),
212
+ mesh_tex_idx.data.cpu().numpy(),
213
+ tex_map.permute(1, 2, 0).data.cpu().numpy(),
214
+ mesh_path_idx,
215
+ )
216
+ else:
217
+ vertices, faces, vertex_colors = mesh_out
218
+ save_obj(vertices, faces, vertex_colors, mesh_path_idx)
219
+ print(f"Mesh saved to {mesh_path_idx}")
220
+
221
+ render_size = 512
222
+ if if_save_video:
223
+ video_path_idx = os.path.join(save_path, f'{name}.mp4')
224
+ render_size = infer_config.render_resolution
225
+ ENV = load_mipmap("models/lrm/env_mipmap/6")
226
+ materials = (0.0,0.9)
227
+
228
+ all_mv, all_mvp, all_campos = get_render_cameras_video(
229
+ batch_size=1,
230
+ M=240,
231
+ radius=4.5,
232
+ elevation=(90, 60.0),
233
+ is_flexicubes=True,
234
+ fov=30
235
+ )
236
+
237
+ frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
238
+ model,
239
+ planes,
240
+ render_cameras=all_mvp,
241
+ camera_pos=all_campos,
242
+ env=ENV,
243
+ materials=materials,
244
+ render_size=render_size,
245
+ chunk_size=20,
246
+ is_flexicubes=True,
247
+ )
248
+ normals = (torch.nn.functional.normalize(normals) + 1) / 2
249
+ normals = normals * alphas + (1-alphas)
250
+ all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
251
+
252
+ # breakpoint()
253
+ save_video(
254
+ all_frames,
255
+ video_path_idx,
256
+ fps=30,
257
+ )
258
+ print(f"Video saved to {video_path_idx}")
259
+
260
+ if render_azimuths is not None and render_elevations is not None and render_radius is not None:
261
+ render_size = infer_config.render_resolution
262
+ ENV = load_mipmap("models/lrm/env_mipmap/6")
263
+ materials = (0.0,0.9)
264
+ all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames(
265
+ batch_size=1,
266
+ radius=render_radius,
267
+ azimuths=render_azimuths,
268
+ elevations=render_elevations,
269
+ fov=30
270
+ )
271
+ frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
272
+ model,
273
+ planes,
274
+ render_cameras=all_mvp,
275
+ camera_pos=all_campos,
276
+ env=ENV,
277
+ materials=materials,
278
+ render_size=render_size,
279
+ render_mv = all_mv,
280
+ local_normal=True,
281
+ identity_mv=identity_mv,
282
+ )
283
+ else:
284
+ normals = None
285
+ frames = None
286
+ albedos = None
287
+
288
+ return vertices, faces, normals, frames, albedos
289
+
290
+
291
+ def transform_normal(input_normal, azimuths_deg, elevations_deg, radius=4.5, is_global_to_local=False):
292
+ """
293
+ input_normal: in range [-1, 1], shape (b c h w)
294
+ """
295
+
296
+ input_normal = input_normal.permute(0, 2, 3, 1).cpu()
297
+
298
+ azimuths_deg = np.array(azimuths_deg)
299
+ elevations_deg = np.array(elevations_deg)
300
+
301
+ if is_global_to_local:
302
+ local_normal = normal_transfer.trans_global_2_local(input_normal, azimuths_deg, elevations_deg)
303
+ return local_normal.permute(0, 3, 1, 2)
304
+ else:
305
+ global_normal = normal_transfer.trans_local_2_global(input_normal, azimuths_deg, elevations_deg, radius=radius, for_lotus=False)
306
+ global_normal[..., 0] *= -1
307
+ return global_normal.permute(0, 3, 1, 2)
308
+
309
+ def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg):
310
+ if local_normal_images.min() >= 0:
311
+ local_normal = local_normal_images.float() * 2 - 1
312
+ else:
313
+ local_normal = local_normal_images.float()
314
+ global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
315
+ global_normal[...,0] *= -1
316
+ global_normal = (global_normal + 1) / 2
317
+ global_normal = global_normal.permute(0, 3, 1, 2)
318
+ return global_normal
319
+
320
+ def main():
321
+ image_pth = "examples/蓝色小怪物.webp"
322
+ save_dir_path = os.path.join(save_dir, image_pth.split("/")[-1].split(".")[0])
323
+ os.makedirs(save_dir_path, exist_ok=True)
324
+ input_image = Image.open(image_pth)
325
+ # if not args.no_rembg:
326
+ input_image = remove_background(input_image, rembg_session)
327
+ input_image = resize_foreground(input_image, 0.85)
328
+
329
+ # generate caption
330
+ image_caption = run_captioning(image_pth)
331
+
332
+ # generate multi-view images
333
+ output_image = multi_view_rgb_generation(input_image)
334
+
335
+ # lrm reconstructions
336
+ rgb_multi_view = np.asarray(output_image, dtype=np.float32) / 255.0
337
+ rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
338
+ rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2) # (8, 3, 512, 512)
339
+
340
+ input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
341
+
342
+ vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
343
+ lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm',
344
+ export_texmap=False, if_save_video=False, render_azimuths=isomer_azimuths,
345
+ render_elevations=isomer_elevations, render_radius=isomer_radius, render_fov=30)
346
+
347
+ vertices = torch.from_numpy(vertices).to(device)
348
+ faces = torch.from_numpy(faces).to(device)
349
+ vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
350
+ vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
351
+
352
+
353
+ # lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
354
+ lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[[3,0,1,2]].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
355
+ # rgb_multi_view[[3,0,1,2]] : (B,3,H,W)
356
+ # lrm_multi_view_normals : (B,3,H,W)
357
+ # combined_images = 0.5 * rgb_multi_view[[3,0,1,2]].cpu() + 0.5 * (lrm_multi_view_normals.cpu() + 1) / 2
358
+ # torchvision.utils.save_image(combined_images, os.path.join("debug_output", 'combined.png'))
359
+ # breakpoint()
360
+ # Use the low-quality controlnet by default, feel free to try the others
361
+ control_image = [lrm_3D_bundle_image * 2 - 1]
362
+ control_mode = ['tile']
363
+ control_guidance_start = [0.0]
364
+ control_guidance_end = [0.3]
365
+ controlnet_conditioning_scale = [0.8]
366
+
367
+ flux_pipe.controlnet = FluxMultiControlNetModel([flux_controlnet for _ in control_mode])
368
+ # breakpoint()
369
+ rgb_normal_grid = multi_view_rgb_normal_generation_with_controlnet(
370
+ prompt= ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', image_caption]),
371
+ image=lrm_3D_bundle_image,
372
+ strength=0.6,
373
+ control_image=control_image,
374
+ control_mode=control_mode,
375
+ control_guidance_start=control_guidance_start,
376
+ control_guidance_end=control_guidance_end,
377
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
378
+ lora_scale=1.0
379
+ ) # noted that rgb_normal_grid is a (b, h, w, c) numpy array
380
+
381
+ rgb_normal_grid = torch.from_numpy(rgb_normal_grid).contiguous().float()
382
+ rgb_normal_grid = rearrange(rgb_normal_grid.squeeze(0), '(n h) (m w) c-> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
383
+ rgb_multi_view = rgb_normal_grid[:4, :3, :, :].cuda()
384
+ normal_multi_view = rgb_normal_grid[4:, :3, :, :].cuda()
385
+ multi_view_mask = get_background(normal_multi_view).cuda()
386
+ rgb_multi_view = rgb_multi_view * multi_view_mask + (1-multi_view_mask)
387
+
388
+ # local normal to global normal
389
+ global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), isomer_azimuths, isomer_elevations).cuda()
390
+
391
+ global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
392
+
393
+ global_normal = global_normal.permute(0,2,3,1)
394
+ multi_view_mask = multi_view_mask.squeeze(1)
395
+ rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
396
+ # global_normal: B,H,W,3
397
+ # multi_view_mask: B,H,W
398
+ # rgb_multi_view: B,H,W,3
399
+
400
+
401
+ meshes = reconstruction(
402
+ normal_pils=global_normal,
403
+ masks=multi_view_mask,
404
+ weights=isomer_geo_weights,
405
+ fov=30,
406
+ radius=isomer_radius,
407
+ camera_angles_azi=isomer_azimuths,
408
+ camera_angles_ele=isomer_elevations,
409
+ expansion_weight_stage1=0.1,
410
+ init_type="file",
411
+ init_verts=vertices,
412
+ init_faces=faces,
413
+ stage1_steps=0,
414
+ stage2_steps=50,
415
+ start_edge_len_stage1=0.1,
416
+ end_edge_len_stage1=0.02,
417
+ start_edge_len_stage2=0.02,
418
+ end_edge_len_stage2=0.005,
419
+ )
420
+
421
+ save_glb_addr = projection(
422
+ meshes=meshes,
423
+ masks=multi_view_mask,
424
+ images=rgb_multi_view,
425
+ azimuths=isomer_azimuths,
426
+ elevations=isomer_elevations,
427
+ weights=isomer_color_weights,
428
+ fov=30,
429
+ radius=isomer_radius,
430
+ save_dir=f"{save_dir_path}/ISOMER/",
431
+ )
432
+ print(f'saved to {save_glb_addr}')
433
+
434
+
435
+
436
+ if __name__ == '__main__':
437
+ main()
models/ISOMER/__init__.py ADDED
File without changes
models/ISOMER/data/__init__.py ADDED
File without changes
models/ISOMER/data/utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import os
5
+ from pytorch3d.io import load_obj
6
+ import trimesh
7
+ from pytorch3d.structures import Meshes
8
+ # from rembg import remove
9
+
10
+ def remove_color(arr):
11
+ if arr.shape[-1] == 4:
12
+ arr = arr[..., :3]
13
+
14
+ # Convert to torch tensor
15
+ if type(arr) is not torch.Tensor:
16
+ arr = torch.tensor(arr, dtype=torch.int32)
17
+
18
+ # Calculate diffs
19
+ base = arr[0, 0]
20
+ diffs = torch.abs(arr - base).sum(dim=-1)
21
+ alpha = (diffs <= 80)
22
+
23
+ arr[alpha] = 255
24
+ alpha = ~alpha
25
+ alpha = alpha.unsqueeze(-1).int() * 255
26
+ arr = torch.cat([arr, alpha], dim=-1)
27
+
28
+ return arr
29
+
30
+ def simple_remove_bkg_normal(imgs, rm_bkg_with_rembg, return_Image=False):
31
+ """Only works for normal"""
32
+ rets = []
33
+ for img in imgs:
34
+ if rm_bkg_with_rembg:
35
+ from rembg import remove
36
+ image = Image.fromarray(img.to(torch.uint8).detach().cpu().numpy()) if isinstance(img, torch.Tensor) else img
37
+ removed_image = remove(image)
38
+ arr = np.array(removed_image)
39
+ arr = torch.tensor(arr, dtype=torch.uint8)
40
+ else:
41
+ arr = remove_color(img)
42
+
43
+ if return_Image:
44
+ rets.append(Image.fromarray(arr.to(torch.uint8).detach().cpu().numpy()))
45
+ else:
46
+ rets.append(arr.to(torch.uint8))
47
+
48
+ return rets
49
+
50
+
51
+ def load_glb(file_path):
52
+ # Load the .glb file as a scene and merge all meshes
53
+ scene_or_mesh = trimesh.load(file_path)
54
+
55
+ mesh = scene_or_mesh.dump(concatenate=True) if isinstance(scene_or_mesh, trimesh.Scene) else scene_or_mesh
56
+
57
+ # Extract vertices and faces from the merged mesh
58
+ verts = torch.tensor(mesh.vertices, dtype=torch.float32)
59
+ faces = torch.tensor(mesh.faces, dtype=torch.int64)
60
+
61
+
62
+ textured_mesh = Meshes(verts=[verts], faces=[faces])
63
+
64
+
65
+ return textured_mesh
66
+
67
+ def load_obj_with_verts_faces(file_path, return_mesh=True):
68
+ verts, faces, _ = load_obj(file_path)
69
+
70
+ verts = torch.tensor(verts, dtype=torch.float32)
71
+ faces = faces.verts_idx
72
+ faces = torch.tensor(faces, dtype=torch.int64)
73
+
74
+ if return_mesh:
75
+ return Meshes(verts=[verts], faces=[faces])
76
+ else:
77
+ return verts, faces
78
+
79
+ def normalize_mesh(vertices):
80
+ min_vals, _ = torch.min(vertices, axis=0)
81
+ max_vals, _ = torch.max(vertices, axis=0)
82
+ center = (max_vals + min_vals) / 2
83
+ vertices = vertices - center
84
+ max_extent = torch.max(max_vals - min_vals)
85
+ scale = 2.0 / max_extent
86
+ vertices = vertices * scale
87
+ return vertices
models/ISOMER/mesh_reconstruction/__init__.py ADDED
File without changes
models/ISOMER/mesh_reconstruction/func.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/Profactor/continuous-remeshing
2
+ import torch
3
+ import numpy as np
4
+ import trimesh
5
+ from typing import Tuple
6
+ from pytorch3d.renderer.cameras import camera_position_from_spherical_angles, look_at_rotation
7
+ from pytorch3d.renderer import (
8
+ FoVOrthographicCameras,
9
+ look_at_view_transform,
10
+ )
11
+
12
+ def to_numpy(*args):
13
+ def convert(a):
14
+ if isinstance(a,torch.Tensor):
15
+ return a.detach().cpu().numpy()
16
+ assert a is None or isinstance(a,np.ndarray)
17
+ return a
18
+
19
+ return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args)
20
+
21
+ def laplacian(
22
+ num_verts:int,
23
+ edges: torch.Tensor #E,2
24
+ ) -> torch.Tensor: #sparse V,V
25
+ """create sparse Laplacian matrix"""
26
+ V = num_verts
27
+ E = edges.shape[0]
28
+
29
+ #adjacency matrix,
30
+ idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T # (2, 2*E)
31
+ ones = torch.ones(2*E, dtype=torch.float32, device=edges.device)
32
+ A = torch.sparse.FloatTensor(idx, ones, (V, V))
33
+
34
+ #degree matrix
35
+ deg = torch.sparse.sum(A, dim=1).to_dense()
36
+ idx = torch.arange(V, device=edges.device)
37
+ idx = torch.stack([idx, idx], dim=0)
38
+ D = torch.sparse.FloatTensor(idx, deg, (V, V))
39
+
40
+ return D - A
41
+
42
+ def _translation(x, y, z, device):
43
+ return torch.tensor([[1., 0, 0, x],
44
+ [0, 1, 0, y],
45
+ [0, 0, 1, z],
46
+ [0, 0, 0, 1]],device=device) #4,4
47
+
48
+
49
+ def _perspective(fovy, aspect=1.0, n=0.1, f=1000.0, device=None):
50
+ fovy = fovy * torch.pi / 180
51
+ y = np.tan(fovy / 2)
52
+ return torch.tensor([[1/(y*aspect), 0, 0, 0],
53
+ [ 0, 1/-y, 0, 0],
54
+ [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
55
+ [ 0, 0, -1, 0]], dtype=torch.float32, device=device)
56
+
57
+ def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
58
+ """
59
+ see https://blog.csdn.net/wodownload2/article/details/85069240/
60
+ """
61
+ if l is None:
62
+ l = -r
63
+ if t is None:
64
+ t = r
65
+ if b is None:
66
+ b = -t
67
+ p = torch.zeros([4,4],device=device)
68
+ p[0,0] = 2*n/(r-l)
69
+ p[0,2] = (r+l)/(r-l)
70
+ p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1)
71
+ p[1,2] = (t+b)/(t-b)
72
+ p[2,2] = -(f+n)/(f-n)
73
+ p[2,3] = -(2*f*n)/(f-n)
74
+ p[3,2] = -1
75
+ return p #4,4
76
+
77
+ def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
78
+ if l is None:
79
+ l = -r
80
+ if t is None:
81
+ t = r
82
+ if b is None:
83
+ b = -t
84
+ o = torch.zeros([4,4],device=device)
85
+ o[0,0] = 2/(r-l)
86
+ o[0,3] = -(r+l)/(r-l)
87
+ o[1,1] = 2/(t-b) * (-1 if flip_y else 1)
88
+ o[1,3] = -(t+b)/(t-b)
89
+ o[2,2] = -2/(f-n)
90
+ o[2,3] = -(f+n)/(f-n)
91
+ o[3,3] = 1
92
+ return o #4,4
93
+
94
+ def make_star_cameras_orig(phis,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
95
+ if r is None:
96
+ r = 1/distance
97
+ A = len(phis)
98
+ P = pol_count
99
+ C = A * P # total number of cameras
100
+
101
+ phi = phis * torch.pi / 180
102
+ phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone()
103
+ phi_rot[:,0,2,2] = phi.cos()
104
+ phi_rot[:,0,2,0] = -phi.sin()
105
+ phi_rot[:,0,0,2] = phi.sin()
106
+ phi_rot[:,0,0,0] = phi.cos()
107
+
108
+ theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2
109
+ theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone()
110
+ theta_rot[0,:,1,1] = theta.cos()
111
+ theta_rot[0,:,1,2] = -theta.sin()
112
+ theta_rot[0,:,2,1] = theta.sin()
113
+ theta_rot[0,:,2,2] = theta.cos()
114
+
115
+ mv = torch.empty((C,4,4), device=device)
116
+ mv[:] = torch.eye(4, device=device)
117
+ mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3)
118
+ mv_ = _translation(0, 0, -distance, device) @ mv
119
+
120
+ return mv_, _projection(r,device)
121
+
122
+ def make_star_cameras_mv_new(phis,eles,distance:float=10.,r=None,fov=None,image_size=[512,512],device='cuda',translation=True):
123
+ import glm
124
+ def sample_spherical(phi, theta, cam_radius):
125
+ theta = torch.deg2rad(theta)
126
+ phi = torch.deg2rad(phi)
127
+
128
+ z = cam_radius * torch.cos(phi) * torch.sin(theta)
129
+ x = cam_radius * torch.sin(phi) * torch.sin(theta)
130
+ y = cam_radius * torch.cos(theta)
131
+
132
+ return x, y, z
133
+
134
+ all_mvs = []
135
+ for i in range(len(phis)):
136
+ azimuth = - phis[i] + 1e-10
137
+ ele = - eles[i] + 1e-10 + 90
138
+ x, y, z = sample_spherical(azimuth, ele, distance)
139
+ eye = glm.vec3(x, y, z)
140
+ at = glm.vec3(0.0, 0.0, 0.0)
141
+ up = glm.vec3(0.0, 1.0, 0.0)
142
+ view_matrix = glm.lookAt(eye, at, up)
143
+ all_mvs.append(torch.from_numpy(np.array(view_matrix)).cuda())
144
+ mv = torch.stack(all_mvs)
145
+
146
+ return mv
147
+
148
+ def make_star_cameras_mv(phis,eles,distance:float=10.,r=None,fov=None,image_size=[512,512],device='cuda',translation=True):
149
+ if r is None:
150
+ r = 0.15
151
+ A = len(phis)
152
+ assert len(eles) == A, f'len(phis): {len(phis)}, len(eles): {len(eles)}'
153
+
154
+ phi = phis * torch.pi / 180
155
+ phi_rot = torch.eye(3,device=device)[None].expand(A,3,3).clone()
156
+ phi_rot[:,2,2] = phi.cos()
157
+ phi_rot[:,2,0] = -phi.sin()
158
+ phi_rot[:,0,2] = phi.sin()
159
+ phi_rot[:,0,0] = phi.cos()
160
+
161
+
162
+ theta = eles * torch.pi / 180
163
+ theta_rot = torch.eye(3,device=device)[None].expand(A,3,3).clone()
164
+ theta_rot[:,1,1] = theta.cos()
165
+ theta_rot[:,1,2] = -theta.sin()
166
+ theta_rot[:,2,1] = theta.sin()
167
+ theta_rot[:,2,2] = theta.cos()
168
+
169
+ mv = torch.empty((A,4,4), device=device)
170
+ mv[:] = torch.eye(4, device=device)
171
+ mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(A,3,3)
172
+
173
+ if translation:
174
+ mv_ = _translation(0, 0, -distance, device) @ mv
175
+ else:
176
+ mv_ = mv
177
+ return mv_
178
+
179
+ def make_star_cameras(phis,eles,distance:float=10.,r=None,fov=None,image_size=[512,512],device='cuda',translation=True):
180
+ mv_ = make_star_cameras_mv_new(phis, eles, distance, r, device=device, translation=translation)
181
+ return mv_, _perspective(fov,device=device)
182
+
183
+ def make_star_cameras_perspective(phis, eles, distance:float=10., r=None, fov=None, device='cuda'):
184
+
185
+ return make_star_cameras(phis, eles, distance, r, fov=fov, device=device, translation=True)
186
+
187
+ def make_star_cameras_orthographic(phis, eles, distance:float=10., r=None, device='cuda'):
188
+
189
+ mv = make_star_cameras_mv_new(phis, eles, distance, r, device=device)
190
+ if r is None:
191
+ r = 1
192
+ return mv, _orthographic(r,device)
193
+
194
+ def make_sphere(level:int=2,radius=1.,device='cuda') -> Tuple[torch.Tensor,torch.Tensor]:
195
+ sphere = trimesh.creation.icosphere(subdivisions=level, radius=1.0, color=None)
196
+ vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius
197
+ faces = torch.tensor(sphere.faces, device=device, dtype=torch.long)
198
+ return vertices,faces
199
+
200
+
201
+ def get_camera(R, T, focal_length=1 / (2**0.5)):
202
+ focal_length = 1 / focal_length
203
+ camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
204
+ return camera
205
+
206
+ def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1):
207
+ R, T = look_at_view_transform(dist, 0, azim_list)
208
+ focal_length = 1 / focal
209
+ return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device)
210
+
211
+
212
+ def rotation_matrix_to_euler_angles(R, return_degrees=True):
213
+ sy = torch.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0])
214
+ singular = sy < 1e-6
215
+ if not singular:
216
+ x = torch.atan2(R[2, 1], R[2, 2])
217
+ y = torch.atan2(-R[2, 0], sy)
218
+ z = torch.atan2(R[1, 0], R[0, 0])
219
+ else:
220
+ x = torch.atan2(-R[1, 2], R[1, 1])
221
+ y = torch.atan2(-R[2, 0], sy)
222
+ z = 0
223
+
224
+ if return_degrees:
225
+ return torch.tensor([x, y, z]) * 180 / np.pi
226
+ else:
227
+ return torch.tensor([x, y, z])
models/ISOMER/mesh_reconstruction/opt.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/Profactor/continuous-remeshing
2
+ import time
3
+ import torch
4
+ import torch_scatter
5
+ from typing import Tuple
6
+ from ..mesh_reconstruction.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges
7
+
8
+ @torch.no_grad()
9
+ def remesh(
10
+ vertices_etc:torch.Tensor, #V,D
11
+ faces:torch.Tensor, #F,3 long
12
+ min_edgelen:torch.Tensor, #V
13
+ max_edgelen:torch.Tensor, #V
14
+ flip:bool,
15
+ max_vertices=1e6
16
+ ):
17
+
18
+ # dummies
19
+ vertices_etc,faces = prepend_dummies(vertices_etc,faces)
20
+ vertices = vertices_etc[:,:3] #V,3
21
+ nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device)
22
+ min_edgelen = torch.concat((nan_tensor,min_edgelen))
23
+ max_edgelen = torch.concat((nan_tensor,max_edgelen))
24
+
25
+ # collapse
26
+ edges,face_to_edge = calc_edges(faces) #E,2 F,3
27
+ edge_length = calc_edge_length(vertices,edges) #E
28
+ face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3
29
+ vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3
30
+ # then calculates the face collapses, which are the faces that can be removed without changing the overall shape of the object.
31
+ face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5)
32
+ shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0
33
+ priority = face_collapse.float() + shortness
34
+ vertices_etc, faces = collapse_edges(vertices_etc, faces, edges, priority)
35
+
36
+ # split: If the number of vertices is less than the maximum allowed, the function splits the edges that are longer than the maximum edge length.
37
+ if vertices.shape[0]<max_vertices:
38
+ edges,face_to_edge = calc_edges(faces) #E,2 F,3
39
+ vertices = vertices_etc[:,:3] #V,3
40
+ edge_length = calc_edge_length(vertices,edges) #E
41
+ splits = edge_length > max_edgelen[edges].mean(dim=-1)
42
+ vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False)
43
+
44
+ vertices_etc,faces = pack(vertices_etc,faces)
45
+ vertices = vertices_etc[:,:3]
46
+
47
+ if flip: # flips the edges of the faces
48
+ edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3
49
+ flip_edges(vertices,faces,edges,edge_to_face,with_border=False)
50
+
51
+ return remove_dummies(vertices_etc,faces)
52
+
53
+ def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int):
54
+ """lerp with adam's bias correction"""
55
+ c_prev = 1-weight**(step-1)
56
+ c = 1-weight**step
57
+ a_weight = weight*c_prev/c
58
+ b_weight = (1-weight)/c
59
+ a.mul_(a_weight).add_(b, alpha=b_weight)
60
+
61
+
62
+ class MeshOptimizer:
63
+ """Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh()."""
64
+
65
+ def __init__(self,
66
+ vertices:torch.Tensor, #V,3
67
+ faces:torch.Tensor, #F,3
68
+ lr=0.3, #learning rate
69
+ betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu
70
+ gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing)
71
+ nu_ref=0.3, #reference velocity for edge length controller
72
+ edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length
73
+ edge_len_tol=.5, #edge length tolerance for split and collapse
74
+ gain=.2, #gain value for edge length controller
75
+ laplacian_weight=.02, #for laplacian smoothing/regularization
76
+ ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0])
77
+ grad_lim=10., #gradients are clipped to m1.abs()*grad_lim
78
+ remesh_interval=1, #larger intervals are faster but with worse mesh quality
79
+ local_edgelen=True, #set to False to use a global scalar reference edge length instead
80
+ ):
81
+ self._vertices = vertices
82
+ self._faces = faces
83
+ self._lr = lr
84
+ self._betas = betas
85
+ self._gammas = gammas
86
+ self._nu_ref = nu_ref
87
+ self._edge_len_lims = edge_len_lims
88
+ self._edge_len_tol = edge_len_tol
89
+ self._gain = gain
90
+ self._laplacian_weight = laplacian_weight
91
+ self._ramp = ramp
92
+ self._grad_lim = grad_lim
93
+ self._remesh_interval = remesh_interval
94
+ self._local_edgelen = local_edgelen
95
+ self._step = 0
96
+
97
+ V = self._vertices.shape[0]
98
+ # prepare continuous tensor for all vertex-based data
99
+ self._vertices_etc = torch.zeros([V,9],device=vertices.device)
100
+ self._split_vertices_etc()
101
+ self.vertices.copy_(vertices) #initialize vertices
102
+ self._vertices.requires_grad_()
103
+ self._ref_len.fill_(edge_len_lims[1])
104
+
105
+ @property
106
+ def vertices(self):
107
+ return self._vertices
108
+
109
+ @property
110
+ def faces(self):
111
+ return self._faces
112
+
113
+ def _split_vertices_etc(self):
114
+ self._vertices = self._vertices_etc[:,:3]
115
+ self._m2 = self._vertices_etc[:,3]
116
+ self._nu = self._vertices_etc[:,4]
117
+ self._m1 = self._vertices_etc[:,5:8]
118
+ self._ref_len = self._vertices_etc[:,8]
119
+
120
+ with_gammas = any(g!=0 for g in self._gammas)
121
+ self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3]
122
+
123
+ def zero_grad(self):
124
+ self._vertices.grad = None
125
+
126
+ @torch.no_grad()
127
+ def step(self):
128
+
129
+ eps = 1e-8
130
+
131
+ self._step += 1
132
+
133
+ # spatial smoothing
134
+ edges,_ = calc_edges(self._faces) #E,2
135
+ E = edges.shape[0]
136
+ edge_smooth = self._smooth[edges] #E,2,S
137
+ neighbor_smooth = torch.zeros_like(self._smooth) #V,S
138
+ torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth)
139
+
140
+ #apply optional smoothing of m1,m2,nu
141
+ if self._gammas[0]:
142
+ self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0])
143
+ if self._gammas[1]:
144
+ self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1])
145
+ if self._gammas[2]:
146
+ self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2])
147
+
148
+ #add laplace smoothing to gradients
149
+ laplace = self._vertices - neighbor_smooth[:,:3]
150
+ grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight)
151
+
152
+ #gradient clipping
153
+ if self._step>1:
154
+ grad_lim = self._m1.abs().mul_(self._grad_lim)
155
+ grad.clamp_(min=-grad_lim,max=grad_lim)
156
+
157
+ # moment updates
158
+ lerp_unbiased(self._m1, grad, self._betas[0], self._step)
159
+ lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step)
160
+
161
+ velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3
162
+ speed = velocity.norm(dim=-1) #V
163
+
164
+ if self._betas[2]:
165
+ lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V
166
+ else:
167
+ self._nu.copy_(speed) #V
168
+
169
+ # update vertices
170
+ ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp)
171
+ self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr)
172
+
173
+ # update target edge length
174
+ if self._step % self._remesh_interval == 0:
175
+ if self._local_edgelen:
176
+ len_change = (1 + (self._nu - self._nu_ref) * self._gain)
177
+ else:
178
+ len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain)
179
+ self._ref_len *= len_change
180
+ self._ref_len.clamp_(*self._edge_len_lims)
181
+
182
+ def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]:
183
+ min_edge_len = self._ref_len * (1 - self._edge_len_tol)
184
+ max_edge_len = self._ref_len * (1 + self._edge_len_tol)
185
+
186
+ self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e6)
187
+
188
+ self._split_vertices_etc()
189
+ self._vertices.requires_grad_()
190
+
191
+ return self._vertices, self._faces
models/ISOMER/mesh_reconstruction/recon.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+ from torchvision.utils import make_grid
6
+ from typing import List
7
+ from ..mesh_reconstruction.remesh import calc_vertex_normals
8
+ from ..mesh_reconstruction.opt import MeshOptimizer
9
+ from ..mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
10
+ from ..mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
11
+ from ..scripts.utils import to_py3d_mesh, init_target
12
+
13
+ def reconstruct_stage1(pils: List[Image.Image], mv, proj, steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1, use_remesh=True):
14
+
15
+ vertices, faces = vertices.to("cuda"), faces.to("cuda")
16
+
17
+ renderer = NormalsRenderer(mv,proj,list(pils[0].size))
18
+
19
+
20
+ target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
21
+
22
+ opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len))
23
+
24
+ vertices = opt.vertices
25
+
26
+ mask = target_images[..., -1] < 0.5
27
+
28
+ for i in tqdm(range(steps)):
29
+ opt._lr *= decay
30
+
31
+ normals = calc_vertex_normals(vertices,faces)
32
+ images = renderer.render(vertices,normals,faces)
33
+
34
+ loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean()
35
+
36
+ t_mask = images[..., -1] > 0.5
37
+ loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean()
38
+
39
+ loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
40
+
41
+ loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight
42
+
43
+ loss_oob = (vertices.abs() > 0.99).float().mean() * 10
44
+ loss = loss + loss_oob
45
+
46
+
47
+ loss.backward()
48
+ opt.step()
49
+
50
+ if use_remesh:
51
+ vertices,faces = opt.remesh(poisson=False)
52
+
53
+ vertices, faces = vertices.detach(), faces.detach()
54
+
55
+ if return_mesh:
56
+ return to_py3d_mesh(vertices, faces)
57
+ else:
58
+ return vertices, faces
models/ISOMER/mesh_reconstruction/refine.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from PIL import Image
3
+ import torch
4
+ import numpy as np
5
+ from typing import List
6
+ from ..mesh_reconstruction.remesh import calc_vertex_normals
7
+ from ..mesh_reconstruction.opt import MeshOptimizer
8
+ from ..mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
9
+ from ..mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
10
+ from ..scripts.project_mesh import multiview_color_projection, get_cameras_list
11
+ from ..scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
12
+
13
+ def run_mesh_refine(vertices, faces, pils: List[Image.Image], mv, proj, weights, cameras, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True, use_remesh=True, loss_expansion_weight=0):
14
+
15
+ if process_inputs:
16
+ vertices = vertices * 2 / 1.35
17
+ vertices[..., [0, 2]] = - vertices[..., [0, 2]]
18
+
19
+ poission_steps = []
20
+
21
+ renderer = NormalsRenderer(mv,proj,list(pils[0].size))
22
+
23
+
24
+ target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
25
+
26
+ opt = MeshOptimizer(vertices,faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02)
27
+
28
+ vertices = opt.vertices
29
+ alpha_init = None
30
+
31
+ mask = target_images[..., -1] < 0.5
32
+
33
+ for i in tqdm(range(steps)):
34
+ opt.zero_grad()
35
+ opt._lr *= decay
36
+ normals = calc_vertex_normals(vertices,faces)
37
+ images = renderer.render(vertices,normals,faces)
38
+
39
+ if alpha_init is None:
40
+ alpha_init = images.detach()
41
+
42
+ # update explicit target and render images for L_ET calculation
43
+ if i < update_warmup or i % update_normal_interval == 0:
44
+ with torch.no_grad():
45
+
46
+ py3d_mesh = to_py3d_mesh(vertices, faces, normals)
47
+
48
+ _, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=weights, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear'))
49
+
50
+ target_normal = target_normal * 2 - 1
51
+ target_normal = torch.nn.functional.normalize(target_normal, dim=-1)
52
+ debug_images = renderer.render(vertices,target_normal,faces)
53
+
54
+ d_mask = images[..., -1] > 0.5
55
+ loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean()
56
+
57
+ loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
58
+
59
+ loss = loss_debug_l2 + loss_alpha_target_mask_l2
60
+
61
+ loss_oob = (vertices.abs() > 0.99).float().mean() * 10
62
+
63
+ loss = loss + loss_oob
64
+
65
+
66
+ # this loss_expand does not exist in original ISOMER. we add it here (but default loss_expansion_weight is 0)
67
+ loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean()
68
+ loss += loss_expand * loss_expansion_weight
69
+
70
+ loss.backward()
71
+ opt.step()
72
+
73
+
74
+ if use_remesh:
75
+ vertices,faces = opt.remesh(poisson=(i in poission_steps))
76
+
77
+ vertices, faces = vertices.detach(), faces.detach()
78
+
79
+ if process_outputs:
80
+ vertices = vertices / 2 * 1.35
81
+ vertices[..., [0, 2]] = - vertices[..., [0, 2]]
82
+
83
+ if return_mesh:
84
+ return to_py3d_mesh(vertices, faces)
85
+ else:
86
+ return vertices, faces
models/ISOMER/mesh_reconstruction/remesh.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/Profactor/continuous-remeshing
2
+ import torch
3
+ import torch.nn.functional as tfunc
4
+ import torch_scatter
5
+ from typing import Tuple
6
+
7
+ def prepend_dummies(
8
+ vertices:torch.Tensor, #V,D
9
+ faces:torch.Tensor, #F,3 long
10
+ )->Tuple[torch.Tensor,torch.Tensor]:
11
+ """prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
12
+ V,D = vertices.shape
13
+ vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
14
+ faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
15
+ return vertices,faces
16
+
17
+ def remove_dummies(
18
+ vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
19
+ faces:torch.Tensor, #F,3 long - first face all zeros
20
+ )->Tuple[torch.Tensor,torch.Tensor]:
21
+ """remove dummy elements added with prepend_dummies()"""
22
+ return vertices[1:],faces[1:]-1
23
+
24
+
25
+ def calc_edges(
26
+ faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros
27
+ with_edge_to_face: bool = False
28
+ ) -> Tuple[torch.Tensor, ...]:
29
+ """
30
+ returns Tuple of
31
+ - edges E,2 long, 0 for unused, lower vertex index first
32
+ - face_to_edge F,3 long
33
+ - (optional) edge_to_face shape=E,[left,right],[face,side]
34
+
35
+ o-<-----e1 e0,e1...edge, e0<e1
36
+ | /A L,R....left and right face
37
+ | L / | both triangles ordered counter clockwise
38
+ | / R | normals pointing out of screen
39
+ V/ |
40
+ e0---->-o
41
+ """
42
+
43
+ F = faces.shape[0]
44
+
45
+ # make full edges, lower vertex index first
46
+ face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
47
+ full_edges = face_edges.reshape(F*3,2)
48
+ sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2
49
+
50
+ # make unique edges
51
+ edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
52
+ E = edges.shape[0]
53
+ face_to_edge = full_to_unique.reshape(F,3) #F,3
54
+
55
+ if not with_edge_to_face:
56
+ return edges, face_to_edge
57
+
58
+ is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
59
+ edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
60
+ scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
61
+ edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
62
+ edge_to_face[0] = 0
63
+ return edges, face_to_edge, edge_to_face
64
+
65
+ def calc_edge_length(
66
+ vertices:torch.Tensor, #V,3 first may be dummy
67
+ edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
68
+ )->torch.Tensor: #E
69
+
70
+ full_vertices = vertices[edges] #E,2,3
71
+ a,b = full_vertices.unbind(dim=1) #E,3
72
+ return torch.norm(a-b,p=2,dim=-1)
73
+
74
+ def calc_face_normals(
75
+ vertices:torch.Tensor, #V,3 first vertex may be unreferenced
76
+ faces:torch.Tensor, #F,3 long, first face may be all zero
77
+ normalize:bool=False,
78
+ )->torch.Tensor: #F,3
79
+ """
80
+ n
81
+ |
82
+ c0 corners ordered counterclockwise when
83
+ / \ looking onto surface (in neg normal direction)
84
+ c1---c2
85
+ """
86
+ full_vertices = vertices[faces] #F,C=3,3
87
+ v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
88
+ face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
89
+ if normalize:
90
+ face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1)
91
+ return face_normals #F,3
92
+
93
+ def calc_vertex_normals(
94
+ vertices:torch.Tensor, #V,3 first vertex may be unreferenced
95
+ faces:torch.Tensor, #F,3 long, first face may be all zero
96
+ face_normals:torch.Tensor=None, #F,3, not normalized
97
+ )->torch.Tensor: #F,3
98
+
99
+ F = faces.shape[0]
100
+
101
+ if face_normals is None:
102
+ face_normals = calc_face_normals(vertices,faces) # this no grad
103
+
104
+ vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
105
+
106
+
107
+ vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3)) # This no grad
108
+ vertex_normals = vertex_normals.sum(dim=1) #V,3
109
+ return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)
110
+
111
+ def calc_face_ref_normals(
112
+ faces:torch.Tensor, #F,3 long, 0 for unused
113
+ vertex_normals:torch.Tensor, #V,3 first unused
114
+ normalize:bool=False,
115
+ )->torch.Tensor: #F,3
116
+ """calculate reference normals for face flip detection"""
117
+ full_normals = vertex_normals[faces] #F,C=3,3
118
+ ref_normals = full_normals.sum(dim=1) #F,3
119
+ if normalize:
120
+ ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
121
+ return ref_normals
122
+
123
+ def pack(
124
+ vertices:torch.Tensor, #V,3 first unused and nan
125
+ faces:torch.Tensor, #F,3 long, 0 for unused
126
+ )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
127
+ """removes unused elements in vertices and faces"""
128
+ V = vertices.shape[0]
129
+
130
+ # remove unused faces
131
+ used_faces = faces[:,0]!=0
132
+ used_faces[0] = True
133
+ faces = faces[used_faces] #sync
134
+
135
+ # remove unused vertices
136
+ used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
137
+ used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add')
138
+ used_vertices = used_vertices.any(dim=1)
139
+ used_vertices[0] = True
140
+ vertices = vertices[used_vertices] #sync
141
+
142
+ # update used faces
143
+ ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
144
+ V1 = used_vertices.sum()
145
+ ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync
146
+ faces = ind[faces]
147
+
148
+ return vertices,faces
149
+
150
+ def split_edges(
151
+ vertices:torch.Tensor, #V,3 first unused
152
+ faces:torch.Tensor, #F,3 long, 0 for unused
153
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
154
+ face_to_edge:torch.Tensor, #F,3 long 0 for unused
155
+ splits, #E bool
156
+ pack_faces:bool=True,
157
+ )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
158
+
159
+ # c2 c2 c...corners = faces
160
+ # . . . . s...side_vert, 0 means no split
161
+ # . . .N2 . S...shrunk_face
162
+ # . . . . Ni...new_faces
163
+ # s2 s1 s2|c2...s1|c1
164
+ # . . . . .
165
+ # . . . S . .
166
+ # . . . . N1 .
167
+ # c0...(s0=0)....c1 s0|c0...........c1
168
+ #
169
+ # pseudo-code:
170
+ # S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
171
+ # split = side_vert!=0 example:[False,True,True]
172
+ # N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
173
+ # N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
174
+ # N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]
175
+
176
+ V = vertices.shape[0]
177
+ F = faces.shape[0]
178
+ S = splits.sum().item() #sync
179
+
180
+ if S==0:
181
+ return vertices,faces
182
+
183
+ edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
184
+ edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
185
+ side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
186
+ split_edges = edges[splits] #S sync
187
+
188
+ #vertices
189
+ split_vertices = vertices[split_edges].mean(dim=1) #S,3
190
+ vertices = torch.concat((vertices,split_vertices),dim=0)
191
+
192
+ #faces
193
+ side_split = side_vert!=0 #F,3
194
+ shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
195
+ new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
196
+ faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
197
+ if pack_faces:
198
+ mask = faces[:,0]!=0
199
+ mask[0] = True
200
+ faces = faces[mask] #F',3 sync
201
+
202
+ return vertices,faces
203
+
204
+ def collapse_edges(
205
+ vertices:torch.Tensor, #V,3 first unused
206
+ faces:torch.Tensor, #F,3 long 0 for unused
207
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
208
+ priorities:torch.Tensor, #E float
209
+ stable:bool=False, #only for unit testing
210
+ )->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
211
+
212
+ V = vertices.shape[0]
213
+
214
+ # check spacing
215
+ _,order = priorities.sort(stable=stable) #E
216
+ rank = torch.zeros_like(order)
217
+ rank[order] = torch.arange(0,len(rank),device=rank.device)
218
+ vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
219
+ edge_rank = rank #E
220
+ for i in range(3):
221
+ torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
222
+ edge_rank,_ = vert_rank[edges].max(dim=-1) #E
223
+ candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2
224
+
225
+ # check connectivity
226
+ vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
227
+ vert_connections[candidates[:,0]] = 1 #start
228
+ edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
229
+ vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
230
+ vert_connections[candidates] = 0 #clear start and end
231
+ edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
232
+ vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
233
+ collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end
234
+
235
+ # mean vertices
236
+ vertices[collapses[:,0]] = vertices[collapses].mean(dim=1)
237
+
238
+ # update faces
239
+ dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
240
+ dest[collapses[:,1]] = dest[collapses[:,0]]
241
+ faces = dest[faces] #F,3
242
+ c0,c1,c2 = faces.unbind(dim=-1)
243
+ collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
244
+ faces[collapsed] = 0
245
+
246
+ return vertices,faces
247
+
248
+ def calc_face_collapses(
249
+ vertices:torch.Tensor, #V,3 first unused
250
+ faces:torch.Tensor, #F,3 long, 0 for unused
251
+ edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
252
+ face_to_edge:torch.Tensor, #F,3 long 0 for unused
253
+ edge_length:torch.Tensor, #E
254
+ face_normals:torch.Tensor, #F,3
255
+ vertex_normals:torch.Tensor, #V,3 first unused
256
+ min_edge_length:torch.Tensor=None, #V
257
+ area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
258
+ shortest_probability = 0.8
259
+ )->torch.Tensor: #E edges to collapse
260
+
261
+ E = edges.shape[0]
262
+ F = faces.shape[0]
263
+
264
+ # face flips
265
+ ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
266
+ face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
267
+
268
+ # small faces
269
+ if min_edge_length is not None:
270
+ min_face_length = min_edge_length[faces].mean(dim=-1) #F
271
+ min_area = min_face_length**2 * area_ratio #F
272
+ face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
273
+ face_collapses[0] = False
274
+
275
+ # faces to edges
276
+ face_length = edge_length[face_to_edge] #F,3
277
+
278
+ if shortest_probability<1:
279
+ #select shortest edge with shortest_probability chance
280
+ randlim = round(2/(1-shortest_probability))
281
+ rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
282
+ sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
283
+ local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
284
+ else:
285
+ local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
286
+
287
+ edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
288
+ edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
289
+ edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long())
290
+
291
+ return edge_collapses.bool()
292
+
293
+ def flip_edges(
294
+ vertices:torch.Tensor, #V,3 first unused
295
+ faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
296
+ edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
297
+ edge_to_face:torch.Tensor, #E,[left,right],[face,side]
298
+ with_border:bool=True, #handle border edges (D=4 instead of D=6)
299
+ with_normal_check:bool=True, #check face normal flips
300
+ stable:bool=False, #only for unit testing
301
+ ):
302
+ V = vertices.shape[0]
303
+ E = edges.shape[0]
304
+ device=vertices.device
305
+ vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
306
+ vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
307
+ neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
308
+ neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
309
+ edge_is_inside = neighbors.all(dim=-1) #E
310
+
311
+ if with_border:
312
+ # inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
313
+ # need to use float for masks in order to use scatter(reduce='multiply')
314
+ vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
315
+ src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
316
+ vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
317
+ vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
318
+ vertex_degree -= 2 * vertex_is_inside #V long
319
+
320
+ neighbor_degrees = vertex_degree[neighbors] #E,LR=2
321
+ edge_degrees = vertex_degree[edges] #E,2
322
+ #
323
+ # loss = Sum_over_affected_vertices((new_degree-6)**2)
324
+ # loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
325
+ # + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
326
+ # = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
327
+ #
328
+ loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
329
+ candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
330
+ loss_change = loss_change[candidates] #E'
331
+ if loss_change.shape[0]==0:
332
+ return
333
+
334
+ edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
335
+ _,order = loss_change.sort(descending=True, stable=stable) #E'
336
+ rank = torch.zeros_like(order)
337
+ rank[order] = torch.arange(0,len(rank),device=rank.device)
338
+ vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
339
+ torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
340
+ vertex_rank,_ = vertex_rank.max(dim=-1) #V
341
+ neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
342
+ flip = rank==neighborhood_rank #E'
343
+
344
+ if with_normal_check:
345
+ # cl-<-----e1 e0,e1...edge, e0<e1
346
+ # | /A L,R....left and right face
347
+ # | L / | both triangles ordered counter clockwise
348
+ # | / R | normals pointing out of screen
349
+ # V/ |
350
+ # e0---->-cr
351
+ v = vertices[edges_neighbors] #E",4,3
352
+ v = v - v[:,0:1] #make relative to e0
353
+ e1 = v[:,1]
354
+ cl = v[:,2]
355
+ cr = v[:,3]
356
+ n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors
357
+ flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
358
+ flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face
359
+
360
+ flip_edges_neighbors = edges_neighbors[flip] #E",4
361
+ flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
362
+ flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
363
+ faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))
models/ISOMER/mesh_reconstruction/render.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/Profactor/continuous-remeshing
2
+ import nvdiffrast.torch as dr
3
+ import torch
4
+ from typing import Tuple
5
+
6
+ def _warmup(glctx, device=None):
7
+ device = 'cuda' if device is None else device
8
+ #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
9
+ def tensor(*args, **kwargs):
10
+ return torch.tensor(*args, device=device, **kwargs)
11
+
12
+ # defines a triangle in homogeneous coordinates and calls dr.rasterize to render this triangle, which may help to initialize or warm up the GPU context
13
+ pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
14
+ tri = tensor([[0, 1, 2]], dtype=torch.int32)
15
+ dr.rasterize(glctx, pos, tri, resolution=[256, 256])
16
+
17
+ # glctx = dr.RasterizeGLContext(output_db=False, device="cuda")
18
+ glctx = dr.RasterizeCudaContext(device="cuda")
19
+
20
+ class NormalsRenderer:
21
+
22
+ _glctx:dr.RasterizeCudaContext = None
23
+
24
+ def __init__(
25
+ self,
26
+ mv: torch.Tensor, #C,4,4 # normal column-major (unlike pytorch3d)
27
+ proj: torch.Tensor, #C,4,4
28
+ image_size: Tuple[int,int],
29
+ mvp = None,
30
+ device=None,
31
+ ):
32
+ if mvp is None:
33
+ self._mvp = proj @ mv #C,4,4
34
+ else:
35
+ self._mvp = mvp
36
+ self._image_size = image_size
37
+ self._glctx = glctx
38
+ _warmup(self._glctx, device)
39
+
40
+ def render(self,
41
+ vertices: torch.Tensor, #V,3 float
42
+ normals: torch.Tensor, #V,3 float in [-1, 1]
43
+ faces: torch.Tensor, #F,3 long
44
+ ) ->torch.Tensor: #C,H,W,4
45
+
46
+ V = vertices.shape[0]
47
+ faces = faces.type(torch.int32)
48
+ vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4
49
+ # transforms the vertices into clip space using the mvp matrix.
50
+ vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4 # the .transpose(-2,-1) operation ensures that the matrix multiplication aligns with the row-major convention.
51
+ rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4 -> 4 includes the barycentric coordinates and other data.
52
+ vert_col = (normals+1)/2 #V,3
53
+ # this function takes the attributes (colors) defined at the vertices and computes their values at each pixel (or fragment) within the triangles
54
+ col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3
55
+ alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1
56
+ col = torch.concat((col,alpha),dim=-1) #C,H,W,4
57
+ col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
58
+ return col #C,H,W,4
59
+
60
+
61
+
62
+ from pytorch3d.structures import Meshes
63
+ from pytorch3d.renderer.mesh.shader import ShaderBase
64
+ from pytorch3d.renderer import (
65
+ RasterizationSettings,
66
+ MeshRendererWithFragments,
67
+ TexturesVertex,
68
+ MeshRasterizer,
69
+ BlendParams,
70
+ FoVOrthographicCameras,
71
+ look_at_view_transform,
72
+ hard_rgb_blend,
73
+ )
74
+
75
+ class VertexColorShader(ShaderBase):
76
+ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
77
+ blend_params = kwargs.get("blend_params", self.blend_params)
78
+ texels = meshes.sample_textures(fragments)
79
+ return hard_rgb_blend(texels, fragments, blend_params)
80
+
81
+ def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"):
82
+ if len(mesh) != len(cameras):
83
+ if len(cameras) % len(mesh) == 0:
84
+ mesh = mesh.extend(len(cameras))
85
+ else:
86
+ raise NotImplementedError()
87
+
88
+ # render requires everything in float16 or float32
89
+ input_dtype = dtype
90
+ blend_params = BlendParams(1e-4, 1e-4, bkgd)
91
+
92
+ # Define the settings for rasterization and shading
93
+ raster_settings = RasterizationSettings(
94
+ image_size=(H, W),
95
+ blur_radius=blur_radius,
96
+ faces_per_pixel=faces_per_pixel,
97
+ clip_barycentric_coords=True,
98
+ bin_size=None,
99
+ max_faces_per_bin=None,
100
+ )
101
+
102
+ # Create a renderer by composing a rasterizer and a shader
103
+ # We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used)
104
+ renderer = MeshRendererWithFragments(
105
+ rasterizer=MeshRasterizer(
106
+ cameras=cameras,
107
+ raster_settings=raster_settings
108
+ ),
109
+ shader=VertexColorShader(
110
+ device=device,
111
+ cameras=cameras,
112
+ blend_params=blend_params
113
+ )
114
+ )
115
+
116
+ # render RGB and depth, get mask
117
+ with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type):
118
+ images, _ = renderer(mesh)
119
+ return images # BHW4
120
+
121
+ class Pytorch3DNormalsRenderer: # 100 times slower!!!
122
+ def __init__(self, cameras, image_size, device):
123
+ self.cameras = cameras.to(device)
124
+ self._image_size = image_size
125
+ self.device = device
126
+
127
+ def render(self,
128
+ vertices: torch.Tensor, #V,3 float
129
+ normals: torch.Tensor, #V,3 float in [-1, 1]
130
+ faces: torch.Tensor, #F,3 long
131
+ ) ->torch.Tensor: #C,H,W,4
132
+ mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device)
133
+ return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device)
134
+
135
+ def save_tensor_to_img(tensor, save_dir):
136
+ from PIL import Image
137
+ import numpy as np
138
+ for idx, img in enumerate(tensor):
139
+ img = img[..., :3].cpu().numpy()
140
+ img = (img * 255).astype(np.uint8)
141
+ img = Image.fromarray(img)
142
+ img.save(save_dir + f"{idx}.png")
models/ISOMER/model/__init__.py ADDED
File without changes
models/ISOMER/model/inference_pipeline.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+
6
+ from pytorch3d.structures import Meshes
7
+ from pytorch3d.renderer import TexturesVertex
8
+
9
+ from ..scripts.fast_geo import fast_geo, create_sphere, create_box
10
+ from ..scripts.project_mesh import get_cameras_list_azi_ele
11
+ from ..mesh_reconstruction.recon import reconstruct_stage1
12
+ from ..mesh_reconstruction.refine import run_mesh_refine
13
+ from ..mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_perspective
14
+
15
+ from ..data.utils import (
16
+ simple_remove_bkg_normal,
17
+ load_glb,
18
+ load_obj_with_verts_faces)
19
+ from ..scripts.utils import (
20
+ to_pyml_mesh,
21
+ simple_clean_mesh,
22
+ normal_rotation_img2img_c2w,
23
+ rotate_normal_R,
24
+ get_rotation_matrix_azi_ele,
25
+ manage_elevation_azimuth)
26
+
27
+ @torch.enable_grad()
28
+ def reconstruction_pipe(normal_pils,
29
+ rotation_angles_azi,
30
+ rotation_angles_ele,
31
+ front_index=0,
32
+ back_index=2,
33
+ side_index=1,
34
+ weights=None,
35
+ expansion_weight=0.1,
36
+ expansion_weight_stage2=0.0,
37
+ init_type="ball",
38
+ sphere_r=None, # only used if init_type=="ball"
39
+ box_width=1.0, # only used if init_type=="box"
40
+ box_length=1.0, # only used if init_type=="box"
41
+ box_height=1.0, # only used if init_type=="box"
42
+ init_verts=None,
43
+ init_faces=None,
44
+ init_mesh_from_file="",
45
+ stage1_steps=200,
46
+ stage2_steps=200,
47
+ projection_type="orthographic",
48
+ fovy=None,
49
+ radius=None,
50
+ ortho_dist=1.1,
51
+ camera_angles_azi=None,
52
+ camera_angles_ele=None,
53
+ rm_bkg=False,
54
+ rm_bkg_with_rembg=False, # only used if rm_bkg
55
+ normal_rotation_R=None,
56
+ train_stage1=True,
57
+ train_stage2=True,
58
+ use_remesh_stage1=True,
59
+ use_remesh_stage2=True,
60
+ start_edge_len_stage1=0.1,
61
+ end_edge_len_stage1=0.02,
62
+ start_edge_len_stage2=0.02,
63
+ end_edge_len_stage2=0.005,
64
+ ):
65
+
66
+ assert projection_type in ['perspective', 'orthographic'], f"projection_type ({projection_type}) should be one of ['perspective', 'orthographic']"
67
+
68
+ if stage1_steps == 0:
69
+ train_stage1 = False
70
+ if stage2_steps == 0:
71
+ train_stage2 = False
72
+
73
+ if normal_rotation_R is not None:
74
+ assert normal_rotation_R.shape[-2] == 3 and normal_rotation_R.shape[-1] == 3
75
+ assert len(normal_rotation_R.shape) == 2
76
+ normal_rotation_R = normal_rotation_R.float()
77
+
78
+ camera_angles_azi = camera_angles_azi.float()
79
+ camera_angles_ele = camera_angles_ele.float()
80
+
81
+ camera_angles_ele, camera_angles_azi = manage_elevation_azimuth(camera_angles_ele, camera_angles_azi)
82
+
83
+ if init_type in ["std", "thin"]:
84
+ assert camera_angles_azi[front_index]%360==0, f"the camera_angles_azi associated with front image (index {front_index}) should be 0 not {camera_angles_azi[front_index]}"
85
+ assert camera_angles_azi[back_index]%360==180, f"the camera_angles_azi associated with back image (index {back_index}) should be 180 not {camera_angles_azi[back_index]}"
86
+ assert camera_angles_azi[side_index]%360==90, f"the camera_angles_azi associated with left side image (index {side_index}) should be 90, not {camera_angles_azi[back_index]}"
87
+
88
+ if rm_bkg:
89
+ if rm_bkg_with_rembg:
90
+ os.environ["OMP_NUM_THREADS"] = '8'
91
+ normal_pils = simple_remove_bkg_normal(normal_pils,rm_bkg_with_rembg)
92
+
93
+ if rotation_angles_azi is not None:
94
+ rotation_angles_azi = -rotation_angles_azi.float()
95
+ rotation_angles_ele = rotation_angles_ele.float()
96
+
97
+ rotation_angles_ele, rotation_angles_azi = manage_elevation_azimuth(rotation_angles_ele, rotation_angles_azi)
98
+
99
+ assert len(normal_pils) == len(rotation_angles_azi), f'len(normal_pils) ({len(normal_pils)}) != len(rotation_angles_azi) ({len(rotation_angles_azi)})'
100
+ if rotation_angles_ele is None:
101
+ rotation_angles_ele = [0] * len(normal_pils)
102
+
103
+ normal_pils_rotated = []
104
+ for i in range(len(normal_pils)):
105
+ c2w_R = get_rotation_matrix_azi_ele(rotation_angles_azi[i], rotation_angles_ele[i])
106
+
107
+ rotated_ = normal_rotation_img2img_c2w(normal_pils[i], c2w=c2w_R)
108
+ normal_pils_rotated.append(rotated_)
109
+
110
+ normal_pils = normal_pils_rotated
111
+
112
+ if normal_rotation_R is not None:
113
+ normal_pils_rotated = []
114
+ for i in range(len(normal_pils)):
115
+ rotated_ = rotate_normal_R(normal_pils[i], normal_rotation_R, save_addr="", device="cuda")
116
+ normal_pils_rotated.append(rotated_)
117
+
118
+ normal_pils = normal_pils_rotated
119
+
120
+ normal_stg1 = [img for img in normal_pils]
121
+
122
+ if init_type in ['thin', 'std']:
123
+ front_ = normal_stg1[front_index]
124
+ back_ = normal_stg1[back_index]
125
+ side_ = normal_stg1[side_index]
126
+ meshes, depth_front, depth_back, mesh_front, mesh_back = fast_geo(front_, back_, side_, init_type=init_type, return_depth_and_sep_mesh=True)
127
+
128
+
129
+ elif init_type in ["ball", "box"]:
130
+
131
+ if init_type == "ball":
132
+ assert sphere_r is not None, f"sphere_r ({sphere_r}) should not be None when init_type is 'ball'"
133
+ meshes = create_sphere(sphere_r)
134
+
135
+ if init_type == "box":
136
+ assert box_width is not None and box_length is not None and box_height is not None, f"box_width ({box_width}), box_length ({box_length}), and box_height ({box_height}) should not be None when init_type is 'box'"
137
+ meshes = create_box(width=box_width, length=box_length, height=box_height)
138
+
139
+ # add texture just in case
140
+ num_meshes = len(meshes)
141
+ num_verts_per_mesh = meshes.verts_packed().shape[0] // num_meshes
142
+ black_texture = torch.zeros((num_meshes, num_verts_per_mesh, 3), device="cuda")
143
+ textures = TexturesVertex(verts_features=black_texture)
144
+ meshes.textures = textures
145
+
146
+ elif init_type == "file":
147
+ assert init_mesh_from_file or (init_verts is not None and init_faces is not None), f"init_mesh_from_file ({init_mesh_from_file}) should not be None when init_type is 'file', else init_verts and init_faces should not be None"
148
+
149
+ if init_verts is not None and init_faces is not None:
150
+ meshes = Meshes(verts=[init_verts], faces=[init_faces]).to('cuda')
151
+ elif init_mesh_from_file.endswith('.glb'):
152
+ meshes = load_glb(init_mesh_from_file).to('cuda')
153
+ else:
154
+ meshes = load_obj_with_verts_faces(init_mesh_from_file).to('cuda')
155
+
156
+ # add texture just in case
157
+ num_meshes = len(meshes)
158
+ num_verts_per_mesh = meshes.verts_packed().shape[0] // num_meshes
159
+ black_texture = torch.zeros((num_meshes, num_verts_per_mesh, 3), device="cuda")
160
+ textures = TexturesVertex(verts_features=black_texture)
161
+ meshes.textures = textures
162
+
163
+ if projection_type == 'perspective':
164
+ assert fovy is not None and radius is not None, f"fovy ({fovy}) and radius ({radius}) should not be None when projection_type is 'perspective'"
165
+ cameras = get_cameras_list_azi_ele(camera_angles_azi, camera_angles_ele, fov_in_degrees=fovy,device="cuda", dist=radius, cam_type='fov')
166
+
167
+ elif projection_type == 'orthographic':
168
+ cameras = get_cameras_list_azi_ele(camera_angles_azi, camera_angles_ele, fov_in_degrees=fovy, device="cuda", focal=1., dist=ortho_dist, cam_type='orthographic')
169
+
170
+ vertices, faces = meshes.verts_list()[0], meshes.faces_list()[0]
171
+
172
+ render_camera_angles_azi = -camera_angles_azi
173
+ render_camera_angles_ele = camera_angles_ele
174
+ if projection_type == 'orthographic':
175
+ mv, proj = make_star_cameras_orthographic(render_camera_angles_azi, render_camera_angles_ele)
176
+ else:
177
+ mv, proj = make_star_cameras_perspective(render_camera_angles_azi, render_camera_angles_ele, distance=radius, r=radius, fov=fovy, device='cuda')
178
+
179
+ # stage 1
180
+ if train_stage1:
181
+ vertices, faces = reconstruct_stage1(normal_stg1, mv=mv, proj=proj, steps=stage1_steps, vertices=vertices, faces=faces, start_edge_len=start_edge_len_stage1, end_edge_len=end_edge_len_stage1, gain=0.05, return_mesh=False, loss_expansion_weight=expansion_weight, use_remesh=use_remesh_stage1)
182
+
183
+ # stage 2
184
+ if train_stage2:
185
+ vertices, faces = run_mesh_refine(vertices, faces, normal_pils, mv=mv, proj=proj, weights=weights, steps=stage2_steps, start_edge_len=start_edge_len_stage2, end_edge_len=end_edge_len_stage2, decay=0.99, update_normal_interval=20, update_warmup=5, return_mesh=False, process_inputs=False, process_outputs=False, cameras=cameras, use_remesh=use_remesh_stage2, loss_expansion_weight=expansion_weight_stage2)
186
+
187
+ meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=1, apply_sub_divide=True, sub_divide_threshold=0.25).to("cuda")
188
+
189
+ return meshes
models/ISOMER/projection_func.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ import os
6
+ from .scripts.proj_commands import projection as isomer_projection
7
+ from .data.utils import simple_remove_bkg_normal
8
+
9
+ # mesh_address,
10
+ def projection(
11
+ meshes,
12
+ masks,
13
+ images,
14
+ azimuths,
15
+ elevations,
16
+ weights,
17
+ fov,
18
+ radius,
19
+ save_dir,
20
+ save_glb_addr=None,
21
+ remove_background=False,
22
+ auto_center=False,
23
+ projection_type="perspective",
24
+ below_confidence_strategy="smooth",
25
+ complete_unseen=True,
26
+ mesh_scale_factor=1.0,
27
+ rm_bkg_with_rembg=True,
28
+ ):
29
+
30
+ if save_glb_addr is None:
31
+ os.makedirs(save_dir, exist_ok=True)
32
+ save_glb_addr=os.path.join(save_dir, "rgb_projected.glb")
33
+
34
+ bs = len(images)
35
+ assert len(azimuths) == bs, f'len(azimuths) ({len(azimuths)} != batchsize ({bs}))'
36
+ assert len(elevations) == bs, f'len(elevations) ({len(elevations)} != batchsize ({bs}))'
37
+ assert len(weights) == bs, f'len(weights) ({len(weights)} != batchsize ({bs}))'
38
+
39
+ image_rgba = torch.cat([images[:,:,:,:3], masks.unsqueeze(-1)], dim=-1)
40
+
41
+ assert image_rgba.shape[-1] == 4, f'image_rgba.shape is {image_rgba.shape}'
42
+
43
+ img_list = [Image.fromarray((image.cpu()*255).numpy().astype(np.uint8)) for image in image_rgba]
44
+
45
+
46
+ if remove_background:
47
+ if rm_bkg_with_rembg:
48
+ os.environ["OMP_NUM_THREADS"] = '8'
49
+ img_list = simple_remove_bkg_normal(img_list, rm_bkg_with_rembg, return_Image=True)
50
+
51
+ resolution = img_list[0].size[0]
52
+ new_img_list = []
53
+ for i in range(len(img_list)):
54
+ new_img = img_list[i].resize((resolution,resolution))
55
+
56
+ path_dir = os.path.join(save_dir, f'projection_images')
57
+ os.makedirs(path_dir, exist_ok=True)
58
+
59
+ path_ = os.path.join(path_dir, f'ProjectionImg{i}.png')
60
+
61
+ new_img.save(path_)
62
+
63
+ new_img_list.append(new_img)
64
+
65
+ img_list = new_img_list
66
+
67
+ isomer_projection(meshes,
68
+ img_list=img_list,
69
+ weights=weights,
70
+ azimuths=azimuths,
71
+ elevations=elevations,
72
+ projection_type=projection_type,
73
+ auto_center=auto_center,
74
+ resolution=resolution,
75
+ fovy=fov,
76
+ radius=radius,
77
+ scale_factor=mesh_scale_factor,
78
+ save_glb_addr=save_glb_addr,
79
+ scale_verts=True,
80
+ complete_unseen=complete_unseen,
81
+ below_confidence_strategy=below_confidence_strategy
82
+ )
83
+
84
+ return save_glb_addr
85
+
86
+
models/ISOMER/reconstruction_func.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ import os
6
+ from .model.inference_pipeline import reconstruction_pipe
7
+
8
+ def reconstruction(
9
+ normal_pils,
10
+ masks,
11
+ weights,
12
+ fov,
13
+ radius,
14
+ camera_angles_azi,
15
+ camera_angles_ele,
16
+ expansion_weight_stage1=0.1,
17
+ init_type="ball",
18
+ init_verts=None,
19
+ init_faces=None,
20
+ init_mesh_from_file="",
21
+ stage1_steps=200,
22
+ stage2_steps=200,
23
+ projection_type="perspective",
24
+ need_normal_rotation=False,
25
+ rotation_angles_azi=None, # only used if need_normal_rotation
26
+ rotation_angles_ele=None, # only used if need_normal_rotation
27
+ normal_rotation_R=None, # only used if need_normal_rotation
28
+ rm_bkg=False,
29
+ rm_bkg_with_rembg=True, # only used if rm_bkg
30
+ start_edge_len_stage1=0.1,
31
+ end_edge_len_stage1=0.02,
32
+ start_edge_len_stage2=0.02,
33
+ end_edge_len_stage2=0.005,
34
+ expansion_weight_stage2=0.0,
35
+ ):
36
+
37
+ if init_type == "file":
38
+ assert ((init_verts is not None and init_faces is not None) or init_mesh_from_file), f'init_mesh_from_file or (init_verts and init_faces) must be provided if init_type=="file"'
39
+
40
+ if not need_normal_rotation:
41
+ rotation_angles_azi = None
42
+ rotation_angles_ele = None
43
+ normal_rotation_R = None
44
+
45
+ bs = len(normal_pils)
46
+
47
+ assert len(camera_angles_azi) == bs, f'len(camera_angles_azi) ({len(camera_angles_azi)} != batchsize ({bs}))'
48
+ assert len(camera_angles_ele) == bs, f'len(camera_angles_ele) ({len(camera_angles_ele)} != batchsize ({bs}))'
49
+
50
+ normal_pils_rgba = torch.cat([normal_pils[:,:,:,:3], masks.unsqueeze(-1)], dim=-1)
51
+
52
+ assert normal_pils_rgba.shape[-1] == 4, f'normal_pils_rgba.shape is {normal_pils_rgba.shape}'
53
+
54
+
55
+ normal_pils = [Image.fromarray((normal_pil.cpu()*255).numpy().astype(np.uint8)) for normal_pil in normal_pils_rgba]
56
+
57
+
58
+ meshes = reconstruction_pipe(
59
+ normal_pils=normal_pils,
60
+ rotation_angles_azi=rotation_angles_azi,
61
+ rotation_angles_ele=rotation_angles_ele,
62
+ weights=weights,
63
+ expansion_weight=expansion_weight_stage1,
64
+ init_type=init_type,
65
+ stage1_steps=stage1_steps,
66
+ stage2_steps=stage2_steps,
67
+ projection_type=projection_type,
68
+ fovy=fov,
69
+ radius=radius,
70
+ camera_angles_azi=camera_angles_azi,
71
+ camera_angles_ele=camera_angles_ele,
72
+ rm_bkg=rm_bkg, rm_bkg_with_rembg=rm_bkg_with_rembg,
73
+ normal_rotation_R=normal_rotation_R,
74
+ init_mesh_from_file=init_mesh_from_file,
75
+ start_edge_len_stage1=start_edge_len_stage1,
76
+ end_edge_len_stage1=end_edge_len_stage1,
77
+ start_edge_len_stage2=start_edge_len_stage2,
78
+ end_edge_len_stage2=end_edge_len_stage2,
79
+ expansion_weight_stage2=expansion_weight_stage2,
80
+ init_verts=init_verts,
81
+ init_faces=init_faces,
82
+
83
+ )
84
+
85
+
86
+ return meshes
87
+
88
+
models/ISOMER/scripts/__init__.py ADDED
File without changes
models/ISOMER/scripts/all_typing.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code from https://github.com/threestudio-project
2
+
3
+ """
4
+ This module contains type annotations for the project, using
5
+ 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
6
+ 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
7
+
8
+ Two types of typing checking can be used:
9
+ 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
10
+ 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
11
+ """
12
+
13
+ # Basic types
14
+ from typing import (
15
+ Any,
16
+ Callable,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ Literal,
21
+ NamedTuple,
22
+ NewType,
23
+ Optional,
24
+ Sized,
25
+ Tuple,
26
+ Type,
27
+ TypeVar,
28
+ Union,
29
+ )
30
+
31
+ # Tensor dtype
32
+ # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
33
+ from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
34
+
35
+ # Config type
36
+ from omegaconf import DictConfig
37
+
38
+ # PyTorch Tensor type
39
+ from torch import Tensor
40
+
41
+ # Runtime type checking decorator
42
+ from typeguard import typechecked as typechecker
models/ISOMER/scripts/fast_geo.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from .mesh_init import build_mesh, calc_w_over_h, fix_border_with_pymeshlab_fast
4
+ from pytorch3d.structures import Meshes, join_meshes_as_scene
5
+ import numpy as np
6
+
7
+ import torch
8
+ from pytorch3d.structures import Meshes
9
+ from pytorch3d.utils import ico_sphere
10
+
11
+ def create_sphere(radius, device='cuda'):
12
+
13
+ sphere_mesh = ico_sphere(3, device=device) # Increase the subdivision level (e.g., 2) for higher resolution sphere
14
+ sphere_mesh = sphere_mesh.scale_verts(radius)
15
+
16
+ meshes = Meshes(verts=[sphere_mesh.verts_list()[0]], faces=[sphere_mesh.faces_list()[0]])
17
+ return meshes
18
+
19
+
20
+ def create_box(width, length, height, device='cuda'):
21
+ """
22
+ Create a box mesh given the width, length, and height.
23
+
24
+ Args:
25
+ width (float): Width of the box.
26
+ length (float): Length of the box.
27
+ height (float): Height of the box.
28
+ device (str): Device for the tensor operations, default is 'cuda'.
29
+
30
+ Returns:
31
+ Meshes: A PyTorch3D Meshes object representing the box.
32
+ """
33
+ # Define the 8 vertices of the box
34
+ verts = torch.tensor([
35
+ [-width / 2, -length / 2, -height / 2],
36
+ [ width / 2, -length / 2, -height / 2],
37
+ [ width / 2, length / 2, -height / 2],
38
+ [-width / 2, length / 2, -height / 2],
39
+ [-width / 2, -length / 2, height / 2],
40
+ [ width / 2, -length / 2, height / 2],
41
+ [ width / 2, length / 2, height / 2],
42
+ [-width / 2, length / 2, height / 2]
43
+ ], device=device)
44
+
45
+ # Define the 12 triangles (faces) of the box using vertex indices
46
+ faces = torch.tensor([
47
+ [0, 1, 2], [0, 2, 3], # Bottom face
48
+ [4, 5, 6], [4, 6, 7], # Top face
49
+ [0, 1, 5], [0, 5, 4], # Front face
50
+ [1, 2, 6], [1, 6, 5], # Right face
51
+ [2, 3, 7], [2, 7, 6], # Back face
52
+ [3, 0, 4], [3, 4, 7] # Left face
53
+ ], device=device)
54
+
55
+ # Create the Meshes object
56
+ meshes = Meshes(verts=[verts], faces=[faces])
57
+
58
+ return meshes
59
+
60
+
61
+ # stage 0 inital mesh estimation
62
+ def fast_geo(front_normal: Image.Image, back_normal: Image.Image, side_normal: Image.Image, clamp=0., init_type="std", return_depth_and_sep_mesh=False):
63
+
64
+ import time
65
+ assert front_normal.mode != "RGB"
66
+ assert back_normal.mode != "RGB"
67
+ assert side_normal.mode != "RGB"
68
+
69
+ front_normal = front_normal.resize((192, 192))
70
+ back_normal = back_normal.resize((192, 192))
71
+ side_normal = side_normal.resize((192, 192))
72
+
73
+ # build mesh with front back projection # ~3s
74
+ side_w_over_h = calc_w_over_h(side_normal)
75
+ mesh_front, depth_front = build_mesh(front_normal, front_normal, clamp_min=clamp, scale=side_w_over_h, init_type=init_type, return_depth=True)
76
+ mesh_back, depth_back = build_mesh(back_normal, back_normal, is_back=True, clamp_min=clamp, scale=side_w_over_h, init_type=init_type, return_depth=True)
77
+ meshes = join_meshes_as_scene([mesh_front, mesh_back])
78
+
79
+ # poisson reconstruction which guarantees a smooth connection between meshes
80
+ # and simplify into 2000 fewer faces
81
+ meshes = fix_border_with_pymeshlab_fast(meshes, poissson_depth=6, simplification=2000)
82
+
83
+
84
+ if return_depth_and_sep_mesh:
85
+ return meshes, depth_front, depth_back, mesh_front, mesh_back
86
+ return meshes
models/ISOMER/scripts/load_onnx.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime
2
+ import torch
3
+
4
+ providers = [
5
+ ('TensorrtExecutionProvider', {
6
+ 'device_id': 0,
7
+ 'trt_max_workspace_size': 8 * 1024 * 1024 * 1024,
8
+ 'trt_fp16_enable': True,
9
+ 'trt_engine_cache_enable': True,
10
+ }),
11
+ ('CUDAExecutionProvider', {
12
+ 'device_id': 0,
13
+ 'arena_extend_strategy': 'kSameAsRequested',
14
+ 'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
15
+ 'cudnn_conv_algo_search': 'HEURISTIC',
16
+ })
17
+ ]
18
+
19
+ def load_onnx(file_path: str):
20
+ assert file_path.endswith(".onnx")
21
+ sess_opt = onnxruntime.SessionOptions()
22
+ ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers)
23
+ return ort_session
24
+
25
+
26
+ def load_onnx_caller(file_path: str, single_output=False):
27
+ ort_session = load_onnx(file_path)
28
+ def caller(*args):
29
+ torch_input = isinstance(args[0], torch.Tensor)
30
+ if torch_input:
31
+ torch_input_dtype = args[0].dtype
32
+ torch_input_device = args[0].device
33
+ # check all are torch.Tensor and have same dtype and device
34
+ assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor"
35
+ assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor"
36
+ assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor"
37
+ args = [arg.cpu().float().numpy() for arg in args]
38
+
39
+ ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))}
40
+ ort_outs = ort_session.run(None, ort_inputs)
41
+
42
+ if torch_input:
43
+ ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs]
44
+
45
+ if single_output:
46
+ return ort_outs[0]
47
+ return ort_outs
48
+ return caller
models/ISOMER/scripts/mesh_init.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import numpy as np
4
+ from pytorch3d.structures import Meshes
5
+ from pytorch3d.renderer import TexturesVertex
6
+ from .utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh
7
+ import pymeshlab
8
+
9
+ _MAX_THREAD = 8
10
+
11
+ # rgb and depth to mesh
12
+ def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"):
13
+ pixel_center = 0.5 if use_pixel_centers else 0
14
+ i, j = np.meshgrid(
15
+ np.arange(W, dtype=np.float32) + pixel_center,
16
+ np.arange(H, dtype=np.float32) + pixel_center,
17
+ indexing='xy'
18
+ )
19
+ i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device)
20
+
21
+ origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3
22
+ directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3
23
+
24
+ return origins, directions
25
+
26
+ def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False):
27
+ if valid_HWC is None:
28
+ valid_HWC = torch.ones_like(pred_HWC).bool()
29
+ H, W = rgb_BCHW.shape[-2:]
30
+ rgb_BCHW = rgb_BCHW.flip(-2)
31
+ pred_HWC = pred_HWC.flip(0)
32
+ valid_HWC = valid_HWC.flip(0)
33
+ rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device)
34
+ verts = rays_o + rays_d * pred_HWC # [H, W, 3]
35
+ verts = verts.reshape(-1, 3) # [V, 3]
36
+ indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device)
37
+ faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1)
38
+ # faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1]
39
+ faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1]
40
+ faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1)
41
+ # faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:]
42
+ faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:]
43
+ faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3),
44
+ faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)],
45
+ dim=0) # (F, 3)
46
+ colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3)
47
+ if is_back:
48
+ verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device)
49
+
50
+ used_verts = faces.unique()
51
+ old_to_new_mapping = torch.zeros_like(verts[..., 0]).long()
52
+ old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device)
53
+ new_faces = old_to_new_mapping[faces]
54
+ mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]]))
55
+ return mesh
56
+
57
+ def normalmap_to_depthmap(normal_np):
58
+ from .normal_to_height_map import estimate_height_map
59
+ height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96)
60
+ return height
61
+
62
+ def transform_back_normal_to_front(normal_pil):
63
+ arr = np.array(normal_pil) # in [0, 255]
64
+ arr[..., 0] = 255-arr[..., 0]
65
+ arr[..., 2] = 255-arr[..., 2]
66
+ return Image.fromarray(arr.astype(np.uint8))
67
+
68
+ def calc_w_over_h(normal_pil):
69
+ if isinstance(normal_pil, Image.Image):
70
+ arr = np.array(normal_pil)
71
+ else:
72
+ assert isinstance(normal_pil, np.ndarray)
73
+ arr = normal_pil
74
+ if arr.shape[-1] == 4:
75
+ alpha = arr[..., -1] / 255.
76
+ alpha[alpha >= 0.5] = 1
77
+ alpha[alpha < 0.5] = 0
78
+ else:
79
+ alpha = ~(arr.min(axis=-1) >= 250)
80
+ h_min, w_min = np.min(np.where(alpha), axis=1)
81
+ h_max, w_max = np.max(np.where(alpha), axis=1)
82
+ return (w_max - w_min) / (h_max - h_min)
83
+
84
+ def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0, return_depth=False):
85
+ if is_back:
86
+ normal_pil = transform_back_normal_to_front(normal_pil)
87
+ normal_img = np.array(normal_pil)
88
+ rgb_img = np.array(rgb_pil)
89
+ if normal_img.shape[-1] == 4:
90
+ valid_HWC = normal_img[..., [3]] / 255
91
+ elif rgb_img.shape[-1] == 4:
92
+ valid_HWC = rgb_img[..., [3]] / 255
93
+ else:
94
+ raise ValueError("invalid input, either normal or rgb should have alpha channel")
95
+
96
+ # object area pixels height
97
+ real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0])
98
+
99
+ heights = normalmap_to_depthmap(normal_img)
100
+ rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None]
101
+ valid_HWC[valid_HWC < 0.5] = 0
102
+ valid_HWC[valid_HWC >= 0.5] = 1
103
+ valid_HWC = torch.from_numpy(valid_HWC).bool()
104
+
105
+ if init_type == "std":
106
+ # accurate but not stable
107
+ pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None]
108
+ elif init_type == "thin":
109
+ heights = heights - heights.min()
110
+ heights = (heights / heights.max() * 0.2)
111
+ pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
112
+ else:
113
+ # stable but not accurate
114
+ heights = heights - heights.min()
115
+ heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1]
116
+ pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
117
+
118
+ # set the boarder pixels to 0 height
119
+ import cv2
120
+ # edge filter
121
+ edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255)
122
+ edge = torch.from_numpy(edge).bool()[..., None]
123
+ pred_HWC[edge] = 0
124
+
125
+ valid_HWC[pred_HWC < clamp_min] = False
126
+ rt_mesh = depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back)
127
+
128
+ if return_depth:
129
+ return rt_mesh, pred_HWC
130
+ return rt_mesh
131
+
132
+ # poisson reconstruction which guarantees a smooth connection between meshes
133
+ # and simplify into 2000 fewer faces
134
+ def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0):
135
+ ms = pymeshlab.MeshSet()
136
+ ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh")
137
+ if simplification > 0:
138
+ ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
139
+ ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True)
140
+ if simplification > 0:
141
+ ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
142
+ return meshlab_mesh_to_py3dmesh(ms.current_mesh())
models/ISOMER/scripts/normal_to_height_map.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code modified from https://github.com/YertleTurtleGit/depth-from-normals
2
+ import numpy as np
3
+ import cv2 as cv
4
+ from multiprocessing.pool import ThreadPool as Pool
5
+ from multiprocessing import cpu_count
6
+ from typing import Tuple, List, Union
7
+ import numba
8
+
9
+
10
+ def calculate_gradients(
11
+ normals: np.ndarray, mask: np.ndarray
12
+ ) -> Tuple[np.ndarray, np.ndarray]:
13
+ horizontal_angle_map = np.arccos(np.clip(normals[:, :, 0], -1, 1))
14
+ left_gradients = np.zeros(normals.shape[:2])
15
+ left_gradients[mask != 0] = (1 - np.sin(horizontal_angle_map[mask != 0])) * np.sign(
16
+ horizontal_angle_map[mask != 0] - np.pi / 2
17
+ )
18
+
19
+ vertical_angle_map = np.arccos(np.clip(normals[:, :, 1], -1, 1))
20
+ top_gradients = np.zeros(normals.shape[:2])
21
+ top_gradients[mask != 0] = -(1 - np.sin(vertical_angle_map[mask != 0])) * np.sign(
22
+ vertical_angle_map[mask != 0] - np.pi / 2
23
+ )
24
+
25
+ return left_gradients, top_gradients
26
+
27
+
28
+ @numba.jit(nopython=True)
29
+ def integrate_gradient_field(
30
+ gradient_field: np.ndarray, axis: int, mask: np.ndarray
31
+ ) -> np.ndarray:
32
+ heights = np.zeros(gradient_field.shape)
33
+
34
+ for d1 in numba.prange(heights.shape[1 - axis]): # numba.prange: executes the loop in parallel
35
+ sum_value = 0
36
+ for d2 in range(heights.shape[axis]):
37
+ coordinates = (d1, d2) if axis == 1 else (d2, d1)
38
+
39
+ if mask[coordinates] != 0:
40
+ sum_value = sum_value + gradient_field[coordinates] # equation 1 in paper along `axis` axis
41
+ heights[coordinates] = sum_value
42
+ else:
43
+ sum_value = 0
44
+
45
+ return heights
46
+
47
+ # equation 1 in paper wrt these directions
48
+ def calculate_heights(
49
+ left_gradients: np.ndarray, top_gradients, mask: np.ndarray
50
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
51
+ left_heights = integrate_gradient_field(left_gradients, 1, mask)
52
+ right_heights = np.fliplr(
53
+ integrate_gradient_field(np.fliplr(-left_gradients), 1, np.fliplr(mask))
54
+ )
55
+ top_heights = integrate_gradient_field(top_gradients, 0, mask)
56
+ bottom_heights = np.flipud(
57
+ integrate_gradient_field(np.flipud(-top_gradients), 0, np.flipud(mask))
58
+ )
59
+ return left_heights, right_heights, top_heights, bottom_heights
60
+
61
+
62
+ def combine_heights(*heights: np.ndarray) -> np.ndarray:
63
+ return np.mean(np.stack(heights, axis=0), axis=0)
64
+
65
+
66
+ def rotate(matrix: np.ndarray, angle: float) -> np.ndarray:
67
+ h, w = matrix.shape[:2]
68
+ center = (w / 2, h / 2)
69
+
70
+ rotation_matrix = cv.getRotationMatrix2D(center, angle, 1.0)
71
+ corners = cv.transform(
72
+ np.array([[[0, 0], [w, 0], [w, h], [0, h]]]), rotation_matrix
73
+ )[0]
74
+
75
+ _, _, w, h = cv.boundingRect(corners)
76
+
77
+ rotation_matrix[0, 2] += w / 2 - center[0]
78
+ rotation_matrix[1, 2] += h / 2 - center[1]
79
+ result = cv.warpAffine(matrix, rotation_matrix, (w, h), flags=cv.INTER_LINEAR)
80
+
81
+ return result
82
+
83
+
84
+ def rotate_vector_field_normals(normals: np.ndarray, angle: float) -> np.ndarray:
85
+ angle = np.radians(angle)
86
+ cos_angle = np.cos(angle)
87
+ sin_angle = np.sin(angle)
88
+
89
+ rotated_normals = np.empty_like(normals)
90
+ rotated_normals[:, :, 0] = (
91
+ normals[:, :, 0] * cos_angle - normals[:, :, 1] * sin_angle
92
+ )
93
+ rotated_normals[:, :, 1] = (
94
+ normals[:, :, 0] * sin_angle + normals[:, :, 1] * cos_angle
95
+ )
96
+
97
+ return rotated_normals
98
+
99
+
100
+ def centered_crop(image: np.ndarray, target_resolution: Tuple[int, int]) -> np.ndarray:
101
+ return image[
102
+ (image.shape[0] - target_resolution[0])
103
+ // 2 : (image.shape[0] - target_resolution[0])
104
+ // 2
105
+ + target_resolution[0],
106
+ (image.shape[1] - target_resolution[1])
107
+ // 2 : (image.shape[1] - target_resolution[1])
108
+ // 2
109
+ + target_resolution[1],
110
+ ]
111
+
112
+
113
+ def integrate_vector_field(
114
+ vector_field: np.ndarray,
115
+ mask: np.ndarray,
116
+ target_iteration_count: int,
117
+ thread_count: int,
118
+ ) -> np.ndarray:
119
+ shape = vector_field.shape[:2]
120
+ angles = np.linspace(0, 90, target_iteration_count, endpoint=False)
121
+
122
+ def integrate_vector_field_angles(angles: List[float]) -> np.ndarray:
123
+ all_combined_heights = np.zeros(shape)
124
+
125
+ for angle in angles:
126
+ rotated_vector_field = rotate_vector_field_normals(
127
+ rotate(vector_field, angle), angle
128
+ ) # rotate twice: first rotate the whole in image level, then rotate the individual normal vectors
129
+
130
+ rotated_mask = rotate(mask, angle)
131
+
132
+ left_gradients, top_gradients = calculate_gradients(
133
+ rotated_vector_field, rotated_mask
134
+ )
135
+ (
136
+ left_heights,
137
+ right_heights,
138
+ top_heights,
139
+ bottom_heights,
140
+ ) = calculate_heights(left_gradients, top_gradients, rotated_mask)
141
+
142
+ combined_heights = combine_heights(
143
+ left_heights, right_heights, top_heights, bottom_heights
144
+ ) # = mean of these heights
145
+ combined_heights = centered_crop(rotate(combined_heights, -angle), shape)
146
+ all_combined_heights += combined_heights / len(angles)
147
+
148
+ return all_combined_heights
149
+
150
+ with Pool(processes=thread_count) as pool:
151
+ heights = pool.map(
152
+ integrate_vector_field_angles,
153
+ np.array(
154
+ np.array_split(angles, thread_count),
155
+ dtype=object,
156
+ ),
157
+ )
158
+ pool.close()
159
+ pool.join()
160
+
161
+ isotropic_height = np.zeros(shape)
162
+ for height in heights:
163
+ isotropic_height += height / thread_count
164
+
165
+ return isotropic_height
166
+
167
+
168
+ def estimate_height_map(
169
+ normal_map: np.ndarray,
170
+ mask: Union[np.ndarray, None] = None,
171
+ height_divisor: float = 1,
172
+ target_iteration_count: int = 250,
173
+ thread_count: int = cpu_count(),
174
+ raw_values: bool = False,
175
+ ) -> np.ndarray:
176
+ if mask is None:
177
+ if normal_map.shape[-1] == 4:
178
+ mask = normal_map[:, :, 3] / 255
179
+ mask[mask < 0.5] = 0
180
+ mask[mask >= 0.5] = 1
181
+ else:
182
+ mask = np.ones(normal_map.shape[:2], dtype=np.uint8)
183
+
184
+ normals = ((normal_map[:, :, :3].astype(np.float64) / 255) - 0.5) * 2
185
+ heights = integrate_vector_field(
186
+ normals, mask, target_iteration_count, thread_count
187
+ ) # equation 1 in paper, repeat `target_iteration_count` (8?) times with rotation in angle np.linspace(0, 90, target_iteration_count), then find mean
188
+ # target_iteration_count=8 ? defined _MAX_THREAD = 8 in mesh_init.py
189
+
190
+ if raw_values:
191
+ return heights
192
+
193
+ heights /= height_divisor
194
+ heights[mask > 0] += 1 / 2
195
+ heights[mask == 0] = 1 / 2
196
+
197
+ heights *= 2**16 - 1
198
+
199
+ if np.min(heights) < 0 or np.max(heights) > 2**16 - 1:
200
+ raise OverflowError("Height values are clipping.")
201
+
202
+ heights = np.clip(heights, 0, 2**16 - 1)
203
+ heights = heights.astype(np.uint16)
204
+
205
+ return heights
models/ISOMER/scripts/proj_commands.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from pytorch3d.renderer import (
5
+ TexturesVertex,
6
+ )
7
+ from .project_mesh import (
8
+ get_cameras_list_azi_ele,
9
+ multiview_color_projection
10
+
11
+ )
12
+ from .utils import save_py3dmesh_with_trimesh_fast
13
+
14
+ def projection(meshes,
15
+ img_list,
16
+ weights,
17
+ azimuths,
18
+ elevations,
19
+ projection_type='orthographic',
20
+ auto_center=True,
21
+ resolution=1024,
22
+ fovy=None,
23
+ radius=None,
24
+ ortho_dist=1.1,
25
+ scale_factor=1.0,
26
+ save_glb_addr=None,
27
+ scale_verts=True,
28
+ complete_unseen=True,
29
+ below_confidence_strategy="smooth"
30
+ ):
31
+
32
+ assert len(img_list) == len(azimuths) == len(elevations) == len(weights), f"len(img_list) ({len(img_list)}) != len(azimuths) ({len(azimuths)}) != len(elevations) ({len(elevations)}) != len(weights) ({len(weights)})"
33
+
34
+ projection_types = ['perspective', 'orthographic']
35
+ assert projection_type in projection_types, f"projection_type ({projection_type}) should be one of {projection_types}"
36
+
37
+ if auto_center:
38
+ verts = meshes.verts_packed()
39
+ max_bb = (verts - 0).max(0)[0]
40
+ min_bb = (verts - 0).min(0)[0]
41
+ scale = (max_bb - min_bb).max() / 2
42
+ center = (max_bb + min_bb) / 2
43
+ meshes.offset_verts_(-center)
44
+ if scale_verts:
45
+ meshes.scale_verts_((scale_factor / float(scale)))
46
+ elif scale_verts:
47
+ meshes.scale_verts_((scale_factor))
48
+
49
+ if projection_type == 'perspective':
50
+ assert fovy is not None and radius is not None, f"fovy ({fovy}) and radius ({radius}) should not be None when projection_type is 'perspective'"
51
+ cameras = get_cameras_list_azi_ele(azimuths, elevations, fov_in_degrees=fovy,device="cuda", dist=radius, cam_type='fov')
52
+ elif projection_type == 'orthographic':
53
+ cameras = get_cameras_list_azi_ele(azimuths, elevations, fov_in_degrees=fovy, device="cuda", focal=2/1.35, dist=ortho_dist, cam_type='orthographic')
54
+
55
+
56
+ num_meshes = len(meshes)
57
+ num_verts_per_mesh = meshes.verts_packed().shape[0] // num_meshes
58
+ black_texture = torch.zeros((num_meshes, num_verts_per_mesh, 3), device="cuda")
59
+ textures = TexturesVertex(verts_features=black_texture)
60
+ meshes.textures = textures
61
+
62
+
63
+ proj_mesh = multiview_color_projection(meshes, img_list, cameras, weights=weights, eps=0.05, resolution=resolution, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=complete_unseen, below_confidence_strategy=below_confidence_strategy)
64
+
65
+
66
+ if save_glb_addr is not None:
67
+ save_py3dmesh_with_trimesh_fast(proj_mesh, save_glb_addr)
68
+
69
+
models/ISOMER/scripts/project_mesh.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from pytorch3d.renderer.cameras import look_at_view_transform, OrthographicCameras, CamerasBase
6
+ from pytorch3d.io import load_objs_as_meshes
7
+ from pytorch3d.renderer.mesh.rasterizer import Fragments
8
+ from pytorch3d.structures import Meshes
9
+ from pytorch3d.renderer import (
10
+ RasterizationSettings,
11
+ TexturesVertex,
12
+ FoVPerspectiveCameras,
13
+ FoVOrthographicCameras,
14
+ )
15
+ from pytorch3d.renderer import MeshRasterizer
16
+
17
+ def get_camera(world_to_cam, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'):
18
+ # pytorch3d expects transforms as row-vectors, so flip rotation: https://github.com/facebookresearch/pytorch3d/issues/1183
19
+ R = world_to_cam[:3, :3].t()[None, ...]
20
+ T = world_to_cam[:3, 3][None, ...]
21
+ if cam_type == 'fov':
22
+ assert fov_in_degrees is not None, "fov_in_degrees should not be None when cam_type is fov"
23
+ camera = FoVPerspectiveCameras(device=world_to_cam.device, R=R, T=T, fov=fov_in_degrees, degrees=True)
24
+ else:
25
+ focal_length = 1 / focal_length
26
+ camera = FoVOrthographicCameras(device=world_to_cam.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
27
+ return camera
28
+
29
+ def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1):
30
+ """
31
+ Renders pix2face of visible faces.
32
+
33
+ :param mesh: Pytorch3d.structures.Meshes
34
+ :param cameras: pytorch3d.renderer.Cameras
35
+ :param H: target image height
36
+ :param W: target image width
37
+ :param blur_radius: Float distance in the range [0, 2] used to expand the face
38
+ bounding boxes for rasterization. Setting blur radius
39
+ results in blurred edges around the shape instead of a
40
+ hard boundary. Set to 0 for no blur.
41
+ :param faces_per_pixel: (int) Number of faces to keep track of per pixel.
42
+ We return the nearest faces_per_pixel faces along the z-axis.
43
+ """
44
+ # Define the settings for rasterization and shading
45
+ raster_settings = RasterizationSettings(
46
+ image_size=(H, W),
47
+ blur_radius=blur_radius,
48
+ faces_per_pixel=faces_per_pixel
49
+ )
50
+ rasterizer=MeshRasterizer(
51
+ cameras=cameras,
52
+ raster_settings=raster_settings
53
+ )
54
+ fragments: Fragments = rasterizer(meshes, cameras=cameras)
55
+ return {
56
+ "pix_to_face": fragments.pix_to_face[..., 0],
57
+ }
58
+
59
+ import nvdiffrast.torch as dr
60
+
61
+ def _warmup(glctx, device=None):
62
+ device = 'cuda' if device is None else device
63
+ #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
64
+ def tensor(*args, **kwargs):
65
+ return torch.tensor(*args, device=device, **kwargs)
66
+ pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
67
+ tri = tensor([[0, 1, 2]], dtype=torch.int32)
68
+ dr.rasterize(glctx, pos, tri, resolution=[256, 256])
69
+
70
+ class Pix2FacesRenderer:
71
+ def __init__(self, device="cuda"):
72
+ # self._glctx = dr.RasterizeGLContext(output_db=False, device=device)
73
+ self._glctx = dr.RasterizeCudaContext(device=device)
74
+ self.device = device
75
+ _warmup(self._glctx, device)
76
+
77
+ def transform_vertices(self, meshes: Meshes, cameras: CamerasBase):
78
+ vertices = cameras.transform_points_ndc(meshes.verts_padded())
79
+
80
+ perspective_correct = cameras.is_perspective()
81
+ znear = cameras.get_znear()
82
+ if isinstance(znear, torch.Tensor):
83
+ znear = znear.min().item()
84
+ z_clip = None #if not perspective_correct or znear is None else znear / 2
85
+
86
+ if z_clip:
87
+ vertices = vertices[vertices[..., 2] >= cameras.get_znear()][None] # clip
88
+ vertices = vertices * torch.tensor([-1, -1, 1]).to(vertices)
89
+ vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1).to(torch.float32)
90
+ return vertices
91
+
92
+ def render_pix2faces_nvdiff(self, meshes: Meshes, cameras: CamerasBase, H=512, W=512):
93
+ meshes = meshes.to(self.device)
94
+ cameras = cameras.to(self.device)
95
+ vertices = self.transform_vertices(meshes, cameras)
96
+ faces = meshes.faces_packed().to(torch.int32)
97
+ rast_out,_ = dr.rasterize(self._glctx, vertices, faces, resolution=(H, W), grad_db=False) #C,H,W,4
98
+ pix_to_face = rast_out[..., -1].to(torch.int32) - 1
99
+ return pix_to_face
100
+
101
+ pix2faces_renderer = Pix2FacesRenderer()
102
+
103
+ def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024):
104
+ # pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face']
105
+ pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution)
106
+
107
+ unique_faces = torch.unique(pix_to_face.flatten())
108
+ unique_faces = unique_faces[unique_faces != -1]
109
+ return unique_faces
110
+
111
+ def project_color(meshes: Meshes, cameras: CamerasBase, pil_image: Image.Image, use_alpha=True, eps=0.05, resolution=1024, device="cuda") -> dict:
112
+ """
113
+ Projects color from a given image onto a 3D mesh.
114
+
115
+ Args:
116
+ meshes (pytorch3d.structures.Meshes): The 3D mesh object.
117
+ cameras (pytorch3d.renderer.cameras.CamerasBase): The camera object.
118
+ pil_image (PIL.Image.Image): The input image.
119
+ use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True.
120
+ eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05.
121
+ resolution (int, optional): The resolution of the projection. Defaults to 1024.
122
+ device (str, optional): The device to use for computation. Defaults to "cuda".
123
+ debug (bool, optional): Whether to save debug images. Defaults to False.
124
+
125
+ Returns:
126
+ dict: A dictionary containing the following keys:
127
+ - "new_texture" (TexturesVertex): The updated texture with interpolated colors.
128
+ - "valid_verts" (Tensor of [M,3]): The indices of the vertices being projected.
129
+ - "valid_colors" (Tensor of [M,3]): The interpolated colors for the valid vertices.
130
+ """
131
+ meshes = meshes.to(device)
132
+ cameras = cameras.to(device)
133
+ image = torch.from_numpy(np.array(pil_image.convert("RGBA")) / 255.).permute((2, 0, 1)).float().to(device) # in CHW format of [0, 1.]
134
+ unique_faces = get_visible_faces(meshes, cameras, resolution=resolution)
135
+
136
+ # visible faces
137
+ faces_normals = meshes.faces_normals_packed()[unique_faces]
138
+ faces_normals = faces_normals / faces_normals.norm(dim=1, keepdim=True)
139
+ world_points = cameras.unproject_points(torch.tensor([[[0., 0., 0.1], [0., 0., 0.2]]]).to(device))[0]
140
+ view_direction = world_points[1] - world_points[0]
141
+ view_direction = view_direction / view_direction.norm(dim=0, keepdim=True)
142
+
143
+
144
+ # find invalid faces
145
+ cos_angles = (faces_normals * view_direction).sum(dim=1)
146
+ # assert cos_angles.mean() < 0, f"The view direction is not correct. cos_angles.mean()={cos_angles.mean()}"
147
+ selected_faces = unique_faces[cos_angles < -eps]
148
+
149
+ # find verts
150
+ faces = meshes.faces_packed()[selected_faces] # [N, 3]
151
+ verts = torch.unique(faces.flatten()) # [N, 1]
152
+ verts_coordinates = meshes.verts_packed()[verts] # [N, 3]
153
+
154
+ # compute color
155
+ pt_tensor = cameras.transform_points(verts_coordinates)[..., :2] # NDC space points
156
+ valid = ~((pt_tensor.isnan()|(pt_tensor<-1)|(1<pt_tensor)).any(dim=1)) # checked, correct
157
+ valid_pt = pt_tensor[valid, :]
158
+ valid_idx = verts[valid]
159
+ valid_color = torch.nn.functional.grid_sample(image[None].flip((-1, -2)), valid_pt[None, :, None, :], align_corners=False, padding_mode="reflection", mode="bilinear")[0, :, :, 0].T.clamp(0, 1) # [N, 4], note that bicubic may give invalid value
160
+ alpha, valid_color = valid_color[:, 3:], valid_color[:, :3]
161
+ if not use_alpha:
162
+ alpha = torch.ones_like(alpha)
163
+
164
+ # modify color
165
+ old_colors = meshes.textures.verts_features_packed()
166
+ old_colors[valid_idx] = valid_color * alpha + old_colors[valid_idx] * (1 - alpha)
167
+ new_texture = TexturesVertex(verts_features=[old_colors])
168
+
169
+ valid_verts_normals = meshes.verts_normals_packed()[valid_idx]
170
+ valid_verts_normals = valid_verts_normals / valid_verts_normals.norm(dim=1, keepdim=True).clamp_min(0.001)
171
+ cos_angles = (valid_verts_normals * view_direction).sum(dim=1)
172
+ return {
173
+ "new_texture": new_texture,
174
+ "valid_verts": valid_idx,
175
+ "valid_colors": valid_color,
176
+ "valid_alpha": alpha,
177
+ "cos_angles": cos_angles,
178
+ }
179
+
180
+ def complete_unseen_vertex_color(meshes: Meshes, valid_index: torch.Tensor) -> dict:
181
+ """
182
+ meshes: the mesh with vertex color to be completed.
183
+ valid_index: the index of the valid vertices, where valid means colors are fixed. [V, 1]
184
+ """
185
+ valid_index = valid_index.to(meshes.device)
186
+ colors = meshes.textures.verts_features_packed() # [V, 3]
187
+ V = colors.shape[0]
188
+
189
+ invalid_index = torch.ones_like(colors[:, 0]).bool() # [V]
190
+ invalid_index[valid_index] = False
191
+ invalid_index = torch.arange(V).to(meshes.device)[invalid_index]
192
+
193
+ L = meshes.laplacian_packed() # connectivity
194
+ E = torch.sparse_coo_tensor(torch.tensor([list(range(V))] * 2), torch.ones((V,)), size=(V, V)).to(meshes.device)
195
+ L = L + E
196
+ # E = torch.eye(V, layout=torch.sparse_coo, device=meshes.device)
197
+ # L = L + E
198
+ colored_count = torch.ones_like(colors[:, 0]) # [V]
199
+ colored_count[invalid_index] = 0
200
+ L_invalid = torch.index_select(L, 0, invalid_index) # sparse [IV, V]
201
+
202
+ total_colored = colored_count.sum()
203
+ coloring_round = 0
204
+ stage = "uncolored"
205
+ from tqdm import tqdm
206
+ pbar = tqdm(miniters=100)
207
+ while stage == "uncolored" or coloring_round > 0:
208
+ new_color = torch.matmul(L_invalid, colors * colored_count[:, None]) # [IV, 3]
209
+ new_count = torch.matmul(L_invalid, colored_count)[:, None] # [IV, 1]
210
+ colors[invalid_index] = torch.where(new_count > 0, new_color / new_count, colors[invalid_index])
211
+ colored_count[invalid_index] = (new_count[:, 0] > 0).float()
212
+
213
+ new_total_colored = colored_count.sum()
214
+ if new_total_colored > total_colored:
215
+ total_colored = new_total_colored
216
+ coloring_round += 1
217
+ else:
218
+ stage = "colored"
219
+ coloring_round -= 1
220
+ pbar.update(1)
221
+ if coloring_round > 10000:
222
+ print("coloring_round > 10000, break")
223
+ break
224
+ assert not torch.isnan(colors).any()
225
+ meshes.textures = TexturesVertex(verts_features=[colors])
226
+ return meshes
227
+
228
+ def load_glb_mesh(glb_path, device="cuda"):
229
+ meshes = load_objs_as_meshes([glb_path], device=device)
230
+ return meshes
231
+
232
+ def get_separated_images_from_img_grid(img_grid_path, image_num):
233
+ img_list = []
234
+ grid = Image.open(img_grid_path)
235
+ w, h = grid.size
236
+ for i in range(0, image_num):
237
+ img_list.append(grid.crop((i*h, 0, i*h + h, h)))
238
+ return img_list
239
+
240
+ def get_fov_camera_(azimuth, elevation, fovy, radius, mesh, auto_center, scale_factor, device='cuda'):
241
+ if auto_center:
242
+ verts = mesh.verts_packed()
243
+ max_bb = (verts - 0).max(0)[0]
244
+ min_bb = (verts - 0).min(0)[0]
245
+ scale = (max_bb - min_bb).max() / 2
246
+ center = (max_bb + min_bb) / 2
247
+ mesh.offset_verts_(-center)
248
+ mesh.scale_verts_((scale_factor / float(scale)))
249
+ else:
250
+ mesh.scale_verts_((scale_factor))
251
+ R, T = look_at_view_transform(radius, azimuth, elevation, device=device)
252
+ cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=fovy)
253
+ return cameras
254
+
255
+ def multiview_color_projection(meshes: Meshes, image_list: List[Image.Image], cameras_list: List[CamerasBase], weights=None, eps=0.05, resolution=1024, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy="smooth") -> Meshes:
256
+ """
257
+ Projects color from a given image onto a 3D mesh.
258
+
259
+ Args:
260
+ meshes (pytorch3d.structures.Meshes): The 3D mesh object, only one mesh.
261
+ image_list (PIL.Image.Image): List of images.
262
+ cameras_list (list): List of cameras.
263
+ weights (list, optional): List of weights for each image, for ['front', 'front_right', 'right', 'back', 'left', 'front_left']. Defaults to None.
264
+ eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05.
265
+ resolution (int, optional): The resolution of the projection. Defaults to 1024.
266
+ device (str, optional): The device to use for computation. Defaults to "cuda".
267
+ reweight_with_cosangle (str, optional): Whether to reweight the color with the angle between the view direction and the vertex normal. Defaults to None.
268
+ use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True.
269
+ confidence_threshold (float, optional): The threshold for the confidence of the projected color, if final projection weight is less than this, we will use the original color. Defaults to 0.1.
270
+ complete_unseen (bool, optional): Whether to complete the unseen vertex color using laplacian. Defaults to False.
271
+
272
+ Returns:
273
+ Meshes: the colored mesh
274
+ """
275
+
276
+ if image_list is None:
277
+ raise ValueError("image_list is None")
278
+
279
+
280
+ meshes = meshes.clone().to(device)
281
+ if weights is None:
282
+ weights = [1. for _ in range(len(cameras_list))]
283
+
284
+ assert len(cameras_list) == len(image_list) == len(weights), f'the following three lengths should be equal: len(cameras_list)({len(cameras_list)}), len(image_list)({len(image_list)}), len(weights)({len(weights)})'
285
+
286
+ original_color = meshes.textures.verts_features_packed()
287
+ assert not torch.isnan(original_color).any()
288
+ texture_counts = torch.zeros_like(original_color[..., :1])
289
+ texture_values = torch.zeros_like(original_color)
290
+ max_texture_counts = torch.zeros_like(original_color[..., :1])
291
+ max_texture_values = torch.zeros_like(original_color)
292
+ for camera, image, weight in zip(cameras_list, image_list, weights):
293
+ ret = project_color(meshes, camera, image, eps=eps, resolution=resolution, device=device, use_alpha=use_alpha)
294
+ if reweight_with_cosangle == "linear":
295
+ weight = (ret['cos_angles'].abs() * weight)[:, None]
296
+ elif reweight_with_cosangle == "square":
297
+ weight = (ret['cos_angles'].abs() ** 2 * weight)[:, None]
298
+ if use_alpha:
299
+ weight = weight * ret['valid_alpha']
300
+
301
+ try:
302
+ assert weight.min() > -0.0001, f'weight.min() is {weight.min()}, but shoule be > -0.0001'
303
+ except Exception as e:
304
+ raise e
305
+
306
+ texture_counts[ret['valid_verts']] += weight
307
+ texture_values[ret['valid_verts']] += ret['valid_colors'] * weight
308
+ max_texture_values[ret['valid_verts']] = torch.where(weight > max_texture_counts[ret['valid_verts']], ret['valid_colors'], max_texture_values[ret['valid_verts']])
309
+ max_texture_counts[ret['valid_verts']] = torch.max(max_texture_counts[ret['valid_verts']], weight)
310
+
311
+ texture_values = torch.where(texture_counts > confidence_threshold, texture_values / texture_counts, texture_values)
312
+ if below_confidence_strategy == "smooth":
313
+ texture_values = torch.where(texture_counts <= confidence_threshold, (original_color * (confidence_threshold - texture_counts) + texture_values) / confidence_threshold, texture_values)
314
+ elif below_confidence_strategy == "original":
315
+ texture_values = torch.where(texture_counts <= confidence_threshold, original_color, texture_values)
316
+ else:
317
+ raise ValueError(f"below_confidence_strategy={below_confidence_strategy} is not supported")
318
+ assert not torch.isnan(texture_values).any()
319
+ meshes.textures = TexturesVertex(verts_features=[texture_values])
320
+
321
+ if complete_unseen:
322
+ meshes = complete_unseen_vertex_color(meshes, torch.arange(texture_values.shape[0]).to(device)[texture_counts[:, 0] >= confidence_threshold])
323
+ ret_mesh = meshes.detach()
324
+ del meshes
325
+ return ret_mesh
326
+
327
+ def get_cameras_list(azim_list, device, elevation, fov_in_degrees=None, focal=2/1.35, dist=1.1, cam_type='orthographic'):
328
+ ret = []
329
+ for azim in azim_list:
330
+ R, T = look_at_view_transform(dist, elevation, azim)
331
+ w2c = torch.cat([R[0].T, T[0, :, None]], dim=1)
332
+ cameras = get_camera(w2c, fov_in_degrees=fov_in_degrees, focal_length=focal, cam_type=cam_type).to(device)
333
+ ret.append(cameras)
334
+ return ret
335
+
336
+ def get_cameras_list_azi_ele(azim_list, elev_list, device, fov_in_degrees=None, focal=2/1.35, dist=1.1, cam_type='orthographic'):
337
+ ret = []
338
+ for i in range(len(azim_list)):
339
+ R, T = look_at_view_transform(dist, elev_list[i], azim_list[i])
340
+ w2c = torch.cat([R[0].T, T[0, :, None]], dim=1)
341
+ cameras = get_camera(w2c, fov_in_degrees=fov_in_degrees, focal_length=focal, cam_type=cam_type).to(device)
342
+ ret.append(cameras)
343
+ return ret
344
+
345
+ def get_8view_cameras(device, focal=2/1.35):
346
+ return get_cameras_list(azim_list = [180, 225, 270, 315, 0, 45, 90, 135], elevation=0, device=device, focal=focal)
347
+
348
+ def get_6view_cameras(device, focal=2/1.35):
349
+ return get_cameras_list(azim_list = [180, 225, 270, 0, 90, 135], elevation=0, device=device, focal=focal)
350
+
351
+ def get_4view_cameras(device, focal=2/1.35):
352
+ return get_cameras_list(azim_list = [180, 270, 0, 90], elevation=0, device=device, focal=focal)
353
+
354
+ def get_2view_cameras(device, focal=2/1.35):
355
+ return get_cameras_list(azim_list = [180, 0], elevation=0, device=device, focal=focal)
356
+
357
+ def get_multiple_view_cameras(device, focal=2/1.35, offset=180, num_views=8, dist=1.1):
358
+ return get_cameras_list(azim_list = (np.linspace(0, 360, num_views+1)[:-1] + offset) % 360, elevation=0, device=device, focal=focal, dist=dist)
359
+
360
+ def align_with_alpha_bbox(source_img, target_img, final_size=1024):
361
+ # align source_img with target_img using alpha channel
362
+ # source_img and target_img are PIL.Image.Image
363
+ source_img = source_img.convert("RGBA")
364
+ target_img = target_img.convert("RGBA").resize((final_size, final_size))
365
+ source_np = np.array(source_img)
366
+ target_np = np.array(target_img)
367
+ source_alpha = source_np[:, :, 3]
368
+ target_alpha = target_np[:, :, 3]
369
+ bbox_source_min, bbox_source_max = np.argwhere(source_alpha > 0).min(axis=0), np.argwhere(source_alpha > 0).max(axis=0)
370
+ bbox_target_min, bbox_target_max = np.argwhere(target_alpha > 0).min(axis=0), np.argwhere(target_alpha > 0).max(axis=0)
371
+ source_content = source_np[bbox_source_min[0]:bbox_source_max[0]+1, bbox_source_min[1]:bbox_source_max[1]+1, :]
372
+ # resize source_content to fit in the position of target_content
373
+ source_content = Image.fromarray(source_content).resize((bbox_target_max[1]-bbox_target_min[1]+1, bbox_target_max[0]-bbox_target_min[0]+1), resample=Image.BICUBIC)
374
+ target_np[bbox_target_min[0]:bbox_target_max[0]+1, bbox_target_min[1]:bbox_target_max[1]+1, :] = np.array(source_content)
375
+ return Image.fromarray(target_np)
376
+
377
+ def load_image_list_from_mvdiffusion(mvdiffusion_path, front_from_pil_or_path=None):
378
+ import os
379
+ image_list = []
380
+ for dir in ['front', 'front_right', 'right', 'back', 'left', 'front_left']:
381
+ image_path = os.path.join(mvdiffusion_path, f"rgb_000_{dir}.png")
382
+ pil = Image.open(image_path)
383
+ if dir == 'front':
384
+ if front_from_pil_or_path is not None:
385
+ if isinstance(front_from_pil_or_path, str):
386
+ replace_pil = Image.open(front_from_pil_or_path)
387
+ else:
388
+ replace_pil = front_from_pil_or_path
389
+ # align replace_pil with pil using bounding box in alpha channel
390
+ pil = align_with_alpha_bbox(replace_pil, pil, final_size=1024)
391
+ image_list.append(pil)
392
+ return image_list
393
+
394
+ def load_image_list_from_img_grid(img_grid_path, resolution = 1024):
395
+ img_list = []
396
+ grid = Image.open(img_grid_path)
397
+ w, h = grid.size
398
+ for row in range(0, h, resolution):
399
+ for col in range(0, w, resolution):
400
+ img_list.append(grid.crop((col, row, col + resolution, row + resolution)))
401
+ return img_list
models/ISOMER/scripts/refine_lr_to_sr.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ import numpy as np
5
+ from hashlib import md5
6
+ def hash_img(img):
7
+ return md5(np.array(img).tobytes()).hexdigest()
8
+ def hash_any(obj):
9
+ return md5(str(obj).encode()).hexdigest()
10
+
11
+ def refine_lr_with_sd(pil_image_list, concept_img_list, control_image_list, prompt_list, pipe=None, strength=0.35, neg_prompt_list="", output_size=(512, 512), controlnet_conditioning_scale=1.):
12
+ with torch.no_grad():
13
+ images = pipe(
14
+ image=pil_image_list,
15
+ ip_adapter_image=concept_img_list,
16
+ prompt=prompt_list,
17
+ neg_prompt=neg_prompt_list,
18
+ num_inference_steps=50,
19
+ strength=strength,
20
+ height=output_size[0],
21
+ width=output_size[1],
22
+ control_image=control_image_list,
23
+ guidance_scale=5.0,
24
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
25
+ generator=torch.manual_seed(233),
26
+ ).images
27
+ return images
28
+
29
+ SR_cache = None
30
+
31
+ def run_sr_fast(source_pils, scale=4):
32
+ from PIL import Image
33
+ from scripts.upsampler import RealESRGANer
34
+ import numpy as np
35
+ global SR_cache
36
+ if SR_cache is not None:
37
+ upsampler = SR_cache
38
+ else:
39
+ upsampler = RealESRGANer(
40
+ scale=4,
41
+ onnx_path="ckpt/realesrgan-x4.onnx",
42
+ tile=0,
43
+ tile_pad=10,
44
+ pre_pad=0,
45
+ half=True,
46
+ gpu_id=0,
47
+ )
48
+ ret_pils = []
49
+ for idx, img_pils in enumerate(source_pils):
50
+ np_in = isinstance(img_pils, np.ndarray)
51
+ assert isinstance(img_pils, (Image.Image, np.ndarray))
52
+ img = np.array(img_pils)
53
+ output, _ = upsampler.enhance(img, outscale=scale)
54
+ if np_in:
55
+ ret_pils.append(output)
56
+ else:
57
+ ret_pils.append(Image.fromarray(output))
58
+ if SR_cache is None:
59
+ SR_cache = upsampler
60
+ return ret_pils
models/ISOMER/scripts/sd_model_zoo.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
2
+ from transformers import CLIPVisionModelWithProjection
3
+ import torch
4
+ from copy import deepcopy
5
+
6
+ ENABLE_CPU_CACHE = False
7
+ DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5"
8
+
9
+ cached_models = {} # cache for models to avoid repeated loading, key is model name
10
+ def cache_model(func):
11
+ def wrapper(*args, **kwargs):
12
+ if ENABLE_CPU_CACHE:
13
+ model_name = func.__name__ + str(args) + str(kwargs)
14
+ if model_name not in cached_models:
15
+ cached_models[model_name] = func(*args, **kwargs)
16
+ return cached_models[model_name]
17
+ else:
18
+ return func(*args, **kwargs)
19
+ return wrapper
20
+
21
+ def copied_cache_model(func):
22
+ def wrapper(*args, **kwargs):
23
+ if ENABLE_CPU_CACHE:
24
+ model_name = func.__name__ + str(args) + str(kwargs)
25
+ if model_name not in cached_models:
26
+ cached_models[model_name] = func(*args, **kwargs)
27
+ return deepcopy(cached_models[model_name])
28
+ else:
29
+ return func(*args, **kwargs)
30
+ return wrapper
31
+
32
+ def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs):
33
+ if ckpt_or_pretrained.endswith(".safetensors"):
34
+ pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs)
35
+ else:
36
+ pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs)
37
+ return pipe
38
+
39
+ @copied_cache_model
40
+ def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16):
41
+ model_kwargs = dict(
42
+ torch_dtype=torch_dtype,
43
+ requires_safety_checker=False,
44
+ safety_checker=None,
45
+ )
46
+ pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
47
+ base_model,
48
+ StableDiffusionPipeline,
49
+ **model_kwargs
50
+ )
51
+ pipe.to("cpu")
52
+ return pipe.components
53
+
54
+ @cache_model
55
+ def load_controlnet(controlnet_path, torch_dtype=torch.float16):
56
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype)
57
+ return controlnet
58
+
59
+ @cache_model
60
+ def load_image_encoder():
61
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
62
+ "h94/IP-Adapter",
63
+ subfolder="models/image_encoder",
64
+ torch_dtype=torch.float16,
65
+ )
66
+ return image_encoder
67
+
68
+ def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="auto", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs):
69
+ model_kwargs = dict(
70
+ torch_dtype=torch_dtype,
71
+ device_map=device,
72
+ requires_safety_checker=False,
73
+ safety_checker=None,
74
+ )
75
+ components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype)
76
+ model_kwargs.update(components)
77
+ model_kwargs.update(kwargs)
78
+
79
+ if controlnet is not None:
80
+ if isinstance(controlnet, list):
81
+ controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet]
82
+ else:
83
+ controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype)
84
+ model_kwargs.update(controlnet=controlnet)
85
+
86
+ if pipeline_class is None:
87
+ if controlnet is not None:
88
+ pipeline_class = StableDiffusionControlNetPipeline
89
+ else:
90
+ pipeline_class = StableDiffusionPipeline
91
+
92
+ pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
93
+ base_model,
94
+ pipeline_class,
95
+ **model_kwargs
96
+ )
97
+
98
+ if ip_adapter:
99
+ image_encoder = load_image_encoder()
100
+ pipe.image_encoder = image_encoder
101
+ if plus_model:
102
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors")
103
+ else:
104
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors")
105
+ pipe.set_ip_adapter_scale(1.0)
106
+ else:
107
+ pipe.unload_ip_adapter()
108
+
109
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
110
+
111
+ if model_cpu_offload_seq is None:
112
+ if isinstance(pipe, StableDiffusionControlNetPipeline):
113
+ pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae"
114
+ elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline):
115
+ pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae"
116
+ else:
117
+ pipe.model_cpu_offload_seq = model_cpu_offload_seq
118
+
119
+ if enable_sequential_cpu_offload:
120
+ pipe.enable_sequential_cpu_offload()
121
+ else:
122
+ pipe = pipe.to("cuda")
123
+ pass
124
+ # pipe.enable_model_cpu_offload()
125
+ if vae_slicing:
126
+ pipe.enable_vae_slicing()
127
+
128
+ import gc
129
+ gc.collect()
130
+ return pipe
131
+
models/ISOMER/scripts/upsampler.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from torch.nn import functional as F
7
+ from scripts.load_onnx import load_onnx_caller
8
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9
+
10
+
11
+ class RealESRGANer():
12
+ """A helper class for upsampling images with RealESRGAN.
13
+
14
+ Args:
15
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
16
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
17
+ model (nn.Module): The defined network. Default: None.
18
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
19
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
20
+ 0 denotes for do not use tile. Default: 0.
21
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
22
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
23
+ half (float): Whether to use half precision during inference. Default: False.
24
+ """
25
+
26
+ def __init__(self,
27
+ scale,
28
+ onnx_path,
29
+ tile=0,
30
+ tile_pad=10,
31
+ pre_pad=10,
32
+ half=False,
33
+ device=None,
34
+ gpu_id=None):
35
+ self.scale = scale
36
+ self.tile_size = tile
37
+ self.tile_pad = tile_pad
38
+ self.pre_pad = pre_pad
39
+ self.mod_scale = None
40
+ self.half = half
41
+
42
+ print('about to initialize model')
43
+ # initialize model
44
+ if gpu_id:
45
+ self.device = torch.device(
46
+ f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
47
+ else:
48
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
49
+ print('self.device set')
50
+ print(f'about to self.model = load_onnx_caller({onnx_path}, single_output=True)')
51
+ self.model = load_onnx_caller(onnx_path, single_output=True)
52
+ print('self.model loaded')
53
+
54
+ print('about to warm up')
55
+ # warm up
56
+ sample_input = torch.randn(1,3,512,512).cuda().float()
57
+ print(f'sample_input.shape = {sample_input.shape}')
58
+ self.model(sample_input)
59
+ print('finished warming up')
60
+
61
+ def pre_process(self, img):
62
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
63
+ """
64
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
65
+ self.img = img.unsqueeze(0).to(self.device)
66
+ if self.half:
67
+ self.img = self.img.half()
68
+
69
+ # pre_pad
70
+ if self.pre_pad != 0:
71
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
72
+ # mod pad for divisible borders
73
+ if self.scale == 2:
74
+ self.mod_scale = 2
75
+ elif self.scale == 1:
76
+ self.mod_scale = 4
77
+ if self.mod_scale is not None:
78
+ self.mod_pad_h, self.mod_pad_w = 0, 0
79
+ _, _, h, w = self.img.size()
80
+ if (h % self.mod_scale != 0):
81
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
82
+ if (w % self.mod_scale != 0):
83
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
84
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
85
+
86
+ def process(self):
87
+ # model inference
88
+ self.output = self.model(self.img)
89
+
90
+ def tile_process(self):
91
+ """It will first crop input images to tiles, and then process each tile.
92
+ Finally, all the processed tiles are merged into one images.
93
+
94
+ Modified from: https://github.com/ata4/esrgan-launcher
95
+ """
96
+ batch, channel, height, width = self.img.shape
97
+ output_height = height * self.scale
98
+ output_width = width * self.scale
99
+ output_shape = (batch, channel, output_height, output_width)
100
+
101
+ # start with black image
102
+ self.output = self.img.new_zeros(output_shape)
103
+ tiles_x = math.ceil(width / self.tile_size)
104
+ tiles_y = math.ceil(height / self.tile_size)
105
+
106
+ # loop over all tiles
107
+ for y in range(tiles_y):
108
+ for x in range(tiles_x):
109
+ # extract tile from input image
110
+ ofs_x = x * self.tile_size
111
+ ofs_y = y * self.tile_size
112
+ # input tile area on total image
113
+ input_start_x = ofs_x
114
+ input_end_x = min(ofs_x + self.tile_size, width)
115
+ input_start_y = ofs_y
116
+ input_end_y = min(ofs_y + self.tile_size, height)
117
+
118
+ # input tile area on total image with padding
119
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
120
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
121
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
122
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
123
+
124
+ # input tile dimensions
125
+ input_tile_width = input_end_x - input_start_x
126
+ input_tile_height = input_end_y - input_start_y
127
+ tile_idx = y * tiles_x + x + 1
128
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
129
+
130
+ # upscale tile
131
+ try:
132
+ with torch.no_grad():
133
+ output_tile = self.model(input_tile)
134
+ except RuntimeError as error:
135
+ print('Error', error)
136
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
137
+
138
+ # output tile area on total image
139
+ output_start_x = input_start_x * self.scale
140
+ output_end_x = input_end_x * self.scale
141
+ output_start_y = input_start_y * self.scale
142
+ output_end_y = input_end_y * self.scale
143
+
144
+ # output tile area without padding
145
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
146
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
147
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
148
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
149
+
150
+ # put tile into output image
151
+ self.output[:, :, output_start_y:output_end_y,
152
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
153
+ output_start_x_tile:output_end_x_tile]
154
+
155
+ def post_process(self):
156
+ # remove extra pad
157
+ if self.mod_scale is not None:
158
+ _, _, h, w = self.output.size()
159
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
160
+ # remove prepad
161
+ if self.pre_pad != 0:
162
+ _, _, h, w = self.output.size()
163
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
164
+ return self.output
165
+
166
+ @torch.no_grad()
167
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
168
+ print('inside enhance')
169
+ h_input, w_input = img.shape[0:2]
170
+ # img: numpy
171
+ img = img.astype(np.float32)
172
+ if np.max(img) > 256: # 16-bit image
173
+ max_range = 65535
174
+ print('\tInput is a 16-bit image')
175
+ else:
176
+ max_range = 255
177
+ img = img / max_range
178
+ if len(img.shape) == 2: # gray image
179
+ img_mode = 'L'
180
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
181
+ elif img.shape[2] == 4: # RGBA image with alpha channel
182
+ img_mode = 'RGBA'
183
+ alpha = img[:, :, 3]
184
+ img = img[:, :, 0:3]
185
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
186
+ if alpha_upsampler == 'realesrgan':
187
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
188
+ else:
189
+ img_mode = 'RGB'
190
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
191
+
192
+ # ------------------- process image (without the alpha channel) ------------------- #
193
+ print('about to process image (without the alpha channel)')
194
+ self.pre_process(img)
195
+ if self.tile_size > 0:
196
+ print(f'self.tile_size is {self.tile_size}, thus about to self.tile_process()')
197
+ self.tile_process()
198
+ print('finished self.tile_process()')
199
+ else:
200
+ print('about to self.process()')
201
+ self.process()
202
+ print('finished self.process()')
203
+
204
+ print('about to self.post_process()')
205
+ output_img = self.post_process()
206
+ print('finished self.post_process()')
207
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
208
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
209
+ if img_mode == 'L':
210
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
211
+ print('finished process image (without the alpha channel)')
212
+
213
+ # ------------------- process the alpha channel if necessary ------------------- #
214
+ if img_mode == 'RGBA':
215
+ print("img_mode == 'RGBA' thus about to process alpha channel")
216
+ if alpha_upsampler == 'realesrgan':
217
+ print(f"alpha_upsampler == 'realesrgan', about to self.pre_process({alpha})")
218
+ self.pre_process(alpha)
219
+ print('finished self.pre_process')
220
+ if self.tile_size > 0:
221
+ print(f'self.tile_size is {self.tile_size}, thus about to self.tile_process()')
222
+ self.tile_process()
223
+ print('finished self.tile_process()')
224
+ else:
225
+ print('about to self.process()')
226
+ self.process()
227
+ print('finished self.process()')
228
+ print('about to self.post_process()')
229
+ output_alpha = self.post_process()
230
+ print('finished self.post_process()')
231
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
232
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
233
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
234
+ else: # use the cv2 resize for alpha channel
235
+ print('about to use the cv2 resize for alpha channel')
236
+ h, w = alpha.shape[0:2]
237
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
238
+
239
+ print('about to merge the alpha channel')
240
+ # merge the alpha channel
241
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
242
+ output_img[:, :, 3] = output_alpha
243
+ print('finished process alpha channel')
244
+
245
+ print('about to resize and return')
246
+ # ------------------------------ return ------------------------------ #
247
+ if max_range == 65535: # 16-bit image
248
+ output = (output_img * 65535.0).round().astype(np.uint16)
249
+ else:
250
+ output = (output_img * 255.0).round().astype(np.uint8)
251
+
252
+ if outscale is not None and outscale != float(self.scale):
253
+ output = cv2.resize(
254
+ output, (
255
+ int(w_input * outscale),
256
+ int(h_input * outscale),
257
+ ), interpolation=cv2.INTER_LANCZOS4)
258
+
259
+ return output, img_mode
260
+
models/ISOMER/scripts/utils.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ import pymeshlab
5
+ import pymeshlab as ml
6
+ from pymeshlab import PercentageValue
7
+ from pytorch3d.renderer import TexturesVertex
8
+ from pytorch3d.structures import Meshes
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from typing import List, Tuple
12
+ from PIL import Image
13
+ import trimesh
14
+
15
+ EPSILON = 1e-8
16
+
17
+ def load_mesh_with_trimesh(file_name, file_type=None):
18
+ import trimesh
19
+ mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type)
20
+ if isinstance(mesh, trimesh.Scene):
21
+ assert len(mesh.geometry) > 0
22
+ # save to obj first and load again to avoid offset issue
23
+ from io import BytesIO
24
+ with BytesIO() as f:
25
+ mesh.export(f, file_type="obj")
26
+ f.seek(0)
27
+ mesh = trimesh.load(f, file_type="obj")
28
+ if isinstance(mesh, trimesh.Scene):
29
+ # we lose texture information here
30
+ mesh = trimesh.util.concatenate(
31
+ tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
32
+ for g in mesh.geometry.values()))
33
+ assert isinstance(mesh, trimesh.Trimesh)
34
+
35
+ vertices = torch.from_numpy(mesh.vertices).T
36
+ faces = torch.from_numpy(mesh.faces).T
37
+ colors = None
38
+ if mesh.visual is not None:
39
+ if hasattr(mesh.visual, 'vertex_colors'):
40
+ colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255.
41
+ if colors is None:
42
+ colors = torch.ones_like(vertices) * 0.5
43
+ return vertices, faces, colors
44
+
45
+ def meshlab_mesh_to_py3dmesh(mesh: pymeshlab.Mesh) -> Meshes:
46
+ verts = torch.from_numpy(mesh.vertex_matrix()).float()
47
+ faces = torch.from_numpy(mesh.face_matrix()).long()
48
+ colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
49
+ textures = TexturesVertex(verts_features=[colors])
50
+ return Meshes(verts=[verts], faces=[faces], textures=textures)
51
+
52
+
53
+ def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh:
54
+ colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
55
+ m1 = pymeshlab.Mesh(
56
+ vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64),
57
+ face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32),
58
+ v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64),
59
+ v_color_matrix=colors_in)
60
+ return m1
61
+
62
+
63
+ def to_pyml_mesh(vertices,faces):
64
+ m1 = pymeshlab.Mesh(
65
+ vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
66
+ face_matrix=faces.cpu().long().numpy().astype(np.int32),
67
+ )
68
+ return m1
69
+
70
+
71
+ def to_py3d_mesh(vertices, faces, normals=None):
72
+ from pytorch3d.structures import Meshes
73
+ from pytorch3d.renderer.mesh.textures import TexturesVertex
74
+ mesh = Meshes(verts=[vertices], faces=[faces], textures=None)
75
+ if normals is None:
76
+ normals = mesh.verts_normals_packed()
77
+ # set normals as vertext colors
78
+ mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
79
+ return mesh
80
+
81
+
82
+ def from_py3d_mesh(mesh):
83
+ return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
84
+
85
+ def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float):
86
+ """
87
+ rotate along y-axis
88
+ normal_map: np.array, shape=(H, W, 3) in [-1, 1]
89
+ angle: float, in degree
90
+ """
91
+ angle = angle / 180 * np.pi
92
+ R = np.array([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]])
93
+ return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape)
94
+
95
+ # from view coord to front view world coord
96
+ def rotate_normals(normal_pils, return_types='np', rotate_direction=1) -> np.ndarray: # [0, 255]
97
+ n_views = len(normal_pils)
98
+ ret = []
99
+ for idx, rgba_normal in enumerate(normal_pils):
100
+ # rotate normal
101
+ normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1]
102
+ alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1]
103
+ normal_np = normal_np * 2 - 1
104
+ normal_np = rotate_normalmap_by_angle(normal_np, rotate_direction * idx * (360 / n_views))
105
+ normal_np = (normal_np + 1) / 2
106
+ normal_np = normal_np * alpha_np[..., None] # make bg black
107
+ rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255] , axis=-1)
108
+ if return_types == 'np':
109
+ ret.append(rgba_normal_np)
110
+ elif return_types == 'pil':
111
+ ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
112
+ else:
113
+ raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
114
+ return ret
115
+
116
+
117
+ def rotate_normalmap_by_angle_torch(normal_map, angle):
118
+ """
119
+ rotate along y-axis
120
+ normal_map: torch.Tensor, shape=(H, W, 3) in [-1, 1], device='cuda'
121
+ angle: float, in degree
122
+ """
123
+ angle = torch.tensor(angle / 180 * np.pi).to(normal_map)
124
+ R = torch.tensor([[torch.cos(angle), 0, torch.sin(angle)],
125
+ [0, 1, 0],
126
+ [-torch.sin(angle), 0, torch.cos(angle)]]).to(normal_map)
127
+ return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape)
128
+
129
+ def do_rotate(rgba_normal, angle):
130
+ rgba_normal = torch.from_numpy(rgba_normal).float().cuda() / 255
131
+ rotated_normal_tensor = rotate_normalmap_by_angle_torch(rgba_normal[..., :3] * 2 - 1, angle)
132
+ rotated_normal_tensor = (rotated_normal_tensor + 1) / 2
133
+ rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]] # make bg black
134
+ rgba_normal_np = torch.cat([rotated_normal_tensor * 255, rgba_normal[:, :, [3]] * 255], dim=-1).cpu().numpy()
135
+ return rgba_normal_np
136
+
137
+ def rotate_normals_torch(normal_pils, return_types='np', rotate_direction=1):
138
+ n_views = len(normal_pils)
139
+ ret = []
140
+ for idx, rgba_normal in enumerate(normal_pils):
141
+ # rotate normal
142
+ angle = rotate_direction * idx * (360 / n_views)
143
+ rgba_normal_np = do_rotate(np.array(rgba_normal), angle)
144
+ if return_types == 'np':
145
+ ret.append(rgba_normal_np)
146
+ elif return_types == 'pil':
147
+ ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
148
+ else:
149
+ raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
150
+ return ret
151
+
152
+ def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)):
153
+ ret = []
154
+ new_bkgd = np.array(new_bkgd).reshape(1, 1, 3)
155
+ for rgba_img in img_pils:
156
+ img_np = np.array(rgba_img)[:, :, :3] / 255
157
+ alpha_np = np.array(rgba_img)[:, :, 3] / 255
158
+ ori_bkgd = img_np[:1, :1]
159
+ # color = ori_color * alpha + bkgd * (1-alpha)
160
+ # ori_color = (color - bkgd * (1-alpha)) / alpha
161
+ alpha_np_clamp = np.clip(alpha_np, 1e-6, 1) # avoid divide by zero
162
+ ori_img_np = (img_np - ori_bkgd * (1 - alpha_np[..., None])) / alpha_np_clamp[..., None]
163
+ img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np * alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd)
164
+ rgba_img_np = np.concatenate([img_np * 255, alpha_np[..., None] * 255], axis=-1)
165
+ ret.append(Image.fromarray(rgba_img_np.astype(np.uint8)))
166
+ return ret
167
+
168
+ def change_bkgd_to_normal(normal_pils) -> List[Image.Image]:
169
+ n_views = len(normal_pils)
170
+ ret = []
171
+ for idx, rgba_normal in enumerate(normal_pils):
172
+ # calcuate background normal
173
+ target_bkgd = rotate_normalmap_by_angle(np.array([[[0., 0., 1.]]]), idx * (360 / n_views))
174
+ normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1]
175
+ alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1]
176
+ normal_np = normal_np * 2 - 1
177
+ old_bkgd = normal_np[:1,:1]
178
+ normal_np[alpha_np > 0.05] = (normal_np[alpha_np > 0.05] - old_bkgd * (1 - alpha_np[alpha_np > 0.05][..., None])) / alpha_np[alpha_np > 0.05][..., None]
179
+ normal_np = normal_np * alpha_np[..., None] + target_bkgd * (1 - alpha_np[..., None])
180
+ normal_np = (normal_np + 1) / 2
181
+ rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[..., None] * 255] , axis=-1)
182
+ ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
183
+ return ret
184
+
185
+
186
+ def fix_vert_color_glb(mesh_path):
187
+ from pygltflib import GLTF2, Material, PbrMetallicRoughness
188
+ obj1 = GLTF2().load(mesh_path)
189
+ obj1.meshes[0].primitives[0].material = 0
190
+ obj1.materials.append(Material(
191
+ pbrMetallicRoughness = PbrMetallicRoughness(
192
+ baseColorFactor = [1.0, 1.0, 1.0, 1.0],
193
+ metallicFactor = 0.,
194
+ roughnessFactor = 1.0,
195
+ ),
196
+ emissiveFactor = [0.0, 0.0, 0.0],
197
+ doubleSided = True,
198
+ ))
199
+ obj1.save(mesh_path)
200
+
201
+
202
+ def srgb_to_linear(c_srgb):
203
+ c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
204
+ return c_linear.clip(0, 1.)
205
+
206
+
207
+ def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
208
+ # convert from pytorch3d meshes to trimesh mesh
209
+ vertices = meshes.verts_packed().cpu().float().numpy()
210
+ triangles = meshes.faces_packed().cpu().long().numpy()
211
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
212
+ if save_glb_path.endswith(".glb"):
213
+ # rotate 180 along +Y
214
+ vertices[:, [0, 2]] = -vertices[:, [0, 2]]
215
+
216
+ if apply_sRGB_to_LinearRGB:
217
+ np_color = srgb_to_linear(np_color)
218
+ assert vertices.shape[0] == np_color.shape[0]
219
+ assert np_color.shape[1] == 3
220
+ assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
221
+ mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
222
+ mesh.remove_unreferenced_vertices()
223
+ # save mesh
224
+ mesh.export(save_glb_path)
225
+ if save_glb_path.endswith(".glb"):
226
+ fix_vert_color_glb(save_glb_path)
227
+
228
+
229
+ def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]:
230
+ import time
231
+ if '.' in save_mesh_prefix:
232
+ save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])
233
+ if with_timestamp:
234
+ save_mesh_prefix = save_mesh_prefix + f"_{int(time.time())}"
235
+ ret_mesh = save_mesh_prefix + ".glb"
236
+ # optimizied version
237
+ save_py3dmesh_with_trimesh_fast(meshes, ret_mesh)
238
+ return ret_mesh, None
239
+
240
+
241
+ def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
242
+ ms = ml.MeshSet()
243
+ ms.add_mesh(pyml_mesh, "cube_mesh")
244
+
245
+ if apply_smooth:
246
+ ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False)
247
+ if apply_sub_divide: # 5s, slow
248
+ ms.apply_filter("meshing_repair_non_manifold_vertices")
249
+ ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces')
250
+ ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold))
251
+ return meshlab_mesh_to_py3dmesh(ms.current_mesh())
252
+
253
+
254
+ def expand2square(pil_img, background_color):
255
+ width, height = pil_img.size
256
+ if width == height:
257
+ return pil_img
258
+ elif width > height:
259
+ result = Image.new(pil_img.mode, (width, width), background_color)
260
+ result.paste(pil_img, (0, (width - height) // 2))
261
+ return result
262
+ else:
263
+ result = Image.new(pil_img.mode, (height, height), background_color)
264
+ result.paste(pil_img, ((height - width) // 2, 0))
265
+ return result
266
+
267
+
268
+
269
+ def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"):
270
+ new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device)
271
+
272
+ imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255
273
+ img_nps = imgs[..., :3]
274
+ alpha_nps = imgs[..., 3]
275
+ ori_bkgds = img_nps[:, :1, :1]
276
+
277
+ alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1)
278
+ ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1)
279
+ ori_img_nps = torch.clamp(ori_img_nps, 0, 1)
280
+ img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd)
281
+
282
+ rgba_img_np = torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1)
283
+ return rgba_img_np
284
+
285
+
286
+
287
+ def rotation_matrix_axis_angle(axis, angle, device='cuda'):
288
+ """
289
+ Return the rotation matrix associated with counterclockwise rotation about
290
+ the given axis by angle degrees, using PyTorch.
291
+ """
292
+ if type(axis) != torch.tensor:
293
+ axis = torch.tensor(axis, device=device)
294
+ axis = axis.float().to(device)
295
+ if type(angle) != torch.tensor:
296
+ angle = torch.tensor(angle, device=device)
297
+ angle = angle.float().to(device)
298
+
299
+ theta = angle * torch.pi / 180.0
300
+ axis = torch.tensor(axis, dtype=torch.float32)
301
+ if torch.dot(axis, axis) > 0:
302
+ denom = torch.sqrt(torch.dot(axis, axis))
303
+ demon = torch.where(denom == 0, torch.tensor(EPSILON).to(denom.device), denom)
304
+ axis = axis / torch.sqrt(demon)
305
+ a = torch.cos(theta / 2.0)
306
+ b, c, d = -axis[0] * torch.sin(theta / 2.0), -axis[1] * torch.sin(theta / 2.0), -axis[2] * torch.sin(theta / 2.0)
307
+
308
+ aa, bb, cc, dd = a*a, b*b, c*c, d*d
309
+ bc, ad, ac, ab, bd, cd = b*c, a*d, a*c, a*b, b*d, c*d
310
+ return torch.stack([
311
+ torch.stack([aa+bb-cc-dd, 2*(bc+ad), 2*(bd-ac)]),
312
+ torch.stack([2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)]),
313
+ torch.stack([2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc])
314
+ ])
315
+ else:
316
+ return torch.eye(3)
317
+
318
+
319
+
320
+ def normal_rotation_img2img_angle_axis(image, angle, axis=None, device='cuda'):
321
+ """
322
+ Rotate an image by a given angle around a given axis using PyTorch.
323
+
324
+ Args:
325
+ image: Input Image to rotate.
326
+ angle: Rotation angle in degrees.
327
+ axis: Rotation axis as a array of 3 floats.
328
+
329
+ Returns:
330
+ Image: Rotated Image.
331
+ """
332
+ if axis is None:
333
+ axis = [0,1,0]
334
+ axis = torch.tensor(axis, device=device)
335
+
336
+
337
+ if type(image) == Image.Image:
338
+ image_array = torch.tensor(np.array(image, dtype='float32'))
339
+ else:
340
+ image_array = image
341
+ image_array = image_array.to(device)
342
+
343
+ if type(angle) != torch.Tensor:
344
+ angle = torch.tensor(angle)
345
+ angle = angle.to(device)
346
+
347
+ if image_array.shape[2] == 4:
348
+ rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3]
349
+ else:
350
+ rgb_array = image_array
351
+ alpha_array = None
352
+
353
+ rgb_array = rgb_array / 255.0 - 0.5
354
+
355
+ rgb_array = rgb_array.permute(2, 0, 1)
356
+
357
+ rotated_tensor = apply_rotation_angle_axis(rgb_array.unsqueeze(0), axis, torch.tensor([angle], device=rgb_array.device))
358
+
359
+
360
+ rotated_array = rotated_tensor.squeeze().permute(1, 2, 0)
361
+
362
+ rotated_array = (rotated_array/2 + 0.5) * 255
363
+
364
+ if alpha_array is not None:
365
+ rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2)
366
+
367
+ rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8')
368
+
369
+ rotated_normal = Image.fromarray(rotated_array_uint8)
370
+
371
+ return rotated_normal
372
+
373
+ def normal_rotation_img2img_c2w(image, c2w, device='cuda'):
374
+
375
+ if type(image) != torch.Tensor:
376
+ image_array = torch.tensor(np.array(image, dtype='float32'))
377
+ else:
378
+ image_array = image
379
+
380
+
381
+ image_array = image_array.to(device)
382
+
383
+ if image_array.shape[2] == 4:
384
+ rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3]
385
+ else:
386
+ rgb_array = image_array
387
+ alpha_array = None
388
+
389
+ rgb_array = rgb_array / 255.0 - 0.5
390
+
391
+ rotation_matrix = c2w
392
+
393
+ rotated_tensor = transform_normals_R(rgb_array, rotation_matrix)
394
+
395
+ rotated_array = rotated_tensor.squeeze().permute(1, 2, 0)
396
+ rotated_array = (rotated_array/2 + 0.5) * 255
397
+
398
+ if alpha_array is not None:
399
+ rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2)
400
+
401
+ rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8')
402
+
403
+ rotated_normal = Image.fromarray(rotated_array_uint8)
404
+
405
+ return rotated_normal
406
+
407
+ def normal_rotation_img2img_azi_ele(image, azi, ele, device='cuda'):
408
+ """
409
+ Rotate an image by a given angle around a given axis using PyTorch.
410
+
411
+ Args:
412
+ image: Input Image to rotate.
413
+
414
+ Returns:
415
+ Image: Rotated Image.
416
+ """
417
+
418
+ if type(image) == Image.Image:
419
+ image_array = torch.tensor(np.array(image, dtype='float32'))
420
+ else:
421
+ image_array = image
422
+ image_array = image_array.to(device)
423
+
424
+ if type(azi) != torch.Tensor:
425
+ azi = torch.tensor(azi)
426
+ azi = azi.to(device)
427
+
428
+ if type(ele) != torch.Tensor:
429
+ ele = torch.tensor(ele)
430
+ ele = ele.to(device)
431
+
432
+ if image_array.shape[2] == 4:
433
+ rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3]
434
+ else:
435
+ rgb_array = image_array
436
+ alpha_array = None
437
+
438
+ rgb_array = rgb_array / 255.0 - 0.5
439
+
440
+ rotation_matrix = get_rotation_matrix_azi_ele(azi, ele)
441
+ rotated_tensor = transform_normals_R(rgb_array, rotation_matrix)
442
+
443
+ rotated_array = rotated_tensor.squeeze().permute(1, 2, 0)
444
+
445
+ rotated_array = (rotated_array/2 + 0.5) * 255
446
+
447
+ if alpha_array is not None:
448
+ rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2)
449
+
450
+ rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8')
451
+
452
+ rotated_normal = Image.fromarray(rotated_array_uint8)
453
+
454
+ return rotated_normal
455
+
456
+
457
+ def rotate_normal_R(image, rotation_matrix, save_addr="", device="cuda"):
458
+ """
459
+ Rotate a normal map by a given Rotation matrix using PyTorch.
460
+
461
+ Args:
462
+ image: Input Image to rotate.
463
+
464
+ Returns:
465
+ Image: Rotated Image.
466
+ """
467
+
468
+ if type(image) != torch.tensor:
469
+ image_array = torch.tensor(np.array(image, dtype='float32'))
470
+ else:
471
+ image_array = image
472
+ image_array = image_array.to(device)
473
+
474
+ if image_array.shape[2] == 4:
475
+ rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3]
476
+ else:
477
+ rgb_array = image_array
478
+ alpha_array = None
479
+
480
+ rgb_array = rgb_array / 255.0 - 0.5
481
+
482
+ rotated_tensor = transform_normals_R(rgb_array, rotation_matrix.to(device))
483
+
484
+ rotated_array = rotated_tensor.squeeze().permute(1, 2, 0)
485
+
486
+ rotated_array = (rotated_array/2 + 0.5) * 255
487
+
488
+ if alpha_array is not None:
489
+ rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2)
490
+
491
+ rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8')
492
+
493
+ rotated_normal = Image.fromarray(rotated_array_uint8)
494
+
495
+ if save_addr:
496
+ rotated_normal.save(save_addr)
497
+ return rotated_normal
498
+
499
+
500
+
501
+ def transform_normals_R(local_normals, rotation_matrix):
502
+ assert local_normals.shape[2] ==3 ,f'local_normals.shape[2]: {local_normals.shape[2]}. only support rgb image'
503
+
504
+ h, w = local_normals.shape[:2]
505
+ local_normals_flat = local_normals.view(-1, 3).permute(1, 0)
506
+
507
+ images_flat = local_normals_flat.unsqueeze(0)
508
+ rotation_matrices = rotation_matrix.unsqueeze(0)
509
+ rotated_images_flat = torch.bmm(rotation_matrices, images_flat)
510
+
511
+ rotated_images = rotated_images_flat.view(1, 3, h, w)
512
+
513
+ norms = torch.norm(rotated_images, p=2, dim=1, keepdim=True)
514
+ norms = torch.where(norms == 0, torch.tensor(EPSILON).to(norms.device), norms)
515
+ normalized_images = rotated_images / norms
516
+
517
+ return normalized_images
518
+
519
+
520
+ def manage_elevation_azimuth(ele_list, azi_list):
521
+ """deal with cases when elevation > 90"""
522
+
523
+ for i in range(len(ele_list)):
524
+ elevation = ele_list[i] % 360
525
+ azimuth = azi_list[i] % 360
526
+ if elevation > 90 and elevation<=270:
527
+ # when elevation is too big,camera gets to the other side
528
+ # print(f'!!! elevation({elevation}) > 90 and <=270, set to 180-elevation, and add 180 to azimuth')
529
+ elevation = 180 - elevation
530
+ azimuth = azimuth + 180
531
+ # print(f'new elevation: {elevation}, new azimuth: {azimuth}')
532
+
533
+ elif elevation>270:
534
+ # print(f'!!! elevation({elevation}) > 270, set to elevation-360, and use original azimuth')
535
+ elevation = elevation - 360
536
+ azimuth = azimuth
537
+ # print(f'new elevation: {elevation}, new azimuth: {azimuth}')
538
+
539
+ ele_list[i] = elevation
540
+ azi_list[i] = azimuth
541
+
542
+ return ele_list, azi_list
543
+
544
+ def get_rotation_matrix_azi_ele(azimuth, elevation):
545
+
546
+ ele = elevation/180 * torch.pi
547
+ azi = azimuth/180 * torch.pi
548
+
549
+ Rz = torch.tensor([
550
+ [torch.cos(azi), 0, -torch.sin(azi)],
551
+ [0, 1, 0],
552
+ [torch.sin(azi), 0, torch.cos(azi)],
553
+ ]).to(azimuth.device)
554
+
555
+ Re = torch.tensor([
556
+ [1, 0, 0],
557
+ [0, torch.cos(ele), torch.sin(ele)],
558
+ [0, -torch.sin(ele), torch.cos(ele)],
559
+ ]).to(elevation.device)
560
+
561
+ return torch.matmul(Rz,Re).to(azimuth.device)
562
+
563
+
564
+ def rotate_vector(vector, axis, angle, device='cuda'):
565
+ rot_matrix = rotation_matrix_axis_angle(axis, angle)
566
+ return torch.matmul(vector.to(device).float(), rot_matrix.to(device).float())
567
+
568
+ def apply_rotation_angle_axis(image, axis, angle, device='cuda'):
569
+ """Apply rotation to a batch of images with shape [batch_size, 3(rgb), h, w] using PyTorch.
570
+
571
+ Args:
572
+ image (torch.Tensor): Input RGB image tensor of shape [batch_size, 3, h, w]. each pixel's rgb channels refer to direction of normal (can be negative)
573
+ axis (torch.Tensor): Rotation axis of shape [3].
574
+ angle (torch.Tensor): Rotation angles in degrees, of shape [batch_size].
575
+ Returns:
576
+ torch.Tensor: Rotated image tensor of shape [batch_size, 3, h, w]. values between [-1., 1.]
577
+
578
+ """
579
+
580
+ if not isinstance(image, torch.Tensor):
581
+ image_tensor = torch.tensor(image).to(device)
582
+ else:
583
+ image_tensor = image.to(device)
584
+
585
+ if not isinstance(axis, torch.Tensor):
586
+ axis = torch.tensor(axis)
587
+ axis = axis.to(device)
588
+
589
+ if not isinstance(angle, torch.Tensor):
590
+ angle = torch.tensor(angle)
591
+ angle = angle.to(device)
592
+
593
+ batch_size, channels, h, w = image_tensor.shape
594
+ rot_matrix = rotation_matrix_axis_angle(axis, angle)
595
+
596
+ rotation_matrices = rot_matrix.permute(2, 0, 1)
597
+
598
+ batch_size, c, h, w = image_tensor.shape
599
+ images_flat = image_tensor.view(batch_size, c, h * w)
600
+
601
+ rotated_images_flat = torch.bmm(rotation_matrices, images_flat)
602
+
603
+ rotated_images = rotated_images_flat.view(batch_size, c, h, w)
604
+
605
+ norms = torch.norm(rotated_images, p=2, dim=1, keepdim=True)
606
+
607
+ norms = torch.where(norms == 0, torch.tensor(EPSILON).to(norms.device), norms)
608
+
609
+ normalized_images = rotated_images / norms
610
+
611
+ return normalized_images
models/lrm/config/PRM_inference.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ target: models.lrm.models.lrm_mesh.PRM
3
+ params:
4
+ encoder_feat_dim: 768
5
+ encoder_freeze: false
6
+ encoder_model_name: facebook/dino-vitb16
7
+ transformer_dim: 1024
8
+ transformer_layers: 16
9
+ transformer_heads: 16
10
+ triplane_low_res: 32
11
+ triplane_high_res: 64
12
+ triplane_dim: 80
13
+ rendering_samples_per_ray: 128
14
+ grid_res: 128
15
+ grid_scale: 2.1
16
+
17
+
18
+ infer_config:
19
+ unet_path: ckpts/diffusion_pytorch_model.bin
20
+ model_path: ckpts/final_ckpt.ckpt
21
+ texture_resolution: 2048
22
+ render_resolution: 512
models/lrm/models/__init__.py ADDED
File without changes
models/lrm/models/decoder/__init__.py ADDED
File without changes
models/lrm/models/decoder/transformer.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ class BasicTransformerBlock(nn.Module):
21
+ """
22
+ Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
23
+ """
24
+ # use attention from torch.nn.MultiHeadAttention
25
+ # Block contains a cross-attention layer, a self-attention layer, and a MLP
26
+ def __init__(
27
+ self,
28
+ inner_dim: int,
29
+ cond_dim: int,
30
+ num_heads: int,
31
+ eps: float,
32
+ attn_drop: float = 0.,
33
+ attn_bias: bool = False,
34
+ mlp_ratio: float = 4.,
35
+ mlp_drop: float = 0.,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.norm1 = nn.LayerNorm(inner_dim)
40
+ self.cross_attn = nn.MultiheadAttention(
41
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
42
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
43
+ self.norm2 = nn.LayerNorm(inner_dim)
44
+ self.self_attn = nn.MultiheadAttention(
45
+ embed_dim=inner_dim, num_heads=num_heads,
46
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
47
+ self.norm3 = nn.LayerNorm(inner_dim)
48
+ self.mlp = nn.Sequential(
49
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
50
+ nn.GELU(),
51
+ nn.Dropout(mlp_drop),
52
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
53
+ nn.Dropout(mlp_drop),
54
+ )
55
+
56
+ def forward(self, x, cond):
57
+ # x: [N, L, D]
58
+ # cond: [N, L_cond, D_cond]
59
+ x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
60
+ before_sa = self.norm2(x)
61
+ x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
62
+ x = x + self.mlp(self.norm3(x))
63
+ return x
64
+
65
+
66
+ class TriplaneTransformer(nn.Module):
67
+ """
68
+ Transformer with condition that generates a triplane representation.
69
+
70
+ Reference:
71
+ Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
72
+ """
73
+ def __init__(
74
+ self,
75
+ inner_dim: int,
76
+ image_feat_dim: int,
77
+ triplane_low_res: int,
78
+ triplane_high_res: int,
79
+ triplane_dim: int,
80
+ num_layers: int,
81
+ num_heads: int,
82
+ eps: float = 1e-6,
83
+ ):
84
+ super().__init__()
85
+
86
+ # attributes
87
+ self.triplane_low_res = triplane_low_res
88
+ self.triplane_high_res = triplane_high_res
89
+ self.triplane_dim = triplane_dim
90
+
91
+ # modules
92
+ # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
93
+ self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
94
+ self.layers = nn.ModuleList([
95
+ BasicTransformerBlock(
96
+ inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
97
+ for _ in range(num_layers)
98
+ ])
99
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
100
+ self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
101
+
102
+ def forward(self, image_feats):
103
+ # image_feats: [N, L_cond, D_cond]
104
+
105
+ N = image_feats.shape[0]
106
+ H = W = self.triplane_low_res
107
+ L = 3 * H * W
108
+
109
+ x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
110
+ for layer in self.layers:
111
+ x = layer(x, image_feats)
112
+ x = self.norm(x)
113
+
114
+ # separate each plane and apply deconv
115
+ x = x.view(N, 3, H, W, -1)
116
+ x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
117
+ x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
118
+ x = self.deconv(x) # [3*N, D', H', W']
119
+ x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
120
+ x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
121
+ x = x.contiguous()
122
+
123
+ return x
models/lrm/models/encoder/__init__.py ADDED
File without changes
models/lrm/models/encoder/dino.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch ViT model."""
16
+
17
+
18
+ import collections.abc
19
+ import math
20
+ from typing import Dict, List, Optional, Set, Tuple, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutput,
28
+ BaseModelOutputWithPooling,
29
+ )
30
+ from transformers import PreTrainedModel, ViTConfig
31
+ from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
32
+
33
+
34
+ class ViTEmbeddings(nn.Module):
35
+ """
36
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
37
+ """
38
+
39
+ def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
40
+ super().__init__()
41
+
42
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
43
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
44
+ self.patch_embeddings = ViTPatchEmbeddings(config)
45
+ num_patches = self.patch_embeddings.num_patches
46
+ self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
47
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
48
+ self.config = config
49
+
50
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
51
+ """
52
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
53
+ resolution images.
54
+
55
+ Source:
56
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
57
+ """
58
+
59
+ num_patches = embeddings.shape[1] - 1
60
+ num_positions = self.position_embeddings.shape[1] - 1
61
+ if num_patches == num_positions and height == width:
62
+ return self.position_embeddings
63
+ class_pos_embed = self.position_embeddings[:, 0]
64
+ patch_pos_embed = self.position_embeddings[:, 1:]
65
+ dim = embeddings.shape[-1]
66
+ h0 = height // self.config.patch_size
67
+ w0 = width // self.config.patch_size
68
+ # we add a small number to avoid floating point error in the interpolation
69
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
70
+ h0, w0 = h0 + 0.1, w0 + 0.1
71
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
72
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
73
+ patch_pos_embed = nn.functional.interpolate(
74
+ patch_pos_embed,
75
+ scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
76
+ mode="bicubic",
77
+ align_corners=False,
78
+ )
79
+ assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
80
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
81
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
82
+
83
+ def forward(
84
+ self,
85
+ pixel_values: torch.Tensor,
86
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
87
+ interpolate_pos_encoding: bool = False,
88
+ ) -> torch.Tensor:
89
+ batch_size, num_channels, height, width = pixel_values.shape
90
+ embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
91
+
92
+ if bool_masked_pos is not None:
93
+ seq_length = embeddings.shape[1]
94
+ mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
95
+ # replace the masked visual tokens by mask_tokens
96
+ mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
97
+ embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
98
+
99
+ # add the [CLS] token to the embedded patch tokens
100
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
101
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
102
+
103
+ # add positional encoding to each token
104
+ if interpolate_pos_encoding:
105
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
106
+ else:
107
+ embeddings = embeddings + self.position_embeddings
108
+
109
+ embeddings = self.dropout(embeddings)
110
+
111
+ return embeddings
112
+
113
+
114
+ class ViTPatchEmbeddings(nn.Module):
115
+ """
116
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
117
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
118
+ Transformer.
119
+ """
120
+
121
+ def __init__(self, config):
122
+ super().__init__()
123
+ image_size, patch_size = config.image_size, config.patch_size
124
+ num_channels, hidden_size = config.num_channels, config.hidden_size
125
+
126
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
127
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
128
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
129
+ self.image_size = image_size
130
+ self.patch_size = patch_size
131
+ self.num_channels = num_channels
132
+ self.num_patches = num_patches
133
+
134
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
135
+
136
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
137
+ batch_size, num_channels, height, width = pixel_values.shape
138
+ if num_channels != self.num_channels:
139
+ raise ValueError(
140
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
141
+ f" Expected {self.num_channels} but got {num_channels}."
142
+ )
143
+ if not interpolate_pos_encoding:
144
+ if height != self.image_size[0] or width != self.image_size[1]:
145
+ raise ValueError(
146
+ f"Input image size ({height}*{width}) doesn't match model"
147
+ f" ({self.image_size[0]}*{self.image_size[1]})."
148
+ )
149
+ embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
150
+ return embeddings
151
+
152
+
153
+ class ViTSelfAttention(nn.Module):
154
+ def __init__(self, config: ViTConfig) -> None:
155
+ super().__init__()
156
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
157
+ raise ValueError(
158
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
159
+ f"heads {config.num_attention_heads}."
160
+ )
161
+
162
+ self.num_attention_heads = config.num_attention_heads
163
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
164
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
165
+
166
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
167
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
168
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
169
+
170
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
171
+
172
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
173
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
174
+ x = x.view(new_x_shape)
175
+ return x.permute(0, 2, 1, 3)
176
+
177
+ def forward(
178
+ self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
179
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
180
+ mixed_query_layer = self.query(hidden_states)
181
+
182
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
183
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
184
+ query_layer = self.transpose_for_scores(mixed_query_layer)
185
+
186
+ # Take the dot product between "query" and "key" to get the raw attention scores.
187
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
188
+
189
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
190
+
191
+ # Normalize the attention scores to probabilities.
192
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
193
+
194
+ # This is actually dropping out entire tokens to attend to, which might
195
+ # seem a bit unusual, but is taken from the original Transformer paper.
196
+ attention_probs = self.dropout(attention_probs)
197
+
198
+ # Mask heads if we want to
199
+ if head_mask is not None:
200
+ attention_probs = attention_probs * head_mask
201
+
202
+ context_layer = torch.matmul(attention_probs, value_layer)
203
+
204
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
205
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
206
+ context_layer = context_layer.view(new_context_layer_shape)
207
+
208
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
209
+
210
+ return outputs
211
+
212
+
213
+ class ViTSelfOutput(nn.Module):
214
+ """
215
+ The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
216
+ layernorm applied before each block.
217
+ """
218
+
219
+ def __init__(self, config: ViTConfig) -> None:
220
+ super().__init__()
221
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
222
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
223
+
224
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
225
+ hidden_states = self.dense(hidden_states)
226
+ hidden_states = self.dropout(hidden_states)
227
+
228
+ return hidden_states
229
+
230
+
231
+ class ViTAttention(nn.Module):
232
+ def __init__(self, config: ViTConfig) -> None:
233
+ super().__init__()
234
+ self.attention = ViTSelfAttention(config)
235
+ self.output = ViTSelfOutput(config)
236
+ self.pruned_heads = set()
237
+
238
+ def prune_heads(self, heads: Set[int]) -> None:
239
+ if len(heads) == 0:
240
+ return
241
+ heads, index = find_pruneable_heads_and_indices(
242
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
243
+ )
244
+
245
+ # Prune linear layers
246
+ self.attention.query = prune_linear_layer(self.attention.query, index)
247
+ self.attention.key = prune_linear_layer(self.attention.key, index)
248
+ self.attention.value = prune_linear_layer(self.attention.value, index)
249
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
250
+
251
+ # Update hyper params and store pruned heads
252
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
253
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
254
+ self.pruned_heads = self.pruned_heads.union(heads)
255
+
256
+ def forward(
257
+ self,
258
+ hidden_states: torch.Tensor,
259
+ head_mask: Optional[torch.Tensor] = None,
260
+ output_attentions: bool = False,
261
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
262
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions)
263
+
264
+ attention_output = self.output(self_outputs[0], hidden_states)
265
+
266
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
267
+ return outputs
268
+
269
+
270
+ class ViTIntermediate(nn.Module):
271
+ def __init__(self, config: ViTConfig) -> None:
272
+ super().__init__()
273
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
274
+ if isinstance(config.hidden_act, str):
275
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
276
+ else:
277
+ self.intermediate_act_fn = config.hidden_act
278
+
279
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
280
+ hidden_states = self.dense(hidden_states)
281
+ hidden_states = self.intermediate_act_fn(hidden_states)
282
+
283
+ return hidden_states
284
+
285
+
286
+ class ViTOutput(nn.Module):
287
+ def __init__(self, config: ViTConfig) -> None:
288
+ super().__init__()
289
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
290
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
291
+
292
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.dropout(hidden_states)
295
+
296
+ hidden_states = hidden_states + input_tensor
297
+
298
+ return hidden_states
299
+
300
+
301
+ def modulate(x, shift, scale):
302
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
303
+
304
+
305
+ class ViTLayer(nn.Module):
306
+ """This corresponds to the Block class in the timm implementation."""
307
+
308
+ def __init__(self, config: ViTConfig) -> None:
309
+ super().__init__()
310
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
311
+ self.seq_len_dim = 1
312
+ self.attention = ViTAttention(config)
313
+ self.intermediate = ViTIntermediate(config)
314
+ self.output = ViTOutput(config)
315
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
316
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
317
+
318
+ self.adaLN_modulation = nn.Sequential(
319
+ nn.SiLU(),
320
+ nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
321
+ )
322
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
323
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states: torch.Tensor,
328
+ adaln_input: torch.Tensor = None,
329
+ head_mask: Optional[torch.Tensor] = None,
330
+ output_attentions: bool = False,
331
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
332
+ shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
333
+
334
+ self_attention_outputs = self.attention(
335
+ modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
336
+ head_mask,
337
+ output_attentions=output_attentions,
338
+ )
339
+ attention_output = self_attention_outputs[0]
340
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
341
+
342
+ # first residual connection
343
+ hidden_states = attention_output + hidden_states
344
+
345
+ # in ViT, layernorm is also applied after self-attention
346
+ layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
347
+ layer_output = self.intermediate(layer_output)
348
+
349
+ # second residual connection is done here
350
+ layer_output = self.output(layer_output, hidden_states)
351
+
352
+ outputs = (layer_output,) + outputs
353
+
354
+ return outputs
355
+
356
+
357
+ class ViTEncoder(nn.Module):
358
+ def __init__(self, config: ViTConfig) -> None:
359
+ super().__init__()
360
+ self.config = config
361
+ self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
362
+ self.gradient_checkpointing = False
363
+
364
+ def forward(
365
+ self,
366
+ hidden_states: torch.Tensor,
367
+ adaln_input: torch.Tensor = None,
368
+ head_mask: Optional[torch.Tensor] = None,
369
+ output_attentions: bool = False,
370
+ output_hidden_states: bool = False,
371
+ return_dict: bool = True,
372
+ ) -> Union[tuple, BaseModelOutput]:
373
+ all_hidden_states = () if output_hidden_states else None
374
+ all_self_attentions = () if output_attentions else None
375
+
376
+ for i, layer_module in enumerate(self.layer):
377
+ if output_hidden_states:
378
+ all_hidden_states = all_hidden_states + (hidden_states,)
379
+
380
+ layer_head_mask = head_mask[i] if head_mask is not None else None
381
+
382
+ if self.gradient_checkpointing and self.training:
383
+ layer_outputs = self._gradient_checkpointing_func(
384
+ layer_module.__call__,
385
+ hidden_states,
386
+ adaln_input,
387
+ layer_head_mask,
388
+ output_attentions,
389
+ )
390
+ else:
391
+ layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
392
+
393
+ hidden_states = layer_outputs[0]
394
+
395
+ if output_attentions:
396
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
397
+
398
+ if output_hidden_states:
399
+ all_hidden_states = all_hidden_states + (hidden_states,)
400
+
401
+ if not return_dict:
402
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
403
+ return BaseModelOutput(
404
+ last_hidden_state=hidden_states,
405
+ hidden_states=all_hidden_states,
406
+ attentions=all_self_attentions,
407
+ )
408
+
409
+
410
+ class ViTPreTrainedModel(PreTrainedModel):
411
+ """
412
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
413
+ models.
414
+ """
415
+
416
+ config_class = ViTConfig
417
+ base_model_prefix = "vit"
418
+ main_input_name = "pixel_values"
419
+ supports_gradient_checkpointing = True
420
+ _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
421
+
422
+ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
423
+ """Initialize the weights"""
424
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
425
+ # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
426
+ # `trunc_normal_cpu` not implemented in `half` issues
427
+ module.weight.data = nn.init.trunc_normal_(
428
+ module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
429
+ ).to(module.weight.dtype)
430
+ if module.bias is not None:
431
+ module.bias.data.zero_()
432
+ elif isinstance(module, nn.LayerNorm):
433
+ module.bias.data.zero_()
434
+ module.weight.data.fill_(1.0)
435
+ elif isinstance(module, ViTEmbeddings):
436
+ module.position_embeddings.data = nn.init.trunc_normal_(
437
+ module.position_embeddings.data.to(torch.float32),
438
+ mean=0.0,
439
+ std=self.config.initializer_range,
440
+ ).to(module.position_embeddings.dtype)
441
+
442
+ module.cls_token.data = nn.init.trunc_normal_(
443
+ module.cls_token.data.to(torch.float32),
444
+ mean=0.0,
445
+ std=self.config.initializer_range,
446
+ ).to(module.cls_token.dtype)
447
+
448
+
449
+ class ViTModel(ViTPreTrainedModel):
450
+ def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
451
+ super().__init__(config)
452
+ self.config = config
453
+
454
+ self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
455
+ self.encoder = ViTEncoder(config)
456
+
457
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
458
+ self.pooler = ViTPooler(config) if add_pooling_layer else None
459
+
460
+ # Initialize weights and apply final processing
461
+ self.post_init()
462
+
463
+ def get_input_embeddings(self) -> ViTPatchEmbeddings:
464
+ return self.embeddings.patch_embeddings
465
+
466
+ def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
467
+ """
468
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
469
+ class PreTrainedModel
470
+ """
471
+ for layer, heads in heads_to_prune.items():
472
+ self.encoder.layer[layer].attention.prune_heads(heads)
473
+
474
+ def forward(
475
+ self,
476
+ pixel_values: Optional[torch.Tensor] = None,
477
+ adaln_input: Optional[torch.Tensor] = None,
478
+ bool_masked_pos: Optional[torch.BoolTensor] = None,
479
+ head_mask: Optional[torch.Tensor] = None,
480
+ output_attentions: Optional[bool] = None,
481
+ output_hidden_states: Optional[bool] = None,
482
+ interpolate_pos_encoding: Optional[bool] = None,
483
+ return_dict: Optional[bool] = None,
484
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
485
+ r"""
486
+ bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
487
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
488
+ """
489
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
+ output_hidden_states = (
491
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
492
+ )
493
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
494
+
495
+ if pixel_values is None:
496
+ raise ValueError("You have to specify pixel_values")
497
+
498
+ # Prepare head mask if needed
499
+ # 1.0 in head_mask indicate we keep the head
500
+ # attention_probs has shape bsz x n_heads x N x N
501
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
502
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
503
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
504
+
505
+ # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
506
+ expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
507
+ if pixel_values.dtype != expected_dtype:
508
+ pixel_values = pixel_values.to(expected_dtype)
509
+
510
+ embedding_output = self.embeddings(
511
+ pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
512
+ )
513
+
514
+ encoder_outputs = self.encoder(
515
+ embedding_output,
516
+ adaln_input=adaln_input,
517
+ head_mask=head_mask,
518
+ output_attentions=output_attentions,
519
+ output_hidden_states=output_hidden_states,
520
+ return_dict=return_dict,
521
+ )
522
+ sequence_output = encoder_outputs[0]
523
+ sequence_output = self.layernorm(sequence_output)
524
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
525
+
526
+ if not return_dict:
527
+ head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
528
+ return head_outputs + encoder_outputs[1:]
529
+
530
+ return BaseModelOutputWithPooling(
531
+ last_hidden_state=sequence_output,
532
+ pooler_output=pooled_output,
533
+ hidden_states=encoder_outputs.hidden_states,
534
+ attentions=encoder_outputs.attentions,
535
+ )
536
+
537
+
538
+ class ViTPooler(nn.Module):
539
+ def __init__(self, config: ViTConfig):
540
+ super().__init__()
541
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
542
+ self.activation = nn.Tanh()
543
+
544
+ def forward(self, hidden_states):
545
+ # We "pool" the model by simply taking the hidden state corresponding
546
+ # to the first token.
547
+ first_token_tensor = hidden_states[:, 0]
548
+ pooled_output = self.dense(first_token_tensor)
549
+ pooled_output = self.activation(pooled_output)
550
+ return pooled_output
models/lrm/models/encoder/dino_wrapper.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Zexin He
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch.nn as nn
17
+ from transformers import ViTImageProcessor
18
+ from einops import rearrange, repeat
19
+ from .dino import ViTModel
20
+
21
+
22
+ class DinoWrapper(nn.Module):
23
+ """
24
+ Dino v1 wrapper using huggingface transformer implementation.
25
+ """
26
+ def __init__(self, model_name: str, freeze: bool = True):
27
+ super().__init__()
28
+ self.model, self.processor = self._build_dino(model_name)
29
+ self.camera_embedder = nn.Sequential(
30
+ nn.Linear(16, self.model.config.hidden_size, bias=True),
31
+ nn.SiLU(),
32
+ nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
33
+ )
34
+ if freeze:
35
+ self._freeze()
36
+
37
+ def forward(self, image, camera):
38
+ # image: [B, N, C, H, W]
39
+ # camera: [B, N, D]
40
+ # RGB image with [0,1] scale and properly sized
41
+ if image.ndim == 5:
42
+ image = rearrange(image, 'b n c h w -> (b n) c h w')
43
+ dtype = image.dtype
44
+ inputs = self.processor(
45
+ images=image.float(),
46
+ return_tensors="pt",
47
+ do_rescale=False,
48
+ do_resize=False,
49
+ ).to(self.model.device).to(dtype)
50
+ # embed camera
51
+ N = camera.shape[1]
52
+ camera_embeddings = self.camera_embedder(camera)
53
+ camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
54
+ embeddings = camera_embeddings
55
+ # This resampling of positional embedding uses bicubic interpolation
56
+ outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
57
+ last_hidden_states = outputs.last_hidden_state
58
+ return last_hidden_states
59
+
60
+ def _freeze(self):
61
+ print(f"======== Freezing DinoWrapper ========")
62
+ self.model.eval()
63
+ for name, param in self.model.named_parameters():
64
+ param.requires_grad = False
65
+
66
+ @staticmethod
67
+ def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
68
+ import requests
69
+ try:
70
+ model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
71
+ processor = ViTImageProcessor.from_pretrained(model_name)
72
+ return model, processor
73
+ except requests.exceptions.ProxyError as err:
74
+ if proxy_error_retries > 0:
75
+ print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
76
+ import time
77
+ time.sleep(proxy_error_cooldown)
78
+ return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
79
+ else:
80
+ raise err
models/lrm/models/geometry/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
models/lrm/models/geometry/camera/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+
13
+ class Camera(nn.Module):
14
+ def __init__(self):
15
+ super(Camera, self).__init__()
16
+ pass
models/lrm/models/geometry/camera/perspective_camera.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from . import Camera
11
+ import numpy as np
12
+
13
+
14
+ def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
15
+ if near_plane is None:
16
+ near_plane = n
17
+ return np.array(
18
+ [[n / x, 0, 0, 0],
19
+ [0, n / -x, 0, 0],
20
+ [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
21
+ [0, 0, -1, 0]]).astype(np.float32)
22
+
23
+
24
+ class PerspectiveCamera(Camera):
25
+ def __init__(self, fovy=49.0, device='cuda'):
26
+ super(PerspectiveCamera, self).__init__()
27
+ self.device = device
28
+ focal = np.tan(fovy / 180.0 * np.pi * 0.5)
29
+ self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
30
+
31
+ def project(self, points_bxnx4):
32
+ out = torch.matmul(
33
+ points_bxnx4,
34
+ torch.transpose(self.proj_mtx, 1, 2))
35
+ return out
models/lrm/models/geometry/render/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Renderer():
4
+ def __init__(self):
5
+ pass
6
+
7
+ def forward(self):
8
+ pass
models/lrm/models/geometry/render/neural_render.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import nvdiffrast.torch as dr
12
+ from . import Renderer
13
+ from . import util
14
+ from . import renderutils as ru
15
+ _FG_LUT = None
16
+
17
+
18
+ def interpolate(attr, rast, attr_idx, rast_db=None):
19
+ return dr.interpolate(
20
+ attr.contiguous(), rast, attr_idx, rast_db=rast_db,
21
+ diff_attrs=None if rast_db is None else 'all')
22
+
23
+
24
+ def xfm_points(points, matrix, use_python=True):
25
+ '''Transform points.
26
+ Args:
27
+ points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
28
+ matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
29
+ use_python: Use PyTorch's torch.matmul (for validation)
30
+ Returns:
31
+ Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
32
+ '''
33
+ out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
34
+ if torch.is_anomaly_enabled():
35
+ assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
36
+ return out
37
+
38
+
39
+ def dot(x, y):
40
+ return torch.sum(x * y, -1, keepdim=True)
41
+
42
+
43
+ def compute_vertex_normal(v_pos, t_pos_idx):
44
+ i0 = t_pos_idx[:, 0]
45
+ i1 = t_pos_idx[:, 1]
46
+ i2 = t_pos_idx[:, 2]
47
+
48
+ v0 = v_pos[i0, :]
49
+ v1 = v_pos[i1, :]
50
+ v2 = v_pos[i2, :]
51
+
52
+ face_normals = torch.cross(v1 - v0, v2 - v0)
53
+
54
+ # Splat face normals to vertices
55
+ v_nrm = torch.zeros_like(v_pos)
56
+ v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
57
+ v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
58
+ v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
59
+
60
+ # Normalize, replace zero (degenerated) normals with some default value
61
+ v_nrm = torch.where(
62
+ dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
63
+ )
64
+ v_nrm = F.normalize(v_nrm, dim=1)
65
+ assert torch.all(torch.isfinite(v_nrm))
66
+
67
+ return v_nrm
68
+
69
+
70
+ class NeuralRender(Renderer):
71
+ def __init__(self, device='cuda', camera_model=None):
72
+ super(NeuralRender, self).__init__()
73
+ self.device = device
74
+ self.ctx = dr.RasterizeCudaContext(device=device)
75
+ self.projection_mtx = None
76
+ self.camera = camera_model
77
+
78
+ # ==============================================================================================
79
+ # pixel shader
80
+ # ==============================================================================================
81
+ # def shade(
82
+ # self,
83
+ # gb_pos,
84
+ # gb_geometric_normal,
85
+ # gb_normal,
86
+ # gb_tangent,
87
+ # gb_texc,
88
+ # gb_texc_deriv,
89
+ # view_pos,
90
+ # ):
91
+
92
+ # ################################################################################
93
+ # # Texture lookups
94
+ # ################################################################################
95
+ # breakpoint()
96
+ # # Separate kd into alpha and color, default alpha = 1
97
+ # alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])
98
+ # kd = kd[..., 0:3]
99
+
100
+ # ################################################################################
101
+ # # Normal perturbation & normal bend
102
+ # ################################################################################
103
+
104
+ # perturbed_nrm = None
105
+
106
+ # gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
107
+
108
+ # ################################################################################
109
+ # # Evaluate BSDF
110
+ # ################################################################################
111
+
112
+ # assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type"
113
+ # bsdf = material['bsdf'] if bsdf is None else bsdf
114
+ # if bsdf == 'pbr':
115
+ # if isinstance(lgt, light.EnvironmentLight):
116
+ # shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True)
117
+ # else:
118
+ # assert False, "Invalid light type"
119
+ # elif bsdf == 'diffuse':
120
+ # if isinstance(lgt, light.EnvironmentLight):
121
+ # shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False)
122
+ # else:
123
+ # assert False, "Invalid light type"
124
+ # elif bsdf == 'normal':
125
+ # shaded_col = (gb_normal + 1.0)*0.5
126
+ # elif bsdf == 'tangent':
127
+ # shaded_col = (gb_tangent + 1.0)*0.5
128
+ # elif bsdf == 'kd':
129
+ # shaded_col = kd
130
+ # elif bsdf == 'ks':
131
+ # shaded_col = ks
132
+ # else:
133
+ # assert False, "Invalid BSDF '%s'" % bsdf
134
+
135
+ # # Return multiple buffers
136
+ # buffers = {
137
+ # 'shaded' : torch.cat((shaded_col, alpha), dim=-1),
138
+ # 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1),
139
+ # 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1)
140
+ # }
141
+ # return buffers
142
+
143
+ # ==============================================================================================
144
+ # Render a depth slice of the mesh (scene), some limitations:
145
+ # - Single mesh
146
+ # - Single light
147
+ # - Single material
148
+ # ==============================================================================================
149
+ def render_layer(
150
+ self,
151
+ rast,
152
+ rast_deriv,
153
+ mesh,
154
+ view_pos,
155
+ resolution,
156
+ spp,
157
+ msaa
158
+ ):
159
+
160
+ # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution
161
+ rast_out_s = rast
162
+ rast_out_deriv_s = rast_deriv
163
+
164
+ ################################################################################
165
+ # Interpolate attributes
166
+ ################################################################################
167
+
168
+ # Interpolate world space position
169
+ gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int())
170
+
171
+ # Compute geometric normals. We need those because of bent normals trick (for bump mapping)
172
+ v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :]
173
+ v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :]
174
+ v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :]
175
+ face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
176
+ face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
177
+ gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int())
178
+
179
+ # Compute tangent space
180
+ assert mesh.v_nrm is not None and mesh.v_tng is not None
181
+ gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int())
182
+ gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents
183
+
184
+ # Texture coordinate
185
+ # assert mesh.v_tex is not None
186
+ # gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s)
187
+ perturbed_nrm = None
188
+ gb_normal = ru.prepare_shading_normal(gb_pos, view_pos[:,None,None,:], perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
189
+
190
+ return gb_pos, gb_normal
191
+
192
+ def render_mesh(
193
+ self,
194
+ mesh_v_pos_bxnx3,
195
+ mesh_t_pos_idx_fx3,
196
+ mesh,
197
+ camera_mv_bx4x4,
198
+ camera_pos,
199
+ mesh_v_feat_bxnxd,
200
+ resolution=256,
201
+ spp=1,
202
+ device='cuda',
203
+ hierarchical_mask=False
204
+ ):
205
+ assert not hierarchical_mask
206
+
207
+ mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
208
+ v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
209
+ v_pos_clip = self.camera.project(v_pos) # Projection in the camera
210
+
211
+ # view_pos = torch.linalg.inv(mtx_in)[:, :3, 3]
212
+ view_pos = camera_pos
213
+ v_nrm = mesh.v_nrm #compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
214
+
215
+ # Render the image,
216
+ # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
217
+ num_layers = 1
218
+ mask_pyramid = None
219
+ assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
220
+
221
+ mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos [org_pos, clip space pose for rasterization]
222
+
223
+ layers = []
224
+ with dr.DepthPeeler(self.ctx, v_pos_clip, mesh.t_pos_idx.int(), [resolution * spp, resolution * spp]) as peeler:
225
+ for _ in range(num_layers):
226
+ rast, db = peeler.rasterize_next_layer()
227
+ gb_pos, gb_normal = self.render_layer(rast, db, mesh, view_pos, resolution, spp, msaa=False)
228
+
229
+ with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
230
+ for _ in range(num_layers):
231
+ rast, db = peeler.rasterize_next_layer()
232
+ gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
233
+
234
+ hard_mask = torch.clamp(rast[..., -1:], 0, 1)
235
+ antialias_mask = dr.antialias(
236
+ hard_mask.clone().contiguous(), rast, v_pos_clip,
237
+ mesh_t_pos_idx_fx3)
238
+
239
+ depth = gb_feat[..., -2:-1]
240
+ ori_mesh_feature = gb_feat[..., :-4]
241
+
242
+ normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
243
+ normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
244
+ # normal = F.normalize(normal, dim=-1)
245
+ # normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
246
+ return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal
247
+
248
+ def render_mesh_light(
249
+ self,
250
+ mesh_v_pos_bxnx3,
251
+ mesh_t_pos_idx_fx3,
252
+ mesh,
253
+ camera_mv_bx4x4,
254
+ mesh_v_feat_bxnxd,
255
+ resolution=256,
256
+ spp=1,
257
+ device='cuda',
258
+ hierarchical_mask=False
259
+ ):
260
+ assert not hierarchical_mask
261
+
262
+ mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
263
+ v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
264
+ v_pos_clip = self.camera.project(v_pos) # Projection in the camera
265
+
266
+ v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
267
+
268
+ # Render the image,
269
+ # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
270
+ num_layers = 1
271
+ mask_pyramid = None
272
+ assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
273
+ mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos
274
+
275
+ with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
276
+ for _ in range(num_layers):
277
+ rast, db = peeler.rasterize_next_layer()
278
+ gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
279
+
280
+ hard_mask = torch.clamp(rast[..., -1:], 0, 1)
281
+ antialias_mask = dr.antialias(
282
+ hard_mask.clone().contiguous(), rast, v_pos_clip,
283
+ mesh_t_pos_idx_fx3)
284
+
285
+ depth = gb_feat[..., -2:-1]
286
+ ori_mesh_feature = gb_feat[..., :-4]
287
+
288
+ normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
289
+ normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
290
+ normal = F.normalize(normal, dim=-1)
291
+ normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
292
+
293
+ return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal
models/lrm/models/geometry/render/renderutils/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith
11
+ __all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ]
models/lrm/models/geometry/render/renderutils/bsdf.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
4
+ # property and proprietary rights in and to this material, related
5
+ # documentation and any modifications thereto. Any use, reproduction,
6
+ # disclosure or distribution of this material and related documentation
7
+ # without an express license agreement from NVIDIA CORPORATION or
8
+ # its affiliates is strictly prohibited.
9
+
10
+ import math
11
+ import torch
12
+
13
+ NORMAL_THRESHOLD = 0.1
14
+
15
+ ################################################################################
16
+ # Vector utility functions
17
+ ################################################################################
18
+
19
+ def _dot(x, y):
20
+ return torch.sum(x*y, -1, keepdim=True)
21
+
22
+ def _reflect(x, n):
23
+ return 2*_dot(x, n)*n - x
24
+
25
+ def _safe_normalize(x):
26
+ return torch.nn.functional.normalize(x, dim = -1)
27
+
28
+ def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
29
+ # Swap normal direction for backfacing surfaces
30
+ if two_sided_shading:
31
+ smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)
32
+ geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)
33
+
34
+ t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)
35
+ return torch.lerp(geom_nrm, smooth_nrm, t)
36
+
37
+
38
+ def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
39
+ smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm))
40
+ if opengl:
41
+ shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
42
+ else:
43
+ shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
44
+ return _safe_normalize(shading_nrm)
45
+
46
+ def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
47
+ smooth_nrm = _safe_normalize(smooth_nrm)
48
+ smooth_tng = _safe_normalize(smooth_tng)
49
+ view_vec = _safe_normalize(view_pos - pos)
50
+ shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)
51
+ return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)
52
+
53
+ ################################################################################
54
+ # Simple lambertian diffuse BSDF
55
+ ################################################################################
56
+
57
+ def bsdf_lambert(nrm, wi):
58
+ return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi
59
+
60
+ ################################################################################
61
+ # Frostbite diffuse
62
+ ################################################################################
63
+
64
+ def bsdf_frostbite(nrm, wi, wo, linearRoughness):
65
+ wiDotN = _dot(wi, nrm)
66
+ woDotN = _dot(wo, nrm)
67
+
68
+ h = _safe_normalize(wo + wi)
69
+ wiDotH = _dot(wi, h)
70
+
71
+ energyBias = 0.5 * linearRoughness
72
+ energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness
73
+ f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness
74
+ f0 = 1.0
75
+
76
+ wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN)
77
+ woScatter = bsdf_fresnel_shlick(f0, f90, woDotN)
78
+ res = wiScatter * woScatter * energyFactor
79
+ return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res))
80
+
81
+ ################################################################################
82
+ # Phong specular, loosely based on mitsuba implementation
83
+ ################################################################################
84
+
85
+ def bsdf_phong(nrm, wo, wi, N):
86
+ dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)
87
+ dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)
88
+ return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)
89
+
90
+ ################################################################################
91
+ # PBR's implementation of GGX specular
92
+ ################################################################################
93
+
94
+ specular_epsilon = 1e-4
95
+
96
+ def bsdf_fresnel_shlick(f0, f90, cosTheta):
97
+ _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
98
+ return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0
99
+
100
+ def bsdf_ndf_ggx(alphaSqr, cosTheta):
101
+ _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
102
+ d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1
103
+ return alphaSqr / (d * d * math.pi)
104
+
105
+ def bsdf_lambda_ggx(alphaSqr, cosTheta):
106
+ _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
107
+ cosThetaSqr = _cosTheta * _cosTheta
108
+ tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr
109
+ res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)
110
+ return res
111
+
112
+ def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
113
+ lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)
114
+ lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)
115
+ return 1 / (1 + lambdaI + lambdaO)
116
+
117
+ def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
118
+ _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)
119
+ alphaSqr = _alpha * _alpha
120
+
121
+ h = _safe_normalize(wo + wi)
122
+ woDotN = _dot(wo, nrm)
123
+ wiDotN = _dot(wi, nrm)
124
+ woDotH = _dot(wo, h)
125
+ nDotH = _dot(nrm, h)
126
+
127
+ D = bsdf_ndf_ggx(alphaSqr, nDotH)
128
+ G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)
129
+ F = bsdf_fresnel_shlick(col, 1, woDotH)
130
+
131
+ w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)
132
+
133
+ frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)
134
+ return torch.where(frontfacing, w, torch.zeros_like(w))
135
+
136
+ def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
137
+ wo = _safe_normalize(view_pos - pos)
138
+ wi = _safe_normalize(light_pos - pos)
139
+
140
+ spec_str = arm[..., 0:1] # x component
141
+ roughness = arm[..., 1:2] # y component
142
+ metallic = arm[..., 2:3] # z component
143
+ ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)
144
+ kd = kd * (1.0 - metallic)
145
+
146
+ if BSDF == 0:
147
+ diffuse = kd * bsdf_lambert(nrm, wi)
148
+ else:
149
+ diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness)
150
+ specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)
151
+ return diffuse + specular
models/lrm/models/geometry/render/renderutils/c_src/bsdf.cu ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include "common.h"
13
+ #include "bsdf.h"
14
+
15
+ #define SPECULAR_EPSILON 1e-4f
16
+
17
+ //------------------------------------------------------------------------
18
+ // Lambert functions
19
+
20
+ __device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)
21
+ {
22
+ return max(dot(nrm, wi) / M_PI, 0.0f);
23
+ }
24
+
25
+ __device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
26
+ {
27
+ if (dot(nrm, wi) > 0.0f)
28
+ bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
29
+ }
30
+
31
+ //------------------------------------------------------------------------
32
+ // Fresnel Schlick
33
+
34
+ __device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)
35
+ {
36
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
37
+ float scale = powf(1.0f - _cosTheta, 5.0f);
38
+ return f0 * (1.0f - scale) + f90 * scale;
39
+ }
40
+
41
+ __device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)
42
+ {
43
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
44
+ float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
45
+ d_f0 += d_out * (1.0 - scale);
46
+ d_f90 += d_out * scale;
47
+ if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
48
+ {
49
+ d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);
50
+ }
51
+ }
52
+
53
+ __device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
54
+ {
55
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
56
+ float scale = powf(1.0f - _cosTheta, 5.0f);
57
+ return f0 * (1.0f - scale) + f90 * scale;
58
+ }
59
+
60
+ __device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
61
+ {
62
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
63
+ float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
64
+ d_f0 += d_out * (1.0 - scale);
65
+ d_f90 += d_out * scale;
66
+ if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
67
+ {
68
+ d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
69
+ }
70
+ }
71
+
72
+ //------------------------------------------------------------------------
73
+ // Frostbite diffuse
74
+
75
+ __device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)
76
+ {
77
+ float wiDotN = dot(wi, nrm);
78
+ float woDotN = dot(wo, nrm);
79
+ if (wiDotN > 0.0f && woDotN > 0.0f)
80
+ {
81
+ vec3f h = safeNormalize(wo + wi);
82
+ float wiDotH = dot(wi, h);
83
+
84
+ float energyBias = 0.5f * linearRoughness;
85
+ float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
86
+ float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
87
+ float f0 = 1.f;
88
+
89
+ float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
90
+ float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
91
+
92
+ return wiScatter * woScatter * energyFactor;
93
+ }
94
+ else return 0.0f;
95
+ }
96
+
97
+ __device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)
98
+ {
99
+ float wiDotN = dot(wi, nrm);
100
+ float woDotN = dot(wo, nrm);
101
+
102
+ if (wiDotN > 0.0f && woDotN > 0.0f)
103
+ {
104
+ vec3f h = safeNormalize(wo + wi);
105
+ float wiDotH = dot(wi, h);
106
+
107
+ float energyBias = 0.5f * linearRoughness;
108
+ float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
109
+ float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
110
+ float f0 = 1.f;
111
+
112
+ float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
113
+ float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
114
+
115
+ // -------------- BWD --------------
116
+ // Backprop: return wiScatter * woScatter * energyFactor;
117
+ float d_wiScatter = d_out * woScatter * energyFactor;
118
+ float d_woScatter = d_out * wiScatter * energyFactor;
119
+ float d_energyFactor = d_out * wiScatter * woScatter;
120
+
121
+ // Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
122
+ float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;
123
+ bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);
124
+
125
+ // Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);
126
+ float d_wiDotN = 0.0f;
127
+ bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);
128
+
129
+ // Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
130
+ float d_energyBias = d_f90;
131
+ float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;
132
+ d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;
133
+
134
+ // Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
135
+ d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;
136
+
137
+ // Backprop: float energyBias = 0.5f * linearRoughness;
138
+ d_linearRoughness += 0.5 * d_energyBias;
139
+
140
+ // Backprop: float wiDotH = dot(wi, h);
141
+ vec3f d_h(0);
142
+ bwdDot(wi, h, d_wi, d_h, d_wiDotH);
143
+
144
+ // Backprop: vec3f h = safeNormalize(wo + wi);
145
+ vec3f d_wo_wi(0);
146
+ bwdSafeNormalize(wo + wi, d_wo_wi, d_h);
147
+ d_wi += d_wo_wi; d_wo += d_wo_wi;
148
+
149
+ bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
150
+ bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
151
+ }
152
+ }
153
+
154
+ //------------------------------------------------------------------------
155
+ // Ndf GGX
156
+
157
+ __device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)
158
+ {
159
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
160
+ float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
161
+ return alphaSqr / (d * d * M_PI);
162
+ }
163
+
164
+ __device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
165
+ {
166
+ // Torch only back propagates if clamp doesn't trigger
167
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
168
+ float cosThetaSqr = _cosTheta * _cosTheta;
169
+ d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
170
+ if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
171
+ {
172
+ d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
173
+ }
174
+ }
175
+
176
+ //------------------------------------------------------------------------
177
+ // Lambda GGX
178
+
179
+ __device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
180
+ {
181
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
182
+ float cosThetaSqr = _cosTheta * _cosTheta;
183
+ float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
184
+ float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
185
+ return res;
186
+ }
187
+
188
+ __device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
189
+ {
190
+ float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
191
+ float cosThetaSqr = _cosTheta * _cosTheta;
192
+ float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
193
+ float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
194
+
195
+ d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
196
+ if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
197
+ d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
198
+ }
199
+
200
+ //------------------------------------------------------------------------
201
+ // Masking GGX
202
+
203
+ __device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
204
+ {
205
+ float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
206
+ float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
207
+ return 1.0f / (1.0f + lambdaI + lambdaO);
208
+ }
209
+
210
+ __device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
211
+ {
212
+ // FWD eval
213
+ float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
214
+ float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
215
+
216
+ // BWD eval
217
+ float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
218
+ bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
219
+ bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
220
+ }
221
+
222
+ //------------------------------------------------------------------------
223
+ // GGX specular
224
+
225
+ __device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
226
+ {
227
+ float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
228
+ float alphaSqr = _alpha * _alpha;
229
+
230
+ vec3f h = safeNormalize(wo + wi);
231
+ float woDotN = dot(wo, nrm);
232
+ float wiDotN = dot(wi, nrm);
233
+ float woDotH = dot(wo, h);
234
+ float nDotH = dot(nrm, h);
235
+
236
+ float D = fwdNdfGGX(alphaSqr, nDotH);
237
+ float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
238
+ vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
239
+ vec3f w = F * D * G * 0.25 / woDotN;
240
+
241
+ bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
242
+ return frontfacing ? w : 0.0f;
243
+ }
244
+
245
+ __device__ void bwdPbrSpecular(
246
+ const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
247
+ vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
248
+ {
249
+ ///////////////////////////////////////////////////////////////////////
250
+ // FWD eval
251
+
252
+ float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
253
+ float alphaSqr = _alpha * _alpha;
254
+
255
+ vec3f h = safeNormalize(wo + wi);
256
+ float woDotN = dot(wo, nrm);
257
+ float wiDotN = dot(wi, nrm);
258
+ float woDotH = dot(wo, h);
259
+ float nDotH = dot(nrm, h);
260
+
261
+ float D = fwdNdfGGX(alphaSqr, nDotH);
262
+ float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
263
+ vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
264
+ vec3f w = F * D * G * 0.25 / woDotN;
265
+ bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
266
+
267
+ if (frontfacing)
268
+ {
269
+ ///////////////////////////////////////////////////////////////////////
270
+ // BWD eval
271
+
272
+ vec3f d_F = d_out * D * G * 0.25f / woDotN;
273
+ float d_D = sum(d_out * F * G * 0.25f / woDotN);
274
+ float d_G = sum(d_out * F * D * 0.25f / woDotN);
275
+
276
+ float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
277
+
278
+ vec3f d_f90(0);
279
+ float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
280
+ bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
281
+ bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
282
+ bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
283
+
284
+ vec3f d_h(0);
285
+ bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
286
+ bwdDot(wo, h, d_wo, d_h, d_woDotH);
287
+ bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
288
+ bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
289
+
290
+ vec3f d_h_unnorm(0);
291
+ bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
292
+ d_wo += d_h_unnorm;
293
+ d_wi += d_h_unnorm;
294
+
295
+ if (alpha > min_roughness * min_roughness)
296
+ d_alpha += d_alphaSqr * 2 * alpha;
297
+ }
298
+ }
299
+
300
+ //------------------------------------------------------------------------
301
+ // Full PBR BSDF
302
+
303
+ __device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)
304
+ {
305
+ vec3f wo = safeNormalize(view_pos - pos);
306
+ vec3f wi = safeNormalize(light_pos - pos);
307
+
308
+ float alpha = arm.y * arm.y;
309
+ vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
310
+ vec3f diff_col = kd * (1.0f - arm.z);
311
+
312
+ float diff = 0.0f;
313
+ if (BSDF == 0)
314
+ diff = fwdLambert(nrm, wi);
315
+ else
316
+ diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
317
+ vec3f diffuse = diff_col * diff;
318
+ vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
319
+
320
+ return diffuse + specular;
321
+ }
322
+
323
+ __device__ void bwdPbrBSDF(
324
+ const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,
325
+ vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
326
+ {
327
+ ////////////////////////////////////////////////////////////////////////
328
+ // FWD
329
+ vec3f _wi = light_pos - pos;
330
+ vec3f _wo = view_pos - pos;
331
+ vec3f wi = safeNormalize(_wi);
332
+ vec3f wo = safeNormalize(_wo);
333
+
334
+ float alpha = arm.y * arm.y;
335
+ vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
336
+ vec3f diff_col = kd * (1.0f - arm.z);
337
+ float diff = 0.0f;
338
+ if (BSDF == 0)
339
+ diff = fwdLambert(nrm, wi);
340
+ else
341
+ diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
342
+
343
+ ////////////////////////////////////////////////////////////////////////
344
+ // BWD
345
+
346
+ float d_alpha(0);
347
+ vec3f d_spec_col(0), d_wi(0), d_wo(0);
348
+ bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
349
+
350
+ float d_diff = sum(diff_col * d_out);
351
+ if (BSDF == 0)
352
+ bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);
353
+ else
354
+ bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);
355
+
356
+ // Backprop: diff_col = kd * (1.0f - arm.z)
357
+ vec3f d_diff_col = d_out * diff;
358
+ d_kd += d_diff_col * (1.0f - arm.z);
359
+ d_arm.z -= sum(d_diff_col * kd);
360
+
361
+ // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
362
+ d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
363
+ d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
364
+ d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
365
+
366
+ // Backprop: alpha = arm.y * arm.y
367
+ d_arm.y += d_alpha * 2 * arm.y;
368
+
369
+ // Backprop: vec3f wi = safeNormalize(light_pos - pos);
370
+ vec3f d__wi(0);
371
+ bwdSafeNormalize(_wi, d__wi, d_wi);
372
+ d_light_pos += d__wi;
373
+ d_pos -= d__wi;
374
+
375
+ // Backprop: vec3f wo = safeNormalize(view_pos - pos);
376
+ vec3f d__wo(0);
377
+ bwdSafeNormalize(_wo, d__wo, d_wo);
378
+ d_view_pos += d__wo;
379
+ d_pos -= d__wo;
380
+ }
381
+
382
+ //------------------------------------------------------------------------
383
+ // Kernels
384
+
385
+ __global__ void LambertFwdKernel(LambertKernelParams p)
386
+ {
387
+ // Calculate pixel position.
388
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
389
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
390
+ unsigned int pz = blockIdx.z;
391
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
392
+ return;
393
+
394
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
395
+ vec3f wi = p.wi.fetch3(px, py, pz);
396
+
397
+ float res = fwdLambert(nrm, wi);
398
+
399
+ p.out.store(px, py, pz, res);
400
+ }
401
+
402
+ __global__ void LambertBwdKernel(LambertKernelParams p)
403
+ {
404
+ // Calculate pixel position.
405
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
406
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
407
+ unsigned int pz = blockIdx.z;
408
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
409
+ return;
410
+
411
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
412
+ vec3f wi = p.wi.fetch3(px, py, pz);
413
+ float d_out = p.out.fetch1(px, py, pz);
414
+
415
+ vec3f d_nrm(0), d_wi(0);
416
+ bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
417
+
418
+ p.nrm.store_grad(px, py, pz, d_nrm);
419
+ p.wi.store_grad(px, py, pz, d_wi);
420
+ }
421
+
422
+ __global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)
423
+ {
424
+ // Calculate pixel position.
425
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
426
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
427
+ unsigned int pz = blockIdx.z;
428
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
429
+ return;
430
+
431
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
432
+ vec3f wi = p.wi.fetch3(px, py, pz);
433
+ vec3f wo = p.wo.fetch3(px, py, pz);
434
+ float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
435
+
436
+ float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);
437
+
438
+ p.out.store(px, py, pz, res);
439
+ }
440
+
441
+ __global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)
442
+ {
443
+ // Calculate pixel position.
444
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
445
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
446
+ unsigned int pz = blockIdx.z;
447
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
448
+ return;
449
+
450
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
451
+ vec3f wi = p.wi.fetch3(px, py, pz);
452
+ vec3f wo = p.wo.fetch3(px, py, pz);
453
+ float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
454
+ float d_out = p.out.fetch1(px, py, pz);
455
+
456
+ float d_linearRoughness = 0.0f;
457
+ vec3f d_nrm(0), d_wi(0), d_wo(0);
458
+ bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);
459
+
460
+ p.nrm.store_grad(px, py, pz, d_nrm);
461
+ p.wi.store_grad(px, py, pz, d_wi);
462
+ p.wo.store_grad(px, py, pz, d_wo);
463
+ p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);
464
+ }
465
+
466
+ __global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
467
+ {
468
+ // Calculate pixel position.
469
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
470
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
471
+ unsigned int pz = blockIdx.z;
472
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
473
+ return;
474
+
475
+ vec3f f0 = p.f0.fetch3(px, py, pz);
476
+ vec3f f90 = p.f90.fetch3(px, py, pz);
477
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
478
+
479
+ vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
480
+ p.out.store(px, py, pz, res);
481
+ }
482
+
483
+ __global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
484
+ {
485
+ // Calculate pixel position.
486
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
487
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
488
+ unsigned int pz = blockIdx.z;
489
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
490
+ return;
491
+
492
+ vec3f f0 = p.f0.fetch3(px, py, pz);
493
+ vec3f f90 = p.f90.fetch3(px, py, pz);
494
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
495
+ vec3f d_out = p.out.fetch3(px, py, pz);
496
+
497
+ vec3f d_f0(0), d_f90(0);
498
+ float d_cosTheta(0);
499
+ bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
500
+
501
+ p.f0.store_grad(px, py, pz, d_f0);
502
+ p.f90.store_grad(px, py, pz, d_f90);
503
+ p.cosTheta.store_grad(px, py, pz, d_cosTheta);
504
+ }
505
+
506
+ __global__ void ndfGGXFwdKernel(NdfGGXParams p)
507
+ {
508
+ // Calculate pixel position.
509
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
510
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
511
+ unsigned int pz = blockIdx.z;
512
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
513
+ return;
514
+
515
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
516
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
517
+ float res = fwdNdfGGX(alphaSqr, cosTheta);
518
+
519
+ p.out.store(px, py, pz, res);
520
+ }
521
+
522
+ __global__ void ndfGGXBwdKernel(NdfGGXParams p)
523
+ {
524
+ // Calculate pixel position.
525
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
526
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
527
+ unsigned int pz = blockIdx.z;
528
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
529
+ return;
530
+
531
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
532
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
533
+ float d_out = p.out.fetch1(px, py, pz);
534
+
535
+ float d_alphaSqr(0), d_cosTheta(0);
536
+ bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
537
+
538
+ p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
539
+ p.cosTheta.store_grad(px, py, pz, d_cosTheta);
540
+ }
541
+
542
+ __global__ void lambdaGGXFwdKernel(NdfGGXParams p)
543
+ {
544
+ // Calculate pixel position.
545
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
546
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
547
+ unsigned int pz = blockIdx.z;
548
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
549
+ return;
550
+
551
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
552
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
553
+ float res = fwdLambdaGGX(alphaSqr, cosTheta);
554
+
555
+ p.out.store(px, py, pz, res);
556
+ }
557
+
558
+ __global__ void lambdaGGXBwdKernel(NdfGGXParams p)
559
+ {
560
+ // Calculate pixel position.
561
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
562
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
563
+ unsigned int pz = blockIdx.z;
564
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
565
+ return;
566
+
567
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
568
+ float cosTheta = p.cosTheta.fetch1(px, py, pz);
569
+ float d_out = p.out.fetch1(px, py, pz);
570
+
571
+ float d_alphaSqr(0), d_cosTheta(0);
572
+ bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
573
+
574
+ p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
575
+ p.cosTheta.store_grad(px, py, pz, d_cosTheta);
576
+ }
577
+
578
+ __global__ void maskingSmithFwdKernel(MaskingSmithParams p)
579
+ {
580
+ // Calculate pixel position.
581
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
582
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
583
+ unsigned int pz = blockIdx.z;
584
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
585
+ return;
586
+
587
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
588
+ float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
589
+ float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
590
+ float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
591
+
592
+ p.out.store(px, py, pz, res);
593
+ }
594
+
595
+ __global__ void maskingSmithBwdKernel(MaskingSmithParams p)
596
+ {
597
+ // Calculate pixel position.
598
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
599
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
600
+ unsigned int pz = blockIdx.z;
601
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
602
+ return;
603
+
604
+ float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
605
+ float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
606
+ float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
607
+ float d_out = p.out.fetch1(px, py, pz);
608
+
609
+ float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
610
+ bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
611
+
612
+ p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
613
+ p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
614
+ p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
615
+ }
616
+
617
+ __global__ void pbrSpecularFwdKernel(PbrSpecular p)
618
+ {
619
+ // Calculate pixel position.
620
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
621
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
622
+ unsigned int pz = blockIdx.z;
623
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
624
+ return;
625
+
626
+ vec3f col = p.col.fetch3(px, py, pz);
627
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
628
+ vec3f wo = p.wo.fetch3(px, py, pz);
629
+ vec3f wi = p.wi.fetch3(px, py, pz);
630
+ float alpha = p.alpha.fetch1(px, py, pz);
631
+
632
+ vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
633
+
634
+ p.out.store(px, py, pz, res);
635
+ }
636
+
637
+ __global__ void pbrSpecularBwdKernel(PbrSpecular p)
638
+ {
639
+ // Calculate pixel position.
640
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
641
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
642
+ unsigned int pz = blockIdx.z;
643
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
644
+ return;
645
+
646
+ vec3f col = p.col.fetch3(px, py, pz);
647
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
648
+ vec3f wo = p.wo.fetch3(px, py, pz);
649
+ vec3f wi = p.wi.fetch3(px, py, pz);
650
+ float alpha = p.alpha.fetch1(px, py, pz);
651
+ vec3f d_out = p.out.fetch3(px, py, pz);
652
+
653
+ float d_alpha(0);
654
+ vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
655
+ bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
656
+
657
+ p.col.store_grad(px, py, pz, d_col);
658
+ p.nrm.store_grad(px, py, pz, d_nrm);
659
+ p.wo.store_grad(px, py, pz, d_wo);
660
+ p.wi.store_grad(px, py, pz, d_wi);
661
+ p.alpha.store_grad(px, py, pz, d_alpha);
662
+ }
663
+
664
+ __global__ void pbrBSDFFwdKernel(PbrBSDF p)
665
+ {
666
+ // Calculate pixel position.
667
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
668
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
669
+ unsigned int pz = blockIdx.z;
670
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
671
+ return;
672
+
673
+ vec3f kd = p.kd.fetch3(px, py, pz);
674
+ vec3f arm = p.arm.fetch3(px, py, pz);
675
+ vec3f pos = p.pos.fetch3(px, py, pz);
676
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
677
+ vec3f view_pos = p.view_pos.fetch3(px, py, pz);
678
+ vec3f light_pos = p.light_pos.fetch3(px, py, pz);
679
+
680
+ vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);
681
+
682
+ p.out.store(px, py, pz, res);
683
+ }
684
+ __global__ void pbrBSDFBwdKernel(PbrBSDF p)
685
+ {
686
+ // Calculate pixel position.
687
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
688
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
689
+ unsigned int pz = blockIdx.z;
690
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
691
+ return;
692
+
693
+ vec3f kd = p.kd.fetch3(px, py, pz);
694
+ vec3f arm = p.arm.fetch3(px, py, pz);
695
+ vec3f pos = p.pos.fetch3(px, py, pz);
696
+ vec3f nrm = p.nrm.fetch3(px, py, pz);
697
+ vec3f view_pos = p.view_pos.fetch3(px, py, pz);
698
+ vec3f light_pos = p.light_pos.fetch3(px, py, pz);
699
+ vec3f d_out = p.out.fetch3(px, py, pz);
700
+
701
+ vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
702
+ bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
703
+
704
+ p.kd.store_grad(px, py, pz, d_kd);
705
+ p.arm.store_grad(px, py, pz, d_arm);
706
+ p.pos.store_grad(px, py, pz, d_pos);
707
+ p.nrm.store_grad(px, py, pz, d_nrm);
708
+ p.view_pos.store_grad(px, py, pz, d_view_pos);
709
+ p.light_pos.store_grad(px, py, pz, d_light_pos);
710
+ }
models/lrm/models/geometry/render/renderutils/c_src/bsdf.h ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include "common.h"
15
+
16
+ struct LambertKernelParams
17
+ {
18
+ Tensor nrm;
19
+ Tensor wi;
20
+ Tensor out;
21
+ dim3 gridSize;
22
+ };
23
+
24
+ struct FrostbiteDiffuseKernelParams
25
+ {
26
+ Tensor nrm;
27
+ Tensor wi;
28
+ Tensor wo;
29
+ Tensor linearRoughness;
30
+ Tensor out;
31
+ dim3 gridSize;
32
+ };
33
+
34
+ struct FresnelShlickKernelParams
35
+ {
36
+ Tensor f0;
37
+ Tensor f90;
38
+ Tensor cosTheta;
39
+ Tensor out;
40
+ dim3 gridSize;
41
+ };
42
+
43
+ struct NdfGGXParams
44
+ {
45
+ Tensor alphaSqr;
46
+ Tensor cosTheta;
47
+ Tensor out;
48
+ dim3 gridSize;
49
+ };
50
+
51
+ struct MaskingSmithParams
52
+ {
53
+ Tensor alphaSqr;
54
+ Tensor cosThetaI;
55
+ Tensor cosThetaO;
56
+ Tensor out;
57
+ dim3 gridSize;
58
+ };
59
+
60
+ struct PbrSpecular
61
+ {
62
+ Tensor col;
63
+ Tensor nrm;
64
+ Tensor wo;
65
+ Tensor wi;
66
+ Tensor alpha;
67
+ Tensor out;
68
+ dim3 gridSize;
69
+ float min_roughness;
70
+ };
71
+
72
+ struct PbrBSDF
73
+ {
74
+ Tensor kd;
75
+ Tensor arm;
76
+ Tensor pos;
77
+ Tensor nrm;
78
+ Tensor view_pos;
79
+ Tensor light_pos;
80
+ Tensor out;
81
+ dim3 gridSize;
82
+ float min_roughness;
83
+ int BSDF;
84
+ };
models/lrm/models/geometry/render/renderutils/c_src/common.cpp ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include <cuda_runtime.h>
13
+ #include <algorithm>
14
+
15
+ //------------------------------------------------------------------------
16
+ // Block and grid size calculators for kernel launches.
17
+
18
+ dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)
19
+ {
20
+ int maxThreads = maxWidth * maxHeight;
21
+ if (maxThreads <= 1 || (dims.x * dims.y) <= 1)
22
+ return dim3(1, 1, 1); // Degenerate.
23
+
24
+ // Start from max size.
25
+ int bw = maxWidth;
26
+ int bh = maxHeight;
27
+
28
+ // Optimizations for weirdly sized buffers.
29
+ if (dims.x < bw)
30
+ {
31
+ // Decrease block width to smallest power of two that covers the buffer width.
32
+ while ((bw >> 1) >= dims.x)
33
+ bw >>= 1;
34
+
35
+ // Maximize height.
36
+ bh = maxThreads / bw;
37
+ if (bh > dims.y)
38
+ bh = dims.y;
39
+ }
40
+ else if (dims.y < bh)
41
+ {
42
+ // Halve height and double width until fits completely inside buffer vertically.
43
+ while (bh > dims.y)
44
+ {
45
+ bh >>= 1;
46
+ if (bw < dims.x)
47
+ bw <<= 1;
48
+ }
49
+ }
50
+
51
+ // Done.
52
+ return dim3(bw, bh, 1);
53
+ }
54
+
55
+ // returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync)
56
+ dim3 getWarpSize(dim3 blockSize)
57
+ {
58
+ return dim3(
59
+ std::min(blockSize.x, 32u),
60
+ std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)),
61
+ std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z))
62
+ );
63
+ }
64
+
65
+ dim3 getLaunchGridSize(dim3 blockSize, dim3 dims)
66
+ {
67
+ dim3 gridSize;
68
+ gridSize.x = (dims.x - 1) / blockSize.x + 1;
69
+ gridSize.y = (dims.y - 1) / blockSize.y + 1;
70
+ gridSize.z = (dims.z - 1) / blockSize.z + 1;
71
+ return gridSize;
72
+ }
73
+
74
+ //------------------------------------------------------------------------
models/lrm/models/geometry/render/renderutils/c_src/common.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+ #include <cuda.h>
14
+ #include <stdint.h>
15
+
16
+ #include "vec3f.h"
17
+ #include "vec4f.h"
18
+ #include "tensor.h"
19
+
20
+ dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims);
21
+ dim3 getLaunchGridSize(dim3 blockSize, dim3 dims);
22
+
23
+ #ifdef __CUDACC__
24
+
25
+ #ifdef _MSC_VER
26
+ #define M_PI 3.14159265358979323846f
27
+ #endif
28
+
29
+ __host__ __device__ static inline dim3 getWarpSize(dim3 blockSize)
30
+ {
31
+ return dim3(
32
+ min(blockSize.x, 32u),
33
+ min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)),
34
+ min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z))
35
+ );
36
+ }
37
+
38
+ __device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); }
39
+ #else
40
+ dim3 getWarpSize(dim3 blockSize);
41
+ #endif
models/lrm/models/geometry/render/renderutils/c_src/cubemap.cu ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include "common.h"
13
+ #include "cubemap.h"
14
+ #include <float.h>
15
+
16
+ // https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf
17
+ __device__ float pixel_area(int x, int y, int N)
18
+ {
19
+ if (N > 1)
20
+ {
21
+ int H = N / 2;
22
+ x = abs(x - H);
23
+ y = abs(y - H);
24
+ float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H);
25
+ float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H);
26
+ return dx * dy;
27
+ }
28
+ else
29
+ return 1;
30
+ }
31
+
32
+ __device__ vec3f cube_to_dir(int x, int y, int side, int N)
33
+ {
34
+ float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f;
35
+ float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f;
36
+ switch (side)
37
+ {
38
+ case 0: return safeNormalize(vec3f(1, -fy, -fx));
39
+ case 1: return safeNormalize(vec3f(-1, -fy, fx));
40
+ case 2: return safeNormalize(vec3f(fx, 1, fy));
41
+ case 3: return safeNormalize(vec3f(fx, -1, -fy));
42
+ case 4: return safeNormalize(vec3f(fx, -fy, 1));
43
+ case 5: return safeNormalize(vec3f(-fx, -fy, -1));
44
+ }
45
+ return vec3f(0,0,0); // Unreachable
46
+ }
47
+
48
+ __device__ vec3f dir_to_side(int side, vec3f v)
49
+ {
50
+ switch (side)
51
+ {
52
+ case 0: return vec3f(-v.z, -v.y, v.x);
53
+ case 1: return vec3f( v.z, -v.y, -v.x);
54
+ case 2: return vec3f( v.x, v.z, v.y);
55
+ case 3: return vec3f( v.x, -v.z, -v.y);
56
+ case 4: return vec3f( v.x, -v.y, v.z);
57
+ case 5: return vec3f(-v.x, -v.y, -v.z);
58
+ }
59
+ return vec3f(0,0,0); // Unreachable
60
+ }
61
+
62
+ __device__ void extents_1d(float x, float z, float theta, float& _min, float& _max)
63
+ {
64
+ float l = sqrtf(x * x + z * z);
65
+ float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l;
66
+ float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l;
67
+ if (pzl <= 0.00001f)
68
+ _min = pxl > 0.0f ? FLT_MAX : -FLT_MAX;
69
+ else
70
+ _min = pxl / pzl;
71
+ if (pzr <= 0.00001f)
72
+ _max = pxr > 0.0f ? FLT_MAX : -FLT_MAX;
73
+ else
74
+ _max = pxr / pzr;
75
+ }
76
+
77
+ __device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax)
78
+ {
79
+ vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1
80
+
81
+ if (theta < 0.785398f) // PI/4
82
+ {
83
+ float xmin, xmax, ymin, ymax;
84
+ extents_1d(c.x, c.z, theta, xmin, xmax);
85
+ extents_1d(c.y, c.z, theta, ymin, ymax);
86
+
87
+ if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f)
88
+ {
89
+ _xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb
90
+ }
91
+ else
92
+ {
93
+ _xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
94
+ _xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
95
+ _ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
96
+ _ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
97
+ }
98
+ }
99
+ else
100
+ {
101
+ _xmin = 0.0f;
102
+ _xmax = (float)(N-1);
103
+ _ymin = 0.0f;
104
+ _ymax = (float)(N-1);
105
+ }
106
+ }
107
+
108
+ ///////////////////////////////////////////////////////////////////////////////////////////////////////////
109
+ // Diffuse kernel
110
+ __global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p)
111
+ {
112
+ // Calculate pixel position.
113
+ int px = blockIdx.x * blockDim.x + threadIdx.x;
114
+ int py = blockIdx.y * blockDim.y + threadIdx.y;
115
+ int pz = blockIdx.z;
116
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
117
+ return;
118
+
119
+ int Npx = p.cubemap.dims[1];
120
+ vec3f N = cube_to_dir(px, py, pz, Npx);
121
+
122
+ vec3f col(0);
123
+
124
+ for (int s = 0; s < p.cubemap.dims[0]; ++s)
125
+ {
126
+ for (int y = 0; y < Npx; ++y)
127
+ {
128
+ for (int x = 0; x < Npx; ++x)
129
+ {
130
+ vec3f L = cube_to_dir(x, y, s, Npx);
131
+ float costheta = min(max(dot(N, L), 0.0f), 0.999f);
132
+ float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
133
+ col += p.cubemap.fetch3(x, y, s) * w;
134
+ }
135
+ }
136
+ }
137
+
138
+ p.out.store(px, py, pz, col);
139
+ }
140
+
141
+ __global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p)
142
+ {
143
+ // Calculate pixel position.
144
+ int px = blockIdx.x * blockDim.x + threadIdx.x;
145
+ int py = blockIdx.y * blockDim.y + threadIdx.y;
146
+ int pz = blockIdx.z;
147
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
148
+ return;
149
+
150
+ int Npx = p.cubemap.dims[1];
151
+ vec3f N = cube_to_dir(px, py, pz, Npx);
152
+ vec3f grad = p.out.fetch3(px, py, pz);
153
+
154
+ for (int s = 0; s < p.cubemap.dims[0]; ++s)
155
+ {
156
+ for (int y = 0; y < Npx; ++y)
157
+ {
158
+ for (int x = 0; x < Npx; ++x)
159
+ {
160
+ vec3f L = cube_to_dir(x, y, s, Npx);
161
+ float costheta = min(max(dot(N, L), 0.0f), 0.999f);
162
+ float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
163
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
164
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
165
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
166
+ }
167
+ }
168
+ }
169
+ }
170
+
171
+ ///////////////////////////////////////////////////////////////////////////////////////////////////////////
172
+ // GGX splitsum kernel
173
+
174
+ __device__ inline float ndfGGX(const float alphaSqr, const float cosTheta)
175
+ {
176
+ float _cosTheta = clamp(cosTheta, 0.0, 1.0f);
177
+ float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
178
+ return alphaSqr / (d * d * M_PI);
179
+ }
180
+
181
+ __global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p)
182
+ {
183
+ int px = blockIdx.x * blockDim.x + threadIdx.x;
184
+ int py = blockIdx.y * blockDim.y + threadIdx.y;
185
+ int pz = blockIdx.z;
186
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
187
+ return;
188
+
189
+ int Npx = p.gridSize.x;
190
+ vec3f VNR = cube_to_dir(px, py, pz, Npx);
191
+
192
+ const int TILE_SIZE = 16;
193
+
194
+ // Brute force entire cubemap and compute bounds for the cone
195
+ for (int s = 0; s < p.gridSize.z; ++s)
196
+ {
197
+ // Assume empty BBox
198
+ int _min_x = p.gridSize.x - 1, _max_x = 0;
199
+ int _min_y = p.gridSize.y - 1, _max_y = 0;
200
+
201
+ // For each (8x8) tile
202
+ for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++)
203
+ {
204
+ for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++)
205
+ {
206
+ // Compute tile extents
207
+ int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE;
208
+ int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y);
209
+
210
+ // Use some blunt interval arithmetics to cull tiles
211
+ vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx);
212
+ vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx);
213
+
214
+ float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x));
215
+ float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y));
216
+ float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z));
217
+
218
+ float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z);
219
+ if (maxdp >= p.costheta_cutoff)
220
+ {
221
+ // Test all pixels in tile.
222
+ for (int y = tsy; y < tey; ++y)
223
+ {
224
+ for (int x = tsx; x < tex; ++x)
225
+ {
226
+ vec3f L = cube_to_dir(x, y, s, Npx);
227
+ if (dot(L, VNR) >= p.costheta_cutoff)
228
+ {
229
+ _min_x = min(_min_x, x);
230
+ _max_x = max(_max_x, x);
231
+ _min_y = min(_min_y, y);
232
+ _max_y = max(_max_y, y);
233
+ }
234
+ }
235
+ }
236
+ }
237
+ }
238
+ }
239
+ p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x);
240
+ p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x);
241
+ p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y);
242
+ p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y);
243
+ }
244
+ }
245
+
246
+ __global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p)
247
+ {
248
+ // Calculate pixel position.
249
+ int px = blockIdx.x * blockDim.x + threadIdx.x;
250
+ int py = blockIdx.y * blockDim.y + threadIdx.y;
251
+ int pz = blockIdx.z;
252
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
253
+ return;
254
+
255
+ int Npx = p.cubemap.dims[1];
256
+ vec3f VNR = cube_to_dir(px, py, pz, Npx);
257
+
258
+ float alpha = p.roughness * p.roughness;
259
+ float alphaSqr = alpha * alpha;
260
+
261
+ float wsum = 0.0f;
262
+ vec3f col(0);
263
+ for (int s = 0; s < p.cubemap.dims[0]; ++s)
264
+ {
265
+ int xmin, xmax, ymin, ymax;
266
+ xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
267
+ xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
268
+ ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
269
+ ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
270
+
271
+ if (xmin <= xmax)
272
+ {
273
+ for (int y = ymin; y <= ymax; ++y)
274
+ {
275
+ for (int x = xmin; x <= xmax; ++x)
276
+ {
277
+ vec3f L = cube_to_dir(x, y, s, Npx);
278
+ if (dot(L, VNR) >= p.costheta_cutoff)
279
+ {
280
+ vec3f H = safeNormalize(L + VNR);
281
+
282
+ float wiDotN = max(dot(L, VNR), 0.0f);
283
+ float VNRDotH = max(dot(VNR, H), 0.0f);
284
+
285
+ float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
286
+ col += p.cubemap.fetch3(x, y, s) * w;
287
+ wsum += w;
288
+ }
289
+ }
290
+ }
291
+ }
292
+ }
293
+
294
+ p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x);
295
+ p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y);
296
+ p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z);
297
+ p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum);
298
+ }
299
+
300
+ __global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p)
301
+ {
302
+ // Calculate pixel position.
303
+ int px = blockIdx.x * blockDim.x + threadIdx.x;
304
+ int py = blockIdx.y * blockDim.y + threadIdx.y;
305
+ int pz = blockIdx.z;
306
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
307
+ return;
308
+
309
+ int Npx = p.cubemap.dims[1];
310
+ vec3f VNR = cube_to_dir(px, py, pz, Npx);
311
+
312
+ vec3f grad = p.out.fetch3(px, py, pz);
313
+
314
+ float alpha = p.roughness * p.roughness;
315
+ float alphaSqr = alpha * alpha;
316
+
317
+ vec3f col(0);
318
+ for (int s = 0; s < p.cubemap.dims[0]; ++s)
319
+ {
320
+ int xmin, xmax, ymin, ymax;
321
+ xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
322
+ xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
323
+ ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
324
+ ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
325
+
326
+ if (xmin <= xmax)
327
+ {
328
+ for (int y = ymin; y <= ymax; ++y)
329
+ {
330
+ for (int x = xmin; x <= xmax; ++x)
331
+ {
332
+ vec3f L = cube_to_dir(x, y, s, Npx);
333
+ if (dot(L, VNR) >= p.costheta_cutoff)
334
+ {
335
+ vec3f H = safeNormalize(L + VNR);
336
+
337
+ float wiDotN = max(dot(L, VNR), 0.0f);
338
+ float VNRDotH = max(dot(VNR, H), 0.0f);
339
+
340
+ float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
341
+
342
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
343
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
344
+ atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
345
+ }
346
+ }
347
+ }
348
+ }
349
+ }
350
+ }
models/lrm/models/geometry/render/renderutils/c_src/cubemap.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #pragma once
13
+
14
+ #include "common.h"
15
+
16
+ struct DiffuseCubemapKernelParams
17
+ {
18
+ Tensor cubemap;
19
+ Tensor out;
20
+ dim3 gridSize;
21
+ };
22
+
23
+ struct SpecularCubemapKernelParams
24
+ {
25
+ Tensor cubemap;
26
+ Tensor bounds;
27
+ Tensor out;
28
+ dim3 gridSize;
29
+ float costheta_cutoff;
30
+ float roughness;
31
+ };
32
+
33
+ struct SpecularBoundsKernelParams
34
+ {
35
+ float costheta_cutoff;
36
+ Tensor out;
37
+ dim3 gridSize;
38
+ };
models/lrm/models/geometry/render/renderutils/c_src/loss.cu ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ *
4
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
5
+ * property and proprietary rights in and to this material, related
6
+ * documentation and any modifications thereto. Any use, reproduction,
7
+ * disclosure or distribution of this material and related documentation
8
+ * without an express license agreement from NVIDIA CORPORATION or
9
+ * its affiliates is strictly prohibited.
10
+ */
11
+
12
+ #include <cuda.h>
13
+
14
+ #include "common.h"
15
+ #include "loss.h"
16
+
17
+ //------------------------------------------------------------------------
18
+ // Utils
19
+
20
+ __device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; }
21
+
22
+ __device__ float warpSum(float val) {
23
+ for (int i = 1; i < 32; i *= 2)
24
+ val += __shfl_xor_sync(0xFFFFFFFF, val, i);
25
+ return val;
26
+ }
27
+
28
+ //------------------------------------------------------------------------
29
+ // Tonemapping
30
+
31
+ __device__ inline float fwdSRGB(float x)
32
+ {
33
+ return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f);
34
+ }
35
+
36
+ __device__ inline void bwdSRGB(float x, float &d_x, float d_out)
37
+ {
38
+ if (x > 0.0031308f)
39
+ d_x += d_out * 0.439583f / powf(x, 0.583333f);
40
+ else if (x > 0.0f)
41
+ d_x += d_out * 12.92f;
42
+ }
43
+
44
+ __device__ inline vec3f fwdTonemapLogSRGB(vec3f x)
45
+ {
46
+ return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f)));
47
+ }
48
+
49
+ __device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out)
50
+ {
51
+ if (x.x > 0.0f && x.x < 65535.0f)
52
+ {
53
+ bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x);
54
+ d_x.x *= 1 / (x.x + 1.0f);
55
+ }
56
+ if (x.y > 0.0f && x.y < 65535.0f)
57
+ {
58
+ bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y);
59
+ d_x.y *= 1 / (x.y + 1.0f);
60
+ }
61
+ if (x.z > 0.0f && x.z < 65535.0f)
62
+ {
63
+ bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z);
64
+ d_x.z *= 1 / (x.z + 1.0f);
65
+ }
66
+ }
67
+
68
+ __device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f)
69
+ {
70
+ return (img - target) * (img - target) / (img * img + target * target + eps);
71
+ }
72
+
73
+ __device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f)
74
+ {
75
+ float denom = (target * target + img * img + eps);
76
+ d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom);
77
+ d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom);
78
+ }
79
+
80
+ __device__ inline float fwdSMAPE(float img, float target, float eps=0.01f)
81
+ {
82
+ return abs(img - target) / (img + target + eps);
83
+ }
84
+
85
+ __device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f)
86
+ {
87
+ float denom = (target + img + eps);
88
+ d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom);
89
+ d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom);
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+ // Kernels
94
+
95
+ __global__ void imgLossFwdKernel(LossKernelParams p)
96
+ {
97
+ // Calculate pixel position.
98
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
99
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
100
+ unsigned int pz = blockIdx.z;
101
+
102
+ float floss = 0.0f;
103
+ if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z)
104
+ {
105
+ vec3f img = p.img.fetch3(px, py, pz);
106
+ vec3f target = p.target.fetch3(px, py, pz);
107
+
108
+ img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f));
109
+ target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f));
110
+
111
+ if (p.tonemapper == TONEMAPPER_LOG_SRGB)
112
+ {
113
+ img = fwdTonemapLogSRGB(img);
114
+ target = fwdTonemapLogSRGB(target);
115
+ }
116
+
117
+ vec3f vloss(0);
118
+ if (p.loss == LOSS_MSE)
119
+ vloss = (img - target) * (img - target);
120
+ else if (p.loss == LOSS_RELMSE)
121
+ vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z));
122
+ else if (p.loss == LOSS_SMAPE)
123
+ vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z));
124
+ else
125
+ vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z));
126
+
127
+ floss = sum(vloss) / 3.0f;
128
+ }
129
+
130
+ floss = warpSum(floss);
131
+
132
+ dim3 warpSize = getWarpSize(blockDim);
133
+ if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0)
134
+ p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss);
135
+ }
136
+
137
+ __global__ void imgLossBwdKernel(LossKernelParams p)
138
+ {
139
+ // Calculate pixel position.
140
+ unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
141
+ unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
142
+ unsigned int pz = blockIdx.z;
143
+
144
+ if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
145
+ return;
146
+
147
+ dim3 warpSize = getWarpSize(blockDim);
148
+
149
+ vec3f _img = p.img.fetch3(px, py, pz);
150
+ vec3f _target = p.target.fetch3(px, py, pz);
151
+ float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z);
152
+
153
+ /////////////////////////////////////////////////////////////////////
154
+ // FWD
155
+
156
+ vec3f img = _img, target = _target;
157
+ if (p.tonemapper == TONEMAPPER_LOG_SRGB)
158
+ {
159
+ img = fwdTonemapLogSRGB(img);
160
+ target = fwdTonemapLogSRGB(target);
161
+ }
162
+
163
+ /////////////////////////////////////////////////////////////////////
164
+ // BWD
165
+
166
+ vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f;
167
+
168
+ vec3f d_img(0), d_target(0);
169
+ if (p.loss == LOSS_MSE)
170
+ {
171
+ d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z));
172
+ d_target = -d_img;
173
+ }
174
+ else if (p.loss == LOSS_RELMSE)
175
+ {
176
+ bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
177
+ bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
178
+ bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
179
+ }
180
+ else if (p.loss == LOSS_SMAPE)
181
+ {
182
+ bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
183
+ bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
184
+ bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
185
+ }
186
+ else
187
+ {
188
+ d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z));
189
+ d_target = -d_img;
190
+ }
191
+
192
+
193
+ if (p.tonemapper == TONEMAPPER_LOG_SRGB)
194
+ {
195
+ vec3f d__img(0), d__target(0);
196
+ bwdTonemapLogSRGB(_img, d__img, d_img);
197
+ bwdTonemapLogSRGB(_target, d__target, d_target);
198
+ d_img = d__img; d_target = d__target;
199
+ }
200
+
201
+ if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0;
202
+ if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0;
203
+ if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0;
204
+ if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0;
205
+ if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0;
206
+ if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0;
207
+
208
+ p.img.store_grad(px, py, pz, d_img);
209
+ p.target.store_grad(px, py, pz, d_target);
210
+ }