|
|
|
""" |
|
Real-time ASR using microphone |
|
""" |
|
|
|
import argparse |
|
import logging |
|
import sherpa_onnx |
|
import os |
|
import time |
|
import struct |
|
import asyncio |
|
import soundfile |
|
|
|
try: |
|
import pyaudio |
|
except ImportError: |
|
raise ImportError('Please install pyaudio with `pip install pyaudio`') |
|
|
|
logger = logging.getLogger(__name__) |
|
SAMPLE_RATE = 16000 |
|
|
|
pactx = pyaudio.PyAudio() |
|
models_root: str = None |
|
num_threads: int = 1 |
|
|
|
|
|
def create_zipformer(args) -> sherpa_onnx.OnlineRecognizer: |
|
d = os.path.join( |
|
models_root, 'sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20') |
|
encoder = os.path.join(d, "encoder-epoch-99-avg-1.onnx") |
|
decoder = os.path.join(d, "decoder-epoch-99-avg-1.onnx") |
|
joiner = os.path.join(d, "joiner-epoch-99-avg-1.onnx") |
|
tokens = os.path.join(d, "tokens.txt") |
|
|
|
recognizer = sherpa_onnx.OnlineRecognizer.from_transducer( |
|
tokens=tokens, |
|
encoder=encoder, |
|
decoder=decoder, |
|
joiner=joiner, |
|
provider=args.provider, |
|
num_threads=num_threads, |
|
sample_rate=SAMPLE_RATE, |
|
feature_dim=80, |
|
enable_endpoint_detection=True, |
|
rule1_min_trailing_silence=2.4, |
|
rule2_min_trailing_silence=1.2, |
|
rule3_min_utterance_length=20, |
|
) |
|
return recognizer |
|
|
|
|
|
def create_sensevoice(args) -> sherpa_onnx.OfflineRecognizer: |
|
model = os.path.join( |
|
models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17', 'model.onnx') |
|
tokens = os.path.join( |
|
models_root, 'sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17', 'tokens.txt') |
|
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( |
|
model=model, |
|
tokens=tokens, |
|
num_threads=num_threads, |
|
use_itn=True, |
|
debug=0, |
|
language=args.lang, |
|
) |
|
return recognizer |
|
|
|
|
|
async def run_online(buf, recognizer): |
|
stream = recognizer.create_stream() |
|
last_result = "" |
|
segment_id = 0 |
|
logger.info('Start real-time recognizer') |
|
while True: |
|
samples = await buf.get() |
|
stream.accept_waveform(SAMPLE_RATE, samples) |
|
while recognizer.is_ready(stream): |
|
recognizer.decode_stream(stream) |
|
|
|
is_endpoint = recognizer.is_endpoint(stream) |
|
result = recognizer.get_result(stream) |
|
|
|
if result and (last_result != result): |
|
last_result = result |
|
logger.info(f' > {segment_id}:{result}') |
|
|
|
if is_endpoint: |
|
if result: |
|
logger.info(f'{segment_id}: {result}') |
|
segment_id += 1 |
|
recognizer.reset(stream) |
|
|
|
|
|
async def run_offline(buf, recognizer): |
|
config = sherpa_onnx.VadModelConfig() |
|
config.silero_vad.model = os.path.join( |
|
models_root, 'silero_vad', 'silero_vad.onnx') |
|
config.silero_vad.min_silence_duration = 0.25 |
|
config.sample_rate = SAMPLE_RATE |
|
vad = sherpa_onnx.VoiceActivityDetector( |
|
config, buffer_size_in_seconds=100) |
|
|
|
logger.info('Start offline recognizer with VAD') |
|
texts = [] |
|
while True: |
|
samples = await buf.get() |
|
vad.accept_waveform(samples) |
|
while not vad.empty(): |
|
stream = recognizer.create_stream() |
|
stream.accept_waveform(SAMPLE_RATE, vad.front.samples) |
|
|
|
vad.pop() |
|
recognizer.decode_stream(stream) |
|
|
|
text = stream.result.text.strip().lower() |
|
if len(text): |
|
idx = len(texts) |
|
texts.append(text) |
|
logger.info(f"{idx}: {text}") |
|
|
|
|
|
async def handle_asr(args): |
|
action_func = None |
|
if args.model == 'zipformer': |
|
recognizer = create_zipformer(args) |
|
action_func = run_online |
|
elif args.model == 'sensevoice': |
|
recognizer = create_sensevoice(args) |
|
action_func = run_offline |
|
else: |
|
raise ValueError(f'Unknown model: {args.model}') |
|
buf = asyncio.Queue() |
|
recorder_task = asyncio.create_task(run_record(buf)) |
|
asr_task = asyncio.create_task(action_func(buf, recognizer)) |
|
await asyncio.gather(asr_task, recorder_task) |
|
|
|
|
|
async def handle_tts(args): |
|
model = os.path.join( |
|
models_root, 'vits-melo-tts-zh_en', 'model.onnx') |
|
lexicon = os.path.join( |
|
models_root, 'vits-melo-tts-zh_en', 'lexicon.txt') |
|
dict_dir = os.path.join( |
|
models_root, 'vits-melo-tts-zh_en', 'dict') |
|
tokens = os.path.join( |
|
models_root, 'vits-melo-tts-zh_en', 'tokens.txt') |
|
tts_config = sherpa_onnx.OfflineTtsConfig( |
|
model=sherpa_onnx.OfflineTtsModelConfig( |
|
vits=sherpa_onnx.OfflineTtsVitsModelConfig( |
|
model=model, |
|
lexicon=lexicon, |
|
dict_dir=dict_dir, |
|
tokens=tokens, |
|
), |
|
provider=args.provider, |
|
debug=0, |
|
num_threads=num_threads, |
|
), |
|
max_num_sentences=args.max_num_sentences, |
|
) |
|
if not tts_config.validate(): |
|
raise ValueError("Please check your config") |
|
|
|
tts = sherpa_onnx.OfflineTts(tts_config) |
|
|
|
start = time.time() |
|
audio = tts.generate(args.text, sid=args.sid, |
|
speed=args.speed) |
|
elapsed_seconds = time.time() - start |
|
audio_duration = len(audio.samples) / audio.sample_rate |
|
real_time_factor = elapsed_seconds / audio_duration |
|
|
|
if args.output: |
|
logger.info(f"Saved to {args.output}") |
|
soundfile.write( |
|
args.output, |
|
audio.samples, |
|
samplerate=audio.sample_rate, |
|
subtype="PCM_16", |
|
) |
|
|
|
logger.info(f"The text is '{args.text}'") |
|
logger.info(f"Elapsed seconds: {elapsed_seconds:.3f}") |
|
logger.info(f"Audio duration in seconds: {audio_duration:.3f}") |
|
logger.info( |
|
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}") |
|
|
|
|
|
async def run_record(buf: asyncio.Queue[list[float]]): |
|
loop = asyncio.get_event_loop() |
|
|
|
def on_input(in_data, frame_count, time_info, status): |
|
samples = [ |
|
v/32768.0 for v in list(struct.unpack('<' + 'h' * frame_count, in_data))] |
|
loop.create_task(buf.put(samples)) |
|
return (None, pyaudio.paContinue) |
|
|
|
frame_size = 320 |
|
recorder = pactx.open(format=pyaudio.paInt16, channels=1, |
|
rate=SAMPLE_RATE, input=True, |
|
frames_per_buffer=frame_size, |
|
stream_callback=on_input) |
|
recorder.start_stream() |
|
logger.info('Start recording') |
|
|
|
while recorder.is_active(): |
|
await asyncio.sleep(0.1) |
|
|
|
|
|
async def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--provider', default='cpu', |
|
help='onnxruntime provider, default is cpu, use cuda for GPU') |
|
|
|
subparsers = parser.add_subparsers(help='commands help') |
|
|
|
asr_parser = subparsers.add_parser('asr', help='run asr mode') |
|
asr_parser.add_argument('--model', default='zipformer', |
|
help='model name, default is zipformer') |
|
asr_parser.add_argument('--lang', default='zh', |
|
help='language, default is zh') |
|
asr_parser.set_defaults(func=handle_asr) |
|
|
|
tts_parser = subparsers.add_parser('tts', help='run tts mode') |
|
tts_parser.add_argument('--sid', type=int, default=0, help="""Speaker ID. Used only for multi-speaker models, e.g. |
|
models trained using the VCTK dataset. Not used for single-speaker |
|
models, e.g., models trained using the LJ speech dataset. |
|
""") |
|
tts_parser.add_argument('--output', type=str, default='output.wav', |
|
help='output file name, default is output.wav') |
|
tts_parser.add_argument( |
|
"--speed", |
|
type=float, |
|
default=1.0, |
|
help="Speech speed. Larger->faster; smaller->slower", |
|
) |
|
tts_parser.add_argument( |
|
"--max-num-sentences", |
|
type=int, |
|
default=2, |
|
help="""Max number of sentences in a batch to avoid OOM if the input |
|
text is very long. Set it to -1 to process all the sentences in a |
|
single batch. A smaller value does not mean it is slower compared |
|
to a larger one on CPU. |
|
""", |
|
) |
|
tts_parser.add_argument( |
|
"text", |
|
type=str, |
|
help="The input text to generate audio for", |
|
) |
|
tts_parser.set_defaults(func=handle_tts) |
|
|
|
args = parser.parse_args() |
|
|
|
if hasattr(args, 'func'): |
|
await args.func(args) |
|
else: |
|
parser.print_help() |
|
|
|
if __name__ == '__main__': |
|
logging.basicConfig( |
|
format='%(levelname)s: %(asctime)s %(name)s:%(lineno)s %(message)s') |
|
logging.getLogger().setLevel(logging.INFO) |
|
painfo = pactx.get_default_input_device_info() |
|
assert painfo['maxInputChannels'] >= 1, 'No input device' |
|
logger.info('Default input device: %s', painfo['name']) |
|
|
|
for d in ['.', '..', '../..']: |
|
if os.path.isdir(f'{d}/models'): |
|
models_root = f'{d}/models' |
|
break |
|
assert models_root, 'Could not find models directory' |
|
asyncio.run(main()) |
|
|