Nyamdavaa Amar
commited on
Commit
·
cf49f13
1
Parent(s):
3d4d40d
Edit presets
Browse files- adaptive_schedule.py +25 -50
- app.py +18 -16
- description1.md +4 -0
- description2.md +1 -0
- description3.md +0 -0
- interleaved_variant.py +9 -14
- schedule1f1bv.py +12 -12
adaptive_schedule.py
CHANGED
@@ -46,9 +46,9 @@ def transform_schedule(schedule, f, b, w, c):
|
|
46 |
time = 0
|
47 |
if (stage, type, mb) in local_prev:
|
48 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
49 |
-
if type in
|
50 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
51 |
-
if type in
|
52 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
53 |
# print(f'{stage} {type}:{mb}', time + cost[type])
|
54 |
time_map[(stage, type, mb)] = time + cost[type]
|
@@ -63,7 +63,7 @@ def transform_schedule(schedule, f, b, w, c):
|
|
63 |
for p, mb in stage:
|
64 |
result_stage.append(ScheduledNode(
|
65 |
p.upper(),
|
66 |
-
p in
|
67 |
sid,
|
68 |
mb,
|
69 |
get_time(sid, p, mb) - cost[p],
|
@@ -110,9 +110,9 @@ def evaluate_schedule(schedule, f, b, w, c):
|
|
110 |
time = 0
|
111 |
if (stage, type, mb) in local_prev:
|
112 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
113 |
-
if type in
|
114 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
115 |
-
if type in
|
116 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
117 |
# print(f'{stage} {type}:{mb}', time + cost[type])
|
118 |
time_map[(stage, type, mb)] = time + cost[type]
|
@@ -153,16 +153,6 @@ def get_peak_mem(schedules, return_all=False):
|
|
153 |
return all_peak
|
154 |
return max_peak
|
155 |
|
156 |
-
debug = False
|
157 |
-
def print_schedules(schedules):
|
158 |
-
if not debug:
|
159 |
-
return
|
160 |
-
for seq in schedules:
|
161 |
-
_str = ""
|
162 |
-
for v in seq:
|
163 |
-
_str += v
|
164 |
-
print(_str)
|
165 |
-
|
166 |
|
167 |
def calc_bubble(schedules):
|
168 |
stage_bubbles = []
|
@@ -199,8 +189,8 @@ def clear_invalid(repeated, stage, pos, offset=-1):
|
|
199 |
def clear_invalid_index(repeated, m):
|
200 |
p = len(repeated)
|
201 |
index = pattern_size
|
202 |
-
for identifier in
|
203 |
-
if identifier in
|
204 |
_iter = range(p)
|
205 |
else:
|
206 |
_iter = range(p - 1, -1, -1)
|
@@ -210,7 +200,7 @@ def clear_invalid_index(repeated, m):
|
|
210 |
clear_invalid(repeated, i, index - pattern_size, offset=-1)
|
211 |
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
|
212 |
index += 1
|
213 |
-
if identifier in
|
214 |
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
|
215 |
for k in range(pattern_size):
|
216 |
if repeated[i][index + k] == w_identifier:
|
@@ -386,6 +376,17 @@ def squeeze_without_change_order(schedules, m):
|
|
386 |
identifier_cnt[i][identifier] += 1
|
387 |
identifier_index[_cnt * p + i][identifier] = index
|
388 |
stage_index[i] = index + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
return squeezed
|
390 |
|
391 |
|
@@ -485,12 +486,11 @@ def process_cooldown(schedules, m):
|
|
485 |
return schedules
|
486 |
|
487 |
|
488 |
-
def schedule_by_pattern(p, m, patterns):
|
489 |
schedules = init_repeated_schedule(p, max(m, 2 * p), patterns)
|
490 |
schedules = clear_invalid_index(schedules, max(m, 2 * p))
|
491 |
-
print_schedules(schedules)
|
492 |
init_peak_mem = get_peak_mem(schedules)
|
493 |
-
if init_peak_mem >
|
494 |
return None, init_peak_mem, [6 * max(m, 2 * p)] * p
|
495 |
schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p))
|
496 |
|
@@ -503,20 +503,16 @@ def schedule_by_pattern(p, m, patterns):
|
|
503 |
schedules[sid][i] = ' '
|
504 |
else:
|
505 |
cnt[schedules[sid][i]] += 1
|
506 |
-
print_schedules(schedules)
|
507 |
peak_mem = get_peak_mem(schedules)
|
508 |
if peak_mem > init_peak_mem:
|
509 |
return None, init_peak_mem, [6 * m] * p
|
510 |
|
511 |
schedules = squeeze_without_change_order(schedules, m)
|
512 |
-
print_schedules(schedules)
|
513 |
|
514 |
schedules = process_cooldown(schedules, m)
|
515 |
-
print_schedules(schedules)
|
516 |
peak_mem = get_peak_mem(schedules)
|
517 |
if peak_mem > init_peak_mem:
|
518 |
return None, init_peak_mem, [6 * m] * p
|
519 |
-
|
520 |
stage_bubbles = calc_bubble(schedules)
|
521 |
return schedules, peak_mem, stage_bubbles
|
522 |
|
@@ -572,25 +568,8 @@ def schedule(p, m, cost, max_mem):
|
|
572 |
pattern = [0, ff_i, b_i, bb_i, -1, -1]
|
573 |
pattern = fill_w_in_pattern(pattern)
|
574 |
available_patterns.append(pattern)
|
575 |
-
available_offsets = []
|
576 |
-
for f_o in range(1, pattern_size + 1):
|
577 |
-
for ff_o in range(1, pattern_size + 1):
|
578 |
-
for b_o in range(1, pattern_size + 1):
|
579 |
-
if f_o != b_o:
|
580 |
-
continue
|
581 |
-
bb_o = ff_o + b_o - f_o
|
582 |
-
if bb_o < 1 or bb_o > pattern_size:
|
583 |
-
continue
|
584 |
-
if bb_o + ff_o + b_o + f_o > 2 * pattern_size:
|
585 |
-
continue
|
586 |
-
# if bb_o + ff_o + b_o + f_o != 6:
|
587 |
-
# continue
|
588 |
-
offset = [f_o, - ff_o, b_o, - bb_o]
|
589 |
-
if min(ff_o, bb_o) > 1:
|
590 |
-
continue
|
591 |
-
available_offsets.append(offset)
|
592 |
|
593 |
-
print(
|
594 |
available_offsets = [
|
595 |
[1, -1, 1, -1],
|
596 |
[2, -1, 2, -1],
|
@@ -601,7 +580,6 @@ def schedule(p, m, cost, max_mem):
|
|
601 |
|
602 |
best_schedule = None
|
603 |
best_bubble = None
|
604 |
-
peak_mem2min_bubble = {}
|
605 |
for pattern_0 in available_patterns:
|
606 |
for i_0 in range(len(available_offsets)):
|
607 |
for i_1 in range(i_0 + 1):
|
@@ -611,13 +589,10 @@ def schedule(p, m, cost, max_mem):
|
|
611 |
whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p)
|
612 |
if whole_pattern is None:
|
613 |
continue
|
614 |
-
|
615 |
-
# print(get_pattern_str(pattern))
|
616 |
-
# print(offset)
|
617 |
-
s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern)
|
618 |
-
if s is None:
|
619 |
-
continue
|
620 |
if peak_mem > 2 * p or peak_mem > max_mem:
|
|
|
|
|
621 |
continue
|
622 |
max_bubble = max(bubbles)
|
623 |
max_bubble = evaluate_schedule(s, *cost)
|
|
|
46 |
time = 0
|
47 |
if (stage, type, mb) in local_prev:
|
48 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
49 |
+
if type in "FB" and stage > 0:
|
50 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
51 |
+
if type in "fb" and stage + 1< len(schedule):
|
52 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
53 |
# print(f'{stage} {type}:{mb}', time + cost[type])
|
54 |
time_map[(stage, type, mb)] = time + cost[type]
|
|
|
63 |
for p, mb in stage:
|
64 |
result_stage.append(ScheduledNode(
|
65 |
p.upper(),
|
66 |
+
p in "fBW",
|
67 |
sid,
|
68 |
mb,
|
69 |
get_time(sid, p, mb) - cost[p],
|
|
|
110 |
time = 0
|
111 |
if (stage, type, mb) in local_prev:
|
112 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
113 |
+
if type in "FB" and stage > 0:
|
114 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
115 |
+
if type in "fb" and stage + 1< len(schedule):
|
116 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
117 |
# print(f'{stage} {type}:{mb}', time + cost[type])
|
118 |
time_map[(stage, type, mb)] = time + cost[type]
|
|
|
153 |
return all_peak
|
154 |
return max_peak
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
def calc_bubble(schedules):
|
158 |
stage_bubbles = []
|
|
|
189 |
def clear_invalid_index(repeated, m):
|
190 |
p = len(repeated)
|
191 |
index = pattern_size
|
192 |
+
for identifier in "FfBb":
|
193 |
+
if identifier in "FB":
|
194 |
_iter = range(p)
|
195 |
else:
|
196 |
_iter = range(p - 1, -1, -1)
|
|
|
200 |
clear_invalid(repeated, i, index - pattern_size, offset=-1)
|
201 |
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
|
202 |
index += 1
|
203 |
+
if identifier in "Bb":
|
204 |
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
|
205 |
for k in range(pattern_size):
|
206 |
if repeated[i][index + k] == w_identifier:
|
|
|
376 |
identifier_cnt[i][identifier] += 1
|
377 |
identifier_index[_cnt * p + i][identifier] = index
|
378 |
stage_index[i] = index + 1
|
379 |
+
while True:
|
380 |
+
if(len(squeezed[0]) == 1):
|
381 |
+
break
|
382 |
+
allempty = True
|
383 |
+
for x in squeezed:
|
384 |
+
if x[-1] != ' ':
|
385 |
+
allempty = False
|
386 |
+
if allempty == False:
|
387 |
+
break
|
388 |
+
for x in squeezed:
|
389 |
+
del x[-1]
|
390 |
return squeezed
|
391 |
|
392 |
|
|
|
486 |
return schedules
|
487 |
|
488 |
|
489 |
+
def schedule_by_pattern(p, m, patterns, max_mem):
|
490 |
schedules = init_repeated_schedule(p, max(m, 2 * p), patterns)
|
491 |
schedules = clear_invalid_index(schedules, max(m, 2 * p))
|
|
|
492 |
init_peak_mem = get_peak_mem(schedules)
|
493 |
+
if init_peak_mem > max_mem:
|
494 |
return None, init_peak_mem, [6 * max(m, 2 * p)] * p
|
495 |
schedules = process_warmup_without_increasing_peak_mem(schedules, max(m, 2 * p))
|
496 |
|
|
|
503 |
schedules[sid][i] = ' '
|
504 |
else:
|
505 |
cnt[schedules[sid][i]] += 1
|
|
|
506 |
peak_mem = get_peak_mem(schedules)
|
507 |
if peak_mem > init_peak_mem:
|
508 |
return None, init_peak_mem, [6 * m] * p
|
509 |
|
510 |
schedules = squeeze_without_change_order(schedules, m)
|
|
|
511 |
|
512 |
schedules = process_cooldown(schedules, m)
|
|
|
513 |
peak_mem = get_peak_mem(schedules)
|
514 |
if peak_mem > init_peak_mem:
|
515 |
return None, init_peak_mem, [6 * m] * p
|
|
|
516 |
stage_bubbles = calc_bubble(schedules)
|
517 |
return schedules, peak_mem, stage_bubbles
|
518 |
|
|
|
568 |
pattern = [0, ff_i, b_i, bb_i, -1, -1]
|
569 |
pattern = fill_w_in_pattern(pattern)
|
570 |
available_patterns.append(pattern)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
571 |
|
572 |
+
print(len(available_patterns))
|
573 |
available_offsets = [
|
574 |
[1, -1, 1, -1],
|
575 |
[2, -1, 2, -1],
|
|
|
580 |
|
581 |
best_schedule = None
|
582 |
best_bubble = None
|
|
|
583 |
for pattern_0 in available_patterns:
|
584 |
for i_0 in range(len(available_offsets)):
|
585 |
for i_1 in range(i_0 + 1):
|
|
|
589 |
whole_pattern = get_whole_pattern(pattern_0, offset_0, offset_1, len_0, p)
|
590 |
if whole_pattern is None:
|
591 |
continue
|
592 |
+
s, peak_mem, bubbles = schedule_by_pattern(p, m, whole_pattern, min(2 * p, max_mem))
|
|
|
|
|
|
|
|
|
|
|
593 |
if peak_mem > 2 * p or peak_mem > max_mem:
|
594 |
+
break
|
595 |
+
if s is None:
|
596 |
continue
|
597 |
max_bubble = max(bubbles)
|
598 |
max_bubble = evaluate_schedule(s, *cost)
|
app.py
CHANGED
@@ -136,9 +136,10 @@ with gr.Blocks() as demo:
|
|
136 |
gr.Markdown(open("description1.md").read())
|
137 |
gr.Markdown("# Pipeline Scheduler Playground")
|
138 |
presets = {
|
139 |
-
'
|
140 |
-
'Ideal Case': (
|
141 |
-
'
|
|
|
142 |
}
|
143 |
preset_buttons = {}
|
144 |
|
@@ -153,30 +154,30 @@ with gr.Blocks() as demo:
|
|
153 |
with gr.Group():
|
154 |
gr.Markdown("Basic Parameters")
|
155 |
with gr.Row():
|
156 |
-
p=gr.Number(label="Number of stages (p)", value=
|
157 |
-
m=gr.Number(label="Number of microbatches (m)", value=
|
158 |
with gr.Column(scale=2):
|
159 |
with gr.Group():
|
160 |
gr.Markdown("Costs. All costs are used as integers. For chunked schedules, this is the time of two virtual stages on a stage combined.")
|
161 |
-
with gr.Row():
|
162 |
-
f=gr.Number(label="Time of F", value=
|
163 |
-
b=gr.Number(label="Time of B", value=
|
164 |
-
w=gr.Number(label="Time of W", value=
|
165 |
-
c=gr.Number(label="Time of one P2P communication", value=
|
166 |
with gr.Group():
|
167 |
gr.Markdown("Activation memory limit.")
|
168 |
def update_mem(p, s, mem):
|
169 |
print("update")
|
170 |
if s == "custom":
|
171 |
return mem
|
172 |
-
if s == "V-Min":
|
173 |
return (p + 4) // 3
|
174 |
-
if s == "V-Half":
|
175 |
return (p + 2) // 2
|
176 |
-
if s == "V-ZB":
|
177 |
return p
|
178 |
assert False
|
179 |
-
memsel=gr.Radio(choices=["V-Min", "V-Half", "V-ZB", "custom"], value="V-Half")
|
180 |
mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For chunked schedules, this is relative to two virtual stages on a stage combined.", value=(p.value + 2) // 2, interactive=True, precision=0)
|
181 |
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
182 |
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
@@ -212,7 +213,7 @@ with gr.Blocks() as demo:
|
|
212 |
with gr.Column(scale=4):
|
213 |
schedule1f1bv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
214 |
with gr.Group():
|
215 |
-
gr.Markdown("
|
216 |
with gr.Row():
|
217 |
with gr.Column(scale=1):
|
218 |
type2_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
@@ -221,7 +222,7 @@ with gr.Blocks() as demo:
|
|
221 |
with gr.Column(scale=4):
|
222 |
type2_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
223 |
with gr.Group():
|
224 |
-
gr.Markdown("Interleaved 1F1B Schedule")
|
225 |
with gr.Row():
|
226 |
with gr.Column(scale=1):
|
227 |
interleaved_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
@@ -234,6 +235,7 @@ with gr.Blocks() as demo:
|
|
234 |
schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
|
235 |
type2_acceleration, type2_mem, type2_bubble, type2_image,
|
236 |
interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
|
|
|
237 |
|
238 |
for (k, v) in presets.items():
|
239 |
def update_preset(pb, p, m, f, b, w, c, mem):
|
|
|
136 |
gr.Markdown(open("description1.md").read())
|
137 |
gr.Markdown("# Pipeline Scheduler Playground")
|
138 |
presets = {
|
139 |
+
'Default Case': (4, 10, 100, 110, 90, 5, 'V-Half (1/2)'),
|
140 |
+
'Ideal Case': (4, 10, 20, 20, 20, 0, 'V-Min (1/3)'),
|
141 |
+
'Real Case': (4, 10, 1049, 1122, 903, 79, 'V-Half (1/2)'),
|
142 |
+
'Zero Bubble Case': (4, 10, 1049, 1122, 903, 79, 'V-ZB (1)')
|
143 |
}
|
144 |
preset_buttons = {}
|
145 |
|
|
|
154 |
with gr.Group():
|
155 |
gr.Markdown("Basic Parameters")
|
156 |
with gr.Row():
|
157 |
+
p=gr.Number(label="Number of stages (p)", value=4, interactive=True, precision=0)
|
158 |
+
m=gr.Number(label="Number of microbatches (m)", value=10, interactive=True, precision=0)
|
159 |
with gr.Column(scale=2):
|
160 |
with gr.Group():
|
161 |
gr.Markdown("Costs. All costs are used as integers. For chunked schedules, this is the time of two virtual stages on a stage combined.")
|
162 |
+
with gr.Row():
|
163 |
+
f=gr.Number(label="Time of F", value=100, interactive=True, precision=0)
|
164 |
+
b=gr.Number(label="Time of B", value=110, interactive=True, precision=0)
|
165 |
+
w=gr.Number(label="Time of W", value=90, interactive=True, precision=0)
|
166 |
+
c=gr.Number(label="Time of one P2P communication", value=5, interactive=True, precision=0)
|
167 |
with gr.Group():
|
168 |
gr.Markdown("Activation memory limit.")
|
169 |
def update_mem(p, s, mem):
|
170 |
print("update")
|
171 |
if s == "custom":
|
172 |
return mem
|
173 |
+
if s == "V-Min (1/3)":
|
174 |
return (p + 4) // 3
|
175 |
+
if s == "V-Half (1/2)":
|
176 |
return (p + 2) // 2
|
177 |
+
if s == "V-ZB (1)":
|
178 |
return p
|
179 |
assert False
|
180 |
+
memsel=gr.Radio(choices=["V-Min (1/3)", "V-Half (1/2)", "V-ZB (1)", "custom"], value="V-Half (1/2)")
|
181 |
mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For chunked schedules, this is relative to two virtual stages on a stage combined.", value=(p.value + 2) // 2, interactive=True, precision=0)
|
182 |
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
183 |
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
|
|
213 |
with gr.Column(scale=4):
|
214 |
schedule1f1bv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
215 |
with gr.Group():
|
216 |
+
gr.Markdown("Zero bubble schedule with 2/3 1F1B memory")
|
217 |
with gr.Row():
|
218 |
with gr.Column(scale=1):
|
219 |
type2_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
|
|
222 |
with gr.Column(scale=4):
|
223 |
type2_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
224 |
with gr.Group():
|
225 |
+
gr.Markdown("Variation of Interleaved 1F1B Schedule")
|
226 |
with gr.Row():
|
227 |
with gr.Column(scale=1):
|
228 |
interleaved_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
|
|
235 |
schedule1f1bv_acceleration, schedule1f1bv_mem, schedule1f1bv_bubble, schedule1f1bv_image,
|
236 |
type2_acceleration, type2_mem, type2_bubble, type2_image,
|
237 |
interleaved_acceleration, interleaved_mem, interleaved_bubble, interleaved_image])
|
238 |
+
gr.Markdown(open("description3.md").read())
|
239 |
|
240 |
for (k, v) in presets.items():
|
241 |
def update_preset(pb, p, m, f, b, w, c, mem):
|
description1.md
CHANGED
@@ -1,5 +1,9 @@
|
|
1 |
# Pipeline Parallellism with Controllable Memory
|
2 |
|
|
|
|
|
|
|
|
|
3 |
Check out our paper at [Arxiv](https://arxiv.org/abs/2405.15362).
|
4 |
|
5 |
Bubble Rate here is calculated as (1 - longest stage time/(F+B+W)/m).
|
|
|
1 |
# Pipeline Parallellism with Controllable Memory
|
2 |
|
3 |
+
Pipeline Parallelism with Controllable Memory creates a framework on designing pipeline schedules and uses the framework to find memory optimal efficient schedules.
|
4 |
+
|
5 |
+
From our findings, we need approximately 1/3 memory under ideal conditions (F, B and W have same runtime), and 1/2 memory to create zero bubble schedule in realistic scenarios (with the necessary condition being W + 2B ≥ 2F and W + 2F ≥ 2B ).
|
6 |
+
|
7 |
Check out our paper at [Arxiv](https://arxiv.org/abs/2405.15362).
|
8 |
|
9 |
Bubble Rate here is calculated as (1 - longest stage time/(F+B+W)/m).
|
description2.md
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
## Alternative schedules
|
2 |
|
3 |
By utilizing the building block, we can search for different types of schedules depending on the need. We illustrate few of them here below:
|
|
|
4 |
* 1F1B-V schedule without doing any B-W split.
|
5 |
* Schedule with 2/3rd 1F1B memory by utilising B-W split. Note that two microbatches are included in a single building block to avoid collision.
|
6 |
* Variation of interleaved 1F1B with lower memory
|
|
|
1 |
## Alternative schedules
|
2 |
|
3 |
By utilizing the building block, we can search for different types of schedules depending on the need. We illustrate few of them here below:
|
4 |
+
|
5 |
* 1F1B-V schedule without doing any B-W split.
|
6 |
* Schedule with 2/3rd 1F1B memory by utilising B-W split. Note that two microbatches are included in a single building block to avoid collision.
|
7 |
* Variation of interleaved 1F1B with lower memory
|
description3.md
ADDED
File without changes
|
interleaved_variant.py
CHANGED
@@ -65,11 +65,6 @@ def get_interleaved_variation(_p, _n, cost):
|
|
65 |
'B': _b+_w,
|
66 |
'b': _b+_w
|
67 |
}
|
68 |
-
pred = {
|
69 |
-
'f': 'F',
|
70 |
-
'B': 'f',
|
71 |
-
'b': 'B'
|
72 |
-
}
|
73 |
|
74 |
time_map = {}
|
75 |
def get_time(stage, type, minibatch):
|
@@ -78,16 +73,16 @@ def get_interleaved_variation(_p, _n, cost):
|
|
78 |
time = 0
|
79 |
if (stage, type, minibatch) in local_prev:
|
80 |
time = get_time(*local_prev[(stage, type, minibatch)])
|
81 |
-
if stage > 0 and type in
|
82 |
time = max(time, get_time(stage - 1, type, minibatch) + _c)
|
83 |
-
if stage == 0 and type
|
84 |
-
time = max(time, get_time(_p - 1,
|
85 |
-
if stage != _p - 1 and type in
|
86 |
time = max(time, get_time(stage + 1, type, minibatch) + _c)
|
87 |
-
if stage == _p - 1 and type
|
88 |
-
time = max(time, get_time(0,
|
89 |
-
if stage == _p - 1 and type
|
90 |
-
time = max(time, get_time(stage,
|
91 |
|
92 |
time_map[(stage, type, minibatch)] = time + cost[type]
|
93 |
return time_map[(stage, type, minibatch)]
|
@@ -97,7 +92,7 @@ def get_interleaved_variation(_p, _n, cost):
|
|
97 |
for type, minibatch in stage:
|
98 |
result_stage.append(ScheduledNode(
|
99 |
type.upper(),
|
100 |
-
type in
|
101 |
sid,
|
102 |
minibatch,
|
103 |
get_time(sid, type, minibatch) - cost[type],
|
|
|
65 |
'B': _b+_w,
|
66 |
'b': _b+_w
|
67 |
}
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
time_map = {}
|
70 |
def get_time(stage, type, minibatch):
|
|
|
73 |
time = 0
|
74 |
if (stage, type, minibatch) in local_prev:
|
75 |
time = get_time(*local_prev[(stage, type, minibatch)])
|
76 |
+
if stage > 0 and type in "Ff":
|
77 |
time = max(time, get_time(stage - 1, type, minibatch) + _c)
|
78 |
+
if stage == 0 and type == 'f':
|
79 |
+
time = max(time, get_time(_p - 1, 'F', minibatch) + _c)
|
80 |
+
if stage != _p - 1 and type in "Bb":
|
81 |
time = max(time, get_time(stage + 1, type, minibatch) + _c)
|
82 |
+
if stage == _p - 1 and type == 'b':
|
83 |
+
time = max(time, get_time(0, 'B', minibatch) + _c)
|
84 |
+
if stage == _p - 1 and type == 'B':
|
85 |
+
time = max(time, get_time(stage, 'f', minibatch))
|
86 |
|
87 |
time_map[(stage, type, minibatch)] = time + cost[type]
|
88 |
return time_map[(stage, type, minibatch)]
|
|
|
92 |
for type, minibatch in stage:
|
93 |
result_stage.append(ScheduledNode(
|
94 |
type.upper(),
|
95 |
+
type in "fBW",
|
96 |
sid,
|
97 |
minibatch,
|
98 |
get_time(sid, type, minibatch) - cost[type],
|
schedule1f1bv.py
CHANGED
@@ -44,9 +44,9 @@ def transform_schedule(schedule, f, b, w, c):
|
|
44 |
time = 0
|
45 |
if (stage, type, mb) in local_prev:
|
46 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
47 |
-
if type in
|
48 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
49 |
-
if type in
|
50 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
51 |
time_map[(stage, type, mb)] = time + cost[type]
|
52 |
return time_map[(stage, type, mb)]
|
@@ -59,7 +59,7 @@ def transform_schedule(schedule, f, b, w, c):
|
|
59 |
for p, mb in stage:
|
60 |
result_stage.append(ScheduledNode(
|
61 |
p.upper(),
|
62 |
-
p in
|
63 |
sid,
|
64 |
mb,
|
65 |
get_time(sid, p, mb) - cost[p],
|
@@ -104,8 +104,8 @@ def clear_invalid(repeated, stage, pos, offset=-1):
|
|
104 |
def clear_invalid_index(repeated, m):
|
105 |
p = len(repeated)
|
106 |
index = pattern_size
|
107 |
-
for identifier in
|
108 |
-
if identifier in
|
109 |
_iter = range(p)
|
110 |
else:
|
111 |
_iter = range(p - 1, -1, -1)
|
@@ -115,7 +115,7 @@ def clear_invalid_index(repeated, m):
|
|
115 |
clear_invalid(repeated, i, index - pattern_size, offset=-1)
|
116 |
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
|
117 |
index += 1
|
118 |
-
if identifier in
|
119 |
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
|
120 |
for k in range(pattern_size):
|
121 |
if repeated[i][index + k] == w_identifier:
|
@@ -135,9 +135,9 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
|
|
135 |
for sid in range(len(schedules)):
|
136 |
cur = 0
|
137 |
for i in range(len(schedules[sid])):
|
138 |
-
if schedules[sid][i] in
|
139 |
cur += 1
|
140 |
-
if schedules[sid][i] in
|
141 |
cur -= 1
|
142 |
mem[sid][i] = cur
|
143 |
peak_mem = max(peak_mem, cur)
|
@@ -177,16 +177,16 @@ def process_warmup_without_increasing_peak_mem(schedules, m):
|
|
177 |
pos += 1
|
178 |
while schedules[sid][pos] != ' ' and pos < i:
|
179 |
pos += 1
|
180 |
-
if schedules[sid][i] in
|
181 |
while pos < i and (schedules[sid][pos] != ' ' or schedules[sid][pos + 1] != ' '):
|
182 |
pos += 1
|
183 |
if pos == i:
|
184 |
loc[sid][cnt][schedules[sid][i]] = i
|
185 |
continue
|
186 |
-
if schedules[sid][i] in
|
187 |
schedules[sid][pos] = schedules[sid][i]
|
188 |
schedules[sid][i] = ' '
|
189 |
-
if schedules[sid][pos] in
|
190 |
for j in range(pos, i):
|
191 |
mem[sid][j] -= 1
|
192 |
loc[sid][cnt][schedules[sid][pos]] = pos
|
@@ -265,7 +265,7 @@ def schedule(p, m, cost):
|
|
265 |
s = schedule_by_pattern(p, m, whole_pattern)
|
266 |
for sid in range(len(s)):
|
267 |
for i in range(len(s[sid])):
|
268 |
-
if s[sid][i] in
|
269 |
s[sid][i] = ' '
|
270 |
res = transform_schedule(s, *cost)
|
271 |
return res
|
|
|
44 |
time = 0
|
45 |
if (stage, type, mb) in local_prev:
|
46 |
time = get_time(stage, *local_prev[(stage, type, mb)])
|
47 |
+
if type in "FB"and stage > 0:
|
48 |
time = max(time, get_time(stage - 1, type, mb) + c)
|
49 |
+
if type in "fb" and stage + 1< len(schedule):
|
50 |
time = max(time, get_time(stage + 1, type, mb) + c)
|
51 |
time_map[(stage, type, mb)] = time + cost[type]
|
52 |
return time_map[(stage, type, mb)]
|
|
|
59 |
for p, mb in stage:
|
60 |
result_stage.append(ScheduledNode(
|
61 |
p.upper(),
|
62 |
+
p in "fBW",
|
63 |
sid,
|
64 |
mb,
|
65 |
get_time(sid, p, mb) - cost[p],
|
|
|
104 |
def clear_invalid_index(repeated, m):
|
105 |
p = len(repeated)
|
106 |
index = pattern_size
|
107 |
+
for identifier in "FfBb":
|
108 |
+
if identifier in "FB":
|
109 |
_iter = range(p)
|
110 |
else:
|
111 |
_iter = range(p - 1, -1, -1)
|
|
|
115 |
clear_invalid(repeated, i, index - pattern_size, offset=-1)
|
116 |
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
|
117 |
index += 1
|
118 |
+
if identifier in "Bb":
|
119 |
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
|
120 |
for k in range(pattern_size):
|
121 |
if repeated[i][index + k] == w_identifier:
|
|
|
135 |
for sid in range(len(schedules)):
|
136 |
cur = 0
|
137 |
for i in range(len(schedules[sid])):
|
138 |
+
if schedules[sid][i] in "Ff":
|
139 |
cur += 1
|
140 |
+
if schedules[sid][i] in "Ww":
|
141 |
cur -= 1
|
142 |
mem[sid][i] = cur
|
143 |
peak_mem = max(peak_mem, cur)
|
|
|
177 |
pos += 1
|
178 |
while schedules[sid][pos] != ' ' and pos < i:
|
179 |
pos += 1
|
180 |
+
if schedules[sid][i] in "Bb":
|
181 |
while pos < i and (schedules[sid][pos] != ' ' or schedules[sid][pos + 1] != ' '):
|
182 |
pos += 1
|
183 |
if pos == i:
|
184 |
loc[sid][cnt][schedules[sid][i]] = i
|
185 |
continue
|
186 |
+
if schedules[sid][i] in "BbWw":
|
187 |
schedules[sid][pos] = schedules[sid][i]
|
188 |
schedules[sid][i] = ' '
|
189 |
+
if schedules[sid][pos] in "Ww":
|
190 |
for j in range(pos, i):
|
191 |
mem[sid][j] -= 1
|
192 |
loc[sid][cnt][schedules[sid][pos]] = pos
|
|
|
265 |
s = schedule_by_pattern(p, m, whole_pattern)
|
266 |
for sid in range(len(s)):
|
267 |
for i in range(len(s[sid])):
|
268 |
+
if s[sid][i] in "Ww":
|
269 |
s[sid][i] = ' '
|
270 |
res = transform_schedule(s, *cost)
|
271 |
return res
|