happyme531 commited on
Commit
043b275
·
verified ·
1 Parent(s): eda4c05

Split part of vision encoder to CPU and optimize Transpose ops.

Browse files
Files changed (2) hide show
  1. convert.py +344 -0
  2. rknnrun.py +325 -0
convert.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from rknn.api import RKNN
5
+ from math import exp
6
+ from sys import exit
7
+
8
+ import onnx
9
+ import onnxscript
10
+
11
+ batch_size = 1
12
+ # embed_seq_len = 590
13
+
14
+ prompt_tokens_list = [15, 17, 21, 25]
15
+
16
+ encoder_seq_len_list = [577 + p for p in prompt_tokens_list]
17
+
18
+ decoder_seq_len = 1
19
+
20
+ # set current directory to the directory of this file
21
+ import os
22
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
23
+
24
+ import subprocess
25
+ import select
26
+
27
+ def run_python_code(code):
28
+ # 启动子进程并执行代码
29
+ process = subprocess.Popen(
30
+ ['python', '-c', code],
31
+ stdout=subprocess.PIPE,
32
+ stderr=subprocess.PIPE,
33
+ text=True
34
+ )
35
+
36
+ # 实时读取子进程的输出和错误输出
37
+ while True:
38
+ reads = [process.stdout.fileno(), process.stderr.fileno()]
39
+ ret = select.select(reads, [], [])
40
+
41
+ for fd in ret[0]:
42
+ if fd == process.stdout.fileno():
43
+ output = process.stdout.readline()
44
+ if output:
45
+ print(output.strip())
46
+ if fd == process.stderr.fileno():
47
+ err = process.stderr.readline()
48
+ if err:
49
+ print(f"Error: {err.strip()}")
50
+
51
+ if process.poll() is not None:
52
+ break
53
+
54
+ def convert_decoder():
55
+ rknn = RKNN(verbose=True)
56
+
57
+ ONNX_MODEL="decoder_model.onnx"
58
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
59
+ DATASET="dataset.txt"
60
+ QUANTIZE=False
61
+
62
+ # [[batch_size, encoder_seq_len],
63
+ # [batch_size, encoder_seq_len, 768],
64
+ # [batch_size, decoder_seq_len, 768]]
65
+ input_shapes =[[[batch_size, encoder_seq_len],
66
+ [batch_size, encoder_seq_len, 768],
67
+ [batch_size, decoder_seq_len, 768]] for encoder_seq_len in encoder_seq_len_list]
68
+ # pre-process config
69
+ print('--> Config model')
70
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True,
71
+ dynamic_input=input_shapes)
72
+ print('done')
73
+
74
+ # Load ONNX model
75
+ print('--> Loading model')
76
+ ret = rknn.load_onnx(model=ONNX_MODEL,
77
+ )
78
+ if ret != 0:
79
+ print('Load model failed!')
80
+ exit(ret)
81
+ print('done')
82
+
83
+ # Build model
84
+ print('--> Building model')
85
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
86
+ if ret != 0:
87
+ print('Build model failed!')
88
+ exit(ret)
89
+ print('done')
90
+
91
+ #export
92
+ print('--> Export RKNN model')
93
+ ret = rknn.export_rknn(RKNN_MODEL)
94
+ if ret != 0:
95
+ print('Export RKNN model failed!')
96
+ exit(ret)
97
+ print('done')
98
+
99
+ def convert_encoder():
100
+ rknn = RKNN(verbose=True)
101
+
102
+ ONNX_MODEL="encoder_model.onnx"
103
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
104
+ DATASET="dataset.txt"
105
+ QUANTIZE=False
106
+
107
+ #[[batch_size, encoder_seq_len], [batch_size, encoder_seq_len, 768]]
108
+ input_shapes = [[[batch_size, encoder_seq_len], [batch_size, encoder_seq_len, 768]] for encoder_seq_len in encoder_seq_len_list]
109
+ # pre-process config
110
+ print('--> Config model')
111
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True, dynamic_input=input_shapes)
112
+ print('done')
113
+
114
+ # Load ONNX model
115
+ print('--> Loading model')
116
+ ret = rknn.load_onnx(model=ONNX_MODEL
117
+ )
118
+ if ret != 0:
119
+ print('Load model failed!')
120
+ exit(ret)
121
+ print('done')
122
+
123
+ # Build model
124
+ print('--> Building model')
125
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
126
+ if ret != 0:
127
+ print('Build model failed!')
128
+ exit(ret)
129
+ print('done')
130
+
131
+ # Export RKNN model
132
+ print('--> Export RKNN model')
133
+ ret = rknn.export_rknn(RKNN_MODEL)
134
+ if ret != 0:
135
+ print('Export RKNN model failed!')
136
+ exit(ret)
137
+ print('done')
138
+
139
+ def convert_vision():
140
+ rknn = RKNN(verbose=True)
141
+
142
+ ONNX_MODEL="vision_encoder.onnx"
143
+ DATASET="dataset.txt"
144
+ QUANTIZE=False
145
+
146
+ # split the first Transformers block into a separate model because it's too large to fit in the rknn
147
+ onnx.utils.extract_model(ONNX_MODEL, "vision_encoder_part1.onnx", ['pixel_values'], ['/blocks.0/blocks.0.0/channel_block/channel_attn/Add_output_0'])
148
+
149
+ ##### Build stage 1, this will crash the python process, so we need to run it in a separate process
150
+ code = f"""
151
+ from rknn.api import RKNN
152
+ rknn = RKNN(verbose=True)
153
+ ONNX_MODEL="vision_encoder.onnx"
154
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
155
+ DATASET="dataset.txt"
156
+ QUANTIZE=False
157
+ batch_size = {batch_size}
158
+ # pre-process config
159
+ print('--> Config model')
160
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True)
161
+ print('done')
162
+
163
+ # Load ONNX model
164
+ print('--> Loading model')
165
+ ret = rknn.load_onnx(model=ONNX_MODEL,
166
+ inputs=["pixel_values"],
167
+ input_size_list=[[batch_size, 3, 768, 768]],
168
+ )
169
+ if ret != 0:
170
+ print('Load model failed!')
171
+ exit(ret)
172
+ print('done')
173
+
174
+ print('--> Building model stage 1')
175
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
176
+ if ret != 0:
177
+ print('Build model failed!')
178
+ exit(ret)
179
+ print('done')
180
+ """
181
+ run_python_code(code)
182
+ print("Build stage 1 done")
183
+
184
+ intermidiate_model = onnx.load("check3_fuse_ops.onnx")
185
+
186
+ # fuse ops
187
+ from onnxscript.rewriter import pattern
188
+ import onnx.numpy_helper as onh
189
+ import numpy as np
190
+ def tp_rs_tp_rs_tp_pattern(op, input1, perm1, shape2, perm3, shape4, perm5):
191
+ i1 = op.Transpose(input1, perm=perm1)
192
+ i2 = op.Reshape(i1, shape2)
193
+ i3 = op.Transpose(i2, perm=perm3)
194
+ i4 = op.Reshape(i3, shape4)
195
+ i5 = op.Transpose(i4, perm=perm5)
196
+ return i5
197
+
198
+ def fused_pattern(op, input1, perm1, shape2, perm3, shape4, perm5):
199
+ rs1_shape = op.Constant(value=onh.from_array(np.array([input1.shape[0]* 3, input1.shape[1]//3, input1.shape[2], input1.shape[3]], dtype=np.int64)))
200
+ fi1 = op.Reshape(input1, rs1_shape)
201
+ fi2 = op.Transpose(fi1, perm=[0, 2, 1, 3])
202
+ elems = input1.shape[0] * input1.shape[1] * input1.shape[2] * input1.shape[3]
203
+ rs4_shape = op.Constant(value=onh.from_array(np.array([elems / 32 / 144, 32, 1, 144], dtype=np.int64)))
204
+ fi3 = op.Reshape(fi2, rs4_shape)
205
+ return fi3
206
+
207
+ rewrite_rule = pattern.RewriteRule(tp_rs_tp_rs_tp_pattern, fused_pattern)
208
+ rewrite_rule_set = pattern.RewriteRuleSet([rewrite_rule],commute=True)
209
+ fused_model = onnxscript.rewriter.rewrite(
210
+ intermidiate_model,
211
+ pattern_rewrite_rules=rewrite_rule_set
212
+ )
213
+ onnx.save(fused_model, "vision_encoder_part2.onnx")
214
+ ONNX_MODEL = "vision_encoder_part2.onnx"
215
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
216
+ del intermidiate_model
217
+ del fused_model
218
+
219
+
220
+ rknn = RKNN(verbose=True)
221
+
222
+ # pre-process config
223
+ print('--> Config model')
224
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True)
225
+ print('done')
226
+
227
+ # Load ONNX model
228
+ print('--> Loading model')
229
+ ret = rknn.load_onnx(model="check3_fuse_ops.onnx",
230
+ inputs=["/blocks.0/blocks.0.0/channel_block/channel_attn/Add_output_0-rs"],
231
+ input_size_list=[[batch_size, 128, 1, 36864]],)
232
+ if ret != 0:
233
+ print('Load model failed!')
234
+ exit(ret)
235
+ print('done')
236
+
237
+ # Build model
238
+ print('--> Building model stage 2')
239
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
240
+ if ret != 0:
241
+ print('Build model failed!')
242
+ exit(ret)
243
+ print('done')
244
+
245
+ # Export RKNN model
246
+ print('--> Export RKNN model')
247
+ ret = rknn.export_rknn(RKNN_MODEL)
248
+ if ret != 0:
249
+ print('Export RKNN model failed!')
250
+ exit(ret)
251
+ print('done')
252
+
253
+
254
+
255
+
256
+
257
+
258
+
259
+ def check_vision_model():
260
+ rknn = RKNN(verbose=True)
261
+
262
+ ONNX_MODEL="vision_encoder.onnx"
263
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
264
+ DATASET="dataset.txt"
265
+ QUANTIZE=False
266
+
267
+ # pre-process config
268
+ print('--> Config model')
269
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3, single_core_mode=True )
270
+ print('done')
271
+
272
+ # Load ONNX model
273
+ print('--> Loading model')
274
+ ret = rknn.load_onnx(model=ONNX_MODEL,
275
+ inputs=["pixel_values"],
276
+ input_size_list=[[batch_size, 3, vision_size[0], vision_size[1]]],
277
+ )
278
+ if ret != 0:
279
+ print('Load model failed!')
280
+ exit(ret)
281
+ print('done')
282
+
283
+ # Build model
284
+ print('--> Building model')
285
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
286
+ if ret != 0:
287
+ print('Build model failed!')
288
+ exit(ret)
289
+ print('done')
290
+
291
+ # Export RKNN model
292
+ print('--> Export RKNN model')
293
+ ret = rknn.export_rknn(RKNN_MODEL)
294
+ if ret != 0:
295
+ print('Export RKNN model failed!')
296
+ exit(ret)
297
+ print('done')
298
+
299
+ #init runtime
300
+ print('--> Init runtime environment')
301
+ ret = rknn.init_runtime(target='rk3588')
302
+ if ret != 0:
303
+ print('Init runtime environment failed!')
304
+ exit(ret)
305
+ print('done')
306
+
307
+ #precision check
308
+ print('--> Precision check')
309
+ ret = rknn.accuracy_analysis(inputs=["lena.png"], target='rk3588')
310
+ if ret != 0:
311
+ print('Precision check failed!')
312
+ exit(ret)
313
+ print('done')
314
+
315
+
316
+
317
+
318
+
319
+ import argparse
320
+ # python convert.py <decoder|encoder|vision|all>
321
+ if __name__ == "__main__":
322
+ parser = argparse.ArgumentParser()
323
+ parser.add_argument("model", type=str, help="Model to convert")
324
+ parser.add_argument("--check", action="store_true", help="Check model")
325
+ args = parser.parse_args()
326
+ if args.model == "decoder":
327
+ convert_decoder()
328
+ elif args.model == "encoder":
329
+ convert_encoder()
330
+ # elif args.model == "embed": # embed is faster with cpu
331
+ # convert_embed()
332
+ elif args.model == "vision":
333
+ if args.check:
334
+ check_vision_model()
335
+ else:
336
+ convert_vision()
337
+ elif args.model == "all":
338
+ convert_decoder()
339
+ convert_encoder()
340
+ # convert_embed()
341
+ convert_vision()
342
+ else:
343
+ print("Invalid model")
344
+ exit(1)
rknnrun.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from rknnlite.api.rknn_lite import RKNNLite
3
+ from transformers import AutoProcessor
4
+ from PIL import Image, ImageDraw
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+ import time
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.patches as patches
10
+ # set current working directory to the directory of this file
11
+ import os
12
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
13
+
14
+ # 初始化总时间计数器
15
+ total_time = 0
16
+
17
+ # Initialize RKNNLite instances
18
+ rknn_vision_encoder = RKNNLite(verbose=False)
19
+ rknn_encoder = RKNNLite(verbose=False)
20
+ rknn_decoder_prefill = RKNNLite(verbose=False)
21
+
22
+ # Load RKNN models
23
+ ret = rknn_vision_encoder.load_rknn('./vision_encoder_part2.rknn')
24
+ ret = rknn_encoder.load_rknn('./encoder_model.rknn')
25
+ ret = rknn_decoder_prefill.load_rknn('./decoder_model.rknn')
26
+
27
+ # Init runtime environment for each model
28
+ ret = rknn_vision_encoder.init_runtime()
29
+ ret = rknn_encoder.init_runtime()
30
+ ret = rknn_decoder_prefill.init_runtime()
31
+
32
+ text_embed = ort.InferenceSession("embed_tokens_fp16.onnx", providers=['CPUExecutionProvider'])
33
+ decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider'])
34
+ vision_encoder = ort.InferenceSession("vision_encoder_part1.onnx", providers=['CPUExecutionProvider'])
35
+ prompt_tokens_list = [15, 17, 21, 25]
36
+
37
+ # 1. prepare inputs
38
+ processor = AutoProcessor.from_pretrained("/home/firefly/mnt/zt-rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True)
39
+
40
+ # 2. prepare image
41
+ image = Image.open("./test.jpg")
42
+ original_image = image.copy()
43
+ original_size = image.size
44
+ # resize image to 768x768
45
+ image = image.resize((768, 768))
46
+ # 3. prepare text
47
+ prompt = "<MORE_DETAILED_CAPTION>"
48
+
49
+ ## try tokenize first
50
+ input_tokens_len = processor.tokenizer(prompt, return_tensors="np")["input_ids"].shape[1]
51
+ print("input_tokens_len: ", input_tokens_len)
52
+ ## select the closest greater value
53
+ pad_to = 0
54
+ for i in prompt_tokens_list:
55
+ if i >= input_tokens_len:
56
+ pad_to = i
57
+ break
58
+ print("pad_to: ", pad_to)
59
+ inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False, padding="max_length", max_length=pad_to + 577, truncation=True)
60
+ for k, v in inputs.items():
61
+ print(k, v.shape)
62
+
63
+ # 4. run vision encoder using RKNN
64
+ start_time = time.time()
65
+ image_features0 = vision_encoder.run(None, {
66
+ "pixel_values": inputs["pixel_values"]
67
+ })[0]
68
+ image_features = rknn_vision_encoder.inference(inputs=[image_features0.reshape(1, 128, 1, 36864)])[0]
69
+
70
+ end_time = time.time()
71
+ vision_encoder_time = (end_time - start_time) * 1000
72
+ total_time += vision_encoder_time
73
+ print(f"Vision encoder time: {vision_encoder_time:.2f} ms")
74
+ print(image_features.shape)
75
+ np.save("image_features.npy", image_features)
76
+
77
+ # 5. run text embed using RKNN
78
+ start_time = time.time()
79
+ inputs_embeds = text_embed.run(None, {
80
+ "input_ids": inputs["input_ids"]
81
+ })[0]
82
+ end_time = time.time()
83
+ text_embed_time = (end_time - start_time) * 1000
84
+ total_time += text_embed_time
85
+ print(f"Text embed time: {text_embed_time:.2f} ms")
86
+ print(inputs_embeds.shape)
87
+
88
+ # 6. concat image features and text embed
89
+ batch_size, image_token_length = image_features.shape[:-1]
90
+ image_attention_mask = np.ones((batch_size, image_token_length))
91
+ task_prefix_embeds = inputs_embeds
92
+ task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1]))
93
+ if len(task_prefix_attention_mask.shape) == 3:
94
+ task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
95
+ inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1)
96
+ attention_mask = np.concatenate([image_attention_mask, task_prefix_attention_mask], axis=1)
97
+
98
+ # 6. run encoder using RKNN
99
+ start_time = time.time()
100
+ encoder_out = rknn_encoder.inference(inputs=[attention_mask.astype(np.int64),inputs_embeds])
101
+ end_time = time.time()
102
+ encoder_time = (end_time - start_time) * 1000
103
+ total_time += encoder_time
104
+ print(f"Encoder time: {encoder_time:.2f} ms")
105
+ encoder_hidden_states = encoder_out[0]
106
+ print(encoder_hidden_states.shape)
107
+
108
+ # 7. run decoder prefill stage using RKNN
109
+ start_time = time.time()
110
+ next_token = processor.tokenizer.bos_token_id
111
+ next_input_embeds = text_embed.run(None, {
112
+ "input_ids": np.array([[next_token]], dtype=np.int64)
113
+ })[0]
114
+ decoder_outs = rknn_decoder_prefill.inference(inputs=[attention_mask.astype(np.int64), encoder_hidden_states,inputs_embeds[:, -1:]])
115
+ end_time = time.time()
116
+ decoder_prefill_time = (end_time - start_time) * 1000
117
+ total_time += decoder_prefill_time
118
+ print(f"Decoder prefill time: {decoder_prefill_time:.2f} ms")
119
+ # for output in decoder_outs:
120
+ # print(output.shape)
121
+
122
+ encoder_kv = decoder_outs[1:]
123
+
124
+ # 8. run decoder decode stage(autoregressive) (using onnxruntime)
125
+ generated_tokens = []
126
+ max_new_tokens = 512
127
+ decoder_decode_total_time = 0
128
+ while generated_tokens.__len__() < max_new_tokens:
129
+ # 获取上一步的输出
130
+ logits = decoder_outs[0]
131
+ decoder_kv = decoder_outs[1:]
132
+
133
+ # 选择最后一个token的logits
134
+ next_token_logits = logits[:, -1, :]
135
+
136
+ # 使用argmax选择下一个token (贪心算法)
137
+ next_token = np.argmax(next_token_logits, axis=-1)[0]
138
+ print("next_token: ", next_token)
139
+ # 将新生成的token添加到结果中
140
+ generated_tokens.append(next_token)
141
+
142
+ # 如果生成了结束符,则停止生成
143
+ if next_token == 2: # </s>
144
+ break
145
+
146
+ # 准备下一步的输入
147
+ start_time = time.time()
148
+ next_input_embeds = text_embed.run(None, {
149
+ "input_ids": np.array([[next_token]], dtype=np.int64)
150
+ })[0]
151
+ end_time = time.time()
152
+ text_embed_time = (end_time - start_time) * 1000
153
+ decoder_decode_total_time += text_embed_time
154
+
155
+ # 运行decoder的decode阶段
156
+ start_time = time.time()
157
+ decoder_outs = decoder_decode.run(None, {
158
+ "use_cache_branch": np.array([True], dtype=np.bool_),
159
+ "inputs_embeds": next_input_embeds,
160
+ "encoder_hidden_states": encoder_hidden_states,
161
+ "encoder_attention_mask": attention_mask.astype(np.int64),
162
+ "past_key_values.0.decoder.key": decoder_kv[0],
163
+ "past_key_values.0.decoder.value": decoder_kv[1],
164
+ "past_key_values.0.encoder.key": encoder_kv[2],
165
+ "past_key_values.0.encoder.value": encoder_kv[3],
166
+ "past_key_values.1.decoder.key": decoder_kv[4],
167
+ "past_key_values.1.decoder.value": decoder_kv[5],
168
+ "past_key_values.1.encoder.key": encoder_kv[6],
169
+ "past_key_values.1.encoder.value": encoder_kv[7],
170
+ "past_key_values.2.decoder.key": decoder_kv[8],
171
+ "past_key_values.2.decoder.value": decoder_kv[9],
172
+ "past_key_values.2.encoder.key": encoder_kv[10],
173
+ "past_key_values.2.encoder.value": encoder_kv[11],
174
+ "past_key_values.3.decoder.key": decoder_kv[12],
175
+ "past_key_values.3.decoder.value": decoder_kv[13],
176
+ "past_key_values.3.encoder.key": encoder_kv[14],
177
+ "past_key_values.3.encoder.value": encoder_kv[15],
178
+ "past_key_values.4.decoder.key": decoder_kv[16],
179
+ "past_key_values.4.decoder.value": decoder_kv[17],
180
+ "past_key_values.4.encoder.key": encoder_kv[18],
181
+ "past_key_values.4.encoder.value": encoder_kv[19],
182
+ "past_key_values.5.decoder.key": decoder_kv[20],
183
+ "past_key_values.5.decoder.value": decoder_kv[21],
184
+ "past_key_values.5.encoder.key": encoder_kv[22],
185
+ "past_key_values.5.encoder.value": encoder_kv[23],
186
+ })
187
+ end_time = time.time()
188
+ decoder_decode_time = (end_time - start_time) * 1000
189
+ decoder_decode_total_time += decoder_decode_time
190
+
191
+ total_time += decoder_decode_total_time
192
+ print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms")
193
+
194
+ # 将生成的tokens转换为文本
195
+ print("generated_tokens: ", generated_tokens)
196
+ generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0]
197
+ print("Generated Text:", generated_text)
198
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt.split(">")[0].strip() + ">", image_size=original_size)
199
+ print("Parsed Answer:", parsed_answer)
200
+
201
+ print(f"Total inference time: {total_time:.2f} ms")
202
+
203
+ # postprocess
204
+ from PIL import Image, ImageDraw, ImageFont
205
+
206
+ from PIL import Image, ImageDraw, ImageFont
207
+
208
+ def plot_bbox(image, data):
209
+ # Convert the image to a PIL Image if it's not already
210
+ if not isinstance(image, Image.Image):
211
+ image = Image.fromarray(image)
212
+
213
+ # Create a drawing context
214
+ draw = ImageDraw.Draw(image)
215
+
216
+ # Load a larger font
217
+ try:
218
+ font = ImageFont.truetype("arial.ttf", 20) # 尝试加载Arial字体,大小为20
219
+ except IOError:
220
+ font = ImageFont.load_default().font_variant(size=20) # 如果Arial不可用,使用默认字体并放大
221
+
222
+ # Plot each bounding box
223
+ for bbox, label in zip(data['bboxes'], data['labels']):
224
+ # Unpack the bounding box coordinates
225
+ x1, y1, x2, y2 = bbox
226
+ # Draw the rectangle with thicker outline
227
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3) # 增加线条宽度到3
228
+
229
+ # Annotate the label
230
+ left, top, right, bottom = font.getbbox(label)
231
+ text_width = right - left
232
+ text_height = bottom - top
233
+
234
+ # 增加文本背景框的大小
235
+ padding = 5
236
+ draw.rectangle([x1, y1 - text_height - padding*2, x1 + text_width + padding*2, y1], fill="red")
237
+ draw.text((x1 + padding, y1 - text_height - padding), label, fill="white", font=font)
238
+
239
+ # Save the image
240
+ image.save("result_image.jpg")
241
+
242
+ colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
243
+ 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
244
+
245
+ def draw_polygons(image, prediction, fill_mask=False):
246
+ """
247
+ Draws segmentation masks with polygons on an image.
248
+
249
+ Parameters:
250
+ - image_path: Path to the image file.
251
+ - prediction: Dictionary containing 'polygons' and 'labels' keys.
252
+ 'polygons' is a list of lists, each containing vertices of a polygon.
253
+ 'labels' is a list of labels corresponding to each polygon.
254
+ - fill_mask: Boolean indicating whether to fill the polygons with color.
255
+ """
256
+ # Load the image
257
+
258
+ draw = ImageDraw.Draw(image)
259
+
260
+
261
+ # Set up scale factor if needed (use 1 if not scaling)
262
+ scale = 1
263
+
264
+ # Iterate over polygons and labels
265
+ for polygons, label in zip(prediction['polygons'], prediction['labels']):
266
+ color = random.choice(colormap)
267
+ fill_color = random.choice(colormap) if fill_mask else None
268
+
269
+ for _polygon in polygons:
270
+ _polygon = np.array(_polygon).reshape(-1, 2)
271
+ if len(_polygon) < 3:
272
+ print('Invalid polygon:', _polygon)
273
+ continue
274
+
275
+ _polygon = (_polygon * scale).reshape(-1).tolist()
276
+
277
+ # Draw the polygon
278
+ if fill_mask:
279
+ draw.polygon(_polygon, outline=color, fill=fill_color)
280
+ else:
281
+ draw.polygon(_polygon, outline=color)
282
+
283
+ # Draw the label text
284
+ draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
285
+
286
+ # Save or display the image
287
+ # image.show() # Display the image
288
+ # display(image)
289
+ image.save("result_image.jpg")
290
+
291
+
292
+
293
+ def draw_ocr_bboxes(image, prediction, scale=1):
294
+ draw = ImageDraw.Draw(image)
295
+
296
+ # Load a larger font
297
+ try:
298
+ font = ImageFont.truetype("arial.ttf", 18) # 尝试加载Arial字体,大小为18
299
+ except IOError:
300
+ font = ImageFont.load_default().font_variant(size=18) # 如果Arial不可用,使用默认字体并放大
301
+ bboxes, labels = prediction['quad_boxes'], prediction['labels']
302
+ for box, label in zip(bboxes, labels):
303
+ color = random.choice(colormap)
304
+ new_box = (np.array(box) * scale).tolist()
305
+ draw.polygon(new_box, width=3, outline=color)
306
+ draw.text((new_box[0]+8, new_box[1]+2),
307
+ "{}".format(label),
308
+ align="right",
309
+
310
+ fill=color)
311
+
312
+ # display(image)
313
+ image.save("result_image.jpg")
314
+
315
+
316
+ # draw_polygons(original_image, parsed_answer['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
317
+ # plot_bbox(original_image, parsed_answer[prompt.split(">")[0].strip() + ">"])
318
+ # draw_ocr_bboxes(original_image, parsed_answer["<OCR_WITH_REGION>"], scale=1)
319
+
320
+
321
+
322
+ # Release RKNNLite instances
323
+ rknn_vision_encoder.release()
324
+ rknn_encoder.release()
325
+ rknn_decoder_prefill.release()