nguyenvulebinh commited on
Commit
7a8d5af
1 Parent(s): 40f5535

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +130 -0
README.md ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ """the interface to interact with wakeword model"""
3
+ import pyaudio
4
+ import threading
5
+ import time
6
+ import torchaudio
7
+ import torch
8
+ import numpy as np
9
+ import queue
10
+ from transformers import WavLMForSequenceClassification
11
+ from transformers import AutoFeatureExtractor
12
+
13
+
14
+ def int2float(sound):
15
+ abs_max = np.abs(sound).max()
16
+ sound = sound.astype('float32')
17
+ if abs_max > 0:
18
+ sound *= 1/abs_max
19
+ sound = sound.squeeze() # depends on the use case
20
+ return sound
21
+
22
+ class RealtimeDecoder():
23
+
24
+ def __init__(self,
25
+ model,
26
+ ) -> None:
27
+ self.model = model
28
+ self.vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
29
+ model='silero_vad',
30
+ force_reload=False,
31
+ onnx=False)
32
+
33
+ (self.get_speech_timestamps, _, _, _, _) = utils
34
+ self.SAMPLE_RATE = 16000
35
+ self.cache_output = {
36
+ "cache" : torch.zeros(0, 0, 0, dtype=torch.float),
37
+ "wavchunks": [],
38
+ }
39
+ self.continue_recording = threading.Event()
40
+ self.frame_duration_ms = 1000
41
+ self.audio_queue = queue.SimpleQueue()
42
+ self.speech_queue = queue.SimpleQueue()
43
+
44
+ def start_recording(self, wait_enter_to_stop=True):
45
+ def stop():
46
+ input("Press Enter to stop the recording:\n\n")
47
+ self.continue_recording.set()
48
+ def record():
49
+ audio = pyaudio.PyAudio()
50
+ stream = audio.open(format=pyaudio.paInt16,
51
+ channels=1,
52
+ rate=self.SAMPLE_RATE,
53
+ input=True,
54
+ frames_per_buffer=int(self.SAMPLE_RATE / 10))
55
+ while not self.continue_recording.is_set():
56
+ audio_chunk = stream.read(int(self.SAMPLE_RATE * self.frame_duration_ms / 1000.0), exception_on_overflow = False)
57
+ audio_int16 = np.frombuffer(audio_chunk, np.int16)
58
+ audio_float32 = int2float(audio_int16)
59
+ waveform = torch.from_numpy(audio_float32)
60
+ self.audio_queue.put(waveform)
61
+ print("Finish record")
62
+ stream.close()
63
+ if wait_enter_to_stop:
64
+ stop_listener_thread = threading.Thread(target=stop, daemon=False)
65
+ else:
66
+ stop_listener_thread = None
67
+ recording_thread = threading.Thread(target=record, daemon=False)
68
+ return stop_listener_thread, recording_thread
69
+
70
+ def finish_realtime_decode(self):
71
+ self.cache_output = {
72
+ "cache" : torch.zeros(0, 0, 0, dtype=torch.float),
73
+ "wavchunks": [],
74
+ }
75
+
76
+ def start_decoding(self):
77
+ def decode():
78
+ while not self.continue_recording.is_set():
79
+ if self.audio_queue.qsize() > 0:
80
+ currunt_wavform = self.audio_queue.get()
81
+ if currunt_wavform is not None:
82
+ self.cache_output['wavchunks'].append(currunt_wavform)
83
+ self.cache_output['wavchunks'] = self.cache_output['wavchunks'][-4:]
84
+
85
+ if len(self.cache_output['wavchunks']) > 1:
86
+ wavform = torch.cat(self.cache_output['wavchunks'][-2:], dim=-1)
87
+ speech_timestamps = self.get_speech_timestamps(wavform, self.vad_model, sampling_rate=self.SAMPLE_RATE)
88
+ logits = [1, 0]
89
+ if len(speech_timestamps) > 0:
90
+ input_features = feature_extractor.pad([{"input_values": wavform}], padding=True, return_tensors="pt")
91
+ logits = self.model(**input_features).logits.softmax(dim=-1).squeeze()
92
+ if logits[1] > 0.6:
93
+ print("hey armar", logits, wavform.size(-1) / self.SAMPLE_RATE)
94
+ self.cache_output['wavchunks'] = []
95
+ else:
96
+ print('.'+'.'*self.audio_queue.qsize())
97
+ else:
98
+ time.sleep(0.01)
99
+ print("KWS thread finish")
100
+ kws_decode_thread = threading.Thread(target=decode, daemon=False)
101
+ return kws_decode_thread
102
+
103
+ if __name__ == "__main__":
104
+ print("Model loading....")
105
+
106
+ kws_model = WavLMForSequenceClassification.from_pretrained('nguyenvulebinh/heyarmar')
107
+ feature_extractor = AutoFeatureExtractor.from_pretrained('nguyenvulebinh/heyarmar')
108
+
109
+ print("Model loaded....")
110
+
111
+ # file_wave = './99.wav'
112
+ # wav, rate = torchaudio.load(file_wave)
113
+ # input_features = feature_extractor.pad([{"input_values": item} for item in wav], padding=True, return_tensors="pt")
114
+ # output = kws_model(**input_features)
115
+ # print(output.logits.softmax(dim=-1))
116
+
117
+
118
+ obj_decode = RealtimeDecoder(kws_model)
119
+ recording_threads = obj_decode.start_recording()
120
+ kws_decode_thread = obj_decode.start_decoding()
121
+ for thread in recording_threads:
122
+ if thread is not None:
123
+ thread.start()
124
+ kws_decode_thread.start()
125
+ for thread in recording_threads:
126
+ if thread is not None:
127
+ thread.join()
128
+ kws_decode_thread.join()
129
+
130
+ ```