Spaces:
Paused
Paused
from typing import Tuple | |
import wave | |
import os | |
import torchaudio | |
from vad import EnergyVAD | |
TARGET_SAMPLING_RATE = 16000 | |
def create_frames(data: bytes, frame_duration: int) -> Tuple[bytes]: | |
frame_size = int(TARGET_SAMPLING_RATE * (frame_duration / 1000)) | |
return (data[i:i + frame_size] for i in range(0, len(data), frame_size)), frame_size | |
def detect_activity(energies: list): | |
if sum(energies) < len(energies) / 12: | |
return False | |
count = 0 | |
for energy in energies: | |
if energy == 1: | |
count += 1 | |
if count == 12: | |
return True | |
else: | |
count = 0 | |
return False | |
class Client: | |
def __init__(self, sid, client_id, call_id=None): | |
self.sid = sid | |
self.client_id = client_id | |
self.call_id = call_id | |
self.buffer = bytearray() | |
self.output_path = self.sid + "_output_audio.wav" | |
self.target_language = None | |
self.original_sr = None | |
self.vad = EnergyVAD( | |
sample_rate=TARGET_SAMPLING_RATE, | |
frame_length=25, | |
frame_shift=20, | |
energy_threshold=0.05, | |
pre_emphasis=0.95, | |
) # PM - Default values given in the docs for this class | |
def add_bytes(self, new_bytes): | |
self.buffer += new_bytes | |
def resample_and_write_to_file(self): | |
print("Audio being written to file....\n") | |
with wave.open(self.sid + "_OG.wav", "wb") as wf: | |
wf.setnchannels(1) | |
wf.setsampwidth(2) | |
wf.setframerate(self.original_sr) | |
wf.setnframes(0) | |
wf.setcomptype("NONE", "not compressed") | |
wf.writeframes(self.buffer) | |
waveform, sample_rate = torchaudio.load(self.sid + "_OG.wav") | |
resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLING_RATE, dtype=waveform.dtype) | |
resampled_waveform = resampler(waveform) | |
# torchaudio.save(self.output_path, resampled_waveform, TARGET_SAMPLING_RATE) | |
vad_waveform = self.vad(resampled_waveform) | |
# print(vad_waveform) # debugging | |
self.buffer = bytearray() | |
return detect_activity(vad_waveform), resampled_waveform | |
def get_length(self): | |
return len(self.buffer) | |
def __del__(self): | |
if len(self.buffer) > 0: | |
print(f"🚨 [ClientAudioBuffer] Buffer not empty for {self.sid} ({len(self.buffer)} bytes)!") | |
if os.path.exists(self.output_path): | |
os.remove(self.output_path) | |
if os.path.exists(self.sid + "_OG.wav"): | |
os.remove(self.sid + "_OG.wav") | |