File size: 4,017 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import TYPE_CHECKING
from threading import Thread, Event
from queue import Queue
import time
import numpy as np
import torch
from easydict import EasyDict
from ding.framework import task
from ding.data import Dataset, DataLoader
from ding.utils import get_rank, get_world_size

if TYPE_CHECKING:
    from ding.framework import OfflineRLContext


class OfflineMemoryDataFetcher:

    def __new__(cls, *args, **kwargs):
        if task.router.is_active and not task.has_role(task.role.FETCHER):
            return task.void()
        return super(OfflineMemoryDataFetcher, cls).__new__(cls)

    def __init__(self, cfg: EasyDict, dataset: Dataset):
        device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
        if device != 'cpu':
            stream = torch.cuda.Stream()

        def producer(queue, dataset, batch_size, device, event):
            torch.set_num_threads(4)
            if device != 'cpu':
                nonlocal stream
            sbatch_size = batch_size * get_world_size()
            rank = get_rank()
            idx_list = np.random.permutation(len(dataset))
            temp_idx_list = []
            for i in range(len(dataset) // sbatch_size):
                temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size])
            idx_iter = iter(temp_idx_list)

            if device != 'cpu':
                with torch.cuda.stream(stream):
                    while True:
                        if queue.full():
                            time.sleep(0.1)
                        else:
                            data = []
                            for _ in range(batch_size):
                                try:
                                    data.append(dataset.__getitem__(next(idx_iter)))
                                except StopIteration:
                                    del idx_iter
                                    idx_list = np.random.permutation(len(dataset))
                                    idx_iter = iter(idx_list)
                                    data.append(dataset.__getitem__(next(idx_iter)))
                            data = [[i[j] for i in data] for j in range(len(data[0]))]
                            data = [torch.stack(x).to(device) for x in data]
                            queue.put(data)
                        if event.is_set():
                            break
            else:
                while True:
                    if queue.full():
                        time.sleep(0.1)
                    else:
                        data = []
                        for _ in range(batch_size):
                            try:
                                data.append(dataset.__getitem__(next(idx_iter)))
                            except StopIteration:
                                del idx_iter
                                idx_list = np.random.permutation(len(dataset))
                                idx_iter = iter(idx_list)
                                data.append(dataset.__getitem__(next(idx_iter)))
                        data = [[i[j] for i in data] for j in range(len(data[0]))]
                        data = [torch.stack(x) for x in data]
                        queue.put(data)
                    if event.is_set():
                        break

        self.queue = Queue(maxsize=50)
        self.event = Event()
        self.producer_thread = Thread(
            target=producer,
            args=(self.queue, dataset, cfg.policy.batch_size, device, self.event),
            name='cuda_fetcher_producer'
        )

    def __call__(self, ctx: "OfflineRLContext"):
        if not self.producer_thread.is_alive():
            time.sleep(5)
            self.producer_thread.start()
        while self.queue.empty():
            time.sleep(0.001)
        ctx.train_data = self.queue.get()

    def __del__(self):
        if self.producer_thread.is_alive():
            self.event.set()
            del self.queue