Nyamdavaa Amar
Edit presets
cf49f13
raw
history blame
3.33 kB
from dataclasses import dataclass
@dataclass(eq=True, frozen=True)
class ScheduledNode:
type: str
chunk: int
stage: int
minibatch: int
start_time: int
completion_time: int
def get_interleaved_variation(_p, _n, cost):
_f, _b, _w, _c = cost
schedule = []
local_prev = {}
f_order = []
b_order = []
left = [_n, _n]
for id in range(min(_n, _p)):
f_order.append(('F', id))
for id in range(min(_n, _p)):
f_order.append(('f', id))
left = [max(0, _n - _p), max(0, _n - _p)]
i = 0
cur = 0
for id in range(min(_n, _p)):
b_order.append(('B', id))
while left[0] > 0 or left[1] > 0:
if i >= _p and left[1 - cur] > 0:
cur = 1 - cur
if left[cur] > 0:
if cur == 0:
f_order.append(('F', _n - left[cur]))
b_order.append(('b', _n - left[cur] - _p))
else:
f_order.append(('f', _n - left[cur]))
b_order.append(('B', _n - left[cur]))
left[cur] -= 1
i += 3
for id in range(min(_n, _p)):
b_order.append(('b', _n - _p + id))
for stage in range(_p):
diff = min(_p + _p - stage, len(f_order))
stage_schedule = []
for i in range(diff):
stage_schedule.append(f_order[i])
for i in range(len(f_order) - diff):
stage_schedule.append(b_order[i])
stage_schedule.append(f_order[i + diff])
for i in range(diff):
stage_schedule.append(b_order[len(b_order) - diff + i])
for i in range(len(stage_schedule) - 1):
local_prev[(stage, *stage_schedule[i + 1])] = (stage, *stage_schedule[i])
schedule.append(stage_schedule)
# print(stage_schedule)
# return None
cost = {
'F': _f,
'f': _f,
'B': _b+_w,
'b': _b+_w
}
time_map = {}
def get_time(stage, type, minibatch):
if (stage, type, minibatch) in time_map:
return time_map.get((stage, type, minibatch))
time = 0
if (stage, type, minibatch) in local_prev:
time = get_time(*local_prev[(stage, type, minibatch)])
if stage > 0 and type in "Ff":
time = max(time, get_time(stage - 1, type, minibatch) + _c)
if stage == 0 and type == 'f':
time = max(time, get_time(_p - 1, 'F', minibatch) + _c)
if stage != _p - 1 and type in "Bb":
time = max(time, get_time(stage + 1, type, minibatch) + _c)
if stage == _p - 1 and type == 'b':
time = max(time, get_time(0, 'B', minibatch) + _c)
if stage == _p - 1 and type == 'B':
time = max(time, get_time(stage, 'f', minibatch))
time_map[(stage, type, minibatch)] = time + cost[type]
return time_map[(stage, type, minibatch)]
result = []
for sid, stage in enumerate(schedule):
result_stage = []
for type, minibatch in stage:
result_stage.append(ScheduledNode(
type.upper(),
type in "fBW",
sid,
minibatch,
get_time(sid, type, minibatch) - cost[type],
get_time(sid, type, minibatch)
))
result.append(result_stage)
return result