import argparse import logging import numpy as np import torch import os from transformers import AutoConfig, FlaxAutoModelForCausalLM logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) model_path = "./distilgpt2-base-pretrained-he" save_directory = "./tmp/flax/" config_path = os.path.join(model_path, 'config.json') # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) config = AutoConfig.from_pretrained(config_path) model = FlaxAutoModelForCausalLM.from_pretrained(model_path, from_pt=True, config=config) model.save_pretrained(save_directory)