{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "This tutorial demonstrates how to perform evaluation on a gpt-j-6B-int8 model." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Prerequisite" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "!pip install onnx onnxruntime torch transformers datasets accelerate" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Run\n", "\n", "### 1. Get lambada acc" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "import torch\n", "import numpy as np\n", "from datasets import load_dataset\n", "import onnxruntime as ort\n", "from torch.nn.functional import pad\n", "\n", "# load model\n", "model_id = \"EleutherAI/gpt-j-6B\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "\n", "def tokenize_function(examples):\n", " example = tokenizer(examples['text'])\n", " return example\n", "\n", "# create dataset\n", "dataset = load_dataset('lambada', split='validation')\n", "dataset = dataset.shuffle(seed=42)\n", "dataset = dataset.map(tokenize_function, batched=True)\n", "dataset.set_format(type='torch', columns=['input_ids'])\n", "\n", "# create session\n", "options = ort.SessionOptions()\n", "options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n", "session = ort.InferenceSession('/path/to/model.onnx', options, providers=ort.get_available_providers())\n", "total, hit = 0, 0\n", "index = 1\n", "\n", "# inference\n", "for idx, batch in enumerate(dataset):\n", " input_ids = batch['input_ids'].unsqueeze(0)\n", " label = input_ids[:, -1]\n", " pad_len = 0 ##set to 0\n", " input_ids = pad(input_ids, (0, pad_len), value=1)\n", " ort_inputs = {\n", " 'input_ids': input_ids.detach().cpu().numpy(),\n", " 'attention_mask': torch.cat([torch.ones(input_ids.shape), torch.ones([1, 1])], dim=-1).detach().cpu().numpy().astype('int64')\n", " }\n", " for i in range(28):\n", " ort_inputs[\"past_key_values.{}.key\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n", " ort_inputs[\"past_key_values.{}.value\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n", " predictions = session.run(None, ort_inputs)\n", " outputs = torch.from_numpy(predictions[0]) \n", " last_token_logits = outputs[:, -2 - pad_len, :]\n", " pred = last_token_logits.argmax(dim=-1)\n", " total += label.size(0)\n", " hit += (pred == label).sum().item()\n", "\n", "acc = hit / total\n", "print('acc: ', acc)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Text Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "import os\n", "import time\n", "import sys\n", "\n", "# create session\n", "sess_options = ort.SessionOptions()\n", "sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL\n", "session = ort.InferenceSession('/path/to/model.onnx', sess_options)\n", "\n", "# input prompt\n", "# 32 tokens input\n", "prompt = \"Once upon a time, there existed a little girl, who liked to have adventures.\" + \\\n", " \" She wanted to go to places and meet new people, and have fun.\"\n", "\n", "print(\"prompt: \", prompt)\n", "\n", "total_time = 0.0\n", "num_iter = 10\n", "num_warmup = 3\n", "\n", "# start\n", "for idx in range(num_iter):\n", " text = []\n", " tic = time.time()\n", "\n", " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n", "\n", " attention_mask = torch.ones(input_ids.shape[1] +1)\n", " attention_mask[0] = 0\n", " attention_mask = attention_mask.unsqueeze(0)\n", "\n", " inp = {'input_ids': input_ids.detach().cpu().numpy(),\n", " 'attention_mask': attention_mask.detach().cpu().numpy().astype('int64')}\n", " for i in range(28):\n", " inp[\"past_key_values.{}.key\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n", " inp[\"past_key_values.{}.value\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n", "\n", " for step in range(32):\n", "\n", " output = session.run(None, inp)\n", " logits = output[0]\n", " logits = torch.from_numpy(logits)\n", " next_token_logits = logits[:, -1, :]\n", " probs = torch.nn.functional.softmax(next_token_logits, dim=-1)\n", " next_tokens = torch.argmax(probs, dim=-1)\n", " present_kv = output[1]\n", " for i in range(28):\n", "\n", " if step == 0:\n", " inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1][:, :, 1:, :]\n", " inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2][:, :, 1:, :]\n", " else:\n", " inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1]\n", " inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2]\n", "\n", " input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n", " if step == 0:\n", " attention_mask = torch.cat([attention_mask[:, 1:], torch.ones([1, 1])], dim=-1)\n", " else:\n", " attention_mask = torch.cat([attention_mask, torch.ones([1, 1])], dim=-1)\n", "\n", " inp['attention_mask'] = attention_mask.detach().cpu().numpy().astype('int64')\n", " inp['input_ids'] = input_ids[:, -1:].detach().cpu().numpy()\n", "\n", " print(tokenizer.decode(input_ids[0]))\n", " toc = time.time()\n", " if idx >= num_warmup:\n", " total_time += (toc - tic)\n", "print(\"Inference latency: %.3f s.\" % (total_time / (num_iter - num_warmup)))" ] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }