Nyamdavaa Amar
Pipeline Parallelism with Controllable Memory
3d4d40d
raw
history blame
5.62 kB
pattern_size = 6
from collections import Counter
from dataclasses import dataclass
@dataclass(eq=True, frozen=True)
class ScheduledNode:
type: str
stage: int
minibatch: int
start_time: int
completion_time: int
def transform_schedule(schedule, f, b, w, c):
result = []
stage_order = []
local_prev = {}
stages = len(schedule)
for sid, stage in enumerate(schedule):
counter = Counter()
order = []
for p in stage:
if not p.strip():
continue
mb = counter.get(p, 0)
if order:
local_prev[(sid, p, mb)] = order[-1]
order.append((p, mb))
counter.update(p)
stage_order.append(order)
nmb = max(counter.values())
time_map = {}
cost = {
'F': f,
'B': b,
'W': w,
}
def get_time(stage, type, mb):
if (stage, type, mb) in time_map:
return time_map.get((stage, type, mb))
time = 0
if (stage, type, mb) in local_prev:
time = get_time(stage, *local_prev[(stage, type, mb)])
if type in ('F') and stage > 0:
time = max(time, get_time(stage - 1, type, mb) + c)
if type in ('B') and stage + 1< len(schedule):
time = max(time, get_time(stage + 1, type, mb) + c)
# print(f'{stage} {type}:{mb}', time + cost[type])
time_map[(stage, type, mb)] = time + cost[type]
return time_map[(stage, type, mb)]
r = 0
for sid, stage in enumerate(schedule):
r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)
for sid, stage in enumerate(stage_order):
result_stage = []
for p, mb in stage:
result_stage.append(ScheduledNode(
p.upper(),
sid,
mb,
get_time(sid, p, mb) - cost[p],
get_time(sid, p, mb)
)
)
result.append(result_stage)
return result
def process_warmup_without_increasing_peak_mem(schedules, m):
peak_mem = 0
mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))]
loc = [[{key: -1 for key in ('F', 'B', 'W')} for _ in range(m + 2)] for _ in range(len(schedules))]
cntr = [{key: 0 for key in ('F', 'B', 'W')} for _ in range(len(schedules))]
for sid in range(len(schedules)):
cur = 0
for i in range(len(schedules[sid])):
if schedules[sid][i] in ('F'):
cur += 1
if schedules[sid][i] in ('W'):
cur -= 1
mem[sid][i] = cur
peak_mem = max(peak_mem, cur)
for i in range(len(schedules[0])):
for sid in range(len(schedules)):
if schedules[sid][i] == ' ':
continue
cntr[sid][schedules[sid][i]] += 1
cnt = cntr[sid][schedules[sid][i]]
pos = -1
if cnt > 1:
pos = loc[sid][cnt - 1][schedules[sid][i]]
if schedules[sid][i] == 'W':
pos = max(pos, loc[sid][cnt]['B'])
if schedules[sid][i] == 'F' and sid > 0:
pos = max(pos, loc[sid - 1][cnt]['F'])
if schedules[sid][i] == 'B':
if sid != len(schedules) - 1:
pos = max(pos, loc[sid + 1][cnt]['B'])
else :
pos = max(pos, loc[sid][cnt]['F'])
pos += 1
while schedules[sid][pos] != ' ' and pos < i:
pos += 1
if pos == i:
loc[sid][cnt][schedules[sid][i]] = i
continue
if schedules[sid][i] in ('B', 'W'):
schedules[sid][pos] = schedules[sid][i]
schedules[sid][i] = ' '
if schedules[sid][pos] in ('W'):
for j in range(pos, i):
mem[sid][j] -= 1
loc[sid][cnt][schedules[sid][pos]] = pos
continue
#If F:
if (sid == 0):
print(cnt, pos, i)
place = i
while place > pos and mem[sid][place - 1] < peak_mem:
place -= 1
while place < i and schedules[sid][place] != ' ':
place += 1
if place == i:
loc[sid][cnt][schedules[sid][i]] = i
continue
if (sid == 0):
print(place)
pos = place
schedules[sid][pos] = schedules[sid][i]
schedules[sid][i] = ' '
for j in range(pos, i):
mem[sid][j] += 1
loc[sid][cnt][schedules[sid][pos]] = pos
return schedules
def schedule(p, m, cost):
schedules = [[' ' for _ in range(6 * m + 2 * p + 6)] for _ in range(p)]
f_0, f_1, b_0, b_1= p-1, p+1, p, p + 2
for sid in range(p - 1, -1, -1):
for mid in range((m + 1) // 2):
if mid * 2 < m:
schedules[sid][f_0 + mid * 6], schedules[sid][b_0 + mid * 6] = 'F', 'B'
if mid * 2 + 1 < m:
schedules[sid][f_1 + mid * 6], schedules[sid][b_1 + mid * 6] = 'F', 'B'
f_0 -= 1
f_1 -= 1
b_0 += 1
b_1 += 1
cnt = 0
for i in range(len(schedules[0])):
if schedules[sid][i] == 'B':
cnt += 1
if schedules[sid][i] == ' ' and cnt > 0:
cnt -= 1
schedules[sid][i] = 'W'
schedules = process_warmup_without_increasing_peak_mem(schedules, m)
res = transform_schedule(schedules, *cost)
return res