|
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
from mpl_toolkits.mplot3d import Axes3D |
|
from pytorch3d.vis.plotly_vis import get_camera_wireframe |
|
|
|
|
|
def plot_cameras(ax, cameras, color: str = "blue"): |
|
""" |
|
Plots a set of `cameras` objects into the maplotlib axis `ax` with |
|
color `color`. |
|
""" |
|
cam_wires_canonical = get_camera_wireframe().cuda()[None] |
|
cam_trans = cameras.get_world_to_view_transform().inverse() |
|
cam_wires_trans = cam_trans.transform_points(cam_wires_canonical) |
|
plot_handles = [] |
|
for wire in cam_wires_trans: |
|
|
|
x_, z_, y_ = wire.detach().cpu().numpy().T.astype(float) |
|
(h,) = ax.plot(x_, y_, z_, color=color, linewidth=0.3) |
|
plot_handles.append(h) |
|
return plot_handles |
|
|
|
|
|
def plot_camera_scene(cameras, cameras_gt, status: str): |
|
""" |
|
Plots a set of predicted cameras `cameras` and their corresponding |
|
ground truth locations `cameras_gt`. The plot is named with |
|
a string passed inside the `status` argument. |
|
""" |
|
fig = plt.figure() |
|
ax = fig.add_subplot(projection="3d") |
|
ax.clear() |
|
ax.set_title(status) |
|
handle_cam = plot_cameras(ax, cameras, color="#FF7D1E") |
|
handle_cam_gt = plot_cameras(ax, cameras_gt, color="#812CE5") |
|
plot_radius = 3 |
|
ax.set_xlim3d([-plot_radius, plot_radius]) |
|
ax.set_ylim3d([3 - plot_radius, 3 + plot_radius]) |
|
ax.set_zlim3d([-plot_radius, plot_radius]) |
|
ax.set_xlabel("x") |
|
ax.set_ylabel("z") |
|
ax.set_zlabel("y") |
|
labels_handles = { |
|
"Estimated cameras": handle_cam[0], |
|
"GT cameras": handle_cam_gt[0], |
|
} |
|
ax.legend( |
|
labels_handles.values(), |
|
labels_handles.keys(), |
|
loc="upper center", |
|
bbox_to_anchor=(0.5, 0), |
|
) |
|
plt.show() |
|
return fig |
|
|