Kiss3DGen / utils /tool.py
JiantaoLin
initial commit
98bebfc
raw
history blame
19.2 kB
import rembg
import cv2
import numpy as np
import glm
import torch
from tqdm import tqdm
import torchvision
import torchvision.transforms.v2 as T
from models.lrm.utils import render_utils
import os
# get the background of the image
import torch
import numpy as np
import scipy
import cv2
from rembg import remove
def load_mipmap(env_path):
diffuse_path = os.path.join(env_path, "diffuse.pth")
diffuse = torch.load(diffuse_path, map_location=torch.device('cpu'))
specular = []
for i in range(6):
specular_path = os.path.join(env_path, f"specular_{i}.pth")
specular_tensor = torch.load(specular_path, map_location=torch.device('cpu'))
specular.append(specular_tensor)
return [specular, diffuse]
def get_background(img_tensor):
"""
Args:
img_tensor: 输入图像张量,形状为 (B, 3, H, W),数值范围为 [0, 1] 或 [0, 255]。
Returns:
mask_tensor: 输出掩码张量,形状为 (B, 1, H, W),二值化。
"""
B, C, H, W = img_tensor.shape
assert C == 3, "Input tensor must have 3 channels (RGB)."
# 将 tensor 转换为 numpy 格式 (B, H, W, C),并归一化到 [0, 255]
img_numpy = (img_tensor.permute(0, 2, 3, 1) * 255).byte().cpu().numpy() # (B, H, W, C)
masks = []
for i in range(B):
# 调用 rembg 生成掩码
mask = remove(img_numpy[i], only_mask=True)
# 转换为二值掩码
mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)[1]
# 添加到结果列表 (H, W, 1)
masks.append(mask_binary[..., None])
# 将所有掩码组合成 numpy 数组,形状为 (B, H, W, 1)
masks = np.stack(masks, axis=0)
# 转换为 PyTorch 张量,形状为 (B, 1, H, W),值为 {0, 1}
mask_tensor = torch.from_numpy(masks).permute(0, 3, 1, 2).float() / 255.0
# breakpoint()
return mask_tensor
def get_render_cameras_video(batch_size=1, M=120, radius=4.0, elevation=20.0, is_flexicubes=False, fov=50):
"""
Get the rendering camera parameters.
"""
train_res = [512, 512]
cam_near_far = [0.1, 1000.0]
fovy = np.deg2rad(fov)
proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1])
all_mv = []
all_mvp = []
all_campos = []
if isinstance(elevation, tuple):
elevation_0 = np.deg2rad(elevation[0])
elevation_1 = np.deg2rad(elevation[1])
for i in range(M//2):
azimuth = 2 * np.pi * i / (M // 2)
z = radius * np.cos(azimuth) * np.sin(elevation_0)
x = radius * np.sin(azimuth) * np.sin(elevation_0)
y = radius * np.cos(elevation_0)
eye = glm.vec3(x, y, z)
at = glm.vec3(0.0, 0.0, 0.0)
up = glm.vec3(0.0, 1.0, 0.0)
view_matrix = glm.lookAt(eye, at, up)
mv = torch.from_numpy(np.array(view_matrix))
mvp = proj_mtx @ (mv) #w2c
campos = torch.linalg.inv(mv)[:3, 3]
all_mv.append(mv[None, ...].cuda())
all_mvp.append(mvp[None, ...].cuda())
all_campos.append(campos[None, ...].cuda())
for i in range(M//2):
azimuth = 2 * np.pi * i / (M // 2)
z = radius * np.cos(azimuth) * np.sin(elevation_1)
x = radius * np.sin(azimuth) * np.sin(elevation_1)
y = radius * np.cos(elevation_1)
eye = glm.vec3(x, y, z)
at = glm.vec3(0.0, 0.0, 0.0)
up = glm.vec3(0.0, 1.0, 0.0)
view_matrix = glm.lookAt(eye, at, up)
mv = torch.from_numpy(np.array(view_matrix))
mvp = proj_mtx @ (mv) #w2c
campos = torch.linalg.inv(mv)[:3, 3]
all_mv.append(mv[None, ...].cuda())
all_mvp.append(mvp[None, ...].cuda())
all_campos.append(campos[None, ...].cuda())
else:
# elevation = 90 - elevation
for i in range(M):
azimuth = 2 * np.pi * i / M
z = radius * np.cos(azimuth) * np.sin(elevation)
x = radius * np.sin(azimuth) * np.sin(elevation)
y = radius * np.cos(elevation)
eye = glm.vec3(x, y, z)
at = glm.vec3(0.0, 0.0, 0.0)
up = glm.vec3(0.0, 1.0, 0.0)
view_matrix = glm.lookAt(eye, at, up)
mv = torch.from_numpy(np.array(view_matrix))
mvp = proj_mtx @ (mv) #w2c
campos = torch.linalg.inv(mv)[:3, 3]
all_mv.append(mv[None, ...].cuda())
all_mvp.append(mvp[None, ...].cuda())
all_campos.append(campos[None, ...].cuda())
all_mv = torch.stack(all_mv, dim=0).unsqueeze(0).squeeze(2)
all_mvp = torch.stack(all_mvp, dim=0).unsqueeze(0).squeeze(2)
all_campos = torch.stack(all_campos, dim=0).unsqueeze(0).squeeze(2)
return all_mv, all_mvp, all_campos
def get_render_cameras_frames(batch_size=1, radius=4.0, azimuths=0, elevations=20.0, fov=30):
"""
Get the rendering camera parameters.
"""
train_res = [512, 512]
cam_near_far = [0.1, 1000.0]
fovy = np.deg2rad(fov)
proj_mtx = render_utils.perspective(fovy, train_res[1] / train_res[0], cam_near_far[0], cam_near_far[1])
all_mv = []
all_mvp = []
all_campos = []
elevations = 90 - elevations
if isinstance(elevations, np.ndarray) or isinstance(elevations, torch.Tensor):
if isinstance(elevations, torch.Tensor):
elevations = elevations.cpu().numpy()
if isinstance(azimuths, torch.Tensor):
azimuths = azimuths.cpu().numpy()
azimuths = np.deg2rad(azimuths)
elevations = np.deg2rad(elevations)
for azi, ele in zip(azimuths, elevations):
z = radius * np.cos(azi) * np.sin(ele)
x = radius * np.sin(azi) * np.sin(ele)
y = radius * np.cos(ele)
eye = glm.vec3(x, y, z)
at = glm.vec3(0.0, 0.0, 0.0)
up = glm.vec3(0.0, 1.0, 0.0)
view_matrix = glm.lookAt(eye, at, up)
mv = torch.from_numpy(np.array(view_matrix))
mvp = proj_mtx @ (mv) #w2c
campos = torch.linalg.inv(mv)[:3, 3]
all_mv.append(mv[None, ...].cuda())
all_mvp.append(mvp[None, ...].cuda())
all_campos.append(campos[None, ...].cuda())
else:
z = radius * np.cos(azimuths) * np.sin(elevations)
x = radius * np.sin(azimuths) * np.sin(elevations)
y = radius * np.cos(elevations)
eye = glm.vec3(x, y, z)
at = glm.vec3(0.0, 0.0, 0.0)
up = glm.vec3(0.0, 1.0, 0.0)
view_matrix = glm.lookAt(eye, at, up)
mv = torch.from_numpy(np.array(view_matrix))
mvp = proj_mtx @ (mv) #w2c
campos = torch.linalg.inv(mv)[:3, 3]
all_mv.append(mv[None, ...].cuda())
all_mvp.append(mvp[None, ...].cuda())
all_campos.append(campos[None, ...].cuda())
# TODO, identity pose
identity_azimuths = np.array([0])
identity_elevations = np.array([90])
z = radius * np.cos(identity_azimuths) * np.sin(identity_elevations)
x = radius * np.sin(identity_azimuths) * np.sin(identity_elevations)
y = radius * np.cos(identity_elevations)
eye = glm.vec3(x, y, z)
at = glm.vec3(0.0, 0.0, 0.0)
up = glm.vec3(0.0, 1.0, 0.0)
view_matrix = glm.lookAt(eye, at, up)
identity_mv = torch.from_numpy(np.array(view_matrix))
all_mv = torch.stack(all_mv, dim=0).unsqueeze(0).squeeze(2)
all_mvp = torch.stack(all_mvp, dim=0).unsqueeze(0).squeeze(2)
all_campos = torch.stack(all_campos, dim=0).unsqueeze(0).squeeze(2)
return all_mv, all_mvp, all_campos, identity_mv
def worldNormal2camNormal(rot_w2c, normal_map_world):
H,W,_ = normal_map_world.shape
# normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
normal_map_world = normal_map_world[...,:3]
# faster version
normal_map_flat = normal_map_world.view(-1, 3)
normal_map_camera_flat = torch.matmul(normal_map_flat.float(), rot_w2c.T.float())
# Reshape the transformed normal map back to its original shape
normal_map_camera = normal_map_camera_flat.view(normal_map_world.shape)
return normal_map_camera
def trans_normal(normal, RT_w2c, RT_w2c_target):
# normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
# normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
relative_RT = torch.matmul(RT_w2c_target[:3,:3], torch.linalg.inv(RT_w2c[:3,:3]))
normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal)
return normal_target_cam
def render_frames(model, planes, render_cameras, camera_pos, env, materials, render_size=512, chunk_size=1,
is_flexicubes=False, render_mv=None, local_normal=False, identity_mv=None):
"""
Render frames from triplanes.
"""
frames = []
albedos = []
pbr_spec_lights = []
pbr_diffuse_lights = []
normals = []
alphas = []
for i in tqdm(range(0, render_cameras.shape[1])):
out = model.forward_geometry(
planes,
render_cameras[:, i:i+chunk_size],
camera_pos[:, i:i+chunk_size],
[[env]*chunk_size],
[[materials]*chunk_size],
render_size=render_size,
)
frame = out['pbr_img']
albedo = out['albedo']
pbr_spec_light = out['pbr_spec_light']
pbr_diffuse_light = out['pbr_diffuse_light']
normal = out['normal']
alpha = out['mask']
# breakpoint()
if local_normal:
# TODO global normal to local
target_w2c = render_mv[0,i,:3,:3]
identity_w2c = identity_mv[:3,:3]
# breakpoint()
# torchvision.utils.save_image((normal.permute(0,3,1,2)+1)/2, f"debug_output/global_normal.png")
normal = trans_normal(normal.squeeze(0), identity_w2c.cuda(), target_w2c.cuda())
normal = normal / torch.norm(normal, dim=-1, keepdim=True)
# torchvision.utils.save_image((normal.permute(2,0,1)+1)/2, f"debug_output/local_normal.png")
background_normal = torch.tensor([1,1,1], dtype=torch.float32, device=normal.device)
normal = normal.unsqueeze(0)
normal[...,0] *= -1
# breakpoint()
normal = normal * alpha.squeeze(0).permute(0,2,3,1) + background_normal * (1-alpha.squeeze(0).permute(0,2,3,1))
frames.append(frame)
albedos.append(albedo)
pbr_spec_lights.append(pbr_spec_light)
pbr_diffuse_lights.append(pbr_diffuse_light)
normals.append(normal)
alphas.append(alpha)
frames = torch.cat(frames, dim=1)[0] # we suppose batch size is always 1
alphas = torch.cat(alphas, dim=1)[0]
albedos = torch.cat(albedos, dim=1)[0]
pbr_spec_lights = torch.cat(pbr_spec_lights, dim=1)[0]
pbr_diffuse_lights = torch.cat(pbr_diffuse_lights, dim=1)[0]
normals = torch.cat(normals, dim=0).permute(0,3,1,2)[:,:3]
return frames, albedos, pbr_spec_lights, pbr_diffuse_lights, normals, alphas
# from https://github.com/cubiq/ComfyUI_essentials
def mask_fix(mask, erode_dilate=0, smooth=0, remove_isolated_pixels=0, blur=0, fill_holes=0):
masks = []
for m in mask:
# erode and dilate
if erode_dilate != 0:
if erode_dilate < 0:
m = torch.from_numpy(scipy.ndimage.grey_erosion(m.cpu().numpy(), size=(-erode_dilate, -erode_dilate)))
else:
m = torch.from_numpy(scipy.ndimage.grey_dilation(m.cpu().numpy(), size=(erode_dilate, erode_dilate)))
# fill holes
if fill_holes > 0:
#m = torch.from_numpy(scipy.ndimage.binary_fill_holes(m.cpu().numpy(), structure=np.ones((fill_holes,fill_holes)))).float()
m = torch.from_numpy(scipy.ndimage.grey_closing(m.cpu().numpy(), size=(fill_holes, fill_holes)))
# remove isolated pixels
if remove_isolated_pixels > 0:
m = torch.from_numpy(scipy.ndimage.grey_opening(m.cpu().numpy(), size=(remove_isolated_pixels, remove_isolated_pixels)))
# smooth the mask
if smooth > 0:
if smooth % 2 == 0:
smooth += 1
m = T.functional.gaussian_blur((m > 0.5).unsqueeze(0), smooth).squeeze(0)
# blur the mask
if blur > 0:
if blur % 2 == 0:
blur += 1
m = T.functional.gaussian_blur(m.float().unsqueeze(0), blur).squeeze(0)
masks.append(m.float())
masks = torch.stack(masks, dim=0).float()
return masks
class NormalTransfer:
def __init__(self):
self.identity_w2c = torch.tensor([
[0.0, 0.0, 1.0, 0.0],
[ 0.0, 1.0, 0.0, 0.0],
[-1.0, 0.0, 0.0, 4.5]]).float()
def look_at(self,camera_position, target_position, up_vector=np.array([0, 0, 1])):
forward = camera_position - target_position
forward = forward / np.linalg.norm(forward)
right = np.cross(up_vector, forward)
right = right / np.linalg.norm(right)
up = np.cross(forward, right)
rotation_matrix = np.array([right, up, forward]).T
translation_matrix = np.eye(4)
translation_matrix[:3, 3] = -camera_position
rotation_homogeneous = np.eye(4)
rotation_homogeneous[:3, :3] = rotation_matrix
w2c = rotation_homogeneous @ translation_matrix
return w2c
def generate_target_pose(self, azimuths_deg, elevations_deg, radius=4.5):
if isinstance(azimuths_deg, torch.Tensor):
azimuths_deg = azimuths_deg.cpu().numpy()
if isinstance(elevations_deg, torch.Tensor):
elevations_deg = elevations_deg.cpu().numpy()
azimuths = np.deg2rad(azimuths_deg)
elevations = np.deg2rad(elevations_deg)
x = radius * np.cos(azimuths) * np.cos(elevations)
y = radius * np.sin(azimuths) * np.cos(elevations)
z = radius * np.sin(elevations)
camera_positions = np.stack([x, y, z], axis=-1)
target_position = np.array([0, 0, 0]) # 目标点位置
# 为每个相机位置生成 w2c 矩阵
w2c_matrices = [self.look_at(cam_pos, target_position) for cam_pos in camera_positions]
w2c_matrices = np.stack(w2c_matrices, axis=0)
return w2c_matrices
def convert_to_blender(self, pose):
# Swap the y and z axes
w2c_opengl = pose
w2c_opengl[[1, 2], :] = w2c_opengl[[2, 1], :]
# Invert the y axis
w2c_opengl[1] *= -1
R = w2c_opengl[:3, :3]
t = w2c_opengl[:3, 3]
cam_rec = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
R = R.T
t = -R @ t
R_world2cv = cam_rec @ R
t_world2cv = cam_rec @ t
RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
return RT
def worldNormal2camNormal(self, rot_w2c, normal_map_world):
H,W,_ = normal_map_world.shape
# normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
normal_map_world = normal_map_world[...,:3]
# faster version
normal_map_flat = normal_map_world.contiguous().view(-1, 3)
normal_map_camera_flat = torch.matmul(normal_map_flat.float(), rot_w2c.T.float())
# Reshape the transformed normal map back to its original shape
normal_map_camera = normal_map_camera_flat.view(normal_map_world.shape)
return normal_map_camera
def trans_normal(self, normal, RT_w2c, RT_w2c_target):
"""
:param normal: (H,W,3), torch tensor, range [-1,1]
:param RT_w2c: (4,4), torch tensor, world to camera
:param RT_w2c_target: (4,4), torch tensor, world to camera
:return: normal_target_cam: (H,W,3), torch tensor, range [-1,1]
"""
relative_RT = torch.matmul(RT_w2c_target[:3,:3], torch.linalg.inv(RT_w2c[:3,:3]))
normal_target_cam = self.worldNormal2camNormal(relative_RT[:3,:3], normal)
return normal_target_cam
def trans_local_2_global(self, normal_local, azimuths_deg, elevations_deg, radius=4.5, for_lotus=True):
"""
:param normal_local: (B,H,W,3), torch tensor, range [-1,1]
:param azimuths_deg: (B,), numpy array, range [0,360]
:param elevations_deg: (B,), numpy array, range [-90,90]
:param radius: float, default 4.5
:return: global_normal: (B,H,W,3), torch tensor, range [-1,1]
"""
# print(f"normal_local.shape:{normal_local.shape}")
# print(f"azimuths_deg.shape:{azimuths_deg.shape}")
# print(f"elevations_deg.shape:{elevations_deg.shape}")
assert normal_local.shape[0] == azimuths_deg.shape[0] == elevations_deg.shape[0]
identity_w2c = self.identity_w2c
# generate target pose
target_w2c = self.generate_target_pose(azimuths_deg, elevations_deg, radius)
target_w2c = torch.from_numpy(np.stack([self.convert_to_blender(w2c) for w2c in target_w2c])).float()
global_normal = []
# transform normal
for i in range(normal_local.shape[0]):
normal_local_i = normal_local[i]
normal_zero123 = self.trans_normal(normal_local_i, target_w2c[i], identity_w2c)
global_normal.append(normal_zero123)
global_normal = torch.stack(global_normal, dim=0)
if for_lotus:
global_normal[...,0] *= -1
global_normal = global_normal / torch.norm(global_normal, dim=-1, keepdim=True)
return global_normal
def trans_global_2_local(self, normal_local, azimuths_deg, elevations_deg, radius=4.5):
"""
:param normal_global: (B,H,W,3), torch tensor, range [-1,1]
:param azimuths_deg: (B,), numpy array, range [0,360]
:param elevations_deg: (B,), numpy array, range [-90,90]
:param radius: float, default 4.5
:return: local_normal: (B,H,W,3), torch tensor, range [-1,1]
"""
print(f"normal_local.shape:{normal_local.shape}")
print(f"azimuths_deg.shape:{azimuths_deg.shape}")
print(f"elevations_deg.shape:{elevations_deg.shape}")
assert normal_local.shape[0] == azimuths_deg.shape[0] == elevations_deg.shape[0]
identity_w2c = self.identity_w2c
# generate target pose
target_w2c = self.generate_target_pose(azimuths_deg, elevations_deg, radius)
target_w2c = torch.from_numpy(np.stack([w2c for w2c in target_w2c])).float()
local_normal = []
# transform normal
for i in range(normal_local.shape[0]):
normal_local_i = normal_local[i]
normal = self.trans_normal(normal_local_i, identity_w2c, target_w2c[i])
local_normal.append(normal)
local_normal = torch.stack(local_normal, dim=0)
# global_normal[...,0] *= -1
local_normal = local_normal / torch.norm(local_normal, dim=-1, keepdim=True)
return local_normal