{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "This tutorial demonstrates how to perform evaluation on a gpt-j-6B-int8 model." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Prerequisite" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "!pip install onnx onnxruntime torch transformers datasets accelerate" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Run\n", "\n", "### 1. Get lambada acc" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "import torch\n", "import numpy as np\n", "from datasets import load_dataset\n", "import onnxruntime as ort\n", "from torch.nn.functional import pad\n", "\n", "# load model\n", "model_id = \"EleutherAI/gpt-j-6B\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "\n", "def tokenize_function(examples):\n", " example = tokenizer(examples['text'])\n", " return example\n", "\n", "# create dataset\n", "dataset = load_dataset('lambada', split='validation')\n", "dataset = dataset.shuffle(seed=42)\n", "dataset = dataset.map(tokenize_function, batched=True)\n", "dataset.set_format(type='torch', columns=['input_ids'])\n", "\n", "# create session\n", "options = ort.SessionOptions()\n", "options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n", "session = ort.InferenceSession('/path/to/model.onnx', options, providers=ort.get_available_providers())\n", "total, hit = 0, 0\n", "index = 1\n", "\n", "# inference\n", "for idx, batch in enumerate(dataset):\n", " input_ids = batch['input_ids'].unsqueeze(0)\n", " label = input_ids[:, -1]\n", " pad_len = 0 ##set to 0\n", " input_ids = pad(input_ids, (0, pad_len), value=1)\n", " ort_inputs = {\n", " 'input_ids': input_ids.detach().cpu().numpy(),\n", " 'attention_mask': torch.cat([torch.ones(input_ids.shape), torch.ones([1, 1])], dim=-1).detach().cpu().numpy().astype('int64')\n", " }\n", " for i in range(28):\n", " ort_inputs[\"past_key_values.{}.key\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n", " ort_inputs[\"past_key_values.{}.value\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n", " predictions = session.run(None, ort_inputs)\n", " outputs = torch.from_numpy(predictions[0]) \n", " last_token_logits = outputs[:, -2 - pad_len, :]\n", " pred = last_token_logits.argmax(dim=-1)\n", " total += label.size(0)\n", " hit += (pred == label).sum().item()\n", "\n", "acc = hit / total\n", "print('acc: ', acc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "# batch inference\n", "\n", "from transformers import AutoTokenizer\n", "import torch\n", "import numpy as np\n", "from datasets import load_dataset\n", "import onnxruntime as ort\n", "from torch.nn.functional import pad\n", "from torch.utils.data import DataLoader\n", "\n", "batch_size = 2\n", "pad_max = 196\n", "\n", "# load model\n", "model_id = \"EleutherAI/gpt-j-6B\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "\n", "def tokenize_function(examples):\n", " example = tokenizer(examples['text'])\n", " return example\n", "\n", "# create dataloader\n", "class Dataloader:\n", " def __init__(self, pad_max=196, batch_size=1, sub_folder='validation'):\n", " self.pad_max = pad_max\n", " self.batch_size=batch_size\n", " dataset = load_dataset('lambada', split=sub_folder)\n", " dataset = dataset.map(tokenize_function, batched=True)\n", " dataset.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\"])\n", " self.dataloader = DataLoader(\n", " dataset,\n", " batch_size=self.batch_size,\n", " shuffle=False,\n", " collate_fn=self.collate_batch,\n", " )\n", "\n", " def collate_batch(self, batch):\n", " input_ids_padded = []\n", " attention_mask_padded = []\n", " last_ind = []\n", " for text in batch:\n", " input_ids = text[\"input_ids\"] if text[\"input_ids\"].shape[0] <= self.pad_max else text[\"input_ids\"][0:int(self.pad_max-1)]\n", " pad_len = self.pad_max - input_ids.shape[0]\n", " last_ind.append(input_ids.shape[0] - 1)\n", " input_ids = pad(input_ids, (0, pad_len), value=1)\n", " input_ids_padded.append(input_ids)\n", " attention_mask = torch.ones(input_ids.shape[0] + 1)\n", " attention_mask_padded.append(attention_mask)\n", " return (torch.vstack(input_ids_padded), torch.vstack(attention_mask_padded)), torch.tensor(last_ind)\n", "\n", " def __iter__(self):\n", " try:\n", " for (input_ids, attention_mask), last_ind in self.dataloader:\n", " data = [input_ids.detach().cpu().numpy().astype('int64')]\n", " data.append(attention_mask.detach().cpu().numpy().astype('int64'))\n", " yield data, last_ind.detach().cpu().numpy()\n", " except StopIteration:\n", " return\n", "\n", "# create session\n", "options = ort.SessionOptions()\n", "options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n", "session = ort.InferenceSession('/path/to/model.onnx', options, providers=ort.get_available_providers())\n", "total, hit = 0, 0\n", "\n", "dataloader = Dataloader(pad_max=pad_max, batch_size=batch_size)\n", "\n", "# inference\n", "for idx, (batch, last_ind) in enumerate(dataloader):\n", " label = torch.from_numpy(batch[0][torch.arange(len(last_ind)), last_ind])\n", " pad_len = pad_max - last_ind - 1\n", " ort_inputs = {\n", " 'input_ids': batch[0],\n", " 'attention_mask': batch[1]\n", " }\n", " for i in range(28):\n", " ort_inputs[\"past_key_values.{}.key\".format(i)] = np.zeros((batch_size,16,1,256), dtype='float32')\n", " ort_inputs[\"past_key_values.{}.value\".format(i)] = np.zeros((batch_size,16,1,256), dtype='float32')\n", " \n", " predictions = session.run(None, ort_inputs)\n", " outputs = torch.from_numpy(predictions[0])\n", " last_token_logits = outputs[torch.arange(len(last_ind)), -2 - pad_len, :]\n", " pred = last_token_logits.argmax(dim=-1)\n", " total += len(label)\n", " hit += (pred == label).sum().item()\n", "\n", "acc = hit / total\n", "print('acc: ', acc)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Text Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "import os\n", "import time\n", "import sys\n", "\n", "# create session\n", "sess_options = ort.SessionOptions()\n", "sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n", "session = ort.InferenceSession('/path/to/model.onnx', sess_options)\n", "\n", "# input prompt\n", "# 32 tokens input\n", "prompt = \"Once upon a time, there existed a little girl, who liked to have adventures.\" + \\\n", " \" She wanted to go to places and meet new people, and have fun.\"\n", "\n", "print(\"prompt: \", prompt)\n", "\n", "total_time = 0.0\n", "num_iter = 10\n", "num_warmup = 3\n", "\n", "# start\n", "for idx in range(num_iter):\n", " text = []\n", " tic = time.time()\n", "\n", " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", "\n", " attention_mask = torch.ones(input_ids.shape[1] +1)\n", " attention_mask[0] = 0\n", " attention_mask = attention_mask.unsqueeze(0)\n", "\n", " inp = {'input_ids': input_ids.detach().cpu().numpy(),\n", " 'attention_mask': attention_mask.detach().cpu().numpy().astype('int64')}\n", " for i in range(28):\n", " inp[\"past_key_values.{}.key\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n", " inp[\"past_key_values.{}.value\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n", "\n", " for step in range(32):\n", "\n", " output = session.run(None, inp)\n", " logits = output[0]\n", " logits = torch.from_numpy(logits)\n", " next_token_logits = logits[:, -1, :]\n", " probs = torch.nn.functional.softmax(next_token_logits, dim=-1)\n", " next_tokens = torch.argmax(probs, dim=-1)\n", " present_kv = output[1]\n", " for i in range(28):\n", "\n", " if step == 0:\n", " inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1][:, :, 1:, :]\n", " inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2][:, :, 1:, :]\n", " else:\n", " inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1]\n", " inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2]\n", "\n", " input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n", " if step == 0:\n", " attention_mask = torch.cat([attention_mask[:, 1:], torch.ones([1, 1])], dim=-1)\n", " else:\n", " attention_mask = torch.cat([attention_mask, torch.ones([1, 1])], dim=-1)\n", "\n", " inp['attention_mask'] = attention_mask.detach().cpu().numpy().astype('int64')\n", " inp['input_ids'] = input_ids[:, -1:].detach().cpu().numpy()\n", "\n", " print(tokenizer.decode(input_ids[0]))\n", " toc = time.time()\n", " if idx >= num_warmup:\n", " total_time += (toc - tic)\n", "print(\"Inference latency: %.3f s.\" % (total_time / (num_iter - num_warmup)))" ] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }