import sys, os current_path = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_path) from transformers import FlaxGPT2LMHeadModel as Orig_FlaxGPT2LMHeadModel from vit_gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel # Main model - ViTGPT2LM from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration # ViT - as encoder from transformers import ViTFeatureExtractor from PIL import Image import requests import numpy as np import jax import jax.numpy as jnp # GPT2+LM - as decoder from transformers import GPT2Tokenizer max_length = 8 # ================================================================================ # Models preparation vision_model_name = 'google/vit-base-patch16-224-in21k' text_model_name = 'asi/gpt-fr-cased-small' project_encoder = False flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained( vision_pretrained_model_name_or_path=vision_model_name, text_pretrained_model_name_or_path=text_model_name, project_encoder=project_encoder ) model = flax_vit_gpt2_lm # ================================================================================ # Inputs preparation feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name) tokenizer = GPT2Tokenizer.from_pretrained(text_model_name) # encoder data url = 'http://images.cocodataset.org/val2017/000000039769.jpg' image = Image.open(requests.get(url, stream=True).raw) # batch dim is added automatically encoder_inputs = feature_extractor(images=image, return_tensors="jax") pixel_values = encoder_inputs.pixel_values print('=' * 60) print(f'pixel_values.shape = {pixel_values.shape}') # decoder data sentence = 'mon chien est mignon' # IMPORTANT: For training/evaluation/attention_mask/loss sentence += ' ' + tokenizer.eos_token # batch dim is added automatically # Setup the tokenizer for targets with tokenizer.as_target_tokenizer(): labels = tokenizer(sentence, max_length=max_length, padding="max_length", truncation=True, return_tensors="np") def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: """ Shift input ids one token to the right. """ shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) return shifted_input_ids decoder_input_ids = shift_tokens_right( jnp.array(labels["input_ids"]), model.config.text_config.pad_token_id, model.config.decoder_start_token_id ) decoder_input_ids = np.asarray(decoder_input_ids) # We need decoder_attention_mask so we can ignore pad tokens from loss decoder_attention_mask = labels["attention_mask"] print('=' * 60) print(f'decoder_inputs = {decoder_input_ids}') print(f'decoder_input_ids.shape = {decoder_input_ids.shape}') print(f'decoder_attention_mask = {decoder_attention_mask}') print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}') # ================================================================================ # Check `FlaxGPT2LMHeadModel` has the same results in the new version (when no `encoder_outputs` is provided). orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name) gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name) # generation! num_beams = 1 gen_kwargs = {"max_length": 6, "num_beams": num_beams} orig_gpt2_generated = orig_gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs) gpt2_generated = gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs) orig_token_ids = np.array(orig_gpt2_generated.sequences)[0] token_ids = np.array(gpt2_generated.sequences)[0] orig_caption = tokenizer.decode(orig_token_ids) caption = tokenizer.decode(token_ids) print('=' * 60) print(f'orig. GPT2 generated token ids: {orig_token_ids}') print(f'GPT2 generated token ids: {token_ids}') print('=' * 60) print(f'orig. GPT2 generated caption: {orig_caption}') print(f'GPT2 generated caption: {caption}') assert list(orig_token_ids) == list(token_ids) # ================================================================================ # model data model_inputs = { 'pixel_values': pixel_values, 'attention_mask': None, 'decoder_input_ids': decoder_input_ids, 'decoder_attention_mask': decoder_attention_mask, 'decoder_position_ids': None, } # ================================================================================ # Check `model.__call__()` # Model call model_outputs = model(**model_inputs) logits = model_outputs[0] preds = np.argmax(logits, axis=-1) print('=' * 60) print('Flax ViT-GPT2-LM - predicted token ids:') print(preds) encoder_last_hidden_state = model_outputs['encoder_last_hidden_state'] print('=' * 60) print("encoder_last_hidden_state given by model.__call__():") print(encoder_last_hidden_state) encoder_outputs = model.encode(pixel_values, return_dict=True) print('=' * 60) print("encoder's last_hidden_state given by model.encode():") print(encoder_outputs['last_hidden_state']) total_diff = np.sum(np.abs(encoder_outputs['last_hidden_state'] - encoder_last_hidden_state)) print('=' * 60) print(f"total difference: {total_diff}") # ================================================================================ # Check model generation # generation num_beams = 1 gen_kwargs = {"max_length": max_length, "num_beams": num_beams} batch = {'pixel_values': pixel_values} generated = model.generate(batch['pixel_values'], **gen_kwargs) token_ids = np.array(generated.sequences)[0] print('=' * 60) print(f'generated token ids: {token_ids}') caption = tokenizer.decode(token_ids) print('=' * 60) print(f'generated caption: {caption}') # ================================================================================ # Check save & load # save os.makedirs('./model/', exist_ok=True) model.save_pretrained(save_directory='./model/') # load _model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained('./model/') # check if the result is the same as before _generated = _model.generate(batch['pixel_values'], **gen_kwargs) _token_ids = np.array(_generated.sequences)[0] print('=' * 60) print(f'new generated token ids: {_token_ids}') print(f'token_ids == new_token_ids: {token_ids == _token_ids}') # ================================================================================ # Check PyTorch version's output - it should be the same as above import torch from transformers import ViTModel, GPT2Config, GPT2LMHeadModel vision_model_pt = ViTModel.from_pretrained(vision_model_name) config = GPT2Config.from_pretrained(text_model_name) # config.is_encoder_decoder = True config.add_cross_attention = True text_model_pt = GPT2LMHeadModel.from_pretrained(text_model_name, config=config) encoder_pt_inputs = feature_extractor(images=image, return_tensors="pt") encoder_pt_outputs = vision_model_pt(**encoder_pt_inputs) encoder_hidden_states = encoder_pt_outputs.last_hidden_state # model data text_model_pt_inputs = { 'input_ids': torch.tensor(decoder_input_ids, dtype=torch.int32), 'attention_mask': torch.tensor(decoder_attention_mask, dtype=torch.int32), 'position_ids': None, 'encoder_hidden_states': encoder_hidden_states } # Model call text_model_pt_outputs = text_model_pt(**text_model_pt_inputs) logits = text_model_pt_outputs[0] preds = np.argmax(logits.detach().numpy(), axis=-1) print('=' * 60) print('PyTroch: ViT --> GPT2-LM') print('predicted token ids:') print(preds) model_logits = np.array(model_outputs.logits) text_model_pt_logits = text_model_pt_outputs.logits.detach().cpu().numpy() total_diff = np.sum(np.abs(model_logits - text_model_pt_logits)) print('=' * 60) print("model_logits:") print(model_logits) print('=' * 60) print("text_model_pt_logits:") print(text_model_pt_logits) print('=' * 60) print(f"total difference between logits: {total_diff}")