flare / stt /stt_base.py
ciyidogan's picture
Create stt_base.py
0b9eed5 verified
"""
Base STT Implementation
======================
Common audio processing and validation for all STT providers
"""
import struct
from typing import Optional, Tuple, List
from datetime import datetime
from abc import ABC, abstractmethod
from .stt_interface import STTInterface, STTConfig, TranscriptionResult
from utils.logger import log_info, log_error, log_debug, log_warning
class STTBase(STTInterface, ABC):
"""Base class for all STT implementations with common audio processing"""
def __init__(self):
super().__init__()
async def transcribe(self, audio_data: bytes, config: STTConfig) -> Optional[TranscriptionResult]:
"""Main transcription method with preprocessing"""
try:
# 1. Validate input
if not audio_data:
log_warning("⚠️ No audio data provided")
return None
log_info(f"πŸ“Š Transcribing {len(audio_data)} bytes of audio")
# 2. Analyze and validate audio
analysis_result = self._analyze_audio(audio_data, config.sample_rate)
if not analysis_result.is_valid:
log_warning(f"⚠️ Audio validation failed: {analysis_result.reason}")
return None
# 3. Preprocess audio if needed
processed_audio = self._preprocess_audio(audio_data, config)
# 4. Call provider-specific implementation
return await self._transcribe_impl(processed_audio, config, analysis_result)
except Exception as e:
log_error(f"❌ Error during transcription: {str(e)}")
import traceback
log_error(f"Traceback: {traceback.format_exc()}")
return None
@abstractmethod
async def _transcribe_impl(self, audio_data: bytes, config: STTConfig, analysis: 'AudioAnalysis') -> Optional[TranscriptionResult]:
"""Provider-specific transcription implementation"""
pass
def _analyze_audio(self, audio_data: bytes, sample_rate: int) -> 'AudioAnalysis':
"""Analyze audio quality and content"""
try:
samples = struct.unpack(f'{len(audio_data)//2}h', audio_data)
total_samples = len(samples)
# Basic statistics
non_zero_samples = [s for s in samples if s != 0]
zero_count = total_samples - len(non_zero_samples)
if non_zero_samples:
avg_amplitude = sum(abs(s) for s in non_zero_samples) / len(non_zero_samples)
max_amplitude = max(abs(s) for s in non_zero_samples)
else:
avg_amplitude = 0
max_amplitude = 0
log_info(f"πŸ” Audio stats: {total_samples} total samples, {zero_count} zeros ({zero_count/total_samples:.1%})")
log_info(f"πŸ” Non-zero stats: avg={avg_amplitude:.1f}, max={max_amplitude}")
# Section analysis (10 sections)
section_size = total_samples // 10
sections = []
for i in range(10):
start_idx = i * section_size
end_idx = (i + 1) * section_size if i < 9 else total_samples
section = samples[start_idx:end_idx]
section_non_zero = [s for s in section if s != 0]
section_max = max(abs(s) for s in section_non_zero) if section_non_zero else 0
section_avg = sum(abs(s) for s in section_non_zero) / len(section_non_zero) if section_non_zero else 0
zero_ratio = (len(section) - len(section_non_zero)) / len(section)
sections.append({
'max': section_max,
'avg': section_avg,
'zero_ratio': zero_ratio
})
log_info(f" Section {i+1}: max={section_max}, avg={section_avg:.1f}, zeros={zero_ratio:.1%}")
# Find speech start
speech_start_idx = self._find_speech_start(samples, sample_rate)
speech_start_time = speech_start_idx / sample_rate if speech_start_idx >= 0 else -1
if speech_start_idx >= 0:
log_info(f"🎀 Speech detected starting at sample {speech_start_idx} ({speech_start_time:.2f}s)")
else:
log_warning("⚠️ No speech detected above threshold in entire audio")
# Validation
is_valid = True
reason = ""
if max_amplitude < 100:
is_valid = False
reason = f"Audio appears silent: max_amplitude={max_amplitude}"
elif zero_count / total_samples > 0.95:
is_valid = False
reason = f"Audio is mostly zeros: {zero_count/total_samples:.1%}"
elif speech_start_idx < 0:
is_valid = False
reason = "No speech detected"
return AudioAnalysis(
total_samples=total_samples,
sample_rate=sample_rate,
zero_count=zero_count,
avg_amplitude=avg_amplitude,
max_amplitude=max_amplitude,
sections=sections,
speech_start_idx=speech_start_idx,
speech_start_time=speech_start_time,
is_valid=is_valid,
reason=reason
)
except Exception as e:
log_error(f"Audio analysis failed: {e}")
return AudioAnalysis(
total_samples=0,
sample_rate=sample_rate,
is_valid=False,
reason=f"Analysis failed: {e}"
)
def _find_speech_start(self, samples: List[int], sample_rate: int, threshold: int = 500) -> int:
"""Find the starting point of speech in audio"""
window_size = 100
for i in range(0, len(samples) - window_size, window_size):
window = samples[i:i + window_size]
rms = (sum(s * s for s in window) / window_size) ** 0.5
if rms > threshold:
return i
return -1
def _preprocess_audio(self, audio_data: bytes, config: STTConfig) -> bytes:
"""Preprocess audio if needed (can be overridden by providers)"""
# Default: no preprocessing
return audio_data
def _clean_audio_silence(self, audio_data: bytes, threshold: int = 50) -> bytes:
"""Remove leading/trailing silence"""
try:
samples = struct.unpack(f'{len(audio_data)//2}h', audio_data)
# Find first non-silent sample
start_idx = 0
for i, sample in enumerate(samples):
if abs(sample) > threshold:
start_idx = i
break
# Find last non-silent sample
end_idx = len(samples) - 1
for i in range(len(samples) - 1, -1, -1):
if abs(samples[i]) > threshold:
end_idx = i
break
# Add padding
start_idx = max(0, start_idx - 100)
end_idx = min(len(samples) - 1, end_idx + 100)
# Convert back
cleaned_samples = samples[start_idx:end_idx + 1]
cleaned_audio = struct.pack(f'{len(cleaned_samples)}h', *cleaned_samples)
log_debug(f"Audio cleaning: {len(audio_data)} β†’ {len(cleaned_audio)} bytes")
return cleaned_audio
except Exception as e:
log_warning(f"Audio cleaning failed: {e}, using original")
return audio_data
@dataclass
class AudioAnalysis:
"""Audio analysis results"""
total_samples: int = 0
sample_rate: int = 16000
zero_count: int = 0
avg_amplitude: float = 0.0
max_amplitude: int = 0
sections: List[dict] = field(default_factory=list)
speech_start_idx: int = -1
speech_start_time: float = -1.0
is_valid: bool = False
reason: str = ""