Vivek commited on
Commit
d7d81d6
1 Parent(s): fd148ff

test cosmos

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. src/testcosmos.py +79 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
src/testcosmos.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ print(jax.local_device_count())
3
+ import jax.numpy as jnp
4
+
5
+ import flax
6
+ import flax.linen as nn
7
+ from flax.core.frozen_dict import FrozenDict, unfreeze
8
+ from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
9
+
10
+ from transformers import GPTNeoConfig
11
+ from transformers.models.gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoPreTrainedModel
12
+ from transformers import GPT2Tokenizer
13
+
14
+ from datasets import load_dataset
15
+ import pandas as pd
16
+
17
+ num_choices=4
18
+ dataset = load_dataset("cosmos_qa")
19
+
20
+ def preprocess(example):
21
+ example['context&question']=example['context']+example['question']
22
+ example['first_sentence']=[example['context&question']]*num_choices
23
+ example['second_sentence']=[example[f'answer{i}'] for i in range(num_choices)]
24
+ return example
25
+
26
+ test_dataset=dataset['test'].map(preprocess)
27
+
28
+ len_test_dataset=100
29
+
30
+ test_dataset=test_dataset.select(range(len_test_dataset))
31
+
32
+ tokenizer=GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B',pad_token='<|endoftext|>')
33
+
34
+ remove_col=test_dataset.column_names
35
+
36
+ def tokenize(examples):
37
+ tokenized_examples=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
38
+ return tokenized_examples
39
+
40
+ test_dataset=test_dataset.map(tokenize)
41
+
42
+ test_dataset=test_dataset.remove_columns(remove_col)
43
+ list1=[]
44
+
45
+ def glue_test_data_loader(rng,dataset,batch_size):
46
+ steps_per_epoch=len_test_dataset//batch_size
47
+ perms=jax.random.permutation(rng,len_test_dataset)
48
+ perms=perms[:steps_per_epoch*batch_size]
49
+ perms=perms.reshape((steps_per_epoch,batch_size))
50
+ for perm in perms:
51
+ list1.append(perm)
52
+ batch=dataset[perm]
53
+ #print(jnp.array(batch['label']))
54
+ batch={k:jnp.array(v) for k,v in batch.items()}
55
+ #batch=shard(batch)
56
+ yield batch
57
+
58
+ seed=0
59
+ rng=jax.random.PRNGKey(seed)
60
+ dropout_rngs=jax.random.split(rng,jax.local_device_count())
61
+
62
+ input_id=jnp.array(test_dataset['input_ids'])
63
+ att_mask=jnp.array(test_dataset['attention_mask'])
64
+
65
+ total_batch_size=16
66
+
67
+ from model_file import FlaxGPTNeoForMultipleChoice
68
+
69
+ model = FlaxGPTNeoForMultipleChoice.from_pretrained('Vivek/gptneo_hellaswag',input_shape=(1,num_choices,1))
70
+
71
+ restored_output=[]
72
+ rng, input_rng = jax.random.split(rng)
73
+ for idx,batch in enumerate(glue_test_data_loader(input_rng, test_dataset, total_batch_size)):
74
+ outputs=model(batch['input_ids'],batch['attention_mask'])
75
+ final_output=jnp.argmax(outputs,axis=-1)
76
+ restored_output.append(final_output)
77
+
78
+ finall=pd.DataFrame({'predictions':restored_output,'permutation':list1})
79
+ finall.to_csv('./cosmos_predictions.csv')