joaogante HF staff commited on
Commit
8bece88
1 Parent(s): 4f6680e
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -15,6 +15,7 @@ print("Done")
15
 
16
 
17
  def create_medusa_heads(model_id: str):
 
18
  training_args = [
19
  "--model_name_or_path", model_id,
20
  "--data_path", "data/ShareGPT_V4.3_unfiltered_cleaned_split.json",
@@ -37,7 +38,10 @@ def create_medusa_heads(model_id: str):
37
  "--medusa_num_heads", "3",
38
  "--medusa_num_layers", "1",
39
  ]
40
- distributed_run.run_script_path("medusa/medusa/train/train.py", *training_args)
 
 
 
41
 
42
  # Upload the medusa heads to the Hub
43
  repo_id = f"medusa-{model_id}"
 
15
 
16
 
17
  def create_medusa_heads(model_id: str):
18
+ parser = distributed_run.get_args_parser()
19
  training_args = [
20
  "--model_name_or_path", model_id,
21
  "--data_path", "data/ShareGPT_V4.3_unfiltered_cleaned_split.json",
 
38
  "--medusa_num_heads", "3",
39
  "--medusa_num_layers", "1",
40
  ]
41
+ args = parser.parse_args(
42
+ ["training_script", "medusa/medusa/train/train.py", "training_script_args"] + training_args
43
+ )
44
+ distributed_run.run(args)
45
 
46
  # Upload the medusa heads to the Hub
47
  repo_id = f"medusa-{model_id}"