ldm-vq-f16 / push_to_hub.py
ktrk115's picture
Upload push_to_hub.py with huggingface_hub
d76abce verified
import argparse
import huggingface_hub
import torch
from vqmodel.configuration_vqmodel import VQModelConfig
from vqmodel.image_processing_vqmodel import VQModelImageProcessor
from vqmodel.modeling_vqmodel import VQModel
VQModelConfig.register_for_auto_class()
VQModel.register_for_auto_class()
VQModelImageProcessor.register_for_auto_class()
def main():
args = parse_args()
config = VQModelConfig(yaml_path=args.yaml_path)
model = VQModel(config)
load_model_weights(model, args.ckpt_path)
# Define image processor
ddconfig = model.vq_cfg.model.params.ddconfig
image_processor = VQModelImageProcessor(
size=ddconfig.resolution,
convert_rgb=ddconfig.in_channels == 3,
)
# Edit config
model.config.repo_id = args.repo_id
model.config.yaml_path = "config.yaml"
# Push to hub
model.push_to_hub(args.repo_id, private=True)
image_processor.push_to_hub(args.repo_id, private=True)
api = huggingface_hub.HfApi()
api.upload_file(
path_or_fileobj=args.yaml_path,
path_in_repo="config.yaml",
repo_id=args.repo_id,
)
api.upload_file(
path_or_fileobj=__file__,
path_in_repo="push_to_hub.py",
repo_id=args.repo_id,
)
api.upload_file(
path_or_fileobj="requirements.txt",
path_in_repo="requirements.txt",
repo_id=args.repo_id,
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--repo_id", type=str, required=True, help="Repository ID")
parser.add_argument(
"--yaml_path", type=str, required=True, help="Path to YAML file"
)
parser.add_argument(
"--ckpt_path", type=str, required=True, help="Path to checkpoint file"
)
return parser.parse_args()
def load_model_weights(model, ckpt_path):
# Load checkpoint
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
# Remove loss related states
for key in list(ckpt.keys()):
if key.startswith("loss."):
del ckpt[key]
model.model.load_state_dict(ckpt)
if __name__ == "__main__":
main()