Nyamdavaa Amar
commited on
Commit
•
3d4d40d
1
Parent(s):
f8e95f6
Pipeline Parallelism with Controllable Memory
Browse files- README.md +4 -10
- adaptive_schedule.py +627 -0
- app.py +152 -96
- auto_schedule.py +0 -564
- description1.md +3 -9
- description2.md +5 -32
- interleaved_variant.py +107 -0
- schedule1f1bv.py +271 -0
- svg_event.py +1 -1
- type2.py +163 -0
- v_schedule.py +0 -474
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🏆
|
4 |
colorFrom: indigo
|
5 |
colorTo: red
|
@@ -11,14 +11,8 @@ license: apache-2.0
|
|
11 |
---
|
12 |
|
13 |
|
14 |
-
#
|
15 |
|
16 |
-
|
17 |
|
18 |
-
|
19 |
-
* [Arxiv Version with ZBV](https://arxiv.org/abs/2401.10241)
|
20 |
-
* [ICLR Accepted version with ZB1P and ZB2P](https://openreview.net/pdf?id=tuzTN0eIO5)
|
21 |
-
|
22 |
-
Try out our implementation based on Megatron on [https://github.com/sail-sg/zero-bubble-pipeline-parallelism](https://github.com/sail-sg/zero-bubble-pipeline-parallelism)
|
23 |
-
|
24 |
-
Experiments shows zero bubble pipeline parallelism can accelerate training up to 30% with a similar memory comsumption. A detailed table of experiments is coming soon.
|
|
|
1 |
---
|
2 |
+
title: Pipeline Parallellism with Controllable Memory
|
3 |
emoji: 🏆
|
4 |
colorFrom: indigo
|
5 |
colorTo: red
|
|
|
11 |
---
|
12 |
|
13 |
|
14 |
+
# Pipeline Parallellism with Controllable Memory
|
15 |
|
16 |
+
Check out our paper at [Arxiv](https://arxiv.org/abs/2405.15362).
|
17 |
|
18 |
+
Bubble Rate here is calculated as (1 - longest stage time/(F+B+W)/m).
|
|
|
|
|
|
|
|
|
|
|
|
adaptive_schedule.py
ADDED
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pattern_size = 6
|
2 |
+
from collections import Counter
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
@dataclass(eq=True, frozen=True)
|
6 |
+
class ScheduledNode:
|
7 |
+
type: str
|
8 |
+
chunk: int
|
9 |
+
stage: int
|
10 |
+
minibatch: int
|
11 |
+
start_time: int
|
12 |
+
completion_time: int
|
13 |
+
|
14 |
+
def transform_schedule(schedule, f, b, w, c):
|
15 |
+
result = []
|
16 |
+
|
17 |
+
stage_order = []
|
18 |
+
local_prev = {}
|
19 |
+
stages = len(schedule)
|
20 |
+
|
21 |
+
for sid, stage in enumerate(schedule):
|
22 |
+
counter = Counter()
|
23 |
+
order = []
|
24 |
+
for p in stage:
|
25 |
+
if not p.strip():
|
26 |
+
continue
|
27 |
+
mb = counter.get(p, 0)
|
28 |
+
if order:
|
29 |
+
local_prev[(sid, p, mb)] = order[-1]
|
30 |
+
order.append((p, mb))
|
31 |
+
counter.update(p)
|
32 |
+
stage_order.append(order)
|
33 |
+
nmb = max(counter.values())
|
34 |
+
time_map = {}
|
35 |
+
cost = {
|
36 |
+
'F': f,
|
37 |
+
'B': b,
|
38 |
+
'W': w,
|
39 |
+
'f': f,
|
40 |
+
'b': b,
|
41 |
+
'w': w,
|
42 |
+
}
|
43 |
+
def get_time(stage, type, mb):
|
44 |
+
if (stage, type, mb) in time_map:
|
45 |
+
return time_map.get((stage, type, mb))
|
46 |
+
time = 0
|
47 |
+
if (stage, type, mb) in local_prev:
|
48 |
+
time = get_time(stage, *local_prev[(stage, type, mb)])
|
49 |
+
if type in ('F', 'B') and stage > 0:
|
50 |
+
time = max(time, get_time(stage - 1, type, mb) + c)
|
51 |
+
if type in ('f', 'b') and stage + 1< len(schedule):
|
52 |
+
time = max(time, get_time(stage + 1, type, mb) + c)
|
53 |
+
# print(f'{stage} {type}:{mb}', time + cost[type])
|
54 |
+
time_map[(stage, type, mb)] = time + cost[type]
|
55 |
+
return time_map[(stage, type, mb)]
|
56 |
+
r = 0
|
57 |
+
for sid, stage in enumerate(schedule):
|
58 |
+
r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)
|
59 |
+
r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r)
|
60 |
+
|
61 |
+
for sid, stage in enumerate(stage_order):
|
62 |
+
result_stage = []
|
63 |
+
for p, mb in stage:
|
64 |
+
result_stage.append(ScheduledNode(
|
65 |
+
p.upper(),
|
66 |
+
p in ('f', 'B', 'W'),
|
67 |
+
sid,
|
68 |
+
mb,
|
69 |
+
get_time(sid, p, mb) - cost[p],
|
70 |
+
get_time(sid, p, mb)
|
71 |
+
)
|
72 |
+
)
|
73 |
+
result.append(result_stage)
|
74 |
+
return result
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
def evaluate_schedule(schedule, f, b, w, c):
|
81 |
+
stage_order = []
|
82 |
+
local_prev = {}
|
83 |
+
stages = len(schedule)
|
84 |
+
|
85 |
+
for sid, stage in enumerate(schedule):
|
86 |
+
counter = Counter()
|
87 |
+
order = []
|
88 |
+
for p in stage:
|
89 |
+
if not p.strip():
|
90 |
+
continue
|
91 |
+
mb = counter.get(p, 0)
|
92 |
+
if order:
|
93 |
+
local_prev[(sid, p, mb)] = order[-1]
|
94 |
+
order.append((p, mb))
|
95 |
+
counter.update(p)
|
96 |
+
stage_order.append(order)
|
97 |
+
nmb = max(counter.values())
|
98 |
+
time_map = {}
|
99 |
+
cost = {
|
100 |
+
'F': f,
|
101 |
+
'B': b,
|
102 |
+
'W': w,
|
103 |
+
'f': f,
|
104 |
+
'b': b,
|
105 |
+
'w': w,
|
106 |
+
}
|
107 |
+
def get_time(stage, type, mb):
|
108 |
+
if (stage, type, mb) in time_map:
|
109 |
+
return time_map.get((stage, type, mb))
|
110 |
+
time = 0
|
111 |
+
if (stage, type, mb) in local_prev:
|
112 |
+
time = get_time(stage, *local_prev[(stage, type, mb)])
|
113 |
+
if type in ('F', 'B') and stage > 0:
|
114 |
+
time = max(time, get_time(stage - 1, type, mb) + c)
|
115 |
+
if type in ('f', 'b') and stage + 1< len(schedule):
|
116 |
+
time = max(time, get_time(stage + 1, type, mb) + c)
|
117 |
+
# print(f'{stage} {type}:{mb}', time + cost[type])
|
118 |
+
time_map[(stage, type, mb)] = time + cost[type]
|
119 |
+
return time_map[(stage, type, mb)]
|
120 |
+
r = 0
|
121 |
+
for sid, stage in enumerate(schedule):
|
122 |
+
r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)
|
123 |
+
r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r)
|
124 |
+
return r
|
125 |
+
|
126 |
+
def get_pattern_str(pos):
|
127 |
+
pattern = [" "] * pattern_size
|
128 |
+
notations = "FfBbWw"
|
129 |
+
for i, v in enumerate(pos):
|
130 |
+
if v < 0:
|
131 |
+
continue
|
132 |
+
pattern[v] = notations[i]
|
133 |
+
_str = ""
|
134 |
+
for v in pattern:
|
135 |
+
_str += v
|
136 |
+
return _str
|
137 |
+
|
138 |
+
|
139 |
+
def get_peak_mem(schedules, return_all=False):
|
140 |
+
max_peak = 0
|
141 |
+
all_peak = []
|
142 |
+
for schedule_ in schedules:
|
143 |
+
peak, mem = 0, 0
|
144 |
+
for v in schedule_:
|
145 |
+
if v in "Ff":
|
146 |
+
mem += 1
|
147 |
+
elif v in "Ww":
|
148 |
+
mem -= 1
|
149 |
+
peak = max(peak, mem)
|
150 |
+
all_peak.append(peak)
|
151 |
+
max_peak = max(max_peak, peak)
|
152 |
+
if return_all:
|
153 |
+
return all_peak
|
154 |
+
return max_peak
|
155 |
+
|
156 |
+
debug = False
|
157 |
+
def print_schedules(schedules):
|
158 |
+
if not debug:
|
159 |
+
return
|
160 |
+
for seq in schedules:
|
161 |
+
_str = ""
|
162 |
+
for v in seq:
|
163 |
+
_str += v
|
164 |
+
print(_str)
|
165 |
+
|
166 |
+
|
167 |
+
def calc_bubble(schedules):
|
168 |
+
stage_bubbles = []
|
169 |
+
for i in range(len(schedules)):
|
170 |
+
max_len = 0
|
171 |
+
count = 0
|
172 |
+
for j in range(len(schedules[i])):
|
173 |
+
if schedules[i][j] != ' ':
|
174 |
+
max_len = j + 1
|
175 |
+
count += 1
|
176 |
+
stage_bubbles.append(max_len - count - i)
|
177 |
+
return stage_bubbles
|
178 |
+
|
179 |
+
|
180 |
+
def init_repeated_schedule(p, m, patterns):
|
181 |
+
repeated = []
|
182 |
+
_len = 4 * p + m + 1
|
183 |
+
for i in range(p):
|
184 |
+
str_i = get_pattern_str(patterns[i]) * _len
|
185 |
+
repeated_i = []
|
186 |
+
for v in str_i:
|
187 |
+
repeated_i.append(v)
|
188 |
+
repeated.append(repeated_i)
|
189 |
+
return repeated
|
190 |
+
|
191 |
+
|
192 |
+
def clear_invalid(repeated, stage, pos, offset=-1):
|
193 |
+
while 0 <= pos < len(repeated[stage]):
|
194 |
+
repeated[stage][pos] = ' '
|
195 |
+
pos += offset * pattern_size
|
196 |
+
return repeated
|
197 |
+
|
198 |
+
|
199 |
+
def clear_invalid_index(repeated, m):
|
200 |
+
p = len(repeated)
|
201 |
+
index = pattern_size
|
202 |
+
for identifier in ['F', 'f', 'B', 'b']:
|
203 |
+
if identifier in ['F', 'B']:
|
204 |
+
_iter = range(p)
|
205 |
+
else:
|
206 |
+
_iter = range(p - 1, -1, -1)
|
207 |
+
for i in _iter:
|
208 |
+
for j in range(pattern_size):
|
209 |
+
if repeated[i][index] == identifier:
|
210 |
+
clear_invalid(repeated, i, index - pattern_size, offset=-1)
|
211 |
+
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
|
212 |
+
index += 1
|
213 |
+
if identifier in ['B', 'b']:
|
214 |
+
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
|
215 |
+
for k in range(pattern_size):
|
216 |
+
if repeated[i][index + k] == w_identifier:
|
217 |
+
clear_invalid(repeated, i, index + k - pattern_size, offset=-1)
|
218 |
+
clear_invalid(repeated, i, index + k + pattern_size * m, offset=1)
|
219 |
+
break
|
220 |
+
break
|
221 |
+
index += 1
|
222 |
+
return repeated
|
223 |
+
|
224 |
+
|
225 |
+
def process_warmup_without_increasing_peak_mem(schedules, m):
|
226 |
+
"""
|
227 |
+
FFFFFFFFFF fBWfBWfBWfBWfBW b
|
228 |
+
FFFFFFFFF f fBWfBWfBWfBWFBWb
|
229 |
+
FFFFFFFF f f fBWfBWfBWFBW b
|
230 |
+
FFFFFFF f f f fBWfBWFBW Bb
|
231 |
+
FFFFFF f f f f fBWFBWFBWb
|
232 |
+
FFFFFfFf f f f BWFBW b
|
233 |
+
FFFfFfFfFf f BW Bb
|
234 |
+
FfFfFfFfFfF BWb
|
235 |
+
We reorganize the warmup phase in the following way (i -> pipeline stage from 0):
|
236 |
+
1. Before the first B, we set #f = min(i+1, peak_mem//2), #F = peak_mem - #f
|
237 |
+
2. Before the first b, #f = peak_mem//2
|
238 |
+
3. The offset between the first B is 1
|
239 |
+
4. Before the first b, we use the pattern of (BWf)*j + (BWF)*k,
|
240 |
+
where j = max(0, peak_mem//2 - (i+1)), k = max(0, #W - j - 1)
|
241 |
+
"""
|
242 |
+
# process warmup phase (before the first b)
|
243 |
+
p = len(schedules)
|
244 |
+
peak_mem = get_peak_mem(schedules)
|
245 |
+
peak_mem = min(peak_mem, 2 * p)
|
246 |
+
cnt_f, cnt_ff = [], []
|
247 |
+
for i in range(p):
|
248 |
+
cc_ff = min(i + 1, peak_mem // 2)
|
249 |
+
cc_ff = min(cc_ff, m)
|
250 |
+
cc_f = min(peak_mem - cc_ff, m)
|
251 |
+
cnt_f.append(cc_f)
|
252 |
+
cnt_ff.append(cc_ff)
|
253 |
+
distance_b2bb = 0
|
254 |
+
for j in range(len(schedules[p - 1])):
|
255 |
+
if schedules[p - 1][j] == 'B':
|
256 |
+
for k in range(j, len(schedules[p - 1])):
|
257 |
+
if schedules[p - 1][k] == 'b':
|
258 |
+
distance_b2bb = k - j
|
259 |
+
break
|
260 |
+
break
|
261 |
+
for i in range(p):
|
262 |
+
c_f, c_ff, c_b, c_w = 0, 0, 0, 0
|
263 |
+
for j in range(len(schedules[i])):
|
264 |
+
char = schedules[i][j]
|
265 |
+
if char == 'F':
|
266 |
+
c_f += 1
|
267 |
+
elif char == 'f':
|
268 |
+
c_ff += 1
|
269 |
+
elif char == 'B':
|
270 |
+
c_b += 1
|
271 |
+
elif char == 'W':
|
272 |
+
c_w += 1
|
273 |
+
elif char == 'b':
|
274 |
+
bj = j
|
275 |
+
while j < len(schedules[i]):
|
276 |
+
char = schedules[i][j]
|
277 |
+
if char == 'f' and c_ff < cnt_ff[p - 1]:
|
278 |
+
schedules[i][j] = ' '
|
279 |
+
c_ff += 1
|
280 |
+
if char == 'B' and c_b < c_ff:
|
281 |
+
if c_b < (2 * (p - i) + distance_b2bb) // 3 or c_b < cnt_ff[p - 1] - cnt_ff[i]:
|
282 |
+
# there is empty space, or the number of B is not enough to cover extra f
|
283 |
+
schedules[i][j] = ' '
|
284 |
+
c_b += 1
|
285 |
+
if char == 'W' and c_w < c_b:
|
286 |
+
if c_w < (2 * (p - i) + distance_b2bb - 1) // 3 or c_w < cnt_ff[p - 1] - cnt_ff[i]:
|
287 |
+
# there is empty space, or the number of W is not enough to cover extra f
|
288 |
+
schedules[i][j] = ' '
|
289 |
+
c_w += 1
|
290 |
+
j += 1
|
291 |
+
j = bj
|
292 |
+
while j < len(schedules[i]):
|
293 |
+
if schedules[i][j] == 'F':
|
294 |
+
if c_f < c_ff or c_f < cnt_f[i] or c_f - cnt_f[i] + c_ff - cnt_ff[i] < c_w - 1:
|
295 |
+
# put enough F, or there are some unused BW
|
296 |
+
schedules[i][j] = ' '
|
297 |
+
c_f += 1
|
298 |
+
j += 1
|
299 |
+
break
|
300 |
+
else:
|
301 |
+
assert char == ' '
|
302 |
+
schedules[i][j] = ' '
|
303 |
+
assert c_f >= cnt_f[i] and c_ff >= cnt_ff[i]
|
304 |
+
assert c_w >= cnt_ff[p - 1] - cnt_ff[i] and c_b >= cnt_ff[p - 1] - cnt_ff[i]
|
305 |
+
j = i
|
306 |
+
u_f, u_ff, u_b, u_w = 0, 0, 0, 0
|
307 |
+
for _ in range(2 * (p - 1 - i)):
|
308 |
+
if u_f < cnt_f[i] and u_f < c_f:
|
309 |
+
schedules[i][j] = 'F'
|
310 |
+
u_f += 1
|
311 |
+
j += 1
|
312 |
+
for _ in range(i + 1):
|
313 |
+
if u_f < cnt_f[i] and u_f < c_f:
|
314 |
+
schedules[i][j] = 'F'
|
315 |
+
u_f += 1
|
316 |
+
j += 1
|
317 |
+
if u_ff < cnt_ff[i] and u_ff < c_ff:
|
318 |
+
schedules[i][j] = 'f'
|
319 |
+
u_ff += 1
|
320 |
+
j += 1
|
321 |
+
while u_f < c_f or u_ff < c_ff or u_b < c_b or u_w < c_w:
|
322 |
+
if u_b < c_b:
|
323 |
+
schedules[i][j] = 'B'
|
324 |
+
u_b += 1
|
325 |
+
j += 1
|
326 |
+
if u_w < c_w:
|
327 |
+
schedules[i][j] = 'W'
|
328 |
+
u_w += 1
|
329 |
+
j += 1
|
330 |
+
if u_ff < c_ff:
|
331 |
+
assert u_ff < u_f
|
332 |
+
schedules[i][j] = 'f'
|
333 |
+
u_ff += 1
|
334 |
+
elif u_f < c_f:
|
335 |
+
schedules[i][j] = 'F'
|
336 |
+
u_f += 1
|
337 |
+
j += 1
|
338 |
+
return schedules
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
def squeeze_without_change_order(schedules, m):
|
343 |
+
p = len(schedules)
|
344 |
+
squeezed = [[' '] * len(schedules[_]) for _ in range(p)]
|
345 |
+
max_len = 0
|
346 |
+
for seq in squeezed:
|
347 |
+
assert max_len == 0 or max_len == len(seq)
|
348 |
+
max_len = max(max_len, len(seq))
|
349 |
+
|
350 |
+
identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
|
351 |
+
identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
|
352 |
+
stage_index = [0 for _ in range(p)]
|
353 |
+
for j in range(max_len):
|
354 |
+
for _dir in range(2):
|
355 |
+
if _dir == 0:
|
356 |
+
_iter = range(p)
|
357 |
+
else:
|
358 |
+
_iter = range(p - 1, -1, -1)
|
359 |
+
for i in _iter:
|
360 |
+
identifier = schedules[i][j]
|
361 |
+
if identifier == ' ':
|
362 |
+
continue
|
363 |
+
if _dir == 0 and identifier in "fbw":
|
364 |
+
continue
|
365 |
+
if _dir == 1 and identifier in "FBW":
|
366 |
+
continue
|
367 |
+
_cnt = identifier_cnt[i][identifier]
|
368 |
+
assert _cnt < m, "{} - {}, {}".format(i, identifier, _cnt)
|
369 |
+
if identifier in "Ww" or (i == 0 and identifier in "FB") or (i == p - 1 and identifier in "fb"):
|
370 |
+
if i == 0 and identifier == 'B':
|
371 |
+
assert identifier_index[_cnt * p + i]['f'] >= 0
|
372 |
+
if i == p - 1 and identifier == 'f':
|
373 |
+
assert identifier_index[_cnt * p + i]['F'] >= 0
|
374 |
+
if i == p - 1 and identifier == 'b':
|
375 |
+
assert identifier_index[_cnt * p + i]['B'] >= 0
|
376 |
+
index = stage_index[i]
|
377 |
+
elif identifier in "FB":
|
378 |
+
assert identifier_index[_cnt * p + i - 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt)
|
379 |
+
index = max(identifier_index[_cnt * p + i - 1][identifier] + 1, stage_index[i])
|
380 |
+
elif identifier in "fb":
|
381 |
+
assert identifier_index[_cnt * p + i + 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt)
|
382 |
+
index = max(identifier_index[_cnt * p + i + 1][identifier] + 1, stage_index[i])
|
383 |
+
else:
|
384 |
+
raise
|
385 |
+
squeezed[i][index] = identifier
|
386 |
+
identifier_cnt[i][identifier] += 1
|
387 |
+
identifier_index[_cnt * p + i][identifier] = index
|
388 |
+
stage_index[i] = index + 1
|
389 |
+
return squeezed
|
390 |
+
|
391 |
+
|
392 |
+
def process_cooldown(schedules, m):
|
393 |
+
"""
|
394 |
+
fBW bwbwbwbw
|
395 |
+
fBWBW bwbwbwbw
|
396 |
+
fBWBWBW bwbwbwbw
|
397 |
+
fBWBWBWBW bwbwbwbw
|
398 |
+
f BWBWBWBbWbwbwbww
|
399 |
+
f BWBWBbBbWbWbwwww
|
400 |
+
f BWBbBbBbWbWWwwww
|
401 |
+
f BbBbBbBbWWWWwwww
|
402 |
+
We reorganize the cooldown phase in the following way (i -> pipeline stage from 0):
|
403 |
+
1. After the last f, we set #b = (peak_mem+1)//2, and #B = min(i+1, peak_mem - #b)
|
404 |
+
2. After the last f, we make all the dependencies as tight as possible
|
405 |
+
"""
|
406 |
+
p = len(schedules)
|
407 |
+
|
408 |
+
peak_mem = get_peak_mem(schedules)
|
409 |
+
assert peak_mem <= 2 * p
|
410 |
+
max_bb = (peak_mem + 1) // 2
|
411 |
+
max_bb = min(max_bb, m)
|
412 |
+
max_b = min(peak_mem - max_bb, m)
|
413 |
+
|
414 |
+
# 1: reorganize B/b and remove W/w in cooldown phase
|
415 |
+
starting_index = -1
|
416 |
+
for i in range(p):
|
417 |
+
c_b, c_bb, c_w, c_ww = 0, 0, 0, 0
|
418 |
+
last_ff_index = -1
|
419 |
+
# collect B/b which can be reorganized
|
420 |
+
for j in range(len(schedules[i]) - 1, -1, -1):
|
421 |
+
char = schedules[i][j]
|
422 |
+
if char == 'f' and last_ff_index == -1:
|
423 |
+
last_ff_index = j
|
424 |
+
if char == 'B' and c_b < i + 1 and c_b < max_b:
|
425 |
+
schedules[i][j] = ' '
|
426 |
+
c_b += 1
|
427 |
+
if char == 'b' and c_bb < max_bb:
|
428 |
+
schedules[i][j] = ' '
|
429 |
+
c_bb += 1
|
430 |
+
# clear W in the tail (#W + #w = peak_mem)
|
431 |
+
for j in range(len(schedules[i]) - 1, -1, -1):
|
432 |
+
char = schedules[i][j]
|
433 |
+
if char == 'W' and c_w + c_ww < peak_mem:
|
434 |
+
schedules[i][j] = ' '
|
435 |
+
c_w += 1
|
436 |
+
if char == 'w' and c_w + c_ww < peak_mem:
|
437 |
+
schedules[i][j] = ' '
|
438 |
+
c_ww += 1
|
439 |
+
if i == 0:
|
440 |
+
starting_index = last_ff_index
|
441 |
+
# reorganize B/b in the tail
|
442 |
+
for k in range(c_bb):
|
443 |
+
index = starting_index - i + 2 * p - 2 * k
|
444 |
+
assert schedules[i][index] == ' ', "{} {} {}".format(schedules[i][index], k, i)
|
445 |
+
schedules[i][index] = 'b'
|
446 |
+
for k in range(c_b):
|
447 |
+
index = starting_index + 1 + i - 2 * k
|
448 |
+
assert schedules[i][index] == ' ', schedules[i][index]
|
449 |
+
schedules[i][index] = 'B'
|
450 |
+
|
451 |
+
# 2: squeeze cooldown phase without change order
|
452 |
+
schedules = squeeze_without_change_order(schedules, m)
|
453 |
+
|
454 |
+
# 3: add W back in cooldown phase
|
455 |
+
for i in range(p):
|
456 |
+
c_w, c_ww = 0, 0
|
457 |
+
last_w_index = -2
|
458 |
+
for j in range(len(schedules[i]) - 1, -1, -1):
|
459 |
+
if schedules[i][j] in "Ww":
|
460 |
+
if last_w_index < 0:
|
461 |
+
schedules[i][j] = ' '
|
462 |
+
last_w_index += 1
|
463 |
+
else:
|
464 |
+
last_w_index = j
|
465 |
+
break
|
466 |
+
for j in range(len(schedules[i])):
|
467 |
+
char = schedules[i][j]
|
468 |
+
if char == 'B':
|
469 |
+
c_w += 1
|
470 |
+
elif char == 'b':
|
471 |
+
c_ww += 1
|
472 |
+
elif char == 'W':
|
473 |
+
c_w -= 1
|
474 |
+
elif char == 'w':
|
475 |
+
c_ww -= 1
|
476 |
+
if char == ' ' and j > last_w_index:
|
477 |
+
if c_w > 0:
|
478 |
+
schedules[i][j] = 'W'
|
479 |
+
c_w -= 1
|
480 |
+
elif c_ww > 0:
|
481 |
+
schedules[i][j] = 'w'
|
482 |
+
c_ww -= 1
|
483 |
+
|
484 |
+
schedules = squeeze_without_change_order(schedules, m)
|
485 |
+
return schedules
|
486 |
+
|
487 |
+
|
488 |
+
def schedule_by_pattern(p, m, patterns):
|
489 |
+
schedules = init_repeated_schedule(p, max(m, 2 * p), patterns)
|
490 |
+
schedules = clear_invalid_index(schedules, max(m, 2 * p))
|
491 |
+
print_schedules(schedules)
|
492 |
+
init_peak_mem = get_peak_mem(schedules)
|
493 |
+
if init_peak_mem > 2 * p:
|
494 |
+
return None, init_peak_mem, [6 * max(m, 2 * p)] * p
|
495 |
+
schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p))
|
496 |
+
|
497 |
+
for sid in range(len(schedules)):
|
498 |
+
cnt = {_id: 0 for _id in "FfBbWw"}
|
499 |
+
for i in range(len(schedules[sid])):
|
500 |
+
if(schedules[sid][i] == ' '):
|
501 |
+
continue
|
502 |
+
if cnt[schedules[sid][i]] >= m:
|
503 |
+
schedules[sid][i] = ' '
|
504 |
+
else:
|
505 |
+
cnt[schedules[sid][i]] += 1
|
506 |
+
print_schedules(schedules)
|
507 |
+
peak_mem = get_peak_mem(schedules)
|
508 |
+
if peak_mem > init_peak_mem:
|
509 |
+
return None, init_peak_mem, [6 * m] * p
|
510 |
+
|
511 |
+
schedules = squeeze_without_change_order(schedules, m)
|
512 |
+
print_schedules(schedules)
|
513 |
+
|
514 |
+
schedules = process_cooldown(schedules, m)
|
515 |
+
print_schedules(schedules)
|
516 |
+
peak_mem = get_peak_mem(schedules)
|
517 |
+
if peak_mem > init_peak_mem:
|
518 |
+
return None, init_peak_mem, [6 * m] * p
|
519 |
+
|
520 |
+
stage_bubbles = calc_bubble(schedules)
|
521 |
+
return schedules, peak_mem, stage_bubbles
|
522 |
+
|
523 |
+
|
524 |
+
def fill_w_in_pattern(pattern):
|
525 |
+
f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5
|
526 |
+
vis = [False] * pattern_size
|
527 |
+
for v in pattern:
|
528 |
+
if v >= 0:
|
529 |
+
vis[v] = True
|
530 |
+
assert pattern[b] >= 0 and pattern[bb] >= 0
|
531 |
+
for v, vw in [(b, w), (bb, ww)]:
|
532 |
+
for j in range(pattern_size):
|
533 |
+
pos = (pattern[v] + j) % pattern_size
|
534 |
+
if not vis[pos]:
|
535 |
+
pattern[vw] = pos
|
536 |
+
vis[pos] = True
|
537 |
+
break
|
538 |
+
return pattern
|
539 |
+
|
540 |
+
|
541 |
+
def get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p):
|
542 |
+
whole_pattern = [pattern_0]
|
543 |
+
for i in range(p - 1):
|
544 |
+
last_pattern = whole_pattern[i]
|
545 |
+
new_pattern = [-1] * pattern_size
|
546 |
+
vis = [False] * pattern_size
|
547 |
+
if i < len_0:
|
548 |
+
offset = offset_0
|
549 |
+
else:
|
550 |
+
offset = offset_1
|
551 |
+
for v, v_o in enumerate(offset):
|
552 |
+
pos = (last_pattern[v] + v_o + pattern_size) % pattern_size
|
553 |
+
assert 0 <= pos < pattern_size
|
554 |
+
if vis[pos]:
|
555 |
+
return None
|
556 |
+
vis[pos] = True
|
557 |
+
new_pattern[v] = pos
|
558 |
+
new_pattern = fill_w_in_pattern(new_pattern)
|
559 |
+
whole_pattern.append(new_pattern)
|
560 |
+
return whole_pattern
|
561 |
+
|
562 |
+
|
563 |
+
|
564 |
+
def schedule(p, m, cost, max_mem):
|
565 |
+
f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5
|
566 |
+
available_patterns = []
|
567 |
+
for ff_i in range(1, pattern_size):
|
568 |
+
for b_i in range(1, pattern_size):
|
569 |
+
for bb_i in range(1, pattern_size):
|
570 |
+
if ff_i == b_i or ff_i == bb_i or b_i == bb_i:
|
571 |
+
continue
|
572 |
+
pattern = [0, ff_i, b_i, bb_i, -1, -1]
|
573 |
+
pattern = fill_w_in_pattern(pattern)
|
574 |
+
available_patterns.append(pattern)
|
575 |
+
available_offsets = []
|
576 |
+
for f_o in range(1, pattern_size + 1):
|
577 |
+
for ff_o in range(1, pattern_size + 1):
|
578 |
+
for b_o in range(1, pattern_size + 1):
|
579 |
+
if f_o != b_o:
|
580 |
+
continue
|
581 |
+
bb_o = ff_o + b_o - f_o
|
582 |
+
if bb_o < 1 or bb_o > pattern_size:
|
583 |
+
continue
|
584 |
+
if bb_o + ff_o + b_o + f_o > 2 * pattern_size:
|
585 |
+
continue
|
586 |
+
# if bb_o + ff_o + b_o + f_o != 6:
|
587 |
+
# continue
|
588 |
+
offset = [f_o, - ff_o, b_o, - bb_o]
|
589 |
+
if min(ff_o, bb_o) > 1:
|
590 |
+
continue
|
591 |
+
available_offsets.append(offset)
|
592 |
+
|
593 |
+
print(available_offsets, len(available_patterns))
|
594 |
+
available_offsets = [
|
595 |
+
[1, -1, 1, -1],
|
596 |
+
[2, -1, 2, -1],
|
597 |
+
[3, -1, 3, -1],
|
598 |
+
[4, -1, 4, -1],
|
599 |
+
[5, -1, 5, -1]
|
600 |
+
]
|
601 |
+
|
602 |
+
best_schedule = None
|
603 |
+
best_bubble = None
|
604 |
+
peak_mem2min_bubble = {}
|
605 |
+
for pattern_0 in available_patterns:
|
606 |
+
for i_0 in range(len(available_offsets)):
|
607 |
+
for i_1 in range(i_0 + 1):
|
608 |
+
for len_0 in range(1, p):
|
609 |
+
offset_0 = available_offsets[i_0]
|
610 |
+
offset_1 = available_offsets[i_1]
|
611 |
+
whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p)
|
612 |
+
if whole_pattern is None:
|
613 |
+
continue
|
614 |
+
# for pattern in whole_pattern:
|
615 |
+
# print(get_pattern_str(pattern))
|
616 |
+
# print(offset)
|
617 |
+
s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern)
|
618 |
+
if s is None:
|
619 |
+
continue
|
620 |
+
if peak_mem > 2 * p or peak_mem > max_mem:
|
621 |
+
continue
|
622 |
+
max_bubble = max(bubbles)
|
623 |
+
max_bubble = evaluate_schedule(s, *cost)
|
624 |
+
if best_schedule is None or max_bubble < best_bubble:
|
625 |
+
best_schedule, best_bubble = s, max_bubble
|
626 |
+
res = transform_schedule(best_schedule, *cost)
|
627 |
+
return res
|
app.py
CHANGED
@@ -1,15 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
-
import auto_schedule
|
3 |
-
import v_schedule
|
4 |
import hand_schedule
|
|
|
|
|
|
|
|
|
5 |
from PIL import Image
|
6 |
from svg_event import render_manual_graph
|
7 |
import pathlib
|
8 |
-
def greet(name, is_morning, temperature):
|
9 |
-
salutation = "Good morning" if is_morning else "Good evening"
|
10 |
-
greeting = f"{salutation} {name}. It is {temperature} degrees today"
|
11 |
-
celsius = (temperature - 32) * 5 / 9
|
12 |
-
return greeting, round(celsius, 2)
|
13 |
|
14 |
def percentage(x):
|
15 |
return f"{x*100:.2f}%"
|
@@ -25,6 +22,26 @@ def get_schedule_time(result):
|
|
25 |
)
|
26 |
return time
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
img_queue = []
|
29 |
def get_schedule_image(result, max_time):
|
30 |
result = [
|
@@ -41,80 +58,87 @@ def get_schedule_image(result, max_time):
|
|
41 |
|
42 |
|
43 |
def calculate(p, m, f, b, w, c, mem):
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
baseline_result = [
|
53 |
-
list(filter(lambda x: x.type in {'F', 'B'}, r)) for r in baseline_result
|
54 |
-
]
|
55 |
-
baseline_time = get_schedule_time(baseline_result)
|
56 |
-
baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
|
57 |
-
baseline_acceleration=percentage(0)
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
cost_b=b,
|
63 |
-
cost_w=w,
|
64 |
-
cost_comm=c,
|
65 |
-
max_mem=mem * 2,
|
66 |
-
print_scaling=1000
|
67 |
-
))
|
68 |
-
|
69 |
-
zb_time=get_schedule_time(zb_result)
|
70 |
-
|
71 |
-
zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
|
72 |
-
zb_acceleration=percentage(baseline_time/zb_time - 1) if baseline_time is not None else None
|
73 |
-
|
74 |
-
if mem < p:
|
75 |
-
zbv_time=None
|
76 |
-
zbv_bubble=None
|
77 |
-
zbv_acceleration=None
|
78 |
-
zbv_image=None
|
79 |
-
zbv_result=None
|
80 |
-
else:
|
81 |
-
zbv_graph = v_schedule.PipelineGraph(
|
82 |
-
n_stage=p,
|
83 |
-
n_micro=m,
|
84 |
-
f_cost=f/2,
|
85 |
-
b_cost=b/2,
|
86 |
-
w_cost=w/2,
|
87 |
-
c_cost=c,
|
88 |
-
f_mem=2,
|
89 |
-
b_mem=-1,
|
90 |
-
w_mem=-1,
|
91 |
-
max_mem=mem * 4,
|
92 |
-
)
|
93 |
-
zbv_result = zbv_graph.get_v_schedule()
|
94 |
-
|
95 |
-
zbv_time = get_schedule_time(zbv_result)
|
96 |
-
zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
|
97 |
-
zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
|
98 |
-
|
99 |
-
max_time = max(filter(lambda x: x is not None, [baseline_time, zb_time, zbv_time]))
|
100 |
print(max_time)
|
101 |
if baseline_result is not None:
|
102 |
baseline_image = get_schedule_image(baseline_result, max_time)
|
103 |
-
if
|
104 |
-
|
105 |
-
if
|
106 |
-
|
|
|
|
|
|
|
|
|
107 |
|
108 |
-
return [
|
|
|
|
|
|
|
|
|
109 |
|
110 |
with gr.Blocks() as demo:
|
111 |
gr.Markdown(open("description1.md").read())
|
112 |
gr.Markdown("# Pipeline Scheduler Playground")
|
113 |
presets = {
|
114 |
-
'
|
115 |
-
'Ideal Case
|
116 |
-
'
|
117 |
-
'Real Case 2p': (4, 12, 1049, 1122, 903, 79, '2p'),
|
118 |
}
|
119 |
preset_buttons = {}
|
120 |
|
@@ -129,25 +153,31 @@ with gr.Blocks() as demo:
|
|
129 |
with gr.Group():
|
130 |
gr.Markdown("Basic Parameters")
|
131 |
with gr.Row():
|
132 |
-
p=gr.Number(label="Number of stages (p)", value=
|
133 |
m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0)
|
134 |
with gr.Column(scale=2):
|
135 |
with gr.Group():
|
136 |
-
gr.Markdown("Costs. All costs are used as integers. For
|
137 |
with gr.Row():
|
138 |
-
f=gr.Number(label="Time of F", value=
|
139 |
-
b=gr.Number(label="Time of B", value=
|
140 |
-
w=gr.Number(label="Time of W", value=
|
141 |
-
c=gr.Number(label="Time of one P2P communication", value=
|
142 |
with gr.Group():
|
143 |
gr.Markdown("Activation memory limit.")
|
144 |
def update_mem(p, s, mem):
|
145 |
print("update")
|
146 |
-
if s=="custom":
|
147 |
return mem
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
152 |
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
153 |
|
@@ -157,31 +187,53 @@ with gr.Blocks() as demo:
|
|
157 |
gr.Markdown("1F1B")
|
158 |
with gr.Row():
|
159 |
with gr.Column(scale=1):
|
160 |
-
baseline_time=gr.Textbox("", label="Longest Stage Time")
|
161 |
-
baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
|
162 |
baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
|
|
|
|
163 |
with gr.Column(scale=4):
|
164 |
baseline_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
165 |
-
|
166 |
with gr.Group():
|
167 |
-
gr.Markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
with gr.Row():
|
169 |
with gr.Column(scale=1):
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
with gr.Column(scale=4):
|
174 |
-
|
175 |
with gr.Group():
|
176 |
-
gr.Markdown("
|
177 |
with gr.Row():
|
178 |
with gr.Column(scale=1):
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
with gr.Column(scale=4):
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
for (k, v) in presets.items():
|
187 |
def update_preset(pb, p, m, f, b, w, c, mem):
|
@@ -192,6 +244,10 @@ with gr.Blocks() as demo:
|
|
192 |
preset_buttons[k].click(
|
193 |
update_preset,
|
194 |
inputs=[preset_buttons[k], p, m, f, b, w, c, mem],
|
195 |
-
outputs=[p, m, f, b, w, c, memsel,
|
196 |
-
|
|
|
|
|
|
|
|
|
197 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
import hand_schedule
|
3 |
+
import adaptive_schedule
|
4 |
+
import interleaved_variant
|
5 |
+
import type2
|
6 |
+
import schedule1f1bv
|
7 |
from PIL import Image
|
8 |
from svg_event import render_manual_graph
|
9 |
import pathlib
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def percentage(x):
|
12 |
return f"{x*100:.2f}%"
|
|
|
22 |
)
|
23 |
return time
|
24 |
|
25 |
+
|
26 |
+
def get_memory_usage(result):
|
27 |
+
max_mem = 0
|
28 |
+
has_w = False
|
29 |
+
for r in result:
|
30 |
+
for x in r:
|
31 |
+
if x.type in ('W', 'w'):
|
32 |
+
has_w = True
|
33 |
+
for r in result:
|
34 |
+
cur = 0
|
35 |
+
for x in r:
|
36 |
+
if x.type in ('F', 'f'):
|
37 |
+
cur += 1
|
38 |
+
if x.type in ('W', 'w'):
|
39 |
+
cur -= 1
|
40 |
+
if has_w == False and x.type in ('B', 'b'):
|
41 |
+
cur -= 1
|
42 |
+
max_mem = max(max_mem, cur)
|
43 |
+
return max_mem
|
44 |
+
|
45 |
img_queue = []
|
46 |
def get_schedule_image(result, max_time):
|
47 |
result = [
|
|
|
58 |
|
59 |
|
60 |
def calculate(p, m, f, b, w, c, mem):
|
61 |
+
baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
|
62 |
+
baseline_result = [
|
63 |
+
list(filter(lambda x: x.type in {'F', 'B'}, r)) for r in baseline_result
|
64 |
+
]
|
65 |
+
baseline_time = get_schedule_time(baseline_result)
|
66 |
+
baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
|
67 |
+
baseline_mem = get_memory_usage(baseline_result)
|
68 |
+
baseline_acceleration=percentage(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
adapt_result = adaptive_schedule.schedule(
|
71 |
+
p,
|
72 |
+
m,
|
73 |
+
[f/2, b/2, w/2, c],
|
74 |
+
max_mem=mem * 2
|
75 |
+
)
|
76 |
+
|
77 |
+
adapt_time = get_schedule_time(adapt_result)
|
78 |
+
adapt_mem = get_memory_usage(adapt_result) / 2
|
79 |
+
adapt_bubble=percentage(adapt_time/(f+b+w)/m - 1)
|
80 |
+
adapt_acceleration=percentage(baseline_time/adapt_time - 1) if baseline_time is not None else None
|
81 |
+
|
82 |
+
schedule1f1bv_result = schedule1f1bv.schedule(
|
83 |
+
p,
|
84 |
+
m,
|
85 |
+
[f / 2, b / 2, w / 2, c]
|
86 |
+
)
|
87 |
+
|
88 |
+
schedule1f1bv_time = get_schedule_time(schedule1f1bv_result)
|
89 |
+
schedule1f1bv_mem = get_memory_usage(schedule1f1bv_result) / 2
|
90 |
+
schedule1f1bv_bubble=percentage(schedule1f1bv_time/(f+b+w)/m - 1)
|
91 |
+
schedule1f1bv_acceleration=percentage(baseline_time/schedule1f1bv_time - 1) if baseline_time is not None else None
|
92 |
+
|
93 |
+
type2_result = type2.schedule(
|
94 |
+
p,
|
95 |
+
m,
|
96 |
+
[f, b, w, c]
|
97 |
+
)
|
98 |
+
|
99 |
+
type2_time = get_schedule_time(type2_result)
|
100 |
+
type2_mem = get_memory_usage(type2_result)
|
101 |
+
type2_bubble=percentage(type2_time/(f+b+w)/m - 1)
|
102 |
+
type2_acceleration=percentage(baseline_time/type2_time - 1) if baseline_time is not None else None
|
103 |
+
|
104 |
+
interleaved_result = interleaved_variant.get_interleaved_variation(
|
105 |
+
p,
|
106 |
+
m,
|
107 |
+
[f/2, b/2, w/2, c]
|
108 |
+
)
|
109 |
+
|
110 |
+
interleaved_time = get_schedule_time(interleaved_result)
|
111 |
+
interleaved_mem = get_memory_usage(interleaved_result) / 2
|
112 |
+
interleaved_bubble=percentage(interleaved_time/(f+b+w)/m - 1)
|
113 |
+
interleaved_acceleration=percentage(baseline_time/interleaved_time - 1) if baseline_time is not None else None
|
114 |
|
115 |
+
|
116 |
+
max_time = max(filter(lambda x: x is not None, [baseline_time, adapt_time, interleaved_time, type2_time, schedule1f1bv_time]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
print(max_time)
|
118 |
if baseline_result is not None:
|
119 |
baseline_image = get_schedule_image(baseline_result, max_time)
|
120 |
+
if adapt_result is not None:
|
121 |
+
adapt_image = get_schedule_image(adapt_result, max_time)
|
122 |
+
if interleaved_result is not None:
|
123 |
+
interleaved_image = get_schedule_image(interleaved_result, max_time)
|
124 |
+
if type2_result is not None:
|
125 |
+
type2_image = get_schedule_image(type2_result, max_time)
|
126 |
+
if schedule1f1bv_result is not None:
|
127 |
+
schedule1f1bv_image = get_schedule_image(schedule1f1bv_result, max_time)
|
128 |
|
129 |
+
return [baseline_acceleration, baseline_mem, baseline_bubble, baseline_image,
|
130 |
+
adapt_acceleration, adapt_mem, adapt_bubble, adapt_image,
|
131 |
+
schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
|
132 |
+
type2_acceleration, type2_mem, type2_bubble, type2_image,
|
133 |
+
interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image]
|
134 |
|
135 |
with gr.Blocks() as demo:
|
136 |
gr.Markdown(open("description1.md").read())
|
137 |
gr.Markdown("# Pipeline Scheduler Playground")
|
138 |
presets = {
|
139 |
+
'Real Case': (6, 12, 1049, 1122, 903, 79, 'V-Half'),
|
140 |
+
'Ideal Case': (6, 12, 20, 20, 20, 0, 'V-Min'),
|
141 |
+
'Zero Bubble Case': (6, 12, 1049, 1122, 903, 79, 'V-ZB')
|
|
|
142 |
}
|
143 |
preset_buttons = {}
|
144 |
|
|
|
153 |
with gr.Group():
|
154 |
gr.Markdown("Basic Parameters")
|
155 |
with gr.Row():
|
156 |
+
p=gr.Number(label="Number of stages (p)", value=6, interactive=True, precision=0)
|
157 |
m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0)
|
158 |
with gr.Column(scale=2):
|
159 |
with gr.Group():
|
160 |
+
gr.Markdown("Costs. All costs are used as integers. For chunked schedules, this is the time of two virtual stages on a stage combined.")
|
161 |
with gr.Row():
|
162 |
+
f=gr.Number(label="Time of F", value=1049, interactive=True, precision=0)
|
163 |
+
b=gr.Number(label="Time of B", value=1122, interactive=True, precision=0)
|
164 |
+
w=gr.Number(label="Time of W", value=903, interactive=True, precision=0)
|
165 |
+
c=gr.Number(label="Time of one P2P communication", value=79, interactive=True, precision=0)
|
166 |
with gr.Group():
|
167 |
gr.Markdown("Activation memory limit.")
|
168 |
def update_mem(p, s, mem):
|
169 |
print("update")
|
170 |
+
if s == "custom":
|
171 |
return mem
|
172 |
+
if s == "V-Min":
|
173 |
+
return (p + 4) // 3
|
174 |
+
if s == "V-Half":
|
175 |
+
return (p + 2) // 2
|
176 |
+
if s == "V-ZB":
|
177 |
+
return p
|
178 |
+
assert False
|
179 |
+
memsel=gr.Radio(choices=["V-Min", "V-Half", "V-ZB", "custom"], value="V-Half")
|
180 |
+
mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For chunked schedules, this is relative to two virtual stages on a stage combined.", value=(p.value + 2) // 2, interactive=True, precision=0)
|
181 |
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
182 |
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
183 |
|
|
|
187 |
gr.Markdown("1F1B")
|
188 |
with gr.Row():
|
189 |
with gr.Column(scale=1):
|
|
|
|
|
190 |
baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
191 |
+
baseline_mem=gr.Textbox("", label="Maximum memory usage")
|
192 |
+
baseline_bubble=gr.Textbox("", label="Bubble Rate")
|
193 |
with gr.Column(scale=4):
|
194 |
baseline_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
|
|
195 |
with gr.Group():
|
196 |
+
gr.Markdown("Adaptive Scheduler")
|
197 |
+
with gr.Row():
|
198 |
+
with gr.Column(scale=1):
|
199 |
+
adapt_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
200 |
+
adapt_mem=gr.Textbox("", label="Maximum memory usage")
|
201 |
+
adapt_bubble=gr.Textbox("", label="Bubble Rate")
|
202 |
+
with gr.Column(scale=4):
|
203 |
+
adapt_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
204 |
+
gr.Markdown(open("description2.md").read())
|
205 |
+
with gr.Group():
|
206 |
+
gr.Markdown("1F1B-V Schedule")
|
207 |
with gr.Row():
|
208 |
with gr.Column(scale=1):
|
209 |
+
schedule1f1bv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
210 |
+
schedule1f1bv_mem=gr.Textbox("", label="Maximum memory usage")
|
211 |
+
schedule1f1bv_bubble=gr.Textbox("", label="Bubble Rate")
|
212 |
with gr.Column(scale=4):
|
213 |
+
schedule1f1bv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
214 |
with gr.Group():
|
215 |
+
gr.Markdown("Two microbatch in one building block schedule")
|
216 |
with gr.Row():
|
217 |
with gr.Column(scale=1):
|
218 |
+
type2_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
219 |
+
type2_mem=gr.Textbox("", label="Maximum memory usage")
|
220 |
+
type2_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
|
221 |
with gr.Column(scale=4):
|
222 |
+
type2_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
223 |
+
with gr.Group():
|
224 |
+
gr.Markdown("Interleaved 1F1B Schedule")
|
225 |
+
with gr.Row():
|
226 |
+
with gr.Column(scale=1):
|
227 |
+
interleaved_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
228 |
+
interleaved_mem=gr.Textbox("", label="Maximum memory usage")
|
229 |
+
interleaved_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
|
230 |
+
with gr.Column(scale=4):
|
231 |
+
interleaved_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
232 |
+
button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_acceleration, baseline_mem, baseline_bubble, baseline_image,
|
233 |
+
adapt_acceleration, adapt_mem, adapt_bubble, adapt_image,
|
234 |
+
schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
|
235 |
+
type2_acceleration, type2_mem, type2_bubble, type2_image,
|
236 |
+
interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
|
237 |
|
238 |
for (k, v) in presets.items():
|
239 |
def update_preset(pb, p, m, f, b, w, c, mem):
|
|
|
244 |
preset_buttons[k].click(
|
245 |
update_preset,
|
246 |
inputs=[preset_buttons[k], p, m, f, b, w, c, mem],
|
247 |
+
outputs=[p, m, f, b, w, c, memsel,
|
248 |
+
baseline_acceleration, baseline_mem, baseline_bubble, baseline_image,
|
249 |
+
adapt_acceleration, adapt_mem, adapt_bubble, adapt_image,
|
250 |
+
schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
|
251 |
+
type2_acceleration, type2_mem, type2_bubble, type2_image,
|
252 |
+
interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
|
253 |
demo.launch()
|
auto_schedule.py
DELETED
@@ -1,564 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import List, Set
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
@dataclass
|
7 |
-
class GraphConfig:
|
8 |
-
mem_f: float = 2
|
9 |
-
mem_b: float = -1
|
10 |
-
mem_w: float = -1
|
11 |
-
max_mem: float = None
|
12 |
-
cost_f: int = 1
|
13 |
-
cost_b: int = 1
|
14 |
-
cost_w: int = 1
|
15 |
-
cost_comm: int = 0
|
16 |
-
print_scaling: int = 1
|
17 |
-
|
18 |
-
def __post_init__(self):
|
19 |
-
assert type(self.cost_f) is int
|
20 |
-
assert type(self.cost_b) is int
|
21 |
-
assert type(self.cost_w) is int
|
22 |
-
assert type(self.cost_comm) is int
|
23 |
-
assert self.mem_f + self.mem_b + self.mem_w == 0
|
24 |
-
|
25 |
-
@dataclass(eq=True, frozen=True)
|
26 |
-
class ScheduledNode:
|
27 |
-
type: str
|
28 |
-
stage: int
|
29 |
-
minibatch: int
|
30 |
-
start_time: int
|
31 |
-
completion_time: int
|
32 |
-
rollback: bool = False
|
33 |
-
|
34 |
-
|
35 |
-
@dataclass
|
36 |
-
class Graph:
|
37 |
-
nstages: int
|
38 |
-
nmb: int
|
39 |
-
nnodes: int
|
40 |
-
config: GraphConfig
|
41 |
-
parents: List[Set[int]] = None
|
42 |
-
name: List[str] = None
|
43 |
-
|
44 |
-
# ID mapping:
|
45 |
-
# F[stage][minibatch]: 0..STAGE* MB
|
46 |
-
# B[stage][minibatch]: STAGE* MB .. 2 * STAGE * MB
|
47 |
-
# W[stage][minibatch]: 2 * STAGE* MB .. 3 * STAGE * MB
|
48 |
-
|
49 |
-
def get_id(self, type, stage, mb):
|
50 |
-
return type * (self.nstages * self.nmb) + stage * self.nmb + mb
|
51 |
-
|
52 |
-
def get_stage(self, id):
|
53 |
-
return (id // self.nmb) % self.nstages
|
54 |
-
|
55 |
-
def get_cost(self, id):
|
56 |
-
type = id // (self.nstages * self.nmb)
|
57 |
-
return [self.config.cost_f, self.config.cost_b, self.config.cost_w][type]
|
58 |
-
|
59 |
-
def get_mem(self, id):
|
60 |
-
type = id // (self.nstages * self.nmb)
|
61 |
-
return [self.config.mem_f, self.config.mem_b, self.config.mem_w][type]
|
62 |
-
|
63 |
-
@classmethod
|
64 |
-
def build_graph(cls, nstages, nmb, config):
|
65 |
-
nnodes = nstages * nmb * 3
|
66 |
-
g = Graph(nstages=nstages, nmb=nmb, nnodes=nnodes, config=config)
|
67 |
-
parents = []
|
68 |
-
name = []
|
69 |
-
for type in range(3):
|
70 |
-
for stage in range(nstages):
|
71 |
-
for mb in range(nmb):
|
72 |
-
p = set()
|
73 |
-
if type == 0:
|
74 |
-
name.append(f'F{mb}')
|
75 |
-
if stage > 0:
|
76 |
-
p.add(g.get_id(type, stage - 1, mb))
|
77 |
-
if mb > 0:
|
78 |
-
p.add(g.get_id(type, stage, mb - 1))
|
79 |
-
elif type == 1:
|
80 |
-
name.append(f'B{mb}')
|
81 |
-
if stage == nstages - 1:
|
82 |
-
p.add(g.get_id(0, stage, mb))
|
83 |
-
else:
|
84 |
-
p.add(g.get_id(type, stage + 1, mb))
|
85 |
-
if mb > 0:
|
86 |
-
p.add(g.get_id(type, stage, mb - 1))
|
87 |
-
elif type == 2:
|
88 |
-
name.append(f'W{mb}')
|
89 |
-
p.add(g.get_id(1, stage, mb))
|
90 |
-
if mb > 0:
|
91 |
-
p.add(g.get_id(type, stage, mb - 1))
|
92 |
-
else:
|
93 |
-
assert False
|
94 |
-
parents.append(p)
|
95 |
-
|
96 |
-
g.name = name
|
97 |
-
g.parents = parents
|
98 |
-
return g
|
99 |
-
|
100 |
-
# Manual ordering producing this kind of schedule:
|
101 |
-
# fffffffbfbfbfbfbfbwbwbwbwbwbwbwwwwww
|
102 |
-
# fffffbfbfbfbfbfbfbfbwbwbwbwbwwwwwwww
|
103 |
-
# fffbfbfbfbfbfbfbfbfbfbwbwbwwwwwwwwww
|
104 |
-
# fbfbfbfbfbfbfbfbfbfbfbfbwwwwwwwwwwww
|
105 |
-
# Returns the order index of each node on its own stage
|
106 |
-
def manual_order(
|
107 |
-
self, allow_bubble_before_first_b=False, prioritize_b=False, no_bubble_greedy=True
|
108 |
-
):
|
109 |
-
order = [0] * self.nnodes
|
110 |
-
f = [0] * self.nstages
|
111 |
-
b = [0] * self.nstages
|
112 |
-
w = [0] * self.nstages
|
113 |
-
o = [0] * self.nstages
|
114 |
-
m = [0] * self.nstages
|
115 |
-
e = [0] * self.nstages
|
116 |
-
t = [0] * self.nnodes
|
117 |
-
max_mem = self.config.max_mem or self.get_mem(self.get_id(0, 0, 0)) * self.nmb * 3
|
118 |
-
comm = self.config.cost_comm
|
119 |
-
order_str = [""] * self.nstages
|
120 |
-
stage_bubble = [0] * self.nstages
|
121 |
-
|
122 |
-
def get_max_bubble():
|
123 |
-
max_bubble = 0
|
124 |
-
for bb in stage_bubble:
|
125 |
-
max_bubble = max(max_bubble, bb)
|
126 |
-
return max_bubble
|
127 |
-
|
128 |
-
def put(stage_j, type_k):
|
129 |
-
if type_k == 0:
|
130 |
-
_i = f[stage_j]
|
131 |
-
elif type_k == 1:
|
132 |
-
_i = b[stage_j]
|
133 |
-
else:
|
134 |
-
_i = w[stage_j]
|
135 |
-
_j = stage_j
|
136 |
-
_id = self.get_id(type_k, _j, _i)
|
137 |
-
_mem = self.get_mem(_id)
|
138 |
-
_cost = self.get_cost(_id)
|
139 |
-
assert m[_j] + _mem <= max_mem
|
140 |
-
|
141 |
-
tmp = e[_j] + _cost
|
142 |
-
no_bubble = tmp
|
143 |
-
if _j > 0 and type_k == 0:
|
144 |
-
tmp = max(tmp, t[self.get_id(0, _j - 1, _i)] + comm + _cost)
|
145 |
-
if _j < self.nstages - 1 and type_k == 1:
|
146 |
-
tmp = max(tmp, t[self.get_id(1, _j + 1, _i)] + comm + _cost)
|
147 |
-
if f[_j] > 0:
|
148 |
-
stage_bubble[_j] += tmp - no_bubble
|
149 |
-
e[_j] = tmp
|
150 |
-
t[_id] = tmp
|
151 |
-
m[_j] += _mem
|
152 |
-
order[_id] = o[_j]
|
153 |
-
if type_k == 0:
|
154 |
-
f[_j] += 1
|
155 |
-
elif type_k == 1:
|
156 |
-
b[_j] += 1
|
157 |
-
else:
|
158 |
-
w[_j] += 1
|
159 |
-
o[_j] += 1
|
160 |
-
fbw = "fbw"
|
161 |
-
order_str[stage_j] += fbw[type_k]
|
162 |
-
|
163 |
-
for i in range(self.nmb):
|
164 |
-
if i == 0:
|
165 |
-
for j in range(self.nstages):
|
166 |
-
put(j, 0)
|
167 |
-
f_required = [0] * self.nstages
|
168 |
-
last_t = 0
|
169 |
-
for j in range(self.nstages - 1, -1, -1):
|
170 |
-
if j == self.nstages - 1:
|
171 |
-
last_t = t[self.get_id(0, j, i)] + self.get_cost(self.get_id(1, j, i))
|
172 |
-
continue
|
173 |
-
mem = m[j]
|
174 |
-
cost = e[j]
|
175 |
-
while True:
|
176 |
-
f_id = self.get_id(0, j, f[j] + f_required[j])
|
177 |
-
if f[j] + f_required[j] < self.nmb and mem + self.get_mem(f_id) <= max_mem:
|
178 |
-
if allow_bubble_before_first_b:
|
179 |
-
if cost + self.get_cost(f_id) > last_t + comm:
|
180 |
-
break
|
181 |
-
else:
|
182 |
-
if cost >= last_t + comm:
|
183 |
-
break
|
184 |
-
mem += self.get_mem(f_id)
|
185 |
-
cost += self.get_cost(f_id)
|
186 |
-
f_required[j] += 1
|
187 |
-
else:
|
188 |
-
break
|
189 |
-
last_t = max(cost, last_t + comm) + self.get_cost(self.get_id(1, j, i))
|
190 |
-
for j in range(self.nstages):
|
191 |
-
while j > 0 and f_required[j] > 0 and f_required[j] >= f_required[j - 1] and f[j] + f_required[j] < self.nmb:
|
192 |
-
f_required[j] -= 1
|
193 |
-
for j in range(self.nstages - 1, -1, -1):
|
194 |
-
for _ in range(f_required[j]):
|
195 |
-
put(j, 0)
|
196 |
-
put(j, 1)
|
197 |
-
continue
|
198 |
-
f_required = [0] * self.nstages
|
199 |
-
for j in range(self.nstages):
|
200 |
-
if f[j] >= self.nmb:
|
201 |
-
continue
|
202 |
-
if j + 1 < self.nstages and f[j] >= f[j + 1] + 2 and prioritize_b:
|
203 |
-
next_plus_fw = (
|
204 |
-
e[j + 1]
|
205 |
-
+ self.get_cost(self.get_id(0, j + 1, f[j + 1]))
|
206 |
-
+ self.get_cost(self.get_id(1, j + 1, b[j + 1]))
|
207 |
-
+ comm
|
208 |
-
)
|
209 |
-
if e[j] >= next_plus_fw:
|
210 |
-
continue
|
211 |
-
f_id = self.get_id(0, j, f[j])
|
212 |
-
f_mem = self.get_mem(f_id)
|
213 |
-
w_cost, w_cnt = 0, 0
|
214 |
-
mem_with_w = m[j] + f_mem
|
215 |
-
while mem_with_w > max_mem and w[j] + w_cnt < b[j]:
|
216 |
-
w_id = self.get_id(2, j, w[j] + w_cnt)
|
217 |
-
w_cost += self.get_cost(w_id)
|
218 |
-
mem_with_w += self.get_mem(w_id)
|
219 |
-
w_cnt += 1
|
220 |
-
if e[j] + self.get_cost(f_id) + w_cost <= next_plus_fw:
|
221 |
-
f_required[j] = 1
|
222 |
-
continue
|
223 |
-
|
224 |
-
w_cost, w_cnt = 0, 0
|
225 |
-
# mem_with_w = m[j]
|
226 |
-
# while w[j] + w_cnt < b[j]:
|
227 |
-
# w_id = self.get_id(2, j, w[j] + w_cnt)
|
228 |
-
# w_cost += self.get_cost(w_id)
|
229 |
-
# mem_with_w += self.get_mem(w_id)
|
230 |
-
# w_cnt += 1
|
231 |
-
# if e[j] + w_cost >= next_plus_fw:
|
232 |
-
# continue
|
233 |
-
if next_plus_fw - (e[j] + w_cost) <= get_max_bubble() - stage_bubble[j]:
|
234 |
-
# TODO: can sample here
|
235 |
-
continue
|
236 |
-
f_required[j] = 1
|
237 |
-
for j in range(self.nstages - 2, -1, -1):
|
238 |
-
f_required[j] = min(f_required[j], f_required[j + 1])
|
239 |
-
for j in range(self.nstages):
|
240 |
-
if f_required[j] == 0:
|
241 |
-
continue
|
242 |
-
f_id = self.get_id(0, j, f[j])
|
243 |
-
mem = self.get_mem(f_id)
|
244 |
-
while m[j] + mem > max_mem:
|
245 |
-
if w[j] >= b[j]:
|
246 |
-
raise ValueError("Cannot fit memory")
|
247 |
-
put(j, 2)
|
248 |
-
if j > 0:
|
249 |
-
while (
|
250 |
-
w[j] < b[j]
|
251 |
-
and e[j] + self.get_cost(self.get_id(2, j, w[j]))
|
252 |
-
<= t[self.get_id(0, j - 1, f[j])] + comm
|
253 |
-
):
|
254 |
-
put(j, 2)
|
255 |
-
if w[j] < b[j] and e[j] < t[self.get_id(0, j - 1, f[j])] + comm:
|
256 |
-
# TODO: e[j] + self.get_cost(self.get_id(2, j, w[j])) > t[self.get_id(0, j - 1, f[j])] + comm
|
257 |
-
if (
|
258 |
-
t[self.get_id(0, j - 1, f[j])] + comm - e[j]
|
259 |
-
<= get_max_bubble() - stage_bubble[j]
|
260 |
-
):
|
261 |
-
# TODO: can sample here
|
262 |
-
if no_bubble_greedy:
|
263 |
-
put(j, 2)
|
264 |
-
else:
|
265 |
-
put(j, 2)
|
266 |
-
put(j, 0)
|
267 |
-
for j in range(self.nstages - 1, -1, -1):
|
268 |
-
assert b[j] == i
|
269 |
-
b_id = self.get_id(1, j, b[j])
|
270 |
-
mem = self.get_mem(b_id)
|
271 |
-
while m[j] + mem > max_mem:
|
272 |
-
if w[j] >= b[j]:
|
273 |
-
raise ValueError("Cannot fit memory")
|
274 |
-
put(j, 2)
|
275 |
-
if j + 1 < self.nstages:
|
276 |
-
while (
|
277 |
-
w[j] < b[j]
|
278 |
-
and e[j] + self.get_cost(self.get_id(2, j, w[j]))
|
279 |
-
<= t[self.get_id(1, j + 1, i)] + comm
|
280 |
-
):
|
281 |
-
put(j, 2)
|
282 |
-
if w[j] < b[j] and e[j] < t[self.get_id(1, j + 1, i)] + comm:
|
283 |
-
# TODO: e[j] + self.get_cost(self.get_id(2, j, w[j])) > t[self.get_id(1, j + 1, i)] + comm
|
284 |
-
if (
|
285 |
-
t[self.get_id(1, j + 1, i)] + comm - e[j]
|
286 |
-
<= get_max_bubble() - stage_bubble[j]
|
287 |
-
):
|
288 |
-
# TODO: can sample here
|
289 |
-
if no_bubble_greedy:
|
290 |
-
put(j, 2)
|
291 |
-
else:
|
292 |
-
put(j, 2)
|
293 |
-
if j == 0 and f[j] == self.nmb:
|
294 |
-
while w[j] < b[j]:
|
295 |
-
put(j, 2)
|
296 |
-
put(j, 1)
|
297 |
-
|
298 |
-
for i in range(self.nstages):
|
299 |
-
while w[i] < self.nmb:
|
300 |
-
put(i, 2)
|
301 |
-
# print(f"{' ' * i}{order_str[i]} -> {e[i]}")
|
302 |
-
|
303 |
-
for i in range(self.nstages):
|
304 |
-
for j in range(self.nmb):
|
305 |
-
f_id = self.get_id(0, i, j)
|
306 |
-
b_id = self.get_id(1, i, j)
|
307 |
-
w_id = self.get_id(2, i, j)
|
308 |
-
f_cost = self.get_cost(f_id)
|
309 |
-
b_cost = self.get_cost(b_id)
|
310 |
-
w_cost = self.get_cost(w_id)
|
311 |
-
assert t[b_id] >= t[f_id] + b_cost
|
312 |
-
assert t[w_id] >= t[b_id] + w_cost, f"{i}-{j}, {t[w_id]} >= {t[b_id]} + {w_cost}"
|
313 |
-
if i > 0:
|
314 |
-
assert t[f_id] >= t[self.get_id(0, i - 1, j)] + comm + f_cost, f"{i}-{j}"
|
315 |
-
if i < self.nstages - 1:
|
316 |
-
assert t[b_id] >= t[self.get_id(1, i + 1, j)] + comm + b_cost
|
317 |
-
|
318 |
-
# print(order)
|
319 |
-
best_time = 0
|
320 |
-
for i in range(self.nstages):
|
321 |
-
time_i = (
|
322 |
-
t[self.get_id(2, i, self.nmb - 1)]
|
323 |
-
- t[self.get_id(0, i, 0)]
|
324 |
-
+ self.get_cost(self.get_id(0, i, 0))
|
325 |
-
)
|
326 |
-
best_time = max(best_time, time_i)
|
327 |
-
|
328 |
-
return order, t, best_time
|
329 |
-
|
330 |
-
|
331 |
-
def initial_solution(graph):
|
332 |
-
best_time, order, complete_time = None, None, None
|
333 |
-
for allow_bubble_before_first_b in [True, False]:
|
334 |
-
for prioritize_b in [True, False]:
|
335 |
-
for no_bubble_greedy in [True, False]:
|
336 |
-
order_t, complete_time_t, best_time_t = graph.manual_order(
|
337 |
-
allow_bubble_before_first_b=allow_bubble_before_first_b,
|
338 |
-
prioritize_b=prioritize_b,
|
339 |
-
no_bubble_greedy=no_bubble_greedy,
|
340 |
-
)
|
341 |
-
if best_time is None or best_time_t < best_time:
|
342 |
-
best_time = best_time_t
|
343 |
-
order = order_t
|
344 |
-
complete_time = complete_time_t
|
345 |
-
|
346 |
-
print_detail(graph, complete_time)
|
347 |
-
print("-" * 20, best_time, "-" * 20)
|
348 |
-
return best_time, order, complete_time
|
349 |
-
|
350 |
-
|
351 |
-
def print_detail(graph, F):
|
352 |
-
typenames = ['F', 'B', 'W']
|
353 |
-
times = []
|
354 |
-
for stage in range(graph.nstages):
|
355 |
-
stage_str = ['.'] * int(F[graph.get_id(2, stage, graph.nmb - 1)] / graph.config.print_scaling)
|
356 |
-
for _type in range(3):
|
357 |
-
for _mb in range(graph.nmb):
|
358 |
-
_id = graph.get_id(_type, stage, _mb)
|
359 |
-
end = int(F[_id] / graph.config.print_scaling)
|
360 |
-
start = int((F[_id] - graph.get_cost(_id)) / graph.config.print_scaling)
|
361 |
-
for j in range(start, end):
|
362 |
-
if j == start or j == end - 1:
|
363 |
-
stage_str[j] = typenames[_type]
|
364 |
-
elif j == start + 1:
|
365 |
-
if _mb >= 10:
|
366 |
-
stage_str[j] = str(_mb // 10)
|
367 |
-
else:
|
368 |
-
stage_str[j] = str(_mb)
|
369 |
-
elif j == start + 2 and _mb >= 10:
|
370 |
-
stage_str[j] = str(_mb % 10)
|
371 |
-
else:
|
372 |
-
stage_str[j] = "-"
|
373 |
-
_str = ""
|
374 |
-
for _c in stage_str:
|
375 |
-
_str += _c
|
376 |
-
times.append(
|
377 |
-
F[graph.get_id(2, stage, graph.nmb - 1)]
|
378 |
-
- F[graph.get_id(0, stage, 0)]
|
379 |
-
+ graph.get_cost(graph.get_id(0, stage, 0))
|
380 |
-
)
|
381 |
-
print(_str)
|
382 |
-
print('Longest stage time: ', max(times))
|
383 |
-
|
384 |
-
|
385 |
-
def ilp_results(graph, F):
|
386 |
-
typenames = ['F', 'B', 'W']
|
387 |
-
local_order = []
|
388 |
-
end_time = []
|
389 |
-
for i in range(graph.nnodes):
|
390 |
-
end_time.append(F[i])
|
391 |
-
for stage in range(graph.nstages):
|
392 |
-
order = []
|
393 |
-
for type in range(3):
|
394 |
-
for mb in range(graph.nmb):
|
395 |
-
id = graph.get_id(type, stage, mb)
|
396 |
-
order.append(
|
397 |
-
ScheduledNode(
|
398 |
-
type=typenames[type],
|
399 |
-
stage=stage,
|
400 |
-
minibatch=mb,
|
401 |
-
start_time=end_time[id] - graph.get_cost(id),
|
402 |
-
completion_time=F[id],
|
403 |
-
)
|
404 |
-
)
|
405 |
-
local_order.append(order)
|
406 |
-
# For each F/B, append a send/recv node. The timestamp of recv node is the same as send node to guarrentee a global order.
|
407 |
-
comm_id = {}
|
408 |
-
comm_id_counter = 0
|
409 |
-
post_validation_time = 0
|
410 |
-
for i in range(graph.nstages - 1, -1, -1):
|
411 |
-
warmup_f_count = -1
|
412 |
-
first_b_end = end_time[graph.get_id(1, i, 0)]
|
413 |
-
for j in range(graph.nmb):
|
414 |
-
if end_time[graph.get_id(0, i, j)] < first_b_end:
|
415 |
-
warmup_f_count += 1
|
416 |
-
assert warmup_f_count >= 0
|
417 |
-
pv_id = warmup_f_count
|
418 |
-
_id = graph.get_id(0, i, pv_id)
|
419 |
-
_cost = graph.get_cost(_id)
|
420 |
-
post_validation_time = max(post_validation_time, end_time[_id] - _cost - graph.config.cost_comm)
|
421 |
-
# post_validation_time = 0
|
422 |
-
# print(i, pv_id, post_validation_time)
|
423 |
-
for it in ["RECV_", "SEND_", ""]:
|
424 |
-
if i == 0 and it == "SEND_":
|
425 |
-
continue
|
426 |
-
if i == graph.nstages - 1 and it == "RECV_":
|
427 |
-
continue
|
428 |
-
# stage_ = i - 1 if it == "RECV_" else i
|
429 |
-
stage_ = i
|
430 |
-
local_order[stage_].append(ScheduledNode(
|
431 |
-
type=it + "POST_VALIDATION",
|
432 |
-
stage=stage_,
|
433 |
-
minibatch=0,
|
434 |
-
start_time=post_validation_time,
|
435 |
-
completion_time=post_validation_time,
|
436 |
-
))
|
437 |
-
comm_id[local_order[stage_][-1]] = comm_id_counter
|
438 |
-
comm_id_counter += 1
|
439 |
-
for stage in range(graph.nstages):
|
440 |
-
for node in local_order[stage]:
|
441 |
-
if node.type == 'F' and node.stage != graph.nstages - 1:
|
442 |
-
local_order[stage].append(
|
443 |
-
ScheduledNode(
|
444 |
-
type='SEND_FORWARD',
|
445 |
-
stage=stage,
|
446 |
-
minibatch=node.minibatch,
|
447 |
-
start_time=node.completion_time,
|
448 |
-
completion_time=node.completion_time, # TODO: consider comm cost in completion time
|
449 |
-
)
|
450 |
-
)
|
451 |
-
local_order[stage + 1].append(
|
452 |
-
ScheduledNode(
|
453 |
-
type='RECV_FORWARD',
|
454 |
-
stage=stage + 1,
|
455 |
-
minibatch=node.minibatch,
|
456 |
-
start_time=node.completion_time,
|
457 |
-
completion_time=node.completion_time, # TODO: consider comm cost in completion time
|
458 |
-
)
|
459 |
-
)
|
460 |
-
comm_id[local_order[stage][-1]] = comm_id_counter
|
461 |
-
comm_id[local_order[stage + 1][-1]] = comm_id_counter
|
462 |
-
comm_id_counter += 1
|
463 |
-
if node.type == 'B' and node.stage != 0:
|
464 |
-
local_order[stage].append(
|
465 |
-
ScheduledNode(
|
466 |
-
type='SEND_BACKWARD',
|
467 |
-
stage=stage,
|
468 |
-
minibatch=node.minibatch,
|
469 |
-
start_time=node.completion_time,
|
470 |
-
completion_time=node.completion_time, # TODO: consider comm cost in completion time
|
471 |
-
)
|
472 |
-
)
|
473 |
-
local_order[stage - 1].append(
|
474 |
-
ScheduledNode(
|
475 |
-
type='RECV_BACKWARD',
|
476 |
-
stage=stage - 1,
|
477 |
-
minibatch=node.minibatch,
|
478 |
-
start_time=node.completion_time,
|
479 |
-
completion_time=node.completion_time, # TODO: consider comm cost in completion time
|
480 |
-
)
|
481 |
-
)
|
482 |
-
comm_id[local_order[stage][-1]] = comm_id_counter
|
483 |
-
comm_id[local_order[stage - 1][-1]] = comm_id_counter
|
484 |
-
comm_id_counter += 1
|
485 |
-
for stage in range(graph.nstages):
|
486 |
-
# For nodes with the same timestamp on the same stage, communication will be prioritized.
|
487 |
-
def even_breaker(x: ScheduledNode):
|
488 |
-
# Compute nodes are always delayed.
|
489 |
-
if x.type in ['F', 'B', 'W']:
|
490 |
-
return comm_id_counter
|
491 |
-
# For comm nodes, order by their unique comm id
|
492 |
-
return comm_id[x]
|
493 |
-
|
494 |
-
local_order[stage] = list(sorted(
|
495 |
-
local_order[stage], key=lambda x: (x.start_time, even_breaker(x))
|
496 |
-
))
|
497 |
-
# If a recv with intersects with previous computation, reorder them so that recv
|
498 |
-
# is executed before computation and hence can be overlapped.
|
499 |
-
for i in range(len(local_order[stage])):
|
500 |
-
if i > 0 and local_order[stage][i - 1].type in {'F', 'B', 'W'} and \
|
501 |
-
local_order[stage][i].type.startswith('RECV') and \
|
502 |
-
"POST_VALIDATION" not in local_order[stage][i].type and \
|
503 |
-
local_order[stage][i].start_time <= local_order[stage][i - 1].completion_time:
|
504 |
-
(local_order[stage][i], local_order[stage][i - 1]) = (local_order[stage][i - 1], local_order[stage][i])
|
505 |
-
# print([(x.type, x.start_time, x.completion_time) for x in local_order[stage]])
|
506 |
-
|
507 |
-
local_order_with_rollback = [[] for _ in range(graph.nstages)]
|
508 |
-
for rank in range(graph.nstages):
|
509 |
-
rollback_comm = set()
|
510 |
-
if rank > 0:
|
511 |
-
for node in local_order[rank - 1]:
|
512 |
-
if node.type == "POST_VALIDATION":
|
513 |
-
break
|
514 |
-
if node.type == "SEND_FORWARD":
|
515 |
-
rollback_comm.add(node.minibatch)
|
516 |
-
for node in local_order[rank]:
|
517 |
-
if node.type == "RECV_FORWARD" and node.minibatch in rollback_comm:
|
518 |
-
rollback = True
|
519 |
-
rollback_comm.remove(node.minibatch)
|
520 |
-
else:
|
521 |
-
rollback = False
|
522 |
-
local_order_with_rollback[rank].append(ScheduledNode(
|
523 |
-
type=node.type,
|
524 |
-
stage=node.stage,
|
525 |
-
minibatch=node.minibatch,
|
526 |
-
start_time=node.start_time,
|
527 |
-
completion_time=node.completion_time,
|
528 |
-
rollback=rollback,
|
529 |
-
))
|
530 |
-
assert len(rollback_comm) == 0
|
531 |
-
# for node in local_order_with_rollback[rank]:
|
532 |
-
# print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=', ')
|
533 |
-
# print()
|
534 |
-
|
535 |
-
print_detail(graph, end_time)
|
536 |
-
return local_order_with_rollback
|
537 |
-
|
538 |
-
|
539 |
-
def auto_schedule(nstages, nmb, config):
|
540 |
-
graph = Graph.build_graph(nstages, nmb, config)
|
541 |
-
|
542 |
-
best_time, order, complete_time = initial_solution(graph)
|
543 |
-
return ilp_results(graph, complete_time)
|
544 |
-
|
545 |
-
|
546 |
-
if __name__ == "__main__":
|
547 |
-
# auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=10))
|
548 |
-
# auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=14))
|
549 |
-
auto_schedule(24, 72, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=100))
|
550 |
-
auto_schedule(4, 12, GraphConfig(
|
551 |
-
cost_f=5478,
|
552 |
-
cost_b=5806,
|
553 |
-
cost_w=3534,
|
554 |
-
cost_comm=200,
|
555 |
-
max_mem=32,
|
556 |
-
print_scaling=1000
|
557 |
-
))
|
558 |
-
auto_schedule(32, 16, GraphConfig(
|
559 |
-
cost_f=1,
|
560 |
-
cost_b=1,
|
561 |
-
cost_w=1,
|
562 |
-
cost_comm=0,
|
563 |
-
max_mem=64,
|
564 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
description1.md
CHANGED
@@ -1,11 +1,5 @@
|
|
1 |
-
#
|
2 |
|
3 |
-
|
4 |
|
5 |
-
|
6 |
-
* [Arxiv Version with ZBV](https://arxiv.org/abs/2401.10241)
|
7 |
-
* [ICLR Accepted version with ZB1P and ZB2P](https://openreview.net/pdf?id=tuzTN0eIO5)
|
8 |
-
|
9 |
-
Try out our implementation based on Megatron on [https://github.com/sail-sg/zero-bubble-pipeline-parallelism](https://github.com/sail-sg/zero-bubble-pipeline-parallelism)
|
10 |
-
|
11 |
-
Experiments shows zero bubble pipeline parallelism can accelerate training up to 30% with a similar memory comsumption. A detailed table of experiments is coming soon.
|
|
|
1 |
+
# Pipeline Parallellism with Controllable Memory
|
2 |
|
3 |
+
Check out our paper at [Arxiv](https://arxiv.org/abs/2405.15362).
|
4 |
|
5 |
+
Bubble Rate here is calculated as (1 - longest stage time/(F+B+W)/m).
|
|
|
|
|
|
|
|
|
|
|
|
description2.md
CHANGED
@@ -1,33 +1,6 @@
|
|
1 |
-
##
|
2 |
-
The key of achieving zero bubble is to breaking a backward pass into a B pass and W pass. B on one stage will only depend on the B on its next stage, compared to depending on both B and W of in 1F1B.
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
* 1F1B
|
8 |
-
![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/Q3yxf4BQIESQ_M7lKKlhf.png)
|
9 |
-
* ZB1P
|
10 |
-
![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/EcTFvbjfM7soUXDYyn1Xu.png)
|
11 |
-
* ZB2P
|
12 |
-
![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/8jFI_rO69BREKqiSFHIOL.png)
|
13 |
-
* ZBV - Each device is assigned to exactly 2 chunks (virtual stages), where white text colors represent the first chunk and black text colors represent the second chunk. The sequence of dependencies among model chunks follows a ”V” shape pattern for both the forward and backward passes.
|
14 |
-
![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/VRfjNVXakAU3MQK3h6OKa.png)
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
| Comparison assuming T_F=T_B=T_W | 1F1B | ZB1P | ZB2P | ZBV (Recommended) |
|
19 |
-
| ----------------------------------------------------- | ------- | -------- | ---- | --- |
|
20 |
-
| Bubble Rate | (p-1)/m | (p-1)/3m | 0 | 0 |
|
21 |
-
| Activation Memory <br> (Compared to 1F1B) | 1x | 1x | 2x | 1x |
|
22 |
-
| Pipeline Communication Volume <br> (Compared to 1F1B) | 1x | 1x | 1x | 2x |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
## Optimizer Post Validation
|
27 |
-
|
28 |
-
In most practices of PP there's an all-reduce cross all pipeline stages for numerical robustness, e.g. global gradient norm for gradient clipping. INF/NAN check for mixed precision training, etc. This all-reduce breaks parallelogram and makes zero bubble impossible.
|
29 |
-
Under the observation that during a stable training both the gradient clipping and INF/NAN rarely triggers, we replace the before-hand synchronizations with a post update validation.
|
30 |
-
|
31 |
-
![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/hRPFqaFxJ20wm2omwyKmO.png)
|
32 |
-
|
33 |
-
We eagerly step the optimizers assuming the grad cliping, INF/NAN conditions are not triggered. In case an amendment to the gradient is required, a rollback will be issued and then we redo the optimizer step based on the fully reduced global state.
|
|
|
1 |
+
## Alternative schedules
|
|
|
2 |
|
3 |
+
By utilizing the building block, we can search for different types of schedules depending on the need. We illustrate few of them here below:
|
4 |
+
* 1F1B-V schedule without doing any B-W split.
|
5 |
+
* Schedule with 2/3rd 1F1B memory by utilising B-W split. Note that two microbatches are included in a single building block to avoid collision.
|
6 |
+
* Variation of interleaved 1F1B with lower memory
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
interleaved_variant.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass(eq=True, frozen=True)
|
4 |
+
class ScheduledNode:
|
5 |
+
type: str
|
6 |
+
chunk: int
|
7 |
+
stage: int
|
8 |
+
minibatch: int
|
9 |
+
start_time: int
|
10 |
+
completion_time: int
|
11 |
+
|
12 |
+
def get_interleaved_variation(_p, _n, cost):
|
13 |
+
_f, _b, _w, _c = cost
|
14 |
+
schedule = []
|
15 |
+
local_prev = {}
|
16 |
+
|
17 |
+
f_order = []
|
18 |
+
b_order = []
|
19 |
+
|
20 |
+
left = [_n, _n]
|
21 |
+
for id in range(min(_n, _p)):
|
22 |
+
f_order.append(('F', id))
|
23 |
+
for id in range(min(_n, _p)):
|
24 |
+
f_order.append(('f', id))
|
25 |
+
|
26 |
+
left = [max(0, _n - _p), max(0, _n - _p)]
|
27 |
+
|
28 |
+
i = 0
|
29 |
+
cur = 0
|
30 |
+
for id in range(min(_n, _p)):
|
31 |
+
b_order.append(('B', id))
|
32 |
+
while left[0] > 0 or left[1] > 0:
|
33 |
+
if i >= _p and left[1 - cur] > 0:
|
34 |
+
cur = 1 - cur
|
35 |
+
if left[cur] > 0:
|
36 |
+
if cur == 0:
|
37 |
+
f_order.append(('F', _n - left[cur]))
|
38 |
+
b_order.append(('b', _n - left[cur] - _p))
|
39 |
+
else:
|
40 |
+
f_order.append(('f', _n - left[cur]))
|
41 |
+
b_order.append(('B', _n - left[cur]))
|
42 |
+
left[cur] -= 1
|
43 |
+
i += 3
|
44 |
+
for id in range(min(_n, _p)):
|
45 |
+
b_order.append(('b', _n - _p + id))
|
46 |
+
|
47 |
+
for stage in range(_p):
|
48 |
+
diff = min(_p + _p - stage, len(f_order))
|
49 |
+
stage_schedule = []
|
50 |
+
for i in range(diff):
|
51 |
+
stage_schedule.append(f_order[i])
|
52 |
+
for i in range(len(f_order) - diff):
|
53 |
+
stage_schedule.append(b_order[i])
|
54 |
+
stage_schedule.append(f_order[i + diff])
|
55 |
+
for i in range(diff):
|
56 |
+
stage_schedule.append(b_order[len(b_order) - diff + i])
|
57 |
+
for i in range(len(stage_schedule) - 1):
|
58 |
+
local_prev[(stage, *stage_schedule[i + 1])] = (stage, *stage_schedule[i])
|
59 |
+
schedule.append(stage_schedule)
|
60 |
+
# print(stage_schedule)
|
61 |
+
# return None
|
62 |
+
cost = {
|
63 |
+
'F': _f,
|
64 |
+
'f': _f,
|
65 |
+
'B': _b+_w,
|
66 |
+
'b': _b+_w
|
67 |
+
}
|
68 |
+
pred = {
|
69 |
+
'f': 'F',
|
70 |
+
'B': 'f',
|
71 |
+
'b': 'B'
|
72 |
+
}
|
73 |
+
|
74 |
+
time_map = {}
|
75 |
+
def get_time(stage, type, minibatch):
|
76 |
+
if (stage, type, minibatch) in time_map:
|
77 |
+
return time_map.get((stage, type, minibatch))
|
78 |
+
time = 0
|
79 |
+
if (stage, type, minibatch) in local_prev:
|
80 |
+
time = get_time(*local_prev[(stage, type, minibatch)])
|
81 |
+
if stage > 0 and type in ('F', 'f'):
|
82 |
+
time = max(time, get_time(stage - 1, type, minibatch) + _c)
|
83 |
+
if stage == 0 and type in ('f'):
|
84 |
+
time = max(time, get_time(_p - 1, pred[type], minibatch) + _c)
|
85 |
+
if stage != _p - 1 and type in ('B', 'b'):
|
86 |
+
time = max(time, get_time(stage + 1, type, minibatch) + _c)
|
87 |
+
if stage == _p - 1 and type in ('b'):
|
88 |
+
time = max(time, get_time(0, pred[type], minibatch) + _c)
|
89 |
+
if stage == _p - 1 and type in ('B'):
|
90 |
+
time = max(time, get_time(stage, pred[type], minibatch))
|
91 |
+
|
92 |
+
time_map[(stage, type, minibatch)] = time + cost[type]
|
93 |
+
return time_map[(stage, type, minibatch)]
|
94 |
+
result = []
|
95 |
+
for sid, stage in enumerate(schedule):
|
96 |
+
result_stage = []
|
97 |
+
for type, minibatch in stage:
|
98 |
+
result_stage.append(ScheduledNode(
|
99 |
+
type.upper(),
|
100 |
+
type in ('f', 'B', 'W'),
|
101 |
+
sid,
|
102 |
+
minibatch,
|
103 |
+
get_time(sid, type, minibatch) - cost[type],
|
104 |
+
get_time(sid, type, minibatch)
|
105 |
+
))
|
106 |
+
result.append(result_stage)
|
107 |
+
return result
|
schedule1f1bv.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pattern_size = 6
|
2 |
+
from collections import Counter
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
@dataclass(eq=True, frozen=True)
|
6 |
+
class ScheduledNode:
|
7 |
+
type: str
|
8 |
+
chunk: int
|
9 |
+
stage: int
|
10 |
+
minibatch: int
|
11 |
+
start_time: int
|
12 |
+
completion_time: int
|
13 |
+
|
14 |
+
def transform_schedule(schedule, f, b, w, c):
|
15 |
+
result = []
|
16 |
+
|
17 |
+
stage_order = []
|
18 |
+
local_prev = {}
|
19 |
+
stages = len(schedule)
|
20 |
+
|
21 |
+
for sid, stage in enumerate(schedule):
|
22 |
+
counter = Counter()
|
23 |
+
order = []
|
24 |
+
for p in stage:
|
25 |
+
if not p.strip():
|
26 |
+
continue
|
27 |
+
mb = counter.get(p, 0)
|
28 |
+
if order:
|
29 |
+
local_prev[(sid, p, mb)] = order[-1]
|
30 |
+
order.append((p, mb))
|
31 |
+
counter.update(p)
|
32 |
+
stage_order.append(order)
|
33 |
+
nmb = max(counter.values())
|
34 |
+
time_map = {}
|
35 |
+
cost = {
|
36 |
+
'F': f,
|
37 |
+
'B': b + w,
|
38 |
+
'f': f,
|
39 |
+
'b': b + w,
|
40 |
+
}
|
41 |
+
def get_time(stage, type, mb):
|
42 |
+
if (stage, type, mb) in time_map:
|
43 |
+
return time_map.get((stage, type, mb))
|
44 |
+
time = 0
|
45 |
+
if (stage, type, mb) in local_prev:
|
46 |
+
time = get_time(stage, *local_prev[(stage, type, mb)])
|
47 |
+
if type in ('F', 'B') and stage > 0:
|
48 |
+
time = max(time, get_time(stage - 1, type, mb) + c)
|
49 |
+
if type in ('f', 'b') and stage + 1< len(schedule):
|
50 |
+
time = max(time, get_time(stage + 1, type, mb) + c)
|
51 |
+
time_map[(stage, type, mb)] = time + cost[type]
|
52 |
+
return time_map[(stage, type, mb)]
|
53 |
+
r = 0
|
54 |
+
for sid, stage in enumerate(schedule):
|
55 |
+
r = max(get_time(sid, 'b', nmb - 1) - get_time(sid, 'F', 0) + f, r)
|
56 |
+
|
57 |
+
for sid, stage in enumerate(stage_order):
|
58 |
+
result_stage = []
|
59 |
+
for p, mb in stage:
|
60 |
+
result_stage.append(ScheduledNode(
|
61 |
+
p.upper(),
|
62 |
+
p in ('f', 'B', 'W'),
|
63 |
+
sid,
|
64 |
+
mb,
|
65 |
+
get_time(sid, p, mb) - cost[p],
|
66 |
+
get_time(sid, p, mb)
|
67 |
+
)
|
68 |
+
)
|
69 |
+
result.append(result_stage)
|
70 |
+
return result
|
71 |
+
|
72 |
+
|
73 |
+
def get_pattern_str(pos):
|
74 |
+
pattern = [" "] * pattern_size
|
75 |
+
notations = "FfBbWw"
|
76 |
+
for i, v in enumerate(pos):
|
77 |
+
if v < 0:
|
78 |
+
continue
|
79 |
+
pattern[v] = notations[i]
|
80 |
+
_str = ""
|
81 |
+
for v in pattern:
|
82 |
+
_str += v
|
83 |
+
return _str
|
84 |
+
|
85 |
+
def init_repeated_schedule(p, m, patterns):
|
86 |
+
repeated = []
|
87 |
+
_len = 4 * p + m + 1
|
88 |
+
for i in range(p):
|
89 |
+
str_i = get_pattern_str(patterns[i]) * _len
|
90 |
+
repeated_i = []
|
91 |
+
for v in str_i:
|
92 |
+
repeated_i.append(v)
|
93 |
+
repeated.append(repeated_i)
|
94 |
+
return repeated
|
95 |
+
|
96 |
+
|
97 |
+
def clear_invalid(repeated, stage, pos, offset=-1):
|
98 |
+
while 0 <= pos < len(repeated[stage]):
|
99 |
+
repeated[stage][pos] = ' '
|
100 |
+
pos += offset * pattern_size
|
101 |
+
return repeated
|
102 |
+
|
103 |
+
|
104 |
+
def clear_invalid_index(repeated, m):
|
105 |
+
p = len(repeated)
|
106 |
+
index = pattern_size
|
107 |
+
for identifier in ['F', 'f', 'B', 'b']:
|
108 |
+
if identifier in ['F', 'B']:
|
109 |
+
_iter = range(p)
|
110 |
+
else:
|
111 |
+
_iter = range(p - 1, -1, -1)
|
112 |
+
for i in _iter:
|
113 |
+
for j in range(pattern_size):
|
114 |
+
if repeated[i][index] == identifier:
|
115 |
+
clear_invalid(repeated, i, index - pattern_size, offset=-1)
|
116 |
+
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
|
117 |
+
index += 1
|
118 |
+
if identifier in ['B', 'b']:
|
119 |
+
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
|
120 |
+
for k in range(pattern_size):
|
121 |
+
if repeated[i][index + k] == w_identifier:
|
122 |
+
clear_invalid(repeated, i, index + k - pattern_size, offset=-1)
|
123 |
+
clear_invalid(repeated, i, index + k + pattern_size * m, offset=1)
|
124 |
+
break
|
125 |
+
break
|
126 |
+
index += 1
|
127 |
+
return repeated
|
128 |
+
|
129 |
+
|
130 |
+
def process_warmup_without_increasing_peak_mem(schedules, m):
|
131 |
+
peak_mem = 0
|
132 |
+
mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))]
|
133 |
+
loc = [[{key: -1 for key in ('F', 'f', 'B', 'b', 'W', 'w')} for _ in range(m + 2)] for _ in range(len(schedules))]
|
134 |
+
cntr = [{key: 0 for key in ('F', 'f', 'B', 'b', 'W', 'w')} for _ in range(len(schedules))]
|
135 |
+
for sid in range(len(schedules)):
|
136 |
+
cur = 0
|
137 |
+
for i in range(len(schedules[sid])):
|
138 |
+
if schedules[sid][i] in ('F', 'f'):
|
139 |
+
cur += 1
|
140 |
+
if schedules[sid][i] in ('W', 'w'):
|
141 |
+
cur -= 1
|
142 |
+
mem[sid][i] = cur
|
143 |
+
peak_mem = max(peak_mem, cur)
|
144 |
+
|
145 |
+
for i in range(len(schedules[0])):
|
146 |
+
for sid in range(len(schedules)):
|
147 |
+
if schedules[sid][i] == ' ':
|
148 |
+
continue
|
149 |
+
cntr[sid][schedules[sid][i]] += 1
|
150 |
+
cnt = cntr[sid][schedules[sid][i]]
|
151 |
+
pos = -1
|
152 |
+
if cnt > 1:
|
153 |
+
pos = loc[sid][cnt - 1][schedules[sid][i]]
|
154 |
+
if schedules[sid][i] == 'W':
|
155 |
+
pos = max(pos, loc[sid][cnt]['B'])
|
156 |
+
if schedules[sid][i] == 'w':
|
157 |
+
pos = max(pos, loc[sid][cnt]['b'])
|
158 |
+
if schedules[sid][i] == 'F' and sid > 0:
|
159 |
+
pos = max(pos, loc[sid - 1][cnt]['F'])
|
160 |
+
if schedules[sid][i] == 'f':
|
161 |
+
if sid != len(schedules) - 1:
|
162 |
+
pos = max(pos, loc[sid + 1][cnt]['f'])
|
163 |
+
else :
|
164 |
+
pos = max(pos, loc[sid][cnt]['F'])
|
165 |
+
if schedules[sid][i] == 'B':
|
166 |
+
if sid != 0:
|
167 |
+
#Because B and W are always combined
|
168 |
+
pos = max(pos, loc[sid - 1][cnt]['W'])
|
169 |
+
else :
|
170 |
+
pos = max(pos, loc[sid][cnt]['f'])
|
171 |
+
if schedules[sid][i] == 'b':
|
172 |
+
if sid != len(schedules) - 1:
|
173 |
+
#Because B and W are always combined
|
174 |
+
pos = max(pos, loc[sid + 1][cnt]['w'])
|
175 |
+
else :
|
176 |
+
pos = max(pos, loc[sid][cnt]['W'])
|
177 |
+
pos += 1
|
178 |
+
while schedules[sid][pos] != ' ' and pos < i:
|
179 |
+
pos += 1
|
180 |
+
if schedules[sid][i] in ('B', 'b'):
|
181 |
+
while pos < i and (schedules[sid][pos] != ' ' or schedules[sid][pos + 1] != ' '):
|
182 |
+
pos += 1
|
183 |
+
if pos == i:
|
184 |
+
loc[sid][cnt][schedules[sid][i]] = i
|
185 |
+
continue
|
186 |
+
if schedules[sid][i] in ('B', 'b', 'W', 'w'):
|
187 |
+
schedules[sid][pos] = schedules[sid][i]
|
188 |
+
schedules[sid][i] = ' '
|
189 |
+
if schedules[sid][pos] in ('W', 'w'):
|
190 |
+
for j in range(pos, i):
|
191 |
+
mem[sid][j] -= 1
|
192 |
+
loc[sid][cnt][schedules[sid][pos]] = pos
|
193 |
+
continue
|
194 |
+
|
195 |
+
#If F or f:
|
196 |
+
|
197 |
+
place = i
|
198 |
+
while place > pos and mem[sid][place - 1] < peak_mem:
|
199 |
+
place -= 1
|
200 |
+
while place < i and schedules[sid][place] != ' ':
|
201 |
+
place += 1
|
202 |
+
if place == i:
|
203 |
+
loc[sid][cnt][schedules[sid][i]] = i
|
204 |
+
continue
|
205 |
+
pos = place
|
206 |
+
schedules[sid][pos] = schedules[sid][i]
|
207 |
+
schedules[sid][i] = ' '
|
208 |
+
for j in range(pos, i):
|
209 |
+
mem[sid][j] += 1
|
210 |
+
loc[sid][cnt][schedules[sid][pos]] = pos
|
211 |
+
return schedules
|
212 |
+
|
213 |
+
def schedule_by_pattern(p, m, patterns):
|
214 |
+
schedules = init_repeated_schedule(p, m, patterns)
|
215 |
+
schedules = clear_invalid_index(schedules, m)
|
216 |
+
|
217 |
+
schedules = process_warmup_without_increasing_peak_mem(schedules, m)
|
218 |
+
for sid in range(len(schedules)):
|
219 |
+
cnt = {_id: 0 for _id in "FfBbWw"}
|
220 |
+
for i in range(len(schedules[sid])):
|
221 |
+
if(schedules[sid][i] == ' '):
|
222 |
+
continue
|
223 |
+
if cnt[schedules[sid][i]] >= m:
|
224 |
+
schedules[sid][i] = ' '
|
225 |
+
else:
|
226 |
+
cnt[schedules[sid][i]] += 1
|
227 |
+
|
228 |
+
|
229 |
+
return schedules
|
230 |
+
|
231 |
+
def create_whole_pattern(p):
|
232 |
+
whole_pattern = [[0 for _ in range(6)] for _ in range(p)]
|
233 |
+
now = 0
|
234 |
+
for i in range(p):
|
235 |
+
now += 1
|
236 |
+
whole_pattern[i][0] = now
|
237 |
+
for i in range(p):
|
238 |
+
now += 1
|
239 |
+
whole_pattern[p - 1 - i][1] = now
|
240 |
+
now += 1
|
241 |
+
if p % 3 == 0:
|
242 |
+
now += 3
|
243 |
+
cyc = (3 - (p + 2) % 3) % 3
|
244 |
+
for i in range(p):
|
245 |
+
whole_pattern[i][2], whole_pattern[i][4] = now, now + 1
|
246 |
+
cyc += 1
|
247 |
+
now += 2
|
248 |
+
if(cyc == 3):
|
249 |
+
cyc = 0
|
250 |
+
now += 3
|
251 |
+
for i in range(p):
|
252 |
+
whole_pattern[p - 1 - i][3], whole_pattern[p - 1 - i][5] = now, now + 1
|
253 |
+
cyc += 1
|
254 |
+
now += 2
|
255 |
+
if(cyc == 3):
|
256 |
+
cyc = 0
|
257 |
+
now += 3
|
258 |
+
for sid in range(p):
|
259 |
+
for i in range(6):
|
260 |
+
whole_pattern[sid][i] %= 6
|
261 |
+
return whole_pattern
|
262 |
+
|
263 |
+
def schedule(p, m, cost):
|
264 |
+
whole_pattern = create_whole_pattern(p)
|
265 |
+
s = schedule_by_pattern(p, m, whole_pattern)
|
266 |
+
for sid in range(len(s)):
|
267 |
+
for i in range(len(s[sid])):
|
268 |
+
if s[sid][i] in ('W', 'w'):
|
269 |
+
s[sid][i] = ' '
|
270 |
+
res = transform_schedule(s, *cost)
|
271 |
+
return res
|
svg_event.py
CHANGED
@@ -234,7 +234,7 @@ def plot_events(ctx, events, title_text: str, canvas_info: CanvasInfo, include_w
|
|
234 |
if ENABLE_BATCH_ID:
|
235 |
minibatch = str(e["minibatch"])
|
236 |
center = (start + end) // 2
|
237 |
-
data_ctx.text(h, center, minibatch, font_scale=0.6, fill='
|
238 |
if ENABLE_BORDER:
|
239 |
data_ctx.line(h+SPAN_HEIGHT, 0, h+SPAN_HEIGHT+BORDER_SIZE, max_len - 1)
|
240 |
|
|
|
234 |
if ENABLE_BATCH_ID:
|
235 |
minibatch = str(e["minibatch"])
|
236 |
center = (start + end) // 2
|
237 |
+
data_ctx.text(h, center, minibatch, font_scale=0.6, fill='white' if e["chunk"] == 0 else 'black')
|
238 |
if ENABLE_BORDER:
|
239 |
data_ctx.line(h+SPAN_HEIGHT, 0, h+SPAN_HEIGHT+BORDER_SIZE, max_len - 1)
|
240 |
|
type2.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pattern_size = 6
|
2 |
+
from collections import Counter
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
@dataclass(eq=True, frozen=True)
|
6 |
+
class ScheduledNode:
|
7 |
+
type: str
|
8 |
+
stage: int
|
9 |
+
minibatch: int
|
10 |
+
start_time: int
|
11 |
+
completion_time: int
|
12 |
+
|
13 |
+
def transform_schedule(schedule, f, b, w, c):
|
14 |
+
result = []
|
15 |
+
|
16 |
+
stage_order = []
|
17 |
+
local_prev = {}
|
18 |
+
stages = len(schedule)
|
19 |
+
|
20 |
+
for sid, stage in enumerate(schedule):
|
21 |
+
counter = Counter()
|
22 |
+
order = []
|
23 |
+
for p in stage:
|
24 |
+
if not p.strip():
|
25 |
+
continue
|
26 |
+
mb = counter.get(p, 0)
|
27 |
+
if order:
|
28 |
+
local_prev[(sid, p, mb)] = order[-1]
|
29 |
+
order.append((p, mb))
|
30 |
+
counter.update(p)
|
31 |
+
stage_order.append(order)
|
32 |
+
nmb = max(counter.values())
|
33 |
+
time_map = {}
|
34 |
+
cost = {
|
35 |
+
'F': f,
|
36 |
+
'B': b,
|
37 |
+
'W': w,
|
38 |
+
}
|
39 |
+
def get_time(stage, type, mb):
|
40 |
+
if (stage, type, mb) in time_map:
|
41 |
+
return time_map.get((stage, type, mb))
|
42 |
+
time = 0
|
43 |
+
if (stage, type, mb) in local_prev:
|
44 |
+
time = get_time(stage, *local_prev[(stage, type, mb)])
|
45 |
+
if type in ('F') and stage > 0:
|
46 |
+
time = max(time, get_time(stage - 1, type, mb) + c)
|
47 |
+
if type in ('B') and stage + 1< len(schedule):
|
48 |
+
time = max(time, get_time(stage + 1, type, mb) + c)
|
49 |
+
# print(f'{stage} {type}:{mb}', time + cost[type])
|
50 |
+
time_map[(stage, type, mb)] = time + cost[type]
|
51 |
+
return time_map[(stage, type, mb)]
|
52 |
+
r = 0
|
53 |
+
for sid, stage in enumerate(schedule):
|
54 |
+
r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)
|
55 |
+
|
56 |
+
for sid, stage in enumerate(stage_order):
|
57 |
+
result_stage = []
|
58 |
+
for p, mb in stage:
|
59 |
+
result_stage.append(ScheduledNode(
|
60 |
+
p.upper(),
|
61 |
+
sid,
|
62 |
+
mb,
|
63 |
+
get_time(sid, p, mb) - cost[p],
|
64 |
+
get_time(sid, p, mb)
|
65 |
+
)
|
66 |
+
)
|
67 |
+
result.append(result_stage)
|
68 |
+
return result
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
def process_warmup_without_increasing_peak_mem(schedules, m):
|
74 |
+
peak_mem = 0
|
75 |
+
mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))]
|
76 |
+
loc = [[{key: -1 for key in ('F', 'B', 'W')} for _ in range(m + 2)] for _ in range(len(schedules))]
|
77 |
+
cntr = [{key: 0 for key in ('F', 'B', 'W')} for _ in range(len(schedules))]
|
78 |
+
for sid in range(len(schedules)):
|
79 |
+
cur = 0
|
80 |
+
for i in range(len(schedules[sid])):
|
81 |
+
if schedules[sid][i] in ('F'):
|
82 |
+
cur += 1
|
83 |
+
if schedules[sid][i] in ('W'):
|
84 |
+
cur -= 1
|
85 |
+
mem[sid][i] = cur
|
86 |
+
peak_mem = max(peak_mem, cur)
|
87 |
+
for i in range(len(schedules[0])):
|
88 |
+
for sid in range(len(schedules)):
|
89 |
+
if schedules[sid][i] == ' ':
|
90 |
+
continue
|
91 |
+
cntr[sid][schedules[sid][i]] += 1
|
92 |
+
cnt = cntr[sid][schedules[sid][i]]
|
93 |
+
pos = -1
|
94 |
+
if cnt > 1:
|
95 |
+
pos = loc[sid][cnt - 1][schedules[sid][i]]
|
96 |
+
if schedules[sid][i] == 'W':
|
97 |
+
pos = max(pos, loc[sid][cnt]['B'])
|
98 |
+
if schedules[sid][i] == 'F' and sid > 0:
|
99 |
+
pos = max(pos, loc[sid - 1][cnt]['F'])
|
100 |
+
if schedules[sid][i] == 'B':
|
101 |
+
if sid != len(schedules) - 1:
|
102 |
+
pos = max(pos, loc[sid + 1][cnt]['B'])
|
103 |
+
else :
|
104 |
+
pos = max(pos, loc[sid][cnt]['F'])
|
105 |
+
pos += 1
|
106 |
+
while schedules[sid][pos] != ' ' and pos < i:
|
107 |
+
pos += 1
|
108 |
+
if pos == i:
|
109 |
+
loc[sid][cnt][schedules[sid][i]] = i
|
110 |
+
continue
|
111 |
+
if schedules[sid][i] in ('B', 'W'):
|
112 |
+
schedules[sid][pos] = schedules[sid][i]
|
113 |
+
schedules[sid][i] = ' '
|
114 |
+
if schedules[sid][pos] in ('W'):
|
115 |
+
for j in range(pos, i):
|
116 |
+
mem[sid][j] -= 1
|
117 |
+
loc[sid][cnt][schedules[sid][pos]] = pos
|
118 |
+
continue
|
119 |
+
|
120 |
+
#If F:
|
121 |
+
if (sid == 0):
|
122 |
+
print(cnt, pos, i)
|
123 |
+
place = i
|
124 |
+
while place > pos and mem[sid][place - 1] < peak_mem:
|
125 |
+
place -= 1
|
126 |
+
while place < i and schedules[sid][place] != ' ':
|
127 |
+
place += 1
|
128 |
+
if place == i:
|
129 |
+
loc[sid][cnt][schedules[sid][i]] = i
|
130 |
+
continue
|
131 |
+
if (sid == 0):
|
132 |
+
print(place)
|
133 |
+
pos = place
|
134 |
+
schedules[sid][pos] = schedules[sid][i]
|
135 |
+
schedules[sid][i] = ' '
|
136 |
+
for j in range(pos, i):
|
137 |
+
mem[sid][j] += 1
|
138 |
+
loc[sid][cnt][schedules[sid][pos]] = pos
|
139 |
+
return schedules
|
140 |
+
|
141 |
+
def schedule(p, m, cost):
|
142 |
+
schedules = [[' ' for _ in range(6 * m + 2 * p + 6)] for _ in range(p)]
|
143 |
+
f_0, f_1, b_0, b_1= p-1, p+1, p, p + 2
|
144 |
+
for sid in range(p - 1, -1, -1):
|
145 |
+
for mid in range((m + 1) // 2):
|
146 |
+
if mid * 2 < m:
|
147 |
+
schedules[sid][f_0 + mid * 6], schedules[sid][b_0 + mid * 6] = 'F', 'B'
|
148 |
+
if mid * 2 + 1 < m:
|
149 |
+
schedules[sid][f_1 + mid * 6], schedules[sid][b_1 + mid * 6] = 'F', 'B'
|
150 |
+
f_0 -= 1
|
151 |
+
f_1 -= 1
|
152 |
+
b_0 += 1
|
153 |
+
b_1 += 1
|
154 |
+
cnt = 0
|
155 |
+
for i in range(len(schedules[0])):
|
156 |
+
if schedules[sid][i] == 'B':
|
157 |
+
cnt += 1
|
158 |
+
if schedules[sid][i] == ' ' and cnt > 0:
|
159 |
+
cnt -= 1
|
160 |
+
schedules[sid][i] = 'W'
|
161 |
+
schedules = process_warmup_without_increasing_peak_mem(schedules, m)
|
162 |
+
res = transform_schedule(schedules, *cost)
|
163 |
+
return res
|
v_schedule.py
DELETED
@@ -1,474 +0,0 @@
|
|
1 |
-
from collections import deque
|
2 |
-
from dataclasses import dataclass
|
3 |
-
|
4 |
-
@dataclass(eq=True, frozen=True)
|
5 |
-
class ScheduledNode:
|
6 |
-
type: str
|
7 |
-
chunk: int
|
8 |
-
stage: int
|
9 |
-
minibatch: int
|
10 |
-
start_time: int
|
11 |
-
completion_time: int
|
12 |
-
rollback: bool = False
|
13 |
-
|
14 |
-
|
15 |
-
class PipelineGraph(object):
|
16 |
-
def __init__(
|
17 |
-
self, n_stage, n_micro, f_cost, b_cost, w_cost, c_cost,
|
18 |
-
f_mem, b_mem, w_mem, max_mem=None,
|
19 |
-
):
|
20 |
-
self.n_node = 6 * n_stage * n_micro
|
21 |
-
self.n_stage = n_stage
|
22 |
-
self.n_micro = n_micro
|
23 |
-
self.f_cost = f_cost
|
24 |
-
self.b_cost = b_cost
|
25 |
-
self.w_cost = w_cost
|
26 |
-
self.c_cost = c_cost
|
27 |
-
self.f_mem = f_mem
|
28 |
-
self.b_mem = b_mem
|
29 |
-
self.w_mem = w_mem
|
30 |
-
self.fbw_cost = [f_cost, b_cost, w_cost]
|
31 |
-
self.fbw_mem = [f_mem, b_mem, w_mem]
|
32 |
-
self.max_mem = max_mem or f_mem * self.n_stage * 2
|
33 |
-
|
34 |
-
def get_id(self, cat, chunk, stage, micro):
|
35 |
-
return cat * 2 * self.n_stage * self.n_micro + \
|
36 |
-
chunk * self.n_stage * self.n_micro + \
|
37 |
-
stage * self.n_micro + \
|
38 |
-
micro
|
39 |
-
|
40 |
-
def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None):
|
41 |
-
count = []
|
42 |
-
for i in range(self.n_stage):
|
43 |
-
count.append([0] * 6)
|
44 |
-
|
45 |
-
end_time = [-1] * self.n_node
|
46 |
-
cur_time = [0] * self.n_stage
|
47 |
-
mem = [0] * self.n_stage
|
48 |
-
stage_bubble = [0] * self.n_stage
|
49 |
-
pending_w = [deque() for _ in range(self.n_stage)]
|
50 |
-
schedule = [[] for _ in range(self.n_stage)]
|
51 |
-
stage_str = [" " * i for i in range(self.n_stage)]
|
52 |
-
|
53 |
-
if approved_bubble is None:
|
54 |
-
approved_bubble = [-1] * self.n_stage
|
55 |
-
max_approved_bubble = max(approved_bubble)
|
56 |
-
|
57 |
-
def get_max_stage_bubble(stage=-1):
|
58 |
-
max_stage_bubble = 0
|
59 |
-
for bb in stage_bubble:
|
60 |
-
max_stage_bubble = max(max_stage_bubble, bb)
|
61 |
-
if stage >= 0:
|
62 |
-
max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage])
|
63 |
-
return max_stage_bubble
|
64 |
-
|
65 |
-
def put_w(stage):
|
66 |
-
assert len(pending_w[stage]) > 0
|
67 |
-
_, chunk_, _ = pending_w[stage].popleft()
|
68 |
-
put(2, chunk_, stage)
|
69 |
-
|
70 |
-
def put(cat, chunk, stage, assert_cnt=True):
|
71 |
-
_tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat]
|
72 |
-
_cnt = count[stage][cat * 2 + chunk]
|
73 |
-
# assert _cnt < self.n_micro
|
74 |
-
if _cnt >= self.n_micro:
|
75 |
-
if not assert_cnt:
|
76 |
-
stage_str[stage] += " "
|
77 |
-
cur_time[stage] = _tmp # TODO
|
78 |
-
return
|
79 |
-
assert False
|
80 |
-
assert mem[stage] + self.fbw_mem[cat] <= self.max_mem
|
81 |
-
stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1)))
|
82 |
-
if cat > 0 or chunk > 0:
|
83 |
-
last_id = cat * 2 + chunk - 1
|
84 |
-
if cat < 2:
|
85 |
-
# if end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] < 0:
|
86 |
-
# print(cat, chunk, stage, _cnt)
|
87 |
-
# self.print_details(end_time)
|
88 |
-
assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0
|
89 |
-
else:
|
90 |
-
assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0
|
91 |
-
if chunk == 1 and cat < 2:
|
92 |
-
if stage < self.n_stage - 1:
|
93 |
-
_fa_id = self.get_id(cat, chunk, stage + 1, _cnt)
|
94 |
-
assert end_time[_fa_id] >= 0
|
95 |
-
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
|
96 |
-
if chunk == 0 and cat < 2:
|
97 |
-
if stage > 0:
|
98 |
-
_fa_id = self.get_id(cat, chunk, stage - 1, _cnt)
|
99 |
-
# if end_time[_fa_id] < 0:
|
100 |
-
# print(cat, chunk, stage, _cnt)
|
101 |
-
# self.print_details(end_time)
|
102 |
-
assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}"
|
103 |
-
_tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
|
104 |
-
_id = self.get_id(cat, chunk, stage, _cnt)
|
105 |
-
if count[stage][0] > 0:
|
106 |
-
stage_bubble[stage] += _tmp - _no_bubble
|
107 |
-
end_time[_id] = _tmp
|
108 |
-
cur_time[stage] = _tmp
|
109 |
-
mem[stage] += self.fbw_mem[cat]
|
110 |
-
# noinspection PyTypeChecker
|
111 |
-
schedule[stage].append((cat, chunk, _cnt))
|
112 |
-
if cat == 1:
|
113 |
-
pending_w[stage].append((2, chunk, _cnt))
|
114 |
-
count[stage][cat * 2 + chunk] += 1
|
115 |
-
|
116 |
-
# for _ in range(2 * self.n_stage):
|
117 |
-
# for i in range(self.n_stage):
|
118 |
-
# if count[i][1] >= count[i][0]:
|
119 |
-
# put(0, 0, i, assert_cnt=False)
|
120 |
-
# continue
|
121 |
-
# if i == self.n_stage - 1:
|
122 |
-
# put(0, 1, i, assert_cnt=False)
|
123 |
-
# continue
|
124 |
-
# fa_id = self.get_id(0, 1, i + 1, count[i][1])
|
125 |
-
# if 0 <= end_time[fa_id] < cur_time[i + 1]: # TODO
|
126 |
-
# put(0, 1, i, assert_cnt=False)
|
127 |
-
# else:
|
128 |
-
# put(0, 0, i, assert_cnt=False)
|
129 |
-
|
130 |
-
for i in range(self.n_stage):
|
131 |
-
put(0, 0, i)
|
132 |
-
for i in range(self.n_stage - 1, -1, -1):
|
133 |
-
if i == self.n_stage - 1:
|
134 |
-
put(0, 1, i)
|
135 |
-
continue
|
136 |
-
tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost
|
137 |
-
while mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp and count[i][0] < self.n_micro:
|
138 |
-
for j in range(i + 1):
|
139 |
-
put(0, 0, j)
|
140 |
-
put(0, 1, i)
|
141 |
-
iter_chunk_ = 0
|
142 |
-
end_tmp = 0
|
143 |
-
for i in range(self.n_stage):
|
144 |
-
if i == 0:
|
145 |
-
end_tmp = cur_time[0] + self.fbw_cost[1]
|
146 |
-
continue
|
147 |
-
tmp = end_tmp + self.c_cost
|
148 |
-
while count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1] or count[i][1] <= count[i - 1][1] < self.n_micro:
|
149 |
-
for j in range(self.n_stage - 1, i - 1, -1):
|
150 |
-
if count[j][iter_chunk_] < self.n_micro:
|
151 |
-
put(0, iter_chunk_, j)
|
152 |
-
iter_chunk_ = 1 - iter_chunk_
|
153 |
-
# while mem[i] + self.fbw_mem[0] <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp:
|
154 |
-
# if iter_chunk_ == 0 and count[i][0] >= count[i - 1][0]:
|
155 |
-
# break
|
156 |
-
# for j in range(self.n_stage - 1, i - 1, -1):
|
157 |
-
# if count[j][iter_chunk_] < self.n_micro:
|
158 |
-
# put(0, iter_chunk_, j)
|
159 |
-
# iter_chunk_ = 1 - iter_chunk_
|
160 |
-
# end_tmp = max(tmp, cur_time[i]) + self.fbw_cost[1]
|
161 |
-
|
162 |
-
# init_bubble = get_max_stage_bubble()
|
163 |
-
# print(stage_bubble)
|
164 |
-
for _ in range(2 * self.n_micro):
|
165 |
-
# check mem before putting b
|
166 |
-
for i in range(self.n_stage):
|
167 |
-
while mem[i] + self.fbw_mem[1] > self.max_mem:
|
168 |
-
assert len(pending_w[i]) > 0
|
169 |
-
put_w(i)
|
170 |
-
b0_ranks, b1_ranks = [], []
|
171 |
-
for i in range(self.n_stage):
|
172 |
-
if count[i][3] >= count[i][2]:
|
173 |
-
b0_ranks.append(i)
|
174 |
-
elif i == self.n_stage - 1:
|
175 |
-
b1_ranks.append(i)
|
176 |
-
else:
|
177 |
-
fa_id = self.get_id(1, 1, i + 1, count[i][3])
|
178 |
-
if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro:
|
179 |
-
b1_ranks.append(i)
|
180 |
-
else:
|
181 |
-
b0_ranks.append(i)
|
182 |
-
b_ranks = []
|
183 |
-
# put b1
|
184 |
-
for i in reversed(b1_ranks):
|
185 |
-
b_ranks.append((i, 1))
|
186 |
-
# put b0
|
187 |
-
for i in b0_ranks:
|
188 |
-
b_ranks.append((i, 0))
|
189 |
-
for i, _chunk_ in b_ranks:
|
190 |
-
fa_id = -1
|
191 |
-
if _chunk_ == 1 and i < self.n_stage - 1:
|
192 |
-
fa_id = self.get_id(1, 1, i + 1, count[i][3])
|
193 |
-
if _chunk_ == 0 and i > 0:
|
194 |
-
fa_id = self.get_id(1, 0, i - 1, count[i][2])
|
195 |
-
while len(pending_w[i]) > 0 and fa_id >= 0 and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]:
|
196 |
-
# fill the bubble
|
197 |
-
put_w(i)
|
198 |
-
if len(pending_w[i]) > 0 and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]:
|
199 |
-
if _chunk_ == 1:
|
200 |
-
put_w(i)
|
201 |
-
elif fill_b:
|
202 |
-
put_w(i)
|
203 |
-
put(1, _chunk_, i)
|
204 |
-
|
205 |
-
# put f
|
206 |
-
for i in range(self.n_stage):
|
207 |
-
if count[i][1] >= self.n_micro:
|
208 |
-
continue
|
209 |
-
put_item = None
|
210 |
-
if count[i][1] >= count[i][0]:
|
211 |
-
put_item = 0
|
212 |
-
elif i == self.n_stage - 1:
|
213 |
-
put_item = 1
|
214 |
-
else:
|
215 |
-
if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0:
|
216 |
-
put_item = 1
|
217 |
-
elif count[i][0] < self.n_micro:
|
218 |
-
if i == 0:
|
219 |
-
put_item = 0
|
220 |
-
elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0:
|
221 |
-
put_item = 0
|
222 |
-
if put_item is None:
|
223 |
-
continue
|
224 |
-
# check mem before putting f
|
225 |
-
while mem[i] + self.fbw_mem[0] > self.max_mem:
|
226 |
-
assert len(pending_w[i]) > 0
|
227 |
-
put_w(i)
|
228 |
-
fa_id = -1
|
229 |
-
if put_item == 0 and i > 0:
|
230 |
-
fa_id = self.get_id(0, 0, i - 1, count[i][0])
|
231 |
-
if put_item == 1 and i < self.n_stage - 1:
|
232 |
-
fa_id = self.get_id(0, 1, i + 1, count[i][1])
|
233 |
-
while len(pending_w[i]) > 0 and fa_id >= 0 and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]:
|
234 |
-
# fill the bubble
|
235 |
-
put_w(i)
|
236 |
-
if len(pending_w[i]) > 0 and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]:
|
237 |
-
if fill_f:
|
238 |
-
put_w(i)
|
239 |
-
put(0, put_item, i)
|
240 |
-
|
241 |
-
for i in range(self.n_stage):
|
242 |
-
while len(pending_w[i]) > 0:
|
243 |
-
put_w(i)
|
244 |
-
|
245 |
-
# for i in range(self.n_stage):
|
246 |
-
# print(stage_str[i])
|
247 |
-
|
248 |
-
max_bubble = get_max_stage_bubble()
|
249 |
-
expected_time = sum(self.fbw_cost) * self.n_micro * 2
|
250 |
-
bubble_rate = max_bubble / expected_time
|
251 |
-
# print("%6.4f" % bubble_rate, "->", stage_bubble)
|
252 |
-
if max_approved_bubble < 0 or max_bubble < max_approved_bubble:
|
253 |
-
_schedule, _end_time, _max_bubble = self.try_v_schedule(
|
254 |
-
fill_f=fill_f, fill_b=fill_b,
|
255 |
-
approved_bubble=stage_bubble,
|
256 |
-
)
|
257 |
-
if _max_bubble < max_bubble:
|
258 |
-
return _schedule, _end_time, _max_bubble
|
259 |
-
# print("%2d %3d, [%5d %5d %5d], %6d -> %6.4f %6.4f" % \
|
260 |
-
# (self.n_stage, self.n_micro, *self.fbw_cost, self.max_mem // self.f_mem, init_bubble / expected_time, bubble_rate), max_bubble)
|
261 |
-
return schedule, end_time, max_bubble
|
262 |
-
|
263 |
-
def print_details(self, end_time, print_scaling=1):
|
264 |
-
for stage in range(self.n_stage):
|
265 |
-
stage_str = ['.'] * int(max(end_time) / print_scaling)
|
266 |
-
for _cat in range(3):
|
267 |
-
for _chunk in range(2):
|
268 |
-
for _micro in range(self.n_micro):
|
269 |
-
_id = self.get_id(_cat, _chunk, stage, _micro)
|
270 |
-
if end_time[_id] < 0:
|
271 |
-
continue
|
272 |
-
end = int(end_time[_id] / print_scaling)
|
273 |
-
start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling)
|
274 |
-
for j in range(start, end):
|
275 |
-
if j == start or j == end - 1:
|
276 |
-
stage_str[j] = "FfBbWw"[_cat * 2 + _chunk]
|
277 |
-
elif j == start + 1:
|
278 |
-
if _micro >= 10:
|
279 |
-
stage_str[j] = str(_micro // 10)
|
280 |
-
else:
|
281 |
-
stage_str[j] = str(_micro)
|
282 |
-
elif j == start + 2 and _micro >= 10:
|
283 |
-
stage_str[j] = str(_micro % 10)
|
284 |
-
else:
|
285 |
-
stage_str[j] = "-"
|
286 |
-
_str = ""
|
287 |
-
for _c in stage_str:
|
288 |
-
_str += _c
|
289 |
-
print(_str)
|
290 |
-
|
291 |
-
def get_v_schedule(self, only_run_time=False):
|
292 |
-
schedule, end_time, max_bubble = None, None, None
|
293 |
-
expected_time = sum(self.fbw_cost) * self.n_micro * 2
|
294 |
-
for fill_b in [True, False]:
|
295 |
-
for fill_f in [True, False]:
|
296 |
-
_schedule, _end_time, _max_bubble = self.try_v_schedule(
|
297 |
-
fill_b=fill_b, fill_f=fill_f
|
298 |
-
)
|
299 |
-
# print("")
|
300 |
-
if max_bubble is None or _max_bubble < max_bubble:
|
301 |
-
max_bubble = _max_bubble
|
302 |
-
schedule = _schedule
|
303 |
-
end_time = _end_time
|
304 |
-
if only_run_time:
|
305 |
-
return max_bubble + expected_time
|
306 |
-
# self.print_details(end_time, print_scaling=1)
|
307 |
-
bubble_rate = max_bubble / (expected_time + max_bubble)
|
308 |
-
print("%2d %3d, [%5d %5d %5d %5d], %6d -> %6.4f" % \
|
309 |
-
(self.n_stage, self.n_micro, *self.fbw_cost, self.c_cost, self.max_mem // self.f_mem, bubble_rate))
|
310 |
-
local_order = [[] for _ in range(self.n_stage)]
|
311 |
-
comm_id = {}
|
312 |
-
comm_id_counter = 0
|
313 |
-
post_validation_time = 0
|
314 |
-
for i in range(self.n_stage - 1, -1, -1):
|
315 |
-
pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1)
|
316 |
-
post_validation_time = max(post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost)
|
317 |
-
# post_validation_time = 0
|
318 |
-
# print(i, pv_id, post_validation_time)
|
319 |
-
for it in ["RECV_", "SEND_", ""]:
|
320 |
-
if i == 0 and it == "SEND_":
|
321 |
-
continue
|
322 |
-
if i == self.n_stage - 1 and it == "RECV_":
|
323 |
-
continue
|
324 |
-
# stage_ = i - 1 if it == "RECV_" else i
|
325 |
-
stage_ = i
|
326 |
-
local_order[stage_].append(ScheduledNode(
|
327 |
-
type=it + "POST_VALIDATION",
|
328 |
-
chunk=0,
|
329 |
-
stage=stage_,
|
330 |
-
minibatch=0,
|
331 |
-
start_time=post_validation_time,
|
332 |
-
completion_time=post_validation_time,
|
333 |
-
))
|
334 |
-
comm_id[local_order[stage_][-1]] = comm_id_counter
|
335 |
-
comm_id_counter += 1
|
336 |
-
for i in range(self.n_stage):
|
337 |
-
for _cat_, _chunk_, _micro_ in schedule[i]:
|
338 |
-
complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)]
|
339 |
-
local_order[i].append(ScheduledNode(
|
340 |
-
type="FBW"[_cat_],
|
341 |
-
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
|
342 |
-
stage=i,
|
343 |
-
minibatch=_micro_,
|
344 |
-
start_time=complete_time - self.fbw_cost[_cat_],
|
345 |
-
completion_time=complete_time,
|
346 |
-
))
|
347 |
-
if _cat_ == 2: # no communication for W
|
348 |
-
continue
|
349 |
-
cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD"
|
350 |
-
def communicate(send_recv, stage_):
|
351 |
-
# noinspection PyTypeChecker
|
352 |
-
local_order[stage_].append(ScheduledNode(
|
353 |
-
type=send_recv + cat_str,
|
354 |
-
chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
|
355 |
-
stage=stage_,
|
356 |
-
minibatch=_micro_,
|
357 |
-
start_time=complete_time,
|
358 |
-
completion_time=complete_time,
|
359 |
-
))
|
360 |
-
comm_id[local_order[stage_][-1]] = comm_id_counter
|
361 |
-
|
362 |
-
if _chunk_ == 1 and i > 0:
|
363 |
-
communicate("SEND_", i)
|
364 |
-
communicate("RECV_", i - 1)
|
365 |
-
if _chunk_ == 0 and i < self.n_stage - 1:
|
366 |
-
communicate("SEND_", i)
|
367 |
-
communicate("RECV_", i + 1)
|
368 |
-
comm_id_counter += 1
|
369 |
-
for rank in range(self.n_stage):
|
370 |
-
# For nodes with the same timestamp on the same stage, communication will be prioritized.
|
371 |
-
def even_breaker(x: ScheduledNode):
|
372 |
-
# Compute nodes are always delayed.
|
373 |
-
if x.type in ['F', 'B', 'W']:
|
374 |
-
return comm_id_counter
|
375 |
-
# For comm nodes, order by their unique comm id
|
376 |
-
return comm_id[x]
|
377 |
-
|
378 |
-
local_order[rank] = list(sorted(
|
379 |
-
local_order[rank],
|
380 |
-
key=lambda x: (x.start_time, even_breaker(x))
|
381 |
-
))
|
382 |
-
# If a recv with intersects with previous computation, reorder them so that recv
|
383 |
-
# is executed before computation and hence can be overlapped.
|
384 |
-
for i in range(len(local_order[rank])):
|
385 |
-
if i > 0 and local_order[rank][i - 1].type in {'F', 'B', 'W'} and \
|
386 |
-
local_order[rank][i].type.startswith('RECV') and \
|
387 |
-
"POST_VALIDATION" not in local_order[rank][i].type and \
|
388 |
-
local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time:
|
389 |
-
local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i]
|
390 |
-
|
391 |
-
local_order_with_rollback = [[] for _ in range(self.n_stage)]
|
392 |
-
for rank in range(self.n_stage):
|
393 |
-
rollback_comm = set()
|
394 |
-
if rank > 0:
|
395 |
-
for node in local_order[rank - 1]:
|
396 |
-
if node.type == "POST_VALIDATION":
|
397 |
-
break
|
398 |
-
if node.type == "SEND_FORWARD":
|
399 |
-
assert node.chunk == 0
|
400 |
-
rollback_comm.add(node.minibatch)
|
401 |
-
for node in local_order[rank]:
|
402 |
-
if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm:
|
403 |
-
rollback = True
|
404 |
-
rollback_comm.remove(node.minibatch)
|
405 |
-
else:
|
406 |
-
rollback = False
|
407 |
-
local_order_with_rollback[rank].append(ScheduledNode(
|
408 |
-
type=node.type,
|
409 |
-
chunk=node.chunk,
|
410 |
-
stage=node.stage,
|
411 |
-
minibatch=node.minibatch,
|
412 |
-
start_time=node.start_time,
|
413 |
-
completion_time=node.completion_time,
|
414 |
-
rollback=rollback,
|
415 |
-
))
|
416 |
-
assert len(rollback_comm) == 0
|
417 |
-
for node in local_order_with_rollback[rank]:
|
418 |
-
print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=', ')
|
419 |
-
print()
|
420 |
-
|
421 |
-
return local_order_with_rollback
|
422 |
-
|
423 |
-
|
424 |
-
if __name__ == '__main__':
|
425 |
-
settings = [
|
426 |
-
# p, n, f, b, w, c, h, a, l
|
427 |
-
# (8, 24, 18522, 18086, 9337, 601, 2304, 24, 24),
|
428 |
-
# (8, 32, 18513, 18086, 9331, 626, 2304, 24, 24),
|
429 |
-
# (8, 64, 18546, 18097, 9321, 762, 2304, 24, 24),
|
430 |
-
# (8, 24, 29718, 29444, 19927, 527, 4096, 32, 32),
|
431 |
-
# (8, 32, 29802, 29428, 19530, 577, 4096, 32, 32),
|
432 |
-
# (8, 64, 29935, 29621, 19388, 535, 4096, 32, 32),
|
433 |
-
# (16, 48, 11347, 11248, 8132, 377, 5120, 40, 48),
|
434 |
-
# (16, 64, 11307, 11254, 8101, 379, 5120, 40, 48),
|
435 |
-
# (16, 128, 11325, 11308, 8109, 378, 5120, 40, 48),
|
436 |
-
# (32, 96, 10419, 10207, 7715, 408, 6144, 48, 64),
|
437 |
-
# (32, 128, 10408, 10204, 7703, 408, 6144, 48, 64),
|
438 |
-
# (32, 256, 10402, 10248, 7698, 460, 6144, 48, 64),
|
439 |
-
# (4, 8, 6, 4, 4, 1, 4096, 32, 32),
|
440 |
-
# (8, 24, 29444, 29718, 19927, 527, 4096, 32, 32),
|
441 |
-
# ( 8, 32, 16099, 16504, 7589, 540, 2304, 24, 16),
|
442 |
-
(16, 48, 14407, 14380, 9676, 1610, 4096, 32, 32),
|
443 |
-
(16, 64, 14412, 14393, 9688, 1621, 4096, 32, 32),
|
444 |
-
(16, 128,14316, 14306, 9639, 1619, 4096, 32, 32),
|
445 |
-
(24, 72, 6763, 6969, 5251, 755, 5120, 40, 48),
|
446 |
-
(24, 96, 6783, 6984, 5259, 758, 5120, 40, 48),
|
447 |
-
(24, 192, 6785, 6990, 5260, 770, 5120, 40, 48),
|
448 |
-
(32, 96, 9458, 9748, 7288, 879, 6144, 48, 64),
|
449 |
-
(32, 128, 9469, 9744, 7306, 892, 6144, 48, 64),
|
450 |
-
(32, 256, 9447, 9644, 7193, 887, 6144, 48, 64),
|
451 |
-
]
|
452 |
-
s = 1024
|
453 |
-
|
454 |
-
# h, a, s = 4096, 32, 1024
|
455 |
-
# cost_f, cost_b, cost_w, cost_c = 29718, 29444, 19927, 527
|
456 |
-
for p, n, f, b, w, c, h, a, _ in settings:
|
457 |
-
mem_f = 34 * h + 5 * a * s
|
458 |
-
mem_w = - 32 * h
|
459 |
-
mem_b = - mem_w - mem_f
|
460 |
-
for m_offset in range(p + 1):
|
461 |
-
graph = PipelineGraph(
|
462 |
-
n_stage=p,
|
463 |
-
n_micro=n,
|
464 |
-
f_cost=f,
|
465 |
-
b_cost=b,
|
466 |
-
w_cost=w,
|
467 |
-
c_cost=c,
|
468 |
-
f_mem=mem_f,
|
469 |
-
b_mem=mem_b,
|
470 |
-
w_mem=mem_w,
|
471 |
-
max_mem=mem_f * (p * 2 + m_offset),
|
472 |
-
)
|
473 |
-
graph.get_v_schedule()
|
474 |
-
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|