File size: 6,823 Bytes
2cba4ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
#!/usr/bin/env python3
import json
import string
_punctuation = "".join(c for c in string.punctuation if c not in ["-", "'"]) + "。,!?:”、…"
def split_long_segments(segments, max_length, use_space = True):
new_segments = []
for segment in segments:
text = segment["text"]
if len(text) <= max_length:
new_segments.append(segment)
else:
meta_words = segment["words"]
# Note: we do this in case punctuation were removed from words
if use_space:
# Split text around spaces and punctuations (keeping punctuations)
words = text.split()
else:
words = [w["text"] for w in meta_words]
if len(words) != len(meta_words):
new_words = [w["text"] for w in meta_words]
print(f"WARNING: {' '.join(words)} != {' '.join(new_words)}")
words = new_words
current_text = ""
current_start = segment["start"]
current_best_idx = None
current_best_end = None
current_best_next_start = None
for i, (word, meta) in enumerate(zip(words, meta_words)):
current_text_before = current_text
if current_text and use_space:
current_text += " "
current_text += word
if len(current_text) > max_length and len(current_text_before):
start = current_start
if current_best_idx is not None:
text = current_text[:current_best_idx]
end = current_best_end
current_text = current_text[current_best_idx+1:]
current_start = current_best_next_start
else:
text = current_text_before
end = meta_words[i-1]["end"]
current_text = word
current_start = meta["start"]
current_best_idx = None
current_best_end = None
current_best_next_start = None
new_segments.append({"text": text, "start": start, "end": end})
# Try to cut after punctuation
if current_text and current_text[-1] in _punctuation:
current_best_idx = len(current_text)
current_best_end = meta["end"]
current_best_next_start = meta_words[i+1]["start"] if i+1 < len(meta_words) else None
if len(current_text):
new_segments.append({"text": current_text, "start": current_start, "end": segment["end"]})
return new_segments
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
def write_vtt(result, file):
print("WEBVTT\n", file=file)
for segment in result:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
def write_srt(result, file):
for i, segment in enumerate(result, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
def cli():
import os
import argparse
supported_formats = ["srt", "vtt"]
parser = argparse.ArgumentParser(
description='Convert .word.json transcription files (output of whisper_timestamped) to srt or vtt, being able to cut long segments',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('input', type=str, help='Input json file, or input folder')
parser.add_argument('output', type=str, help='Output srt or vtt file, or output folder')
parser.add_argument('--max_length', default=200, help='Maximum length of a segment in characters', type=int)
parser.add_argument('--format', type=str, default="all", help='Output format (if the output is a folder, i.e. not a file with an explicit extension)', choices= supported_formats + ["all"])
args = parser.parse_args()
if os.path.isdir(args.input) or not max([args.output.endswith(e) for e in supported_formats]):
input_files = [f for f in os.listdir(args.input) if f.endswith(".words.json")] if os.path.isdir(args.input) else [os.path.basename(args.input)]
extensions = [args.format] if args.format != "all" else ["srt", "vtt"]
output_files = [[os.path.join(args.output, f[:-11] + "." + e) for e in extensions] for f in input_files]
if os.path.isdir(args.input):
input_files = [os.path.join(args.input, f) for f in input_files]
else:
input_files = [args.input]
if not os.path.isdir(args.output):
os.makedirs(args.output)
else:
input_files = [args.input]
output_files = [[args.output]]
if not os.path.isdir(os.path.dirname(args.output)):
os.makedirs(os.path.dirname(args.output))
for fn, outputs in zip(input_files, output_files):
with open(fn, "r", encoding="utf-8") as f:
transcript = json.load(f)
segments = transcript["segments"]
if args.max_length:
language = transcript["language"]
use_space = language not in ["zh", "ja", "th", "lo", "my"]
segments = split_long_segments(segments, args.max_length, use_space=use_space)
for output in outputs:
if output.endswith(".srt"):
with open(output, "w", encoding="utf-8") as f:
write_srt(segments, file=f)
elif output.endswith(".vtt"):
with open(output, "w", encoding="utf-8") as f:
write_vtt(segments, file=f)
else:
raise RuntimeError(f"Unknown output format for {output}")
if __name__ == "__main__":
cli() |