{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "opwecKkHKCHY" }, "source": [ "# **Google drive Mounting**" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6ufRlY5NJuH1", "outputId": "cc1c14b9-1c16-4532-a1a2-88e0811c7200" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ] } ], "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ] }, { "cell_type": "code", "source": [ "!pip install transformers\n", "!pip install huggingface_hub" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "S44RrN7O6Wx5", "outputId": "1bf2800d-0656-49b3-abc8-f825c5919a42" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.44.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.24.7)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (2024.6.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n", "Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.10/dist-packages (0.24.7)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (3.16.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (2024.6.1)\n", "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (6.0.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (4.66.5)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub) (2024.8.30)\n" ] } ] }, { "cell_type": "code", "source": [ "from huggingface_hub import notebook_login\n", "notebook_login()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 160, "referenced_widgets": [ "c9e738e658ba48fba0483f17fe5994ea", "590009f6c10241c48e5056ef70034208", "064d839358e64573a6612f3e887fc02a", "76e514eb154d4798936ee6eeb94b44eb", "c2d186ca4bb04676905356a6b9e981dd", "eef83ea229504c67836f98f796fbe584", "0ba5c39bcab040d5abbf1274e4d0a678", "308c81cd6269445f88a1a5dff7803936", "83915d4ebfe448369051604d53f8f144", "dd1e22b30e40445ea97f94a4a76f2d5c", "f68095098d93426aa52a7a927dfe7f4c", "bb36edb1dfc9463ab434fe46e6f022d0", "d86c7801aa52404d82a54c7c6aef61f1", "a1a24381dd024b40b15a3de1527a0ab6", "b6e453cf38024520af85f252a3e044ce", "ff26aedbb28f4d738240e60e97d4ef1e", "05ddd82b414a406eaf31a1c56d578e8e", "e98290bb06f14143b3956bce1039d1d6", "d08f6dbec15d4e6d81e77597fde629af", "4dbd12f944074c70a74de39e5eea7b86", "8f80f5a39281421b9be35985893203c4", "3e3c248e2f474990bcf558711e866421", "053e56fb46254958852cf373e9f96f5e", "ae8f070d3d0445b89a8f2a797e79406f", "23163add522745188c04a3b040cb1057", "c96740311b7d42fe82393c057845e059", "d62565a066614ffdba7679817de48f16", "a8a4452336594979b246b98c05d6b911", "498b40387da1418bafe65b50ce69b94b", "0418608a12d54ebdac0a86ba9fcd71c6", "62c6b18e7be149fbb7add2effe27313f", "8827d4f55f2e433ba66a94206d69ed0f" ] }, "id": "mmvT1GKV6h74", "outputId": "5349e560-f7f3-4e78-ec53-bfbc39abd5c8" }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='
Step | \n", "Training Loss | \n", "
---|
" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=225, training_loss=4.451516384548611, metrics={'train_runtime': 2131.93, 'train_samples_per_second': 0.418, 'train_steps_per_second': 0.106, 'total_flos': 58202800128000.0, 'train_loss': 4.451516384548611, 'epoch': 3.0})" ] }, "metadata": {}, "execution_count": 7 } ], "source": [ "!pip install transformers[torch]\n", "!pip install accelerate>=0.21.0 --upgrade\n", "\n", "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", "from transformers import TextDataset, DataCollatorForLanguageModeling\n", "from transformers import Trainer, TrainingArguments\n", "\n", "# Load the tokenizer\n", "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", "\n", "# Tokenize the dataset\n", "train_path = '/content/drive/My Drive/Data/HUISGENOOT_corpus_v5.txt'\n", "train_dataset = TextDataset(\n", " tokenizer=tokenizer,\n", " file_path=train_path,\n", " block_size=128\n", ")\n", "\n", "# Create data collator\n", "data_collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer, mlm=False\n", ")\n", "\n", "# Initialize the model\n", "model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n", "\n", "# Define the training arguments\n", "training_args = TrainingArguments(\n", " output_dir=\"./gpt2-afrikaans\",\n", " overwrite_output_dir=True,\n", " num_train_epochs=3,\n", " per_device_train_batch_size=4,\n", " save_steps=10_000,\n", " save_total_limit=2,\n", " prediction_loss_only=True,\n", ")\n", "\n", "# Create Trainer instance\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=data_collator,\n", " train_dataset=train_dataset,\n", ")\n", "\n", "# Start training\n", "trainer.train()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "v_wpeJ0oskwt", "outputId": "ef42e9c2-8ce2-4401-c38c-0c80a820280e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.34.2)\n", "Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (24.1)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.2)\n", "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.4.0+cu121)\n", "Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.24.7)\n", "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.5)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (3.16.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2024.6.1)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.66.5)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (1.13.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.3)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate) (3.1.4)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.8)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.8.30)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n" ] } ], "source": [ "!pip install --upgrade accelerate\n", "exit() # Restart the kernel after running this cell" ] }, { "cell_type": "markdown", "metadata": { "id": "oZv9fsDQvLny" }, "source": [ "# **Model Saving**" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5BJ_r8QEvQRo", "outputId": "9a9407e6-04ee-40ca-9bcd-a0c1d83aaf5c" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "('./gpt2-afrikaans/tokenizer_config.json',\n", " './gpt2-afrikaans/special_tokens_map.json',\n", " './gpt2-afrikaans/vocab.json',\n", " './gpt2-afrikaans/merges.txt',\n", " './gpt2-afrikaans/added_tokens.json')" ] }, "metadata": {}, "execution_count": 8 } ], "source": [ "# Save the trained model\n", "output_dir = \"./gpt2-afrikaans\"\n", "model.save_pretrained(output_dir)\n", "tokenizer.save_pretrained(output_dir)" ] }, { "cell_type": "markdown", "source": [ "# **Trained data Calling (11-50)**" ], "metadata": { "id": "BZQrtm7GPkCw" } }, { "cell_type": "code", "source": [ "# Load the preprocessed text from article 11-50\n", "with open('/content/drive/My Drive/Data/HUISGENOOT_corpus_11_50.txt', 'r', encoding='utf-8') as file:\n", " corpus_51_550 = file.read()" ], "metadata": { "id": "TdEbPSbcPrQf" }, "execution_count": 9, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "MEjNk3HVk63S" }, "source": [ "#Trained data Calling (51-550)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "SAYeLDaFlRUx" }, "outputs": [], "source": [ "# Load the preprocessed text from article 51-550\n", "with open('/content/drive/My Drive/Data/HUISGENOOT_corpus_51_550.txt', 'r', encoding='utf-8') as file:\n", " corpus_51_550 = file.read()" ] }, { "cell_type": "markdown", "metadata": { "id": "X7cKw4BLlJXB" }, "source": [ "# Trained data calling (551-1000)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "ArAb-xKvlQC5" }, "outputs": [], "source": [ "# Load the preprocessed text from article 551-1000\n", "with open('/content/drive/My Drive/Data/HUISGENOOT_corpus_551_1000.txt', 'r', encoding='utf-8') as file:\n", " corpus_551_1000 = file.read()" ] }, { "cell_type": "markdown", "metadata": { "id": "UUFYlYlMCLTs" }, "source": [ "# **Trained data Calling (1001 - 1800)**" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "id": "5jpK8uDPCTDf" }, "outputs": [], "source": [ "# Load the preprocessed text from article 1001 - 1800\n", "with open('/content/drive/My Drive/Data/HUISGENOOT_corpus_1001_1800.txt', 'r', encoding='utf-8') as file:\n", " corpus_1001_1800 = file.read()" ] }, { "cell_type": "markdown", "metadata": { "id": "bxZSXLWCvRs2" }, "source": [ "# **Model Testing**" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QFbrBB2mvcEq", "outputId": "8610e9c0-bcd3-4959-aa84-d5d1ce1897b8" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n", "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Ek het 'n droom gehad in die hoe nie die klange oorstie die skitterklaas, sy het met gedak wat ons uit gekrag te het in die het vrous die beteken vir een eejoon, haar in vir nou koiΓ« nie. βHoy ek die klaas het ek op βn m\n" ] } ], "source": [ "from transformers import pipeline\n", "\n", "# Load the fine-tuned model\n", "model_path = \"./gpt2-afrikaans\" # Adjust the path accordingly\n", "tokenizer = GPT2Tokenizer.from_pretrained(model_path)\n", "model = GPT2LMHeadModel.from_pretrained(model_path)\n", "\n", "# Set up text generation pipeline\n", "text_generator = pipeline(\"text-generation\", model=model, tokenizer=tokenizer)\n", "\n", "# Generate text\n", "prompt = \"Ek het 'n droom gehad\"\n", "generated_text = text_generator(prompt, max_length=100, num_return_sequences=1)\n", "print(generated_text[0]['generated_text'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "IiuQv23_vz5a", "outputId": "63d890f0-19a9-4cc3-bfd4-f2117b003024" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Hou oud is jy wΓͺreld.\n", "\n", "\n", "Daar is jy wat ek voel by die jongbukkakeld met die kommer van my party aan die skaat. Ek is vermorkig by diΓ© raad met die kommunikings van die steeds om om te bordeljies en vermaaan kou, maar gekry in die maou in jou kon n\n" ] } ], "source": [ "prompt2 = \"Hou oud is jy\"\n", "generated_text2 = text_generator(prompt2, max_length=100, num_return_sequences=1)\n", "print(generated_text2[0]['generated_text'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iuJ2BYxSnuHr", "outputId": "a951c912-b4f3-41ea-e243-fc41102443fe" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Vertel my van die nuutste berigte in Huisgenoot. \n", "\n", "\n", "Gwyneth Paltrow (@wiebestsellers) is en laatste maar nou beΒder vuil vir my lewe.\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "REDAACHT\n", "\n", "\n", "DIE GROOTTE, MELISSA BAKER, GEORGE HANSBURG\n", "\n", "\n", "\n", "\n", "\n", "\n", "LIEVELANDE RUMS\n", "\n", "\n", "JP\n" ] } ], "source": [ "prompt3 = \"Vertel my van die nuutste berigte in Huisgenoot.\"\n", "generated_text3 = text_generator(prompt3, max_length=100, num_return_sequences=1)\n", "print(generated_text3[0]['generated_text'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VGPyy9iXvqrt", "outputId": "1727e499-9d58-4b79-99f1-f13013f4e3bd" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Wat is jou gunsteling Huisgenoot artikel en hoekom?β he moet toe jou gee. (Huisgenoot verstaan van die RUSSIAN PETERS.) \n", "βHoe beter Jy uit gehad het van die RUSSIAN PETERS gebore het en word oor die RUSSIAN PETERS wanneer gesΓͺ die ook sy het\n" ] } ], "source": [ "prompt4 = \"Wat is jou gunsteling Huisgenoot artikel en hoekom?\"\n", "generated_text4 = text_generator(prompt4, max_length=100, num_return_sequences=1)\n", "print(generated_text4[0]['generated_text'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "C5UX1XrpwsAj", "outputId": "01a5e6dd-0366-4681-e6e4-238b5a2af1df" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Wat is die naam van die tydskrif?β vertel nΓ³r die gesien het, kon verhouding van die tydskrif vir skryf oor die tydskrif van die die gekom nou in lewe.\n", "Dit was twee die opgemaar tydste tydste boeke opgeveel om een vrou daardie na haar\n" ] } ], "source": [ "prompt5 = \"Wat is die naam van die tydskrif?\"\n", "generated_text5 = text_generator(prompt5, max_length=100, num_return_sequences=1)\n", "print(generated_text5[0]['generated_text'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-UVRto9Pyynr", "outputId": "1f069bfa-5876-44cb-b50b-4fb4ffd0ff6a" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Hoe het Huisgenoot oor belangrike historiese gebeurtenisse in Suid-Afrika berig? \n", "βEk geen uit altyd hoe verdoon om by die dae te minstel verhouding hulle geleur βn oor julle bly die diefe oor die jou van die bly.β \n", "Ek kan altyd\n" ] } ], "source": [ "prompt6 = \"Hoe het Huisgenoot oor belangrike historiese gebeurtenisse in Suid-Afrika berig?\"\n", "generated_text6 = text_generator(prompt6, max_length=100, num_return_sequences=1)\n", "print(generated_text6[0]['generated_text'])" ] }, { "cell_type": "markdown", "metadata": { "id": "4fG2XO052hqW" }, "source": [ "# **Downstream Applications**" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "cja81Xwl7Rlb", "outputId": "1798ba44-f15d-42d5-8870-191832b8246e" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", " warnings.warn(\n", "Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModelForQuestionAnswering\n", "\n", "# Replace 'your-pretrained-model-name' with the actual name of the model\n", "model_name = \"bert-base-uncased\" # Example: using a public model\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForQuestionAnswering.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 360 }, "id": "HTBzzjkH7ic7", "outputId": "0acaf10b-d326-4919-c460-b66b07a8ba59" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Found existing installation: pyarrow 14.0.2\n", "Uninstalling pyarrow-14.0.2:\n", " Successfully uninstalled pyarrow-14.0.2\n", "Collecting pyarrow\n", " Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\n", "Requirement already satisfied: numpy>=1.16.6 in /usr/local/lib/python3.10/dist-packages (from pyarrow) (1.26.4)\n", "Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)\n", "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m39.9/39.9 MB\u001b[0m \u001b[31m20.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: pyarrow\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.\n", "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed pyarrow-17.0.0\n" ] }, { "output_type": "display_data", "data": { "application/vnd.colab-display-data+json": { "pip_warning": { "packages": [ "pyarrow" ] }, "id": "c9bf11b6ea23464cb47ca59e72a33226" } }, "metadata": {} } ], "source": [ "!pip uninstall -y pyarrow\n", "!pip install pyarrow" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 177, "referenced_widgets": [ "82fc479120854d86bcef73db967b6ec9", "ee980916096b4394b9b716ee6320ba1f", "2c11bd034e084d24abe6c3439a9dabaa", "33ccc527ffe14cd9be6a45d724d7dc97", "fee3adc677204c23943c91274713f275", "9299999420a04ee4be882f09c44633dd", "1f8ac66132944750ae250b5aa01092a9", "95f6477dcd75400ca9cfeeeaddef894a", "f8bf348d24a64d4299965d1a8199ff4c", "45726faf5b32401f8626a8d0a7e84ed3", "967b128c873b48f1ac4da4ce320c327e", "365b5c3e87eb497388e4795a1bcbe9dd", "388648db5e91492a937b3732d70c8716", "9bde2384454945e8bd0cc7f4e422eaff", "d18fa32255984b2da53cf5c9e852123a", "a50b5dd90feb49328aa928e5d4a20b79", "ac0aa671b99b41f69112067d25a9a1d7", "fdefac5acb164025a5860476fff717c1", "0fcd4f975b9348e2975aa23e1419b2f4", "f5fd3e13eb944771ba56139a849140b0", "739ce12d24744cd9b289770cc9a2452a", "dfe84088930a4fd0b057eff9a79cdc24", "2d2bfbbb2894415a8ac6d4273c73355d", "a12e75ffaccc4526a9400862ef1518f1", "b04aa00ec6e44eb9a76f0d9ddbdfa60f", "b71982105a91405fbc1cb09949e87ddd", "0869b1756223458b9882de3f8c82656c", "02e79aa08d1243a18f1cbca9714e026f", "f3b8c57e5b24433d874d542b2a23eba3", "17c263b18f964b97a66ccf64ad4b4276", "89ed766f2e49481399fcc852212d3d70", "fef55075abd441539725f06e72b51c09", "4a2a8f28f8d644bcaf68b57e58f775e4", "406a13bcf1f347c3a99ecea78b6da02a", "e9b59bc3225e4ef9a2157343be55a63f", "6a6dfa5d095b4a2790dbfb07deb1e155", "cabc88a574894c8b894c84401f88d338", "a942c3c51c57460eb62ab0a9009e4c54", "8becb871229e4c83be573cc6fe97306b", "062f29fb07514400aafdbd398265cd78", "0fb6521f195943a696206063436a8b60", "e6c2d063173e41998fa5f1d71ccf4312", "5dc3f638c90f44d5830e5e6ef962b114", "7f2ca1c3768e4b329bfd11aed8e044ed", "d016dc6f603d4c63985e55d65924cab9", "2b317b02cf5b476282fc5f6e076bed05", "8793ff11134b4ca7bae2e2795a4812e4", "4da5f2eb5fe14ddd9379f99f2a79b6f1", "4d5fca8d30ab4b528cb5d1254091e066", "2da79f44eb6e420aaac7376d9f630db3", "687e25d32e564ced99477e133a8f301e", "0af71d029f894711bdb14fc668ba8c4c", "b7baa3d8749f44949b7a4507d79f2999", "d29745c1d66245929b3cab298082f319", "cf16e6cd310b4b6292e5c77c7a498bc3" ] }, "id": "HhJjGJiW6dRf", "outputId": "817f6173-8da0-472c-c319-38dd130f7f3c" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "README.md: 0%| | 0.00/7.62k [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "82fc479120854d86bcef73db967b6ec9" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "train-00000-of-00001.parquet: 0%| | 0.00/14.5M [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "365b5c3e87eb497388e4795a1bcbe9dd" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "validation-00000-of-00001.parquet: 0%| | 0.00/1.82M [00:00, ?B/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "2d2bfbbb2894415a8ac6d4273c73355d" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Generating train split: 0%| | 0/87599 [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "406a13bcf1f347c3a99ecea78b6da02a" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Generating validation split: 0%| | 0/10570 [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "d016dc6f603d4c63985e55d65924cab9" } }, "metadata": {} } ], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(\"squad\")\n" ] }, { "source": [ "!pip install datasets" ], "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "gMbIfr5OgXbH", "outputId": "ea58c1e8-02dc-4850-d12c-201e996a158a" }, "execution_count": 15, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting datasets\n", " Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", "Collecting pyarrow>=15.0.0 (from datasets)\n", " Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\n", "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\n", "Collecting xxhash (from datasets)\n", " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", "Collecting multiprocess (from datasets)\n", " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.5)\n", "Requirement already satisfied: huggingface-hub>=0.22.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.24.7)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.11.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.22.0->datasets) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Downloading datasets-3.0.0-py3-none-any.whl (474 kB)\n", "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m474.3/474.3 kB\u001b[0m \u001b[31m14.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)\n", "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m39.9/39.9 MB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", "\u001b[2K \u001b[90mββββββββββββββββββββββββββββββββββββββββ\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m12.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: xxhash, pyarrow, dill, multiprocess, datasets\n", " Attempting uninstall: pyarrow\n", " Found existing installation: pyarrow 14.0.2\n", " Uninstalling pyarrow-14.0.2:\n", " Successfully uninstalled pyarrow-14.0.2\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.\n", "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed datasets-3.0.0 dill-0.3.8 multiprocess-0.70.16 pyarrow-17.0.0 xxhash-3.5.0\n" ] }, { "output_type": "display_data", "data": { "application/vnd.colab-display-data+json": { "pip_warning": { "packages": [ "pyarrow" ] }, "id": "5d5ff7e285bc4463b45545f97359f075" } }, "metadata": {} } ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 81, "referenced_widgets": [ "1894be3923594d31ba6fef1c6ad4ec8d", "317346312e2246bcb183cd33869a27c6", "a3ad6ae3419e4a44896d0bc1ade36e5b", "42a469fd2be545c19f9cc65d60f119e2", "12c76fbfc877419a902cf89295009701", "050c950adf8e403e8e269629ecfdf34a", "88e6c19df9cf4a4095bcec449e22775c", "9c0105f5c1144007a4048a943f3717e1", "67e411a799b64f46b59184433861a6a9", "1593e36056f1404695f6d305a8cebc06", "12a4af2f44264df0ac2c604279f02d2b", "fd641b11b8b645e79bf4125a013d2372", "273643fec5784b5799592c9a91735cb7", "a0e9d0ada6144bec944af851ddfe57f4", "d8dd1ad6adb24a06a85d02131da5475c", "980363ab5ec84432845ab56b84444333", "5a10d39659df41fabe1e8ddf2c62d59d", "7c9f8d2397aa47efa1ab4cebfe5aad49", "46cb29e5d9de4f50bd9d9f4a0b592e25", "338a9d9310ff49fb929ba21b421c8f0e", "fdc4122a197341e0b07deba6b5b00514", "3bcadfb0f29c40c9b6e735ecd4473c90" ] }, "id": "A0k0Ne5SC-6x", "outputId": "058c888a-85c7-453f-f4d8-093ca1e112ec" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Map: 0%| | 0/87599 [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "1894be3923594d31ba6fef1c6ad4ec8d" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Map: 0%| | 0/10570 [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "fd641b11b8b645e79bf4125a013d2372" } }, "metadata": {} } ], "source": [ "def preprocess_function(examples):\n", " questions = [q.strip() for q in examples[\"question\"]]\n", " inputs = tokenizer(\n", " questions,\n", " examples[\"context\"],\n", " max_length=384,\n", " truncation=\"only_second\",\n", " return_offsets_mapping=True,\n", " padding=\"max_length\",\n", " )\n", "\n", " offset_mapping = inputs.pop(\"offset_mapping\")\n", " answers = examples[\"answers\"]\n", " start_positions = []\n", " end_positions = []\n", "\n", " for i, offset in enumerate(offset_mapping):\n", " answer = answers[i]\n", " start_char = answer[\"answer_start\"][0]\n", " end_char = start_char + len(answer[\"text\"][0])\n", " sequence_ids = inputs.sequence_ids(i)\n", "\n", " context_start = sequence_ids.index(1)\n", " context_end = len(sequence_ids) - 1 - sequence_ids[::-1].index(1)\n", "\n", " if offset[context_start][0] > end_char or offset[context_end][1] < start_char:\n", " start_positions.append(0)\n", " end_positions.append(0)\n", " else:\n", " # Find the start and end positions using the offset mapping\n", " start_position = end_position = 0 # Initialize to 0\n", " for j, (s, e) in enumerate(offset):\n", " if s == start_char:\n", " start_position = j\n", " if e == end_char:\n", " end_position = j\n", " break\n", "\n", " start_positions.append(start_position) # Append the found position or 0\n", " end_positions.append(end_position) # Append the found position or 0\n", "\n", " inputs[\"start_positions\"] = start_positions\n", " inputs[\"end_positions\"] = end_positions\n", " return inputs\n", "\n", "tokenized_datasets = dataset.map(preprocess_function, batched=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vpMKmGG06ppx", "outputId": "d80635e7-b2bd-40c8-8b9b-31bdd3b306c3" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of π€ Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n" ] } ], "source": [ "from transformers import TrainingArguments, Trainer\n", "\n", "training_args = TrainingArguments(\n", " output_dir=\"./results\",\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " num_train_epochs=3,\n", " weight_decay=0.01,\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " eval_dataset=tokenized_datasets[\"validation\"],\n", ")\n", "\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JQ8MBHu_6teA" }, "outputs": [], "source": [ "results = trainer.evaluate()\n", "print(results)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RsylWV9Ww37c" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "opwecKkHKCHY", "aV_gTrLgNdN6", "29rTY238NJJ5", "F2hy8IHkOTwB", "oZv9fsDQvLny", "bxZSXLWCvRs2", "Cd8o5uNsvc2C" ], "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "c9e738e658ba48fba0483f17fe5994ea": { "model_module": "@jupyter-widgets/controls", "model_name": "VBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "VBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "VBoxView", "box_style": "", "children": [ "IPY_MODEL_8f80f5a39281421b9be35985893203c4", "IPY_MODEL_3e3c248e2f474990bcf558711e866421", "IPY_MODEL_053e56fb46254958852cf373e9f96f5e", "IPY_MODEL_ae8f070d3d0445b89a8f2a797e79406f" ], "layout": "IPY_MODEL_0ba5c39bcab040d5abbf1274e4d0a678" } }, "590009f6c10241c48e5056ef70034208": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_308c81cd6269445f88a1a5dff7803936", "placeholder": "β", "style": "IPY_MODEL_83915d4ebfe448369051604d53f8f144", "value": "