|
import jax |
|
print(jax.local_device_count()) |
|
import jax.numpy as jnp |
|
|
|
import flax |
|
import flax.linen as nn |
|
from flax.core.frozen_dict import FrozenDict, unfreeze |
|
|
|
from typing import Any, Optional, Tuple |
|
|
|
from transformers import ( |
|
GPT2Config) |
|
|
|
import transformers |
|
from transformers import GPT2Tokenizer |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>') |
|
from datasets import load_dataset,load_metric |
|
|
|
from datasets import Dataset |
|
|
|
from model_file import FlaxGPT2ForMultipleChoice |
|
|
|
import logging |
|
|
|
logger = logging.getLogger() |
|
logger.setLevel(logging.INFO) |
|
|
|
run_dataset=Dataset.from_csv('......') |
|
|
|
def preprocess(example): |
|
example['context&question']=example['context']+example['question'] |
|
example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']] |
|
example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3'] |
|
return example |
|
|
|
run_dataset=run_dataset.map(preprocess) |
|
|
|
def tokenize(examples): |
|
a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax') |
|
a['labels']=examples['label'] |
|
return a |
|
|
|
run_dataset=run_dataset.map(tokenize) |
|
|
|
remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence'] |
|
|
|
run_dataset=run_dataset.remove_columns(remov_col) |
|
|
|
model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1)) |
|
|
|
input_id=jnp.array(run_dataset['input_ids']) |
|
att_mask=jnp.array(run_dataset['attention_mask']) |
|
|
|
outputs=model(input_id,att_mask) |
|
|
|
final_output=jnp.argmax(outputs,axis=-1) |
|
|
|
logger.info(f"the predction of the dataset : {final_output}") |
|
|