{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "0-7S1J6Jq7nc" }, "source": [ "# Fine-Tuning BERT as a `RewardModel`\n", "\n", "1. First, intall `transformers`, `tlr`, and `codecarbon`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Fx7pg9eT62-d", "outputId": "00506e81-1d52-4d25-a90e-547d2a923e1d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/7.5 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.1/7.5 MB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:04\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/7.5 MB\u001b[0m \u001b[31m45.3 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m80.4 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.5/7.5 MB\u001b[0m \u001b[31m59.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m110.0/110.0 kB\u001b[0m \u001b[31m17.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.8/179.8 kB\u001b[0m \u001b[31m24.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m35.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m116.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m251.2/251.2 kB\u001b[0m \u001b[31m35.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.3/519.3 kB\u001b[0m \u001b[31m55.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.4/66.4 kB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m8.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m25.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m18.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip install transformers trl codecarbon -q" ] }, { "cell_type": "markdown", "metadata": { "id": "Y6xzGtxPrMaF" }, "source": [ "2. Downloas the `reward-aira-dataset-comparisons` from the Hub." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 774, "referenced_widgets": [ "f45e7c737c464f0d959184ceb5398af8", "c3c0695f98784127afc9271fa8e76b04", "caa5b35f5d484342a1f3c77dfd745e15", "c0bfc0285996404eb6e270a9cedb8a98", "6e865ee06fec4728979476624379cbd4", "b7e5236ea8d04870936a007b49730d6a", "77c59c65323043dfacf4d53513502e00", "2c04534f8ee4478c8b4235027e02986a", "671a5338d9ad4796b87523285ca101f5", "d5e32441b5534c6d92627ff568887312", "67df59f69c474f2aa7125564ebd0a523", "355a0bcdd553414eaea64c04d1149697", "01ab2ce6e77c4fd891d79a1a783ad7c1", "e1c692274d624d89b9392a52e8c3446b", "39bdf3ea275e439bb047bd34cfcbe416", "390a29abcd5e49a9b53cd6197f00bafc", "321ca13a13aa406dbccb8780a4cec29c", "64001368d68545579af007d59a09d03e", "6499256b11a940bc8d72b65b521015e6", "277f6e6884574dd688296f477f21edef", "8adc0992233c43f782a41065381829e5", "e638493411d24afe8e81ca2bda175a6b", "6389c5cbe33043ceb9891e77e8593d06", "5c7a7880cc734ed79b6c1ff6de9c4036", "3a6cd49593ee4647af4262f15852e35b", "61d0076b4ef54ea3b878fb2db3930f59", "ab2651d737854c84bba9b90b8947aced", "faa296ab1e7d4974aaa1be7e9389cf61", "7798407a6bc64199bc9f96ff7fa6f5c2", "cd4916e919074fdcba49311ee4a65631", "e9e2d8f878b94b64a26d1dd61f4a2a5b", "afd5ee16af434495ae4ef9c4f1009d20", "1c9eaf27d6f34cff887f6e7801226257", "d5cb1261ea1e437fa84da0d7e4422d16", "abd21429017440368bef2ac9f054079c", "b12d59aa4f0244b1a35fc63f086467ee", "078515a012b94209bafa6b20826785d4", "3a1406a921ff4e11b2b25cad3eb6f24e", "228cb5e6187145bfa93c618d31371124", "69461a0331784fa78dcc4557b5cbce6c", "79bd1b6112c54829af0e6d5c970d6a4d", "beb06a39b975411dbb879d59ca02acb4", "12bcd35c01ba4599b6780582ebc40f11", "b4d40262a5e44b96ae2a79ec9b3f2a0e", "69c1f7061d664dd3976502a8a21ce30e", "61e1e5a51c7e4ed7ba9cd3052fcccfa4", "73719a2b12b9414d8b106bd38afce4f7", "fca62001171743669370fa043fef4b67", "3e4116671a434acebfc84035b5756c89", "feb4735f216b4e1297a941cad1dbc141", "1bb2df10968240e09efcd2fd10a3a208", "156f0254cf2c459d9a0efeac9f5de5c1", "e9db9cd5b17f4bf39d3946986956c335", "6fd00edfc3fe473f91f36c3a8b2ce900", "397f851d00934dd7b71fd6f9538c1c21" ] }, "id": "DtCgCgEr62C9", "outputId": "5ca419b2-07e0-4317-f8fd-41b1823cf302" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f45e7c737c464f0d959184ceb5398af8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading readme: 0%| | 0.00/464 [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "355a0bcdd553414eaea64c04d1149697", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data files: 0%| | 0/1 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6389c5cbe33043ceb9891e77e8593d06", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading data: 0%| | 0.00/31.2M [00:00, ?B/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d5cb1261ea1e437fa84da0d7e4422d16", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Extracting data files: 0%| | 0/1 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "69c1f7061d664dd3976502a8a21ce30e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating english split: 0%| | 0/32675 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Dataset loaded.\n" ] }, { "data": { "text/html": [ "\n", "
\n", " | instruction | \n", "chosen_response | \n", "rejected_response | \n", "
---|---|---|---|
0 | \n", "I'm going to be in Rome next month and I would... | \n", "That's great to hear that you'll be visiting R... | \n", "Rome is a city filled with incredible attracti... | \n", "
1 | \n", "How do I determine the value of an antique item. | \n", "Determining the value of an antique item can b... | \n", "To determine the value of an antique item, res... | \n", "
2 | \n", "Does feature engineering always improve model ... | \n", "Feature engineering is a critical step in mach... | \n", "To optimize the performance of machine learnin... | \n", "
3 | \n", "What is the difference between an oval and a c... | \n", "An oval is a stretched shape that is longer in... | \n", "The main difference between an oval and a circ... | \n", "
4 | \n", "What is the history of the Suez Canal. | \n", "The Suez Canal is an artificial sea-level wate... | \n", "The Suez Canal is an artificial sea - level wa... | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
32670 | \n", "Is it true that the Earth’s core is made of iron. | \n", "Yes, it is generally accepted that the Earth's... | \n", "Yes, it is true that the Earth's core is made ... | \n", "
32671 | \n", "I'm doing a science experiment but I can't see... | \n", "Of course! I'd be happy to help you troublesho... | \n", "Yes, I can certainly help you troubleshoot the... | \n", "
32672 | \n", "How much exercise do I need to do each day. | \n", "The amount of exercise you need to do each day... | \n", "It depends on your lifestyle and fitness goals... | \n", "
32673 | \n", "What are the most common mistakes in financial... | \n", "There are several common mistakes that people ... | \n", "Not budgeting for unexpected expenses. 2 Not p... | \n", "
32674 | \n", "I'm planning a camping trip for this weekend, ... | \n", "Of course! I'd be happy to help you find some ... | \n", "I suggest researching campsites that are close... | \n", "
32675 rows × 3 columns
\n", "Step | \n", "Training Loss | \n", "
---|---|
200 | \n", "0.042300 | \n", "
400 | \n", "0.008500 | \n", "
600 | \n", "0.006900 | \n", "
800 | \n", "0.005000 | \n", "
1000 | \n", "0.001200 | \n", "
1200 | \n", "0.000900 | \n", "
"
],
"text/plain": [
"