File size: 2,407 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from __future__ import division
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler
class RandomSampler(Sampler):
def __init__(self, data_source, checkpoint):
self.data_source = data_source
if checkpoint is not None and checkpoint['dataset_perm'] is not None:
self.dataset_perm = checkpoint['dataset_perm']
self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:]
else:
self.dataset_perm = torch.randperm(len(self.data_source)).tolist()
self.perm = torch.randperm(len(self.data_source)).tolist()
def __iter__(self):
return iter(self.perm)
def __len__(self):
return len(self.perm)
class SequentialSampler(Sampler):
def __init__(self, data_source, checkpoint):
self.data_source = data_source
if checkpoint is not None and checkpoint['dataset_perm'] is not None:
self.dataset_perm = checkpoint['dataset_perm']
self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:]
else:
self.dataset_perm = list(range(len(self.data_source)))
self.perm = self.dataset_perm
def __iter__(self):
return iter(self.perm)
def __len__(self):
return len(self.perm)
class CheckpointDataLoader(DataLoader):
"""
Extends torch.utils.data.DataLoader to handle resuming training from an arbitrary point within an epoch.
"""
def __init__(
self,
dataset,
checkpoint=None,
batch_size=1,
shuffle=False,
num_workers=0,
pin_memory=False,
drop_last=True,
timeout=0,
worker_init_fn=None
):
if shuffle:
sampler = RandomSampler(dataset, checkpoint)
else:
sampler = SequentialSampler(dataset, checkpoint)
if checkpoint is not None:
self.checkpoint_batch_idx = checkpoint['batch_idx']
else:
self.checkpoint_batch_idx = 0
super(CheckpointDataLoader, self).__init__(
dataset,
sampler=sampler,
shuffle=False,
batch_size=batch_size,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=pin_memory,
timeout=timeout,
worker_init_fn=None
)
|