JiantaoLin
commited on
Commit
·
98bebfc
0
Parent(s):
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +319 -0
- extension/put_here.txt +0 -0
- image_to_mesh.py +437 -0
- models/ISOMER/__init__.py +0 -0
- models/ISOMER/data/__init__.py +0 -0
- models/ISOMER/data/utils.py +87 -0
- models/ISOMER/mesh_reconstruction/__init__.py +0 -0
- models/ISOMER/mesh_reconstruction/func.py +227 -0
- models/ISOMER/mesh_reconstruction/opt.py +191 -0
- models/ISOMER/mesh_reconstruction/recon.py +58 -0
- models/ISOMER/mesh_reconstruction/refine.py +86 -0
- models/ISOMER/mesh_reconstruction/remesh.py +363 -0
- models/ISOMER/mesh_reconstruction/render.py +142 -0
- models/ISOMER/model/__init__.py +0 -0
- models/ISOMER/model/inference_pipeline.py +189 -0
- models/ISOMER/projection_func.py +86 -0
- models/ISOMER/reconstruction_func.py +88 -0
- models/ISOMER/scripts/__init__.py +0 -0
- models/ISOMER/scripts/all_typing.py +42 -0
- models/ISOMER/scripts/fast_geo.py +86 -0
- models/ISOMER/scripts/load_onnx.py +48 -0
- models/ISOMER/scripts/mesh_init.py +142 -0
- models/ISOMER/scripts/normal_to_height_map.py +205 -0
- models/ISOMER/scripts/proj_commands.py +69 -0
- models/ISOMER/scripts/project_mesh.py +401 -0
- models/ISOMER/scripts/refine_lr_to_sr.py +60 -0
- models/ISOMER/scripts/sd_model_zoo.py +131 -0
- models/ISOMER/scripts/upsampler.py +260 -0
- models/ISOMER/scripts/utils.py +611 -0
- models/lrm/config/PRM_inference.yaml +22 -0
- models/lrm/models/__init__.py +0 -0
- models/lrm/models/decoder/__init__.py +0 -0
- models/lrm/models/decoder/transformer.py +123 -0
- models/lrm/models/encoder/__init__.py +0 -0
- models/lrm/models/encoder/dino.py +550 -0
- models/lrm/models/encoder/dino_wrapper.py +80 -0
- models/lrm/models/geometry/__init__.py +7 -0
- models/lrm/models/geometry/camera/__init__.py +16 -0
- models/lrm/models/geometry/camera/perspective_camera.py +35 -0
- models/lrm/models/geometry/render/__init__.py +8 -0
- models/lrm/models/geometry/render/neural_render.py +293 -0
- models/lrm/models/geometry/render/renderutils/__init__.py +11 -0
- models/lrm/models/geometry/render/renderutils/bsdf.py +151 -0
- models/lrm/models/geometry/render/renderutils/c_src/bsdf.cu +710 -0
- models/lrm/models/geometry/render/renderutils/c_src/bsdf.h +84 -0
- models/lrm/models/geometry/render/renderutils/c_src/common.cpp +74 -0
- models/lrm/models/geometry/render/renderutils/c_src/common.h +41 -0
- models/lrm/models/geometry/render/renderutils/c_src/cubemap.cu +350 -0
- models/lrm/models/geometry/render/renderutils/c_src/cubemap.h +38 -0
- models/lrm/models/geometry/render/renderutils/c_src/loss.cu +210 -0
app.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
import shlex
|
5 |
+
subprocess.run(
|
6 |
+
shlex.split(
|
7 |
+
"pip install ./extension/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
|
8 |
+
)
|
9 |
+
)
|
10 |
+
|
11 |
+
subprocess.run(
|
12 |
+
shlex.split(
|
13 |
+
"pip install ./extension/renderutils_plugin-1.0-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
|
14 |
+
)
|
15 |
+
)
|
16 |
+
import torch
|
17 |
+
import numpy as np
|
18 |
+
from PIL import Image
|
19 |
+
from einops import rearrange
|
20 |
+
from diffusers import FluxPipeline
|
21 |
+
from models.lrm.utils.camera_util import get_flux_input_cameras
|
22 |
+
from models.lrm.utils.infer_util import save_video
|
23 |
+
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
24 |
+
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
25 |
+
from models.lrm.utils.train_util import instantiate_from_config
|
26 |
+
from models.ISOMER.reconstruction_func import reconstruction
|
27 |
+
from models.ISOMER.projection_func import projection
|
28 |
+
import os
|
29 |
+
from einops import rearrange
|
30 |
+
from omegaconf import OmegaConf
|
31 |
+
import spaces
|
32 |
+
import torch
|
33 |
+
import numpy as np
|
34 |
+
import trimesh
|
35 |
+
import torchvision
|
36 |
+
import torch.nn.functional as F
|
37 |
+
from PIL import Image
|
38 |
+
from torchvision import transforms
|
39 |
+
from torchvision.transforms import v2
|
40 |
+
from diffusers import HeunDiscreteScheduler
|
41 |
+
from diffusers import FluxPipeline
|
42 |
+
from pytorch_lightning import seed_everything
|
43 |
+
import os
|
44 |
+
from huggingface_hub import hf_hub_download
|
45 |
+
|
46 |
+
|
47 |
+
from utils.tool import NormalTransfer, get_background, get_render_cameras_video, load_mipmap, render_frames
|
48 |
+
|
49 |
+
device = "cuda"
|
50 |
+
resolution = 512
|
51 |
+
save_dir = "./outputs"
|
52 |
+
normal_transfer = NormalTransfer()
|
53 |
+
isomer_azimuths = torch.from_numpy(np.array([0, 90, 180, 270])).float().to(device)
|
54 |
+
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).float().to(device)
|
55 |
+
isomer_radius = 4.5
|
56 |
+
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
|
57 |
+
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
|
58 |
+
|
59 |
+
# model initialization and loading
|
60 |
+
# flux
|
61 |
+
flux_pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
|
62 |
+
flux_lora_ckpt_path = hf_hub_download(repo_id="LTT/xxx-ckpt", filename="rgb_normal_large.safetensors", repo_type="model")
|
63 |
+
flux_pipe.load_lora_weights(flux_lora_ckpt_path)
|
64 |
+
|
65 |
+
flux_pipe.to(device=device, dtype=torch.bfloat16)
|
66 |
+
generator = torch.Generator(device=device).manual_seed(10)
|
67 |
+
|
68 |
+
# lrm
|
69 |
+
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
|
70 |
+
model_config = config.model_config
|
71 |
+
infer_config = config.infer_config
|
72 |
+
model = instantiate_from_config(model_config)
|
73 |
+
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
74 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
75 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
76 |
+
model.load_state_dict(state_dict, strict=True)
|
77 |
+
|
78 |
+
model = model.to(device)
|
79 |
+
model.init_flexicubes_geometry(device, fovy=50.0)
|
80 |
+
model = model.eval()
|
81 |
+
|
82 |
+
@spaces.GPU
|
83 |
+
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False):
|
84 |
+
images = image.unsqueeze(0).to(device)
|
85 |
+
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
86 |
+
# breakpoint()
|
87 |
+
with torch.no_grad():
|
88 |
+
# get triplane
|
89 |
+
planes = model.forward_planes(images, input_cameras)
|
90 |
+
|
91 |
+
mesh_path_idx = os.path.join(save_path, f'{name}.obj')
|
92 |
+
|
93 |
+
mesh_out = model.extract_mesh(
|
94 |
+
planes,
|
95 |
+
use_texture_map=export_texmap,
|
96 |
+
**infer_config,
|
97 |
+
)
|
98 |
+
if export_texmap:
|
99 |
+
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
100 |
+
save_obj_with_mtl(
|
101 |
+
vertices.data.cpu().numpy(),
|
102 |
+
uvs.data.cpu().numpy(),
|
103 |
+
faces.data.cpu().numpy(),
|
104 |
+
mesh_tex_idx.data.cpu().numpy(),
|
105 |
+
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
106 |
+
mesh_path_idx,
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
vertices, faces, vertex_colors = mesh_out
|
110 |
+
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
111 |
+
print(f"Mesh saved to {mesh_path_idx}")
|
112 |
+
|
113 |
+
render_size = 512
|
114 |
+
if if_save_video:
|
115 |
+
video_path_idx = os.path.join(save_path, f'{name}.mp4')
|
116 |
+
render_size = infer_config.render_resolution
|
117 |
+
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
118 |
+
materials = (0.0,0.9)
|
119 |
+
|
120 |
+
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
121 |
+
batch_size=1,
|
122 |
+
M=240,
|
123 |
+
radius=4.5,
|
124 |
+
elevation=(90, 60.0),
|
125 |
+
is_flexicubes=True,
|
126 |
+
fov=30
|
127 |
+
)
|
128 |
+
|
129 |
+
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
130 |
+
model,
|
131 |
+
planes,
|
132 |
+
render_cameras=all_mvp,
|
133 |
+
camera_pos=all_campos,
|
134 |
+
env=ENV,
|
135 |
+
materials=materials,
|
136 |
+
render_size=render_size,
|
137 |
+
chunk_size=20,
|
138 |
+
is_flexicubes=True,
|
139 |
+
)
|
140 |
+
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
141 |
+
normals = normals * alphas + (1-alphas)
|
142 |
+
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
143 |
+
|
144 |
+
save_video(
|
145 |
+
all_frames,
|
146 |
+
video_path_idx,
|
147 |
+
fps=30,
|
148 |
+
)
|
149 |
+
print(f"Video saved to {video_path_idx}")
|
150 |
+
|
151 |
+
return vertices, faces
|
152 |
+
|
153 |
+
|
154 |
+
def local_normal_global_transform(local_normal_images, azimuths_deg, elevations_deg):
|
155 |
+
if local_normal_images.min() >= 0:
|
156 |
+
local_normal = local_normal_images.float() * 2 - 1
|
157 |
+
else:
|
158 |
+
local_normal = local_normal_images.float()
|
159 |
+
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
|
160 |
+
global_normal[...,0] *= -1
|
161 |
+
global_normal = (global_normal + 1) / 2
|
162 |
+
global_normal = global_normal.permute(0, 3, 1, 2)
|
163 |
+
return global_normal
|
164 |
+
|
165 |
+
# 生成多视图图像
|
166 |
+
@spaces.GPU
|
167 |
+
def generate_multi_view_images(prompt, seed):
|
168 |
+
generator = torch.manual_seed(seed)
|
169 |
+
with torch.no_grad():
|
170 |
+
images = flux_pipe(
|
171 |
+
prompt=prompt,
|
172 |
+
num_inference_steps=30,
|
173 |
+
guidance_scale=3.5,
|
174 |
+
num_images_per_prompt=1,
|
175 |
+
width=resolution * 4,
|
176 |
+
height=resolution * 2,
|
177 |
+
output_type='np',
|
178 |
+
generator=generator
|
179 |
+
).images
|
180 |
+
return images
|
181 |
+
|
182 |
+
# 重建 3D 模型
|
183 |
+
@spaces.GPU
|
184 |
+
def reconstruct_3d_model(images, prompt):
|
185 |
+
rgb_normal_grid = images
|
186 |
+
save_dir_path = os.path.join(save_dir, prompt.replace(" ", "_"))
|
187 |
+
os.makedirs(save_dir_path, exist_ok=True)
|
188 |
+
|
189 |
+
images = torch.from_numpy(rgb_normal_grid).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
190 |
+
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
|
191 |
+
rgb_multi_view = images[:4, :3, :, :]
|
192 |
+
normal_multi_view = images[4:, :3, :, :]
|
193 |
+
multi_view_mask = get_background(normal_multi_view)
|
194 |
+
rgb_multi_view = rgb_multi_view * rgb_multi_view + (1-multi_view_mask)
|
195 |
+
input_cameras = get_flux_input_cameras(batch_size=1, radius=4.2, fov=30).to(device)
|
196 |
+
vertices, faces = lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm', export_texmap=False, if_save_video=False)
|
197 |
+
# local normal to global normal
|
198 |
+
|
199 |
+
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1), isomer_azimuths, isomer_elevations)
|
200 |
+
global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
|
201 |
+
|
202 |
+
global_normal = global_normal.permute(0,2,3,1)
|
203 |
+
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
204 |
+
multi_view_mask = multi_view_mask.permute(0,2,3,1).squeeze(-1)
|
205 |
+
vertices = torch.from_numpy(vertices).to(device)
|
206 |
+
faces = torch.from_numpy(faces).to(device)
|
207 |
+
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
|
208 |
+
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
|
209 |
+
|
210 |
+
# global_normal: B,H,W,3
|
211 |
+
# multi_view_mask: B,H,W
|
212 |
+
# rgb_multi_view: B,H,W,3
|
213 |
+
|
214 |
+
meshes = reconstruction(
|
215 |
+
normal_pils=global_normal,
|
216 |
+
masks=multi_view_mask,
|
217 |
+
weights=isomer_geo_weights,
|
218 |
+
fov=30,
|
219 |
+
radius=isomer_radius,
|
220 |
+
camera_angles_azi=isomer_azimuths,
|
221 |
+
camera_angles_ele=isomer_elevations,
|
222 |
+
expansion_weight_stage1=0.1,
|
223 |
+
init_type="file",
|
224 |
+
init_verts=vertices,
|
225 |
+
init_faces=faces,
|
226 |
+
stage1_steps=0,
|
227 |
+
stage2_steps=50,
|
228 |
+
start_edge_len_stage1=0.1,
|
229 |
+
end_edge_len_stage1=0.02,
|
230 |
+
start_edge_len_stage2=0.02,
|
231 |
+
end_edge_len_stage2=0.005,
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
+
save_glb_addr = projection(
|
236 |
+
meshes,
|
237 |
+
masks=multi_view_mask,
|
238 |
+
images=rgb_multi_view,
|
239 |
+
azimuths=isomer_azimuths,
|
240 |
+
elevations=isomer_elevations,
|
241 |
+
weights=isomer_color_weights,
|
242 |
+
fov=30,
|
243 |
+
radius=isomer_radius,
|
244 |
+
save_dir=f"{save_dir_path}/ISOMER/",
|
245 |
+
)
|
246 |
+
|
247 |
+
return save_glb_addr
|
248 |
+
|
249 |
+
# Gradio 接口函数
|
250 |
+
def gradio_pipeline(prompt, seed):
|
251 |
+
# 生成多视图图像
|
252 |
+
rgb_normal_grid = generate_multi_view_images(prompt, seed)
|
253 |
+
image_preview = Image.fromarray((rgb_normal_grid * 255).astype(np.uint8))
|
254 |
+
|
255 |
+
# 3d reconstruction
|
256 |
+
|
257 |
+
|
258 |
+
# 重建 3D 模型并返回 glb 路径
|
259 |
+
save_glb_addr = reconstruct_3d_model(rgb_normal_grid, prompt)
|
260 |
+
|
261 |
+
return image_preview, save_glb_addr
|
262 |
+
|
263 |
+
# Gradio Blocks 应用
|
264 |
+
with gr.Blocks() as demo:
|
265 |
+
with gr.Row(variant="panel"):
|
266 |
+
# 左侧输入区域
|
267 |
+
with gr.Column():
|
268 |
+
with gr.Row():
|
269 |
+
prompt_input = gr.Textbox(
|
270 |
+
label="Enter Prompt",
|
271 |
+
placeholder="Describe your 3D model...",
|
272 |
+
lines=2,
|
273 |
+
elem_id="prompt_input"
|
274 |
+
)
|
275 |
+
|
276 |
+
with gr.Row():
|
277 |
+
sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
|
278 |
+
|
279 |
+
with gr.Row():
|
280 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
281 |
+
|
282 |
+
with gr.Row(variant="panel"):
|
283 |
+
gr.Markdown("Examples:")
|
284 |
+
gr.Examples(
|
285 |
+
examples=[
|
286 |
+
["a castle on a hill"],
|
287 |
+
["an owl wearing a hat"],
|
288 |
+
["a futuristic car"]
|
289 |
+
],
|
290 |
+
inputs=[prompt_input],
|
291 |
+
label="Prompt Examples"
|
292 |
+
)
|
293 |
+
|
294 |
+
# 右侧输出区域
|
295 |
+
with gr.Column():
|
296 |
+
with gr.Row():
|
297 |
+
rgb_normal_grid_image = gr.Image(
|
298 |
+
label="RGB Normal Grid",
|
299 |
+
type="pil",
|
300 |
+
interactive=False
|
301 |
+
)
|
302 |
+
|
303 |
+
with gr.Row():
|
304 |
+
with gr.Tab("GLB"):
|
305 |
+
output_glb_model = gr.Model3D(
|
306 |
+
label="Generated 3D Model (GLB Format)",
|
307 |
+
interactive=False
|
308 |
+
)
|
309 |
+
gr.Markdown("Download the model for proper visualization.")
|
310 |
+
|
311 |
+
# 处理逻辑
|
312 |
+
submit.click(
|
313 |
+
fn=gradio_pipeline, inputs=[prompt_input, sample_seed],
|
314 |
+
outputs=[rgb_normal_grid_image, output_glb_model]
|
315 |
+
)
|
316 |
+
|
317 |
+
# 启动应用
|
318 |
+
demo.queue(max_size=10)
|
319 |
+
demo.launch(server_port=1211)
|
extension/put_here.txt
ADDED
File without changes
|
image_to_mesh.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from einops import rearrange
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import trimesh
|
7 |
+
import torchvision
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision import transforms
|
11 |
+
from torchvision.transforms import v2
|
12 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
13 |
+
import rembg
|
14 |
+
from diffusers import FluxPipeline, FluxControlNetImg2ImgPipeline
|
15 |
+
from diffusers.models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
|
16 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler
|
17 |
+
from pytorch_lightning import seed_everything
|
18 |
+
import os
|
19 |
+
|
20 |
+
from models.ISOMER.reconstruction_func import reconstruction
|
21 |
+
from models.ISOMER.projection_func import projection
|
22 |
+
from models.lrm.utils.infer_util import remove_background, resize_foreground, save_video
|
23 |
+
from models.lrm.utils.mesh_util import save_obj, save_obj_with_mtl
|
24 |
+
from models.lrm.utils.render_utils import rotate_x, rotate_y
|
25 |
+
from models.lrm.utils.train_util import instantiate_from_config
|
26 |
+
from models.lrm.utils.camera_util import get_zero123plus_input_cameras, get_custom_zero123plus_input_cameras, get_flux_input_cameras
|
27 |
+
from utils.tool import NormalTransfer, get_render_cameras_frames, load_mipmap
|
28 |
+
from utils.tool import get_background, get_render_cameras_video, render_frames
|
29 |
+
import time
|
30 |
+
|
31 |
+
device = "cuda"
|
32 |
+
resolution = 512
|
33 |
+
save_dir = "./outputs"
|
34 |
+
zero123plus_diffusion_steps = 75
|
35 |
+
normal_transfer = NormalTransfer()
|
36 |
+
rembg_session = rembg.new_session()
|
37 |
+
isomer_azimuths = torch.from_numpy(np.array([270, 0, 90, 180])).to(device)
|
38 |
+
isomer_elevations = torch.from_numpy(np.array([5, 5, 5, 5])).to(device)
|
39 |
+
isomer_radius = 4.1
|
40 |
+
isomer_geo_weights = torch.from_numpy(np.array([1, 0.9, 1, 0.9])).float().to(device)
|
41 |
+
isomer_color_weights = torch.from_numpy(np.array([1, 0.5, 1, 0.5])).float().to(device)
|
42 |
+
# seed_everything(42)
|
43 |
+
|
44 |
+
# model initialization and loading
|
45 |
+
# flux
|
46 |
+
print('==> Loading Flux model ...')
|
47 |
+
flux_base_model_pth = "/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/models--black-forest-labs--FLUX.1-dev"
|
48 |
+
flux_controlnet = FluxControlNetModel.from_pretrained("/hpc2hdd/JH_DATA/share/yingcongchen/PrivateShareGroup/yingcongchen_datasets/model_checkpoint/flux_controlnets/FLUX.1-dev-ControlNet-Union-Pro")
|
49 |
+
flux_pipe = FluxControlNetImg2ImgPipeline.from_pretrained(flux_base_model_pth, controlnet=[flux_controlnet], torch_dtype=torch.bfloat16).to(device=device, dtype=torch.bfloat16)
|
50 |
+
|
51 |
+
flux_pipe.load_lora_weights('./checkpoint/flux_lora/rgb_normal_large.safetensors')
|
52 |
+
|
53 |
+
|
54 |
+
flux_pipe.to(device=device, dtype=torch.bfloat16)
|
55 |
+
generator = torch.Generator(device=device).manual_seed(0)
|
56 |
+
|
57 |
+
# lrm
|
58 |
+
print('==> Loading LRM model ...')
|
59 |
+
config = OmegaConf.load("./models/lrm/config/PRM_inference.yaml")
|
60 |
+
model_config = config.model_config
|
61 |
+
infer_config = config.infer_config
|
62 |
+
model = instantiate_from_config(model_config)
|
63 |
+
model_ckpt_path = "./checkpoint/lrm/final_ckpt.ckpt"
|
64 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
65 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
66 |
+
model.load_state_dict(state_dict, strict=True)
|
67 |
+
|
68 |
+
model = model.to(device)
|
69 |
+
model.init_flexicubes_geometry(device, fovy=50.0)
|
70 |
+
model = model.eval()
|
71 |
+
|
72 |
+
# zero123++
|
73 |
+
print('==> Loading diffusion model ...')
|
74 |
+
zero123plus_pipeline = DiffusionPipeline.from_pretrained(
|
75 |
+
"sudo-ai/zero123plus-v1.2",
|
76 |
+
custom_pipeline="./models/zero123plus",
|
77 |
+
torch_dtype=torch.float16,
|
78 |
+
)
|
79 |
+
zero123plus_pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
80 |
+
zero123plus_pipeline.scheduler.config, timestep_spacing='trailing'
|
81 |
+
)
|
82 |
+
unet_ckpt_path = "./checkpoint/zero123++/flexgen_19w.ckpt"
|
83 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
|
84 |
+
state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
|
85 |
+
zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
|
86 |
+
zero123plus_pipeline = zero123plus_pipeline.to(device)
|
87 |
+
|
88 |
+
# unet_ckpt_path = "checkpoint/zero123++/diffusion_pytorch_model.bin"
|
89 |
+
# state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
90 |
+
# zero123plus_pipeline.unet.load_state_dict(state_dict, strict=True)
|
91 |
+
# zero123plus_pipeline = zero123plus_pipeline.to(device)
|
92 |
+
|
93 |
+
# florence
|
94 |
+
caption_model = AutoModelForCausalLM.from_pretrained(
|
95 |
+
"/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", torch_dtype=torch.bfloat16, trust_remote_code=True,
|
96 |
+
).to(device)
|
97 |
+
caption_processor = AutoProcessor.from_pretrained("/hpc2hdd/home/jlin695/.cache/huggingface/hub/models--multimodalart--Florence-2-large-no-flash-attn/snapshots/8db3793cf5b453b2ccfb3a4f613b403b2e6b7ca2", trust_remote_code=True)
|
98 |
+
|
99 |
+
# Flux multi-view generation
|
100 |
+
def multi_view_rgb_normal_generation_with_controlnet(prompt, image, strength=1.0,
|
101 |
+
control_image=[],
|
102 |
+
control_mode=[],
|
103 |
+
control_guidance_start=None,
|
104 |
+
control_guidance_end=None,
|
105 |
+
controlnet_conditioning_scale=None,
|
106 |
+
lora_scale=1.0
|
107 |
+
):
|
108 |
+
control_mode_dict = {
|
109 |
+
'canny': 0,
|
110 |
+
'tile': 1,
|
111 |
+
'depth': 2,
|
112 |
+
'blur': 3,
|
113 |
+
'pose': 4,
|
114 |
+
'gray': 5,
|
115 |
+
'lq': 6,
|
116 |
+
} # for https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union only
|
117 |
+
|
118 |
+
hparam_dict = {
|
119 |
+
'prompt': prompt,
|
120 |
+
'image': image,
|
121 |
+
'strength': strength,
|
122 |
+
'num_inference_steps': 30,
|
123 |
+
'guidance_scale': 3.5,
|
124 |
+
'num_images_per_prompt': 1,
|
125 |
+
'width': resolution*4,
|
126 |
+
'height': resolution*2,
|
127 |
+
'output_type': 'np',
|
128 |
+
'generator': generator,
|
129 |
+
'joint_attention_kwargs': {"scale": lora_scale}
|
130 |
+
}
|
131 |
+
|
132 |
+
# append controlnet hparams
|
133 |
+
if len(control_image) > 0:
|
134 |
+
assert len(control_mode) == len(control_image) # the count of image should be the same as control mode
|
135 |
+
|
136 |
+
ctrl_hparams = {
|
137 |
+
'control_mode': [control_mode_dict[mode_] for mode_ in control_mode],
|
138 |
+
'control_image': control_image,
|
139 |
+
'control_guidance_start': control_guidance_start or [0.0 for i in range(len(control_image))],
|
140 |
+
'control_guidance_end': control_guidance_end or [1.0 for i in range(len(control_image))],
|
141 |
+
'controlnet_conditioning_scale': controlnet_conditioning_scale or [1.0 for i in range(len(control_image))],
|
142 |
+
}
|
143 |
+
|
144 |
+
hparam_dict.update(ctrl_hparams)
|
145 |
+
|
146 |
+
# generate multi-view images
|
147 |
+
with torch.no_grad():
|
148 |
+
image = flux_pipe(
|
149 |
+
**hparam_dict
|
150 |
+
).images
|
151 |
+
return image
|
152 |
+
|
153 |
+
# captioning
|
154 |
+
def run_captioning(image):
|
155 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
156 |
+
torch_dtype = torch.bfloat16
|
157 |
+
|
158 |
+
if isinstance(image, str): # If image is a file path
|
159 |
+
image = Image.open(image).convert("RGB")
|
160 |
+
|
161 |
+
prompt = "<MORE_DETAILED_CAPTION>"
|
162 |
+
inputs = caption_processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
|
163 |
+
# print(f"inputs {inputs}")
|
164 |
+
|
165 |
+
generated_ids = caption_model.generate(
|
166 |
+
input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3
|
167 |
+
)
|
168 |
+
|
169 |
+
generated_text = caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
170 |
+
parsed_answer = caption_processor.post_process_generation(
|
171 |
+
generated_text, task=prompt, image_size=(image.width, image.height)
|
172 |
+
)
|
173 |
+
# print(f"parsed_answer = {parsed_answer}")
|
174 |
+
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"].replace("The image is ", "")
|
175 |
+
return caption_text
|
176 |
+
|
177 |
+
|
178 |
+
# zero123++ multi-view generation
|
179 |
+
def multi_view_rgb_generation(cond_img):
|
180 |
+
# generate multi-view images
|
181 |
+
with torch.no_grad():
|
182 |
+
output_image = zero123plus_pipeline(
|
183 |
+
cond_img,
|
184 |
+
num_inference_steps=zero123plus_diffusion_steps,
|
185 |
+
width=resolution*2,
|
186 |
+
height=resolution*2,
|
187 |
+
).images[0]
|
188 |
+
return output_image
|
189 |
+
|
190 |
+
# lrm reconstructions
|
191 |
+
def lrm_reconstructions(image, input_cameras, save_path=None, name="temp", export_texmap=False, if_save_video=False, render_azimuths=None, render_elevations=None, render_radius=None, render_fov=30):
|
192 |
+
images = image.unsqueeze(0).to(device)
|
193 |
+
images = v2.functional.resize(images, 512, interpolation=3, antialias=True).clamp(0, 1)
|
194 |
+
# breakpoint()
|
195 |
+
with torch.no_grad():
|
196 |
+
# get triplane
|
197 |
+
planes = model.forward_planes(images, input_cameras)
|
198 |
+
|
199 |
+
mesh_path_idx = os.path.join(save_path, f'{name}.obj')
|
200 |
+
|
201 |
+
mesh_out = model.extract_mesh(
|
202 |
+
planes,
|
203 |
+
use_texture_map=export_texmap,
|
204 |
+
**infer_config,
|
205 |
+
)
|
206 |
+
if export_texmap:
|
207 |
+
vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
|
208 |
+
save_obj_with_mtl(
|
209 |
+
vertices.data.cpu().numpy(),
|
210 |
+
uvs.data.cpu().numpy(),
|
211 |
+
faces.data.cpu().numpy(),
|
212 |
+
mesh_tex_idx.data.cpu().numpy(),
|
213 |
+
tex_map.permute(1, 2, 0).data.cpu().numpy(),
|
214 |
+
mesh_path_idx,
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
vertices, faces, vertex_colors = mesh_out
|
218 |
+
save_obj(vertices, faces, vertex_colors, mesh_path_idx)
|
219 |
+
print(f"Mesh saved to {mesh_path_idx}")
|
220 |
+
|
221 |
+
render_size = 512
|
222 |
+
if if_save_video:
|
223 |
+
video_path_idx = os.path.join(save_path, f'{name}.mp4')
|
224 |
+
render_size = infer_config.render_resolution
|
225 |
+
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
226 |
+
materials = (0.0,0.9)
|
227 |
+
|
228 |
+
all_mv, all_mvp, all_campos = get_render_cameras_video(
|
229 |
+
batch_size=1,
|
230 |
+
M=240,
|
231 |
+
radius=4.5,
|
232 |
+
elevation=(90, 60.0),
|
233 |
+
is_flexicubes=True,
|
234 |
+
fov=30
|
235 |
+
)
|
236 |
+
|
237 |
+
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
238 |
+
model,
|
239 |
+
planes,
|
240 |
+
render_cameras=all_mvp,
|
241 |
+
camera_pos=all_campos,
|
242 |
+
env=ENV,
|
243 |
+
materials=materials,
|
244 |
+
render_size=render_size,
|
245 |
+
chunk_size=20,
|
246 |
+
is_flexicubes=True,
|
247 |
+
)
|
248 |
+
normals = (torch.nn.functional.normalize(normals) + 1) / 2
|
249 |
+
normals = normals * alphas + (1-alphas)
|
250 |
+
all_frames = torch.cat([frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals], dim=3)
|
251 |
+
|
252 |
+
# breakpoint()
|
253 |
+
save_video(
|
254 |
+
all_frames,
|
255 |
+
video_path_idx,
|
256 |
+
fps=30,
|
257 |
+
)
|
258 |
+
print(f"Video saved to {video_path_idx}")
|
259 |
+
|
260 |
+
if render_azimuths is not None and render_elevations is not None and render_radius is not None:
|
261 |
+
render_size = infer_config.render_resolution
|
262 |
+
ENV = load_mipmap("models/lrm/env_mipmap/6")
|
263 |
+
materials = (0.0,0.9)
|
264 |
+
all_mv, all_mvp, all_campos, identity_mv = get_render_cameras_frames(
|
265 |
+
batch_size=1,
|
266 |
+
radius=render_radius,
|
267 |
+
azimuths=render_azimuths,
|
268 |
+
elevations=render_elevations,
|
269 |
+
fov=30
|
270 |
+
)
|
271 |
+
frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas = render_frames(
|
272 |
+
model,
|
273 |
+
planes,
|
274 |
+
render_cameras=all_mvp,
|
275 |
+
camera_pos=all_campos,
|
276 |
+
env=ENV,
|
277 |
+
materials=materials,
|
278 |
+
render_size=render_size,
|
279 |
+
render_mv = all_mv,
|
280 |
+
local_normal=True,
|
281 |
+
identity_mv=identity_mv,
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
normals = None
|
285 |
+
frames = None
|
286 |
+
albedos = None
|
287 |
+
|
288 |
+
return vertices, faces, normals, frames, albedos
|
289 |
+
|
290 |
+
|
291 |
+
def transform_normal(input_normal, azimuths_deg, elevations_deg, radius=4.5, is_global_to_local=False):
|
292 |
+
"""
|
293 |
+
input_normal: in range [-1, 1], shape (b c h w)
|
294 |
+
"""
|
295 |
+
|
296 |
+
input_normal = input_normal.permute(0, 2, 3, 1).cpu()
|
297 |
+
|
298 |
+
azimuths_deg = np.array(azimuths_deg)
|
299 |
+
elevations_deg = np.array(elevations_deg)
|
300 |
+
|
301 |
+
if is_global_to_local:
|
302 |
+
local_normal = normal_transfer.trans_global_2_local(input_normal, azimuths_deg, elevations_deg)
|
303 |
+
return local_normal.permute(0, 3, 1, 2)
|
304 |
+
else:
|
305 |
+
global_normal = normal_transfer.trans_local_2_global(input_normal, azimuths_deg, elevations_deg, radius=radius, for_lotus=False)
|
306 |
+
global_normal[..., 0] *= -1
|
307 |
+
return global_normal.permute(0, 3, 1, 2)
|
308 |
+
|
309 |
+
def local_normal_global_transform(local_normal_images,azimuths_deg,elevations_deg):
|
310 |
+
if local_normal_images.min() >= 0:
|
311 |
+
local_normal = local_normal_images.float() * 2 - 1
|
312 |
+
else:
|
313 |
+
local_normal = local_normal_images.float()
|
314 |
+
global_normal = normal_transfer.trans_local_2_global(local_normal, azimuths_deg, elevations_deg, radius=4.5, for_lotus=False)
|
315 |
+
global_normal[...,0] *= -1
|
316 |
+
global_normal = (global_normal + 1) / 2
|
317 |
+
global_normal = global_normal.permute(0, 3, 1, 2)
|
318 |
+
return global_normal
|
319 |
+
|
320 |
+
def main():
|
321 |
+
image_pth = "examples/蓝色小怪物.webp"
|
322 |
+
save_dir_path = os.path.join(save_dir, image_pth.split("/")[-1].split(".")[0])
|
323 |
+
os.makedirs(save_dir_path, exist_ok=True)
|
324 |
+
input_image = Image.open(image_pth)
|
325 |
+
# if not args.no_rembg:
|
326 |
+
input_image = remove_background(input_image, rembg_session)
|
327 |
+
input_image = resize_foreground(input_image, 0.85)
|
328 |
+
|
329 |
+
# generate caption
|
330 |
+
image_caption = run_captioning(image_pth)
|
331 |
+
|
332 |
+
# generate multi-view images
|
333 |
+
output_image = multi_view_rgb_generation(input_image)
|
334 |
+
|
335 |
+
# lrm reconstructions
|
336 |
+
rgb_multi_view = np.asarray(output_image, dtype=np.float32) / 255.0
|
337 |
+
rgb_multi_view = torch.from_numpy(rgb_multi_view).squeeze(0).permute(2, 0, 1).contiguous().float() # (3, 1024, 2048)
|
338 |
+
rgb_multi_view = rearrange(rgb_multi_view, 'c (n h) (m w) -> (n m) c h w', n=2, m=2) # (8, 3, 512, 512)
|
339 |
+
|
340 |
+
input_cameras = get_custom_zero123plus_input_cameras(batch_size=1, radius=3.5, fov=30).to(device)
|
341 |
+
|
342 |
+
vertices, faces, lrm_multi_view_normals, lrm_multi_view_rgb, lrm_multi_view_albedo = \
|
343 |
+
lrm_reconstructions(rgb_multi_view, input_cameras, save_path=save_dir_path, name='lrm',
|
344 |
+
export_texmap=False, if_save_video=False, render_azimuths=isomer_azimuths,
|
345 |
+
render_elevations=isomer_elevations, render_radius=isomer_radius, render_fov=30)
|
346 |
+
|
347 |
+
vertices = torch.from_numpy(vertices).to(device)
|
348 |
+
faces = torch.from_numpy(faces).to(device)
|
349 |
+
vertices = vertices @ rotate_x(np.pi / 2, device=vertices.device)[:3, :3]
|
350 |
+
vertices = vertices @ rotate_y(np.pi / 2, device=vertices.device)[:3, :3]
|
351 |
+
|
352 |
+
|
353 |
+
# lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([lrm_multi_view_rgb.cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
354 |
+
lrm_3D_bundle_image = torchvision.utils.make_grid(torch.cat([rgb_multi_view[[3,0,1,2]].cpu(), (lrm_multi_view_normals.cpu() + 1) / 2], dim=0), nrow=4, padding=0).unsqueeze(0) # range [0, 1]
|
355 |
+
# rgb_multi_view[[3,0,1,2]] : (B,3,H,W)
|
356 |
+
# lrm_multi_view_normals : (B,3,H,W)
|
357 |
+
# combined_images = 0.5 * rgb_multi_view[[3,0,1,2]].cpu() + 0.5 * (lrm_multi_view_normals.cpu() + 1) / 2
|
358 |
+
# torchvision.utils.save_image(combined_images, os.path.join("debug_output", 'combined.png'))
|
359 |
+
# breakpoint()
|
360 |
+
# Use the low-quality controlnet by default, feel free to try the others
|
361 |
+
control_image = [lrm_3D_bundle_image * 2 - 1]
|
362 |
+
control_mode = ['tile']
|
363 |
+
control_guidance_start = [0.0]
|
364 |
+
control_guidance_end = [0.3]
|
365 |
+
controlnet_conditioning_scale = [0.8]
|
366 |
+
|
367 |
+
flux_pipe.controlnet = FluxMultiControlNetModel([flux_controlnet for _ in control_mode])
|
368 |
+
# breakpoint()
|
369 |
+
rgb_normal_grid = multi_view_rgb_normal_generation_with_controlnet(
|
370 |
+
prompt= ' '.join(['A grid of 2x4 multi-view image, elevation 5. White background.', image_caption]),
|
371 |
+
image=lrm_3D_bundle_image,
|
372 |
+
strength=0.6,
|
373 |
+
control_image=control_image,
|
374 |
+
control_mode=control_mode,
|
375 |
+
control_guidance_start=control_guidance_start,
|
376 |
+
control_guidance_end=control_guidance_end,
|
377 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
378 |
+
lora_scale=1.0
|
379 |
+
) # noted that rgb_normal_grid is a (b, h, w, c) numpy array
|
380 |
+
|
381 |
+
rgb_normal_grid = torch.from_numpy(rgb_normal_grid).contiguous().float()
|
382 |
+
rgb_normal_grid = rearrange(rgb_normal_grid.squeeze(0), '(n h) (m w) c-> (n m) c h w', n=2, m=4) # (8, 3, 512, 512)
|
383 |
+
rgb_multi_view = rgb_normal_grid[:4, :3, :, :].cuda()
|
384 |
+
normal_multi_view = rgb_normal_grid[4:, :3, :, :].cuda()
|
385 |
+
multi_view_mask = get_background(normal_multi_view).cuda()
|
386 |
+
rgb_multi_view = rgb_multi_view * multi_view_mask + (1-multi_view_mask)
|
387 |
+
|
388 |
+
# local normal to global normal
|
389 |
+
global_normal = local_normal_global_transform(normal_multi_view.permute(0, 2, 3, 1).cpu(), isomer_azimuths, isomer_elevations).cuda()
|
390 |
+
|
391 |
+
global_normal = global_normal * multi_view_mask + (1-multi_view_mask)
|
392 |
+
|
393 |
+
global_normal = global_normal.permute(0,2,3,1)
|
394 |
+
multi_view_mask = multi_view_mask.squeeze(1)
|
395 |
+
rgb_multi_view = rgb_multi_view.permute(0,2,3,1)
|
396 |
+
# global_normal: B,H,W,3
|
397 |
+
# multi_view_mask: B,H,W
|
398 |
+
# rgb_multi_view: B,H,W,3
|
399 |
+
|
400 |
+
|
401 |
+
meshes = reconstruction(
|
402 |
+
normal_pils=global_normal,
|
403 |
+
masks=multi_view_mask,
|
404 |
+
weights=isomer_geo_weights,
|
405 |
+
fov=30,
|
406 |
+
radius=isomer_radius,
|
407 |
+
camera_angles_azi=isomer_azimuths,
|
408 |
+
camera_angles_ele=isomer_elevations,
|
409 |
+
expansion_weight_stage1=0.1,
|
410 |
+
init_type="file",
|
411 |
+
init_verts=vertices,
|
412 |
+
init_faces=faces,
|
413 |
+
stage1_steps=0,
|
414 |
+
stage2_steps=50,
|
415 |
+
start_edge_len_stage1=0.1,
|
416 |
+
end_edge_len_stage1=0.02,
|
417 |
+
start_edge_len_stage2=0.02,
|
418 |
+
end_edge_len_stage2=0.005,
|
419 |
+
)
|
420 |
+
|
421 |
+
save_glb_addr = projection(
|
422 |
+
meshes=meshes,
|
423 |
+
masks=multi_view_mask,
|
424 |
+
images=rgb_multi_view,
|
425 |
+
azimuths=isomer_azimuths,
|
426 |
+
elevations=isomer_elevations,
|
427 |
+
weights=isomer_color_weights,
|
428 |
+
fov=30,
|
429 |
+
radius=isomer_radius,
|
430 |
+
save_dir=f"{save_dir_path}/ISOMER/",
|
431 |
+
)
|
432 |
+
print(f'saved to {save_glb_addr}')
|
433 |
+
|
434 |
+
|
435 |
+
|
436 |
+
if __name__ == '__main__':
|
437 |
+
main()
|
models/ISOMER/__init__.py
ADDED
File without changes
|
models/ISOMER/data/__init__.py
ADDED
File without changes
|
models/ISOMER/data/utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import os
|
5 |
+
from pytorch3d.io import load_obj
|
6 |
+
import trimesh
|
7 |
+
from pytorch3d.structures import Meshes
|
8 |
+
# from rembg import remove
|
9 |
+
|
10 |
+
def remove_color(arr):
|
11 |
+
if arr.shape[-1] == 4:
|
12 |
+
arr = arr[..., :3]
|
13 |
+
|
14 |
+
# Convert to torch tensor
|
15 |
+
if type(arr) is not torch.Tensor:
|
16 |
+
arr = torch.tensor(arr, dtype=torch.int32)
|
17 |
+
|
18 |
+
# Calculate diffs
|
19 |
+
base = arr[0, 0]
|
20 |
+
diffs = torch.abs(arr - base).sum(dim=-1)
|
21 |
+
alpha = (diffs <= 80)
|
22 |
+
|
23 |
+
arr[alpha] = 255
|
24 |
+
alpha = ~alpha
|
25 |
+
alpha = alpha.unsqueeze(-1).int() * 255
|
26 |
+
arr = torch.cat([arr, alpha], dim=-1)
|
27 |
+
|
28 |
+
return arr
|
29 |
+
|
30 |
+
def simple_remove_bkg_normal(imgs, rm_bkg_with_rembg, return_Image=False):
|
31 |
+
"""Only works for normal"""
|
32 |
+
rets = []
|
33 |
+
for img in imgs:
|
34 |
+
if rm_bkg_with_rembg:
|
35 |
+
from rembg import remove
|
36 |
+
image = Image.fromarray(img.to(torch.uint8).detach().cpu().numpy()) if isinstance(img, torch.Tensor) else img
|
37 |
+
removed_image = remove(image)
|
38 |
+
arr = np.array(removed_image)
|
39 |
+
arr = torch.tensor(arr, dtype=torch.uint8)
|
40 |
+
else:
|
41 |
+
arr = remove_color(img)
|
42 |
+
|
43 |
+
if return_Image:
|
44 |
+
rets.append(Image.fromarray(arr.to(torch.uint8).detach().cpu().numpy()))
|
45 |
+
else:
|
46 |
+
rets.append(arr.to(torch.uint8))
|
47 |
+
|
48 |
+
return rets
|
49 |
+
|
50 |
+
|
51 |
+
def load_glb(file_path):
|
52 |
+
# Load the .glb file as a scene and merge all meshes
|
53 |
+
scene_or_mesh = trimesh.load(file_path)
|
54 |
+
|
55 |
+
mesh = scene_or_mesh.dump(concatenate=True) if isinstance(scene_or_mesh, trimesh.Scene) else scene_or_mesh
|
56 |
+
|
57 |
+
# Extract vertices and faces from the merged mesh
|
58 |
+
verts = torch.tensor(mesh.vertices, dtype=torch.float32)
|
59 |
+
faces = torch.tensor(mesh.faces, dtype=torch.int64)
|
60 |
+
|
61 |
+
|
62 |
+
textured_mesh = Meshes(verts=[verts], faces=[faces])
|
63 |
+
|
64 |
+
|
65 |
+
return textured_mesh
|
66 |
+
|
67 |
+
def load_obj_with_verts_faces(file_path, return_mesh=True):
|
68 |
+
verts, faces, _ = load_obj(file_path)
|
69 |
+
|
70 |
+
verts = torch.tensor(verts, dtype=torch.float32)
|
71 |
+
faces = faces.verts_idx
|
72 |
+
faces = torch.tensor(faces, dtype=torch.int64)
|
73 |
+
|
74 |
+
if return_mesh:
|
75 |
+
return Meshes(verts=[verts], faces=[faces])
|
76 |
+
else:
|
77 |
+
return verts, faces
|
78 |
+
|
79 |
+
def normalize_mesh(vertices):
|
80 |
+
min_vals, _ = torch.min(vertices, axis=0)
|
81 |
+
max_vals, _ = torch.max(vertices, axis=0)
|
82 |
+
center = (max_vals + min_vals) / 2
|
83 |
+
vertices = vertices - center
|
84 |
+
max_extent = torch.max(max_vals - min_vals)
|
85 |
+
scale = 2.0 / max_extent
|
86 |
+
vertices = vertices * scale
|
87 |
+
return vertices
|
models/ISOMER/mesh_reconstruction/__init__.py
ADDED
File without changes
|
models/ISOMER/mesh_reconstruction/func.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/Profactor/continuous-remeshing
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import trimesh
|
5 |
+
from typing import Tuple
|
6 |
+
from pytorch3d.renderer.cameras import camera_position_from_spherical_angles, look_at_rotation
|
7 |
+
from pytorch3d.renderer import (
|
8 |
+
FoVOrthographicCameras,
|
9 |
+
look_at_view_transform,
|
10 |
+
)
|
11 |
+
|
12 |
+
def to_numpy(*args):
|
13 |
+
def convert(a):
|
14 |
+
if isinstance(a,torch.Tensor):
|
15 |
+
return a.detach().cpu().numpy()
|
16 |
+
assert a is None or isinstance(a,np.ndarray)
|
17 |
+
return a
|
18 |
+
|
19 |
+
return convert(args[0]) if len(args)==1 else tuple(convert(a) for a in args)
|
20 |
+
|
21 |
+
def laplacian(
|
22 |
+
num_verts:int,
|
23 |
+
edges: torch.Tensor #E,2
|
24 |
+
) -> torch.Tensor: #sparse V,V
|
25 |
+
"""create sparse Laplacian matrix"""
|
26 |
+
V = num_verts
|
27 |
+
E = edges.shape[0]
|
28 |
+
|
29 |
+
#adjacency matrix,
|
30 |
+
idx = torch.cat([edges, edges.fliplr()], dim=0).type(torch.long).T # (2, 2*E)
|
31 |
+
ones = torch.ones(2*E, dtype=torch.float32, device=edges.device)
|
32 |
+
A = torch.sparse.FloatTensor(idx, ones, (V, V))
|
33 |
+
|
34 |
+
#degree matrix
|
35 |
+
deg = torch.sparse.sum(A, dim=1).to_dense()
|
36 |
+
idx = torch.arange(V, device=edges.device)
|
37 |
+
idx = torch.stack([idx, idx], dim=0)
|
38 |
+
D = torch.sparse.FloatTensor(idx, deg, (V, V))
|
39 |
+
|
40 |
+
return D - A
|
41 |
+
|
42 |
+
def _translation(x, y, z, device):
|
43 |
+
return torch.tensor([[1., 0, 0, x],
|
44 |
+
[0, 1, 0, y],
|
45 |
+
[0, 0, 1, z],
|
46 |
+
[0, 0, 0, 1]],device=device) #4,4
|
47 |
+
|
48 |
+
|
49 |
+
def _perspective(fovy, aspect=1.0, n=0.1, f=1000.0, device=None):
|
50 |
+
fovy = fovy * torch.pi / 180
|
51 |
+
y = np.tan(fovy / 2)
|
52 |
+
return torch.tensor([[1/(y*aspect), 0, 0, 0],
|
53 |
+
[ 0, 1/-y, 0, 0],
|
54 |
+
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
|
55 |
+
[ 0, 0, -1, 0]], dtype=torch.float32, device=device)
|
56 |
+
|
57 |
+
def _projection(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
|
58 |
+
"""
|
59 |
+
see https://blog.csdn.net/wodownload2/article/details/85069240/
|
60 |
+
"""
|
61 |
+
if l is None:
|
62 |
+
l = -r
|
63 |
+
if t is None:
|
64 |
+
t = r
|
65 |
+
if b is None:
|
66 |
+
b = -t
|
67 |
+
p = torch.zeros([4,4],device=device)
|
68 |
+
p[0,0] = 2*n/(r-l)
|
69 |
+
p[0,2] = (r+l)/(r-l)
|
70 |
+
p[1,1] = 2*n/(t-b) * (-1 if flip_y else 1)
|
71 |
+
p[1,2] = (t+b)/(t-b)
|
72 |
+
p[2,2] = -(f+n)/(f-n)
|
73 |
+
p[2,3] = -(2*f*n)/(f-n)
|
74 |
+
p[3,2] = -1
|
75 |
+
return p #4,4
|
76 |
+
|
77 |
+
def _orthographic(r, device, l=None, t=None, b=None, n=1.0, f=50.0, flip_y=True):
|
78 |
+
if l is None:
|
79 |
+
l = -r
|
80 |
+
if t is None:
|
81 |
+
t = r
|
82 |
+
if b is None:
|
83 |
+
b = -t
|
84 |
+
o = torch.zeros([4,4],device=device)
|
85 |
+
o[0,0] = 2/(r-l)
|
86 |
+
o[0,3] = -(r+l)/(r-l)
|
87 |
+
o[1,1] = 2/(t-b) * (-1 if flip_y else 1)
|
88 |
+
o[1,3] = -(t+b)/(t-b)
|
89 |
+
o[2,2] = -2/(f-n)
|
90 |
+
o[2,3] = -(f+n)/(f-n)
|
91 |
+
o[3,3] = 1
|
92 |
+
return o #4,4
|
93 |
+
|
94 |
+
def make_star_cameras_orig(phis,pol_count,distance:float=10.,r=None,image_size=[512,512],device='cuda'):
|
95 |
+
if r is None:
|
96 |
+
r = 1/distance
|
97 |
+
A = len(phis)
|
98 |
+
P = pol_count
|
99 |
+
C = A * P # total number of cameras
|
100 |
+
|
101 |
+
phi = phis * torch.pi / 180
|
102 |
+
phi_rot = torch.eye(3,device=device)[None,None].expand(A,1,3,3).clone()
|
103 |
+
phi_rot[:,0,2,2] = phi.cos()
|
104 |
+
phi_rot[:,0,2,0] = -phi.sin()
|
105 |
+
phi_rot[:,0,0,2] = phi.sin()
|
106 |
+
phi_rot[:,0,0,0] = phi.cos()
|
107 |
+
|
108 |
+
theta = torch.arange(1,P+1) * (torch.pi/(P+1)) - torch.pi/2
|
109 |
+
theta_rot = torch.eye(3,device=device)[None,None].expand(1,P,3,3).clone()
|
110 |
+
theta_rot[0,:,1,1] = theta.cos()
|
111 |
+
theta_rot[0,:,1,2] = -theta.sin()
|
112 |
+
theta_rot[0,:,2,1] = theta.sin()
|
113 |
+
theta_rot[0,:,2,2] = theta.cos()
|
114 |
+
|
115 |
+
mv = torch.empty((C,4,4), device=device)
|
116 |
+
mv[:] = torch.eye(4, device=device)
|
117 |
+
mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(C,3,3)
|
118 |
+
mv_ = _translation(0, 0, -distance, device) @ mv
|
119 |
+
|
120 |
+
return mv_, _projection(r,device)
|
121 |
+
|
122 |
+
def make_star_cameras_mv_new(phis,eles,distance:float=10.,r=None,fov=None,image_size=[512,512],device='cuda',translation=True):
|
123 |
+
import glm
|
124 |
+
def sample_spherical(phi, theta, cam_radius):
|
125 |
+
theta = torch.deg2rad(theta)
|
126 |
+
phi = torch.deg2rad(phi)
|
127 |
+
|
128 |
+
z = cam_radius * torch.cos(phi) * torch.sin(theta)
|
129 |
+
x = cam_radius * torch.sin(phi) * torch.sin(theta)
|
130 |
+
y = cam_radius * torch.cos(theta)
|
131 |
+
|
132 |
+
return x, y, z
|
133 |
+
|
134 |
+
all_mvs = []
|
135 |
+
for i in range(len(phis)):
|
136 |
+
azimuth = - phis[i] + 1e-10
|
137 |
+
ele = - eles[i] + 1e-10 + 90
|
138 |
+
x, y, z = sample_spherical(azimuth, ele, distance)
|
139 |
+
eye = glm.vec3(x, y, z)
|
140 |
+
at = glm.vec3(0.0, 0.0, 0.0)
|
141 |
+
up = glm.vec3(0.0, 1.0, 0.0)
|
142 |
+
view_matrix = glm.lookAt(eye, at, up)
|
143 |
+
all_mvs.append(torch.from_numpy(np.array(view_matrix)).cuda())
|
144 |
+
mv = torch.stack(all_mvs)
|
145 |
+
|
146 |
+
return mv
|
147 |
+
|
148 |
+
def make_star_cameras_mv(phis,eles,distance:float=10.,r=None,fov=None,image_size=[512,512],device='cuda',translation=True):
|
149 |
+
if r is None:
|
150 |
+
r = 0.15
|
151 |
+
A = len(phis)
|
152 |
+
assert len(eles) == A, f'len(phis): {len(phis)}, len(eles): {len(eles)}'
|
153 |
+
|
154 |
+
phi = phis * torch.pi / 180
|
155 |
+
phi_rot = torch.eye(3,device=device)[None].expand(A,3,3).clone()
|
156 |
+
phi_rot[:,2,2] = phi.cos()
|
157 |
+
phi_rot[:,2,0] = -phi.sin()
|
158 |
+
phi_rot[:,0,2] = phi.sin()
|
159 |
+
phi_rot[:,0,0] = phi.cos()
|
160 |
+
|
161 |
+
|
162 |
+
theta = eles * torch.pi / 180
|
163 |
+
theta_rot = torch.eye(3,device=device)[None].expand(A,3,3).clone()
|
164 |
+
theta_rot[:,1,1] = theta.cos()
|
165 |
+
theta_rot[:,1,2] = -theta.sin()
|
166 |
+
theta_rot[:,2,1] = theta.sin()
|
167 |
+
theta_rot[:,2,2] = theta.cos()
|
168 |
+
|
169 |
+
mv = torch.empty((A,4,4), device=device)
|
170 |
+
mv[:] = torch.eye(4, device=device)
|
171 |
+
mv[:,:3,:3] = (theta_rot @ phi_rot).reshape(A,3,3)
|
172 |
+
|
173 |
+
if translation:
|
174 |
+
mv_ = _translation(0, 0, -distance, device) @ mv
|
175 |
+
else:
|
176 |
+
mv_ = mv
|
177 |
+
return mv_
|
178 |
+
|
179 |
+
def make_star_cameras(phis,eles,distance:float=10.,r=None,fov=None,image_size=[512,512],device='cuda',translation=True):
|
180 |
+
mv_ = make_star_cameras_mv_new(phis, eles, distance, r, device=device, translation=translation)
|
181 |
+
return mv_, _perspective(fov,device=device)
|
182 |
+
|
183 |
+
def make_star_cameras_perspective(phis, eles, distance:float=10., r=None, fov=None, device='cuda'):
|
184 |
+
|
185 |
+
return make_star_cameras(phis, eles, distance, r, fov=fov, device=device, translation=True)
|
186 |
+
|
187 |
+
def make_star_cameras_orthographic(phis, eles, distance:float=10., r=None, device='cuda'):
|
188 |
+
|
189 |
+
mv = make_star_cameras_mv_new(phis, eles, distance, r, device=device)
|
190 |
+
if r is None:
|
191 |
+
r = 1
|
192 |
+
return mv, _orthographic(r,device)
|
193 |
+
|
194 |
+
def make_sphere(level:int=2,radius=1.,device='cuda') -> Tuple[torch.Tensor,torch.Tensor]:
|
195 |
+
sphere = trimesh.creation.icosphere(subdivisions=level, radius=1.0, color=None)
|
196 |
+
vertices = torch.tensor(sphere.vertices, device=device, dtype=torch.float32) * radius
|
197 |
+
faces = torch.tensor(sphere.faces, device=device, dtype=torch.long)
|
198 |
+
return vertices,faces
|
199 |
+
|
200 |
+
|
201 |
+
def get_camera(R, T, focal_length=1 / (2**0.5)):
|
202 |
+
focal_length = 1 / focal_length
|
203 |
+
camera = FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
|
204 |
+
return camera
|
205 |
+
|
206 |
+
def make_star_cameras_orthographic_py3d(azim_list, device, focal=2/1.35, dist=1.1):
|
207 |
+
R, T = look_at_view_transform(dist, 0, azim_list)
|
208 |
+
focal_length = 1 / focal
|
209 |
+
return FoVOrthographicCameras(device=R.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length).to(device)
|
210 |
+
|
211 |
+
|
212 |
+
def rotation_matrix_to_euler_angles(R, return_degrees=True):
|
213 |
+
sy = torch.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0])
|
214 |
+
singular = sy < 1e-6
|
215 |
+
if not singular:
|
216 |
+
x = torch.atan2(R[2, 1], R[2, 2])
|
217 |
+
y = torch.atan2(-R[2, 0], sy)
|
218 |
+
z = torch.atan2(R[1, 0], R[0, 0])
|
219 |
+
else:
|
220 |
+
x = torch.atan2(-R[1, 2], R[1, 1])
|
221 |
+
y = torch.atan2(-R[2, 0], sy)
|
222 |
+
z = 0
|
223 |
+
|
224 |
+
if return_degrees:
|
225 |
+
return torch.tensor([x, y, z]) * 180 / np.pi
|
226 |
+
else:
|
227 |
+
return torch.tensor([x, y, z])
|
models/ISOMER/mesh_reconstruction/opt.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/Profactor/continuous-remeshing
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import torch_scatter
|
5 |
+
from typing import Tuple
|
6 |
+
from ..mesh_reconstruction.remesh import calc_edge_length, calc_edges, calc_face_collapses, calc_face_normals, calc_vertex_normals, collapse_edges, flip_edges, pack, prepend_dummies, remove_dummies, split_edges
|
7 |
+
|
8 |
+
@torch.no_grad()
|
9 |
+
def remesh(
|
10 |
+
vertices_etc:torch.Tensor, #V,D
|
11 |
+
faces:torch.Tensor, #F,3 long
|
12 |
+
min_edgelen:torch.Tensor, #V
|
13 |
+
max_edgelen:torch.Tensor, #V
|
14 |
+
flip:bool,
|
15 |
+
max_vertices=1e6
|
16 |
+
):
|
17 |
+
|
18 |
+
# dummies
|
19 |
+
vertices_etc,faces = prepend_dummies(vertices_etc,faces)
|
20 |
+
vertices = vertices_etc[:,:3] #V,3
|
21 |
+
nan_tensor = torch.tensor([torch.nan],device=min_edgelen.device)
|
22 |
+
min_edgelen = torch.concat((nan_tensor,min_edgelen))
|
23 |
+
max_edgelen = torch.concat((nan_tensor,max_edgelen))
|
24 |
+
|
25 |
+
# collapse
|
26 |
+
edges,face_to_edge = calc_edges(faces) #E,2 F,3
|
27 |
+
edge_length = calc_edge_length(vertices,edges) #E
|
28 |
+
face_normals = calc_face_normals(vertices,faces,normalize=False) #F,3
|
29 |
+
vertex_normals = calc_vertex_normals(vertices,faces,face_normals) #V,3
|
30 |
+
# then calculates the face collapses, which are the faces that can be removed without changing the overall shape of the object.
|
31 |
+
face_collapse = calc_face_collapses(vertices,faces,edges,face_to_edge,edge_length,face_normals,vertex_normals,min_edgelen,area_ratio=0.5)
|
32 |
+
shortness = (1 - edge_length / min_edgelen[edges].mean(dim=-1)).clamp_min_(0) #e[0,1] 0...ok, 1...edgelen=0
|
33 |
+
priority = face_collapse.float() + shortness
|
34 |
+
vertices_etc, faces = collapse_edges(vertices_etc, faces, edges, priority)
|
35 |
+
|
36 |
+
# split: If the number of vertices is less than the maximum allowed, the function splits the edges that are longer than the maximum edge length.
|
37 |
+
if vertices.shape[0]<max_vertices:
|
38 |
+
edges,face_to_edge = calc_edges(faces) #E,2 F,3
|
39 |
+
vertices = vertices_etc[:,:3] #V,3
|
40 |
+
edge_length = calc_edge_length(vertices,edges) #E
|
41 |
+
splits = edge_length > max_edgelen[edges].mean(dim=-1)
|
42 |
+
vertices_etc,faces = split_edges(vertices_etc,faces,edges,face_to_edge,splits,pack_faces=False)
|
43 |
+
|
44 |
+
vertices_etc,faces = pack(vertices_etc,faces)
|
45 |
+
vertices = vertices_etc[:,:3]
|
46 |
+
|
47 |
+
if flip: # flips the edges of the faces
|
48 |
+
edges,_,edge_to_face = calc_edges(faces,with_edge_to_face=True) #E,2 F,3
|
49 |
+
flip_edges(vertices,faces,edges,edge_to_face,with_border=False)
|
50 |
+
|
51 |
+
return remove_dummies(vertices_etc,faces)
|
52 |
+
|
53 |
+
def lerp_unbiased(a:torch.Tensor,b:torch.Tensor,weight:float,step:int):
|
54 |
+
"""lerp with adam's bias correction"""
|
55 |
+
c_prev = 1-weight**(step-1)
|
56 |
+
c = 1-weight**step
|
57 |
+
a_weight = weight*c_prev/c
|
58 |
+
b_weight = (1-weight)/c
|
59 |
+
a.mul_(a_weight).add_(b, alpha=b_weight)
|
60 |
+
|
61 |
+
|
62 |
+
class MeshOptimizer:
|
63 |
+
"""Use this like a pytorch Optimizer, but after calling opt.step(), do vertices,faces = opt.remesh()."""
|
64 |
+
|
65 |
+
def __init__(self,
|
66 |
+
vertices:torch.Tensor, #V,3
|
67 |
+
faces:torch.Tensor, #F,3
|
68 |
+
lr=0.3, #learning rate
|
69 |
+
betas=(0.8,0.8,0), #betas[0:2] are the same as in Adam, betas[2] may be used to time-smooth the relative velocity nu
|
70 |
+
gammas=(0,0,0), #optional spatial smoothing for m1,m2,nu, values between 0 (no smoothing) and 1 (max. smoothing)
|
71 |
+
nu_ref=0.3, #reference velocity for edge length controller
|
72 |
+
edge_len_lims=(.01,.15), #smallest and largest allowed reference edge length
|
73 |
+
edge_len_tol=.5, #edge length tolerance for split and collapse
|
74 |
+
gain=.2, #gain value for edge length controller
|
75 |
+
laplacian_weight=.02, #for laplacian smoothing/regularization
|
76 |
+
ramp=1, #learning rate ramp, actual ramp width is ramp/(1-betas[0])
|
77 |
+
grad_lim=10., #gradients are clipped to m1.abs()*grad_lim
|
78 |
+
remesh_interval=1, #larger intervals are faster but with worse mesh quality
|
79 |
+
local_edgelen=True, #set to False to use a global scalar reference edge length instead
|
80 |
+
):
|
81 |
+
self._vertices = vertices
|
82 |
+
self._faces = faces
|
83 |
+
self._lr = lr
|
84 |
+
self._betas = betas
|
85 |
+
self._gammas = gammas
|
86 |
+
self._nu_ref = nu_ref
|
87 |
+
self._edge_len_lims = edge_len_lims
|
88 |
+
self._edge_len_tol = edge_len_tol
|
89 |
+
self._gain = gain
|
90 |
+
self._laplacian_weight = laplacian_weight
|
91 |
+
self._ramp = ramp
|
92 |
+
self._grad_lim = grad_lim
|
93 |
+
self._remesh_interval = remesh_interval
|
94 |
+
self._local_edgelen = local_edgelen
|
95 |
+
self._step = 0
|
96 |
+
|
97 |
+
V = self._vertices.shape[0]
|
98 |
+
# prepare continuous tensor for all vertex-based data
|
99 |
+
self._vertices_etc = torch.zeros([V,9],device=vertices.device)
|
100 |
+
self._split_vertices_etc()
|
101 |
+
self.vertices.copy_(vertices) #initialize vertices
|
102 |
+
self._vertices.requires_grad_()
|
103 |
+
self._ref_len.fill_(edge_len_lims[1])
|
104 |
+
|
105 |
+
@property
|
106 |
+
def vertices(self):
|
107 |
+
return self._vertices
|
108 |
+
|
109 |
+
@property
|
110 |
+
def faces(self):
|
111 |
+
return self._faces
|
112 |
+
|
113 |
+
def _split_vertices_etc(self):
|
114 |
+
self._vertices = self._vertices_etc[:,:3]
|
115 |
+
self._m2 = self._vertices_etc[:,3]
|
116 |
+
self._nu = self._vertices_etc[:,4]
|
117 |
+
self._m1 = self._vertices_etc[:,5:8]
|
118 |
+
self._ref_len = self._vertices_etc[:,8]
|
119 |
+
|
120 |
+
with_gammas = any(g!=0 for g in self._gammas)
|
121 |
+
self._smooth = self._vertices_etc[:,:8] if with_gammas else self._vertices_etc[:,:3]
|
122 |
+
|
123 |
+
def zero_grad(self):
|
124 |
+
self._vertices.grad = None
|
125 |
+
|
126 |
+
@torch.no_grad()
|
127 |
+
def step(self):
|
128 |
+
|
129 |
+
eps = 1e-8
|
130 |
+
|
131 |
+
self._step += 1
|
132 |
+
|
133 |
+
# spatial smoothing
|
134 |
+
edges,_ = calc_edges(self._faces) #E,2
|
135 |
+
E = edges.shape[0]
|
136 |
+
edge_smooth = self._smooth[edges] #E,2,S
|
137 |
+
neighbor_smooth = torch.zeros_like(self._smooth) #V,S
|
138 |
+
torch_scatter.scatter_mean(src=edge_smooth.flip(dims=[1]).reshape(E*2,-1),index=edges.reshape(E*2,1),dim=0,out=neighbor_smooth)
|
139 |
+
|
140 |
+
#apply optional smoothing of m1,m2,nu
|
141 |
+
if self._gammas[0]:
|
142 |
+
self._m1.lerp_(neighbor_smooth[:,5:8],self._gammas[0])
|
143 |
+
if self._gammas[1]:
|
144 |
+
self._m2.lerp_(neighbor_smooth[:,3],self._gammas[1])
|
145 |
+
if self._gammas[2]:
|
146 |
+
self._nu.lerp_(neighbor_smooth[:,4],self._gammas[2])
|
147 |
+
|
148 |
+
#add laplace smoothing to gradients
|
149 |
+
laplace = self._vertices - neighbor_smooth[:,:3]
|
150 |
+
grad = torch.addcmul(self._vertices.grad, laplace, self._nu[:,None], value=self._laplacian_weight)
|
151 |
+
|
152 |
+
#gradient clipping
|
153 |
+
if self._step>1:
|
154 |
+
grad_lim = self._m1.abs().mul_(self._grad_lim)
|
155 |
+
grad.clamp_(min=-grad_lim,max=grad_lim)
|
156 |
+
|
157 |
+
# moment updates
|
158 |
+
lerp_unbiased(self._m1, grad, self._betas[0], self._step)
|
159 |
+
lerp_unbiased(self._m2, (grad**2).sum(dim=-1), self._betas[1], self._step)
|
160 |
+
|
161 |
+
velocity = self._m1 / self._m2[:,None].sqrt().add_(eps) #V,3
|
162 |
+
speed = velocity.norm(dim=-1) #V
|
163 |
+
|
164 |
+
if self._betas[2]:
|
165 |
+
lerp_unbiased(self._nu,speed,self._betas[2],self._step) #V
|
166 |
+
else:
|
167 |
+
self._nu.copy_(speed) #V
|
168 |
+
|
169 |
+
# update vertices
|
170 |
+
ramped_lr = self._lr * min(1,self._step * (1-self._betas[0]) / self._ramp)
|
171 |
+
self._vertices.add_(velocity * self._ref_len[:,None], alpha=-ramped_lr)
|
172 |
+
|
173 |
+
# update target edge length
|
174 |
+
if self._step % self._remesh_interval == 0:
|
175 |
+
if self._local_edgelen:
|
176 |
+
len_change = (1 + (self._nu - self._nu_ref) * self._gain)
|
177 |
+
else:
|
178 |
+
len_change = (1 + (self._nu.mean() - self._nu_ref) * self._gain)
|
179 |
+
self._ref_len *= len_change
|
180 |
+
self._ref_len.clamp_(*self._edge_len_lims)
|
181 |
+
|
182 |
+
def remesh(self, flip:bool=True, poisson=False)->Tuple[torch.Tensor,torch.Tensor]:
|
183 |
+
min_edge_len = self._ref_len * (1 - self._edge_len_tol)
|
184 |
+
max_edge_len = self._ref_len * (1 + self._edge_len_tol)
|
185 |
+
|
186 |
+
self._vertices_etc,self._faces = remesh(self._vertices_etc,self._faces,min_edge_len,max_edge_len,flip, max_vertices=1e6)
|
187 |
+
|
188 |
+
self._split_vertices_etc()
|
189 |
+
self._vertices.requires_grad_()
|
190 |
+
|
191 |
+
return self._vertices, self._faces
|
models/ISOMER/mesh_reconstruction/recon.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torchvision.utils import make_grid
|
6 |
+
from typing import List
|
7 |
+
from ..mesh_reconstruction.remesh import calc_vertex_normals
|
8 |
+
from ..mesh_reconstruction.opt import MeshOptimizer
|
9 |
+
from ..mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
|
10 |
+
from ..mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
|
11 |
+
from ..scripts.utils import to_py3d_mesh, init_target
|
12 |
+
|
13 |
+
def reconstruct_stage1(pils: List[Image.Image], mv, proj, steps=100, vertices=None, faces=None, start_edge_len=0.15, end_edge_len=0.005, decay=0.995, return_mesh=True, loss_expansion_weight=0.1, gain=0.1, use_remesh=True):
|
14 |
+
|
15 |
+
vertices, faces = vertices.to("cuda"), faces.to("cuda")
|
16 |
+
|
17 |
+
renderer = NormalsRenderer(mv,proj,list(pils[0].size))
|
18 |
+
|
19 |
+
|
20 |
+
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
21 |
+
|
22 |
+
opt = MeshOptimizer(vertices,faces, local_edgelen=False, gain=gain, edge_len_lims=(end_edge_len, start_edge_len))
|
23 |
+
|
24 |
+
vertices = opt.vertices
|
25 |
+
|
26 |
+
mask = target_images[..., -1] < 0.5
|
27 |
+
|
28 |
+
for i in tqdm(range(steps)):
|
29 |
+
opt._lr *= decay
|
30 |
+
|
31 |
+
normals = calc_vertex_normals(vertices,faces)
|
32 |
+
images = renderer.render(vertices,normals,faces)
|
33 |
+
|
34 |
+
loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean()
|
35 |
+
|
36 |
+
t_mask = images[..., -1] > 0.5
|
37 |
+
loss_target_l2 = (images[t_mask] - target_images[t_mask]).abs().pow(2).mean()
|
38 |
+
|
39 |
+
loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
|
40 |
+
|
41 |
+
loss = loss_target_l2 + loss_alpha_target_mask_l2 + loss_expand * loss_expansion_weight
|
42 |
+
|
43 |
+
loss_oob = (vertices.abs() > 0.99).float().mean() * 10
|
44 |
+
loss = loss + loss_oob
|
45 |
+
|
46 |
+
|
47 |
+
loss.backward()
|
48 |
+
opt.step()
|
49 |
+
|
50 |
+
if use_remesh:
|
51 |
+
vertices,faces = opt.remesh(poisson=False)
|
52 |
+
|
53 |
+
vertices, faces = vertices.detach(), faces.detach()
|
54 |
+
|
55 |
+
if return_mesh:
|
56 |
+
return to_py3d_mesh(vertices, faces)
|
57 |
+
else:
|
58 |
+
return vertices, faces
|
models/ISOMER/mesh_reconstruction/refine.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from typing import List
|
6 |
+
from ..mesh_reconstruction.remesh import calc_vertex_normals
|
7 |
+
from ..mesh_reconstruction.opt import MeshOptimizer
|
8 |
+
from ..mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_orthographic_py3d
|
9 |
+
from ..mesh_reconstruction.render import NormalsRenderer, Pytorch3DNormalsRenderer
|
10 |
+
from ..scripts.project_mesh import multiview_color_projection, get_cameras_list
|
11 |
+
from ..scripts.utils import to_py3d_mesh, from_py3d_mesh, init_target
|
12 |
+
|
13 |
+
def run_mesh_refine(vertices, faces, pils: List[Image.Image], mv, proj, weights, cameras, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=10, update_warmup=10, return_mesh=True, process_inputs=True, process_outputs=True, use_remesh=True, loss_expansion_weight=0):
|
14 |
+
|
15 |
+
if process_inputs:
|
16 |
+
vertices = vertices * 2 / 1.35
|
17 |
+
vertices[..., [0, 2]] = - vertices[..., [0, 2]]
|
18 |
+
|
19 |
+
poission_steps = []
|
20 |
+
|
21 |
+
renderer = NormalsRenderer(mv,proj,list(pils[0].size))
|
22 |
+
|
23 |
+
|
24 |
+
target_images = init_target(pils, new_bkgd=(0., 0., 0.)) # 4s
|
25 |
+
|
26 |
+
opt = MeshOptimizer(vertices,faces, ramp=5, edge_len_lims=(end_edge_len, start_edge_len), local_edgelen=False, laplacian_weight=0.02)
|
27 |
+
|
28 |
+
vertices = opt.vertices
|
29 |
+
alpha_init = None
|
30 |
+
|
31 |
+
mask = target_images[..., -1] < 0.5
|
32 |
+
|
33 |
+
for i in tqdm(range(steps)):
|
34 |
+
opt.zero_grad()
|
35 |
+
opt._lr *= decay
|
36 |
+
normals = calc_vertex_normals(vertices,faces)
|
37 |
+
images = renderer.render(vertices,normals,faces)
|
38 |
+
|
39 |
+
if alpha_init is None:
|
40 |
+
alpha_init = images.detach()
|
41 |
+
|
42 |
+
# update explicit target and render images for L_ET calculation
|
43 |
+
if i < update_warmup or i % update_normal_interval == 0:
|
44 |
+
with torch.no_grad():
|
45 |
+
|
46 |
+
py3d_mesh = to_py3d_mesh(vertices, faces, normals)
|
47 |
+
|
48 |
+
_, _, target_normal = from_py3d_mesh(multiview_color_projection(py3d_mesh, pils, cameras_list=cameras, weights=weights, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy='original', reweight_with_cosangle='linear'))
|
49 |
+
|
50 |
+
target_normal = target_normal * 2 - 1
|
51 |
+
target_normal = torch.nn.functional.normalize(target_normal, dim=-1)
|
52 |
+
debug_images = renderer.render(vertices,target_normal,faces)
|
53 |
+
|
54 |
+
d_mask = images[..., -1] > 0.5
|
55 |
+
loss_debug_l2 = (images[..., :3][d_mask] - debug_images[..., :3][d_mask]).pow(2).mean()
|
56 |
+
|
57 |
+
loss_alpha_target_mask_l2 = (images[..., -1][mask] - target_images[..., -1][mask]).pow(2).mean()
|
58 |
+
|
59 |
+
loss = loss_debug_l2 + loss_alpha_target_mask_l2
|
60 |
+
|
61 |
+
loss_oob = (vertices.abs() > 0.99).float().mean() * 10
|
62 |
+
|
63 |
+
loss = loss + loss_oob
|
64 |
+
|
65 |
+
|
66 |
+
# this loss_expand does not exist in original ISOMER. we add it here (but default loss_expansion_weight is 0)
|
67 |
+
loss_expand = 0.5 * ((vertices+normals).detach() - vertices).pow(2).mean()
|
68 |
+
loss += loss_expand * loss_expansion_weight
|
69 |
+
|
70 |
+
loss.backward()
|
71 |
+
opt.step()
|
72 |
+
|
73 |
+
|
74 |
+
if use_remesh:
|
75 |
+
vertices,faces = opt.remesh(poisson=(i in poission_steps))
|
76 |
+
|
77 |
+
vertices, faces = vertices.detach(), faces.detach()
|
78 |
+
|
79 |
+
if process_outputs:
|
80 |
+
vertices = vertices / 2 * 1.35
|
81 |
+
vertices[..., [0, 2]] = - vertices[..., [0, 2]]
|
82 |
+
|
83 |
+
if return_mesh:
|
84 |
+
return to_py3d_mesh(vertices, faces)
|
85 |
+
else:
|
86 |
+
return vertices, faces
|
models/ISOMER/mesh_reconstruction/remesh.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/Profactor/continuous-remeshing
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as tfunc
|
4 |
+
import torch_scatter
|
5 |
+
from typing import Tuple
|
6 |
+
|
7 |
+
def prepend_dummies(
|
8 |
+
vertices:torch.Tensor, #V,D
|
9 |
+
faces:torch.Tensor, #F,3 long
|
10 |
+
)->Tuple[torch.Tensor,torch.Tensor]:
|
11 |
+
"""prepend dummy elements to vertices and faces to enable "masked" scatter operations"""
|
12 |
+
V,D = vertices.shape
|
13 |
+
vertices = torch.concat((torch.full((1,D),fill_value=torch.nan,device=vertices.device),vertices),dim=0)
|
14 |
+
faces = torch.concat((torch.zeros((1,3),dtype=torch.long,device=faces.device),faces+1),dim=0)
|
15 |
+
return vertices,faces
|
16 |
+
|
17 |
+
def remove_dummies(
|
18 |
+
vertices:torch.Tensor, #V,D - first vertex all nan and unreferenced
|
19 |
+
faces:torch.Tensor, #F,3 long - first face all zeros
|
20 |
+
)->Tuple[torch.Tensor,torch.Tensor]:
|
21 |
+
"""remove dummy elements added with prepend_dummies()"""
|
22 |
+
return vertices[1:],faces[1:]-1
|
23 |
+
|
24 |
+
|
25 |
+
def calc_edges(
|
26 |
+
faces: torch.Tensor, # F,3 long - first face may be dummy with all zeros
|
27 |
+
with_edge_to_face: bool = False
|
28 |
+
) -> Tuple[torch.Tensor, ...]:
|
29 |
+
"""
|
30 |
+
returns Tuple of
|
31 |
+
- edges E,2 long, 0 for unused, lower vertex index first
|
32 |
+
- face_to_edge F,3 long
|
33 |
+
- (optional) edge_to_face shape=E,[left,right],[face,side]
|
34 |
+
|
35 |
+
o-<-----e1 e0,e1...edge, e0<e1
|
36 |
+
| /A L,R....left and right face
|
37 |
+
| L / | both triangles ordered counter clockwise
|
38 |
+
| / R | normals pointing out of screen
|
39 |
+
V/ |
|
40 |
+
e0---->-o
|
41 |
+
"""
|
42 |
+
|
43 |
+
F = faces.shape[0]
|
44 |
+
|
45 |
+
# make full edges, lower vertex index first
|
46 |
+
face_edges = torch.stack((faces,faces.roll(-1,1)),dim=-1) #F*3,3,2
|
47 |
+
full_edges = face_edges.reshape(F*3,2)
|
48 |
+
sorted_edges,_ = full_edges.sort(dim=-1) #F*3,2
|
49 |
+
|
50 |
+
# make unique edges
|
51 |
+
edges,full_to_unique = torch.unique(input=sorted_edges,sorted=True,return_inverse=True,dim=0) #(E,2),(F*3)
|
52 |
+
E = edges.shape[0]
|
53 |
+
face_to_edge = full_to_unique.reshape(F,3) #F,3
|
54 |
+
|
55 |
+
if not with_edge_to_face:
|
56 |
+
return edges, face_to_edge
|
57 |
+
|
58 |
+
is_right = full_edges[:,0]!=sorted_edges[:,0] #F*3
|
59 |
+
edge_to_face = torch.zeros((E,2,2),dtype=torch.long,device=faces.device) #E,LR=2,S=2
|
60 |
+
scatter_src = torch.cartesian_prod(torch.arange(0,F,device=faces.device),torch.arange(0,3,device=faces.device)) #F*3,2
|
61 |
+
edge_to_face.reshape(2*E,2).scatter_(dim=0,index=(2*full_to_unique+is_right)[:,None].expand(F*3,2),src=scatter_src) #E,LR=2,S=2
|
62 |
+
edge_to_face[0] = 0
|
63 |
+
return edges, face_to_edge, edge_to_face
|
64 |
+
|
65 |
+
def calc_edge_length(
|
66 |
+
vertices:torch.Tensor, #V,3 first may be dummy
|
67 |
+
edges:torch.Tensor, #E,2 long, lower vertex index first, (0,0) for unused
|
68 |
+
)->torch.Tensor: #E
|
69 |
+
|
70 |
+
full_vertices = vertices[edges] #E,2,3
|
71 |
+
a,b = full_vertices.unbind(dim=1) #E,3
|
72 |
+
return torch.norm(a-b,p=2,dim=-1)
|
73 |
+
|
74 |
+
def calc_face_normals(
|
75 |
+
vertices:torch.Tensor, #V,3 first vertex may be unreferenced
|
76 |
+
faces:torch.Tensor, #F,3 long, first face may be all zero
|
77 |
+
normalize:bool=False,
|
78 |
+
)->torch.Tensor: #F,3
|
79 |
+
"""
|
80 |
+
n
|
81 |
+
|
|
82 |
+
c0 corners ordered counterclockwise when
|
83 |
+
/ \ looking onto surface (in neg normal direction)
|
84 |
+
c1---c2
|
85 |
+
"""
|
86 |
+
full_vertices = vertices[faces] #F,C=3,3
|
87 |
+
v0,v1,v2 = full_vertices.unbind(dim=1) #F,3
|
88 |
+
face_normals = torch.cross(v1-v0,v2-v0, dim=1) #F,3
|
89 |
+
if normalize:
|
90 |
+
face_normals = tfunc.normalize(face_normals, eps=1e-6, dim=1)
|
91 |
+
return face_normals #F,3
|
92 |
+
|
93 |
+
def calc_vertex_normals(
|
94 |
+
vertices:torch.Tensor, #V,3 first vertex may be unreferenced
|
95 |
+
faces:torch.Tensor, #F,3 long, first face may be all zero
|
96 |
+
face_normals:torch.Tensor=None, #F,3, not normalized
|
97 |
+
)->torch.Tensor: #F,3
|
98 |
+
|
99 |
+
F = faces.shape[0]
|
100 |
+
|
101 |
+
if face_normals is None:
|
102 |
+
face_normals = calc_face_normals(vertices,faces) # this no grad
|
103 |
+
|
104 |
+
vertex_normals = torch.zeros((vertices.shape[0],3,3),dtype=vertices.dtype,device=vertices.device) #V,C=3,3
|
105 |
+
|
106 |
+
|
107 |
+
vertex_normals.scatter_add_(dim=0,index=faces[:,:,None].expand(F,3,3),src=face_normals[:,None,:].expand(F,3,3)) # This no grad
|
108 |
+
vertex_normals = vertex_normals.sum(dim=1) #V,3
|
109 |
+
return tfunc.normalize(vertex_normals, eps=1e-6, dim=1)
|
110 |
+
|
111 |
+
def calc_face_ref_normals(
|
112 |
+
faces:torch.Tensor, #F,3 long, 0 for unused
|
113 |
+
vertex_normals:torch.Tensor, #V,3 first unused
|
114 |
+
normalize:bool=False,
|
115 |
+
)->torch.Tensor: #F,3
|
116 |
+
"""calculate reference normals for face flip detection"""
|
117 |
+
full_normals = vertex_normals[faces] #F,C=3,3
|
118 |
+
ref_normals = full_normals.sum(dim=1) #F,3
|
119 |
+
if normalize:
|
120 |
+
ref_normals = tfunc.normalize(ref_normals, eps=1e-6, dim=1)
|
121 |
+
return ref_normals
|
122 |
+
|
123 |
+
def pack(
|
124 |
+
vertices:torch.Tensor, #V,3 first unused and nan
|
125 |
+
faces:torch.Tensor, #F,3 long, 0 for unused
|
126 |
+
)->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces), keeps first vertex unused
|
127 |
+
"""removes unused elements in vertices and faces"""
|
128 |
+
V = vertices.shape[0]
|
129 |
+
|
130 |
+
# remove unused faces
|
131 |
+
used_faces = faces[:,0]!=0
|
132 |
+
used_faces[0] = True
|
133 |
+
faces = faces[used_faces] #sync
|
134 |
+
|
135 |
+
# remove unused vertices
|
136 |
+
used_vertices = torch.zeros(V,3,dtype=torch.bool,device=vertices.device)
|
137 |
+
used_vertices.scatter_(dim=0,index=faces,value=True,reduce='add')
|
138 |
+
used_vertices = used_vertices.any(dim=1)
|
139 |
+
used_vertices[0] = True
|
140 |
+
vertices = vertices[used_vertices] #sync
|
141 |
+
|
142 |
+
# update used faces
|
143 |
+
ind = torch.zeros(V,dtype=torch.long,device=vertices.device)
|
144 |
+
V1 = used_vertices.sum()
|
145 |
+
ind[used_vertices] = torch.arange(0,V1,device=vertices.device) #sync
|
146 |
+
faces = ind[faces]
|
147 |
+
|
148 |
+
return vertices,faces
|
149 |
+
|
150 |
+
def split_edges(
|
151 |
+
vertices:torch.Tensor, #V,3 first unused
|
152 |
+
faces:torch.Tensor, #F,3 long, 0 for unused
|
153 |
+
edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
|
154 |
+
face_to_edge:torch.Tensor, #F,3 long 0 for unused
|
155 |
+
splits, #E bool
|
156 |
+
pack_faces:bool=True,
|
157 |
+
)->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
|
158 |
+
|
159 |
+
# c2 c2 c...corners = faces
|
160 |
+
# . . . . s...side_vert, 0 means no split
|
161 |
+
# . . .N2 . S...shrunk_face
|
162 |
+
# . . . . Ni...new_faces
|
163 |
+
# s2 s1 s2|c2...s1|c1
|
164 |
+
# . . . . .
|
165 |
+
# . . . S . .
|
166 |
+
# . . . . N1 .
|
167 |
+
# c0...(s0=0)....c1 s0|c0...........c1
|
168 |
+
#
|
169 |
+
# pseudo-code:
|
170 |
+
# S = [s0|c0,s1|c1,s2|c2] example:[c0,s1,s2]
|
171 |
+
# split = side_vert!=0 example:[False,True,True]
|
172 |
+
# N0 = split[0]*[c0,s0,s2|c2] example:[0,0,0]
|
173 |
+
# N1 = split[1]*[c1,s1,s0|c0] example:[c1,s1,c0]
|
174 |
+
# N2 = split[2]*[c2,s2,s1|c1] example:[c2,s2,s1]
|
175 |
+
|
176 |
+
V = vertices.shape[0]
|
177 |
+
F = faces.shape[0]
|
178 |
+
S = splits.sum().item() #sync
|
179 |
+
|
180 |
+
if S==0:
|
181 |
+
return vertices,faces
|
182 |
+
|
183 |
+
edge_vert = torch.zeros_like(splits, dtype=torch.long) #E
|
184 |
+
edge_vert[splits] = torch.arange(V,V+S,dtype=torch.long,device=vertices.device) #E 0 for no split, sync
|
185 |
+
side_vert = edge_vert[face_to_edge] #F,3 long, 0 for no split
|
186 |
+
split_edges = edges[splits] #S sync
|
187 |
+
|
188 |
+
#vertices
|
189 |
+
split_vertices = vertices[split_edges].mean(dim=1) #S,3
|
190 |
+
vertices = torch.concat((vertices,split_vertices),dim=0)
|
191 |
+
|
192 |
+
#faces
|
193 |
+
side_split = side_vert!=0 #F,3
|
194 |
+
shrunk_faces = torch.where(side_split,side_vert,faces) #F,3 long, 0 for no split
|
195 |
+
new_faces = side_split[:,:,None] * torch.stack((faces,side_vert,shrunk_faces.roll(1,dims=-1)),dim=-1) #F,N=3,C=3
|
196 |
+
faces = torch.concat((shrunk_faces,new_faces.reshape(F*3,3))) #4F,3
|
197 |
+
if pack_faces:
|
198 |
+
mask = faces[:,0]!=0
|
199 |
+
mask[0] = True
|
200 |
+
faces = faces[mask] #F',3 sync
|
201 |
+
|
202 |
+
return vertices,faces
|
203 |
+
|
204 |
+
def collapse_edges(
|
205 |
+
vertices:torch.Tensor, #V,3 first unused
|
206 |
+
faces:torch.Tensor, #F,3 long 0 for unused
|
207 |
+
edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
|
208 |
+
priorities:torch.Tensor, #E float
|
209 |
+
stable:bool=False, #only for unit testing
|
210 |
+
)->Tuple[torch.Tensor,torch.Tensor]: #(vertices,faces)
|
211 |
+
|
212 |
+
V = vertices.shape[0]
|
213 |
+
|
214 |
+
# check spacing
|
215 |
+
_,order = priorities.sort(stable=stable) #E
|
216 |
+
rank = torch.zeros_like(order)
|
217 |
+
rank[order] = torch.arange(0,len(rank),device=rank.device)
|
218 |
+
vert_rank = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
|
219 |
+
edge_rank = rank #E
|
220 |
+
for i in range(3):
|
221 |
+
torch_scatter.scatter_max(src=edge_rank[:,None].expand(-1,2).reshape(-1),index=edges.reshape(-1),dim=0,out=vert_rank)
|
222 |
+
edge_rank,_ = vert_rank[edges].max(dim=-1) #E
|
223 |
+
candidates = edges[(edge_rank==rank).logical_and_(priorities>0)] #E',2
|
224 |
+
|
225 |
+
# check connectivity
|
226 |
+
vert_connections = torch.zeros(V,dtype=torch.long,device=vertices.device) #V
|
227 |
+
vert_connections[candidates[:,0]] = 1 #start
|
228 |
+
edge_connections = vert_connections[edges].sum(dim=-1) #E, edge connected to start
|
229 |
+
vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1))# one edge from start
|
230 |
+
vert_connections[candidates] = 0 #clear start and end
|
231 |
+
edge_connections = vert_connections[edges].sum(dim=-1) #E, one or two edges from start
|
232 |
+
vert_connections.scatter_add_(dim=0,index=edges.reshape(-1),src=edge_connections[:,None].expand(-1,2).reshape(-1)) #one or two edges from start
|
233 |
+
collapses = candidates[vert_connections[candidates[:,1]] <= 2] # E" not more than two connections between start and end
|
234 |
+
|
235 |
+
# mean vertices
|
236 |
+
vertices[collapses[:,0]] = vertices[collapses].mean(dim=1)
|
237 |
+
|
238 |
+
# update faces
|
239 |
+
dest = torch.arange(0,V,dtype=torch.long,device=vertices.device) #V
|
240 |
+
dest[collapses[:,1]] = dest[collapses[:,0]]
|
241 |
+
faces = dest[faces] #F,3
|
242 |
+
c0,c1,c2 = faces.unbind(dim=-1)
|
243 |
+
collapsed = (c0==c1).logical_or_(c1==c2).logical_or_(c0==c2)
|
244 |
+
faces[collapsed] = 0
|
245 |
+
|
246 |
+
return vertices,faces
|
247 |
+
|
248 |
+
def calc_face_collapses(
|
249 |
+
vertices:torch.Tensor, #V,3 first unused
|
250 |
+
faces:torch.Tensor, #F,3 long, 0 for unused
|
251 |
+
edges:torch.Tensor, #E,2 long 0 for unused, lower vertex index first
|
252 |
+
face_to_edge:torch.Tensor, #F,3 long 0 for unused
|
253 |
+
edge_length:torch.Tensor, #E
|
254 |
+
face_normals:torch.Tensor, #F,3
|
255 |
+
vertex_normals:torch.Tensor, #V,3 first unused
|
256 |
+
min_edge_length:torch.Tensor=None, #V
|
257 |
+
area_ratio = 0.5, #collapse if area < min_edge_length**2 * area_ratio
|
258 |
+
shortest_probability = 0.8
|
259 |
+
)->torch.Tensor: #E edges to collapse
|
260 |
+
|
261 |
+
E = edges.shape[0]
|
262 |
+
F = faces.shape[0]
|
263 |
+
|
264 |
+
# face flips
|
265 |
+
ref_normals = calc_face_ref_normals(faces,vertex_normals,normalize=False) #F,3
|
266 |
+
face_collapses = (face_normals*ref_normals).sum(dim=-1)<0 #F
|
267 |
+
|
268 |
+
# small faces
|
269 |
+
if min_edge_length is not None:
|
270 |
+
min_face_length = min_edge_length[faces].mean(dim=-1) #F
|
271 |
+
min_area = min_face_length**2 * area_ratio #F
|
272 |
+
face_collapses.logical_or_(face_normals.norm(dim=-1) < min_area*2) #F
|
273 |
+
face_collapses[0] = False
|
274 |
+
|
275 |
+
# faces to edges
|
276 |
+
face_length = edge_length[face_to_edge] #F,3
|
277 |
+
|
278 |
+
if shortest_probability<1:
|
279 |
+
#select shortest edge with shortest_probability chance
|
280 |
+
randlim = round(2/(1-shortest_probability))
|
281 |
+
rand_ind = torch.randint(0,randlim,size=(F,),device=faces.device).clamp_max_(2) #selected edge local index in face
|
282 |
+
sort_ind = torch.argsort(face_length,dim=-1,descending=True) #F,3
|
283 |
+
local_ind = sort_ind.gather(dim=-1,index=rand_ind[:,None])
|
284 |
+
else:
|
285 |
+
local_ind = torch.argmin(face_length,dim=-1)[:,None] #F,1 0...2 shortest edge local index in face
|
286 |
+
|
287 |
+
edge_ind = face_to_edge.gather(dim=1,index=local_ind)[:,0] #F 0...E selected edge global index
|
288 |
+
edge_collapses = torch.zeros(E,dtype=torch.long,device=vertices.device)
|
289 |
+
edge_collapses.scatter_add_(dim=0,index=edge_ind,src=face_collapses.long())
|
290 |
+
|
291 |
+
return edge_collapses.bool()
|
292 |
+
|
293 |
+
def flip_edges(
|
294 |
+
vertices:torch.Tensor, #V,3 first unused
|
295 |
+
faces:torch.Tensor, #F,3 long, first must be 0, 0 for unused
|
296 |
+
edges:torch.Tensor, #E,2 long, first must be 0, 0 for unused, lower vertex index first
|
297 |
+
edge_to_face:torch.Tensor, #E,[left,right],[face,side]
|
298 |
+
with_border:bool=True, #handle border edges (D=4 instead of D=6)
|
299 |
+
with_normal_check:bool=True, #check face normal flips
|
300 |
+
stable:bool=False, #only for unit testing
|
301 |
+
):
|
302 |
+
V = vertices.shape[0]
|
303 |
+
E = edges.shape[0]
|
304 |
+
device=vertices.device
|
305 |
+
vertex_degree = torch.zeros(V,dtype=torch.long,device=device) #V long
|
306 |
+
vertex_degree.scatter_(dim=0,index=edges.reshape(E*2),value=1,reduce='add')
|
307 |
+
neighbor_corner = (edge_to_face[:,:,1] + 2) % 3 #go from side to corner
|
308 |
+
neighbors = faces[edge_to_face[:,:,0],neighbor_corner] #E,LR=2
|
309 |
+
edge_is_inside = neighbors.all(dim=-1) #E
|
310 |
+
|
311 |
+
if with_border:
|
312 |
+
# inside vertices should have D=6, border edges D=4, so we subtract 2 for all inside vertices
|
313 |
+
# need to use float for masks in order to use scatter(reduce='multiply')
|
314 |
+
vertex_is_inside = torch.ones(V,2,dtype=torch.float32,device=vertices.device) #V,2 float
|
315 |
+
src = edge_is_inside.type(torch.float32)[:,None].expand(E,2) #E,2 float
|
316 |
+
vertex_is_inside.scatter_(dim=0,index=edges,src=src,reduce='multiply')
|
317 |
+
vertex_is_inside = vertex_is_inside.prod(dim=-1,dtype=torch.long) #V long
|
318 |
+
vertex_degree -= 2 * vertex_is_inside #V long
|
319 |
+
|
320 |
+
neighbor_degrees = vertex_degree[neighbors] #E,LR=2
|
321 |
+
edge_degrees = vertex_degree[edges] #E,2
|
322 |
+
#
|
323 |
+
# loss = Sum_over_affected_vertices((new_degree-6)**2)
|
324 |
+
# loss_change = Sum_over_neighbor_vertices((degree+1-6)**2-(degree-6)**2)
|
325 |
+
# + Sum_over_edge_vertices((degree-1-6)**2-(degree-6)**2)
|
326 |
+
# = 2 * (2 + Sum_over_neighbor_vertices(degree) - Sum_over_edge_vertices(degree))
|
327 |
+
#
|
328 |
+
loss_change = 2 + neighbor_degrees.sum(dim=-1) - edge_degrees.sum(dim=-1) #E
|
329 |
+
candidates = torch.logical_and(loss_change<0, edge_is_inside) #E
|
330 |
+
loss_change = loss_change[candidates] #E'
|
331 |
+
if loss_change.shape[0]==0:
|
332 |
+
return
|
333 |
+
|
334 |
+
edges_neighbors = torch.concat((edges[candidates],neighbors[candidates]),dim=-1) #E',4
|
335 |
+
_,order = loss_change.sort(descending=True, stable=stable) #E'
|
336 |
+
rank = torch.zeros_like(order)
|
337 |
+
rank[order] = torch.arange(0,len(rank),device=rank.device)
|
338 |
+
vertex_rank = torch.zeros((V,4),dtype=torch.long,device=device) #V,4
|
339 |
+
torch_scatter.scatter_max(src=rank[:,None].expand(-1,4),index=edges_neighbors,dim=0,out=vertex_rank)
|
340 |
+
vertex_rank,_ = vertex_rank.max(dim=-1) #V
|
341 |
+
neighborhood_rank,_ = vertex_rank[edges_neighbors].max(dim=-1) #E'
|
342 |
+
flip = rank==neighborhood_rank #E'
|
343 |
+
|
344 |
+
if with_normal_check:
|
345 |
+
# cl-<-----e1 e0,e1...edge, e0<e1
|
346 |
+
# | /A L,R....left and right face
|
347 |
+
# | L / | both triangles ordered counter clockwise
|
348 |
+
# | / R | normals pointing out of screen
|
349 |
+
# V/ |
|
350 |
+
# e0---->-cr
|
351 |
+
v = vertices[edges_neighbors] #E",4,3
|
352 |
+
v = v - v[:,0:1] #make relative to e0
|
353 |
+
e1 = v[:,1]
|
354 |
+
cl = v[:,2]
|
355 |
+
cr = v[:,3]
|
356 |
+
n = torch.cross(e1,cl) + torch.cross(cr,e1) #sum of old normal vectors
|
357 |
+
flip.logical_and_(torch.sum(n*torch.cross(cr,cl),dim=-1)>0) #first new face
|
358 |
+
flip.logical_and_(torch.sum(n*torch.cross(cl-e1,cr-e1),dim=-1)>0) #second new face
|
359 |
+
|
360 |
+
flip_edges_neighbors = edges_neighbors[flip] #E",4
|
361 |
+
flip_edge_to_face = edge_to_face[candidates,:,0][flip] #E",2
|
362 |
+
flip_faces = flip_edges_neighbors[:,[[0,3,2],[1,2,3]]] #E",2,3
|
363 |
+
faces.scatter_(dim=0,index=flip_edge_to_face.reshape(-1,1).expand(-1,3),src=flip_faces.reshape(-1,3))
|
models/ISOMER/mesh_reconstruction/render.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/Profactor/continuous-remeshing
|
2 |
+
import nvdiffrast.torch as dr
|
3 |
+
import torch
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
def _warmup(glctx, device=None):
|
7 |
+
device = 'cuda' if device is None else device
|
8 |
+
#windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
|
9 |
+
def tensor(*args, **kwargs):
|
10 |
+
return torch.tensor(*args, device=device, **kwargs)
|
11 |
+
|
12 |
+
# defines a triangle in homogeneous coordinates and calls dr.rasterize to render this triangle, which may help to initialize or warm up the GPU context
|
13 |
+
pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
|
14 |
+
tri = tensor([[0, 1, 2]], dtype=torch.int32)
|
15 |
+
dr.rasterize(glctx, pos, tri, resolution=[256, 256])
|
16 |
+
|
17 |
+
# glctx = dr.RasterizeGLContext(output_db=False, device="cuda")
|
18 |
+
glctx = dr.RasterizeCudaContext(device="cuda")
|
19 |
+
|
20 |
+
class NormalsRenderer:
|
21 |
+
|
22 |
+
_glctx:dr.RasterizeCudaContext = None
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
mv: torch.Tensor, #C,4,4 # normal column-major (unlike pytorch3d)
|
27 |
+
proj: torch.Tensor, #C,4,4
|
28 |
+
image_size: Tuple[int,int],
|
29 |
+
mvp = None,
|
30 |
+
device=None,
|
31 |
+
):
|
32 |
+
if mvp is None:
|
33 |
+
self._mvp = proj @ mv #C,4,4
|
34 |
+
else:
|
35 |
+
self._mvp = mvp
|
36 |
+
self._image_size = image_size
|
37 |
+
self._glctx = glctx
|
38 |
+
_warmup(self._glctx, device)
|
39 |
+
|
40 |
+
def render(self,
|
41 |
+
vertices: torch.Tensor, #V,3 float
|
42 |
+
normals: torch.Tensor, #V,3 float in [-1, 1]
|
43 |
+
faces: torch.Tensor, #F,3 long
|
44 |
+
) ->torch.Tensor: #C,H,W,4
|
45 |
+
|
46 |
+
V = vertices.shape[0]
|
47 |
+
faces = faces.type(torch.int32)
|
48 |
+
vert_hom = torch.cat((vertices, torch.ones(V,1,device=vertices.device)),axis=-1) #V,3 -> V,4
|
49 |
+
# transforms the vertices into clip space using the mvp matrix.
|
50 |
+
vertices_clip = vert_hom @ self._mvp.transpose(-2,-1) #C,V,4 # the .transpose(-2,-1) operation ensures that the matrix multiplication aligns with the row-major convention.
|
51 |
+
rast_out,_ = dr.rasterize(self._glctx, vertices_clip, faces, resolution=self._image_size, grad_db=False) #C,H,W,4 -> 4 includes the barycentric coordinates and other data.
|
52 |
+
vert_col = (normals+1)/2 #V,3
|
53 |
+
# this function takes the attributes (colors) defined at the vertices and computes their values at each pixel (or fragment) within the triangles
|
54 |
+
col,_ = dr.interpolate(vert_col, rast_out, faces) #C,H,W,3
|
55 |
+
alpha = torch.clamp(rast_out[..., -1:], max=1) #C,H,W,1
|
56 |
+
col = torch.concat((col,alpha),dim=-1) #C,H,W,4
|
57 |
+
col = dr.antialias(col, rast_out, vertices_clip, faces) #C,H,W,4
|
58 |
+
return col #C,H,W,4
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
from pytorch3d.structures import Meshes
|
63 |
+
from pytorch3d.renderer.mesh.shader import ShaderBase
|
64 |
+
from pytorch3d.renderer import (
|
65 |
+
RasterizationSettings,
|
66 |
+
MeshRendererWithFragments,
|
67 |
+
TexturesVertex,
|
68 |
+
MeshRasterizer,
|
69 |
+
BlendParams,
|
70 |
+
FoVOrthographicCameras,
|
71 |
+
look_at_view_transform,
|
72 |
+
hard_rgb_blend,
|
73 |
+
)
|
74 |
+
|
75 |
+
class VertexColorShader(ShaderBase):
|
76 |
+
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
77 |
+
blend_params = kwargs.get("blend_params", self.blend_params)
|
78 |
+
texels = meshes.sample_textures(fragments)
|
79 |
+
return hard_rgb_blend(texels, fragments, blend_params)
|
80 |
+
|
81 |
+
def render_mesh_vertex_color(mesh, cameras, H, W, blur_radius=0.0, faces_per_pixel=1, bkgd=(0., 0., 0.), dtype=torch.float32, device="cuda"):
|
82 |
+
if len(mesh) != len(cameras):
|
83 |
+
if len(cameras) % len(mesh) == 0:
|
84 |
+
mesh = mesh.extend(len(cameras))
|
85 |
+
else:
|
86 |
+
raise NotImplementedError()
|
87 |
+
|
88 |
+
# render requires everything in float16 or float32
|
89 |
+
input_dtype = dtype
|
90 |
+
blend_params = BlendParams(1e-4, 1e-4, bkgd)
|
91 |
+
|
92 |
+
# Define the settings for rasterization and shading
|
93 |
+
raster_settings = RasterizationSettings(
|
94 |
+
image_size=(H, W),
|
95 |
+
blur_radius=blur_radius,
|
96 |
+
faces_per_pixel=faces_per_pixel,
|
97 |
+
clip_barycentric_coords=True,
|
98 |
+
bin_size=None,
|
99 |
+
max_faces_per_bin=None,
|
100 |
+
)
|
101 |
+
|
102 |
+
# Create a renderer by composing a rasterizer and a shader
|
103 |
+
# We simply render vertex colors through the custom VertexColorShader (no lighting, materials are used)
|
104 |
+
renderer = MeshRendererWithFragments(
|
105 |
+
rasterizer=MeshRasterizer(
|
106 |
+
cameras=cameras,
|
107 |
+
raster_settings=raster_settings
|
108 |
+
),
|
109 |
+
shader=VertexColorShader(
|
110 |
+
device=device,
|
111 |
+
cameras=cameras,
|
112 |
+
blend_params=blend_params
|
113 |
+
)
|
114 |
+
)
|
115 |
+
|
116 |
+
# render RGB and depth, get mask
|
117 |
+
with torch.autocast(dtype=input_dtype, device_type=torch.device(device).type):
|
118 |
+
images, _ = renderer(mesh)
|
119 |
+
return images # BHW4
|
120 |
+
|
121 |
+
class Pytorch3DNormalsRenderer: # 100 times slower!!!
|
122 |
+
def __init__(self, cameras, image_size, device):
|
123 |
+
self.cameras = cameras.to(device)
|
124 |
+
self._image_size = image_size
|
125 |
+
self.device = device
|
126 |
+
|
127 |
+
def render(self,
|
128 |
+
vertices: torch.Tensor, #V,3 float
|
129 |
+
normals: torch.Tensor, #V,3 float in [-1, 1]
|
130 |
+
faces: torch.Tensor, #F,3 long
|
131 |
+
) ->torch.Tensor: #C,H,W,4
|
132 |
+
mesh = Meshes(verts=[vertices], faces=[faces], textures=TexturesVertex(verts_features=[(normals + 1) / 2])).to(self.device)
|
133 |
+
return render_mesh_vertex_color(mesh, self.cameras, self._image_size[0], self._image_size[1], device=self.device)
|
134 |
+
|
135 |
+
def save_tensor_to_img(tensor, save_dir):
|
136 |
+
from PIL import Image
|
137 |
+
import numpy as np
|
138 |
+
for idx, img in enumerate(tensor):
|
139 |
+
img = img[..., :3].cpu().numpy()
|
140 |
+
img = (img * 255).astype(np.uint8)
|
141 |
+
img = Image.fromarray(img)
|
142 |
+
img.save(save_dir + f"{idx}.png")
|
models/ISOMER/model/__init__.py
ADDED
File without changes
|
models/ISOMER/model/inference_pipeline.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from pytorch3d.structures import Meshes
|
7 |
+
from pytorch3d.renderer import TexturesVertex
|
8 |
+
|
9 |
+
from ..scripts.fast_geo import fast_geo, create_sphere, create_box
|
10 |
+
from ..scripts.project_mesh import get_cameras_list_azi_ele
|
11 |
+
from ..mesh_reconstruction.recon import reconstruct_stage1
|
12 |
+
from ..mesh_reconstruction.refine import run_mesh_refine
|
13 |
+
from ..mesh_reconstruction.func import make_star_cameras_orthographic, make_star_cameras_perspective
|
14 |
+
|
15 |
+
from ..data.utils import (
|
16 |
+
simple_remove_bkg_normal,
|
17 |
+
load_glb,
|
18 |
+
load_obj_with_verts_faces)
|
19 |
+
from ..scripts.utils import (
|
20 |
+
to_pyml_mesh,
|
21 |
+
simple_clean_mesh,
|
22 |
+
normal_rotation_img2img_c2w,
|
23 |
+
rotate_normal_R,
|
24 |
+
get_rotation_matrix_azi_ele,
|
25 |
+
manage_elevation_azimuth)
|
26 |
+
|
27 |
+
@torch.enable_grad()
|
28 |
+
def reconstruction_pipe(normal_pils,
|
29 |
+
rotation_angles_azi,
|
30 |
+
rotation_angles_ele,
|
31 |
+
front_index=0,
|
32 |
+
back_index=2,
|
33 |
+
side_index=1,
|
34 |
+
weights=None,
|
35 |
+
expansion_weight=0.1,
|
36 |
+
expansion_weight_stage2=0.0,
|
37 |
+
init_type="ball",
|
38 |
+
sphere_r=None, # only used if init_type=="ball"
|
39 |
+
box_width=1.0, # only used if init_type=="box"
|
40 |
+
box_length=1.0, # only used if init_type=="box"
|
41 |
+
box_height=1.0, # only used if init_type=="box"
|
42 |
+
init_verts=None,
|
43 |
+
init_faces=None,
|
44 |
+
init_mesh_from_file="",
|
45 |
+
stage1_steps=200,
|
46 |
+
stage2_steps=200,
|
47 |
+
projection_type="orthographic",
|
48 |
+
fovy=None,
|
49 |
+
radius=None,
|
50 |
+
ortho_dist=1.1,
|
51 |
+
camera_angles_azi=None,
|
52 |
+
camera_angles_ele=None,
|
53 |
+
rm_bkg=False,
|
54 |
+
rm_bkg_with_rembg=False, # only used if rm_bkg
|
55 |
+
normal_rotation_R=None,
|
56 |
+
train_stage1=True,
|
57 |
+
train_stage2=True,
|
58 |
+
use_remesh_stage1=True,
|
59 |
+
use_remesh_stage2=True,
|
60 |
+
start_edge_len_stage1=0.1,
|
61 |
+
end_edge_len_stage1=0.02,
|
62 |
+
start_edge_len_stage2=0.02,
|
63 |
+
end_edge_len_stage2=0.005,
|
64 |
+
):
|
65 |
+
|
66 |
+
assert projection_type in ['perspective', 'orthographic'], f"projection_type ({projection_type}) should be one of ['perspective', 'orthographic']"
|
67 |
+
|
68 |
+
if stage1_steps == 0:
|
69 |
+
train_stage1 = False
|
70 |
+
if stage2_steps == 0:
|
71 |
+
train_stage2 = False
|
72 |
+
|
73 |
+
if normal_rotation_R is not None:
|
74 |
+
assert normal_rotation_R.shape[-2] == 3 and normal_rotation_R.shape[-1] == 3
|
75 |
+
assert len(normal_rotation_R.shape) == 2
|
76 |
+
normal_rotation_R = normal_rotation_R.float()
|
77 |
+
|
78 |
+
camera_angles_azi = camera_angles_azi.float()
|
79 |
+
camera_angles_ele = camera_angles_ele.float()
|
80 |
+
|
81 |
+
camera_angles_ele, camera_angles_azi = manage_elevation_azimuth(camera_angles_ele, camera_angles_azi)
|
82 |
+
|
83 |
+
if init_type in ["std", "thin"]:
|
84 |
+
assert camera_angles_azi[front_index]%360==0, f"the camera_angles_azi associated with front image (index {front_index}) should be 0 not {camera_angles_azi[front_index]}"
|
85 |
+
assert camera_angles_azi[back_index]%360==180, f"the camera_angles_azi associated with back image (index {back_index}) should be 180 not {camera_angles_azi[back_index]}"
|
86 |
+
assert camera_angles_azi[side_index]%360==90, f"the camera_angles_azi associated with left side image (index {side_index}) should be 90, not {camera_angles_azi[back_index]}"
|
87 |
+
|
88 |
+
if rm_bkg:
|
89 |
+
if rm_bkg_with_rembg:
|
90 |
+
os.environ["OMP_NUM_THREADS"] = '8'
|
91 |
+
normal_pils = simple_remove_bkg_normal(normal_pils,rm_bkg_with_rembg)
|
92 |
+
|
93 |
+
if rotation_angles_azi is not None:
|
94 |
+
rotation_angles_azi = -rotation_angles_azi.float()
|
95 |
+
rotation_angles_ele = rotation_angles_ele.float()
|
96 |
+
|
97 |
+
rotation_angles_ele, rotation_angles_azi = manage_elevation_azimuth(rotation_angles_ele, rotation_angles_azi)
|
98 |
+
|
99 |
+
assert len(normal_pils) == len(rotation_angles_azi), f'len(normal_pils) ({len(normal_pils)}) != len(rotation_angles_azi) ({len(rotation_angles_azi)})'
|
100 |
+
if rotation_angles_ele is None:
|
101 |
+
rotation_angles_ele = [0] * len(normal_pils)
|
102 |
+
|
103 |
+
normal_pils_rotated = []
|
104 |
+
for i in range(len(normal_pils)):
|
105 |
+
c2w_R = get_rotation_matrix_azi_ele(rotation_angles_azi[i], rotation_angles_ele[i])
|
106 |
+
|
107 |
+
rotated_ = normal_rotation_img2img_c2w(normal_pils[i], c2w=c2w_R)
|
108 |
+
normal_pils_rotated.append(rotated_)
|
109 |
+
|
110 |
+
normal_pils = normal_pils_rotated
|
111 |
+
|
112 |
+
if normal_rotation_R is not None:
|
113 |
+
normal_pils_rotated = []
|
114 |
+
for i in range(len(normal_pils)):
|
115 |
+
rotated_ = rotate_normal_R(normal_pils[i], normal_rotation_R, save_addr="", device="cuda")
|
116 |
+
normal_pils_rotated.append(rotated_)
|
117 |
+
|
118 |
+
normal_pils = normal_pils_rotated
|
119 |
+
|
120 |
+
normal_stg1 = [img for img in normal_pils]
|
121 |
+
|
122 |
+
if init_type in ['thin', 'std']:
|
123 |
+
front_ = normal_stg1[front_index]
|
124 |
+
back_ = normal_stg1[back_index]
|
125 |
+
side_ = normal_stg1[side_index]
|
126 |
+
meshes, depth_front, depth_back, mesh_front, mesh_back = fast_geo(front_, back_, side_, init_type=init_type, return_depth_and_sep_mesh=True)
|
127 |
+
|
128 |
+
|
129 |
+
elif init_type in ["ball", "box"]:
|
130 |
+
|
131 |
+
if init_type == "ball":
|
132 |
+
assert sphere_r is not None, f"sphere_r ({sphere_r}) should not be None when init_type is 'ball'"
|
133 |
+
meshes = create_sphere(sphere_r)
|
134 |
+
|
135 |
+
if init_type == "box":
|
136 |
+
assert box_width is not None and box_length is not None and box_height is not None, f"box_width ({box_width}), box_length ({box_length}), and box_height ({box_height}) should not be None when init_type is 'box'"
|
137 |
+
meshes = create_box(width=box_width, length=box_length, height=box_height)
|
138 |
+
|
139 |
+
# add texture just in case
|
140 |
+
num_meshes = len(meshes)
|
141 |
+
num_verts_per_mesh = meshes.verts_packed().shape[0] // num_meshes
|
142 |
+
black_texture = torch.zeros((num_meshes, num_verts_per_mesh, 3), device="cuda")
|
143 |
+
textures = TexturesVertex(verts_features=black_texture)
|
144 |
+
meshes.textures = textures
|
145 |
+
|
146 |
+
elif init_type == "file":
|
147 |
+
assert init_mesh_from_file or (init_verts is not None and init_faces is not None), f"init_mesh_from_file ({init_mesh_from_file}) should not be None when init_type is 'file', else init_verts and init_faces should not be None"
|
148 |
+
|
149 |
+
if init_verts is not None and init_faces is not None:
|
150 |
+
meshes = Meshes(verts=[init_verts], faces=[init_faces]).to('cuda')
|
151 |
+
elif init_mesh_from_file.endswith('.glb'):
|
152 |
+
meshes = load_glb(init_mesh_from_file).to('cuda')
|
153 |
+
else:
|
154 |
+
meshes = load_obj_with_verts_faces(init_mesh_from_file).to('cuda')
|
155 |
+
|
156 |
+
# add texture just in case
|
157 |
+
num_meshes = len(meshes)
|
158 |
+
num_verts_per_mesh = meshes.verts_packed().shape[0] // num_meshes
|
159 |
+
black_texture = torch.zeros((num_meshes, num_verts_per_mesh, 3), device="cuda")
|
160 |
+
textures = TexturesVertex(verts_features=black_texture)
|
161 |
+
meshes.textures = textures
|
162 |
+
|
163 |
+
if projection_type == 'perspective':
|
164 |
+
assert fovy is not None and radius is not None, f"fovy ({fovy}) and radius ({radius}) should not be None when projection_type is 'perspective'"
|
165 |
+
cameras = get_cameras_list_azi_ele(camera_angles_azi, camera_angles_ele, fov_in_degrees=fovy,device="cuda", dist=radius, cam_type='fov')
|
166 |
+
|
167 |
+
elif projection_type == 'orthographic':
|
168 |
+
cameras = get_cameras_list_azi_ele(camera_angles_azi, camera_angles_ele, fov_in_degrees=fovy, device="cuda", focal=1., dist=ortho_dist, cam_type='orthographic')
|
169 |
+
|
170 |
+
vertices, faces = meshes.verts_list()[0], meshes.faces_list()[0]
|
171 |
+
|
172 |
+
render_camera_angles_azi = -camera_angles_azi
|
173 |
+
render_camera_angles_ele = camera_angles_ele
|
174 |
+
if projection_type == 'orthographic':
|
175 |
+
mv, proj = make_star_cameras_orthographic(render_camera_angles_azi, render_camera_angles_ele)
|
176 |
+
else:
|
177 |
+
mv, proj = make_star_cameras_perspective(render_camera_angles_azi, render_camera_angles_ele, distance=radius, r=radius, fov=fovy, device='cuda')
|
178 |
+
|
179 |
+
# stage 1
|
180 |
+
if train_stage1:
|
181 |
+
vertices, faces = reconstruct_stage1(normal_stg1, mv=mv, proj=proj, steps=stage1_steps, vertices=vertices, faces=faces, start_edge_len=start_edge_len_stage1, end_edge_len=end_edge_len_stage1, gain=0.05, return_mesh=False, loss_expansion_weight=expansion_weight, use_remesh=use_remesh_stage1)
|
182 |
+
|
183 |
+
# stage 2
|
184 |
+
if train_stage2:
|
185 |
+
vertices, faces = run_mesh_refine(vertices, faces, normal_pils, mv=mv, proj=proj, weights=weights, steps=stage2_steps, start_edge_len=start_edge_len_stage2, end_edge_len=end_edge_len_stage2, decay=0.99, update_normal_interval=20, update_warmup=5, return_mesh=False, process_inputs=False, process_outputs=False, cameras=cameras, use_remesh=use_remesh_stage2, loss_expansion_weight=expansion_weight_stage2)
|
186 |
+
|
187 |
+
meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=1, apply_sub_divide=True, sub_divide_threshold=0.25).to("cuda")
|
188 |
+
|
189 |
+
return meshes
|
models/ISOMER/projection_func.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
from .scripts.proj_commands import projection as isomer_projection
|
7 |
+
from .data.utils import simple_remove_bkg_normal
|
8 |
+
|
9 |
+
# mesh_address,
|
10 |
+
def projection(
|
11 |
+
meshes,
|
12 |
+
masks,
|
13 |
+
images,
|
14 |
+
azimuths,
|
15 |
+
elevations,
|
16 |
+
weights,
|
17 |
+
fov,
|
18 |
+
radius,
|
19 |
+
save_dir,
|
20 |
+
save_glb_addr=None,
|
21 |
+
remove_background=False,
|
22 |
+
auto_center=False,
|
23 |
+
projection_type="perspective",
|
24 |
+
below_confidence_strategy="smooth",
|
25 |
+
complete_unseen=True,
|
26 |
+
mesh_scale_factor=1.0,
|
27 |
+
rm_bkg_with_rembg=True,
|
28 |
+
):
|
29 |
+
|
30 |
+
if save_glb_addr is None:
|
31 |
+
os.makedirs(save_dir, exist_ok=True)
|
32 |
+
save_glb_addr=os.path.join(save_dir, "rgb_projected.glb")
|
33 |
+
|
34 |
+
bs = len(images)
|
35 |
+
assert len(azimuths) == bs, f'len(azimuths) ({len(azimuths)} != batchsize ({bs}))'
|
36 |
+
assert len(elevations) == bs, f'len(elevations) ({len(elevations)} != batchsize ({bs}))'
|
37 |
+
assert len(weights) == bs, f'len(weights) ({len(weights)} != batchsize ({bs}))'
|
38 |
+
|
39 |
+
image_rgba = torch.cat([images[:,:,:,:3], masks.unsqueeze(-1)], dim=-1)
|
40 |
+
|
41 |
+
assert image_rgba.shape[-1] == 4, f'image_rgba.shape is {image_rgba.shape}'
|
42 |
+
|
43 |
+
img_list = [Image.fromarray((image.cpu()*255).numpy().astype(np.uint8)) for image in image_rgba]
|
44 |
+
|
45 |
+
|
46 |
+
if remove_background:
|
47 |
+
if rm_bkg_with_rembg:
|
48 |
+
os.environ["OMP_NUM_THREADS"] = '8'
|
49 |
+
img_list = simple_remove_bkg_normal(img_list, rm_bkg_with_rembg, return_Image=True)
|
50 |
+
|
51 |
+
resolution = img_list[0].size[0]
|
52 |
+
new_img_list = []
|
53 |
+
for i in range(len(img_list)):
|
54 |
+
new_img = img_list[i].resize((resolution,resolution))
|
55 |
+
|
56 |
+
path_dir = os.path.join(save_dir, f'projection_images')
|
57 |
+
os.makedirs(path_dir, exist_ok=True)
|
58 |
+
|
59 |
+
path_ = os.path.join(path_dir, f'ProjectionImg{i}.png')
|
60 |
+
|
61 |
+
new_img.save(path_)
|
62 |
+
|
63 |
+
new_img_list.append(new_img)
|
64 |
+
|
65 |
+
img_list = new_img_list
|
66 |
+
|
67 |
+
isomer_projection(meshes,
|
68 |
+
img_list=img_list,
|
69 |
+
weights=weights,
|
70 |
+
azimuths=azimuths,
|
71 |
+
elevations=elevations,
|
72 |
+
projection_type=projection_type,
|
73 |
+
auto_center=auto_center,
|
74 |
+
resolution=resolution,
|
75 |
+
fovy=fov,
|
76 |
+
radius=radius,
|
77 |
+
scale_factor=mesh_scale_factor,
|
78 |
+
save_glb_addr=save_glb_addr,
|
79 |
+
scale_verts=True,
|
80 |
+
complete_unseen=complete_unseen,
|
81 |
+
below_confidence_strategy=below_confidence_strategy
|
82 |
+
)
|
83 |
+
|
84 |
+
return save_glb_addr
|
85 |
+
|
86 |
+
|
models/ISOMER/reconstruction_func.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
from .model.inference_pipeline import reconstruction_pipe
|
7 |
+
|
8 |
+
def reconstruction(
|
9 |
+
normal_pils,
|
10 |
+
masks,
|
11 |
+
weights,
|
12 |
+
fov,
|
13 |
+
radius,
|
14 |
+
camera_angles_azi,
|
15 |
+
camera_angles_ele,
|
16 |
+
expansion_weight_stage1=0.1,
|
17 |
+
init_type="ball",
|
18 |
+
init_verts=None,
|
19 |
+
init_faces=None,
|
20 |
+
init_mesh_from_file="",
|
21 |
+
stage1_steps=200,
|
22 |
+
stage2_steps=200,
|
23 |
+
projection_type="perspective",
|
24 |
+
need_normal_rotation=False,
|
25 |
+
rotation_angles_azi=None, # only used if need_normal_rotation
|
26 |
+
rotation_angles_ele=None, # only used if need_normal_rotation
|
27 |
+
normal_rotation_R=None, # only used if need_normal_rotation
|
28 |
+
rm_bkg=False,
|
29 |
+
rm_bkg_with_rembg=True, # only used if rm_bkg
|
30 |
+
start_edge_len_stage1=0.1,
|
31 |
+
end_edge_len_stage1=0.02,
|
32 |
+
start_edge_len_stage2=0.02,
|
33 |
+
end_edge_len_stage2=0.005,
|
34 |
+
expansion_weight_stage2=0.0,
|
35 |
+
):
|
36 |
+
|
37 |
+
if init_type == "file":
|
38 |
+
assert ((init_verts is not None and init_faces is not None) or init_mesh_from_file), f'init_mesh_from_file or (init_verts and init_faces) must be provided if init_type=="file"'
|
39 |
+
|
40 |
+
if not need_normal_rotation:
|
41 |
+
rotation_angles_azi = None
|
42 |
+
rotation_angles_ele = None
|
43 |
+
normal_rotation_R = None
|
44 |
+
|
45 |
+
bs = len(normal_pils)
|
46 |
+
|
47 |
+
assert len(camera_angles_azi) == bs, f'len(camera_angles_azi) ({len(camera_angles_azi)} != batchsize ({bs}))'
|
48 |
+
assert len(camera_angles_ele) == bs, f'len(camera_angles_ele) ({len(camera_angles_ele)} != batchsize ({bs}))'
|
49 |
+
|
50 |
+
normal_pils_rgba = torch.cat([normal_pils[:,:,:,:3], masks.unsqueeze(-1)], dim=-1)
|
51 |
+
|
52 |
+
assert normal_pils_rgba.shape[-1] == 4, f'normal_pils_rgba.shape is {normal_pils_rgba.shape}'
|
53 |
+
|
54 |
+
|
55 |
+
normal_pils = [Image.fromarray((normal_pil.cpu()*255).numpy().astype(np.uint8)) for normal_pil in normal_pils_rgba]
|
56 |
+
|
57 |
+
|
58 |
+
meshes = reconstruction_pipe(
|
59 |
+
normal_pils=normal_pils,
|
60 |
+
rotation_angles_azi=rotation_angles_azi,
|
61 |
+
rotation_angles_ele=rotation_angles_ele,
|
62 |
+
weights=weights,
|
63 |
+
expansion_weight=expansion_weight_stage1,
|
64 |
+
init_type=init_type,
|
65 |
+
stage1_steps=stage1_steps,
|
66 |
+
stage2_steps=stage2_steps,
|
67 |
+
projection_type=projection_type,
|
68 |
+
fovy=fov,
|
69 |
+
radius=radius,
|
70 |
+
camera_angles_azi=camera_angles_azi,
|
71 |
+
camera_angles_ele=camera_angles_ele,
|
72 |
+
rm_bkg=rm_bkg, rm_bkg_with_rembg=rm_bkg_with_rembg,
|
73 |
+
normal_rotation_R=normal_rotation_R,
|
74 |
+
init_mesh_from_file=init_mesh_from_file,
|
75 |
+
start_edge_len_stage1=start_edge_len_stage1,
|
76 |
+
end_edge_len_stage1=end_edge_len_stage1,
|
77 |
+
start_edge_len_stage2=start_edge_len_stage2,
|
78 |
+
end_edge_len_stage2=end_edge_len_stage2,
|
79 |
+
expansion_weight_stage2=expansion_weight_stage2,
|
80 |
+
init_verts=init_verts,
|
81 |
+
init_faces=init_faces,
|
82 |
+
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
return meshes
|
87 |
+
|
88 |
+
|
models/ISOMER/scripts/__init__.py
ADDED
File without changes
|
models/ISOMER/scripts/all_typing.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code from https://github.com/threestudio-project
|
2 |
+
|
3 |
+
"""
|
4 |
+
This module contains type annotations for the project, using
|
5 |
+
1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
|
6 |
+
2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
|
7 |
+
|
8 |
+
Two types of typing checking can be used:
|
9 |
+
1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
|
10 |
+
2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
|
11 |
+
"""
|
12 |
+
|
13 |
+
# Basic types
|
14 |
+
from typing import (
|
15 |
+
Any,
|
16 |
+
Callable,
|
17 |
+
Dict,
|
18 |
+
Iterable,
|
19 |
+
List,
|
20 |
+
Literal,
|
21 |
+
NamedTuple,
|
22 |
+
NewType,
|
23 |
+
Optional,
|
24 |
+
Sized,
|
25 |
+
Tuple,
|
26 |
+
Type,
|
27 |
+
TypeVar,
|
28 |
+
Union,
|
29 |
+
)
|
30 |
+
|
31 |
+
# Tensor dtype
|
32 |
+
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
|
33 |
+
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
|
34 |
+
|
35 |
+
# Config type
|
36 |
+
from omegaconf import DictConfig
|
37 |
+
|
38 |
+
# PyTorch Tensor type
|
39 |
+
from torch import Tensor
|
40 |
+
|
41 |
+
# Runtime type checking decorator
|
42 |
+
from typeguard import typechecked as typechecker
|
models/ISOMER/scripts/fast_geo.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
from .mesh_init import build_mesh, calc_w_over_h, fix_border_with_pymeshlab_fast
|
4 |
+
from pytorch3d.structures import Meshes, join_meshes_as_scene
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from pytorch3d.structures import Meshes
|
9 |
+
from pytorch3d.utils import ico_sphere
|
10 |
+
|
11 |
+
def create_sphere(radius, device='cuda'):
|
12 |
+
|
13 |
+
sphere_mesh = ico_sphere(3, device=device) # Increase the subdivision level (e.g., 2) for higher resolution sphere
|
14 |
+
sphere_mesh = sphere_mesh.scale_verts(radius)
|
15 |
+
|
16 |
+
meshes = Meshes(verts=[sphere_mesh.verts_list()[0]], faces=[sphere_mesh.faces_list()[0]])
|
17 |
+
return meshes
|
18 |
+
|
19 |
+
|
20 |
+
def create_box(width, length, height, device='cuda'):
|
21 |
+
"""
|
22 |
+
Create a box mesh given the width, length, and height.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
width (float): Width of the box.
|
26 |
+
length (float): Length of the box.
|
27 |
+
height (float): Height of the box.
|
28 |
+
device (str): Device for the tensor operations, default is 'cuda'.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
Meshes: A PyTorch3D Meshes object representing the box.
|
32 |
+
"""
|
33 |
+
# Define the 8 vertices of the box
|
34 |
+
verts = torch.tensor([
|
35 |
+
[-width / 2, -length / 2, -height / 2],
|
36 |
+
[ width / 2, -length / 2, -height / 2],
|
37 |
+
[ width / 2, length / 2, -height / 2],
|
38 |
+
[-width / 2, length / 2, -height / 2],
|
39 |
+
[-width / 2, -length / 2, height / 2],
|
40 |
+
[ width / 2, -length / 2, height / 2],
|
41 |
+
[ width / 2, length / 2, height / 2],
|
42 |
+
[-width / 2, length / 2, height / 2]
|
43 |
+
], device=device)
|
44 |
+
|
45 |
+
# Define the 12 triangles (faces) of the box using vertex indices
|
46 |
+
faces = torch.tensor([
|
47 |
+
[0, 1, 2], [0, 2, 3], # Bottom face
|
48 |
+
[4, 5, 6], [4, 6, 7], # Top face
|
49 |
+
[0, 1, 5], [0, 5, 4], # Front face
|
50 |
+
[1, 2, 6], [1, 6, 5], # Right face
|
51 |
+
[2, 3, 7], [2, 7, 6], # Back face
|
52 |
+
[3, 0, 4], [3, 4, 7] # Left face
|
53 |
+
], device=device)
|
54 |
+
|
55 |
+
# Create the Meshes object
|
56 |
+
meshes = Meshes(verts=[verts], faces=[faces])
|
57 |
+
|
58 |
+
return meshes
|
59 |
+
|
60 |
+
|
61 |
+
# stage 0 inital mesh estimation
|
62 |
+
def fast_geo(front_normal: Image.Image, back_normal: Image.Image, side_normal: Image.Image, clamp=0., init_type="std", return_depth_and_sep_mesh=False):
|
63 |
+
|
64 |
+
import time
|
65 |
+
assert front_normal.mode != "RGB"
|
66 |
+
assert back_normal.mode != "RGB"
|
67 |
+
assert side_normal.mode != "RGB"
|
68 |
+
|
69 |
+
front_normal = front_normal.resize((192, 192))
|
70 |
+
back_normal = back_normal.resize((192, 192))
|
71 |
+
side_normal = side_normal.resize((192, 192))
|
72 |
+
|
73 |
+
# build mesh with front back projection # ~3s
|
74 |
+
side_w_over_h = calc_w_over_h(side_normal)
|
75 |
+
mesh_front, depth_front = build_mesh(front_normal, front_normal, clamp_min=clamp, scale=side_w_over_h, init_type=init_type, return_depth=True)
|
76 |
+
mesh_back, depth_back = build_mesh(back_normal, back_normal, is_back=True, clamp_min=clamp, scale=side_w_over_h, init_type=init_type, return_depth=True)
|
77 |
+
meshes = join_meshes_as_scene([mesh_front, mesh_back])
|
78 |
+
|
79 |
+
# poisson reconstruction which guarantees a smooth connection between meshes
|
80 |
+
# and simplify into 2000 fewer faces
|
81 |
+
meshes = fix_border_with_pymeshlab_fast(meshes, poissson_depth=6, simplification=2000)
|
82 |
+
|
83 |
+
|
84 |
+
if return_depth_and_sep_mesh:
|
85 |
+
return meshes, depth_front, depth_back, mesh_front, mesh_back
|
86 |
+
return meshes
|
models/ISOMER/scripts/load_onnx.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import onnxruntime
|
2 |
+
import torch
|
3 |
+
|
4 |
+
providers = [
|
5 |
+
('TensorrtExecutionProvider', {
|
6 |
+
'device_id': 0,
|
7 |
+
'trt_max_workspace_size': 8 * 1024 * 1024 * 1024,
|
8 |
+
'trt_fp16_enable': True,
|
9 |
+
'trt_engine_cache_enable': True,
|
10 |
+
}),
|
11 |
+
('CUDAExecutionProvider', {
|
12 |
+
'device_id': 0,
|
13 |
+
'arena_extend_strategy': 'kSameAsRequested',
|
14 |
+
'gpu_mem_limit': 8 * 1024 * 1024 * 1024,
|
15 |
+
'cudnn_conv_algo_search': 'HEURISTIC',
|
16 |
+
})
|
17 |
+
]
|
18 |
+
|
19 |
+
def load_onnx(file_path: str):
|
20 |
+
assert file_path.endswith(".onnx")
|
21 |
+
sess_opt = onnxruntime.SessionOptions()
|
22 |
+
ort_session = onnxruntime.InferenceSession(file_path, sess_opt=sess_opt, providers=providers)
|
23 |
+
return ort_session
|
24 |
+
|
25 |
+
|
26 |
+
def load_onnx_caller(file_path: str, single_output=False):
|
27 |
+
ort_session = load_onnx(file_path)
|
28 |
+
def caller(*args):
|
29 |
+
torch_input = isinstance(args[0], torch.Tensor)
|
30 |
+
if torch_input:
|
31 |
+
torch_input_dtype = args[0].dtype
|
32 |
+
torch_input_device = args[0].device
|
33 |
+
# check all are torch.Tensor and have same dtype and device
|
34 |
+
assert all([isinstance(arg, torch.Tensor) for arg in args]), "All inputs should be torch.Tensor, if first input is torch.Tensor"
|
35 |
+
assert all([arg.dtype == torch_input_dtype for arg in args]), "All inputs should have same dtype, if first input is torch.Tensor"
|
36 |
+
assert all([arg.device == torch_input_device for arg in args]), "All inputs should have same device, if first input is torch.Tensor"
|
37 |
+
args = [arg.cpu().float().numpy() for arg in args]
|
38 |
+
|
39 |
+
ort_inputs = {ort_session.get_inputs()[idx].name: args[idx] for idx in range(len(args))}
|
40 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
41 |
+
|
42 |
+
if torch_input:
|
43 |
+
ort_outs = [torch.tensor(ort_out, dtype=torch_input_dtype, device=torch_input_device) for ort_out in ort_outs]
|
44 |
+
|
45 |
+
if single_output:
|
46 |
+
return ort_outs[0]
|
47 |
+
return ort_outs
|
48 |
+
return caller
|
models/ISOMER/scripts/mesh_init.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from pytorch3d.structures import Meshes
|
5 |
+
from pytorch3d.renderer import TexturesVertex
|
6 |
+
from .utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh
|
7 |
+
import pymeshlab
|
8 |
+
|
9 |
+
_MAX_THREAD = 8
|
10 |
+
|
11 |
+
# rgb and depth to mesh
|
12 |
+
def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"):
|
13 |
+
pixel_center = 0.5 if use_pixel_centers else 0
|
14 |
+
i, j = np.meshgrid(
|
15 |
+
np.arange(W, dtype=np.float32) + pixel_center,
|
16 |
+
np.arange(H, dtype=np.float32) + pixel_center,
|
17 |
+
indexing='xy'
|
18 |
+
)
|
19 |
+
i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device)
|
20 |
+
|
21 |
+
origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3
|
22 |
+
directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3
|
23 |
+
|
24 |
+
return origins, directions
|
25 |
+
|
26 |
+
def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False):
|
27 |
+
if valid_HWC is None:
|
28 |
+
valid_HWC = torch.ones_like(pred_HWC).bool()
|
29 |
+
H, W = rgb_BCHW.shape[-2:]
|
30 |
+
rgb_BCHW = rgb_BCHW.flip(-2)
|
31 |
+
pred_HWC = pred_HWC.flip(0)
|
32 |
+
valid_HWC = valid_HWC.flip(0)
|
33 |
+
rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device)
|
34 |
+
verts = rays_o + rays_d * pred_HWC # [H, W, 3]
|
35 |
+
verts = verts.reshape(-1, 3) # [V, 3]
|
36 |
+
indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device)
|
37 |
+
faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1)
|
38 |
+
# faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1]
|
39 |
+
faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1]
|
40 |
+
faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1)
|
41 |
+
# faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:]
|
42 |
+
faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:]
|
43 |
+
faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3),
|
44 |
+
faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)],
|
45 |
+
dim=0) # (F, 3)
|
46 |
+
colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3)
|
47 |
+
if is_back:
|
48 |
+
verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device)
|
49 |
+
|
50 |
+
used_verts = faces.unique()
|
51 |
+
old_to_new_mapping = torch.zeros_like(verts[..., 0]).long()
|
52 |
+
old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device)
|
53 |
+
new_faces = old_to_new_mapping[faces]
|
54 |
+
mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]]))
|
55 |
+
return mesh
|
56 |
+
|
57 |
+
def normalmap_to_depthmap(normal_np):
|
58 |
+
from .normal_to_height_map import estimate_height_map
|
59 |
+
height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96)
|
60 |
+
return height
|
61 |
+
|
62 |
+
def transform_back_normal_to_front(normal_pil):
|
63 |
+
arr = np.array(normal_pil) # in [0, 255]
|
64 |
+
arr[..., 0] = 255-arr[..., 0]
|
65 |
+
arr[..., 2] = 255-arr[..., 2]
|
66 |
+
return Image.fromarray(arr.astype(np.uint8))
|
67 |
+
|
68 |
+
def calc_w_over_h(normal_pil):
|
69 |
+
if isinstance(normal_pil, Image.Image):
|
70 |
+
arr = np.array(normal_pil)
|
71 |
+
else:
|
72 |
+
assert isinstance(normal_pil, np.ndarray)
|
73 |
+
arr = normal_pil
|
74 |
+
if arr.shape[-1] == 4:
|
75 |
+
alpha = arr[..., -1] / 255.
|
76 |
+
alpha[alpha >= 0.5] = 1
|
77 |
+
alpha[alpha < 0.5] = 0
|
78 |
+
else:
|
79 |
+
alpha = ~(arr.min(axis=-1) >= 250)
|
80 |
+
h_min, w_min = np.min(np.where(alpha), axis=1)
|
81 |
+
h_max, w_max = np.max(np.where(alpha), axis=1)
|
82 |
+
return (w_max - w_min) / (h_max - h_min)
|
83 |
+
|
84 |
+
def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0, return_depth=False):
|
85 |
+
if is_back:
|
86 |
+
normal_pil = transform_back_normal_to_front(normal_pil)
|
87 |
+
normal_img = np.array(normal_pil)
|
88 |
+
rgb_img = np.array(rgb_pil)
|
89 |
+
if normal_img.shape[-1] == 4:
|
90 |
+
valid_HWC = normal_img[..., [3]] / 255
|
91 |
+
elif rgb_img.shape[-1] == 4:
|
92 |
+
valid_HWC = rgb_img[..., [3]] / 255
|
93 |
+
else:
|
94 |
+
raise ValueError("invalid input, either normal or rgb should have alpha channel")
|
95 |
+
|
96 |
+
# object area pixels height
|
97 |
+
real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0])
|
98 |
+
|
99 |
+
heights = normalmap_to_depthmap(normal_img)
|
100 |
+
rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None]
|
101 |
+
valid_HWC[valid_HWC < 0.5] = 0
|
102 |
+
valid_HWC[valid_HWC >= 0.5] = 1
|
103 |
+
valid_HWC = torch.from_numpy(valid_HWC).bool()
|
104 |
+
|
105 |
+
if init_type == "std":
|
106 |
+
# accurate but not stable
|
107 |
+
pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None]
|
108 |
+
elif init_type == "thin":
|
109 |
+
heights = heights - heights.min()
|
110 |
+
heights = (heights / heights.max() * 0.2)
|
111 |
+
pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
|
112 |
+
else:
|
113 |
+
# stable but not accurate
|
114 |
+
heights = heights - heights.min()
|
115 |
+
heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1]
|
116 |
+
pred_HWC = torch.from_numpy(heights * scale).float()[..., None]
|
117 |
+
|
118 |
+
# set the boarder pixels to 0 height
|
119 |
+
import cv2
|
120 |
+
# edge filter
|
121 |
+
edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255)
|
122 |
+
edge = torch.from_numpy(edge).bool()[..., None]
|
123 |
+
pred_HWC[edge] = 0
|
124 |
+
|
125 |
+
valid_HWC[pred_HWC < clamp_min] = False
|
126 |
+
rt_mesh = depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back)
|
127 |
+
|
128 |
+
if return_depth:
|
129 |
+
return rt_mesh, pred_HWC
|
130 |
+
return rt_mesh
|
131 |
+
|
132 |
+
# poisson reconstruction which guarantees a smooth connection between meshes
|
133 |
+
# and simplify into 2000 fewer faces
|
134 |
+
def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0):
|
135 |
+
ms = pymeshlab.MeshSet()
|
136 |
+
ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh")
|
137 |
+
if simplification > 0:
|
138 |
+
ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
|
139 |
+
ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True)
|
140 |
+
if simplification > 0:
|
141 |
+
ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True)
|
142 |
+
return meshlab_mesh_to_py3dmesh(ms.current_mesh())
|
models/ISOMER/scripts/normal_to_height_map.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code modified from https://github.com/YertleTurtleGit/depth-from-normals
|
2 |
+
import numpy as np
|
3 |
+
import cv2 as cv
|
4 |
+
from multiprocessing.pool import ThreadPool as Pool
|
5 |
+
from multiprocessing import cpu_count
|
6 |
+
from typing import Tuple, List, Union
|
7 |
+
import numba
|
8 |
+
|
9 |
+
|
10 |
+
def calculate_gradients(
|
11 |
+
normals: np.ndarray, mask: np.ndarray
|
12 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
13 |
+
horizontal_angle_map = np.arccos(np.clip(normals[:, :, 0], -1, 1))
|
14 |
+
left_gradients = np.zeros(normals.shape[:2])
|
15 |
+
left_gradients[mask != 0] = (1 - np.sin(horizontal_angle_map[mask != 0])) * np.sign(
|
16 |
+
horizontal_angle_map[mask != 0] - np.pi / 2
|
17 |
+
)
|
18 |
+
|
19 |
+
vertical_angle_map = np.arccos(np.clip(normals[:, :, 1], -1, 1))
|
20 |
+
top_gradients = np.zeros(normals.shape[:2])
|
21 |
+
top_gradients[mask != 0] = -(1 - np.sin(vertical_angle_map[mask != 0])) * np.sign(
|
22 |
+
vertical_angle_map[mask != 0] - np.pi / 2
|
23 |
+
)
|
24 |
+
|
25 |
+
return left_gradients, top_gradients
|
26 |
+
|
27 |
+
|
28 |
+
@numba.jit(nopython=True)
|
29 |
+
def integrate_gradient_field(
|
30 |
+
gradient_field: np.ndarray, axis: int, mask: np.ndarray
|
31 |
+
) -> np.ndarray:
|
32 |
+
heights = np.zeros(gradient_field.shape)
|
33 |
+
|
34 |
+
for d1 in numba.prange(heights.shape[1 - axis]): # numba.prange: executes the loop in parallel
|
35 |
+
sum_value = 0
|
36 |
+
for d2 in range(heights.shape[axis]):
|
37 |
+
coordinates = (d1, d2) if axis == 1 else (d2, d1)
|
38 |
+
|
39 |
+
if mask[coordinates] != 0:
|
40 |
+
sum_value = sum_value + gradient_field[coordinates] # equation 1 in paper along `axis` axis
|
41 |
+
heights[coordinates] = sum_value
|
42 |
+
else:
|
43 |
+
sum_value = 0
|
44 |
+
|
45 |
+
return heights
|
46 |
+
|
47 |
+
# equation 1 in paper wrt these directions
|
48 |
+
def calculate_heights(
|
49 |
+
left_gradients: np.ndarray, top_gradients, mask: np.ndarray
|
50 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
51 |
+
left_heights = integrate_gradient_field(left_gradients, 1, mask)
|
52 |
+
right_heights = np.fliplr(
|
53 |
+
integrate_gradient_field(np.fliplr(-left_gradients), 1, np.fliplr(mask))
|
54 |
+
)
|
55 |
+
top_heights = integrate_gradient_field(top_gradients, 0, mask)
|
56 |
+
bottom_heights = np.flipud(
|
57 |
+
integrate_gradient_field(np.flipud(-top_gradients), 0, np.flipud(mask))
|
58 |
+
)
|
59 |
+
return left_heights, right_heights, top_heights, bottom_heights
|
60 |
+
|
61 |
+
|
62 |
+
def combine_heights(*heights: np.ndarray) -> np.ndarray:
|
63 |
+
return np.mean(np.stack(heights, axis=0), axis=0)
|
64 |
+
|
65 |
+
|
66 |
+
def rotate(matrix: np.ndarray, angle: float) -> np.ndarray:
|
67 |
+
h, w = matrix.shape[:2]
|
68 |
+
center = (w / 2, h / 2)
|
69 |
+
|
70 |
+
rotation_matrix = cv.getRotationMatrix2D(center, angle, 1.0)
|
71 |
+
corners = cv.transform(
|
72 |
+
np.array([[[0, 0], [w, 0], [w, h], [0, h]]]), rotation_matrix
|
73 |
+
)[0]
|
74 |
+
|
75 |
+
_, _, w, h = cv.boundingRect(corners)
|
76 |
+
|
77 |
+
rotation_matrix[0, 2] += w / 2 - center[0]
|
78 |
+
rotation_matrix[1, 2] += h / 2 - center[1]
|
79 |
+
result = cv.warpAffine(matrix, rotation_matrix, (w, h), flags=cv.INTER_LINEAR)
|
80 |
+
|
81 |
+
return result
|
82 |
+
|
83 |
+
|
84 |
+
def rotate_vector_field_normals(normals: np.ndarray, angle: float) -> np.ndarray:
|
85 |
+
angle = np.radians(angle)
|
86 |
+
cos_angle = np.cos(angle)
|
87 |
+
sin_angle = np.sin(angle)
|
88 |
+
|
89 |
+
rotated_normals = np.empty_like(normals)
|
90 |
+
rotated_normals[:, :, 0] = (
|
91 |
+
normals[:, :, 0] * cos_angle - normals[:, :, 1] * sin_angle
|
92 |
+
)
|
93 |
+
rotated_normals[:, :, 1] = (
|
94 |
+
normals[:, :, 0] * sin_angle + normals[:, :, 1] * cos_angle
|
95 |
+
)
|
96 |
+
|
97 |
+
return rotated_normals
|
98 |
+
|
99 |
+
|
100 |
+
def centered_crop(image: np.ndarray, target_resolution: Tuple[int, int]) -> np.ndarray:
|
101 |
+
return image[
|
102 |
+
(image.shape[0] - target_resolution[0])
|
103 |
+
// 2 : (image.shape[0] - target_resolution[0])
|
104 |
+
// 2
|
105 |
+
+ target_resolution[0],
|
106 |
+
(image.shape[1] - target_resolution[1])
|
107 |
+
// 2 : (image.shape[1] - target_resolution[1])
|
108 |
+
// 2
|
109 |
+
+ target_resolution[1],
|
110 |
+
]
|
111 |
+
|
112 |
+
|
113 |
+
def integrate_vector_field(
|
114 |
+
vector_field: np.ndarray,
|
115 |
+
mask: np.ndarray,
|
116 |
+
target_iteration_count: int,
|
117 |
+
thread_count: int,
|
118 |
+
) -> np.ndarray:
|
119 |
+
shape = vector_field.shape[:2]
|
120 |
+
angles = np.linspace(0, 90, target_iteration_count, endpoint=False)
|
121 |
+
|
122 |
+
def integrate_vector_field_angles(angles: List[float]) -> np.ndarray:
|
123 |
+
all_combined_heights = np.zeros(shape)
|
124 |
+
|
125 |
+
for angle in angles:
|
126 |
+
rotated_vector_field = rotate_vector_field_normals(
|
127 |
+
rotate(vector_field, angle), angle
|
128 |
+
) # rotate twice: first rotate the whole in image level, then rotate the individual normal vectors
|
129 |
+
|
130 |
+
rotated_mask = rotate(mask, angle)
|
131 |
+
|
132 |
+
left_gradients, top_gradients = calculate_gradients(
|
133 |
+
rotated_vector_field, rotated_mask
|
134 |
+
)
|
135 |
+
(
|
136 |
+
left_heights,
|
137 |
+
right_heights,
|
138 |
+
top_heights,
|
139 |
+
bottom_heights,
|
140 |
+
) = calculate_heights(left_gradients, top_gradients, rotated_mask)
|
141 |
+
|
142 |
+
combined_heights = combine_heights(
|
143 |
+
left_heights, right_heights, top_heights, bottom_heights
|
144 |
+
) # = mean of these heights
|
145 |
+
combined_heights = centered_crop(rotate(combined_heights, -angle), shape)
|
146 |
+
all_combined_heights += combined_heights / len(angles)
|
147 |
+
|
148 |
+
return all_combined_heights
|
149 |
+
|
150 |
+
with Pool(processes=thread_count) as pool:
|
151 |
+
heights = pool.map(
|
152 |
+
integrate_vector_field_angles,
|
153 |
+
np.array(
|
154 |
+
np.array_split(angles, thread_count),
|
155 |
+
dtype=object,
|
156 |
+
),
|
157 |
+
)
|
158 |
+
pool.close()
|
159 |
+
pool.join()
|
160 |
+
|
161 |
+
isotropic_height = np.zeros(shape)
|
162 |
+
for height in heights:
|
163 |
+
isotropic_height += height / thread_count
|
164 |
+
|
165 |
+
return isotropic_height
|
166 |
+
|
167 |
+
|
168 |
+
def estimate_height_map(
|
169 |
+
normal_map: np.ndarray,
|
170 |
+
mask: Union[np.ndarray, None] = None,
|
171 |
+
height_divisor: float = 1,
|
172 |
+
target_iteration_count: int = 250,
|
173 |
+
thread_count: int = cpu_count(),
|
174 |
+
raw_values: bool = False,
|
175 |
+
) -> np.ndarray:
|
176 |
+
if mask is None:
|
177 |
+
if normal_map.shape[-1] == 4:
|
178 |
+
mask = normal_map[:, :, 3] / 255
|
179 |
+
mask[mask < 0.5] = 0
|
180 |
+
mask[mask >= 0.5] = 1
|
181 |
+
else:
|
182 |
+
mask = np.ones(normal_map.shape[:2], dtype=np.uint8)
|
183 |
+
|
184 |
+
normals = ((normal_map[:, :, :3].astype(np.float64) / 255) - 0.5) * 2
|
185 |
+
heights = integrate_vector_field(
|
186 |
+
normals, mask, target_iteration_count, thread_count
|
187 |
+
) # equation 1 in paper, repeat `target_iteration_count` (8?) times with rotation in angle np.linspace(0, 90, target_iteration_count), then find mean
|
188 |
+
# target_iteration_count=8 ? defined _MAX_THREAD = 8 in mesh_init.py
|
189 |
+
|
190 |
+
if raw_values:
|
191 |
+
return heights
|
192 |
+
|
193 |
+
heights /= height_divisor
|
194 |
+
heights[mask > 0] += 1 / 2
|
195 |
+
heights[mask == 0] = 1 / 2
|
196 |
+
|
197 |
+
heights *= 2**16 - 1
|
198 |
+
|
199 |
+
if np.min(heights) < 0 or np.max(heights) > 2**16 - 1:
|
200 |
+
raise OverflowError("Height values are clipping.")
|
201 |
+
|
202 |
+
heights = np.clip(heights, 0, 2**16 - 1)
|
203 |
+
heights = heights.astype(np.uint16)
|
204 |
+
|
205 |
+
return heights
|
models/ISOMER/scripts/proj_commands.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from pytorch3d.renderer import (
|
5 |
+
TexturesVertex,
|
6 |
+
)
|
7 |
+
from .project_mesh import (
|
8 |
+
get_cameras_list_azi_ele,
|
9 |
+
multiview_color_projection
|
10 |
+
|
11 |
+
)
|
12 |
+
from .utils import save_py3dmesh_with_trimesh_fast
|
13 |
+
|
14 |
+
def projection(meshes,
|
15 |
+
img_list,
|
16 |
+
weights,
|
17 |
+
azimuths,
|
18 |
+
elevations,
|
19 |
+
projection_type='orthographic',
|
20 |
+
auto_center=True,
|
21 |
+
resolution=1024,
|
22 |
+
fovy=None,
|
23 |
+
radius=None,
|
24 |
+
ortho_dist=1.1,
|
25 |
+
scale_factor=1.0,
|
26 |
+
save_glb_addr=None,
|
27 |
+
scale_verts=True,
|
28 |
+
complete_unseen=True,
|
29 |
+
below_confidence_strategy="smooth"
|
30 |
+
):
|
31 |
+
|
32 |
+
assert len(img_list) == len(azimuths) == len(elevations) == len(weights), f"len(img_list) ({len(img_list)}) != len(azimuths) ({len(azimuths)}) != len(elevations) ({len(elevations)}) != len(weights) ({len(weights)})"
|
33 |
+
|
34 |
+
projection_types = ['perspective', 'orthographic']
|
35 |
+
assert projection_type in projection_types, f"projection_type ({projection_type}) should be one of {projection_types}"
|
36 |
+
|
37 |
+
if auto_center:
|
38 |
+
verts = meshes.verts_packed()
|
39 |
+
max_bb = (verts - 0).max(0)[0]
|
40 |
+
min_bb = (verts - 0).min(0)[0]
|
41 |
+
scale = (max_bb - min_bb).max() / 2
|
42 |
+
center = (max_bb + min_bb) / 2
|
43 |
+
meshes.offset_verts_(-center)
|
44 |
+
if scale_verts:
|
45 |
+
meshes.scale_verts_((scale_factor / float(scale)))
|
46 |
+
elif scale_verts:
|
47 |
+
meshes.scale_verts_((scale_factor))
|
48 |
+
|
49 |
+
if projection_type == 'perspective':
|
50 |
+
assert fovy is not None and radius is not None, f"fovy ({fovy}) and radius ({radius}) should not be None when projection_type is 'perspective'"
|
51 |
+
cameras = get_cameras_list_azi_ele(azimuths, elevations, fov_in_degrees=fovy,device="cuda", dist=radius, cam_type='fov')
|
52 |
+
elif projection_type == 'orthographic':
|
53 |
+
cameras = get_cameras_list_azi_ele(azimuths, elevations, fov_in_degrees=fovy, device="cuda", focal=2/1.35, dist=ortho_dist, cam_type='orthographic')
|
54 |
+
|
55 |
+
|
56 |
+
num_meshes = len(meshes)
|
57 |
+
num_verts_per_mesh = meshes.verts_packed().shape[0] // num_meshes
|
58 |
+
black_texture = torch.zeros((num_meshes, num_verts_per_mesh, 3), device="cuda")
|
59 |
+
textures = TexturesVertex(verts_features=black_texture)
|
60 |
+
meshes.textures = textures
|
61 |
+
|
62 |
+
|
63 |
+
proj_mesh = multiview_color_projection(meshes, img_list, cameras, weights=weights, eps=0.05, resolution=resolution, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=complete_unseen, below_confidence_strategy=below_confidence_strategy)
|
64 |
+
|
65 |
+
|
66 |
+
if save_glb_addr is not None:
|
67 |
+
save_py3dmesh_with_trimesh_fast(proj_mesh, save_glb_addr)
|
68 |
+
|
69 |
+
|
models/ISOMER/scripts/project_mesh.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from pytorch3d.renderer.cameras import look_at_view_transform, OrthographicCameras, CamerasBase
|
6 |
+
from pytorch3d.io import load_objs_as_meshes
|
7 |
+
from pytorch3d.renderer.mesh.rasterizer import Fragments
|
8 |
+
from pytorch3d.structures import Meshes
|
9 |
+
from pytorch3d.renderer import (
|
10 |
+
RasterizationSettings,
|
11 |
+
TexturesVertex,
|
12 |
+
FoVPerspectiveCameras,
|
13 |
+
FoVOrthographicCameras,
|
14 |
+
)
|
15 |
+
from pytorch3d.renderer import MeshRasterizer
|
16 |
+
|
17 |
+
def get_camera(world_to_cam, fov_in_degrees=60, focal_length=1 / (2**0.5), cam_type='fov'):
|
18 |
+
# pytorch3d expects transforms as row-vectors, so flip rotation: https://github.com/facebookresearch/pytorch3d/issues/1183
|
19 |
+
R = world_to_cam[:3, :3].t()[None, ...]
|
20 |
+
T = world_to_cam[:3, 3][None, ...]
|
21 |
+
if cam_type == 'fov':
|
22 |
+
assert fov_in_degrees is not None, "fov_in_degrees should not be None when cam_type is fov"
|
23 |
+
camera = FoVPerspectiveCameras(device=world_to_cam.device, R=R, T=T, fov=fov_in_degrees, degrees=True)
|
24 |
+
else:
|
25 |
+
focal_length = 1 / focal_length
|
26 |
+
camera = FoVOrthographicCameras(device=world_to_cam.device, R=R, T=T, min_x=-focal_length, max_x=focal_length, min_y=-focal_length, max_y=focal_length)
|
27 |
+
return camera
|
28 |
+
|
29 |
+
def render_pix2faces_py3d(meshes, cameras, H=512, W=512, blur_radius=0.0, faces_per_pixel=1):
|
30 |
+
"""
|
31 |
+
Renders pix2face of visible faces.
|
32 |
+
|
33 |
+
:param mesh: Pytorch3d.structures.Meshes
|
34 |
+
:param cameras: pytorch3d.renderer.Cameras
|
35 |
+
:param H: target image height
|
36 |
+
:param W: target image width
|
37 |
+
:param blur_radius: Float distance in the range [0, 2] used to expand the face
|
38 |
+
bounding boxes for rasterization. Setting blur radius
|
39 |
+
results in blurred edges around the shape instead of a
|
40 |
+
hard boundary. Set to 0 for no blur.
|
41 |
+
:param faces_per_pixel: (int) Number of faces to keep track of per pixel.
|
42 |
+
We return the nearest faces_per_pixel faces along the z-axis.
|
43 |
+
"""
|
44 |
+
# Define the settings for rasterization and shading
|
45 |
+
raster_settings = RasterizationSettings(
|
46 |
+
image_size=(H, W),
|
47 |
+
blur_radius=blur_radius,
|
48 |
+
faces_per_pixel=faces_per_pixel
|
49 |
+
)
|
50 |
+
rasterizer=MeshRasterizer(
|
51 |
+
cameras=cameras,
|
52 |
+
raster_settings=raster_settings
|
53 |
+
)
|
54 |
+
fragments: Fragments = rasterizer(meshes, cameras=cameras)
|
55 |
+
return {
|
56 |
+
"pix_to_face": fragments.pix_to_face[..., 0],
|
57 |
+
}
|
58 |
+
|
59 |
+
import nvdiffrast.torch as dr
|
60 |
+
|
61 |
+
def _warmup(glctx, device=None):
|
62 |
+
device = 'cuda' if device is None else device
|
63 |
+
#windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
|
64 |
+
def tensor(*args, **kwargs):
|
65 |
+
return torch.tensor(*args, device=device, **kwargs)
|
66 |
+
pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1], [-0.8, 0.8, 0, 1]]], dtype=torch.float32)
|
67 |
+
tri = tensor([[0, 1, 2]], dtype=torch.int32)
|
68 |
+
dr.rasterize(glctx, pos, tri, resolution=[256, 256])
|
69 |
+
|
70 |
+
class Pix2FacesRenderer:
|
71 |
+
def __init__(self, device="cuda"):
|
72 |
+
# self._glctx = dr.RasterizeGLContext(output_db=False, device=device)
|
73 |
+
self._glctx = dr.RasterizeCudaContext(device=device)
|
74 |
+
self.device = device
|
75 |
+
_warmup(self._glctx, device)
|
76 |
+
|
77 |
+
def transform_vertices(self, meshes: Meshes, cameras: CamerasBase):
|
78 |
+
vertices = cameras.transform_points_ndc(meshes.verts_padded())
|
79 |
+
|
80 |
+
perspective_correct = cameras.is_perspective()
|
81 |
+
znear = cameras.get_znear()
|
82 |
+
if isinstance(znear, torch.Tensor):
|
83 |
+
znear = znear.min().item()
|
84 |
+
z_clip = None #if not perspective_correct or znear is None else znear / 2
|
85 |
+
|
86 |
+
if z_clip:
|
87 |
+
vertices = vertices[vertices[..., 2] >= cameras.get_znear()][None] # clip
|
88 |
+
vertices = vertices * torch.tensor([-1, -1, 1]).to(vertices)
|
89 |
+
vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1).to(torch.float32)
|
90 |
+
return vertices
|
91 |
+
|
92 |
+
def render_pix2faces_nvdiff(self, meshes: Meshes, cameras: CamerasBase, H=512, W=512):
|
93 |
+
meshes = meshes.to(self.device)
|
94 |
+
cameras = cameras.to(self.device)
|
95 |
+
vertices = self.transform_vertices(meshes, cameras)
|
96 |
+
faces = meshes.faces_packed().to(torch.int32)
|
97 |
+
rast_out,_ = dr.rasterize(self._glctx, vertices, faces, resolution=(H, W), grad_db=False) #C,H,W,4
|
98 |
+
pix_to_face = rast_out[..., -1].to(torch.int32) - 1
|
99 |
+
return pix_to_face
|
100 |
+
|
101 |
+
pix2faces_renderer = Pix2FacesRenderer()
|
102 |
+
|
103 |
+
def get_visible_faces(meshes: Meshes, cameras: CamerasBase, resolution=1024):
|
104 |
+
# pix_to_face = render_pix2faces_py3d(meshes, cameras, H=resolution, W=resolution)['pix_to_face']
|
105 |
+
pix_to_face = pix2faces_renderer.render_pix2faces_nvdiff(meshes, cameras, H=resolution, W=resolution)
|
106 |
+
|
107 |
+
unique_faces = torch.unique(pix_to_face.flatten())
|
108 |
+
unique_faces = unique_faces[unique_faces != -1]
|
109 |
+
return unique_faces
|
110 |
+
|
111 |
+
def project_color(meshes: Meshes, cameras: CamerasBase, pil_image: Image.Image, use_alpha=True, eps=0.05, resolution=1024, device="cuda") -> dict:
|
112 |
+
"""
|
113 |
+
Projects color from a given image onto a 3D mesh.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
meshes (pytorch3d.structures.Meshes): The 3D mesh object.
|
117 |
+
cameras (pytorch3d.renderer.cameras.CamerasBase): The camera object.
|
118 |
+
pil_image (PIL.Image.Image): The input image.
|
119 |
+
use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True.
|
120 |
+
eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05.
|
121 |
+
resolution (int, optional): The resolution of the projection. Defaults to 1024.
|
122 |
+
device (str, optional): The device to use for computation. Defaults to "cuda".
|
123 |
+
debug (bool, optional): Whether to save debug images. Defaults to False.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
dict: A dictionary containing the following keys:
|
127 |
+
- "new_texture" (TexturesVertex): The updated texture with interpolated colors.
|
128 |
+
- "valid_verts" (Tensor of [M,3]): The indices of the vertices being projected.
|
129 |
+
- "valid_colors" (Tensor of [M,3]): The interpolated colors for the valid vertices.
|
130 |
+
"""
|
131 |
+
meshes = meshes.to(device)
|
132 |
+
cameras = cameras.to(device)
|
133 |
+
image = torch.from_numpy(np.array(pil_image.convert("RGBA")) / 255.).permute((2, 0, 1)).float().to(device) # in CHW format of [0, 1.]
|
134 |
+
unique_faces = get_visible_faces(meshes, cameras, resolution=resolution)
|
135 |
+
|
136 |
+
# visible faces
|
137 |
+
faces_normals = meshes.faces_normals_packed()[unique_faces]
|
138 |
+
faces_normals = faces_normals / faces_normals.norm(dim=1, keepdim=True)
|
139 |
+
world_points = cameras.unproject_points(torch.tensor([[[0., 0., 0.1], [0., 0., 0.2]]]).to(device))[0]
|
140 |
+
view_direction = world_points[1] - world_points[0]
|
141 |
+
view_direction = view_direction / view_direction.norm(dim=0, keepdim=True)
|
142 |
+
|
143 |
+
|
144 |
+
# find invalid faces
|
145 |
+
cos_angles = (faces_normals * view_direction).sum(dim=1)
|
146 |
+
# assert cos_angles.mean() < 0, f"The view direction is not correct. cos_angles.mean()={cos_angles.mean()}"
|
147 |
+
selected_faces = unique_faces[cos_angles < -eps]
|
148 |
+
|
149 |
+
# find verts
|
150 |
+
faces = meshes.faces_packed()[selected_faces] # [N, 3]
|
151 |
+
verts = torch.unique(faces.flatten()) # [N, 1]
|
152 |
+
verts_coordinates = meshes.verts_packed()[verts] # [N, 3]
|
153 |
+
|
154 |
+
# compute color
|
155 |
+
pt_tensor = cameras.transform_points(verts_coordinates)[..., :2] # NDC space points
|
156 |
+
valid = ~((pt_tensor.isnan()|(pt_tensor<-1)|(1<pt_tensor)).any(dim=1)) # checked, correct
|
157 |
+
valid_pt = pt_tensor[valid, :]
|
158 |
+
valid_idx = verts[valid]
|
159 |
+
valid_color = torch.nn.functional.grid_sample(image[None].flip((-1, -2)), valid_pt[None, :, None, :], align_corners=False, padding_mode="reflection", mode="bilinear")[0, :, :, 0].T.clamp(0, 1) # [N, 4], note that bicubic may give invalid value
|
160 |
+
alpha, valid_color = valid_color[:, 3:], valid_color[:, :3]
|
161 |
+
if not use_alpha:
|
162 |
+
alpha = torch.ones_like(alpha)
|
163 |
+
|
164 |
+
# modify color
|
165 |
+
old_colors = meshes.textures.verts_features_packed()
|
166 |
+
old_colors[valid_idx] = valid_color * alpha + old_colors[valid_idx] * (1 - alpha)
|
167 |
+
new_texture = TexturesVertex(verts_features=[old_colors])
|
168 |
+
|
169 |
+
valid_verts_normals = meshes.verts_normals_packed()[valid_idx]
|
170 |
+
valid_verts_normals = valid_verts_normals / valid_verts_normals.norm(dim=1, keepdim=True).clamp_min(0.001)
|
171 |
+
cos_angles = (valid_verts_normals * view_direction).sum(dim=1)
|
172 |
+
return {
|
173 |
+
"new_texture": new_texture,
|
174 |
+
"valid_verts": valid_idx,
|
175 |
+
"valid_colors": valid_color,
|
176 |
+
"valid_alpha": alpha,
|
177 |
+
"cos_angles": cos_angles,
|
178 |
+
}
|
179 |
+
|
180 |
+
def complete_unseen_vertex_color(meshes: Meshes, valid_index: torch.Tensor) -> dict:
|
181 |
+
"""
|
182 |
+
meshes: the mesh with vertex color to be completed.
|
183 |
+
valid_index: the index of the valid vertices, where valid means colors are fixed. [V, 1]
|
184 |
+
"""
|
185 |
+
valid_index = valid_index.to(meshes.device)
|
186 |
+
colors = meshes.textures.verts_features_packed() # [V, 3]
|
187 |
+
V = colors.shape[0]
|
188 |
+
|
189 |
+
invalid_index = torch.ones_like(colors[:, 0]).bool() # [V]
|
190 |
+
invalid_index[valid_index] = False
|
191 |
+
invalid_index = torch.arange(V).to(meshes.device)[invalid_index]
|
192 |
+
|
193 |
+
L = meshes.laplacian_packed() # connectivity
|
194 |
+
E = torch.sparse_coo_tensor(torch.tensor([list(range(V))] * 2), torch.ones((V,)), size=(V, V)).to(meshes.device)
|
195 |
+
L = L + E
|
196 |
+
# E = torch.eye(V, layout=torch.sparse_coo, device=meshes.device)
|
197 |
+
# L = L + E
|
198 |
+
colored_count = torch.ones_like(colors[:, 0]) # [V]
|
199 |
+
colored_count[invalid_index] = 0
|
200 |
+
L_invalid = torch.index_select(L, 0, invalid_index) # sparse [IV, V]
|
201 |
+
|
202 |
+
total_colored = colored_count.sum()
|
203 |
+
coloring_round = 0
|
204 |
+
stage = "uncolored"
|
205 |
+
from tqdm import tqdm
|
206 |
+
pbar = tqdm(miniters=100)
|
207 |
+
while stage == "uncolored" or coloring_round > 0:
|
208 |
+
new_color = torch.matmul(L_invalid, colors * colored_count[:, None]) # [IV, 3]
|
209 |
+
new_count = torch.matmul(L_invalid, colored_count)[:, None] # [IV, 1]
|
210 |
+
colors[invalid_index] = torch.where(new_count > 0, new_color / new_count, colors[invalid_index])
|
211 |
+
colored_count[invalid_index] = (new_count[:, 0] > 0).float()
|
212 |
+
|
213 |
+
new_total_colored = colored_count.sum()
|
214 |
+
if new_total_colored > total_colored:
|
215 |
+
total_colored = new_total_colored
|
216 |
+
coloring_round += 1
|
217 |
+
else:
|
218 |
+
stage = "colored"
|
219 |
+
coloring_round -= 1
|
220 |
+
pbar.update(1)
|
221 |
+
if coloring_round > 10000:
|
222 |
+
print("coloring_round > 10000, break")
|
223 |
+
break
|
224 |
+
assert not torch.isnan(colors).any()
|
225 |
+
meshes.textures = TexturesVertex(verts_features=[colors])
|
226 |
+
return meshes
|
227 |
+
|
228 |
+
def load_glb_mesh(glb_path, device="cuda"):
|
229 |
+
meshes = load_objs_as_meshes([glb_path], device=device)
|
230 |
+
return meshes
|
231 |
+
|
232 |
+
def get_separated_images_from_img_grid(img_grid_path, image_num):
|
233 |
+
img_list = []
|
234 |
+
grid = Image.open(img_grid_path)
|
235 |
+
w, h = grid.size
|
236 |
+
for i in range(0, image_num):
|
237 |
+
img_list.append(grid.crop((i*h, 0, i*h + h, h)))
|
238 |
+
return img_list
|
239 |
+
|
240 |
+
def get_fov_camera_(azimuth, elevation, fovy, radius, mesh, auto_center, scale_factor, device='cuda'):
|
241 |
+
if auto_center:
|
242 |
+
verts = mesh.verts_packed()
|
243 |
+
max_bb = (verts - 0).max(0)[0]
|
244 |
+
min_bb = (verts - 0).min(0)[0]
|
245 |
+
scale = (max_bb - min_bb).max() / 2
|
246 |
+
center = (max_bb + min_bb) / 2
|
247 |
+
mesh.offset_verts_(-center)
|
248 |
+
mesh.scale_verts_((scale_factor / float(scale)))
|
249 |
+
else:
|
250 |
+
mesh.scale_verts_((scale_factor))
|
251 |
+
R, T = look_at_view_transform(radius, azimuth, elevation, device=device)
|
252 |
+
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=fovy)
|
253 |
+
return cameras
|
254 |
+
|
255 |
+
def multiview_color_projection(meshes: Meshes, image_list: List[Image.Image], cameras_list: List[CamerasBase], weights=None, eps=0.05, resolution=1024, device="cuda", reweight_with_cosangle="square", use_alpha=True, confidence_threshold=0.1, complete_unseen=False, below_confidence_strategy="smooth") -> Meshes:
|
256 |
+
"""
|
257 |
+
Projects color from a given image onto a 3D mesh.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
meshes (pytorch3d.structures.Meshes): The 3D mesh object, only one mesh.
|
261 |
+
image_list (PIL.Image.Image): List of images.
|
262 |
+
cameras_list (list): List of cameras.
|
263 |
+
weights (list, optional): List of weights for each image, for ['front', 'front_right', 'right', 'back', 'left', 'front_left']. Defaults to None.
|
264 |
+
eps (float, optional): The threshold for selecting visible faces. Defaults to 0.05.
|
265 |
+
resolution (int, optional): The resolution of the projection. Defaults to 1024.
|
266 |
+
device (str, optional): The device to use for computation. Defaults to "cuda".
|
267 |
+
reweight_with_cosangle (str, optional): Whether to reweight the color with the angle between the view direction and the vertex normal. Defaults to None.
|
268 |
+
use_alpha (bool, optional): Whether to use the alpha channel of the image. Defaults to True.
|
269 |
+
confidence_threshold (float, optional): The threshold for the confidence of the projected color, if final projection weight is less than this, we will use the original color. Defaults to 0.1.
|
270 |
+
complete_unseen (bool, optional): Whether to complete the unseen vertex color using laplacian. Defaults to False.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
Meshes: the colored mesh
|
274 |
+
"""
|
275 |
+
|
276 |
+
if image_list is None:
|
277 |
+
raise ValueError("image_list is None")
|
278 |
+
|
279 |
+
|
280 |
+
meshes = meshes.clone().to(device)
|
281 |
+
if weights is None:
|
282 |
+
weights = [1. for _ in range(len(cameras_list))]
|
283 |
+
|
284 |
+
assert len(cameras_list) == len(image_list) == len(weights), f'the following three lengths should be equal: len(cameras_list)({len(cameras_list)}), len(image_list)({len(image_list)}), len(weights)({len(weights)})'
|
285 |
+
|
286 |
+
original_color = meshes.textures.verts_features_packed()
|
287 |
+
assert not torch.isnan(original_color).any()
|
288 |
+
texture_counts = torch.zeros_like(original_color[..., :1])
|
289 |
+
texture_values = torch.zeros_like(original_color)
|
290 |
+
max_texture_counts = torch.zeros_like(original_color[..., :1])
|
291 |
+
max_texture_values = torch.zeros_like(original_color)
|
292 |
+
for camera, image, weight in zip(cameras_list, image_list, weights):
|
293 |
+
ret = project_color(meshes, camera, image, eps=eps, resolution=resolution, device=device, use_alpha=use_alpha)
|
294 |
+
if reweight_with_cosangle == "linear":
|
295 |
+
weight = (ret['cos_angles'].abs() * weight)[:, None]
|
296 |
+
elif reweight_with_cosangle == "square":
|
297 |
+
weight = (ret['cos_angles'].abs() ** 2 * weight)[:, None]
|
298 |
+
if use_alpha:
|
299 |
+
weight = weight * ret['valid_alpha']
|
300 |
+
|
301 |
+
try:
|
302 |
+
assert weight.min() > -0.0001, f'weight.min() is {weight.min()}, but shoule be > -0.0001'
|
303 |
+
except Exception as e:
|
304 |
+
raise e
|
305 |
+
|
306 |
+
texture_counts[ret['valid_verts']] += weight
|
307 |
+
texture_values[ret['valid_verts']] += ret['valid_colors'] * weight
|
308 |
+
max_texture_values[ret['valid_verts']] = torch.where(weight > max_texture_counts[ret['valid_verts']], ret['valid_colors'], max_texture_values[ret['valid_verts']])
|
309 |
+
max_texture_counts[ret['valid_verts']] = torch.max(max_texture_counts[ret['valid_verts']], weight)
|
310 |
+
|
311 |
+
texture_values = torch.where(texture_counts > confidence_threshold, texture_values / texture_counts, texture_values)
|
312 |
+
if below_confidence_strategy == "smooth":
|
313 |
+
texture_values = torch.where(texture_counts <= confidence_threshold, (original_color * (confidence_threshold - texture_counts) + texture_values) / confidence_threshold, texture_values)
|
314 |
+
elif below_confidence_strategy == "original":
|
315 |
+
texture_values = torch.where(texture_counts <= confidence_threshold, original_color, texture_values)
|
316 |
+
else:
|
317 |
+
raise ValueError(f"below_confidence_strategy={below_confidence_strategy} is not supported")
|
318 |
+
assert not torch.isnan(texture_values).any()
|
319 |
+
meshes.textures = TexturesVertex(verts_features=[texture_values])
|
320 |
+
|
321 |
+
if complete_unseen:
|
322 |
+
meshes = complete_unseen_vertex_color(meshes, torch.arange(texture_values.shape[0]).to(device)[texture_counts[:, 0] >= confidence_threshold])
|
323 |
+
ret_mesh = meshes.detach()
|
324 |
+
del meshes
|
325 |
+
return ret_mesh
|
326 |
+
|
327 |
+
def get_cameras_list(azim_list, device, elevation, fov_in_degrees=None, focal=2/1.35, dist=1.1, cam_type='orthographic'):
|
328 |
+
ret = []
|
329 |
+
for azim in azim_list:
|
330 |
+
R, T = look_at_view_transform(dist, elevation, azim)
|
331 |
+
w2c = torch.cat([R[0].T, T[0, :, None]], dim=1)
|
332 |
+
cameras = get_camera(w2c, fov_in_degrees=fov_in_degrees, focal_length=focal, cam_type=cam_type).to(device)
|
333 |
+
ret.append(cameras)
|
334 |
+
return ret
|
335 |
+
|
336 |
+
def get_cameras_list_azi_ele(azim_list, elev_list, device, fov_in_degrees=None, focal=2/1.35, dist=1.1, cam_type='orthographic'):
|
337 |
+
ret = []
|
338 |
+
for i in range(len(azim_list)):
|
339 |
+
R, T = look_at_view_transform(dist, elev_list[i], azim_list[i])
|
340 |
+
w2c = torch.cat([R[0].T, T[0, :, None]], dim=1)
|
341 |
+
cameras = get_camera(w2c, fov_in_degrees=fov_in_degrees, focal_length=focal, cam_type=cam_type).to(device)
|
342 |
+
ret.append(cameras)
|
343 |
+
return ret
|
344 |
+
|
345 |
+
def get_8view_cameras(device, focal=2/1.35):
|
346 |
+
return get_cameras_list(azim_list = [180, 225, 270, 315, 0, 45, 90, 135], elevation=0, device=device, focal=focal)
|
347 |
+
|
348 |
+
def get_6view_cameras(device, focal=2/1.35):
|
349 |
+
return get_cameras_list(azim_list = [180, 225, 270, 0, 90, 135], elevation=0, device=device, focal=focal)
|
350 |
+
|
351 |
+
def get_4view_cameras(device, focal=2/1.35):
|
352 |
+
return get_cameras_list(azim_list = [180, 270, 0, 90], elevation=0, device=device, focal=focal)
|
353 |
+
|
354 |
+
def get_2view_cameras(device, focal=2/1.35):
|
355 |
+
return get_cameras_list(azim_list = [180, 0], elevation=0, device=device, focal=focal)
|
356 |
+
|
357 |
+
def get_multiple_view_cameras(device, focal=2/1.35, offset=180, num_views=8, dist=1.1):
|
358 |
+
return get_cameras_list(azim_list = (np.linspace(0, 360, num_views+1)[:-1] + offset) % 360, elevation=0, device=device, focal=focal, dist=dist)
|
359 |
+
|
360 |
+
def align_with_alpha_bbox(source_img, target_img, final_size=1024):
|
361 |
+
# align source_img with target_img using alpha channel
|
362 |
+
# source_img and target_img are PIL.Image.Image
|
363 |
+
source_img = source_img.convert("RGBA")
|
364 |
+
target_img = target_img.convert("RGBA").resize((final_size, final_size))
|
365 |
+
source_np = np.array(source_img)
|
366 |
+
target_np = np.array(target_img)
|
367 |
+
source_alpha = source_np[:, :, 3]
|
368 |
+
target_alpha = target_np[:, :, 3]
|
369 |
+
bbox_source_min, bbox_source_max = np.argwhere(source_alpha > 0).min(axis=0), np.argwhere(source_alpha > 0).max(axis=0)
|
370 |
+
bbox_target_min, bbox_target_max = np.argwhere(target_alpha > 0).min(axis=0), np.argwhere(target_alpha > 0).max(axis=0)
|
371 |
+
source_content = source_np[bbox_source_min[0]:bbox_source_max[0]+1, bbox_source_min[1]:bbox_source_max[1]+1, :]
|
372 |
+
# resize source_content to fit in the position of target_content
|
373 |
+
source_content = Image.fromarray(source_content).resize((bbox_target_max[1]-bbox_target_min[1]+1, bbox_target_max[0]-bbox_target_min[0]+1), resample=Image.BICUBIC)
|
374 |
+
target_np[bbox_target_min[0]:bbox_target_max[0]+1, bbox_target_min[1]:bbox_target_max[1]+1, :] = np.array(source_content)
|
375 |
+
return Image.fromarray(target_np)
|
376 |
+
|
377 |
+
def load_image_list_from_mvdiffusion(mvdiffusion_path, front_from_pil_or_path=None):
|
378 |
+
import os
|
379 |
+
image_list = []
|
380 |
+
for dir in ['front', 'front_right', 'right', 'back', 'left', 'front_left']:
|
381 |
+
image_path = os.path.join(mvdiffusion_path, f"rgb_000_{dir}.png")
|
382 |
+
pil = Image.open(image_path)
|
383 |
+
if dir == 'front':
|
384 |
+
if front_from_pil_or_path is not None:
|
385 |
+
if isinstance(front_from_pil_or_path, str):
|
386 |
+
replace_pil = Image.open(front_from_pil_or_path)
|
387 |
+
else:
|
388 |
+
replace_pil = front_from_pil_or_path
|
389 |
+
# align replace_pil with pil using bounding box in alpha channel
|
390 |
+
pil = align_with_alpha_bbox(replace_pil, pil, final_size=1024)
|
391 |
+
image_list.append(pil)
|
392 |
+
return image_list
|
393 |
+
|
394 |
+
def load_image_list_from_img_grid(img_grid_path, resolution = 1024):
|
395 |
+
img_list = []
|
396 |
+
grid = Image.open(img_grid_path)
|
397 |
+
w, h = grid.size
|
398 |
+
for row in range(0, h, resolution):
|
399 |
+
for col in range(0, w, resolution):
|
400 |
+
img_list.append(grid.crop((col, row, col + resolution, row + resolution)))
|
401 |
+
return img_list
|
models/ISOMER/scripts/refine_lr_to_sr.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from hashlib import md5
|
6 |
+
def hash_img(img):
|
7 |
+
return md5(np.array(img).tobytes()).hexdigest()
|
8 |
+
def hash_any(obj):
|
9 |
+
return md5(str(obj).encode()).hexdigest()
|
10 |
+
|
11 |
+
def refine_lr_with_sd(pil_image_list, concept_img_list, control_image_list, prompt_list, pipe=None, strength=0.35, neg_prompt_list="", output_size=(512, 512), controlnet_conditioning_scale=1.):
|
12 |
+
with torch.no_grad():
|
13 |
+
images = pipe(
|
14 |
+
image=pil_image_list,
|
15 |
+
ip_adapter_image=concept_img_list,
|
16 |
+
prompt=prompt_list,
|
17 |
+
neg_prompt=neg_prompt_list,
|
18 |
+
num_inference_steps=50,
|
19 |
+
strength=strength,
|
20 |
+
height=output_size[0],
|
21 |
+
width=output_size[1],
|
22 |
+
control_image=control_image_list,
|
23 |
+
guidance_scale=5.0,
|
24 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
25 |
+
generator=torch.manual_seed(233),
|
26 |
+
).images
|
27 |
+
return images
|
28 |
+
|
29 |
+
SR_cache = None
|
30 |
+
|
31 |
+
def run_sr_fast(source_pils, scale=4):
|
32 |
+
from PIL import Image
|
33 |
+
from scripts.upsampler import RealESRGANer
|
34 |
+
import numpy as np
|
35 |
+
global SR_cache
|
36 |
+
if SR_cache is not None:
|
37 |
+
upsampler = SR_cache
|
38 |
+
else:
|
39 |
+
upsampler = RealESRGANer(
|
40 |
+
scale=4,
|
41 |
+
onnx_path="ckpt/realesrgan-x4.onnx",
|
42 |
+
tile=0,
|
43 |
+
tile_pad=10,
|
44 |
+
pre_pad=0,
|
45 |
+
half=True,
|
46 |
+
gpu_id=0,
|
47 |
+
)
|
48 |
+
ret_pils = []
|
49 |
+
for idx, img_pils in enumerate(source_pils):
|
50 |
+
np_in = isinstance(img_pils, np.ndarray)
|
51 |
+
assert isinstance(img_pils, (Image.Image, np.ndarray))
|
52 |
+
img = np.array(img_pils)
|
53 |
+
output, _ = upsampler.enhance(img, outscale=scale)
|
54 |
+
if np_in:
|
55 |
+
ret_pils.append(output)
|
56 |
+
else:
|
57 |
+
ret_pils.append(Image.fromarray(output))
|
58 |
+
if SR_cache is None:
|
59 |
+
SR_cache = upsampler
|
60 |
+
return ret_pils
|
models/ISOMER/scripts/sd_model_zoo.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
|
2 |
+
from transformers import CLIPVisionModelWithProjection
|
3 |
+
import torch
|
4 |
+
from copy import deepcopy
|
5 |
+
|
6 |
+
ENABLE_CPU_CACHE = False
|
7 |
+
DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5"
|
8 |
+
|
9 |
+
cached_models = {} # cache for models to avoid repeated loading, key is model name
|
10 |
+
def cache_model(func):
|
11 |
+
def wrapper(*args, **kwargs):
|
12 |
+
if ENABLE_CPU_CACHE:
|
13 |
+
model_name = func.__name__ + str(args) + str(kwargs)
|
14 |
+
if model_name not in cached_models:
|
15 |
+
cached_models[model_name] = func(*args, **kwargs)
|
16 |
+
return cached_models[model_name]
|
17 |
+
else:
|
18 |
+
return func(*args, **kwargs)
|
19 |
+
return wrapper
|
20 |
+
|
21 |
+
def copied_cache_model(func):
|
22 |
+
def wrapper(*args, **kwargs):
|
23 |
+
if ENABLE_CPU_CACHE:
|
24 |
+
model_name = func.__name__ + str(args) + str(kwargs)
|
25 |
+
if model_name not in cached_models:
|
26 |
+
cached_models[model_name] = func(*args, **kwargs)
|
27 |
+
return deepcopy(cached_models[model_name])
|
28 |
+
else:
|
29 |
+
return func(*args, **kwargs)
|
30 |
+
return wrapper
|
31 |
+
|
32 |
+
def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs):
|
33 |
+
if ckpt_or_pretrained.endswith(".safetensors"):
|
34 |
+
pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs)
|
35 |
+
else:
|
36 |
+
pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs)
|
37 |
+
return pipe
|
38 |
+
|
39 |
+
@copied_cache_model
|
40 |
+
def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16):
|
41 |
+
model_kwargs = dict(
|
42 |
+
torch_dtype=torch_dtype,
|
43 |
+
requires_safety_checker=False,
|
44 |
+
safety_checker=None,
|
45 |
+
)
|
46 |
+
pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
|
47 |
+
base_model,
|
48 |
+
StableDiffusionPipeline,
|
49 |
+
**model_kwargs
|
50 |
+
)
|
51 |
+
pipe.to("cpu")
|
52 |
+
return pipe.components
|
53 |
+
|
54 |
+
@cache_model
|
55 |
+
def load_controlnet(controlnet_path, torch_dtype=torch.float16):
|
56 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype)
|
57 |
+
return controlnet
|
58 |
+
|
59 |
+
@cache_model
|
60 |
+
def load_image_encoder():
|
61 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
62 |
+
"h94/IP-Adapter",
|
63 |
+
subfolder="models/image_encoder",
|
64 |
+
torch_dtype=torch.float16,
|
65 |
+
)
|
66 |
+
return image_encoder
|
67 |
+
|
68 |
+
def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="auto", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs):
|
69 |
+
model_kwargs = dict(
|
70 |
+
torch_dtype=torch_dtype,
|
71 |
+
device_map=device,
|
72 |
+
requires_safety_checker=False,
|
73 |
+
safety_checker=None,
|
74 |
+
)
|
75 |
+
components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype)
|
76 |
+
model_kwargs.update(components)
|
77 |
+
model_kwargs.update(kwargs)
|
78 |
+
|
79 |
+
if controlnet is not None:
|
80 |
+
if isinstance(controlnet, list):
|
81 |
+
controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet]
|
82 |
+
else:
|
83 |
+
controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype)
|
84 |
+
model_kwargs.update(controlnet=controlnet)
|
85 |
+
|
86 |
+
if pipeline_class is None:
|
87 |
+
if controlnet is not None:
|
88 |
+
pipeline_class = StableDiffusionControlNetPipeline
|
89 |
+
else:
|
90 |
+
pipeline_class = StableDiffusionPipeline
|
91 |
+
|
92 |
+
pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
|
93 |
+
base_model,
|
94 |
+
pipeline_class,
|
95 |
+
**model_kwargs
|
96 |
+
)
|
97 |
+
|
98 |
+
if ip_adapter:
|
99 |
+
image_encoder = load_image_encoder()
|
100 |
+
pipe.image_encoder = image_encoder
|
101 |
+
if plus_model:
|
102 |
+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors")
|
103 |
+
else:
|
104 |
+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors")
|
105 |
+
pipe.set_ip_adapter_scale(1.0)
|
106 |
+
else:
|
107 |
+
pipe.unload_ip_adapter()
|
108 |
+
|
109 |
+
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
110 |
+
|
111 |
+
if model_cpu_offload_seq is None:
|
112 |
+
if isinstance(pipe, StableDiffusionControlNetPipeline):
|
113 |
+
pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae"
|
114 |
+
elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline):
|
115 |
+
pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae"
|
116 |
+
else:
|
117 |
+
pipe.model_cpu_offload_seq = model_cpu_offload_seq
|
118 |
+
|
119 |
+
if enable_sequential_cpu_offload:
|
120 |
+
pipe.enable_sequential_cpu_offload()
|
121 |
+
else:
|
122 |
+
pipe = pipe.to("cuda")
|
123 |
+
pass
|
124 |
+
# pipe.enable_model_cpu_offload()
|
125 |
+
if vae_slicing:
|
126 |
+
pipe.enable_vae_slicing()
|
127 |
+
|
128 |
+
import gc
|
129 |
+
gc.collect()
|
130 |
+
return pipe
|
131 |
+
|
models/ISOMER/scripts/upsampler.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from scripts.load_onnx import load_onnx_caller
|
8 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
9 |
+
|
10 |
+
|
11 |
+
class RealESRGANer():
|
12 |
+
"""A helper class for upsampling images with RealESRGAN.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
16 |
+
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
17 |
+
model (nn.Module): The defined network. Default: None.
|
18 |
+
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
19 |
+
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
20 |
+
0 denotes for do not use tile. Default: 0.
|
21 |
+
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
22 |
+
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
23 |
+
half (float): Whether to use half precision during inference. Default: False.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
scale,
|
28 |
+
onnx_path,
|
29 |
+
tile=0,
|
30 |
+
tile_pad=10,
|
31 |
+
pre_pad=10,
|
32 |
+
half=False,
|
33 |
+
device=None,
|
34 |
+
gpu_id=None):
|
35 |
+
self.scale = scale
|
36 |
+
self.tile_size = tile
|
37 |
+
self.tile_pad = tile_pad
|
38 |
+
self.pre_pad = pre_pad
|
39 |
+
self.mod_scale = None
|
40 |
+
self.half = half
|
41 |
+
|
42 |
+
print('about to initialize model')
|
43 |
+
# initialize model
|
44 |
+
if gpu_id:
|
45 |
+
self.device = torch.device(
|
46 |
+
f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
|
47 |
+
else:
|
48 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
49 |
+
print('self.device set')
|
50 |
+
print(f'about to self.model = load_onnx_caller({onnx_path}, single_output=True)')
|
51 |
+
self.model = load_onnx_caller(onnx_path, single_output=True)
|
52 |
+
print('self.model loaded')
|
53 |
+
|
54 |
+
print('about to warm up')
|
55 |
+
# warm up
|
56 |
+
sample_input = torch.randn(1,3,512,512).cuda().float()
|
57 |
+
print(f'sample_input.shape = {sample_input.shape}')
|
58 |
+
self.model(sample_input)
|
59 |
+
print('finished warming up')
|
60 |
+
|
61 |
+
def pre_process(self, img):
|
62 |
+
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
|
63 |
+
"""
|
64 |
+
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
65 |
+
self.img = img.unsqueeze(0).to(self.device)
|
66 |
+
if self.half:
|
67 |
+
self.img = self.img.half()
|
68 |
+
|
69 |
+
# pre_pad
|
70 |
+
if self.pre_pad != 0:
|
71 |
+
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
|
72 |
+
# mod pad for divisible borders
|
73 |
+
if self.scale == 2:
|
74 |
+
self.mod_scale = 2
|
75 |
+
elif self.scale == 1:
|
76 |
+
self.mod_scale = 4
|
77 |
+
if self.mod_scale is not None:
|
78 |
+
self.mod_pad_h, self.mod_pad_w = 0, 0
|
79 |
+
_, _, h, w = self.img.size()
|
80 |
+
if (h % self.mod_scale != 0):
|
81 |
+
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
|
82 |
+
if (w % self.mod_scale != 0):
|
83 |
+
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
|
84 |
+
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
|
85 |
+
|
86 |
+
def process(self):
|
87 |
+
# model inference
|
88 |
+
self.output = self.model(self.img)
|
89 |
+
|
90 |
+
def tile_process(self):
|
91 |
+
"""It will first crop input images to tiles, and then process each tile.
|
92 |
+
Finally, all the processed tiles are merged into one images.
|
93 |
+
|
94 |
+
Modified from: https://github.com/ata4/esrgan-launcher
|
95 |
+
"""
|
96 |
+
batch, channel, height, width = self.img.shape
|
97 |
+
output_height = height * self.scale
|
98 |
+
output_width = width * self.scale
|
99 |
+
output_shape = (batch, channel, output_height, output_width)
|
100 |
+
|
101 |
+
# start with black image
|
102 |
+
self.output = self.img.new_zeros(output_shape)
|
103 |
+
tiles_x = math.ceil(width / self.tile_size)
|
104 |
+
tiles_y = math.ceil(height / self.tile_size)
|
105 |
+
|
106 |
+
# loop over all tiles
|
107 |
+
for y in range(tiles_y):
|
108 |
+
for x in range(tiles_x):
|
109 |
+
# extract tile from input image
|
110 |
+
ofs_x = x * self.tile_size
|
111 |
+
ofs_y = y * self.tile_size
|
112 |
+
# input tile area on total image
|
113 |
+
input_start_x = ofs_x
|
114 |
+
input_end_x = min(ofs_x + self.tile_size, width)
|
115 |
+
input_start_y = ofs_y
|
116 |
+
input_end_y = min(ofs_y + self.tile_size, height)
|
117 |
+
|
118 |
+
# input tile area on total image with padding
|
119 |
+
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
120 |
+
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
121 |
+
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
122 |
+
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
123 |
+
|
124 |
+
# input tile dimensions
|
125 |
+
input_tile_width = input_end_x - input_start_x
|
126 |
+
input_tile_height = input_end_y - input_start_y
|
127 |
+
tile_idx = y * tiles_x + x + 1
|
128 |
+
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
129 |
+
|
130 |
+
# upscale tile
|
131 |
+
try:
|
132 |
+
with torch.no_grad():
|
133 |
+
output_tile = self.model(input_tile)
|
134 |
+
except RuntimeError as error:
|
135 |
+
print('Error', error)
|
136 |
+
print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
|
137 |
+
|
138 |
+
# output tile area on total image
|
139 |
+
output_start_x = input_start_x * self.scale
|
140 |
+
output_end_x = input_end_x * self.scale
|
141 |
+
output_start_y = input_start_y * self.scale
|
142 |
+
output_end_y = input_end_y * self.scale
|
143 |
+
|
144 |
+
# output tile area without padding
|
145 |
+
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
146 |
+
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
147 |
+
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
148 |
+
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
149 |
+
|
150 |
+
# put tile into output image
|
151 |
+
self.output[:, :, output_start_y:output_end_y,
|
152 |
+
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
|
153 |
+
output_start_x_tile:output_end_x_tile]
|
154 |
+
|
155 |
+
def post_process(self):
|
156 |
+
# remove extra pad
|
157 |
+
if self.mod_scale is not None:
|
158 |
+
_, _, h, w = self.output.size()
|
159 |
+
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
|
160 |
+
# remove prepad
|
161 |
+
if self.pre_pad != 0:
|
162 |
+
_, _, h, w = self.output.size()
|
163 |
+
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
|
164 |
+
return self.output
|
165 |
+
|
166 |
+
@torch.no_grad()
|
167 |
+
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
|
168 |
+
print('inside enhance')
|
169 |
+
h_input, w_input = img.shape[0:2]
|
170 |
+
# img: numpy
|
171 |
+
img = img.astype(np.float32)
|
172 |
+
if np.max(img) > 256: # 16-bit image
|
173 |
+
max_range = 65535
|
174 |
+
print('\tInput is a 16-bit image')
|
175 |
+
else:
|
176 |
+
max_range = 255
|
177 |
+
img = img / max_range
|
178 |
+
if len(img.shape) == 2: # gray image
|
179 |
+
img_mode = 'L'
|
180 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
181 |
+
elif img.shape[2] == 4: # RGBA image with alpha channel
|
182 |
+
img_mode = 'RGBA'
|
183 |
+
alpha = img[:, :, 3]
|
184 |
+
img = img[:, :, 0:3]
|
185 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
186 |
+
if alpha_upsampler == 'realesrgan':
|
187 |
+
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
188 |
+
else:
|
189 |
+
img_mode = 'RGB'
|
190 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
191 |
+
|
192 |
+
# ------------------- process image (without the alpha channel) ------------------- #
|
193 |
+
print('about to process image (without the alpha channel)')
|
194 |
+
self.pre_process(img)
|
195 |
+
if self.tile_size > 0:
|
196 |
+
print(f'self.tile_size is {self.tile_size}, thus about to self.tile_process()')
|
197 |
+
self.tile_process()
|
198 |
+
print('finished self.tile_process()')
|
199 |
+
else:
|
200 |
+
print('about to self.process()')
|
201 |
+
self.process()
|
202 |
+
print('finished self.process()')
|
203 |
+
|
204 |
+
print('about to self.post_process()')
|
205 |
+
output_img = self.post_process()
|
206 |
+
print('finished self.post_process()')
|
207 |
+
output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
208 |
+
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
209 |
+
if img_mode == 'L':
|
210 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
211 |
+
print('finished process image (without the alpha channel)')
|
212 |
+
|
213 |
+
# ------------------- process the alpha channel if necessary ------------------- #
|
214 |
+
if img_mode == 'RGBA':
|
215 |
+
print("img_mode == 'RGBA' thus about to process alpha channel")
|
216 |
+
if alpha_upsampler == 'realesrgan':
|
217 |
+
print(f"alpha_upsampler == 'realesrgan', about to self.pre_process({alpha})")
|
218 |
+
self.pre_process(alpha)
|
219 |
+
print('finished self.pre_process')
|
220 |
+
if self.tile_size > 0:
|
221 |
+
print(f'self.tile_size is {self.tile_size}, thus about to self.tile_process()')
|
222 |
+
self.tile_process()
|
223 |
+
print('finished self.tile_process()')
|
224 |
+
else:
|
225 |
+
print('about to self.process()')
|
226 |
+
self.process()
|
227 |
+
print('finished self.process()')
|
228 |
+
print('about to self.post_process()')
|
229 |
+
output_alpha = self.post_process()
|
230 |
+
print('finished self.post_process()')
|
231 |
+
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
232 |
+
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
233 |
+
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
234 |
+
else: # use the cv2 resize for alpha channel
|
235 |
+
print('about to use the cv2 resize for alpha channel')
|
236 |
+
h, w = alpha.shape[0:2]
|
237 |
+
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
|
238 |
+
|
239 |
+
print('about to merge the alpha channel')
|
240 |
+
# merge the alpha channel
|
241 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
242 |
+
output_img[:, :, 3] = output_alpha
|
243 |
+
print('finished process alpha channel')
|
244 |
+
|
245 |
+
print('about to resize and return')
|
246 |
+
# ------------------------------ return ------------------------------ #
|
247 |
+
if max_range == 65535: # 16-bit image
|
248 |
+
output = (output_img * 65535.0).round().astype(np.uint16)
|
249 |
+
else:
|
250 |
+
output = (output_img * 255.0).round().astype(np.uint8)
|
251 |
+
|
252 |
+
if outscale is not None and outscale != float(self.scale):
|
253 |
+
output = cv2.resize(
|
254 |
+
output, (
|
255 |
+
int(w_input * outscale),
|
256 |
+
int(h_input * outscale),
|
257 |
+
), interpolation=cv2.INTER_LANCZOS4)
|
258 |
+
|
259 |
+
return output, img_mode
|
260 |
+
|
models/ISOMER/scripts/utils.py
ADDED
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import pymeshlab
|
5 |
+
import pymeshlab as ml
|
6 |
+
from pymeshlab import PercentageValue
|
7 |
+
from pytorch3d.renderer import TexturesVertex
|
8 |
+
from pytorch3d.structures import Meshes
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from typing import List, Tuple
|
12 |
+
from PIL import Image
|
13 |
+
import trimesh
|
14 |
+
|
15 |
+
EPSILON = 1e-8
|
16 |
+
|
17 |
+
def load_mesh_with_trimesh(file_name, file_type=None):
|
18 |
+
import trimesh
|
19 |
+
mesh: trimesh.Trimesh = trimesh.load(file_name, file_type=file_type)
|
20 |
+
if isinstance(mesh, trimesh.Scene):
|
21 |
+
assert len(mesh.geometry) > 0
|
22 |
+
# save to obj first and load again to avoid offset issue
|
23 |
+
from io import BytesIO
|
24 |
+
with BytesIO() as f:
|
25 |
+
mesh.export(f, file_type="obj")
|
26 |
+
f.seek(0)
|
27 |
+
mesh = trimesh.load(f, file_type="obj")
|
28 |
+
if isinstance(mesh, trimesh.Scene):
|
29 |
+
# we lose texture information here
|
30 |
+
mesh = trimesh.util.concatenate(
|
31 |
+
tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces)
|
32 |
+
for g in mesh.geometry.values()))
|
33 |
+
assert isinstance(mesh, trimesh.Trimesh)
|
34 |
+
|
35 |
+
vertices = torch.from_numpy(mesh.vertices).T
|
36 |
+
faces = torch.from_numpy(mesh.faces).T
|
37 |
+
colors = None
|
38 |
+
if mesh.visual is not None:
|
39 |
+
if hasattr(mesh.visual, 'vertex_colors'):
|
40 |
+
colors = torch.from_numpy(mesh.visual.vertex_colors)[..., :3].T / 255.
|
41 |
+
if colors is None:
|
42 |
+
colors = torch.ones_like(vertices) * 0.5
|
43 |
+
return vertices, faces, colors
|
44 |
+
|
45 |
+
def meshlab_mesh_to_py3dmesh(mesh: pymeshlab.Mesh) -> Meshes:
|
46 |
+
verts = torch.from_numpy(mesh.vertex_matrix()).float()
|
47 |
+
faces = torch.from_numpy(mesh.face_matrix()).long()
|
48 |
+
colors = torch.from_numpy(mesh.vertex_color_matrix()[..., :3]).float()
|
49 |
+
textures = TexturesVertex(verts_features=[colors])
|
50 |
+
return Meshes(verts=[verts], faces=[faces], textures=textures)
|
51 |
+
|
52 |
+
|
53 |
+
def py3dmesh_to_meshlab_mesh(meshes: Meshes) -> pymeshlab.Mesh:
|
54 |
+
colors_in = F.pad(meshes.textures.verts_features_packed().cpu().float(), [0,1], value=1).numpy().astype(np.float64)
|
55 |
+
m1 = pymeshlab.Mesh(
|
56 |
+
vertex_matrix=meshes.verts_packed().cpu().float().numpy().astype(np.float64),
|
57 |
+
face_matrix=meshes.faces_packed().cpu().long().numpy().astype(np.int32),
|
58 |
+
v_normals_matrix=meshes.verts_normals_packed().cpu().float().numpy().astype(np.float64),
|
59 |
+
v_color_matrix=colors_in)
|
60 |
+
return m1
|
61 |
+
|
62 |
+
|
63 |
+
def to_pyml_mesh(vertices,faces):
|
64 |
+
m1 = pymeshlab.Mesh(
|
65 |
+
vertex_matrix=vertices.cpu().float().numpy().astype(np.float64),
|
66 |
+
face_matrix=faces.cpu().long().numpy().astype(np.int32),
|
67 |
+
)
|
68 |
+
return m1
|
69 |
+
|
70 |
+
|
71 |
+
def to_py3d_mesh(vertices, faces, normals=None):
|
72 |
+
from pytorch3d.structures import Meshes
|
73 |
+
from pytorch3d.renderer.mesh.textures import TexturesVertex
|
74 |
+
mesh = Meshes(verts=[vertices], faces=[faces], textures=None)
|
75 |
+
if normals is None:
|
76 |
+
normals = mesh.verts_normals_packed()
|
77 |
+
# set normals as vertext colors
|
78 |
+
mesh.textures = TexturesVertex(verts_features=[normals / 2 + 0.5])
|
79 |
+
return mesh
|
80 |
+
|
81 |
+
|
82 |
+
def from_py3d_mesh(mesh):
|
83 |
+
return mesh.verts_list()[0], mesh.faces_list()[0], mesh.textures.verts_features_packed()
|
84 |
+
|
85 |
+
def rotate_normalmap_by_angle(normal_map: np.ndarray, angle: float):
|
86 |
+
"""
|
87 |
+
rotate along y-axis
|
88 |
+
normal_map: np.array, shape=(H, W, 3) in [-1, 1]
|
89 |
+
angle: float, in degree
|
90 |
+
"""
|
91 |
+
angle = angle / 180 * np.pi
|
92 |
+
R = np.array([[np.cos(angle), 0, np.sin(angle)], [0, 1, 0], [-np.sin(angle), 0, np.cos(angle)]])
|
93 |
+
return np.dot(normal_map.reshape(-1, 3), R.T).reshape(normal_map.shape)
|
94 |
+
|
95 |
+
# from view coord to front view world coord
|
96 |
+
def rotate_normals(normal_pils, return_types='np', rotate_direction=1) -> np.ndarray: # [0, 255]
|
97 |
+
n_views = len(normal_pils)
|
98 |
+
ret = []
|
99 |
+
for idx, rgba_normal in enumerate(normal_pils):
|
100 |
+
# rotate normal
|
101 |
+
normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1]
|
102 |
+
alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1]
|
103 |
+
normal_np = normal_np * 2 - 1
|
104 |
+
normal_np = rotate_normalmap_by_angle(normal_np, rotate_direction * idx * (360 / n_views))
|
105 |
+
normal_np = (normal_np + 1) / 2
|
106 |
+
normal_np = normal_np * alpha_np[..., None] # make bg black
|
107 |
+
rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[:, :, None] * 255] , axis=-1)
|
108 |
+
if return_types == 'np':
|
109 |
+
ret.append(rgba_normal_np)
|
110 |
+
elif return_types == 'pil':
|
111 |
+
ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
|
112 |
+
else:
|
113 |
+
raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
|
114 |
+
return ret
|
115 |
+
|
116 |
+
|
117 |
+
def rotate_normalmap_by_angle_torch(normal_map, angle):
|
118 |
+
"""
|
119 |
+
rotate along y-axis
|
120 |
+
normal_map: torch.Tensor, shape=(H, W, 3) in [-1, 1], device='cuda'
|
121 |
+
angle: float, in degree
|
122 |
+
"""
|
123 |
+
angle = torch.tensor(angle / 180 * np.pi).to(normal_map)
|
124 |
+
R = torch.tensor([[torch.cos(angle), 0, torch.sin(angle)],
|
125 |
+
[0, 1, 0],
|
126 |
+
[-torch.sin(angle), 0, torch.cos(angle)]]).to(normal_map)
|
127 |
+
return torch.matmul(normal_map.view(-1, 3), R.T).view(normal_map.shape)
|
128 |
+
|
129 |
+
def do_rotate(rgba_normal, angle):
|
130 |
+
rgba_normal = torch.from_numpy(rgba_normal).float().cuda() / 255
|
131 |
+
rotated_normal_tensor = rotate_normalmap_by_angle_torch(rgba_normal[..., :3] * 2 - 1, angle)
|
132 |
+
rotated_normal_tensor = (rotated_normal_tensor + 1) / 2
|
133 |
+
rotated_normal_tensor = rotated_normal_tensor * rgba_normal[:, :, [3]] # make bg black
|
134 |
+
rgba_normal_np = torch.cat([rotated_normal_tensor * 255, rgba_normal[:, :, [3]] * 255], dim=-1).cpu().numpy()
|
135 |
+
return rgba_normal_np
|
136 |
+
|
137 |
+
def rotate_normals_torch(normal_pils, return_types='np', rotate_direction=1):
|
138 |
+
n_views = len(normal_pils)
|
139 |
+
ret = []
|
140 |
+
for idx, rgba_normal in enumerate(normal_pils):
|
141 |
+
# rotate normal
|
142 |
+
angle = rotate_direction * idx * (360 / n_views)
|
143 |
+
rgba_normal_np = do_rotate(np.array(rgba_normal), angle)
|
144 |
+
if return_types == 'np':
|
145 |
+
ret.append(rgba_normal_np)
|
146 |
+
elif return_types == 'pil':
|
147 |
+
ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
|
148 |
+
else:
|
149 |
+
raise ValueError(f"return_types should be 'np' or 'pil', but got {return_types}")
|
150 |
+
return ret
|
151 |
+
|
152 |
+
def change_bkgd(img_pils, new_bkgd=(0., 0., 0.)):
|
153 |
+
ret = []
|
154 |
+
new_bkgd = np.array(new_bkgd).reshape(1, 1, 3)
|
155 |
+
for rgba_img in img_pils:
|
156 |
+
img_np = np.array(rgba_img)[:, :, :3] / 255
|
157 |
+
alpha_np = np.array(rgba_img)[:, :, 3] / 255
|
158 |
+
ori_bkgd = img_np[:1, :1]
|
159 |
+
# color = ori_color * alpha + bkgd * (1-alpha)
|
160 |
+
# ori_color = (color - bkgd * (1-alpha)) / alpha
|
161 |
+
alpha_np_clamp = np.clip(alpha_np, 1e-6, 1) # avoid divide by zero
|
162 |
+
ori_img_np = (img_np - ori_bkgd * (1 - alpha_np[..., None])) / alpha_np_clamp[..., None]
|
163 |
+
img_np = np.where(alpha_np[..., None] > 0.05, ori_img_np * alpha_np[..., None] + new_bkgd * (1 - alpha_np[..., None]), new_bkgd)
|
164 |
+
rgba_img_np = np.concatenate([img_np * 255, alpha_np[..., None] * 255], axis=-1)
|
165 |
+
ret.append(Image.fromarray(rgba_img_np.astype(np.uint8)))
|
166 |
+
return ret
|
167 |
+
|
168 |
+
def change_bkgd_to_normal(normal_pils) -> List[Image.Image]:
|
169 |
+
n_views = len(normal_pils)
|
170 |
+
ret = []
|
171 |
+
for idx, rgba_normal in enumerate(normal_pils):
|
172 |
+
# calcuate background normal
|
173 |
+
target_bkgd = rotate_normalmap_by_angle(np.array([[[0., 0., 1.]]]), idx * (360 / n_views))
|
174 |
+
normal_np = np.array(rgba_normal)[:, :, :3] / 255 # in [-1, 1]
|
175 |
+
alpha_np = np.array(rgba_normal)[:, :, 3] / 255 # in [0, 1]
|
176 |
+
normal_np = normal_np * 2 - 1
|
177 |
+
old_bkgd = normal_np[:1,:1]
|
178 |
+
normal_np[alpha_np > 0.05] = (normal_np[alpha_np > 0.05] - old_bkgd * (1 - alpha_np[alpha_np > 0.05][..., None])) / alpha_np[alpha_np > 0.05][..., None]
|
179 |
+
normal_np = normal_np * alpha_np[..., None] + target_bkgd * (1 - alpha_np[..., None])
|
180 |
+
normal_np = (normal_np + 1) / 2
|
181 |
+
rgba_normal_np = np.concatenate([normal_np * 255, alpha_np[..., None] * 255] , axis=-1)
|
182 |
+
ret.append(Image.fromarray(rgba_normal_np.astype(np.uint8)))
|
183 |
+
return ret
|
184 |
+
|
185 |
+
|
186 |
+
def fix_vert_color_glb(mesh_path):
|
187 |
+
from pygltflib import GLTF2, Material, PbrMetallicRoughness
|
188 |
+
obj1 = GLTF2().load(mesh_path)
|
189 |
+
obj1.meshes[0].primitives[0].material = 0
|
190 |
+
obj1.materials.append(Material(
|
191 |
+
pbrMetallicRoughness = PbrMetallicRoughness(
|
192 |
+
baseColorFactor = [1.0, 1.0, 1.0, 1.0],
|
193 |
+
metallicFactor = 0.,
|
194 |
+
roughnessFactor = 1.0,
|
195 |
+
),
|
196 |
+
emissiveFactor = [0.0, 0.0, 0.0],
|
197 |
+
doubleSided = True,
|
198 |
+
))
|
199 |
+
obj1.save(mesh_path)
|
200 |
+
|
201 |
+
|
202 |
+
def srgb_to_linear(c_srgb):
|
203 |
+
c_linear = np.where(c_srgb <= 0.04045, c_srgb / 12.92, ((c_srgb + 0.055) / 1.055) ** 2.4)
|
204 |
+
return c_linear.clip(0, 1.)
|
205 |
+
|
206 |
+
|
207 |
+
def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
|
208 |
+
# convert from pytorch3d meshes to trimesh mesh
|
209 |
+
vertices = meshes.verts_packed().cpu().float().numpy()
|
210 |
+
triangles = meshes.faces_packed().cpu().long().numpy()
|
211 |
+
np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
|
212 |
+
if save_glb_path.endswith(".glb"):
|
213 |
+
# rotate 180 along +Y
|
214 |
+
vertices[:, [0, 2]] = -vertices[:, [0, 2]]
|
215 |
+
|
216 |
+
if apply_sRGB_to_LinearRGB:
|
217 |
+
np_color = srgb_to_linear(np_color)
|
218 |
+
assert vertices.shape[0] == np_color.shape[0]
|
219 |
+
assert np_color.shape[1] == 3
|
220 |
+
assert 0 <= np_color.min() and np_color.max() <= 1, f"min={np_color.min()}, max={np_color.max()}"
|
221 |
+
mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
|
222 |
+
mesh.remove_unreferenced_vertices()
|
223 |
+
# save mesh
|
224 |
+
mesh.export(save_glb_path)
|
225 |
+
if save_glb_path.endswith(".glb"):
|
226 |
+
fix_vert_color_glb(save_glb_path)
|
227 |
+
|
228 |
+
|
229 |
+
def save_glb_and_video(save_mesh_prefix: str, meshes: Meshes, with_timestamp=True, dist=3.5, azim_offset=180, resolution=512, fov_in_degrees=1 / 1.15, cam_type="ortho", view_padding=60, export_video=True) -> Tuple[str, str]:
|
230 |
+
import time
|
231 |
+
if '.' in save_mesh_prefix:
|
232 |
+
save_mesh_prefix = ".".join(save_mesh_prefix.split('.')[:-1])
|
233 |
+
if with_timestamp:
|
234 |
+
save_mesh_prefix = save_mesh_prefix + f"_{int(time.time())}"
|
235 |
+
ret_mesh = save_mesh_prefix + ".glb"
|
236 |
+
# optimizied version
|
237 |
+
save_py3dmesh_with_trimesh_fast(meshes, ret_mesh)
|
238 |
+
return ret_mesh, None
|
239 |
+
|
240 |
+
|
241 |
+
def simple_clean_mesh(pyml_mesh: ml.Mesh, apply_smooth=True, stepsmoothnum=1, apply_sub_divide=False, sub_divide_threshold=0.25):
|
242 |
+
ms = ml.MeshSet()
|
243 |
+
ms.add_mesh(pyml_mesh, "cube_mesh")
|
244 |
+
|
245 |
+
if apply_smooth:
|
246 |
+
ms.apply_filter("apply_coord_laplacian_smoothing", stepsmoothnum=stepsmoothnum, cotangentweight=False)
|
247 |
+
if apply_sub_divide: # 5s, slow
|
248 |
+
ms.apply_filter("meshing_repair_non_manifold_vertices")
|
249 |
+
ms.apply_filter("meshing_repair_non_manifold_edges", method='Remove Faces')
|
250 |
+
ms.apply_filter("meshing_surface_subdivision_loop", iterations=2, threshold=PercentageValue(sub_divide_threshold))
|
251 |
+
return meshlab_mesh_to_py3dmesh(ms.current_mesh())
|
252 |
+
|
253 |
+
|
254 |
+
def expand2square(pil_img, background_color):
|
255 |
+
width, height = pil_img.size
|
256 |
+
if width == height:
|
257 |
+
return pil_img
|
258 |
+
elif width > height:
|
259 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
260 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
261 |
+
return result
|
262 |
+
else:
|
263 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
264 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
265 |
+
return result
|
266 |
+
|
267 |
+
|
268 |
+
|
269 |
+
def init_target(img_pils, new_bkgd=(0., 0., 0.), device="cuda"):
|
270 |
+
new_bkgd = torch.tensor(new_bkgd, dtype=torch.float32).view(1, 1, 3).to(device)
|
271 |
+
|
272 |
+
imgs = torch.stack([torch.from_numpy(np.array(img, dtype=np.float32)) for img in img_pils]).to(device) / 255
|
273 |
+
img_nps = imgs[..., :3]
|
274 |
+
alpha_nps = imgs[..., 3]
|
275 |
+
ori_bkgds = img_nps[:, :1, :1]
|
276 |
+
|
277 |
+
alpha_nps_clamp = torch.clamp(alpha_nps, 1e-6, 1)
|
278 |
+
ori_img_nps = (img_nps - ori_bkgds * (1 - alpha_nps.unsqueeze(-1))) / alpha_nps_clamp.unsqueeze(-1)
|
279 |
+
ori_img_nps = torch.clamp(ori_img_nps, 0, 1)
|
280 |
+
img_nps = torch.where(alpha_nps.unsqueeze(-1) > 0.05, ori_img_nps * alpha_nps.unsqueeze(-1) + new_bkgd * (1 - alpha_nps.unsqueeze(-1)), new_bkgd)
|
281 |
+
|
282 |
+
rgba_img_np = torch.cat([img_nps, alpha_nps.unsqueeze(-1)], dim=-1)
|
283 |
+
return rgba_img_np
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
def rotation_matrix_axis_angle(axis, angle, device='cuda'):
|
288 |
+
"""
|
289 |
+
Return the rotation matrix associated with counterclockwise rotation about
|
290 |
+
the given axis by angle degrees, using PyTorch.
|
291 |
+
"""
|
292 |
+
if type(axis) != torch.tensor:
|
293 |
+
axis = torch.tensor(axis, device=device)
|
294 |
+
axis = axis.float().to(device)
|
295 |
+
if type(angle) != torch.tensor:
|
296 |
+
angle = torch.tensor(angle, device=device)
|
297 |
+
angle = angle.float().to(device)
|
298 |
+
|
299 |
+
theta = angle * torch.pi / 180.0
|
300 |
+
axis = torch.tensor(axis, dtype=torch.float32)
|
301 |
+
if torch.dot(axis, axis) > 0:
|
302 |
+
denom = torch.sqrt(torch.dot(axis, axis))
|
303 |
+
demon = torch.where(denom == 0, torch.tensor(EPSILON).to(denom.device), denom)
|
304 |
+
axis = axis / torch.sqrt(demon)
|
305 |
+
a = torch.cos(theta / 2.0)
|
306 |
+
b, c, d = -axis[0] * torch.sin(theta / 2.0), -axis[1] * torch.sin(theta / 2.0), -axis[2] * torch.sin(theta / 2.0)
|
307 |
+
|
308 |
+
aa, bb, cc, dd = a*a, b*b, c*c, d*d
|
309 |
+
bc, ad, ac, ab, bd, cd = b*c, a*d, a*c, a*b, b*d, c*d
|
310 |
+
return torch.stack([
|
311 |
+
torch.stack([aa+bb-cc-dd, 2*(bc+ad), 2*(bd-ac)]),
|
312 |
+
torch.stack([2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)]),
|
313 |
+
torch.stack([2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc])
|
314 |
+
])
|
315 |
+
else:
|
316 |
+
return torch.eye(3)
|
317 |
+
|
318 |
+
|
319 |
+
|
320 |
+
def normal_rotation_img2img_angle_axis(image, angle, axis=None, device='cuda'):
|
321 |
+
"""
|
322 |
+
Rotate an image by a given angle around a given axis using PyTorch.
|
323 |
+
|
324 |
+
Args:
|
325 |
+
image: Input Image to rotate.
|
326 |
+
angle: Rotation angle in degrees.
|
327 |
+
axis: Rotation axis as a array of 3 floats.
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
Image: Rotated Image.
|
331 |
+
"""
|
332 |
+
if axis is None:
|
333 |
+
axis = [0,1,0]
|
334 |
+
axis = torch.tensor(axis, device=device)
|
335 |
+
|
336 |
+
|
337 |
+
if type(image) == Image.Image:
|
338 |
+
image_array = torch.tensor(np.array(image, dtype='float32'))
|
339 |
+
else:
|
340 |
+
image_array = image
|
341 |
+
image_array = image_array.to(device)
|
342 |
+
|
343 |
+
if type(angle) != torch.Tensor:
|
344 |
+
angle = torch.tensor(angle)
|
345 |
+
angle = angle.to(device)
|
346 |
+
|
347 |
+
if image_array.shape[2] == 4:
|
348 |
+
rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3]
|
349 |
+
else:
|
350 |
+
rgb_array = image_array
|
351 |
+
alpha_array = None
|
352 |
+
|
353 |
+
rgb_array = rgb_array / 255.0 - 0.5
|
354 |
+
|
355 |
+
rgb_array = rgb_array.permute(2, 0, 1)
|
356 |
+
|
357 |
+
rotated_tensor = apply_rotation_angle_axis(rgb_array.unsqueeze(0), axis, torch.tensor([angle], device=rgb_array.device))
|
358 |
+
|
359 |
+
|
360 |
+
rotated_array = rotated_tensor.squeeze().permute(1, 2, 0)
|
361 |
+
|
362 |
+
rotated_array = (rotated_array/2 + 0.5) * 255
|
363 |
+
|
364 |
+
if alpha_array is not None:
|
365 |
+
rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2)
|
366 |
+
|
367 |
+
rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8')
|
368 |
+
|
369 |
+
rotated_normal = Image.fromarray(rotated_array_uint8)
|
370 |
+
|
371 |
+
return rotated_normal
|
372 |
+
|
373 |
+
def normal_rotation_img2img_c2w(image, c2w, device='cuda'):
|
374 |
+
|
375 |
+
if type(image) != torch.Tensor:
|
376 |
+
image_array = torch.tensor(np.array(image, dtype='float32'))
|
377 |
+
else:
|
378 |
+
image_array = image
|
379 |
+
|
380 |
+
|
381 |
+
image_array = image_array.to(device)
|
382 |
+
|
383 |
+
if image_array.shape[2] == 4:
|
384 |
+
rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3]
|
385 |
+
else:
|
386 |
+
rgb_array = image_array
|
387 |
+
alpha_array = None
|
388 |
+
|
389 |
+
rgb_array = rgb_array / 255.0 - 0.5
|
390 |
+
|
391 |
+
rotation_matrix = c2w
|
392 |
+
|
393 |
+
rotated_tensor = transform_normals_R(rgb_array, rotation_matrix)
|
394 |
+
|
395 |
+
rotated_array = rotated_tensor.squeeze().permute(1, 2, 0)
|
396 |
+
rotated_array = (rotated_array/2 + 0.5) * 255
|
397 |
+
|
398 |
+
if alpha_array is not None:
|
399 |
+
rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2)
|
400 |
+
|
401 |
+
rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8')
|
402 |
+
|
403 |
+
rotated_normal = Image.fromarray(rotated_array_uint8)
|
404 |
+
|
405 |
+
return rotated_normal
|
406 |
+
|
407 |
+
def normal_rotation_img2img_azi_ele(image, azi, ele, device='cuda'):
|
408 |
+
"""
|
409 |
+
Rotate an image by a given angle around a given axis using PyTorch.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
image: Input Image to rotate.
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
Image: Rotated Image.
|
416 |
+
"""
|
417 |
+
|
418 |
+
if type(image) == Image.Image:
|
419 |
+
image_array = torch.tensor(np.array(image, dtype='float32'))
|
420 |
+
else:
|
421 |
+
image_array = image
|
422 |
+
image_array = image_array.to(device)
|
423 |
+
|
424 |
+
if type(azi) != torch.Tensor:
|
425 |
+
azi = torch.tensor(azi)
|
426 |
+
azi = azi.to(device)
|
427 |
+
|
428 |
+
if type(ele) != torch.Tensor:
|
429 |
+
ele = torch.tensor(ele)
|
430 |
+
ele = ele.to(device)
|
431 |
+
|
432 |
+
if image_array.shape[2] == 4:
|
433 |
+
rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3]
|
434 |
+
else:
|
435 |
+
rgb_array = image_array
|
436 |
+
alpha_array = None
|
437 |
+
|
438 |
+
rgb_array = rgb_array / 255.0 - 0.5
|
439 |
+
|
440 |
+
rotation_matrix = get_rotation_matrix_azi_ele(azi, ele)
|
441 |
+
rotated_tensor = transform_normals_R(rgb_array, rotation_matrix)
|
442 |
+
|
443 |
+
rotated_array = rotated_tensor.squeeze().permute(1, 2, 0)
|
444 |
+
|
445 |
+
rotated_array = (rotated_array/2 + 0.5) * 255
|
446 |
+
|
447 |
+
if alpha_array is not None:
|
448 |
+
rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2)
|
449 |
+
|
450 |
+
rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8')
|
451 |
+
|
452 |
+
rotated_normal = Image.fromarray(rotated_array_uint8)
|
453 |
+
|
454 |
+
return rotated_normal
|
455 |
+
|
456 |
+
|
457 |
+
def rotate_normal_R(image, rotation_matrix, save_addr="", device="cuda"):
|
458 |
+
"""
|
459 |
+
Rotate a normal map by a given Rotation matrix using PyTorch.
|
460 |
+
|
461 |
+
Args:
|
462 |
+
image: Input Image to rotate.
|
463 |
+
|
464 |
+
Returns:
|
465 |
+
Image: Rotated Image.
|
466 |
+
"""
|
467 |
+
|
468 |
+
if type(image) != torch.tensor:
|
469 |
+
image_array = torch.tensor(np.array(image, dtype='float32'))
|
470 |
+
else:
|
471 |
+
image_array = image
|
472 |
+
image_array = image_array.to(device)
|
473 |
+
|
474 |
+
if image_array.shape[2] == 4:
|
475 |
+
rgb_array, alpha_array = image_array[:, :, :3], image_array[:, :, 3]
|
476 |
+
else:
|
477 |
+
rgb_array = image_array
|
478 |
+
alpha_array = None
|
479 |
+
|
480 |
+
rgb_array = rgb_array / 255.0 - 0.5
|
481 |
+
|
482 |
+
rotated_tensor = transform_normals_R(rgb_array, rotation_matrix.to(device))
|
483 |
+
|
484 |
+
rotated_array = rotated_tensor.squeeze().permute(1, 2, 0)
|
485 |
+
|
486 |
+
rotated_array = (rotated_array/2 + 0.5) * 255
|
487 |
+
|
488 |
+
if alpha_array is not None:
|
489 |
+
rotated_array = torch.cat((rotated_array, alpha_array.unsqueeze(2)), dim=2)
|
490 |
+
|
491 |
+
rotated_array_uint8 = np.array(rotated_array.detach().cpu()).astype('uint8')
|
492 |
+
|
493 |
+
rotated_normal = Image.fromarray(rotated_array_uint8)
|
494 |
+
|
495 |
+
if save_addr:
|
496 |
+
rotated_normal.save(save_addr)
|
497 |
+
return rotated_normal
|
498 |
+
|
499 |
+
|
500 |
+
|
501 |
+
def transform_normals_R(local_normals, rotation_matrix):
|
502 |
+
assert local_normals.shape[2] ==3 ,f'local_normals.shape[2]: {local_normals.shape[2]}. only support rgb image'
|
503 |
+
|
504 |
+
h, w = local_normals.shape[:2]
|
505 |
+
local_normals_flat = local_normals.view(-1, 3).permute(1, 0)
|
506 |
+
|
507 |
+
images_flat = local_normals_flat.unsqueeze(0)
|
508 |
+
rotation_matrices = rotation_matrix.unsqueeze(0)
|
509 |
+
rotated_images_flat = torch.bmm(rotation_matrices, images_flat)
|
510 |
+
|
511 |
+
rotated_images = rotated_images_flat.view(1, 3, h, w)
|
512 |
+
|
513 |
+
norms = torch.norm(rotated_images, p=2, dim=1, keepdim=True)
|
514 |
+
norms = torch.where(norms == 0, torch.tensor(EPSILON).to(norms.device), norms)
|
515 |
+
normalized_images = rotated_images / norms
|
516 |
+
|
517 |
+
return normalized_images
|
518 |
+
|
519 |
+
|
520 |
+
def manage_elevation_azimuth(ele_list, azi_list):
|
521 |
+
"""deal with cases when elevation > 90"""
|
522 |
+
|
523 |
+
for i in range(len(ele_list)):
|
524 |
+
elevation = ele_list[i] % 360
|
525 |
+
azimuth = azi_list[i] % 360
|
526 |
+
if elevation > 90 and elevation<=270:
|
527 |
+
# when elevation is too big,camera gets to the other side
|
528 |
+
# print(f'!!! elevation({elevation}) > 90 and <=270, set to 180-elevation, and add 180 to azimuth')
|
529 |
+
elevation = 180 - elevation
|
530 |
+
azimuth = azimuth + 180
|
531 |
+
# print(f'new elevation: {elevation}, new azimuth: {azimuth}')
|
532 |
+
|
533 |
+
elif elevation>270:
|
534 |
+
# print(f'!!! elevation({elevation}) > 270, set to elevation-360, and use original azimuth')
|
535 |
+
elevation = elevation - 360
|
536 |
+
azimuth = azimuth
|
537 |
+
# print(f'new elevation: {elevation}, new azimuth: {azimuth}')
|
538 |
+
|
539 |
+
ele_list[i] = elevation
|
540 |
+
azi_list[i] = azimuth
|
541 |
+
|
542 |
+
return ele_list, azi_list
|
543 |
+
|
544 |
+
def get_rotation_matrix_azi_ele(azimuth, elevation):
|
545 |
+
|
546 |
+
ele = elevation/180 * torch.pi
|
547 |
+
azi = azimuth/180 * torch.pi
|
548 |
+
|
549 |
+
Rz = torch.tensor([
|
550 |
+
[torch.cos(azi), 0, -torch.sin(azi)],
|
551 |
+
[0, 1, 0],
|
552 |
+
[torch.sin(azi), 0, torch.cos(azi)],
|
553 |
+
]).to(azimuth.device)
|
554 |
+
|
555 |
+
Re = torch.tensor([
|
556 |
+
[1, 0, 0],
|
557 |
+
[0, torch.cos(ele), torch.sin(ele)],
|
558 |
+
[0, -torch.sin(ele), torch.cos(ele)],
|
559 |
+
]).to(elevation.device)
|
560 |
+
|
561 |
+
return torch.matmul(Rz,Re).to(azimuth.device)
|
562 |
+
|
563 |
+
|
564 |
+
def rotate_vector(vector, axis, angle, device='cuda'):
|
565 |
+
rot_matrix = rotation_matrix_axis_angle(axis, angle)
|
566 |
+
return torch.matmul(vector.to(device).float(), rot_matrix.to(device).float())
|
567 |
+
|
568 |
+
def apply_rotation_angle_axis(image, axis, angle, device='cuda'):
|
569 |
+
"""Apply rotation to a batch of images with shape [batch_size, 3(rgb), h, w] using PyTorch.
|
570 |
+
|
571 |
+
Args:
|
572 |
+
image (torch.Tensor): Input RGB image tensor of shape [batch_size, 3, h, w]. each pixel's rgb channels refer to direction of normal (can be negative)
|
573 |
+
axis (torch.Tensor): Rotation axis of shape [3].
|
574 |
+
angle (torch.Tensor): Rotation angles in degrees, of shape [batch_size].
|
575 |
+
Returns:
|
576 |
+
torch.Tensor: Rotated image tensor of shape [batch_size, 3, h, w]. values between [-1., 1.]
|
577 |
+
|
578 |
+
"""
|
579 |
+
|
580 |
+
if not isinstance(image, torch.Tensor):
|
581 |
+
image_tensor = torch.tensor(image).to(device)
|
582 |
+
else:
|
583 |
+
image_tensor = image.to(device)
|
584 |
+
|
585 |
+
if not isinstance(axis, torch.Tensor):
|
586 |
+
axis = torch.tensor(axis)
|
587 |
+
axis = axis.to(device)
|
588 |
+
|
589 |
+
if not isinstance(angle, torch.Tensor):
|
590 |
+
angle = torch.tensor(angle)
|
591 |
+
angle = angle.to(device)
|
592 |
+
|
593 |
+
batch_size, channels, h, w = image_tensor.shape
|
594 |
+
rot_matrix = rotation_matrix_axis_angle(axis, angle)
|
595 |
+
|
596 |
+
rotation_matrices = rot_matrix.permute(2, 0, 1)
|
597 |
+
|
598 |
+
batch_size, c, h, w = image_tensor.shape
|
599 |
+
images_flat = image_tensor.view(batch_size, c, h * w)
|
600 |
+
|
601 |
+
rotated_images_flat = torch.bmm(rotation_matrices, images_flat)
|
602 |
+
|
603 |
+
rotated_images = rotated_images_flat.view(batch_size, c, h, w)
|
604 |
+
|
605 |
+
norms = torch.norm(rotated_images, p=2, dim=1, keepdim=True)
|
606 |
+
|
607 |
+
norms = torch.where(norms == 0, torch.tensor(EPSILON).to(norms.device), norms)
|
608 |
+
|
609 |
+
normalized_images = rotated_images / norms
|
610 |
+
|
611 |
+
return normalized_images
|
models/lrm/config/PRM_inference.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: models.lrm.models.lrm_mesh.PRM
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 16
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 80
|
13 |
+
rendering_samples_per_ray: 128
|
14 |
+
grid_res: 128
|
15 |
+
grid_scale: 2.1
|
16 |
+
|
17 |
+
|
18 |
+
infer_config:
|
19 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
20 |
+
model_path: ckpts/final_ckpt.ckpt
|
21 |
+
texture_resolution: 2048
|
22 |
+
render_resolution: 512
|
models/lrm/models/__init__.py
ADDED
File without changes
|
models/lrm/models/decoder/__init__.py
ADDED
File without changes
|
models/lrm/models/decoder/transformer.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
class BasicTransformerBlock(nn.Module):
|
21 |
+
"""
|
22 |
+
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
|
23 |
+
"""
|
24 |
+
# use attention from torch.nn.MultiHeadAttention
|
25 |
+
# Block contains a cross-attention layer, a self-attention layer, and a MLP
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
inner_dim: int,
|
29 |
+
cond_dim: int,
|
30 |
+
num_heads: int,
|
31 |
+
eps: float,
|
32 |
+
attn_drop: float = 0.,
|
33 |
+
attn_bias: bool = False,
|
34 |
+
mlp_ratio: float = 4.,
|
35 |
+
mlp_drop: float = 0.,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.norm1 = nn.LayerNorm(inner_dim)
|
40 |
+
self.cross_attn = nn.MultiheadAttention(
|
41 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
42 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
43 |
+
self.norm2 = nn.LayerNorm(inner_dim)
|
44 |
+
self.self_attn = nn.MultiheadAttention(
|
45 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
46 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
47 |
+
self.norm3 = nn.LayerNorm(inner_dim)
|
48 |
+
self.mlp = nn.Sequential(
|
49 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
50 |
+
nn.GELU(),
|
51 |
+
nn.Dropout(mlp_drop),
|
52 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
53 |
+
nn.Dropout(mlp_drop),
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, x, cond):
|
57 |
+
# x: [N, L, D]
|
58 |
+
# cond: [N, L_cond, D_cond]
|
59 |
+
x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
|
60 |
+
before_sa = self.norm2(x)
|
61 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
|
62 |
+
x = x + self.mlp(self.norm3(x))
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class TriplaneTransformer(nn.Module):
|
67 |
+
"""
|
68 |
+
Transformer with condition that generates a triplane representation.
|
69 |
+
|
70 |
+
Reference:
|
71 |
+
Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
|
72 |
+
"""
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
inner_dim: int,
|
76 |
+
image_feat_dim: int,
|
77 |
+
triplane_low_res: int,
|
78 |
+
triplane_high_res: int,
|
79 |
+
triplane_dim: int,
|
80 |
+
num_layers: int,
|
81 |
+
num_heads: int,
|
82 |
+
eps: float = 1e-6,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
# attributes
|
87 |
+
self.triplane_low_res = triplane_low_res
|
88 |
+
self.triplane_high_res = triplane_high_res
|
89 |
+
self.triplane_dim = triplane_dim
|
90 |
+
|
91 |
+
# modules
|
92 |
+
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
|
93 |
+
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
|
94 |
+
self.layers = nn.ModuleList([
|
95 |
+
BasicTransformerBlock(
|
96 |
+
inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
|
97 |
+
for _ in range(num_layers)
|
98 |
+
])
|
99 |
+
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
100 |
+
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
|
101 |
+
|
102 |
+
def forward(self, image_feats):
|
103 |
+
# image_feats: [N, L_cond, D_cond]
|
104 |
+
|
105 |
+
N = image_feats.shape[0]
|
106 |
+
H = W = self.triplane_low_res
|
107 |
+
L = 3 * H * W
|
108 |
+
|
109 |
+
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
|
110 |
+
for layer in self.layers:
|
111 |
+
x = layer(x, image_feats)
|
112 |
+
x = self.norm(x)
|
113 |
+
|
114 |
+
# separate each plane and apply deconv
|
115 |
+
x = x.view(N, 3, H, W, -1)
|
116 |
+
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
|
117 |
+
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
|
118 |
+
x = self.deconv(x) # [3*N, D', H', W']
|
119 |
+
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
|
120 |
+
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
|
121 |
+
x = x.contiguous()
|
122 |
+
|
123 |
+
return x
|
models/lrm/models/encoder/__init__.py
ADDED
File without changes
|
models/lrm/models/encoder/dino.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch ViT model."""
|
16 |
+
|
17 |
+
|
18 |
+
import collections.abc
|
19 |
+
import math
|
20 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.modeling_outputs import (
|
27 |
+
BaseModelOutput,
|
28 |
+
BaseModelOutputWithPooling,
|
29 |
+
)
|
30 |
+
from transformers import PreTrainedModel, ViTConfig
|
31 |
+
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
32 |
+
|
33 |
+
|
34 |
+
class ViTEmbeddings(nn.Module):
|
35 |
+
"""
|
36 |
+
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
43 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
44 |
+
self.patch_embeddings = ViTPatchEmbeddings(config)
|
45 |
+
num_patches = self.patch_embeddings.num_patches
|
46 |
+
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
|
47 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
48 |
+
self.config = config
|
49 |
+
|
50 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
51 |
+
"""
|
52 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
53 |
+
resolution images.
|
54 |
+
|
55 |
+
Source:
|
56 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
57 |
+
"""
|
58 |
+
|
59 |
+
num_patches = embeddings.shape[1] - 1
|
60 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
61 |
+
if num_patches == num_positions and height == width:
|
62 |
+
return self.position_embeddings
|
63 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
64 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
65 |
+
dim = embeddings.shape[-1]
|
66 |
+
h0 = height // self.config.patch_size
|
67 |
+
w0 = width // self.config.patch_size
|
68 |
+
# we add a small number to avoid floating point error in the interpolation
|
69 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
70 |
+
h0, w0 = h0 + 0.1, w0 + 0.1
|
71 |
+
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
72 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
73 |
+
patch_pos_embed = nn.functional.interpolate(
|
74 |
+
patch_pos_embed,
|
75 |
+
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
76 |
+
mode="bicubic",
|
77 |
+
align_corners=False,
|
78 |
+
)
|
79 |
+
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
80 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
81 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
82 |
+
|
83 |
+
def forward(
|
84 |
+
self,
|
85 |
+
pixel_values: torch.Tensor,
|
86 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
87 |
+
interpolate_pos_encoding: bool = False,
|
88 |
+
) -> torch.Tensor:
|
89 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
90 |
+
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
91 |
+
|
92 |
+
if bool_masked_pos is not None:
|
93 |
+
seq_length = embeddings.shape[1]
|
94 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
95 |
+
# replace the masked visual tokens by mask_tokens
|
96 |
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
97 |
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
98 |
+
|
99 |
+
# add the [CLS] token to the embedded patch tokens
|
100 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
101 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
102 |
+
|
103 |
+
# add positional encoding to each token
|
104 |
+
if interpolate_pos_encoding:
|
105 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
106 |
+
else:
|
107 |
+
embeddings = embeddings + self.position_embeddings
|
108 |
+
|
109 |
+
embeddings = self.dropout(embeddings)
|
110 |
+
|
111 |
+
return embeddings
|
112 |
+
|
113 |
+
|
114 |
+
class ViTPatchEmbeddings(nn.Module):
|
115 |
+
"""
|
116 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
117 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
118 |
+
Transformer.
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(self, config):
|
122 |
+
super().__init__()
|
123 |
+
image_size, patch_size = config.image_size, config.patch_size
|
124 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
125 |
+
|
126 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
127 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
128 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
129 |
+
self.image_size = image_size
|
130 |
+
self.patch_size = patch_size
|
131 |
+
self.num_channels = num_channels
|
132 |
+
self.num_patches = num_patches
|
133 |
+
|
134 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
135 |
+
|
136 |
+
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
137 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
138 |
+
if num_channels != self.num_channels:
|
139 |
+
raise ValueError(
|
140 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
141 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
142 |
+
)
|
143 |
+
if not interpolate_pos_encoding:
|
144 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
145 |
+
raise ValueError(
|
146 |
+
f"Input image size ({height}*{width}) doesn't match model"
|
147 |
+
f" ({self.image_size[0]}*{self.image_size[1]})."
|
148 |
+
)
|
149 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
150 |
+
return embeddings
|
151 |
+
|
152 |
+
|
153 |
+
class ViTSelfAttention(nn.Module):
|
154 |
+
def __init__(self, config: ViTConfig) -> None:
|
155 |
+
super().__init__()
|
156 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
157 |
+
raise ValueError(
|
158 |
+
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
159 |
+
f"heads {config.num_attention_heads}."
|
160 |
+
)
|
161 |
+
|
162 |
+
self.num_attention_heads = config.num_attention_heads
|
163 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
164 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
165 |
+
|
166 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
167 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
168 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
169 |
+
|
170 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
171 |
+
|
172 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
173 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
174 |
+
x = x.view(new_x_shape)
|
175 |
+
return x.permute(0, 2, 1, 3)
|
176 |
+
|
177 |
+
def forward(
|
178 |
+
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
179 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
180 |
+
mixed_query_layer = self.query(hidden_states)
|
181 |
+
|
182 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
183 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
184 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
185 |
+
|
186 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
187 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
188 |
+
|
189 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
190 |
+
|
191 |
+
# Normalize the attention scores to probabilities.
|
192 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
193 |
+
|
194 |
+
# This is actually dropping out entire tokens to attend to, which might
|
195 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
196 |
+
attention_probs = self.dropout(attention_probs)
|
197 |
+
|
198 |
+
# Mask heads if we want to
|
199 |
+
if head_mask is not None:
|
200 |
+
attention_probs = attention_probs * head_mask
|
201 |
+
|
202 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
203 |
+
|
204 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
205 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
206 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
207 |
+
|
208 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
209 |
+
|
210 |
+
return outputs
|
211 |
+
|
212 |
+
|
213 |
+
class ViTSelfOutput(nn.Module):
|
214 |
+
"""
|
215 |
+
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
|
216 |
+
layernorm applied before each block.
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self, config: ViTConfig) -> None:
|
220 |
+
super().__init__()
|
221 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
222 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
223 |
+
|
224 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
225 |
+
hidden_states = self.dense(hidden_states)
|
226 |
+
hidden_states = self.dropout(hidden_states)
|
227 |
+
|
228 |
+
return hidden_states
|
229 |
+
|
230 |
+
|
231 |
+
class ViTAttention(nn.Module):
|
232 |
+
def __init__(self, config: ViTConfig) -> None:
|
233 |
+
super().__init__()
|
234 |
+
self.attention = ViTSelfAttention(config)
|
235 |
+
self.output = ViTSelfOutput(config)
|
236 |
+
self.pruned_heads = set()
|
237 |
+
|
238 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
239 |
+
if len(heads) == 0:
|
240 |
+
return
|
241 |
+
heads, index = find_pruneable_heads_and_indices(
|
242 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
243 |
+
)
|
244 |
+
|
245 |
+
# Prune linear layers
|
246 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
247 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
248 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
249 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
250 |
+
|
251 |
+
# Update hyper params and store pruned heads
|
252 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
253 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
254 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
255 |
+
|
256 |
+
def forward(
|
257 |
+
self,
|
258 |
+
hidden_states: torch.Tensor,
|
259 |
+
head_mask: Optional[torch.Tensor] = None,
|
260 |
+
output_attentions: bool = False,
|
261 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
262 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
263 |
+
|
264 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
265 |
+
|
266 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
267 |
+
return outputs
|
268 |
+
|
269 |
+
|
270 |
+
class ViTIntermediate(nn.Module):
|
271 |
+
def __init__(self, config: ViTConfig) -> None:
|
272 |
+
super().__init__()
|
273 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
274 |
+
if isinstance(config.hidden_act, str):
|
275 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
276 |
+
else:
|
277 |
+
self.intermediate_act_fn = config.hidden_act
|
278 |
+
|
279 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
280 |
+
hidden_states = self.dense(hidden_states)
|
281 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
282 |
+
|
283 |
+
return hidden_states
|
284 |
+
|
285 |
+
|
286 |
+
class ViTOutput(nn.Module):
|
287 |
+
def __init__(self, config: ViTConfig) -> None:
|
288 |
+
super().__init__()
|
289 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
290 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
291 |
+
|
292 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
293 |
+
hidden_states = self.dense(hidden_states)
|
294 |
+
hidden_states = self.dropout(hidden_states)
|
295 |
+
|
296 |
+
hidden_states = hidden_states + input_tensor
|
297 |
+
|
298 |
+
return hidden_states
|
299 |
+
|
300 |
+
|
301 |
+
def modulate(x, shift, scale):
|
302 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
303 |
+
|
304 |
+
|
305 |
+
class ViTLayer(nn.Module):
|
306 |
+
"""This corresponds to the Block class in the timm implementation."""
|
307 |
+
|
308 |
+
def __init__(self, config: ViTConfig) -> None:
|
309 |
+
super().__init__()
|
310 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
311 |
+
self.seq_len_dim = 1
|
312 |
+
self.attention = ViTAttention(config)
|
313 |
+
self.intermediate = ViTIntermediate(config)
|
314 |
+
self.output = ViTOutput(config)
|
315 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
316 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
317 |
+
|
318 |
+
self.adaLN_modulation = nn.Sequential(
|
319 |
+
nn.SiLU(),
|
320 |
+
nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
|
321 |
+
)
|
322 |
+
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
323 |
+
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
324 |
+
|
325 |
+
def forward(
|
326 |
+
self,
|
327 |
+
hidden_states: torch.Tensor,
|
328 |
+
adaln_input: torch.Tensor = None,
|
329 |
+
head_mask: Optional[torch.Tensor] = None,
|
330 |
+
output_attentions: bool = False,
|
331 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
332 |
+
shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
333 |
+
|
334 |
+
self_attention_outputs = self.attention(
|
335 |
+
modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
|
336 |
+
head_mask,
|
337 |
+
output_attentions=output_attentions,
|
338 |
+
)
|
339 |
+
attention_output = self_attention_outputs[0]
|
340 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
341 |
+
|
342 |
+
# first residual connection
|
343 |
+
hidden_states = attention_output + hidden_states
|
344 |
+
|
345 |
+
# in ViT, layernorm is also applied after self-attention
|
346 |
+
layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
|
347 |
+
layer_output = self.intermediate(layer_output)
|
348 |
+
|
349 |
+
# second residual connection is done here
|
350 |
+
layer_output = self.output(layer_output, hidden_states)
|
351 |
+
|
352 |
+
outputs = (layer_output,) + outputs
|
353 |
+
|
354 |
+
return outputs
|
355 |
+
|
356 |
+
|
357 |
+
class ViTEncoder(nn.Module):
|
358 |
+
def __init__(self, config: ViTConfig) -> None:
|
359 |
+
super().__init__()
|
360 |
+
self.config = config
|
361 |
+
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
|
362 |
+
self.gradient_checkpointing = False
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
hidden_states: torch.Tensor,
|
367 |
+
adaln_input: torch.Tensor = None,
|
368 |
+
head_mask: Optional[torch.Tensor] = None,
|
369 |
+
output_attentions: bool = False,
|
370 |
+
output_hidden_states: bool = False,
|
371 |
+
return_dict: bool = True,
|
372 |
+
) -> Union[tuple, BaseModelOutput]:
|
373 |
+
all_hidden_states = () if output_hidden_states else None
|
374 |
+
all_self_attentions = () if output_attentions else None
|
375 |
+
|
376 |
+
for i, layer_module in enumerate(self.layer):
|
377 |
+
if output_hidden_states:
|
378 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
379 |
+
|
380 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
381 |
+
|
382 |
+
if self.gradient_checkpointing and self.training:
|
383 |
+
layer_outputs = self._gradient_checkpointing_func(
|
384 |
+
layer_module.__call__,
|
385 |
+
hidden_states,
|
386 |
+
adaln_input,
|
387 |
+
layer_head_mask,
|
388 |
+
output_attentions,
|
389 |
+
)
|
390 |
+
else:
|
391 |
+
layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
|
392 |
+
|
393 |
+
hidden_states = layer_outputs[0]
|
394 |
+
|
395 |
+
if output_attentions:
|
396 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
397 |
+
|
398 |
+
if output_hidden_states:
|
399 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
400 |
+
|
401 |
+
if not return_dict:
|
402 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
403 |
+
return BaseModelOutput(
|
404 |
+
last_hidden_state=hidden_states,
|
405 |
+
hidden_states=all_hidden_states,
|
406 |
+
attentions=all_self_attentions,
|
407 |
+
)
|
408 |
+
|
409 |
+
|
410 |
+
class ViTPreTrainedModel(PreTrainedModel):
|
411 |
+
"""
|
412 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
413 |
+
models.
|
414 |
+
"""
|
415 |
+
|
416 |
+
config_class = ViTConfig
|
417 |
+
base_model_prefix = "vit"
|
418 |
+
main_input_name = "pixel_values"
|
419 |
+
supports_gradient_checkpointing = True
|
420 |
+
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
|
421 |
+
|
422 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
423 |
+
"""Initialize the weights"""
|
424 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
425 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
426 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
427 |
+
module.weight.data = nn.init.trunc_normal_(
|
428 |
+
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
429 |
+
).to(module.weight.dtype)
|
430 |
+
if module.bias is not None:
|
431 |
+
module.bias.data.zero_()
|
432 |
+
elif isinstance(module, nn.LayerNorm):
|
433 |
+
module.bias.data.zero_()
|
434 |
+
module.weight.data.fill_(1.0)
|
435 |
+
elif isinstance(module, ViTEmbeddings):
|
436 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
437 |
+
module.position_embeddings.data.to(torch.float32),
|
438 |
+
mean=0.0,
|
439 |
+
std=self.config.initializer_range,
|
440 |
+
).to(module.position_embeddings.dtype)
|
441 |
+
|
442 |
+
module.cls_token.data = nn.init.trunc_normal_(
|
443 |
+
module.cls_token.data.to(torch.float32),
|
444 |
+
mean=0.0,
|
445 |
+
std=self.config.initializer_range,
|
446 |
+
).to(module.cls_token.dtype)
|
447 |
+
|
448 |
+
|
449 |
+
class ViTModel(ViTPreTrainedModel):
|
450 |
+
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
|
451 |
+
super().__init__(config)
|
452 |
+
self.config = config
|
453 |
+
|
454 |
+
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
|
455 |
+
self.encoder = ViTEncoder(config)
|
456 |
+
|
457 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
458 |
+
self.pooler = ViTPooler(config) if add_pooling_layer else None
|
459 |
+
|
460 |
+
# Initialize weights and apply final processing
|
461 |
+
self.post_init()
|
462 |
+
|
463 |
+
def get_input_embeddings(self) -> ViTPatchEmbeddings:
|
464 |
+
return self.embeddings.patch_embeddings
|
465 |
+
|
466 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
467 |
+
"""
|
468 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
469 |
+
class PreTrainedModel
|
470 |
+
"""
|
471 |
+
for layer, heads in heads_to_prune.items():
|
472 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
473 |
+
|
474 |
+
def forward(
|
475 |
+
self,
|
476 |
+
pixel_values: Optional[torch.Tensor] = None,
|
477 |
+
adaln_input: Optional[torch.Tensor] = None,
|
478 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
479 |
+
head_mask: Optional[torch.Tensor] = None,
|
480 |
+
output_attentions: Optional[bool] = None,
|
481 |
+
output_hidden_states: Optional[bool] = None,
|
482 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
483 |
+
return_dict: Optional[bool] = None,
|
484 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
485 |
+
r"""
|
486 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
487 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
488 |
+
"""
|
489 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
490 |
+
output_hidden_states = (
|
491 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
492 |
+
)
|
493 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
494 |
+
|
495 |
+
if pixel_values is None:
|
496 |
+
raise ValueError("You have to specify pixel_values")
|
497 |
+
|
498 |
+
# Prepare head mask if needed
|
499 |
+
# 1.0 in head_mask indicate we keep the head
|
500 |
+
# attention_probs has shape bsz x n_heads x N x N
|
501 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
502 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
503 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
504 |
+
|
505 |
+
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
506 |
+
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
507 |
+
if pixel_values.dtype != expected_dtype:
|
508 |
+
pixel_values = pixel_values.to(expected_dtype)
|
509 |
+
|
510 |
+
embedding_output = self.embeddings(
|
511 |
+
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
512 |
+
)
|
513 |
+
|
514 |
+
encoder_outputs = self.encoder(
|
515 |
+
embedding_output,
|
516 |
+
adaln_input=adaln_input,
|
517 |
+
head_mask=head_mask,
|
518 |
+
output_attentions=output_attentions,
|
519 |
+
output_hidden_states=output_hidden_states,
|
520 |
+
return_dict=return_dict,
|
521 |
+
)
|
522 |
+
sequence_output = encoder_outputs[0]
|
523 |
+
sequence_output = self.layernorm(sequence_output)
|
524 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
525 |
+
|
526 |
+
if not return_dict:
|
527 |
+
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
528 |
+
return head_outputs + encoder_outputs[1:]
|
529 |
+
|
530 |
+
return BaseModelOutputWithPooling(
|
531 |
+
last_hidden_state=sequence_output,
|
532 |
+
pooler_output=pooled_output,
|
533 |
+
hidden_states=encoder_outputs.hidden_states,
|
534 |
+
attentions=encoder_outputs.attentions,
|
535 |
+
)
|
536 |
+
|
537 |
+
|
538 |
+
class ViTPooler(nn.Module):
|
539 |
+
def __init__(self, config: ViTConfig):
|
540 |
+
super().__init__()
|
541 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
542 |
+
self.activation = nn.Tanh()
|
543 |
+
|
544 |
+
def forward(self, hidden_states):
|
545 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
546 |
+
# to the first token.
|
547 |
+
first_token_tensor = hidden_states[:, 0]
|
548 |
+
pooled_output = self.dense(first_token_tensor)
|
549 |
+
pooled_output = self.activation(pooled_output)
|
550 |
+
return pooled_output
|
models/lrm/models/encoder/dino_wrapper.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch.nn as nn
|
17 |
+
from transformers import ViTImageProcessor
|
18 |
+
from einops import rearrange, repeat
|
19 |
+
from .dino import ViTModel
|
20 |
+
|
21 |
+
|
22 |
+
class DinoWrapper(nn.Module):
|
23 |
+
"""
|
24 |
+
Dino v1 wrapper using huggingface transformer implementation.
|
25 |
+
"""
|
26 |
+
def __init__(self, model_name: str, freeze: bool = True):
|
27 |
+
super().__init__()
|
28 |
+
self.model, self.processor = self._build_dino(model_name)
|
29 |
+
self.camera_embedder = nn.Sequential(
|
30 |
+
nn.Linear(16, self.model.config.hidden_size, bias=True),
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
|
33 |
+
)
|
34 |
+
if freeze:
|
35 |
+
self._freeze()
|
36 |
+
|
37 |
+
def forward(self, image, camera):
|
38 |
+
# image: [B, N, C, H, W]
|
39 |
+
# camera: [B, N, D]
|
40 |
+
# RGB image with [0,1] scale and properly sized
|
41 |
+
if image.ndim == 5:
|
42 |
+
image = rearrange(image, 'b n c h w -> (b n) c h w')
|
43 |
+
dtype = image.dtype
|
44 |
+
inputs = self.processor(
|
45 |
+
images=image.float(),
|
46 |
+
return_tensors="pt",
|
47 |
+
do_rescale=False,
|
48 |
+
do_resize=False,
|
49 |
+
).to(self.model.device).to(dtype)
|
50 |
+
# embed camera
|
51 |
+
N = camera.shape[1]
|
52 |
+
camera_embeddings = self.camera_embedder(camera)
|
53 |
+
camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
|
54 |
+
embeddings = camera_embeddings
|
55 |
+
# This resampling of positional embedding uses bicubic interpolation
|
56 |
+
outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
|
57 |
+
last_hidden_states = outputs.last_hidden_state
|
58 |
+
return last_hidden_states
|
59 |
+
|
60 |
+
def _freeze(self):
|
61 |
+
print(f"======== Freezing DinoWrapper ========")
|
62 |
+
self.model.eval()
|
63 |
+
for name, param in self.model.named_parameters():
|
64 |
+
param.requires_grad = False
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
|
68 |
+
import requests
|
69 |
+
try:
|
70 |
+
model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
|
71 |
+
processor = ViTImageProcessor.from_pretrained(model_name)
|
72 |
+
return model, processor
|
73 |
+
except requests.exceptions.ProxyError as err:
|
74 |
+
if proxy_error_retries > 0:
|
75 |
+
print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
|
76 |
+
import time
|
77 |
+
time.sleep(proxy_error_cooldown)
|
78 |
+
return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
|
79 |
+
else:
|
80 |
+
raise err
|
models/lrm/models/geometry/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
models/lrm/models/geometry/camera/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
|
13 |
+
class Camera(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super(Camera, self).__init__()
|
16 |
+
pass
|
models/lrm/models/geometry/camera/perspective_camera.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from . import Camera
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
def projection(x=0.1, n=1.0, f=50.0, near_plane=None):
|
15 |
+
if near_plane is None:
|
16 |
+
near_plane = n
|
17 |
+
return np.array(
|
18 |
+
[[n / x, 0, 0, 0],
|
19 |
+
[0, n / -x, 0, 0],
|
20 |
+
[0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)],
|
21 |
+
[0, 0, -1, 0]]).astype(np.float32)
|
22 |
+
|
23 |
+
|
24 |
+
class PerspectiveCamera(Camera):
|
25 |
+
def __init__(self, fovy=49.0, device='cuda'):
|
26 |
+
super(PerspectiveCamera, self).__init__()
|
27 |
+
self.device = device
|
28 |
+
focal = np.tan(fovy / 180.0 * np.pi * 0.5)
|
29 |
+
self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0)
|
30 |
+
|
31 |
+
def project(self, points_bxnx4):
|
32 |
+
out = torch.matmul(
|
33 |
+
points_bxnx4,
|
34 |
+
torch.transpose(self.proj_mtx, 1, 2))
|
35 |
+
return out
|
models/lrm/models/geometry/render/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class Renderer():
|
4 |
+
def __init__(self):
|
5 |
+
pass
|
6 |
+
|
7 |
+
def forward(self):
|
8 |
+
pass
|
models/lrm/models/geometry/render/neural_render.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import nvdiffrast.torch as dr
|
12 |
+
from . import Renderer
|
13 |
+
from . import util
|
14 |
+
from . import renderutils as ru
|
15 |
+
_FG_LUT = None
|
16 |
+
|
17 |
+
|
18 |
+
def interpolate(attr, rast, attr_idx, rast_db=None):
|
19 |
+
return dr.interpolate(
|
20 |
+
attr.contiguous(), rast, attr_idx, rast_db=rast_db,
|
21 |
+
diff_attrs=None if rast_db is None else 'all')
|
22 |
+
|
23 |
+
|
24 |
+
def xfm_points(points, matrix, use_python=True):
|
25 |
+
'''Transform points.
|
26 |
+
Args:
|
27 |
+
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]
|
28 |
+
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]
|
29 |
+
use_python: Use PyTorch's torch.matmul (for validation)
|
30 |
+
Returns:
|
31 |
+
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].
|
32 |
+
'''
|
33 |
+
out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))
|
34 |
+
if torch.is_anomaly_enabled():
|
35 |
+
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN"
|
36 |
+
return out
|
37 |
+
|
38 |
+
|
39 |
+
def dot(x, y):
|
40 |
+
return torch.sum(x * y, -1, keepdim=True)
|
41 |
+
|
42 |
+
|
43 |
+
def compute_vertex_normal(v_pos, t_pos_idx):
|
44 |
+
i0 = t_pos_idx[:, 0]
|
45 |
+
i1 = t_pos_idx[:, 1]
|
46 |
+
i2 = t_pos_idx[:, 2]
|
47 |
+
|
48 |
+
v0 = v_pos[i0, :]
|
49 |
+
v1 = v_pos[i1, :]
|
50 |
+
v2 = v_pos[i2, :]
|
51 |
+
|
52 |
+
face_normals = torch.cross(v1 - v0, v2 - v0)
|
53 |
+
|
54 |
+
# Splat face normals to vertices
|
55 |
+
v_nrm = torch.zeros_like(v_pos)
|
56 |
+
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
57 |
+
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
58 |
+
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
59 |
+
|
60 |
+
# Normalize, replace zero (degenerated) normals with some default value
|
61 |
+
v_nrm = torch.where(
|
62 |
+
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
63 |
+
)
|
64 |
+
v_nrm = F.normalize(v_nrm, dim=1)
|
65 |
+
assert torch.all(torch.isfinite(v_nrm))
|
66 |
+
|
67 |
+
return v_nrm
|
68 |
+
|
69 |
+
|
70 |
+
class NeuralRender(Renderer):
|
71 |
+
def __init__(self, device='cuda', camera_model=None):
|
72 |
+
super(NeuralRender, self).__init__()
|
73 |
+
self.device = device
|
74 |
+
self.ctx = dr.RasterizeCudaContext(device=device)
|
75 |
+
self.projection_mtx = None
|
76 |
+
self.camera = camera_model
|
77 |
+
|
78 |
+
# ==============================================================================================
|
79 |
+
# pixel shader
|
80 |
+
# ==============================================================================================
|
81 |
+
# def shade(
|
82 |
+
# self,
|
83 |
+
# gb_pos,
|
84 |
+
# gb_geometric_normal,
|
85 |
+
# gb_normal,
|
86 |
+
# gb_tangent,
|
87 |
+
# gb_texc,
|
88 |
+
# gb_texc_deriv,
|
89 |
+
# view_pos,
|
90 |
+
# ):
|
91 |
+
|
92 |
+
# ################################################################################
|
93 |
+
# # Texture lookups
|
94 |
+
# ################################################################################
|
95 |
+
# breakpoint()
|
96 |
+
# # Separate kd into alpha and color, default alpha = 1
|
97 |
+
# alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])
|
98 |
+
# kd = kd[..., 0:3]
|
99 |
+
|
100 |
+
# ################################################################################
|
101 |
+
# # Normal perturbation & normal bend
|
102 |
+
# ################################################################################
|
103 |
+
|
104 |
+
# perturbed_nrm = None
|
105 |
+
|
106 |
+
# gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
|
107 |
+
|
108 |
+
# ################################################################################
|
109 |
+
# # Evaluate BSDF
|
110 |
+
# ################################################################################
|
111 |
+
|
112 |
+
# assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type"
|
113 |
+
# bsdf = material['bsdf'] if bsdf is None else bsdf
|
114 |
+
# if bsdf == 'pbr':
|
115 |
+
# if isinstance(lgt, light.EnvironmentLight):
|
116 |
+
# shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=True)
|
117 |
+
# else:
|
118 |
+
# assert False, "Invalid light type"
|
119 |
+
# elif bsdf == 'diffuse':
|
120 |
+
# if isinstance(lgt, light.EnvironmentLight):
|
121 |
+
# shaded_col = lgt.shade(gb_pos, gb_normal, kd, ks, view_pos, specular=False)
|
122 |
+
# else:
|
123 |
+
# assert False, "Invalid light type"
|
124 |
+
# elif bsdf == 'normal':
|
125 |
+
# shaded_col = (gb_normal + 1.0)*0.5
|
126 |
+
# elif bsdf == 'tangent':
|
127 |
+
# shaded_col = (gb_tangent + 1.0)*0.5
|
128 |
+
# elif bsdf == 'kd':
|
129 |
+
# shaded_col = kd
|
130 |
+
# elif bsdf == 'ks':
|
131 |
+
# shaded_col = ks
|
132 |
+
# else:
|
133 |
+
# assert False, "Invalid BSDF '%s'" % bsdf
|
134 |
+
|
135 |
+
# # Return multiple buffers
|
136 |
+
# buffers = {
|
137 |
+
# 'shaded' : torch.cat((shaded_col, alpha), dim=-1),
|
138 |
+
# 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1),
|
139 |
+
# 'occlusion' : torch.cat((ks[..., :1], alpha), dim=-1)
|
140 |
+
# }
|
141 |
+
# return buffers
|
142 |
+
|
143 |
+
# ==============================================================================================
|
144 |
+
# Render a depth slice of the mesh (scene), some limitations:
|
145 |
+
# - Single mesh
|
146 |
+
# - Single light
|
147 |
+
# - Single material
|
148 |
+
# ==============================================================================================
|
149 |
+
def render_layer(
|
150 |
+
self,
|
151 |
+
rast,
|
152 |
+
rast_deriv,
|
153 |
+
mesh,
|
154 |
+
view_pos,
|
155 |
+
resolution,
|
156 |
+
spp,
|
157 |
+
msaa
|
158 |
+
):
|
159 |
+
|
160 |
+
# Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution
|
161 |
+
rast_out_s = rast
|
162 |
+
rast_out_deriv_s = rast_deriv
|
163 |
+
|
164 |
+
################################################################################
|
165 |
+
# Interpolate attributes
|
166 |
+
################################################################################
|
167 |
+
|
168 |
+
# Interpolate world space position
|
169 |
+
gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int())
|
170 |
+
|
171 |
+
# Compute geometric normals. We need those because of bent normals trick (for bump mapping)
|
172 |
+
v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :]
|
173 |
+
v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :]
|
174 |
+
v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :]
|
175 |
+
face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))
|
176 |
+
face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)
|
177 |
+
gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int())
|
178 |
+
|
179 |
+
# Compute tangent space
|
180 |
+
assert mesh.v_nrm is not None and mesh.v_tng is not None
|
181 |
+
gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int())
|
182 |
+
gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents
|
183 |
+
|
184 |
+
# Texture coordinate
|
185 |
+
# assert mesh.v_tex is not None
|
186 |
+
# gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s)
|
187 |
+
perturbed_nrm = None
|
188 |
+
gb_normal = ru.prepare_shading_normal(gb_pos, view_pos[:,None,None,:], perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)
|
189 |
+
|
190 |
+
return gb_pos, gb_normal
|
191 |
+
|
192 |
+
def render_mesh(
|
193 |
+
self,
|
194 |
+
mesh_v_pos_bxnx3,
|
195 |
+
mesh_t_pos_idx_fx3,
|
196 |
+
mesh,
|
197 |
+
camera_mv_bx4x4,
|
198 |
+
camera_pos,
|
199 |
+
mesh_v_feat_bxnxd,
|
200 |
+
resolution=256,
|
201 |
+
spp=1,
|
202 |
+
device='cuda',
|
203 |
+
hierarchical_mask=False
|
204 |
+
):
|
205 |
+
assert not hierarchical_mask
|
206 |
+
|
207 |
+
mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
|
208 |
+
v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
|
209 |
+
v_pos_clip = self.camera.project(v_pos) # Projection in the camera
|
210 |
+
|
211 |
+
# view_pos = torch.linalg.inv(mtx_in)[:, :3, 3]
|
212 |
+
view_pos = camera_pos
|
213 |
+
v_nrm = mesh.v_nrm #compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
|
214 |
+
|
215 |
+
# Render the image,
|
216 |
+
# Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
|
217 |
+
num_layers = 1
|
218 |
+
mask_pyramid = None
|
219 |
+
assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
|
220 |
+
|
221 |
+
mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos [org_pos, clip space pose for rasterization]
|
222 |
+
|
223 |
+
layers = []
|
224 |
+
with dr.DepthPeeler(self.ctx, v_pos_clip, mesh.t_pos_idx.int(), [resolution * spp, resolution * spp]) as peeler:
|
225 |
+
for _ in range(num_layers):
|
226 |
+
rast, db = peeler.rasterize_next_layer()
|
227 |
+
gb_pos, gb_normal = self.render_layer(rast, db, mesh, view_pos, resolution, spp, msaa=False)
|
228 |
+
|
229 |
+
with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
|
230 |
+
for _ in range(num_layers):
|
231 |
+
rast, db = peeler.rasterize_next_layer()
|
232 |
+
gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
|
233 |
+
|
234 |
+
hard_mask = torch.clamp(rast[..., -1:], 0, 1)
|
235 |
+
antialias_mask = dr.antialias(
|
236 |
+
hard_mask.clone().contiguous(), rast, v_pos_clip,
|
237 |
+
mesh_t_pos_idx_fx3)
|
238 |
+
|
239 |
+
depth = gb_feat[..., -2:-1]
|
240 |
+
ori_mesh_feature = gb_feat[..., :-4]
|
241 |
+
|
242 |
+
normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
|
243 |
+
normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
|
244 |
+
# normal = F.normalize(normal, dim=-1)
|
245 |
+
# normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
|
246 |
+
return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal, gb_normal
|
247 |
+
|
248 |
+
def render_mesh_light(
|
249 |
+
self,
|
250 |
+
mesh_v_pos_bxnx3,
|
251 |
+
mesh_t_pos_idx_fx3,
|
252 |
+
mesh,
|
253 |
+
camera_mv_bx4x4,
|
254 |
+
mesh_v_feat_bxnxd,
|
255 |
+
resolution=256,
|
256 |
+
spp=1,
|
257 |
+
device='cuda',
|
258 |
+
hierarchical_mask=False
|
259 |
+
):
|
260 |
+
assert not hierarchical_mask
|
261 |
+
|
262 |
+
mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4
|
263 |
+
v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates
|
264 |
+
v_pos_clip = self.camera.project(v_pos) # Projection in the camera
|
265 |
+
|
266 |
+
v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates
|
267 |
+
|
268 |
+
# Render the image,
|
269 |
+
# Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render
|
270 |
+
num_layers = 1
|
271 |
+
mask_pyramid = None
|
272 |
+
assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes
|
273 |
+
mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos
|
274 |
+
|
275 |
+
with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler:
|
276 |
+
for _ in range(num_layers):
|
277 |
+
rast, db = peeler.rasterize_next_layer()
|
278 |
+
gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3)
|
279 |
+
|
280 |
+
hard_mask = torch.clamp(rast[..., -1:], 0, 1)
|
281 |
+
antialias_mask = dr.antialias(
|
282 |
+
hard_mask.clone().contiguous(), rast, v_pos_clip,
|
283 |
+
mesh_t_pos_idx_fx3)
|
284 |
+
|
285 |
+
depth = gb_feat[..., -2:-1]
|
286 |
+
ori_mesh_feature = gb_feat[..., :-4]
|
287 |
+
|
288 |
+
normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3)
|
289 |
+
normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3)
|
290 |
+
normal = F.normalize(normal, dim=-1)
|
291 |
+
normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background
|
292 |
+
|
293 |
+
return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal
|
models/lrm/models/geometry/render/renderutils/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith
|
11 |
+
__all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ]
|
models/lrm/models/geometry/render/renderutils/bsdf.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
4 |
+
# property and proprietary rights in and to this material, related
|
5 |
+
# documentation and any modifications thereto. Any use, reproduction,
|
6 |
+
# disclosure or distribution of this material and related documentation
|
7 |
+
# without an express license agreement from NVIDIA CORPORATION or
|
8 |
+
# its affiliates is strictly prohibited.
|
9 |
+
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
|
13 |
+
NORMAL_THRESHOLD = 0.1
|
14 |
+
|
15 |
+
################################################################################
|
16 |
+
# Vector utility functions
|
17 |
+
################################################################################
|
18 |
+
|
19 |
+
def _dot(x, y):
|
20 |
+
return torch.sum(x*y, -1, keepdim=True)
|
21 |
+
|
22 |
+
def _reflect(x, n):
|
23 |
+
return 2*_dot(x, n)*n - x
|
24 |
+
|
25 |
+
def _safe_normalize(x):
|
26 |
+
return torch.nn.functional.normalize(x, dim = -1)
|
27 |
+
|
28 |
+
def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
|
29 |
+
# Swap normal direction for backfacing surfaces
|
30 |
+
if two_sided_shading:
|
31 |
+
smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)
|
32 |
+
geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)
|
33 |
+
|
34 |
+
t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)
|
35 |
+
return torch.lerp(geom_nrm, smooth_nrm, t)
|
36 |
+
|
37 |
+
|
38 |
+
def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
|
39 |
+
smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm))
|
40 |
+
if opengl:
|
41 |
+
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
|
42 |
+
else:
|
43 |
+
shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)
|
44 |
+
return _safe_normalize(shading_nrm)
|
45 |
+
|
46 |
+
def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):
|
47 |
+
smooth_nrm = _safe_normalize(smooth_nrm)
|
48 |
+
smooth_tng = _safe_normalize(smooth_tng)
|
49 |
+
view_vec = _safe_normalize(view_pos - pos)
|
50 |
+
shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)
|
51 |
+
return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)
|
52 |
+
|
53 |
+
################################################################################
|
54 |
+
# Simple lambertian diffuse BSDF
|
55 |
+
################################################################################
|
56 |
+
|
57 |
+
def bsdf_lambert(nrm, wi):
|
58 |
+
return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi
|
59 |
+
|
60 |
+
################################################################################
|
61 |
+
# Frostbite diffuse
|
62 |
+
################################################################################
|
63 |
+
|
64 |
+
def bsdf_frostbite(nrm, wi, wo, linearRoughness):
|
65 |
+
wiDotN = _dot(wi, nrm)
|
66 |
+
woDotN = _dot(wo, nrm)
|
67 |
+
|
68 |
+
h = _safe_normalize(wo + wi)
|
69 |
+
wiDotH = _dot(wi, h)
|
70 |
+
|
71 |
+
energyBias = 0.5 * linearRoughness
|
72 |
+
energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness
|
73 |
+
f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness
|
74 |
+
f0 = 1.0
|
75 |
+
|
76 |
+
wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN)
|
77 |
+
woScatter = bsdf_fresnel_shlick(f0, f90, woDotN)
|
78 |
+
res = wiScatter * woScatter * energyFactor
|
79 |
+
return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res))
|
80 |
+
|
81 |
+
################################################################################
|
82 |
+
# Phong specular, loosely based on mitsuba implementation
|
83 |
+
################################################################################
|
84 |
+
|
85 |
+
def bsdf_phong(nrm, wo, wi, N):
|
86 |
+
dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)
|
87 |
+
dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)
|
88 |
+
return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)
|
89 |
+
|
90 |
+
################################################################################
|
91 |
+
# PBR's implementation of GGX specular
|
92 |
+
################################################################################
|
93 |
+
|
94 |
+
specular_epsilon = 1e-4
|
95 |
+
|
96 |
+
def bsdf_fresnel_shlick(f0, f90, cosTheta):
|
97 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
98 |
+
return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0
|
99 |
+
|
100 |
+
def bsdf_ndf_ggx(alphaSqr, cosTheta):
|
101 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
102 |
+
d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1
|
103 |
+
return alphaSqr / (d * d * math.pi)
|
104 |
+
|
105 |
+
def bsdf_lambda_ggx(alphaSqr, cosTheta):
|
106 |
+
_cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)
|
107 |
+
cosThetaSqr = _cosTheta * _cosTheta
|
108 |
+
tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr
|
109 |
+
res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)
|
110 |
+
return res
|
111 |
+
|
112 |
+
def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
|
113 |
+
lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)
|
114 |
+
lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)
|
115 |
+
return 1 / (1 + lambdaI + lambdaO)
|
116 |
+
|
117 |
+
def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
|
118 |
+
_alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)
|
119 |
+
alphaSqr = _alpha * _alpha
|
120 |
+
|
121 |
+
h = _safe_normalize(wo + wi)
|
122 |
+
woDotN = _dot(wo, nrm)
|
123 |
+
wiDotN = _dot(wi, nrm)
|
124 |
+
woDotH = _dot(wo, h)
|
125 |
+
nDotH = _dot(nrm, h)
|
126 |
+
|
127 |
+
D = bsdf_ndf_ggx(alphaSqr, nDotH)
|
128 |
+
G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)
|
129 |
+
F = bsdf_fresnel_shlick(col, 1, woDotH)
|
130 |
+
|
131 |
+
w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)
|
132 |
+
|
133 |
+
frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)
|
134 |
+
return torch.where(frontfacing, w, torch.zeros_like(w))
|
135 |
+
|
136 |
+
def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
|
137 |
+
wo = _safe_normalize(view_pos - pos)
|
138 |
+
wi = _safe_normalize(light_pos - pos)
|
139 |
+
|
140 |
+
spec_str = arm[..., 0:1] # x component
|
141 |
+
roughness = arm[..., 1:2] # y component
|
142 |
+
metallic = arm[..., 2:3] # z component
|
143 |
+
ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)
|
144 |
+
kd = kd * (1.0 - metallic)
|
145 |
+
|
146 |
+
if BSDF == 0:
|
147 |
+
diffuse = kd * bsdf_lambert(nrm, wi)
|
148 |
+
else:
|
149 |
+
diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness)
|
150 |
+
specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)
|
151 |
+
return diffuse + specular
|
models/lrm/models/geometry/render/renderutils/c_src/bsdf.cu
ADDED
@@ -0,0 +1,710 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "common.h"
|
13 |
+
#include "bsdf.h"
|
14 |
+
|
15 |
+
#define SPECULAR_EPSILON 1e-4f
|
16 |
+
|
17 |
+
//------------------------------------------------------------------------
|
18 |
+
// Lambert functions
|
19 |
+
|
20 |
+
__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)
|
21 |
+
{
|
22 |
+
return max(dot(nrm, wi) / M_PI, 0.0f);
|
23 |
+
}
|
24 |
+
|
25 |
+
__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)
|
26 |
+
{
|
27 |
+
if (dot(nrm, wi) > 0.0f)
|
28 |
+
bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);
|
29 |
+
}
|
30 |
+
|
31 |
+
//------------------------------------------------------------------------
|
32 |
+
// Fresnel Schlick
|
33 |
+
|
34 |
+
__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)
|
35 |
+
{
|
36 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
37 |
+
float scale = powf(1.0f - _cosTheta, 5.0f);
|
38 |
+
return f0 * (1.0f - scale) + f90 * scale;
|
39 |
+
}
|
40 |
+
|
41 |
+
__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)
|
42 |
+
{
|
43 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
44 |
+
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
45 |
+
d_f0 += d_out * (1.0 - scale);
|
46 |
+
d_f90 += d_out * scale;
|
47 |
+
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
48 |
+
{
|
49 |
+
d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)
|
54 |
+
{
|
55 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
56 |
+
float scale = powf(1.0f - _cosTheta, 5.0f);
|
57 |
+
return f0 * (1.0f - scale) + f90 * scale;
|
58 |
+
}
|
59 |
+
|
60 |
+
__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)
|
61 |
+
{
|
62 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
63 |
+
float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);
|
64 |
+
d_f0 += d_out * (1.0 - scale);
|
65 |
+
d_f90 += d_out * scale;
|
66 |
+
if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
67 |
+
{
|
68 |
+
d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
//------------------------------------------------------------------------
|
73 |
+
// Frostbite diffuse
|
74 |
+
|
75 |
+
__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)
|
76 |
+
{
|
77 |
+
float wiDotN = dot(wi, nrm);
|
78 |
+
float woDotN = dot(wo, nrm);
|
79 |
+
if (wiDotN > 0.0f && woDotN > 0.0f)
|
80 |
+
{
|
81 |
+
vec3f h = safeNormalize(wo + wi);
|
82 |
+
float wiDotH = dot(wi, h);
|
83 |
+
|
84 |
+
float energyBias = 0.5f * linearRoughness;
|
85 |
+
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
86 |
+
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
87 |
+
float f0 = 1.f;
|
88 |
+
|
89 |
+
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
|
90 |
+
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
91 |
+
|
92 |
+
return wiScatter * woScatter * energyFactor;
|
93 |
+
}
|
94 |
+
else return 0.0f;
|
95 |
+
}
|
96 |
+
|
97 |
+
__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)
|
98 |
+
{
|
99 |
+
float wiDotN = dot(wi, nrm);
|
100 |
+
float woDotN = dot(wo, nrm);
|
101 |
+
|
102 |
+
if (wiDotN > 0.0f && woDotN > 0.0f)
|
103 |
+
{
|
104 |
+
vec3f h = safeNormalize(wo + wi);
|
105 |
+
float wiDotH = dot(wi, h);
|
106 |
+
|
107 |
+
float energyBias = 0.5f * linearRoughness;
|
108 |
+
float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
109 |
+
float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
110 |
+
float f0 = 1.f;
|
111 |
+
|
112 |
+
float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);
|
113 |
+
float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
114 |
+
|
115 |
+
// -------------- BWD --------------
|
116 |
+
// Backprop: return wiScatter * woScatter * energyFactor;
|
117 |
+
float d_wiScatter = d_out * woScatter * energyFactor;
|
118 |
+
float d_woScatter = d_out * wiScatter * energyFactor;
|
119 |
+
float d_energyFactor = d_out * wiScatter * woScatter;
|
120 |
+
|
121 |
+
// Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);
|
122 |
+
float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;
|
123 |
+
bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);
|
124 |
+
|
125 |
+
// Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);
|
126 |
+
float d_wiDotN = 0.0f;
|
127 |
+
bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);
|
128 |
+
|
129 |
+
// Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;
|
130 |
+
float d_energyBias = d_f90;
|
131 |
+
float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;
|
132 |
+
d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;
|
133 |
+
|
134 |
+
// Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;
|
135 |
+
d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;
|
136 |
+
|
137 |
+
// Backprop: float energyBias = 0.5f * linearRoughness;
|
138 |
+
d_linearRoughness += 0.5 * d_energyBias;
|
139 |
+
|
140 |
+
// Backprop: float wiDotH = dot(wi, h);
|
141 |
+
vec3f d_h(0);
|
142 |
+
bwdDot(wi, h, d_wi, d_h, d_wiDotH);
|
143 |
+
|
144 |
+
// Backprop: vec3f h = safeNormalize(wo + wi);
|
145 |
+
vec3f d_wo_wi(0);
|
146 |
+
bwdSafeNormalize(wo + wi, d_wo_wi, d_h);
|
147 |
+
d_wi += d_wo_wi; d_wo += d_wo_wi;
|
148 |
+
|
149 |
+
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
150 |
+
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
151 |
+
}
|
152 |
+
}
|
153 |
+
|
154 |
+
//------------------------------------------------------------------------
|
155 |
+
// Ndf GGX
|
156 |
+
|
157 |
+
__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)
|
158 |
+
{
|
159 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
160 |
+
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
|
161 |
+
return alphaSqr / (d * d * M_PI);
|
162 |
+
}
|
163 |
+
|
164 |
+
__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
165 |
+
{
|
166 |
+
// Torch only back propagates if clamp doesn't trigger
|
167 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
168 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
169 |
+
d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
170 |
+
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
171 |
+
{
|
172 |
+
d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));
|
173 |
+
}
|
174 |
+
}
|
175 |
+
|
176 |
+
//------------------------------------------------------------------------
|
177 |
+
// Lambda GGX
|
178 |
+
|
179 |
+
__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)
|
180 |
+
{
|
181 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
182 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
183 |
+
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
184 |
+
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
185 |
+
return res;
|
186 |
+
}
|
187 |
+
|
188 |
+
__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)
|
189 |
+
{
|
190 |
+
float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);
|
191 |
+
float cosThetaSqr = _cosTheta * _cosTheta;
|
192 |
+
float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;
|
193 |
+
float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);
|
194 |
+
|
195 |
+
d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);
|
196 |
+
if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)
|
197 |
+
d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));
|
198 |
+
}
|
199 |
+
|
200 |
+
//------------------------------------------------------------------------
|
201 |
+
// Masking GGX
|
202 |
+
|
203 |
+
__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)
|
204 |
+
{
|
205 |
+
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
206 |
+
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
207 |
+
return 1.0f / (1.0f + lambdaI + lambdaO);
|
208 |
+
}
|
209 |
+
|
210 |
+
__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)
|
211 |
+
{
|
212 |
+
// FWD eval
|
213 |
+
float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);
|
214 |
+
float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);
|
215 |
+
|
216 |
+
// BWD eval
|
217 |
+
float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);
|
218 |
+
bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);
|
219 |
+
bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);
|
220 |
+
}
|
221 |
+
|
222 |
+
//------------------------------------------------------------------------
|
223 |
+
// GGX specular
|
224 |
+
|
225 |
+
__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)
|
226 |
+
{
|
227 |
+
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
228 |
+
float alphaSqr = _alpha * _alpha;
|
229 |
+
|
230 |
+
vec3f h = safeNormalize(wo + wi);
|
231 |
+
float woDotN = dot(wo, nrm);
|
232 |
+
float wiDotN = dot(wi, nrm);
|
233 |
+
float woDotH = dot(wo, h);
|
234 |
+
float nDotH = dot(nrm, h);
|
235 |
+
|
236 |
+
float D = fwdNdfGGX(alphaSqr, nDotH);
|
237 |
+
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
238 |
+
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
239 |
+
vec3f w = F * D * G * 0.25 / woDotN;
|
240 |
+
|
241 |
+
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
242 |
+
return frontfacing ? w : 0.0f;
|
243 |
+
}
|
244 |
+
|
245 |
+
__device__ void bwdPbrSpecular(
|
246 |
+
const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,
|
247 |
+
vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)
|
248 |
+
{
|
249 |
+
///////////////////////////////////////////////////////////////////////
|
250 |
+
// FWD eval
|
251 |
+
|
252 |
+
float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);
|
253 |
+
float alphaSqr = _alpha * _alpha;
|
254 |
+
|
255 |
+
vec3f h = safeNormalize(wo + wi);
|
256 |
+
float woDotN = dot(wo, nrm);
|
257 |
+
float wiDotN = dot(wi, nrm);
|
258 |
+
float woDotH = dot(wo, h);
|
259 |
+
float nDotH = dot(nrm, h);
|
260 |
+
|
261 |
+
float D = fwdNdfGGX(alphaSqr, nDotH);
|
262 |
+
float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);
|
263 |
+
vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);
|
264 |
+
vec3f w = F * D * G * 0.25 / woDotN;
|
265 |
+
bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);
|
266 |
+
|
267 |
+
if (frontfacing)
|
268 |
+
{
|
269 |
+
///////////////////////////////////////////////////////////////////////
|
270 |
+
// BWD eval
|
271 |
+
|
272 |
+
vec3f d_F = d_out * D * G * 0.25f / woDotN;
|
273 |
+
float d_D = sum(d_out * F * G * 0.25f / woDotN);
|
274 |
+
float d_G = sum(d_out * F * D * 0.25f / woDotN);
|
275 |
+
|
276 |
+
float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));
|
277 |
+
|
278 |
+
vec3f d_f90(0);
|
279 |
+
float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);
|
280 |
+
bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);
|
281 |
+
bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);
|
282 |
+
bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);
|
283 |
+
|
284 |
+
vec3f d_h(0);
|
285 |
+
bwdDot(nrm, h, d_nrm, d_h, d_nDotH);
|
286 |
+
bwdDot(wo, h, d_wo, d_h, d_woDotH);
|
287 |
+
bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);
|
288 |
+
bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);
|
289 |
+
|
290 |
+
vec3f d_h_unnorm(0);
|
291 |
+
bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);
|
292 |
+
d_wo += d_h_unnorm;
|
293 |
+
d_wi += d_h_unnorm;
|
294 |
+
|
295 |
+
if (alpha > min_roughness * min_roughness)
|
296 |
+
d_alpha += d_alphaSqr * 2 * alpha;
|
297 |
+
}
|
298 |
+
}
|
299 |
+
|
300 |
+
//------------------------------------------------------------------------
|
301 |
+
// Full PBR BSDF
|
302 |
+
|
303 |
+
__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)
|
304 |
+
{
|
305 |
+
vec3f wo = safeNormalize(view_pos - pos);
|
306 |
+
vec3f wi = safeNormalize(light_pos - pos);
|
307 |
+
|
308 |
+
float alpha = arm.y * arm.y;
|
309 |
+
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
310 |
+
vec3f diff_col = kd * (1.0f - arm.z);
|
311 |
+
|
312 |
+
float diff = 0.0f;
|
313 |
+
if (BSDF == 0)
|
314 |
+
diff = fwdLambert(nrm, wi);
|
315 |
+
else
|
316 |
+
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
|
317 |
+
vec3f diffuse = diff_col * diff;
|
318 |
+
vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);
|
319 |
+
|
320 |
+
return diffuse + specular;
|
321 |
+
}
|
322 |
+
|
323 |
+
__device__ void bwdPbrBSDF(
|
324 |
+
const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,
|
325 |
+
vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)
|
326 |
+
{
|
327 |
+
////////////////////////////////////////////////////////////////////////
|
328 |
+
// FWD
|
329 |
+
vec3f _wi = light_pos - pos;
|
330 |
+
vec3f _wo = view_pos - pos;
|
331 |
+
vec3f wi = safeNormalize(_wi);
|
332 |
+
vec3f wo = safeNormalize(_wo);
|
333 |
+
|
334 |
+
float alpha = arm.y * arm.y;
|
335 |
+
vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);
|
336 |
+
vec3f diff_col = kd * (1.0f - arm.z);
|
337 |
+
float diff = 0.0f;
|
338 |
+
if (BSDF == 0)
|
339 |
+
diff = fwdLambert(nrm, wi);
|
340 |
+
else
|
341 |
+
diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);
|
342 |
+
|
343 |
+
////////////////////////////////////////////////////////////////////////
|
344 |
+
// BWD
|
345 |
+
|
346 |
+
float d_alpha(0);
|
347 |
+
vec3f d_spec_col(0), d_wi(0), d_wo(0);
|
348 |
+
bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
349 |
+
|
350 |
+
float d_diff = sum(diff_col * d_out);
|
351 |
+
if (BSDF == 0)
|
352 |
+
bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);
|
353 |
+
else
|
354 |
+
bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);
|
355 |
+
|
356 |
+
// Backprop: diff_col = kd * (1.0f - arm.z)
|
357 |
+
vec3f d_diff_col = d_out * diff;
|
358 |
+
d_kd += d_diff_col * (1.0f - arm.z);
|
359 |
+
d_arm.z -= sum(d_diff_col * kd);
|
360 |
+
|
361 |
+
// Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)
|
362 |
+
d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;
|
363 |
+
d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));
|
364 |
+
d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));
|
365 |
+
|
366 |
+
// Backprop: alpha = arm.y * arm.y
|
367 |
+
d_arm.y += d_alpha * 2 * arm.y;
|
368 |
+
|
369 |
+
// Backprop: vec3f wi = safeNormalize(light_pos - pos);
|
370 |
+
vec3f d__wi(0);
|
371 |
+
bwdSafeNormalize(_wi, d__wi, d_wi);
|
372 |
+
d_light_pos += d__wi;
|
373 |
+
d_pos -= d__wi;
|
374 |
+
|
375 |
+
// Backprop: vec3f wo = safeNormalize(view_pos - pos);
|
376 |
+
vec3f d__wo(0);
|
377 |
+
bwdSafeNormalize(_wo, d__wo, d_wo);
|
378 |
+
d_view_pos += d__wo;
|
379 |
+
d_pos -= d__wo;
|
380 |
+
}
|
381 |
+
|
382 |
+
//------------------------------------------------------------------------
|
383 |
+
// Kernels
|
384 |
+
|
385 |
+
__global__ void LambertFwdKernel(LambertKernelParams p)
|
386 |
+
{
|
387 |
+
// Calculate pixel position.
|
388 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
389 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
390 |
+
unsigned int pz = blockIdx.z;
|
391 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
392 |
+
return;
|
393 |
+
|
394 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
395 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
396 |
+
|
397 |
+
float res = fwdLambert(nrm, wi);
|
398 |
+
|
399 |
+
p.out.store(px, py, pz, res);
|
400 |
+
}
|
401 |
+
|
402 |
+
__global__ void LambertBwdKernel(LambertKernelParams p)
|
403 |
+
{
|
404 |
+
// Calculate pixel position.
|
405 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
406 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
407 |
+
unsigned int pz = blockIdx.z;
|
408 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
409 |
+
return;
|
410 |
+
|
411 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
412 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
413 |
+
float d_out = p.out.fetch1(px, py, pz);
|
414 |
+
|
415 |
+
vec3f d_nrm(0), d_wi(0);
|
416 |
+
bwdLambert(nrm, wi, d_nrm, d_wi, d_out);
|
417 |
+
|
418 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
419 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
420 |
+
}
|
421 |
+
|
422 |
+
__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)
|
423 |
+
{
|
424 |
+
// Calculate pixel position.
|
425 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
426 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
427 |
+
unsigned int pz = blockIdx.z;
|
428 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
429 |
+
return;
|
430 |
+
|
431 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
432 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
433 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
434 |
+
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
|
435 |
+
|
436 |
+
float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);
|
437 |
+
|
438 |
+
p.out.store(px, py, pz, res);
|
439 |
+
}
|
440 |
+
|
441 |
+
__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)
|
442 |
+
{
|
443 |
+
// Calculate pixel position.
|
444 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
445 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
446 |
+
unsigned int pz = blockIdx.z;
|
447 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
448 |
+
return;
|
449 |
+
|
450 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
451 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
452 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
453 |
+
float linearRoughness = p.linearRoughness.fetch1(px, py, pz);
|
454 |
+
float d_out = p.out.fetch1(px, py, pz);
|
455 |
+
|
456 |
+
float d_linearRoughness = 0.0f;
|
457 |
+
vec3f d_nrm(0), d_wi(0), d_wo(0);
|
458 |
+
bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);
|
459 |
+
|
460 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
461 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
462 |
+
p.wo.store_grad(px, py, pz, d_wo);
|
463 |
+
p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);
|
464 |
+
}
|
465 |
+
|
466 |
+
__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)
|
467 |
+
{
|
468 |
+
// Calculate pixel position.
|
469 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
470 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
471 |
+
unsigned int pz = blockIdx.z;
|
472 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
473 |
+
return;
|
474 |
+
|
475 |
+
vec3f f0 = p.f0.fetch3(px, py, pz);
|
476 |
+
vec3f f90 = p.f90.fetch3(px, py, pz);
|
477 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
478 |
+
|
479 |
+
vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);
|
480 |
+
p.out.store(px, py, pz, res);
|
481 |
+
}
|
482 |
+
|
483 |
+
__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)
|
484 |
+
{
|
485 |
+
// Calculate pixel position.
|
486 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
487 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
488 |
+
unsigned int pz = blockIdx.z;
|
489 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
490 |
+
return;
|
491 |
+
|
492 |
+
vec3f f0 = p.f0.fetch3(px, py, pz);
|
493 |
+
vec3f f90 = p.f90.fetch3(px, py, pz);
|
494 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
495 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
496 |
+
|
497 |
+
vec3f d_f0(0), d_f90(0);
|
498 |
+
float d_cosTheta(0);
|
499 |
+
bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);
|
500 |
+
|
501 |
+
p.f0.store_grad(px, py, pz, d_f0);
|
502 |
+
p.f90.store_grad(px, py, pz, d_f90);
|
503 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
504 |
+
}
|
505 |
+
|
506 |
+
__global__ void ndfGGXFwdKernel(NdfGGXParams p)
|
507 |
+
{
|
508 |
+
// Calculate pixel position.
|
509 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
510 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
511 |
+
unsigned int pz = blockIdx.z;
|
512 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
513 |
+
return;
|
514 |
+
|
515 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
516 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
517 |
+
float res = fwdNdfGGX(alphaSqr, cosTheta);
|
518 |
+
|
519 |
+
p.out.store(px, py, pz, res);
|
520 |
+
}
|
521 |
+
|
522 |
+
__global__ void ndfGGXBwdKernel(NdfGGXParams p)
|
523 |
+
{
|
524 |
+
// Calculate pixel position.
|
525 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
526 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
527 |
+
unsigned int pz = blockIdx.z;
|
528 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
529 |
+
return;
|
530 |
+
|
531 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
532 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
533 |
+
float d_out = p.out.fetch1(px, py, pz);
|
534 |
+
|
535 |
+
float d_alphaSqr(0), d_cosTheta(0);
|
536 |
+
bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
537 |
+
|
538 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
539 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
540 |
+
}
|
541 |
+
|
542 |
+
__global__ void lambdaGGXFwdKernel(NdfGGXParams p)
|
543 |
+
{
|
544 |
+
// Calculate pixel position.
|
545 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
546 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
547 |
+
unsigned int pz = blockIdx.z;
|
548 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
549 |
+
return;
|
550 |
+
|
551 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
552 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
553 |
+
float res = fwdLambdaGGX(alphaSqr, cosTheta);
|
554 |
+
|
555 |
+
p.out.store(px, py, pz, res);
|
556 |
+
}
|
557 |
+
|
558 |
+
__global__ void lambdaGGXBwdKernel(NdfGGXParams p)
|
559 |
+
{
|
560 |
+
// Calculate pixel position.
|
561 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
562 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
563 |
+
unsigned int pz = blockIdx.z;
|
564 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
565 |
+
return;
|
566 |
+
|
567 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
568 |
+
float cosTheta = p.cosTheta.fetch1(px, py, pz);
|
569 |
+
float d_out = p.out.fetch1(px, py, pz);
|
570 |
+
|
571 |
+
float d_alphaSqr(0), d_cosTheta(0);
|
572 |
+
bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);
|
573 |
+
|
574 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
575 |
+
p.cosTheta.store_grad(px, py, pz, d_cosTheta);
|
576 |
+
}
|
577 |
+
|
578 |
+
__global__ void maskingSmithFwdKernel(MaskingSmithParams p)
|
579 |
+
{
|
580 |
+
// Calculate pixel position.
|
581 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
582 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
583 |
+
unsigned int pz = blockIdx.z;
|
584 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
585 |
+
return;
|
586 |
+
|
587 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
588 |
+
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
589 |
+
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
590 |
+
float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);
|
591 |
+
|
592 |
+
p.out.store(px, py, pz, res);
|
593 |
+
}
|
594 |
+
|
595 |
+
__global__ void maskingSmithBwdKernel(MaskingSmithParams p)
|
596 |
+
{
|
597 |
+
// Calculate pixel position.
|
598 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
599 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
600 |
+
unsigned int pz = blockIdx.z;
|
601 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
602 |
+
return;
|
603 |
+
|
604 |
+
float alphaSqr = p.alphaSqr.fetch1(px, py, pz);
|
605 |
+
float cosThetaI = p.cosThetaI.fetch1(px, py, pz);
|
606 |
+
float cosThetaO = p.cosThetaO.fetch1(px, py, pz);
|
607 |
+
float d_out = p.out.fetch1(px, py, pz);
|
608 |
+
|
609 |
+
float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);
|
610 |
+
bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);
|
611 |
+
|
612 |
+
p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);
|
613 |
+
p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);
|
614 |
+
p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);
|
615 |
+
}
|
616 |
+
|
617 |
+
__global__ void pbrSpecularFwdKernel(PbrSpecular p)
|
618 |
+
{
|
619 |
+
// Calculate pixel position.
|
620 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
621 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
622 |
+
unsigned int pz = blockIdx.z;
|
623 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
624 |
+
return;
|
625 |
+
|
626 |
+
vec3f col = p.col.fetch3(px, py, pz);
|
627 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
628 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
629 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
630 |
+
float alpha = p.alpha.fetch1(px, py, pz);
|
631 |
+
|
632 |
+
vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);
|
633 |
+
|
634 |
+
p.out.store(px, py, pz, res);
|
635 |
+
}
|
636 |
+
|
637 |
+
__global__ void pbrSpecularBwdKernel(PbrSpecular p)
|
638 |
+
{
|
639 |
+
// Calculate pixel position.
|
640 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
641 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
642 |
+
unsigned int pz = blockIdx.z;
|
643 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
644 |
+
return;
|
645 |
+
|
646 |
+
vec3f col = p.col.fetch3(px, py, pz);
|
647 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
648 |
+
vec3f wo = p.wo.fetch3(px, py, pz);
|
649 |
+
vec3f wi = p.wi.fetch3(px, py, pz);
|
650 |
+
float alpha = p.alpha.fetch1(px, py, pz);
|
651 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
652 |
+
|
653 |
+
float d_alpha(0);
|
654 |
+
vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);
|
655 |
+
bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);
|
656 |
+
|
657 |
+
p.col.store_grad(px, py, pz, d_col);
|
658 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
659 |
+
p.wo.store_grad(px, py, pz, d_wo);
|
660 |
+
p.wi.store_grad(px, py, pz, d_wi);
|
661 |
+
p.alpha.store_grad(px, py, pz, d_alpha);
|
662 |
+
}
|
663 |
+
|
664 |
+
__global__ void pbrBSDFFwdKernel(PbrBSDF p)
|
665 |
+
{
|
666 |
+
// Calculate pixel position.
|
667 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
668 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
669 |
+
unsigned int pz = blockIdx.z;
|
670 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
671 |
+
return;
|
672 |
+
|
673 |
+
vec3f kd = p.kd.fetch3(px, py, pz);
|
674 |
+
vec3f arm = p.arm.fetch3(px, py, pz);
|
675 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
676 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
677 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
678 |
+
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
679 |
+
|
680 |
+
vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);
|
681 |
+
|
682 |
+
p.out.store(px, py, pz, res);
|
683 |
+
}
|
684 |
+
__global__ void pbrBSDFBwdKernel(PbrBSDF p)
|
685 |
+
{
|
686 |
+
// Calculate pixel position.
|
687 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
688 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
689 |
+
unsigned int pz = blockIdx.z;
|
690 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
691 |
+
return;
|
692 |
+
|
693 |
+
vec3f kd = p.kd.fetch3(px, py, pz);
|
694 |
+
vec3f arm = p.arm.fetch3(px, py, pz);
|
695 |
+
vec3f pos = p.pos.fetch3(px, py, pz);
|
696 |
+
vec3f nrm = p.nrm.fetch3(px, py, pz);
|
697 |
+
vec3f view_pos = p.view_pos.fetch3(px, py, pz);
|
698 |
+
vec3f light_pos = p.light_pos.fetch3(px, py, pz);
|
699 |
+
vec3f d_out = p.out.fetch3(px, py, pz);
|
700 |
+
|
701 |
+
vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);
|
702 |
+
bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);
|
703 |
+
|
704 |
+
p.kd.store_grad(px, py, pz, d_kd);
|
705 |
+
p.arm.store_grad(px, py, pz, d_arm);
|
706 |
+
p.pos.store_grad(px, py, pz, d_pos);
|
707 |
+
p.nrm.store_grad(px, py, pz, d_nrm);
|
708 |
+
p.view_pos.store_grad(px, py, pz, d_view_pos);
|
709 |
+
p.light_pos.store_grad(px, py, pz, d_light_pos);
|
710 |
+
}
|
models/lrm/models/geometry/render/renderutils/c_src/bsdf.h
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
#include "common.h"
|
15 |
+
|
16 |
+
struct LambertKernelParams
|
17 |
+
{
|
18 |
+
Tensor nrm;
|
19 |
+
Tensor wi;
|
20 |
+
Tensor out;
|
21 |
+
dim3 gridSize;
|
22 |
+
};
|
23 |
+
|
24 |
+
struct FrostbiteDiffuseKernelParams
|
25 |
+
{
|
26 |
+
Tensor nrm;
|
27 |
+
Tensor wi;
|
28 |
+
Tensor wo;
|
29 |
+
Tensor linearRoughness;
|
30 |
+
Tensor out;
|
31 |
+
dim3 gridSize;
|
32 |
+
};
|
33 |
+
|
34 |
+
struct FresnelShlickKernelParams
|
35 |
+
{
|
36 |
+
Tensor f0;
|
37 |
+
Tensor f90;
|
38 |
+
Tensor cosTheta;
|
39 |
+
Tensor out;
|
40 |
+
dim3 gridSize;
|
41 |
+
};
|
42 |
+
|
43 |
+
struct NdfGGXParams
|
44 |
+
{
|
45 |
+
Tensor alphaSqr;
|
46 |
+
Tensor cosTheta;
|
47 |
+
Tensor out;
|
48 |
+
dim3 gridSize;
|
49 |
+
};
|
50 |
+
|
51 |
+
struct MaskingSmithParams
|
52 |
+
{
|
53 |
+
Tensor alphaSqr;
|
54 |
+
Tensor cosThetaI;
|
55 |
+
Tensor cosThetaO;
|
56 |
+
Tensor out;
|
57 |
+
dim3 gridSize;
|
58 |
+
};
|
59 |
+
|
60 |
+
struct PbrSpecular
|
61 |
+
{
|
62 |
+
Tensor col;
|
63 |
+
Tensor nrm;
|
64 |
+
Tensor wo;
|
65 |
+
Tensor wi;
|
66 |
+
Tensor alpha;
|
67 |
+
Tensor out;
|
68 |
+
dim3 gridSize;
|
69 |
+
float min_roughness;
|
70 |
+
};
|
71 |
+
|
72 |
+
struct PbrBSDF
|
73 |
+
{
|
74 |
+
Tensor kd;
|
75 |
+
Tensor arm;
|
76 |
+
Tensor pos;
|
77 |
+
Tensor nrm;
|
78 |
+
Tensor view_pos;
|
79 |
+
Tensor light_pos;
|
80 |
+
Tensor out;
|
81 |
+
dim3 gridSize;
|
82 |
+
float min_roughness;
|
83 |
+
int BSDF;
|
84 |
+
};
|
models/lrm/models/geometry/render/renderutils/c_src/common.cpp
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include <cuda_runtime.h>
|
13 |
+
#include <algorithm>
|
14 |
+
|
15 |
+
//------------------------------------------------------------------------
|
16 |
+
// Block and grid size calculators for kernel launches.
|
17 |
+
|
18 |
+
dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)
|
19 |
+
{
|
20 |
+
int maxThreads = maxWidth * maxHeight;
|
21 |
+
if (maxThreads <= 1 || (dims.x * dims.y) <= 1)
|
22 |
+
return dim3(1, 1, 1); // Degenerate.
|
23 |
+
|
24 |
+
// Start from max size.
|
25 |
+
int bw = maxWidth;
|
26 |
+
int bh = maxHeight;
|
27 |
+
|
28 |
+
// Optimizations for weirdly sized buffers.
|
29 |
+
if (dims.x < bw)
|
30 |
+
{
|
31 |
+
// Decrease block width to smallest power of two that covers the buffer width.
|
32 |
+
while ((bw >> 1) >= dims.x)
|
33 |
+
bw >>= 1;
|
34 |
+
|
35 |
+
// Maximize height.
|
36 |
+
bh = maxThreads / bw;
|
37 |
+
if (bh > dims.y)
|
38 |
+
bh = dims.y;
|
39 |
+
}
|
40 |
+
else if (dims.y < bh)
|
41 |
+
{
|
42 |
+
// Halve height and double width until fits completely inside buffer vertically.
|
43 |
+
while (bh > dims.y)
|
44 |
+
{
|
45 |
+
bh >>= 1;
|
46 |
+
if (bw < dims.x)
|
47 |
+
bw <<= 1;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
// Done.
|
52 |
+
return dim3(bw, bh, 1);
|
53 |
+
}
|
54 |
+
|
55 |
+
// returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync)
|
56 |
+
dim3 getWarpSize(dim3 blockSize)
|
57 |
+
{
|
58 |
+
return dim3(
|
59 |
+
std::min(blockSize.x, 32u),
|
60 |
+
std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)),
|
61 |
+
std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z))
|
62 |
+
);
|
63 |
+
}
|
64 |
+
|
65 |
+
dim3 getLaunchGridSize(dim3 blockSize, dim3 dims)
|
66 |
+
{
|
67 |
+
dim3 gridSize;
|
68 |
+
gridSize.x = (dims.x - 1) / blockSize.x + 1;
|
69 |
+
gridSize.y = (dims.y - 1) / blockSize.y + 1;
|
70 |
+
gridSize.z = (dims.z - 1) / blockSize.z + 1;
|
71 |
+
return gridSize;
|
72 |
+
}
|
73 |
+
|
74 |
+
//------------------------------------------------------------------------
|
models/lrm/models/geometry/render/renderutils/c_src/common.h
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
#include <cuda.h>
|
14 |
+
#include <stdint.h>
|
15 |
+
|
16 |
+
#include "vec3f.h"
|
17 |
+
#include "vec4f.h"
|
18 |
+
#include "tensor.h"
|
19 |
+
|
20 |
+
dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims);
|
21 |
+
dim3 getLaunchGridSize(dim3 blockSize, dim3 dims);
|
22 |
+
|
23 |
+
#ifdef __CUDACC__
|
24 |
+
|
25 |
+
#ifdef _MSC_VER
|
26 |
+
#define M_PI 3.14159265358979323846f
|
27 |
+
#endif
|
28 |
+
|
29 |
+
__host__ __device__ static inline dim3 getWarpSize(dim3 blockSize)
|
30 |
+
{
|
31 |
+
return dim3(
|
32 |
+
min(blockSize.x, 32u),
|
33 |
+
min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)),
|
34 |
+
min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z))
|
35 |
+
);
|
36 |
+
}
|
37 |
+
|
38 |
+
__device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); }
|
39 |
+
#else
|
40 |
+
dim3 getWarpSize(dim3 blockSize);
|
41 |
+
#endif
|
models/lrm/models/geometry/render/renderutils/c_src/cubemap.cu
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include "common.h"
|
13 |
+
#include "cubemap.h"
|
14 |
+
#include <float.h>
|
15 |
+
|
16 |
+
// https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf
|
17 |
+
__device__ float pixel_area(int x, int y, int N)
|
18 |
+
{
|
19 |
+
if (N > 1)
|
20 |
+
{
|
21 |
+
int H = N / 2;
|
22 |
+
x = abs(x - H);
|
23 |
+
y = abs(y - H);
|
24 |
+
float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H);
|
25 |
+
float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H);
|
26 |
+
return dx * dy;
|
27 |
+
}
|
28 |
+
else
|
29 |
+
return 1;
|
30 |
+
}
|
31 |
+
|
32 |
+
__device__ vec3f cube_to_dir(int x, int y, int side, int N)
|
33 |
+
{
|
34 |
+
float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f;
|
35 |
+
float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f;
|
36 |
+
switch (side)
|
37 |
+
{
|
38 |
+
case 0: return safeNormalize(vec3f(1, -fy, -fx));
|
39 |
+
case 1: return safeNormalize(vec3f(-1, -fy, fx));
|
40 |
+
case 2: return safeNormalize(vec3f(fx, 1, fy));
|
41 |
+
case 3: return safeNormalize(vec3f(fx, -1, -fy));
|
42 |
+
case 4: return safeNormalize(vec3f(fx, -fy, 1));
|
43 |
+
case 5: return safeNormalize(vec3f(-fx, -fy, -1));
|
44 |
+
}
|
45 |
+
return vec3f(0,0,0); // Unreachable
|
46 |
+
}
|
47 |
+
|
48 |
+
__device__ vec3f dir_to_side(int side, vec3f v)
|
49 |
+
{
|
50 |
+
switch (side)
|
51 |
+
{
|
52 |
+
case 0: return vec3f(-v.z, -v.y, v.x);
|
53 |
+
case 1: return vec3f( v.z, -v.y, -v.x);
|
54 |
+
case 2: return vec3f( v.x, v.z, v.y);
|
55 |
+
case 3: return vec3f( v.x, -v.z, -v.y);
|
56 |
+
case 4: return vec3f( v.x, -v.y, v.z);
|
57 |
+
case 5: return vec3f(-v.x, -v.y, -v.z);
|
58 |
+
}
|
59 |
+
return vec3f(0,0,0); // Unreachable
|
60 |
+
}
|
61 |
+
|
62 |
+
__device__ void extents_1d(float x, float z, float theta, float& _min, float& _max)
|
63 |
+
{
|
64 |
+
float l = sqrtf(x * x + z * z);
|
65 |
+
float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l;
|
66 |
+
float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l;
|
67 |
+
if (pzl <= 0.00001f)
|
68 |
+
_min = pxl > 0.0f ? FLT_MAX : -FLT_MAX;
|
69 |
+
else
|
70 |
+
_min = pxl / pzl;
|
71 |
+
if (pzr <= 0.00001f)
|
72 |
+
_max = pxr > 0.0f ? FLT_MAX : -FLT_MAX;
|
73 |
+
else
|
74 |
+
_max = pxr / pzr;
|
75 |
+
}
|
76 |
+
|
77 |
+
__device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax)
|
78 |
+
{
|
79 |
+
vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1
|
80 |
+
|
81 |
+
if (theta < 0.785398f) // PI/4
|
82 |
+
{
|
83 |
+
float xmin, xmax, ymin, ymax;
|
84 |
+
extents_1d(c.x, c.z, theta, xmin, xmax);
|
85 |
+
extents_1d(c.y, c.z, theta, ymin, ymax);
|
86 |
+
|
87 |
+
if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f)
|
88 |
+
{
|
89 |
+
_xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb
|
90 |
+
}
|
91 |
+
else
|
92 |
+
{
|
93 |
+
_xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
94 |
+
_xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
95 |
+
_ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
96 |
+
_ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));
|
97 |
+
}
|
98 |
+
}
|
99 |
+
else
|
100 |
+
{
|
101 |
+
_xmin = 0.0f;
|
102 |
+
_xmax = (float)(N-1);
|
103 |
+
_ymin = 0.0f;
|
104 |
+
_ymax = (float)(N-1);
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////////////
|
109 |
+
// Diffuse kernel
|
110 |
+
__global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p)
|
111 |
+
{
|
112 |
+
// Calculate pixel position.
|
113 |
+
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
114 |
+
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
115 |
+
int pz = blockIdx.z;
|
116 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
117 |
+
return;
|
118 |
+
|
119 |
+
int Npx = p.cubemap.dims[1];
|
120 |
+
vec3f N = cube_to_dir(px, py, pz, Npx);
|
121 |
+
|
122 |
+
vec3f col(0);
|
123 |
+
|
124 |
+
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
125 |
+
{
|
126 |
+
for (int y = 0; y < Npx; ++y)
|
127 |
+
{
|
128 |
+
for (int x = 0; x < Npx; ++x)
|
129 |
+
{
|
130 |
+
vec3f L = cube_to_dir(x, y, s, Npx);
|
131 |
+
float costheta = min(max(dot(N, L), 0.0f), 0.999f);
|
132 |
+
float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
|
133 |
+
col += p.cubemap.fetch3(x, y, s) * w;
|
134 |
+
}
|
135 |
+
}
|
136 |
+
}
|
137 |
+
|
138 |
+
p.out.store(px, py, pz, col);
|
139 |
+
}
|
140 |
+
|
141 |
+
__global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p)
|
142 |
+
{
|
143 |
+
// Calculate pixel position.
|
144 |
+
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
145 |
+
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
146 |
+
int pz = blockIdx.z;
|
147 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
148 |
+
return;
|
149 |
+
|
150 |
+
int Npx = p.cubemap.dims[1];
|
151 |
+
vec3f N = cube_to_dir(px, py, pz, Npx);
|
152 |
+
vec3f grad = p.out.fetch3(px, py, pz);
|
153 |
+
|
154 |
+
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
155 |
+
{
|
156 |
+
for (int y = 0; y < Npx; ++y)
|
157 |
+
{
|
158 |
+
for (int x = 0; x < Npx; ++x)
|
159 |
+
{
|
160 |
+
vec3f L = cube_to_dir(x, y, s, Npx);
|
161 |
+
float costheta = min(max(dot(N, L), 0.0f), 0.999f);
|
162 |
+
float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere
|
163 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
|
164 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
|
165 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
|
166 |
+
}
|
167 |
+
}
|
168 |
+
}
|
169 |
+
}
|
170 |
+
|
171 |
+
///////////////////////////////////////////////////////////////////////////////////////////////////////////
|
172 |
+
// GGX splitsum kernel
|
173 |
+
|
174 |
+
__device__ inline float ndfGGX(const float alphaSqr, const float cosTheta)
|
175 |
+
{
|
176 |
+
float _cosTheta = clamp(cosTheta, 0.0, 1.0f);
|
177 |
+
float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;
|
178 |
+
return alphaSqr / (d * d * M_PI);
|
179 |
+
}
|
180 |
+
|
181 |
+
__global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p)
|
182 |
+
{
|
183 |
+
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
184 |
+
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
185 |
+
int pz = blockIdx.z;
|
186 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
187 |
+
return;
|
188 |
+
|
189 |
+
int Npx = p.gridSize.x;
|
190 |
+
vec3f VNR = cube_to_dir(px, py, pz, Npx);
|
191 |
+
|
192 |
+
const int TILE_SIZE = 16;
|
193 |
+
|
194 |
+
// Brute force entire cubemap and compute bounds for the cone
|
195 |
+
for (int s = 0; s < p.gridSize.z; ++s)
|
196 |
+
{
|
197 |
+
// Assume empty BBox
|
198 |
+
int _min_x = p.gridSize.x - 1, _max_x = 0;
|
199 |
+
int _min_y = p.gridSize.y - 1, _max_y = 0;
|
200 |
+
|
201 |
+
// For each (8x8) tile
|
202 |
+
for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++)
|
203 |
+
{
|
204 |
+
for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++)
|
205 |
+
{
|
206 |
+
// Compute tile extents
|
207 |
+
int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE;
|
208 |
+
int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y);
|
209 |
+
|
210 |
+
// Use some blunt interval arithmetics to cull tiles
|
211 |
+
vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx);
|
212 |
+
vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx);
|
213 |
+
|
214 |
+
float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x));
|
215 |
+
float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y));
|
216 |
+
float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z));
|
217 |
+
|
218 |
+
float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z);
|
219 |
+
if (maxdp >= p.costheta_cutoff)
|
220 |
+
{
|
221 |
+
// Test all pixels in tile.
|
222 |
+
for (int y = tsy; y < tey; ++y)
|
223 |
+
{
|
224 |
+
for (int x = tsx; x < tex; ++x)
|
225 |
+
{
|
226 |
+
vec3f L = cube_to_dir(x, y, s, Npx);
|
227 |
+
if (dot(L, VNR) >= p.costheta_cutoff)
|
228 |
+
{
|
229 |
+
_min_x = min(_min_x, x);
|
230 |
+
_max_x = max(_max_x, x);
|
231 |
+
_min_y = min(_min_y, y);
|
232 |
+
_max_y = max(_max_y, y);
|
233 |
+
}
|
234 |
+
}
|
235 |
+
}
|
236 |
+
}
|
237 |
+
}
|
238 |
+
}
|
239 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x);
|
240 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x);
|
241 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y);
|
242 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y);
|
243 |
+
}
|
244 |
+
}
|
245 |
+
|
246 |
+
__global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p)
|
247 |
+
{
|
248 |
+
// Calculate pixel position.
|
249 |
+
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
250 |
+
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
251 |
+
int pz = blockIdx.z;
|
252 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
253 |
+
return;
|
254 |
+
|
255 |
+
int Npx = p.cubemap.dims[1];
|
256 |
+
vec3f VNR = cube_to_dir(px, py, pz, Npx);
|
257 |
+
|
258 |
+
float alpha = p.roughness * p.roughness;
|
259 |
+
float alphaSqr = alpha * alpha;
|
260 |
+
|
261 |
+
float wsum = 0.0f;
|
262 |
+
vec3f col(0);
|
263 |
+
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
264 |
+
{
|
265 |
+
int xmin, xmax, ymin, ymax;
|
266 |
+
xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
|
267 |
+
xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
|
268 |
+
ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
|
269 |
+
ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
|
270 |
+
|
271 |
+
if (xmin <= xmax)
|
272 |
+
{
|
273 |
+
for (int y = ymin; y <= ymax; ++y)
|
274 |
+
{
|
275 |
+
for (int x = xmin; x <= xmax; ++x)
|
276 |
+
{
|
277 |
+
vec3f L = cube_to_dir(x, y, s, Npx);
|
278 |
+
if (dot(L, VNR) >= p.costheta_cutoff)
|
279 |
+
{
|
280 |
+
vec3f H = safeNormalize(L + VNR);
|
281 |
+
|
282 |
+
float wiDotN = max(dot(L, VNR), 0.0f);
|
283 |
+
float VNRDotH = max(dot(VNR, H), 0.0f);
|
284 |
+
|
285 |
+
float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
|
286 |
+
col += p.cubemap.fetch3(x, y, s) * w;
|
287 |
+
wsum += w;
|
288 |
+
}
|
289 |
+
}
|
290 |
+
}
|
291 |
+
}
|
292 |
+
}
|
293 |
+
|
294 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x);
|
295 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y);
|
296 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z);
|
297 |
+
p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum);
|
298 |
+
}
|
299 |
+
|
300 |
+
__global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p)
|
301 |
+
{
|
302 |
+
// Calculate pixel position.
|
303 |
+
int px = blockIdx.x * blockDim.x + threadIdx.x;
|
304 |
+
int py = blockIdx.y * blockDim.y + threadIdx.y;
|
305 |
+
int pz = blockIdx.z;
|
306 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
307 |
+
return;
|
308 |
+
|
309 |
+
int Npx = p.cubemap.dims[1];
|
310 |
+
vec3f VNR = cube_to_dir(px, py, pz, Npx);
|
311 |
+
|
312 |
+
vec3f grad = p.out.fetch3(px, py, pz);
|
313 |
+
|
314 |
+
float alpha = p.roughness * p.roughness;
|
315 |
+
float alphaSqr = alpha * alpha;
|
316 |
+
|
317 |
+
vec3f col(0);
|
318 |
+
for (int s = 0; s < p.cubemap.dims[0]; ++s)
|
319 |
+
{
|
320 |
+
int xmin, xmax, ymin, ymax;
|
321 |
+
xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));
|
322 |
+
xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));
|
323 |
+
ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));
|
324 |
+
ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));
|
325 |
+
|
326 |
+
if (xmin <= xmax)
|
327 |
+
{
|
328 |
+
for (int y = ymin; y <= ymax; ++y)
|
329 |
+
{
|
330 |
+
for (int x = xmin; x <= xmax; ++x)
|
331 |
+
{
|
332 |
+
vec3f L = cube_to_dir(x, y, s, Npx);
|
333 |
+
if (dot(L, VNR) >= p.costheta_cutoff)
|
334 |
+
{
|
335 |
+
vec3f H = safeNormalize(L + VNR);
|
336 |
+
|
337 |
+
float wiDotN = max(dot(L, VNR), 0.0f);
|
338 |
+
float VNRDotH = max(dot(VNR, H), 0.0f);
|
339 |
+
|
340 |
+
float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;
|
341 |
+
|
342 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);
|
343 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);
|
344 |
+
atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);
|
345 |
+
}
|
346 |
+
}
|
347 |
+
}
|
348 |
+
}
|
349 |
+
}
|
350 |
+
}
|
models/lrm/models/geometry/render/renderutils/c_src/cubemap.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#pragma once
|
13 |
+
|
14 |
+
#include "common.h"
|
15 |
+
|
16 |
+
struct DiffuseCubemapKernelParams
|
17 |
+
{
|
18 |
+
Tensor cubemap;
|
19 |
+
Tensor out;
|
20 |
+
dim3 gridSize;
|
21 |
+
};
|
22 |
+
|
23 |
+
struct SpecularCubemapKernelParams
|
24 |
+
{
|
25 |
+
Tensor cubemap;
|
26 |
+
Tensor bounds;
|
27 |
+
Tensor out;
|
28 |
+
dim3 gridSize;
|
29 |
+
float costheta_cutoff;
|
30 |
+
float roughness;
|
31 |
+
};
|
32 |
+
|
33 |
+
struct SpecularBoundsKernelParams
|
34 |
+
{
|
35 |
+
float costheta_cutoff;
|
36 |
+
Tensor out;
|
37 |
+
dim3 gridSize;
|
38 |
+
};
|
models/lrm/models/geometry/render/renderutils/c_src/loss.cu
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
3 |
+
*
|
4 |
+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
5 |
+
* property and proprietary rights in and to this material, related
|
6 |
+
* documentation and any modifications thereto. Any use, reproduction,
|
7 |
+
* disclosure or distribution of this material and related documentation
|
8 |
+
* without an express license agreement from NVIDIA CORPORATION or
|
9 |
+
* its affiliates is strictly prohibited.
|
10 |
+
*/
|
11 |
+
|
12 |
+
#include <cuda.h>
|
13 |
+
|
14 |
+
#include "common.h"
|
15 |
+
#include "loss.h"
|
16 |
+
|
17 |
+
//------------------------------------------------------------------------
|
18 |
+
// Utils
|
19 |
+
|
20 |
+
__device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; }
|
21 |
+
|
22 |
+
__device__ float warpSum(float val) {
|
23 |
+
for (int i = 1; i < 32; i *= 2)
|
24 |
+
val += __shfl_xor_sync(0xFFFFFFFF, val, i);
|
25 |
+
return val;
|
26 |
+
}
|
27 |
+
|
28 |
+
//------------------------------------------------------------------------
|
29 |
+
// Tonemapping
|
30 |
+
|
31 |
+
__device__ inline float fwdSRGB(float x)
|
32 |
+
{
|
33 |
+
return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f);
|
34 |
+
}
|
35 |
+
|
36 |
+
__device__ inline void bwdSRGB(float x, float &d_x, float d_out)
|
37 |
+
{
|
38 |
+
if (x > 0.0031308f)
|
39 |
+
d_x += d_out * 0.439583f / powf(x, 0.583333f);
|
40 |
+
else if (x > 0.0f)
|
41 |
+
d_x += d_out * 12.92f;
|
42 |
+
}
|
43 |
+
|
44 |
+
__device__ inline vec3f fwdTonemapLogSRGB(vec3f x)
|
45 |
+
{
|
46 |
+
return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f)));
|
47 |
+
}
|
48 |
+
|
49 |
+
__device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out)
|
50 |
+
{
|
51 |
+
if (x.x > 0.0f && x.x < 65535.0f)
|
52 |
+
{
|
53 |
+
bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x);
|
54 |
+
d_x.x *= 1 / (x.x + 1.0f);
|
55 |
+
}
|
56 |
+
if (x.y > 0.0f && x.y < 65535.0f)
|
57 |
+
{
|
58 |
+
bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y);
|
59 |
+
d_x.y *= 1 / (x.y + 1.0f);
|
60 |
+
}
|
61 |
+
if (x.z > 0.0f && x.z < 65535.0f)
|
62 |
+
{
|
63 |
+
bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z);
|
64 |
+
d_x.z *= 1 / (x.z + 1.0f);
|
65 |
+
}
|
66 |
+
}
|
67 |
+
|
68 |
+
__device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f)
|
69 |
+
{
|
70 |
+
return (img - target) * (img - target) / (img * img + target * target + eps);
|
71 |
+
}
|
72 |
+
|
73 |
+
__device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f)
|
74 |
+
{
|
75 |
+
float denom = (target * target + img * img + eps);
|
76 |
+
d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom);
|
77 |
+
d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom);
|
78 |
+
}
|
79 |
+
|
80 |
+
__device__ inline float fwdSMAPE(float img, float target, float eps=0.01f)
|
81 |
+
{
|
82 |
+
return abs(img - target) / (img + target + eps);
|
83 |
+
}
|
84 |
+
|
85 |
+
__device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f)
|
86 |
+
{
|
87 |
+
float denom = (target + img + eps);
|
88 |
+
d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom);
|
89 |
+
d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom);
|
90 |
+
}
|
91 |
+
|
92 |
+
//------------------------------------------------------------------------
|
93 |
+
// Kernels
|
94 |
+
|
95 |
+
__global__ void imgLossFwdKernel(LossKernelParams p)
|
96 |
+
{
|
97 |
+
// Calculate pixel position.
|
98 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
99 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
100 |
+
unsigned int pz = blockIdx.z;
|
101 |
+
|
102 |
+
float floss = 0.0f;
|
103 |
+
if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z)
|
104 |
+
{
|
105 |
+
vec3f img = p.img.fetch3(px, py, pz);
|
106 |
+
vec3f target = p.target.fetch3(px, py, pz);
|
107 |
+
|
108 |
+
img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f));
|
109 |
+
target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f));
|
110 |
+
|
111 |
+
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
112 |
+
{
|
113 |
+
img = fwdTonemapLogSRGB(img);
|
114 |
+
target = fwdTonemapLogSRGB(target);
|
115 |
+
}
|
116 |
+
|
117 |
+
vec3f vloss(0);
|
118 |
+
if (p.loss == LOSS_MSE)
|
119 |
+
vloss = (img - target) * (img - target);
|
120 |
+
else if (p.loss == LOSS_RELMSE)
|
121 |
+
vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z));
|
122 |
+
else if (p.loss == LOSS_SMAPE)
|
123 |
+
vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z));
|
124 |
+
else
|
125 |
+
vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z));
|
126 |
+
|
127 |
+
floss = sum(vloss) / 3.0f;
|
128 |
+
}
|
129 |
+
|
130 |
+
floss = warpSum(floss);
|
131 |
+
|
132 |
+
dim3 warpSize = getWarpSize(blockDim);
|
133 |
+
if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0)
|
134 |
+
p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss);
|
135 |
+
}
|
136 |
+
|
137 |
+
__global__ void imgLossBwdKernel(LossKernelParams p)
|
138 |
+
{
|
139 |
+
// Calculate pixel position.
|
140 |
+
unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;
|
141 |
+
unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;
|
142 |
+
unsigned int pz = blockIdx.z;
|
143 |
+
|
144 |
+
if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)
|
145 |
+
return;
|
146 |
+
|
147 |
+
dim3 warpSize = getWarpSize(blockDim);
|
148 |
+
|
149 |
+
vec3f _img = p.img.fetch3(px, py, pz);
|
150 |
+
vec3f _target = p.target.fetch3(px, py, pz);
|
151 |
+
float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z);
|
152 |
+
|
153 |
+
/////////////////////////////////////////////////////////////////////
|
154 |
+
// FWD
|
155 |
+
|
156 |
+
vec3f img = _img, target = _target;
|
157 |
+
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
158 |
+
{
|
159 |
+
img = fwdTonemapLogSRGB(img);
|
160 |
+
target = fwdTonemapLogSRGB(target);
|
161 |
+
}
|
162 |
+
|
163 |
+
/////////////////////////////////////////////////////////////////////
|
164 |
+
// BWD
|
165 |
+
|
166 |
+
vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f;
|
167 |
+
|
168 |
+
vec3f d_img(0), d_target(0);
|
169 |
+
if (p.loss == LOSS_MSE)
|
170 |
+
{
|
171 |
+
d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z));
|
172 |
+
d_target = -d_img;
|
173 |
+
}
|
174 |
+
else if (p.loss == LOSS_RELMSE)
|
175 |
+
{
|
176 |
+
bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
|
177 |
+
bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
|
178 |
+
bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
|
179 |
+
}
|
180 |
+
else if (p.loss == LOSS_SMAPE)
|
181 |
+
{
|
182 |
+
bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);
|
183 |
+
bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);
|
184 |
+
bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);
|
185 |
+
}
|
186 |
+
else
|
187 |
+
{
|
188 |
+
d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z));
|
189 |
+
d_target = -d_img;
|
190 |
+
}
|
191 |
+
|
192 |
+
|
193 |
+
if (p.tonemapper == TONEMAPPER_LOG_SRGB)
|
194 |
+
{
|
195 |
+
vec3f d__img(0), d__target(0);
|
196 |
+
bwdTonemapLogSRGB(_img, d__img, d_img);
|
197 |
+
bwdTonemapLogSRGB(_target, d__target, d_target);
|
198 |
+
d_img = d__img; d_target = d__target;
|
199 |
+
}
|
200 |
+
|
201 |
+
if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0;
|
202 |
+
if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0;
|
203 |
+
if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0;
|
204 |
+
if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0;
|
205 |
+
if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0;
|
206 |
+
if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0;
|
207 |
+
|
208 |
+
p.img.store_grad(px, py, pz, d_img);
|
209 |
+
p.target.store_grad(px, py, pz, d_target);
|
210 |
+
}
|