Spaces:
Runtime error
Runtime error
benliang99
commited on
Commit
•
f058f94
1
Parent(s):
25dae2e
Added documentation, website details to readme.
Browse files- .ipynb_checkpoints/finetunehupd-checkpoint.ipynb +1207 -11
- README.md +2 -0
- app.py +1 -0
- finetunehupd.ipynb +8 -104
- finetunehupd.py +0 -92
.ipynb_checkpoints/finetunehupd-checkpoint.ipynb
CHANGED
@@ -2,29 +2,1225 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"id": "1df3c609-62a6-49c3-bcc6-29c520f9501c",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
10 |
-
"# -*- coding: utf-8 -*-\n",
|
11 |
-
"\"\"\"FinetuneHUPD.ipynb\n",
|
12 |
-
"\n",
|
13 |
-
"Automatically generated by Colaboratory.\n",
|
14 |
-
"\n",
|
15 |
-
"Original file is located at\n",
|
16 |
-
" https://colab.research.google.com/drive/17c2CQZx_kyD3-0fuQqv_pCMJ0Evd7fLN\n",
|
17 |
-
"\"\"\"\n",
|
18 |
-
"\n",
|
19 |
"# Pretty print\n",
|
20 |
"from pprint import pprint\n",
|
21 |
"# Datasets load_dataset function\n",
|
22 |
"from datasets import load_dataset\n",
|
23 |
"# Transformers Autokenizer\n",
|
24 |
-
"from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizer, Trainer, TrainingArguments, AdamW\n",
|
25 |
"from torch.utils.data import DataLoader\n",
|
26 |
"import torch"
|
27 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
}
|
29 |
],
|
30 |
"metadata": {
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
"id": "1df3c609-62a6-49c3-bcc6-29c520f9501c",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"# Pretty print\n",
|
11 |
"from pprint import pprint\n",
|
12 |
"# Datasets load_dataset function\n",
|
13 |
"from datasets import load_dataset\n",
|
14 |
"# Transformers Autokenizer\n",
|
15 |
+
"from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertTokenizerFast, Trainer, TrainingArguments, AdamW\n",
|
16 |
"from torch.utils.data import DataLoader\n",
|
17 |
"import torch"
|
18 |
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"execution_count": 3,
|
23 |
+
"id": "58167c28-eb27-4f82-b7d0-8216dbeaf650",
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [
|
26 |
+
{
|
27 |
+
"name": "stderr",
|
28 |
+
"output_type": "stream",
|
29 |
+
"text": [
|
30 |
+
"Found cached dataset hupd (C:/Users/calia/.cache/huggingface/datasets/HUPD___hupd/sample-5094df4de61ed3bc/0.0.0/6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142)\n"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"data": {
|
35 |
+
"application/vnd.jupyter.widget-view+json": {
|
36 |
+
"model_id": "a2f090474cb148548ce3eb73698fcc6c",
|
37 |
+
"version_major": 2,
|
38 |
+
"version_minor": 0
|
39 |
+
},
|
40 |
+
"text/plain": [
|
41 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
"metadata": {},
|
45 |
+
"output_type": "display_data"
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"name": "stdout",
|
49 |
+
"output_type": "stream",
|
50 |
+
"text": [
|
51 |
+
"Loading is done!\n"
|
52 |
+
]
|
53 |
+
}
|
54 |
+
],
|
55 |
+
"source": [
|
56 |
+
"# Use the tokenizer for DistilBert since we are using the pretrained base uncased model \n",
|
57 |
+
"tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')\n",
|
58 |
+
"\n",
|
59 |
+
"# Load Harvard USPTO Patents Dataset from January 2-16 dataset using the load_dataset function\n",
|
60 |
+
"# Split the data into a train and validation set. For this purpose, I have split the patents as follows:\n",
|
61 |
+
"# train: Jan 1 - Jan 22, val: Jan 22 - Jan 31.\n",
|
62 |
+
"dataset_dict = load_dataset('HUPD/hupd',\n",
|
63 |
+
" name='sample',\n",
|
64 |
+
" data_files=\"https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather\", \n",
|
65 |
+
" icpr_label=None,\n",
|
66 |
+
" train_filing_start_date='2016-01-01',\n",
|
67 |
+
" train_filing_end_date='2016-01-21',\n",
|
68 |
+
" val_filing_start_date='2016-01-22',\n",
|
69 |
+
" val_filing_end_date='2016-01-31',\n",
|
70 |
+
")\n",
|
71 |
+
"\n",
|
72 |
+
"print('Loading is done!')"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": 3,
|
78 |
+
"id": "e13c6ad1-a7f2-4806-80a2-e9c4655e1eed",
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [
|
81 |
+
{
|
82 |
+
"name": "stderr",
|
83 |
+
"output_type": "stream",
|
84 |
+
"text": [
|
85 |
+
"Loading cached processed dataset at C:\\Users\\calia\\.cache\\huggingface\\datasets\\HUPD___hupd\\sample-5094df4de61ed3bc\\0.0.0\\6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142\\cache-9f7788eb9924fd62.arrow\n",
|
86 |
+
"Loading cached processed dataset at C:\\Users\\calia\\.cache\\huggingface\\datasets\\HUPD___hupd\\sample-5094df4de61ed3bc\\0.0.0\\6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142\\cache-6c3687322fe5b556.arrow\n",
|
87 |
+
"Loading cached processed dataset at C:\\Users\\calia\\.cache\\huggingface\\datasets\\HUPD___hupd\\sample-5094df4de61ed3bc\\0.0.0\\6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142\\cache-bd3b1eee4495f3ce.arrow\n"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"data": {
|
92 |
+
"application/vnd.jupyter.widget-view+json": {
|
93 |
+
"model_id": "",
|
94 |
+
"version_major": 2,
|
95 |
+
"version_minor": 0
|
96 |
+
},
|
97 |
+
"text/plain": [
|
98 |
+
"Map: 0%| | 0/9094 [00:00<?, ? examples/s]"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
"metadata": {},
|
102 |
+
"output_type": "display_data"
|
103 |
+
}
|
104 |
+
],
|
105 |
+
"source": [
|
106 |
+
"# Label-to-index mapping for the decision status field\n",
|
107 |
+
"# Convert to binary, as we only care about whether or not a patent was rejected or accepted.\n",
|
108 |
+
"decision_to_str = {'REJECTED': 0, 'ACCEPTED': 1, 'PENDING': 0, 'CONT-REJECTED': 0, 'CONT-ACCEPTED': 0, 'CONT-PENDING': 0}\n",
|
109 |
+
"\n",
|
110 |
+
"# Helper function\n",
|
111 |
+
"def map_decision_to_string(example):\n",
|
112 |
+
" return {'decision': decision_to_str[example['decision']]}\n",
|
113 |
+
"\n",
|
114 |
+
"# Re-labeling/mapping.\n",
|
115 |
+
"train_set = dataset_dict['train'].map(map_decision_to_string)\n",
|
116 |
+
"val_set = dataset_dict['validation'].map(map_decision_to_string)\n",
|
117 |
+
"\n",
|
118 |
+
"# Focus on the abstract section and tokenize the text using the tokenizer. \n",
|
119 |
+
"_SECTION_ = 'abstract'\n",
|
120 |
+
"\n",
|
121 |
+
"# Training set\n",
|
122 |
+
"train_set = train_set.map(\n",
|
123 |
+
" lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),\n",
|
124 |
+
" batched=True)\n",
|
125 |
+
"\n",
|
126 |
+
"# Validation set\n",
|
127 |
+
"val_set = val_set.map(\n",
|
128 |
+
" lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),\n",
|
129 |
+
" batched=True)"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"execution_count": 4,
|
135 |
+
"id": "b5c098be-019b-42ce-9b80-4f6de93ef6a3",
|
136 |
+
"metadata": {},
|
137 |
+
"outputs": [
|
138 |
+
{
|
139 |
+
"data": {
|
140 |
+
"text/plain": [
|
141 |
+
"Dataset({\n",
|
142 |
+
" features: ['patent_number', 'decision', 'title', 'abstract', 'claims', 'background', 'summary', 'description', 'cpc_label', 'ipc_label', 'filing_date', 'patent_issue_date', 'date_published', 'examiner_id', 'input_ids', 'attention_mask'],\n",
|
143 |
+
" num_rows: 16153\n",
|
144 |
+
"})"
|
145 |
+
]
|
146 |
+
},
|
147 |
+
"execution_count": 4,
|
148 |
+
"metadata": {},
|
149 |
+
"output_type": "execute_result"
|
150 |
+
}
|
151 |
+
],
|
152 |
+
"source": [
|
153 |
+
"train_set"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"execution_count": 5,
|
159 |
+
"id": "1e5a5390-19fe-4a73-b913-e3c1e2c2a399",
|
160 |
+
"metadata": {},
|
161 |
+
"outputs": [
|
162 |
+
{
|
163 |
+
"data": {
|
164 |
+
"text/plain": [
|
165 |
+
"Dataset({\n",
|
166 |
+
" features: ['patent_number', 'decision', 'title', 'abstract', 'claims', 'background', 'summary', 'description', 'cpc_label', 'ipc_label', 'filing_date', 'patent_issue_date', 'date_published', 'examiner_id', 'input_ids', 'attention_mask'],\n",
|
167 |
+
" num_rows: 9094\n",
|
168 |
+
"})"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
"execution_count": 5,
|
172 |
+
"metadata": {},
|
173 |
+
"output_type": "execute_result"
|
174 |
+
}
|
175 |
+
],
|
176 |
+
"source": [
|
177 |
+
"val_set"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": 6,
|
183 |
+
"id": "4fb69db8-86e5-4c6c-8ac6-853d3e15fb93",
|
184 |
+
"metadata": {},
|
185 |
+
"outputs": [],
|
186 |
+
"source": [
|
187 |
+
"# Process the train and validation sets so that only labels, input ids, and attention masks are left.\n",
|
188 |
+
"train_set = train_set.remove_columns([\"patent_number\", \"title\", \"abstract\", \"claims\", \"background\", \"summary\", \"description\", \"cpc_label\", \"ipc_label\", \"filing_date\", \"patent_issue_date\", \"date_published\", \"examiner_id\"])\n",
|
189 |
+
"val_set = val_set.remove_columns([\"patent_number\", \"title\", \"abstract\", \"claims\", \"background\", \"summary\", \"description\", \"cpc_label\", \"ipc_label\", \"filing_date\", \"patent_issue_date\", \"date_published\", \"examiner_id\"])\n",
|
190 |
+
"\n",
|
191 |
+
"train_set = train_set.rename_column(\"decision\", \"labels\")\n",
|
192 |
+
"val_set = val_set.rename_column(\"decision\", \"labels\")"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "code",
|
197 |
+
"execution_count": 7,
|
198 |
+
"id": "c0d17213-4b14-418c-980c-0238236096c2",
|
199 |
+
"metadata": {},
|
200 |
+
"outputs": [
|
201 |
+
{
|
202 |
+
"data": {
|
203 |
+
"text/plain": [
|
204 |
+
"Dataset({\n",
|
205 |
+
" features: ['labels', 'input_ids', 'attention_mask'],\n",
|
206 |
+
" num_rows: 16153\n",
|
207 |
+
"})"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
"execution_count": 7,
|
211 |
+
"metadata": {},
|
212 |
+
"output_type": "execute_result"
|
213 |
+
}
|
214 |
+
],
|
215 |
+
"source": [
|
216 |
+
"train_set"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": 8,
|
222 |
+
"id": "da2f1c16-3ba4-4e56-9455-5cd838df4dcd",
|
223 |
+
"metadata": {},
|
224 |
+
"outputs": [
|
225 |
+
{
|
226 |
+
"data": {
|
227 |
+
"text/plain": [
|
228 |
+
"Dataset({\n",
|
229 |
+
" features: ['labels', 'input_ids', 'attention_mask'],\n",
|
230 |
+
" num_rows: 9094\n",
|
231 |
+
"})"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
"execution_count": 8,
|
235 |
+
"metadata": {},
|
236 |
+
"output_type": "execute_result"
|
237 |
+
}
|
238 |
+
],
|
239 |
+
"source": [
|
240 |
+
"val_set"
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "code",
|
245 |
+
"execution_count": 9,
|
246 |
+
"id": "cfb35702-863d-4fec-83e1-44c4e5668156",
|
247 |
+
"metadata": {},
|
248 |
+
"outputs": [],
|
249 |
+
"source": [
|
250 |
+
"# Set the format to 'torch'\n",
|
251 |
+
"train_set.set_format(type='torch', \n",
|
252 |
+
" columns=['labels', 'input_ids', 'attention_mask'])\n",
|
253 |
+
"\n",
|
254 |
+
"val_set.set_format(type='torch', \n",
|
255 |
+
" columns=['labels', 'input_ids', 'attention_mask'])"
|
256 |
+
]
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"cell_type": "code",
|
260 |
+
"execution_count": 11,
|
261 |
+
"id": "b3248182-fddb-46dc-addb-26981a881d99",
|
262 |
+
"metadata": {},
|
263 |
+
"outputs": [
|
264 |
+
{
|
265 |
+
"name": "stderr",
|
266 |
+
"output_type": "stream",
|
267 |
+
"text": [
|
268 |
+
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias']\n",
|
269 |
+
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
270 |
+
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
271 |
+
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias', 'classifier.weight']\n",
|
272 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
273 |
+
]
|
274 |
+
},
|
275 |
+
{
|
276 |
+
"name": "stdout",
|
277 |
+
"output_type": "stream",
|
278 |
+
"text": [
|
279 |
+
"cuda\n",
|
280 |
+
"torch cuda is avail: \n",
|
281 |
+
"True\n"
|
282 |
+
]
|
283 |
+
}
|
284 |
+
],
|
285 |
+
"source": [
|
286 |
+
"# Use GPU acceleration. Make sure you have the right cuda and pytorch versions!\n",
|
287 |
+
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
288 |
+
"model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')\n",
|
289 |
+
"model.to(device)\n",
|
290 |
+
"print(device)\n",
|
291 |
+
"print(\"torch cuda is avail: \")\n",
|
292 |
+
"print(torch.cuda.is_available())"
|
293 |
+
]
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"cell_type": "markdown",
|
297 |
+
"id": "abb2cf74-3cd5-4ca5-af0e-b0ee80627f2a",
|
298 |
+
"metadata": {},
|
299 |
+
"source": [
|
300 |
+
"HuggingFace Trainer"
|
301 |
+
]
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"execution_count": 12,
|
306 |
+
"id": "99947cf9-a6cd-490f-a81d-32f65fb3cd46",
|
307 |
+
"metadata": {},
|
308 |
+
"outputs": [],
|
309 |
+
"source": [
|
310 |
+
"training_args = TrainingArguments(\n",
|
311 |
+
" output_dir='./results/',\n",
|
312 |
+
" num_train_epochs=2,\n",
|
313 |
+
" per_device_train_batch_size=16,\n",
|
314 |
+
" per_device_eval_batch_size=16,\n",
|
315 |
+
" warmup_steps=500,\n",
|
316 |
+
" learning_rate=5e-5,\n",
|
317 |
+
" weight_decay=0.01,\n",
|
318 |
+
" logging_dir='./logs/',\n",
|
319 |
+
" logging_steps=10,\n",
|
320 |
+
")\n",
|
321 |
+
"\n",
|
322 |
+
"trainer = Trainer(\n",
|
323 |
+
" model=model,\n",
|
324 |
+
" args=training_args,\n",
|
325 |
+
" train_dataset=train_set,\n",
|
326 |
+
" eval_dataset=val_set,\n",
|
327 |
+
")"
|
328 |
+
]
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"cell_type": "code",
|
332 |
+
"execution_count": 13,
|
333 |
+
"id": "be865f1d-f29b-4306-8570-900386ac4570",
|
334 |
+
"metadata": {},
|
335 |
+
"outputs": [
|
336 |
+
{
|
337 |
+
"name": "stderr",
|
338 |
+
"output_type": "stream",
|
339 |
+
"text": [
|
340 |
+
"C:\\Users\\calia\\anaconda3\\envs\\ai-finetuning-project\\lib\\site-packages\\transformers\\optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
341 |
+
" warnings.warn(\n",
|
342 |
+
"***** Running training *****\n",
|
343 |
+
" Num examples = 16153\n",
|
344 |
+
" Num Epochs = 2\n",
|
345 |
+
" Instantaneous batch size per device = 16\n",
|
346 |
+
" Total train batch size (w. parallel, distributed & accumulation) = 16\n",
|
347 |
+
" Gradient Accumulation steps = 1\n",
|
348 |
+
" Total optimization steps = 2020\n",
|
349 |
+
" Number of trainable parameters = 66955010\n"
|
350 |
+
]
|
351 |
+
},
|
352 |
+
{
|
353 |
+
"data": {
|
354 |
+
"text/html": [
|
355 |
+
"\n",
|
356 |
+
" <div>\n",
|
357 |
+
" \n",
|
358 |
+
" <progress value='2020' max='2020' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
359 |
+
" [2020/2020 11:47, Epoch 2/2]\n",
|
360 |
+
" </div>\n",
|
361 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
362 |
+
" <thead>\n",
|
363 |
+
" <tr style=\"text-align: left;\">\n",
|
364 |
+
" <th>Step</th>\n",
|
365 |
+
" <th>Training Loss</th>\n",
|
366 |
+
" </tr>\n",
|
367 |
+
" </thead>\n",
|
368 |
+
" <tbody>\n",
|
369 |
+
" <tr>\n",
|
370 |
+
" <td>10</td>\n",
|
371 |
+
" <td>0.692000</td>\n",
|
372 |
+
" </tr>\n",
|
373 |
+
" <tr>\n",
|
374 |
+
" <td>20</td>\n",
|
375 |
+
" <td>0.685100</td>\n",
|
376 |
+
" </tr>\n",
|
377 |
+
" <tr>\n",
|
378 |
+
" <td>30</td>\n",
|
379 |
+
" <td>0.684000</td>\n",
|
380 |
+
" </tr>\n",
|
381 |
+
" <tr>\n",
|
382 |
+
" <td>40</td>\n",
|
383 |
+
" <td>0.685100</td>\n",
|
384 |
+
" </tr>\n",
|
385 |
+
" <tr>\n",
|
386 |
+
" <td>50</td>\n",
|
387 |
+
" <td>0.678400</td>\n",
|
388 |
+
" </tr>\n",
|
389 |
+
" <tr>\n",
|
390 |
+
" <td>60</td>\n",
|
391 |
+
" <td>0.687300</td>\n",
|
392 |
+
" </tr>\n",
|
393 |
+
" <tr>\n",
|
394 |
+
" <td>70</td>\n",
|
395 |
+
" <td>0.681900</td>\n",
|
396 |
+
" </tr>\n",
|
397 |
+
" <tr>\n",
|
398 |
+
" <td>80</td>\n",
|
399 |
+
" <td>0.691100</td>\n",
|
400 |
+
" </tr>\n",
|
401 |
+
" <tr>\n",
|
402 |
+
" <td>90</td>\n",
|
403 |
+
" <td>0.683200</td>\n",
|
404 |
+
" </tr>\n",
|
405 |
+
" <tr>\n",
|
406 |
+
" <td>100</td>\n",
|
407 |
+
" <td>0.694100</td>\n",
|
408 |
+
" </tr>\n",
|
409 |
+
" <tr>\n",
|
410 |
+
" <td>110</td>\n",
|
411 |
+
" <td>0.673300</td>\n",
|
412 |
+
" </tr>\n",
|
413 |
+
" <tr>\n",
|
414 |
+
" <td>120</td>\n",
|
415 |
+
" <td>0.694100</td>\n",
|
416 |
+
" </tr>\n",
|
417 |
+
" <tr>\n",
|
418 |
+
" <td>130</td>\n",
|
419 |
+
" <td>0.669500</td>\n",
|
420 |
+
" </tr>\n",
|
421 |
+
" <tr>\n",
|
422 |
+
" <td>140</td>\n",
|
423 |
+
" <td>0.691100</td>\n",
|
424 |
+
" </tr>\n",
|
425 |
+
" <tr>\n",
|
426 |
+
" <td>150</td>\n",
|
427 |
+
" <td>0.683400</td>\n",
|
428 |
+
" </tr>\n",
|
429 |
+
" <tr>\n",
|
430 |
+
" <td>160</td>\n",
|
431 |
+
" <td>0.654900</td>\n",
|
432 |
+
" </tr>\n",
|
433 |
+
" <tr>\n",
|
434 |
+
" <td>170</td>\n",
|
435 |
+
" <td>0.684300</td>\n",
|
436 |
+
" </tr>\n",
|
437 |
+
" <tr>\n",
|
438 |
+
" <td>180</td>\n",
|
439 |
+
" <td>0.679300</td>\n",
|
440 |
+
" </tr>\n",
|
441 |
+
" <tr>\n",
|
442 |
+
" <td>190</td>\n",
|
443 |
+
" <td>0.662600</td>\n",
|
444 |
+
" </tr>\n",
|
445 |
+
" <tr>\n",
|
446 |
+
" <td>200</td>\n",
|
447 |
+
" <td>0.598400</td>\n",
|
448 |
+
" </tr>\n",
|
449 |
+
" <tr>\n",
|
450 |
+
" <td>210</td>\n",
|
451 |
+
" <td>0.717700</td>\n",
|
452 |
+
" </tr>\n",
|
453 |
+
" <tr>\n",
|
454 |
+
" <td>220</td>\n",
|
455 |
+
" <td>0.679100</td>\n",
|
456 |
+
" </tr>\n",
|
457 |
+
" <tr>\n",
|
458 |
+
" <td>230</td>\n",
|
459 |
+
" <td>0.677500</td>\n",
|
460 |
+
" </tr>\n",
|
461 |
+
" <tr>\n",
|
462 |
+
" <td>240</td>\n",
|
463 |
+
" <td>0.668800</td>\n",
|
464 |
+
" </tr>\n",
|
465 |
+
" <tr>\n",
|
466 |
+
" <td>250</td>\n",
|
467 |
+
" <td>0.678100</td>\n",
|
468 |
+
" </tr>\n",
|
469 |
+
" <tr>\n",
|
470 |
+
" <td>260</td>\n",
|
471 |
+
" <td>0.657500</td>\n",
|
472 |
+
" </tr>\n",
|
473 |
+
" <tr>\n",
|
474 |
+
" <td>270</td>\n",
|
475 |
+
" <td>0.707200</td>\n",
|
476 |
+
" </tr>\n",
|
477 |
+
" <tr>\n",
|
478 |
+
" <td>280</td>\n",
|
479 |
+
" <td>0.670300</td>\n",
|
480 |
+
" </tr>\n",
|
481 |
+
" <tr>\n",
|
482 |
+
" <td>290</td>\n",
|
483 |
+
" <td>0.659900</td>\n",
|
484 |
+
" </tr>\n",
|
485 |
+
" <tr>\n",
|
486 |
+
" <td>300</td>\n",
|
487 |
+
" <td>0.633300</td>\n",
|
488 |
+
" </tr>\n",
|
489 |
+
" <tr>\n",
|
490 |
+
" <td>310</td>\n",
|
491 |
+
" <td>0.676300</td>\n",
|
492 |
+
" </tr>\n",
|
493 |
+
" <tr>\n",
|
494 |
+
" <td>320</td>\n",
|
495 |
+
" <td>0.684800</td>\n",
|
496 |
+
" </tr>\n",
|
497 |
+
" <tr>\n",
|
498 |
+
" <td>330</td>\n",
|
499 |
+
" <td>0.673100</td>\n",
|
500 |
+
" </tr>\n",
|
501 |
+
" <tr>\n",
|
502 |
+
" <td>340</td>\n",
|
503 |
+
" <td>0.670500</td>\n",
|
504 |
+
" </tr>\n",
|
505 |
+
" <tr>\n",
|
506 |
+
" <td>350</td>\n",
|
507 |
+
" <td>0.657500</td>\n",
|
508 |
+
" </tr>\n",
|
509 |
+
" <tr>\n",
|
510 |
+
" <td>360</td>\n",
|
511 |
+
" <td>0.618100</td>\n",
|
512 |
+
" </tr>\n",
|
513 |
+
" <tr>\n",
|
514 |
+
" <td>370</td>\n",
|
515 |
+
" <td>0.670000</td>\n",
|
516 |
+
" </tr>\n",
|
517 |
+
" <tr>\n",
|
518 |
+
" <td>380</td>\n",
|
519 |
+
" <td>0.607400</td>\n",
|
520 |
+
" </tr>\n",
|
521 |
+
" <tr>\n",
|
522 |
+
" <td>390</td>\n",
|
523 |
+
" <td>0.656200</td>\n",
|
524 |
+
" </tr>\n",
|
525 |
+
" <tr>\n",
|
526 |
+
" <td>400</td>\n",
|
527 |
+
" <td>0.700000</td>\n",
|
528 |
+
" </tr>\n",
|
529 |
+
" <tr>\n",
|
530 |
+
" <td>410</td>\n",
|
531 |
+
" <td>0.644800</td>\n",
|
532 |
+
" </tr>\n",
|
533 |
+
" <tr>\n",
|
534 |
+
" <td>420</td>\n",
|
535 |
+
" <td>0.682800</td>\n",
|
536 |
+
" </tr>\n",
|
537 |
+
" <tr>\n",
|
538 |
+
" <td>430</td>\n",
|
539 |
+
" <td>0.668800</td>\n",
|
540 |
+
" </tr>\n",
|
541 |
+
" <tr>\n",
|
542 |
+
" <td>440</td>\n",
|
543 |
+
" <td>0.662600</td>\n",
|
544 |
+
" </tr>\n",
|
545 |
+
" <tr>\n",
|
546 |
+
" <td>450</td>\n",
|
547 |
+
" <td>0.647700</td>\n",
|
548 |
+
" </tr>\n",
|
549 |
+
" <tr>\n",
|
550 |
+
" <td>460</td>\n",
|
551 |
+
" <td>0.688600</td>\n",
|
552 |
+
" </tr>\n",
|
553 |
+
" <tr>\n",
|
554 |
+
" <td>470</td>\n",
|
555 |
+
" <td>0.682400</td>\n",
|
556 |
+
" </tr>\n",
|
557 |
+
" <tr>\n",
|
558 |
+
" <td>480</td>\n",
|
559 |
+
" <td>0.642900</td>\n",
|
560 |
+
" </tr>\n",
|
561 |
+
" <tr>\n",
|
562 |
+
" <td>490</td>\n",
|
563 |
+
" <td>0.726900</td>\n",
|
564 |
+
" </tr>\n",
|
565 |
+
" <tr>\n",
|
566 |
+
" <td>500</td>\n",
|
567 |
+
" <td>0.660400</td>\n",
|
568 |
+
" </tr>\n",
|
569 |
+
" <tr>\n",
|
570 |
+
" <td>510</td>\n",
|
571 |
+
" <td>0.649500</td>\n",
|
572 |
+
" </tr>\n",
|
573 |
+
" <tr>\n",
|
574 |
+
" <td>520</td>\n",
|
575 |
+
" <td>0.637200</td>\n",
|
576 |
+
" </tr>\n",
|
577 |
+
" <tr>\n",
|
578 |
+
" <td>530</td>\n",
|
579 |
+
" <td>0.669700</td>\n",
|
580 |
+
" </tr>\n",
|
581 |
+
" <tr>\n",
|
582 |
+
" <td>540</td>\n",
|
583 |
+
" <td>0.667100</td>\n",
|
584 |
+
" </tr>\n",
|
585 |
+
" <tr>\n",
|
586 |
+
" <td>550</td>\n",
|
587 |
+
" <td>0.617000</td>\n",
|
588 |
+
" </tr>\n",
|
589 |
+
" <tr>\n",
|
590 |
+
" <td>560</td>\n",
|
591 |
+
" <td>0.725300</td>\n",
|
592 |
+
" </tr>\n",
|
593 |
+
" <tr>\n",
|
594 |
+
" <td>570</td>\n",
|
595 |
+
" <td>0.656800</td>\n",
|
596 |
+
" </tr>\n",
|
597 |
+
" <tr>\n",
|
598 |
+
" <td>580</td>\n",
|
599 |
+
" <td>0.664600</td>\n",
|
600 |
+
" </tr>\n",
|
601 |
+
" <tr>\n",
|
602 |
+
" <td>590</td>\n",
|
603 |
+
" <td>0.702600</td>\n",
|
604 |
+
" </tr>\n",
|
605 |
+
" <tr>\n",
|
606 |
+
" <td>600</td>\n",
|
607 |
+
" <td>0.686300</td>\n",
|
608 |
+
" </tr>\n",
|
609 |
+
" <tr>\n",
|
610 |
+
" <td>610</td>\n",
|
611 |
+
" <td>0.668400</td>\n",
|
612 |
+
" </tr>\n",
|
613 |
+
" <tr>\n",
|
614 |
+
" <td>620</td>\n",
|
615 |
+
" <td>0.648200</td>\n",
|
616 |
+
" </tr>\n",
|
617 |
+
" <tr>\n",
|
618 |
+
" <td>630</td>\n",
|
619 |
+
" <td>0.628700</td>\n",
|
620 |
+
" </tr>\n",
|
621 |
+
" <tr>\n",
|
622 |
+
" <td>640</td>\n",
|
623 |
+
" <td>0.676700</td>\n",
|
624 |
+
" </tr>\n",
|
625 |
+
" <tr>\n",
|
626 |
+
" <td>650</td>\n",
|
627 |
+
" <td>0.652400</td>\n",
|
628 |
+
" </tr>\n",
|
629 |
+
" <tr>\n",
|
630 |
+
" <td>660</td>\n",
|
631 |
+
" <td>0.654300</td>\n",
|
632 |
+
" </tr>\n",
|
633 |
+
" <tr>\n",
|
634 |
+
" <td>670</td>\n",
|
635 |
+
" <td>0.640800</td>\n",
|
636 |
+
" </tr>\n",
|
637 |
+
" <tr>\n",
|
638 |
+
" <td>680</td>\n",
|
639 |
+
" <td>0.672000</td>\n",
|
640 |
+
" </tr>\n",
|
641 |
+
" <tr>\n",
|
642 |
+
" <td>690</td>\n",
|
643 |
+
" <td>0.636100</td>\n",
|
644 |
+
" </tr>\n",
|
645 |
+
" <tr>\n",
|
646 |
+
" <td>700</td>\n",
|
647 |
+
" <td>0.689100</td>\n",
|
648 |
+
" </tr>\n",
|
649 |
+
" <tr>\n",
|
650 |
+
" <td>710</td>\n",
|
651 |
+
" <td>0.691100</td>\n",
|
652 |
+
" </tr>\n",
|
653 |
+
" <tr>\n",
|
654 |
+
" <td>720</td>\n",
|
655 |
+
" <td>0.650300</td>\n",
|
656 |
+
" </tr>\n",
|
657 |
+
" <tr>\n",
|
658 |
+
" <td>730</td>\n",
|
659 |
+
" <td>0.655200</td>\n",
|
660 |
+
" </tr>\n",
|
661 |
+
" <tr>\n",
|
662 |
+
" <td>740</td>\n",
|
663 |
+
" <td>0.668400</td>\n",
|
664 |
+
" </tr>\n",
|
665 |
+
" <tr>\n",
|
666 |
+
" <td>750</td>\n",
|
667 |
+
" <td>0.659200</td>\n",
|
668 |
+
" </tr>\n",
|
669 |
+
" <tr>\n",
|
670 |
+
" <td>760</td>\n",
|
671 |
+
" <td>0.647800</td>\n",
|
672 |
+
" </tr>\n",
|
673 |
+
" <tr>\n",
|
674 |
+
" <td>770</td>\n",
|
675 |
+
" <td>0.662800</td>\n",
|
676 |
+
" </tr>\n",
|
677 |
+
" <tr>\n",
|
678 |
+
" <td>780</td>\n",
|
679 |
+
" <td>0.648500</td>\n",
|
680 |
+
" </tr>\n",
|
681 |
+
" <tr>\n",
|
682 |
+
" <td>790</td>\n",
|
683 |
+
" <td>0.656700</td>\n",
|
684 |
+
" </tr>\n",
|
685 |
+
" <tr>\n",
|
686 |
+
" <td>800</td>\n",
|
687 |
+
" <td>0.669400</td>\n",
|
688 |
+
" </tr>\n",
|
689 |
+
" <tr>\n",
|
690 |
+
" <td>810</td>\n",
|
691 |
+
" <td>0.607800</td>\n",
|
692 |
+
" </tr>\n",
|
693 |
+
" <tr>\n",
|
694 |
+
" <td>820</td>\n",
|
695 |
+
" <td>0.683200</td>\n",
|
696 |
+
" </tr>\n",
|
697 |
+
" <tr>\n",
|
698 |
+
" <td>830</td>\n",
|
699 |
+
" <td>0.663800</td>\n",
|
700 |
+
" </tr>\n",
|
701 |
+
" <tr>\n",
|
702 |
+
" <td>840</td>\n",
|
703 |
+
" <td>0.700900</td>\n",
|
704 |
+
" </tr>\n",
|
705 |
+
" <tr>\n",
|
706 |
+
" <td>850</td>\n",
|
707 |
+
" <td>0.648200</td>\n",
|
708 |
+
" </tr>\n",
|
709 |
+
" <tr>\n",
|
710 |
+
" <td>860</td>\n",
|
711 |
+
" <td>0.619400</td>\n",
|
712 |
+
" </tr>\n",
|
713 |
+
" <tr>\n",
|
714 |
+
" <td>870</td>\n",
|
715 |
+
" <td>0.649200</td>\n",
|
716 |
+
" </tr>\n",
|
717 |
+
" <tr>\n",
|
718 |
+
" <td>880</td>\n",
|
719 |
+
" <td>0.717500</td>\n",
|
720 |
+
" </tr>\n",
|
721 |
+
" <tr>\n",
|
722 |
+
" <td>890</td>\n",
|
723 |
+
" <td>0.669600</td>\n",
|
724 |
+
" </tr>\n",
|
725 |
+
" <tr>\n",
|
726 |
+
" <td>900</td>\n",
|
727 |
+
" <td>0.669700</td>\n",
|
728 |
+
" </tr>\n",
|
729 |
+
" <tr>\n",
|
730 |
+
" <td>910</td>\n",
|
731 |
+
" <td>0.683900</td>\n",
|
732 |
+
" </tr>\n",
|
733 |
+
" <tr>\n",
|
734 |
+
" <td>920</td>\n",
|
735 |
+
" <td>0.636900</td>\n",
|
736 |
+
" </tr>\n",
|
737 |
+
" <tr>\n",
|
738 |
+
" <td>930</td>\n",
|
739 |
+
" <td>0.656400</td>\n",
|
740 |
+
" </tr>\n",
|
741 |
+
" <tr>\n",
|
742 |
+
" <td>940</td>\n",
|
743 |
+
" <td>0.650000</td>\n",
|
744 |
+
" </tr>\n",
|
745 |
+
" <tr>\n",
|
746 |
+
" <td>950</td>\n",
|
747 |
+
" <td>0.617800</td>\n",
|
748 |
+
" </tr>\n",
|
749 |
+
" <tr>\n",
|
750 |
+
" <td>960</td>\n",
|
751 |
+
" <td>0.665600</td>\n",
|
752 |
+
" </tr>\n",
|
753 |
+
" <tr>\n",
|
754 |
+
" <td>970</td>\n",
|
755 |
+
" <td>0.642700</td>\n",
|
756 |
+
" </tr>\n",
|
757 |
+
" <tr>\n",
|
758 |
+
" <td>980</td>\n",
|
759 |
+
" <td>0.644000</td>\n",
|
760 |
+
" </tr>\n",
|
761 |
+
" <tr>\n",
|
762 |
+
" <td>990</td>\n",
|
763 |
+
" <td>0.688900</td>\n",
|
764 |
+
" </tr>\n",
|
765 |
+
" <tr>\n",
|
766 |
+
" <td>1000</td>\n",
|
767 |
+
" <td>0.654700</td>\n",
|
768 |
+
" </tr>\n",
|
769 |
+
" <tr>\n",
|
770 |
+
" <td>1010</td>\n",
|
771 |
+
" <td>0.645800</td>\n",
|
772 |
+
" </tr>\n",
|
773 |
+
" <tr>\n",
|
774 |
+
" <td>1020</td>\n",
|
775 |
+
" <td>0.609200</td>\n",
|
776 |
+
" </tr>\n",
|
777 |
+
" <tr>\n",
|
778 |
+
" <td>1030</td>\n",
|
779 |
+
" <td>0.602300</td>\n",
|
780 |
+
" </tr>\n",
|
781 |
+
" <tr>\n",
|
782 |
+
" <td>1040</td>\n",
|
783 |
+
" <td>0.618800</td>\n",
|
784 |
+
" </tr>\n",
|
785 |
+
" <tr>\n",
|
786 |
+
" <td>1050</td>\n",
|
787 |
+
" <td>0.643500</td>\n",
|
788 |
+
" </tr>\n",
|
789 |
+
" <tr>\n",
|
790 |
+
" <td>1060</td>\n",
|
791 |
+
" <td>0.611000</td>\n",
|
792 |
+
" </tr>\n",
|
793 |
+
" <tr>\n",
|
794 |
+
" <td>1070</td>\n",
|
795 |
+
" <td>0.645000</td>\n",
|
796 |
+
" </tr>\n",
|
797 |
+
" <tr>\n",
|
798 |
+
" <td>1080</td>\n",
|
799 |
+
" <td>0.641000</td>\n",
|
800 |
+
" </tr>\n",
|
801 |
+
" <tr>\n",
|
802 |
+
" <td>1090</td>\n",
|
803 |
+
" <td>0.595400</td>\n",
|
804 |
+
" </tr>\n",
|
805 |
+
" <tr>\n",
|
806 |
+
" <td>1100</td>\n",
|
807 |
+
" <td>0.635100</td>\n",
|
808 |
+
" </tr>\n",
|
809 |
+
" <tr>\n",
|
810 |
+
" <td>1110</td>\n",
|
811 |
+
" <td>0.611600</td>\n",
|
812 |
+
" </tr>\n",
|
813 |
+
" <tr>\n",
|
814 |
+
" <td>1120</td>\n",
|
815 |
+
" <td>0.600300</td>\n",
|
816 |
+
" </tr>\n",
|
817 |
+
" <tr>\n",
|
818 |
+
" <td>1130</td>\n",
|
819 |
+
" <td>0.618100</td>\n",
|
820 |
+
" </tr>\n",
|
821 |
+
" <tr>\n",
|
822 |
+
" <td>1140</td>\n",
|
823 |
+
" <td>0.617200</td>\n",
|
824 |
+
" </tr>\n",
|
825 |
+
" <tr>\n",
|
826 |
+
" <td>1150</td>\n",
|
827 |
+
" <td>0.633400</td>\n",
|
828 |
+
" </tr>\n",
|
829 |
+
" <tr>\n",
|
830 |
+
" <td>1160</td>\n",
|
831 |
+
" <td>0.597600</td>\n",
|
832 |
+
" </tr>\n",
|
833 |
+
" <tr>\n",
|
834 |
+
" <td>1170</td>\n",
|
835 |
+
" <td>0.619400</td>\n",
|
836 |
+
" </tr>\n",
|
837 |
+
" <tr>\n",
|
838 |
+
" <td>1180</td>\n",
|
839 |
+
" <td>0.584200</td>\n",
|
840 |
+
" </tr>\n",
|
841 |
+
" <tr>\n",
|
842 |
+
" <td>1190</td>\n",
|
843 |
+
" <td>0.600700</td>\n",
|
844 |
+
" </tr>\n",
|
845 |
+
" <tr>\n",
|
846 |
+
" <td>1200</td>\n",
|
847 |
+
" <td>0.657400</td>\n",
|
848 |
+
" </tr>\n",
|
849 |
+
" <tr>\n",
|
850 |
+
" <td>1210</td>\n",
|
851 |
+
" <td>0.569600</td>\n",
|
852 |
+
" </tr>\n",
|
853 |
+
" <tr>\n",
|
854 |
+
" <td>1220</td>\n",
|
855 |
+
" <td>0.575500</td>\n",
|
856 |
+
" </tr>\n",
|
857 |
+
" <tr>\n",
|
858 |
+
" <td>1230</td>\n",
|
859 |
+
" <td>0.617900</td>\n",
|
860 |
+
" </tr>\n",
|
861 |
+
" <tr>\n",
|
862 |
+
" <td>1240</td>\n",
|
863 |
+
" <td>0.610300</td>\n",
|
864 |
+
" </tr>\n",
|
865 |
+
" <tr>\n",
|
866 |
+
" <td>1250</td>\n",
|
867 |
+
" <td>0.570600</td>\n",
|
868 |
+
" </tr>\n",
|
869 |
+
" <tr>\n",
|
870 |
+
" <td>1260</td>\n",
|
871 |
+
" <td>0.545700</td>\n",
|
872 |
+
" </tr>\n",
|
873 |
+
" <tr>\n",
|
874 |
+
" <td>1270</td>\n",
|
875 |
+
" <td>0.656300</td>\n",
|
876 |
+
" </tr>\n",
|
877 |
+
" <tr>\n",
|
878 |
+
" <td>1280</td>\n",
|
879 |
+
" <td>0.554700</td>\n",
|
880 |
+
" </tr>\n",
|
881 |
+
" <tr>\n",
|
882 |
+
" <td>1290</td>\n",
|
883 |
+
" <td>0.598200</td>\n",
|
884 |
+
" </tr>\n",
|
885 |
+
" <tr>\n",
|
886 |
+
" <td>1300</td>\n",
|
887 |
+
" <td>0.606300</td>\n",
|
888 |
+
" </tr>\n",
|
889 |
+
" <tr>\n",
|
890 |
+
" <td>1310</td>\n",
|
891 |
+
" <td>0.600500</td>\n",
|
892 |
+
" </tr>\n",
|
893 |
+
" <tr>\n",
|
894 |
+
" <td>1320</td>\n",
|
895 |
+
" <td>0.569800</td>\n",
|
896 |
+
" </tr>\n",
|
897 |
+
" <tr>\n",
|
898 |
+
" <td>1330</td>\n",
|
899 |
+
" <td>0.604700</td>\n",
|
900 |
+
" </tr>\n",
|
901 |
+
" <tr>\n",
|
902 |
+
" <td>1340</td>\n",
|
903 |
+
" <td>0.628300</td>\n",
|
904 |
+
" </tr>\n",
|
905 |
+
" <tr>\n",
|
906 |
+
" <td>1350</td>\n",
|
907 |
+
" <td>0.602700</td>\n",
|
908 |
+
" </tr>\n",
|
909 |
+
" <tr>\n",
|
910 |
+
" <td>1360</td>\n",
|
911 |
+
" <td>0.583700</td>\n",
|
912 |
+
" </tr>\n",
|
913 |
+
" <tr>\n",
|
914 |
+
" <td>1370</td>\n",
|
915 |
+
" <td>0.623800</td>\n",
|
916 |
+
" </tr>\n",
|
917 |
+
" <tr>\n",
|
918 |
+
" <td>1380</td>\n",
|
919 |
+
" <td>0.670300</td>\n",
|
920 |
+
" </tr>\n",
|
921 |
+
" <tr>\n",
|
922 |
+
" <td>1390</td>\n",
|
923 |
+
" <td>0.622400</td>\n",
|
924 |
+
" </tr>\n",
|
925 |
+
" <tr>\n",
|
926 |
+
" <td>1400</td>\n",
|
927 |
+
" <td>0.590200</td>\n",
|
928 |
+
" </tr>\n",
|
929 |
+
" <tr>\n",
|
930 |
+
" <td>1410</td>\n",
|
931 |
+
" <td>0.587000</td>\n",
|
932 |
+
" </tr>\n",
|
933 |
+
" <tr>\n",
|
934 |
+
" <td>1420</td>\n",
|
935 |
+
" <td>0.555500</td>\n",
|
936 |
+
" </tr>\n",
|
937 |
+
" <tr>\n",
|
938 |
+
" <td>1430</td>\n",
|
939 |
+
" <td>0.561000</td>\n",
|
940 |
+
" </tr>\n",
|
941 |
+
" <tr>\n",
|
942 |
+
" <td>1440</td>\n",
|
943 |
+
" <td>0.514300</td>\n",
|
944 |
+
" </tr>\n",
|
945 |
+
" <tr>\n",
|
946 |
+
" <td>1450</td>\n",
|
947 |
+
" <td>0.553100</td>\n",
|
948 |
+
" </tr>\n",
|
949 |
+
" <tr>\n",
|
950 |
+
" <td>1460</td>\n",
|
951 |
+
" <td>0.692400</td>\n",
|
952 |
+
" </tr>\n",
|
953 |
+
" <tr>\n",
|
954 |
+
" <td>1470</td>\n",
|
955 |
+
" <td>0.605200</td>\n",
|
956 |
+
" </tr>\n",
|
957 |
+
" <tr>\n",
|
958 |
+
" <td>1480</td>\n",
|
959 |
+
" <td>0.548000</td>\n",
|
960 |
+
" </tr>\n",
|
961 |
+
" <tr>\n",
|
962 |
+
" <td>1490</td>\n",
|
963 |
+
" <td>0.672600</td>\n",
|
964 |
+
" </tr>\n",
|
965 |
+
" <tr>\n",
|
966 |
+
" <td>1500</td>\n",
|
967 |
+
" <td>0.531100</td>\n",
|
968 |
+
" </tr>\n",
|
969 |
+
" <tr>\n",
|
970 |
+
" <td>1510</td>\n",
|
971 |
+
" <td>0.610600</td>\n",
|
972 |
+
" </tr>\n",
|
973 |
+
" <tr>\n",
|
974 |
+
" <td>1520</td>\n",
|
975 |
+
" <td>0.580200</td>\n",
|
976 |
+
" </tr>\n",
|
977 |
+
" <tr>\n",
|
978 |
+
" <td>1530</td>\n",
|
979 |
+
" <td>0.571300</td>\n",
|
980 |
+
" </tr>\n",
|
981 |
+
" <tr>\n",
|
982 |
+
" <td>1540</td>\n",
|
983 |
+
" <td>0.644400</td>\n",
|
984 |
+
" </tr>\n",
|
985 |
+
" <tr>\n",
|
986 |
+
" <td>1550</td>\n",
|
987 |
+
" <td>0.558500</td>\n",
|
988 |
+
" </tr>\n",
|
989 |
+
" <tr>\n",
|
990 |
+
" <td>1560</td>\n",
|
991 |
+
" <td>0.624000</td>\n",
|
992 |
+
" </tr>\n",
|
993 |
+
" <tr>\n",
|
994 |
+
" <td>1570</td>\n",
|
995 |
+
" <td>0.659200</td>\n",
|
996 |
+
" </tr>\n",
|
997 |
+
" <tr>\n",
|
998 |
+
" <td>1580</td>\n",
|
999 |
+
" <td>0.580500</td>\n",
|
1000 |
+
" </tr>\n",
|
1001 |
+
" <tr>\n",
|
1002 |
+
" <td>1590</td>\n",
|
1003 |
+
" <td>0.649900</td>\n",
|
1004 |
+
" </tr>\n",
|
1005 |
+
" <tr>\n",
|
1006 |
+
" <td>1600</td>\n",
|
1007 |
+
" <td>0.608700</td>\n",
|
1008 |
+
" </tr>\n",
|
1009 |
+
" <tr>\n",
|
1010 |
+
" <td>1610</td>\n",
|
1011 |
+
" <td>0.595100</td>\n",
|
1012 |
+
" </tr>\n",
|
1013 |
+
" <tr>\n",
|
1014 |
+
" <td>1620</td>\n",
|
1015 |
+
" <td>0.592900</td>\n",
|
1016 |
+
" </tr>\n",
|
1017 |
+
" <tr>\n",
|
1018 |
+
" <td>1630</td>\n",
|
1019 |
+
" <td>0.584000</td>\n",
|
1020 |
+
" </tr>\n",
|
1021 |
+
" <tr>\n",
|
1022 |
+
" <td>1640</td>\n",
|
1023 |
+
" <td>0.607100</td>\n",
|
1024 |
+
" </tr>\n",
|
1025 |
+
" <tr>\n",
|
1026 |
+
" <td>1650</td>\n",
|
1027 |
+
" <td>0.565800</td>\n",
|
1028 |
+
" </tr>\n",
|
1029 |
+
" <tr>\n",
|
1030 |
+
" <td>1660</td>\n",
|
1031 |
+
" <td>0.568300</td>\n",
|
1032 |
+
" </tr>\n",
|
1033 |
+
" <tr>\n",
|
1034 |
+
" <td>1670</td>\n",
|
1035 |
+
" <td>0.572200</td>\n",
|
1036 |
+
" </tr>\n",
|
1037 |
+
" <tr>\n",
|
1038 |
+
" <td>1680</td>\n",
|
1039 |
+
" <td>0.597500</td>\n",
|
1040 |
+
" </tr>\n",
|
1041 |
+
" <tr>\n",
|
1042 |
+
" <td>1690</td>\n",
|
1043 |
+
" <td>0.602700</td>\n",
|
1044 |
+
" </tr>\n",
|
1045 |
+
" <tr>\n",
|
1046 |
+
" <td>1700</td>\n",
|
1047 |
+
" <td>0.692900</td>\n",
|
1048 |
+
" </tr>\n",
|
1049 |
+
" <tr>\n",
|
1050 |
+
" <td>1710</td>\n",
|
1051 |
+
" <td>0.597900</td>\n",
|
1052 |
+
" </tr>\n",
|
1053 |
+
" <tr>\n",
|
1054 |
+
" <td>1720</td>\n",
|
1055 |
+
" <td>0.538600</td>\n",
|
1056 |
+
" </tr>\n",
|
1057 |
+
" <tr>\n",
|
1058 |
+
" <td>1730</td>\n",
|
1059 |
+
" <td>0.599400</td>\n",
|
1060 |
+
" </tr>\n",
|
1061 |
+
" <tr>\n",
|
1062 |
+
" <td>1740</td>\n",
|
1063 |
+
" <td>0.704300</td>\n",
|
1064 |
+
" </tr>\n",
|
1065 |
+
" <tr>\n",
|
1066 |
+
" <td>1750</td>\n",
|
1067 |
+
" <td>0.580500</td>\n",
|
1068 |
+
" </tr>\n",
|
1069 |
+
" <tr>\n",
|
1070 |
+
" <td>1760</td>\n",
|
1071 |
+
" <td>0.595600</td>\n",
|
1072 |
+
" </tr>\n",
|
1073 |
+
" <tr>\n",
|
1074 |
+
" <td>1770</td>\n",
|
1075 |
+
" <td>0.583100</td>\n",
|
1076 |
+
" </tr>\n",
|
1077 |
+
" <tr>\n",
|
1078 |
+
" <td>1780</td>\n",
|
1079 |
+
" <td>0.569500</td>\n",
|
1080 |
+
" </tr>\n",
|
1081 |
+
" <tr>\n",
|
1082 |
+
" <td>1790</td>\n",
|
1083 |
+
" <td>0.603300</td>\n",
|
1084 |
+
" </tr>\n",
|
1085 |
+
" <tr>\n",
|
1086 |
+
" <td>1800</td>\n",
|
1087 |
+
" <td>0.564500</td>\n",
|
1088 |
+
" </tr>\n",
|
1089 |
+
" <tr>\n",
|
1090 |
+
" <td>1810</td>\n",
|
1091 |
+
" <td>0.592100</td>\n",
|
1092 |
+
" </tr>\n",
|
1093 |
+
" <tr>\n",
|
1094 |
+
" <td>1820</td>\n",
|
1095 |
+
" <td>0.617000</td>\n",
|
1096 |
+
" </tr>\n",
|
1097 |
+
" <tr>\n",
|
1098 |
+
" <td>1830</td>\n",
|
1099 |
+
" <td>0.656500</td>\n",
|
1100 |
+
" </tr>\n",
|
1101 |
+
" <tr>\n",
|
1102 |
+
" <td>1840</td>\n",
|
1103 |
+
" <td>0.563600</td>\n",
|
1104 |
+
" </tr>\n",
|
1105 |
+
" <tr>\n",
|
1106 |
+
" <td>1850</td>\n",
|
1107 |
+
" <td>0.624800</td>\n",
|
1108 |
+
" </tr>\n",
|
1109 |
+
" <tr>\n",
|
1110 |
+
" <td>1860</td>\n",
|
1111 |
+
" <td>0.686700</td>\n",
|
1112 |
+
" </tr>\n",
|
1113 |
+
" <tr>\n",
|
1114 |
+
" <td>1870</td>\n",
|
1115 |
+
" <td>0.572300</td>\n",
|
1116 |
+
" </tr>\n",
|
1117 |
+
" <tr>\n",
|
1118 |
+
" <td>1880</td>\n",
|
1119 |
+
" <td>0.587700</td>\n",
|
1120 |
+
" </tr>\n",
|
1121 |
+
" <tr>\n",
|
1122 |
+
" <td>1890</td>\n",
|
1123 |
+
" <td>0.583000</td>\n",
|
1124 |
+
" </tr>\n",
|
1125 |
+
" <tr>\n",
|
1126 |
+
" <td>1900</td>\n",
|
1127 |
+
" <td>0.601500</td>\n",
|
1128 |
+
" </tr>\n",
|
1129 |
+
" <tr>\n",
|
1130 |
+
" <td>1910</td>\n",
|
1131 |
+
" <td>0.559700</td>\n",
|
1132 |
+
" </tr>\n",
|
1133 |
+
" <tr>\n",
|
1134 |
+
" <td>1920</td>\n",
|
1135 |
+
" <td>0.610100</td>\n",
|
1136 |
+
" </tr>\n",
|
1137 |
+
" <tr>\n",
|
1138 |
+
" <td>1930</td>\n",
|
1139 |
+
" <td>0.571300</td>\n",
|
1140 |
+
" </tr>\n",
|
1141 |
+
" <tr>\n",
|
1142 |
+
" <td>1940</td>\n",
|
1143 |
+
" <td>0.549900</td>\n",
|
1144 |
+
" </tr>\n",
|
1145 |
+
" <tr>\n",
|
1146 |
+
" <td>1950</td>\n",
|
1147 |
+
" <td>0.589200</td>\n",
|
1148 |
+
" </tr>\n",
|
1149 |
+
" <tr>\n",
|
1150 |
+
" <td>1960</td>\n",
|
1151 |
+
" <td>0.634800</td>\n",
|
1152 |
+
" </tr>\n",
|
1153 |
+
" <tr>\n",
|
1154 |
+
" <td>1970</td>\n",
|
1155 |
+
" <td>0.584200</td>\n",
|
1156 |
+
" </tr>\n",
|
1157 |
+
" <tr>\n",
|
1158 |
+
" <td>1980</td>\n",
|
1159 |
+
" <td>0.557000</td>\n",
|
1160 |
+
" </tr>\n",
|
1161 |
+
" <tr>\n",
|
1162 |
+
" <td>1990</td>\n",
|
1163 |
+
" <td>0.602700</td>\n",
|
1164 |
+
" </tr>\n",
|
1165 |
+
" <tr>\n",
|
1166 |
+
" <td>2000</td>\n",
|
1167 |
+
" <td>0.669700</td>\n",
|
1168 |
+
" </tr>\n",
|
1169 |
+
" <tr>\n",
|
1170 |
+
" <td>2010</td>\n",
|
1171 |
+
" <td>0.607500</td>\n",
|
1172 |
+
" </tr>\n",
|
1173 |
+
" <tr>\n",
|
1174 |
+
" <td>2020</td>\n",
|
1175 |
+
" <td>0.631800</td>\n",
|
1176 |
+
" </tr>\n",
|
1177 |
+
" </tbody>\n",
|
1178 |
+
"</table><p>"
|
1179 |
+
],
|
1180 |
+
"text/plain": [
|
1181 |
+
"<IPython.core.display.HTML object>"
|
1182 |
+
]
|
1183 |
+
},
|
1184 |
+
"metadata": {},
|
1185 |
+
"output_type": "display_data"
|
1186 |
+
},
|
1187 |
+
{
|
1188 |
+
"name": "stderr",
|
1189 |
+
"output_type": "stream",
|
1190 |
+
"text": [
|
1191 |
+
"Saving model checkpoint to ./results/checkpoint-500\n",
|
1192 |
+
"Configuration saved in ./results/checkpoint-500\\config.json\n",
|
1193 |
+
"Model weights saved in ./results/checkpoint-500\\pytorch_model.bin\n",
|
1194 |
+
"Saving model checkpoint to ./results/checkpoint-1000\n",
|
1195 |
+
"Configuration saved in ./results/checkpoint-1000\\config.json\n",
|
1196 |
+
"Model weights saved in ./results/checkpoint-1000\\pytorch_model.bin\n",
|
1197 |
+
"Saving model checkpoint to ./results/checkpoint-1500\n",
|
1198 |
+
"Configuration saved in ./results/checkpoint-1500\\config.json\n",
|
1199 |
+
"Model weights saved in ./results/checkpoint-1500\\pytorch_model.bin\n",
|
1200 |
+
"Saving model checkpoint to ./results/checkpoint-2000\n",
|
1201 |
+
"Configuration saved in ./results/checkpoint-2000\\config.json\n",
|
1202 |
+
"Model weights saved in ./results/checkpoint-2000\\pytorch_model.bin\n",
|
1203 |
+
"\n",
|
1204 |
+
"\n",
|
1205 |
+
"Training completed. Do not forget to share your model on huggingface.co/models =)\n",
|
1206 |
+
"\n",
|
1207 |
+
"\n"
|
1208 |
+
]
|
1209 |
+
},
|
1210 |
+
{
|
1211 |
+
"data": {
|
1212 |
+
"text/plain": [
|
1213 |
+
"TrainOutput(global_step=2020, training_loss=0.6342116433795136, metrics={'train_runtime': 708.5025, 'train_samples_per_second': 45.598, 'train_steps_per_second': 2.851, 'total_flos': 4279491780980736.0, 'train_loss': 0.6342116433795136, 'epoch': 2.0})"
|
1214 |
+
]
|
1215 |
+
},
|
1216 |
+
"execution_count": 13,
|
1217 |
+
"metadata": {},
|
1218 |
+
"output_type": "execute_result"
|
1219 |
+
}
|
1220 |
+
],
|
1221 |
+
"source": [
|
1222 |
+
"trainer.train()"
|
1223 |
+
]
|
1224 |
}
|
1225 |
],
|
1226 |
"metadata": {
|
README.md
CHANGED
@@ -9,6 +9,8 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
|
|
12 |
Link to huggingface app: https://huggingface.co/spaces/saccharinedreams/sentiment-analysis-app
|
13 |
|
14 |
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
Link to Google Site: https://sites.google.com/nyu.edu/finetunedpatentanalysis/app
|
13 |
+
Link to huggingface model repo: https://huggingface.co/saccharinedreams/finetuned-distilbert-base-uncased-for-hupd
|
14 |
Link to huggingface app: https://huggingface.co/spaces/saccharinedreams/sentiment-analysis-app
|
15 |
|
16 |
|
app.py
CHANGED
@@ -59,6 +59,7 @@ def sentiment_analysis(model, tokenizer):
|
|
59 |
# Title the Streamlit app 'Finetuned Harvard USPTO Patent Dataset (using DistilBert-Base-Uncased)'
|
60 |
st.title('Finetuned Sentiment Analysis for US Patents')
|
61 |
st.markdown('Link to the app - [sentiment-analysis-app](https://huggingface.co/spaces/saccharinedreams/sentiment-analysis-app)')
|
|
|
62 |
st.markdown('This model was finetuned on the Harvard USPTO Patent Dataset and uses Distilbert-Base-Uncased.')
|
63 |
|
64 |
abstracts = load_abstracts()
|
|
|
59 |
# Title the Streamlit app 'Finetuned Harvard USPTO Patent Dataset (using DistilBert-Base-Uncased)'
|
60 |
st.title('Finetuned Sentiment Analysis for US Patents')
|
61 |
st.markdown('Link to the app - [sentiment-analysis-app](https://huggingface.co/spaces/saccharinedreams/sentiment-analysis-app)')
|
62 |
+
st.markdown('Link to the model - [model repo](https://huggingface.co/saccharinedreams/finetuned-distilbert-base-uncased-for-hupd')
|
63 |
st.markdown('This model was finetuned on the Harvard USPTO Patent Dataset and uses Distilbert-Base-Uncased.')
|
64 |
|
65 |
abstracts = load_abstracts()
|
finetunehupd.ipynb
CHANGED
@@ -53,8 +53,12 @@
|
|
53 |
}
|
54 |
],
|
55 |
"source": [
|
|
|
56 |
"tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')\n",
|
57 |
"\n",
|
|
|
|
|
|
|
58 |
"dataset_dict = load_dataset('HUPD/hupd',\n",
|
59 |
" name='sample',\n",
|
60 |
" data_files=\"https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather\", \n",
|
@@ -100,6 +104,7 @@
|
|
100 |
],
|
101 |
"source": [
|
102 |
"# Label-to-index mapping for the decision status field\n",
|
|
|
103 |
"decision_to_str = {'REJECTED': 0, 'ACCEPTED': 1, 'PENDING': 0, 'CONT-REJECTED': 0, 'CONT-ACCEPTED': 0, 'CONT-PENDING': 0}\n",
|
104 |
"\n",
|
105 |
"# Helper function\n",
|
@@ -179,6 +184,7 @@
|
|
179 |
"metadata": {},
|
180 |
"outputs": [],
|
181 |
"source": [
|
|
|
182 |
"train_set = train_set.remove_columns([\"patent_number\", \"title\", \"abstract\", \"claims\", \"background\", \"summary\", \"description\", \"cpc_label\", \"ipc_label\", \"filing_date\", \"patent_issue_date\", \"date_published\", \"examiner_id\"])\n",
|
183 |
"val_set = val_set.remove_columns([\"patent_number\", \"title\", \"abstract\", \"claims\", \"background\", \"summary\", \"description\", \"cpc_label\", \"ipc_label\", \"filing_date\", \"patent_issue_date\", \"date_published\", \"examiner_id\"])\n",
|
184 |
"\n",
|
@@ -241,7 +247,7 @@
|
|
241 |
"metadata": {},
|
242 |
"outputs": [],
|
243 |
"source": [
|
244 |
-
"# Set the format\n",
|
245 |
"train_set.set_format(type='torch', \n",
|
246 |
" columns=['labels', 'input_ids', 'attention_mask'])\n",
|
247 |
"\n",
|
@@ -249,18 +255,6 @@
|
|
249 |
" columns=['labels', 'input_ids', 'attention_mask'])"
|
250 |
]
|
251 |
},
|
252 |
-
{
|
253 |
-
"cell_type": "code",
|
254 |
-
"execution_count": 10,
|
255 |
-
"id": "d7ac796a-9f6e-4213-960f-e17837c27d87",
|
256 |
-
"metadata": {},
|
257 |
-
"outputs": [],
|
258 |
-
"source": [
|
259 |
-
"# train_dataloader and val_data_loader\n",
|
260 |
-
"train_dataloader = DataLoader(train_set, batch_size=16)\n",
|
261 |
-
"val_dataloader = DataLoader(val_set, batch_size=16)"
|
262 |
-
]
|
263 |
-
},
|
264 |
{
|
265 |
"cell_type": "code",
|
266 |
"execution_count": 11,
|
@@ -289,6 +283,7 @@
|
|
289 |
}
|
290 |
],
|
291 |
"source": [
|
|
|
292 |
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
293 |
"model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')\n",
|
294 |
"model.to(device)\n",
|
@@ -1226,97 +1221,6 @@
|
|
1226 |
"source": [
|
1227 |
"trainer.train()"
|
1228 |
]
|
1229 |
-
},
|
1230 |
-
{
|
1231 |
-
"cell_type": "markdown",
|
1232 |
-
"id": "304e0d65-74cf-4945-978d-b9f56c5a83b1",
|
1233 |
-
"metadata": {},
|
1234 |
-
"source": [
|
1235 |
-
"PyTorch Training Loop"
|
1236 |
-
]
|
1237 |
-
},
|
1238 |
-
{
|
1239 |
-
"cell_type": "code",
|
1240 |
-
"execution_count": null,
|
1241 |
-
"id": "e56d14fb-dfde-40fa-9dfa-1187c2e09866",
|
1242 |
-
"metadata": {},
|
1243 |
-
"outputs": [],
|
1244 |
-
"source": [
|
1245 |
-
"# model.train()\n",
|
1246 |
-
"# optim = AdamW(model.parameters(), lr=5e-5)\n",
|
1247 |
-
"# num_training_epochs = 2\n",
|
1248 |
-
"\n",
|
1249 |
-
"# for epoch in range(num_training_epochs):\n",
|
1250 |
-
"# print(\"starting epoch: \" + str(epoch))\n",
|
1251 |
-
"# for batch in train_dataloader:\n",
|
1252 |
-
"# optim.zero_grad()\n",
|
1253 |
-
"# input_ids = batch['input_ids'].to(device)\n",
|
1254 |
-
"# attention_mask = batch['attention_mask'].to(device)\n",
|
1255 |
-
"# labels = batch['labels'].to(device)\n",
|
1256 |
-
"# outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
|
1257 |
-
"# loss = outputs[0]\n",
|
1258 |
-
"# loss.backward()\n",
|
1259 |
-
"# optim.step()\n",
|
1260 |
-
"# model.eval()"
|
1261 |
-
]
|
1262 |
-
},
|
1263 |
-
{
|
1264 |
-
"cell_type": "code",
|
1265 |
-
"execution_count": 5,
|
1266 |
-
"id": "9b496593-c0de-4ce2-95d5-d5d3bf09d93c",
|
1267 |
-
"metadata": {},
|
1268 |
-
"outputs": [
|
1269 |
-
{
|
1270 |
-
"data": {
|
1271 |
-
"text/plain": [
|
1272 |
-
"'The present invention relates to passive optical network (PON), and in particular, to an optical network terminal (ONT) in the PON system. In one embodiment, the optical network terminal includes a first interface coupled to a communications network, a second interface coupled to a network client and a processor including a memory coupled to the first interface and to the second interface, wherein the processor is capable of converting optical signals to electric signals, such that the network client can access the communications network.'"
|
1273 |
-
]
|
1274 |
-
},
|
1275 |
-
"execution_count": 5,
|
1276 |
-
"metadata": {},
|
1277 |
-
"output_type": "execute_result"
|
1278 |
-
}
|
1279 |
-
],
|
1280 |
-
"source": [
|
1281 |
-
"dataset_dict['train']['abstract'][0]"
|
1282 |
-
]
|
1283 |
-
},
|
1284 |
-
{
|
1285 |
-
"cell_type": "code",
|
1286 |
-
"execution_count": null,
|
1287 |
-
"id": "cd6f20fc-9874-465d-9781-505415db3ffd",
|
1288 |
-
"metadata": {},
|
1289 |
-
"outputs": [],
|
1290 |
-
"source": []
|
1291 |
-
},
|
1292 |
-
{
|
1293 |
-
"cell_type": "code",
|
1294 |
-
"execution_count": 6,
|
1295 |
-
"id": "6b6ad778-15aa-492a-9484-40106269e10d",
|
1296 |
-
"metadata": {},
|
1297 |
-
"outputs": [
|
1298 |
-
{
|
1299 |
-
"data": {
|
1300 |
-
"text/plain": [
|
1301 |
-
"'Embodiments of the invention provide a method of reading and verifying a tag based on inherent disorder during a manufacturing process. The method includes using a first reader to take a first reading of an inherent disorder feature of the tag, and using a second reader to take a second reading of the inherent disorder feature of the tag. The method further includes matching the first reading with the second reading, and determining one or more acceptance criteria, wherein at least one of the acceptance criteria is based on whether the first reading and the second reading match within a predetermined threshold. If the acceptance criteria are met, then the tag is accepted, and a fingerprint for the tag is recorded. The invention further provides a method of testing and characterizing a reader of inherent disorder tags during a manufacturing process. The method includes taking a reading of a known inherent disorder tag, using the reading to measure a characteristic of the reader, and storing the measured characteristic for use when reading inherent disorder tags.'"
|
1302 |
-
]
|
1303 |
-
},
|
1304 |
-
"execution_count": 6,
|
1305 |
-
"metadata": {},
|
1306 |
-
"output_type": "execute_result"
|
1307 |
-
}
|
1308 |
-
],
|
1309 |
-
"source": [
|
1310 |
-
"dataset_dict['train']['abstract'][1]"
|
1311 |
-
]
|
1312 |
-
},
|
1313 |
-
{
|
1314 |
-
"cell_type": "code",
|
1315 |
-
"execution_count": null,
|
1316 |
-
"id": "cf9d94a7-4d7e-4f6b-88f8-46c9855289f4",
|
1317 |
-
"metadata": {},
|
1318 |
-
"outputs": [],
|
1319 |
-
"source": []
|
1320 |
}
|
1321 |
],
|
1322 |
"metadata": {
|
|
|
53 |
}
|
54 |
],
|
55 |
"source": [
|
56 |
+
"# Use the tokenizer for DistilBert since we are using the pretrained base uncased model \n",
|
57 |
"tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')\n",
|
58 |
"\n",
|
59 |
+
"# Load Harvard USPTO Patents Dataset from January 2-16 dataset using the load_dataset function\n",
|
60 |
+
"# Split the data into a train and validation set. For this purpose, I have split the patents as follows:\n",
|
61 |
+
"# train: Jan 1 - Jan 22, val: Jan 22 - Jan 31.\n",
|
62 |
"dataset_dict = load_dataset('HUPD/hupd',\n",
|
63 |
" name='sample',\n",
|
64 |
" data_files=\"https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather\", \n",
|
|
|
104 |
],
|
105 |
"source": [
|
106 |
"# Label-to-index mapping for the decision status field\n",
|
107 |
+
"# Convert to binary, as we only care about whether or not a patent was rejected or accepted.\n",
|
108 |
"decision_to_str = {'REJECTED': 0, 'ACCEPTED': 1, 'PENDING': 0, 'CONT-REJECTED': 0, 'CONT-ACCEPTED': 0, 'CONT-PENDING': 0}\n",
|
109 |
"\n",
|
110 |
"# Helper function\n",
|
|
|
184 |
"metadata": {},
|
185 |
"outputs": [],
|
186 |
"source": [
|
187 |
+
"# Process the train and validation sets so that only labels, input ids, and attention masks are left.\n",
|
188 |
"train_set = train_set.remove_columns([\"patent_number\", \"title\", \"abstract\", \"claims\", \"background\", \"summary\", \"description\", \"cpc_label\", \"ipc_label\", \"filing_date\", \"patent_issue_date\", \"date_published\", \"examiner_id\"])\n",
|
189 |
"val_set = val_set.remove_columns([\"patent_number\", \"title\", \"abstract\", \"claims\", \"background\", \"summary\", \"description\", \"cpc_label\", \"ipc_label\", \"filing_date\", \"patent_issue_date\", \"date_published\", \"examiner_id\"])\n",
|
190 |
"\n",
|
|
|
247 |
"metadata": {},
|
248 |
"outputs": [],
|
249 |
"source": [
|
250 |
+
"# Set the format to 'torch'\n",
|
251 |
"train_set.set_format(type='torch', \n",
|
252 |
" columns=['labels', 'input_ids', 'attention_mask'])\n",
|
253 |
"\n",
|
|
|
255 |
" columns=['labels', 'input_ids', 'attention_mask'])"
|
256 |
]
|
257 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
{
|
259 |
"cell_type": "code",
|
260 |
"execution_count": 11,
|
|
|
283 |
}
|
284 |
],
|
285 |
"source": [
|
286 |
+
"# Use GPU acceleration. Make sure you have the right cuda and pytorch versions!\n",
|
287 |
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
288 |
"model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')\n",
|
289 |
"model.to(device)\n",
|
|
|
1221 |
"source": [
|
1222 |
"trainer.train()"
|
1223 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1224 |
}
|
1225 |
],
|
1226 |
"metadata": {
|
finetunehupd.py
DELETED
@@ -1,92 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
"""FinetuneHUPD.ipynb
|
3 |
-
|
4 |
-
Automatically generated by Colaboratory.
|
5 |
-
|
6 |
-
Original file is located at
|
7 |
-
https://colab.research.google.com/drive/17c2CQZx_kyD3-0fuQqv_pCMJ0Evd7fLN
|
8 |
-
"""
|
9 |
-
|
10 |
-
# Pretty print
|
11 |
-
from pprint import pprint
|
12 |
-
# Datasets load_dataset function
|
13 |
-
from datasets import load_dataset
|
14 |
-
# Transformers Autokenizer
|
15 |
-
from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizer, Trainer, TrainingArguments, AdamW
|
16 |
-
from torch.utils.data import DataLoader
|
17 |
-
import torch
|
18 |
-
|
19 |
-
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
|
20 |
-
|
21 |
-
dataset_dict = load_dataset('HUPD/hupd',
|
22 |
-
name='sample',
|
23 |
-
data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",
|
24 |
-
icpr_label=None,
|
25 |
-
train_filing_start_date='2016-01-01',
|
26 |
-
train_filing_end_date='2016-01-31',
|
27 |
-
val_filing_start_date='2016-01-01',
|
28 |
-
val_filing_end_date='2016-01-31',
|
29 |
-
)
|
30 |
-
|
31 |
-
print('Loading is done!')
|
32 |
-
|
33 |
-
# Label-to-index mapping for the decision status field
|
34 |
-
decision_to_str = {'REJECTED': 0, 'ACCEPTED': 1, 'PENDING': 2, 'CONT-REJECTED': 3, 'CONT-ACCEPTED': 4, 'CONT-PENDING': 5}
|
35 |
-
|
36 |
-
# Helper function
|
37 |
-
def map_decision_to_string(example):
|
38 |
-
return {'decision': decision_to_str[example['decision']]}
|
39 |
-
|
40 |
-
# Re-labeling/mapping.
|
41 |
-
train_set = dataset_dict['train'].map(map_decision_to_string)
|
42 |
-
val_set = dataset_dict['validation'].map(map_decision_to_string)
|
43 |
-
|
44 |
-
# Focus on the abstract section and tokenize the text using the tokenizer.
|
45 |
-
_SECTION_ = 'abstract'
|
46 |
-
|
47 |
-
# Training set
|
48 |
-
train_set = train_set.map(
|
49 |
-
lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),
|
50 |
-
batched=True)
|
51 |
-
|
52 |
-
# Validation set
|
53 |
-
val_set = val_set.map(
|
54 |
-
lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),
|
55 |
-
batched=True)
|
56 |
-
|
57 |
-
# Set the format
|
58 |
-
train_set.set_format(type='torch',
|
59 |
-
columns=['input_ids', 'attention_mask', 'decision'])
|
60 |
-
|
61 |
-
val_set.set_format(type='torch',
|
62 |
-
columns=['input_ids', 'attention_mask', 'decision'])
|
63 |
-
|
64 |
-
#print(train_set['decision'])
|
65 |
-
|
66 |
-
# train_dataloader and val_data_loader
|
67 |
-
train_dataloader = DataLoader(train_set, batch_size=16)
|
68 |
-
val_dataloader = DataLoader(val_set, batch_size=16)
|
69 |
-
|
70 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
71 |
-
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
|
72 |
-
model.to(device)
|
73 |
-
print(device)
|
74 |
-
print("torch cuda is avail: ")
|
75 |
-
print(torch.cuda.is_available())
|
76 |
-
model.train()
|
77 |
-
optim = AdamW(model.parameters(), lr=5e-5)
|
78 |
-
num_training_epochs = 2
|
79 |
-
|
80 |
-
for epoch in range(num_training_epochs):
|
81 |
-
for batch in train_dataloader:
|
82 |
-
optim.zero_grad()
|
83 |
-
input_ids = batch['input_ids'].to(device)
|
84 |
-
attention_mask = batch['attention_mask'].to(device)
|
85 |
-
labels = batch['decision'].to(device)
|
86 |
-
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
|
87 |
-
loss = outputs[0]
|
88 |
-
loss.backward()
|
89 |
-
optim.step()
|
90 |
-
print("batch finished")
|
91 |
-
|
92 |
-
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|