File size: 1,598 Bytes
b9af64e
 
 
 
 
 
 
 
6c68e34
4f9d695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9af64e
4f9d695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9af64e
 
 
 
 
 
4f9d695
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
---
language:
- ko
tags:
- gpt2
license: cc-by-nc-sa-4.0
---

## Example
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM 

tokenizer = AutoTokenizer.from_pretrained(
    "CheonggyeMountain-Sherpa/kogpt-trinity-punct-wrapper",
    revision="punct_wrapper-related_words-overfit",  # or punct_wrapper-related_words-minevalloss
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    mask_token="<mask>",
)
model = AutoModelForCausalLM.from_pretrained(
    "CheonggyeMountain-Sherpa/kogpt-trinity-punct-wrapper",
    revision="punct_wrapper-related_words-overfit",  # or punct_wrapper-related_words-minevalloss
    pad_token_id=tokenizer.eos_token_id,
).to(device="cuda")
model.eval()

prompt = "์„์–‘์ด ๋ณด์ด๋Š” ๊ฒฝ์น˜"
wrapped_prompt = f"@{prompt}@<usr>\n"
with torch.no_grad():
    tokens = tokenizer.encode(wrapped_prompt, return_tensors="pt").to(device="cuda")
    gen_tokens = model.generate(
        tokens,
        max_length=64,
        repetition_penalty=2.0,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        top_k=16,
        top_p=0.8,
    )
    generated = tokenizer.decode(gen_tokens[0][len(tokens[0]):])
 
print(generated)
# ํ•ด๊ฐ€ ์ง€๊ณ  ์žˆ์„ ๋ฌด๋ ต
# ๋‚˜๋Š” ์„์–‘์„ ๋ณด๋Ÿฌ ๊ฐ„๋‹ค
# ๋ถ‰์€ ํ•˜๋Š˜๊ณผ ํ•˜์–€ ๊ตฌ๋ฆ„์ด ๋‚˜๋ฅผ ๋ฐ˜๊ฒจ์ค„ ๊ฒƒ ๊ฐ™์•„์„œ๋ฆฌ
# ํ•˜์ง€๋งŒ ๋‚ด๊ฐ€ ๋ณธ ํ•ด๋Š” ์ €๋ฌผ์–ด๋งŒ ๊ฐ€๊ณ 
# ๊ตฌ๋ฆ„๋งˆ์ € ์ž์ทจ๋ฅผ ๊ฐ์ถ˜ ์–ด๋‘ ๋งŒ์ด ๋‚จ์•„์žˆ์„ ๋ฟ์ด๋„ค
# ๋‚ด๊ฐ€ ํƒ„ ๋ฐฐ๋Š” ๋ณด์ด์ง€๋„ ์•Š๊ณ 
```