Beijia11
init
3aba902
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import cv2
import torch
import flow_vis
from matplotlib import cm
import torch.nn.functional as F
import torchvision.transforms as transforms
from moviepy.editor import ImageSequenceClip
import matplotlib.pyplot as plt
from tqdm import tqdm
def read_video_from_path(path):
cap = cv2.VideoCapture(path)
if not cap.isOpened():
print("Error opening video file")
else:
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret == True:
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
else:
break
cap.release()
return np.stack(frames)
class Visualizer:
def __init__(
self,
save_dir: str = "./results",
grayscale: bool = False,
pad_value: int = 0,
fps: int = 10,
mode: str = "rainbow", # 'cool', 'optical_flow'
linewidth: int = 1,
show_first_frame: int = 10,
tracks_leave_trace: int = 0, # -1 for infinite
):
self.mode = mode
self.save_dir = save_dir
self.vtxt_path = os.path.join(save_dir, "videos.txt")
self.ttxt_path = os.path.join(save_dir, "trackings.txt")
if mode == "rainbow":
self.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool":
self.color_map = cm.get_cmap(mode)
self.show_first_frame = show_first_frame
self.grayscale = grayscale
self.tracks_leave_trace = tracks_leave_trace
self.pad_value = pad_value
self.linewidth = linewidth
self.fps = fps
def visualize(
self,
video: torch.Tensor, # (B,T,C,H,W)
tracks: torch.Tensor, # (B,T,N,2)
visibility: torch.Tensor = None, # (B, T, N, 1) bool
gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video",
writer=None, # tensorboard Summary Writer, used for visualization during training
step: int = 0,
query_frame: int = 0,
save_video: bool = True,
compensate_for_camera_motion: bool = False,
rigid_part = None,
video_depth = None # (B,T,C,H,W)
):
if compensate_for_camera_motion:
assert segm_mask is not None
if segm_mask is not None:
coords = tracks[0, query_frame].round().long()
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
video = F.pad(
video,
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
"constant",
255,
)
if video_depth is not None:
video_depth = (video_depth*255).cpu().numpy().astype(np.uint8)
video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO)
for i in range(video_depth.shape[1])])
video_depth = np.stack(video_depth, axis=0)
video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None]
tracks = tracks + self.pad_value
if self.grayscale:
transform = transforms.Grayscale()
video = transform(video)
video = video.repeat(1, 1, 3, 1, 1)
tracking_video = self.draw_tracks_on_video(
video=video,
tracks=tracks,
visibility=visibility,
segm_mask=segm_mask,
gt_tracks=gt_tracks,
query_frame=query_frame,
compensate_for_camera_motion=compensate_for_camera_motion,
rigid_part=rigid_part
)
if save_video:
# import ipdb; ipdb.set_trace()
tracking_dir = os.path.join(self.save_dir, "tracking")
if not os.path.exists(tracking_dir):
os.makedirs(tracking_dir)
self.save_video(tracking_video, filename=filename+"_tracking",
savedir=tracking_dir, writer=writer, step=step)
# with open(self.ttxt_path, 'a') as file:
# file.write(f"tracking/{filename}_tracking.mp4\n")
videos_dir = os.path.join(self.save_dir, "videos")
if not os.path.exists(videos_dir):
os.makedirs(videos_dir)
self.save_video(video, filename=filename,
savedir=videos_dir, writer=writer, step=step)
# with open(self.vtxt_path, 'a') as file:
# file.write(f"videos/{filename}.mp4\n")
if video_depth is not None:
self.save_video(video_depth, filename=filename+"_depth",
savedir=os.path.join(self.save_dir, "depth"), writer=writer, step=step)
return tracking_video
def save_video(self, video, filename, savedir=None, writer=None, step=0):
if writer is not None:
writer.add_video(
f"{filename}",
video.to(torch.uint8),
global_step=step,
fps=self.fps,
)
else:
os.makedirs(self.save_dir, exist_ok=True)
wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
# clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
clip = ImageSequenceClip(wide_list, fps=self.fps)
# Write the video file
if savedir is None:
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
else:
save_path = os.path.join(savedir, f"{filename}.mp4")
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
print(f"Video saved to {save_path}")
def draw_tracks_on_video(
self,
video: torch.Tensor,
tracks: torch.Tensor,
visibility: torch.Tensor = None,
segm_mask: torch.Tensor = None,
gt_tracks=None,
query_frame: int = 0,
compensate_for_camera_motion=False,
rigid_part=None,
):
B, T, C, H, W = video.shape
_, _, N, D = tracks.shape
assert D == 3
assert C == 3
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
tracks = tracks[0].detach().cpu().numpy() # S, N, 2
if gt_tracks is not None:
gt_tracks = gt_tracks[0].detach().cpu().numpy()
res_video = []
# process input video
# for rgb in video:
# res_video.append(rgb.copy())
# create a blank tensor with the same shape as the video
for rgb in video:
black_frame = np.zeros_like(rgb.copy(), dtype=rgb.dtype)
res_video.append(black_frame)
vector_colors = np.zeros((T, N, 3))
if self.mode == "optical_flow":
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
elif segm_mask is None:
if self.mode == "rainbow":
x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max()
y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
z_inv = 1/tracks[0, :, 2]
z_min, z_max = np.percentile(z_inv, [2, 98])
norm_x = plt.Normalize(x_min, x_max)
norm_y = plt.Normalize(y_min, y_max)
norm_z = plt.Normalize(z_min, z_max)
for n in range(N):
r = norm_x(tracks[0, n, 0])
g = norm_y(tracks[0, n, 1])
# r = 0
# g = 0
b = norm_z(1/tracks[0, n, 2])
color = np.array([r, g, b])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with time
for t in range(T):
color = np.array(self.color_map(t / T)[:3])[None] * 255
vector_colors[t] = np.repeat(color, N, axis=0)
else:
if self.mode == "rainbow":
vector_colors[:, segm_mask <= 0, :] = 255
x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max()
y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max()
z_min, z_max = tracks[0, :, 2].min(), tracks[0, :, 2].max()
norm_x = plt.Normalize(x_min, x_max)
norm_y = plt.Normalize(y_min, y_max)
norm_z = plt.Normalize(z_min, z_max)
for n in range(N):
r = norm_x(tracks[0, n, 0])
g = norm_y(tracks[0, n, 1])
b = norm_z(tracks[0, n, 2])
color = np.array([r, g, b])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with segm class
segm_mask = segm_mask.cpu()
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
vector_colors = np.repeat(color[None], T, axis=0)
# Draw tracks
if self.tracks_leave_trace != 0:
for t in range(1, T):
first_ind = (
max(0, t - self.tracks_leave_trace)
if self.tracks_leave_trace >= 0
else 0
)
curr_tracks = tracks[first_ind : t + 1]
curr_colors = vector_colors[first_ind : t + 1]
if compensate_for_camera_motion:
diff = (
tracks[first_ind : t + 1, segm_mask <= 0]
- tracks[t : t + 1, segm_mask <= 0]
).mean(1)[:, None]
curr_tracks = curr_tracks - diff
curr_tracks = curr_tracks[:, segm_mask > 0]
curr_colors = curr_colors[:, segm_mask > 0]
res_video[t] = self._draw_pred_tracks(
res_video[t],
curr_tracks,
curr_colors,
)
if gt_tracks is not None:
res_video[t] = self._draw_gt_tracks(
res_video[t], gt_tracks[first_ind : t + 1]
)
if rigid_part is not None:
cls_label = torch.unique(rigid_part)
cls_num = len(torch.unique(rigid_part))
# visualize the clustering results
cmap = plt.get_cmap('jet') # get the color mapping
colors = cmap(np.linspace(0, 1, cls_num))
colors = (colors[:, :3] * 255)
color_map = {lable.item(): color for lable, color in zip(cls_label, colors)}
# Draw points
for t in tqdm(range(T)):
# Create a list to store information for each point
points_info = []
for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1])
depth = tracks[t, i, 2] # assume the third dimension is depth
visibile = True
if visibility is not None:
visibile = visibility[0, t, i]
if coord[0] != 0 and coord[1] != 0:
if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0
):
points_info.append((i, coord, depth, visibile))
# Sort points by depth, points with smaller depth (closer) will be drawn later
points_info.sort(key=lambda x: x[2], reverse=True)
for i, coord, _, visibile in points_info:
if rigid_part is not None:
color = color_map[rigid_part.squeeze()[i].item()]
cv2.circle(
res_video[t],
coord,
int(self.linewidth * 2),
color.tolist(),
thickness=-1 if visibile else 2
-1,
)
else:
# Determine rectangle width based on the distance between adjacent tracks in the first frame
if t == 0:
distances = np.linalg.norm(tracks[0] - tracks[0, i], axis=1)
distances = distances[distances > 0]
rect_size = int(np.min(distances))/2
# Define coordinates for top-left and bottom-right corners of the rectangle
top_left = (int(coord[0] - rect_size), int(coord[1] - rect_size/1.5)) # Rectangle width is 1.5x (video aspect ratio is 1.5:1)
bottom_right = (int(coord[0] + rect_size), int(coord[1] + rect_size/1.5))
# Draw rectangle
cv2.rectangle(
res_video[t],
top_left,
bottom_right,
vector_colors[t, i].tolist(),
thickness=-1 if visibile else 0
-1,
)
# Construct the final rgb sequence
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
def _draw_pred_tracks(
self,
rgb: np.ndarray, # H x W x 3
tracks: np.ndarray, # T x 2
vector_colors: np.ndarray,
alpha: float = 0.5,
):
T, N, _ = tracks.shape
for s in range(T - 1):
vector_color = vector_colors[s]
original = rgb.copy()
alpha = (s / T) ** 2
for i in range(N):
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
if coord_y[0] != 0 and coord_y[1] != 0:
cv2.line(
rgb,
coord_y,
coord_x,
vector_color[i].tolist(),
self.linewidth,
cv2.LINE_AA,
)
if self.tracks_leave_trace > 0:
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
return rgb
def _draw_gt_tracks(
self,
rgb: np.ndarray, # H x W x 3,
gt_tracks: np.ndarray, # T x 2
):
T, N, _ = gt_tracks.shape
color = np.array((211.0, 0.0, 0.0))
for t in range(T):
for i in range(N):
gt_tracks = gt_tracks[t][i]
# draw a red cross
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
length = self.linewidth * 3
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
cv2.line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
cv2.line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
return rgb