DONGJOONSHIN
commited on
Commit
โข
9e4703e
1
Parent(s):
3e6d04f
first
Browse files- config.json +45 -0
- klue-roberta-base-kornli.ipynb +854 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- tokenizer.json +0 -0
- tokenizer_config.json +1 -0
- training_args.bin +3 -0
- vocab.txt +0 -0
config.json
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "ehdwns1516/klue-roberta-base_sae",
|
3 |
+
"architectures": [
|
4 |
+
"RobertaForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"eos_token_id": 2,
|
9 |
+
"gradient_checkpointing": false,
|
10 |
+
"hidden_act": "gelu",
|
11 |
+
"hidden_dropout_prob": 0.1,
|
12 |
+
"hidden_size": 768,
|
13 |
+
"id2label": {
|
14 |
+
"0": "yes/no",
|
15 |
+
"1": "alternative",
|
16 |
+
"2": "wh- questions",
|
17 |
+
"3": "prohibitions",
|
18 |
+
"4": "requirements",
|
19 |
+
"5": "strong requirements"
|
20 |
+
},
|
21 |
+
"initializer_range": 0.02,
|
22 |
+
"intermediate_size": 3072,
|
23 |
+
"label2id": {
|
24 |
+
"yes/no": 0,
|
25 |
+
"alternative": 1,
|
26 |
+
"wh- questions": 2,
|
27 |
+
"prohibitions": 3,
|
28 |
+
"requirements": 4,
|
29 |
+
"strong requirements": 5
|
30 |
+
},
|
31 |
+
"layer_norm_eps": 1e-05,
|
32 |
+
"max_position_embeddings": 512,
|
33 |
+
"model_type": "roberta",
|
34 |
+
"num_attention_heads": 12,
|
35 |
+
"num_hidden_layers": 12,
|
36 |
+
"pad_token_id": 1,
|
37 |
+
"position_embedding_type": "absolute",
|
38 |
+
"problem_type": "single_label_classification",
|
39 |
+
"tokenizer_class": "BertTokenizer",
|
40 |
+
"torch_dtype": "float32",
|
41 |
+
"transformers_version": "4.9.2",
|
42 |
+
"type_vocab_size": 1,
|
43 |
+
"use_cache": true,
|
44 |
+
"vocab_size": 32000
|
45 |
+
}
|
klue-roberta-base-kornli.ipynb
ADDED
@@ -0,0 +1,854 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "4ef9f047",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# HuggingFace Transformers๋ฅผ ์ด์ฉํ NLI ๋ชจ๋ธ ํ์ต\n",
|
9 |
+
"\n",
|
10 |
+
"* ๋ณธ ๋
ธํธ๋ถ์์๋ klue/roberta-base ๋ชจ๋ธ์ kakaobrain์์ ์ ๊ณตํ๋ KorNLI ๋ฐ์ดํฐ์
์ ํ์ฉํ์ฌ ๋ชจ๋ธ์ ํ๋ จํ๋ ์์ ์
๋๋ค. \n",
|
11 |
+
"\n",
|
12 |
+
"* KorNLI ๋ฐ์ดํฐ์
์ [kakaobrain](https://github.com/kakaobrain/KorNLUDatasets?fbclid=IwAR0LX_jem7qb6HUikflO-F6lPpfoefK9Yc0jSQIdSKdkX4s8SW1UvoVGc7I)์์ ์ ๊ณตํ๋ ๋ฐ์ดํฐ์
์ด๋ฉฐ, KLUE์ NLI ๋ฐ์ดํฐ์
๋ณด๋ค ๋ฐ์ดํฐ ํฌ๊ธฐ๊ฐ ๋ ์ปค์ ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฝ๋๋ค. \n",
|
13 |
+
"\n",
|
14 |
+
"* ๋ชจ๋ ์์ค ์ฝ๋๋ [huggingface-notebook](https://github.com/huggingface/notebooks)์ ์ฐธ๊ณ ํ์์ต๋๋ค.\n",
|
15 |
+
"\n",
|
16 |
+
"* ๋ณธ ๋
ธํธ๋ถ์ ํ์ต ๋ด์ฉ ๋๋ถ๋ถ์ Huffon๋์ [klue-transformers-tutorial](https://github.com/Huffon/klue-transformers-tutorial)์ ์ฐธ๊ณ ํ์์ต๋๋ค.\n",
|
17 |
+
"\n",
|
18 |
+
"* ํ์ต์ ํตํด ์ป์ด์ง klue-roberta-base-kornli ๋ชจ๋ธ์ ์
๋ ฅ๋ ๋ ๋ฌธ์ฅ์ ์ถ๋ก ๊ด๊ณ๋ฅผ ์์ธกํ๋๋ฐ ์ฌ์ฉํ ์ ์๊ฒ ๋ฉ๋๋ค.\n",
|
19 |
+
"\n",
|
20 |
+
"* ๋
ธํธ๋ถ์ ํ๊ฒฝ์ ainize workspace ์
๋๋ค. [ainze](https://ainize.ai/) ํํ์ด์ง ์ค๋ฅธ์ชฝ ์๋จ์ github๋ก ์์ด๋ ์์ฑ ํ My space์์ workspace๋ฅผ ์์ฑํ ์ ์์ต๋๋ค. ์ ๊ฐ ์ฌ์ฉํ workspace์ GPU ํ๊ฒฝ์ tesla V100 32GB์
๋๋ค.\n",
|
21 |
+
"\n",
|
22 |
+
"\n",
|
23 |
+
"## ๋
ธํธ๋ถ ํ๊ฒฝ ์ค์ \n",
|
24 |
+
"\n",
|
25 |
+
"๋ชจ๋ธ์ ํ์ต์ํค๊ธฐ ์ ์ ์์ ์ ๋
ธํธ๋ถ ํ๊ฒฝ์ ํ์ธํ๊ณ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํฉ๋๋ค."
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 1,
|
31 |
+
"id": "9ddee489",
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [
|
34 |
+
{
|
35 |
+
"name": "stdout",
|
36 |
+
"output_type": "stream",
|
37 |
+
"text": [
|
38 |
+
"True\n"
|
39 |
+
]
|
40 |
+
}
|
41 |
+
],
|
42 |
+
"source": [
|
43 |
+
"import torch\n",
|
44 |
+
"use_cuda = torch.cuda.is_available()\n",
|
45 |
+
"print(use_cuda)"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "markdown",
|
50 |
+
"id": "68700205",
|
51 |
+
"metadata": {},
|
52 |
+
"source": [
|
53 |
+
"๋ชจ๋ธ์ ํ์ต์ํฌ ๋ GPU๋ฅผ ์ฌ์ฉํด์ผ ํ์ต์ ๋น ๋ฅด๊ฒ ํ ์ ์์ผ๋ฏ๋ก ํ์ต์ํค๊ธฐ ์ ์์ ์ ๋
ธํธ๋ถ ํ๊ฒฝ์์ GPU๊ฐ ์ ๋์ ํ๋์ง ํ์ธ์ ํฉ๋๋ค. True๊ฐ ์ถ๋ ฅ์ด ๋๋ค๋ฉด ํ์ตํ ๋ GPU๊ฐ ๋์ํ๋ ๊ฒ์ด๊ณ False๊ฐ ๋์ค๋ฉด ๋์ํ์ง ์๋ ๊ฒ์
๋๋ค. False๊ฐ ์ถ๋ ฅ์ด ๋์ ๋ค๋ฉด ๊ฒ์์ ํตํด GPU๋ฅผ ํ์ฑํ ์ํค๋ ๋ฐฉ๋ฒ์ ์ฐพ์ ํ์ฑํ ์ํค์๋ฉด ๋ฉ๋๋ค."
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 2,
|
59 |
+
"id": "49766327",
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [
|
62 |
+
{
|
63 |
+
"name": "stdout",
|
64 |
+
"output_type": "stream",
|
65 |
+
"text": [
|
66 |
+
"[name: \"/device:CPU:0\"\n",
|
67 |
+
"device_type: \"CPU\"\n",
|
68 |
+
"memory_limit: 268435456\n",
|
69 |
+
"locality {\n",
|
70 |
+
"}\n",
|
71 |
+
"incarnation: 4073418392972710111\n",
|
72 |
+
", name: \"/device:GPU:0\"\n",
|
73 |
+
"device_type: \"GPU\"\n",
|
74 |
+
"memory_limit: 19146211328\n",
|
75 |
+
"locality {\n",
|
76 |
+
" bus_id: 1\n",
|
77 |
+
" links {\n",
|
78 |
+
" }\n",
|
79 |
+
"}\n",
|
80 |
+
"incarnation: 16025480191014330435\n",
|
81 |
+
"physical_device_desc: \"device: 0, name: Tesla V100-DGXS-32GB, pci bus id: 0000:07:00.0, compute capability: 7.0\"\n",
|
82 |
+
"]\n"
|
83 |
+
]
|
84 |
+
}
|
85 |
+
],
|
86 |
+
"source": [
|
87 |
+
"from tensorflow.python.client import device_lib\n",
|
88 |
+
"print(device_lib.list_local_devices())"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "markdown",
|
93 |
+
"id": "cbad9661",
|
94 |
+
"metadata": {},
|
95 |
+
"source": [
|
96 |
+
"๋ณธ์ธ์ด ์์ฑํ๊ณ ์๋ ๋
ธํธ๋ถ GPU ํ๊ฒฝ์ ์ ์ ์๋ ์ฝ๋์
๋๋ค. ์ ๋ ainize workspace์์ Tesla V100 ํ๊ฒฝ์์ ๋
ธํธ๋ถ์ ํ
์คํธํ๊ธฐ ๋๋ฌธ์ Tesla V100-DGXS-32GB์ด๋ผ๊ณ ์ ๋์ค๋ค์."
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 3,
|
102 |
+
"id": "4a699d02",
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"#!pip install -U transformers datasets scipy scikit-learn"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"cell_type": "markdown",
|
111 |
+
"id": "8c957b4c",
|
112 |
+
"metadata": {},
|
113 |
+
"source": [
|
114 |
+
"๋ชจ๋ธ ํ๋ จ์ ์ํ transformers์ ํ์ต ๋ฐ์ดํฐ์
๋ก๋๋ฅผ ์ํด datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํฉ๋๋ค. ๊ทธ ์ธ ๋ชจ๋ธ ์ฑ๋ฅ ๊ฒ์ฆ์ ์ํด scipy, scikit-learn ๋ํ ์ถ๊ฐ๋ก ์ค์นํด์ค๋๋ค.\n",
|
115 |
+
"\n",
|
116 |
+
"\n",
|
117 |
+
"## ๋ฌธ์ฅ ๋ถ๋ฅ ๋ชจ๋ธ ํ์ต"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": 4,
|
123 |
+
"id": "401dacc7",
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"import random\n",
|
128 |
+
"import logging\n",
|
129 |
+
"import torch\n",
|
130 |
+
"from IPython.display import display, HTML\n",
|
131 |
+
"import numpy as np\n",
|
132 |
+
"import pandas as pd\n",
|
133 |
+
"import datasets\n",
|
134 |
+
"from datasets import load_dataset, load_metric, ClassLabel, Sequence, list_datasets\n",
|
135 |
+
"from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer\n"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "markdown",
|
140 |
+
"id": "0ae8dc56",
|
141 |
+
"metadata": {},
|
142 |
+
"source": [
|
143 |
+
"๏ฟฝ๏ฟฝ๏ฟฝํธ๋ถ์ ์คํํ๋๋ฐ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๋ชจ๋ ์ํฌํธ ํด์ค๋๋ค."
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "code",
|
148 |
+
"execution_count": 5,
|
149 |
+
"id": "1d99beee",
|
150 |
+
"metadata": {},
|
151 |
+
"outputs": [],
|
152 |
+
"source": [
|
153 |
+
"model_checkpoint = \"klue/roberta-base\"\n",
|
154 |
+
"batch_size = 32\n",
|
155 |
+
"task = \"nli\""
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "markdown",
|
160 |
+
"id": "048072e8",
|
161 |
+
"metadata": {},
|
162 |
+
"source": [
|
163 |
+
"ํ์ต์ ํ์ํ ์ ๋ณด๋ฅผ ๋ณ์๋ก ๊ธฐ๋กํฉ๋๋ค.\n",
|
164 |
+
"\n",
|
165 |
+
"๋ณธ ๋
ธํธ๋ถ์ ์๋ klue-roberta-base ๋ชจ๋ธ์ ํ์ฉํ์ง๋ง, https://huggingface.co/klue ํ์ด์ง์์ ๋ ๋ค์ํ ์ฌ์ ํ์ต ์ธ์ด ๋ชจ๋ธ์ ํ์ธํ์ค ์ ์์ต๋๋ค.\n",
|
166 |
+
"\n",
|
167 |
+
"ํ์ต ํ์คํฌ๋ก๋ nli๋ฅผ, ๋ฐฐ์น ์ฌ์ด์ฆ๋ 32๋ก ์ง์ ํ๊ฒ ์ต๋๋ค."
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": 6,
|
173 |
+
"id": "5e5c4151",
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [
|
176 |
+
{
|
177 |
+
"name": "stderr",
|
178 |
+
"output_type": "stream",
|
179 |
+
"text": [
|
180 |
+
"Reusing dataset kor_nlu (/workspace/.cache/huggingface/datasets/kor_nlu/nli/1.0.0/4facbba77df60b0658056ced2052633e681a50187b9428bd5752ebd59d332ba8)\n"
|
181 |
+
]
|
182 |
+
}
|
183 |
+
],
|
184 |
+
"source": [
|
185 |
+
"datasets = load_dataset(\"kor_nlu\", task)"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "markdown",
|
190 |
+
"id": "06d9dc6d",
|
191 |
+
"metadata": {},
|
192 |
+
"source": [
|
193 |
+
"์ ๋ ํ์ต์ ์ฌ์ฉํ ๋ฐ์ดํฐ์
์ kakaobrain์์ ์ ๊ณตํ๋ KorNLI ๋ฐ์ดํฐ์
์ ์ฌ์ฉํ ๊ฒ์ด๋ฏ๋ก HuggingFace datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋ฑ๋ก๋ kor_nlu์ nli ๋ฐ์ดํฐ์ ๋ค์ด๋ก๋ ํฉ๋๋ค."
|
194 |
+
]
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"execution_count": 7,
|
199 |
+
"id": "01457377",
|
200 |
+
"metadata": {},
|
201 |
+
"outputs": [
|
202 |
+
{
|
203 |
+
"data": {
|
204 |
+
"text/plain": [
|
205 |
+
"DatasetDict({\n",
|
206 |
+
" train: Dataset({\n",
|
207 |
+
" features: ['premise', 'hypothesis', 'label'],\n",
|
208 |
+
" num_rows: 550146\n",
|
209 |
+
" })\n",
|
210 |
+
" validation: Dataset({\n",
|
211 |
+
" features: ['premise', 'hypothesis', 'label'],\n",
|
212 |
+
" num_rows: 1570\n",
|
213 |
+
" })\n",
|
214 |
+
" test: Dataset({\n",
|
215 |
+
" features: ['premise', 'hypothesis', 'label'],\n",
|
216 |
+
" num_rows: 4954\n",
|
217 |
+
" })\n",
|
218 |
+
"})"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
"execution_count": 7,
|
222 |
+
"metadata": {},
|
223 |
+
"output_type": "execute_result"
|
224 |
+
}
|
225 |
+
],
|
226 |
+
"source": [
|
227 |
+
"datasets"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "markdown",
|
232 |
+
"id": "a8291fbf",
|
233 |
+
"metadata": {},
|
234 |
+
"source": [
|
235 |
+
"๋ค์ด๋ก๋ ํ ์ป์ด์ง datasets์ ๊ฐ์ฒด๋ฅผ ๋ณด๋ฉด KorNLI ๋ฐ์ดํฐ์๋ ํ๋ จ ๋ฐ์ดํฐ, ๊ฒ์ฆ ๋ฐ์ดํฐ, ํ
์คํธ ๋ฐ์ดํฐ๊ฐ ํฌํจ๋์ด ์๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค."
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"execution_count": 8,
|
241 |
+
"id": "07659e8c",
|
242 |
+
"metadata": {},
|
243 |
+
"outputs": [
|
244 |
+
{
|
245 |
+
"data": {
|
246 |
+
"text/plain": [
|
247 |
+
"{'premise': '๋ง์ ํ ์ฌ๋์ด ๊ณ ์ฅ๋ ๋นํ๊ธฐ ์๋ก ๋ฐ์ด์ค๋ฅธ๋ค.',\n",
|
248 |
+
" 'hypothesis': 'ํ ์ฌ๋์ด ๊ฒฝ์์ ์ํด ๋ง์ ํ๋ จ์ํค๊ณ ์๋ค.',\n",
|
249 |
+
" 'label': 1}"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
"execution_count": 8,
|
253 |
+
"metadata": {},
|
254 |
+
"output_type": "execute_result"
|
255 |
+
}
|
256 |
+
],
|
257 |
+
"source": [
|
258 |
+
"datasets[\"train\"][0]"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "markdown",
|
263 |
+
"id": "29d24878",
|
264 |
+
"metadata": {},
|
265 |
+
"source": [
|
266 |
+
"๊ฐ ๋ฐ์ดํฐ๋ ์์ ๊ฐ์ด ๋๊ฐ์ ๋ฌธ์ฅ๊ณผ ๋ ๋ฌธ์ฅ์ ์ถ๋ก ๊ด๊ณ๋ฅผ ๋ผ๋ฒจ๋ก ๊ฐ์ง๊ณ ์์ต๋๋ค."
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"execution_count": 9,
|
272 |
+
"id": "33ce978b",
|
273 |
+
"metadata": {},
|
274 |
+
"outputs": [],
|
275 |
+
"source": [
|
276 |
+
"def show_random_elements(dataset, num_examples=10):\n",
|
277 |
+
" assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
|
278 |
+
"\n",
|
279 |
+
" picks = []\n",
|
280 |
+
" \n",
|
281 |
+
" for _ in range(num_examples):\n",
|
282 |
+
" pick = random.randint(0, len(dataset)-1)\n",
|
283 |
+
"\n",
|
284 |
+
" # ์ด๋ฏธ ๋ฑ๋ก๋ ์์ ๊ฐ ๋ฝํ ๊ฒฝ์ฐ, ๋ค์ ์ถ์ถ\n",
|
285 |
+
" while pick in picks:\n",
|
286 |
+
" pick = random.randint(0, len(dataset)-1)\n",
|
287 |
+
"\n",
|
288 |
+
" picks.append(pick)\n",
|
289 |
+
"\n",
|
290 |
+
" # ์์๋ก ์ถ์ถ๋ ์ธ๋ฑ์ค๋ค๋ก ๊ตฌ์ฑ๋ ๋ฐ์ดํฐ ํ๋ ์ ์ ์ธ\n",
|
291 |
+
" df = pd.DataFrame(dataset[picks])\n",
|
292 |
+
"\n",
|
293 |
+
" for column, typ in dataset.features.items():\n",
|
294 |
+
" # ๋ผ๋ฒจ ํด๋์ค๋ฅผ ์คํธ๋ง์ผ๋ก ๋ณํ\n",
|
295 |
+
" if isinstance(typ, ClassLabel):\n",
|
296 |
+
" df[column] = df[column].transform(lambda i: typ.names[i])\n",
|
297 |
+
"\n",
|
298 |
+
" display(HTML(df.to_html()))"
|
299 |
+
]
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "markdown",
|
303 |
+
"id": "58c8a3aa",
|
304 |
+
"metadata": {},
|
305 |
+
"source": [
|
306 |
+
"๋ฐ์ดํฐ์
์ ์ ๋ฐ์ ์ผ๋ก ์ดํด๋ณด๊ธฐ ์ํด ์๊ฐํ ํจ์๋ฅผ ์ ์ํฉ๋๋ค."
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": 10,
|
312 |
+
"id": "d4500f78",
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [
|
315 |
+
{
|
316 |
+
"data": {
|
317 |
+
"text/html": [
|
318 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
319 |
+
" <thead>\n",
|
320 |
+
" <tr style=\"text-align: right;\">\n",
|
321 |
+
" <th></th>\n",
|
322 |
+
" <th>premise</th>\n",
|
323 |
+
" <th>hypothesis</th>\n",
|
324 |
+
" <th>label</th>\n",
|
325 |
+
" </tr>\n",
|
326 |
+
" </thead>\n",
|
327 |
+
" <tbody>\n",
|
328 |
+
" <tr>\n",
|
329 |
+
" <th>0</th>\n",
|
330 |
+
" <td>์ ๊ณ ๋์ ์ฌ์ฑ๋ค์ ํฐ ์๋๋ฐญ์์ ์ฌํ, ์ ์น, ๊ตํ ๋ชจ์์ด ๋ ์ ์๋ ๊ณณ์ ๊ทธ๋ ค์ ธ ์๋ค.</td>\n",
|
331 |
+
" <td>ํ ๋ฌด๋ฆฌ์ ์ฌ์ฑ์ด ๊ตํ ์งํ์ค์์ ์ ์น์ ํ์๋ฅผ ํ๊ณ ์๋ค.</td>\n",
|
332 |
+
" <td>contradiction</td>\n",
|
333 |
+
" </tr>\n",
|
334 |
+
" <tr>\n",
|
335 |
+
" <th>1</th>\n",
|
336 |
+
" <td>ํ ์ฌ๋์ด ์ฐจ ์ ๋ฐ์๊ฐ ๋ง์ ๊ณณ์์ ๋ฌผ๊ฑด์ ์ค๋๋ค.</td>\n",
|
337 |
+
" <td>์ฌ๋์ด ๋ฐ์ ์๋ค.</td>\n",
|
338 |
+
" <td>entailment</td>\n",
|
339 |
+
" </tr>\n",
|
340 |
+
" <tr>\n",
|
341 |
+
" <th>2</th>\n",
|
342 |
+
" <td>ํ์ ์
์ธ ๋ฅผ ์
์ ๋จ์๊ฐ ์ผ์ธ ํ
์ด๋ธ์์ ์ฌ์ ๋ง์ํธ์ ์์ ์๋ค.</td>\n",
|
343 |
+
" <td>์ปค๋ค๋ ์ธ๊ฐ์ด ์์ ์์๋ค.</td>\n",
|
344 |
+
" <td>neutral</td>\n",
|
345 |
+
" </tr>\n",
|
346 |
+
" <tr>\n",
|
347 |
+
" <th>3</th>\n",
|
348 |
+
" <td>์ค๋ ์ง์ ์กฐ๋ผ๋ฅผ ์
์ ๋ ๋ช
์ ์ง์์ด ์๋ฅ ์์
์ ํ๊ณ ์๋ค.</td>\n",
|
349 |
+
" <td>์ฌ๋๋ค์ด ์๊ณ ์๋ค.</td>\n",
|
350 |
+
" <td>contradiction</td>\n",
|
351 |
+
" </tr>\n",
|
352 |
+
" <tr>\n",
|
353 |
+
" <th>4</th>\n",
|
354 |
+
" <td>๋นจ๊ฐ ์ ์ง๋ฅผ ์
์ ์ถ๊ตฌ์ ์๊ฐ ๊ณต์ ๋์ง๋ ค๊ณ ํ๋ ํ์ ์ ์ง๋ฅผ ์
์ ์ถ๊ตฌ์ ์์ ์ธ์ฐ๊ณ ์๋ค.</td>\n",
|
355 |
+
" <td>์๋
๊ฐ ์นดํ์์ ์ปคํผ๋ฅผ ์ฃผ๋ฌธํ๋ค.</td>\n",
|
356 |
+
" <td>contradiction</td>\n",
|
357 |
+
" </tr>\n",
|
358 |
+
" <tr>\n",
|
359 |
+
" <th>5</th>\n",
|
360 |
+
" <td>๋
ธ๋ ์ํผ์ค์ ๋ถํ์ ์๋ค์ ์ ์ ์์ด๊ฐ ์ปค๋ค๋ ๋ฌผ์์ ๊ฑธ์ด๊ฐ๋ค.</td>\n",
|
361 |
+
" <td>ํ ์์ด๊ฐ ๊ทธ๋
์ ๊ฐ๋ฅผ ํธ์์์ ์ซ์๋ธ๋ค.</td>\n",
|
362 |
+
" <td>neutral</td>\n",
|
363 |
+
" </tr>\n",
|
364 |
+
" <tr>\n",
|
365 |
+
" <th>6</th>\n",
|
366 |
+
" <td>์ด๋ผํ ๋๋ ์ค๋ฅผ ์
์ ์ฌ์.</td>\n",
|
367 |
+
" <td>๊ทธ ๋๋ ์ค๋ ์ฝ๊ฐ ์์ํ๋ค.</td>\n",
|
368 |
+
" <td>neutral</td>\n",
|
369 |
+
" </tr>\n",
|
370 |
+
" <tr>\n",
|
371 |
+
" <th>7</th>\n",
|
372 |
+
" <td>๋ ๋ง๋ฆฌ์ ๋ง์ด ํ ์ฌ์๋ฅผ ์๋ ์ ํ์ฐ๊ณ ์๋ค.</td>\n",
|
373 |
+
" <td>์ฌ์๊ฐ ๋ง์ฐจ๋ฅผ ํ๊ณ ๋ง์ ํ๋ค.</td>\n",
|
374 |
+
" <td>entailment</td>\n",
|
375 |
+
" </tr>\n",
|
376 |
+
" <tr>\n",
|
377 |
+
" <th>8</th>\n",
|
378 |
+
" <td>๋์๋ก ๋ค๋ฎ์ธ ๋ฒฝ๋๋ด์ ์ง๋๊ฐ๋ ๋จ์์ ์ฌ์.</td>\n",
|
379 |
+
" <td>๋ ์ฌ๋์ด ๋ฒฝ์ ๊ฑธ๋ฆฐ ๋์๋ฅผ ์ง๋๊ฐ๋ค.</td>\n",
|
380 |
+
" <td>entailment</td>\n",
|
381 |
+
" </tr>\n",
|
382 |
+
" <tr>\n",
|
383 |
+
" <th>9</th>\n",
|
384 |
+
" <td>์ธ๋ ์ฌ์ฑ๋ค์ด ์ ํต๋ฌด์ฉ์ ์์ ์ก๊ณ ์๋ค.</td>\n",
|
385 |
+
" <td>์ธ๋ ์ฌ์ฑ๋ค์ ๊ฒฐํผ์์์ ์ ํต ์ถค์ ์ถ๋ค.</td>\n",
|
386 |
+
" <td>neutral</td>\n",
|
387 |
+
" </tr>\n",
|
388 |
+
" </tbody>\n",
|
389 |
+
"</table>"
|
390 |
+
],
|
391 |
+
"text/plain": [
|
392 |
+
"<IPython.core.display.HTML object>"
|
393 |
+
]
|
394 |
+
},
|
395 |
+
"metadata": {},
|
396 |
+
"output_type": "display_data"
|
397 |
+
}
|
398 |
+
],
|
399 |
+
"source": [
|
400 |
+
"show_random_elements(datasets[\"train\"])"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "markdown",
|
405 |
+
"id": "5dc9ddf3",
|
406 |
+
"metadata": {},
|
407 |
+
"source": [
|
408 |
+
"์์ ์ ์ํ ํจ์๋ฅผ ์ด์ฉํ์ฌ ์์์ ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ์ดํด๋ณด๋๋ก ํฉ๋๋ค.\n",
|
409 |
+
"\n",
|
410 |
+
"KorNLI์๋ ๊ด๊ณ ์ถ๋ก ์ ํ ๋ entailment, neutral ๊ทธ๋ฆฌ๊ณ contradiction ์ธ ๊ฐ์ ๋ผ๋ฒจ์ ๊ฐ์ง๊ณ ์์์ ์ ์ ์์ต๋๋ค.\n",
|
411 |
+
"\n",
|
412 |
+
"์ด๋ ๊ฒ ๋ฐ์ดํฐ๋ฅผ ์ดํด๋ณด๋ฉด ๊ฐ ๋ผ๋ฒจ์ ์ด๋ ํ ๋ฌธ์ฅ๋ค์ด ํด๋น๋๋์ง ์ ์ ์์ด ๊ฐ์ ์ตํ ์ ์๋๋ฐ์ ์ฅ์ ์ด ์์ต๋๋ค."
|
413 |
+
]
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"cell_type": "code",
|
417 |
+
"execution_count": 11,
|
418 |
+
"id": "e1dbaa2b",
|
419 |
+
"metadata": {},
|
420 |
+
"outputs": [],
|
421 |
+
"source": [
|
422 |
+
"metric = load_metric(\"glue\", \"qnli\")"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "markdown",
|
427 |
+
"id": "d112985e",
|
428 |
+
"metadata": {},
|
429 |
+
"source": [
|
430 |
+
"ํ๋ จ ๊ณผ์ ์ค ๋ชจ๋ธ์ ์ฑ๋ฅ ํ์
์ ์ํด ๋ฉํธ๋ฆญ์ ์ค์ ํฉ๋๋ค.\n",
|
431 |
+
"\n",
|
432 |
+
"datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์๋ ์ด๋ฏธ ๊ตฌํ๋ ๋ฉํธ๋ฆญ์ ์ฌ์ฉํ ์ ์๋ load_metric ํจ์๊ฐ ์์ด์ ์ฝ๊ฒ ์ค์ ์ด ๊ฐ๋ฅํฉ๋๋ค.\n",
|
433 |
+
"\n",
|
434 |
+
"๊ทธ ์ค GLUE ๋ฐ์ดํฐ์
์ qnli ํ์คํฌ์ ๋ฉํธ๋ฆญ์ ์ฌ์ฉํ๋๋ก ํ๊ฒ ์ต๋๋ค."
|
435 |
+
]
|
436 |
+
},
|
437 |
+
{
|
438 |
+
"cell_type": "code",
|
439 |
+
"execution_count": 12,
|
440 |
+
"id": "7be908b6",
|
441 |
+
"metadata": {},
|
442 |
+
"outputs": [
|
443 |
+
{
|
444 |
+
"data": {
|
445 |
+
"text/plain": [
|
446 |
+
"(array([0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1,\n",
|
447 |
+
" 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1,\n",
|
448 |
+
" 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0]),\n",
|
449 |
+
" array([0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0,\n",
|
450 |
+
" 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0,\n",
|
451 |
+
" 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1]))"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
"execution_count": 12,
|
455 |
+
"metadata": {},
|
456 |
+
"output_type": "execute_result"
|
457 |
+
}
|
458 |
+
],
|
459 |
+
"source": [
|
460 |
+
"fake_preds = np.random.randint(0, 2, size=(64,))\n",
|
461 |
+
"fake_labels = np.random.randint(0, 2, size=(64,))\n",
|
462 |
+
"fake_preds, fake_labels"
|
463 |
+
]
|
464 |
+
},
|
465 |
+
{
|
466 |
+
"cell_type": "markdown",
|
467 |
+
"id": "7b7de386",
|
468 |
+
"metadata": {},
|
469 |
+
"source": [
|
470 |
+
"๋ฉํธ๋ฆญ์ด ์ ์์ ์ผ๋ก ์๋ํ๋ ๊ฒ์ ํ์ธํ๊ธฐ ์ํด ๋๋คํ ์์ธก ๊ฐ๊ณผ ๋ผ๋ฒจ ๊ฐ์ ์์ฑํฉ๋๋ค."
|
471 |
+
]
|
472 |
+
},
|
473 |
+
{
|
474 |
+
"cell_type": "code",
|
475 |
+
"execution_count": 13,
|
476 |
+
"id": "b3cce09f",
|
477 |
+
"metadata": {},
|
478 |
+
"outputs": [
|
479 |
+
{
|
480 |
+
"data": {
|
481 |
+
"text/plain": [
|
482 |
+
"{'accuracy': 0.5}"
|
483 |
+
]
|
484 |
+
},
|
485 |
+
"execution_count": 13,
|
486 |
+
"metadata": {},
|
487 |
+
"output_type": "execute_result"
|
488 |
+
}
|
489 |
+
],
|
490 |
+
"source": [
|
491 |
+
"metric.compute(predictions=fake_preds, references=fake_labels)"
|
492 |
+
]
|
493 |
+
},
|
494 |
+
{
|
495 |
+
"cell_type": "markdown",
|
496 |
+
"id": "a040c7e9",
|
497 |
+
"metadata": {},
|
498 |
+
"source": [
|
499 |
+
"์์ฑํ ๋๋ค ๊ฐ๋ค์ compute() ํจ์๋ฅผ ํตํด ์ ๋์ํ๋์ง ํ์ธํฉ๋๋ค."
|
500 |
+
]
|
501 |
+
},
|
502 |
+
{
|
503 |
+
"cell_type": "code",
|
504 |
+
"execution_count": 14,
|
505 |
+
"id": "6b7eb994",
|
506 |
+
"metadata": {},
|
507 |
+
"outputs": [],
|
508 |
+
"source": [
|
509 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)"
|
510 |
+
]
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"cell_type": "markdown",
|
514 |
+
"id": "439ef8aa",
|
515 |
+
"metadata": {},
|
516 |
+
"source": [
|
517 |
+
"์ด์ ํ์ต์ ํ์ฉํ ํ ํฌ๋์ด์ ๋ฅผ ๋ก๋ํด์ต๋๋ค."
|
518 |
+
]
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"cell_type": "code",
|
522 |
+
"execution_count": 15,
|
523 |
+
"id": "88f32f23",
|
524 |
+
"metadata": {},
|
525 |
+
"outputs": [
|
526 |
+
{
|
527 |
+
"data": {
|
528 |
+
"text/plain": [
|
529 |
+
"{'input_ids': [0, 1891, 3611, 2052, 3855, 2069, 3627, 1041, 2069, 4484, 2067, 2089, 2088, 1513, 2062, 18, 2, 1041, 2069, 1763, 3611, 2052, 8514, 2336, 7046, 7587, 5603, 17290, 2062, 18, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
|
530 |
+
]
|
531 |
+
},
|
532 |
+
"execution_count": 15,
|
533 |
+
"metadata": {},
|
534 |
+
"output_type": "execute_result"
|
535 |
+
}
|
536 |
+
],
|
537 |
+
"source": [
|
538 |
+
"tokenizer(\"ํ ์ฌ๋์ด ๊ฒฝ์์ ์ํด ๋ง์ ํ๋ จ์ํค๊ณ ์๋ค.\", \"๋ง์ ํ ์ฌ๋์ด ๊ณ ์ฅ๋ ๋นํ๊ธฐ ์๋ก ๋ฐ์ด์ค๋ฅธ๋ค.\")"
|
539 |
+
]
|
540 |
+
},
|
541 |
+
{
|
542 |
+
"cell_type": "markdown",
|
543 |
+
"id": "fcde8cd6",
|
544 |
+
"metadata": {},
|
545 |
+
"source": [
|
546 |
+
"๋ก๋๋ ํ ํฌ๋์ด์ ๊ฐ ๋ ๋ฌธ์ฅ์ ํ ํฐํํ๋ ๋ฐฉ์์ ํ์
ํ๊ธฐ ์ํด ๋ ๋ฌธ์ฅ์ ์
๋ ฅ ๊ฐ์ผ๋ก ๋ฃ์ด์ ํ์ธํด ๋ด
๋๋ค."
|
547 |
+
]
|
548 |
+
},
|
549 |
+
{
|
550 |
+
"cell_type": "code",
|
551 |
+
"execution_count": 16,
|
552 |
+
"id": "6a927431",
|
553 |
+
"metadata": {},
|
554 |
+
"outputs": [
|
555 |
+
{
|
556 |
+
"name": "stdout",
|
557 |
+
"output_type": "stream",
|
558 |
+
"text": [
|
559 |
+
"Sentence 1: ๋ง์ ํ ์ฌ๋์ด ๊ณ ์ฅ๋ ๋นํ๊ธฐ ์๋ก ๋ฐ์ด์ค๋ฅธ๋ค.\n",
|
560 |
+
"Sentence 2: ํ ์ฌ๋์ด ๊ฒฝ์์ ์ํด ๋ง์ ํ๋ จ์ํค๊ณ ์๋ค.\n"
|
561 |
+
]
|
562 |
+
}
|
563 |
+
],
|
564 |
+
"source": [
|
565 |
+
"sentence1_key, sentence2_key = (\"premise\", \"hypothesis\")\n",
|
566 |
+
"print(f\"Sentence 1: {datasets['train'][0][sentence1_key]}\")\n",
|
567 |
+
"print(f\"Sentence 2: {datasets['train'][0][sentence2_key]}\")"
|
568 |
+
]
|
569 |
+
},
|
570 |
+
{
|
571 |
+
"cell_type": "markdown",
|
572 |
+
"id": "0fb0ba8b",
|
573 |
+
"metadata": {},
|
574 |
+
"source": [
|
575 |
+
"์ด์ ์์ ๋ก๋ํ ๋ฐ์ดํฐ์
์์ ๊ฐ ๋ฌธ์ฅ์ ํด๋นํ๋ value๋ฅผ ๋ฝ์์ฃผ๊ธฐ ์ํ key๋ฅผ ์ ์ํฉ๋๋ค.\n",
|
576 |
+
"\n",
|
577 |
+
"KorNLI ๋ฐ์ดํฐ์
์์ ๋ ๋ฌธ์ฅ์ label์ด ๊ฐ๊ฐ premise์ hypothesis๋ผ๋ ์ด๋ฆ์ผ๋ก ์ ์๋ ๊ฒ์ ํ์ธํ์์ผ๋ฏ๋ก, ๋ ๋ฌธ์ฅ์ key ๋ ๋ง์ฐฌ๊ฐ์ง๋ก ๊ฐ๊ฐ premise, hypothesis๊ฐ ๋ฉ๋๋ค."
|
578 |
+
]
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "code",
|
582 |
+
"execution_count": 20,
|
583 |
+
"id": "4e020138",
|
584 |
+
"metadata": {},
|
585 |
+
"outputs": [],
|
586 |
+
"source": [
|
587 |
+
"def preprocess_function(examples):\n",
|
588 |
+
" print(examples)\n",
|
589 |
+
" return tokenizer(\n",
|
590 |
+
" examples[sentence1_key],\n",
|
591 |
+
" examples[sentence2_key],\n",
|
592 |
+
" truncation=True,\n",
|
593 |
+
" return_token_type_ids=False,\n",
|
594 |
+
" )"
|
595 |
+
]
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"cell_type": "markdown",
|
599 |
+
"id": "42c8f5a8",
|
600 |
+
"metadata": {},
|
601 |
+
"source": [
|
602 |
+
"์ด์ key ๋ ํ์ธ์ด ๋์์ผ๋, ๋ฐ์ดํฐ์
์์ ๊ฐ ์์ ๋ค์ ๋ฝ์์ ํ ํฐํ ํ ์ ์๋ ํจ์๋ฅผ ์์ ๊ฐ์ด ์ ์ํด์ค๋๋ค.\n",
|
603 |
+
"\n",
|
604 |
+
"ํด๋น ํจ์๋ ๋ชจ๋ธ์ ํ๋ จํ๊ธฐ ์์ ๋ฐ์ดํฐ์
์ ๋ฏธ๋ฆฌ ํ ํฐํ ์์ผ๋๋ ์์
์ ์ํ ์ฝ๋ฐฑ ํจ์๋ก ์ฌ์ฉ๋๊ฒ ๋ฉ๋๋ค.\n",
|
605 |
+
"\n",
|
606 |
+
"์ธ์๋ก ๋ฃ์ด์ฃผ๋ truncation๋ ๋ชจ๋ธ์ด ์
๋ ฅ ๋ฐ์ ์ ์๋ ์ต๋ ๊ธธ์ด ์ด์์ ํ ํฐ ์ํ์ค๊ฐ ๋ค์ด์ค๊ฒ ๋ ๊ฒฝ์ฐ, ์ต๋ ๊ธธ์ด ๊ธฐ์ค์ผ๋ก ์ํ์ค๋ฅผ ์๋ฅด๋ผ๋ ์๋ฏธ๋ฅผ ์ง๋๋๋ค.\n",
|
607 |
+
"\n",
|
608 |
+
"( * return_token_type_ids๋ ํ ํฌ๋์ด์ ๊ฐ token_type_ids๋ฅผ ๋ฐํํ๋๋ก ํ ๊ฒ์ธ์ง๋ฅผ ๊ฒฐ์ ํ๋ ์ธ์์
๋๋ค. transformers==4.7.0 ๊ธฐ์ค์ผ๋ก token_type_ids๊ฐ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ฐํ๋๋ฏ๋ก token_type_ids ์์ฒด๋ฅผ ์ฌ์ฉํ์ง ์๋ RoBERTa ๋ชจ๋ธ์ ํ์ฉํ๊ธฐ ์ํด ํด๋น ์ธ์๋ฅผ False๋ก ์ค์ ํด์ฃผ๋๋ก ํฉ๋๋ค.)"
|
609 |
+
]
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"cell_type": "code",
|
613 |
+
"execution_count": 21,
|
614 |
+
"id": "ac8d69bd",
|
615 |
+
"metadata": {},
|
616 |
+
"outputs": [
|
617 |
+
{
|
618 |
+
"name": "stdout",
|
619 |
+
"output_type": "stream",
|
620 |
+
"text": [
|
621 |
+
"{'premise': ['๋ง์ ํ ์ฌ๋์ด ๊ณ ์ฅ๋ ๋นํ๊ธฐ ์๋ก ๋ฐ์ด์ค๋ฅธ๋ค.', '๋ง์ ํ ์ฌ๋์ด ๊ณ ์ฅ๋ ๋นํ๊ธฐ ์๋ก ๋ฐ์ด์ค๋ฅธ๋ค.', '๋ง์ ํ ์ฌ๋์ด ๊ณ ์ฅ๋ ๋นํ๊ธฐ ์๋ก ๋ฐ์ด์ค๋ฅธ๋ค.', '์นด๋ฉ๋ผ์ ์๊ณ ์์ ํ๋๋ ์์ด๋ค', '์นด๋ฉ๋ผ์ ์๊ณ ์์ ํ๋๋ ์์ด๋ค', '์นด๋ฉ๋ผ์ ์๊ณ ์์ ํ๋๋ ์์ด๋ค', 'ํ ์๋
์ด ๋นจ๊ฐ ๋ค๋ฆฌ ํ๊ฐ์ด๋ฐ ์ค์ผ์ดํธ๋ณด๋์ ๋ฐ์ด์ค๋ฅด๊ณ ์๋ค.', 'ํ ์๋
์ด ๋นจ๊ฐ ๋ค๋ฆฌ ํ๊ฐ์ด๋ฐ ์ค์ผ์ดํธ๋ณด๋์ ๋ฐ์ด์ค๋ฅด๊ณ ์๋ค.', 'ํ ์๋
์ด ๋นจ๊ฐ ๋ค๋ฆฌ ํ๊ฐ์ด๋ฐ ์ค์ผ์ดํธ๋ณด๋์ ๋ฐ์ด์ค๋ฅด๊ณ ์๋ค.', '๋์ด ๋ ๋จ์๊ฐ ์ปคํผ์์ ์์ ํ
์ด๋ธ์ ์ค๋ ์ง ์ฃผ์ค๋ฅผ ๋ค๊ณ ์์ ์๊ณ ๋ฐ์ ์ ์
์ธ ๋ฅผ ์
์ ์ง์๋ค์ ๋ค์์ ๋ฏธ์๋ฅผ ์ง๊ณ ์๋ค.', '๋์ด ๋ ๋จ์๊ฐ ์ปคํผ์์ ์์ ํ
์ด๋ธ์ ์ค๋ ์ง ์ฃผ์ค๋ฅผ ๋ค๊ณ ์์ ์๊ณ ๋ฐ์ ์ ์
์ธ ๋ฅผ ์
์ ์ง์๋ค์ ๋ค์์ ๋ฏธ์๋ฅผ ์ง๊ณ ์๋ค.', '๋์ด ๋ ๋จ์๊ฐ ์ปคํผ์์ ์์ ํ
์ด๋ธ์ ์ค๋ ์ง ์ฃผ์ค๋ฅผ ๋ค๊ณ ์์ ์๊ณ ๋ฐ์ ์ ์
์ธ ๋ฅผ ์
์ ์ง์๋ค์ ๋ค์์ ๋ฏธ์๋ฅผ ์ง๊ณ ์๋ค.', '๊ธ๋ฐ ์ฌ์ ๋์ด ์๋ก ๊ปด์๊ณ ์๋ค.', '๊ธ๋ฐ ์ฌ์ ๋์ด ์๋ก ๊ปด์๊ณ ์๋ค.', '๊ธ๋ฐ ์ฌ์ ๋์ด ์๋ก ๊ปด์๊ณ ์๋ค.', '์๋น์ ์๋ ๋ช๋ช ์ฌ๋๋ค, ๊ทธ๋ค ์ค ํ ๋ช
์ ์ค๋ ์ง ์ฃผ์ค๋ฅผ ๋ง์๊ณ ์๋ค.', '์๋น์ ์๋ ๋ช๋ช ์ฌ๋๋ค, ๊ทธ๋ค ์ค ํ ๋ช
์ ์ค๋ ์ง ์ฃผ์ค๋ฅผ ๋ง์๊ณ ์๋ค.', '์๋น์ ์๋ ๋ช๋ช ์ฌ๋๋ค, ๊ทธ๋ค ์ค ํ ๋ช
์ ์ค๋ ์ง ์ฃผ์ค๋ฅผ ๋ง์๊ณ ์๋ค.', 'ํ ๋์ด๋ ๋จ์๊ฐ ์๋น์์ ์ค๋ ์ง ์ฃผ์ค๋ฅผ ๋ง์๊ณ ์๋ค.', 'ํ ๋์ด๋ ๋จ์๊ฐ ์๋น์์ ์ค๋ ์ง ์ฃผ์ค๋ฅผ ๋ง์๊ณ ์๋ค.', 'ํ ๋์ด๋ ๋จ์๊ฐ ์๋น์์ ์ค๋ ์ง ์ฃผ์ค๋ฅผ ๋ง์๊ณ ์๋ค.', '๊ธ๋ฐ์ ๋จ์, ๊ทธ๋ฆฌ๊ณ ๊ณต๊ณต ๋ถ์๋์์ ์ ์ ๋ง์๋ ๊ฐ์ ์
์ธ .', '๊ธ๋ฐ์ ๋จ์, ๊ทธ๋ฆฌ๊ณ ๊ณต๊ณต ๋ถ์๋์์ ์ ์ ๋ง์๋ ๊ฐ์ ์
์ธ .', '๊ธ๋ฐ์ ๋จ์, ๊ทธ๋ฆฌ๊ณ ๊ณต๊ณต ๋ถ์๋์์ ์ ์ ๋ง์๋ ๊ฐ์ ์
์ธ .', '๋ฐฉ๊ธ ์ ์ฌ์ ๋จน๊ณ ์๋ณ์ธ์ฌ๋ฅผ ํ ๋ ์ฌ์.', '๋ฐฉ๊ธ ์ ์ฌ์ ๋จน๊ณ ์๋ณ์ธ์ฌ๋ฅผ ํ ๋ ์ฌ์.', '๋ฐฉ๊ธ ์ ์ฌ์ ๋จน๊ณ ์๋ณ์ธ์ฌ๋ฅผ ํ ๋ ์ฌ์.', '๋ ์ฌ์๊ฐ ์์ ์ด๋ฐ ์ฉ๊ธฐ๋ฅผ ๋ค๊ณ ํฌ์น์ ํ๋ค.', '๋ ์ฌ์๊ฐ ์์ ์ด๋ฐ ์ฉ๊ธฐ๋ฅผ ๋ค๊ณ ํฌ์น์ ํ๋ค.', '๋ ์ฌ์๊ฐ ์์ ์ด๋ฐ ์ฉ๊ธฐ๋ฅผ ๋ค๊ณ ํฌ์น์ ํ๋ค.', '๋ฆฌํ๋ฆฌ๊ทธ ํ์ด ์คํ ๊ฒฝ๊ธฐ์์ ๋ฒ ์ด์ค๋ก ๋ฏธ๋๋ฌ์ง๋ ์ฃผ์๋ฅผ ์ก์ผ๋ ค๊ณ ํ๋ค.', '๋ฆฌํ๋ฆฌ๊ทธ ํ์ด ์คํ ๊ฒฝ๊ธฐ์์ ๋ฒ ์ด์ค๋ก ๋ฏธ๋๋ฌ์ง๋ ์ฃผ์๋ฅผ ์ก์ผ๋ ค๊ณ ํ๋ค.', '๋ฆฌํ๋ฆฌ๊ทธ ํ์ด ์คํ ๊ฒฝ๊ธฐ์์ ๋ฒ ์ด์ค๋ก ๋ฏธ๋๋ฌ์ง๋ ์ฃผ์๋ฅผ ์ก์ผ๋ ค๊ณ ํ๋ค.', '์ด ํ๊ต๋ ๋ค๋ฅธ ๋ฌธํ๋ค์ด ํํฐ์์ ์ด๋ป๊ฒ ๋ค๋ฃจ์ด์ง๋์ง์ ๋ํ ๋ฏธ๊ตญ ๋ฌธํ๋ฅผ ๋ณด์ฌ์ฃผ๊ธฐ ์ํด ํน๋ณํ ํ์ฌ๋ฅผ ์ด๊ณ ์๋ค.', '์ด ํ๊ต๋ ๋ค๋ฅธ ๋ฌธํ๋ค์ด ํํฐ์์ ์ด๋ป๊ฒ ๋ค๋ฃจ์ด์ง๋์ง์ ๋ํ ๋ฏธ๊ตญ ๋ฌธํ๋ฅผ ๋ณด์ฌ์ฃผ๊ธฐ ์ํด ํน๋ณํ ํ์ฌ๋ฅผ ์ด๊ณ ์๋ค.', '์ด ํ๊ต๋ ๋ค๋ฅธ ๋ฌธํ๋ค์ด ํํฐ์์ ์ด๋ป๊ฒ ๋ค๋ฃจ์ด์ง๋์ง์ ๋ํ ๋ฏธ๊ตญ ๋ฌธํ๋ฅผ ๋ณด์ฌ์ฃผ๊ธฐ ์ํด ํน๋ณํ ํ์ฌ๋ฅผ ์ด๊ณ ์๋ค.', '๊ณ ๊ธ ํจ์
์๊ฐ์จ๋ค์ด ๋์์ ์ฌ๋๋ค ์์์ ์ ์ฐจ ๋ฐ์์ ๊ธฐ๋ค๋ฆฌ๊ณ ์๋ค.', '๊ณ ๊ธ ํจ์
์๊ฐ์จ๋ค์ด ๋์์ ์ฌ๋๋ค ์์์ ์ ์ฐจ ๋ฐ์์ ๊ธฐ๋ค๋ฆฌ๊ณ ์๋ค.', '๊ณ ๊ธ ํจ์
์๊ฐ์จ๋ค์ด ๋์์ ์ฌ๋๋ค ์์์ ์ ์ฐจ ๋ฐ์์ ๊ธฐ๋ค๋ฆฌ๊ณ ์๋ค.', 'ํด๋ณ์์ ์ฆ๊ธฐ๋ ๋จ์, ์ฌ์, ์์ด.'], 'hypothesis': ['ํ ์ฌ๋์ด ๊ฒฝ์์ ์ํด ๋ง์ ํ๋ จ์ํค๊ณ ์๋ค.', 'ํ ์ฌ๋์ด ์๋น์์ ์ค๋ฏ๋ ์ ์ฃผ๋ฌธํ๊ณ ์๋ค.', '์ฌ๋์ ์ผ์ธ์์ ๋ง์ ํ๊ณ ์๋ค.', '๊ทธ๋ค์ ๋ถ๋ชจ๋์ ๋ณด๊ณ ์๊ณ ์๋ค', '์์ด๋ค์ด ์๋ค', '์์ด๋ค์ด ์ผ๊ตด์ ์ฐํธ๋ฆฌ๊ณ ์๋ค', '์๋
์ ์ธ๋๋ฅผ ๋ฐ๋ผ ์ค์ผ์ดํธ๋ฅผ ํ๋ค.', '๊ทธ ์๋
์ ์ค์ผ์ดํธ๋ณด๋๋ฅผ ํ๋ ๋ฌ๊ธฐ๋ฅผ ๋ถ๋ฆฐ๋ค.', '์๋
์ด ์์ ์ฅ๋น๋ฅผ ์ฐฉ์ฉํ๊ณ ์๋ค.', '๋์ด ๋ ๋จ์๊ฐ ๋ธ์ด ํด๊ทผํ๊ธฐ๋ฅผ ๊ธฐ๋ค๋ฆฌ๋ฉด์ ์ฃผ์ค๋ฅผ ๋ง์ ๋ค.', 'ํ ์๋
์ด ํ๋ฒ๊ฑฐ๋ฅผ ๋ค์ง๋๋ค.', 'ํ ๋
ธ์ธ์ด ์์ ๊ฐ๊ฒ์ ์์ ์๋ค.', '๋ช๋ช ์ฌ์ฑ๋ค์ ํด๊ฐ ๋ ํฌ์น์ ํ๊ณ ์๋ค.', '์ฌ์๋ค์ด ์๊ณ ์๋ค.', '์ ์ ์ ๋ณด์ด๋ ์ฌ์๋ค์ด ์๋ค.', '์ฌ๋๋ค์ด ์ค๋ฏ๋ ์ ๋จน๊ณ ์๏ฟฝ๏ฟฝ.', '์ฌ๋๋ค์ด ํ๊ต ์ฑ
์์ ์์ ์๋ค.', '์๋๋ค์ด ์๋น์ ์๋ค.', '๋จ์๊ฐ ์ฃผ์ค๋ฅผ ๋ง์๊ณ ์๋ค.', '๋ ์ฌ์๊ฐ ์๋น์์ ์์ธ์ ๋ง์๊ณ ์๋ค.', '์๋น์์ ํ ๋จ์๊ฐ ์์ฌ๊ฐ ๋์ฐฉํ๊ธฐ๋ฅผ ๊ธฐ๋ค๋ฆฌ๊ณ ์๋ค.', '๊ณต์์ ๋ถ์๋์์ ๋ฌผ์ ๋ง์๋ ๊ธ๋ฐ์ ๋จ์.', '๊ฐ์ ์
์ธ ๋ฅผ ์
์ ๊ธ๋ฐ ๋จ์๊ฐ ๊ณต์์ ๋ฒค์น์์ ์ฑ
์ ์ฝ๊ณ ์๋ค.', '๋ถ์๋์์ ๋ฌผ์ ๋ง์๋ ๊ธ๋ฐ ๋จ์.', '์น๊ตฌ๋ค์ ์ ๋
์ํ์ ๊ฐ๋ ์ฑ์ฐ๊ณ ์๋ก๋ฅผ ๋
ธ๋ ค๋ณธ๋ค.', '์ด ์ฌ์ง์๋ ๋ ์ฌ์๊ฐ ์์ต๋๋ค.', '์น๊ตฌ๋ค์ 20๋
๋ง์ ์ฒ์ ๋ง๋ฌ๊ณ , ๋ฐ๋ผ์ก๋ ๋ฐ ์ฆ๊ฑฐ์ด ์๊ฐ์ ๋ณด๋๋ค.', '๋ ์๋งค๋ ๋ถ๋น๋ ์๋น์ ๊ฐ๋ก์ง๋ฌ ์๋ก๋ฅผ ๋ณด๊ณ ํฌ์น์ ๋๋๋ฉฐ ๋ ๋ค ๊ฐ ๊ฐ๋ฐฉ์ ์์ผ์ก์๋ค.', '๋ ๊ทธ๋ฃน์ ๋ผ์ด๋ฒ ๊ฐฑ๋จ์๋ค์ด ์๋ก๋ฅผ ๋ฐฐ์ ํ๋ค.', '๋ ์ฌ์๊ฐ ์๋ก ๊ปด์๋๋ค.', 'ํ ํ์ด ์น๋ฆฌ๋ฅผ ๊ฑฐ๋๊ธฐ ์ํด ๋์ ์ ์๋ํ๊ณ ์๋ค.', 'ํ ํ์ด ์ฃผ์๋ฅผ ๋ฐ๋๋ฆฌ๋ ค ํ๊ณ ์๋ค.', 'ํ ํ์ด ํ ์ฑ์์ ์ผ๊ตฌ๋ฅผ ํ๊ณ ์๋ค.', 'ํ ํ๊ต๊ฐ ๋๊ตฌ ๊ฒฝ๊ธฐ๋ฅผ ์ฃผ์ตํ๋ค.', 'ํ ๊ณ ๋ฑํ๊ต๊ฐ ํ์ฌ๋ฅผ ์ฃผ์ตํ๊ณ ์๋ค.', 'ํ ํ๊ต๊ฐ ํ์ฌ๋ฅผ ์ฃผ์ตํ๊ณ ์๋ค.', '์ฌ์๋ค์ ์ด๋ค ์ท์ ์
๋ ์๊ดํ์ง ์๋๋ค.', '์ฌ์๋ค์ด ์ ์ฐจ ์์์ ๊ธฐ๋ค๋ฆฌ๊ณ ์๋ค.', '์ฌ์ฑ๋ค์ ์ข์ ํจ์
๊ฐ๊ฐ์ ๊ฐ๋ ๊ฒ์ ์ฆ๊ธด๋ค.', '์ฌ๋ฆ๋ฐฉํ ๋ ํด๋ณ์์ ์๋ง ์๋น ์ ํจ๊ป ์๋ ์์ด.'], 'label': [1, 2, 0, 1, 0, 2, 2, 0, 1, 1, 2, 0, 1, 2, 0, 1, 2, 0, 0, 2, 1, 1, 2, 0, 2, 0, 1, 1, 2, 0, 1, 0, 2, 2, 1, 0, 2, 0, 1, 1]}\n"
|
622 |
+
]
|
623 |
+
},
|
624 |
+
{
|
625 |
+
"data": {
|
626 |
+
"text/plain": [
|
627 |
+
"{'input_ids': [[0, 1041, 2069, 1763, 3611, 2052, 8514, 2336, 7046, 7587, 5603, 17290, 2062, 18, 2, 1891, 3611, 2052, 3855, 2069, 3627, 1041, 2069, 4484, 2067, 2089, 2088, 1513, 2062, 18, 2], [0, 1041, 2069, 1763, 3611, 2052, 8514, 2336, 7046, 7587, 5603, 17290, 2062, 18, 2, 1891, 3611, 2052, 5499, 27135, 3, 4867, 19521, 1513, 2062, 18, 2], [0, 1041, 2069, 1763, 3611, 2052, 8514, 2336, 7046, 7587, 5603, 17290, 2062, 18, 2, 3611, 2073, 8296, 27135, 1041, 2069, 11532, 1513, 2062, 18, 2], [0, 5677, 2170, 1474, 2088, 1284, 2069, 18882, 2259, 3651, 2031, 2, 636, 2031, 2073, 4267, 2098, 2069, 4530, 1474, 2088, 1513, 2062, 2], [0, 5677, 2170, 1474, 2088, 1284, 2069, 18882, 2259, 3651, 2031, 2, 3651, 7285, 1513, 2062, 2], [0, 5677, 2170, 1474, 2088, 1284, 2069, 18882, 2259, 3651, 2031, 2, 3651, 7285, 3977, 2069, 19556, 2088, 1513, 2062, 2], [0, 1891, 5950, 2052, 8013, 5035, 6853, 3832, 19090, 11507, 2170, 5603, 5667, 2088, 1513, 2062, 18, 2, 5950, 2073, 4543, 2138, 3653, 19090, 2138, 25695, 18, 2], [0, 1891, 5950, 2052, 8013, 5035, 6853, 3832, 19090, 11507, 2170, 5603, 5667, 2088, 1513, 2062, 18, 2, 636, 5950, 2073, 19090, 11507, 2138, 1761, 2259, 1087, 2015, 2138, 29121, 2062, 18, 2], [0, 1891, 5950, 2052, 8013, 5035, 6853, 3832, 19090, 11507, 2170, 5603, 5667, 2088, 1513, 2062, 18, 2, 5950, 2052, 4040, 5424, 2138, 7845, 19521, 1513, 2062, 18, 2], [0, 4358, 880, 3997, 2116, 18970, 2079, 1518, 2073, 6889, 2170, 11150, 12228, 2138, 882, 2088, 1379, 2227, 1513, 2088, 1124, 2073, 1245, 10727, 2138, 1511, 2073, 4070, 2031, 2073, 873, 27135, 5658, 2138, 1590, 2088, 1513, 2062, 18, 2, 4358, 880, 3997, 2116, 900, 2052, 8194, 31302, 2138, 5037, 31369, 12228, 2138, 23457, 18, 2], [0, 4358, 880, 3997, 2116, 18970, 2079, 1518, 2073, 6889, 2170, 11150, 12228, 2138, 882, 2088, 1379, 2227, 1513, 2088, 1124, 2073, 1245, 10727, 2138, 1511, 2073, 4070, 2031, 2073, 873, 27135, 5658, 2138, 1590, 2088, 1513, 2062, 18, 2, 1891, 5950, 2052, 14995, 2138, 6762, 2259, 2062, 18, 2], [0, 4358, 880, 3997, 2116, 18970, 2079, 1518, 2073, 6889, 2170, 11150, 12228, 2138, 882, 2088, 1379, 2227, 1513, 2088, 1124, 2073, 1245, 10727, 2138, 1511, 2073, 4070, 2031, 2073, 873, 27135, 5658, 2138, 1590, 2088, 1513, 2062, 18, 2, 1891, 4662, 2052, 1518, 2073, 6042, 2170, 1379, 2227, 1513, 2062, 18, 2], [0, 21459, 3883, 867, 2052, 4084, 15322, 2088, 1513, 2062, 18, 2, 7396, 3811, 2031, 2073, 6493, 904, 20957, 2069, 6159, 1513, 2062, 18, 2], [0, 21459, 3883, 867, 2052, 4084, 15322, 2088, 1513, 2062, 18, 2, 3883, 7285, 5883, 1513, 2062, 18, 2], [0, 21459, 3883, 867, 2052, 4084, 15322, 2088, 1513, 2062, 18, 2, 7689, 2069, 3783, 2259, 3883, 7285, 1513, 2062, 18, 2], [0, 5499, 2170, 1513, 2259, 7396, 3611, 2031, 16, 636, 2031, 1570, 1891, 1076, 2073, 11150, 12228, 2138, 5012, 2088, 1513, 2062, 18, 2, 3611, 7285, 3, 1059, 2088, 1513, 2062, 18, 2], [0, 5499, 2170, 1513, 2259, 7396, 3611, 2031, 16, 636, 2031, 1570, 1891, 1076, 2073, 11150, 12228, 2138, 5012, 2088, 1513, 2062, 18, 2, 3611, 7285, 3741, 7961, 2170, 1379, 2227, 1513, 2062, 18, 2], [0, 5499, 2170, 1513, 2259, 7396, 3611, 2031, 16, 636, 2031, 1570, 1891, 1076, 2073, 11150, 12228, 2138, 5012, 2088, 1513, 2062, 18, 2, 6654, 7285, 5499, 2170, 1513, 2062, 18, 2], [0, 1891, 4358, 2778, 3997, 2116, 5499, 27135, 11150, 12228, 2138, 5012, 2088, 1513, 2062, 18, 2, 3997, 2116, 12228, 2138, 5012, 2088, 1513, 2062, 18, 2], [0, 1891, 4358, 2778, 3997, 2116, 5499, 27135, 11150, 12228, 2138, 5012, 2088, 1513, 2062, 18, 2, 864, 3883, 2116, 5499, 27135, 6612, 2069, 5012, 2088, 1513, 2062, 18, 2], [0, 1891, 4358, 2778, 3997, 2116, 5499, 27135, 11150, 12228, 2138, 5012, 2088, 1513, 2062, 18, 2, 5499, 27135, 1891, 3997, 2116, 5067, 2116, 5082, 31302, 2138, 5037, 2088, 1513, 2062, 18, 2], [0, 21459, 2079, 3997, 16, 3673, 4437, 11790, 2104, 27135, 1299, 2069, 5012, 2259, 14008, 10727, 18, 2, 4599, 2079, 11790, 2104, 27135, 1093, 2069, 5012, 2259, 21459, 2079, 3997, 18, 2], [0, 21459, 2079, 3997, 16, 3673, 4437, 11790, 2104, 27135, 1299, 2069, 5012, 2259, 14008, 10727, 18, 2, 14008, 10727, 2138, 1511, 2073, 21459, 3997, 2116, 4599, 2079, 9262, 27135, 1644, 2069, 1508, 2088, 1513, 2062, 18, 2], [0, 21459, 2079, 3997, 16, 3673, 4437, 11790, 2104, 27135, 1299, 2069, 5012, 2259, 14008, 10727, 18, 2, 11790, 2104, 27135, 1093, 2069, 5012, 2259, 21459, 3997, 18, 2], [0, 9684, 5961, 2069, 1059, 2088, 17812, 2179, 2063, 2138, 1891, 864, 3883, 18, 2, 3949, 2031, 2073, 4750, 10160, 2069, 4983, 8140, 2088, 4084, 2138, 10631, 7471, 18, 2], [0, 9684, 5961, 2069, 1059, 2088, 17812, 2179, 2063, 2138, 1891, 864, 3883, 18, 2, 1504, 4035, 2170, 2259, 864, 3883, 2116, 1513, 2219, 3606, 18, 2], [0, 9684, 5961, 2069, 1059, 2088, 17812, 2179, 2063, 2138, 1891, 864, 3883, 18, 2, 3949, 2031, 2073, 3619, 2440, 1038, 2170, 3790, 5836, 2088, 16, 15856, 2259, 842, 7924, 3641, 2069, 5755, 2062, 18, 2], [0, 864, 3883, 2116, 4182, 11047, 6153, 2138, 882, 2088, 20957, 2069, 1902, 2062, 18, 2, 864, 9646, 2259, 26576, 2259, 5499, 2069, 19642, 4084, 2138, 4530, 20957, 2069, 4835, 2307, 867, 809, 558, 6840, 2069, 13955, 2741, 2886, 2062, 18, 2], [0, 864, 3883, 2116, 4182, 11047, 6153, 2138, 882, 2088, 20957, 2069, 1902, 2062, 18, 2, 864, 4063, 2079, 12574, 563, 2286, 2252, 7285, 4084, 2138, 10004, 2371, 2062, 18, 2], [0, 864, 3883, 2116, 4182, 11047, 6153, 2138, 882, 2088, 20957, 2069, 1902, 2062, 18, 2, 864, 3883, 2116, 4084, 15322, 2259, 2062, 18, 2], [0, 18023, 17665, 1823, 2052, 4082, 3682, 27135, 9763, 2200, 21371, 2259, 6788, 2138, 1523, 2279, 10554, 3605, 18, 2, 1891, 1823, 2052, 4644, 2138, 6387, 2015, 3627, 6300, 2069, 4703, 19521, 1513, 2062, 18, 2], [0, 18023, 17665, 1823, 2052, 4082, 3682, 27135, 9763, 2200, 21371, 2259, 6788, 2138, 1523, 2279, 10554, 3605, 18, 2, 1891, 1823, 2052, 6788, 2138, 18841, 2370, 6159, 1513, 2062, 18, 2], [0, 18023, 17665, 1823, 2052, 4082, 3682, 27135, 9763, 2200, 21371, 2259, 6788, 2138, 1523, 2279, 10554, 3605, 18, 2, 1891, 1823, 2052, 24820, 27135, 4878, 2138, 6159, 1513, 2062, 18, 2], [0, 1504, 3741, 2259, 3656, 3697, 7285, 7291, 27135, 3842, 5778, 4379, 18246, 2170, 3618, 3666, 3697, 2138, 3897, 2223, 2015, 3627, 4014, 2470, 3925, 2138, 1432, 2088, 1513, 2062, 18, 2, 1891, 3741, 2116, 7124, 3682, 2138, 6771, 4538, 18, 2], [0, 1504, 3741, 2259, 3656, 3697, 7285, 7291, 27135, 3842, 5778, 4379, 18246, 2170, 3618, 3666, 3697, 2138, 3897, 2223, 2015, 3627, 4014, 2470, 3925, 2138, 1432, 2088, 1513, 2062, 18, 2, 1891, 5868, 2116, 3925, 2138, 6771, 19521, 1513, 2062, 18, 2], [0, 1504, 3741, 2259, 3656, 3697, 7285, 7291, 27135, 3842, 5778, 4379, 18246, 2170, 3618, 3666, 3697, 2138, 3897, 2223, 2015, 3627, 4014, 2470, 3925, 2138, 1432, 2088, 1513, 2062, 18, 2, 1891, 3741, 2116, 3925, 2138, 6771, 19521, 1513, 2062, 18, 2], [0, 5399, 5179, 9019, 7285, 3763, 2079, 3611, 2031, 1438, 27135, 13533, 4044, 2112, 5037, 2088, 1513, 2062, 18, 2, 3883, 2031, 2073, 3711, 1451, 2069, 1511, 2778, 5468, 2205, 2118, 1380, 2259, 2062, 18, 2], [0, 5399, 5179, 9019, 7285, 3763, 2079, 3611, 2031, 1438, 27135, 13533, 4044, 2112, 5037, 2088, 1513, 2062, 18, 2, 3883, 7285, 13533, 1438, 27135, 5037, 2088, 1513, 2062, 18, 2], [0, 5399, 5179, 9019, 7285, 3763, 2079, 3611, 2031, 1438, 27135, 13533, 4044, 2112, 5037, 2088, 1513, 2062, 18, 2, 3811, 2031, 2073, 1560, 2073, 5179, 5700, 2069, 554, 2259, 575, 2069, 21589, 18, 2], [0, 9738, 27135, 5380, 2259, 3997, 16, 3883, 16, 3651, 18, 2, 4565, 2239, 2218, 904, 9738, 27135, 4122, 5091, 2522, 3655, 1513, 2259, 3651, 18, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}"
|
628 |
+
]
|
629 |
+
},
|
630 |
+
"execution_count": 21,
|
631 |
+
"metadata": {},
|
632 |
+
"output_type": "execute_result"
|
633 |
+
}
|
634 |
+
],
|
635 |
+
"source": [
|
636 |
+
"preprocess_function(datasets[\"train\"][:40])"
|
637 |
+
]
|
638 |
+
},
|
639 |
+
{
|
640 |
+
"cell_type": "markdown",
|
641 |
+
"id": "7b8b48f7",
|
642 |
+
"metadata": {},
|
643 |
+
"source": [
|
644 |
+
"์์ ์ ์ํ process_function์ ์์ ๊ฐ์ด ์ฌ๋ฌ ๊ฐ์ ์์ ๋ฐ์ดํฐ๋ฅผ ๋ฐ์ ์๋ ์์ต๋๋ค."
|
645 |
+
]
|
646 |
+
},
|
647 |
+
{
|
648 |
+
"cell_type": "code",
|
649 |
+
"execution_count": null,
|
650 |
+
"id": "07593d57",
|
651 |
+
"metadata": {},
|
652 |
+
"outputs": [],
|
653 |
+
"source": [
|
654 |
+
"encoded_datasets = datasets.map(preprocess_function, batched=True)"
|
655 |
+
]
|
656 |
+
},
|
657 |
+
{
|
658 |
+
"cell_type": "markdown",
|
659 |
+
"id": "7bbf64ea",
|
660 |
+
"metadata": {},
|
661 |
+
"source": [
|
662 |
+
"์ด์ ์ ์๋ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ํ์ฉํด ๋ฐ์ดํฐ์
์ ๋ฏธ๋ฆฌ ํ ํฐํ์ํค๋ ์์
์ ์ํํฉ๋๋ค.\n",
|
663 |
+
"\n",
|
664 |
+
"datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ํตํด ์ป์ด์ง DatasetDict ๊ฐ์ฒด๋ map() ํจ์๋ฅผ ์ง์ํ๋ฏ๋ก, ์ ์๋ ์ ์ฒ๋ฆฌ ํจ์๋ฅผ ๋ฐ์ดํฐ์
ํ ํฐํ๋ฅผ ์ํ ์ฝ๋ฐฑ ํจ์๋ก map() ํจ์ ์ธ์๋ก ๋๊ฒจ์ฃผ๋ฉด ๋ฉ๋๋ค."
|
665 |
+
]
|
666 |
+
},
|
667 |
+
{
|
668 |
+
"cell_type": "code",
|
669 |
+
"execution_count": null,
|
670 |
+
"id": "840516c6",
|
671 |
+
"metadata": {},
|
672 |
+
"outputs": [],
|
673 |
+
"source": [
|
674 |
+
"num_labels = 3\n",
|
675 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)"
|
676 |
+
]
|
677 |
+
},
|
678 |
+
{
|
679 |
+
"cell_type": "markdown",
|
680 |
+
"id": "80281b6f",
|
681 |
+
"metadata": {},
|
682 |
+
"source": [
|
683 |
+
"ํ์ต์ ์ํ ๋ชจ๋ธ์ ๋ก๋ํฉ๋๋ค.\n",
|
684 |
+
"\n",
|
685 |
+
"์์ ์ดํด๋ณธ ๋ฐ์ ๊ฐ์ด KorNLI์๋ ์ด 3๊ฐ์ ํด๋์ค๊ฐ ์กด์ฌํ๋ฏ๋ก, 3๊ฐ์ ํด๋์ค๋ฅผ ์์ธกํ ์ ์๋ SequenceClassification ๊ตฌ์กฐ๋ก ๋ชจ๋ธ์ ๋ก๋ํฉ๋๋ค.\n",
|
686 |
+
"\n",
|
687 |
+
"๋ชจ๋ธ์ ๋ก๋ํ ๋ ๋ฐ์ํ๋ ๊ฒฝ๊ณ ๋ฌธ๊ตฌ๋ ๋ ๊ฐ์ง ์๋ฏธ๋ฅผ ์ง๋๋๋ค.\n",
|
688 |
+
"1. Masked Language Modeling ์ ์ํด ์กด์ฌํ๋ lm_head๊ฐ ํ์ฌ๋ ์ฌ์ฉ๋์ง ์๊ณ ์์์ ์๋ฏธํฉ๋๋ค.\n",
|
689 |
+
"2. ๋ฌธ์ฅ ๋ถ๋ฅ๋ฅผ ์ํ classifier ๋ ์ด์ด๋ฅผ ๋ฐฑ๋ณธ ๋ชจ๋ธ ๋ค์ ์ด์ด ๋ถ์์ผ๋ ์์ง ํ๋ จ์ด ๋์ง ์์์ผ๋ฏ๋ก, ํ์ต์ ์ํํด์ผ ํจ์ ์๋ฏธํฉ๋๋ค."
|
690 |
+
]
|
691 |
+
},
|
692 |
+
{
|
693 |
+
"cell_type": "code",
|
694 |
+
"execution_count": null,
|
695 |
+
"id": "ea50be8c",
|
696 |
+
"metadata": {},
|
697 |
+
"outputs": [],
|
698 |
+
"source": [
|
699 |
+
"def compute_metrics(eval_pred):\n",
|
700 |
+
" predictions, labels = eval_pred\n",
|
701 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
702 |
+
" return metric.compute(predictions=predictions, references=labels)"
|
703 |
+
]
|
704 |
+
},
|
705 |
+
{
|
706 |
+
"cell_type": "markdown",
|
707 |
+
"id": "d62aef1e",
|
708 |
+
"metadata": {},
|
709 |
+
"source": [
|
710 |
+
"๋ง์ง๋ง์ผ๋ก ์์ ์ ์ํ ๋ฉํธ๋ฆญ์ ๋ชจ๋ธ ์์ธก ๊ฒฐ๊ณผ์ ์ ์ฉํ๊ธฐ ์ํ ํจ์๋ฅผ ์ ์ํฉ๋๋ค.\n",
|
711 |
+
"\n",
|
712 |
+
"์
๋ ฅ์ผ๋ก ๋ค์ด์ค๋ eval_pred๋ EvalPrediction ๊ฐ์ฒด์ด๋ฉฐ, ๋ชจ๋ธ์ ํด๋์ค ๋ณ ์์ธก ๊ฐ๊ณผ ์ ๋ต ๊ฐ์ ์ง๋๋๋ค.\n",
|
713 |
+
"\n",
|
714 |
+
"ํด๋์ค ๋ณ ์์ธก ์ค ๊ฐ์ฅ ๋์ ๋ผ๋ฒจ์ argmax()๋ฅผ ํตํด ๋ฝ์๋ธ ํ, ์ ๋ต ๋ผ๋ฒจ๊ณผ ๋น๊ต๋ฅผ ํ๊ฒ ๋ฉ๋๋ค."
|
715 |
+
]
|
716 |
+
},
|
717 |
+
{
|
718 |
+
"cell_type": "code",
|
719 |
+
"execution_count": null,
|
720 |
+
"id": "09688900",
|
721 |
+
"metadata": {},
|
722 |
+
"outputs": [],
|
723 |
+
"source": [
|
724 |
+
"metric_name = \"accuracy\"\n",
|
725 |
+
"\n",
|
726 |
+
"args = TrainingArguments(\n",
|
727 |
+
" \"test-nli\",\n",
|
728 |
+
" evaluation_strategy=\"epoch\",\n",
|
729 |
+
" learning_rate=2e-5,\n",
|
730 |
+
" per_device_train_batch_size=batch_size,\n",
|
731 |
+
" per_device_eval_batch_size=batch_size,\n",
|
732 |
+
" num_train_epochs=5,\n",
|
733 |
+
" weight_decay=0.01,\n",
|
734 |
+
" load_best_model_at_end=True,\n",
|
735 |
+
" metric_for_best_model=metric_name,\n",
|
736 |
+
")"
|
737 |
+
]
|
738 |
+
},
|
739 |
+
{
|
740 |
+
"cell_type": "markdown",
|
741 |
+
"id": "9a768407",
|
742 |
+
"metadata": {},
|
743 |
+
"source": [
|
744 |
+
"์ด์ ์์ ์ ์ํ ์ ๋ณด๋ค์ ๋ฐํ์ผ๋ก transformers์์ ์ ๊ณตํ๋ Trainer ๊ฐ์ฒด๋ฅผ ํ์ฉํ๊ธฐ ์ํ ์ธ์ ๊ด๋ฆฌ ํด๋์ค๋ฅผ ์ด๊ธฐํํฉ๋๋ค.\n",
|
745 |
+
"\n",
|
746 |
+
"metric_name์ ์์ ์ป์ด์ง ๋ฉํธ๋ฆญ ํจ์๋ฅผ ํ์ฉํ์ ๋, ์๋์ ๊ฐ์ด dict ํ์์ผ๋ก ๊ฒฐ๊ณผ ๊ฐ์ด ๋ฐํ๋๋๋ฐ ์ฌ๊ธฐ์ ์ฐ๋ฆฌ๊ฐ ์ฌ์ฉํ key ๋ฅผ ์ ์ํด์ค๋ค๊ณ ์๊ฐํ์๋ฉด ๋ฉ๋๋ค.\n"
|
747 |
+
]
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"cell_type": "code",
|
751 |
+
"execution_count": null,
|
752 |
+
"id": "080c165e",
|
753 |
+
"metadata": {},
|
754 |
+
"outputs": [],
|
755 |
+
"source": [
|
756 |
+
"trainer = Trainer(\n",
|
757 |
+
" model,\n",
|
758 |
+
" args,\n",
|
759 |
+
" train_dataset=encoded_datasets[\"train\"],\n",
|
760 |
+
" eval_dataset=encoded_datasets[\"validation\"],\n",
|
761 |
+
" tokenizer=tokenizer,\n",
|
762 |
+
" compute_metrics=compute_metrics,\n",
|
763 |
+
")"
|
764 |
+
]
|
765 |
+
},
|
766 |
+
{
|
767 |
+
"cell_type": "markdown",
|
768 |
+
"id": "03a9041d",
|
769 |
+
"metadata": {},
|
770 |
+
"source": [
|
771 |
+
"์ด์ ๋ก๋ํ ๋ชจ๋ธ, ์ธ์ ๊ด๋ฆฌ ํด๋์ค, ๋ฐ์ดํฐ์
๋ฑ์ ์ด์ฉํ์ฌ Trainer๋ฅผ ์ด๊ธฐํ ํด์ค๋๋ค."
|
772 |
+
]
|
773 |
+
},
|
774 |
+
{
|
775 |
+
"cell_type": "code",
|
776 |
+
"execution_count": null,
|
777 |
+
"id": "5cdf8b98",
|
778 |
+
"metadata": {},
|
779 |
+
"outputs": [],
|
780 |
+
"source": [
|
781 |
+
"trainer.train()"
|
782 |
+
]
|
783 |
+
},
|
784 |
+
{
|
785 |
+
"cell_type": "markdown",
|
786 |
+
"id": "515e7b86",
|
787 |
+
"metadata": {},
|
788 |
+
"source": [
|
789 |
+
"์์ ์ด๊ธฐํํ trainer ๊ฐ์ฒด๋ฅผ ํ๋ จ์ํต๋๋ค. ๋ฐ์ดํฐ ์
์ ํฌ๊ธฐ๊ฐ ์ปค์ ๊ทธ๋ฐ์ง ์ ํ๊ฒฝ์์๋ ์ฝ 3์๊ฐ 30๋ถ ์ ๋ ๊ฑธ๋ ธ๋ค์."
|
790 |
+
]
|
791 |
+
},
|
792 |
+
{
|
793 |
+
"cell_type": "code",
|
794 |
+
"execution_count": null,
|
795 |
+
"id": "c8a0e594",
|
796 |
+
"metadata": {},
|
797 |
+
"outputs": [],
|
798 |
+
"source": [
|
799 |
+
"trainer.evaluate()"
|
800 |
+
]
|
801 |
+
},
|
802 |
+
{
|
803 |
+
"cell_type": "markdown",
|
804 |
+
"id": "0eba9443",
|
805 |
+
"metadata": {},
|
806 |
+
"source": [
|
807 |
+
"ํ์ตํ ๋ชจ๋ธ์ evalutate()๋ฅผ ํตํด ํ๊ฐ๋ฅผ ํ์ฌ ์ ํ๋๋ฅผ ํ์ธํด ๋ด
๋๋ค."
|
808 |
+
]
|
809 |
+
},
|
810 |
+
{
|
811 |
+
"cell_type": "code",
|
812 |
+
"execution_count": null,
|
813 |
+
"id": "0f57d941",
|
814 |
+
"metadata": {},
|
815 |
+
"outputs": [],
|
816 |
+
"source": [
|
817 |
+
"trainer.save_model(\"./\")"
|
818 |
+
]
|
819 |
+
},
|
820 |
+
{
|
821 |
+
"cell_type": "markdown",
|
822 |
+
"id": "69a124cf",
|
823 |
+
"metadata": {},
|
824 |
+
"source": [
|
825 |
+
"ํ์ตํ ๋ชจ๋ธ์ ๋์ค์ ์ฌ์ฉํ๊ธฐ ์ํด ์ ์ฅ ๊ฒฝ๋ก์ ์ ์ฅํฉ๋๋ค. ์ ๋ ํ์ฌ ํด๋์ ์ ์ฅํ์์ต๋๋ค.\n",
|
826 |
+
"\n",
|
827 |
+
"## ๋ง์น๋ฉฐ\n",
|
828 |
+
"\n",
|
829 |
+
"์ ๊ฐ ์ธ๊ณต์ง๋ฅ์ ์์ํ์ง ์ผ๋ง ๋์ง ์์ ๊ฐ๋
์ด ๋ ์กํ์๊ณ , ์ ๋ง ์ด๋ ค์ ํ์๋๋ฐ Huffon๋์ [github](https://github.com/Huffon)์ ์ฌ๋ผ๊ฐ์๋ ๋
ธํธ๋ถ์ ๋ณด๋ฉฐ ๋ง์ด ๋ฐฐ์ด ๊ฒ ๊ฐ์ต๋๋ค. ๋ณธ ๋
ธํธ๋ถ์ Huffon๋์ ์์ ์ฝ๋์ ์ค๋ช
์ ๋ฐ๋ผ์น ๊ฒ์ด ๋๋ถ๋ถ์ด์ง๋ง ๋์ผ๋ก ๋ณด๋ ๊ฒ๊ณผ ์ค์ ํด๋ณด๋ ๊ฒ์ ์ฐจ์ด๊ฐ ์์ฒญ ํฌ๋ค๋ ๊ฒ์ ๋๊ผ์ต๋๋ค."
|
830 |
+
]
|
831 |
+
}
|
832 |
+
],
|
833 |
+
"metadata": {
|
834 |
+
"kernelspec": {
|
835 |
+
"display_name": "Python 3",
|
836 |
+
"language": "python",
|
837 |
+
"name": "python3"
|
838 |
+
},
|
839 |
+
"language_info": {
|
840 |
+
"codemirror_mode": {
|
841 |
+
"name": "ipython",
|
842 |
+
"version": 3
|
843 |
+
},
|
844 |
+
"file_extension": ".py",
|
845 |
+
"mimetype": "text/x-python",
|
846 |
+
"name": "python",
|
847 |
+
"nbconvert_exporter": "python",
|
848 |
+
"pygments_lexer": "ipython3",
|
849 |
+
"version": "3.7.9"
|
850 |
+
}
|
851 |
+
},
|
852 |
+
"nbformat": 4,
|
853 |
+
"nbformat_minor": 5
|
854 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:714a296e4cf0c0f584471b4291f7b8a2af77675c9a4f98feba067c674209ed78
|
3 |
+
size 442575305
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "[CLS]", "eos_token": "[SEP]", "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": false, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "do_basic_tokenize": true, "never_split": null, "bos_token": "[CLS]", "eos_token": "[SEP]", "model_max_length": 512, "special_tokens_map_file": "/workspace/.cache/huggingface/transformers/9d0c87e44b00acfbfbae931b2e4068eb6311a0c3e71e23e5400bdf57cab4bfbf.70c17d6e4d492c8f24f5bb97ab56c7f272e947112c6faf9dd846da42ba13eb23", "name_or_path": "klue/roberta-base", "tokenizer_class": "BertTokenizer"}
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:66f689c35735c17dcdcddb56ac4d319208e4c9f41b16b09e71fe2d3fdec197dc
|
3 |
+
size 2607
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|