File size: 8,704 Bytes
2e5906e
 
 
 
 
 
 
 
 
 
 
 
 
 
cc776fe
 
 
 
 
 
 
bc560ab
 
cc776fe
 
 
12f28e5
65d9b39
 
cc776fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e5906e
cc776fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ebdedd
 
88fb476
cc776fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
try:
    import spaces
    print("'spaces' module imported successfully.")
except ImportError:
    print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
    # Define a dummy decorator that does nothing if 'spaces' isn't available
    class DummySpaces:
        def GPU(self, *args, **kwargs):
            def decorator(func):
                # This dummy decorator just returns the original function
                print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
                return func
            return decorator
    spaces = DummySpaces() # Create an instance of the dummy class
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TFAutoModelForSeq2SeqLM
import torch # Or import tensorflow as tf
import os
import math
# Requires Gradio version supporting spaces.GPU decorator if running on Spaces
# Might need: from gradio.external import spaces <- if spaces not directly available
#import gradio.external as spaces # Use this import path
from huggingface_hub import hf_hub_download

# --- Configuration ---
# IMPORTANT: REPLACE THIS with your model's Hugging Face Hub ID or local path
MODEL_PATH = "Gregniuki/pl-en-pl-v2" # Use your actual model path
MAX_WORDS_PER_CHUNK = 40
BATCH_SIZE = 4 # Adjust based on GPU memory / desired throughput

# --- Device Setup (Zero GPU Support) ---
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU detected. Using CUDA.")
else:
    device = torch.device("cpu")
    print("No GPU detected. Using CPU.")

# --- Get Hugging Face Token from Secrets for Private Models ---
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): # Rough check if it's likely a Hub ID
    if HF_AUTH_TOKEN is None:
        print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.")
    else:
        print("HF_TOKEN found. Using token for model loading.")
else:
    print(f"Loading model from local path: {MODEL_PATH}")
    HF_AUTH_TOKEN = None # Don't use token for local paths


# --- Load Model and Tokenizer (once on startup) ---
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATH,
        token=HF_AUTH_TOKEN,
        trust_remote_code=False
    )

    # --- Choose the correct model class ---
    # PyTorch (most common)
    model = AutoModelForSeq2SeqLM.from_pretrained(
        MODEL_PATH,
        token=HF_AUTH_TOKEN,
        trust_remote_code=False
    )
    model.to(device) # Move model to the determined device
    model.eval() # Set model to evaluation mode
    print(f"Using PyTorch model on device: {device}")

    # # TensorFlow (uncomment if your model is TF)
    # from transformers import TFAutoModelForSeq2SeqLM
    # import tensorflow as tf
    # model = TFAutoModelForSeq2SeqLM.from_pretrained(
    #     MODEL_PATH,
    #     token=HF_AUTH_TOKEN,
    #     trust_remote_code=False
    # )
    # # TF device placement is often automatic or managed via strategies
    # print("Using TensorFlow model.")

    print("Model and tokenizer loaded successfully.")

except Exception as e:
    print(f"FATAL Error loading model/tokenizer: {e}")
    if "401 Client Error" in str(e):
         error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}."
    else:
         error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}"
    # Raise error to prevent app launch if model loading fails
    raise RuntimeError(error_message)


# --- Helper Function for Chunking Sentences ---
def chunk_sentence(sentence, max_words):
    """Splits a sentence into chunks of max_words."""
    if not sentence or sentence.isspace():
        return []
    words = sentence.split() # Simple space splitting
    chunks = []
    current_chunk = []
    for word in words:
        current_chunk.append(word)
        if len(current_chunk) >= max_words:
            chunks.append(" ".join(current_chunk))
            current_chunk = []
    if current_chunk: # Add any remaining words
        chunks.append(" ".join(current_chunk))
    return chunks

# --- Define the BATCH translation function ---
# Add GPU decorator for Spaces (adjust duration if needed)
@spaces.GPU
def translate_batch(text_input):
    """
    Translates multi-line input text using batching and sentence chunking.
    Assumes auto-detection of language direction (no prefixes).
    """
    if not text_input or text_input.strip() == "":
        return "[Error] Please enter some text to translate."

    print(f"Received input block for batch translation.")

    # 1. Split input into potential sentences (lines) and clean
    lines = [line.strip() for line in text_input.splitlines() if line.strip()]
    if not lines:
        return "[Info] No valid text lines found in input."

    # 2. Chunk long sentences
    all_chunks = []
    for line in lines:
        sentence_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK)
        all_chunks.extend(sentence_chunks)

    if not all_chunks:
        return "[Info] No text chunks generated after processing input."

    print(f"Processing {len(all_chunks)} chunks in batches...")

    # 3. Process chunks in batches
    all_translations = []
    num_batches = math.ceil(len(all_chunks) / BATCH_SIZE)

    for i in range(num_batches):
        batch_start = i * BATCH_SIZE
        batch_end = batch_start + BATCH_SIZE
        batch_chunks = all_chunks[batch_start:batch_end]
        print(f"  Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)")

        # Tokenize the batch
        try:
            # PyTorch
            inputs = tokenizer(batch_chunks, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
            # # TensorFlow
            # inputs = tokenizer(batch_chunks, return_tensors="tf", padding=True, truncation=True, max_length=512)

        except Exception as e:
            print(f"Error during batch tokenization: {e}")
            # Return partial results or a general error
            return "[Error] Tokenization failed for a batch."

        # Generate translations for the batch
        try:
            # PyTorch
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_length=1024,
                    num_beams=8,
                    early_stopping=False
                )
            # output_ids shape: [batch_size, sequence_length]

            # # TensorFlow
            # outputs = model.generate(
            #     inputs['input_ids'],
            #     attention_mask=inputs['attention_mask'],
            #     max_length=512,
            #     num_beams=4,
            #     early_stopping=True
            # )
            # outputs is typically a tensor of shape [batch_size, sequence_length]

            # Decode the batch results
            batch_translations = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            all_translations.extend(batch_translations)

        except Exception as e:
            print(f"Error during batch generation/decoding: {e}")
            # Return partial results or a general error
            return "[Error] Translation generation failed for a batch."

    # 4. Join translated chunks back together
    # Simple join with newline, might not perfectly preserve original structure if chunking happened mid-sentence.
    final_output = "\n".join(all_translations)
    print("Batch translation finished.")
    return final_output


# --- Create Gradio Interface for Batch Translation ---
input_textbox = gr.Textbox(
    lines=10, # Allow more lines for batch input
    label="Input Text (Polish or English - One sentence per line recommended)",
    placeholder="Enter text here. Longer sentences will be split into chunks (max 20 words)."
)
output_textbox = gr.Textbox(label="Translation Output", lines=10)

# Interface definition
interface = gr.Interface(
    fn=translate_batch,          # Use the batch function
    inputs=input_textbox,
    outputs=output_textbox,
    title="🇵🇱 <-> 🇬🇧 Batch ByT5 Translator (Auto-Detect, Chunking)",
    description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nLong sentences are automatically split into chunks of max {MAX_WORDS_PER_CHUNK} words.",
    article="Enter text (ideally one sentence per line). Click Submit to translate all lines.",
    allow_flagging="never"
)

# --- Launch the App ---
if __name__ == "__main__":
    # Set share=True for a public link if running locally, not needed on Spaces
    interface.launch()