File size: 3,327 Bytes
3d4d40d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf49f13
3d4d40d
cf49f13
 
 
3d4d40d
cf49f13
 
 
 
3d4d40d
 
 
 
 
 
 
 
 
cf49f13
3d4d40d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from dataclasses import dataclass

@dataclass(eq=True, frozen=True)
class ScheduledNode:
    type: str
    chunk: int
    stage: int
    minibatch: int
    start_time: int
    completion_time: int

def get_interleaved_variation(_p, _n, cost):
    _f, _b, _w, _c = cost
    schedule = []
    local_prev = {}

    f_order = []
    b_order = []

    left = [_n, _n]
    for id in range(min(_n, _p)):
        f_order.append(('F', id))
    for id in range(min(_n, _p)):
        f_order.append(('f', id))
    
    left = [max(0, _n - _p), max(0, _n - _p)]

    i = 0
    cur = 0
    for id in range(min(_n, _p)):
        b_order.append(('B', id))
    while left[0] > 0 or left[1] > 0:
        if i >= _p and left[1 - cur] > 0:
            cur = 1 - cur
        if left[cur] > 0:
            if cur == 0:
                f_order.append(('F', _n - left[cur]))
                b_order.append(('b', _n - left[cur] - _p))
            else:
                f_order.append(('f', _n - left[cur]))
                b_order.append(('B', _n - left[cur]))
            left[cur] -= 1
        i += 3
    for id in range(min(_n, _p)):
        b_order.append(('b', _n - _p + id))

    for stage in range(_p):
        diff = min(_p + _p - stage, len(f_order))
        stage_schedule = []
        for i in range(diff):
            stage_schedule.append(f_order[i])
        for i in range(len(f_order) - diff):
            stage_schedule.append(b_order[i])
            stage_schedule.append(f_order[i + diff])
        for i in range(diff):
            stage_schedule.append(b_order[len(b_order) - diff + i])
        for i in range(len(stage_schedule) - 1):
            local_prev[(stage, *stage_schedule[i + 1])] = (stage, *stage_schedule[i])
        schedule.append(stage_schedule)
        # print(stage_schedule)
    # return None
    cost = {
        'F': _f,
        'f': _f,
        'B': _b+_w,
        'b': _b+_w
    }

    time_map = {}
    def get_time(stage, type, minibatch):
        if (stage, type, minibatch) in time_map:
            return time_map.get((stage, type, minibatch))
        time = 0
        if (stage, type, minibatch) in local_prev:
            time = get_time(*local_prev[(stage, type, minibatch)])
        if stage > 0 and type in "Ff":
            time = max(time, get_time(stage - 1, type, minibatch) + _c)
        if stage == 0 and type == 'f':
            time = max(time, get_time(_p - 1, 'F', minibatch) + _c)
        if stage != _p - 1 and type in "Bb":
            time = max(time, get_time(stage + 1, type, minibatch) + _c)
        if stage == _p - 1 and type == 'b':
            time = max(time, get_time(0, 'B', minibatch) + _c)
        if stage == _p - 1 and type == 'B':
            time = max(time, get_time(stage, 'f', minibatch))
        
        time_map[(stage, type, minibatch)] = time + cost[type]
        return time_map[(stage, type, minibatch)]
    result = []
    for sid, stage in enumerate(schedule):
        result_stage = []
        for type, minibatch in stage:
            result_stage.append(ScheduledNode(
                type.upper(),
                type in "fBW",
                sid,
                minibatch,
                get_time(sid, type, minibatch) - cost[type],
                get_time(sid, type, minibatch)
            ))
        result.append(result_stage)
    return result