Spaces:
Running
Running
More robust with tied weight keys (#41)
Browse files- More robust with tied weight keys (843addb2c0b07af224d9a3a98fef76059d40ebef)
- Update convert.py (5c59f9a20bf82e23b6d7fe8a772c85e218b9563b)
- Update convert.py (385ed6892301c6c3ec71498dbedcce550dedd7a3)
Co-authored-by: Cyril Vallez <cyrilvallez@users.noreply.huggingface.co>
- convert.py +6 -5
convert.py
CHANGED
|
@@ -241,7 +241,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[st
|
|
| 241 |
|
| 242 |
|
| 243 |
def convert_generic(
|
| 244 |
-
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
|
| 245 |
) -> ConversionResult:
|
| 246 |
operations = []
|
| 247 |
errors = []
|
|
@@ -262,7 +262,7 @@ def convert_generic(
|
|
| 262 |
sf_in_repo = f"{prefix}.safetensors"
|
| 263 |
sf_filename = os.path.join(folder, sf_in_repo)
|
| 264 |
try:
|
| 265 |
-
convert_file(pt_filename, sf_filename, discard_names=
|
| 266 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
| 267 |
except Exception as e:
|
| 268 |
errors.append((pt_filename, e))
|
|
@@ -280,11 +280,13 @@ def convert(
|
|
| 280 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
| 281 |
os.makedirs(folder)
|
| 282 |
new_pr = None
|
|
|
|
|
|
|
| 283 |
try:
|
| 284 |
operations = None
|
| 285 |
pr = previous_pr(api, model_id, pr_title, revision=revision)
|
| 286 |
-
|
| 287 |
library_name = getattr(info, "library_name", None)
|
|
|
|
| 288 |
if any(filename.endswith(".safetensors") for filename in filenames) and not force:
|
| 289 |
raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
|
| 290 |
elif pr is not None and not force:
|
|
@@ -293,7 +295,6 @@ def convert(
|
|
| 293 |
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
|
| 294 |
elif library_name == "transformers":
|
| 295 |
|
| 296 |
-
discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
|
| 297 |
if "pytorch_model.bin" in filenames:
|
| 298 |
operations, errors = convert_single(
|
| 299 |
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
|
|
@@ -306,7 +307,7 @@ def convert(
|
|
| 306 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 307 |
else:
|
| 308 |
operations, errors = convert_generic(
|
| 309 |
-
model_id, revision=revision, folder=folder, filenames=filenames, token=api.token
|
| 310 |
)
|
| 311 |
|
| 312 |
if operations:
|
|
|
|
| 241 |
|
| 242 |
|
| 243 |
def convert_generic(
|
| 244 |
+
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str], discard_names: List[str],
|
| 245 |
) -> ConversionResult:
|
| 246 |
operations = []
|
| 247 |
errors = []
|
|
|
|
| 262 |
sf_in_repo = f"{prefix}.safetensors"
|
| 263 |
sf_filename = os.path.join(folder, sf_in_repo)
|
| 264 |
try:
|
| 265 |
+
convert_file(pt_filename, sf_filename, discard_names=discard_names)
|
| 266 |
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
|
| 267 |
except Exception as e:
|
| 268 |
errors.append((pt_filename, e))
|
|
|
|
| 280 |
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
| 281 |
os.makedirs(folder)
|
| 282 |
new_pr = None
|
| 283 |
+
# Exception handling already happen inside this function
|
| 284 |
+
discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
|
| 285 |
try:
|
| 286 |
operations = None
|
| 287 |
pr = previous_pr(api, model_id, pr_title, revision=revision)
|
|
|
|
| 288 |
library_name = getattr(info, "library_name", None)
|
| 289 |
+
|
| 290 |
if any(filename.endswith(".safetensors") for filename in filenames) and not force:
|
| 291 |
raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
|
| 292 |
elif pr is not None and not force:
|
|
|
|
| 295 |
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
|
| 296 |
elif library_name == "transformers":
|
| 297 |
|
|
|
|
| 298 |
if "pytorch_model.bin" in filenames:
|
| 299 |
operations, errors = convert_single(
|
| 300 |
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
|
|
|
|
| 307 |
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
|
| 308 |
else:
|
| 309 |
operations, errors = convert_generic(
|
| 310 |
+
model_id, revision=revision, folder=folder, filenames=filenames, token=api.token, discard_names=discard_names
|
| 311 |
)
|
| 312 |
|
| 313 |
if operations:
|