QPHutu commited on
Commit
eff388a
1 Parent(s): 07554d1

Fix corner case when m is too small

Browse files
Files changed (1) hide show
  1. adaptive_schedule.py +59 -12
adaptive_schedule.py CHANGED
@@ -345,10 +345,7 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
345
  def squeeze_without_change_order(schedules, m):
346
  p = len(schedules)
347
  squeezed = [[' '] * len(schedules[_]) for _ in range(p)]
348
- max_len = 0
349
- for seq in squeezed:
350
- assert max_len == 0 or max_len == len(seq)
351
- max_len = max(max_len, len(seq))
352
 
353
  identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
354
  identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
@@ -389,6 +386,9 @@ def squeeze_without_change_order(schedules, m):
389
  identifier_cnt[i][identifier] = _cnt + 1
390
  identifier_index[_cnt * p + i][identifier] = index
391
  stage_index[i] = index + 1
 
 
 
392
  return squeezed
393
 
394
 
@@ -454,6 +454,7 @@ def process_cooldown(schedules, m):
454
  schedules[i][index] = 'B'
455
 
456
  # 2: add W back in cooldown phase
 
457
  for i in range(p):
458
  c_w, c_ww = 0, 0
459
  last_w_index = -1
@@ -478,12 +479,57 @@ def process_cooldown(schedules, m):
478
  elif c_ww > 0:
479
  schedules[i][j] = 'w'
480
  c_ww -= 1
 
 
 
 
 
 
 
 
481
 
482
  schedules = squeeze_without_change_order(schedules, m)
483
  return schedules
484
 
485
 
486
- def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index = None, ending_index = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  """
488
  We iterate all the cells from left to right. If a vacant cell (which means a bubble) is encountered, we try to
489
  find a computation pass to fill this bubble. We iterate all the following computation passes in the same device,
@@ -491,17 +537,15 @@ def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index =
491
  to the vacant cell, and the bubble is filled.
492
  """
493
  p = len(schedules)
494
- max_len = 0
495
- for seq in schedules:
496
- assert max_len == 0 or max_len == len(seq)
497
- max_len = max(max_len, len(seq))
498
  if starting_index is not None:
499
  assert isinstance(starting_index, list) and len(starting_index) == p
500
  if ending_index is not None:
501
  assert isinstance(ending_index, list) and len(ending_index) == p
 
 
 
502
  starting_index = starting_index or [0] * p
503
  ending_index = ending_index or [max_len] * p
504
-
505
  last_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p)]
506
  for i in range(p):
507
  for j in range(max_len):
@@ -510,7 +554,6 @@ def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index =
510
  continue
511
  last_index[i][identifier] = j
512
 
513
- peak_mem = get_peak_mem(schedules)
514
  stage_mem = [0] * p
515
  def update_mem(stage_i, pass_c):
516
  if pass_c in "Ff":
@@ -645,6 +688,7 @@ def check_correctness(schedules, m, raise_exception=False):
645
  return False
646
  return True
647
 
 
648
  def relabel_w(schedules, m):
649
  p = len(schedules)
650
  c_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
@@ -654,7 +698,7 @@ def relabel_w(schedules, m):
654
  continue
655
  c_cnt[i][schedules[i][j]] += 1
656
  for c in "FfBbWw":
657
- assert c_cnt[i][c] == m
658
  for i in range(p):
659
  w_queue = deque(maxlen=2 * m)
660
  for j in range(len(schedules[i])):
@@ -722,6 +766,8 @@ def schedule_by_building_block(p, m, building_block, max_mem, keep_stable_phase=
722
  if m < redundant_m:
723
  # 4. remove redundancy
724
  schedules = remove_redundancy(schedules, m)
 
 
725
  schedules = squeeze_without_change_order(schedules, m)
726
  print_schedules(schedules, "after removing redundancy")
727
  init_peak_mem = peak_mem = get_peak_mem(schedules)
@@ -820,6 +866,7 @@ def schedule(p, m, cost, max_mem):
820
  [4, -1, 4, -1],
821
  [5, -1, 5, -1]
822
  ]
 
823
 
824
  best_schedule = None
825
  best_bubble = None
 
345
  def squeeze_without_change_order(schedules, m):
346
  p = len(schedules)
347
  squeezed = [[' '] * len(schedules[_]) for _ in range(p)]
348
+ max_len = check_and_get_schedule_len(schedules)
 
 
 
349
 
350
  identifier_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
351
  identifier_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p * m)]
 
386
  identifier_cnt[i][identifier] = _cnt + 1
387
  identifier_index[_cnt * p + i][identifier] = index
388
  stage_index[i] = index + 1
389
+ new_len = max(stage_index)
390
+ for i in range(p):
391
+ squeezed[i] = squeezed[i][:new_len]
392
  return squeezed
393
 
394
 
 
454
  schedules[i][index] = 'B'
455
 
456
  # 2: add W back in cooldown phase
457
+ max_len = 0
458
  for i in range(p):
459
  c_w, c_ww = 0, 0
460
  last_w_index = -1
 
479
  elif c_ww > 0:
480
  schedules[i][j] = 'w'
481
  c_ww -= 1
482
+ for _ in range(c_w):
483
+ schedules[i].append('W')
484
+ for _ in range(c_ww):
485
+ schedules[i].append('w')
486
+ max_len = max(max_len, len(schedules[i]))
487
+ for i in range(p):
488
+ for _ in range(len(schedules[i]), max_len):
489
+ schedules[i].append(' ')
490
 
491
  schedules = squeeze_without_change_order(schedules, m)
492
  return schedules
493
 
494
 
495
+ def check_and_get_schedule_len(schedules):
496
+ max_len = 0
497
+ for seq in schedules:
498
+ assert max_len == 0 or max_len == len(seq)
499
+ max_len = max(max_len, len(seq))
500
+ return max_len
501
+
502
+
503
+ def release_w_in_warmup_if_under_memory(schedules, peak_mem = None):
504
+ """
505
+ FF fBWfBW bwbw -> FF fBfBWW bwbw
506
+ FF f fBW BW bwbw -> FF f fBWBW bwbw
507
+ FF f f BW BbWbww -> FF f f BWBbWbww
508
+ FfFf BbWBbwWw -> FfFf BbBbWwWw
509
+ When the number of micro-batches is too small (than mem), the warmup phase is not optimal. We simply remove some
510
+ preceding W to fully utilize the memory to reduce unnecessary bubbles.
511
+ """
512
+ p = len(schedules)
513
+ max_len = check_and_get_schedule_len(schedules)
514
+ all_peak_mem = get_peak_mem(schedules, return_all=True)
515
+ peak_mem = peak_mem or max(all_peak_mem)
516
+ min_peak = min(all_peak_mem)
517
+ for i in range(p):
518
+ cnt = 0
519
+ padding = [" "] * (peak_mem - min_peak)
520
+ for j in range(max_len):
521
+ if all_peak_mem[i] + cnt >= peak_mem:
522
+ break
523
+ if schedules[i][j] in "Ww":
524
+ padding[cnt] = schedules[i][j]
525
+ schedules[i][j] = ' '
526
+ cnt += 1
527
+ schedules[i].extend(padding)
528
+ # max_len += peak_mem - min_peak
529
+ return schedules
530
+
531
+
532
+ def reorder_greedily_without_increasing_peak_mem(schedules, m, starting_index = None, ending_index = None, peak_mem = None):
533
  """
534
  We iterate all the cells from left to right. If a vacant cell (which means a bubble) is encountered, we try to
535
  find a computation pass to fill this bubble. We iterate all the following computation passes in the same device,
 
537
  to the vacant cell, and the bubble is filled.
538
  """
539
  p = len(schedules)
 
 
 
 
540
  if starting_index is not None:
541
  assert isinstance(starting_index, list) and len(starting_index) == p
542
  if ending_index is not None:
543
  assert isinstance(ending_index, list) and len(ending_index) == p
544
+
545
+ peak_mem = peak_mem or get_peak_mem(schedules)
546
+ max_len = check_and_get_schedule_len(schedules)
547
  starting_index = starting_index or [0] * p
548
  ending_index = ending_index or [max_len] * p
 
549
  last_index = [{_id: -1 for _id in "FfBbWw"} for _ in range(p)]
550
  for i in range(p):
551
  for j in range(max_len):
 
554
  continue
555
  last_index[i][identifier] = j
556
 
 
557
  stage_mem = [0] * p
558
  def update_mem(stage_i, pass_c):
559
  if pass_c in "Ff":
 
688
  return False
689
  return True
690
 
691
+
692
  def relabel_w(schedules, m):
693
  p = len(schedules)
694
  c_cnt = [{_id: 0 for _id in "FfBbWw"} for _ in range(p)]
 
698
  continue
699
  c_cnt[i][schedules[i][j]] += 1
700
  for c in "FfBbWw":
701
+ assert c_cnt[i][c] == m, f"{i}, {c}, {c_cnt[i][c]}"
702
  for i in range(p):
703
  w_queue = deque(maxlen=2 * m)
704
  for j in range(len(schedules[i])):
 
766
  if m < redundant_m:
767
  # 4. remove redundancy
768
  schedules = remove_redundancy(schedules, m)
769
+ if m <= p and 2 * m <= max_mem:
770
+ schedules = release_w_in_warmup_if_under_memory(schedules, peak_mem=min(2 * p, peak_mem))
771
  schedules = squeeze_without_change_order(schedules, m)
772
  print_schedules(schedules, "after removing redundancy")
773
  init_peak_mem = peak_mem = get_peak_mem(schedules)
 
866
  [4, -1, 4, -1],
867
  [5, -1, 5, -1]
868
  ]
869
+ # available_starting_patterns = available_starting_patterns[:1]
870
 
871
  best_schedule = None
872
  best_bubble = None