File size: 3,170 Bytes
2ef3e1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn as nn

from transformers import Qwen2AudioForConditionalGeneration

class Qwen2AudioEncoderWrapper(nn.Module):
    """包装Qwen2Audio的编码器和映射层用于ONNX导出"""
    
    def __init__(self, model):
        super().__init__()
        self.audio_tower = model.audio_tower
        self.projector = model.multi_modal_projector
        
    def forward(self, input_features, feature_attention_mask):
        # 计算音频特征长度
        audio_feat_lengths = feature_attention_mask.sum(-1)
        batch_size, _, max_mel_seq_len = input_features.shape
        
        # 计算序列长度
        max_seq_len = (max_mel_seq_len - 2) // 2 + 1
        seq_range = torch.arange(0, max_seq_len, device=input_features.device).unsqueeze(0)
        seq_range = seq_range.expand(batch_size, max_seq_len)
        
        # 创建attention mask
        lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len) 
        padding_mask = seq_range >= lengths_expand
        audio_attention_mask = padding_mask.view(batch_size, 1, 1, max_seq_len)
        audio_attention_mask = audio_attention_mask.expand(batch_size, 1, max_seq_len, max_seq_len)
        audio_attention_mask = audio_attention_mask.float()
        audio_attention_mask = audio_attention_mask.masked_fill(audio_attention_mask.bool(), float("-inf"))

        # 获取音频特征
        audio_outputs = self.audio_tower(input_features, attention_mask=audio_attention_mask)
        audio_features = audio_outputs.last_hidden_state
        
        # 投影到文本空间
        projected_features = self.projector(audio_features)
        
        return projected_features

def export_qwen2audio_encoder(model, save_path, input_shape=(1, 80, 3000)):
    """
    导出Qwen2Audio编码器到ONNX格式
    
    Args:
        model: Qwen2AudioForConditionalGeneration模型
        save_path: 保存ONNX模型的路径
        input_shape: 输入音频特征的形状 (batch_size, n_mels, seq_len)
    """
    
    wrapper = Qwen2AudioEncoderWrapper(model)
    wrapper.eval()
    
    # 准备样例输入
    batch_size, n_mels, seq_len = input_shape
    dummy_input = torch.randn(input_shape)
    dummy_mask = torch.ones((batch_size, seq_len))
    
    # 设置动态轴
    dynamic_axes = {
        'input_features': {0: 'batch_size', 2: 'sequence_length'},
        'feature_attention_mask': {0: 'batch_size', 1: 'sequence_length'},
        'output': {0: 'batch_size', 1: 'sequence_length'}
    }
    
    # 导出ONNX
    torch.onnx.export(
        wrapper,
        (dummy_input, dummy_mask),
        save_path,
        input_names=['input_features', 'feature_attention_mask'],
        output_names=['output'],
        dynamic_axes=dynamic_axes,
        opset_version=17,
        do_constant_folding=True
    )

if __name__ == "__main__":
    # 加载模型
    model = Qwen2AudioForConditionalGeneration.from_pretrained("../Qwen2-Audio-7B-Instruct/")
    model.eval()

    # 导出ONNX
    export_qwen2audio_encoder(
        model,
        "audio_encoder.onnx",
        input_shape=(1, 128, 3000)  # batch_size=1, n_mels=128, seq_len=3000
    )