ydshieh commited on
Commit
56fa8ac
2 Parent(s): 0ac6b6e dc74cb9

Merge branch 'main' of https://huggingface.co/flax-community/vit-gpt2

Browse files
Files changed (1) hide show
  1. tests/test_model.py +66 -28
tests/test_model.py CHANGED
@@ -6,28 +6,30 @@ sys.path.append(current_path)
6
  # Main model - ViTGPT2LM
7
  from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
8
 
9
- # Vit - as encoder
10
  from transformers import ViTFeatureExtractor
11
  from PIL import Image
12
  import requests
13
  import numpy as np
 
 
14
 
15
- # GPT2 / GPT2LM - as decoder
16
- from transformers import ViTFeatureExtractor, GPT2Tokenizer
17
 
18
- model_name_or_path = './outputs/ckpt_2/'
19
- flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_name_or_path)
20
 
21
- vit_model_name = 'google/vit-base-patch16-224-in21k'
22
- feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
23
 
24
- gpt2_model_name = 'asi/gpt-fr-cased-small'
25
- tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
26
-
27
- max_length = 32
28
- num_beams = 16
29
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
30
 
 
 
31
 
32
  # encoder data
33
  url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
@@ -42,16 +44,47 @@ sentence = 'mon chien est mignon'
42
  # IMPORTANT: For training/evaluation/attention_mask/loss
43
  sentence += ' ' + tokenizer.eos_token
44
  # batch dim is added automatically
45
- decoder_inputs = tokenizer(sentence, return_tensors="jax")
46
- print(decoder_inputs)
47
- print(f'input_ids.shape = {decoder_inputs.input_ids.shape}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # model data
50
- inputs = dict(decoder_inputs)
51
- inputs['pixel_values'] = pixel_values
52
-
53
-
54
- logits = flax_vit_gpt2_lm(**inputs)[0]
 
 
 
 
 
 
55
  preds = np.argmax(logits, axis=-1)
56
  print('=' * 60)
57
  print('Flax: Vit-GPT2-LM')
@@ -59,16 +92,21 @@ print('predicted token ids:')
59
  print(preds)
60
  print('=' * 60)
61
 
 
 
 
 
 
62
 
63
  # Generation!
 
 
 
64
  batch = {'pixel_values': pixel_values}
65
- generation = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
66
- print('generation:')
67
- print(generation)
68
  print('=' * 60)
69
-
70
- token_ids = np.array(generation.sequences)[0]
71
  caption = tokenizer.decode(token_ids)
72
- print(f'token_ids: {token_ids}')
73
- print(f'caption: {caption}')
74
  print('=' * 60)
 
6
  # Main model - ViTGPT2LM
7
  from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
8
 
9
+ # ViT - as encoder
10
  from transformers import ViTFeatureExtractor
11
  from PIL import Image
12
  import requests
13
  import numpy as np
14
+ import jax
15
+ import jax.numpy as jnp
16
 
17
+ # GPT2+LM - as decoder
18
+ from transformers import GPT2Tokenizer
19
 
20
+ max_length = 8
 
21
 
22
+ vision_model_name = 'google/vit-base-patch16-224-in21k'
23
+ text_model_name = 'asi/gpt-fr-cased-small'
24
 
25
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained(
26
+ vision_pretrained_model_name_or_path=vision_model_name,
27
+ text_pretrained_model_name_or_path=text_model_name
28
+ )
29
+ model = flax_vit_gpt2_lm
 
30
 
31
+ feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name)
32
+ tokenizer = GPT2Tokenizer.from_pretrained(text_model_name)
33
 
34
  # encoder data
35
  url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
 
44
  # IMPORTANT: For training/evaluation/attention_mask/loss
45
  sentence += ' ' + tokenizer.eos_token
46
  # batch dim is added automatically
47
+ # Setup the tokenizer for targets
48
+ with tokenizer.as_target_tokenizer():
49
+ labels = tokenizer(sentence, max_length=max_length, padding="max_length", truncation=True, return_tensors="np")
50
+
51
+ def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
52
+ """
53
+ Shift input ids one token to the right.
54
+ """
55
+ shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
56
+ shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
57
+ # replace possible -100 values in labels by `pad_token_id`
58
+ shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
59
+
60
+ return shifted_input_ids
61
+
62
+ decoder_input_ids = shift_tokens_right(
63
+ jnp.array(labels["input_ids"]),
64
+ model.config.text_config.pad_token_id,
65
+ model.config.decoder_start_token_id
66
+ )
67
+ decoder_input_ids = np.asarray(decoder_input_ids)
68
+ # We need decoder_attention_mask so we can ignore pad tokens from loss
69
+ decoder_attention_mask = labels["attention_mask"]
70
+
71
+ print(f'decoder_inputs = {decoder_input_ids}')
72
+ print(f'decoder_input_ids.shape = {decoder_input_ids.shape}')
73
+ print(f'decoder_attention_mask = {decoder_attention_mask}')
74
+ print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}')
75
 
76
  # model data
77
+ model_inputs = {
78
+ 'pixel_values': pixel_values,
79
+ 'attention_mask': None,
80
+ 'decoder_input_ids': decoder_input_ids,
81
+ 'decoder_attention_mask': decoder_attention_mask,
82
+ 'decoder_position_ids': None,
83
+ }
84
+
85
+ # Model call
86
+ model_outputs = flax_vit_gpt2_lm(**model_inputs)
87
+ logits = model_outputs[0]
88
  preds = np.argmax(logits, axis=-1)
89
  print('=' * 60)
90
  print('Flax: Vit-GPT2-LM')
 
92
  print(preds)
93
  print('=' * 60)
94
 
95
+ # encoder_last_hidden_state = model_outputs['encoder_last_hidden_state']
96
+ # print(encoder_last_hidden_state)
97
+ # encoder_kwargs = {}
98
+ # encoder_outputs = flax_vit_gpt2_lm.encode(pixel_values, return_dict=True, **encoder_kwargs)
99
+ # print(encoder_outputs['last_hidden_state'])
100
 
101
  # Generation!
102
+ num_beams = 1
103
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
104
+
105
  batch = {'pixel_values': pixel_values}
106
+ generated = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
107
+ token_ids = np.array(generated.sequences)[0]
108
+ print(f'generated token ids: {token_ids}')
109
  print('=' * 60)
 
 
110
  caption = tokenizer.decode(token_ids)
111
+ print(f'generated caption: {caption}')
 
112
  print('=' * 60)