Hotfix
Browse fileshttps://github.com/huggingface/safetensors/pull/102
- convert.py +4 -3
convert.py
CHANGED
@@ -45,7 +45,8 @@ def check_file_size(sf_filename: str, pt_filename: str):
|
|
45 |
|
46 |
|
47 |
def rename(pt_filename: str) -> str:
|
48 |
-
|
|
|
49 |
local = local.replace("pytorch_model", "model")
|
50 |
return local
|
51 |
|
@@ -103,7 +104,7 @@ def convert_file(
|
|
103 |
# For tensors to be contiguous
|
104 |
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
105 |
|
106 |
-
dirname =
|
107 |
os.makedirs(dirname, exist_ok=True)
|
108 |
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
109 |
check_file_size(sf_filename, pt_filename)
|
@@ -199,7 +200,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
|
|
199 |
prefix, ext = os.path.splitext(filename)
|
200 |
if ext in extensions:
|
201 |
pt_filename = hf_hub_download(model_id, filename=filename)
|
202 |
-
sf_in_repo = f"{
|
203 |
sf_filename = os.path.join(folder, sf_in_repo)
|
204 |
convert_file(pt_filename, sf_filename)
|
205 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
|
|
45 |
|
46 |
|
47 |
def rename(pt_filename: str) -> str:
|
48 |
+
filename, ext = os.path.splitext(pt_filename)
|
49 |
+
local = f"{filename}.safetensors"
|
50 |
local = local.replace("pytorch_model", "model")
|
51 |
return local
|
52 |
|
|
|
104 |
# For tensors to be contiguous
|
105 |
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
106 |
|
107 |
+
dirname = os.path.dirname(sf_filename)
|
108 |
os.makedirs(dirname, exist_ok=True)
|
109 |
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
110 |
check_file_size(sf_filename, pt_filename)
|
|
|
200 |
prefix, ext = os.path.splitext(filename)
|
201 |
if ext in extensions:
|
202 |
pt_filename = hf_hub_download(model_id, filename=filename)
|
203 |
+
sf_in_repo = f"{prefix}.safetensors"
|
204 |
sf_filename = os.path.join(folder, sf_in_repo)
|
205 |
convert_file(pt_filename, sf_filename)
|
206 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|