Spaces:
Sleeping
Sleeping
File size: 2,900 Bytes
c9d7167 65ed74c c9d7167 65ed74c 6f64d0b f80fc89 fc36581 65ed74c c9d7167 f80fc89 0760540 f80fc89 0760540 c4b718f 0760540 f80fc89 0760540 f80fc89 0760540 c4b718f 0760540 c4b718f 0760540 65ed74c f80fc89 0760540 65ed74c 0760540 f80fc89 0760540 c4b718f 65ed74c c9d7167 0760540 c9d7167 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import numpy as np
import onnxruntime as ort
from transformers import MarianTokenizer
import gradio as gr
# Load the tokenizer from the local folder
tokenizer_path = "./onnx_model" # Path to the local tokenizer folder
tokenizer = MarianTokenizer.from_pretrained(tokenizer_path)
# Load the ONNX model
onnx_model_path = "./model.onnx"
session = ort.InferenceSession(onnx_model_path)
def translate(texts, max_length=512):
# Tokenize the input texts
inputs = tokenizer(texts, return_tensors="np", padding=True, truncation=True, max_length=max_length)
input_ids = inputs["input_ids"].astype(np.int64)
attention_mask = inputs["attention_mask"].astype(np.int64)
# Initialize variables for decoding
batch_size = input_ids.shape[0]
decoder_input_ids = np.array([[tokenizer.pad_token_id]] * batch_size, dtype=np.int64) # Start with pad token
eos_reached = np.zeros(batch_size, dtype=bool) # Track which sequences have finished
# Generate output tokens iteratively
for _ in range(max_length):
# Run the ONNX model
onnx_outputs = session.run(
None,
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
},
)
# Get the next token logits (output of the ONNX model)
next_token_logits = onnx_outputs[0][:, -1, :] # Shape: (batch_size, vocab_size)
# Greedy decoding: select the token with the highest probability
next_tokens = np.argmax(next_token_logits, axis=-1) # Shape: (batch_size,)
# Append the next tokens to the decoder input for the next iteration
decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens[:, None]], axis=-1)
# Check if the EOS token has been generated for each sequence
eos_reached = eos_reached | (next_tokens == tokenizer.eos_token_id)
# Stop if all sequences have reached the EOS token
if all(eos_reached):
break
# Decode the output tokens to text
translations = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
return translations
# Gradio interface
def gradio_translate(input_text):
# Split the input text into lines (assuming one sentence per line)
texts = input_text.strip().split("\n")
translations = translate(texts)
# Join the translations into a single string with line breaks
return "\n".join(translations)
# Create the Gradio interface
interface = gr.Interface(
fn=gradio_translate,
inputs=gr.Textbox(lines=5, placeholder="Enter text to translate...", label="Input Text"),
outputs=gr.Textbox(lines=5, label="Translated Text"),
title="ONNX English to French Translation",
description="Translate English text to French using a MarianMT ONNX model.",
)
# Launch the Gradio app
interface.launch() |