File size: 5,255 Bytes
f338d56 845642f f338d56 dc74cb9 f338d56 dc74cb9 f338d56 dc74cb9 f338d56 dc74cb9 f338d56 dc74cb9 f338d56 dc74cb9 f338d56 dc74cb9 f338d56 845642f f338d56 dc74cb9 845642f dc74cb9 f338d56 845642f f338d56 dc74cb9 845642f dc74cb9 f338d56 845642f f338d56 dc74cb9 f338d56 dc74cb9 f338d56 845642f dc74cb9 845642f f338d56 845642f f338d56 845642f dc74cb9 845642f f338d56 845642f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import sys, os
current_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_path)
from transformers import FlaxGPT2LMHeadModel as Orig_FlaxGPT2LMHeadModel
from vit_gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel
# Main model - ViTGPT2LM
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
# ViT - as encoder
from transformers import ViTFeatureExtractor
from PIL import Image
import requests
import numpy as np
import jax
import jax.numpy as jnp
# GPT2+LM - as decoder
from transformers import GPT2Tokenizer
max_length = 8
vision_model_name = 'google/vit-base-patch16-224-in21k'
text_model_name = 'asi/gpt-fr-cased-small'
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained(
vision_pretrained_model_name_or_path=vision_model_name,
text_pretrained_model_name_or_path=text_model_name
)
model = flax_vit_gpt2_lm
feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name)
tokenizer = GPT2Tokenizer.from_pretrained(text_model_name)
# encoder data
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
# batch dim is added automatically
encoder_inputs = feature_extractor(images=image, return_tensors="jax")
pixel_values = encoder_inputs.pixel_values
print('=' * 60)
print(f'pixel_values.shape = {pixel_values.shape}')
# decoder data
sentence = 'mon chien est mignon'
# IMPORTANT: For training/evaluation/attention_mask/loss
sentence += ' ' + tokenizer.eos_token
# batch dim is added automatically
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(sentence, max_length=max_length, padding="max_length", truncation=True, return_tensors="np")
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
decoder_input_ids = shift_tokens_right(
jnp.array(labels["input_ids"]),
model.config.text_config.pad_token_id,
model.config.decoder_start_token_id
)
decoder_input_ids = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
decoder_attention_mask = labels["attention_mask"]
print('=' * 60)
print(f'decoder_inputs = {decoder_input_ids}')
print(f'decoder_input_ids.shape = {decoder_input_ids.shape}')
print(f'decoder_attention_mask = {decoder_attention_mask}')
print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}')
orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
# Generation!
num_beams = 1
gen_kwargs = {"max_length": 6, "num_beams": num_beams}
orig_gpt2_generated = orig_gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs)
gpt2_generated = gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs)
orig_token_ids = np.array(orig_gpt2_generated.sequences)[0]
token_ids = np.array(gpt2_generated.sequences)[0]
orig_caption = tokenizer.decode(orig_token_ids)
caption = tokenizer.decode(token_ids)
print('=' * 60)
print(f'orig. GPT2 generated token ids: {orig_token_ids}')
print(f'GPT2 generated token ids: {token_ids}')
print('=' * 60)
print(f'orig. GPT2 generated caption: {orig_caption}')
print(f'GPT2 generated caption: {caption}')
# model data
model_inputs = {
'pixel_values': pixel_values,
'attention_mask': None,
'decoder_input_ids': decoder_input_ids,
'decoder_attention_mask': decoder_attention_mask,
'decoder_position_ids': None,
}
# Model call
model_outputs = model(**model_inputs)
logits = model_outputs[0]
preds = np.argmax(logits, axis=-1)
print('=' * 60)
print('Flax: Vit-GPT2-LM')
print('predicted token ids:')
print(preds)
# encoder_last_hidden_state = model_outputs['encoder_last_hidden_state']
# print(encoder_last_hidden_state)
# encoder_kwargs = {}
# encoder_outputs = flax_vit_gpt2_lm.encode(pixel_values, return_dict=True, **encoder_kwargs)
# print(encoder_outputs['last_hidden_state'])
# Generation!
num_beams = 1
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
batch = {'pixel_values': pixel_values}
generated = model.generate(batch['pixel_values'], **gen_kwargs)
token_ids = np.array(generated.sequences)[0]
print('=' * 60)
print(f'generated token ids: {token_ids}')
caption = tokenizer.decode(token_ids)
print('=' * 60)
print(f'generated caption: {caption}')
# save
os.makedirs('./model/', exist_ok=True)
model.save_pretrained(save_directory='./model/')
# load
_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained('./model/')
# check if the result is the same as before
_generated = _model.generate(batch['pixel_values'], **gen_kwargs)
_token_ids = np.array(_generated.sequences)[0]
print('=' * 60)
print(f'new generated token ids: {_token_ids}')
print(f'token_ids == new_token_ids: {token_ids == _token_ids}')
|