bhavana commited on
Commit
4ba628e
·
1 Parent(s): 6a6728e

Add application file

Browse files
Files changed (3) hide show
  1. app.py +194 -0
  2. langlist_slt.py +26 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # MIT_LICENSE file in the root directory of this source tree.
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ import pathlib
12
+ import getpass
13
+ from typing import Any, Dict
14
+
15
+ import gradio as gr
16
+ import numpy as np
17
+ import torch
18
+ import torchaudio
19
+ from fairseq2.assets import InProcAssetMetadataProvider, asset_store
20
+ from huggingface_hub import snapshot_download
21
+ from seamless_communication.inference import Translator
22
+
23
+ from langlist_slt import (
24
+ LANGUAGE_NAME_TO_CODE,
25
+ S2TT_TARGET_LANGUAGE_NAMES,
26
+ ASR_TARGET_LANGUAGE_NAMES,
27
+ )
28
+ os.environ["GRADIO_TEMP_DIR"] = "/data/bhavana/app/tmp"
29
+ os.makedirs(os.environ["GRADIO_TEMP_DIR"], exist_ok=True)
30
+ user = getpass.getuser() # this is not portable on windows
31
+ CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", f"/data/bhavana/app/models"))
32
+ if not CHECKPOINTS_PATH.exists():
33
+ snapshot_download(repo_id="facebook/hf-seamless-m4t-medium", repo_type="model", local_dir=CHECKPOINTS_PATH)
34
+ asset_store.env_resolvers.clear()
35
+ asset_store.env_resolvers.append(lambda: "demo")
36
+ demo_metadata = [
37
+ {
38
+ "name": "seamlessM4T_v2_medium@demo",
39
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_medium.pt",
40
+ "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
41
+ },
42
+ {
43
+ "name": "vocoder_v2@demo",
44
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
45
+ },
46
+ ]
47
+ asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
48
+
49
+ DESCRIPTION = """\
50
+ #"IIITH-SLT: End-to-End Speech Translation demo for low-resource Indian languages using weakly labeled data. Supports ST models for Bengali-Hindi,
51
+ # Malayalam-Hindi, Odia-Hindi, and Telugu-Hindi, trained on the Shrutilipi-anuvaad dataset."
52
+ [Paper](https://arxiv.org/pdf/2506.16251)
53
+ """
54
+
55
+ CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available()
56
+
57
+ AUDIO_SAMPLE_RATE = 16000.0
58
+ MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
59
+ DEFAULT_TARGET_LANGUAGE = "Hindi"
60
+
61
+ if torch.cuda.is_available():
62
+ device = torch.device("cuda:0")
63
+ dtype = torch.float16
64
+ else:
65
+ device = torch.device("cpu")
66
+ dtype = torch.float32
67
+
68
+ FINETUNED_MODEL_MAP = {
69
+ "Telugu": "/data/aishwarya/seamless_communication/src/seamless_communication/cli/m4t_original/finetune_original/icassp/checkpoint_te_hi_v5.pt",
70
+ "Malayalam": "/data/aishwarya/seamless_communication/src/seamless_communication/cli/m4t_original/finetune_original/icassp/checkpoint_ml_hi_v5.pt",
71
+ "Odiya": "/data/aishwarya/seamless_communication/src/seamless_communication/cli/m4t_original/finetune_original/icassp/checkpoint_od_hi_v5.pt",
72
+ "Bengali": "/data/aishwarya/seamless_communication/src/seamless_communication/cli/m4t_original/finetune_original/icassp/checkpoint_bn_hi_v5.pt ",
73
+ # Add more
74
+ }
75
+
76
+
77
+ from functools import lru_cache
78
+
79
+ @lru_cache(maxsize=None)
80
+ def load_translator_for_language(language: str) -> Translator:
81
+ # Load base model (Meta's SeamlessM4T)
82
+ translator = Translator(
83
+ model_name_or_card="seamlessM4T_v2_large",
84
+ vocoder_name_or_card=None,
85
+ device=device,
86
+ dtype=dtype,
87
+ apply_mintox=False,
88
+ )
89
+
90
+ # Apply language-specific fine-tuned weights
91
+ ckpt_path = FINETUNED_MODEL_MAP.get(language)
92
+ if ckpt_path and os.path.exists(ckpt_path):
93
+ print(f"Loading fine-tuned checkpoint for {language} from {ckpt_path}")
94
+ saved_model = torch.load(ckpt_path, map_location=device)["model"]
95
+ saved_model = {k.replace("module.", ""): v for k, v in saved_model.items()}
96
+
97
+ def _select_keys(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]:
98
+ return {key.replace(prefix, ""): value for key, value in state_dict.items() if key.startswith(prefix)}
99
+
100
+ translator.model.speech_encoder_frontend.load_state_dict(_select_keys(saved_model, "model.speech_encoder_frontend."))
101
+ translator.model.speech_encoder.load_state_dict(_select_keys(saved_model, "model.speech_encoder."))
102
+
103
+ assert translator.model.text_decoder_frontend is not None
104
+ translator.model.text_decoder_frontend.load_state_dict(_select_keys(saved_model, "model.text_decoder_frontend."))
105
+
106
+ assert translator.model.text_decoder is not None
107
+ translator.model.text_decoder.load_state_dict(_select_keys(saved_model, "model.text_decoder."))
108
+
109
+ assert translator.model.final_proj is not None
110
+ translator.model.final_proj.load_state_dict(_select_keys(saved_model, "model.final_proj."))
111
+
112
+ return translator
113
+
114
+
115
+ def preprocess_audio(input_audio: str) -> None:
116
+ arr, org_sr = torchaudio.load(input_audio)
117
+ new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
118
+ max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
119
+ if new_arr.shape[1] > max_length:
120
+ new_arr = new_arr[:, :max_length]
121
+ gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
122
+ torchaudio.save(input_audio, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
123
+
124
+
125
+ def run_s2tt(input_audio: str, source_language: str, target_language: str):
126
+ preprocess_audio(input_audio)
127
+
128
+ translator = load_translator_for_language(source_language)
129
+
130
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
131
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
132
+ out_texts, _= translator.predict(
133
+ input=input_audio,
134
+ task_str="S2ST",
135
+ src_lang=source_language_code,
136
+ tgt_lang=target_language_code,
137
+ )
138
+ return str(out_texts[0])
139
+
140
+
141
+
142
+ with gr.Blocks() as demo_s2tt:
143
+ with gr.Row():
144
+ with gr.Column():
145
+ with gr.Group():
146
+ input_audio = gr.Audio(label="Input speech", type="filepath")
147
+ source_language = gr.Dropdown(
148
+ label="Source language",
149
+ choices=ASR_TARGET_LANGUAGE_NAMES,
150
+ value="Telugu",
151
+ )
152
+ target_language = gr.Dropdown(
153
+ label="Target language",
154
+ choices=S2TT_TARGET_LANGUAGE_NAMES,
155
+ value=DEFAULT_TARGET_LANGUAGE,
156
+ )
157
+ btn = gr.Button("Translate")
158
+ with gr.Column():
159
+ output_text = gr.Textbox(label="Translated text")
160
+
161
+ gr.Examples(
162
+ examples=[],
163
+ inputs=[input_audio, source_language, target_language],
164
+ outputs=output_text,
165
+ fn=run_s2tt,
166
+ cache_examples=CACHE_EXAMPLES,
167
+ api_name=False,
168
+ )
169
+
170
+ btn.click(
171
+ fn=run_s2tt,
172
+ inputs=[input_audio, source_language, target_language],
173
+ outputs=output_text,
174
+ api_name="s2tt",
175
+ )
176
+
177
+
178
+
179
+
180
+
181
+ with gr.Blocks(css="style.css") as demo:
182
+ gr.Markdown(DESCRIPTION)
183
+ gr.DuplicateButton(
184
+ value="Duplicate Space for private use",
185
+ elem_id="duplicate-button",
186
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
187
+ )
188
+
189
+ with gr.Tabs():
190
+ demo_s2tt.render()
191
+
192
+
193
+ if __name__ == "__main__":
194
+ demo.queue(max_size=50).launch(share=True)
langlist_slt.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Language dict
2
+ language_code_to_name = {
3
+ "ben": "Bengali",
4
+ "hin": "Hindi",
5
+ "mal": "Malayalam",
6
+ "tel": "Telugu",
7
+ "odi": "Odiya",
8
+ }
9
+ LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
10
+
11
+ # Source langs: S2ST / S2TT / ASR don't need source lang
12
+ # T2TT / T2ST use this
13
+ text_source_language_codes = [
14
+ "tel",
15
+ "odi",
16
+ "mal",
17
+ "ben",
18
+ "hin",
19
+ ]
20
+ TEXT_SOURCE_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in text_source_language_codes])
21
+
22
+
23
+
24
+ # S2TT / T2TT / ASR
25
+ S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
26
+ ASR_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio~=4.5.0
2
+ omegaconf~=2.3.0
3
+ torch~=2.1.0
4
+ torchaudio~=2.1.0
5
+ fairseq2~=0.2.0