Create llama2_onnx_inference.py
Browse files- llama2_onnx_inference.py +176 -0
llama2_onnx_inference.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This program will run the ONNX version of the LlamaV2 model.
|
2 |
+
# Copyright (c) Microsoft
|
3 |
+
# https://github.com/microsoft/Llama-2-Onnx/blob/38d310991a21203ac6cacc35298f420f60a527dd/MinimumExample/Example_ONNX_LlamaV2.py
|
4 |
+
import torch
|
5 |
+
import onnxruntime
|
6 |
+
import numpy as np
|
7 |
+
from sentencepiece import SentencePieceProcessor
|
8 |
+
from typing import List
|
9 |
+
import os
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
|
13 |
+
class Tokenizer:
|
14 |
+
def __init__(self, model_path: str):
|
15 |
+
# reload tokenizer
|
16 |
+
assert os.path.isfile(model_path), model_path
|
17 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
18 |
+
|
19 |
+
# BOS / EOS token IDs
|
20 |
+
self.n_words: int = self.sp_model.vocab_size()
|
21 |
+
self.bos_id: int = self.sp_model.bos_id()
|
22 |
+
self.eos_id: int = self.sp_model.eos_id()
|
23 |
+
self.pad_id: int = self.sp_model.pad_id()
|
24 |
+
|
25 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
26 |
+
|
27 |
+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
|
28 |
+
assert type(s) is str
|
29 |
+
t = self.sp_model.encode(s)
|
30 |
+
if bos:
|
31 |
+
t = [self.bos_id] + t
|
32 |
+
if eos:
|
33 |
+
t = t + [self.eos_id]
|
34 |
+
return t
|
35 |
+
|
36 |
+
def decode(self, t: List[int]) -> str:
|
37 |
+
return self.sp_model.decode(t)
|
38 |
+
|
39 |
+
|
40 |
+
def run_onnx_llamav2(
|
41 |
+
prompt: str,
|
42 |
+
onnx_file: str,
|
43 |
+
embedding_file: str,
|
44 |
+
tokenizer_path: str,
|
45 |
+
max_gen_len: int = 256,
|
46 |
+
) -> str:
|
47 |
+
# Create the ONNX session
|
48 |
+
options = onnxruntime.SessionOptions()
|
49 |
+
llm_session = onnxruntime.InferenceSession(
|
50 |
+
onnx_file,
|
51 |
+
sess_options=options,
|
52 |
+
providers=[
|
53 |
+
"DmlExecutionProvider",
|
54 |
+
"CUDAExecutionProvider",
|
55 |
+
"CPUExecutionProvider",
|
56 |
+
],
|
57 |
+
)
|
58 |
+
|
59 |
+
# get the data type used by the model
|
60 |
+
data_type_str = llm_session.get_inputs()[0].type
|
61 |
+
if data_type_str == "tensor(float16)":
|
62 |
+
data_type = np.float16
|
63 |
+
elif data_type_str == "tensor(float32)" or data_type_str == "tensor(float)":
|
64 |
+
data_type = np.float32
|
65 |
+
else:
|
66 |
+
raise Exception(f"Unknown data type {data_type_str}")
|
67 |
+
|
68 |
+
# Get the relevant shapes so we can create the inputs
|
69 |
+
for inputs_meta in llm_session._inputs_meta:
|
70 |
+
if inputs_meta.name == "x":
|
71 |
+
x_shape = inputs_meta.shape
|
72 |
+
elif inputs_meta.name == "attn_mask":
|
73 |
+
attn_mask_shape = inputs_meta.shape
|
74 |
+
elif inputs_meta.name == "k_cache":
|
75 |
+
k_cache_shape = inputs_meta.shape
|
76 |
+
|
77 |
+
hidden_size = x_shape[2]
|
78 |
+
max_seq_len = attn_mask_shape[1]
|
79 |
+
n_layers = k_cache_shape[1]
|
80 |
+
n_heads = k_cache_shape[3]
|
81 |
+
|
82 |
+
# Initialize the tokenizer and produce the initial tokens.
|
83 |
+
tokenizer = Tokenizer(model_path=tokenizer_path)
|
84 |
+
tokens = tokenizer.encode(prompt, bos=True, eos=False)
|
85 |
+
|
86 |
+
# create the embedding layer.
|
87 |
+
embedding_layer = torch.nn.Embedding(tokenizer.n_words, hidden_size)
|
88 |
+
embedding_layer.load_state_dict(torch.load(embedding_file))
|
89 |
+
embedding_layer.eval()
|
90 |
+
|
91 |
+
# Create the embeddings of the initial prompt.
|
92 |
+
x = embedding_layer(torch.tensor(tokens)).detach().cpu().numpy()
|
93 |
+
x = np.expand_dims(x, axis=0).astype(data_type)
|
94 |
+
|
95 |
+
# Create the attention mask.
|
96 |
+
attn_mask = -10000.0 * torch.triu(
|
97 |
+
torch.ones(attn_mask_shape), diagonal=1
|
98 |
+
).cpu().detach().numpy().astype(data_type)
|
99 |
+
|
100 |
+
# Create the K and V caches.
|
101 |
+
head_dim = int(hidden_size / n_heads)
|
102 |
+
k_cache = np.zeros([1, n_layers, max_seq_len, n_heads, head_dim], dtype=data_type)
|
103 |
+
v_cache = np.zeros([1, n_layers, max_seq_len, n_heads, head_dim], dtype=data_type)
|
104 |
+
|
105 |
+
# Iteratively generate tokens.
|
106 |
+
pos = np.array(0)
|
107 |
+
output_tokens = []
|
108 |
+
for idx in range(max_gen_len):
|
109 |
+
results = llm_session.run(
|
110 |
+
None,
|
111 |
+
{
|
112 |
+
"x": x,
|
113 |
+
"attn_mask": attn_mask,
|
114 |
+
"k_cache": k_cache[:, :, :pos],
|
115 |
+
"v_cache": v_cache[:, :, :pos],
|
116 |
+
"pos": pos.astype(np.int64),
|
117 |
+
},
|
118 |
+
)
|
119 |
+
logits, k_out, v_out = results[:3]
|
120 |
+
|
121 |
+
# Decide the next token using your preferred sampling strategy.
|
122 |
+
next_token = np.argmax(logits, axis=-1).astype(np.int64)
|
123 |
+
output_tokens.extend(next_token)
|
124 |
+
|
125 |
+
# Stop if/when we get an ENDOFTEXT token before reaching maximum sequence length
|
126 |
+
if next_token == tokenizer.eos_id:
|
127 |
+
break
|
128 |
+
|
129 |
+
# Update the cache
|
130 |
+
seq_len = x.shape[1]
|
131 |
+
k_cache[:, :, pos : pos + seq_len] = k_out
|
132 |
+
v_cache[:, :, pos : pos + seq_len] = v_out
|
133 |
+
|
134 |
+
# Update pos and x ready for the next round.
|
135 |
+
pos = np.array(int(pos) + seq_len, dtype=np.int64)
|
136 |
+
x = embedding_layer(torch.tensor(next_token)).unsqueeze(0)
|
137 |
+
x = x.cpu().detach().numpy().astype(data_type)
|
138 |
+
|
139 |
+
output_str = tokenizer.decode(torch.tensor(output_tokens).tolist())
|
140 |
+
|
141 |
+
return output_str
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == "__main__":
|
145 |
+
parser = argparse.ArgumentParser()
|
146 |
+
parser.add_argument(
|
147 |
+
"--prompt",
|
148 |
+
type=str,
|
149 |
+
required=True,
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--onnx_file",
|
153 |
+
type=str,
|
154 |
+
required=True,
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--embedding_file",
|
158 |
+
type=str,
|
159 |
+
required=True,
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--tokenizer_path",
|
163 |
+
type=str,
|
164 |
+
required=True,
|
165 |
+
)
|
166 |
+
parser.add_argument("--max_gen_len", type=int, default=256)
|
167 |
+
args = parser.parse_args()
|
168 |
+
response = run_onnx_llamav2(
|
169 |
+
args.prompt,
|
170 |
+
args.onnx_file,
|
171 |
+
args.embedding_file,
|
172 |
+
args.tokenizer_path,
|
173 |
+
args.max_gen_len,
|
174 |
+
)
|
175 |
+
|
176 |
+
print(response)
|