Vivek commited on
Commit
2c8f24a
1 Parent(s): a6d9333
Files changed (1) hide show
  1. src/test_hellaswag.py +5 -5
src/test_hellaswag.py CHANGED
@@ -39,6 +39,7 @@ def tokenize(examples):
39
  test_dataset=test_dataset.map(tokenize)
40
 
41
  test_dataset=test_dataset.remove_columns(remove_col)
 
42
 
43
  def glue_test_data_loader(rng,dataset,batch_size):
44
  steps_per_epoch=len_test_dataset//batch_size
@@ -46,10 +47,11 @@ def glue_test_data_loader(rng,dataset,batch_size):
46
  perms=perms[:steps_per_epoch*batch_size]
47
  perms=perms.reshape((steps_per_epoch,batch_size))
48
  for perm in perms:
 
49
  batch=dataset[perm]
50
  #print(jnp.array(batch['label']))
51
  batch={k:jnp.array(v) for k,v in batch.items()}
52
- batch=shard(batch)
53
  yield batch
54
 
55
  seed=0
@@ -59,7 +61,7 @@ dropout_rngs=jax.random.split(rng,jax.local_device_count())
59
  input_id=jnp.array(test_dataset['input_ids'])
60
  att_mask=jnp.array(test_dataset['attention_mask'])
61
 
62
- total_batch_size=32
63
 
64
  from model_file import FlaxGPTNeoForMultipleChoice
65
 
@@ -69,12 +71,10 @@ restored_output=[]
69
  rng, input_rng = jax.random.split(rng)
70
  for idx,batch in enumerate(glue_test_data_loader(input_rng, test_dataset, total_batch_size)):
71
  outputs=model(batch['input_ids'],batch['attention_mask'])
72
- #outputs=outputs['logits'].reshape(total_batch_size,-1)
73
- print(outputs.shape)
74
  final_output=jnp.argmax(outputs,axis=-1)
75
  restored_output.append(final_output)
76
 
77
- finall=pd.DataFrame({'predictions':restored_output})
78
  finall.to_csv('../predictions.csv')
79
 
80
 
 
39
  test_dataset=test_dataset.map(tokenize)
40
 
41
  test_dataset=test_dataset.remove_columns(remove_col)
42
+ list1=[]
43
 
44
  def glue_test_data_loader(rng,dataset,batch_size):
45
  steps_per_epoch=len_test_dataset//batch_size
 
47
  perms=perms[:steps_per_epoch*batch_size]
48
  perms=perms.reshape((steps_per_epoch,batch_size))
49
  for perm in perms:
50
+ list1.append(perm)
51
  batch=dataset[perm]
52
  #print(jnp.array(batch['label']))
53
  batch={k:jnp.array(v) for k,v in batch.items()}
54
+ #batch=shard(batch)
55
  yield batch
56
 
57
  seed=0
 
61
  input_id=jnp.array(test_dataset['input_ids'])
62
  att_mask=jnp.array(test_dataset['attention_mask'])
63
 
64
+ total_batch_size=16
65
 
66
  from model_file import FlaxGPTNeoForMultipleChoice
67
 
 
71
  rng, input_rng = jax.random.split(rng)
72
  for idx,batch in enumerate(glue_test_data_loader(input_rng, test_dataset, total_batch_size)):
73
  outputs=model(batch['input_ids'],batch['attention_mask'])
 
 
74
  final_output=jnp.argmax(outputs,axis=-1)
75
  restored_output.append(final_output)
76
 
77
+ finall=pd.DataFrame({'predictions':restored_output,'permutation':list1})
78
  finall.to_csv('../predictions.csv')
79
 
80