MengniWang commited on
Commit
653ae5f
1 Parent(s): 8289f2e

update script

Browse files
Files changed (1) hide show
  1. evaluation.ipynb +2 -1
evaluation.ipynb CHANGED
@@ -53,6 +53,7 @@
53
  "source": [
54
  "from transformers import AutoTokenizer\n",
55
  "import torch\n",
 
56
  "from datasets import load_dataset\n",
57
  "import onnxruntime as ort\n",
58
  "from torch.nn.functional import pad\n",
@@ -157,7 +158,7 @@
157
  " inp[\"past_key_values.{}.key\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
158
  " inp[\"past_key_values.{}.value\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
159
  "\n",
160
- " for i in range(32):\n",
161
  "\n",
162
  " output = session.run(None, inp)\n",
163
  " logits = output[0]\n",
 
53
  "source": [
54
  "from transformers import AutoTokenizer\n",
55
  "import torch\n",
56
+ "import numpy as np\n",
57
  "from datasets import load_dataset\n",
58
  "import onnxruntime as ort\n",
59
  "from torch.nn.functional import pad\n",
 
158
  " inp[\"past_key_values.{}.key\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
159
  " inp[\"past_key_values.{}.value\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
160
  "\n",
161
+ " for step in range(32):\n",
162
  "\n",
163
  " output = session.run(None, inp)\n",
164
  " logits = output[0]\n",