joaogante HF staff commited on
Commit
f5b4b7c
1 Parent(s): a503cd3
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -17,7 +17,10 @@ print("Done")
17
 
18
  def create_medusa_heads(model_id: str):
19
  parser = distributed_run.get_args_parser()
20
- training_args = [
 
 
 
21
  "--model_name_or_path", model_id,
22
  "--data_path", "data/ShareGPT_V4.3_unfiltered_cleaned_split.json",
23
  "--bf16", "True",
@@ -38,10 +41,7 @@ def create_medusa_heads(model_id: str):
38
  "--lazy_preprocess", "True",
39
  "--medusa_num_heads", "3",
40
  "--medusa_num_layers", "1",
41
- ]
42
- args = parser.parse_args(
43
- ["medusa/medusa/train/train.py", "training_script_args"] + training_args
44
- )
45
  distributed_run.run(args)
46
 
47
  # Upload the medusa heads to the Hub
 
17
 
18
  def create_medusa_heads(model_id: str):
19
  parser = distributed_run.get_args_parser()
20
+ args = parser.parse_args([
21
+ "--nproc_per_node", "4",
22
+ "training_script", "medusa/medusa/train/train.py",
23
+ "training_script_args",
24
  "--model_name_or_path", model_id,
25
  "--data_path", "data/ShareGPT_V4.3_unfiltered_cleaned_split.json",
26
  "--bf16", "True",
 
41
  "--lazy_preprocess", "True",
42
  "--medusa_num_heads", "3",
43
  "--medusa_num_layers", "1",
44
+ ])
 
 
 
45
  distributed_run.run(args)
46
 
47
  # Upload the medusa heads to the Hub