Spaces:
Paused
Paused
import torch | |
import torchaudio | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor | |
import numpy as np | |
from typing import Optional, Union | |
import librosa | |
import soundfile as sf | |
import os | |
class KyutaiSTTProcessor: | |
"""Processor for Kyutai Speech-to-Text model""" | |
def __init__(self, device: str = "cuda"): | |
self.device = device if torch.cuda.is_available() else "cpu" | |
self.model = None | |
self.processor = None | |
self.model_id = "kyutai/stt-2.6b-en" # English-only model for better accuracy | |
# Audio processing parameters | |
self.sample_rate = 16000 | |
self.chunk_length_s = 30 # Process in 30-second chunks | |
self.max_duration = 120 # Maximum 2 minutes of audio | |
def load_model(self): | |
"""Lazy load the STT model""" | |
if self.model is None: | |
try: | |
# Load processor and model | |
self.processor = AutoProcessor.from_pretrained(self.model_id) | |
# Model configuration for low VRAM usage | |
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
self.model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
self.model_id, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
) | |
self.model.to(self.device) | |
# Enable better generation settings | |
self.model.generation_config.language = "english" | |
self.model.generation_config.task = "transcribe" | |
self.model.generation_config.forced_decoder_ids = None | |
except Exception as e: | |
print(f"Failed to load STT model: {e}") | |
raise | |
def preprocess_audio(self, audio_path: str) -> np.ndarray: | |
"""Preprocess audio file for transcription""" | |
try: | |
# Load audio file | |
audio, sr = librosa.load(audio_path, sr=None, mono=True) | |
# Resample if necessary | |
if sr != self.sample_rate: | |
audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate) | |
# Limit duration | |
max_samples = self.max_duration * self.sample_rate | |
if len(audio) > max_samples: | |
audio = audio[:max_samples] | |
# Normalize audio | |
audio = audio / np.max(np.abs(audio) + 1e-7) | |
return audio | |
except Exception as e: | |
print(f"Error preprocessing audio: {e}") | |
raise | |
def transcribe(self, audio_input: Union[str, np.ndarray]) -> str: | |
"""Transcribe audio to text""" | |
try: | |
# Load model if not already loaded | |
self.load_model() | |
# Process audio input | |
if isinstance(audio_input, str): | |
audio = self.preprocess_audio(audio_input) | |
else: | |
audio = audio_input | |
# Process with model | |
inputs = self.processor( | |
audio, | |
sampling_rate=self.sample_rate, | |
return_tensors="pt" | |
).to(self.device) | |
# Generate transcription | |
with torch.no_grad(): | |
generated_ids = self.model.generate( | |
inputs["input_features"], | |
max_new_tokens=128, | |
do_sample=False, | |
num_beams=1 # Greedy decoding for speed | |
) | |
# Decode transcription | |
transcription = self.processor.batch_decode( | |
generated_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True | |
)[0] | |
# Clean up transcription | |
transcription = self._clean_transcription(transcription) | |
return transcription | |
except Exception as e: | |
print(f"Transcription error: {e}") | |
# Return a default description on error | |
return "Create a unique digital monster companion" | |
def _clean_transcription(self, text: str) -> str: | |
"""Clean up transcription output""" | |
# Remove extra whitespace | |
text = " ".join(text.split()) | |
# Ensure proper capitalization | |
if text and text[0].islower(): | |
text = text[0].upper() + text[1:] | |
# Add period if missing | |
if text and not text[-1] in '.!?': | |
text += '.' | |
return text | |
def transcribe_streaming(self, audio_stream): | |
"""Streaming transcription (for future implementation)""" | |
# This would handle real-time audio streams | |
# For now, return placeholder | |
raise NotImplementedError("Streaming transcription not yet implemented") | |
def to(self, device: str): | |
"""Move model to specified device""" | |
self.device = device | |
if self.model: | |
self.model.to(device) | |
def __del__(self): | |
"""Cleanup when object is destroyed""" | |
if self.model: | |
del self.model | |
if self.processor: | |
del self.processor | |
torch.cuda.empty_cache() |