Wan Xinyi
initial commit
4b2c8d9
from dataclasses import dataclass
from typing import List, Set
@dataclass
class GraphConfig:
mem_f: float = 2
mem_b: float = -1
mem_w: float = -1
max_mem: float = None
cost_f: int = 1
cost_b: int = 1
cost_w: int = 1
cost_comm: int = 0
print_scaling: int = 1
def __post_init__(self):
assert type(self.cost_f) is int
assert type(self.cost_b) is int
assert type(self.cost_w) is int
assert type(self.cost_comm) is int
assert self.mem_f + self.mem_b + self.mem_w == 0
@dataclass(eq=True, frozen=True)
class ScheduledNode:
type: str
stage: int
minibatch: int
start_time: int
completion_time: int
rollback: bool = False
@dataclass
class Graph:
nstages: int
nmb: int
nnodes: int
config: GraphConfig
parents: List[Set[int]] = None
name: List[str] = None
# ID mapping:
# F[stage][minibatch]: 0..STAGE* MB
# B[stage][minibatch]: STAGE* MB .. 2 * STAGE * MB
# W[stage][minibatch]: 2 * STAGE* MB .. 3 * STAGE * MB
def get_id(self, type, stage, mb):
return type * (self.nstages * self.nmb) + stage * self.nmb + mb
def get_stage(self, id):
return (id // self.nmb) % self.nstages
def get_cost(self, id):
type = id // (self.nstages * self.nmb)
return [self.config.cost_f, self.config.cost_b, self.config.cost_w][type]
def get_mem(self, id):
type = id // (self.nstages * self.nmb)
return [self.config.mem_f, self.config.mem_b, self.config.mem_w][type]
@classmethod
def build_graph(cls, nstages, nmb, config):
nnodes = nstages * nmb * 3
g = Graph(nstages=nstages, nmb=nmb, nnodes=nnodes, config=config)
parents = []
name = []
for type in range(3):
for stage in range(nstages):
for mb in range(nmb):
p = set()
if type == 0:
name.append(f'F{mb}')
if stage > 0:
p.add(g.get_id(type, stage - 1, mb))
if mb > 0:
p.add(g.get_id(type, stage, mb - 1))
elif type == 1:
name.append(f'B{mb}')
if stage == nstages - 1:
p.add(g.get_id(0, stage, mb))
else:
p.add(g.get_id(type, stage + 1, mb))
if mb > 0:
p.add(g.get_id(type, stage, mb - 1))
elif type == 2:
name.append(f'W{mb}')
p.add(g.get_id(1, stage, mb))
if mb > 0:
p.add(g.get_id(type, stage, mb - 1))
else:
assert False
parents.append(p)
g.name = name
g.parents = parents
return g
# Manual ordering producing this kind of schedule:
# fffffffbfbfbfbfbfbwbwbwbwbwbwbwwwwww
# fffffbfbfbfbfbfbfbfbwbwbwbwbwwwwwwww
# fffbfbfbfbfbfbfbfbfbfbwbwbwwwwwwwwww
# fbfbfbfbfbfbfbfbfbfbfbfbwwwwwwwwwwww
# Returns the order index of each node on its own stage
def manual_order(
self, allow_bubble_before_first_b=False, prioritize_b=False, no_bubble_greedy=True
):
order = [0] * self.nnodes
f = [0] * self.nstages
b = [0] * self.nstages
w = [0] * self.nstages
o = [0] * self.nstages
m = [0] * self.nstages
e = [0] * self.nstages
t = [0] * self.nnodes
max_mem = self.config.max_mem or self.get_mem(self.get_id(0, 0, 0)) * self.nmb * 3
comm = self.config.cost_comm
order_str = [""] * self.nstages
stage_bubble = [0] * self.nstages
def get_max_bubble():
max_bubble = 0
for bb in stage_bubble:
max_bubble = max(max_bubble, bb)
return max_bubble
def put(stage_j, type_k):
if type_k == 0:
_i = f[stage_j]
elif type_k == 1:
_i = b[stage_j]
else:
_i = w[stage_j]
_j = stage_j
_id = self.get_id(type_k, _j, _i)
_mem = self.get_mem(_id)
_cost = self.get_cost(_id)
assert m[_j] + _mem <= max_mem
tmp = e[_j] + _cost
no_bubble = tmp
if _j > 0 and type_k == 0:
tmp = max(tmp, t[self.get_id(0, _j - 1, _i)] + comm + _cost)
if _j < self.nstages - 1 and type_k == 1:
tmp = max(tmp, t[self.get_id(1, _j + 1, _i)] + comm + _cost)
if f[_j] > 0:
stage_bubble[_j] += tmp - no_bubble
e[_j] = tmp
t[_id] = tmp
m[_j] += _mem
order[_id] = o[_j]
if type_k == 0:
f[_j] += 1
elif type_k == 1:
b[_j] += 1
else:
w[_j] += 1
o[_j] += 1
fbw = "fbw"
order_str[stage_j] += fbw[type_k]
for i in range(self.nmb):
if i == 0:
for j in range(self.nstages):
put(j, 0)
f_required = [0] * self.nstages
last_t = 0
for j in range(self.nstages - 1, -1, -1):
if j == self.nstages - 1:
last_t = t[self.get_id(0, j, i)] + self.get_cost(self.get_id(1, j, i))
continue
mem = m[j]
cost = e[j]
while True:
f_id = self.get_id(0, j, f[j] + f_required[j])
if f[j] + f_required[j] < self.nmb and mem + self.get_mem(f_id) <= max_mem:
if allow_bubble_before_first_b:
if cost + self.get_cost(f_id) > last_t + comm:
break
else:
if cost >= last_t + comm:
break
mem += self.get_mem(f_id)
cost += self.get_cost(f_id)
f_required[j] += 1
else:
break
last_t = max(cost, last_t + comm) + self.get_cost(self.get_id(1, j, i))
for j in range(self.nstages):
while j > 0 and f_required[j] > 0 and f_required[j] >= f_required[j - 1] and f[j] + f_required[j] < self.nmb:
f_required[j] -= 1
for j in range(self.nstages - 1, -1, -1):
for _ in range(f_required[j]):
put(j, 0)
put(j, 1)
continue
f_required = [0] * self.nstages
for j in range(self.nstages):
if f[j] >= self.nmb:
continue
if j + 1 < self.nstages and f[j] >= f[j + 1] + 2 and prioritize_b:
next_plus_fw = (
e[j + 1]
+ self.get_cost(self.get_id(0, j + 1, f[j + 1]))
+ self.get_cost(self.get_id(1, j + 1, b[j + 1]))
+ comm
)
if e[j] >= next_plus_fw:
continue
f_id = self.get_id(0, j, f[j])
f_mem = self.get_mem(f_id)
w_cost, w_cnt = 0, 0
mem_with_w = m[j] + f_mem
while mem_with_w > max_mem and w[j] + w_cnt < b[j]:
w_id = self.get_id(2, j, w[j] + w_cnt)
w_cost += self.get_cost(w_id)
mem_with_w += self.get_mem(w_id)
w_cnt += 1
if e[j] + self.get_cost(f_id) + w_cost <= next_plus_fw:
f_required[j] = 1
continue
w_cost, w_cnt = 0, 0
# mem_with_w = m[j]
# while w[j] + w_cnt < b[j]:
# w_id = self.get_id(2, j, w[j] + w_cnt)
# w_cost += self.get_cost(w_id)
# mem_with_w += self.get_mem(w_id)
# w_cnt += 1
# if e[j] + w_cost >= next_plus_fw:
# continue
if next_plus_fw - (e[j] + w_cost) <= get_max_bubble() - stage_bubble[j]:
# TODO: can sample here
continue
f_required[j] = 1
for j in range(self.nstages - 2, -1, -1):
f_required[j] = min(f_required[j], f_required[j + 1])
for j in range(self.nstages):
if f_required[j] == 0:
continue
f_id = self.get_id(0, j, f[j])
mem = self.get_mem(f_id)
while m[j] + mem > max_mem:
if w[j] >= b[j]:
raise ValueError("Cannot fit memory")
put(j, 2)
if j > 0:
while (
w[j] < b[j]
and e[j] + self.get_cost(self.get_id(2, j, w[j]))
<= t[self.get_id(0, j - 1, f[j])] + comm
):
put(j, 2)
if w[j] < b[j] and e[j] < t[self.get_id(0, j - 1, f[j])] + comm:
# TODO: e[j] + self.get_cost(self.get_id(2, j, w[j])) > t[self.get_id(0, j - 1, f[j])] + comm
if (
t[self.get_id(0, j - 1, f[j])] + comm - e[j]
<= get_max_bubble() - stage_bubble[j]
):
# TODO: can sample here
if no_bubble_greedy:
put(j, 2)
else:
put(j, 2)
put(j, 0)
for j in range(self.nstages - 1, -1, -1):
assert b[j] == i
b_id = self.get_id(1, j, b[j])
mem = self.get_mem(b_id)
while m[j] + mem > max_mem:
if w[j] >= b[j]:
raise ValueError("Cannot fit memory")
put(j, 2)
if j + 1 < self.nstages:
while (
w[j] < b[j]
and e[j] + self.get_cost(self.get_id(2, j, w[j]))
<= t[self.get_id(1, j + 1, i)] + comm
):
put(j, 2)
if w[j] < b[j] and e[j] < t[self.get_id(1, j + 1, i)] + comm:
# TODO: e[j] + self.get_cost(self.get_id(2, j, w[j])) > t[self.get_id(1, j + 1, i)] + comm
if (
t[self.get_id(1, j + 1, i)] + comm - e[j]
<= get_max_bubble() - stage_bubble[j]
):
# TODO: can sample here
if no_bubble_greedy:
put(j, 2)
else:
put(j, 2)
if j == 0 and f[j] == self.nmb:
while w[j] < b[j]:
put(j, 2)
put(j, 1)
for i in range(self.nstages):
while w[i] < self.nmb:
put(i, 2)
# print(f"{' ' * i}{order_str[i]} -> {e[i]}")
for i in range(self.nstages):
for j in range(self.nmb):
f_id = self.get_id(0, i, j)
b_id = self.get_id(1, i, j)
w_id = self.get_id(2, i, j)
f_cost = self.get_cost(f_id)
b_cost = self.get_cost(b_id)
w_cost = self.get_cost(w_id)
assert t[b_id] >= t[f_id] + b_cost
assert t[w_id] >= t[b_id] + w_cost, f"{i}-{j}, {t[w_id]} >= {t[b_id]} + {w_cost}"
if i > 0:
assert t[f_id] >= t[self.get_id(0, i - 1, j)] + comm + f_cost, f"{i}-{j}"
if i < self.nstages - 1:
assert t[b_id] >= t[self.get_id(1, i + 1, j)] + comm + b_cost
# print(order)
best_time = 0
for i in range(self.nstages):
time_i = (
t[self.get_id(2, i, self.nmb - 1)]
- t[self.get_id(0, i, 0)]
+ self.get_cost(self.get_id(0, i, 0))
)
best_time = max(best_time, time_i)
return order, t, best_time
def initial_solution(graph):
best_time, order, complete_time = None, None, None
for allow_bubble_before_first_b in [True, False]:
for prioritize_b in [True, False]:
for no_bubble_greedy in [True, False]:
order_t, complete_time_t, best_time_t = graph.manual_order(
allow_bubble_before_first_b=allow_bubble_before_first_b,
prioritize_b=prioritize_b,
no_bubble_greedy=no_bubble_greedy,
)
if best_time is None or best_time_t < best_time:
best_time = best_time_t
order = order_t
complete_time = complete_time_t
print_detail(graph, complete_time)
print("-" * 20, best_time, "-" * 20)
return best_time, order, complete_time
def print_detail(graph, F):
typenames = ['F', 'B', 'W']
times = []
for stage in range(graph.nstages):
stage_str = ['.'] * int(F[graph.get_id(2, stage, graph.nmb - 1)] / graph.config.print_scaling)
for _type in range(3):
for _mb in range(graph.nmb):
_id = graph.get_id(_type, stage, _mb)
end = int(F[_id] / graph.config.print_scaling)
start = int((F[_id] - graph.get_cost(_id)) / graph.config.print_scaling)
for j in range(start, end):
if j == start or j == end - 1:
stage_str[j] = typenames[_type]
elif j == start + 1:
if _mb >= 10:
stage_str[j] = str(_mb // 10)
else:
stage_str[j] = str(_mb)
elif j == start + 2 and _mb >= 10:
stage_str[j] = str(_mb % 10)
else:
stage_str[j] = "-"
_str = ""
for _c in stage_str:
_str += _c
times.append(
F[graph.get_id(2, stage, graph.nmb - 1)]
- F[graph.get_id(0, stage, 0)]
+ graph.get_cost(graph.get_id(0, stage, 0))
)
print(_str)
print('Longest stage time: ', max(times))
def ilp_results(graph, F):
typenames = ['F', 'B', 'W']
local_order = []
end_time = []
for i in range(graph.nnodes):
end_time.append(F[i])
for stage in range(graph.nstages):
order = []
for type in range(3):
for mb in range(graph.nmb):
id = graph.get_id(type, stage, mb)
order.append(
ScheduledNode(
type=typenames[type],
stage=stage,
minibatch=mb,
start_time=end_time[id] - graph.get_cost(id),
completion_time=F[id],
)
)
local_order.append(order)
# For each F/B, append a send/recv node. The timestamp of recv node is the same as send node to guarrentee a global order.
comm_id = {}
comm_id_counter = 0
post_validation_time = 0
for i in range(graph.nstages - 1, -1, -1):
warmup_f_count = -1
first_b_end = end_time[graph.get_id(1, i, 0)]
for j in range(graph.nmb):
if end_time[graph.get_id(0, i, j)] < first_b_end:
warmup_f_count += 1
assert warmup_f_count >= 0
pv_id = warmup_f_count
_id = graph.get_id(0, i, pv_id)
_cost = graph.get_cost(_id)
post_validation_time = max(post_validation_time, end_time[_id] - _cost - graph.config.cost_comm)
# post_validation_time = 0
# print(i, pv_id, post_validation_time)
for it in ["RECV_", "SEND_", ""]:
if i == 0 and it == "SEND_":
continue
if i == graph.nstages - 1 and it == "RECV_":
continue
# stage_ = i - 1 if it == "RECV_" else i
stage_ = i
local_order[stage_].append(ScheduledNode(
type=it + "POST_VALIDATION",
stage=stage_,
minibatch=0,
start_time=post_validation_time,
completion_time=post_validation_time,
))
comm_id[local_order[stage_][-1]] = comm_id_counter
comm_id_counter += 1
for stage in range(graph.nstages):
for node in local_order[stage]:
if node.type == 'F' and node.stage != graph.nstages - 1:
local_order[stage].append(
ScheduledNode(
type='SEND_FORWARD',
stage=stage,
minibatch=node.minibatch,
start_time=node.completion_time,
completion_time=node.completion_time, # TODO: consider comm cost in completion time
)
)
local_order[stage + 1].append(
ScheduledNode(
type='RECV_FORWARD',
stage=stage + 1,
minibatch=node.minibatch,
start_time=node.completion_time,
completion_time=node.completion_time, # TODO: consider comm cost in completion time
)
)
comm_id[local_order[stage][-1]] = comm_id_counter
comm_id[local_order[stage + 1][-1]] = comm_id_counter
comm_id_counter += 1
if node.type == 'B' and node.stage != 0:
local_order[stage].append(
ScheduledNode(
type='SEND_BACKWARD',
stage=stage,
minibatch=node.minibatch,
start_time=node.completion_time,
completion_time=node.completion_time, # TODO: consider comm cost in completion time
)
)
local_order[stage - 1].append(
ScheduledNode(
type='RECV_BACKWARD',
stage=stage - 1,
minibatch=node.minibatch,
start_time=node.completion_time,
completion_time=node.completion_time, # TODO: consider comm cost in completion time
)
)
comm_id[local_order[stage][-1]] = comm_id_counter
comm_id[local_order[stage - 1][-1]] = comm_id_counter
comm_id_counter += 1
for stage in range(graph.nstages):
# For nodes with the same timestamp on the same stage, communication will be prioritized.
def even_breaker(x: ScheduledNode):
# Compute nodes are always delayed.
if x.type in ['F', 'B', 'W']:
return comm_id_counter
# For comm nodes, order by their unique comm id
return comm_id[x]
local_order[stage] = list(sorted(
local_order[stage], key=lambda x: (x.start_time, even_breaker(x))
))
# If a recv with intersects with previous computation, reorder them so that recv
# is executed before computation and hence can be overlapped.
for i in range(len(local_order[stage])):
if i > 0 and local_order[stage][i - 1].type in {'F', 'B', 'W'} and \
local_order[stage][i].type.startswith('RECV') and \
"POST_VALIDATION" not in local_order[stage][i].type and \
local_order[stage][i].start_time <= local_order[stage][i - 1].completion_time:
(local_order[stage][i], local_order[stage][i - 1]) = (local_order[stage][i - 1], local_order[stage][i])
# print([(x.type, x.start_time, x.completion_time) for x in local_order[stage]])
local_order_with_rollback = [[] for _ in range(graph.nstages)]
for rank in range(graph.nstages):
rollback_comm = set()
if rank > 0:
for node in local_order[rank - 1]:
if node.type == "POST_VALIDATION":
break
if node.type == "SEND_FORWARD":
rollback_comm.add(node.minibatch)
for node in local_order[rank]:
if node.type == "RECV_FORWARD" and node.minibatch in rollback_comm:
rollback = True
rollback_comm.remove(node.minibatch)
else:
rollback = False
local_order_with_rollback[rank].append(ScheduledNode(
type=node.type,
stage=node.stage,
minibatch=node.minibatch,
start_time=node.start_time,
completion_time=node.completion_time,
rollback=rollback,
))
assert len(rollback_comm) == 0
# for node in local_order_with_rollback[rank]:
# print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=', ')
# print()
print_detail(graph, end_time)
return local_order_with_rollback
def auto_schedule(nstages, nmb, config):
graph = Graph.build_graph(nstages, nmb, config)
best_time, order, complete_time = initial_solution(graph)
return ilp_results(graph, complete_time)
if __name__ == "__main__":
# auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=10))
# auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=14))
auto_schedule(24, 72, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=100))
auto_schedule(4, 12, GraphConfig(
cost_f=5478,
cost_b=5806,
cost_w=3534,
cost_comm=200,
max_mem=32,
print_scaling=1000
))
auto_schedule(32, 16, GraphConfig(
cost_f=1,
cost_b=1,
cost_w=1,
cost_comm=0,
max_mem=64,
))