Narsil HF staff commited on
Commit
bac319c
1 Parent(s): 0c2d8eb

Make this work on diffusers out of the box.

Browse files
Files changed (1) hide show
  1. convert.py +8 -2
convert.py CHANGED
@@ -200,7 +200,13 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str]) -> List["Co
200
  prefix, ext = os.path.splitext(filename)
201
  if ext in extensions:
202
  pt_filename = hf_hub_download(model_id, filename=filename)
203
- sf_in_repo = f"{prefix}.safetensors"
 
 
 
 
 
 
204
  sf_filename = os.path.join(folder, sf_in_repo)
205
  convert_file(pt_filename, sf_filename)
206
  operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
@@ -219,7 +225,7 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["Commi
219
  try:
220
  operations = None
221
  pr = previous_pr(api, model_id, pr_title)
222
- if ("model.safetensors" in filenames or "model_index.safetensors.index.json" in filenames) and not force:
223
  raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
224
  elif pr is not None and not force:
225
  url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
 
200
  prefix, ext = os.path.splitext(filename)
201
  if ext in extensions:
202
  pt_filename = hf_hub_download(model_id, filename=filename)
203
+ _, raw_filename = os.path.split(filename)
204
+ if raw_filename == "pytorch_model.bin":
205
+ # XXX: This is a special case to handle `transformers` and the
206
+ # `transformers` part of the model which is actually loaded by `transformers`.
207
+ sf_in_repo = "model.safetensors"
208
+ else:
209
+ sf_in_repo = f"{prefix}.safetensors"
210
  sf_filename = os.path.join(folder, sf_in_repo)
211
  convert_file(pt_filename, sf_filename)
212
  operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
 
225
  try:
226
  operations = None
227
  pr = previous_pr(api, model_id, pr_title)
228
+ if any(filename.endswith(".safetensors") for filename in filenames) and not force:
229
  raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
230
  elif pr is not None and not force:
231
  url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"