Spaces:
Sleeping
Sleeping
File size: 9,188 Bytes
5c2220f 64131ce 387f1a8 64131ce 387f1a8 64131ce 541194f 387f1a8 64131ce 387f1a8 64131ce 387f1a8 64131ce 387f1a8 64131ce 387f1a8 64131ce 387f1a8 541194f 64131ce 541194f 387f1a8 541194f 387f1a8 541194f 387f1a8 64131ce 387f1a8 ab92d39 387f1a8 5c2220f 387f1a8 64131ce 541194f 64131ce 5c2220f 387f1a8 64131ce 387f1a8 64131ce ab92d39 387f1a8 ab92d39 387f1a8 ab92d39 387f1a8 ab92d39 64131ce 541194f 64131ce 387f1a8 64131ce 387f1a8 64131ce 387f1a8 64131ce 387f1a8 ab92d39 64131ce ab92d39 64131ce 387f1a8 541194f 64131ce 387f1a8 ab92d39 387f1a8 ab92d39 387f1a8 8b3849e ab92d39 8b3849e ab92d39 387f1a8 ab92d39 387f1a8 ab92d39 387f1a8 ab92d39 387f1a8 5c2220f 387f1a8 ab92d39 387f1a8 ab92d39 387f1a8 64131ce 387f1a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
# --- START OF MODIFIED app.py ---
try:
import spaces
print("'spaces' module imported successfully.")
except ImportError:
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
class DummySpaces:
def GPU(self, *args, **kwargs):
def decorator(func):
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
return func
return decorator
spaces = DummySpaces()
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
import math
# --- Configuration ---
MODEL_PATH = "gregniuki/mandarin_thai_ipa"
BATCH_SIZE = 8
# --- Device Setup ---
if torch.cuda.is_available():
device = torch.device("cuda")
print("GPU detected. Using CUDA.")
else:
device = torch.device("cpu")
print("No GPU detected. Using CPU.")
# --- Model Loading ---
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_AUTH_TOKEN)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH, token=HF_AUTH_TOKEN)
model.to(device)
model.eval()
print(f"Model and tokenizer loaded successfully on device: {device}")
except Exception as e:
raise RuntimeError(f"FATAL Error loading model/tokenizer: {e}")
# --- Helper Function for Chunking ---
def chunk_text(text, max_chars):
if not text or text.isspace():
return []
text = text.strip()
if len(text) <= max_chars:
return [text]
chunks, current_index = [], 0
while current_index < len(text):
potential_end_index = min(current_index + max_chars, len(text))
actual_end_index = potential_end_index
if potential_end_index < len(text):
punctuation = ".!?。!?,,"
best_split_pos = -1
for punc in punctuation:
best_split_pos = max(best_split_pos, text.rfind(punc, current_index, potential_end_index))
if best_split_pos != -1:
actual_end_index = best_split_pos + 1
chunk = text[current_index:actual_end_index]
if chunk and not chunk.isspace():
chunks.append(chunk.strip())
current_index = actual_end_index
if current_index >= len(text):
break
return [c for c in chunks if c]
# --- Main Processing Function ---
@spaces.GPU
def translate_batch(
text_input,
max_chars_per_chunk,
repetition_penalty,
token_multiplier, # --- MODIFICATION: Added new parameter for the multiplier ---
decoding_strategy,
num_beams,
length_penalty,
use_early_stopping,
temperature,
top_p,
progress=gr.Progress(track_tqdm=True)
):
if not text_input or text_input.strip() == "":
return "[Error] Please enter some text to process."
lines = [line.strip() for line in text_input.splitlines() if line.strip()]
if not lines:
return "[Info] No valid text lines found in input."
max_chars_per_chunk = int(max_chars_per_chunk)
all_chunks = []
for line in lines:
all_chunks.extend(chunk_text(line, max_chars_per_chunk))
if not all_chunks:
return "[Info] No text chunks generated after processing input."
generation_kwargs = {"repetition_penalty": repetition_penalty, "do_sample": False}
if decoding_strategy == "Beam Search":
generation_kwargs.update({
"num_beams": int(num_beams), "length_penalty": length_penalty, "early_stopping": use_early_stopping,
})
elif decoding_strategy == "Sampling":
generation_kwargs.update({
"do_sample": True, "temperature": temperature, "top_p": top_p, "num_beams": 1,
})
print(f"Processing {len(all_chunks)} chunks with strategy: {decoding_strategy}. Args: {generation_kwargs}")
print(f" Using token_multiplier: {token_multiplier}") # Log the multiplier
all_ipa_outputs = []
num_batches = math.ceil(len(all_chunks) / BATCH_SIZE)
for i in progress.tqdm(range(num_batches), desc="Processing Batches"):
batch_start, batch_end = i * BATCH_SIZE, (i + 1) * BATCH_SIZE
batch_chunks = all_chunks[batch_start:batch_end]
try:
inputs = tokenizer(
batch_chunks, return_tensors="pt", padding=True, truncation=True, max_length=512
).to(device)
# --- MODIFICATION: Dynamic max_new_tokens using the slider value ---
max_input_length = inputs["input_ids"].shape[1]
generation_kwargs["max_new_tokens"] = min(int(max_input_length * token_multiplier) + 10, 512)
with torch.no_grad():
outputs = model.generate(**generation_kwargs, **inputs)
batch_ipa = tokenizer.batch_decode(outputs, skip_special_tokens=True)
all_ipa_outputs.extend(batch_ipa)
except Exception as e:
print(f"Error during batch {i+1} processing: {e}")
all_ipa_outputs.extend([f"[Error in batch {i+1}]"] * len(batch_chunks))
return "\n".join(all_ipa_outputs)
# --- UI Helper Function for Dynamic Controls ---
def update_decoding_ui(strategy):
if strategy == "Beam Search":
return gr.update(visible=True), gr.update(visible=False)
elif strategy == "Sampling":
return gr.update(visible=False), gr.update(visible=True)
# --- Gradio UI using Blocks ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🇹🇭🇨🇳 Advanced Mandarin & Thai to IPA Converter
Get the International Phonetic Alphabet (IPA) for Chinese (Mandarin) or Thai text. The model automatically detects the language.
This interface provides advanced controls for tuning the output quality.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_textbox = gr.Textbox(lines=10, label="Input Text (Mandarin or Thai)", placeholder="Enter text here (e.g., 你好世界 or สวัสดีครับ).")
submit_button = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
output_textbox = gr.Textbox(lines=10, label="IPA Output", interactive=False)
with gr.Accordion("Generation & Chunking Controls", open=False):
with gr.Row():
max_chars_slider = gr.Slider(minimum=1, maximum=512, step=1, value=36, label="Max Characters per Chunk", info="Splits long lines into smaller pieces for the model.")
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.1, label="Repetition Penalty", info=">1.0 discourages repeated words. Prevents stuttering.")
# --- MODIFICATION: Added the new slider for token multiplier ---
token_multiplier_slider = gr.Slider(
minimum=1.0, maximum=4.0, step=0.1, value=2.0,
label="Output Token Multiplier",
info="Safety factor for output length. Increase if output is cut off."
)
decoding_strategy_radio = gr.Radio(["Beam Search", "Sampling"], value="Beam Search", label="Decoding Strategy", info="Choose the method for generating text.")
with gr.Group(visible=True) as beam_search_group:
with gr.Row():
num_beams_slider = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of Beams", info="More beams can yield better quality but are slower.")
length_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Penalty", info=">1.0 encourages longer output; <1.0 for shorter.")
early_stopping_checkbox = gr.Checkbox(value=True, label="Early Stopping", info="Stop when a sentence is complete. Recommended.")
with gr.Group(visible=False) as sampling_group:
with gr.Row():
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.8, label="Temperature", info="Controls randomness. Lower is more predictable.")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.95, label="Top-p (Nucleus Sampling)", info="Considers a smaller, more probable set of next words.")
gr.Markdown(f"**Model:** `{MODEL_PATH}` | **Batch Size:** `{BATCH_SIZE}` | **Device:** `{str(device).upper()}`")
# --- Event Listeners ---
decoding_strategy_radio.change(fn=update_decoding_ui, inputs=decoding_strategy_radio, outputs=[beam_search_group, sampling_group])
submit_button.click(
fn=translate_batch,
inputs=[
input_textbox,
max_chars_slider,
repetition_penalty_slider,
token_multiplier_slider, # --- MODIFICATION: Added slider to inputs list ---
decoding_strategy_radio,
num_beams_slider,
length_penalty_slider,
early_stopping_checkbox,
temperature_slider,
top_p_slider,
],
outputs=output_textbox
)
# --- Launch the App ---
if __name__ == "__main__":
demo.launch()
# --- END OF FILE --- |