Spaces:
Paused
Paused
File size: 4,077 Bytes
632990b b6a509f 7bea5b8 b6a509f 9471796 d56736b 9471796 8bece88 f5b4b7c 6b71df1 256858e 9471796 4f6680e 9471796 d3e5ded 9471796 f5b4b7c 8bece88 9471796 359575a 9471796 7bea5b8 359575a 9471796 d56736b 9471796 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import gradio as gr
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import torch
import torch.distributed.run as distributed_run
from git import Repo
from huggingface_hub import HfApi
import multiprocessing as mp
# Clone the medusa repo locally
print("Cloning the medusa repo locally...")
Repo.clone_from("https://github.com/FasterDecoding/Medusa.git", "medusa")
print("Cloning the Vicuna data locally...")
Repo.clone_from("https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered", "data")
print("Done")
def create_medusa_heads(model_id: str):
parser = distributed_run.get_args_parser()
args = parser.parse_args([
"--nproc_per_node", "2",
"medusa/medusa/train/train.py",
"--model_name_or_path", model_id,
"--data_path", "data/ShareGPT_V4.3_unfiltered_cleaned_split.json",
"--bf16", "True",
"--output_dir", "medusa_heads",
"--num_train_epochs", "1",
"--per_device_train_batch_size", "4",
"--per_device_eval_batch_size", "4",
"--gradient_accumulation_steps", "4",
"--evaluation_strategy", "no",
"--save_strategy", "no",
"--learning_rate", "1e-3",
"--weight_decay", "0.0",
"--warmup_ratio", "0.1",
"--lr_scheduler_type", "cosine",
"--logging_steps", "1",
"--tf32", "True",
"--model_max_length", "2048",
"--lazy_preprocess", "True",
"--medusa_num_heads", "3",
"--medusa_num_layers", "1",
])
distributed_run.run(args)
# Upload the medusa heads to the Hub
repo_id = f"{model_id}-medusa"
api = HfApi()
api.create_repo(
repo_id=repo_id,
exist_ok=True,
)
api.upload_folder(
folder_path="medusa_heads",
repo_id=repo_id,
)
return repo_id
def run(model_id: str) -> str:
print(f"\n\n\nNEW RUN: {model_id}")
# Input validation
if model_id == "":
return """
### Invalid input π
Please fill a model_id.
"""
print(f"Valid inputs β
\nValidating model_id: {model_id}")
# Attempt to load the base model
try:
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
del config, tokenizer, model
except Exception as e:
return f"""
### {model_id} can't be loaded with AutoClasses π
{e}
"""
print(f"{model_id} can be loaded β
\nCreating medusa heads (will take a few hours)")
# Run the medusa heads creation
try:
proc = mp.Process(target=create_medusa_heads, args=(model_id, ))
proc.start()
proc.join()
repo_id = f"{model_id}/medusa"
print("Success β
\nMedusa heads uploaded to: ", repo_id)
return f"""
### Success π₯
Yay! Medusa heads were successfully created and uploaded to, {repo_id}
"""
except Exception as e:
print("Error β\n", e)
return f"""
### Error π’π’π’
{e}
"""
DESCRIPTION = """
The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:
1. Input a public model id from the Hub
2. Click "Submit"
3. That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the new repo π₯
"""
title="Create LLM medusa heads in a new repo π"
with gr.Blocks(title=title) as demo:
description = gr.Markdown(f"""# {title}""")
description = gr.Markdown(DESCRIPTION)
with gr.Row() as r:
with gr.Column() as c:
model_id = gr.Text(max_lines=1, label="model_id")
with gr.Row() as c:
clean = gr.ClearButton()
submit = gr.Button("Submit", variant="primary")
with gr.Column() as d:
status_box = gr.Markdown()
submit.click(run, inputs=[model_id], outputs=status_box, concurrency_limit=1)
demo.queue(max_size=10).launch(show_api=True)
|