Spaces:
Building
Building
""" | |
Base STT Implementation | |
====================== | |
Common audio processing and validation for all STT providers | |
""" | |
import struct | |
from typing import Optional, Tuple, List | |
from datetime import datetime | |
from abc import ABC, abstractmethod | |
from .stt_interface import STTInterface, STTConfig, TranscriptionResult | |
from utils.logger import log_info, log_error, log_debug, log_warning | |
class STTBase(STTInterface, ABC): | |
"""Base class for all STT implementations with common audio processing""" | |
def __init__(self): | |
super().__init__() | |
async def transcribe(self, audio_data: bytes, config: STTConfig) -> Optional[TranscriptionResult]: | |
"""Main transcription method with preprocessing""" | |
try: | |
# 1. Validate input | |
if not audio_data: | |
log_warning("β οΈ No audio data provided") | |
return None | |
log_info(f"π Transcribing {len(audio_data)} bytes of audio") | |
# 2. Analyze and validate audio | |
analysis_result = self._analyze_audio(audio_data, config.sample_rate) | |
if not analysis_result.is_valid: | |
log_warning(f"β οΈ Audio validation failed: {analysis_result.reason}") | |
return None | |
# 3. Preprocess audio if needed | |
processed_audio = self._preprocess_audio(audio_data, config) | |
# 4. Call provider-specific implementation | |
return await self._transcribe_impl(processed_audio, config, analysis_result) | |
except Exception as e: | |
log_error(f"β Error during transcription: {str(e)}") | |
import traceback | |
log_error(f"Traceback: {traceback.format_exc()}") | |
return None | |
async def _transcribe_impl(self, audio_data: bytes, config: STTConfig, analysis: 'AudioAnalysis') -> Optional[TranscriptionResult]: | |
"""Provider-specific transcription implementation""" | |
pass | |
def _analyze_audio(self, audio_data: bytes, sample_rate: int) -> 'AudioAnalysis': | |
"""Analyze audio quality and content""" | |
try: | |
samples = struct.unpack(f'{len(audio_data)//2}h', audio_data) | |
total_samples = len(samples) | |
# Basic statistics | |
non_zero_samples = [s for s in samples if s != 0] | |
zero_count = total_samples - len(non_zero_samples) | |
if non_zero_samples: | |
avg_amplitude = sum(abs(s) for s in non_zero_samples) / len(non_zero_samples) | |
max_amplitude = max(abs(s) for s in non_zero_samples) | |
else: | |
avg_amplitude = 0 | |
max_amplitude = 0 | |
log_info(f"π Audio stats: {total_samples} total samples, {zero_count} zeros ({zero_count/total_samples:.1%})") | |
log_info(f"π Non-zero stats: avg={avg_amplitude:.1f}, max={max_amplitude}") | |
# Section analysis (10 sections) | |
section_size = total_samples // 10 | |
sections = [] | |
for i in range(10): | |
start_idx = i * section_size | |
end_idx = (i + 1) * section_size if i < 9 else total_samples | |
section = samples[start_idx:end_idx] | |
section_non_zero = [s for s in section if s != 0] | |
section_max = max(abs(s) for s in section_non_zero) if section_non_zero else 0 | |
section_avg = sum(abs(s) for s in section_non_zero) / len(section_non_zero) if section_non_zero else 0 | |
zero_ratio = (len(section) - len(section_non_zero)) / len(section) | |
sections.append({ | |
'max': section_max, | |
'avg': section_avg, | |
'zero_ratio': zero_ratio | |
}) | |
log_info(f" Section {i+1}: max={section_max}, avg={section_avg:.1f}, zeros={zero_ratio:.1%}") | |
# Find speech start | |
speech_start_idx = self._find_speech_start(samples, sample_rate) | |
speech_start_time = speech_start_idx / sample_rate if speech_start_idx >= 0 else -1 | |
if speech_start_idx >= 0: | |
log_info(f"π€ Speech detected starting at sample {speech_start_idx} ({speech_start_time:.2f}s)") | |
else: | |
log_warning("β οΈ No speech detected above threshold in entire audio") | |
# Validation | |
is_valid = True | |
reason = "" | |
if max_amplitude < 100: | |
is_valid = False | |
reason = f"Audio appears silent: max_amplitude={max_amplitude}" | |
elif zero_count / total_samples > 0.95: | |
is_valid = False | |
reason = f"Audio is mostly zeros: {zero_count/total_samples:.1%}" | |
elif speech_start_idx < 0: | |
is_valid = False | |
reason = "No speech detected" | |
return AudioAnalysis( | |
total_samples=total_samples, | |
sample_rate=sample_rate, | |
zero_count=zero_count, | |
avg_amplitude=avg_amplitude, | |
max_amplitude=max_amplitude, | |
sections=sections, | |
speech_start_idx=speech_start_idx, | |
speech_start_time=speech_start_time, | |
is_valid=is_valid, | |
reason=reason | |
) | |
except Exception as e: | |
log_error(f"Audio analysis failed: {e}") | |
return AudioAnalysis( | |
total_samples=0, | |
sample_rate=sample_rate, | |
is_valid=False, | |
reason=f"Analysis failed: {e}" | |
) | |
def _find_speech_start(self, samples: List[int], sample_rate: int, threshold: int = 500) -> int: | |
"""Find the starting point of speech in audio""" | |
window_size = 100 | |
for i in range(0, len(samples) - window_size, window_size): | |
window = samples[i:i + window_size] | |
rms = (sum(s * s for s in window) / window_size) ** 0.5 | |
if rms > threshold: | |
return i | |
return -1 | |
def _preprocess_audio(self, audio_data: bytes, config: STTConfig) -> bytes: | |
"""Preprocess audio if needed (can be overridden by providers)""" | |
# Default: no preprocessing | |
return audio_data | |
def _clean_audio_silence(self, audio_data: bytes, threshold: int = 50) -> bytes: | |
"""Remove leading/trailing silence""" | |
try: | |
samples = struct.unpack(f'{len(audio_data)//2}h', audio_data) | |
# Find first non-silent sample | |
start_idx = 0 | |
for i, sample in enumerate(samples): | |
if abs(sample) > threshold: | |
start_idx = i | |
break | |
# Find last non-silent sample | |
end_idx = len(samples) - 1 | |
for i in range(len(samples) - 1, -1, -1): | |
if abs(samples[i]) > threshold: | |
end_idx = i | |
break | |
# Add padding | |
start_idx = max(0, start_idx - 100) | |
end_idx = min(len(samples) - 1, end_idx + 100) | |
# Convert back | |
cleaned_samples = samples[start_idx:end_idx + 1] | |
cleaned_audio = struct.pack(f'{len(cleaned_samples)}h', *cleaned_samples) | |
log_debug(f"Audio cleaning: {len(audio_data)} β {len(cleaned_audio)} bytes") | |
return cleaned_audio | |
except Exception as e: | |
log_warning(f"Audio cleaning failed: {e}, using original") | |
return audio_data | |
class AudioAnalysis: | |
"""Audio analysis results""" | |
total_samples: int = 0 | |
sample_rate: int = 16000 | |
zero_count: int = 0 | |
avg_amplitude: float = 0.0 | |
max_amplitude: int = 0 | |
sections: List[dict] = field(default_factory=list) | |
speech_start_idx: int = -1 | |
speech_start_time: float = -1.0 | |
is_valid: bool = False | |
reason: str = "" |