{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "e3b3364f" }, "source": [ "# Fine Tuning roberta model with Twitter Data\n", "* List item\n", "* List item\n", "\n", "\n", "\n" ], "id": "e3b3364f" }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "7cfcb724" }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "from sklearn.feature_extraction.text import CountVectorizer\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.tree import DecisionTreeClassifier" ], "id": "7cfcb724" }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "60ab1d26", "outputId": "ee4eb86f-3994-4391-e5c4-e2f363290977" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "[nltk_data] Downloading package stopwords to /root/nltk_data...\n", "[nltk_data] Package stopwords is already up-to-date!\n" ] } ], "source": [ "import re\n", "import string\n", "import nltk\n", "nltk.download('stopwords')\n", "from nltk.stem.snowball import SnowballStemmer" ], "id": "60ab1d26" }, { "cell_type": "markdown", "metadata": { "id": "538a8bf3" }, "source": [ "### Imports" ], "id": "538a8bf3" }, { "cell_type": "code", "execution_count": 25, "metadata": { "id": "bae03f72", "scrolled": true }, "outputs": [], "source": [ "stemmer = nltk.SnowballStemmer(\"english\")\n", "from nltk.corpus import stopwords\n", "import string\n", "stopword = set(stopwords.words(\"english\"))" ], "id": "bae03f72" }, { "cell_type": "code", "execution_count": 26, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 327 }, "id": "6de55c38", "outputId": "618d357e-2e8b-4300-cacb-c07b514913ce" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Unnamed: 0 count hate_speech offensive_language neither class \\\n", "0 0 3 0 0 3 2 \n", "1 1 3 0 3 0 1 \n", "2 2 3 0 3 0 1 \n", "3 3 3 0 2 1 1 \n", "4 4 6 0 6 0 1 \n", "\n", " tweet \n", "0 !!! RT @mayasolovely: As a woman you shouldn't... \n", "1 !!!!! RT @mleew17: boy dats cold...tyga dwn ba... \n", "2 !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... \n", "3 !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... \n", "4 !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0counthate_speechoffensive_languageneitherclasstweet
0030032!!! RT @mayasolovely: As a woman you shouldn't...
1130301!!!!! RT @mleew17: boy dats cold...tyga dwn ba...
2230301!!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby...
3330211!!!!!!!!! RT @C_G_Anderson: @viva_based she lo...
4460601!!!!!!!!!!!!! RT @ShenikaRoberts: The shit you...
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "df1", "summary": "{\n \"name\": \"df1\",\n \"rows\": 24783,\n \"fields\": [\n {\n \"column\": \"Unnamed: 0\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 7299,\n \"min\": 0,\n \"max\": 25296,\n \"num_unique_values\": 24783,\n \"samples\": [\n 2326,\n 16283,\n 19362\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"count\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 3,\n \"max\": 9,\n \"num_unique_values\": 5,\n \"samples\": [\n 6,\n 7,\n 9\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"hate_speech\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 7,\n \"num_unique_values\": 8,\n \"samples\": [\n 1,\n 6,\n 0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"offensive_language\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 9,\n \"num_unique_values\": 10,\n \"samples\": [\n 8,\n 3,\n 7\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"neither\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 9,\n \"num_unique_values\": 10,\n \"samples\": [\n 8,\n 0,\n 4\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"class\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 2,\n 1,\n 0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"tweet\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 24783,\n \"samples\": [\n \"934 8616\\ni got a missed call from yo bitch\",\n \"RT @KINGTUNCHI_: Fucking with a bad bitch you gone need some money lil homie!\",\n \"RT @eanahS__: @1inkkofrosess lol my credit ain't no where near good , but I know the right man for the job .. that ho nice though!\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 26 } ], "source": [ "df1 = pd.read_csv(\"twitter_data.csv\")\n", "df1 = df1.dropna()\n", "df1.head()" ], "id": "6de55c38" }, { "cell_type": "markdown", "metadata": { "id": "4c288f90" }, "source": [ "#### `.tolist()` converts NumPy arrays into Python lists." ], "id": "4c288f90" }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dd22b72e", "outputId": "90c4cb5a-7de8-4a19-8e7e-b227eac1b701" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "['Unnamed: 0', 'count', 'hate_speech', 'offensive_language', 'neither', 'class', 'tweet']\n" ] } ], "source": [ "print(df1.columns.tolist())\n" ], "id": "dd22b72e" }, { "cell_type": "markdown", "metadata": { "id": "f584d46b" }, "source": [ "- The `.map()` function applies a specified function to an iterable and returns the result.\n", "- We used the `.map` function to assign 0, 1, and 2 to \"Hate Speech Detected\", \"Offensive language detected\", and \"No hate and - - offensive speech\"" ], "id": "f584d46b" }, { "cell_type": "markdown", "source": [ "### Preprocess the Labels" ], "metadata": { "id": "MSIgr88pMz8x" }, "id": "MSIgr88pMz8x" }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "117eadd5" }, "outputs": [], "source": [ "df1['labels'] = df1['class'].map({0:\"Hate Speech Detected\", 1:\"Offensive language detected\", 2:\"No hate and offensive speech\"})\n", "\n", "# Merging the labels\n", "def unify_labels(row):\n", " if row['labels'] in ['Hate Speech Detected', 'Offensive language detected']:\n", " return 'Offensive or Hate Speech'\n", " else:\n", " return 'Not Hate'\n", "\n", "# Apply this function to the dataset with three labels\n", "df1['labels'] = df1.apply(unify_labels, axis=1)" ], "id": "117eadd5" }, { "cell_type": "code", "execution_count": 29, "metadata": { "id": "8fdf617f", "colab": { "base_uri": "https://localhost:8080/", "height": 486 }, "outputId": "e79d6a0a-e650-46a2-ad66-ee93d9d66f9a" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Unnamed: 0 count hate_speech offensive_language neither class \\\n", "0 0 3 0 0 3 2 \n", "1 1 3 0 3 0 1 \n", "2 2 3 0 3 0 1 \n", "3 3 3 0 2 1 1 \n", "4 4 6 0 6 0 1 \n", "\n", " tweet labels \n", "0 !!! RT @mayasolovely: As a woman you shouldn't... Not Hate \n", "1 !!!!! RT @mleew17: boy dats cold...tyga dwn ba... Offensive or Hate Speech \n", "2 !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... Offensive or Hate Speech \n", "3 !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... Offensive or Hate Speech \n", "4 !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... Offensive or Hate Speech " ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0counthate_speechoffensive_languageneitherclasstweetlabels
0030032!!! RT @mayasolovely: As a woman you shouldn't...Not Hate
1130301!!!!! RT @mleew17: boy dats cold...tyga dwn ba...Offensive or Hate Speech
2230301!!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby...Offensive or Hate Speech
3330211!!!!!!!!! RT @C_G_Anderson: @viva_based she lo...Offensive or Hate Speech
4460601!!!!!!!!!!!!! RT @ShenikaRoberts: The shit you...Offensive or Hate Speech
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "df1", "summary": "{\n \"name\": \"df1\",\n \"rows\": 24783,\n \"fields\": [\n {\n \"column\": \"Unnamed: 0\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 7299,\n \"min\": 0,\n \"max\": 25296,\n \"num_unique_values\": 24783,\n \"samples\": [\n 2326,\n 16283,\n 19362\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"count\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 3,\n \"max\": 9,\n \"num_unique_values\": 5,\n \"samples\": [\n 6,\n 7,\n 9\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"hate_speech\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 7,\n \"num_unique_values\": 8,\n \"samples\": [\n 1,\n 6,\n 0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"offensive_language\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 9,\n \"num_unique_values\": 10,\n \"samples\": [\n 8,\n 3,\n 7\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"neither\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 9,\n \"num_unique_values\": 10,\n \"samples\": [\n 8,\n 0,\n 4\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"class\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 2,\n 1,\n 0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"tweet\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 24783,\n \"samples\": [\n \"934 8616\\ni got a missed call from yo bitch\",\n \"RT @KINGTUNCHI_: Fucking with a bad bitch you gone need some money lil homie!\",\n \"RT @eanahS__: @1inkkofrosess lol my credit ain't no where near good , but I know the right man for the job .. that ho nice though!\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"labels\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"Offensive or Hate Speech\",\n \"Not Hate\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 29 } ], "source": [ "df1['labels'].info\n", "df1.head()" ], "id": "8fdf617f" }, { "cell_type": "markdown", "source": [ "### Import the second dataset" ], "metadata": { "id": "9DgbrPGdSk5O" }, "id": "9DgbrPGdSk5O" }, { "cell_type": "markdown", "metadata": { "id": "a420ba1c" }, "source": [ "### Formated to two tables of tweets and labels" ], "id": "a420ba1c" }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "5db5746b", "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "outputId": "1f2e2ba7-0288-4920-f852-3a4e4be5b3e4" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " tweet labels\n", "0 !!! RT @mayasolovely: As a woman you shouldn't... Not Hate\n", "1 !!!!! RT @mleew17: boy dats cold...tyga dwn ba... Offensive or Hate Speech\n", "2 !!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby... Offensive or Hate Speech\n", "3 !!!!!!!!! RT @C_G_Anderson: @viva_based she lo... Offensive or Hate Speech\n", "4 !!!!!!!!!!!!! RT @ShenikaRoberts: The shit you... Offensive or Hate Speech" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tweetlabels
0!!! RT @mayasolovely: As a woman you shouldn't...Not Hate
1!!!!! RT @mleew17: boy dats cold...tyga dwn ba...Offensive or Hate Speech
2!!!!!!! RT @UrKindOfBrand Dawg!!!! RT @80sbaby...Offensive or Hate Speech
3!!!!!!!!! RT @C_G_Anderson: @viva_based she lo...Offensive or Hate Speech
4!!!!!!!!!!!!! RT @ShenikaRoberts: The shit you...Offensive or Hate Speech
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "df1", "summary": "{\n \"name\": \"df1\",\n \"rows\": 24783,\n \"fields\": [\n {\n \"column\": \"tweet\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 24783,\n \"samples\": [\n \"934 8616\\ni got a missed call from yo bitch\",\n \"RT @KINGTUNCHI_: Fucking with a bad bitch you gone need some money lil homie!\",\n \"RT @eanahS__: @1inkkofrosess lol my credit ain't no where near good , but I know the right man for the job .. that ho nice though!\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"labels\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"Offensive or Hate Speech\",\n \"Not Hate\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 30 } ], "source": [ "df1 = df1[['tweet', 'labels']]\n", "df1 = df1[['tweet', 'labels']].fillna(0)\n", "df1.head()" ], "id": "5db5746b" }, { "cell_type": "code", "source": [], "metadata": { "id": "O7ibjvL6LdKO" }, "id": "O7ibjvL6LdKO", "execution_count": 30, "outputs": [] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "604ec08e" }, "outputs": [], "source": [ "def clean(text):\n", " text = str(text).lower()\n", " text = re.sub('\\[.*?\\]', '', text)\n", " text = re.sub('https?://\\S+|www\\.\\S+', '', text)\n", " text = re.sub('<.*?>+', '', text)\n", " text = re.sub('[%s]' % re.escape(string.punctuation), '', text)\n", " text = re.sub('\\n', '', text)\n", " text = re.sub('\\w*\\d\\w*', \"\", text)\n", " text = [word for word in text.split() if word not in stopword]\n", " text = \" \".join(text)\n", " return text\n", "# Apply cleaning function to the 'tweet' column of both dataframes\n", "df1['tweet'] = df1['tweet'].apply(clean)\n" ], "id": "604ec08e" }, { "cell_type": "code", "source": [ "df1.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "XuArptokP5u4", "outputId": "7f650ae3-4686-4585-8c70-e6bcdc0afc0e" }, "id": "XuArptokP5u4", "execution_count": 32, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " tweet labels\n", "0 rt mayasolovely woman shouldnt complain cleani... Not Hate\n", "1 rt boy dats coldtyga dwn bad cuffin dat hoe place Offensive or Hate Speech\n", "2 rt urkindofbrand dawg rt ever fuck bitch start... Offensive or Hate Speech\n", "3 rt cganderson vivabased look like tranny Offensive or Hate Speech\n", "4 rt shenikaroberts shit hear might true might f... Offensive or Hate Speech" ], "text/html": [ "\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tweetlabels
0rt mayasolovely woman shouldnt complain cleani...Not Hate
1rt boy dats coldtyga dwn bad cuffin dat hoe placeOffensive or Hate Speech
2rt urkindofbrand dawg rt ever fuck bitch start...Offensive or Hate Speech
3rt cganderson vivabased look like trannyOffensive or Hate Speech
4rt shenikaroberts shit hear might true might f...Offensive or Hate Speech
\n", "
\n", "
\n", "\n", "
\n", " \n", "\n", " \n", "\n", " \n", "
\n", "\n", "\n", "
\n", " \n", "\n", "\n", "\n", " \n", "
\n", "\n", "
\n", "
\n" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "dataframe", "variable_name": "df1", "summary": "{\n \"name\": \"df1\",\n \"rows\": 24783,\n \"fields\": [\n {\n \"column\": \"tweet\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 24506,\n \"samples\": [\n \"didnt even get see baby today smh moms fault selfish bitch\",\n \"hoes got money mall ballin bitch buy something\",\n \"rt johnnyfootbali yeah kaepernick might biceps like greek god dude looks like conceived proboscis monkey\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"labels\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"Offensive or Hate Speech\",\n \"Not Hate\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" } }, "metadata": {}, "execution_count": 32 } ] }, { "cell_type": "code", "source": [ "# Use a pipeline as a high-level helper\n", "from transformers import pipeline\n", "\n", "pipe = pipeline(\"text-classification\", model=\"facebook/roberta-hate-speech-dynabench-r4-target\")" ], "metadata": { "id": "1ey-IvgELgAJ" }, "id": "1ey-IvgELgAJ", "execution_count": 33, "outputs": [] }, { "cell_type": "code", "source": [ "# Install necessary libraries\n", "!pip install transformers\n", "\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from transformers import TextClassificationPipeline, AutoTokenizer, AutoModelForSequenceClassification\n", "from datasets import Dataset\n", "from transformers import Trainer, TrainingArguments\n", "\n", "\n", "# Load your CSV file into a pandas DataFrame\n", "data = pd.read_csv(\"twitter_data.csv\")\n", "\n", "# Add the 'labels' column to the 'data' DataFrame\n", "data['labels'] = [1 if tweet == 'hate_speech' else 0 for tweet in data['class']]\n", "\n", "# Split data into train and validation sets\n", "train_texts, val_texts, train_labels, val_labels = train_test_split(data[\"tweet\"], data[\"labels\"], test_size=0.2, random_state=42)\n", "\n", "# Load pre-trained tokenizer and model\n", "tokenizer = AutoTokenizer.from_pretrained(\"facebook/roberta-hate-speech-dynabench-r4-target\")\n", "model = AutoModelForSequenceClassification.from_pretrained(\"facebook/roberta-hate-speech-dynabench-r4-target\")\n", "\n", "# Tokenize the input texts\n", "train_encodings = tokenizer(list(train_texts), truncation=True, padding=True)\n", "val_encodings = tokenizer(list(val_texts), truncation=True, padding=True)\n", "\n", "# Convert labels to tensors\n", "train_labels = list(train_labels)\n", "val_labels = list(val_labels)\n", "\n", "# Create datasets\n", "train_dataset = Dataset.from_dict({\"input_ids\": train_encodings[\"input_ids\"],\n", " \"attention_mask\": train_encodings[\"attention_mask\"],\n", " \"labels\": train_labels})\n", "\n", "val_dataset = Dataset.from_dict({\"input_ids\": val_encodings[\"input_ids\"],\n", " \"attention_mask\": val_encodings[\"attention_mask\"],\n", " \"labels\": val_labels})\n", "\n", "# Fine-tune the model\n", "model.train(True)\n", "\n", "# have to test model" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mq5qeFd1OemP", "outputId": "43815cba-e57a-42da-8f1a-9058032ca5da" }, "id": "mq5qeFd1OemP", "execution_count": 40, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.38.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.20.3)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.25.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.12.25)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.2)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.2)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.2)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (4.10.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.2.2)\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "RobertaForSequenceClassification(\n", " (roberta): RobertaModel(\n", " (embeddings): RobertaEmbeddings(\n", " (word_embeddings): Embedding(50265, 768, padding_idx=1)\n", " (position_embeddings): Embedding(514, 768, padding_idx=1)\n", " (token_type_embeddings): Embedding(1, 768)\n", " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (encoder): RobertaEncoder(\n", " (layer): ModuleList(\n", " (0-11): 12 x RobertaLayer(\n", " (attention): RobertaAttention(\n", " (self): RobertaSelfAttention(\n", " (query): Linear(in_features=768, out_features=768, bias=True)\n", " (key): Linear(in_features=768, out_features=768, bias=True)\n", " (value): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (output): RobertaSelfOutput(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " (intermediate): RobertaIntermediate(\n", " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", " (intermediate_act_fn): GELUActivation()\n", " )\n", " (output): RobertaOutput(\n", " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", " (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " )\n", " )\n", " (classifier): RobertaClassificationHead(\n", " (dense): Linear(in_features=768, out_features=768, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (out_proj): Linear(in_features=768, out_features=2, bias=True)\n", " )\n", ")" ] }, "metadata": {}, "execution_count": 40 } ] }, { "cell_type": "code", "source": [ "text_classifier = pipeline(\"text-classification\", model=model, tokenizer=tokenizer)\n", "def test_model():\n", " while True:\n", " statement = input(\"Enter a statement to test (or type 'exit' to quit): \")\n", " if statement.lower() == 'exit':\n", " break\n", " offensive_probabilities = text_classifier(statement)\n", " print(offensive_probabilities)" ], "metadata": { "id": "l6MAu_NnUdxG" }, "id": "l6MAu_NnUdxG", "execution_count": 43, "outputs": [] }, { "cell_type": "code", "source": [ "test_model()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YRnBVSW2UjyW", "outputId": "f723c664-95eb-406e-9178-46bee6a7f5af" }, "id": "YRnBVSW2UjyW", "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Enter a statement to test (or type 'exit' to quit): kill\n", "[{'label': 'nothate', 'score': 0.9961766004562378}]\n" ] } ] }, { "cell_type": "code", "source": [], "metadata": { "id": "jCFfdS6CUNN7" }, "id": "jCFfdS6CUNN7", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [], "metadata": { "id": "XOIePsq1N5Fo" }, "id": "XOIePsq1N5Fo" }, { "cell_type": "code", "source": [ "from sklearn.metrics import accuracy_score, classification_report\n", "\n", "y_pred = clf.predict(X_test)\n", "print(f\"Accuracy: {accuracy_score(y_test, y_pred)}\")\n", "print(classification_report(y_test, y_pred))" ], "metadata": { "id": "sdFRCXtGY3yI" }, "id": "sdFRCXtGY3yI", "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fb36a279" }, "outputs": [], "source": [], "id": "fb36a279" } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }