vit-gpt2 / tests /test_model.py
ydshieh
update test_model.py
845642f
raw
history blame
5.26 kB
import sys, os
current_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_path)
from transformers import FlaxGPT2LMHeadModel as Orig_FlaxGPT2LMHeadModel
from vit_gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel
# Main model - ViTGPT2LM
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
# ViT - as encoder
from transformers import ViTFeatureExtractor
from PIL import Image
import requests
import numpy as np
import jax
import jax.numpy as jnp
# GPT2+LM - as decoder
from transformers import GPT2Tokenizer
max_length = 8
vision_model_name = 'google/vit-base-patch16-224-in21k'
text_model_name = 'asi/gpt-fr-cased-small'
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained(
vision_pretrained_model_name_or_path=vision_model_name,
text_pretrained_model_name_or_path=text_model_name
)
model = flax_vit_gpt2_lm
feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name)
tokenizer = GPT2Tokenizer.from_pretrained(text_model_name)
# encoder data
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
# batch dim is added automatically
encoder_inputs = feature_extractor(images=image, return_tensors="jax")
pixel_values = encoder_inputs.pixel_values
print('=' * 60)
print(f'pixel_values.shape = {pixel_values.shape}')
# decoder data
sentence = 'mon chien est mignon'
# IMPORTANT: For training/evaluation/attention_mask/loss
sentence += ' ' + tokenizer.eos_token
# batch dim is added automatically
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(sentence, max_length=max_length, padding="max_length", truncation=True, return_tensors="np")
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
decoder_input_ids = shift_tokens_right(
jnp.array(labels["input_ids"]),
model.config.text_config.pad_token_id,
model.config.decoder_start_token_id
)
decoder_input_ids = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
decoder_attention_mask = labels["attention_mask"]
print('=' * 60)
print(f'decoder_inputs = {decoder_input_ids}')
print(f'decoder_input_ids.shape = {decoder_input_ids.shape}')
print(f'decoder_attention_mask = {decoder_attention_mask}')
print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}')
orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
# Generation!
num_beams = 1
gen_kwargs = {"max_length": 6, "num_beams": num_beams}
orig_gpt2_generated = orig_gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs)
gpt2_generated = gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs)
orig_token_ids = np.array(orig_gpt2_generated.sequences)[0]
token_ids = np.array(gpt2_generated.sequences)[0]
orig_caption = tokenizer.decode(orig_token_ids)
caption = tokenizer.decode(token_ids)
print('=' * 60)
print(f'orig. GPT2 generated token ids: {orig_token_ids}')
print(f'GPT2 generated token ids: {token_ids}')
print('=' * 60)
print(f'orig. GPT2 generated caption: {orig_caption}')
print(f'GPT2 generated caption: {caption}')
# model data
model_inputs = {
'pixel_values': pixel_values,
'attention_mask': None,
'decoder_input_ids': decoder_input_ids,
'decoder_attention_mask': decoder_attention_mask,
'decoder_position_ids': None,
}
# Model call
model_outputs = model(**model_inputs)
logits = model_outputs[0]
preds = np.argmax(logits, axis=-1)
print('=' * 60)
print('Flax: Vit-GPT2-LM')
print('predicted token ids:')
print(preds)
# encoder_last_hidden_state = model_outputs['encoder_last_hidden_state']
# print(encoder_last_hidden_state)
# encoder_kwargs = {}
# encoder_outputs = flax_vit_gpt2_lm.encode(pixel_values, return_dict=True, **encoder_kwargs)
# print(encoder_outputs['last_hidden_state'])
# Generation!
num_beams = 1
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
batch = {'pixel_values': pixel_values}
generated = model.generate(batch['pixel_values'], **gen_kwargs)
token_ids = np.array(generated.sequences)[0]
print('=' * 60)
print(f'generated token ids: {token_ids}')
caption = tokenizer.decode(token_ids)
print('=' * 60)
print(f'generated caption: {caption}')
# save
os.makedirs('./model/', exist_ok=True)
model.save_pretrained(save_directory='./model/')
# load
_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained('./model/')
# check if the result is the same as before
_generated = _model.generate(batch['pixel_values'], **gen_kwargs)
_token_ids = np.array(_generated.sequences)[0]
print('=' * 60)
print(f'new generated token ids: {_token_ids}')
print(f'token_ids == new_token_ids: {token_ids == _token_ids}')