|
import argparse |
|
|
|
import huggingface_hub |
|
import torch |
|
from vqmodel.configuration_vqmodel import VQModelConfig |
|
from vqmodel.image_processing_vqmodel import VQModelImageProcessor |
|
from vqmodel.modeling_vqmodel import VQModel |
|
|
|
VQModelConfig.register_for_auto_class() |
|
VQModel.register_for_auto_class() |
|
VQModelImageProcessor.register_for_auto_class() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
config = VQModelConfig(yaml_path=args.yaml_path) |
|
model = VQModel(config) |
|
load_model_weights(model, args.ckpt_path) |
|
|
|
|
|
ddconfig = model.vq_cfg.model.params.ddconfig |
|
image_processor = VQModelImageProcessor( |
|
size=ddconfig.resolution, |
|
convert_rgb=ddconfig.in_channels == 3, |
|
) |
|
|
|
|
|
model.config.repo_id = args.repo_id |
|
model.config.yaml_path = "config.yaml" |
|
|
|
|
|
model.push_to_hub(args.repo_id, private=True) |
|
image_processor.push_to_hub(args.repo_id, private=True) |
|
api = huggingface_hub.HfApi() |
|
api.upload_file( |
|
path_or_fileobj=args.yaml_path, |
|
path_in_repo="config.yaml", |
|
repo_id=args.repo_id, |
|
) |
|
api.upload_file( |
|
path_or_fileobj=__file__, |
|
path_in_repo="push_to_hub.py", |
|
repo_id=args.repo_id, |
|
) |
|
api.upload_file( |
|
path_or_fileobj="requirements.txt", |
|
path_in_repo="requirements.txt", |
|
repo_id=args.repo_id, |
|
) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--repo_id", type=str, required=True, help="Repository ID") |
|
parser.add_argument( |
|
"--yaml_path", type=str, required=True, help="Path to YAML file" |
|
) |
|
parser.add_argument( |
|
"--ckpt_path", type=str, required=True, help="Path to checkpoint file" |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def load_model_weights(model, ckpt_path): |
|
|
|
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] |
|
|
|
|
|
for key in list(ckpt.keys()): |
|
if key.startswith("loss."): |
|
del ckpt[key] |
|
model.model.load_state_dict(ckpt) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|