Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
_remove_color = "rgb(103,6,12)" | |
_addition_color = "rgb(6,103,12)" | |
def mark_text(text, add=True): | |
if add: | |
color = _addition_color | |
else: | |
color = _remove_color | |
return f'<mark style="background-color:{color}!important;color:white!important">{text}</mark>' | |
def highlight(option): | |
filename = option.lower().replace(' ', '_') | |
with open(f"code_samples/{filename}") as f: | |
output = f.read() | |
lines = output.split("\n") | |
for i,line in enumerate(lines): | |
if line.startswith("-"): | |
lines[i] = "- " + line[1:] | |
lines[i] = mark_text(lines[i], False) | |
elif line.startswith("+"): | |
lines[i] = "+ " + line[1:] | |
lines[i] = mark_text(lines[i], True) | |
else: | |
lines[i] = " " + line | |
return "\n".join(lines).rstrip() | |
with open("code_samples/initial") as f: | |
template = f.read() | |
with open("code_samples/accelerate") as f: | |
accelerated_template = f.read() | |
with open("code_samples/initial_with_metrics") as f: | |
metrics_template = f.read() | |
def change(inp): | |
if inp == "Basic": | |
return (template, highlight(inp), "## Accelerate Code (Base Integration)") | |
elif inp == "Calculating Metrics": | |
return (metrics_template, highlight(inp), f"## Accelerate Code ({inp})") | |
else: | |
return (accelerated_template, highlight(inp), f"## Accelerate Code ({inp})") | |
with gr.Blocks() as demo: | |
gr.Markdown(f'''# Accelerate Template Generator | |
Here is a very basic Python training loop. | |
Select how you would like to introduce an Accelerate capability to add to it.''') | |
inp = gr.Radio( | |
["Basic", "Calculating Metrics", "Checkpointing", "Gradient Accumulation", ], | |
label="Select a feature" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## Initial Code") | |
code = gr.Markdown(template) | |
with gr.Column(): | |
feature = gr.Markdown("## Accelerate Code") | |
out = gr.Markdown() | |
inp.change(fn=change, inputs=inp, outputs=[code, out, feature]) | |
demo.launch() |