zjowowen's picture
init space
079c32c
from typing import TYPE_CHECKING
from threading import Thread, Event
from queue import Queue
import time
import numpy as np
import torch
from easydict import EasyDict
from ding.framework import task
from ding.data import Dataset, DataLoader
from ding.utils import get_rank, get_world_size
if TYPE_CHECKING:
from ding.framework import OfflineRLContext
class OfflineMemoryDataFetcher:
def __new__(cls, *args, **kwargs):
if task.router.is_active and not task.has_role(task.role.FETCHER):
return task.void()
return super(OfflineMemoryDataFetcher, cls).__new__(cls)
def __init__(self, cfg: EasyDict, dataset: Dataset):
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
if device != 'cpu':
stream = torch.cuda.Stream()
def producer(queue, dataset, batch_size, device, event):
torch.set_num_threads(4)
if device != 'cpu':
nonlocal stream
sbatch_size = batch_size * get_world_size()
rank = get_rank()
idx_list = np.random.permutation(len(dataset))
temp_idx_list = []
for i in range(len(dataset) // sbatch_size):
temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size])
idx_iter = iter(temp_idx_list)
if device != 'cpu':
with torch.cuda.stream(stream):
while True:
if queue.full():
time.sleep(0.1)
else:
data = []
for _ in range(batch_size):
try:
data.append(dataset.__getitem__(next(idx_iter)))
except StopIteration:
del idx_iter
idx_list = np.random.permutation(len(dataset))
idx_iter = iter(idx_list)
data.append(dataset.__getitem__(next(idx_iter)))
data = [[i[j] for i in data] for j in range(len(data[0]))]
data = [torch.stack(x).to(device) for x in data]
queue.put(data)
if event.is_set():
break
else:
while True:
if queue.full():
time.sleep(0.1)
else:
data = []
for _ in range(batch_size):
try:
data.append(dataset.__getitem__(next(idx_iter)))
except StopIteration:
del idx_iter
idx_list = np.random.permutation(len(dataset))
idx_iter = iter(idx_list)
data.append(dataset.__getitem__(next(idx_iter)))
data = [[i[j] for i in data] for j in range(len(data[0]))]
data = [torch.stack(x) for x in data]
queue.put(data)
if event.is_set():
break
self.queue = Queue(maxsize=50)
self.event = Event()
self.producer_thread = Thread(
target=producer,
args=(self.queue, dataset, cfg.policy.batch_size, device, self.event),
name='cuda_fetcher_producer'
)
def __call__(self, ctx: "OfflineRLContext"):
if not self.producer_thread.is_alive():
time.sleep(5)
self.producer_thread.start()
while self.queue.empty():
time.sleep(0.001)
ctx.train_data = self.queue.get()
def __del__(self):
if self.producer_thread.is_alive():
self.event.set()
del self.queue