ydshieh commited on
Commit
f338d56
1 Parent(s): 8eca5ba

add model.py and app.py

Browse files
Files changed (3) hide show
  1. app.py +1 -0
  2. model.py +57 -0
  3. test_model.py +74 -0
app.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from model import *
model.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ current_path = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(current_path)
5
+
6
+ # Main model - ViTGPT2LM
7
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
8
+
9
+ # Vit - as encoder
10
+ from transformers import ViTFeatureExtractor
11
+ from PIL import Image
12
+ import requests
13
+ import numpy as np
14
+
15
+ # GPT2 / GPT2LM - as decoder
16
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer
17
+
18
+ model_name_or_path = './outputs/ckpt_2/'
19
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_name_or_path)
20
+
21
+ vit_model_name = 'google/vit-base-patch16-224-in21k'
22
+ feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
23
+
24
+ gpt2_model_name = 'asi/gpt-fr-cased-small'
25
+ tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
26
+
27
+ max_length = 16
28
+ num_beams = 4
29
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
30
+
31
+
32
+ def predict(image):
33
+
34
+ image = Image.open(requests.get(url, stream=True).raw)
35
+ # batch dim is added automatically
36
+ encoder_inputs = feature_extractor(images=image, return_tensors="jax")
37
+ pixel_values = encoder_inputs.pixel_values
38
+
39
+ # generation
40
+ batch = {'pixel_values': pixel_values}
41
+ generation = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
42
+
43
+ token_ids = np.array(generation.sequences)[0]
44
+ caption = tokenizer.decode(token_ids)
45
+
46
+ return caption, token_ids
47
+
48
+
49
+ if __name__ == '__main__':
50
+
51
+
52
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
53
+ image = Image.open(requests.get(url, stream=True).raw)
54
+ caption, token_ids = predict(image)
55
+
56
+ print(f'token_ids: {token_ids}')
57
+ print(f'caption: {caption}')
test_model.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ current_path = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(current_path)
5
+
6
+ # Main model - ViTGPT2LM
7
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
8
+
9
+ # Vit - as encoder
10
+ from transformers import ViTFeatureExtractor
11
+ from PIL import Image
12
+ import requests
13
+ import numpy as np
14
+
15
+ # GPT2 / GPT2LM - as decoder
16
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer
17
+
18
+ model_name_or_path = './outputs/ckpt_2/'
19
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_name_or_path)
20
+
21
+ vit_model_name = 'google/vit-base-patch16-224-in21k'
22
+ feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name)
23
+
24
+ gpt2_model_name = 'asi/gpt-fr-cased-small'
25
+ tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
26
+
27
+ max_length = 32
28
+ num_beams = 16
29
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
30
+
31
+
32
+ # encoder data
33
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
34
+ image = Image.open(requests.get(url, stream=True).raw)
35
+ # batch dim is added automatically
36
+ encoder_inputs = feature_extractor(images=image, return_tensors="jax")
37
+ pixel_values = encoder_inputs.pixel_values
38
+ print(f'pixel_values.shape = {pixel_values.shape}')
39
+
40
+ # decoder data
41
+ sentence = 'mon chien est mignon'
42
+ # IMPORTANT: For training/evaluation/attention_mask/loss
43
+ sentence += ' ' + tokenizer.eos_token
44
+ # batch dim is added automatically
45
+ decoder_inputs = tokenizer(sentence, return_tensors="jax")
46
+ print(decoder_inputs)
47
+ print(f'input_ids.shape = {decoder_inputs.input_ids.shape}')
48
+
49
+ # model data
50
+ inputs = dict(decoder_inputs)
51
+ inputs['pixel_values'] = pixel_values
52
+
53
+
54
+ logits = flax_vit_gpt2_lm(**inputs)[0]
55
+ preds = np.argmax(logits, axis=-1)
56
+ print('=' * 60)
57
+ print('Flax: Vit-GPT2-LM')
58
+ print('predicted token ids:')
59
+ print(preds)
60
+ print('=' * 60)
61
+
62
+
63
+ # Generation!
64
+ batch = {'pixel_values': pixel_values}
65
+ generation = flax_vit_gpt2_lm.generate(batch['pixel_values'], **gen_kwargs)
66
+ print('generation:')
67
+ print(generation)
68
+ print('=' * 60)
69
+
70
+ token_ids = np.array(generation.sequences)[0]
71
+ caption = tokenizer.decode(token_ids)
72
+ print(f'token_ids: {token_ids}')
73
+ print(f'caption: {caption}')
74
+ print('=' * 60)