Doron Adler commited on
Commit
bffdb0a
1 Parent(s): 68eb283

Added inference examples using the .pt and .onnx models

Browse files
examples/example-onnx-infer.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #Tested with the following Python package versions:
3
+ #optimum 1.2.3.dev0
4
+ #transformers 4.21.0.dev0
5
+ #tokenizers 0.11.6
6
+
7
+ from transformers import AutoTokenizer
8
+ from optimum.onnxruntime import ORTModelForCausalLM
9
+ from optimum.pipelines import pipeline
10
+
11
+
12
+ def main():
13
+ model_name="Norod78/distilgpt2-base-pretrained-he"
14
+
15
+ prompt_text = "שלום, קוראים לי"
16
+ generated_max_length = 192
17
+
18
+ print("Loading model...")
19
+ model = ORTModelForCausalLM.from_pretrained(model_name)
20
+ print('Loading Tokenizer...')
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ text_generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
23
+
24
+ print("Generating text...")
25
+ result = text_generator(prompt_text, num_return_sequences=1, batch_size=1, do_sample=True, top_k=40, top_p=0.92, temperature = 1, repetition_penalty=5.0, max_length = generated_max_length)
26
+
27
+ print("result = " + str(result))
28
+
29
+ if __name__ == '__main__':
30
+ main()
examples/example-pt-infer.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+
4
+ def main():
5
+ model_name="Norod78/distilgpt2-base-pretrained-he"
6
+
7
+ prompt_text = "שלום, קוראים לי"
8
+ generated_max_length = 192
9
+
10
+ print("Loading model...")
11
+ model = AutoModelForCausalLM.from_pretrained(model_name)
12
+ print('Loading Tokenizer...')
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ text_generator = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
15
+
16
+ print("Generating text...")
17
+ result = text_generator(prompt_text, num_return_sequences=1, batch_size=1, do_sample=True, top_k=40, top_p=0.92, temperature = 1, repetition_penalty=5.0, max_length = generated_max_length)
18
+
19
+ print("result = " + str(result))
20
+
21
+ if __name__ == '__main__':
22
+ main()