mohitsha's picture
Upload whisper_eval.py
a69f683
import os
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import onnxruntime as onnxrt
import torch
from datasets import load_dataset
from transformers import (
AutoConfig,
AutoProcessor,
GenerationConfig,
WhisperForConditionalGeneration,
)
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
model_name = "openai/whisper-tiny.en"
config = AutoConfig.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
batch_size = 1
encoder_num_attention_heads = 6
decoder_num_attention_heads = 6
hidden_size = 384
encoder_sequence_length = 1500
decoder_max_length = 448
num_hidden_layers = 4
encoder_shape = (
batch_size,
encoder_num_attention_heads,
encoder_sequence_length,
hidden_size // encoder_num_attention_heads,
)
decoder_shape = (
batch_size,
decoder_num_attention_heads,
decoder_max_length,
hidden_size // decoder_num_attention_heads,
)
# load dataset
ds = load_dataset(
"hf-internal-testing/librispeech_asr_dummy", "clean", split="validation"
)
idx = 4
inputs = processor.feature_extractor(ds[idx]["audio"]["array"], return_tensors="pt")
input_features = inputs.input_features
# onnx_model_path = "/home/ubuntu/optimum/output_whisper_smooth_quant_4_oct_static_testing"
onnx_model_path = ".\\whisper-tiny-static-shape-quantized-SL-448"
config_file = ".\\other_libs_qdq\\vaip_config_gemm_asr_decoder.json"
encoder_model_path = ".\\whisper-tiny-static-shape-quantized-SL-448\\encoder_model.onnx"
decoder_model_path = ".\\whisper-tiny-static-shape-quantized-SL-448\\decoder_model_quantized.onnx"
print(decoder_model_path)
class ORTEncoder(torch.nn.Module):
def __init__(self):
super().__init__()
self.main_input_name = "input_features"
self.session = onnxrt.InferenceSession(
encoder_model_path, providers=["CPUExecutionProvider"]
)
self.output_names = {
output_key.name: idx
for idx, output_key in enumerate(self.session.get_outputs())
}
def forward(
self,
input_features: torch.FloatTensor,
**kwargs,
) -> BaseModelOutput:
onnx_inputs = {"input_features": input_features.cpu().detach().numpy()}
# Run inference
outputs = self.session.run(None, onnx_inputs)
last_hidden_state = torch.from_numpy(
outputs[self.output_names["last_hidden_state"]]
)
return BaseModelOutput(last_hidden_state=last_hidden_state)
class ORTDecoder(torch.nn.Module):
def __init__(self):
super().__init__()
sess_options = onnxrt.SessionOptions()
self.provider = "VitisAIExecutionProvider"
self.provider_options = {"config_file": config_file}
sess_options.graph_optimization_level = (
onnxrt.GraphOptimizationLevel.ORT_DISABLE_ALL
)
sess_options.add_session_config_entry("session.disable_quant_qdq", "1")
self.session = onnxrt.InferenceSession(
decoder_model_path,
providers=[self.provider],
sess_options=sess_options,
provider_options=[self.provider_options],
)
self.generation_config = GenerationConfig.from_model_config(config)
self.max_length = decoder_max_length
self.input_names = {
input_key.name: idx
for idx, input_key in enumerate(self.session.get_inputs())
}
self.output_names = {
output_key.name: idx
for idx, output_key in enumerate(self.session.get_outputs())
}
self.key_value_input_names = [
key for key in self.input_names if (".key" in key) or (".value" in key)
]
self.key_value_output_names = [
key for key in self.output_names if (".key" in key) or (".value" in key)
]
self.reset()
def reset(self):
# Set the start model inputs
self.decoder_attention_mask = np.zeros((batch_size, self.max_length)).astype(
np.int64
)
self.decoder_attention_mask[0, 0] = 1
self.position_ids = np.array([[0]]).astype(np.int64)
# Set the input / output names
self.num_pkv = 4
def prepare_pkv(self):
decoder_key_value = torch.rand(*decoder_shape).to(torch.float32)
encoder_key_value = torch.rand(*encoder_shape).to(torch.float32)
past_key_values = []
repeat_count = len(self.key_value_input_names) // 4
past_key_values = tuple(
(decoder_key_value, decoder_key_value, encoder_key_value, encoder_key_value)
for _ in range(repeat_count)
)
return tuple(past_key_values)
def forward(
self,
input_ids: torch.LongTensor,
encoder_hidden_states: torch.FloatTensor,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
) -> Seq2SeqLMOutput:
if past_key_values is None:
self.reset()
if self.position_ids[0][0] == self.max_length:
logits = torch.zeros((len(input_ids), 1, config.vocab_size))
logits[:, :, config.eos_token_id] = 1
return Seq2SeqLMOutput(logits=logits, past_key_values=past_key_values)
onnx_inputs = {"input_ids": input_ids.cpu().detach().numpy()}
onnx_inputs["position_ids"] = self.position_ids
onnx_inputs["decoder_attention_mask"] = self.decoder_attention_mask
onnx_inputs["encoder_hidden_states"] = (
encoder_hidden_states.cpu().detach().numpy()
)
if self.position_ids[0][0] == 0:
past_key_values = self.prepare_pkv()
past_key_values = tuple(
past_key_value
for pkv_per_layer in past_key_values
for past_key_value in pkv_per_layer
)
for input_name, past_key_value in zip(
self.key_value_input_names, past_key_values
):
onnx_inputs[input_name] = past_key_value.cpu().detach().numpy()
# Run inference
outputs = self.session.run(None, onnx_inputs)
logits = torch.from_numpy(outputs[self.output_names["logits"]])
out_past_key_values = tuple(
torch.from_numpy(outputs[self.output_names[key]])
for key in self.key_value_output_names
)
if self.position_ids[0][0] == 0:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
else:
out_past_key_values = tuple(
out_past_key_values[i : i + 2] + past_key_values[i + 2 : i + 4]
for i in range(0, len(out_past_key_values), self.num_pkv)
)
if self.position_ids[0][0] < self.max_length - 1:
self.decoder_attention_mask[:, self.position_ids[0][0] + 1] = 1
self.position_ids += 1
return Seq2SeqLMOutput(logits=logits, past_key_values=out_past_key_values)
class ORTModelForWhisper(WhisperForConditionalGeneration):
def __init__(self, *args, **kwargs):
config = AutoConfig.from_pretrained(model_name)
super().__init__(config)
self.encoder = ORTEncoder()
self.decoder = ORTDecoder()
def get_encoder(self):
return self.encoder
def forward(
self,
input_features: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
**kwargs,
) -> Seq2SeqLMOutput:
if encoder_outputs is None:
encoder_outputs = self.encoder(input_features=input_features)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids[:, -1:],
encoder_hidden_states=encoder_outputs.last_hidden_state,
past_key_values=past_key_values,
)
return Seq2SeqLMOutput(
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
)
def can_generate(self):
return True
def reset(self):
self.decoder.reset()
model_ort = ORTModelForWhisper()
model = WhisperForConditionalGeneration.from_pretrained(model_name)
def test_ort():
model = ORTModelForWhisper()
generated_ids = model.generate(input_features)
model_output = processor.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True
)[0]
print("ORT: ", model_output, generated_ids)
def test_original():
model = WhisperForConditionalGeneration.from_pretrained(model_name)
generated_ids = model.generate(input_features)
model_output = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Torch: ", model_output, generated_ids)
test_ort()