gzzyyxy's picture
Upload folder using huggingface_hub
c1a7f73 verified
import os
import pickle
import torch
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from argparse import ArgumentParser
from dev.datasets.preprocess import TokenProcessor
from dev.transforms.target_builder import WaymoTargetBuilder
colors = [
('#1f77b4', '#1a5a8a'), # blue
('#2ca02c', '#217721'), # green
('#ff7f0e', '#cc660b'), # orange
('#9467bd', '#6f4a91'), # purple
('#d62728', '#a31d1d'), # red
('#000000', '#000000'), # black
]
def draw_map(tokenize_data, token_processor: TokenProcessor, index, posfix):
print("Drawing raw data ...")
shift = 5
token_size = 2048
traj_token = token_processor.trajectory_token["veh"]
traj_token_all = token_processor.trajectory_token_all["veh"]
plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3)
fig, ax = plt.subplots()
ax.set_axis_off()
scenario_id = data['scenario_id']
ax.scatter(tokenize_data["map_point"]["position"][:, 0],
tokenize_data["map_point"]["position"][:, 1], s=0.2, c='black', edgecolors='none')
index = np.array(index).astype(np.int32)
agent_data = tokenize_data["agent"]
token_index = agent_data["token_idx"][index]
token_valid_mask = agent_data["agent_valid_mask"][index]
num_agent, num_token = token_index.shape
tokens = traj_token[token_index.view(-1)].reshape(num_agent, num_token, 4, 2)
tokens_all = traj_token_all[token_index.view(-1)].reshape(num_agent, num_token, 6, 4, 2)
position = agent_data['position'][index, :, :2] # (num_agent, 91, 2)
heading = agent_data['heading'][index] # (num_agent, 91)
valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) # (num_agent, 91)
# TODO: fix this
if args.smart:
for shifted_tid in range(token_valid_mask.shape[1]):
valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_valid_mask[:, shifted_tid : shifted_tid + 1].repeat(1, shift)
else:
for shifted_tid in range(token_index.shape[1]):
valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_index[:, shifted_tid : shifted_tid + 1] != token_size + 2
last_valid_step = valid_mask.shape[1] - 1 - torch.argmax(valid_mask.flip(dims=[1]).long(), dim=1)
last_valid_step = {int(index[i]): int(last_valid_step[i]) for i in range(len(index))}
_, token_num, token_contour_dim, feat_dim = tokens.shape
tokens_src = tokens.reshape(num_agent, token_num * token_contour_dim, feat_dim)
tokens_all_src = tokens_all.reshape(num_agent, token_num * 6 * token_contour_dim, feat_dim)
prev_heading = heading[:, 0]
prev_pos = position[:, 0]
fig_paths = []
agent_colors = np.zeros((num_agent, position.shape[1]))
shape = np.zeros((num_agent, position.shape[1], 2)) + 3.
for tid in tqdm(range(shift, position.shape[1], shift), leave=False, desc="Token ..."):
cos, sin = prev_heading.cos(), prev_heading.sin()
rot_mat = prev_heading.new_zeros(num_agent, 2, 2)
rot_mat[:, 0, 0] = cos
rot_mat[:, 0, 1] = sin
rot_mat[:, 1, 0] = -sin
rot_mat[:, 1, 1] = cos
tokens_world = torch.bmm(torch.from_numpy(tokens_src).float(), rot_mat).reshape(num_agent,
token_num,
token_contour_dim,
feat_dim)
tokens_all_world = torch.bmm(torch.from_numpy(tokens_all_src).float(), rot_mat).reshape(num_agent,
token_num,
6,
token_contour_dim,
feat_dim)
tokens_world += prev_pos[:, None, None, :2]
tokens_all_world += prev_pos[:, None, None, None, :2]
tokens_select = tokens_world[:, tid // shift - 1] # (num_agent, token_contour_dim, feat_dim)
tokens_all_select = tokens_all_world[:, tid // shift - 1] # (num_agent, 6, token_contour_dim, feat_dim)
diff_xy = tokens_select[:, 0, :] - tokens_select[:, 3, :]
prev_heading = heading[:, tid].clone()
# prev_heading[valid_mask[:, tid - shift]] = torch.arctan2(diff_xy[:, 1], diff_xy[:, 0])[
# valid_mask[:, tid - shift]]
prev_pos = position[:, tid].clone()
# prev_pos[valid_mask[:, tid - shift]] = tokens_select.mean(dim=1)[valid_mask[:, tid - shift]]
# NOTE tokens_pos equals to tokens_all_pos[:, -1]
tokens_pos = tokens_select.mean(dim=1) # (num_agent, 2)
tokens_all_pos = tokens_all_select.mean(dim=2) # (num_agent, 6, 2)
# colors
cur_token_index = token_index[:, tid // shift - 1]
is_bos = cur_token_index == token_size
is_eos = cur_token_index == token_size + 1
is_invalid = cur_token_index == token_size + 2
is_valid = ~is_bos & ~is_eos & ~is_invalid
agent_colors[is_valid, tid - shift : tid] = 1
agent_colors[is_bos, tid - shift : tid] = 2
agent_colors[is_eos, tid - shift : tid] = 3
agent_colors[is_invalid, tid - shift : tid] = 4
for i in tqdm(range(shift), leave=False, desc="Timestep ..."):
global_tid = tid - shift + i
cur_valid_mask = valid_mask[:, tid - shift] # only when the last tokenized timestep is valid the current shifts trajectory is valid
xs = tokens_all_pos[cur_valid_mask, i, 0]
ys = tokens_all_pos[cur_valid_mask, i, 1]
widths = shape[cur_valid_mask, global_tid, 1]
lengths = shape[cur_valid_mask, global_tid, 0]
angles = heading[cur_valid_mask, global_tid]
cur_agent_colors = agent_colors[cur_valid_mask, global_tid]
current_index = index[cur_valid_mask]
drawn_agents = []
drawn_texts = []
for x, y, width, length, angle, color_type, id in zip(
xs, ys, widths, lengths, angles, cur_agent_colors, current_index):
if x < 3000: continue
agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360,
linewidth=0.2,
facecolor=colors[int(color_type) - 1][0],
edgecolor=colors[int(color_type) - 1][1])
ax.add_patch(agent)
text = plt.text(x-4, y-4, f"{str(id)}:{str(global_tid)}", fontdict={'family': 'serif', 'size': 3, 'color': 'red'})
if global_tid != last_valid_step[id]:
drawn_agents.append(agent)
drawn_texts.append(text)
# draw timestep to be tokenized
if global_tid % shift == 0:
tokenize_agent = plt.Rectangle((x, y), width, length, # angle=((angle + np.pi / 2) / np.pi * 360) % 360,
linewidth=0.2, fill=False,
edgecolor=colors[int(color_type) - 1][1])
ax.add_patch(tokenize_agent)
plt.gca().set_aspect('equal', adjustable='box')
fig_path = f"debug/tokenize/steps/{scenario_id}_{global_tid}.png"
plt.savefig(fig_path, dpi=600, bbox_inches="tight")
fig_paths.append(fig_path)
for drawn_agent, drawn_text in zip(drawn_agents, drawn_texts):
drawn_agent.remove()
drawn_text.remove()
plt.close()
# generate gif
import imageio.v2 as imageio
images = []
for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."):
images.append(imageio.imread(fig_path))
imageio.mimsave(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif", images, duration=0.1)
def main(data):
token_size = 2048
os.makedirs("debug/tokenize/steps/", exist_ok=True)
scenario_id = data["scenario_id"]
selected_agents_index = [1, 21, 35, 36, 46]
# raw data
if not os.path.exists(f"debug/tokenize/{scenario_id}_raw.gif"):
draw_raw(data, selected_agents_index)
# tokenization
token_processor = TokenProcessor(token_size, disable_invalid=args.smart)
print(f"Loaded token processor with token_size: {token_size}")
data = token_processor.preprocess(data)
# tokenzied data
posfix = "smart" if args.smart else "ours"
# if not os.path.exists(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif"):
draw_tokenize(data, token_processor, selected_agents_index, posfix)
target_builder = WaymoTargetBuilder(num_historical_steps=11, num_future_steps=80)
data = target_builder(data)
if __name__ == "__main__":
parser = ArgumentParser(description="Testing script parameters")
parser.add_argument("--smart", action="store_true")
parser.add_argument("--data_path", type=str, default="/u/xiuyu/work/dev4/data/waymo_processed/training")
args = parser.parse_args()
scenario_id = "74ad7b76d5906d39"
data_path = os.path.join(args.data_path, f"{scenario_id}.pkl")
data = pickle.load(open(data_path, "rb"))
print(f"Loaded scenario {scenario_id}")
main(data)