FLUX-VisionReply / core /utils /save_and_load.py
gokaygokay's picture
full_files
2f4febc
raw
history blame
1.94 kB
import os
import torch
import json
from pathlib import Path
import safetensors
import wandb
def create_folder_if_necessary(path):
path = "/".join(path.split("/")[:-1])
Path(path).mkdir(parents=True, exist_ok=True)
def safe_save(ckpt, path):
try:
os.remove(f"{path}.bak")
except OSError:
pass
try:
os.rename(path, f"{path}.bak")
except OSError:
pass
if path.endswith(".pt") or path.endswith(".ckpt"):
torch.save(ckpt, path)
elif path.endswith(".json"):
with open(path, "w", encoding="utf-8") as f:
json.dump(ckpt, f, indent=4)
elif path.endswith(".safetensors"):
safetensors.torch.save_file(ckpt, path)
else:
raise ValueError(f"File extension not supported: {path}")
def load_or_fail(path, wandb_run_id=None):
accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
try:
assert any(
[path.endswith(ext) for ext in accepted_extensions]
), f"Automatic loading not supported for this extension: {path}"
if not os.path.exists(path):
checkpoint = None
elif path.endswith(".pt") or path.endswith(".ckpt"):
checkpoint = torch.load(path, map_location="cpu")
elif path.endswith(".json"):
with open(path, "r", encoding="utf-8") as f:
checkpoint = json.load(f)
elif path.endswith(".safetensors"):
checkpoint = {}
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
return checkpoint
except Exception as e:
if wandb_run_id is not None:
wandb.alert(
title=f"Corrupt checkpoint for run {wandb_run_id}",
text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
)
raise e