File size: 6,823 Bytes
2cba4ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python3

import json
import string

_punctuation = "".join(c for c in string.punctuation if c not in ["-", "'"]) + "。,!?:”、…"

def split_long_segments(segments, max_length, use_space = True):
    new_segments = []
    for segment in segments:
        text = segment["text"]
        if len(text) <= max_length:
            new_segments.append(segment)
        else:
            meta_words = segment["words"]
            # Note: we do this in case punctuation were removed from words
            if use_space:
                # Split text around spaces and punctuations (keeping punctuations)
                words = text.split()
            else:
                words = [w["text"] for w in meta_words]
            if len(words) != len(meta_words):
                new_words = [w["text"] for w in meta_words]
                print(f"WARNING: {' '.join(words)} != {' '.join(new_words)}")
                words = new_words
            current_text = ""
            current_start = segment["start"]
            current_best_idx = None
            current_best_end = None
            current_best_next_start = None
            for i, (word, meta) in enumerate(zip(words, meta_words)):
                current_text_before = current_text
                if current_text and use_space:
                    current_text += " "
                current_text += word

                if len(current_text) > max_length and len(current_text_before):
                    start = current_start
                    if current_best_idx is not None:
                        text = current_text[:current_best_idx]
                        end = current_best_end
                        current_text = current_text[current_best_idx+1:]
                        current_start = current_best_next_start
                    else:
                        text = current_text_before
                        end = meta_words[i-1]["end"]
                        current_text = word
                        current_start = meta["start"]

                    current_best_idx = None
                    current_best_end = None
                    current_best_next_start = None                        

                    new_segments.append({"text": text, "start": start, "end": end})

                # Try to cut after punctuation
                if current_text and current_text[-1] in _punctuation:
                    current_best_idx = len(current_text)
                    current_best_end = meta["end"]
                    current_best_next_start = meta_words[i+1]["start"] if i+1 < len(meta_words) else None
            
            if len(current_text):
                new_segments.append({"text": current_text, "start": current_start, "end": segment["end"]})
            
    return new_segments

def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
    assert seconds >= 0, "non-negative timestamp expected"
    milliseconds = round(seconds * 1000.0)

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
    return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"

def write_vtt(result, file):
    print("WEBVTT\n", file=file)
    for segment in result:
        print(
            f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
            f"{segment['text'].strip().replace('-->', '->')}\n",
            file=file,
            flush=True,
        )

def write_srt(result, file):
    for i, segment in enumerate(result, start=1):
        # write srt lines
        print(
            f"{i}\n"
            f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
            f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
            f"{segment['text'].strip().replace('-->', '->')}\n",
            file=file,
            flush=True,
        )

def cli():

    import os
    import argparse

    supported_formats = ["srt", "vtt"]

    parser = argparse.ArgumentParser(
        description='Convert .word.json transcription files (output of whisper_timestamped) to srt or vtt, being able to cut long segments',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('input', type=str, help='Input json file, or input folder')
    parser.add_argument('output', type=str, help='Output srt or vtt file, or output folder')
    parser.add_argument('--max_length', default=200, help='Maximum length of a segment in characters', type=int)
    parser.add_argument('--format', type=str, default="all", help='Output format (if the output is a folder, i.e. not a file with an explicit extension)', choices= supported_formats + ["all"])
    args = parser.parse_args()

    if os.path.isdir(args.input) or not max([args.output.endswith(e) for e in supported_formats]):
        input_files = [f for f in os.listdir(args.input) if f.endswith(".words.json")] if os.path.isdir(args.input) else [os.path.basename(args.input)]
        extensions = [args.format] if args.format != "all" else ["srt", "vtt"]
        output_files = [[os.path.join(args.output, f[:-11] + "." + e) for e in extensions] for f in input_files]
        if os.path.isdir(args.input):
            input_files = [os.path.join(args.input, f) for f in input_files]
        else:
            input_files = [args.input]
        if not os.path.isdir(args.output):
            os.makedirs(args.output)
    else:
        input_files = [args.input]
        output_files = [[args.output]]
        if not os.path.isdir(os.path.dirname(args.output)):
            os.makedirs(os.path.dirname(args.output))

    for fn, outputs in zip(input_files, output_files):
        with open(fn, "r", encoding="utf-8") as f:
            transcript = json.load(f)
        segments = transcript["segments"]
        if args.max_length:
            language = transcript["language"]
            use_space = language not in ["zh", "ja", "th", "lo", "my"]
            segments = split_long_segments(segments, args.max_length, use_space=use_space)
        for output in outputs:
            if output.endswith(".srt"):
                with open(output, "w", encoding="utf-8") as f:
                    write_srt(segments, file=f)
            elif output.endswith(".vtt"):
                with open(output, "w", encoding="utf-8") as f:
                    write_vtt(segments, file=f)
            else:
                raise RuntimeError(f"Unknown output format for {output}")

if __name__ == "__main__":
    cli()