import numpy as np import onnxruntime as ort from transformers import MarianTokenizer import gradio as gr # Load the tokenizer from the local folder tokenizer_path = "./onnx_model" # Path to the local tokenizer folder tokenizer = MarianTokenizer.from_pretrained(tokenizer_path) # Load the ONNX model onnx_model_path = "./model.onnx" session = ort.InferenceSession(onnx_model_path) def translate(texts, max_length=512): # Tokenize the input texts inputs = tokenizer(texts, return_tensors="np", padding=True, truncation=True, max_length=max_length) input_ids = inputs["input_ids"].astype(np.int64) attention_mask = inputs["attention_mask"].astype(np.int64) # Initialize variables for decoding batch_size = input_ids.shape[0] decoder_input_ids = np.array([[tokenizer.pad_token_id]] * batch_size, dtype=np.int64) # Start with pad token eos_reached = np.zeros(batch_size, dtype=bool) # Track which sequences have finished # Generate output tokens iteratively for _ in range(max_length): # Run the ONNX model onnx_outputs = session.run( None, { "input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, }, ) # Get the next token logits (output of the ONNX model) next_token_logits = onnx_outputs[0][:, -1, :] # Shape: (batch_size, vocab_size) # Greedy decoding: select the token with the highest probability next_tokens = np.argmax(next_token_logits, axis=-1) # Shape: (batch_size,) # Append the next tokens to the decoder input for the next iteration decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens[:, None]], axis=-1) # Check if the EOS token has been generated for each sequence eos_reached = eos_reached | (next_tokens == tokenizer.eos_token_id) # Stop if all sequences have reached the EOS token if all(eos_reached): break # Decode the output tokens to text translations = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True) return translations # Gradio interface def gradio_translate(input_text): # Split the input text into lines (assuming one sentence per line) texts = input_text.strip().split("\n") translations = translate(texts) # Join the translations into a single string with line breaks return "\n".join(translations) # Create the Gradio interface interface = gr.Interface( fn=gradio_translate, inputs=gr.Textbox(lines=5, placeholder="Enter text to translate...", label="Input Text"), outputs=gr.Textbox(lines=5, label="Translated Text"), title="ONNX English to French Translation", description="Translate English text to French using a MarianMT ONNX model.", ) # Launch the Gradio app interface.launch()