|
from typing import TYPE_CHECKING |
|
from threading import Thread, Event |
|
from queue import Queue |
|
import time |
|
import numpy as np |
|
import torch |
|
from easydict import EasyDict |
|
from ding.framework import task |
|
from ding.data import Dataset, DataLoader |
|
from ding.utils import get_rank, get_world_size |
|
|
|
if TYPE_CHECKING: |
|
from ding.framework import OfflineRLContext |
|
|
|
|
|
class OfflineMemoryDataFetcher: |
|
|
|
def __new__(cls, *args, **kwargs): |
|
if task.router.is_active and not task.has_role(task.role.FETCHER): |
|
return task.void() |
|
return super(OfflineMemoryDataFetcher, cls).__new__(cls) |
|
|
|
def __init__(self, cfg: EasyDict, dataset: Dataset): |
|
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' |
|
if device != 'cpu': |
|
stream = torch.cuda.Stream() |
|
|
|
def producer(queue, dataset, batch_size, device, event): |
|
torch.set_num_threads(4) |
|
if device != 'cpu': |
|
nonlocal stream |
|
sbatch_size = batch_size * get_world_size() |
|
rank = get_rank() |
|
idx_list = np.random.permutation(len(dataset)) |
|
temp_idx_list = [] |
|
for i in range(len(dataset) // sbatch_size): |
|
temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size]) |
|
idx_iter = iter(temp_idx_list) |
|
|
|
if device != 'cpu': |
|
with torch.cuda.stream(stream): |
|
while True: |
|
if queue.full(): |
|
time.sleep(0.1) |
|
else: |
|
data = [] |
|
for _ in range(batch_size): |
|
try: |
|
data.append(dataset.__getitem__(next(idx_iter))) |
|
except StopIteration: |
|
del idx_iter |
|
idx_list = np.random.permutation(len(dataset)) |
|
idx_iter = iter(idx_list) |
|
data.append(dataset.__getitem__(next(idx_iter))) |
|
data = [[i[j] for i in data] for j in range(len(data[0]))] |
|
data = [torch.stack(x).to(device) for x in data] |
|
queue.put(data) |
|
if event.is_set(): |
|
break |
|
else: |
|
while True: |
|
if queue.full(): |
|
time.sleep(0.1) |
|
else: |
|
data = [] |
|
for _ in range(batch_size): |
|
try: |
|
data.append(dataset.__getitem__(next(idx_iter))) |
|
except StopIteration: |
|
del idx_iter |
|
idx_list = np.random.permutation(len(dataset)) |
|
idx_iter = iter(idx_list) |
|
data.append(dataset.__getitem__(next(idx_iter))) |
|
data = [[i[j] for i in data] for j in range(len(data[0]))] |
|
data = [torch.stack(x) for x in data] |
|
queue.put(data) |
|
if event.is_set(): |
|
break |
|
|
|
self.queue = Queue(maxsize=50) |
|
self.event = Event() |
|
self.producer_thread = Thread( |
|
target=producer, |
|
args=(self.queue, dataset, cfg.policy.batch_size, device, self.event), |
|
name='cuda_fetcher_producer' |
|
) |
|
|
|
def __call__(self, ctx: "OfflineRLContext"): |
|
if not self.producer_thread.is_alive(): |
|
time.sleep(5) |
|
self.producer_thread.start() |
|
while self.queue.empty(): |
|
time.sleep(0.001) |
|
ctx.train_data = self.queue.get() |
|
|
|
def __del__(self): |
|
if self.producer_thread.is_alive(): |
|
self.event.set() |
|
del self.queue |
|
|