muellerzr's picture
muellerzr HF staff
Refactor
06a60a3
raw
history blame
5.89 kB
from contextlib import contextmanager
import gradio as gr
from markup import get_text, highlight
from template import get_templates
templates = get_templates()
def fill_tab(title, explanation):
"""
Fill the tab with the appropriate title and explanation.
"""
return gr.Markdown(title), gr.Markdown(explanation)
@contextmanager
def new_section():
"""
A context manager to create a new section in the interface. Equivalent of:
```python
with gr.Row():
with gr.Column():
...
```
"""
with gr.Row():
with gr.Column():
yield
def change(inp, textbox):
"""Based on an `inp`, render and highlight the appropriate code sample.
Args:
inp (`str`):
The input button from the interface.
textbox (`str`):
The textbox specifying the tab name from the interface.
Returns:
`tuple`: A tuple of the highlighted code diff, and the title for the section.
"""
if textbox == "base":
code, explanation, docs = get_text(inp, textbox)
if inp == "Basic":
return (
highlight(code),
"## Accelerate Code (Base Integration)",
explanation,
docs,
)
elif inp == "Calculating Metrics":
return (highlight(code), f"## Accelerate Code ({inp})", explanation, docs)
else:
return (highlight(code), f"## Accelerate Code ({inp})", explanation, docs)
elif textbox == "training_configuration":
yaml, changes, command, explanation, docs = get_text(inp, textbox)
return (highlight(yaml), highlight(changes), command, explanation, docs)
else:
raise ValueError(f"Invalid tab name: {textbox}")
default_base = change("Basic", "base")
default_training_config = change("Multi GPU", "training_configuration")
def base_features(textbox):
inp = gr.Radio(
[
"Basic",
"Calculating Metrics",
"Checkpointing",
"Experiment Tracking",
"Gradient Accumulation",
],
label="Select a feature you would like to integrate",
value="Basic",
)
with new_section():
feature, out = fill_tab("## Accelerate Code", default_base[0])
with new_section():
_, explanation = fill_tab("## Explanation", default_base[2])
with new_section():
_, docs = fill_tab("## Documentation Links", default_base[3])
inp.change(
fn=change, inputs=[inp, textbox], outputs=[out, feature, explanation, docs]
)
def training_config(textbox):
inp = gr.Radio(
[
"AWS SageMaker",
"DeepSpeed",
"Megatron-LM",
"Multi GPU",
"Multi Node Multi GPU",
"PyTorch FSDP",
],
label="Select a distributed YAML configuration you would like to view.",
value="Multi GPU",
)
with new_section():
_, yaml = fill_tab("## Example YAML Configuration", default_training_config[0])
with new_section():
_, changes = fill_tab(
"## Changes to Training Script", default_training_config[1]
)
with new_section():
_, command = fill_tab("## Command to Run Training", default_training_config[2])
with new_section():
_, explanation = fill_tab("## Explanation", default_training_config[3])
with new_section():
_, docs = fill_tab("## Documentation Links", default_training_config[4])
inp.change(
fn=change,
inputs=[inp, textbox],
outputs=[yaml, changes, command, explanation, docs],
)
# def big_model_inference():
# inp = gr.Radio(
# ["Accelerate's Big Model Inference",], # "DeepSpeed ZeRO Stage-3 Offload"
# label="Select a feature you would like to integrate",
# value="Basic",
# )
# with gr.Row():
# with gr.Column():
# feature = gr.Markdown("## Accelerate Code")
# out = gr.Markdown(default[0])
# with gr.Row():
# with gr.Column():
# gr.Markdown(default[1])
# explanation = gr.Markdown(default[2])
# with gr.Row():
# with gr.Column():
# gr.Markdown("## Documentation Links")
# docs = gr.Markdown(default[3])
# inp.change(fn=change, inputs=[inp, "big_model_inference"], outputs=[out, feature, explanation, docs])
# def notebook_launcher():
# inp = gr.Radio(
# ["Colab GPU", "Colab TPU", "Kaggle GPU", "Kaggle Multi GPU", "Kaggle TPU", "Multi GPU VMs"],
# label="Select a feature you would like to integrate",
# value="Basic",
# )
# with gr.Row():
# with gr.Column():
# feature = gr.Markdown("## Accelerate Code")
# out = gr.Markdown(default[0])
# with gr.Row():
# with gr.Column():
# gr.Markdown(default[1])
# explanation = gr.Markdown(default[2])
# with gr.Row():
# with gr.Column():
# gr.Markdown("## Documentation Links")
# docs = gr.Markdown(default[3])
# inp.change(fn=change, inputs=[inp, "notebook_launcher"], outputs=[out, feature, explanation, docs])
with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("Basic Training Integration"):
textbox = gr.Textbox(label="tab_name", visible=False, value="base")
base_features(textbox)
with gr.TabItem("Launch Configuration"):
textbox = gr.Textbox(
label="tab_name", visible=False, value="training_configuration"
)
training_config(textbox)
with gr.TabItem("Big Model Inference"):
# big_model_inference()
pass
with gr.TabItem("Launching from Notebooks"):
# notebook_launcher()
pass
demo.launch()