JasonSmithSO's picture
Upload 777 files
0034848 verified
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
"""
import os.path as op
import torch
import logging
import code
from custom_mesh_graphormer.utils.comm import get_world_size
from custom_mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset)
from custom_mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset)
def build_dataset(yaml_file, args, is_train=True, scale_factor=1):
print(yaml_file)
if not op.isfile(yaml_file):
yaml_file = op.join(args.data_dir, yaml_file)
# code.interact(local=locals())
assert op.isfile(yaml_file)
return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor)
class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler):
"""
Wraps a BatchSampler, resampling from it until
a specified number of iterations have been sampled
"""
def __init__(self, batch_sampler, num_iterations, start_iter=0):
self.batch_sampler = batch_sampler
self.num_iterations = num_iterations
self.start_iter = start_iter
def __iter__(self):
iteration = self.start_iter
while iteration <= self.num_iterations:
# if the underlying sampler has a set_epoch method, like
# DistributedSampler, used for making each process see
# a different split of the dataset, then set it
if hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(iteration)
for batch in self.batch_sampler:
iteration += 1
if iteration > self.num_iterations:
break
yield batch
def __len__(self):
return self.num_iterations
def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0):
batch_sampler = torch.utils.data.sampler.BatchSampler(
sampler, images_per_gpu, drop_last=False
)
if num_iters is not None and num_iters >= 0:
batch_sampler = IterationBasedBatchSampler(
batch_sampler, num_iters, start_iter
)
return batch_sampler
def make_data_sampler(dataset, shuffle, distributed):
if distributed:
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
sampler = torch.utils.data.sampler.RandomSampler(dataset)
else:
sampler = torch.utils.data.sampler.SequentialSampler(dataset)
return sampler
def make_data_loader(args, yaml_file, is_distributed=True,
is_train=True, start_iter=0, scale_factor=1):
dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
logger = logging.getLogger(__name__)
if is_train==True:
shuffle = True
images_per_gpu = args.per_gpu_train_batch_size
images_per_batch = images_per_gpu * get_world_size()
iters_per_batch = len(dataset) // images_per_batch
num_iters = iters_per_batch * args.num_train_epochs
logger.info("Train with {} images per GPU.".format(images_per_gpu))
logger.info("Total batch size {}".format(images_per_batch))
logger.info("Total training steps {}".format(num_iters))
else:
shuffle = False
images_per_gpu = args.per_gpu_eval_batch_size
num_iters = None
start_iter = 0
sampler = make_data_sampler(dataset, shuffle, is_distributed)
batch_sampler = make_batch_data_sampler(
sampler, images_per_gpu, num_iters, start_iter
)
data_loader = torch.utils.data.DataLoader(
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
pin_memory=True,
)
return data_loader
#==============================================================================================
def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1):
print(yaml_file)
if not op.isfile(yaml_file):
yaml_file = op.join(args.data_dir, yaml_file)
# code.interact(local=locals())
assert op.isfile(yaml_file)
return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor)
def make_hand_data_loader(args, yaml_file, is_distributed=True,
is_train=True, start_iter=0, scale_factor=1):
dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
logger = logging.getLogger(__name__)
if is_train==True:
shuffle = True
images_per_gpu = args.per_gpu_train_batch_size
images_per_batch = images_per_gpu * get_world_size()
iters_per_batch = len(dataset) // images_per_batch
num_iters = iters_per_batch * args.num_train_epochs
logger.info("Train with {} images per GPU.".format(images_per_gpu))
logger.info("Total batch size {}".format(images_per_batch))
logger.info("Total training steps {}".format(num_iters))
else:
shuffle = False
images_per_gpu = args.per_gpu_eval_batch_size
num_iters = None
start_iter = 0
sampler = make_data_sampler(dataset, shuffle, is_distributed)
batch_sampler = make_batch_data_sampler(
sampler, images_per_gpu, num_iters, start_iter
)
data_loader = torch.utils.data.DataLoader(
dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
pin_memory=True,
)
return data_loader