|
pattern_size = 6 |
|
from collections import Counter |
|
from dataclasses import dataclass |
|
|
|
@dataclass(eq=True, frozen=True) |
|
class ScheduledNode: |
|
type: str |
|
stage: int |
|
minibatch: int |
|
start_time: int |
|
completion_time: int |
|
|
|
def transform_schedule(schedule, f, b, w, c): |
|
result = [] |
|
|
|
stage_order = [] |
|
local_prev = {} |
|
stages = len(schedule) |
|
|
|
for sid, stage in enumerate(schedule): |
|
counter = Counter() |
|
order = [] |
|
for p in stage: |
|
if not p.strip(): |
|
continue |
|
mb = counter.get(p, 0) |
|
if order: |
|
local_prev[(sid, p, mb)] = order[-1] |
|
order.append((p, mb)) |
|
counter.update(p) |
|
stage_order.append(order) |
|
nmb = max(counter.values()) |
|
time_map = {} |
|
cost = { |
|
'F': f, |
|
'B': b, |
|
'W': w, |
|
} |
|
def get_time(stage, type, mb): |
|
if (stage, type, mb) in time_map: |
|
return time_map.get((stage, type, mb)) |
|
time = 0 |
|
if (stage, type, mb) in local_prev: |
|
time = get_time(stage, *local_prev[(stage, type, mb)]) |
|
if type in ('F') and stage > 0: |
|
time = max(time, get_time(stage - 1, type, mb) + c) |
|
if type in ('B') and stage + 1< len(schedule): |
|
time = max(time, get_time(stage + 1, type, mb) + c) |
|
|
|
time_map[(stage, type, mb)] = time + cost[type] |
|
return time_map[(stage, type, mb)] |
|
r = 0 |
|
for sid, stage in enumerate(schedule): |
|
r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r) |
|
|
|
for sid, stage in enumerate(stage_order): |
|
result_stage = [] |
|
for p, mb in stage: |
|
result_stage.append(ScheduledNode( |
|
p.upper(), |
|
sid, |
|
mb, |
|
get_time(sid, p, mb) - cost[p], |
|
get_time(sid, p, mb) |
|
) |
|
) |
|
result.append(result_stage) |
|
return result |
|
|
|
|
|
|
|
|
|
def process_warmup_without_increasing_peak_mem(schedules, m): |
|
peak_mem = 0 |
|
mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))] |
|
loc = [[{key: -1 for key in ('F', 'B', 'W')} for _ in range(m + 2)] for _ in range(len(schedules))] |
|
cntr = [{key: 0 for key in ('F', 'B', 'W')} for _ in range(len(schedules))] |
|
for sid in range(len(schedules)): |
|
cur = 0 |
|
for i in range(len(schedules[sid])): |
|
if schedules[sid][i] in ('F'): |
|
cur += 1 |
|
if schedules[sid][i] in ('W'): |
|
cur -= 1 |
|
mem[sid][i] = cur |
|
peak_mem = max(peak_mem, cur) |
|
for i in range(len(schedules[0])): |
|
for sid in range(len(schedules)): |
|
if schedules[sid][i] == ' ': |
|
continue |
|
cntr[sid][schedules[sid][i]] += 1 |
|
cnt = cntr[sid][schedules[sid][i]] |
|
pos = -1 |
|
if cnt > 1: |
|
pos = loc[sid][cnt - 1][schedules[sid][i]] |
|
if schedules[sid][i] == 'W': |
|
pos = max(pos, loc[sid][cnt]['B']) |
|
if schedules[sid][i] == 'F' and sid > 0: |
|
pos = max(pos, loc[sid - 1][cnt]['F']) |
|
if schedules[sid][i] == 'B': |
|
if sid != len(schedules) - 1: |
|
pos = max(pos, loc[sid + 1][cnt]['B']) |
|
else : |
|
pos = max(pos, loc[sid][cnt]['F']) |
|
pos += 1 |
|
while schedules[sid][pos] != ' ' and pos < i: |
|
pos += 1 |
|
if pos == i: |
|
loc[sid][cnt][schedules[sid][i]] = i |
|
continue |
|
if schedules[sid][i] in ('B', 'W'): |
|
schedules[sid][pos] = schedules[sid][i] |
|
schedules[sid][i] = ' ' |
|
if schedules[sid][pos] in ('W'): |
|
for j in range(pos, i): |
|
mem[sid][j] -= 1 |
|
loc[sid][cnt][schedules[sid][pos]] = pos |
|
continue |
|
|
|
|
|
if (sid == 0): |
|
print(cnt, pos, i) |
|
place = i |
|
while place > pos and mem[sid][place - 1] < peak_mem: |
|
place -= 1 |
|
while place < i and schedules[sid][place] != ' ': |
|
place += 1 |
|
if place == i: |
|
loc[sid][cnt][schedules[sid][i]] = i |
|
continue |
|
if (sid == 0): |
|
print(place) |
|
pos = place |
|
schedules[sid][pos] = schedules[sid][i] |
|
schedules[sid][i] = ' ' |
|
for j in range(pos, i): |
|
mem[sid][j] += 1 |
|
loc[sid][cnt][schedules[sid][pos]] = pos |
|
return schedules |
|
|
|
def schedule(p, m, cost): |
|
schedules = [[' ' for _ in range(6 * m + 2 * p + 6)] for _ in range(p)] |
|
f_0, f_1, b_0, b_1= p-1, p+1, p, p + 2 |
|
for sid in range(p - 1, -1, -1): |
|
for mid in range((m + 1) // 2): |
|
if mid * 2 < m: |
|
schedules[sid][f_0 + mid * 6], schedules[sid][b_0 + mid * 6] = 'F', 'B' |
|
if mid * 2 + 1 < m: |
|
schedules[sid][f_1 + mid * 6], schedules[sid][b_1 + mid * 6] = 'F', 'B' |
|
f_0 -= 1 |
|
f_1 -= 1 |
|
b_0 += 1 |
|
b_1 += 1 |
|
cnt = 0 |
|
for i in range(len(schedules[0])): |
|
if schedules[sid][i] == 'B': |
|
cnt += 1 |
|
if schedules[sid][i] == ' ' and cnt > 0: |
|
cnt -= 1 |
|
schedules[sid][i] = 'W' |
|
schedules = process_warmup_without_increasing_peak_mem(schedules, m) |
|
res = transform_schedule(schedules, *cost) |
|
return res |