|
from dataclasses import dataclass |
|
|
|
@dataclass(eq=True, frozen=True) |
|
class ScheduledNode: |
|
type: str |
|
chunk: int |
|
stage: int |
|
minibatch: int |
|
start_time: int |
|
completion_time: int |
|
|
|
def get_interleaved_variation(_p, _n, cost): |
|
_f, _b, _w, _c = cost |
|
schedule = [] |
|
local_prev = {} |
|
|
|
f_order = [] |
|
b_order = [] |
|
|
|
left = [_n, _n] |
|
for id in range(min(_n, _p)): |
|
f_order.append(('F', id)) |
|
for id in range(min(_n, _p)): |
|
f_order.append(('f', id)) |
|
|
|
left = [max(0, _n - _p), max(0, _n - _p)] |
|
|
|
i = 0 |
|
cur = 0 |
|
for id in range(min(_n, _p)): |
|
b_order.append(('B', id)) |
|
while left[0] > 0 or left[1] > 0: |
|
if i >= _p and left[1 - cur] > 0: |
|
cur = 1 - cur |
|
if left[cur] > 0: |
|
if cur == 0: |
|
f_order.append(('F', _n - left[cur])) |
|
b_order.append(('b', _n - left[cur] - _p)) |
|
else: |
|
f_order.append(('f', _n - left[cur])) |
|
b_order.append(('B', _n - left[cur])) |
|
left[cur] -= 1 |
|
i += 3 |
|
for id in range(min(_n, _p)): |
|
b_order.append(('b', _n - _p + id)) |
|
|
|
for stage in range(_p): |
|
diff = min(_p + _p - stage, len(f_order)) |
|
stage_schedule = [] |
|
for i in range(diff): |
|
stage_schedule.append(f_order[i]) |
|
for i in range(len(f_order) - diff): |
|
stage_schedule.append(b_order[i]) |
|
stage_schedule.append(f_order[i + diff]) |
|
for i in range(diff): |
|
stage_schedule.append(b_order[len(b_order) - diff + i]) |
|
for i in range(len(stage_schedule) - 1): |
|
local_prev[(stage, *stage_schedule[i + 1])] = (stage, *stage_schedule[i]) |
|
schedule.append(stage_schedule) |
|
|
|
|
|
cost = { |
|
'F': _f, |
|
'f': _f, |
|
'B': _b+_w, |
|
'b': _b+_w |
|
} |
|
|
|
time_map = {} |
|
def get_time(stage, type, minibatch): |
|
if (stage, type, minibatch) in time_map: |
|
return time_map.get((stage, type, minibatch)) |
|
time = 0 |
|
if (stage, type, minibatch) in local_prev: |
|
time = get_time(*local_prev[(stage, type, minibatch)]) |
|
if stage > 0 and type in "Ff": |
|
time = max(time, get_time(stage - 1, type, minibatch) + _c) |
|
if stage == 0 and type == 'f': |
|
time = max(time, get_time(_p - 1, 'F', minibatch) + _c) |
|
if stage != _p - 1 and type in "Bb": |
|
time = max(time, get_time(stage + 1, type, minibatch) + _c) |
|
if stage == _p - 1 and type == 'b': |
|
time = max(time, get_time(0, 'B', minibatch) + _c) |
|
if stage == _p - 1 and type == 'B': |
|
time = max(time, get_time(stage, 'f', minibatch)) |
|
|
|
time_map[(stage, type, minibatch)] = time + cost[type] |
|
return time_map[(stage, type, minibatch)] |
|
result = [] |
|
for sid, stage in enumerate(schedule): |
|
result_stage = [] |
|
for type, minibatch in stage: |
|
result_stage.append(ScheduledNode( |
|
type.upper(), |
|
type in "fBW", |
|
sid, |
|
minibatch, |
|
get_time(sid, type, minibatch) - cost[type], |
|
get_time(sid, type, minibatch) |
|
)) |
|
result.append(result_stage) |
|
return result |