ppisljar commited on
Commit
768ff29
1 Parent(s): 5d224b7

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +17 -10
infer.py CHANGED
@@ -8,13 +8,20 @@ tokenizer = AutoTokenizer.from_pretrained('google/byt5-small')
8
  sentence = "Kupil sem bicikel in mu zamenjal stol.".lower()
9
 
10
  ort_session = onnxruntime.InferenceSession("g2p_t5.onnx", providers=["CPUExecutionProvider"])
11
- input_ids = [sentence]
12
- input_encoding = tokenizer(
13
- input_ids, padding='longest', max_length=512, truncation=True, return_tensors='pt',
14
- )
15
- input_ids, attention_mask = input_encoding.input_ids, input_encoding.attention_mask
16
- ort_inputs = {'input_ids': input_ids.numpy()}
17
- ort_outs = ort_session.run(None, ort_inputs)
18
- generated_ids = [ort_outs[0]]
19
- generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
20
- print(generated_texts)
 
 
 
 
 
 
 
 
8
  sentence = "Kupil sem bicikel in mu zamenjal stol.".lower()
9
 
10
  ort_session = onnxruntime.InferenceSession("g2p_t5.onnx", providers=["CPUExecutionProvider"])
11
+
12
+
13
+ def g2p(sentence, onnx_session, tokenizer):
14
+ input_ids = [sentence]
15
+ input_encoding = tokenizer(
16
+ input_ids, padding='longest', max_length=512, truncation=True, return_tensors='pt',
17
+ )
18
+ input_ids, attention_mask = input_encoding.input_ids, input_encoding.attention_mask
19
+ ort_inputs = {'input_ids': input_ids.numpy()}
20
+ ort_outs = ort_session.run(None, ort_inputs)
21
+ generated_ids = [ort_outs[0]]
22
+ generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
23
+ return generated_texts
24
+
25
+
26
+ result = g2p(sentence, ort_session, tokenizer)
27
+ print(result)