Wan Xinyi commited on
Commit
4b2c8d9
1 Parent(s): 933f413

initial commit

Browse files
Files changed (3) hide show
  1. app.py +126 -0
  2. auto_schedule.py +564 -0
  3. v_schedule.py +461 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import auto_schedule
3
+ import v_schedule
4
+
5
+ def greet(name, is_morning, temperature):
6
+ salutation = "Good morning" if is_morning else "Good evening"
7
+ greeting = f"{salutation} {name}. It is {temperature} degrees today"
8
+ celsius = (temperature - 32) * 5 / 9
9
+ return greeting, round(celsius, 2)
10
+
11
+ def percentage(x):
12
+ return f"{x*100:.2f}%"
13
+
14
+ def get_schedule_time_and_image(result):
15
+ result = [
16
+ list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result
17
+ ]
18
+ time = max(
19
+ [
20
+ max([x.completion_time for x in stage]) - min([x.start_time for x in stage]) for stage in result
21
+ ]
22
+ )
23
+ return time, None
24
+
25
+ def calculate(p, m, f, b, w, c, mem):
26
+ baseline_time=(f+b+w)*m + (f+b+w+c)*(p-1)
27
+ baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
28
+ baseline_acceleration=percentage(0)
29
+ baseline_image=None
30
+
31
+
32
+ zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig(
33
+ cost_f=f,
34
+ cost_b=b,
35
+ cost_w=w,
36
+ cost_comm=c,
37
+ max_mem=mem * 2,
38
+ print_scaling=1000
39
+ ))
40
+ zb_time,zb_image=get_schedule_time_and_image(zb_result)
41
+
42
+ zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
43
+ zb_acceleration=percentage(baseline_time/zb_time - 1)
44
+
45
+ zbv_graph = v_schedule.PipelineGraph(
46
+ n_stage=p,
47
+ n_micro=m,
48
+ f_cost=f/2,
49
+ b_cost=b/2,
50
+ w_cost=w/2,
51
+ c_cost=c,
52
+ f_mem=2,
53
+ b_mem=-1,
54
+ w_mem=-1,
55
+ max_mem=mem * 4,
56
+ )
57
+ zbv_result = zbv_graph.get_v_schedule()
58
+
59
+ zbv_time,zbv_image = get_schedule_time_and_image(zbv_result)
60
+ zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
61
+ zbv_acceleration=percentage(baseline_time/zbv_time - 1)
62
+ zbv_image=None
63
+
64
+ return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]
65
+
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown("Zero bubble pipeline parallel bubble calculator")
68
+ with gr.Row():
69
+ with gr.Column(scale=1):
70
+ with gr.Group():
71
+ gr.Markdown("Basic Parameters")
72
+ with gr.Row():
73
+ p=gr.Number(label="Number of stages (p)", value=4, interactive=True, precision=0)
74
+ m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0)
75
+ with gr.Column(scale=2):
76
+ with gr.Group():
77
+ gr.Markdown("Costs. All costs are used as integers. For ZBV schedules, this is the time of two virtual stages on a stage combined.")
78
+ with gr.Row():
79
+ f=gr.Number(label="Time of F", value=8, interactive=True, precision=0)
80
+ b=gr.Number(label="Time of B", value=8, interactive=True, precision=0)
81
+ w=gr.Number(label="Time of W", value=8, interactive=True, precision=0)
82
+ c=gr.Number(label="Time of one P2P communication", value=1, interactive=True, precision=0)
83
+ with gr.Group():
84
+ gr.Markdown("Activation memory limit.")
85
+ def update_mem(p, s, mem):
86
+ print("update")
87
+ if s=="custom":
88
+ return mem
89
+ return p*int(s[:-1])
90
+ memsel=gr.Radio(choices=["1p", "2p", "3p", "custom"], value="1p")
91
+ mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0)
92
+ memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
93
+ p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
94
+
95
+ button=gr.Button("Calculate")
96
+
97
+ with gr.Group():
98
+ gr.Markdown("1F1B")
99
+ with gr.Row():
100
+ with gr.Column(scale=1):
101
+ baseline_time=gr.Textbox("", label="Longest Stage Time")
102
+ baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
103
+ baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
104
+ with gr.Column(scale=4):
105
+ baseline_image=gr.Image(None, interactive=False, label="Schedule Image")
106
+
107
+ with gr.Group():
108
+ gr.Markdown("Zero Bubble Schedule")
109
+ with gr.Row():
110
+ with gr.Column(scale=1):
111
+ zb_time=gr.Textbox("", label="Longest Stage Time")
112
+ zb_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
113
+ zb_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
114
+ with gr.Column(scale=4):
115
+ zb_image=gr.Image(None, interactive=False, label="Schedule Image")
116
+ with gr.Group():
117
+ gr.Markdown("Zero Bubble V Schedule")
118
+ with gr.Row():
119
+ with gr.Column(scale=1):
120
+ zbv_time=gr.Textbox("", label="Longest Stage Time")
121
+ zbv_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
122
+ zbv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
123
+ with gr.Column(scale=4):
124
+ zbv_image=gr.Image(None, interactive=False, label="Schedule Image")
125
+ button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
126
+ demo.launch()
auto_schedule.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Set
3
+
4
+
5
+
6
+ @dataclass
7
+ class GraphConfig:
8
+ mem_f: float = 2
9
+ mem_b: float = -1
10
+ mem_w: float = -1
11
+ max_mem: float = None
12
+ cost_f: int = 1
13
+ cost_b: int = 1
14
+ cost_w: int = 1
15
+ cost_comm: int = 0
16
+ print_scaling: int = 1
17
+
18
+ def __post_init__(self):
19
+ assert type(self.cost_f) is int
20
+ assert type(self.cost_b) is int
21
+ assert type(self.cost_w) is int
22
+ assert type(self.cost_comm) is int
23
+ assert self.mem_f + self.mem_b + self.mem_w == 0
24
+
25
+ @dataclass(eq=True, frozen=True)
26
+ class ScheduledNode:
27
+ type: str
28
+ stage: int
29
+ minibatch: int
30
+ start_time: int
31
+ completion_time: int
32
+ rollback: bool = False
33
+
34
+
35
+ @dataclass
36
+ class Graph:
37
+ nstages: int
38
+ nmb: int
39
+ nnodes: int
40
+ config: GraphConfig
41
+ parents: List[Set[int]] = None
42
+ name: List[str] = None
43
+
44
+ # ID mapping:
45
+ # F[stage][minibatch]: 0..STAGE* MB
46
+ # B[stage][minibatch]: STAGE* MB .. 2 * STAGE * MB
47
+ # W[stage][minibatch]: 2 * STAGE* MB .. 3 * STAGE * MB
48
+
49
+ def get_id(self, type, stage, mb):
50
+ return type * (self.nstages * self.nmb) + stage * self.nmb + mb
51
+
52
+ def get_stage(self, id):
53
+ return (id // self.nmb) % self.nstages
54
+
55
+ def get_cost(self, id):
56
+ type = id // (self.nstages * self.nmb)
57
+ return [self.config.cost_f, self.config.cost_b, self.config.cost_w][type]
58
+
59
+ def get_mem(self, id):
60
+ type = id // (self.nstages * self.nmb)
61
+ return [self.config.mem_f, self.config.mem_b, self.config.mem_w][type]
62
+
63
+ @classmethod
64
+ def build_graph(cls, nstages, nmb, config):
65
+ nnodes = nstages * nmb * 3
66
+ g = Graph(nstages=nstages, nmb=nmb, nnodes=nnodes, config=config)
67
+ parents = []
68
+ name = []
69
+ for type in range(3):
70
+ for stage in range(nstages):
71
+ for mb in range(nmb):
72
+ p = set()
73
+ if type == 0:
74
+ name.append(f'F{mb}')
75
+ if stage > 0:
76
+ p.add(g.get_id(type, stage - 1, mb))
77
+ if mb > 0:
78
+ p.add(g.get_id(type, stage, mb - 1))
79
+ elif type == 1:
80
+ name.append(f'B{mb}')
81
+ if stage == nstages - 1:
82
+ p.add(g.get_id(0, stage, mb))
83
+ else:
84
+ p.add(g.get_id(type, stage + 1, mb))
85
+ if mb > 0:
86
+ p.add(g.get_id(type, stage, mb - 1))
87
+ elif type == 2:
88
+ name.append(f'W{mb}')
89
+ p.add(g.get_id(1, stage, mb))
90
+ if mb > 0:
91
+ p.add(g.get_id(type, stage, mb - 1))
92
+ else:
93
+ assert False
94
+ parents.append(p)
95
+
96
+ g.name = name
97
+ g.parents = parents
98
+ return g
99
+
100
+ # Manual ordering producing this kind of schedule:
101
+ # fffffffbfbfbfbfbfbwbwbwbwbwbwbwwwwww
102
+ # fffffbfbfbfbfbfbfbfbwbwbwbwbwwwwwwww
103
+ # fffbfbfbfbfbfbfbfbfbfbwbwbwwwwwwwwww
104
+ # fbfbfbfbfbfbfbfbfbfbfbfbwwwwwwwwwwww
105
+ # Returns the order index of each node on its own stage
106
+ def manual_order(
107
+ self, allow_bubble_before_first_b=False, prioritize_b=False, no_bubble_greedy=True
108
+ ):
109
+ order = [0] * self.nnodes
110
+ f = [0] * self.nstages
111
+ b = [0] * self.nstages
112
+ w = [0] * self.nstages
113
+ o = [0] * self.nstages
114
+ m = [0] * self.nstages
115
+ e = [0] * self.nstages
116
+ t = [0] * self.nnodes
117
+ max_mem = self.config.max_mem or self.get_mem(self.get_id(0, 0, 0)) * self.nmb * 3
118
+ comm = self.config.cost_comm
119
+ order_str = [""] * self.nstages
120
+ stage_bubble = [0] * self.nstages
121
+
122
+ def get_max_bubble():
123
+ max_bubble = 0
124
+ for bb in stage_bubble:
125
+ max_bubble = max(max_bubble, bb)
126
+ return max_bubble
127
+
128
+ def put(stage_j, type_k):
129
+ if type_k == 0:
130
+ _i = f[stage_j]
131
+ elif type_k == 1:
132
+ _i = b[stage_j]
133
+ else:
134
+ _i = w[stage_j]
135
+ _j = stage_j
136
+ _id = self.get_id(type_k, _j, _i)
137
+ _mem = self.get_mem(_id)
138
+ _cost = self.get_cost(_id)
139
+ assert m[_j] + _mem <= max_mem
140
+
141
+ tmp = e[_j] + _cost
142
+ no_bubble = tmp
143
+ if _j > 0 and type_k == 0:
144
+ tmp = max(tmp, t[self.get_id(0, _j - 1, _i)] + comm + _cost)
145
+ if _j < self.nstages - 1 and type_k == 1:
146
+ tmp = max(tmp, t[self.get_id(1, _j + 1, _i)] + comm + _cost)
147
+ if f[_j] > 0:
148
+ stage_bubble[_j] += tmp - no_bubble
149
+ e[_j] = tmp
150
+ t[_id] = tmp
151
+ m[_j] += _mem
152
+ order[_id] = o[_j]
153
+ if type_k == 0:
154
+ f[_j] += 1
155
+ elif type_k == 1:
156
+ b[_j] += 1
157
+ else:
158
+ w[_j] += 1
159
+ o[_j] += 1
160
+ fbw = "fbw"
161
+ order_str[stage_j] += fbw[type_k]
162
+
163
+ for i in range(self.nmb):
164
+ if i == 0:
165
+ for j in range(self.nstages):
166
+ put(j, 0)
167
+ f_required = [0] * self.nstages
168
+ last_t = 0
169
+ for j in range(self.nstages - 1, -1, -1):
170
+ if j == self.nstages - 1:
171
+ last_t = t[self.get_id(0, j, i)] + self.get_cost(self.get_id(1, j, i))
172
+ continue
173
+ mem = m[j]
174
+ cost = e[j]
175
+ while True:
176
+ f_id = self.get_id(0, j, f[j] + f_required[j])
177
+ if f[j] + f_required[j] < self.nmb and mem + self.get_mem(f_id) <= max_mem:
178
+ if allow_bubble_before_first_b:
179
+ if cost + self.get_cost(f_id) > last_t + comm:
180
+ break
181
+ else:
182
+ if cost >= last_t + comm:
183
+ break
184
+ mem += self.get_mem(f_id)
185
+ cost += self.get_cost(f_id)
186
+ f_required[j] += 1
187
+ else:
188
+ break
189
+ last_t = max(cost, last_t + comm) + self.get_cost(self.get_id(1, j, i))
190
+ for j in range(self.nstages):
191
+ while j > 0 and f_required[j] > 0 and f_required[j] >= f_required[j - 1] and f[j] + f_required[j] < self.nmb:
192
+ f_required[j] -= 1
193
+ for j in range(self.nstages - 1, -1, -1):
194
+ for _ in range(f_required[j]):
195
+ put(j, 0)
196
+ put(j, 1)
197
+ continue
198
+ f_required = [0] * self.nstages
199
+ for j in range(self.nstages):
200
+ if f[j] >= self.nmb:
201
+ continue
202
+ if j + 1 < self.nstages and f[j] >= f[j + 1] + 2 and prioritize_b:
203
+ next_plus_fw = (
204
+ e[j + 1]
205
+ + self.get_cost(self.get_id(0, j + 1, f[j + 1]))
206
+ + self.get_cost(self.get_id(1, j + 1, b[j + 1]))
207
+ + comm
208
+ )
209
+ if e[j] >= next_plus_fw:
210
+ continue
211
+ f_id = self.get_id(0, j, f[j])
212
+ f_mem = self.get_mem(f_id)
213
+ w_cost, w_cnt = 0, 0
214
+ mem_with_w = m[j] + f_mem
215
+ while mem_with_w > max_mem and w[j] + w_cnt < b[j]:
216
+ w_id = self.get_id(2, j, w[j] + w_cnt)
217
+ w_cost += self.get_cost(w_id)
218
+ mem_with_w += self.get_mem(w_id)
219
+ w_cnt += 1
220
+ if e[j] + self.get_cost(f_id) + w_cost <= next_plus_fw:
221
+ f_required[j] = 1
222
+ continue
223
+
224
+ w_cost, w_cnt = 0, 0
225
+ # mem_with_w = m[j]
226
+ # while w[j] + w_cnt < b[j]:
227
+ # w_id = self.get_id(2, j, w[j] + w_cnt)
228
+ # w_cost += self.get_cost(w_id)
229
+ # mem_with_w += self.get_mem(w_id)
230
+ # w_cnt += 1
231
+ # if e[j] + w_cost >= next_plus_fw:
232
+ # continue
233
+ if next_plus_fw - (e[j] + w_cost) <= get_max_bubble() - stage_bubble[j]:
234
+ # TODO: can sample here
235
+ continue
236
+ f_required[j] = 1
237
+ for j in range(self.nstages - 2, -1, -1):
238
+ f_required[j] = min(f_required[j], f_required[j + 1])
239
+ for j in range(self.nstages):
240
+ if f_required[j] == 0:
241
+ continue
242
+ f_id = self.get_id(0, j, f[j])
243
+ mem = self.get_mem(f_id)
244
+ while m[j] + mem > max_mem:
245
+ if w[j] >= b[j]:
246
+ raise ValueError("Cannot fit memory")
247
+ put(j, 2)
248
+ if j > 0:
249
+ while (
250
+ w[j] < b[j]
251
+ and e[j] + self.get_cost(self.get_id(2, j, w[j]))
252
+ <= t[self.get_id(0, j - 1, f[j])] + comm
253
+ ):
254
+ put(j, 2)
255
+ if w[j] < b[j] and e[j] < t[self.get_id(0, j - 1, f[j])] + comm:
256
+ # TODO: e[j] + self.get_cost(self.get_id(2, j, w[j])) > t[self.get_id(0, j - 1, f[j])] + comm
257
+ if (
258
+ t[self.get_id(0, j - 1, f[j])] + comm - e[j]
259
+ <= get_max_bubble() - stage_bubble[j]
260
+ ):
261
+ # TODO: can sample here
262
+ if no_bubble_greedy:
263
+ put(j, 2)
264
+ else:
265
+ put(j, 2)
266
+ put(j, 0)
267
+ for j in range(self.nstages - 1, -1, -1):
268
+ assert b[j] == i
269
+ b_id = self.get_id(1, j, b[j])
270
+ mem = self.get_mem(b_id)
271
+ while m[j] + mem > max_mem:
272
+ if w[j] >= b[j]:
273
+ raise ValueError("Cannot fit memory")
274
+ put(j, 2)
275
+ if j + 1 < self.nstages:
276
+ while (
277
+ w[j] < b[j]
278
+ and e[j] + self.get_cost(self.get_id(2, j, w[j]))
279
+ <= t[self.get_id(1, j + 1, i)] + comm
280
+ ):
281
+ put(j, 2)
282
+ if w[j] < b[j] and e[j] < t[self.get_id(1, j + 1, i)] + comm:
283
+ # TODO: e[j] + self.get_cost(self.get_id(2, j, w[j])) > t[self.get_id(1, j + 1, i)] + comm
284
+ if (
285
+ t[self.get_id(1, j + 1, i)] + comm - e[j]
286
+ <= get_max_bubble() - stage_bubble[j]
287
+ ):
288
+ # TODO: can sample here
289
+ if no_bubble_greedy:
290
+ put(j, 2)
291
+ else:
292
+ put(j, 2)
293
+ if j == 0 and f[j] == self.nmb:
294
+ while w[j] < b[j]:
295
+ put(j, 2)
296
+ put(j, 1)
297
+
298
+ for i in range(self.nstages):
299
+ while w[i] < self.nmb:
300
+ put(i, 2)
301
+ # print(f"{' ' * i}{order_str[i]} -> {e[i]}")
302
+
303
+ for i in range(self.nstages):
304
+ for j in range(self.nmb):
305
+ f_id = self.get_id(0, i, j)
306
+ b_id = self.get_id(1, i, j)
307
+ w_id = self.get_id(2, i, j)
308
+ f_cost = self.get_cost(f_id)
309
+ b_cost = self.get_cost(b_id)
310
+ w_cost = self.get_cost(w_id)
311
+ assert t[b_id] >= t[f_id] + b_cost
312
+ assert t[w_id] >= t[b_id] + w_cost, f"{i}-{j}, {t[w_id]} >= {t[b_id]} + {w_cost}"
313
+ if i > 0:
314
+ assert t[f_id] >= t[self.get_id(0, i - 1, j)] + comm + f_cost, f"{i}-{j}"
315
+ if i < self.nstages - 1:
316
+ assert t[b_id] >= t[self.get_id(1, i + 1, j)] + comm + b_cost
317
+
318
+ # print(order)
319
+ best_time = 0
320
+ for i in range(self.nstages):
321
+ time_i = (
322
+ t[self.get_id(2, i, self.nmb - 1)]
323
+ - t[self.get_id(0, i, 0)]
324
+ + self.get_cost(self.get_id(0, i, 0))
325
+ )
326
+ best_time = max(best_time, time_i)
327
+
328
+ return order, t, best_time
329
+
330
+
331
+ def initial_solution(graph):
332
+ best_time, order, complete_time = None, None, None
333
+ for allow_bubble_before_first_b in [True, False]:
334
+ for prioritize_b in [True, False]:
335
+ for no_bubble_greedy in [True, False]:
336
+ order_t, complete_time_t, best_time_t = graph.manual_order(
337
+ allow_bubble_before_first_b=allow_bubble_before_first_b,
338
+ prioritize_b=prioritize_b,
339
+ no_bubble_greedy=no_bubble_greedy,
340
+ )
341
+ if best_time is None or best_time_t < best_time:
342
+ best_time = best_time_t
343
+ order = order_t
344
+ complete_time = complete_time_t
345
+
346
+ print_detail(graph, complete_time)
347
+ print("-" * 20, best_time, "-" * 20)
348
+ return best_time, order, complete_time
349
+
350
+
351
+ def print_detail(graph, F):
352
+ typenames = ['F', 'B', 'W']
353
+ times = []
354
+ for stage in range(graph.nstages):
355
+ stage_str = ['.'] * int(F[graph.get_id(2, stage, graph.nmb - 1)] / graph.config.print_scaling)
356
+ for _type in range(3):
357
+ for _mb in range(graph.nmb):
358
+ _id = graph.get_id(_type, stage, _mb)
359
+ end = int(F[_id] / graph.config.print_scaling)
360
+ start = int((F[_id] - graph.get_cost(_id)) / graph.config.print_scaling)
361
+ for j in range(start, end):
362
+ if j == start or j == end - 1:
363
+ stage_str[j] = typenames[_type]
364
+ elif j == start + 1:
365
+ if _mb >= 10:
366
+ stage_str[j] = str(_mb // 10)
367
+ else:
368
+ stage_str[j] = str(_mb)
369
+ elif j == start + 2 and _mb >= 10:
370
+ stage_str[j] = str(_mb % 10)
371
+ else:
372
+ stage_str[j] = "-"
373
+ _str = ""
374
+ for _c in stage_str:
375
+ _str += _c
376
+ times.append(
377
+ F[graph.get_id(2, stage, graph.nmb - 1)]
378
+ - F[graph.get_id(0, stage, 0)]
379
+ + graph.get_cost(graph.get_id(0, stage, 0))
380
+ )
381
+ print(_str)
382
+ print('Longest stage time: ', max(times))
383
+
384
+
385
+ def ilp_results(graph, F):
386
+ typenames = ['F', 'B', 'W']
387
+ local_order = []
388
+ end_time = []
389
+ for i in range(graph.nnodes):
390
+ end_time.append(F[i])
391
+ for stage in range(graph.nstages):
392
+ order = []
393
+ for type in range(3):
394
+ for mb in range(graph.nmb):
395
+ id = graph.get_id(type, stage, mb)
396
+ order.append(
397
+ ScheduledNode(
398
+ type=typenames[type],
399
+ stage=stage,
400
+ minibatch=mb,
401
+ start_time=end_time[id] - graph.get_cost(id),
402
+ completion_time=F[id],
403
+ )
404
+ )
405
+ local_order.append(order)
406
+ # For each F/B, append a send/recv node. The timestamp of recv node is the same as send node to guarrentee a global order.
407
+ comm_id = {}
408
+ comm_id_counter = 0
409
+ post_validation_time = 0
410
+ for i in range(graph.nstages - 1, -1, -1):
411
+ warmup_f_count = -1
412
+ first_b_end = end_time[graph.get_id(1, i, 0)]
413
+ for j in range(graph.nmb):
414
+ if end_time[graph.get_id(0, i, j)] < first_b_end:
415
+ warmup_f_count += 1
416
+ assert warmup_f_count >= 0
417
+ pv_id = warmup_f_count
418
+ _id = graph.get_id(0, i, pv_id)
419
+ _cost = graph.get_cost(_id)
420
+ post_validation_time = max(post_validation_time, end_time[_id] - _cost - graph.config.cost_comm)
421
+ # post_validation_time = 0
422
+ # print(i, pv_id, post_validation_time)
423
+ for it in ["RECV_", "SEND_", ""]:
424
+ if i == 0 and it == "SEND_":
425
+ continue
426
+ if i == graph.nstages - 1 and it == "RECV_":
427
+ continue
428
+ # stage_ = i - 1 if it == "RECV_" else i
429
+ stage_ = i
430
+ local_order[stage_].append(ScheduledNode(
431
+ type=it + "POST_VALIDATION",
432
+ stage=stage_,
433
+ minibatch=0,
434
+ start_time=post_validation_time,
435
+ completion_time=post_validation_time,
436
+ ))
437
+ comm_id[local_order[stage_][-1]] = comm_id_counter
438
+ comm_id_counter += 1
439
+ for stage in range(graph.nstages):
440
+ for node in local_order[stage]:
441
+ if node.type == 'F' and node.stage != graph.nstages - 1:
442
+ local_order[stage].append(
443
+ ScheduledNode(
444
+ type='SEND_FORWARD',
445
+ stage=stage,
446
+ minibatch=node.minibatch,
447
+ start_time=node.completion_time,
448
+ completion_time=node.completion_time, # TODO: consider comm cost in completion time
449
+ )
450
+ )
451
+ local_order[stage + 1].append(
452
+ ScheduledNode(
453
+ type='RECV_FORWARD',
454
+ stage=stage + 1,
455
+ minibatch=node.minibatch,
456
+ start_time=node.completion_time,
457
+ completion_time=node.completion_time, # TODO: consider comm cost in completion time
458
+ )
459
+ )
460
+ comm_id[local_order[stage][-1]] = comm_id_counter
461
+ comm_id[local_order[stage + 1][-1]] = comm_id_counter
462
+ comm_id_counter += 1
463
+ if node.type == 'B' and node.stage != 0:
464
+ local_order[stage].append(
465
+ ScheduledNode(
466
+ type='SEND_BACKWARD',
467
+ stage=stage,
468
+ minibatch=node.minibatch,
469
+ start_time=node.completion_time,
470
+ completion_time=node.completion_time, # TODO: consider comm cost in completion time
471
+ )
472
+ )
473
+ local_order[stage - 1].append(
474
+ ScheduledNode(
475
+ type='RECV_BACKWARD',
476
+ stage=stage - 1,
477
+ minibatch=node.minibatch,
478
+ start_time=node.completion_time,
479
+ completion_time=node.completion_time, # TODO: consider comm cost in completion time
480
+ )
481
+ )
482
+ comm_id[local_order[stage][-1]] = comm_id_counter
483
+ comm_id[local_order[stage - 1][-1]] = comm_id_counter
484
+ comm_id_counter += 1
485
+ for stage in range(graph.nstages):
486
+ # For nodes with the same timestamp on the same stage, communication will be prioritized.
487
+ def even_breaker(x: ScheduledNode):
488
+ # Compute nodes are always delayed.
489
+ if x.type in ['F', 'B', 'W']:
490
+ return comm_id_counter
491
+ # For comm nodes, order by their unique comm id
492
+ return comm_id[x]
493
+
494
+ local_order[stage] = list(sorted(
495
+ local_order[stage], key=lambda x: (x.start_time, even_breaker(x))
496
+ ))
497
+ # If a recv with intersects with previous computation, reorder them so that recv
498
+ # is executed before computation and hence can be overlapped.
499
+ for i in range(len(local_order[stage])):
500
+ if i > 0 and local_order[stage][i - 1].type in {'F', 'B', 'W'} and \
501
+ local_order[stage][i].type.startswith('RECV') and \
502
+ "POST_VALIDATION" not in local_order[stage][i].type and \
503
+ local_order[stage][i].start_time <= local_order[stage][i - 1].completion_time:
504
+ (local_order[stage][i], local_order[stage][i - 1]) = (local_order[stage][i - 1], local_order[stage][i])
505
+ # print([(x.type, x.start_time, x.completion_time) for x in local_order[stage]])
506
+
507
+ local_order_with_rollback = [[] for _ in range(graph.nstages)]
508
+ for rank in range(graph.nstages):
509
+ rollback_comm = set()
510
+ if rank > 0:
511
+ for node in local_order[rank - 1]:
512
+ if node.type == "POST_VALIDATION":
513
+ break
514
+ if node.type == "SEND_FORWARD":
515
+ rollback_comm.add(node.minibatch)
516
+ for node in local_order[rank]:
517
+ if node.type == "RECV_FORWARD" and node.minibatch in rollback_comm:
518
+ rollback = True
519
+ rollback_comm.remove(node.minibatch)
520
+ else:
521
+ rollback = False
522
+ local_order_with_rollback[rank].append(ScheduledNode(
523
+ type=node.type,
524
+ stage=node.stage,
525
+ minibatch=node.minibatch,
526
+ start_time=node.start_time,
527
+ completion_time=node.completion_time,
528
+ rollback=rollback,
529
+ ))
530
+ assert len(rollback_comm) == 0
531
+ # for node in local_order_with_rollback[rank]:
532
+ # print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=', ')
533
+ # print()
534
+
535
+ print_detail(graph, end_time)
536
+ return local_order_with_rollback
537
+
538
+
539
+ def auto_schedule(nstages, nmb, config):
540
+ graph = Graph.build_graph(nstages, nmb, config)
541
+
542
+ best_time, order, complete_time = initial_solution(graph)
543
+ return ilp_results(graph, complete_time)
544
+
545
+
546
+ if __name__ == "__main__":
547
+ # auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=10))
548
+ # auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=14))
549
+ auto_schedule(24, 72, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=100))
550
+ auto_schedule(4, 12, GraphConfig(
551
+ cost_f=5478,
552
+ cost_b=5806,
553
+ cost_w=3534,
554
+ cost_comm=200,
555
+ max_mem=32,
556
+ print_scaling=1000
557
+ ))
558
+ auto_schedule(32, 16, GraphConfig(
559
+ cost_f=1,
560
+ cost_b=1,
561
+ cost_w=1,
562
+ cost_comm=0,
563
+ max_mem=64,
564
+ ))
v_schedule.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass(eq=True, frozen=True)
5
+ class ScheduledNode:
6
+ type: str
7
+ chunk: int
8
+ stage: int
9
+ minibatch: int
10
+ start_time: int
11
+ completion_time: int
12
+ rollback: bool = False
13
+
14
+
15
+ class PipelineGraph(object):
16
+ def __init__(
17
+ self, n_stage, n_micro, f_cost, b_cost, w_cost, c_cost,
18
+ f_mem, b_mem, w_mem, max_mem=None,
19
+ ):
20
+ self.n_node = 6 * n_stage * n_micro
21
+ self.n_stage = n_stage
22
+ self.n_micro = n_micro
23
+ self.f_cost = f_cost
24
+ self.b_cost = b_cost
25
+ self.w_cost = w_cost
26
+ self.c_cost = c_cost
27
+ self.f_mem = f_mem
28
+ self.b_mem = b_mem
29
+ self.w_mem = w_mem
30
+ self.fbw_cost = [f_cost, b_cost, w_cost]
31
+ self.fbw_mem = [f_mem, b_mem, w_mem]
32
+ self.max_mem = max_mem or f_mem * self.n_stage * 2
33
+
34
+ def get_id(self, cat, chunk, stage, micro):
35
+ return cat * 2 * self.n_stage * self.n_micro + \
36
+ chunk * self.n_stage * self.n_micro + \
37
+ stage * self.n_micro + \
38
+ micro
39
+
40
+ def try_v_schedule(self, fill_f=True, fill_b=True, approved_bubble=None):
41
+ count = []
42
+ for i in range(self.n_stage):
43
+ count.append([0] * 6)
44
+
45
+ end_time = [-1] * self.n_node
46
+ cur_time = [0] * self.n_stage
47
+ mem = [0] * self.n_stage
48
+ stage_bubble = [0] * self.n_stage
49
+ pending_w = [deque() for _ in range(self.n_stage)]
50
+ schedule = [[] for _ in range(self.n_stage)]
51
+ stage_str = [" " * i for i in range(self.n_stage)]
52
+
53
+ if approved_bubble is None:
54
+ approved_bubble = [-1] * self.n_stage
55
+ max_approved_bubble = max(approved_bubble)
56
+
57
+ def get_max_stage_bubble(stage=-1):
58
+ max_stage_bubble = 0
59
+ for bb in stage_bubble:
60
+ max_stage_bubble = max(max_stage_bubble, bb)
61
+ if stage >= 0:
62
+ max_stage_bubble = max(max_stage_bubble, max_approved_bubble - approved_bubble[stage])
63
+ return max_stage_bubble
64
+
65
+ def put_w(stage):
66
+ assert len(pending_w[stage]) > 0
67
+ _, chunk_, _ = pending_w[stage].popleft()
68
+ put(2, chunk_, stage)
69
+
70
+ def put(cat, chunk, stage, assert_cnt=True):
71
+ _tmp = _no_bubble = cur_time[stage] + self.fbw_cost[cat]
72
+ _cnt = count[stage][cat * 2 + chunk]
73
+ # assert _cnt < self.n_micro
74
+ if _cnt >= self.n_micro:
75
+ if not assert_cnt:
76
+ stage_str[stage] += " "
77
+ cur_time[stage] = _tmp # TODO
78
+ return
79
+ assert False
80
+ assert mem[stage] + self.fbw_mem[cat] <= self.max_mem
81
+ stage_str[stage] += "FfBbWw"[cat * 2 + chunk] + str(_cnt + 1) + " " * (3 - len(str(_cnt + 1)))
82
+ if cat > 0 or chunk > 0:
83
+ last_id = cat * 2 + chunk - 1
84
+ if cat < 2:
85
+ # if end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] < 0:
86
+ # print(cat, chunk, stage, _cnt)
87
+ # self.print_details(end_time)
88
+ assert end_time[self.get_id(last_id // 2, last_id % 2, stage, _cnt)] >= 0
89
+ else:
90
+ assert end_time[self.get_id(1, chunk, stage, _cnt)] >= 0
91
+ if chunk == 1 and cat < 2:
92
+ if stage < self.n_stage - 1:
93
+ _fa_id = self.get_id(cat, chunk, stage + 1, _cnt)
94
+ assert end_time[_fa_id] >= 0
95
+ _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
96
+ if chunk == 0 and cat < 2:
97
+ if stage > 0:
98
+ _fa_id = self.get_id(cat, chunk, stage - 1, _cnt)
99
+ # if end_time[_fa_id] < 0:
100
+ # print(cat, chunk, stage, _cnt)
101
+ # self.print_details(end_time)
102
+ assert end_time[_fa_id] >= 0, f"{cat}, {chunk}, {stage}, {_cnt}"
103
+ _tmp = max(_tmp, end_time[_fa_id] + self.c_cost + self.fbw_cost[cat])
104
+ _id = self.get_id(cat, chunk, stage, _cnt)
105
+ if count[stage][0] > 0:
106
+ stage_bubble[stage] += _tmp - _no_bubble
107
+ end_time[_id] = _tmp
108
+ cur_time[stage] = _tmp
109
+ mem[stage] += self.fbw_mem[cat]
110
+ # noinspection PyTypeChecker
111
+ schedule[stage].append((cat, chunk, _cnt))
112
+ if cat == 1:
113
+ pending_w[stage].append((2, chunk, _cnt))
114
+ count[stage][cat * 2 + chunk] += 1
115
+
116
+ for _ in range(2 * self.n_stage):
117
+ for i in range(self.n_stage):
118
+ if count[i][1] >= count[i][0]:
119
+ put(0, 0, i, assert_cnt=False)
120
+ continue
121
+ if i == self.n_stage - 1:
122
+ put(0, 1, i, assert_cnt=False)
123
+ continue
124
+ fa_id = self.get_id(0, 1, i + 1, count[i][1])
125
+ if 0 <= end_time[fa_id] < cur_time[i + 1]: # TODO
126
+ put(0, 1, i, assert_cnt=False)
127
+ else:
128
+ put(0, 0, i, assert_cnt=False)
129
+
130
+ # for i in range(self.n_stage):
131
+ # put(0, 0, i)
132
+ # for i in range(self.n_stage - 1, -1, -1):
133
+ # if i == self.n_stage - 1:
134
+ # put(0, 1, i)
135
+ # continue
136
+ # tmp = end_time[self.get_id(0, 1, i + 1, 0)] + self.c_cost
137
+ # while mem[i] + self.fbw_mem[0] * (2 + i * 2) <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp and count[i][0] < self.n_micro:
138
+ # for j in range(i + 1):
139
+ # put(0, 0, j)
140
+ # put(0, 1, i)
141
+ # iter_chunk_ = 0
142
+ # end_tmp = 0
143
+ # for i in range(self.n_stage):
144
+ # if i == 0:
145
+ # end_tmp = cur_time[0] + self.fbw_cost[1]
146
+ # continue
147
+ # tmp = end_tmp + self.c_cost
148
+ # while count[i][0] + count[i][1] < count[i - 1][0] + count[i - 1][1]:
149
+ # for j in range(self.n_stage - 1, i - 1, -1):
150
+ # if count[j][iter_chunk_] < self.n_micro:
151
+ # put(0, iter_chunk_, j)
152
+ # iter_chunk_ = 1 - iter_chunk_
153
+ # # while mem[i] + self.fbw_mem[0] <= self.max_mem and cur_time[i] + self.fbw_cost[0] <= tmp:
154
+ # # if iter_chunk_ == 0 and count[i][0] >= count[i - 1][0]:
155
+ # # break
156
+ # # for j in range(self.n_stage - 1, i - 1, -1):
157
+ # # if count[j][iter_chunk_] < self.n_micro:
158
+ # # put(0, iter_chunk_, j)
159
+ # # iter_chunk_ = 1 - iter_chunk_
160
+ # # end_tmp = max(tmp, cur_time[i]) + self.fbw_cost[1]
161
+
162
+ # init_bubble = get_max_stage_bubble()
163
+ # print(stage_bubble)
164
+ for _ in range(2 * self.n_micro):
165
+ # check mem before putting b
166
+ for i in range(self.n_stage):
167
+ while mem[i] + self.fbw_mem[1] > self.max_mem:
168
+ assert len(pending_w[i]) > 0
169
+ put_w(i)
170
+ b0_ranks, b1_ranks = [], []
171
+ for i in range(self.n_stage):
172
+ if count[i][3] >= count[i][2]:
173
+ b0_ranks.append(i)
174
+ elif i == self.n_stage - 1:
175
+ b1_ranks.append(i)
176
+ else:
177
+ fa_id = self.get_id(1, 1, i + 1, count[i][3])
178
+ if end_time[fa_id] >= 0 or count[i][2] >= self.n_micro:
179
+ b1_ranks.append(i)
180
+ else:
181
+ b0_ranks.append(i)
182
+ b_ranks = []
183
+ # put b1
184
+ for i in reversed(b1_ranks):
185
+ b_ranks.append((i, 1))
186
+ # put b0
187
+ for i in b0_ranks:
188
+ b_ranks.append((i, 0))
189
+ for i, _chunk_ in b_ranks:
190
+ fa_id = -1
191
+ if _chunk_ == 1 and i < self.n_stage - 1:
192
+ fa_id = self.get_id(1, 1, i + 1, count[i][3])
193
+ if _chunk_ == 0 and i > 0:
194
+ fa_id = self.get_id(1, 0, i - 1, count[i][2])
195
+ while len(pending_w[i]) > 0 and fa_id >= 0 and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]:
196
+ # fill the bubble
197
+ put_w(i)
198
+ if len(pending_w[i]) > 0 and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]:
199
+ if _chunk_ == 1:
200
+ put_w(i)
201
+ elif fill_b:
202
+ put_w(i)
203
+ put(1, _chunk_, i)
204
+
205
+ # put f
206
+ for i in range(self.n_stage):
207
+ if count[i][1] >= self.n_micro:
208
+ continue
209
+ put_item = None
210
+ if count[i][1] >= count[i][0]:
211
+ put_item = 0
212
+ elif i == self.n_stage - 1:
213
+ put_item = 1
214
+ else:
215
+ if end_time[self.get_id(0, 1, i + 1, count[i][1])] >= 0:
216
+ put_item = 1
217
+ elif count[i][0] < self.n_micro:
218
+ if i == 0:
219
+ put_item = 0
220
+ elif end_time[self.get_id(0, 0, i - 1, count[i][0])] >= 0:
221
+ put_item = 0
222
+ if put_item is None:
223
+ continue
224
+ # check mem before putting f
225
+ while mem[i] + self.fbw_mem[0] > self.max_mem:
226
+ assert len(pending_w[i]) > 0
227
+ put_w(i)
228
+ fa_id = -1
229
+ if put_item == 0 and i > 0:
230
+ fa_id = self.get_id(0, 0, i - 1, count[i][0])
231
+ if put_item == 1 and i < self.n_stage - 1:
232
+ fa_id = self.get_id(0, 1, i + 1, count[i][1])
233
+ while len(pending_w[i]) > 0 and fa_id >= 0 and end_time[fa_id] + self.c_cost >= cur_time[i] + self.fbw_cost[2]:
234
+ # fill the bubble
235
+ put_w(i)
236
+ if len(pending_w[i]) > 0 and end_time[fa_id] + self.c_cost - cur_time[i] > get_max_stage_bubble(i) - stage_bubble[i]:
237
+ if fill_f:
238
+ put_w(i)
239
+ put(0, put_item, i)
240
+
241
+ for i in range(self.n_stage):
242
+ while len(pending_w[i]) > 0:
243
+ put_w(i)
244
+
245
+ # for i in range(self.n_stage):
246
+ # print(stage_str[i])
247
+
248
+ max_bubble = get_max_stage_bubble()
249
+ expected_time = sum(self.fbw_cost) * self.n_micro * 2
250
+ bubble_rate = max_bubble / expected_time
251
+ # print("%6.4f" % bubble_rate, "->", stage_bubble)
252
+ if max_approved_bubble < 0 or max_bubble < max_approved_bubble:
253
+ _schedule, _end_time, _max_bubble = self.try_v_schedule(
254
+ fill_f=fill_f, fill_b=fill_b,
255
+ approved_bubble=stage_bubble,
256
+ )
257
+ if _max_bubble < max_bubble:
258
+ return _schedule, _end_time, _max_bubble
259
+ # print("%2d %3d, [%5d %5d %5d], %6d -> %6.4f %6.4f" % \
260
+ # (self.n_stage, self.n_micro, *self.fbw_cost, self.max_mem // self.f_mem, init_bubble / expected_time, bubble_rate), max_bubble)
261
+ return schedule, end_time, max_bubble
262
+
263
+ def print_details(self, end_time, print_scaling=1):
264
+ for stage in range(self.n_stage):
265
+ stage_str = ['.'] * int(max(end_time) / print_scaling)
266
+ for _cat in range(3):
267
+ for _chunk in range(2):
268
+ for _micro in range(self.n_micro):
269
+ _id = self.get_id(_cat, _chunk, stage, _micro)
270
+ if end_time[_id] < 0:
271
+ continue
272
+ end = int(end_time[_id] / print_scaling)
273
+ start = int((end_time[_id] - self.fbw_cost[_cat]) / print_scaling)
274
+ for j in range(start, end):
275
+ if j == start or j == end - 1:
276
+ stage_str[j] = "FfBbWw"[_cat * 2 + _chunk]
277
+ elif j == start + 1:
278
+ if _micro >= 10:
279
+ stage_str[j] = str(_micro // 10)
280
+ else:
281
+ stage_str[j] = str(_micro)
282
+ elif j == start + 2 and _micro >= 10:
283
+ stage_str[j] = str(_micro % 10)
284
+ else:
285
+ stage_str[j] = "-"
286
+ _str = ""
287
+ for _c in stage_str:
288
+ _str += _c
289
+ print(_str)
290
+
291
+ def get_v_schedule(self):
292
+ schedule, end_time, max_bubble = None, None, None
293
+ expected_time = sum(self.fbw_cost) * self.n_micro * 2
294
+ for fill_b in [True, False]:
295
+ for fill_f in [True, False]:
296
+ _schedule, _end_time, _max_bubble = self.try_v_schedule(
297
+ fill_b=fill_b, fill_f=fill_f
298
+ )
299
+ # print("")
300
+ if max_bubble is None or _max_bubble < max_bubble:
301
+ max_bubble = _max_bubble
302
+ schedule = _schedule
303
+ end_time = _end_time
304
+ # self.print_details(end_time, print_scaling=1)
305
+ bubble_rate = max_bubble / expected_time
306
+ print("%2d %3d, [%5d %5d %5d], %6d -> %6.4f" % \
307
+ (self.n_stage, self.n_micro, *self.fbw_cost, self.max_mem // self.f_mem, bubble_rate))
308
+ local_order = [[] for _ in range(self.n_stage)]
309
+ comm_id = {}
310
+ comm_id_counter = 0
311
+ post_validation_time = 0
312
+ for i in range(self.n_stage - 1, -1, -1):
313
+ pv_id = min(2 * (self.n_stage - 1 - i), self.n_micro - 1)
314
+ post_validation_time = max(post_validation_time, end_time[self.get_id(0, 0, i, pv_id)] - self.fbw_cost[0] - self.c_cost)
315
+ # post_validation_time = 0
316
+ # print(i, pv_id, post_validation_time)
317
+ for it in ["RECV_", "SEND_", ""]:
318
+ if i == 0 and it == "SEND_":
319
+ continue
320
+ if i == self.n_stage - 1 and it == "RECV_":
321
+ continue
322
+ # stage_ = i - 1 if it == "RECV_" else i
323
+ stage_ = i
324
+ local_order[stage_].append(ScheduledNode(
325
+ type=it + "POST_VALIDATION",
326
+ chunk=0,
327
+ stage=stage_,
328
+ minibatch=0,
329
+ start_time=post_validation_time,
330
+ completion_time=post_validation_time,
331
+ ))
332
+ comm_id[local_order[stage_][-1]] = comm_id_counter
333
+ comm_id_counter += 1
334
+ for i in range(self.n_stage):
335
+ for _cat_, _chunk_, _micro_ in schedule[i]:
336
+ complete_time = end_time[self.get_id(_cat_, _chunk_, i, _micro_)]
337
+ local_order[i].append(ScheduledNode(
338
+ type="FBW"[_cat_],
339
+ chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
340
+ stage=i,
341
+ minibatch=_micro_,
342
+ start_time=complete_time - self.fbw_cost[_cat_],
343
+ completion_time=complete_time,
344
+ ))
345
+ if _cat_ == 2: # no communication for W
346
+ continue
347
+ cat_str = "FORWARD" if _cat_ == 0 else "BACKWARD"
348
+ def communicate(send_recv, stage_):
349
+ # noinspection PyTypeChecker
350
+ local_order[stage_].append(ScheduledNode(
351
+ type=send_recv + cat_str,
352
+ chunk=_chunk_ if _cat_ == 0 else 1 - _chunk_,
353
+ stage=stage_,
354
+ minibatch=_micro_,
355
+ start_time=complete_time,
356
+ completion_time=complete_time,
357
+ ))
358
+ comm_id[local_order[stage_][-1]] = comm_id_counter
359
+
360
+ if _chunk_ == 1 and i > 0:
361
+ communicate("SEND_", i)
362
+ communicate("RECV_", i - 1)
363
+ if _chunk_ == 0 and i < self.n_stage - 1:
364
+ communicate("SEND_", i)
365
+ communicate("RECV_", i + 1)
366
+ comm_id_counter += 1
367
+ for rank in range(self.n_stage):
368
+ # For nodes with the same timestamp on the same stage, communication will be prioritized.
369
+ def even_breaker(x: ScheduledNode):
370
+ # Compute nodes are always delayed.
371
+ if x.type in ['F', 'B', 'W']:
372
+ return comm_id_counter
373
+ # For comm nodes, order by their unique comm id
374
+ return comm_id[x]
375
+
376
+ local_order[rank] = list(sorted(
377
+ local_order[rank],
378
+ key=lambda x: (x.start_time, even_breaker(x))
379
+ ))
380
+ # If a recv with intersects with previous computation, reorder them so that recv
381
+ # is executed before computation and hence can be overlapped.
382
+ for i in range(len(local_order[rank])):
383
+ if i > 0 and local_order[rank][i - 1].type in {'F', 'B', 'W'} and \
384
+ local_order[rank][i].type.startswith('RECV') and \
385
+ "POST_VALIDATION" not in local_order[rank][i].type and \
386
+ local_order[rank][i].start_time <= local_order[rank][i - 1].completion_time:
387
+ local_order[rank][i], local_order[rank][i - 1] = local_order[rank][i - 1], local_order[rank][i]
388
+
389
+ local_order_with_rollback = [[] for _ in range(self.n_stage)]
390
+ for rank in range(self.n_stage):
391
+ rollback_comm = set()
392
+ if rank > 0:
393
+ for node in local_order[rank - 1]:
394
+ if node.type == "POST_VALIDATION":
395
+ break
396
+ if node.type == "SEND_FORWARD":
397
+ assert node.chunk == 0
398
+ rollback_comm.add(node.minibatch)
399
+ for node in local_order[rank]:
400
+ if node.type == "RECV_FORWARD" and node.chunk == 0 and node.minibatch in rollback_comm:
401
+ rollback = True
402
+ rollback_comm.remove(node.minibatch)
403
+ else:
404
+ rollback = False
405
+ local_order_with_rollback[rank].append(ScheduledNode(
406
+ type=node.type,
407
+ chunk=node.chunk,
408
+ stage=node.stage,
409
+ minibatch=node.minibatch,
410
+ start_time=node.start_time,
411
+ completion_time=node.completion_time,
412
+ rollback=rollback,
413
+ ))
414
+ assert len(rollback_comm) == 0
415
+ for node in local_order_with_rollback[rank]:
416
+ print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=', ')
417
+ print()
418
+
419
+ return local_order_with_rollback
420
+
421
+
422
+ if __name__ == '__main__':
423
+ settings = [
424
+ # p, n, f, b, w, c, h, a, l
425
+ # (8, 24, 18522, 18086, 9337, 601, 2304, 24, 24),
426
+ # (8, 32, 18513, 18086, 9331, 626, 2304, 24, 24),
427
+ # (8, 64, 18546, 18097, 9321, 762, 2304, 24, 24),
428
+ # (8, 24, 29718, 29444, 19927, 527, 4096, 32, 32),
429
+ # (8, 32, 29802, 29428, 19530, 577, 4096, 32, 32),
430
+ # (8, 64, 29935, 29621, 19388, 535, 4096, 32, 32),
431
+ # (16, 48, 11347, 11248, 8132, 377, 5120, 40, 48),
432
+ # (16, 64, 11307, 11254, 8101, 379, 5120, 40, 48),
433
+ # (16, 128, 11325, 11308, 8109, 378, 5120, 40, 48),
434
+ # (32, 96, 10419, 10207, 7715, 408, 6144, 48, 64),
435
+ # (32, 128, 10408, 10204, 7703, 408, 6144, 48, 64),
436
+ # (32, 256, 10402, 10248, 7698, 460, 6144, 48, 64),
437
+ (4, 8, 6, 4, 4, 1, 4096, 32, 32),
438
+ # (8, 24, 29444, 29718, 19927, 527, 4096, 32, 32),
439
+ ]
440
+ s = 1024
441
+
442
+ # h, a, s = 4096, 32, 1024
443
+ # cost_f, cost_b, cost_w, cost_c = 29718, 29444, 19927, 527
444
+ for p, n, f, b, w, c, h, a, l in settings:
445
+ mem_f = 34 * h + 5 * a * s
446
+ mem_w = - 32 * h
447
+ mem_b = - mem_w - mem_f
448
+ for m_offset in range(p + 1):
449
+ graph = PipelineGraph(
450
+ n_stage=p,
451
+ n_micro=n,
452
+ f_cost=f,
453
+ b_cost=b,
454
+ w_cost=w,
455
+ c_cost=c,
456
+ f_mem=mem_f,
457
+ b_mem=mem_b,
458
+ w_mem=mem_w,
459
+ max_mem=mem_f * (p * 2 + m_offset),
460
+ )
461
+ graph.get_v_schedule()