from transformers import ViTConfig, FlaxViTModel, GPT2Config, FlaxGPT2Model, FlaxAutoModelForVision2Seq, FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer hidden_size = 8 num_hidden_layers = 2 num_attention_heads = 2 intermediate_size = 16 n_embd = 8 n_layer = 2 n_head = 2 n_inner = 16 encoder_config = ViTConfig( hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, ) decoder_config = GPT2Config( n_embd=n_embd, n_layer=n_layer, n_head=n_head, n_inner=n_inner, ) encoder = FlaxViTModel(encoder_config) decoder = FlaxGPT2Model(decoder_config) encoder.save_pretrained("./encoder-decoder/encoder") decoder.save_pretrained("./encoder-decoder/decoder") enocder_decoder = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained( "./encoder-decoder/encoder", "./encoder-decoder/decoder", ) enocder_decoder.save_pretrained("./encoder-decoder") enocder_decoder = FlaxAutoModelForVision2Seq.from_pretrained("./encoder-decoder") config = enocder_decoder.config decoder_start_token_id = getattr(config, "decoder_start_token_id", None) if not decoder_start_token_id and getattr(config, "decoder", None): decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None) bos_token_id = getattr(config, "bos_token_id", None) if not bos_token_id and getattr(config, "decoder", None): bos_token_id = getattr(config.decoder, "bos_token_id", None) eos_token_id = getattr(config, "eos_token_id", None) if not eos_token_id and getattr(config, "decoder", None): eos_token_id = getattr(config.decoder, "eos_token_id", None) pad_token_id = getattr(config, "pad_token_id", None) if not pad_token_id and getattr(config, "decoder", None): pad_token_id = getattr(config.decoder, "pad_token_id", None) if decoder_start_token_id is None: decoder_start_token_id = bos_token_id if pad_token_id is None: pad_token_id = eos_token_id config.decoder_start_token_id = decoder_start_token_id config.bos_token_id = bos_token_id config.eos_token_id = eos_token_id config.pad_token_id = pad_token_id if getattr(config, "decoder", None): config.decoder.decoder_start_token_id = decoder_start_token_id config.decoder.bos_token_id = bos_token_id config.decoder.eos_token_id = eos_token_id config.decoder.pad_token_id = pad_token_id fe = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id) fe.save_pretrained("./encoder-decoder/encoder") tokenizer.save_pretrained("./encoder-decoder/decoder") targets = ['i love dog', 'you cat is very cute'] # Setup the tokenizer for targets with tokenizer.as_target_tokenizer(): labels = tokenizer( targets, max_length=8, padding="max_length", truncation=True, return_tensors="np" ) print(labels)