| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import fire |
|
|
| from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
|
|
| def save_randomly_initialized_version(config_name: str, save_dir: str, **config_kwargs): |
| """Save a randomly initialized version of a model using a pretrained config. |
| Args: |
| config_name: which config to use |
| save_dir: where to save the resulting model and tokenizer |
| config_kwargs: Passed to AutoConfig |
| |
| Usage:: |
| save_randomly_initialized_version("facebook/bart-large-cnn", "distilbart_random_cnn_6_3", encoder_layers=6, decoder_layers=3, num_beams=3) |
| """ |
| cfg = AutoConfig.from_pretrained(config_name, **config_kwargs) |
| model = AutoModelForSeq2SeqLM.from_config(cfg) |
| model.save_pretrained(save_dir) |
| AutoTokenizer.from_pretrained(config_name).save_pretrained(save_dir) |
| return model |
|
|
|
|
| if __name__ == "__main__": |
| fire.Fire(save_randomly_initialized_version) |
|
|