Spaces:
Running
Running
import modal | |
import uuid | |
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2" | |
def download_model(): | |
try: | |
import nemo.collections.asr as nemo_asr # type: ignore | |
nemo_asr.models.ASRModel.from_pretrained(MODEL_NAME) | |
except ImportError: | |
pass | |
asr_image = ( | |
modal.Image.debian_slim(python_version="3.12") | |
.apt_install("git", "ffmpeg") | |
.pip_install( | |
"torch", | |
"librosa", | |
"omegaconf", | |
"lightning", | |
"cuda-python>=12.3", | |
"git+https://github.com/NVIDIA/multi-storage-client.git", | |
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo@main", | |
extra_options="-U", | |
gpu="A10G", | |
) | |
.run_function( | |
download_model, | |
gpu="A10G", | |
) | |
) | |
with asr_image.imports(): | |
import nemo.collections.asr as nemo_asr # type: ignore | |
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig # type: ignore | |
from nemo.collections.asr.parts.utils.streaming_utils import BatchedFrameASRTDT # type: ignore | |
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_rnnt # type: ignore | |
import math | |
import torch # type: ignore | |
from omegaconf import OmegaConf # type: ignore | |
import librosa # type: ignore | |
import os | |
app = modal.App(name="clipscript-asr-service") | |
# This must be the same volume object used in processing.py | |
upload_volume = modal.Volume.from_name( | |
"clipscript-uploads", create_if_missing=True | |
) | |
class ASR: | |
def startup(self): | |
print("loading model...") | |
self.model = nemo_asr.models.ASRModel.from_pretrained(MODEL_NAME) | |
print("model loaded.") | |
self.model.freeze() | |
torch.set_grad_enabled(False) | |
# Configure for buffered inference | |
model_cfg = self.model._cfg | |
OmegaConf.set_struct(model_cfg.preprocessor, False) | |
model_cfg.preprocessor.dither = 0.0 | |
model_cfg.preprocessor.pad_to = 0 | |
OmegaConf.set_struct(model_cfg.preprocessor, True) | |
# Setup decoding for TDT model | |
decoding_cfg = RNNTDecodingConfig() | |
decoding_cfg.strategy = "greedy" # TDT requires greedy | |
decoding_cfg.preserve_alignments = True | |
decoding_cfg.fused_batch_size = -1 | |
if hasattr(self.model, 'change_decoding_strategy'): | |
self.model.change_decoding_strategy(decoding_cfg) | |
# Calculate timing parameters | |
self.feature_stride = model_cfg.preprocessor['window_stride'] | |
self.model_stride = 4 # TDT model stride | |
self.model_stride_in_secs = self.feature_stride * self.model_stride | |
# Buffered inference parameters | |
self.chunk_len_in_secs = 15.0 | |
self.total_buffer_in_secs = 20.0 | |
self.batch_size = 64 | |
self.max_steps_per_timestep = 15 | |
# Calculate chunk parameters | |
self.tokens_per_chunk = math.ceil(self.chunk_len_in_secs / self.model_stride_in_secs) | |
print("ASR setup complete with buffered inference support.") | |
def _get_audio_duration(self, audio_path: str) -> float: | |
try: | |
duration = librosa.get_duration(path=audio_path) | |
return duration | |
except Exception: | |
# Fallback: estimate from file size (rough approximation) | |
file_size = os.path.getsize(audio_path) | |
# Rough estimate: 16kHz, 16-bit mono = ~32KB per second | |
return file_size / 32000 | |
def _simple_transcribe(self, audio_path: str) -> str: | |
print("Using simple transcription...") | |
output = self.model.transcribe([audio_path]) | |
if not output or not hasattr(output[0], "text"): | |
return "" | |
return output[0].text | |
def _buffered_transcribe(self, audio_path: str) -> str: | |
print("Using buffered transcription...") | |
# Setup TDT frame processor | |
frame_asr = BatchedFrameASRTDT( | |
asr_model=self.model, | |
frame_len=self.chunk_len_in_secs, | |
total_buffer=self.total_buffer_in_secs, | |
batch_size=self.batch_size, | |
max_steps_per_timestep=self.max_steps_per_timestep, | |
stateful_decoding=False, | |
) | |
# Calculate delay for TDT | |
mid_delay = math.ceil((self.chunk_len_in_secs + (self.total_buffer_in_secs - self.chunk_len_in_secs) / 2) / self.model_stride_in_secs) | |
# Process with buffered inference | |
hyps = get_buffered_pred_feat_rnnt( | |
asr=frame_asr, | |
tokens_per_chunk=self.tokens_per_chunk, | |
delay=mid_delay, | |
model_stride_in_secs=self.model_stride_in_secs, | |
batch_size=self.batch_size, | |
manifest=None, | |
filepaths=[audio_path], | |
accelerator='gpu', | |
) | |
# Extract transcription | |
if hyps and len(hyps) > 0: | |
return hyps[0].text | |
return "" | |
def transcribe(self, audio_filename: str = None, audio_bytes: bytes = None, use_buffered: bool | None = None) -> dict[str, str]: | |
audio_path = None | |
temp_audio_path = None | |
try: | |
if audio_filename: | |
audio_path = f"/data/{audio_filename}" | |
elif audio_bytes: | |
# When bytes are passed, they must be written to a file for librosa/nemo to read. | |
temp_audio_path = f"/tmp/input_{uuid.uuid4()}.wav" | |
with open(temp_audio_path, "wb") as f: | |
f.write(audio_bytes) | |
audio_path = temp_audio_path | |
else: | |
raise ValueError("Either 'audio_filename' or 'audio_bytes' must be provided.") | |
if not os.path.exists(audio_path): | |
return {"text": "", "error": f"Audio file not found at path: {audio_path}"} | |
# Determine transcription method | |
if use_buffered is None: | |
duration = self._get_audio_duration(audio_path) | |
use_buffered = duration > 1200.0 # 20 minutes | |
print(f"Audio duration: {duration:.1f}s, using {'buffered' if use_buffered else 'simple'} transcription") | |
if use_buffered: | |
text = self._buffered_transcribe(audio_path) | |
else: | |
text = self._simple_transcribe(audio_path) | |
print("transcription complete.") | |
return {"text": text, "error": ""} | |
except Exception as e: | |
print(f"Transcription error: {e}") | |
return {"text": "", "error": str(e)} | |
finally: | |
if temp_audio_path and os.path.exists(temp_audio_path): | |
os.remove(temp_audio_path) | |
def transcribe_simple(self, audio_filename: str = None, audio_bytes: bytes = None) -> dict[str, str]: | |
"""Force simple transcription (for compatibility)""" | |
return self.transcribe(audio_filename=audio_filename, audio_bytes=audio_bytes, use_buffered=False) | |
def transcribe_buffered(self, audio_filename: str = None, audio_bytes: bytes = None) -> dict[str, str]: | |
"""Force buffered transcription""" | |
return self.transcribe(audio_filename=audio_filename, audio_bytes=audio_bytes, use_buffered=True) |