ydshieh commited on
Commit
845642f
1 Parent(s): dc74cb9

update test_model.py

Browse files
Files changed (1) hide show
  1. tests/test_model.py +52 -4
tests/test_model.py CHANGED
@@ -3,6 +3,9 @@ import sys, os
3
  current_path = os.path.dirname(os.path.abspath(__file__))
4
  sys.path.append(current_path)
5
 
 
 
 
6
  # Main model - ViTGPT2LM
7
  from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
8
 
@@ -37,6 +40,8 @@ image = Image.open(requests.get(url, stream=True).raw)
37
  # batch dim is added automatically
38
  encoder_inputs = feature_extractor(images=image, return_tensors="jax")
39
  pixel_values = encoder_inputs.pixel_values
 
 
40
  print(f'pixel_values.shape = {pixel_values.shape}')
41
 
42
  # decoder data
@@ -68,11 +73,36 @@ decoder_input_ids = np.asarray(decoder_input_ids)
68
  # We need decoder_attention_mask so we can ignore pad tokens from loss
69
  decoder_attention_mask = labels["attention_mask"]
70
 
 
71
  print(f'decoder_inputs = {decoder_input_ids}')
72
  print(f'decoder_input_ids.shape = {decoder_input_ids.shape}')
73
  print(f'decoder_attention_mask = {decoder_attention_mask}')
74
  print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}')
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # model data
77
  model_inputs = {
78
  'pixel_values': pixel_values,
@@ -83,14 +113,14 @@ model_inputs = {
83
  }
84
 
85
  # Model call
86
- model_outputs = flax_vit_gpt2_lm(**model_inputs)
87
  logits = model_outputs[0]
88
  preds = np.argmax(logits, axis=-1)
 
89
  print('=' * 60)
90
  print('Flax: Vit-GPT2-LM')
91
  print('predicted token ids:')
92
  print(preds)
93
- print('=' * 60)
94
 
95
  # encoder_last_hidden_state = model_outputs['encoder_last_hidden_state']
96
  # print(encoder_last_hidden_state)
@@ -103,10 +133,28 @@ num_beams = 1
103
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
104
 
105
  batch = {'pixel_values': pixel_values}
106
- generated = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
107
  token_ids = np.array(generated.sequences)[0]
108
- print(f'generated token ids: {token_ids}')
109
  print('=' * 60)
 
 
110
  caption = tokenizer.decode(token_ids)
 
 
111
  print(f'generated caption: {caption}')
 
 
 
 
 
 
 
 
 
 
 
 
112
  print('=' * 60)
 
 
 
3
  current_path = os.path.dirname(os.path.abspath(__file__))
4
  sys.path.append(current_path)
5
 
6
+ from transformers import FlaxGPT2LMHeadModel as Orig_FlaxGPT2LMHeadModel
7
+ from vit_gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel
8
+
9
  # Main model - ViTGPT2LM
10
  from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
11
 
 
40
  # batch dim is added automatically
41
  encoder_inputs = feature_extractor(images=image, return_tensors="jax")
42
  pixel_values = encoder_inputs.pixel_values
43
+
44
+ print('=' * 60)
45
  print(f'pixel_values.shape = {pixel_values.shape}')
46
 
47
  # decoder data
 
73
  # We need decoder_attention_mask so we can ignore pad tokens from loss
74
  decoder_attention_mask = labels["attention_mask"]
75
 
76
+ print('=' * 60)
77
  print(f'decoder_inputs = {decoder_input_ids}')
78
  print(f'decoder_input_ids.shape = {decoder_input_ids.shape}')
79
  print(f'decoder_attention_mask = {decoder_attention_mask}')
80
  print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}')
81
 
82
+ orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
83
+ gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
84
+
85
+ # Generation!
86
+ num_beams = 1
87
+ gen_kwargs = {"max_length": 6, "num_beams": num_beams}
88
+
89
+ orig_gpt2_generated = orig_gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs)
90
+ gpt2_generated = gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs)
91
+
92
+ orig_token_ids = np.array(orig_gpt2_generated.sequences)[0]
93
+ token_ids = np.array(gpt2_generated.sequences)[0]
94
+
95
+ orig_caption = tokenizer.decode(orig_token_ids)
96
+ caption = tokenizer.decode(token_ids)
97
+
98
+ print('=' * 60)
99
+ print(f'orig. GPT2 generated token ids: {orig_token_ids}')
100
+ print(f'GPT2 generated token ids: {token_ids}')
101
+
102
+ print('=' * 60)
103
+ print(f'orig. GPT2 generated caption: {orig_caption}')
104
+ print(f'GPT2 generated caption: {caption}')
105
+
106
  # model data
107
  model_inputs = {
108
  'pixel_values': pixel_values,
 
113
  }
114
 
115
  # Model call
116
+ model_outputs = model(**model_inputs)
117
  logits = model_outputs[0]
118
  preds = np.argmax(logits, axis=-1)
119
+
120
  print('=' * 60)
121
  print('Flax: Vit-GPT2-LM')
122
  print('predicted token ids:')
123
  print(preds)
 
124
 
125
  # encoder_last_hidden_state = model_outputs['encoder_last_hidden_state']
126
  # print(encoder_last_hidden_state)
 
133
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
134
 
135
  batch = {'pixel_values': pixel_values}
136
+ generated = model.generate(batch['pixel_values'], **gen_kwargs)
137
  token_ids = np.array(generated.sequences)[0]
138
+
139
  print('=' * 60)
140
+ print(f'generated token ids: {token_ids}')
141
+
142
  caption = tokenizer.decode(token_ids)
143
+
144
+ print('=' * 60)
145
  print(f'generated caption: {caption}')
146
+
147
+ # save
148
+ os.makedirs('./model/', exist_ok=True)
149
+ model.save_pretrained(save_directory='./model/')
150
+
151
+ # load
152
+ _model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained('./model/')
153
+
154
+ # check if the result is the same as before
155
+ _generated = _model.generate(batch['pixel_values'], **gen_kwargs)
156
+ _token_ids = np.array(_generated.sequences)[0]
157
+
158
  print('=' * 60)
159
+ print(f'new generated token ids: {_token_ids}')
160
+ print(f'token_ids == new_token_ids: {token_ids == _token_ids}')