import argparse import huggingface_hub import torch from vqmodel.configuration_vqmodel import VQModelConfig from vqmodel.image_processing_vqmodel import VQModelImageProcessor from vqmodel.modeling_vqmodel import VQModel VQModelConfig.register_for_auto_class() VQModel.register_for_auto_class() VQModelImageProcessor.register_for_auto_class() def main(): args = parse_args() config = VQModelConfig(yaml_path=args.yaml_path) model = VQModel(config) load_model_weights(model, args.ckpt_path) # Define image processor ddconfig = model.vq_cfg.model.params.ddconfig image_processor = VQModelImageProcessor( size=ddconfig.resolution, convert_rgb=ddconfig.in_channels == 3, ) # Edit config model.config.repo_id = args.repo_id model.config.yaml_path = "config.yaml" # Push to hub model.push_to_hub(args.repo_id, private=True) image_processor.push_to_hub(args.repo_id, private=True) api = huggingface_hub.HfApi() api.upload_file( path_or_fileobj=args.yaml_path, path_in_repo="config.yaml", repo_id=args.repo_id, ) api.upload_file( path_or_fileobj=__file__, path_in_repo="push_to_hub.py", repo_id=args.repo_id, ) api.upload_file( path_or_fileobj="requirements.txt", path_in_repo="requirements.txt", repo_id=args.repo_id, ) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--repo_id", type=str, required=True, help="Repository ID") parser.add_argument( "--yaml_path", type=str, required=True, help="Path to YAML file" ) parser.add_argument( "--ckpt_path", type=str, required=True, help="Path to checkpoint file" ) return parser.parse_args() def load_model_weights(model, ckpt_path): # Load checkpoint ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] # Remove loss related states for key in list(ckpt.keys()): if key.startswith("loss."): del ckpt[key] model.model.load_state_dict(ckpt) if __name__ == "__main__": main()