Narsil HF staff commited on
Commit
ca6ec81
1 Parent(s): dfef5b2

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +7 -7
convert.py CHANGED
@@ -71,7 +71,7 @@ def rename(pt_filename: str) -> str:
71
  return local
72
 
73
 
74
- def convert_multi(model_id: str, folder: str, token: str) -> ConversionResult:
75
  filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token)
76
  with open(filename, "r") as f:
77
  data = json.load(f)
@@ -79,7 +79,7 @@ def convert_multi(model_id: str, folder: str, token: str) -> ConversionResult:
79
  filenames = set(data["weight_map"].values())
80
  local_filenames = []
81
  for filename in filenames:
82
- pt_filename = hf_hub_download(repo_id=model_id, filename=filename)
83
 
84
  sf_filename = rename(pt_filename)
85
  sf_filename = os.path.join(folder, sf_filename)
@@ -102,7 +102,7 @@ def convert_multi(model_id: str, folder: str, token: str) -> ConversionResult:
102
  return operations, errors
103
 
104
 
105
- def convert_single(model_id: str, folder: str, token: str) -> ConversionResult:
106
  pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token)
107
 
108
  sf_name = "model.safetensors"
@@ -156,8 +156,8 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
156
  return "\n".join(errors)
157
 
158
 
159
- def check_final_model(model_id: str, folder: str):
160
- config = hf_hub_download(repo_id=model_id, filename="config.json")
161
  shutil.copy(config, os.path.join(folder, "config.json"))
162
  config = AutoConfig.from_pretrained(folder)
163
 
@@ -236,7 +236,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
236
  return None
237
 
238
 
239
- def convert_generic(model_id: str, folder: str, filenames: Set[str], token: str) -> ConversionResult:
240
  operations = []
241
  errors = []
242
 
@@ -288,7 +288,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
288
  operations, errors = convert_multi(model_id, folder, token=api.token)
289
  else:
290
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
291
- check_final_model(model_id, folder)
292
  else:
293
  operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
294
 
 
71
  return local
72
 
73
 
74
+ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
75
  filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token)
76
  with open(filename, "r") as f:
77
  data = json.load(f)
 
79
  filenames = set(data["weight_map"].values())
80
  local_filenames = []
81
  for filename in filenames:
82
+ pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token)
83
 
84
  sf_filename = rename(pt_filename)
85
  sf_filename = os.path.join(folder, sf_filename)
 
102
  return operations, errors
103
 
104
 
105
+ def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
106
  pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token)
107
 
108
  sf_name = "model.safetensors"
 
156
  return "\n".join(errors)
157
 
158
 
159
+ def check_final_model(model_id: str, folder: str, token: Optional[str]):
160
+ config = hf_hub_download(repo_id=model_id, filename="config.json", token=token)
161
  shutil.copy(config, os.path.join(folder, "config.json"))
162
  config = AutoConfig.from_pretrained(folder)
163
 
 
236
  return None
237
 
238
 
239
+ def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
240
  operations = []
241
  errors = []
242
 
 
288
  operations, errors = convert_multi(model_id, folder, token=api.token)
289
  else:
290
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
291
+ check_final_model(model_id, folder, token=api.token)
292
  else:
293
  operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
294