File size: 9,188 Bytes
5c2220f
64131ce
 
 
 
 
 
 
 
 
 
 
 
387f1a8
64131ce
 
387f1a8
 
64131ce
 
 
 
541194f
387f1a8
64131ce
387f1a8
64131ce
 
 
 
 
 
 
387f1a8
64131ce
 
 
387f1a8
 
 
 
 
64131ce
387f1a8
64131ce
 
387f1a8
541194f
 
64131ce
541194f
 
 
387f1a8
541194f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387f1a8
541194f
387f1a8
64131ce
 
 
387f1a8
 
ab92d39
387f1a8
5c2220f
387f1a8
 
 
 
 
64131ce
 
541194f
64131ce
 
 
 
5c2220f
387f1a8
64131ce
 
387f1a8
64131ce
 
 
 
ab92d39
387f1a8
 
 
ab92d39
387f1a8
 
 
ab92d39
387f1a8
 
 
ab92d39
64131ce
541194f
64131ce
 
387f1a8
 
64131ce
387f1a8
64131ce
 
387f1a8
64131ce
387f1a8
ab92d39
64131ce
ab92d39
 
64131ce
387f1a8
 
541194f
 
64131ce
 
387f1a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab92d39
387f1a8
 
ab92d39
387f1a8
 
 
8b3849e
ab92d39
 
 
8b3849e
ab92d39
 
387f1a8
 
ab92d39
387f1a8
 
 
ab92d39
 
 
387f1a8
 
 
ab92d39
 
387f1a8
 
5c2220f
387f1a8
ab92d39
387f1a8
 
 
 
 
 
 
ab92d39
387f1a8
 
 
 
 
 
 
 
 
64131ce
 
 
387f1a8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# --- START OF MODIFIED app.py ---

try:
    import spaces
    print("'spaces' module imported successfully.")
except ImportError:
    print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
    class DummySpaces:
        def GPU(self, *args, **kwargs):
            def decorator(func):
                print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
                return func
            return decorator
    spaces = DummySpaces()

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
import math

# --- Configuration ---
MODEL_PATH = "gregniuki/mandarin_thai_ipa"
BATCH_SIZE = 8

# --- Device Setup ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU detected. Using CUDA.")
else:
    device = torch.device("cpu")
    print("No GPU detected. Using CPU.")

# --- Model Loading ---
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_AUTH_TOKEN)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH, token=HF_AUTH_TOKEN)
    model.to(device)
    model.eval()
    print(f"Model and tokenizer loaded successfully on device: {device}")
except Exception as e:
    raise RuntimeError(f"FATAL Error loading model/tokenizer: {e}")


# --- Helper Function for Chunking ---
def chunk_text(text, max_chars):
    if not text or text.isspace():
        return []
    text = text.strip()
    if len(text) <= max_chars:
        return [text]
    chunks, current_index = [], 0
    while current_index < len(text):
        potential_end_index = min(current_index + max_chars, len(text))
        actual_end_index = potential_end_index
        if potential_end_index < len(text):
            punctuation = ".!?。!?,,"
            best_split_pos = -1
            for punc in punctuation:
                best_split_pos = max(best_split_pos, text.rfind(punc, current_index, potential_end_index))
            if best_split_pos != -1:
                actual_end_index = best_split_pos + 1
        chunk = text[current_index:actual_end_index]
        if chunk and not chunk.isspace():
            chunks.append(chunk.strip())
        current_index = actual_end_index
        if current_index >= len(text):
            break
    return [c for c in chunks if c]

# --- Main Processing Function ---
@spaces.GPU
def translate_batch(
    text_input,
    max_chars_per_chunk,
    repetition_penalty,
    token_multiplier, # --- MODIFICATION: Added new parameter for the multiplier ---
    decoding_strategy,
    num_beams,
    length_penalty,
    use_early_stopping,
    temperature,
    top_p,
    progress=gr.Progress(track_tqdm=True)
):
    if not text_input or text_input.strip() == "":
        return "[Error] Please enter some text to process."

    lines = [line.strip() for line in text_input.splitlines() if line.strip()]
    if not lines:
        return "[Info] No valid text lines found in input."
    
    max_chars_per_chunk = int(max_chars_per_chunk)
    all_chunks = []
    for line in lines:
        all_chunks.extend(chunk_text(line, max_chars_per_chunk))

    if not all_chunks:
        return "[Info] No text chunks generated after processing input."

    generation_kwargs = {"repetition_penalty": repetition_penalty, "do_sample": False}

    if decoding_strategy == "Beam Search":
        generation_kwargs.update({
            "num_beams": int(num_beams), "length_penalty": length_penalty, "early_stopping": use_early_stopping,
        })
    elif decoding_strategy == "Sampling":
        generation_kwargs.update({
            "do_sample": True, "temperature": temperature, "top_p": top_p, "num_beams": 1,
        })
    
    print(f"Processing {len(all_chunks)} chunks with strategy: {decoding_strategy}. Args: {generation_kwargs}")
    print(f"  Using token_multiplier: {token_multiplier}") # Log the multiplier

    all_ipa_outputs = []
    num_batches = math.ceil(len(all_chunks) / BATCH_SIZE)

    for i in progress.tqdm(range(num_batches), desc="Processing Batches"):
        batch_start, batch_end = i * BATCH_SIZE, (i + 1) * BATCH_SIZE
        batch_chunks = all_chunks[batch_start:batch_end]
        
        try:
            inputs = tokenizer(
                batch_chunks, return_tensors="pt", padding=True, truncation=True, max_length=512
            ).to(device)
            
            # --- MODIFICATION: Dynamic max_new_tokens using the slider value ---
            max_input_length = inputs["input_ids"].shape[1]
            generation_kwargs["max_new_tokens"] = min(int(max_input_length * token_multiplier) + 10, 512)
            
            with torch.no_grad():
                outputs = model.generate(**generation_kwargs, **inputs)
            
            batch_ipa = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            all_ipa_outputs.extend(batch_ipa)

        except Exception as e:
            print(f"Error during batch {i+1} processing: {e}")
            all_ipa_outputs.extend([f"[Error in batch {i+1}]"] * len(batch_chunks))

    return "\n".join(all_ipa_outputs)

# --- UI Helper Function for Dynamic Controls ---
def update_decoding_ui(strategy):
    if strategy == "Beam Search":
        return gr.update(visible=True), gr.update(visible=False)
    elif strategy == "Sampling":
        return gr.update(visible=False), gr.update(visible=True)

# --- Gradio UI using Blocks ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🇹🇭🇨🇳 Advanced Mandarin & Thai to IPA Converter
        Get the International Phonetic Alphabet (IPA) for Chinese (Mandarin) or Thai text. The model automatically detects the language.
        This interface provides advanced controls for tuning the output quality.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            input_textbox = gr.Textbox(lines=10, label="Input Text (Mandarin or Thai)", placeholder="Enter text here (e.g., 你好世界 or สวัสดีครับ).")
            submit_button = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            output_textbox = gr.Textbox(lines=10, label="IPA Output", interactive=False)

    with gr.Accordion("Generation & Chunking Controls", open=False):
        with gr.Row():
            max_chars_slider = gr.Slider(minimum=1, maximum=512, step=1, value=36, label="Max Characters per Chunk", info="Splits long lines into smaller pieces for the model.")
            repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.1, label="Repetition Penalty", info=">1.0 discourages repeated words. Prevents stuttering.")
            # --- MODIFICATION: Added the new slider for token multiplier ---
            token_multiplier_slider = gr.Slider(
                minimum=1.0, maximum=4.0, step=0.1, value=2.0,
                label="Output Token Multiplier",
                info="Safety factor for output length. Increase if output is cut off."
            )
            
        decoding_strategy_radio = gr.Radio(["Beam Search", "Sampling"], value="Beam Search", label="Decoding Strategy", info="Choose the method for generating text.")

        with gr.Group(visible=True) as beam_search_group:
            with gr.Row():
                num_beams_slider = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of Beams", info="More beams can yield better quality but are slower.")
                length_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Penalty", info=">1.0 encourages longer output; <1.0 for shorter.")
                early_stopping_checkbox = gr.Checkbox(value=True, label="Early Stopping", info="Stop when a sentence is complete. Recommended.")

        with gr.Group(visible=False) as sampling_group:
            with gr.Row():
                temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.8, label="Temperature", info="Controls randomness. Lower is more predictable.")
                top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.95, label="Top-p (Nucleus Sampling)", info="Considers a smaller, more probable set of next words.")
    
    gr.Markdown(f"**Model:** `{MODEL_PATH}` | **Batch Size:** `{BATCH_SIZE}` | **Device:** `{str(device).upper()}`")

    # --- Event Listeners ---
    decoding_strategy_radio.change(fn=update_decoding_ui, inputs=decoding_strategy_radio, outputs=[beam_search_group, sampling_group])
    
    submit_button.click(
        fn=translate_batch,
        inputs=[
            input_textbox,
            max_chars_slider,
            repetition_penalty_slider,
            token_multiplier_slider, # --- MODIFICATION: Added slider to inputs list ---
            decoding_strategy_radio,
            num_beams_slider,
            length_penalty_slider,
            early_stopping_checkbox,
            temperature_slider,
            top_p_slider,
        ],
        outputs=output_textbox
    )

# --- Launch the App ---
if __name__ == "__main__":
    demo.launch()
# --- END OF FILE ---