Michael Hu commited on
Commit
619b266
·
1 Parent(s): c8d736e

refactor: remove parakeet ASR provider and update all references to Whisper only

Browse files
src/application/dtos/processing_request_dto.py CHANGED
@@ -35,7 +35,7 @@ class ProcessingRequestDto:
35
  raise ValueError("ASR model cannot be empty")
36
 
37
  # Validate ASR model options
38
- supported_asr_models = ['parakeet', 'whisper-small', 'whisper-medium', 'whisper-large']
39
  if self.asr_model not in supported_asr_models:
40
  raise ValueError(f"Unsupported ASR model: {self.asr_model}. Supported: {supported_asr_models}")
41
 
 
35
  raise ValueError("ASR model cannot be empty")
36
 
37
  # Validate ASR model options
38
+ supported_asr_models = ['whisper-small', 'whisper-medium', 'whisper-large']
39
  if self.asr_model not in supported_asr_models:
40
  raise ValueError(f"Unsupported ASR model: {self.asr_model}. Supported: {supported_asr_models}")
41
 
src/application/services/audio_processing_service.py CHANGED
@@ -634,7 +634,7 @@ class AudioProcessingApplicationService:
634
  Dict[str, Any]: Supported configurations
635
  """
636
  return {
637
- 'asr_models': ['parakeet', 'whisper-large'],
638
  'voices': ['chatterbox'],
639
  'languages': ['en', 'zh'],
640
  'audio_formats': self._config.get_processing_config()['supported_audio_formats'],
 
634
  Dict[str, Any]: Supported configurations
635
  """
636
  return {
637
+ 'asr_models': ['whisper-large'],
638
  'voices': ['chatterbox'],
639
  'languages': ['en', 'zh'],
640
  'audio_formats': self._config.get_processing_config()['supported_audio_formats'],
src/application/services/configuration_service.py CHANGED
@@ -331,7 +331,7 @@ class ConfigurationApplicationService:
331
  Raises:
332
  ConfigurationException: If validation fails
333
  """
334
- valid_providers = ['whisper', 'parakeet']
335
 
336
  for key, value in updates.items():
337
  if key == 'preferred_providers':
@@ -524,7 +524,7 @@ class ConfigurationApplicationService:
524
 
525
  # Check STT providers
526
  stt_factory = self._container.resolve(type(self._container._get_stt_factory()))
527
- for provider in ['whisper', 'parakeet']:
528
  try:
529
  stt_factory.create_provider(provider)
530
  availability['stt'][provider] = True
 
331
  Raises:
332
  ConfigurationException: If validation fails
333
  """
334
+ valid_providers = ['whisper']
335
 
336
  for key, value in updates.items():
337
  if key == 'preferred_providers':
 
524
 
525
  # Check STT providers
526
  stt_factory = self._container.resolve(type(self._container._get_stt_factory()))
527
+ for provider in ['whisper']:
528
  try:
529
  stt_factory.create_provider(provider)
530
  availability['stt'][provider] = True
src/domain/interfaces/speech_recognition.py CHANGED
@@ -5,7 +5,7 @@ audio content into textual representation. The interface supports multiple STT
5
  models and providers with consistent error handling.
6
 
7
  The interface is designed to be:
8
- - Model-agnostic: Works with any STT implementation (Whisper, Parakeet, etc.)
9
  - Language-aware: Handles multiple languages and dialects
10
  - Error-resilient: Provides detailed error information for debugging
11
  - Performance-conscious: Supports both batch and streaming transcription
@@ -65,7 +65,6 @@ class ISpeechRecognitionService(ABC):
65
  model: The STT model identifier to use for transcription. Examples:
66
  - "whisper-small": Fast, lower accuracy
67
  - "whisper-large": Slower, higher accuracy
68
- - "parakeet": Real-time optimized
69
  Must be supported by the implementation.
70
 
71
  Returns:
 
5
  models and providers with consistent error handling.
6
 
7
  The interface is designed to be:
8
+ - Model-agnostic: Works with any STT implementation (Whisper, etc.)
9
  - Language-aware: Handles multiple languages and dialects
10
  - Error-resilient: Provides detailed error information for debugging
11
  - Performance-conscious: Supports both batch and streaming transcription
 
65
  model: The STT model identifier to use for transcription. Examples:
66
  - "whisper-small": Fast, lower accuracy
67
  - "whisper-large": Slower, higher accuracy
 
68
  Must be supported by the implementation.
69
 
70
  Returns:
src/infrastructure/config/app_config.py CHANGED
@@ -23,8 +23,8 @@ class TTSConfig:
23
  @dataclass
24
  class STTConfig:
25
  """Configuration for STT providers."""
26
- preferred_providers: List[str] = field(default_factory=lambda: ['parakeet', 'whisper'])
27
- default_model: str = 'parakeet'
28
  chunk_length_s: int = 30
29
  batch_size: int = 16
30
  enable_vad: bool = True
 
23
  @dataclass
24
  class STTConfig:
25
  """Configuration for STT providers."""
26
+ preferred_providers: List[str] = field(default_factory=lambda: ['whisper'])
27
+ default_model: str = 'whisper'
28
  chunk_length_s: int = 30
29
  batch_size: int = 16
30
  enable_vad: bool = True
src/infrastructure/stt/__init__.py CHANGED
@@ -1,13 +1,11 @@
1
  """STT provider implementations."""
2
 
3
  from .whisper_provider import WhisperSTTProvider
4
- from .parakeet_provider import ParakeetSTTProvider
5
  from .provider_factory import STTProviderFactory, ASRFactory
6
  from .legacy_compatibility import transcribe_audio, create_audio_content_from_file
7
 
8
  __all__ = [
9
  'WhisperSTTProvider',
10
- 'ParakeetSTTProvider',
11
  'STTProviderFactory',
12
  'ASRFactory',
13
  'transcribe_audio',
 
1
  """STT provider implementations."""
2
 
3
  from .whisper_provider import WhisperSTTProvider
 
4
  from .provider_factory import STTProviderFactory, ASRFactory
5
  from .legacy_compatibility import transcribe_audio, create_audio_content_from_file
6
 
7
  __all__ = [
8
  'WhisperSTTProvider',
 
9
  'STTProviderFactory',
10
  'ASRFactory',
11
  'transcribe_audio',
src/infrastructure/stt/legacy_compatibility.py CHANGED
@@ -11,7 +11,7 @@ from ...domain.exceptions import SpeechRecognitionException
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
- def transcribe_audio(audio_path: Union[str, Path], model_name: str = "parakeet") -> str:
15
  """
16
  Convert audio file to text using specified STT model (legacy interface).
17
 
@@ -19,7 +19,7 @@ def transcribe_audio(audio_path: Union[str, Path], model_name: str = "parakeet")
19
 
20
  Args:
21
  audio_path: Path to input audio file
22
- model_name: Name of the STT model/provider to use (whisper or parakeet)
23
 
24
  Returns:
25
  str: Transcribed English text
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
14
+ def transcribe_audio(audio_path: Union[str, Path], model_name: str = "whisper") -> str:
15
  """
16
  Convert audio file to text using specified STT model (legacy interface).
17
 
 
19
 
20
  Args:
21
  audio_path: Path to input audio file
22
+ model_name: Name of the STT model/provider to use (whisper)
23
 
24
  Returns:
25
  str: Transcribed English text
src/infrastructure/stt/parakeet_provider.py DELETED
@@ -1,168 +0,0 @@
1
- """Parakeet STT provider implementation using Hugging Face Transformers."""
2
-
3
- import logging
4
- import torch
5
- import librosa
6
- from pathlib import Path
7
- from typing import TYPE_CHECKING, Optional, Tuple
8
-
9
- if TYPE_CHECKING:
10
- from ...domain.models.audio_content import AudioContent
11
- from ...domain.models.text_content import TextContent
12
-
13
- from ..base.stt_provider_base import STTProviderBase
14
- from ...domain.exceptions import SpeechRecognitionException
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- class ParakeetSTTProvider(STTProviderBase):
20
- """Parakeet STT provider using Hugging Face Transformers CTC model."""
21
-
22
- def __init__(self):
23
- """Initialize the Parakeet STT provider."""
24
- super().__init__(
25
- provider_name="Parakeet",
26
- supported_languages=["en"] # Parakeet primarily supports English
27
- )
28
- self.model = None
29
- self.processor = None
30
- self.current_model_name = None
31
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
32
-
33
- def _perform_transcription(self, audio_path: Path, model: str) -> str:
34
- """
35
- Perform transcription using Parakeet CTC model.
36
-
37
- Args:
38
- audio_path: Path to the preprocessed audio file
39
- model: The Parakeet model to use
40
-
41
- Returns:
42
- str: The transcribed text
43
- """
44
- try:
45
- # Load model if not already loaded or if different model requested
46
- if self.model is None or self.current_model_name != model:
47
- self._load_model(model)
48
-
49
- logger.info(f"Starting Parakeet transcription with model {model}")
50
-
51
- # Load and preprocess audio
52
- audio_array, sample_rate = self._load_audio(audio_path)
53
-
54
- # Process audio with the processor
55
- inputs = self.processor(
56
- audio_array,
57
- sampling_rate=sample_rate,
58
- return_tensors="pt"
59
- )
60
-
61
- inputs.to(self.device, dtype="auto")
62
-
63
- # Decode the predictions
64
- outputs = this.model.generate(**inputs)
65
- transcription = self.processor.batch_decode(outputs)
66
-
67
- logger.info("Parakeet transcription completed successfully")
68
- return transcription
69
-
70
- except Exception as e:
71
- self._handle_provider_error(e, "transcription")
72
-
73
- def _load_model(self, model_name: str):
74
- """
75
- Load the Parakeet model using Hugging Face Transformers.
76
-
77
- Args:
78
- model_name: Name of the model to load
79
- """
80
- try:
81
- from transformers import AutoProcessor, AutoModelForCTC
82
-
83
- logger.info(f"Loading Parakeet model: {model_name}")
84
-
85
- # Map model names to actual model identifiers
86
- model_mapping = {
87
- "parakeet-ctc-0.6b": "nvidia/parakeet-ctc-0.6b",
88
- "default": "nvidia/parakeet-ctc-0.6b"
89
- }
90
-
91
- actual_model_name = model_mapping.get(model_name, model_mapping["default"])
92
-
93
- # Load processor and model
94
- self.processor = AutoProcessor.from_pretrained(actual_model_name)
95
- self.model = AutoModelForCTC.from_pretrained(actual_model_name, dtype="auto", device_map=self.device)
96
- self.current_model_name = model_name
97
- logger.info(f"Parakeet processor {processor}")
98
- logger.info(f"Parakeet model {model}")
99
-
100
- # Set model to evaluation mode
101
- self.model.eval()
102
-
103
- logger.info(f"Parakeet model {model_name} loaded successfully")
104
-
105
- except ImportError as e:
106
- raise SpeechRecognitionException(
107
- "transformers library not available. Please install with: pip install transformers[audio]"
108
- ) from e
109
- except Exception as e:
110
- raise SpeechRecognitionException(f"Failed to load Parakeet model {model_name}: {str(e)}") from e
111
-
112
- def _load_audio(self, audio_path: Path) -> Tuple[torch.Tensor, int]:
113
- """
114
- Load audio file and return as tensor with sample rate.
115
-
116
- Args:
117
- audio_path: Path to the audio file
118
-
119
- Returns:
120
- Tuple[torch.Tensor, int]: Audio tensor and sample rate
121
- """
122
- try:
123
- # Load audio using librosa
124
- audio_array, sample_rate = librosa.load(str(audio_path), sr=None)
125
-
126
- # Convert to torch tensor
127
- audio_tensor = torch.from_numpy(audio_array).float()
128
-
129
- return audio_tensor, sample_rate
130
-
131
- except Exception as e:
132
- raise SpeechRecognitionException(f"Failed to load audio file {audio_path}: {str(e)}") from e
133
-
134
- def is_available(self) -> bool:
135
- """
136
- Check if the Parakeet provider is available.
137
-
138
- Returns:
139
- bool: True if transformers and required libraries are available, False otherwise
140
- """
141
- try:
142
- from transformers import AutoProcessor, AutoModelForCTC
143
- import torch
144
- import librosa
145
- return True
146
- except ImportError:
147
- logger.warning("Required libraries (transformers, torch, librosa) not available")
148
- return False
149
-
150
- def get_available_models(self) -> list[str]:
151
- """
152
- Get list of available Parakeet models.
153
-
154
- Returns:
155
- list[str]: List of available model names
156
- """
157
- return [
158
- "parakeet-ctc-0.6b"
159
- ]
160
-
161
- def get_default_model(self) -> str:
162
- """
163
- Get the default model for this provider.
164
-
165
- Returns:
166
- str: Default model name
167
- """
168
- return "parakeet-ctc-0.6b"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/infrastructure/stt/provider_factory.py CHANGED
@@ -5,7 +5,6 @@ from typing import Dict, Type, Optional
5
 
6
  from ..base.stt_provider_base import STTProviderBase
7
  from .whisper_provider import WhisperSTTProvider
8
- from .parakeet_provider import ParakeetSTTProvider
9
  from ...domain.exceptions import SpeechRecognitionException
10
 
11
  logger = logging.getLogger(__name__)
@@ -15,11 +14,10 @@ class STTProviderFactory:
15
  """Factory for creating STT provider instances with availability checking and fallback logic."""
16
 
17
  _providers: Dict[str, Type[STTProviderBase]] = {
18
- "whisper": WhisperSTTProvider,
19
- "parakeet": ParakeetSTTProvider
20
  }
21
 
22
- _fallback_order = ["whisper", "parakeet"]
23
 
24
  @classmethod
25
  def create_provider(cls, provider_name: str) -> STTProviderBase:
@@ -162,7 +160,7 @@ class ASRFactory:
162
  """Legacy ASRFactory for backward compatibility."""
163
 
164
  @staticmethod
165
- def get_model(model_name: str = "parakeet") -> STTProviderBase:
166
  """
167
  Get STT provider by model name (legacy interface).
168
 
@@ -175,7 +173,6 @@ class ASRFactory:
175
  # Map legacy model names to provider names
176
  provider_mapping = {
177
  "whisper": "whisper",
178
- "parakeet": "parakeet",
179
  "faster-whisper": "whisper"
180
  }
181
 
 
5
 
6
  from ..base.stt_provider_base import STTProviderBase
7
  from .whisper_provider import WhisperSTTProvider
 
8
  from ...domain.exceptions import SpeechRecognitionException
9
 
10
  logger = logging.getLogger(__name__)
 
14
  """Factory for creating STT provider instances with availability checking and fallback logic."""
15
 
16
  _providers: Dict[str, Type[STTProviderBase]] = {
17
+ "whisper": WhisperSTTProvider
 
18
  }
19
 
20
+ _fallback_order = ["whisper"]
21
 
22
  @classmethod
23
  def create_provider(cls, provider_name: str) -> STTProviderBase:
 
160
  """Legacy ASRFactory for backward compatibility."""
161
 
162
  @staticmethod
163
+ def get_model(model_name: str = "whisper") -> STTProviderBase:
164
  """
165
  Get STT provider by model name (legacy interface).
166
 
 
173
  # Map legacy model names to provider names
174
  provider_mapping = {
175
  "whisper": "whisper",
 
176
  "faster-whisper": "whisper"
177
  }
178