Spaces:
Paused
Paused
| import argparse | |
| import json | |
| import re | |
| import time | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union | |
| import torch | |
| import numpy as np | |
| from whisper.tokenizer import get_tokenizer | |
| from whisper_live.whisper_utils import (mel_filters, store_transcripts, | |
| write_error_stats, load_audio_wav_format, | |
| pad_or_trim) | |
| import tensorrt_llm | |
| import tensorrt_llm.logger as logger | |
| from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt, | |
| trt_dtype_to_torch) | |
| from tensorrt_llm.runtime import ModelConfig, SamplingConfig | |
| from tensorrt_llm.runtime.session import Session, TensorInfo | |
| SAMPLE_RATE = 16000 | |
| N_FFT = 400 | |
| HOP_LENGTH = 160 | |
| CHUNK_LENGTH = 30 | |
| N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk | |
| class WhisperEncoding: | |
| def __init__(self, engine_dir): | |
| self.session = self.get_session(engine_dir) | |
| def get_session(self, engine_dir): | |
| config_path = engine_dir / 'encoder_config.json' | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| use_gpt_attention_plugin = config['plugin_config'][ | |
| 'gpt_attention_plugin'] | |
| dtype = config['builder_config']['precision'] | |
| n_mels = config['builder_config']['n_mels'] | |
| num_languages = config['builder_config']['num_languages'] | |
| self.dtype = dtype | |
| self.n_mels = n_mels | |
| self.num_languages = num_languages | |
| serialize_path = engine_dir / f'whisper_encoder_{self.dtype}_tp1_rank0.engine' | |
| with open(serialize_path, 'rb') as f: | |
| session = Session.from_serialized_engine(f.read()) | |
| return session | |
| def get_audio_features(self, mel): | |
| inputs = OrderedDict() | |
| output_list = [] | |
| inputs.update({'x': mel}) | |
| output_list.append( | |
| TensorInfo('x', str_dtype_to_trt(self.dtype), mel.shape)) | |
| output_info = (self.session).infer_shapes(output_list) | |
| logger.debug(f'output info {output_info}') | |
| outputs = { | |
| t.name: torch.empty(tuple(t.shape), | |
| dtype=trt_dtype_to_torch(t.dtype), | |
| device='cuda') | |
| for t in output_info | |
| } | |
| stream = torch.cuda.current_stream() | |
| ok = self.session.run(inputs=inputs, | |
| outputs=outputs, | |
| stream=stream.cuda_stream) | |
| assert ok, 'Engine execution failed' | |
| stream.synchronize() | |
| audio_features = outputs['output'] | |
| return audio_features | |
| class WhisperDecoding: | |
| def __init__(self, engine_dir, runtime_mapping, debug_mode=False): | |
| self.decoder_config = self.get_config(engine_dir) | |
| self.decoder_generation_session = self.get_session( | |
| engine_dir, runtime_mapping, debug_mode) | |
| def get_config(self, engine_dir): | |
| config_path = engine_dir / 'decoder_config.json' | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| decoder_config = OrderedDict() | |
| decoder_config.update(config['plugin_config']) | |
| decoder_config.update(config['builder_config']) | |
| return decoder_config | |
| def get_session(self, engine_dir, runtime_mapping, debug_mode=False): | |
| dtype = self.decoder_config['precision'] | |
| serialize_path = engine_dir / f'whisper_decoder_{dtype}_tp1_rank0.engine' | |
| with open(serialize_path, "rb") as f: | |
| decoder_engine_buffer = f.read() | |
| decoder_model_config = ModelConfig( | |
| num_heads=self.decoder_config['num_heads'], | |
| num_kv_heads=self.decoder_config['num_heads'], | |
| hidden_size=self.decoder_config['hidden_size'], | |
| vocab_size=self.decoder_config['vocab_size'], | |
| num_layers=self.decoder_config['num_layers'], | |
| gpt_attention_plugin=self.decoder_config['gpt_attention_plugin'], | |
| remove_input_padding=self.decoder_config['remove_input_padding'], | |
| cross_attention=self.decoder_config['cross_attention'], | |
| has_position_embedding=self. | |
| decoder_config['has_position_embedding'], | |
| has_token_type_embedding=self. | |
| decoder_config['has_token_type_embedding'], | |
| ) | |
| decoder_generation_session = tensorrt_llm.runtime.GenerationSession( | |
| decoder_model_config, | |
| decoder_engine_buffer, | |
| runtime_mapping, | |
| debug_mode=debug_mode) | |
| return decoder_generation_session | |
| def generate(self, | |
| decoder_input_ids, | |
| encoder_outputs, | |
| eot_id, | |
| max_new_tokens=40, | |
| num_beams=1): | |
| encoder_input_lengths = torch.tensor( | |
| [encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])], | |
| dtype=torch.int32, | |
| device='cuda') | |
| decoder_input_lengths = torch.tensor([ | |
| decoder_input_ids.shape[-1] | |
| for _ in range(decoder_input_ids.shape[0]) | |
| ], | |
| dtype=torch.int32, | |
| device='cuda') | |
| decoder_max_input_length = torch.max(decoder_input_lengths).item() | |
| # generation config | |
| sampling_config = SamplingConfig(end_id=eot_id, | |
| pad_id=eot_id, | |
| num_beams=num_beams) | |
| self.decoder_generation_session.setup( | |
| decoder_input_lengths.size(0), | |
| decoder_max_input_length, | |
| max_new_tokens, | |
| beam_width=num_beams, | |
| encoder_max_input_length=encoder_outputs.shape[1]) | |
| torch.cuda.synchronize() | |
| decoder_input_ids = decoder_input_ids.type(torch.int32).cuda() | |
| output_ids = self.decoder_generation_session.decode( | |
| decoder_input_ids, | |
| decoder_input_lengths, | |
| sampling_config, | |
| encoder_output=encoder_outputs, | |
| encoder_input_lengths=encoder_input_lengths, | |
| ) | |
| torch.cuda.synchronize() | |
| # get the list of int from output_ids tensor | |
| output_ids = output_ids.cpu().numpy().tolist() | |
| return output_ids | |
| class WhisperTRTLLM(object): | |
| def __init__( | |
| self, | |
| engine_dir, | |
| debug_mode=False, | |
| assets_dir=None, | |
| device=None | |
| ): | |
| world_size = 1 | |
| runtime_rank = tensorrt_llm.mpi_rank() | |
| runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank) | |
| torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) | |
| engine_dir = Path(engine_dir) | |
| self.encoder = WhisperEncoding(engine_dir) | |
| self.decoder = WhisperDecoding(engine_dir, | |
| runtime_mapping, | |
| debug_mode=False) | |
| self.n_mels = self.encoder.n_mels | |
| # self.tokenizer = get_tokenizer(num_languages=self.encoder.num_languages, | |
| # tokenizer_dir=assets_dir) | |
| self.device = device | |
| self.tokenizer = get_tokenizer( | |
| False, | |
| # num_languages=self.encoder.num_languages, | |
| language="en", | |
| task="transcribe", | |
| ) | |
| self.filters = mel_filters(self.device, self.encoder.n_mels, assets_dir) | |
| def log_mel_spectrogram( | |
| self, | |
| audio: Union[str, np.ndarray, torch.Tensor], | |
| padding: int = 0, | |
| return_duration = True | |
| ): | |
| """ | |
| Compute the log-Mel spectrogram of | |
| Parameters | |
| ---------- | |
| audio: Union[str, np.ndarray, torch.Tensor], shape = (*) | |
| The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz | |
| n_mels: int | |
| The number of Mel-frequency filters, only 80 and 128 are supported | |
| padding: int | |
| Number of zero samples to pad to the right | |
| device: Optional[Union[str, torch.device]] | |
| If given, the audio tensor is moved to this device before STFT | |
| Returns | |
| ------- | |
| torch.Tensor, shape = (80 or 128, n_frames) | |
| A Tensor that contains the Mel spectrogram | |
| """ | |
| if not torch.is_tensor(audio): | |
| if isinstance(audio, str): | |
| if audio.endswith('.wav'): | |
| audio, _ = load_audio_wav_format(audio) | |
| else: | |
| audio = load_audio(audio) | |
| assert isinstance(audio, | |
| np.ndarray), f"Unsupported audio type: {type(audio)}" | |
| duration = audio.shape[-1] / SAMPLE_RATE | |
| audio = pad_or_trim(audio, N_SAMPLES) | |
| audio = audio.astype(np.float32) | |
| audio = torch.from_numpy(audio) | |
| if self.device is not None: | |
| audio = audio.to(self.device) | |
| if padding > 0: | |
| audio = F.pad(audio, (0, padding)) | |
| window = torch.hann_window(N_FFT).to(audio.device) | |
| stft = torch.stft(audio, | |
| N_FFT, | |
| HOP_LENGTH, | |
| window=window, | |
| return_complex=True) | |
| magnitudes = stft[..., :-1].abs()**2 | |
| mel_spec = self.filters @ magnitudes | |
| log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
| log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | |
| log_spec = (log_spec + 4.0) / 4.0 | |
| if return_duration: | |
| return log_spec, duration | |
| else: | |
| return log_spec | |
| def process_batch( | |
| self, | |
| mel, | |
| text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", | |
| num_beams=1): | |
| prompt_id = self.tokenizer.encode( | |
| text_prefix, allowed_special=set(self.tokenizer.special_tokens.keys())) | |
| prompt_id = torch.tensor(prompt_id) | |
| batch_size = mel.shape[0] | |
| decoder_input_ids = prompt_id.repeat(batch_size, 1) | |
| encoder_output = self.encoder.get_audio_features(mel) | |
| output_ids = self.decoder.generate(decoder_input_ids, | |
| encoder_output, | |
| self.tokenizer.eot, | |
| max_new_tokens=96, | |
| num_beams=num_beams) | |
| texts = [] | |
| for i in range(len(output_ids)): | |
| text = self.tokenizer.decode(output_ids[i][0]).strip() | |
| texts.append(text) | |
| return texts | |
| def transcribe( | |
| self, | |
| mel, | |
| text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", | |
| dtype='float16', | |
| batch_size=1, | |
| num_beams=1, | |
| ): | |
| mel = mel.type(str_dtype_to_torch(dtype)) | |
| mel = mel.unsqueeze(0) | |
| predictions = self.process_batch(mel, text_prefix, num_beams) | |
| prediction = predictions[0] | |
| # remove all special tokens in the prediction | |
| prediction = re.sub(r'<\|.*?\|>', '', prediction) | |
| return prediction.strip() | |
| def decode_wav_file( | |
| model, | |
| mel, | |
| text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", | |
| dtype='float16', | |
| batch_size=1, | |
| num_beams=1, | |
| normalizer=None, | |
| mel_filters_dir=None): | |
| mel = mel.type(str_dtype_to_torch(dtype)) | |
| mel = mel.unsqueeze(0) | |
| # repeat the mel spectrogram to match the batch size | |
| mel = mel.repeat(batch_size, 1, 1) | |
| predictions = model.process_batch(mel, text_prefix, num_beams) | |
| prediction = predictions[0] | |
| # remove all special tokens in the prediction | |
| prediction = re.sub(r'<\|.*?\|>', '', prediction) | |
| if normalizer: | |
| prediction = normalizer(prediction) | |
| return prediction.strip() | |
| if __name__=="__main__": | |
| tensorrt_llm.logger.set_level("error") | |
| model = WhisperTRTLLM("/root/TensorRT-LLM/examples/whisper/whisper_small_en", False, "../assets", device="cuda") | |
| mel, total_duration = model.log_mel_spectrogram( | |
| "../assets/1221-135766-0002.wav", | |
| ) | |
| results = model.transcribe(mel) | |
| print(results, total_duration) |