Vincentqyw
update: limit keypoints number
e15a186
raw history blame
No virus
5.26 kB
"""
3D visualization based on plotly.
Works for a small number of points and cameras, might be slow otherwise.
1) Initialize a figure with `init_figure`
2) Add 3D points, camera frustums, or both as a pycolmap.Reconstruction
Written by Paul-Edouard Sarlin and Philipp Lindenberger.
"""
from typing import Optional
import numpy as np
import pycolmap
import plotly.graph_objects as go
def to_homogeneous(points):
pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
return np.concatenate([points, pad], axis=-1)
def init_figure(height: int = 800) -> go.Figure:
"""Initialize a 3D figure."""
fig = go.Figure()
axes = dict(
visible=False,
showbackground=False,
showgrid=False,
showline=False,
showticklabels=True,
autorange=True,
)
fig.update_layout(
template="plotly_dark",
height=height,
scene_camera=dict(
eye=dict(x=0.0, y=-0.1, z=-2),
up=dict(x=0, y=-1.0, z=0),
projection=dict(type="orthographic"),
),
scene=dict(
xaxis=axes,
yaxis=axes,
zaxis=axes,
aspectmode="data",
dragmode="orbit",
),
margin=dict(l=0, r=0, b=0, t=0, pad=0),
legend=dict(
orientation="h", yanchor="top", y=0.99, xanchor="left", x=0.1
),
)
return fig
def plot_points(
fig: go.Figure,
pts: np.ndarray,
color: str = "rgba(255, 0, 0, 1)",
ps: int = 2,
colorscale: Optional[str] = None,
name: Optional[str] = None,
):
"""Plot a set of 3D points."""
x, y, z = pts.T
tr = go.Scatter3d(
x=x,
y=y,
z=z,
mode="markers",
name=name,
legendgroup=name,
marker=dict(
size=ps, color=color, line_width=0.0, colorscale=colorscale
),
)
fig.add_trace(tr)
def plot_camera(
fig: go.Figure,
R: np.ndarray,
t: np.ndarray,
K: np.ndarray,
color: str = "rgb(0, 0, 255)",
name: Optional[str] = None,
legendgroup: Optional[str] = None,
size: float = 1.0,
):
"""Plot a camera frustum from pose and intrinsic matrix."""
W, H = K[0, 2] * 2, K[1, 2] * 2
corners = np.array([[0, 0], [W, 0], [W, H], [0, H], [0, 0]])
if size is not None:
image_extent = max(size * W / 1024.0, size * H / 1024.0)
world_extent = max(W, H) / (K[0, 0] + K[1, 1]) / 0.5
scale = 0.5 * image_extent / world_extent
else:
scale = 1.0
corners = to_homogeneous(corners) @ np.linalg.inv(K).T
corners = (corners / 2 * scale) @ R.T + t
x, y, z = corners.T
rect = go.Scatter3d(
x=x,
y=y,
z=z,
line=dict(color=color),
legendgroup=legendgroup,
name=name,
marker=dict(size=0.0001),
showlegend=False,
)
fig.add_trace(rect)
x, y, z = np.concatenate(([t], corners)).T
i = [0, 0, 0, 0]
j = [1, 2, 3, 4]
k = [2, 3, 4, 1]
pyramid = go.Mesh3d(
x=x,
y=y,
z=z,
color=color,
i=i,
j=j,
k=k,
legendgroup=legendgroup,
name=name,
showlegend=False,
)
fig.add_trace(pyramid)
triangles = np.vstack((i, j, k)).T
vertices = np.concatenate(([t], corners))
tri_points = np.array([vertices[i] for i in triangles.reshape(-1)])
x, y, z = tri_points.T
pyramid = go.Scatter3d(
x=x,
y=y,
z=z,
mode="lines",
legendgroup=legendgroup,
name=name,
line=dict(color=color, width=1),
showlegend=False,
)
fig.add_trace(pyramid)
def plot_camera_colmap(
fig: go.Figure,
image: pycolmap.Image,
camera: pycolmap.Camera,
name: Optional[str] = None,
**kwargs
):
"""Plot a camera frustum from PyCOLMAP objects"""
plot_camera(
fig,
image.rotmat().T,
image.projection_center(),
camera.calibration_matrix(),
name=name or str(image.image_id),
**kwargs
)
def plot_cameras(
fig: go.Figure, reconstruction: pycolmap.Reconstruction, **kwargs
):
"""Plot a camera as a cone with camera frustum."""
for image_id, image in reconstruction.images.items():
plot_camera_colmap(
fig, image, reconstruction.cameras[image.camera_id], **kwargs
)
def plot_reconstruction(
fig: go.Figure,
rec: pycolmap.Reconstruction,
max_reproj_error: float = 6.0,
color: str = "rgb(0, 0, 255)",
name: Optional[str] = None,
min_track_length: int = 2,
points: bool = True,
cameras: bool = True,
cs: float = 1.0,
):
# Filter outliers
bbs = rec.compute_bounding_box(0.001, 0.999)
# Filter points, use original reproj error here
xyzs = [
p3D.xyz
for _, p3D in rec.points3D.items()
if (
(p3D.xyz >= bbs[0]).all()
and (p3D.xyz <= bbs[1]).all()
and p3D.error <= max_reproj_error
and p3D.track.length() >= min_track_length
)
]
if points:
plot_points(fig, np.array(xyzs), color=color, ps=1, name=name)
if cameras:
plot_cameras(fig, rec, color=color, legendgroup=name, size=cs)