medusa-maker / app.py
Narsil's picture
Narsil HF staff
Update app.py
359575a
raw
history blame
4.08 kB
import gradio as gr
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import torch
import torch.distributed.run as distributed_run
from git import Repo
from huggingface_hub import HfApi
import multiprocessing as mp
# Clone the medusa repo locally
print("Cloning the medusa repo locally...")
Repo.clone_from("https://github.com/FasterDecoding/Medusa.git", "medusa")
print("Cloning the Vicuna data locally...")
Repo.clone_from("https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered", "data")
print("Done")
def create_medusa_heads(model_id: str):
parser = distributed_run.get_args_parser()
args = parser.parse_args([
"--nproc_per_node", "2",
"medusa/medusa/train/train.py",
"--model_name_or_path", model_id,
"--data_path", "data/ShareGPT_V4.3_unfiltered_cleaned_split.json",
"--bf16", "True",
"--output_dir", "medusa_heads",
"--num_train_epochs", "1",
"--per_device_train_batch_size", "4",
"--per_device_eval_batch_size", "4",
"--gradient_accumulation_steps", "4",
"--evaluation_strategy", "no",
"--save_strategy", "no",
"--learning_rate", "1e-3",
"--weight_decay", "0.0",
"--warmup_ratio", "0.1",
"--lr_scheduler_type", "cosine",
"--logging_steps", "1",
"--tf32", "True",
"--model_max_length", "2048",
"--lazy_preprocess", "True",
"--medusa_num_heads", "3",
"--medusa_num_layers", "1",
])
distributed_run.run(args)
# Upload the medusa heads to the Hub
repo_id = f"{model_id}-medusa"
api = HfApi()
api.create_repo(
repo_id=repo_id,
exist_ok=True,
)
api.upload_folder(
folder_path="medusa_heads",
repo_id=repo_id,
)
return repo_id
def run(model_id: str) -> str:
print(f"\n\n\nNEW RUN: {model_id}")
# Input validation
if model_id == "":
return """
### Invalid input 🐞
Please fill a model_id.
"""
print(f"Valid inputs βœ…\nValidating model_id: {model_id}")
# Attempt to load the base model
try:
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
del config, tokenizer, model
except Exception as e:
return f"""
### {model_id} can't be loaded with AutoClasses 🐞
{e}
"""
print(f"{model_id} can be loaded βœ…\nCreating medusa heads (will take a few hours)")
# Run the medusa heads creation
try:
proc = mp.Process(target=create_medusa_heads, args=(model_id, ))
proc.start()
proc.join()
repo_id = f"{model_id}/medusa"
print("Success βœ…\nMedusa heads uploaded to: ", repo_id)
return f"""
### Success πŸ”₯
Yay! Medusa heads were successfully created and uploaded to, {repo_id}
"""
except Exception as e:
print("Error ❌\n", e)
return f"""
### Error 😒😒😒
{e}
"""
DESCRIPTION = """
The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:
1. Input a public model id from the Hub
2. Click "Submit"
3. That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the new repo πŸ”₯
"""
title="Create LLM medusa heads in a new repo 🐍"
with gr.Blocks(title=title) as demo:
description = gr.Markdown(f"""# {title}""")
description = gr.Markdown(DESCRIPTION)
with gr.Row() as r:
with gr.Column() as c:
model_id = gr.Text(max_lines=1, label="model_id")
with gr.Row() as c:
clean = gr.ClearButton()
submit = gr.Button("Submit", variant="primary")
with gr.Column() as d:
status_box = gr.Markdown()
submit.click(run, inputs=[model_id], outputs=status_box, concurrency_limit=1)
demo.queue(max_size=10).launch(show_api=True)