File size: 2,998 Bytes
fd8c682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from transformers import ViTConfig, FlaxViTModel, GPT2Config, FlaxGPT2Model, FlaxAutoModelForVision2Seq, FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer


hidden_size = 8
num_hidden_layers = 2
num_attention_heads = 2
intermediate_size = 16

n_embd = 8
n_layer = 2
n_head = 2
n_inner = 16

encoder_config = ViTConfig(
    hidden_size=hidden_size,
    num_hidden_layers=num_hidden_layers,
    num_attention_heads=num_attention_heads,
    intermediate_size=intermediate_size,
)
decoder_config = GPT2Config(
    n_embd=n_embd,
    n_layer=n_layer,
    n_head=n_head,
    n_inner=n_inner,
)
encoder = FlaxViTModel(encoder_config)
decoder = FlaxGPT2Model(decoder_config)
encoder.save_pretrained("./encoder-decoder/encoder")
decoder.save_pretrained("./encoder-decoder/decoder")

enocder_decoder = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "./encoder-decoder/encoder",
    "./encoder-decoder/decoder",
)
enocder_decoder.save_pretrained("./encoder-decoder")
enocder_decoder = FlaxAutoModelForVision2Seq.from_pretrained("./encoder-decoder")


config = enocder_decoder.config

decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
if not decoder_start_token_id and getattr(config, "decoder", None):
    decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
bos_token_id = getattr(config, "bos_token_id", None)
if not bos_token_id and getattr(config, "decoder", None):
    bos_token_id = getattr(config.decoder, "bos_token_id", None)
eos_token_id = getattr(config, "eos_token_id", None)
if not eos_token_id and getattr(config, "decoder", None):
    eos_token_id = getattr(config.decoder, "eos_token_id", None)
pad_token_id = getattr(config, "pad_token_id", None)
if not pad_token_id and getattr(config, "decoder", None):
    pad_token_id = getattr(config.decoder, "pad_token_id", None)

if decoder_start_token_id is None:
    decoder_start_token_id = bos_token_id
if pad_token_id is None:
    pad_token_id = eos_token_id
        
config.decoder_start_token_id = decoder_start_token_id
config.bos_token_id = bos_token_id
config.eos_token_id = eos_token_id
config.pad_token_id = pad_token_id

if getattr(config, "decoder", None):
    config.decoder.decoder_start_token_id = decoder_start_token_id           
    config.decoder.bos_token_id = bos_token_id
    config.decoder.eos_token_id = eos_token_id
    config.decoder.pad_token_id = pad_token_id

fe = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)

fe.save_pretrained("./encoder-decoder/encoder")
tokenizer.save_pretrained("./encoder-decoder/decoder")

targets = ['i love dog', 'you cat is very cute']

# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
    labels = tokenizer(
        targets, max_length=8, padding="max_length", truncation=True, return_tensors="np"
    )
    
    print(labels)