Realcat
fix: eloftr
63f3cf2
raw
history blame
3.63 kB
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File pram -> recdataset
@IDE PyCharm
@Author fx221@cam.ac.uk
@Date 29/01/2024 14:42
=================================================='''
import numpy as np
from torch.utils.data import Dataset
class RecDataset(Dataset):
def __init__(self, sub_sets=[]):
assert len(sub_sets) >= 1
self.sub_sets = sub_sets
self.names = []
self.sub_set_index = []
self.seg_offsets = []
self.sub_set_item_index = []
self.dataset_names = []
self.scene_names = []
start_index_valid_seg = 1 # start from 1, 0 is for invalid
total_subset = 0
for scene_set in sub_sets: # [0, n_class]
name = scene_set.dataset
self.names.append(name)
n_samples = len(scene_set)
n_class = scene_set.n_class
self.seg_offsets = self.seg_offsets + [start_index_valid_seg for v in range(len(scene_set))]
start_index_valid_seg = start_index_valid_seg + n_class - 1
self.sub_set_index = self.sub_set_index + [total_subset for k in range(n_samples)]
self.sub_set_item_index = self.sub_set_item_index + [k for k in range(n_samples)]
# self.dataset_names = self.dataset_names + [name for k in range(n_samples)]
self.scene_names = self.scene_names + [name for k in range(n_samples)]
total_subset += 1
self.n_class = start_index_valid_seg
print('Load {} images {} segs from {} subsets from {}'.format(len(self.sub_set_item_index), self.n_class,
len(sub_sets), self.names))
def __len__(self):
return len(self.sub_set_item_index)
def __getitem__(self, idx):
subset_idx = self.sub_set_index[idx]
item_idx = self.sub_set_item_index[idx]
scene_name = self.scene_names[idx]
out = self.sub_sets[subset_idx][item_idx]
org_gt_seg = out['gt_seg']
org_gt_cls = out['gt_cls']
org_gt_cls_dist = out['gt_cls_dist']
org_gt_n_seg = out['gt_n_seg']
offset = self.seg_offsets[idx]
org_n_class = self.sub_sets[subset_idx].n_class
gt_seg = np.zeros(shape=(org_gt_seg.shape[0],), dtype=int) # [0, ..., n_features]
gt_n_seg = np.zeros(shape=(self.n_class,), dtype=int)
gt_cls = np.zeros(shape=(self.n_class,), dtype=int)
gt_cls_dist = np.zeros(shape=(self.n_class,), dtype=float)
# copy invalid segments
gt_n_seg[0] = org_gt_n_seg[0]
gt_cls[0] = org_gt_cls[0]
gt_cls_dist[0] = org_gt_cls_dist[0]
# print('org: ', org_n_class, org_gt_seg.shape, org_gt_n_seg.shape, org_gt_seg)
# copy valid segments
gt_seg[org_gt_seg > 0] = org_gt_seg[org_gt_seg > 0] + offset - 1 # [0, ..., 1023]
gt_n_seg[offset:offset + org_n_class - 1] = org_gt_n_seg[1:] # [0...,n_seg]
gt_cls[offset:offset + org_n_class - 1] = org_gt_cls[1:] # [0, ..., n_seg]
gt_cls_dist[offset:offset + org_n_class - 1] = org_gt_cls_dist[1:] # [0, ..., n_seg]
out['gt_seg'] = gt_seg
out['gt_cls'] = gt_cls
out['gt_cls_dist'] = gt_cls_dist
out['gt_n_seg'] = gt_n_seg
# print('gt: ', org_n_class, gt_seg.shape, gt_n_seg.shape, gt_seg)
out['scene_name'] = scene_name
# out['org_gt_seg'] = org_gt_seg
# out['org_gt_n_seg'] = org_gt_n_seg
# out['org_gt_cls'] = org_gt_cls
# out['org_gt_cls_dist'] = org_gt_cls_dist
return out