|
from dataclasses import dataclass |
|
|
|
@dataclass(eq=True, frozen=True) |
|
class ScheduledNode: |
|
type: str |
|
stage: int |
|
minibatch: int |
|
start_time: int |
|
completion_time: int |
|
rollback: bool = False |
|
|
|
|
|
def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1): |
|
assert _n >= 2 * _p |
|
stage = [[] for _ in range(_p)] |
|
for rank in range(_p): |
|
warmup = (_p - rank - 1) * warmup_c |
|
for _ in range(warmup): |
|
stage[rank].append(0) |
|
for i in range(_n): |
|
if warmup + i < _n: |
|
stage[rank].append(0) |
|
stage[rank].append(1) |
|
if warmup + i >= (_p - 1) * warmup_c: |
|
stage[rank].append(2) |
|
for _ in range((_p - 1) * warmup_c - warmup): |
|
stage[rank].append(2) |
|
labels = ["F", "B", "W"] |
|
for rank in range(_p): |
|
rank_str = " " * rank |
|
for i in range(_n * 3): |
|
rank_str += labels[stage[rank][i]] |
|
|
|
size = _p * _n * 3 |
|
def get_id(_i, _j, _k): |
|
return _i * _p * _n + _j * _n + _k |
|
t = [-1] * size |
|
e = [0] * _p |
|
fc = [0] * _p |
|
bc = [0] * _p |
|
for i in range(3 * _n): |
|
for rank in range(_p): |
|
last = e[rank] |
|
if stage[rank][i] == 0: |
|
tmp = e[rank] + _f |
|
if rank > 0: |
|
assert t[get_id(0, rank - 1, fc[rank])] > 0 |
|
tmp = max(tmp, t[get_id(0, rank - 1, fc[rank])] + _c + _f) |
|
e[rank] = tmp |
|
t[get_id(0, rank, fc[rank])] = tmp |
|
fc[rank] += 1 |
|
elif stage[rank][i] == 1: |
|
tmp = e[rank] + _b |
|
if rank < _p - 1: |
|
assert t[get_id(1, rank + 1, bc[rank])] > 0 |
|
tmp = max(tmp, t[get_id(1, rank + 1, bc[rank])] + _c + _b) |
|
e[rank] = tmp |
|
t[get_id(1, rank, bc[rank])] = tmp |
|
bc[rank] += 1 |
|
else: |
|
tmp = e[rank] + _w |
|
e[rank] = tmp |
|
t[get_id(2, rank, i - fc[rank] - bc[rank])] = tmp |
|
|
|
|
|
max_time = 0 |
|
for rank in range(_p): |
|
if warmup_c == 2: |
|
max_time = max(max_time, e[rank] - t[get_id(0, rank, 0)] + _f) |
|
else: |
|
max_time = max(max_time, e[rank]) |
|
|
|
|
|
res = [[] for _ in range(_p)] |
|
for rank in range(_p): |
|
for i in range(_n): |
|
res[rank].append(ScheduledNode( |
|
"F", rank, i, t[get_id(0, rank, i)] - _f, t[get_id(0, rank, i)])) |
|
res[rank].append(ScheduledNode( |
|
"B", rank, i, t[get_id(1, rank, i)] - _b, t[get_id(1, rank, i)])) |
|
res[rank].append(ScheduledNode( |
|
"W", rank, i, t[get_id(2, rank, i)] - _w, t[get_id(2, rank, i)])) |
|
res[rank] = sorted(res[rank], key=lambda x: x.start_time) |
|
return res |