File size: 2,796 Bytes
a63e06d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import jax
print(jax.local_device_count())
import jax.numpy as jnp

import flax
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key

from typing import Any, Optional, Tuple

from transformers import (
    GPT2Config)

import transformers
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>') 
from datasets import load_dataset,load_metric

from model_file import FlaxGPT2ForMultipleChoice

import logging 

logger = logging.getLogger()
logger.setLevel(logging.INFO)

dataset=load_dataset('cosmos_qa')

len_test_dataset=6963

test_dataset=dataset['test'].select(range(len_test_dataset))

def preprocess(example):
    example['context&question']=example['context']+example['question']
    example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
    example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
    return example

test_dataset=test_dataset.map(preprocess)

def tokenize(examples):
    a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
    a['labels']=examples['label']
    return a

test_dataset=test_dataset.map(tokenize)

remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']

test_dataset=test_dataset.remove_columns(remov_col)

seed=0
total_batch_size=32

model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))

def glue_train_data_loader(rng,dataset,batch_size):
    steps_per_epoch=len_test_dataset//batch_size
    perms=jax.random.permutation(rng,len(dataset))
    perms=perms[:steps_per_epoch*batch_size]
    perms=perms.reshape((steps_per_epoch,batch_size))
    for perm in perms:
      batch=dataset[perm]
      batch={k:jnp.array(v) for k,v in batch.items()}
      batch=shard(batch)
      yield batch

rng=jax.random.PRNGKey(seed)
dropout_rngs=jax.random.split(rng,jax.local_device_count())

input_id=jnp.array(test_dataset['input_ids'])
att_mask=jnp.array(test_dataset['attention_mask'])

restored_output=[]
rng, input_rng = jax.random.split(rng)

for idx,batch in enumerate(glue_train_data_loader(input_rng, test_dataset, total_batch_size)):
    outputs=model(batch['input_ids'],batch['attention_mask'])
    final_output=jnp.argmax(outputs,axis=-1)
    restored_output.append(final_output)

#outputs=model(input_id,att_mask)
#final_output=jnp.argmax(outputs,axis=-1)

logger.info(f"the predction of the test dataset : {restored_output[:30]}")