hengyu commited on
Commit
f27847f
1 Parent(s): d4e876c

update model

Browse files
Files changed (3) hide show
  1. evaluation.ipynb +56 -12
  2. model.onnx +2 -2
  3. weights.pb +2 -2
evaluation.ipynb CHANGED
@@ -86,14 +86,18 @@
86
  " input_ids = pad(input_ids, (0, pad_len), value=1)\n",
87
  " ort_inputs = {\n",
88
  " 'input_ids': input_ids.detach().cpu().numpy(),\n",
89
- " 'attention_mask': torch.ones(input_ids.shape).detach().cpu().numpy().astype('int64')\n",
90
  " }\n",
 
 
 
91
  " predictions = session.run(None, ort_inputs)\n",
92
  " outputs = torch.from_numpy(predictions[0]) \n",
93
  " last_token_logits = outputs[:, -2 - pad_len, :]\n",
94
  " pred = last_token_logits.argmax(dim=-1)\n",
95
  " total += label.size(0)\n",
96
  " hit += (pred == label).sum().item()\n",
 
97
  "acc = hit / total\n",
98
  "print('acc: ', acc)"
99
  ]
@@ -132,19 +136,59 @@
132
  "\n",
133
  "print(\"prompt: \", prompt)\n",
134
  "\n",
 
 
 
 
135
  "# start\n",
136
- "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
137
- "for i in range(32):\n",
 
 
 
 
 
 
 
 
138
  " inp = {'input_ids': input_ids.detach().cpu().numpy(),\n",
139
- " 'attention_mask': torch.ones(input_ids.shape).detach().cpu().numpy().astype('int64')}\n",
140
- " output = session.run(None, inp)\n",
141
- " logits = output[0]\n",
142
- " logits = torch.from_numpy(logits)\n",
143
- " next_token_logits = logits[:, -1, :]\n",
144
- " probs = torch.nn.functional.softmax(next_token_logits, dim=-1)\n",
145
- " next_tokens = torch.argmax(probs, dim=-1)\n",
146
- " input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
147
- "print(tokenizer.decode(input_ids[0]))"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  ]
149
  }
150
  ],
 
86
  " input_ids = pad(input_ids, (0, pad_len), value=1)\n",
87
  " ort_inputs = {\n",
88
  " 'input_ids': input_ids.detach().cpu().numpy(),\n",
89
+ " 'attention_mask': torch.cat([torch.ones(input_ids.shape), torch.ones([1, 1])], dim=-1).detach().cpu().numpy().astype('int64')\n",
90
  " }\n",
91
+ " for i in range(28):\n",
92
+ " ort_inputs[\"past_key_values.{}.key\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n",
93
+ " ort_inputs[\"past_key_values.{}.value\".format(i)] = np.zeros((1,16,1,256), dtype='float32')\n",
94
  " predictions = session.run(None, ort_inputs)\n",
95
  " outputs = torch.from_numpy(predictions[0]) \n",
96
  " last_token_logits = outputs[:, -2 - pad_len, :]\n",
97
  " pred = last_token_logits.argmax(dim=-1)\n",
98
  " total += label.size(0)\n",
99
  " hit += (pred == label).sum().item()\n",
100
+ "\n",
101
  "acc = hit / total\n",
102
  "print('acc: ', acc)"
103
  ]
 
136
  "\n",
137
  "print(\"prompt: \", prompt)\n",
138
  "\n",
139
+ "total_time = 0.0\n",
140
+ "num_iter = 10\n",
141
+ "num_warmup = 3\n",
142
+ "\n",
143
  "# start\n",
144
+ "for idx in range(num_iter):\n",
145
+ " text = []\n",
146
+ " tic = time.time()\n",
147
+ "\n",
148
+ " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n",
149
+ "\n",
150
+ " attention_mask = torch.ones(input_ids.shape[1] +1)\n",
151
+ " attention_mask[0] = 0\n",
152
+ " attention_mask = attention_mask.unsqueeze(0)\n",
153
+ "\n",
154
  " inp = {'input_ids': input_ids.detach().cpu().numpy(),\n",
155
+ " 'attention_mask': attention_mask.detach().cpu().numpy().astype('int64')}\n",
156
+ " for i in range(28):\n",
157
+ " inp[\"past_key_values.{}.key\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
158
+ " inp[\"past_key_values.{}.value\".format(i)] = torch.zeros([1,16,1,256]).detach().cpu().numpy()\n",
159
+ "\n",
160
+ " for i in range(32):\n",
161
+ "\n",
162
+ " output = session.run(None, inp)\n",
163
+ " logits = output[0]\n",
164
+ " logits = torch.from_numpy(logits)\n",
165
+ " next_token_logits = logits[:, -1, :]\n",
166
+ " probs = torch.nn.functional.softmax(next_token_logits, dim=-1)\n",
167
+ " next_tokens = torch.argmax(probs, dim=-1)\n",
168
+ " present_kv = output[1]\n",
169
+ " for i in range(28):\n",
170
+ "\n",
171
+ " if step == 0:\n",
172
+ " inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1][:, :, 1:, :]\n",
173
+ " inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2][:, :, 1:, :]\n",
174
+ " else:\n",
175
+ " inp[\"past_key_values.{}.key\".format(i)] = output[2*i+1]\n",
176
+ " inp[\"past_key_values.{}.value\".format(i)] = output[2*i+2]\n",
177
+ "\n",
178
+ " input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)\n",
179
+ " if step == 0:\n",
180
+ " attention_mask = torch.cat([attention_mask[:, 1:], torch.ones([1, 1])], dim=-1)\n",
181
+ " else:\n",
182
+ " attention_mask = torch.cat([attention_mask, torch.ones([1, 1])], dim=-1)\n",
183
+ "\n",
184
+ " inp['attention_mask'] = attention_mask.detach().cpu().numpy().astype('int64')\n",
185
+ " inp['input_ids'] = input_ids[:, -1:].detach().cpu().numpy()\n",
186
+ "\n",
187
+ " print(tokenizer.decode(input_ids[0]))\n",
188
+ " toc = time.time()\n",
189
+ " if idx >= num_warmup:\n",
190
+ " total_time += (toc - tic)\n",
191
+ "print(\"Inference latency: %.3f s.\" % (total_time / (num_iter - num_warmup)))"
192
  ]
193
  }
194
  ],
model.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:246a7a41aa525a327bf019840ff1f726c7562681b03fc09b9aec5633695fbd95
3
- size 5852363
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4638154e78188303e3819fff286a65547673a84d5938beec68d3feb8676508c6
3
+ size 6173261
weights.pb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:747ec4484853a81ae90ac3a75518f2dc4623a88b06db69300d961b69166c3087
3
- size 6790222720
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcdc4899701755241dfb65b06569c7b5e5d3a6be67c2054f7de431fd6bb1cd48
3
+ size 6057661312