Benjamin S Liang commited on
Commit
cfabf1c
1 Parent(s): f59b608

Added finetuned model

Browse files

Checkpoints in /results/ folder

.ipynb_checkpoints/finetunehupd-checkpoint.ipynb ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "1df3c609-62a6-49c3-bcc6-29c520f9501c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# -*- coding: utf-8 -*-\n",
11
+ "\"\"\"FinetuneHUPD.ipynb\n",
12
+ "\n",
13
+ "Automatically generated by Colaboratory.\n",
14
+ "\n",
15
+ "Original file is located at\n",
16
+ " https://colab.research.google.com/drive/17c2CQZx_kyD3-0fuQqv_pCMJ0Evd7fLN\n",
17
+ "\"\"\"\n",
18
+ "\n",
19
+ "# Pretty print\n",
20
+ "from pprint import pprint\n",
21
+ "# Datasets load_dataset function\n",
22
+ "from datasets import load_dataset\n",
23
+ "# Transformers Autokenizer\n",
24
+ "from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizer, Trainer, TrainingArguments, AdamW\n",
25
+ "from torch.utils.data import DataLoader\n",
26
+ "import torch"
27
+ ]
28
+ }
29
+ ],
30
+ "metadata": {
31
+ "kernelspec": {
32
+ "display_name": "Python 3 (ipykernel)",
33
+ "language": "python",
34
+ "name": "python3"
35
+ },
36
+ "language_info": {
37
+ "codemirror_mode": {
38
+ "name": "ipython",
39
+ "version": 3
40
+ },
41
+ "file_extension": ".py",
42
+ "mimetype": "text/x-python",
43
+ "name": "python",
44
+ "nbconvert_exporter": "python",
45
+ "pygments_lexer": "ipython3",
46
+ "version": "3.9.16"
47
+ }
48
+ },
49
+ "nbformat": 4,
50
+ "nbformat_minor": 5
51
+ }
.ipynb_checkpoints/finetunehupd-checkpoint.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """FinetuneHUPD.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/17c2CQZx_kyD3-0fuQqv_pCMJ0Evd7fLN
8
+ """
9
+
10
+ # Pretty print
11
+ from pprint import pprint
12
+ # Datasets load_dataset function
13
+ from datasets import load_dataset
14
+ # Transformers Autokenizer
15
+ from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizer, Trainer, TrainingArguments, AdamW
16
+ from torch.utils.data import DataLoader
17
+ import torch
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
20
+
21
+ dataset_dict = load_dataset('HUPD/hupd',
22
+ name='sample',
23
+ data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",
24
+ icpr_label=None,
25
+ train_filing_start_date='2016-01-01',
26
+ train_filing_end_date='2016-01-31',
27
+ val_filing_start_date='2016-01-01',
28
+ val_filing_end_date='2016-01-31',
29
+ )
30
+
31
+ print('Loading is done!')
32
+
33
+ # Label-to-index mapping for the decision status field
34
+ decision_to_str = {'REJECTED': 0, 'ACCEPTED': 1, 'PENDING': 2, 'CONT-REJECTED': 3, 'CONT-ACCEPTED': 4, 'CONT-PENDING': 5}
35
+
36
+ # Helper function
37
+ def map_decision_to_string(example):
38
+ return {'decision': decision_to_str[example['decision']]}
39
+
40
+ # Re-labeling/mapping.
41
+ train_set = dataset_dict['train'].map(map_decision_to_string)
42
+ val_set = dataset_dict['validation'].map(map_decision_to_string)
43
+
44
+ # Focus on the abstract section and tokenize the text using the tokenizer.
45
+ _SECTION_ = 'abstract'
46
+
47
+ # Training set
48
+ train_set = train_set.map(
49
+ lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),
50
+ batched=True)
51
+
52
+ # Validation set
53
+ val_set = val_set.map(
54
+ lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),
55
+ batched=True)
56
+
57
+ # Set the format
58
+ train_set.set_format(type='torch',
59
+ columns=['input_ids', 'attention_mask', 'decision'])
60
+
61
+ val_set.set_format(type='torch',
62
+ columns=['input_ids', 'attention_mask', 'decision'])
63
+
64
+ #print(train_set['decision'])
65
+
66
+ # train_dataloader and val_data_loader
67
+ train_dataloader = DataLoader(train_set, batch_size=16)
68
+ val_dataloader = DataLoader(val_set, batch_size=16)
69
+
70
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
71
+ model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
72
+ model.to(device)
73
+ print(device)
74
+ print("torch cuda is avail: ")
75
+ print(torch.cuda.is_available())
76
+ model.train()
77
+ optim = AdamW(model.parameters(), lr=5e-5)
78
+ num_training_epochs = 2
79
+
80
+ for epoch in range(num_training_epochs):
81
+ for batch in train_dataloader:
82
+ optim.zero_grad()
83
+ input_ids = batch['input_ids'].to(device)
84
+ attention_mask = batch['attention_mask'].to(device)
85
+ labels = batch['decision'].to(device)
86
+ outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
87
+ loss = outputs[0]
88
+ loss.backward()
89
+ optim.step()
90
+ print("batch finished")
91
+
92
+ model.eval()
finetunehupd.ipynb ADDED
@@ -0,0 +1,1293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "1df3c609-62a6-49c3-bcc6-29c520f9501c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# Pretty print\n",
11
+ "from pprint import pprint\n",
12
+ "# Datasets load_dataset function\n",
13
+ "from datasets import load_dataset\n",
14
+ "# Transformers Autokenizer\n",
15
+ "from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertTokenizerFast, Trainer, TrainingArguments, AdamW\n",
16
+ "from torch.utils.data import DataLoader\n",
17
+ "import torch"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": 2,
23
+ "id": "58167c28-eb27-4f82-b7d0-8216dbeaf650",
24
+ "metadata": {},
25
+ "outputs": [
26
+ {
27
+ "name": "stderr",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "Found cached dataset hupd (C:/Users/calia/.cache/huggingface/datasets/HUPD___hupd/sample-5094df4de61ed3bc/0.0.0/6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142)\n"
31
+ ]
32
+ },
33
+ {
34
+ "data": {
35
+ "application/vnd.jupyter.widget-view+json": {
36
+ "model_id": "345008775bf549b5a548948949710507",
37
+ "version_major": 2,
38
+ "version_minor": 0
39
+ },
40
+ "text/plain": [
41
+ " 0%| | 0/2 [00:00<?, ?it/s]"
42
+ ]
43
+ },
44
+ "metadata": {},
45
+ "output_type": "display_data"
46
+ },
47
+ {
48
+ "name": "stdout",
49
+ "output_type": "stream",
50
+ "text": [
51
+ "Loading is done!\n"
52
+ ]
53
+ }
54
+ ],
55
+ "source": [
56
+ "tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')\n",
57
+ "\n",
58
+ "dataset_dict = load_dataset('HUPD/hupd',\n",
59
+ " name='sample',\n",
60
+ " data_files=\"https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather\", \n",
61
+ " icpr_label=None,\n",
62
+ " train_filing_start_date='2016-01-01',\n",
63
+ " train_filing_end_date='2016-01-21',\n",
64
+ " val_filing_start_date='2016-01-22',\n",
65
+ " val_filing_end_date='2016-01-31',\n",
66
+ ")\n",
67
+ "\n",
68
+ "print('Loading is done!')"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 3,
74
+ "id": "e13c6ad1-a7f2-4806-80a2-e9c4655e1eed",
75
+ "metadata": {},
76
+ "outputs": [
77
+ {
78
+ "name": "stderr",
79
+ "output_type": "stream",
80
+ "text": [
81
+ "Loading cached processed dataset at C:\\Users\\calia\\.cache\\huggingface\\datasets\\HUPD___hupd\\sample-5094df4de61ed3bc\\0.0.0\\6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142\\cache-9f7788eb9924fd62.arrow\n",
82
+ "Loading cached processed dataset at C:\\Users\\calia\\.cache\\huggingface\\datasets\\HUPD___hupd\\sample-5094df4de61ed3bc\\0.0.0\\6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142\\cache-6c3687322fe5b556.arrow\n",
83
+ "Loading cached processed dataset at C:\\Users\\calia\\.cache\\huggingface\\datasets\\HUPD___hupd\\sample-5094df4de61ed3bc\\0.0.0\\6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142\\cache-bd3b1eee4495f3ce.arrow\n"
84
+ ]
85
+ },
86
+ {
87
+ "data": {
88
+ "application/vnd.jupyter.widget-view+json": {
89
+ "model_id": "",
90
+ "version_major": 2,
91
+ "version_minor": 0
92
+ },
93
+ "text/plain": [
94
+ "Map: 0%| | 0/9094 [00:00<?, ? examples/s]"
95
+ ]
96
+ },
97
+ "metadata": {},
98
+ "output_type": "display_data"
99
+ }
100
+ ],
101
+ "source": [
102
+ "# Label-to-index mapping for the decision status field\n",
103
+ "decision_to_str = {'REJECTED': 0, 'ACCEPTED': 1, 'PENDING': 0, 'CONT-REJECTED': 0, 'CONT-ACCEPTED': 0, 'CONT-PENDING': 0}\n",
104
+ "\n",
105
+ "# Helper function\n",
106
+ "def map_decision_to_string(example):\n",
107
+ " return {'decision': decision_to_str[example['decision']]}\n",
108
+ "\n",
109
+ "# Re-labeling/mapping.\n",
110
+ "train_set = dataset_dict['train'].map(map_decision_to_string)\n",
111
+ "val_set = dataset_dict['validation'].map(map_decision_to_string)\n",
112
+ "\n",
113
+ "# Focus on the abstract section and tokenize the text using the tokenizer. \n",
114
+ "_SECTION_ = 'abstract'\n",
115
+ "\n",
116
+ "# Training set\n",
117
+ "train_set = train_set.map(\n",
118
+ " lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),\n",
119
+ " batched=True)\n",
120
+ "\n",
121
+ "# Validation set\n",
122
+ "val_set = val_set.map(\n",
123
+ " lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),\n",
124
+ " batched=True)"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": 4,
130
+ "id": "b5c098be-019b-42ce-9b80-4f6de93ef6a3",
131
+ "metadata": {},
132
+ "outputs": [
133
+ {
134
+ "data": {
135
+ "text/plain": [
136
+ "Dataset({\n",
137
+ " features: ['patent_number', 'decision', 'title', 'abstract', 'claims', 'background', 'summary', 'description', 'cpc_label', 'ipc_label', 'filing_date', 'patent_issue_date', 'date_published', 'examiner_id', 'input_ids', 'attention_mask'],\n",
138
+ " num_rows: 16153\n",
139
+ "})"
140
+ ]
141
+ },
142
+ "execution_count": 4,
143
+ "metadata": {},
144
+ "output_type": "execute_result"
145
+ }
146
+ ],
147
+ "source": [
148
+ "train_set"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": 5,
154
+ "id": "1e5a5390-19fe-4a73-b913-e3c1e2c2a399",
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "data": {
159
+ "text/plain": [
160
+ "Dataset({\n",
161
+ " features: ['patent_number', 'decision', 'title', 'abstract', 'claims', 'background', 'summary', 'description', 'cpc_label', 'ipc_label', 'filing_date', 'patent_issue_date', 'date_published', 'examiner_id', 'input_ids', 'attention_mask'],\n",
162
+ " num_rows: 9094\n",
163
+ "})"
164
+ ]
165
+ },
166
+ "execution_count": 5,
167
+ "metadata": {},
168
+ "output_type": "execute_result"
169
+ }
170
+ ],
171
+ "source": [
172
+ "val_set"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 6,
178
+ "id": "4fb69db8-86e5-4c6c-8ac6-853d3e15fb93",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "train_set = train_set.remove_columns([\"patent_number\", \"title\", \"abstract\", \"claims\", \"background\", \"summary\", \"description\", \"cpc_label\", \"ipc_label\", \"filing_date\", \"patent_issue_date\", \"date_published\", \"examiner_id\"])\n",
183
+ "val_set = val_set.remove_columns([\"patent_number\", \"title\", \"abstract\", \"claims\", \"background\", \"summary\", \"description\", \"cpc_label\", \"ipc_label\", \"filing_date\", \"patent_issue_date\", \"date_published\", \"examiner_id\"])\n",
184
+ "\n",
185
+ "train_set = train_set.rename_column(\"decision\", \"labels\")\n",
186
+ "val_set = val_set.rename_column(\"decision\", \"labels\")"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": 7,
192
+ "id": "c0d17213-4b14-418c-980c-0238236096c2",
193
+ "metadata": {},
194
+ "outputs": [
195
+ {
196
+ "data": {
197
+ "text/plain": [
198
+ "Dataset({\n",
199
+ " features: ['labels', 'input_ids', 'attention_mask'],\n",
200
+ " num_rows: 16153\n",
201
+ "})"
202
+ ]
203
+ },
204
+ "execution_count": 7,
205
+ "metadata": {},
206
+ "output_type": "execute_result"
207
+ }
208
+ ],
209
+ "source": [
210
+ "train_set"
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": 8,
216
+ "id": "da2f1c16-3ba4-4e56-9455-5cd838df4dcd",
217
+ "metadata": {},
218
+ "outputs": [
219
+ {
220
+ "data": {
221
+ "text/plain": [
222
+ "Dataset({\n",
223
+ " features: ['labels', 'input_ids', 'attention_mask'],\n",
224
+ " num_rows: 9094\n",
225
+ "})"
226
+ ]
227
+ },
228
+ "execution_count": 8,
229
+ "metadata": {},
230
+ "output_type": "execute_result"
231
+ }
232
+ ],
233
+ "source": [
234
+ "val_set"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": 9,
240
+ "id": "cfb35702-863d-4fec-83e1-44c4e5668156",
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": [
244
+ "# Set the format\n",
245
+ "train_set.set_format(type='torch', \n",
246
+ " columns=['labels', 'input_ids', 'attention_mask'])\n",
247
+ "\n",
248
+ "val_set.set_format(type='torch', \n",
249
+ " columns=['labels', 'input_ids', 'attention_mask'])"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 10,
255
+ "id": "d7ac796a-9f6e-4213-960f-e17837c27d87",
256
+ "metadata": {},
257
+ "outputs": [],
258
+ "source": [
259
+ "# train_dataloader and val_data_loader\n",
260
+ "train_dataloader = DataLoader(train_set, batch_size=16)\n",
261
+ "val_dataloader = DataLoader(val_set, batch_size=16)"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": 11,
267
+ "id": "b3248182-fddb-46dc-addb-26981a881d99",
268
+ "metadata": {},
269
+ "outputs": [
270
+ {
271
+ "name": "stderr",
272
+ "output_type": "stream",
273
+ "text": [
274
+ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias']\n",
275
+ "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
276
+ "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
277
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias', 'classifier.weight']\n",
278
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
279
+ ]
280
+ },
281
+ {
282
+ "name": "stdout",
283
+ "output_type": "stream",
284
+ "text": [
285
+ "cuda\n",
286
+ "torch cuda is avail: \n",
287
+ "True\n"
288
+ ]
289
+ }
290
+ ],
291
+ "source": [
292
+ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
293
+ "model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')\n",
294
+ "model.to(device)\n",
295
+ "print(device)\n",
296
+ "print(\"torch cuda is avail: \")\n",
297
+ "print(torch.cuda.is_available())"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "markdown",
302
+ "id": "abb2cf74-3cd5-4ca5-af0e-b0ee80627f2a",
303
+ "metadata": {},
304
+ "source": [
305
+ "HuggingFace Trainer"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": 12,
311
+ "id": "99947cf9-a6cd-490f-a81d-32f65fb3cd46",
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "training_args = TrainingArguments(\n",
316
+ " output_dir='./results/',\n",
317
+ " num_train_epochs=2,\n",
318
+ " per_device_train_batch_size=16,\n",
319
+ " per_device_eval_batch_size=16,\n",
320
+ " warmup_steps=500,\n",
321
+ " learning_rate=5e-5,\n",
322
+ " weight_decay=0.01,\n",
323
+ " logging_dir='./logs/',\n",
324
+ " logging_steps=10,\n",
325
+ ")\n",
326
+ "\n",
327
+ "trainer = Trainer(\n",
328
+ " model=model,\n",
329
+ " args=training_args,\n",
330
+ " train_dataset=train_set,\n",
331
+ " eval_dataset=val_set,\n",
332
+ ")"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "execution_count": 13,
338
+ "id": "be865f1d-f29b-4306-8570-900386ac4570",
339
+ "metadata": {},
340
+ "outputs": [
341
+ {
342
+ "name": "stderr",
343
+ "output_type": "stream",
344
+ "text": [
345
+ "C:\\Users\\calia\\anaconda3\\envs\\ai-finetuning-project\\lib\\site-packages\\transformers\\optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
346
+ " warnings.warn(\n",
347
+ "***** Running training *****\n",
348
+ " Num examples = 16153\n",
349
+ " Num Epochs = 2\n",
350
+ " Instantaneous batch size per device = 16\n",
351
+ " Total train batch size (w. parallel, distributed & accumulation) = 16\n",
352
+ " Gradient Accumulation steps = 1\n",
353
+ " Total optimization steps = 2020\n",
354
+ " Number of trainable parameters = 66955010\n"
355
+ ]
356
+ },
357
+ {
358
+ "data": {
359
+ "text/html": [
360
+ "\n",
361
+ " <div>\n",
362
+ " \n",
363
+ " <progress value='2020' max='2020' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
364
+ " [2020/2020 11:47, Epoch 2/2]\n",
365
+ " </div>\n",
366
+ " <table border=\"1\" class=\"dataframe\">\n",
367
+ " <thead>\n",
368
+ " <tr style=\"text-align: left;\">\n",
369
+ " <th>Step</th>\n",
370
+ " <th>Training Loss</th>\n",
371
+ " </tr>\n",
372
+ " </thead>\n",
373
+ " <tbody>\n",
374
+ " <tr>\n",
375
+ " <td>10</td>\n",
376
+ " <td>0.692000</td>\n",
377
+ " </tr>\n",
378
+ " <tr>\n",
379
+ " <td>20</td>\n",
380
+ " <td>0.685100</td>\n",
381
+ " </tr>\n",
382
+ " <tr>\n",
383
+ " <td>30</td>\n",
384
+ " <td>0.684000</td>\n",
385
+ " </tr>\n",
386
+ " <tr>\n",
387
+ " <td>40</td>\n",
388
+ " <td>0.685100</td>\n",
389
+ " </tr>\n",
390
+ " <tr>\n",
391
+ " <td>50</td>\n",
392
+ " <td>0.678400</td>\n",
393
+ " </tr>\n",
394
+ " <tr>\n",
395
+ " <td>60</td>\n",
396
+ " <td>0.687300</td>\n",
397
+ " </tr>\n",
398
+ " <tr>\n",
399
+ " <td>70</td>\n",
400
+ " <td>0.681900</td>\n",
401
+ " </tr>\n",
402
+ " <tr>\n",
403
+ " <td>80</td>\n",
404
+ " <td>0.691100</td>\n",
405
+ " </tr>\n",
406
+ " <tr>\n",
407
+ " <td>90</td>\n",
408
+ " <td>0.683200</td>\n",
409
+ " </tr>\n",
410
+ " <tr>\n",
411
+ " <td>100</td>\n",
412
+ " <td>0.694100</td>\n",
413
+ " </tr>\n",
414
+ " <tr>\n",
415
+ " <td>110</td>\n",
416
+ " <td>0.673300</td>\n",
417
+ " </tr>\n",
418
+ " <tr>\n",
419
+ " <td>120</td>\n",
420
+ " <td>0.694100</td>\n",
421
+ " </tr>\n",
422
+ " <tr>\n",
423
+ " <td>130</td>\n",
424
+ " <td>0.669500</td>\n",
425
+ " </tr>\n",
426
+ " <tr>\n",
427
+ " <td>140</td>\n",
428
+ " <td>0.691100</td>\n",
429
+ " </tr>\n",
430
+ " <tr>\n",
431
+ " <td>150</td>\n",
432
+ " <td>0.683400</td>\n",
433
+ " </tr>\n",
434
+ " <tr>\n",
435
+ " <td>160</td>\n",
436
+ " <td>0.654900</td>\n",
437
+ " </tr>\n",
438
+ " <tr>\n",
439
+ " <td>170</td>\n",
440
+ " <td>0.684300</td>\n",
441
+ " </tr>\n",
442
+ " <tr>\n",
443
+ " <td>180</td>\n",
444
+ " <td>0.679300</td>\n",
445
+ " </tr>\n",
446
+ " <tr>\n",
447
+ " <td>190</td>\n",
448
+ " <td>0.662600</td>\n",
449
+ " </tr>\n",
450
+ " <tr>\n",
451
+ " <td>200</td>\n",
452
+ " <td>0.598400</td>\n",
453
+ " </tr>\n",
454
+ " <tr>\n",
455
+ " <td>210</td>\n",
456
+ " <td>0.717700</td>\n",
457
+ " </tr>\n",
458
+ " <tr>\n",
459
+ " <td>220</td>\n",
460
+ " <td>0.679100</td>\n",
461
+ " </tr>\n",
462
+ " <tr>\n",
463
+ " <td>230</td>\n",
464
+ " <td>0.677500</td>\n",
465
+ " </tr>\n",
466
+ " <tr>\n",
467
+ " <td>240</td>\n",
468
+ " <td>0.668800</td>\n",
469
+ " </tr>\n",
470
+ " <tr>\n",
471
+ " <td>250</td>\n",
472
+ " <td>0.678100</td>\n",
473
+ " </tr>\n",
474
+ " <tr>\n",
475
+ " <td>260</td>\n",
476
+ " <td>0.657500</td>\n",
477
+ " </tr>\n",
478
+ " <tr>\n",
479
+ " <td>270</td>\n",
480
+ " <td>0.707200</td>\n",
481
+ " </tr>\n",
482
+ " <tr>\n",
483
+ " <td>280</td>\n",
484
+ " <td>0.670300</td>\n",
485
+ " </tr>\n",
486
+ " <tr>\n",
487
+ " <td>290</td>\n",
488
+ " <td>0.659900</td>\n",
489
+ " </tr>\n",
490
+ " <tr>\n",
491
+ " <td>300</td>\n",
492
+ " <td>0.633300</td>\n",
493
+ " </tr>\n",
494
+ " <tr>\n",
495
+ " <td>310</td>\n",
496
+ " <td>0.676300</td>\n",
497
+ " </tr>\n",
498
+ " <tr>\n",
499
+ " <td>320</td>\n",
500
+ " <td>0.684800</td>\n",
501
+ " </tr>\n",
502
+ " <tr>\n",
503
+ " <td>330</td>\n",
504
+ " <td>0.673100</td>\n",
505
+ " </tr>\n",
506
+ " <tr>\n",
507
+ " <td>340</td>\n",
508
+ " <td>0.670500</td>\n",
509
+ " </tr>\n",
510
+ " <tr>\n",
511
+ " <td>350</td>\n",
512
+ " <td>0.657500</td>\n",
513
+ " </tr>\n",
514
+ " <tr>\n",
515
+ " <td>360</td>\n",
516
+ " <td>0.618100</td>\n",
517
+ " </tr>\n",
518
+ " <tr>\n",
519
+ " <td>370</td>\n",
520
+ " <td>0.670000</td>\n",
521
+ " </tr>\n",
522
+ " <tr>\n",
523
+ " <td>380</td>\n",
524
+ " <td>0.607400</td>\n",
525
+ " </tr>\n",
526
+ " <tr>\n",
527
+ " <td>390</td>\n",
528
+ " <td>0.656200</td>\n",
529
+ " </tr>\n",
530
+ " <tr>\n",
531
+ " <td>400</td>\n",
532
+ " <td>0.700000</td>\n",
533
+ " </tr>\n",
534
+ " <tr>\n",
535
+ " <td>410</td>\n",
536
+ " <td>0.644800</td>\n",
537
+ " </tr>\n",
538
+ " <tr>\n",
539
+ " <td>420</td>\n",
540
+ " <td>0.682800</td>\n",
541
+ " </tr>\n",
542
+ " <tr>\n",
543
+ " <td>430</td>\n",
544
+ " <td>0.668800</td>\n",
545
+ " </tr>\n",
546
+ " <tr>\n",
547
+ " <td>440</td>\n",
548
+ " <td>0.662600</td>\n",
549
+ " </tr>\n",
550
+ " <tr>\n",
551
+ " <td>450</td>\n",
552
+ " <td>0.647700</td>\n",
553
+ " </tr>\n",
554
+ " <tr>\n",
555
+ " <td>460</td>\n",
556
+ " <td>0.688600</td>\n",
557
+ " </tr>\n",
558
+ " <tr>\n",
559
+ " <td>470</td>\n",
560
+ " <td>0.682400</td>\n",
561
+ " </tr>\n",
562
+ " <tr>\n",
563
+ " <td>480</td>\n",
564
+ " <td>0.642900</td>\n",
565
+ " </tr>\n",
566
+ " <tr>\n",
567
+ " <td>490</td>\n",
568
+ " <td>0.726900</td>\n",
569
+ " </tr>\n",
570
+ " <tr>\n",
571
+ " <td>500</td>\n",
572
+ " <td>0.660400</td>\n",
573
+ " </tr>\n",
574
+ " <tr>\n",
575
+ " <td>510</td>\n",
576
+ " <td>0.649500</td>\n",
577
+ " </tr>\n",
578
+ " <tr>\n",
579
+ " <td>520</td>\n",
580
+ " <td>0.637200</td>\n",
581
+ " </tr>\n",
582
+ " <tr>\n",
583
+ " <td>530</td>\n",
584
+ " <td>0.669700</td>\n",
585
+ " </tr>\n",
586
+ " <tr>\n",
587
+ " <td>540</td>\n",
588
+ " <td>0.667100</td>\n",
589
+ " </tr>\n",
590
+ " <tr>\n",
591
+ " <td>550</td>\n",
592
+ " <td>0.617000</td>\n",
593
+ " </tr>\n",
594
+ " <tr>\n",
595
+ " <td>560</td>\n",
596
+ " <td>0.725300</td>\n",
597
+ " </tr>\n",
598
+ " <tr>\n",
599
+ " <td>570</td>\n",
600
+ " <td>0.656800</td>\n",
601
+ " </tr>\n",
602
+ " <tr>\n",
603
+ " <td>580</td>\n",
604
+ " <td>0.664600</td>\n",
605
+ " </tr>\n",
606
+ " <tr>\n",
607
+ " <td>590</td>\n",
608
+ " <td>0.702600</td>\n",
609
+ " </tr>\n",
610
+ " <tr>\n",
611
+ " <td>600</td>\n",
612
+ " <td>0.686300</td>\n",
613
+ " </tr>\n",
614
+ " <tr>\n",
615
+ " <td>610</td>\n",
616
+ " <td>0.668400</td>\n",
617
+ " </tr>\n",
618
+ " <tr>\n",
619
+ " <td>620</td>\n",
620
+ " <td>0.648200</td>\n",
621
+ " </tr>\n",
622
+ " <tr>\n",
623
+ " <td>630</td>\n",
624
+ " <td>0.628700</td>\n",
625
+ " </tr>\n",
626
+ " <tr>\n",
627
+ " <td>640</td>\n",
628
+ " <td>0.676700</td>\n",
629
+ " </tr>\n",
630
+ " <tr>\n",
631
+ " <td>650</td>\n",
632
+ " <td>0.652400</td>\n",
633
+ " </tr>\n",
634
+ " <tr>\n",
635
+ " <td>660</td>\n",
636
+ " <td>0.654300</td>\n",
637
+ " </tr>\n",
638
+ " <tr>\n",
639
+ " <td>670</td>\n",
640
+ " <td>0.640800</td>\n",
641
+ " </tr>\n",
642
+ " <tr>\n",
643
+ " <td>680</td>\n",
644
+ " <td>0.672000</td>\n",
645
+ " </tr>\n",
646
+ " <tr>\n",
647
+ " <td>690</td>\n",
648
+ " <td>0.636100</td>\n",
649
+ " </tr>\n",
650
+ " <tr>\n",
651
+ " <td>700</td>\n",
652
+ " <td>0.689100</td>\n",
653
+ " </tr>\n",
654
+ " <tr>\n",
655
+ " <td>710</td>\n",
656
+ " <td>0.691100</td>\n",
657
+ " </tr>\n",
658
+ " <tr>\n",
659
+ " <td>720</td>\n",
660
+ " <td>0.650300</td>\n",
661
+ " </tr>\n",
662
+ " <tr>\n",
663
+ " <td>730</td>\n",
664
+ " <td>0.655200</td>\n",
665
+ " </tr>\n",
666
+ " <tr>\n",
667
+ " <td>740</td>\n",
668
+ " <td>0.668400</td>\n",
669
+ " </tr>\n",
670
+ " <tr>\n",
671
+ " <td>750</td>\n",
672
+ " <td>0.659200</td>\n",
673
+ " </tr>\n",
674
+ " <tr>\n",
675
+ " <td>760</td>\n",
676
+ " <td>0.647800</td>\n",
677
+ " </tr>\n",
678
+ " <tr>\n",
679
+ " <td>770</td>\n",
680
+ " <td>0.662800</td>\n",
681
+ " </tr>\n",
682
+ " <tr>\n",
683
+ " <td>780</td>\n",
684
+ " <td>0.648500</td>\n",
685
+ " </tr>\n",
686
+ " <tr>\n",
687
+ " <td>790</td>\n",
688
+ " <td>0.656700</td>\n",
689
+ " </tr>\n",
690
+ " <tr>\n",
691
+ " <td>800</td>\n",
692
+ " <td>0.669400</td>\n",
693
+ " </tr>\n",
694
+ " <tr>\n",
695
+ " <td>810</td>\n",
696
+ " <td>0.607800</td>\n",
697
+ " </tr>\n",
698
+ " <tr>\n",
699
+ " <td>820</td>\n",
700
+ " <td>0.683200</td>\n",
701
+ " </tr>\n",
702
+ " <tr>\n",
703
+ " <td>830</td>\n",
704
+ " <td>0.663800</td>\n",
705
+ " </tr>\n",
706
+ " <tr>\n",
707
+ " <td>840</td>\n",
708
+ " <td>0.700900</td>\n",
709
+ " </tr>\n",
710
+ " <tr>\n",
711
+ " <td>850</td>\n",
712
+ " <td>0.648200</td>\n",
713
+ " </tr>\n",
714
+ " <tr>\n",
715
+ " <td>860</td>\n",
716
+ " <td>0.619400</td>\n",
717
+ " </tr>\n",
718
+ " <tr>\n",
719
+ " <td>870</td>\n",
720
+ " <td>0.649200</td>\n",
721
+ " </tr>\n",
722
+ " <tr>\n",
723
+ " <td>880</td>\n",
724
+ " <td>0.717500</td>\n",
725
+ " </tr>\n",
726
+ " <tr>\n",
727
+ " <td>890</td>\n",
728
+ " <td>0.669600</td>\n",
729
+ " </tr>\n",
730
+ " <tr>\n",
731
+ " <td>900</td>\n",
732
+ " <td>0.669700</td>\n",
733
+ " </tr>\n",
734
+ " <tr>\n",
735
+ " <td>910</td>\n",
736
+ " <td>0.683900</td>\n",
737
+ " </tr>\n",
738
+ " <tr>\n",
739
+ " <td>920</td>\n",
740
+ " <td>0.636900</td>\n",
741
+ " </tr>\n",
742
+ " <tr>\n",
743
+ " <td>930</td>\n",
744
+ " <td>0.656400</td>\n",
745
+ " </tr>\n",
746
+ " <tr>\n",
747
+ " <td>940</td>\n",
748
+ " <td>0.650000</td>\n",
749
+ " </tr>\n",
750
+ " <tr>\n",
751
+ " <td>950</td>\n",
752
+ " <td>0.617800</td>\n",
753
+ " </tr>\n",
754
+ " <tr>\n",
755
+ " <td>960</td>\n",
756
+ " <td>0.665600</td>\n",
757
+ " </tr>\n",
758
+ " <tr>\n",
759
+ " <td>970</td>\n",
760
+ " <td>0.642700</td>\n",
761
+ " </tr>\n",
762
+ " <tr>\n",
763
+ " <td>980</td>\n",
764
+ " <td>0.644000</td>\n",
765
+ " </tr>\n",
766
+ " <tr>\n",
767
+ " <td>990</td>\n",
768
+ " <td>0.688900</td>\n",
769
+ " </tr>\n",
770
+ " <tr>\n",
771
+ " <td>1000</td>\n",
772
+ " <td>0.654700</td>\n",
773
+ " </tr>\n",
774
+ " <tr>\n",
775
+ " <td>1010</td>\n",
776
+ " <td>0.645800</td>\n",
777
+ " </tr>\n",
778
+ " <tr>\n",
779
+ " <td>1020</td>\n",
780
+ " <td>0.609200</td>\n",
781
+ " </tr>\n",
782
+ " <tr>\n",
783
+ " <td>1030</td>\n",
784
+ " <td>0.602300</td>\n",
785
+ " </tr>\n",
786
+ " <tr>\n",
787
+ " <td>1040</td>\n",
788
+ " <td>0.618800</td>\n",
789
+ " </tr>\n",
790
+ " <tr>\n",
791
+ " <td>1050</td>\n",
792
+ " <td>0.643500</td>\n",
793
+ " </tr>\n",
794
+ " <tr>\n",
795
+ " <td>1060</td>\n",
796
+ " <td>0.611000</td>\n",
797
+ " </tr>\n",
798
+ " <tr>\n",
799
+ " <td>1070</td>\n",
800
+ " <td>0.645000</td>\n",
801
+ " </tr>\n",
802
+ " <tr>\n",
803
+ " <td>1080</td>\n",
804
+ " <td>0.641000</td>\n",
805
+ " </tr>\n",
806
+ " <tr>\n",
807
+ " <td>1090</td>\n",
808
+ " <td>0.595400</td>\n",
809
+ " </tr>\n",
810
+ " <tr>\n",
811
+ " <td>1100</td>\n",
812
+ " <td>0.635100</td>\n",
813
+ " </tr>\n",
814
+ " <tr>\n",
815
+ " <td>1110</td>\n",
816
+ " <td>0.611600</td>\n",
817
+ " </tr>\n",
818
+ " <tr>\n",
819
+ " <td>1120</td>\n",
820
+ " <td>0.600300</td>\n",
821
+ " </tr>\n",
822
+ " <tr>\n",
823
+ " <td>1130</td>\n",
824
+ " <td>0.618100</td>\n",
825
+ " </tr>\n",
826
+ " <tr>\n",
827
+ " <td>1140</td>\n",
828
+ " <td>0.617200</td>\n",
829
+ " </tr>\n",
830
+ " <tr>\n",
831
+ " <td>1150</td>\n",
832
+ " <td>0.633400</td>\n",
833
+ " </tr>\n",
834
+ " <tr>\n",
835
+ " <td>1160</td>\n",
836
+ " <td>0.597600</td>\n",
837
+ " </tr>\n",
838
+ " <tr>\n",
839
+ " <td>1170</td>\n",
840
+ " <td>0.619400</td>\n",
841
+ " </tr>\n",
842
+ " <tr>\n",
843
+ " <td>1180</td>\n",
844
+ " <td>0.584200</td>\n",
845
+ " </tr>\n",
846
+ " <tr>\n",
847
+ " <td>1190</td>\n",
848
+ " <td>0.600700</td>\n",
849
+ " </tr>\n",
850
+ " <tr>\n",
851
+ " <td>1200</td>\n",
852
+ " <td>0.657400</td>\n",
853
+ " </tr>\n",
854
+ " <tr>\n",
855
+ " <td>1210</td>\n",
856
+ " <td>0.569600</td>\n",
857
+ " </tr>\n",
858
+ " <tr>\n",
859
+ " <td>1220</td>\n",
860
+ " <td>0.575500</td>\n",
861
+ " </tr>\n",
862
+ " <tr>\n",
863
+ " <td>1230</td>\n",
864
+ " <td>0.617900</td>\n",
865
+ " </tr>\n",
866
+ " <tr>\n",
867
+ " <td>1240</td>\n",
868
+ " <td>0.610300</td>\n",
869
+ " </tr>\n",
870
+ " <tr>\n",
871
+ " <td>1250</td>\n",
872
+ " <td>0.570600</td>\n",
873
+ " </tr>\n",
874
+ " <tr>\n",
875
+ " <td>1260</td>\n",
876
+ " <td>0.545700</td>\n",
877
+ " </tr>\n",
878
+ " <tr>\n",
879
+ " <td>1270</td>\n",
880
+ " <td>0.656300</td>\n",
881
+ " </tr>\n",
882
+ " <tr>\n",
883
+ " <td>1280</td>\n",
884
+ " <td>0.554700</td>\n",
885
+ " </tr>\n",
886
+ " <tr>\n",
887
+ " <td>1290</td>\n",
888
+ " <td>0.598200</td>\n",
889
+ " </tr>\n",
890
+ " <tr>\n",
891
+ " <td>1300</td>\n",
892
+ " <td>0.606300</td>\n",
893
+ " </tr>\n",
894
+ " <tr>\n",
895
+ " <td>1310</td>\n",
896
+ " <td>0.600500</td>\n",
897
+ " </tr>\n",
898
+ " <tr>\n",
899
+ " <td>1320</td>\n",
900
+ " <td>0.569800</td>\n",
901
+ " </tr>\n",
902
+ " <tr>\n",
903
+ " <td>1330</td>\n",
904
+ " <td>0.604700</td>\n",
905
+ " </tr>\n",
906
+ " <tr>\n",
907
+ " <td>1340</td>\n",
908
+ " <td>0.628300</td>\n",
909
+ " </tr>\n",
910
+ " <tr>\n",
911
+ " <td>1350</td>\n",
912
+ " <td>0.602700</td>\n",
913
+ " </tr>\n",
914
+ " <tr>\n",
915
+ " <td>1360</td>\n",
916
+ " <td>0.583700</td>\n",
917
+ " </tr>\n",
918
+ " <tr>\n",
919
+ " <td>1370</td>\n",
920
+ " <td>0.623800</td>\n",
921
+ " </tr>\n",
922
+ " <tr>\n",
923
+ " <td>1380</td>\n",
924
+ " <td>0.670300</td>\n",
925
+ " </tr>\n",
926
+ " <tr>\n",
927
+ " <td>1390</td>\n",
928
+ " <td>0.622400</td>\n",
929
+ " </tr>\n",
930
+ " <tr>\n",
931
+ " <td>1400</td>\n",
932
+ " <td>0.590200</td>\n",
933
+ " </tr>\n",
934
+ " <tr>\n",
935
+ " <td>1410</td>\n",
936
+ " <td>0.587000</td>\n",
937
+ " </tr>\n",
938
+ " <tr>\n",
939
+ " <td>1420</td>\n",
940
+ " <td>0.555500</td>\n",
941
+ " </tr>\n",
942
+ " <tr>\n",
943
+ " <td>1430</td>\n",
944
+ " <td>0.561000</td>\n",
945
+ " </tr>\n",
946
+ " <tr>\n",
947
+ " <td>1440</td>\n",
948
+ " <td>0.514300</td>\n",
949
+ " </tr>\n",
950
+ " <tr>\n",
951
+ " <td>1450</td>\n",
952
+ " <td>0.553100</td>\n",
953
+ " </tr>\n",
954
+ " <tr>\n",
955
+ " <td>1460</td>\n",
956
+ " <td>0.692400</td>\n",
957
+ " </tr>\n",
958
+ " <tr>\n",
959
+ " <td>1470</td>\n",
960
+ " <td>0.605200</td>\n",
961
+ " </tr>\n",
962
+ " <tr>\n",
963
+ " <td>1480</td>\n",
964
+ " <td>0.548000</td>\n",
965
+ " </tr>\n",
966
+ " <tr>\n",
967
+ " <td>1490</td>\n",
968
+ " <td>0.672600</td>\n",
969
+ " </tr>\n",
970
+ " <tr>\n",
971
+ " <td>1500</td>\n",
972
+ " <td>0.531100</td>\n",
973
+ " </tr>\n",
974
+ " <tr>\n",
975
+ " <td>1510</td>\n",
976
+ " <td>0.610600</td>\n",
977
+ " </tr>\n",
978
+ " <tr>\n",
979
+ " <td>1520</td>\n",
980
+ " <td>0.580200</td>\n",
981
+ " </tr>\n",
982
+ " <tr>\n",
983
+ " <td>1530</td>\n",
984
+ " <td>0.571300</td>\n",
985
+ " </tr>\n",
986
+ " <tr>\n",
987
+ " <td>1540</td>\n",
988
+ " <td>0.644400</td>\n",
989
+ " </tr>\n",
990
+ " <tr>\n",
991
+ " <td>1550</td>\n",
992
+ " <td>0.558500</td>\n",
993
+ " </tr>\n",
994
+ " <tr>\n",
995
+ " <td>1560</td>\n",
996
+ " <td>0.624000</td>\n",
997
+ " </tr>\n",
998
+ " <tr>\n",
999
+ " <td>1570</td>\n",
1000
+ " <td>0.659200</td>\n",
1001
+ " </tr>\n",
1002
+ " <tr>\n",
1003
+ " <td>1580</td>\n",
1004
+ " <td>0.580500</td>\n",
1005
+ " </tr>\n",
1006
+ " <tr>\n",
1007
+ " <td>1590</td>\n",
1008
+ " <td>0.649900</td>\n",
1009
+ " </tr>\n",
1010
+ " <tr>\n",
1011
+ " <td>1600</td>\n",
1012
+ " <td>0.608700</td>\n",
1013
+ " </tr>\n",
1014
+ " <tr>\n",
1015
+ " <td>1610</td>\n",
1016
+ " <td>0.595100</td>\n",
1017
+ " </tr>\n",
1018
+ " <tr>\n",
1019
+ " <td>1620</td>\n",
1020
+ " <td>0.592900</td>\n",
1021
+ " </tr>\n",
1022
+ " <tr>\n",
1023
+ " <td>1630</td>\n",
1024
+ " <td>0.584000</td>\n",
1025
+ " </tr>\n",
1026
+ " <tr>\n",
1027
+ " <td>1640</td>\n",
1028
+ " <td>0.607100</td>\n",
1029
+ " </tr>\n",
1030
+ " <tr>\n",
1031
+ " <td>1650</td>\n",
1032
+ " <td>0.565800</td>\n",
1033
+ " </tr>\n",
1034
+ " <tr>\n",
1035
+ " <td>1660</td>\n",
1036
+ " <td>0.568300</td>\n",
1037
+ " </tr>\n",
1038
+ " <tr>\n",
1039
+ " <td>1670</td>\n",
1040
+ " <td>0.572200</td>\n",
1041
+ " </tr>\n",
1042
+ " <tr>\n",
1043
+ " <td>1680</td>\n",
1044
+ " <td>0.597500</td>\n",
1045
+ " </tr>\n",
1046
+ " <tr>\n",
1047
+ " <td>1690</td>\n",
1048
+ " <td>0.602700</td>\n",
1049
+ " </tr>\n",
1050
+ " <tr>\n",
1051
+ " <td>1700</td>\n",
1052
+ " <td>0.692900</td>\n",
1053
+ " </tr>\n",
1054
+ " <tr>\n",
1055
+ " <td>1710</td>\n",
1056
+ " <td>0.597900</td>\n",
1057
+ " </tr>\n",
1058
+ " <tr>\n",
1059
+ " <td>1720</td>\n",
1060
+ " <td>0.538600</td>\n",
1061
+ " </tr>\n",
1062
+ " <tr>\n",
1063
+ " <td>1730</td>\n",
1064
+ " <td>0.599400</td>\n",
1065
+ " </tr>\n",
1066
+ " <tr>\n",
1067
+ " <td>1740</td>\n",
1068
+ " <td>0.704300</td>\n",
1069
+ " </tr>\n",
1070
+ " <tr>\n",
1071
+ " <td>1750</td>\n",
1072
+ " <td>0.580500</td>\n",
1073
+ " </tr>\n",
1074
+ " <tr>\n",
1075
+ " <td>1760</td>\n",
1076
+ " <td>0.595600</td>\n",
1077
+ " </tr>\n",
1078
+ " <tr>\n",
1079
+ " <td>1770</td>\n",
1080
+ " <td>0.583100</td>\n",
1081
+ " </tr>\n",
1082
+ " <tr>\n",
1083
+ " <td>1780</td>\n",
1084
+ " <td>0.569500</td>\n",
1085
+ " </tr>\n",
1086
+ " <tr>\n",
1087
+ " <td>1790</td>\n",
1088
+ " <td>0.603300</td>\n",
1089
+ " </tr>\n",
1090
+ " <tr>\n",
1091
+ " <td>1800</td>\n",
1092
+ " <td>0.564500</td>\n",
1093
+ " </tr>\n",
1094
+ " <tr>\n",
1095
+ " <td>1810</td>\n",
1096
+ " <td>0.592100</td>\n",
1097
+ " </tr>\n",
1098
+ " <tr>\n",
1099
+ " <td>1820</td>\n",
1100
+ " <td>0.617000</td>\n",
1101
+ " </tr>\n",
1102
+ " <tr>\n",
1103
+ " <td>1830</td>\n",
1104
+ " <td>0.656500</td>\n",
1105
+ " </tr>\n",
1106
+ " <tr>\n",
1107
+ " <td>1840</td>\n",
1108
+ " <td>0.563600</td>\n",
1109
+ " </tr>\n",
1110
+ " <tr>\n",
1111
+ " <td>1850</td>\n",
1112
+ " <td>0.624800</td>\n",
1113
+ " </tr>\n",
1114
+ " <tr>\n",
1115
+ " <td>1860</td>\n",
1116
+ " <td>0.686700</td>\n",
1117
+ " </tr>\n",
1118
+ " <tr>\n",
1119
+ " <td>1870</td>\n",
1120
+ " <td>0.572300</td>\n",
1121
+ " </tr>\n",
1122
+ " <tr>\n",
1123
+ " <td>1880</td>\n",
1124
+ " <td>0.587700</td>\n",
1125
+ " </tr>\n",
1126
+ " <tr>\n",
1127
+ " <td>1890</td>\n",
1128
+ " <td>0.583000</td>\n",
1129
+ " </tr>\n",
1130
+ " <tr>\n",
1131
+ " <td>1900</td>\n",
1132
+ " <td>0.601500</td>\n",
1133
+ " </tr>\n",
1134
+ " <tr>\n",
1135
+ " <td>1910</td>\n",
1136
+ " <td>0.559700</td>\n",
1137
+ " </tr>\n",
1138
+ " <tr>\n",
1139
+ " <td>1920</td>\n",
1140
+ " <td>0.610100</td>\n",
1141
+ " </tr>\n",
1142
+ " <tr>\n",
1143
+ " <td>1930</td>\n",
1144
+ " <td>0.571300</td>\n",
1145
+ " </tr>\n",
1146
+ " <tr>\n",
1147
+ " <td>1940</td>\n",
1148
+ " <td>0.549900</td>\n",
1149
+ " </tr>\n",
1150
+ " <tr>\n",
1151
+ " <td>1950</td>\n",
1152
+ " <td>0.589200</td>\n",
1153
+ " </tr>\n",
1154
+ " <tr>\n",
1155
+ " <td>1960</td>\n",
1156
+ " <td>0.634800</td>\n",
1157
+ " </tr>\n",
1158
+ " <tr>\n",
1159
+ " <td>1970</td>\n",
1160
+ " <td>0.584200</td>\n",
1161
+ " </tr>\n",
1162
+ " <tr>\n",
1163
+ " <td>1980</td>\n",
1164
+ " <td>0.557000</td>\n",
1165
+ " </tr>\n",
1166
+ " <tr>\n",
1167
+ " <td>1990</td>\n",
1168
+ " <td>0.602700</td>\n",
1169
+ " </tr>\n",
1170
+ " <tr>\n",
1171
+ " <td>2000</td>\n",
1172
+ " <td>0.669700</td>\n",
1173
+ " </tr>\n",
1174
+ " <tr>\n",
1175
+ " <td>2010</td>\n",
1176
+ " <td>0.607500</td>\n",
1177
+ " </tr>\n",
1178
+ " <tr>\n",
1179
+ " <td>2020</td>\n",
1180
+ " <td>0.631800</td>\n",
1181
+ " </tr>\n",
1182
+ " </tbody>\n",
1183
+ "</table><p>"
1184
+ ],
1185
+ "text/plain": [
1186
+ "<IPython.core.display.HTML object>"
1187
+ ]
1188
+ },
1189
+ "metadata": {},
1190
+ "output_type": "display_data"
1191
+ },
1192
+ {
1193
+ "name": "stderr",
1194
+ "output_type": "stream",
1195
+ "text": [
1196
+ "Saving model checkpoint to ./results/checkpoint-500\n",
1197
+ "Configuration saved in ./results/checkpoint-500\\config.json\n",
1198
+ "Model weights saved in ./results/checkpoint-500\\pytorch_model.bin\n",
1199
+ "Saving model checkpoint to ./results/checkpoint-1000\n",
1200
+ "Configuration saved in ./results/checkpoint-1000\\config.json\n",
1201
+ "Model weights saved in ./results/checkpoint-1000\\pytorch_model.bin\n",
1202
+ "Saving model checkpoint to ./results/checkpoint-1500\n",
1203
+ "Configuration saved in ./results/checkpoint-1500\\config.json\n",
1204
+ "Model weights saved in ./results/checkpoint-1500\\pytorch_model.bin\n",
1205
+ "Saving model checkpoint to ./results/checkpoint-2000\n",
1206
+ "Configuration saved in ./results/checkpoint-2000\\config.json\n",
1207
+ "Model weights saved in ./results/checkpoint-2000\\pytorch_model.bin\n",
1208
+ "\n",
1209
+ "\n",
1210
+ "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
1211
+ "\n",
1212
+ "\n"
1213
+ ]
1214
+ },
1215
+ {
1216
+ "data": {
1217
+ "text/plain": [
1218
+ "TrainOutput(global_step=2020, training_loss=0.6342116433795136, metrics={'train_runtime': 708.5025, 'train_samples_per_second': 45.598, 'train_steps_per_second': 2.851, 'total_flos': 4279491780980736.0, 'train_loss': 0.6342116433795136, 'epoch': 2.0})"
1219
+ ]
1220
+ },
1221
+ "execution_count": 13,
1222
+ "metadata": {},
1223
+ "output_type": "execute_result"
1224
+ }
1225
+ ],
1226
+ "source": [
1227
+ "trainer.train()"
1228
+ ]
1229
+ },
1230
+ {
1231
+ "cell_type": "markdown",
1232
+ "id": "304e0d65-74cf-4945-978d-b9f56c5a83b1",
1233
+ "metadata": {},
1234
+ "source": [
1235
+ "PyTorch Training Loop"
1236
+ ]
1237
+ },
1238
+ {
1239
+ "cell_type": "code",
1240
+ "execution_count": null,
1241
+ "id": "e56d14fb-dfde-40fa-9dfa-1187c2e09866",
1242
+ "metadata": {},
1243
+ "outputs": [],
1244
+ "source": [
1245
+ "# model.train()\n",
1246
+ "# optim = AdamW(model.parameters(), lr=5e-5)\n",
1247
+ "# num_training_epochs = 2\n",
1248
+ "\n",
1249
+ "# for epoch in range(num_training_epochs):\n",
1250
+ "# print(\"starting epoch: \" + str(epoch))\n",
1251
+ "# for batch in train_dataloader:\n",
1252
+ "# optim.zero_grad()\n",
1253
+ "# input_ids = batch['input_ids'].to(device)\n",
1254
+ "# attention_mask = batch['attention_mask'].to(device)\n",
1255
+ "# labels = batch['labels'].to(device)\n",
1256
+ "# outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
1257
+ "# loss = outputs[0]\n",
1258
+ "# loss.backward()\n",
1259
+ "# optim.step()\n",
1260
+ "# model.eval()"
1261
+ ]
1262
+ },
1263
+ {
1264
+ "cell_type": "code",
1265
+ "execution_count": null,
1266
+ "id": "9b496593-c0de-4ce2-95d5-d5d3bf09d93c",
1267
+ "metadata": {},
1268
+ "outputs": [],
1269
+ "source": []
1270
+ }
1271
+ ],
1272
+ "metadata": {
1273
+ "kernelspec": {
1274
+ "display_name": "Python 3 (ipykernel)",
1275
+ "language": "python",
1276
+ "name": "python3"
1277
+ },
1278
+ "language_info": {
1279
+ "codemirror_mode": {
1280
+ "name": "ipython",
1281
+ "version": 3
1282
+ },
1283
+ "file_extension": ".py",
1284
+ "mimetype": "text/x-python",
1285
+ "name": "python",
1286
+ "nbconvert_exporter": "python",
1287
+ "pygments_lexer": "ipython3",
1288
+ "version": "3.9.16"
1289
+ }
1290
+ },
1291
+ "nbformat": 4,
1292
+ "nbformat_minor": 5
1293
+ }
finetunehupd.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """FinetuneHUPD.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/17c2CQZx_kyD3-0fuQqv_pCMJ0Evd7fLN
8
+ """
9
+
10
+ # Pretty print
11
+ from pprint import pprint
12
+ # Datasets load_dataset function
13
+ from datasets import load_dataset
14
+ # Transformers Autokenizer
15
+ from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizer, Trainer, TrainingArguments, AdamW
16
+ from torch.utils.data import DataLoader
17
+ import torch
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
20
+
21
+ dataset_dict = load_dataset('HUPD/hupd',
22
+ name='sample',
23
+ data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",
24
+ icpr_label=None,
25
+ train_filing_start_date='2016-01-01',
26
+ train_filing_end_date='2016-01-31',
27
+ val_filing_start_date='2016-01-01',
28
+ val_filing_end_date='2016-01-31',
29
+ )
30
+
31
+ print('Loading is done!')
32
+
33
+ # Label-to-index mapping for the decision status field
34
+ decision_to_str = {'REJECTED': 0, 'ACCEPTED': 1, 'PENDING': 2, 'CONT-REJECTED': 3, 'CONT-ACCEPTED': 4, 'CONT-PENDING': 5}
35
+
36
+ # Helper function
37
+ def map_decision_to_string(example):
38
+ return {'decision': decision_to_str[example['decision']]}
39
+
40
+ # Re-labeling/mapping.
41
+ train_set = dataset_dict['train'].map(map_decision_to_string)
42
+ val_set = dataset_dict['validation'].map(map_decision_to_string)
43
+
44
+ # Focus on the abstract section and tokenize the text using the tokenizer.
45
+ _SECTION_ = 'abstract'
46
+
47
+ # Training set
48
+ train_set = train_set.map(
49
+ lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),
50
+ batched=True)
51
+
52
+ # Validation set
53
+ val_set = val_set.map(
54
+ lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),
55
+ batched=True)
56
+
57
+ # Set the format
58
+ train_set.set_format(type='torch',
59
+ columns=['input_ids', 'attention_mask', 'decision'])
60
+
61
+ val_set.set_format(type='torch',
62
+ columns=['input_ids', 'attention_mask', 'decision'])
63
+
64
+ #print(train_set['decision'])
65
+
66
+ # train_dataloader and val_data_loader
67
+ train_dataloader = DataLoader(train_set, batch_size=16)
68
+ val_dataloader = DataLoader(val_set, batch_size=16)
69
+
70
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
71
+ model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
72
+ model.to(device)
73
+ print(device)
74
+ print("torch cuda is avail: ")
75
+ print(torch.cuda.is_available())
76
+ model.train()
77
+ optim = AdamW(model.parameters(), lr=5e-5)
78
+ num_training_epochs = 2
79
+
80
+ for epoch in range(num_training_epochs):
81
+ for batch in train_dataloader:
82
+ optim.zero_grad()
83
+ input_ids = batch['input_ids'].to(device)
84
+ attention_mask = batch['attention_mask'].to(device)
85
+ labels = batch['decision'].to(device)
86
+ outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
87
+ loss = outputs[0]
88
+ loss.backward()
89
+ optim.step()
90
+ print("batch finished")
91
+
92
+ model.eval()