Spaces:
Configuration error
Configuration error
#!/usr/bin/env python3 | |
import json | |
import string | |
_punctuation = "".join(c for c in string.punctuation if c not in ["-", "'"]) + "。,!?:”、…" | |
def split_long_segments(segments, max_length, use_space = True): | |
new_segments = [] | |
for segment in segments: | |
text = segment["text"] | |
if len(text) <= max_length: | |
new_segments.append(segment) | |
else: | |
meta_words = segment["words"] | |
# Note: we do this in case punctuation were removed from words | |
if use_space: | |
# Split text around spaces and punctuations (keeping punctuations) | |
words = text.split() | |
else: | |
words = [w["text"] for w in meta_words] | |
if len(words) != len(meta_words): | |
new_words = [w["text"] for w in meta_words] | |
print(f"WARNING: {' '.join(words)} != {' '.join(new_words)}") | |
words = new_words | |
current_text = "" | |
current_start = segment["start"] | |
current_best_idx = None | |
current_best_end = None | |
current_best_next_start = None | |
for i, (word, meta) in enumerate(zip(words, meta_words)): | |
current_text_before = current_text | |
if current_text and use_space: | |
current_text += " " | |
current_text += word | |
if len(current_text) > max_length and len(current_text_before): | |
start = current_start | |
if current_best_idx is not None: | |
text = current_text[:current_best_idx] | |
end = current_best_end | |
current_text = current_text[current_best_idx+1:] | |
current_start = current_best_next_start | |
else: | |
text = current_text_before | |
end = meta_words[i-1]["end"] | |
current_text = word | |
current_start = meta["start"] | |
current_best_idx = None | |
current_best_end = None | |
current_best_next_start = None | |
new_segments.append({"text": text, "start": start, "end": end}) | |
# Try to cut after punctuation | |
if current_text and current_text[-1] in _punctuation: | |
current_best_idx = len(current_text) | |
current_best_end = meta["end"] | |
current_best_next_start = meta_words[i+1]["start"] if i+1 < len(meta_words) else None | |
if len(current_text): | |
new_segments.append({"text": current_text, "start": current_start, "end": segment["end"]}) | |
return new_segments | |
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): | |
assert seconds >= 0, "non-negative timestamp expected" | |
milliseconds = round(seconds * 1000.0) | |
hours = milliseconds // 3_600_000 | |
milliseconds -= hours * 3_600_000 | |
minutes = milliseconds // 60_000 | |
milliseconds -= minutes * 60_000 | |
seconds = milliseconds // 1_000 | |
milliseconds -= seconds * 1_000 | |
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" | |
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" | |
def write_vtt(result, file): | |
print("WEBVTT\n", file=file) | |
for segment in result: | |
print( | |
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" | |
f"{segment['text'].strip().replace('-->', '->')}\n", | |
file=file, | |
flush=True, | |
) | |
def write_srt(result, file): | |
for i, segment in enumerate(result, start=1): | |
# write srt lines | |
print( | |
f"{i}\n" | |
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " | |
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" | |
f"{segment['text'].strip().replace('-->', '->')}\n", | |
file=file, | |
flush=True, | |
) | |
def cli(): | |
import os | |
import argparse | |
supported_formats = ["srt", "vtt"] | |
parser = argparse.ArgumentParser( | |
description='Convert .word.json transcription files (output of whisper_timestamped) to srt or vtt, being able to cut long segments', | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
) | |
parser.add_argument('input', type=str, help='Input json file, or input folder') | |
parser.add_argument('output', type=str, help='Output srt or vtt file, or output folder') | |
parser.add_argument('--max_length', default=200, help='Maximum length of a segment in characters', type=int) | |
parser.add_argument('--format', type=str, default="all", help='Output format (if the output is a folder, i.e. not a file with an explicit extension)', choices= supported_formats + ["all"]) | |
args = parser.parse_args() | |
if os.path.isdir(args.input) or not max([args.output.endswith(e) for e in supported_formats]): | |
input_files = [f for f in os.listdir(args.input) if f.endswith(".words.json")] if os.path.isdir(args.input) else [os.path.basename(args.input)] | |
extensions = [args.format] if args.format != "all" else ["srt", "vtt"] | |
output_files = [[os.path.join(args.output, f[:-11] + "." + e) for e in extensions] for f in input_files] | |
if os.path.isdir(args.input): | |
input_files = [os.path.join(args.input, f) for f in input_files] | |
else: | |
input_files = [args.input] | |
if not os.path.isdir(args.output): | |
os.makedirs(args.output) | |
else: | |
input_files = [args.input] | |
output_files = [[args.output]] | |
if not os.path.isdir(os.path.dirname(args.output)): | |
os.makedirs(os.path.dirname(args.output)) | |
for fn, outputs in zip(input_files, output_files): | |
with open(fn, "r", encoding="utf-8") as f: | |
transcript = json.load(f) | |
segments = transcript["segments"] | |
if args.max_length: | |
language = transcript["language"] | |
use_space = language not in ["zh", "ja", "th", "lo", "my"] | |
segments = split_long_segments(segments, args.max_length, use_space=use_space) | |
for output in outputs: | |
if output.endswith(".srt"): | |
with open(output, "w", encoding="utf-8") as f: | |
write_srt(segments, file=f) | |
elif output.endswith(".vtt"): | |
with open(output, "w", encoding="utf-8") as f: | |
write_vtt(segments, file=f) | |
else: | |
raise RuntimeError(f"Unknown output format for {output}") | |
if __name__ == "__main__": | |
cli() |