CanYing0913 commited on
Commit
2d29c14
2 Parent(s): 85c07a7 7d74f8e

Merge branch 'SRT_cleanup' into eason/main

Browse files

Former-commit-id: 3cedf7bb4e826122d3227968510ee9811a86bcb5

doc/Installation.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ### **Recommended:**
2
+ We recommend you to configure your environment using [mamba](https://pypi.org/project/mamba/). The following packages are required:
3
+ ```
4
+ openai
5
+ openai-whisper
6
+
7
+ ```
doc/struct.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Structure of Repository
2
+ ```
3
+ ├── doc # Baseline implementation of SpMM algorithm.
4
+ ├────── struct.md # Document of repository structure.
5
+ ├── finetune_data #
6
+ └── README.md
7
+ ```
pipeline.py CHANGED
@@ -3,10 +3,10 @@ from pytube import YouTube
3
  import argparse
4
  import os
5
  from tqdm import tqdm
6
- from SRT import SRT_script
7
  import stable_whisper
8
  import whisper
9
- from srt2ass import srt2ass
10
  import logging
11
  from datetime import datetime
12
  import torch
@@ -15,23 +15,29 @@ import subprocess
15
 
16
  import time
17
 
 
18
  def parse_args():
19
  parser = argparse.ArgumentParser()
20
  parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
21
  parser.add_argument("--video_file", help="local video path here", default=None, type=str, required=False)
22
  parser.add_argument("--audio_file", help="local audio path here", default=None, type=str, required=False)
23
- parser.add_argument("--srt_file", help="srt file input path here", default=None, type=str, required=False) # New argument
 
24
  parser.add_argument("--download", help="download path", default='./downloads', type=str, required=False)
25
  parser.add_argument("--output_dir", help="translate result path", default='./results', type=str, required=False)
26
- parser.add_argument("--video_name", help="video name, if use video link as input, the name will auto-filled by youtube video name", default='placeholder', type=str, required=False)
27
- parser.add_argument("--model_name", help="model name only support gpt-4 and gpt-3.5-turbo", type=str, required=False, default="gpt-4") # default change to gpt-4
 
 
 
28
  parser.add_argument("--log_dir", help="log path", default='./logs', type=str, required=False)
29
  parser.add_argument("-only_srt", help="set script output to only .srt file", action='store_true')
30
  parser.add_argument("-v", help="auto encode script with video", action='store_true')
31
  args = parser.parse_args()
32
-
33
  return args
34
 
 
35
  def get_sources(args, download_path, result_path, video_name):
36
  # get source audio
37
  audio_path = None
@@ -59,9 +65,9 @@ def get_sources(args, download_path, result_path, video_name):
59
  print("Error: Audio stream not found")
60
  except Exception as e:
61
  print("Connection Error")
62
- print(e)
63
  exit()
64
-
65
  video_path = f'{download_path}/video/{video.default_filename}'
66
  audio_path = '{}/audio/{}'.format(download_path, audio.default_filename)
67
  audio_file = open(audio_path, "rb")
@@ -72,7 +78,7 @@ def get_sources(args, download_path, result_path, video_name):
72
  video_path = args.video_file
73
 
74
  if args.audio_file is not None:
75
- audio_file= open(args.audio_file, "rb")
76
  audio_path = args.audio_file
77
  else:
78
  output_audio_path = f'{download_path}/audio/{video_name}.mp3'
@@ -84,37 +90,41 @@ def get_sources(args, download_path, result_path, video_name):
84
  os.mkdir(f'{result_path}/{video_name}')
85
 
86
  if args.audio_file is not None:
87
- audio_file= open(args.audio_file, "rb")
88
  audio_path = args.audio_file
89
  pass
90
 
91
  return audio_path, audio_file, video_path, video_name
92
 
93
- def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file = None, whisper_model = 'large', method = "stable"):
 
 
94
  # Instead of using the script_en variable directly, we'll use script_input
95
- if srt_file_en is not None:
96
- srt = SRT_script.parse_from_srt_file(srt_file_en)
97
  else:
98
  # using whisper to perform speech-to-text and save it in <video name>_en.txt under RESULT PATH.
99
  srt_file_en = "{}/{}/{}_en.srt".format(result_path, video_name, video_name)
100
  if not os.path.exists(srt_file_en):
101
-
102
- devices = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
103
  # use OpenAI API for transcribe
104
  if method == "api":
105
- transcript = openai.Audio.transcribe("whisper-1", audio_file)
106
 
107
- # use local whisper model
108
  elif method == "basic":
109
- model = whisper.load_model(whisper_model, device = devices) # using base model in local machine (may use large model on our server)
 
110
  transcript = model.transcribe(audio_path)
111
 
112
  # use stable-whisper
113
  elif method == "stable":
114
 
115
  # use cuda if available
116
- model = stable_whisper.load_model(whisper_model, device = devices)
117
- transcript = model.transcribe(audio_path, regroup = False, initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
 
118
  (
119
  transcript
120
  .split_by_punctuation(['.', '。', '?'])
@@ -126,14 +136,15 @@ def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file =
126
  else:
127
  raise ValueError("invalid speech to text method")
128
 
129
- srt = SRT_script(transcript['segments']) # read segments to SRT class
130
 
131
  else:
132
- srt = SRT_script.parse_from_srt_file(srt_file_en)
133
  return srt_file_en, srt
134
 
 
135
  # Split the video script by sentences and create chunks within the token limit
136
- def script_split(script_in, chunk_size = 1000):
137
  script_split = script_in.split('\n\n')
138
  script_arr = []
139
  range_arr = []
@@ -143,20 +154,21 @@ def script_split(script_in, chunk_size = 1000):
143
  for sentence in script_split:
144
  if len(script) + len(sentence) + 1 <= chunk_size:
145
  script += sentence + '\n\n'
146
- end+=1
147
  else:
148
  range_arr.append((start, end))
149
- start = end+1
150
  end += 1
151
  script_arr.append(script.strip())
152
  script = sentence + '\n\n'
153
  if script.strip():
154
  script_arr.append(script.strip())
155
- range_arr.append((start, len(script_split)-1))
156
 
157
  assert len(script_arr) == len(range_arr)
158
  return script_arr, range_arr
159
 
 
160
  def check_translation(sentence, translation):
161
  """
162
  check merge sentence issue from openai translation
@@ -187,24 +199,25 @@ def get_response(model_name, sentence):
187
  if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
188
  response = openai.ChatCompletion.create(
189
  model=model_name,
190
- messages = [
191
- #{"role": "system", "content": "You are a helpful assistant that translates English to Chinese and have decent background in starcraft2."},
192
- #{"role": "system", "content": "Your translation has to keep the orginal format and be as accurate as possible."},
193
- #{"role": "system", "content": "Your translation needs to be consistent with the number of sentences in the original."},
194
- #{"role": "system", "content": "There is no need for you to add any comments or notes."},
195
- #{"role": "user", "content": 'Translate the following English text to Chinese: "{}"'.format(sentence)}
196
-
197
- {"role": "system", "content": "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"},
 
198
  {"role": "user", "content": sentence}
199
  ],
200
  temperature=0.15
201
  )
202
 
203
  return response['choices'][0]['message']['content'].strip()
204
-
205
-
206
  # Translate and save
207
- def translate(srt, script_arr, range_arr, model_name, video_name, video_link, attempts_count = 5):
208
  """
209
  Translates the given script array into another language using the chatgpt and writes to the SRT file.
210
 
@@ -226,7 +239,7 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
226
  previous_length = 0
227
  for sentence, range in tqdm(zip(script_arr, range_arr)):
228
  # update the range based on previous length
229
- range = (range[0]+previous_length, range[1]+previous_length)
230
 
231
  # using chatgpt model
232
  print(f"now translating sentences {range}")
@@ -240,7 +253,7 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
240
  while not check_translation(sentence, translate) and attempts_count > 0:
241
  translate = get_response(model_name, sentence)
242
  attempts_count -= 1
243
-
244
  # if failure still happen, split into smaller tokens
245
  if attempts_count == 0:
246
  single_sentences = sentence.split("\n\n")
@@ -252,11 +265,11 @@ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, at
252
  else:
253
  translate += get_response(model_name, single_sentence) + "\n\n"
254
  # print(single_sentence, translate.split("\n\n")[-2])
255
- logging.info("solved by individually translation!")
256
 
257
  except Exception as e:
258
- logging.debug("An error has occurred during translation:",e)
259
- print("An error has occurred during translation:",e)
260
  print("Retrying... the script will continue after 30 seconds.")
261
  time.sleep(30)
262
  flag = True
@@ -284,9 +297,9 @@ def main():
284
  RESULT_PATH = args.output_dir
285
  if not os.path.exists(RESULT_PATH):
286
  os.mkdir(RESULT_PATH)
287
-
288
  # set video name as the input file name if not specified
289
- if args.video_name == 'placeholder' :
290
  # set video name to upload file name
291
  if args.video_file is not None:
292
  VIDEO_NAME = args.video_file.split('/')[-1].split('.')[0]
@@ -303,7 +316,9 @@ def main():
303
 
304
  if not os.path.exists(args.log_dir):
305
  os.makedirs(args.log_dir)
306
- logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler("{}/{}_{}.log".format(args.log_dir, VIDEO_NAME, datetime.now().strftime("%m%d%Y_%H%M%S")), 'w', encoding='utf-8')])
 
 
307
  logging.info("---------------------Video Info---------------------")
308
  logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
309
 
@@ -346,12 +361,16 @@ def main():
346
  if args.v:
347
  logging.info("encoding video file")
348
  if args.only_srt:
349
- os.system(f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
 
350
  else:
351
- os.system(f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.ass" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
 
352
 
353
  end_time = time.time()
354
- logging.info("Pipeline finished, time duration:{}".format(time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))))
 
 
355
 
356
  if __name__ == "__main__":
357
- main()
 
3
  import argparse
4
  import os
5
  from tqdm import tqdm
6
+ from srt_util.srt import SrtScript
7
  import stable_whisper
8
  import whisper
9
+ from srt_util.srt2ass import srt2ass
10
  import logging
11
  from datetime import datetime
12
  import torch
 
15
 
16
  import time
17
 
18
+
19
  def parse_args():
20
  parser = argparse.ArgumentParser()
21
  parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
22
  parser.add_argument("--video_file", help="local video path here", default=None, type=str, required=False)
23
  parser.add_argument("--audio_file", help="local audio path here", default=None, type=str, required=False)
24
+ parser.add_argument("--srt_file", help="srt file input path here", default=None, type=str,
25
+ required=False) # New argument
26
  parser.add_argument("--download", help="download path", default='./downloads', type=str, required=False)
27
  parser.add_argument("--output_dir", help="translate result path", default='./results', type=str, required=False)
28
+ parser.add_argument("--video_name",
29
+ help="video name, if use video link as input, the name will auto-filled by youtube video name",
30
+ default='placeholder', type=str, required=False)
31
+ parser.add_argument("--model_name", help="model name only support gpt-4 and gpt-3.5-turbo", type=str,
32
+ required=False, default="gpt-4") # default change to gpt-4
33
  parser.add_argument("--log_dir", help="log path", default='./logs', type=str, required=False)
34
  parser.add_argument("-only_srt", help="set script output to only .srt file", action='store_true')
35
  parser.add_argument("-v", help="auto encode script with video", action='store_true')
36
  args = parser.parse_args()
37
+
38
  return args
39
 
40
+
41
  def get_sources(args, download_path, result_path, video_name):
42
  # get source audio
43
  audio_path = None
 
65
  print("Error: Audio stream not found")
66
  except Exception as e:
67
  print("Connection Error")
68
+ print(e)
69
  exit()
70
+
71
  video_path = f'{download_path}/video/{video.default_filename}'
72
  audio_path = '{}/audio/{}'.format(download_path, audio.default_filename)
73
  audio_file = open(audio_path, "rb")
 
78
  video_path = args.video_file
79
 
80
  if args.audio_file is not None:
81
+ audio_file = open(args.audio_file, "rb")
82
  audio_path = args.audio_file
83
  else:
84
  output_audio_path = f'{download_path}/audio/{video_name}.mp3'
 
90
  os.mkdir(f'{result_path}/{video_name}')
91
 
92
  if args.audio_file is not None:
93
+ audio_file = open(args.audio_file, "rb")
94
  audio_path = args.audio_file
95
  pass
96
 
97
  return audio_path, audio_file, video_path, video_name
98
 
99
+
100
+ def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file=None, whisper_model='large',
101
+ method="stable"):
102
  # Instead of using the script_en variable directly, we'll use script_input
103
+ if srt_file_en is not None:
104
+ srt = SrtScript.parse_from_srt_file(srt_file_en)
105
  else:
106
  # using whisper to perform speech-to-text and save it in <video name>_en.txt under RESULT PATH.
107
  srt_file_en = "{}/{}/{}_en.srt".format(result_path, video_name, video_name)
108
  if not os.path.exists(srt_file_en):
109
+
110
+ devices = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
111
  # use OpenAI API for transcribe
112
  if method == "api":
113
+ transcript = openai.Audio.transcribe("whisper-1", audio_file)
114
 
115
+ # use local whisper model
116
  elif method == "basic":
117
+ model = whisper.load_model(whisper_model,
118
+ device=devices) # using base model in local machine (may use large model on our server)
119
  transcript = model.transcribe(audio_path)
120
 
121
  # use stable-whisper
122
  elif method == "stable":
123
 
124
  # use cuda if available
125
+ model = stable_whisper.load_model(whisper_model, device=devices)
126
+ transcript = model.transcribe(audio_path, regroup=False,
127
+ initial_prompt="Hello, welcome to my lecture. Are you good my friend?")
128
  (
129
  transcript
130
  .split_by_punctuation(['.', '。', '?'])
 
136
  else:
137
  raise ValueError("invalid speech to text method")
138
 
139
+ srt = SrtScript(transcript['segments']) # read segments to SRT class
140
 
141
  else:
142
+ srt = SrtScript.parse_from_srt_file(srt_file_en)
143
  return srt_file_en, srt
144
 
145
+
146
  # Split the video script by sentences and create chunks within the token limit
147
+ def script_split(script_in, chunk_size=1000):
148
  script_split = script_in.split('\n\n')
149
  script_arr = []
150
  range_arr = []
 
154
  for sentence in script_split:
155
  if len(script) + len(sentence) + 1 <= chunk_size:
156
  script += sentence + '\n\n'
157
+ end += 1
158
  else:
159
  range_arr.append((start, end))
160
+ start = end + 1
161
  end += 1
162
  script_arr.append(script.strip())
163
  script = sentence + '\n\n'
164
  if script.strip():
165
  script_arr.append(script.strip())
166
+ range_arr.append((start, len(script_split) - 1))
167
 
168
  assert len(script_arr) == len(range_arr)
169
  return script_arr, range_arr
170
 
171
+
172
  def check_translation(sentence, translation):
173
  """
174
  check merge sentence issue from openai translation
 
199
  if model_name == "gpt-3.5-turbo" or model_name == "gpt-4":
200
  response = openai.ChatCompletion.create(
201
  model=model_name,
202
+ messages=[
203
+ # {"role": "system", "content": "You are a helpful assistant that translates English to Chinese and have decent background in starcraft2."},
204
+ # {"role": "system", "content": "Your translation has to keep the orginal format and be as accurate as possible."},
205
+ # {"role": "system", "content": "Your translation needs to be consistent with the number of sentences in the original."},
206
+ # {"role": "system", "content": "There is no need for you to add any comments or notes."},
207
+ # {"role": "user", "content": 'Translate the following English text to Chinese: "{}"'.format(sentence)}
208
+
209
+ {"role": "system",
210
+ "content": "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"},
211
  {"role": "user", "content": sentence}
212
  ],
213
  temperature=0.15
214
  )
215
 
216
  return response['choices'][0]['message']['content'].strip()
217
+
218
+
219
  # Translate and save
220
+ def translate(srt, script_arr, range_arr, model_name, video_name, video_link, attempts_count=5):
221
  """
222
  Translates the given script array into another language using the chatgpt and writes to the SRT file.
223
 
 
239
  previous_length = 0
240
  for sentence, range in tqdm(zip(script_arr, range_arr)):
241
  # update the range based on previous length
242
+ range = (range[0] + previous_length, range[1] + previous_length)
243
 
244
  # using chatgpt model
245
  print(f"now translating sentences {range}")
 
253
  while not check_translation(sentence, translate) and attempts_count > 0:
254
  translate = get_response(model_name, sentence)
255
  attempts_count -= 1
256
+
257
  # if failure still happen, split into smaller tokens
258
  if attempts_count == 0:
259
  single_sentences = sentence.split("\n\n")
 
265
  else:
266
  translate += get_response(model_name, single_sentence) + "\n\n"
267
  # print(single_sentence, translate.split("\n\n")[-2])
268
+ logging.info("solved by individually translation!")
269
 
270
  except Exception as e:
271
+ logging.debug("An error has occurred during translation:", e)
272
+ print("An error has occurred during translation:", e)
273
  print("Retrying... the script will continue after 30 seconds.")
274
  time.sleep(30)
275
  flag = True
 
297
  RESULT_PATH = args.output_dir
298
  if not os.path.exists(RESULT_PATH):
299
  os.mkdir(RESULT_PATH)
300
+
301
  # set video name as the input file name if not specified
302
+ if args.video_name == 'placeholder':
303
  # set video name to upload file name
304
  if args.video_file is not None:
305
  VIDEO_NAME = args.video_file.split('/')[-1].split('.')[0]
 
316
 
317
  if not os.path.exists(args.log_dir):
318
  os.makedirs(args.log_dir)
319
+ logging.basicConfig(level=logging.INFO, handlers=[
320
+ logging.FileHandler("{}/{}_{}.log".format(args.log_dir, VIDEO_NAME, datetime.now().strftime("%m%d%Y_%H%M%S")),
321
+ 'w', encoding='utf-8')])
322
  logging.info("---------------------Video Info---------------------")
323
  logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link))
324
 
 
361
  if args.v:
362
  logging.info("encoding video file")
363
  if args.only_srt:
364
+ os.system(
365
+ f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
366
  else:
367
+ os.system(
368
+ f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.ass" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4')
369
 
370
  end_time = time.time()
371
+ logging.info(
372
+ "Pipeline finished, time duration:{}".format(time.strftime("%H:%M:%S", time.gmtime(end_time - start_time))))
373
+
374
 
375
  if __name__ == "__main__":
376
+ main()
srt_util/__init__.py ADDED
File without changes
SRT.py → srt_util/srt.py RENAMED
@@ -8,7 +8,7 @@ import openai
8
  from tqdm import tqdm
9
 
10
 
11
- class SRT_segment(object):
12
  def __init__(self, *args) -> None:
13
  if isinstance(args[0], dict):
14
  segment = args[0]
@@ -64,28 +64,23 @@ class SRT_segment(object):
64
  self.end = seg.end
65
  self.end_ms = seg.end_ms
66
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
67
- pass
68
 
69
  def __add__(self, other):
70
  """
71
  Merge the segment seg with the current segment, and return the new constructed segment.
72
  No in-place modification.
 
73
  :param other: Another segment that is strictly next to added segment.
74
  :return: new segment of the two sub-segments
75
  """
76
 
77
  result = deepcopy(self)
78
- result.source_text += f' {other.source_text}'
79
- result.translation += f' {other.translation}'
80
- result.end_time_str = other.end_time_str
81
- result.end = other.end
82
- result.end_ms = other.end_ms
83
- result.duration = f"{self.start_time_str} --> {result.end_time_str}"
84
  return result
85
 
86
- def remove_trans_punc(self):
87
  """
88
- remove punctuations in translation text
89
  :return: None
90
  """
91
  punc_cn = ",。!?"
@@ -102,12 +97,9 @@ class SRT_segment(object):
102
  return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
103
 
104
 
105
- class SRT_script():
106
  def __init__(self, segments) -> None:
107
- self.segments = []
108
- for seg in segments:
109
- srt_seg = SRT_segment(seg)
110
- self.segments.append(srt_seg)
111
 
112
  @classmethod
113
  def parse_from_srt_file(cls, path: str):
@@ -115,13 +107,12 @@ class SRT_script():
115
  script_lines = [line.rstrip() for line in f.readlines()]
116
 
117
  segments = []
118
- for i in range(len(script_lines)-4):
119
- if i % 4 == 0:
120
- segments.append(list(script_lines[i:i + 4]))
121
 
122
  return cls(segments)
123
 
124
- def merge_segs(self, idx_list) -> SRT_segment:
125
  """
126
  Merge entire segment list to a single segment
127
  :param idx_list: List of index to merge
@@ -147,6 +138,7 @@ class SRT_script():
147
  logging.info("Forming whole sentences...")
148
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
149
  sentence = []
 
150
  for i, seg in enumerate(self.segments):
151
  if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
152
  sentence.append(i)
@@ -155,6 +147,7 @@ class SRT_script():
155
  else:
156
  sentence.append(i)
157
 
 
158
  segments = []
159
  for idx_list in merge_list:
160
  if len(idx_list) > 1:
@@ -254,11 +247,10 @@ class SRT_script():
254
  max_num -= 1
255
  if i == len(lines) - 1:
256
  break
257
- if lines[i][0] in [' ', '\n']:
258
  lines[i] = lines[i][1:]
259
  seg.translation = lines[i]
260
 
261
-
262
  def split_seg(self, seg, text_threshold, time_threshold):
263
  # evenly split seg to 2 parts and add new seg into self.segments
264
 
@@ -314,14 +306,14 @@ class SRT_script():
314
  seg1_dict['text'] = src_seg1
315
  seg1_dict['start'] = start_seg1
316
  seg1_dict['end'] = end_seg1
317
- seg1 = SRT_segment(seg1_dict)
318
  seg1.translation = trans_seg1
319
 
320
  seg2_dict = {}
321
  seg2_dict['text'] = src_seg2
322
  seg2_dict['start'] = start_seg2
323
  seg2_dict['end'] = end_seg2
324
- seg2 = SRT_segment(seg2_dict)
325
  seg2.translation = trans_seg2
326
 
327
  result_list = []
@@ -344,7 +336,7 @@ class SRT_script():
344
  for i, seg in enumerate(self.segments):
345
  if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
346
  seg_list = self.split_seg(seg, text_threshold, time_threshold)
347
- logging.info("splitting segment {} in to {} parts".format(i+1, len(seg_list)))
348
  segments += seg_list
349
  else:
350
  segments.append(seg)
@@ -376,39 +368,41 @@ class SRT_script():
376
  ## force term correction
377
  logging.info("performing force term correction")
378
  # load term dictionary
379
- with open("./finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
380
  term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
381
-
382
  keywords = list(term_enzh_dict.keys())
383
  keywords.sort(key=lambda x: len(x), reverse=True)
384
 
385
  for word in keywords:
386
  for i, seg in enumerate(self.segments):
387
  if word in seg.source_text.lower():
388
- seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(term_enzh_dict.get(word)), seg.source_text, flags=re.IGNORECASE)
389
- logging.info("replace term: " + word + " --> " + term_enzh_dict.get(word) + " in time stamp {}".format(i+1))
 
 
 
390
  logging.info("source text becomes: " + seg.source_text)
391
-
392
-
393
  comp_dict = []
394
-
395
- def fetchfunc(self,word,threshold):
396
  import enchant
397
  result = word
398
  distance = 0
399
- threshold = threshold*len(word)
400
- if len(self.comp_dict)==0:
401
  with open("./finetune_data/dict_freq.txt", 'r', encoding='utf-8') as f:
402
- self.comp_dict = {rows[0]: 1 for rows in reader(f)}
403
  temp = ""
404
  for matched in self.comp_dict:
405
  if (" " in matched and " " in word) or (" " not in matched and " " not in word):
406
- if enchant.utils.levenshtein(word, matched)<enchant.utils.levenshtein(word, temp):
407
  temp = matched
408
  if enchant.utils.levenshtein(word, temp) < threshold:
409
  distance = enchant.utils.levenshtein(word, temp)
410
  result = temp
411
- return distance, result
412
 
413
  def extract_words(self, sentence, n):
414
  # this function split the sentence to chunks by n of words
@@ -417,9 +411,9 @@ class SRT_script():
417
  words = sentence.split()
418
  res = []
419
  for j in range(n, 0, -1):
420
- res += [words[i:i+j] for i in range(len(words)-j+1)]
421
- return res
422
-
423
  def spell_check_term(self):
424
  logging.info("performing spell check")
425
  import enchant
@@ -435,14 +429,14 @@ class SRT_script():
435
  distance, correct_term = self.fetchfunc(real_word, 0.3)
436
  if distance != 0:
437
  seg.source_text = re.sub(word[:pos], correct_term, seg.source_text, flags=re.IGNORECASE)
438
- logging.info("replace: " + word[:pos] + " to " + correct_term + "\t distance = " + str(distance))
 
439
 
440
-
441
- def get_real_word(self, word_list:list):
442
  word = ""
443
  for w in word_list:
444
  word += f"{w} "
445
- word = word[:-1] # "this, is"
446
  if word[-2:] == ".\n":
447
  real_word = word[:-2].lower()
448
  n = -2
@@ -460,8 +454,8 @@ class SRT_script():
460
  # return a string with pure source text
461
  result = ""
462
  for i, seg in enumerate(self.segments):
463
- result+=f'{seg.source_text}\n\n\n'#f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
464
-
465
  return result
466
 
467
  def reform_src_str(self):
 
8
  from tqdm import tqdm
9
 
10
 
11
+ class SrtSegment(object):
12
  def __init__(self, *args) -> None:
13
  if isinstance(args[0], dict):
14
  segment = args[0]
 
64
  self.end = seg.end
65
  self.end_ms = seg.end_ms
66
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
 
67
 
68
  def __add__(self, other):
69
  """
70
  Merge the segment seg with the current segment, and return the new constructed segment.
71
  No in-place modification.
72
+ This is used for '+' operator.
73
  :param other: Another segment that is strictly next to added segment.
74
  :return: new segment of the two sub-segments
75
  """
76
 
77
  result = deepcopy(self)
78
+ result.merge_seg(other)
 
 
 
 
 
79
  return result
80
 
81
+ def remove_trans_punc(self) -> None:
82
  """
83
+ remove CN punctuations in translation text
84
  :return: None
85
  """
86
  punc_cn = ",。!?"
 
97
  return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
98
 
99
 
100
+ class SrtScript(object):
101
  def __init__(self, segments) -> None:
102
+ self.segments = [SrtSegment(seg) for seg in segments]
 
 
 
103
 
104
  @classmethod
105
  def parse_from_srt_file(cls, path: str):
 
107
  script_lines = [line.rstrip() for line in f.readlines()]
108
 
109
  segments = []
110
+ for i in range(0, len(script_lines), 4):
111
+ segments.append(list(script_lines[i:i + 4]))
 
112
 
113
  return cls(segments)
114
 
115
+ def merge_segs(self, idx_list) -> SrtSegment:
116
  """
117
  Merge entire segment list to a single segment
118
  :param idx_list: List of index to merge
 
138
  logging.info("Forming whole sentences...")
139
  merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]]
140
  sentence = []
141
+ # Get each entire sentence of distinct segments, fill indices to merge_list
142
  for i, seg in enumerate(self.segments):
143
  if seg.source_text[-1] in ['.', '!', '?'] and len(seg.source_text) > 10 and 'vs.' not in seg.source_text:
144
  sentence.append(i)
 
147
  else:
148
  sentence.append(i)
149
 
150
+ # Reconstruct segments, each with an entire sentence
151
  segments = []
152
  for idx_list in merge_list:
153
  if len(idx_list) > 1:
 
247
  max_num -= 1
248
  if i == len(lines) - 1:
249
  break
250
+ if lines[i][0] in [' ', '\n']:
251
  lines[i] = lines[i][1:]
252
  seg.translation = lines[i]
253
 
 
254
  def split_seg(self, seg, text_threshold, time_threshold):
255
  # evenly split seg to 2 parts and add new seg into self.segments
256
 
 
306
  seg1_dict['text'] = src_seg1
307
  seg1_dict['start'] = start_seg1
308
  seg1_dict['end'] = end_seg1
309
+ seg1 = SrtSegment(seg1_dict)
310
  seg1.translation = trans_seg1
311
 
312
  seg2_dict = {}
313
  seg2_dict['text'] = src_seg2
314
  seg2_dict['start'] = start_seg2
315
  seg2_dict['end'] = end_seg2
316
+ seg2 = SrtSegment(seg2_dict)
317
  seg2.translation = trans_seg2
318
 
319
  result_list = []
 
336
  for i, seg in enumerate(self.segments):
337
  if len(seg.translation) > text_threshold and (seg.end - seg.start) > time_threshold:
338
  seg_list = self.split_seg(seg, text_threshold, time_threshold)
339
+ logging.info("splitting segment {} in to {} parts".format(i + 1, len(seg_list)))
340
  segments += seg_list
341
  else:
342
  segments.append(seg)
 
368
  ## force term correction
369
  logging.info("performing force term correction")
370
  # load term dictionary
371
+ with open("../finetune_data/dict_enzh.csv", 'r', encoding='utf-8') as f:
372
  term_enzh_dict = {rows[0]: rows[1] for rows in reader(f)}
373
+
374
  keywords = list(term_enzh_dict.keys())
375
  keywords.sort(key=lambda x: len(x), reverse=True)
376
 
377
  for word in keywords:
378
  for i, seg in enumerate(self.segments):
379
  if word in seg.source_text.lower():
380
+ seg.source_text = re.sub(fr"({word}es|{word}s?)\b", "{}".format(term_enzh_dict.get(word)),
381
+ seg.source_text, flags=re.IGNORECASE)
382
+ logging.info(
383
+ "replace term: " + word + " --> " + term_enzh_dict.get(word) + " in time stamp {}".format(
384
+ i + 1))
385
  logging.info("source text becomes: " + seg.source_text)
386
+
 
387
  comp_dict = []
388
+
389
+ def fetchfunc(self, word, threshold):
390
  import enchant
391
  result = word
392
  distance = 0
393
+ threshold = threshold * len(word)
394
+ if len(self.comp_dict) == 0:
395
  with open("./finetune_data/dict_freq.txt", 'r', encoding='utf-8') as f:
396
+ self.comp_dict = {rows[0]: 1 for rows in reader(f)}
397
  temp = ""
398
  for matched in self.comp_dict:
399
  if (" " in matched and " " in word) or (" " not in matched and " " not in word):
400
+ if enchant.utils.levenshtein(word, matched) < enchant.utils.levenshtein(word, temp):
401
  temp = matched
402
  if enchant.utils.levenshtein(word, temp) < threshold:
403
  distance = enchant.utils.levenshtein(word, temp)
404
  result = temp
405
+ return distance, result
406
 
407
  def extract_words(self, sentence, n):
408
  # this function split the sentence to chunks by n of words
 
411
  words = sentence.split()
412
  res = []
413
  for j in range(n, 0, -1):
414
+ res += [words[i:i + j] for i in range(len(words) - j + 1)]
415
+ return res
416
+
417
  def spell_check_term(self):
418
  logging.info("performing spell check")
419
  import enchant
 
429
  distance, correct_term = self.fetchfunc(real_word, 0.3)
430
  if distance != 0:
431
  seg.source_text = re.sub(word[:pos], correct_term, seg.source_text, flags=re.IGNORECASE)
432
+ logging.info(
433
+ "replace: " + word[:pos] + " to " + correct_term + "\t distance = " + str(distance))
434
 
435
+ def get_real_word(self, word_list: list):
 
436
  word = ""
437
  for w in word_list:
438
  word += f"{w} "
439
+ word = word[:-1] # "this, is"
440
  if word[-2:] == ".\n":
441
  real_word = word[:-2].lower()
442
  n = -2
 
454
  # return a string with pure source text
455
  result = ""
456
  for i, seg in enumerate(self.segments):
457
+ result += f'{seg.source_text}\n\n\n' # f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
458
+
459
  return result
460
 
461
  def reform_src_str(self):
srt2ass.py → srt_util/srt2ass.py RENAMED
File without changes