happyme531 commited on
Commit
424b51e
1 Parent(s): 385f65d

Split part of vision encoder to CPU and optimize Transpose ops. (Reupload to correct path)

Browse files
Files changed (1) hide show
  1. onnx/convert.py +166 -38
onnx/convert.py CHANGED
@@ -1,31 +1,56 @@
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
4
- # In[1]:
5
-
6
-
7
- import os
8
- import urllib
9
- import traceback
10
- import time
11
- import sys
12
- import numpy as np
13
- import cv2
14
  from rknn.api import RKNN
15
  from math import exp
16
  from sys import exit
17
 
 
 
 
18
  batch_size = 1
19
  # embed_seq_len = 590
20
 
21
- vision_size = (512, 512)
22
 
23
- vision_tokens = 257
24
- prompt_tokens = 14
25
 
26
- encoder_seq_len = vision_tokens + prompt_tokens
27
  decoder_seq_len = 1
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def convert_decoder():
30
  rknn = RKNN(verbose=True)
31
 
@@ -34,21 +59,21 @@ def convert_decoder():
34
  DATASET="dataset.txt"
35
  QUANTIZE=False
36
 
 
 
 
 
 
 
37
  # pre-process config
38
  print('--> Config model')
39
- rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True )
 
40
  print('done')
41
 
42
  # Load ONNX model
43
  print('--> Loading model')
44
  ret = rknn.load_onnx(model=ONNX_MODEL,
45
- inputs=["encoder_attention_mask",
46
- "encoder_hidden_states",
47
- "inputs_embeds",
48
- ],
49
- input_size_list=[[batch_size, encoder_seq_len],
50
- [batch_size, encoder_seq_len, 768],
51
- [batch_size, decoder_seq_len, 768]],
52
  )
53
  if ret != 0:
54
  print('Load model failed!')
@@ -79,16 +104,16 @@ def convert_encoder():
79
  DATASET="dataset.txt"
80
  QUANTIZE=False
81
 
 
 
82
  # pre-process config
83
  print('--> Config model')
84
- rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True )
85
  print('done')
86
 
87
  # Load ONNX model
88
  print('--> Loading model')
89
- ret = rknn.load_onnx(model=ONNX_MODEL,
90
- inputs=["attention_mask", "inputs_embeds"],
91
- input_size_list=[[batch_size, encoder_seq_len], [batch_size, encoder_seq_len, 768]],
92
  )
93
  if ret != 0:
94
  print('Load model failed!')
@@ -111,32 +136,106 @@ def convert_encoder():
111
  exit(ret)
112
  print('done')
113
 
114
- def convert_embed():
115
  rknn = RKNN(verbose=True)
116
 
117
- ONNX_MODEL="embed_tokens.onnx"
118
- RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
119
  DATASET="dataset.txt"
120
  QUANTIZE=False
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # pre-process config
123
  print('--> Config model')
124
- rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True )
125
  print('done')
126
 
127
  # Load ONNX model
128
  print('--> Loading model')
129
- ret = rknn.load_onnx(model=ONNX_MODEL,
130
- inputs=["input_ids"],
131
- input_size_list=[[batch_size, embed_seq_len]],
132
- )
133
  if ret != 0:
134
  print('Load model failed!')
135
  exit(ret)
136
  print('done')
137
 
138
  # Build model
139
- print('--> Building model')
140
  ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
141
  if ret != 0:
142
  print('Build model failed!')
@@ -150,8 +249,14 @@ def convert_embed():
150
  print('Export RKNN model failed!')
151
  exit(ret)
152
  print('done')
 
 
 
 
153
 
154
- def convert_vision():
 
 
155
  rknn = RKNN(verbose=True)
156
 
157
  ONNX_MODEL="vision_encoder.onnx"
@@ -191,12 +296,32 @@ def convert_vision():
191
  exit(ret)
192
  print('done')
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  import argparse
196
  # python convert.py <decoder|encoder|vision|all>
197
  if __name__ == "__main__":
198
  parser = argparse.ArgumentParser()
199
  parser.add_argument("model", type=str, help="Model to convert")
 
200
  args = parser.parse_args()
201
  if args.model == "decoder":
202
  convert_decoder()
@@ -205,7 +330,10 @@ if __name__ == "__main__":
205
  # elif args.model == "embed": # embed is faster with cpu
206
  # convert_embed()
207
  elif args.model == "vision":
208
- convert_vision()
 
 
 
209
  elif args.model == "all":
210
  convert_decoder()
211
  convert_encoder()
@@ -213,4 +341,4 @@ if __name__ == "__main__":
213
  convert_vision()
214
  else:
215
  print("Invalid model")
216
- exit(1)
 
1
  #!/usr/bin/env python
2
  # coding: utf-8
3
 
 
 
 
 
 
 
 
 
 
 
4
  from rknn.api import RKNN
5
  from math import exp
6
  from sys import exit
7
 
8
+ import onnx
9
+ import onnxscript
10
+
11
  batch_size = 1
12
  # embed_seq_len = 590
13
 
14
+ prompt_tokens_list = [15, 17, 21, 25]
15
 
16
+ encoder_seq_len_list = [577 + p for p in prompt_tokens_list]
 
17
 
 
18
  decoder_seq_len = 1
19
 
20
+ # set current directory to the directory of this file
21
+ import os
22
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
23
+
24
+ import subprocess
25
+ import select
26
+
27
+ def run_python_code(code):
28
+ # 启动子进程并执行代码
29
+ process = subprocess.Popen(
30
+ ['python', '-c', code],
31
+ stdout=subprocess.PIPE,
32
+ stderr=subprocess.PIPE,
33
+ text=True
34
+ )
35
+
36
+ # 实时读取子进程的输出和错误输出
37
+ while True:
38
+ reads = [process.stdout.fileno(), process.stderr.fileno()]
39
+ ret = select.select(reads, [], [])
40
+
41
+ for fd in ret[0]:
42
+ if fd == process.stdout.fileno():
43
+ output = process.stdout.readline()
44
+ if output:
45
+ print(output.strip())
46
+ if fd == process.stderr.fileno():
47
+ err = process.stderr.readline()
48
+ if err:
49
+ print(f"Error: {err.strip()}")
50
+
51
+ if process.poll() is not None:
52
+ break
53
+
54
  def convert_decoder():
55
  rknn = RKNN(verbose=True)
56
 
 
59
  DATASET="dataset.txt"
60
  QUANTIZE=False
61
 
62
+ # [[batch_size, encoder_seq_len],
63
+ # [batch_size, encoder_seq_len, 768],
64
+ # [batch_size, decoder_seq_len, 768]]
65
+ input_shapes =[[[batch_size, encoder_seq_len],
66
+ [batch_size, encoder_seq_len, 768],
67
+ [batch_size, decoder_seq_len, 768]] for encoder_seq_len in encoder_seq_len_list]
68
  # pre-process config
69
  print('--> Config model')
70
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True,
71
+ dynamic_input=input_shapes)
72
  print('done')
73
 
74
  # Load ONNX model
75
  print('--> Loading model')
76
  ret = rknn.load_onnx(model=ONNX_MODEL,
 
 
 
 
 
 
 
77
  )
78
  if ret != 0:
79
  print('Load model failed!')
 
104
  DATASET="dataset.txt"
105
  QUANTIZE=False
106
 
107
+ #[[batch_size, encoder_seq_len], [batch_size, encoder_seq_len, 768]]
108
+ input_shapes = [[[batch_size, encoder_seq_len], [batch_size, encoder_seq_len, 768]] for encoder_seq_len in encoder_seq_len_list]
109
  # pre-process config
110
  print('--> Config model')
111
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True, dynamic_input=input_shapes)
112
  print('done')
113
 
114
  # Load ONNX model
115
  print('--> Loading model')
116
+ ret = rknn.load_onnx(model=ONNX_MODEL
 
 
117
  )
118
  if ret != 0:
119
  print('Load model failed!')
 
136
  exit(ret)
137
  print('done')
138
 
139
+ def convert_vision():
140
  rknn = RKNN(verbose=True)
141
 
142
+ ONNX_MODEL="vision_encoder.onnx"
 
143
  DATASET="dataset.txt"
144
  QUANTIZE=False
145
 
146
+ # split the first Transformers block into a separate model because it's too large to fit in the rknn
147
+ onnx.utils.extract_model(ONNX_MODEL, "vision_encoder_part1.onnx", ['pixel_values'], ['/blocks.0/blocks.0.0/channel_block/channel_attn/Add_output_0'])
148
+
149
+ ##### Build stage 1, this will crash the python process, so we need to run it in a separate process
150
+ code = f"""
151
+ from rknn.api import RKNN
152
+ rknn = RKNN(verbose=True)
153
+ ONNX_MODEL="vision_encoder.onnx"
154
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
155
+ DATASET="dataset.txt"
156
+ QUANTIZE=False
157
+ batch_size = {batch_size}
158
+ # pre-process config
159
+ print('--> Config model')
160
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True)
161
+ print('done')
162
+
163
+ # Load ONNX model
164
+ print('--> Loading model')
165
+ ret = rknn.load_onnx(model=ONNX_MODEL,
166
+ inputs=["pixel_values"],
167
+ input_size_list=[[batch_size, 3, 768, 768]],
168
+ )
169
+ if ret != 0:
170
+ print('Load model failed!')
171
+ exit(ret)
172
+ print('done')
173
+
174
+ print('--> Building model stage 1')
175
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
176
+ if ret != 0:
177
+ print('Build model failed!')
178
+ exit(ret)
179
+ print('done')
180
+ """
181
+ run_python_code(code)
182
+ print("Build stage 1 done")
183
+
184
+ intermidiate_model = onnx.load("check3_fuse_ops.onnx")
185
+
186
+ # fuse ops
187
+ from onnxscript.rewriter import pattern
188
+ import onnx.numpy_helper as onh
189
+ import numpy as np
190
+ def tp_rs_tp_rs_tp_pattern(op, input1, perm1, shape2, perm3, shape4, perm5):
191
+ i1 = op.Transpose(input1, perm=perm1)
192
+ i2 = op.Reshape(i1, shape2)
193
+ i3 = op.Transpose(i2, perm=perm3)
194
+ i4 = op.Reshape(i3, shape4)
195
+ i5 = op.Transpose(i4, perm=perm5)
196
+ return i5
197
+
198
+ def fused_pattern(op, input1, perm1, shape2, perm3, shape4, perm5):
199
+ rs1_shape = op.Constant(value=onh.from_array(np.array([input1.shape[0]* 3, input1.shape[1]//3, input1.shape[2], input1.shape[3]], dtype=np.int64)))
200
+ fi1 = op.Reshape(input1, rs1_shape)
201
+ fi2 = op.Transpose(fi1, perm=[0, 2, 1, 3])
202
+ elems = input1.shape[0] * input1.shape[1] * input1.shape[2] * input1.shape[3]
203
+ rs4_shape = op.Constant(value=onh.from_array(np.array([elems / 32 / 144, 32, 1, 144], dtype=np.int64)))
204
+ fi3 = op.Reshape(fi2, rs4_shape)
205
+ return fi3
206
+
207
+ rewrite_rule = pattern.RewriteRule(tp_rs_tp_rs_tp_pattern, fused_pattern)
208
+ rewrite_rule_set = pattern.RewriteRuleSet([rewrite_rule],commute=True)
209
+ fused_model = onnxscript.rewriter.rewrite(
210
+ intermidiate_model,
211
+ pattern_rewrite_rules=rewrite_rule_set
212
+ )
213
+ onnx.save(fused_model, "vision_encoder_part2.onnx")
214
+ ONNX_MODEL = "vision_encoder_part2.onnx"
215
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
216
+ del intermidiate_model
217
+ del fused_model
218
+
219
+
220
+ rknn = RKNN(verbose=True)
221
+
222
  # pre-process config
223
  print('--> Config model')
224
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True)
225
  print('done')
226
 
227
  # Load ONNX model
228
  print('--> Loading model')
229
+ ret = rknn.load_onnx(model="check3_fuse_ops.onnx",
230
+ inputs=["/blocks.0/blocks.0.0/channel_block/channel_attn/Add_output_0-rs"],
231
+ input_size_list=[[batch_size, 128, 1, 36864]],)
 
232
  if ret != 0:
233
  print('Load model failed!')
234
  exit(ret)
235
  print('done')
236
 
237
  # Build model
238
+ print('--> Building model stage 2')
239
  ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
240
  if ret != 0:
241
  print('Build model failed!')
 
249
  print('Export RKNN model failed!')
250
  exit(ret)
251
  print('done')
252
+
253
+
254
+
255
+
256
 
257
+
258
+
259
+ def check_vision_model():
260
  rknn = RKNN(verbose=True)
261
 
262
  ONNX_MODEL="vision_encoder.onnx"
 
296
  exit(ret)
297
  print('done')
298
 
299
+ #init runtime
300
+ print('--> Init runtime environment')
301
+ ret = rknn.init_runtime(target='rk3588')
302
+ if ret != 0:
303
+ print('Init runtime environment failed!')
304
+ exit(ret)
305
+ print('done')
306
+
307
+ #precision check
308
+ print('--> Precision check')
309
+ ret = rknn.accuracy_analysis(inputs=["lena.png"], target='rk3588')
310
+ if ret != 0:
311
+ print('Precision check failed!')
312
+ exit(ret)
313
+ print('done')
314
+
315
+
316
+
317
+
318
 
319
  import argparse
320
  # python convert.py <decoder|encoder|vision|all>
321
  if __name__ == "__main__":
322
  parser = argparse.ArgumentParser()
323
  parser.add_argument("model", type=str, help="Model to convert")
324
+ parser.add_argument("--check", action="store_true", help="Check model")
325
  args = parser.parse_args()
326
  if args.model == "decoder":
327
  convert_decoder()
 
330
  # elif args.model == "embed": # embed is faster with cpu
331
  # convert_embed()
332
  elif args.model == "vision":
333
+ if args.check:
334
+ check_vision_model()
335
+ else:
336
+ convert_vision()
337
  elif args.model == "all":
338
  convert_decoder()
339
  convert_encoder()
 
341
  convert_vision()
342
  else:
343
  print("Invalid model")
344
+ exit(1)