File size: 2,236 Bytes
6286e43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe50c4c
6286e43
 
 
fe50c4c
 
 
6286e43
fe50c4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6286e43
fe50c4c
 
 
 
 
 
 
 
 
 
 
6286e43
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
import gradio as gr

# Load the ONNX model and tokenizer
model_path = "model.onnx"
translation_session = ort.InferenceSession(model_path)
translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")

def translate_text(input_text):
    # Tokenize input text
    tokenized_input = translation_tokenizer(
        input_text, return_tensors="np", padding=True, truncation=True, max_length=512
    )
    
    # Prepare encoder inputs
    input_ids = tokenized_input["input_ids"].astype(np.int64)
    attention_mask = tokenized_input["attention_mask"].astype(np.int64)

    # Prepare decoder inputs (start with the start token)
    decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id
    decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)

    # Iteratively generate output tokens
    translated_tokens = []
    for _ in range(512):  # Max length of output
        # Run inference with the ONNX model
        outputs = translation_session.run(
            None,
            {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "decoder_input_ids": decoder_input_ids,
            }
        )

        # Get the next token ID
        next_token_id = np.argmax(outputs[0][0, -1, :], axis=-1)
        translated_tokens.append(next_token_id)

        # Stop if the end-of-sequence token is generated
        if next_token_id == translation_tokenizer.eos_token_id:
            break

        # Update decoder_input_ids for the next iteration
        decoder_input_ids = np.concatenate(
            [decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1
        )

    # Decode the output tokens
    translated_text = translation_tokenizer.decode(translated_tokens, skip_special_tokens=True)
    return translated_text

# Create a Gradio interface
interface = gr.Interface(
    fn=translate_text,
    inputs="text",
    outputs="text",
    title="Frenchizer Translation Model",
    description="Translate text from English to French using an ONNX model."
)

# Launch the Gradio app
interface.launch()