pattern_size = 6 from collections import Counter from dataclasses import dataclass @dataclass(eq=True, frozen=True) class ScheduledNode: type: str chunk: int stage: int minibatch: int start_time: int completion_time: int def transform_schedule(schedule, f, b, w, c): result = [] stage_order = [] local_prev = {} stages = len(schedule) for sid, stage in enumerate(schedule): counter = Counter() order = [] for p in stage: if not p.strip(): continue mb = counter.get(p, 0) if order: local_prev[(sid, p, mb)] = order[-1] order.append((p, mb)) counter.update(p) stage_order.append(order) nmb = max(counter.values()) time_map = {} cost = { 'F': f, 'B': b, 'W': w, 'f': f, 'b': b, 'w': w, } def get_time(stage, type, mb): if (stage, type, mb) in time_map: return time_map.get((stage, type, mb)) time = 0 if (stage, type, mb) in local_prev: time = get_time(stage, *local_prev[(stage, type, mb)]) if type in "FB" and stage > 0: time = max(time, get_time(stage - 1, type, mb) + c) if type in "fb" and stage + 1< len(schedule): time = max(time, get_time(stage + 1, type, mb) + c) # print(f'{stage} {type}:{mb}', time + cost[type]) time_map[(stage, type, mb)] = time + cost[type] return time_map[(stage, type, mb)] r = 0 for sid, stage in enumerate(schedule): r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r) r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r) for sid, stage in enumerate(stage_order): result_stage = [] for p, mb in stage: result_stage.append(ScheduledNode( p.upper(), p in "fBW", sid, mb, get_time(sid, p, mb) - cost[p], get_time(sid, p, mb) ) ) result.append(result_stage) return result def evaluate_schedule(schedule, f, b, w, c): stage_order = [] local_prev = {} stages = len(schedule) for sid, stage in enumerate(schedule): counter = Counter() order = [] for p in stage: if not p.strip(): continue mb = counter.get(p, 0) if order: local_prev[(sid, p, mb)] = order[-1] order.append((p, mb)) counter.update(p) stage_order.append(order) nmb = max(counter.values()) time_map = {} cost = { 'F': f, 'B': b, 'W': w, 'f': f, 'b': b, 'w': w, } def get_time(stage, type, mb): if (stage, type, mb) in time_map: return time_map.get((stage, type, mb)) time = 0 if (stage, type, mb) in local_prev: time = get_time(stage, *local_prev[(stage, type, mb)]) if type in "FB" and stage > 0: time = max(time, get_time(stage - 1, type, mb) + c) if type in "fb" and stage + 1< len(schedule): time = max(time, get_time(stage + 1, type, mb) + c) # print(f'{stage} {type}:{mb}', time + cost[type]) time_map[(stage, type, mb)] = time + cost[type] return time_map[(stage, type, mb)] r = 0 for sid, stage in enumerate(schedule): r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r) r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r) return r def get_pattern_str(pos): pattern = [" "] * pattern_size notations = "FfBbWw" for i, v in enumerate(pos): if v < 0: continue pattern[v] = notations[i] _str = "" for v in pattern: _str += v return _str def get_peak_mem(schedules, return_all=False): max_peak = 0 all_peak = [] for schedule_ in schedules: peak, mem = 0, 0 for v in schedule_: if v in "Ff": mem += 1 elif v in "Ww": mem -= 1 peak = max(peak, mem) all_peak.append(peak) max_peak = max(max_peak, peak) if return_all: return all_peak return max_peak def calc_bubble(schedules): stage_bubbles = [] for i in range(len(schedules)): max_len = 0 count = 0 for j in range(len(schedules[i])): if schedules[i][j] != ' ': max_len = j + 1 count += 1 stage_bubbles.append(max_len - count - i) return stage_bubbles def init_repeated_schedule(p, m, patterns): repeated = [] _len = 4 * p + m + 1 for i in range(p): str_i = get_pattern_str(patterns[i]) * _len repeated_i = [] for v in str_i: repeated_i.append(v) repeated.append(repeated_i) return repeated def clear_invalid(repeated, stage, pos, offset=-1): while 0 <= pos < len(repeated[stage]): repeated[stage][pos] = ' ' pos += offset * pattern_size return repeated def clear_invalid_index(repeated, m): p = len(repeated) index = pattern_size for identifier in "FfBb": if identifier in "FB": _iter = range(p) else: _iter = range(p - 1, -1, -1) for i in _iter: for j in range(pattern_size): if repeated[i][index] == identifier: clear_invalid(repeated, i, index - pattern_size, offset=-1) clear_invalid(repeated, i, index + pattern_size * m, offset=1) index += 1 if identifier in "Bb": w_identifier = {'B': 'W', 'b': 'w'}[identifier] for k in range(pattern_size): if repeated[i][index + k] == w_identifier: clear_invalid(repeated, i, index + k - pattern_size, offset=-1) clear_invalid(repeated, i, index + k + pattern_size * m, offset=1) break break index += 1 return repeated def process_warmup_without_increasing_peak_mem(schedules, m): """ FFFFFFFFFF fBWfBWfBWfBWfBW b FFFFFFFFF f fBWfBWfBWfBWFBWb FFFFFFFF f f fBWfBWfBWFBW b FFFFFFF f f f fBWfBWFBW Bb FFFFFF f f f f fBWFBWFBWb FFFFFfFf f f f BWFBW b FFFfFfFfFf f BW Bb FfFfFfFfFfF BWb We reorganize the warmup phase in the following way (i -> pipeline stage from 0): 1. Before the first B, we set #f = min(i+1, peak_mem//2), #F = peak_mem - #f 2. Before the first b, #f = peak_mem//2 3. The offset between the first B is 1 4. Before the first b, we use the pattern of (BWf)*j + (BWF)*k, where j = max(0, peak_mem//2 - (i+1)), k = max(0, #W - j - 1) """ # process warmup phase (before the first b) p = len(schedules) peak_mem = get_peak_mem(schedules) peak_mem = min(peak_mem, 2 * p) cnt_f, cnt_ff = [], [] for i in range(p): cc_ff = min(i + 1, peak_mem // 2) cc_ff = min(cc_ff, m) cc_f = min(peak_mem - cc_ff, m) cnt_f.append(cc_f) cnt_ff.append(cc_ff) distance_b2bb = 0 for j in range(len(schedules[p - 1])): if schedules[p - 1][j] == 'B': for k in range(j, len(schedules[p - 1])): if schedules[p - 1][k] == 'b': distance_b2bb = k - j break break for i in range(p): c_f, c_ff, c_b, c_w = 0, 0, 0, 0 for j in range(len(schedules[i])): char = schedules[i][j] if char == 'F': c_f += 1 elif char == 'f': c_ff += 1 elif char == 'B': c_b += 1 elif char == 'W': c_w += 1 elif char == 'b': bj = j while j < len(schedules[i]): char = schedules[i][j] if char == 'f' and c_ff < cnt_ff[p - 1]: schedules[i][j] = ' ' c_ff += 1 if char == 'B' and c_b < c_ff: if c_b < (2 * (p - i) + distance_b2bb) // 3 or c_b < cnt_ff[p - 1] - cnt_ff[i]: # there is empty space, or the number of B is not enough to cover extra f schedules[i][j] = ' ' c_b += 1 if char == 'W' and c_w < c_b: if c_w < (2 * (p - i) + distance_b2bb - 1) // 3 or c_w < cnt_ff[p - 1] - cnt_ff[i]: # there is empty space, or the number of W is not enough to cover extra f schedules[i][j] = ' ' c_w += 1 j += 1 j = bj while j < len(schedules[i]): if schedules[i][j] == 'F': if c_f < c_ff or c_f < cnt_f[i] or c_f - cnt_f[i] + c_ff - cnt_ff[i] < c_w - 1: # put enough F, or there are some unused BW schedules[i][j] = ' ' c_f += 1 j += 1 break else: assert char == ' ' schedules[i][j] = ' ' assert c_f >= cnt_f[i] and c_ff >= cnt_ff[i] assert c_w >= cnt_ff[p - 1] - cnt_ff[i] and c_b >= cnt_ff[p - 1] - cnt_ff[i] j = i u_f, u_ff, u_b, u_w = 0, 0, 0, 0 for _ in range(2 * (p - 1 - i)): if u_f < cnt_f[i] and u_f < c_f: schedules[i][j] = 'F' u_f += 1 j += 1 for _ in range(i + 1): if u_f < cnt_f[i] and u_f < c_f: schedules[i][j] = 'F' u_f += 1 j += 1 if u_ff < cnt_ff[i] and u_ff < c_ff: schedules[i][j] = 'f' u_ff += 1 j += 1 while u_f < c_f or u_ff < c_ff or u_b < c_b or u_w < c_w: if u_b < c_b: schedules[i][j] = 'B' u_b += 1 j += 1 if u_w < c_w: schedules[i][j] = 'W' u_w += 1 j += 1 if u_ff < c_ff: assert u_ff < u_f schedules[i][j] = 'f' u_ff += 1 elif u_f < c_f: schedules[i][j] = 'F' u_f += 1 j += 1 return schedules def squeeze_without_change_order(schedules, m): p = len(schedules) squeezed = [[' '] * len(schedules[_]) for _ in range(p)] max_len = 0 for seq in squeezed: assert max_len == 0 or max_len == len(seq) max_len = max(max_len, len(seq)) identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)] identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)] stage_index = [0 for _ in range(p)] for j in range(max_len): for _dir in range(2): if _dir == 0: _iter = range(p) else: _iter = range(p - 1, -1, -1) for i in _iter: identifier = schedules[i][j] if identifier == ' ': continue if _dir == 0 and identifier in "fbw": continue if _dir == 1 and identifier in "FBW": continue _cnt = identifier_cnt[i][identifier] assert _cnt < m, "{} - {}, {}".format(i, identifier, _cnt) if identifier in "Ww" or (i == 0 and identifier in "FB") or (i == p - 1 and identifier in "fb"): if i == 0 and identifier == 'B': assert identifier_index[_cnt * p + i]['f'] >= 0 if i == p - 1 and identifier == 'f': assert identifier_index[_cnt * p + i]['F'] >= 0 if i == p - 1 and identifier == 'b': assert identifier_index[_cnt * p + i]['B'] >= 0 index = stage_index[i] elif identifier in "FB": assert identifier_index[_cnt * p + i - 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt) index = max(identifier_index[_cnt * p + i - 1][identifier] + 1, stage_index[i]) elif identifier in "fb": assert identifier_index[_cnt * p + i + 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt) index = max(identifier_index[_cnt * p + i + 1][identifier] + 1, stage_index[i]) else: raise squeezed[i][index] = identifier identifier_cnt[i][identifier] += 1 identifier_index[_cnt * p + i][identifier] = index stage_index[i] = index + 1 return squeezed def process_cooldown(schedules, m): """ fBW bwbwbwbw fBWBW bwbwbwbw fBWBWBW bwbwbwbw fBWBWBWBW bwbwbwbw f BWBWBWBbWbwbwbww f BWBWBbBbWbWbwwww f BWBbBbBbWbWWwwww f BbBbBbBbWWWWwwww We reorganize the cooldown phase in the following way (i -> pipeline stage from 0): 1. After the last f, we set #b = (peak_mem+1)//2, and #B = min(i+1, peak_mem - #b) 2. After the last f, we make all the dependencies as tight as possible """ p = len(schedules) peak_mem = get_peak_mem(schedules) assert peak_mem <= 2 * p max_bb = (peak_mem + 1) // 2 max_bb = min(max_bb, m) max_b = min(peak_mem - max_bb, m) # 1: reorganize B/b and remove W/w in cooldown phase starting_index = -1 for i in range(p): c_b, c_bb, c_w, c_ww = 0, 0, 0, 0 last_ff_index = -1 # collect B/b which can be reorganized for j in range(len(schedules[i]) - 1, -1, -1): char = schedules[i][j] if char == 'f' and last_ff_index == -1: last_ff_index = j if char == 'B' and c_b < i + 1 and c_b < max_b: schedules[i][j] = ' ' c_b += 1 if char == 'b' and c_bb < max_bb: schedules[i][j] = ' ' c_bb += 1 # clear W in the tail (#W + #w = peak_mem) for j in range(len(schedules[i]) - 1, -1, -1): char = schedules[i][j] if char == 'W' and c_w + c_ww < peak_mem: schedules[i][j] = ' ' c_w += 1 if char == 'w' and c_w + c_ww < peak_mem: schedules[i][j] = ' ' c_ww += 1 if i == 0: starting_index = last_ff_index # reorganize B/b in the tail for k in range(c_bb): index = starting_index - i + 2 * p - 2 * k assert schedules[i][index] == ' ', "{} {} {}".format(schedules[i][index], k, i) schedules[i][index] = 'b' for k in range(c_b): index = starting_index + 1 + i - 2 * k assert schedules[i][index] == ' ', schedules[i][index] schedules[i][index] = 'B' # 2: squeeze cooldown phase without change order schedules = squeeze_without_change_order(schedules, m) # 3: add W back in cooldown phase for i in range(p): c_w, c_ww = 0, 0 last_w_index = -2 for j in range(len(schedules[i]) - 1, -1, -1): if schedules[i][j] in "Ww": if last_w_index < 0: schedules[i][j] = ' ' last_w_index += 1 else: last_w_index = j break for j in range(len(schedules[i])): char = schedules[i][j] if char == 'B': c_w += 1 elif char == 'b': c_ww += 1 elif char == 'W': c_w -= 1 elif char == 'w': c_ww -= 1 if char == ' ' and j > last_w_index: if c_w > 0: schedules[i][j] = 'W' c_w -= 1 elif c_ww > 0: schedules[i][j] = 'w' c_ww -= 1 schedules = squeeze_without_change_order(schedules, m) return schedules def schedule_by_pattern(p, m, patterns, max_mem): schedules = init_repeated_schedule(p, max(m, 2 * p), patterns) schedules = clear_invalid_index(schedules, max(m, 2 * p)) init_peak_mem = get_peak_mem(schedules) if init_peak_mem > max_mem: return None, init_peak_mem, [6 * max(m, 2 * p)] * p schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p)) for sid in range(len(schedules)): cnt = {_id: 0 for _id in "FfBbWw"} for i in range(len(schedules[sid])): if(schedules[sid][i] == ' '): continue if cnt[schedules[sid][i]] >= m: schedules[sid][i] = ' ' else: cnt[schedules[sid][i]] += 1 peak_mem = get_peak_mem(schedules) if peak_mem > init_peak_mem: return None, init_peak_mem, [6 * m] * p schedules = squeeze_without_change_order(schedules, m) schedules = process_cooldown(schedules, m) peak_mem = get_peak_mem(schedules) if peak_mem > init_peak_mem: return None, init_peak_mem, [6 * m] * p stage_bubbles = calc_bubble(schedules) return schedules, peak_mem, stage_bubbles def fill_w_in_pattern(pattern): f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5 vis = [False] * pattern_size for v in pattern: if v >= 0: vis[v] = True assert pattern[b] >= 0 and pattern[bb] >= 0 for v, vw in [(b, w), (bb, ww)]: for j in range(pattern_size): pos = (pattern[v] + j) % pattern_size if not vis[pos]: pattern[vw] = pos vis[pos] = True break return pattern def get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p): whole_pattern = [pattern_0] for i in range(p - 1): last_pattern = whole_pattern[i] new_pattern = [-1] * pattern_size vis = [False] * pattern_size if i < len_0: offset = offset_0 else: offset = offset_1 for v, v_o in enumerate(offset): pos = (last_pattern[v] + v_o + pattern_size) % pattern_size assert 0 <= pos < pattern_size if vis[pos]: return None vis[pos] = True new_pattern[v] = pos new_pattern = fill_w_in_pattern(new_pattern) whole_pattern.append(new_pattern) return whole_pattern def schedule(p, m, cost, max_mem): f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5 available_patterns = [] for ff_i in range(1, pattern_size): for b_i in range(1, pattern_size): for bb_i in range(1, pattern_size): if ff_i == b_i or ff_i == bb_i or b_i == bb_i: continue pattern = [0, ff_i, b_i, bb_i, -1, -1] pattern = fill_w_in_pattern(pattern) available_patterns.append(pattern) print(len(available_patterns)) available_offsets = [ [1, -1, 1, -1], [2, -1, 2, -1], [3, -1, 3, -1], [4, -1, 4, -1], [5, -1, 5, -1] ] best_schedule = None best_bubble = None for pattern_0 in available_patterns: for i_0 in range(len(available_offsets)): for i_1 in range(i_0 + 1): for len_0 in range(1, p): offset_0 = available_offsets[i_0] offset_1 = available_offsets[i_1] whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p) if whole_pattern is None: continue s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern, min(2 * p, max_mem)) if peak_mem > 2 * p or peak_mem > max_mem: break if s is None: continue max_bubble = max(bubbles) max_bubble = evaluate_schedule(s, *cost) if best_schedule is None or max_bubble < best_bubble: best_schedule, best_bubble = s, max_bubble res = transform_schedule(best_schedule, *cost) return res