alpindale commited on
Commit
6e19318
1 Parent(s): 5f8353d

Create llama2_onnx_inference.py

Browse files
Files changed (1) hide show
  1. llama2_onnx_inference.py +176 -0
llama2_onnx_inference.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This program will run the ONNX version of the LlamaV2 model.
2
+ # Copyright (c) Microsoft
3
+ # https://github.com/microsoft/Llama-2-Onnx/blob/38d310991a21203ac6cacc35298f420f60a527dd/MinimumExample/Example_ONNX_LlamaV2.py
4
+ import torch
5
+ import onnxruntime
6
+ import numpy as np
7
+ from sentencepiece import SentencePieceProcessor
8
+ from typing import List
9
+ import os
10
+ import argparse
11
+
12
+
13
+ class Tokenizer:
14
+ def __init__(self, model_path: str):
15
+ # reload tokenizer
16
+ assert os.path.isfile(model_path), model_path
17
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
18
+
19
+ # BOS / EOS token IDs
20
+ self.n_words: int = self.sp_model.vocab_size()
21
+ self.bos_id: int = self.sp_model.bos_id()
22
+ self.eos_id: int = self.sp_model.eos_id()
23
+ self.pad_id: int = self.sp_model.pad_id()
24
+
25
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
26
+
27
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
28
+ assert type(s) is str
29
+ t = self.sp_model.encode(s)
30
+ if bos:
31
+ t = [self.bos_id] + t
32
+ if eos:
33
+ t = t + [self.eos_id]
34
+ return t
35
+
36
+ def decode(self, t: List[int]) -> str:
37
+ return self.sp_model.decode(t)
38
+
39
+
40
+ def run_onnx_llamav2(
41
+ prompt: str,
42
+ onnx_file: str,
43
+ embedding_file: str,
44
+ tokenizer_path: str,
45
+ max_gen_len: int = 256,
46
+ ) -> str:
47
+ # Create the ONNX session
48
+ options = onnxruntime.SessionOptions()
49
+ llm_session = onnxruntime.InferenceSession(
50
+ onnx_file,
51
+ sess_options=options,
52
+ providers=[
53
+ "DmlExecutionProvider",
54
+ "CUDAExecutionProvider",
55
+ "CPUExecutionProvider",
56
+ ],
57
+ )
58
+
59
+ # get the data type used by the model
60
+ data_type_str = llm_session.get_inputs()[0].type
61
+ if data_type_str == "tensor(float16)":
62
+ data_type = np.float16
63
+ elif data_type_str == "tensor(float32)" or data_type_str == "tensor(float)":
64
+ data_type = np.float32
65
+ else:
66
+ raise Exception(f"Unknown data type {data_type_str}")
67
+
68
+ # Get the relevant shapes so we can create the inputs
69
+ for inputs_meta in llm_session._inputs_meta:
70
+ if inputs_meta.name == "x":
71
+ x_shape = inputs_meta.shape
72
+ elif inputs_meta.name == "attn_mask":
73
+ attn_mask_shape = inputs_meta.shape
74
+ elif inputs_meta.name == "k_cache":
75
+ k_cache_shape = inputs_meta.shape
76
+
77
+ hidden_size = x_shape[2]
78
+ max_seq_len = attn_mask_shape[1]
79
+ n_layers = k_cache_shape[1]
80
+ n_heads = k_cache_shape[3]
81
+
82
+ # Initialize the tokenizer and produce the initial tokens.
83
+ tokenizer = Tokenizer(model_path=tokenizer_path)
84
+ tokens = tokenizer.encode(prompt, bos=True, eos=False)
85
+
86
+ # create the embedding layer.
87
+ embedding_layer = torch.nn.Embedding(tokenizer.n_words, hidden_size)
88
+ embedding_layer.load_state_dict(torch.load(embedding_file))
89
+ embedding_layer.eval()
90
+
91
+ # Create the embeddings of the initial prompt.
92
+ x = embedding_layer(torch.tensor(tokens)).detach().cpu().numpy()
93
+ x = np.expand_dims(x, axis=0).astype(data_type)
94
+
95
+ # Create the attention mask.
96
+ attn_mask = -10000.0 * torch.triu(
97
+ torch.ones(attn_mask_shape), diagonal=1
98
+ ).cpu().detach().numpy().astype(data_type)
99
+
100
+ # Create the K and V caches.
101
+ head_dim = int(hidden_size / n_heads)
102
+ k_cache = np.zeros([1, n_layers, max_seq_len, n_heads, head_dim], dtype=data_type)
103
+ v_cache = np.zeros([1, n_layers, max_seq_len, n_heads, head_dim], dtype=data_type)
104
+
105
+ # Iteratively generate tokens.
106
+ pos = np.array(0)
107
+ output_tokens = []
108
+ for idx in range(max_gen_len):
109
+ results = llm_session.run(
110
+ None,
111
+ {
112
+ "x": x,
113
+ "attn_mask": attn_mask,
114
+ "k_cache": k_cache[:, :, :pos],
115
+ "v_cache": v_cache[:, :, :pos],
116
+ "pos": pos.astype(np.int64),
117
+ },
118
+ )
119
+ logits, k_out, v_out = results[:3]
120
+
121
+ # Decide the next token using your preferred sampling strategy.
122
+ next_token = np.argmax(logits, axis=-1).astype(np.int64)
123
+ output_tokens.extend(next_token)
124
+
125
+ # Stop if/when we get an ENDOFTEXT token before reaching maximum sequence length
126
+ if next_token == tokenizer.eos_id:
127
+ break
128
+
129
+ # Update the cache
130
+ seq_len = x.shape[1]
131
+ k_cache[:, :, pos : pos + seq_len] = k_out
132
+ v_cache[:, :, pos : pos + seq_len] = v_out
133
+
134
+ # Update pos and x ready for the next round.
135
+ pos = np.array(int(pos) + seq_len, dtype=np.int64)
136
+ x = embedding_layer(torch.tensor(next_token)).unsqueeze(0)
137
+ x = x.cpu().detach().numpy().astype(data_type)
138
+
139
+ output_str = tokenizer.decode(torch.tensor(output_tokens).tolist())
140
+
141
+ return output_str
142
+
143
+
144
+ if __name__ == "__main__":
145
+ parser = argparse.ArgumentParser()
146
+ parser.add_argument(
147
+ "--prompt",
148
+ type=str,
149
+ required=True,
150
+ )
151
+ parser.add_argument(
152
+ "--onnx_file",
153
+ type=str,
154
+ required=True,
155
+ )
156
+ parser.add_argument(
157
+ "--embedding_file",
158
+ type=str,
159
+ required=True,
160
+ )
161
+ parser.add_argument(
162
+ "--tokenizer_path",
163
+ type=str,
164
+ required=True,
165
+ )
166
+ parser.add_argument("--max_gen_len", type=int, default=256)
167
+ args = parser.parse_args()
168
+ response = run_onnx_llamav2(
169
+ args.prompt,
170
+ args.onnx_file,
171
+ args.embedding_file,
172
+ args.tokenizer_path,
173
+ args.max_gen_len,
174
+ )
175
+
176
+ print(response)