hassaanik commited on
Commit
02b9964
·
verified ·
1 Parent(s): 95d4dfe

Upload 34 files

Browse files
Files changed (34) hide show
  1. Data/cleaned_QA_data.csv +0 -0
  2. Data/cleaned_med_QA_data.csv +0 -0
  3. Notebooks/Couselling_from_Scratch.ipynb +0 -0
  4. Notebooks/Medication_from_Scratch.ipynb +1348 -0
  5. app.py +118 -0
  6. backend/__init__.py +0 -0
  7. backend/__pycache__/__init__.cpython-312.pyc +0 -0
  8. backend/__pycache__/utils.cpython-312.pyc +0 -0
  9. backend/models/diabetes_model/random_forest_modelf.joblib +3 -0
  10. backend/models/diabetes_model/standard_scaler.joblib +3 -0
  11. backend/models/medication_classification_model/age_scaler.pkl +3 -0
  12. backend/models/medication_classification_model/knn_model.pkl +3 -0
  13. backend/models/medication_classification_model/label_encoders.pkl +3 -0
  14. backend/models/medication_classification_model/medication_encoder.pkl +3 -0
  15. backend/models/medication_info/config.json +39 -0
  16. backend/models/medication_info/generation_config.json +6 -0
  17. backend/models/medication_info/merges.txt +0 -0
  18. backend/models/medication_info/model.safetensors +3 -0
  19. backend/models/medication_info/special_tokens_map.json +24 -0
  20. backend/models/medication_info/tokenizer_config.json +22 -0
  21. backend/models/medication_info/training_args.bin +3 -0
  22. backend/models/medication_info/vocab.json +0 -0
  23. backend/models/mental_health_model/config.json +39 -0
  24. backend/models/mental_health_model/generation_config.json +6 -0
  25. backend/models/mental_health_model/merges.txt +0 -0
  26. backend/models/mental_health_model/model.safetensors +3 -0
  27. backend/models/mental_health_model/special_tokens_map.json +24 -0
  28. backend/models/mental_health_model/tokenizer_config.json +22 -0
  29. backend/models/mental_health_model/training_args.bin +3 -0
  30. backend/models/mental_health_model/vocab.json +0 -0
  31. backend/utils.py +125 -0
  32. frontend/index.html +80 -0
  33. frontend/script.js +87 -0
  34. frontend/styles.css +89 -0
Data/cleaned_QA_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
Data/cleaned_med_QA_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
Notebooks/Couselling_from_Scratch.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Notebooks/Medication_from_Scratch.ipynb ADDED
@@ -0,0 +1,1348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "source": [
22
+ "### Data Preparation"
23
+ ],
24
+ "metadata": {
25
+ "id": "ga8c1nhja4Qy"
26
+ }
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "source": [
31
+ "!pip install opendatasets"
32
+ ],
33
+ "metadata": {
34
+ "colab": {
35
+ "base_uri": "https://localhost:8080/"
36
+ },
37
+ "id": "O7NczD5abI6o",
38
+ "outputId": "422faa21-1ee0-4582-9315-4c2b01f4518d"
39
+ },
40
+ "execution_count": 1,
41
+ "outputs": [
42
+ {
43
+ "output_type": "stream",
44
+ "name": "stdout",
45
+ "text": [
46
+ "Collecting opendatasets\n",
47
+ " Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)\n",
48
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from opendatasets) (4.66.5)\n",
49
+ "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (from opendatasets) (1.6.17)\n",
50
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from opendatasets) (8.1.7)\n",
51
+ "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (1.16.0)\n",
52
+ "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2024.8.30)\n",
53
+ "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.8.2)\n",
54
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.32.3)\n",
55
+ "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (8.0.4)\n",
56
+ "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (2.0.7)\n",
57
+ "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle->opendatasets) (6.1.0)\n",
58
+ "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle->opendatasets) (0.5.1)\n",
59
+ "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle->opendatasets) (1.3)\n",
60
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.3.2)\n",
61
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle->opendatasets) (3.10)\n",
62
+ "Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)\n",
63
+ "Installing collected packages: opendatasets\n",
64
+ "Successfully installed opendatasets-0.1.22\n"
65
+ ]
66
+ }
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "source": [
72
+ "import opendatasets as od\n",
73
+ "od.download('https://www.kaggle.com/datasets/hassaanidrees/medinfo?select=MedInfo2019-QA-Medications.xlsx')"
74
+ ],
75
+ "metadata": {
76
+ "colab": {
77
+ "base_uri": "https://localhost:8080/"
78
+ },
79
+ "id": "7QSxa8cRbIug",
80
+ "outputId": "088ef3d5-b3fc-4860-8928-bb872ff83ab5"
81
+ },
82
+ "execution_count": 2,
83
+ "outputs": [
84
+ {
85
+ "output_type": "stream",
86
+ "name": "stdout",
87
+ "text": [
88
+ "Dataset URL: https://www.kaggle.com/datasets/hassaanidrees/medinfo\n",
89
+ "Downloading medinfo.zip to ./medinfo\n"
90
+ ]
91
+ },
92
+ {
93
+ "output_type": "stream",
94
+ "name": "stderr",
95
+ "text": [
96
+ "100%|██████████| 159k/159k [00:00<00:00, 480kB/s]"
97
+ ]
98
+ },
99
+ {
100
+ "output_type": "stream",
101
+ "name": "stdout",
102
+ "text": [
103
+ "\n"
104
+ ]
105
+ },
106
+ {
107
+ "output_type": "stream",
108
+ "name": "stderr",
109
+ "text": [
110
+ "\n"
111
+ ]
112
+ }
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "source": [
118
+ "# Import pandas for data analysis\n",
119
+ "import pandas as pd\n",
120
+ "df = pd.read_excel(\"/content/medinfo/MedInfo2019-QA-Medications.xlsx\")\n",
121
+ "df = df[['Question','Answer']]"
122
+ ],
123
+ "metadata": {
124
+ "id": "sooD64r3bIDJ"
125
+ },
126
+ "execution_count": 3,
127
+ "outputs": []
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "df.head() #show first five rows"
133
+ ],
134
+ "metadata": {
135
+ "colab": {
136
+ "base_uri": "https://localhost:8080/",
137
+ "height": 206
138
+ },
139
+ "id": "eRneQPLAqAJL",
140
+ "outputId": "d1772f7e-8edd-4687-9c1a-c3102e86138e"
141
+ },
142
+ "execution_count": null,
143
+ "outputs": [
144
+ {
145
+ "output_type": "execute_result",
146
+ "data": {
147
+ "text/plain": [
148
+ " Question \\\n",
149
+ "0 how does rivatigmine and otc sleep medicine in... \n",
150
+ "1 how does valium affect the brain \n",
151
+ "2 what is morphine \n",
152
+ "3 what are the milligrams for oxycodone e \n",
153
+ "4 81% aspirin contain resin and shellac in it. ? \n",
154
+ "\n",
155
+ " Answer \n",
156
+ "0 tell your doctor and pharmacist what prescript... \n",
157
+ "1 Diazepam is a benzodiazepine that exerts anxio... \n",
158
+ "2 Morphine is a pain medication of the opiate fa... \n",
159
+ "3 … 10 mg … 20 mg … 40 mg … 80 mg ... \n",
160
+ "4 Inactive Ingredients Ingredient Name "
161
+ ],
162
+ "text/html": [
163
+ "\n",
164
+ " <div id=\"df-d79eadfb-a1cc-4af0-87f3-9921298edcfe\" class=\"colab-df-container\">\n",
165
+ " <div>\n",
166
+ "<style scoped>\n",
167
+ " .dataframe tbody tr th:only-of-type {\n",
168
+ " vertical-align: middle;\n",
169
+ " }\n",
170
+ "\n",
171
+ " .dataframe tbody tr th {\n",
172
+ " vertical-align: top;\n",
173
+ " }\n",
174
+ "\n",
175
+ " .dataframe thead th {\n",
176
+ " text-align: right;\n",
177
+ " }\n",
178
+ "</style>\n",
179
+ "<table border=\"1\" class=\"dataframe\">\n",
180
+ " <thead>\n",
181
+ " <tr style=\"text-align: right;\">\n",
182
+ " <th></th>\n",
183
+ " <th>Question</th>\n",
184
+ " <th>Answer</th>\n",
185
+ " </tr>\n",
186
+ " </thead>\n",
187
+ " <tbody>\n",
188
+ " <tr>\n",
189
+ " <th>0</th>\n",
190
+ " <td>how does rivatigmine and otc sleep medicine in...</td>\n",
191
+ " <td>tell your doctor and pharmacist what prescript...</td>\n",
192
+ " </tr>\n",
193
+ " <tr>\n",
194
+ " <th>1</th>\n",
195
+ " <td>how does valium affect the brain</td>\n",
196
+ " <td>Diazepam is a benzodiazepine that exerts anxio...</td>\n",
197
+ " </tr>\n",
198
+ " <tr>\n",
199
+ " <th>2</th>\n",
200
+ " <td>what is morphine</td>\n",
201
+ " <td>Morphine is a pain medication of the opiate fa...</td>\n",
202
+ " </tr>\n",
203
+ " <tr>\n",
204
+ " <th>3</th>\n",
205
+ " <td>what are the milligrams for oxycodone e</td>\n",
206
+ " <td>… 10 mg … 20 mg … 40 mg … 80 mg ...</td>\n",
207
+ " </tr>\n",
208
+ " <tr>\n",
209
+ " <th>4</th>\n",
210
+ " <td>81% aspirin contain resin and shellac in it. ?</td>\n",
211
+ " <td>Inactive Ingredients Ingredient Name</td>\n",
212
+ " </tr>\n",
213
+ " </tbody>\n",
214
+ "</table>\n",
215
+ "</div>\n",
216
+ " <div class=\"colab-df-buttons\">\n",
217
+ "\n",
218
+ " <div class=\"colab-df-container\">\n",
219
+ " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-d79eadfb-a1cc-4af0-87f3-9921298edcfe')\"\n",
220
+ " title=\"Convert this dataframe to an interactive table.\"\n",
221
+ " style=\"display:none;\">\n",
222
+ "\n",
223
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
224
+ " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
225
+ " </svg>\n",
226
+ " </button>\n",
227
+ "\n",
228
+ " <style>\n",
229
+ " .colab-df-container {\n",
230
+ " display:flex;\n",
231
+ " gap: 12px;\n",
232
+ " }\n",
233
+ "\n",
234
+ " .colab-df-convert {\n",
235
+ " background-color: #E8F0FE;\n",
236
+ " border: none;\n",
237
+ " border-radius: 50%;\n",
238
+ " cursor: pointer;\n",
239
+ " display: none;\n",
240
+ " fill: #1967D2;\n",
241
+ " height: 32px;\n",
242
+ " padding: 0 0 0 0;\n",
243
+ " width: 32px;\n",
244
+ " }\n",
245
+ "\n",
246
+ " .colab-df-convert:hover {\n",
247
+ " background-color: #E2EBFA;\n",
248
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
249
+ " fill: #174EA6;\n",
250
+ " }\n",
251
+ "\n",
252
+ " .colab-df-buttons div {\n",
253
+ " margin-bottom: 4px;\n",
254
+ " }\n",
255
+ "\n",
256
+ " [theme=dark] .colab-df-convert {\n",
257
+ " background-color: #3B4455;\n",
258
+ " fill: #D2E3FC;\n",
259
+ " }\n",
260
+ "\n",
261
+ " [theme=dark] .colab-df-convert:hover {\n",
262
+ " background-color: #434B5C;\n",
263
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
264
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
265
+ " fill: #FFFFFF;\n",
266
+ " }\n",
267
+ " </style>\n",
268
+ "\n",
269
+ " <script>\n",
270
+ " const buttonEl =\n",
271
+ " document.querySelector('#df-d79eadfb-a1cc-4af0-87f3-9921298edcfe button.colab-df-convert');\n",
272
+ " buttonEl.style.display =\n",
273
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
274
+ "\n",
275
+ " async function convertToInteractive(key) {\n",
276
+ " const element = document.querySelector('#df-d79eadfb-a1cc-4af0-87f3-9921298edcfe');\n",
277
+ " const dataTable =\n",
278
+ " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
279
+ " [key], {});\n",
280
+ " if (!dataTable) return;\n",
281
+ "\n",
282
+ " const docLinkHtml = 'Like what you see? Visit the ' +\n",
283
+ " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
284
+ " + ' to learn more about interactive tables.';\n",
285
+ " element.innerHTML = '';\n",
286
+ " dataTable['output_type'] = 'display_data';\n",
287
+ " await google.colab.output.renderOutput(dataTable, element);\n",
288
+ " const docLink = document.createElement('div');\n",
289
+ " docLink.innerHTML = docLinkHtml;\n",
290
+ " element.appendChild(docLink);\n",
291
+ " }\n",
292
+ " </script>\n",
293
+ " </div>\n",
294
+ "\n",
295
+ "\n",
296
+ "<div id=\"df-862caeb9-15bf-47b1-b083-8b9307722b80\">\n",
297
+ " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-862caeb9-15bf-47b1-b083-8b9307722b80')\"\n",
298
+ " title=\"Suggest charts\"\n",
299
+ " style=\"display:none;\">\n",
300
+ "\n",
301
+ "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
302
+ " width=\"24px\">\n",
303
+ " <g>\n",
304
+ " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
305
+ " </g>\n",
306
+ "</svg>\n",
307
+ " </button>\n",
308
+ "\n",
309
+ "<style>\n",
310
+ " .colab-df-quickchart {\n",
311
+ " --bg-color: #E8F0FE;\n",
312
+ " --fill-color: #1967D2;\n",
313
+ " --hover-bg-color: #E2EBFA;\n",
314
+ " --hover-fill-color: #174EA6;\n",
315
+ " --disabled-fill-color: #AAA;\n",
316
+ " --disabled-bg-color: #DDD;\n",
317
+ " }\n",
318
+ "\n",
319
+ " [theme=dark] .colab-df-quickchart {\n",
320
+ " --bg-color: #3B4455;\n",
321
+ " --fill-color: #D2E3FC;\n",
322
+ " --hover-bg-color: #434B5C;\n",
323
+ " --hover-fill-color: #FFFFFF;\n",
324
+ " --disabled-bg-color: #3B4455;\n",
325
+ " --disabled-fill-color: #666;\n",
326
+ " }\n",
327
+ "\n",
328
+ " .colab-df-quickchart {\n",
329
+ " background-color: var(--bg-color);\n",
330
+ " border: none;\n",
331
+ " border-radius: 50%;\n",
332
+ " cursor: pointer;\n",
333
+ " display: none;\n",
334
+ " fill: var(--fill-color);\n",
335
+ " height: 32px;\n",
336
+ " padding: 0;\n",
337
+ " width: 32px;\n",
338
+ " }\n",
339
+ "\n",
340
+ " .colab-df-quickchart:hover {\n",
341
+ " background-color: var(--hover-bg-color);\n",
342
+ " box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
343
+ " fill: var(--button-hover-fill-color);\n",
344
+ " }\n",
345
+ "\n",
346
+ " .colab-df-quickchart-complete:disabled,\n",
347
+ " .colab-df-quickchart-complete:disabled:hover {\n",
348
+ " background-color: var(--disabled-bg-color);\n",
349
+ " fill: var(--disabled-fill-color);\n",
350
+ " box-shadow: none;\n",
351
+ " }\n",
352
+ "\n",
353
+ " .colab-df-spinner {\n",
354
+ " border: 2px solid var(--fill-color);\n",
355
+ " border-color: transparent;\n",
356
+ " border-bottom-color: var(--fill-color);\n",
357
+ " animation:\n",
358
+ " spin 1s steps(1) infinite;\n",
359
+ " }\n",
360
+ "\n",
361
+ " @keyframes spin {\n",
362
+ " 0% {\n",
363
+ " border-color: transparent;\n",
364
+ " border-bottom-color: var(--fill-color);\n",
365
+ " border-left-color: var(--fill-color);\n",
366
+ " }\n",
367
+ " 20% {\n",
368
+ " border-color: transparent;\n",
369
+ " border-left-color: var(--fill-color);\n",
370
+ " border-top-color: var(--fill-color);\n",
371
+ " }\n",
372
+ " 30% {\n",
373
+ " border-color: transparent;\n",
374
+ " border-left-color: var(--fill-color);\n",
375
+ " border-top-color: var(--fill-color);\n",
376
+ " border-right-color: var(--fill-color);\n",
377
+ " }\n",
378
+ " 40% {\n",
379
+ " border-color: transparent;\n",
380
+ " border-right-color: var(--fill-color);\n",
381
+ " border-top-color: var(--fill-color);\n",
382
+ " }\n",
383
+ " 60% {\n",
384
+ " border-color: transparent;\n",
385
+ " border-right-color: var(--fill-color);\n",
386
+ " }\n",
387
+ " 80% {\n",
388
+ " border-color: transparent;\n",
389
+ " border-right-color: var(--fill-color);\n",
390
+ " border-bottom-color: var(--fill-color);\n",
391
+ " }\n",
392
+ " 90% {\n",
393
+ " border-color: transparent;\n",
394
+ " border-bottom-color: var(--fill-color);\n",
395
+ " }\n",
396
+ " }\n",
397
+ "</style>\n",
398
+ "\n",
399
+ " <script>\n",
400
+ " async function quickchart(key) {\n",
401
+ " const quickchartButtonEl =\n",
402
+ " document.querySelector('#' + key + ' button');\n",
403
+ " quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
404
+ " quickchartButtonEl.classList.add('colab-df-spinner');\n",
405
+ " try {\n",
406
+ " const charts = await google.colab.kernel.invokeFunction(\n",
407
+ " 'suggestCharts', [key], {});\n",
408
+ " } catch (error) {\n",
409
+ " console.error('Error during call to suggestCharts:', error);\n",
410
+ " }\n",
411
+ " quickchartButtonEl.classList.remove('colab-df-spinner');\n",
412
+ " quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
413
+ " }\n",
414
+ " (() => {\n",
415
+ " let quickchartButtonEl =\n",
416
+ " document.querySelector('#df-862caeb9-15bf-47b1-b083-8b9307722b80 button');\n",
417
+ " quickchartButtonEl.style.display =\n",
418
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
419
+ " })();\n",
420
+ " </script>\n",
421
+ "</div>\n",
422
+ "\n",
423
+ " </div>\n",
424
+ " </div>\n"
425
+ ],
426
+ "application/vnd.google.colaboratory.intrinsic+json": {
427
+ "type": "dataframe",
428
+ "variable_name": "df",
429
+ "summary": "{\n \"name\": \"df\",\n \"rows\": 690,\n \"fields\": [\n {\n \"column\": \"Question\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 651,\n \"samples\": [\n \"how is marijuana used\",\n \"tudorza pressair is what schedule drug\",\n \"how long does ecstasy or mda leave your body\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Answer\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 652,\n \"samples\": [\n \"Marijuana is best known as a drug that people smoke or eat to get high. It is derived from the plant Cannabis sativa. Possession of marijuana is illegal under federal law. Medical marijuana refers to using marijuana to treat certain medical conditions. In the United States, about half of the states have legalized marijuana for medical use.\",\n \"Color - GRAY, Shape - CAPSULE (biconvex), Score - no score, Size - 12mm, Imprint Code - m10\",\n \"Quantity: 60; Per Unit: $4.68 \\u2013 $15.91; Price: $280.99 \\u2013 $954.47\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
430
+ }
431
+ },
432
+ "metadata": {},
433
+ "execution_count": 4
434
+ }
435
+ ]
436
+ },
437
+ {
438
+ "cell_type": "code",
439
+ "source": [
440
+ "df.Question[0]"
441
+ ],
442
+ "metadata": {
443
+ "colab": {
444
+ "base_uri": "https://localhost:8080/",
445
+ "height": 36
446
+ },
447
+ "id": "4SEkJJHwqBwo",
448
+ "outputId": "7aeec0ad-b51a-44fa-f2e1-5a93b61246d5"
449
+ },
450
+ "execution_count": null,
451
+ "outputs": [
452
+ {
453
+ "output_type": "execute_result",
454
+ "data": {
455
+ "text/plain": [
456
+ "'how does rivatigmine and otc sleep medicine interact'"
457
+ ],
458
+ "application/vnd.google.colaboratory.intrinsic+json": {
459
+ "type": "string"
460
+ }
461
+ },
462
+ "metadata": {},
463
+ "execution_count": 5
464
+ }
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "code",
469
+ "source": [
470
+ "df.Answer[0]"
471
+ ],
472
+ "metadata": {
473
+ "colab": {
474
+ "base_uri": "https://localhost:8080/",
475
+ "height": 105
476
+ },
477
+ "id": "qTllg8a-qGXW",
478
+ "outputId": "a6b8bca7-135e-4e26-e0ff-a2a1424bc45c"
479
+ },
480
+ "execution_count": null,
481
+ "outputs": [
482
+ {
483
+ "output_type": "execute_result",
484
+ "data": {
485
+ "text/plain": [
486
+ "\"tell your doctor and pharmacist what prescription and nonprescription medications, vitamins, nutritional supplements, and herbal products you are taking or plan to take. Be sure to mention any of the following: antihistamines; aspirin and other nonsteroidal anti-inflammatory medications (NSAIDs) such as ibuprofen (Advil, Motrin) and naproxen (Aleve, Naprosyn); bethanechol (Duvoid, Urecholine); ipratropium (Atrovent, in Combivent, DuoNeb); and medications for Alzheimer's disease, glaucoma, irritable bowel disease, motion sickness, ulcers, or urinary problems. Your doctor may need to change the doses of your medications or monitor you carefully for side effects.\""
487
+ ],
488
+ "application/vnd.google.colaboratory.intrinsic+json": {
489
+ "type": "string"
490
+ }
491
+ },
492
+ "metadata": {},
493
+ "execution_count": 6
494
+ }
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "source": [
500
+ "df.shape # 690 rows | 2 cols"
501
+ ],
502
+ "metadata": {
503
+ "colab": {
504
+ "base_uri": "https://localhost:8080/"
505
+ },
506
+ "id": "xs_qECG1qIW5",
507
+ "outputId": "678a409c-9164-48f4-803e-501d3dff3c96"
508
+ },
509
+ "execution_count": null,
510
+ "outputs": [
511
+ {
512
+ "output_type": "execute_result",
513
+ "data": {
514
+ "text/plain": [
515
+ "(690, 2)"
516
+ ]
517
+ },
518
+ "metadata": {},
519
+ "execution_count": 7
520
+ }
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "source": [
526
+ "!pip install cleantext"
527
+ ],
528
+ "metadata": {
529
+ "colab": {
530
+ "base_uri": "https://localhost:8080/"
531
+ },
532
+ "id": "LPvkkbdbrNp-",
533
+ "outputId": "938e6a8d-fb4b-4112-9a0e-3139146e56eb"
534
+ },
535
+ "execution_count": null,
536
+ "outputs": [
537
+ {
538
+ "output_type": "stream",
539
+ "name": "stdout",
540
+ "text": [
541
+ "Collecting cleantext\n",
542
+ " Downloading cleantext-1.1.4-py3-none-any.whl.metadata (3.5 kB)\n",
543
+ "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from cleantext) (3.8.1)\n",
544
+ "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (8.1.7)\n",
545
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (1.4.2)\n",
546
+ "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (2024.5.15)\n",
547
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk->cleantext) (4.66.5)\n",
548
+ "Downloading cleantext-1.1.4-py3-none-any.whl (4.9 kB)\n",
549
+ "Installing collected packages: cleantext\n",
550
+ "Successfully installed cleantext-1.1.4\n"
551
+ ]
552
+ }
553
+ ]
554
+ },
555
+ {
556
+ "cell_type": "code",
557
+ "source": [
558
+ "import cleantext\n",
559
+ "\n",
560
+ "# Function to clean text data by removing unwanted characters and formatting\n",
561
+ "def clean(textdata):\n",
562
+ " cleaned_text = []\n",
563
+ " for i in textdata:\n",
564
+ " cleaned_text.append(cleantext.clean(str(i), extra_spaces=True, lowercase=True, stopwords=False, stemming=False, numbers=True, punct=True, clean_all = True))\n",
565
+ "\n",
566
+ " return cleaned_text"
567
+ ],
568
+ "metadata": {
569
+ "id": "dws3d49Lqv1b"
570
+ },
571
+ "execution_count": null,
572
+ "outputs": []
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "source": [
577
+ "# Apply the clean function to the questions and answers columns\n",
578
+ "\n",
579
+ "df.Question = list(clean(df.Question))\n",
580
+ "df.Answer = list(clean(df.Answer))"
581
+ ],
582
+ "metadata": {
583
+ "id": "H1ia-jFqrIsG"
584
+ },
585
+ "execution_count": null,
586
+ "outputs": []
587
+ },
588
+ {
589
+ "cell_type": "code",
590
+ "source": [
591
+ "# Save the cleaned data into a new CSV file & save\n",
592
+ "df.to_csv(\"cleaned_med_QA_data.csv\", index=False)"
593
+ ],
594
+ "metadata": {
595
+ "id": "HcB15JQirImk"
596
+ },
597
+ "execution_count": null,
598
+ "outputs": []
599
+ },
600
+ {
601
+ "cell_type": "markdown",
602
+ "source": [
603
+ "### GPT-2 Model"
604
+ ],
605
+ "metadata": {
606
+ "id": "zw5mkpmueML4"
607
+ }
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "source": [
612
+ "!pip install datasets"
613
+ ],
614
+ "metadata": {
615
+ "colab": {
616
+ "base_uri": "https://localhost:8080/",
617
+ "height": 1000
618
+ },
619
+ "id": "QhgGKgZ-rYAY",
620
+ "outputId": "f2334a48-2745-42b5-f5fd-929ca58e1ed6",
621
+ "collapsed": true
622
+ },
623
+ "execution_count": null,
624
+ "outputs": [
625
+ {
626
+ "output_type": "stream",
627
+ "name": "stdout",
628
+ "text": [
629
+ "Collecting datasets\n",
630
+ " Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)\n",
631
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.0)\n",
632
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n",
633
+ "Collecting pyarrow>=15.0.0 (from datasets)\n",
634
+ " Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\n",
635
+ "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n",
636
+ " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n",
637
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n",
638
+ "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n",
639
+ "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\n",
640
+ "Collecting xxhash (from datasets)\n",
641
+ " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n",
642
+ "Collecting multiprocess (from datasets)\n",
643
+ " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n",
644
+ "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n",
645
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.5)\n",
646
+ "Requirement already satisfied: huggingface-hub>=0.22.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.24.6)\n",
647
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n",
648
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n",
649
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.0)\n",
650
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
651
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n",
652
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n",
653
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n",
654
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.11.1)\n",
655
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
656
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.22.0->datasets) (4.12.2)\n",
657
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2)\n",
658
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.8)\n",
659
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.0.7)\n",
660
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n",
661
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
662
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n",
663
+ "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n",
664
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
665
+ "Downloading datasets-3.0.0-py3-none-any.whl (474 kB)\n",
666
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m474.3/474.3 kB\u001b[0m \u001b[31m32.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
667
+ "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
668
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
669
+ "\u001b[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)\n",
670
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.9/39.9 MB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
671
+ "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n",
672
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m14.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
673
+ "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n",
674
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
675
+ "\u001b[?25hInstalling collected packages: xxhash, pyarrow, dill, multiprocess, datasets\n",
676
+ " Attempting uninstall: pyarrow\n",
677
+ " Found existing installation: pyarrow 14.0.2\n",
678
+ " Uninstalling pyarrow-14.0.2:\n",
679
+ " Successfully uninstalled pyarrow-14.0.2\n",
680
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
681
+ "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.\n",
682
+ "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.\u001b[0m\u001b[31m\n",
683
+ "\u001b[0mSuccessfully installed datasets-3.0.0 dill-0.3.8 multiprocess-0.70.16 pyarrow-17.0.0 xxhash-3.5.0\n"
684
+ ]
685
+ },
686
+ {
687
+ "output_type": "display_data",
688
+ "data": {
689
+ "application/vnd.colab-display-data+json": {
690
+ "pip_warning": {
691
+ "packages": [
692
+ "pyarrow"
693
+ ]
694
+ },
695
+ "id": "a6cd6efad93b4c4cb5a29a91b023de8a"
696
+ }
697
+ },
698
+ "metadata": {}
699
+ }
700
+ ]
701
+ },
702
+ {
703
+ "cell_type": "code",
704
+ "source": [
705
+ "from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments\n",
706
+ "import torch\n",
707
+ "from datasets import load_dataset\n",
708
+ "\n",
709
+ "# Load the GPT-2 model and tokenizer\n",
710
+ "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
711
+ "model = GPT2LMHeadModel.from_pretrained('gpt2')"
712
+ ],
713
+ "metadata": {
714
+ "colab": {
715
+ "base_uri": "https://localhost:8080/"
716
+ },
717
+ "id": "xgGgvCqerk-1",
718
+ "outputId": "e338ee7f-c898-41c4-b1f6-036f115d3735"
719
+ },
720
+ "execution_count": null,
721
+ "outputs": [
722
+ {
723
+ "output_type": "stream",
724
+ "name": "stderr",
725
+ "text": [
726
+ "/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
727
+ " warnings.warn(\n"
728
+ ]
729
+ }
730
+ ]
731
+ },
732
+ {
733
+ "cell_type": "code",
734
+ "source": [
735
+ "# Set the padding token for the tokenizer to be the end-of-sequence token\n",
736
+ "tokenizer.pad_token = tokenizer.eos_token\n",
737
+ "\n",
738
+ "# Maximum sequence length that GPT-2 can handle\n",
739
+ "max_length = tokenizer.model_max_length\n",
740
+ "print(max_length)"
741
+ ],
742
+ "metadata": {
743
+ "colab": {
744
+ "base_uri": "https://localhost:8080/"
745
+ },
746
+ "id": "EeiMYkpCrp62",
747
+ "outputId": "e8b0118b-1694-4d9e-d666-e791b083f63f"
748
+ },
749
+ "execution_count": null,
750
+ "outputs": [
751
+ {
752
+ "output_type": "stream",
753
+ "name": "stdout",
754
+ "text": [
755
+ "1024\n"
756
+ ]
757
+ }
758
+ ]
759
+ },
760
+ {
761
+ "cell_type": "code",
762
+ "source": [
763
+ "# Load the cleaned QA dataset as a training set using the 'datasets' library\n",
764
+ "dataset = load_dataset('csv', data_files={'train': 'cleaned_med_QA_data.csv'}, split='train')"
765
+ ],
766
+ "metadata": {
767
+ "id": "MW5Ad0exrry3"
768
+ },
769
+ "execution_count": null,
770
+ "outputs": []
771
+ },
772
+ {
773
+ "cell_type": "code",
774
+ "source": [
775
+ "#Function to tokenize questions and answers and prepare them for the model\n",
776
+ "def tokenize_function(examples):\n",
777
+ " '''1. Combine each question and answer into a single input string\n",
778
+ " 2. Tokenize the combined text using the GPT-2 tokenizer\n",
779
+ " 3. Set the labels to be the same as the input_ids (shifted to predict the next word)\n",
780
+ " 4. Return the tokenized output. '''\n",
781
+ "\n",
782
+ " combined_text = [str(q) + \" \" + str(a) for q, a in zip(examples['Question'], examples['Answer'])]\n",
783
+ " tokenized_output = tokenizer(combined_text, padding='max_length', truncation=True, max_length=128)\n",
784
+ "\n",
785
+ " # Set the labels to be the same as the input_ids (shifted to predict the next word)\n",
786
+ " tokenized_output['labels'] = tokenized_output['input_ids'].copy()\n",
787
+ "\n",
788
+ " return tokenized_output\n",
789
+ "\n",
790
+ "# Tokenize the entire dataset\n",
791
+ "tokenized_dataset = dataset.map(tokenize_function, batched=True)"
792
+ ],
793
+ "metadata": {
794
+ "id": "99rfOROKr-M0"
795
+ },
796
+ "execution_count": null,
797
+ "outputs": []
798
+ },
799
+ {
800
+ "cell_type": "code",
801
+ "source": [
802
+ "# Define training arguments for the GPT-2 model\n",
803
+ "training_args = TrainingArguments(\n",
804
+ " output_dir='./results', # Directory to save model outputs\n",
805
+ " num_train_epochs=20, # Train for 50 epochs\n",
806
+ " per_device_train_batch_size=16, # Batch size during training\n",
807
+ " per_device_eval_batch_size=32, # Batch size during evaluation\n",
808
+ " warmup_steps=500, # Warmup steps for learning rate scheduler\n",
809
+ " weight_decay=0.01, # Weight decay for regularization\n",
810
+ " logging_dir='./logs', # Directory for saving logs\n",
811
+ " logging_steps=10, # Log every 10 steps\n",
812
+ " save_steps=1000, # Save model checkpoints every 1000 steps\n",
813
+ ")\n",
814
+ "\n",
815
+ "# Trainer class to handle training process\n",
816
+ "trainer = Trainer(\n",
817
+ " model=model,\n",
818
+ " args=training_args,\n",
819
+ " train_dataset=tokenized_dataset,\n",
820
+ " tokenizer=tokenizer,\n",
821
+ ")\n",
822
+ "\n",
823
+ "# Train the model\n",
824
+ "trainer.train()"
825
+ ],
826
+ "metadata": {
827
+ "colab": {
828
+ "base_uri": "https://localhost:8080/",
829
+ "height": 1000
830
+ },
831
+ "id": "TQGJ16yJsCBc",
832
+ "outputId": "ec5b1ae4-83c1-4117-95fe-3aae63fc0f75",
833
+ "collapsed": true
834
+ },
835
+ "execution_count": null,
836
+ "outputs": [
837
+ {
838
+ "output_type": "display_data",
839
+ "data": {
840
+ "text/plain": [
841
+ "<IPython.core.display.HTML object>"
842
+ ],
843
+ "text/html": [
844
+ "\n",
845
+ " <div>\n",
846
+ " \n",
847
+ " <progress value='880' max='880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
848
+ " [880/880 08:45, Epoch 20/20]\n",
849
+ " </div>\n",
850
+ " <table border=\"1\" class=\"dataframe\">\n",
851
+ " <thead>\n",
852
+ " <tr style=\"text-align: left;\">\n",
853
+ " <th>Step</th>\n",
854
+ " <th>Training Loss</th>\n",
855
+ " </tr>\n",
856
+ " </thead>\n",
857
+ " <tbody>\n",
858
+ " <tr>\n",
859
+ " <td>10</td>\n",
860
+ " <td>5.891800</td>\n",
861
+ " </tr>\n",
862
+ " <tr>\n",
863
+ " <td>20</td>\n",
864
+ " <td>5.497900</td>\n",
865
+ " </tr>\n",
866
+ " <tr>\n",
867
+ " <td>30</td>\n",
868
+ " <td>4.671300</td>\n",
869
+ " </tr>\n",
870
+ " <tr>\n",
871
+ " <td>40</td>\n",
872
+ " <td>3.751500</td>\n",
873
+ " </tr>\n",
874
+ " <tr>\n",
875
+ " <td>50</td>\n",
876
+ " <td>3.016000</td>\n",
877
+ " </tr>\n",
878
+ " <tr>\n",
879
+ " <td>60</td>\n",
880
+ " <td>2.633300</td>\n",
881
+ " </tr>\n",
882
+ " <tr>\n",
883
+ " <td>70</td>\n",
884
+ " <td>2.360800</td>\n",
885
+ " </tr>\n",
886
+ " <tr>\n",
887
+ " <td>80</td>\n",
888
+ " <td>2.079000</td>\n",
889
+ " </tr>\n",
890
+ " <tr>\n",
891
+ " <td>90</td>\n",
892
+ " <td>2.145600</td>\n",
893
+ " </tr>\n",
894
+ " <tr>\n",
895
+ " <td>100</td>\n",
896
+ " <td>2.150100</td>\n",
897
+ " </tr>\n",
898
+ " <tr>\n",
899
+ " <td>110</td>\n",
900
+ " <td>2.069300</td>\n",
901
+ " </tr>\n",
902
+ " <tr>\n",
903
+ " <td>120</td>\n",
904
+ " <td>2.000300</td>\n",
905
+ " </tr>\n",
906
+ " <tr>\n",
907
+ " <td>130</td>\n",
908
+ " <td>1.919900</td>\n",
909
+ " </tr>\n",
910
+ " <tr>\n",
911
+ " <td>140</td>\n",
912
+ " <td>1.954000</td>\n",
913
+ " </tr>\n",
914
+ " <tr>\n",
915
+ " <td>150</td>\n",
916
+ " <td>1.928500</td>\n",
917
+ " </tr>\n",
918
+ " <tr>\n",
919
+ " <td>160</td>\n",
920
+ " <td>1.832900</td>\n",
921
+ " </tr>\n",
922
+ " <tr>\n",
923
+ " <td>170</td>\n",
924
+ " <td>1.921300</td>\n",
925
+ " </tr>\n",
926
+ " <tr>\n",
927
+ " <td>180</td>\n",
928
+ " <td>2.043500</td>\n",
929
+ " </tr>\n",
930
+ " <tr>\n",
931
+ " <td>190</td>\n",
932
+ " <td>1.827400</td>\n",
933
+ " </tr>\n",
934
+ " <tr>\n",
935
+ " <td>200</td>\n",
936
+ " <td>1.687700</td>\n",
937
+ " </tr>\n",
938
+ " <tr>\n",
939
+ " <td>210</td>\n",
940
+ " <td>1.782400</td>\n",
941
+ " </tr>\n",
942
+ " <tr>\n",
943
+ " <td>220</td>\n",
944
+ " <td>1.959600</td>\n",
945
+ " </tr>\n",
946
+ " <tr>\n",
947
+ " <td>230</td>\n",
948
+ " <td>1.810500</td>\n",
949
+ " </tr>\n",
950
+ " <tr>\n",
951
+ " <td>240</td>\n",
952
+ " <td>1.706800</td>\n",
953
+ " </tr>\n",
954
+ " <tr>\n",
955
+ " <td>250</td>\n",
956
+ " <td>1.662200</td>\n",
957
+ " </tr>\n",
958
+ " <tr>\n",
959
+ " <td>260</td>\n",
960
+ " <td>1.783900</td>\n",
961
+ " </tr>\n",
962
+ " <tr>\n",
963
+ " <td>270</td>\n",
964
+ " <td>1.567300</td>\n",
965
+ " </tr>\n",
966
+ " <tr>\n",
967
+ " <td>280</td>\n",
968
+ " <td>1.695100</td>\n",
969
+ " </tr>\n",
970
+ " <tr>\n",
971
+ " <td>290</td>\n",
972
+ " <td>1.681800</td>\n",
973
+ " </tr>\n",
974
+ " <tr>\n",
975
+ " <td>300</td>\n",
976
+ " <td>1.657400</td>\n",
977
+ " </tr>\n",
978
+ " <tr>\n",
979
+ " <td>310</td>\n",
980
+ " <td>1.684000</td>\n",
981
+ " </tr>\n",
982
+ " <tr>\n",
983
+ " <td>320</td>\n",
984
+ " <td>1.494700</td>\n",
985
+ " </tr>\n",
986
+ " <tr>\n",
987
+ " <td>330</td>\n",
988
+ " <td>1.556800</td>\n",
989
+ " </tr>\n",
990
+ " <tr>\n",
991
+ " <td>340</td>\n",
992
+ " <td>1.648300</td>\n",
993
+ " </tr>\n",
994
+ " <tr>\n",
995
+ " <td>350</td>\n",
996
+ " <td>1.529300</td>\n",
997
+ " </tr>\n",
998
+ " <tr>\n",
999
+ " <td>360</td>\n",
1000
+ " <td>1.421200</td>\n",
1001
+ " </tr>\n",
1002
+ " <tr>\n",
1003
+ " <td>370</td>\n",
1004
+ " <td>1.483900</td>\n",
1005
+ " </tr>\n",
1006
+ " <tr>\n",
1007
+ " <td>380</td>\n",
1008
+ " <td>1.588400</td>\n",
1009
+ " </tr>\n",
1010
+ " <tr>\n",
1011
+ " <td>390</td>\n",
1012
+ " <td>1.442200</td>\n",
1013
+ " </tr>\n",
1014
+ " <tr>\n",
1015
+ " <td>400</td>\n",
1016
+ " <td>1.524600</td>\n",
1017
+ " </tr>\n",
1018
+ " <tr>\n",
1019
+ " <td>410</td>\n",
1020
+ " <td>1.469100</td>\n",
1021
+ " </tr>\n",
1022
+ " <tr>\n",
1023
+ " <td>420</td>\n",
1024
+ " <td>1.412900</td>\n",
1025
+ " </tr>\n",
1026
+ " <tr>\n",
1027
+ " <td>430</td>\n",
1028
+ " <td>1.388300</td>\n",
1029
+ " </tr>\n",
1030
+ " <tr>\n",
1031
+ " <td>440</td>\n",
1032
+ " <td>1.414400</td>\n",
1033
+ " </tr>\n",
1034
+ " <tr>\n",
1035
+ " <td>450</td>\n",
1036
+ " <td>1.368200</td>\n",
1037
+ " </tr>\n",
1038
+ " <tr>\n",
1039
+ " <td>460</td>\n",
1040
+ " <td>1.374900</td>\n",
1041
+ " </tr>\n",
1042
+ " <tr>\n",
1043
+ " <td>470</td>\n",
1044
+ " <td>1.336500</td>\n",
1045
+ " </tr>\n",
1046
+ " <tr>\n",
1047
+ " <td>480</td>\n",
1048
+ " <td>1.294900</td>\n",
1049
+ " </tr>\n",
1050
+ " <tr>\n",
1051
+ " <td>490</td>\n",
1052
+ " <td>1.231700</td>\n",
1053
+ " </tr>\n",
1054
+ " <tr>\n",
1055
+ " <td>500</td>\n",
1056
+ " <td>1.287600</td>\n",
1057
+ " </tr>\n",
1058
+ " <tr>\n",
1059
+ " <td>510</td>\n",
1060
+ " <td>1.248500</td>\n",
1061
+ " </tr>\n",
1062
+ " <tr>\n",
1063
+ " <td>520</td>\n",
1064
+ " <td>1.220700</td>\n",
1065
+ " </tr>\n",
1066
+ " <tr>\n",
1067
+ " <td>530</td>\n",
1068
+ " <td>1.335700</td>\n",
1069
+ " </tr>\n",
1070
+ " <tr>\n",
1071
+ " <td>540</td>\n",
1072
+ " <td>1.094200</td>\n",
1073
+ " </tr>\n",
1074
+ " <tr>\n",
1075
+ " <td>550</td>\n",
1076
+ " <td>1.151400</td>\n",
1077
+ " </tr>\n",
1078
+ " <tr>\n",
1079
+ " <td>560</td>\n",
1080
+ " <td>1.215000</td>\n",
1081
+ " </tr>\n",
1082
+ " <tr>\n",
1083
+ " <td>570</td>\n",
1084
+ " <td>1.235600</td>\n",
1085
+ " </tr>\n",
1086
+ " <tr>\n",
1087
+ " <td>580</td>\n",
1088
+ " <td>1.139800</td>\n",
1089
+ " </tr>\n",
1090
+ " <tr>\n",
1091
+ " <td>590</td>\n",
1092
+ " <td>1.119600</td>\n",
1093
+ " </tr>\n",
1094
+ " <tr>\n",
1095
+ " <td>600</td>\n",
1096
+ " <td>1.148000</td>\n",
1097
+ " </tr>\n",
1098
+ " <tr>\n",
1099
+ " <td>610</td>\n",
1100
+ " <td>1.057300</td>\n",
1101
+ " </tr>\n",
1102
+ " <tr>\n",
1103
+ " <td>620</td>\n",
1104
+ " <td>1.039700</td>\n",
1105
+ " </tr>\n",
1106
+ " <tr>\n",
1107
+ " <td>630</td>\n",
1108
+ " <td>1.081300</td>\n",
1109
+ " </tr>\n",
1110
+ " <tr>\n",
1111
+ " <td>640</td>\n",
1112
+ " <td>0.960300</td>\n",
1113
+ " </tr>\n",
1114
+ " <tr>\n",
1115
+ " <td>650</td>\n",
1116
+ " <td>1.026400</td>\n",
1117
+ " </tr>\n",
1118
+ " <tr>\n",
1119
+ " <td>660</td>\n",
1120
+ " <td>1.049900</td>\n",
1121
+ " </tr>\n",
1122
+ " <tr>\n",
1123
+ " <td>670</td>\n",
1124
+ " <td>0.967600</td>\n",
1125
+ " </tr>\n",
1126
+ " <tr>\n",
1127
+ " <td>680</td>\n",
1128
+ " <td>0.902100</td>\n",
1129
+ " </tr>\n",
1130
+ " <tr>\n",
1131
+ " <td>690</td>\n",
1132
+ " <td>0.950900</td>\n",
1133
+ " </tr>\n",
1134
+ " <tr>\n",
1135
+ " <td>700</td>\n",
1136
+ " <td>0.998500</td>\n",
1137
+ " </tr>\n",
1138
+ " <tr>\n",
1139
+ " <td>710</td>\n",
1140
+ " <td>1.043500</td>\n",
1141
+ " </tr>\n",
1142
+ " <tr>\n",
1143
+ " <td>720</td>\n",
1144
+ " <td>0.877700</td>\n",
1145
+ " </tr>\n",
1146
+ " <tr>\n",
1147
+ " <td>730</td>\n",
1148
+ " <td>0.818800</td>\n",
1149
+ " </tr>\n",
1150
+ " <tr>\n",
1151
+ " <td>740</td>\n",
1152
+ " <td>0.949500</td>\n",
1153
+ " </tr>\n",
1154
+ " <tr>\n",
1155
+ " <td>750</td>\n",
1156
+ " <td>1.032200</td>\n",
1157
+ " </tr>\n",
1158
+ " <tr>\n",
1159
+ " <td>760</td>\n",
1160
+ " <td>0.813600</td>\n",
1161
+ " </tr>\n",
1162
+ " <tr>\n",
1163
+ " <td>770</td>\n",
1164
+ " <td>0.871600</td>\n",
1165
+ " </tr>\n",
1166
+ " <tr>\n",
1167
+ " <td>780</td>\n",
1168
+ " <td>0.877400</td>\n",
1169
+ " </tr>\n",
1170
+ " <tr>\n",
1171
+ " <td>790</td>\n",
1172
+ " <td>0.952400</td>\n",
1173
+ " </tr>\n",
1174
+ " <tr>\n",
1175
+ " <td>800</td>\n",
1176
+ " <td>0.819600</td>\n",
1177
+ " </tr>\n",
1178
+ " <tr>\n",
1179
+ " <td>810</td>\n",
1180
+ " <td>0.852700</td>\n",
1181
+ " </tr>\n",
1182
+ " <tr>\n",
1183
+ " <td>820</td>\n",
1184
+ " <td>0.848300</td>\n",
1185
+ " </tr>\n",
1186
+ " <tr>\n",
1187
+ " <td>830</td>\n",
1188
+ " <td>0.834200</td>\n",
1189
+ " </tr>\n",
1190
+ " <tr>\n",
1191
+ " <td>840</td>\n",
1192
+ " <td>0.900900</td>\n",
1193
+ " </tr>\n",
1194
+ " <tr>\n",
1195
+ " <td>850</td>\n",
1196
+ " <td>0.830800</td>\n",
1197
+ " </tr>\n",
1198
+ " <tr>\n",
1199
+ " <td>860</td>\n",
1200
+ " <td>0.864700</td>\n",
1201
+ " </tr>\n",
1202
+ " <tr>\n",
1203
+ " <td>870</td>\n",
1204
+ " <td>0.842200</td>\n",
1205
+ " </tr>\n",
1206
+ " <tr>\n",
1207
+ " <td>880</td>\n",
1208
+ " <td>0.865000</td>\n",
1209
+ " </tr>\n",
1210
+ " </tbody>\n",
1211
+ "</table><p>"
1212
+ ]
1213
+ },
1214
+ "metadata": {}
1215
+ },
1216
+ {
1217
+ "output_type": "execute_result",
1218
+ "data": {
1219
+ "text/plain": [
1220
+ "TrainOutput(global_step=880, training_loss=1.5622584277933294, metrics={'train_runtime': 525.9662, 'train_samples_per_second': 26.237, 'train_steps_per_second': 1.673, 'total_flos': 901457510400000.0, 'train_loss': 1.5622584277933294, 'epoch': 20.0})"
1221
+ ]
1222
+ },
1223
+ "metadata": {},
1224
+ "execution_count": 13
1225
+ }
1226
+ ]
1227
+ },
1228
+ {
1229
+ "cell_type": "code",
1230
+ "source": [
1231
+ "# Save the model\n",
1232
+ "trainer.save_model('med_info_model')"
1233
+ ],
1234
+ "metadata": {
1235
+ "id": "4UrH8iP0u6Cp"
1236
+ },
1237
+ "execution_count": null,
1238
+ "outputs": []
1239
+ },
1240
+ {
1241
+ "cell_type": "markdown",
1242
+ "source": [
1243
+ "### Testing"
1244
+ ],
1245
+ "metadata": {
1246
+ "id": "VhXRJT6jeTuz"
1247
+ }
1248
+ },
1249
+ {
1250
+ "cell_type": "code",
1251
+ "source": [
1252
+ "# Function to generate a response based on a user prompt (testing the model)\n",
1253
+ "def generate_response(prompt):\n",
1254
+ " inputs = tokenizer.encode(prompt, return_tensors=\"pt\").to('cuda')\n",
1255
+ " outputs = model.generate(inputs, max_length=150, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id)\n",
1256
+ "\n",
1257
+ " # Decode the generated output\n",
1258
+ " response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
1259
+ "\n",
1260
+ " # Remove the prompt from the response\n",
1261
+ " if response.startswith(prompt):\n",
1262
+ " response = response[len(prompt):].strip() # Remove the prompt from the response\n",
1263
+ "\n",
1264
+ " return response"
1265
+ ],
1266
+ "metadata": {
1267
+ "id": "JbMs8UuSu5_R"
1268
+ },
1269
+ "execution_count": null,
1270
+ "outputs": []
1271
+ },
1272
+ {
1273
+ "cell_type": "code",
1274
+ "source": [
1275
+ "# Example conversation\n",
1276
+ "user_input = \"what is desonide ointment used for\"\n",
1277
+ "bot_response = generate_response(user_input)\n",
1278
+ "print(\"Bot Response:\", bot_response)"
1279
+ ],
1280
+ "metadata": {
1281
+ "colab": {
1282
+ "base_uri": "https://localhost:8080/"
1283
+ },
1284
+ "id": "qsHAT1-uxC4_",
1285
+ "outputId": "89b73c5f-0ae9-449d-8eb4-3df1a7c146bb"
1286
+ },
1287
+ "execution_count": null,
1288
+ "outputs": [
1289
+ {
1290
+ "output_type": "stream",
1291
+ "name": "stdout",
1292
+ "text": [
1293
+ "Bot Response: desonide ointment is used to treat a variety of conditions it is used to treat allergies and other skin conditions it is also used to treat certain types of infections it is also used to treat skin infections caused by bacteria that are on skin desonide is in a class of medications called antimicrobials it works by killing bacteria that cause skin infections desonide is in a class of medications called antibiotics it works by killing bacteria that cause skin infections\n"
1294
+ ]
1295
+ }
1296
+ ]
1297
+ },
1298
+ {
1299
+ "cell_type": "code",
1300
+ "source": [
1301
+ "# Copying the model to Google Drive (optional)\n",
1302
+ "import shutil\n",
1303
+ "\n",
1304
+ "# Path to the file in Colab\n",
1305
+ "colab_file_path = '/content/med_info_model/model.safetensors'\n",
1306
+ "\n",
1307
+ "# Path to your Google Drive\n",
1308
+ "drive_file_path = '/content/drive/MyDrive'\n",
1309
+ "\n",
1310
+ "# Copy the file\n",
1311
+ "shutil.copy(colab_file_path, drive_file_path)"
1312
+ ],
1313
+ "metadata": {
1314
+ "colab": {
1315
+ "base_uri": "https://localhost:8080/",
1316
+ "height": 36
1317
+ },
1318
+ "id": "aP4IEboMxDWG",
1319
+ "outputId": "c00d1d74-e389-4de4-a151-d20736b6bccd"
1320
+ },
1321
+ "execution_count": null,
1322
+ "outputs": [
1323
+ {
1324
+ "output_type": "execute_result",
1325
+ "data": {
1326
+ "text/plain": [
1327
+ "'/content/drive/MyDrive/model.safetensors'"
1328
+ ],
1329
+ "application/vnd.google.colaboratory.intrinsic+json": {
1330
+ "type": "string"
1331
+ }
1332
+ },
1333
+ "metadata": {},
1334
+ "execution_count": 22
1335
+ }
1336
+ ]
1337
+ },
1338
+ {
1339
+ "cell_type": "code",
1340
+ "source": [],
1341
+ "metadata": {
1342
+ "id": "uKYwYe5XyXgx"
1343
+ },
1344
+ "execution_count": null,
1345
+ "outputs": []
1346
+ }
1347
+ ]
1348
+ }
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, jsonify, request, send_from_directory
2
+ from backend.utils import (
3
+ generate_counseling_response,
4
+ generate_medication_response,
5
+ classify_diabetes,
6
+ classify_medicine,
7
+ get_llama_response
8
+ )
9
+ import os
10
+
11
+ app = Flask(__name__, static_folder='frontend', template_folder='frontend')
12
+
13
+ # Serve the main HTML file for the frontend
14
+ @app.route('/')
15
+ def index():
16
+ return send_from_directory(app.static_folder, 'index.html')
17
+
18
+
19
+ # Serve the CSS files
20
+ @app.route('/styles.css')
21
+ def styles():
22
+ return send_from_directory(app.static_folder, 'styles.css')
23
+
24
+
25
+ # Serve the JavaScript files
26
+ @app.route('/script.js')
27
+ def script():
28
+ return send_from_directory(app.static_folder, 'script.js')
29
+
30
+
31
+ # Route for Counseling Model
32
+ @app.route('/api/counseling', methods=['POST'])
33
+ def counseling():
34
+ data = request.json
35
+ question = data.get('question')
36
+ if not question:
37
+ return jsonify({"error": "Question is required."}), 400
38
+
39
+ response = generate_counseling_response(question)
40
+ return jsonify({"response": response})
41
+
42
+
43
+ # Route for Medication Info Model
44
+ @app.route('/api/medication', methods=['POST'])
45
+ def medication():
46
+ data = request.json
47
+ question = data.get('question')
48
+ if not question:
49
+ return jsonify({"error": "Question is required."}), 400
50
+
51
+ response = generate_medication_response(question)
52
+ return jsonify({"response": response})
53
+
54
+
55
+ # Route for Diabetes Classification
56
+ @app.route('/api/diabetes_classification', methods=['POST'])
57
+ def diabetes_classification():
58
+ data = request.json
59
+
60
+ # Extract input features
61
+ glucose = data.get('glucose')
62
+ bmi = data.get('bmi')
63
+ age = data.get('age')
64
+
65
+ # Validate input data
66
+ if glucose is None or bmi is None or age is None:
67
+ return jsonify({"error": "Please provide glucose, bmi, and age."}), 400
68
+
69
+ result = classify_diabetes(glucose, bmi, age)
70
+ return jsonify({"result": result})
71
+
72
+
73
+ # Route for Medicine Classification
74
+ @app.route('/api/medicine_classification', methods=['POST'])
75
+ def medicine_classification():
76
+ data = request.json
77
+
78
+ # Extract input features
79
+ age = data.get('age')
80
+ gender = data.get('gender')
81
+ blood_type = data.get('blood_type')
82
+ medical_condition = data.get('medical_condition')
83
+ test_results = data.get('test_results')
84
+
85
+ # Validate input data
86
+ if not (age and gender and blood_type and medical_condition and test_results):
87
+ return jsonify({"error": "Please provide Age, Gender, Blood Type, Medical Condition, and Test Results."}), 400
88
+
89
+ # Prepare the new data as a DataFrame
90
+ new_data = {
91
+ 'Age': [int(age)],
92
+ 'Gender': [gender],
93
+ 'Blood Type': [blood_type],
94
+ 'Medical Condition': [medical_condition],
95
+ 'Test Results': [test_results]
96
+ }
97
+
98
+ # Call the classification function
99
+ medicine = classify_medicine(new_data)
100
+ return jsonify({"medicine": medicine[0]})
101
+
102
+
103
+
104
+ # Route for General Chat (Llama 3.1 API using Groq Cloud)
105
+ @app.route('/api/general', methods=['POST'])
106
+ def general_chat():
107
+ data = request.json
108
+ question = data.get('question')
109
+ if not question:
110
+ return jsonify({"error": "Question is required."}), 400
111
+
112
+ # Get formatted response from LLaMA 3.1 hosted on Groq Cloud
113
+ llama_response = get_llama_response(question)
114
+ return jsonify({"response": llama_response})
115
+
116
+
117
+ if __name__ == '__main__':
118
+ app.run(debug=True)
backend/__init__.py ADDED
File without changes
backend/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (139 Bytes). View file
 
backend/__pycache__/utils.cpython-312.pyc ADDED
Binary file (5.4 kB). View file
 
backend/models/diabetes_model/random_forest_modelf.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a85f94f5f6d3042ac7b03513b38ff6472eb3bbdaa5d9d218734398f86f32a2b0
3
+ size 2412153
backend/models/diabetes_model/standard_scaler.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06c474ca12ad0be6d4364a5c3e791799deae88319690b39853b21321472e9483
3
+ size 671
backend/models/medication_classification_model/age_scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11e23987d039e63910799d5b655f1854e7a771f976aa9616e283bc560cf8a05a
3
+ size 927
backend/models/medication_classification_model/knn_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1f2d56e390c7f035c379fac579590e938b8d6559d294c86107981923d3c1a45
3
+ size 4493526
backend/models/medication_classification_model/label_encoders.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:054ccfdced7e079ad079278f88e4694501e734aa63c2b7b2704c76eda9157a89
3
+ size 1576
backend/models/medication_classification_model/medication_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67594b193e6f4a44315776b4694012fd38e5ec157ea649091d3758135c3b2dfb
3
+ size 597
backend/models/medication_info/config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "gpt2",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "GPT2LMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.1,
8
+ "bos_token_id": 50256,
9
+ "embd_pdrop": 0.1,
10
+ "eos_token_id": 50256,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_epsilon": 1e-05,
13
+ "model_type": "gpt2",
14
+ "n_ctx": 1024,
15
+ "n_embd": 768,
16
+ "n_head": 12,
17
+ "n_inner": null,
18
+ "n_layer": 12,
19
+ "n_positions": 1024,
20
+ "reorder_and_upcast_attn": false,
21
+ "resid_pdrop": 0.1,
22
+ "scale_attn_by_inverse_layer_idx": false,
23
+ "scale_attn_weights": true,
24
+ "summary_activation": null,
25
+ "summary_first_dropout": 0.1,
26
+ "summary_proj_to_labels": true,
27
+ "summary_type": "cls_index",
28
+ "summary_use_proj": true,
29
+ "task_specific_params": {
30
+ "text-generation": {
31
+ "do_sample": true,
32
+ "max_length": 50
33
+ }
34
+ },
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.44.2",
37
+ "use_cache": true,
38
+ "vocab_size": 50257
39
+ }
backend/models/medication_info/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.44.2"
6
+ }
backend/models/medication_info/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
backend/models/medication_info/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:796a9391740bb9884e37b9f11b4c0d9f57c06941f057fa6345536d13c771e810
3
+ size 497774208
backend/models/medication_info/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
backend/models/medication_info/tokenizer_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "50256": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ }
13
+ },
14
+ "bos_token": "<|endoftext|>",
15
+ "clean_up_tokenization_spaces": true,
16
+ "eos_token": "<|endoftext|>",
17
+ "errors": "replace",
18
+ "model_max_length": 1024,
19
+ "pad_token": "<|endoftext|>",
20
+ "tokenizer_class": "GPT2Tokenizer",
21
+ "unk_token": "<|endoftext|>"
22
+ }
backend/models/medication_info/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3765ef35134499414f1f0ec4f0439ae47cdf23380f5535f5092007be173d31c
3
+ size 5112
backend/models/medication_info/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
backend/models/mental_health_model/config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "gpt2",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "GPT2LMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.1,
8
+ "bos_token_id": 50256,
9
+ "embd_pdrop": 0.1,
10
+ "eos_token_id": 50256,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_epsilon": 1e-05,
13
+ "model_type": "gpt2",
14
+ "n_ctx": 1024,
15
+ "n_embd": 768,
16
+ "n_head": 12,
17
+ "n_inner": null,
18
+ "n_layer": 12,
19
+ "n_positions": 1024,
20
+ "reorder_and_upcast_attn": false,
21
+ "resid_pdrop": 0.1,
22
+ "scale_attn_by_inverse_layer_idx": false,
23
+ "scale_attn_weights": true,
24
+ "summary_activation": null,
25
+ "summary_first_dropout": 0.1,
26
+ "summary_proj_to_labels": true,
27
+ "summary_type": "cls_index",
28
+ "summary_use_proj": true,
29
+ "task_specific_params": {
30
+ "text-generation": {
31
+ "do_sample": true,
32
+ "max_length": 50
33
+ }
34
+ },
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.44.2",
37
+ "use_cache": true,
38
+ "vocab_size": 50257
39
+ }
backend/models/mental_health_model/generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "transformers_version": "4.44.2"
6
+ }
backend/models/mental_health_model/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
backend/models/mental_health_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eea02c798a4efdadb3ecf163a411a12f393d9ac30c9f5019348a65a666dabdbc
3
+ size 497774208
backend/models/mental_health_model/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
backend/models/mental_health_model/tokenizer_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "50256": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ }
13
+ },
14
+ "bos_token": "<|endoftext|>",
15
+ "clean_up_tokenization_spaces": true,
16
+ "eos_token": "<|endoftext|>",
17
+ "errors": "replace",
18
+ "model_max_length": 1024,
19
+ "pad_token": "<|endoftext|>",
20
+ "tokenizer_class": "GPT2Tokenizer",
21
+ "unk_token": "<|endoftext|>"
22
+ }
backend/models/mental_health_model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fe4d3b4fb9feb46fdfc5a116e608dc00a98241105ed35a3cc8d220ee6e20886
3
+ size 5112
backend/models/mental_health_model/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
backend/utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # backend/utils.py
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ from langchain_groq import ChatGroq
4
+ import torch
5
+ import requests
6
+ import joblib
7
+ import pandas as pd
8
+
9
+ # Load the trained model and tokenizer : Counselling
10
+ counseling_model = GPT2LMHeadModel.from_pretrained('backend\\models\\mental_health_model')
11
+ counselling_tokenizer = GPT2Tokenizer.from_pretrained('backend\\models\\mental_health_model')
12
+
13
+ # Load the trained model and tokenizer : Medication
14
+ medication_model = GPT2LMHeadModel.from_pretrained('backend\\models\\medication_info')
15
+ medication_tokenizer = GPT2Tokenizer.from_pretrained('backend\\models\\medication_info')
16
+
17
+ # Load the trained Random Forest model and StandardScaler
18
+ diabetes_model = joblib.load('backend\\models\\diabetes_model\\random_forest_modelf.joblib')
19
+ diabetes_scaler = joblib.load('backend\\models\\diabetes_model\\standard_scaler.joblib')
20
+
21
+ # Load the model, encoders, and scaler
22
+ knn = joblib.load('backend\\models\\medication_classification_model\\knn_model.pkl')
23
+ label_encoders = joblib.load('backend\\models\\medication_classification_model\\label_encoders.pkl')
24
+ age_scaler = joblib.load('backend\\models\\medication_classification_model\\age_scaler.pkl')
25
+ medication_encoder = joblib.load('backend\\models\\medication_classification_model\\medication_encoder.pkl')
26
+
27
+
28
+
29
+
30
+ # Diabetes Classifier
31
+ def classify_diabetes(glucose, bmi, age):
32
+ # Normalize the input features
33
+ input_features = [[glucose, bmi, age]]
34
+ input_features_norm = diabetes_scaler.transform(input_features)
35
+
36
+ # Make predictions
37
+ prediction = diabetes_model.predict(input_features_norm)[0]
38
+ prediction_probability = diabetes_model.predict_proba(input_features_norm)[0] * 100
39
+
40
+ diabetic_probability = prediction_probability[prediction].item()
41
+
42
+ if prediction == 0:
43
+ result = "Non Diabetic"
44
+ else:
45
+ result = "Diabetic"
46
+
47
+ # Format the output as: "Non Diabetic | 72%"
48
+ formatted_result = f"{result} | {diabetic_probability:.1f}%"
49
+ return formatted_result
50
+
51
+
52
+ # Medicine Classifier
53
+ def classify_medicine(new_data):
54
+ # Convert dictionary to DataFrame
55
+ new_data_df = pd.DataFrame(new_data)
56
+
57
+ # Encode the new data using the saved label encoders
58
+ for column in ['Gender', 'Blood Type', 'Medical Condition', 'Test Results']:
59
+ new_data_df[column] = label_encoders[column].transform(new_data_df[column])
60
+
61
+ # Normalize the 'Age' column in the new data
62
+ new_data_df['Age'] = age_scaler.transform(new_data_df[['Age']])
63
+
64
+ # Make predictions
65
+ predictions = knn.predict(new_data_df)
66
+
67
+ # Decode the predictions back to the original medication names
68
+ predicted_medications = medication_encoder.inverse_transform(predictions)
69
+
70
+ return predicted_medications
71
+
72
+
73
+ # Generate Counseling Response
74
+ def generate_counseling_response(prompt):
75
+ inputs = counselling_tokenizer.encode(prompt, return_tensors="pt")
76
+ outputs = counseling_model.generate(inputs, max_length=150, num_return_sequences=1, pad_token_id=counselling_tokenizer.eos_token_id)
77
+
78
+ # Decode the generated output
79
+ response = counselling_tokenizer.decode(outputs[0], skip_special_tokens=True)
80
+
81
+ # Remove the prompt from the response
82
+ if response.startswith(prompt):
83
+ response = response[len(prompt):].strip() # Remove the prompt from the response
84
+
85
+ return response
86
+
87
+
88
+ # Generate Medication Response
89
+ def generate_medication_response(prompt):
90
+ inputs = medication_tokenizer.encode(prompt, return_tensors="pt")
91
+ outputs = medication_model.generate(inputs, max_length=150, num_return_sequences=1, pad_token_id=medication_tokenizer.eos_token_id)
92
+
93
+ # Decode the generated output
94
+ response = medication_tokenizer.decode(outputs[0], skip_special_tokens=True)
95
+
96
+ # Remove the prompt from the response
97
+ if response.startswith(prompt):
98
+ response = response[len(prompt):].strip() # Remove the prompt from the response
99
+
100
+ return response
101
+
102
+
103
+ # Llama 3.1 Integration as a General Tab
104
+ llm = ChatGroq(
105
+ temperature=0,
106
+ groq_api_key='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',
107
+ model_name="llama-3.1-70b-versatile"
108
+ )
109
+
110
+ def get_llama_response(prompt):
111
+ try:
112
+ response = llm.invoke(prompt)
113
+ formatted_response = format_response(response.content)
114
+ return formatted_response
115
+ except Exception as e:
116
+ return f"Error: {str(e)}"
117
+
118
+ def format_response(response):
119
+ # Add line breaks and make it easier to read
120
+ response = response.replace("**", "").replace("*", "").replace(" ", "\n").strip()
121
+ lines = response.split("\n")
122
+ formatted_response = ""
123
+ for line in lines:
124
+ formatted_response += f"<p>{line}</p>"
125
+ return formatted_response
frontend/index.html ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>AHA</title>
7
+ <link rel="stylesheet" href="/styles.css">
8
+ </head>
9
+ <body>
10
+ <div class="container">
11
+ <h1>AI Health Assistant</h1>
12
+ <div class="tabs">
13
+ <button class="tab-button" onclick="showTab('counseling')">Counseling Chat</button>
14
+ <button class="tab-button" onclick="showTab('medication')">Medication Chat</button>
15
+ <button class="tab-button" onclick="showTab('general')">General Chat</button>
16
+ <button class="tab-button" onclick="showTab('diabetes')">Diabetes Classification</button>
17
+ <button class="tab-button" onclick="showTab('medicine')">Medicine Classification</button>
18
+ </div>
19
+ <div id="counseling" class="tab-content">
20
+ <textarea id="counseling-question" placeholder="Ask your health problem here..."></textarea>
21
+ <button onclick="submitCounseling()">Generate</button>
22
+ <p id="counseling-response"></p>
23
+ </div>
24
+ <div id="medication" class="tab-content">
25
+ <textarea id="medication-question" placeholder="Ask your medicine here..."></textarea>
26
+ <button onclick="submitMedication()">Generate</button>
27
+ <p id="medication-response"></p>
28
+ </div>
29
+ <div id="diabetes" class="tab-content">
30
+ <input type="number" id="glucose" placeholder="Glucose Level" required>
31
+ <input type="number" id="bmi" placeholder="BMI" required>
32
+ <input type="number" id="age-diabetes" placeholder="Age" required>
33
+ <button onclick="submitDiabetes()">Submit</button>
34
+ <p id="diabetes-response"></p>
35
+ </div>
36
+ <div id="medicine" class="tab-content">
37
+ <input type="number" id="age" placeholder="Age" required>
38
+ <select id="gender" required>
39
+ <option value="" disabled selected>Gender</option>
40
+ <option value="Male">Male</option>
41
+ <option value="Female">Female</option>
42
+ </select>
43
+ <select id="blood-type" required>
44
+ <option value="" disabled selected>Blood Group</option>
45
+ <option value="A+">A+</option>
46
+ <option value="A-">A-</option>
47
+ <option value="B+">B+</option>
48
+ <option value="B-">B-</option>
49
+ <option value="AB+">AB+</option>
50
+ <option value="AB-">AB-</option>
51
+ <option value="O+">O+</option>
52
+ <option value="O-">O-</option>
53
+ </select>
54
+ <select id="medical-condition" required>
55
+ <option value="" disabled selected>Medical Condition</option>
56
+ <option value="Cancer">Cancer</option>
57
+ <option value="Arthritis">Arthritis</option>
58
+ <option value="Diabetes">Diabetes</option>
59
+ <option value="Hypertension">Hypertension</option>
60
+ <option value="Obesity">Obesity</option>
61
+ <option value="Asthma">Asthma</option>
62
+ </select>
63
+ <select id="test-results" required>
64
+ <option value="" disabled selected>Test Results</option>
65
+ <option value="Normal">Normal</option>
66
+ <option value="Abnormal">Abnormal</option>
67
+ <option value="Inconclusive">Inconclusive</option>
68
+ </select>
69
+ <button onclick="submitMedicine()">Submit</button>
70
+ <p id="medicine-response"></p>
71
+ </div>
72
+ <div id="general" class="tab-content">
73
+ <textarea id="general-question" placeholder="Ask your question here..."></textarea>
74
+ <button onclick="submitGeneral()">Generate</button>
75
+ <p id="general-response"></p>
76
+ </div>
77
+ </div>
78
+ <script src="/script.js"></script>
79
+ </body>
80
+ </html>
frontend/script.js ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function showTab(tabId) {
2
+ const tabs = document.querySelectorAll('.tab-content');
3
+ tabs.forEach(tab => {
4
+ if (tab.id === tabId) {
5
+ tab.classList.add('active');
6
+ } else {
7
+ tab.classList.remove('active');
8
+ }
9
+ });
10
+ }
11
+
12
+ function submitCounseling() {
13
+ const question = document.getElementById('counseling-question').value;
14
+ fetch('/api/counseling', {
15
+ method: 'POST',
16
+ headers: { 'Content-Type': 'application/json' },
17
+ body: JSON.stringify({ question })
18
+ })
19
+ .then(response => response.json())
20
+ .then(data => {
21
+ document.getElementById('counseling-response').innerText = data.response;
22
+ })
23
+ .catch(error => console.error('Error:', error));
24
+ }
25
+
26
+ function submitMedication() {
27
+ const question = document.getElementById('medication-question').value;
28
+ fetch('/api/medication', {
29
+ method: 'POST',
30
+ headers: { 'Content-Type': 'application/json' },
31
+ body: JSON.stringify({ question })
32
+ })
33
+ .then(response => response.json())
34
+ .then(data => {
35
+ document.getElementById('medication-response').innerText = data.response;
36
+ })
37
+ .catch(error => console.error('Error:', error));
38
+ }
39
+
40
+ function submitDiabetes() {
41
+ const glucose = document.getElementById('glucose').value;
42
+ const bmi = document.getElementById('bmi').value;
43
+ const age = document.getElementById('age-diabetes').value;
44
+ fetch('/api/diabetes_classification', {
45
+ method: 'POST',
46
+ headers: { 'Content-Type': 'application/json' },
47
+ body: JSON.stringify({ glucose, bmi, age })
48
+ })
49
+ .then(response => response.json())
50
+ .then(data => {
51
+ document.getElementById('diabetes-response').innerText = data.result;
52
+ })
53
+ .catch(error => console.error('Error:', error));
54
+ }
55
+
56
+ function submitMedicine() {
57
+ const age = document.getElementById('age').value;
58
+ const gender = document.getElementById('gender').value;
59
+ const bloodType = document.getElementById('blood-type').value;
60
+ const medicalCondition = document.getElementById('medical-condition').value;
61
+ const testResults = document.getElementById('test-results').value;
62
+
63
+ fetch('/api/medicine_classification', {
64
+ method: 'POST',
65
+ headers: { 'Content-Type': 'application/json' },
66
+ body: JSON.stringify({ age, gender, blood_type: bloodType, medical_condition: medicalCondition, test_results: testResults })
67
+ })
68
+ .then(response => response.json())
69
+ .then(data => {
70
+ document.getElementById('medicine-response').innerText = data.medicine;
71
+ })
72
+ .catch(error => console.error('Error:', error));
73
+ }
74
+
75
+ function submitGeneral() {
76
+ const question = document.getElementById('general-question').value;
77
+ fetch('/api/general', {
78
+ method: 'POST',
79
+ headers: { 'Content-Type': 'application/json' },
80
+ body: JSON.stringify({ question })
81
+ })
82
+ .then(response => response.json())
83
+ .then(data => {
84
+ document.getElementById('general-response').innerText = data.response;
85
+ })
86
+ .catch(error => console.error('Error:', error));
87
+ }
frontend/styles.css ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: Arial, sans-serif;
3
+ margin: 0;
4
+ padding: 0;
5
+ background: #121212; /* Dark background for body */
6
+ color: #e0e0e0; /* Light text color */
7
+ }
8
+
9
+ .container {
10
+ width: 80%;
11
+ margin: 0 auto;
12
+ padding: 20px;
13
+ }
14
+
15
+ h1 {
16
+ text-align: center;
17
+ color: #ffffff; /* Light color for headers */
18
+ }
19
+
20
+ .tabs {
21
+ display: flex;
22
+ justify-content: space-around;
23
+ margin-bottom: 20px;
24
+ }
25
+
26
+ .tab-button {
27
+ padding: 10px 20px;
28
+ background: #054c66; /* Dark background for tab buttons */
29
+ color: #ffffff; /* Light text color */
30
+ border: none;
31
+ cursor: pointer;
32
+ border-radius: 5px;
33
+ transition: background 0.3s;
34
+ }
35
+
36
+ .tab-button:hover {
37
+ background: #226f90; /* Slightly lighter background on hover */
38
+ }
39
+
40
+ .tab-content {
41
+ display: none;
42
+ animation: fadeIn 0.5s;
43
+ }
44
+
45
+ .tab-content.active {
46
+ display: block;
47
+ }
48
+
49
+ textarea, input, select {
50
+ display: block;
51
+ width: 100%;
52
+ padding: 10px;
53
+ margin-bottom: 10px;
54
+ border: 1px solid #444; /* Darker border color */
55
+ border-radius: 5px;
56
+ background: #1e1e1e; /* Dark background for inputs */
57
+ color: #e0e0e0; /* Light text color */
58
+ transition: border-color 0.3s;
59
+ }
60
+
61
+ textarea:focus, input:focus, select:focus {
62
+ border-color: #00b3ff; /* Highlight border on focus */
63
+ }
64
+
65
+ button {
66
+ padding: 10px 20px;
67
+ background: #0b979e; /* Dark green background */
68
+ color: #fff;
69
+ border: none;
70
+ cursor: pointer;
71
+ border-radius: 5px;
72
+ transition: background 0.3s;
73
+ }
74
+
75
+ button:hover {
76
+ background: #417a7d; /* Slightly darker green on hover */
77
+ }
78
+
79
+ p {
80
+ font-size: 1.2em;
81
+ color: #e0e0e0; /* Light text color for paragraphs */
82
+ margin-top: 10px;
83
+ }
84
+
85
+ /* Keyframes for fade-in animation */
86
+ @keyframes fadeIn {
87
+ from { opacity: 0; }
88
+ to { opacity: 1; }
89
+ }