Spaces:
Sleeping
Sleeping
import numpy as np | |
import onnxruntime as ort | |
from transformers import AutoTokenizer | |
import gradio as gr | |
# Load the ONNX model and tokenizer | |
model_path = "model.onnx" | |
translation_session = ort.InferenceSession(model_path) | |
translation_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr") | |
def translate_text(input_text): | |
# Tokenize input text | |
tokenized_input = translation_tokenizer( | |
input_text, return_tensors="np", padding=True, truncation=True, max_length=512 | |
) | |
# Prepare encoder inputs | |
input_ids = tokenized_input["input_ids"].astype(np.int64) | |
attention_mask = tokenized_input["attention_mask"].astype(np.int64) | |
# Prepare decoder inputs (start with the start token) | |
decoder_start_token_id = translation_tokenizer.cls_token_id or translation_tokenizer.pad_token_id | |
decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64) | |
# Iteratively generate output tokens | |
translated_tokens = [] | |
for _ in range(512): # Max length of output | |
# Run inference with the ONNX model | |
outputs = translation_session.run( | |
None, | |
{ | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"decoder_input_ids": decoder_input_ids, | |
} | |
) | |
# Get the next token ID | |
next_token_id = np.argmax(outputs[0][0, -1, :], axis=-1) | |
translated_tokens.append(next_token_id) | |
# Stop if the end-of-sequence token is generated | |
if next_token_id == translation_tokenizer.eos_token_id: | |
break | |
# Update decoder_input_ids for the next iteration | |
decoder_input_ids = np.concatenate( | |
[decoder_input_ids, np.array([[next_token_id]], dtype=np.int64)], axis=1 | |
) | |
# Decode the output tokens | |
translated_text = translation_tokenizer.decode(translated_tokens, skip_special_tokens=True) | |
return translated_text | |
# Create a Gradio interface | |
interface = gr.Interface( | |
fn=translate_text, | |
inputs="text", | |
outputs="text", | |
title="Frenchizer Translation Model", | |
description="Translate text from English to French using an ONNX model." | |
) | |
# Launch the Gradio app | |
interface.launch() |