import jax print(jax.local_device_count()) import jax.numpy as jnp import flax import flax.linen as nn from flax.core.frozen_dict import FrozenDict, unfreeze from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key from transformers import GPTNeoConfig from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoPreTrainedModel from transformers import GPT2Tokenizer from datasets import load_dataset import pandas as pd num_choices=2 dataset = load_dataset("piqa") def preprocess(example): example['first_sentence']=[example['goal']]*num_choices example['second_sentence']=[example[f'sol{i}'] for i in [1,2]] return example test_dataset=dataset['test'].map(preprocess) len_test_dataset=3084 test_dataset=test_dataset.select(range(len_test_dataset)) tokenizer=GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B',pad_token='<|endoftext|>') remove_col=test_dataset.column_names def tokenize(examples): tokenized_examples=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=512,return_tensors='jax') return tokenized_examples test_dataset=test_dataset.map(tokenize) test_dataset=test_dataset.remove_columns(remove_col) list1=[] def glue_test_data_loader(rng,dataset,batch_size): steps_per_epoch=len_test_dataset//batch_size perms=jax.random.permutation(rng,len_test_dataset) perms=perms[:steps_per_epoch*batch_size] perms=perms.reshape((steps_per_epoch,batch_size)) for perm in perms: list1.append(perm) batch=dataset[perm] #print(jnp.array(batch['label'])) batch={k:jnp.array(v) for k,v in batch.items()} #batch=shard(batch) yield batch seed=0 rng=jax.random.PRNGKey(seed) dropout_rngs=jax.random.split(rng,jax.local_device_count()) input_id=jnp.array(test_dataset['input_ids']) att_mask=jnp.array(test_dataset['attention_mask']) total_batch_size=16 from model_file import FlaxGPTNeoForMultipleChoice model = FlaxGPTNeoForMultipleChoice.from_pretrained('Vivek/gptneo_piqa',input_shape=(1,num_choices,1)) restored_output=[] rng, input_rng = jax.random.split(rng) for idx,batch in enumerate(glue_test_data_loader(input_rng, test_dataset, total_batch_size)): outputs=model(batch['input_ids'],batch['attention_mask']) final_output=jnp.argmax(outputs,axis=-1) restored_output.append(final_output) finall=pd.DataFrame({'predictions':restored_output,'permutation':list1}) finall.to_csv('./piqa_predictions.csv')