ydshieh commited on
Commit
7a16b57
1 Parent(s): 70dcb42

upload ckpt 5

Browse files
.gitattributes CHANGED
@@ -20,3 +20,4 @@ wit_data_dir/dev/dev.tsv filter=lfs diff=lfs merge=lfs -text
20
  wit_data_dir/test/test.tsv filter=lfs diff=lfs merge=lfs -text
21
  train.json filter=lfs diff=lfs merge=lfs -text
22
  val.json filter=lfs diff=lfs merge=lfs -text
 
 
20
  wit_data_dir/test/test.tsv filter=lfs diff=lfs merge=lfs -text
21
  train.json filter=lfs diff=lfs merge=lfs -text
22
  val.json filter=lfs diff=lfs merge=lfs -text
23
+ outputs/ckpt_5/flax_model.msgpack filter=lfs diff=lfs merge=lfs -text
generate.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, datasets, json
2
+
3
+ current_path = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(current_path)
5
+
6
+ # jax
7
+ import jax
8
+
9
+ # Main model - ViTGPT2LM
10
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
11
+
12
+ # Vit - as encoder
13
+ from transformers import ViTFeatureExtractor
14
+ from PIL import Image
15
+ import requests
16
+ import numpy as np
17
+
18
+ # GPT2 / GPT2LM - as decoder
19
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer
20
+
21
+ ckpt_no = 5
22
+ model_name_or_path = f'./outputs/ckpt_{ckpt_no}/'
23
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_name_or_path)
24
+
25
+ vit_model_name = 'google/vit-base-patch16-224-in21k'
26
+ feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
27
+
28
+ gpt2_model_name = 'asi/gpt-fr-cased-small'
29
+ tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
30
+
31
+ max_length = 32
32
+ num_beams = 8
33
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
34
+
35
+
36
+ @jax.jit
37
+ def predict_fn(pixel_values):
38
+
39
+ return flax_vit_gpt2_lm.generate(pixel_values, **gen_kwargs)
40
+
41
+ def predict(image):
42
+
43
+ # batch dim is added automatically
44
+ encoder_inputs = feature_extractor(images=image, return_tensors="jax")
45
+ pixel_values = encoder_inputs.pixel_values
46
+
47
+ # generation
48
+ generation = predict_fn(pixel_values)
49
+
50
+ token_ids = np.array(generation.sequences)[0]
51
+ caption = tokenizer.decode(token_ids)
52
+
53
+ return caption, token_ids
54
+
55
+
56
+ if __name__ == '__main__':
57
+
58
+ from datetime import datetime
59
+
60
+ split = 'val'
61
+ image_id = 322141
62
+ p = f'/home/33611/caption/{split}2014/COCO_{split}2014_{str(image_id).zfill(12)}.jpg'
63
+ image = Image.open(p)
64
+ caption, token_ids = predict(image)
65
+ image.close()
66
+
67
+ print(f'token_ids: {token_ids}')
68
+ print(f'caption: {caption}')
69
+
70
+ ds = datasets.load_dataset('./coco_dataset_script.py', data_dir='/home/33611/caption/')
71
+ ds = ds['train']
72
+ ds = ds.select(range(100))
73
+
74
+ predictions = []
75
+ for ex in ds:
76
+
77
+ p = ex['image_file']
78
+ image = Image.open(p)
79
+ s = datetime.now()
80
+ caption, token_ids = predict(image)
81
+ caption = caption.replace('<s>', '').replace('</s>', '').replace('<pad>', '').strip()
82
+ image.close()
83
+ e = datetime.now()
84
+ e = (e - s).total_seconds()
85
+ print(f' timing: {e}')
86
+ print(f' caption: {ex["fr"]}')
87
+ print(f'prediction: {caption}')
88
+ print('-' * 20)
89
+ ex['pred'] = caption
90
+ predictions.append(ex)
91
+
92
+ with open(f'ckpt_{ckpt_no}_preds.json', 'w', encoding='UTF-8') as fp:
93
+ json.dump(predictions, fp, ensure_ascii=False, indent=4)
outputs/ckpt_5/config.json ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViTGPT2LMForConditionalGeneration"
4
+ ],
5
+ "bos_token_id": 0,
6
+ "decoder_start_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gpt2_config": {
9
+ "_name_or_path": "",
10
+ "activation_function": "gelu_new",
11
+ "add_cross_attention": true,
12
+ "architectures": null,
13
+ "attn_pdrop": 0.1,
14
+ "bad_words_ids": null,
15
+ "bos_token_id": 0,
16
+ "chunk_size_feed_forward": 0,
17
+ "decoder_start_token_id": null,
18
+ "diversity_penalty": 0.0,
19
+ "do_sample": false,
20
+ "early_stopping": false,
21
+ "embd_pdrop": 0.1,
22
+ "encoder_no_repeat_ngram_size": 0,
23
+ "eos_token_id": 2,
24
+ "finetuning_task": null,
25
+ "forced_bos_token_id": null,
26
+ "forced_eos_token_id": null,
27
+ "gradient_checkpointing": false,
28
+ "id2label": {
29
+ "0": "LABEL_0",
30
+ "1": "LABEL_1"
31
+ },
32
+ "initializer_range": 0.02,
33
+ "is_decoder": false,
34
+ "is_encoder_decoder": false,
35
+ "label2id": {
36
+ "LABEL_0": 0,
37
+ "LABEL_1": 1
38
+ },
39
+ "layer_norm_epsilon": 1e-05,
40
+ "length_penalty": 1.0,
41
+ "max_length": 20,
42
+ "min_length": 0,
43
+ "model_type": "gpt2",
44
+ "n_ctx": 1024,
45
+ "n_embd": 768,
46
+ "n_head": 12,
47
+ "n_inner": null,
48
+ "n_layer": 12,
49
+ "n_positions": 1024,
50
+ "no_repeat_ngram_size": 0,
51
+ "num_beam_groups": 1,
52
+ "num_beams": 1,
53
+ "num_return_sequences": 1,
54
+ "output_attentions": false,
55
+ "output_hidden_states": false,
56
+ "output_scores": false,
57
+ "pad_token_id": 1,
58
+ "prefix": null,
59
+ "problem_type": null,
60
+ "pruned_heads": {},
61
+ "remove_invalid_values": false,
62
+ "repetition_penalty": 1.0,
63
+ "resid_pdrop": 0.1,
64
+ "return_dict": true,
65
+ "return_dict_in_generate": false,
66
+ "scale_attn_weights": true,
67
+ "sep_token_id": null,
68
+ "summary_activation": null,
69
+ "summary_first_dropout": 0.1,
70
+ "summary_proj_to_labels": true,
71
+ "summary_type": "cls_index",
72
+ "summary_use_proj": true,
73
+ "task_specific_params": null,
74
+ "temperature": 1.0,
75
+ "tie_encoder_decoder": false,
76
+ "tie_word_embeddings": true,
77
+ "tokenizer_class": null,
78
+ "top_k": 50,
79
+ "top_p": 1.0,
80
+ "torch_dtype": null,
81
+ "torchscript": false,
82
+ "transformers_version": "4.9.0.dev0",
83
+ "use_bfloat16": false,
84
+ "use_cache": true,
85
+ "vocab_size": 50000
86
+ },
87
+ "is_encoder_decoder": true,
88
+ "model_type": "vit-gpt2",
89
+ "pad_token_id": 1,
90
+ "transformers_version": null,
91
+ "vit_config": {
92
+ "_name_or_path": "",
93
+ "add_cross_attention": false,
94
+ "architectures": [
95
+ "ViTModel"
96
+ ],
97
+ "attention_probs_dropout_prob": 0.0,
98
+ "bad_words_ids": null,
99
+ "bos_token_id": null,
100
+ "chunk_size_feed_forward": 0,
101
+ "decoder_start_token_id": null,
102
+ "diversity_penalty": 0.0,
103
+ "do_sample": false,
104
+ "early_stopping": false,
105
+ "encoder_no_repeat_ngram_size": 0,
106
+ "eos_token_id": null,
107
+ "finetuning_task": null,
108
+ "forced_bos_token_id": null,
109
+ "forced_eos_token_id": null,
110
+ "hidden_act": "gelu",
111
+ "hidden_dropout_prob": 0.0,
112
+ "hidden_size": 768,
113
+ "id2label": {
114
+ "0": "LABEL_0",
115
+ "1": "LABEL_1"
116
+ },
117
+ "image_size": 224,
118
+ "initializer_range": 0.02,
119
+ "intermediate_size": 3072,
120
+ "is_decoder": false,
121
+ "is_encoder_decoder": false,
122
+ "label2id": {
123
+ "LABEL_0": 0,
124
+ "LABEL_1": 1
125
+ },
126
+ "layer_norm_eps": 1e-12,
127
+ "length_penalty": 1.0,
128
+ "max_length": 20,
129
+ "min_length": 0,
130
+ "model_type": "vit",
131
+ "no_repeat_ngram_size": 0,
132
+ "num_attention_heads": 12,
133
+ "num_beam_groups": 1,
134
+ "num_beams": 1,
135
+ "num_channels": 3,
136
+ "num_hidden_layers": 12,
137
+ "num_return_sequences": 1,
138
+ "output_attentions": false,
139
+ "output_hidden_states": false,
140
+ "output_scores": false,
141
+ "pad_token_id": null,
142
+ "patch_size": 16,
143
+ "prefix": null,
144
+ "problem_type": null,
145
+ "pruned_heads": {},
146
+ "remove_invalid_values": false,
147
+ "repetition_penalty": 1.0,
148
+ "return_dict": true,
149
+ "return_dict_in_generate": false,
150
+ "sep_token_id": null,
151
+ "task_specific_params": null,
152
+ "temperature": 1.0,
153
+ "tie_encoder_decoder": false,
154
+ "tie_word_embeddings": true,
155
+ "tokenizer_class": null,
156
+ "top_k": 50,
157
+ "top_p": 1.0,
158
+ "torch_dtype": null,
159
+ "torchscript": false,
160
+ "transformers_version": "4.9.0.dev0",
161
+ "use_bfloat16": false
162
+ }
163
+ }
outputs/ckpt_5/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8e99c510ec8b0373084cfb90b85e8555fc8dd31c6a5b34bcb4e0da6688f750a
3
+ size 1012706583
outputs/events.out.tfevents.1626474479.t1v-n-cab111a8-w-0.878944.3.v2 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4e8a3d0a6269cbdb20e41469f2d609d57a48438b9fa13117dae18ce8aa723563
3
- size 116380
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26719e710837c70e552401ff0a0d54bdbb82a6bd98e373ff04766dcca1c279f2
3
+ size 228985
outputs/summary.txt CHANGED
@@ -4,3 +4,9 @@ Epoch... (2/10 | Loss: 2.1883292198181152, Learning Rate: 1.6007936210371554e-05
4
  Epoch... (2/10 | Eval Loss: 2.2480881214141846 | Eval rouge1: 15.861 | Eval rouge2: 3.108 | Eval rougeL: 13.6457 | Eval rougeLsum: 13.6531 | Eval gen_len: 31.5794 |)
5
  Epoch... (3/10 | Loss: 2.1005117893218994, Learning Rate: 1.4007936442794744e-05)
6
  Epoch... (3/10 | Eval Loss: 2.182466506958008 | Eval rouge1: 18.7278 | Eval rouge2: 3.4425 | Eval rougeL: 15.3744 | Eval rougeLsum: 15.3757 | Eval gen_len: 31.9742 |)
 
 
 
 
 
 
 
4
  Epoch... (2/10 | Eval Loss: 2.2480881214141846 | Eval rouge1: 15.861 | Eval rouge2: 3.108 | Eval rougeL: 13.6457 | Eval rougeLsum: 13.6531 | Eval gen_len: 31.5794 |)
5
  Epoch... (3/10 | Loss: 2.1005117893218994, Learning Rate: 1.4007936442794744e-05)
6
  Epoch... (3/10 | Eval Loss: 2.182466506958008 | Eval rouge1: 18.7278 | Eval rouge2: 3.4425 | Eval rougeL: 15.3744 | Eval rougeLsum: 15.3757 | Eval gen_len: 31.9742 |)
7
+ Epoch... (4/10 | Loss: 1.9504339694976807, Learning Rate: 1.2007935765723232e-05)
8
+ Epoch... (4/10 | Eval Loss: 2.1522512435913086 | Eval rouge1: 18.217 | Eval rouge2: 2.819 | Eval rougeL: 15.1391 | Eval rougeLsum: 15.1443 | Eval gen_len: 31.9922 |)
9
+ Epoch... (5/10 | Loss: 1.9127023220062256, Learning Rate: 1.0007936907641124e-05)
10
+ Epoch... (5/10 | Eval Loss: 2.1301980018615723 | Eval rouge1: 19.1425 | Eval rouge2: 3.3425 | Eval rougeL: 15.796 | Eval rougeLsum: 15.8031 | Eval gen_len: 31.9547 |)
11
+ Epoch... (6/10 | Loss: 1.9510844945907593, Learning Rate: 8.007936230569612e-06)
12
+ Epoch... (6/10 | Eval Loss: 2.1168270111083984 | Eval rouge1: 18.8478 | Eval rouge2: 3.2246 | Eval rougeL: 15.519 | Eval rougeLsum: 15.5254 | Eval gen_len: 31.9568 |)