| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import sys |
| |
|
| | sys.path.append(".") |
| |
|
| | import argparse |
| |
|
| | from accelerate import Accelerator |
| |
|
| | from LHM.models import model_dict |
| | from LHM.utils.hf_hub import wrap_model_hub |
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--model_type", type=str, required=True) |
| | parser.add_argument("--local_ckpt", type=str, required=True) |
| | parser.add_argument("--repo_id", type=str, required=True) |
| | args, unknown = parser.parse_known_args() |
| |
|
| | accelerator = Accelerator() |
| |
|
| | hf_model_cls = wrap_model_hub(model_dict[args.model_type]) |
| | hf_model = hf_model_cls.from_pretrained(args.local_ckpt) |
| | hf_model.push_to_hub( |
| | repo_id=args.repo_id, |
| | config=hf_model.config, |
| | private=True, |
| | ) |
| |
|