Wan Xinyi
Add some presets, support 1f1b with fewer microbatches
ac0b05c
raw
history blame
3.25 kB
from dataclasses import dataclass
@dataclass(eq=True, frozen=True)
class ScheduledNode:
type: str
stage: int
minibatch: int
start_time: int
completion_time: int
rollback: bool = False
def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
# assert _n >= 2 * _p
stage = [[] for _ in range(_p)]
real_n = _n
_n = max(_n, _p)
for rank in range(_p):
warmup = (_p - rank - 1) * warmup_c
for _ in range(warmup):
stage[rank].append(0)
for i in range(_n):
if warmup + i < _n:
stage[rank].append(0)
stage[rank].append(1)
if warmup + i >= (_p - 1) * warmup_c:
stage[rank].append(2)
for _ in range((_p - 1) * warmup_c - warmup):
stage[rank].append(2)
labels = ["F", "B", "W", '.']
for rank in range(_p):
rank_str = " " * rank
# for i in range(_n * 3):
for i in range(len(stage[rank])):
rank_str += labels[stage[rank][i]]
print(rank_str)
size = _p * _n * 3
def get_id(_i, _j, _k):
return _i * _p * _n + _j * _n + _k
t = [-1] * size
e = [0] * _p
fc = [0] * _p
bc = [0] * _p
for i in range(3 * _n):
for rank in range(_p):
last = e[rank]
if stage[rank][i] == 0:
if fc[rank] >= real_n:
continue
tmp = e[rank] + _f
if rank > 0:
assert t[get_id(0, rank - 1, fc[rank])] > 0
tmp = max(tmp, t[get_id(0, rank - 1, fc[rank])] + _c + _f)
e[rank] = tmp
t[get_id(0, rank, fc[rank])] = tmp
fc[rank] += 1
elif stage[rank][i] == 1:
if bc[rank] >= real_n:
continue
tmp = e[rank] + _b
if rank < _p - 1:
assert t[get_id(1, rank + 1, bc[rank])] > 0, f"{rank} {i} {bc[rank]}"
tmp = max(tmp, t[get_id(1, rank + 1, bc[rank])] + _c + _b)
e[rank] = tmp
t[get_id(1, rank, bc[rank])] = tmp
bc[rank] += 1
elif stage[rank][i] == 2:
continue
# if rank == _p - 1:
# print(_f, _b, _w, _c, "->", rank, i, stage[rank][i], e[rank], e[rank] - last)
max_time = 0
for rank in range(_p):
if warmup_c == 2:
max_time = max(max_time, e[rank] - t[get_id(0, rank, 0)] + _f)
else:
max_time = max(max_time, e[rank])
# print(rank, "->", e[rank])
# exit(0)
res = [[] for _ in range(_p)]
for rank in range(_p):
for i in range(real_n):
res[rank].append(ScheduledNode(
"F", rank, i, t[get_id(0, rank, i)] - _f, t[get_id(0, rank, i)]))
res[rank].append(ScheduledNode(
"B", rank, i, t[get_id(1, rank, i)] - _b, t[get_id(1, rank, i)]))
res[rank].append(ScheduledNode(
"W", rank, i, t[get_id(2, rank, i)] - _w, t[get_id(2, rank, i)]))
res[rank] = sorted(res[rank], key=lambda x: x.start_time)
return res
if __name__ == "__main__":
print(get_hand_schedule(16, 16, 1, 1, 1, 0))