Edit model card
YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)
"""the interface to interact with wakeword model"""
import pyaudio
import threading
import time
import torchaudio
import torch
import numpy as np
import queue
from transformers import WavLMForSequenceClassification
from transformers import AutoFeatureExtractor


def int2float(sound):
    abs_max = np.abs(sound).max()
    sound = sound.astype('float32')
    if abs_max > 0:
        sound *= 1/abs_max
    sound = sound.squeeze()  # depends on the use case
    return sound

class RealtimeDecoder():

    def __init__(self,
        model,
    ) -> None:
        self.model = model
        self.vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                              model='silero_vad',
                              force_reload=False,
                              onnx=False)

        (self.get_speech_timestamps, _, _, _, _) = utils
        self.SAMPLE_RATE = 16000
        self.cache_output = {
            "cache" : torch.zeros(0, 0, 0, dtype=torch.float),
            "wavchunks": [],
        }
        self.continue_recording = threading.Event()
        self.frame_duration_ms = 1000
        self.audio_queue = queue.SimpleQueue()
        self.speech_queue = queue.SimpleQueue()

    def start_recording(self, wait_enter_to_stop=True):
        def stop():
            input("Press Enter to stop the recording:\n\n")
            self.continue_recording.set()
        def record():
            audio = pyaudio.PyAudio()
            stream = audio.open(format=pyaudio.paInt16,
                    channels=1,
                    rate=self.SAMPLE_RATE,
                    input=True,
                    frames_per_buffer=int(self.SAMPLE_RATE / 10))
            while not self.continue_recording.is_set():
                audio_chunk = stream.read(int(self.SAMPLE_RATE * self.frame_duration_ms / 1000.0), exception_on_overflow = False)
                audio_int16 = np.frombuffer(audio_chunk, np.int16)
                audio_float32 = int2float(audio_int16)
                waveform = torch.from_numpy(audio_float32)
                self.audio_queue.put(waveform)
            print("Finish record")
            stream.close()
        if wait_enter_to_stop:
            stop_listener_thread = threading.Thread(target=stop, daemon=False)
        else:
            stop_listener_thread = None
        recording_thread = threading.Thread(target=record, daemon=False)
        return stop_listener_thread, recording_thread

    def finish_realtime_decode(self):        
        self.cache_output = {
            "cache" : torch.zeros(0, 0, 0, dtype=torch.float),
            "wavchunks": [],
        }

    def start_decoding(self):
        def decode():
            while not self.continue_recording.is_set():
                if self.audio_queue.qsize() > 0:
                    currunt_wavform = self.audio_queue.get()
                    if currunt_wavform is not None:
                        self.cache_output['wavchunks'].append(currunt_wavform)
                        self.cache_output['wavchunks'] = self.cache_output['wavchunks'][-4:]

                    if len(self.cache_output['wavchunks']) > 1:
                        wavform = torch.cat(self.cache_output['wavchunks'][-2:], dim=-1)
                        speech_timestamps = self.get_speech_timestamps(wavform, self.vad_model, sampling_rate=self.SAMPLE_RATE)
                        logits = [1, 0]
                        if len(speech_timestamps) > 0:
                            input_features = feature_extractor.pad([{"input_values": wavform}], padding=True, return_tensors="pt")
                            logits = self.model(**input_features).logits.softmax(dim=-1).squeeze()
                        if logits[1] > 0.6:
                            print("hey armar", logits, wavform.size(-1) / self.SAMPLE_RATE)
                            self.cache_output['wavchunks'] = []
                        else:
                            print('.'+'.'*self.audio_queue.qsize())
                else:
                    time.sleep(0.01)
            print("KWS thread finish")
        kws_decode_thread = threading.Thread(target=decode, daemon=False)
        return kws_decode_thread

if __name__ == "__main__":
    print("Model loading....")

    kws_model = WavLMForSequenceClassification.from_pretrained('nguyenvulebinh/heyarmar')
    feature_extractor = AutoFeatureExtractor.from_pretrained('nguyenvulebinh/heyarmar')

    print("Model loaded....")

    # file_wave = './99.wav'
    # wav, rate = torchaudio.load(file_wave)
    # input_features = feature_extractor.pad([{"input_values": item} for item in wav], padding=True, return_tensors="pt")
    # output = kws_model(**input_features)
    # print(output.logits.softmax(dim=-1))


    obj_decode = RealtimeDecoder(kws_model)
    recording_threads = obj_decode.start_recording()
    kws_decode_thread = obj_decode.start_decoding()
    for thread in recording_threads:
        if thread is not None:
            thread.start()
    kws_decode_thread.start()
    for thread in recording_threads:
        if thread is not None:
            thread.join()
    kws_decode_thread.join()
Downloads last month
3