Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import numpy as np | |
import cv2 | |
import torch | |
import flow_vis | |
from matplotlib import cm | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from moviepy.editor import ImageSequenceClip | |
import matplotlib.pyplot as plt | |
from tqdm import tqdm | |
def read_video_from_path(path): | |
cap = cv2.VideoCapture(path) | |
if not cap.isOpened(): | |
print("Error opening video file") | |
else: | |
frames = [] | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if ret == True: | |
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))) | |
else: | |
break | |
cap.release() | |
return np.stack(frames) | |
class Visualizer: | |
def __init__( | |
self, | |
save_dir: str = "./results", | |
grayscale: bool = False, | |
pad_value: int = 0, | |
fps: int = 10, | |
mode: str = "rainbow", # 'cool', 'optical_flow' | |
linewidth: int = 1, | |
show_first_frame: int = 10, | |
tracks_leave_trace: int = 0, # -1 for infinite | |
): | |
self.mode = mode | |
self.save_dir = save_dir | |
self.vtxt_path = os.path.join(save_dir, "videos.txt") | |
self.ttxt_path = os.path.join(save_dir, "trackings.txt") | |
if mode == "rainbow": | |
self.color_map = cm.get_cmap("gist_rainbow") | |
elif mode == "cool": | |
self.color_map = cm.get_cmap(mode) | |
self.show_first_frame = show_first_frame | |
self.grayscale = grayscale | |
self.tracks_leave_trace = tracks_leave_trace | |
self.pad_value = pad_value | |
self.linewidth = linewidth | |
self.fps = fps | |
def visualize( | |
self, | |
video: torch.Tensor, # (B,T,C,H,W) | |
tracks: torch.Tensor, # (B,T,N,2) | |
visibility: torch.Tensor = None, # (B, T, N, 1) bool | |
gt_tracks: torch.Tensor = None, # (B,T,N,2) | |
segm_mask: torch.Tensor = None, # (B,1,H,W) | |
filename: str = "video", | |
writer=None, # tensorboard Summary Writer, used for visualization during training | |
step: int = 0, | |
query_frame: int = 0, | |
save_video: bool = True, | |
compensate_for_camera_motion: bool = False, | |
rigid_part = None, | |
video_depth = None # (B,T,C,H,W) | |
): | |
if compensate_for_camera_motion: | |
assert segm_mask is not None | |
if segm_mask is not None: | |
coords = tracks[0, query_frame].round().long() | |
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() | |
video = F.pad( | |
video, | |
(self.pad_value, self.pad_value, self.pad_value, self.pad_value), | |
"constant", | |
255, | |
) | |
if video_depth is not None: | |
video_depth = (video_depth*255).cpu().numpy().astype(np.uint8) | |
video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO) | |
for i in range(video_depth.shape[1])]) | |
video_depth = np.stack(video_depth, axis=0) | |
video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None] | |
tracks = tracks + self.pad_value | |
if self.grayscale: | |
transform = transforms.Grayscale() | |
video = transform(video) | |
video = video.repeat(1, 1, 3, 1, 1) | |
tracking_video = self.draw_tracks_on_video( | |
video=video, | |
tracks=tracks, | |
visibility=visibility, | |
segm_mask=segm_mask, | |
gt_tracks=gt_tracks, | |
query_frame=query_frame, | |
compensate_for_camera_motion=compensate_for_camera_motion, | |
rigid_part=rigid_part | |
) | |
if save_video: | |
# import ipdb; ipdb.set_trace() | |
tracking_dir = os.path.join(self.save_dir, "tracking") | |
if not os.path.exists(tracking_dir): | |
os.makedirs(tracking_dir) | |
self.save_video(tracking_video, filename=filename+"_tracking", | |
savedir=tracking_dir, writer=writer, step=step) | |
# with open(self.ttxt_path, 'a') as file: | |
# file.write(f"tracking/{filename}_tracking.mp4\n") | |
videos_dir = os.path.join(self.save_dir, "videos") | |
if not os.path.exists(videos_dir): | |
os.makedirs(videos_dir) | |
self.save_video(video, filename=filename, | |
savedir=videos_dir, writer=writer, step=step) | |
# with open(self.vtxt_path, 'a') as file: | |
# file.write(f"videos/{filename}.mp4\n") | |
if video_depth is not None: | |
self.save_video(video_depth, filename=filename+"_depth", | |
savedir=os.path.join(self.save_dir, "depth"), writer=writer, step=step) | |
return tracking_video | |
def save_video(self, video, filename, savedir=None, writer=None, step=0): | |
if writer is not None: | |
writer.add_video( | |
f"{filename}", | |
video.to(torch.uint8), | |
global_step=step, | |
fps=self.fps, | |
) | |
else: | |
os.makedirs(self.save_dir, exist_ok=True) | |
wide_list = list(video.unbind(1)) | |
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] | |
# clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps) | |
clip = ImageSequenceClip(wide_list, fps=self.fps) | |
# Write the video file | |
if savedir is None: | |
save_path = os.path.join(self.save_dir, f"{filename}.mp4") | |
else: | |
save_path = os.path.join(savedir, f"{filename}.mp4") | |
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None) | |
print(f"Video saved to {save_path}") | |
def draw_tracks_on_video( | |
self, | |
video: torch.Tensor, | |
tracks: torch.Tensor, | |
visibility: torch.Tensor = None, | |
segm_mask: torch.Tensor = None, | |
gt_tracks=None, | |
query_frame: int = 0, | |
compensate_for_camera_motion=False, | |
rigid_part=None, | |
): | |
B, T, C, H, W = video.shape | |
_, _, N, D = tracks.shape | |
assert D == 3 | |
assert C == 3 | |
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C | |
tracks = tracks[0].detach().cpu().numpy() # S, N, 2 | |
if gt_tracks is not None: | |
gt_tracks = gt_tracks[0].detach().cpu().numpy() | |
res_video = [] | |
# process input video | |
# for rgb in video: | |
# res_video.append(rgb.copy()) | |
# create a blank tensor with the same shape as the video | |
for rgb in video: | |
black_frame = np.zeros_like(rgb.copy(), dtype=rgb.dtype) | |
res_video.append(black_frame) | |
vector_colors = np.zeros((T, N, 3)) | |
if self.mode == "optical_flow": | |
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) | |
elif segm_mask is None: | |
if self.mode == "rainbow": | |
x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max() | |
y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max() | |
z_inv = 1/tracks[0, :, 2] | |
z_min, z_max = np.percentile(z_inv, [2, 98]) | |
norm_x = plt.Normalize(x_min, x_max) | |
norm_y = plt.Normalize(y_min, y_max) | |
norm_z = plt.Normalize(z_min, z_max) | |
for n in range(N): | |
r = norm_x(tracks[0, n, 0]) | |
g = norm_y(tracks[0, n, 1]) | |
# r = 0 | |
# g = 0 | |
b = norm_z(1/tracks[0, n, 2]) | |
color = np.array([r, g, b])[None] * 255 | |
vector_colors[:, n] = np.repeat(color, T, axis=0) | |
else: | |
# color changes with time | |
for t in range(T): | |
color = np.array(self.color_map(t / T)[:3])[None] * 255 | |
vector_colors[t] = np.repeat(color, N, axis=0) | |
else: | |
if self.mode == "rainbow": | |
vector_colors[:, segm_mask <= 0, :] = 255 | |
x_min, x_max = tracks[0, :, 0].min(), tracks[0, :, 0].max() | |
y_min, y_max = tracks[0, :, 1].min(), tracks[0, :, 1].max() | |
z_min, z_max = tracks[0, :, 2].min(), tracks[0, :, 2].max() | |
norm_x = plt.Normalize(x_min, x_max) | |
norm_y = plt.Normalize(y_min, y_max) | |
norm_z = plt.Normalize(z_min, z_max) | |
for n in range(N): | |
r = norm_x(tracks[0, n, 0]) | |
g = norm_y(tracks[0, n, 1]) | |
b = norm_z(tracks[0, n, 2]) | |
color = np.array([r, g, b])[None] * 255 | |
vector_colors[:, n] = np.repeat(color, T, axis=0) | |
else: | |
# color changes with segm class | |
segm_mask = segm_mask.cpu() | |
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) | |
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 | |
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 | |
vector_colors = np.repeat(color[None], T, axis=0) | |
# Draw tracks | |
if self.tracks_leave_trace != 0: | |
for t in range(1, T): | |
first_ind = ( | |
max(0, t - self.tracks_leave_trace) | |
if self.tracks_leave_trace >= 0 | |
else 0 | |
) | |
curr_tracks = tracks[first_ind : t + 1] | |
curr_colors = vector_colors[first_ind : t + 1] | |
if compensate_for_camera_motion: | |
diff = ( | |
tracks[first_ind : t + 1, segm_mask <= 0] | |
- tracks[t : t + 1, segm_mask <= 0] | |
).mean(1)[:, None] | |
curr_tracks = curr_tracks - diff | |
curr_tracks = curr_tracks[:, segm_mask > 0] | |
curr_colors = curr_colors[:, segm_mask > 0] | |
res_video[t] = self._draw_pred_tracks( | |
res_video[t], | |
curr_tracks, | |
curr_colors, | |
) | |
if gt_tracks is not None: | |
res_video[t] = self._draw_gt_tracks( | |
res_video[t], gt_tracks[first_ind : t + 1] | |
) | |
if rigid_part is not None: | |
cls_label = torch.unique(rigid_part) | |
cls_num = len(torch.unique(rigid_part)) | |
# visualize the clustering results | |
cmap = plt.get_cmap('jet') # get the color mapping | |
colors = cmap(np.linspace(0, 1, cls_num)) | |
colors = (colors[:, :3] * 255) | |
color_map = {lable.item(): color for lable, color in zip(cls_label, colors)} | |
# Draw points | |
for t in tqdm(range(T)): | |
# Create a list to store information for each point | |
points_info = [] | |
for i in range(N): | |
coord = (tracks[t, i, 0], tracks[t, i, 1]) | |
depth = tracks[t, i, 2] # assume the third dimension is depth | |
visibile = True | |
if visibility is not None: | |
visibile = visibility[0, t, i] | |
if coord[0] != 0 and coord[1] != 0: | |
if not compensate_for_camera_motion or ( | |
compensate_for_camera_motion and segm_mask[i] > 0 | |
): | |
points_info.append((i, coord, depth, visibile)) | |
# Sort points by depth, points with smaller depth (closer) will be drawn later | |
points_info.sort(key=lambda x: x[2], reverse=True) | |
for i, coord, _, visibile in points_info: | |
if rigid_part is not None: | |
color = color_map[rigid_part.squeeze()[i].item()] | |
cv2.circle( | |
res_video[t], | |
coord, | |
int(self.linewidth * 2), | |
color.tolist(), | |
thickness=-1 if visibile else 2 | |
-1, | |
) | |
else: | |
# Determine rectangle width based on the distance between adjacent tracks in the first frame | |
if t == 0: | |
distances = np.linalg.norm(tracks[0] - tracks[0, i], axis=1) | |
distances = distances[distances > 0] | |
rect_size = int(np.min(distances))/2 | |
# Define coordinates for top-left and bottom-right corners of the rectangle | |
top_left = (int(coord[0] - rect_size), int(coord[1] - rect_size/1.5)) # Rectangle width is 1.5x (video aspect ratio is 1.5:1) | |
bottom_right = (int(coord[0] + rect_size), int(coord[1] + rect_size/1.5)) | |
# Draw rectangle | |
cv2.rectangle( | |
res_video[t], | |
top_left, | |
bottom_right, | |
vector_colors[t, i].tolist(), | |
thickness=-1 if visibile else 0 | |
-1, | |
) | |
# Construct the final rgb sequence | |
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() | |
def _draw_pred_tracks( | |
self, | |
rgb: np.ndarray, # H x W x 3 | |
tracks: np.ndarray, # T x 2 | |
vector_colors: np.ndarray, | |
alpha: float = 0.5, | |
): | |
T, N, _ = tracks.shape | |
for s in range(T - 1): | |
vector_color = vector_colors[s] | |
original = rgb.copy() | |
alpha = (s / T) ** 2 | |
for i in range(N): | |
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) | |
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) | |
if coord_y[0] != 0 and coord_y[1] != 0: | |
cv2.line( | |
rgb, | |
coord_y, | |
coord_x, | |
vector_color[i].tolist(), | |
self.linewidth, | |
cv2.LINE_AA, | |
) | |
if self.tracks_leave_trace > 0: | |
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0) | |
return rgb | |
def _draw_gt_tracks( | |
self, | |
rgb: np.ndarray, # H x W x 3, | |
gt_tracks: np.ndarray, # T x 2 | |
): | |
T, N, _ = gt_tracks.shape | |
color = np.array((211.0, 0.0, 0.0)) | |
for t in range(T): | |
for i in range(N): | |
gt_tracks = gt_tracks[t][i] | |
# draw a red cross | |
if gt_tracks[0] > 0 and gt_tracks[1] > 0: | |
length = self.linewidth * 3 | |
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) | |
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) | |
cv2.line( | |
rgb, | |
coord_y, | |
coord_x, | |
color, | |
self.linewidth, | |
cv2.LINE_AA, | |
) | |
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) | |
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) | |
cv2.line( | |
rgb, | |
coord_y, | |
coord_x, | |
color, | |
self.linewidth, | |
cv2.LINE_AA, | |
) | |
return rgb | |