```python """the interface to interact with wakeword model""" import pyaudio import threading import time import torchaudio import torch import numpy as np import queue from transformers import WavLMForSequenceClassification from transformers import AutoFeatureExtractor def int2float(sound): abs_max = np.abs(sound).max() sound = sound.astype('float32') if abs_max > 0: sound *= 1/abs_max sound = sound.squeeze() # depends on the use case return sound class RealtimeDecoder(): def __init__(self, model, ) -> None: self.model = model self.vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad', force_reload=False, onnx=False) (self.get_speech_timestamps, _, _, _, _) = utils self.SAMPLE_RATE = 16000 self.cache_output = { "cache" : torch.zeros(0, 0, 0, dtype=torch.float), "wavchunks": [], } self.continue_recording = threading.Event() self.frame_duration_ms = 1000 self.audio_queue = queue.SimpleQueue() self.speech_queue = queue.SimpleQueue() def start_recording(self, wait_enter_to_stop=True): def stop(): input("Press Enter to stop the recording:\n\n") self.continue_recording.set() def record(): audio = pyaudio.PyAudio() stream = audio.open(format=pyaudio.paInt16, channels=1, rate=self.SAMPLE_RATE, input=True, frames_per_buffer=int(self.SAMPLE_RATE / 10)) while not self.continue_recording.is_set(): audio_chunk = stream.read(int(self.SAMPLE_RATE * self.frame_duration_ms / 1000.0), exception_on_overflow = False) audio_int16 = np.frombuffer(audio_chunk, np.int16) audio_float32 = int2float(audio_int16) waveform = torch.from_numpy(audio_float32) self.audio_queue.put(waveform) print("Finish record") stream.close() if wait_enter_to_stop: stop_listener_thread = threading.Thread(target=stop, daemon=False) else: stop_listener_thread = None recording_thread = threading.Thread(target=record, daemon=False) return stop_listener_thread, recording_thread def finish_realtime_decode(self): self.cache_output = { "cache" : torch.zeros(0, 0, 0, dtype=torch.float), "wavchunks": [], } def start_decoding(self): def decode(): while not self.continue_recording.is_set(): if self.audio_queue.qsize() > 0: currunt_wavform = self.audio_queue.get() if currunt_wavform is not None: self.cache_output['wavchunks'].append(currunt_wavform) self.cache_output['wavchunks'] = self.cache_output['wavchunks'][-4:] if len(self.cache_output['wavchunks']) > 1: wavform = torch.cat(self.cache_output['wavchunks'][-2:], dim=-1) speech_timestamps = self.get_speech_timestamps(wavform, self.vad_model, sampling_rate=self.SAMPLE_RATE) logits = [1, 0] if len(speech_timestamps) > 0: input_features = feature_extractor.pad([{"input_values": wavform}], padding=True, return_tensors="pt") logits = self.model(**input_features).logits.softmax(dim=-1).squeeze() if logits[1] > 0.6: print("hey armar", logits, wavform.size(-1) / self.SAMPLE_RATE) self.cache_output['wavchunks'] = [] else: print('.'+'.'*self.audio_queue.qsize()) else: time.sleep(0.01) print("KWS thread finish") kws_decode_thread = threading.Thread(target=decode, daemon=False) return kws_decode_thread if __name__ == "__main__": print("Model loading....") kws_model = WavLMForSequenceClassification.from_pretrained('nguyenvulebinh/heyarmar') feature_extractor = AutoFeatureExtractor.from_pretrained('nguyenvulebinh/heyarmar') print("Model loaded....") # file_wave = './99.wav' # wav, rate = torchaudio.load(file_wave) # input_features = feature_extractor.pad([{"input_values": item} for item in wav], padding=True, return_tensors="pt") # output = kws_model(**input_features) # print(output.logits.softmax(dim=-1)) obj_decode = RealtimeDecoder(kws_model) recording_threads = obj_decode.start_recording() kws_decode_thread = obj_decode.start_decoding() for thread in recording_threads: if thread is not None: thread.start() kws_decode_thread.start() for thread in recording_threads: if thread is not None: thread.join() kws_decode_thread.join() ```