Spaces:
Sleeping
Sleeping
| import logging | |
| import time | |
| import re | |
| from typing import List, Dict, Tuple | |
| import torch | |
| from transformers import pipeline | |
| from transformers import logging as transformers_logging | |
| import warnings | |
| import os | |
| from app.core.chunking import split_audio_to_chunks | |
| from app.core.audio_utils import get_audio_info | |
| from app.config.settings import MODEL_NAME | |
| logger = logging.getLogger(__name__) | |
| # =============================== | |
| # Post-processing utilities | |
| # =============================== | |
| def _clean_transcript(text: str) -> str: | |
| """ | |
| Clean up common ASR artifacts: | |
| - Remove excessive dots (silence markers) | |
| - Remove repeated words/phrases (hallucinations) | |
| - Clean up whitespace | |
| """ | |
| if not text: | |
| return "" | |
| # 1. Remove excessive dots (more than 3 consecutive) | |
| text = re.sub(r'\.{4,}', '.', text) | |
| # 2. Remove repeated single words (e.g., "chuyền chuyền chuyền...") | |
| # Match word repeated 2+ times consecutively | |
| text = re.sub(r'\b(\w+)(\s+\1){2,}\b', r'\1', text, flags=re.IGNORECASE) | |
| # 3. Remove repeated short phrases (2-5 words repeated 2+ times) | |
| # More aggressive pattern to catch "biết chính xác mình cần làm" repeats | |
| for phrase_len in [5, 4, 3, 2]: | |
| pattern = r'((?:\S+\s+){' + str(phrase_len) + r'})\1{1,}' | |
| text = re.sub(pattern, r'\1', text) | |
| # 4. Remove long repeated phrases (like "thế giới trên cầu" repeated many times) | |
| # Find and remove sequences where same phrase appears 3+ times | |
| words = text.split() | |
| if len(words) > 10: | |
| text = _remove_long_repeats(text) | |
| # 5. Clean up multiple spaces | |
| text = re.sub(r'\s+', ' ', text) | |
| # 6. Clean up space before punctuation | |
| text = re.sub(r'\s+([.,!?])', r'\1', text) | |
| # 7. Remove trailing/leading dots and spaces | |
| text = text.strip(' .') | |
| return text | |
| def _remove_long_repeats(text: str) -> str: | |
| """ | |
| Remove long repeated phrases that regex can't easily catch. | |
| Looks for phrases of 3-8 words that repeat consecutively. | |
| """ | |
| words = text.split() | |
| if len(words) < 10: | |
| return text | |
| result = [] | |
| i = 0 | |
| while i < len(words): | |
| # Try to find repeating patterns of length 3-8 words | |
| found_repeat = False | |
| for phrase_len in range(8, 2, -1): # Check longer phrases first | |
| if i + phrase_len * 2 > len(words): | |
| continue | |
| phrase = words[i:i+phrase_len] | |
| next_phrase = words[i+phrase_len:i+phrase_len*2] | |
| if phrase == next_phrase: | |
| # Found a repeat, skip all consecutive repeats | |
| result.extend(phrase) | |
| j = i + phrase_len | |
| while j + phrase_len <= len(words) and words[j:j+phrase_len] == phrase: | |
| j += phrase_len | |
| i = j | |
| found_repeat = True | |
| break | |
| if not found_repeat: | |
| result.append(words[i]) | |
| i += 1 | |
| return ' '.join(result) | |
| def _deduplicate_chunks(prev_text: str, curr_text: str, overlap_words: int = 15) -> str: | |
| """ | |
| Remove overlapping text between consecutive chunks. | |
| Compares the end of prev_text with the start of curr_text. | |
| """ | |
| if not prev_text or not curr_text: | |
| return curr_text | |
| prev_words = prev_text.split() | |
| curr_words = curr_text.split() | |
| if len(prev_words) < 3 or len(curr_words) < 3: | |
| return curr_text | |
| # Check last N words of prev against first N words of curr | |
| check_len = min(overlap_words, len(prev_words), len(curr_words)) | |
| best_overlap = 0 | |
| for i in range(1, check_len + 1): | |
| prev_end = prev_words[-i:] | |
| curr_start = curr_words[:i] | |
| # Normalize for comparison (lowercase, strip punctuation) | |
| prev_normalized = [re.sub(r'[^\w]', '', w.lower()) for w in prev_end] | |
| curr_normalized = [re.sub(r'[^\w]', '', w.lower()) for w in curr_start] | |
| if prev_normalized == curr_normalized: | |
| best_overlap = i | |
| if best_overlap > 0: | |
| # Remove the overlapping part from current text | |
| return ' '.join(curr_words[best_overlap:]) | |
| return curr_text | |
| # =============================== | |
| # Global model cache | |
| # =============================== | |
| _ASR_MODEL = None | |
| def load_model(chunk_length_s: float = 30.0): | |
| """ | |
| Load ASR model once and reuse. | |
| Safe to call multiple times. | |
| """ | |
| global _ASR_MODEL | |
| if _ASR_MODEL is not None: | |
| return _ASR_MODEL | |
| logger.info("Loading ASR model %s", MODEL_NAME) | |
| device = 0 if torch.cuda.is_available() else -1 | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Reduce noisy transformer logs and warnings about experimental chunking | |
| try: | |
| transformers_logging.set_verbosity_error() | |
| except Exception: | |
| pass | |
| # filter the noisy chunk_length_s warnings (regex) | |
| warnings.filterwarnings("ignore", message=r".*chunk_length_s.*") | |
| _ASR_MODEL = pipeline( | |
| task="automatic-speech-recognition", | |
| model=MODEL_NAME, | |
| device=device, | |
| dtype=dtype, | |
| chunk_length_s=chunk_length_s, | |
| return_timestamps=True, | |
| ignore_warning=True, | |
| ) | |
| logger.info( | |
| "ASR model loaded (device=%s)", "cuda" if device >= 0 else "cpu" | |
| ) | |
| return _ASR_MODEL | |
| # =============================== | |
| # Transcribe full text | |
| # =============================== | |
| def transcribe_file_unified( | |
| model, | |
| wav_path: str, | |
| chunk_length_s: float = 30.0, | |
| stride_s: float = 5.0, | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| 🔥 UNIFIED: Return both full transcript text AND timestamped chunks in ONE inference pass. | |
| Uses Whisper's built-in chunking mechanism instead of manual splitting to avoid hallucination. | |
| Returns: | |
| (text, chunks) where chunks = [{"start": float, "end": float, "text": str}, ...] | |
| """ | |
| if not wav_path: | |
| return "", [] | |
| start_time = time.time() | |
| logger.info("[ASR] Starting unified transcription for %s", wav_path) | |
| info = get_audio_info(wav_path) or {} | |
| duration = info.get("duration", 0) | |
| logger.info("[ASR] Audio duration: %.2fs", duration) | |
| # 🔥 FIX: Always use single pipeline call with Whisper's built-in chunking | |
| # Manual chunking causes text repetition and hallucination | |
| # Whisper's internal chunking handles long audio properly | |
| out = model( | |
| wav_path, | |
| chunk_length_s=chunk_length_s, | |
| stride_length_s=(chunk_length_s // 6, chunk_length_s // 6), # ~5s left/right context | |
| return_timestamps=True, | |
| ) | |
| # Extract text | |
| text = (out.get("text") or "").strip() | |
| if not text: | |
| segs = out.get("chunks") or out.get("segments") or [] | |
| if segs: | |
| parts = [(s.get("text") or "").strip() for s in segs] | |
| text = " ".join([p for p in parts if p]).strip() | |
| # Extract chunks with timestamps | |
| chunks = _extract_chunks_from_output(out) | |
| # 🔥 FIX: Clean up ASR artifacts (repeated words/phrases, hallucinations) | |
| text = _clean_transcript(text) | |
| for chunk in chunks: | |
| if chunk.get("text"): | |
| chunk["text"] = _clean_transcript(chunk["text"]) | |
| elapsed = time.time() - start_time | |
| logger.info("[ASR] Transcription completed in %.2fs (%.2fx realtime)", | |
| elapsed, elapsed / duration if duration else 0) | |
| return text, chunks | |
| def _extract_chunks_from_output(out: dict) -> List[Dict]: | |
| """Extract timestamped chunks from model output.""" | |
| raw_segments = out.get("chunks") or out.get("segments") or [] | |
| chunks = [] | |
| for c in raw_segments: | |
| start = None | |
| end = None | |
| if isinstance(c.get("timestamp"), (list, tuple)) and len(c.get("timestamp")) >= 2: | |
| ts = c.get("timestamp") | |
| start, end = ts[0], ts[1] | |
| elif c.get("start") is not None and c.get("end") is not None: | |
| start, end = c.get("start"), c.get("end") | |
| text = (c.get("text") or "").strip() | |
| if not text or start is None or end is None: | |
| continue | |
| try: | |
| chunks.append({"start": float(start), "end": float(end), "text": text}) | |
| except Exception: | |
| continue | |
| return chunks | |
| def transcribe_file( | |
| model, | |
| wav_path: str, | |
| chunk_length_s: float = 30.0, | |
| stride_s: float = 5.0, | |
| ) -> str: | |
| """ | |
| Return full transcript text. | |
| ⚠️ DEPRECATED: Use transcribe_file_unified() to get both text and chunks in one pass. | |
| """ | |
| text, _ = transcribe_file_unified(model, wav_path, chunk_length_s, stride_s) | |
| return text | |
| def transcribe_long_audio( | |
| model, | |
| wav_path: str, | |
| chunk_length_s: float = 30.0, | |
| overlap_s: float = 5.0, | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Split `wav_path` into chunks and run inference on each chunk sequentially. | |
| Returns (full_text, chunks) where chunks have global start/end timestamps. | |
| """ | |
| if not wav_path: | |
| return "", [] | |
| split_start = time.time() | |
| # prefer VAD-based splitting if available | |
| try: | |
| from app.core.chunking import split_audio_with_vad | |
| chunk_paths = split_audio_with_vad(wav_path) | |
| logger.info("[ASR] VAD split into %d chunks in %.2fs", len(chunk_paths), time.time() - split_start) | |
| except Exception as e: | |
| logger.warning("[ASR] VAD split failed (%s), using fixed windows", e) | |
| chunk_paths = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s) | |
| logger.info("[ASR] Fixed-window split into %d chunks in %.2fs", len(chunk_paths), time.time() - split_start) | |
| combined_text_parts = [] | |
| combined_chunks: List[Dict] = [] | |
| prev_chunk_text = "" # For deduplication | |
| # Track actual start time of each chunk for accurate global timestamps | |
| # For VAD chunks: use cumulative duration | |
| # For fixed chunks: use step = chunk_length - overlap (default 30-2=28s) | |
| chunk_start_times = [] | |
| cumulative_time = 0.0 | |
| for i, cp in enumerate(chunk_paths): | |
| chunk_start_times.append(cumulative_time) | |
| try: | |
| cinfo = get_audio_info(cp) or {} | |
| chunk_dur = cinfo.get("duration", chunk_length_s) | |
| # For fixed-window chunks, step forward by (chunk_length - overlap) | |
| # For VAD chunks, step forward by actual duration (no overlap) | |
| step = chunk_dur - overlap_s if overlap_s > 0 else chunk_dur | |
| cumulative_time += max(step, chunk_dur * 0.8) # At least 80% of chunk duration | |
| except Exception: | |
| cumulative_time += chunk_length_s - overlap_s | |
| try: | |
| for i, cp in enumerate(chunk_paths): | |
| base_offset = chunk_start_times[i] if i < len(chunk_start_times) else 0.0 | |
| try: | |
| cinfo = get_audio_info(cp) or {} | |
| logger.debug( | |
| "chunk[%d]=%s duration=%.3fs samplerate=%s", i, cp, cinfo.get("duration"), cinfo.get("samplerate") | |
| ) | |
| except Exception: | |
| logger.debug("chunk[%d]=%s (info unavailable)", i, cp) | |
| try: | |
| # Use small stride for better timestamp accuracy within each chunk | |
| # stride_length_s as tuple: (left_context, right_context) | |
| stride = min(5, chunk_length_s // 6) | |
| out = model( | |
| cp, | |
| chunk_length_s=chunk_length_s, | |
| stride_length_s=(stride, stride), | |
| return_timestamps=True, | |
| ) | |
| except Exception: | |
| logger.exception("model inference failed for chunk %s", cp) | |
| continue | |
| # debug: log output shape/keys (only first few chunks to avoid huge logs) | |
| try: | |
| if i < 5: | |
| logger.debug("model out keys for chunk[%d]: %s", i, list(out.keys()) if isinstance(out, dict) else type(out)) | |
| except Exception: | |
| logger.debug("failed to log model out keys for chunk %d", i) | |
| part_text = (out.get("text") or "").strip() | |
| if not part_text: | |
| segs = out.get("chunks") or out.get("segments") or [] | |
| parts = [ (s.get("text") or "").strip() for s in segs ] | |
| part_text = " ".join([p for p in parts if p]).strip() | |
| if part_text: | |
| # ⚠️ FIX: Deduplicate overlapping text between chunks | |
| deduped_text = _deduplicate_chunks(prev_chunk_text, part_text) | |
| if deduped_text: | |
| combined_text_parts.append(deduped_text) | |
| prev_chunk_text = part_text # Keep original for next comparison | |
| raw_segs = out.get("chunks") or out.get("segments") or [] | |
| if raw_segs: | |
| for s in raw_segs: | |
| start = None | |
| end = None | |
| if isinstance(s.get("timestamp"), (list, tuple)) and len(s.get("timestamp")) >= 2: | |
| ts = s.get("timestamp") | |
| start, end = ts[0], ts[1] | |
| elif s.get("start") is not None and s.get("end") is not None: | |
| start, end = s.get("start"), s.get("end") | |
| text = (s.get("text") or "").strip() | |
| if not text or start is None or end is None: | |
| continue | |
| try: | |
| combined_chunks.append( | |
| {"start": float(start) + base_offset, "end": float(end) + base_offset, "text": text} | |
| ) | |
| except Exception: | |
| continue | |
| else: | |
| # If model returned text but no timestamped segments for this chunk, | |
| # create a fallback chunk spanning the chunk file duration. | |
| if part_text: | |
| try: | |
| cinfo = get_audio_info(cp) or {} | |
| cdur = cinfo.get("duration") or chunk_length_s | |
| combined_chunks.append({ | |
| "start": float(base_offset), | |
| "end": float(base_offset) + float(cdur), | |
| "text": part_text, | |
| }) | |
| except Exception: | |
| logger.exception("failed to create fallback chunk for %s", cp) | |
| finally: | |
| for p in chunk_paths: | |
| try: | |
| if p and os.path.exists(p): | |
| os.remove(p) | |
| except Exception: | |
| logger.debug("Failed to remove chunk file %s", p) | |
| full_text = " ".join([p for p in combined_text_parts if p]).strip() | |
| # ⚠️ FIX: Clean up ASR artifacts (repeated words, excessive dots, etc.) | |
| full_text = _clean_transcript(full_text) | |
| # Also clean individual chunk texts | |
| for chunk in combined_chunks: | |
| if chunk.get("text"): | |
| chunk["text"] = _clean_transcript(chunk["text"]) | |
| return full_text, combined_chunks | |
| # =============================== | |
| # Transcribe chunks with timestamps | |
| # =============================== | |
| def transcribe_file_chunks( | |
| model, | |
| wav_path: str, | |
| chunk_length_s: float = 30.0, | |
| stride_s: float = 5.0, | |
| ) -> List[Dict]: | |
| """ | |
| Return list of chunks: | |
| [{ start, end, text }] | |
| ⚠️ DEPRECATED: Use transcribe_file_unified() to get both text and chunks in one pass. | |
| """ | |
| _, chunks = transcribe_file_unified(model, wav_path, chunk_length_s, stride_s) | |
| return chunks | |