happyme531 commited on
Commit
f3a1217
·
verified ·
1 Parent(s): 0b3bd77

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ transformer.rknn filter=lfs diff=lfs merge=lfs -text
37
+ vae_decoder.rknn filter=lfs diff=lfs merge=lfs -text
convert_rknn.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import datetime
5
+ import argparse
6
+ from rknn.api import RKNN
7
+ from sys import exit
8
+
9
+ AUDIO_LENGTH = 645 # 音频长度, 645为10秒
10
+ TEXT_LENGTH = 64 # 文本长度(token)
11
+
12
+ # 模型配置
13
+ MODELS = {
14
+ 'transformer': 'transformer.onnx',
15
+ 'vae_decoder': 'vae_decoder.onnx',
16
+ }
17
+
18
+ SHAPES = {
19
+ 'transformer': [
20
+ [
21
+ [1, AUDIO_LENGTH, 64], # hidden_states
22
+ [1,], # timestep
23
+ [2, 1024], # pooled_text
24
+ [2, TEXT_LENGTH, 1024], # encoder_hidden_states
25
+ [1, TEXT_LENGTH, 3], # txt_ids
26
+ [1, AUDIO_LENGTH, 3], # img_ids
27
+ ],
28
+ ],
29
+ 'vae_decoder': [
30
+ [
31
+ [1, 64, AUDIO_LENGTH],
32
+ ],
33
+ ],
34
+ }
35
+
36
+ QUANTIZE=False
37
+ detailed_performance_log = True
38
+
39
+ def convert_model(model_type):
40
+ """转换指定类型的模型到RKNN格式"""
41
+ if model_type not in MODELS:
42
+ print(f"错误: 不支持的模型类型 {model_type}")
43
+ return False
44
+
45
+ onnx_model = MODELS[model_type]
46
+ rknn_model = onnx_model.replace(".onnx",".rknn")
47
+
48
+ timedate_iso = datetime.datetime.now().isoformat()
49
+
50
+ rknn = RKNN(verbose=True)
51
+ rknn.config(
52
+ quantized_dtype='w8a8',
53
+ quantized_algorithm='normal',
54
+ quantized_method='channel',
55
+ quantized_hybrid_level=0,
56
+ target_platform='rk3588',
57
+ quant_img_RGB2BGR = False,
58
+ float_dtype='float16',
59
+ optimization_level=3,
60
+ custom_string=f"converted at {timedate_iso}",
61
+ remove_weight=False,
62
+ compress_weight=False,
63
+ inputs_yuv_fmt=None,
64
+ single_core_mode=False,
65
+ dynamic_input=SHAPES[model_type],
66
+ model_pruning=False,
67
+ op_target=None,
68
+ quantize_weight=False,
69
+ remove_reshape=False,
70
+ sparse_infer=False,
71
+ enable_flash_attention=False,
72
+ # disable_rules=['convert_gemm_by_exmatmul']
73
+ )
74
+
75
+ print(f"开始转换 {model_type} 模型...")
76
+ ret = rknn.load_onnx(model=onnx_model)
77
+ if ret != 0:
78
+ print("加载ONNX模型失败")
79
+ return False
80
+
81
+ ret = rknn.build(do_quantization=False, rknn_batch_size=None)
82
+ if ret != 0:
83
+ print("构建RKNN模型失败")
84
+ return False
85
+
86
+ ret = rknn.export_rknn(rknn_model)
87
+ if ret != 0:
88
+ print("导出RKNN模型失败")
89
+ return False
90
+
91
+ print(f"成功转换模型: {rknn_model}")
92
+ return True
93
+
94
+ def main():
95
+ parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式')
96
+ parser.add_argument('model_type', nargs='?', default='all',
97
+ choices=['all', 'transformer', 'vae_decoder'],
98
+ help='要转换的模型类型 (默认: all)')
99
+
100
+ args = parser.parse_args()
101
+
102
+ if args.model_type == 'all':
103
+ # 转换所有模型
104
+ for model_type in MODELS.keys():
105
+ if not convert_model(model_type):
106
+ print(f"转换 {model_type} 失败")
107
+ else:
108
+ # 转换指定模型
109
+ if not convert_model(args.model_type):
110
+ print(f"转换 {args.model_type} 失败")
111
+
112
+ if __name__ == '__main__':
113
+ main()
114
+
115
+
duration_embedder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b2bd04d4bbd075e7c663711e55b6d09c68bbd35a772587ae46d8339599e03e3
3
+ size 1061046
export_onnx.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from diffusers import AutoencoderOobleck
4
+ from diffusers import FluxTransformer2DModel
5
+ from tangoflux import TangoFluxInference
6
+ from tangoflux.model import DurationEmbedder, TangoFlux
7
+
8
+ def export_vae_encoder(vae, save_path, batch_size=1, audio_length=441000):
9
+ """导出VAE编码器到ONNX格式
10
+
11
+ Args:
12
+ vae: AutoencoderOobleck实例
13
+ save_path: 保存路径
14
+ batch_size: batch大小
15
+ audio_length: 音频长度(默认10秒,44100Hz采样率)
16
+ """
17
+ vae.eval()
18
+
19
+ # 创建dummy input - 注意这里是双声道音频
20
+ dummy_input = torch.randn(batch_size, 2, audio_length)
21
+
22
+ # 创建一个包装类来处理forward调用
23
+ class VAEEncoderWrapper(nn.Module):
24
+ def __init__(self, vae):
25
+ super().__init__()
26
+ self.vae = vae
27
+
28
+ def forward(self, audio):
29
+ return self.vae.encode(audio).latent_dist.sample()
30
+
31
+ wrapper = VAEEncoderWrapper(vae)
32
+
33
+ # 导出encoder部分
34
+ torch.onnx.export(
35
+ wrapper,
36
+ dummy_input,
37
+ save_path,
38
+ input_names=['audio'],
39
+ output_names=['latent'],
40
+ dynamic_axes={
41
+ 'audio': {0: 'batch_size', 2: 'audio_length'},
42
+ 'latent': {0: 'batch_size', 2: 'latent_length'}
43
+ },
44
+ opset_version=17
45
+ )
46
+
47
+ def export_vae_decoder(vae, save_path, batch_size=1, latent_length=645):
48
+ """导出VAE解码器到ONNX格式
49
+
50
+ Args:
51
+ vae: AutoencoderOobleck实例
52
+ save_path: 保存路径
53
+ batch_size: batch大小
54
+ latent_length: 潜在向量长度
55
+ """
56
+ vae.eval()
57
+
58
+ # 创建dummy input
59
+ dummy_input = torch.randn(batch_size, 64, latent_length)
60
+
61
+ # 创建一个包装类来处理forward调用
62
+ class VAEDecoderWrapper(nn.Module):
63
+ def __init__(self, vae):
64
+ super().__init__()
65
+ self.vae = vae
66
+
67
+ def forward(self, latent):
68
+ return self.vae.decode(latent).sample
69
+
70
+ wrapper = VAEDecoderWrapper(vae)
71
+
72
+ # 导出decoder部分
73
+ torch.onnx.export(
74
+ wrapper,
75
+ dummy_input,
76
+ save_path,
77
+ input_names=['latent'],
78
+ output_names=['audio'],
79
+ dynamic_axes={
80
+ 'latent': {0: 'batch_size', 2: 'latent_length'},
81
+ 'audio': {0: 'batch_size', 2: 'audio_length'}
82
+ },
83
+ opset_version=17
84
+ )
85
+
86
+ def export_duration_embedder(duration_embedder, save_path, batch_size=1):
87
+ """导出Duration Embedder到ONNX格式
88
+
89
+ Args:
90
+ duration_embedder: DurationEmbedder实例
91
+ save_path: 保存路径
92
+ batch_size: batch大小
93
+ """
94
+ duration_embedder.eval()
95
+
96
+ # 创建dummy input - 注意这里是标量值
97
+ dummy_input = torch.tensor([[10.0]], dtype=torch.float32) # 10秒
98
+
99
+ # 导出
100
+ torch.onnx.export(
101
+ duration_embedder,
102
+ dummy_input,
103
+ save_path,
104
+ input_names=['duration'],
105
+ output_names=['embedding'],
106
+ dynamic_axes={
107
+ 'duration': {0: 'batch_size'},
108
+ 'embedding': {0: 'batch_size'}
109
+ },
110
+ opset_version=17
111
+ )
112
+
113
+ def export_flux_transformer(transformer, save_path, batch_size=1, seq_length=645):
114
+ """导出FluxTransformer2D到ONNX格式
115
+
116
+ Args:
117
+ transformer: FluxTransformer2DModel实例
118
+ save_path: 保存路径
119
+ batch_size: batch大小
120
+ seq_length: 序列长度
121
+ """
122
+ transformer.eval()
123
+
124
+ # 创建dummy inputs - 注意所有输入的形状
125
+ hidden_states = torch.randn(batch_size, seq_length, 64) # [B, S, C]
126
+ timestep = torch.tensor([0.5]) # [1]
127
+ pooled_text = torch.randn(batch_size, 1024) # [B, D]
128
+ encoder_hidden_states = torch.randn(batch_size, 64, 1024) # [B, L, D]
129
+ txt_ids = torch.zeros(batch_size, 64, 3).to(torch.int64) # [B, L, 3]
130
+ img_ids = torch.arange(seq_length).unsqueeze(0).unsqueeze(-1).repeat(batch_size, 1, 3).to(torch.int64) # [B, S, 3]
131
+
132
+ # 创建一个包装类来处理forward调用
133
+ class TransformerWrapper(nn.Module):
134
+ def __init__(self, transformer):
135
+ super().__init__()
136
+ self.transformer = transformer
137
+
138
+ def forward(self, hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids):
139
+ return self.transformer(
140
+ hidden_states=hidden_states,
141
+ timestep=timestep,
142
+ guidance=None,
143
+ pooled_projections=pooled_text,
144
+ encoder_hidden_states=encoder_hidden_states,
145
+ txt_ids=txt_ids,
146
+ img_ids=img_ids,
147
+ return_dict=False
148
+ )[0]
149
+
150
+ wrapper = TransformerWrapper(transformer)
151
+
152
+ # 导出
153
+ torch.onnx.export(
154
+ wrapper,
155
+ (hidden_states, timestep, pooled_text, encoder_hidden_states, txt_ids, img_ids),
156
+ save_path,
157
+ input_names=['hidden_states', 'timestep', 'pooled_text', 'encoder_hidden_states', 'txt_ids', 'img_ids'],
158
+ output_names=['output'],
159
+ dynamic_axes={
160
+ 'hidden_states': {0: 'batch_size', 1: 'sequence_length'},
161
+ 'pooled_text': {0: 'batch_size'},
162
+ 'encoder_hidden_states': {0: 'batch_size', 1: 'text_length'},
163
+ 'txt_ids': {0: 'batch_size', 1: 'text_length'},
164
+ 'img_ids': {0: 'batch_size', 1: 'sequence_length'}
165
+ },
166
+ opset_version=17
167
+ )
168
+
169
+ def export_proj_layer(proj_layer, save_path, batch_size=1):
170
+ """导出projection层到ONNX格式
171
+
172
+ Args:
173
+ proj_layer: 投影层(fc层)实例
174
+ save_path: 保存路径
175
+ batch_size: batch大小
176
+ """
177
+ proj_layer.eval()
178
+
179
+ # 创建dummy input - 使用T5的hidden size
180
+ dummy_input = torch.randn(batch_size, 1024) # T5-large hidden size
181
+
182
+ # 导出
183
+ torch.onnx.export(
184
+ proj_layer,
185
+ dummy_input,
186
+ save_path,
187
+ input_names=['text_embedding'],
188
+ output_names=['projected'],
189
+ dynamic_axes={
190
+ 'text_embedding': {0: 'batch_size'},
191
+ 'projected': {0: 'batch_size'}
192
+ },
193
+ opset_version=17
194
+ )
195
+
196
+ def export_all(model_path, output_dir):
197
+ """导出所有组件到ONNX格式
198
+
199
+ Args:
200
+ model_path: TangoFlux模型路径
201
+ output_dir: 输出目录
202
+ """
203
+ import os
204
+
205
+ # 加载模型
206
+ model = TangoFluxInference(name=model_path, device="cpu")
207
+
208
+ # 创建输出目录
209
+ os.makedirs(output_dir, exist_ok=True)
210
+
211
+ # 导出VAE
212
+ export_vae_encoder(model.vae, f"{output_dir}/vae_encoder.onnx")
213
+ export_vae_decoder(model.vae, f"{output_dir}/vae_decoder.onnx")
214
+
215
+ # 导出Duration Embedder
216
+ export_duration_embedder(model.model.duration_emebdder, f"{output_dir}/duration_embedder.onnx")
217
+
218
+ # 导出Transformer
219
+ export_flux_transformer(model.model.transformer, f"{output_dir}/transformer.onnx")
220
+
221
+ # 导出Projection层
222
+ export_proj_layer(model.model.fc, f"{output_dir}/proj.onnx")
223
+
224
+ print(f"所有模型已导出到: {output_dir}")
225
+
226
+ if __name__ == "__main__":
227
+ import argparse
228
+
229
+ parser = argparse.ArgumentParser(description="导出TangoFlux模型到ONNX格式")
230
+ parser.add_argument("--model_path", type=str, required=True, help="TangoFlux模型路径")
231
+ parser.add_argument("--output_dir", type=str, required=True, help="输出目录")
232
+
233
+ args = parser.parse_args()
234
+ export_all(args.model_path, args.output_dir)
inference.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ # import onnxruntime as ort
3
+ import ztu_somemodelruntime_rknnlite2 as ort
4
+ import sentencepiece as spm
5
+ import soundfile as sf
6
+
7
+ ort.set_default_logger_verbosity(0)
8
+
9
+ def load_onnx_model(model_path):
10
+ """加载ONNX模型"""
11
+ return ort.InferenceSession(
12
+ model_path,
13
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
14
+ )
15
+
16
+ class SimpleT5Tokenizer:
17
+ def __init__(self, model_path, max_length=128):
18
+ """初始化tokenizer
19
+
20
+ Args:
21
+ model_path: sentencepiece模型路径
22
+ max_length: 序列最大长度,默认128
23
+ """
24
+ self.sp = spm.SentencePieceProcessor()
25
+ self.sp.Load(model_path)
26
+
27
+ # T5特殊token的ID
28
+ self.pad_token_id = 0
29
+ self.eos_token_id = 1
30
+ self.max_length = max_length
31
+
32
+ def __call__(self, texts, padding=True, truncation=True, max_length=None, return_tensors="np"):
33
+ """处理文本序列
34
+
35
+ Args:
36
+ texts: 文本或文本列表
37
+ padding: 是否padding
38
+ truncation: 是否截断
39
+ max_length: 可选,覆盖默认max_length
40
+ return_tensors: 返回类型(只支持"np")
41
+
42
+ Returns:
43
+ dict: 包含input_ids和attention_mask
44
+ """
45
+ if isinstance(texts, str):
46
+ texts = [texts]
47
+
48
+ max_len = max_length if max_length is not None else self.max_length
49
+
50
+ # 分词并转换为ID
51
+ input_ids = []
52
+ attention_mask = []
53
+ for text in texts:
54
+ ids = self.sp.EncodeAsIds(text)
55
+
56
+ # 截断处理(预留EOS token位置)
57
+ if truncation and len(ids) > max_len - 1:
58
+ ids = ids[:max_len-1]
59
+ ids.append(self.eos_token_id)
60
+
61
+ # 创建attention mask
62
+ mask = [1] * len(ids)
63
+
64
+ # Padding处理
65
+ if padding:
66
+ pad_length = max_len - len(ids)
67
+ ids.extend([self.pad_token_id] * pad_length)
68
+ mask.extend([0] * pad_length)
69
+
70
+ input_ids.append(ids)
71
+ attention_mask.append(mask)
72
+
73
+ # 转换为numpy array
74
+ input_ids = np.array(input_ids, dtype=np.int64)
75
+ attention_mask = np.array(attention_mask, dtype=np.int64)
76
+
77
+ return {
78
+ "input_ids": input_ids,
79
+ "attention_mask": attention_mask
80
+ }
81
+
82
+ def encode_text(prompt, negative_prompt, tokenizer, text_encoder_onnx, guidance_scale=None):
83
+ """编码文本,同时处理条件和无条件文本
84
+
85
+ Args:
86
+ prompt: 文本提示
87
+ tokenizer: T5 tokenizer
88
+ text_encoder_onnx: T5 ONNX模型
89
+ guidance_scale: 引导系数
90
+ """
91
+ if not isinstance(prompt, list):
92
+ prompt = [prompt]
93
+
94
+ if guidance_scale is not None and guidance_scale > 1.0:
95
+ # 同时处理条件和无条件文本
96
+ all_prompts = [negative_prompt] + prompt
97
+ batch = tokenizer(
98
+ all_prompts,
99
+ padding=True,
100
+ truncation=True,
101
+ return_tensors="np"
102
+ )
103
+
104
+ # ONNX推理
105
+ all_hidden_states = text_encoder_onnx.run(
106
+ ['last_hidden_state'],
107
+ {
108
+ 'input_ids': batch['input_ids'].astype(np.int64),
109
+ 'attention_mask': batch['attention_mask'].astype(np.int64)
110
+ }
111
+ )[0]
112
+
113
+ # 分离无条件和条件结果
114
+ uncond_hidden_states = all_hidden_states[0:1]
115
+ cond_hidden_states = all_hidden_states[1:]
116
+ uncond_mask = batch['attention_mask'][0:1]
117
+ cond_mask = batch['attention_mask'][1:]
118
+
119
+ return (uncond_hidden_states, uncond_mask), (cond_hidden_states, cond_mask)
120
+ else:
121
+ # 只处理条件文本
122
+ batch = tokenizer(
123
+ prompt,
124
+ padding=True,
125
+ truncation=True,
126
+ return_tensors="np"
127
+ )
128
+
129
+ # ONNX推理
130
+ hidden_states = text_encoder_onnx.run(
131
+ ['last_hidden_state'],
132
+ {
133
+ 'input_ids': batch['input_ids'].astype(np.int64),
134
+ 'attention_mask': batch['attention_mask'].astype(np.int64)
135
+ }
136
+ )[0]
137
+
138
+ return hidden_states, batch['attention_mask']
139
+
140
+ def retrieve_timesteps(scheduler, num_inference_steps, device, timesteps=None, sigmas=None):
141
+ """获取timesteps"""
142
+ if sigmas is not None:
143
+ scheduler.set_timesteps(sigmas=sigmas)
144
+ timesteps = scheduler.timesteps
145
+ num_inference_steps = len(timesteps)
146
+ else:
147
+ scheduler.set_timesteps(num_inference_steps)
148
+ timesteps = scheduler.timesteps
149
+ return timesteps, num_inference_steps
150
+
151
+ # 添加一个简单的FlowMatchScheduler类
152
+ class SimpleFlowMatchScheduler:
153
+ def __init__(self, num_train_timesteps=1000, shift=1.0):
154
+ """初始化scheduler
155
+
156
+ Args:
157
+ num_train_timesteps: 训练步数
158
+ shift: 时间步偏移量
159
+ """
160
+ # 生成线性timesteps
161
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
162
+
163
+ # 计算sigmas
164
+ sigmas = timesteps / num_train_timesteps
165
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
166
+
167
+ # 添加终止sigma
168
+ self.sigmas = np.append(sigmas, 0.0)
169
+ self.timesteps = sigmas * num_train_timesteps
170
+ self.step_index = None
171
+
172
+ def set_timesteps(self, num_inference_steps):
173
+ """设置推理时的timesteps
174
+
175
+ Args:
176
+ num_inference_steps: 推理步数
177
+ """
178
+ timesteps = np.linspace(1, len(self.timesteps), num_inference_steps, dtype=np.float32)[::-1].copy()
179
+ sigmas = timesteps / len(self.timesteps)
180
+ self.sigmas = np.append(sigmas, 0.0)
181
+ self.timesteps = sigmas * len(self.timesteps)
182
+ self.step_index = 0
183
+
184
+ def step(self, model_output, timestep, sample):
185
+ """执行一步euler更新
186
+
187
+ Args:
188
+ model_output: 模型输出
189
+ timestep: 当前时间步
190
+ sample: 当前样本
191
+
192
+ Returns:
193
+ prev_sample: 更新后的样本
194
+ """
195
+ sigma = self.sigmas[self.step_index]
196
+ sigma_next = self.sigmas[self.step_index + 1]
197
+
198
+ # euler更新
199
+ prev_sample = sample + (sigma_next - sigma) * model_output
200
+
201
+ self.step_index += 1
202
+ return prev_sample
203
+
204
+ def generate_audio_onnx(
205
+ prompt="",
206
+ negative_prompt="",
207
+ duration=10,
208
+ steps=50,
209
+ guidance_scale=4.5,
210
+ onnx_dir="./onnx_models",
211
+ output_path="output_onnx.wav",
212
+ seed=None
213
+ ):
214
+ if seed is not None:
215
+ np.random.seed(seed)
216
+
217
+ # 加载tokenizer和ONNX模型,设置固定长度
218
+ tokenizer = SimpleT5Tokenizer(f"{onnx_dir}/spiece.model", max_length=63)
219
+ text_encoder_onnx = load_onnx_model(f"{onnx_dir}/text_encoder_nf4.onnx")
220
+
221
+ # 加载其他ONNX模型
222
+ vae_decoder = load_onnx_model(f"{onnx_dir}/vae_decoder.onnx")
223
+ duration_embedder = load_onnx_model(f"{onnx_dir}/duration_embedder.onnx")
224
+ transformer = load_onnx_model(f"{onnx_dir}/transformer.onnx")
225
+ proj_layer = load_onnx_model(f"{onnx_dir}/proj.onnx")
226
+
227
+ # 1. duration embedding
228
+ duration_input = np.array([[duration]], dtype=np.float32)
229
+ print(f"[Shape] duration输入: {duration_input.shape}")
230
+
231
+ duration_hidden_states = duration_embedder.run(
232
+ ['embedding'],
233
+ {'duration': duration_input}
234
+ )[0]
235
+ print(f"[Shape] duration embedding: {duration_hidden_states.shape}")
236
+
237
+ if guidance_scale > 1.0:
238
+ duration_hidden_states = np.concatenate([duration_hidden_states] * 2, axis=0)
239
+ print(f"[Shape] 复制后的duration embedding: {duration_hidden_states.shape}")
240
+
241
+ # 2. text encoder
242
+ if guidance_scale > 1.0:
243
+ (uncond_hidden_states, uncond_mask), (cond_hidden_states, cond_mask) = encode_text(
244
+ prompt, negative_prompt, tokenizer, text_encoder_onnx, guidance_scale=guidance_scale
245
+ )
246
+ print(cond_hidden_states)
247
+ encoder_hidden_states = np.concatenate([uncond_hidden_states, cond_hidden_states])
248
+ attention_mask = np.concatenate([uncond_mask, cond_mask])
249
+ else:
250
+ encoder_hidden_states, attention_mask = encode_text(
251
+ prompt, tokenizer, text_encoder_onnx
252
+ )
253
+
254
+ # 3. pooled_text
255
+ boolean_encoder_mask = (attention_mask == 1)
256
+ mask_expanded = boolean_encoder_mask[..., None].repeat(encoder_hidden_states.shape[-1], axis=-1)
257
+ masked_data = np.where(mask_expanded, encoder_hidden_states, np.nan)
258
+ pooled = np.nanmean(masked_data, axis=1)
259
+
260
+ # 使用projection层处理
261
+ pooled_text = proj_layer.run(
262
+ ['projected'],
263
+ {'text_embedding': pooled.astype(np.float32)}
264
+ )[0]
265
+
266
+ # 4. 合并duration和text特征
267
+ encoder_hidden_states = np.concatenate(
268
+ [encoder_hidden_states, duration_hidden_states],
269
+ axis=1
270
+ )
271
+
272
+ # 5. 创建其他输入
273
+ txt_ids = np.zeros((1, encoder_hidden_states.shape[1], 3), dtype=np.int64)
274
+ img_ids = np.tile(
275
+ np.arange(645, dtype=np.int64)[None, :, None],
276
+ (1, 1, 3)
277
+ )
278
+
279
+ # 6. scheduler
280
+ scheduler = SimpleFlowMatchScheduler(num_train_timesteps=1000)
281
+ scheduler.set_timesteps(steps)
282
+
283
+ # 初始化latents
284
+ latents = np.random.randn(1, 645, 64).astype(np.float32)
285
+
286
+ # 7. 生成循环
287
+ for i in range(steps):
288
+ # Transformer前向传播
289
+ noise_pred = transformer.run(
290
+ ['output'],
291
+ {
292
+ 'hidden_states': latents,
293
+ 'timestep': np.array([scheduler.timesteps[i]/1000], dtype=np.float32),
294
+ 'pooled_text': pooled_text,
295
+ 'encoder_hidden_states': encoder_hidden_states,
296
+ 'txt_ids': txt_ids,
297
+ 'img_ids': img_ids
298
+ }
299
+ )[0]
300
+
301
+ if i == 0: # 只在第一步打印
302
+ print(f"[Shape] noise预测输出: {noise_pred.shape}")
303
+
304
+ # 应用classifier free guidance
305
+ if guidance_scale > 1.0:
306
+ noise_pred_uncond, noise_pred_text = noise_pred[0:1], noise_pred[1:2]
307
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
308
+
309
+ # 使用scheduler更新latents
310
+ latents = scheduler.step(noise_pred, scheduler.timesteps[i], latents)
311
+
312
+ if i % 10 == 0:
313
+ print(f"生成进度: {i}/{steps}")
314
+
315
+ # 8. VAE解码前的处理
316
+ latents = latents / scheduler.sigmas[0]
317
+ latents = np.transpose(latents, (0, 2, 1))
318
+
319
+ # 9. VAE解码
320
+ wave = vae_decoder.run(['audio'], {'latent': latents})[0]
321
+
322
+ # 10. 裁剪
323
+ sample_rate = 44100
324
+ waveform_end = int(duration * sample_rate)
325
+ wave = wave[:, :, :waveform_end]
326
+ print(f"[Shape] 裁剪后的最终波形: {wave.shape}")
327
+
328
+ # 11. 保存音频
329
+ wave = wave[0] # 移除batch维度
330
+ sf.write(output_path, wave.T, sample_rate) # soundfile需要(samples, channels)格式
331
+
332
+ return wave
333
+
334
+ if __name__ == "__main__":
335
+ import argparse
336
+
337
+ parser = argparse.ArgumentParser(description="测试ONNX模型推理")
338
+ parser.add_argument("--prompt", type=str, default="What does the fox say?", help="文本提示")
339
+ parser.add_argument("--negative_prompt", type=str, default="", help="负文本提示")
340
+ parser.add_argument("--onnx_dir", type=str, default=".", help="ONNX模型目录")
341
+ parser.add_argument("--duration", type=float, default=10.0, help="生成音频时长(秒)")
342
+ parser.add_argument("--steps", type=int, default=30, help="推理步数")
343
+ parser.add_argument("--guidance_scale", type=float, default=4.5, help="引导系数")
344
+ parser.add_argument("--output", type=str, default="output_onnx.wav", help="输出音频路径")
345
+ parser.add_argument("--seed", type=int, default=42, help="随机种子")
346
+
347
+ args = parser.parse_args()
348
+
349
+ # 生成音频
350
+ wave = generate_audio_onnx(
351
+ # prompt="What does the fox say?",
352
+ # prompt="Never gonna give you up, never gonna let you down",
353
+ # prompt="Electonic music, future house style",
354
+ prompt=args.prompt,
355
+ negative_prompt=args.negative_prompt,
356
+ duration=args.duration,
357
+ steps=args.steps,
358
+ guidance_scale=args.guidance_scale,
359
+ onnx_dir=args.onnx_dir,
360
+ output_path=args.output,
361
+ seed=args.seed
362
+ )
363
+
364
+ print(f"生成的音频shape为: {wave.shape}")
365
+ print(f"音频已保存到: {args.output}")
proj.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af92c1c262e6a217a75ef59c922304eb90770f9a67a6253c9c477fbe3fa9eba8
3
+ size 4198734
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
text_encoder_bnb4.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6334ec5d0eeaba54449c24ec0784c3e224db6c483a968c0fca055e001b80e39
3
+ size 305592280
transformer.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9079229578ab2b683c271f9f585fddabcfeb588191d9c02c597f0aa4b6a383b
3
+ size 2068637351
transformer.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dffd321f320c949a0cffc9b3bf92b371fccaebb5f25826710c8b89d84184d2c7
3
+ size 1118028281
vae_decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25cdc6d16f896906df2cea9374b2746842ef808563e6daeb7b48b2eb6360a4a2
3
+ size 312595968
vae_decoder.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c9782440915ffe1698717576b7889fc7d387a9c62b4874ff83337b8473b5049
3
+ size 352599027
vae_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:846b8fd27f2f6309954fb1420066b56123bfe34c3adcc100809c578011179980
3
+ size 312074746
ztu_somemodelruntime_rknnlite2.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 模块级常量和函数
2
+ from rknnlite.api import RKNNLite
3
+ import numpy as np
4
+ import os
5
+ import warnings
6
+ import logging
7
+ from typing import List, Dict, Union, Optional
8
+
9
+ try:
10
+ import onnxruntime as ort
11
+ HAS_ORT = True
12
+ except ImportError:
13
+ HAS_ORT = False
14
+ warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
15
+
16
+ # 配置日志
17
+ logger = logging.getLogger("somemodelruntime_rknnlite2")
18
+ logger.setLevel(logging.ERROR) # 默认只输出错误信息
19
+ if not logger.handlers:
20
+ handler = logging.StreamHandler()
21
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
22
+ logger.addHandler(handler)
23
+
24
+ # ONNX Runtime日志级别到Python logging级别的映射
25
+ _LOGGING_LEVEL_MAP = {
26
+ 0: logging.DEBUG, # Verbose
27
+ 1: logging.INFO, # Info
28
+ 2: logging.WARNING, # Warning
29
+ 3: logging.ERROR, # Error
30
+ 4: logging.CRITICAL # Fatal
31
+ }
32
+
33
+ def set_default_logger_severity(level: int) -> None:
34
+ """
35
+ Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
36
+
37
+ Args:
38
+ level: 日志级别(0-4)
39
+ """
40
+ if level not in _LOGGING_LEVEL_MAP:
41
+ raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
42
+ logger.setLevel(_LOGGING_LEVEL_MAP[level])
43
+
44
+ def set_default_logger_verbosity(level: int) -> None:
45
+ """
46
+ Sets the default logging verbosity level. To activate the verbose log,
47
+ you need to set the default logging severity to 0:Verbose level.
48
+
49
+ Args:
50
+ level: 日志级别(0-4)
51
+ """
52
+ set_default_logger_severity(level)
53
+
54
+ # NPU核心模式常量
55
+ NPU_CORE_AUTO = 0 # 自动选择
56
+ NPU_CORE_0 = 1 # 使用核心0
57
+ NPU_CORE_1 = 2 # 使用核心1
58
+ NPU_CORE_2 = 4 # 使用核心2
59
+ NPU_CORE_0_1 = 3 # 使用核心0和1
60
+ NPU_CORE_0_1_2 = 7 # 使用所有核心
61
+ NPU_CORE_ALL = 0xffff # 使用所有核心
62
+
63
+ # RKNN tensor type到numpy dtype的映射
64
+ RKNN_DTYPE_MAP = {
65
+ 0: np.float32, # RKNN_TENSOR_FLOAT32
66
+ 1: np.float16, # RKNN_TENSOR_FLOAT16
67
+ 2: np.int8, # RKNN_TENSOR_INT8
68
+ 3: np.uint8, # RKNN_TENSOR_UINT8
69
+ 4: np.int16, # RKNN_TENSOR_INT16
70
+ 5: np.uint16, # RKNN_TENSOR_UINT16
71
+ 6: np.int32, # RKNN_TENSOR_INT32
72
+ 7: np.uint32, # RKNN_TENSOR_UINT32
73
+ 8: np.int64, # RKNN_TENSOR_INT64
74
+ 9: bool, # RKNN_TENSOR_BOOL
75
+ 10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
76
+ }
77
+
78
+ def get_available_providers() -> List[str]:
79
+ """
80
+ 获取可用的设备提供者列表(为保持接口兼容性的占位函数)
81
+
82
+ Returns:
83
+ list: 可用的设备提供者列表,总是返回["CPUExecutionProvider"]
84
+ """
85
+ return ["CPUExecutionProvider"]
86
+
87
+ def get_version_info() -> Dict[str, str]:
88
+ """
89
+ 获取版本信息
90
+
91
+ Returns:
92
+ dict: 包含API和驱动版本信息的字典
93
+ """
94
+ runtime = RKNNLite()
95
+ version = runtime.get_sdk_version()
96
+ return {
97
+ "api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
98
+ "driver_version": version.split('\n')[3].split(': ')[1]
99
+ }
100
+
101
+ class IOTensor:
102
+ """输入/输出张量的信息封装类"""
103
+ def __init__(self, name, shape, type=None):
104
+ self.name = name.decode() if isinstance(name, bytes) else name
105
+ self.shape = shape
106
+ self.type = type
107
+
108
+ def __str__(self):
109
+ return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
110
+
111
+ class SessionOptions:
112
+ """会话选项类"""
113
+ def __init__(self):
114
+ self.async_mode = False # 是否使用异步模式
115
+ self.core_mask = 0 # NPU核心选择
116
+ self.perf_debug = False # 是否启用性能分析
117
+
118
+ class InferenceSession:
119
+ """
120
+ RKNNLite运行时封装类,API风格类似ONNX Runtime
121
+ """
122
+
123
+ def __new__(cls, model_path: str, verbose: bool = False, sess_options: Optional[SessionOptions] = None, fallback: bool = True, **kwargs):
124
+ """
125
+ 创建运行时实例
126
+
127
+ Args:
128
+ model_path: 模型文件路径(.rknn或.onnx)
129
+ verbose: 是否打印详细日志
130
+ sess_options: 会话选项
131
+ fallback: 是否自动加载同名.rknn文件
132
+ **kwargs: 其他初始化参数
133
+ """
134
+ # 只在verbose=True时开启详细日志
135
+ if verbose:
136
+ set_default_logger_severity(0)
137
+
138
+ if not os.path.exists(model_path):
139
+ logger.error(f"模型文件不存在: {model_path}")
140
+ raise FileNotFoundError(f"模型文件不存在: {model_path}")
141
+
142
+ # 检查是否是ONNX文件
143
+ is_onnx = model_path.lower().endswith('.onnx')
144
+
145
+ if is_onnx and fallback:
146
+ # 尝试查找对应的RKNN文件
147
+ rknn_path = os.path.splitext(model_path)[0] + '.rknn'
148
+ if os.path.exists(rknn_path):
149
+ logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
150
+ # 创建RKNN运行时实例
151
+ instance = super().__new__(cls)
152
+ instance.model_path = rknn_path
153
+ return instance
154
+
155
+ if is_onnx:
156
+ # 使用ONNX Runtime
157
+ logger.info(f"使用ONNX Runtime加载模型: {model_path}")
158
+ if not HAS_ORT:
159
+ raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
160
+ return ort.InferenceSession(model_path, sess_options=sess_options, **kwargs)
161
+
162
+ # 创建RKNN运行时实例
163
+ instance = super().__new__(cls)
164
+ instance.model_path = model_path
165
+ return instance
166
+
167
+ def __init__(self, model_path: str, verbose: bool = False, sess_options: Optional[SessionOptions] = None, fallback: bool = True, **kwargs):
168
+ """
169
+ 初始化RKNN运行时
170
+
171
+ Args:
172
+ model_path: 模型文件路径(.rknn或.onnx)
173
+ verbose: 是否打印详细日志
174
+ sess_options: 会话选项
175
+ fallback: 是否自动加载同名.rknn文件
176
+ **kwargs: 其他初始化参数
177
+ """
178
+ # 如果是ONNX模型,__init__不会被调用
179
+ if not hasattr(self, 'model_path'): # 如果是ONNX Runtime实例
180
+ return
181
+
182
+ self.runtime = RKNNLite(verbose=verbose)
183
+
184
+ # 加载模型
185
+ logger.debug(f"正在加载模型: {self.model_path}")
186
+ ret = self.runtime.load_rknn(self.model_path)
187
+ if ret != 0:
188
+ logger.error(f"加载RKNN模型失败: {self.model_path}")
189
+ raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
190
+ logger.debug("模型加载成功")
191
+
192
+ # 应用会话选项
193
+ options = sess_options or SessionOptions()
194
+
195
+ # 初始化运行时
196
+ logger.debug("正在初始化运行时环境")
197
+ ret = self.runtime.init_runtime(
198
+ async_mode=options.async_mode,
199
+ core_mask=options.core_mask
200
+ )
201
+ if ret != 0:
202
+ logger.error("初始化运行时环境失败")
203
+ raise RuntimeError('初始化运行时环境失败')
204
+ logger.debug("运行时环境初始化成功")
205
+
206
+ # 获取输入输出信息
207
+ self._init_io_info()
208
+
209
+ # 保存选项
210
+ self.options = options
211
+
212
+ def get_performance_info(self) -> Dict[str, float]:
213
+ """
214
+ 获取性能信息
215
+
216
+ Returns:
217
+ dict: 包含性能信息的字典
218
+ """
219
+ if not self.options.perf_debug:
220
+ raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
221
+
222
+ perf = self.runtime.rknn_runtime.get_run_perf()
223
+ return {
224
+ "run_duration": perf.run_duration / 1000.0 # 转换为毫秒
225
+ }
226
+
227
+ def set_core_mask(self, core_mask: int) -> None:
228
+ """
229
+ 设置NPU核心使用模式
230
+
231
+ Args:
232
+ core_mask: NPU核心掩码,使用NPU_CORE_*常量
233
+ """
234
+ ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
235
+ if ret != 0:
236
+ raise RuntimeError("设置NPU核心模式失败")
237
+
238
+ def _convert_nhwc_to_nchw(self, shape):
239
+ """将NHWC格式的shape转换为NCHW格式"""
240
+ if len(shape) == 4:
241
+ # NHWC -> NCHW
242
+ n, h, w, c = shape
243
+ return [n, c, h, w]
244
+ return shape
245
+
246
+ def _init_io_info(self):
247
+ """初始化模型的输入输出信息"""
248
+ runtime = self.runtime.rknn_runtime
249
+
250
+ # 获取输入输出数量
251
+ n_input, n_output = runtime.get_in_out_num()
252
+
253
+ # 获取输入信息
254
+ self.input_tensors = []
255
+ for i in range(n_input):
256
+ attr = runtime.get_tensor_attr(i)
257
+ shape = [attr.dims[j] for j in range(attr.n_dims)]
258
+ # 对四维输入进行NHWC到NCHW的转换
259
+ shape = self._convert_nhwc_to_nchw(shape)
260
+ # 获取dtype
261
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
262
+ tensor = IOTensor(attr.name, shape, dtype)
263
+ self.input_tensors.append(tensor)
264
+
265
+ # 获取输出信息
266
+ self.output_tensors = []
267
+ for i in range(n_output):
268
+ attr = runtime.get_tensor_attr(i, is_output=True)
269
+ shape = runtime.get_output_shape(i)
270
+ # 获取dtype
271
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
272
+ tensor = IOTensor(attr.name, shape, dtype)
273
+ self.output_tensors.append(tensor)
274
+
275
+ def get_inputs(self):
276
+ """
277
+ 获取模型输入信息
278
+
279
+ Returns:
280
+ list: 包含输入信息的列表
281
+ """
282
+ return self.input_tensors
283
+
284
+ def get_outputs(self):
285
+ """
286
+ 获取模型输出信息
287
+
288
+ Returns:
289
+ list: 包含输出信息的列表
290
+ """
291
+ return self.output_tensors
292
+
293
+ def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
294
+ """
295
+ 执行模型推理
296
+
297
+ Args:
298
+ output_names: 输出节点名称列表,指定需要返回哪些输出
299
+ input_feed: 输入数据字典或列表
300
+ data_format: 输入数据格式,"nchw"或"nhwc"
301
+ **kwargs: 其他运行时参数
302
+
303
+ Returns:
304
+ list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
305
+ """
306
+ if input_feed is None:
307
+ logger.error("input_feed不能为None")
308
+ raise ValueError("input_feed不能为None")
309
+
310
+ # 准备输入数据
311
+ if isinstance(input_feed, dict):
312
+ # 如果是字典,按照模型输入顺序排列
313
+ inputs = []
314
+ input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
315
+ for tensor in self.input_tensors:
316
+ if tensor.name not in input_feed:
317
+ raise ValueError(f"缺少输入: {tensor.name}")
318
+ inputs.append(input_feed[tensor.name])
319
+ elif isinstance(input_feed, (list, tuple)):
320
+ # 如果是列表,确保长度匹配
321
+ if len(input_feed) != len(self.input_tensors):
322
+ raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
323
+ inputs = list(input_feed)
324
+ else:
325
+ logger.error("input_feed必须是字典或列表类型")
326
+ raise ValueError("input_feed必须是字典或列表类型")
327
+
328
+ # 执行推理
329
+ try:
330
+ logger.debug("开始执行推理")
331
+ all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
332
+
333
+ # 如果没有指定output_names,返回所有输出
334
+ if output_names is None:
335
+ return all_outputs
336
+
337
+ # 获取指定的输出
338
+ output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
339
+ selected_outputs = []
340
+ for name in output_names:
341
+ if name not in output_map:
342
+ raise ValueError(f"未找到输出节点: {name}")
343
+ selected_outputs.append(all_outputs[output_map[name]])
344
+
345
+ return selected_outputs
346
+
347
+ except Exception as e:
348
+ logger.error(f"推理执行失败: {str(e)}")
349
+ raise RuntimeError(f"推理执行失败: {str(e)}")
350
+
351
+ def close(self):
352
+ """
353
+ 关闭会话,释放资源
354
+ """
355
+ if self.runtime is not None:
356
+ logger.info("正在释放运行时资源")
357
+ self.runtime.release()
358
+ self.runtime = None
359
+
360
+ def __enter__(self):
361
+ return self
362
+
363
+ def __exit__(self, exc_type, exc_val, exc_tb):
364
+ self.close()
365
+
366
+ def end_profiling(self) -> Optional[str]:
367
+ """
368
+ 结束性能分析的存根方法
369
+
370
+ Returns:
371
+ Optional[str]: None
372
+ """
373
+ warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
374
+ return None
375
+
376
+ def get_profiling_start_time_ns(self) -> int:
377
+ """
378
+ 获取性能分析开始时间的存根方法
379
+
380
+ Returns:
381
+ int: 0
382
+ """
383
+ warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
384
+ return 0
385
+
386
+ def get_modelmeta(self) -> Dict[str, str]:
387
+ """
388
+ 获取模型元数据的存根方法
389
+
390
+ Returns:
391
+ Dict[str, str]: 空字典
392
+ """
393
+ warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
394
+ return {}
395
+
396
+ def get_session_options(self) -> SessionOptions:
397
+ """
398
+ 获取会话选项
399
+
400
+ Returns:
401
+ SessionOptions: 当前会话选项
402
+ """
403
+ return self.options
404
+
405
+ def get_providers(self) -> List[str]:
406
+ """
407
+ 获取当前使用的providers的存根方法
408
+
409
+ Returns:
410
+ List[str]: ["CPUExecutionProvider"]
411
+ """
412
+ warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
413
+ return ["CPUExecutionProvider"]
414
+
415
+ def get_provider_options(self) -> Dict[str, Dict[str, str]]:
416
+ """
417
+ 获取provider选项的存根方法
418
+
419
+ Returns:
420
+ Dict[str, Dict[str, str]]: 空字典
421
+ """
422
+ warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
423
+ return {}
424
+
425
+ def get_session_config(self) -> Dict[str, str]:
426
+ """
427
+ 获取会话配置的存根方法
428
+
429
+ Returns:
430
+ Dict[str, str]: 空字典
431
+ """
432
+ warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
433
+ return {}
434
+
435
+ def get_session_state(self) -> Dict[str, str]:
436
+ """
437
+ 获取会话状态的存根方法
438
+
439
+ Returns:
440
+ Dict[str, str]: 空字典
441
+ """
442
+ warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
443
+ return {}
444
+
445
+ def set_session_config(self, config: Dict[str, str]) -> None:
446
+ """
447
+ 设置会话配置的存根方法
448
+
449
+ Args:
450
+ config: 会话配置字典
451
+ """
452
+ warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
453
+
454
+ def get_memory_info(self) -> Dict[str, int]:
455
+ """
456
+ 获取内存使用信息的存根方法
457
+
458
+ Returns:
459
+ Dict[str, int]: 空字典
460
+ """
461
+ warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
462
+ return {}
463
+
464
+ def set_memory_pattern(self, enable: bool) -> None:
465
+ """
466
+ 设置内存模式的存根方法
467
+
468
+ Args:
469
+ enable: 是否启用内存模式
470
+ """
471
+ warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
472
+
473
+ def disable_memory_pattern(self) -> None:
474
+ """
475
+ 禁用内存模式的存根方法
476
+ """
477
+ warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
478
+
479
+ def get_optimization_level(self) -> int:
480
+ """
481
+ 获取优化级别的存根方法
482
+
483
+ Returns:
484
+ int: 0
485
+ """
486
+ warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
487
+ return 0
488
+
489
+ def set_optimization_level(self, level: int) -> None:
490
+ """
491
+ 设置优化级别的存根方法
492
+
493
+ Args:
494
+ level: 优化级别
495
+ """
496
+ warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
497
+
498
+ def get_model_metadata(self) -> Dict[str, str]:
499
+ """
500
+ 获取模型元数据的存根方法(与get_modelmeta不同的接口)
501
+
502
+ Returns:
503
+ Dict[str, str]: 空字典
504
+ """
505
+ warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
506
+ return {}
507
+
508
+ def get_model_path(self) -> str:
509
+ """
510
+ 获取模型路径
511
+
512
+ Returns:
513
+ str: 模型文件路径
514
+ """
515
+ return self.model_path
516
+
517
+ def get_input_type_info(self) -> List[Dict[str, str]]:
518
+ """
519
+ 获取输入类型信息的存根方法
520
+
521
+ Returns:
522
+ List[Dict[str, str]]: 空列表
523
+ """
524
+ warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
525
+ return []
526
+
527
+ def get_output_type_info(self) -> List[Dict[str, str]]:
528
+ """
529
+ 获取输出类型信息的存根方法
530
+
531
+ Returns:
532
+ List[Dict[str, str]]: 空列表
533
+ """
534
+ warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
535
+ return []