IIITH-SLT / app.py
akkirajubhavana's picture
Update app.py
1a98341 verified
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc.
# Licensed under the MIT License (MIT).
from __future__ import annotations
import os
import pathlib
import getpass
from typing import Any, Dict
import spaces
import gradio as gr
import torch
import torchaudio
from fairseq2.assets import InProcAssetMetadataProvider, asset_store
from huggingface_hub import snapshot_download, hf_hub_download
from seamless_communication.inference import Translator
from langlist_slt import (
LANGUAGE_NAME_TO_CODE,
S2TT_TARGET_LANGUAGE_NAMES,
ASR_TARGET_LANGUAGE_NAMES,
)
# ============================================================
# SETTINGS
# ============================================================
user = getpass.getuser()
CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", f"/home/{user}/app/models"))
# Ensure local model files exist (optional download if missing)
if not CHECKPOINTS_PATH.exists():
print(f"CHECKPOINTS_PATH {CHECKPOINTS_PATH} does not exist β€” downloading base model snapshot...")
snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
else:
print(f"Using existing CHECKPOINTS_PATH: {CHECKPOINTS_PATH}")
# Register assets (these `file://` paths should point to local files in CHECKPOINTS_PATH)
asset_store.env_resolvers.clear()
asset_store.env_resolvers.append(lambda: "demo")
demo_metadata = [
{
"name": "seamlessM4T_v2_large@demo",
"checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
"char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
},
{
"name": "vocoder_v2@demo",
"checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
},
]
asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
DESCRIPTION = """
IIITH-SLT - Speech Translation demo for Indian Language using weakly labeled data
End-to-End Speech Translation demo for low resource Indian language using weakly labeled data.
Supports ST models for:
- Bengali β†’ Hindi
- Malayalam β†’ Hindi
- Odia β†’ Hindi
- Telugu β†’ Hindi
Trained on the Shrutilipi-anuvaad dataset.
Paper: https://arxiv.org/pdf/2506.16251
"""
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available()
AUDIO_SAMPLE_RATE = 16000.0
MAX_INPUT_AUDIO_LENGTH = 60 # seconds
DEFAULT_TARGET_LANGUAGE = "Hindi"
# ============================================================
# FINETUNED MODEL CHECKPOINTS (HF repo & filenames)
# ============================================================
hf_repo = "SPL-IIITH/iiith-slt-finetuned-checkpoints"
FINETUNED_MODEL_FILES = {
"Telugu": "checkpoint_te_hi_v5.pt",
"Malayalam": "checkpoint_ml_hi_v5.pt",
"Odiya": "checkpoint_od_hi_v5.pt",
"Bengali": "checkpoint_bn_hi_v5.pt",
}
# ============================================================
# AUDIO PREPROCESSING
# ============================================================
def preprocess_audio(input_audio: str) -> None:
arr, org_sr = torchaudio.load(input_audio)
if org_sr != AUDIO_SAMPLE_RATE:
new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
else:
new_arr = arr
max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
if new_arr.shape[1] > max_length:
new_arr = new_arr[:, :max_length]
gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
torchaudio.save(input_audio, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
# ============================================================
# UTILS: load fine-tuned weights onto a translator instance
# ============================================================
def apply_finetuned_weights(translator: Translator, language: str, device: torch.device) -> Translator:
"""
Given a translator instance (fresh base model), load fine-tuned weights for `language`
onto the translator's submodules. Uses map_location=device when loading the .pt.
"""
ckpt_filename = FINETUNED_MODEL_FILES.get(language)
if not ckpt_filename:
return translator
ckpt_path = hf_hub_download(hf_repo, ckpt_filename)
print(f"Applying fine-tuned checkpoint for {language} from {ckpt_path}")
# Load checkpoint safely: newer torch has weights_only, but guard for compatibility
load_kwargs = {"map_location": device}
try:
saved = torch.load(ckpt_path, **load_kwargs, weights_only=True) # type: ignore
except TypeError:
saved = torch.load(ckpt_path, **load_kwargs)
# In many checkpoints the actual state is under key "model"
saved_model = saved.get("model", saved)
def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]:
return {key.replace(prefix, ""): value for key, value in state_dict.items() if key.startswith(prefix)}
# Try loading the expected submodules. Wrap in try/except to avoid full crash if keys differ.
try:
translator.model.speech_encoder_frontend.load_state_dict(_select_keys(saved_model, "model.speech_encoder_frontend."))
translator.model.speech_encoder.load_state_dict(_select_keys(saved_model, "model.speech_encoder."))
if getattr(translator.model, "text_decoder_frontend", None) is not None:
translator.model.text_decoder_frontend.load_state_dict(_select_keys(saved_model, "model.text_decoder_frontend."))
if getattr(translator.model, "text_decoder", None) is not None:
translator.model.text_decoder.load_state_dict(_select_keys(saved_model, "model.text_decoder."))
if getattr(translator.model, "final_proj", None) is not None:
translator.model.final_proj.load_state_dict(_select_keys(saved_model, "model.final_proj."))
except Exception as e:
# If applying fails (e.g., due to quantization differences), raise a warning and continue.
print("Warning: failed to load some fine-tuned weights directly. Error:", e)
# Optionally: try fallback (recreate non-quantized base model and apply weights). Keep simple for now.
return translator
# ============================================================
# MAIN TRANSLATION FUNCTION (load model per request; quantize on CPU)
# ============================================================
@spaces.GPU
def run_s2tt(input_audio: str, source_language: str, target_language: str):
# Preprocess audio to desired sample rate and length
preprocess_audio(input_audio)
# Decide device and dtype dynamically
if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16
print("Device: GPU β€” loading base model in fp16.")
translator = Translator(
model_name_or_card="seamlessM4T_v2_large",
vocoder_name_or_card=None,
device=device,
dtype=dtype,
apply_mintox=False,
)
else:
device = torch.device("cpu")
dtype = torch.float32
print("Device: CPU β€” loading base model and applying dynamic int8 quantization to reduce memory usage.")
translator = Translator(
model_name_or_card="seamlessM4T_v2_large",
vocoder_name_or_card=None,
device=device,
dtype=dtype,
apply_mintox=False,
)
# Try dynamic quantization (reduces memory used by Linear layers)
try:
translator.model = torch.quantization.quantize_dynamic(
translator.model,
{torch.nn.Linear},
dtype=torch.qint8,
)
print("Dynamic quantization applied successfully on CPU (torch.quantization.quantize_dynamic).")
except Exception as e:
print("Dynamic quantization failed or is incompatible with this model; continuing without quantization. Error:", e)
# Apply fine-tuned weights for the chosen source language
translator = apply_finetuned_weights(translator, source_language, device)
# Run prediction
source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
out_texts, _ = translator.predict(
input=input_audio,
task_str="S2TT",
src_lang=source_language_code,
tgt_lang=target_language_code,
)
# Optionally, free some memory if you want (Python may not release immediately)
# del translator
return str(out_texts[0])
# ============================================================
# GRADIO UI
# ============================================================
with gr.Blocks() as demo_s2tt:
with gr.Row():
with gr.Column():
with gr.Group():
input_audio = gr.Audio(label="Input speech", type="filepath")
source_language = gr.Dropdown(
label="Source language",
choices=ASR_TARGET_LANGUAGE_NAMES,
value="Telugu",
)
target_language = gr.Dropdown(
label="Target language",
choices=S2TT_TARGET_LANGUAGE_NAMES,
value=DEFAULT_TARGET_LANGUAGE,
)
btn = gr.Button("Translate")
with gr.Column():
output_text = gr.Textbox(label="Translated text")
btn.click(
fn=run_s2tt,
inputs=[input_audio, source_language, target_language],
outputs=output_text,
api_name="s2tt",
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Tabs():
demo_s2tt.render()
if __name__ == "__main__":
demo.queue(max_size=50).launch(share=True)