Doron Adler
* Updated model card
6e1c9c6
raw
history blame contribute delete
No virus
714 Bytes
import argparse
import logging
import numpy as np
import torch
import os
from transformers import AutoConfig, FlaxAutoModelForCausalLM
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
model_path = "./distilgpt2-base-pretrained-he"
save_directory = "./tmp/flax/"
config_path = os.path.join(model_path, 'config.json')
# Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
config = AutoConfig.from_pretrained(config_path)
model = FlaxAutoModelForCausalLM.from_pretrained(model_path, from_pt=True, config=config)
model.save_pretrained(save_directory)