import argparse import os from typing import Tuple import omegaconf import torch from relik.common.utils import from_cache from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule from relik.reader.relik_reader_core import RelikReaderCoreModel CKPT_FILE_NAME = "model.ckpt" CONFIG_FILE_NAME = "cfg.yaml" def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None: if not os.path.exists(output_dir): os.makedirs(output_dir) else: print(f"{output_dir} already exists, aborting operation") exit(1) relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint( pl_module_ckpt_path ) torch.save( relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}" ) with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f: omegaconf.OmegaConf.save( omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f ) def load_model_and_conf( model_dir_path: str, ) -> Tuple[RelikReaderCoreModel, omegaconf.DictConfig]: # TODO: quick workaround to load the model from HF hub model_dir = from_cache( model_dir_path, filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME], cache_dir=None, force_download=False, ) ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}" model = torch.load(ckpt_path, map_location=torch.device("cpu")) model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}" model_conf = omegaconf.OmegaConf.load(model_cfg_path) return model, model_conf def parse_arg() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument( "--ckpt", help="Path to the pytorch lightning ckpt you want to convert.", required=True, ) parser.add_argument( "--output-dir", "-o", help="The output dir to store the bare models and the config.", required=True, ) return parser.parse_args() def main(): args = parse_arg() convert_pl_module(args.ckpt, args.output_dir) if __name__ == "__main__": main()