Nyamdavaa Amar commited on
Commit
3d4d40d
·
1 Parent(s): f8e95f6

Pipeline Parallelism with Controllable Memory

Browse files
Files changed (11) hide show
  1. README.md +4 -10
  2. adaptive_schedule.py +627 -0
  3. app.py +152 -96
  4. auto_schedule.py +0 -564
  5. description1.md +3 -9
  6. description2.md +5 -32
  7. interleaved_variant.py +107 -0
  8. schedule1f1bv.py +271 -0
  9. svg_event.py +1 -1
  10. type2.py +163 -0
  11. v_schedule.py +0 -474
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Zero Bubble Pipeline Parallellism
3
  emoji: 🏆
4
  colorFrom: indigo
5
  colorTo: red
@@ -11,14 +11,8 @@ license: apache-2.0
11
  ---
12
 
13
 
14
- # Zero Bubble Pipeline Parallelism
15
 
16
- Zero Bubble Pipeline Parallelism is a novel pipeline parallelism algorithm able to reduce the bubble of pipeline parallelism to almost zero while preserving synchronous semantics.
17
 
18
- Check out our paper at:
19
- * [Arxiv Version with ZBV](https://arxiv.org/abs/2401.10241)
20
- * [ICLR Accepted version with ZB1P and ZB2P](https://openreview.net/pdf?id=tuzTN0eIO5)
21
-
22
- Try out our implementation based on Megatron on [https://github.com/sail-sg/zero-bubble-pipeline-parallelism](https://github.com/sail-sg/zero-bubble-pipeline-parallelism)
23
-
24
- Experiments shows zero bubble pipeline parallelism can accelerate training up to 30% with a similar memory comsumption. A detailed table of experiments is coming soon.
 
1
  ---
2
+ title: Pipeline Parallellism with Controllable Memory
3
  emoji: 🏆
4
  colorFrom: indigo
5
  colorTo: red
 
11
  ---
12
 
13
 
14
+ # Pipeline Parallellism with Controllable Memory
15
 
16
+ Check out our paper at [Arxiv](https://arxiv.org/abs/2405.15362).
17
 
18
+ Bubble Rate here is calculated as (1 - longest stage time/(F+B+W)/m).
 
 
 
 
 
 
adaptive_schedule.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pattern_size = 6
2
+ from collections import Counter
3
+ from dataclasses import dataclass
4
+
5
+ @dataclass(eq=True, frozen=True)
6
+ class ScheduledNode:
7
+ type: str
8
+ chunk: int
9
+ stage: int
10
+ minibatch: int
11
+ start_time: int
12
+ completion_time: int
13
+
14
+ def transform_schedule(schedule, f, b, w, c):
15
+ result = []
16
+
17
+ stage_order = []
18
+ local_prev = {}
19
+ stages = len(schedule)
20
+
21
+ for sid, stage in enumerate(schedule):
22
+ counter = Counter()
23
+ order = []
24
+ for p in stage:
25
+ if not p.strip():
26
+ continue
27
+ mb = counter.get(p, 0)
28
+ if order:
29
+ local_prev[(sid, p, mb)] = order[-1]
30
+ order.append((p, mb))
31
+ counter.update(p)
32
+ stage_order.append(order)
33
+ nmb = max(counter.values())
34
+ time_map = {}
35
+ cost = {
36
+ 'F': f,
37
+ 'B': b,
38
+ 'W': w,
39
+ 'f': f,
40
+ 'b': b,
41
+ 'w': w,
42
+ }
43
+ def get_time(stage, type, mb):
44
+ if (stage, type, mb) in time_map:
45
+ return time_map.get((stage, type, mb))
46
+ time = 0
47
+ if (stage, type, mb) in local_prev:
48
+ time = get_time(stage, *local_prev[(stage, type, mb)])
49
+ if type in ('F', 'B') and stage > 0:
50
+ time = max(time, get_time(stage - 1, type, mb) + c)
51
+ if type in ('f', 'b') and stage + 1< len(schedule):
52
+ time = max(time, get_time(stage + 1, type, mb) + c)
53
+ # print(f'{stage} {type}:{mb}', time + cost[type])
54
+ time_map[(stage, type, mb)] = time + cost[type]
55
+ return time_map[(stage, type, mb)]
56
+ r = 0
57
+ for sid, stage in enumerate(schedule):
58
+ r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)
59
+ r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r)
60
+
61
+ for sid, stage in enumerate(stage_order):
62
+ result_stage = []
63
+ for p, mb in stage:
64
+ result_stage.append(ScheduledNode(
65
+ p.upper(),
66
+ p in ('f', 'B', 'W'),
67
+ sid,
68
+ mb,
69
+ get_time(sid, p, mb) - cost[p],
70
+ get_time(sid, p, mb)
71
+ )
72
+ )
73
+ result.append(result_stage)
74
+ return result
75
+
76
+
77
+
78
+
79
+
80
+ def evaluate_schedule(schedule, f, b, w, c):
81
+ stage_order = []
82
+ local_prev = {}
83
+ stages = len(schedule)
84
+
85
+ for sid, stage in enumerate(schedule):
86
+ counter = Counter()
87
+ order = []
88
+ for p in stage:
89
+ if not p.strip():
90
+ continue
91
+ mb = counter.get(p, 0)
92
+ if order:
93
+ local_prev[(sid, p, mb)] = order[-1]
94
+ order.append((p, mb))
95
+ counter.update(p)
96
+ stage_order.append(order)
97
+ nmb = max(counter.values())
98
+ time_map = {}
99
+ cost = {
100
+ 'F': f,
101
+ 'B': b,
102
+ 'W': w,
103
+ 'f': f,
104
+ 'b': b,
105
+ 'w': w,
106
+ }
107
+ def get_time(stage, type, mb):
108
+ if (stage, type, mb) in time_map:
109
+ return time_map.get((stage, type, mb))
110
+ time = 0
111
+ if (stage, type, mb) in local_prev:
112
+ time = get_time(stage, *local_prev[(stage, type, mb)])
113
+ if type in ('F', 'B') and stage > 0:
114
+ time = max(time, get_time(stage - 1, type, mb) + c)
115
+ if type in ('f', 'b') and stage + 1< len(schedule):
116
+ time = max(time, get_time(stage + 1, type, mb) + c)
117
+ # print(f'{stage} {type}:{mb}', time + cost[type])
118
+ time_map[(stage, type, mb)] = time + cost[type]
119
+ return time_map[(stage, type, mb)]
120
+ r = 0
121
+ for sid, stage in enumerate(schedule):
122
+ r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)
123
+ r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r)
124
+ return r
125
+
126
+ def get_pattern_str(pos):
127
+ pattern = [" "] * pattern_size
128
+ notations = "FfBbWw"
129
+ for i, v in enumerate(pos):
130
+ if v < 0:
131
+ continue
132
+ pattern[v] = notations[i]
133
+ _str = ""
134
+ for v in pattern:
135
+ _str += v
136
+ return _str
137
+
138
+
139
+ def get_peak_mem(schedules, return_all=False):
140
+ max_peak = 0
141
+ all_peak = []
142
+ for schedule_ in schedules:
143
+ peak, mem = 0, 0
144
+ for v in schedule_:
145
+ if v in "Ff":
146
+ mem += 1
147
+ elif v in "Ww":
148
+ mem -= 1
149
+ peak = max(peak, mem)
150
+ all_peak.append(peak)
151
+ max_peak = max(max_peak, peak)
152
+ if return_all:
153
+ return all_peak
154
+ return max_peak
155
+
156
+ debug = False
157
+ def print_schedules(schedules):
158
+ if not debug:
159
+ return
160
+ for seq in schedules:
161
+ _str = ""
162
+ for v in seq:
163
+ _str += v
164
+ print(_str)
165
+
166
+
167
+ def calc_bubble(schedules):
168
+ stage_bubbles = []
169
+ for i in range(len(schedules)):
170
+ max_len = 0
171
+ count = 0
172
+ for j in range(len(schedules[i])):
173
+ if schedules[i][j] != ' ':
174
+ max_len = j + 1
175
+ count += 1
176
+ stage_bubbles.append(max_len - count - i)
177
+ return stage_bubbles
178
+
179
+
180
+ def init_repeated_schedule(p, m, patterns):
181
+ repeated = []
182
+ _len = 4 * p + m + 1
183
+ for i in range(p):
184
+ str_i = get_pattern_str(patterns[i]) * _len
185
+ repeated_i = []
186
+ for v in str_i:
187
+ repeated_i.append(v)
188
+ repeated.append(repeated_i)
189
+ return repeated
190
+
191
+
192
+ def clear_invalid(repeated, stage, pos, offset=-1):
193
+ while 0 <= pos < len(repeated[stage]):
194
+ repeated[stage][pos] = ' '
195
+ pos += offset * pattern_size
196
+ return repeated
197
+
198
+
199
+ def clear_invalid_index(repeated, m):
200
+ p = len(repeated)
201
+ index = pattern_size
202
+ for identifier in ['F', 'f', 'B', 'b']:
203
+ if identifier in ['F', 'B']:
204
+ _iter = range(p)
205
+ else:
206
+ _iter = range(p - 1, -1, -1)
207
+ for i in _iter:
208
+ for j in range(pattern_size):
209
+ if repeated[i][index] == identifier:
210
+ clear_invalid(repeated, i, index - pattern_size, offset=-1)
211
+ clear_invalid(repeated, i, index + pattern_size * m, offset=1)
212
+ index += 1
213
+ if identifier in ['B', 'b']:
214
+ w_identifier = {'B': 'W', 'b': 'w'}[identifier]
215
+ for k in range(pattern_size):
216
+ if repeated[i][index + k] == w_identifier:
217
+ clear_invalid(repeated, i, index + k - pattern_size, offset=-1)
218
+ clear_invalid(repeated, i, index + k + pattern_size * m, offset=1)
219
+ break
220
+ break
221
+ index += 1
222
+ return repeated
223
+
224
+
225
+ def process_warmup_without_increasing_peak_mem(schedules, m):
226
+ """
227
+ FFFFFFFFFF fBWfBWfBWfBWfBW b
228
+ FFFFFFFFF f fBWfBWfBWfBWFBWb
229
+ FFFFFFFF f f fBWfBWfBWFBW b
230
+ FFFFFFF f f f fBWfBWFBW Bb
231
+ FFFFFF f f f f fBWFBWFBWb
232
+ FFFFFfFf f f f BWFBW b
233
+ FFFfFfFfFf f BW Bb
234
+ FfFfFfFfFfF BWb
235
+ We reorganize the warmup phase in the following way (i -> pipeline stage from 0):
236
+ 1. Before the first B, we set #f = min(i+1, peak_mem//2), #F = peak_mem - #f
237
+ 2. Before the first b, #f = peak_mem//2
238
+ 3. The offset between the first B is 1
239
+ 4. Before the first b, we use the pattern of (BWf)*j + (BWF)*k,
240
+ where j = max(0, peak_mem//2 - (i+1)), k = max(0, #W - j - 1)
241
+ """
242
+ # process warmup phase (before the first b)
243
+ p = len(schedules)
244
+ peak_mem = get_peak_mem(schedules)
245
+ peak_mem = min(peak_mem, 2 * p)
246
+ cnt_f, cnt_ff = [], []
247
+ for i in range(p):
248
+ cc_ff = min(i + 1, peak_mem // 2)
249
+ cc_ff = min(cc_ff, m)
250
+ cc_f = min(peak_mem - cc_ff, m)
251
+ cnt_f.append(cc_f)
252
+ cnt_ff.append(cc_ff)
253
+ distance_b2bb = 0
254
+ for j in range(len(schedules[p - 1])):
255
+ if schedules[p - 1][j] == 'B':
256
+ for k in range(j, len(schedules[p - 1])):
257
+ if schedules[p - 1][k] == 'b':
258
+ distance_b2bb = k - j
259
+ break
260
+ break
261
+ for i in range(p):
262
+ c_f, c_ff, c_b, c_w = 0, 0, 0, 0
263
+ for j in range(len(schedules[i])):
264
+ char = schedules[i][j]
265
+ if char == 'F':
266
+ c_f += 1
267
+ elif char == 'f':
268
+ c_ff += 1
269
+ elif char == 'B':
270
+ c_b += 1
271
+ elif char == 'W':
272
+ c_w += 1
273
+ elif char == 'b':
274
+ bj = j
275
+ while j < len(schedules[i]):
276
+ char = schedules[i][j]
277
+ if char == 'f' and c_ff < cnt_ff[p - 1]:
278
+ schedules[i][j] = ' '
279
+ c_ff += 1
280
+ if char == 'B' and c_b < c_ff:
281
+ if c_b < (2 * (p - i) + distance_b2bb) // 3 or c_b < cnt_ff[p - 1] - cnt_ff[i]:
282
+ # there is empty space, or the number of B is not enough to cover extra f
283
+ schedules[i][j] = ' '
284
+ c_b += 1
285
+ if char == 'W' and c_w < c_b:
286
+ if c_w < (2 * (p - i) + distance_b2bb - 1) // 3 or c_w < cnt_ff[p - 1] - cnt_ff[i]:
287
+ # there is empty space, or the number of W is not enough to cover extra f
288
+ schedules[i][j] = ' '
289
+ c_w += 1
290
+ j += 1
291
+ j = bj
292
+ while j < len(schedules[i]):
293
+ if schedules[i][j] == 'F':
294
+ if c_f < c_ff or c_f < cnt_f[i] or c_f - cnt_f[i] + c_ff - cnt_ff[i] < c_w - 1:
295
+ # put enough F, or there are some unused BW
296
+ schedules[i][j] = ' '
297
+ c_f += 1
298
+ j += 1
299
+ break
300
+ else:
301
+ assert char == ' '
302
+ schedules[i][j] = ' '
303
+ assert c_f >= cnt_f[i] and c_ff >= cnt_ff[i]
304
+ assert c_w >= cnt_ff[p - 1] - cnt_ff[i] and c_b >= cnt_ff[p - 1] - cnt_ff[i]
305
+ j = i
306
+ u_f, u_ff, u_b, u_w = 0, 0, 0, 0
307
+ for _ in range(2 * (p - 1 - i)):
308
+ if u_f < cnt_f[i] and u_f < c_f:
309
+ schedules[i][j] = 'F'
310
+ u_f += 1
311
+ j += 1
312
+ for _ in range(i + 1):
313
+ if u_f < cnt_f[i] and u_f < c_f:
314
+ schedules[i][j] = 'F'
315
+ u_f += 1
316
+ j += 1
317
+ if u_ff < cnt_ff[i] and u_ff < c_ff:
318
+ schedules[i][j] = 'f'
319
+ u_ff += 1
320
+ j += 1
321
+ while u_f < c_f or u_ff < c_ff or u_b < c_b or u_w < c_w:
322
+ if u_b < c_b:
323
+ schedules[i][j] = 'B'
324
+ u_b += 1
325
+ j += 1
326
+ if u_w < c_w:
327
+ schedules[i][j] = 'W'
328
+ u_w += 1
329
+ j += 1
330
+ if u_ff < c_ff:
331
+ assert u_ff < u_f
332
+ schedules[i][j] = 'f'
333
+ u_ff += 1
334
+ elif u_f < c_f:
335
+ schedules[i][j] = 'F'
336
+ u_f += 1
337
+ j += 1
338
+ return schedules
339
+
340
+
341
+
342
+ def squeeze_without_change_order(schedules, m):
343
+ p = len(schedules)
344
+ squeezed = [[' '] * len(schedules[_]) for _ in range(p)]
345
+ max_len = 0
346
+ for seq in squeezed:
347
+ assert max_len == 0 or max_len == len(seq)
348
+ max_len = max(max_len, len(seq))
349
+
350
+ identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
351
+ identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
352
+ stage_index = [0 for _ in range(p)]
353
+ for j in range(max_len):
354
+ for _dir in range(2):
355
+ if _dir == 0:
356
+ _iter = range(p)
357
+ else:
358
+ _iter = range(p - 1, -1, -1)
359
+ for i in _iter:
360
+ identifier = schedules[i][j]
361
+ if identifier == ' ':
362
+ continue
363
+ if _dir == 0 and identifier in "fbw":
364
+ continue
365
+ if _dir == 1 and identifier in "FBW":
366
+ continue
367
+ _cnt = identifier_cnt[i][identifier]
368
+ assert _cnt < m, "{} - {}, {}".format(i, identifier, _cnt)
369
+ if identifier in "Ww" or (i == 0 and identifier in "FB") or (i == p - 1 and identifier in "fb"):
370
+ if i == 0 and identifier == 'B':
371
+ assert identifier_index[_cnt * p + i]['f'] >= 0
372
+ if i == p - 1 and identifier == 'f':
373
+ assert identifier_index[_cnt * p + i]['F'] >= 0
374
+ if i == p - 1 and identifier == 'b':
375
+ assert identifier_index[_cnt * p + i]['B'] >= 0
376
+ index = stage_index[i]
377
+ elif identifier in "FB":
378
+ assert identifier_index[_cnt * p + i - 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt)
379
+ index = max(identifier_index[_cnt * p + i - 1][identifier] + 1, stage_index[i])
380
+ elif identifier in "fb":
381
+ assert identifier_index[_cnt * p + i + 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt)
382
+ index = max(identifier_index[_cnt * p + i + 1][identifier] + 1, stage_index[i])
383
+ else:
384
+ raise
385
+ squeezed[i][index] = identifier
386
+ identifier_cnt[i][identifier] += 1
387
+ identifier_index[_cnt * p + i][identifier] = index
388
+ stage_index[i] = index + 1
389
+ return squeezed
390
+
391
+
392
+ def process_cooldown(schedules, m):
393
+ """
394
+ fBW bwbwbwbw
395
+ fBWBW bwbwbwbw
396
+ fBWBWBW bwbwbwbw
397
+ fBWBWBWBW bwbwbwbw
398
+ f BWBWBWBbWbwbwbww
399
+ f BWBWBbBbWbWbwwww
400
+ f BWBbBbBbWbWWwwww
401
+ f BbBbBbBbWWWWwwww
402
+ We reorganize the cooldown phase in the following way (i -> pipeline stage from 0):
403
+ 1. After the last f, we set #b = (peak_mem+1)//2, and #B = min(i+1, peak_mem - #b)
404
+ 2. After the last f, we make all the dependencies as tight as possible
405
+ """
406
+ p = len(schedules)
407
+
408
+ peak_mem = get_peak_mem(schedules)
409
+ assert peak_mem <= 2 * p
410
+ max_bb = (peak_mem + 1) // 2
411
+ max_bb = min(max_bb, m)
412
+ max_b = min(peak_mem - max_bb, m)
413
+
414
+ # 1: reorganize B/b and remove W/w in cooldown phase
415
+ starting_index = -1
416
+ for i in range(p):
417
+ c_b, c_bb, c_w, c_ww = 0, 0, 0, 0
418
+ last_ff_index = -1
419
+ # collect B/b which can be reorganized
420
+ for j in range(len(schedules[i]) - 1, -1, -1):
421
+ char = schedules[i][j]
422
+ if char == 'f' and last_ff_index == -1:
423
+ last_ff_index = j
424
+ if char == 'B' and c_b < i + 1 and c_b < max_b:
425
+ schedules[i][j] = ' '
426
+ c_b += 1
427
+ if char == 'b' and c_bb < max_bb:
428
+ schedules[i][j] = ' '
429
+ c_bb += 1
430
+ # clear W in the tail (#W + #w = peak_mem)
431
+ for j in range(len(schedules[i]) - 1, -1, -1):
432
+ char = schedules[i][j]
433
+ if char == 'W' and c_w + c_ww < peak_mem:
434
+ schedules[i][j] = ' '
435
+ c_w += 1
436
+ if char == 'w' and c_w + c_ww < peak_mem:
437
+ schedules[i][j] = ' '
438
+ c_ww += 1
439
+ if i == 0:
440
+ starting_index = last_ff_index
441
+ # reorganize B/b in the tail
442
+ for k in range(c_bb):
443
+ index = starting_index - i + 2 * p - 2 * k
444
+ assert schedules[i][index] == ' ', "{} {} {}".format(schedules[i][index], k, i)
445
+ schedules[i][index] = 'b'
446
+ for k in range(c_b):
447
+ index = starting_index + 1 + i - 2 * k
448
+ assert schedules[i][index] == ' ', schedules[i][index]
449
+ schedules[i][index] = 'B'
450
+
451
+ # 2: squeeze cooldown phase without change order
452
+ schedules = squeeze_without_change_order(schedules, m)
453
+
454
+ # 3: add W back in cooldown phase
455
+ for i in range(p):
456
+ c_w, c_ww = 0, 0
457
+ last_w_index = -2
458
+ for j in range(len(schedules[i]) - 1, -1, -1):
459
+ if schedules[i][j] in "Ww":
460
+ if last_w_index < 0:
461
+ schedules[i][j] = ' '
462
+ last_w_index += 1
463
+ else:
464
+ last_w_index = j
465
+ break
466
+ for j in range(len(schedules[i])):
467
+ char = schedules[i][j]
468
+ if char == 'B':
469
+ c_w += 1
470
+ elif char == 'b':
471
+ c_ww += 1
472
+ elif char == 'W':
473
+ c_w -= 1
474
+ elif char == 'w':
475
+ c_ww -= 1
476
+ if char == ' ' and j > last_w_index:
477
+ if c_w > 0:
478
+ schedules[i][j] = 'W'
479
+ c_w -= 1
480
+ elif c_ww > 0:
481
+ schedules[i][j] = 'w'
482
+ c_ww -= 1
483
+
484
+ schedules = squeeze_without_change_order(schedules, m)
485
+ return schedules
486
+
487
+
488
+ def schedule_by_pattern(p, m, patterns):
489
+ schedules = init_repeated_schedule(p, max(m, 2 * p), patterns)
490
+ schedules = clear_invalid_index(schedules, max(m, 2 * p))
491
+ print_schedules(schedules)
492
+ init_peak_mem = get_peak_mem(schedules)
493
+ if init_peak_mem > 2 * p:
494
+ return None, init_peak_mem, [6 * max(m, 2 * p)] * p
495
+ schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p))
496
+
497
+ for sid in range(len(schedules)):
498
+ cnt = {_id: 0 for _id in "FfBbWw"}
499
+ for i in range(len(schedules[sid])):
500
+ if(schedules[sid][i] == ' '):
501
+ continue
502
+ if cnt[schedules[sid][i]] >= m:
503
+ schedules[sid][i] = ' '
504
+ else:
505
+ cnt[schedules[sid][i]] += 1
506
+ print_schedules(schedules)
507
+ peak_mem = get_peak_mem(schedules)
508
+ if peak_mem > init_peak_mem:
509
+ return None, init_peak_mem, [6 * m] * p
510
+
511
+ schedules = squeeze_without_change_order(schedules, m)
512
+ print_schedules(schedules)
513
+
514
+ schedules = process_cooldown(schedules, m)
515
+ print_schedules(schedules)
516
+ peak_mem = get_peak_mem(schedules)
517
+ if peak_mem > init_peak_mem:
518
+ return None, init_peak_mem, [6 * m] * p
519
+
520
+ stage_bubbles = calc_bubble(schedules)
521
+ return schedules, peak_mem, stage_bubbles
522
+
523
+
524
+ def fill_w_in_pattern(pattern):
525
+ f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5
526
+ vis = [False] * pattern_size
527
+ for v in pattern:
528
+ if v >= 0:
529
+ vis[v] = True
530
+ assert pattern[b] >= 0 and pattern[bb] >= 0
531
+ for v, vw in [(b, w), (bb, ww)]:
532
+ for j in range(pattern_size):
533
+ pos = (pattern[v] + j) % pattern_size
534
+ if not vis[pos]:
535
+ pattern[vw] = pos
536
+ vis[pos] = True
537
+ break
538
+ return pattern
539
+
540
+
541
+ def get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p):
542
+ whole_pattern = [pattern_0]
543
+ for i in range(p - 1):
544
+ last_pattern = whole_pattern[i]
545
+ new_pattern = [-1] * pattern_size
546
+ vis = [False] * pattern_size
547
+ if i < len_0:
548
+ offset = offset_0
549
+ else:
550
+ offset = offset_1
551
+ for v, v_o in enumerate(offset):
552
+ pos = (last_pattern[v] + v_o + pattern_size) % pattern_size
553
+ assert 0 <= pos < pattern_size
554
+ if vis[pos]:
555
+ return None
556
+ vis[pos] = True
557
+ new_pattern[v] = pos
558
+ new_pattern = fill_w_in_pattern(new_pattern)
559
+ whole_pattern.append(new_pattern)
560
+ return whole_pattern
561
+
562
+
563
+
564
+ def schedule(p, m, cost, max_mem):
565
+ f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5
566
+ available_patterns = []
567
+ for ff_i in range(1, pattern_size):
568
+ for b_i in range(1, pattern_size):
569
+ for bb_i in range(1, pattern_size):
570
+ if ff_i == b_i or ff_i == bb_i or b_i == bb_i:
571
+ continue
572
+ pattern = [0, ff_i, b_i, bb_i, -1, -1]
573
+ pattern = fill_w_in_pattern(pattern)
574
+ available_patterns.append(pattern)
575
+ available_offsets = []
576
+ for f_o in range(1, pattern_size + 1):
577
+ for ff_o in range(1, pattern_size + 1):
578
+ for b_o in range(1, pattern_size + 1):
579
+ if f_o != b_o:
580
+ continue
581
+ bb_o = ff_o + b_o - f_o
582
+ if bb_o < 1 or bb_o > pattern_size:
583
+ continue
584
+ if bb_o + ff_o + b_o + f_o > 2 * pattern_size:
585
+ continue
586
+ # if bb_o + ff_o + b_o + f_o != 6:
587
+ # continue
588
+ offset = [f_o, - ff_o, b_o, - bb_o]
589
+ if min(ff_o, bb_o) > 1:
590
+ continue
591
+ available_offsets.append(offset)
592
+
593
+ print(available_offsets, len(available_patterns))
594
+ available_offsets = [
595
+ [1, -1, 1, -1],
596
+ [2, -1, 2, -1],
597
+ [3, -1, 3, -1],
598
+ [4, -1, 4, -1],
599
+ [5, -1, 5, -1]
600
+ ]
601
+
602
+ best_schedule = None
603
+ best_bubble = None
604
+ peak_mem2min_bubble = {}
605
+ for pattern_0 in available_patterns:
606
+ for i_0 in range(len(available_offsets)):
607
+ for i_1 in range(i_0 + 1):
608
+ for len_0 in range(1, p):
609
+ offset_0 = available_offsets[i_0]
610
+ offset_1 = available_offsets[i_1]
611
+ whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p)
612
+ if whole_pattern is None:
613
+ continue
614
+ # for pattern in whole_pattern:
615
+ # print(get_pattern_str(pattern))
616
+ # print(offset)
617
+ s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern)
618
+ if s is None:
619
+ continue
620
+ if peak_mem > 2 * p or peak_mem > max_mem:
621
+ continue
622
+ max_bubble = max(bubbles)
623
+ max_bubble = evaluate_schedule(s, *cost)
624
+ if best_schedule is None or max_bubble < best_bubble:
625
+ best_schedule, best_bubble = s, max_bubble
626
+ res = transform_schedule(best_schedule, *cost)
627
+ return res
app.py CHANGED
@@ -1,15 +1,12 @@
1
  import gradio as gr
2
- import auto_schedule
3
- import v_schedule
4
  import hand_schedule
 
 
 
 
5
  from PIL import Image
6
  from svg_event import render_manual_graph
7
  import pathlib
8
- def greet(name, is_morning, temperature):
9
- salutation = "Good morning" if is_morning else "Good evening"
10
- greeting = f"{salutation} {name}. It is {temperature} degrees today"
11
- celsius = (temperature - 32) * 5 / 9
12
- return greeting, round(celsius, 2)
13
 
14
  def percentage(x):
15
  return f"{x*100:.2f}%"
@@ -25,6 +22,26 @@ def get_schedule_time(result):
25
  )
26
  return time
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  img_queue = []
29
  def get_schedule_image(result, max_time):
30
  result = [
@@ -41,80 +58,87 @@ def get_schedule_image(result, max_time):
41
 
42
 
43
  def calculate(p, m, f, b, w, c, mem):
44
- if mem < p:
45
- baseline_time=None
46
- baseline_bubble=None
47
- baseline_acceleration=None
48
- baseline_image=None
49
- baseline_result=None
50
- else:
51
- baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
52
- baseline_result = [
53
- list(filter(lambda x: x.type in {'F', 'B'}, r)) for r in baseline_result
54
- ]
55
- baseline_time = get_schedule_time(baseline_result)
56
- baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
57
- baseline_acceleration=percentage(0)
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig(
61
- cost_f=f,
62
- cost_b=b,
63
- cost_w=w,
64
- cost_comm=c,
65
- max_mem=mem * 2,
66
- print_scaling=1000
67
- ))
68
-
69
- zb_time=get_schedule_time(zb_result)
70
-
71
- zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
72
- zb_acceleration=percentage(baseline_time/zb_time - 1) if baseline_time is not None else None
73
-
74
- if mem < p:
75
- zbv_time=None
76
- zbv_bubble=None
77
- zbv_acceleration=None
78
- zbv_image=None
79
- zbv_result=None
80
- else:
81
- zbv_graph = v_schedule.PipelineGraph(
82
- n_stage=p,
83
- n_micro=m,
84
- f_cost=f/2,
85
- b_cost=b/2,
86
- w_cost=w/2,
87
- c_cost=c,
88
- f_mem=2,
89
- b_mem=-1,
90
- w_mem=-1,
91
- max_mem=mem * 4,
92
- )
93
- zbv_result = zbv_graph.get_v_schedule()
94
-
95
- zbv_time = get_schedule_time(zbv_result)
96
- zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
97
- zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
98
-
99
- max_time = max(filter(lambda x: x is not None, [baseline_time, zb_time, zbv_time]))
100
  print(max_time)
101
  if baseline_result is not None:
102
  baseline_image = get_schedule_image(baseline_result, max_time)
103
- if zb_result is not None:
104
- zb_image = get_schedule_image(zb_result, max_time)
105
- if zbv_result is not None:
106
- zbv_image = get_schedule_image(zbv_result, max_time)
 
 
 
 
107
 
108
- return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]
 
 
 
 
109
 
110
  with gr.Blocks() as demo:
111
  gr.Markdown(open("description1.md").read())
112
  gr.Markdown("# Pipeline Scheduler Playground")
113
  presets = {
114
- 'Ideal Case 1p': (4, 12, 20, 20, 20, 0, '1p (Same as 1F1B)'),
115
- 'Ideal Case 2p': (4, 12, 20, 20, 20, 0, '2p'),
116
- 'Real Case 1p': (4, 12, 1049, 1122, 903, 79, '1p (Same as 1F1B)'),
117
- 'Real Case 2p': (4, 12, 1049, 1122, 903, 79, '2p'),
118
  }
119
  preset_buttons = {}
120
 
@@ -129,25 +153,31 @@ with gr.Blocks() as demo:
129
  with gr.Group():
130
  gr.Markdown("Basic Parameters")
131
  with gr.Row():
132
- p=gr.Number(label="Number of stages (p)", value=4, interactive=True, precision=0)
133
  m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0)
134
  with gr.Column(scale=2):
135
  with gr.Group():
136
- gr.Markdown("Costs. All costs are used as integers. For ZBV schedules, this is the time of two virtual stages on a stage combined.")
137
  with gr.Row():
138
- f=gr.Number(label="Time of F", value=100, interactive=True, precision=0)
139
- b=gr.Number(label="Time of B", value=110, interactive=True, precision=0)
140
- w=gr.Number(label="Time of W", value=90, interactive=True, precision=0)
141
- c=gr.Number(label="Time of one P2P communication", value=5, interactive=True, precision=0)
142
  with gr.Group():
143
  gr.Markdown("Activation memory limit.")
144
  def update_mem(p, s, mem):
145
  print("update")
146
- if s=="custom":
147
  return mem
148
- return int(p*float(s.split('p')[0]) + 0.5)
149
- memsel=gr.Radio(choices=["1p (Same as 1F1B)", "1.5p", "2p", "3p", "custom"], value="1p (Same as 1F1B)")
150
- mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0)
 
 
 
 
 
 
151
  memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
152
  p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
153
 
@@ -157,31 +187,53 @@ with gr.Blocks() as demo:
157
  gr.Markdown("1F1B")
158
  with gr.Row():
159
  with gr.Column(scale=1):
160
- baseline_time=gr.Textbox("", label="Longest Stage Time")
161
- baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
162
  baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
 
 
163
  with gr.Column(scale=4):
164
  baseline_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
165
-
166
  with gr.Group():
167
- gr.Markdown("ZB Schedule")
 
 
 
 
 
 
 
 
 
 
168
  with gr.Row():
169
  with gr.Column(scale=1):
170
- zb_time=gr.Textbox("", label="Longest Stage Time")
171
- zb_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
172
- zb_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
173
  with gr.Column(scale=4):
174
- zb_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
175
  with gr.Group():
176
- gr.Markdown("ZBV Schedule")
177
  with gr.Row():
178
  with gr.Column(scale=1):
179
- zbv_time=gr.Textbox("", label="Longest Stage Time")
180
- zbv_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
181
- zbv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
182
  with gr.Column(scale=4):
183
- zbv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
184
- button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  for (k, v) in presets.items():
187
  def update_preset(pb, p, m, f, b, w, c, mem):
@@ -192,6 +244,10 @@ with gr.Blocks() as demo:
192
  preset_buttons[k].click(
193
  update_preset,
194
  inputs=[preset_buttons[k], p, m, f, b, w, c, mem],
195
- outputs=[p, m, f, b, w, c, memsel, baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
196
- gr.Markdown(open("description2.md").read())
 
 
 
 
197
  demo.launch()
 
1
  import gradio as gr
 
 
2
  import hand_schedule
3
+ import adaptive_schedule
4
+ import interleaved_variant
5
+ import type2
6
+ import schedule1f1bv
7
  from PIL import Image
8
  from svg_event import render_manual_graph
9
  import pathlib
 
 
 
 
 
10
 
11
  def percentage(x):
12
  return f"{x*100:.2f}%"
 
22
  )
23
  return time
24
 
25
+
26
+ def get_memory_usage(result):
27
+ max_mem = 0
28
+ has_w = False
29
+ for r in result:
30
+ for x in r:
31
+ if x.type in ('W', 'w'):
32
+ has_w = True
33
+ for r in result:
34
+ cur = 0
35
+ for x in r:
36
+ if x.type in ('F', 'f'):
37
+ cur += 1
38
+ if x.type in ('W', 'w'):
39
+ cur -= 1
40
+ if has_w == False and x.type in ('B', 'b'):
41
+ cur -= 1
42
+ max_mem = max(max_mem, cur)
43
+ return max_mem
44
+
45
  img_queue = []
46
  def get_schedule_image(result, max_time):
47
  result = [
 
58
 
59
 
60
  def calculate(p, m, f, b, w, c, mem):
61
+ baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
62
+ baseline_result = [
63
+ list(filter(lambda x: x.type in {'F', 'B'}, r)) for r in baseline_result
64
+ ]
65
+ baseline_time = get_schedule_time(baseline_result)
66
+ baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
67
+ baseline_mem = get_memory_usage(baseline_result)
68
+ baseline_acceleration=percentage(0)
 
 
 
 
 
 
69
 
70
+ adapt_result = adaptive_schedule.schedule(
71
+ p,
72
+ m,
73
+ [f/2, b/2, w/2, c],
74
+ max_mem=mem * 2
75
+ )
76
+
77
+ adapt_time = get_schedule_time(adapt_result)
78
+ adapt_mem = get_memory_usage(adapt_result) / 2
79
+ adapt_bubble=percentage(adapt_time/(f+b+w)/m - 1)
80
+ adapt_acceleration=percentage(baseline_time/adapt_time - 1) if baseline_time is not None else None
81
+
82
+ schedule1f1bv_result = schedule1f1bv.schedule(
83
+ p,
84
+ m,
85
+ [f / 2, b / 2, w / 2, c]
86
+ )
87
+
88
+ schedule1f1bv_time = get_schedule_time(schedule1f1bv_result)
89
+ schedule1f1bv_mem = get_memory_usage(schedule1f1bv_result) / 2
90
+ schedule1f1bv_bubble=percentage(schedule1f1bv_time/(f+b+w)/m - 1)
91
+ schedule1f1bv_acceleration=percentage(baseline_time/schedule1f1bv_time - 1) if baseline_time is not None else None
92
+
93
+ type2_result = type2.schedule(
94
+ p,
95
+ m,
96
+ [f, b, w, c]
97
+ )
98
+
99
+ type2_time = get_schedule_time(type2_result)
100
+ type2_mem = get_memory_usage(type2_result)
101
+ type2_bubble=percentage(type2_time/(f+b+w)/m - 1)
102
+ type2_acceleration=percentage(baseline_time/type2_time - 1) if baseline_time is not None else None
103
+
104
+ interleaved_result = interleaved_variant.get_interleaved_variation(
105
+ p,
106
+ m,
107
+ [f/2, b/2, w/2, c]
108
+ )
109
+
110
+ interleaved_time = get_schedule_time(interleaved_result)
111
+ interleaved_mem = get_memory_usage(interleaved_result) / 2
112
+ interleaved_bubble=percentage(interleaved_time/(f+b+w)/m - 1)
113
+ interleaved_acceleration=percentage(baseline_time/interleaved_time - 1) if baseline_time is not None else None
114
 
115
+
116
+ max_time = max(filter(lambda x: x is not None, [baseline_time, adapt_time, interleaved_time, type2_time, schedule1f1bv_time]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  print(max_time)
118
  if baseline_result is not None:
119
  baseline_image = get_schedule_image(baseline_result, max_time)
120
+ if adapt_result is not None:
121
+ adapt_image = get_schedule_image(adapt_result, max_time)
122
+ if interleaved_result is not None:
123
+ interleaved_image = get_schedule_image(interleaved_result, max_time)
124
+ if type2_result is not None:
125
+ type2_image = get_schedule_image(type2_result, max_time)
126
+ if schedule1f1bv_result is not None:
127
+ schedule1f1bv_image = get_schedule_image(schedule1f1bv_result, max_time)
128
 
129
+ return [baseline_acceleration, baseline_mem, baseline_bubble, baseline_image,
130
+ adapt_acceleration, adapt_mem, adapt_bubble, adapt_image,
131
+ schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
132
+ type2_acceleration, type2_mem, type2_bubble, type2_image,
133
+ interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image]
134
 
135
  with gr.Blocks() as demo:
136
  gr.Markdown(open("description1.md").read())
137
  gr.Markdown("# Pipeline Scheduler Playground")
138
  presets = {
139
+ 'Real Case': (6, 12, 1049, 1122, 903, 79, 'V-Half'),
140
+ 'Ideal Case': (6, 12, 20, 20, 20, 0, 'V-Min'),
141
+ 'Zero Bubble Case': (6, 12, 1049, 1122, 903, 79, 'V-ZB')
 
142
  }
143
  preset_buttons = {}
144
 
 
153
  with gr.Group():
154
  gr.Markdown("Basic Parameters")
155
  with gr.Row():
156
+ p=gr.Number(label="Number of stages (p)", value=6, interactive=True, precision=0)
157
  m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0)
158
  with gr.Column(scale=2):
159
  with gr.Group():
160
+ gr.Markdown("Costs. All costs are used as integers. For chunked schedules, this is the time of two virtual stages on a stage combined.")
161
  with gr.Row():
162
+ f=gr.Number(label="Time of F", value=1049, interactive=True, precision=0)
163
+ b=gr.Number(label="Time of B", value=1122, interactive=True, precision=0)
164
+ w=gr.Number(label="Time of W", value=903, interactive=True, precision=0)
165
+ c=gr.Number(label="Time of one P2P communication", value=79, interactive=True, precision=0)
166
  with gr.Group():
167
  gr.Markdown("Activation memory limit.")
168
  def update_mem(p, s, mem):
169
  print("update")
170
+ if s == "custom":
171
  return mem
172
+ if s == "V-Min":
173
+ return (p + 4) // 3
174
+ if s == "V-Half":
175
+ return (p + 2) // 2
176
+ if s == "V-ZB":
177
+ return p
178
+ assert False
179
+ memsel=gr.Radio(choices=["V-Min", "V-Half", "V-ZB", "custom"], value="V-Half")
180
+ mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For chunked schedules, this is relative to two virtual stages on a stage combined.", value=(p.value + 2) // 2, interactive=True, precision=0)
181
  memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
182
  p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
183
 
 
187
  gr.Markdown("1F1B")
188
  with gr.Row():
189
  with gr.Column(scale=1):
 
 
190
  baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
191
+ baseline_mem=gr.Textbox("", label="Maximum memory usage")
192
+ baseline_bubble=gr.Textbox("", label="Bubble Rate")
193
  with gr.Column(scale=4):
194
  baseline_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
 
195
  with gr.Group():
196
+ gr.Markdown("Adaptive Scheduler")
197
+ with gr.Row():
198
+ with gr.Column(scale=1):
199
+ adapt_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
200
+ adapt_mem=gr.Textbox("", label="Maximum memory usage")
201
+ adapt_bubble=gr.Textbox("", label="Bubble Rate")
202
+ with gr.Column(scale=4):
203
+ adapt_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
204
+ gr.Markdown(open("description2.md").read())
205
+ with gr.Group():
206
+ gr.Markdown("1F1B-V Schedule")
207
  with gr.Row():
208
  with gr.Column(scale=1):
209
+ schedule1f1bv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
210
+ schedule1f1bv_mem=gr.Textbox("", label="Maximum memory usage")
211
+ schedule1f1bv_bubble=gr.Textbox("", label="Bubble Rate")
212
  with gr.Column(scale=4):
213
+ schedule1f1bv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
214
  with gr.Group():
215
+ gr.Markdown("Two microbatch in one building block schedule")
216
  with gr.Row():
217
  with gr.Column(scale=1):
218
+ type2_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
219
+ type2_mem=gr.Textbox("", label="Maximum memory usage")
220
+ type2_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
221
  with gr.Column(scale=4):
222
+ type2_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
223
+ with gr.Group():
224
+ gr.Markdown("Interleaved 1F1B Schedule")
225
+ with gr.Row():
226
+ with gr.Column(scale=1):
227
+ interleaved_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
228
+ interleaved_mem=gr.Textbox("", label="Maximum memory usage")
229
+ interleaved_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
230
+ with gr.Column(scale=4):
231
+ interleaved_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
232
+ button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_acceleration, baseline_mem, baseline_bubble, baseline_image,
233
+ adapt_acceleration, adapt_mem, adapt_bubble, adapt_image,
234
+ schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
235
+ type2_acceleration, type2_mem, type2_bubble, type2_image,
236
+ interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
237
 
238
  for (k, v) in presets.items():
239
  def update_preset(pb, p, m, f, b, w, c, mem):
 
244
  preset_buttons[k].click(
245
  update_preset,
246
  inputs=[preset_buttons[k], p, m, f, b, w, c, mem],
247
+ outputs=[p, m, f, b, w, c, memsel,
248
+ baseline_acceleration, baseline_mem, baseline_bubble, baseline_image,
249
+ adapt_acceleration, adapt_mem, adapt_bubble, adapt_image,
250
+ schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
251
+ type2_acceleration, type2_mem, type2_bubble, type2_image,
252
+ interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
253
  demo.launch()
auto_schedule.py DELETED
@@ -1,564 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import List, Set
3
-
4
-
5
-
6
- @dataclass
7
- class GraphConfig:
8
- mem_f: float = 2
9
- mem_b: float = -1
10
- mem_w: float = -1
11
- max_mem: float = None
12
- cost_f: int = 1
13
- cost_b: int = 1
14
- cost_w: int = 1
15
- cost_comm: int = 0
16
- print_scaling: int = 1
17
-
18
- def __post_init__(self):
19
- assert type(self.cost_f) is int
20
- assert type(self.cost_b) is int
21
- assert type(self.cost_w) is int
22
- assert type(self.cost_comm) is int
23
- assert self.mem_f + self.mem_b + self.mem_w == 0
24
-
25
- @dataclass(eq=True, frozen=True)
26
- class ScheduledNode:
27
- type: str
28
- stage: int
29
- minibatch: int
30
- start_time: int
31
- completion_time: int
32
- rollback: bool = False
33
-
34
-
35
- @dataclass
36
- class Graph:
37
- nstages: int
38
- nmb: int
39
- nnodes: int
40
- config: GraphConfig
41
- parents: List[Set[int]] = None
42
- name: List[str] = None
43
-
44
- # ID mapping:
45
- # F[stage][minibatch]: 0..STAGE* MB
46
- # B[stage][minibatch]: STAGE* MB .. 2 * STAGE * MB
47
- # W[stage][minibatch]: 2 * STAGE* MB .. 3 * STAGE * MB
48
-
49
- def get_id(self, type, stage, mb):
50
- return type * (self.nstages * self.nmb) + stage * self.nmb + mb
51
-
52
- def get_stage(self, id):
53
- return (id // self.nmb) % self.nstages
54
-
55
- def get_cost(self, id):
56
- type = id // (self.nstages * self.nmb)
57
- return [self.config.cost_f, self.config.cost_b, self.config.cost_w][type]
58
-
59
- def get_mem(self, id):
60
- type = id // (self.nstages * self.nmb)
61
- return [self.config.mem_f, self.config.mem_b, self.config.mem_w][type]
62
-
63
- @classmethod
64
- def build_graph(cls, nstages, nmb, config):
65
- nnodes = nstages * nmb * 3
66
- g = Graph(nstages=nstages, nmb=nmb, nnodes=nnodes, config=config)
67
- parents = []
68
- name = []
69
- for type in range(3):
70
- for stage in range(nstages):
71
- for mb in range(nmb):
72
- p = set()
73
- if type == 0:
74
- name.append(f'F{mb}')
75
- if stage > 0:
76
- p.add(g.get_id(type, stage - 1, mb))
77
- if mb > 0:
78
- p.add(g.get_id(type, stage, mb - 1))
79
- elif type == 1:
80
- name.append(f'B{mb}')
81
- if stage == nstages - 1:
82
- p.add(g.get_id(0, stage, mb))
83
- else:
84
- p.add(g.get_id(type, stage + 1, mb))
85
- if mb > 0:
86
- p.add(g.get_id(type, stage, mb - 1))
87
- elif type == 2:
88
- name.append(f'W{mb}')
89
- p.add(g.get_id(1, stage, mb))
90
- if mb > 0:
91
- p.add(g.get_id(type, stage, mb - 1))
92
- else:
93
- assert False
94
- parents.append(p)
95
-
96
- g.name = name
97
- g.parents = parents
98
- return g
99
-
100
- # Manual ordering producing this kind of schedule:
101
- # fffffffbfbfbfbfbfbwbwbwbwbwbwbwwwwww
102
- # fffffbfbfbfbfbfbfbfbwbwbwbwbwwwwwwww
103
- # fffbfbfbfbfbfbfbfbfbfbwbwbwwwwwwwwww
104
- # fbfbfbfbfbfbfbfbfbfbfbfbwwwwwwwwwwww
105
- # Returns the order index of each node on its own stage
106
- def manual_order(
107
- self, allow_bubble_before_first_b=False, prioritize_b=False, no_bubble_greedy=True
108
- ):
109
- order = [0] * self.nnodes
110
- f = [0] * self.nstages
111
- b = [0] * self.nstages
112
- w = [0] * self.nstages
113
- o = [0] * self.nstages
114
- m = [0] * self.nstages
115
- e = [0] * self.nstages
116
- t = [0] * self.nnodes
117
- max_mem = self.config.max_mem or self.get_mem(self.get_id(0, 0, 0)) * self.nmb * 3
118
- comm = self.config.cost_comm
119
- order_str = [""] * self.nstages
120
- stage_bubble = [0] * self.nstages
121
-
122
- def get_max_bubble():
123
- max_bubble = 0
124
- for bb in stage_bubble:
125
- max_bubble = max(max_bubble, bb)
126
- return max_bubble
127
-
128
- def put(stage_j, type_k):
129
- if type_k == 0:
130
- _i = f[stage_j]
131
- elif type_k == 1:
132
- _i = b[stage_j]
133
- else:
134
- _i = w[stage_j]
135
- _j = stage_j
136
- _id = self.get_id(type_k, _j, _i)
137
- _mem = self.get_mem(_id)
138
- _cost = self.get_cost(_id)
139
- assert m[_j] + _mem <= max_mem
140
-
141
- tmp = e[_j] + _cost
142
- no_bubble = tmp
143
- if _j > 0 and type_k == 0:
144
- tmp = max(tmp, t[self.get_id(0, _j - 1, _i)] + comm + _cost)
145
- if _j < self.nstages - 1 and type_k == 1:
146
- tmp = max(tmp, t[self.get_id(1, _j + 1, _i)] + comm + _cost)
147
- if f[_j] > 0:
148
- stage_bubble[_j] += tmp - no_bubble
149
- e[_j] = tmp
150
- t[_id] = tmp
151
- m[_j] += _mem
152
- order[_id] = o[_j]
153
- if type_k == 0:
154
- f[_j] += 1
155
- elif type_k == 1:
156
- b[_j] += 1
157
- else:
158
- w[_j] += 1
159
- o[_j] += 1
160
- fbw = "fbw"
161
- order_str[stage_j] += fbw[type_k]
162
-
163
- for i in range(self.nmb):
164
- if i == 0:
165
- for j in range(self.nstages):
166
- put(j, 0)
167
- f_required = [0] * self.nstages
168
- last_t = 0
169
- for j in range(self.nstages - 1, -1, -1):
170
- if j == self.nstages - 1:
171
- last_t = t[self.get_id(0, j, i)] + self.get_cost(self.get_id(1, j, i))
172
- continue
173
- mem = m[j]
174
- cost = e[j]
175
- while True:
176
- f_id = self.get_id(0, j, f[j] + f_required[j])
177
- if f[j] + f_required[j] < self.nmb and mem + self.get_mem(f_id) <= max_mem:
178
- if allow_bubble_before_first_b:
179
- if cost + self.get_cost(f_id) > last_t + comm:
180
- break
181
- else:
182
- if cost >= last_t + comm:
183
- break
184
- mem += self.get_mem(f_id)
185
- cost += self.get_cost(f_id)
186
- f_required[j] += 1
187
- else:
188
- break
189
- last_t = max(cost, last_t + comm) + self.get_cost(self.get_id(1, j, i))
190
- for j in range(self.nstages):
191
- while j > 0 and f_required[j] > 0 and f_required[j] >= f_required[j - 1] and f[j] + f_required[j] < self.nmb:
192
- f_required[j] -= 1
193
- for j in range(self.nstages - 1, -1, -1):
194
- for _ in range(f_required[j]):
195
- put(j, 0)
196
- put(j, 1)
197
- continue
198
- f_required = [0] * self.nstages
199
- for j in range(self.nstages):
200
- if f[j] >= self.nmb:
201
- continue
202
- if j + 1 < self.nstages and f[j] >= f[j + 1] + 2 and prioritize_b:
203
- next_plus_fw = (
204
- e[j + 1]
205
- + self.get_cost(self.get_id(0, j + 1, f[j + 1]))
206
- + self.get_cost(self.get_id(1, j + 1, b[j + 1]))
207
- + comm
208
- )
209
- if e[j] >= next_plus_fw:
210
- continue
211
- f_id = self.get_id(0, j, f[j])
212
- f_mem = self.get_mem(f_id)
213
- w_cost, w_cnt = 0, 0
214
- mem_with_w = m[j] + f_mem
215
- while mem_with_w > max_mem and w[j] + w_cnt < b[j]:
216
- w_id = self.get_id(2, j, w[j] + w_cnt)
217
- w_cost += self.get_cost(w_id)
218
- mem_with_w += self.get_mem(w_id)
219
- w_cnt += 1
220
- if e[j] + self.get_cost(f_id) + w_cost <= next_plus_fw:
221
- f_required[j] = 1
222
- continue
223
-
224
- w_cost, w_cnt = 0, 0
225
- # mem_with_w = m[j]
226
- # while w[j] + w_cnt < b[j]:
227
- # w_id = self.get_id(2, j, w[j] + w_cnt)
228
- # w_cost += self.get_cost(w_id)
229
- # mem_with_w += self.get_mem(w_id)
230
- # w_cnt += 1
231
- # if e[j] + w_cost >= next_plus_fw:
232
- # continue
233
- if next_plus_fw - (e[j] + w_cost) <= get_max_bubble() - stage_bubble[j]:
234
- # TODO: can sample here
235
- continue
236
- f_required[j] = 1
237
- for j in range(self.nstages - 2, -1, -1):
238
- f_required[j] = min(f_required[j], f_required[j + 1])
239
- for j in range(self.nstages):
240
- if f_required[j] == 0:
241
- continue
242
- f_id = self.get_id(0, j, f[j])
243
- mem = self.get_mem(f_id)
244
- while m[j] + mem > max_mem:
245
- if w[j] >= b[j]:
246
- raise ValueError("Cannot fit memory")
247
- put(j, 2)
248
- if j > 0:
249
- while (
250
- w[j] < b[j]
251
- and e[j] + self.get_cost(self.get_id(2, j, w[j]))
252
- <= t[self.get_id(0, j - 1, f[j])] + comm
253
- ):
254
- put(j, 2)
255
- if w[j] < b[j] and e[j] < t[self.get_id(0, j - 1, f[j])] + comm:
256
- # TODO: e[j] + self.get_cost(self.get_id(2, j, w[j])) > t[self.get_id(0, j - 1, f[j])] + comm
257
- if (
258
- t[self.get_id(0, j - 1, f[j])] + comm - e[j]
259
- <= get_max_bubble() - stage_bubble[j]
260
- ):
261
- # TODO: can sample here
262
- if no_bubble_greedy:
263
- put(j, 2)
264
- else:
265
- put(j, 2)
266
- put(j, 0)
267
- for j in range(self.nstages - 1, -1, -1):
268
- assert b[j] == i
269
- b_id = self.get_id(1, j, b[j])
270
- mem = self.get_mem(b_id)
271
- while m[j] + mem > max_mem:
272
- if w[j] >= b[j]:
273
- raise ValueError("Cannot fit memory")
274
- put(j, 2)
275
- if j + 1 < self.nstages:
276
- while (
277
- w[j] < b[j]
278
- and e[j] + self.get_cost(self.get_id(2, j, w[j]))
279
- <= t[self.get_id(1, j + 1, i)] + comm
280
- ):
281
- put(j, 2)
282
- if w[j] < b[j] and e[j] < t[self.get_id(1, j + 1, i)] + comm:
283
- # TODO: e[j] + self.get_cost(self.get_id(2, j, w[j])) > t[self.get_id(1, j + 1, i)] + comm
284
- if (
285
- t[self.get_id(1, j + 1, i)] + comm - e[j]
286
- <= get_max_bubble() - stage_bubble[j]
287
- ):
288
- # TODO: can sample here
289
- if no_bubble_greedy:
290
- put(j, 2)
291
- else:
292
- put(j, 2)
293
- if j == 0 and f[j] == self.nmb:
294
- while w[j] < b[j]:
295
- put(j, 2)
296
- put(j, 1)
297
-
298
- for i in range(self.nstages):
299
- while w[i] < self.nmb:
300
- put(i, 2)
301
- # print(f"{' ' * i}{order_str[i]} -> {e[i]}")
302
-
303
- for i in range(self.nstages):
304
- for j in range(self.nmb):
305
- f_id = self.get_id(0, i, j)
306
- b_id = self.get_id(1, i, j)
307
- w_id = self.get_id(2, i, j)
308
- f_cost = self.get_cost(f_id)
309
- b_cost = self.get_cost(b_id)
310
- w_cost = self.get_cost(w_id)
311
- assert t[b_id] >= t[f_id] + b_cost
312
- assert t[w_id] >= t[b_id] + w_cost, f"{i}-{j}, {t[w_id]} >= {t[b_id]} + {w_cost}"
313
- if i > 0:
314
- assert t[f_id] >= t[self.get_id(0, i - 1, j)] + comm + f_cost, f"{i}-{j}"
315
- if i < self.nstages - 1:
316
- assert t[b_id] >= t[self.get_id(1, i + 1, j)] + comm + b_cost
317
-
318
- # print(order)
319
- best_time = 0
320
- for i in range(self.nstages):
321
- time_i = (
322
- t[self.get_id(2, i, self.nmb - 1)]
323
- - t[self.get_id(0, i, 0)]
324
- + self.get_cost(self.get_id(0, i, 0))
325
- )
326
- best_time = max(best_time, time_i)
327
-
328
- return order, t, best_time
329
-
330
-
331
- def initial_solution(graph):
332
- best_time, order, complete_time = None, None, None
333
- for allow_bubble_before_first_b in [True, False]:
334
- for prioritize_b in [True, False]:
335
- for no_bubble_greedy in [True, False]:
336
- order_t, complete_time_t, best_time_t = graph.manual_order(
337
- allow_bubble_before_first_b=allow_bubble_before_first_b,
338
- prioritize_b=prioritize_b,
339
- no_bubble_greedy=no_bubble_greedy,
340
- )
341
- if best_time is None or best_time_t < best_time:
342
- best_time = best_time_t
343
- order = order_t
344
- complete_time = complete_time_t
345
-
346
- print_detail(graph, complete_time)
347
- print("-" * 20, best_time, "-" * 20)
348
- return best_time, order, complete_time
349
-
350
-
351
- def print_detail(graph, F):
352
- typenames = ['F', 'B', 'W']
353
- times = []
354
- for stage in range(graph.nstages):
355
- stage_str = ['.'] * int(F[graph.get_id(2, stage, graph.nmb - 1)] / graph.config.print_scaling)
356
- for _type in range(3):
357
- for _mb in range(graph.nmb):
358
- _id = graph.get_id(_type, stage, _mb)
359
- end = int(F[_id] / graph.config.print_scaling)
360
- start = int((F[_id] - graph.get_cost(_id)) / graph.config.print_scaling)
361
- for j in range(start, end):
362
- if j == start or j == end - 1:
363
- stage_str[j] = typenames[_type]
364
- elif j == start + 1:
365
- if _mb >= 10:
366
- stage_str[j] = str(_mb // 10)
367
- else:
368
- stage_str[j] = str(_mb)
369
- elif j == start + 2 and _mb >= 10:
370
- stage_str[j] = str(_mb % 10)
371
- else:
372
- stage_str[j] = "-"
373
- _str = ""
374
- for _c in stage_str:
375
- _str += _c
376
- times.append(
377
- F[graph.get_id(2, stage, graph.nmb - 1)]
378
- - F[graph.get_id(0, stage, 0)]
379
- + graph.get_cost(graph.get_id(0, stage, 0))
380
- )
381
- print(_str)
382
- print('Longest stage time: ', max(times))
383
-
384
-
385
- def ilp_results(graph, F):
386
- typenames = ['F', 'B', 'W']
387
- local_order = []
388
- end_time = []
389
- for i in range(graph.nnodes):
390
- end_time.append(F[i])
391
- for stage in range(graph.nstages):
392
- order = []
393
- for type in range(3):
394
- for mb in range(graph.nmb):
395
- id = graph.get_id(type, stage, mb)
396
- order.append(
397
- ScheduledNode(
398
- type=typenames[type],
399
- stage=stage,
400
- minibatch=mb,
401
- start_time=end_time[id] - graph.get_cost(id),
402
- completion_time=F[id],
403
- )
404
- )
405
- local_order.append(order)
406
- # For each F/B, append a send/recv node. The timestamp of recv node is the same as send node to guarrentee a global order.
407
- comm_id = {}
408
- comm_id_counter = 0
409
- post_validation_time = 0
410
- for i in range(graph.nstages - 1, -1, -1):
411
- warmup_f_count = -1
412
- first_b_end = end_time[graph.get_id(1, i, 0)]
413
- for j in range(graph.nmb):
414
- if end_time[graph.get_id(0, i, j)] < first_b_end:
415
- warmup_f_count += 1
416
- assert warmup_f_count >= 0
417
- pv_id = warmup_f_count
418
- _id = graph.get_id(0, i, pv_id)
419
- _cost = graph.get_cost(_id)
420
- post_validation_time = max(post_validation_time, end_time[_id] - _cost - graph.config.cost_comm)
421
- # post_validation_time = 0
422
- # print(i, pv_id, post_validation_time)
423
- for it in ["RECV_", "SEND_", ""]:
424
- if i == 0 and it == "SEND_":
425
- continue
426
- if i == graph.nstages - 1 and it == "RECV_":
427
- continue
428
- # stage_ = i - 1 if it == "RECV_" else i
429
- stage_ = i
430
- local_order[stage_].append(ScheduledNode(
431
- type=it + "POST_VALIDATION",
432
- stage=stage_,
433
- minibatch=0,
434
- start_time=post_validation_time,
435
- completion_time=post_validation_time,
436
- ))
437
- comm_id[local_order[stage_][-1]] = comm_id_counter
438
- comm_id_counter += 1
439
- for stage in range(graph.nstages):
440
- for node in local_order[stage]:
441
- if node.type == 'F' and node.stage != graph.nstages - 1:
442
- local_order[stage].append(
443
- ScheduledNode(
444
- type='SEND_FORWARD',
445
- stage=stage,
446
- minibatch=node.minibatch,
447
- start_time=node.completion_time,
448
- completion_time=node.completion_time, # TODO: consider comm cost in completion time
449
- )
450
- )
451
- local_order[stage + 1].append(
452
- ScheduledNode(
453
- type='RECV_FORWARD',
454
- stage=stage + 1,
455
- minibatch=node.minibatch,
456
- start_time=node.completion_time,
457
- completion_time=node.completion_time, # TODO: consider comm cost in completion time
458
- )
459
- )
460
- comm_id[local_order[stage][-1]] = comm_id_counter
461
- comm_id[local_order[stage + 1][-1]] = comm_id_counter
462
- comm_id_counter += 1
463
- if node.type == 'B' and node.stage != 0:
464
- local_order[stage].append(
465
- ScheduledNode(
466
- type='SEND_BACKWARD',
467
- stage=stage,
468
- minibatch=node.minibatch,
469
- start_time=node.completion_time,
470
- completion_time=node.completion_time, # TODO: consider comm cost in completion time
471
- )
472
- )
473
- local_order[stage - 1].append(
474
- ScheduledNode(
475
- type='RECV_BACKWARD',
476
- stage=stage - 1,
477
- minibatch=node.minibatch,
478
- start_time=node.completion_time,
479
- completion_time=node.completion_time, # TODO: consider comm cost in completion time
480
- )
481
- )
482
- comm_id[local_order[stage][-1]] = comm_id_counter
483
- comm_id[local_order[stage - 1][-1]] = comm_id_counter
484
- comm_id_counter += 1
485
- for stage in range(graph.nstages):
486
- # For nodes with the same timestamp on the same stage, communication will be prioritized.
487
- def even_breaker(x: ScheduledNode):
488
- # Compute nodes are always delayed.
489
- if x.type in ['F', 'B', 'W']:
490
- return comm_id_counter
491
- # For comm nodes, order by their unique comm id
492
- return comm_id[x]
493
-
494
- local_order[stage] = list(sorted(
495
- local_order[stage], key=lambda x: (x.start_time, even_breaker(x))
496
- ))
497
- # If a recv with intersects with previous computation, reorder them so that recv
498
- # is executed before computation and hence can be overlapped.
499
- for i in range(len(local_order[stage])):
500
- if i > 0 and local_order[stage][i - 1].type in {'F', 'B', 'W'} and \
501
- local_order[stage][i].type.startswith('RECV') and \
502
- "POST_VALIDATION" not in local_order[stage][i].type and \
503
- local_order[stage][i].start_time <= local_order[stage][i - 1].completion_time:
504
- (local_order[stage][i], local_order[stage][i - 1]) = (local_order[stage][i - 1], local_order[stage][i])
505
- # print([(x.type, x.start_time, x.completion_time) for x in local_order[stage]])
506
-
507
- local_order_with_rollback = [[] for _ in range(graph.nstages)]
508
- for rank in range(graph.nstages):
509
- rollback_comm = set()
510
- if rank > 0:
511
- for node in local_order[rank - 1]:
512
- if node.type == "POST_VALIDATION":
513
- break
514
- if node.type == "SEND_FORWARD":
515
- rollback_comm.add(node.minibatch)
516
- for node in local_order[rank]:
517
- if node.type == "RECV_FORWARD" and node.minibatch in rollback_comm:
518
- rollback = True
519
- rollback_comm.remove(node.minibatch)
520
- else:
521
- rollback = False
522
- local_order_with_rollback[rank].append(ScheduledNode(
523
- type=node.type,
524
- stage=node.stage,
525
- minibatch=node.minibatch,
526
- start_time=node.start_time,
527
- completion_time=node.completion_time,
528
- rollback=rollback,
529
- ))
530
- assert len(rollback_comm) == 0
531
- # for node in local_order_with_rollback[rank]:
532
- # print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=', ')
533
- # print()
534
-
535
- print_detail(graph, end_time)
536
- return local_order_with_rollback
537
-
538
-
539
- def auto_schedule(nstages, nmb, config):
540
- graph = Graph.build_graph(nstages, nmb, config)
541
-
542
- best_time, order, complete_time = initial_solution(graph)
543
- return ilp_results(graph, complete_time)
544
-
545
-
546
- if __name__ == "__main__":
547
- # auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=10))
548
- # auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=14))
549
- auto_schedule(24, 72, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=100))
550
- auto_schedule(4, 12, GraphConfig(
551
- cost_f=5478,
552
- cost_b=5806,
553
- cost_w=3534,
554
- cost_comm=200,
555
- max_mem=32,
556
- print_scaling=1000
557
- ))
558
- auto_schedule(32, 16, GraphConfig(
559
- cost_f=1,
560
- cost_b=1,
561
- cost_w=1,
562
- cost_comm=0,
563
- max_mem=64,
564
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
description1.md CHANGED
@@ -1,11 +1,5 @@
1
- # Zero Bubble Pipeline Parallelism
2
 
3
- Zero Bubble Pipeline Parallelism is a novel pipeline parallelism algorithm able to reduce the bubble of pipeline parallelism to almost zero while preserving synchronous semantics.
4
 
5
- Check out our paper at:
6
- * [Arxiv Version with ZBV](https://arxiv.org/abs/2401.10241)
7
- * [ICLR Accepted version with ZB1P and ZB2P](https://openreview.net/pdf?id=tuzTN0eIO5)
8
-
9
- Try out our implementation based on Megatron on [https://github.com/sail-sg/zero-bubble-pipeline-parallelism](https://github.com/sail-sg/zero-bubble-pipeline-parallelism)
10
-
11
- Experiments shows zero bubble pipeline parallelism can accelerate training up to 30% with a similar memory comsumption. A detailed table of experiments is coming soon.
 
1
+ # Pipeline Parallellism with Controllable Memory
2
 
3
+ Check out our paper at [Arxiv](https://arxiv.org/abs/2405.15362).
4
 
5
+ Bubble Rate here is calculated as (1 - longest stage time/(F+B+W)/m).
 
 
 
 
 
 
description2.md CHANGED
@@ -1,33 +1,6 @@
1
- ## Zero Bubble Schedules
2
- The key of achieving zero bubble is to breaking a backward pass into a B pass and W pass. B on one stage will only depend on the B on its next stage, compared to depending on both B and W of in 1F1B.
3
 
4
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/8B9thyMiLgysNi_m_O3Qn.png)
5
-
6
- ### Comparision of Schedules
7
- * 1F1B
8
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/Q3yxf4BQIESQ_M7lKKlhf.png)
9
- * ZB1P
10
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/EcTFvbjfM7soUXDYyn1Xu.png)
11
- * ZB2P
12
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/8jFI_rO69BREKqiSFHIOL.png)
13
- * ZBV - Each device is assigned to exactly 2 chunks (virtual stages), where white text colors represent the first chunk and black text colors represent the second chunk. The sequence of dependencies among model chunks follows a ”V” shape pattern for both the forward and backward passes.
14
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/VRfjNVXakAU3MQK3h6OKa.png)
15
-
16
-
17
-
18
- | Comparison assuming T_F=T_B=T_W | 1F1B | ZB1P | ZB2P | ZBV (Recommended) |
19
- | ----------------------------------------------------- | ------- | -------- | ---- | --- |
20
- | Bubble Rate | (p-1)/m | (p-1)/3m | 0 | 0 |
21
- | Activation Memory <br> (Compared to 1F1B) | 1x | 1x | 2x | 1x |
22
- | Pipeline Communication Volume <br> (Compared to 1F1B) | 1x | 1x | 1x | 2x |
23
-
24
-
25
-
26
- ## Optimizer Post Validation
27
-
28
- In most practices of PP there's an all-reduce cross all pipeline stages for numerical robustness, e.g. global gradient norm for gradient clipping. INF/NAN check for mixed precision training, etc. This all-reduce breaks parallelogram and makes zero bubble impossible.
29
- Under the observation that during a stable training both the gradient clipping and INF/NAN rarely triggers, we replace the before-hand synchronizations with a post update validation.
30
-
31
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/63510eea0b94548566dad923/hRPFqaFxJ20wm2omwyKmO.png)
32
-
33
- We eagerly step the optimizers assuming the grad cliping, INF/NAN conditions are not triggered. In case an amendment to the gradient is required, a rollback will be issued and then we redo the optimizer step based on the fully reduced global state.
 
1
+ ## Alternative schedules
 
2
 
3
+ By utilizing the building block, we can search for different types of schedules depending on the need. We illustrate few of them here below:
4
+ * 1F1B-V schedule without doing any B-W split.
5
+ * Schedule with 2/3rd 1F1B memory by utilising B-W split. Note that two microbatches are included in a single building block to avoid collision.
6
+ * Variation of interleaved 1F1B with lower memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
interleaved_variant.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass(eq=True, frozen=True)
4
+ class ScheduledNode:
5
+ type: str
6
+ chunk: int
7
+ stage: int
8
+ minibatch: int
9
+ start_time: int
10
+ completion_time: int
11
+
12
+ def get_interleaved_variation(_p, _n, cost):
13
+ _f, _b, _w, _c = cost
14
+ schedule = []
15
+ local_prev = {}
16
+
17
+ f_order = []
18
+ b_order = []
19
+
20
+ left = [_n, _n]
21
+ for id in range(min(_n, _p)):
22
+ f_order.append(('F', id))
23
+ for id in range(min(_n, _p)):
24
+ f_order.append(('f', id))
25
+
26
+ left = [max(0, _n - _p), max(0, _n - _p)]
27
+
28
+ i = 0
29
+ cur = 0
30
+ for id in range(min(_n, _p)):
31
+ b_order.append(('B', id))
32
+ while left[0] > 0 or left[1] > 0:
33
+ if i >= _p and left[1 - cur] > 0:
34
+ cur = 1 - cur
35
+ if left[cur] > 0:
36
+ if cur == 0:
37
+ f_order.append(('F', _n - left[cur]))
38
+ b_order.append(('b', _n - left[cur] - _p))
39
+ else:
40
+ f_order.append(('f', _n - left[cur]))
41
+ b_order.append(('B', _n - left[cur]))
42
+ left[cur] -= 1
43
+ i += 3
44
+ for id in range(min(_n, _p)):
45
+ b_order.append(('b', _n - _p + id))
46
+
47
+ for stage in range(_p):
48
+ diff = min(_p + _p - stage, len(f_order))
49
+ stage_schedule = []
50
+ for i in range(diff):
51
+ stage_schedule.append(f_order[i])
52
+ for i in range(len(f_order) - diff):
53
+ stage_schedule.append(b_order[i])
54
+ stage_schedule.append(f_order[i + diff])
55
+ for i in range(diff):
56
+ stage_schedule.append(b_order[len(b_order) - diff + i])
57
+ for i in range(len(stage_schedule) - 1):
58
+ local_prev[(stage, *stage_schedule[i + 1])] = (stage, *stage_schedule[i])
59
+ schedule.append(stage_schedule)
60
+ # print(stage_schedule)
61
+ # return None
62
+ cost = {
63
+ 'F': _f,
64
+ 'f': _f,
65
+ 'B': _b+_w,
66
+ 'b': _b+_w
67
+ }
68
+ pred = {
69
+ 'f': 'F',
70
+ 'B': 'f',
71
+ 'b': 'B'
72
+ }
73
+
74
+ time_map = {}
75
+ def get_time(stage, type, minibatch):
76
+ if (stage, type, minibatch) in time_map:
77
+ return time_map.get((stage, type, minibatch))
78
+ time = 0
79
+ if (stage, type, minibatch) in local_prev:
80
+ time = get_time(*local_prev[(stage, type, minibatch)])
81
+ if stage > 0 and type in ('F', 'f'):
82
+ time = max(time, get_time(stage - 1, type, minibatch) + _c)
83
+ if stage == 0 and type in ('f'):
84
+ time = max(time, get_time(_p - 1, pred[type], minibatch) + _c)
85
+ if stage != _p - 1 and type in ('B', 'b'):
86
+ time = max(time, get_time(stage + 1, type, minibatch) + _c)
87
+ if stage == _p - 1 and type in ('b'):
88
+ time = max(time, get_time(0, pred[type], minibatch) + _c)
89
+ if stage == _p - 1 and type in ('B'):
90
+ time = max(time, get_time(stage, pred[type], minibatch))
91
+
92
+ time_map[(stage, type, minibatch)] = time + cost[type]
93
+ return time_map[(stage, type, minibatch)]
94
+ result = []
95
+ for sid, stage in enumerate(schedule):
96
+ result_stage = []
97
+ for type, minibatch in stage:
98
+ result_stage.append(ScheduledNode(
99
+ type.upper(),
100
+ type in ('f', 'B', 'W'),
101
+ sid,
102
+ minibatch,
103
+ get_time(sid, type, minibatch) - cost[type],
104
+ get_time(sid, type, minibatch)
105
+ ))
106
+ result.append(result_stage)
107
+ return result
schedule1f1bv.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pattern_size = 6
2
+ from collections import Counter
3
+ from dataclasses import dataclass
4
+
5
+ @dataclass(eq=True, frozen=True)
6
+ class ScheduledNode:
7
+ type: str
8
+ chunk: int
9
+ stage: int
10
+ minibatch: int
11
+ start_time: int
12
+ completion_time: int
13
+
14
+ def transform_schedule(schedule, f, b, w, c):
15
+ result = []
16
+
17
+ stage_order = []
18
+ local_prev = {}
19
+ stages = len(schedule)
20
+
21
+ for sid, stage in enumerate(schedule):
22
+ counter = Counter()
23
+ order = []
24
+ for p in stage:
25
+ if not p.strip():
26
+ continue
27
+ mb = counter.get(p, 0)
28
+ if order:
29
+ local_prev[(sid, p, mb)] = order[-1]
30
+ order.append((p, mb))
31
+ counter.update(p)
32
+ stage_order.append(order)
33
+ nmb = max(counter.values())
34
+ time_map = {}
35
+ cost = {
36
+ 'F': f,
37
+ 'B': b + w,
38
+ 'f': f,
39
+ 'b': b + w,
40
+ }
41
+ def get_time(stage, type, mb):
42
+ if (stage, type, mb) in time_map:
43
+ return time_map.get((stage, type, mb))
44
+ time = 0
45
+ if (stage, type, mb) in local_prev:
46
+ time = get_time(stage, *local_prev[(stage, type, mb)])
47
+ if type in ('F', 'B') and stage > 0:
48
+ time = max(time, get_time(stage - 1, type, mb) + c)
49
+ if type in ('f', 'b') and stage + 1< len(schedule):
50
+ time = max(time, get_time(stage + 1, type, mb) + c)
51
+ time_map[(stage, type, mb)] = time + cost[type]
52
+ return time_map[(stage, type, mb)]
53
+ r = 0
54
+ for sid, stage in enumerate(schedule):
55
+ r = max(get_time(sid, 'b', nmb - 1) - get_time(sid, 'F', 0) + f, r)
56
+
57
+ for sid, stage in enumerate(stage_order):
58
+ result_stage = []
59
+ for p, mb in stage:
60
+ result_stage.append(ScheduledNode(
61
+ p.upper(),
62
+ p in ('f', 'B', 'W'),
63
+ sid,
64
+ mb,
65
+ get_time(sid, p, mb) - cost[p],
66
+ get_time(sid, p, mb)
67
+ )
68
+ )
69
+ result.append(result_stage)
70
+ return result
71
+
72
+
73
+ def get_pattern_str(pos):
74
+ pattern = [" "] * pattern_size
75
+ notations = "FfBbWw"
76
+ for i, v in enumerate(pos):
77
+ if v < 0:
78
+ continue
79
+ pattern[v] = notations[i]
80
+ _str = ""
81
+ for v in pattern:
82
+ _str += v
83
+ return _str
84
+
85
+ def init_repeated_schedule(p, m, patterns):
86
+ repeated = []
87
+ _len = 4 * p + m + 1
88
+ for i in range(p):
89
+ str_i = get_pattern_str(patterns[i]) * _len
90
+ repeated_i = []
91
+ for v in str_i:
92
+ repeated_i.append(v)
93
+ repeated.append(repeated_i)
94
+ return repeated
95
+
96
+
97
+ def clear_invalid(repeated, stage, pos, offset=-1):
98
+ while 0 <= pos < len(repeated[stage]):
99
+ repeated[stage][pos] = ' '
100
+ pos += offset * pattern_size
101
+ return repeated
102
+
103
+
104
+ def clear_invalid_index(repeated, m):
105
+ p = len(repeated)
106
+ index = pattern_size
107
+ for identifier in ['F', 'f', 'B', 'b']:
108
+ if identifier in ['F', 'B']:
109
+ _iter = range(p)
110
+ else:
111
+ _iter = range(p - 1, -1, -1)
112
+ for i in _iter:
113
+ for j in range(pattern_size):
114
+ if repeated[i][index] == identifier:
115
+ clear_invalid(repeated, i, index - pattern_size, offset=-1)
116
+ clear_invalid(repeated, i, index + pattern_size * m, offset=1)
117
+ index += 1
118
+ if identifier in ['B', 'b']:
119
+ w_identifier = {'B': 'W', 'b': 'w'}[identifier]
120
+ for k in range(pattern_size):
121
+ if repeated[i][index + k] == w_identifier:
122
+ clear_invalid(repeated, i, index + k - pattern_size, offset=-1)
123
+ clear_invalid(repeated, i, index + k + pattern_size * m, offset=1)
124
+ break
125
+ break
126
+ index += 1
127
+ return repeated
128
+
129
+
130
+ def process_warmup_without_increasing_peak_mem(schedules, m):
131
+ peak_mem = 0
132
+ mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))]
133
+ loc = [[{key: -1 for key in ('F', 'f', 'B', 'b', 'W', 'w')} for _ in range(m + 2)] for _ in range(len(schedules))]
134
+ cntr = [{key: 0 for key in ('F', 'f', 'B', 'b', 'W', 'w')} for _ in range(len(schedules))]
135
+ for sid in range(len(schedules)):
136
+ cur = 0
137
+ for i in range(len(schedules[sid])):
138
+ if schedules[sid][i] in ('F', 'f'):
139
+ cur += 1
140
+ if schedules[sid][i] in ('W', 'w'):
141
+ cur -= 1
142
+ mem[sid][i] = cur
143
+ peak_mem = max(peak_mem, cur)
144
+
145
+ for i in range(len(schedules[0])):
146
+ for sid in range(len(schedules)):
147
+ if schedules[sid][i] == ' ':
148
+ continue
149
+ cntr[sid][schedules[sid][i]] += 1
150
+ cnt = cntr[sid][schedules[sid][i]]
151
+ pos = -1
152
+ if cnt > 1:
153
+ pos = loc[sid][cnt - 1][schedules[sid][i]]
154
+ if schedules[sid][i] == 'W':
155
+ pos = max(pos, loc[sid][cnt]['B'])
156
+ if schedules[sid][i] == 'w':
157
+ pos = max(pos, loc[sid][cnt]['b'])
158
+ if schedules[sid][i] == 'F' and sid > 0:
159
+ pos = max(pos, loc[sid - 1][cnt]['F'])
160
+ if schedules[sid][i] == 'f':
161
+ if sid != len(schedules) - 1:
162
+ pos = max(pos, loc[sid + 1][cnt]['f'])
163
+ else :
164
+ pos = max(pos, loc[sid][cnt]['F'])
165
+ if schedules[sid][i] == 'B':
166
+ if sid != 0:
167
+ #Because B and W are always combined
168
+ pos = max(pos, loc[sid - 1][cnt]['W'])
169
+ else :
170
+ pos = max(pos, loc[sid][cnt]['f'])
171
+ if schedules[sid][i] == 'b':
172
+ if sid != len(schedules) - 1:
173
+ #Because B and W are always combined
174
+ pos = max(pos, loc[sid + 1][cnt]['w'])
175
+ else :
176
+ pos = max(pos, loc[sid][cnt]['W'])
177
+ pos += 1
178
+ while schedules[sid][pos] != ' ' and pos < i:
179
+ pos += 1
180
+ if schedules[sid][i] in ('B', 'b'):
181
+ while pos < i and (schedules[sid][pos] != ' ' or schedules[sid][pos + 1] != ' '):
182
+ pos += 1
183
+ if pos == i:
184
+ loc[sid][cnt][schedules[sid][i]] = i
185
+ continue
186
+ if schedules[sid][i] in ('B', 'b', 'W', 'w'):
187
+ schedules[sid][pos] = schedules[sid][i]
188
+ schedules[sid][i] = ' '
189
+ if schedules[sid][pos] in ('W', 'w'):
190
+ for j in range(pos, i):
191
+ mem[sid][j] -= 1
192
+ loc[sid][cnt][schedules[sid][pos]] = pos
193
+ continue
194
+
195
+ #If F or f:
196
+
197
+ place = i
198
+ while place > pos and mem[sid][place - 1] < peak_mem:
199
+ place -= 1
200
+ while place < i and schedules[sid][place] != ' ':
201
+ place += 1
202
+ if place == i:
203
+ loc[sid][cnt][schedules[sid][i]] = i
204
+ continue
205
+ pos = place
206
+ schedules[sid][pos] = schedules[sid][i]
207
+ schedules[sid][i] = ' '
208
+ for j in range(pos, i):
209
+ mem[sid][j] += 1
210
+ loc[sid][cnt][schedules[sid][pos]] = pos
211
+ return schedules
212
+
213
+ def schedule_by_pattern(p, m, patterns):
214
+ schedules = init_repeated_schedule(p, m, patterns)
215
+ schedules = clear_invalid_index(schedules, m)
216
+
217
+ schedules = process_warmup_without_increasing_peak_mem(schedules, m)
218
+ for sid in range(len(schedules)):
219
+ cnt = {_id: 0 for _id in "FfBbWw"}
220
+ for i in range(len(schedules[sid])):
221
+ if(schedules[sid][i] == ' '):
222
+ continue
223
+ if cnt[schedules[sid][i]] >= m:
224
+ schedules[sid][i] = ' '
225
+ else:
226
+ cnt[schedules[sid][i]] += 1
227
+
228
+
229
+ return schedules
230
+
231
+ def create_whole_pattern(p):
232
+ whole_pattern = [[0 for _ in range(6)] for _ in range(p)]
233
+ now = 0
234
+ for i in range(p):
235
+ now += 1
236
+ whole_pattern[i][0] = now
237
+ for i in range(p):
238
+ now += 1
239
+ whole_pattern[p - 1 - i][1] = now
240
+ now += 1
241
+ if p % 3 == 0:
242
+ now += 3
243
+ cyc = (3 - (p + 2) % 3) % 3
244
+ for i in range(p):
245
+ whole_pattern[i][2], whole_pattern[i][4] = now, now + 1
246
+ cyc += 1
247
+ now += 2
248
+ if(cyc == 3):
249
+ cyc = 0
250
+ now += 3
251
+ for i in range(p):
252
+ whole_pattern[p - 1 - i][3], whole_pattern[p - 1 - i][5] = now, now + 1
253
+ cyc += 1
254
+ now += 2
255
+ if(cyc == 3):
256
+ cyc = 0
257
+ now += 3
258
+ for sid in range(p):
259
+ for i in range(6):
260
+ whole_pattern[sid][i] %= 6
261
+ return whole_pattern
262
+
263
+ def schedule(p, m, cost):
264
+ whole_pattern = create_whole_pattern(p)
265
+ s = schedule_by_pattern(p, m, whole_pattern)
266
+ for sid in range(len(s)):
267
+ for i in range(len(s[sid])):
268
+ if s[sid][i] in ('W', 'w'):
269
+ s[sid][i] = ' '
270
+ res = transform_schedule(s, *cost)
271
+ return res
svg_event.py CHANGED
@@ -234,7 +234,7 @@ def plot_events(ctx, events, title_text: str, canvas_info: CanvasInfo, include_w
234
  if ENABLE_BATCH_ID:
235
  minibatch = str(e["minibatch"])
236
  center = (start + end) // 2
237
- data_ctx.text(h, center, minibatch, font_scale=0.6, fill='black' if e["chunk"] == 0 else 'white')
238
  if ENABLE_BORDER:
239
  data_ctx.line(h+SPAN_HEIGHT, 0, h+SPAN_HEIGHT+BORDER_SIZE, max_len - 1)
240
 
 
234
  if ENABLE_BATCH_ID:
235
  minibatch = str(e["minibatch"])
236
  center = (start + end) // 2
237
+ data_ctx.text(h, center, minibatch, font_scale=0.6, fill='white' if e["chunk"] == 0 else 'black')
238
  if ENABLE_BORDER:
239
  data_ctx.line(h+SPAN_HEIGHT, 0, h+SPAN_HEIGHT+BORDER_SIZE, max_len - 1)
240
 
type2.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pattern_size = 6
2
+ from collections import Counter
3
+ from dataclasses import dataclass
4
+
5
+ @dataclass(eq=True, frozen=True)
6
+ class ScheduledNode:
7
+ type: str
8
+ stage: int
9
+ minibatch: int
10
+ start_time: int
11
+ completion_time: int
12
+
13
+ def transform_schedule(schedule, f, b, w, c):
14
+ result = []
15
+
16
+ stage_order = []
17
+ local_prev = {}
18
+ stages = len(schedule)
19
+
20
+ for sid, stage in enumerate(schedule):
21
+ counter = Counter()
22
+ order = []
23
+ for p in stage:
24
+ if not p.strip():
25
+ continue
26
+ mb = counter.get(p, 0)
27
+ if order:
28
+ local_prev[(sid, p, mb)] = order[-1]
29
+ order.append((p, mb))
30
+ counter.update(p)
31
+ stage_order.append(order)
32
+ nmb = max(counter.values())
33
+ time_map = {}
34
+ cost = {
35
+ 'F': f,
36
+ 'B': b,
37
+ 'W': w,
38
+ }
39
+ def get_time(stage, type, mb):
40
+ if (stage, type, mb) in time_map:
41
+ return time_map.get((stage, type, mb))
42
+ time = 0
43
+ if (stage, type, mb) in local_prev:
44
+ time = get_time(stage, *local_prev[(stage, type, mb)])
45
+ if type in ('F') and stage > 0:
46
+ time = max(time, get_time(stage - 1, type, mb) + c)
47
+ if type in ('B') and stage + 1< len(schedule):
48
+ time = max(time, get_time(stage + 1, type, mb) + c)
49
+ # print(f'{stage} {type}:{mb}', time + cost[type])
50
+ time_map[(stage, type, mb)] = time + cost[type]
51
+ return time_map[(stage, type, mb)]
52
+ r = 0
53
+ for sid, stage in enumerate(schedule):
54
+ r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)
55
+
56
+ for sid, stage in enumerate(stage_order):
57
+ result_stage = []
58
+ for p, mb in stage:
59
+ result_stage.append(ScheduledNode(
60
+ p.upper(),
61
+ sid,
62
+ mb,
63
+ get_time(sid, p, mb) - cost[p],
64
+ get_time(sid, p, mb)
65
+ )
66
+ )
67
+ result.append(result_stage)
68
+ return result
69
+
70
+
71
+
72
+
73
+ def process_warmup_without_increasing_peak_mem(schedules, m):
74
+ peak_mem = 0
75
+ mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))]
76
+ loc = [[{key: -1 for key in ('F', 'B', 'W')} for _ in range(m + 2)] for _ in range(len(schedules))]
77
+ cntr = [{key: 0 for key in ('F', 'B', 'W')} for _ in range(len(schedules))]
78
+ for sid in range(len(schedules)):
79
+ cur = 0
80
+ for i in range(len(schedules[sid])):
81
+ if schedules[sid][i] in ('F'):
82
+ cur += 1
83
+ if schedules[sid][i] in ('W'):
84
+ cur -= 1
85
+ mem[sid][i] = cur
86
+ peak_mem = max(peak_mem, cur)
87
+ for i in range(len(schedules[0])):
88
+ for sid in range(len(schedules)):
89
+ if schedules[sid][i] == ' ':
90
+ continue
91
+ cntr[sid][schedules[sid][i]] += 1
92
+ cnt = cntr[sid][schedules[sid][i]]
93
+ pos = -1
94
+ if cnt > 1:
95
+ pos = loc[sid][cnt - 1][schedules[sid][i]]
96
+ if schedules[sid][i] == 'W':
97
+ pos = max(pos, loc[sid][cnt]['B'])
98
+ if schedules[sid][i] == 'F' and sid > 0:
99
+ pos = max(pos, loc[sid - 1][cnt]['F'])
100
+ if schedules[sid][i] == 'B':
101
+ if sid != len(schedules) - 1:
102
+ pos = max(pos, loc[sid + 1][cnt]['B'])
103
+ else :
104
+ pos = max(pos, loc[sid][cnt]['F'])
105
+ pos += 1
106
+ while schedules[sid][pos] != ' ' and pos < i:
107
+ pos += 1
108
+ if pos == i:
109
+ loc[sid][cnt][schedules[sid][i]] = i
110
+ continue
111
+ if schedules[sid][i] in ('B', 'W'):
112
+ schedules[sid][pos] = schedules[sid][i]
113
+ schedules[sid][i] = ' '
114
+ if schedules[sid][pos] in ('W'):
115
+ for j in range(pos, i):
116
+ mem[sid][j] -= 1
117
+ loc[sid][cnt][schedules[sid][pos]] = pos
118
+ continue
119
+
120
+ #If F:
121
+ if (sid == 0):
122
+ print(cnt, pos, i)
123
+ place = i
124
+ while place > pos and mem[sid][place - 1] < peak_mem:
125
+ place -= 1
126
+ while place < i and schedules[sid][place] != ' ':
127
+ place += 1
128
+ if place == i:
129
+ loc[sid][cnt][schedules[sid][i]] = i
130
+ continue
131
+ if (sid == 0):
132
+ print(place)
133
+ pos = place
134
+ schedules[sid][pos] = schedules[sid][i]
135
+ schedules[sid][i] = ' '
136
+ for j in range(pos, i):
137
+ mem[sid][j] += 1
138
+ loc[sid][cnt][schedules[sid][pos]] = pos
139
+ return schedules
140
+
141
+ def schedule(p, m, cost):
142
+ schedules = [[' ' for _ in range(6 * m + 2 * p + 6)] for _ in range(p)]
143
+ f_0, f_1, b_0, b_1= p-1, p+1, p, p + 2
144
+ for sid in range(p - 1, -1, -1):
145
+ for mid in range((m + 1) // 2):
146
+ if mid * 2 < m:
147
+ schedules[sid][f_0 + mid * 6], schedules[sid][b_0 + mid * 6] = 'F', 'B'
148
+ if mid * 2 + 1 < m:
149
+ schedules[sid][f_1 + mid * 6], schedules[sid][b_1 + mid * 6] = 'F', 'B'
150
+ f_0 -= 1
151
+ f_1 -= 1
152
+ b_0 += 1
153
+ b_1 += 1
154
+ cnt = 0
155
+ for i in range(len(schedules[0])):
156
+ if schedules[sid][i] == 'B':
157
+ cnt += 1
158
+ if schedules[sid][i] == ' ' and cnt > 0:
159
+ cnt -= 1
160
+ schedules[sid][i] = 'W'
161
+ schedules = process_warmup_without_increasing_peak_mem(schedules, m)
162
+ res = transform_schedule(schedules, *cost)
163
+ return res
v_schedule.py DELETED
@@ -1,474 +0,0 @@
1
- from collections import deque
2
- from dataclasses import dataclass
3
-
4
- @dataclass(eq=True, frozen=True)
5
- class ScheduledNode:
6
- type: str
7
- chunk: int
8
- stage: int
9
- minibatch: int
10
- start_time: int
11
- completion_time: int
12
- rollback: bool = False
13
-
14
-
15
- class PipelineGraph(object):
16
- def __init__(
17
- self, n_stage, n_micro, f_cost, b_cost, w_cost, c_cost,
18
- f_mem, b_mem, w_mem, max_mem=None,
19
- ):
20
- self.n_node = 6 * n_stage * n_micro
21
- self.n_stage = n_stage
22
- self.n_micro = n_micro
23
- self.f_cost = f_cost
24
- self.b_cost = b_cost
25
- self.w_cost = w_cost
26
- self.c_cost = c_cost
27
- self.f_mem = f_mem
28
- self.b_mem = b_mem
29
- self.w_mem = w_mem
30
- self.fbw_cost = [f_cost, b_cost, w_cost]
31
- self.fbw_mem = [f_mem, b_mem, w_mem]
32
- self.max_mem = max_mem or f_mem * self.n_stage * 2
33
-
34
- def get_id(self, cat, chunk, stage, micro):
35
- return cat * 2 * self.n_stage * self.n_micro + \
36
- chunk * self.n_stage * self.n_micro + \
37
- stage * self.n_micro + \
38
- micro
39
-
40
- def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None):
41
- count = []
42
- for i in range(self.n_stage):
43
- count.append([0] * 6)
44
-
45
- end_time = [-1] * self.n_node
46
- cur_time = [0] * self.n_stage
47
- mem = [0] * self.n_stage
48
- stage_bubble = [0] * self.n_stage
49
- pending_w = [deque() for _ in range(self.n_stage)]
50
- schedule = [[] for _ in range(self.n_stage)]
51
- stage_str = [" " * i for i in range(self.n_stage)]
52
-
53
- if approved_bubble is None:
54
- approved_bubble = [-1] * self.n_stage
55
- max_approved_bubble = max(approved_bubble)
56
-
57
- def get_max_stage_bubble(stage=-1):
58
- max_stage_bubble = 0
59
- for bb in stage_bubble:
60
- max_stage_bubble = max(max_stage_bubble, bb)
61
- if stage >= 0:
62
- max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage])
63
- return max_stage_bubble
64
-
65
- def put_w(stage):
66
- assert len(pending_w[stage]) > 0
67
- _, chunk_, _ = pending_w[stage].popleft()
68
- put(2, chunk_, stage)
69
-
70
- def put(cat, chunk, stage, assert_cnt=True):
71
- _tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat]
72
- _cnt = count[stage][cat * 2 + chunk]
73
- # assert _cnt < self.n_micro
74
- if _cnt >= self.n_micro:
75
- if not assert_cnt:
76
- stage_str[stage] += " "
77
- cur_time[stage] = _tmp # TODO
78
- return
79
- assert False
80
- assert mem[stage] + self.fbw_mem[cat] <= self.max_mem
81
- stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1)))
82
- if cat > 0 or chunk > 0:
83
- last_id = cat * 2 + chunk - 1
84
- if cat < 2:
85
- # if end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] < 0:
86
- # print(cat, chunk, stage, _cnt)
87
- # self.print_details(end_time)
88
- assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0
89
- else:
90
- assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0
91
- if chunk == 1 and cat < 2:
92
- if stage < self.n_stage - 1:
93
- _fa_id = self.get_id(cat, chunk, stage + 1, _cnt)
94
- assert end_time[_fa_id] >= 0
95
- _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
96
- if chunk == 0 and cat < 2:
97
- if stage > 0:
98
- _fa_id = self.get_id(cat, chunk, stage - 1, _cnt)
99
- # if end_time[_fa_id] < 0:
100
- # print(cat, chunk, stage, _cnt)
101
- # self.print_details(end_time)
102
- assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}"
103
- _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
104
- _id = self.get_id(cat, chunk, stage, _cnt)
105
- if count[stage][0] > 0:
106
- stage_bubble[stage] += _tmp - _no_bubble
107
- end_time[_id] = _tmp
108
- cur_time[stage] = _tmp
109
- mem[stage] += self.fbw_mem[cat]
110
- # noinspection PyTypeChecker
111
- schedule[stage].append((cat, chunk, _cnt))
112
- if cat == 1:
113
- pending_w[stage].append((2, chunk, _cnt))
114
- count[stage][cat * 2 + chunk] += 1
115
-
116
- # for _ in range(2 * self.n_stage):
117
- # for i in range(self.n_stage):
118
- # if count[i][1] >= count[i][0]:
119
- # put(0, 0, i, assert_cnt=False)
120
- # continue
121
- # if i == self.n_stage - 1:
122
- # put(0, 1, i, assert_cnt=False)
123
- # continue
124
- # fa_id = self.get_id(0, 1, i + 1, count[i][1])
125
- # if 0 <= end_time[fa_id] < cur_time[i + 1]: # TODO
126
- # put(0, 1, i, assert_cnt=False)
127
- # else:
128
- # put(0, 0, i, assert_cnt=False)
129
-
130
- for i in range(self.n_stage):
131
- put(0, 0, i)
132
- for i in range(self.n_stage - 1, -1, -1):
133
- if i == self.n_stage - 1:
134
- put(0, 1, i)
135
- continue
136
- tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost
137
- while mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp and count[i][0] < self.n_micro:
138
- for j in range(i + 1):
139
- put(0, 0, j)
140
- put(0, 1, i)
141
- iter_chunk_ = 0
142
- end_tmp = 0
143
- for i in range(self.n_stage):
144
- if i == 0:
145
- end_tmp = cur_time[0] + self.fbw_cost[1]
146
- continue
147
- tmp = end_tmp + self.c_cost
148
- while count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1] or count[i][1] <= count[i - 1][1] < self.n_micro:
149
- for j in range(self.n_stage - 1, i - 1, -1):
150
- if count[j][iter_chunk_] < self.n_micro:
151
- put(0, iter_chunk_, j)
152
- iter_chunk_ = 1 - iter_chunk_
153
- # while mem[i] + self.fbw_mem[0] <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp:
154
- # if iter_chunk_ == 0 and count[i][0] >= count[i - 1][0]:
155
- # break
156
- # for j in range(self.n_stage - 1, i - 1, -1):
157
- # if count[j][iter_chunk_] < self.n_micro:
158
- # put(0, iter_chunk_, j)
159
- # iter_chunk_ = 1 - iter_chunk_
160
- # end_tmp = max(tmp, cur_time[i]) + self.fbw_cost[1]
161
-
162
- # init_bubble = get_max_stage_bubble()
163
- # print(stage_bubble)
164
- for _ in range(2 * self.n_micro):
165
- # check mem before putting b
166
- for i in range(self.n_stage):
167
- while mem[i] + self.fbw_mem[1] > self.max_mem:
168
- assert len(pending_w[i]) > 0
169
- put_w(i)
170
- b0_ranks, b1_ranks = [], []
171
- for i in range(self.n_stage):
172
- if count[i][3] >= count[i][2]:
173
- b0_ranks.append(i)
174
- elif i == self.n_stage - 1:
175
- b1_ranks.append(i)
176
- else:
177
- fa_id = self.get_id(1, 1, i + 1, count[i][3])
178
- if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro:
179
- b1_ranks.append(i)
180
- else:
181
- b0_ranks.append(i)
182
- b_ranks = []
183
- # put b1
184
- for i in reversed(b1_ranks):
185
- b_ranks.append((i, 1))
186
- # put b0
187
- for i in b0_ranks:
188
- b_ranks.append((i, 0))
189
- for i, _chunk_ in b_ranks:
190
- fa_id = -1
191
- if _chunk_ == 1 and i < self.n_stage - 1:
192
- fa_id = self.get_id(1, 1, i + 1, count[i][3])
193
- if _chunk_ == 0 and i > 0:
194
- fa_id = self.get_id(1, 0, i - 1, count[i][2])
195
- while len(pending_w[i]) > 0 and fa_id >= 0 and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]:
196
- # fill the bubble
197
- put_w(i)
198
- if len(pending_w[i]) > 0 and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]:
199
- if _chunk_ == 1:
200
- put_w(i)
201
- elif fill_b:
202
- put_w(i)
203
- put(1, _chunk_, i)
204
-
205
- # put f
206
- for i in range(self.n_stage):
207
- if count[i][1] >= self.n_micro:
208
- continue
209
- put_item = None
210
- if count[i][1] >= count[i][0]:
211
- put_item = 0
212
- elif i == self.n_stage - 1:
213
- put_item = 1
214
- else:
215
- if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0:
216
- put_item = 1
217
- elif count[i][0] < self.n_micro:
218
- if i == 0:
219
- put_item = 0
220
- elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0:
221
- put_item = 0
222
- if put_item is None:
223
- continue
224
- # check mem before putting f
225
- while mem[i] + self.fbw_mem[0] > self.max_mem:
226
- assert len(pending_w[i]) > 0
227
- put_w(i)
228
- fa_id = -1
229
- if put_item == 0 and i > 0:
230
- fa_id = self.get_id(0, 0, i - 1, count[i][0])
231
- if put_item == 1 and i < self.n_stage - 1:
232
- fa_id = self.get_id(0, 1, i + 1, count[i][1])
233
- while len(pending_w[i]) > 0 and fa_id >= 0 and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]:
234
- # fill the bubble
235
- put_w(i)
236
- if len(pending_w[i]) > 0 and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]:
237
- if fill_f:
238
- put_w(i)
239
- put(0, put_item, i)
240
-
241
- for i in range(self.n_stage):
242
- while len(pending_w[i]) > 0:
243
- put_w(i)
244
-
245
- # for i in range(self.n_stage):
246
- # print(stage_str[i])
247
-
248
- max_bubble = get_max_stage_bubble()
249
- expected_time = sum(self.fbw_cost) * self.n_micro * 2
250
- bubble_rate = max_bubble / expected_time
251
- # print("%6.4f" % bubble_rate, "->", stage_bubble)
252
- if max_approved_bubble < 0 or max_bubble < max_approved_bubble:
253
- _schedule, _end_time, _max_bubble = self.try_v_schedule(
254
- fill_f=fill_f, fill_b=fill_b,
255
- approved_bubble=stage_bubble,
256
- )
257
- if _max_bubble < max_bubble:
258
- return _schedule, _end_time, _max_bubble
259
- # print("%2d %3d, [%5d %5d %5d], %6d -> %6.4f %6.4f" % \
260
- # (self.n_stage, self.n_micro, *self.fbw_cost, self.max_mem // self.f_mem, init_bubble / expected_time, bubble_rate), max_bubble)
261
- return schedule, end_time, max_bubble
262
-
263
- def print_details(self, end_time, print_scaling=1):
264
- for stage in range(self.n_stage):
265
- stage_str = ['.'] * int(max(end_time) / print_scaling)
266
- for _cat in range(3):
267
- for _chunk in range(2):
268
- for _micro in range(self.n_micro):
269
- _id = self.get_id(_cat, _chunk, stage, _micro)
270
- if end_time[_id] < 0:
271
- continue
272
- end = int(end_time[_id] / print_scaling)
273
- start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling)
274
- for j in range(start, end):
275
- if j == start or j == end - 1:
276
- stage_str[j] = "FfBbWw"[_cat * 2 + _chunk]
277
- elif j == start + 1:
278
- if _micro >= 10:
279
- stage_str[j] = str(_micro // 10)
280
- else:
281
- stage_str[j] = str(_micro)
282
- elif j == start + 2 and _micro >= 10:
283
- stage_str[j] = str(_micro % 10)
284
- else:
285
- stage_str[j] = "-"
286
- _str = ""
287
- for _c in stage_str:
288
- _str += _c
289
- print(_str)
290
-
291
- def get_v_schedule(self, only_run_time=False):
292
- schedule, end_time, max_bubble = None, None, None
293
- expected_time = sum(self.fbw_cost) * self.n_micro * 2
294
- for fill_b in [True, False]:
295
- for fill_f in [True, False]:
296
- _schedule, _end_time, _max_bubble = self.try_v_schedule(
297
- fill_b=fill_b, fill_f=fill_f
298
- )
299
- # print("")
300
- if max_bubble is None or _max_bubble < max_bubble:
301
- max_bubble = _max_bubble
302
- schedule = _schedule
303
- end_time = _end_time
304
- if only_run_time:
305
- return max_bubble + expected_time
306
- # self.print_details(end_time, print_scaling=1)
307
- bubble_rate = max_bubble / (expected_time + max_bubble)
308
- print("%2d %3d, [%5d %5d %5d %5d], %6d -> %6.4f" % \
309
- (self.n_stage, self.n_micro, *self.fbw_cost, self.c_cost, self.max_mem // self.f_mem, bubble_rate))
310
- local_order = [[] for _ in range(self.n_stage)]
311
- comm_id = {}
312
- comm_id_counter = 0
313
- post_validation_time = 0
314
- for i in range(self.n_stage - 1, -1, -1):
315
- pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1)
316
- post_validation_time = max(post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost)
317
- # post_validation_time = 0
318
- # print(i, pv_id, post_validation_time)
319
- for it in ["RECV_", "SEND_", ""]:
320
- if i == 0 and it == "SEND_":
321
- continue
322
- if i == self.n_stage - 1 and it == "RECV_":
323
- continue
324
- # stage_ = i - 1 if it == "RECV_" else i
325
- stage_ = i
326
- local_order[stage_].append(ScheduledNode(
327
- type=it + "POST_VALIDATION",
328
- chunk=0,
329
- stage=stage_,
330
- minibatch=0,
331
- start_time=post_validation_time,
332
- completion_time=post_validation_time,
333
- ))
334
- comm_id[local_order[stage_][-1]] = comm_id_counter
335
- comm_id_counter += 1
336
- for i in range(self.n_stage):
337
- for _cat_, _chunk_, _micro_ in schedule[i]:
338
- complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)]
339
- local_order[i].append(ScheduledNode(
340
- type="FBW"[_cat_],
341
- chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
342
- stage=i,
343
- minibatch=_micro_,
344
- start_time=complete_time - self.fbw_cost[_cat_],
345
- completion_time=complete_time,
346
- ))
347
- if _cat_ == 2: # no communication for W
348
- continue
349
- cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD"
350
- def communicate(send_recv, stage_):
351
- # noinspection PyTypeChecker
352
- local_order[stage_].append(ScheduledNode(
353
- type=send_recv + cat_str,
354
- chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
355
- stage=stage_,
356
- minibatch=_micro_,
357
- start_time=complete_time,
358
- completion_time=complete_time,
359
- ))
360
- comm_id[local_order[stage_][-1]] = comm_id_counter
361
-
362
- if _chunk_ == 1 and i > 0:
363
- communicate("SEND_", i)
364
- communicate("RECV_", i - 1)
365
- if _chunk_ == 0 and i < self.n_stage - 1:
366
- communicate("SEND_", i)
367
- communicate("RECV_", i + 1)
368
- comm_id_counter += 1
369
- for rank in range(self.n_stage):
370
- # For nodes with the same timestamp on the same stage, communication will be prioritized.
371
- def even_breaker(x: ScheduledNode):
372
- # Compute nodes are always delayed.
373
- if x.type in ['F', 'B', 'W']:
374
- return comm_id_counter
375
- # For comm nodes, order by their unique comm id
376
- return comm_id[x]
377
-
378
- local_order[rank] = list(sorted(
379
- local_order[rank],
380
- key=lambda x: (x.start_time, even_breaker(x))
381
- ))
382
- # If a recv with intersects with previous computation, reorder them so that recv
383
- # is executed before computation and hence can be overlapped.
384
- for i in range(len(local_order[rank])):
385
- if i > 0 and local_order[rank][i - 1].type in {'F', 'B', 'W'} and \
386
- local_order[rank][i].type.startswith('RECV') and \
387
- "POST_VALIDATION" not in local_order[rank][i].type and \
388
- local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time:
389
- local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i]
390
-
391
- local_order_with_rollback = [[] for _ in range(self.n_stage)]
392
- for rank in range(self.n_stage):
393
- rollback_comm = set()
394
- if rank > 0:
395
- for node in local_order[rank - 1]:
396
- if node.type == "POST_VALIDATION":
397
- break
398
- if node.type == "SEND_FORWARD":
399
- assert node.chunk == 0
400
- rollback_comm.add(node.minibatch)
401
- for node in local_order[rank]:
402
- if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm:
403
- rollback = True
404
- rollback_comm.remove(node.minibatch)
405
- else:
406
- rollback = False
407
- local_order_with_rollback[rank].append(ScheduledNode(
408
- type=node.type,
409
- chunk=node.chunk,
410
- stage=node.stage,
411
- minibatch=node.minibatch,
412
- start_time=node.start_time,
413
- completion_time=node.completion_time,
414
- rollback=rollback,
415
- ))
416
- assert len(rollback_comm) == 0
417
- for node in local_order_with_rollback[rank]:
418
- print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=', ')
419
- print()
420
-
421
- return local_order_with_rollback
422
-
423
-
424
- if __name__ == '__main__':
425
- settings = [
426
- # p, n, f, b, w, c, h, a, l
427
- # (8, 24, 18522, 18086, 9337, 601, 2304, 24, 24),
428
- # (8, 32, 18513, 18086, 9331, 626, 2304, 24, 24),
429
- # (8, 64, 18546, 18097, 9321, 762, 2304, 24, 24),
430
- # (8, 24, 29718, 29444, 19927, 527, 4096, 32, 32),
431
- # (8, 32, 29802, 29428, 19530, 577, 4096, 32, 32),
432
- # (8, 64, 29935, 29621, 19388, 535, 4096, 32, 32),
433
- # (16, 48, 11347, 11248, 8132, 377, 5120, 40, 48),
434
- # (16, 64, 11307, 11254, 8101, 379, 5120, 40, 48),
435
- # (16, 128, 11325, 11308, 8109, 378, 5120, 40, 48),
436
- # (32, 96, 10419, 10207, 7715, 408, 6144, 48, 64),
437
- # (32, 128, 10408, 10204, 7703, 408, 6144, 48, 64),
438
- # (32, 256, 10402, 10248, 7698, 460, 6144, 48, 64),
439
- # (4, 8, 6, 4, 4, 1, 4096, 32, 32),
440
- # (8, 24, 29444, 29718, 19927, 527, 4096, 32, 32),
441
- # ( 8, 32, 16099, 16504, 7589, 540, 2304, 24, 16),
442
- (16, 48, 14407, 14380, 9676, 1610, 4096, 32, 32),
443
- (16, 64, 14412, 14393, 9688, 1621, 4096, 32, 32),
444
- (16, 128,14316, 14306, 9639, 1619, 4096, 32, 32),
445
- (24, 72, 6763, 6969, 5251, 755, 5120, 40, 48),
446
- (24, 96, 6783, 6984, 5259, 758, 5120, 40, 48),
447
- (24, 192, 6785, 6990, 5260, 770, 5120, 40, 48),
448
- (32, 96, 9458, 9748, 7288, 879, 6144, 48, 64),
449
- (32, 128, 9469, 9744, 7306, 892, 6144, 48, 64),
450
- (32, 256, 9447, 9644, 7193, 887, 6144, 48, 64),
451
- ]
452
- s = 1024
453
-
454
- # h, a, s = 4096, 32, 1024
455
- # cost_f, cost_b, cost_w, cost_c = 29718, 29444, 19927, 527
456
- for p, n, f, b, w, c, h, a, _ in settings:
457
- mem_f = 34 * h + 5 * a * s
458
- mem_w = - 32 * h
459
- mem_b = - mem_w - mem_f
460
- for m_offset in range(p + 1):
461
- graph = PipelineGraph(
462
- n_stage=p,
463
- n_micro=n,
464
- f_cost=f,
465
- b_cost=b,
466
- w_cost=w,
467
- c_cost=c,
468
- f_mem=mem_f,
469
- b_mem=mem_b,
470
- w_mem=mem_w,
471
- max_mem=mem_f * (p * 2 + m_offset),
472
- )
473
- graph.get_v_schedule()
474
- break