File size: 3,730 Bytes
632990b
b6a509f
 
 
 
 
 
9471796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import torch
import torch.distributed.run as distributed_run
from git import Repo
from huggingface_hub import HfApi


# Clone the medusa repo locally
print("Cloning the medusa repo locally...")
Repo.clone_from("https://github.com/FasterDecoding/Medusa.git", "medusa")
print("Done")


def create_medusa_heads(model_id: str):
    training_args = [
        "--model_name_or_path", model_id,
        "--data_path", "ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json",
        "--bf16", "True",
        "--output_dir", "medusa_heads",
        "--num_train_epochs", "1",
        "--per_device_train_batch_size", "8",
        "--per_device_eval_batch_size", "8",
        "--gradient_accumulation_steps", "4",
        "--evaluation_strategy", "no",
        "--save_strategy", "no",
        "--learning_rate", "1e-3",
        "--weight_decay", "0.0",
        "--warmup_ratio", "0.1",
        "--lr_scheduler_type", "cosine",
        "--logging_steps", "1",
        "--tf32", "True",
        "--model_max_length", "2048",
        "--lazy_preprocess", "True",
        "--medusa_num_heads", "3",
        "--medusa_num_layers", "1",
    ]
    distributed_run.run_script_path("medusa/medusa/train/train.py", *training_args)

    # Upload the medusa heads to the Hub
    repo_id = f"medusa-{model_id}"
    api = HfApi()
    api.create_repo(
        repo_id=repo_id,
        exist_ok=True,
    )
    api.upload_folder(
        folder_path="medusa_heads",
        repo_id=repo_id,
    )
    return repo_id

def run(model_id: str) -> str:
    print(f"\n\n\nNEW RUN: {model_id}")

    # Input validation
    if model_id == "":
        return """
        ### Invalid input 🐞

        Please fill a model_id.
        """
    print(f"Valid inputs βœ…\nValidating model_id: {model_id}")

    # Attempt to load the base model
    try:
        config = AutoConfig.from_pretrained(model_id)
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
        del config, tokenizer, model
    except Exception as e:
        return f"""
        ### {model_id} can't be loaded with AutoClasses 🐞

        {e}
        """
    print(f"{model_id} can be loaded βœ…\nCreating medusa heads (will take a few hours)")

    # Run the medusa heads creation
    try:
        repo_id = create_medusa_heads(model_id=model_id)
        print("Success βœ…\nMedusa heads uploaded to: ", repo_id)
        return f"""
        ### Success πŸ”₯

        Yay! Medusa heads were successfully created and uploaded to, {repo_id}
        """
    except Exception as e:
        return f"""
        ### Error 😒😒😒

        {e}
        """


DESCRIPTION = """
The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:

1. Input a public model id from the Hub
2. Click "Submit"
3. That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the new repo πŸ”₯
"""

title="Create LLM medusa heads in a new repo 🐍"

with gr.Blocks(title=title) as demo:
    description = gr.Markdown(f"""# {title}""")
    description = gr.Markdown(DESCRIPTION)

    with gr.Row() as r:
        with gr.Column() as c:
            model_id = gr.Text(max_lines=1, label="model_id")
            with gr.Row() as c:
                clean = gr.ClearButton()
                submit = gr.Button("Submit", variant="primary")

        with gr.Column() as d:
            status_box = gr.Markdown()

    submit.click(run, inputs=[model_id], outputs=status_box, concurrency_limit=1)

demo.queue(max_size=10).launch(show_api=True)