InterpreTalk / backend /Client.py
benjolo's picture
adding updates to monolingual transcript functionality
9680844 verified
raw
history blame
No virus
2.9 kB
from typing import Tuple
import wave
import os
import torchaudio
from vad import EnergyVAD
TARGET_SAMPLING_RATE = 16000
def create_frames(data: bytes, frame_duration: int) -> Tuple[bytes]:
frame_size = int(TARGET_SAMPLING_RATE * (frame_duration / 1000))
return (data[i:i + frame_size] for i in range(0, len(data), frame_size)), frame_size
def detect_activity(energies: list):
if sum(energies) < len(energies) / 12:
return False
count = 0
for energy in energies:
if energy == 1:
count += 1
if count == 12:
return True
else:
count = 0
return False
class Client:
def __init__(self, sid, client_id, username, call_id=None, original_sr=None):
self.sid = sid
self.client_id = client_id
self.username = username,
self.call_id = call_id
self.buffer = bytearray()
self.output_path = self.sid + "_output_audio.wav"
self.target_language = None
self.original_sr = original_sr
self.vad = EnergyVAD(
sample_rate=TARGET_SAMPLING_RATE,
frame_length=25,
frame_shift=20,
energy_threshold=0.05,
pre_emphasis=0.95,
) # PM - Default values given in the docs for this class
def add_bytes(self, new_bytes):
self.buffer += new_bytes
def resample_and_clear(self):
print(f"πŸ“₯ [ClientAudioBuffer] Writing {len(self.buffer)} bytes to {self.output_path}")
with wave.open(self.sid + "_OG.wav", "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(self.original_sr)
wf.setnframes(0)
wf.setcomptype("NONE", "not compressed")
wf.writeframes(self.buffer)
waveform, sample_rate = torchaudio.load(self.sid + "_OG.wav")
resampler = torchaudio.transforms.Resample(sample_rate, TARGET_SAMPLING_RATE, dtype=waveform.dtype)
resampled_waveform = resampler(waveform)
self.buffer = bytearray()
return resampled_waveform
def vad_analyse(self, resampled_waveform):
torchaudio.save(self.output_path, resampled_waveform, TARGET_SAMPLING_RATE)
vad_array = self.vad(resampled_waveform)
print(f"VAD OUTPUT: {vad_array}")
return detect_activity(vad_array)
def write_to_file(self, resampled_waveform):
torchaudio.save(self.output_path, resampled_waveform, TARGET_SAMPLING_RATE)
def get_length(self):
return len(self.buffer)
def __del__(self):
if len(self.buffer) > 0:
print(f"🚨 [ClientAudioBuffer] Buffer not empty for {self.sid} ({len(self.buffer)} bytes)!")
if os.path.exists(self.output_path):
os.remove(self.output_path)
if os.path.exists(self.sid + "_OG.wav"):
os.remove(self.sid + "_OG.wav")