Spaces:
Running
Running
Make this work on diffusers out of the box.
Browse files- convert.py +8 -2
convert.py
CHANGED
@@ -200,7 +200,13 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
|
|
200 |
prefix, ext = os.path.splitext(filename)
|
201 |
if ext in extensions:
|
202 |
pt_filename = hf_hub_download(model_id, filename=filename)
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
sf_filename = os.path.join(folder, sf_in_repo)
|
205 |
convert_file(pt_filename, sf_filename)
|
206 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
@@ -219,7 +225,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
|
|
219 |
try:
|
220 |
operations = None
|
221 |
pr = previous_pr(api, model_id, pr_title)
|
222 |
-
if ("
|
223 |
raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
|
224 |
elif pr is not None and not force:
|
225 |
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
|
|
|
200 |
prefix, ext = os.path.splitext(filename)
|
201 |
if ext in extensions:
|
202 |
pt_filename = hf_hub_download(model_id, filename=filename)
|
203 |
+
_, raw_filename = os.path.split(filename)
|
204 |
+
if raw_filename == "pytorch_model.bin":
|
205 |
+
# XXX: This is a special case to handle `transformers` and the
|
206 |
+
# `transformers` part of the model which is actually loaded by `transformers`.
|
207 |
+
sf_in_repo = "model.safetensors"
|
208 |
+
else:
|
209 |
+
sf_in_repo = f"{prefix}.safetensors"
|
210 |
sf_filename = os.path.join(folder, sf_in_repo)
|
211 |
convert_file(pt_filename, sf_filename)
|
212 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
|
|
225 |
try:
|
226 |
operations = None
|
227 |
pr = previous_pr(api, model_id, pr_title)
|
228 |
+
if any(filename.endswith(".safetensors") for filename in filenames) and not force:
|
229 |
raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
|
230 |
elif pr is not None and not force:
|
231 |
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
|