ydshieh commited on
Commit
fd8c682
1 Parent(s): 8ffa189

Add a script to create dummy pretrained models for testing

Browse files
create_dummy_pretrained_models.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTConfig, FlaxViTModel, GPT2Config, FlaxGPT2Model, FlaxAutoModelForVision2Seq, FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer
2
+
3
+
4
+ hidden_size = 8
5
+ num_hidden_layers = 2
6
+ num_attention_heads = 2
7
+ intermediate_size = 16
8
+
9
+ n_embd = 8
10
+ n_layer = 2
11
+ n_head = 2
12
+ n_inner = 16
13
+
14
+ encoder_config = ViTConfig(
15
+ hidden_size=hidden_size,
16
+ num_hidden_layers=num_hidden_layers,
17
+ num_attention_heads=num_attention_heads,
18
+ intermediate_size=intermediate_size,
19
+ )
20
+ decoder_config = GPT2Config(
21
+ n_embd=n_embd,
22
+ n_layer=n_layer,
23
+ n_head=n_head,
24
+ n_inner=n_inner,
25
+ )
26
+ encoder = FlaxViTModel(encoder_config)
27
+ decoder = FlaxGPT2Model(decoder_config)
28
+ encoder.save_pretrained("./encoder-decoder/encoder")
29
+ decoder.save_pretrained("./encoder-decoder/decoder")
30
+
31
+ enocder_decoder = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
32
+ "./encoder-decoder/encoder",
33
+ "./encoder-decoder/decoder",
34
+ )
35
+ enocder_decoder.save_pretrained("./encoder-decoder")
36
+ enocder_decoder = FlaxAutoModelForVision2Seq.from_pretrained("./encoder-decoder")
37
+
38
+
39
+ config = enocder_decoder.config
40
+
41
+ decoder_start_token_id = getattr(config, "decoder_start_token_id", None)
42
+ if not decoder_start_token_id and getattr(config, "decoder", None):
43
+ decoder_start_token_id = getattr(config.decoder, "decoder_start_token_id", None)
44
+ bos_token_id = getattr(config, "bos_token_id", None)
45
+ if not bos_token_id and getattr(config, "decoder", None):
46
+ bos_token_id = getattr(config.decoder, "bos_token_id", None)
47
+ eos_token_id = getattr(config, "eos_token_id", None)
48
+ if not eos_token_id and getattr(config, "decoder", None):
49
+ eos_token_id = getattr(config.decoder, "eos_token_id", None)
50
+ pad_token_id = getattr(config, "pad_token_id", None)
51
+ if not pad_token_id and getattr(config, "decoder", None):
52
+ pad_token_id = getattr(config.decoder, "pad_token_id", None)
53
+
54
+ if decoder_start_token_id is None:
55
+ decoder_start_token_id = bos_token_id
56
+ if pad_token_id is None:
57
+ pad_token_id = eos_token_id
58
+
59
+ config.decoder_start_token_id = decoder_start_token_id
60
+ config.bos_token_id = bos_token_id
61
+ config.eos_token_id = eos_token_id
62
+ config.pad_token_id = pad_token_id
63
+
64
+ if getattr(config, "decoder", None):
65
+ config.decoder.decoder_start_token_id = decoder_start_token_id
66
+ config.decoder.bos_token_id = bos_token_id
67
+ config.decoder.eos_token_id = eos_token_id
68
+ config.decoder.pad_token_id = pad_token_id
69
+
70
+ fe = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
71
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
72
+ tokenizer.pad_token = tokenizer.convert_ids_to_tokens(config.pad_token_id)
73
+
74
+ fe.save_pretrained("./encoder-decoder/encoder")
75
+ tokenizer.save_pretrained("./encoder-decoder/decoder")
76
+
77
+ targets = ['i love dog', 'you cat is very cute']
78
+
79
+ # Setup the tokenizer for targets
80
+ with tokenizer.as_target_tokenizer():
81
+ labels = tokenizer(
82
+ targets, max_length=8, padding="max_length", truncation=True, return_tensors="np"
83
+ )
84
+
85
+ print(labels)
dataset_example.py → dataset_usage_example.py RENAMED
File without changes