{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "### Data Preparation" ], "metadata": { "id": "ga8c1nhja4Qy" } }, { "cell_type": "code", "source": [ "!pip install opendatasets" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "O7NczD5abI6o", "outputId": "422faa21-1ee0-4582-9315-4c2b01f4518d" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting opendatasets\n", " Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from opendatasets) (4.66.5)\n", "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (from opendatasets) (1.6.17)\n", "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from opendatasets) (8.1.7)\n", "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (1.16.0)\n", "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2024.8.30)\n", "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.8.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.32.3)\n", "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (8.0.4)\n", "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.0.7)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (6.1.0)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle->opendatasets) (0.5.1)\n", "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle->opendatasets) (1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.10)\n", "Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)\n", "Installing collected packages: opendatasets\n", "Successfully installed opendatasets-0.1.22\n" ] } ] }, { "cell_type": "code", "source": [ "import opendatasets as od\n", "od.download('https://www.kaggle.com/datasets/hassaanidrees/medinfo?select=MedInfo2019-QA-Medications.xlsx')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7QSxa8cRbIug", "outputId": "088ef3d5-b3fc-4860-8928-bb872ff83ab5" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Dataset URL: https://www.kaggle.com/datasets/hassaanidrees/medinfo\n", "Downloading medinfo.zip to ./medinfo\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "100%|██████████| 159k/159k [00:00<00:00, 480kB/s]" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "\n" ] } ] }, { "cell_type": "code", "source": [ "# Import pandas for data analysis\n", "import pandas as pd\n", "df = pd.read_excel(\"/content/medinfo/MedInfo2019-QA-Medications.xlsx\")\n", "df = df[['Question','Answer']]" ], "metadata": { "id": "sooD64r3bIDJ" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "df.head() #show first five rows" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "eRneQPLAqAJL", "outputId": "d1772f7e-8edd-4687-9c1a-c3102e86138e" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Question \\\n", "0 how does rivatigmine and otc sleep medicine in... \n", "1 how does valium affect the brain \n", "2 what is morphine \n", "3 what are the milligrams for oxycodone e \n", "4 81% aspirin contain resin and shellac in it. ? \n", "\n", " Answer \n", "0 tell your doctor and pharmacist what prescript... \n", "1 Diazepam is a benzodiazepine that exerts anxio... \n", "2 Morphine is a pain medication of the opiate fa... \n", "3 … 10 mg … 20 mg … 40 mg … 80 mg ... \n", "4 Inactive Ingredients Ingredient Name " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
QuestionAnswer
0how does rivatigmine and otc sleep medicine in...tell your doctor and pharmacist what prescript...
1how does valium affect the brainDiazepam is a benzodiazepine that exerts anxio...
2what is morphineMorphine is a pain medication of the opiate fa...
3what are the milligrams for oxycodone e… 10 mg … 20 mg … 40 mg … 80 mg ...
481% aspirin contain resin and shellac in it. ?Inactive Ingredients Ingredient Name
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "df", "summary": "{\n \"name\": \"df\",\n \"rows\": 690,\n \"fields\": [\n {\n \"column\": \"Question\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 651,\n \"samples\": [\n \"how is marijuana used\",\n \"tudorza pressair is what schedule drug\",\n \"how long does ecstasy or mda leave your body\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Answer\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 652,\n \"samples\": [\n \"Marijuana is best known as a drug that people smoke or eat to get high. It is derived from the plant Cannabis sativa. Possession of marijuana is illegal under federal law. Medical marijuana refers to using marijuana to treat certain medical conditions. In the United States, about half of the states have legalized marijuana for medical use.\",\n \"Color - GRAY, Shape - CAPSULE (biconvex), Score - no score, Size - 12mm, Imprint Code - m10\",\n \"Quantity: 60; Per Unit: $4.68 \\u2013 $15.91; Price: $280.99 \\u2013 $954.47\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 4 } ] }, { "cell_type": "code", "source": [ "df.Question[0]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "id": "4SEkJJHwqBwo", "outputId": "7aeec0ad-b51a-44fa-f2e1-5a93b61246d5" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'how does rivatigmine and otc sleep medicine interact'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 5 } ] }, { "cell_type": "code", "source": [ "df.Answer[0]" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 105 }, "id": "qTllg8a-qGXW", "outputId": "a6b8bca7-135e-4e26-e0ff-a2a1424bc45c" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "\"tell your doctor and pharmacist what prescription and nonprescription medications, vitamins, nutritional supplements, and herbal products you are taking or plan to take. Be sure to mention any of the following: antihistamines; aspirin and other nonsteroidal anti-inflammatory medications (NSAIDs) such as ibuprofen (Advil, Motrin) and naproxen (Aleve, Naprosyn); bethanechol (Duvoid, Urecholine); ipratropium (Atrovent, in Combivent, DuoNeb); and medications for Alzheimer's disease, glaucoma, irritable bowel disease, motion sickness, ulcers, or urinary problems. Your doctor may need to change the doses of your medications or monitor you carefully for side effects.\"" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 6 } ] }, { "cell_type": "code", "source": [ "df.shape # 690 rows | 2 cols" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xs_qECG1qIW5", "outputId": "678a409c-9164-48f4-803e-501d3dff3c96" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(690, 2)" ] }, "metadata": {}, "execution_count": 7 } ] }, { "cell_type": "code", "source": [ "!pip install cleantext" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "LPvkkbdbrNp-", "outputId": "938e6a8d-fb4b-4112-9a0e-3139146e56eb" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting cleantext\n", " Downloading cleantext-1.1.4-py3-none-any.whl.metadata (3.5 kB)\n", "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from cleantext) (3.8.1)\n", "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (8.1.7)\n", "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (1.4.2)\n", "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (2024.5.15)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (4.66.5)\n", "Downloading cleantext-1.1.4-py3-none-any.whl (4.9 kB)\n", "Installing collected packages: cleantext\n", "Successfully installed cleantext-1.1.4\n" ] } ] }, { "cell_type": "code", "source": [ "import cleantext\n", "\n", "# Function to clean text data by removing unwanted characters and formatting\n", "def clean(textdata):\n", " cleaned_text = []\n", " for i in textdata:\n", " cleaned_text.append(cleantext.clean(str(i), extra_spaces=True, lowercase=True, stopwords=False, stemming=False, numbers=True, punct=True, clean_all = True))\n", "\n", " return cleaned_text" ], "metadata": { "id": "dws3d49Lqv1b" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Apply the clean function to the questions and answers columns\n", "\n", "df.Question = list(clean(df.Question))\n", "df.Answer = list(clean(df.Answer))" ], "metadata": { "id": "H1ia-jFqrIsG" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Save the cleaned data into a new CSV file & save\n", "df.to_csv(\"cleaned_med_QA_data.csv\", index=False)" ], "metadata": { "id": "HcB15JQirImk" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### GPT-2 Model" ], "metadata": { "id": "zw5mkpmueML4" } }, { "cell_type": "code", "source": [ "!pip install datasets" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "QhgGKgZ-rYAY", "outputId": "f2334a48-2745-42b5-f5fd-929ca58e1ed6", "collapsed": true }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting datasets\n", " Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.0)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", "Collecting pyarrow>=15.0.0 (from datasets)\n", " Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\n", "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\n", "Collecting xxhash (from datasets)\n", " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", "Collecting multiprocess (from datasets)\n", " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.5)\n", "Requirement already satisfied: huggingface-hub>=0.22.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.24.6)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.11.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.22.0->datasets) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.8)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Downloading datasets-3.0.0-py3-none-any.whl (474 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m474.3/474.3 kB\u001b[0m \u001b[31m32.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.9/39.9 MB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m14.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: xxhash, pyarrow, dill, multiprocess, datasets\n", " Attempting uninstall: pyarrow\n", " Found existing installation: pyarrow 14.0.2\n", " Uninstalling pyarrow-14.0.2:\n", " Successfully uninstalled pyarrow-14.0.2\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.\n", "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed datasets-3.0.0 dill-0.3.8 multiprocess-0.70.16 pyarrow-17.0.0 xxhash-3.5.0\n" ] }, { "output_type": "display_data", "data": { "application/vnd.colab-display-data+json": { "pip_warning": { "packages": [ "pyarrow" ] }, "id": "a6cd6efad93b4c4cb5a29a91b023de8a" } }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments\n", "import torch\n", "from datasets import load_dataset\n", "\n", "# Load the GPT-2 model and tokenizer\n", "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n", "model = GPT2LMHeadModel.from_pretrained('gpt2')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "xgGgvCqerk-1", "outputId": "e338ee7f-c898-41c4-b1f6-036f115d3735" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", " warnings.warn(\n" ] } ] }, { "cell_type": "code", "source": [ "# Set the padding token for the tokenizer to be the end-of-sequence token\n", "tokenizer.pad_token = tokenizer.eos_token\n", "\n", "# Maximum sequence length that GPT-2 can handle\n", "max_length = tokenizer.model_max_length\n", "print(max_length)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EeiMYkpCrp62", "outputId": "e8b0118b-1694-4d9e-d666-e791b083f63f" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "1024\n" ] } ] }, { "cell_type": "code", "source": [ "# Load the cleaned QA dataset as a training set using the 'datasets' library\n", "dataset = load_dataset('csv', data_files={'train': 'cleaned_med_QA_data.csv'}, split='train')" ], "metadata": { "id": "MW5Ad0exrry3" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#Function to tokenize questions and answers and prepare them for the model\n", "def tokenize_function(examples):\n", " '''1. Combine each question and answer into a single input string\n", " 2. Tokenize the combined text using the GPT-2 tokenizer\n", " 3. Set the labels to be the same as the input_ids (shifted to predict the next word)\n", " 4. Return the tokenized output. '''\n", "\n", " combined_text = [str(q) + \" \" + str(a) for q, a in zip(examples['Question'], examples['Answer'])]\n", " tokenized_output = tokenizer(combined_text, padding='max_length', truncation=True, max_length=128)\n", "\n", " # Set the labels to be the same as the input_ids (shifted to predict the next word)\n", " tokenized_output['labels'] = tokenized_output['input_ids'].copy()\n", "\n", " return tokenized_output\n", "\n", "# Tokenize the entire dataset\n", "tokenized_dataset = dataset.map(tokenize_function, batched=True)" ], "metadata": { "id": "99rfOROKr-M0" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Define training arguments for the GPT-2 model\n", "training_args = TrainingArguments(\n", " output_dir='./results', # Directory to save model outputs\n", " num_train_epochs=20, # Train for 50 epochs\n", " per_device_train_batch_size=16, # Batch size during training\n", " per_device_eval_batch_size=32, # Batch size during evaluation\n", " warmup_steps=500, # Warmup steps for learning rate scheduler\n", " weight_decay=0.01, # Weight decay for regularization\n", " logging_dir='./logs', # Directory for saving logs\n", " logging_steps=10, # Log every 10 steps\n", " save_steps=1000, # Save model checkpoints every 1000 steps\n", ")\n", "\n", "# Trainer class to handle training process\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_dataset,\n", " tokenizer=tokenizer,\n", ")\n", "\n", "# Train the model\n", "trainer.train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "TQGJ16yJsCBc", "outputId": "ec5b1ae4-83c1-4117-95fe-3aae63fc0f75", "collapsed": true }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "
\n", " \n", " \n", " [880/880 08:45, Epoch 20/20]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
105.891800
205.497900
304.671300
403.751500
503.016000
602.633300
702.360800
802.079000
902.145600
1002.150100
1102.069300
1202.000300
1301.919900
1401.954000
1501.928500
1601.832900
1701.921300
1802.043500
1901.827400
2001.687700
2101.782400
2201.959600
2301.810500
2401.706800
2501.662200
2601.783900
2701.567300
2801.695100
2901.681800
3001.657400
3101.684000
3201.494700
3301.556800
3401.648300
3501.529300
3601.421200
3701.483900
3801.588400
3901.442200
4001.524600
4101.469100
4201.412900
4301.388300
4401.414400
4501.368200
4601.374900
4701.336500
4801.294900
4901.231700
5001.287600
5101.248500
5201.220700
5301.335700
5401.094200
5501.151400
5601.215000
5701.235600
5801.139800
5901.119600
6001.148000
6101.057300
6201.039700
6301.081300
6400.960300
6501.026400
6601.049900
6700.967600
6800.902100
6900.950900
7000.998500
7101.043500
7200.877700
7300.818800
7400.949500
7501.032200
7600.813600
7700.871600
7800.877400
7900.952400
8000.819600
8100.852700
8200.848300
8300.834200
8400.900900
8500.830800
8600.864700
8700.842200
8800.865000

" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=880, training_loss=1.5622584277933294, metrics={'train_runtime': 525.9662, 'train_samples_per_second': 26.237, 'train_steps_per_second': 1.673, 'total_flos': 901457510400000.0, 'train_loss': 1.5622584277933294, 'epoch': 20.0})" ] }, "metadata": {}, "execution_count": 13 } ] }, { "cell_type": "code", "source": [ "# Save the model\n", "trainer.save_model('med_info_model')" ], "metadata": { "id": "4UrH8iP0u6Cp" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Testing" ], "metadata": { "id": "VhXRJT6jeTuz" } }, { "cell_type": "code", "source": [ "# Function to generate a response based on a user prompt (testing the model)\n", "def generate_response(prompt):\n", " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to('cuda')\n", " outputs = model.generate(inputs, max_length=150, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)\n", "\n", " # Decode the generated output\n", " response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n", "\n", " # Remove the prompt from the response\n", " if response.startswith(prompt):\n", " response = response[len(prompt):].strip() # Remove the prompt from the response\n", "\n", " return response" ], "metadata": { "id": "JbMs8UuSu5_R" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Example conversation\n", "user_input = \"what is desonide ointment used for\"\n", "bot_response = generate_response(user_input)\n", "print(\"Bot Response:\", bot_response)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qsHAT1-uxC4_", "outputId": "89b73c5f-0ae9-449d-8eb4-3df1a7c146bb" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Bot Response: desonide ointment is used to treat a variety of conditions it is used to treat allergies and other skin conditions it is also used to treat certain types of infections it is also used to treat skin infections caused by bacteria that are on skin desonide is in a class of medications called antimicrobials it works by killing bacteria that cause skin infections desonide is in a class of medications called antibiotics it works by killing bacteria that cause skin infections\n" ] } ] }, { "cell_type": "code", "source": [ "# Copying the model to Google Drive (optional)\n", "import shutil\n", "\n", "# Path to the file in Colab\n", "colab_file_path = '/content/med_info_model/model.safetensors'\n", "\n", "# Path to your Google Drive\n", "drive_file_path = '/content/drive/MyDrive'\n", "\n", "# Copy the file\n", "shutil.copy(colab_file_path, drive_file_path)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "id": "aP4IEboMxDWG", "outputId": "c00d1d74-e389-4de4-a151-d20736b6bccd" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'/content/drive/MyDrive/model.safetensors'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 22 } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "uKYwYe5XyXgx" }, "execution_count": null, "outputs": [] } ] }