{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "MTL-Bert LREC for github.ipynb",
"provenance": [],
"collapsed_sections": [
"PXTixssZe1Rm"
],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "q639yW9Sunhy",
"colab_type": "text"
},
"source": [
"#Multi-Task Learning using AraBert for Offensive Language Detection\n",
"\n",
"Notebook used in the OSACT4 - shared task on Offensive language detection (LREC 2020)\n",
"http://edinburghnlp.inf.ed.ac.uk/workshops/OSACT4/ \n",
"\n",
"Task Paper: https://www.aclweb.org/anthology/2020.osact-1.16/ \n",
"\n",
"Authors: Marc Djandji, Fady Baly, Wissam Antoun"
]
},
{
"cell_type": "code",
"metadata": {
"id": "pYrVtpSu04iI",
"colab_type": "code",
"outputId": "1ad66211-61ad-4052-e2a8-e638008dc205",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 175
}
},
"source": [
"import os\n",
"import collections\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import tensorflow_hub as hub\n",
"from datetime import datetime\n",
"##install arabert if not already done\n",
"!git clone -b macrof1 https://github.com/WissamAntoun/bert\n",
"\n",
"import bert.tokenization as tokenization\n",
"import bert.modeling as modeling\n",
"import bert.optimization as optimization\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"\n",
"from google.colab import auth, drive\n",
"auth.authenticate_user()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"
\n",
"The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.
\n",
"We recommend you upgrade now \n",
"or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x
magic:\n",
"more info.
\n"
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Cloning into 'bert'...\n",
"remote: Enumerating objects: 415, done.\u001b[K\n",
"Receiving objects: 0% (1/415) \rReceiving objects: 1% (5/415) \rReceiving objects: 2% (9/415) \rReceiving objects: 3% (13/415) \rReceiving objects: 4% (17/415) \rReceiving objects: 5% (21/415) \rReceiving objects: 6% (25/415) \rReceiving objects: 7% (30/415) \rReceiving objects: 8% (34/415) \rReceiving objects: 9% (38/415) \rReceiving objects: 10% (42/415) \rReceiving objects: 11% (46/415) \rReceiving objects: 12% (50/415) \rReceiving objects: 13% (54/415) \rReceiving objects: 14% (59/415) \rReceiving objects: 15% (63/415) \rReceiving objects: 16% (67/415) \rReceiving objects: 17% (71/415) \rReceiving objects: 18% (75/415) \rReceiving objects: 19% (79/415) \rReceiving objects: 20% (83/415) \rReceiving objects: 21% (88/415) \rReceiving objects: 22% (92/415) \rReceiving objects: 23% (96/415) \rReceiving objects: 24% (100/415) \rReceiving objects: 25% (104/415) \rReceiving objects: 26% (108/415) \rReceiving objects: 27% (113/415) \rReceiving objects: 28% (117/415) \rReceiving objects: 29% (121/415) \rReceiving objects: 30% (125/415) \rReceiving objects: 31% (129/415) \rReceiving objects: 32% (133/415) \rReceiving objects: 33% (137/415) \rReceiving objects: 34% (142/415) \rReceiving objects: 35% (146/415) \rReceiving objects: 36% (150/415) \rReceiving objects: 37% (154/415) \rReceiving objects: 38% (158/415) \rReceiving objects: 39% (162/415) \rReceiving objects: 40% (166/415) \rReceiving objects: 41% (171/415) \rReceiving objects: 42% (175/415) \rReceiving objects: 43% (179/415) \rReceiving objects: 44% (183/415) \rReceiving objects: 45% (187/415) \rReceiving objects: 46% (191/415) \rReceiving objects: 47% (196/415) \rReceiving objects: 48% (200/415) \rReceiving objects: 49% (204/415) \rReceiving objects: 50% (208/415) \rReceiving objects: 51% (212/415) \rReceiving objects: 52% (216/415) \rReceiving objects: 53% (220/415) \rReceiving objects: 54% (225/415) \rReceiving objects: 55% (229/415) \rReceiving objects: 56% (233/415) \rReceiving objects: 57% (237/415) \rReceiving objects: 58% (241/415) \rReceiving objects: 59% (245/415) \rReceiving objects: 60% (249/415) \rReceiving objects: 61% (254/415) \rReceiving objects: 62% (258/415) \rReceiving objects: 63% (262/415) \rReceiving objects: 64% (266/415) \rReceiving objects: 65% (270/415) \rReceiving objects: 66% (274/415) \rReceiving objects: 67% (279/415) \rReceiving objects: 68% (283/415) \rReceiving objects: 69% (287/415) \rReceiving objects: 70% (291/415) \rReceiving objects: 71% (295/415) \rReceiving objects: 72% (299/415) \rReceiving objects: 73% (303/415) \rReceiving objects: 74% (308/415) \rReceiving objects: 75% (312/415) \rReceiving objects: 76% (316/415) \rReceiving objects: 77% (320/415) \rReceiving objects: 78% (324/415) \rReceiving objects: 79% (328/415) \rReceiving objects: 80% (332/415) \rReceiving objects: 81% (337/415) \rReceiving objects: 82% (341/415) \rReceiving objects: 83% (345/415) \rReceiving objects: 84% (349/415) \rReceiving objects: 85% (353/415) \rReceiving objects: 86% (357/415) \rReceiving objects: 87% (362/415) \rremote: Total 415 (delta 0), reused 0 (delta 0), pack-reused 415\u001b[K\n",
"Receiving objects: 88% (366/415) \rReceiving objects: 89% (370/415) \rReceiving objects: 90% (374/415) \rReceiving objects: 91% (378/415) \rReceiving objects: 92% (382/415) \rReceiving objects: 93% (386/415) \rReceiving objects: 94% (391/415) \rReceiving objects: 95% (395/415) \rReceiving objects: 96% (399/415) \rReceiving objects: 97% (403/415) \rReceiving objects: 98% (407/415) \rReceiving objects: 99% (411/415) \rReceiving objects: 100% (415/415) \rReceiving objects: 100% (415/415), 420.19 KiB | 20.01 MiB/s, done.\n",
"Resolving deltas: 0% (0/234) \rResolving deltas: 1% (4/234) \rResolving deltas: 2% (5/234) \rResolving deltas: 3% (8/234) \rResolving deltas: 4% (10/234) \rResolving deltas: 5% (14/234) \rResolving deltas: 7% (17/234) \rResolving deltas: 8% (21/234) \rResolving deltas: 9% (22/234) \rResolving deltas: 10% (24/234) \rResolving deltas: 28% (67/234) \rResolving deltas: 30% (71/234) \rResolving deltas: 33% (78/234) \rResolving deltas: 34% (81/234) \rResolving deltas: 36% (86/234) \rResolving deltas: 38% (89/234) \rResolving deltas: 41% (96/234) \rResolving deltas: 52% (123/234) \rResolving deltas: 58% (138/234) \rResolving deltas: 60% (142/234) \rResolving deltas: 93% (219/234) \rResolving deltas: 94% (220/234) \rResolving deltas: 100% (234/234) \rResolving deltas: 100% (234/234), done.\n",
"WARNING:tensorflow:From /content/bert/optimization.py:87: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6POYz6YRJ-bv",
"colab_type": "code",
"outputId": "af952439-8ec7-441f-f204-7267ccf45238",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
}
},
"source": [
"import pynvml\n",
"pynvml.nvmlInit()\n",
"handle = pynvml.nvmlDeviceGetHandleByIndex(0)\n",
"device_name = pynvml.nvmlDeviceGetName(handle)\n",
"print(device_name)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"b'Tesla P4'\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "7EOWk73XB2aD",
"colab_type": "code",
"outputId": "0d05ba75-3d86-4e23-b807-a4ea09c481da",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 118
}
},
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n",
"\n",
"Enter your authorization code:\n",
"··········\n",
"Mounted at /content/drive\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "W10ugCcAEpYY",
"colab_type": "code",
"colab": {}
},
"source": [
"!cp /content/drive/'My Drive'/best_MTL_model/best_model.zip ./"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ej9HhYNVEvr9",
"colab_type": "code",
"outputId": "90d4b19c-5c90-45b1-eefe-bf94695d5cfb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 147
}
},
"source": [
"!unzip /content/best_model.zip "
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Archive: /content/best_model.zip\n",
" creating: content/best_model/\n",
" inflating: content/best_model/events.out.tfevents.1582011537.9f8209d16eb2 \n",
" inflating: content/best_model/checkpoint \n",
" inflating: content/best_model/model.ckpt-1200.meta \n",
" inflating: content/best_model/graph.pbtxt \n",
" inflating: content/best_model/model.ckpt-1200.index \n",
" inflating: content/best_model/model.ckpt-1200.data-00000-of-00001 \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LkDDkSYcEjm5",
"colab_type": "code",
"colab": {}
},
"source": [
"!cp -r /content/content/best_model/model.ckpt-1200.index /content/Output_Dir/1250000/\n",
"!cp -r /content/content/best_model/model.ckpt-1200.meta /content/Output_Dir/1250000/\n",
"!cp -r /content/content/best_model/model.ckpt-1200.data-00000-of-00001 /content/Output_Dir/1250000/"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xowJMVAvFh8C",
"colab_type": "code",
"colab": {}
},
"source": [
"!mv /content/Output_Dir/1250000/model.ckpt-1200.data-00000-of-00001 /content/Output_Dir/1250000/model.ckpt-0.data-00000-of-00001\n",
"!mv /content/Output_Dir/1250000/model.ckpt-1200.index /content/Output_Dir/1250000/model.ckpt-0.index\n",
"!mv /content/Output_Dir/1250000/model.ckpt-1200.meta /content/Output_Dir/1250000/model.ckpt-0.meta"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "mjSgO2t3ZiVj",
"colab_type": "code",
"colab": {}
},
"source": [
"!cp -r /content/drive/'My Drive'/Bert/arabert.zip /content/\n",
"# # !cp -r /content/drive/'My Drive'/Trained_bert/Output_Dir.zip /content"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4sLMPRfDHygJ",
"colab_type": "code",
"colab": {}
},
"source": [
"# # !unzip Output_Dir.zip\n",
"!unzip arabert.zip"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "aHP_MFYJ08XK",
"colab_type": "code",
"colab": {}
},
"source": [
"# ##use tf_arabert model, change path accordingly\n",
"BERT_VOCAB= ''\n",
"BERT_INIT_CHKPNT = ''\n",
"BERT_CONFIG = ''"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "q0kHR7py08Uz",
"colab_type": "code",
"outputId": "310c4244-6c76-4a9c-be1e-31ad79990f1e",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 50
}
},
"source": [
"tokenization.validate_case_matches_checkpoint(True, BERT_INIT_CHKPNT)\n",
"tokenizer = tokenization.FullTokenizer(vocab_file=BERT_VOCAB,\n",
" do_lower_case=True)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /content/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U6x0tLpWew0V",
"colab_type": "text"
},
"source": [
"### Data Preprocessing\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4ntkpNwOKj4P",
"colab_type": "code",
"outputId": "5c9851db-56db-495b-c67a-6d7bcbd77c54",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 521
}
},
"source": [
"!pip install pyarabic\n",
"!gsutil cp gs://bert_pretrain4/FarasaSegmenterJar.jar ./\n",
"!pip install emoji\n",
"\n",
"def install_java():\n",
" !apt-get install -y openjdk-8-jdk-headless -qq > /dev/null #install openjdk\n",
" os.environ[\"JAVA_HOME\"] = \"/usr/lib/jvm/java-8-openjdk-amd64\"\n",
" !java -version\n",
"install_java()\n",
"!pip install py4j\n",
"!pkill \"java\""
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting pyarabic\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/b8/77/da852ee13bce3affc55b746cebc0fdc0fc48628dbc5898ce489112cd6bd1/PyArabic-0.6.6.tar.gz (101kB)\n",
"\r\u001b[K |███▎ | 10kB 31.5MB/s eta 0:00:01\r\u001b[K |██████▌ | 20kB 34.6MB/s eta 0:00:01\r\u001b[K |█████████▊ | 30kB 40.4MB/s eta 0:00:01\r\u001b[K |█████████████ | 40kB 39.7MB/s eta 0:00:01\r\u001b[K |████████████████▏ | 51kB 34.6MB/s eta 0:00:01\r\u001b[K |███████████████████▍ | 61kB 38.6MB/s eta 0:00:01\r\u001b[K |██████████████████████▋ | 71kB 28.0MB/s eta 0:00:01\r\u001b[K |█████████████████████████▉ | 81kB 25.5MB/s eta 0:00:01\r\u001b[K |█████████████████████████████ | 92kB 27.4MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 102kB 12.1MB/s \n",
"\u001b[?25hBuilding wheels for collected packages: pyarabic\n",
" Building wheel for pyarabic (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pyarabic: filename=PyArabic-0.6.6-cp36-none-any.whl size=106208 sha256=7a3c3dcca57c1ec79e80a2db3f72f7234463bb8df9375c8d0a7ecce5730d95ee\n",
" Stored in directory: /root/.cache/pip/wheels/34/b5/2d/668d567e8c2b6f10309dbfaba5bfef6ea0b1c0f9f6fb37078f\n",
"Successfully built pyarabic\n",
"Installing collected packages: pyarabic\n",
"Successfully installed pyarabic-0.6.6\n",
"Copying gs://bert_pretrain4/FarasaSegmenterJar.jar...\n",
"/ [1 files][ 13.0 MiB/ 13.0 MiB] \n",
"Operation completed over 1 objects/13.0 MiB. \n",
"Collecting emoji\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/40/8d/521be7f0091fe0f2ae690cc044faf43e3445e0ff33c574eae752dd7e39fa/emoji-0.5.4.tar.gz (43kB)\n",
"\u001b[K |████████████████████████████████| 51kB 9.4MB/s \n",
"\u001b[?25hBuilding wheels for collected packages: emoji\n",
" Building wheel for emoji (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for emoji: filename=emoji-0.5.4-cp36-none-any.whl size=42176 sha256=e40cb86eb868f95d6934ce361da628a09b871a1cfefc23ddd22fc6d6fc234f8b\n",
" Stored in directory: /root/.cache/pip/wheels/2a/a9/0a/4f8e8cce8074232aba240caca3fade315bb49fac68808d1a9c\n",
"Successfully built emoji\n",
"Installing collected packages: emoji\n",
"Successfully installed emoji-0.5.4\n",
"openjdk version \"11.0.6\" 2020-01-14\n",
"OpenJDK Runtime Environment (build 11.0.6+10-post-Ubuntu-1ubuntu118.04.1)\n",
"OpenJDK 64-Bit Server VM (build 11.0.6+10-post-Ubuntu-1ubuntu118.04.1, mixed mode, sharing)\n",
"Collecting py4j\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/9e/b6/6a4fb90cd235dc8e265a6a2067f2a2c99f0d91787f06aca4bcf7c23f3f80/py4j-0.10.9-py2.py3-none-any.whl (198kB)\n",
"\u001b[K |████████████████████████████████| 204kB 42.1MB/s \n",
"\u001b[?25hInstalling collected packages: py4j\n",
"Successfully installed py4j-0.10.9\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "iDBveNiY08Sn",
"colab_type": "code",
"outputId": "9697dcc4-3d9d-4a17-eb4f-5350b92b7d01",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 82
}
},
"source": [
"# get preprocessing script\n",
"# !gsutil cp gs://bert_pretrain4/preprocess_wiss.py ./\n",
"!cp -r /content/drive/'My Drive'/Bert/preprocess_wiss.py ./\n",
"from preprocess_wiss import preprocess\n",
"# download data\n",
"!gsutil cp -r gs://osact4/data ./\n",
"\n",
"import emoji\n",
"def remove_emoji(text):\n",
" \"\"\"\n",
" :param text: enter a string of words.\n",
" :return: removes emojis from the input string\n",
" \"\"\"\n",
" return emoji.get_emoji_regexp().sub(u'', text)\n",
"\n",
"eastern_to_western = {\"٠\":\"0\",\"١\":\"1\",\"٢\":\"2\",\"٣\":\"3\",\"٤\":\"4\",\"٥\":\"5\",\"٦\":\"6\",\"٧\":\"7\",\"٨\":\"8\",\"٩\":\"9\",\"٪\":\"%\",\"_\":\" \",\"ڤ\":\"ف\",\"|\":\" \"}\n",
"trans_string = str.maketrans(eastern_to_western)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Copying gs://osact4/data/OSACT2020-sharedTask-dev.txt...\n",
"Copying gs://osact4/data/OSACT2020-sharedTask-train.txt...\n",
"/ [2 files][ 1.5 MiB/ 1.5 MiB] \n",
"Operation completed over 2 objects/1.5 MiB. \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "zlHvDLm5bDhz",
"colab_type": "code",
"colab": {}
},
"source": [
"!cp /content/drive/'My Drive'/Bert/OSACT2020-sharedTask-test-tweets.txt /content/"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "qtE4KF0m08QZ",
"colab_type": "code",
"colab": {}
},
"source": [
"f = open('/content/OSACT2020-sharedTask-test-tweets.txt', 'r', encoding='utf-8')\n",
"lines = f.readlines()\n",
"test = pd.DataFrame(np.array(lines).reshape(-1,1),columns=['text'])\n",
"\n",
"train = pd.read_csv(\"./data/OSACT2020-sharedTask-train.txt\",sep=\"\\t\",header=None)\n",
"train.columns = [\"text\",\"offensive\",\"hate_speech\"]\n",
"\n",
"eval_ = pd.read_csv(\"./data/OSACT2020-sharedTask-dev.txt\",sep=\"\\t\",header=None)\n",
"eval_.columns = [\"text\",\"offensive\",\"hate_speech\"]\n",
"\n",
"train[\"text\"] = train[\"text\"].apply(lambda x : remove_emoji(x))\n",
"train[\"text\"] = train[\"text\"].apply(lambda x : x.replace(\"\",\" \"))\n",
"train[\"text\"] = train[\"text\"].apply(lambda x : x.translate(trans_string))\n",
"train[\"text\"] = train[\"text\"].apply(lambda x : preprocess(x, True))\n",
"train[\"text\"] = train[\"text\"].apply(lambda x : x.replace(\"\\\\\",\" \"))\n",
"train[\"offensive\"] = train[\"offensive\"].apply(lambda x : x.replace(\"NOT_OFF\", '0'))\n",
"train[\"offensive\"] = train[\"offensive\"].apply(lambda x : x.replace(\"OFF\", '1'))\n",
"train[\"hate_speech\"] = train[\"hate_speech\"].apply(lambda x : x.replace(\"NOT_HS\", '0'))\n",
"train[\"hate_speech\"] = train[\"hate_speech\"].apply(lambda x : x.replace(\"HS\", '1'))\n",
"\n",
"eval_[\"text\"] = eval_[\"text\"].apply(lambda x : remove_emoji(x))\n",
"eval_[\"text\"] = eval_[\"text\"].apply(lambda x : x.replace(\"\",\" \"))\n",
"eval_[\"text\"] = eval_[\"text\"].apply(lambda x : x.translate(trans_string))\n",
"eval_[\"text\"] = eval_[\"text\"].apply(lambda x : preprocess(x, True))\n",
"eval_[\"text\"] = eval_[\"text\"].apply(lambda x : x.replace(\"\\\\\",\" \"))\n",
"eval_[\"offensive\"] = eval_[\"offensive\"].apply(lambda x : x.replace(\"NOT_OFF\", '0'))\n",
"eval_[\"offensive\"] = eval_[\"offensive\"].apply(lambda x : x.replace(\"OFF\", '1'))\n",
"eval_[\"hate_speech\"] = eval_[\"hate_speech\"].apply(lambda x : x.replace(\"NOT_HS\", '0'))\n",
"eval_[\"hate_speech\"] = eval_[\"hate_speech\"].apply(lambda x : x.replace(\"HS\", '1'))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WZRAYxwZbHgs",
"colab_type": "code",
"colab": {}
},
"source": [
"test[\"text\"] = test[\"text\"].apply(lambda x : remove_emoji(x))\n",
"test[\"text\"] = test[\"text\"].apply(lambda x : x.replace(\"\",\" \"))\n",
"test[\"text\"] = test[\"text\"].apply(lambda x : x.translate(trans_string))\n",
"test[\"text\"] = test[\"text\"].apply(lambda x : preprocess(x, True))\n",
"test[\"text\"] = test[\"text\"].apply(lambda x : x.replace(\"\\\\\",\" \"))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "GE89hIoy08OH",
"colab_type": "code",
"colab": {}
},
"source": [
"DATA_COLUMN = 'text'\n",
"LABEL_COLUMNS = ['offensive', 'hate_speech']"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "IqORf2akbJC5",
"colab_type": "code",
"colab": {}
},
"source": [
"x_train = train\n",
"x_val = eval_\n",
"x_test = test"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "6vkbh1xbG2CU",
"colab_type": "code",
"colab": {}
},
"source": [
"x_train_w_eval = pd.concat([x_train, x_val], axis=0)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "AXY_8w1yHHnm",
"colab_type": "code",
"outputId": "41008564-9944-4692-aafe-006fde367988",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
}
},
"source": [
"x_train_w_eval.shape"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(7839, 3)"
]
},
"metadata": {
"tags": []
},
"execution_count": 52
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PXTixssZe1Rm",
"colab_type": "text"
},
"source": [
"#### Downlsampling"
]
},
{
"cell_type": "code",
"metadata": {
"id": "V0dC0rmkQ4em",
"colab_type": "code",
"colab": {}
},
"source": [
"x_train['hate_speech'] = x_train['hate_speech'].astype(np.uint8)\n",
"x_train['offensive'] = x_train['offensive'].astype(np.uint8)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_dI4tsFeL_Oh",
"colab_type": "code",
"colab": {}
},
"source": [
"hate_idx = x_train[x_train.hate_speech == 1].index\n",
"off_idx = x_train[x_train.offensive == 1].index\n",
"non_hate_indices = x_train[x_train.hate_speech == 0].index\n",
"non_off_indices = x_train[x_train.offensive == 0].index"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "m00cKvyZbLaI",
"colab_type": "code",
"outputId": "0709f0b7-97c0-40b4-ba61-7386689b8790",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 82
}
},
"source": [
"hate_idx"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Int64Index([ 2, 40, 66, 74, 85, 94, 182, 183, 200, 212,\n",
" ...\n",
" 6670, 6697, 6713, 6732, 6735, 6745, 6753, 6773, 6791, 6806],\n",
" dtype='int64', length=350)"
]
},
"metadata": {
"tags": []
},
"execution_count": 72
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "r8jtCoIzMKGc",
"colab_type": "code",
"colab": {}
},
"source": [
"diff = non_hate_indices.shape[0]-non_off_indices.shape[0]\n",
"nans = np.ones((diff,))*100000\n",
"non_hate_indices = np.asarray(non_hate_indices)\n",
"non_off_indices =np.concatenate((non_off_indices,nans),axis=0)\n",
"\n",
"matches = []\n",
"\n",
"for i in range(non_hate_indices.shape[0]):\n",
" for j in range(non_hate_indices.shape[0]):\n",
" if non_off_indices[j]== non_hate_indices[i]:\n",
" matches.append(non_off_indices[j])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "SU9MPHKAMSbr",
"colab_type": "code",
"outputId": "5cc4b9e1-97cb-4e62-8250-5f7c0e34d4eb",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 98
}
},
"source": [
"n_hate = len(hate_idx)\n",
"n_offense = len(off_idx)\n",
"n_neither = len(matches)\n",
"\n",
"from prettytable import PrettyTable\n",
" \n",
"tab = PrettyTable()\n",
"\n",
"tab.field_names = [\"\",\"Train Offensive\", 'Train Hate', 'Train Neither']\n",
"tab.add_row([\"count\", n_offense, n_hate, n_neither])\n",
"print(tab)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"+-------+-----------------+------------+---------------+\n",
"| | Train Offensive | Train Hate | Train Neither |\n",
"+-------+-----------------+------------+---------------+\n",
"| count | 1371 | 350 | 5468 |\n",
"+-------+-----------------+------------+---------------+\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6B-_l8dSrLFX",
"colab_type": "code",
"outputId": "01317df1-cb5c-4751-a839-a47cc79eb327",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
}
},
"source": [
"int(n_neither/1.5)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"3645"
]
},
"metadata": {
"tags": []
},
"execution_count": 78
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ExZw9P4mMOfx",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"n_hate = len(hate_idx)\n",
"n_offense = len(off_idx)\n",
"random_indices = np.random.choice(matches, int(n_neither/1.5), replace=False)\n",
"\n",
"under_sample_indices = np.concatenate([hate_idx, off_idx, random_indices])\n",
"under_sampled = x_train.loc[under_sample_indices]\n",
"\n",
"hate_data = x_train.loc[hate_idx] \n",
"\n",
"under_sampled = under_sampled.sample(frac=1).reset_index(drop=True)\n",
"\n",
"# balanced = pd.concat([under_sampled, hate_data], axis=0)\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "JnSUretjTheR",
"colab_type": "code",
"colab": {}
},
"source": [
"x_train = under_sampled.sample(frac=1).reset_index(drop=True)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "u3REqJ01R050",
"colab_type": "code",
"outputId": "4b8bf895-46fa-487b-a51d-a6731d06dc69",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 98
}
},
"source": [
"tot = len(under_sampled.values)\n",
"percentage_off = np.count_nonzero((under_sampled.values)[:,1] == 1)/tot\n",
"percentage_hate = np.count_nonzero((under_sampled.values)[:,2] == 1)/tot\n",
"percentage_non = 1-(percentage_off+ percentage_hate)\n",
"tab = PrettyTable()\n",
"\n",
"tab.field_names = [\"\",\"Test Offensive\", 'Test Hate', 'Test Neither']\n",
"tab.add_row([\"Percentage\", np.round(percentage_off,4)*100, np.round(percentage_hate,4)*100, np.round(percentage_non,4)*100])\n",
"print(tab)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"+------------+----------------+-----------+--------------------+\n",
"| | Test Offensive | Test Hate | Test Neither |\n",
"+------------+----------------+-----------+--------------------+\n",
"| Percentage | 32.07 | 13.05 | 54.879999999999995 |\n",
"+------------+----------------+-----------+--------------------+\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "PJr7IGxuMrmO",
"colab_type": "code",
"outputId": "92513b09-2f1f-41a0-d7eb-192ca35dea11",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 191
}
},
"source": [
"x_train.head()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" text | \n",
" offensive | \n",
" hate_speech | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" [مستخدم] يعني هياخدوا دوري ال+ أبطال السن +ه د... | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" [مستخدم] مبروك يا دوك يا مبدع و+ عقبال ال+ ملي... | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" [مستخدم] يا عبري +ة يا أيتها ال+ عاهر +ة فك +ي... | \n",
" 1 | \n",
" 1 | \n",
"
\n",
" \n",
" 3 | \n",
" يا بسم +ة ال+ صبح يا ف+ ألي و+ يا أملي يا نغم ... | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" RT [مستخدم] : صباح ال+ زفت علي +كم [مستخدم] يا... | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text offensive hate_speech\n",
"0 [مستخدم] يعني هياخدوا دوري ال+ أبطال السن +ه د... 0 0\n",
"1 [مستخدم] مبروك يا دوك يا مبدع و+ عقبال ال+ ملي... 0 0\n",
"2 [مستخدم] يا عبري +ة يا أيتها ال+ عاهر +ة فك +ي... 1 1\n",
"3 يا بسم +ة ال+ صبح يا ف+ ألي و+ يا أملي يا نغم ... 0 0\n",
"4 RT [مستخدم] : صباح ال+ زفت علي +كم [مستخدم] يا... 1 0"
]
},
"metadata": {
"tags": []
},
"execution_count": 107
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "obgq5DAybFsc",
"colab_type": "code",
"outputId": "42600a63-68ca-4bdb-886d-9762508d6f3b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 191
}
},
"source": [
"balanced.head()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" text | \n",
" offensive | \n",
" hate_speech | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" RT [مستخدم] : يافاتنة كل ال+ مدن يا أرض ال+ نق... | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" أي يا حظ ، أي يا سفر ، أي يا فلوس ، أي ياجواز ... | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" أنا حاول +ت كتير أتصور ميرور سيلفي و+ ب+ يطلع ... | \n",
" 1 | \n",
" 0 | \n",
"
\n",
" \n",
" 3 | \n",
" RT [مستخدم] : أقولك يا إيمان و+ لا يا نان +ا و... | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" #عبدالرحمن ال+ خشت +ي اللهم يا سيدي و+ يا خالق... | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text offensive hate_speech\n",
"0 RT [مستخدم] : يافاتنة كل ال+ مدن يا أرض ال+ نق... 0 0\n",
"1 أي يا حظ ، أي يا سفر ، أي يا فلوس ، أي ياجواز ... 1 0\n",
"2 أنا حاول +ت كتير أتصور ميرور سيلفي و+ ب+ يطلع ... 1 0\n",
"3 RT [مستخدم] : أقولك يا إيمان و+ لا يا نان +ا و... 0 0\n",
"4 #عبدالرحمن ال+ خشت +ي اللهم يا سيدي و+ يا خالق... 0 0"
]
},
"metadata": {
"tags": []
},
"execution_count": 40
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UNxWMU4BTW6r",
"colab_type": "text"
},
"source": [
"#### Class weighing\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wsUDrwJ8TdB1",
"colab_type": "code",
"colab": {}
},
"source": [
"from sklearn.utils import class_weight\n",
"import numpy as np\n",
"\n",
"labels_off= x_train['offensive'].values\n",
"labels_hate= x_train['hate_speech'].values\n",
"class_weights_off = class_weight.compute_class_weight('balanced',\n",
" np.unique(labels_off),\n",
" labels_off)\n",
"class_weights_hate = class_weight.compute_class_weight('balanced',\n",
" np.unique(labels_hate),\n",
" labels_hate)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Efa-LiTGUCHq",
"colab_type": "code",
"outputId": "16b0e57d-9a72-4894-bb65-7f1c37310026",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
}
},
"source": [
"class_weights_off"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0.62536576, 2.49416484])"
]
},
"metadata": {
"tags": []
},
"execution_count": 54
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aFSs-07HUFRe",
"colab_type": "code",
"outputId": "4138ce23-8f6e-4f7f-c6c0-495f1b3a1608",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
}
},
"source": [
"class_weights_hate"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([0.52696872, 9.77 ])"
]
},
"metadata": {
"tags": []
},
"execution_count": 55
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8q6Zx3J0e69P",
"colab_type": "text"
},
"source": [
"### Start model Components"
]
},
{
"cell_type": "code",
"metadata": {
"id": "YHvoreqj08MW",
"colab_type": "code",
"colab": {}
},
"source": [
"class InputExample(object):\n",
" \"\"\"A single training/test example for simple sequence classification.\"\"\"\n",
"\n",
" def __init__(self, guid, text_a, text_b=None, labels=None):\n",
" \"\"\"Constructs a InputExample.\n",
"\n",
" Args:\n",
" guid: Unique id for the example.\n",
" text_a: string. The untokenized text of the first sequence. For single\n",
" sequence tasks, only this sequence must be specified.\n",
" text_b: (Optional) string. The untokenized text of the second sequence.\n",
" Only must be specified for sequence pair tasks.\n",
" labels: (Optional) [string]. The label of the example. This should be\n",
" specified for train and dev examples, but not for test examples.\n",
" \"\"\"\n",
" self.guid = guid\n",
" self.text_a = text_a\n",
" self.text_b = text_b\n",
" self.labels = labels\n",
"\n",
"\n",
"class InputFeatures(object):\n",
" \"\"\"A single set of features of data.\"\"\"\n",
"\n",
" def __init__(self, input_ids, input_mask, segment_ids, label_ids, is_real_example=True):\n",
" self.input_ids = input_ids\n",
" self.input_mask = input_mask\n",
" self.segment_ids = segment_ids\n",
" self.label_ids = label_ids,\n",
" self.is_real_example=is_real_example"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9eYaSmhZ08KA",
"colab_type": "code",
"colab": {}
},
"source": [
"def create_examples(df, labels_available=True):\n",
" \"\"\"Creates examples for the training and dev sets.\"\"\"\n",
" examples = []\n",
" for (i, row) in enumerate(df.values):\n",
" guid = ''\n",
" text_a = row[0]\n",
" if labels_available:\n",
" labels = row[1:]\n",
" else:\n",
" labels = [0,0]\n",
" print(text_a, labels)\n",
" examples.append(\n",
" InputExample(guid=guid, text_a=text_a, labels=labels))\n",
" return examples"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NCd6CVex08F8",
"colab_type": "code",
"colab": {}
},
"source": [
"def convert_examples_to_features(examples, max_seq_length, tokenizer):\n",
" \"\"\"Loads a data file into a list of `InputBatch`s.\"\"\"\n",
"\n",
" features = []\n",
" for (ex_index, example) in enumerate(examples):\n",
" print(example.text_a)\n",
" tokens_a = tokenizer.tokenize(example.text_a)\n",
"\n",
" tokens_b = None\n",
" if example.text_b:\n",
" tokens_b = tokenizer.tokenize(example.text_b)\n",
" # Modifies `tokens_a` and `tokens_b` in place so that the total\n",
" # length is less than the specified length.\n",
" # Account for [CLS], [SEP], [SEP] with \"- 3\"\n",
" _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)\n",
" else:\n",
" # Account for [CLS] and [SEP] with \"- 2\"\n",
" if len(tokens_a) > max_seq_length - 2:\n",
" tokens_a = tokens_a[:(max_seq_length - 2)]\n",
"\n",
" # The convention in BERT is:\n",
" # (a) For sequence pairs:\n",
" # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]\n",
" # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1\n",
" # (b) For single sequences:\n",
" # tokens: [CLS] the dog is hairy . [SEP]\n",
" # type_ids: 0 0 0 0 0 0 0\n",
" #\n",
" # Where \"type_ids\" are used to indicate whether this is the first\n",
" # sequence or the second sequence. The embedding vectors for `type=0` and\n",
" # `type=1` were learned during pre-training and are added to the wordpiece\n",
" # embedding vector (and position vector). This is not *strictly* necessary\n",
" # since the [SEP] token unambigiously separates the sequences, but it makes\n",
" # it easier for the model to learn the concept of sequences.\n",
" #\n",
" # For classification tasks, the first vector (corresponding to [CLS]) is\n",
" # used as as the \"sentence vector\". Note that this only makes sense because\n",
" # the entire model is fine-tuned.\n",
" tokens = [\"[CLS]\"] + tokens_a + [\"[SEP]\"]\n",
" segment_ids = [0] * len(tokens)\n",
"\n",
" if tokens_b:\n",
" tokens += tokens_b + [\"[SEP]\"]\n",
" segment_ids += [1] * (len(tokens_b) + 1)\n",
"\n",
" input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
"\n",
" # The mask has 1 for real tokens and 0 for padding tokens. Only real\n",
" # tokens are attended to.\n",
" input_mask = [1] * len(input_ids)\n",
"\n",
" # Zero-pad up to the sequence length.\n",
" padding = [0] * (max_seq_length - len(input_ids))\n",
" input_ids += padding\n",
" input_mask += padding\n",
" segment_ids += padding\n",
"\n",
" assert len(input_ids) == max_seq_length\n",
" assert len(input_mask) == max_seq_length\n",
" assert len(segment_ids) == max_seq_length\n",
" \n",
" labels_ids = []\n",
" for label in example.labels:\n",
" labels_ids.append(int(label))\n",
"\n",
" if ex_index < 0:\n",
" logger.info(\"*** Example ***\")\n",
" logger.info(\"guid: %s\" % (example.guid))\n",
" logger.info(\"tokens: %s\" % \" \".join(\n",
" [str(x) for x in tokens]))\n",
" logger.info(\"input_ids: %s\" % \" \".join([str(x) for x in input_ids]))\n",
" logger.info(\"input_mask: %s\" % \" \".join([str(x) for x in input_mask]))\n",
" logger.info(\n",
" \"segment_ids: %s\" % \" \".join([str(x) for x in segment_ids]))\n",
" logger.info(\"label: %s (id = %s)\" % (example.labels, labels_ids))\n",
"\n",
" features.append(\n",
" InputFeatures(input_ids=input_ids,\n",
" input_mask=input_mask,\n",
" segment_ids=segment_ids,\n",
" label_ids=labels_ids))\n",
" return features"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xXmO_n0808A4",
"colab_type": "code",
"colab": {}
},
"source": [
"class PaddingInputExample(object):\n",
" \"\"\"Fake example so the num input examples is a multiple of the batch size.\n",
" When running eval/predict on the TPU, we need to pad the number of examples\n",
" to be a multiple of the batch size, because the TPU requires a fixed batch\n",
" size. The alternative is to drop the last batch, which is bad because it means\n",
" the entire output data won't be generated.\n",
" We use this class instead of `None` because treating `None` as padding\n",
" batches could cause silent errors.\n",
" \"\"\"\n",
" \n",
" \n",
"def convert_single_example(ex_index, example, max_seq_length,\n",
" tokenizer):\n",
" \"\"\"Converts a single `InputExample` into a single `InputFeatures`.\"\"\"\n",
"\n",
" if isinstance(example, PaddingInputExample):\n",
" return InputFeatures(\n",
" input_ids=[0] * max_seq_length,\n",
" input_mask=[0] * max_seq_length,\n",
" segment_ids=[0] * max_seq_length,\n",
" label_ids=0,\n",
" is_real_example=False)\n",
"\n",
" tokens_a = tokenizer.tokenize(example.text_a)\n",
" tokens_b = None\n",
" if example.text_b:\n",
" tokens_b = tokenizer.tokenize(example.text_b)\n",
"\n",
" if tokens_b:\n",
" # Modifies `tokens_a` and `tokens_b` in place so that the total\n",
" # length is less than the specified length.\n",
" # Account for [CLS], [SEP], [SEP] with \"- 3\"\n",
" _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)\n",
" else:\n",
" # Account for [CLS] and [SEP] with \"- 2\"\n",
" if len(tokens_a) > max_seq_length - 2:\n",
" tokens_a = tokens_a[0:(max_seq_length - 2)]\n",
"\n",
" # The convention in BERT is:\n",
" # (a) For sequence pairs:\n",
" # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]\n",
" # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1\n",
" # (b) For single sequences:\n",
" # tokens: [CLS] the dog is hairy . [SEP]\n",
" # type_ids: 0 0 0 0 0 0 0\n",
" #\n",
" # Where \"type_ids\" are used to indicate whether this is the first\n",
" # sequence or the second sequence. The embedding vectors for `type=0` and\n",
" # `type=1` were learned during pre-training and are added to the wordpiece\n",
" # embedding vector (and position vector). This is not *strictly* necessary\n",
" # since the [SEP] token unambiguously separates the sequences, but it makes\n",
" # it easier for the model to learn the concept of sequences.\n",
" #\n",
" # For classification tasks, the first vector (corresponding to [CLS]) is\n",
" # used as the \"sentence vector\". Note that this only makes sense because\n",
" # the entire model is fine-tuned.\n",
" tokens = []\n",
" segment_ids = []\n",
" tokens.append(\"[CLS]\")\n",
" segment_ids.append(0)\n",
" for token in tokens_a:\n",
" tokens.append(token)\n",
" segment_ids.append(0)\n",
" tokens.append(\"[SEP]\")\n",
" segment_ids.append(0)\n",
"\n",
" if tokens_b:\n",
" for token in tokens_b:\n",
" tokens.append(token)\n",
" segment_ids.append(1)\n",
" tokens.append(\"[SEP]\")\n",
" segment_ids.append(1)\n",
"\n",
" input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
"\n",
" # The mask has 1 for real tokens and 0 for padding tokens. Only real\n",
" # tokens are attended to.\n",
" input_mask = [1] * len(input_ids)\n",
"\n",
" # Zero-pad up to the sequence length.\n",
" while len(input_ids) < max_seq_length:\n",
" input_ids.append(0)\n",
" input_mask.append(0)\n",
" segment_ids.append(0)\n",
"\n",
" assert len(input_ids) == max_seq_length\n",
" assert len(input_mask) == max_seq_length\n",
" assert len(segment_ids) == max_seq_length\n",
"\n",
" labels_ids = []\n",
" for label in example.labels:\n",
" labels_ids.append(int(label))\n",
"\n",
"\n",
" feature = InputFeatures(\n",
" input_ids=input_ids,\n",
" input_mask=input_mask,\n",
" segment_ids=segment_ids,\n",
" label_ids=labels_ids,\n",
" is_real_example=True)\n",
" return feature\n",
"\n",
"\n",
"def file_based_convert_examples_to_features(\n",
" examples, max_seq_length, tokenizer, output_file):\n",
" \"\"\"Convert a set of `InputExample`s to a TFRecord file.\"\"\"\n",
"\n",
" writer = tf.python_io.TFRecordWriter(output_file)\n",
"\n",
" for (ex_index, example) in enumerate(examples):\n",
" #if ex_index % 10000 == 0:\n",
" #tf.logging.info(\"Writing example %d of %d\" % (ex_index, len(examples)))\n",
"\n",
" feature = convert_single_example(ex_index, example,\n",
" max_seq_length, tokenizer)\n",
"\n",
" def create_int_feature(values):\n",
" f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))\n",
" return f\n",
"\n",
" features = collections.OrderedDict()\n",
" features[\"input_ids\"] = create_int_feature(feature.input_ids)\n",
" features[\"input_mask\"] = create_int_feature(feature.input_mask)\n",
" features[\"segment_ids\"] = create_int_feature(feature.segment_ids)\n",
" features[\"is_real_example\"] = create_int_feature(\n",
" [int(feature.is_real_example)])\n",
" if isinstance(feature.label_ids, list):\n",
" label_ids = feature.label_ids\n",
" else:\n",
" label_ids = feature.label_ids[0]\n",
" features[\"label_ids\"] = create_int_feature(label_ids)\n",
"\n",
" tf_example = tf.train.Example(features=tf.train.Features(feature=features))\n",
" writer.write(tf_example.SerializeToString())\n",
" writer.close()\n",
"\n",
"\n",
"def file_based_input_fn_builder(input_file, seq_length, is_training,\n",
" drop_remainder):\n",
" \"\"\"Creates an `input_fn` closure to be passed to TPUEstimator.\"\"\"\n",
"\n",
" name_to_features = {\n",
" \"input_ids\": tf.FixedLenFeature([seq_length], tf.int64),\n",
" \"input_mask\": tf.FixedLenFeature([seq_length], tf.int64),\n",
" \"segment_ids\": tf.FixedLenFeature([seq_length], tf.int64),\n",
" \"label_ids\": tf.FixedLenFeature([2], tf.int64),\n",
" \"is_real_example\": tf.FixedLenFeature([], tf.int64),\n",
" }\n",
"\n",
" def _decode_record(record, name_to_features):\n",
" \"\"\"Decodes a record to a TensorFlow example.\"\"\"\n",
" example = tf.parse_single_example(record, name_to_features)\n",
"\n",
" # tf.Example only supports tf.int64, but the TPU only supports tf.int32.\n",
" # So cast all int64 to int32.\n",
" for name in list(example.keys()):\n",
" t = example[name]\n",
" if t.dtype == tf.int64:\n",
" t = tf.to_int32(t)\n",
" example[name] = t\n",
"\n",
" return example\n",
"\n",
" def input_fn(params):\n",
" \"\"\"The actual input function.\"\"\"\n",
" batch_size = params[\"batch_size\"]\n",
"\n",
" # For training, we want a lot of parallel reading and shuffling.\n",
" # For eval, we want no shuffling and parallel reading doesn't matter.\n",
" d = tf.data.TFRecordDataset(input_file)\n",
" if is_training:\n",
" d = d.repeat()\n",
" d = d.shuffle(buffer_size=100)\n",
"\n",
" d = d.apply(\n",
" tf.contrib.data.map_and_batch(\n",
" lambda record: _decode_record(record, name_to_features),\n",
" batch_size=batch_size,\n",
" drop_remainder=drop_remainder))\n",
"\n",
" return d\n",
"\n",
" return input_fn\n",
"\n",
"\n",
"def _truncate_seq_pair(tokens_a, tokens_b, max_length):\n",
" \"\"\"Truncates a sequence pair in place to the maximum length.\"\"\"\n",
"\n",
" # This is a simple heuristic which will always truncate the longer sequence\n",
" # one token at a time. This makes more sense than truncating an equal percent\n",
" # of tokens from each, since if one sequence is very short then each token\n",
" # that's truncated likely contains more information than a longer sequence.\n",
" while True:\n",
" total_length = len(tokens_a) + len(tokens_b)\n",
" if total_length <= max_length:\n",
" break\n",
" if len(tokens_a) > len(tokens_b):\n",
" tokens_a.pop()\n",
" else:\n",
" tokens_b.pop()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "LV7xGtjtJOaK",
"colab_type": "text"
},
"source": [
"#### Creating the MTL model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "gMBZy46u4uY9",
"colab_type": "code",
"colab": {}
},
"source": [
"def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,\n",
" labels, num_labels, use_one_hot_embeddings, class_weights_off = 1,\n",
" class_weights_Noff=1, class_weights_hate = 1, class_weights_Nhate=1 ):\n",
" \"\"\"Creates a classification model.\"\"\"\n",
" model = modeling.BertModel(\n",
" config=bert_config,\n",
" is_training=is_training,\n",
" input_ids=input_ids,\n",
" input_mask=input_mask,\n",
" token_type_ids=segment_ids,\n",
" use_one_hot_embeddings=use_one_hot_embeddings)\n",
"\n",
"\n",
" output_layer = model.get_pooled_output()\n",
" \n",
"\n",
"\n",
" hidden_size = output_layer.shape[-1].value\n",
" weights = {}\n",
" biases = {}\n",
" # creating weights and biases per task (lable)\n",
"\n",
" for task in range(num_labels):\n",
" \n",
" key = 'task_{}'.format(task)\n",
" weight_name = \"output_weights_{}\".format(task)\n",
" bias_name = \"output_bias_{}\".format(task)\n",
" \n",
" weights[key] = tf.get_variable(\n",
" weight_name, [2, hidden_size],\n",
" initializer = tf.truncated_normal_initializer(stddev=0.02))\n",
" \n",
" biases[key] = tf.get_variable(\n",
" bias_name, \n",
" [2], initializer=tf.zeros_initializer())\n",
"\n",
"\n",
" # class_weights = tf.constant([[1.0, 2.0]])\n",
" # deduce weights for batch samples based on their true label\n",
"\n",
" class_weights_off = tf.constant([[class_weights_Noff, class_weights_off]])\n",
" class_weights_hate = tf.constant([[class_weights_Nhate, class_weights_hate]])\n",
"\n",
" with tf.variable_scope(\"loss\"):\n",
" if is_training:\n",
" # I.e., 0.1 dropout\n",
" output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)\n",
" \n",
" logits = []\n",
" probabilities = []\n",
" per_example_loss = []\n",
" log_probs =[]\n",
" labels_hot= []\n",
" losses=[]\n",
"\n",
" for task in range(num_labels):\n",
" key = 'task_{}'.format(task)\n",
"\n",
" logits.append(tf.matmul(output_layer, weights[key], transpose_b=True))\n",
" logits[task] = tf.nn.bias_add(logits[task], biases[key])\n",
"\n",
" \n",
" \n",
"\n",
" probabilities.append(tf.nn.softmax(logits[task], axis=-1))\n",
"\n",
" labels = tf.cast(labels, tf.int32) # labels shape: (batch_size, n_labels) (32,2)\n",
" labels_hot.append(tf.one_hot(labels[:,task], depth=num_labels, dtype=tf.int32))\n",
" \n",
" tf.logging.info(\"num_labels:{};logits:{};labels:{}\".format(num_labels, logits, labels))\n",
"\n",
" labels_hot[task] = tf.cast(labels_hot[task], tf.float32)\n",
"\n",
"\n",
" ''' for weighted loss function '''\n",
" # if task ==0:\n",
" per_example_loss.append(tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits[task], labels=labels_hot[task] ) ))\n",
" # per_example_loss[task] = per_example_loss[task] * class_weights_off\n",
" # else:\n",
" # per_example_loss.append(tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits = logits[task], labels = labels_hot[task] ) ))\n",
" # per_example_loss[task] = per_example_loss[task] * class_weights_hate\n",
" # per_example_loss.append(tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits[task], labels=labels_hot[task]) ))\n",
"\n",
" loss = tf.math.reduce_sum(per_example_loss)\n",
"\n",
" return (loss, per_example_loss, logits, probabilities)\n",
"\n",
"\n",
"def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,\n",
" num_train_steps, num_warmup_steps, use_tpu,\n",
" use_one_hot_embeddings):\n",
" \"\"\"Returns `model_fn` closure for TPUEstimator.\"\"\"\n",
"\n",
" def model_fn(features, labels, mode, params): # pylint: disable=unused-argument\n",
" \"\"\"The `model_fn` for TPUEstimator.\"\"\"\n",
"\n",
" tf.logging.info(\"*** Features ***\")\n",
" for name in sorted(features.keys()):\n",
" tf.logging.info(\" name = %s, shape = %s\" % (name, features[name].shape))\n",
"\n",
" input_ids = features[\"input_ids\"]\n",
" input_mask = features[\"input_mask\"]\n",
" segment_ids = features[\"segment_ids\"]\n",
" label_ids = features[\"label_ids\"]\n",
" is_real_example = None\n",
" if \"is_real_example\" in features:\n",
" is_real_example = tf.cast(features[\"is_real_example\"], dtype=tf.float32)\n",
" else:\n",
" is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)\n",
"\n",
" is_training = (mode == tf.estimator.ModeKeys.TRAIN)\n",
"\n",
" \n",
"\n",
" (total_loss, per_example_loss, logits, probabilities) = create_model(\n",
" bert_config, is_training, input_ids, input_mask, segment_ids, label_ids,\n",
" num_labels, use_one_hot_embeddings)\n",
"\n",
" tvars = tf.trainable_variables()\n",
" initialized_variable_names = {}\n",
" scaffold_fn = None\n",
" if init_checkpoint:\n",
" (assignment_map, initialized_variable_names\n",
" ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)\n",
" if use_tpu:\n",
"\n",
" def tpu_scaffold():\n",
" tf.train.init_from_checkpoint(init_checkpoint, assignment_map)\n",
" return tf.train.Scaffold()\n",
"\n",
" scaffold_fn = tpu_scaffold\n",
" else:\n",
" tf.train.init_from_checkpoint(init_checkpoint, assignment_map)\n",
"\n",
" tf.logging.info(\"**** Trainable Variables ****\")\n",
" for var in tvars:\n",
" init_string = \"\"\n",
" if var.name in initialized_variable_names:\n",
" init_string = \", *INIT_FROM_CKPT*\"\n",
" # tf.logging.info(\" name = %s, shape = %s%s\", var.name, var.shape,init_string)\n",
"\n",
" output_spec = None\n",
" if mode == tf.estimator.ModeKeys.TRAIN:\n",
" train_op = optimization.create_optimizer(\n",
" total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)\n",
"\n",
" output_spec = tf.estimator.EstimatorSpec(\n",
" mode=mode,\n",
" loss=total_loss,\n",
" train_op=train_op,\n",
" scaffold=scaffold_fn)\n",
"\n",
"\n",
" elif mode == tf.estimator.ModeKeys.EVAL:\n",
"\n",
" def metric_fn(per_example_loss, label_ids, probabilities, is_real_example):\n",
" \n",
" labels_hot = []\n",
" for task in range(len(LABEL_COLUMNS)):\n",
"\n",
" label_ids = tf.cast(label_ids, tf.int32) # labels shape: (batch_size, n_labels) (32,2)\n",
" labels_hot.append(tf.one_hot(label_ids[:,task], depth=num_labels, dtype=tf.int32))\n",
" # 0 non-offense, non-hate\n",
" # 1 offense, hate\n",
"\n",
"\n",
" label_cols = ['non-offense','offense', 'non-hate', 'hate']\n",
"\n",
" non_offense = probabilities[0][:,0]\n",
" offense = probabilities[0][:,1]\n",
" non_hate = probabilities[1][:,0]\n",
" hate = probabilities[1][:,1]\n",
"\n",
" proba = [non_offense, offense, non_hate, hate]\n",
" labels_separate = [labels_hot[0][:,0], labels_hot[0][:,1], labels_hot[1][:,0], labels_hot[1][:,1]]\n",
" # metrics change to auc of every class\n",
" eval_dict = {}\n",
" \n",
" for j, (label_name, logits) in enumerate(zip(label_cols, proba)):\n",
"\n",
" logits = tf.math.round(logits)\n",
"\n",
" current_f1, update_op_f1 = tf.contrib.metrics.f1_score(labels_separate[j], logits) \n",
" current_accuracy, update_op_accuracy = tf.compat.v1.metrics.accuracy(labels_separate[j], logits)\n",
"\n",
" \n",
" eval_dict[label_name + '_f1'] = (current_f1, update_op_f1) # (current_auc, update_op_auc)\n",
" eval_dict[label_name + '_accuracy'] = (current_accuracy, update_op_accuracy)\n",
"\n",
"\n",
" eval_dict['eval_loss'] = tf.metrics.mean(values=per_example_loss)\n",
" return eval_dict\n",
"\n",
"\n",
"\n",
" eval_metrics = metric_fn(per_example_loss, label_ids, probabilities, is_real_example)\n",
" output_spec = tf.estimator.EstimatorSpec(\n",
" mode=mode,\n",
" loss=total_loss,\n",
" eval_metric_ops = eval_metrics,\n",
" scaffold=scaffold_fn)\n",
" else:\n",
" print(\"mode:\", mode,\"probabilities:\", probabilities)\n",
" output_spec = tf.estimator.EstimatorSpec(\n",
" mode=mode,\n",
" predictions={\"probabilities_offense\": probabilities[0], 'probabilities_hate': probabilities[1]},\n",
" scaffold=scaffold_fn)\n",
" return output_spec\n",
"\n",
" return model_fn"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "VekZ_uAYfJdu",
"colab_type": "text"
},
"source": [
"### Training Preps"
]
},
{
"cell_type": "code",
"metadata": {
"id": "2nZETvAQ078Q",
"colab_type": "code",
"outputId": "7ec08981-3293-46aa-dd04-78d830d56caf",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
}
},
"source": [
"tokenizer = tokenization.FullTokenizer(vocab_file=BERT_VOCAB, do_lower_case=False)\n",
"\n",
"MAX_SEQ_LENGTH = 256\n",
"BATCH_SIZE = 16\n",
"LEARNING_RATE = 2e-5\n",
"NUM_TRAIN_EPOCHS = 5.0\n",
"WARMUP_PROPORTION = 0.1\n",
"SAVE_CHECKPOINTS_STEPS = 200\n",
"SAVE_SUMMARY_STEPS = 200\n",
"\n",
"!mkdir working\n",
"train_file = os.path.join('./working', \"train.tf_record\")\n",
"#filename = Path(train_file)\n",
"if not os.path.exists(train_file):\n",
" open(train_file, 'w').close()\n",
"\n",
"train_examples = create_examples(x_train)\n",
"file_based_convert_examples_to_features(train_examples, MAX_SEQ_LENGTH, tokenizer, train_file)\n",
"\n",
"eval_file = os.path.join('./working', \"eval.tf_record\")\n",
"#filename = Path(train_file)\n",
"if not os.path.exists(eval_file):\n",
" open(eval_file, 'w').close()\n",
"\n",
"eval_examples = create_examples(x_val)\n",
"file_based_convert_examples_to_features(eval_examples, MAX_SEQ_LENGTH, tokenizer, eval_file)\n",
"\n",
"test_file = os.path.join('./working', \"test.tf_record\")\n",
"if not os.path.exists(test_file):\n",
" open(test_file, 'w').close()\n",
"\n",
"test_examples = create_examples(x_test)\n",
"file_based_convert_examples_to_features(test_examples, MAX_SEQ_LENGTH, tokenizer, test_file)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"mkdir: cannot create directory ‘working’: File exists\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "I8Nux7RvlC4l",
"colab_type": "code",
"outputId": "1bbb7517-bbf8-4a2b-9cf4-1f23958b6474",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 191
}
},
"source": [
"x_val.head()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" text | \n",
" offensive | \n",
" hate_speech | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" في حاج +ات مينفعش نلفت نظرك +وا لي +ها زى ال+ ... | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" : و+ عيون تنادي +نا تحايل في +نا و+ نقول يا عي... | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" يا بلاد +ي يا أم ال+ بلاد يا بلاد +ي ب+ حبك يا... | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 3 | \n",
" : يا رب يا قوي يا معين مدني ب+ ال+ قو +ة و+ ال... | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
" 4 | \n",
" : رحم ك+ الله يا صدام يا بطل و+ مقدام . | \n",
" 0 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text offensive hate_speech\n",
"0 في حاج +ات مينفعش نلفت نظرك +وا لي +ها زى ال+ ... 0 0\n",
"1 : و+ عيون تنادي +نا تحايل في +نا و+ نقول يا عي... 0 0\n",
"2 يا بلاد +ي يا أم ال+ بلاد يا بلاد +ي ب+ حبك يا... 0 0\n",
"3 : يا رب يا قوي يا معين مدني ب+ ال+ قو +ة و+ ال... 0 0\n",
"4 : رحم ك+ الله يا صدام يا بطل و+ مقدام . 0 0"
]
},
"metadata": {
"tags": []
},
"execution_count": 57
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "VOz-hJDelJTF",
"colab_type": "code",
"outputId": "7fed0ce0-9408-41e0-df00-5a798d63c702",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 191
}
},
"source": [
"x_test.head()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" text | \n",
" offensive | \n",
" hate_speech | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" اما أنت تقعد طول عمر ك+ لا مبدا و+ لا رأي ثابت... | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1 | \n",
" ب+ تخاف نسوان ك+ يزعل +وا و+ لا أي اه يا هلفوت... | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2 | \n",
" : يا عسانى نبقى يا عمري حبايب و+ حب +نا يكبر م... | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 3 | \n",
" : باقي ال+ بيان و+ ينو ما شفن +ه يا برهان و+ ر... | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 4 | \n",
" اللهم أنت ال+ شافي ال+ معافي اشفي +ه و+ جميع م... | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" text offensive hate_speech\n",
"0 اما أنت تقعد طول عمر ك+ لا مبدا و+ لا رأي ثابت... 0.0 0.0\n",
"1 ب+ تخاف نسوان ك+ يزعل +وا و+ لا أي اه يا هلفوت... 0.0 0.0\n",
"2 : يا عسانى نبقى يا عمري حبايب و+ حب +نا يكبر م... 0.0 0.0\n",
"3 : باقي ال+ بيان و+ ينو ما شفن +ه يا برهان و+ ر... 0.0 0.0\n",
"4 اللهم أنت ال+ شافي ال+ معافي اشفي +ه و+ جميع م... 0.0 0.0"
]
},
"metadata": {
"tags": []
},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "CIurzoBFk-eA",
"colab_type": "code",
"colab": {}
},
"source": [
"x_test['offensive'] = np.zeros((x_test.shape[0]))\n",
"x_test['hate_speech'] = np.zeros((x_test.shape[0]))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jOsj5MFm4sS5",
"colab_type": "code",
"outputId": "5f4926fd-439d-4ca5-c766-c4a6b92450a8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 82
}
},
"source": [
"num_train_steps = int(len(train_examples) / BATCH_SIZE * NUM_TRAIN_EPOCHS)\n",
"num_warmup_steps = int(num_train_steps * WARMUP_PROPORTION)\n",
"tf.logging.info(\"***** Running training *****\")\n",
"tf.logging.info(\" Num examples = %d\", len(train_examples))\n",
"tf.logging.info(\" Batch size = %d\", BATCH_SIZE)\n",
"tf.logging.info(\" Num steps = %d\", num_train_steps)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:***** Running training *****\n",
"INFO:tensorflow: Num examples = 6839\n",
"INFO:tensorflow: Batch size = 16\n",
"INFO:tensorflow: Num steps = 2137\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AKLAN-yD073K",
"colab_type": "code",
"colab": {}
},
"source": [
"train_input_fn = file_based_input_fn_builder(\n",
" input_file=train_file,\n",
" seq_length=MAX_SEQ_LENGTH,\n",
" is_training=True,\n",
" drop_remainder=False)\n",
"\n",
"eval_input_fn = file_based_input_fn_builder(\n",
" input_file=eval_file,\n",
" seq_length=MAX_SEQ_LENGTH,\n",
" is_training=False,\n",
" drop_remainder=False)\n",
"\n",
"test_input_fn = file_based_input_fn_builder(\n",
" input_file=test_file,\n",
" seq_length=MAX_SEQ_LENGTH,\n",
" is_training=False,\n",
" drop_remainder=False)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Ye-aDlm7c33_",
"colab_type": "code",
"colab": {}
},
"source": [
"OUTPUT_DIR = \"/content/Output_Dir\"\n",
"if not os.path.exists(OUTPUT_DIR):\n",
" os.mkdir(OUTPUT_DIR)\n",
"\n",
"# tf.gfile.MkDir(OUTPUT_DIR)\n",
"\n",
"# checkpoint_name = BERT_GCS_DIR + 'test.ckpt'\n",
"checkpoint_name = BERT_INIT_CHKPNT\n",
"OUTPUT_DIR_PER_MODEL = OUTPUT_DIR + \"/arabert/\"\n",
"print(OUTPUT_DIR_PER_MODEL)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "iE4XUL523q9L",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"\n",
"if tf.gfile.Exists(OUTPUT_DIR_PER_MODEL):\n",
" pass\n",
"else:\n",
" tf.gfile.MkDir(OUTPUT_DIR_PER_MODEL)\n",
"\n",
"num_train_steps = int(len(train_examples) / BATCH_SIZE * NUM_TRAIN_EPOCHS)\n",
"num_eval_steps = int(len(eval_examples) / BATCH_SIZE * NUM_TRAIN_EPOCHS)\n",
"num_warmup_steps = int(num_train_steps * WARMUP_PROPORTION)\n",
"num_steps_per_epoch = int(num_train_steps/BATCH_SIZE)\n",
"\n",
"print(\"num train steps: {}\".format(num_train_steps))\n",
"print(\"num warmup steps: {}\".format(num_warmup_steps))\n",
"print(\"num_steps_per_epoch: {}\".format(num_steps_per_epoch))\n",
"\n",
"bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)\n",
"\n",
"model_fn = model_fn_builder(\n",
" bert_config=bert_config,\n",
" num_labels=len(LABEL_COLUMNS),\n",
" init_checkpoint=checkpoint_name,\n",
" learning_rate=LEARNING_RATE,\n",
" num_train_steps=num_train_steps,\n",
" num_warmup_steps=num_warmup_steps,\n",
" use_tpu=False,\n",
" use_one_hot_embeddings=False)\n",
"\n",
"run_config = tf.estimator.RunConfig(\n",
" keep_checkpoint_max=100,\n",
" model_dir=OUTPUT_DIR_PER_MODEL,\n",
" save_summary_steps=SAVE_SUMMARY_STEPS,\n",
" save_checkpoints_steps=200)\n",
"\n",
"estimator = tf.estimator.Estimator(\n",
" model_fn=model_fn,\n",
" config=run_config,\n",
" params={\"batch_size\": BATCH_SIZE})\n",
"\n",
"print(f'Beginning Training!')\n",
"estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)\n",
"\n",
"print(f'Beginning Evaluation!')\n",
"eval_model_files = tf.gfile.Glob(os.path.join(OUTPUT_DIR_PER_MODEL,'*index'))\n",
"\n",
"for eval_checkpoint in tqdm(sorted(eval_model_files,key=lambda x: int(x[0:-6].split('-')[-1]))):\n",
"# eval_checkpoint = eval_model_files[3]\n",
" estimator.evaluate(input_fn=eval_input_fn, steps=int(num_eval_steps/BATCH_SIZE),checkpoint_path=eval_checkpoint[0:-6])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "9_CNI-tefRnC",
"colab_type": "text"
},
"source": [
"### Evaluating"
]
},
{
"cell_type": "code",
"metadata": {
"id": "yK57NMAdHd8k",
"colab_type": "code",
"colab": {}
},
"source": [
"import numpy as np\n",
"from sklearn.metrics import f1_score\n",
"\n",
"offensive_true = np.asarray(eval_[\"offensive\"], dtype=np.int32)\n",
"hatespeech_true = np.asarray(eval_[\"hate_speech\"], dtype=np.int32)\n",
"\n",
"def get_f1(preds, offensive_true, hatespeech_true, f1_averaging):\n",
" model_predictions = list()\n",
" for i in preds:\n",
" model_predictions.append(i)\n",
"\n",
" hate_speech_output = list()\n",
" offense_output = list()\n",
" for instance in model_predictions:\n",
" hate_speech_output.append(list(instance['probabilities_hate']))\n",
" offense_output.append(list(instance['probabilities_offense']))\n",
" hate_speech_output = np.asarray(hate_speech_output)\n",
" offense_output = np.asarray(offense_output)\n",
"\n",
" hate_speech_output = np.round(hate_speech_output).astype(np.int32)\n",
" offense_output = np.round(offense_output).astype(np.int32)\n",
"\n",
" # print('not_offensive f1 score:', f1_score(offensive_true_one_hot[:, 0], offense_output[:, 0]))\n",
" # print('not_hate_speech f1 score:', f1_score(hatespeech_true_one_hot[:, 0], hate_speech_output[:, 0]))\n",
" # print('offensive f1 score:', f1_score(offensive_true_one_hot[:, 1], offense_output[:, 1]))\n",
" # print('hate_speech f1 score:', f1_score(hatespeech_true_one_hot[:, 1], hate_speech_output[:, 1]))\n",
" offensive_true_one_hot = np.zeros((offensive_true.size, offensive_true.max()+1))\n",
" offensive_true_one_hot[np.arange(offensive_true.size), offensive_true] = 1\n",
" hatespeech_true_one_hot = np.zeros((hatespeech_true.size, hatespeech_true.max()+1))\n",
" hatespeech_true_one_hot[np.arange(hatespeech_true.size), hatespeech_true] = 1\n",
"\n",
" from prettytable import PrettyTable\n",
" \n",
" tab = PrettyTable()\n",
" print('\\n')\n",
" tab.field_names = [\"\",\"not_offensive f1\", 'not_hate_speech f1', 'offensive f1', 'hate_speech f1']\n",
" tab.add_row([\"values\", np.round(f1_score(offensive_true_one_hot[:, 0], offense_output[:, 0], average= f1_averaging),4),\n",
" np.round(f1_score(hatespeech_true_one_hot[:, 0], hate_speech_output[:, 0], average= f1_averaging),4),\n",
" np.round(f1_score(offensive_true_one_hot[:, 1], offense_output[:, 1], average= f1_averaging),4), \n",
" np.round(f1_score(hatespeech_true_one_hot[:, 1], hate_speech_output[:, 1], average= f1_averaging),4)])\n",
" print(tab)\n",
"\n",
"\n",
" f1_not_hate_speech = f1_score(offensive_true_one_hot[:, 0], offense_output[:, 0])\n",
" f1_not_hate_speech = f1_score(hatespeech_true_one_hot[:, 0], hate_speech_output[:, 0])\n",
" f1_offensive = f1_score(offensive_true_one_hot[:, 1], offense_output[:, 1])\n",
" f1_hate_speech = f1_score(hatespeech_true_one_hot[:, 1], hate_speech_output[:, 1])\n",
"\n",
" return f1_not_hate_speech, f1_not_hate_speech, f1_offensive, f1_hate_speech"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LM35TL6MHh_X",
"colab_type": "code",
"outputId": "3d9fcb27-f368-4fa0-84bd-a7bd1be4dfed",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 980
}
},
"source": [
"# offensive_labels = x_val.iloc[:,1]\n",
"# hate_labels = x_val.iloc[:,2]\n",
"eval_model_files = tf.gfile.Glob(os.path.join(OUTPUT_DIR_PER_MODEL,'*index'))\n",
"for eval_checkpoint in tqdm(sorted(eval_model_files,key=lambda x: int(x[0:-6].split('-')[-1]))):\n",
" \n",
" result = estimator.predict(eval_input_fn, checkpoint_path=eval_checkpoint[0:-6])\n",
"\n",
"\n",
" f1_not_hate_speech, f1_not_hate_speech, f1_offensive, f1_hate_speech = get_f1(result, offensive_true, hatespeech_true, 'macro')"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"\r 0%| | 0/1 [00:00, ?it/s]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:\n",
"The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
"For more information, please see:\n",
" * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
" * https://github.com/tensorflow/addons\n",
" * https://github.com/tensorflow/io (for I/O related ops)\n",
"If you depend on functionality not listed there, please file an issue.\n",
"\n",
"WARNING:tensorflow:From :179: map_and_batch (from tensorflow.contrib.data.python.ops.batching) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.data.experimental.map_and_batch(...)`.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/contrib/data/python/ops/batching.py:276: map_and_batch (from tensorflow.python.data.experimental.ops.batching) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.data.Dataset.map(map_func, num_parallel_calls)` followed by `tf.data.Dataset.batch(batch_size, drop_remainder)`. Static tf.data optimizations will take care of using the fused implementation.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/autograph/converters/directives.py:119: The name tf.parse_single_example is deprecated. Please use tf.io.parse_single_example instead.\n",
"\n",
"WARNING:tensorflow:From :159: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.cast` instead.\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:*** Features ***\n",
"INFO:tensorflow: name = input_ids, shape = (?, 256)\n",
"INFO:tensorflow: name = input_mask, shape = (?, 256)\n",
"INFO:tensorflow: name = is_real_example, shape = (?,)\n",
"INFO:tensorflow: name = label_ids, shape = (?, 2)\n",
"INFO:tensorflow: name = segment_ids, shape = (?, 256)\n",
"WARNING:tensorflow:From /content/bert/modeling.py:171: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
"\n",
"WARNING:tensorflow:From /content/bert/modeling.py:409: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
"\n",
"WARNING:tensorflow:From /content/bert/modeling.py:490: The name tf.assert_less_equal is deprecated. Please use tf.compat.v1.assert_less_equal instead.\n",
"\n",
"WARNING:tensorflow:From /content/bert/modeling.py:671: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use keras.layers.Dense instead.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/layers/core.py:187: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Please use `layer.__call__` method instead.\n",
"INFO:tensorflow:num_labels:2;logits:[];labels:Tensor(\"IteratorGetNext:3\", shape=(?, 2), dtype=int32)\n",
"INFO:tensorflow:num_labels:2;logits:[, ];labels:Tensor(\"IteratorGetNext:3\", shape=(?, 2), dtype=int32)\n",
"INFO:tensorflow:**** Trainable Variables ****\n",
"mode: infer probabilities: [, ]\n",
"INFO:tensorflow:Done calling model_fn.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/array_ops.py:1475: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Restoring parameters from /content/Output_Dir/1250000/model.ckpt-0\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"\r100%|██████████| 1/1 [00:36<00:00, 36.97s/it]"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"\n",
"\n",
"+--------+------------------+--------------------+--------------+----------------+\n",
"| | not_offensive f1 | not_hate_speech f1 | offensive f1 | hate_speech f1 |\n",
"+--------+------------------+--------------------+--------------+----------------+\n",
"| values | 0.9015 | 0.8341 | 0.9015 | 0.8341 |\n",
"+--------+------------------+--------------------+--------------+----------------+\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "epyAnG_bfdl8",
"colab_type": "text"
},
"source": [
"#### Saving Best Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "bb9a3uwyr_-2",
"colab_type": "code",
"colab": {}
},
"source": [
"eval_checkpoint = eval_model_files[7]\n",
"checkpoint_path = eval_checkpoint[0:-6]\n",
"# result = estimator.predict(eval_input_fn, checkpoint_path = eval_checkpoint[0:-6])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hCRbQfe8aNL8",
"colab_type": "code",
"colab": {}
},
"source": [
"!cp -r /content/Output_Dir/arabert/model.ckpt-1200.index /content/best_model/\n",
"!cp -r /content/Output_Dir/arabert/model.ckpt-1200.meta /content/best_model/\n",
"!cp -r /content/Output_Dir/arabert/graph.pbtxt /content/best_model/\n",
"!cp -r /content/Output_Dir/arabert/events.out.tfevents.1582011537.9f8209d16eb2 /content/best_model/\n",
"!cp -r /content/Output_Dir/arabert/checkpoint /content/best_model/\n",
"!cp -r /content/Output_Dir/arabert/model.ckpt-1200.data-00000-of-00001 /content/best_model/"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "rLfSGVZXadpL",
"colab_type": "code",
"outputId": "807463dd-cdfb-4b91-f47f-49d89461b558",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 131
}
},
"source": [
"!zip -r best_model.zip /content/best_model"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"updating: content/best_model/ (stored 0%)\n",
" adding: content/best_model/events.out.tfevents.1582011537.9f8209d16eb2 (deflated 92%)\n",
" adding: content/best_model/checkpoint (deflated 84%)\n",
" adding: content/best_model/model.ckpt-1200.meta (deflated 92%)\n",
" adding: content/best_model/graph.pbtxt (deflated 97%)\n",
" adding: content/best_model/model.ckpt-1200.index (deflated 69%)\n",
" adding: content/best_model/model.ckpt-1200.data-00000-of-00001 (deflated 25%)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "sYv2cLAqbwaB",
"colab_type": "code",
"colab": {}
},
"source": [
"!cp /content/best_model.zip /content/drive/'My Drive'/best_MTL_model/"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "DP8q2FX3T3QY",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"id": "SanXbZg_g9uz",
"colab_type": "code",
"colab": {}
},
"source": [
"result = estimator.predict(test_input_fn, checkpoint_path=eval_checkpoint[0:-6])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "lZDpSp2EfhvM",
"colab_type": "text"
},
"source": [
"### Predicting on eval and test sets"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ErsiphU0sVjj",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_predictions(path_to_input_text, predictions, filename, true_labels, is_test_data):\n",
" label_list = ['0','1']\n",
" label_list_2= ['0', '1']\n",
" output_predict_file = \"/content/test_results.tsv\"\n",
" res = []\n",
" f = open(path_to_input_text, 'r', encoding='utf-8')\n",
" lines = f.readlines()\n",
" str1 = ''\n",
" with tf.gfile.GFile(output_predict_file, \"w\") as writer:\n",
" tf.logging.info(\"***** Predict results *****\")\n",
"\n",
" for i, preds in enumerate(predictions):\n",
" prediction_1 = preds[\"probabilities_offense\"]\n",
" prediction_2 = preds[\"probabilities_hate\"]\n",
" output_line = \"\\t\".join(\n",
" str(class_probability) for class_probability in prediction_1) + \"\\n\"\n",
" writer.write(output_line)\n",
"\n",
" prediction_1 = prediction_1.tolist()\n",
" prediction_2 = prediction_2.tolist()\n",
" pre_label_1 = label_list[prediction_1.index(max(prediction_1))]\n",
" pre_label_2 = label_list_2[prediction_2.index(max(prediction_2))]\n",
" # print(i)\n",
" li = lines[i].strip().split('\\t')\n",
" # print(li)\n",
" if is_test_data:\n",
" str1 += li[0] + '\\t' + pre_label_1 + '\\t' + pre_label_2 + '\\n'\n",
" \n",
" else: \n",
" str1 += li[0] + '\\t' + pre_label_1 + '\\t'+ true_labels.iloc[i,0] + '\\t' + pre_label_2 + '\\t' + true_labels.iloc[i,1] + '\\n'\n",
"\n",
" \n",
"\n",
" with open('result.txt', 'w', encoding='utf-8') as f:\n",
" f.write(str1)\n",
"\n",
" preds = pd.read_csv(\"/content/result.txt\",sep=\"\\t\",header=None)\n",
" if is_test_data:\n",
" preds.columns = [\"text\",\"offensive\",\"hate_speech\", ]\n",
" else:\n",
" \n",
" preds.columns = [\"text\",\"offensive\",\"true_off\",\"hate_speech\", \"true_hate\"]\n",
"\n",
" preds.to_excel('{}.xlsx'.format(filename))\n",
"\n",
" return preds"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oP_sa6g6tgQ6",
"colab_type": "code",
"colab": {}
},
"source": [
"eval_model_files"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WZxW2wgItrQe",
"colab_type": "code",
"colab": {}
},
"source": [
"# preds = get_predictions('/content/OSACT2020-sharedTask-test-tweets.txt', result, 'test_preds')\n",
"eval_checkpoint = eval_model_files[7]\n",
"result_eval = estimator.predict(input_fn = eval_input_fn, checkpoint_path=eval_checkpoint[0:-6])\n",
"val_labels = x_train_w_eval.iloc[:,1:]\n",
"# train_labels = x_train.iloc[:,1:]\n",
"preds_eval = get_predictions('/content/data/OSACT2020-sharedTask-dev.txt', result_eval, 'eval_preds_1400', val_labels, is_test_data=True)\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "e76z2YW2htIA",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_predictions_final( predictions):\n",
" label_list = ['NOT_OFF','OFF ']\n",
" label_list_2= ['NOT_HS', 'HS']\n",
" output_predict_file = \"/content/test_results.tsv\"\n",
" res = []\n",
" str1 = ''\n",
" str2 = ''\n",
" with tf.gfile.GFile(output_predict_file, \"w\") as writer:\n",
" tf.logging.info(\"***** Predict results *****\")\n",
"\n",
" for i, preds in enumerate(predictions):\n",
" prediction_1 = preds[\"probabilities_offense\"]\n",
" prediction_2 = preds[\"probabilities_hate\"]\n",
" output_line = \"\\t\".join(\n",
" str(class_probability) for class_probability in prediction_1) + \"\\n\"\n",
" writer.write(output_line)\n",
"\n",
" prediction_1 = prediction_1.tolist()\n",
" prediction_2 = prediction_2.tolist()\n",
" pre_label_1 = label_list[prediction_1.index(max(prediction_1))]\n",
" pre_label_2 = label_list_2[prediction_2.index(max(prediction_2))]\n",
" \n",
" str1 += pre_label_1 + '\\n'\n",
" str2 += pre_label_2 + '\\n'\n",
" \n",
"\n",
" with open('result_off.txt', 'w', encoding='utf-8') as f:\n",
" f.write(str1)\n",
" \n",
" with open('result_HS.txt', 'w', encoding='utf-8') as f:\n",
" f.write(str2)\n",
"\n",
"\n",
" return preds"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "AZwcI8VQrOcH",
"colab_type": "code",
"outputId": "e5e1aa55-c7ce-475f-9bf2-d1e48777ed16",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 33
}
},
"source": [
"x_test.shape"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(2000, 3)"
]
},
"metadata": {
"tags": []
},
"execution_count": 23
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "BcmcDR4PnZOO",
"colab_type": "code",
"colab": {}
},
"source": [
"# preds = get_predictions('/content/OSACT2020-sharedTask-test-tweets.txt', result, 'test_preds')\n",
"eval_model_files = tf.gfile.Glob(os.path.join(OUTPUT_DIR_PER_MODEL,'*index'))\n",
"eval_checkpoint = eval_model_files[7]\n",
"result_eval = estimator.predict(input_fn = eval_input_fn, checkpoint_path=eval_checkpoint[0:-6])\n",
"\n",
"# train_labels = x_train.iloc[:,1:]\n",
"preds_eval = get_predictions_final( result_eval )\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "nuRcSTiAPiL0",
"colab_type": "code",
"colab": {}
},
"source": [
"!cp eval_preds.xlsx /content/drive/'My Drive'/"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "cShj5A_9hhfn",
"colab_type": "code",
"colab": {}
},
"source": [
"label_list = ['non-offense','offense']\n",
"label_list_2= ['non-hate', 'hate']\n",
"output_predict_file = \"/content/test_results.tsv\"\n",
"res = []\n",
"f = open('/content/OSACT2020-sharedTask-test-tweets.txt', 'r', encoding='utf-8')\n",
"lines = f.readlines()\n",
"str1 = ''\n",
"with tf.gfile.GFile(output_predict_file, \"w\") as writer:\n",
" tf.logging.info(\"***** Predict results *****\")\n",
"\n",
" for i, predictions in enumerate(result):\n",
" prediction = predictions[\"probabilities_offense\"]\n",
" prediction2 = predictions[\"probabilities_hate\"]\n",
" output_line = \"\\t\".join(\n",
" str(class_probability) for class_probability in prediction) + \"\\n\"\n",
" writer.write(output_line)\n",
"\n",
" prediction = prediction.tolist()\n",
" prediction2 = prediction2.tolist()\n",
" pre_label = label_list[prediction.index(max(prediction))]\n",
" pre_label2 = label_list_2[prediction2.index(max(prediction2))]\n",
"\n",
" li = lines[i].strip().split('\\t')\n",
" str1 += li[0] + '\\t' + pre_label + '\\t' + pre_label2 + '\\n'\n",
"\n",
" \n",
"\n",
" with open('result.txt', 'w', encoding='utf-8') as f:\n",
" f.write(str1)\n",
"\n"
],
"execution_count": 0,
"outputs": []
}
]
}