Wan Xinyi
Fix small m for zbv
c47cbb6
raw
history blame
7.76 kB
import gradio as gr
import auto_schedule
import v_schedule
import hand_schedule
from PIL import Image
from svg_event import render_manual_graph
import pathlib
def greet(name, is_morning, temperature):
salutation = "Good morning" if is_morning else "Good evening"
greeting = f"{salutation} {name}. It is {temperature} degrees today"
celsius = (temperature - 32) * 5 / 9
return greeting, round(celsius, 2)
def percentage(x):
return f"{x*100:.2f}%"
def get_schedule_time(result):
result = [
list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result
]
time = max(
[
max([x.completion_time for x in stage]) - min([x.start_time for x in stage]) for stage in result
]
)
return time
img_queue = []
def get_schedule_image(result, max_time):
result = [
list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result
]
svg = render_manual_graph(result, max_time, len(result[0]) <= 72)
img_queue.append(svg)
if len(img_queue) > 32:
poped = img_queue.pop(0)
pathlib.Path(poped).unlink()
return pathlib.Path(svg)
def calculate(p, m, f, b, w, c, mem):
if mem < p:
baseline_time=None
baseline_bubble=None
baseline_acceleration=None
baseline_image=None
baseline_result=None
else:
baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
baseline_result = [
list(filter(lambda x: x.type in {'F', 'B'}, r)) for r in baseline_result
]
baseline_time = get_schedule_time(baseline_result)
baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
baseline_acceleration=percentage(0)
zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig(
cost_f=f,
cost_b=b,
cost_w=w,
cost_comm=c,
max_mem=mem * 2,
print_scaling=1000
))
zb_time=get_schedule_time(zb_result)
zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
zb_acceleration=percentage(baseline_time/zb_time - 1) if baseline_time is not None else None
if mem < p:
zbv_time=None
zbv_bubble=None
zbv_acceleration=None
zbv_image=None
zbv_result=None
else:
zbv_graph = v_schedule.PipelineGraph(
n_stage=p,
n_micro=m,
f_cost=f/2,
b_cost=b/2,
w_cost=w/2,
c_cost=c,
f_mem=2,
b_mem=-1,
w_mem=-1,
max_mem=mem * 4,
)
zbv_result = zbv_graph.get_v_schedule()
zbv_time = get_schedule_time(zbv_result)
zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
max_time = max(filter(lambda x: x is not None, [baseline_time, zb_time, zbv_time]))
print(max_time)
if baseline_result is not None:
baseline_image = get_schedule_image(baseline_result, max_time)
if zb_result is not None:
zb_image = get_schedule_image(zb_result, max_time)
if zbv_result is not None:
zbv_image = get_schedule_image(zbv_result, max_time)
return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]
with gr.Blocks() as demo:
gr.Markdown(open("description1.md").read())
gr.Markdown("# Pipeline Scheduler Playground")
presets = {
'Ideal Case 1p': (4, 12, 20, 20, 20, 0, '1p (Same as 1F1B)'),
'Ideal Case 2p': (4, 12, 20, 20, 20, 0, '2p'),
'Real Case 1p': (4, 12, 1049, 1122, 903, 79, '1p (Same as 1F1B)'),
'Real Case 2p': (4, 12, 1049, 1122, 903, 79, '2p'),
}
preset_buttons = {}
with gr.Group():
gr.Markdown("Preset Setups")
with gr.Row():
for (k, v) in presets.items():
preset_buttons[k] = gr.Button(k, variant="secondary")
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("Basic Parameters")
with gr.Row():
p=gr.Number(label="Number of stages (p)", value=4, interactive=True, precision=0)
m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0)
with gr.Column(scale=2):
with gr.Group():
gr.Markdown("Costs. All costs are used as integers. For ZBV schedules, this is the time of two virtual stages on a stage combined.")
with gr.Row():
f=gr.Number(label="Time of F", value=100, interactive=True, precision=0)
b=gr.Number(label="Time of B", value=110, interactive=True, precision=0)
w=gr.Number(label="Time of W", value=90, interactive=True, precision=0)
c=gr.Number(label="Time of one P2P communication", value=5, interactive=True, precision=0)
with gr.Group():
gr.Markdown("Activation memory limit.")
def update_mem(p, s, mem):
print("update")
if s=="custom":
return mem
return int(p*float(s.split('p')[0]) + 0.5)
memsel=gr.Radio(choices=["1p (Same as 1F1B)", "1.5p", "2p", "3p", "custom"], value="1p (Same as 1F1B)")
mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0)
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
button=gr.Button("Calculate", variant="primary")
with gr.Group():
gr.Markdown("1F1B")
with gr.Row():
with gr.Column(scale=1):
baseline_time=gr.Textbox("", label="Longest Stage Time")
baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
with gr.Column(scale=4):
baseline_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
with gr.Group():
gr.Markdown("ZB Schedule")
with gr.Row():
with gr.Column(scale=1):
zb_time=gr.Textbox("", label="Longest Stage Time")
zb_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
zb_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
with gr.Column(scale=4):
zb_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
with gr.Group():
gr.Markdown("ZBV Schedule")
with gr.Row():
with gr.Column(scale=1):
zbv_time=gr.Textbox("", label="Longest Stage Time")
zbv_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
zbv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
with gr.Column(scale=4):
zbv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
for (k, v) in presets.items():
def update_preset(pb, p, m, f, b, w, c, mem):
print(pb)
print(presets[pb])
print(presets[pb][-1])
return *presets[pb],*calculate(*presets[pb][:-1], update_mem(p, presets[pb][-1], -1))
preset_buttons[k].click(
update_preset,
inputs=[preset_buttons[k], p, m, f, b, w, c, mem],
outputs=[p, m, f, b, w, c, memsel, baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
gr.Markdown(open("description2.md").read())
demo.launch()