ydshieh commited on
Commit
9b4bdf2
β€’
1 Parent(s): 1cc17e2

creat tests dir

Browse files
test_coco_dataset_script.py β†’ tests/test_coco_dataset_script.py RENAMED
File without changes
test_load.py β†’ tests/test_load.py RENAMED
File without changes
test_model.py β†’ tests/test_model.py RENAMED
File without changes
test_save.py β†’ tests/test_save.py RENAMED
File without changes
test_vit_gpt2.py β†’ tests/test_vit_gpt2.py RENAMED
@@ -1,83 +1,83 @@
1
- import sys, os
2
-
3
- current_path = os.path.dirname(os.path.abspath(__file__))
4
- sys.path.append(current_path)
5
-
6
- # Vit - as encoder
7
- from transformers import ViTFeatureExtractor
8
- from PIL import Image
9
- import requests
10
- import numpy as np
11
-
12
- url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
13
- image = Image.open(requests.get(url, stream=True).raw)
14
-
15
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
16
- encoder_inputs = feature_extractor(images=image, return_tensors="jax")
17
- pixel_values = encoder_inputs.pixel_values
18
-
19
- # GPT2 / GPT2LM - as decoder
20
- from transformers import ViTFeatureExtractor, GPT2Tokenizer
21
-
22
- name = 'asi/gpt-fr-cased-small'
23
- tokenizer = GPT2Tokenizer.from_pretrained(name)
24
- decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
25
-
26
- inputs = dict(decoder_inputs)
27
- inputs['pixel_values'] = pixel_values
28
- print(inputs)
29
-
30
- # With new added LM head
31
- from vit_gpt2.modeling_flax_vit_gpt2 import FlaxViTGPT2ForConditionalGeneration
32
- flax_vit_gpt2 = FlaxViTGPT2ForConditionalGeneration.from_vit_gpt2_pretrained(
33
- 'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
34
- )
35
- logits = flax_vit_gpt2(**inputs)[0]
36
- preds = np.argmax(logits, axis=-1)
37
- print('=' * 60)
38
- print('Flax: Vit + modified GPT2 + LM')
39
- print(preds)
40
-
41
- del flax_vit_gpt2
42
-
43
- # With the LM head in GPT2LM
44
- from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
45
- flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
46
- 'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
47
- )
48
-
49
- logits = flax_vit_gpt2_lm(**inputs)[0]
50
- preds = np.argmax(logits, axis=-1)
51
- print('=' * 60)
52
- print('Flax: Vit + modified GPT2LM')
53
- print(preds)
54
-
55
- del flax_vit_gpt2_lm
56
-
57
- # With PyTorch [Vit + unmodified GPT2LMHeadModel]
58
- import torch
59
- from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
60
-
61
- vit_model_pt = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
62
- encoder_inputs = feature_extractor(images=image, return_tensors="pt")
63
- vit_outputs = vit_model_pt(**encoder_inputs)
64
- vit_last_hidden_states = vit_outputs.last_hidden_state
65
-
66
- del vit_model_pt
67
-
68
- inputs_pt = tokenizer("mon chien est mignon", return_tensors="pt")
69
- inputs_pt = dict(inputs_pt)
70
- inputs_pt['encoder_hidden_states'] = vit_last_hidden_states
71
-
72
- config = GPT2Config.from_pretrained('asi/gpt-fr-cased-small')
73
- config.add_cross_attention = True
74
- gpt2_model_pt = GPT2LMHeadModel.from_pretrained('asi/gpt-fr-cased-small', config=config)
75
-
76
- gp2lm_outputs = gpt2_model_pt(**inputs_pt)
77
- logits_pt = gp2lm_outputs.logits
78
- preds_pt = torch.argmax(logits_pt, dim=-1).cpu().detach().numpy()
79
- print('=' * 60)
80
- print('Pytorch: Vit + unmodified GPT2LM')
81
- print(preds_pt)
82
-
83
- del gpt2_model_pt
 
1
+ import sys, os
2
+
3
+ current_path = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(current_path)
5
+
6
+ # Vit - as encoder
7
+ from transformers import ViTFeatureExtractor
8
+ from PIL import Image
9
+ import requests
10
+ import numpy as np
11
+
12
+ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
13
+ image = Image.open(requests.get(url, stream=True).raw)
14
+
15
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
16
+ encoder_inputs = feature_extractor(images=image, return_tensors="jax")
17
+ pixel_values = encoder_inputs.pixel_values
18
+
19
+ # GPT2 / GPT2LM - as decoder
20
+ from transformers import ViTFeatureExtractor, GPT2Tokenizer
21
+
22
+ name = 'asi/gpt-fr-cased-small'
23
+ tokenizer = GPT2Tokenizer.from_pretrained(name)
24
+ decoder_inputs = tokenizer("mon chien est mignon", return_tensors="jax")
25
+
26
+ inputs = dict(decoder_inputs)
27
+ inputs['pixel_values'] = pixel_values
28
+ print(inputs)
29
+
30
+ # With new added LM head
31
+ from vit_gpt2.modeling_flax_vit_gpt2 import FlaxViTGPT2ForConditionalGeneration
32
+ flax_vit_gpt2 = FlaxViTGPT2ForConditionalGeneration.from_vit_gpt2_pretrained(
33
+ 'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
34
+ )
35
+ logits = flax_vit_gpt2(**inputs)[0]
36
+ preds = np.argmax(logits, axis=-1)
37
+ print('=' * 60)
38
+ print('Flax: Vit + modified GPT2 + LM')
39
+ print(preds)
40
+
41
+ del flax_vit_gpt2
42
+
43
+ # With the LM head in GPT2LM
44
+ from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
45
+ flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vit_gpt2_pretrained(
46
+ 'google/vit-base-patch16-224-in21k', 'asi/gpt-fr-cased-small'
47
+ )
48
+
49
+ logits = flax_vit_gpt2_lm(**inputs)[0]
50
+ preds = np.argmax(logits, axis=-1)
51
+ print('=' * 60)
52
+ print('Flax: Vit + modified GPT2LM')
53
+ print(preds)
54
+
55
+ del flax_vit_gpt2_lm
56
+
57
+ # With PyTorch [Vit + unmodified GPT2LMHeadModel]
58
+ import torch
59
+ from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
60
+
61
+ vit_model_pt = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
62
+ encoder_inputs = feature_extractor(images=image, return_tensors="pt")
63
+ vit_outputs = vit_model_pt(**encoder_inputs)
64
+ vit_last_hidden_states = vit_outputs.last_hidden_state
65
+
66
+ del vit_model_pt
67
+
68
+ inputs_pt = tokenizer("mon chien est mignon", return_tensors="pt")
69
+ inputs_pt = dict(inputs_pt)
70
+ inputs_pt['encoder_hidden_states'] = vit_last_hidden_states
71
+
72
+ config = GPT2Config.from_pretrained('asi/gpt-fr-cased-small')
73
+ config.add_cross_attention = True
74
+ gpt2_model_pt = GPT2LMHeadModel.from_pretrained('asi/gpt-fr-cased-small', config=config)
75
+
76
+ gp2lm_outputs = gpt2_model_pt(**inputs_pt)
77
+ logits_pt = gp2lm_outputs.logits
78
+ preds_pt = torch.argmax(logits_pt, dim=-1).cpu().detach().numpy()
79
+ print('=' * 60)
80
+ print('Pytorch: Vit + unmodified GPT2LM')
81
+ print(preds_pt)
82
+
83
+ del gpt2_model_pt