File size: 1,826 Bytes
3adb47c a25072c 3adb47c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import jax
print(jax.local_device_count())
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict, unfreeze
from typing import Any, Optional, Tuple
from transformers import (
GPT2Config)
import transformers
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
from datasets import load_dataset,load_metric
from datasets import Dataset
from model_file import FlaxGPT2ForMultipleChoice
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
run_dataset=Dataset.from_csv('......')
def preprocess(example):
example['context&question']=example['context']+example['question']
example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
return example
run_dataset=run_dataset.map(preprocess)
def tokenize(examples):
a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
a['labels']=examples['label']
return a
run_dataset=run_dataset.map(tokenize)
remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']
run_dataset=run_dataset.remove_columns(remov_col)
model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
input_id=jnp.array(run_dataset['input_ids'])
att_mask=jnp.array(run_dataset['attention_mask'])
outputs=model(input_id,att_mask)
final_output=jnp.argmax(outputs,axis=-1)
logger.info(f"the predction of the dataset : {final_output}")
|