{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"dotnet_interactive": {
"language": "csharp"
},
"polyglot_notebook": {
"kernelName": "csharp"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"g:\\IDEs and Modules\\Anaconda\\envs\\pytorch_gpu\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Found cached dataset funsd-layoutlmv3 (C:/Users/csara/.cache/huggingface/datasets/nielsr___funsd-layoutlmv3/funsd/1.0.0/0e3f4efdfd59aa1c3b4952c517894f7b1fc4d75c12ef01bcc8626a69e41c1bb9)\n",
"100%|██████████| 2/2 [00:00<00:00, 290.08it/s]\n"
]
}
],
"source": [
"from datasets import load_dataset \n",
"dataset = load_dataset(\"nielsr/funsd-layoutlmv3\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:12:08.141445Z",
"iopub.status.busy": "2023-03-03T19:12:08.140739Z",
"iopub.status.idle": "2023-03-03T19:12:24.061622Z",
"shell.execute_reply": "2023-03-03T19:12:24.060499Z",
"shell.execute_reply.started": "2023-03-03T19:12:08.141404Z"
},
"trusted": true
},
"outputs": [],
"source": [
"from transformers import AutoProcessor\n",
"from datasets.features import ClassLabel\n",
"tokenizer = AutoProcessor.from_pretrained(\"microsoft/layoutlmv3-base\", apply_ocr=False)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:12:49.955307Z",
"iopub.status.busy": "2023-03-03T19:12:49.954390Z",
"iopub.status.idle": "2023-03-03T19:12:49.962064Z",
"shell.execute_reply": "2023-03-03T19:12:49.961038Z",
"shell.execute_reply.started": "2023-03-03T19:12:49.955256Z"
},
"trusted": true
},
"outputs": [],
"source": [
"def get_label_list(labels):\n",
" unique_labels = set()\n",
" for label in labels:\n",
" unique_labels = unique_labels | set(label)\n",
" label_list = list(unique_labels)\n",
" label_list.sort()\n",
" return label_list"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:12:59.478905Z",
"iopub.status.busy": "2023-03-03T19:12:59.478538Z",
"iopub.status.idle": "2023-03-03T19:12:59.487011Z",
"shell.execute_reply": "2023-03-03T19:12:59.485535Z",
"shell.execute_reply.started": "2023-03-03T19:12:59.478872Z"
},
"trusted": true
},
"outputs": [],
"source": [
"image_column_name = \"image\"\n",
"text_column_name = \"tokens\"\n",
"boxes_column_name = \"bboxes\"\n",
"label_column_name = \"ner_tags\"\n",
"\n",
"\n",
"features = dataset[\"train\"].features\n",
"column_names = dataset[\"train\"].column_names\n",
"\n",
"if isinstance(features[\"ner_tags\"].feature, ClassLabel):\n",
" label_list = features[\"ner_tags\"].feature.names\n",
" id2label = {k: v for k,v in enumerate(label_list)}\n",
" label2id = {v: k for k,v in enumerate(label_list)}\n",
"else:\n",
" label_list = get_label_list(dataset[\"train\"][\"ner_tags\"])\n",
" id2label = {k: v for k,v in enumerate(label_list)}\n",
" label2id = {v: k for k,v in enumerate(label_list)}\n",
"num_labels = len(label_list)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:13:09.900559Z",
"iopub.status.busy": "2023-03-03T19:13:09.900179Z",
"iopub.status.idle": "2023-03-03T19:13:09.907248Z",
"shell.execute_reply": "2023-03-03T19:13:09.906077Z",
"shell.execute_reply.started": "2023-03-03T19:13:09.900526Z"
},
"trusted": true
},
"outputs": [],
"source": [
"def encoder(examples):\n",
" images = examples[\"image\"]\n",
" words = examples[\"tokens\"]\n",
" boxes = examples[\"bboxes\"]\n",
" word_labels = examples[label_column_name]\n",
"\n",
" encoding = tokenizer(images, words, boxes=boxes, word_labels=word_labels,\n",
" truncation=True, padding=\"max_length\")\n",
"\n",
" return encoding"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:13:22.537584Z",
"iopub.status.busy": "2023-03-03T19:13:22.537226Z",
"iopub.status.idle": "2023-03-03T19:13:28.944323Z",
"shell.execute_reply": "2023-03-03T19:13:28.943322Z",
"shell.execute_reply.started": "2023-03-03T19:13:22.537552Z"
},
"trusted": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at C:\\Users\\csara\\.cache\\huggingface\\datasets\\nielsr___funsd-layoutlmv3\\funsd\\1.0.0\\0e3f4efdfd59aa1c3b4952c517894f7b1fc4d75c12ef01bcc8626a69e41c1bb9\\cache-1a2006e093366773.arrow\n",
"Loading cached processed dataset at C:\\Users\\csara\\.cache\\huggingface\\datasets\\nielsr___funsd-layoutlmv3\\funsd\\1.0.0\\0e3f4efdfd59aa1c3b4952c517894f7b1fc4d75c12ef01bcc8626a69e41c1bb9\\cache-b7746bd565a79622.arrow\n"
]
}
],
"source": [
"from datasets import Features, Sequence, ClassLabel, Value, Array2D, Array3D\n",
"\n",
"\n",
"features = Features({\n",
" 'pixel_values': Array3D(dtype=\"float32\", shape=(3, 224, 224)),\n",
" 'input_ids': Sequence(feature=Value(dtype='int64')),\n",
" 'attention_mask': Sequence(Value(dtype='int64')),\n",
" 'bbox': Array2D(dtype=\"int64\", shape=(512, 4)),\n",
" 'labels': Sequence(feature=Value(dtype='int64')),\n",
"})\n",
"\n",
"train_dataset = dataset[\"train\"].map(\n",
" encoder,\n",
" batched=True,\n",
" remove_columns=column_names,\n",
" features=features,\n",
")\n",
"eval_dataset = dataset[\"test\"].map(\n",
" encoder,\n",
" batched=True,\n",
" remove_columns=column_names,\n",
" features=features,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:13:40.875858Z",
"iopub.status.busy": "2023-03-03T19:13:40.875401Z",
"iopub.status.idle": "2023-03-03T19:13:40.882514Z",
"shell.execute_reply": "2023-03-03T19:13:40.881313Z",
"shell.execute_reply.started": "2023-03-03T19:13:40.875816Z"
},
"trusted": true
},
"outputs": [],
"source": [
"train_dataset.set_format(\"torch\")\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:13:47.654814Z",
"iopub.status.busy": "2023-03-03T19:13:47.654437Z",
"iopub.status.idle": "2023-03-03T19:13:47.691746Z",
"shell.execute_reply": "2023-03-03T19:13:47.690779Z",
"shell.execute_reply.started": "2023-03-03T19:13:47.654780Z"
},
"trusted": true
},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:16:13.499216Z",
"iopub.status.busy": "2023-03-03T19:16:13.498129Z",
"iopub.status.idle": "2023-03-03T19:16:14.186549Z",
"shell.execute_reply": "2023-03-03T19:16:14.185533Z",
"shell.execute_reply.started": "2023-03-03T19:16:13.499170Z"
},
"trusted": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\csara\\AppData\\Local\\Temp\\ipykernel_21636\\3673091550.py:2: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
" metric = load_metric(\"seqeval\")\n"
]
}
],
"source": [
"from datasets import load_metric\n",
"metric = load_metric(\"seqeval\")\n",
"\n",
"import numpy as np\n",
"\n",
"return_entity_level_metrics = False\n",
"\n",
"def compute_metrics(p):\n",
" predictions, labels = p\n",
" predictions = np.argmax(predictions, axis=2)\n",
"\n",
" # Remove ignored index (special tokens)\n",
" true_predictions = [\n",
" [label_list[p] for (p, l) in zip(prediction, label) if l != -100]\n",
" for prediction, label in zip(predictions, labels)\n",
" ]\n",
" true_labels = [\n",
" [label_list[l] for (p, l) in zip(prediction, label) if l != -100]\n",
" for prediction, label in zip(predictions, labels)\n",
" ]\n",
"\n",
" results = metric.compute(predictions=true_predictions, references=true_labels)\n",
" if return_entity_level_metrics:\n",
" # Unpack nested dictionaries\n",
" final_results = {}\n",
" for key, value in results.items():\n",
" if isinstance(value, dict):\n",
" for n, v in value.items():\n",
" final_results[f\"{key}_{n}\"] = v\n",
" else:\n",
" final_results[key] = value\n",
" return final_results\n",
" else:\n",
" return {\n",
" \"precision\": results[\"overall_precision\"],\n",
" \"recall\": results[\"overall_recall\"],\n",
" \"f1\": results[\"overall_f1\"],\n",
" \"accuracy\": results[\"overall_accuracy\"],\n",
" }\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:16:23.761664Z",
"iopub.status.busy": "2023-03-03T19:16:23.760963Z",
"iopub.status.idle": "2023-03-03T19:16:28.559914Z",
"shell.execute_reply": "2023-03-03T19:16:28.558886Z",
"shell.execute_reply.started": "2023-03-03T19:16:23.761620Z"
},
"trusted": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of LayoutLMv3ForTokenClassification were not initialized from the model checkpoint at microsoft/layoutlmv3-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import LayoutLMv3ForTokenClassification\n",
"\n",
"model = LayoutLMv3ForTokenClassification.from_pretrained(\"microsoft/layoutlmv3-base\",\n",
" id2label=id2label,\n",
" label2id=label2id)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:16:31.406847Z",
"iopub.status.busy": "2023-03-03T19:16:31.405450Z",
"iopub.status.idle": "2023-03-03T19:16:31.608864Z",
"shell.execute_reply": "2023-03-03T19:16:31.607848Z",
"shell.execute_reply.started": "2023-03-03T19:16:31.406796Z"
},
"trusted": true
},
"outputs": [],
"source": [
"from transformers import TrainingArguments, Trainer\n",
"\n",
"training_args = TrainingArguments(output_dir=\"test\",\n",
" max_steps=1000,\n",
" per_device_train_batch_size=2,\n",
" per_device_eval_batch_size=3,\n",
" learning_rate=1e-5,\n",
" evaluation_strategy=\"steps\",\n",
" eval_steps=100,\n",
" load_best_model_at_end=True,\n",
" metric_for_best_model=\"f1\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:16:33.355517Z",
"iopub.status.busy": "2023-03-03T19:16:33.354801Z",
"iopub.status.idle": "2023-03-03T19:16:38.092937Z",
"shell.execute_reply": "2023-03-03T19:16:38.091904Z",
"shell.execute_reply.started": "2023-03-03T19:16:33.355477Z"
},
"trusted": true
},
"outputs": [],
"source": [
"from transformers.data.data_collator import default_data_collator\n",
"\n",
"# Initialize our Trainer\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=eval_dataset,\n",
" tokenizer=tokenizer,\n",
" data_collator=default_data_collator,\n",
" compute_metrics=compute_metrics,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:16:40.470684Z",
"iopub.status.busy": "2023-03-03T19:16:40.469983Z",
"iopub.status.idle": "2023-03-03T19:23:15.995022Z",
"shell.execute_reply": "2023-03-03T19:23:15.993467Z",
"shell.execute_reply.started": "2023-03-03T19:16:40.470647Z"
},
"trusted": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"g:\\IDEs and Modules\\Anaconda\\envs\\pytorch_gpu\\lib\\site-packages\\transformers\\optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n",
" 0%| | 0/1000 [00:00, ?it/s]g:\\IDEs and Modules\\Anaconda\\envs\\pytorch_gpu\\lib\\site-packages\\transformers\\modeling_utils.py:884: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
" warnings.warn(\n",
" \n",
" 10%|█ | 100/1000 [05:03<46:20, 3.09s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.6499422192573547, 'eval_precision': 0.7698232895333031, 'eval_recall': 0.8440139095876801, 'eval_f1': 0.8052132701421801, 'eval_accuracy': 0.8016165458219422, 'eval_runtime': 34.3294, 'eval_samples_per_second': 1.456, 'eval_steps_per_second': 0.495, 'epoch': 1.33}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \n",
" 20%|██ | 200/1000 [11:12<38:47, 2.91s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.4936561584472656, 'eval_precision': 0.8225578102878717, 'eval_recall': 0.8658718330849479, 'eval_f1': 0.8436592449177153, 'eval_accuracy': 0.8246760965172947, 'eval_runtime': 35.9544, 'eval_samples_per_second': 1.391, 'eval_steps_per_second': 0.473, 'epoch': 2.67}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \n",
" 30%|███ | 300/1000 [17:00<33:27, 2.87s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.4679512083530426, 'eval_precision': 0.8520765282314512, 'eval_recall': 0.907103825136612, 'eval_f1': 0.8787295476419633, 'eval_accuracy': 0.8576013312730298, 'eval_runtime': 35.098, 'eval_samples_per_second': 1.425, 'eval_steps_per_second': 0.484, 'epoch': 4.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \n",
" 40%|████ | 400/1000 [22:50<29:56, 2.99s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.5078234076499939, 'eval_precision': 0.8813806514341274, 'eval_recall': 0.9006458022851466, 'eval_f1': 0.890909090909091, 'eval_accuracy': 0.849994056816831, 'eval_runtime': 38.4245, 'eval_samples_per_second': 1.301, 'eval_steps_per_second': 0.442, 'epoch': 5.33}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 50%|█████ | 500/1000 [28:23<25:09, 3.02s/it] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'loss': 0.5405, 'learning_rate': 5e-06, 'epoch': 6.67}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \n",
" 50%|█████ | 500/1000 [29:02<25:09, 3.02s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.5313129425048828, 'eval_precision': 0.8702807357212003, 'eval_recall': 0.8931942374565326, 'eval_f1': 0.8815886246629075, 'eval_accuracy': 0.8495186021633186, 'eval_runtime': 37.6242, 'eval_samples_per_second': 1.329, 'eval_steps_per_second': 0.452, 'epoch': 6.67}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"g:\\IDEs and Modules\\Anaconda\\envs\\pytorch_gpu\\lib\\site-packages\\transformers\\modeling_utils.py:884: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
" warnings.warn(\n",
" \n",
" 60%|██████ | 600/1000 [36:52<19:39, 2.95s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.5118691921234131, 'eval_precision': 0.8847441415590627, 'eval_recall': 0.9190263288623944, 'eval_f1': 0.9015594541910332, 'eval_accuracy': 0.8628313324616664, 'eval_runtime': 37.2452, 'eval_samples_per_second': 1.342, 'eval_steps_per_second': 0.456, 'epoch': 8.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \n",
" 70%|███████ | 700/1000 [42:26<19:21, 3.87s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.5196597576141357, 'eval_precision': 0.8799428299190091, 'eval_recall': 0.9175360158966717, 'eval_f1': 0.8983463035019456, 'eval_accuracy': 0.8707951979079995, 'eval_runtime': 38.9063, 'eval_samples_per_second': 1.285, 'eval_steps_per_second': 0.437, 'epoch': 9.33}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \n",
" 80%|████████ | 800/1000 [49:13<12:57, 3.89s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.5508859157562256, 'eval_precision': 0.8907603464870067, 'eval_recall': 0.9195230998509687, 'eval_f1': 0.9049132241505743, 'eval_accuracy': 0.8629501961250445, 'eval_runtime': 41.8293, 'eval_samples_per_second': 1.195, 'eval_steps_per_second': 0.406, 'epoch': 10.67}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \n",
" 90%|█████████ | 900/1000 [55:10<04:35, 2.75s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.5496693253517151, 'eval_precision': 0.8995098039215687, 'eval_recall': 0.9115747640337805, 'eval_f1': 0.9055020972119417, 'eval_accuracy': 0.8654463330559848, 'eval_runtime': 34.7336, 'eval_samples_per_second': 1.44, 'eval_steps_per_second': 0.489, 'epoch': 12.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [1:01:33<00:00, 3.32s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'loss': 0.123, 'learning_rate': 0.0, 'epoch': 13.33}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \n",
"100%|██████████| 1000/1000 [1:02:05<00:00, 3.32s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'eval_loss': 0.5414420962333679, 'eval_precision': 0.8941005802707931, 'eval_recall': 0.9185295578738202, 'eval_f1': 0.9061504533202647, 'eval_accuracy': 0.8693688339474622, 'eval_runtime': 32.6105, 'eval_samples_per_second': 1.533, 'eval_steps_per_second': 0.521, 'epoch': 13.33}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [1:02:26<00:00, 3.75s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'train_runtime': 3746.26, 'train_samples_per_second': 0.534, 'train_steps_per_second': 0.267, 'train_loss': 0.3317529182434082, 'epoch': 13.33}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=1000, training_loss=0.3317529182434082, metrics={'train_runtime': 3746.26, 'train_samples_per_second': 0.534, 'train_steps_per_second': 0.267, 'train_loss': 0.3317529182434082, 'epoch': 13.33})"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:24:05.222740Z",
"iopub.status.busy": "2023-03-03T19:24:05.222132Z",
"iopub.status.idle": "2023-03-03T19:24:09.031814Z",
"shell.execute_reply": "2023-03-03T19:24:09.030506Z",
"shell.execute_reply.started": "2023-03-03T19:24:05.222692Z"
},
"trusted": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"g:\\IDEs and Modules\\Anaconda\\envs\\pytorch_gpu\\lib\\site-packages\\transformers\\modeling_utils.py:884: FutureWarning: The `device` argument is deprecated and will be removed in v5 of Transformers.\n",
" warnings.warn(\n",
"100%|██████████| 17/17 [00:50<00:00, 2.94s/it]\n"
]
},
{
"data": {
"text/plain": [
"{'eval_loss': 0.5414420962333679,\n",
" 'eval_precision': 0.8941005802707931,\n",
" 'eval_recall': 0.9185295578738202,\n",
" 'eval_f1': 0.9061504533202647,\n",
" 'eval_accuracy': 0.8693688339474622,\n",
" 'eval_runtime': 53.511,\n",
" 'eval_samples_per_second': 0.934,\n",
" 'eval_steps_per_second': 0.318,\n",
" 'epoch': 13.33}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluate()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"trainer.save_model(\"trained_model\")\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'LayoutLMv3ForTokenClassification' is not defined",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[4], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m model \u001b[39m=\u001b[39m LayoutLMv3ForTokenClassification\u001b[39m.\u001b[39mfrom_pretrained(\u001b[39mr\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mG:/Form understanding in Noisy scanned documents/trained_model\u001b[39m\u001b[39m\"\u001b[39m)\n",
"\u001b[1;31mNameError\u001b[0m: name 'LayoutLMv3ForTokenClassification' is not defined"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details."
]
}
],
"source": [
"model = LayoutLMv3ForTokenClassification.from_pretrained(r\"G:/Form understanding in Noisy scanned documents/trained_model\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"trusted": true
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\")\n",
"model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:43:35.233219Z",
"iopub.status.busy": "2023-03-03T19:43:35.232609Z",
"iopub.status.idle": "2023-03-03T19:43:35.335952Z",
"shell.execute_reply": "2023-03-03T19:43:35.334111Z",
"shell.execute_reply.started": "2023-03-03T19:43:35.233171Z"
},
"trusted": true
},
"outputs": [],
"source": [
"example = dataset[\"test\"][2]\n",
"print(example.keys())\n",
"\n",
"image = example[\"image\"]\n",
"words = example[\"tokens\"]\n",
"boxes = example[\"bboxes\"]\n",
"word_labels = example[\"ner_tags\"]\n",
"\n",
"encoding = tokenizer(image, words, boxes=boxes, word_labels=word_labels, return_tensors=\"pt\")\n",
"encoding = encoding.to('cuda')\n",
"for k,v in encoding.items():\n",
" print(k,v.shape)\n",
"\n",
"with torch.no_grad():\n",
" outputs = model.to('cuda')(**encoding)\n",
"\n",
"logits = outputs.logits\n",
"logits.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T19:43:41.414753Z",
"iopub.status.busy": "2023-03-03T19:43:41.413702Z",
"iopub.status.idle": "2023-03-03T19:43:41.427778Z",
"shell.execute_reply": "2023-03-03T19:43:41.424566Z",
"shell.execute_reply.started": "2023-03-03T19:43:41.414704Z"
},
"trusted": true
},
"outputs": [],
"source": [
"predictions = logits.argmax(-1).squeeze().tolist()\n",
"#labels = encoding.labels.squeeze().tolist()\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T20:05:26.719238Z",
"iopub.status.busy": "2023-03-03T20:05:26.718647Z",
"iopub.status.idle": "2023-03-03T20:05:26.731470Z",
"shell.execute_reply": "2023-03-03T20:05:26.730335Z",
"shell.execute_reply.started": "2023-03-03T20:05:26.719199Z"
},
"trusted": true
},
"outputs": [],
"source": [
"def unnormalize_box(bbox, width, height):\n",
" return [\n",
" width * (bbox[0] / 1000),\n",
" height * (bbox[1] / 1000),\n",
" width * (bbox[2] / 1000),\n",
" height * (bbox[3] / 1000),\n",
" ]\n",
"\n",
"token_boxes = encoding.bbox.squeeze().tolist()\n",
"width, height = image.size\n",
"\n",
"true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != - 100]\n",
"true_labels = [model.config.id2label[label] for prediction, label in zip(predictions, labels) if label != -100]\n",
"true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2023-03-03T20:05:31.060497Z",
"iopub.status.busy": "2023-03-03T20:05:31.060140Z",
"iopub.status.idle": "2023-03-03T20:05:31.161089Z",
"shell.execute_reply": "2023-03-03T20:05:31.158110Z",
"shell.execute_reply.started": "2023-03-03T20:05:31.060465Z"
},
"trusted": true
},
"outputs": [],
"source": [
"from PIL import ImageDraw, ImageFont\n",
"\n",
"draw = ImageDraw.Draw(image)\n",
"\n",
"font = ImageFont.load_default()\n",
"\n",
"def iob_to_label(label):\n",
" label = label[2:]\n",
" if not label:\n",
" return 'other'\n",
" return label\n",
"\n",
"label2color = {'question':'blue', 'answer':'green', 'header':'orange', 'other':'violet'}\n",
"\n",
"for prediction, box in zip(true_predictions, true_boxes):\n",
" predicted_label = iob_to_label(prediction).lower()\n",
" draw.rectangle(box, outline=label2color[predicted_label])\n",
" draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)\n",
"\n",
"image"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch_gpu",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"polyglot_notebook": {
"kernelInfo": {
"defaultKernelName": "csharp",
"items": [
{
"aliases": [],
"name": "csharp"
}
]
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}