alsubari commited on
Commit
5defb24
1 Parent(s): b3250c4

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +41 -0
inference.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from transformers import GPT2Tokenizer
4
+ from arabert.preprocess import ArabertPreprocessor
5
+ from arabert.aragpt2.grover.modeling_gpt2 import GPT2LMHeadModel
6
+ from pyarabic.araby import strip_tashkeel
7
+ import pyarabic.trans
8
+ model_name='alsubari/aragpt2-mega-pos-msa'
9
+
10
+
11
+ tokenizer = GPT2Tokenizer.from_pretrained('alsubari/aragpt2-mega-pos-msa')
12
+ model = GPT2LMHeadModel.from_pretrained('alsubari/aragpt2-mega-pos-msa').to("cuda")
13
+
14
+ arabert_prep = ArabertPreprocessor(model_name='aubmindlab/aragpt2-mega')
15
+ prml=['اعراب الجملة :', ' صنف الكلمات من الجملة :']
16
+ text='تعلَّمْ من أخطائِكَ'
17
+ text=arabert_prep.preprocess(strip_tashkeel(text))
18
+ generation_args = {
19
+ 'pad_token_id':tokenizer.eos_token_id,
20
+ 'max_length': 256,
21
+ 'num_beams':20,
22
+ 'no_repeat_ngram_size': 3,
23
+ 'top_k': 20,
24
+ 'top_p': 0.1, # Consider all tokens with non-zero probability
25
+ 'do_sample': True,
26
+ 'repetition_penalty':2.0
27
+ }
28
+ input_text = f'<|startoftext|>Instruction: {prml[1]} {text}<|pad|>Answer:'
29
+ input_ids = tokenizer.encode(input_text, return_tensors='pt').to("cuda")
30
+ output_ids = model.generate(input_ids=input_ids,**generation_args)
31
+ output_text = tokenizer.decode(output_ids[0],skip_special_tokens=True).split('Answer:')[1]
32
+ answer_pose=pyarabic.trans.delimite_language(output_text, start="<token>", end="</token>")
33
+
34
+ print(answer_pose)# <token>تعلم : تعلم</token> : Verb <token>من : من</token> : Relative pronoun <token>أخطائك : اخطا</token> : Noun <token>ك</token> : Personal pronunction
35
+
36
+ input_text = f'<|startoftext|>Instruction: {prml[0]} {text}<|pad|>Answer:'
37
+ input_ids = tokenizer.encode(input_text, return_tensors='pt').to("cuda")
38
+ output_ids = model.generate(input_ids=input_ids,**generation_args)
39
+ output_text = tokenizer.decode(output_ids[0],skip_special_tokens=True).split('Answer:')[1]
40
+
41
+ print(output_text)#تعلم : تعلم : فعل ، مفرد المخاطب للمذكر ، فعل مضارع ، مرفوع من : من : حرف جر أخطائك : اخطا : اسم ، جمع المذكر ، مجرور ك : ضمير ، مفرد المتكلم