happyme531
commited on
Upload 13 files
Browse files- .gitattributes +2 -0
- convert_rknn.py +115 -0
- duration_embedder.onnx +3 -0
- export_onnx.py +234 -0
- inference.py +365 -0
- proj.onnx +3 -0
- spiece.model +3 -0
- text_encoder_bnb4.onnx +3 -0
- transformer.onnx +3 -0
- transformer.rknn +3 -0
- vae_decoder.onnx +3 -0
- vae_decoder.rknn +3 -0
- vae_encoder.onnx +3 -0
- ztu_somemodelruntime_rknnlite2.py +535 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
transformer.rknn filter=lfs diff=lfs merge=lfs -text
|
37 |
+
vae_decoder.rknn filter=lfs diff=lfs merge=lfs -text
|
convert_rknn.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import datetime
|
5 |
+
import argparse
|
6 |
+
from rknn.api import RKNN
|
7 |
+
from sys import exit
|
8 |
+
|
9 |
+
AUDIO_LENGTH = 645 # 音频长度, 645为10秒
|
10 |
+
TEXT_LENGTH = 64 # 文本长度(token)
|
11 |
+
|
12 |
+
# 模型配置
|
13 |
+
MODELS = {
|
14 |
+
'transformer': 'transformer.onnx',
|
15 |
+
'vae_decoder': 'vae_decoder.onnx',
|
16 |
+
}
|
17 |
+
|
18 |
+
SHAPES = {
|
19 |
+
'transformer': [
|
20 |
+
[
|
21 |
+
[1, AUDIO_LENGTH, 64], # hidden_states
|
22 |
+
[1,], # timestep
|
23 |
+
[2, 1024], # pooled_text
|
24 |
+
[2, TEXT_LENGTH, 1024], # encoder_hidden_states
|
25 |
+
[1, TEXT_LENGTH, 3], # txt_ids
|
26 |
+
[1, AUDIO_LENGTH, 3], # img_ids
|
27 |
+
],
|
28 |
+
],
|
29 |
+
'vae_decoder': [
|
30 |
+
[
|
31 |
+
[1, 64, AUDIO_LENGTH],
|
32 |
+
],
|
33 |
+
],
|
34 |
+
}
|
35 |
+
|
36 |
+
QUANTIZE=False
|
37 |
+
detailed_performance_log = True
|
38 |
+
|
39 |
+
def convert_model(model_type):
|
40 |
+
"""转换指定类型的模型到RKNN格式"""
|
41 |
+
if model_type not in MODELS:
|
42 |
+
print(f"错误: 不支持的模型类型 {model_type}")
|
43 |
+
return False
|
44 |
+
|
45 |
+
onnx_model = MODELS[model_type]
|
46 |
+
rknn_model = onnx_model.replace(".onnx",".rknn")
|
47 |
+
|
48 |
+
timedate_iso = datetime.datetime.now().isoformat()
|
49 |
+
|
50 |
+
rknn = RKNN(verbose=True)
|
51 |
+
rknn.config(
|
52 |
+
quantized_dtype='w8a8',
|
53 |
+
quantized_algorithm='normal',
|
54 |
+
quantized_method='channel',
|
55 |
+
quantized_hybrid_level=0,
|
56 |
+
target_platform='rk3588',
|
57 |
+
quant_img_RGB2BGR = False,
|
58 |
+
float_dtype='float16',
|
59 |
+
optimization_level=3,
|
60 |
+
custom_string=f"converted at {timedate_iso}",
|
61 |
+
remove_weight=False,
|
62 |
+
compress_weight=False,
|
63 |
+
inputs_yuv_fmt=None,
|
64 |
+
single_core_mode=False,
|
65 |
+
dynamic_input=SHAPES[model_type],
|
66 |
+
model_pruning=False,
|
67 |
+
op_target=None,
|
68 |
+
quantize_weight=False,
|
69 |
+
remove_reshape=False,
|
70 |
+
sparse_infer=False,
|
71 |
+
enable_flash_attention=False,
|
72 |
+
# disable_rules=['convert_gemm_by_exmatmul']
|
73 |
+
)
|
74 |
+
|
75 |
+
print(f"开始转换 {model_type} 模型...")
|
76 |
+
ret = rknn.load_onnx(model=onnx_model)
|
77 |
+
if ret != 0:
|
78 |
+
print("加载ONNX模型失败")
|
79 |
+
return False
|
80 |
+
|
81 |
+
ret = rknn.build(do_quantization=False, rknn_batch_size=None)
|
82 |
+
if ret != 0:
|
83 |
+
print("构建RKNN模型失败")
|
84 |
+
return False
|
85 |
+
|
86 |
+
ret = rknn.export_rknn(rknn_model)
|
87 |
+
if ret != 0:
|
88 |
+
print("导出RKNN模型失败")
|
89 |
+
return False
|
90 |
+
|
91 |
+
print(f"成功转换模型: {rknn_model}")
|
92 |
+
return True
|
93 |
+
|
94 |
+
def main():
|
95 |
+
parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式')
|
96 |
+
parser.add_argument('model_type', nargs='?', default='all',
|
97 |
+
choices=['all', 'transformer', 'vae_decoder'],
|
98 |
+
help='要转换的模型类型 (默认: all)')
|
99 |
+
|
100 |
+
args = parser.parse_args()
|
101 |
+
|
102 |
+
if args.model_type == 'all':
|
103 |
+
# 转换所有模型
|
104 |
+
for model_type in MODELS.keys():
|
105 |
+
if not convert_model(model_type):
|
106 |
+
print(f"转换 {model_type} 失败")
|
107 |
+
else:
|
108 |
+
# 转换指定模型
|
109 |
+
if not convert_model(args.model_type):
|
110 |
+
print(f"转换 {args.model_type} 失败")
|
111 |
+
|
112 |
+
if __name__ == '__main__':
|
113 |
+
main()
|
114 |
+
|
115 |
+
|
duration_embedder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1b2bd04d4bbd075e7c663711e55b6d09c68bbd35a772587ae46d8339599e03e3
|
3 |
+
size 1061046
|
export_onnx.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from diffusers import AutoencoderOobleck
|
4 |
+
from diffusers import FluxTransformer2DModel
|
5 |
+
from tangoflux import TangoFluxInference
|
6 |
+
from tangoflux.model import DurationEmbedder, TangoFlux
|
7 |
+
|
8 |
+
def export_vae_encoder(vae, save_path, batch_size=1, audio_length=441000):
|
9 |
+
"""导出VAE编码器到ONNX格式
|
10 |
+
|
11 |
+
Args:
|
12 |
+
vae: AutoencoderOobleck实例
|
13 |
+
save_path: 保存路径
|
14 |
+
batch_size: batch大小
|
15 |
+
audio_length: 音频长度(默认10秒,44100Hz采样率)
|
16 |
+
"""
|
17 |
+
vae.eval()
|
18 |
+
|
19 |
+
# 创建dummy input - 注意这里是双声道音频
|
20 |
+
dummy_input = torch.randn(batch_size, 2, audio_length)
|
21 |
+
|
22 |
+
# 创建一个包装类来处理forward调用
|
23 |
+
class VAEEncoderWrapper(nn.Module):
|
24 |
+
def __init__(self, vae):
|
25 |
+
super().__init__()
|
26 |
+
self.vae = vae
|
27 |
+
|
28 |
+
def forward(self, audio):
|
29 |
+
return self.vae.encode(audio).latent_dist.sample()
|
30 |
+
|
31 |
+
wrapper = VAEEncoderWrapper(vae)
|
32 |
+
|
33 |
+
# 导出encoder部分
|
34 |
+
torch.onnx.export(
|
35 |
+
wrapper,
|
36 |
+
dummy_input,
|
37 |
+
save_path,
|
38 |
+
input_names=['audio'],
|
39 |
+
output_names=['latent'],
|
40 |
+
dynamic_axes={
|
41 |
+
'audio': {0: 'batch_size', 2: 'audio_length'},
|
42 |
+
'latent': {0: 'batch_size', 2: 'latent_length'}
|
43 |
+
},
|
44 |
+
opset_version=17
|
45 |
+
)
|
46 |
+
|
47 |
+
def export_vae_decoder(vae, save_path, batch_size=1, latent_length=645):
|
48 |
+
"""导出VAE解码器到ONNX格式
|
49 |
+
|
50 |
+
Args:
|
51 |
+
vae: AutoencoderOobleck实例
|
52 |
+
save_path: 保存路径
|
53 |
+
batch_size: batch大小
|
54 |
+
latent_length: 潜在向量长度
|
55 |
+
"""
|
56 |
+
vae.eval()
|
57 |
+
|
58 |
+
# 创建dummy input
|
59 |
+
dummy_input = torch.randn(batch_size, 64, latent_length)
|
60 |
+
|
61 |
+
# 创建一个包装类来处理forward调用
|
62 |
+
class VAEDecoderWrapper(nn.Module):
|
63 |
+
def __init__(self, vae):
|
64 |
+
super().__init__()
|
65 |
+
self.vae = vae
|
66 |
+
|
67 |
+
def forward(self, latent):
|
68 |
+
return self.vae.decode(latent).sample
|
69 |
+
|
70 |
+
wrapper = VAEDecoderWrapper(vae)
|
71 |
+
|
72 |
+
# 导出decoder部分
|
73 |
+
torch.onnx.export(
|
74 |
+
wrapper,
|
75 |
+
dummy_input,
|
76 |
+
save_path,
|
77 |
+
input_names=['latent'],
|
78 |
+
output_names=['audio'],
|
79 |
+
dynamic_axes={
|
80 |
+
'latent': {0: 'batch_size', 2: 'latent_length'},
|
81 |
+
'audio': {0: 'batch_size', 2: 'audio_length'}
|
82 |
+
},
|
83 |
+
opset_version=17
|
84 |
+
)
|
85 |
+
|
86 |
+
def export_duration_embedder(duration_embedder, save_path, batch_size=1):
|
87 |
+
"""导出Duration Embedder到ONNX格式
|
88 |
+
|
89 |
+
Args:
|
90 |
+
duration_embedder: DurationEmbedder实例
|
91 |
+
save_path: 保存路径
|
92 |
+
batch_size: batch大小
|
93 |
+
"""
|
94 |
+
duration_embedder.eval()
|
95 |
+
|
96 |
+
# 创建dummy input - 注意这里是标量值
|
97 |
+
dummy_input = torch.tensor([[10.0]], dtype=torch.float32) # 10秒
|
98 |
+
|
99 |
+
# 导出
|
100 |
+
torch.onnx.export(
|
101 |
+
duration_embedder,
|
102 |
+
dummy_input,
|
103 |
+
save_path,
|
104 |
+
input_names=['duration'],
|
105 |
+
output_names=['embedding'],
|
106 |
+
dynamic_axes={
|
107 |
+
'duration': {0: 'batch_size'},
|
108 |
+
'embedding': {0: 'batch_size'}
|
109 |
+
},
|
110 |
+
opset_version=17
|
111 |
+
)
|
112 |
+
|
113 |
+
def export_flux_transformer(transformer, save_path, batch_size=1, seq_length=645):
|
114 |
+
"""导出FluxTransformer2D到ONNX格式
|
115 |
+
|
116 |
+
Args:
|
117 |
+
transformer: FluxTransformer2DModel实例
|
118 |
+
save_path: 保存路径
|
119 |
+
batch_size: batch大小
|
120 |
+
seq_length: 序列长度
|
121 |
+
"""
|
122 |
+
transformer.eval()
|
123 |
+
|
124 |
+
# 创建dummy inputs - 注意所有输入的形状
|
125 |
+
hidden_states = torch.randn(batch_size, seq_length, 64) # [B, S, C]
|
126 |
+
timestep = torch.tensor([0.5]) # [1]
|
127 |
+
pooled_text = torch.randn(batch_size, 1024) # [B, D]
|
128 |
+
encoder_hidden_states = torch.randn(batch_size, 64, 1024) # [B, L, D]
|
129 |
+
txt_ids = torch.zeros(batch_size, 64, 3).to(torch.int64) # [B, L, 3]
|
130 |
+
img_ids = torch.arange(seq_length).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, 3).to(torch.int64) # [B, S, 3]
|
131 |
+
|
132 |
+
# 创建一个包装类来处理forward调用
|
133 |
+
class TransformerWrapper(nn.Module):
|
134 |
+
def __init__(self, transformer):
|
135 |
+
super().__init__()
|
136 |
+
self.transformer = transformer
|
137 |
+
|
138 |
+
def forward(self, hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids):
|
139 |
+
return self.transformer(
|
140 |
+
hidden_states=hidden_states,
|
141 |
+
timestep=timestep,
|
142 |
+
guidance=None,
|
143 |
+
pooled_projections=pooled_text,
|
144 |
+
encoder_hidden_states=encoder_hidden_states,
|
145 |
+
txt_ids=txt_ids,
|
146 |
+
img_ids=img_ids,
|
147 |
+
return_dict=False
|
148 |
+
)[0]
|
149 |
+
|
150 |
+
wrapper = TransformerWrapper(transformer)
|
151 |
+
|
152 |
+
# 导出
|
153 |
+
torch.onnx.export(
|
154 |
+
wrapper,
|
155 |
+
(hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids),
|
156 |
+
save_path,
|
157 |
+
input_names=['hidden_states', 'timestep', 'pooled_text', 'encoder_hidden_states', 'txt_ids', 'img_ids'],
|
158 |
+
output_names=['output'],
|
159 |
+
dynamic_axes={
|
160 |
+
'hidden_states': {0: 'batch_size', 1: 'sequence_length'},
|
161 |
+
'pooled_text': {0: 'batch_size'},
|
162 |
+
'encoder_hidden_states': {0: 'batch_size', 1: 'text_length'},
|
163 |
+
'txt_ids': {0: 'batch_size', 1: 'text_length'},
|
164 |
+
'img_ids': {0: 'batch_size', 1: 'sequence_length'}
|
165 |
+
},
|
166 |
+
opset_version=17
|
167 |
+
)
|
168 |
+
|
169 |
+
def export_proj_layer(proj_layer, save_path, batch_size=1):
|
170 |
+
"""导出projection层到ONNX格式
|
171 |
+
|
172 |
+
Args:
|
173 |
+
proj_layer: 投影层(fc层)实例
|
174 |
+
save_path: 保存路径
|
175 |
+
batch_size: batch大小
|
176 |
+
"""
|
177 |
+
proj_layer.eval()
|
178 |
+
|
179 |
+
# 创建dummy input - 使用T5的hidden size
|
180 |
+
dummy_input = torch.randn(batch_size, 1024) # T5-large hidden size
|
181 |
+
|
182 |
+
# 导出
|
183 |
+
torch.onnx.export(
|
184 |
+
proj_layer,
|
185 |
+
dummy_input,
|
186 |
+
save_path,
|
187 |
+
input_names=['text_embedding'],
|
188 |
+
output_names=['projected'],
|
189 |
+
dynamic_axes={
|
190 |
+
'text_embedding': {0: 'batch_size'},
|
191 |
+
'projected': {0: 'batch_size'}
|
192 |
+
},
|
193 |
+
opset_version=17
|
194 |
+
)
|
195 |
+
|
196 |
+
def export_all(model_path, output_dir):
|
197 |
+
"""导出所有组件到ONNX格式
|
198 |
+
|
199 |
+
Args:
|
200 |
+
model_path: TangoFlux模型路径
|
201 |
+
output_dir: 输出目录
|
202 |
+
"""
|
203 |
+
import os
|
204 |
+
|
205 |
+
# 加载模型
|
206 |
+
model = TangoFluxInference(name=model_path, device="cpu")
|
207 |
+
|
208 |
+
# 创建输出目录
|
209 |
+
os.makedirs(output_dir, exist_ok=True)
|
210 |
+
|
211 |
+
# 导出VAE
|
212 |
+
export_vae_encoder(model.vae, f"{output_dir}/vae_encoder.onnx")
|
213 |
+
export_vae_decoder(model.vae, f"{output_dir}/vae_decoder.onnx")
|
214 |
+
|
215 |
+
# 导出Duration Embedder
|
216 |
+
export_duration_embedder(model.model.duration_emebdder, f"{output_dir}/duration_embedder.onnx")
|
217 |
+
|
218 |
+
# 导出Transformer
|
219 |
+
export_flux_transformer(model.model.transformer, f"{output_dir}/transformer.onnx")
|
220 |
+
|
221 |
+
# 导出Projection层
|
222 |
+
export_proj_layer(model.model.fc, f"{output_dir}/proj.onnx")
|
223 |
+
|
224 |
+
print(f"所有模型已导出到: {output_dir}")
|
225 |
+
|
226 |
+
if __name__ == "__main__":
|
227 |
+
import argparse
|
228 |
+
|
229 |
+
parser = argparse.ArgumentParser(description="导出TangoFlux模型到ONNX格式")
|
230 |
+
parser.add_argument("--model_path", type=str, required=True, help="TangoFlux模型路径")
|
231 |
+
parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
|
232 |
+
|
233 |
+
args = parser.parse_args()
|
234 |
+
export_all(args.model_path, args.output_dir)
|
inference.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
# import onnxruntime as ort
|
3 |
+
import ztu_somemodelruntime_rknnlite2 as ort
|
4 |
+
import sentencepiece as spm
|
5 |
+
import soundfile as sf
|
6 |
+
|
7 |
+
ort.set_default_logger_verbosity(0)
|
8 |
+
|
9 |
+
def load_onnx_model(model_path):
|
10 |
+
"""加载ONNX模型"""
|
11 |
+
return ort.InferenceSession(
|
12 |
+
model_path,
|
13 |
+
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
14 |
+
)
|
15 |
+
|
16 |
+
class SimpleT5Tokenizer:
|
17 |
+
def __init__(self, model_path, max_length=128):
|
18 |
+
"""初始化tokenizer
|
19 |
+
|
20 |
+
Args:
|
21 |
+
model_path: sentencepiece模型路径
|
22 |
+
max_length: 序列最大长度,默认128
|
23 |
+
"""
|
24 |
+
self.sp = spm.SentencePieceProcessor()
|
25 |
+
self.sp.Load(model_path)
|
26 |
+
|
27 |
+
# T5特殊token的ID
|
28 |
+
self.pad_token_id = 0
|
29 |
+
self.eos_token_id = 1
|
30 |
+
self.max_length = max_length
|
31 |
+
|
32 |
+
def __call__(self, texts, padding=True, truncation=True, max_length=None, return_tensors="np"):
|
33 |
+
"""处理文本序列
|
34 |
+
|
35 |
+
Args:
|
36 |
+
texts: 文本或文本列表
|
37 |
+
padding: 是否padding
|
38 |
+
truncation: 是否截断
|
39 |
+
max_length: 可选,覆盖默认max_length
|
40 |
+
return_tensors: 返回类型(只支持"np")
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
dict: 包含input_ids和attention_mask
|
44 |
+
"""
|
45 |
+
if isinstance(texts, str):
|
46 |
+
texts = [texts]
|
47 |
+
|
48 |
+
max_len = max_length if max_length is not None else self.max_length
|
49 |
+
|
50 |
+
# 分词并转换为ID
|
51 |
+
input_ids = []
|
52 |
+
attention_mask = []
|
53 |
+
for text in texts:
|
54 |
+
ids = self.sp.EncodeAsIds(text)
|
55 |
+
|
56 |
+
# 截断处理(预留EOS token位置)
|
57 |
+
if truncation and len(ids) > max_len - 1:
|
58 |
+
ids = ids[:max_len-1]
|
59 |
+
ids.append(self.eos_token_id)
|
60 |
+
|
61 |
+
# 创建attention mask
|
62 |
+
mask = [1] * len(ids)
|
63 |
+
|
64 |
+
# Padding处理
|
65 |
+
if padding:
|
66 |
+
pad_length = max_len - len(ids)
|
67 |
+
ids.extend([self.pad_token_id] * pad_length)
|
68 |
+
mask.extend([0] * pad_length)
|
69 |
+
|
70 |
+
input_ids.append(ids)
|
71 |
+
attention_mask.append(mask)
|
72 |
+
|
73 |
+
# 转换为numpy array
|
74 |
+
input_ids = np.array(input_ids, dtype=np.int64)
|
75 |
+
attention_mask = np.array(attention_mask, dtype=np.int64)
|
76 |
+
|
77 |
+
return {
|
78 |
+
"input_ids": input_ids,
|
79 |
+
"attention_mask": attention_mask
|
80 |
+
}
|
81 |
+
|
82 |
+
def encode_text(prompt, negative_prompt, tokenizer, text_encoder_onnx, guidance_scale=None):
|
83 |
+
"""编码文本,同时处理条件和无条件文本
|
84 |
+
|
85 |
+
Args:
|
86 |
+
prompt: 文本提示
|
87 |
+
tokenizer: T5 tokenizer
|
88 |
+
text_encoder_onnx: T5 ONNX模型
|
89 |
+
guidance_scale: 引导系数
|
90 |
+
"""
|
91 |
+
if not isinstance(prompt, list):
|
92 |
+
prompt = [prompt]
|
93 |
+
|
94 |
+
if guidance_scale is not None and guidance_scale > 1.0:
|
95 |
+
# 同时处理条件和无条件文本
|
96 |
+
all_prompts = [negative_prompt] + prompt
|
97 |
+
batch = tokenizer(
|
98 |
+
all_prompts,
|
99 |
+
padding=True,
|
100 |
+
truncation=True,
|
101 |
+
return_tensors="np"
|
102 |
+
)
|
103 |
+
|
104 |
+
# ONNX推理
|
105 |
+
all_hidden_states = text_encoder_onnx.run(
|
106 |
+
['last_hidden_state'],
|
107 |
+
{
|
108 |
+
'input_ids': batch['input_ids'].astype(np.int64),
|
109 |
+
'attention_mask': batch['attention_mask'].astype(np.int64)
|
110 |
+
}
|
111 |
+
)[0]
|
112 |
+
|
113 |
+
# 分离无条件和条件结果
|
114 |
+
uncond_hidden_states = all_hidden_states[0:1]
|
115 |
+
cond_hidden_states = all_hidden_states[1:]
|
116 |
+
uncond_mask = batch['attention_mask'][0:1]
|
117 |
+
cond_mask = batch['attention_mask'][1:]
|
118 |
+
|
119 |
+
return (uncond_hidden_states, uncond_mask), (cond_hidden_states, cond_mask)
|
120 |
+
else:
|
121 |
+
# 只处理条件文本
|
122 |
+
batch = tokenizer(
|
123 |
+
prompt,
|
124 |
+
padding=True,
|
125 |
+
truncation=True,
|
126 |
+
return_tensors="np"
|
127 |
+
)
|
128 |
+
|
129 |
+
# ONNX推理
|
130 |
+
hidden_states = text_encoder_onnx.run(
|
131 |
+
['last_hidden_state'],
|
132 |
+
{
|
133 |
+
'input_ids': batch['input_ids'].astype(np.int64),
|
134 |
+
'attention_mask': batch['attention_mask'].astype(np.int64)
|
135 |
+
}
|
136 |
+
)[0]
|
137 |
+
|
138 |
+
return hidden_states, batch['attention_mask']
|
139 |
+
|
140 |
+
def retrieve_timesteps(scheduler, num_inference_steps, device, timesteps=None, sigmas=None):
|
141 |
+
"""获取timesteps"""
|
142 |
+
if sigmas is not None:
|
143 |
+
scheduler.set_timesteps(sigmas=sigmas)
|
144 |
+
timesteps = scheduler.timesteps
|
145 |
+
num_inference_steps = len(timesteps)
|
146 |
+
else:
|
147 |
+
scheduler.set_timesteps(num_inference_steps)
|
148 |
+
timesteps = scheduler.timesteps
|
149 |
+
return timesteps, num_inference_steps
|
150 |
+
|
151 |
+
# 添加一个简单的FlowMatchScheduler类
|
152 |
+
class SimpleFlowMatchScheduler:
|
153 |
+
def __init__(self, num_train_timesteps=1000, shift=1.0):
|
154 |
+
"""初始化scheduler
|
155 |
+
|
156 |
+
Args:
|
157 |
+
num_train_timesteps: 训练步数
|
158 |
+
shift: 时间步偏移量
|
159 |
+
"""
|
160 |
+
# 生成线性timesteps
|
161 |
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
162 |
+
|
163 |
+
# 计算sigmas
|
164 |
+
sigmas = timesteps / num_train_timesteps
|
165 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
166 |
+
|
167 |
+
# 添加终止sigma
|
168 |
+
self.sigmas = np.append(sigmas, 0.0)
|
169 |
+
self.timesteps = sigmas * num_train_timesteps
|
170 |
+
self.step_index = None
|
171 |
+
|
172 |
+
def set_timesteps(self, num_inference_steps):
|
173 |
+
"""设置推理时的timesteps
|
174 |
+
|
175 |
+
Args:
|
176 |
+
num_inference_steps: 推理步数
|
177 |
+
"""
|
178 |
+
timesteps = np.linspace(1, len(self.timesteps), num_inference_steps, dtype=np.float32)[::-1].copy()
|
179 |
+
sigmas = timesteps / len(self.timesteps)
|
180 |
+
self.sigmas = np.append(sigmas, 0.0)
|
181 |
+
self.timesteps = sigmas * len(self.timesteps)
|
182 |
+
self.step_index = 0
|
183 |
+
|
184 |
+
def step(self, model_output, timestep, sample):
|
185 |
+
"""执行一步euler更新
|
186 |
+
|
187 |
+
Args:
|
188 |
+
model_output: 模型输出
|
189 |
+
timestep: 当前时间步
|
190 |
+
sample: 当前样本
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
prev_sample: 更新后的样本
|
194 |
+
"""
|
195 |
+
sigma = self.sigmas[self.step_index]
|
196 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
197 |
+
|
198 |
+
# euler更新
|
199 |
+
prev_sample = sample + (sigma_next - sigma) * model_output
|
200 |
+
|
201 |
+
self.step_index += 1
|
202 |
+
return prev_sample
|
203 |
+
|
204 |
+
def generate_audio_onnx(
|
205 |
+
prompt="",
|
206 |
+
negative_prompt="",
|
207 |
+
duration=10,
|
208 |
+
steps=50,
|
209 |
+
guidance_scale=4.5,
|
210 |
+
onnx_dir="./onnx_models",
|
211 |
+
output_path="output_onnx.wav",
|
212 |
+
seed=None
|
213 |
+
):
|
214 |
+
if seed is not None:
|
215 |
+
np.random.seed(seed)
|
216 |
+
|
217 |
+
# 加载tokenizer和ONNX模型,设置固定长度
|
218 |
+
tokenizer = SimpleT5Tokenizer(f"{onnx_dir}/spiece.model", max_length=63)
|
219 |
+
text_encoder_onnx = load_onnx_model(f"{onnx_dir}/text_encoder_nf4.onnx")
|
220 |
+
|
221 |
+
# 加载其他ONNX模型
|
222 |
+
vae_decoder = load_onnx_model(f"{onnx_dir}/vae_decoder.onnx")
|
223 |
+
duration_embedder = load_onnx_model(f"{onnx_dir}/duration_embedder.onnx")
|
224 |
+
transformer = load_onnx_model(f"{onnx_dir}/transformer.onnx")
|
225 |
+
proj_layer = load_onnx_model(f"{onnx_dir}/proj.onnx")
|
226 |
+
|
227 |
+
# 1. duration embedding
|
228 |
+
duration_input = np.array([[duration]], dtype=np.float32)
|
229 |
+
print(f"[Shape] duration输入: {duration_input.shape}")
|
230 |
+
|
231 |
+
duration_hidden_states = duration_embedder.run(
|
232 |
+
['embedding'],
|
233 |
+
{'duration': duration_input}
|
234 |
+
)[0]
|
235 |
+
print(f"[Shape] duration embedding: {duration_hidden_states.shape}")
|
236 |
+
|
237 |
+
if guidance_scale > 1.0:
|
238 |
+
duration_hidden_states = np.concatenate([duration_hidden_states] * 2, axis=0)
|
239 |
+
print(f"[Shape] 复制后的duration embedding: {duration_hidden_states.shape}")
|
240 |
+
|
241 |
+
# 2. text encoder
|
242 |
+
if guidance_scale > 1.0:
|
243 |
+
(uncond_hidden_states, uncond_mask), (cond_hidden_states, cond_mask) = encode_text(
|
244 |
+
prompt, negative_prompt, tokenizer, text_encoder_onnx, guidance_scale=guidance_scale
|
245 |
+
)
|
246 |
+
print(cond_hidden_states)
|
247 |
+
encoder_hidden_states = np.concatenate([uncond_hidden_states, cond_hidden_states])
|
248 |
+
attention_mask = np.concatenate([uncond_mask, cond_mask])
|
249 |
+
else:
|
250 |
+
encoder_hidden_states, attention_mask = encode_text(
|
251 |
+
prompt, tokenizer, text_encoder_onnx
|
252 |
+
)
|
253 |
+
|
254 |
+
# 3. pooled_text
|
255 |
+
boolean_encoder_mask = (attention_mask == 1)
|
256 |
+
mask_expanded = boolean_encoder_mask[..., None].repeat(encoder_hidden_states.shape[-1], axis=-1)
|
257 |
+
masked_data = np.where(mask_expanded, encoder_hidden_states, np.nan)
|
258 |
+
pooled = np.nanmean(masked_data, axis=1)
|
259 |
+
|
260 |
+
# 使用projection层处理
|
261 |
+
pooled_text = proj_layer.run(
|
262 |
+
['projected'],
|
263 |
+
{'text_embedding': pooled.astype(np.float32)}
|
264 |
+
)[0]
|
265 |
+
|
266 |
+
# 4. 合并duration和text特征
|
267 |
+
encoder_hidden_states = np.concatenate(
|
268 |
+
[encoder_hidden_states, duration_hidden_states],
|
269 |
+
axis=1
|
270 |
+
)
|
271 |
+
|
272 |
+
# 5. 创建其他输入
|
273 |
+
txt_ids = np.zeros((1, encoder_hidden_states.shape[1], 3), dtype=np.int64)
|
274 |
+
img_ids = np.tile(
|
275 |
+
np.arange(645, dtype=np.int64)[None, :, None],
|
276 |
+
(1, 1, 3)
|
277 |
+
)
|
278 |
+
|
279 |
+
# 6. scheduler
|
280 |
+
scheduler = SimpleFlowMatchScheduler(num_train_timesteps=1000)
|
281 |
+
scheduler.set_timesteps(steps)
|
282 |
+
|
283 |
+
# 初始化latents
|
284 |
+
latents = np.random.randn(1, 645, 64).astype(np.float32)
|
285 |
+
|
286 |
+
# 7. 生成循环
|
287 |
+
for i in range(steps):
|
288 |
+
# Transformer前向传播
|
289 |
+
noise_pred = transformer.run(
|
290 |
+
['output'],
|
291 |
+
{
|
292 |
+
'hidden_states': latents,
|
293 |
+
'timestep': np.array([scheduler.timesteps[i]/1000], dtype=np.float32),
|
294 |
+
'pooled_text': pooled_text,
|
295 |
+
'encoder_hidden_states': encoder_hidden_states,
|
296 |
+
'txt_ids': txt_ids,
|
297 |
+
'img_ids': img_ids
|
298 |
+
}
|
299 |
+
)[0]
|
300 |
+
|
301 |
+
if i == 0: # 只在第一步打印
|
302 |
+
print(f"[Shape] noise预测输出: {noise_pred.shape}")
|
303 |
+
|
304 |
+
# 应用classifier free guidance
|
305 |
+
if guidance_scale > 1.0:
|
306 |
+
noise_pred_uncond, noise_pred_text = noise_pred[0:1], noise_pred[1:2]
|
307 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
308 |
+
|
309 |
+
# 使用scheduler更新latents
|
310 |
+
latents = scheduler.step(noise_pred, scheduler.timesteps[i], latents)
|
311 |
+
|
312 |
+
if i % 10 == 0:
|
313 |
+
print(f"生成进度: {i}/{steps}")
|
314 |
+
|
315 |
+
# 8. VAE解码前的处理
|
316 |
+
latents = latents / scheduler.sigmas[0]
|
317 |
+
latents = np.transpose(latents, (0, 2, 1))
|
318 |
+
|
319 |
+
# 9. VAE解码
|
320 |
+
wave = vae_decoder.run(['audio'], {'latent': latents})[0]
|
321 |
+
|
322 |
+
# 10. 裁剪
|
323 |
+
sample_rate = 44100
|
324 |
+
waveform_end = int(duration * sample_rate)
|
325 |
+
wave = wave[:, :, :waveform_end]
|
326 |
+
print(f"[Shape] 裁剪后的最终波形: {wave.shape}")
|
327 |
+
|
328 |
+
# 11. 保存音频
|
329 |
+
wave = wave[0] # 移除batch维度
|
330 |
+
sf.write(output_path, wave.T, sample_rate) # soundfile需要(samples, channels)格式
|
331 |
+
|
332 |
+
return wave
|
333 |
+
|
334 |
+
if __name__ == "__main__":
|
335 |
+
import argparse
|
336 |
+
|
337 |
+
parser = argparse.ArgumentParser(description="测试ONNX模型推理")
|
338 |
+
parser.add_argument("--prompt", type=str, default="What does the fox say?", help="文本提示")
|
339 |
+
parser.add_argument("--negative_prompt", type=str, default="", help="负文本提示")
|
340 |
+
parser.add_argument("--onnx_dir", type=str, default=".", help="ONNX模型目录")
|
341 |
+
parser.add_argument("--duration", type=float, default=10.0, help="生成音频时长(秒)")
|
342 |
+
parser.add_argument("--steps", type=int, default=30, help="推理步数")
|
343 |
+
parser.add_argument("--guidance_scale", type=float, default=4.5, help="引导系数")
|
344 |
+
parser.add_argument("--output", type=str, default="output_onnx.wav", help="输出音频路径")
|
345 |
+
parser.add_argument("--seed", type=int, default=42, help="随机种子")
|
346 |
+
|
347 |
+
args = parser.parse_args()
|
348 |
+
|
349 |
+
# 生成音频
|
350 |
+
wave = generate_audio_onnx(
|
351 |
+
# prompt="What does the fox say?",
|
352 |
+
# prompt="Never gonna give you up, never gonna let you down",
|
353 |
+
# prompt="Electonic music, future house style",
|
354 |
+
prompt=args.prompt,
|
355 |
+
negative_prompt=args.negative_prompt,
|
356 |
+
duration=args.duration,
|
357 |
+
steps=args.steps,
|
358 |
+
guidance_scale=args.guidance_scale,
|
359 |
+
onnx_dir=args.onnx_dir,
|
360 |
+
output_path=args.output,
|
361 |
+
seed=args.seed
|
362 |
+
)
|
363 |
+
|
364 |
+
print(f"生成的音频shape为: {wave.shape}")
|
365 |
+
print(f"音频已保存到: {args.output}")
|
proj.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:af92c1c262e6a217a75ef59c922304eb90770f9a67a6253c9c477fbe3fa9eba8
|
3 |
+
size 4198734
|
spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
|
3 |
+
size 791656
|
text_encoder_bnb4.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6334ec5d0eeaba54449c24ec0784c3e224db6c483a968c0fca055e001b80e39
|
3 |
+
size 305592280
|
transformer.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c9079229578ab2b683c271f9f585fddabcfeb588191d9c02c597f0aa4b6a383b
|
3 |
+
size 2068637351
|
transformer.rknn
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dffd321f320c949a0cffc9b3bf92b371fccaebb5f25826710c8b89d84184d2c7
|
3 |
+
size 1118028281
|
vae_decoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25cdc6d16f896906df2cea9374b2746842ef808563e6daeb7b48b2eb6360a4a2
|
3 |
+
size 312595968
|
vae_decoder.rknn
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c9782440915ffe1698717576b7889fc7d387a9c62b4874ff83337b8473b5049
|
3 |
+
size 352599027
|
vae_encoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:846b8fd27f2f6309954fb1420066b56123bfe34c3adcc100809c578011179980
|
3 |
+
size 312074746
|
ztu_somemodelruntime_rknnlite2.py
ADDED
@@ -0,0 +1,535 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 模块级常量和函数
|
2 |
+
from rknnlite.api import RKNNLite
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import warnings
|
6 |
+
import logging
|
7 |
+
from typing import List, Dict, Union, Optional
|
8 |
+
|
9 |
+
try:
|
10 |
+
import onnxruntime as ort
|
11 |
+
HAS_ORT = True
|
12 |
+
except ImportError:
|
13 |
+
HAS_ORT = False
|
14 |
+
warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
|
15 |
+
|
16 |
+
# 配置日志
|
17 |
+
logger = logging.getLogger("somemodelruntime_rknnlite2")
|
18 |
+
logger.setLevel(logging.ERROR) # 默认只输出错误信息
|
19 |
+
if not logger.handlers:
|
20 |
+
handler = logging.StreamHandler()
|
21 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
22 |
+
logger.addHandler(handler)
|
23 |
+
|
24 |
+
# ONNX Runtime日志级别到Python logging级别的映射
|
25 |
+
_LOGGING_LEVEL_MAP = {
|
26 |
+
0: logging.DEBUG, # Verbose
|
27 |
+
1: logging.INFO, # Info
|
28 |
+
2: logging.WARNING, # Warning
|
29 |
+
3: logging.ERROR, # Error
|
30 |
+
4: logging.CRITICAL # Fatal
|
31 |
+
}
|
32 |
+
|
33 |
+
def set_default_logger_severity(level: int) -> None:
|
34 |
+
"""
|
35 |
+
Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
|
36 |
+
|
37 |
+
Args:
|
38 |
+
level: 日志级别(0-4)
|
39 |
+
"""
|
40 |
+
if level not in _LOGGING_LEVEL_MAP:
|
41 |
+
raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
|
42 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[level])
|
43 |
+
|
44 |
+
def set_default_logger_verbosity(level: int) -> None:
|
45 |
+
"""
|
46 |
+
Sets the default logging verbosity level. To activate the verbose log,
|
47 |
+
you need to set the default logging severity to 0:Verbose level.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
level: 日志级别(0-4)
|
51 |
+
"""
|
52 |
+
set_default_logger_severity(level)
|
53 |
+
|
54 |
+
# NPU核心模式常量
|
55 |
+
NPU_CORE_AUTO = 0 # 自动选择
|
56 |
+
NPU_CORE_0 = 1 # 使用核心0
|
57 |
+
NPU_CORE_1 = 2 # 使用核心1
|
58 |
+
NPU_CORE_2 = 4 # 使用核心2
|
59 |
+
NPU_CORE_0_1 = 3 # 使用核心0和1
|
60 |
+
NPU_CORE_0_1_2 = 7 # 使用所有核心
|
61 |
+
NPU_CORE_ALL = 0xffff # 使用所有核心
|
62 |
+
|
63 |
+
# RKNN tensor type到numpy dtype的映射
|
64 |
+
RKNN_DTYPE_MAP = {
|
65 |
+
0: np.float32, # RKNN_TENSOR_FLOAT32
|
66 |
+
1: np.float16, # RKNN_TENSOR_FLOAT16
|
67 |
+
2: np.int8, # RKNN_TENSOR_INT8
|
68 |
+
3: np.uint8, # RKNN_TENSOR_UINT8
|
69 |
+
4: np.int16, # RKNN_TENSOR_INT16
|
70 |
+
5: np.uint16, # RKNN_TENSOR_UINT16
|
71 |
+
6: np.int32, # RKNN_TENSOR_INT32
|
72 |
+
7: np.uint32, # RKNN_TENSOR_UINT32
|
73 |
+
8: np.int64, # RKNN_TENSOR_INT64
|
74 |
+
9: bool, # RKNN_TENSOR_BOOL
|
75 |
+
10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
|
76 |
+
}
|
77 |
+
|
78 |
+
def get_available_providers() -> List[str]:
|
79 |
+
"""
|
80 |
+
获取可用的设备提供者列表(为保持接口兼容性的占位函数)
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
list: 可用的设备提供者列表,总是返回["CPUExecutionProvider"]
|
84 |
+
"""
|
85 |
+
return ["CPUExecutionProvider"]
|
86 |
+
|
87 |
+
def get_version_info() -> Dict[str, str]:
|
88 |
+
"""
|
89 |
+
获取版本信息
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
dict: 包含API和驱动版本信息的字典
|
93 |
+
"""
|
94 |
+
runtime = RKNNLite()
|
95 |
+
version = runtime.get_sdk_version()
|
96 |
+
return {
|
97 |
+
"api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
|
98 |
+
"driver_version": version.split('\n')[3].split(': ')[1]
|
99 |
+
}
|
100 |
+
|
101 |
+
class IOTensor:
|
102 |
+
"""输入/输出张量的信息封装类"""
|
103 |
+
def __init__(self, name, shape, type=None):
|
104 |
+
self.name = name.decode() if isinstance(name, bytes) else name
|
105 |
+
self.shape = shape
|
106 |
+
self.type = type
|
107 |
+
|
108 |
+
def __str__(self):
|
109 |
+
return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
|
110 |
+
|
111 |
+
class SessionOptions:
|
112 |
+
"""会话选项类"""
|
113 |
+
def __init__(self):
|
114 |
+
self.async_mode = False # 是否使用异步模式
|
115 |
+
self.core_mask = 0 # NPU核心选择
|
116 |
+
self.perf_debug = False # 是否启用性能分析
|
117 |
+
|
118 |
+
class InferenceSession:
|
119 |
+
"""
|
120 |
+
RKNNLite运行时封装类,API风格类似ONNX Runtime
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __new__(cls, model_path: str, verbose: bool = False, sess_options: Optional[SessionOptions] = None, fallback: bool = True, **kwargs):
|
124 |
+
"""
|
125 |
+
创建运行时实例
|
126 |
+
|
127 |
+
Args:
|
128 |
+
model_path: 模型文件路径(.rknn或.onnx)
|
129 |
+
verbose: 是否打印详细日志
|
130 |
+
sess_options: 会话选项
|
131 |
+
fallback: 是否自动加载同名.rknn文件
|
132 |
+
**kwargs: 其他初始化参数
|
133 |
+
"""
|
134 |
+
# 只在verbose=True时开启详细日志
|
135 |
+
if verbose:
|
136 |
+
set_default_logger_severity(0)
|
137 |
+
|
138 |
+
if not os.path.exists(model_path):
|
139 |
+
logger.error(f"模型文件不存在: {model_path}")
|
140 |
+
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
141 |
+
|
142 |
+
# 检查是否是ONNX文件
|
143 |
+
is_onnx = model_path.lower().endswith('.onnx')
|
144 |
+
|
145 |
+
if is_onnx and fallback:
|
146 |
+
# 尝试查找对应的RKNN文件
|
147 |
+
rknn_path = os.path.splitext(model_path)[0] + '.rknn'
|
148 |
+
if os.path.exists(rknn_path):
|
149 |
+
logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
|
150 |
+
# 创建RKNN运行时实例
|
151 |
+
instance = super().__new__(cls)
|
152 |
+
instance.model_path = rknn_path
|
153 |
+
return instance
|
154 |
+
|
155 |
+
if is_onnx:
|
156 |
+
# 使用ONNX Runtime
|
157 |
+
logger.info(f"使用ONNX Runtime加载模型: {model_path}")
|
158 |
+
if not HAS_ORT:
|
159 |
+
raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
|
160 |
+
return ort.InferenceSession(model_path, sess_options=sess_options, **kwargs)
|
161 |
+
|
162 |
+
# 创建RKNN运行时实例
|
163 |
+
instance = super().__new__(cls)
|
164 |
+
instance.model_path = model_path
|
165 |
+
return instance
|
166 |
+
|
167 |
+
def __init__(self, model_path: str, verbose: bool = False, sess_options: Optional[SessionOptions] = None, fallback: bool = True, **kwargs):
|
168 |
+
"""
|
169 |
+
初始化RKNN运行时
|
170 |
+
|
171 |
+
Args:
|
172 |
+
model_path: 模型文件路径(.rknn或.onnx)
|
173 |
+
verbose: 是否打印详细日志
|
174 |
+
sess_options: 会话选项
|
175 |
+
fallback: 是否自动加载同名.rknn文件
|
176 |
+
**kwargs: 其他初始化参数
|
177 |
+
"""
|
178 |
+
# 如果是ONNX模型,__init__不会被调用
|
179 |
+
if not hasattr(self, 'model_path'): # 如果是ONNX Runtime实例
|
180 |
+
return
|
181 |
+
|
182 |
+
self.runtime = RKNNLite(verbose=verbose)
|
183 |
+
|
184 |
+
# 加载模型
|
185 |
+
logger.debug(f"正在加载模型: {self.model_path}")
|
186 |
+
ret = self.runtime.load_rknn(self.model_path)
|
187 |
+
if ret != 0:
|
188 |
+
logger.error(f"加载RKNN模型失败: {self.model_path}")
|
189 |
+
raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
|
190 |
+
logger.debug("模型加载成功")
|
191 |
+
|
192 |
+
# 应用会话选项
|
193 |
+
options = sess_options or SessionOptions()
|
194 |
+
|
195 |
+
# 初始化运行时
|
196 |
+
logger.debug("正在初始化运行时环境")
|
197 |
+
ret = self.runtime.init_runtime(
|
198 |
+
async_mode=options.async_mode,
|
199 |
+
core_mask=options.core_mask
|
200 |
+
)
|
201 |
+
if ret != 0:
|
202 |
+
logger.error("初始化运行时环境失败")
|
203 |
+
raise RuntimeError('初始化运行时环境失败')
|
204 |
+
logger.debug("运行时环境初始化成功")
|
205 |
+
|
206 |
+
# 获取输入输出信息
|
207 |
+
self._init_io_info()
|
208 |
+
|
209 |
+
# 保存选项
|
210 |
+
self.options = options
|
211 |
+
|
212 |
+
def get_performance_info(self) -> Dict[str, float]:
|
213 |
+
"""
|
214 |
+
获取性能信息
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
dict: 包含性能信息的字典
|
218 |
+
"""
|
219 |
+
if not self.options.perf_debug:
|
220 |
+
raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
|
221 |
+
|
222 |
+
perf = self.runtime.rknn_runtime.get_run_perf()
|
223 |
+
return {
|
224 |
+
"run_duration": perf.run_duration / 1000.0 # 转换为毫秒
|
225 |
+
}
|
226 |
+
|
227 |
+
def set_core_mask(self, core_mask: int) -> None:
|
228 |
+
"""
|
229 |
+
设置NPU核心使用模式
|
230 |
+
|
231 |
+
Args:
|
232 |
+
core_mask: NPU核心掩码,使用NPU_CORE_*常量
|
233 |
+
"""
|
234 |
+
ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
|
235 |
+
if ret != 0:
|
236 |
+
raise RuntimeError("设置NPU核心模式失败")
|
237 |
+
|
238 |
+
def _convert_nhwc_to_nchw(self, shape):
|
239 |
+
"""将NHWC格式的shape转换为NCHW格式"""
|
240 |
+
if len(shape) == 4:
|
241 |
+
# NHWC -> NCHW
|
242 |
+
n, h, w, c = shape
|
243 |
+
return [n, c, h, w]
|
244 |
+
return shape
|
245 |
+
|
246 |
+
def _init_io_info(self):
|
247 |
+
"""初始化模型的输入输出信息"""
|
248 |
+
runtime = self.runtime.rknn_runtime
|
249 |
+
|
250 |
+
# 获取输入输出数量
|
251 |
+
n_input, n_output = runtime.get_in_out_num()
|
252 |
+
|
253 |
+
# 获取输入信息
|
254 |
+
self.input_tensors = []
|
255 |
+
for i in range(n_input):
|
256 |
+
attr = runtime.get_tensor_attr(i)
|
257 |
+
shape = [attr.dims[j] for j in range(attr.n_dims)]
|
258 |
+
# 对四维输入进行NHWC到NCHW的转换
|
259 |
+
shape = self._convert_nhwc_to_nchw(shape)
|
260 |
+
# 获取dtype
|
261 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
262 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
263 |
+
self.input_tensors.append(tensor)
|
264 |
+
|
265 |
+
# 获取输出信息
|
266 |
+
self.output_tensors = []
|
267 |
+
for i in range(n_output):
|
268 |
+
attr = runtime.get_tensor_attr(i, is_output=True)
|
269 |
+
shape = runtime.get_output_shape(i)
|
270 |
+
# 获取dtype
|
271 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
272 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
273 |
+
self.output_tensors.append(tensor)
|
274 |
+
|
275 |
+
def get_inputs(self):
|
276 |
+
"""
|
277 |
+
获取模型输入信息
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
list: 包含输入信息的列表
|
281 |
+
"""
|
282 |
+
return self.input_tensors
|
283 |
+
|
284 |
+
def get_outputs(self):
|
285 |
+
"""
|
286 |
+
获取模型输出信息
|
287 |
+
|
288 |
+
Returns:
|
289 |
+
list: 包含输出信息的列表
|
290 |
+
"""
|
291 |
+
return self.output_tensors
|
292 |
+
|
293 |
+
def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
|
294 |
+
"""
|
295 |
+
执行模型推理
|
296 |
+
|
297 |
+
Args:
|
298 |
+
output_names: 输出节点名称列表,指定需要返回哪些输出
|
299 |
+
input_feed: 输入数据字典或列表
|
300 |
+
data_format: 输入数据格式,"nchw"或"nhwc"
|
301 |
+
**kwargs: 其他运行时参数
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
|
305 |
+
"""
|
306 |
+
if input_feed is None:
|
307 |
+
logger.error("input_feed不能为None")
|
308 |
+
raise ValueError("input_feed不能为None")
|
309 |
+
|
310 |
+
# 准备输入数据
|
311 |
+
if isinstance(input_feed, dict):
|
312 |
+
# 如果是字典,按照模型输入顺序排列
|
313 |
+
inputs = []
|
314 |
+
input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
|
315 |
+
for tensor in self.input_tensors:
|
316 |
+
if tensor.name not in input_feed:
|
317 |
+
raise ValueError(f"缺少输入: {tensor.name}")
|
318 |
+
inputs.append(input_feed[tensor.name])
|
319 |
+
elif isinstance(input_feed, (list, tuple)):
|
320 |
+
# 如果是列表,确保长度匹配
|
321 |
+
if len(input_feed) != len(self.input_tensors):
|
322 |
+
raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
|
323 |
+
inputs = list(input_feed)
|
324 |
+
else:
|
325 |
+
logger.error("input_feed必须是字典或列表类型")
|
326 |
+
raise ValueError("input_feed必须是字典或列表类型")
|
327 |
+
|
328 |
+
# 执行推理
|
329 |
+
try:
|
330 |
+
logger.debug("开始执行推理")
|
331 |
+
all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
|
332 |
+
|
333 |
+
# 如果没有指定output_names,返回所有输出
|
334 |
+
if output_names is None:
|
335 |
+
return all_outputs
|
336 |
+
|
337 |
+
# 获取指定的输出
|
338 |
+
output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
|
339 |
+
selected_outputs = []
|
340 |
+
for name in output_names:
|
341 |
+
if name not in output_map:
|
342 |
+
raise ValueError(f"未找到输出节点: {name}")
|
343 |
+
selected_outputs.append(all_outputs[output_map[name]])
|
344 |
+
|
345 |
+
return selected_outputs
|
346 |
+
|
347 |
+
except Exception as e:
|
348 |
+
logger.error(f"推理执行失败: {str(e)}")
|
349 |
+
raise RuntimeError(f"推理执行失败: {str(e)}")
|
350 |
+
|
351 |
+
def close(self):
|
352 |
+
"""
|
353 |
+
关闭会话,释放资源
|
354 |
+
"""
|
355 |
+
if self.runtime is not None:
|
356 |
+
logger.info("正在释放运行时资源")
|
357 |
+
self.runtime.release()
|
358 |
+
self.runtime = None
|
359 |
+
|
360 |
+
def __enter__(self):
|
361 |
+
return self
|
362 |
+
|
363 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
364 |
+
self.close()
|
365 |
+
|
366 |
+
def end_profiling(self) -> Optional[str]:
|
367 |
+
"""
|
368 |
+
结束性能分析的存根方法
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
Optional[str]: None
|
372 |
+
"""
|
373 |
+
warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
374 |
+
return None
|
375 |
+
|
376 |
+
def get_profiling_start_time_ns(self) -> int:
|
377 |
+
"""
|
378 |
+
获取性能分析开始时间的存根方法
|
379 |
+
|
380 |
+
Returns:
|
381 |
+
int: 0
|
382 |
+
"""
|
383 |
+
warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
384 |
+
return 0
|
385 |
+
|
386 |
+
def get_modelmeta(self) -> Dict[str, str]:
|
387 |
+
"""
|
388 |
+
获取模型元数据的存根方法
|
389 |
+
|
390 |
+
Returns:
|
391 |
+
Dict[str, str]: 空字典
|
392 |
+
"""
|
393 |
+
warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
394 |
+
return {}
|
395 |
+
|
396 |
+
def get_session_options(self) -> SessionOptions:
|
397 |
+
"""
|
398 |
+
获取会话选项
|
399 |
+
|
400 |
+
Returns:
|
401 |
+
SessionOptions: 当前会话选项
|
402 |
+
"""
|
403 |
+
return self.options
|
404 |
+
|
405 |
+
def get_providers(self) -> List[str]:
|
406 |
+
"""
|
407 |
+
获取当前使用的providers的存根方法
|
408 |
+
|
409 |
+
Returns:
|
410 |
+
List[str]: ["CPUExecutionProvider"]
|
411 |
+
"""
|
412 |
+
warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
|
413 |
+
return ["CPUExecutionProvider"]
|
414 |
+
|
415 |
+
def get_provider_options(self) -> Dict[str, Dict[str, str]]:
|
416 |
+
"""
|
417 |
+
获取provider选项的存根方法
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
Dict[str, Dict[str, str]]: 空字典
|
421 |
+
"""
|
422 |
+
warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
423 |
+
return {}
|
424 |
+
|
425 |
+
def get_session_config(self) -> Dict[str, str]:
|
426 |
+
"""
|
427 |
+
获取会话配置的存根方法
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
Dict[str, str]: 空字典
|
431 |
+
"""
|
432 |
+
warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
433 |
+
return {}
|
434 |
+
|
435 |
+
def get_session_state(self) -> Dict[str, str]:
|
436 |
+
"""
|
437 |
+
获取会话状态的存根方法
|
438 |
+
|
439 |
+
Returns:
|
440 |
+
Dict[str, str]: 空字典
|
441 |
+
"""
|
442 |
+
warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
443 |
+
return {}
|
444 |
+
|
445 |
+
def set_session_config(self, config: Dict[str, str]) -> None:
|
446 |
+
"""
|
447 |
+
设置会话配置的存根方法
|
448 |
+
|
449 |
+
Args:
|
450 |
+
config: 会话配置字典
|
451 |
+
"""
|
452 |
+
warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
453 |
+
|
454 |
+
def get_memory_info(self) -> Dict[str, int]:
|
455 |
+
"""
|
456 |
+
获取内存使用信息的存根方法
|
457 |
+
|
458 |
+
Returns:
|
459 |
+
Dict[str, int]: 空字典
|
460 |
+
"""
|
461 |
+
warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
462 |
+
return {}
|
463 |
+
|
464 |
+
def set_memory_pattern(self, enable: bool) -> None:
|
465 |
+
"""
|
466 |
+
设置内存模式的存根方法
|
467 |
+
|
468 |
+
Args:
|
469 |
+
enable: 是否启用内存模式
|
470 |
+
"""
|
471 |
+
warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
472 |
+
|
473 |
+
def disable_memory_pattern(self) -> None:
|
474 |
+
"""
|
475 |
+
禁用内存模式的存根方法
|
476 |
+
"""
|
477 |
+
warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
478 |
+
|
479 |
+
def get_optimization_level(self) -> int:
|
480 |
+
"""
|
481 |
+
获取优化级别的存根方法
|
482 |
+
|
483 |
+
Returns:
|
484 |
+
int: 0
|
485 |
+
"""
|
486 |
+
warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
487 |
+
return 0
|
488 |
+
|
489 |
+
def set_optimization_level(self, level: int) -> None:
|
490 |
+
"""
|
491 |
+
设置优化级别的存根方法
|
492 |
+
|
493 |
+
Args:
|
494 |
+
level: 优化级别
|
495 |
+
"""
|
496 |
+
warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
497 |
+
|
498 |
+
def get_model_metadata(self) -> Dict[str, str]:
|
499 |
+
"""
|
500 |
+
获取模型元数据的存根方法(与get_modelmeta不同的接口)
|
501 |
+
|
502 |
+
Returns:
|
503 |
+
Dict[str, str]: 空字典
|
504 |
+
"""
|
505 |
+
warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
506 |
+
return {}
|
507 |
+
|
508 |
+
def get_model_path(self) -> str:
|
509 |
+
"""
|
510 |
+
获取模型路径
|
511 |
+
|
512 |
+
Returns:
|
513 |
+
str: 模型文件路径
|
514 |
+
"""
|
515 |
+
return self.model_path
|
516 |
+
|
517 |
+
def get_input_type_info(self) -> List[Dict[str, str]]:
|
518 |
+
"""
|
519 |
+
获取输入类型信息的存根方法
|
520 |
+
|
521 |
+
Returns:
|
522 |
+
List[Dict[str, str]]: 空列表
|
523 |
+
"""
|
524 |
+
warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
525 |
+
return []
|
526 |
+
|
527 |
+
def get_output_type_info(self) -> List[Dict[str, str]]:
|
528 |
+
"""
|
529 |
+
获取输出类型信息的存根方法
|
530 |
+
|
531 |
+
Returns:
|
532 |
+
List[Dict[str, str]]: 空列表
|
533 |
+
"""
|
534 |
+
warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
535 |
+
return []
|