File size: 514 Bytes
98c533d
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import argparse
from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration

def main(args):
    pt_model = T5ForConditionalGeneration.from_pretrained(args.model_dir, from_flax=True)
    pt_model.save_pretrained(args.model_dir)
    tf_model = TFT5ForConditionalGeneration.from_pretrained(args.model_dir, from_pt=True)
    tf_model.save_pretrained(args.model_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_dir', type=str, default='.')