Spaces:
Sleeping
Sleeping
import cv2 | |
import os | |
import numpy as np | |
import torch | |
import imageio | |
from torchvision.utils import make_grid, save_image | |
from .ray_marcher import RayMarcher, generate_colored_boxes | |
def get_pose_on_orbit(radius, height, angles, world_up=torch.Tensor([0, 1, 0])): | |
num_points = angles.shape[0] | |
x = radius * torch.cos(angles) | |
h = torch.ones((num_points,)) * height | |
z = radius * torch.sin(angles) | |
position = torch.stack([x, h, z], dim=-1) | |
forward = position / torch.norm(position, p=2, dim=-1, keepdim=True) | |
right = -torch.cross(world_up[None, ...], forward) | |
right /= torch.norm(right, dim=-1, keepdim=True) | |
up = torch.cross(forward, right) | |
up /= torch.norm(up, p=2, dim=-1, keepdim=True) | |
rotation = torch.stack([right, up, forward], dim=1) | |
translation = torch.Tensor([0, 0, radius])[None, :, None].repeat(num_points, 1, 1) | |
return torch.concat([rotation, translation], dim=2) | |
def render_mvp_boxes(rm, batch, preds): | |
with torch.no_grad(): | |
boxes_rgba = generate_colored_boxes( | |
preds["prim_rgba"], | |
preds["prim_rot"], | |
) | |
preds_boxes = rm( | |
prim_rgba=boxes_rgba, | |
prim_pos=preds["prim_pos"], | |
prim_scale=preds["prim_scale"], | |
prim_rot=preds["prim_rot"], | |
RT=batch["Rt"], | |
K=batch["K"], | |
) | |
return preds_boxes["rgba_image"][:, :3].permute(0, 2, 3, 1) | |
def save_image_summary(path, batch, preds): | |
rgb = preds["rgb"].detach().permute(0, 3, 1, 2) | |
# rgb_gt = batch["image"] | |
rgb_boxes = preds["rgb_boxes"].detach().permute(0, 3, 1, 2) | |
bs = rgb_boxes.shape[0] | |
if "folder" in batch and "key" in batch: | |
obj_list = [] | |
for bs_idx in range(bs): | |
tmp_img = rgb_boxes[bs_idx].permute(1, 2, 0).to(torch.uint8).cpu().numpy() | |
tmp_img = np.ascontiguousarray(tmp_img) | |
folder = batch['folder'][bs_idx] | |
key = batch['key'][bs_idx] | |
obj_list.append("{}/{}\n".format(folder, key)) | |
cv2.putText(tmp_img, "{}".format(folder), (200, 200), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 2) | |
cv2.putText(tmp_img, "{}".format(key), (200, 400), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 0), 2) | |
tmp_img_torch = torch.as_tensor(tmp_img).permute(2, 0, 1).float() | |
rgb_boxes[bs_idx] = tmp_img_torch | |
with open(os.path.splitext(path)[0]+".txt", "w") as f: | |
f.writelines(obj_list) | |
img = make_grid(torch.cat([rgb, rgb_boxes], dim=2) / 255.0).clip(0.0, 1.0) | |
save_image(img, path) | |
def visualize_primsdf_box(image_save_path, model, rm: RayMarcher, device): | |
# prim_rgba: primitive payload [B, K, 4, S, S, S], | |
# K - # of primitives, S - primitive size | |
# prim_pos: locations [B, K, 3] | |
# prim_rot: rotations [B, K, 3, 3] | |
# prim_scale: scales [B, K, 3] | |
# K: intrinsics [B, 3, 3] | |
# RT: extrinsics [B, 3, 4] | |
preds = {} | |
batch = {} | |
prim_alpha = model.sdf2alpha(model.feat_geo).reshape(1, model.num_prims, 1, model.prim_shape, model.prim_shape, model.prim_shape) * 255 | |
prim_rgb = model.feat_tex.reshape(1, model.num_prims, 3, model.prim_shape, model.prim_shape, model.prim_shape) * 255 | |
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) | |
preds['prim_pos'] = model.pos.reshape(1, model.num_prims, 3) * rm.volradius | |
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(1, model.num_prims, 1, 1) | |
preds['prim_scale'] = (1 / model.scale.reshape(1, model.num_prims, 1).repeat(1, 1, 3)) | |
batch['Rt'] = torch.Tensor([ | |
[ | |
1.0, | |
0.0, | |
0.0, | |
0.0 * rm.volradius | |
], | |
[ | |
0.0, | |
-1.0, | |
0.0, | |
0.0 * rm.volradius | |
], | |
[ | |
0.0, | |
0.0, | |
-1.0, | |
5 * rm.volradius | |
] | |
]).to(device)[None, ...] | |
batch['K'] = torch.Tensor([ | |
[ | |
2084.9526697685183, | |
0.0, | |
512.0 | |
], | |
[ | |
0.0, | |
2084.9526697685183, | |
512.0 | |
], | |
[ | |
0.0, | |
0.0, | |
1.0 | |
]]).to(device)[None, ...] | |
ratio_h = rm.image_height / 1024. | |
ratio_w = rm.image_width / 1024. | |
batch['K'][:, 0:1, :] *= ratio_h | |
batch['K'][:, 1:2, :] *= ratio_w | |
# raymarcher is in mm | |
rm_preds = rm( | |
prim_rgba=preds["prim_rgba"], | |
prim_pos=preds["prim_pos"], | |
prim_scale=preds["prim_scale"], | |
prim_rot=preds["prim_rot"], | |
RT=batch["Rt"], | |
K=batch["K"], | |
) | |
rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) | |
with torch.no_grad(): | |
preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) | |
save_image_summary(image_save_path, batch, preds) | |
def render_primsdf(image_save_path, model, rm, device): | |
preds = {} | |
batch = {} | |
preds['prim_pos'] = model.pos.reshape(1, model.num_prims, 3) * rm.volradius | |
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(1, model.num_prims, 1, 1) | |
preds['prim_scale'] = (1 / model.scale.reshape(1, model.num_prims, 1).repeat(1, 1, 3)) | |
batch['Rt'] = torch.Tensor([ | |
[ | |
1.0, | |
0.0, | |
0.0, | |
0.0 * rm.volradius | |
], | |
[ | |
0.0, | |
-1.0, | |
0.0, | |
0.0 * rm.volradius | |
], | |
[ | |
0.0, | |
0.0, | |
-1.0, | |
5 * rm.volradius | |
] | |
]).to(device)[None, ...] | |
batch['K'] = torch.Tensor([ | |
[ | |
2084.9526697685183, | |
0.0, | |
512.0 | |
], | |
[ | |
0.0, | |
2084.9526697685183, | |
512.0 | |
], | |
[ | |
0.0, | |
0.0, | |
1.0 | |
]]).to(device)[None, ...] | |
ratio_h = rm.image_height / 1024. | |
ratio_w = rm.image_width / 1024. | |
batch['K'][:, 0:1, :] *= ratio_h | |
batch['K'][:, 1:2, :] *= ratio_w | |
# test rendering | |
all_sampled_sdf = [] | |
all_sampled_tex = [] | |
for i in range(model.prim_shape ** 3): | |
with torch.no_grad(): | |
model_prediction = model(model.sdf_sampled_point[:, i, :].to(device)) | |
sampled_sdf = model_prediction['sdf'] | |
sampled_rgb = model_prediction['tex'] | |
all_sampled_sdf.append(sampled_sdf) | |
all_sampled_tex.append(sampled_rgb) | |
sampled_sdf = torch.stack(all_sampled_sdf, dim=1) | |
sampled_tex = torch.stack(all_sampled_tex, dim=1).permute(0, 2, 1).reshape(1, model.num_prims, 3, model.prim_shape, model.prim_shape, model.prim_shape) * 255 | |
prim_rgb = sampled_tex | |
prim_alpha = model.sdf2alpha(sampled_sdf).reshape(1, model.num_prims, 1, model.prim_shape, model.prim_shape, model.prim_shape) * 255 | |
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) | |
rm_preds = rm( | |
prim_rgba=preds["prim_rgba"], | |
prim_pos=preds["prim_pos"], | |
prim_scale=preds["prim_scale"], | |
prim_rot=preds["prim_rot"], | |
RT=batch["Rt"], | |
K=batch["K"], | |
) | |
rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) | |
with torch.no_grad(): | |
preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) | |
save_image_summary(image_save_path, batch, preds) | |
def visualize_primvolume(image_save_path, batch, prim_volume, rm: RayMarcher, device): | |
# prim_volume - [B, nprims, 4+6*8^3] | |
def sdf2alpha(sdf): | |
return torch.exp(-(sdf / 0.005) ** 2) | |
preds = {} | |
prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3))) | |
num_prims = prim_volume.shape[1] | |
bs = prim_volume.shape[0] | |
geo_start_index = 4 | |
geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive | |
tex_start_index = geo_end_index | |
tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive | |
mat_start_index = tex_end_index | |
mat_end_index = mat_start_index + prim_shape ** 3 * 2 | |
feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] | |
feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] | |
prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255 | |
prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255 | |
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) | |
pos = prim_volume[:, :, 1:4] | |
scale = prim_volume[:, :, 0:1] | |
preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius | |
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1) | |
preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3)) | |
batch['Rt'] = torch.Tensor([ | |
[ | |
1.0, | |
0.0, | |
0.0, | |
0.0 * rm.volradius | |
], | |
[ | |
0.0, | |
-1.0, | |
0.0, | |
0.0 * rm.volradius | |
], | |
[ | |
0.0, | |
0.0, | |
-1.0, | |
5 * rm.volradius | |
] | |
]).to(device)[None, ...].repeat(bs, 1, 1) | |
batch['K'] = torch.Tensor([ | |
[ | |
2084.9526697685183, | |
0.0, | |
512.0 | |
], | |
[ | |
0.0, | |
2084.9526697685183, | |
512.0 | |
], | |
[ | |
0.0, | |
0.0, | |
1.0 | |
]]).to(device)[None, ...].repeat(bs, 1, 1) | |
ratio_h = rm.image_height / 1024. | |
ratio_w = rm.image_width / 1024. | |
batch['K'][:, 0:1, :] *= ratio_h | |
batch['K'][:, 1:2, :] *= ratio_w | |
# raymarcher is in mm | |
rm_preds = rm( | |
prim_rgba=preds["prim_rgba"], | |
prim_pos=preds["prim_pos"], | |
prim_scale=preds["prim_scale"], | |
prim_rot=preds["prim_rot"], | |
RT=batch["Rt"], | |
K=batch["K"], | |
) | |
rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) | |
with torch.no_grad(): | |
preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) | |
save_image_summary(image_save_path, batch, preds) | |
def visualize_multiview_primvolume(image_save_path, batch, prim_volume, view_counts, rm: RayMarcher, device): | |
# prim_volume - [B, nprims, 4+6*8^3] | |
view_angles = torch.linspace(0.5, 2.5, view_counts + 1) * torch.pi | |
view_angles = view_angles[:-1] | |
def sdf2alpha(sdf): | |
return torch.exp(-(sdf / 0.005) ** 2) | |
preds = {} | |
prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3))) | |
num_prims = prim_volume.shape[1] | |
bs = prim_volume.shape[0] | |
geo_start_index = 4 | |
geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive | |
tex_start_index = geo_end_index | |
tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive | |
mat_start_index = tex_end_index | |
mat_end_index = mat_start_index + prim_shape ** 3 * 2 | |
feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] | |
feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] | |
prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255 | |
prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255 | |
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) | |
pos = prim_volume[:, :, 1:4] | |
scale = prim_volume[:, :, 0:1] | |
preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius | |
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1) | |
preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3)) | |
batch['K'] = torch.Tensor([ | |
[ | |
2084.9526697685183, | |
0.0, | |
512.0 | |
], | |
[ | |
0.0, | |
2084.9526697685183, | |
512.0 | |
], | |
[ | |
0.0, | |
0.0, | |
1.0 | |
]]).to(device)[None, ...].repeat(bs, 1, 1) | |
ratio_h = rm.image_height / 1024. | |
ratio_w = rm.image_width / 1024. | |
batch['K'][:, 0:1, :] *= ratio_h | |
batch['K'][:, 1:2, :] *= ratio_w | |
final_preds = {} | |
final_preds['rgb'] = [] | |
final_preds['rgb_boxes'] = [] | |
for view_ang in view_angles: | |
bs_view_ang = view_ang.repeat(bs,) | |
batch['Rt'] = get_pose_on_orbit(radius=5*rm.volradius, height=0, angles=bs_view_ang).to(prim_volume) | |
# raymarcher is in mm | |
rm_preds = rm( | |
prim_rgba=preds["prim_rgba"], | |
prim_pos=preds["prim_pos"], | |
prim_scale=preds["prim_scale"], | |
prim_rot=preds["prim_rot"], | |
RT=batch["Rt"], | |
K=batch["K"], | |
) | |
rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) | |
with torch.no_grad(): | |
preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) | |
final_preds['rgb'].append(preds['rgb']) | |
final_preds['rgb_boxes'].append(preds['rgb_boxes']) | |
final_preds['rgb'] = torch.concat(final_preds['rgb'], dim=0) | |
final_preds['rgb_boxes'] = torch.concat(final_preds['rgb_boxes'], dim=0) | |
save_image_summary(image_save_path, batch, final_preds) | |
def visualize_video_primvolume(video_save_folder, batch, prim_volume, view_counts, rm: RayMarcher, device): | |
# prim_volume - [B, nprims, 4+6*8^3] | |
view_angles = torch.linspace(1.5, 3.5, view_counts + 1) * torch.pi | |
def sdf2alpha(sdf): | |
return torch.exp(-(sdf / 0.005) ** 2) | |
preds = {} | |
prim_shape = int(np.round(((prim_volume.shape[2] - 4) / 6) ** (1/3))) | |
num_prims = prim_volume.shape[1] | |
bs = prim_volume.shape[0] | |
geo_start_index = 4 | |
geo_end_index = geo_start_index + prim_shape ** 3 # non-inclusive | |
tex_start_index = geo_end_index | |
tex_end_index = tex_start_index + prim_shape ** 3 * 3 # non-inclusive | |
mat_start_index = tex_end_index | |
mat_end_index = mat_start_index + prim_shape ** 3 * 2 | |
feat_geo = prim_volume[:, :, geo_start_index: geo_end_index] | |
feat_tex = prim_volume[:, :, tex_start_index: tex_end_index] | |
feat_mat = prim_volume[:, :, mat_start_index: mat_end_index] | |
prim_alpha = sdf2alpha(feat_geo).reshape(bs, num_prims, 1, prim_shape, prim_shape, prim_shape) * 255 | |
prim_rgb = feat_tex.reshape(bs, num_prims, 3, prim_shape, prim_shape, prim_shape) * 255 | |
prim_mat = feat_mat.reshape(bs, num_prims, 2, prim_shape, prim_shape, prim_shape) * 255 | |
dummy_prim = torch.zeros_like(prim_mat[:, :, 0:1, ...]) | |
prim_mat = torch.concat([dummy_prim, prim_mat], dim=2) | |
preds['prim_rgba'] = torch.concat([prim_rgb, prim_alpha], dim=2) | |
preds['prim_mata'] = torch.concat([prim_mat, prim_alpha], dim=2) | |
pos = prim_volume[:, :, 1:4] | |
scale = prim_volume[:, :, 0:1] | |
preds['prim_pos'] = pos.reshape(bs, num_prims, 3) * rm.volradius | |
preds['prim_rot'] = torch.eye(3).to(preds['prim_pos'])[None, None, ...].repeat(bs, num_prims, 1, 1) | |
preds['prim_scale'] = (1 / scale.reshape(bs, num_prims, 1).repeat(1, 1, 3)) | |
batch['K'] = torch.Tensor([ | |
[ | |
2084.9526697685183, | |
0.0, | |
512.0 | |
], | |
[ | |
0.0, | |
2084.9526697685183, | |
512.0 | |
], | |
[ | |
0.0, | |
0.0, | |
1.0 | |
]]).to(device)[None, ...].repeat(bs, 1, 1) | |
ratio_h = rm.image_height / 1024. | |
ratio_w = rm.image_width / 1024. | |
batch['K'][:, 0:1, :] *= ratio_h | |
batch['K'][:, 1:2, :] *= ratio_w | |
final_preds = {} | |
final_preds['rgb'] = [] | |
final_preds['rgb_boxes'] = [] | |
final_preds['mat_rgb'] = [] | |
for view_ang in view_angles: | |
bs_view_ang = view_ang.repeat(bs,) | |
batch['Rt'] = get_pose_on_orbit(radius=5*rm.volradius, height=0, angles=bs_view_ang).to(prim_volume) | |
# raymarcher is in mm | |
rm_preds = rm( | |
prim_rgba=preds["prim_rgba"], | |
prim_pos=preds["prim_pos"], | |
prim_scale=preds["prim_scale"], | |
prim_rot=preds["prim_rot"], | |
RT=batch["Rt"], | |
K=batch["K"], | |
) | |
rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
preds.update(alpha=rgba[..., -1].contiguous(), rgb=rgba[..., :3].contiguous()) | |
with torch.no_grad(): | |
preds["rgb_boxes"] = render_mvp_boxes(rm, batch, preds) | |
rm_preds = rm( | |
prim_rgba=preds["prim_mata"], | |
prim_pos=preds["prim_pos"], | |
prim_scale=preds["prim_scale"], | |
prim_rot=preds["prim_rot"], | |
RT=batch["Rt"], | |
K=batch["K"], | |
) | |
mat_rgba = rm_preds["rgba_image"].permute(0, 2, 3, 1) | |
preds.update(mat_rgb=mat_rgba[..., :3].contiguous()) | |
final_preds['rgb'].append(preds['rgb']) | |
final_preds['rgb_boxes'].append(preds['rgb_boxes']) | |
final_preds['mat_rgb'].append(preds['mat_rgb']) | |
assert len(final_preds['rgb']) == len(final_preds['rgb_boxes']) | |
final_preds['rgb'] = torch.concat(final_preds['rgb'], dim=0) | |
final_preds['rgb_boxes'] = torch.concat(final_preds['rgb_boxes'], dim=0) | |
final_preds['mat_rgb'] = torch.concat(final_preds['mat_rgb'], dim=0) | |
total_num_frames = final_preds['rgb'].shape[0] | |
rgb_video = os.path.join(video_save_folder, 'rgb.mp4') | |
rgb_video_out = imageio.get_writer(rgb_video, fps=20) | |
prim_video = os.path.join(video_save_folder, 'prim.mp4') | |
prim_video_out = imageio.get_writer(prim_video, fps=20) | |
mat_video = os.path.join(video_save_folder, 'mat.mp4') | |
mat_video_out = imageio.get_writer(mat_video, fps=20) | |
rgb_np = np.clip(final_preds['rgb'].detach().cpu().numpy(), 0, 255).astype(np.uint8) | |
prim_np = np.clip(final_preds['rgb_boxes'].detach().cpu().numpy(), 0, 255).astype(np.uint8) | |
mat_np = np.clip(final_preds['mat_rgb'].detach().cpu().numpy(), 0, 255).astype(np.uint8) | |
for fidx in range(total_num_frames): | |
rgb_video_out.append_data(rgb_np[fidx]) | |
prim_video_out.append_data(prim_np[fidx]) | |
mat_video_out.append_data(mat_np[fidx]) | |
rgb_video_out.close() | |
prim_video_out.close() | |
mat_video_out.close() |