Spaces:
Sleeping
Sleeping
import subprocess | |
import threading | |
import argparse | |
import fcntl | |
import select | |
import whisper | |
import ffmpeg | |
import signal | |
import numpy as np | |
import queue | |
import time | |
import webrtcvad | |
import collections | |
import os | |
from transformers import MarianMTModel, MarianTokenizer | |
# Global variables | |
rtmp_url = "" | |
dash_output_path = "" | |
segment_duration = 2 | |
last_activity_time = 0.0 | |
cleanup_threshold = 10 # seconds of inactivity before cleanup | |
start_time = 0.0 | |
# Languages for translation (ISO 639-1 codes) | |
target_languages = ["es", "zh", "ru"] # Example: Spanish, Chinese, Russian | |
# Initialize Whisper model | |
whisper_model = {} | |
# Define Frame class | |
class Frame: | |
def __init__(self, data, timestamp, duration): | |
self.data = data | |
self.timestamp = timestamp | |
self.duration = duration | |
# Audio buffer and caption queues | |
audio_buffer = queue.Queue() | |
caption_queues = {lang: queue.Queue() for lang in target_languages + ["original", "en"]} | |
language_model_names = { | |
"es": "Helsinki-NLP/opus-mt-en-es", | |
"zh": "Helsinki-NLP/opus-mt-en-zh", | |
"ru": "Helsinki-NLP/opus-mt-en-ru", | |
} | |
translation_models = {} | |
tokenizers = {} | |
# Initialize VAD | |
vad = webrtcvad.Vad(3) # Aggressiveness mode 3 (most aggressive) | |
# Event to signal threads to stop | |
stop_event = threading.Event() | |
def transcode_rtmp_to_dash(): | |
ffmpeg_command = [ | |
"ffmpeg", | |
"-i", rtmp_url, | |
"-map", "0:v:0", "-map", "0:a:0", | |
"-c:v", "libx264", "-preset", "slow", | |
"-c:a", "aac", "-b:a", "128k", | |
"-f", "dash", | |
"-seg_duration", str(segment_duration), | |
"-use_timeline", "1", | |
"-use_template", "1", | |
"-init_seg_name", "init_$RepresentationID$.m4s", | |
"-media_seg_name", "chunk_$RepresentationID$_$Number%05d$.m4s", | |
"-adaptation_sets", "id=0,streams=v id=1,streams=a", | |
f"{dash_output_path}/manifest.mpd" | |
] | |
process = subprocess.Popen(ffmpeg_command) | |
while not stop_event.is_set(): | |
time.sleep(1) | |
process.kill() | |
def capture_audio(): | |
global last_activity_time | |
command = [ | |
'ffmpeg', | |
'-i', rtmp_url, | |
'-acodec', 'pcm_s16le', | |
'-ar', '16000', | |
'-ac', '1', | |
'-f', 's16le', | |
'-' | |
] | |
sample_rate = 16000 | |
frame_duration_ms = 30 | |
sample_width = 2 # Only 16-bit audio supported | |
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) | |
# Set stdout to non-blocking mode | |
fd = process.stdout.fileno() | |
fl = fcntl.fcntl(fd, fcntl.F_GETFL) | |
fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) | |
frame_size = int(sample_rate * frame_duration_ms / 1000) * sample_width | |
frame_count = 0 | |
while not stop_event.is_set(): | |
ready, _, _ = select.select([process.stdout], [], [], 0.1) | |
if ready: | |
try: | |
in_bytes = os.read(fd, frame_size) | |
if not in_bytes: | |
break | |
if len(in_bytes) < frame_size: | |
in_bytes += b'\x00' * (frame_size - len(in_bytes)) | |
last_activity_time = time.time() | |
timestamp = frame_count * frame_duration_ms * 0.85 | |
frame = Frame(np.frombuffer(in_bytes, np.int16), timestamp, frame_duration_ms) | |
audio_buffer.put(frame) | |
frame_count += 1 | |
except BlockingIOError: | |
continue | |
else: | |
time.sleep(0.01) | |
process.kill() | |
def frames_to_numpy(frames): | |
all_frames = np.concatenate([f.data for f in frames]) | |
float_samples = all_frames.astype(np.float32) / np.iinfo(np.int16).max | |
return float_samples | |
def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames): | |
num_padding_frames = int(padding_duration_ms / frame_duration_ms) | |
ring_buffer = collections.deque(maxlen=num_padding_frames) | |
triggered = False | |
for frame in frames: | |
if len(frame.data) != int(sample_rate * (frame_duration_ms / 1000.0)): | |
print(f"Skipping frame with incorrect size: {len(frame.data)} samples", flush=True) | |
continue | |
is_speech = vad.is_speech(frame.data.tobytes(), sample_rate) | |
if not triggered: | |
ring_buffer.append((frame, is_speech)) | |
num_voiced = len([f for f, speech in ring_buffer if speech]) | |
if num_voiced > 0.8 * ring_buffer.maxlen: | |
triggered = True | |
for f, s in ring_buffer: | |
yield f | |
ring_buffer.clear() | |
else: | |
yield frame | |
ring_buffer.append((frame, is_speech)) | |
num_unvoiced = len([f for f, speech in ring_buffer if not speech]) | |
if num_unvoiced > 0.8 * ring_buffer.maxlen: | |
triggered = False | |
yield None | |
ring_buffer.clear() | |
for f, s in ring_buffer: | |
yield f | |
ring_buffer.clear() | |
def process_audio(): | |
global last_activity_time | |
frames = [] | |
buffer_duration_ms = 1500 # About 1.5 seconds of audio | |
while not stop_event.is_set(): | |
while not audio_buffer.empty(): | |
frame = audio_buffer.get(timeout=5.0) | |
frames.append(frame) | |
if frames and sum(f.duration for f in frames) >= buffer_duration_ms: | |
vad_frames = list(vad_collector(16000, 30, 300, vad, frames)) | |
if vad_frames: | |
audio_segment = [f for f in vad_frames if f is not None] | |
if audio_segment: | |
# Transcribe the original audio | |
result = whisper_model.transcribe(frames_to_numpy(audio_segment)) | |
if result["text"]: | |
timestamp = audio_segment[0].timestamp | |
caption_queues["original"].put((timestamp, result["text"])) | |
english_translation = whisper_model.transcribe(frames_to_numpy(audio_segment), task="translate") | |
caption_queues["en"].put((timestamp, english_translation["text"])) | |
# Translate to target languages | |
for lang in target_languages: | |
tokenizer = tokenizers[lang] | |
translation_model = translation_models[lang] | |
inputs = tokenizer.encode(english_translation["text"], return_tensors="pt", padding=True, truncation=True) | |
translated_tokens = translation_model.generate(inputs) | |
translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True) | |
caption_queues[lang].put((timestamp, translated_text)) | |
frames = [] | |
time.sleep(0.01) | |
def write_captions(lang): | |
os.makedirs(dash_output_path, exist_ok=True) | |
filename = f"{dash_output_path}/captions_{lang}.vtt" | |
with open(filename, "w", encoding="utf-8") as f: | |
f.write("WEBVTT\n\n") | |
last_end_time = None | |
while not stop_event.is_set(): | |
if not caption_queues[lang].empty(): | |
timestamp, text = caption_queues[lang].get() | |
start_time = format_time(timestamp / 1000) # Convert ms to seconds | |
end_time = format_time((timestamp + 5000) / 1000) # Assume 5-second duration for each caption | |
# Adjust the previous caption's end time if necessary | |
if last_end_time and start_time != last_end_time: | |
adjust_previous_caption(filename, last_end_time, start_time) | |
# Write the new caption | |
with open(filename, "a", encoding="utf-8") as f: | |
f.write(f"{start_time} --> {end_time}\n") | |
f.write(f"{text}\n\n") | |
f.flush() | |
last_end_time = end_time | |
time.sleep(0.1) | |
def adjust_previous_caption(filename, old_end_time, new_end_time): | |
with open(filename, "r", encoding="utf-8") as f: | |
lines = f.readlines() | |
for i in range(len(lines) - 1, -1, -1): | |
if "-->" in lines[i]: | |
parts = lines[i].split("-->") | |
if parts[1].strip() == old_end_time: | |
lines[i] = f"{parts[0].strip()} --> {new_end_time}\n" | |
break | |
with open(filename, "w", encoding="utf-8") as f: | |
f.writelines(lines) | |
def format_time(seconds): | |
hours, remainder = divmod(seconds, 3600) | |
minutes, seconds = divmod(remainder, 60) | |
return f"{int(hours):02d}:{int(minutes):02d}:{seconds:06.3f}" | |
def signal_handler(signum, frame): | |
print(f"Received signal {signum}. Cleaning up and exiting...") | |
# Signal all threads to stop | |
stop_event.set() | |
def cleanup(): | |
global last_activity_time | |
while not stop_event.is_set(): | |
current_time = time.time() | |
if last_activity_time != 0.0 and current_time - last_activity_time > cleanup_threshold: | |
print("No activity detected for 10 seconds. Cleaning up...", flush=True) | |
# Signal all threads to stop | |
stop_event.set() | |
break | |
time.sleep(1) # Check for inactivity every second | |
# Clear caption queues | |
for lang in target_languages + ["original", "en"]: | |
while not caption_queues[lang].empty(): | |
caption_queues[lang].get() | |
# Delete DASH output files | |
for root, dirs, files in os.walk(dash_output_path, topdown=False): | |
for name in files: | |
os.remove(os.path.join(root, name)) | |
for name in dirs: | |
os.rmdir(os.path.join(root, name)) | |
print("Cleanup completed.", flush=True) | |
if __name__ == "__main__": | |
# Get RTMP URL and DASH output path from user input | |
signal.signal(signal.SIGTERM, signal_handler) | |
parser = argparse.ArgumentParser(description="Process audio for translation.") | |
parser.add_argument('--rtmp_url', help='rtmp url') | |
parser.add_argument('--output_directory', help='Dash directory') | |
parser.add_argument('--model', help='Whisper model size: base|small|medium|large|large-v2') | |
start_time = time.time() | |
args = parser.parse_args() | |
rtmp_url = args.rtmp_url | |
dash_output_path = args.output_directory | |
model_size = args.model | |
print(f"RTMP URL: {rtmp_url}") | |
print(f"DASH output path: {dash_output_path}") | |
print(f"Model: {dash_output_path}") | |
print("Downloading models\n") | |
print("Whisper\n") | |
whisper_model = whisper.load_model(model_size, download_root="/tmp/model/") # Adjust model size as necessary | |
for lang, model_name in language_model_names.items(): | |
print(f"Lang: {lang}, model: {model_name}\n") | |
tokenizers[lang] = MarianTokenizer.from_pretrained(model_name) | |
translation_models[lang] = MarianMTModel.from_pretrained(model_name) | |
# Start RTMP to DASH transcoding in a separate thread | |
transcode_thread = threading.Thread(target=transcode_rtmp_to_dash) | |
transcode_thread.start() | |
# Start audio capture in a separate thread | |
audio_capture_thread = threading.Thread(target=capture_audio) | |
audio_capture_thread.start() | |
# Start audio processing in a separate thread | |
audio_processing_thread = threading.Thread(target=process_audio) | |
audio_processing_thread.start() | |
# Start caption writing threads for original and all target languages | |
caption_threads = [] | |
for lang in target_languages + ["original", "en"]: | |
caption_thread = threading.Thread(target=write_captions, args=(lang,)) | |
caption_threads.append(caption_thread) | |
caption_thread.start() | |
# Start the cleanup thread | |
cleanup_thread = threading.Thread(target=cleanup) | |
cleanup_thread.start() | |
# Wait for all threads to complete | |
print("Join transcode", flush=True) | |
if transcode_thread.is_alive(): | |
transcode_thread.join() | |
print("Join sudio capture", flush=True) | |
if audio_capture_thread.is_alive(): | |
audio_capture_thread.join() | |
print("Join audio processing", flush=True) | |
if audio_processing_thread.is_alive(): | |
audio_processing_thread.join() | |
for thread in caption_threads: | |
if thread.is_alive(): | |
thread.join() | |
print("Join clenaup", flush=True) | |
if cleanup_thread.is_alive(): | |
cleanup_thread.join() | |
print("All threads have been stopped and cleaned up.") | |
exit(0) | |