|
|
import math |
|
|
import onnxruntime |
|
|
import numpy as np |
|
|
import base64 |
|
|
import whisper |
|
|
import re |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchaudio |
|
|
from typing import List, Any, Dict |
|
|
from transformers import Wav2Vec2CTCTokenizer, PreTrainedModel, PretrainedConfig |
|
|
import pycantonese |
|
|
|
|
|
|
|
|
def parse_jyutping(jyutping: str) -> str: |
|
|
"""Helper to parse Jyutping string using pycantonese.""" |
|
|
|
|
|
|
|
|
if jyutping and not jyutping[-1].isdigit(): |
|
|
match = re.search(r"([1-6])", jyutping) |
|
|
if match: |
|
|
tone = match.group(1) |
|
|
jyutping = jyutping.replace(tone, "") + tone |
|
|
|
|
|
try: |
|
|
|
|
|
parsed_jyutping = pycantonese.parse_jyutping(jyutping)[0] |
|
|
onset = parsed_jyutping.onset if parsed_jyutping.onset else "" |
|
|
nucleus = parsed_jyutping.nucleus if parsed_jyutping.nucleus else "" |
|
|
coda = parsed_jyutping.coda if parsed_jyutping.coda else "" |
|
|
tone_val = str(parsed_jyutping.tone) if parsed_jyutping.tone else "" |
|
|
|
|
|
|
|
|
return "".join([onset, nucleus, coda, tone_val]) |
|
|
except Exception as e: |
|
|
print(f"Failed to parse Jyutping '{jyutping}': {e}. Returning original.") |
|
|
return jyutping |
|
|
|
|
|
|
|
|
class CTCTransformerConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=100, |
|
|
num_labels=50, |
|
|
eos_token_id=2, |
|
|
bos_token_id=1, |
|
|
pad_token_id=0, |
|
|
blank_id=0, |
|
|
hidden_size=384, |
|
|
num_hidden_layers=50, |
|
|
num_attention_heads=4, |
|
|
intermediate_size=2048, |
|
|
dropout=0.1, |
|
|
max_position_embeddings=1024, |
|
|
ctc_loss_reduction="mean", |
|
|
ctc_zero_infinity=True, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.num_labels = num_labels |
|
|
self.hidden_size = hidden_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.intermediate_size = intermediate_size |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.dropout = dropout |
|
|
self.eos_token_id = eos_token_id |
|
|
self.bos_token_id = bos_token_id |
|
|
self.pad_token_id = pad_token_id |
|
|
self.blank_id = blank_id |
|
|
self.ctc_loss_reduction = ctc_loss_reduction |
|
|
self.ctc_zero_infinity = ctc_zero_infinity |
|
|
|
|
|
|
|
|
class SinusoidalPositionEncoder(torch.nn.Module): |
|
|
"""Sinusoidal positional embeddings for sequences""" |
|
|
|
|
|
def __init__(self, d_model=384, dropout_rate=0.1): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.dropout = nn.Dropout(p=dropout_rate) |
|
|
|
|
|
def encode( |
|
|
self, |
|
|
positions: torch.Tensor = None, |
|
|
depth: int = None, |
|
|
dtype: torch.dtype = torch.float32, |
|
|
): |
|
|
if depth is None: |
|
|
depth = self.d_model |
|
|
|
|
|
batch_size = positions.size(0) |
|
|
positions = positions.type(dtype) |
|
|
device = positions.device |
|
|
|
|
|
|
|
|
depth_float = float(depth) |
|
|
log_timescale_increment = torch.log( |
|
|
torch.tensor([10000.0], dtype=dtype, device=device) |
|
|
) / (depth_float / 2.0 - 1.0) |
|
|
|
|
|
|
|
|
inv_timescales = torch.exp( |
|
|
torch.arange(depth_float // 2, device=device, dtype=dtype) |
|
|
* (-log_timescale_increment) |
|
|
) |
|
|
|
|
|
|
|
|
pos_seq = positions.view(-1, 1) |
|
|
inv_timescales = inv_timescales.view(1, -1) |
|
|
|
|
|
scaled_time = pos_seq * inv_timescales |
|
|
|
|
|
|
|
|
sin_encodings = torch.sin(scaled_time) |
|
|
cos_encodings = torch.cos(scaled_time) |
|
|
|
|
|
|
|
|
pos_encodings = torch.zeros( |
|
|
positions.shape[0], positions.shape[1], depth, device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
even_indices = torch.arange(0, depth, 2, device=device) |
|
|
odd_indices = torch.arange(1, depth, 2, device=device) |
|
|
|
|
|
pos_encodings[:, :, even_indices] = sin_encodings.view( |
|
|
batch_size, -1, depth // 2 |
|
|
) |
|
|
pos_encodings[:, :, odd_indices] = cos_encodings.view( |
|
|
batch_size, -1, depth // 2 |
|
|
) |
|
|
|
|
|
return pos_encodings |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size, timesteps, input_dim = x.size() |
|
|
|
|
|
positions = ( |
|
|
torch.arange(1, timesteps + 1, device=x.device) |
|
|
.unsqueeze(0) |
|
|
.expand(batch_size, -1) |
|
|
) |
|
|
position_encoding = self.encode(positions, input_dim, x.dtype) |
|
|
|
|
|
|
|
|
return self.dropout(x + position_encoding) |
|
|
|
|
|
|
|
|
class CTCTransformerModel(PreTrainedModel): |
|
|
config_class = CTCTransformerConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.embed = nn.Embedding( |
|
|
config.vocab_size + 1, |
|
|
config.hidden_size, |
|
|
padding_idx=config.vocab_size, |
|
|
) |
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_attention_heads, |
|
|
dim_feedforward=config.intermediate_size, |
|
|
dropout=self.config.dropout, |
|
|
activation="gelu", |
|
|
batch_first=True, |
|
|
) |
|
|
self.encoder = nn.TransformerEncoder( |
|
|
encoder_layer, num_layers=config.num_hidden_layers |
|
|
) |
|
|
self.pos_embed = SinusoidalPositionEncoder( |
|
|
d_model=config.hidden_size, dropout_rate=config.dropout |
|
|
) |
|
|
self.norm = nn.LayerNorm(config.hidden_size) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids, |
|
|
attention_mask=None, |
|
|
labels=None, |
|
|
): |
|
|
|
|
|
x = self.embed(input_ids) |
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
|
|
|
x = self.pos_embed(x) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src_key_padding_mask = attention_mask == 0 |
|
|
else: |
|
|
src_key_padding_mask = None |
|
|
|
|
|
|
|
|
x = self.encoder(x, src_key_padding_mask=src_key_padding_mask) |
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
|
|
|
logits = self.classifier(x) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
input_lengths = attention_mask.sum(-1) |
|
|
|
|
|
|
|
|
labels_mask = labels >= 0 |
|
|
target_lengths = labels_mask.sum(-1) |
|
|
flattened_targets = labels.masked_select(labels_mask) |
|
|
|
|
|
|
|
|
log_probs = nn.functional.log_softmax( |
|
|
logits, dim=-1, dtype=torch.float32 |
|
|
).transpose(0, 1) |
|
|
|
|
|
with torch.backends.cudnn.flags(enabled=False): |
|
|
loss = nn.functional.ctc_loss( |
|
|
log_probs, |
|
|
flattened_targets, |
|
|
input_lengths, |
|
|
target_lengths, |
|
|
blank=0, |
|
|
reduction=self.config.ctc_loss_reduction, |
|
|
zero_infinity=self.config.ctc_zero_infinity, |
|
|
) |
|
|
|
|
|
return {"loss": loss, "logits": logits} |
|
|
|
|
|
@torch.inference_mode() |
|
|
def predict(self, input_ids: List[int]): |
|
|
blank_id = self.config.blank_id |
|
|
|
|
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to( |
|
|
input_ids.device |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
x = self.embed(input_ids) |
|
|
x = self.pos_embed(x) |
|
|
|
|
|
encoded = self.encoder(x, src_key_padding_mask=(attention_mask == 0)) |
|
|
logits = self.classifier(encoded) |
|
|
log_probs = F.log_softmax(logits, dim=-1) |
|
|
pred_ids = torch.argmax(log_probs, dim=-1).squeeze(0).tolist() |
|
|
|
|
|
|
|
|
pred_phoneme_ids = [] |
|
|
prev = None |
|
|
|
|
|
for idx in pred_ids: |
|
|
if idx != blank_id and idx != prev: |
|
|
pred_phoneme_ids.append(idx) |
|
|
prev = idx |
|
|
|
|
|
return pred_phoneme_ids |
|
|
|
|
|
|
|
|
def load_speech_tokenizer(speech_tokenizer_path: str): |
|
|
"""Load speech tokenizer ONNX model.""" |
|
|
option = onnxruntime.SessionOptions() |
|
|
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
option.intra_op_num_threads = 1 |
|
|
session = onnxruntime.InferenceSession( |
|
|
speech_tokenizer_path, |
|
|
sess_options=option, |
|
|
providers=["CPUExecutionProvider"], |
|
|
) |
|
|
return session |
|
|
|
|
|
|
|
|
def extract_speech_token(audio, speech_tokenizer_session): |
|
|
""" |
|
|
Extract speech tokens from audio using speech tokenizer. |
|
|
|
|
|
Args: |
|
|
audio: audio signal (torch.Tensor or numpy.ndarray), shape (T,) at 16kHz |
|
|
speech_tokenizer_session: ONNX speech tokenizer session |
|
|
|
|
|
Returns: |
|
|
speech_token: tensor of shape (1, num_tokens) |
|
|
speech_token_len: tensor of shape (1,) with token sequence length |
|
|
""" |
|
|
|
|
|
if isinstance(audio, torch.Tensor): |
|
|
audio = audio.cpu().numpy() |
|
|
elif isinstance(audio, np.ndarray): |
|
|
pass |
|
|
else: |
|
|
raise ValueError("Audio must be torch.Tensor or numpy.ndarray") |
|
|
|
|
|
|
|
|
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0) |
|
|
|
|
|
|
|
|
feat = whisper.log_mel_spectrogram(audio_tensor, n_mels=128) |
|
|
|
|
|
|
|
|
speech_token = ( |
|
|
speech_tokenizer_session.run( |
|
|
None, |
|
|
{ |
|
|
speech_tokenizer_session.get_inputs()[0] |
|
|
.name: feat.detach() |
|
|
.cpu() |
|
|
.numpy(), |
|
|
speech_tokenizer_session.get_inputs()[1].name: np.array( |
|
|
[feat.shape[2]], dtype=np.int32 |
|
|
), |
|
|
}, |
|
|
)[0] |
|
|
.flatten() |
|
|
.tolist() |
|
|
) |
|
|
|
|
|
speech_token = torch.tensor([speech_token], dtype=torch.int32) |
|
|
speech_token_len = torch.tensor([len(speech_token[0])], dtype=torch.int32) |
|
|
|
|
|
return speech_token, speech_token_len |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir: str, **kwargs: Any): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.speech_tokenizer_session = load_speech_tokenizer( |
|
|
f"{model_dir}/speech_tokenizer_v2.onnx" |
|
|
) |
|
|
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_dir) |
|
|
self.model = ( |
|
|
CTCTransformerModel.from_pretrained( |
|
|
model_dir, |
|
|
torch_dtype=torch.bfloat16, |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
.eval() |
|
|
.to(device) |
|
|
) |
|
|
|
|
|
def preprocess(self, inputs): |
|
|
waveform, original_sampling_rate = torchaudio.load(inputs) |
|
|
|
|
|
if original_sampling_rate != 16000: |
|
|
resampler = torchaudio.transforms.Resample( |
|
|
orig_freq=original_sampling_rate, new_freq=16000 |
|
|
) |
|
|
audio_array = resampler(waveform).numpy().flatten() |
|
|
else: |
|
|
audio_array = waveform.numpy().flatten() |
|
|
return audio_array |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[str]: |
|
|
|
|
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
audio = inputs["audio"] |
|
|
audio_bytes = base64.b64decode(audio) |
|
|
temp_wav_path = "/tmp/temp.wav" |
|
|
with open(temp_wav_path, "wb") as f: |
|
|
f.write(audio_bytes) |
|
|
|
|
|
audio_array = self.preprocess(temp_wav_path) |
|
|
|
|
|
|
|
|
speech_token, speech_token_len = extract_speech_token( |
|
|
audio_array, self.speech_tokenizer_session |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
speech_token = speech_token.to(next(self.model.parameters()).device) |
|
|
outputs = self.model.predict(speech_token) |
|
|
|
|
|
transcription = self.tokenizer.decode(outputs, skip_special_tokens=True) |
|
|
print(transcription) |
|
|
transcription = " ".join( |
|
|
[parse_jyutping(jyt) for jyt in transcription.split(" ")] |
|
|
) |
|
|
|
|
|
return {"transcription": transcription} |
|
|
|