splatter_image / utils /app_utils.py
Stanislaw Szymanowicz
Remove halo and add rotation
ccc0216
raw
history blame
No virus
7.96 kB
from PIL import Image
from typing import Any
import rembg
import numpy as np
from torchvision import transforms
from plyfile import PlyData, PlyElement
import os
import torch
from .camera_utils import get_loop_cameras
from .graphics_utils import getProjectionMatrix
from .general_utils import matrix_to_quaternion, quaternion_raw_multiply
import math
def remove_background(image, rembg_session):
do_remove = True
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
do_remove = False
if do_remove:
image = rembg.remove(image, session=rembg_session)
return image
def set_white_background(image):
image = np.array(image).astype(np.float32) / 255.0
mask = image[:, :, 3:4]
image = image[:, :, :3] * mask + (1 - mask)
image = Image.fromarray((image * 255.0).astype(np.uint8))
return image
def resize_foreground(image, ratio):
image = np.array(image)
assert image.shape[-1] == 4
alpha = np.where(image[..., 3] > 0)
# modify so that cropping doesn't change the world center
y1, y2, x1, x2 = (
alpha[0].min(),
alpha[0].max(),
alpha[1].min(),
alpha[1].max(),
)
# crop the foreground
fg = image[y1: y2,
x1: x2]
# pad to square
size = max(fg.shape[0], fg.shape[1])
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0
new_image = np.pad(
fg,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((255, 255), (255, 255), (0, 0)),
)
# compute padding according to the ratio
new_size = int(new_image.shape[0] / ratio)
# pad to size, double side
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2
ph1, pw1 = new_size - size - ph0, new_size - size - pw0
new_image = np.pad(
new_image,
((ph0, ph1), (pw0, pw1), (0, 0)),
mode="constant",
constant_values=((255, 255), (255, 255), (0, 0)),
)
new_image = Image.fromarray(new_image)
return new_image
def resize_to_128(img):
img = transforms.functional.resize(img, 128,
interpolation=transforms.InterpolationMode.LANCZOS)
return img
def to_tensor(img):
img = torch.tensor(img).permute(2, 0, 1) / 255.0
return img
def get_source_camera_v2w_rmo_and_quats(num_imgs_in_loop=200):
source_camera = get_loop_cameras(num_imgs_in_loop=num_imgs_in_loop)[0]
source_camera = torch.from_numpy(source_camera).transpose(0, 1).unsqueeze(0)
qs = []
for c_idx in range(source_camera.shape[0]):
qs.append(matrix_to_quaternion(source_camera[c_idx, :3, :3].transpose(0, 1)))
return source_camera.unsqueeze(0), torch.stack(qs, dim=0).unsqueeze(0)
def get_target_cameras(num_imgs_in_loop=200):
"""
Returns camera parameters for rendering a loop around the object:
world_to_view_transforms,
full_proj_transforms,
camera_centers
"""
projection_matrix = getProjectionMatrix(
znear=0.8, zfar=3.2,
fovX=49.134342641202636 * 2 * np.pi / 360,
fovY=49.134342641202636 * 2 * np.pi / 360).transpose(0,1)
target_cameras = get_loop_cameras(num_imgs_in_loop=num_imgs_in_loop,
max_elevation=np.pi/4,
elevation_freq=1.5)
world_view_transforms = []
view_world_transforms = []
camera_centers = []
for loop_camera_c2w_cmo in target_cameras:
view_world_transform = torch.from_numpy(loop_camera_c2w_cmo).transpose(0, 1)
world_view_transform = torch.from_numpy(loop_camera_c2w_cmo).inverse().transpose(0, 1)
camera_center = view_world_transform[3, :3].clone()
world_view_transforms.append(world_view_transform)
view_world_transforms.append(view_world_transform)
camera_centers.append(camera_center)
world_view_transforms = torch.stack(world_view_transforms)
view_world_transforms = torch.stack(view_world_transforms)
camera_centers = torch.stack(camera_centers)
full_proj_transforms = world_view_transforms.bmm(projection_matrix.unsqueeze(0).expand(
world_view_transforms.shape[0], 4, 4))
return world_view_transforms, full_proj_transforms, camera_centers
def construct_list_of_attributes():
# taken from gaussian splatting repo.
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
# All channels except the 3 DC
# 3 channels for DC
for i in range(3):
l.append('f_dc_{}'.format(i))
# 9 channels for SH order 1
for i in range(9):
l.append('f_rest_{}'.format(i))
l.append('opacity')
for i in range(3):
l.append('scale_{}'.format(i))
for i in range(4):
l.append('rot_{}'.format(i))
return l
def export_to_obj(reconstruction, ply_out_path):
"""
Args:
reconstruction: dict with xyz, opacity, features dc, etc with leading batch size
ply_out_path: file path where to save the output
"""
os.makedirs(os.path.dirname(ply_out_path), exist_ok=True)
for k, v in reconstruction.items():
# check dimensions
if k not in ["features_dc", "features_rest"]:
assert len(v.shape) == 3, "Unexpected size for {}".format(k)
else:
assert len(v.shape) == 4, "Unexpected size for {}".format(k)
assert v.shape[0] == 1, "Expected batch size to be 0"
reconstruction[k] = v[0]
valid_gaussians = torch.where(reconstruction["opacity"] > -2.5)[0]
# transforms for visualisation in Gradio
# ============= Transform locations =============
xyz = reconstruction["xyz"][valid_gaussians].detach().cpu().clone()
t1 = torch.tensor([[1, 0, 0],
[0, 0, 1],
[0, -1, 0]], dtype=torch.float32)
angle1 = 30 * math.pi * 2 / 360
t2 = torch.tensor([[math.cos(angle1), -math.sin(angle1), 0],
[math.sin(angle1), math.cos(angle1), 0],
[0, 0, 1]], dtype=torch.float32)
angle2 = -60 * math.pi * 2 / 360
t3 = torch.tensor([[math.cos(angle2), 0, math.sin(angle2)],
[0, 1, 0],
[-math.sin(angle2), 0, math.cos(angle2)]], dtype=torch.float32)
overall_transform_matrix = (t1 @ t2)@ t3
xyz = torch.matmul(xyz, overall_transform_matrix).numpy()
normals = np.zeros_like(xyz)
# ============= Transform rotations =============
camera_transformation_matrix = overall_transform_matrix.inverse()
camera_quaternions = matrix_to_quaternion(camera_transformation_matrix).to(reconstruction["rotation"].device)
rotation = reconstruction["rotation"].clone()
rotation = rotation.unsqueeze(0)
rotation = quaternion_raw_multiply(camera_quaternions.unsqueeze(0).unsqueeze(0).expand(*rotation.shape),
rotation).squeeze(0)
f_dc = reconstruction["features_dc"][valid_gaussians].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
f_rest = reconstruction["features_rest"][valid_gaussians].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
opacities = reconstruction["opacity"][valid_gaussians].detach().contiguous().cpu().numpy()
# enlarge Gaussians - otherwise transforming them to .ply results in artefacts
scale = (reconstruction["scaling"][valid_gaussians] + torch.abs(reconstruction["scaling"][valid_gaussians] * 0.1)).detach().contiguous().cpu().numpy()
rotation = rotation[valid_gaussians].detach().contiguous().cpu().numpy()
dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes()]
elements = np.empty(xyz.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(ply_out_path)