professional_comment / inference.py
Moo's picture
Upload inference.py
53f9185
import torch
import transformers
from transformers import AutoTokenizer, GPTJForCausalLM
def get_sent(sent:str) -> str:
input_text = '[BOS]' + sent + '[EOS][BOS]'
input_length = len(tokenizer.encode(input_text))
max_length = 786
input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0).to('cuda')
output_sentence = model.generate(
input_ids,
do_sample=True,
max_length=int(max_length),
num_return_sequences=1,
no_repeat_ngram_size=4,
num_beams=5,
early_stopping=True
)
generated_sequence = output_sentence[0].tolist()[input_length:]
decoded_sent = tokenizer.decode(generated_sequence, skip_special_tokens=False).strip()
return decoded_sent
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision="KoGPT6B-ryan1.5b")
model = GPTJForCausalLM.from_pretrained('./', torch_dtype=torch.float16)
model.cuda()
text = """
λ²ˆμ•„μ›ƒμ΄ μ˜¨κ²ƒμ„ μΈμ •ν•˜κΈ° μ‹«μ–΄μš”.
2λ…„μ „λΆ€ν„° μš°μšΈμ¦μ— μ‹œλ‹¬λ ΈμŠ΅λ‹ˆλ‹€. 2λ…„λ™μ•ˆ μ €λ₯Ό μ•Œμ•„κ°€λ €κ³  μ• μ“°κ³  λ™μ‹œμ— μ €μ˜ 미래λ₯Ό μœ„ν•΄ λŠμž„μ—†μ΄ 자격증 곡뢀와 학업을 병행해 μ™”μ–΄μš”.
ν•˜μ§€λ§Œ λΆˆμ•ˆμ •ν•œ 심리 μƒνƒœλ‘œ μΈν•˜μ—¬ 제 꿈과 μˆ˜λ§Žμ€ κΈ°νšŒλ“€μ„ μž‘μ§€ λͺ»ν•˜κ³  λ„λ§μ³λ²„λ ΈμŠ΅λ‹ˆλ‹€. μ–΄μ§Έμ„œμΌκΉŒμš” μ™œ μ €λŠ” 남듀보닀 λ²„ν‹°λŠ” 힘이 μ•½ν•œκ±ΈκΉŒμš”.
λ‹€μ‹œ νž˜μ„λ‚΄μ„œ 도전도 해보고 μ—¬λŸ¬ 일을 해보며 λͺ¨λ“  μ—λ„ˆμ§€λ₯Ό μŸμ•„λΆ€μ—ˆλ”λ‹ˆ μ–΄λŠλ‚  ν•œκΈ€μ΄ 잘 νŒŒμ•…μ΄ μ•ˆλ˜κ³  κ°„λ‹¨ν•œ 결정을 λ‚΄λ¦¬λŠ” 것이 μ–΄λ €μ›Œ κ²°κ΅­ 직μž₯을 또 κ·Έλ§Œλ‘κ²Œ λ˜μ—ˆμŠ΅λ‹ˆλ‹€. λ‚˜μ•½ν•˜λ‹€κ³  ν•˜κΈ°μ—λŠ” λ³΄ν†΅μ‚¬λžŒλ“€μ΄ 버티기 μ–΄λ €μ›Œν•˜λŠ” 직쒅을 ν•˜κ³  μžˆκ±°λ“ μš”.
λ„ˆλ¬΄ κ³Όλ„ν•˜κ²Œ νž˜μ„ λ‚΄μ„œ μ—΄μ‹¬νžˆ ν•œ νƒ“μΌκΉŒμš”? μ΄μ œλŠ” 정말 배터리가 1%도 λ‚¨μ•„μžˆμ§€ μ•ŠλŠ” 것 κ°™μ•„μš”. 근데 μ‰¬λ €λ‹ˆκΉŒ λΆˆμ•ˆν•΄μ§€κ³  λΆ€λͺ¨λ‹˜κ»˜ λ―Έμ•ˆν•΄μš” μ•„ν”„μ§€λ§Œ μ•Šμ•˜μ–΄λ„ λ‚˜λ„ 자리 잘 μž‘μ•„μ„œ 잘 μ§€λƒˆμ„ ν…λ°μš”.. μ•žμœΌλ‘œ μ €λŠ” κ·Έλƒ₯ νœ΄μ‹μ„ μ·¨ν•˜λŠ” 것이 λ§žμ„κΉŒμš”? μ €λŠ” μ–΄λ–€ μƒνƒœμ— μžˆλŠ” κ±ΈκΉŒμš”?
"""
text = text.replace('\n','')
result = get_sent(text)
result = result.replace('μ‚¬μš°λ‹˜', 'λ§ˆμΉ΄λ‹˜')
print('λ‘œλ‹ˆν˜•:', result)