risk_biased_prediction / export_waymo_to_json.py
jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
raw
history blame
3.19 kB
import json
from json import JSONEncoder
from mmcv import Config
import numpy
import torch
from risk_biased.utils.waymo_dataloader import WaymoDataloaders
class NumpyArrayEncoder(JSONEncoder):
def default(self, obj):
if isinstance(obj, numpy.ndarray):
return obj.tolist()
return JSONEncoder.default(self, obj)
if __name__ == "__main__":
output_path = "../risk_biased_dataset/data.json"
config_path = "risk_biased/config/waymo_config.py"
cfg = Config.fromfile(config_path)
dataloaders = WaymoDataloaders(cfg)
sample_dataloader = dataloaders.sample_dataloader()
(
x,
mask_x,
y,
mask_y,
mask_loss,
map_data,
mask_map,
offset,
x_ego,
y_ego,
) = sample_dataloader.collate_fn(sample_dataloader.dataset)
batch_size, n_agents, n_timesteps_past, n_features = x.shape
n_timesteps_future = y.shape[2]
n_features_map = map_data.shape[3]
n_features_offset = offset.shape[2]
print(x.shape)
print(mask_x.shape)
print(y.shape)
print(mask_y.shape)
print(mask_loss.shape)
print(map_data.shape)
print(mask_map.shape)
print(offset.shape)
print(x_ego.shape)
print(y_ego.shape)
data = {"x": x.numpy(),
"mask_x": mask_x.numpy(),
"y": y.numpy(),
"mask_y": mask_y.numpy(),
"mask_loss": mask_loss.numpy(),
"map_data": map_data.numpy(),
"mask_map": mask_map.numpy(),
"offset": offset.numpy(),
"x_ego": x_ego.numpy(),
"y_ego": y_ego.numpy(),
}
json_data = json.dumps(data, cls=NumpyArrayEncoder)
with open(output_path, "w+") as f:
f.write(json_data)
with open(output_path, "r") as f:
decoded = json.load(f)
x_c = torch.from_numpy(numpy.array(decoded["x"]).astype(numpy.float32))
mask_x_c = torch.from_numpy(numpy.array(decoded["mask_x"]).astype(numpy.bool8))
y_c = torch.from_numpy(numpy.array(decoded["y"]).astype(numpy.float32))
mask_y_c = torch.from_numpy(numpy.array(decoded["mask_y"]).astype(numpy.bool8))
mask_loss_c = torch.from_numpy( numpy.array(decoded["mask_loss"]).astype(numpy.bool8))
map_data_c = torch.from_numpy(numpy.array(decoded["map_data"]).astype(numpy.float32))
mask_map_c = torch.from_numpy(numpy.array(decoded["mask_map"]).astype(numpy.bool8))
offset_c = torch.from_numpy(numpy.array(decoded["offset"]).astype(numpy.float32))
x_ego_c = torch.from_numpy(numpy.array(decoded["x_ego"]).astype(numpy.float32))
y_ego_c = torch.from_numpy(numpy.array(decoded["y_ego"]).astype(numpy.float32))
assert torch.allclose(x, x_c)
assert torch.allclose(mask_x, mask_x_c)
assert torch.allclose(y, y_c)
assert torch.allclose(mask_y, mask_y_c)
assert torch.allclose(mask_loss, mask_loss_c)
assert torch.allclose(map_data, map_data_c)
assert torch.allclose(mask_map, mask_map_c)
assert torch.allclose(offset, offset_c)
assert torch.allclose(x_ego, x_ego_c)
assert torch.allclose(y_ego, y_ego_c)
print("All good!")