from dataclasses import dataclass @dataclass(eq=True, frozen=True) class ScheduledNode: type: str stage: int minibatch: int start_time: int completion_time: int rollback: bool = False def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1): # assert _n >= 2 * _p stage = [[] for _ in range(_p)] real_n = _n _n = max(_n, _p) for rank in range(_p): warmup = (_p - rank - 1) * warmup_c for _ in range(warmup): stage[rank].append(0) for i in range(_n): if warmup + i < _n: stage[rank].append(0) stage[rank].append(1) if warmup + i >= (_p - 1) * warmup_c: stage[rank].append(2) for _ in range((_p - 1) * warmup_c - warmup): stage[rank].append(2) labels = ["F", "B", "W", '.'] for rank in range(_p): rank_str = " " * rank # for i in range(_n * 3): for i in range(len(stage[rank])): rank_str += labels[stage[rank][i]] print(rank_str) size = _p * _n * 3 def get_id(_i, _j, _k): return _i * _p * _n + _j * _n + _k t = [-1] * size e = [0] * _p fc = [0] * _p bc = [0] * _p for i in range(3 * _n): for rank in range(_p): last = e[rank] if stage[rank][i] == 0: if fc[rank] >= real_n: continue tmp = e[rank] + _f if rank > 0: assert t[get_id(0, rank - 1, fc[rank])] > 0 tmp = max(tmp, t[get_id(0, rank - 1, fc[rank])] + _c + _f) e[rank] = tmp t[get_id(0, rank, fc[rank])] = tmp fc[rank] += 1 elif stage[rank][i] == 1: if bc[rank] >= real_n: continue tmp = e[rank] + _b if rank < _p - 1: assert t[get_id(1, rank + 1, bc[rank])] > 0, f"{rank} {i} {bc[rank]}" tmp = max(tmp, t[get_id(1, rank + 1, bc[rank])] + _c + _b) e[rank] = tmp t[get_id(1, rank, bc[rank])] = tmp bc[rank] += 1 elif stage[rank][i] == 2: continue # if rank == _p - 1: # print(_f, _b, _w, _c, "->", rank, i, stage[rank][i], e[rank], e[rank] - last) max_time = 0 for rank in range(_p): if warmup_c == 2: max_time = max(max_time, e[rank] - t[get_id(0, rank, 0)] + _f) else: max_time = max(max_time, e[rank]) # print(rank, "->", e[rank]) # exit(0) res = [[] for _ in range(_p)] for rank in range(_p): for i in range(real_n): res[rank].append(ScheduledNode( "F", rank, i, t[get_id(0, rank, i)] - _f, t[get_id(0, rank, i)])) res[rank].append(ScheduledNode( "B", rank, i, t[get_id(1, rank, i)] - _b, t[get_id(1, rank, i)])) res[rank].append(ScheduledNode( "W", rank, i, t[get_id(2, rank, i)] - _w, t[get_id(2, rank, i)])) res[rank] = sorted(res[rank], key=lambda x: x.start_time) return res if __name__ == "__main__": print(get_hand_schedule(16, 16, 1, 1, 1, 0))