QPHutu commited on
Commit
07554d1
1 Parent(s): 964a6f1

A better version

Browse files
Files changed (1) hide show
  1. adaptive_schedule.py +324 -56
adaptive_schedule.py CHANGED
@@ -1,5 +1,5 @@
1
  pattern_size = 6
2
- from collections import Counter
3
  from dataclasses import dataclass
4
 
5
  @dataclass(eq=True, frozen=True)
@@ -74,9 +74,6 @@ def transform_schedule(schedule, f, b, w, c):
74
  return result
75
 
76
 
77
-
78
-
79
-
80
  def evaluate_schedule(schedule, f, b, w, c):
81
  stage_order = []
82
  local_prev = {}
@@ -123,7 +120,21 @@ def evaluate_schedule(schedule, f, b, w, c):
123
  r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r)
124
  return r
125
 
126
- def get_pattern_str(pos):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  pattern = [" "] * pattern_size
128
  notations = "FfBbWw"
129
  for i, v in enumerate(pos):
@@ -167,11 +178,11 @@ def calc_bubble(schedules):
167
  return stage_bubbles
168
 
169
 
170
- def init_repeated_schedule(p, m, patterns):
171
  repeated = []
172
  _len = 4 * p + m + 1
173
  for i in range(p):
174
- str_i = get_pattern_str(patterns[i]) * _len
175
  repeated_i = []
176
  for v in str_i:
177
  repeated_i.append(v)
@@ -261,6 +272,8 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
261
  elif char == 'W':
262
  c_w += 1
263
  elif char == 'b':
 
 
264
  bj = j
265
  while j < len(schedules[i]):
266
  char = schedules[i][j]
@@ -290,8 +303,8 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
290
  else:
291
  assert char == ' '
292
  schedules[i][j] = ' '
293
- assert c_f >= cnt_f[i] and c_ff >= cnt_ff[i]
294
- assert c_w >= cnt_ff[p - 1] - cnt_ff[i] and c_b >= cnt_ff[p - 1] - cnt_ff[i]
295
  j = i
296
  u_f, u_ff, u_b, u_w = 0, 0, 0, 0
297
  for _ in range(2 * (p - 1 - i)):
@@ -365,15 +378,15 @@ def squeeze_without_change_order(schedules, m):
365
  assert identifier_index[_cnt * p + i]['B'] >= 0
366
  index = stage_index[i]
367
  elif identifier in "FB":
368
- assert identifier_index[_cnt * p + i - 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt)
369
  index = max(identifier_index[_cnt * p + i - 1][identifier] + 1, stage_index[i])
370
  elif identifier in "fb":
371
- assert identifier_index[_cnt * p + i + 1][identifier] >= 0, "{} {} {}".format(i,identifier,_cnt)
372
  index = max(identifier_index[_cnt * p + i + 1][identifier] + 1, stage_index[i])
373
  else:
374
  raise
375
  squeezed[i][index] = identifier
376
- identifier_cnt[i][identifier] += 1
377
  identifier_index[_cnt * p + i][identifier] = index
378
  stage_index[i] = index + 1
379
  return squeezed
@@ -396,7 +409,7 @@ def process_cooldown(schedules, m):
396
  p = len(schedules)
397
 
398
  peak_mem = get_peak_mem(schedules)
399
- assert peak_mem <= 2 * p
400
  max_bb = (peak_mem + 1) // 2
401
  max_bb = min(max_bb, m)
402
  max_b = min(peak_mem - max_bb, m)
@@ -406,7 +419,7 @@ def process_cooldown(schedules, m):
406
  for i in range(p):
407
  c_b, c_bb, c_w, c_ww = 0, 0, 0, 0
408
  last_ff_index = -1
409
- # collect B/b which can be reorganized
410
  for j in range(len(schedules[i]) - 1, -1, -1):
411
  char = schedules[i][j]
412
  if char == 'f' and last_ff_index == -1:
@@ -417,13 +430,15 @@ def process_cooldown(schedules, m):
417
  if char == 'b' and c_bb < max_bb:
418
  schedules[i][j] = ' '
419
  c_bb += 1
420
- # clear W in the tail (#W + #w = peak_mem)
421
  for j in range(len(schedules[i]) - 1, -1, -1):
422
  char = schedules[i][j]
423
- if char == 'W' and c_w + c_ww < peak_mem:
 
 
424
  schedules[i][j] = ' '
425
  c_w += 1
426
- if char == 'w' and c_w + c_ww < peak_mem:
427
  schedules[i][j] = ' '
428
  c_ww += 1
429
  if i == 0:
@@ -435,24 +450,17 @@ def process_cooldown(schedules, m):
435
  schedules[i][index] = 'b'
436
  for k in range(c_b):
437
  index = starting_index + 1 + i - 2 * k
438
- assert schedules[i][index] == ' ', schedules[i][index]
439
  schedules[i][index] = 'B'
440
 
441
- # 2: squeeze cooldown phase without change order
442
- schedules = squeeze_without_change_order(schedules, m)
443
-
444
- # 3: add W back in cooldown phase
445
  for i in range(p):
446
  c_w, c_ww = 0, 0
447
- last_w_index = -2
448
  for j in range(len(schedules[i]) - 1, -1, -1):
449
  if schedules[i][j] in "Ww":
450
- if last_w_index < 0:
451
- schedules[i][j] = ' '
452
- last_w_index += 1
453
- else:
454
- last_w_index = j
455
- break
456
  for j in range(len(schedules[i])):
457
  char = schedules[i][j]
458
  if char == 'B':
@@ -475,38 +483,281 @@ def process_cooldown(schedules, m):
475
  return schedules
476
 
477
 
478
- def schedule_by_pattern(p, m, patterns, max_mem):
479
- schedules = init_repeated_schedule(p, max(m, 2 * p), patterns)
480
- schedules = clear_invalid_index(schedules, max(m, 2 * p))
481
- init_peak_mem = get_peak_mem(schedules)
482
- if init_peak_mem > max_mem:
483
- return None, init_peak_mem, [6 * max(m, 2 * p)] * p
484
- schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p))
485
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  for sid in range(len(schedules)):
487
  cnt = {_id: 0 for _id in "FfBbWw"}
488
  for i in range(len(schedules[sid])):
489
- if(schedules[sid][i] == ' '):
490
  continue
491
  if cnt[schedules[sid][i]] >= m:
492
  schedules[sid][i] = ' '
493
  else:
494
  cnt[schedules[sid][i]] += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  peak_mem = get_peak_mem(schedules)
 
 
496
  if peak_mem > init_peak_mem:
497
  return None, init_peak_mem, [6 * m] * p
498
 
499
- schedules = squeeze_without_change_order(schedules, m)
 
 
 
 
 
 
 
500
 
 
501
  schedules = process_cooldown(schedules, m)
 
 
 
 
 
 
 
 
 
 
 
502
  peak_mem = get_peak_mem(schedules)
 
 
503
  if peak_mem > init_peak_mem:
504
  return None, init_peak_mem, [6 * m] * p
 
 
 
 
505
  stage_bubbles = calc_bubble(schedules)
 
 
 
506
  return schedules, peak_mem, stage_bubbles
507
 
508
 
509
- def fill_w_in_pattern(pattern):
510
  f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5
511
  vis = [False] * pattern_size
512
  for v in pattern:
@@ -523,10 +774,11 @@ def fill_w_in_pattern(pattern):
523
  return pattern
524
 
525
 
526
- def get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p):
527
- whole_pattern = [pattern_0]
 
528
  for i in range(p - 1):
529
- last_pattern = whole_pattern[i]
530
  new_pattern = [-1] * pattern_size
531
  vis = [False] * pattern_size
532
  if i < len_0:
@@ -540,26 +792,28 @@ def get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p):
540
  return None
541
  vis[pos] = True
542
  new_pattern[v] = pos
543
- new_pattern = fill_w_in_pattern(new_pattern)
544
- whole_pattern.append(new_pattern)
545
- return whole_pattern
546
 
547
 
548
 
549
  def schedule(p, m, cost, max_mem):
550
  f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5
551
- available_patterns = []
 
552
  for ff_i in range(1, pattern_size):
553
  for b_i in range(1, pattern_size):
554
  for bb_i in range(1, pattern_size):
555
  if ff_i == b_i or ff_i == bb_i or b_i == bb_i:
556
  continue
557
  pattern = [0, ff_i, b_i, bb_i, -1, -1]
558
- pattern = fill_w_in_pattern(pattern)
559
- available_patterns.append(pattern)
560
 
561
- print(len(available_patterns))
562
  available_offsets = [
 
563
  [1, -1, 1, -1],
564
  [2, -1, 2, -1],
565
  [3, -1, 3, -1],
@@ -569,23 +823,37 @@ def schedule(p, m, cost, max_mem):
569
 
570
  best_schedule = None
571
  best_bubble = None
572
- for pattern_0 in available_patterns:
 
573
  for i_0 in range(len(available_offsets)):
574
  for i_1 in range(i_0 + 1):
575
  for len_0 in range(1, p):
576
  offset_0 = available_offsets[i_0]
577
  offset_1 = available_offsets[i_1]
578
- whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p)
579
- if whole_pattern is None:
580
  continue
581
- s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern, min(2 * p, max_mem))
582
  if peak_mem > 2 * p or peak_mem > max_mem:
583
  break
584
  if s is None:
585
  continue
586
- max_bubble = max(bubbles)
587
  max_bubble = evaluate_schedule(s, *cost)
588
  if best_schedule is None or max_bubble < best_bubble:
589
  best_schedule, best_bubble = s, max_bubble
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
  res = transform_schedule(best_schedule, *cost)
591
- return res
 
1
  pattern_size = 6
2
+ from collections import Counter, deque
3
  from dataclasses import dataclass
4
 
5
  @dataclass(eq=True, frozen=True)
 
74
  return result
75
 
76
 
 
 
 
77
  def evaluate_schedule(schedule, f, b, w, c):
78
  stage_order = []
79
  local_prev = {}
 
120
  r = max(get_time(sid, 'w', nmb - 1) - get_time(sid, 'F', 0) + f, r)
121
  return r
122
 
123
+
124
+ debug = False
125
+ def print_schedules(schedules, msg = None, force=False):
126
+ if not debug and not force:
127
+ return
128
+ if msg is not None:
129
+ print(msg)
130
+ for seq in schedules:
131
+ _str = ""
132
+ for v in seq:
133
+ _str += v
134
+ print(_str)
135
+
136
+
137
+ def get_building_block_str(pos):
138
  pattern = [" "] * pattern_size
139
  notations = "FfBbWw"
140
  for i, v in enumerate(pos):
 
178
  return stage_bubbles
179
 
180
 
181
+ def init_repeated_schedule(p, m, building_block):
182
  repeated = []
183
  _len = 4 * p + m + 1
184
  for i in range(p):
185
+ str_i = get_building_block_str(building_block[i]) * _len
186
  repeated_i = []
187
  for v in str_i:
188
  repeated_i.append(v)
 
272
  elif char == 'W':
273
  c_w += 1
274
  elif char == 'b':
275
+ break
276
+ # This logic can be removed because it is too complicated and should not impact the optimal solution
277
  bj = j
278
  while j < len(schedules[i]):
279
  char = schedules[i][j]
 
303
  else:
304
  assert char == ' '
305
  schedules[i][j] = ' '
306
+ # assert c_f >= cnt_f[i] and c_ff >= cnt_ff[i]
307
+ # assert c_w >= cnt_ff[p - 1] - cnt_ff[i] and c_b >= cnt_ff[p - 1] - cnt_ff[i]
308
  j = i
309
  u_f, u_ff, u_b, u_w = 0, 0, 0, 0
310
  for _ in range(2 * (p - 1 - i)):
 
378
  assert identifier_index[_cnt * p + i]['B'] >= 0
379
  index = stage_index[i]
380
  elif identifier in "FB":
381
+ assert identifier_index[_cnt * p + i - 1][identifier] >= 0, "{} {} {}".format(i, identifier,_cnt)
382
  index = max(identifier_index[_cnt * p + i - 1][identifier] + 1, stage_index[i])
383
  elif identifier in "fb":
384
+ assert identifier_index[_cnt * p + i + 1][identifier] >= 0, "{} {} {}".format(i, identifier,_cnt)
385
  index = max(identifier_index[_cnt * p + i + 1][identifier] + 1, stage_index[i])
386
  else:
387
  raise
388
  squeezed[i][index] = identifier
389
+ identifier_cnt[i][identifier] = _cnt + 1
390
  identifier_index[_cnt * p + i][identifier] = index
391
  stage_index[i] = index + 1
392
  return squeezed
 
409
  p = len(schedules)
410
 
411
  peak_mem = get_peak_mem(schedules)
412
+ assert peak_mem <= 2 * p, peak_mem
413
  max_bb = (peak_mem + 1) // 2
414
  max_bb = min(max_bb, m)
415
  max_b = min(peak_mem - max_bb, m)
 
419
  for i in range(p):
420
  c_b, c_bb, c_w, c_ww = 0, 0, 0, 0
421
  last_ff_index = -1
422
+ # collect B/b which can be reordered
423
  for j in range(len(schedules[i]) - 1, -1, -1):
424
  char = schedules[i][j]
425
  if char == 'f' and last_ff_index == -1:
 
430
  if char == 'b' and c_bb < max_bb:
431
  schedules[i][j] = ' '
432
  c_bb += 1
433
+ # clear W in the tail (#W + #w >= peak_mem & #W >= #B & #w >= #b)
434
  for j in range(len(schedules[i]) - 1, -1, -1):
435
  char = schedules[i][j]
436
+ if c_w >= c_b and c_ww >= c_bb and c_w + c_ww >= peak_mem:
437
+ break
438
+ if char == 'W':
439
  schedules[i][j] = ' '
440
  c_w += 1
441
+ if char == 'w':
442
  schedules[i][j] = ' '
443
  c_ww += 1
444
  if i == 0:
 
450
  schedules[i][index] = 'b'
451
  for k in range(c_b):
452
  index = starting_index + 1 + i - 2 * k
453
+ # assert schedules[i][index] == ' ', schedules[i][index]
454
  schedules[i][index] = 'B'
455
 
456
+ # 2: add W back in cooldown phase
 
 
 
457
  for i in range(p):
458
  c_w, c_ww = 0, 0
459
+ last_w_index = -1
460
  for j in range(len(schedules[i]) - 1, -1, -1):
461
  if schedules[i][j] in "Ww":
462
+ last_w_index = j
463
+ break
 
 
 
 
464
  for j in range(len(schedules[i])):
465
  char = schedules[i][j]
466
  if char == 'B':
 
483
  return schedules
484
 
485
 
486
+ def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index = None, ending_index = None):
487
+ """
488
+ We iterate all the cells from left to right. If a vacant cell (which means a bubble) is encountered, we try to
489
+ find a computation pass to fill this bubble. We iterate all the following computation passes in the same device,
490
+ and check whether it is possible to move if we keep all other passes unchanged. If the check succeeds, we move it
491
+ to the vacant cell, and the bubble is filled.
492
+ """
493
+ p = len(schedules)
494
+ max_len = 0
495
+ for seq in schedules:
496
+ assert max_len == 0 or max_len == len(seq)
497
+ max_len = max(max_len, len(seq))
498
+ if starting_index is not None:
499
+ assert isinstance(starting_index, list) and len(starting_index) == p
500
+ if ending_index is not None:
501
+ assert isinstance(ending_index, list) and len(ending_index) == p
502
+ starting_index = starting_index or [0] * p
503
+ ending_index = ending_index or [max_len] * p
504
+
505
+ last_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p)]
506
+ for i in range(p):
507
+ for j in range(max_len):
508
+ identifier = schedules[i][j]
509
+ if identifier == ' ':
510
+ continue
511
+ last_index[i][identifier] = j
512
+
513
+ peak_mem = get_peak_mem(schedules)
514
+ stage_mem = [0] * p
515
+ def update_mem(stage_i, pass_c):
516
+ if pass_c in "Ff":
517
+ stage_mem[stage_i] += 1
518
+ elif pass_c in "Ww":
519
+ stage_mem[stage_i] -= 1
520
+
521
+ identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
522
+ identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
523
+ for j in range(0, max_len):
524
+ for i in range(p):
525
+ identifier = schedules[i][j]
526
+ if identifier in "FfBbWw":
527
+ _cnt = identifier_cnt[i][identifier]
528
+ identifier_cnt[i][identifier] = _cnt + 1
529
+ identifier_index[_cnt * p + i][identifier] = j
530
+ update_mem(i, identifier)
531
+ continue
532
+ assert identifier == ' '
533
+ if j < starting_index[i] or j >= ending_index[i]:
534
+ continue
535
+ available = set()
536
+ for c in "FfBbWw":
537
+ if last_index[i][c] > j:
538
+ available.add(c)
539
+ mem_delta, peak_delta = 0, 0
540
+ for k in range(j + 1, ending_index[i]):
541
+ if len(available) == 0:
542
+ break
543
+ identifier = schedules[i][k]
544
+ if identifier in "Ff":
545
+ mem_delta += 1
546
+ elif identifier in "Ww":
547
+ mem_delta -= 1
548
+ prev_peak = peak_delta
549
+ peak_delta = max(peak_delta, mem_delta)
550
+ if identifier == ' ' or identifier not in available:
551
+ continue
552
+ available.remove(identifier)
553
+ if identifier in "Ff" and stage_mem[i] + prev_peak >= peak_mem:
554
+ # will increase peak memory
555
+ continue
556
+ can_move = True
557
+ _cnt = identifier_cnt[i][identifier]
558
+ if identifier in "FB":
559
+ if i > 0:
560
+ _index = identifier_index[_cnt * p + i - 1][identifier]
561
+ if _index <= -1 or _index >= j:
562
+ can_move = False
563
+ elif identifier == 'B':
564
+ if identifier_cnt[i]['f'] <= _cnt:
565
+ can_move = False
566
+ elif identifier in "fb":
567
+ if i + 1 < p:
568
+ _index = identifier_index[_cnt * p + i + 1][identifier]
569
+ if _index <= -1 or _index >= j:
570
+ can_move = False
571
+ else:
572
+ _pi = 'F' if identifier == 'f' else 'B'
573
+ if identifier_cnt[i][_pi] <= _cnt:
574
+ can_move = False
575
+ elif identifier in "Ww":
576
+ _bi = 'B' if identifier == 'W' else 'b'
577
+ if identifier_cnt[i][_bi] <= _cnt:
578
+ can_move = False
579
+ else:
580
+ assert False
581
+ if not can_move:
582
+ continue
583
+ # if i == 0:
584
+ # print(peak_mem, stage_mem[i], identifier, mem_delta)
585
+ schedules[i][j] = identifier
586
+ schedules[i][k] = ' '
587
+ identifier_cnt[i][identifier] = _cnt + 1
588
+ identifier_index[_cnt * p + i][identifier] = j
589
+ update_mem(i, identifier)
590
+ break
591
+ return schedules
592
+
593
+
594
+ def check_correctness(schedules, m, raise_exception=False):
595
+ p = len(schedules)
596
+ c_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
597
+ for i in range(p):
598
+ c_cnt = {_id: 0 for _id in "FfBbWw"}
599
+ for j in range(len(schedules[i])):
600
+ c = schedules[i][j]
601
+ if c in "FfBbWw":
602
+ _cnt = c_cnt[c]
603
+ assert _cnt < m
604
+ c_index[_cnt * p + i][c] = j
605
+ c_cnt[c] = _cnt + 1
606
+ for c in "FfBbWw":
607
+ if c_cnt[c] != m:
608
+ assert not raise_exception
609
+ return False
610
+ for i in range(p):
611
+ for j in range(m):
612
+ for c in "FfBbWw":
613
+ if c_index[j * p + i][c] == -1:
614
+ assert not raise_exception
615
+ return False
616
+ if c_index[j * p + i]['B'] >= c_index[j * p + i]['W']:
617
+ assert not raise_exception, f"{i} {j} {c}"
618
+ return False
619
+ if c_index[j * p + i]['b'] >= c_index[j * p + i]['w']:
620
+ assert not raise_exception
621
+ return False
622
+ if i == 0:
623
+ if c_index[j * p + i]['f'] >= c_index[j * p + i]['B']:
624
+ assert not raise_exception
625
+ return False
626
+ elif i == p - 1:
627
+ if c_index[j * p + i]['F'] >= c_index[j * p + i]['f']:
628
+ assert not raise_exception
629
+ return False
630
+ if c_index[j * p + i]['B'] >= c_index[j * p + i]['b']:
631
+ assert not raise_exception
632
+ return False
633
+ else:
634
+ if c_index[j * p + i - 1]['F'] >= c_index[j * p + i]['F']:
635
+ assert not raise_exception
636
+ return False
637
+ if c_index[j * p + i - 1]['B'] >= c_index[j * p + i]['B']:
638
+ assert not raise_exception
639
+ return False
640
+ if c_index[j * p + i + 1]['f'] >= c_index[j * p + i]['f']:
641
+ assert not raise_exception
642
+ return False
643
+ if c_index[j * p + i + 1]['b'] >= c_index[j * p + i]['b']:
644
+ assert not raise_exception
645
+ return False
646
+ return True
647
+
648
+ def relabel_w(schedules, m):
649
+ p = len(schedules)
650
+ c_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
651
+ for i in range(p):
652
+ for j in range(len(schedules[i])):
653
+ if schedules[i][j] == ' ':
654
+ continue
655
+ c_cnt[i][schedules[i][j]] += 1
656
+ for c in "FfBbWw":
657
+ assert c_cnt[i][c] == m
658
+ for i in range(p):
659
+ w_queue = deque(maxlen=2 * m)
660
+ for j in range(len(schedules[i])):
661
+ identifier = schedules[i][j]
662
+ if identifier == 'B':
663
+ w_queue.append('W')
664
+ elif identifier == 'b':
665
+ w_queue.append('w')
666
+ elif identifier in "Ww":
667
+ assert len(w_queue) > 0, f"{i} {j}"
668
+ schedules[i][j] = w_queue.popleft()
669
+ assert len(w_queue) == 0
670
+ return schedules
671
+
672
+
673
+ def remove_redundancy(schedules, m):
674
  for sid in range(len(schedules)):
675
  cnt = {_id: 0 for _id in "FfBbWw"}
676
  for i in range(len(schedules[sid])):
677
+ if schedules[sid][i] == ' ':
678
  continue
679
  if cnt[schedules[sid][i]] >= m:
680
  schedules[sid][i] = ' '
681
  else:
682
  cnt[schedules[sid][i]] += 1
683
+ return schedules
684
+
685
+
686
+ def schedule_by_building_block(p, m, building_block, max_mem, keep_stable_phase=False):
687
+ # Apply the framework of repeating-squeezing-reordering
688
+ # 1. repeating
689
+ redundant_m = max(m, 2 * p) # we add some redundant micro-batches to avoid unexpected bugs
690
+ schedules = init_repeated_schedule(p, redundant_m, building_block)
691
+ schedules = clear_invalid_index(schedules, redundant_m)
692
+ init_peak_mem = get_peak_mem(schedules)
693
+ if (m == redundant_m and init_peak_mem > max_mem) or init_peak_mem > 2 * p:
694
+ return None, init_peak_mem, [6 * m] * p
695
+ print_schedules(schedules, "after repeating")
696
+
697
+ # 2. squeezing
698
+ schedules = squeeze_without_change_order(schedules, redundant_m)
699
+ print_schedules(schedules, "after squeezing")
700
+
701
+ # 3. reordering
702
+ # 3.a. reorder warm-up
703
+ schedules = process_warmup_without_increasing_peak_mem(schedules, redundant_m) # must work with m >= 2p
704
+ schedules = squeeze_without_change_order(schedules, redundant_m)
705
+ if keep_stable_phase:
706
+ ending_index = [0] * p # before second b
707
+ for i in range(p):
708
+ bb_cnt = 0
709
+ for j in range(len(schedules[i])):
710
+ if schedules[i][j] == 'b':
711
+ bb_cnt += 1
712
+ if bb_cnt >= 2:
713
+ ending_index[i] = j
714
+ break
715
+ schedules = reorder_greedily_without_increasing_peak_mem(schedules, redundant_m, ending_index=ending_index)
716
  peak_mem = get_peak_mem(schedules)
717
+ if debug:
718
+ assert peak_mem <= init_peak_mem, f"{init_peak_mem}, {peak_mem}"
719
  if peak_mem > init_peak_mem:
720
  return None, init_peak_mem, [6 * m] * p
721
 
722
+ if m < redundant_m:
723
+ # 4. remove redundancy
724
+ schedules = remove_redundancy(schedules, m)
725
+ schedules = squeeze_without_change_order(schedules, m)
726
+ print_schedules(schedules, "after removing redundancy")
727
+ init_peak_mem = peak_mem = get_peak_mem(schedules)
728
+ if peak_mem > max_mem:
729
+ return None, peak_mem, [6 * m] * p
730
 
731
+ # 3.b. reorder cool-down
732
  schedules = process_cooldown(schedules, m)
733
+ if keep_stable_phase:
734
+ starting_index = [0] * p
735
+ for i in range(p):
736
+ for j in range(len(schedules[i])):
737
+ if schedules[i][j] == 'F':
738
+ starting_index[i] = j
739
+ schedules = reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index=starting_index)
740
+ if not keep_stable_phase:
741
+ reorder_greedily_without_increasing_peak_mem(schedules, m)
742
+ schedules = relabel_w(schedules, m)
743
+ print_schedules(schedules, "after reordering")
744
  peak_mem = get_peak_mem(schedules)
745
+ if debug:
746
+ assert peak_mem <= init_peak_mem, f"{init_peak_mem}, {peak_mem}"
747
  if peak_mem > init_peak_mem:
748
  return None, init_peak_mem, [6 * m] * p
749
+
750
+ # return
751
+ if not check_correctness(schedules, m, raise_exception=debug):
752
+ return None, peak_mem, [6 * m] * p
753
  stage_bubbles = calc_bubble(schedules)
754
+ if debug:
755
+ print(peak_mem, stage_bubbles)
756
+ print("-" * 100)
757
  return schedules, peak_mem, stage_bubbles
758
 
759
 
760
+ def fill_w_in_building_block(pattern):
761
  f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5
762
  vis = [False] * pattern_size
763
  for v in pattern:
 
774
  return pattern
775
 
776
 
777
+ def get_building_block(pattern_0, offset_0, offset_1, len_0, p):
778
+ # see Appendix A in the paper
779
+ build_block = [pattern_0]
780
  for i in range(p - 1):
781
+ last_pattern = build_block[i]
782
  new_pattern = [-1] * pattern_size
783
  vis = [False] * pattern_size
784
  if i < len_0:
 
792
  return None
793
  vis[pos] = True
794
  new_pattern[v] = pos
795
+ new_pattern = fill_w_in_building_block(new_pattern)
796
+ build_block.append(new_pattern)
797
+ return build_block
798
 
799
 
800
 
801
  def schedule(p, m, cost, max_mem):
802
  f, ff, b, bb, w, ww = 0, 1, 2, 3, 4, 5
803
+ available_starting_patterns = []
804
+ # iterate available patterns for the first row/device of a building block
805
  for ff_i in range(1, pattern_size):
806
  for b_i in range(1, pattern_size):
807
  for bb_i in range(1, pattern_size):
808
  if ff_i == b_i or ff_i == bb_i or b_i == bb_i:
809
  continue
810
  pattern = [0, ff_i, b_i, bb_i, -1, -1]
811
+ pattern = fill_w_in_building_block(pattern)
812
+ available_starting_patterns.append(pattern)
813
 
814
+ # available uniform offsets, see Section 3.1 in the paper.
815
  available_offsets = [
816
+ # [\delta_F^0, \delta_F^1, \delta_B^1, \delta_B^0]
817
  [1, -1, 1, -1],
818
  [2, -1, 2, -1],
819
  [3, -1, 3, -1],
 
823
 
824
  best_schedule = None
825
  best_bubble = None
826
+ peak_mem2min_bubble = {}
827
+ for pattern_0 in available_starting_patterns:
828
  for i_0 in range(len(available_offsets)):
829
  for i_1 in range(i_0 + 1):
830
  for len_0 in range(1, p):
831
  offset_0 = available_offsets[i_0]
832
  offset_1 = available_offsets[i_1]
833
+ build_block = get_building_block(pattern_0, offset_0, offset_1, len_0, p)
834
+ if build_block is None:
835
  continue
836
+ s, peak_mem, bubbles = schedule_by_building_block(p, m, build_block, min(2 * p, max_mem))
837
  if peak_mem > 2 * p or peak_mem > max_mem:
838
  break
839
  if s is None:
840
  continue
 
841
  max_bubble = evaluate_schedule(s, *cost)
842
  if best_schedule is None or max_bubble < best_bubble:
843
  best_schedule, best_bubble = s, max_bubble
844
+
845
+ max_bubble = max(bubbles)
846
+ min_bubble = min(peak_mem2min_bubble.get(peak_mem, max_bubble), max_bubble)
847
+ peak_mem2min_bubble[peak_mem] = min_bubble
848
+ mem2bubble = {}
849
+ for peak_mem in sorted(peak_mem2min_bubble.keys()):
850
+ bubble = peak_mem2min_bubble[peak_mem]
851
+ mem2bubble[peak_mem] = bubble
852
+ # expected_bubble = max(0, 6 * p - 1 - 3 * peak_mem)
853
+ expected_bubble = 3 * p - 1 - 3 * peak_mem + max(3 * p, p - 1 + (1+(peak_mem+1)//2)*2)
854
+ # expected_bubble = 6 * p - 1 - 3 * peak_mem
855
+ print(peak_mem, bubble, expected_bubble, "|", bubble - expected_bubble)
856
+ print(mem2bubble)
857
+
858
  res = transform_schedule(best_schedule, *cost)
859
+ return res