File size: 3,910 Bytes
c7e434a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5395cd1
 
c7e434a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
Speech to Text Service
Wrapper untuk Whisper STT
"""

import whisper
import torch
import warnings
import os
from typing import Dict
from app.core.device import get_device, optimize_for_device
warnings.filterwarnings('ignore')


class SpeechToTextService:
    """Speech-to-Text service using Whisper"""
    
    def __init__(self, model_name: str = "medium", device: str = None, language: str = "id"):
        """Initialize Whisper model"""
        print(f"πŸŽ™οΈ Initializing Speech-to-Text service")
        print(f"πŸ“¦ Loading Whisper model: {model_name}")
        
        # Auto-detect device if not specified
        if device is None or device == "auto":
            self.device = get_device()
            optimize_for_device(self.device)
        else:
            self.device = device
            print(f"πŸ’» Using device: {self.device}")
        
        # Check if model is already cached
        # Use /data/.cache for Whisper (persistent storage on HF Pro)
        cache_dir = os.environ.get('WHISPER_CACHE', '/data/.cache')
        model_cache_path = os.path.join(cache_dir, f'{model_name}.pt')
        
        # Load Whisper model
        try:
            if os.path.exists(model_cache_path):
                print(f"βœ… Loading from cache (pre-downloaded during build)")
            else:
                print(f"πŸ“₯ Model not in cache, downloading '{model_name}'...")
                print(f"   This may take 1-2 minutes...")
            
            self.model = whisper.load_model(model_name, device=self.device, download_root=cache_dir)
            print("βœ… Whisper model ready!\n")
        except Exception as e:
            print(f"❌ Failed to load model '{model_name}': {e}")
            print("βš™οΈ Falling back to 'base' model...")
            
            base_cache_path = os.path.join(cache_dir, 'base.pt')
            if os.path.exists(base_cache_path):
                print(f"βœ… Loading base model from cache")
            else:
                print(f"πŸ“₯ Downloading base model...")
            
            self.model = whisper.load_model("base", device=self.device, download_root=cache_dir)
            print("βœ… Base model ready!\n")
        
        self.language = language
    
    def transcribe(self, audio_path: str, **kwargs) -> Dict:
        """
        Transcribe audio file to text
        
        Args:
            audio_path: Path ke file audio
            **kwargs: Additional Whisper parameters
            
        Returns:
            Dict: {'text': str, 'segments': list, 'language': str}
        """
        print(f"🎧 Transcribing: {audio_path}")
        
        try:
            # Try with word_timestamps first
            # Use FP16 for GPU to reduce memory and improve speed
            fp16 = self.device == "cuda"
            
            result = self.model.transcribe(
                audio_path,
                language=self.language,
                task="transcribe",
                word_timestamps=True,
                condition_on_previous_text=False,
                fp16=fp16,
                **kwargs
            )
        except Exception as e:
            print(f"⚠️ Transcription with word_timestamps failed: {e}")
            print(f"πŸ”„ Retrying without word_timestamps...")
            
            # Fallback: transcribe without word_timestamps
            fp16 = self.device == "cuda"
            
            result = self.model.transcribe(
                audio_path,
                language=self.language,
                task="transcribe",
                condition_on_previous_text=False,
                fp16=fp16,
                **kwargs
            )
        
        print("βœ… Transcription complete!\n")
        
        return {
            'text': result['text'],
            'segments': result.get('segments', []),
            'language': result.get('language', self.language)
        }