Narsil HF staff commited on
Commit
0c2d8eb
1 Parent(s): a5c2203

https://github.com/huggingface/safetensors/pull/102

Files changed (1) hide show
  1. convert.py +4 -3
convert.py CHANGED
@@ -45,7 +45,8 @@ def check_file_size(sf_filename: str, pt_filename: str):
45
 
46
 
47
  def rename(pt_filename: str) -> str:
48
- local = pt_filename.replace(".bin", ".safetensors")
 
49
  local = local.replace("pytorch_model", "model")
50
  return local
51
 
@@ -103,7 +104,7 @@ def convert_file(
103
  # For tensors to be contiguous
104
  loaded = {k: v.contiguous() for k, v in loaded.items()}
105
 
106
- dirname = sf_filename.rsplit(os.path.sep, 1)[0]
107
  os.makedirs(dirname, exist_ok=True)
108
  save_file(loaded, sf_filename, metadata={"format": "pt"})
109
  check_file_size(sf_filename, pt_filename)
@@ -199,7 +200,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
199
  prefix, ext = os.path.splitext(filename)
200
  if ext in extensions:
201
  pt_filename = hf_hub_download(model_id, filename=filename)
202
- sf_in_repo = f"{filename}.safetensors"
203
  sf_filename = os.path.join(folder, sf_in_repo)
204
  convert_file(pt_filename, sf_filename)
205
  operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
 
45
 
46
 
47
  def rename(pt_filename: str) -> str:
48
+ filename, ext = os.path.splitext(pt_filename)
49
+ local = f"{filename}.safetensors"
50
  local = local.replace("pytorch_model", "model")
51
  return local
52
 
 
104
  # For tensors to be contiguous
105
  loaded = {k: v.contiguous() for k, v in loaded.items()}
106
 
107
+ dirname = os.path.dirname(sf_filename)
108
  os.makedirs(dirname, exist_ok=True)
109
  save_file(loaded, sf_filename, metadata={"format": "pt"})
110
  check_file_size(sf_filename, pt_filename)
 
200
  prefix, ext = os.path.splitext(filename)
201
  if ext in extensions:
202
  pt_filename = hf_hub_download(model_id, filename=filename)
203
+ sf_in_repo = f"{prefix}.safetensors"
204
  sf_filename = os.path.join(folder, sf_in_repo)
205
  convert_file(pt_filename, sf_filename)
206
  operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))