|
import os |
|
import pickle |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from tqdm import tqdm |
|
from argparse import ArgumentParser |
|
from dev.datasets.preprocess import TokenProcessor |
|
from dev.transforms.target_builder import WaymoTargetBuilder |
|
|
|
|
|
colors = [ |
|
('#1f77b4', '#1a5a8a'), |
|
('#2ca02c', '#217721'), |
|
('#ff7f0e', '#cc660b'), |
|
('#9467bd', '#6f4a91'), |
|
('#d62728', '#a31d1d'), |
|
('#000000', '#000000'), |
|
] |
|
|
|
|
|
def draw_map(tokenize_data, token_processor: TokenProcessor, index, posfix): |
|
print("Drawing raw data ...") |
|
shift = 5 |
|
token_size = 2048 |
|
|
|
traj_token = token_processor.trajectory_token["veh"] |
|
traj_token_all = token_processor.trajectory_token_all["veh"] |
|
|
|
plt.subplots_adjust(left=0.3, right=0.7, top=0.7, bottom=0.3) |
|
fig, ax = plt.subplots() |
|
ax.set_axis_off() |
|
|
|
scenario_id = data['scenario_id'] |
|
ax.scatter(tokenize_data["map_point"]["position"][:, 0], |
|
tokenize_data["map_point"]["position"][:, 1], s=0.2, c='black', edgecolors='none') |
|
|
|
index = np.array(index).astype(np.int32) |
|
agent_data = tokenize_data["agent"] |
|
token_index = agent_data["token_idx"][index] |
|
token_valid_mask = agent_data["agent_valid_mask"][index] |
|
|
|
num_agent, num_token = token_index.shape |
|
tokens = traj_token[token_index.view(-1)].reshape(num_agent, num_token, 4, 2) |
|
tokens_all = traj_token_all[token_index.view(-1)].reshape(num_agent, num_token, 6, 4, 2) |
|
|
|
position = agent_data['position'][index, :, :2] |
|
heading = agent_data['heading'][index] |
|
valid_mask = (position[..., 0] != 0) & (position[..., 1] != 0) |
|
|
|
if args.smart: |
|
for shifted_tid in range(token_valid_mask.shape[1]): |
|
valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_valid_mask[:, shifted_tid : shifted_tid + 1].repeat(1, shift) |
|
else: |
|
for shifted_tid in range(token_index.shape[1]): |
|
valid_mask[:, shifted_tid * shift : (shifted_tid + 1) * shift] = token_index[:, shifted_tid : shifted_tid + 1] != token_size + 2 |
|
last_valid_step = valid_mask.shape[1] - 1 - torch.argmax(valid_mask.flip(dims=[1]).long(), dim=1) |
|
last_valid_step = {int(index[i]): int(last_valid_step[i]) for i in range(len(index))} |
|
|
|
_, token_num, token_contour_dim, feat_dim = tokens.shape |
|
tokens_src = tokens.reshape(num_agent, token_num * token_contour_dim, feat_dim) |
|
tokens_all_src = tokens_all.reshape(num_agent, token_num * 6 * token_contour_dim, feat_dim) |
|
prev_heading = heading[:, 0] |
|
prev_pos = position[:, 0] |
|
|
|
fig_paths = [] |
|
agent_colors = np.zeros((num_agent, position.shape[1])) |
|
shape = np.zeros((num_agent, position.shape[1], 2)) + 3. |
|
for tid in tqdm(range(shift, position.shape[1], shift), leave=False, desc="Token ..."): |
|
cos, sin = prev_heading.cos(), prev_heading.sin() |
|
rot_mat = prev_heading.new_zeros(num_agent, 2, 2) |
|
rot_mat[:, 0, 0] = cos |
|
rot_mat[:, 0, 1] = sin |
|
rot_mat[:, 1, 0] = -sin |
|
rot_mat[:, 1, 1] = cos |
|
tokens_world = torch.bmm(torch.from_numpy(tokens_src).float(), rot_mat).reshape(num_agent, |
|
token_num, |
|
token_contour_dim, |
|
feat_dim) |
|
tokens_all_world = torch.bmm(torch.from_numpy(tokens_all_src).float(), rot_mat).reshape(num_agent, |
|
token_num, |
|
6, |
|
token_contour_dim, |
|
feat_dim) |
|
tokens_world += prev_pos[:, None, None, :2] |
|
tokens_all_world += prev_pos[:, None, None, None, :2] |
|
tokens_select = tokens_world[:, tid // shift - 1] |
|
tokens_all_select = tokens_all_world[:, tid // shift - 1] |
|
|
|
diff_xy = tokens_select[:, 0, :] - tokens_select[:, 3, :] |
|
prev_heading = heading[:, tid].clone() |
|
|
|
|
|
prev_pos = position[:, tid].clone() |
|
|
|
|
|
|
|
tokens_pos = tokens_select.mean(dim=1) |
|
tokens_all_pos = tokens_all_select.mean(dim=2) |
|
|
|
|
|
cur_token_index = token_index[:, tid // shift - 1] |
|
is_bos = cur_token_index == token_size |
|
is_eos = cur_token_index == token_size + 1 |
|
is_invalid = cur_token_index == token_size + 2 |
|
is_valid = ~is_bos & ~is_eos & ~is_invalid |
|
agent_colors[is_valid, tid - shift : tid] = 1 |
|
agent_colors[is_bos, tid - shift : tid] = 2 |
|
agent_colors[is_eos, tid - shift : tid] = 3 |
|
agent_colors[is_invalid, tid - shift : tid] = 4 |
|
|
|
for i in tqdm(range(shift), leave=False, desc="Timestep ..."): |
|
global_tid = tid - shift + i |
|
cur_valid_mask = valid_mask[:, tid - shift] |
|
xs = tokens_all_pos[cur_valid_mask, i, 0] |
|
ys = tokens_all_pos[cur_valid_mask, i, 1] |
|
widths = shape[cur_valid_mask, global_tid, 1] |
|
lengths = shape[cur_valid_mask, global_tid, 0] |
|
angles = heading[cur_valid_mask, global_tid] |
|
cur_agent_colors = agent_colors[cur_valid_mask, global_tid] |
|
current_index = index[cur_valid_mask] |
|
|
|
drawn_agents = [] |
|
drawn_texts = [] |
|
for x, y, width, length, angle, color_type, id in zip( |
|
xs, ys, widths, lengths, angles, cur_agent_colors, current_index): |
|
if x < 3000: continue |
|
agent = plt.Rectangle((x, y), width, length, |
|
linewidth=0.2, |
|
facecolor=colors[int(color_type) - 1][0], |
|
edgecolor=colors[int(color_type) - 1][1]) |
|
ax.add_patch(agent) |
|
text = plt.text(x-4, y-4, f"{str(id)}:{str(global_tid)}", fontdict={'family': 'serif', 'size': 3, 'color': 'red'}) |
|
|
|
if global_tid != last_valid_step[id]: |
|
drawn_agents.append(agent) |
|
drawn_texts.append(text) |
|
|
|
|
|
if global_tid % shift == 0: |
|
tokenize_agent = plt.Rectangle((x, y), width, length, |
|
linewidth=0.2, fill=False, |
|
edgecolor=colors[int(color_type) - 1][1]) |
|
ax.add_patch(tokenize_agent) |
|
|
|
plt.gca().set_aspect('equal', adjustable='box') |
|
|
|
fig_path = f"debug/tokenize/steps/{scenario_id}_{global_tid}.png" |
|
plt.savefig(fig_path, dpi=600, bbox_inches="tight") |
|
fig_paths.append(fig_path) |
|
|
|
for drawn_agent, drawn_text in zip(drawn_agents, drawn_texts): |
|
drawn_agent.remove() |
|
drawn_text.remove() |
|
|
|
plt.close() |
|
|
|
|
|
import imageio.v2 as imageio |
|
images = [] |
|
for fig_path in tqdm(fig_paths, leave=False, desc="Generate gif ..."): |
|
images.append(imageio.imread(fig_path)) |
|
imageio.mimsave(f"debug/tokenize/{scenario_id}_tokenize_{posfix}.gif", images, duration=0.1) |
|
|
|
|
|
def main(data): |
|
|
|
token_size = 2048 |
|
|
|
os.makedirs("debug/tokenize/steps/", exist_ok=True) |
|
scenario_id = data["scenario_id"] |
|
|
|
selected_agents_index = [1, 21, 35, 36, 46] |
|
|
|
|
|
if not os.path.exists(f"debug/tokenize/{scenario_id}_raw.gif"): |
|
draw_raw(data, selected_agents_index) |
|
|
|
|
|
token_processor = TokenProcessor(token_size, disable_invalid=args.smart) |
|
print(f"Loaded token processor with token_size: {token_size}") |
|
data = token_processor.preprocess(data) |
|
|
|
|
|
posfix = "smart" if args.smart else "ours" |
|
|
|
draw_tokenize(data, token_processor, selected_agents_index, posfix) |
|
|
|
target_builder = WaymoTargetBuilder(num_historical_steps=11, num_future_steps=80) |
|
data = target_builder(data) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser(description="Testing script parameters") |
|
parser.add_argument("--smart", action="store_true") |
|
parser.add_argument("--data_path", type=str, default="/u/xiuyu/work/dev4/data/waymo_processed/training") |
|
args = parser.parse_args() |
|
|
|
scenario_id = "74ad7b76d5906d39" |
|
data_path = os.path.join(args.data_path, f"{scenario_id}.pkl") |
|
data = pickle.load(open(data_path, "rb")) |
|
print(f"Loaded scenario {scenario_id}") |
|
|
|
main(data) |