Spaces:
Configuration error
Configuration error
""" | |
Copyright (c) Microsoft Corporation. | |
Licensed under the MIT license. | |
""" | |
import os.path as op | |
import torch | |
import logging | |
import code | |
from custom_mesh_graphormer.utils.comm import get_world_size | |
from custom_mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset) | |
from custom_mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset) | |
def build_dataset(yaml_file, args, is_train=True, scale_factor=1): | |
print(yaml_file) | |
if not op.isfile(yaml_file): | |
yaml_file = op.join(args.data_dir, yaml_file) | |
# code.interact(local=locals()) | |
assert op.isfile(yaml_file) | |
return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor) | |
class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler): | |
""" | |
Wraps a BatchSampler, resampling from it until | |
a specified number of iterations have been sampled | |
""" | |
def __init__(self, batch_sampler, num_iterations, start_iter=0): | |
self.batch_sampler = batch_sampler | |
self.num_iterations = num_iterations | |
self.start_iter = start_iter | |
def __iter__(self): | |
iteration = self.start_iter | |
while iteration <= self.num_iterations: | |
# if the underlying sampler has a set_epoch method, like | |
# DistributedSampler, used for making each process see | |
# a different split of the dataset, then set it | |
if hasattr(self.batch_sampler.sampler, "set_epoch"): | |
self.batch_sampler.sampler.set_epoch(iteration) | |
for batch in self.batch_sampler: | |
iteration += 1 | |
if iteration > self.num_iterations: | |
break | |
yield batch | |
def __len__(self): | |
return self.num_iterations | |
def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0): | |
batch_sampler = torch.utils.data.sampler.BatchSampler( | |
sampler, images_per_gpu, drop_last=False | |
) | |
if num_iters is not None and num_iters >= 0: | |
batch_sampler = IterationBasedBatchSampler( | |
batch_sampler, num_iters, start_iter | |
) | |
return batch_sampler | |
def make_data_sampler(dataset, shuffle, distributed): | |
if distributed: | |
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) | |
if shuffle: | |
sampler = torch.utils.data.sampler.RandomSampler(dataset) | |
else: | |
sampler = torch.utils.data.sampler.SequentialSampler(dataset) | |
return sampler | |
def make_data_loader(args, yaml_file, is_distributed=True, | |
is_train=True, start_iter=0, scale_factor=1): | |
dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) | |
logger = logging.getLogger(__name__) | |
if is_train==True: | |
shuffle = True | |
images_per_gpu = args.per_gpu_train_batch_size | |
images_per_batch = images_per_gpu * get_world_size() | |
iters_per_batch = len(dataset) // images_per_batch | |
num_iters = iters_per_batch * args.num_train_epochs | |
logger.info("Train with {} images per GPU.".format(images_per_gpu)) | |
logger.info("Total batch size {}".format(images_per_batch)) | |
logger.info("Total training steps {}".format(num_iters)) | |
else: | |
shuffle = False | |
images_per_gpu = args.per_gpu_eval_batch_size | |
num_iters = None | |
start_iter = 0 | |
sampler = make_data_sampler(dataset, shuffle, is_distributed) | |
batch_sampler = make_batch_data_sampler( | |
sampler, images_per_gpu, num_iters, start_iter | |
) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, | |
pin_memory=True, | |
) | |
return data_loader | |
#============================================================================================== | |
def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1): | |
print(yaml_file) | |
if not op.isfile(yaml_file): | |
yaml_file = op.join(args.data_dir, yaml_file) | |
# code.interact(local=locals()) | |
assert op.isfile(yaml_file) | |
return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor) | |
def make_hand_data_loader(args, yaml_file, is_distributed=True, | |
is_train=True, start_iter=0, scale_factor=1): | |
dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) | |
logger = logging.getLogger(__name__) | |
if is_train==True: | |
shuffle = True | |
images_per_gpu = args.per_gpu_train_batch_size | |
images_per_batch = images_per_gpu * get_world_size() | |
iters_per_batch = len(dataset) // images_per_batch | |
num_iters = iters_per_batch * args.num_train_epochs | |
logger.info("Train with {} images per GPU.".format(images_per_gpu)) | |
logger.info("Total batch size {}".format(images_per_batch)) | |
logger.info("Total training steps {}".format(num_iters)) | |
else: | |
shuffle = False | |
images_per_gpu = args.per_gpu_eval_batch_size | |
num_iters = None | |
start_iter = 0 | |
sampler = make_data_sampler(dataset, shuffle, is_distributed) | |
batch_sampler = make_batch_data_sampler( | |
sampler, images_per_gpu, num_iters, start_iter | |
) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, | |
pin_memory=True, | |
) | |
return data_loader | |