Nyamdavaa Amar commited on
Commit
cf49f13
·
1 Parent(s): 3d4d40d

Edit presets

Browse files
adaptive_schedule.py CHANGED
@@ -46,9 +46,9 @@ def transform_schedule(schedule, f, b, w, c):
46
  time = 0
47
  if (stage, type, mb) in local_prev:
48
  time = get_time(stage, *local_prev[(stage, type, mb)])
49
- if type in ('F', 'B') and stage > 0:
50
  time = max(time, get_time(stage - 1, type, mb) + c)
51
- if type in ('f', 'b') and stage + 1< len(schedule):
52
  time = max(time, get_time(stage + 1, type, mb) + c)
53
  # print(f'{stage} {type}:{mb}', time + cost[type])
54
  time_map[(stage, type, mb)] = time + cost[type]
@@ -63,7 +63,7 @@ def transform_schedule(schedule, f, b, w, c):
63
  for p, mb in stage:
64
  result_stage.append(ScheduledNode(
65
  p.upper(),
66
- p in ('f', 'B', 'W'),
67
  sid,
68
  mb,
69
  get_time(sid, p, mb) - cost[p],
@@ -110,9 +110,9 @@ def evaluate_schedule(schedule, f, b, w, c):
110
  time = 0
111
  if (stage, type, mb) in local_prev:
112
  time = get_time(stage, *local_prev[(stage, type, mb)])
113
- if type in ('F', 'B') and stage > 0:
114
  time = max(time, get_time(stage - 1, type, mb) + c)
115
- if type in ('f', 'b') and stage + 1< len(schedule):
116
  time = max(time, get_time(stage + 1, type, mb) + c)
117
  # print(f'{stage} {type}:{mb}', time + cost[type])
118
  time_map[(stage, type, mb)] = time + cost[type]
@@ -153,16 +153,6 @@ def get_peak_mem(schedules, return_all=False):
153
  return all_peak
154
  return max_peak
155
 
156
- debug = False
157
- def print_schedules(schedules):
158
- if not debug:
159
- return
160
- for seq in schedules:
161
- _str = ""
162
- for v in seq:
163
- _str += v
164
- print(_str)
165
-
166
 
167
  def calc_bubble(schedules):
168
  stage_bubbles = []
@@ -199,8 +189,8 @@ def clear_invalid(repeated, stage, pos, offset=-1):
199
  def clear_invalid_index(repeated, m):
200
  p = len(repeated)
201
  index = pattern_size
202
- for identifier in ['F', 'f', 'B', 'b']:
203
- if identifier in ['F', 'B']:
204
  _iter = range(p)
205
  else:
206
  _iter = range(p - 1, -1, -1)
@@ -210,7 +200,7 @@ def clear_invalid_index(repeated, m):
210
  clear_invalid(repeated, i, index - pattern_size, offset=-1)
211
  clear_invalid(repeated, i, index + pattern_size * m, offset=1)
212
  index += 1
213
- if identifier in ['B', 'b']:
214
  w_identifier = {'B': 'W', 'b': 'w'}[identifier]
215
  for k in range(pattern_size):
216
  if repeated[i][index + k] == w_identifier:
@@ -386,6 +376,17 @@ def squeeze_without_change_order(schedules, m):
386
  identifier_cnt[i][identifier] += 1
387
  identifier_index[_cnt * p + i][identifier] = index
388
  stage_index[i] = index + 1
 
 
 
 
 
 
 
 
 
 
 
389
  return squeezed
390
 
391
 
@@ -485,12 +486,11 @@ def process_cooldown(schedules, m):
485
  return schedules
486
 
487
 
488
- def schedule_by_pattern(p, m, patterns):
489
  schedules = init_repeated_schedule(p, max(m, 2 * p), patterns)
490
  schedules = clear_invalid_index(schedules, max(m, 2 * p))
491
- print_schedules(schedules)
492
  init_peak_mem = get_peak_mem(schedules)
493
- if init_peak_mem > 2 * p:
494
  return None, init_peak_mem, [6 * max(m, 2 * p)] * p
495
  schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p))
496
 
@@ -503,20 +503,16 @@ def schedule_by_pattern(p, m, patterns):
503
  schedules[sid][i] = ' '
504
  else:
505
  cnt[schedules[sid][i]] += 1
506
- print_schedules(schedules)
507
  peak_mem = get_peak_mem(schedules)
508
  if peak_mem > init_peak_mem:
509
  return None, init_peak_mem, [6 * m] * p
510
 
511
  schedules = squeeze_without_change_order(schedules, m)
512
- print_schedules(schedules)
513
 
514
  schedules = process_cooldown(schedules, m)
515
- print_schedules(schedules)
516
  peak_mem = get_peak_mem(schedules)
517
  if peak_mem > init_peak_mem:
518
  return None, init_peak_mem, [6 * m] * p
519
-
520
  stage_bubbles = calc_bubble(schedules)
521
  return schedules, peak_mem, stage_bubbles
522
 
@@ -572,25 +568,8 @@ def schedule(p, m, cost, max_mem):
572
  pattern = [0, ff_i, b_i, bb_i, -1, -1]
573
  pattern = fill_w_in_pattern(pattern)
574
  available_patterns.append(pattern)
575
- available_offsets = []
576
- for f_o in range(1, pattern_size + 1):
577
- for ff_o in range(1, pattern_size + 1):
578
- for b_o in range(1, pattern_size + 1):
579
- if f_o != b_o:
580
- continue
581
- bb_o = ff_o + b_o - f_o
582
- if bb_o < 1 or bb_o > pattern_size:
583
- continue
584
- if bb_o + ff_o + b_o + f_o > 2 * pattern_size:
585
- continue
586
- # if bb_o + ff_o + b_o + f_o != 6:
587
- # continue
588
- offset = [f_o, - ff_o, b_o, - bb_o]
589
- if min(ff_o, bb_o) > 1:
590
- continue
591
- available_offsets.append(offset)
592
 
593
- print(available_offsets, len(available_patterns))
594
  available_offsets = [
595
  [1, -1, 1, -1],
596
  [2, -1, 2, -1],
@@ -601,7 +580,6 @@ def schedule(p, m, cost, max_mem):
601
 
602
  best_schedule = None
603
  best_bubble = None
604
- peak_mem2min_bubble = {}
605
  for pattern_0 in available_patterns:
606
  for i_0 in range(len(available_offsets)):
607
  for i_1 in range(i_0 + 1):
@@ -611,13 +589,10 @@ def schedule(p, m, cost, max_mem):
611
  whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p)
612
  if whole_pattern is None:
613
  continue
614
- # for pattern in whole_pattern:
615
- # print(get_pattern_str(pattern))
616
- # print(offset)
617
- s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern)
618
- if s is None:
619
- continue
620
  if peak_mem > 2 * p or peak_mem > max_mem:
 
 
621
  continue
622
  max_bubble = max(bubbles)
623
  max_bubble = evaluate_schedule(s, *cost)
 
46
  time = 0
47
  if (stage, type, mb) in local_prev:
48
  time = get_time(stage, *local_prev[(stage, type, mb)])
49
+ if type in "FB" and stage > 0:
50
  time = max(time, get_time(stage - 1, type, mb) + c)
51
+ if type in "fb" and stage + 1< len(schedule):
52
  time = max(time, get_time(stage + 1, type, mb) + c)
53
  # print(f'{stage} {type}:{mb}', time + cost[type])
54
  time_map[(stage, type, mb)] = time + cost[type]
 
63
  for p, mb in stage:
64
  result_stage.append(ScheduledNode(
65
  p.upper(),
66
+ p in "fBW",
67
  sid,
68
  mb,
69
  get_time(sid, p, mb) - cost[p],
 
110
  time = 0
111
  if (stage, type, mb) in local_prev:
112
  time = get_time(stage, *local_prev[(stage, type, mb)])
113
+ if type in "FB" and stage > 0:
114
  time = max(time, get_time(stage - 1, type, mb) + c)
115
+ if type in "fb" and stage + 1< len(schedule):
116
  time = max(time, get_time(stage + 1, type, mb) + c)
117
  # print(f'{stage} {type}:{mb}', time + cost[type])
118
  time_map[(stage, type, mb)] = time + cost[type]
 
153
  return all_peak
154
  return max_peak
155
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  def calc_bubble(schedules):
158
  stage_bubbles = []
 
189
  def clear_invalid_index(repeated, m):
190
  p = len(repeated)
191
  index = pattern_size
192
+ for identifier in "FfBb":
193
+ if identifier in "FB":
194
  _iter = range(p)
195
  else:
196
  _iter = range(p - 1, -1, -1)
 
200
  clear_invalid(repeated, i, index - pattern_size, offset=-1)
201
  clear_invalid(repeated, i, index + pattern_size * m, offset=1)
202
  index += 1
203
+ if identifier in "Bb":
204
  w_identifier = {'B': 'W', 'b': 'w'}[identifier]
205
  for k in range(pattern_size):
206
  if repeated[i][index + k] == w_identifier:
 
376
  identifier_cnt[i][identifier] += 1
377
  identifier_index[_cnt * p + i][identifier] = index
378
  stage_index[i] = index + 1
379
+ while True:
380
+ if(len(squeezed[0]) == 1):
381
+ break
382
+ allempty = True
383
+ for x in squeezed:
384
+ if x[-1] != ' ':
385
+ allempty = False
386
+ if allempty == False:
387
+ break
388
+ for x in squeezed:
389
+ del x[-1]
390
  return squeezed
391
 
392
 
 
486
  return schedules
487
 
488
 
489
+ def schedule_by_pattern(p, m, patterns, max_mem):
490
  schedules = init_repeated_schedule(p, max(m, 2 * p), patterns)
491
  schedules = clear_invalid_index(schedules, max(m, 2 * p))
 
492
  init_peak_mem = get_peak_mem(schedules)
493
+ if init_peak_mem > max_mem:
494
  return None, init_peak_mem, [6 * max(m, 2 * p)] * p
495
  schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p))
496
 
 
503
  schedules[sid][i] = ' '
504
  else:
505
  cnt[schedules[sid][i]] += 1
 
506
  peak_mem = get_peak_mem(schedules)
507
  if peak_mem > init_peak_mem:
508
  return None, init_peak_mem, [6 * m] * p
509
 
510
  schedules = squeeze_without_change_order(schedules, m)
 
511
 
512
  schedules = process_cooldown(schedules, m)
 
513
  peak_mem = get_peak_mem(schedules)
514
  if peak_mem > init_peak_mem:
515
  return None, init_peak_mem, [6 * m] * p
 
516
  stage_bubbles = calc_bubble(schedules)
517
  return schedules, peak_mem, stage_bubbles
518
 
 
568
  pattern = [0, ff_i, b_i, bb_i, -1, -1]
569
  pattern = fill_w_in_pattern(pattern)
570
  available_patterns.append(pattern)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
571
 
572
+ print(len(available_patterns))
573
  available_offsets = [
574
  [1, -1, 1, -1],
575
  [2, -1, 2, -1],
 
580
 
581
  best_schedule = None
582
  best_bubble = None
 
583
  for pattern_0 in available_patterns:
584
  for i_0 in range(len(available_offsets)):
585
  for i_1 in range(i_0 + 1):
 
589
  whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p)
590
  if whole_pattern is None:
591
  continue
592
+ s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern, min(2 * p, max_mem))
 
 
 
 
 
593
  if peak_mem > 2 * p or peak_mem > max_mem:
594
+ break
595
+ if s is None:
596
  continue
597
  max_bubble = max(bubbles)
598
  max_bubble = evaluate_schedule(s, *cost)
app.py CHANGED
@@ -136,9 +136,10 @@ with gr.Blocks() as demo:
136
  gr.Markdown(open("description1.md").read())
137
  gr.Markdown("# Pipeline Scheduler Playground")
138
  presets = {
139
- 'Real Case': (6, 12, 1049, 1122, 903, 79, 'V-Half'),
140
- 'Ideal Case': (6, 12, 20, 20, 20, 0, 'V-Min'),
141
- 'Zero Bubble Case': (6, 12, 1049, 1122, 903, 79, 'V-ZB')
 
142
  }
143
  preset_buttons = {}
144
 
@@ -153,30 +154,30 @@ with gr.Blocks() as demo:
153
  with gr.Group():
154
  gr.Markdown("Basic Parameters")
155
  with gr.Row():
156
- p=gr.Number(label="Number of stages (p)", value=6, interactive=True, precision=0)
157
- m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0)
158
  with gr.Column(scale=2):
159
  with gr.Group():
160
  gr.Markdown("Costs. All costs are used as integers. For chunked schedules, this is the time of two virtual stages on a stage combined.")
161
- with gr.Row():
162
- f=gr.Number(label="Time of F", value=1049, interactive=True, precision=0)
163
- b=gr.Number(label="Time of B", value=1122, interactive=True, precision=0)
164
- w=gr.Number(label="Time of W", value=903, interactive=True, precision=0)
165
- c=gr.Number(label="Time of one P2P communication", value=79, interactive=True, precision=0)
166
  with gr.Group():
167
  gr.Markdown("Activation memory limit.")
168
  def update_mem(p, s, mem):
169
  print("update")
170
  if s == "custom":
171
  return mem
172
- if s == "V-Min":
173
  return (p + 4) // 3
174
- if s == "V-Half":
175
  return (p + 2) // 2
176
- if s == "V-ZB":
177
  return p
178
  assert False
179
- memsel=gr.Radio(choices=["V-Min", "V-Half", "V-ZB", "custom"], value="V-Half")
180
  mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For chunked schedules, this is relative to two virtual stages on a stage combined.", value=(p.value + 2) // 2, interactive=True, precision=0)
181
  memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
182
  p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
@@ -212,7 +213,7 @@ with gr.Blocks() as demo:
212
  with gr.Column(scale=4):
213
  schedule1f1bv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
214
  with gr.Group():
215
- gr.Markdown("Two microbatch in one building block schedule")
216
  with gr.Row():
217
  with gr.Column(scale=1):
218
  type2_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
@@ -221,7 +222,7 @@ with gr.Blocks() as demo:
221
  with gr.Column(scale=4):
222
  type2_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
223
  with gr.Group():
224
- gr.Markdown("Interleaved 1F1B Schedule")
225
  with gr.Row():
226
  with gr.Column(scale=1):
227
  interleaved_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
@@ -234,6 +235,7 @@ with gr.Blocks() as demo:
234
  schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
235
  type2_acceleration, type2_mem, type2_bubble, type2_image,
236
  interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
 
237
 
238
  for (k, v) in presets.items():
239
  def update_preset(pb, p, m, f, b, w, c, mem):
 
136
  gr.Markdown(open("description1.md").read())
137
  gr.Markdown("# Pipeline Scheduler Playground")
138
  presets = {
139
+ 'Default Case': (4, 10, 100, 110, 90, 5, 'V-Half (1/2)'),
140
+ 'Ideal Case': (4, 10, 20, 20, 20, 0, 'V-Min (1/3)'),
141
+ 'Real Case': (4, 10, 1049, 1122, 903, 79, 'V-Half (1/2)'),
142
+ 'Zero Bubble Case': (4, 10, 1049, 1122, 903, 79, 'V-ZB (1)')
143
  }
144
  preset_buttons = {}
145
 
 
154
  with gr.Group():
155
  gr.Markdown("Basic Parameters")
156
  with gr.Row():
157
+ p=gr.Number(label="Number of stages (p)", value=4, interactive=True, precision=0)
158
+ m=gr.Number(label="Number of microbatches (m)", value=10, interactive=True, precision=0)
159
  with gr.Column(scale=2):
160
  with gr.Group():
161
  gr.Markdown("Costs. All costs are used as integers. For chunked schedules, this is the time of two virtual stages on a stage combined.")
162
+ with gr.Row():
163
+ f=gr.Number(label="Time of F", value=100, interactive=True, precision=0)
164
+ b=gr.Number(label="Time of B", value=110, interactive=True, precision=0)
165
+ w=gr.Number(label="Time of W", value=90, interactive=True, precision=0)
166
+ c=gr.Number(label="Time of one P2P communication", value=5, interactive=True, precision=0)
167
  with gr.Group():
168
  gr.Markdown("Activation memory limit.")
169
  def update_mem(p, s, mem):
170
  print("update")
171
  if s == "custom":
172
  return mem
173
+ if s == "V-Min (1/3)":
174
  return (p + 4) // 3
175
+ if s == "V-Half (1/2)":
176
  return (p + 2) // 2
177
+ if s == "V-ZB (1)":
178
  return p
179
  assert False
180
+ memsel=gr.Radio(choices=["V-Min (1/3)", "V-Half (1/2)", "V-ZB (1)", "custom"], value="V-Half (1/2)")
181
  mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For chunked schedules, this is relative to two virtual stages on a stage combined.", value=(p.value + 2) // 2, interactive=True, precision=0)
182
  memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
183
  p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
 
213
  with gr.Column(scale=4):
214
  schedule1f1bv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
215
  with gr.Group():
216
+ gr.Markdown("Zero bubble schedule with 2/3 1F1B memory")
217
  with gr.Row():
218
  with gr.Column(scale=1):
219
  type2_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
 
222
  with gr.Column(scale=4):
223
  type2_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
224
  with gr.Group():
225
+ gr.Markdown("Variation of Interleaved 1F1B Schedule")
226
  with gr.Row():
227
  with gr.Column(scale=1):
228
  interleaved_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
 
235
  schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
236
  type2_acceleration, type2_mem, type2_bubble, type2_image,
237
  interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
238
+ gr.Markdown(open("description3.md").read())
239
 
240
  for (k, v) in presets.items():
241
  def update_preset(pb, p, m, f, b, w, c, mem):
description1.md CHANGED
@@ -1,5 +1,9 @@
1
  # Pipeline Parallellism with Controllable Memory
2
 
 
 
 
 
3
  Check out our paper at [Arxiv](https://arxiv.org/abs/2405.15362).
4
 
5
  Bubble Rate here is calculated as (1 - longest stage time/(F+B+W)/m).
 
1
  # Pipeline Parallellism with Controllable Memory
2
 
3
+ Pipeline Parallelism with Controllable Memory creates a framework on designing pipeline schedules and uses the framework to find memory optimal efficient schedules.
4
+
5
+ From our findings, we need approximately 1/3 memory under ideal conditions (F, B and W have same runtime), and 1/2 memory to create zero bubble schedule in realistic scenarios (with the necessary condition being W + 2B ≥ 2F and W + 2F ≥ 2B ).
6
+
7
  Check out our paper at [Arxiv](https://arxiv.org/abs/2405.15362).
8
 
9
  Bubble Rate here is calculated as (1 - longest stage time/(F+B+W)/m).
description2.md CHANGED
@@ -1,6 +1,7 @@
1
  ## Alternative schedules
2
 
3
  By utilizing the building block, we can search for different types of schedules depending on the need. We illustrate few of them here below:
 
4
  * 1F1B-V schedule without doing any B-W split.
5
  * Schedule with 2/3rd 1F1B memory by utilising B-W split. Note that two microbatches are included in a single building block to avoid collision.
6
  * Variation of interleaved 1F1B with lower memory
 
1
  ## Alternative schedules
2
 
3
  By utilizing the building block, we can search for different types of schedules depending on the need. We illustrate few of them here below:
4
+
5
  * 1F1B-V schedule without doing any B-W split.
6
  * Schedule with 2/3rd 1F1B memory by utilising B-W split. Note that two microbatches are included in a single building block to avoid collision.
7
  * Variation of interleaved 1F1B with lower memory
description3.md ADDED
File without changes
interleaved_variant.py CHANGED
@@ -65,11 +65,6 @@ def get_interleaved_variation(_p, _n, cost):
65
  'B': _b+_w,
66
  'b': _b+_w
67
  }
68
- pred = {
69
- 'f': 'F',
70
- 'B': 'f',
71
- 'b': 'B'
72
- }
73
 
74
  time_map = {}
75
  def get_time(stage, type, minibatch):
@@ -78,16 +73,16 @@ def get_interleaved_variation(_p, _n, cost):
78
  time = 0
79
  if (stage, type, minibatch) in local_prev:
80
  time = get_time(*local_prev[(stage, type, minibatch)])
81
- if stage > 0 and type in ('F', 'f'):
82
  time = max(time, get_time(stage - 1, type, minibatch) + _c)
83
- if stage == 0 and type in ('f'):
84
- time = max(time, get_time(_p - 1, pred[type], minibatch) + _c)
85
- if stage != _p - 1 and type in ('B', 'b'):
86
  time = max(time, get_time(stage + 1, type, minibatch) + _c)
87
- if stage == _p - 1 and type in ('b'):
88
- time = max(time, get_time(0, pred[type], minibatch) + _c)
89
- if stage == _p - 1 and type in ('B'):
90
- time = max(time, get_time(stage, pred[type], minibatch))
91
 
92
  time_map[(stage, type, minibatch)] = time + cost[type]
93
  return time_map[(stage, type, minibatch)]
@@ -97,7 +92,7 @@ def get_interleaved_variation(_p, _n, cost):
97
  for type, minibatch in stage:
98
  result_stage.append(ScheduledNode(
99
  type.upper(),
100
- type in ('f', 'B', 'W'),
101
  sid,
102
  minibatch,
103
  get_time(sid, type, minibatch) - cost[type],
 
65
  'B': _b+_w,
66
  'b': _b+_w
67
  }
 
 
 
 
 
68
 
69
  time_map = {}
70
  def get_time(stage, type, minibatch):
 
73
  time = 0
74
  if (stage, type, minibatch) in local_prev:
75
  time = get_time(*local_prev[(stage, type, minibatch)])
76
+ if stage > 0 and type in "Ff":
77
  time = max(time, get_time(stage - 1, type, minibatch) + _c)
78
+ if stage == 0 and type == 'f':
79
+ time = max(time, get_time(_p - 1, 'F', minibatch) + _c)
80
+ if stage != _p - 1 and type in "Bb":
81
  time = max(time, get_time(stage + 1, type, minibatch) + _c)
82
+ if stage == _p - 1 and type == 'b':
83
+ time = max(time, get_time(0, 'B', minibatch) + _c)
84
+ if stage == _p - 1 and type == 'B':
85
+ time = max(time, get_time(stage, 'f', minibatch))
86
 
87
  time_map[(stage, type, minibatch)] = time + cost[type]
88
  return time_map[(stage, type, minibatch)]
 
92
  for type, minibatch in stage:
93
  result_stage.append(ScheduledNode(
94
  type.upper(),
95
+ type in "fBW",
96
  sid,
97
  minibatch,
98
  get_time(sid, type, minibatch) - cost[type],
schedule1f1bv.py CHANGED
@@ -44,9 +44,9 @@ def transform_schedule(schedule, f, b, w, c):
44
  time = 0
45
  if (stage, type, mb) in local_prev:
46
  time = get_time(stage, *local_prev[(stage, type, mb)])
47
- if type in ('F', 'B') and stage > 0:
48
  time = max(time, get_time(stage - 1, type, mb) + c)
49
- if type in ('f', 'b') and stage + 1< len(schedule):
50
  time = max(time, get_time(stage + 1, type, mb) + c)
51
  time_map[(stage, type, mb)] = time + cost[type]
52
  return time_map[(stage, type, mb)]
@@ -59,7 +59,7 @@ def transform_schedule(schedule, f, b, w, c):
59
  for p, mb in stage:
60
  result_stage.append(ScheduledNode(
61
  p.upper(),
62
- p in ('f', 'B', 'W'),
63
  sid,
64
  mb,
65
  get_time(sid, p, mb) - cost[p],
@@ -104,8 +104,8 @@ def clear_invalid(repeated, stage, pos, offset=-1):
104
  def clear_invalid_index(repeated, m):
105
  p = len(repeated)
106
  index = pattern_size
107
- for identifier in ['F', 'f', 'B', 'b']:
108
- if identifier in ['F', 'B']:
109
  _iter = range(p)
110
  else:
111
  _iter = range(p - 1, -1, -1)
@@ -115,7 +115,7 @@ def clear_invalid_index(repeated, m):
115
  clear_invalid(repeated, i, index - pattern_size, offset=-1)
116
  clear_invalid(repeated, i, index + pattern_size * m, offset=1)
117
  index += 1
118
- if identifier in ['B', 'b']:
119
  w_identifier = {'B': 'W', 'b': 'w'}[identifier]
120
  for k in range(pattern_size):
121
  if repeated[i][index + k] == w_identifier:
@@ -135,9 +135,9 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
135
  for sid in range(len(schedules)):
136
  cur = 0
137
  for i in range(len(schedules[sid])):
138
- if schedules[sid][i] in ('F', 'f'):
139
  cur += 1
140
- if schedules[sid][i] in ('W', 'w'):
141
  cur -= 1
142
  mem[sid][i] = cur
143
  peak_mem = max(peak_mem, cur)
@@ -177,16 +177,16 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
177
  pos += 1
178
  while schedules[sid][pos] != ' ' and pos < i:
179
  pos += 1
180
- if schedules[sid][i] in ('B', 'b'):
181
  while pos < i and (schedules[sid][pos] != ' ' or schedules[sid][pos + 1] != ' '):
182
  pos += 1
183
  if pos == i:
184
  loc[sid][cnt][schedules[sid][i]] = i
185
  continue
186
- if schedules[sid][i] in ('B', 'b', 'W', 'w'):
187
  schedules[sid][pos] = schedules[sid][i]
188
  schedules[sid][i] = ' '
189
- if schedules[sid][pos] in ('W', 'w'):
190
  for j in range(pos, i):
191
  mem[sid][j] -= 1
192
  loc[sid][cnt][schedules[sid][pos]] = pos
@@ -265,7 +265,7 @@ def schedule(p, m, cost):
265
  s = schedule_by_pattern(p, m, whole_pattern)
266
  for sid in range(len(s)):
267
  for i in range(len(s[sid])):
268
- if s[sid][i] in ('W', 'w'):
269
  s[sid][i] = ' '
270
  res = transform_schedule(s, *cost)
271
  return res
 
44
  time = 0
45
  if (stage, type, mb) in local_prev:
46
  time = get_time(stage, *local_prev[(stage, type, mb)])
47
+ if type in "FB"and stage > 0:
48
  time = max(time, get_time(stage - 1, type, mb) + c)
49
+ if type in "fb" and stage + 1< len(schedule):
50
  time = max(time, get_time(stage + 1, type, mb) + c)
51
  time_map[(stage, type, mb)] = time + cost[type]
52
  return time_map[(stage, type, mb)]
 
59
  for p, mb in stage:
60
  result_stage.append(ScheduledNode(
61
  p.upper(),
62
+ p in "fBW",
63
  sid,
64
  mb,
65
  get_time(sid, p, mb) - cost[p],
 
104
  def clear_invalid_index(repeated, m):
105
  p = len(repeated)
106
  index = pattern_size
107
+ for identifier in "FfBb":
108
+ if identifier in "FB":
109
  _iter = range(p)
110
  else:
111
  _iter = range(p - 1, -1, -1)
 
115
  clear_invalid(repeated, i, index - pattern_size, offset=-1)
116
  clear_invalid(repeated, i, index + pattern_size * m, offset=1)
117
  index += 1
118
+ if identifier in "Bb":
119
  w_identifier = {'B': 'W', 'b': 'w'}[identifier]
120
  for k in range(pattern_size):
121
  if repeated[i][index + k] == w_identifier:
 
135
  for sid in range(len(schedules)):
136
  cur = 0
137
  for i in range(len(schedules[sid])):
138
+ if schedules[sid][i] in "Ff":
139
  cur += 1
140
+ if schedules[sid][i] in "Ww":
141
  cur -= 1
142
  mem[sid][i] = cur
143
  peak_mem = max(peak_mem, cur)
 
177
  pos += 1
178
  while schedules[sid][pos] != ' ' and pos < i:
179
  pos += 1
180
+ if schedules[sid][i] in "Bb":
181
  while pos < i and (schedules[sid][pos] != ' ' or schedules[sid][pos + 1] != ' '):
182
  pos += 1
183
  if pos == i:
184
  loc[sid][cnt][schedules[sid][i]] = i
185
  continue
186
+ if schedules[sid][i] in "BbWw":
187
  schedules[sid][pos] = schedules[sid][i]
188
  schedules[sid][i] = ' '
189
+ if schedules[sid][pos] in "Ww":
190
  for j in range(pos, i):
191
  mem[sid][j] -= 1
192
  loc[sid][cnt][schedules[sid][pos]] = pos
 
265
  s = schedule_by_pattern(p, m, whole_pattern)
266
  for sid in range(len(s)):
267
  for i in range(len(s[sid])):
268
+ if s[sid][i] in "Ww":
269
  s[sid][i] = ' '
270
  res = transform_schedule(s, *cost)
271
  return res