gpt2-Cosmos / src /prediction.py
Vivek's picture
Saving weights of epoch 1 at step 92
a25072c
import jax
print(jax.local_device_count())
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict, unfreeze
from typing import Any, Optional, Tuple
from transformers import (
GPT2Config)
import transformers
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
from datasets import load_dataset,load_metric
from datasets import Dataset
from model_file import FlaxGPT2ForMultipleChoice
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
run_dataset=Dataset.from_csv('......')
def preprocess(example):
example['context&question']=example['context']+example['question']
example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
return example
run_dataset=run_dataset.map(preprocess)
def tokenize(examples):
a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
a['labels']=examples['label']
return a
run_dataset=run_dataset.map(tokenize)
remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']
run_dataset=run_dataset.remove_columns(remov_col)
model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
input_id=jnp.array(run_dataset['input_ids'])
att_mask=jnp.array(run_dataset['attention_mask'])
outputs=model(input_id,att_mask)
final_output=jnp.argmax(outputs,axis=-1)
logger.info(f"the predction of the dataset : {final_output}")