{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "0e7385a4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\preec\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['title', 'body', 'summary', 'type', 'tags', 'url'],\n", " num_rows: 358868\n", " })\n", " validation: Dataset({\n", " features: ['title', 'body', 'summary', 'type', 'tags', 'url'],\n", " num_rows: 11000\n", " })\n", " test: Dataset({\n", " features: ['title', 'body', 'summary', 'type', 'tags', 'url'],\n", " num_rows: 11000\n", " })\n", "})" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from datasets import load_dataset\n", "\n", "ds = load_dataset(\"thaisum\")\n", "ds" ] }, { "cell_type": "code", "execution_count": null, "id": "337b3bc6", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "from datasets import DatasetDict \n", "\n", "dataset = load_dataset('csv', data_files='thaisum.csv')\n", "ds_train_devtest = dataset['train'].train_test_split(test_size=0.05, seed=42)\n", "ds_devtest = ds_train_devtest['test'].train_test_split(test_size=0.5, seed=42)\n", "\n", "\n", "ds_thai_news = DatasetDict({\n", " 'train': ds_train_devtest['train'],\n", " 'valid': ds_devtest['train'],\n", " 'test': ds_devtest['test']\n", "})\n", "ds_thai_news" ] }, { "cell_type": "code", "execution_count": 2, "id": "286cbb13-5fff-4291-bdd7-3e4ddf972228", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\preec\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\transformers\\utils\\generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", " _torch_pytree._register_pytree_node(\n", "C:\\Users\\preec\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\transformers\\utils\\generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", " _torch_pytree._register_pytree_node(\n", "C:\\Users\\preec\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.12_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python312\\site-packages\\transformers\\utils\\generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", " _torch_pytree._register_pytree_node(\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig\n", "import torch\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "mt5_config = AutoConfig.from_pretrained(\n", " \"../mt5-base-thaisum-text-summarization\",\n", " local_files_only=True,\n", " max_length=140,\n", " min_length=40,\n", " length_penalty=1.2,\n", " no_repeat_ngram_size=2,\n", " num_beams=15,\n", ")\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"../mt5-base-thaisum-text-summarization\", local_files_only=True)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(\"../mt5-base-thaisum-text-summarization\", local_files_only=True).to(device)" ] }, { "cell_type": "code", "execution_count": 3, "id": "ebfdf213", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Map: 100%|██████████| 11000/11000 [00:17<00:00, 622.18 examples/s]\n" ] } ], "source": [ "from transformers import DataCollatorForSeq2Seq\n", "data_collator = DataCollatorForSeq2Seq(\n", " tokenizer,\n", " model=model,\n", " return_tensors=\"pt\")\n", "\n", "def tokenize_data(data):\n", "\n", " input_feature = tokenizer(data[\"body\"], truncation=True, max_length=512)\n", " label = tokenizer(data[\"summary\"], truncation=True, max_length=140)\n", " return {\n", " \"input_ids\": input_feature[\"input_ids\"],\n", " \"attention_mask\": input_feature[\"attention_mask\"],\n", " \"labels\": label[\"input_ids\"],\n", " }\n", "\n", "token_ds_thai_news = ds.map(\n", " tokenize_data,\n", " remove_columns=['title', 'body', 'summary', 'type', 'tags', 'url'],\n", " batched=True,\n", " batch_size=64)" ] }, { "cell_type": "code", "execution_count": 4, "id": "a01f4771", "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "import numpy as np\n", "def tokenize_sentence(arg):\n", " encoded_arg = tokenizer(arg)\n", " return tokenizer.convert_ids_to_tokens(encoded_arg.input_ids)\n", "\n", "def metrics_func(eval_arg):\n", " preds, labels = eval_arg\n", " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n", " text_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n", " text_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", "\n", " return rouge_metric.compute(\n", " predictions=text_preds,\n", " references=text_labels,\n", " tokenizer=tokenize_sentence\n", " )\n", "rouge_metric = evaluate.load(\"rouge\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "5d0f286b", "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainingArguments\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir = \"..\",\n", " log_level = \"error\",\n", " num_train_epochs = 6,\n", " learning_rate = 5e-4,\n", " warmup_steps = 5000,\n", " weight_decay=0.01,\n", " per_device_train_batch_size = 8,\n", " per_device_eval_batch_size = 1,\n", " gradient_accumulation_steps = 4,\n", " evaluation_strategy = \"steps\",\n", " eval_steps = 100,\n", " predict_with_generate=True,\n", " generation_max_length = 140,\n", " save_steps = 3000,\n", " logging_steps = 10,\n", " push_to_hub = False,\n", " remove_unused_columns=False\n", ")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "33e02416", "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainer\n", "trainer = Seq2SeqTrainer(\n", " model = model,\n", " args = training_args,\n", " data_collator = data_collator,\n", " compute_metrics = metrics_func,\n", " train_dataset = token_ds_thai_news[\"train\"],\n", " eval_dataset = token_ds_thai_news[\"valid\"].select(range(30)),\n", " tokenizer = tokenizer,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "1048d26c", "metadata": {}, "outputs": [], "source": [ "import os\n", "from transformers import AutoModelForSeq2SeqLM\n", "\n", "os.makedirs(\"./trained_for_summarization\", exist_ok=True)\n", "if hasattr(trainer.model, \"module\"):\n", " trainer.model.module.save_pretrained(\"./trained_for_summarization\")\n", "else:\n", " trainer.model.save_pretrained(\"./trained_for_summarization\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }