PhoWhisperBaseAPI / app /core /asr_engine.py
bichnhan2701's picture
Tempt remove preprocess
f58f9b6
import logging
import time
import re
from typing import List, Dict, Tuple
import torch
from transformers import pipeline
from transformers import logging as transformers_logging
import warnings
import os
from app.core.chunking import split_audio_to_chunks
from app.core.audio_utils import get_audio_info
from app.config.settings import MODEL_NAME
logger = logging.getLogger(__name__)
# ===============================
# Post-processing utilities
# ===============================
def _clean_transcript(text: str) -> str:
"""
Clean up common ASR artifacts:
- Remove excessive dots (silence markers)
- Remove repeated words/phrases (hallucinations)
- Clean up whitespace
"""
if not text:
return ""
# 1. Remove excessive dots (more than 3 consecutive)
text = re.sub(r'\.{4,}', '.', text)
# 2. Remove repeated single words (e.g., "chuyền chuyền chuyền...")
# Match word repeated 2+ times consecutively
text = re.sub(r'\b(\w+)(\s+\1){2,}\b', r'\1', text, flags=re.IGNORECASE)
# 3. Remove repeated short phrases (2-5 words repeated 2+ times)
# More aggressive pattern to catch "biết chính xác mình cần làm" repeats
for phrase_len in [5, 4, 3, 2]:
pattern = r'((?:\S+\s+){' + str(phrase_len) + r'})\1{1,}'
text = re.sub(pattern, r'\1', text)
# 4. Remove long repeated phrases (like "thế giới trên cầu" repeated many times)
# Find and remove sequences where same phrase appears 3+ times
words = text.split()
if len(words) > 10:
text = _remove_long_repeats(text)
# 5. Clean up multiple spaces
text = re.sub(r'\s+', ' ', text)
# 6. Clean up space before punctuation
text = re.sub(r'\s+([.,!?])', r'\1', text)
# 7. Remove trailing/leading dots and spaces
text = text.strip(' .')
return text
def _remove_long_repeats(text: str) -> str:
"""
Remove long repeated phrases that regex can't easily catch.
Looks for phrases of 3-8 words that repeat consecutively.
"""
words = text.split()
if len(words) < 10:
return text
result = []
i = 0
while i < len(words):
# Try to find repeating patterns of length 3-8 words
found_repeat = False
for phrase_len in range(8, 2, -1): # Check longer phrases first
if i + phrase_len * 2 > len(words):
continue
phrase = words[i:i+phrase_len]
next_phrase = words[i+phrase_len:i+phrase_len*2]
if phrase == next_phrase:
# Found a repeat, skip all consecutive repeats
result.extend(phrase)
j = i + phrase_len
while j + phrase_len <= len(words) and words[j:j+phrase_len] == phrase:
j += phrase_len
i = j
found_repeat = True
break
if not found_repeat:
result.append(words[i])
i += 1
return ' '.join(result)
def _deduplicate_chunks(prev_text: str, curr_text: str, overlap_words: int = 15) -> str:
"""
Remove overlapping text between consecutive chunks.
Compares the end of prev_text with the start of curr_text.
"""
if not prev_text or not curr_text:
return curr_text
prev_words = prev_text.split()
curr_words = curr_text.split()
if len(prev_words) < 3 or len(curr_words) < 3:
return curr_text
# Check last N words of prev against first N words of curr
check_len = min(overlap_words, len(prev_words), len(curr_words))
best_overlap = 0
for i in range(1, check_len + 1):
prev_end = prev_words[-i:]
curr_start = curr_words[:i]
# Normalize for comparison (lowercase, strip punctuation)
prev_normalized = [re.sub(r'[^\w]', '', w.lower()) for w in prev_end]
curr_normalized = [re.sub(r'[^\w]', '', w.lower()) for w in curr_start]
if prev_normalized == curr_normalized:
best_overlap = i
if best_overlap > 0:
# Remove the overlapping part from current text
return ' '.join(curr_words[best_overlap:])
return curr_text
# ===============================
# Global model cache
# ===============================
_ASR_MODEL = None
def load_model(chunk_length_s: float = 30.0):
"""
Load ASR model once and reuse.
Safe to call multiple times.
"""
global _ASR_MODEL
if _ASR_MODEL is not None:
return _ASR_MODEL
logger.info("Loading ASR model %s", MODEL_NAME)
device = 0 if torch.cuda.is_available() else -1
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Reduce noisy transformer logs and warnings about experimental chunking
try:
transformers_logging.set_verbosity_error()
except Exception:
pass
# filter the noisy chunk_length_s warnings (regex)
warnings.filterwarnings("ignore", message=r".*chunk_length_s.*")
_ASR_MODEL = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
device=device,
dtype=dtype,
chunk_length_s=chunk_length_s,
return_timestamps=True,
ignore_warning=True,
)
logger.info(
"ASR model loaded (device=%s)", "cuda" if device >= 0 else "cpu"
)
return _ASR_MODEL
# ===============================
# Transcribe full text
# ===============================
def transcribe_file_unified(
model,
wav_path: str,
chunk_length_s: float = 30.0,
stride_s: float = 5.0,
) -> Tuple[str, List[Dict]]:
"""
🔥 UNIFIED: Return both full transcript text AND timestamped chunks in ONE inference pass.
Uses Whisper's built-in chunking mechanism instead of manual splitting to avoid hallucination.
Returns:
(text, chunks) where chunks = [{"start": float, "end": float, "text": str}, ...]
"""
if not wav_path:
return "", []
start_time = time.time()
logger.info("[ASR] Starting unified transcription for %s", wav_path)
info = get_audio_info(wav_path) or {}
duration = info.get("duration", 0)
logger.info("[ASR] Audio duration: %.2fs", duration)
# 🔥 FIX: Always use single pipeline call with Whisper's built-in chunking
# Manual chunking causes text repetition and hallucination
# Whisper's internal chunking handles long audio properly
out = model(
wav_path,
chunk_length_s=chunk_length_s,
stride_length_s=(chunk_length_s // 6, chunk_length_s // 6), # ~5s left/right context
return_timestamps=True,
)
# Extract text
text = (out.get("text") or "").strip()
if not text:
segs = out.get("chunks") or out.get("segments") or []
if segs:
parts = [(s.get("text") or "").strip() for s in segs]
text = " ".join([p for p in parts if p]).strip()
# Extract chunks with timestamps
chunks = _extract_chunks_from_output(out)
# 🔥 FIX: Clean up ASR artifacts (repeated words/phrases, hallucinations)
text = _clean_transcript(text)
for chunk in chunks:
if chunk.get("text"):
chunk["text"] = _clean_transcript(chunk["text"])
elapsed = time.time() - start_time
logger.info("[ASR] Transcription completed in %.2fs (%.2fx realtime)",
elapsed, elapsed / duration if duration else 0)
return text, chunks
def _extract_chunks_from_output(out: dict) -> List[Dict]:
"""Extract timestamped chunks from model output."""
raw_segments = out.get("chunks") or out.get("segments") or []
chunks = []
for c in raw_segments:
start = None
end = None
if isinstance(c.get("timestamp"), (list, tuple)) and len(c.get("timestamp")) >= 2:
ts = c.get("timestamp")
start, end = ts[0], ts[1]
elif c.get("start") is not None and c.get("end") is not None:
start, end = c.get("start"), c.get("end")
text = (c.get("text") or "").strip()
if not text or start is None or end is None:
continue
try:
chunks.append({"start": float(start), "end": float(end), "text": text})
except Exception:
continue
return chunks
def transcribe_file(
model,
wav_path: str,
chunk_length_s: float = 30.0,
stride_s: float = 5.0,
) -> str:
"""
Return full transcript text.
⚠️ DEPRECATED: Use transcribe_file_unified() to get both text and chunks in one pass.
"""
text, _ = transcribe_file_unified(model, wav_path, chunk_length_s, stride_s)
return text
def transcribe_long_audio(
model,
wav_path: str,
chunk_length_s: float = 30.0,
overlap_s: float = 5.0,
) -> Tuple[str, List[Dict]]:
"""
Split `wav_path` into chunks and run inference on each chunk sequentially.
Returns (full_text, chunks) where chunks have global start/end timestamps.
"""
if not wav_path:
return "", []
split_start = time.time()
# prefer VAD-based splitting if available
try:
from app.core.chunking import split_audio_with_vad
chunk_paths = split_audio_with_vad(wav_path)
logger.info("[ASR] VAD split into %d chunks in %.2fs", len(chunk_paths), time.time() - split_start)
except Exception as e:
logger.warning("[ASR] VAD split failed (%s), using fixed windows", e)
chunk_paths = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s)
logger.info("[ASR] Fixed-window split into %d chunks in %.2fs", len(chunk_paths), time.time() - split_start)
combined_text_parts = []
combined_chunks: List[Dict] = []
prev_chunk_text = "" # For deduplication
# Track actual start time of each chunk for accurate global timestamps
# For VAD chunks: use cumulative duration
# For fixed chunks: use step = chunk_length - overlap (default 30-2=28s)
chunk_start_times = []
cumulative_time = 0.0
for i, cp in enumerate(chunk_paths):
chunk_start_times.append(cumulative_time)
try:
cinfo = get_audio_info(cp) or {}
chunk_dur = cinfo.get("duration", chunk_length_s)
# For fixed-window chunks, step forward by (chunk_length - overlap)
# For VAD chunks, step forward by actual duration (no overlap)
step = chunk_dur - overlap_s if overlap_s > 0 else chunk_dur
cumulative_time += max(step, chunk_dur * 0.8) # At least 80% of chunk duration
except Exception:
cumulative_time += chunk_length_s - overlap_s
try:
for i, cp in enumerate(chunk_paths):
base_offset = chunk_start_times[i] if i < len(chunk_start_times) else 0.0
try:
cinfo = get_audio_info(cp) or {}
logger.debug(
"chunk[%d]=%s duration=%.3fs samplerate=%s", i, cp, cinfo.get("duration"), cinfo.get("samplerate")
)
except Exception:
logger.debug("chunk[%d]=%s (info unavailable)", i, cp)
try:
# Use small stride for better timestamp accuracy within each chunk
# stride_length_s as tuple: (left_context, right_context)
stride = min(5, chunk_length_s // 6)
out = model(
cp,
chunk_length_s=chunk_length_s,
stride_length_s=(stride, stride),
return_timestamps=True,
)
except Exception:
logger.exception("model inference failed for chunk %s", cp)
continue
# debug: log output shape/keys (only first few chunks to avoid huge logs)
try:
if i < 5:
logger.debug("model out keys for chunk[%d]: %s", i, list(out.keys()) if isinstance(out, dict) else type(out))
except Exception:
logger.debug("failed to log model out keys for chunk %d", i)
part_text = (out.get("text") or "").strip()
if not part_text:
segs = out.get("chunks") or out.get("segments") or []
parts = [ (s.get("text") or "").strip() for s in segs ]
part_text = " ".join([p for p in parts if p]).strip()
if part_text:
# ⚠️ FIX: Deduplicate overlapping text between chunks
deduped_text = _deduplicate_chunks(prev_chunk_text, part_text)
if deduped_text:
combined_text_parts.append(deduped_text)
prev_chunk_text = part_text # Keep original for next comparison
raw_segs = out.get("chunks") or out.get("segments") or []
if raw_segs:
for s in raw_segs:
start = None
end = None
if isinstance(s.get("timestamp"), (list, tuple)) and len(s.get("timestamp")) >= 2:
ts = s.get("timestamp")
start, end = ts[0], ts[1]
elif s.get("start") is not None and s.get("end") is not None:
start, end = s.get("start"), s.get("end")
text = (s.get("text") or "").strip()
if not text or start is None or end is None:
continue
try:
combined_chunks.append(
{"start": float(start) + base_offset, "end": float(end) + base_offset, "text": text}
)
except Exception:
continue
else:
# If model returned text but no timestamped segments for this chunk,
# create a fallback chunk spanning the chunk file duration.
if part_text:
try:
cinfo = get_audio_info(cp) or {}
cdur = cinfo.get("duration") or chunk_length_s
combined_chunks.append({
"start": float(base_offset),
"end": float(base_offset) + float(cdur),
"text": part_text,
})
except Exception:
logger.exception("failed to create fallback chunk for %s", cp)
finally:
for p in chunk_paths:
try:
if p and os.path.exists(p):
os.remove(p)
except Exception:
logger.debug("Failed to remove chunk file %s", p)
full_text = " ".join([p for p in combined_text_parts if p]).strip()
# ⚠️ FIX: Clean up ASR artifacts (repeated words, excessive dots, etc.)
full_text = _clean_transcript(full_text)
# Also clean individual chunk texts
for chunk in combined_chunks:
if chunk.get("text"):
chunk["text"] = _clean_transcript(chunk["text"])
return full_text, combined_chunks
# ===============================
# Transcribe chunks with timestamps
# ===============================
def transcribe_file_chunks(
model,
wav_path: str,
chunk_length_s: float = 30.0,
stride_s: float = 5.0,
) -> List[Dict]:
"""
Return list of chunks:
[{ start, end, text }]
⚠️ DEPRECATED: Use transcribe_file_unified() to get both text and chunks in one pass.
"""
_, chunks = transcribe_file_unified(model, wav_path, chunk_length_s, stride_s)
return chunks