Vivek commited on
Commit
42bbcd0
1 Parent(s): 2afb876
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. src/test_piqa.py +78 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
src/test_piqa.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ print(jax.local_device_count())
3
+ import jax.numpy as jnp
4
+
5
+ import flax
6
+ import flax.linen as nn
7
+ from flax.core.frozen_dict import FrozenDict, unfreeze
8
+ from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
9
+
10
+ from transformers import GPTNeoConfig
11
+ from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoPreTrainedModel
12
+ from transformers import GPT2Tokenizer
13
+
14
+ from datasets import load_dataset
15
+ import pandas as pd
16
+
17
+ num_choices=2
18
+ dataset = load_dataset("piqa")
19
+
20
+ def preprocess(example):
21
+ example['first_sentence']=[example['goal']]*num_choices
22
+ example['second_sentence']=[example[f'sol{i}'] for i in [1,2]]
23
+ return example
24
+
25
+ test_dataset=dataset['test'].map(preprocess)
26
+
27
+ len_test_dataset=100
28
+
29
+ test_dataset=test_dataset.select(range(len_test_dataset))
30
+
31
+ tokenizer=GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B',pad_token='<|endoftext|>')
32
+
33
+ remove_col=test_dataset.column_names
34
+
35
+ def tokenize(examples):
36
+ tokenized_examples=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=512,return_tensors='jax')
37
+ return tokenized_examples
38
+
39
+ test_dataset=test_dataset.map(tokenize)
40
+
41
+ test_dataset=test_dataset.remove_columns(remove_col)
42
+ list1=[]
43
+
44
+ def glue_test_data_loader(rng,dataset,batch_size):
45
+ steps_per_epoch=len_test_dataset//batch_size
46
+ perms=jax.random.permutation(rng,len_test_dataset)
47
+ perms=perms[:steps_per_epoch*batch_size]
48
+ perms=perms.reshape((steps_per_epoch,batch_size))
49
+ for perm in perms:
50
+ list1.append(perm)
51
+ batch=dataset[perm]
52
+ #print(jnp.array(batch['label']))
53
+ batch={k:jnp.array(v) for k,v in batch.items()}
54
+ #batch=shard(batch)
55
+ yield batch
56
+
57
+ seed=0
58
+ rng=jax.random.PRNGKey(seed)
59
+ dropout_rngs=jax.random.split(rng,jax.local_device_count())
60
+
61
+ input_id=jnp.array(test_dataset['input_ids'])
62
+ att_mask=jnp.array(test_dataset['attention_mask'])
63
+
64
+ total_batch_size=16
65
+
66
+ from model_file import FlaxGPTNeoForMultipleChoice
67
+
68
+ model = FlaxGPTNeoForMultipleChoice.from_pretrained('Vivek/gptneo_hellaswag',input_shape=(1,num_choices,1))
69
+
70
+ restored_output=[]
71
+ rng, input_rng = jax.random.split(rng)
72
+ for idx,batch in enumerate(glue_test_data_loader(input_rng, test_dataset, total_batch_size)):
73
+ outputs=model(batch['input_ids'],batch['attention_mask'])
74
+ final_output=jnp.argmax(outputs,axis=-1)
75
+ restored_output.append(final_output)
76
+
77
+ finall=pd.DataFrame({'predictions':restored_output,'permutation':list1})
78
+ finall.to_csv('./piqa_predictions.csv')