david
commited on
Commit
·
e1e0093
1
Parent(s):
aca8f40
fix seg id error
Browse files- transcribe/strategy.py +14 -11
transcribe/strategy.py
CHANGED
|
@@ -147,7 +147,14 @@ class TranscriptBuffer:
|
|
| 147 |
self._segments: List[str] = [] # 确认的完整段落
|
| 148 |
self._sentences: List[str] = [] # 当前段落中的短句
|
| 149 |
self._buffer: str = "" # 当前缓冲中的文本
|
|
|
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def update_pending_text(self, text: str) -> None:
|
| 152 |
"""更新临时缓冲字符串"""
|
| 153 |
self._buffer = text
|
|
@@ -158,16 +165,16 @@ class TranscriptBuffer:
|
|
| 158 |
self._sentences.append(self._buffer)
|
| 159 |
self._buffer = ""
|
| 160 |
|
| 161 |
-
def commit_paragraph(self
|
| 162 |
"""
|
| 163 |
提交当前短句为完整段落(如句子结束)
|
| 164 |
|
| 165 |
Args:
|
| 166 |
end_of_sentence: 是否为句子结尾(如检测到句号)
|
| 167 |
"""
|
| 168 |
-
self.
|
| 169 |
-
if end_of_sentence and self._sentences:
|
| 170 |
self._segments.append("".join(self._sentences))
|
|
|
|
| 171 |
self._sentences.clear()
|
| 172 |
|
| 173 |
|
|
@@ -175,12 +182,11 @@ class TranscriptBuffer:
|
|
| 175 |
self.update_pending_text(stable_string)
|
| 176 |
if is_end_sentence:
|
| 177 |
self.commit_paragraph(end_of_sentence=True)
|
|
|
|
| 178 |
else:
|
| 179 |
self.commit_line()
|
| 180 |
self.update_pending_text(remaining_string)
|
| 181 |
|
| 182 |
-
def get_seg_id(self) -> int:
|
| 183 |
-
return len(self._segments)
|
| 184 |
|
| 185 |
@property
|
| 186 |
def paragraph(self) -> str:
|
|
@@ -197,7 +203,6 @@ class TranscriptBuffer:
|
|
| 197 |
"""最新确认的段落"""
|
| 198 |
return self._segments[-1] if self._segments else ""
|
| 199 |
|
| 200 |
-
|
| 201 |
@property
|
| 202 |
def current_not_commit_text(self) -> str:
|
| 203 |
return self.paragraph + self.pending_text
|
|
@@ -274,23 +279,21 @@ class TranscriptStabilityAnalyzer:
|
|
| 274 |
|
| 275 |
prev_seg_id = self._transcript_buffer.get_seg_id()
|
| 276 |
self._transcript_buffer.update_and_commit(stable_str, remaining_str, is_end_sentence)
|
| 277 |
-
curr_seg_id = self._transcript_buffer.get_seg_id()
|
| 278 |
|
| 279 |
logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
|
| 280 |
|
| 281 |
-
if
|
| 282 |
# 表示生成了一个新段落 换行
|
| 283 |
yield TranscriptResult(
|
| 284 |
-
seg_id=
|
| 285 |
cut_index=frame_cut_index,
|
| 286 |
context=self._transcript_buffer.latest_paragraph,
|
| 287 |
is_end_sentence=True
|
| 288 |
)
|
| 289 |
-
|
| 290 |
# 如果还有挂起的文本
|
| 291 |
if (current_not_commit_text := self._transcript_buffer.current_not_commit_text.strip()):
|
| 292 |
yield TranscriptResult(
|
| 293 |
-
seg_id=
|
| 294 |
cut_index=frame_cut_index,
|
| 295 |
context=current_not_commit_text
|
| 296 |
)
|
|
|
|
| 147 |
self._segments: List[str] = [] # 确认的完整段落
|
| 148 |
self._sentences: List[str] = [] # 当前段落中的短句
|
| 149 |
self._buffer: str = "" # 当前缓冲中的文本
|
| 150 |
+
self._current_seg_id: int =0
|
| 151 |
|
| 152 |
+
def get_seg_id(self) -> int:
|
| 153 |
+
return self._current_seg_id
|
| 154 |
+
|
| 155 |
+
def next_seg_id(self) -> int:
|
| 156 |
+
return self._current_seg_id + 1
|
| 157 |
+
|
| 158 |
def update_pending_text(self, text: str) -> None:
|
| 159 |
"""更新临时缓冲字符串"""
|
| 160 |
self._buffer = text
|
|
|
|
| 165 |
self._sentences.append(self._buffer)
|
| 166 |
self._buffer = ""
|
| 167 |
|
| 168 |
+
def commit_paragraph(self) -> None:
|
| 169 |
"""
|
| 170 |
提交当前短句为完整段落(如句子结束)
|
| 171 |
|
| 172 |
Args:
|
| 173 |
end_of_sentence: 是否为句子结尾(如检测到句号)
|
| 174 |
"""
|
| 175 |
+
if self._sentences:
|
|
|
|
| 176 |
self._segments.append("".join(self._sentences))
|
| 177 |
+
self._current_seg_id += 1
|
| 178 |
self._sentences.clear()
|
| 179 |
|
| 180 |
|
|
|
|
| 182 |
self.update_pending_text(stable_string)
|
| 183 |
if is_end_sentence:
|
| 184 |
self.commit_paragraph(end_of_sentence=True)
|
| 185 |
+
# if len() >=20
|
| 186 |
else:
|
| 187 |
self.commit_line()
|
| 188 |
self.update_pending_text(remaining_string)
|
| 189 |
|
|
|
|
|
|
|
| 190 |
|
| 191 |
@property
|
| 192 |
def paragraph(self) -> str:
|
|
|
|
| 203 |
"""最新确认的段落"""
|
| 204 |
return self._segments[-1] if self._segments else ""
|
| 205 |
|
|
|
|
| 206 |
@property
|
| 207 |
def current_not_commit_text(self) -> str:
|
| 208 |
return self.paragraph + self.pending_text
|
|
|
|
| 279 |
|
| 280 |
prev_seg_id = self._transcript_buffer.get_seg_id()
|
| 281 |
self._transcript_buffer.update_and_commit(stable_str, remaining_str, is_end_sentence)
|
|
|
|
| 282 |
|
| 283 |
logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
|
| 284 |
|
| 285 |
+
if self._transcript_buffer.get_seg_id() > prev_seg_id:
|
| 286 |
# 表示生成了一个新段落 换行
|
| 287 |
yield TranscriptResult(
|
| 288 |
+
seg_id=prev_seg_id,
|
| 289 |
cut_index=frame_cut_index,
|
| 290 |
context=self._transcript_buffer.latest_paragraph,
|
| 291 |
is_end_sentence=True
|
| 292 |
)
|
|
|
|
| 293 |
# 如果还有挂起的文本
|
| 294 |
if (current_not_commit_text := self._transcript_buffer.current_not_commit_text.strip()):
|
| 295 |
yield TranscriptResult(
|
| 296 |
+
seg_id=self._transcript_buffer.get_seg_id(),
|
| 297 |
cut_index=frame_cut_index,
|
| 298 |
context=current_not_commit_text
|
| 299 |
)
|