Spaces:
Paused
Paused
""" | |
Holds the gradio app itself | |
""" | |
import os | |
import gradio as gr | |
from src.train_workflow import run, DEFAULT_TRAINING_ARGS | |
from src.calibration_datasets import CalibrationDataset | |
# TODO: install FA2 in a better way, e.g docker img | |
os.system("pip install flash-attn --no-build-isolation") | |
DESCRIPTION = """ | |
The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following: | |
1. Input a public model id from the Hub | |
2. Select a dataset to train the medusa heads on. The dataset should be representative of the downstream use case. | |
3. Click "Submit" | |
4. That's it! You'll get feedback if it works or not, and if it worked, you'll get the name of the new repo π₯ | |
""" | |
title="Create LLM medusa heads in a new repo π" | |
with gr.Blocks(title=title) as demo: | |
description = gr.Markdown(f"""# {title}""") | |
description = gr.Markdown(DESCRIPTION) | |
with gr.Row() as r: | |
with gr.Column() as c: | |
model_id = gr.Text(max_lines=1, label="model_id") | |
dataset_names = [ | |
cls.dataset for cls in CalibrationDataset.__subclasses__() | |
] | |
dataset = gr.Dropdown(dataset_names, label="dataset") | |
with gr.Accordion("Training arguments (advanced)", open=False): | |
training_args = gr.Textbox(DEFAULT_TRAINING_ARGS, interactive=True, lines=20, label="training_args") | |
with gr.Row() as c: | |
clean = gr.ClearButton() | |
submit = gr.Button("Submit", variant="primary") | |
with gr.Column() as d: | |
status_box = gr.Markdown() | |
submit.click(run, inputs=[model_id, training_args, dataset], outputs=status_box, concurrency_limit=1) | |
demo.queue(max_size=10).launch(show_api=True) | |