ydshieh commited on
Commit
8eca5ba
1 Parent(s): a244e91

clean generate.py

Browse files
Files changed (1) hide show
  1. generate.py +43 -32
generate.py CHANGED
@@ -3,7 +3,8 @@ import sys, os
3
  current_path = os.path.dirname(os.path.abspath(__file__))
4
  sys.path.append(current_path)
5
 
6
-
 
7
 
8
  # Vit - as encoder
9
  from transformers import ViTFeatureExtractor
@@ -11,53 +12,63 @@ from PIL import Image
11
  import requests
12
  import numpy as np
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
15
  image = Image.open(requests.get(url, stream=True).raw)
16
-
17
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
18
  encoder_inputs = feature_extractor(images=image, return_tensors="jax")
19
  pixel_values = encoder_inputs.pixel_values
20
-
21
- # GPT2 / GPT2LM - as decoder
22
- from transformers import ViTFeatureExtractor, GPT2Tokenizer
23
-
24
- name = 'asi/gpt-fr-cased-small'
25
- tokenizer = GPT2Tokenizer.from_pretrained(name)
26
- decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax", )
 
27
  print(decoder_inputs)
 
28
 
29
- # Setup the tokenizer for targets
30
- with tokenizer.as_target_tokenizer():
31
- labels = tokenizer(
32
- ['un chien super beau' + ' ' + tokenizer.eos_token, 'un chat' + ' ' + tokenizer.eos_token], max_length=5, padding="max_length", truncation=True, return_tensors="np"
33
- )
34
- print(labels)
35
- exit(0)
36
-
37
  inputs = dict(decoder_inputs)
38
  inputs['pixel_values'] = pixel_values
39
- #print(inputs)
40
 
41
 
42
- # With the LM head in GPT2LM
43
- from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
44
- flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained('./outputs-small-ds/ckpt_3',)
45
-
46
  logits = flax_vit_gpt2_lm(**inputs)[0]
47
  preds = np.argmax(logits, axis=-1)
48
  print('=' * 60)
49
- print('Flax: Vit + modified GPT2LM')
50
- #print(preds)
 
 
51
 
52
- max_length = 32
53
- num_beams = 16
54
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
55
  batch = {'pixel_values': pixel_values}
56
  generation = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
 
57
  print(generation)
 
58
 
59
  token_ids = np.array(generation.sequences)[0]
60
- generation = tokenizer.decode(token_ids)
61
- print(generation)
62
-
63
- del flax_vit_gpt2_lm
 
3
  current_path = os.path.dirname(os.path.abspath(__file__))
4
  sys.path.append(current_path)
5
 
6
+ # Main model - ViTGPT2LM
7
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
8
 
9
  # Vit - as encoder
10
  from transformers import ViTFeatureExtractor
 
12
  import requests
13
  import numpy as np
14
 
15
+ # GPT2 / GPT2LM - as decoder
16
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer
17
+
18
+ model_name_or_path = './outputs/ckpt_2/'
19
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_name_or_path)
20
+
21
+ vit_model_name = 'google/vit-base-patch16-224-in21k'
22
+ feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
23
+
24
+ gpt2_model_name = 'asi/gpt-fr-cased-small'
25
+ tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
26
+
27
+ max_length = 32
28
+ num_beams = 16
29
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
30
+
31
+
32
+ # encoder data
33
  url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
34
  image = Image.open(requests.get(url, stream=True).raw)
35
+ # batch dim is added automatically
 
36
  encoder_inputs = feature_extractor(images=image, return_tensors="jax")
37
  pixel_values = encoder_inputs.pixel_values
38
+ print(f'pixel_values.shape = {pixel_values.shape}')
39
+
40
+ # decoder data
41
+ sentence = 'mon chien est mignon'
42
+ # IMPORTANT: For training/evaluation/attention_mask/loss
43
+ sentence += ' ' + tokenizer.eos_token
44
+ # batch dim is added automatically
45
+ decoder_inputs = tokenizer(sentence, return_tensors="jax")
46
  print(decoder_inputs)
47
+ print(f'input_ids.shape = {decoder_inputs.input_ids.shape}')
48
 
49
+ # model data
 
 
 
 
 
 
 
50
  inputs = dict(decoder_inputs)
51
  inputs['pixel_values'] = pixel_values
 
52
 
53
 
 
 
 
 
54
  logits = flax_vit_gpt2_lm(**inputs)[0]
55
  preds = np.argmax(logits, axis=-1)
56
  print('=' * 60)
57
+ print('Flax: Vit-GPT2-LM')
58
+ print('predicted token ids:')
59
+ print(preds)
60
+ print('=' * 60)
61
 
62
+
63
+ # Generation!
 
64
  batch = {'pixel_values': pixel_values}
65
  generation = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
66
+ print('generation:')
67
  print(generation)
68
+ print('=' * 60)
69
 
70
  token_ids = np.array(generation.sequences)[0]
71
+ caption = tokenizer.decode(token_ids)
72
+ print(f'token_ids: {token_ids}')
73
+ print(f'caption: {caption}')
74
+ print('=' * 60)