diffumatch / shape_data /__init__.py
daidedou
forgot a few things lol
e321b92
raw
history blame
2.42 kB
import sys
import os.path as osp
import numpy as np
import torch
from collections import defaultdict
ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
if ROOT_DIR not in sys.path:
sys.path.append(ROOT_DIR)
DATA_DIRS = {
'faust': 'FAUST_r',
'faust_ori': 'FAUST_r_ori',
'scape': 'SCAPE_r',
'scape_ori': 'SCAPE_r_ori',
'smalr': 'SMAL_r',
'smalr_ori': 'SMAL_r_ori',
'shrec19': 'SHREC_r',
'shrec19_ori': 'SHREC_r_ori',
'dt4d': 'DT4D_r',
'dt4dintra': 'DT4D_r',
'dt4dintra_ori': 'DT4D_r_ori',
'dt4dinter': 'DT4D_r',
'dt4dinter_ori': 'DT4D_r_ori',
'tosca': 'TOSCA_r',
'tosca_ori': 'TOSCA_r',
}
def get_data_dirs(root, name, mode):
prefix = osp.join(root, DATA_DIRS[name])
shape_dir = osp.join(prefix, 'shapes')
corr_dir = osp.join(prefix, 'correspondences')
return shape_dir, DATA_DIRS[name], corr_dir
# def collate_default(data_list):
# data_dict = defaultdict(list)
# for pair_dict in data_list:
# for k, v in pair_dict.items():
# data_dict[k].append(v)
# for k in data_dict.keys():
# if k.startswith('fmap') or k.startswith('evals') or k.endswith('_sub'):
# data_dict[k] = np.stack(data_dict[k], axis=0)
# batch_size = len(data_list)
# for k, v in data_dict.items():
# assert len(v) == batch_size
# return data_dict
def prepare_batch(data_dict, device):
for k in data_dict.keys():
if isinstance(data_dict[k], np.ndarray):
data_dict[k] = torch.from_numpy(data_dict[k]).to(device)
else:
if k.startswith('gradX') or \
k.startswith('gradY') or \
k.startswith('L'):
from diffusion_net.utils import sparse_np_to_torch
tmp_list = [sparse_np_to_torch(st).to(device) for st in data_dict[k]]
if len(data_dict[k]) == 1:
data_dict[k] = torch.stack(tmp_list, dim=0)
else:
data_dict[k] = tmp_list
else:
if isinstance(data_dict[k][0], np.ndarray):
tmp_list = [torch.from_numpy(st).to(device) for st in data_dict[k]]
if len(data_dict[k]) == 1:
data_dict[k] = torch.stack(tmp_list, dim=0).to(device)
else:
data_dict[k] = tmp_list
return data_dict