|
import os |
|
from pathlib import Path |
|
from typing import Optional, Tuple |
|
|
|
import numpy as np |
|
import onnxruntime as onnxrt |
|
import torch |
|
from datasets import load_dataset |
|
from transformers import ( |
|
AutoConfig, |
|
AutoProcessor, |
|
GenerationConfig, |
|
WhisperForConditionalGeneration, |
|
) |
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput |
|
|
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
|
|
|
|
|
model_name = "openai/whisper-tiny.en" |
|
config = AutoConfig.from_pretrained(model_name) |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
|
|
batch_size = 1 |
|
encoder_num_attention_heads = 6 |
|
decoder_num_attention_heads = 6 |
|
hidden_size = 384 |
|
encoder_sequence_length = 1500 |
|
decoder_max_length = 448 |
|
num_hidden_layers = 4 |
|
|
|
encoder_shape = ( |
|
batch_size, |
|
encoder_num_attention_heads, |
|
encoder_sequence_length, |
|
hidden_size // encoder_num_attention_heads, |
|
) |
|
decoder_shape = ( |
|
batch_size, |
|
decoder_num_attention_heads, |
|
decoder_max_length, |
|
hidden_size // decoder_num_attention_heads, |
|
) |
|
|
|
|
|
|
|
ds = load_dataset( |
|
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" |
|
) |
|
idx = 4 |
|
inputs = processor.feature_extractor(ds[idx]["audio"]["array"], return_tensors="pt") |
|
input_features = inputs.input_features |
|
|
|
|
|
|
|
onnx_model_path = ".\\whisper-tiny-static-shape-quantized-SL-448" |
|
config_file = ".\\other_libs_qdq\\vaip_config_gemm_asr_decoder.json" |
|
encoder_model_path = ".\\whisper-tiny-static-shape-quantized-SL-448\\encoder_model.onnx" |
|
decoder_model_path = ".\\whisper-tiny-static-shape-quantized-SL-448\\decoder_model_quantized.onnx" |
|
|
|
print(decoder_model_path) |
|
|
|
|
|
class ORTEncoder(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.main_input_name = "input_features" |
|
self.session = onnxrt.InferenceSession( |
|
encoder_model_path, providers=["CPUExecutionProvider"] |
|
) |
|
self.output_names = { |
|
output_key.name: idx |
|
for idx, output_key in enumerate(self.session.get_outputs()) |
|
} |
|
|
|
def forward( |
|
self, |
|
input_features: torch.FloatTensor, |
|
**kwargs, |
|
) -> BaseModelOutput: |
|
onnx_inputs = {"input_features": input_features.cpu().detach().numpy()} |
|
|
|
|
|
outputs = self.session.run(None, onnx_inputs) |
|
last_hidden_state = torch.from_numpy( |
|
outputs[self.output_names["last_hidden_state"]] |
|
) |
|
|
|
return BaseModelOutput(last_hidden_state=last_hidden_state) |
|
|
|
|
|
class ORTDecoder(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
sess_options = onnxrt.SessionOptions() |
|
self.provider = "VitisAIExecutionProvider" |
|
self.provider_options = {"config_file": config_file} |
|
sess_options.graph_optimization_level = ( |
|
onnxrt.GraphOptimizationLevel.ORT_DISABLE_ALL |
|
) |
|
sess_options.add_session_config_entry("session.disable_quant_qdq", "1") |
|
self.session = onnxrt.InferenceSession( |
|
decoder_model_path, |
|
providers=[self.provider], |
|
sess_options=sess_options, |
|
provider_options=[self.provider_options], |
|
) |
|
|
|
self.generation_config = GenerationConfig.from_model_config(config) |
|
self.max_length = decoder_max_length |
|
|
|
self.input_names = { |
|
input_key.name: idx |
|
for idx, input_key in enumerate(self.session.get_inputs()) |
|
} |
|
self.output_names = { |
|
output_key.name: idx |
|
for idx, output_key in enumerate(self.session.get_outputs()) |
|
} |
|
self.key_value_input_names = [ |
|
key for key in self.input_names if (".key" in key) or (".value" in key) |
|
] |
|
self.key_value_output_names = [ |
|
key for key in self.output_names if (".key" in key) or (".value" in key) |
|
] |
|
|
|
self.reset() |
|
|
|
def reset(self): |
|
|
|
self.decoder_attention_mask = np.zeros((batch_size, self.max_length)).astype( |
|
np.int64 |
|
) |
|
self.decoder_attention_mask[0, 0] = 1 |
|
self.position_ids = np.array([[0]]).astype(np.int64) |
|
|
|
|
|
self.num_pkv = 4 |
|
|
|
def prepare_pkv(self): |
|
decoder_key_value = torch.rand(*decoder_shape).to(torch.float32) |
|
encoder_key_value = torch.rand(*encoder_shape).to(torch.float32) |
|
|
|
past_key_values = [] |
|
repeat_count = len(self.key_value_input_names) // 4 |
|
past_key_values = tuple( |
|
(decoder_key_value, decoder_key_value, encoder_key_value, encoder_key_value) |
|
for _ in range(repeat_count) |
|
) |
|
|
|
return tuple(past_key_values) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
encoder_hidden_states: torch.FloatTensor, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
) -> Seq2SeqLMOutput: |
|
if past_key_values is None: |
|
self.reset() |
|
|
|
if self.position_ids[0][0] == self.max_length: |
|
logits = torch.zeros((len(input_ids), 1, config.vocab_size)) |
|
logits[:, :, config.eos_token_id] = 1 |
|
|
|
return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values) |
|
|
|
onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()} |
|
|
|
onnx_inputs["position_ids"] = self.position_ids |
|
onnx_inputs["decoder_attention_mask"] = self.decoder_attention_mask |
|
onnx_inputs["encoder_hidden_states"] = ( |
|
encoder_hidden_states.cpu().detach().numpy() |
|
) |
|
|
|
if self.position_ids[0][0] == 0: |
|
past_key_values = self.prepare_pkv() |
|
|
|
past_key_values = tuple( |
|
past_key_value |
|
for pkv_per_layer in past_key_values |
|
for past_key_value in pkv_per_layer |
|
) |
|
|
|
for input_name, past_key_value in zip( |
|
self.key_value_input_names, past_key_values |
|
): |
|
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy() |
|
|
|
|
|
outputs = self.session.run(None, onnx_inputs) |
|
|
|
logits = torch.from_numpy(outputs[self.output_names["logits"]]) |
|
|
|
out_past_key_values = tuple( |
|
torch.from_numpy(outputs[self.output_names[key]]) |
|
for key in self.key_value_output_names |
|
) |
|
|
|
if self.position_ids[0][0] == 0: |
|
out_past_key_values = tuple( |
|
out_past_key_values[i : i + self.num_pkv] |
|
for i in range(0, len(out_past_key_values), self.num_pkv) |
|
) |
|
else: |
|
out_past_key_values = tuple( |
|
out_past_key_values[i : i + 2] + past_key_values[i + 2 : i + 4] |
|
for i in range(0, len(out_past_key_values), self.num_pkv) |
|
) |
|
|
|
if self.position_ids[0][0] < self.max_length - 1: |
|
self.decoder_attention_mask[:, self.position_ids[0][0] + 1] = 1 |
|
self.position_ids += 1 |
|
|
|
return Seq2SeqLMOutput(logits=logits, past_key_values=out_past_key_values) |
|
|
|
|
|
class ORTModelForWhisper(WhisperForConditionalGeneration): |
|
def __init__(self, *args, **kwargs): |
|
config = AutoConfig.from_pretrained(model_name) |
|
super().__init__(config) |
|
|
|
self.encoder = ORTEncoder() |
|
self.decoder = ORTDecoder() |
|
|
|
def get_encoder(self): |
|
return self.encoder |
|
|
|
def forward( |
|
self, |
|
input_features: Optional[torch.FloatTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
**kwargs, |
|
) -> Seq2SeqLMOutput: |
|
if encoder_outputs is None: |
|
encoder_outputs = self.encoder(input_features=input_features) |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids[:, -1:], |
|
encoder_hidden_states=encoder_outputs.last_hidden_state, |
|
past_key_values=past_key_values, |
|
) |
|
|
|
return Seq2SeqLMOutput( |
|
logits=decoder_outputs.logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
) |
|
|
|
def can_generate(self): |
|
return True |
|
|
|
def reset(self): |
|
self.decoder.reset() |
|
|
|
|
|
model_ort = ORTModelForWhisper() |
|
model = WhisperForConditionalGeneration.from_pretrained(model_name) |
|
|
|
|
|
def test_ort(): |
|
model = ORTModelForWhisper() |
|
|
|
generated_ids = model.generate(input_features) |
|
model_output = processor.tokenizer.batch_decode( |
|
generated_ids, skip_special_tokens=True |
|
)[0] |
|
|
|
print("ORT: ", model_output, generated_ids) |
|
|
|
|
|
def test_original(): |
|
model = WhisperForConditionalGeneration.from_pretrained(model_name) |
|
|
|
generated_ids = model.generate(input_features) |
|
model_output = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
print("Torch: ", model_output, generated_ids) |
|
|
|
|
|
test_ort() |
|
|