Wan Xinyi commited on
Commit
ac0b05c
1 Parent(s): be3048f

Add some presets, support 1f1b with fewer microbatches

Browse files
Files changed (2) hide show
  1. app.py +36 -6
  2. hand_schedule.py +20 -11
app.py CHANGED
@@ -46,6 +46,7 @@ def calculate(p, m, f, b, w, c, mem):
46
  baseline_bubble=None
47
  baseline_acceleration=None
48
  baseline_image=None
 
49
  else:
50
  baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
51
  baseline_result = [
@@ -70,11 +71,12 @@ def calculate(p, m, f, b, w, c, mem):
70
  zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
71
  zb_acceleration=percentage(baseline_time/zb_time - 1) if baseline_time is not None else None
72
 
73
- if mem < p:
74
  zbv_time=None
75
  zbv_bubble=None
76
  zbv_acceleration=None
77
  zbv_image=None
 
78
  else:
79
  zbv_graph = v_schedule.PipelineGraph(
80
  n_stage=p,
@@ -94,10 +96,13 @@ def calculate(p, m, f, b, w, c, mem):
94
  zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
95
  zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
96
 
97
- max_time = max([baseline_time, zb_time, zbv_time])
98
- print(max_time)
 
99
  baseline_image = get_schedule_image(baseline_result, max_time)
 
100
  zb_image = get_schedule_image(zb_result, max_time)
 
101
  zbv_image = get_schedule_image(zbv_result, max_time)
102
 
103
  return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]
@@ -105,6 +110,20 @@ def calculate(p, m, f, b, w, c, mem):
105
  with gr.Blocks() as demo:
106
  gr.Markdown(open("description1.md").read())
107
  gr.Markdown("# Pipeline Scheduler Playground")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  with gr.Row():
109
  with gr.Column(scale=1):
110
  with gr.Group():
@@ -142,7 +161,7 @@ with gr.Blocks() as demo:
142
  baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
143
  baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
144
  with gr.Column(scale=4):
145
- baseline_image=gr.Image(None, interactive=False, label="Schedule Image")
146
 
147
  with gr.Group():
148
  gr.Markdown("Zero Bubble Schedule")
@@ -152,7 +171,7 @@ with gr.Blocks() as demo:
152
  zb_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
153
  zb_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
154
  with gr.Column(scale=4):
155
- zb_image=gr.Image(None, interactive=False, label="Schedule Image")
156
  with gr.Group():
157
  gr.Markdown("Zero Bubble V Schedule (ZBV)")
158
  with gr.Row():
@@ -161,7 +180,18 @@ with gr.Blocks() as demo:
161
  zbv_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
162
  zbv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
163
  with gr.Column(scale=4):
164
- zbv_image=gr.Image(None, interactive=False, label="Schedule Image")
165
  button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
 
 
 
 
 
 
 
 
 
 
 
166
  gr.Markdown(open("description2.md").read())
167
  demo.launch()
 
46
  baseline_bubble=None
47
  baseline_acceleration=None
48
  baseline_image=None
49
+ baseline_result=None
50
  else:
51
  baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
52
  baseline_result = [
 
71
  zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
72
  zb_acceleration=percentage(baseline_time/zb_time - 1) if baseline_time is not None else None
73
 
74
+ if mem < p or m < 2 * p:
75
  zbv_time=None
76
  zbv_bubble=None
77
  zbv_acceleration=None
78
  zbv_image=None
79
+ zbv_result=None
80
  else:
81
  zbv_graph = v_schedule.PipelineGraph(
82
  n_stage=p,
 
96
  zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
97
  zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
98
 
99
+ max_time = max(filter(lambda x: x is not None, [baseline_time, zb_time, zbv_time]))
100
+ print(max_time)
101
+ if baseline_result is not None:
102
  baseline_image = get_schedule_image(baseline_result, max_time)
103
+ if zb_result is not None:
104
  zb_image = get_schedule_image(zb_result, max_time)
105
+ if zbv_result is not None:
106
  zbv_image = get_schedule_image(zbv_result, max_time)
107
 
108
  return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]
 
110
  with gr.Blocks() as demo:
111
  gr.Markdown(open("description1.md").read())
112
  gr.Markdown("# Pipeline Scheduler Playground")
113
+ presets = {
114
+ 'Ideal Case 1p': (4, 12, 20, 20, 20, 0, '1p (Same as 1F1B)'),
115
+ 'Ideal Case 2p': (4, 12, 20, 20, 20, 0, '2p'),
116
+ 'Real Case 1p': (4, 12, 1049, 1122, 903, 79, '1p (Same as 1F1B)'),
117
+ 'Real Case 2p': (4, 12, 1049, 1122, 903, 79, '2p'),
118
+ }
119
+ preset_buttons = {}
120
+
121
+ with gr.Group():
122
+ gr.Markdown("Preset Setups")
123
+ with gr.Row():
124
+ for (k, v) in presets.items():
125
+ preset_buttons[k] = gr.Button(k, variant="secondary")
126
+
127
  with gr.Row():
128
  with gr.Column(scale=1):
129
  with gr.Group():
 
161
  baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
162
  baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
163
  with gr.Column(scale=4):
164
+ baseline_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
165
 
166
  with gr.Group():
167
  gr.Markdown("Zero Bubble Schedule")
 
171
  zb_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
172
  zb_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
173
  with gr.Column(scale=4):
174
+ zb_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
175
  with gr.Group():
176
  gr.Markdown("Zero Bubble V Schedule (ZBV)")
177
  with gr.Row():
 
180
  zbv_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
181
  zbv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
182
  with gr.Column(scale=4):
183
+ zbv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
184
  button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
185
+
186
+ for (k, v) in presets.items():
187
+ def update_preset(pb, p, m, f, b, w, c, mem):
188
+ print(pb)
189
+ print(presets[pb])
190
+ print(presets[pb][-1])
191
+ return *presets[pb],*calculate(*presets[pb][:-1], update_mem(p, presets[pb][-1], -1))
192
+ preset_buttons[k].click(
193
+ update_preset,
194
+ inputs=[preset_buttons[k], p, m, f, b, w, c, mem],
195
+ outputs=[p, m, f, b, w, c, memsel, baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
196
  gr.Markdown(open("description2.md").read())
197
  demo.launch()
hand_schedule.py CHANGED
@@ -11,8 +11,10 @@ class ScheduledNode:
11
 
12
 
13
  def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
14
- assert _n >= 2 * _p
15
  stage = [[] for _ in range(_p)]
 
 
16
  for rank in range(_p):
17
  warmup = (_p - rank - 1) * warmup_c
18
  for _ in range(warmup):
@@ -25,12 +27,13 @@ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
25
  stage[rank].append(2)
26
  for _ in range((_p - 1) * warmup_c - warmup):
27
  stage[rank].append(2)
28
- labels = ["F", "B", "W"]
29
  for rank in range(_p):
30
  rank_str = " " * rank
31
- for i in range(_n * 3):
 
32
  rank_str += labels[stage[rank][i]]
33
- # print(rank_str)
34
  size = _p * _n * 3
35
  def get_id(_i, _j, _k):
36
  return _i * _p * _n + _j * _n + _k
@@ -42,6 +45,8 @@ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
42
  for rank in range(_p):
43
  last = e[rank]
44
  if stage[rank][i] == 0:
 
 
45
  tmp = e[rank] + _f
46
  if rank > 0:
47
  assert t[get_id(0, rank - 1, fc[rank])] > 0
@@ -50,17 +55,17 @@ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
50
  t[get_id(0, rank, fc[rank])] = tmp
51
  fc[rank] += 1
52
  elif stage[rank][i] == 1:
 
 
53
  tmp = e[rank] + _b
54
  if rank < _p - 1:
55
- assert t[get_id(1, rank + 1, bc[rank])] > 0
56
  tmp = max(tmp, t[get_id(1, rank + 1, bc[rank])] + _c + _b)
57
  e[rank] = tmp
58
  t[get_id(1, rank, bc[rank])] = tmp
59
  bc[rank] += 1
60
- else:
61
- tmp = e[rank] + _w
62
- e[rank] = tmp
63
- t[get_id(2, rank, i - fc[rank] - bc[rank])] = tmp
64
  # if rank == _p - 1:
65
  # print(_f, _b, _w, _c, "->", rank, i, stage[rank][i], e[rank], e[rank] - last)
66
  max_time = 0
@@ -73,7 +78,7 @@ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
73
  # exit(0)
74
  res = [[] for _ in range(_p)]
75
  for rank in range(_p):
76
- for i in range(_n):
77
  res[rank].append(ScheduledNode(
78
  "F", rank, i, t[get_id(0, rank, i)] - _f, t[get_id(0, rank, i)]))
79
  res[rank].append(ScheduledNode(
@@ -81,4 +86,8 @@ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
81
  res[rank].append(ScheduledNode(
82
  "W", rank, i, t[get_id(2, rank, i)] - _w, t[get_id(2, rank, i)]))
83
  res[rank] = sorted(res[rank], key=lambda x: x.start_time)
84
- return res
 
 
 
 
 
11
 
12
 
13
  def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
14
+ # assert _n >= 2 * _p
15
  stage = [[] for _ in range(_p)]
16
+ real_n = _n
17
+ _n = max(_n, _p)
18
  for rank in range(_p):
19
  warmup = (_p - rank - 1) * warmup_c
20
  for _ in range(warmup):
 
27
  stage[rank].append(2)
28
  for _ in range((_p - 1) * warmup_c - warmup):
29
  stage[rank].append(2)
30
+ labels = ["F", "B", "W", '.']
31
  for rank in range(_p):
32
  rank_str = " " * rank
33
+ # for i in range(_n * 3):
34
+ for i in range(len(stage[rank])):
35
  rank_str += labels[stage[rank][i]]
36
+ print(rank_str)
37
  size = _p * _n * 3
38
  def get_id(_i, _j, _k):
39
  return _i * _p * _n + _j * _n + _k
 
45
  for rank in range(_p):
46
  last = e[rank]
47
  if stage[rank][i] == 0:
48
+ if fc[rank] >= real_n:
49
+ continue
50
  tmp = e[rank] + _f
51
  if rank > 0:
52
  assert t[get_id(0, rank - 1, fc[rank])] > 0
 
55
  t[get_id(0, rank, fc[rank])] = tmp
56
  fc[rank] += 1
57
  elif stage[rank][i] == 1:
58
+ if bc[rank] >= real_n:
59
+ continue
60
  tmp = e[rank] + _b
61
  if rank < _p - 1:
62
+ assert t[get_id(1, rank + 1, bc[rank])] > 0, f"{rank} {i} {bc[rank]}"
63
  tmp = max(tmp, t[get_id(1, rank + 1, bc[rank])] + _c + _b)
64
  e[rank] = tmp
65
  t[get_id(1, rank, bc[rank])] = tmp
66
  bc[rank] += 1
67
+ elif stage[rank][i] == 2:
68
+ continue
 
 
69
  # if rank == _p - 1:
70
  # print(_f, _b, _w, _c, "->", rank, i, stage[rank][i], e[rank], e[rank] - last)
71
  max_time = 0
 
78
  # exit(0)
79
  res = [[] for _ in range(_p)]
80
  for rank in range(_p):
81
+ for i in range(real_n):
82
  res[rank].append(ScheduledNode(
83
  "F", rank, i, t[get_id(0, rank, i)] - _f, t[get_id(0, rank, i)]))
84
  res[rank].append(ScheduledNode(
 
86
  res[rank].append(ScheduledNode(
87
  "W", rank, i, t[get_id(2, rank, i)] - _w, t[get_id(2, rank, i)]))
88
  res[rank] = sorted(res[rank], key=lambda x: x.start_time)
89
+ return res
90
+
91
+ if __name__ == "__main__":
92
+ print(get_hand_schedule(16, 16, 1, 1, 1, 0))
93
+