Spaces:
Sleeping
Sleeping
Michael Hu
commited on
Commit
·
619b266
1
Parent(s):
c8d736e
refactor: remove parakeet ASR provider and update all references to Whisper only
Browse files- src/application/dtos/processing_request_dto.py +1 -1
- src/application/services/audio_processing_service.py +1 -1
- src/application/services/configuration_service.py +2 -2
- src/domain/interfaces/speech_recognition.py +1 -2
- src/infrastructure/config/app_config.py +2 -2
- src/infrastructure/stt/__init__.py +0 -2
- src/infrastructure/stt/legacy_compatibility.py +2 -2
- src/infrastructure/stt/parakeet_provider.py +0 -168
- src/infrastructure/stt/provider_factory.py +3 -6
src/application/dtos/processing_request_dto.py
CHANGED
|
@@ -35,7 +35,7 @@ class ProcessingRequestDto:
|
|
| 35 |
raise ValueError("ASR model cannot be empty")
|
| 36 |
|
| 37 |
# Validate ASR model options
|
| 38 |
-
supported_asr_models = ['
|
| 39 |
if self.asr_model not in supported_asr_models:
|
| 40 |
raise ValueError(f"Unsupported ASR model: {self.asr_model}. Supported: {supported_asr_models}")
|
| 41 |
|
|
|
|
| 35 |
raise ValueError("ASR model cannot be empty")
|
| 36 |
|
| 37 |
# Validate ASR model options
|
| 38 |
+
supported_asr_models = ['whisper-small', 'whisper-medium', 'whisper-large']
|
| 39 |
if self.asr_model not in supported_asr_models:
|
| 40 |
raise ValueError(f"Unsupported ASR model: {self.asr_model}. Supported: {supported_asr_models}")
|
| 41 |
|
src/application/services/audio_processing_service.py
CHANGED
|
@@ -634,7 +634,7 @@ class AudioProcessingApplicationService:
|
|
| 634 |
Dict[str, Any]: Supported configurations
|
| 635 |
"""
|
| 636 |
return {
|
| 637 |
-
'asr_models': ['
|
| 638 |
'voices': ['chatterbox'],
|
| 639 |
'languages': ['en', 'zh'],
|
| 640 |
'audio_formats': self._config.get_processing_config()['supported_audio_formats'],
|
|
|
|
| 634 |
Dict[str, Any]: Supported configurations
|
| 635 |
"""
|
| 636 |
return {
|
| 637 |
+
'asr_models': ['whisper-large'],
|
| 638 |
'voices': ['chatterbox'],
|
| 639 |
'languages': ['en', 'zh'],
|
| 640 |
'audio_formats': self._config.get_processing_config()['supported_audio_formats'],
|
src/application/services/configuration_service.py
CHANGED
|
@@ -331,7 +331,7 @@ class ConfigurationApplicationService:
|
|
| 331 |
Raises:
|
| 332 |
ConfigurationException: If validation fails
|
| 333 |
"""
|
| 334 |
-
valid_providers = ['whisper'
|
| 335 |
|
| 336 |
for key, value in updates.items():
|
| 337 |
if key == 'preferred_providers':
|
|
@@ -524,7 +524,7 @@ class ConfigurationApplicationService:
|
|
| 524 |
|
| 525 |
# Check STT providers
|
| 526 |
stt_factory = self._container.resolve(type(self._container._get_stt_factory()))
|
| 527 |
-
for provider in ['whisper'
|
| 528 |
try:
|
| 529 |
stt_factory.create_provider(provider)
|
| 530 |
availability['stt'][provider] = True
|
|
|
|
| 331 |
Raises:
|
| 332 |
ConfigurationException: If validation fails
|
| 333 |
"""
|
| 334 |
+
valid_providers = ['whisper']
|
| 335 |
|
| 336 |
for key, value in updates.items():
|
| 337 |
if key == 'preferred_providers':
|
|
|
|
| 524 |
|
| 525 |
# Check STT providers
|
| 526 |
stt_factory = self._container.resolve(type(self._container._get_stt_factory()))
|
| 527 |
+
for provider in ['whisper']:
|
| 528 |
try:
|
| 529 |
stt_factory.create_provider(provider)
|
| 530 |
availability['stt'][provider] = True
|
src/domain/interfaces/speech_recognition.py
CHANGED
|
@@ -5,7 +5,7 @@ audio content into textual representation. The interface supports multiple STT
|
|
| 5 |
models and providers with consistent error handling.
|
| 6 |
|
| 7 |
The interface is designed to be:
|
| 8 |
-
- Model-agnostic: Works with any STT implementation (Whisper,
|
| 9 |
- Language-aware: Handles multiple languages and dialects
|
| 10 |
- Error-resilient: Provides detailed error information for debugging
|
| 11 |
- Performance-conscious: Supports both batch and streaming transcription
|
|
@@ -65,7 +65,6 @@ class ISpeechRecognitionService(ABC):
|
|
| 65 |
model: The STT model identifier to use for transcription. Examples:
|
| 66 |
- "whisper-small": Fast, lower accuracy
|
| 67 |
- "whisper-large": Slower, higher accuracy
|
| 68 |
-
- "parakeet": Real-time optimized
|
| 69 |
Must be supported by the implementation.
|
| 70 |
|
| 71 |
Returns:
|
|
|
|
| 5 |
models and providers with consistent error handling.
|
| 6 |
|
| 7 |
The interface is designed to be:
|
| 8 |
+
- Model-agnostic: Works with any STT implementation (Whisper, etc.)
|
| 9 |
- Language-aware: Handles multiple languages and dialects
|
| 10 |
- Error-resilient: Provides detailed error information for debugging
|
| 11 |
- Performance-conscious: Supports both batch and streaming transcription
|
|
|
|
| 65 |
model: The STT model identifier to use for transcription. Examples:
|
| 66 |
- "whisper-small": Fast, lower accuracy
|
| 67 |
- "whisper-large": Slower, higher accuracy
|
|
|
|
| 68 |
Must be supported by the implementation.
|
| 69 |
|
| 70 |
Returns:
|
src/infrastructure/config/app_config.py
CHANGED
|
@@ -23,8 +23,8 @@ class TTSConfig:
|
|
| 23 |
@dataclass
|
| 24 |
class STTConfig:
|
| 25 |
"""Configuration for STT providers."""
|
| 26 |
-
preferred_providers: List[str] = field(default_factory=lambda: ['
|
| 27 |
-
default_model: str = '
|
| 28 |
chunk_length_s: int = 30
|
| 29 |
batch_size: int = 16
|
| 30 |
enable_vad: bool = True
|
|
|
|
| 23 |
@dataclass
|
| 24 |
class STTConfig:
|
| 25 |
"""Configuration for STT providers."""
|
| 26 |
+
preferred_providers: List[str] = field(default_factory=lambda: ['whisper'])
|
| 27 |
+
default_model: str = 'whisper'
|
| 28 |
chunk_length_s: int = 30
|
| 29 |
batch_size: int = 16
|
| 30 |
enable_vad: bool = True
|
src/infrastructure/stt/__init__.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
| 1 |
"""STT provider implementations."""
|
| 2 |
|
| 3 |
from .whisper_provider import WhisperSTTProvider
|
| 4 |
-
from .parakeet_provider import ParakeetSTTProvider
|
| 5 |
from .provider_factory import STTProviderFactory, ASRFactory
|
| 6 |
from .legacy_compatibility import transcribe_audio, create_audio_content_from_file
|
| 7 |
|
| 8 |
__all__ = [
|
| 9 |
'WhisperSTTProvider',
|
| 10 |
-
'ParakeetSTTProvider',
|
| 11 |
'STTProviderFactory',
|
| 12 |
'ASRFactory',
|
| 13 |
'transcribe_audio',
|
|
|
|
| 1 |
"""STT provider implementations."""
|
| 2 |
|
| 3 |
from .whisper_provider import WhisperSTTProvider
|
|
|
|
| 4 |
from .provider_factory import STTProviderFactory, ASRFactory
|
| 5 |
from .legacy_compatibility import transcribe_audio, create_audio_content_from_file
|
| 6 |
|
| 7 |
__all__ = [
|
| 8 |
'WhisperSTTProvider',
|
|
|
|
| 9 |
'STTProviderFactory',
|
| 10 |
'ASRFactory',
|
| 11 |
'transcribe_audio',
|
src/infrastructure/stt/legacy_compatibility.py
CHANGED
|
@@ -11,7 +11,7 @@ from ...domain.exceptions import SpeechRecognitionException
|
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
-
def transcribe_audio(audio_path: Union[str, Path], model_name: str = "
|
| 15 |
"""
|
| 16 |
Convert audio file to text using specified STT model (legacy interface).
|
| 17 |
|
|
@@ -19,7 +19,7 @@ def transcribe_audio(audio_path: Union[str, Path], model_name: str = "parakeet")
|
|
| 19 |
|
| 20 |
Args:
|
| 21 |
audio_path: Path to input audio file
|
| 22 |
-
model_name: Name of the STT model/provider to use (whisper
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
str: Transcribed English text
|
|
|
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
+
def transcribe_audio(audio_path: Union[str, Path], model_name: str = "whisper") -> str:
|
| 15 |
"""
|
| 16 |
Convert audio file to text using specified STT model (legacy interface).
|
| 17 |
|
|
|
|
| 19 |
|
| 20 |
Args:
|
| 21 |
audio_path: Path to input audio file
|
| 22 |
+
model_name: Name of the STT model/provider to use (whisper)
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
str: Transcribed English text
|
src/infrastructure/stt/parakeet_provider.py
DELETED
|
@@ -1,168 +0,0 @@
|
|
| 1 |
-
"""Parakeet STT provider implementation using Hugging Face Transformers."""
|
| 2 |
-
|
| 3 |
-
import logging
|
| 4 |
-
import torch
|
| 5 |
-
import librosa
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import TYPE_CHECKING, Optional, Tuple
|
| 8 |
-
|
| 9 |
-
if TYPE_CHECKING:
|
| 10 |
-
from ...domain.models.audio_content import AudioContent
|
| 11 |
-
from ...domain.models.text_content import TextContent
|
| 12 |
-
|
| 13 |
-
from ..base.stt_provider_base import STTProviderBase
|
| 14 |
-
from ...domain.exceptions import SpeechRecognitionException
|
| 15 |
-
|
| 16 |
-
logger = logging.getLogger(__name__)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
class ParakeetSTTProvider(STTProviderBase):
|
| 20 |
-
"""Parakeet STT provider using Hugging Face Transformers CTC model."""
|
| 21 |
-
|
| 22 |
-
def __init__(self):
|
| 23 |
-
"""Initialize the Parakeet STT provider."""
|
| 24 |
-
super().__init__(
|
| 25 |
-
provider_name="Parakeet",
|
| 26 |
-
supported_languages=["en"] # Parakeet primarily supports English
|
| 27 |
-
)
|
| 28 |
-
self.model = None
|
| 29 |
-
self.processor = None
|
| 30 |
-
self.current_model_name = None
|
| 31 |
-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
-
|
| 33 |
-
def _perform_transcription(self, audio_path: Path, model: str) -> str:
|
| 34 |
-
"""
|
| 35 |
-
Perform transcription using Parakeet CTC model.
|
| 36 |
-
|
| 37 |
-
Args:
|
| 38 |
-
audio_path: Path to the preprocessed audio file
|
| 39 |
-
model: The Parakeet model to use
|
| 40 |
-
|
| 41 |
-
Returns:
|
| 42 |
-
str: The transcribed text
|
| 43 |
-
"""
|
| 44 |
-
try:
|
| 45 |
-
# Load model if not already loaded or if different model requested
|
| 46 |
-
if self.model is None or self.current_model_name != model:
|
| 47 |
-
self._load_model(model)
|
| 48 |
-
|
| 49 |
-
logger.info(f"Starting Parakeet transcription with model {model}")
|
| 50 |
-
|
| 51 |
-
# Load and preprocess audio
|
| 52 |
-
audio_array, sample_rate = self._load_audio(audio_path)
|
| 53 |
-
|
| 54 |
-
# Process audio with the processor
|
| 55 |
-
inputs = self.processor(
|
| 56 |
-
audio_array,
|
| 57 |
-
sampling_rate=sample_rate,
|
| 58 |
-
return_tensors="pt"
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
inputs.to(self.device, dtype="auto")
|
| 62 |
-
|
| 63 |
-
# Decode the predictions
|
| 64 |
-
outputs = this.model.generate(**inputs)
|
| 65 |
-
transcription = self.processor.batch_decode(outputs)
|
| 66 |
-
|
| 67 |
-
logger.info("Parakeet transcription completed successfully")
|
| 68 |
-
return transcription
|
| 69 |
-
|
| 70 |
-
except Exception as e:
|
| 71 |
-
self._handle_provider_error(e, "transcription")
|
| 72 |
-
|
| 73 |
-
def _load_model(self, model_name: str):
|
| 74 |
-
"""
|
| 75 |
-
Load the Parakeet model using Hugging Face Transformers.
|
| 76 |
-
|
| 77 |
-
Args:
|
| 78 |
-
model_name: Name of the model to load
|
| 79 |
-
"""
|
| 80 |
-
try:
|
| 81 |
-
from transformers import AutoProcessor, AutoModelForCTC
|
| 82 |
-
|
| 83 |
-
logger.info(f"Loading Parakeet model: {model_name}")
|
| 84 |
-
|
| 85 |
-
# Map model names to actual model identifiers
|
| 86 |
-
model_mapping = {
|
| 87 |
-
"parakeet-ctc-0.6b": "nvidia/parakeet-ctc-0.6b",
|
| 88 |
-
"default": "nvidia/parakeet-ctc-0.6b"
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
actual_model_name = model_mapping.get(model_name, model_mapping["default"])
|
| 92 |
-
|
| 93 |
-
# Load processor and model
|
| 94 |
-
self.processor = AutoProcessor.from_pretrained(actual_model_name)
|
| 95 |
-
self.model = AutoModelForCTC.from_pretrained(actual_model_name, dtype="auto", device_map=self.device)
|
| 96 |
-
self.current_model_name = model_name
|
| 97 |
-
logger.info(f"Parakeet processor {processor}")
|
| 98 |
-
logger.info(f"Parakeet model {model}")
|
| 99 |
-
|
| 100 |
-
# Set model to evaluation mode
|
| 101 |
-
self.model.eval()
|
| 102 |
-
|
| 103 |
-
logger.info(f"Parakeet model {model_name} loaded successfully")
|
| 104 |
-
|
| 105 |
-
except ImportError as e:
|
| 106 |
-
raise SpeechRecognitionException(
|
| 107 |
-
"transformers library not available. Please install with: pip install transformers[audio]"
|
| 108 |
-
) from e
|
| 109 |
-
except Exception as e:
|
| 110 |
-
raise SpeechRecognitionException(f"Failed to load Parakeet model {model_name}: {str(e)}") from e
|
| 111 |
-
|
| 112 |
-
def _load_audio(self, audio_path: Path) -> Tuple[torch.Tensor, int]:
|
| 113 |
-
"""
|
| 114 |
-
Load audio file and return as tensor with sample rate.
|
| 115 |
-
|
| 116 |
-
Args:
|
| 117 |
-
audio_path: Path to the audio file
|
| 118 |
-
|
| 119 |
-
Returns:
|
| 120 |
-
Tuple[torch.Tensor, int]: Audio tensor and sample rate
|
| 121 |
-
"""
|
| 122 |
-
try:
|
| 123 |
-
# Load audio using librosa
|
| 124 |
-
audio_array, sample_rate = librosa.load(str(audio_path), sr=None)
|
| 125 |
-
|
| 126 |
-
# Convert to torch tensor
|
| 127 |
-
audio_tensor = torch.from_numpy(audio_array).float()
|
| 128 |
-
|
| 129 |
-
return audio_tensor, sample_rate
|
| 130 |
-
|
| 131 |
-
except Exception as e:
|
| 132 |
-
raise SpeechRecognitionException(f"Failed to load audio file {audio_path}: {str(e)}") from e
|
| 133 |
-
|
| 134 |
-
def is_available(self) -> bool:
|
| 135 |
-
"""
|
| 136 |
-
Check if the Parakeet provider is available.
|
| 137 |
-
|
| 138 |
-
Returns:
|
| 139 |
-
bool: True if transformers and required libraries are available, False otherwise
|
| 140 |
-
"""
|
| 141 |
-
try:
|
| 142 |
-
from transformers import AutoProcessor, AutoModelForCTC
|
| 143 |
-
import torch
|
| 144 |
-
import librosa
|
| 145 |
-
return True
|
| 146 |
-
except ImportError:
|
| 147 |
-
logger.warning("Required libraries (transformers, torch, librosa) not available")
|
| 148 |
-
return False
|
| 149 |
-
|
| 150 |
-
def get_available_models(self) -> list[str]:
|
| 151 |
-
"""
|
| 152 |
-
Get list of available Parakeet models.
|
| 153 |
-
|
| 154 |
-
Returns:
|
| 155 |
-
list[str]: List of available model names
|
| 156 |
-
"""
|
| 157 |
-
return [
|
| 158 |
-
"parakeet-ctc-0.6b"
|
| 159 |
-
]
|
| 160 |
-
|
| 161 |
-
def get_default_model(self) -> str:
|
| 162 |
-
"""
|
| 163 |
-
Get the default model for this provider.
|
| 164 |
-
|
| 165 |
-
Returns:
|
| 166 |
-
str: Default model name
|
| 167 |
-
"""
|
| 168 |
-
return "parakeet-ctc-0.6b"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/infrastructure/stt/provider_factory.py
CHANGED
|
@@ -5,7 +5,6 @@ from typing import Dict, Type, Optional
|
|
| 5 |
|
| 6 |
from ..base.stt_provider_base import STTProviderBase
|
| 7 |
from .whisper_provider import WhisperSTTProvider
|
| 8 |
-
from .parakeet_provider import ParakeetSTTProvider
|
| 9 |
from ...domain.exceptions import SpeechRecognitionException
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
|
@@ -15,11 +14,10 @@ class STTProviderFactory:
|
|
| 15 |
"""Factory for creating STT provider instances with availability checking and fallback logic."""
|
| 16 |
|
| 17 |
_providers: Dict[str, Type[STTProviderBase]] = {
|
| 18 |
-
"whisper": WhisperSTTProvider
|
| 19 |
-
"parakeet": ParakeetSTTProvider
|
| 20 |
}
|
| 21 |
|
| 22 |
-
_fallback_order = ["whisper"
|
| 23 |
|
| 24 |
@classmethod
|
| 25 |
def create_provider(cls, provider_name: str) -> STTProviderBase:
|
|
@@ -162,7 +160,7 @@ class ASRFactory:
|
|
| 162 |
"""Legacy ASRFactory for backward compatibility."""
|
| 163 |
|
| 164 |
@staticmethod
|
| 165 |
-
def get_model(model_name: str = "
|
| 166 |
"""
|
| 167 |
Get STT provider by model name (legacy interface).
|
| 168 |
|
|
@@ -175,7 +173,6 @@ class ASRFactory:
|
|
| 175 |
# Map legacy model names to provider names
|
| 176 |
provider_mapping = {
|
| 177 |
"whisper": "whisper",
|
| 178 |
-
"parakeet": "parakeet",
|
| 179 |
"faster-whisper": "whisper"
|
| 180 |
}
|
| 181 |
|
|
|
|
| 5 |
|
| 6 |
from ..base.stt_provider_base import STTProviderBase
|
| 7 |
from .whisper_provider import WhisperSTTProvider
|
|
|
|
| 8 |
from ...domain.exceptions import SpeechRecognitionException
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
|
|
|
| 14 |
"""Factory for creating STT provider instances with availability checking and fallback logic."""
|
| 15 |
|
| 16 |
_providers: Dict[str, Type[STTProviderBase]] = {
|
| 17 |
+
"whisper": WhisperSTTProvider
|
|
|
|
| 18 |
}
|
| 19 |
|
| 20 |
+
_fallback_order = ["whisper"]
|
| 21 |
|
| 22 |
@classmethod
|
| 23 |
def create_provider(cls, provider_name: str) -> STTProviderBase:
|
|
|
|
| 160 |
"""Legacy ASRFactory for backward compatibility."""
|
| 161 |
|
| 162 |
@staticmethod
|
| 163 |
+
def get_model(model_name: str = "whisper") -> STTProviderBase:
|
| 164 |
"""
|
| 165 |
Get STT provider by model name (legacy interface).
|
| 166 |
|
|
|
|
| 173 |
# Map legacy model names to provider names
|
| 174 |
provider_mapping = {
|
| 175 |
"whisper": "whisper",
|
|
|
|
| 176 |
"faster-whisper": "whisper"
|
| 177 |
}
|
| 178 |
|