{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "a940b50d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Python Platform: macOS-13.4-arm64-arm-64bit\n", "Tensor Flow Version: 2.12.0\n", "Keras Version: 2.12.0\n", "\n", "Python 3.10.9 (main, Mar 1 2023, 12:20:14) [Clang 14.0.6 ]\n", "Pandas 2.0.3\n", "Scikit-Learn 1.3.0\n", "SciPy 1.11.1\n", "GPU is available\n" ] } ], "source": [ "import sys\n", "import tensorflow.keras\n", "import pandas as pd\n", "import sklearn as sk\n", "import scipy as sp\n", "import tensorflow as tf\n", "import platform\n", "print(f\"Python Platform: {platform.platform()}\")\n", "print(f\"Tensor Flow Version: {tf.__version__}\")\n", "print(f\"Keras Version: {tensorflow.keras.__version__}\")\n", "print()\n", "print(f\"Python {sys.version}\")\n", "print(f\"Pandas {pd.__version__}\")\n", "print(f\"Scikit-Learn {sk.__version__}\")\n", "print(f\"SciPy {sp.__version__}\")\n", "gpu = len(tf.config.list_physical_devices('GPU'))>0\n", "print(\"GPU is\", \"available\" if gpu else \"NOT AVAILABLE\")" ] }, { "cell_type": "markdown", "id": "ca1dbcff", "metadata": {}, "source": [ "**Based on the paper** :https://arxiv.org/pdf/1612.05251.pdf:\n", " *2017 paper PubMed 200k RCT: a Dataset for Sequenctial Sentence Classification in Medical Abstracts.*" ] }, { "cell_type": "code", "execution_count": 5, "id": "865602b3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "fatal: destination path 'pubmed-rct' already exists and is not an empty directory.\n", "\u001b[34mPubMed_200k_RCT\u001b[m\u001b[m\n", "\u001b[34mPubMed_200k_RCT_numbers_replaced_with_at_sign\u001b[m\u001b[m\n", "\u001b[34mPubMed_20k_RCT\u001b[m\u001b[m\n", "\u001b[34mPubMed_20k_RCT_numbers_replaced_with_at_sign\u001b[m\u001b[m\n", "README.md\n", "helper_functions.py\n" ] } ], "source": [ "!git clone https://github.com/Franck-Dernoncourt/pubmed-rct.git\n", "!ls pubmed-rct" ] }, { "cell_type": "code", "execution_count": 6, "id": "49999d88", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dev.txt test.txt train.txt train.zip\r\n" ] } ], "source": [ "!ls pubmed-rct/PubMed_200k_RCT_numbers_replaced_with_at_sign/" ] }, { "cell_type": "code", "execution_count": 8, "id": "0eb48880", "metadata": {}, "outputs": [], "source": [ "data_dir=\"pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign/\"" ] }, { "cell_type": "code", "execution_count": 9, "id": "0ea6b9e4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign/dev.txt',\n", " 'pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign/train.txt',\n", " 'pubmed-rct/PubMed_20k_RCT_numbers_replaced_with_at_sign/test.txt']" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "filenames= [data_dir+ filename for filename in os.listdir(data_dir)]\n", "filenames" ] }, { "cell_type": "code", "execution_count": 10, "id": "4e578219", "metadata": {}, "outputs": [], "source": [ "def get_lines(filename):\n", " with open(filename) as f:\n", " return f.readlines()" ] }, { "cell_type": "code", "execution_count": 11, "id": "2a709290", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['###24293578\\n',\n", " 'OBJECTIVE\\tTo investigate the efficacy of @ weeks of daily low-dose oral prednisolone in improving pain , mobility , and systemic low-grade inflammation in the short term and whether the effect would be sustained at @ weeks in older adults with moderate to severe knee osteoarthritis ( OA ) .\\n',\n", " 'METHODS\\tA total of @ patients with primary knee OA were randomized @:@ ; @ received @ mg/day of prednisolone and @ received placebo for @ weeks .\\n',\n", " 'METHODS\\tOutcome measures included pain reduction and improvement in function scores and systemic inflammation markers .\\n',\n", " 'METHODS\\tPain was assessed using the visual analog pain scale ( @-@ mm ) .\\n',\n", " 'METHODS\\tSecondary outcome measures included the Western Ontario and McMaster Universities Osteoarthritis Index scores , patient global assessment ( PGA ) of the severity of knee OA , and @-min walk distance ( @MWD ) .\\n',\n", " 'METHODS\\tSerum levels of interleukin @ ( IL-@ ) , IL-@ , tumor necrosis factor ( TNF ) - , and high-sensitivity C-reactive protein ( hsCRP ) were measured .\\n',\n", " 'RESULTS\\tThere was a clinically relevant reduction in the intervention group compared to the placebo group for knee pain , physical function , PGA , and @MWD at @ weeks .\\n',\n", " 'RESULTS\\tThe mean difference between treatment arms ( @ % CI ) was @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; and @ ( @-@ @ ) , p < @ , respectively .\\n',\n", " 'RESULTS\\tFurther , there was a clinically relevant reduction in the serum levels of IL-@ , IL-@ , TNF - , and hsCRP at @ weeks in the intervention group when compared to the placebo group .\\n',\n", " 'RESULTS\\tThese differences remained significant at @ weeks .\\n',\n", " 'RESULTS\\tThe Outcome Measures in Rheumatology Clinical Trials-Osteoarthritis Research Society International responder rate was @ % in the intervention group and @ % in the placebo group ( p < @ ) .\\n',\n", " 'CONCLUSIONS\\tLow-dose oral prednisolone had both a short-term and a longer sustained effect resulting in less knee pain , better physical function , and attenuation of systemic inflammation in older patients with knee OA ( ClinicalTrials.gov identifier NCT@ ) .\\n',\n", " '\\n',\n", " '###24854809\\n',\n", " 'BACKGROUND\\tEmotional eating is associated with overeating and the development of obesity .\\n',\n", " 'BACKGROUND\\tYet , empirical evidence for individual ( trait ) differences in emotional eating and cognitive mechanisms that contribute to eating during sad mood remain equivocal .\\n',\n", " 'OBJECTIVE\\tThe aim of this study was to test if attention bias for food moderates the effect of self-reported emotional eating during sad mood ( vs neutral mood ) on actual food intake .\\n',\n", " 'OBJECTIVE\\tIt was expected that emotional eating is predictive of elevated attention for food and higher food intake after an experimentally induced sad mood and that attentional maintenance on food predicts food intake during a sad versus a neutral mood .\\n',\n", " 'METHODS\\tParticipants ( N = @ ) were randomly assigned to one of the two experimental mood induction conditions ( sad/neutral ) .\\n']" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_lines=get_lines(data_dir+'train.txt')\n", "train_lines[:20]" ] }, { "cell_type": "code", "execution_count": 12, "id": "c870e8a5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "210040" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_lines)" ] }, { "cell_type": "markdown", "id": "6828c4df", "metadata": {}, "source": [ "# Data Representation\n", "\n", "[{ line number:0\n", " text:\n", " target:Background\n", " total_lines:11\n", " }]" ] }, { "cell_type": "code", "execution_count": 13, "id": "fd07d2de", "metadata": {}, "outputs": [], "source": [ "def pre_processing_data(filename):\n", " input_lines=get_lines(filename)\n", " abstract_lines=\"\"\n", " abstract_samples=[]\n", " \n", " for line in input_lines:\n", " if line.startswith(\"###\"):\n", " abstract_id=line\n", " abstract_lines=\"\"\n", " elif line.isspace():\n", " abstract_line_split=abstract_lines.splitlines()\n", " for abstract_line_number,abstract_line in enumerate(abstract_line_split):\n", " line_data={}\n", " target_split=abstract_line.split(\"\\t\")\n", " line_data[\"target\"]=target_split[0]\n", " line_data[\"text\"]=target_split[1].lower()\n", " line_data[\"line_number\"]=abstract_line_number\n", " line_data[\"total_lines\"]=len(abstract_line_split)-1\n", " abstract_samples.append(line_data)\n", " else:\n", " abstract_lines+=line\n", " return abstract_samples" ] }, { "cell_type": "code", "execution_count": 14, "id": "076e7b10", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 301 ms, sys: 47.9 ms, total: 349 ms\n", "Wall time: 359 ms\n" ] } ], "source": [ "%%time\n", "train_samples= pre_processing_data(data_dir+\"train.txt\")\n", "val_samples=pre_processing_data(data_dir+\"dev.txt\")\n", "test_samples=pre_processing_data(data_dir+\"test.txt\")" ] }, { "cell_type": "code", "execution_count": 15, "id": "5c234a1e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'target': 'OBJECTIVE',\n", " 'text': 'to investigate the efficacy of @ weeks of daily low-dose oral prednisolone in improving pain , mobility , and systemic low-grade inflammation in the short term and whether the effect would be sustained at @ weeks in older adults with moderate to severe knee osteoarthritis ( oa ) .',\n", " 'line_number': 0,\n", " 'total_lines': 11},\n", " {'target': 'METHODS',\n", " 'text': 'a total of @ patients with primary knee oa were randomized @:@ ; @ received @ mg/day of prednisolone and @ received placebo for @ weeks .',\n", " 'line_number': 1,\n", " 'total_lines': 11},\n", " {'target': 'METHODS',\n", " 'text': 'outcome measures included pain reduction and improvement in function scores and systemic inflammation markers .',\n", " 'line_number': 2,\n", " 'total_lines': 11},\n", " {'target': 'METHODS',\n", " 'text': 'pain was assessed using the visual analog pain scale ( @-@ mm ) .',\n", " 'line_number': 3,\n", " 'total_lines': 11},\n", " {'target': 'METHODS',\n", " 'text': 'secondary outcome measures included the western ontario and mcmaster universities osteoarthritis index scores , patient global assessment ( pga ) of the severity of knee oa , and @-min walk distance ( @mwd ) .',\n", " 'line_number': 4,\n", " 'total_lines': 11},\n", " {'target': 'METHODS',\n", " 'text': 'serum levels of interleukin @ ( il-@ ) , il-@ , tumor necrosis factor ( tnf ) - , and high-sensitivity c-reactive protein ( hscrp ) were measured .',\n", " 'line_number': 5,\n", " 'total_lines': 11},\n", " {'target': 'RESULTS',\n", " 'text': 'there was a clinically relevant reduction in the intervention group compared to the placebo group for knee pain , physical function , pga , and @mwd at @ weeks .',\n", " 'line_number': 6,\n", " 'total_lines': 11},\n", " {'target': 'RESULTS',\n", " 'text': 'the mean difference between treatment arms ( @ % ci ) was @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; and @ ( @-@ @ ) , p < @ , respectively .',\n", " 'line_number': 7,\n", " 'total_lines': 11},\n", " {'target': 'RESULTS',\n", " 'text': 'further , there was a clinically relevant reduction in the serum levels of il-@ , il-@ , tnf - , and hscrp at @ weeks in the intervention group when compared to the placebo group .',\n", " 'line_number': 8,\n", " 'total_lines': 11},\n", " {'target': 'RESULTS',\n", " 'text': 'these differences remained significant at @ weeks .',\n", " 'line_number': 9,\n", " 'total_lines': 11},\n", " {'target': 'RESULTS',\n", " 'text': 'the outcome measures in rheumatology clinical trials-osteoarthritis research society international responder rate was @ % in the intervention group and @ % in the placebo group ( p < @ ) .',\n", " 'line_number': 10,\n", " 'total_lines': 11},\n", " {'target': 'CONCLUSIONS',\n", " 'text': 'low-dose oral prednisolone had both a short-term and a longer sustained effect resulting in less knee pain , better physical function , and attenuation of systemic inflammation in older patients with knee oa ( clinicaltrials.gov identifier nct@ ) .',\n", " 'line_number': 11,\n", " 'total_lines': 11},\n", " {'target': 'BACKGROUND',\n", " 'text': 'emotional eating is associated with overeating and the development of obesity .',\n", " 'line_number': 0,\n", " 'total_lines': 10},\n", " {'target': 'BACKGROUND',\n", " 'text': 'yet , empirical evidence for individual ( trait ) differences in emotional eating and cognitive mechanisms that contribute to eating during sad mood remain equivocal .',\n", " 'line_number': 1,\n", " 'total_lines': 10}]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_samples[:14]" ] }, { "cell_type": "code", "execution_count": 16, "id": "bffbe3ea", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "train_df=pd.DataFrame(train_samples)\n", "val_df=pd.DataFrame(val_samples)\n", "test_df=pd.DataFrame(test_samples)" ] }, { "cell_type": "code", "execution_count": 17, "id": "b9d2b25f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
targettextline_numbertotal_lines
0OBJECTIVEto investigate the efficacy of @ weeks of dail...011
1METHODSa total of @ patients with primary knee oa wer...111
2METHODSoutcome measures included pain reduction and i...211
3METHODSpain was assessed using the visual analog pain...311
4METHODSsecondary outcome measures included the wester...411
5METHODSserum levels of interleukin @ ( il-@ ) , il-@ ...511
6RESULTSthere was a clinically relevant reduction in t...611
7RESULTSthe mean difference between treatment arms ( @...711
8RESULTSfurther , there was a clinically relevant redu...811
9RESULTSthese differences remained significant at @ we...911
10RESULTSthe outcome measures in rheumatology clinical ...1011
11CONCLUSIONSlow-dose oral prednisolone had both a short-te...1111
12BACKGROUNDemotional eating is associated with overeating...010
13BACKGROUNDyet , empirical evidence for individual ( trai...110
\n", "
" ], "text/plain": [ " target text \\\n", "0 OBJECTIVE to investigate the efficacy of @ weeks of dail... \n", "1 METHODS a total of @ patients with primary knee oa wer... \n", "2 METHODS outcome measures included pain reduction and i... \n", "3 METHODS pain was assessed using the visual analog pain... \n", "4 METHODS secondary outcome measures included the wester... \n", "5 METHODS serum levels of interleukin @ ( il-@ ) , il-@ ... \n", "6 RESULTS there was a clinically relevant reduction in t... \n", "7 RESULTS the mean difference between treatment arms ( @... \n", "8 RESULTS further , there was a clinically relevant redu... \n", "9 RESULTS these differences remained significant at @ we... \n", "10 RESULTS the outcome measures in rheumatology clinical ... \n", "11 CONCLUSIONS low-dose oral prednisolone had both a short-te... \n", "12 BACKGROUND emotional eating is associated with overeating... \n", "13 BACKGROUND yet , empirical evidence for individual ( trai... \n", "\n", " line_number total_lines \n", "0 0 11 \n", "1 1 11 \n", "2 2 11 \n", "3 3 11 \n", "4 4 11 \n", "5 5 11 \n", "6 6 11 \n", "7 7 11 \n", "8 8 11 \n", "9 9 11 \n", "10 10 11 \n", "11 11 11 \n", "12 0 10 \n", "13 1 10 " ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.head(14)" ] }, { "cell_type": "code", "execution_count": 18, "id": "c2a36602", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "target\n", "METHODS 59353\n", "RESULTS 57953\n", "CONCLUSIONS 27168\n", "BACKGROUND 21727\n", "OBJECTIVE 13839\n", "Name: count, dtype: int64" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df.target.value_counts()" ] }, { "cell_type": "code", "execution_count": 19, "id": "9cdfddf2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "train_df.total_lines.plot.hist()" ] }, { "cell_type": "code", "execution_count": 20, "id": "fe6844bc", "metadata": {}, "outputs": [], "source": [ "train_sentences=train_df['text'].tolist()\n", "val_sentences=val_df['text'].tolist()\n", "test_sentences=test_df['text'].tolist()" ] }, { "cell_type": "code", "execution_count": 21, "id": "65973451", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['to investigate the efficacy of @ weeks of daily low-dose oral prednisolone in improving pain , mobility , and systemic low-grade inflammation in the short term and whether the effect would be sustained at @ weeks in older adults with moderate to severe knee osteoarthritis ( oa ) .',\n", " 'a total of @ patients with primary knee oa were randomized @:@ ; @ received @ mg/day of prednisolone and @ received placebo for @ weeks .',\n", " 'outcome measures included pain reduction and improvement in function scores and systemic inflammation markers .',\n", " 'pain was assessed using the visual analog pain scale ( @-@ mm ) .',\n", " 'secondary outcome measures included the western ontario and mcmaster universities osteoarthritis index scores , patient global assessment ( pga ) of the severity of knee oa , and @-min walk distance ( @mwd ) .',\n", " 'serum levels of interleukin @ ( il-@ ) , il-@ , tumor necrosis factor ( tnf ) - , and high-sensitivity c-reactive protein ( hscrp ) were measured .',\n", " 'there was a clinically relevant reduction in the intervention group compared to the placebo group for knee pain , physical function , pga , and @mwd at @ weeks .',\n", " 'the mean difference between treatment arms ( @ % ci ) was @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; @ ( @-@ @ ) , p < @ ; and @ ( @-@ @ ) , p < @ , respectively .',\n", " 'further , there was a clinically relevant reduction in the serum levels of il-@ , il-@ , tnf - , and hscrp at @ weeks in the intervention group when compared to the placebo group .',\n", " 'these differences remained significant at @ weeks .']" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_sentences[:10]" ] }, { "cell_type": "markdown", "id": "08bb7874", "metadata": {}, "source": [ "# Text to numbers " ] }, { "cell_type": "code", "execution_count": 22, "id": "ca820e30", "metadata": {}, "outputs": [], "source": [ "#one hot encode labels\n", "from sklearn.preprocessing import OneHotEncoder\n", "one_hot_encoder=OneHotEncoder(sparse_output=False)\n", "train_labels_one_hot=one_hot_encoder.fit_transform(train_df[\"target\"].to_numpy().reshape(-1,1))\n", "val_labels_one_hot=one_hot_encoder.transform(val_df[\"target\"].to_numpy().reshape(-1,1))\n", "test_label_one_hot=one_hot_encoder.transform(test_df[\"target\"].to_numpy().reshape(-1,1))\n" ] }, { "cell_type": "code", "execution_count": 23, "id": "fda6a6c1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf.constant(train_labels_one_hot)" ] }, { "cell_type": "code", "execution_count": 44, "id": "45f80b74", "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import LabelEncoder\n", "labelencoder=LabelEncoder()\n", "train_labels_encoder=labelencoder.fit_transform(train_df[\"target\"].to_numpy())\n", "val_labels_encoder=labelencoder.transform(val_df[\"target\"].to_numpy())\n", "test_label_encoder=labelencoder.transform(test_df[\"target\"].to_numpy())\n" ] }, { "cell_type": "code", "execution_count": 24, "id": "7bd78a9b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([3, 2, 2, ..., 4, 1, 1])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_labels_encoder" ] }, { "cell_type": "code", "execution_count": 25, "id": "d0a80828", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['BACKGROUND', 'CONCLUSIONS', 'METHODS', 'OBJECTIVE', 'RESULTS'],\n", " dtype=object)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class_names=labelencoder.classes_\n", "class_names" ] }, { "cell_type": "markdown", "id": "49e9e8b8", "metadata": {}, "source": [ "# Model 0" ] }, { "cell_type": "code", "execution_count": 26, "id": "f55fb274", "metadata": {}, "outputs": [], "source": [ "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.naive_bayes import MultinomialNB\n", "from sklearn.pipeline import Pipeline\n", "\n", "model_0= Pipeline (\n", "[\n", " (\"Tdidf\", TfidfVectorizer()),\n", " (\"Naive-bayes(classification)\", MultinomialNB())\n", "])\n" ] }, { "cell_type": "code", "execution_count": 27, "id": "5289a2e2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('Tdidf', TfidfVectorizer()),\n",
       "                ('Naive-bayes(classification)', MultinomialNB())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('Tdidf', TfidfVectorizer()),\n", " ('Naive-bayes(classification)', MultinomialNB())])" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_0.fit(train_sentences,train_labels_encoder)" ] }, { "cell_type": "code", "execution_count": 28, "id": "91ce1c82", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7218323844829869" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_0.score(val_sentences,val_labels_encoder)" ] }, { "cell_type": "code", "execution_count": 29, "id": "004a4f8e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([4, 1, 3, ..., 4, 4, 1])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "baseline_preds=model_0.predict(val_sentences)\n", "baseline_preds" ] }, { "cell_type": "code", "execution_count": 30, "id": "02c365fc", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score,precision_recall_fscore_support\n", "def calculate_results(y_true,y_pred):\n", " model_accuracy=accuracy_score(y_true,y_pred)*100\n", " model_precision,recall_score,f1_score, _ = precision_recall_fscore_support(y_true,y_pred,average=\"weighted\")\n", " #Compute precision, recall, F-measure and support for each class.\n", " results={ \"accuracy\":model_accuracy,\n", " \"precision\":model_precision*100,\n", " \"recall\":recall_score*100,\n", " \"F1-score\":f1_score*100}\n", " return results\n", "\n" ] }, { "cell_type": "code", "execution_count": 31, "id": "bfdd9abb", "metadata": {}, "outputs": [], "source": [ "model_0_reults=calculate_results(val_labels_encoder,baseline_preds)" ] }, { "cell_type": "code", "execution_count": 32, "id": "bed2ff0c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'accuracy': 72.1832384482987,\n", " 'precision': 71.86466952323352,\n", " 'recall': 72.1832384482987,\n", " 'F1-score': 69.89250353450294}" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_0_reults" ] }, { "cell_type": "markdown", "id": "266b14e1", "metadata": {}, "source": [ "# Model-1 Conv1D with token embeddings" ] }, { "cell_type": "code", "execution_count": 33, "id": "71f0e787", "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization\n", "tokenize=TextVectorization(max_tokens=None,standardize='lower_and_strip_punctuation',\n", " ngrams=None,\n", " output_mode='int',\n", " output_sequence_length=None\n", ")" ] }, { "cell_type": "code", "execution_count": 34, "id": "b5d49a8b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['to',\n", " 'investigate',\n", " 'the',\n", " 'efficacy',\n", " 'of',\n", " '@',\n", " 'weeks',\n", " 'of',\n", " 'daily',\n", " 'low-dose',\n", " 'oral',\n", " 'prednisolone',\n", " 'in',\n", " 'improving',\n", " 'pain',\n", " ',',\n", " 'mobility',\n", " ',',\n", " 'and',\n", " 'systemic',\n", " 'low-grade',\n", " 'inflammation',\n", " 'in',\n", " 'the',\n", " 'short',\n", " 'term',\n", " 'and',\n", " 'whether',\n", " 'the',\n", " 'effect',\n", " 'would',\n", " 'be',\n", " 'sustained',\n", " 'at',\n", " '@',\n", " 'weeks',\n", " 'in',\n", " 'older',\n", " 'adults',\n", " 'with',\n", " 'moderate',\n", " 'to',\n", " 'severe',\n", " 'knee',\n", " 'osteoarthritis',\n", " '(',\n", " 'oa',\n", " ')',\n", " '.']" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_sentences[0].split()" ] }, { "cell_type": "code", "execution_count": 35, "id": "48a0aeb8", "metadata": {}, "outputs": [], "source": [ "a=0\n", "for i in train_sentences:\n", " a=a+(len(i.split()))\n", " " ] }, { "cell_type": "code", "execution_count": 36, "id": "8cd69f1d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "26" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "avg_no_words=round(a/len(train_sentences))\n", "avg_no_words" ] }, { "cell_type": "code", "execution_count": 37, "id": "7b904ee0", "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 38, "id": "1adbb05f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(296, 26.338269273494777)" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lenghts=[len(sentence.split()) for sentence in train_sentences]\n", "avg=np.mean(lenghts)\n", "max(lenghts),avg" ] }, { "cell_type": "code", "execution_count": 39, "id": "921c1c25", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.hist(lenghts,bins=20);" ] }, { "cell_type": "code", "execution_count": 40, "id": "6b2085ab", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "55.0" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# what sentence length covers 95% of the samples\n", "percentile=np.percentile(lenghts,95)\n", "percentile" ] }, { "cell_type": "code", "execution_count": 41, "id": "959ae373", "metadata": {}, "outputs": [], "source": [ "# 55 sentence length is the highest\n", "max_vocab_words=64000\n", "max_length=55\n", "text_tokenize=TextVectorization(\n", " max_tokens=max_vocab_words,\n", " output_mode='int',\n", " output_sequence_length=max_length\n", " )" ] }, { "cell_type": "code", "execution_count": 42, "id": "eee5dbfc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-06-13 17:41:34.825717: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n" ] } ], "source": [ "text_tokenize.adapt(train_sentences)" ] }, { "cell_type": "code", "execution_count": 43, "id": "118c5439", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original sentence use of a diclofenac sodium suppository ( @mg ) was allowed for all patients at any time after surgery , and the diclofenac sodium suppository usage was assessed .\n", "\n", "lengh of text 29\n", "\n", "Vectorized text [[ 87 4 8 2868 764 6735 68 10 2583 11 62 12 15 262\n", " 63 21 115 3 2 2868 764 6735 2637 10 113 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0]]\n" ] } ], "source": [ "import random\n", "random_sentence=random.choice(train_sentences)\n", "print(f\"Original sentence {random_sentence}\\n\")\n", "print(f\"lengh of text {len (random_sentence.split())}\\n\")\n", "print(f\"Vectorized text {text_tokenize([random_sentence])}\")" ] }, { "cell_type": "code", "execution_count": 44, "id": "4684e07b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of words in vocab 64000\n", "Top 5 most frquent words ['', '[UNK]', 'the', 'and', 'of']\n", "Top 5 least frquent words ['andbehavior', 'andat', 'andapplication', 'ancovamean', 'ancovaadjusted']\n" ] } ], "source": [ "text_vocab=text_tokenize.get_vocabulary()\n", "print(f\"Number of words in vocab { len(text_vocab)}\")\n", "print (f\"Top 5 most frquent words {text_vocab[:5]}\")\n", "print (f\"Top 5 least frquent words {text_vocab[-5:]}\")\n" ] }, { "cell_type": "code", "execution_count": 45, "id": "526ef3d8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'name': 'text_vectorization_1',\n", " 'trainable': True,\n", " 'dtype': 'string',\n", " 'batch_input_shape': (None,),\n", " 'max_tokens': 64000,\n", " 'standardize': 'lower_and_strip_punctuation',\n", " 'split': 'whitespace',\n", " 'ngrams': None,\n", " 'output_mode': 'int',\n", " 'output_sequence_length': 55,\n", " 'pad_to_max_tokens': False,\n", " 'sparse': False,\n", " 'ragged': False,\n", " 'vocabulary': None,\n", " 'idf_weights': None,\n", " 'encoding': 'utf-8',\n", " 'vocabulary_size': 64000}" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text_tokenize.get_config()" ] }, { "cell_type": "code", "execution_count": 46, "id": "0871ad6c", "metadata": {}, "outputs": [], "source": [ "token_embedding=tf.keras.layers.Embedding(input_dim=len(text_vocab),\n", " output_dim=128,\n", " mask_zero=False\n", ")" ] }, { "cell_type": "code", "execution_count": 47, "id": "a31d9736", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original sentence use of a diclofenac sodium suppository ( @mg ) was allowed for all patients at any time after surgery , and the diclofenac sodium suppository usage was assessed .\n", "\n", "Vectorized sentence [[ 87 4 8 2868 764 6735 68 10 2583 11 62 12 15 262\n", " 63 21 115 3 2 2868 764 6735 2637 10 113 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0]]\n", "\n", "Embedded sentence [[[ 0.04358969 0.00608678 -0.02475305 ... 0.02248098 -0.00060068\n", " -0.02649772]\n", " [ 0.01575196 -0.01732622 -0.03623853 ... 0.04600047 0.02817501\n", " 0.01843751]\n", " [ 0.02404206 -0.04182126 0.01860196 ... -0.02644104 0.04451705\n", " 0.03283714]\n", " ...\n", " [ 0.02993966 -0.0383271 0.00891661 ... 0.04420439 -0.01741441\n", " -0.01684239]\n", " [ 0.02993966 -0.0383271 0.00891661 ... 0.04420439 -0.01741441\n", " -0.01684239]\n", " [ 0.02993966 -0.0383271 0.00891661 ... 0.04420439 -0.01741441\n", " -0.01684239]]]\n", "Embedded sentence shape: (1, 55, 128)\n" ] } ], "source": [ "print(f\"Original sentence {random_sentence}\\n\")\n", "print(f\"Vectorized sentence {text_tokenize([random_sentence])}\\n\")\n", "print(f\"Embedded sentence {token_embedding(text_tokenize([random_sentence]))}\")\n", "print(f\"Embedded sentence shape: {token_embedding(text_tokenize([random_sentence])).shape}\")\n" ] }, { "cell_type": "markdown", "id": "c2e466a4", "metadata": {}, "source": [ "The tf.data.Dataset.from_tensor_slices() function is a utility in TensorFlow that creates a tf.data.Dataset object from one or more input tensors. This function is particularly useful for creating an input pipeline for training or inference tasks. " ] }, { "cell_type": "code", "execution_count": 48, "id": "79270549", "metadata": {}, "outputs": [], "source": [ "train_dataset=tf.data.Dataset.from_tensor_slices((train_sentences,train_labels_one_hot))\n", "val_dataset=tf.data.Dataset.from_tensor_slices((val_sentences,val_labels_one_hot))\n", "test_dataset=tf.data.Dataset.from_tensor_slices((test_sentences,test_label_one_hot))" ] }, { "cell_type": "code", "execution_count": 49, "id": "2ee4c907", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<_TensorSliceDataset element_spec=(TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(5,), dtype=tf.float64, name=None))>" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataset" ] }, { "cell_type": "code", "execution_count": 50, "id": "4cec70fc", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<_PrefetchDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.string, name=None), TensorSpec(shape=(None, 5), dtype=tf.float64, name=None))>" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataset = train_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n", "val_dataset = val_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n", "test_dataset = test_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n", "\n", "train_dataset" ] }, { "cell_type": "code", "execution_count": 51, "id": "17e8f7a7", "metadata": {}, "outputs": [], "source": [ "inputs=tf.keras.layers.Input(shape=(1,), dtype=\"string\")\n", "text_vector=text_tokenize(inputs)\n", "text_embedding=token_embedding(text_vector)\n", "x=tf.keras.layers.Conv1D(64,kernel_size=5,padding=\"same\",activation=\"relu\")(text_embedding)\n", "x=tf.keras.layers.GlobalAveragePooling1D()(x)\n", "outputs=tf.keras.layers.Dense(len(class_names),activation=\"softmax\")(x)\n", "model_1=tf.keras.Model(inputs,outputs)" ] }, { "cell_type": "code", "execution_count": 52, "id": "69b5a4f4", "metadata": {}, "outputs": [], "source": [ "model_1.compile(loss=tf.keras.losses.categorical_crossentropy,\n", " optimizer=tf.keras.optimizers.legacy.Adam(),\n", " metrics=[\"accuracy\"])\n", " " ] }, { "cell_type": "code", "execution_count": 53, "id": "b31bba3f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " input_1 (InputLayer) [(None, 1)] 0 \n", " \n", " text_vectorization_1 (TextV (None, 55) 0 \n", " ectorization) \n", " \n", " embedding (Embedding) (None, 55, 128) 8192000 \n", " \n", " conv1d (Conv1D) (None, 55, 64) 41024 \n", " \n", " global_average_pooling1d (G (None, 64) 0 \n", " lobalAveragePooling1D) \n", " \n", " dense (Dense) (None, 5) 325 \n", " \n", "=================================================================\n", "Total params: 8,233,349\n", "Trainable params: 8,233,349\n", "Non-trainable params: 0\n", "_________________________________________________________________\n" ] } ], "source": [ "model_1.summary()\n" ] }, { "cell_type": "code", "execution_count": 54, "id": "be4038e3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "562/562 [==============================] - 27s 47ms/step - loss: 0.9223 - accuracy: 0.6340 - val_loss: 0.6906 - val_accuracy: 0.7307\n", "Epoch 2/3\n", "562/562 [==============================] - 24s 44ms/step - loss: 0.6571 - accuracy: 0.7563 - val_loss: 0.6362 - val_accuracy: 0.7689\n", "Epoch 3/3\n", "562/562 [==============================] - 27s 48ms/step - loss: 0.6151 - accuracy: 0.7751 - val_loss: 0.5962 - val_accuracy: 0.7836\n" ] } ], "source": [ "model_1_history = model_1.fit(train_dataset,\n", " steps_per_epoch=int(0.1 * len(train_dataset)),\n", " epochs=3,\n", " validation_data=val_dataset,\n", " validation_steps=int(0.1 * len(val_dataset)))" ] }, { "cell_type": "code", "execution_count": 56, "id": "52cd33b1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "945/945 [==============================] - 2s 2ms/step\n" ] }, { "data": { "text/plain": [ "array([[4.7636923e-01, 1.2716618e-01, 5.3757947e-02, 3.2397968e-01,\n", " 1.8726962e-02],\n", " [3.8611493e-01, 3.2393026e-01, 1.3540406e-02, 2.6583520e-01,\n", " 1.0579119e-02],\n", " [1.3623297e-01, 6.9641406e-03, 1.8214269e-03, 8.5494816e-01,\n", " 3.3247936e-05],\n", " ...,\n", " [1.3169743e-05, 6.1650638e-04, 8.4461336e-04, 5.1141324e-06,\n", " 9.9852055e-01],\n", " [5.5400979e-02, 4.4984341e-01, 8.8301055e-02, 7.4640714e-02,\n", " 3.3181381e-01],\n", " [1.5156463e-01, 7.0157272e-01, 4.8869491e-02, 4.7299001e-02,\n", " 5.0694138e-02]], dtype=float32)" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_probs=model_1.predict(val_dataset)\n", "model_probs" ] }, { "cell_type": "code", "execution_count": 57, "id": "2af9cf78", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_1_pred=tf.argmax(model_probs,axis=1)\n", "model_1_pred" ] }, { "cell_type": "code", "execution_count": 58, "id": "5058b7a1", "metadata": {}, "outputs": [], "source": [ "model_1_results=calculate_results(val_labels_encoder,model_1_pred)" ] }, { "cell_type": "code", "execution_count": 59, "id": "8be33cf9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'accuracy': 78.7104461803257,\n", " 'precision': 78.3333456187092,\n", " 'recall': 78.7104461803257,\n", " 'F1-score': 78.41304194274046}" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_1_results" ] }, { "cell_type": "code", "execution_count": 60, "id": "c733c9fc", "metadata": {}, "outputs": [], "source": [ "# Model-2 Using Pretrained Model\n", "import tensorflow_hub as hub\n", "tf_hub_embedding_layer=hub.KerasLayer(\"https://tfhub.dev/google/universal-sentence-encoder/4\",\n", " trainable=False)" ] }, { "cell_type": "code", "execution_count": 36, "id": "61745f58", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random training sentence:\n", "missing data were imputed using multiple imputation .\n", "\n" ] }, { "ename": "NameError", "evalue": "name 'tf_hub_embedding_layer' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[36], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m random_training_sentence \u001b[38;5;241m=\u001b[39m random\u001b[38;5;241m.\u001b[39mchoice(train_sentences)\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRandom training sentence:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mrandom_training_sentence\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m use_embedded_sentence \u001b[38;5;241m=\u001b[39m \u001b[43mtf_hub_embedding_layer\u001b[49m([random_training_sentence])\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSentence after embedding:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00muse_embedded_sentence[\u001b[38;5;241m0\u001b[39m][:\u001b[38;5;241m30\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLength of sentence embedding:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(use_embedded_sentence[\u001b[38;5;241m0\u001b[39m])\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", "\u001b[0;31mNameError\u001b[0m: name 'tf_hub_embedding_layer' is not defined" ] } ], "source": [ "import random\n", "random_training_sentence = random.choice(train_sentences)\n", "print(f\"Random training sentence:\\n{random_training_sentence}\\n\")\n", "use_embedded_sentence = tf_hub_embedding_layer([random_training_sentence])\n", "print(f\"Sentence after embedding:\\n{use_embedded_sentence[0][:30]}\\n\")\n", "print(f\"Length of sentence embedding:\\n{len(use_embedded_sentence[0])}\")\n", " " ] }, { "cell_type": "code", "execution_count": 62, "id": "583d8417", "metadata": {}, "outputs": [], "source": [ "inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)\n", "pre_trained_embedding=tf_hub_embedding_layer(inputs)\n", "x=tf.keras.layers.Dense(128,activation=\"relu\")(pre_trained_embedding)\n", "outputs=tf.keras.layers.Dense(len(class_names),activation=\"softmax\")(x)\n", "model_2=tf.keras.Model(inputs,outputs)" ] }, { "cell_type": "code", "execution_count": 63, "id": "fa1a7665", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_1\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " input_2 (InputLayer) [(None,)] 0 \n", " \n", " keras_layer (KerasLayer) (None, 512) 256797824 \n", " \n", " dense_1 (Dense) (None, 128) 65664 \n", " \n", " dense_2 (Dense) (None, 5) 645 \n", " \n", "=================================================================\n", "Total params: 256,864,133\n", "Trainable params: 66,309\n", "Non-trainable params: 256,797,824\n", "_________________________________________________________________\n" ] } ], "source": [ "model_2.summary()\n" ] }, { "cell_type": "code", "execution_count": 64, "id": "6653f2a6", "metadata": {}, "outputs": [], "source": [ "model_2.compile(loss=\"categorical_crossentropy\",\n", " optimizer=tf.keras.optimizers.legacy.Adam(),\n", " metrics=[\"accuracy\"])" ] }, { "cell_type": "code", "execution_count": 65, "id": "106a4b02", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "562/562 [==============================] - 49s 86ms/step - loss: 0.9129 - accuracy: 0.6509 - val_loss: 0.7937 - val_accuracy: 0.6912\n", "Epoch 2/3\n", "562/562 [==============================] - 45s 80ms/step - loss: 0.7674 - accuracy: 0.7020 - val_loss: 0.7538 - val_accuracy: 0.7061\n", "Epoch 3/3\n", "562/562 [==============================] - 61s 109ms/step - loss: 0.7503 - accuracy: 0.7122 - val_loss: 0.7382 - val_accuracy: 0.7131\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fit feature extractor model for 3 epochs\n", "model_2.fit(train_dataset,\n", " steps_per_epoch=int(0.1 * len(train_dataset)),\n", " epochs=3,\n", " validation_data=val_dataset,\n", " validation_steps=int(0.1 * len(val_dataset)))" ] }, { "cell_type": "code", "execution_count": 66, "id": "7dd2f572", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "945/945 [==============================] - 82s 86ms/step\n" ] }, { "data": { "text/plain": [ "array([[4.4031423e-01, 3.4722412e-01, 2.5081097e-03, 1.9981924e-01,\n", " 1.0134370e-02],\n", " [3.3964318e-01, 4.9833187e-01, 3.6500413e-03, 1.5515944e-01,\n", " 3.2154536e-03],\n", " [2.6505971e-01, 1.5719941e-01, 1.9075207e-02, 5.1727098e-01,\n", " 4.1394755e-02],\n", " ...,\n", " [1.8576552e-03, 6.5741059e-03, 5.2557763e-02, 9.0209109e-04,\n", " 9.3810833e-01],\n", " [4.2318460e-03, 4.7142524e-02, 1.9696368e-01, 1.3423875e-03,\n", " 7.5031954e-01],\n", " [1.6839372e-01, 2.3914343e-01, 5.2446765e-01, 5.4206112e-03,\n", " 6.2574625e-02]], dtype=float32)" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_2_pred_probs = model_2.predict(val_dataset)\n", "model_2_pred_probs" ] }, { "cell_type": "code", "execution_count": 68, "id": "8b855f9b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_2_preds = tf.argmax(model_2_pred_probs, axis=1)\n", "model_2_preds" ] }, { "cell_type": "code", "execution_count": 69, "id": "98ac4f0e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'accuracy': 71.39216205481266,\n", " 'precision': 71.42823673632186,\n", " 'recall': 71.39216205481266,\n", " 'F1-score': 71.10159521803406}" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_2_results = calculate_results(val_labels_encoder,\n", " model_2_preds)\n", "model_2_results\n", " " ] }, { "cell_type": "markdown", "id": "a70d9099", "metadata": {}, "source": [ "# model 3-conv 1D using character embeddings\n" ] }, { "cell_type": "code", "execution_count": 70, "id": "dfea44d8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['to investigate the efficacy of @ weeks of daily low-dose oral prednisolone in improving pain , mobility , and systemic low-grade inflammation in the short term and whether the effect would be sustained at @ weeks in older adults with moderate to severe knee osteoarthritis ( oa ) .',\n", " 'a total of @ patients with primary knee oa were randomized @:@ ; @ received @ mg/day of prednisolone and @ received placebo for @ weeks .',\n", " 'outcome measures included pain reduction and improvement in function scores and systemic inflammation markers .',\n", " 'pain was assessed using the visual analog pain scale ( @-@ mm ) .',\n", " 'secondary outcome measures included the western ontario and mcmaster universities osteoarthritis index scores , patient global assessment ( pga ) of the severity of knee oa , and @-min walk distance ( @mwd ) .']" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_sentences[:5]" ] }, { "cell_type": "code", "execution_count": 35, "id": "6938a248", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'random_training_sentence' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[35], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msplit_to_char\u001b[39m(text):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;28mlist\u001b[39m(text))\n\u001b[0;32m----> 4\u001b[0m split_to_char(\u001b[43mrandom_training_sentence\u001b[49m)\n", "\u001b[0;31mNameError\u001b[0m: name 'random_training_sentence' is not defined" ] } ], "source": [ "def split_to_char(text):\n", " return \" \" .join(list(text))\n", "\n", "split_to_char(random_training_sentence)" ] }, { "cell_type": "code", "execution_count": 72, "id": "2a52b468", "metadata": {}, "outputs": [], "source": [ "train_chars=[split_to_char(sentence) for sentence in train_sentences]\n", "val_chars=[split_to_char(sentence) for sentence in val_sentences]\n", "test_chars=[split_to_char(sentence) for sentence in test_sentences]" ] }, { "cell_type": "code", "execution_count": 73, "id": "0e2a6583", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['t o i n v e s t i g a t e t h e e f f i c a c y o f @ w e e k s o f d a i l y l o w - d o s e o r a l p r e d n i s o l o n e i n i m p r o v i n g p a i n , m o b i l i t y , a n d s y s t e m i c l o w - g r a d e i n f l a m m a t i o n i n t h e s h o r t t e r m a n d w h e t h e r t h e e f f e c t w o u l d b e s u s t a i n e d a t @ w e e k s i n o l d e r a d u l t s w i t h m o d e r a t e t o s e v e r e k n e e o s t e o a r t h r i t i s ( o a ) .',\n", " 'a t o t a l o f @ p a t i e n t s w i t h p r i m a r y k n e e o a w e r e r a n d o m i z e d @ : @ ; @ r e c e i v e d @ m g / d a y o f p r e d n i s o l o n e a n d @ r e c e i v e d p l a c e b o f o r @ w e e k s .',\n", " 'o u t c o m e m e a s u r e s i n c l u d e d p a i n r e d u c t i o n a n d i m p r o v e m e n t i n f u n c t i o n s c o r e s a n d s y s t e m i c i n f l a m m a t i o n m a r k e r s .',\n", " 'p a i n w a s a s s e s s e d u s i n g t h e v i s u a l a n a l o g p a i n s c a l e ( @ - @ m m ) .',\n", " 's e c o n d a r y o u t c o m e m e a s u r e s i n c l u d e d t h e w e s t e r n o n t a r i o a n d m c m a s t e r u n i v e r s i t i e s o s t e o a r t h r i t i s i n d e x s c o r e s , p a t i e n t g l o b a l a s s e s s m e n t ( p g a ) o f t h e s e v e r i t y o f k n e e o a , a n d @ - m i n w a l k d i s t a n c e ( @ m w d ) .']" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_chars[:5]" ] }, { "cell_type": "code", "execution_count": 74, "id": "4795ada4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(149.3662574983337,)" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "char_length=[len(sentence) for sentence in train_sentences]\n", "mean_length=np.mean(char_length)\n", "mean_length," ] }, { "cell_type": "code", "execution_count": 75, "id": "92d8585d", "metadata": {}, "outputs": [], "source": [ "char_percentile =np.percentile(char_length,95)" ] }, { "cell_type": "code", "execution_count": 76, "id": "8c709205", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "290.0" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "char_percentile # most characters are 290 or less" ] }, { "cell_type": "code", "execution_count": 77, "id": "c6390ae9", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.hist(char_length,bins=7);" ] }, { "cell_type": "code", "execution_count": 78, "id": "3d6fc1d4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'abcdefghijklmnopqrstuvwxyz0123456789!\"#$%&\\'()*+,-./:;<=>?@[\\\\]^_`{|}~'" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import string\n", "alphabet= string.ascii_lowercase+string.digits+string.punctuation\n", "alphabet" ] }, { "cell_type": "code", "execution_count": 79, "id": "f64d31f8", "metadata": {}, "outputs": [], "source": [ "max_length_tokens=len(alphabet)+2\n", "char_vectorization=TextVectorization( \n", " max_tokens=max_length_tokens,\n", " output_mode='int',\n", " output_sequence_length=int(char_percentile)\n", " )" ] }, { "cell_type": "code", "execution_count": 80, "id": "988b1891", "metadata": {}, "outputs": [], "source": [ "char_vectorization.adapt(train_chars)" ] }, { "cell_type": "code", "execution_count": 81, "id": "21c2c1c7", "metadata": {}, "outputs": [], "source": [ "char_vocab=char_vectorization.get_vocabulary()" ] }, { "cell_type": "code", "execution_count": 82, "id": "38fb5281", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['', '[UNK]', 'e', 't', 'i', 'a', 'n', 'o', 'r', 's']" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ "char_vocab[:10]" ] }, { "cell_type": "code", "execution_count": 83, "id": "1920ef57", "metadata": {}, "outputs": [], "source": [ "char_embedding=tf.keras.layers.Embedding(input_dim=max_length_tokens,\n", " output_dim=25,\n", " mask_zero=None\n", ")" ] }, { "cell_type": "code", "execution_count": 85, "id": "dc372178", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original text: \n", " p a r e c o x i b s o d i u m a n a l g e s i a r e d u c e s t h e r a t e o f p o d a n d p o c d i n e l d e r l y p a t i e n t s w i t h n e u r o p r o t e c t i v e e f f e c t s .\n", "\n", "Vectorized Text [[14 5 8 2 11 7 24 4 22 9 7 10 4 16 15 5 6 5 12 18 2 9 4 5\n", " 8 2 10 16 11 2 9 3 13 2 8 5 3 2 7 17 14 7 10 5 6 10 14 7\n", " 11 10 4 6 2 12 10 2 8 12 19 14 5 3 4 2 6 3 9 20 4 3 13 6\n", " 2 16 8 7 14 8 7 3 2 11 3 4 21 2 2 17 17 2 11 3 9 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", " 0 0]]\n", "\n", "Embedded Text [[[ 0.04685961 0.03016916 0.0077801 ... -0.04590077 0.01184995\n", " -0.04785159]\n", " [ 0.00771859 0.00036714 -0.01702696 ... 0.02873543 -0.03228469\n", " -0.04450686]\n", " [ 0.02558073 0.00907476 0.02239788 ... 0.01857281 -0.01394367\n", " 0.00124573]\n", " ...\n", " [ 0.04484713 0.03263881 -0.00276654 ... -0.00651003 0.0304917\n", " -0.02310417]\n", " [ 0.04484713 0.03263881 -0.00276654 ... -0.00651003 0.0304917\n", " -0.02310417]\n", " [ 0.04484713 0.03263881 -0.00276654 ... -0.00651003 0.0304917\n", " -0.02310417]]]\n", "(1, 290, 25)\n" ] } ], "source": [ "random_char=random.choice(train_chars)\n", "print(f\"Original text: \\n {random_char}\\n\")\n", "print(f\"Vectorized Text {char_vectorization([random_char])}\\n\")\n", "char_embed=char_embedding(char_vectorization([random_char]))\n", "print(f\"Embedded Text {char_embed}\")\n", "print(char_embed.shape)" ] }, { "cell_type": "code", "execution_count": 86, "id": "c967a2ad", "metadata": {}, "outputs": [], "source": [ "inputs=tf.keras.layers.Input(shape=(1,), dtype=\"string\")\n", "char_vector=char_vectorization(inputs)\n", "embedding=char_embedding(char_vector)\n", "x=tf.keras.layers.Conv1D(64,5,padding=\"same\",activation=\"relu\") (embedding)\n", "x=tf.keras.layers.GlobalMaxPool1D()(x)\n", "outputs=tf.keras.layers.Dense(len(class_names),activation=\"softmax\")(x)\n", "model_3=tf.keras.Model(inputs,outputs)" ] }, { "cell_type": "code", "execution_count": 87, "id": "fd688358", "metadata": {}, "outputs": [], "source": [ "model_3.compile(loss=\"categorical_crossentropy\",\n", " optimizer=tf.keras.optimizers.legacy.Adam(),\n", " metrics=[\"accuracy\"])" ] }, { "cell_type": "code", "execution_count": 88, "id": "cea8bb5d", "metadata": {}, "outputs": [], "source": [ "train_char_dataset=tf.data.Dataset.from_tensor_slices((train_chars,train_labels_one_hot)).batch(32).prefetch(tf.data.AUTOTUNE)\n", "val_char_dataset=tf.data.Dataset.from_tensor_slices((val_chars,val_labels_one_hot)).batch(32).prefetch(tf.data.AUTOTUNE)\n", "test_char_dataset=tf.data.Dataset.from_tensor_slices((test_chars,test_label_one_hot)).batch(32).prefetch(tf.data.AUTOTUNE)" ] }, { "cell_type": "code", "execution_count": 89, "id": "bc952f28", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<_PrefetchDataset element_spec=(TensorSpec(shape=(None,), dtype=tf.string, name=None), TensorSpec(shape=(None, 5), dtype=tf.float64, name=None))>" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_char_dataset" ] }, { "cell_type": "code", "execution_count": 90, "id": "6cd72c25", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "562/562 [==============================] - 15s 25ms/step - loss: 1.2866 - accuracy: 0.4727 - val_loss: 1.0617 - val_accuracy: 0.5904\n", "Epoch 2/3\n", "562/562 [==============================] - 14s 25ms/step - loss: 1.0089 - accuracy: 0.5985 - val_loss: 0.9363 - val_accuracy: 0.6307\n", "Epoch 3/3\n", "562/562 [==============================] - 16s 28ms/step - loss: 0.9260 - accuracy: 0.6371 - val_loss: 0.8641 - val_accuracy: 0.6636\n" ] } ], "source": [ "model_3_history=model_3.fit(train_char_dataset,epochs=3,steps_per_epoch=(int(0.1*(len(train_char_dataset)))),\n", " validation_data=val_char_dataset,validation_steps=(int(0.1*len(val_char_dataset))))\n", " " ] }, { "cell_type": "code", "execution_count": 91, "id": "b48b0136", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "945/945 [==============================] - 4s 4ms/step\n" ] } ], "source": [ "model_3_probs=model_3.predict(val_char_dataset)" ] }, { "cell_type": "code", "execution_count": 92, "id": "0bc4c4d0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_3_pred=tf.argmax(model_3_probs,axis=1)\n", "model_3_pred" ] }, { "cell_type": "code", "execution_count": 93, "id": "6112d216", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'accuracy': 65.76194889447902,\n", " 'precision': 65.44382499645309,\n", " 'recall': 65.76194889447902,\n", " 'F1-score': 64.57778080685951}" ] }, "execution_count": 93, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_3_results=calculate_results(val_labels_encoder,model_3_pred)\n", "model_3_results" ] }, { "cell_type": "markdown", "id": "1119a3ad", "metadata": {}, "source": [ "# Model_4 (mixing model_1 and model_3)" ] }, { "cell_type": "markdown", "id": "2cf1085f", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": 94, "id": "6545be9f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tf_hub_embedding_layer" ] }, { "cell_type": "code", "execution_count": 95, "id": "e2c516cf", "metadata": {}, "outputs": [], "source": [ "#model_1\n", "inputs=tf.keras.layers.Input(shape=[],dtype=tf.string)\n", "embedding=tf_hub_embedding_layer(inputs)\n", "output=tf.keras.layers.Dense(128,activation=\"relu\")(embedding)\n", "pre_trained_model=tf.keras.Model(inputs,output)" ] }, { "cell_type": "code", "execution_count": 96, "id": "6a6d4427", "metadata": {}, "outputs": [], "source": [ "#model_3\n", "char_input=tf.keras.layers.Input(shape=(1,), dtype=\"string\")\n", "char_vectors=char_vectorization(char_input)\n", "char_embed=char_embedding(char_vectors)\n", "char_bi_lstm=tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(25))(char_embed)\n", "char_model=tf.keras.Model(char_input,char_bi_lstm)" ] }, { "cell_type": "code", "execution_count": 97, "id": "380fc7b9", "metadata": {}, "outputs": [], "source": [ "concatenated_model=tf.keras.layers.Concatenate(name=\"concatenated_layers\")([pre_trained_model.output,char_model.output])" ] }, { "cell_type": "code", "execution_count": 98, "id": "b2e591a8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ "concatenated_model" ] }, { "cell_type": "markdown", "id": "29da904d", "metadata": {}, "source": [ "**According to paper**\n", "\n", "For regularization, dropout\n", "with a rate of 0.5 is applied to the characterenhanced token embeddings and before the label\n", "prediction layer." ] }, { "cell_type": "code", "execution_count": 99, "id": "2f73f124", "metadata": {}, "outputs": [], "source": [ "dropout=tf.keras.layers.Dropout(0.5)(concatenated_model)\n", "concatenated_output=tf.keras.layers.Dense(128,activation=\"relu\")(concatenated_model)\n", "final_dropout=tf.keras.layers.Dropout(0.5)(concatenated_output)\n", "output_layer = tf.keras.layers.Dense(len(class_names), activation=\"softmax\")(final_dropout)" ] }, { "cell_type": "code", "execution_count": 100, "id": "2c7a3582", "metadata": {}, "outputs": [], "source": [ "model_4=tf.keras.Model(inputs=[pre_trained_model.input,char_model.input],\n", " outputs=output_layer,\n", " )" ] }, { "cell_type": "code", "execution_count": 101, "id": "dab9fa40", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_5\"\n", "__________________________________________________________________________________________________\n", " Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", " input_5 (InputLayer) [(None, 1)] 0 [] \n", " \n", " input_4 (InputLayer) [(None,)] 0 [] \n", " \n", " text_vectorization_2 (TextVect (None, 290) 0 ['input_5[0][0]'] \n", " orization) \n", " \n", " keras_layer (KerasLayer) (None, 512) 256797824 ['input_4[0][0]'] \n", " \n", " embedding_1 (Embedding) (None, 290, 25) 1750 ['text_vectorization_2[1][0]'] \n", " \n", " dense_4 (Dense) (None, 128) 65664 ['keras_layer[1][0]'] \n", " \n", " bidirectional (Bidirectional) (None, 50) 10200 ['embedding_1[1][0]'] \n", " \n", " concatenated_layers (Concatena (None, 178) 0 ['dense_4[0][0]', \n", " te) 'bidirectional[0][0]'] \n", " \n", " dense_5 (Dense) (None, 128) 22912 ['concatenated_layers[0][0]'] \n", " \n", " dropout_1 (Dropout) (None, 128) 0 ['dense_5[0][0]'] \n", " \n", " dense_6 (Dense) (None, 5) 645 ['dropout_1[0][0]'] \n", " \n", "==================================================================================================\n", "Total params: 256,898,995\n", "Trainable params: 101,171\n", "Non-trainable params: 256,797,824\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "model_4.summary()" ] }, { "cell_type": "code", "execution_count": 102, "id": "31a93688", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tensorflow.keras.utils import plot_model\n", "plot_model(model_4)" ] }, { "cell_type": "code", "execution_count": 103, "id": "1e746b48", "metadata": {}, "outputs": [], "source": [ "model_4.compile(loss=\"categorical_crossentropy\",\n", " optimizer=tf.keras.optimizers.legacy.Adam(), \n", " metrics=[\"accuracy\"])\n", " " ] }, { "cell_type": "code", "execution_count": 106, "id": "914ca68b", "metadata": {}, "outputs": [], "source": [ "# Dataset\n", "train_char_token=tf.data.Dataset.from_tensor_slices((train_sentences,train_chars))\n", "train_char_labels=tf.data.Dataset.from_tensor_slices(train_labels_one_hot)\n", "combined_train_dataset=tf.data.Dataset.zip((train_char_token,train_char_labels))\n", "\n", "combined_train_dataset=combined_train_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n", "\n", "val_char_token=tf.data.Dataset.from_tensor_slices((val_sentences,val_chars))\n", "val_char_labels=tf.data.Dataset.from_tensor_slices(val_labels_one_hot)\n", "combined_val_dataset=tf.data.Dataset.zip((val_char_token,val_char_labels))\n", "\n", "combined_val_dataset=combined_val_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n", "\n" ] }, { "cell_type": "code", "execution_count": 107, "id": "9dcca50a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(<_PrefetchDataset element_spec=((TensorSpec(shape=(None,), dtype=tf.string, name=None), TensorSpec(shape=(None,), dtype=tf.string, name=None)), TensorSpec(shape=(None, 5), dtype=tf.float64, name=None))>,\n", " <_PrefetchDataset element_spec=((TensorSpec(shape=(None,), dtype=tf.string, name=None), TensorSpec(shape=(None,), dtype=tf.string, name=None)), TensorSpec(shape=(None, 5), dtype=tf.float64, name=None))>)" ] }, "execution_count": 107, "metadata": {}, "output_type": "execute_result" } ], "source": [ "combined_train_dataset,combined_val_dataset" ] }, { "cell_type": "code", "execution_count": 108, "id": "a3c13e34", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "562/562 [==============================] - 145s 249ms/step - loss: 0.9091 - accuracy: 0.6467 - val_loss: 0.7813 - val_accuracy: 0.6981\n", "Epoch 2/3\n", "562/562 [==============================] - 111s 198ms/step - loss: 0.7352 - accuracy: 0.7166 - val_loss: 0.6956 - val_accuracy: 0.7350\n", "Epoch 3/3\n", "562/562 [==============================] - 109s 194ms/step - loss: 0.7026 - accuracy: 0.7276 - val_loss: 0.6708 - val_accuracy: 0.7500\n" ] } ], "source": [ "model_4_history = model_4.fit(combined_train_dataset, \n", " steps_per_epoch=int(0.1 * len(combined_train_dataset)),\n", " epochs=3,\n", " validation_data=combined_val_dataset,\n", " validation_steps=int(0.1 * len(combined_val_dataset)))" ] }, { "cell_type": "code", "execution_count": 109, "id": "23709bb3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "945/945 [==============================] - 135s 141ms/step\n" ] } ], "source": [ "model_4_prob=model_4.predict(combined_val_dataset)" ] }, { "cell_type": "code", "execution_count": 110, "id": "ecb21881", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 110, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_4_pred=tf.argmax(model_4_prob,axis=1)\n", "model_4_pred" ] }, { "cell_type": "code", "execution_count": 111, "id": "f6af1add", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'accuracy': 74.31815172779028,\n", " 'precision': 74.56791634210045,\n", " 'recall': 74.31815172779028,\n", " 'F1-score': 74.33453044400244}" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_4_results=calculate_results(val_labels_encoder,model_4_pred)\n", "model_4_results" ] }, { "cell_type": "markdown", "id": "4aeab3be", "metadata": {}, "source": [ "# Model_5 (Model_4 +Postional Embeddings: where the sentence appears in an abstract. )" ] }, { "cell_type": "code", "execution_count": 112, "id": "2b81f8f2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
targettextline_numbertotal_lines
0OBJECTIVEto investigate the efficacy of @ weeks of dail...011
1METHODSa total of @ patients with primary knee oa wer...111
2METHODSoutcome measures included pain reduction and i...211
3METHODSpain was assessed using the visual analog pain...311
4METHODSsecondary outcome measures included the wester...411
...............
180035RESULTSfor the absolute change in percent atheroma vo...711
180036RESULTSfor pav , a significantly greater percentage o...811
180037RESULTSboth strategies had acceptable side effect pro...911
180038CONCLUSIONScompared with standard statin monotherapy , th...1011
180039CONCLUSIONS( plaque regression with cholesterol absorptio...1111
\n", "

180040 rows × 4 columns

\n", "
" ], "text/plain": [ " target text \\\n", "0 OBJECTIVE to investigate the efficacy of @ weeks of dail... \n", "1 METHODS a total of @ patients with primary knee oa wer... \n", "2 METHODS outcome measures included pain reduction and i... \n", "3 METHODS pain was assessed using the visual analog pain... \n", "4 METHODS secondary outcome measures included the wester... \n", "... ... ... \n", "180035 RESULTS for the absolute change in percent atheroma vo... \n", "180036 RESULTS for pav , a significantly greater percentage o... \n", "180037 RESULTS both strategies had acceptable side effect pro... \n", "180038 CONCLUSIONS compared with standard statin monotherapy , th... \n", "180039 CONCLUSIONS ( plaque regression with cholesterol absorptio... \n", "\n", " line_number total_lines \n", "0 0 11 \n", "1 1 11 \n", "2 2 11 \n", "3 3 11 \n", "4 4 11 \n", "... ... ... \n", "180035 7 11 \n", "180036 8 11 \n", "180037 9 11 \n", "180038 10 11 \n", "180039 11 11 \n", "\n", "[180040 rows x 4 columns]" ] }, "execution_count": 112, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df" ] }, { "cell_type": "code", "execution_count": 113, "id": "8c657ce9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAk0AAAGdCAYAAAAPLEfqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAqEElEQVR4nO3dfXAUdZ7H8U8emPCUCQZIQo5AsoJglqciQJjz4RbJMki0RLAKFCVi1MMNHBCRhz0XxLU2CCWCB8huuRKtE0H2xF3JAbIBwnlGkGDkoZaILG7gwoSokIFoHsj0/eFmljGoP8ZgD+H9qpoqpvubns90tZWPPT2dMMuyLAEAAOA7hdsdAAAA4GpAaQIAADBAaQIAADBAaQIAADBAaQIAADBAaQIAADBAaQIAADBAaQIAADAQaXeA1sLn86miokLR0dEKCwuzOw4AADBgWZbOnTunxMREhYd/97kkSlMLqaioUFJSkt0xAABAEE6cOKHu3bt/5wylqYVER0dL+nqnO51Om9MAAAATXq9XSUlJ/t/j34XS1EKaPpJzOp2UJgAArjIml9ZwITgAAIABShMAAIABShMAAIABShMAAIABShMAAIABShMAAIABShMAAIABShMAAIABShMAAIABShMAAIABShMAAIABShMAAIABShMAAIABShMAAICBSLsDwEzyvAK7I1y2Txdn2h0BAIAWQ2nCFUPRAwC0Jnw8BwAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYIDSBAAAYCDS7gBAKEmeV2B3hMv26eJMuyMAwDWBM00AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGQqY0LV68WGFhYZo5c6Z/WW1trXJyctS5c2d17NhR48ePV2VlZcDPlZeXKzMzU+3bt1dcXJyeeOIJXbhwIWBm165dGjx4sKKiotSrVy/l5+c3e/1Vq1YpOTlZbdu2VXp6uvbu3Xsl3iYAALhKhURp+uCDD/Tb3/5WAwYMCFg+a9Ysvf3229q4caOKiopUUVGhcePG+dc3NjYqMzNT9fX1eu+99/TKK68oPz9fCxYs8M8cP35cmZmZGjFihEpLSzVz5kw9/PDD2rZtm39mw4YNys3N1cKFC7V//34NHDhQbrdbp0+fvvJvHgAAXBXCLMuy7Axw/vx5DR48WKtXr9YzzzyjQYMGafny5aqurlbXrl21bt063XPPPZKkI0eO6MYbb1RxcbGGDx+uLVu26I477lBFRYXi4+MlSWvWrNHcuXNVVVUlh8OhuXPnqqCgQIcOHfK/5sSJE3X27Flt3bpVkpSenq6hQ4dq5cqVkiSfz6ekpCRNnz5d8+bNM3ofXq9XMTExqq6ultPpbMldJElKnlfQ4ttE6/Dp4ky7IwDAVetyfn/bfqYpJydHmZmZysjICFheUlKihoaGgOV9+/ZVjx49VFxcLEkqLi5W//79/YVJktxut7xerw4fPuyf+ea23W63fxv19fUqKSkJmAkPD1dGRoZ/5lLq6urk9XoDHgAAoPWKtPPF169fr/379+uDDz5ots7j8cjhcKhTp04By+Pj4+XxePwzFxempvVN675rxuv16quvvtKZM2fU2Nh4yZkjR458a/a8vDwtWrTI7I0CAICrnm1nmk6cOKEZM2botddeU9u2be2KEbT58+erurra/zhx4oTdkQAAwBVkW2kqKSnR6dOnNXjwYEVGRioyMlJFRUV64YUXFBkZqfj4eNXX1+vs2bMBP1dZWamEhARJUkJCQrNv0zU9/74Zp9Opdu3aqUuXLoqIiLjkTNM2LiUqKkpOpzPgAQAAWi/bStPIkSN18OBBlZaW+h9DhgzRpEmT/P9u06aNCgsL/T9TVlam8vJyuVwuSZLL5dLBgwcDvuW2fft2OZ1Opaam+mcu3kbTTNM2HA6H0tLSAmZ8Pp8KCwv9MwAAALZd0xQdHa1+/foFLOvQoYM6d+7sX56dna3c3FzFxsbK6XRq+vTpcrlcGj58uCRp1KhRSk1N1QMPPKAlS5bI4/HoySefVE5OjqKioiRJU6dO1cqVKzVnzhw99NBD2rFjh9544w0VFPzj22i5ubnKysrSkCFDNGzYMC1fvlw1NTWaMmXKj7Q3AABAqLP1QvDv8/zzzys8PFzjx49XXV2d3G63Vq9e7V8fERGhzZs367HHHpPL5VKHDh2UlZWlp59+2j+TkpKigoICzZo1SytWrFD37t310ksvye12+2cmTJigqqoqLViwQB6PR4MGDdLWrVubXRwOAACuXbbfp6m14D5NsAv3aQKA4F1V92kCAAC4GlCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADFCaAAAADNhaml588UUNGDBATqdTTqdTLpdLW7Zs8a+vra1VTk6OOnfurI4dO2r8+PGqrKwM2EZ5ebkyMzPVvn17xcXF6YknntCFCxcCZnbt2qXBgwcrKipKvXr1Un5+frMsq1atUnJystq2bav09HTt3bv3irxnAABwdbK1NHXv3l2LFy9WSUmJ9u3bp9tuu0133XWXDh8+LEmaNWuW3n77bW3cuFFFRUWqqKjQuHHj/D/f2NiozMxM1dfX67333tMrr7yi/Px8LViwwD9z/PhxZWZmasSIESotLdXMmTP18MMPa9u2bf6ZDRs2KDc3VwsXLtT+/fs1cOBAud1unT59+sfbGQAAIKSFWZZl2R3iYrGxsVq6dKnuuecede3aVevWrdM999wjSTpy5IhuvPFGFRcXa/jw4dqyZYvuuOMOVVRUKD4+XpK0Zs0azZ07V1VVVXI4HJo7d64KCgp06NAh/2tMnDhRZ8+e1datWyVJ6enpGjp0qFauXClJ8vl8SkpK0vTp0zVv3jyj3F6vVzExMaqurpbT6WzJXSJJSp5X0OLbROvw6eJMuyMAwFXrcn5/h8w1TY2NjVq/fr1qamrkcrlUUlKihoYGZWRk+Gf69u2rHj16qLi4WJJUXFys/v37+wuTJLndbnm9Xv/ZquLi4oBtNM00baO+vl4lJSUBM+Hh4crIyPDPAAAARNod4ODBg3K5XKqtrVXHjh21adMmpaamqrS0VA6HQ506dQqYj4+Pl8fjkSR5PJ6AwtS0vmndd814vV599dVXOnPmjBobGy85c+TIkW/NXVdXp7q6Ov9zr9d7eW8cAABcVWwvTX369FFpaamqq6v1hz/8QVlZWSoqKrI71vfKy8vTokWL7I4BXJUf3fKRIoCrke0fzzkcDvXq1UtpaWnKy8vTwIEDtWLFCiUkJKi+vl5nz54NmK+srFRCQoIkKSEhodm36Zqef9+M0+lUu3bt1KVLF0VERFxypmkblzJ//nxVV1f7HydOnAjq/QMAgKuD7aXpm3w+n+rq6pSWlqY2bdqosLDQv66srEzl5eVyuVySJJfLpYMHDwZ8y2379u1yOp1KTU31z1y8jaaZpm04HA6lpaUFzPh8PhUWFvpnLiUqKsp/q4SmBwAAaL1s/Xhu/vz5uv3229WjRw+dO3dO69at065du7Rt2zbFxMQoOztbubm5io2NldPp1PTp0+VyuTR8+HBJ0qhRo5SamqoHHnhAS5Yskcfj0ZNPPqmcnBxFRUVJkqZOnaqVK1dqzpw5euihh7Rjxw698cYbKij4x0caubm5ysrK0pAhQzRs2DAtX75cNTU1mjJlii37BQAAhB5bS9Pp06c1efJknTp1SjExMRowYIC2bdumn//855Kk559/XuHh4Ro/frzq6urkdru1evVq/89HRERo8+bNeuyxx+RyudShQwdlZWXp6aef9s+kpKSooKBAs2bN0ooVK9S9e3e99NJLcrvd/pkJEyaoqqpKCxYskMfj0aBBg7R169ZmF4cDAIBrV8jdp+lqxX2aAHNcCA4gVFyV92kCAAAIZZQmAAAAA5QmAAAAA5QmAAAAA5QmAAAAA5QmAAAAA5QmAAAAA5QmAAAAA5QmAAAAA5QmAAAAA5QmAAAAA0GVpr/+9a8tnQMAACCkBVWaevXqpREjRug///M/VVtb29KZAAAAQk5QpWn//v0aMGCAcnNzlZCQoH/913/V3r17WzobAABAyAiqNA0aNEgrVqxQRUWFXn75ZZ06dUo333yz+vXrp2XLlqmqqqqlcwIAANjqB10IHhkZqXHjxmnjxo169tln9cknn2j27NlKSkrS5MmTderUqZbKCQAAYKsfVJr27dunX/ziF+rWrZuWLVum2bNn69ixY9q+fbsqKip01113tVROAAAAW0UG80PLli3T2rVrVVZWpjFjxujVV1/VmDFjFB7+dQdLSUlRfn6+kpOTWzIrAACAbYIqTS+++KIeeughPfjgg+rWrdslZ+Li4vT73//+B4UDAAAIFUGVpqNHj37vjMPhUFZWVjCbBwAACDlBXdO0du1abdy4sdnyjRs36pVXXvnBoQAAAEJNUKUpLy9PXbp0abY8Li5Ov/nNb35wKAAAgFATVGkqLy9XSkpKs+U9e/ZUeXn5Dw4FAAAQaoIqTXFxcTpw4ECz5R999JE6d+78g0MBAACEmqBK07333qt/+7d/086dO9XY2KjGxkbt2LFDM2bM0MSJE1s6IwAAgO2C+vbcr3/9a3366acaOXKkIiO/3oTP59PkyZO5pgkAALRKQZUmh8OhDRs26Ne//rU++ugjtWvXTv3791fPnj1bOh8AAEBICKo0Nbnhhht0ww03tFQWAACAkBVUaWpsbFR+fr4KCwt1+vRp+Xy+gPU7duxokXAAAAChIqjSNGPGDOXn5yszM1P9+vVTWFhYS+cCAAAIKUGVpvXr1+uNN97QmDFjWjoPAABASArqlgMOh0O9evVq6SwAAAAhK6jS9Pjjj2vFihWyLKul8wAAAISkoD6ee/fdd7Vz505t2bJFP/3pT9WmTZuA9W+++WaLhAMAAAgVQZWmTp066e67727pLAAAACErqNK0du3als4BAAAQ0oK6pkmSLly4oD//+c/67W9/q3PnzkmSKioqdP78+RYLBwAAECqCOtP0t7/9TaNHj1Z5ebnq6ur085//XNHR0Xr22WdVV1enNWvWtHROAAAAWwV1pmnGjBkaMmSIzpw5o3bt2vmX33333SosLGyxcAAAAKEiqDNN//M//6P33ntPDocjYHlycrL+7//+r0WCAQAAhJKgzjT5fD41NjY2W37y5ElFR0f/4FAAAAChJqjSNGrUKC1fvtz/PCwsTOfPn9fChQv50yoAAKBVCurjueeee05ut1upqamqra3Vfffdp6NHj6pLly56/fXXWzojAACA7YIqTd27d9dHH32k9evX68CBAzp//ryys7M1adKkgAvDAQAAWougSpMkRUZG6v7772/JLAAAACErqNL06quvfuf6yZMnBxUGAAAgVAVVmmbMmBHwvKGhQV9++aUcDofat29PaQIAAK1OUN+eO3PmTMDj/PnzKisr080338yF4AAAoFUK+m/PfVPv3r21ePHiZmehAAAAWoMWK03S1xeHV1RUtOQmAQAAQkJQ1zT96U9/CnhuWZZOnTqllStX6qabbmqRYAAAAKEkqNI0duzYgOdhYWHq2rWrbrvtNj333HMtkQsAACCkBFWafD5fS+cAAAAIaS16TRMAAEBrFdSZptzcXOPZZcuWBfMSAAAAISWo0vThhx/qww8/VENDg/r06SNJ+vjjjxUREaHBgwf758LCwlomJQAAgM2CKk133nmnoqOj9corr+i6666T9PUNL6dMmaJbbrlFjz/+eIuGBAAAsFtQ1zQ999xzysvL8xcmSbruuuv0zDPP8O05AADQKgVVmrxer6qqqpotr6qq0rlz535wKAAAgFATVGm6++67NWXKFL355ps6efKkTp48qf/6r/9Sdna2xo0b19IZAQAAbBfUNU1r1qzR7Nmzdd9996mhoeHrDUVGKjs7W0uXLm3RgAAAAKEgqNLUvn17rV69WkuXLtWxY8ckSddff706dOjQouEAAABCxQ+6ueWpU6d06tQp9e7dWx06dJBlWS2VCwAAIKQEVZo+//xzjRw5UjfccIPGjBmjU6dOSZKys7O53QAAAGiVgipNs2bNUps2bVReXq727dv7l0+YMEFbt25tsXAAAAChIqhrmt555x1t27ZN3bt3D1jeu3dv/e1vf2uRYAAAAKEkqDNNNTU1AWeYmnzxxReKior6waEAAABCTVCl6ZZbbtGrr77qfx4WFiafz6clS5ZoxIgRLRYOAAAgVARVmpYsWaLf/e53uv3221VfX685c+aoX79+2r17t5599lnj7eTl5Wno0KGKjo5WXFycxo4dq7KysoCZ2tpa5eTkqHPnzurYsaPGjx+vysrKgJny8nJlZmaqffv2iouL0xNPPKELFy4EzOzatUuDBw9WVFSUevXqpfz8/GZ5Vq1apeTkZLVt21bp6enau3ev+U4BAACtWlClqV+/fvr44491880366677lJNTY3GjRunDz/8UNdff73xdoqKipSTk6P3339f27dvV0NDg0aNGqWamhr/zKxZs/T2229r48aNKioqUkVFRcBdxxsbG5WZman6+nq99957euWVV5Sfn68FCxb4Z44fP67MzEyNGDFCpaWlmjlzph5++GFt27bNP7Nhwwbl5uZq4cKF2r9/vwYOHCi3263Tp08Hs4sAAEArE2Zd5s2VGhoaNHr0aK1Zs0a9e/du0TBVVVWKi4tTUVGRbr31VlVXV6tr165at26d7rnnHknSkSNHdOONN6q4uFjDhw/Xli1bdMcdd6iiokLx8fGSvr5j+dy5c1VVVSWHw6G5c+eqoKBAhw4d8r/WxIkTdfbsWf+3/dLT0zV06FCtXLlSkuTz+ZSUlKTp06dr3rx535vd6/UqJiZG1dXVcjqdLbpfJCl5XkGLbxOwy6eLM+2OAACSLu/392WfaWrTpo0OHDgQdLjvUl1dLUmKjY2VJJWUlKihoUEZGRn+mb59+6pHjx4qLi6WJBUXF6t///7+wiRJbrdbXq9Xhw8f9s9cvI2mmaZt1NfXq6SkJGAmPDxcGRkZ/plvqqurk9frDXgAAIDWK6iP5+6//379/ve/b9EgPp9PM2fO1E033aR+/fpJkjwejxwOhzp16hQwGx8fL4/H45+5uDA1rW9a910zXq9XX331lT777DM1NjZecqZpG9+Ul5enmJgY/yMpKSm4Nw4AAK4KQd2n6cKFC3r55Zf15z//WWlpac3+5tyyZcsue5s5OTk6dOiQ3n333WAi/ejmz5+v3Nxc/3Ov10txAgCgFbus0vTXv/5VycnJOnTokAYPHixJ+vjjjwNmwsLCLjvEtGnTtHnzZu3evTvghpkJCQmqr6/X2bNnA842VVZWKiEhwT/zzW+5NX277uKZb37jrrKyUk6nU+3atVNERIQiIiIuOdO0jW+KiorinlQAAFxDLuvjud69e+uzzz7Tzp07tXPnTsXFxWn9+vX+5zt37tSOHTuMt2dZlqZNm6ZNmzZpx44dSklJCViflpamNm3aqLCw0L+srKxM5eXlcrlckiSXy6WDBw8GfMtt+/btcjqdSk1N9c9cvI2mmaZtOBwOpaWlBcz4fD4VFhb6ZwAAwLXtss40ffOLdlu2bAm4PcDlysnJ0bp16/THP/5R0dHR/uuHYmJi1K5dO8XExCg7O1u5ubmKjY2V0+nU9OnT5XK5NHz4cEnSqFGjlJqaqgceeEBLliyRx+PRk08+qZycHP+ZoKlTp2rlypWaM2eOHnroIe3YsUNvvPGGCgr+8Y203NxcZWVlaciQIRo2bJiWL1+umpoaTZkyJej3BwAAWo+grmlqcpl3K2jmxRdflCT97Gc/C1i+du1aPfjgg5Kk559/XuHh4Ro/frzq6urkdru1evVq/2xERIQ2b96sxx57TC6XSx06dFBWVpaefvpp/0xKSooKCgo0a9YsrVixQt27d9dLL70kt9vtn5kwYYKqqqq0YMECeTweDRo0SFu3bm12cTgAALg2XdZ9miIiIuTxeNS1a1dJUnR0tA4cONDsY7VrEfdpAsxxnyYAoeJyfn9f9sdzDz74oP9jr9raWk2dOrXZt+fefPPNy4wMAAAQ2i6rNGVlZQU8v//++1s0DAAAQKi6rNK0du3aK5UDAAAgpAV1R3AAAIBrDaUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAAKUJAADAQKTdAQBce5LnFdgd4bJ9ujjT7ggAbMaZJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAOUJgAAAAO2lqbdu3frzjvvVGJiosLCwvTWW28FrLcsSwsWLFC3bt3Url07ZWRk6OjRowEzX3zxhSZNmiSn06lOnTopOztb58+fD5g5cOCAbrnlFrVt21ZJSUlasmRJsywbN25U37591bZtW/Xv31///d//3eLvFwAAXL1sLU01NTUaOHCgVq1adcn1S5Ys0QsvvKA1a9Zoz5496tChg9xut2pra/0zkyZN0uHDh7V9+3Zt3rxZu3fv1qOPPupf7/V6NWrUKPXs2VMlJSVaunSpnnrqKf3ud7/zz7z33nu69957lZ2drQ8//FBjx47V2LFjdejQoSv35gEAwFUlzLIsy+4QkhQWFqZNmzZp7Nixkr4+y5SYmKjHH39cs2fPliRVV1crPj5e+fn5mjhxov7yl78oNTVVH3zwgYYMGSJJ2rp1q8aMGaOTJ08qMTFRL774ov793/9dHo9HDodDkjRv3jy99dZbOnLkiCRpwoQJqqmp0ebNm/15hg8frkGDBmnNmjVG+b1er2JiYlRdXS2n09lSu8UveV5Bi28TgLlPF2faHQHAFXA5v79D9pqm48ePy+PxKCMjw78sJiZG6enpKi4uliQVFxerU6dO/sIkSRkZGQoPD9eePXv8M7feequ/MEmS2+1WWVmZzpw545+5+HWaZppe51Lq6urk9XoDHgAAoPUK2dLk8XgkSfHx8QHL4+Pj/es8Ho/i4uIC1kdGRio2NjZg5lLbuPg1vm2maf2l5OXlKSYmxv9ISkq63LcIAACuIiFbmkLd/PnzVV1d7X+cOHHC7kgAAOAKCtnSlJCQIEmqrKwMWF5ZWelfl5CQoNOnTwesv3Dhgr744ouAmUtt4+LX+LaZpvWXEhUVJafTGfAAAACtV8iWppSUFCUkJKiwsNC/zOv1as+ePXK5XJIkl8uls2fPqqSkxD+zY8cO+Xw+paen+2d2796thoYG/8z27dvVp08fXXfddf6Zi1+naabpdQAAAGwtTefPn1dpaalKS0slfX3xd2lpqcrLyxUWFqaZM2fqmWee0Z/+9CcdPHhQkydPVmJiov8bdjfeeKNGjx6tRx55RHv37tX//u//atq0aZo4caISExMlSffdd58cDoeys7N1+PBhbdiwQStWrFBubq4/x4wZM7R161Y999xzOnLkiJ566int27dP06ZN+7F3CQAACFGRdr74vn37NGLECP/zpiKTlZWl/Px8zZkzRzU1NXr00Ud19uxZ3Xzzzdq6davatm3r/5nXXntN06ZN08iRIxUeHq7x48frhRde8K+PiYnRO++8o5ycHKWlpalLly5asGBBwL2c/vmf/1nr1q3Tk08+qV/+8pfq3bu33nrrLfXr1+9H2AsAAOBqEDL3abracZ8moHXjPk1A69Qq7tMEAAAQSihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABihNAAAABiLtDgAAV4PkeQV2R7hsny7OtDsC0KpwpgkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMAApQkAAMBApN0BAABXRvK8ArsjXLZPF2faHQH4VpxpAgAAMEBpAgAAMEBp+oZVq1YpOTlZbdu2VXp6uvbu3Wt3JAAAEAIoTRfZsGGDcnNztXDhQu3fv18DBw6U2+3W6dOn7Y4GAABsRmm6yLJly/TII49oypQpSk1N1Zo1a9S+fXu9/PLLdkcDAAA249tzf1dfX6+SkhLNnz/fvyw8PFwZGRkqLi5uNl9XV6e6ujr/8+rqakmS1+u9Ivl8dV9eke0CQCjpMWuj3REu26FFbrsj4Ado+r1tWdb3zlKa/u6zzz5TY2Oj4uPjA5bHx8fryJEjzebz8vK0aNGiZsuTkpKuWEYAQOiJWW53ArSEc+fOKSYm5jtnKE1Bmj9/vnJzc/3PfT6fvvjiC3Xu3FlhYWEt+lper1dJSUk6ceKEnE5ni267tWFfmWNfmWNfmWNfmWNfXZ4rtb8sy9K5c+eUmJj4vbOUpr/r0qWLIiIiVFlZGbC8srJSCQkJzeajoqIUFRUVsKxTp05XMqKcTif/YRliX5ljX5ljX5ljX5ljX12eK7G/vu8MUxMuBP87h8OhtLQ0FRYW+pf5fD4VFhbK5XLZmAwAAIQCzjRdJDc3V1lZWRoyZIiGDRum5cuXq6amRlOmTLE7GgAAsBml6SITJkxQVVWVFixYII/Ho0GDBmnr1q3NLg7/sUVFRWnhwoXNPg5Ec+wrc+wrc+wrc+wrc+yryxMK+yvMMvmOHQAAwDWOa5oAAAAMUJoAAAAMUJoAAAAMUJoAAAAMUJpC3KpVq5ScnKy2bdsqPT1de/futTtSSHrqqacUFhYW8Ojbt6/dsULC7t27deeddyoxMVFhYWF66623AtZblqUFCxaoW7duateunTIyMnT06FF7wtrs+/bVgw8+2Ow4Gz16tD1hbZaXl6ehQ4cqOjpacXFxGjt2rMrKygJmamtrlZOTo86dO6tjx44aP358sxsIXwtM9tXPfvazZsfW1KlTbUpsnxdffFEDBgzw38DS5XJpy5Yt/vV2H1OUphC2YcMG5ebmauHChdq/f78GDhwot9ut06dP2x0tJP30pz/VqVOn/I93333X7kghoaamRgMHDtSqVasuuX7JkiV64YUXtGbNGu3Zs0cdOnSQ2+1WbW3tj5zUft+3ryRp9OjRAcfZ66+//iMmDB1FRUXKycnR+++/r+3bt6uhoUGjRo1STU2Nf2bWrFl6++23tXHjRhUVFamiokLjxo2zMbU9TPaVJD3yyCMBx9aSJUtsSmyf7t27a/HixSopKdG+fft022236a677tLhw4clhcAxZSFkDRs2zMrJyfE/b2xstBITE628vDwbU4WmhQsXWgMHDrQ7RsiTZG3atMn/3OfzWQkJCdbSpUv9y86ePWtFRUVZr7/+ug0JQ8c395VlWVZWVpZ111132ZIn1J0+fdqSZBUVFVmW9fVx1KZNG2vjxo3+mb/85S+WJKu4uNiumCHhm/vKsizrX/7lX6wZM2bYFyqEXXfdddZLL70UEscUZ5pCVH19vUpKSpSRkeFfFh4eroyMDBUXF9uYLHQdPXpUiYmJ+slPfqJJkyapvLzc7kgh7/jx4/J4PAHHWUxMjNLT0znOvsWuXbsUFxenPn366LHHHtPnn39ud6SQUF1dLUmKjY2VJJWUlKihoSHg2Orbt6969OhxzR9b39xXTV577TV16dJF/fr10/z58/Xll1/aES9kNDY2av369aqpqZHL5QqJY4o7goeozz77TI2Njc3uRh4fH68jR47YlCp0paenKz8/X3369NGpU6e0aNEi3XLLLTp06JCio6PtjheyPB6PJF3yOGtah38YPXq0xo0bp5SUFB07dky//OUvdfvtt6u4uFgRERF2x7ONz+fTzJkzddNNN6lfv36Svj62HA5Hsz9kfq0fW5faV5J03333qWfPnkpMTNSBAwc0d+5clZWV6c0337QxrT0OHjwol8ul2tpadezYUZs2bVJqaqpKS0ttP6YoTWgVbr/9dv+/BwwYoPT0dPXs2VNvvPGGsrOzbUyG1mTixIn+f/fv318DBgzQ9ddfr127dmnkyJE2JrNXTk6ODh06xHWEBr5tXz366KP+f/fv31/dunXTyJEjdezYMV1//fU/dkxb9enTR6WlpaqurtYf/vAHZWVlqaioyO5YkrgQPGR16dJFERERzb4VUFlZqYSEBJtSXT06deqkG264QZ988ondUUJa07HEcRacn/zkJ+rSpcs1fZxNmzZNmzdv1s6dO9W9e3f/8oSEBNXX1+vs2bMB89fysfVt++pS0tPTJemaPLYcDod69eqltLQ05eXlaeDAgVqxYkVIHFOUphDlcDiUlpamwsJC/zKfz6fCwkK5XC4bk10dzp8/r2PHjqlbt252RwlpKSkpSkhICDjOvF6v9uzZw3Fm4OTJk/r888+vyePMsixNmzZNmzZt0o4dO5SSkhKwPi0tTW3atAk4tsrKylReXn7NHVvft68upbS0VJKuyWPrm3w+n+rq6kLimOLjuRCWm5urrKwsDRkyRMOGDdPy5ctVU1OjKVOm2B0t5MyePVt33nmnevbsqYqKCi1cuFARERG699577Y5mu/Pnzwf83+rx48dVWlqq2NhY9ejRQzNnztQzzzyj3r17KyUlRb/61a+UmJiosWPH2hfaJt+1r2JjY7Vo0SKNHz9eCQkJOnbsmObMmaNevXrJ7XbbmNoeOTk5Wrdunf74xz8qOjraf01JTEyM2rVrp5iYGGVnZys3N1exsbFyOp2aPn26XC6Xhg8fbnP6H9f37atjx45p3bp1GjNmjDp37qwDBw5o1qxZuvXWWzVgwACb0/+45s+fr9tvv109evTQuXPntG7dOu3atUvbtm0LjWPqR/mOHoL2H//xH1aPHj0sh8NhDRs2zHr//fftjhSSJkyYYHXr1s1yOBzWP/3TP1kTJkywPvnkE7tjhYSdO3dakpo9srKyLMv6+rYDv/rVr6z4+HgrKirKGjlypFVWVmZvaJt817768ssvrVGjRlldu3a12rRpY/Xs2dN65JFHLI/HY3dsW1xqP0my1q5d65/56quvrF/84hfWddddZ7Vv3966++67rVOnTtkX2ibft6/Ky8utW2+91YqNjbWioqKsXr16WU888YRVXV1tb3AbPPTQQ1bPnj0th8Nhde3a1Ro5cqT1zjvv+NfbfUyFWZZl/Tj1DAAA4OrFNU0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAGKE0AAAAG/h9OqcxjzMXKQgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "train_df.line_number.plot.hist()" ] }, { "cell_type": "code", "execution_count": 114, "id": "99c45c9e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "15.0" ] }, "execution_count": 114, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.percentile(train_df.line_number,98)" ] }, { "cell_type": "code", "execution_count": 115, "id": "acac96ff", "metadata": {}, "outputs": [], "source": [ "train_line_number_one_hot=tf.one_hot(train_df[\"line_number\"].to_numpy(),depth=15)\n", "val_line_number_one_hot=tf.one_hot(val_df[\"line_number\"].to_numpy(),depth=15)\n", "test_line_number_one_hot=tf.one_hot(test_df[\"line_number\"].to_numpy(),depth=15)\n" ] }, { "cell_type": "code", "execution_count": 116, "id": "8982b552", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 116, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_line_number_one_hot[:10]" ] }, { "cell_type": "code", "execution_count": 117, "id": "2c05ba32", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "total_lines\n", "11 24468\n", "10 23639\n", "12 22113\n", "9 19400\n", "13 18438\n", "14 14610\n", "8 12285\n", "15 10768\n", "7 7464\n", "16 7429\n", "17 5202\n", "6 3353\n", "18 3344\n", "19 2480\n", "20 1281\n", "5 1146\n", "21 770\n", "22 759\n", "23 264\n", "4 215\n", "24 200\n", "25 182\n", "26 81\n", "28 58\n", "3 32\n", "30 31\n", "27 28\n", "Name: count, dtype: int64" ] }, "execution_count": 117, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_df[\"total_lines\"].value_counts()" ] }, { "cell_type": "code", "execution_count": 118, "id": "87b40999", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "20.0" ] }, "execution_count": 118, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.percentile(train_df.total_lines,98)" ] }, { "cell_type": "code", "execution_count": 119, "id": "210c84b4", "metadata": {}, "outputs": [], "source": [ "train_total_lines=tf.one_hot(train_df[\"total_lines\"].to_numpy(),depth=20)\n", "valid_total_lines=tf.one_hot(val_df[\"total_lines\"].to_numpy(),depth=20)\n", "test_total_lines=tf.one_hot(test_df[\"total_lines\"].to_numpy(),depth=20)" ] }, { "cell_type": "code", "execution_count": 120, "id": "34d63477", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 120, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_total_lines[:5]" ] }, { "cell_type": "code", "execution_count": 121, "id": "2e75261c", "metadata": {}, "outputs": [], "source": [ "# Token_input\n", "token_inputs=tf.keras.layers.Input(shape=[],dtype=\"string\")\n", "embedding_layer=tf_hub_embedding_layer(token_inputs)\n", "token_outputs=tf.keras.layers.Dense(128,activation=\"relu\") (embedding_layer)\n", "first_model=tf.keras.Model(token_inputs,token_outputs)\n", "\n", "# Char_inputs\n", "\n", "char_inputs=tf.keras.layers.Input(shape=(1,),dtype=\"string\")\n", "char_tokens=char_vectorization(char_inputs)\n", "char_embedded=char_embedding(char_tokens)\n", "char_outputs=tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32))(char_embedded)\n", "second_model=tf.keras.Model(char_inputs,char_outputs)\n", "\n", "# Line Number \n", "line_number=tf.keras.layers.Input(shape=(15,), dtype=tf.int32)\n", "x=tf.keras.layers.Dense(32,activation=\"relu\")(line_number)\n", "line_number_model=tf.keras.Model(line_number,x)\n", "\n", "#total lines\n", "total_lines_inputs = tf.keras.layers.Input(shape=(20,), dtype=tf.int32)\n", "y = tf.keras.layers.Dense(32, activation=\"relu\")(total_lines_inputs)\n", "total_line_model = tf.keras.Model(total_lines_inputs,y)\n", "\n", "combined_model=tf.keras.layers.Concatenate()([first_model.output,second_model.output])\n", "z = tf.keras.layers.Dense(256, activation=\"relu\")(combined_model)\n", "z = tf.keras.layers.Dropout(0.5)(z)\n", "\n", "z = tf.keras.layers.Concatenate()([line_number_model.output,total_line_model.output,z])\n", "output_layer = tf.keras.layers.Dense(5, activation=\"softmax\", name=\"output_layer\")(z)\n" ] }, { "cell_type": "code", "execution_count": 122, "id": "cb4d33bc", "metadata": {}, "outputs": [], "source": [ "model_5 = tf.keras.Model(inputs=[line_number_model.input,\n", " total_line_model.input,\n", " first_model.input, \n", " second_model.input],\n", " outputs=output_layer)\n", " " ] }, { "cell_type": "code", "execution_count": 123, "id": "f2b0c51f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"model_10\"\n", "__________________________________________________________________________________________________\n", " Layer (type) Output Shape Param # Connected to \n", "==================================================================================================\n", " input_7 (InputLayer) [(None, 1)] 0 [] \n", " \n", " input_6 (InputLayer) [(None,)] 0 [] \n", " \n", " text_vectorization_2 (TextVect (None, 290) 0 ['input_7[0][0]'] \n", " orization) \n", " \n", " keras_layer (KerasLayer) (None, 512) 256797824 ['input_6[0][0]'] \n", " \n", " embedding_1 (Embedding) (None, 290, 25) 1750 ['text_vectorization_2[2][0]'] \n", " \n", " dense_7 (Dense) (None, 128) 65664 ['keras_layer[2][0]'] \n", " \n", " bidirectional_1 (Bidirectional (None, 64) 14848 ['embedding_1[2][0]'] \n", " ) \n", " \n", " concatenate (Concatenate) (None, 192) 0 ['dense_7[0][0]', \n", " 'bidirectional_1[0][0]'] \n", " \n", " input_8 (InputLayer) [(None, 15)] 0 [] \n", " \n", " input_9 (InputLayer) [(None, 20)] 0 [] \n", " \n", " dense_10 (Dense) (None, 256) 49408 ['concatenate[0][0]'] \n", " \n", " dense_8 (Dense) (None, 32) 512 ['input_8[0][0]'] \n", " \n", " dense_9 (Dense) (None, 32) 672 ['input_9[0][0]'] \n", " \n", " dropout_2 (Dropout) (None, 256) 0 ['dense_10[0][0]'] \n", " \n", " concatenate_1 (Concatenate) (None, 320) 0 ['dense_8[0][0]', \n", " 'dense_9[0][0]', \n", " 'dropout_2[0][0]'] \n", " \n", " output_layer (Dense) (None, 5) 1605 ['concatenate_1[0][0]'] \n", " \n", "==================================================================================================\n", "Total params: 256,932,283\n", "Trainable params: 134,459\n", "Non-trainable params: 256,797,824\n", "__________________________________________________________________________________________________\n" ] } ], "source": [ "model_5.summary()\n" ] }, { "cell_type": "code", "execution_count": 124, "id": "4da74e81", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "" ] }, "execution_count": 124, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tensorflow.keras.utils import plot_model\n", "plot_model(model_5,show_shapes=True)" ] }, { "cell_type": "code", "execution_count": 125, "id": "3ed8b66a", "metadata": {}, "outputs": [], "source": [ "# label_smooting-prevents overfitting\n", "model_5.compile(loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.2),\n", " optimizer=tf.keras.optimizers.legacy.Adam(),\n", " metrics=[\"accuracy\"])" ] }, { "cell_type": "code", "execution_count": 126, "id": "8c88bd63", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(<_PrefetchDataset element_spec=((TensorSpec(shape=(None, 15), dtype=tf.float32, name=None), TensorSpec(shape=(None, 20), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.string, name=None), TensorSpec(shape=(None,), dtype=tf.string, name=None)), TensorSpec(shape=(None, 5), dtype=tf.float64, name=None))>,\n", " <_PrefetchDataset element_spec=((TensorSpec(shape=(None, 15), dtype=tf.float32, name=None), TensorSpec(shape=(None, 20), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.string, name=None), TensorSpec(shape=(None,), dtype=tf.string, name=None)), TensorSpec(shape=(None, 5), dtype=tf.float64, name=None))>)" ] }, "execution_count": 126, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create training and validation datasets (all four kinds of inputs)\n", "train_pos_char_token_data = tf.data.Dataset.from_tensor_slices((train_line_number_one_hot,\n", " train_total_lines,\n", " train_sentences,\n", " train_chars)) \n", "train_pos_char_token_labels = tf.data.Dataset.from_tensor_slices(train_labels_one_hot) \n", "train_pos_char_token_dataset = tf.data.Dataset.zip((train_pos_char_token_data, train_pos_char_token_labels))\n", "train_pos_char_token_dataset = train_pos_char_token_dataset.batch(32).prefetch(tf.data.AUTOTUNE)\n", "\n", "# Validation dataset\n", "val_pos_char_token_data = tf.data.Dataset.from_tensor_slices((val_line_number_one_hot,\n", " valid_total_lines,\n", " val_sentences,\n", " val_chars))\n", "val_pos_char_token_labels = tf.data.Dataset.from_tensor_slices(val_labels_one_hot)\n", "val_pos_char_token_dataset = tf.data.Dataset.zip((val_pos_char_token_data, val_pos_char_token_labels))\n", "val_pos_char_token_dataset = val_pos_char_token_dataset.batch(32).prefetch(tf.data.AUTOTUNE) # turn into batches and prefetch appropriately\n", "\n", "# Check input shapes\n", "train_pos_char_token_dataset, val_pos_char_token_dataset" ] }, { "cell_type": "code", "execution_count": 127, "id": "2294d0cd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/3\n", "562/562 [==============================] - 178s 306ms/step - loss: 1.0938 - accuracy: 0.7248 - val_loss: 0.9813 - val_accuracy: 0.8082\n", "Epoch 2/3\n", "562/562 [==============================] - 171s 305ms/step - loss: 0.9613 - accuracy: 0.8168 - val_loss: 0.9476 - val_accuracy: 0.8261\n", "Epoch 3/3\n", "562/562 [==============================] - 163s 290ms/step - loss: 0.9449 - accuracy: 0.8263 - val_loss: 0.9347 - val_accuracy: 0.8314\n" ] } ], "source": [ "history_model_5 = model_5.fit(train_pos_char_token_dataset,\n", " steps_per_epoch=int(0.1 * len(train_pos_char_token_dataset)),\n", " epochs=3,\n", " validation_data=val_pos_char_token_dataset,\n", " validation_steps=int(0.1 * len(val_pos_char_token_dataset)))" ] }, { "cell_type": "code", "execution_count": 128, "id": "a4c0425e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "945/945 [==============================] - 225s 236ms/step\n" ] } ], "source": [ "model_5_pred=model_5.predict(val_pos_char_token_dataset)" ] }, { "cell_type": "code", "execution_count": 129, "id": "290114c1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 129, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Turn prediction probabilities into prediction classes\n", "model_5_preds = tf.argmax(model_5_pred, axis=1)\n", "model_5_preds\n" ] }, { "cell_type": "code", "execution_count": 130, "id": "07f11921", "metadata": {}, "outputs": [], "source": [ "model_5_results=calculate_results(val_labels_encoder,model_5_preds)" ] }, { "cell_type": "code", "execution_count": 131, "id": "5625f0bf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'accuracy': 83.50655368727658,\n", " 'precision': 83.40022533730422,\n", " 'recall': 83.50655368727658,\n", " 'F1-score': 83.35969828392814}" ] }, "execution_count": 131, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model_5_results" ] }, { "cell_type": "markdown", "id": "0c8faca1", "metadata": {}, "source": [ "# Comparing Models" ] }, { "cell_type": "code", "execution_count": 132, "id": "cbd01f7a", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "all_model_results=pd.DataFrame(\n", " { \n", " \"baseline\":model_0_reults,\n", " \"Model_1\": model_1_results,\n", " \"Model_2\": model_2_results,\n", " \"Model_3\": model_3_results,\n", " \"Model_4\": model_4_results,\n", " \"Model_5\": model_5_results})" ] }, { "cell_type": "code", "execution_count": 133, "id": "30d4cc61", "metadata": {}, "outputs": [], "source": [ "all_model_results=all_model_results.transpose()\n" ] }, { "cell_type": "code", "execution_count": 134, "id": "3444df75", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
accuracyprecisionrecallF1-score
baseline72.18323871.86467072.18323869.892504
Model_178.71044678.33334678.71044678.413042
Model_271.39216271.42823771.39216271.101595
Model_365.76194965.44382565.76194964.577781
Model_474.31815274.56791674.31815274.334530
Model_583.50655483.40022583.50655483.359698
\n", "
" ], "text/plain": [ " accuracy precision recall F1-score\n", "baseline 72.183238 71.864670 72.183238 69.892504\n", "Model_1 78.710446 78.333346 78.710446 78.413042\n", "Model_2 71.392162 71.428237 71.392162 71.101595\n", "Model_3 65.761949 65.443825 65.761949 64.577781\n", "Model_4 74.318152 74.567916 74.318152 74.334530\n", "Model_5 83.506554 83.400225 83.506554 83.359698" ] }, "execution_count": 134, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_model_results" ] }, { "cell_type": "code", "execution_count": 135, "id": "1ec359e5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 135, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "all_model_results.plot(kind=\"bar\",figsize=(10,5)).legend(bbox_to_anchor=(1.0,1.0))" ] }, { "cell_type": "code", "execution_count": 136, "id": "32950eb0", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAzYAAAJyCAYAAAAW8VWPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAxNklEQVR4nO3df5TWdZ3//8cgMCAwQ2DMwHEQLBUxTcUNp9y1iJXMSleOW0abpv3YYinl9EPa1E+moZ5dJXdB0xC1DXXdNVf7oRQlbQmmaPZTYgsdCme0NhhBGUTm+0dfJyfBdYZhrnk5t9s51znM+7rmmif1PO/jnfd1XVPV3t7eHgAAgIINqPQAAAAAu0vYAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxBlZ6gD+3Y8eObNiwISNGjEhVVVWlxwEAACqkvb09Tz75ZMaNG5cBA178mkyfC5sNGzakoaGh0mMAAAB9xPr167Pvvvu+6GP6XNiMGDEiyR+Hr6mpqfA0AABApbS2tqahoaGjEV5Mnwub515+VlNTI2wAAICX9BYVHx4AAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8QZWeoC+asI5X6/0CBX1yMUnVHoEAAB4yVyxAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAoXpfC5tlnn825556biRMnZujQoXnVq16Vz33uc2lvb+94THt7e84777yMHTs2Q4cOzfTp07N27doeHxwAAOA5XQqbSy65JFdeeWX+9V//Nb/4xS9yySWX5NJLL82//Mu/dDzm0ksvzRVXXJGrrroq9957b4YNG5YZM2Zk69atPT48AABAkgzsyoPvueeenHjiiTnhhBOSJBMmTMiNN96YH/7wh0n+eLVmwYIF+cxnPpMTTzwxSXLDDTekrq4ut912W971rnf18PgAAABdvGLz+te/PsuXL88vf/nLJMlDDz2U73//+zn++OOTJOvWrUtzc3OmT5/e8T21tbWZOnVqVq5cudPnbGtrS2tra6cbAABAV3Tpis0555yT1tbWTJo0KXvttVeeffbZXHTRRZk1a1aSpLm5OUlSV1fX6fvq6uo67vtz8+fPz2c/+9nuzA4AAJCki1ds/v3f/z1f+cpXsnTp0jzwwAO5/vrr80//9E+5/vrruz3AvHnzsmnTpo7b+vXru/1cAABA/9SlKzaf+MQncs4553S8V+bQQw/No48+mvnz5+e0005LfX19kqSlpSVjx47t+L6WlpYcfvjhO33O6urqVFdXd3N8AACALl6xeeqppzJgQOdv2WuvvbJjx44kycSJE1NfX5/ly5d33N/a2pp77703jY2NPTAuAADAC3Xpis3b3/72XHTRRRk/fnwOOeSQPPjgg7nssstyxhlnJEmqqqpy1lln5cILL8wBBxyQiRMn5txzz824ceNy0kkn7Yn5AQAAuhY2//Iv/5Jzzz03H/nIR/L4449n3Lhx+dCHPpTzzjuv4zGf/OQns2XLlnzwgx/Mxo0bc8wxx+TOO+/MkCFDenx4AACAJKlqb29vr/QQz9fa2pra2tps2rQpNTU1FZtjwjlfr9jP7gseufiESo8AAEA/15U26NJ7bAAAAPoiYQMAABSvS++xgf7EyxG9HBEAKIcrNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUbWOkBAPqqCed8vdIjVNwjF59Q6REA4CVxxQYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4g2s9AAA0FdNOOfrlR6h4h65+IRKjwDwkrhiAwAAFE/YAAAAxetS2EyYMCFVVVUvuM2ePTtJsnXr1syePTujR4/O8OHDM3PmzLS0tOyRwQEAAJ7TpbC577778thjj3XcvvWtbyVJTjnllCTJ2WefnTvuuCO33HJLVqxYkQ0bNuTkk0/u+akBAACep0sfHvDKV76y09cXX3xxXvWqV+XYY4/Npk2bsnjx4ixdujTTpk1LkixZsiQHH3xwVq1alaOPPrrnpgYAAHiebr/HZtu2bfm3f/u3nHHGGamqqsrq1avzzDPPZPr06R2PmTRpUsaPH5+VK1fu8nna2trS2tra6QYAANAV3Q6b2267LRs3bszpp5+eJGlubs7gwYMzcuTITo+rq6tLc3PzLp9n/vz5qa2t7bg1NDR0dyQAAKCf6vbvsVm8eHGOP/74jBs3brcGmDdvXubOndvxdWtrq7gBAPoEv8vI7zKiHN0Km0cffTTf/va3c+utt3Ycq6+vz7Zt27Jx48ZOV21aWlpSX1+/y+eqrq5OdXV1d8YAAABI0s2Xoi1ZsiRjxozJCSf8qeCnTJmSQYMGZfny5R3H1qxZk6ampjQ2Nu7+pAAAALvQ5Ss2O3bsyJIlS3Laaadl4MA/fXttbW3OPPPMzJ07N6NGjUpNTU3mzJmTxsZGn4gGAADsUV0Om29/+9tpamrKGWec8YL7Lr/88gwYMCAzZ85MW1tbZsyYkUWLFvXIoAAAALvS5bA57rjj0t7evtP7hgwZkoULF2bhwoW7PRgAAMBL1e2PewYAAOgrhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFC8gZUeAAAA+qoJ53y90iNU3CMXn1DpEV4SV2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACK1+Ww+e1vf5v3vOc9GT16dIYOHZpDDz00999/f8f97e3tOe+88zJ27NgMHTo006dPz9q1a3t0aAAAgOfrUtj84Q9/yBve8IYMGjQo3/zmN/Pzn/88//zP/5xXvOIVHY+59NJLc8UVV+Sqq67Kvffem2HDhmXGjBnZunVrjw8PAACQJAO78uBLLrkkDQ0NWbJkScexiRMndvy5vb09CxYsyGc+85mceOKJSZIbbrghdXV1ue222/Kud72rh8YGAAD4ky5dsbn99ttz1FFH5ZRTTsmYMWNyxBFH5Jprrum4f926dWlubs706dM7jtXW1mbq1KlZuXLlTp+zra0tra2tnW4AAABd0aWw+fWvf50rr7wyBxxwQO666658+MMfzkc/+tFcf/31SZLm5uYkSV1dXafvq6ur67jvz82fPz+1tbUdt4aGhu78PQAAgH6sS2GzY8eOHHnkkfn85z+fI444Ih/84AfzgQ98IFdddVW3B5g3b142bdrUcVu/fn23nwsAAOifuhQ2Y8eOzeTJkzsdO/jgg9PU1JQkqa+vT5K0tLR0ekxLS0vHfX+uuro6NTU1nW4AAABd0aWwecMb3pA1a9Z0OvbLX/4y++23X5I/fpBAfX19li9f3nF/a2tr7r333jQ2NvbAuAAAAC/UpU9FO/vss/P6178+n//85/O3f/u3+eEPf5irr746V199dZKkqqoqZ511Vi688MIccMABmThxYs4999yMGzcuJ5100p6YHwAAoGth8xd/8Rf56le/mnnz5uWCCy7IxIkTs2DBgsyaNavjMZ/85CezZcuWfPCDH8zGjRtzzDHH5M4778yQIUN6fHgAAICki2GTJG9729vytre9bZf3V1VV5YILLsgFF1ywW4MBAAC8VF16jw0AAEBfJGwAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4XQqb//f//l+qqqo63SZNmtRx/9atWzN79uyMHj06w4cPz8yZM9PS0tLjQwMAADxfl6/YHHLIIXnsscc6bt///vc77jv77LNzxx135JZbbsmKFSuyYcOGnHzyyT06MAAAwJ8b2OVvGDgw9fX1Lzi+adOmLF68OEuXLs20adOSJEuWLMnBBx+cVatW5eijj97p87W1taWtra3j69bW1q6OBAAA9HNdvmKzdu3ajBs3Lvvvv39mzZqVpqamJMnq1avzzDPPZPr06R2PnTRpUsaPH5+VK1fu8vnmz5+f2trajltDQ0M3/hoAAEB/1qWwmTp1aq677rrceeedufLKK7Nu3br85V/+ZZ588sk0Nzdn8ODBGTlyZKfvqaurS3Nz8y6fc968edm0aVPHbf369d36iwAAAP1Xl16Kdvzxx3f8+bDDDsvUqVOz33775d///d8zdOjQbg1QXV2d6urqbn0vAABAspsf9zxy5MgceOCB+Z//+Z/U19dn27Zt2bhxY6fHtLS07PQ9OQAAAD1lt8Jm8+bN+dWvfpWxY8dmypQpGTRoUJYvX95x/5o1a9LU1JTGxsbdHhQAAGBXuvRStI9//ON5+9vfnv322y8bNmzI+eefn7322iunnnpqamtrc+aZZ2bu3LkZNWpUampqMmfOnDQ2Nu7yE9EAAAB6QpfC5je/+U1OPfXU/P73v88rX/nKHHPMMVm1alVe+cpXJkkuv/zyDBgwIDNnzkxbW1tmzJiRRYsW7ZHBAQAAntOlsLnpppte9P4hQ4Zk4cKFWbhw4W4NBQAA0BW79R4bAACAvkDYAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxdutsLn44otTVVWVs846q+PY1q1bM3v27IwePTrDhw/PzJkz09LSsrtzAgAA7FK3w+a+++7LF7/4xRx22GGdjp999tm54447csstt2TFihXZsGFDTj755N0eFAAAYFe6FTabN2/OrFmzcs011+QVr3hFx/FNmzZl8eLFueyyyzJt2rRMmTIlS5YsyT333JNVq1b12NAAAADP162wmT17dk444YRMnz690/HVq1fnmWee6XR80qRJGT9+fFauXLnT52pra0tra2unGwAAQFcM7Oo33HTTTXnggQdy3333veC+5ubmDB48OCNHjux0vK6uLs3NzTt9vvnz5+ezn/1sV8cAAADo0KUrNuvXr8/HPvaxfOUrX8mQIUN6ZIB58+Zl06ZNHbf169f3yPMCAAD9R5fCZvXq1Xn88cdz5JFHZuDAgRk4cGBWrFiRK664IgMHDkxdXV22bduWjRs3dvq+lpaW1NfX7/Q5q6urU1NT0+kGAADQFV16Kdqb3/zm/OQnP+l07H3ve18mTZqUT33qU2loaMigQYOyfPnyzJw5M0myZs2aNDU1pbGxseemBgAAeJ4uhc2IESPymte8ptOxYcOGZfTo0R3HzzzzzMydOzejRo1KTU1N5syZk8bGxhx99NE9NzUAAMDzdPnDA/4vl19+eQYMGJCZM2emra0tM2bMyKJFi3r6xwAAAHTY7bC5++67O309ZMiQLFy4MAsXLtzdpwYAAHhJuvV7bAAAAPoSYQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMXrUthceeWVOeyww1JTU5Oampo0Njbmm9/8Zsf9W7duzezZszN69OgMHz48M2fOTEtLS48PDQAA8HxdCpt99903F198cVavXp37778/06ZNy4knnpif/exnSZKzzz47d9xxR2655ZasWLEiGzZsyMknn7xHBgcAAHjOwK48+O1vf3unry+66KJceeWVWbVqVfbdd98sXrw4S5cuzbRp05IkS5YsycEHH5xVq1bl6KOP3ulztrW1pa2trePr1tbWrv4dAACAfq7b77F59tlnc9NNN2XLli1pbGzM6tWr88wzz2T69Okdj5k0aVLGjx+flStX7vJ55s+fn9ra2o5bQ0NDd0cCAAD6qS6HzU9+8pMMHz481dXV+fu///t89atfzeTJk9Pc3JzBgwdn5MiRnR5fV1eX5ubmXT7fvHnzsmnTpo7b+vXru/yXAAAA+rcuvRQtSQ466KD86Ec/yqZNm/If//EfOe2007JixYpuD1BdXZ3q6upufz8AAECXw2bw4MF59atfnSSZMmVK7rvvvnzhC1/IO9/5zmzbti0bN27sdNWmpaUl9fX1PTYwAADAn9vt32OzY8eOtLW1ZcqUKRk0aFCWL1/ecd+aNWvS1NSUxsbG3f0xAAAAu9SlKzbz5s3L8ccfn/Hjx+fJJ5/M0qVLc/fdd+euu+5KbW1tzjzzzMydOzejRo1KTU1N5syZk8bGxl1+IhoAAEBP6FLYPP7443nve9+bxx57LLW1tTnssMNy11135a//+q+TJJdffnkGDBiQmTNnpq2tLTNmzMiiRYv2yOAAAADP6VLYLF68+EXvHzJkSBYuXJiFCxfu1lAAAABdsdvvsQEAAKg0YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABSvS2Ezf/78/MVf/EVGjBiRMWPG5KSTTsqaNWs6PWbr1q2ZPXt2Ro8eneHDh2fmzJlpaWnp0aEBAACer0ths2LFisyePTurVq3Kt771rTzzzDM57rjjsmXLlo7HnH322bnjjjtyyy23ZMWKFdmwYUNOPvnkHh8cAADgOQO78uA777yz09fXXXddxowZk9WrV+ev/uqvsmnTpixevDhLly7NtGnTkiRLlizJwQcfnFWrVuXoo4/uuckBAAD+f7v1HptNmzYlSUaNGpUkWb16dZ555plMnz694zGTJk3K+PHjs3Llyp0+R1tbW1pbWzvdAAAAuqLbYbNjx46cddZZecMb3pDXvOY1SZLm5uYMHjw4I0eO7PTYurq6NDc37/R55s+fn9ra2o5bQ0NDd0cCAAD6qW6HzezZs/PTn/40N910024NMG/evGzatKnjtn79+t16PgAAoP/p0ntsnvMP//AP+drXvpbvfe972XfffTuO19fXZ9u2bdm4cWOnqzYtLS2pr6/f6XNVV1enurq6O2MAAAAk6eIVm/b29vzDP/xDvvrVr+Y73/lOJk6c2On+KVOmZNCgQVm+fHnHsTVr1qSpqSmNjY09MzEAAMCf6dIVm9mzZ2fp0qX5r//6r4wYMaLjfTO1tbUZOnRoamtrc+aZZ2bu3LkZNWpUampqMmfOnDQ2NvpENAAAYI/pUthceeWVSZI3vvGNnY4vWbIkp59+epLk8ssvz4ABAzJz5sy0tbVlxowZWbRoUY8MCwAAsDNdCpv29vb/8zFDhgzJwoULs3Dhwm4PBQAA0BW79XtsAAAA+gJhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDxhA0AAFA8YQMAABRP2AAAAMUTNgAAQPGEDQAAUDxhAwAAFE/YAAAAxety2Hzve9/L29/+9owbNy5VVVW57bbbOt3f3t6e8847L2PHjs3QoUMzffr0rF27tqfmBQAAeIEuh82WLVvy2te+NgsXLtzp/ZdeemmuuOKKXHXVVbn33nszbNiwzJgxI1u3bt3tYQEAAHZmYFe/4fjjj8/xxx+/0/va29uzYMGCfOYzn8mJJ56YJLnhhhtSV1eX2267Le9617t2b1oAAICd6NH32Kxbty7Nzc2ZPn16x7Ha2tpMnTo1K1eu3On3tLW1pbW1tdMNAACgK3o0bJqbm5MkdXV1nY7X1dV13Pfn5s+fn9ra2o5bQ0NDT44EAAD0AxX/VLR58+Zl06ZNHbf169dXeiQAAKAwPRo29fX1SZKWlpZOx1taWjru+3PV1dWpqanpdAMAAOiKHg2biRMnpr6+PsuXL+841tramnvvvTeNjY09+aMAAAA6dPlT0TZv3pz/+Z//6fh63bp1+dGPfpRRo0Zl/PjxOeuss3LhhRfmgAMOyMSJE3Puuedm3LhxOemkk3pybgAAgA5dDpv7778/b3rTmzq+njt3bpLktNNOy3XXXZdPfvKT2bJlSz74wQ9m48aNOeaYY3LnnXdmyJAhPTc1AADA83Q5bN74xjemvb19l/dXVVXlggsuyAUXXLBbgwEAALxUFf9UNAAAgN0lbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOIJGwAAoHjCBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2wAAIDiCRsAAKB4wgYAACiesAEAAIonbAAAgOLtsbBZuHBhJkyYkCFDhmTq1Kn54Q9/uKd+FAAA0M/tkbC5+eabM3fu3Jx//vl54IEH8trXvjYzZszI448/vid+HAAA0M/tkbC57LLL8oEPfCDve9/7Mnny5Fx11VXZe++9c+211+6JHwcAAPRzA3v6Cbdt25bVq1dn3rx5HccGDBiQ6dOnZ+XKlS94fFtbW9ra2jq+3rRpU5KktbW1p0frkh1tT1X051dapf/37wvsgB3o7zuQ2AM7YAfsgB2wA5Xdged+dnt7+//52B4Pm9/97nd59tlnU1dX1+l4XV1dHn744Rc8fv78+fnsZz/7guMNDQ09PRpdULug0hNQaXaAxB5gB7AD9I0dePLJJ1NbW/uij+nxsOmqefPmZe7cuR1f79ixI//7v/+b0aNHp6qqqoKTVU5ra2saGhqyfv361NTUVHocKsAOYAewAyT2ADvQ3t6eJ598MuPGjfs/H9vjYbPPPvtkr732SktLS6fjLS0tqa+vf8Hjq6urU11d3enYyJEje3qsItXU1PTLBeZP7AB2ADtAYg/o3zvwf12peU6Pf3jA4MGDM2XKlCxfvrzj2I4dO7J8+fI0Njb29I8DAADYMy9Fmzt3bk477bQcddRRed3rXpcFCxZky5Yted/73rcnfhwAANDP7ZGweec735knnngi5513Xpqbm3P44YfnzjvvfMEHCrBz1dXVOf/881/wEj36DzuAHcAOkNgD7EBXVLW/lM9OAwAA6MP2yC/oBAAA6E3CBgAAKJ6wAQAAiidsAACA4gkbAACgeMIGAAAonrABAACKJ2z6oHXr1uVb3/pWfvrTn1Z6FPqIP/zhD7nhhhsqPQYVMG3atDz66KOVHoNetGPHjl0eb2pq6uVpqJTt27fn29/+dr74xS/mySefTJJs2LAhmzdvrvBk7Gnt7e1Zt25dtm/fniTZtm1bbr755txwww353e9+V+Hp+ja/oLPCPvKRj+TSSy/N8OHD8/TTT+fv/u7v8tWvfjXt7e2pqqrKsccem9tvvz3Dhw+v9KhU0EMPPZQjjzwyzz77bKVHYQ+5/fbbd3r85JNPzhe+8IU0NDQkSd7xjnf05lj0otbW1rz//e/PHXfckZqamnzoQx/K+eefn7322itJ0tLSknHjxjkP9AOPPvpo3vKWt6SpqSltbW355S9/mf333z8f+9jH0tbWlquuuqrSI7KHrFmzJjNmzMj69euz//77Z9myZTnllFPy8MMPp729PXvvvXfuueeeHHDAAZUetU8SNhW211575bHHHsuYMWPy6U9/Ol/+8pdzww03ZOrUqXnwwQdz2mmn5ZRTTsn8+fMrPSp7UGtr64ve/+Mf/zjHHnus/6B5GRswYECqqqryYqfkqqoqO/Ay9rGPfSx33nlnLrroomzcuDEXXnhhXvOa1+TWW2/N4MGD09LSkrFjx+7yig4vHyeddFJGjBiRxYsXZ/To0XnooYey//775+67784HPvCBrF27ttIjsoecdNJJaW9vz4UXXphrr702d911Vw488MDccsst2bFjR0455ZTU1tbmy1/+cqVH7ZOETYUNGDAgzc3NGTNmTA499NB8+tOfzqmnntpx/+23355PfOITWbNmTQWnZE977j9qd+W5K3j+o/bl6/jjj89ee+2Va6+9NmPGjOk4PmjQoDz00EOZPHlyBaejN+y33365/vrr88Y3vjFJ8rvf/S4nnHBCRo4cmdtvvz0bN250xaafGD16dO65554cdNBBGTFiREfYPPLII5k8eXKeeuqpSo/IHjJmzJgsW7Yshx9+eLZs2ZIRI0bke9/7Xo455pgkyT333JNTTz3VS5R3YWClByAd/0Hb3Nycww47rNN9r33ta7N+/fpKjEUvGjFiRP7xH/8xU6dO3en9a9euzYc+9KFenore9M1vfjOXX355jjrqqCxatChve9vbKj0SveyJJ57Ifvvt1/H1Pvvsk29/+9uZMWNG3vrWt+ZLX/pSBaejN+3YsWOnAfub3/wmI0aMqMBE9JbNmzdn1KhRSZJhw4Zl2LBhGTt2bMf9DQ0NaWlpqdR4fZ6w6QPOPffc7L333hkwYEA2bNiQQw45pOO+3//+9xk2bFgFp6M3HHnkkUmSY489dqf3jxw58kVfosTLw9lnn503velNmTVrVu64445cfvnllR6JXjR+/Pj84he/yMSJEzuOjRgxIsuWLctxxx2Xv/mbv6ngdPSm4447LgsWLMjVV1+d5I//ALp58+acf/75eetb31rh6diTxo0bl6ampowfPz5Jcumll3a6iv/EE0/kFa94RaXG6/N8KlqF/dVf/VXWrFmTBx98MJMnT37BpcVvfOMbnUKHl6d3v/vdGTJkyC7vr6+vz/nnn9+LE1Ephx9+eO6///5UVVXl8MMPF7T9yHHHHZclS5a84Pjw4cNz1113veg5gpeXf/7nf84PfvCDTJ48OVu3bs273/3uTJgwIb/97W9zySWXVHo89qDp06fn4Ycf7vj6wx/+cKerdMuWLev4x1BeyHts+rhf//rXGTx4cPbdd99KjwL0sttvvz3f/e53M2/evE7/YsfL0x/+8IcXXLV/vieffDIPPPDALq/s8vKyffv23HTTTfnxj3+czZs358gjj8ysWbMydOjQSo9GBa1bty5Dhgzp9PI0/kTYFObQQw/NN77xjY6PfqV/sgfYAewA4DzQmffYFOaRRx7JM888U+kxqDB7gB3ADry8rV27Nt/97nfz+OOPv+Ajvs8777wKTUVf4zzQmbABAOhDrrnmmnz4wx/OPvvsk/r6+k6/DqCqqkrYwC4IGwCAPuTCCy/MRRddlE996lOVHgWK4lPRAAD6kD/84Q855ZRTKj0GFEfYAAD0IaecckqWLVtW6TGgOF6KBgDQh7z61a/Oueeem1WrVuXQQw/NoEGDOt3/0Y9+tEKTQd/m454Ls3Tp0px44okZNmxYpUehguwBdgA78PI1ceLEXd5XVVWVX//61704DX2Z80BnwqaCrrjiipf8WP868/JlD7AD2AHAeWD3CZsKerF/kXk+/zrz8mYPsAPYAcB5YPcJGwCACps7d24+97nPZdiwYZk7d+6LPvayyy7rpamgLD48oI/Ztm1b1q1bl1e96lUZOND/Pf2VPcAOYAf6lwcffLDjN8g/+OCDu3zc839ZJy9/zgNd4+Oe+4innnoqZ555Zvbee+8ccsghaWpqSpLMmTMnF198cYWno7fYA+wAdqB/+u53v5uRI0d2/HlXt+985zuVHZRe4TzQPcKmj5g3b14eeuih3H333RkyZEjH8enTp+fmm2+u4GT0JnuAHcAOAM4D3eOaVh9x22235eabb87RRx/d6TLzIYcckl/96lcVnIzeZA+wA9iB/unkk09+yY+99dZb9+Ak9AXOA90jbPqIJ554ImPGjHnB8S1btng9bT9iD7AD2IH+qba2ttIj0Ic4D3SPsOkjjjrqqHz961/PnDlzkvzpzYFf+tKX0tjYWMnR6EX2ADuAHeiflixZUukR6EOcB7pH2PQRn//853P88cfn5z//ebZv354vfOEL+fnPf5577rknK1asqPR49BJ7gB3ADpAk27dvz913351f/epXefe7350RI0Zkw4YNqampyfDhwys9HnuY80D3+PCAPuKYY47Jj370o2zfvj2HHnpoli1bljFjxmTlypWZMmVKpcejl9gD7AB2gEcffTSHHnpoTjzxxMyePTtPPPFEkuSSSy7Jxz/+8QpPR29wHugev6ATAKAPOemkkzJixIgsXrw4o0ePzkMPPZT9998/d999dz7wgQ9k7dq1lR4R+iQvRaug1tbWl/zYmpqaPTgJlWQPsAPYAZ7vv//7v3PPPfdk8ODBnY5PmDAhv/3tbys0FXua88DuEzYVNHLkyJf8yRbPPvvsHp6GSrEH2AHsAM+3Y8eOnf7//Jvf/CYjRoyowET0BueB3SdsKui73/1ux58feeSRnHPOOTn99NM7Pu1i5cqVuf766zN//vxKjUgvsAfYAewAz3fcccdlwYIFufrqq5P88ROxNm/enPPPPz9vfetbKzwde4rzQA9op0+YNm1a+9KlS19w/Ctf+Ur7scce2/sDURH2ADuAHWD9+vXtkydPbj/44IPbBw4c2H700Ue3jx49uv2ggw5qb2lpqfR49ALnge7x4QF9xN57752HHnooBxxwQKfjv/zlL3P44YfnqaeeqtBk9CZ7gB3ADpD88eOeb7755jz00EPZvHlzjjzyyMyaNStDhw6t9Gj0AueB7vFxz31EQ0NDrrnmmhcc/9KXvpSGhoYKTEQl2APsAHaAJBk4cGBmzZqVSy+9NIsWLcr73/9+UdOPOA90jys2fcQ3vvGNzJw5M69+9aszderUJMkPf/jDrF27Nv/5n//pNbX9hD3ADmAHuP7667PPPvvkhBNOSJJ88pOfzNVXX53JkyfnxhtvzH777VfhCdnTnAe6R9j0Ib/5zW+yaNGiPPzww0mSgw8+OH//93+vzPsZe4AdwA70bwcddFCuvPLKTJs2LStXrsyb3/zmLFiwIF/72tcycODA3HrrrZUekV7gPNB1wgYAoA/Ze++98/DDD2f8+PH51Kc+lcceeyw33HBDfvazn+WNb3xjnnjiiUqPCH2Sj3vuQzZu3JjFixfnF7/4RZLkkEMOyRlnnJHa2toKT0ZvsgfYAexA/zZ8+PD8/ve/z/jx47Ns2bLMnTs3STJkyJA8/fTTFZ6O3uI80HWu2PQR999/f2bMmJGhQ4fmda97XZLkvvvuy9NPP51ly5blyCOPrPCE9AZ7gB3ADjBr1qw8/PDDOeKII3LjjTemqakpo0ePzu23355Pf/rT+elPf1rpEdnDnAe6R9j0EX/5l3+ZV7/61bnmmmsycOAfL6Rt374973//+/PrX/863/ve9yo8Ib3BHmAHsANs3Lgxn/nMZ7J+/fp8+MMfzlve8pYkyfnnn5/BgwfnH//xHys8IXua80D3CJs+YujQoXnwwQczadKkTsd//vOf56ijjvJ55f2EPcAOYAcA54Hu8Xts+oiampo0NTW94Pj69eszYsSICkxEJdgD7AB2gOc89dRTefjhh/PjH/+4042XP+eB7vHhAX3EO9/5zpx55pn5p3/6p7z+9a9PkvzgBz/IJz7xiZx66qkVno7eYg+wA9gBnnjiiZx++um58847d3r/s88+28sT0ducB7qpnT6hra2t/aMf/Wj74MGD2wcMGNBeVVXVXl1d3X7WWWe1b926tdLj0UvsAXYAO8C73/3u9je84Q3t9913X/uwYcPaly1b1v7lL3+5/aCDDmr/2te+Vunx6AXOA93jPTZ9zFNPPZVf/epXSZJXvepV2XvvvSs8EZVgD7AD2IH+a+zYsfmv//qvvO51r0tNTU3uv//+HHjggbn99ttz6aWX5vvf/36lR6SXOA90jZeiVdgZZ5zxkh537bXX7uFJqCR7gB3ADvCcLVu2ZMyYMUmSV7ziFXniiSdy4IEH5tBDD80DDzxQ4enYk5wHdo+wqbDrrrsu++23X4444oi4eNZ/2QPsAHaA5xx00EFZs2ZNJkyYkNe+9rX54he/mAkTJuSqq67K2LFjKz0ee5DzwO7xUrQKmz17dm688cbst99+ed/73pf3vOc9GTVqVKXHopfZA+wAdoDn/Nu//Vu2b9+e008/PatXr85b3vKW/P73v8/gwYNz/fXX553vfGelR2QPcR7YPcKmD2hra8utt96aa6+9Nvfcc09OOOGEnHnmmTnuuONSVVVV6fHoJfYAO4Ad4M+1t7fn6aefzsMPP5zx48dnn332qfRI7GHOA90nbPqYRx99NNddd11uuOGGbN++PT/72c8yfPjwSo9FL7MH2AHsQP+2ePHiXH755Vm7dm2S5IADDshZZ52V97///RWejN7kPNA13mPTxwwYMCBVVVVpb2/3OfX9mD3ADmAH+q/zzjsvl112WebMmZPGxsYkycqVK3P22WenqakpF1xwQYUnpLc4D3TNgEoPwB8vOd54443567/+6xx44IH5yU9+kn/9139NU1OTKu9H7AF2ADtAklx55ZW55pprMn/+/LzjHe/IO97xjsyfPz9XX311Fi1aVOnx2MOcB7rPFZsK+8hHPpKbbropDQ0NOeOMM3LjjTd6/Ww/ZA+wA9gBnvPMM8/kqKOOesHxKVOmZPv27RWYiN7iPLB7vMemwgYMGJDx48fniCOOeNE3hN166629OBW9zR5gB7ADPGfOnDkZNGhQLrvssk7HP/7xj+fpp5/OwoULKzQZe5rzwO5xxabC3vve9/qEC+wBdgA70M/NnTu3489VVVX50pe+lGXLluXoo49Oktx7771pamrKe9/73kqNSC9wHtg9rtgAAFTYm970ppf0uKqqqnznO9/Zw9NAmYQNAABQPJ+KBgAAFE/YAAAAxRM2AABA8YQNAABQPGEDAAAUT9gAAADFEzYAAEDx/j8lPLuyRRI47wAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "all_model_results.sort_values(\"F1-score\", ascending=False)[\"F1-score\"].plot(kind=\"bar\", figsize=(10, 7));" ] }, { "cell_type": "code", "execution_count": 137, "id": "bf2c88d5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:Found untraced functions such as lstm_cell_4_layer_call_fn, lstm_cell_4_layer_call_and_return_conditional_losses, lstm_cell_5_layer_call_fn, lstm_cell_5_layer_call_and_return_conditional_losses while saving (showing 4 of 4). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: skimlit_final_model/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: skimlit_final_model/assets\n" ] } ], "source": [ "model_5.save(\"skimlit_final_model\") " ] }, { "cell_type": "code", "execution_count": 3, "id": "4c31f7fe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Metal device set to: Apple M2\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2023-07-08 19:38:08.228766: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz\n" ] } ], "source": [ "loaded_model=tf.keras.models.load_model(\"skimlit_final_model\")" ] }, { "cell_type": "code", "execution_count": 27, "id": "33b66792", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'abstract': 'This RCT examined the efficacy of a manualized social intervention for children with HFASDs. Participants were randomly assigned to treatment or wait-list conditions. Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language. A response-cost program was applied to reduce problem behaviors and foster skills acquisition. Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures). Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents. High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity. Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.',\n", " 'source': 'https://pubmed.ncbi.nlm.nih.gov/20232240/',\n", " 'details': 'RCT of a manualized social treatment for high-functioning autism spectrum disorders'},\n", " {'abstract': \"Postpartum depression (PPD) is the most prevalent mood disorder associated with childbirth. No single cause of PPD has been identified, however the increased risk of nutritional deficiencies incurred through the high nutritional requirements of pregnancy may play a role in the pathology of depressive symptoms. Three nutritional interventions have drawn particular interest as possible non-invasive and cost-effective prevention and/or treatment strategies for PPD; omega-3 (n-3) long chain polyunsaturated fatty acids (LCPUFA), vitamin D and overall diet. We searched for meta-analyses of randomised controlled trials (RCT's) of nutritional interventions during the perinatal period with PPD as an outcome, and checked for any trials published subsequently to the meta-analyses. Fish oil: Eleven RCT's of prenatal fish oil supplementation RCT's show null and positive effects on PPD symptoms. Vitamin D: no relevant RCT's were identified, however seven observational studies of maternal vitamin D levels with PPD outcomes showed inconsistent associations. Diet: Two Australian RCT's with dietary advice interventions in pregnancy had a positive and null result on PPD. With the exception of fish oil, few RCT's with nutritional interventions during pregnancy assess PPD. Further research is needed to determine whether nutritional intervention strategies during pregnancy can protect against symptoms of PPD. Given the prevalence of PPD and ease of administering PPD measures, we recommend future prenatal nutritional RCT's include PPD as an outcome.\",\n", " 'source': 'https://pubmed.ncbi.nlm.nih.gov/28012571/',\n", " 'details': 'Formatting removed (can be used to compare model to actual example)'},\n", " {'abstract': 'Mental illness, including depression, anxiety and bipolar disorder, accounts for a significant proportion of global disability and poses a substantial social, economic and heath burden. Treatment is presently dominated by pharmacotherapy, such as antidepressants, and psychotherapy, such as cognitive behavioural therapy; however, such treatments avert less than half of the disease burden, suggesting that additional strategies are needed to prevent and treat mental disorders. There are now consistent mechanistic, observational and interventional data to suggest diet quality may be a modifiable risk factor for mental illness. This review provides an overview of the nutritional psychiatry field. It includes a discussion of the neurobiological mechanisms likely modulated by diet, the use of dietary and nutraceutical interventions in mental disorders, and recommendations for further research. Potential biological pathways related to mental disorders include inflammation, oxidative stress, the gut microbiome, epigenetic modifications and neuroplasticity. Consistent epidemiological evidence, particularly for depression, suggests an association between measures of diet quality and mental health, across multiple populations and age groups; these do not appear to be explained by other demographic, lifestyle factors or reverse causality. Our recently published intervention trial provides preliminary clinical evidence that dietary interventions in clinically diagnosed populations are feasible and can provide significant clinical benefit. Furthermore, nutraceuticals including n-3 fatty acids, folate, S-adenosylmethionine, N-acetyl cysteine and probiotics, among others, are promising avenues for future research. Continued research is now required to investigate the efficacy of intervention studies in large cohorts and within clinically relevant populations, particularly in patients with schizophrenia, bipolar and anxiety disorders.',\n", " 'source': 'https://pubmed.ncbi.nlm.nih.gov/28942748/',\n", " 'details': 'Effect of nutrition on mental health'},\n", " {'abstract': \"Hepatitis C virus (HCV) and alcoholic liver disease (ALD), either alone or in combination, count for more than two thirds of all liver diseases in the Western world. There is no safe level of drinking in HCV-infected patients and the most effective goal for these patients is total abstinence. Baclofen, a GABA(B) receptor agonist, represents a promising pharmacotherapy for alcohol dependence (AD). Previously, we performed a randomized clinical trial (RCT), which demonstrated the safety and efficacy of baclofen in patients affected by AD and cirrhosis. The goal of this post-hoc analysis was to explore baclofen's effect in a subgroup of alcohol-dependent HCV-infected cirrhotic patients. Any patient with HCV infection was selected for this analysis. Among the 84 subjects randomized in the main trial, 24 alcohol-dependent cirrhotic patients had a HCV infection; 12 received baclofen 10mg t.i.d. and 12 received placebo for 12-weeks. With respect to the placebo group (3/12, 25.0%), a significantly higher number of patients who achieved and maintained total alcohol abstinence was found in the baclofen group (10/12, 83.3%; p=0.0123). Furthermore, in the baclofen group, compared to placebo, there was a significantly higher increase in albumin values from baseline (p=0.0132) and a trend toward a significant reduction in INR levels from baseline (p=0.0716). In conclusion, baclofen was safe and significantly more effective than placebo in promoting alcohol abstinence, and improving some Liver Function Tests (LFTs) (i.e. albumin, INR) in alcohol-dependent HCV-infected cirrhotic patients. Baclofen may represent a clinically relevant alcohol pharmacotherapy for these patients.\",\n", " 'source': 'https://pubmed.ncbi.nlm.nih.gov/22244707/',\n", " 'details': 'Baclofen promotes alcohol abstinence in alcohol dependent cirrhotic patients with hepatitis C virus (HCV) infection'}]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import json\n", "with open(\"skimlit_example_abstracts.json\", \"r\") as f:\n", " example_abstracts = json.load(f)\n", "\n", "example_abstracts" ] }, { "cell_type": "code", "execution_count": 29, "id": "a0998e3e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/ujjwalbansal/anaconda3/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "data": { "text/plain": [ "['This RCT examined the efficacy of a manualized social intervention for children with HFASDs.',\n", " 'Participants were randomly assigned to treatment or wait-list conditions.',\n", " 'Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language.',\n", " 'A response-cost program was applied to reduce problem behaviors and foster skills acquisition.',\n", " 'Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures).',\n", " 'Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents.',\n", " 'High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity.',\n", " 'Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.']" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from spacy.lang.en import English\n", "nlp = English() \n", "sentencizer = nlp.add_pipe(\"sentencizer\") \n", "\n", "doc = nlp(example_abstracts[0][\"abstract\"]) \n", "abstract_lines = [str(sent) for sent in list(doc.sents)] \n", "abstract_lines" ] }, { "cell_type": "code", "execution_count": 32, "id": "47a6506a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'text': 'This RCT examined the efficacy of a manualized social intervention for children with HFASDs.',\n", " 'line_number': 0,\n", " 'total_lines': 7},\n", " {'text': 'Participants were randomly assigned to treatment or wait-list conditions.',\n", " 'line_number': 1,\n", " 'total_lines': 7},\n", " {'text': 'Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language.',\n", " 'line_number': 2,\n", " 'total_lines': 7},\n", " {'text': 'A response-cost program was applied to reduce problem behaviors and foster skills acquisition.',\n", " 'line_number': 3,\n", " 'total_lines': 7},\n", " {'text': 'Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures).',\n", " 'line_number': 4,\n", " 'total_lines': 7},\n", " {'text': 'Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents.',\n", " 'line_number': 5,\n", " 'total_lines': 7},\n", " {'text': 'High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity.',\n", " 'line_number': 6,\n", " 'total_lines': 7},\n", " {'text': 'Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.',\n", " 'line_number': 7,\n", " 'total_lines': 7}]" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "total_lines_in_sample = len(abstract_lines)\n", "\n", "sample_lines = []\n", "for i, line in enumerate(abstract_lines):\n", " sample_dict = {}\n", " sample_dict[\"text\"] = str(line)\n", " sample_dict[\"line_number\"] = i\n", " sample_dict[\"total_lines\"] = total_lines_in_sample - 1\n", " sample_lines.append(sample_dict)\n", "sample_lines\n" ] }, { "cell_type": "code", "execution_count": 33, "id": "f4fab058", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_abstract_line_numbers = [line[\"line_number\"] for line in sample_lines]\n", "test_abstract_line_numbers_one_hot = tf.one_hot(test_abstract_line_numbers, depth=15) \n", "test_abstract_line_numbers_one_hot\n" ] }, { "cell_type": "code", "execution_count": 34, "id": "58215d49", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_abstract_total_lines = [line[\"total_lines\"] for line in sample_lines]\n", "test_abstract_total_lines_one_hot = tf.one_hot(test_abstract_total_lines, depth=20)\n", "test_abstract_total_lines_one_hot" ] }, { "cell_type": "code", "execution_count": 37, "id": "08f1e56a", "metadata": {}, "outputs": [], "source": [ "def split_to_char(text):\n", " return \" \" .join(list(text))" ] }, { "cell_type": "code", "execution_count": 38, "id": "64d42218", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['T h i s R C T e x a m i n e d t h e e f f i c a c y o f a m a n u a l i z e d s o c i a l i n t e r v e n t i o n f o r c h i l d r e n w i t h H F A S D s .',\n", " 'P a r t i c i p a n t s w e r e r a n d o m l y a s s i g n e d t o t r e a t m e n t o r w a i t - l i s t c o n d i t i o n s .',\n", " 'T r e a t m e n t i n c l u d e d i n s t r u c t i o n a n d t h e r a p e u t i c a c t i v i t i e s t a r g e t i n g s o c i a l s k i l l s , f a c e - e m o t i o n r e c o g n i t i o n , i n t e r e s t e x p a n s i o n , a n d i n t e r p r e t a t i o n o f n o n - l i t e r a l l a n g u a g e .',\n", " 'A r e s p o n s e - c o s t p r o g r a m w a s a p p l i e d t o r e d u c e p r o b l e m b e h a v i o r s a n d f o s t e r s k i l l s a c q u i s i t i o n .',\n", " 'S i g n i f i c a n t t r e a t m e n t e f f e c t s w e r e f o u n d f o r f i v e o f s e v e n p r i m a r y o u t c o m e m e a s u r e s ( p a r e n t r a t i n g s a n d d i r e c t c h i l d m e a s u r e s ) .',\n", " 'S e c o n d a r y m e a s u r e s b a s e d o n s t a f f r a t i n g s ( t r e a t m e n t g r o u p o n l y ) c o r r o b o r a t e d g a i n s r e p o r t e d b y p a r e n t s .',\n", " 'H i g h l e v e l s o f p a r e n t , c h i l d a n d s t a f f s a t i s f a c t i o n w e r e r e p o r t e d , a l o n g w i t h h i g h l e v e l s o f t r e a t m e n t f i d e l i t y .',\n", " 'S t a n d a r d i z e d e f f e c t s i z e e s t i m a t e s w e r e p r i m a r i l y i n t h e m e d i u m a n d l a r g e r a n g e s a n d f a v o r e d t h e t r e a t m e n t g r o u p .']" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "abstract_chars = [split_to_char(sentence) for sentence in abstract_lines]\n", "abstract_chars" ] }, { "cell_type": "code", "execution_count": 40, "id": "535753c3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1/1 [==============================] - 2s 2s/step\n" ] }, { "data": { "text/plain": [ "array([[0.26708567, 0.09977566, 0.02257598, 0.5738482 , 0.03671453],\n", " [0.0666267 , 0.03083183, 0.71341974, 0.10463963, 0.08448218],\n", " [0.14207463, 0.06278381, 0.53325343, 0.1764009 , 0.08548719],\n", " [0.08821893, 0.13287878, 0.5818414 , 0.07643546, 0.12062543],\n", " [0.05683414, 0.11416066, 0.38816544, 0.04979956, 0.39104018],\n", " [0.03436745, 0.11669327, 0.5302272 , 0.04525217, 0.27345994],\n", " [0.02947308, 0.14869098, 0.08467548, 0.028472 , 0.7086885 ],\n", " [0.01927791, 0.1428847 , 0.27512696, 0.03265918, 0.5300512 ]],\n", " dtype=float32)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Make predictions on sample abstract features\n", "test_abstract_pred_probs = loaded_model.predict(x=(test_abstract_line_numbers_one_hot,\n", " test_abstract_total_lines_one_hot,\n", " tf.constant(abstract_lines),\n", " tf.constant(abstract_chars)))\n", "test_abstract_pred_probs" ] }, { "cell_type": "code", "execution_count": 41, "id": "99564562", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_abstract_preds = tf.argmax(test_abstract_pred_probs, axis=1)\n", "test_abstract_preds" ] }, { "cell_type": "code", "execution_count": 45, "id": "ae4c4cc4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['OBJECTIVE',\n", " 'METHODS',\n", " 'METHODS',\n", " 'METHODS',\n", " 'RESULTS',\n", " 'METHODS',\n", " 'RESULTS',\n", " 'RESULTS']" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_abstract_pred_classes = [labelencoder.classes_[i] for i in test_abstract_preds]\n", "test_abstract_pred_classes" ] }, { "cell_type": "code", "execution_count": 47, "id": "2fb163f4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OBJECTIVE: This RCT examined the efficacy of a manualized social intervention for children with HFASDs.\n", "METHODS: Participants were randomly assigned to treatment or wait-list conditions.\n", "METHODS: Treatment included instruction and therapeutic activities targeting social skills, face-emotion recognition, interest expansion, and interpretation of non-literal language.\n", "METHODS: A response-cost program was applied to reduce problem behaviors and foster skills acquisition.\n", "RESULTS: Significant treatment effects were found for five of seven primary outcome measures (parent ratings and direct child measures).\n", "METHODS: Secondary measures based on staff ratings (treatment group only) corroborated gains reported by parents.\n", "RESULTS: High levels of parent, child and staff satisfaction were reported, along with high levels of treatment fidelity.\n", "RESULTS: Standardized effect size estimates were primarily in the medium and large ranges and favored the treatment group.\n" ] } ], "source": [ "for i, line in enumerate(abstract_lines):\n", " print(f\"{test_abstract_pred_classes[i]}: {line}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "df64834a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10 (tensorflow)", "language": "python", "name": "tensorflow" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }