Narsil HF staff commited on
Commit
1420ae0
1 Parent(s): 705ab71

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +44 -17
convert.py CHANGED
@@ -5,7 +5,7 @@ import shutil
5
  from collections import defaultdict
6
  from inspect import signature
7
  from tempfile import TemporaryDirectory
8
- from typing import Dict, List, Optional, Set
9
 
10
  import torch
11
 
@@ -33,6 +33,8 @@ If you find any issues: please report here: https://huggingface.co/spaces/safete
33
  Feel free to ignore this PR.
34
  """
35
 
 
 
36
 
37
  class AlreadyExists(Exception):
38
  pass
@@ -69,7 +71,7 @@ def rename(pt_filename: str) -> str:
69
  return local
70
 
71
 
72
- def convert_multi(model_id: str, folder: str) -> List["CommitOperationAdd"]:
73
  filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
74
  with open(filename, "r") as f:
75
  data = json.load(f)
@@ -95,18 +97,20 @@ def convert_multi(model_id: str, folder: str) -> List["CommitOperationAdd"]:
95
  operations = [
96
  CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames
97
  ]
 
98
 
99
- return operations
100
 
101
 
102
- def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
103
  pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
104
 
105
  sf_name = "model.safetensors"
106
  sf_filename = os.path.join(folder, sf_name)
107
  convert_file(pt_filename, sf_filename)
108
  operations = [CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)]
109
- return operations
 
110
 
111
 
112
  def convert_file(
@@ -204,18 +208,22 @@ def check_final_model(model_id: str, folder: str):
204
 
205
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
206
  try:
 
207
  discussions = api.get_repo_discussions(repo_id=model_id)
208
  except Exception:
209
  return None
210
  for discussion in discussions:
211
  if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
212
- details = api.get_discussion_details(repo_id=model_id, discussion_num=discussion.num)
213
- if details.target_branch == "refs/heads/main":
 
214
  return discussion
 
215
 
216
 
217
- def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["CommitOperationAdd"]:
218
  operations = []
 
219
 
220
  extensions = set([".bin", ".ckpt"])
221
  for filename in filenames:
@@ -230,12 +238,15 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
230
  else:
231
  sf_in_repo = f"{prefix}.safetensors"
232
  sf_filename = os.path.join(folder, sf_in_repo)
233
- convert_file(pt_filename, sf_filename)
234
- operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
235
- return operations
 
 
 
236
 
237
 
238
- def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
239
  pr_title = "Adding `safetensors` variant of this model"
240
  info = api.model_info(model_id)
241
  filenames = set(s.rfilename for s in info.siblings)
@@ -257,14 +268,14 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
257
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
258
  elif library_name == "transformers":
259
  if "pytorch_model.bin" in filenames:
260
- operations = convert_single(model_id, folder)
261
  elif "pytorch_model.bin.index.json" in filenames:
262
- operations = convert_multi(model_id, folder)
263
  else:
264
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
265
  check_final_model(model_id, folder)
266
  else:
267
- operations = convert_generic(model_id, folder, filenames)
268
 
269
  if operations:
270
  new_pr = api.create_commit(
@@ -279,7 +290,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
279
  print("No files to convert")
280
  finally:
281
  shutil.rmtree(folder)
282
- return new_pr
283
 
284
 
285
  if __name__ == "__main__":
@@ -300,7 +311,23 @@ if __name__ == "__main__":
300
  action="store_true",
301
  help="Create the PR even if it already exists of if the model was already converted.",
302
  )
 
 
 
 
 
303
  args = parser.parse_args()
304
  model_id = args.model_id
305
  api = HfApi()
306
- convert(api, model_id, force=args.force)
 
 
 
 
 
 
 
 
 
 
 
 
5
  from collections import defaultdict
6
  from inspect import signature
7
  from tempfile import TemporaryDirectory
8
+ from typing import Dict, List, Optional, Set, Tuple
9
 
10
  import torch
11
 
 
33
  Feel free to ignore this PR.
34
  """
35
 
36
+ ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]
37
+
38
 
39
  class AlreadyExists(Exception):
40
  pass
 
71
  return local
72
 
73
 
74
+ def convert_multi(model_id: str, folder: str) -> ConversionResult:
75
  filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json")
76
  with open(filename, "r") as f:
77
  data = json.load(f)
 
97
  operations = [
98
  CommitOperationAdd(path_in_repo=local.split("/")[-1], path_or_fileobj=local) for local in local_filenames
99
  ]
100
+ errors: List[Tuple[str, "Exception"]] = []
101
 
102
+ return operations, errors
103
 
104
 
105
+ def convert_single(model_id: str, folder: str) -> ConversionResult:
106
  pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
107
 
108
  sf_name = "model.safetensors"
109
  sf_filename = os.path.join(folder, sf_name)
110
  convert_file(pt_filename, sf_filename)
111
  operations = [CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)]
112
+ errors: List[Tuple[str, "Exception"]] = []
113
+ return operations, errors
114
 
115
 
116
  def convert_file(
 
208
 
209
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
210
  try:
211
+ main_commit = api.list_repo_commits(model_id)[0].commit_id
212
  discussions = api.get_repo_discussions(repo_id=model_id)
213
  except Exception:
214
  return None
215
  for discussion in discussions:
216
  if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
217
+ commits = api.list_repo_commits(model_id, revision=discussion.git_reference)
218
+
219
+ if main_commit == commits[1].commit_id:
220
  return discussion
221
+ return None
222
 
223
 
224
+ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> ConversionResult:
225
  operations = []
226
+ errors = []
227
 
228
  extensions = set([".bin", ".ckpt"])
229
  for filename in filenames:
 
238
  else:
239
  sf_in_repo = f"{prefix}.safetensors"
240
  sf_filename = os.path.join(folder, sf_in_repo)
241
+ try:
242
+ convert_file(pt_filename, sf_filename)
243
+ operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
244
+ except Exception as e:
245
+ errors.append((pt_filename, e))
246
+ return operations, errors
247
 
248
 
249
+ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List["Exception"]]:
250
  pr_title = "Adding `safetensors` variant of this model"
251
  info = api.model_info(model_id)
252
  filenames = set(s.rfilename for s in info.siblings)
 
268
  raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
269
  elif library_name == "transformers":
270
  if "pytorch_model.bin" in filenames:
271
+ operations, errors = convert_single(model_id, folder)
272
  elif "pytorch_model.bin.index.json" in filenames:
273
+ operations, errors = convert_multi(model_id, folder)
274
  else:
275
  raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
276
  check_final_model(model_id, folder)
277
  else:
278
+ operations, errors = convert_generic(model_id, folder, filenames)
279
 
280
  if operations:
281
  new_pr = api.create_commit(
 
290
  print("No files to convert")
291
  finally:
292
  shutil.rmtree(folder)
293
+ return new_pr, errors
294
 
295
 
296
  if __name__ == "__main__":
 
311
  action="store_true",
312
  help="Create the PR even if it already exists of if the model was already converted.",
313
  )
314
+ parser.add_argument(
315
+ "-y",
316
+ action="store_true",
317
+ help="Ignore safety prompt",
318
+ )
319
  args = parser.parse_args()
320
  model_id = args.model_id
321
  api = HfApi()
322
+ if args.y:
323
+ txt = "y"
324
+ else:
325
+ txt = input(
326
+ "This conversion script will unpickle a pickled file, which is inherently unsafe. If you do not trust this file, we invite you to use"
327
+ " https://huggingface.co/spaces/safetensors/convert or google colab or other hosted solution to avoid potential issues with this file."
328
+ " Continue [Y/n] ?"
329
+ )
330
+ if txt.lower() in {"", "y"}:
331
+ _commit_info, _errors = convert(api, model_id, force=args.force)
332
+ else:
333
+ print(f"Answer was `{txt}` aborting.")