Spaces:
Sleeping
Sleeping
import math | |
from typing import List, Optional, Sequence, Union | |
import imageio | |
import logging | |
import numpy as np | |
import torch | |
import torch.utils.data | |
from PIL import Image | |
from torch.distributions import Normal | |
from torchvision.transforms.functional import to_pil_image | |
from torchvision.utils import make_grid | |
from tqdm import tqdm, trange | |
from pytorch3d.renderer import ( | |
AlphaCompositor, | |
NormWeightedCompositor, | |
OrthographicCameras, | |
PointsRasterizationSettings, | |
PointsRasterizer, | |
PointsRenderer, | |
look_at_view_transform) | |
from pytorch3d.renderer.cameras import CamerasBase | |
from pytorch3d.structures import Pointclouds | |
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch | |
# Disable unnecessary imageio logging | |
logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) | |
def rotation_matrix(axis, theta): | |
""" | |
Return the rotation matrix associated with counterclockwise rotation about | |
the given axis by theta radians. | |
""" | |
axis = np.asarray(axis) | |
axis = axis / np.sqrt(np.dot(axis, axis)) | |
a = np.cos(theta / 2.0) | |
b, c, d = -axis * np.sin(theta / 2.0) | |
aa, bb, cc, dd = a * a, b * b, c * c, d * d | |
bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d | |
return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], | |
[2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], | |
[2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) | |
def rotate(vertices, faces): | |
''' | |
vertices: [numpoints, 3] | |
''' | |
M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() | |
N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() | |
K = rotation_matrix([0, 0, 1], np.pi).transpose() | |
v, f = vertices[:, [1, 2, 0]].dot(M).dot(N).dot(K), faces[:, [1, 2, 0]] | |
return v, f | |
def norm(v, f): | |
v = (v - v.min()) / (v.max() - v.min()) - 0.5 | |
return v, f | |
def getGradNorm(net): | |
pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) | |
gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) | |
return pNorm, gradNorm | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1 and m.weight is not None: | |
torch.nn.init.xavier_normal_(m.weight) | |
elif classname.find('BatchNorm') != -1: | |
m.weight.data.normal_() | |
m.bias.data.fill_(0) | |
def discretized_gaussian_log_likelihood(x, *, means, log_scales): | |
# Assumes data is integers [0, 1] | |
assert x.shape == means.shape == log_scales.shape | |
px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) | |
centered_x = x - means | |
inv_stdv = torch.exp(-log_scales) | |
plus_in = inv_stdv * (centered_x + 0.5) | |
cdf_plus = px0.cdf(plus_in) | |
min_in = inv_stdv * (centered_x - .5) | |
cdf_min = px0.cdf(min_in) | |
log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12)) | |
log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min) * 1e-12)) | |
cdf_delta = cdf_plus - cdf_min | |
log_probs = torch.where( | |
x < 0.001, log_cdf_plus, | |
torch.where(x > 0.999, log_one_minus_cdf_min, | |
torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12)))) | |
assert log_probs.shape == x.shape | |
return log_probs | |
def fig2img(fig): | |
"""Convert a Matplotlib figure to a PIL Image and return it""" | |
import io | |
buf = io.BytesIO() | |
fig.savefig(buf) | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
def visualize_distance_transform( | |
path_stem: str, | |
images: torch.Tensor, | |
) -> str: | |
output_file_image = f'{path_stem}.png' | |
if images.shape[3] in [1, 3]: # convert to (B, C, H, W) | |
images = images.permute(0, 3, 1, 2) | |
images = images[:, -1:] # (B, 1, H, W) # get only distances (not vectors for now, for simplicity) | |
image_grid = make_grid(images, nrow=int(math.sqrt(len(images))), pad_value=1, normalize=True) | |
to_pil_image(image_grid).save(output_file_image) | |
return output_file_image | |
def visualize_image( | |
path_stem: str, | |
images: torch.Tensor, | |
mean: Union[torch.Tensor, float] = 0.5, | |
std: Union[torch.Tensor, float] = 0.5, | |
) -> str: | |
output_file_image = f'{path_stem}.png' | |
if images.shape[3] in [1, 3, 4]: # convert to (B, C, H, W) | |
images = images.permute(0, 3, 1, 2) | |
if images.shape[1] in [3, 4]: # normalize (single-channel images are not normalized) | |
images[:, :3] = images[:, :3] * std + mean # denormalize (color channels only, not alpha channel) | |
if images.shape[1] == 4: # normalize (single-channel images are not normalized) | |
image_alpha = images[:, 3:] # (B, 1, H, W) | |
bg_color = torch.tensor([230, 220, 250], device=images.device).reshape(1, 3, 1, 1) / 255 | |
images = images[:, :3] * image_alpha + bg_color * (1 - image_alpha) # (B, 3, H, W) | |
image_grid = make_grid(images, nrow=int(math.sqrt(len(images))), pad_value=1) | |
to_pil_image(image_grid).save(output_file_image) | |
return output_file_image | |
def ensure_point_cloud_has_colors(pointcloud: Pointclouds): | |
if pointcloud.features_padded() is None: | |
pointcloud = type(pointcloud)(points=pointcloud.points_padded(), | |
normals=pointcloud.normals_padded(), features=torch.zeros_like(pointcloud.points_padded())) | |
return pointcloud | |
def render_pointcloud_batch_pytorch3d( | |
cameras: CamerasBase, | |
pointclouds: Pointclouds, | |
image_size: int = 224, | |
radius: float = 0.01, | |
points_per_pixel: int = 10, | |
background_color: Sequence[float] = (0.78431373, 0.78431373, 0.78431373), | |
compositor: str = 'norm_weighted' | |
): | |
# Define the settings for rasterization and shading. Here we set the output image to be of size | |
# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1 | |
# and blur_radius=0.0. Refer to rasterize_points.py for explanations of these parameters. | |
raster_settings = PointsRasterizationSettings( | |
image_size=image_size, | |
radius=radius, | |
points_per_pixel=points_per_pixel, | |
) | |
# Rasterizer | |
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) | |
# Compositor | |
if compositor == 'alpha': | |
compositor = AlphaCompositor(background_color=background_color) | |
elif compositor == 'norm_weighted': | |
compositor = NormWeightedCompositor(background_color=background_color) | |
else: | |
raise ValueError(compositor) | |
# Create a points renderer by compositing points using an weighted compositor (3D points are | |
# weighted according to their distance to a pixel and accumulated using a weighted sum) | |
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor) | |
# We cannot render a point cloud without colors, so add them if the pointcloud does | |
# not already have them | |
pointclouds = ensure_point_cloud_has_colors(pointclouds) | |
# Render batch of image | |
images = renderer(pointclouds) | |
return images | |
def visualize_pointcloud_batch_pytorch3d( | |
pointclouds: Pointclouds, | |
output_file_video: Optional[str] = None, | |
output_file_image: Optional[str] = None, | |
cameras: Optional[CamerasBase] = None, # if None, we rotate | |
scale_factor: float = 1.0, | |
num_frames: int = 1, # note that it takes a while with 30 * batch_size frames | |
elev: int = 30, | |
): | |
"""Saves a video and a single image of a point cloud""" | |
assert 360 % num_frames == 0, 'please select a better number of frames' | |
# Sizes | |
B, N, C, F = *(pointclouds.points_padded().shape), num_frames | |
device = pointclouds.device | |
# If a camera has not been provided, we render from a rotating view around an image | |
if cameras is None: | |
# Create view transforms - R is (F, 3, 3) and T is (F, 3) | |
R, T = look_at_view_transform(dist=10.0, elev=elev, azim=list(range(0, 360, 360 // F)), degrees=True, device=device) | |
# Repeat | |
R = R.repeat_interleave(B, dim=0) # (F * B, 3, 3) | |
T = T.repeat_interleave(B, dim=0) # (F * B, 3) | |
points = pointclouds.points_padded().tile(F, 1, 1) # (F * B, num_points, 3) | |
colors = (torch.zeros_like(points) if pointclouds.features_padded() is None else | |
pointclouds.features_padded().tile(F, 1, 1)) # (F * B, num_points, 3) | |
# Initialize batch of cameras | |
cameras = OrthographicCameras(focal_length=(0.25 * scale_factor), device=device, R=R, T=T) | |
# Wrap in Pointclouds (with color, even if the original point cloud had no color) | |
pointclouds = Pointclouds(points=points, features=colors).to(device) | |
# Render image | |
images = render_pointcloud_batch_pytorch3d(cameras, pointclouds) | |
# Convert images into grid | |
image_grids = [] | |
images_for_grids = images.reshape(F, B, *images.shape[1:]).permute(0, 1, 4, 2, 3) | |
for image_for_grids in images_for_grids: | |
image_grid = make_grid(image_for_grids, nrow=int(math.sqrt(B)), pad_value=1) | |
image_grids.append(image_grid) | |
image_grids = torch.stack(image_grids, dim=0) | |
image_grids = image_grids.detach().cpu() | |
# Save image | |
if output_file_image is not None: | |
to_pil_image(image_grids[0]).save(output_file_image) | |
# Save video | |
if output_file_video: | |
video = (image_grids * 255).permute(0, 2, 3, 1).to(torch.uint8).numpy() | |
imageio.mimwrite(output_file_video, video, fps=10) | |
def visualize_pointcloud_evolution_pytorch3d( | |
pointclouds: Pointclouds, | |
output_file_video: str, | |
camera: Optional[CamerasBase] = None, # if None, we rotate | |
scale_factor: float = 1.0, | |
): | |
# Device | |
B, device = len(pointclouds), pointclouds.device | |
# Cameras | |
if camera is None: | |
R, T = look_at_view_transform(dist=10.0, elev=30, azim=0, device=device) | |
camera = OrthographicCameras(focal_length=(0.25 * scale_factor), device=device, R=R, T=T) | |
# Render | |
frames = render_pointcloud_batch_pytorch3d(camera, pointclouds) | |
# Save video | |
video = (frames.detach().cpu() * 255).to(torch.uint8).numpy() | |
imageio.mimwrite(output_file_video, video, fps=10) | |
def get_camera_index(cameras: CamerasBase, index: Optional[int] = None): | |
if index is None: | |
return cameras | |
kwargs = dict( | |
R=cameras.R[index].unsqueeze(0), | |
T=cameras.T[index].unsqueeze(0), | |
K=cameras.K[index].unsqueeze(0) if cameras.K is not None else None, | |
) | |
if hasattr(cameras, 'focal_length'): | |
kwargs['focal_length'] = cameras.focal_length[index].unsqueeze(0) | |
if hasattr(cameras, 'principal_point'): | |
kwargs['principal_point'] = cameras.principal_point[index].unsqueeze(0) | |
return type(cameras)(**kwargs).to(cameras.device) | |
def get_metadata(item) -> str: | |
s = '-------------\n' | |
for key in item.keys(): | |
value = item[key] | |
if torch.is_tensor(value) and value.numel() < 25: | |
value_str = value | |
elif torch.is_tensor(value): | |
value_str = value.shape | |
elif isinstance(value, str): | |
value_str = value | |
elif isinstance(value, list) and 0 < len(value) and len(value) < 25 and isinstance(value[0], str): | |
value_str = value | |
elif isinstance(value, dict): | |
value_str = str({k: type(v) for k, v in value.items()}) | |
else: | |
value_str = type(value) | |
s += f"{key:<30} {value_str}\n" | |
return s | |