liwii commited on
Commit
5473c52
1 Parent(s): 31c0541

Training in progress, epoch 1

Browse files
added_tokens.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "<pad>": 0,
3
+ "<unk>": 1,
4
+ "[CLS]": 2,
5
+ "[MASK]": 4,
6
+ "[SEP]": 3
7
+ }
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "line-corporation/line-distilbert-base-japanese",
3
+ "activation": "gelu",
4
+ "architectures": [
5
+ "ConsistentSentenceClassifier"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "dim": 768,
9
+ "dropout": 0.1,
10
+ "hidden_dim": 3072,
11
+ "id2label": {
12
+ "0": "contradiction",
13
+ "1": "neutral",
14
+ "2": "entailment"
15
+ },
16
+ "initializer_range": 0.02,
17
+ "label2id": {
18
+ "contradiction": 0,
19
+ "entailment": 2,
20
+ "neutral": 1
21
+ },
22
+ "max_position_embeddings": 512,
23
+ "model_type": "distilbert",
24
+ "n_heads": 12,
25
+ "n_layers": 6,
26
+ "output_hidden_states": true,
27
+ "pad_token_id": 0,
28
+ "problem_type": "single_label_classification",
29
+ "qa_dropout": 0.1,
30
+ "seq_classif_dropout": 0.2,
31
+ "sinusoidal_pos_embds": true,
32
+ "tie_weights_": true,
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.34.0",
35
+ "vocab_size": 32768
36
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0882f733fa2213d98351829cca6db388f9664784abc99d90cb70d48beaaf16e6
3
+ size 274758317
special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[CLS]",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "[SEP]",
5
+ "mask_token": "[MASK]",
6
+ "pad_token": "<pad>",
7
+ "sep_token": "[SEP]",
8
+ "unk_token": "<unk>"
9
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcfafc8c0662d9c8f39621a64c74260f2ad120310c8dd24886de2dddaf599b4e
3
+ size 439391
tokenizer_config.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<pad>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<unk>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "additional_special_tokens": [],
45
+ "auto_map": {
46
+ "AutoTokenizer": [
47
+ "line-corporation/line-distilbert-base-japanese--distilbert_japanese_tokenizer.DistilBertJapaneseTokenizer",
48
+ null
49
+ ]
50
+ },
51
+ "bos_token": "[CLS]",
52
+ "clean_up_tokenization_spaces": true,
53
+ "cls_token": "[CLS]",
54
+ "do_lower_case": true,
55
+ "do_subword_tokenize": true,
56
+ "do_word_tokenize": true,
57
+ "eos_token": "[SEP]",
58
+ "jumanpp_kwargs": null,
59
+ "keep_accents": true,
60
+ "mask_token": "[MASK]",
61
+ "mecab_kwargs": {
62
+ "mecab_dic": "unidic_lite"
63
+ },
64
+ "model_max_length": 1000000000000000019884624838656,
65
+ "never_split": null,
66
+ "pad_token": "<pad>",
67
+ "remove_space": true,
68
+ "sep_token": "[SEP]",
69
+ "subword_tokenizer_type": "sentencepiece",
70
+ "sudachi_kwargs": null,
71
+ "tokenize_chinese_chars": false,
72
+ "tokenizer_class": "BertJapaneseTokenizer",
73
+ "tokenizer_file": null,
74
+ "unk_token": "<unk>",
75
+ "word_tokenizer_type": "mecab"
76
+ }
train-v1.1.json ADDED
The diff for this file is too large to render. See raw diff
 
train_factual_consistency.ipynb ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "b12ae8a3-9e08-402c-894c-31697fad6c56",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "application/vnd.jupyter.widget-view+json": {
12
+ "model_id": "54d7e7ee895949c4a025acf2c9640f96",
13
+ "version_major": 2,
14
+ "version_minor": 0
15
+ },
16
+ "text/plain": [
17
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
18
+ ]
19
+ },
20
+ "metadata": {},
21
+ "output_type": "display_data"
22
+ }
23
+ ],
24
+ "source": [
25
+ "from huggingface_hub import notebook_login\n",
26
+ "\n",
27
+ "notebook_login()"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "id": "160c80c1-0ca4-45df-8171-87cd3c88a223",
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "\n",
38
+ "from transformers import (\n",
39
+ " AutoTokenizer,\n",
40
+ " DataCollatorWithPadding,\n",
41
+ " Trainer,\n",
42
+ " TrainingArguments,\n",
43
+ ")\n",
44
+ "from utils import ConsistentSentenceClassifier, get_metrics, load_dataset"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 3,
50
+ "id": "25800588-5d42-4524-9dc6-a6a0c180b8b0",
51
+ "metadata": {},
52
+ "outputs": [
53
+ {
54
+ "name": "stdout",
55
+ "output_type": "stream",
56
+ "text": [
57
+ " text label\n",
58
+ "512 カーキ色の服を着た男性が、口元にリンゴを当てています。[SEP]カーキ色の服を着た男性が、口... 0\n",
59
+ "513 男性がグラウンドでボールを投げています。[SEP]白い髯を生やした男性がボールを投げています。 1\n",
60
+ "514 椅子に座った子供が、手づかみで食事をしています。[SEP]椅子に座った子供が手づかみで、食事... 2\n",
61
+ "515 プロペラ機が何台も駐機しています。[SEP]プロペラ機が何台も連なって飛んでいます。 0\n",
62
+ "516 消火栓から水が勢いよく噴き出しています。[SEP]水が噴き出している消火栓の水を浴びるように... 1\n",
63
+ "517 冷蔵庫のないキッチンにナイフとフォークが置かれています。[SEP]冷蔵庫の置かれたキッチンに... 0\n",
64
+ "518 うみでサーフィンをしているひとがいます。[SEP]黒いウェットスーツを着た人がサーフボードに... 1\n",
65
+ "519 池から白い鳥が飛び立っています。[SEP]森にある水の上を鳥が飛んでいます。 1\n",
66
+ "520 丈夫なビーチパラソルが立っています。[SEP]ビーチパラソルの支柱が折れ曲がっています。 0\n",
67
+ "521 白髪の男性が少女から花束を受け取っています。[SEP]花束を持った男性の前に多くの子供たちが... 1\n",
68
+ " text label\n",
69
+ "0 赤いひとつの傘に、二人の人が入っています。[SEP]歩道を歩く通行人が傘をさして歩いています。 1\n",
70
+ "1 川を小さなボートが進んで行きます。[SEP]川を豪華客船が進んでいきます。 0\n",
71
+ "2 ゲレンデのこぶでスキージャンプしています。[SEP]雪上でモーグルを楽しむ水色のウェアを着た女性。 1\n",
72
+ "3 黒いお皿に乗っているピザをカットしています。[SEP]黒い皿の上にピザが盛られています。 2\n",
73
+ "4 女性が目を細めて携帯電話で話をしています。[SEP]目を細めた女性が携帯電話で話をしています。 2\n",
74
+ "5 バナナやパパイヤなどの果物が売られている。[SEP]台の上にはバナナなどの青果が並べられています。 1\n",
75
+ "6 ヘッドライトを点灯させた白いバスが駐車場に止まっています。[SEP]ライトを点灯させているバ... 2\n",
76
+ "7 水面の上に、カイトサーフィンの凧が揚がっています。[SEP]海の上に水上スポーツ用の凧が揚が... 1\n",
77
+ "8 ホットドッグを野外で食べている人たちです。[SEP]家の中でホットドッグを食べている。 0\n",
78
+ "9 草が生い茂っている所に、3頭のゾウがいます。[SEP]草むらの中に三頭のゾウが立っているとこ... 1\n"
79
+ ]
80
+ },
81
+ {
82
+ "data": {
83
+ "application/vnd.jupyter.widget-view+json": {
84
+ "model_id": "014ef81bb16c41a383f86e2ddc5ca383",
85
+ "version_major": 2,
86
+ "version_minor": 0
87
+ },
88
+ "text/plain": [
89
+ "Map: 0%| | 0/19561 [00:00<?, ? examples/s]"
90
+ ]
91
+ },
92
+ "metadata": {},
93
+ "output_type": "display_data"
94
+ },
95
+ {
96
+ "name": "stderr",
97
+ "output_type": "stream",
98
+ "text": [
99
+ "Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.\n",
100
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
101
+ ]
102
+ },
103
+ {
104
+ "data": {
105
+ "application/vnd.jupyter.widget-view+json": {
106
+ "model_id": "7024f08501734cd188d3b8c6dc8495eb",
107
+ "version_major": 2,
108
+ "version_minor": 0
109
+ },
110
+ "text/plain": [
111
+ "Map: 0%| | 0/512 [00:00<?, ? examples/s]"
112
+ ]
113
+ },
114
+ "metadata": {},
115
+ "output_type": "display_data"
116
+ }
117
+ ],
118
+ "source": [
119
+ "tokenizer = AutoTokenizer.from_pretrained(\"line-corporation/line-distilbert-base-japanese\")\n",
120
+ "dataset = load_dataset('train-v1.1.json')\n",
121
+ "tokenized_dataset = dataset.map(\n",
122
+ " lambda examples: tokenizer(examples[\"text\"], padding='max_length', truncation=True), batched=True\n",
123
+ ")"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "id": "6bc83d4c-378c-4313-b641-8ead0c02f715",
130
+ "metadata": {},
131
+ "outputs": [
132
+ {
133
+ "name": "stderr",
134
+ "output_type": "stream",
135
+ "text": [
136
+ "WARNING:root:XRT configuration not detected. Defaulting to preview PJRT runtime. To silence this warning and continue using PJRT, explicitly set PJRT_DEVICE to a supported device or configure XRT. To disable default device selection, set PJRT_SELECT_DEFAULT_DEVICE=0\n",
137
+ "WARNING:root:For more information about the status of PJRT, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md\n",
138
+ "WARNING:root:Defaulting to PJRT_DEVICE=CPU\n"
139
+ ]
140
+ }
141
+ ],
142
+ "source": [
143
+ "model = ConsistentSentenceClassifier(\n",
144
+ " freeze_bert=True)\n",
145
+ "\n",
146
+ "training_args = TrainingArguments(\n",
147
+ " output_dir=\"../factual-consistency-classification-ja-avgpool\",\n",
148
+ " learning_rate=1e-4,\n",
149
+ " per_device_train_batch_size=64,\n",
150
+ " per_device_eval_batch_size=8,\n",
151
+ " num_train_epochs=30,\n",
152
+ " weight_decay=0.02,\n",
153
+ " evaluation_strategy=\"epoch\",\n",
154
+ " eval_accumulation_steps=4,\n",
155
+ " save_strategy=\"epoch\",\n",
156
+ " load_best_model_at_end=True,\n",
157
+ " save_total_limit=5,\n",
158
+ " push_to_hub=True,\n",
159
+ ")\n",
160
+ "\n",
161
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
162
+ "trainer = Trainer(\n",
163
+ " model=model,\n",
164
+ " args=training_args,\n",
165
+ " train_dataset=tokenized_dataset[\"train\"],\n",
166
+ " eval_dataset=tokenized_dataset[\"test\"],\n",
167
+ " tokenizer=tokenizer,\n",
168
+ " data_collator=data_collator,\n",
169
+ " compute_metrics=get_metrics(),\n",
170
+ ")\n",
171
+ "\n",
172
+ "trainer.train()\n",
173
+ "trainer.push_to_hub('factual-consistency-classification-ja-avgpool')"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "id": "a6eb93f7-5a38-49a2-be0d-e42267e23a0a",
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": []
183
+ }
184
+ ],
185
+ "metadata": {
186
+ "environment": {
187
+ "kernel": "python3",
188
+ "name": "pytorch-gpu.2-0.m112",
189
+ "type": "gcloud",
190
+ "uri": "gcr.io/deeplearning-platform-release/pytorch-gpu.2-0:m112"
191
+ },
192
+ "kernelspec": {
193
+ "display_name": "Python 3",
194
+ "language": "python",
195
+ "name": "python3"
196
+ },
197
+ "language_info": {
198
+ "codemirror_mode": {
199
+ "name": "ipython",
200
+ "version": 3
201
+ },
202
+ "file_extension": ".py",
203
+ "mimetype": "text/x-python",
204
+ "name": "python",
205
+ "nbconvert_exporter": "python",
206
+ "pygments_lexer": "ipython3",
207
+ "version": "3.10.12"
208
+ }
209
+ },
210
+ "nbformat": 4,
211
+ "nbformat_minor": 5
212
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85ab236f32cbef6bf6bf1c471ec41a7226363e3945c5bf62fbf3728eca74dee1
3
+ size 4155
utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ import datasets
4
+ import numpy as np
5
+ import evaluate
6
+ import torch
7
+ from transformers import AutoModel, DistilBertForSequenceClassification
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+ from typing import Optional
10
+
11
+ SEP_TOKEN = '[SEP]'
12
+ LABEL2ID = {'entailment': 2, 'neutral': 1, 'contradiction': 0}
13
+ ID2LABEL = {2: 'entailment', 1: 'neutral', 0: 'contradiction'}
14
+
15
+ def format_dataset(arr):
16
+ text = [el['sentence1'] + SEP_TOKEN + el['sentence2'] for el in arr]
17
+ label = [LABEL2ID[el['label']] for el in arr]
18
+ new_df = pd.DataFrame({'text': text, 'label': label})
19
+ return new_df.sample(frac=1, random_state=42).reset_index(drop=True)
20
+
21
+ # Load dataset
22
+ def load_dataset(path):
23
+ train_array = []
24
+ with open(path) as f:
25
+ for line in f.readlines():
26
+ if line:
27
+ train_array.append(json.loads(line))
28
+ df = format_dataset(train_array)
29
+ # Split dataset into train and val
30
+ df_train = df.iloc[512:, :]
31
+ # We do not need much test data
32
+ df_test = df.iloc[:512, :]
33
+ print(df_train[:10])
34
+ print(df_test[:10])
35
+
36
+ factual_consistency_dataset = datasets.dataset_dict.DatasetDict()
37
+ factual_consistency_dataset["train"] = datasets.dataset_dict.Dataset.from_pandas(
38
+ df_train[["text", "label"]])
39
+ factual_consistency_dataset["test"] = datasets.dataset_dict.Dataset.from_pandas(
40
+ df_test[["text", "label"]])
41
+
42
+ return factual_consistency_dataset
43
+
44
+
45
+ class ConsistentSentenceClassifier(DistilBertForSequenceClassification):
46
+
47
+ def __init__(self, freeze_bert=True):
48
+ base_model = AutoModel.from_pretrained(
49
+ 'line-corporation/line-distilbert-base-japanese', num_labels=3)
50
+
51
+ config = base_model.config
52
+ super(ConsistentSentenceClassifier, self).__init__(config=config)
53
+ config.num_labels = 3
54
+ config.id2label = ID2LABEL
55
+ config.label2id = LABEL2ID
56
+ config.problem_type = "single_label_classification"
57
+
58
+ self.distilbert = base_model
59
+
60
+ if not freeze_bert:
61
+ return
62
+
63
+ for param in self.distilbert.parameters():
64
+ param.requires_grad = False
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: Optional[torch.Tensor] = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ head_mask: Optional[torch.Tensor] = None,
71
+ inputs_embeds: Optional[torch.Tensor] = None,
72
+ labels: Optional[torch.LongTensor] = None,
73
+ output_attentions: Optional[bool] = None,
74
+ output_hidden_states: Optional[bool] = None,
75
+ return_dict: Optional[bool] = None,
76
+ ):
77
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
78
+
79
+ distilbert_output = self.distilbert(
80
+ input_ids=input_ids,
81
+ attention_mask=attention_mask,
82
+ head_mask=head_mask,
83
+ inputs_embeds=inputs_embeds,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict,
87
+ )
88
+ hidden_state = distilbert_output[0] # (bs, seq_len, dim)
89
+ pooled_output = torch.mean(hidden_state, dim=1)
90
+ pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
91
+ pooled_output = torch.nn.ReLU()(pooled_output) # (bs, dim)
92
+ pooled_output = self.dropout(pooled_output) # (bs, dim)
93
+ logits = self.classifier(pooled_output) # (bs, num_labels)
94
+
95
+ loss = None
96
+ if labels is not None:
97
+ loss_fct = torch.nn.CrossEntropyLoss()
98
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
99
+
100
+ if not return_dict:
101
+ output = (logits,) + distilbert_output[1:]
102
+ return ((loss,) + output) if loss is not None else output
103
+
104
+ return SequenceClassifierOutput(
105
+ loss=loss,
106
+ logits=logits,
107
+ hidden_states=distilbert_output.hidden_states,
108
+ attentions=distilbert_output.attentions,
109
+ )
110
+
111
+
112
+
113
+ # Set up evaluation metridef get_metrics():
114
+
115
+ def get_metrics():
116
+ metric = evaluate.load("accuracy")
117
+
118
+ def compute_metrics(eval_pred):
119
+ predictions, labels = eval_pred
120
+ preds = predictions[0].argmax(axis=1)
121
+ return metric.compute(predictions=preds, references=labels)
122
+
123
+ return compute_metrics