File size: 3,253 Bytes
be3048f
 
 
 
 
 
 
 
 
 
 
 
 
ac0b05c
be3048f
ac0b05c
 
be3048f
 
 
 
 
 
 
 
 
 
 
 
ac0b05c
be3048f
 
ac0b05c
 
be3048f
ac0b05c
be3048f
 
 
 
 
 
 
 
 
 
 
ac0b05c
 
be3048f
 
 
 
 
 
 
 
ac0b05c
 
be3048f
 
ac0b05c
be3048f
 
 
 
ac0b05c
 
be3048f
 
 
 
 
 
 
 
 
 
 
 
ac0b05c
be3048f
 
 
 
 
 
 
ac0b05c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from dataclasses import dataclass

@dataclass(eq=True, frozen=True)
class ScheduledNode:
    type: str
    stage: int
    minibatch: int
    start_time: int
    completion_time: int
    rollback: bool = False


def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
    # assert _n >= 2 * _p
    stage = [[] for _ in range(_p)]
    real_n = _n
    _n = max(_n, _p)
    for rank in range(_p):
        warmup = (_p - rank - 1) * warmup_c
        for _ in range(warmup):
            stage[rank].append(0)
        for i in range(_n):
            if warmup + i < _n:
                stage[rank].append(0)
            stage[rank].append(1)
            if warmup + i >= (_p - 1) * warmup_c:
                stage[rank].append(2)
        for _ in range((_p - 1) * warmup_c - warmup):
            stage[rank].append(2)
    labels = ["F", "B", "W", '.']
    for rank in range(_p):
        rank_str = " " * rank
        # for i in range(_n * 3):
        for i in range(len(stage[rank])):
            rank_str += labels[stage[rank][i]]
        print(rank_str)
    size = _p * _n * 3
    def get_id(_i, _j, _k):
        return _i * _p * _n + _j * _n + _k
    t = [-1] * size
    e = [0] * _p
    fc = [0] * _p
    bc = [0] * _p
    for i in range(3 * _n):
        for rank in range(_p):
            last = e[rank]
            if stage[rank][i] == 0:
                if fc[rank] >= real_n:
                  continue
                tmp = e[rank] + _f
                if rank > 0:
                    assert t[get_id(0, rank - 1, fc[rank])] > 0
                    tmp = max(tmp, t[get_id(0, rank - 1, fc[rank])] + _c + _f)
                e[rank] = tmp
                t[get_id(0, rank, fc[rank])] = tmp
                fc[rank] += 1
            elif stage[rank][i] == 1:
                if bc[rank] >= real_n:
                  continue
                tmp = e[rank] + _b
                if rank < _p - 1:
                    assert t[get_id(1, rank + 1, bc[rank])] > 0, f"{rank} {i} {bc[rank]}"
                    tmp = max(tmp, t[get_id(1, rank + 1, bc[rank])] + _c + _b)
                e[rank] = tmp
                t[get_id(1, rank, bc[rank])] = tmp
                bc[rank] += 1
            elif stage[rank][i] == 2:
                continue
            # if rank == _p - 1:
            #     print(_f, _b, _w, _c, "->", rank, i, stage[rank][i], e[rank], e[rank] - last)
    max_time = 0
    for rank in range(_p):
        if warmup_c == 2:
            max_time = max(max_time, e[rank] - t[get_id(0, rank, 0)] + _f)
        else:
            max_time = max(max_time, e[rank])
        # print(rank, "->", e[rank])
    # exit(0)
    res = [[] for _ in range(_p)]
    for rank in range(_p):
        for i in range(real_n):
            res[rank].append(ScheduledNode(
              "F", rank, i, t[get_id(0, rank, i)] - _f, t[get_id(0, rank, i)]))
            res[rank].append(ScheduledNode(
              "B", rank, i, t[get_id(1, rank, i)] - _b, t[get_id(1, rank, i)]))
            res[rank].append(ScheduledNode(
              "W", rank, i, t[get_id(2, rank, i)] - _w, t[get_id(2, rank, i)]))
        res[rank] = sorted(res[rank], key=lambda x: x.start_time)
    return res

if __name__ == "__main__":
    print(get_hand_schedule(16, 16, 1, 1, 1, 0))