|
pattern_size = 6 |
|
from collections import Counter |
|
from dataclasses import dataclass |
|
|
|
@dataclass(eq=True, frozen=True) |
|
class ScheduledNode: |
|
type: str |
|
chunk: int |
|
stage: int |
|
minibatch: int |
|
start_time: int |
|
completion_time: int |
|
|
|
def transform_schedule(schedule, f, b, w, c): |
|
result = [] |
|
|
|
stage_order = [] |
|
local_prev = {} |
|
stages = len(schedule) |
|
|
|
for sid, stage in enumerate(schedule): |
|
counter = Counter() |
|
order = [] |
|
for p in stage: |
|
if not p.strip(): |
|
continue |
|
mb = counter.get(p, 0) |
|
if order: |
|
local_prev[(sid, p, mb)] = order[-1] |
|
order.append((p, mb)) |
|
counter.update(p) |
|
stage_order.append(order) |
|
nmb = max(counter.values()) |
|
time_map = {} |
|
cost = { |
|
'F': f, |
|
'B': b, |
|
'W': w, |
|
'f': f, |
|
'b': b, |
|
'w': w, |
|
} |
|
def get_time(stage, type, mb): |
|
if (stage, type, mb) in time_map: |
|
return time_map.get((stage, type, mb)) |
|
time = 0 |
|
if (stage, type, mb) in local_prev: |
|
time = get_time(stage, *local_prev[(stage, type, mb)]) |
|
if type in "FB" and stage > 0: |
|
time = max(time, get_time(stage - 1, type, mb) + c) |
|
if type in "fb" and stage + 1< len(schedule): |
|
time = max(time, get_time(stage + 1, type, mb) + c) |
|
|
|
time_map[(stage, type, mb)] = time + cost[type] |
|
return time_map[(stage, type, mb)] |
|
r = 0 |
|
for sid, stage in enumerate(schedule): |
|
r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r) |
|
r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r) |
|
|
|
for sid, stage in enumerate(stage_order): |
|
result_stage = [] |
|
for p, mb in stage: |
|
result_stage.append(ScheduledNode( |
|
p.upper(), |
|
p in "fBW", |
|
sid, |
|
mb, |
|
get_time(sid, p, mb) - cost[p], |
|
get_time(sid, p, mb) |
|
) |
|
) |
|
result.append(result_stage) |
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def evaluate_schedule(schedule, f, b, w, c): |
|
stage_order = [] |
|
local_prev = {} |
|
stages = len(schedule) |
|
|
|
for sid, stage in enumerate(schedule): |
|
counter = Counter() |
|
order = [] |
|
for p in stage: |
|
if not p.strip(): |
|
continue |
|
mb = counter.get(p, 0) |
|
if order: |
|
local_prev[(sid, p, mb)] = order[-1] |
|
order.append((p, mb)) |
|
counter.update(p) |
|
stage_order.append(order) |
|
nmb = max(counter.values()) |
|
time_map = {} |
|
cost = { |
|
'F': f, |
|
'B': b, |
|
'W': w, |
|
'f': f, |
|
'b': b, |
|
'w': w, |
|
} |
|
def get_time(stage, type, mb): |
|
if (stage, type, mb) in time_map: |
|
return time_map.get((stage, type, mb)) |
|
time = 0 |
|
if (stage, type, mb) in local_prev: |
|
time = get_time(stage, *local_prev[(stage, type, mb)]) |
|
if type in "FB" and stage > 0: |
|
time = max(time, get_time(stage - 1, type, mb) + c) |
|
if type in "fb" and stage + 1< len(schedule): |
|
time = max(time, get_time(stage + 1, type, mb) + c) |
|
|
|
time_map[(stage, type, mb)] = time + cost[type] |
|
return time_map[(stage, type, mb)] |
|
r = 0 |
|
for sid, stage in enumerate(schedule): |
|
r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r) |
|
r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r) |
|
return r |
|
|
|
def get_pattern_str(pos): |
|
pattern = [" "] * pattern_size |
|
notations = "FfBbWw" |
|
for i, v in enumerate(pos): |
|
if v < 0: |
|
continue |
|
pattern[v] = notations[i] |
|
_str = "" |
|
for v in pattern: |
|
_str += v |
|
return _str |
|
|
|
|
|
def get_peak_mem(schedules, return_all=False): |
|
max_peak = 0 |
|
all_peak = [] |
|
for schedule_ in schedules: |
|
peak, mem = 0, 0 |
|
for v in schedule_: |
|
if v in "Ff": |
|
mem += 1 |
|
elif v in "Ww": |
|
mem -= 1 |
|
peak = max(peak, mem) |
|
all_peak.append(peak) |
|
max_peak = max(max_peak, peak) |
|
if return_all: |
|
return all_peak |
|
return max_peak |
|
|
|
|
|
def calc_bubble(schedules): |
|
stage_bubbles = [] |
|
for i in range(len(schedules)): |
|
max_len = 0 |
|
count = 0 |
|
for j in range(len(schedules[i])): |
|
if schedules[i][j] != ' ': |
|
max_len = j + 1 |
|
count += 1 |
|
stage_bubbles.append(max_len - count - i) |
|
return stage_bubbles |
|
|
|
|
|
def init_repeated_schedule(p, m, patterns): |
|
repeated = [] |
|
_len = 4 * p + m + 1 |
|
for i in range(p): |
|
str_i = get_pattern_str(patterns[i]) * _len |
|
repeated_i = [] |
|
for v in str_i: |
|
repeated_i.append(v) |
|
repeated.append(repeated_i) |
|
return repeated |
|
|
|
|
|
def clear_invalid(repeated, stage, pos, offset=-1): |
|
while 0 <= pos < len(repeated[stage]): |
|
repeated[stage][pos] = ' ' |
|
pos += offset * pattern_size |
|
return repeated |
|
|
|
|
|
def clear_invalid_index(repeated, m): |
|
p = len(repeated) |
|
index = pattern_size |
|
for identifier in "FfBb": |
|
if identifier in "FB": |
|
_iter = range(p) |
|
else: |
|
_iter = range(p - 1, -1, -1) |
|
for i in _iter: |
|
for j in range(pattern_size): |
|
if repeated[i][index] == identifier: |
|
clear_invalid(repeated, i, index - pattern_size, offset=-1) |
|
clear_invalid(repeated, i, index + pattern_size * m, offset=1) |
|
index += 1 |
|
if identifier in "Bb": |
|
w_identifier = {'B': 'W', 'b': 'w'}[identifier] |
|
for k in range(pattern_size): |
|
if repeated[i][index + k] == w_identifier: |
|
clear_invalid(repeated, i, index + k - pattern_size, offset=-1) |
|
clear_invalid(repeated, i, index + k + pattern_size * m, offset=1) |
|
break |
|
break |
|
index += 1 |
|
return repeated |
|
|
|
|
|
def process_warmup_without_increasing_peak_mem(schedules, m): |
|
""" |
|
FFFFFFFFFF fBWfBWfBWfBWfBW b |
|
FFFFFFFFF f fBWfBWfBWfBWFBWb |
|
FFFFFFFF f f fBWfBWfBWFBW b |
|
FFFFFFF f f f fBWfBWFBW Bb |
|
FFFFFF f f f f fBWFBWFBWb |
|
FFFFFfFf f f f BWFBW b |
|
FFFfFfFfFf f BW Bb |
|
FfFfFfFfFfF BWb |
|
We reorganize the warmup phase in the following way (i -> pipeline stage from 0): |
|
1. Before the first B, we set #f = min(i+1, peak_mem//2), #F = peak_mem - #f |
|
2. Before the first b, #f = peak_mem//2 |
|
3. The offset between the first B is 1 |
|
4. Before the first b, we use the pattern of (BWf)*j + (BWF)*k, |
|
where j = max(0, peak_mem//2 - (i+1)), k = max(0, #W - j - 1) |
|
""" |
|
|
|
p = len(schedules) |
|
peak_mem = get_peak_mem(schedules) |
|
peak_mem = min(peak_mem, 2 * p) |
|
cnt_f, cnt_ff = [], [] |
|
for i in range(p): |
|
cc_ff = min(i + 1, peak_mem // 2) |
|
cc_ff = min(cc_ff, m) |
|
cc_f = min(peak_mem - cc_ff, m) |
|
cnt_f.append(cc_f) |
|
cnt_ff.append(cc_ff) |
|
distance_b2bb = 0 |
|
for j in range(len(schedules[p - 1])): |
|
if schedules[p - 1][j] == 'B': |
|
for k in range(j, len(schedules[p - 1])): |
|
if schedules[p - 1][k] == 'b': |
|
distance_b2bb = k - j |
|
break |
|
break |
|
for i in range(p): |
|
c_f, c_ff, c_b, c_w = 0, 0, 0, 0 |
|
for j in range(len(schedules[i])): |
|
char = schedules[i][j] |
|
if char == 'F': |
|
c_f += 1 |
|
elif char == 'f': |
|
c_ff += 1 |
|
elif char == 'B': |
|
c_b += 1 |
|
elif char == 'W': |
|
c_w += 1 |
|
elif char == 'b': |
|
bj = j |
|
while j < len(schedules[i]): |
|
char = schedules[i][j] |
|
if char == 'f' and c_ff < cnt_ff[p - 1]: |
|
schedules[i][j] = ' ' |
|
c_ff += 1 |
|
if char == 'B' and c_b < c_ff: |
|
if c_b < (2 * (p - i) + distance_b2bb) // 3 or c_b < cnt_ff[p - 1] - cnt_ff[i]: |
|
|
|
schedules[i][j] = ' ' |
|
c_b += 1 |
|
if char == 'W' and c_w < c_b: |
|
if c_w < (2 * (p - i) + distance_b2bb - 1) // 3 or c_w < cnt_ff[p - 1] - cnt_ff[i]: |
|
|
|
schedules[i][j] = ' ' |
|
c_w += 1 |
|
j += 1 |
|
j = bj |
|
while j < len(schedules[i]): |
|
if schedules[i][j] == 'F': |
|
if c_f < c_ff or c_f < cnt_f[i] or c_f - cnt_f[i] + c_ff - cnt_ff[i] < c_w - 1: |
|
|
|
schedules[i][j] = ' ' |
|
c_f += 1 |
|
j += 1 |
|
break |
|
else: |
|
assert char == ' ' |
|
schedules[i][j] = ' ' |
|
assert c_f >= cnt_f[i] and c_ff >= cnt_ff[i] |
|
assert c_w >= cnt_ff[p - 1] - cnt_ff[i] and c_b >= cnt_ff[p - 1] - cnt_ff[i] |
|
j = i |
|
u_f, u_ff, u_b, u_w = 0, 0, 0, 0 |
|
for _ in range(2 * (p - 1 - i)): |
|
if u_f < cnt_f[i] and u_f < c_f: |
|
schedules[i][j] = 'F' |
|
u_f += 1 |
|
j += 1 |
|
for _ in range(i + 1): |
|
if u_f < cnt_f[i] and u_f < c_f: |
|
schedules[i][j] = 'F' |
|
u_f += 1 |
|
j += 1 |
|
if u_ff < cnt_ff[i] and u_ff < c_ff: |
|
schedules[i][j] = 'f' |
|
u_ff += 1 |
|
j += 1 |
|
while u_f < c_f or u_ff < c_ff or u_b < c_b or u_w < c_w: |
|
if u_b < c_b: |
|
schedules[i][j] = 'B' |
|
u_b += 1 |
|
j += 1 |
|
if u_w < c_w: |
|
schedules[i][j] = 'W' |
|
u_w += 1 |
|
j += 1 |
|
if u_ff < c_ff: |
|
assert u_ff < u_f |
|
schedules[i][j] = 'f' |
|
u_ff += 1 |
|
elif u_f < c_f: |
|
schedules[i][j] = 'F' |
|
u_f += 1 |
|
j += 1 |
|
return schedules |
|
|
|
|
|
|
|
def squeeze_without_change_order(schedules, m): |
|
p = len(schedules) |
|
squeezed = [[' '] * len(schedules[_]) for _ in range(p)] |
|
max_len = 0 |
|
for seq in squeezed: |
|
assert max_len == 0 or max_len == len(seq) |
|
max_len = max(max_len, len(seq)) |
|
|
|
identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)] |
|
identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)] |
|
stage_index = [0 for _ in range(p)] |
|
for j in range(max_len): |
|
for _dir in range(2): |
|
if _dir == 0: |
|
_iter = range(p) |
|
else: |
|
_iter = range(p - 1, -1, -1) |
|
for i in _iter: |
|
identifier = schedules[i][j] |
|
if identifier == ' ': |
|
continue |
|
if _dir == 0 and identifier in "fbw": |
|
continue |
|
if _dir == 1 and identifier in "FBW": |
|
continue |
|
_cnt = identifier_cnt[i][identifier] |
|
assert _cnt < m, "{} - {}, {}".format(i, identifier, _cnt) |
|
if identifier in "Ww" or (i == 0 and identifier in "FB") or (i == p - 1 and identifier in "fb"): |
|
if i == 0 and identifier == 'B': |
|
assert identifier_index[_cnt * p + i]['f'] >= 0 |
|
if i == p - 1 and identifier == 'f': |
|
assert identifier_index[_cnt * p + i]['F'] >= 0 |
|
if i == p - 1 and identifier == 'b': |
|
assert identifier_index[_cnt * p + i]['B'] >= 0 |
|
index = stage_index[i] |
|
elif identifier in "FB": |
|
assert identifier_index[_cnt * p + i - 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt) |
|
index = max(identifier_index[_cnt * p + i - 1][identifier] + 1, stage_index[i]) |
|
elif identifier in "fb": |
|
assert identifier_index[_cnt * p + i + 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt) |
|
index = max(identifier_index[_cnt * p + i + 1][identifier] + 1, stage_index[i]) |
|
else: |
|
raise |
|
squeezed[i][index] = identifier |
|
identifier_cnt[i][identifier] += 1 |
|
identifier_index[_cnt * p + i][identifier] = index |
|
stage_index[i] = index + 1 |
|
while True: |
|
if(len(squeezed[0]) == 1): |
|
break |
|
allempty = True |
|
for x in squeezed: |
|
if x[-1] != ' ': |
|
allempty = False |
|
if allempty == False: |
|
break |
|
for x in squeezed: |
|
del x[-1] |
|
return squeezed |
|
|
|
|
|
def process_cooldown(schedules, m): |
|
""" |
|
fBW bwbwbwbw |
|
fBWBW bwbwbwbw |
|
fBWBWBW bwbwbwbw |
|
fBWBWBWBW bwbwbwbw |
|
f BWBWBWBbWbwbwbww |
|
f BWBWBbBbWbWbwwww |
|
f BWBbBbBbWbWWwwww |
|
f BbBbBbBbWWWWwwww |
|
We reorganize the cooldown phase in the following way (i -> pipeline stage from 0): |
|
1. After the last f, we set #b = (peak_mem+1)//2, and #B = min(i+1, peak_mem - #b) |
|
2. After the last f, we make all the dependencies as tight as possible |
|
""" |
|
p = len(schedules) |
|
|
|
peak_mem = get_peak_mem(schedules) |
|
assert peak_mem <= 2 * p |
|
max_bb = (peak_mem + 1) // 2 |
|
max_bb = min(max_bb, m) |
|
max_b = min(peak_mem - max_bb, m) |
|
|
|
|
|
starting_index = -1 |
|
for i in range(p): |
|
c_b, c_bb, c_w, c_ww = 0, 0, 0, 0 |
|
last_ff_index = -1 |
|
|
|
for j in range(len(schedules[i]) - 1, -1, -1): |
|
char = schedules[i][j] |
|
if char == 'f' and last_ff_index == -1: |
|
last_ff_index = j |
|
if char == 'B' and c_b < i + 1 and c_b < max_b: |
|
schedules[i][j] = ' ' |
|
c_b += 1 |
|
if char == 'b' and c_bb < max_bb: |
|
schedules[i][j] = ' ' |
|
c_bb += 1 |
|
|
|
for j in range(len(schedules[i]) - 1, -1, -1): |
|
char = schedules[i][j] |
|
if char == 'W' and c_w + c_ww < peak_mem: |
|
schedules[i][j] = ' ' |
|
c_w += 1 |
|
if char == 'w' and c_w + c_ww < peak_mem: |
|
schedules[i][j] = ' ' |
|
c_ww += 1 |
|
if i == 0: |
|
starting_index = last_ff_index |
|
|
|
for k in range(c_bb): |
|
index = starting_index - i + 2 * p - 2 * k |
|
assert schedules[i][index] == ' ', "{} {} {}".format(schedules[i][index], k, i) |
|
schedules[i][index] = 'b' |
|
for k in range(c_b): |
|
index = starting_index + 1 + i - 2 * k |
|
assert schedules[i][index] == ' ', schedules[i][index] |
|
schedules[i][index] = 'B' |
|
|
|
|
|
schedules = squeeze_without_change_order(schedules, m) |
|
|
|
|
|
for i in range(p): |
|
c_w, c_ww = 0, 0 |
|
last_w_index = -2 |
|
for j in range(len(schedules[i]) - 1, -1, -1): |
|
if schedules[i][j] in "Ww": |
|
if last_w_index < 0: |
|
schedules[i][j] = ' ' |
|
last_w_index += 1 |
|
else: |
|
last_w_index = j |
|
break |
|
for j in range(len(schedules[i])): |
|
char = schedules[i][j] |
|
if char == 'B': |
|
c_w += 1 |
|
elif char == 'b': |
|
c_ww += 1 |
|
elif char == 'W': |
|
c_w -= 1 |
|
elif char == 'w': |
|
c_ww -= 1 |
|
if char == ' ' and j > last_w_index: |
|
if c_w > 0: |
|
schedules[i][j] = 'W' |
|
c_w -= 1 |
|
elif c_ww > 0: |
|
schedules[i][j] = 'w' |
|
c_ww -= 1 |
|
|
|
schedules = squeeze_without_change_order(schedules, m) |
|
return schedules |
|
|
|
|
|
def schedule_by_pattern(p, m, patterns, max_mem): |
|
schedules = init_repeated_schedule(p, max(m, 2 * p), patterns) |
|
schedules = clear_invalid_index(schedules, max(m, 2 * p)) |
|
init_peak_mem = get_peak_mem(schedules) |
|
if init_peak_mem > max_mem: |
|
return None, init_peak_mem, [6 * max(m, 2 * p)] * p |
|
schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p)) |
|
|
|
for sid in range(len(schedules)): |
|
cnt = {_id: 0 for _id in "FfBbWw"} |
|
for i in range(len(schedules[sid])): |
|
if(schedules[sid][i] == ' '): |
|
continue |
|
if cnt[schedules[sid][i]] >= m: |
|
schedules[sid][i] = ' ' |
|
else: |
|
cnt[schedules[sid][i]] += 1 |
|
peak_mem = get_peak_mem(schedules) |
|
if peak_mem > init_peak_mem: |
|
return None, init_peak_mem, [6 * m] * p |
|
|
|
schedules = squeeze_without_change_order(schedules, m) |
|
|
|
schedules = process_cooldown(schedules, m) |
|
peak_mem = get_peak_mem(schedules) |
|
if peak_mem > init_peak_mem: |
|
return None, init_peak_mem, [6 * m] * p |
|
stage_bubbles = calc_bubble(schedules) |
|
return schedules, peak_mem, stage_bubbles |
|
|
|
|
|
def fill_w_in_pattern(pattern): |
|
f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5 |
|
vis = [False] * pattern_size |
|
for v in pattern: |
|
if v >= 0: |
|
vis[v] = True |
|
assert pattern[b] >= 0 and pattern[bb] >= 0 |
|
for v, vw in [(b, w), (bb, ww)]: |
|
for j in range(pattern_size): |
|
pos = (pattern[v] + j) % pattern_size |
|
if not vis[pos]: |
|
pattern[vw] = pos |
|
vis[pos] = True |
|
break |
|
return pattern |
|
|
|
|
|
def get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p): |
|
whole_pattern = [pattern_0] |
|
for i in range(p - 1): |
|
last_pattern = whole_pattern[i] |
|
new_pattern = [-1] * pattern_size |
|
vis = [False] * pattern_size |
|
if i < len_0: |
|
offset = offset_0 |
|
else: |
|
offset = offset_1 |
|
for v, v_o in enumerate(offset): |
|
pos = (last_pattern[v] + v_o + pattern_size) % pattern_size |
|
assert 0 <= pos < pattern_size |
|
if vis[pos]: |
|
return None |
|
vis[pos] = True |
|
new_pattern[v] = pos |
|
new_pattern = fill_w_in_pattern(new_pattern) |
|
whole_pattern.append(new_pattern) |
|
return whole_pattern |
|
|
|
|
|
|
|
def schedule(p, m, cost, max_mem): |
|
f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5 |
|
available_patterns = [] |
|
for ff_i in range(1, pattern_size): |
|
for b_i in range(1, pattern_size): |
|
for bb_i in range(1, pattern_size): |
|
if ff_i == b_i or ff_i == bb_i or b_i == bb_i: |
|
continue |
|
pattern = [0, ff_i, b_i, bb_i, -1, -1] |
|
pattern = fill_w_in_pattern(pattern) |
|
available_patterns.append(pattern) |
|
|
|
print(len(available_patterns)) |
|
available_offsets = [ |
|
[1, -1, 1, -1], |
|
[2, -1, 2, -1], |
|
[3, -1, 3, -1], |
|
[4, -1, 4, -1], |
|
[5, -1, 5, -1] |
|
] |
|
|
|
best_schedule = None |
|
best_bubble = None |
|
for pattern_0 in available_patterns: |
|
for i_0 in range(len(available_offsets)): |
|
for i_1 in range(i_0 + 1): |
|
for len_0 in range(1, p): |
|
offset_0 = available_offsets[i_0] |
|
offset_1 = available_offsets[i_1] |
|
whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p) |
|
if whole_pattern is None: |
|
continue |
|
s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern, min(2 * p, max_mem)) |
|
if peak_mem > 2 * p or peak_mem > max_mem: |
|
break |
|
if s is None: |
|
continue |
|
max_bubble = max(bubbles) |
|
max_bubble = evaluate_schedule(s, *cost) |
|
if best_schedule is None or max_bubble < best_bubble: |
|
best_schedule, best_bubble = s, max_bubble |
|
res = transform_schedule(best_schedule, *cost) |
|
return res |