File size: 24,386 Bytes
55c7989
66e606c
55c7989
 
 
9a42032
b39d769
55c7989
09cabee
 
cf5f1c9
 
 
f1a218d
 
55c7989
 
f1a218d
55c7989
 
f1a218d
 
 
 
55c7989
5f10ef2
55c7989
 
f1a218d
55c7989
5f10ef2
55c7989
 
698fd4d
cf5f1c9
 
da02d45
cf5f1c9
f1a218d
cf5f1c9
6113bd9
 
7baae45
 
55c7989
 
7baae45
55c7989
7baae45
55c7989
cf5f1c9
55c7989
6113bd9
55c7989
 
 
 
 
 
bd773a2
 
6113bd9
9b3283d
 
6113bd9
 
9b3283d
55c7989
 
 
 
 
 
 
9a42032
55c7989
 
 
 
 
 
 
 
 
9b3283d
55c7989
 
 
 
 
 
 
9b3283d
cf5f1c9
55c7989
 
cf5f1c9
6113bd9
55c7989
cf5f1c9
6113bd9
09cabee
55c7989
09cabee
 
 
 
 
 
cf5f1c9
 
55c7989
cf5f1c9
48f0069
cf5f1c9
 
 
 
55c7989
7baae45
cf5f1c9
 
6113bd9
55c7989
 
 
 
 
 
 
 
6113bd9
55c7989
 
6113bd9
55c7989
 
 
6113bd9
 
55c7989
 
 
 
 
9a42032
55c7989
6113bd9
 
25b7b0c
6113bd9
 
 
 
 
 
 
 
9a42032
 
6113bd9
da02d45
25b7b0c
55c7989
9b3283d
55c7989
 
 
 
9b3283d
 
9a42032
6113bd9
55c7989
cf5f1c9
 
55c7989
0d268b7
55c7989
 
 
0d268b7
55c7989
ce7a58b
55c7989
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d268b7
 
 
 
ce7a58b
55c7989
9e140e9
fb45ef4
7fea39b
b9085da
55c7989
7fea39b
 
55c7989
 
 
a161330
 
 
55c7989
 
ce7a58b
 
 
55c7989
 
ce7a58b
55c7989
 
 
ce7a58b
55c7989
ce7a58b
 
a161330
55c7989
 
 
7fea39b
b9085da
7fea39b
55c7989
a161330
 
9518a49
 
 
55c7989
9518a49
 
55c7989
 
9518a49
 
 
55c7989
9518a49
 
55c7989
 
 
7fea39b
55c7989
 
b39d769
55c7989
 
 
fdce050
 
 
 
55c7989
fdce050
9b3283d
5f10ef2
 
b39d769
a161330
b39d769
 
3cc60a3
9b3283d
e75254e
3cc60a3
 
 
 
 
e75254e
 
3cc60a3
f1a218d
 
9b3283d
 
f1a218d
 
7baae45
55c7989
 
7baae45
 
55c7989
 
 
f2c3799
 
7baae45
 
55c7989
 
7baae45
55c7989
 
9b3283d
55c7989
9b3283d
f1a218d
7baae45
f1a218d
7baae45
9b3283d
f1a218d
55c7989
f1a218d
9b3283d
f1a218d
 
 
 
 
 
6113bd9
f1a218d
 
 
 
 
 
 
7baae45
9b3283d
 
7baae45
 
55c7989
9b3283d
 
7baae45
 
f1a218d
7baae45
f1a218d
9b3283d
 
9a42032
7baae45
9a42032
9b3283d
 
9a42032
7baae45
 
 
55c7989
7baae45
9a42032
f1a218d
6113bd9
09cabee
9b3283d
9a42032
9b3283d
e3825f8
 
0b15a1d
e3825f8
55c7989
9b3283d
 
e3825f8
55c7989
e3825f8
 
55c7989
 
0b15a1d
cf5f1c9
66e606c
 
9a42032
66e606c
55c7989
 
9a42032
 
 
66e606c
9a42032
 
 
3f09796
9a42032
 
cf5f1c9
d438792
 
9a42032
1de6702
 
fd3d7cb
cf5f1c9
d438792
 
 
 
1de6702
0a976db
1de6702
0a976db
44646a0
9a42032
 
 
0a976db
259f806
1de6702
44646a0
 
1de6702
 
 
d438792
 
 
55c7989
d438792
55c7989
d438792
 
22b6efb
55c7989
d438792
 
 
 
 
 
 
 
 
55c7989
 
d438792
 
 
 
55c7989
 
d438792
 
 
 
 
fd3d7cb
d438792
 
55c7989
1de6702
55c7989
 
1de6702
 
 
 
 
 
 
 
 
55c7989
9b3283d
 
 
 
 
 
 
259f806
 
9b3283d
55c7989
9b3283d
 
 
55c7989
9b3283d
 
 
 
 
 
55c7989
9b3283d
 
55c7989
9b3283d
 
 
55c7989
9b3283d
 
 
55c7989
9b3283d
 
 
 
 
55c7989
9a42032
9b3283d
 
 
 
55c7989
9a42032
9b3283d
 
 
 
55c7989
9b3283d
e3825f8
 
 
9ad8b62
 
 
 
55c7989
 
 
e3825f8
 
 
55c7989
9b3283d
e3825f8
 
 
9ad8b62
55c7989
 
 
e3825f8
9b3283d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
import os
import re
from copy import copy, deepcopy
from csv import reader
from datetime import timedelta
import logging
import openai


class SRT_segment(object):
    def __init__(self, *args) -> None:
        if isinstance(args[0], dict):
            segment = args[0]
            self.start = segment['start']
            self.end = segment['end']
            self.start_ms = int((segment['start'] * 100) % 100 * 10)
            self.end_ms = int((segment['end'] * 100) % 100 * 10)

            if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']):  # avoid empty time stamp
                self.end_ms += 500

            self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms)
            self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms)
            if self.start_ms == 0:
                self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',000'
            else:
                self.start_time_str = str(0) + str(self.start_time).split('.')[0] + ',' + \
                                      str(self.start_time).split('.')[1][:3]
            if self.end_ms == 0:
                self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',000'
            else:
                self.end_time_str = str(0) + str(self.end_time).split('.')[0] + ',' + str(self.end_time).split('.')[1][
                                                                                      :3]
            self.source_text = segment['text'].lstrip()
            self.duration = f"{self.start_time_str} --> {self.end_time_str}"
            self.translation = ""

        elif isinstance(args[0], list):
            self.source_text = args[0][2]
            self.duration = args[0][1]
            self.start_time_str = self.duration.split(" --> ")[0]
            self.end_time_str = self.duration.split(" --> ")[1]

            # parse the time to float
            self.start_ms = int(self.start_time_str.split(',')[1]) / 10
            self.end_ms = int(self.end_time_str.split(',')[1]) / 10
            start_list = self.start_time_str.split(',')[0].split(':')
            self.start = int(start_list[0]) * 3600 + int(start_list[1]) * 60 + int(start_list[2]) + self.start_ms / 100
            end_list = self.end_time_str.split(',')[0].split(':')
            self.end = int(end_list[0]) * 3600 + int(end_list[1]) * 60 + int(end_list[2]) + self.end_ms / 100
            self.translation = ""

    def merge_seg(self, seg):
        """
        Merge the segment seg with the current segment in place.
        :param seg: Another segment that is strictly next to current one.
        :return: None
        """
        # assert seg.start_ms == self.end_ms, f"cannot merge discontinuous segments."
        self.source_text += f' {seg.source_text}'
        self.translation += f' {seg.translation}'
        self.end_time_str = seg.end_time_str
        self.end = seg.end
        self.end_ms = seg.end_ms
        self.duration = f"{self.start_time_str} --> {self.end_time_str}"
        pass

    def __add__(self, other):
        """
        Merge the segment seg with the current segment, and return the new constructed segment.
        No in-place modification.
        :param other: Another segment that is strictly next to added segment.
        :return: new segment of the two sub-segments
        """

        result = deepcopy(self)
        result.source_text += f' {other.source_text}'
        result.translation += f' {other.translation}'
        result.end_time_str = other.end_time_str
        result.end = other.end
        result.end_ms = other.end_ms
        result.duration = f"{self.start_time_str} --> {self.end_time_str}"
        return result

    def remove_trans_punc(self):
        """
        remove punctuations in translation text
        :return: None
        """
        punc_cn = ",。!?"
        translator = str.maketrans(punc_cn, ' ' * len(punc_cn))
        self.translation = self.translation.translate(translator)

    def __str__(self) -> str:
        return f'{self.duration}\n{self.source_text}\n\n'

    def get_trans_str(self) -> str:
        return f'{self.duration}\n{self.translation}\n\n'

    def get_bilingual_str(self) -> str:
        return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'


class SRT_script():
    def __init__(self, segments) -> None:
        self.segments = []
        for seg in segments:
            srt_seg = SRT_segment(seg)
            self.segments.append(srt_seg)

    @classmethod
    def parse_from_srt_file(cls, path: str):
        with open(path, 'r', encoding="utf-8") as f:
            script_lines = [line.rstrip() for line in f.readlines()]

        segments = []
        for i in range(len(script_lines)):
            if i % 4 == 0:
                segments.append(list(script_lines[i:i + 4]))

        return cls(segments)

    def merge_segs(self, idx_list) -> SRT_segment:
        """
        Merge entire segment list to a single segment
        :param idx_list: List of index to merge
        :return: Merged list
        """
        if not idx_list:
            raise NotImplementedError('Empty idx_list')
        seg_result = deepcopy(self.segments[idx_list[0]])
        if len(idx_list) == 1:
            return seg_result

        for idx in range(1, len(idx_list)):
            seg_result += self.segments[idx_list[idx]]

        return seg_result

    def form_whole_sentence(self):
        """
        Concatenate or Strip sentences and reconstruct segments list. This is because of
        improper segmentation from openai-whisper.
        :return: None
        """
        logging.info("Forming whole sentences...")
        merge_list = []  # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
        sentence = []
        for i, seg in enumerate(self.segments):
            if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
                sentence.append(i)
                merge_list.append(sentence)
                sentence = []
            else:
                sentence.append(i)

        segments = []
        for idx_list in merge_list:
            if len(idx_list) > 1:
                logging.info("merging segments: %s", idx_list)
            segments.append(self.merge_segs(idx_list))

        self.segments = segments

    def remove_trans_punctuation(self):
        """
        Post-process: remove all punc after translation and split
        :return: None
        """
        for i, seg in enumerate(self.segments):
            seg.remove_trans_punc()
        logging.info("Removed punctuation in translation.")

    def set_translation(self, translate: str, id_range: tuple, model, video_name, video_link=None):
        start_seg_id = id_range[0]
        end_seg_id = id_range[1]

        src_text = ""
        for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
            src_text += seg.source_text
            src_text += '\n\n'

        def inner_func(target, input_str):
            response = openai.ChatCompletion.create(
                # model=model,
                model="gpt-3.5-turbo",
                messages=[
                    # {"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
                    # {"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
                    # {"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
                    # {"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}

                    {"role": "system",
                     "content": "你的任务是按照要求合并或拆分句子到指定行数,你需要尽可能保证句意,但必要时可以将一句话分为两行输出"},
                    {"role": "system", "content": "注意:你只需要输出处理过的中文句子,如果你要输出序号,请使用冒号隔开"},
                    {"role": "user", "content": '请将下面的句子拆分或组合为{}句:\n{}'.format(target, input_str)}
                    # {"role": "system", "content": "请将以下中文与其英文句子一一对应并输出:"},
                    # {"role": "system", "content": "英文:{}".format(src_text)},
                    # {"role": "user", "content": "中文:{}\n\n".format(input_str)},
                ],
                temperature=0.15
            )
            # print(src_text)
            # print(input_str)
            # print(response['choices'][0]['message']['content'].strip())
            # exit()
            return response['choices'][0]['message']['content'].strip()

        lines = translate.split('\n\n')
        if len(lines) < (end_seg_id - start_seg_id + 1):
            count = 0
            solved = True
            while count < 5 and len(lines) != (end_seg_id - start_seg_id + 1):
                count += 1
                print("Solving Unmatched Lines|iteration {}".format(count))
                # input_str = "\n"
                # initialize GPT input
                # for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
                #    input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
                #    #Append to prompt string
                #    #Adds sentence index let GPT keep track of sentence breaks
                # input_str += translate
                # append translate to prompt
                flag = True
                while flag:
                    flag = False
                    # print("translate:")
                    # print(translate)
                    try:
                        # print("target")
                        # print(end_seg_id - start_seg_id + 1)
                        translate = inner_func(end_seg_id - start_seg_id + 1, translate)
                    except Exception as e:
                        print("An error has occurred during solving unmatched lines:", e)
                        print("Retrying...")
                        flag = True
                lines = translate.split('\n')
                # print("result")
                # print(len(lines))

            if len(lines) < (end_seg_id - start_seg_id + 1):
                solved = False
                print("Failed Solving unmatched lines, Manually parse needed")

            if not os.path.exists("./logs"):
                os.mkdir("./logs")
            if video_link:
                log_file = "./logs/log_link.csv"
                log_exist = os.path.exists(log_file)
                with open(log_file, "a") as log:
                    if not log_exist:
                        log.write("range_of_text,iterations_solving,solved,file_length,video_link" + "\n")
                    log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
                        len(self.segments)) + ',' + video_link + "\n")
            else:
                log_file = "./logs/log_name.csv"
                log_exist = os.path.exists(log_file)
                with open(log_file, "a") as log:
                    if not log_exist:
                        log.write("range_of_text,iterations_solving,solved,file_length,video_name" + "\n")
                    log.write(str(id_range) + ',' + str(count) + ',' + str(solved) + ',' + str(
                        len(self.segments)) + ',' + video_name + "\n")

            print(lines)
            # print(id_range)
            # for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
            #    print(seg.source_text)
            # print(translate)

        for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
            # naive way to due with merge translation problem
            # TODO: need a smarter solution

            if i < len(lines):
                if "Note:" in lines[i]:  # to avoid note
                    lines.remove(lines[i])
                    max_num -= 1
                    if i == len(lines) - 1:
                        break
                try:
                    seg.translation = lines[i].split(":" or ":" or ".")[1]
                except:
                    seg.translation = lines[i]

    def split_seg(self, seg, text_threshold, time_threshold):
        # evenly split seg to 2 parts and add new seg into self.segments

        # ignore the initial comma to solve the recursion problem
        if len(seg.source_text) > 2:
            if seg.source_text[:2] == ', ':
                seg.source_text = seg.source_text[2:]
        if seg.translation[0] == ',':
            seg.translation = seg.translation[1:]

        source_text = seg.source_text
        translation = seg.translation

        # split the text based on commas
        src_commas = [m.start() for m in re.finditer(',', source_text)]
        trans_commas = [m.start() for m in re.finditer(',', translation)]
        if len(src_commas) != 0:
            src_split_idx = src_commas[len(src_commas) // 2] if len(src_commas) % 2 == 1 else src_commas[
                len(src_commas) // 2 - 1]
        else:
            src_space = [m.start() for m in re.finditer(' ', source_text)]
            if len(src_space) > 0:
                src_split_idx = src_space[len(src_space) // 2] if len(src_space) % 2 == 1 else src_space[
                    len(src_space) // 2 - 1]
            else:
                src_split_idx = 0

        if len(trans_commas) != 0:
            trans_split_idx = trans_commas[len(trans_commas) // 2] if len(trans_commas) % 2 == 1 else trans_commas[
                len(trans_commas) // 2 - 1]
        else:
            trans_split_idx = len(translation) // 2

        # split the time duration based on text length
        time_split_ratio = trans_split_idx / (len(seg.translation) - 1)

        src_seg1 = source_text[:src_split_idx]
        src_seg2 = source_text[src_split_idx:]
        trans_seg1 = translation[:trans_split_idx]
        trans_seg2 = translation[trans_split_idx:]

        start_seg1 = seg.start
        end_seg1 = start_seg2 = seg.start + (seg.end - seg.start) * time_split_ratio
        end_seg2 = seg.end

        seg1_dict = {}
        seg1_dict['text'] = src_seg1
        seg1_dict['start'] = start_seg1
        seg1_dict['end'] = end_seg1
        seg1 = SRT_segment(seg1_dict)
        seg1.translation = trans_seg1

        seg2_dict = {}
        seg2_dict['text'] = src_seg2
        seg2_dict['start'] = start_seg2
        seg2_dict['end'] = end_seg2
        seg2 = SRT_segment(seg2_dict)
        seg2.translation = trans_seg2

        result_list = []
        if len(seg1.translation) > text_threshold and (seg1.end - seg1.start) > time_threshold:
            result_list += self.split_seg(seg1, text_threshold, time_threshold)
        else:
            result_list.append(seg1)

        if len(seg2.translation) > text_threshold and (seg2.end - seg2.start) > time_threshold:
            result_list += self.split_seg(seg2, text_threshold, time_threshold)
        else:
            result_list.append(seg2)

        return result_list

    def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
        # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
        logging.info("performing check_len_and_split")
        segments = []
        for i, seg in enumerate(self.segments):
            if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
                seg_list = self.split_seg(seg, text_threshold, time_threshold)
                logging.info("splitting segment {} in to {} parts".format(i+1, len(seg_list)))
                segments += seg_list
            else:
                segments.append(seg)

        self.segments = segments
        logging.info("check_len_and_split finished")

        pass

    def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
        # DEPRECATED
        # if sentence length >= text_threshold, split this segments to two
        start_seg_id = range[0]
        end_seg_id = range[1]
        extra_len = 0
        segments = []
        for i, seg in enumerate(self.segments[start_seg_id - 1:end_seg_id]):
            if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
                seg_list = self.split_seg(seg, text_threshold, time_threshold)
                segments += seg_list
                extra_len += len(seg_list) - 1
            else:
                segments.append(seg)

        self.segments[start_seg_id - 1:end_seg_id] = segments
        return extra_len

    def correct_with_force_term(self):
        ## force term correction
        logging.info("performing force term correction")
        # load term dictionary
        with open("./finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
            term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
        
        keywords = list(term_enzh_dict.keys())
        keywords.sort(key=lambda x: len(x), reverse=True)

        for word in keywords:
            for i, seg in enumerate(self.segments):
                if word in seg.source_text.lower():
                    seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(term_enzh_dict.get(word)), seg.source_text, flags=re.IGNORECASE)
                    logging.info("replace term: " + word + " --> " + term_enzh_dict.get(word) + " in time stamp {}".format(i+1))
                    logging.info("source text becomes: " + seg.source_text)

    def spell_check_term(self):
        ## known bug: I've will be replaced because i've is not in the dict
        logging.info("performing spell check")
        import enchant
        dict = enchant.Dict('en_US')
        term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')

        for seg in self.segments:
            ready_words = seg.source_text.split(" ")
            for i in range(len(ready_words)):
                word = ready_words[i]
                [real_word, pos] = self.get_real_word(word)
                if not dict.check(word[:pos]) and not term_spellDict.check(real_word):
                    suggest = term_spellDict.suggest(real_word)
                    if suggest and enchant.utils.levenshtein(real_word, suggest[0]) < (len(real_word)+len(suggest[0]))/4:  # relax spell check
                        
                        # with open("dislog.log","a") as log:
                        #     if not os.path.exists("dislog.log"):
                        #         log.write("word \t suggest \t levenshtein \n")
                        logging.info(real_word + "\t" + suggest[0] + "\t" + str(enchant.utils.levenshtein(real_word, suggest[0]))+'\n')                            
                        #print(word + ":" + suggest[0] + ":---:levenshtein:" + str(enchant.utils.levenshtein(word, suggest[0])))                                                
                        new_word = word.replace(word[:pos],suggest[0])
                    else:
                        new_word = word
                else: 
                    new_word = word
                ready_words[i] = new_word
            seg.source_text = " ".join(ready_words)
        pass

    def spell_correction(self, word: str, arg: int):
        try:
            arg in [0, 1]
        except ValueError:
            print('only 0 or 1 for argument')

        def uncover(word: str):
            if word[-2:] == ".\n":
                real_word = word[:-2].lower()
                n = -2
            elif word[-1:] in [".", "\n", ",", "!", "?"]:
                real_word = word[:-1].lower()
                n = -1
            else:
                real_word = word.lower()
                n = 0
            return real_word, len(word) + n

        real_word = uncover(word)[0]
        pos = uncover(word)[1]
        new_word = word
        if arg == 0:  # term translate mode
            with open("finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
                term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
            if real_word in term_enzh_dict:
                new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
        elif arg == 1:  # term spell check mode
            import enchant
            dict = enchant.Dict('en_US')
            term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
            if not dict.check(real_word):
                if term_spellDict.suggest(real_word):  # relax spell check
                    new_word = word.replace(word[:pos], term_spellDict.suggest(real_word)[0])
        return new_word

    def get_real_word(self, word: str):
        if word[-2:] == ".\n":
            real_word = word[:-2].lower()
            n = -2
        elif word[-1:] in [".", "\n", ",", "!", "?"]:
            real_word = word[:-1].lower()
            n = -1
        else:
            real_word = word.lower()
            n = 0
        return real_word, len(word) + n

    ## WRITE AND READ FUNCTIONS ##

    def get_source_only(self):
        # return a string with pure source text
        result = ""
        for i, seg in enumerate(self.segments):
            result+=f'{seg.source_text}\n\n\n'#f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
        
        return result

    def reform_src_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i + 1}\n'
            result += str(seg)
        return result

    def reform_trans_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i + 1}\n'
            result += seg.get_trans_str()
        return result

    def form_bilingual_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i + 1}\n'
            result += seg.get_bilingual_str()
        return result

    def write_srt_file_src(self, path: str):
        # write srt file to path
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.reform_src_str())
        pass

    def write_srt_file_translate(self, path: str):
        logging.info("writing to " + path)
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.reform_trans_str())
        pass

    def write_srt_file_bilingual(self, path: str):
        logging.info("writing to " + path)
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.form_bilingual_str())
        pass

    def realtime_write_srt(self, path, range, length, idx):
        # DEPRECATED
        start_seg_id = range[0]
        end_seg_id = range[1]
        with open(path, "a", encoding='utf-8') as f:
            # for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id+length]):
            #     f.write(f'{i+idx}\n')
            #     f.write(seg.get_trans_str())
            for i, seg in enumerate(self.segments):
                if i < range[0] - 1: continue
                if i >= range[1] + length: break
                f.write(f'{i + idx}\n')
                f.write(seg.get_trans_str())
        pass

    def realtime_bilingual_write_srt(self, path, range, length, idx):
        # DEPRECATED
        start_seg_id = range[0]
        end_seg_id = range[1]
        with open(path, "a", encoding='utf-8') as f:
            for i, seg in enumerate(self.segments):
                if i < range[0] - 1: continue
                if i >= range[1] + length: break
                f.write(f'{i + idx}\n')
                f.write(seg.get_bilingual_str())
        pass