{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "d090c366-23e5-4221-a868-f290eefcedc2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(\"google/boolq\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "a6bad310-9514-4468-bdca-673b30dfd473", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "tokenizer=AutoTokenizer.from_pretrained(\"bert-base-uncased\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "013559ce-c991-4836-922c-5f9201265c66", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['question', 'answer', 'passage'],\n", " num_rows: 9427\n", " })\n", " validation: Dataset({\n", " features: ['question', 'answer', 'passage'],\n", " num_rows: 3270\n", " })\n", "})" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset" ] }, { "cell_type": "code", "execution_count": 4, "id": "38aac997-3d15-4e61-b80c-c1a4fff0b525", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'question': 'do iran and afghanistan speak the same language',\n", " 'answer': True,\n", " 'passage': 'Persian (/ˈpɜːrʒən, -ʃən/), also known by its endonym Farsi (فارسی fārsi (fɒːɾˈsiː) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari since 1958), and Tajikistan (officially known as Tajiki since the Soviet era), and some other regions which historically were Persianate societies and considered part of Greater Iran. It is written in the Persian alphabet, a modified variant of the Arabic script, which itself evolved from the Aramaic alphabet.'}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[\"train\"][0]" ] }, { "cell_type": "code", "execution_count": 5, "id": "f4d214cd-2fef-4778-bc3a-cb4e1c907515", "metadata": {}, "outputs": [], "source": [ "def encode_question_context_pairs(example):\n", " text=f'{example[\"question\"]} [SEP] {example[\"passage\"]}'\n", " label= 0 if not example[\"answer\"] else 1\n", " inputs=tokenizer(text,truncation=True)\n", " inputs[\"labels\"]=[float(label)]\n", " return inputs" ] }, { "cell_type": "code", "execution_count": 6, "id": "6fa2aa41-6286-4a69-ba23-90482d98f494", "metadata": {}, "outputs": [], "source": [ "train_dataset=dataset[\"train\"].map(encode_question_context_pairs,remove_columns=dataset[\"train\"].column_names)" ] }, { "cell_type": "code", "execution_count": 7, "id": "309bee55-b698-4c66-990d-beb00ac52746", "metadata": {}, "outputs": [], "source": [ "validation_dataset=dataset[\"validation\"].map(encode_question_context_pairs,remove_columns=dataset[\"train\"].column_names)" ] }, { "cell_type": "code", "execution_count": 8, "id": "bf95690a-4ed4-4635-9b39-12bc4b486b5f", "metadata": {}, "outputs": [], "source": [ "# train_dataset['labels']" ] }, { "cell_type": "code", "execution_count": null, "id": "00c07517-6976-4553-8188-2b7f4078adf3", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "1371cc4a-3f0e-4e84-939b-218b570c0b6b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 9, "id": "85c9ccea-f788-4025-b185-c32c6fa51c46", "metadata": {}, "outputs": [], "source": [ "# tokenizer(\"question\",\"answer\",max_length=512,padding=\"max_length\",truncation=\"only_second\",)" ] }, { "cell_type": "code", "execution_count": 10, "id": "30a82635-f956-404d-a95e-db753f7e07b7", "metadata": {}, "outputs": [], "source": [ "from transformers import DataCollatorWithPadding\n", "\n", "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": 11, "id": "22d43e81-1739-443f-95fb-ee98b10a3a0b", "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "\n", "accuracy = evaluate.load(\"accuracy\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "23fa9362-aa3d-4155-85a5-6caa6635c9f8", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " predictions = np.where(predictions<0.5,0,1)\n", " return accuracy.compute(predictions=predictions, references=labels)" ] }, { "cell_type": "code", "execution_count": 13, "id": "e476c76f-21b6-4844-a6a5-29f18b4f6099", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n", "\n", "model = AutoModelForSequenceClassification.from_pretrained(\n", " \"bert-base-uncased\", num_labels=1,\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "5a359a0d-7563-4f4e-b4d4-03e6c601fc2f", "metadata": {}, "outputs": [], "source": [ "training_args = TrainingArguments(\n", " output_dir=\"./\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " num_train_epochs=4,\n", " weight_decay=0.01,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " load_best_model_at_end=True,\n", " gradient_accumulation_steps=4,\n", " logging_steps=50,\n", " seed=42,\n", " adam_beta1= 0.9,\n", " adam_beta2= 0.999,\n", " adam_epsilon= 1e-08,\n", " report_to=\"tensorboard\",\n", " push_to_hub=True,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=validation_dataset,\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", ")\n", "\n", "# trainer.train()" ] }, { "cell_type": "code", "execution_count": 15, "id": "0bc0fca5-d298-40d3-a80b-035a05fe6e1f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('./tokenizer_config.json',\n", " './special_tokens_map.json',\n", " './vocab.txt',\n", " './added_tokens.json',\n", " './tokenizer.json')" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.save_pretrained(training_args.output_dir)\n", "tokenizer.save_pretrained(training_args.output_dir)" ] }, { "cell_type": "code", "execution_count": null, "id": "c96926e2-04c1-4e33-b83f-dc2b9c4d5b08", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [588/588 31:17, Epoch 3.99/4]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracy
00.2317000.2198120.656881
20.1741000.1967690.712232

\n", "

\n", " \n", " \n", " [ 89/205 00:23 < 00:31, 3.73 it/s]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "75e96eb2-0d8e-4e5f-8844-6abce16bd1cb", "metadata": {}, "outputs": [], "source": [ "kwargs = {\n", " \"dataset_tags\": \"google/boolq\",\n", " \"dataset\": \"boolq\", # a 'pretty' name for the training dataset\n", " \"language\": \"en\",\n", " \"model_name\": \"Bert Base Uncased Boolean Question Answer model\", # a 'pretty' name for your model\n", " \"finetuned_from\": \"bert-base-uncased\",\n", " \"tasks\": \"text-classification\",\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "ba5e73bd-d154-43ce-a869-f0f57045a386", "metadata": {}, "outputs": [], "source": [ "trainer.push_to_hub(**kwargs)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }