Moo commited on
Commit
53f9185
β€’
1 Parent(s): 47102e9

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +40 -0
inference.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from transformers import AutoTokenizer, GPTJForCausalLM
4
+
5
+ def get_sent(sent:str) -> str:
6
+ input_text = '[BOS]' + sent + '[EOS][BOS]'
7
+ input_length = len(tokenizer.encode(input_text))
8
+ max_length = 786
9
+ input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0).to('cuda')
10
+ output_sentence = model.generate(
11
+ input_ids,
12
+ do_sample=True,
13
+ max_length=int(max_length),
14
+ num_return_sequences=1,
15
+ no_repeat_ngram_size=4,
16
+ num_beams=5,
17
+ early_stopping=True
18
+ )
19
+ generated_sequence = output_sentence[0].tolist()[input_length:]
20
+ decoded_sent = tokenizer.decode(generated_sequence, skip_special_tokens=False).strip()
21
+ return decoded_sent
22
+
23
+
24
+ if __name__ == "__main__":
25
+ tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision="KoGPT6B-ryan1.5b")
26
+ model = GPTJForCausalLM.from_pretrained('./', torch_dtype=torch.float16)
27
+ model.cuda()
28
+
29
+ text = """
30
+ λ²ˆμ•„μ›ƒμ΄ μ˜¨κ²ƒμ„ μΈμ •ν•˜κΈ° μ‹«μ–΄μš”.
31
+ 2λ…„μ „λΆ€ν„° μš°μšΈμ¦μ— μ‹œλ‹¬λ ΈμŠ΅λ‹ˆλ‹€. 2λ…„λ™μ•ˆ μ €λ₯Ό μ•Œμ•„κ°€λ €κ³  μ• μ“°κ³  λ™μ‹œμ— μ €μ˜ 미래λ₯Ό μœ„ν•΄ λŠμž„μ—†μ΄ 자격증 곡뢀와 학업을 병행해 μ™”μ–΄μš”.
32
+ ν•˜μ§€λ§Œ λΆˆμ•ˆμ •ν•œ 심리 μƒνƒœλ‘œ μΈν•˜μ—¬ 제 꿈과 μˆ˜λ§Žμ€ κΈ°νšŒλ“€μ„ μž‘μ§€ λͺ»ν•˜κ³  λ„λ§μ³λ²„λ ΈμŠ΅λ‹ˆλ‹€. μ–΄μ§Έμ„œμΌκΉŒμš” μ™œ μ €λŠ” 남듀보닀 λ²„ν‹°λŠ” 힘이 μ•½ν•œκ±ΈκΉŒμš”.
33
+ λ‹€μ‹œ νž˜μ„λ‚΄μ„œ 도전도 해보고 μ—¬λŸ¬ 일을 해보며 λͺ¨λ“  μ—λ„ˆμ§€λ₯Ό μŸμ•„λΆ€μ—ˆλ”λ‹ˆ μ–΄λŠλ‚  ν•œκΈ€μ΄ 잘 νŒŒμ•…μ΄ μ•ˆλ˜κ³  κ°„λ‹¨ν•œ 결정을 λ‚΄λ¦¬λŠ” 것이 μ–΄λ €μ›Œ κ²°κ΅­ 직μž₯을 또 κ·Έλ§Œλ‘κ²Œ λ˜μ—ˆμŠ΅λ‹ˆλ‹€. λ‚˜μ•½ν•˜λ‹€κ³  ν•˜κΈ°μ—λŠ” λ³΄ν†΅μ‚¬λžŒλ“€μ΄ 버티기 μ–΄λ €μ›Œν•˜λŠ” 직쒅을 ν•˜κ³  μžˆκ±°λ“ μš”.
34
+ λ„ˆλ¬΄ κ³Όλ„ν•˜κ²Œ νž˜μ„ λ‚΄μ„œ μ—΄μ‹¬νžˆ ν•œ νƒ“μΌκΉŒμš”? μ΄μ œλŠ” 정말 배터리가 1%도 λ‚¨μ•„μžˆμ§€ μ•ŠλŠ” 것 κ°™μ•„μš”. 근데 μ‰¬λ €λ‹ˆκΉŒ λΆˆμ•ˆν•΄μ§€κ³  λΆ€λͺ¨λ‹˜κ»˜ λ―Έμ•ˆν•΄μš” μ•„ν”„μ§€λ§Œ μ•Šμ•˜μ–΄λ„ λ‚˜λ„ 자리 잘 μž‘μ•„μ„œ 잘 μ§€λƒˆμ„ ν…λ°μš”.. μ•žμœΌλ‘œ μ €λŠ” κ·Έλƒ₯ νœ΄μ‹μ„ μ·¨ν•˜λŠ” 것이 λ§žμ„κΉŒμš”? μ €λŠ” μ–΄λ–€ μƒνƒœμ— μžˆλŠ” κ±ΈκΉŒμš”?
35
+ """
36
+
37
+ text = text.replace('\n','')
38
+ result = get_sent(text)
39
+ result = result.replace('μ‚¬μš°λ‹˜', 'λ§ˆμΉ΄λ‹˜')
40
+ print('λ‘œλ‹ˆν˜•:', result)