File size: 2,900 Bytes
c9d7167
 
65ed74c
c9d7167
 
65ed74c
6f64d0b
f80fc89
fc36581
65ed74c
 
 
c9d7167
f80fc89
0760540
f80fc89
0760540
 
 
 
 
 
c4b718f
0760540
 
 
 
f80fc89
0760540
 
 
 
 
 
 
 
 
f80fc89
0760540
 
 
 
 
 
 
c4b718f
 
 
0760540
c4b718f
0760540
 
 
 
 
65ed74c
 
f80fc89
 
 
 
 
 
0760540
 
65ed74c
0760540
f80fc89
 
0760540
c4b718f
65ed74c
c9d7167
0760540
c9d7167
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import onnxruntime as ort
from transformers import MarianTokenizer
import gradio as gr

# Load the tokenizer from the local folder
tokenizer_path = "./onnx_model"  # Path to the local tokenizer folder
tokenizer = MarianTokenizer.from_pretrained(tokenizer_path)

# Load the ONNX model
onnx_model_path = "./model.onnx"
session = ort.InferenceSession(onnx_model_path)

def translate(texts, max_length=512):
    # Tokenize the input texts
    inputs = tokenizer(texts, return_tensors="np", padding=True, truncation=True, max_length=max_length)
    input_ids = inputs["input_ids"].astype(np.int64)
    attention_mask = inputs["attention_mask"].astype(np.int64)

    # Initialize variables for decoding
    batch_size = input_ids.shape[0]
    decoder_input_ids = np.array([[tokenizer.pad_token_id]] * batch_size, dtype=np.int64)  # Start with pad token
    eos_reached = np.zeros(batch_size, dtype=bool)  # Track which sequences have finished

    # Generate output tokens iteratively
    for _ in range(max_length):
        # Run the ONNX model
        onnx_outputs = session.run(
            None,
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "decoder_input_ids": decoder_input_ids,
            },
        )

        # Get the next token logits (output of the ONNX model)
        next_token_logits = onnx_outputs[0][:, -1, :]  # Shape: (batch_size, vocab_size)

        # Greedy decoding: select the token with the highest probability
        next_tokens = np.argmax(next_token_logits, axis=-1)  # Shape: (batch_size,)

        # Append the next tokens to the decoder input for the next iteration
        decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens[:, None]], axis=-1)

        # Check if the EOS token has been generated for each sequence
        eos_reached = eos_reached | (next_tokens == tokenizer.eos_token_id)

        # Stop if all sequences have reached the EOS token
        if all(eos_reached):
            break

    # Decode the output tokens to text
    translations = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=True)
    return translations

# Gradio interface
def gradio_translate(input_text):
    # Split the input text into lines (assuming one sentence per line)
    texts = input_text.strip().split("\n")
    translations = translate(texts)
    # Join the translations into a single string with line breaks
    return "\n".join(translations)

# Create the Gradio interface
interface = gr.Interface(
    fn=gradio_translate,
    inputs=gr.Textbox(lines=5, placeholder="Enter text to translate...", label="Input Text"),
    outputs=gr.Textbox(lines=5, label="Translated Text"),
    title="ONNX English to French Translation",
    description="Translate English text to French using a MarianMT ONNX model.",
)

# Launch the Gradio app
interface.launch()