Wauplin HF staff commited on
Commit
211a715
1 Parent(s): 3fd8bf0
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -7,8 +7,9 @@ import gradio as gr
7
  import huggingface_hub
8
  import torch
9
  import yaml
10
- from gradio_logsview.logsview import Log, LogsView
11
  from mergekit.common import parse_kmb
 
12
  from mergekit.merge import run_merge
13
  from mergekit.options import MergeOptions
14
 
@@ -83,7 +84,7 @@ def merge(
83
  if not yaml_config:
84
  raise gr.Error("Empty yaml, pick an example below")
85
  try:
86
- merge_config = yaml.safe_load(yaml_config)
87
  except Exception as e:
88
  raise gr.Error(f"Invalid yaml {e}")
89
 
@@ -94,6 +95,13 @@ def merge(
94
  config_path = merged_path / "config.yaml"
95
  config_path.write_text(yaml_config)
96
 
 
 
 
 
 
 
 
97
  # Taken from https://github.com/arcee-ai/mergekit/blob/main/mergekit/scripts/run_yaml.py
98
  yield from LogsView.run_thread(
99
  run_merge,
@@ -104,7 +112,7 @@ def merge(
104
  config_source=config_path,
105
  )
106
 
107
- ## TODO(implement upload at the end of the merge, and display the repo URL)
108
  api = huggingface_hub.HfApi(token=hf_token)
109
  repo_url = api.create_repo(repo_name, exist_ok=True)
110
  api.upload_folder(repo_id=repo_url.repo_id, folder_path=merged_path)
 
7
  import huggingface_hub
8
  import torch
9
  import yaml
10
+ from gradio_logsview.logsview import LogsView
11
  from mergekit.common import parse_kmb
12
+ from mergekit.config import MergeConfiguration
13
  from mergekit.merge import run_merge
14
  from mergekit.options import MergeOptions
15
 
 
84
  if not yaml_config:
85
  raise gr.Error("Empty yaml, pick an example below")
86
  try:
87
+ merge_config = MergeConfiguration.model_validate(yaml.safe_load(yaml_config))
88
  except Exception as e:
89
  raise gr.Error(f"Invalid yaml {e}")
90
 
 
95
  config_path = merged_path / "config.yaml"
96
  config_path.write_text(yaml_config)
97
 
98
+ if repo_name == "":
99
+ name = "-".join(
100
+ model.model.path for model in merge_config.referenced_models()
101
+ )
102
+ repo_name = f"mergekit-{merge_config.merge_method}-{name}".replace("/", "-")
103
+ print(f"Will save in {repo_name}")
104
+
105
  # Taken from https://github.com/arcee-ai/mergekit/blob/main/mergekit/scripts/run_yaml.py
106
  yield from LogsView.run_thread(
107
  run_merge,
 
112
  config_source=config_path,
113
  )
114
 
115
+ # TODO: nicely display things
116
  api = huggingface_hub.HfApi(token=hf_token)
117
  repo_url = api.create_repo(repo_name, exist_ok=True)
118
  api.upload_folder(repo_id=repo_url.repo_id, folder_path=merged_path)