vit-gpt2 / create_dummy_pretrained_models.py
ydshieh
Add a script to create dummy pretrained models for testing
fd8c682
from transformers import ViTConfig, FlaxViTModel, GPT2Config, FlaxGPT2Model, FlaxAutoModelForVision2Seq, FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer
hidden_size = 8
num_hidden_layers = 2
num_attention_heads = 2
intermediate_size = 16
n_embd = 8
n_layer = 2
n_head = 2
n_inner = 16
encoder_config = ViTConfig(
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
)
decoder_config = GPT2Config(
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
n_inner=n_inner,
)
encoder = FlaxViTModel(encoder_config)
decoder = FlaxGPT2Model(decoder_config)
encoder.save_pretrained("./encoder-decoder/encoder")
decoder.save_pretrained("./encoder-decoder/decoder")
enocder_decoder = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"./encoder-decoder/encoder",
"./encoder-decoder/decoder",
)
enocder_decoder.save_pretrained("./encoder-decoder")
enocder_decoder = FlaxAutoModelForVision2Seq.from_pretrained("./encoder-decoder")
config = enocder_decoder.config
decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
if not decoder_start_token_id and getattr(config, "decoder", None):
decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
bos_token_id = getattr(config, "bos_token_id", None)
if not bos_token_id and getattr(config, "decoder", None):
bos_token_id = getattr(config.decoder, "bos_token_id", None)
eos_token_id = getattr(config, "eos_token_id", None)
if not eos_token_id and getattr(config, "decoder", None):
eos_token_id = getattr(config.decoder, "eos_token_id", None)
pad_token_id = getattr(config, "pad_token_id", None)
if not pad_token_id and getattr(config, "decoder", None):
pad_token_id = getattr(config.decoder, "pad_token_id", None)
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id
if pad_token_id is None:
pad_token_id = eos_token_id
config.decoder_start_token_id = decoder_start_token_id
config.bos_token_id = bos_token_id
config.eos_token_id = eos_token_id
config.pad_token_id = pad_token_id
if getattr(config, "decoder", None):
config.decoder.decoder_start_token_id = decoder_start_token_id
config.decoder.bos_token_id = bos_token_id
config.decoder.eos_token_id = eos_token_id
config.decoder.pad_token_id = pad_token_id
fe = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
fe.save_pretrained("./encoder-decoder/encoder")
tokenizer.save_pretrained("./encoder-decoder/decoder")
targets = ['i love dog', 'you cat is very cute']
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(
targets, max_length=8, padding="max_length", truncation=True, return_tensors="np"
)
print(labels)