File size: 5,439 Bytes
fe24641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import torchaudio
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import numpy as np
from typing import Optional, Union
import librosa
import soundfile as sf
import os

class KyutaiSTTProcessor:
    """Processor for Kyutai Speech-to-Text model"""
    
    def __init__(self, device: str = "cuda"):
        self.device = device if torch.cuda.is_available() else "cpu"
        self.model = None
        self.processor = None
        self.model_id = "kyutai/stt-2.6b-en"  # English-only model for better accuracy
        
        # Audio processing parameters
        self.sample_rate = 16000
        self.chunk_length_s = 30  # Process in 30-second chunks
        self.max_duration = 120  # Maximum 2 minutes of audio
    
    def load_model(self):
        """Lazy load the STT model"""
        if self.model is None:
            try:
                # Load processor and model
                self.processor = AutoProcessor.from_pretrained(self.model_id)
                
                # Model configuration for low VRAM usage
                torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
                
                self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
                    self.model_id,
                    torch_dtype=torch_dtype,
                    low_cpu_mem_usage=True,
                    use_safetensors=True
                )
                
                self.model.to(self.device)
                
                # Enable better generation settings
                self.model.generation_config.language = "english"
                self.model.generation_config.task = "transcribe"
                self.model.generation_config.forced_decoder_ids = None
                
            except Exception as e:
                print(f"Failed to load STT model: {e}")
                raise
    
    def preprocess_audio(self, audio_path: str) -> np.ndarray:
        """Preprocess audio file for transcription"""
        try:
            # Load audio file
            audio, sr = librosa.load(audio_path, sr=None, mono=True)
            
            # Resample if necessary
            if sr != self.sample_rate:
                audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)
            
            # Limit duration
            max_samples = self.max_duration * self.sample_rate
            if len(audio) > max_samples:
                audio = audio[:max_samples]
            
            # Normalize audio
            audio = audio / np.max(np.abs(audio) + 1e-7)
            
            return audio
            
        except Exception as e:
            print(f"Error preprocessing audio: {e}")
            raise
    
    def transcribe(self, audio_input: Union[str, np.ndarray]) -> str:
        """Transcribe audio to text"""
        try:
            # Load model if not already loaded
            self.load_model()
            
            # Process audio input
            if isinstance(audio_input, str):
                audio = self.preprocess_audio(audio_input)
            else:
                audio = audio_input
            
            # Process with model
            inputs = self.processor(
                audio, 
                sampling_rate=self.sample_rate, 
                return_tensors="pt"
            ).to(self.device)
            
            # Generate transcription
            with torch.no_grad():
                generated_ids = self.model.generate(
                    inputs["input_features"],
                    max_new_tokens=128,
                    do_sample=False,
                    num_beams=1  # Greedy decoding for speed
                )
            
            # Decode transcription
            transcription = self.processor.batch_decode(
                generated_ids, 
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            )[0]
            
            # Clean up transcription
            transcription = self._clean_transcription(transcription)
            
            return transcription
            
        except Exception as e:
            print(f"Transcription error: {e}")
            # Return a default description on error
            return "Create a unique digital monster companion"
    
    def _clean_transcription(self, text: str) -> str:
        """Clean up transcription output"""
        # Remove extra whitespace
        text = " ".join(text.split())
        
        # Ensure proper capitalization
        if text and text[0].islower():
            text = text[0].upper() + text[1:]
        
        # Add period if missing
        if text and not text[-1] in '.!?':
            text += '.'
        
        return text
    
    def transcribe_streaming(self, audio_stream):
        """Streaming transcription (for future implementation)"""
        # This would handle real-time audio streams
        # For now, return placeholder
        raise NotImplementedError("Streaming transcription not yet implemented")
    
    def to(self, device: str):
        """Move model to specified device"""
        self.device = device
        if self.model:
            self.model.to(device)
    
    def __del__(self):
        """Cleanup when object is destroyed"""
        if self.model:
            del self.model
        if self.processor:
            del self.processor
        torch.cuda.empty_cache()