ClipScript / asr.py
muzzz's picture
fix
b224be5
import modal
import uuid
MODEL_NAME = "nvidia/parakeet-tdt-0.6b-v2"
def download_model():
try:
import nemo.collections.asr as nemo_asr # type: ignore
nemo_asr.models.ASRModel.from_pretrained(MODEL_NAME)
except ImportError:
pass
asr_image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("git", "ffmpeg")
.pip_install(
"torch",
"librosa",
"omegaconf",
"lightning",
"cuda-python>=12.3",
"git+https://github.com/NVIDIA/multi-storage-client.git",
"nemo_toolkit[asr] @ git+https://github.com/NVIDIA/NeMo@main",
extra_options="-U",
gpu="A10G",
)
.run_function(
download_model,
gpu="A10G",
)
)
with asr_image.imports():
import nemo.collections.asr as nemo_asr # type: ignore
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig # type: ignore
from nemo.collections.asr.parts.utils.streaming_utils import BatchedFrameASRTDT # type: ignore
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_rnnt # type: ignore
import math
import torch # type: ignore
from omegaconf import OmegaConf # type: ignore
import librosa # type: ignore
import os
app = modal.App(name="clipscript-asr-service")
# This must be the same volume object used in processing.py
upload_volume = modal.Volume.from_name(
"clipscript-uploads", create_if_missing=True
)
@app.cls(
image=asr_image,
gpu="A10G",
scaledown_window=600,
volumes={"/data": upload_volume}, # Mount the shared volume
)
class ASR:
@modal.enter()
def startup(self):
print("loading model...")
self.model = nemo_asr.models.ASRModel.from_pretrained(MODEL_NAME)
print("model loaded.")
self.model.freeze()
torch.set_grad_enabled(False)
# Configure for buffered inference
model_cfg = self.model._cfg
OmegaConf.set_struct(model_cfg.preprocessor, False)
model_cfg.preprocessor.dither = 0.0
model_cfg.preprocessor.pad_to = 0
OmegaConf.set_struct(model_cfg.preprocessor, True)
# Setup decoding for TDT model
decoding_cfg = RNNTDecodingConfig()
decoding_cfg.strategy = "greedy" # TDT requires greedy
decoding_cfg.preserve_alignments = True
decoding_cfg.fused_batch_size = -1
if hasattr(self.model, 'change_decoding_strategy'):
self.model.change_decoding_strategy(decoding_cfg)
# Calculate timing parameters
self.feature_stride = model_cfg.preprocessor['window_stride']
self.model_stride = 4 # TDT model stride
self.model_stride_in_secs = self.feature_stride * self.model_stride
# Buffered inference parameters
self.chunk_len_in_secs = 15.0
self.total_buffer_in_secs = 20.0
self.batch_size = 64
self.max_steps_per_timestep = 15
# Calculate chunk parameters
self.tokens_per_chunk = math.ceil(self.chunk_len_in_secs / self.model_stride_in_secs)
print("ASR setup complete with buffered inference support.")
def _get_audio_duration(self, audio_path: str) -> float:
try:
duration = librosa.get_duration(path=audio_path)
return duration
except Exception:
# Fallback: estimate from file size (rough approximation)
file_size = os.path.getsize(audio_path)
# Rough estimate: 16kHz, 16-bit mono = ~32KB per second
return file_size / 32000
def _simple_transcribe(self, audio_path: str) -> str:
print("Using simple transcription...")
output = self.model.transcribe([audio_path])
if not output or not hasattr(output[0], "text"):
return ""
return output[0].text
def _buffered_transcribe(self, audio_path: str) -> str:
print("Using buffered transcription...")
# Setup TDT frame processor
frame_asr = BatchedFrameASRTDT(
asr_model=self.model,
frame_len=self.chunk_len_in_secs,
total_buffer=self.total_buffer_in_secs,
batch_size=self.batch_size,
max_steps_per_timestep=self.max_steps_per_timestep,
stateful_decoding=False,
)
# Calculate delay for TDT
mid_delay = math.ceil((self.chunk_len_in_secs + (self.total_buffer_in_secs - self.chunk_len_in_secs) / 2) / self.model_stride_in_secs)
# Process with buffered inference
hyps = get_buffered_pred_feat_rnnt(
asr=frame_asr,
tokens_per_chunk=self.tokens_per_chunk,
delay=mid_delay,
model_stride_in_secs=self.model_stride_in_secs,
batch_size=self.batch_size,
manifest=None,
filepaths=[audio_path],
accelerator='gpu',
)
# Extract transcription
if hyps and len(hyps) > 0:
return hyps[0].text
return ""
@modal.method()
def transcribe(self, audio_filename: str = None, audio_bytes: bytes = None, use_buffered: bool | None = None) -> dict[str, str]:
audio_path = None
temp_audio_path = None
try:
if audio_filename:
audio_path = f"/data/{audio_filename}"
elif audio_bytes:
# When bytes are passed, they must be written to a file for librosa/nemo to read.
temp_audio_path = f"/tmp/input_{uuid.uuid4()}.wav"
with open(temp_audio_path, "wb") as f:
f.write(audio_bytes)
audio_path = temp_audio_path
else:
raise ValueError("Either 'audio_filename' or 'audio_bytes' must be provided.")
if not os.path.exists(audio_path):
return {"text": "", "error": f"Audio file not found at path: {audio_path}"}
# Determine transcription method
if use_buffered is None:
duration = self._get_audio_duration(audio_path)
use_buffered = duration > 1200.0 # 20 minutes
print(f"Audio duration: {duration:.1f}s, using {'buffered' if use_buffered else 'simple'} transcription")
if use_buffered:
text = self._buffered_transcribe(audio_path)
else:
text = self._simple_transcribe(audio_path)
print("transcription complete.")
return {"text": text, "error": ""}
except Exception as e:
print(f"Transcription error: {e}")
return {"text": "", "error": str(e)}
finally:
if temp_audio_path and os.path.exists(temp_audio_path):
os.remove(temp_audio_path)
@modal.method()
def transcribe_simple(self, audio_filename: str = None, audio_bytes: bytes = None) -> dict[str, str]:
"""Force simple transcription (for compatibility)"""
return self.transcribe(audio_filename=audio_filename, audio_bytes=audio_bytes, use_buffered=False)
@modal.method()
def transcribe_buffered(self, audio_filename: str = None, audio_bytes: bytes = None) -> dict[str, str]:
"""Force buffered transcription"""
return self.transcribe(audio_filename=audio_filename, audio_bytes=audio_bytes, use_buffered=True)