#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import Union, Tuple import numpy as np import sherpa import sherpa_onnx import torch import torchaudio import wave def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: """ :param wave_filename: Path to a wave file. It should be single channel and each sample should be 16-bit. Its sample rate does not need to be 16kHz. :return: Return a tuple containing: signal: A 1-D array of dtype np.float32 containing the samples, which are normalized to the range [-1, 1]. sample_rate: sample rate of the wave file """ with wave.open(wave_filename) as f: assert f.getnchannels() == 1, f.getnchannels() assert f.getsampwidth() == 2, f.getsampwidth() num_samples = f.getnframes() samples = f.readframes(num_samples) samples_int16 = np.frombuffer(samples, dtype=np.int16) samples_float32 = samples_int16.astype(np.float32) samples_float32 = samples_float32 / 32768 return samples_float32, f.getframerate() def decode_offline_recognizer(recognizer: sherpa.OfflineRecognizer, filename: str, ) -> str: s = recognizer.create_stream() s.accept_wave_file(filename) recognizer.decode_stream(s) text = s.result.text.strip() print("text: {}".format(text)) return text.lower() def decode_online_recognizer(recognizer: sherpa.OnlineRecognizer, filename: str, expected_sample_rate: int = 16000, ) -> str: samples, actual_sample_rate = torchaudio.load(filename) if expected_sample_rate != actual_sample_rate: raise AssertionError( "expected sample rate: {}, but: actually: {}".format(expected_sample_rate, actual_sample_rate) ) samples = samples[0].contiguous() s = recognizer.create_stream() tail_padding = torch.zeros(int(expected_sample_rate * 0.3), dtype=torch.float32) s.accept_waveform(expected_sample_rate, samples) s.accept_waveform(expected_sample_rate, tail_padding) s.input_finished() while recognizer.is_ready(s): recognizer.decode_stream(s) text = recognizer.get_result(s).text return text.strip().lower() def decode_offline_recognizer_sherpa_onnx(recognizer: sherpa_onnx.OfflineRecognizer, filename: str, ) -> str: s = recognizer.create_stream() samples, sample_rate = read_wave(filename) s.accept_waveform(sample_rate, samples) recognizer.decode_stream(s) return s.result.text.lower() def decode_online_recognizer_sherpa_onnx(recognizer: sherpa_onnx.OnlineRecognizer, filename: str, ) -> str: s = recognizer.create_stream() samples, sample_rate = read_wave(filename) s.accept_waveform(sample_rate, samples) tail_paddings = np.zeros(int(0.3 * sample_rate), dtype=np.float32) s.accept_waveform(sample_rate, tail_paddings) s.input_finished() while recognizer.is_ready(s): recognizer.decode_stream(s) return recognizer.get_result(s).lower() def decode_by_recognizer( recognizer: Union[ sherpa.OfflineRecognizer, sherpa.OnlineRecognizer, sherpa_onnx.OfflineRecognizer, sherpa_onnx.OnlineRecognizer, ], filename: str, ) -> str: if isinstance(recognizer, sherpa.OfflineRecognizer): return decode_offline_recognizer(recognizer, filename) elif isinstance(recognizer, sherpa.OnlineRecognizer): return decode_online_recognizer(recognizer, filename) elif isinstance(recognizer, sherpa_onnx.OfflineRecognizer): return decode_offline_recognizer_sherpa_onnx(recognizer, filename) elif isinstance(recognizer, sherpa_onnx.OnlineRecognizer): return decode_online_recognizer_sherpa_onnx(recognizer, filename) else: raise ValueError(f"Unknown recognizer type {type(recognizer)}") if __name__ == "__main__": pass