zhigangjiang's picture
no message
88b0dcb
raw
history blame contribute delete
No virus
4.27 kB
"""
@date: 2021/6/16
@description:
"""
import math
import os
import numpy as np
from dataset.communal.read import read_image, read_label
from dataset.communal.base_dataset import BaseDataset
from utils.logger import get_logger
class PanoS2D3DDataset(BaseDataset):
def __init__(self, root_dir, mode, shape=None, max_wall_num=0, aug=None, camera_height=1.6, logger=None,
split_list=None, patch_num=256, keys=None, for_test_index=None, subset=None):
super().__init__(mode, shape, max_wall_num, aug, camera_height, patch_num, keys)
if logger is None:
logger = get_logger()
self.root_dir = root_dir
if mode is None:
return
label_dir = os.path.join(root_dir, 'valid' if mode == 'val' else mode, 'label_cor')
img_dir = os.path.join(root_dir, 'valid' if mode == 'val' else mode, 'img')
if split_list is None:
split_list = [name.split('.')[0] for name in os.listdir(label_dir) if
not name.startswith('.') and name.endswith('txt')]
split_list.sort()
assert subset == 'pano' or subset == 's2d3d' or subset is None, 'error subset'
if subset == 'pano':
split_list = [name for name in split_list if 'pano_' in name]
logger.info(f"Use PanoContext Dataset")
elif subset == 's2d3d':
split_list = [name for name in split_list if 'camera_' in name]
logger.info(f"Use Stanford2D3D Dataset")
if for_test_index is not None:
split_list = split_list[:for_test_index]
self.data = []
invalid_num = 0
for name in split_list:
img_path = os.path.join(img_dir, f"{name}.png")
label_path = os.path.join(label_dir, f"{name}.txt")
if not os.path.exists(img_path):
logger.warning(f"{img_path} not exists")
invalid_num += 1
continue
if not os.path.exists(label_path):
logger.warning(f"{label_path} not exists")
invalid_num += 1
continue
with open(label_path, 'r') as f:
lines = [line for line in f.readlines() if
len([c for c in line.split(' ') if c[0].isnumeric()]) > 1]
if len(lines) % 2 != 0:
invalid_num += 1
continue
self.data.append([img_path, label_path])
logger.info(
f"Build dataset mode: {self.mode} valid: {len(self.data)} invalid: {invalid_num}")
def __getitem__(self, idx):
rgb_path, label_path = self.data[idx]
label = read_label(label_path, data_type='Pano_S2D3D')
image = read_image(rgb_path, self.shape)
output = self.process_data(label, image, self.patch_num)
return output
if __name__ == '__main__':
modes = ['test', 'val', 'train']
for i in range(1):
for mode in modes:
print(mode)
mp3d_dataset = PanoS2D3DDataset(root_dir='../src/dataset/pano_s2d3d', mode=mode, aug={
# 'STRETCH': True,
# 'ROTATE': True,
# 'FLIP': True,
# 'GAMMA': True
})
continue
save_dir = f'../src/dataset/pano_s2d3d/visualization/{mode}'
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
bar = tqdm(mp3d_dataset, ncols=100)
for data in bar:
bar.set_description(f"Processing {data['id']}")
boundary_list = depth2boundaries(data['ratio'], data['depth'], step=None)
pano_img = draw_boundaries(data['image'].transpose(1, 2, 0), boundary_list=boundary_list, show=False)
Image.fromarray((pano_img * 255).astype(np.uint8)).save(
os.path.join(save_dir, f"{data['id']}_boundary.png"))
floorplan = draw_floorplan(uv2xyz(boundary_list[0])[..., ::2], show=False,
marker_color=None, center_color=0.8, show_radius=None)
Image.fromarray((floorplan.squeeze() * 255).astype(np.uint8)).save(
os.path.join(save_dir, f"{data['id']}_floorplan.png"))