File size: 3,318 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import pytest
import threading
import time
import torch
import torch.nn as nn
from functools import partial
from itertools import product

from ding.utils import EasyTimer
from ding.utils.data import AsyncDataLoader

batch_size_args = [3, 6]
num_workers_args = [0, 4]
chunk_size_args = [1, 3]
args = [item for item in product(*[batch_size_args, num_workers_args, chunk_size_args])]
unittest_args = [item for item in product(*[[3], [2], [1]])]


class Dataset(object):

    def __init__(self):
        self.data = torch.randn(256, 256)

    def __len__(self):
        return 100

    def __getitem__(self, idx):
        time.sleep(0.5)
        return [self.data, idx]


class TestAsyncDataLoader:

    def get_data_source(self):
        dataset = Dataset()

        def data_source_fn(batch_size):
            return [partial(dataset.__getitem__, idx=i) for i in range(batch_size)]

        return data_source_fn

    def get_model(self):

        class Model(nn.Module):

            def __init__(self):
                super(Model, self).__init__()
                self.main = [nn.Linear(256, 256) for _ in range(10)]
                self.main = nn.Sequential(*self.main)

            def forward(self, x):
                idx = x[1]
                x = self.main(x[0])
                time.sleep(1)
                return [x, idx]

        return Model()

    # @pytest.mark.unittest
    @pytest.mark.parametrize('batch_size, num_workers, chunk_size', unittest_args)
    def test_cpu(self, batch_size, num_workers, chunk_size):
        self.entry(batch_size, num_workers, chunk_size, use_cuda=False)

    @pytest.mark.cudatest
    @pytest.mark.parametrize('batch_size, num_workers, chunk_size', args)
    def test_gpu(self, batch_size, num_workers, chunk_size):
        self.entry(batch_size, num_workers, chunk_size, use_cuda=True)
        torch.cuda.empty_cache()

    def entry(self, batch_size, num_workers, chunk_size, use_cuda):
        model = self.get_model()
        if use_cuda:
            model.cuda()
        timer = EasyTimer()
        data_source = self.get_data_source()
        device = 'cuda' if use_cuda else 'cpu'
        dataloader = AsyncDataLoader(data_source, batch_size, device, num_workers=num_workers, chunk_size=chunk_size)
        count = 0
        total_data_time = 0.
        while True:
            with timer:
                data = next(dataloader)
            data_time = timer.value
            if count > 2:  # ignore start-3 time
                total_data_time += data_time
            with timer:
                with torch.no_grad():
                    _, idx = model(data)
                if use_cuda:
                    idx = idx.cpu()
                sorted_idx = torch.sort(idx)[0]
                assert sorted_idx.eq(torch.arange(batch_size)).sum() == batch_size, idx
            model_time = timer.value
            print('count {}, data_time: {}, model_time: {}'.format(count, data_time, model_time))
            count += 1
            if count == 10:
                break
        if num_workers < 1:
            assert total_data_time <= 7 * batch_size * 0.5 + 7 * 0.01 - 7 * 1
        else:
            assert total_data_time <= 7 * 0.008
        dataloader.__del__()
        time.sleep(0.5)
        assert len(threading.enumerate()) <= 2, threading.enumerate()