slo_g2p_byt5 / infer.py
ppisljar's picture
Update infer.py
768ff29 verified
raw
history blame contribute delete
No virus
896 Bytes
import onnxruntime
import torch
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('google/byt5-small')
sentence = "Kupil sem bicikel in mu zamenjal stol.".lower()
ort_session = onnxruntime.InferenceSession("g2p_t5.onnx", providers=["CPUExecutionProvider"])
def g2p(sentence, onnx_session, tokenizer):
input_ids = [sentence]
input_encoding = tokenizer(
input_ids, padding='longest', max_length=512, truncation=True, return_tensors='pt',
)
input_ids, attention_mask = input_encoding.input_ids, input_encoding.attention_mask
ort_inputs = {'input_ids': input_ids.numpy()}
ort_outs = ort_session.run(None, ort_inputs)
generated_ids = [ort_outs[0]]
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return generated_texts
result = g2p(sentence, ort_session, tokenizer)
print(result)