Upload 2 files
Browse files
convert_repo_to_safetensors_gr.py
CHANGED
@@ -269,7 +269,7 @@ def convert_openai_text_enc_state_dict(text_enc_dict):
|
|
269 |
return text_enc_dict
|
270 |
|
271 |
|
272 |
-
def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True,
|
273 |
progress(0, desc="Start converting...")
|
274 |
# Path for safetensors
|
275 |
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
@@ -329,12 +329,7 @@ def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, u
|
|
329 |
if half:
|
330 |
state_dict = {k: v.half() for k, v in state_dict.items()}
|
331 |
|
332 |
-
|
333 |
-
save_file(state_dict, checkpoint_path)
|
334 |
-
else:
|
335 |
-
state_dict = {"state_dict": state_dict}
|
336 |
-
torch.save(state_dict, checkpoint_path)
|
337 |
-
|
338 |
progress(1, desc="Converted.")
|
339 |
|
340 |
|
@@ -343,7 +338,7 @@ def download_repo(repo_id, dir_path, progress=gr.Progress(track_tqdm=True)):
|
|
343 |
try:
|
344 |
snapshot_download(repo_id=repo_id, local_dir=dir_path)
|
345 |
except Exception as e:
|
346 |
-
print(f"Error: Failed to download {repo_id}. ")
|
347 |
return
|
348 |
|
349 |
|
@@ -366,7 +361,7 @@ def upload_safetensors_to_repo(filename, progress=gr.Progress(track_tqdm=True)):
|
|
366 |
return url
|
367 |
|
368 |
|
369 |
-
def convert_repo_to_safetensors(repo_id, progress=gr.Progress(track_tqdm=True)):
|
370 |
download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
|
371 |
output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
|
372 |
download_repo(repo_id, download_dir)
|
@@ -374,7 +369,7 @@ def convert_repo_to_safetensors(repo_id, progress=gr.Progress(track_tqdm=True)):
|
|
374 |
return output_filename
|
375 |
|
376 |
|
377 |
-
def convert_repo_to_safetensors_multi(repo_id, files, is_upload, urls, progress=gr.Progress(track_tqdm=True)):
|
378 |
file = convert_repo_to_safetensors(repo_id)
|
379 |
if not urls: urls = []
|
380 |
url = ""
|
@@ -393,11 +388,12 @@ if __name__ == "__main__":
|
|
393 |
parser = argparse.ArgumentParser()
|
394 |
|
395 |
parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
|
|
|
396 |
|
397 |
args = parser.parse_args()
|
398 |
assert args.repo_id is not None, "Must provide a Repo ID!"
|
399 |
|
400 |
-
convert_repo_to_safetensors(args.repo_id)
|
401 |
|
402 |
|
403 |
# Usage: python convert_repo_to_safetensors.py --repo_id GraydientPlatformAPI/goodfit-pony41-xl
|
|
|
269 |
return text_enc_dict
|
270 |
|
271 |
|
272 |
+
def convert_diffusers_to_safetensors(model_path, checkpoint_path, half = True, progress=gr.Progress(track_tqdm=True)):
|
273 |
progress(0, desc="Start converting...")
|
274 |
# Path for safetensors
|
275 |
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
|
|
329 |
if half:
|
330 |
state_dict = {k: v.half() for k, v in state_dict.items()}
|
331 |
|
332 |
+
save_file(state_dict, checkpoint_path)
|
|
|
|
|
|
|
|
|
|
|
333 |
progress(1, desc="Converted.")
|
334 |
|
335 |
|
|
|
338 |
try:
|
339 |
snapshot_download(repo_id=repo_id, local_dir=dir_path)
|
340 |
except Exception as e:
|
341 |
+
print(f"Error: Failed to download {repo_id}. {e}")
|
342 |
return
|
343 |
|
344 |
|
|
|
361 |
return url
|
362 |
|
363 |
|
364 |
+
def convert_repo_to_safetensors(repo_id, half=True, progress=gr.Progress(track_tqdm=True)):
|
365 |
download_dir = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}"
|
366 |
output_filename = f"{repo_id.split('/')[0]}_{repo_id.split('/')[-1]}.safetensors"
|
367 |
download_repo(repo_id, download_dir)
|
|
|
369 |
return output_filename
|
370 |
|
371 |
|
372 |
+
def convert_repo_to_safetensors_multi(repo_id, files, is_upload, urls, half=True, progress=gr.Progress(track_tqdm=True)):
|
373 |
file = convert_repo_to_safetensors(repo_id)
|
374 |
if not urls: urls = []
|
375 |
url = ""
|
|
|
388 |
parser = argparse.ArgumentParser()
|
389 |
|
390 |
parser.add_argument("--repo_id", default=None, type=str, required=True, help="HF Repo ID of the model to convert.")
|
391 |
+
parser.add_argument("--half", default=True, help="Save weights in half precision.")
|
392 |
|
393 |
args = parser.parse_args()
|
394 |
assert args.repo_id is not None, "Must provide a Repo ID!"
|
395 |
|
396 |
+
convert_repo_to_safetensors(args.repo_id, args.half)
|
397 |
|
398 |
|
399 |
# Usage: python convert_repo_to_safetensors.py --repo_id GraydientPlatformAPI/goodfit-pony41-xl
|