zyingt's picture
Upload 685 files
0d80816
raw
history blame contribute delete
No virus
41.7 kB
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
from __future__ import print_function
import os
import sys
import copy
import math
import yaml
import logging
from typing import Tuple
import torch
import numpy as np
from wenet.transformer.embedding import NoPositionalEncoding
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.init_model import init_model
from wenet.bin.export_onnx_cpu import get_args, to_numpy, print_input_output_info
try:
import onnx
import onnxruntime
except ImportError:
print("Please install onnx and onnxruntime!")
sys.exit(1)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
class BPULayerNorm(torch.nn.Module):
"""Refactor torch.nn.LayerNorm to meet 4-D dataflow."""
def __init__(self, module, chunk_size=8, run_on_bpu=False):
super().__init__()
original = copy.deepcopy(module)
self.hidden = module.weight.size(0)
self.chunk_size = chunk_size
self.run_on_bpu = run_on_bpu
if self.run_on_bpu:
self.weight = torch.nn.Parameter(
module.weight.reshape(1, self.hidden, 1, 1).repeat(1, 1, 1, chunk_size)
)
self.bias = torch.nn.Parameter(
module.bias.reshape(1, self.hidden, 1, 1).repeat(1, 1, 1, chunk_size)
)
self.negtive = torch.nn.Parameter(
torch.ones((1, self.hidden, 1, chunk_size)) * -1.0
)
self.eps = torch.nn.Parameter(
torch.zeros((1, self.hidden, 1, chunk_size)) + module.eps
)
self.mean_conv_1 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False)
self.mean_conv_1.weight = torch.nn.Parameter(
torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden)
)
self.mean_conv_2 = torch.nn.Conv2d(self.hidden, 1, 1, bias=False)
self.mean_conv_2.weight = torch.nn.Parameter(
torch.ones(self.hidden, self.hidden, 1, 1) / (1.0 * self.hidden)
)
else:
self.norm = module
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, self.chunk_size, self.hidden)
orig_out = module(random_data)
new_out = self.forward(random_data.transpose(1, 2).unsqueeze(2))
np.testing.assert_allclose(
to_numpy(orig_out),
to_numpy(new_out.squeeze(2).transpose(1, 2)),
rtol=1e-02,
atol=1e-03,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.run_on_bpu:
u = self.mean_conv_1(x) # (1, h, 1, c)
numerator = x + u * self.negtive # (1, h, 1, c)
s = torch.pow(numerator, 2) # (1, h, 1, c)
s = self.mean_conv_2(s) # (1, h, 1, c)
denominator = torch.sqrt(s + self.eps) # (1, h, 1, c)
x = torch.div(numerator, denominator) # (1, h, 1, c)
x = x * self.weight + self.bias
else:
x = x.squeeze(2).transpose(1, 2).contiguous()
x = self.norm(x)
x = x.transpose(1, 2).contiguous().unsqueeze(2)
return x
class BPUIdentity(torch.nn.Module):
"""Refactor torch.nn.Identity().
For inserting BPU node whose input == output.
"""
def __init__(self, channels):
super().__init__()
self.channels = channels
self.identity_conv = torch.nn.Conv2d(
channels, channels, 1, groups=channels, bias=False
)
torch.nn.init.dirac_(self.identity_conv.weight.data, groups=channels)
self.check_equal()
def check_equal(self):
random_data = torch.randn(1, self.channels, 1, 10)
result = self.forward(random_data)
np.testing.assert_allclose(
to_numpy(random_data), to_numpy(result), rtol=1e-02, atol=1e-03
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Identity with 4-D dataflow, input == output.
Args:
x (torch.Tensor): (batch, in_channel, 1, time)
Returns:
(torch.Tensor): (batch, in_channel, 1, time).
"""
return self.identity_conv(x)
class BPULinear(torch.nn.Module):
"""Refactor torch.nn.Linear or pointwise_conv"""
def __init__(self, module, is_pointwise_conv=False):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.idim = module.weight.size(1)
self.odim = module.weight.size(0)
self.is_pointwise_conv = is_pointwise_conv
# Modify weight & bias
self.linear = torch.nn.Conv2d(self.idim, self.odim, 1, 1)
if is_pointwise_conv:
# (odim, idim, kernel=1) -> (odim, idim, 1, 1)
self.linear.weight = torch.nn.Parameter(module.weight.unsqueeze(-1))
else:
# (odim, idim) -> (odim, idim, 1, 1)
self.linear.weight = torch.nn.Parameter(
module.weight.unsqueeze(2).unsqueeze(3)
)
self.linear.bias = module.bias
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, 8, self.idim)
if self.is_pointwise_conv:
random_data = random_data.transpose(1, 2)
original_result = module(random_data)
if self.is_pointwise_conv:
random_data = random_data.transpose(1, 2)
original_result = original_result.transpose(1, 2)
random_data = random_data.transpose(1, 2).unsqueeze(2)
new_result = self.forward(random_data)
np.testing.assert_allclose(
to_numpy(original_result),
to_numpy(new_result.squeeze(2).transpose(1, 2)),
rtol=1e-02,
atol=1e-03,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Linear with 4-D dataflow.
Args:
x (torch.Tensor): (batch, in_channel, 1, time)
Returns:
(torch.Tensor): (batch, out_channel, 1, time).
"""
return self.linear(x)
class BPUGlobalCMVN(torch.nn.Module):
"""Refactor wenet/transformer/cmvn.py::GlobalCMVN"""
def __init__(self, module):
super().__init__()
# Unchanged submodules and attributes
self.norm_var = module.norm_var
# NOTE(xcsong): Expand to 4-D tensor, (mel_dim) -> (1, 1, mel_dim, 1)
self.mean = module.mean.unsqueeze(-1).unsqueeze(0).unsqueeze(0)
self.istd = module.istd.unsqueeze(-1).unsqueeze(0).unsqueeze(0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""CMVN with 4-D dataflow.
Args:
x (torch.Tensor): (batch, 1, mel_dim, time)
Returns:
(torch.Tensor): normalized feature with same shape.
"""
x = x - self.mean
if self.norm_var:
x = x * self.istd
return x
class BPUConv2dSubsampling8(torch.nn.Module):
"""Refactor wenet/transformer/subsampling.py::Conv2dSubsampling8
NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding
"""
def __init__(self, module):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.right_context = module.right_context
self.subsampling_rate = module.subsampling_rate
assert isinstance(module.pos_enc, NoPositionalEncoding)
# 1. Modify self.conv
# NOTE(xcsong): We change input shape from (1, 1, frames, mel_dim)
# to (1, 1, mel_dim, frames) for more efficient computation.
self.conv = module.conv
for idx in [0, 2, 4]:
self.conv[idx].weight = torch.nn.Parameter(
module.conv[idx].weight.transpose(2, 3)
)
# 2. Modify self.linear
# NOTE(xcsong): Split final projection to meet the requirment of
# maximum kernel_size (7 for XJ3)
self.linear = torch.nn.ModuleList()
odim = module.linear.weight.size(0) # 512, in this case
freq = module.linear.weight.size(1) // odim # 4608 // 512 == 9
self.odim, self.freq = odim, freq
weight = module.linear.weight.reshape(
odim, odim, freq, 1
) # (odim, odim * freq) -> (odim, odim, freq, 1)
self.split_size = []
num_split = (freq - 1) // 7 + 1 # XJ3 requires kernel_size <= 7
slice_begin = 0
for idx in range(num_split):
kernel_size = min(freq, (idx + 1) * 7) - idx * 7
conv_ele = torch.nn.Conv2d(odim, odim, (kernel_size, 1), (kernel_size, 1))
conv_ele.weight = torch.nn.Parameter(
weight[:, :, slice_begin : slice_begin + kernel_size, :]
)
conv_ele.bias = torch.nn.Parameter(torch.zeros_like(conv_ele.bias))
self.linear.append(conv_ele)
self.split_size.append(kernel_size)
slice_begin += kernel_size
self.linear[0].bias = torch.nn.Parameter(module.linear.bias)
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, 67, 80)
mask = torch.zeros(1, 1, 67)
original_result, _, _ = module(random_data, mask) # (1, 8, 512)
random_data = random_data.transpose(1, 2).unsqueeze(0) # (1, 1, 80, 67)
new_result = self.forward(random_data) # (1, 512, 1, 8)
np.testing.assert_allclose(
to_numpy(original_result),
to_numpy(new_result.squeeze(2).transpose(1, 2)),
rtol=1e-02,
atol=1e-03,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x with 4-D dataflow.
Args:
x (torch.Tensor): Input tensor (#batch, 1, mel_dim, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, odim, 1, time'),
where time' = time // 8.
"""
x = self.conv(x) # (1, odim, freq, time')
x_out = torch.zeros(x.size(0), self.odim, 1, x.size(3))
x = torch.split(x, self.split_size, dim=2)
for idx, (x_part, layer) in enumerate(zip(x, self.linear)):
x_out += layer(x_part)
return x_out
class BPUMultiHeadedAttention(torch.nn.Module):
"""Refactor wenet/transformer/attention.py::MultiHeadedAttention
NOTE(xcsong): Only support attention_class == MultiHeadedAttention,
we do not consider RelPositionMultiHeadedAttention currently.
"""
def __init__(self, module, chunk_size, left_chunks):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.d_k = module.d_k
self.h = module.h
n_feat = self.d_k * self.h
self.chunk_size = chunk_size
self.left_chunks = left_chunks
self.time = chunk_size * (left_chunks + 1)
self.activation = torch.nn.Softmax(dim=-1)
# 1. Modify self.linear_x
self.linear_q = BPULinear(module.linear_q)
self.linear_k = BPULinear(module.linear_k)
self.linear_v = BPULinear(module.linear_v)
self.linear_out = BPULinear(module.linear_out)
# 2. denom
self.register_buffer(
"denom", torch.full((1, self.h, 1, 1), 1.0 / math.sqrt(self.d_k))
)
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, self.chunk_size, self.d_k * self.h)
mask = torch.ones((1, self.h, self.chunk_size, self.time), dtype=torch.bool)
cache = torch.zeros(1, self.h, self.chunk_size * self.left_chunks, self.d_k * 2)
original_out, original_cache = module(
random_data,
random_data,
random_data,
mask[:, 0, :, :],
torch.empty(0),
cache,
)
random_data = random_data.transpose(1, 2).unsqueeze(2)
cache = cache.reshape(
1, self.h, self.d_k * 2, self.chunk_size * self.left_chunks
)
new_out, new_cache = self.forward(
random_data, random_data, random_data, mask, cache
)
np.testing.assert_allclose(
to_numpy(original_out),
to_numpy(new_out.squeeze(2).transpose(1, 2)),
rtol=1e-02,
atol=1e-03,
)
np.testing.assert_allclose(
to_numpy(original_cache),
to_numpy(new_cache.transpose(2, 3)),
rtol=1e-02,
atol=1e-03,
)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor,
cache: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute scaled dot product attention.
Args:
q (torch.Tensor): Query tensor (#batch, size, 1, chunk_size).
k (torch.Tensor): Key tensor (#batch, size, 1, chunk_size).
v (torch.Tensor): Value tensor (#batch, size, 1, chunk_size).
mask (torch.Tensor): Mask tensor,
(#batch, head, chunk_size, cache_t + chunk_size).
cache (torch.Tensor): Cache tensor
(1, head, d_k * 2, cache_t),
where `cache_t == chunk_size * left_chunks`.
Returns:
torch.Tensor: Output tensor (#batch, size, 1, chunk_size).
torch.Tensor: Cache tensor
(1, head, d_k * 2, cache_t + chunk_size)
where `cache_t == chunk_size * left_chunks`
"""
# 1. Forward QKV
q = self.linear_q(q) # (1, d, 1, c) d == size, c == chunk_size
k = self.linear_k(k) # (1, d, 1, c)
v = self.linear_v(v) # (1, d, 1, c)
q = q.view(1, self.h, self.d_k, self.chunk_size)
k = k.view(1, self.h, self.d_k, self.chunk_size)
v = v.view(1, self.h, self.d_k, self.chunk_size)
q = q.transpose(2, 3) # (batch, head, time1, d_k)
k_cache, v_cache = torch.split(cache, cache.size(2) // 2, dim=2)
k = torch.cat((k_cache, k), dim=3)
v = torch.cat((v_cache, v), dim=3)
new_cache = torch.cat((k, v), dim=2)
# 2. (Q^T)K
scores = torch.matmul(q, k) * self.denom # (#b, n_head, time1, time2)
# 3. Forward attention
mask = mask.eq(0)
scores = scores.masked_fill(mask, -float("inf"))
attn = self.activation(scores).masked_fill(mask, 0.0)
attn = attn.transpose(2, 3)
x = torch.matmul(v, attn)
x = x.view(1, self.d_k * self.h, 1, self.chunk_size)
x_out = self.linear_out(x)
return x_out, new_cache
class BPUConvolution(torch.nn.Module):
"""Refactor wenet/transformer/convolution.py::ConvolutionModule
NOTE(xcsong): Only suport use_layer_norm == False
"""
def __init__(self, module):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.lorder = module.lorder
self.use_layer_norm = False
self.activation = module.activation
channels = module.pointwise_conv1.weight.size(1)
self.channels = channels
kernel_size = module.depthwise_conv.weight.size(2)
assert module.use_layer_norm is False
# 1. Modify self.pointwise_conv1
self.pointwise_conv1 = BPULinear(module.pointwise_conv1, True)
# 2. Modify self.depthwise_conv
self.depthwise_conv = torch.nn.Conv2d(
channels, channels, (1, kernel_size), stride=1, groups=channels
)
self.depthwise_conv.weight = torch.nn.Parameter(
module.depthwise_conv.weight.unsqueeze(-2)
)
self.depthwise_conv.bias = torch.nn.Parameter(module.depthwise_conv.bias)
# 3. Modify self.norm, Only support batchnorm2d
self.norm = torch.nn.BatchNorm2d(channels)
self.norm.training = False
self.norm.num_features = module.norm.num_features
self.norm.eps = module.norm.eps
self.norm.momentum = module.norm.momentum
self.norm.weight = torch.nn.Parameter(module.norm.weight)
self.norm.bias = torch.nn.Parameter(module.norm.bias)
self.norm.running_mean = module.norm.running_mean
self.norm.running_var = module.norm.running_var
# 4. Modify self.pointwise_conv2
self.pointwise_conv2 = BPULinear(module.pointwise_conv2, True)
# 5. Identity conv, for running `concat` on BPU
self.identity = BPUIdentity(channels)
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, 8, self.channels)
cache = torch.zeros((1, self.channels, self.lorder))
original_out, original_cache = module(random_data, cache=cache)
random_data = random_data.transpose(1, 2).unsqueeze(2)
cache = cache.unsqueeze(2)
new_out, new_cache = self.forward(random_data, cache)
np.testing.assert_allclose(
to_numpy(original_out),
to_numpy(new_out.squeeze(2).transpose(1, 2)),
rtol=1e-02,
atol=1e-03,
)
np.testing.assert_allclose(
to_numpy(original_cache),
to_numpy(new_cache.squeeze(2)),
rtol=1e-02,
atol=1e-03,
)
def forward(
self, x: torch.Tensor, cache: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size).
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, 1, cache_t).
Returns:
torch.Tensor: Output tensor (#batch, channels, 1, chunk_size).
torch.Tensor: Cache tensor (#batch, channels, 1, cache_t).
"""
# Concat cache
x = torch.cat((self.identity(cache), self.identity(x)), dim=3)
new_cache = x[:, :, :, -self.lorder :]
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, 1, dim)
x = torch.nn.functional.glu(x, dim=1) # (b, channel, 1, dim)
# Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x)
return x, new_cache
class BPUFFN(torch.nn.Module):
"""Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward"""
def __init__(self, module):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.activation = module.activation
# 1. Modify self.w_x
self.w_1 = BPULinear(module.w_1)
self.w_2 = BPULinear(module.w_2)
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, 8, self.w_1.idim)
original_out = module(random_data)
random_data = random_data.transpose(1, 2).unsqueeze(2)
new_out = self.forward(random_data)
np.testing.assert_allclose(
to_numpy(original_out),
to_numpy(new_out.squeeze(2).transpose(1, 2)),
rtol=1e-02,
atol=1e-03,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
xs: input tensor (B, D, 1, L)
Returns:
output tensor, (B, D, 1, L)
"""
return self.w_2(self.activation(self.w_1(x)))
class BPUConformerEncoderLayer(torch.nn.Module):
"""Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer"""
def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.size = module.size
assert module.normalize_before is True
assert module.concat_after is False
# 1. Modify submodules
self.feed_forward_macaron = BPUFFN(module.feed_forward_macaron)
self.self_attn = BPUMultiHeadedAttention(
module.self_attn, chunk_size, left_chunks
)
self.conv_module = BPUConvolution(module.conv_module)
self.feed_forward = BPUFFN(module.feed_forward)
# 2. Modify norms
self.norm_ff = BPULayerNorm(module.norm_ff, chunk_size, ln_run_on_bpu)
self.norm_mha = BPULayerNorm(module.norm_mha, chunk_size, ln_run_on_bpu)
self.norm_ff_macron = BPULayerNorm(
module.norm_ff_macaron, chunk_size, ln_run_on_bpu
)
self.norm_conv = BPULayerNorm(module.norm_conv, chunk_size, ln_run_on_bpu)
self.norm_final = BPULayerNorm(module.norm_final, chunk_size, ln_run_on_bpu)
# 3. 4-D ff_scale
self.register_buffer(
"ff_scale", torch.full((1, self.size, 1, 1), module.ff_scale)
)
self.check_equal(original)
def check_equal(self, module):
time1 = self.self_attn.chunk_size
time2 = self.self_attn.time
h, d_k = self.self_attn.h, self.self_attn.d_k
random_x = torch.randn(1, time1, self.size)
att_mask = torch.ones(1, h, time1, time2)
att_cache = torch.zeros(1, h, time2 - time1, d_k * 2)
cnn_cache = torch.zeros(1, self.size, self.conv_module.lorder)
original_x, _, original_att_cache, original_cnn_cache = module(
random_x,
att_mask[:, 0, :, :],
torch.empty(0),
att_cache=att_cache,
cnn_cache=cnn_cache,
)
random_x = random_x.transpose(1, 2).unsqueeze(2)
att_cache = att_cache.reshape(1, h, d_k * 2, time2 - time1)
cnn_cache = cnn_cache.unsqueeze(2)
new_x, new_att_cache, new_cnn_cache = self.forward(
random_x, att_mask, att_cache, cnn_cache
)
np.testing.assert_allclose(
to_numpy(original_att_cache),
to_numpy(new_att_cache.transpose(2, 3)),
rtol=1e-02,
atol=1e-03,
)
np.testing.assert_allclose(
to_numpy(original_x),
to_numpy(new_x.squeeze(2).transpose(1, 2)),
rtol=1e-02,
atol=1e-03,
)
np.testing.assert_allclose(
to_numpy(original_cnn_cache),
to_numpy(new_cnn_cache.squeeze(2)),
rtol=1e-02,
atol=1e-03,
)
def forward(
self,
x: torch.Tensor,
att_mask: torch.Tensor,
att_cache: torch.Tensor,
cnn_cache: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, size, 1, chunk_size)
att_mask (torch.Tensor): Mask tensor for the input
(#batch, head, chunk_size, cache_t1 + chunk_size),
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, d_k * 2, cache_t1), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, 1, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, size, 1, chunk_size).
torch.Tensor: att_cache tensor,
(1, head, d_k * 2, cache_t1 + chunk_size).
torch.Tensor: cnn_cahce tensor (#batch, size, 1, cache_t2).
"""
# 1. ffn_macaron
residual = x
x = self.norm_ff_macron(x)
x = residual + self.ff_scale * self.feed_forward_macaron(x)
# 2. attention
residual = x
x = self.norm_mha(x)
x_att, new_att_cache = self.self_attn(x, x, x, att_mask, att_cache)
x = residual + x_att
# 3. convolution
residual = x
x = self.norm_conv(x)
x, new_cnn_cache = self.conv_module(x, cnn_cache)
x = residual + x
# 4. ffn
residual = x
x = self.norm_ff(x)
x = residual + self.ff_scale * self.feed_forward(x)
# 5. final post-norm
x = self.norm_final(x)
return x, new_att_cache, new_cnn_cache
class BPUConformerEncoder(torch.nn.Module):
"""Refactor wenet/transformer/encoder.py::ConformerEncoder"""
def __init__(self, module, chunk_size, left_chunks, ln_run_on_bpu=False):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
output_size = module.output_size()
self._output_size = module.output_size()
self.after_norm = module.after_norm
self.chunk_size = chunk_size
self.left_chunks = left_chunks
self.head = module.encoders[0].self_attn.h
self.layers = len(module.encoders)
# 1. Modify submodules
self.global_cmvn = BPUGlobalCMVN(module.global_cmvn)
self.embed = BPUConv2dSubsampling8(module.embed)
self.encoders = torch.nn.ModuleList()
for layer in module.encoders:
self.encoders.append(
BPUConformerEncoderLayer(layer, chunk_size, left_chunks, ln_run_on_bpu)
)
# 2. Auxiliary conv
self.identity_cnncache = BPUIdentity(output_size)
self.check_equal(original)
def check_equal(self, module):
time1 = self.encoders[0].self_attn.chunk_size
time2 = self.encoders[0].self_attn.time
layers = self.layers
h, d_k = self.head, self.encoders[0].self_attn.d_k
decoding_window = (
(self.chunk_size - 1) * module.embed.subsampling_rate
+ module.embed.right_context
+ 1
)
lorder = self.encoders[0].conv_module.lorder
random_x = torch.randn(1, decoding_window, 80)
att_mask = torch.ones(1, h, time1, time2)
att_cache = torch.zeros(layers, h, time2 - time1, d_k * 2)
cnn_cache = torch.zeros(layers, 1, self._output_size, lorder)
orig_x, orig_att_cache, orig_cnn_cache = module.forward_chunk(
random_x,
0,
time2 - time1,
att_mask=att_mask[:, 0, :, :],
att_cache=att_cache,
cnn_cache=cnn_cache,
)
random_x = random_x.unsqueeze(0)
att_cache = att_cache.reshape(1, h * layers, d_k * 2, time2 - time1)
cnn_cache = cnn_cache.reshape(1, self._output_size, layers, lorder)
new_x, new_att_cache, new_cnn_cache = self.forward(
random_x, att_cache, cnn_cache, att_mask
)
caches = torch.split(new_att_cache, h, dim=1)
caches = [c.transpose(2, 3) for c in caches]
np.testing.assert_allclose(
to_numpy(orig_att_cache),
to_numpy(torch.cat(caches, dim=0)),
rtol=1e-02,
atol=1e-03,
)
np.testing.assert_allclose(
to_numpy(orig_x),
to_numpy(new_x.squeeze(2).transpose(1, 2)),
rtol=1e-02,
atol=1e-03,
)
np.testing.assert_allclose(
to_numpy(orig_cnn_cache),
to_numpy(new_cnn_cache.transpose(0, 2).transpose(1, 2)),
rtol=1e-02,
atol=1e-03,
)
def forward(
self,
xs: torch.Tensor,
att_cache: torch.Tensor,
cnn_cache: torch.Tensor,
att_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input, with shape (b=1, 1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(1, head * elayers, d_k * 2, cache_t1), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(1, hidden-dim, elayers, cache_t2), where
`cache_t2 == cnn.lorder - 1`
att_mask (torch.Tensor): Mask tensor for the input
(#batch, head, chunk_size, cache_t1 + chunk_size),
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, hidden-dim, 1, chunk_size).
torch.Tensor: new attention cache required for next chunk, with
same shape as the original att_cache.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
# xs: (B, 1, time, mel_dim) -> (B, 1, mel_dim, time)
xs = xs.transpose(2, 3)
xs = self.global_cmvn(xs)
# xs: (B, 1, mel_dim, time) -> (B, hidden_dim, 1, chunk_size)
xs = self.embed(xs)
att_cache = torch.split(att_cache, self.head, dim=1)
cnn_cache = self.identity_cnncache(cnn_cache)
cnn_cache = torch.split(cnn_cache, 1, dim=2)
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoders):
xs, new_att_cache, new_cnn_cache = layer(
xs, att_mask, att_cache=att_cache[i], cnn_cache=cnn_cache[i]
)
r_att_cache.append(new_att_cache[:, :, :, self.chunk_size :])
r_cnn_cache.append(new_cnn_cache)
r_att_cache = torch.cat(r_att_cache, dim=1)
r_cnn_cache = self.identity_cnncache(torch.cat(r_cnn_cache, dim=2))
xs = xs.squeeze(2).transpose(1, 2).contiguous()
xs = self.after_norm(xs)
# NOTE(xcsong): 4D in, 4D out to meet the requirment of CTC input.
xs = xs.transpose(1, 2).contiguous().unsqueeze(2) # (B, C, 1, T)
return (xs, r_att_cache, r_cnn_cache)
class BPUCTC(torch.nn.Module):
"""Refactor wenet/transformer/ctc.py::CTC"""
def __init__(self, module):
super().__init__()
# Unchanged submodules and attributes
original = copy.deepcopy(module)
self.idim = module.ctc_lo.weight.size(1)
num_class = module.ctc_lo.weight.size(0)
# 1. Modify self.ctc_lo, Split final projection to meet the
# requirment of maximum in/out channels (2048 for XJ3)
self.ctc_lo = torch.nn.ModuleList()
self.split_size = []
num_split = (num_class - 1) // 2048 + 1
for idx in range(num_split):
out_channel = min(num_class, (idx + 1) * 2048) - idx * 2048
conv_ele = torch.nn.Conv2d(self.idim, out_channel, 1, 1)
self.ctc_lo.append(conv_ele)
self.split_size.append(out_channel)
orig_weight = torch.split(module.ctc_lo.weight, self.split_size, dim=0)
orig_bias = torch.split(module.ctc_lo.bias, self.split_size, dim=0)
for i, (w, b) in enumerate(zip(orig_weight, orig_bias)):
w = w.unsqueeze(2).unsqueeze(3)
self.ctc_lo[i].weight = torch.nn.Parameter(w)
self.ctc_lo[i].bias = torch.nn.Parameter(b)
self.check_equal(original)
def check_equal(self, module):
random_data = torch.randn(1, 100, self.idim)
original_result = module.ctc_lo(random_data)
random_data = random_data.transpose(1, 2).unsqueeze(2)
new_result = self.forward(random_data)
np.testing.assert_allclose(
to_numpy(original_result),
to_numpy(new_result.squeeze(2).transpose(1, 2)),
rtol=1e-02,
atol=1e-03,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""frame activations, without softmax.
Args:
Tensor x: 4d tensor (B, hidden_dim, 1, chunk_size)
Returns:
torch.Tensor: (B, num_class, 1, chunk_size)
"""
out = []
for i, layer in enumerate(self.ctc_lo):
out.append(layer(x))
out = torch.cat(out, dim=1)
return out
def export_encoder(asr_model, args):
logger.info("Stage-1: export encoder")
decode_window, mel_dim = args.decoding_window, args.feature_size
encoder = BPUConformerEncoder(
asr_model.encoder,
args.chunk_size,
args.num_decoding_left_chunks,
args.ln_run_on_bpu,
)
encoder.eval()
encoder_outpath = os.path.join(args.output_dir, "encoder.onnx")
logger.info("Stage-1.1: prepare inputs for encoder")
chunk = torch.randn((1, 1, decode_window, mel_dim))
required_cache_size = encoder.chunk_size * encoder.left_chunks
kv_time = required_cache_size + encoder.chunk_size
hidden, layers = encoder._output_size, len(encoder.encoders)
head = encoder.encoders[0].self_attn.h
d_k = hidden // head
lorder = encoder.encoders[0].conv_module.lorder
att_cache = torch.zeros(1, layers * head, d_k * 2, required_cache_size)
att_mask = torch.ones((1, head, encoder.chunk_size, kv_time))
att_mask[:, :, :, :required_cache_size] = 0
cnn_cache = torch.zeros((1, hidden, layers, lorder))
inputs = (chunk, att_cache, cnn_cache, att_mask)
logger.info(
"chunk.size(): {} att_cache.size(): {} "
"cnn_cache.size(): {} att_mask.size(): {}".format(
list(chunk.size()),
list(att_cache.size()),
list(cnn_cache.size()),
list(att_mask.size()),
)
)
logger.info("Stage-1.2: torch.onnx.export")
# NOTE(xcsong): Below attributes will be used in
# onnx2horizonbin.py::generate_config()
attributes = {}
attributes["input_name"] = "chunk;att_cache;cnn_cache;att_mask"
attributes["output_name"] = "output;r_att_cache;r_cnn_cache"
attributes["input_type"] = "featuremap;featuremap;featuremap;featuremap"
attributes["norm_type"] = "no_preprocess;no_preprocess;no_preprocess;no_preprocess"
attributes["input_layout_train"] = "NCHW;NCHW;NCHW;NCHW"
attributes["input_layout_rt"] = "NCHW;NCHW;NCHW;NCHW"
attributes[
"input_shape"
] = "{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{}".format(
chunk.size(0),
chunk.size(1),
chunk.size(2),
chunk.size(3),
att_cache.size(0),
att_cache.size(1),
att_cache.size(2),
att_cache.size(3),
cnn_cache.size(0),
cnn_cache.size(1),
cnn_cache.size(2),
cnn_cache.size(3),
att_mask.size(0),
att_mask.size(1),
att_mask.size(2),
att_mask.size(3),
)
torch.onnx.export( # NOTE(xcsong): only support opset==11
encoder,
inputs,
encoder_outpath,
opset_version=11,
export_params=True,
do_constant_folding=True,
input_names=attributes["input_name"].split(";"),
output_names=attributes["output_name"].split(";"),
dynamic_axes=None,
verbose=False,
)
onnx_encoder = onnx.load(encoder_outpath)
for k in vars(args):
meta = onnx_encoder.metadata_props.add()
meta.key, meta.value = str(k), str(getattr(args, k))
for k in attributes:
meta = onnx_encoder.metadata_props.add()
meta.key, meta.value = str(k), str(attributes[k])
onnx.checker.check_model(onnx_encoder)
onnx.helper.printable_graph(onnx_encoder.graph)
onnx.save(onnx_encoder, encoder_outpath)
print_input_output_info(onnx_encoder, "onnx_encoder")
logger.info("Export onnx_encoder, done! see {}".format(encoder_outpath))
logger.info("Stage-1.3: check onnx_encoder and torch_encoder")
torch_output = []
torch_chunk, torch_att_mask = copy.deepcopy(chunk), copy.deepcopy(att_mask)
torch_att_cache = copy.deepcopy(att_cache)
torch_cnn_cache = copy.deepcopy(cnn_cache)
for i in range(10):
logger.info(
"torch chunk-{}: {}, att_cache: {}, cnn_cache: {}"
", att_mask: {}".format(
i,
list(torch_chunk.size()),
list(torch_att_cache.size()),
list(torch_cnn_cache.size()),
list(torch_att_mask.size()),
)
)
torch_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)) :] = 1
out, torch_att_cache, torch_cnn_cache = encoder(
torch_chunk, torch_att_cache, torch_cnn_cache, torch_att_mask
)
torch_output.append(out)
torch_output = torch.cat(torch_output, dim=-1)
onnx_output = []
onnx_chunk, onnx_att_mask = to_numpy(chunk), to_numpy(att_mask)
onnx_att_cache = to_numpy(att_cache)
onnx_cnn_cache = to_numpy(cnn_cache)
ort_session = onnxruntime.InferenceSession(encoder_outpath)
input_names = [node.name for node in onnx_encoder.graph.input]
for i in range(10):
logger.info(
"onnx chunk-{}: {}, att_cache: {}, cnn_cache: {},"
" att_mask: {}".format(
i,
onnx_chunk.shape,
onnx_att_cache.shape,
onnx_cnn_cache.shape,
onnx_att_mask.shape,
)
)
onnx_att_mask[:, :, :, -(encoder.chunk_size * (i + 1)) :] = 1
ort_inputs = {
"chunk": onnx_chunk,
"att_cache": onnx_att_cache,
"cnn_cache": onnx_cnn_cache,
"att_mask": onnx_att_mask,
}
ort_outs = ort_session.run(None, ort_inputs)
onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2]
onnx_output.append(ort_outs[0])
onnx_output = np.concatenate(onnx_output, axis=-1)
np.testing.assert_allclose(
to_numpy(torch_output), onnx_output, rtol=1e-03, atol=1e-04
)
meta = ort_session.get_modelmeta()
logger.info("custom_metadata_map={}".format(meta.custom_metadata_map))
logger.info("Check onnx_encoder, pass!")
return encoder, ort_session
def export_ctc(asr_model, args):
logger.info("Stage-2: export ctc")
ctc = BPUCTC(asr_model.ctc).eval()
ctc_outpath = os.path.join(args.output_dir, "ctc.onnx")
logger.info("Stage-2.1: prepare inputs for ctc")
hidden = torch.randn((1, args.output_size, 1, args.chunk_size))
logger.info("Stage-2.2: torch.onnx.export")
# NOTE(xcsong): Below attributes will be used in
# onnx2horizonbin.py::generate_config()
attributes = {}
attributes["input_name"], attributes["input_type"] = "hidden", "featuremap"
attributes["norm_type"] = "no_preprocess"
attributes["input_layout_train"] = "NCHW"
attributes["input_layout_rt"] = "NCHW"
attributes["input_shape"] = "{}x{}x{}x{}".format(
hidden.size(0),
hidden.size(1),
hidden.size(2),
hidden.size(3),
)
torch.onnx.export(
ctc,
hidden,
ctc_outpath,
opset_version=11,
export_params=True,
do_constant_folding=True,
input_names=["hidden"],
output_names=["probs"],
dynamic_axes=None,
verbose=False,
)
onnx_ctc = onnx.load(ctc_outpath)
for k in vars(args):
meta = onnx_ctc.metadata_props.add()
meta.key, meta.value = str(k), str(getattr(args, k))
for k in attributes:
meta = onnx_ctc.metadata_props.add()
meta.key, meta.value = str(k), str(attributes[k])
onnx.checker.check_model(onnx_ctc)
onnx.helper.printable_graph(onnx_ctc.graph)
onnx.save(onnx_ctc, ctc_outpath)
print_input_output_info(onnx_ctc, "onnx_ctc")
logger.info("Export onnx_ctc, done! see {}".format(ctc_outpath))
logger.info("Stage-2.3: check onnx_ctc and torch_ctc")
torch_output = ctc(hidden)
ort_session = onnxruntime.InferenceSession(ctc_outpath)
onnx_output = ort_session.run(None, {"hidden": to_numpy(hidden)})
np.testing.assert_allclose(
to_numpy(torch_output), onnx_output[0], rtol=1e-03, atol=1e-04
)
meta = ort_session.get_modelmeta()
logger.info("custom_metadata_map={}".format(meta.custom_metadata_map))
logger.info("Check onnx_ctc, pass!")
return ctc, ort_session
def export_decoder(asr_model, args):
logger.info("Currently, Decoder is not supported.")
if __name__ == "__main__":
torch.manual_seed(777)
args = get_args()
args.ln_run_on_bpu = False
# NOTE(xcsong): XJ3 BPU only support static shapes
assert args.chunk_size > 0
assert args.num_decoding_left_chunks > 0
os.system("mkdir -p " + args.output_dir)
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
with open(args.config, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
model = init_model(configs)
load_checkpoint(model, args.checkpoint)
model.eval()
print(model)
args.feature_size = configs["input_dim"]
args.output_size = model.encoder.output_size()
args.decoding_window = (
(args.chunk_size - 1) * model.encoder.embed.subsampling_rate
+ model.encoder.embed.right_context
+ 1
)
export_encoder(model, args)
export_ctc(model, args)
export_decoder(model, args)