OpenLRM / scripts /convert_hf.py
zxhezexin's picture
Update spaces
f2a2544
raw
history blame
2.52 kB
# Copyright (c) 2023-2024, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
from omegaconf import OmegaConf
import torch.nn as nn
from accelerate import Accelerator
import safetensors
import sys
sys.path.append(".")
from openlrm.utils.hf_hub import wrap_model_hub
from openlrm.models import model_dict
def auto_load_model(cfg, model: nn.Module) -> int:
ckpt_root = os.path.join(
cfg.saver.checkpoint_root,
cfg.experiment.parent, cfg.experiment.child,
)
if not os.path.exists(ckpt_root):
raise FileNotFoundError(f"Checkpoint root not found: {ckpt_root}")
ckpt_dirs = os.listdir(ckpt_root)
if len(ckpt_dirs) == 0:
raise FileNotFoundError(f"No checkpoint found in {ckpt_root}")
ckpt_dirs.sort()
load_step = f"{cfg.convert.global_step}" if cfg.convert.global_step is not None else ckpt_dirs[-1]
load_model_path = os.path.join(ckpt_root, load_step, 'model.safetensors')
print(f"Loading from {load_model_path}")
safetensors.torch.load_model(model, load_model_path)
return int(load_step)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='./assets/config.yaml')
args, unknown = parser.parse_known_args()
cfg = OmegaConf.load(args.config)
cli_cfg = OmegaConf.from_cli(unknown)
cfg = OmegaConf.merge(cfg, cli_cfg)
"""
[cfg.convert]
global_step: int
save_dir: str
"""
accelerator = Accelerator()
hf_model_cls = wrap_model_hub(model_dict[cfg.experiment.type])
hf_model = hf_model_cls(dict(cfg.model))
loaded_step = auto_load_model(cfg, hf_model)
dump_path = os.path.join(
f"./exps/releases",
cfg.experiment.parent, cfg.experiment.child,
f'step_{loaded_step:06d}',
)
print(f"Saving locally to {dump_path}")
os.makedirs(dump_path, exist_ok=True)
hf_model.save_pretrained(
save_directory=dump_path,
config=hf_model.config,
)