joaogante HF staff commited on
Commit
4edd5e3
β€’
1 Parent(s): 2fda365
Files changed (1) hide show
  1. app.py +119 -117
app.py CHANGED
@@ -5,120 +5,122 @@ import torch.distributed.run as distributed_run
5
  from git import Repo
6
  from huggingface_hub import HfApi
7
 
8
-
9
- # Clone the medusa repo locally
10
- print("Cloning the medusa repo locally...")
11
- Repo.clone_from("https://github.com/FasterDecoding/Medusa.git", "medusa")
12
- print("Done")
13
-
14
-
15
- def create_medusa_heads(model_id: str):
16
- parser = distributed_run.get_args_parser()
17
- args = parser.parse_args([
18
- "--nproc_per_node", "4",
19
- "training_script", "medusa/medusa/train/train.py",
20
- "training_script_args",
21
- "--model_name_or_path", model_id,
22
- "--data_path", "ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json",
23
- "--bf16", "True",
24
- "--output_dir", "medusa_heads",
25
- "--num_train_epochs", "1",
26
- "--per_device_train_batch_size", "8",
27
- "--per_device_eval_batch_size", "8",
28
- "--gradient_accumulation_steps", "4",
29
- "--evaluation_strategy", "no",
30
- "--save_strategy", "no",
31
- "--learning_rate", "1e-3",
32
- "--weight_decay", "0.0",
33
- "--warmup_ratio", "0.1",
34
- "--lr_scheduler_type", "cosine",
35
- "--logging_steps", "1",
36
- "--tf32", "True",
37
- "--model_max_length", "2048",
38
- "--lazy_preprocess", "True",
39
- "--medusa_num_heads", "3",
40
- "--medusa_num_layers", "1",
41
- ])
42
- distributed_run.run(args)
43
-
44
- # Upload the medusa heads to the Hub
45
- repo_id = f"medusa-{model_id}"
46
- api = HfApi()
47
- api.create_repo(
48
- repo_id=repo_id,
49
- exist_ok=True,
50
- )
51
- api.upload_folder(
52
- folder_path="medusa_heads",
53
- repo_id=repo_id,
54
- )
55
- return repo_id
56
-
57
- def run(model_id: str) -> str:
58
- # Input validation
59
- if model_id == "":
60
- return """
61
- ### Invalid input 🐞
62
-
63
- Please fill a model_id.
64
- """
65
- print(f"Valid inputs βœ…\nValidating model_id: {model_id}")
66
-
67
- # Attempt to load the base model
68
- try:
69
- config = AutoConfig.from_pretrained(model_id)
70
- tokenizer = AutoTokenizer.from_pretrained(model_id)
71
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
72
- del config, tokenizer, model
73
- except Exception as e:
74
- return f"""
75
- ### {model_id} can't be loaded with AutoClasses 🐞
76
-
77
- {e}
78
- """
79
- print(f"{model_id} can be loaded βœ…\nCreating medusa heads (will take a few hours)")
80
-
81
- # Run the medusa heads creation
82
- try:
83
- repo_id = create_medusa_heads(model_id=model_id)
84
- print("Success βœ…\nMedusa heads uploaded to: ", repo_id)
85
- return f"""
86
- ### Success πŸ”₯
87
-
88
- Yay! Medusa heads were successfully created and uploaded to, {repo_id}
89
- """
90
- except Exception as e:
91
- return f"""
92
- ### Error 😒😒😒
93
-
94
- {e}
95
- """
96
-
97
-
98
- DESCRIPTION = """
99
- The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:
100
-
101
- 1. Input a public model id from the Hub
102
- 2. Click "Submit"
103
- 3. That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the new repo πŸ”₯
104
- """
105
-
106
- title="Create LLM medusa heads in a new repo 🐍"
107
-
108
- with gr.Blocks(title=title) as demo:
109
- description = gr.Markdown(f"""# {title}""")
110
- description = gr.Markdown(DESCRIPTION)
111
-
112
- with gr.Row() as r:
113
- with gr.Column() as c:
114
- model_id = gr.Text(max_lines=1, label="model_id")
115
- with gr.Row() as c:
116
- clean = gr.ClearButton()
117
- submit = gr.Button("Submit", variant="primary")
118
-
119
- with gr.Column() as d:
120
- status_box = gr.Markdown()
121
-
122
- submit.click(run, inputs=[model_id], outputs=status_box, concurrency_limit=1)
123
-
124
- demo.queue(max_size=10).launch(show_api=True)
 
 
 
5
  from git import Repo
6
  from huggingface_hub import HfApi
7
 
8
+ if __name__ == "__main__":
9
+ # Clone the medusa repo locally
10
+ print("Cloning the medusa repo locally...")
11
+ Repo.clone_from("https://github.com/FasterDecoding/Medusa.git", "medusa")
12
+ print("Done")
13
+
14
+
15
+ def create_medusa_heads(model_id: str):
16
+ parser = distributed_run.get_args_parser()
17
+ args = parser.parse_args([
18
+ "--nproc_per_node", "4",
19
+ "training_script", "medusa/medusa/train/train.py",
20
+ "training_script_args",
21
+ "--model_name_or_path", model_id,
22
+ "--data_path", "ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json",
23
+ "--bf16", "True",
24
+ "--output_dir", "medusa_heads",
25
+ "--num_train_epochs", "1",
26
+ "--per_device_train_batch_size", "8",
27
+ "--per_device_eval_batch_size", "8",
28
+ "--gradient_accumulation_steps", "4",
29
+ "--evaluation_strategy", "no",
30
+ "--save_strategy", "no",
31
+ "--learning_rate", "1e-3",
32
+ "--weight_decay", "0.0",
33
+ "--warmup_ratio", "0.1",
34
+ "--lr_scheduler_type", "cosine",
35
+ "--logging_steps", "1",
36
+ "--tf32", "True",
37
+ "--model_max_length", "2048",
38
+ "--lazy_preprocess", "True",
39
+ "--medusa_num_heads", "3",
40
+ "--medusa_num_layers", "1",
41
+ ])
42
+ distributed_run.run(args)
43
+
44
+ # Upload the medusa heads to the Hub
45
+ repo_id = f"medusa-{model_id}"
46
+ api = HfApi()
47
+ api.create_repo(
48
+ repo_id=repo_id,
49
+ exist_ok=True,
50
+ )
51
+ api.upload_folder(
52
+ folder_path="medusa_heads",
53
+ repo_id=repo_id,
54
+ )
55
+ return repo_id
56
+
57
+ def run(model_id: str) -> str:
58
+ print(f"\n\n\nNEW RUN: {model_id}")
59
+
60
+ # Input validation
61
+ if model_id == "":
62
+ return """
63
+ ### Invalid input 🐞
64
+
65
+ Please fill a model_id.
66
+ """
67
+ print(f"Valid inputs βœ…\nValidating model_id: {model_id}")
68
+
69
+ # Attempt to load the base model
70
+ try:
71
+ config = AutoConfig.from_pretrained(model_id)
72
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
73
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
74
+ del config, tokenizer, model
75
+ except Exception as e:
76
+ return f"""
77
+ ### {model_id} can't be loaded with AutoClasses 🐞
78
+
79
+ {e}
80
+ """
81
+ print(f"{model_id} can be loaded βœ…\nCreating medusa heads (will take a few hours)")
82
+
83
+ # Run the medusa heads creation
84
+ try:
85
+ repo_id = create_medusa_heads(model_id=model_id)
86
+ print("Success βœ…\nMedusa heads uploaded to: ", repo_id)
87
+ return f"""
88
+ ### Success πŸ”₯
89
+
90
+ Yay! Medusa heads were successfully created and uploaded to, {repo_id}
91
+ """
92
+ except Exception as e:
93
+ return f"""
94
+ ### Error 😒😒😒
95
+
96
+ {e}
97
+ """
98
+
99
+
100
+ DESCRIPTION = """
101
+ The steps to create [medusa](https://sites.google.com/view/medusa-llm) heads are the following:
102
+
103
+ 1. Input a public model id from the Hub
104
+ 2. Click "Submit"
105
+ 3. That's it! You'll get feedback if it works or not, and if it worked, you'll get the URL of the new repo πŸ”₯
106
+ """
107
+
108
+ title="Create LLM medusa heads in a new repo 🐍"
109
+
110
+ with gr.Blocks(title=title) as demo:
111
+ description = gr.Markdown(f"""# {title}""")
112
+ description = gr.Markdown(DESCRIPTION)
113
+
114
+ with gr.Row() as r:
115
+ with gr.Column() as c:
116
+ model_id = gr.Text(max_lines=1, label="model_id")
117
+ with gr.Row() as c:
118
+ clean = gr.ClearButton()
119
+ submit = gr.Button("Submit", variant="primary")
120
+
121
+ with gr.Column() as d:
122
+ status_box = gr.Markdown()
123
+
124
+ submit.click(run, inputs=[model_id], outputs=status_box, concurrency_limit=1)
125
+
126
+ demo.queue(max_size=10).launch(show_api=True)