{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "### Data Preparation" ], "metadata": { "id": "ga8c1nhja4Qy" } }, { "cell_type": "code", "source": [ "!pip install opendatasets" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "O7NczD5abI6o", "outputId": "422faa21-1ee0-4582-9315-4c2b01f4518d" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting opendatasets\n", " Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from opendatasets) (4.66.5)\n", "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (from opendatasets) (1.6.17)\n", "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from opendatasets) (8.1.7)\n", "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (1.16.0)\n", "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2024.8.30)\n", "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.8.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.32.3)\n", "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (8.0.4)\n", "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.0.7)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (6.1.0)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle->opendatasets) (0.5.1)\n", "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle->opendatasets) (1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.10)\n", "Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)\n", "Installing collected packages: opendatasets\n", "Successfully installed opendatasets-0.1.22\n" ] } ] }, { "cell_type": "code", "source": [ "import opendatasets as od\n", "od.download('https://www.kaggle.com/datasets/hassaanidrees/medinfo?select=MedInfo2019-QA-Medications.xlsx')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7QSxa8cRbIug", "outputId": "088ef3d5-b3fc-4860-8928-bb872ff83ab5" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Dataset URL: https://www.kaggle.com/datasets/hassaanidrees/medinfo\n", "Downloading medinfo.zip to ./medinfo\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "100%|██████████| 159k/159k [00:00<00:00, 480kB/s]" ] }, { "output_type": "stream", "name": "stdout", "text": [ "\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "\n" ] } ] }, { "cell_type": "code", "source": [ "# Import pandas for data analysis\n", "import pandas as pd\n", "df = pd.read_excel(\"/content/medinfo/MedInfo2019-QA-Medications.xlsx\")\n", "df = df[['Question','Answer']]" ], "metadata": { "id": "sooD64r3bIDJ" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "df.head() #show first five rows" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "eRneQPLAqAJL", "outputId": "d1772f7e-8edd-4687-9c1a-c3102e86138e" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Question \\\n", "0 how does rivatigmine and otc sleep medicine in... \n", "1 how does valium affect the brain \n", "2 what is morphine \n", "3 what are the milligrams for oxycodone e \n", "4 81% aspirin contain resin and shellac in it. ? \n", "\n", " Answer \n", "0 tell your doctor and pharmacist what prescript... \n", "1 Diazepam is a benzodiazepine that exerts anxio... \n", "2 Morphine is a pain medication of the opiate fa... \n", "3 … 10 mg … 20 mg … 40 mg … 80 mg ... \n", "4 Inactive Ingredients Ingredient Name " ], "text/html": [ "\n", "
\n", " | Question | \n", "Answer | \n", "
---|---|---|
0 | \n", "how does rivatigmine and otc sleep medicine in... | \n", "tell your doctor and pharmacist what prescript... | \n", "
1 | \n", "how does valium affect the brain | \n", "Diazepam is a benzodiazepine that exerts anxio... | \n", "
2 | \n", "what is morphine | \n", "Morphine is a pain medication of the opiate fa... | \n", "
3 | \n", "what are the milligrams for oxycodone e | \n", "… 10 mg … 20 mg … 40 mg … 80 mg ... | \n", "
4 | \n", "81% aspirin contain resin and shellac in it. ? | \n", "Inactive Ingredients Ingredient Name | \n", "
Step | \n", "Training Loss | \n", "
---|---|
10 | \n", "5.891800 | \n", "
20 | \n", "5.497900 | \n", "
30 | \n", "4.671300 | \n", "
40 | \n", "3.751500 | \n", "
50 | \n", "3.016000 | \n", "
60 | \n", "2.633300 | \n", "
70 | \n", "2.360800 | \n", "
80 | \n", "2.079000 | \n", "
90 | \n", "2.145600 | \n", "
100 | \n", "2.150100 | \n", "
110 | \n", "2.069300 | \n", "
120 | \n", "2.000300 | \n", "
130 | \n", "1.919900 | \n", "
140 | \n", "1.954000 | \n", "
150 | \n", "1.928500 | \n", "
160 | \n", "1.832900 | \n", "
170 | \n", "1.921300 | \n", "
180 | \n", "2.043500 | \n", "
190 | \n", "1.827400 | \n", "
200 | \n", "1.687700 | \n", "
210 | \n", "1.782400 | \n", "
220 | \n", "1.959600 | \n", "
230 | \n", "1.810500 | \n", "
240 | \n", "1.706800 | \n", "
250 | \n", "1.662200 | \n", "
260 | \n", "1.783900 | \n", "
270 | \n", "1.567300 | \n", "
280 | \n", "1.695100 | \n", "
290 | \n", "1.681800 | \n", "
300 | \n", "1.657400 | \n", "
310 | \n", "1.684000 | \n", "
320 | \n", "1.494700 | \n", "
330 | \n", "1.556800 | \n", "
340 | \n", "1.648300 | \n", "
350 | \n", "1.529300 | \n", "
360 | \n", "1.421200 | \n", "
370 | \n", "1.483900 | \n", "
380 | \n", "1.588400 | \n", "
390 | \n", "1.442200 | \n", "
400 | \n", "1.524600 | \n", "
410 | \n", "1.469100 | \n", "
420 | \n", "1.412900 | \n", "
430 | \n", "1.388300 | \n", "
440 | \n", "1.414400 | \n", "
450 | \n", "1.368200 | \n", "
460 | \n", "1.374900 | \n", "
470 | \n", "1.336500 | \n", "
480 | \n", "1.294900 | \n", "
490 | \n", "1.231700 | \n", "
500 | \n", "1.287600 | \n", "
510 | \n", "1.248500 | \n", "
520 | \n", "1.220700 | \n", "
530 | \n", "1.335700 | \n", "
540 | \n", "1.094200 | \n", "
550 | \n", "1.151400 | \n", "
560 | \n", "1.215000 | \n", "
570 | \n", "1.235600 | \n", "
580 | \n", "1.139800 | \n", "
590 | \n", "1.119600 | \n", "
600 | \n", "1.148000 | \n", "
610 | \n", "1.057300 | \n", "
620 | \n", "1.039700 | \n", "
630 | \n", "1.081300 | \n", "
640 | \n", "0.960300 | \n", "
650 | \n", "1.026400 | \n", "
660 | \n", "1.049900 | \n", "
670 | \n", "0.967600 | \n", "
680 | \n", "0.902100 | \n", "
690 | \n", "0.950900 | \n", "
700 | \n", "0.998500 | \n", "
710 | \n", "1.043500 | \n", "
720 | \n", "0.877700 | \n", "
730 | \n", "0.818800 | \n", "
740 | \n", "0.949500 | \n", "
750 | \n", "1.032200 | \n", "
760 | \n", "0.813600 | \n", "
770 | \n", "0.871600 | \n", "
780 | \n", "0.877400 | \n", "
790 | \n", "0.952400 | \n", "
800 | \n", "0.819600 | \n", "
810 | \n", "0.852700 | \n", "
820 | \n", "0.848300 | \n", "
830 | \n", "0.834200 | \n", "
840 | \n", "0.900900 | \n", "
850 | \n", "0.830800 | \n", "
860 | \n", "0.864700 | \n", "
870 | \n", "0.842200 | \n", "
880 | \n", "0.865000 | \n", "
" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=880, training_loss=1.5622584277933294, metrics={'train_runtime': 525.9662, 'train_samples_per_second': 26.237, 'train_steps_per_second': 1.673, 'total_flos': 901457510400000.0, 'train_loss': 1.5622584277933294, 'epoch': 20.0})" ] }, "metadata": {}, "execution_count": 13 } ] }, { "cell_type": "code", "source": [ "# Save the model\n", "trainer.save_model('med_info_model')" ], "metadata": { "id": "4UrH8iP0u6Cp" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "### Testing" ], "metadata": { "id": "VhXRJT6jeTuz" } }, { "cell_type": "code", "source": [ "# Function to generate a response based on a user prompt (testing the model)\n", "def generate_response(prompt):\n", " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to('cuda')\n", " outputs = model.generate(inputs, max_length=150, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)\n", "\n", " # Decode the generated output\n", " response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n", "\n", " # Remove the prompt from the response\n", " if response.startswith(prompt):\n", " response = response[len(prompt):].strip() # Remove the prompt from the response\n", "\n", " return response" ], "metadata": { "id": "JbMs8UuSu5_R" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Example conversation\n", "user_input = \"what is desonide ointment used for\"\n", "bot_response = generate_response(user_input)\n", "print(\"Bot Response:\", bot_response)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "qsHAT1-uxC4_", "outputId": "89b73c5f-0ae9-449d-8eb4-3df1a7c146bb" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Bot Response: desonide ointment is used to treat a variety of conditions it is used to treat allergies and other skin conditions it is also used to treat certain types of infections it is also used to treat skin infections caused by bacteria that are on skin desonide is in a class of medications called antimicrobials it works by killing bacteria that cause skin infections desonide is in a class of medications called antibiotics it works by killing bacteria that cause skin infections\n" ] } ] }, { "cell_type": "code", "source": [ "# Copying the model to Google Drive (optional)\n", "import shutil\n", "\n", "# Path to the file in Colab\n", "colab_file_path = '/content/med_info_model/model.safetensors'\n", "\n", "# Path to your Google Drive\n", "drive_file_path = '/content/drive/MyDrive'\n", "\n", "# Copy the file\n", "shutil.copy(colab_file_path, drive_file_path)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 36 }, "id": "aP4IEboMxDWG", "outputId": "c00d1d74-e389-4de4-a151-d20736b6bccd" }, "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "'/content/drive/MyDrive/model.safetensors'" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 22 } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "uKYwYe5XyXgx" }, "execution_count": null, "outputs": [] } ] }