Spaces:
Build error
Build error
from pytorch_lightning import Trainer | |
from monoscene.models.monoscene import MonoScene | |
from monoscene.data.NYU.nyu_dm import NYUDataModule | |
from monoscene.data.semantic_kitti.kitti_dm import KittiDataModule | |
from monoscene.data.kitti_360.kitti_360_dm import Kitti360DataModule | |
import hydra | |
from omegaconf import DictConfig | |
import torch | |
import numpy as np | |
import os | |
from hydra.utils import get_original_cwd | |
from tqdm import tqdm | |
import pickle | |
def main(config: DictConfig): | |
torch.set_grad_enabled(False) | |
# Setup dataloader | |
if config.dataset == "kitti" or config.dataset == "kitti_360": | |
feature = 64 | |
project_scale = 2 | |
full_scene_size = (256, 256, 32) | |
if config.dataset == "kitti": | |
data_module = KittiDataModule( | |
root=config.kitti_root, | |
preprocess_root=config.kitti_preprocess_root, | |
frustum_size=config.frustum_size, | |
batch_size=int(config.batch_size / config.n_gpus), | |
num_workers=int(config.num_workers_per_gpu * config.n_gpus), | |
) | |
data_module.setup() | |
data_loader = data_module.val_dataloader() | |
# data_loader = data_module.test_dataloader() # use this if you want to infer on test set | |
else: | |
data_module = Kitti360DataModule( | |
root=config.kitti_360_root, | |
sequences=[config.kitti_360_sequence], | |
n_scans=2000, | |
batch_size=1, | |
num_workers=3, | |
) | |
data_module.setup() | |
data_loader = data_module.dataloader() | |
elif config.dataset == "NYU": | |
project_scale = 1 | |
feature = 200 | |
full_scene_size = (60, 36, 60) | |
data_module = NYUDataModule( | |
root=config.NYU_root, | |
preprocess_root=config.NYU_preprocess_root, | |
n_relations=config.n_relations, | |
frustum_size=config.frustum_size, | |
batch_size=int(config.batch_size / config.n_gpus), | |
num_workers=int(config.num_workers_per_gpu * config.n_gpus), | |
) | |
data_module.setup() | |
data_loader = data_module.val_dataloader() | |
# data_loader = data_module.test_dataloader() # use this if you want to infer on test set | |
else: | |
print("dataset not support") | |
# Load pretrained models | |
if config.dataset == "NYU": | |
model_path = os.path.join( | |
get_original_cwd(), "trained_models", "monoscene_nyu.ckpt" | |
) | |
else: | |
model_path = os.path.join( | |
get_original_cwd(), "trained_models", "monoscene_kitti.ckpt" | |
) | |
model = MonoScene.load_from_checkpoint( | |
model_path, | |
feature=feature, | |
project_scale=project_scale, | |
fp_loss=config.fp_loss, | |
full_scene_size=full_scene_size, | |
) | |
model.cuda() | |
model.eval() | |
# Save prediction and additional data | |
# to draw the viewing frustum and remove scene outside the room for NYUv2 | |
output_path = os.path.join(config.output_path, config.dataset) | |
with torch.no_grad(): | |
for batch in tqdm(data_loader): | |
batch["img"] = batch["img"].cuda() | |
pred = model(batch) | |
y_pred = torch.softmax(pred["ssc_logit"], dim=1).detach().cpu().numpy() | |
y_pred = np.argmax(y_pred, axis=1) | |
for i in range(config.batch_size): | |
out_dict = {"y_pred": y_pred[i].astype(np.uint16)} | |
if "target" in batch: | |
out_dict["target"] = ( | |
batch["target"][i].detach().cpu().numpy().astype(np.uint16) | |
) | |
if config.dataset == "NYU": | |
write_path = output_path | |
filepath = os.path.join(write_path, batch["name"][i] + ".pkl") | |
out_dict["cam_pose"] = batch["cam_pose"][i].detach().cpu().numpy() | |
out_dict["vox_origin"] = ( | |
batch["vox_origin"][i].detach().cpu().numpy() | |
) | |
else: | |
write_path = os.path.join(output_path, batch["sequence"][i]) | |
filepath = os.path.join(write_path, batch["frame_id"][i] + ".pkl") | |
out_dict["fov_mask_1"] = ( | |
batch["fov_mask_1"][i].detach().cpu().numpy() | |
) | |
out_dict["cam_k"] = batch["cam_k"][i].detach().cpu().numpy() | |
out_dict["T_velo_2_cam"] = ( | |
batch["T_velo_2_cam"][i].detach().cpu().numpy() | |
) | |
os.makedirs(write_path, exist_ok=True) | |
with open(filepath, "wb") as handle: | |
pickle.dump(out_dict, handle) | |
print("wrote to", filepath) | |
if __name__ == "__main__": | |
main() | |