File size: 6,006 Bytes
cd1309d
0ee4f42
 
cd1309d
 
 
c72d839
c10f1ac
0ee4f42
 
 
c72d839
 
cd1309d
 
 
0ee4f42
cd1309d
0ee4f42
 
cd1309d
0ee4f42
 
 
 
 
 
 
 
 
 
 
 
2477bc4
c72d839
 
0ee4f42
 
 
c72d839
2477bc4
0ee4f42
c72d839
 
0ee4f42
 
 
 
 
 
 
 
 
 
 
 
 
 
c72d839
 
 
 
0ee4f42
 
 
 
 
 
 
 
 
 
 
c72d839
 
2477bc4
a4f48aa
 
 
7eff88c
 
0ee4f42
7eff88c
c72d839
 
7eff88c
0ee4f42
 
 
c72d839
 
 
 
7eff88c
0ee4f42
7eff88c
 
 
 
 
 
c72d839
0ee4f42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c72d839
 
 
0ee4f42
 
 
 
 
31708ca
0ee4f42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31708ca
0ee4f42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c72d839
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
Speech Recognition Module
Supports multiple ASR models including Whisper and Parakeet
Handles audio preprocessing and transcription
"""

import logging
import numpy as np
import os
from abc import ABC, abstractmethod

logger = logging.getLogger(__name__)

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from pydub import AudioSegment
import soundfile as sf

class ASRModel(ABC):
    """Base class for ASR models"""
    
    @abstractmethod
    def load_model(self):
        """Load the ASR model"""
        pass
    
    @abstractmethod
    def transcribe(self, audio_path):
        """Transcribe audio to text"""
        pass
    
    def preprocess_audio(self, audio_path):
        """Convert audio to required format"""
        logger.info("Converting audio format")
        audio = AudioSegment.from_file(audio_path)
        processed_audio = audio.set_frame_rate(16000).set_channels(1)
        wav_path = audio_path.replace(".mp3", ".wav") if audio_path.endswith(".mp3") else audio_path
        if not wav_path.endswith(".wav"):
            wav_path = f"{os.path.splitext(wav_path)[0]}.wav"
        processed_audio.export(wav_path, format="wav")
        logger.info(f"Audio converted to: {wav_path}")
        return wav_path


class WhisperModel(ASRModel):
    """Whisper ASR model implementation"""
    
    def __init__(self):
        self.model = None
        self.processor = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
    def load_model(self):
        """Load Whisper model"""
        logger.info("Loading Whisper model")
        logger.info(f"Using device: {self.device}")
        
        self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
            "openai/whisper-large-v3",
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True,
            use_safetensors=True
        ).to(self.device)
        
        self.processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
        logger.info("Whisper model loaded successfully")
    
    def transcribe(self, audio_path):
        """Transcribe audio using Whisper"""
        if self.model is None or self.processor is None:
            self.load_model()
            
        wav_path = self.preprocess_audio(audio_path)
        
        # Processing
        logger.info("Processing audio input")
        logger.debug("Loading audio data")
        audio_data, sample_rate = sf.read(wav_path)
        audio_data = audio_data.astype(np.float32)
        
        # Increase chunk length and stride for longer transcriptions
        inputs = self.processor(
            audio_data,
            sampling_rate=16000,
            return_tensors="pt",
            # Increase chunk length to handle longer segments
            chunk_length_s=60,
            stride_length_s=10
        ).to(self.device)

        # Transcription
        logger.info("Generating transcription")
        with torch.no_grad():
            # Add max_length parameter to allow for longer outputs
            outputs = self.model.generate(
                **inputs, 
                language="en", 
                task="transcribe",
                max_length=448,  # Explicitly set max output length
                no_repeat_ngram_size=3  # Prevent repetition in output
            )
        
        result = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
        logger.info(f"Transcription completed successfully")
        return result


class ParakeetModel(ASRModel):
    """Parakeet ASR model implementation"""
    
    def __init__(self):
        self.model = None
        
    def load_model(self):
        """Load Parakeet model"""
        try:
            import nemo.collections.asr as nemo_asr
            logger.info("Loading Parakeet model")
            self.model = nemo_asr.models.ASRModel.from_pretrained(model_name="nvidia/parakeet-tdt-0.6b-v2")
            logger.info("Parakeet model loaded successfully")
        except ImportError:
            logger.error("Failed to import nemo_toolkit. Please install with: pip install -U 'nemo_toolkit[asr]'")
            raise
    
    def transcribe(self, audio_path):
        """Transcribe audio using Parakeet"""
        if self.model is None:
            self.load_model()
            
        wav_path = self.preprocess_audio(audio_path)
        
        # Transcription
        logger.info("Generating transcription with Parakeet")
        output = self.model.transcribe([wav_path])
        result = output[0].text
        logger.info(f"Transcription completed successfully")
        return result


class ASRFactory:
    """Factory for creating ASR model instances"""
    
    @staticmethod
    def get_model(model_name="parakeet"):
        """
        Get ASR model by name
        Args:
            model_name: Name of the model to use (whisper or parakeet)
        Returns:
            ASR model instance
        """
        if model_name.lower() == "whisper":
            return WhisperModel()
        elif model_name.lower() == "parakeet":
            return ParakeetModel()
        else:
            logger.warning(f"Unknown model: {model_name}, falling back to Whisper")
            return WhisperModel()


def transcribe_audio(audio_path, model_name="parakeet"):
    """
    Convert audio file to text using specified ASR model
    Args:
        audio_path: Path to input audio file
        model_name: Name of the ASR model to use (whisper or parakeet)
    Returns:
        Transcribed English text
    """
    logger.info(f"Starting transcription for: {audio_path} using {model_name} model")
    
    try:
        # Get the appropriate model
        asr_model = ASRFactory.get_model(model_name)
        
        # Transcribe audio
        result = asr_model.transcribe(audio_path)
        logger.info(f"transcription: %s" % result)
        return result

    except Exception as e:
        logger.error(f"Transcription failed: {str(e)}", exc_info=True)
        raise