{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ghZvHQRQ35xE", "outputId": "c3f512e2-1272-4681-f7bf-8e68500f06ab" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.18.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.13.3)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.25.2)\n", "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (14.0.2)\n", "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.2)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec[http]<=2024.2.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.3)\n", "Requirement already satisfied: huggingface-hub>=0.19.4 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.20.3)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.19.4->datasets) (4.10.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2024.2.2)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.4)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n", "Requirement already satisfied: trl in /usr/local/lib/python3.10/dist-packages (0.8.1)\n", "Requirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from trl) (2.2.1+cu121)\n", "Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from trl) (4.40.0.dev0)\n", "Requirement already satisfied: numpy>=1.18.2 in /usr/local/lib/python3.10/dist-packages (from trl) (1.25.2)\n", "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from trl) (0.28.0)\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from trl) (2.18.0)\n", "Requirement already satisfied: tyro>=0.5.11 in /usr/local/lib/python3.10/dist-packages (from trl) (0.7.3)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (3.13.3)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (4.10.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (3.1.3)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (2023.6.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (2.19.3)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (12.1.105)\n", "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl) (2.2.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.4.0->trl) (12.4.99)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl) (0.20.3)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl) (24.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl) (2023.12.25)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl) (0.15.2)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl) (0.4.2)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl) (4.66.2)\n", "Requirement already satisfied: docstring-parser>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (0.16)\n", "Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (13.7.1)\n", "Requirement already satisfied: shtab>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl) (1.7.1)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate->trl) (5.9.5)\n", "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (14.0.2)\n", "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (1.5.3)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (3.4.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (0.70.16)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->trl) (3.9.3)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl) (4.0.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl) (2024.2.2)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl) (2.16.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.4.0->trl) (2.1.5)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->trl) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->trl) (2023.4)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.4.0->trl) (1.3.0)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl) (0.1.2)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets->trl) (1.16.0)\n", "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.10.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (24.0)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n", "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.2.1+cu121)\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.40.0.dev0)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft) (4.66.2)\n", "Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.28.0)\n", "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.4.2)\n", "Requirement already satisfied: huggingface-hub>=0.17.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.20.3)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (3.13.3)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (2023.6.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (2.31.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (4.10.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.3)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.19.3)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.2.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.13.0->peft) (12.4.99)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2023.12.25)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.15.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.5)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (2024.2.2)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n", "Requirement already satisfied: wandb==0.16.3 in /usr/local/lib/python3.10/dist-packages (0.16.3)\n", "Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (8.1.7)\n", "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (3.1.43)\n", "Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (2.31.0)\n", "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (5.9.5)\n", "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (1.44.0)\n", "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (0.4.0)\n", "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (6.0.1)\n", "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (1.3.3)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (67.7.2)\n", "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (1.4.4)\n", "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3) (3.20.3)\n", "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb==0.16.3) (1.16.0)\n", "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb==0.16.3) (4.0.11)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb==0.16.3) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb==0.16.3) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb==0.16.3) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb==0.16.3) (2024.2.2)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb==0.16.3) (5.0.1)\n", "Requirement already satisfied: huggingface_hub==0.20.3 in /usr/local/lib/python3.10/dist-packages (0.20.3)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface_hub==0.20.3) (3.13.3)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub==0.20.3) (2023.6.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface_hub==0.20.3) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub==0.20.3) (4.66.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub==0.20.3) (6.0.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub==0.20.3) (4.10.0)\n", "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub==0.20.3) (24.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub==0.20.3) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub==0.20.3) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub==0.20.3) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub==0.20.3) (2024.2.2)\n", "Collecting git+https://github.com/huggingface/transformers.git\n", " Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-gbui18gs\n", " Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-gbui18gs\n", " Resolved https://github.com/huggingface/transformers.git to commit 096f304695f7e7b169b031f7814352e900ad71c4\n", " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0) (3.13.3)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0) (0.20.3)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0) (24.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0) (2023.12.25)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0) (0.15.2)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0) (0.4.2)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0) (4.66.2)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers==4.40.0.dev0) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers==4.40.0.dev0) (4.10.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.40.0.dev0) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.40.0.dev0) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.40.0.dev0) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers==4.40.0.dev0) (2024.2.2)\n", "Requirement already satisfied: bitsandbytes in /usr/local/lib/python3.10/dist-packages (0.43.0)\n", "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.28.0)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (2.2.1+cu121)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from bitsandbytes) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (24.0)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0.1)\n", "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.20.3)\n", "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from accelerate) (0.4.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (3.13.3)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (4.10.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (3.1.3)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (2023.6.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (2.19.3)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (12.1.105)\n", "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch->bitsandbytes) (2.2.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->bitsandbytes) (12.4.99)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub->accelerate) (4.66.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->bitsandbytes) (2.1.5)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub->accelerate) (2024.2.2)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->bitsandbytes) (1.3.0)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.2.1+cu121)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.3)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.10.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.3)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch) (2.19.3)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.2.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.4.99)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n" ] } ], "source": [ "!pip install datasets\n", "!pip install trl\n", "!pip install peft\n", "!pip install wandb==0.16.3\n", "!pip install huggingface_hub==0.20.3\n", "!pip install git+https://github.com/huggingface/transformers.git\n", "!pip install bitsandbytes accelerate\n", "!pip install torch" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5tDNlPUH5ig-", "outputId": "0de413bb-9ad2-4d3f-fe16-ce348c8630f6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'starcoder2'...\n", "remote: Enumerating objects: 44, done.\u001b[K\n", "remote: Counting objects: 2% (1/44)\u001b[K\rremote: Counting objects: 4% (2/44)\u001b[K\rremote: Counting objects: 6% (3/44)\u001b[K\rremote: Counting objects: 9% (4/44)\u001b[K\rremote: Counting objects: 11% (5/44)\u001b[K\rremote: Counting objects: 13% (6/44)\u001b[K\rremote: Counting objects: 15% (7/44)\u001b[K\rremote: Counting objects: 18% (8/44)\u001b[K\rremote: Counting objects: 20% (9/44)\u001b[K\rremote: Counting objects: 22% (10/44)\u001b[K\rremote: Counting objects: 25% (11/44)\u001b[K\rremote: Counting objects: 27% (12/44)\u001b[K\rremote: Counting objects: 29% (13/44)\u001b[K\rremote: Counting objects: 31% (14/44)\u001b[K\rremote: Counting objects: 34% (15/44)\u001b[K\rremote: Counting objects: 36% (16/44)\u001b[K\rremote: Counting objects: 38% (17/44)\u001b[K\rremote: Counting objects: 40% (18/44)\u001b[K\rremote: Counting objects: 43% (19/44)\u001b[K\rremote: Counting objects: 45% (20/44)\u001b[K\rremote: Counting objects: 47% (21/44)\u001b[K\rremote: Counting objects: 50% (22/44)\u001b[K\rremote: Counting objects: 52% (23/44)\u001b[K\rremote: Counting objects: 54% (24/44)\u001b[K\rremote: Counting objects: 56% (25/44)\u001b[K\rremote: Counting objects: 59% (26/44)\u001b[K\rremote: Counting objects: 61% (27/44)\u001b[K\rremote: Counting objects: 63% (28/44)\u001b[K\rremote: Counting objects: 65% (29/44)\u001b[K\rremote: Counting objects: 68% (30/44)\u001b[K\rremote: Counting objects: 70% (31/44)\u001b[K\rremote: Counting objects: 72% (32/44)\u001b[K\rremote: Counting objects: 75% (33/44)\u001b[K\rremote: Counting objects: 77% (34/44)\u001b[K\rremote: Counting objects: 79% (35/44)\u001b[K\rremote: Counting objects: 81% (36/44)\u001b[K\rremote: Counting objects: 84% (37/44)\u001b[K\rremote: Counting objects: 86% (38/44)\u001b[K\rremote: Counting objects: 88% (39/44)\u001b[K\rremote: Counting objects: 90% (40/44)\u001b[K\rremote: Counting objects: 93% (41/44)\u001b[K\rremote: Counting objects: 95% (42/44)\u001b[K\rremote: Counting objects: 97% (43/44)\u001b[K\rremote: Counting objects: 100% (44/44)\u001b[K\rremote: Counting objects: 100% (44/44), done.\u001b[K\n", "remote: Compressing objects: 100% (41/41), done.\u001b[K\n", "remote: Total 44 (delta 19), reused 9 (delta 2), pack-reused 0\u001b[K\n", "Receiving objects: 100% (44/44), 21.08 KiB | 3.51 MiB/s, done.\n", "Resolving deltas: 100% (19/19), done.\n" ] } ], "source": [ "!git clone https://github.com/bigcode-project/starcoder2.git" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 556, "referenced_widgets": [ "2f3feb1169a94ae1a9cc9c8a07025339", "67d7882371354f5384b34ca079ff4a7e", "8fa955bb306048708d08cb74bec63fa9", "868a7e6273fe48c1a74b09d04d20ffc7", "d9e34b5c6dff4f7cbc0d44e6cd5be9b7", "2d661189af574afa91c3413393ec9473", "a76ee308ac3743aba2930a44180d060d", "7a0233222a7b42e8942de013d9e24e3d", "e00e5bad7f94413989c937fa091302fd", "58eff06c7e21423a9aa2b95ef8d214da", "1487d69479034d33ab497dd89f09b2b5", "58089e7274a14f16b79560d0848220b2", "e1fb6135b1764582b2415de45e5b4ba6", "6cf33df65db94dfeb44fa3a98b66f36e", "dc9f05e0dcda49b3974172e63f035d19", "08c9e9fdbcf547b291e04c54d6017fd0", "42f60dd18e9344b3a11c690191813036", "7d2a6c9a82dd4952bda67daf9613400a", "c7bc78b2dd6f49c0b3a18696bf62875c", "289538f3b1c842629c175748757b39ea", "94c1d94166c34989a6377352e38052b7", "d50526ae51eb42b895684684da0e3fee", "22bcf3e361904aa782ab8a8dc5581182", "ca0e5eab40f548d69981dd36725c5231", "878ac121e7d14aa591a84408af4cc030", "f5d03c6fa484439d8b9845032d91b458", "4951af5d91814da2b15b1f581e930d93", "d0bce1de6c68452a953617f8ed05ac47", "a014618e7d5a41ada7ec516c320180e1", "ab8a4661717d4ccd9199fdad333d96d8", "7a07be7145d84ecc98e527aefbfe7497", "d503472502ab4cab92d52ea7818e4101", "7f39a657a30c4e4e887ab06d65c27127", "c5d32741646d4128b05342cc98566471", "2d91c0ad37824a659ac61fc6ac7d7205", "8bec6880b9e24c17a736441f36b5e343", "55f208052abc4caaaa981c65fc7b6fc3", "9a4b90b52b1744309a3c4d6277e3636c", "f820b30f34134404a6dba98be2cd3a0d", "795350c486cc4c3b9e131879eae902e1", "3d273a73cbff4cbd9d8b7e60eea7f723", "66125ad5301f42dea4d26592b6b7178e", "d55acc6a03a64c91b2af067976f47b51", "daaf12b7316943eca28d44c5e873eb78", "f4c1b562a4ee4120b4c64ab0916c14dd", "0d5f76cbb3454dae9d058459c9c11bcb", "e83d18063350487481bcb4b1dc821d14", "737406625d254033bd159113ca729dad", "701615d1abe649a68c5ff65c7a3d0faf", "f42bd79b1e6c4b2599d01c1fa891920a", "c33ae9e14d5c41b59f8b45161f6faa52", "3e38f7d06eb840e0b734f38d82b5c816", "c49886e0f7624260a972fadd7bf77a69", "eea077a9ee19459c86ee34d2b23900de", "5fd1bb46157c41748d4f061ab8d75bfb", "5d786630032e489c89e2b9cf7210123d", "33f1381a8bec4b948db3f06dc246ad07", "154f247e92cd4b7cbfa4f90ceca4f1d5", "df8be82050774e458f92e19a84faea58", "07bbb73ea5a34b72848ffff4f6eb62eb", "6eae916351f34239bf00f8cc08b85b7b", "7583943dc9144252b515e80415f53117", "51e96f64d31844248bab6ef35c3bb336", "7e237dccec4148f9b01b657c64f11eca", "d6be2d40b83b4474902ff541851a2130", "9ddff15fea8944a29b203f7362d6d315", "2b2a93e5eeae49f3a5ad89be3e6ec2e6", "7031271458bd4612a80c9bd6d54990b5", "01f1e79284e64975b50b8cb5622368e6", "2773fbb3f41048b2b9d177edfb15d842", "6adeeb7ea0284fd29b8df2b05fb331f5", "028683b9eade4b98b80ba6492765e5d8", "b49472e20ba04ce88350f757acf2cf1d", "4f1d66b394f546909e237ad17c61fdd6", "cf12ede2afe244b9a7f302a13d8b9431", "24107f323d994281ab99b310290fb7b7", "dec5fc7d3d4c495398efb2217dba2569" ] }, "id": "6Wmtztbk5rDG", "outputId": "319874de-a453-438f-aaee-cc7b30ea4806" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2f3feb1169a94ae1a9cc9c8a07025339", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/7.88k [00:00=1.4.0 in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (2.2.1+cu121)\n", "Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (4.40.0.dev0)\n", "Requirement already satisfied: numpy>=1.18.2 in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (1.25.2)\n", "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (0.28.0)\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (2.18.0)\n", "Requirement already satisfied: tyro>=0.5.11 in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (0.7.3)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (3.13.3)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (4.10.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (3.1.3)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (2023.6.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (2.19.3)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.105)\n", "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (2.2.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.4.0->trl==0.8.2.dev0) (12.4.99)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (0.20.3)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (24.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (2023.12.25)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (0.15.2)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (0.4.2)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (4.66.2)\n", "Requirement already satisfied: docstring-parser>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl==0.8.2.dev0) (0.16)\n", "Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl==0.8.2.dev0) (13.7.1)\n", "Requirement already satisfied: shtab>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl==0.8.2.dev0) (1.7.1)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate->trl==0.8.2.dev0) (5.9.5)\n", "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (14.0.2)\n", "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (1.5.3)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (3.4.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (0.70.16)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (3.9.3)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (4.0.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl==0.8.2.dev0) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl==0.8.2.dev0) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl==0.8.2.dev0) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl==0.8.2.dev0) (2024.2.2)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl==0.8.2.dev0) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl==0.8.2.dev0) (2.16.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.4.0->trl==0.8.2.dev0) (2.1.5)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->trl==0.8.2.dev0) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->trl==0.8.2.dev0) (2023.4)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.4.0->trl==0.8.2.dev0) (1.3.0)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl==0.8.2.dev0) (0.1.2)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets->trl==0.8.2.dev0) (1.16.0)\n", "Building wheels for collected packages: trl\n", " Building wheel for trl (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for trl: filename=trl-0.8.2.dev0-py3-none-any.whl size=238373 sha256=ffa2f194d60abca5186380c50afb3496403e9a219629e4294a7c3dc041069957\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-bvjt6eyh/wheels/22/0e/42/319b77b2648bb6140ef2b08b0478ede9ca3cc7879fcd022d36\n", "Successfully built trl\n", "Installing collected packages: trl\n", " Attempting uninstall: trl\n", " Found existing installation: trl 0.8.1\n", " Uninstalling trl-0.8.1:\n", " Successfully uninstalled trl-0.8.1\n", "Successfully installed trl-0.8.2.dev0\n", "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.10.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (24.0)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n", "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.2.1+cu121)\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.40.0.dev0)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft) (4.66.2)\n", "Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.28.0)\n", "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.4.2)\n", "Requirement already satisfied: huggingface-hub>=0.17.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.20.3)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (3.13.3)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (2023.6.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (2.31.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (4.10.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.3)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.19.3)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.2.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.13.0->peft) (12.4.99)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2023.12.25)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.15.2)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.5)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (2024.2.2)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n" ] } ], "source": [ "!pip install git+https://github.com/huggingface/trl.git\n", "!pip install peft" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "id": "9nDiRQ-D7FV_" }, "outputs": [], "source": [ "import argparse\n", "import multiprocessing\n", "import os\n", "\n", "import torch\n", "import transformers\n", "from accelerate import PartialState\n", "from datasets import load_dataset\n", "from peft import LoraConfig\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " BitsAndBytesConfig,\n", " logging,\n", " set_seed,\n", ")\n", "from trl import SFTTrainer" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "dJlPQ31k7OeH" }, "outputs": [], "source": [ "parser = argparse.ArgumentParser()\n", "parser.add_argument(\"--model_id\", type=str, default=\"bigcode/starcoder2-3b\")\n", "parser.add_argument(\"--dataset_name\", type=str, default=\"bigcode/the-stack-smol\")\n", "parser.add_argument(\"--subset\", type=str, default=\"data/html\")\n", "parser.add_argument(\"--split\", type=str, default=\"train\")\n", "parser.add_argument(\"--dataset_text_field\", type=str, default=\"content\")\n", "\n", "parser.add_argument(\"--max_seq_length\", type=int, default=260)\n", "parser.add_argument(\"--max_steps\", type=int, default=500)\n", "parser.add_argument(\"--micro_batch_size\", type=int, default=1)\n", "parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=4)\n", "parser.add_argument(\"--weight_decay\", type=float, default=0.01)\n", "parser.add_argument(\"--fp16\", type=bool, default=True)\n", "\n", "parser.add_argument(\"--attention_dropout\", type=float, default=0.1)\n", "parser.add_argument(\"--learning_rate\", type=float, default=2e-4)\n", "parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"cosine\")\n", "parser.add_argument(\"--warmup_steps\", type=int, default=100)\n", "parser.add_argument(\"--seed\", type=int, default=0)\n", "parser.add_argument(\"--output_dir\", type=str, default=\"Trisha_StarCoder2_HTML\")\n", "parser.add_argument(\"--num_proc\", type=int, default=2)#T4 gpu of google colab\n", "parser.add_argument(\"--push_to_hub\", type=bool, default=True)\n", "\n", "args = parser.parse_args(args=[])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VjfP9QhY7gNe", "outputId": "5f1a4281-7aa8-4706-8be6-80b24946d26b" }, "outputs": [ { "data": { "text/plain": [ "Namespace(model_id='bigcode/starcoder2-3b', dataset_name='bigcode/the-stack-smol', subset='data/html', split='train', dataset_text_field='content', max_seq_length=260, max_steps=500, micro_batch_size=1, gradient_accumulation_steps=4, weight_decay=0.01, fp16=True, attention_dropout=0.1, learning_rate=0.0002, lr_scheduler_type='cosine', warmup_steps=100, seed=0, output_dir='Trisha_StarCoder2_HTML', num_proc=2, push_to_hub=True)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "args" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "-9lrGDar7h-9" }, "outputs": [], "source": [ "import locale\n", "def getpreferredencoding(do_setlocale = True):\n", " return \"UTF-8\"\n", "locale.getpreferredencoding = getpreferredencoding" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "id": "HmbEI81R7qtP" }, "outputs": [], "source": [ "def print_trainable_parameters(model):\n", " \"\"\"\n", " Prints the number of trainable parameters in the model.\n", " \"\"\"\n", " trainable_params = 0\n", " all_param = 0\n", " for _, param in model.named_parameters():\n", " all_param += param.numel()\n", " if param.requires_grad:\n", " trainable_params += param.numel()\n", " print(\n", " f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n", " )" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "IB7SCdav7rqF" }, "outputs": [], "source": [ "def main(args):\n", " # config\n", " bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.bfloat16,\n", " )\n", " lora_config = LoraConfig(\n", " r=8,\n", " target_modules=[\n", " \"q_proj\",\n", " \"o_proj\",\n", " \"k_proj\",\n", " \"v_proj\",\n", " \"gate_proj\",\n", " \"up_proj\",\n", " \"down_proj\",\n", " ],\n", " task_type=\"CAUSAL_LM\",\n", " )\n", "\n", " # load model and dataset\n", " token = os.environ.get(\"HF_TOKEN\", None)\n", " model = AutoModelForCausalLM.from_pretrained(\n", " args.model_id,\n", " quantization_config=bnb_config,\n", " device_map={\"\": PartialState().process_index},\n", " attention_dropout=args.attention_dropout,\n", " )\n", " print_trainable_parameters(model)\n", "\n", " data = load_dataset(\n", " args.dataset_name,\n", " data_dir=args.subset,\n", " split=args.split,\n", " token=token,\n", " num_proc=args.num_proc if args.num_proc else multiprocessing.cpu_count(),\n", " )\n", "\n", " # setup the trainer\n", " trainer = SFTTrainer(\n", " model=model,\n", " train_dataset=data,\n", " max_seq_length=args.max_seq_length,\n", " args=transformers.TrainingArguments(\n", " per_device_train_batch_size=args.micro_batch_size,\n", " gradient_accumulation_steps=args.gradient_accumulation_steps,\n", " warmup_steps=args.warmup_steps,\n", " max_steps=args.max_steps,\n", " learning_rate=args.learning_rate,\n", " lr_scheduler_type=args.lr_scheduler_type,\n", " weight_decay=args.weight_decay,\n", " fp16=args.fp16,\n", " logging_strategy=\"steps\",\n", " logging_steps=10,\n", " output_dir=args.output_dir,\n", " optim=\"paged_adamw_8bit\",\n", " seed=args.seed,\n", " run_name=f\"train-{args.model_id.split('/')[-1]}\",\n", " report_to=\"wandb\",\n", " ),\n", " peft_config=lora_config,\n", " dataset_text_field=args.dataset_text_field,\n", " )\n", "\n", " # launch\n", " print(\"Training...\")\n", " trainer.train()\n", "\n", " print(\"Saving the last checkpoint of the model\")\n", " model.save_pretrained(os.path.join(args.output_dir, \"final_checkpoint/\"))\n", " if args.push_to_hub:\n", " trainer.push_to_hub(\"Upload model\")\n", " print(\"Finetuning Completed! 💥\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "812fe30232e54288925c45d6afccab40", "12a34cf1ef314151b0e1ff02409aa13f", "6f9dd57dbb7c4381acd00c971a397d3e", "9c69f684a11c4dd2b67d8b9ea33526a4", "526bca7b47964979bb48da5a789cef1a", "f8dab055c4884770aee67d734a790fb6", "ce8c71126884425d8b68a2f596215f91", "a8d9fe2fff524fbaadaa797b5f5a36fd", "2c2c04246e524dd2bc2d42ed747a87a3", "8e10f8be1a0644669192f70deb60033a", "6e652e4f78db4b91a53c287fda5c9dfa", "bffec89db1754c0bb92ce19dc565d9e1", "6c96a011faad4d2d9b699da0d0cc5811", "29bacbb3352f4af6a0f5eaa83cc34e37", "84da53eb21904acd97236cf0e2b0b656", "1b4624099ca047c0841f4066f5951247", "12f7afb8996d415788791bd67a491843", "01ca3249f1e04926b2eceb0d40695f4e", "4132701b8a324cf580279d2f1391009c", "34ce7a0f16ca4686893b03ae2d492a1e", "95386b5600c54e58b3287faa2a6e808f", "830de09f69434f3cb0dabd20325d489d", "25307eb5938a4196bf2ba52c733f432f", "d7de8b2a512845beb76dafaf286bcc6d", "182e0a42e18f4817a7127106dede449b", "ef0c036aeeeb48b5a849c0e61e2b18df", "ae4f0c4953984756af72cb71c5d3e580", "212b6fa9dc28483fa9677d60afd7b13d", "f0d248ad73964a8db6f249471e38329b", "6ccbd35589fc4ce28ab9e35514373dac", "120368ef5e734367b8ed92664652db66", "d399680217624d6db12d4d0f762f8d5d", "9c0aee65a76f4b8e885a64a1642bacc1", "2e7e51152e194fb1995be1873c44ff02", "b2f38b5029264fffae63c728f73ff5b5", "4f1bfa5f84784a48b33ccc1b1d2a77d7", "e2ccaf5bf91c4029a180a3deb932c7e5", "c03def91475343b38c7e1fa5949c27d3", "2a9eb0aed91840afa3af51897716fbe5", "8972d16220994379ab91647b2c068d18", "481a30187f0843e2894e09bf87bd45ea", "1160e809b79a4b9b806cb1182837ba03", "f1c117aed8104724b06b1ab4f99dc547", "64c8778757b74caba5d15f3e331b1459", "77f1135856204da69cb8b76029444a76", "43ccd5be73a04381a684d44d5daf43e6", "f6f1db8a8dc1402a9909fb5606bf26ee", "76d0c46a87ee40ee96087a7be844ce15", "fe35adc4c108457fab9d968c1c514173", "0e31e047caa24b86bc916ad5f3c86f0b", "a2be2b04563548e19251350edc3c700a", "81f4ef08a25a4a6c8edb8f5da1992bc6", "f5d2b55abac84ecea1c8a36b80a181f3", "d9c775da0ca5449da906a353305ad6db", "ef41220d98ae46d084fbfa77e6e91a23" ] }, "id": "5a81RLwA79kO", "outputId": "3e8a1f5b-ee4b-40fa-8225-3e66e4b54450" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trainable params: 151369728 || all params: 1591200768 || trainable%: 9.5129245186488\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "812fe30232e54288925c45d6afccab40", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/10000 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.3" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /content/wandb/run-20240402_070510-cszj19um" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run train-starcoder2-3b to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/kiit-21051861/huggingface" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/kiit-21051861/huggingface/runs/cszj19um" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'loss': 2.3982, 'grad_norm': 0.7019279599189758, 'learning_rate': 2e-05, 'epoch': 0.0}\n", "{'loss': 2.4836, 'grad_norm': 1.1548597812652588, 'learning_rate': 4e-05, 'epoch': 0.01}\n", "{'loss': 2.7578, 'grad_norm': 1.14144766330719, 'learning_rate': 6e-05, 'epoch': 0.01}\n", "{'loss': 2.3468, 'grad_norm': 0.6640823483467102, 'learning_rate': 8e-05, 'epoch': 0.02}\n", "{'loss': 2.1096, 'grad_norm': 1.3852239847183228, 'learning_rate': 0.0001, 'epoch': 0.02}\n", "{'loss': 1.9092, 'grad_norm': 0.9558876752853394, 'learning_rate': 0.00012, 'epoch': 0.02}\n", "{'loss': 1.6677, 'grad_norm': 2.00368595123291, 'learning_rate': 0.00014, 'epoch': 0.03}\n", "{'loss': 1.4821, 'grad_norm': 1.3382819890975952, 'learning_rate': 0.00015800000000000002, 'epoch': 0.03}\n", "{'loss': 1.4721, 'grad_norm': 3.084728717803955, 'learning_rate': 0.00017800000000000002, 'epoch': 0.04}\n", "{'loss': 1.5986, 'grad_norm': 1.746261715888977, 'learning_rate': 0.00019800000000000002, 'epoch': 0.04}\n", "{'loss': 2.0097, 'grad_norm': 1.3398728370666504, 'learning_rate': 0.00019975027964162702, 'epoch': 0.04}\n", "{'loss': 1.514, 'grad_norm': 1.3517955541610718, 'learning_rate': 0.0001988886498744505, 'epoch': 0.05}\n", "{'loss': 1.4133, 'grad_norm': 0.9783771634101868, 'learning_rate': 0.00019741733869698495, 'epoch': 0.05}\n", "{'loss': 1.284, 'grad_norm': 4.398829460144043, 'learning_rate': 0.0001953454172319001, 'epoch': 0.06}\n", "{'loss': 1.4829, 'grad_norm': 1.1008867025375366, 'learning_rate': 0.00019268565956401208, 'epoch': 0.06}\n", "{'loss': 1.1862, 'grad_norm': 1.4246320724487305, 'learning_rate': 0.0001894544639838025, 'epoch': 0.06}\n", "{'loss': 1.4006, 'grad_norm': 1.6339792013168335, 'learning_rate': 0.00018567175188650498, 'epoch': 0.07}\n", "{'loss': 1.1882, 'grad_norm': 1.117507815361023, 'learning_rate': 0.00018136084495007872, 'epoch': 0.07}\n", "{'loss': 1.3363, 'grad_norm': 2.1591148376464844, 'learning_rate': 0.00017654832134930882, 'epoch': 0.08}\n", "{'loss': 1.4555, 'grad_norm': 1.2217527627944946, 'learning_rate': 0.00017126385189252053, 'epoch': 0.08}\n", "{'loss': 1.1622, 'grad_norm': 4.409976005554199, 'learning_rate': 0.0001655400170911794, 'epoch': 0.08}\n", "{'loss': 1.1353, 'grad_norm': 1.0652554035186768, 'learning_rate': 0.00015941210629020388, 'epoch': 0.09}\n", "{'loss': 1.2232, 'grad_norm': 2.09572696685791, 'learning_rate': 0.00015291790009741907, 'epoch': 0.09}\n", "{'loss': 1.0229, 'grad_norm': 0.940676212310791, 'learning_rate': 0.00014609743745354624, 'epoch': 0.1}\n", "{'loss': 1.0389, 'grad_norm': 1.5075112581253052, 'learning_rate': 0.00013899276877881884, 'epoch': 0.1}\n", "{'loss': 1.1253, 'grad_norm': 1.4180775880813599, 'learning_rate': 0.00013164769671815862, 'epoch': 0.1}\n", "{'loss': 1.339, 'grad_norm': 1.5058542490005493, 'learning_rate': 0.00012410750608330388, 'epoch': 0.11}\n", "{'loss': 1.1879, 'grad_norm': 1.1217693090438843, 'learning_rate': 0.0001164186846568863, 'epoch': 0.11}\n", "{'loss': 1.0624, 'grad_norm': 1.1402180194854736, 'learning_rate': 0.00010862863657979237, 'epoch': 0.12}\n", "{'loss': 1.0309, 'grad_norm': 1.4527842998504639, 'learning_rate': 0.00010078539008887114, 'epoch': 0.12}\n", "{'loss': 1.2608, 'grad_norm': 2.6497609615325928, 'learning_rate': 9.293730140688336e-05, 'epoch': 0.12}\n", "{'loss': 0.9981, 'grad_norm': 1.0792790651321411, 'learning_rate': 8.51327566103077e-05, 'epoch': 0.13}\n", "{'loss': 1.1313, 'grad_norm': 1.34031343460083, 'learning_rate': 7.741987331308964e-05, 'epoch': 0.13}\n", "{'loss': 1.1587, 'grad_norm': 1.6864031553268433, 'learning_rate': 6.984620400555044e-05, 'epoch': 0.14}\n", "{'loss': 1.2093, 'grad_norm': 1.1277486085891724, 'learning_rate': 6.245844287747168e-05, 'epoch': 0.14}\n", "{'loss': 1.1832, 'grad_norm': 1.079222559928894, 'learning_rate': 5.53021379328879e-05, 'epoch': 0.14}\n", "{'loss': 1.3527, 'grad_norm': 1.864291787147522, 'learning_rate': 4.842141017149526e-05, 'epoch': 0.15}\n", "{'loss': 1.1799, 'grad_norm': 2.1111037731170654, 'learning_rate': 4.185868156801694e-05, 'epoch': 0.15}\n", "{'loss': 1.2483, 'grad_norm': 2.539325475692749, 'learning_rate': 3.565441352662211e-05, 'epoch': 0.16}\n", "{'loss': 1.0669, 'grad_norm': 1.9620802402496338, 'learning_rate': 2.9846857422914433e-05, 'epoch': 0.16}\n", "{'loss': 1.1064, 'grad_norm': 1.5506763458251953, 'learning_rate': 2.4471818771481648e-05, 'epoch': 0.16}\n", "{'loss': 1.0647, 'grad_norm': 1.7290705442428589, 'learning_rate': 1.9562436472991552e-05, 'epoch': 0.17}\n", "{'loss': 1.0844, 'grad_norm': 0.6486456394195557, 'learning_rate': 1.5148978501849642e-05, 'epoch': 0.17}\n", "{'loss': 1.2124, 'grad_norm': 1.4418706893920898, 'learning_rate': 1.1258655294071685e-05, 'epoch': 0.18}\n", "{'loss': 1.121, 'grad_norm': 0.826423704624176, 'learning_rate': 7.915451985897382e-06, 'epoch': 0.18}\n", "{'loss': 1.1724, 'grad_norm': 1.253172755241394, 'learning_rate': 5.13998053744954e-06, 'epoch': 0.18}\n", "{'loss': 1.1511, 'grad_norm': 0.8592811822891235, 'learning_rate': 2.949352653145754e-06, 'epoch': 0.19}\n", "{'loss': 1.1563, 'grad_norm': 1.1218185424804688, 'learning_rate': 1.3570742823504567e-06, 'epoch': 0.19}\n", "{'loss': 1.2377, 'grad_norm': 1.4387160539627075, 'learning_rate': 3.7296235070587435e-07, 'epoch': 0.2}\n", "{'loss': 1.13, 'grad_norm': 0.7837572693824768, 'learning_rate': 3.0842355210336515e-09, 'epoch': 0.2}\n", "{'train_runtime': 1106.6799, 'train_samples_per_second': 1.807, 'train_steps_per_second': 0.452, 'train_loss': 1.3810003318786621, 'epoch': 0.2}\n", "Saving the last checkpoint of the model\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bffec89db1754c0bb92ce19dc565d9e1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "model.safetensors: 0%| | 0.00/2.25G [00:00