Spaces:
Paused
Paused
File size: 5,439 Bytes
fe24641 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
import torchaudio
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import numpy as np
from typing import Optional, Union
import librosa
import soundfile as sf
import os
class KyutaiSTTProcessor:
"""Processor for Kyutai Speech-to-Text model"""
def __init__(self, device: str = "cuda"):
self.device = device if torch.cuda.is_available() else "cpu"
self.model = None
self.processor = None
self.model_id = "kyutai/stt-2.6b-en" # English-only model for better accuracy
# Audio processing parameters
self.sample_rate = 16000
self.chunk_length_s = 30 # Process in 30-second chunks
self.max_duration = 120 # Maximum 2 minutes of audio
def load_model(self):
"""Lazy load the STT model"""
if self.model is None:
try:
# Load processor and model
self.processor = AutoProcessor.from_pretrained(self.model_id)
# Model configuration for low VRAM usage
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
self.model.to(self.device)
# Enable better generation settings
self.model.generation_config.language = "english"
self.model.generation_config.task = "transcribe"
self.model.generation_config.forced_decoder_ids = None
except Exception as e:
print(f"Failed to load STT model: {e}")
raise
def preprocess_audio(self, audio_path: str) -> np.ndarray:
"""Preprocess audio file for transcription"""
try:
# Load audio file
audio, sr = librosa.load(audio_path, sr=None, mono=True)
# Resample if necessary
if sr != self.sample_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)
# Limit duration
max_samples = self.max_duration * self.sample_rate
if len(audio) > max_samples:
audio = audio[:max_samples]
# Normalize audio
audio = audio / np.max(np.abs(audio) + 1e-7)
return audio
except Exception as e:
print(f"Error preprocessing audio: {e}")
raise
def transcribe(self, audio_input: Union[str, np.ndarray]) -> str:
"""Transcribe audio to text"""
try:
# Load model if not already loaded
self.load_model()
# Process audio input
if isinstance(audio_input, str):
audio = self.preprocess_audio(audio_input)
else:
audio = audio_input
# Process with model
inputs = self.processor(
audio,
sampling_rate=self.sample_rate,
return_tensors="pt"
).to(self.device)
# Generate transcription
with torch.no_grad():
generated_ids = self.model.generate(
inputs["input_features"],
max_new_tokens=128,
do_sample=False,
num_beams=1 # Greedy decoding for speed
)
# Decode transcription
transcription = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)[0]
# Clean up transcription
transcription = self._clean_transcription(transcription)
return transcription
except Exception as e:
print(f"Transcription error: {e}")
# Return a default description on error
return "Create a unique digital monster companion"
def _clean_transcription(self, text: str) -> str:
"""Clean up transcription output"""
# Remove extra whitespace
text = " ".join(text.split())
# Ensure proper capitalization
if text and text[0].islower():
text = text[0].upper() + text[1:]
# Add period if missing
if text and not text[-1] in '.!?':
text += '.'
return text
def transcribe_streaming(self, audio_stream):
"""Streaming transcription (for future implementation)"""
# This would handle real-time audio streams
# For now, return placeholder
raise NotImplementedError("Streaming transcription not yet implemented")
def to(self, device: str):
"""Move model to specified device"""
self.device = device
if self.model:
self.model.to(device)
def __del__(self):
"""Cleanup when object is destroyed"""
if self.model:
del self.model
if self.processor:
del self.processor
torch.cuda.empty_cache() |