Nyamdavaa Amar
Pipeline Parallelism with Controllable Memory
3d4d40d
raw
history blame
No virus
22 kB
pattern_size = 6
from collections import Counter
from dataclasses import dataclass
@dataclass(eq=True, frozen=True)
class ScheduledNode:
type: str
chunk: int
stage: int
minibatch: int
start_time: int
completion_time: int
def transform_schedule(schedule, f, b, w, c):
result = []
stage_order = []
local_prev = {}
stages = len(schedule)
for sid, stage in enumerate(schedule):
counter = Counter()
order = []
for p in stage:
if not p.strip():
continue
mb = counter.get(p, 0)
if order:
local_prev[(sid, p, mb)] = order[-1]
order.append((p, mb))
counter.update(p)
stage_order.append(order)
nmb = max(counter.values())
time_map = {}
cost = {
'F': f,
'B': b,
'W': w,
'f': f,
'b': b,
'w': w,
}
def get_time(stage, type, mb):
if (stage, type, mb) in time_map:
return time_map.get((stage, type, mb))
time = 0
if (stage, type, mb) in local_prev:
time = get_time(stage, *local_prev[(stage, type, mb)])
if type in ('F', 'B') and stage > 0:
time = max(time, get_time(stage - 1, type, mb) + c)
if type in ('f', 'b') and stage + 1< len(schedule):
time = max(time, get_time(stage + 1, type, mb) + c)
# print(f'{stage} {type}:{mb}', time + cost[type])
time_map[(stage, type, mb)] = time + cost[type]
return time_map[(stage, type, mb)]
r = 0
for sid, stage in enumerate(schedule):
r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)
r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r)
for sid, stage in enumerate(stage_order):
result_stage = []
for p, mb in stage:
result_stage.append(ScheduledNode(
p.upper(),
p in ('f', 'B', 'W'),
sid,
mb,
get_time(sid, p, mb) - cost[p],
get_time(sid, p, mb)
)
)
result.append(result_stage)
return result
def evaluate_schedule(schedule, f, b, w, c):
stage_order = []
local_prev = {}
stages = len(schedule)
for sid, stage in enumerate(schedule):
counter = Counter()
order = []
for p in stage:
if not p.strip():
continue
mb = counter.get(p, 0)
if order:
local_prev[(sid, p, mb)] = order[-1]
order.append((p, mb))
counter.update(p)
stage_order.append(order)
nmb = max(counter.values())
time_map = {}
cost = {
'F': f,
'B': b,
'W': w,
'f': f,
'b': b,
'w': w,
}
def get_time(stage, type, mb):
if (stage, type, mb) in time_map:
return time_map.get((stage, type, mb))
time = 0
if (stage, type, mb) in local_prev:
time = get_time(stage, *local_prev[(stage, type, mb)])
if type in ('F', 'B') and stage > 0:
time = max(time, get_time(stage - 1, type, mb) + c)
if type in ('f', 'b') and stage + 1< len(schedule):
time = max(time, get_time(stage + 1, type, mb) + c)
# print(f'{stage} {type}:{mb}', time + cost[type])
time_map[(stage, type, mb)] = time + cost[type]
return time_map[(stage, type, mb)]
r = 0
for sid, stage in enumerate(schedule):
r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)
r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r)
return r
def get_pattern_str(pos):
pattern = [" "] * pattern_size
notations = "FfBbWw"
for i, v in enumerate(pos):
if v < 0:
continue
pattern[v] = notations[i]
_str = ""
for v in pattern:
_str += v
return _str
def get_peak_mem(schedules, return_all=False):
max_peak = 0
all_peak = []
for schedule_ in schedules:
peak, mem = 0, 0
for v in schedule_:
if v in "Ff":
mem += 1
elif v in "Ww":
mem -= 1
peak = max(peak, mem)
all_peak.append(peak)
max_peak = max(max_peak, peak)
if return_all:
return all_peak
return max_peak
debug = False
def print_schedules(schedules):
if not debug:
return
for seq in schedules:
_str = ""
for v in seq:
_str += v
print(_str)
def calc_bubble(schedules):
stage_bubbles = []
for i in range(len(schedules)):
max_len = 0
count = 0
for j in range(len(schedules[i])):
if schedules[i][j] != ' ':
max_len = j + 1
count += 1
stage_bubbles.append(max_len - count - i)
return stage_bubbles
def init_repeated_schedule(p, m, patterns):
repeated = []
_len = 4 * p + m + 1
for i in range(p):
str_i = get_pattern_str(patterns[i]) * _len
repeated_i = []
for v in str_i:
repeated_i.append(v)
repeated.append(repeated_i)
return repeated
def clear_invalid(repeated, stage, pos, offset=-1):
while 0 <= pos < len(repeated[stage]):
repeated[stage][pos] = ' '
pos += offset * pattern_size
return repeated
def clear_invalid_index(repeated, m):
p = len(repeated)
index = pattern_size
for identifier in ['F', 'f', 'B', 'b']:
if identifier in ['F', 'B']:
_iter = range(p)
else:
_iter = range(p - 1, -1, -1)
for i in _iter:
for j in range(pattern_size):
if repeated[i][index] == identifier:
clear_invalid(repeated, i, index - pattern_size, offset=-1)
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
index += 1
if identifier in ['B', 'b']:
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
for k in range(pattern_size):
if repeated[i][index + k] == w_identifier:
clear_invalid(repeated, i, index + k - pattern_size, offset=-1)
clear_invalid(repeated, i, index + k + pattern_size * m, offset=1)
break
break
index += 1
return repeated
def process_warmup_without_increasing_peak_mem(schedules, m):
"""
FFFFFFFFFF fBWfBWfBWfBWfBW b
FFFFFFFFF f fBWfBWfBWfBWFBWb
FFFFFFFF f f fBWfBWfBWFBW b
FFFFFFF f f f fBWfBWFBW Bb
FFFFFF f f f f fBWFBWFBWb
FFFFFfFf f f f BWFBW b
FFFfFfFfFf f BW Bb
FfFfFfFfFfF BWb
We reorganize the warmup phase in the following way (i -> pipeline stage from 0):
1. Before the first B, we set #f = min(i+1, peak_mem//2), #F = peak_mem - #f
2. Before the first b, #f = peak_mem//2
3. The offset between the first B is 1
4. Before the first b, we use the pattern of (BWf)*j + (BWF)*k,
where j = max(0, peak_mem//2 - (i+1)), k = max(0, #W - j - 1)
"""
# process warmup phase (before the first b)
p = len(schedules)
peak_mem = get_peak_mem(schedules)
peak_mem = min(peak_mem, 2 * p)
cnt_f, cnt_ff = [], []
for i in range(p):
cc_ff = min(i + 1, peak_mem // 2)
cc_ff = min(cc_ff, m)
cc_f = min(peak_mem - cc_ff, m)
cnt_f.append(cc_f)
cnt_ff.append(cc_ff)
distance_b2bb = 0
for j in range(len(schedules[p - 1])):
if schedules[p - 1][j] == 'B':
for k in range(j, len(schedules[p - 1])):
if schedules[p - 1][k] == 'b':
distance_b2bb = k - j
break
break
for i in range(p):
c_f, c_ff, c_b, c_w = 0, 0, 0, 0
for j in range(len(schedules[i])):
char = schedules[i][j]
if char == 'F':
c_f += 1
elif char == 'f':
c_ff += 1
elif char == 'B':
c_b += 1
elif char == 'W':
c_w += 1
elif char == 'b':
bj = j
while j < len(schedules[i]):
char = schedules[i][j]
if char == 'f' and c_ff < cnt_ff[p - 1]:
schedules[i][j] = ' '
c_ff += 1
if char == 'B' and c_b < c_ff:
if c_b < (2 * (p - i) + distance_b2bb) // 3 or c_b < cnt_ff[p - 1] - cnt_ff[i]:
# there is empty space, or the number of B is not enough to cover extra f
schedules[i][j] = ' '
c_b += 1
if char == 'W' and c_w < c_b:
if c_w < (2 * (p - i) + distance_b2bb - 1) // 3 or c_w < cnt_ff[p - 1] - cnt_ff[i]:
# there is empty space, or the number of W is not enough to cover extra f
schedules[i][j] = ' '
c_w += 1
j += 1
j = bj
while j < len(schedules[i]):
if schedules[i][j] == 'F':
if c_f < c_ff or c_f < cnt_f[i] or c_f - cnt_f[i] + c_ff - cnt_ff[i] < c_w - 1:
# put enough F, or there are some unused BW
schedules[i][j] = ' '
c_f += 1
j += 1
break
else:
assert char == ' '
schedules[i][j] = ' '
assert c_f >= cnt_f[i] and c_ff >= cnt_ff[i]
assert c_w >= cnt_ff[p - 1] - cnt_ff[i] and c_b >= cnt_ff[p - 1] - cnt_ff[i]
j = i
u_f, u_ff, u_b, u_w = 0, 0, 0, 0
for _ in range(2 * (p - 1 - i)):
if u_f < cnt_f[i] and u_f < c_f:
schedules[i][j] = 'F'
u_f += 1
j += 1
for _ in range(i + 1):
if u_f < cnt_f[i] and u_f < c_f:
schedules[i][j] = 'F'
u_f += 1
j += 1
if u_ff < cnt_ff[i] and u_ff < c_ff:
schedules[i][j] = 'f'
u_ff += 1
j += 1
while u_f < c_f or u_ff < c_ff or u_b < c_b or u_w < c_w:
if u_b < c_b:
schedules[i][j] = 'B'
u_b += 1
j += 1
if u_w < c_w:
schedules[i][j] = 'W'
u_w += 1
j += 1
if u_ff < c_ff:
assert u_ff < u_f
schedules[i][j] = 'f'
u_ff += 1
elif u_f < c_f:
schedules[i][j] = 'F'
u_f += 1
j += 1
return schedules
def squeeze_without_change_order(schedules, m):
p = len(schedules)
squeezed = [[' '] * len(schedules[_]) for _ in range(p)]
max_len = 0
for seq in squeezed:
assert max_len == 0 or max_len == len(seq)
max_len = max(max_len, len(seq))
identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
stage_index = [0 for _ in range(p)]
for j in range(max_len):
for _dir in range(2):
if _dir == 0:
_iter = range(p)
else:
_iter = range(p - 1, -1, -1)
for i in _iter:
identifier = schedules[i][j]
if identifier == ' ':
continue
if _dir == 0 and identifier in "fbw":
continue
if _dir == 1 and identifier in "FBW":
continue
_cnt = identifier_cnt[i][identifier]
assert _cnt < m, "{} - {}, {}".format(i, identifier, _cnt)
if identifier in "Ww" or (i == 0 and identifier in "FB") or (i == p - 1 and identifier in "fb"):
if i == 0 and identifier == 'B':
assert identifier_index[_cnt * p + i]['f'] >= 0
if i == p - 1 and identifier == 'f':
assert identifier_index[_cnt * p + i]['F'] >= 0
if i == p - 1 and identifier == 'b':
assert identifier_index[_cnt * p + i]['B'] >= 0
index = stage_index[i]
elif identifier in "FB":
assert identifier_index[_cnt * p + i - 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt)
index = max(identifier_index[_cnt * p + i - 1][identifier] + 1, stage_index[i])
elif identifier in "fb":
assert identifier_index[_cnt * p + i + 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt)
index = max(identifier_index[_cnt * p + i + 1][identifier] + 1, stage_index[i])
else:
raise
squeezed[i][index] = identifier
identifier_cnt[i][identifier] += 1
identifier_index[_cnt * p + i][identifier] = index
stage_index[i] = index + 1
return squeezed
def process_cooldown(schedules, m):
"""
fBW bwbwbwbw
fBWBW bwbwbwbw
fBWBWBW bwbwbwbw
fBWBWBWBW bwbwbwbw
f BWBWBWBbWbwbwbww
f BWBWBbBbWbWbwwww
f BWBbBbBbWbWWwwww
f BbBbBbBbWWWWwwww
We reorganize the cooldown phase in the following way (i -> pipeline stage from 0):
1. After the last f, we set #b = (peak_mem+1)//2, and #B = min(i+1, peak_mem - #b)
2. After the last f, we make all the dependencies as tight as possible
"""
p = len(schedules)
peak_mem = get_peak_mem(schedules)
assert peak_mem <= 2 * p
max_bb = (peak_mem + 1) // 2
max_bb = min(max_bb, m)
max_b = min(peak_mem - max_bb, m)
# 1: reorganize B/b and remove W/w in cooldown phase
starting_index = -1
for i in range(p):
c_b, c_bb, c_w, c_ww = 0, 0, 0, 0
last_ff_index = -1
# collect B/b which can be reorganized
for j in range(len(schedules[i]) - 1, -1, -1):
char = schedules[i][j]
if char == 'f' and last_ff_index == -1:
last_ff_index = j
if char == 'B' and c_b < i + 1 and c_b < max_b:
schedules[i][j] = ' '
c_b += 1
if char == 'b' and c_bb < max_bb:
schedules[i][j] = ' '
c_bb += 1
# clear W in the tail (#W + #w = peak_mem)
for j in range(len(schedules[i]) - 1, -1, -1):
char = schedules[i][j]
if char == 'W' and c_w + c_ww < peak_mem:
schedules[i][j] = ' '
c_w += 1
if char == 'w' and c_w + c_ww < peak_mem:
schedules[i][j] = ' '
c_ww += 1
if i == 0:
starting_index = last_ff_index
# reorganize B/b in the tail
for k in range(c_bb):
index = starting_index - i + 2 * p - 2 * k
assert schedules[i][index] == ' ', "{} {} {}".format(schedules[i][index], k, i)
schedules[i][index] = 'b'
for k in range(c_b):
index = starting_index + 1 + i - 2 * k
assert schedules[i][index] == ' ', schedules[i][index]
schedules[i][index] = 'B'
# 2: squeeze cooldown phase without change order
schedules = squeeze_without_change_order(schedules, m)
# 3: add W back in cooldown phase
for i in range(p):
c_w, c_ww = 0, 0
last_w_index = -2
for j in range(len(schedules[i]) - 1, -1, -1):
if schedules[i][j] in "Ww":
if last_w_index < 0:
schedules[i][j] = ' '
last_w_index += 1
else:
last_w_index = j
break
for j in range(len(schedules[i])):
char = schedules[i][j]
if char == 'B':
c_w += 1
elif char == 'b':
c_ww += 1
elif char == 'W':
c_w -= 1
elif char == 'w':
c_ww -= 1
if char == ' ' and j > last_w_index:
if c_w > 0:
schedules[i][j] = 'W'
c_w -= 1
elif c_ww > 0:
schedules[i][j] = 'w'
c_ww -= 1
schedules = squeeze_without_change_order(schedules, m)
return schedules
def schedule_by_pattern(p, m, patterns):
schedules = init_repeated_schedule(p, max(m, 2 * p), patterns)
schedules = clear_invalid_index(schedules, max(m, 2 * p))
print_schedules(schedules)
init_peak_mem = get_peak_mem(schedules)
if init_peak_mem > 2 * p:
return None, init_peak_mem, [6 * max(m, 2 * p)] * p
schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p))
for sid in range(len(schedules)):
cnt = {_id: 0 for _id in "FfBbWw"}
for i in range(len(schedules[sid])):
if(schedules[sid][i] == ' '):
continue
if cnt[schedules[sid][i]] >= m:
schedules[sid][i] = ' '
else:
cnt[schedules[sid][i]] += 1
print_schedules(schedules)
peak_mem = get_peak_mem(schedules)
if peak_mem > init_peak_mem:
return None, init_peak_mem, [6 * m] * p
schedules = squeeze_without_change_order(schedules, m)
print_schedules(schedules)
schedules = process_cooldown(schedules, m)
print_schedules(schedules)
peak_mem = get_peak_mem(schedules)
if peak_mem > init_peak_mem:
return None, init_peak_mem, [6 * m] * p
stage_bubbles = calc_bubble(schedules)
return schedules, peak_mem, stage_bubbles
def fill_w_in_pattern(pattern):
f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5
vis = [False] * pattern_size
for v in pattern:
if v >= 0:
vis[v] = True
assert pattern[b] >= 0 and pattern[bb] >= 0
for v, vw in [(b, w), (bb, ww)]:
for j in range(pattern_size):
pos = (pattern[v] + j) % pattern_size
if not vis[pos]:
pattern[vw] = pos
vis[pos] = True
break
return pattern
def get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p):
whole_pattern = [pattern_0]
for i in range(p - 1):
last_pattern = whole_pattern[i]
new_pattern = [-1] * pattern_size
vis = [False] * pattern_size
if i < len_0:
offset = offset_0
else:
offset = offset_1
for v, v_o in enumerate(offset):
pos = (last_pattern[v] + v_o + pattern_size) % pattern_size
assert 0 <= pos < pattern_size
if vis[pos]:
return None
vis[pos] = True
new_pattern[v] = pos
new_pattern = fill_w_in_pattern(new_pattern)
whole_pattern.append(new_pattern)
return whole_pattern
def schedule(p, m, cost, max_mem):
f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5
available_patterns = []
for ff_i in range(1, pattern_size):
for b_i in range(1, pattern_size):
for bb_i in range(1, pattern_size):
if ff_i == b_i or ff_i == bb_i or b_i == bb_i:
continue
pattern = [0, ff_i, b_i, bb_i, -1, -1]
pattern = fill_w_in_pattern(pattern)
available_patterns.append(pattern)
available_offsets = []
for f_o in range(1, pattern_size + 1):
for ff_o in range(1, pattern_size + 1):
for b_o in range(1, pattern_size + 1):
if f_o != b_o:
continue
bb_o = ff_o + b_o - f_o
if bb_o < 1 or bb_o > pattern_size:
continue
if bb_o + ff_o + b_o + f_o > 2 * pattern_size:
continue
# if bb_o + ff_o + b_o + f_o != 6:
# continue
offset = [f_o, - ff_o, b_o, - bb_o]
if min(ff_o, bb_o) > 1:
continue
available_offsets.append(offset)
print(available_offsets, len(available_patterns))
available_offsets = [
[1, -1, 1, -1],
[2, -1, 2, -1],
[3, -1, 3, -1],
[4, -1, 4, -1],
[5, -1, 5, -1]
]
best_schedule = None
best_bubble = None
peak_mem2min_bubble = {}
for pattern_0 in available_patterns:
for i_0 in range(len(available_offsets)):
for i_1 in range(i_0 + 1):
for len_0 in range(1, p):
offset_0 = available_offsets[i_0]
offset_1 = available_offsets[i_1]
whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p)
if whole_pattern is None:
continue
# for pattern in whole_pattern:
# print(get_pattern_str(pattern))
# print(offset)
s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern)
if s is None:
continue
if peak_mem > 2 * p or peak_mem > max_mem:
continue
max_bubble = max(bubbles)
max_bubble = evaluate_schedule(s, *cost)
if best_schedule is None or max_bubble < best_bubble:
best_schedule, best_bubble = s, max_bubble
res = transform_schedule(best_schedule, *cost)
return res