import gradio as gr import auto_schedule import v_schedule import hand_schedule from PIL import Image from svg_event import render_manual_graph import pathlib def greet(name, is_morning, temperature): salutation = "Good morning" if is_morning else "Good evening" greeting = f"{salutation} {name}. It is {temperature} degrees today" celsius = (temperature - 32) * 5 / 9 return greeting, round(celsius, 2) def percentage(x): return f"{x*100:.2f}%" def get_schedule_time(result): result = [ list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result ] time = max( [ max([x.completion_time for x in stage]) - min([x.start_time for x in stage]) for stage in result ] ) return time img_queue = [] def get_schedule_image(result, max_time): result = [ list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result ] svg = render_manual_graph(result, max_time, len(result[0]) <= 72) img_queue.append(svg) if len(img_queue) > 32: poped = img_queue.pop(0) pathlib.Path(poped).unlink() return pathlib.Path(svg) def calculate(p, m, f, b, w, c, mem): if mem < p: baseline_time=None baseline_bubble=None baseline_acceleration=None baseline_image=None baseline_result=None else: baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c) baseline_result = [ list(filter(lambda x: x.type in {'F', 'B'}, r)) for r in baseline_result ] baseline_time = get_schedule_time(baseline_result) baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1) baseline_acceleration=percentage(0) zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig( cost_f=f, cost_b=b, cost_w=w, cost_comm=c, max_mem=mem * 2, print_scaling=1000 )) zb_time=get_schedule_time(zb_result) zb_bubble=percentage(zb_time/(f+b+w)/m - 1) zb_acceleration=percentage(baseline_time/zb_time - 1) if baseline_time is not None else None if mem < p: zbv_time=None zbv_bubble=None zbv_acceleration=None zbv_image=None zbv_result=None else: zbv_graph = v_schedule.PipelineGraph( n_stage=p, n_micro=m, f_cost=f/2, b_cost=b/2, w_cost=w/2, c_cost=c, f_mem=2, b_mem=-1, w_mem=-1, max_mem=mem * 4, ) zbv_result = zbv_graph.get_v_schedule() zbv_time = get_schedule_time(zbv_result) zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1) zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None max_time = max(filter(lambda x: x is not None, [baseline_time, zb_time, zbv_time])) print(max_time) if baseline_result is not None: baseline_image = get_schedule_image(baseline_result, max_time) if zb_result is not None: zb_image = get_schedule_image(zb_result, max_time) if zbv_result is not None: zbv_image = get_schedule_image(zbv_result, max_time) return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image] with gr.Blocks() as demo: gr.Markdown(open("description1.md").read()) gr.Markdown("# Pipeline Scheduler Playground") presets = { 'Ideal Case 1p': (4, 12, 20, 20, 20, 0, '1p (Same as 1F1B)'), 'Ideal Case 2p': (4, 12, 20, 20, 20, 0, '2p'), 'Real Case 1p': (4, 12, 1049, 1122, 903, 79, '1p (Same as 1F1B)'), 'Real Case 2p': (4, 12, 1049, 1122, 903, 79, '2p'), } preset_buttons = {} with gr.Group(): gr.Markdown("Preset Setups") with gr.Row(): for (k, v) in presets.items(): preset_buttons[k] = gr.Button(k, variant="secondary") with gr.Row(): with gr.Column(scale=1): with gr.Group(): gr.Markdown("Basic Parameters") with gr.Row(): p=gr.Number(label="Number of stages (p)", value=4, interactive=True, precision=0) m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0) with gr.Column(scale=2): with gr.Group(): gr.Markdown("Costs. All costs are used as integers. For ZBV schedules, this is the time of two virtual stages on a stage combined.") with gr.Row(): f=gr.Number(label="Time of F", value=100, interactive=True, precision=0) b=gr.Number(label="Time of B", value=110, interactive=True, precision=0) w=gr.Number(label="Time of W", value=90, interactive=True, precision=0) c=gr.Number(label="Time of one P2P communication", value=5, interactive=True, precision=0) with gr.Group(): gr.Markdown("Activation memory limit.") def update_mem(p, s, mem): print("update") if s=="custom": return mem return int(p*float(s.split('p')[0]) + 0.5) memsel=gr.Radio(choices=["1p (Same as 1F1B)", "1.5p", "2p", "3p", "custom"], value="1p (Same as 1F1B)") mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0) memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem) p.change(update_mem, inputs=[p, memsel, mem], outputs=mem) button=gr.Button("Calculate", variant="primary") with gr.Group(): gr.Markdown("1F1B") with gr.Row(): with gr.Column(scale=1): baseline_time=gr.Textbox("", label="Longest Stage Time") baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).") baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") with gr.Column(scale=4): baseline_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False) with gr.Group(): gr.Markdown("ZB Schedule") with gr.Row(): with gr.Column(scale=1): zb_time=gr.Textbox("", label="Longest Stage Time") zb_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).") zb_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") with gr.Column(scale=4): zb_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False) with gr.Group(): gr.Markdown("ZBV Schedule") with gr.Row(): with gr.Column(scale=1): zbv_time=gr.Textbox("", label="Longest Stage Time") zbv_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).") zbv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") with gr.Column(scale=4): zbv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False) button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]) for (k, v) in presets.items(): def update_preset(pb, p, m, f, b, w, c, mem): print(pb) print(presets[pb]) print(presets[pb][-1]) return *presets[pb],*calculate(*presets[pb][:-1], update_mem(p, presets[pb][-1], -1)) preset_buttons[k].click( update_preset, inputs=[preset_buttons[k], p, m, f, b, w, c, mem], outputs=[p, m, f, b, w, c, memsel, baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]) gr.Markdown(open("description2.md").read()) demo.launch()