ydshieh commited on
Commit
ec3ceb6
1 Parent(s): 7d3b1a0

Update test_model.py

Browse files
Files changed (1) hide show
  1. tests/test_model.py +29 -20
tests/test_model.py CHANGED
@@ -95,7 +95,7 @@ print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}')
95
  orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
96
  gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
97
 
98
- # Generation!
99
  num_beams = 1
100
  gen_kwargs = {"max_length": 6, "num_beams": num_beams}
101
 
@@ -138,20 +138,27 @@ logits = model_outputs[0]
138
  preds = np.argmax(logits, axis=-1)
139
 
140
  print('=' * 60)
141
- print('Flax: Vit-GPT2-LM')
142
- print('predicted token ids:')
143
  print(preds)
144
 
145
- # encoder_last_hidden_state = model_outputs['encoder_last_hidden_state']
146
- # print(encoder_last_hidden_state)
147
- # encoder_kwargs = {}
148
- # encoder_outputs = flax_vit_gpt2_lm.encode(pixel_values, return_dict=True, **encoder_kwargs)
149
- # print(encoder_outputs['last_hidden_state'])
 
 
 
 
 
 
 
 
150
 
151
  # ================================================================================
152
- # Check generation
153
 
154
- # Generation!
155
  num_beams = 1
156
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
157
 
@@ -215,17 +222,19 @@ logits = text_model_pt_outputs[0]
215
  preds = np.argmax(logits.detach().numpy(), axis=-1)
216
 
217
  print('=' * 60)
218
- print('PyTroch: Vit --> GPT2-LM')
219
  print('predicted token ids:')
220
  print(preds)
221
 
222
- #generated = text_model_pt.generate(encoder_outputs=vision_model_pt_outputs, **gen_kwargs)
223
- #token_ids = np.array(generated.sequences)[0]
 
224
 
225
- #print('=' * 60)
226
- #print(f'Pytorch\'s GPT2 LM generated token ids: {token_ids}')
227
-
228
- #caption = tokenizer.decode(token_ids)
229
-
230
- #print('=' * 60)
231
- #print(f'Pytorch\'s GPT2 LM generated caption: {caption}')
 
95
  orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
96
  gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
97
 
98
+ # generation!
99
  num_beams = 1
100
  gen_kwargs = {"max_length": 6, "num_beams": num_beams}
101
 
138
  preds = np.argmax(logits, axis=-1)
139
 
140
  print('=' * 60)
141
+ print('Flax ViT-GPT2-LM - predicted token ids:')
 
142
  print(preds)
143
 
144
+ encoder_last_hidden_state = model_outputs['encoder_last_hidden_state']
145
+ print('=' * 60)
146
+ print("encoder_last_hidden_state given by model.__call__():")
147
+ print(encoder_last_hidden_state)
148
+
149
+ encoder_outputs = model.encode(pixel_values, return_dict=True)
150
+ print('=' * 60)
151
+ print("encoder's last_hidden_state given by model.encode():")
152
+ print(encoder_outputs['last_hidden_state'])
153
+
154
+ total_diff = np.sum(np.abs(encoder_outputs['last_hidden_state'] - encoder_last_hidden_state))
155
+ print('=' * 60)
156
+ print(f"total difference: {total_diff}")
157
 
158
  # ================================================================================
159
+ # Check model generation
160
 
161
+ # generation
162
  num_beams = 1
163
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
164
 
222
  preds = np.argmax(logits.detach().numpy(), axis=-1)
223
 
224
  print('=' * 60)
225
+ print('PyTroch: ViT --> GPT2-LM')
226
  print('predicted token ids:')
227
  print(preds)
228
 
229
+ model_logits = np.array(model_outputs.logits)
230
+ text_model_pt_logits = text_model_pt_outputs.logits.detach().cpu().numpy()
231
+ total_diff = np.sum(np.abs(model_logits - text_model_pt_logits))
232
 
233
+ print('=' * 60)
234
+ print("model_logits:")
235
+ print(model_logits)
236
+ print('=' * 60)
237
+ print("text_model_pt_logits:")
238
+ print(text_model_pt_logits)
239
+ print('=' * 60)
240
+ print(f"total difference between logits: {total_diff}")