File size: 2,329 Bytes
53f9185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import transformers
from transformers import AutoTokenizer, GPTJForCausalLM

def get_sent(sent:str) -> str:
    input_text = '[BOS]' + sent + '[EOS][BOS]'
    input_length = len(tokenizer.encode(input_text))
    max_length = 786
    input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0).to('cuda')
    output_sentence = model.generate(
        input_ids,
        do_sample=True,
        max_length=int(max_length),
        num_return_sequences=1,
        no_repeat_ngram_size=4,
        num_beams=5,
        early_stopping=True
        )
    generated_sequence = output_sentence[0].tolist()[input_length:]
    decoded_sent = tokenizer.decode(generated_sequence, skip_special_tokens=False).strip()
    return decoded_sent


if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision="KoGPT6B-ryan1.5b")
    model = GPTJForCausalLM.from_pretrained('./', torch_dtype=torch.float16)
    model.cuda()

    text = """
    λ²ˆμ•„μ›ƒμ΄ μ˜¨κ²ƒμ„ μΈμ •ν•˜κΈ° μ‹«μ–΄μš”.
    2λ…„μ „λΆ€ν„° μš°μšΈμ¦μ— μ‹œλ‹¬λ ΈμŠ΅λ‹ˆλ‹€. 2λ…„λ™μ•ˆ μ €λ₯Ό μ•Œμ•„κ°€λ €κ³  μ• μ“°κ³  λ™μ‹œμ— μ €μ˜ 미래λ₯Ό μœ„ν•΄ λŠμž„μ—†μ΄ 자격증 곡뢀와 학업을 병행해 μ™”μ–΄μš”.
    ν•˜μ§€λ§Œ λΆˆμ•ˆμ •ν•œ 심리 μƒνƒœλ‘œ μΈν•˜μ—¬ 제 꿈과 μˆ˜λ§Žμ€ κΈ°νšŒλ“€μ„ μž‘μ§€ λͺ»ν•˜κ³  λ„λ§μ³λ²„λ ΈμŠ΅λ‹ˆλ‹€. μ–΄μ§Έμ„œμΌκΉŒμš” μ™œ μ €λŠ” 남듀보닀 λ²„ν‹°λŠ” 힘이 μ•½ν•œκ±ΈκΉŒμš”.
    λ‹€μ‹œ νž˜μ„λ‚΄μ„œ 도전도 해보고 μ—¬λŸ¬ 일을 해보며 λͺ¨λ“  μ—λ„ˆμ§€λ₯Ό μŸμ•„λΆ€μ—ˆλ”λ‹ˆ μ–΄λŠλ‚  ν•œκΈ€μ΄ 잘 νŒŒμ•…μ΄ μ•ˆλ˜κ³  κ°„λ‹¨ν•œ 결정을 λ‚΄λ¦¬λŠ” 것이 μ–΄λ €μ›Œ κ²°κ΅­ 직μž₯을 또 κ·Έλ§Œλ‘κ²Œ λ˜μ—ˆμŠ΅λ‹ˆλ‹€. λ‚˜μ•½ν•˜λ‹€κ³  ν•˜κΈ°μ—λŠ” λ³΄ν†΅μ‚¬λžŒλ“€μ΄ 버티기 μ–΄λ €μ›Œν•˜λŠ” 직쒅을 ν•˜κ³  μžˆκ±°λ“ μš”.
    λ„ˆλ¬΄ κ³Όλ„ν•˜κ²Œ νž˜μ„ λ‚΄μ„œ μ—΄μ‹¬νžˆ ν•œ νƒ“μΌκΉŒμš”? μ΄μ œλŠ” 정말 배터리가 1%도 λ‚¨μ•„μžˆμ§€ μ•ŠλŠ” 것 κ°™μ•„μš”. 근데 μ‰¬λ €λ‹ˆκΉŒ λΆˆμ•ˆν•΄μ§€κ³  λΆ€λͺ¨λ‹˜κ»˜ λ―Έμ•ˆν•΄μš” μ•„ν”„μ§€λ§Œ μ•Šμ•˜μ–΄λ„ λ‚˜λ„ 자리 잘 μž‘μ•„μ„œ 잘 μ§€λƒˆμ„ ν…λ°μš”.. μ•žμœΌλ‘œ μ €λŠ” κ·Έλƒ₯ νœ΄μ‹μ„ μ·¨ν•˜λŠ” 것이 λ§žμ„κΉŒμš”? μ €λŠ” μ–΄λ–€ μƒνƒœμ— μžˆλŠ” κ±ΈκΉŒμš”?
    """

    text = text.replace('\n','')
    result = get_sent(text)
    result = result.replace('μ‚¬μš°λ‹˜', 'λ§ˆμΉ΄λ‹˜')
    print('λ‘œλ‹ˆν˜•:', result)