yichenl5 commited on
Commit
9ed44bb
·
2 Parent(s): 909a30f 4c3dc51

Merge pull request #37 from project-kxkg/eason/fix_term_check

Browse files

add log system; fix correct_with_force_term()

Former-commit-id: 617a92f17016c172dd138f6779be835c9f70e43d

Files changed (4) hide show
  1. .gitignore +2 -0
  2. SRT.py +28 -23
  3. finetune_data/dict_enzh.csv +3 -0
  4. pipeline.py +31 -7
.gitignore CHANGED
@@ -3,7 +3,9 @@ __pycache__/
3
  .idea/
4
  downloads/
5
  results/
 
6
  test/
 
7
  test.py
8
  test.srt
9
  test.txt
 
3
  .idea/
4
  downloads/
5
  results/
6
+ logs/
7
  test/
8
+ .vscode/
9
  test.py
10
  test.srt
11
  test.txt
SRT.py CHANGED
@@ -3,7 +3,7 @@ import re
3
  from copy import copy, deepcopy
4
  from csv import reader
5
  from datetime import timedelta
6
-
7
  import openai
8
 
9
 
@@ -72,7 +72,7 @@ class SRT_segment(object):
72
  :param other: Another segment that is strictly next to added segment.
73
  :return: new segment of the two sub-segments
74
  """
75
- # assert other.start_ms == self.end_ms, f"cannot merge discontinuous segments."
76
  result = deepcopy(self)
77
  result.source_text += f' {other.source_text}'
78
  result.translation += f' {other.translation}'
@@ -143,6 +143,7 @@ class SRT_script():
143
  improper segmentation from openai-whisper.
144
  :return: None
145
  """
 
146
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
147
  sentence = []
148
  for i, seg in enumerate(self.segments):
@@ -155,6 +156,8 @@ class SRT_script():
155
 
156
  segments = []
157
  for idx_list in merge_list:
 
 
158
  segments.append(self.merge_segs(idx_list))
159
 
160
  self.segments = segments
@@ -166,6 +169,7 @@ class SRT_script():
166
  """
167
  for i, seg in enumerate(self.segments):
168
  seg.remove_trans_punc()
 
169
 
170
  def set_translation(self, translate: str, id_range: tuple, model, video_name, video_link=None):
171
  start_seg_id = id_range[0]
@@ -351,21 +355,24 @@ class SRT_script():
351
  return result_list
352
 
353
  def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
354
- # DEPRECATED
355
  # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
 
356
  segments = []
357
- for seg in self.segments:
358
  if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
359
  seg_list = self.split_seg(seg, text_threshold, time_threshold)
 
360
  segments += seg_list
361
  else:
362
  segments.append(seg)
363
 
364
  self.segments = segments
 
365
 
366
  pass
367
 
368
  def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
 
369
  # if sentence length >= text_threshold, split this segments to two
370
  start_seg_id = range[0]
371
  end_seg_id = range[1]
@@ -384,28 +391,24 @@ class SRT_script():
384
 
385
  def correct_with_force_term(self):
386
  ## force term correction
387
-
388
  # load term dictionary
389
  with open("./finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
390
  term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
 
 
 
391
 
392
- # change term
393
- for seg in self.segments:
394
- ready_words = seg.source_text.split(" ")
395
- for i in range(len(ready_words)):
396
- word = ready_words[i]
397
- [real_word, pos] = self.get_real_word(word)
398
- if real_word in term_enzh_dict:
399
- new_word = word.replace(word[:pos], term_enzh_dict.get(real_word))
400
- else:
401
- new_word = word
402
- ready_words[i] = new_word
403
- seg.source_text = " ".join(ready_words)
404
- pass
405
 
406
  def spell_check_term(self):
407
  ## known bug: I've will be replaced because i've is not in the dict
408
-
409
  import enchant
410
  dict = enchant.Dict('en_US')
411
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
@@ -419,10 +422,10 @@ class SRT_script():
419
  suggest = term_spellDict.suggest(real_word)
420
  if suggest and enchant.utils.levenshtein(word, suggest[0]) < (len(word)+len(suggest[0]))/4: # relax spell check
421
 
422
- with open("dislog.log","a") as log:
423
- if not os.path.exists("dislog.log"):
424
- log.write("word \t suggest \t levenshtein \n")
425
- log.write(word + "\t" + suggest[0] + "\t" + str(enchant.utils.levenshtein(word, suggest[0]))+'\n')
426
  #print(word + ":" + suggest[0] + ":---:levenshtein:" + str(enchant.utils.levenshtein(word, suggest[0])))
427
  new_word = word.replace(word[:pos],suggest[0])
428
  else:
@@ -518,11 +521,13 @@ class SRT_script():
518
  pass
519
 
520
  def write_srt_file_translate(self, path: str):
 
521
  with open(path, "w", encoding='utf-8') as f:
522
  f.write(self.reform_trans_str())
523
  pass
524
 
525
  def write_srt_file_bilingual(self, path: str):
 
526
  with open(path, "w", encoding='utf-8') as f:
527
  f.write(self.form_bilingual_str())
528
  pass
 
3
  from copy import copy, deepcopy
4
  from csv import reader
5
  from datetime import timedelta
6
+ import logging
7
  import openai
8
 
9
 
 
72
  :param other: Another segment that is strictly next to added segment.
73
  :return: new segment of the two sub-segments
74
  """
75
+
76
  result = deepcopy(self)
77
  result.source_text += f' {other.source_text}'
78
  result.translation += f' {other.translation}'
 
143
  improper segmentation from openai-whisper.
144
  :return: None
145
  """
146
+ logging.info("Forming whole sentences...")
147
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
148
  sentence = []
149
  for i, seg in enumerate(self.segments):
 
156
 
157
  segments = []
158
  for idx_list in merge_list:
159
+ if len(idx_list) > 1:
160
+ logging.info("merging segments: %s", idx_list)
161
  segments.append(self.merge_segs(idx_list))
162
 
163
  self.segments = segments
 
169
  """
170
  for i, seg in enumerate(self.segments):
171
  seg.remove_trans_punc()
172
+ logging.info("Removed punctuation in translation.")
173
 
174
  def set_translation(self, translate: str, id_range: tuple, model, video_name, video_link=None):
175
  start_seg_id = id_range[0]
 
355
  return result_list
356
 
357
  def check_len_and_split(self, text_threshold=30, time_threshold=1.0):
 
358
  # if sentence length >= threshold and sentence duration > time_threshold, split this segments to two
359
+ logging.info("performing check_len_and_split")
360
  segments = []
361
+ for i, seg in enumerate(self.segments):
362
  if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
363
  seg_list = self.split_seg(seg, text_threshold, time_threshold)
364
+ logging.info("splitting segment {} in to {} parts".format(i+1, len(seg_list)))
365
  segments += seg_list
366
  else:
367
  segments.append(seg)
368
 
369
  self.segments = segments
370
+ logging.info("check_len_and_split finished")
371
 
372
  pass
373
 
374
  def check_len_and_split_range(self, range, text_threshold=30, time_threshold=1.0):
375
+ # DEPRECATED
376
  # if sentence length >= text_threshold, split this segments to two
377
  start_seg_id = range[0]
378
  end_seg_id = range[1]
 
391
 
392
  def correct_with_force_term(self):
393
  ## force term correction
394
+ logging.info("performing force term correction")
395
  # load term dictionary
396
  with open("./finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
397
  term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
398
+
399
+ keywords = list(term_enzh_dict.keys())
400
+ keywords.sort(key=lambda x: len(x), reverse=True)
401
 
402
+ for word in keywords:
403
+ for i, seg in enumerate(self.segments):
404
+ if word in seg.source_text.lower():
405
+ seg.source_text = seg.source_text.lower().replace(word, term_enzh_dict.get(word))
406
+ logging.info("replace term: " + word + " --> " + term_enzh_dict.get(word) + " in time stamp {}".format(i+1))
407
+ logging.info("source text becomes: " + seg.source_text)
 
 
 
 
 
 
 
408
 
409
  def spell_check_term(self):
410
  ## known bug: I've will be replaced because i've is not in the dict
411
+ logging.info("performing spell check")
412
  import enchant
413
  dict = enchant.Dict('en_US')
414
  term_spellDict = enchant.PyPWL('./finetune_data/dict_freq.txt')
 
422
  suggest = term_spellDict.suggest(real_word)
423
  if suggest and enchant.utils.levenshtein(word, suggest[0]) < (len(word)+len(suggest[0]))/4: # relax spell check
424
 
425
+ # with open("dislog.log","a") as log:
426
+ # if not os.path.exists("dislog.log"):
427
+ # log.write("word \t suggest \t levenshtein \n")
428
+ logging.info(word + "\t" + suggest[0] + "\t" + str(enchant.utils.levenshtein(word, suggest[0]))+'\n')
429
  #print(word + ":" + suggest[0] + ":---:levenshtein:" + str(enchant.utils.levenshtein(word, suggest[0])))
430
  new_word = word.replace(word[:pos],suggest[0])
431
  else:
 
521
  pass
522
 
523
  def write_srt_file_translate(self, path: str):
524
+ logging.info("writing to " + path)
525
  with open(path, "w", encoding='utf-8') as f:
526
  f.write(self.reform_trans_str())
527
  pass
528
 
529
  def write_srt_file_bilingual(self, path: str):
530
+ logging.info("writing to " + path)
531
  with open(path, "w", encoding='utf-8') as f:
532
  f.write(self.form_bilingual_str())
533
  pass
finetune_data/dict_enzh.csv CHANGED
@@ -1,4 +1,7 @@
1
  barracks,兵营
 
 
 
2
  engineering bay,工程站
3
  forge,锻炉
4
  blink,闪现
 
1
  barracks,兵营
2
+ zerg,虫族
3
+ protoss,神族
4
+ terran,人族
5
  engineering bay,工程站
6
  forge,锻炉
7
  blink,闪现
pipeline.py CHANGED
@@ -7,6 +7,8 @@ from SRT import SRT_script
7
  import stable_whisper
8
  import whisper
9
  from srt2ass import srt2ass
 
 
10
 
11
  import subprocess
12
 
@@ -22,6 +24,7 @@ def parse_args():
22
  parser.add_argument("--output_dir", help="translate result path", default='./results', type=str, required=False)
23
  parser.add_argument("--video_name", help="video name, if use video link as input, the name will auto-filled by youtube video name", default='placeholder', type=str, required=False)
24
  parser.add_argument("--model_name", help="model name only support gpt-4 and gpt-3.5-turbo", type=str, required=False, default="gpt-4") # default change to gpt-4
 
25
  parser.add_argument("-only_srt", help="set script output to only .srt file", action='store_true')
26
  parser.add_argument("-v", help="auto encode script with video", action='store_true')
27
  args = parser.parse_args()
@@ -30,6 +33,10 @@ def parse_args():
30
 
31
  def get_sources(args, download_path, result_path, video_name):
32
  # get source audio
 
 
 
 
33
  if args.link is not None and args.video_file is None:
34
  # Download audio from YouTube
35
  video_link = args.link
@@ -198,6 +205,7 @@ def get_response(model_name, sentence):
198
 
199
  # Translate and save
200
  def translate(srt, script_arr, range_arr, model_name, video_name, video_link):
 
201
  previous_length = 0
202
  for sentence, range in tqdm(zip(script_arr, range_arr)):
203
  # update the range based on previous length
@@ -205,12 +213,14 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link):
205
 
206
  # using chatgpt model
207
  print(f"now translating sentences {range}")
 
208
  flag = True
209
  while flag:
210
  flag = False
211
  try:
212
  translate = get_response(model_name, sentence)
213
  except Exception as e:
 
214
  print("An error has occurred during translation:",e)
215
  print("Retrying... the script will continue after 30 seconds.")
216
  time.sleep(30)
@@ -227,6 +237,7 @@ def main():
227
  exit()
228
 
229
  # set up
 
230
  openai.api_key = os.getenv("OPENAI_API_KEY")
231
  DOWNLOAD_PATH = args.download
232
  if not os.path.exists(DOWNLOAD_PATH):
@@ -237,7 +248,7 @@ def main():
237
  RESULT_PATH = args.output_dir
238
  if not os.path.exists(RESULT_PATH):
239
  os.mkdir(RESULT_PATH)
240
-
241
  # set video name as the input file name if not specified
242
  if args.video_name == 'placeholder' :
243
  # set video name to upload file name
@@ -254,25 +265,34 @@ def main():
254
 
255
  audio_path, audio_file, video_path, VIDEO_NAME = get_sources(args, DOWNLOAD_PATH, RESULT_PATH, VIDEO_NAME)
256
 
 
 
 
 
257
  srt_file_en, srt = get_srt_class(args.srt_file, RESULT_PATH, VIDEO_NAME, audio_path, audio_file)
258
 
259
  # SRT class preprocess
 
 
260
  srt.form_whole_sentence()
261
- srt.spell_check_term()
262
  srt.correct_with_force_term()
263
- srt.write_srt_file_src(srt_file_en)
 
264
  script_input = srt.get_source_only()
265
 
266
  # write ass
267
  if not args.only_srt:
268
- assSub_en = srt2ass(srt_file_en, "default", "No", "Modest")
269
- print('ASS subtitle saved as: ' + assSub_en)
 
270
 
271
  script_arr, range_arr = script_split(script_input)
272
-
273
  translate(srt, script_arr, range_arr, args.model_name, VIDEO_NAME, args.link)
274
 
275
  # SRT post-processing
 
276
  srt.check_len_and_split()
277
  srt.remove_trans_punctuation()
278
  srt.write_srt_file_translate(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt")
@@ -280,16 +300,20 @@ def main():
280
 
281
  # write ass
282
  if not args.only_srt:
 
283
  assSub_zh = srt2ass(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt", "default", "No", "Modest")
284
- print('ASS subtitle saved as: ' + assSub_zh)
285
 
286
  # encode to .mp4 video file
287
  if args.v:
 
288
  if args.only_srt:
289
  os.system(f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
290
  else:
291
  os.system(f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.ass" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
292
 
 
 
293
 
294
  if __name__ == "__main__":
295
  main()
 
7
  import stable_whisper
8
  import whisper
9
  from srt2ass import srt2ass
10
+ import logging
11
+ from datetime import datetime
12
 
13
  import subprocess
14
 
 
24
  parser.add_argument("--output_dir", help="translate result path", default='./results', type=str, required=False)
25
  parser.add_argument("--video_name", help="video name, if use video link as input, the name will auto-filled by youtube video name", default='placeholder', type=str, required=False)
26
  parser.add_argument("--model_name", help="model name only support gpt-4 and gpt-3.5-turbo", type=str, required=False, default="gpt-4") # default change to gpt-4
27
+ parser.add_argument("--log_dir", help="log path", default='./logs', type=str, required=False)
28
  parser.add_argument("-only_srt", help="set script output to only .srt file", action='store_true')
29
  parser.add_argument("-v", help="auto encode script with video", action='store_true')
30
  args = parser.parse_args()
 
33
 
34
  def get_sources(args, download_path, result_path, video_name):
35
  # get source audio
36
+ audio_path = None
37
+ audio_file = None
38
+ video_path = None
39
+
40
  if args.link is not None and args.video_file is None:
41
  # Download audio from YouTube
42
  video_link = args.link
 
205
 
206
  # Translate and save
207
  def translate(srt, script_arr, range_arr, model_name, video_name, video_link):
208
+ logging.info("start translating...")
209
  previous_length = 0
210
  for sentence, range in tqdm(zip(script_arr, range_arr)):
211
  # update the range based on previous length
 
213
 
214
  # using chatgpt model
215
  print(f"now translating sentences {range}")
216
+ logging.info(f"now translating sentences {range}, time: {datetime.now()}")
217
  flag = True
218
  while flag:
219
  flag = False
220
  try:
221
  translate = get_response(model_name, sentence)
222
  except Exception as e:
223
+ logging.debug("An error has occurred during translation:",e)
224
  print("An error has occurred during translation:",e)
225
  print("Retrying... the script will continue after 30 seconds.")
226
  time.sleep(30)
 
237
  exit()
238
 
239
  # set up
240
+ start_time = time.time()
241
  openai.api_key = os.getenv("OPENAI_API_KEY")
242
  DOWNLOAD_PATH = args.download
243
  if not os.path.exists(DOWNLOAD_PATH):
 
248
  RESULT_PATH = args.output_dir
249
  if not os.path.exists(RESULT_PATH):
250
  os.mkdir(RESULT_PATH)
251
+
252
  # set video name as the input file name if not specified
253
  if args.video_name == 'placeholder' :
254
  # set video name to upload file name
 
265
 
266
  audio_path, audio_file, video_path, VIDEO_NAME = get_sources(args, DOWNLOAD_PATH, RESULT_PATH, VIDEO_NAME)
267
 
268
+ logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler("{}/{}_{}.log".format(args.log_dir, VIDEO_NAME, datetime.now().strftime("%m%d%Y_%H%M%S")))], encoding='utf-8')
269
+ logging.info("---------------------Video Info---------------------")
270
+ logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
271
+
272
  srt_file_en, srt = get_srt_class(args.srt_file, RESULT_PATH, VIDEO_NAME, audio_path, audio_file)
273
 
274
  # SRT class preprocess
275
+ logging.info("---------------------Start Preprocessing SRT class---------------------")
276
+ srt.write_srt_file_src(srt_file_en)
277
  srt.form_whole_sentence()
278
+ # srt.spell_check_term()
279
  srt.correct_with_force_term()
280
+ processed_srt_file_en = srt_file_en.split('.srt')[0] + '_processed.srt'
281
+ srt.write_srt_file_src(processed_srt_file_en)
282
  script_input = srt.get_source_only()
283
 
284
  # write ass
285
  if not args.only_srt:
286
+ logging.info("write English .srt file to .ass")
287
+ assSub_en = srt2ass(processed_srt_file_en, "default", "No", "Modest")
288
+ logging.info('ASS subtitle saved as: ' + assSub_en)
289
 
290
  script_arr, range_arr = script_split(script_input)
291
+ logging.info("---------------------Start Translation--------------------")
292
  translate(srt, script_arr, range_arr, args.model_name, VIDEO_NAME, args.link)
293
 
294
  # SRT post-processing
295
+ logging.info("---------------------Start Post-processing SRT class---------------------")
296
  srt.check_len_and_split()
297
  srt.remove_trans_punctuation()
298
  srt.write_srt_file_translate(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt")
 
300
 
301
  # write ass
302
  if not args.only_srt:
303
+ logging.info("write Chinese .srt file to .ass")
304
  assSub_zh = srt2ass(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt", "default", "No", "Modest")
305
+ logging.info('ASS subtitle saved as: ' + assSub_zh)
306
 
307
  # encode to .mp4 video file
308
  if args.v:
309
+ logging.info("encoding video file")
310
  if args.only_srt:
311
  os.system(f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
312
  else:
313
  os.system(f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.ass" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
314
 
315
+ end_time = time.time()
316
+ logging.info("Pipeline finished, time duration:{}".format(start_time - end_time))
317
 
318
  if __name__ == "__main__":
319
  main()