Eason Lu commited on
Commit
6113bd9
·
1 Parent(s): cf5f1c9

merge segments

Browse files

Former-commit-id: 3b73651a94d5dac62b0c7577f59b3d59509839f9

Files changed (2) hide show
  1. SRT.py +60 -15
  2. pipeline.py +13 -7
SRT.py CHANGED
@@ -8,26 +8,31 @@ class SRT_segment(object):
8
  segment = args[0]
9
  self.start_time_str = str(0)+str(timedelta(seconds=int(segment['start'])))+',000'
10
  self.end_time_str = str(0)+str(timedelta(seconds=int(segment['end'])))+',000'
11
- self.segment_id = segment['id']+1
12
  self.source_text = segment['text']
13
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
14
  self.translation = ""
15
  elif isinstance(args[0], list):
16
- self.segment_id = args[0][0]
17
  self.source_text = args[0][2]
18
  self.duration = args[0][1]
19
- self.start_time_str = self.duration.split("-->")[0]
20
- self.end_time_str = self.duration.split("-->")[1]
21
  self.translation = ""
22
 
 
 
 
 
 
 
 
23
  def __str__(self) -> str:
24
- return f'{self.segment_id}\n{self.duration}\n{self.source_text}\n\n'
25
 
26
  def get_trans_str(self) -> str:
27
- return f'{self.segment_id}\n{self.duration}\n{self.translation}\n\n'
28
 
29
  def get_bilingual_str(self) -> str:
30
- return f'{self.segment_id}\n{self.duration}\n{self.source_text}\n{self.translation}\n\n'
31
 
32
  class SRT_script():
33
  def __init__(self, segments) -> None:
@@ -48,42 +53,82 @@ class SRT_script():
48
 
49
  return cls(segments)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def set_translation(self, translate:str, id_range:tuple):
52
  start_seg_id = id_range[0]
53
  end_seg_id = id_range[1]
54
 
55
  lines = translate.split('\n\n')
56
- print(id_range)
57
- print(translate)
58
- # print(len(translate))
 
 
59
 
60
  for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
61
  seg.translation = lines[i]
62
  pass
 
 
 
 
 
 
 
 
63
 
64
  def get_source_only(self):
65
  # return a string with pure source text
66
  result = ""
67
- for seg in self.segments:
68
- result+=f'{seg.source_text}\n\n'
69
 
70
  return result
71
 
72
  def reform_src_str(self):
73
  result = ""
74
- for seg in self.segments:
 
75
  result += str(seg)
76
  return result
77
 
78
  def reform_trans_str(self):
79
  result = ""
80
- for seg in self.segments:
 
81
  result += seg.get_trans_str()
82
  return result
83
 
84
  def form_bilingual_str(self):
85
  result = ""
86
- for seg in self.segments:
 
87
  result += seg.get_bilingual_str()
88
  return result
89
 
 
8
  segment = args[0]
9
  self.start_time_str = str(0)+str(timedelta(seconds=int(segment['start'])))+',000'
10
  self.end_time_str = str(0)+str(timedelta(seconds=int(segment['end'])))+',000'
 
11
  self.source_text = segment['text']
12
  self.duration = f"{self.start_time_str} --> {self.end_time_str}"
13
  self.translation = ""
14
  elif isinstance(args[0], list):
 
15
  self.source_text = args[0][2]
16
  self.duration = args[0][1]
17
+ self.start_time_str = self.duration.split(" --> ")[0]
18
+ self.end_time_str = self.duration.split(" --> ")[1]
19
  self.translation = ""
20
 
21
+ def merge_seg(self, seg):
22
+ self.source_text += seg.source_text
23
+ self.translation += seg.translation
24
+ self.end_time_str = seg.end_time_str
25
+ self.duration = f"{self.start_time_str} --> {self.end_time_str}"
26
+ pass
27
+
28
  def __str__(self) -> str:
29
+ return f'{self.duration}\n{self.source_text}\n\n'
30
 
31
  def get_trans_str(self) -> str:
32
+ return f'{self.duration}\n{self.translation}\n\n'
33
 
34
  def get_bilingual_str(self) -> str:
35
+ return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n'
36
 
37
  class SRT_script():
38
  def __init__(self, segments) -> None:
 
53
 
54
  return cls(segments)
55
 
56
+ def merge_segs(self, idx_list) -> SRT_segment:
57
+ final_seg = self.segments[idx_list[0]]
58
+ if len(idx_list) == 1:
59
+ return final_seg
60
+
61
+ for idx in range(1, len(idx_list)):
62
+ final_seg.merge_seg(self.segments[idx_list[idx]])
63
+
64
+ return final_seg
65
+
66
+ def form_whole_sentence(self):
67
+ merge_list = [] # a list of indices that should be merged e.g. [[0], [2, 3, 4], [5, 6], [7]]
68
+ sentence = []
69
+ for i, seg in enumerate(self.segments):
70
+ if seg.source_text[-1] == '.':
71
+ sentence.append(i)
72
+ merge_list.append(sentence)
73
+ sentence = []
74
+ else:
75
+ sentence.append(i)
76
+
77
+ segments = []
78
+ for idx_list in merge_list:
79
+ segments.append(self.merge_segs(idx_list))
80
+
81
+ self.segments = segments # need memory release?
82
+
83
  def set_translation(self, translate:str, id_range:tuple):
84
  start_seg_id = id_range[0]
85
  end_seg_id = id_range[1]
86
 
87
  lines = translate.split('\n\n')
88
+ if len(lines) != (end_seg_id - start_seg_id + 1):
89
+ print(id_range)
90
+ for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
91
+ print(seg.source_text)
92
+ print(translate)
93
 
94
  for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]):
95
  seg.translation = lines[i]
96
  pass
97
+
98
+ def split_seg(self, seg_id):
99
+ # TODO: evenly split seg to 2 parts and add new seg into self.segments
100
+ pass
101
+
102
+ def check_len_and_split(self, threshold):
103
+ # TODO: if sentence length >= threshold, split this segments to two
104
+ pass
105
 
106
  def get_source_only(self):
107
  # return a string with pure source text
108
  result = ""
109
+ for i, seg in enumerate(self.segments):
110
+ result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n'
111
 
112
  return result
113
 
114
  def reform_src_str(self):
115
  result = ""
116
+ for i, seg in enumerate(self.segments):
117
+ result += f'{i+1}\n'
118
  result += str(seg)
119
  return result
120
 
121
  def reform_trans_str(self):
122
  result = ""
123
+ for i, seg in enumerate(self.segments):
124
+ result += f'{i+1}\n'
125
  result += seg.get_trans_str()
126
  return result
127
 
128
  def form_bilingual_str(self):
129
  result = ""
130
+ for i, seg in enumerate(self.segments):
131
+ result += f'{i+1}\n'
132
  result += seg.get_bilingual_str()
133
  return result
134
 
pipeline.py CHANGED
@@ -88,8 +88,6 @@ if not os.path.exists(f'{RESULT_PATH}/{VIDEO_NAME}'):
88
  srt_file_en = args.srt_file
89
 
90
  if srt_file_en is not None:
91
- # with open(srt_file_en, 'r', encoding='utf-8') as f:
92
- # script_input = f.read()
93
  srt = SRT_script.parse_from_srt_file(srt_file_en)
94
  script_input = srt.get_source_only()
95
  else:
@@ -106,12 +104,20 @@ else:
106
 
107
  # use stable-whisper
108
  model = stable_whisper.load_model('base')
109
- transcript = model.transcribe(audio_path)
110
- transcript.to_srt_vtt(srt_file_en)
 
 
 
 
 
 
 
111
  transcript = transcript.to_dict()
112
  srt = SRT_script(transcript['segments']) # read segments to SRT class
 
113
  script_input = srt.get_source_only()
114
-
115
  #Write SRT file
116
 
117
  # from whisper.utils import WriteSRT
@@ -168,7 +174,7 @@ if not args.only_srt:
168
 
169
  # script_input_withForceTerm = re.sub('\n ', '\n', "".join(ready_words))
170
 
171
- srt.correct_with_force_term()
172
 
173
  # Split the video script by sentences and create chunks within the token limit
174
  def script_split(script_in, chunk_size = 1000):
@@ -199,8 +205,8 @@ script_arr, range_arr = script_split(script_input)
199
 
200
  # Translate and save
201
  for s, range in tqdm(zip(script_arr, range_arr)):
202
- print(s)
203
  # using chatgpt model
 
204
  if model_name == "gpt-3.5-turbo":
205
  # print(s + "\n")
206
  response = openai.ChatCompletion.create(
 
88
  srt_file_en = args.srt_file
89
 
90
  if srt_file_en is not None:
 
 
91
  srt = SRT_script.parse_from_srt_file(srt_file_en)
92
  script_input = srt.get_source_only()
93
  else:
 
104
 
105
  # use stable-whisper
106
  model = stable_whisper.load_model('base')
107
+ transcript = model.transcribe(audio_path, regroup = False)
108
+ (
109
+ transcript
110
+ .split_by_punctuation(['.', '。', '?'])
111
+ .merge_by_gap(.15, max_words=3)
112
+ .merge_by_punctuation([' '])
113
+ .split_by_punctuation(['.', '。', '?'])
114
+ )
115
+ # transcript.to_srt_vtt(srt_file_en)
116
  transcript = transcript.to_dict()
117
  srt = SRT_script(transcript['segments']) # read segments to SRT class
118
+ srt.form_whole_sentence()
119
  script_input = srt.get_source_only()
120
+ srt.write_srt_file_src(srt_file_en)
121
  #Write SRT file
122
 
123
  # from whisper.utils import WriteSRT
 
174
 
175
  # script_input_withForceTerm = re.sub('\n ', '\n', "".join(ready_words))
176
 
177
+ # srt.correct_with_force_term()
178
 
179
  # Split the video script by sentences and create chunks within the token limit
180
  def script_split(script_in, chunk_size = 1000):
 
205
 
206
  # Translate and save
207
  for s, range in tqdm(zip(script_arr, range_arr)):
 
208
  # using chatgpt model
209
+ print(f"now translating sentences {range}")
210
  if model_name == "gpt-3.5-turbo":
211
  # print(s + "\n")
212
  response = openai.ChatCompletion.create(