File size: 714 Bytes
6e1c9c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import argparse
import logging

import numpy as np
import torch
import os
from transformers import AutoConfig, FlaxAutoModelForCausalLM

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

model_path = "./distilgpt2-base-pretrained-he"
save_directory = "./tmp/flax/"

config_path = os.path.join(model_path, 'config.json')

# Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
config = AutoConfig.from_pretrained(config_path)
model = FlaxAutoModelForCausalLM.from_pretrained(model_path, from_pt=True, config=config)
model.save_pretrained(save_directory)