## Introduction

This tutorial demonstrates how to perform evaluation on a gpt-j-6B-int8 model.

## Prerequisite

In [None]:
!pip install onnx onnxruntime torch transformers datasets accelerate

## Run

### 1. Get lambada acc

In [None]:
from transformers import AutoTokenizer
import torch
import numpy as np
from datasets import load_dataset
import onnxruntime as ort
from torch.nn.functional import pad

# load model
model_id = "EleutherAI/gpt-j-6B"
tokenizer = AutoTokenizer.from_pretrained(model_id)

def tokenize_function(examples):
 example = tokenizer(examples['text'])
 return example

# create dataset
dataset = load_dataset('lambada', split='validation')
dataset = dataset.shuffle(seed=42)
dataset = dataset.map(tokenize_function, batched=True)
dataset.set_format(type='torch', columns=['input_ids'])

# create session
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession('/path/to/model.onnx', options, providers=ort.get_available_providers())
total, hit = 0, 0
index = 1

# inference
for idx, batch in enumerate(dataset):
 input_ids = batch['input_ids'].unsqueeze(0)
 label = input_ids[:, -1]
 pad_len = 0 ##set to 0
 input_ids = pad(input_ids, (0, pad_len), value=1)
 ort_inputs = {
 'input_ids': input_ids.detach().cpu().numpy(),
 'attention_mask': torch.cat([torch.ones(input_ids.shape), torch.ones([1, 1])], dim=-1).detach().cpu().numpy().astype('int64')
 }
 for i in range(28):
 ort_inputs["past_key_values.{}.key".format(i)] = np.zeros((1,16,1,256), dtype='float32')
 ort_inputs["past_key_values.{}.value".format(i)] = np.zeros((1,16,1,256), dtype='float32')
 predictions = session.run(None, ort_inputs)
 outputs = torch.from_numpy(predictions[0]) 
 last_token_logits = outputs[:, -2 - pad_len, :]
 pred = last_token_logits.argmax(dim=-1)
 total += label.size(0)
 hit += (pred == label).sum().item()

acc = hit / total
print('acc: ', acc)

In [None]:
# batch inference

from transformers import AutoTokenizer
import torch
import numpy as np
from datasets import load_dataset
import onnxruntime as ort
from torch.nn.functional import pad
from torch.utils.data import DataLoader

batch_size = 2
pad_max = 196

# load model
model_id = "EleutherAI/gpt-j-6B"
tokenizer = AutoTokenizer.from_pretrained(model_id)

def tokenize_function(examples):
 example = tokenizer(examples['text'])
 return example

# create dataloader
class Dataloader:
 def __init__(self, pad_max=196, batch_size=1, sub_folder='validation'):
 self.pad_max = pad_max
 self.batch_size=batch_size
 dataset = load_dataset('lambada', split=sub_folder)
 dataset = dataset.map(tokenize_function, batched=True)
 dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
 self.dataloader = DataLoader(
 dataset,
 batch_size=self.batch_size,
 shuffle=False,
 collate_fn=self.collate_batch,
 )

 def collate_batch(self, batch):
 input_ids_padded = []
 attention_mask_padded = []
 last_ind = []
 for text in batch:
 input_ids = text["input_ids"] if text["input_ids"].shape[0] <= self.pad_max else text["input_ids"][0:int(self.pad_max-1)]
 pad_len = self.pad_max - input_ids.shape[0]
 last_ind.append(input_ids.shape[0] - 1)
 input_ids = pad(input_ids, (0, pad_len), value=1)
 input_ids_padded.append(input_ids)
 attention_mask = torch.ones(input_ids.shape[0] + 1)
 attention_mask_padded.append(attention_mask)
 return (torch.vstack(input_ids_padded), torch.vstack(attention_mask_padded)), torch.tensor(last_ind)

 def __iter__(self):
 try:
 for (input_ids, attention_mask), last_ind in self.dataloader:
 data = [input_ids.detach().cpu().numpy().astype('int64')]
 data.append(attention_mask.detach().cpu().numpy().astype('int64'))
 yield data, last_ind.detach().cpu().numpy()
 except StopIteration:
 return

# create session
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession('/path/to/model.onnx', options, providers=ort.get_available_providers())
total, hit = 0, 0

dataloader = Dataloader(pad_max=pad_max, batch_size=batch_size)

# inference
for idx, (batch, last_ind) in enumerate(dataloader):
 label = torch.from_numpy(batch[0][torch.arange(len(last_ind)), last_ind])
 pad_len = pad_max - last_ind - 1
 ort_inputs = {
 'input_ids': batch[0],
 'attention_mask': batch[1]
 }
 for i in range(28):
 ort_inputs["past_key_values.{}.key".format(i)] = np.zeros((batch_size,16,1,256), dtype='float32')
 ort_inputs["past_key_values.{}.value".format(i)] = np.zeros((batch_size,16,1,256), dtype='float32')
 
 predictions = session.run(None, ort_inputs)
 outputs = torch.from_numpy(predictions[0])
 last_token_logits = outputs[torch.arange(len(last_ind)), -2 - pad_len, :]
 pred = last_token_logits.argmax(dim=-1)
 total += len(label)
 hit += (pred == label).sum().item()

acc = hit / total
print('acc: ', acc)

### 2. Text Generation

In [None]:
import os
import time
import sys

# create session
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession('/path/to/model.onnx', sess_options)

# input prompt
# 32 tokens input
prompt = "Once upon a time, there existed a little girl, who liked to have adventures." + \
 " She wanted to go to places and meet new people, and have fun."

print("prompt: ", prompt)

total_time = 0.0
num_iter = 10
num_warmup = 3

# start
for idx in range(num_iter):
 text = []
 tic = time.time()

 input_ids = tokenizer(prompt, return_tensors="pt").input_ids

 attention_mask = torch.ones(input_ids.shape[1] +1)
 attention_mask[0] = 0
 attention_mask = attention_mask.unsqueeze(0)

 inp = {'input_ids': input_ids.detach().cpu().numpy(),
 'attention_mask': attention_mask.detach().cpu().numpy().astype('int64')}
 for i in range(28):
 inp["past_key_values.{}.key".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()
 inp["past_key_values.{}.value".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()

 for step in range(32):

 output = session.run(None, inp)
 logits = output[0]
 logits = torch.from_numpy(logits)
 next_token_logits = logits[:, -1, :]
 probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
 next_tokens = torch.argmax(probs, dim=-1)
 present_kv = output[1]
 for i in range(28):

 if step == 0:
 inp["past_key_values.{}.key".format(i)] = output[2*i+1][:, :, 1:, :]
 inp["past_key_values.{}.value".format(i)] = output[2*i+2][:, :, 1:, :]
 else:
 inp["past_key_values.{}.key".format(i)] = output[2*i+1]
 inp["past_key_values.{}.value".format(i)] = output[2*i+2]

 input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
 if step == 0:
 attention_mask = torch.cat([attention_mask[:, 1:], torch.ones([1, 1])], dim=-1)
 else:
 attention_mask = torch.cat([attention_mask, torch.ones([1, 1])], dim=-1)

 inp['attention_mask'] = attention_mask.detach().cpu().numpy().astype('int64')
 inp['input_ids'] = input_ids[:, -1:].detach().cpu().numpy()

 print(tokenizer.decode(input_ids[0]))
 toc = time.time()
 if idx >= num_warmup:
 total_time += (toc - tic)
print("Inference latency: %.3f s." % (total_time / (num_iter - num_warmup)))