NHLOCAL commited on
Commit
e6ed4d0
·
1 Parent(s): 4ad247f

זיהוי דינמי של שניות/מילישניות

Browse files
Files changed (1) hide show
  1. main.py +81 -88
main.py CHANGED
@@ -9,6 +9,7 @@ import yaml
9
  import json
10
  import io
11
  import os
 
12
  from datetime import timedelta
13
  import logging
14
  import asyncio
@@ -48,27 +49,60 @@ def load_system_prompt():
48
  logging.error(f"Error loading instruct.yml: {e}")
49
  raise HTTPException(status_code=500, detail="Server configuration error.")
50
 
 
 
 
 
51
  def parse_time_str_to_ms(time_str: str) -> int:
 
 
 
 
52
  if not isinstance(time_str, str):
53
- raise TypeError("Time string must be a string.")
 
 
54
  time_str = time_str.replace(',', '.')
 
 
 
 
 
 
 
55
  try:
56
- if '.' in time_str:
57
- parts = time_str.split('.')
58
- hms_part, ms_part = parts[0], parts[1]
 
59
  ms = int(ms_part.ljust(3, '0')[:3])
 
 
 
 
 
 
 
 
60
  else:
61
- hms_part, ms = time_str, 0
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- time_components = list(map(int, hms_part.split(':')))
64
- h, m, s = 0, 0, 0
65
- if len(time_components) == 3: h, m, s = time_components
66
- elif len(time_components) == 2: m, s = time_components
67
- elif len(time_components) == 1: s = time_components[0]
68
- else: raise ValueError("Too many ':' separators.")
69
  return (h * 3600000) + (m * 60000) + (s * 1000) + ms
 
70
  except (ValueError, IndexError) as e:
71
- raise ValueError(f"Could not parse time string: '{time_str}'. Error: {e}")
 
72
 
73
  def format_ms_to_srt_time(ms: int) -> str:
74
  td = timedelta(milliseconds=ms)
@@ -78,114 +112,82 @@ def format_ms_to_srt_time(ms: int) -> str:
78
  milliseconds = td.microseconds // 1000
79
  return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
80
 
 
81
  def find_silence_points_webrtcvad(audio_segment: AudioSegment, min_silence_len_ms: int, vad_aggressiveness: int):
82
- # This function remains unchanged
83
  if audio_segment.frame_rate not in [8000, 16000, 32000, 48000]:
84
  audio_segment = audio_segment.set_frame_rate(16000)
85
- if audio_segment.channels > 1:
86
- audio_segment = audio_segment.set_channels(1)
87
- if audio_segment.sample_width != 2:
88
- audio_segment = audio_segment.set_sample_width(2)
89
  vad = webrtcvad.Vad(vad_aggressiveness)
90
  frame_duration_ms = 30
91
- frame_size_bytes = int(audio_segment.frame_rate * (frame_duration_ms / 1000.0) * audio_segment.sample_width)
92
  silence_points_ms, silence_start_ms = [], None
93
- raw_data = audio_segment.raw_data
94
- num_frames = len(raw_data) // frame_size_bytes
95
  for i in range(num_frames):
96
- start_byte, end_byte = i * frame_size_bytes, i * frame_size_bytes + frame_size_bytes
97
- frame = raw_data[start_byte:end_byte]
98
  if len(frame) < frame_size_bytes: break
99
- is_speech = vad.is_speech(frame, audio_segment.frame_rate)
100
- current_time_ms = i * frame_duration_ms
101
  if not is_speech:
102
  if silence_start_ms is None: silence_start_ms = current_time_ms
103
- else:
104
- if silence_start_ms is not None:
105
- if current_time_ms - silence_start_ms >= min_silence_len_ms:
106
- silence_points_ms.append(silence_start_ms + (current_time_ms - silence_start_ms) // 2)
107
- silence_start_ms = None
108
  if silence_start_ms is not None and len(audio_segment) - silence_start_ms >= min_silence_len_ms:
109
  silence_points_ms.append(silence_start_ms + (len(audio_segment) - silence_start_ms) // 2)
110
  return silence_points_ms
111
 
 
112
  def split_audio_webrtcvad(audio_segment, min_silence_len):
113
- # This function remains unchanged
114
  logging.info(f"Splitting with WebRTCVAD: Target Chunk {TARGET_CHUNK_DURATION_MIN}m, VAD Aggressiveness {VAD_AGGRESSIVENESS}")
115
  silence_points = find_silence_points_webrtcvad(audio_segment, min_silence_len, VAD_AGGRESSIVENESS)
116
  if not silence_points:
117
  logging.warning("WebRTCVAD found no significant silences. Splitting into fixed chunks.")
118
  return [audio_segment[i:i + TARGET_CHUNK_DURATION_MS] for i in range(0, len(audio_segment), TARGET_CHUNK_DURATION_MS)]
119
- final_chunks, current_offset = [], 0
120
- total_length = len(audio_segment)
121
  while current_offset < total_length:
122
- remaining_length = total_length - current_offset
123
- if remaining_length <= MAX_SPLIT_SEARCH_END_MS:
124
  final_chunks.append(audio_segment[current_offset:])
125
  break
126
  ideal_split_point = current_offset + TARGET_CHUNK_DURATION_MS
127
  candidate_points = [p for p in silence_points if (current_offset + MIN_SPLIT_SEARCH_START_MS) <= p < (current_offset + MAX_SPLIT_SEARCH_END_MS)]
128
  best_split_point = min(candidate_points, key=lambda p: abs(p - ideal_split_point)) if candidate_points else -1
129
- split_at = best_split_point if best_split_point != -1 else current_offset + TARGET_CHUNK_DURATION_MS
130
  final_chunks.append(audio_segment[current_offset:int(split_at)])
131
  current_offset = int(split_at)
132
  logging.info(f"File successfully split into {len(final_chunks)} chunks using WebRTCVAD.")
133
  logging.info(f"Chunk durations (seconds): {[round(len(c) / 1000) for c in final_chunks]}")
134
  return final_chunks
135
 
136
- # ---
137
- # *** פונקציית התיקון עם הלוגיקה החדשה והקריטית ***
138
- # ---
139
  def validate_and_correct_segments(segments_from_api, chunk_duration_ms):
140
- corrected_segments = []
141
- last_corrected_end_ms = 0
142
-
143
  for seg in segments_from_api:
144
  try:
145
  start_ms = parse_time_str_to_ms(seg.get('start_time'))
146
  end_ms = parse_time_str_to_ms(seg.get('end_time'))
147
 
148
- # --- NEW CRITICAL LOGIC TO PREVENT AVALANCHE ---
149
- # אם זמן ההתחלה הוא הזיה (מחוץ לגבולות הקטע),
150
- # דלג על המקטע הזה לחלוטין. זה ימנע ממנו להרוס את
151
- # last_corrected_end_ms ואת כל הקטעים שאחריו.
152
  if start_ms >= chunk_duration_ms:
153
- logging.warning(f"Skipping segment with hallucinatory start_time ({format_ms_to_srt_time(start_ms)}) outside of chunk duration ({format_ms_to_srt_time(chunk_duration_ms)}).")
154
- continue
155
- # --- END OF NEW LOGIC ---
156
-
157
- # קצץ את זמן הסיום אם הוא חורג (התנהגות פחות הרסנית מהזיית זמן התחלה)
158
- if end_ms > chunk_duration_ms:
159
- end_ms = chunk_duration_ms
160
-
161
- # תקן זמנים הפוכים
162
- if start_ms >= end_ms:
163
- end_ms = start_ms + 3000
164
- end_ms = min(end_ms, chunk_duration_ms)
165
-
166
- # תקן חפיפה
167
- if start_ms < last_corrected_end_ms:
168
- start_ms = last_corrected_end_ms
169
-
170
- # בדיקה סופית
171
- if start_ms >= end_ms:
172
- logging.warning(f"Skipping segment after corrections resulted in zero/negative duration: {seg}")
173
  continue
 
 
 
 
 
174
 
175
- seg['start_time_relative'] = start_ms
176
- seg['end_time_relative'] = end_ms
177
  corrected_segments.append(seg)
178
-
179
  last_corrected_end_ms = end_ms
180
-
181
  except (ValueError, TypeError, KeyError) as e:
182
  logging.warning(f"Skipping segment due to invalid format or value: {seg}. Error: {e}")
183
  continue
184
-
185
  return corrected_segments
186
 
 
187
  def transcribe_chunk(chunk_audio, api_key, system_prompt, pydantic_schema, model_name):
188
- # This function remains unchanged
189
  try:
190
  client = genai.Client(api_key=api_key)
191
  buffer = io.BytesIO()
@@ -202,22 +204,15 @@ def transcribe_chunk(chunk_audio, api_key, system_prompt, pydantic_schema, model
202
  def generate_srt_content(segments):
203
  lines = []
204
  for i, seg in enumerate(segments, 1):
205
- lines.append(str(i))
206
- start = format_ms_to_srt_time(seg['start_time_abs'])
207
- end = format_ms_to_srt_time(seg['end_time_abs'])
208
- lines.append(f"{start} --> {end}")
209
- lines.append(seg['text'])
210
- lines.append("")
211
  return "\n".join(lines)
212
 
 
213
  async def _transcribe_and_stream(api_key: str, file_content: bytes, model_name: str):
214
- # This function's core logic remains unchanged, but it now benefits from the improved validation
215
  def send_event(type: str, message: str = "", percent: int = 0, data: str = ""):
216
  return json.dumps({"type": type, "message": message, "percent": percent, "data": data}) + "\n\n"
217
-
218
  try:
219
- system_prompt = load_system_prompt()
220
- pydantic_schema = TranscriptionSegment
221
  yield send_event("progress", "מעבד את קובץ השמע...", 5)
222
  audio = AudioSegment.from_file(io.BytesIO(file_content))
223
  yield send_event("progress", f"אורך הקובץ {len(audio) / 60000:.1f} דקות. מבצע חלוקה...", 15)
@@ -228,28 +223,26 @@ async def _transcribe_and_stream(api_key: str, file_content: bytes, model_name:
228
  yield send_event("progress", f"הקובץ חולק ל-{len(chunks)} מקטעים. מתחיל תמלול...", 20)
229
 
230
  all_segs, offset = [], 0
231
- total_chunks = len(chunks)
232
  for i, ch in enumerate(chunks):
233
- progress_percent = 20 + int((i / total_chunks) * 75)
234
- yield send_event("progress", f"מתמלל מקטע {i+1} מתוך {total_chunks}...", progress_percent)
235
  data, error_msg = await asyncio.to_thread(transcribe_chunk, ch, api_key, system_prompt, pydantic_schema, model_name)
236
  if error_msg: raise ValueError(f"שגיאה בעיבוד מקטע {i+1}: {error_msg}")
237
  if data and isinstance(data, list):
238
  corrected_segments = validate_and_correct_segments(data, len(ch))
239
  for seg in corrected_segments:
240
  seg['start_time_abs'] = seg['start_time_relative'] + offset
241
- seg['end_time_abs'] = seg['end_time_relative'] + offset
242
  all_segs.append(seg)
243
  offset += len(ch)
244
 
245
  if not all_segs: raise ValueError("התמלול נכשל. לא נוצר תוכן תקני.")
246
- yield send_event("progress", "התמלול הושלם! יוצר קובץ SRT...", 98)
247
- srt_content = generate_srt_content(all_segs)
248
- yield send_event("result", "התהליך הושלם בהצלחה!", 100, data=srt_content)
249
  except Exception as e:
250
  logging.error(f"Streaming transcription failed: {e}", exc_info=True)
251
  yield send_event("error", f"אירעה שגיאה: {e}", 100)
252
 
 
253
  @app.get("/", response_class=HTMLResponse)
254
  async def read_root(request: Request):
255
  return templates.TemplateResponse("index.html", {"request": request})
 
9
  import json
10
  import io
11
  import os
12
+ import re
13
  from datetime import timedelta
14
  import logging
15
  import asyncio
 
49
  logging.error(f"Error loading instruct.yml: {e}")
50
  raise HTTPException(status_code=500, detail="Server configuration error.")
51
 
52
+ # ---
53
+ # *** פונקציית פיענוח זמן חכמה וסופית ***
54
+ # יודעת להבחין באופן אדפטיבי בין פורמט HH:MM:SS ל-MM:SS:mmm
55
+ # ---
56
  def parse_time_str_to_ms(time_str: str) -> int:
57
+ """
58
+ Parses a timestamp string into milliseconds with adaptive format detection.
59
+ Correctly interprets HH:MM:SS,mmm and MM:SS:mmm formats, even with inconsistent separators.
60
+ """
61
  if not isinstance(time_str, str):
62
+ raise TypeError(f"Time string must be a string, got {type(time_str)}")
63
+
64
+ # Normalize decimal separator to period
65
  time_str = time_str.replace(',', '.')
66
+
67
+ # Find the last separator to distinguish format
68
+ last_colon_pos = time_str.rfind(':')
69
+ last_period_pos = time_str.rfind('.')
70
+
71
+ h, m, s, ms = 0, 0, 0, 0
72
+
73
  try:
74
+ # Case 1: Format includes milliseconds (e.g., "MM:SS.mmm" or "HH:MM:SS.mmm")
75
+ if last_period_pos > last_colon_pos:
76
+ hms_part = time_str[:last_period_pos]
77
+ ms_part = time_str[last_period_pos+1:]
78
  ms = int(ms_part.ljust(3, '0')[:3])
79
+
80
+ time_components = list(map(int, hms_part.split(':')))
81
+ if len(time_components) == 3: h, m, s = time_components
82
+ elif len(time_components) == 2: m, s = time_components
83
+ elif len(time_components) == 1: s = time_components[0]
84
+
85
+ # Case 2: Format uses colon for milliseconds (e.g., "MM:SS:mmm")
86
+ # Or it's a standard HH:MM:SS format.
87
  else:
88
+ components = list(map(int, time_str.split(':')))
89
+ # If the last component is > 59, it must be milliseconds
90
+ if len(components) >= 2 and components[-1] > 59:
91
+ ms = components[-1]
92
+ s = components[-2]
93
+ if len(components) == 3: m = components[0]
94
+ elif len(components) > 3: h, m = components[0], components[1] # For very long times
95
+ # Otherwise, it's a standard HH:MM:SS format
96
+ else:
97
+ if len(components) == 3: h, m, s = components
98
+ elif len(components) == 2: m, s = components
99
+ elif len(components) == 1: s = components[0]
100
 
 
 
 
 
 
 
101
  return (h * 3600000) + (m * 60000) + (s * 1000) + ms
102
+
103
  except (ValueError, IndexError) as e:
104
+ raise ValueError(f"Could not parse adaptive time string: '{time_str}'. Error: {e}")
105
+
106
 
107
  def format_ms_to_srt_time(ms: int) -> str:
108
  td = timedelta(milliseconds=ms)
 
112
  milliseconds = td.microseconds // 1000
113
  return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
114
 
115
+
116
  def find_silence_points_webrtcvad(audio_segment: AudioSegment, min_silence_len_ms: int, vad_aggressiveness: int):
 
117
  if audio_segment.frame_rate not in [8000, 16000, 32000, 48000]:
118
  audio_segment = audio_segment.set_frame_rate(16000)
119
+ if audio_segment.channels > 1: audio_segment = audio_segment.set_channels(1)
120
+ if audio_segment.sample_width != 2: audio_segment = audio_segment.set_sample_width(2)
 
 
121
  vad = webrtcvad.Vad(vad_aggressiveness)
122
  frame_duration_ms = 30
123
+ frame_size_bytes = int(audio_segment.frame_rate * (frame_duration_ms / 1000.0) * 2)
124
  silence_points_ms, silence_start_ms = [], None
125
+ raw_data, num_frames = audio_segment.raw_data, len(audio_segment.raw_data) // frame_size_bytes
 
126
  for i in range(num_frames):
127
+ frame = raw_data[i*frame_size_bytes:(i+1)*frame_size_bytes]
 
128
  if len(frame) < frame_size_bytes: break
129
+ is_speech, current_time_ms = vad.is_speech(frame, audio_segment.frame_rate), i * frame_duration_ms
 
130
  if not is_speech:
131
  if silence_start_ms is None: silence_start_ms = current_time_ms
132
+ elif silence_start_ms is not None:
133
+ if current_time_ms - silence_start_ms >= min_silence_len_ms:
134
+ silence_points_ms.append(silence_start_ms + (current_time_ms - silence_start_ms) // 2)
135
+ silence_start_ms = None
 
136
  if silence_start_ms is not None and len(audio_segment) - silence_start_ms >= min_silence_len_ms:
137
  silence_points_ms.append(silence_start_ms + (len(audio_segment) - silence_start_ms) // 2)
138
  return silence_points_ms
139
 
140
+
141
  def split_audio_webrtcvad(audio_segment, min_silence_len):
 
142
  logging.info(f"Splitting with WebRTCVAD: Target Chunk {TARGET_CHUNK_DURATION_MIN}m, VAD Aggressiveness {VAD_AGGRESSIVENESS}")
143
  silence_points = find_silence_points_webrtcvad(audio_segment, min_silence_len, VAD_AGGRESSIVENESS)
144
  if not silence_points:
145
  logging.warning("WebRTCVAD found no significant silences. Splitting into fixed chunks.")
146
  return [audio_segment[i:i + TARGET_CHUNK_DURATION_MS] for i in range(0, len(audio_segment), TARGET_CHUNK_DURATION_MS)]
147
+ final_chunks, current_offset, total_length = [], 0, len(audio_segment)
 
148
  while current_offset < total_length:
149
+ if total_length - current_offset <= MAX_SPLIT_SEARCH_END_MS:
 
150
  final_chunks.append(audio_segment[current_offset:])
151
  break
152
  ideal_split_point = current_offset + TARGET_CHUNK_DURATION_MS
153
  candidate_points = [p for p in silence_points if (current_offset + MIN_SPLIT_SEARCH_START_MS) <= p < (current_offset + MAX_SPLIT_SEARCH_END_MS)]
154
  best_split_point = min(candidate_points, key=lambda p: abs(p - ideal_split_point)) if candidate_points else -1
155
+ split_at = best_split_point if best_split_point != -1 else ideal_split_point
156
  final_chunks.append(audio_segment[current_offset:int(split_at)])
157
  current_offset = int(split_at)
158
  logging.info(f"File successfully split into {len(final_chunks)} chunks using WebRTCVAD.")
159
  logging.info(f"Chunk durations (seconds): {[round(len(c) / 1000) for c in final_chunks]}")
160
  return final_chunks
161
 
162
+
 
 
163
  def validate_and_correct_segments(segments_from_api, chunk_duration_ms):
164
+ corrected_segments, last_corrected_end_ms = [], 0
 
 
165
  for seg in segments_from_api:
166
  try:
167
  start_ms = parse_time_str_to_ms(seg.get('start_time'))
168
  end_ms = parse_time_str_to_ms(seg.get('end_time'))
169
 
170
+ # The improved parser makes this check much more reliable.
171
+ # We still keep it for true hallucinations.
 
 
172
  if start_ms >= chunk_duration_ms:
173
+ logging.warning(f"Skipping segment with true hallucinatory start_time ({format_ms_to_srt_time(start_ms)}) outside of chunk duration ({format_ms_to_srt_time(chunk_duration_ms)}).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  continue
175
+
176
+ if end_ms > chunk_duration_ms: end_ms = chunk_duration_ms
177
+ if start_ms >= end_ms: end_ms = min(start_ms + 3000, chunk_duration_ms)
178
+ if start_ms < last_corrected_end_ms: start_ms = last_corrected_end_ms
179
+ if start_ms >= end_ms: continue
180
 
181
+ seg['start_time_relative'], seg['end_time_relative'] = start_ms, end_ms
 
182
  corrected_segments.append(seg)
 
183
  last_corrected_end_ms = end_ms
 
184
  except (ValueError, TypeError, KeyError) as e:
185
  logging.warning(f"Skipping segment due to invalid format or value: {seg}. Error: {e}")
186
  continue
 
187
  return corrected_segments
188
 
189
+
190
  def transcribe_chunk(chunk_audio, api_key, system_prompt, pydantic_schema, model_name):
 
191
  try:
192
  client = genai.Client(api_key=api_key)
193
  buffer = io.BytesIO()
 
204
  def generate_srt_content(segments):
205
  lines = []
206
  for i, seg in enumerate(segments, 1):
207
+ lines.extend([str(i), f"{format_ms_to_srt_time(seg['start_time_abs'])} --> {format_ms_to_srt_time(seg['end_time_abs'])}", seg['text'], ""])
 
 
 
 
 
208
  return "\n".join(lines)
209
 
210
+
211
  async def _transcribe_and_stream(api_key: str, file_content: bytes, model_name: str):
 
212
  def send_event(type: str, message: str = "", percent: int = 0, data: str = ""):
213
  return json.dumps({"type": type, "message": message, "percent": percent, "data": data}) + "\n\n"
 
214
  try:
215
+ system_prompt, pydantic_schema = load_system_prompt(), TranscriptionSegment
 
216
  yield send_event("progress", "מעבד את קובץ השמע...", 5)
217
  audio = AudioSegment.from_file(io.BytesIO(file_content))
218
  yield send_event("progress", f"אורך הקובץ {len(audio) / 60000:.1f} דקות. מבצע חלוקה...", 15)
 
223
  yield send_event("progress", f"הקובץ חולק ל-{len(chunks)} מקטעים. מתחיל תמלול...", 20)
224
 
225
  all_segs, offset = [], 0
 
226
  for i, ch in enumerate(chunks):
227
+ progress_percent = 20 + int(((i + 1) / len(chunks)) * 75)
228
+ yield send_event("progress", f"מתמלל מקטע {i+1} מתוך {len(chunks)}...", progress_percent)
229
  data, error_msg = await asyncio.to_thread(transcribe_chunk, ch, api_key, system_prompt, pydantic_schema, model_name)
230
  if error_msg: raise ValueError(f"שגיאה בעיבוד מקטע {i+1}: {error_msg}")
231
  if data and isinstance(data, list):
232
  corrected_segments = validate_and_correct_segments(data, len(ch))
233
  for seg in corrected_segments:
234
  seg['start_time_abs'] = seg['start_time_relative'] + offset
235
+ seg['end_time_abs'] = seg['end_time_relative'] + offset
236
  all_segs.append(seg)
237
  offset += len(ch)
238
 
239
  if not all_segs: raise ValueError("התמלול נכשל. לא נוצר תוכן תקני.")
240
+ yield send_event("result", "התהליך הושלם בהצלחה!", 100, data=generate_srt_content(all_segs))
 
 
241
  except Exception as e:
242
  logging.error(f"Streaming transcription failed: {e}", exc_info=True)
243
  yield send_event("error", f"אירעה שגיאה: {e}", 100)
244
 
245
+
246
  @app.get("/", response_class=HTMLResponse)
247
  async def read_root(request: Request):
248
  return templates.TemplateResponse("index.html", {"request": request})