File size: 17,211 Bytes
09cabee
66e606c
7baae45
66e606c
b39d769
e3825f8
09cabee
 
cf5f1c9
 
 
f1a218d
 
 
 
 
 
 
 
 
 
 
 
5f10ef2
7baae45
f1a218d
 
5f10ef2
7baae45
698fd4d
cf5f1c9
 
da02d45
cf5f1c9
f1a218d
cf5f1c9
6113bd9
 
7baae45
 
 
 
 
 
 
 
cf5f1c9
 
6113bd9
 
 
 
 
 
 
cf5f1c9
6113bd9
cf5f1c9
 
6113bd9
cf5f1c9
 
6113bd9
09cabee
 
 
 
 
 
 
cf5f1c9
 
 
 
 
 
 
 
 
 
7baae45
cf5f1c9
 
6113bd9
 
 
 
 
 
 
 
 
 
 
5f10ef2
6113bd9
 
 
 
 
 
 
 
 
 
 
 
da02d45
6113bd9
ce7a58b
 
6113bd9
1e2d254
cf5f1c9
 
b39d769
ce7a58b
 
1e2d254
ce7a58b
 
 
 
 
 
 
 
 
 
9e140e9
fb45ef4
7fea39b
 
 
 
 
 
 
 
 
 
 
 
 
ce7a58b
 
 
 
 
 
 
 
 
 
 
7fea39b
 
 
 
 
b39d769
 
 
 
e48cdd5
9e140e9
e48cdd5
cf5f1c9
fdce050
 
 
 
 
 
5f10ef2
 
b39d769
 
 
 
882925c
09cabee
6113bd9
e3825f8
6113bd9
f1a218d
 
 
 
7baae45
 
 
 
 
 
 
e3825f8
7baae45
 
 
f1a218d
7baae45
f1a218d
7baae45
f1a218d
 
 
 
 
 
 
 
 
6113bd9
f1a218d
 
 
 
 
 
 
7baae45
 
 
 
 
 
 
 
 
 
f1a218d
7baae45
f1a218d
 
9ad8b62
6113bd9
7baae45
 
 
 
 
 
 
 
 
f1a218d
6113bd9
09cabee
9ad8b62
e3825f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5f1c9
 
 
6113bd9
 
cf5f1c9
 
 
 
 
6113bd9
 
cf5f1c9
 
 
 
 
6113bd9
 
cf5f1c9
 
 
 
 
6113bd9
 
cf5f1c9
 
 
 
09cabee
cf5f1c9
 
09cabee
 
cf5f1c9
 
 
 
 
 
 
 
 
 
66e606c
 
 
 
cf5f1c9
66e606c
fd3d7cb
1de6702
66e606c
 
 
d438792
66e606c
 
1de6702
 
 
 
 
 
d438792
cf5f1c9
 
d438792
 
 
 
1de6702
 
fd3d7cb
cf5f1c9
d438792
 
 
 
1de6702
03b6946
1de6702
 
 
 
 
 
d438792
 
 
 
 
 
 
 
22b6efb
d438792
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd3d7cb
d438792
 
 
1de6702
 
 
 
 
 
 
 
 
 
 
 
e3825f8
 
 
 
 
 
9ad8b62
 
 
 
 
 
e3825f8
 
 
 
 
 
 
 
9ad8b62
 
 
e3825f8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
from datetime import timedelta
from csv import reader
from datetime import datetime
import re
import openai
from collections import deque

class SRT_segment(object):
    def __init__(self, *args) -> None:
        if isinstance(args[0], dict):
            segment = args[0]
            self.start = segment['start']
            self.end = segment['end']
            self.start_ms = int((segment['start']*100)%100*10)
            self.end_ms = int((segment['end']*100)%100*10)

            if self.start_ms == self.end_ms and int(segment['start']) == int(segment['end']): # avoid empty time stamp
                self.end_ms+=500

            self.start_time = timedelta(seconds=int(segment['start']), milliseconds=self.start_ms)
            self.end_time = timedelta(seconds=int(segment['end']), milliseconds=self.end_ms)
            if self.start_ms == 0:
                self.start_time_str = str(0)+str(self.start_time).split('.')[0]+',000'
            else:
                self.start_time_str = str(0)+str(self.start_time).split('.')[0]+','+str(self.start_time).split('.')[1][:3]
            if self.end_ms == 0:
                self.end_time_str = str(0)+str(self.end_time).split('.')[0]+',000'
            else:
                self.end_time_str = str(0)+str(self.end_time).split('.')[0]+','+str(self.end_time).split('.')[1][:3]
            self.source_text = segment['text'].lstrip()
            self.duration = f"{self.start_time_str} --> {self.end_time_str}"
            self.translation = ""

        elif isinstance(args[0], list):
            self.source_text = args[0][2]
            self.duration = args[0][1]
            self.start_time_str = self.duration.split(" --> ")[0]
            self.end_time_str = self.duration.split(" --> ")[1]

            # parse the time to float
            self.start_ms = int(self.start_time_str.split(',')[1])/10
            self.end_ms = int(self.end_time_str.split(',')[1])/10
            start_list = self.start_time_str.split(',')[0].split(':')
            self.start = int(start_list[0])*3600 + int(start_list[1])*60 + int(start_list[2]) + self.start_ms/100
            end_list = self.end_time_str.split(',')[0].split(':')
            self.end = int(end_list[0])*3600 + int(end_list[1])*60 + int(end_list[2]) + self.end_ms/100
            self.translation = ""
    
    def merge_seg(self, seg):
        self.source_text += seg.source_text
        self.translation += seg.translation
        self.end_time_str = seg.end_time_str
        self.duration = f"{self.start_time_str} --> {self.end_time_str}"
        pass
    
    def __str__(self) -> str:
        return  f'{self.duration}\n{self.source_text}\n\n'
    
    def get_trans_str(self) -> str:
        return f'{self.duration}\n{self.translation}\n\n'
    
    def get_bilingual_str(self) -> str:
        return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'

class SRT_script():
    def __init__(self, segments) -> None:
        self.segments = []
        for seg in segments:
            srt_seg = SRT_segment(seg)
            self.segments.append(srt_seg)

    @classmethod
    def parse_from_srt_file(cls, path:str):
        with open(path, 'r', encoding="utf-8") as f:
            script_lines = f.read().splitlines() 

        segments = []
        for i in range(len(script_lines)):
            if i % 4 == 0:
                segments.append(list(script_lines[i:i+4]))

        return cls(segments)

    def merge_segs(self, idx_list) -> SRT_segment:
        final_seg = self.segments[idx_list[0]]
        if len(idx_list) == 1:
            return final_seg
        
        for idx in range(1, len(idx_list)):
            final_seg.merge_seg(self.segments[idx_list[idx]])
        
        return final_seg

    def form_whole_sentence(self):
        merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
        sentence = []
        for i, seg in enumerate(self.segments):
            if seg.source_text[-1] == '.':
                sentence.append(i)
                merge_list.append(sentence)
                sentence = []
            else:
                sentence.append(i)

        segments = []
        for idx_list in merge_list:
            segments.append(self.merge_segs(idx_list))

        self.segments = segments # need memory release?
        
    

    def set_translation(self, translate:str, id_range:tuple, model):
        start_seg_id = id_range[0]
        end_seg_id = id_range[1]
        
        def inner_func(input_str):
            response = openai.ChatCompletion.create(
                    model=model,
                    messages = [
                        {"role": "system", "content": "You are a helpful assistant that help calibrates English to Chinese subtitle translations in starcraft2."},
                        {"role": "system", "content": "You are provided with a translated Chinese transcript; you must modify or split the Chinese sentence to match the meaning and the number of the English transcript exactly one by one. You must not merge ANY Chinese lines, you can only split them but the total Chinese lines MUST equals to number of English lines."},
                        {"role": "system", "content": "There is no need for you to add any comments or notes, and do not modify the English transcript."},
                        {"role": "user", "content": 'You are given the English transcript and line number, your task is to merge or split the Chinese to match the exact number of lines in English transcript, no more no less. For example, if there are more Chinese lines than English lines, merge some the Chinese lines to match the number of English lines. If Chinese lines is less than English lines, split some Chinese lines to match the english lines: "{}"'.format(input_str)}
                    ],
                    temperature=0.7
                )
            return response['choices'][0]['message']['content'].strip()
        
        lines = translate.split('\n\n')
        if len(lines) < (end_seg_id - start_seg_id + 1):
            count = 0
            while count<5 and len(lines) != (end_seg_id - start_seg_id + 1):
                
                count += 1
                print("Solving Unmatched Lines|iteration {}".format(count))
                input_str = "\n"
                #initialize GPT input
                for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
                    input_str += 'Sentence %d: ' %(i+1)+ seg.source_text + '\n'
                    #Append to prompt string
                    #Adds sentence index let GPT keep track of sentence breaks
                input_str += translate    
                #append translate to prompt
                
                flag = True
                while flag:
                    flag = False
                    try:
                        translate = inner_func(input_str)
                    except Exception as e:
                        print("An error has occurred during solving unmatched lines:",e)
                        print("Retrying...")
                        flag = True
                        
                lines = translate.split('\n\n')
            if len(lines) < (end_seg_id - start_seg_id + 1):
                print("Failed Solving unmatched lines, Manually parse needed")
            
            print(lines)
            #print(id_range)
            #for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
            #    print(seg.source_text)
            #print(translate)
            
        
        
        for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
            # naive way to due with merge translation problem
            # TODO: need a smarter solution

            if i < len(lines):
                if "(Note:" in lines[i]: # to avoid note
                    lines.remove(lines[i])
                    if i == len(lines) - 1:
                        break
                try:
                    seg.translation = lines[i].split(":" or ": ")[1]
                except:
                    seg.translation = lines[i]
                    #print(lines[i])
        pass
    
    def split_seg(self, seg, threshold=500):
        # TODO: evenly split seg to 2 parts and add new seg into self.segments
        source_text = seg.source_text
        translation = seg.translation
        src_commas = [m.start() for m in re.finditer(',', source_text)]
        trans_commas = [m.start() for m in re.finditer(',', translation)]
        if len(src_commas) != 0:
            src_split_idx = src_commas[len(src_commas)//2] if len(src_commas) % 2 == 1 else src_commas[len(src_commas)//2 - 1]
        else:
            src_space = [m.start() for m in re.finditer(' ', source_text)]
            src_split_idx = src_space[len(src_space)//2] if len(src_space) % 2 == 1 else src_space[len(src_space)//2 - 1]

        if len(trans_commas) != 0:
            trans_split_idx = trans_commas[len(trans_commas)//2] if len(trans_commas) % 2 == 1 else trans_commas[len(trans_commas)//2 - 1]
        else:
            trans_split_idx = len(translation)//2
        
        src_seg1 = source_text[:src_split_idx]
        src_seg2 = source_text[src_split_idx:]
        trans_seg1 = translation[:trans_split_idx]
        trans_seg2 = translation[trans_split_idx:]
        start_seg1 = seg.start
        end_seg1 = start_seg2 = seg.start + (seg.end - seg.start)/2
        end_seg2 = seg.end
        seg1_dict = {}
        seg1_dict['text'] = src_seg1
        seg1_dict['start'] = start_seg1
        seg1_dict['end'] = end_seg1
        seg1 = SRT_segment(seg1_dict)
        seg1.translation = trans_seg1

        seg2_dict = {}
        seg2_dict['text'] = src_seg2
        seg2_dict['start'] = start_seg2
        seg2_dict['end'] = end_seg2
        seg2 = SRT_segment(seg2_dict)
        seg2.translation = trans_seg2

        result_list = []
        if len(seg1.translation) > threshold:
            result_list += self.split_seg(seg1, threshold)
        else:
            result_list.append(seg1)
        
        if len(seg2.translation) > threshold:
            result_list += self.split_seg(seg2, threshold)
        else:
            result_list.append(seg2)

        return result_list


    def check_len_and_split(self, threshold=30000):
        # TODO: if sentence length >= threshold, split this segments to two
        segments = []
        for seg in self.segments:
            if len(seg.translation) > threshold:
                seg_list = self.split_seg(seg, threshold)
                segments += seg_list
            else:
                segments.append(seg)
        
        self.segments = segments

        pass

    def check_len_and_split_range(self, range, threshold=30000):
        # TODO: if sentence length >= threshold, split this segments to two
        start_seg_id = range[0]
        end_seg_id = range[1]
        segments = []
        for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
            if len(seg.translation) > threshold:
                seg_list = self.split_seg(seg, threshold)
                segments += seg_list
            else:
                segments.append(seg)
        
        self.segments[start_seg_id-1:end_seg_id] = segments

        return len(segments)
    
    def get_source_only(self):
        # return a string with pure source text
        result = ""
        for i, seg in enumerate(self.segments):
            result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
        
        return result
    
    def reform_src_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i+1}\n'
            result += str(seg)
        return result

    def reform_trans_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i+1}\n'
            result += seg.get_trans_str()
        return result
    
    def form_bilingual_str(self):
        result = ""
        for i, seg in enumerate(self.segments):
            result += f'{i+1}\n'
            result += seg.get_bilingual_str()
        return result

    def write_srt_file_src(self, path:str):
        # write srt file to path
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.reform_src_str())
        pass

    def write_srt_file_translate(self, path:str):
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.reform_trans_str())
        pass

    def write_srt_file_bilingual(self, path:str):
        with open(path, "w", encoding='utf-8') as f:
            f.write(self.form_bilingual_str())
        pass

    def correct_with_force_term(self):
        ## force term correction
        # TODO: shortcut translation i.e. VA, ob
        # TODO: variety of translation

        # load term dictionary
        with open("./finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
            term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)}

        # change term
        for seg in self.segments:
            ready_words = seg.source_text.split(" ")
            for i in range(len(ready_words)):
                word = ready_words[i]
                [real_word, pos] = self.get_real_word(word)
                if real_word in term_enzh_dict:
                    new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
                else:
                    new_word = word
                ready_words[i] = new_word
            seg.source_text = " ".join(ready_words)
        pass

    def spell_check_term(self):
        ## known bug: I've will be replaced because i've is not in the dict

        
        import enchant
        dict = enchant.Dict('en_US')
        term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')

        for seg in self.segments:
            ready_words = seg.source_text.split(" ")
            for i in range(len(ready_words)):
                word = ready_words[i]
                [real_word, pos] = self.get_real_word(word)
                if not dict.check(word[:pos]):
                    suggest = term_spellDict.suggest(real_word)
                    if suggest:  # relax spell check
                        new_word = word.replace(word[:pos],suggest[0])
                else: 
                    new_word = word
                ready_words[i] = new_word
            seg.source_text = " ".join(ready_words)
        pass

    def spell_correction(self, word:str, arg:int):
        try:
            arg in [0,1]
        except ValueError:
            print('only 0 or 1 for argument')


        def uncover(word:str):
            if word[-2:] == ".\n":
                real_word = word[:-2].lower()
                n = -2
            elif word[-1:] in [".", "\n", ",", "!", "?"]:
                real_word = word[:-1].lower()
                n = -1
            else:
                real_word = word.lower()
                n = 0
            return real_word, len(word)+n
        
        real_word = uncover(word)[0]
        pos = uncover(word)[1]
        new_word = word
        if arg == 0:  # term translate mode
            with open("finetune_data/dict_enzh.csv",'r', encoding='utf-8') as f:
                term_enzh_dict = {rows[0]:rows[1] for rows in reader(f)}
            if real_word in term_enzh_dict:
                new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
        elif arg == 1:  # term spell check mode
            import enchant
            dict = enchant.Dict('en_US')
            term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
            if not dict.check(real_word):
                if term_spellDict.suggest(real_word):  # relax spell check
                    new_word = word.replace(word[:pos],term_spellDict.suggest(real_word)[0])
        return new_word
    
    def get_real_word(self, word:str):
        if word[-2:] == ".\n":
            real_word = word[:-2].lower()
            n = -2
        elif word[-1:] in [".", "\n", ",", "!", "?"]:
            real_word = word[:-1].lower()
            n = -1
        else:
            real_word = word.lower()
            n = 0
        return real_word, len(word)+n
    
    def realtime_write_srt(self,path,range,length,idx):
        start_seg_id = range[0]
        end_seg_id = range[1]
        with open(path, "a", encoding='utf-8') as f:
            # for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id+length]):
            #     f.write(f'{i+idx}\n')
            #     f.write(seg.get_trans_str())
            for i, seg in enumerate(self.segments):
                if i<range[0]-1: continue
                if i>=range[1]:break
                f.write(f'{i+idx}\n')
                f.write(seg.get_trans_str())
        pass

    def realtime_bilingual_write_srt(self,path,range,length,idx):
        start_seg_id = range[0]
        end_seg_id = range[1]
        with open(path, "a", encoding='utf-8') as f:
            for i, seg in enumerate(self.segments):
                if i<range[0]-1: continue
                if i>=range[1]:break
                f.write(f'{i+idx}\n')
                f.write(seg.get_bilingual_str())
        pass