anhquancao's picture
up
4d85df4
raw history blame
No virus
5.36 kB
from pytorch_lightning import Trainer
from monoscene.models.monoscene import MonoScene
from monoscene.data.NYU.nyu_dm import NYUDataModule
from monoscene.data.semantic_kitti.kitti_dm import KittiDataModule
from monoscene.data.kitti_360.kitti_360_dm import Kitti360DataModule
# import hydra
from omegaconf import DictConfig
import torch
import numpy as np
import os
from hydra.utils import get_original_cwd
import gradio as gr
import numpy as np
import plotly.express as px
import pandas as pd
# @hydra.main(config_name="../config/monoscene.yaml")
def plot(input_img):
torch.set_grad_enabled(False)
# Setup dataloader
# if config.dataset == "kitti" or config.dataset == "kitti_360":
feature = 64
project_scale = 2
full_scene_size = (256, 256, 32)
# if config.dataset == "kitti":
# data_module = KittiDataModule(
# root=config.kitti_root,
# preprocess_root=config.kitti_preprocess_root,
# frustum_size=config.frustum_size,
# batch_size=int(config.batch_size / config.n_gpus),
# num_workers=int(config.num_workers_per_gpu * config.n_gpus),
# )
# data_module.setup()
# data_loader = data_module.val_dataloader()
# # data_loader = data_module.test_dataloader() # use this if you want to infer on test set
# else:
# data_module = Kitti360DataModule(
# root=config.kitti_360_root,
# sequences=[config.kitti_360_sequence],
# n_scans=2000,
# batch_size=1,
# num_workers=3,
# )
# data_module.setup()
# data_loader = data_module.dataloader()
# elif config.dataset == "NYU":
# project_scale = 1
# feature = 200
# full_scene_size = (60, 36, 60)
# data_module = NYUDataModule(
# root=config.NYU_root,
# preprocess_root=config.NYU_preprocess_root,
# n_relations=config.n_relations,
# frustum_size=config.frustum_size,
# batch_size=int(config.batch_size / config.n_gpus),
# num_workers=int(config.num_workers_per_gpu * config.n_gpus),
# )
# data_module.setup()
# data_loader = data_module.val_dataloader()
# # data_loader = data_module.test_dataloader() # use this if you want to infer on test set
# else:
# print("dataset not support")
# Load pretrained models
# if config.dataset == "NYU":
# model_path = os.path.join(
# get_original_cwd(), "trained_models", "monoscene_nyu.ckpt"
# )
# else:
# model_path = os.path.join(
# get_original_cwd(), "trained_models", "monoscene_kitti.ckpt"
# )
model_path = "trained_models/monoscene_kitti.ckpt"
model = MonoScene.load_from_checkpoint(
model_path,
feature=feature,
project_scale=project_scale,
fp_loss=False,
full_scene_size=full_scene_size,
)
model.cuda()
model.eval()
print(input_img.shape)
x = np.arange(12).reshape(4, 3) / 12
data = pd.DataFrame(data=x, columns=['x', 'y', 'z'])
fig = px.scatter_3d(data, x="x", y="y", z="z")
return fig
demo = gr.Interface(plot, gr.Image(shape=(200, 200)), gr.Plot())
demo.launch()
# Save prediction and additional data
# to draw the viewing frustum and remove scene outside the room for NYUv2
# output_path = os.path.join(config.output_path, config.dataset)
# with torch.no_grad():
# for batch in tqdm(data_loader):
# batch["img"] = batch["img"].cuda()
# pred = model(batch)
# y_pred = torch.softmax(pred["ssc_logit"], dim=1).detach().cpu().numpy()
# y_pred = np.argmax(y_pred, axis=1)
# for i in range(config.batch_size):
# out_dict = {"y_pred": y_pred[i].astype(np.uint16)}
# if "target" in batch:
# out_dict["target"] = (
# batch["target"][i].detach().cpu().numpy().astype(np.uint16)
# )
# if config.dataset == "NYU":
# write_path = output_path
# filepath = os.path.join(write_path, batch["name"][i] + ".pkl")
# out_dict["cam_pose"] = batch["cam_pose"][i].detach().cpu().numpy()
# out_dict["vox_origin"] = (
# batch["vox_origin"][i].detach().cpu().numpy()
# )
# else:
# write_path = os.path.join(output_path, batch["sequence"][i])
# filepath = os.path.join(write_path, batch["frame_id"][i] + ".pkl")
# out_dict["fov_mask_1"] = (
# batch["fov_mask_1"][i].detach().cpu().numpy()
# )
# out_dict["cam_k"] = batch["cam_k"][i].detach().cpu().numpy()
# out_dict["T_velo_2_cam"] = (
# batch["T_velo_2_cam"][i].detach().cpu().numpy()
# )
# os.makedirs(write_path, exist_ok=True)
# with open(filepath, "wb") as handle:
# pickle.dump(out_dict, handle)
# print("wrote to", filepath)