|
import torch |
|
from torch.utils.data import get_worker_info |
|
from torch.utils.data import DataLoader |
|
|
|
import random |
|
import time |
|
|
|
from functools import partial |
|
|
|
from itertools import chain |
|
|
|
|
|
from petrel_client.utils.data import DataLoader as MyDataLoader |
|
MyDataLoader = partial(MyDataLoader, prefetch_factor=4, persistent_workers=True) |
|
|
|
|
|
def assert_equal(lhs, rhs): |
|
if isinstance(lhs, dict): |
|
assert lhs.keys() == rhs.keys() |
|
for k in lhs.keys(): |
|
assert_equal(lhs[k], rhs[k]) |
|
elif isinstance(lhs, list): |
|
assert len(lhs) == len(rhs) |
|
for i in range(len(lhs)): |
|
assert_equal(lhs[i], rhs[i]) |
|
elif isinstance(lhs, torch.Tensor): |
|
assert torch.equal(lhs, rhs) |
|
else: |
|
assert False |
|
|
|
|
|
def wait(dt): |
|
time.sleep(dt) |
|
|
|
|
|
class Dataset(list): |
|
def __init__(self, *args, **kwargs): |
|
super(Dataset, self).__init__(*args, **kwargs) |
|
self._seed_inited = False |
|
|
|
def __getitem__(self, *args, **kwargs): |
|
worker_info = get_worker_info() |
|
if not self._seed_inited: |
|
if worker_info is None: |
|
random.seed(0) |
|
else: |
|
random.seed(worker_info.id) |
|
self._seed_inited = True |
|
rand_int = random.randint(1, 4) |
|
time_to_sleep = rand_int * 0.05 |
|
if worker_info is not None and worker_info.id == 0: |
|
time_to_sleep *= 2 |
|
wait(time_to_sleep) |
|
val = super(Dataset, self).__getitem__(*args, **kwargs) |
|
return {'val': val} |
|
|
|
|
|
def test(dataloader, result): |
|
print('\ntest') |
|
random.seed(0) |
|
data_time = 0 |
|
tstart = t1 = time.time() |
|
for i, data in enumerate(chain(dataloader, dataloader), 1): |
|
t2 = time.time() |
|
d = t2 - t1 |
|
print('{0:>5}' .format(int((t2 - t1)*1000)), end='') |
|
if i % 10: |
|
print('\t', end='') |
|
else: |
|
print('') |
|
|
|
result.append(data) |
|
|
|
data_time += d |
|
|
|
rand_int = random.randrange(1, 4) |
|
wait(0.05 * rand_int) |
|
|
|
t1 = time.time() |
|
tend = time.time() |
|
print('\ntotal time: %.3f' % (tend - tstart)) |
|
print('total data time: %.3f' % data_time) |
|
print(type(dataloader)) |
|
|
|
|
|
def worker_init_fn(worker_id): |
|
print('start worker:', worker_id) |
|
wait(3) |
|
|
|
|
|
dataloader_args = { |
|
'dataset': Dataset(range(1024)), |
|
'drop_last': False, |
|
'shuffle': False, |
|
'batch_size': 32, |
|
'num_workers': 8, |
|
'worker_init_fn': worker_init_fn, |
|
} |
|
|
|
|
|
torch.manual_seed(0) |
|
l2 = MyDataLoader(**dataloader_args) |
|
r2 = [] |
|
test(l2, r2) |
|
|
|
torch.manual_seed(0) |
|
l1 = DataLoader(**dataloader_args) |
|
r1 = [] |
|
test(l1, r1) |
|
|
|
|
|
print('len l1:', len(l1)) |
|
print('len l2:', len(l2)) |
|
|
|
assert_equal(r1, r2) |
|
print(torch) |
|
|