Narsil HF staff commited on
Commit
dfef5b2
1 Parent(s): 65428ae

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +9 -9
convert.py CHANGED
@@ -71,8 +71,8 @@ def rename(pt_filename: str) -> str:
71
  return local
72
 
73
 
74
- def convert_multi(model_id: str, folder: str) -> ConversionResult:
75
- filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
76
  with open(filename, "r") as f:
77
  data = json.load(f)
78
 
@@ -102,8 +102,8 @@ def convert_multi(model_id: str, folder: str) -> ConversionResult:
102
  return operations, errors
103
 
104
 
105
- def convert_single(model_id: str, folder: str) -> ConversionResult:
106
- pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
107
 
108
  sf_name = "model.safetensors"
109
  sf_filename = os.path.join(folder, sf_name)
@@ -236,7 +236,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss
236
  return None
237
 
238
 
239
- def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> ConversionResult:
240
  operations = []
241
  errors = []
242
 
@@ -244,7 +244,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> Conversi
244
  for filename in filenames:
245
  prefix, ext = os.path.splitext(filename)
246
  if ext in extensions:
247
- pt_filename = hf_hub_download(model_id, filename=filename)
248
  dirname, raw_filename = os.path.split(filename)
249
  if raw_filename == "pytorch_model.bin":
250
  # XXX: This is a special case to handle `transformers` and the
@@ -283,14 +283,14 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn
283
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
284
  elif library_name == "transformers":
285
  if "pytorch_model.bin" in filenames:
286
- operations, errors = convert_single(model_id, folder)
287
  elif "pytorch_model.bin.index.json" in filenames:
288
- operations, errors = convert_multi(model_id, folder)
289
  else:
290
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
291
  check_final_model(model_id, folder)
292
  else:
293
- operations, errors = convert_generic(model_id, folder, filenames)
294
 
295
  if operations:
296
  new_pr = api.create_commit(
 
71
  return local
72
 
73
 
74
+ def convert_multi(model_id: str, folder: str, token: str) -> ConversionResult:
75
+ filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token)
76
  with open(filename, "r") as f:
77
  data = json.load(f)
78
 
 
102
  return operations, errors
103
 
104
 
105
+ def convert_single(model_id: str, folder: str, token: str) -> ConversionResult:
106
+ pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token)
107
 
108
  sf_name = "model.safetensors"
109
  sf_filename = os.path.join(folder, sf_name)
 
236
  return None
237
 
238
 
239
+ def convert_generic(model_id: str, folder: str, filenames: Set[str], token: str) -> ConversionResult:
240
  operations = []
241
  errors = []
242
 
 
244
  for filename in filenames:
245
  prefix, ext = os.path.splitext(filename)
246
  if ext in extensions:
247
+ pt_filename = hf_hub_download(model_id, filename=filename, token=token)
248
  dirname, raw_filename = os.path.split(filename)
249
  if raw_filename == "pytorch_model.bin":
250
  # XXX: This is a special case to handle `transformers` and the
 
283
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
284
  elif library_name == "transformers":
285
  if "pytorch_model.bin" in filenames:
286
+ operations, errors = convert_single(model_id, folder, token=api.token)
287
  elif "pytorch_model.bin.index.json" in filenames:
288
+ operations, errors = convert_multi(model_id, folder, token=api.token)
289
  else:
290
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
291
  check_final_model(model_id, folder)
292
  else:
293
+ operations, errors = convert_generic(model_id, folder, filenames, token=api.token)
294
 
295
  if operations:
296
  new_pr = api.create_commit(