|
|
from contextlib import contextmanager |
|
|
import os |
|
|
import re |
|
|
from time import time |
|
|
import torch |
|
|
from rb import ReplayBuffer, SocketManager |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def timer(header): |
|
|
time_start = time() |
|
|
yield |
|
|
print(header, (time() - time_start)) |
|
|
|
|
|
|
|
|
class DataLoader: |
|
|
def __init__(self, port, cpus, batch_size, sgf_prefix, is_selfplay): |
|
|
seed = torch.randint(0, 2**32, [1]).item() |
|
|
rb = ReplayBuffer() |
|
|
rb.max_iters = 20 |
|
|
|
|
|
rb.moves_per_iter = 5000 * 60 |
|
|
rb.run(seed, cpus, batch_size) |
|
|
sock = SocketManager() |
|
|
sock.run(port) |
|
|
self._rb = rb |
|
|
self._sock = sock |
|
|
|
|
|
self._total = int(rb.moves_per_iter * 1 / batch_size) |
|
|
self._sgf_prefix = sgf_prefix |
|
|
self._iter = 0 |
|
|
self._is_selfplay = is_selfplay |
|
|
if not is_selfplay: |
|
|
import zmq |
|
|
|
|
|
self._ctx = zmq.Context() |
|
|
zsock = self._ctx.socket(zmq.DEALER) |
|
|
zsock.setsockopt(zmq.LINGER, 0) |
|
|
zsock.setsockopt(zmq.ROUTING_ID, b"0") |
|
|
zsock.connect(f"tcp://127.0.0.1:{port}") |
|
|
zsock.send_multipart([b"0", b"0"]) |
|
|
self._zsock = zsock |
|
|
|
|
|
def load(self, sgf_prefix, epoch_ckpt): |
|
|
rb, sock = self._rb, self._sock |
|
|
for i in range(epoch_ckpt): |
|
|
if self._iter > 0: |
|
|
sock.notify() |
|
|
if i == 0 or i + rb.max_iters >= epoch_ckpt: |
|
|
print(f"[{i:3d}] Load selfplay") |
|
|
pattern = re.compile(rf"iter-{self._iter}-(\d+).sgf") |
|
|
nodes = 1 + max( |
|
|
int(m.group(1)) |
|
|
for f in os.listdir(sgf_prefix) |
|
|
if (m := pattern.search(f)) |
|
|
) |
|
|
rb.add_iter(sgf_prefix, self._iter, nodes) |
|
|
for _ in range(self._total): |
|
|
rb.sample().free() |
|
|
self._iter += 1 |
|
|
|
|
|
def __del__(self): |
|
|
self._rb.terminate() |
|
|
self._sock.terminate() |
|
|
|
|
|
def __iter__(self): |
|
|
rb, sock = self._rb, self._sock |
|
|
if self._iter > 0: |
|
|
sock.notify() |
|
|
if self._is_selfplay: |
|
|
with timer("[{:3d}] Time for selfplay:".format(self._iter)): |
|
|
if sock.wait(): |
|
|
exit(0) |
|
|
else: |
|
|
finished = rb.moves_per_iter + 1 |
|
|
self._zsock.send_multipart( |
|
|
[bytes(str(self._iter), "utf-8"), bytes(str(finished), "utf-8")] |
|
|
) |
|
|
with timer("[{:3d}] Time for training:".format(self._iter)): |
|
|
rb.add_iter(self._sgf_prefix, self._iter, sock.nodes) |
|
|
for _ in range(self._total): |
|
|
sample = rb.sample() |
|
|
yield sample |
|
|
sample.free() |
|
|
self._iter += 1 |
|
|
|