Spaces:
Sleeping
Sleeping
import numpy as np | |
import onnxruntime as ort | |
from transformers import MarianTokenizer | |
import gradio as gr | |
# Load the tokenizer from the local folder | |
tokenizer_path = "./onnx_model" # Path to the local tokenizer folder | |
tokenizer = MarianTokenizer.from_pretrained(tokenizer_path) | |
# Load the ONNX model | |
onnx_model_path = "./model.onnx" | |
session = ort.InferenceSession(onnx_model_path) | |
def translate(texts, max_length=512): | |
# Tokenize the input texts | |
inputs = tokenizer(texts, return_tensors="np", padding=True, truncation=True, max_length=max_length) | |
input_ids = inputs["input_ids"].astype(np.int64) | |
attention_mask = inputs["attention_mask"].astype(np.int64) | |
# Initialize variables for decoding | |
batch_size = input_ids.shape[0] | |
decoder_input_ids = np.array([[tokenizer.pad_token_id]] * batch_size, dtype=np.int64) # Start with pad token | |
eos_reached = np.zeros(batch_size, dtype=bool) # Track which sequences have finished | |
# Generate output tokens iteratively | |
for _ in range(max_length): | |
# Run the ONNX model | |
onnx_outputs = session.run( | |
None, | |
{ | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"decoder_input_ids": decoder_input_ids, | |
}, | |
) | |
# Get the next token logits (output of the ONNX model) | |
next_token_logits = onnx_outputs[0][:, -1, :] # Shape: (batch_size, vocab_size) | |
# Greedy decoding: select the token with the highest probability | |
next_tokens = np.argmax(next_token_logits, axis=-1) # Shape: (batch_size,) | |
# Append the next tokens to the decoder input for the next iteration | |
decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens[:, None]], axis=-1) | |
# Check if the EOS token has been generated for each sequence | |
eos_reached = eos_reached | (next_tokens == tokenizer.eos_token_id) | |
# Stop if all sequences have reached the EOS token | |
if all(eos_reached): | |
break | |
# Decode the output tokens to text | |
translations = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True) | |
return translations | |
# Gradio interface | |
def gradio_translate(input_text): | |
# Split the input text into lines (assuming one sentence per line) | |
texts = input_text.strip().split("\n") | |
translations = translate(texts) | |
# Join the translations into a single string with line breaks | |
return "\n".join(translations) | |
# Create the Gradio interface | |
interface = gr.Interface( | |
fn=gradio_translate, | |
inputs=gr.Textbox(lines=5, placeholder="Enter text to translate...", label="Input Text"), | |
outputs=gr.Textbox(lines=5, label="Translated Text"), | |
title="ONNX English to French Translation", | |
description="Translate English text to French using a MarianMT ONNX model.", | |
) | |
# Launch the Gradio app | |
interface.launch() |