Script for speaker-diarization & subtitle export

#3
by tintwotin - opened
# install accelerate, transformers==4.57.6, pyannote.audio
# Accept: https://huggingface.co/pyannote/speaker-diarization-community-1
# Get the HuggingFace token and insert it into the script

import os
import time
import gc
import warnings
import re
import torch
import numpy as np
import soundfile as sf
import torchaudio.functional as F
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from pyannote.audio import Pipeline

# Suppress harmless Pyannote warnings about standard deviation on tiny audio frames
warnings.filterwarnings("ignore", message=".*degrees of freedom is <= 0.*")

# Enable TF32 for faster compute on Ampere+ GPUs
torch.backends.cuda.matmul.allow_tf32 = True

# =====================================================================
# CONFIGURATION
# =====================================================================
MODEL_ID = "syvai/hviske-v5.3"
TARGET_SR = 16000
AUDIO_PATH = r"MY_AUDIO.wav"
SRT_OUTPUT_PATH = AUDIO_PATH.rsplit('.', 1)[0] + ".srt"
BATCH_SIZE = 8

# IMPORTANT: Insert your Hugging Face Token here for Pyannote Diarization
HF_TOKEN = "MY_HUGGINGFACE_TOKEN"

if HF_TOKEN == "YOUR_HF_TOKEN_HERE":
    raise ValueError(
        "Please insert your Hugging Face token into the 'HF_TOKEN' variable to use diarization. "
        "You also need to accept the terms at https://hf.co/pyannote/speaker-diarization-community-1"
    )

# =====================================================================
# STEP 1: SPEAKER DIARIZATION (Finding who speaks when)
# =====================================================================
print("--- STEP 1: SPEAKER DIARIZATION ---")
# Load the new Community-1 Pyannote pipeline to GPU
try:
    diarization_pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization-community-1",
        token=HF_TOKEN
    )
    diarization_pipeline.to(torch.device("cuda"))
except Exception as e:
    raise RuntimeError(f"Failed to load Pyannote Community-1. Ensure your HF_TOKEN is correct and terms are accepted.\nError: {e}")

print("Analyzing audio for speakers and timestamps... (This may take a minute)")
diarization_result = diarization_pipeline(AUDIO_PATH)

# Extract segments using the new Community-1 API
segments = []
for turn, speaker in diarization_result.speaker_diarization:
    segments.append({
        "start": turn.start,
        "end": turn.end,
        "speaker": speaker
    })

# Merge contiguous segments of the same speaker if the gap is less than 1.5 seconds
merged_segments = []
for seg in segments:
    if not merged_segments:
        merged_segments.append(seg)
    else:
        last_seg = merged_segments[-1]
        if seg["speaker"] == last_seg["speaker"] and (seg["start"] - last_seg["end"]) < 1.5:
            last_seg["end"] = max(last_seg["end"], seg["end"])
        else:
            merged_segments.append(seg)

# Filter out extremely short segments (< 0.5s)
valid_segments = [seg for seg in merged_segments if (seg["end"] - seg["start"]) >= 0.5]
valid_segments.sort(key=lambda x: x["start"])
print(f"Found {len(valid_segments)} consolidated speaker segments.")

# FREE UP VRAM COMPLETELY BEFORE LOADING THE ASR MODEL
del diarization_pipeline
gc.collect()
torch.cuda.empty_cache()

# =====================================================================
# STEP 2: AUDIO PREPROCESSING
# =====================================================================
print("\n--- STEP 2: AUDIO PREPROCESSING ---")
audio, sr = sf.read(AUDIO_PATH)
audio = np.asarray(audio, dtype=np.float32)

# Convert stereo to mono
if audio.ndim > 1:
    audio = np.mean(audio, axis=1)

# Resample to 16000 Hz if necessary
if sr != TARGET_SR:
    print(f"Resampling from {sr}Hz to {TARGET_SR}Hz...")
    audio_tensor = torch.from_numpy(audio)
    audio = F.resample(audio_tensor, orig_freq=sr, new_freq=TARGET_SR).numpy()
    sr = TARGET_SR

# =====================================================================
# STEP 3: ASR TRANSCRIPTION
# =====================================================================
print("\n--- STEP 3: ASR TRANSCRIPTION ---")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    MODEL_ID, 
    trust_remote_code=True, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True
).to(device="cuda").eval()

print(f"Transcribing {len(valid_segments)} segments in batches of {BATCH_SIZE}...")
start_time = time.time()

with torch.no_grad():
    for i in range(0, len(valid_segments), BATCH_SIZE):
        batch = valid_segments[i:i+BATCH_SIZE]
        audio_arrays = []
        
        for seg in batch:
            # Safely slice the audio array using the diarization timestamps
            start_sample = max(0, int(seg["start"] * sr))
            end_sample = min(len(audio), int(seg["end"] * sr))
            audio_arrays.append(audio[start_sample:end_sample])
            
        # Native chunking algorithm handles transcription seamlessly
        outputs = model.transcribe(
            processor=processor,
            language="da",
            audio_arrays=audio_arrays,
            sample_rates=[sr] * len(audio_arrays),
        )
        
        # Attach the transcribed text back to our segment objects
        for seg, text in zip(batch, outputs):
            seg["text"] = text.strip()

elapsed = time.time() - start_time
print(f"Transcription took: {elapsed:.1f}s")

# =====================================================================
# STEP 4: GENERATING PROFESSIONAL SRT FILE
# =====================================================================
print("\n--- STEP 4: GENERATING PROFESSIONAL SRT ---")

def format_time(seconds):
    """Formats float seconds into the standard SRT timestamp format (HH:MM:SS,mmm)"""
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    ms = round((seconds - int(seconds)) * 1000)
    if ms >= 1000:
        secs += 1
        ms -= 1000
        if secs >= 60:
            minutes += 1
            secs -= 60
            if minutes >= 60:
                hours += 1
                minutes -= 60
    return f"{hours:02d}:{minutes:02d}:{secs:02d},{ms:03d}"

def format_two_lines(text):
    """Splits a single string into two balanced lines if it's long."""
    if len(text) <= 45:
        return text
    mid = len(text) // 2
    spaces = [i for i, char in enumerate(text) if char == ' ']
    if not spaces:
        return text
    closest_space = min(spaces, key=lambda x: abs(x - mid))
    return text[:closest_space] + "\n" + text[closest_space+1:]

def break_into_subtitles(start, end, text, speaker, max_chars=80):
    """Slices a massive paragraph into broadcast-standard subtitle chunks."""
    text = re.sub(r'\s+', ' ', text).strip()
    if not text:
        return []
        
    # Split by punctuation first to preserve sentence flow
    parts = re.split(r'(?<=[.!?,]) +', text)
    chunks = []
    current_chunk = ""
    
    for part in parts:
        if len(current_chunk) + len(part) + 1 <= max_chars:
            if current_chunk:
                current_chunk += " " + part
            else:
                current_chunk = part
        else:
            if current_chunk:
                chunks.append(current_chunk)
            
            # If a single sentence without punctuation is still too long, split by words
            if len(part) > max_chars:
                words = part.split()
                temp_chunk = ""
                for w in words:
                    if len(temp_chunk) + len(w) + 1 <= max_chars:
                        temp_chunk += (" " + w if temp_chunk else w)
                    else:
                        chunks.append(temp_chunk)
                        temp_chunk = w
                current_chunk = temp_chunk
            else:
                current_chunk = part
    if current_chunk:
        chunks.append(current_chunk)
        
    # Distribute the audio duration proportionally across the generated text chunks
    total_chars = sum(len(c) for c in chunks)
    total_duration = end - start
    
    subs = []
    current_start = start
    for c in chunks:
        if total_chars == 0:
            break
        fraction = len(c) / total_chars
        original_duration = total_duration * fraction
        
        # Prevent short text from hanging on screen for 15+ seconds
        display_duration = original_duration
        max_allowed = max(2.0, len(c) * 0.12) # roughly 8 characters per second reading speed cap
        if display_duration > max_allowed:
            display_duration = max_allowed
            
        subs.append({
            "start": current_start,
            "end": current_start + display_duration,
            "speaker": speaker,
            "text": c.strip()
        })
        current_start += original_duration
        
    return subs

# 1. Expand long segments into bite-sized subtitles
all_subtitles = []
for seg in valid_segments:
    text = seg.get("text", "")
    if not text:
        continue
        
    # Clean up speaker label
    speaker_id = seg["speaker"].replace("SPEAKER_", "")
    try:
        speaker_num = int(speaker_id) + 1
        speaker = f"Speaker {speaker_num}"
    except ValueError:
        speaker = seg["speaker"]
        
    sub_chunks = break_into_subtitles(seg["start"], seg["end"], text, speaker)
    all_subtitles.extend(sub_chunks)

# 2. Sort and format the SRT output
all_subtitles.sort(key=lambda x: x["start"])
srt_content = ""
srt_index = 1
previous_speaker = None
last_end_time = 0.0

for sub in all_subtitles:
    start_time_str = format_time(sub["start"])
    end_time_str = format_time(sub["end"])
    
    formatted_text = format_two_lines(sub["text"])
    
    # Only show speaker tag if the speaker changes OR there was a pause > 3 seconds
    time_since_last = sub["start"] - last_end_time
    if sub["speaker"] != previous_speaker or time_since_last > 3.0:
        final_text = f"[{sub['speaker']}] {formatted_text}"
    else:
        final_text = formatted_text
        
    srt_content += f"{srt_index}\n"
    srt_content += f"{start_time_str} --> {end_time_str}\n"
    srt_content += f"{final_text}\n\n"
    
    previous_speaker = sub["speaker"]
    last_end_time = sub["end"]
    srt_index += 1

# Write to disk
with open(SRT_OUTPUT_PATH, "w", encoding="utf-8") as f:
    f.write(srt_content)

print(f"Successfully saved subtitles to: {SRT_OUTPUT_PATH}\n")
print("--- FULL SRT PREVIEW ---")
print(srt_content)
print("\n--- SCRIPT COMPLETED SUCCESSFULLY ---")

Output example:

134
00:08:56,437 --> 00:09:01,031
[Speaker 2] Jeg læste i et interview, at det var din mor, som introducerede dig for film,

135
00:09:01,031 --> 00:09:05,052
da du var lille. Var det noget, I kunne holde fast i gennem hele livet?

136
00:09:04,925 --> 00:09:06,648
[Speaker 3] Ja, det gjorde vi tit.

137
00:09:05,052 --> 00:09:06,258
[Speaker 2] Den fælles kærlighed?

138
00:09:06,648 --> 00:09:11,984
[Speaker 3] Det var... altså, hun var amerikansk, og så sagde hun en gang imellem:

139
00:09:11,984 --> 00:09:14,364
"Let's go to the movies." "Okay."

Sign up or log in to comment