ydshieh commited on
Commit
3b81fb5
1 Parent(s): 5a95e3a

Add a test against pytorch's GPT2

Browse files
Files changed (3) hide show
  1. tests/test_load.py +0 -48
  2. tests/test_model.py +47 -0
  3. tests/test_save.py +0 -48
tests/test_load.py DELETED
@@ -1,48 +0,0 @@
1
- import sys, os
2
-
3
- current_path = os.path.dirname(os.path.abspath(__file__))
4
- sys.path.append(current_path)
5
-
6
- # Vit - as encoder
7
- from transformers import ViTFeatureExtractor
8
- from PIL import Image
9
- import requests
10
- import numpy as np
11
-
12
- url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
13
- image = Image.open(requests.get(url, stream=True).raw)
14
-
15
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
16
- encoder_inputs = feature_extractor(images=image, return_tensors="jax")
17
- pixel_values = encoder_inputs.pixel_values
18
-
19
- # GPT2 / GPT2LM - as decoder
20
- from transformers import ViTFeatureExtractor, GPT2Tokenizer
21
-
22
- name = 'asi/gpt-fr-cased-small'
23
- tokenizer = GPT2Tokenizer.from_pretrained(name)
24
- decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
25
-
26
- inputs = dict(decoder_inputs)
27
- inputs['pixel_values'] = pixel_values
28
- print(inputs)
29
-
30
-
31
-
32
-
33
-
34
- # With the LM head in GPT2LM
35
- from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
36
- flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(
37
- '.',
38
- )
39
-
40
- logits = flax_vit_gpt2_lm(**inputs)[0]
41
- preds = np.argmax(logits, axis=-1)
42
- print('=' * 60)
43
- print('Flax: Vit + modified GPT2LM')
44
- print(preds)
45
-
46
- # flax_vit_gpt2_lm.save_pretrained('.')
47
-
48
- del flax_vit_gpt2_lm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_model.py CHANGED
@@ -22,6 +22,9 @@ from transformers import GPT2Tokenizer
22
 
23
  max_length = 8
24
 
 
 
 
25
  vision_model_name = 'google/vit-base-patch16-224-in21k'
26
  text_model_name = 'asi/gpt-fr-cased-small'
27
 
@@ -34,6 +37,9 @@ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretra
34
  )
35
  model = flax_vit_gpt2_lm
36
 
 
 
 
37
  feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name)
38
  tokenizer = GPT2Tokenizer.from_pretrained(text_model_name)
39
 
@@ -56,6 +62,7 @@ sentence += ' ' + tokenizer.eos_token
56
  with tokenizer.as_target_tokenizer():
57
  labels = tokenizer(sentence, max_length=max_length, padding="max_length", truncation=True, return_tensors="np")
58
 
 
59
  def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
60
  """
61
  Shift input ids one token to the right.
@@ -82,6 +89,9 @@ print(f'decoder_input_ids.shape = {decoder_input_ids.shape}')
82
  print(f'decoder_attention_mask = {decoder_attention_mask}')
83
  print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}')
84
 
 
 
 
85
  orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
86
  gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
87
 
@@ -108,6 +118,8 @@ print(f'GPT2 generated caption: {caption}')
108
 
109
  assert list(orig_token_ids) == list(token_ids)
110
 
 
 
111
  # model data
112
  model_inputs = {
113
  'pixel_values': pixel_values,
@@ -117,6 +129,9 @@ model_inputs = {
117
  'decoder_position_ids': None,
118
  }
119
 
 
 
 
120
  # Model call
121
  model_outputs = model(**model_inputs)
122
  logits = model_outputs[0]
@@ -133,6 +148,9 @@ print(preds)
133
  # encoder_outputs = flax_vit_gpt2_lm.encode(pixel_values, return_dict=True, **encoder_kwargs)
134
  # print(encoder_outputs['last_hidden_state'])
135
 
 
 
 
136
  # Generation!
137
  num_beams = 1
138
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
@@ -149,6 +167,9 @@ caption = tokenizer.decode(token_ids)
149
  print('=' * 60)
150
  print(f'generated caption: {caption}')
151
 
 
 
 
152
  # save
153
  os.makedirs('./model/', exist_ok=True)
154
  model.save_pretrained(save_directory='./model/')
@@ -163,3 +184,29 @@ _token_ids = np.array(_generated.sequences)[0]
163
  print('=' * 60)
164
  print(f'new generated token ids: {_token_ids}')
165
  print(f'token_ids == new_token_ids: {token_ids == _token_ids}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  max_length = 8
24
 
25
+ # ================================================================================
26
+ # Models preparation
27
+
28
  vision_model_name = 'google/vit-base-patch16-224-in21k'
29
  text_model_name = 'asi/gpt-fr-cased-small'
30
 
 
37
  )
38
  model = flax_vit_gpt2_lm
39
 
40
+ # ================================================================================
41
+ # Inputs preparation
42
+
43
  feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name)
44
  tokenizer = GPT2Tokenizer.from_pretrained(text_model_name)
45
 
 
62
  with tokenizer.as_target_tokenizer():
63
  labels = tokenizer(sentence, max_length=max_length, padding="max_length", truncation=True, return_tensors="np")
64
 
65
+
66
  def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
67
  """
68
  Shift input ids one token to the right.
 
89
  print(f'decoder_attention_mask = {decoder_attention_mask}')
90
  print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}')
91
 
92
+ # ================================================================================
93
+ # Check `FlaxGPT2LMHeadModel` has the same results in the new version (when no `encoder_outputs` is provided).
94
+
95
  orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
96
  gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
97
 
 
118
 
119
  assert list(orig_token_ids) == list(token_ids)
120
 
121
+ # ================================================================================
122
+
123
  # model data
124
  model_inputs = {
125
  'pixel_values': pixel_values,
 
129
  'decoder_position_ids': None,
130
  }
131
 
132
+ # ================================================================================
133
+ # Check `model.__call__()`
134
+
135
  # Model call
136
  model_outputs = model(**model_inputs)
137
  logits = model_outputs[0]
 
148
  # encoder_outputs = flax_vit_gpt2_lm.encode(pixel_values, return_dict=True, **encoder_kwargs)
149
  # print(encoder_outputs['last_hidden_state'])
150
 
151
+ # ================================================================================
152
+ # Check generation
153
+
154
  # Generation!
155
  num_beams = 1
156
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
 
167
  print('=' * 60)
168
  print(f'generated caption: {caption}')
169
 
170
+ # ================================================================================
171
+ # Check save & load
172
+
173
  # save
174
  os.makedirs('./model/', exist_ok=True)
175
  model.save_pretrained(save_directory='./model/')
 
184
  print('=' * 60)
185
  print(f'new generated token ids: {_token_ids}')
186
  print(f'token_ids == new_token_ids: {token_ids == _token_ids}')
187
+
188
+ # ================================================================================
189
+ # Check PyTorch version's output - it should be the same as above
190
+
191
+ import torch
192
+ from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
193
+
194
+ vision_model_pt = ViTModel.from_pretrained(vision_model_name)
195
+ config = GPT2Config.from_pretrained(text_model_name)
196
+ config.is_encoder_decoder = True
197
+ config.add_cross_attention = True
198
+ text_model_pt = GPT2LMHeadModel.from_pretrained(text_model_name, config=config)
199
+
200
+ encoder_inputs_pt = feature_extractor(images=image, return_tensors="pt")
201
+ vision_model_pt_outputs = vision_model_pt(**encoder_inputs)
202
+
203
+ generated = text_model_pt.generate(encoder_outputs=vision_model_pt_outputs, **gen_kwargs)
204
+ token_ids = np.array(generated.sequences)[0]
205
+
206
+ print('=' * 60)
207
+ print(f'Pytorch\'s GPT2 LM generated token ids: {token_ids}')
208
+
209
+ caption = tokenizer.decode(token_ids)
210
+
211
+ print('=' * 60)
212
+ print(f'Pytorch\'s GPT2 LM generated caption: {caption}')
tests/test_save.py DELETED
@@ -1,48 +0,0 @@
1
- import sys, os
2
-
3
- current_path = os.path.dirname(os.path.abspath(__file__))
4
- sys.path.append(current_path)
5
-
6
- # Vit - as encoder
7
- from transformers import ViTFeatureExtractor
8
- from PIL import Image
9
- import requests
10
- import numpy as np
11
-
12
- url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
13
- image = Image.open(requests.get(url, stream=True).raw)
14
-
15
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
16
- encoder_inputs = feature_extractor(images=image, return_tensors="jax")
17
- pixel_values = encoder_inputs.pixel_values
18
-
19
- # GPT2 / GPT2LM - as decoder
20
- from transformers import ViTFeatureExtractor, GPT2Tokenizer
21
-
22
- name = 'asi/gpt-fr-cased-small'
23
- tokenizer = GPT2Tokenizer.from_pretrained(name)
24
- decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
25
-
26
- inputs = dict(decoder_inputs)
27
- inputs['pixel_values'] = pixel_values
28
- print(inputs)
29
-
30
-
31
-
32
-
33
-
34
- # With the LM head in GPT2LM
35
- from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
36
- flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
37
- 'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
38
- )
39
-
40
- logits = flax_vit_gpt2_lm(**inputs)[0]
41
- preds = np.argmax(logits, axis=-1)
42
- print('=' * 60)
43
- print('Flax: Vit + modified GPT2LM')
44
- print(preds)
45
-
46
- flax_vit_gpt2_lm.save_pretrained('.')
47
-
48
- del flax_vit_gpt2_lm