gptneo_piqa / src /test_piqa.py
Vivek's picture
final changes
e1fc256
import jax
print(jax.local_device_count())
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
from transformers import GPTNeoConfig
from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoPreTrainedModel
from transformers import GPT2Tokenizer
from datasets import load_dataset
import pandas as pd
num_choices=2
dataset = load_dataset("piqa")
def preprocess(example):
example['first_sentence']=[example['goal']]*num_choices
example['second_sentence']=[example[f'sol{i}'] for i in [1,2]]
return example
test_dataset=dataset['test'].map(preprocess)
len_test_dataset=3084
test_dataset=test_dataset.select(range(len_test_dataset))
tokenizer=GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B',pad_token='<|endoftext|>')
remove_col=test_dataset.column_names
def tokenize(examples):
tokenized_examples=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=512,return_tensors='jax')
return tokenized_examples
test_dataset=test_dataset.map(tokenize)
test_dataset=test_dataset.remove_columns(remove_col)
list1=[]
def glue_test_data_loader(rng,dataset,batch_size):
steps_per_epoch=len_test_dataset//batch_size
perms=jax.random.permutation(rng,len_test_dataset)
perms=perms[:steps_per_epoch*batch_size]
perms=perms.reshape((steps_per_epoch,batch_size))
for perm in perms:
list1.append(perm)
batch=dataset[perm]
#print(jnp.array(batch['label']))
batch={k:jnp.array(v) for k,v in batch.items()}
#batch=shard(batch)
yield batch
seed=0
rng=jax.random.PRNGKey(seed)
dropout_rngs=jax.random.split(rng,jax.local_device_count())
input_id=jnp.array(test_dataset['input_ids'])
att_mask=jnp.array(test_dataset['attention_mask'])
total_batch_size=16
from model_file import FlaxGPTNeoForMultipleChoice
model = FlaxGPTNeoForMultipleChoice.from_pretrained('Vivek/gptneo_piqa',input_shape=(1,num_choices,1))
restored_output=[]
rng, input_rng = jax.random.split(rng)
for idx,batch in enumerate(glue_test_data_loader(input_rng, test_dataset, total_batch_size)):
outputs=model(batch['input_ids'],batch['attention_mask'])
final_output=jnp.argmax(outputs,axis=-1)
restored_output.append(final_output)
finall=pd.DataFrame({'predictions':restored_output,'permutation':list1})
finall.to_csv('./piqa_predictions.csv')