{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "1df3c609-62a6-49c3-bcc6-29c520f9501c", "metadata": {}, "outputs": [], "source": [ "# Pretty print\n", "from pprint import pprint\n", "# Datasets load_dataset function\n", "from datasets import load_dataset\n", "# Transformers Autokenizer\n", "from transformers import AutoTokenizer, DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertTokenizerFast, Trainer, TrainingArguments, AdamW\n", "from torch.utils.data import DataLoader\n", "import torch" ] }, { "cell_type": "code", "execution_count": 3, "id": "58167c28-eb27-4f82-b7d0-8216dbeaf650", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset hupd (C:/Users/calia/.cache/huggingface/datasets/HUPD___hupd/sample-5094df4de61ed3bc/0.0.0/6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a2f090474cb148548ce3eb73698fcc6c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Loading is done!\n" ] } ], "source": [ "tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')\n", "\n", "dataset_dict = load_dataset('HUPD/hupd',\n", " name='sample',\n", " data_files=\"https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather\", \n", " icpr_label=None,\n", " train_filing_start_date='2016-01-01',\n", " train_filing_end_date='2016-01-21',\n", " val_filing_start_date='2016-01-22',\n", " val_filing_end_date='2016-01-31',\n", ")\n", "\n", "print('Loading is done!')" ] }, { "cell_type": "code", "execution_count": 3, "id": "e13c6ad1-a7f2-4806-80a2-e9c4655e1eed", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at C:\\Users\\calia\\.cache\\huggingface\\datasets\\HUPD___hupd\\sample-5094df4de61ed3bc\\0.0.0\\6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142\\cache-9f7788eb9924fd62.arrow\n", "Loading cached processed dataset at C:\\Users\\calia\\.cache\\huggingface\\datasets\\HUPD___hupd\\sample-5094df4de61ed3bc\\0.0.0\\6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142\\cache-6c3687322fe5b556.arrow\n", "Loading cached processed dataset at C:\\Users\\calia\\.cache\\huggingface\\datasets\\HUPD___hupd\\sample-5094df4de61ed3bc\\0.0.0\\6920d2def8fd7767046c0470603357f76866e5a09c97e19571896bfdca521142\\cache-bd3b1eee4495f3ce.arrow\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/9094 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Label-to-index mapping for the decision status field\n", "decision_to_str = {'REJECTED': 0, 'ACCEPTED': 1, 'PENDING': 0, 'CONT-REJECTED': 0, 'CONT-ACCEPTED': 0, 'CONT-PENDING': 0}\n", "\n", "# Helper function\n", "def map_decision_to_string(example):\n", " return {'decision': decision_to_str[example['decision']]}\n", "\n", "# Re-labeling/mapping.\n", "train_set = dataset_dict['train'].map(map_decision_to_string)\n", "val_set = dataset_dict['validation'].map(map_decision_to_string)\n", "\n", "# Focus on the abstract section and tokenize the text using the tokenizer. \n", "_SECTION_ = 'abstract'\n", "\n", "# Training set\n", "train_set = train_set.map(\n", " lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),\n", " batched=True)\n", "\n", "# Validation set\n", "val_set = val_set.map(\n", " lambda e: tokenizer((e[_SECTION_]), truncation=True, padding='max_length'),\n", " batched=True)" ] }, { "cell_type": "code", "execution_count": 4, "id": "b5c098be-019b-42ce-9b80-4f6de93ef6a3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['patent_number', 'decision', 'title', 'abstract', 'claims', 'background', 'summary', 'description', 'cpc_label', 'ipc_label', 'filing_date', 'patent_issue_date', 'date_published', 'examiner_id', 'input_ids', 'attention_mask'],\n", " num_rows: 16153\n", "})" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_set" ] }, { "cell_type": "code", "execution_count": 5, "id": "1e5a5390-19fe-4a73-b913-e3c1e2c2a399", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['patent_number', 'decision', 'title', 'abstract', 'claims', 'background', 'summary', 'description', 'cpc_label', 'ipc_label', 'filing_date', 'patent_issue_date', 'date_published', 'examiner_id', 'input_ids', 'attention_mask'],\n", " num_rows: 9094\n", "})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val_set" ] }, { "cell_type": "code", "execution_count": 6, "id": "4fb69db8-86e5-4c6c-8ac6-853d3e15fb93", "metadata": {}, "outputs": [], "source": [ "train_set = train_set.remove_columns([\"patent_number\", \"title\", \"abstract\", \"claims\", \"background\", \"summary\", \"description\", \"cpc_label\", \"ipc_label\", \"filing_date\", \"patent_issue_date\", \"date_published\", \"examiner_id\"])\n", "val_set = val_set.remove_columns([\"patent_number\", \"title\", \"abstract\", \"claims\", \"background\", \"summary\", \"description\", \"cpc_label\", \"ipc_label\", \"filing_date\", \"patent_issue_date\", \"date_published\", \"examiner_id\"])\n", "\n", "train_set = train_set.rename_column(\"decision\", \"labels\")\n", "val_set = val_set.rename_column(\"decision\", \"labels\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "c0d17213-4b14-418c-980c-0238236096c2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['labels', 'input_ids', 'attention_mask'],\n", " num_rows: 16153\n", "})" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_set" ] }, { "cell_type": "code", "execution_count": 8, "id": "da2f1c16-3ba4-4e56-9455-5cd838df4dcd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['labels', 'input_ids', 'attention_mask'],\n", " num_rows: 9094\n", "})" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val_set" ] }, { "cell_type": "code", "execution_count": 9, "id": "cfb35702-863d-4fec-83e1-44c4e5668156", "metadata": {}, "outputs": [], "source": [ "# Set the format\n", "train_set.set_format(type='torch', \n", " columns=['labels', 'input_ids', 'attention_mask'])\n", "\n", "val_set.set_format(type='torch', \n", " columns=['labels', 'input_ids', 'attention_mask'])" ] }, { "cell_type": "code", "execution_count": 10, "id": "d7ac796a-9f6e-4213-960f-e17837c27d87", "metadata": {}, "outputs": [], "source": [ "# train_dataloader and val_data_loader\n", "train_dataloader = DataLoader(train_set, batch_size=16)\n", "val_dataloader = DataLoader(val_set, batch_size=16)" ] }, { "cell_type": "code", "execution_count": 11, "id": "b3248182-fddb-46dc-addb-26981a881d99", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias']\n", "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "cuda\n", "torch cuda is avail: \n", "True\n" ] } ], "source": [ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", "model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')\n", "model.to(device)\n", "print(device)\n", "print(\"torch cuda is avail: \")\n", "print(torch.cuda.is_available())" ] }, { "cell_type": "markdown", "id": "abb2cf74-3cd5-4ca5-af0e-b0ee80627f2a", "metadata": {}, "source": [ "HuggingFace Trainer" ] }, { "cell_type": "code", "execution_count": 12, "id": "99947cf9-a6cd-490f-a81d-32f65fb3cd46", "metadata": {}, "outputs": [], "source": [ "training_args = TrainingArguments(\n", " output_dir='./results/',\n", " num_train_epochs=2,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " warmup_steps=500,\n", " learning_rate=5e-5,\n", " weight_decay=0.01,\n", " logging_dir='./logs/',\n", " logging_steps=10,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_set,\n", " eval_dataset=val_set,\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "be865f1d-f29b-4306-8570-900386ac4570", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\calia\\anaconda3\\envs\\ai-finetuning-project\\lib\\site-packages\\transformers\\optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "***** Running training *****\n", " Num examples = 16153\n", " Num Epochs = 2\n", " Instantaneous batch size per device = 16\n", " Total train batch size (w. parallel, distributed & accumulation) = 16\n", " Gradient Accumulation steps = 1\n", " Total optimization steps = 2020\n", " Number of trainable parameters = 66955010\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
10 | \n", "0.692000 | \n", "
20 | \n", "0.685100 | \n", "
30 | \n", "0.684000 | \n", "
40 | \n", "0.685100 | \n", "
50 | \n", "0.678400 | \n", "
60 | \n", "0.687300 | \n", "
70 | \n", "0.681900 | \n", "
80 | \n", "0.691100 | \n", "
90 | \n", "0.683200 | \n", "
100 | \n", "0.694100 | \n", "
110 | \n", "0.673300 | \n", "
120 | \n", "0.694100 | \n", "
130 | \n", "0.669500 | \n", "
140 | \n", "0.691100 | \n", "
150 | \n", "0.683400 | \n", "
160 | \n", "0.654900 | \n", "
170 | \n", "0.684300 | \n", "
180 | \n", "0.679300 | \n", "
190 | \n", "0.662600 | \n", "
200 | \n", "0.598400 | \n", "
210 | \n", "0.717700 | \n", "
220 | \n", "0.679100 | \n", "
230 | \n", "0.677500 | \n", "
240 | \n", "0.668800 | \n", "
250 | \n", "0.678100 | \n", "
260 | \n", "0.657500 | \n", "
270 | \n", "0.707200 | \n", "
280 | \n", "0.670300 | \n", "
290 | \n", "0.659900 | \n", "
300 | \n", "0.633300 | \n", "
310 | \n", "0.676300 | \n", "
320 | \n", "0.684800 | \n", "
330 | \n", "0.673100 | \n", "
340 | \n", "0.670500 | \n", "
350 | \n", "0.657500 | \n", "
360 | \n", "0.618100 | \n", "
370 | \n", "0.670000 | \n", "
380 | \n", "0.607400 | \n", "
390 | \n", "0.656200 | \n", "
400 | \n", "0.700000 | \n", "
410 | \n", "0.644800 | \n", "
420 | \n", "0.682800 | \n", "
430 | \n", "0.668800 | \n", "
440 | \n", "0.662600 | \n", "
450 | \n", "0.647700 | \n", "
460 | \n", "0.688600 | \n", "
470 | \n", "0.682400 | \n", "
480 | \n", "0.642900 | \n", "
490 | \n", "0.726900 | \n", "
500 | \n", "0.660400 | \n", "
510 | \n", "0.649500 | \n", "
520 | \n", "0.637200 | \n", "
530 | \n", "0.669700 | \n", "
540 | \n", "0.667100 | \n", "
550 | \n", "0.617000 | \n", "
560 | \n", "0.725300 | \n", "
570 | \n", "0.656800 | \n", "
580 | \n", "0.664600 | \n", "
590 | \n", "0.702600 | \n", "
600 | \n", "0.686300 | \n", "
610 | \n", "0.668400 | \n", "
620 | \n", "0.648200 | \n", "
630 | \n", "0.628700 | \n", "
640 | \n", "0.676700 | \n", "
650 | \n", "0.652400 | \n", "
660 | \n", "0.654300 | \n", "
670 | \n", "0.640800 | \n", "
680 | \n", "0.672000 | \n", "
690 | \n", "0.636100 | \n", "
700 | \n", "0.689100 | \n", "
710 | \n", "0.691100 | \n", "
720 | \n", "0.650300 | \n", "
730 | \n", "0.655200 | \n", "
740 | \n", "0.668400 | \n", "
750 | \n", "0.659200 | \n", "
760 | \n", "0.647800 | \n", "
770 | \n", "0.662800 | \n", "
780 | \n", "0.648500 | \n", "
790 | \n", "0.656700 | \n", "
800 | \n", "0.669400 | \n", "
810 | \n", "0.607800 | \n", "
820 | \n", "0.683200 | \n", "
830 | \n", "0.663800 | \n", "
840 | \n", "0.700900 | \n", "
850 | \n", "0.648200 | \n", "
860 | \n", "0.619400 | \n", "
870 | \n", "0.649200 | \n", "
880 | \n", "0.717500 | \n", "
890 | \n", "0.669600 | \n", "
900 | \n", "0.669700 | \n", "
910 | \n", "0.683900 | \n", "
920 | \n", "0.636900 | \n", "
930 | \n", "0.656400 | \n", "
940 | \n", "0.650000 | \n", "
950 | \n", "0.617800 | \n", "
960 | \n", "0.665600 | \n", "
970 | \n", "0.642700 | \n", "
980 | \n", "0.644000 | \n", "
990 | \n", "0.688900 | \n", "
1000 | \n", "0.654700 | \n", "
1010 | \n", "0.645800 | \n", "
1020 | \n", "0.609200 | \n", "
1030 | \n", "0.602300 | \n", "
1040 | \n", "0.618800 | \n", "
1050 | \n", "0.643500 | \n", "
1060 | \n", "0.611000 | \n", "
1070 | \n", "0.645000 | \n", "
1080 | \n", "0.641000 | \n", "
1090 | \n", "0.595400 | \n", "
1100 | \n", "0.635100 | \n", "
1110 | \n", "0.611600 | \n", "
1120 | \n", "0.600300 | \n", "
1130 | \n", "0.618100 | \n", "
1140 | \n", "0.617200 | \n", "
1150 | \n", "0.633400 | \n", "
1160 | \n", "0.597600 | \n", "
1170 | \n", "0.619400 | \n", "
1180 | \n", "0.584200 | \n", "
1190 | \n", "0.600700 | \n", "
1200 | \n", "0.657400 | \n", "
1210 | \n", "0.569600 | \n", "
1220 | \n", "0.575500 | \n", "
1230 | \n", "0.617900 | \n", "
1240 | \n", "0.610300 | \n", "
1250 | \n", "0.570600 | \n", "
1260 | \n", "0.545700 | \n", "
1270 | \n", "0.656300 | \n", "
1280 | \n", "0.554700 | \n", "
1290 | \n", "0.598200 | \n", "
1300 | \n", "0.606300 | \n", "
1310 | \n", "0.600500 | \n", "
1320 | \n", "0.569800 | \n", "
1330 | \n", "0.604700 | \n", "
1340 | \n", "0.628300 | \n", "
1350 | \n", "0.602700 | \n", "
1360 | \n", "0.583700 | \n", "
1370 | \n", "0.623800 | \n", "
1380 | \n", "0.670300 | \n", "
1390 | \n", "0.622400 | \n", "
1400 | \n", "0.590200 | \n", "
1410 | \n", "0.587000 | \n", "
1420 | \n", "0.555500 | \n", "
1430 | \n", "0.561000 | \n", "
1440 | \n", "0.514300 | \n", "
1450 | \n", "0.553100 | \n", "
1460 | \n", "0.692400 | \n", "
1470 | \n", "0.605200 | \n", "
1480 | \n", "0.548000 | \n", "
1490 | \n", "0.672600 | \n", "
1500 | \n", "0.531100 | \n", "
1510 | \n", "0.610600 | \n", "
1520 | \n", "0.580200 | \n", "
1530 | \n", "0.571300 | \n", "
1540 | \n", "0.644400 | \n", "
1550 | \n", "0.558500 | \n", "
1560 | \n", "0.624000 | \n", "
1570 | \n", "0.659200 | \n", "
1580 | \n", "0.580500 | \n", "
1590 | \n", "0.649900 | \n", "
1600 | \n", "0.608700 | \n", "
1610 | \n", "0.595100 | \n", "
1620 | \n", "0.592900 | \n", "
1630 | \n", "0.584000 | \n", "
1640 | \n", "0.607100 | \n", "
1650 | \n", "0.565800 | \n", "
1660 | \n", "0.568300 | \n", "
1670 | \n", "0.572200 | \n", "
1680 | \n", "0.597500 | \n", "
1690 | \n", "0.602700 | \n", "
1700 | \n", "0.692900 | \n", "
1710 | \n", "0.597900 | \n", "
1720 | \n", "0.538600 | \n", "
1730 | \n", "0.599400 | \n", "
1740 | \n", "0.704300 | \n", "
1750 | \n", "0.580500 | \n", "
1760 | \n", "0.595600 | \n", "
1770 | \n", "0.583100 | \n", "
1780 | \n", "0.569500 | \n", "
1790 | \n", "0.603300 | \n", "
1800 | \n", "0.564500 | \n", "
1810 | \n", "0.592100 | \n", "
1820 | \n", "0.617000 | \n", "
1830 | \n", "0.656500 | \n", "
1840 | \n", "0.563600 | \n", "
1850 | \n", "0.624800 | \n", "
1860 | \n", "0.686700 | \n", "
1870 | \n", "0.572300 | \n", "
1880 | \n", "0.587700 | \n", "
1890 | \n", "0.583000 | \n", "
1900 | \n", "0.601500 | \n", "
1910 | \n", "0.559700 | \n", "
1920 | \n", "0.610100 | \n", "
1930 | \n", "0.571300 | \n", "
1940 | \n", "0.549900 | \n", "
1950 | \n", "0.589200 | \n", "
1960 | \n", "0.634800 | \n", "
1970 | \n", "0.584200 | \n", "
1980 | \n", "0.557000 | \n", "
1990 | \n", "0.602700 | \n", "
2000 | \n", "0.669700 | \n", "
2010 | \n", "0.607500 | \n", "
2020 | \n", "0.631800 | \n", "
"
],
"text/plain": [
"