sstyli09 commited on
Commit
ff86447
1 Parent(s): 2668044

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ emails.csv filter=lfs diff=lfs merge=lfs -text
all_email_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b51d8a15d0a7aec5b8342802d846f2356355f4d2e49d503c0c334c3b98ec998b
3
+ size 317890688
app.py ADDED
@@ -0,0 +1,2038 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "60db3983-f124-48d8-bddb-1a96f6c60b86",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Enron Email Dataset Analysis\n",
9
+ "\n",
10
+ "## Introduction\n",
11
+ "The Enron scandal was a significant event that highlighted the need for transparency and accountability in corporate America. Utilizing the Enron Email Dataset, this notebook seeks to aid investigators by analyzing communications among Enron's senior management.\n",
12
+ "\n",
13
+ "## Dataset Acquisition and Preparation\n",
14
+ "\n",
15
+ "### Installing Necessary Libraries\n",
16
+ "To work with the Enron dataset, certain Python libraries are necessary. Here we ensure `kaggle` is installed for dataset downloading.\n"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 62,
22
+ "id": "be2d6edc-02f5-48b0-9a47-094cbc58c837",
23
+ "metadata": {
24
+ "collapsed": true,
25
+ "jupyter": {
26
+ "outputs_hidden": true
27
+ }
28
+ },
29
+ "outputs": [
30
+ {
31
+ "name": "stdout",
32
+ "output_type": "stream",
33
+ "text": [
34
+ "Requirement already satisfied: kaggle in c:\\users\\stylianos\\myenv\\lib\\site-packages (1.6.12)\n",
35
+ "Requirement already satisfied: six>=1.10 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from kaggle) (1.16.0)\n",
36
+ "Requirement already satisfied: certifi>=2023.7.22 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from kaggle) (2024.2.2)\n",
37
+ "Requirement already satisfied: python-dateutil in c:\\users\\stylianos\\myenv\\lib\\site-packages (from kaggle) (2.9.0.post0)\n",
38
+ "Requirement already satisfied: requests in c:\\users\\stylianos\\myenv\\lib\\site-packages (from kaggle) (2.31.0)\n",
39
+ "Requirement already satisfied: tqdm in c:\\users\\stylianos\\myenv\\lib\\site-packages (from kaggle) (4.66.2)\n",
40
+ "Requirement already satisfied: python-slugify in c:\\users\\stylianos\\myenv\\lib\\site-packages (from kaggle) (8.0.4)\n",
41
+ "Requirement already satisfied: urllib3 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from kaggle) (2.2.1)\n",
42
+ "Requirement already satisfied: bleach in c:\\users\\stylianos\\myenv\\lib\\site-packages (from kaggle) (6.1.0)\n",
43
+ "Requirement already satisfied: webencodings in c:\\users\\stylianos\\myenv\\lib\\site-packages (from bleach->kaggle) (0.5.1)\n",
44
+ "Requirement already satisfied: text-unidecode>=1.3 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from python-slugify->kaggle) (1.3)\n",
45
+ "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->kaggle) (3.3.2)\n",
46
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->kaggle) (3.7)\n",
47
+ "Requirement already satisfied: colorama in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tqdm->kaggle) (0.4.6)\n"
48
+ ]
49
+ }
50
+ ],
51
+ "source": [
52
+ "! pip install kaggle"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "id": "414af69c-f2b6-4c8e-92ea-7af62e377a6a",
58
+ "metadata": {},
59
+ "source": [
60
+ "## Configuring Kaggle API\n",
61
+ "To access the Kaggle dataset, we set up the Kaggle API credentials."
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 63,
67
+ "id": "ec52b055-4a88-463f-a10e-967391e2b6bb",
68
+ "metadata": {},
69
+ "outputs": [
70
+ {
71
+ "data": {
72
+ "text/plain": [
73
+ "'C:\\\\Users\\\\Stylianos\\\\.kaggle\\\\kaggle.json'"
74
+ ]
75
+ },
76
+ "execution_count": 63,
77
+ "metadata": {},
78
+ "output_type": "execute_result"
79
+ }
80
+ ],
81
+ "source": [
82
+ "import os\n",
83
+ "from shutil import copyfile\n",
84
+ "\n",
85
+ "# Create a .kaggle directory in your home folder\n",
86
+ "kaggle_dir = os.path.join(os.path.expanduser('~'), '.kaggle')\n",
87
+ "os.makedirs(kaggle_dir, exist_ok=True)\n",
88
+ "\n",
89
+ "# Copy the kaggle.json to the .kaggle directory\n",
90
+ "copyfile('kaggle.json', os.path.join(kaggle_dir, 'kaggle.json'))\n",
91
+ "\n",
92
+ "# Only for Unix-based systems: Make sure permissions for the file are set properly\n",
93
+ "# For Windows, this step can be skipped, as file permissions work differently\n"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "markdown",
98
+ "id": "65be3ff6-bcfd-4418-8060-0e3f43486a52",
99
+ "metadata": {},
100
+ "source": [
101
+ "## Downloading the Dataset\n",
102
+ "With the credentials in place, we download the Enron dataset directly from Kaggle."
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 64,
108
+ "id": "f8bf248d-8026-4b0f-a29c-c0f1795cdd64",
109
+ "metadata": {},
110
+ "outputs": [
111
+ {
112
+ "name": "stdout",
113
+ "output_type": "stream",
114
+ "text": [
115
+ "Dataset URL: https://www.kaggle.com/datasets/wcukierski/enron-email-dataset\n",
116
+ "License(s): copyright-authors\n",
117
+ "enron-email-dataset.zip: Skipping, found more recently modified local copy (use --force to force download)\n"
118
+ ]
119
+ }
120
+ ],
121
+ "source": [
122
+ "! kaggle datasets download -d wcukierski/enron-email-dataset"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "id": "bc0975e5-df3f-499a-881a-d2d35a7ca0ab",
128
+ "metadata": {},
129
+ "source": [
130
+ "## Unzipping the Dataset\n",
131
+ "We then extract the dataset from the downloaded zip file for further analysis."
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": 65,
137
+ "id": "49985613-825d-4f33-8f71-f1368f2cf470",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "import zipfile\n",
142
+ "\n",
143
+ "# The path to the zip file\n",
144
+ "zip_path = 'enron-email-dataset.zip' \n",
145
+ "\n",
146
+ "# Unzip the dataset\n",
147
+ "with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
148
+ " zip_ref.extractall() \n"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "id": "0559ce79-4610-4cbe-8246-2236b8e52f8b",
154
+ "metadata": {},
155
+ "source": [
156
+ "## Loading the Data into a DataFrame\n",
157
+ "For ease of manipulation and analysis, we load the dataset into a pandas DataFrame."
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 66,
163
+ "id": "b6838795-3d1c-4074-ad94-8298837f2b3d",
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "import pandas as pd\n",
168
+ "\n",
169
+ "# Read the file\n",
170
+ "df = pd.read_csv('emails.csv')\n"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "id": "1b38cc03-20b8-4b97-9b0a-9b96420831d1",
176
+ "metadata": {},
177
+ "source": [
178
+ "## Initial Data Exploration\n",
179
+ "A preliminary examination of the dataset provides insight into the structure of the data we will be working with."
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": 67,
185
+ "id": "99d41b77-be4a-4ea6-9df6-351d6a213021",
186
+ "metadata": {},
187
+ "outputs": [
188
+ {
189
+ "data": {
190
+ "text/html": [
191
+ "<div>\n",
192
+ "<style scoped>\n",
193
+ " .dataframe tbody tr th:only-of-type {\n",
194
+ " vertical-align: middle;\n",
195
+ " }\n",
196
+ "\n",
197
+ " .dataframe tbody tr th {\n",
198
+ " vertical-align: top;\n",
199
+ " }\n",
200
+ "\n",
201
+ " .dataframe thead th {\n",
202
+ " text-align: right;\n",
203
+ " }\n",
204
+ "</style>\n",
205
+ "<table border=\"1\" class=\"dataframe\">\n",
206
+ " <thead>\n",
207
+ " <tr style=\"text-align: right;\">\n",
208
+ " <th></th>\n",
209
+ " <th>file</th>\n",
210
+ " <th>message</th>\n",
211
+ " </tr>\n",
212
+ " </thead>\n",
213
+ " <tbody>\n",
214
+ " <tr>\n",
215
+ " <th>0</th>\n",
216
+ " <td>allen-p/_sent_mail/1.</td>\n",
217
+ " <td>Message-ID: &lt;18782981.1075855378110.JavaMail.e...</td>\n",
218
+ " </tr>\n",
219
+ " <tr>\n",
220
+ " <th>1</th>\n",
221
+ " <td>allen-p/_sent_mail/10.</td>\n",
222
+ " <td>Message-ID: &lt;15464986.1075855378456.JavaMail.e...</td>\n",
223
+ " </tr>\n",
224
+ " <tr>\n",
225
+ " <th>2</th>\n",
226
+ " <td>allen-p/_sent_mail/100.</td>\n",
227
+ " <td>Message-ID: &lt;24216240.1075855687451.JavaMail.e...</td>\n",
228
+ " </tr>\n",
229
+ " <tr>\n",
230
+ " <th>3</th>\n",
231
+ " <td>allen-p/_sent_mail/1000.</td>\n",
232
+ " <td>Message-ID: &lt;13505866.1075863688222.JavaMail.e...</td>\n",
233
+ " </tr>\n",
234
+ " <tr>\n",
235
+ " <th>4</th>\n",
236
+ " <td>allen-p/_sent_mail/1001.</td>\n",
237
+ " <td>Message-ID: &lt;30922949.1075863688243.JavaMail.e...</td>\n",
238
+ " </tr>\n",
239
+ " </tbody>\n",
240
+ "</table>\n",
241
+ "</div>"
242
+ ],
243
+ "text/plain": [
244
+ " file message\n",
245
+ "0 allen-p/_sent_mail/1. Message-ID: <18782981.1075855378110.JavaMail.e...\n",
246
+ "1 allen-p/_sent_mail/10. Message-ID: <15464986.1075855378456.JavaMail.e...\n",
247
+ "2 allen-p/_sent_mail/100. Message-ID: <24216240.1075855687451.JavaMail.e...\n",
248
+ "3 allen-p/_sent_mail/1000. Message-ID: <13505866.1075863688222.JavaMail.e...\n",
249
+ "4 allen-p/_sent_mail/1001. Message-ID: <30922949.1075863688243.JavaMail.e..."
250
+ ]
251
+ },
252
+ "execution_count": 67,
253
+ "metadata": {},
254
+ "output_type": "execute_result"
255
+ }
256
+ ],
257
+ "source": [
258
+ "df.head()"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 68,
264
+ "id": "e029e381-8842-4292-a264-1209a79feea9",
265
+ "metadata": {},
266
+ "outputs": [
267
+ {
268
+ "data": {
269
+ "text/plain": [
270
+ "file object\n",
271
+ "message object\n",
272
+ "dtype: object"
273
+ ]
274
+ },
275
+ "execution_count": 68,
276
+ "metadata": {},
277
+ "output_type": "execute_result"
278
+ }
279
+ ],
280
+ "source": [
281
+ "df.dtypes"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "code",
286
+ "execution_count": 69,
287
+ "id": "27187b0a-be3b-4971-afc5-e106407daf88",
288
+ "metadata": {},
289
+ "outputs": [
290
+ {
291
+ "data": {
292
+ "text/plain": [
293
+ "file 0\n",
294
+ "message 0\n",
295
+ "dtype: int64"
296
+ ]
297
+ },
298
+ "execution_count": 69,
299
+ "metadata": {},
300
+ "output_type": "execute_result"
301
+ }
302
+ ],
303
+ "source": [
304
+ "df.isnull().sum()"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": 70,
310
+ "id": "78ac00b4-7af6-4cfc-91c0-90b6f5ea485d",
311
+ "metadata": {},
312
+ "outputs": [],
313
+ "source": [
314
+ "# Sample 20% of the data\n",
315
+ "df_sampled = df.sample(frac=0.2, random_state=1) # 'random_state' for reproducibility\n",
316
+ "df_sampled.reset_index(drop=True, inplace=True)"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "markdown",
321
+ "id": "23d4d91d-b2d1-4af4-a16d-1c81150b77eb",
322
+ "metadata": {},
323
+ "source": [
324
+ "## Parsing the Emails\n",
325
+ "\n",
326
+ "### Importing Required Modules\n",
327
+ "We import the `email` module necessary for parsing email content and the `tqdm` module for progress indication during the processing.\n"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": 71,
333
+ "id": "1a6f445f-6d38-4997-a217-19e337e96f94",
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "import email"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": 72,
343
+ "id": "25d57e39-8fc7-44ca-9d6b-16bd7476535e",
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": [
347
+ "from tqdm.notebook import tqdm"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "markdown",
352
+ "id": "a65ce18d-ecf2-4481-899c-40a333d4ee17",
353
+ "metadata": {},
354
+ "source": [
355
+ "## Email Parsing\n",
356
+ "The dataset consists of raw email data, which requires parsing to extract useful information such as subject lines, senders, and message bodies."
357
+ ]
358
+ },
359
+ {
360
+ "cell_type": "code",
361
+ "execution_count": 73,
362
+ "id": "ce454495-3b3f-4f45-bbee-e05884c4d2d3",
363
+ "metadata": {},
364
+ "outputs": [],
365
+ "source": [
366
+ "# create list of email objects\n",
367
+ "emails = list(map(email.parser.Parser().parsestr,df_sampled['message']))\n",
368
+ "\n",
369
+ "# extract headings such as subject, from, to etc..\n",
370
+ "headings = emails[0].keys()\n",
371
+ "\n",
372
+ "# Goes through each email and grabs info for each key\n",
373
+ "# doc['From'] grabs who sent email in all emails\n",
374
+ "for key in headings:\n",
375
+ " df_sampled[key] = [doc[key] for doc in emails]"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "markdown",
380
+ "id": "8ef441dd-d9e7-454c-ad5f-c285cebc0868",
381
+ "metadata": {},
382
+ "source": [
383
+ "## Preview of Parsed Emails\n",
384
+ "After parsing and extracting the information, we can now preview the data to ensure it contains the desired structured fields."
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": 74,
390
+ "id": "cad7d210-68f2-4289-99d3-4f558c503e3b",
391
+ "metadata": {},
392
+ "outputs": [
393
+ {
394
+ "data": {
395
+ "text/html": [
396
+ "<div>\n",
397
+ "<style scoped>\n",
398
+ " .dataframe tbody tr th:only-of-type {\n",
399
+ " vertical-align: middle;\n",
400
+ " }\n",
401
+ "\n",
402
+ " .dataframe tbody tr th {\n",
403
+ " vertical-align: top;\n",
404
+ " }\n",
405
+ "\n",
406
+ " .dataframe thead th {\n",
407
+ " text-align: right;\n",
408
+ " }\n",
409
+ "</style>\n",
410
+ "<table border=\"1\" class=\"dataframe\">\n",
411
+ " <thead>\n",
412
+ " <tr style=\"text-align: right;\">\n",
413
+ " <th></th>\n",
414
+ " <th>file</th>\n",
415
+ " <th>message</th>\n",
416
+ " <th>Message-ID</th>\n",
417
+ " <th>Date</th>\n",
418
+ " <th>From</th>\n",
419
+ " <th>To</th>\n",
420
+ " <th>Subject</th>\n",
421
+ " <th>Mime-Version</th>\n",
422
+ " <th>Content-Type</th>\n",
423
+ " <th>Content-Transfer-Encoding</th>\n",
424
+ " <th>X-From</th>\n",
425
+ " <th>X-To</th>\n",
426
+ " <th>X-cc</th>\n",
427
+ " <th>X-bcc</th>\n",
428
+ " <th>X-Folder</th>\n",
429
+ " <th>X-Origin</th>\n",
430
+ " <th>X-FileName</th>\n",
431
+ " </tr>\n",
432
+ " </thead>\n",
433
+ " <tbody>\n",
434
+ " <tr>\n",
435
+ " <th>0</th>\n",
436
+ " <td>jones-t/all_documents/634.</td>\n",
437
+ " <td>Message-ID: &lt;17820178.1075846925335.JavaMail.e...</td>\n",
438
+ " <td>&lt;17820178.1075846925335.JavaMail.evans@thyme&gt;</td>\n",
439
+ " <td>Tue, 4 Jan 2000 08:20:00 -0800 (PST)</td>\n",
440
+ " <td>tana.jones@enron.com</td>\n",
441
+ " <td>alicia.goodrow@enron.com</td>\n",
442
+ " <td>Re: Dinner</td>\n",
443
+ " <td>1.0</td>\n",
444
+ " <td>text/plain; charset=us-ascii</td>\n",
445
+ " <td>7bit</td>\n",
446
+ " <td>Tana Jones</td>\n",
447
+ " <td>Alicia Goodrow</td>\n",
448
+ " <td></td>\n",
449
+ " <td></td>\n",
450
+ " <td>\\Tanya_Jones_Dec2000\\Notes Folders\\All documents</td>\n",
451
+ " <td>JONES-T</td>\n",
452
+ " <td>tjones.nsf</td>\n",
453
+ " </tr>\n",
454
+ " <tr>\n",
455
+ " <th>1</th>\n",
456
+ " <td>mann-k/all_documents/5690.</td>\n",
457
+ " <td>Message-ID: &lt;29110382.1075845717882.JavaMail.e...</td>\n",
458
+ " <td>&lt;29110382.1075845717882.JavaMail.evans@thyme&gt;</td>\n",
459
+ " <td>Tue, 15 May 2001 11:03:00 -0700 (PDT)</td>\n",
460
+ " <td>kay.mann@enron.com</td>\n",
461
+ " <td>sheila.tweed@enron.com</td>\n",
462
+ " <td>Re: Override letter</td>\n",
463
+ " <td>1.0</td>\n",
464
+ " <td>text/plain; charset=us-ascii</td>\n",
465
+ " <td>7bit</td>\n",
466
+ " <td>Kay Mann</td>\n",
467
+ " <td>Sheila Tweed</td>\n",
468
+ " <td></td>\n",
469
+ " <td></td>\n",
470
+ " <td>\\Kay_Mann_June2001_1\\Notes Folders\\All documents</td>\n",
471
+ " <td>MANN-K</td>\n",
472
+ " <td>kmann.nsf</td>\n",
473
+ " </tr>\n",
474
+ " <tr>\n",
475
+ " <th>2</th>\n",
476
+ " <td>dasovich-j/sent/423.</td>\n",
477
+ " <td>Message-ID: &lt;6812040.1075843194135.JavaMail.ev...</td>\n",
478
+ " <td>&lt;6812040.1075843194135.JavaMail.evans@thyme&gt;</td>\n",
479
+ " <td>Thu, 28 Sep 2000 08:59:00 -0700 (PDT)</td>\n",
480
+ " <td>jeff.dasovich@enron.com</td>\n",
481
+ " <td>christine.piesco@oracle.com</td>\n",
482
+ " <td>Teams</td>\n",
483
+ " <td>1.0</td>\n",
484
+ " <td>text/plain; charset=us-ascii</td>\n",
485
+ " <td>7bit</td>\n",
486
+ " <td>Jeff Dasovich</td>\n",
487
+ " <td>Christine.Piesco@oracle.com</td>\n",
488
+ " <td></td>\n",
489
+ " <td></td>\n",
490
+ " <td>\\Jeff_Dasovich_Dec2000\\Notes Folders\\Sent</td>\n",
491
+ " <td>DASOVICH-J</td>\n",
492
+ " <td>jdasovic.nsf</td>\n",
493
+ " </tr>\n",
494
+ " <tr>\n",
495
+ " <th>3</th>\n",
496
+ " <td>kaminski-v/var/63.</td>\n",
497
+ " <td>Message-ID: &lt;21547648.1075856642126.JavaMail.e...</td>\n",
498
+ " <td>&lt;21547648.1075856642126.JavaMail.evans@thyme&gt;</td>\n",
499
+ " <td>Mon, 9 Oct 2000 01:23:00 -0700 (PDT)</td>\n",
500
+ " <td>tanya.tamarchenko@enron.com</td>\n",
501
+ " <td>vince.kaminski@enron.com</td>\n",
502
+ " <td>Re: FYI: UK Var issues</td>\n",
503
+ " <td>1.0</td>\n",
504
+ " <td>text/plain; charset=us-ascii</td>\n",
505
+ " <td>7bit</td>\n",
506
+ " <td>Tanya Tamarchenko</td>\n",
507
+ " <td>Vince J Kaminski</td>\n",
508
+ " <td></td>\n",
509
+ " <td></td>\n",
510
+ " <td>\\Vincent_Kaminski_Jun2001_5\\Notes Folders\\Var</td>\n",
511
+ " <td>Kaminski-V</td>\n",
512
+ " <td>vkamins.nsf</td>\n",
513
+ " </tr>\n",
514
+ " <tr>\n",
515
+ " <th>4</th>\n",
516
+ " <td>mann-k/_sent_mail/3208.</td>\n",
517
+ " <td>Message-ID: &lt;12684200.1075846107179.JavaMail.e...</td>\n",
518
+ " <td>&lt;12684200.1075846107179.JavaMail.evans@thyme&gt;</td>\n",
519
+ " <td>Fri, 13 Oct 2000 01:50:00 -0700 (PDT)</td>\n",
520
+ " <td>kay.mann@enron.com</td>\n",
521
+ " <td>lisa.bills@enron.com, ben.jacoby@enron.com</td>\n",
522
+ " <td>Change Order #5--Pleasanton Transformer</td>\n",
523
+ " <td>1.0</td>\n",
524
+ " <td>text/plain; charset=us-ascii</td>\n",
525
+ " <td>7bit</td>\n",
526
+ " <td>Kay Mann</td>\n",
527
+ " <td>Lisa Bills, Ben Jacoby</td>\n",
528
+ " <td></td>\n",
529
+ " <td></td>\n",
530
+ " <td>\\Kay_Mann_June2001_4\\Notes Folders\\'sent mail</td>\n",
531
+ " <td>MANN-K</td>\n",
532
+ " <td>kmann.nsf</td>\n",
533
+ " </tr>\n",
534
+ " </tbody>\n",
535
+ "</table>\n",
536
+ "</div>"
537
+ ],
538
+ "text/plain": [
539
+ " file \\\n",
540
+ "0 jones-t/all_documents/634. \n",
541
+ "1 mann-k/all_documents/5690. \n",
542
+ "2 dasovich-j/sent/423. \n",
543
+ "3 kaminski-v/var/63. \n",
544
+ "4 mann-k/_sent_mail/3208. \n",
545
+ "\n",
546
+ " message \\\n",
547
+ "0 Message-ID: <17820178.1075846925335.JavaMail.e... \n",
548
+ "1 Message-ID: <29110382.1075845717882.JavaMail.e... \n",
549
+ "2 Message-ID: <6812040.1075843194135.JavaMail.ev... \n",
550
+ "3 Message-ID: <21547648.1075856642126.JavaMail.e... \n",
551
+ "4 Message-ID: <12684200.1075846107179.JavaMail.e... \n",
552
+ "\n",
553
+ " Message-ID \\\n",
554
+ "0 <17820178.1075846925335.JavaMail.evans@thyme> \n",
555
+ "1 <29110382.1075845717882.JavaMail.evans@thyme> \n",
556
+ "2 <6812040.1075843194135.JavaMail.evans@thyme> \n",
557
+ "3 <21547648.1075856642126.JavaMail.evans@thyme> \n",
558
+ "4 <12684200.1075846107179.JavaMail.evans@thyme> \n",
559
+ "\n",
560
+ " Date From \\\n",
561
+ "0 Tue, 4 Jan 2000 08:20:00 -0800 (PST) tana.jones@enron.com \n",
562
+ "1 Tue, 15 May 2001 11:03:00 -0700 (PDT) kay.mann@enron.com \n",
563
+ "2 Thu, 28 Sep 2000 08:59:00 -0700 (PDT) jeff.dasovich@enron.com \n",
564
+ "3 Mon, 9 Oct 2000 01:23:00 -0700 (PDT) tanya.tamarchenko@enron.com \n",
565
+ "4 Fri, 13 Oct 2000 01:50:00 -0700 (PDT) kay.mann@enron.com \n",
566
+ "\n",
567
+ " To \\\n",
568
+ "0 alicia.goodrow@enron.com \n",
569
+ "1 sheila.tweed@enron.com \n",
570
+ "2 christine.piesco@oracle.com \n",
571
+ "3 vince.kaminski@enron.com \n",
572
+ "4 lisa.bills@enron.com, ben.jacoby@enron.com \n",
573
+ "\n",
574
+ " Subject Mime-Version \\\n",
575
+ "0 Re: Dinner 1.0 \n",
576
+ "1 Re: Override letter 1.0 \n",
577
+ "2 Teams 1.0 \n",
578
+ "3 Re: FYI: UK Var issues 1.0 \n",
579
+ "4 Change Order #5--Pleasanton Transformer 1.0 \n",
580
+ "\n",
581
+ " Content-Type Content-Transfer-Encoding X-From \\\n",
582
+ "0 text/plain; charset=us-ascii 7bit Tana Jones \n",
583
+ "1 text/plain; charset=us-ascii 7bit Kay Mann \n",
584
+ "2 text/plain; charset=us-ascii 7bit Jeff Dasovich \n",
585
+ "3 text/plain; charset=us-ascii 7bit Tanya Tamarchenko \n",
586
+ "4 text/plain; charset=us-ascii 7bit Kay Mann \n",
587
+ "\n",
588
+ " X-To X-cc X-bcc \\\n",
589
+ "0 Alicia Goodrow \n",
590
+ "1 Sheila Tweed \n",
591
+ "2 Christine.Piesco@oracle.com \n",
592
+ "3 Vince J Kaminski \n",
593
+ "4 Lisa Bills, Ben Jacoby \n",
594
+ "\n",
595
+ " X-Folder X-Origin X-FileName \n",
596
+ "0 \\Tanya_Jones_Dec2000\\Notes Folders\\All documents JONES-T tjones.nsf \n",
597
+ "1 \\Kay_Mann_June2001_1\\Notes Folders\\All documents MANN-K kmann.nsf \n",
598
+ "2 \\Jeff_Dasovich_Dec2000\\Notes Folders\\Sent DASOVICH-J jdasovic.nsf \n",
599
+ "3 \\Vincent_Kaminski_Jun2001_5\\Notes Folders\\Var Kaminski-V vkamins.nsf \n",
600
+ "4 \\Kay_Mann_June2001_4\\Notes Folders\\'sent mail MANN-K kmann.nsf "
601
+ ]
602
+ },
603
+ "execution_count": 74,
604
+ "metadata": {},
605
+ "output_type": "execute_result"
606
+ }
607
+ ],
608
+ "source": [
609
+ "df_sampled.head()"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "markdown",
614
+ "id": "fa5ee392-d072-44c8-8218-261afbdc4e2f",
615
+ "metadata": {},
616
+ "source": [
617
+ "## Text Embedding with BERT\n",
618
+ "\n",
619
+ "### Installing and Importing Libraries\n",
620
+ "For this section, we install and import the necessary libraries for working with the BERT model.\n"
621
+ ]
622
+ },
623
+ {
624
+ "cell_type": "code",
625
+ "execution_count": 75,
626
+ "id": "5c77db4c-1d89-4777-bfc1-e0fcaadf83de",
627
+ "metadata": {
628
+ "collapsed": true,
629
+ "jupyter": {
630
+ "outputs_hidden": true
631
+ }
632
+ },
633
+ "outputs": [
634
+ {
635
+ "name": "stdout",
636
+ "output_type": "stream",
637
+ "text": [
638
+ "Requirement already satisfied: transformers in c:\\users\\stylianos\\myenv\\lib\\site-packages (4.39.3)\n",
639
+ "Requirement already satisfied: datasets in c:\\users\\stylianos\\myenv\\lib\\site-packages (2.18.0)\n",
640
+ "Requirement already satisfied: filelock in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (3.13.4)\n",
641
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (0.22.2)\n",
642
+ "Requirement already satisfied: numpy>=1.17 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (1.26.4)\n",
643
+ "Requirement already satisfied: packaging>=20.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (24.0)\n",
644
+ "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (6.0.1)\n",
645
+ "Requirement already satisfied: regex!=2019.12.17 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (2023.12.25)\n",
646
+ "Requirement already satisfied: requests in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (2.31.0)\n",
647
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (0.15.2)\n",
648
+ "Requirement already satisfied: safetensors>=0.4.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (0.4.2)\n",
649
+ "Requirement already satisfied: tqdm>=4.27 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (4.66.2)\n",
650
+ "Requirement already satisfied: pyarrow>=12.0.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from datasets) (15.0.2)\n",
651
+ "Requirement already satisfied: pyarrow-hotfix in c:\\users\\stylianos\\myenv\\lib\\site-packages (from datasets) (0.6)\n",
652
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from datasets) (0.3.8)\n",
653
+ "Requirement already satisfied: pandas in c:\\users\\stylianos\\myenv\\lib\\site-packages (from datasets) (2.2.2)\n",
654
+ "Requirement already satisfied: xxhash in c:\\users\\stylianos\\myenv\\lib\\site-packages (from datasets) (3.4.1)\n",
655
+ "Requirement already satisfied: multiprocess in c:\\users\\stylianos\\myenv\\lib\\site-packages (from datasets) (0.70.16)\n",
656
+ "Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets) (2024.2.0)\n",
657
+ "Requirement already satisfied: aiohttp in c:\\users\\stylianos\\myenv\\lib\\site-packages (from datasets) (3.9.4)\n",
658
+ "Requirement already satisfied: aiosignal>=1.1.2 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from aiohttp->datasets) (1.3.1)\n",
659
+ "Requirement already satisfied: attrs>=17.3.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from aiohttp->datasets) (23.2.0)\n",
660
+ "Requirement already satisfied: frozenlist>=1.1.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from aiohttp->datasets) (1.4.1)\n",
661
+ "Requirement already satisfied: multidict<7.0,>=4.5 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from aiohttp->datasets) (6.0.5)\n",
662
+ "Requirement already satisfied: yarl<2.0,>=1.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from aiohttp->datasets) (1.9.4)\n",
663
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (4.11.0)\n",
664
+ "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->transformers) (3.3.2)\n",
665
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->transformers) (3.7)\n",
666
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->transformers) (2.2.1)\n",
667
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->transformers) (2024.2.2)\n",
668
+ "Requirement already satisfied: colorama in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tqdm>=4.27->transformers) (0.4.6)\n",
669
+ "Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from pandas->datasets) (2.9.0.post0)\n",
670
+ "Requirement already satisfied: pytz>=2020.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from pandas->datasets) (2024.1)\n",
671
+ "Requirement already satisfied: tzdata>=2022.7 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from pandas->datasets) (2024.1)\n",
672
+ "Requirement already satisfied: six>=1.5 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n"
673
+ ]
674
+ }
675
+ ],
676
+ "source": [
677
+ "!pip install transformers datasets"
678
+ ]
679
+ },
680
+ {
681
+ "cell_type": "code",
682
+ "execution_count": 76,
683
+ "id": "f4332ec3-ebab-47c5-b23f-5f1baf48e884",
684
+ "metadata": {},
685
+ "outputs": [],
686
+ "source": [
687
+ "from transformers import BertModel, BertTokenizer\n",
688
+ "import torch\n",
689
+ "from tqdm import tqdm"
690
+ ]
691
+ },
692
+ {
693
+ "cell_type": "markdown",
694
+ "id": "7fb756ee-a131-467c-be1b-32e6a33490d7",
695
+ "metadata": {},
696
+ "source": [
697
+ "## Loading the Pre-trained BERT Model and Tokenizer\n",
698
+ "We load a pre-trained BERT model and its corresponding tokenizer. This will allow us to convert the email text into a format that BERT can understand."
699
+ ]
700
+ },
701
+ {
702
+ "cell_type": "code",
703
+ "execution_count": 77,
704
+ "id": "c977cf03-f678-4cbb-b7e6-7db77fe6b32f",
705
+ "metadata": {
706
+ "collapsed": true,
707
+ "jupyter": {
708
+ "outputs_hidden": true
709
+ }
710
+ },
711
+ "outputs": [
712
+ {
713
+ "data": {
714
+ "text/plain": [
715
+ "BertModel(\n",
716
+ " (embeddings): BertEmbeddings(\n",
717
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
718
+ " (position_embeddings): Embedding(512, 768)\n",
719
+ " (token_type_embeddings): Embedding(2, 768)\n",
720
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
721
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
722
+ " )\n",
723
+ " (encoder): BertEncoder(\n",
724
+ " (layer): ModuleList(\n",
725
+ " (0-11): 12 x BertLayer(\n",
726
+ " (attention): BertAttention(\n",
727
+ " (self): BertSelfAttention(\n",
728
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
729
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
730
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
731
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
732
+ " )\n",
733
+ " (output): BertSelfOutput(\n",
734
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
735
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
736
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
737
+ " )\n",
738
+ " )\n",
739
+ " (intermediate): BertIntermediate(\n",
740
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
741
+ " (intermediate_act_fn): GELUActivation()\n",
742
+ " )\n",
743
+ " (output): BertOutput(\n",
744
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
745
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
746
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
747
+ " )\n",
748
+ " )\n",
749
+ " )\n",
750
+ " )\n",
751
+ " (pooler): BertPooler(\n",
752
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
753
+ " (activation): Tanh()\n",
754
+ " )\n",
755
+ ")"
756
+ ]
757
+ },
758
+ "execution_count": 77,
759
+ "metadata": {},
760
+ "output_type": "execute_result"
761
+ }
762
+ ],
763
+ "source": [
764
+ "# Load pre-trained model tokenizer\n",
765
+ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
766
+ "\n",
767
+ "# Load pre-trained model\n",
768
+ "model = BertModel.from_pretrained('bert-base-uncased')\n",
769
+ "model.eval() "
770
+ ]
771
+ },
772
+ {
773
+ "cell_type": "markdown",
774
+ "id": "d50278dd-5429-47db-a0fc-6bed78bbc743",
775
+ "metadata": {},
776
+ "source": [
777
+ "## Tokenization of Emails\n",
778
+ "We tokenize a small batch of emails to prepare them for embedding with the model"
779
+ ]
780
+ },
781
+ {
782
+ "cell_type": "code",
783
+ "execution_count": 78,
784
+ "id": "436eff2d-6740-45f1-bbe3-db46ef28d1a5",
785
+ "metadata": {},
786
+ "outputs": [
787
+ {
788
+ "data": {
789
+ "text/plain": [
790
+ "{'input_ids': tensor([[ 101, 4471, 1011, ..., 0, 0, 0],\n",
791
+ " [ 101, 4471, 1011, ..., 0, 0, 0],\n",
792
+ " [ 101, 4471, 1011, ..., 0, 0, 0],\n",
793
+ " [ 101, 4471, 1011, ..., 4092, 1010, 102],\n",
794
+ " [ 101, 4471, 1011, ..., 2241, 2006, 102]]), 'token_type_ids': tensor([[0, 0, 0, ..., 0, 0, 0],\n",
795
+ " [0, 0, 0, ..., 0, 0, 0],\n",
796
+ " [0, 0, 0, ..., 0, 0, 0],\n",
797
+ " [0, 0, 0, ..., 0, 0, 0],\n",
798
+ " [0, 0, 0, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n",
799
+ " [1, 1, 1, ..., 0, 0, 0],\n",
800
+ " [1, 1, 1, ..., 0, 0, 0],\n",
801
+ " [1, 1, 1, ..., 1, 1, 1],\n",
802
+ " [1, 1, 1, ..., 1, 1, 1]])}"
803
+ ]
804
+ },
805
+ "execution_count": 78,
806
+ "metadata": {},
807
+ "output_type": "execute_result"
808
+ }
809
+ ],
810
+ "source": [
811
+ "# Tokenize a small batch of emails\n",
812
+ "sample_texts = df_sampled['message'][:5].tolist() # Take the first 5 messages\n",
813
+ "encoded_input = tokenizer(sample_texts, return_tensors='pt', padding=True, truncation=True)\n",
814
+ "\n",
815
+ "# Display the tokenized output\n",
816
+ "encoded_input"
817
+ ]
818
+ },
819
+ {
820
+ "cell_type": "markdown",
821
+ "id": "d6171c4f-208d-4c0e-a4b8-ff272dab1a5b",
822
+ "metadata": {},
823
+ "source": [
824
+ "## Generating Embeddings\n",
825
+ "With the data tokenized, we then pass it through the BERT model to obtain embeddings for each email."
826
+ ]
827
+ },
828
+ {
829
+ "cell_type": "code",
830
+ "execution_count": 79,
831
+ "id": "42f6b3b7-db06-4ad0-86cb-ed0dbf59dd77",
832
+ "metadata": {},
833
+ "outputs": [
834
+ {
835
+ "data": {
836
+ "text/plain": [
837
+ "(5, 768)"
838
+ ]
839
+ },
840
+ "execution_count": 79,
841
+ "metadata": {},
842
+ "output_type": "execute_result"
843
+ }
844
+ ],
845
+ "source": [
846
+ "import torch\n",
847
+ "\n",
848
+ "# Disable gradient calculations for performance\n",
849
+ "with torch.no_grad():\n",
850
+ " # Forward pass, get model output\n",
851
+ " model_output = model(**encoded_input)\n",
852
+ "\n",
853
+ "# Take the mean of the last layer hidden-states to get a single vector embedding per email\n",
854
+ "embeddings = model_output.last_hidden_state.mean(dim=1)\n",
855
+ "\n",
856
+ "# Move the embeddings to the CPU and convert to numpy for easier handling\n",
857
+ "embeddings = embeddings.cpu().numpy()\n",
858
+ "\n",
859
+ "# Display the embeddings shape\n",
860
+ "embeddings.shape"
861
+ ]
862
+ },
863
+ {
864
+ "cell_type": "code",
865
+ "execution_count": 80,
866
+ "id": "440d9e13-dac2-4c3d-a763-c544f136ceea",
867
+ "metadata": {},
868
+ "outputs": [
869
+ {
870
+ "name": "stdout",
871
+ "output_type": "stream",
872
+ "text": [
873
+ "Total number of emails: 103480\n",
874
+ "Average length of emails: 2707.048473134905\n"
875
+ ]
876
+ }
877
+ ],
878
+ "source": [
879
+ "# Check the number of emails\n",
880
+ "number_of_emails = df_sampled['message'].shape[0]\n",
881
+ "print(f\"Total number of emails: {number_of_emails}\")\n",
882
+ "\n",
883
+ "# Check the average length of the emails\n",
884
+ "average_length = df_sampled['message'].str.len().mean()\n",
885
+ "print(f\"Average length of emails: {average_length}\")\n"
886
+ ]
887
+ },
888
+ {
889
+ "cell_type": "markdown",
890
+ "id": "b9f10e59-1435-4262-9688-d6a8f712dd68",
891
+ "metadata": {},
892
+ "source": [
893
+ "## Batch Processing for Embeddings\n",
894
+ "\n",
895
+ "### Importing NumPy\n",
896
+ "We import `numpy` to work with arrays efficiently, which will be essential for handling the embeddings.\n"
897
+ ]
898
+ },
899
+ {
900
+ "cell_type": "code",
901
+ "execution_count": 81,
902
+ "id": "1c6dfb72-1976-418a-9d71-396a2513dda6",
903
+ "metadata": {},
904
+ "outputs": [],
905
+ "source": [
906
+ "from tqdm import tqdm"
907
+ ]
908
+ },
909
+ {
910
+ "cell_type": "code",
911
+ "execution_count": 82,
912
+ "id": "91de2f6a-66e8-4eba-a814-a8ebb5a00044",
913
+ "metadata": {},
914
+ "outputs": [],
915
+ "source": [
916
+ "import numpy as np"
917
+ ]
918
+ },
919
+ {
920
+ "cell_type": "markdown",
921
+ "id": "822ea130-8434-4763-8c0c-2302c806c0f3",
922
+ "metadata": {},
923
+ "source": [
924
+ "## Note on Embeddings Generation\n",
925
+ "\n",
926
+ "Due to the extensive computation time required to process and generate embeddings for the dataset, the operation was performed once and the results were saved to disk. This approach allows us to preserve the computational results and avoid re-running the same expensive computation, especially when the kernel is restarted or when further analysis is needed at a later stage.\n",
927
+ "\n",
928
+ "The batch processing cell, which generates the embeddings, took approximately 24 hours to complete. After the embeddings were generated, they were saved to the disk using `numpy.save('all_email_embeddings.npy', all_embeddings_array)`. In subsequent sessions, we can simply load the pre-computed embeddings from the disk using `numpy.load('all_email_embeddings.npy')`, which significantly reduces the startup time for our analysis and enables us to proceed with further data processing and model training without delay.\n"
929
+ ]
930
+ },
931
+ {
932
+ "cell_type": "raw",
933
+ "id": "28b35ff3-b6f5-4c76-87ea-6cab7a458bca",
934
+ "metadata": {},
935
+ "source": [
936
+ "#batch_size = 64\n",
937
+ "\n",
938
+ "# Initialize a list to store the embeddings\n",
939
+ "#all_embeddings = []\n",
940
+ "\n",
941
+ "# Process the entire sampled data in batches\n",
942
+ "#for i in tqdm(range(0, len(df_sampled), batch_size), desc=\"Batch Processing\"):\n",
943
+ " # Select the batch of texts\n",
944
+ "# batch_texts = df_sampled['message'][i:i+batch_size].tolist()\n",
945
+ " \n",
946
+ " # Tokenize the batch\n",
947
+ "# encoded_batch = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True)\n",
948
+ " \n",
949
+ " # Generate embeddings\n",
950
+ "# with torch.no_grad():\n",
951
+ "# batch_output = model(**encoded_batch)\n",
952
+ " \n",
953
+ " # Calculate the mean of the last hidden states to get one embedding per email\n",
954
+ "# batch_embeddings = batch_output.last_hidden_state.mean(dim=1).cpu().numpy()\n",
955
+ " \n",
956
+ " # Append the embeddings to the list\n",
957
+ "# all_embeddings.extend(batch_embeddings)\n",
958
+ " \n",
959
+ " # Save embeddings of the batch to disk\n",
960
+ "# np.save(f'embeddings_batch_{i}.npy', batch_embeddings)\n",
961
+ "\n",
962
+ "# Convert the list of all embeddings to a NumPy array\n",
963
+ "#all_embeddings_array = np.array(all_embeddings)\n",
964
+ "\n",
965
+ "# Save the complete array of embeddings to disk\n",
966
+ "#np.save('all_email_embeddings.npy', all_embeddings_array)\n"
967
+ ]
968
+ },
969
+ {
970
+ "cell_type": "markdown",
971
+ "id": "2d0a6fc7-d35a-4634-8910-fee5d30d4cdc",
972
+ "metadata": {},
973
+ "source": [
974
+ "## Loading Embeddings"
975
+ ]
976
+ },
977
+ {
978
+ "cell_type": "code",
979
+ "execution_count": 83,
980
+ "id": "315b92fe-a904-4a1b-b25f-e9f9944df465",
981
+ "metadata": {},
982
+ "outputs": [],
983
+ "source": [
984
+ "all_embeddings_array = np.load('all_email_embeddings.npy')"
985
+ ]
986
+ },
987
+ {
988
+ "cell_type": "markdown",
989
+ "id": "ea1b8585-0409-4c1b-bf4e-2243ac519bd5",
990
+ "metadata": {},
991
+ "source": [
992
+ "## Verifying Embeddings Shape\n",
993
+ "Finally, we print the shape of the embeddings array to confirm its dimensions."
994
+ ]
995
+ },
996
+ {
997
+ "cell_type": "code",
998
+ "execution_count": 84,
999
+ "id": "b6b23b32-5702-4343-a9b0-75d1ceeb1bdd",
1000
+ "metadata": {},
1001
+ "outputs": [
1002
+ {
1003
+ "name": "stdout",
1004
+ "output_type": "stream",
1005
+ "text": [
1006
+ "(103480, 768)\n"
1007
+ ]
1008
+ }
1009
+ ],
1010
+ "source": [
1011
+ "print(all_embeddings_array.shape)\n"
1012
+ ]
1013
+ },
1014
+ {
1015
+ "cell_type": "code",
1016
+ "execution_count": 85,
1017
+ "id": "a1239e7e-c9a5-4240-9d06-bff0446f98f9",
1018
+ "metadata": {},
1019
+ "outputs": [],
1020
+ "source": [
1021
+ "import torch"
1022
+ ]
1023
+ },
1024
+ {
1025
+ "cell_type": "code",
1026
+ "execution_count": 86,
1027
+ "id": "38df9298-22b8-491e-965a-2f84ce12118f",
1028
+ "metadata": {},
1029
+ "outputs": [],
1030
+ "source": [
1031
+ "from transformers import BertTokenizer, BertForSequenceClassification\n",
1032
+ "from datasets import load_dataset\n",
1033
+ "import numpy as np"
1034
+ ]
1035
+ },
1036
+ {
1037
+ "cell_type": "markdown",
1038
+ "id": "a4bf8938-3268-42fd-a791-50cacee6c389",
1039
+ "metadata": {},
1040
+ "source": [
1041
+ "## Efficiency in Tokenization\n",
1042
+ "\n",
1043
+ "### Loading Pre-trained Models for Tokenization\n",
1044
+ "For the purpose of this analysis, we utilize BERT models for tokenizing our dataset. Below is the code used to load a pre-trained BERT model which will be used for tokenization."
1045
+ ]
1046
+ },
1047
+ {
1048
+ "cell_type": "code",
1049
+ "execution_count": 87,
1050
+ "id": "23554539-9122-466e-b5de-53858715943e",
1051
+ "metadata": {},
1052
+ "outputs": [],
1053
+ "source": [
1054
+ "from transformers import BertForMaskedLM\n"
1055
+ ]
1056
+ },
1057
+ {
1058
+ "cell_type": "code",
1059
+ "execution_count": 88,
1060
+ "id": "76b5f211-46ae-41b8-9809-0ac2427134fd",
1061
+ "metadata": {
1062
+ "collapsed": true,
1063
+ "jupyter": {
1064
+ "outputs_hidden": true
1065
+ }
1066
+ },
1067
+ "outputs": [
1068
+ {
1069
+ "name": "stderr",
1070
+ "output_type": "stream",
1071
+ "text": [
1072
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
1073
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1074
+ ]
1075
+ }
1076
+ ],
1077
+ "source": [
1078
+ "# Load pre-trained model\n",
1079
+ "model = BertForSequenceClassification.from_pretrained('bert-base-uncased')"
1080
+ ]
1081
+ },
1082
+ {
1083
+ "cell_type": "markdown",
1084
+ "id": "184a5109-cd05-491b-b4e4-db6ad802fe74",
1085
+ "metadata": {},
1086
+ "source": [
1087
+ "## Ensuring Data Consistency\n",
1088
+ "Before proceeding with tokenization, we ensure all messages are of string type, to avoid any issues during the tokenization process"
1089
+ ]
1090
+ },
1091
+ {
1092
+ "cell_type": "code",
1093
+ "execution_count": 89,
1094
+ "id": "2565eb04-37b4-4195-9f92-b82a948f0598",
1095
+ "metadata": {},
1096
+ "outputs": [],
1097
+ "source": [
1098
+ "df['message'] = df['message'].astype(str)"
1099
+ ]
1100
+ },
1101
+ {
1102
+ "cell_type": "code",
1103
+ "execution_count": 90,
1104
+ "id": "a3cea0d3-8a4c-4951-ba46-f8742efb84c9",
1105
+ "metadata": {},
1106
+ "outputs": [],
1107
+ "source": [
1108
+ "from tqdm.auto import tqdm\n",
1109
+ "tqdm.pandas(desc=\"Tokenizing messages\")"
1110
+ ]
1111
+ },
1112
+ {
1113
+ "cell_type": "markdown",
1114
+ "id": "32450c8d-9eb9-47e0-a2b1-9f0b5e801a10",
1115
+ "metadata": {},
1116
+ "source": [
1117
+ "## Tokenization Process\n",
1118
+ "### Saving Tokenized Data for Reuse\n",
1119
+ "To enhance efficiency, the tokenized data is saved to disk after processing. This enables us to reuse the tokenized data without having to repeat the tokenization process, which is particularly helpful when dealing with large datasets or in cases where the session is interrupted, and the kernel needs to be restarted."
1120
+ ]
1121
+ },
1122
+ {
1123
+ "cell_type": "raw",
1124
+ "id": "8983cb95-3809-4d36-bcf8-b364692796de",
1125
+ "metadata": {},
1126
+ "source": [
1127
+ "# Ensure all messages are strings\n",
1128
+ "\n",
1129
+ "#df['message'] = df['message'].astype(str)\n",
1130
+ " \n",
1131
+ "# Tokenize all messages in the DataFrame\n",
1132
+ "\n",
1133
+ "#tokenized_inputs = df['message'].progress_map(\n",
1134
+ "\n",
1135
+ "# lambda x: tokenizer(x, padding='max_length', truncation=True, max_length=512)\n",
1136
+ "\n",
1137
+ "#)\n",
1138
+ " \n",
1139
+ "# Now, extract the tokenized inputs into their own columns\n",
1140
+ "\n",
1141
+ "#df['input_ids'] = tokenized_inputs.apply(lambda x: x['input_ids'])\n",
1142
+ "\n",
1143
+ "#df['attention_mask'] = tokenized_inputs.apply(lambda x: x['attention_mask'])\n"
1144
+ ]
1145
+ },
1146
+ {
1147
+ "cell_type": "code",
1148
+ "execution_count": 91,
1149
+ "id": "cacc9d45-fee6-4cdf-a3ed-827a98b959e4",
1150
+ "metadata": {},
1151
+ "outputs": [],
1152
+ "source": [
1153
+ "import pickle"
1154
+ ]
1155
+ },
1156
+ {
1157
+ "cell_type": "code",
1158
+ "execution_count": 92,
1159
+ "id": "5c2a0225-e0d2-42f9-abc6-e025c7513bd9",
1160
+ "metadata": {},
1161
+ "outputs": [],
1162
+ "source": [
1163
+ "# Load tokenized data from disk\n",
1164
+ "with open('tokenized_data.pkl', 'rb') as file:\n",
1165
+ " tokenized_inputs = pickle.load(file)"
1166
+ ]
1167
+ },
1168
+ {
1169
+ "cell_type": "markdown",
1170
+ "id": "b32e9cde-700c-472f-9996-dd6727455911",
1171
+ "metadata": {},
1172
+ "source": [
1173
+ "## Verifying the Loaded Data\n",
1174
+ "After loading the pre-tokenized data, we verify its structure to ensure it's in the expected format for further processing."
1175
+ ]
1176
+ },
1177
+ {
1178
+ "cell_type": "code",
1179
+ "execution_count": 93,
1180
+ "id": "04c85c85-da8b-4691-9cda-a3bf8b658b95",
1181
+ "metadata": {},
1182
+ "outputs": [
1183
+ {
1184
+ "name": "stdout",
1185
+ "output_type": "stream",
1186
+ "text": [
1187
+ "RangeIndex(start=0, stop=517401, step=1)\n"
1188
+ ]
1189
+ }
1190
+ ],
1191
+ "source": [
1192
+ "# Print out the keys of the loaded tokenized data\n",
1193
+ "print(tokenized_inputs.keys())"
1194
+ ]
1195
+ },
1196
+ {
1197
+ "cell_type": "code",
1198
+ "execution_count": 94,
1199
+ "id": "8a06a698-468b-415b-9f41-0a1fd4699e2c",
1200
+ "metadata": {},
1201
+ "outputs": [
1202
+ {
1203
+ "name": "stdout",
1204
+ "output_type": "stream",
1205
+ "text": [
1206
+ "0 [input_ids, token_type_ids, attention_mask]\n",
1207
+ "1 [input_ids, token_type_ids, attention_mask]\n",
1208
+ "2 [input_ids, token_type_ids, attention_mask]\n",
1209
+ "3 [input_ids, token_type_ids, attention_mask]\n",
1210
+ "4 [input_ids, token_type_ids, attention_mask]\n",
1211
+ "Name: message, dtype: object\n"
1212
+ ]
1213
+ }
1214
+ ],
1215
+ "source": [
1216
+ "# Display the first few rows of the tokenized data to understand its structure\n",
1217
+ "print(tokenized_inputs.head())"
1218
+ ]
1219
+ },
1220
+ {
1221
+ "cell_type": "code",
1222
+ "execution_count": 95,
1223
+ "id": "ec5cf903-ed83-4e31-bfe4-a0394911d24f",
1224
+ "metadata": {},
1225
+ "outputs": [],
1226
+ "source": [
1227
+ "# Each element in the series is a dictionary, so let's convert the Series to a list of dictionaries\n",
1228
+ "list_of_dicts = tokenized_inputs.tolist()\n",
1229
+ " \n",
1230
+ "# Now, we'll extract the input_ids, token_type_ids, and attention_mask from these dictionaries\n",
1231
+ "input_ids = [d['input_ids'] for d in list_of_dicts]\n",
1232
+ "token_type_ids = [d['token_type_ids'] for d in list_of_dicts]\n",
1233
+ "attention_mask = [d['attention_mask'] for d in list_of_dicts]\n",
1234
+ " \n",
1235
+ "# Create the dictionary for the dataset\n",
1236
+ "tokenized_dict = {\n",
1237
+ " 'input_ids': input_ids,\n",
1238
+ " 'token_type_ids': token_type_ids,\n",
1239
+ " 'attention_mask': attention_mask\n",
1240
+ "}"
1241
+ ]
1242
+ },
1243
+ {
1244
+ "cell_type": "code",
1245
+ "execution_count": 96,
1246
+ "id": "3c4d02aa-8028-4299-9e41-5e5014c86284",
1247
+ "metadata": {},
1248
+ "outputs": [],
1249
+ "source": [
1250
+ "from transformers import Trainer, TrainingArguments"
1251
+ ]
1252
+ },
1253
+ {
1254
+ "cell_type": "code",
1255
+ "execution_count": 97,
1256
+ "id": "170052d2-73a7-4767-8b2b-8691d28abd91",
1257
+ "metadata": {},
1258
+ "outputs": [
1259
+ {
1260
+ "data": {
1261
+ "text/plain": [
1262
+ "file object\n",
1263
+ "message object\n",
1264
+ "dtype: object"
1265
+ ]
1266
+ },
1267
+ "execution_count": 97,
1268
+ "metadata": {},
1269
+ "output_type": "execute_result"
1270
+ }
1271
+ ],
1272
+ "source": [
1273
+ "df.dtypes"
1274
+ ]
1275
+ },
1276
+ {
1277
+ "cell_type": "markdown",
1278
+ "id": "603bcf4f-f926-432a-8183-1d476c3e0963",
1279
+ "metadata": {},
1280
+ "source": [
1281
+ "## Preparing the Data for Model Training\n",
1282
+ "\n",
1283
+ "### Creating a Custom Dataset\n",
1284
+ "The tokenized inputs are organized into a structured format suitable for loading into our model. We define a custom dataset class to handle the BERT input formats:"
1285
+ ]
1286
+ },
1287
+ {
1288
+ "cell_type": "code",
1289
+ "execution_count": 98,
1290
+ "id": "92153c5d-ec17-459f-ba2e-804642495e06",
1291
+ "metadata": {},
1292
+ "outputs": [],
1293
+ "source": [
1294
+ "from torch.utils.data import Dataset\n",
1295
+ " \n",
1296
+ "class EnronDataset(Dataset):\n",
1297
+ "\n",
1298
+ " def __init__(self, encodings):\n",
1299
+ "\n",
1300
+ " self.encodings = encodings\n",
1301
+ " \n",
1302
+ " def __getitem__(self, idx):\n",
1303
+ "\n",
1304
+ " # Retrieve the data at the given index.\n",
1305
+ "\n",
1306
+ " item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
1307
+ "\n",
1308
+ " return item\n",
1309
+ " \n",
1310
+ " def __len__(self):\n",
1311
+ "\n",
1312
+ " # Return the number of items in the dataset.\n",
1313
+ "\n",
1314
+ " return len(self.encodings['input_ids'])"
1315
+ ]
1316
+ },
1317
+ {
1318
+ "cell_type": "code",
1319
+ "execution_count": 99,
1320
+ "id": "c14e2c5f-61a4-4cc0-9013-a57b202c0528",
1321
+ "metadata": {},
1322
+ "outputs": [
1323
+ {
1324
+ "name": "stdout",
1325
+ "output_type": "stream",
1326
+ "text": [
1327
+ "Requirement already satisfied: scikit-learn in c:\\users\\stylianos\\myenv\\lib\\site-packages (1.4.2)\n",
1328
+ "Requirement already satisfied: numpy>=1.19.5 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from scikit-learn) (1.26.4)\n",
1329
+ "Requirement already satisfied: scipy>=1.6.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from scikit-learn) (1.13.0)\n",
1330
+ "Requirement already satisfied: joblib>=1.2.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from scikit-learn) (1.4.0)\n",
1331
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from scikit-learn) (3.4.0)\n"
1332
+ ]
1333
+ }
1334
+ ],
1335
+ "source": [
1336
+ "!pip install scikit-learn"
1337
+ ]
1338
+ },
1339
+ {
1340
+ "cell_type": "code",
1341
+ "execution_count": 100,
1342
+ "id": "e395521f-3b26-488a-9d04-d8079e778d2c",
1343
+ "metadata": {},
1344
+ "outputs": [],
1345
+ "source": [
1346
+ "from sklearn.model_selection import train_test_split"
1347
+ ]
1348
+ },
1349
+ {
1350
+ "cell_type": "markdown",
1351
+ "id": "5a34cee7-8cfe-4f81-8ad5-a3ac76999d81",
1352
+ "metadata": {},
1353
+ "source": [
1354
+ "## Splitting Data and Creating Datasets\n",
1355
+ "\n",
1356
+ "To evaluate the model's performance effectively, we split the tokenized data into training and evaluation datasets. Using the `train_test_split` function, we allocate a portion of the data for evaluation to monitor the model's ability to generalize to new, unseen data."
1357
+ ]
1358
+ },
1359
+ {
1360
+ "cell_type": "code",
1361
+ "execution_count": 101,
1362
+ "id": "c10ed619-0062-4d49-96f2-a6434708e46f",
1363
+ "metadata": {},
1364
+ "outputs": [],
1365
+ "source": [
1366
+ "# Split the tokenized data into training and evaluation datasets\n",
1367
+ "\n",
1368
+ "input_ids_train, input_ids_eval, token_type_ids_train, token_type_ids_eval, attention_mask_train, attention_mask_eval = train_test_split(\n",
1369
+ "\n",
1370
+ " tokenized_dict['input_ids'],\n",
1371
+ "\n",
1372
+ " tokenized_dict['token_type_ids'],\n",
1373
+ "\n",
1374
+ " tokenized_dict['attention_mask'],\n",
1375
+ "\n",
1376
+ " test_size=0.1,\n",
1377
+ "\n",
1378
+ " random_state=42\n",
1379
+ "\n",
1380
+ ")\n",
1381
+ " \n",
1382
+ "# Create the datasets\n",
1383
+ "\n",
1384
+ "train_dataset = EnronDataset({\n",
1385
+ "\n",
1386
+ " 'input_ids': input_ids_train,\n",
1387
+ "\n",
1388
+ " 'token_type_ids': token_type_ids_train,\n",
1389
+ "\n",
1390
+ " 'attention_mask': attention_mask_train\n",
1391
+ "\n",
1392
+ "})\n",
1393
+ " \n",
1394
+ "eval_dataset = EnronDataset({\n",
1395
+ "\n",
1396
+ " 'input_ids': input_ids_eval,\n",
1397
+ "\n",
1398
+ " 'token_type_ids': token_type_ids_eval,\n",
1399
+ "\n",
1400
+ " 'attention_mask': attention_mask_eval\n",
1401
+ "\n",
1402
+ "})\n",
1403
+ "\n"
1404
+ ]
1405
+ },
1406
+ {
1407
+ "cell_type": "code",
1408
+ "execution_count": 102,
1409
+ "id": "ff9ec406-6d02-4cac-942e-9f24c180f855",
1410
+ "metadata": {
1411
+ "collapsed": true,
1412
+ "jupyter": {
1413
+ "outputs_hidden": true
1414
+ }
1415
+ },
1416
+ "outputs": [
1417
+ {
1418
+ "name": "stdout",
1419
+ "output_type": "stream",
1420
+ "text": [
1421
+ "Requirement already satisfied: tensorboard in c:\\users\\stylianos\\myenv\\lib\\site-packages (2.16.2)\n",
1422
+ "Requirement already satisfied: absl-py>=0.4 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tensorboard) (2.1.0)\n",
1423
+ "Requirement already satisfied: grpcio>=1.48.2 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tensorboard) (1.62.1)\n",
1424
+ "Requirement already satisfied: markdown>=2.6.8 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tensorboard) (3.6)\n",
1425
+ "Requirement already satisfied: numpy>=1.12.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tensorboard) (1.26.4)\n",
1426
+ "Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tensorboard) (5.26.1)\n",
1427
+ "Requirement already satisfied: setuptools>=41.0.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tensorboard) (69.5.1)\n",
1428
+ "Requirement already satisfied: six>1.9 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tensorboard) (1.16.0)\n",
1429
+ "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tensorboard) (0.7.2)\n",
1430
+ "Requirement already satisfied: werkzeug>=1.0.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tensorboard) (3.0.2)\n",
1431
+ "Requirement already satisfied: MarkupSafe>=2.1.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from werkzeug>=1.0.1->tensorboard) (2.1.5)\n"
1432
+ ]
1433
+ }
1434
+ ],
1435
+ "source": [
1436
+ "!pip install tensorboard"
1437
+ ]
1438
+ },
1439
+ {
1440
+ "cell_type": "code",
1441
+ "execution_count": 103,
1442
+ "id": "272bdc6b-aee1-453a-a330-3850a7e999ce",
1443
+ "metadata": {},
1444
+ "outputs": [],
1445
+ "source": [
1446
+ "def generate_question(email_content, tokenizer, model):\n",
1447
+ " input_text = \"generate question: \" + email_content\n",
1448
+ " input_ids = tokenizer.encode(input_text, return_tensors=\"pt\")\n",
1449
+ " outputs = model.generate(input_ids, max_length=64, num_return_sequences=3)\n",
1450
+ " questions = [tokenizer.decode(output_ids, skip_special_tokens=True) for output_ids in outputs]\n",
1451
+ " return questions"
1452
+ ]
1453
+ },
1454
+ {
1455
+ "cell_type": "raw",
1456
+ "id": "f29fddf7-07a1-4e83-891d-28861ddb0a2a",
1457
+ "metadata": {},
1458
+ "source": [
1459
+ "!pip install transformers"
1460
+ ]
1461
+ },
1462
+ {
1463
+ "cell_type": "raw",
1464
+ "id": "2e387129-9b10-47cf-86ed-c402342bcb84",
1465
+ "metadata": {},
1466
+ "source": [
1467
+ "!pip install transformers[torch]"
1468
+ ]
1469
+ },
1470
+ {
1471
+ "cell_type": "raw",
1472
+ "id": "83eec78e-9c60-4fbe-b7ba-45a9f814f593",
1473
+ "metadata": {},
1474
+ "source": [
1475
+ "!pip install accelerate"
1476
+ ]
1477
+ },
1478
+ {
1479
+ "cell_type": "raw",
1480
+ "id": "b2144dc2-4f95-434c-9dd5-0179ffab3e42",
1481
+ "metadata": {},
1482
+ "source": [
1483
+ "!pip install accelerate>=0.21.0"
1484
+ ]
1485
+ },
1486
+ {
1487
+ "cell_type": "raw",
1488
+ "id": "f1e17511-1e47-4eb1-b0e3-7507c7eb4d5b",
1489
+ "metadata": {},
1490
+ "source": [
1491
+ "!pip install --upgrade transformers"
1492
+ ]
1493
+ },
1494
+ {
1495
+ "cell_type": "raw",
1496
+ "id": "eea55d1f-9924-4508-8d26-b906e96a2d44",
1497
+ "metadata": {},
1498
+ "source": [
1499
+ "!pip install torch==2.2.2\n"
1500
+ ]
1501
+ },
1502
+ {
1503
+ "cell_type": "markdown",
1504
+ "id": "7543b245-a533-4093-9890-cf13ed2beade",
1505
+ "metadata": {},
1506
+ "source": [
1507
+ "## Model Training Configuration\n",
1508
+ "\n",
1509
+ "### Metrics Definition\n",
1510
+ "We define a `compute_metrics` function to calculate accuracy, precision, recall, and F1 score using scikit-learn's metrics. These metrics will help us evaluate the model's performance.\n"
1511
+ ]
1512
+ },
1513
+ {
1514
+ "cell_type": "markdown",
1515
+ "id": "d1c798ed-06e1-4355-9766-a3b34d9385e3",
1516
+ "metadata": {},
1517
+ "source": [
1518
+ "### Training Arguments Setup\n",
1519
+ "We then set up the training arguments, specifying the output directories for checkpoints and TensorBoard logs, along with the batch size, learning rate, and other training options."
1520
+ ]
1521
+ },
1522
+ {
1523
+ "cell_type": "markdown",
1524
+ "id": "7f1afd22-ef8e-4153-b016-53c41e512167",
1525
+ "metadata": {},
1526
+ "source": [
1527
+ "### Trainer Initialization\n",
1528
+ "The Trainer class from the transformers library is used to handle the training process. It takes our datasets, training arguments, and compute_metrics function to manage the training loop."
1529
+ ]
1530
+ },
1531
+ {
1532
+ "cell_type": "code",
1533
+ "execution_count": 104,
1534
+ "id": "ae11c83a-d824-4f62-8af8-0c82dd50844c",
1535
+ "metadata": {
1536
+ "collapsed": true,
1537
+ "jupyter": {
1538
+ "outputs_hidden": true
1539
+ }
1540
+ },
1541
+ "outputs": [
1542
+ {
1543
+ "name": "stderr",
1544
+ "output_type": "stream",
1545
+ "text": [
1546
+ "C:\\Users\\Stylianos\\myenv\\Lib\\site-packages\\accelerate\\accelerator.py:436: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: \n",
1547
+ "dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)\n",
1548
+ " warnings.warn(\n"
1549
+ ]
1550
+ }
1551
+ ],
1552
+ "source": [
1553
+ "import transformers\n",
1554
+ "from sklearn.metrics import precision_recall_fscore_support, accuracy_score\n",
1555
+ "from transformers import Trainer, TrainingArguments\n",
1556
+ " \n",
1557
+ "# Define the compute_metrics function\n",
1558
+ "def compute_metrics(pred):\n",
1559
+ " labels = pred.label_ids\n",
1560
+ " preds = pred.predictions.argmax(-1)\n",
1561
+ " precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')\n",
1562
+ " acc = accuracy_score(labels, preds)\n",
1563
+ " return {\n",
1564
+ " 'accuracy': acc,\n",
1565
+ " 'f1': f1,\n",
1566
+ " 'precision': precision,\n",
1567
+ " 'recall': recall\n",
1568
+ " }\n",
1569
+ " \n",
1570
+ "from transformers import TrainingArguments\n",
1571
+ " \n",
1572
+ "# Set up the training arguments\n",
1573
+ "training_args = TrainingArguments(\n",
1574
+ " output_dir='./results', # Output directory for model checkpoints\n",
1575
+ " logging_dir='./logs', # Directory for TensorBoard logs\n",
1576
+ " per_device_train_batch_size=8, # Adjust batch size according to your system\n",
1577
+ " per_device_eval_batch_size=8,\n",
1578
+ " num_train_epochs=3, # Adjust number of epochs according to your needs\n",
1579
+ " warmup_steps=500,\n",
1580
+ " weight_decay=0.01,\n",
1581
+ " evaluation_strategy=\"steps\",\n",
1582
+ " logging_steps=10, # Log metrics every 10 steps\n",
1583
+ " save_steps=500, # Save checkpoint every 500 steps\n",
1584
+ " eval_steps=500, # Run evaluation every 500 steps\n",
1585
+ " load_best_model_at_end=True # Load the best model at the end of training\n",
1586
+ ")\n",
1587
+ " \n",
1588
+ "# Initialize the Trainer\n",
1589
+ "trainer = Trainer(\n",
1590
+ " model=model, \n",
1591
+ " args=training_args,\n",
1592
+ " train_dataset=train_dataset, \n",
1593
+ " eval_dataset=eval_dataset, \n",
1594
+ " compute_metrics=compute_metrics \n",
1595
+ ")"
1596
+ ]
1597
+ },
1598
+ {
1599
+ "cell_type": "markdown",
1600
+ "id": "33e483e3-7cbf-472d-b217-39634f412d37",
1601
+ "metadata": {},
1602
+ "source": [
1603
+ "## Model Training and Evaluation\n",
1604
+ "With everything set up, we start the training process and evaluate our model."
1605
+ ]
1606
+ },
1607
+ {
1608
+ "cell_type": "raw",
1609
+ "id": "5c97538c-dd6d-466a-afea-bd66a990d8d0",
1610
+ "metadata": {},
1611
+ "source": [
1612
+ "# Train\n",
1613
+ "trainer.train()"
1614
+ ]
1615
+ },
1616
+ {
1617
+ "cell_type": "raw",
1618
+ "id": "43251b47-e422-48fa-9308-19d0c8b57165",
1619
+ "metadata": {},
1620
+ "source": [
1621
+ "# Evaluate\n",
1622
+ "trainer.evaluate()"
1623
+ ]
1624
+ },
1625
+ {
1626
+ "cell_type": "raw",
1627
+ "id": "241655f4-444b-407f-ae3a-65a2255137e2",
1628
+ "metadata": {},
1629
+ "source": [
1630
+ "# Save the fine-tuned model\n",
1631
+ "model.save_pretrained('my_fine_tuned_model')"
1632
+ ]
1633
+ },
1634
+ {
1635
+ "cell_type": "code",
1636
+ "execution_count": 105,
1637
+ "id": "7efc5916-c05e-4582-be10-ed228a0e4548",
1638
+ "metadata": {},
1639
+ "outputs": [
1640
+ {
1641
+ "name": "stdout",
1642
+ "output_type": "stream",
1643
+ "text": [
1644
+ "The tensorboard extension is already loaded. To reload it, use:\n",
1645
+ " %reload_ext tensorboard\n"
1646
+ ]
1647
+ },
1648
+ {
1649
+ "data": {
1650
+ "text/plain": [
1651
+ "Reusing TensorBoard on port 6006 (pid 11724), started 2:36:13 ago. (Use '!kill 11724' to kill it.)"
1652
+ ]
1653
+ },
1654
+ "metadata": {},
1655
+ "output_type": "display_data"
1656
+ },
1657
+ {
1658
+ "data": {
1659
+ "text/html": [
1660
+ "\n",
1661
+ " <iframe id=\"tensorboard-frame-1c80317fa3b1799d\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
1662
+ " </iframe>\n",
1663
+ " <script>\n",
1664
+ " (function() {\n",
1665
+ " const frame = document.getElementById(\"tensorboard-frame-1c80317fa3b1799d\");\n",
1666
+ " const url = new URL(\"/\", window.location);\n",
1667
+ " const port = 6006;\n",
1668
+ " if (port) {\n",
1669
+ " url.port = port;\n",
1670
+ " }\n",
1671
+ " frame.src = url;\n",
1672
+ " })();\n",
1673
+ " </script>\n",
1674
+ " "
1675
+ ],
1676
+ "text/plain": [
1677
+ "<IPython.core.display.HTML object>"
1678
+ ]
1679
+ },
1680
+ "metadata": {},
1681
+ "output_type": "display_data"
1682
+ }
1683
+ ],
1684
+ "source": [
1685
+ "%load_ext tensorboard\n",
1686
+ "\n",
1687
+ "%tensorboard --logdir logs"
1688
+ ]
1689
+ },
1690
+ {
1691
+ "cell_type": "markdown",
1692
+ "id": "ca643b67-8797-4455-b5ca-a610363e6c4f",
1693
+ "metadata": {},
1694
+ "source": [
1695
+ "## Interactive Question-Answering Interface with Gradio\n",
1696
+ "\n",
1697
+ "### Setting up the Interface\n",
1698
+ "To facilitate user interaction with our model, we deploy a Gradio interface. This enables users to pose questions about the Enron case and receive answers based on the data our model has been trained on.\n",
1699
+ "\n",
1700
+ "### Installation of Gradio\n",
1701
+ "First, we ensure Gradio is installed to create the interactive web application.\n",
1702
+ "\n",
1703
+ "### Preparing the Model for the Interface\n",
1704
+ "We ensure that our model is in evaluation mode and load the tokenizer necessary for processing user inputs.\n",
1705
+ "\n",
1706
+ "### Defining Helper Functions\n",
1707
+ "We define functions to handle the question-answering process and logging of feedback, which can be used for future improvements.\n",
1708
+ "\n",
1709
+ "### Creating the Gradio Interface\n",
1710
+ "A Gradio interface is created with input fields for questions and context, and output for the model's answers. We also provide a feedback mechanism to gather user responses on the accuracy of the answers."
1711
+ ]
1712
+ },
1713
+ {
1714
+ "cell_type": "code",
1715
+ "execution_count": 106,
1716
+ "id": "e81ab0a8-5b0c-424a-b6fd-8a4be3685ab2",
1717
+ "metadata": {
1718
+ "collapsed": true,
1719
+ "jupyter": {
1720
+ "outputs_hidden": true
1721
+ }
1722
+ },
1723
+ "outputs": [
1724
+ {
1725
+ "name": "stdout",
1726
+ "output_type": "stream",
1727
+ "text": [
1728
+ "Requirement already satisfied: gradio in c:\\users\\stylianos\\myenv\\lib\\site-packages (4.26.0)\n",
1729
+ "Requirement already satisfied: aiofiles<24.0,>=22.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (23.2.1)\n",
1730
+ "Requirement already satisfied: altair<6.0,>=4.2.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (5.3.0)\n",
1731
+ "Requirement already satisfied: fastapi in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (0.110.1)\n",
1732
+ "Requirement already satisfied: ffmpy in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (0.3.2)\n",
1733
+ "Requirement already satisfied: gradio-client==0.15.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (0.15.1)\n",
1734
+ "Requirement already satisfied: httpx>=0.24.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (0.27.0)\n",
1735
+ "Requirement already satisfied: huggingface-hub>=0.19.3 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (0.22.2)\n",
1736
+ "Requirement already satisfied: importlib-resources<7.0,>=1.3 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (6.4.0)\n",
1737
+ "Requirement already satisfied: jinja2<4.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (3.1.3)\n",
1738
+ "Requirement already satisfied: markupsafe~=2.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (2.1.5)\n",
1739
+ "Requirement already satisfied: matplotlib~=3.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (3.8.4)\n",
1740
+ "Requirement already satisfied: numpy~=1.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (1.26.4)\n",
1741
+ "Requirement already satisfied: orjson~=3.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (3.10.0)\n",
1742
+ "Requirement already satisfied: packaging in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (24.0)\n",
1743
+ "Requirement already satisfied: pandas<3.0,>=1.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (2.2.2)\n",
1744
+ "Requirement already satisfied: pillow<11.0,>=8.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (10.3.0)\n",
1745
+ "Requirement already satisfied: pydantic>=2.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (2.7.0)\n",
1746
+ "Requirement already satisfied: pydub in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (0.25.1)\n",
1747
+ "Requirement already satisfied: python-multipart>=0.0.9 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (0.0.9)\n",
1748
+ "Requirement already satisfied: pyyaml<7.0,>=5.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (6.0.1)\n",
1749
+ "Requirement already satisfied: ruff>=0.2.2 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (0.3.7)\n",
1750
+ "Requirement already satisfied: semantic-version~=2.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (2.10.0)\n",
1751
+ "Requirement already satisfied: tomlkit==0.12.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (0.12.0)\n",
1752
+ "Requirement already satisfied: typer<1.0,>=0.9 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from typer[all]<1.0,>=0.9; sys_platform != \"emscripten\"->gradio) (0.12.3)\n",
1753
+ "Requirement already satisfied: typing-extensions~=4.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (4.11.0)\n",
1754
+ "Requirement already satisfied: uvicorn>=0.14.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio) (0.29.0)\n",
1755
+ "Requirement already satisfied: fsspec in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio-client==0.15.1->gradio) (2024.2.0)\n",
1756
+ "Requirement already satisfied: websockets<12.0,>=10.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from gradio-client==0.15.1->gradio) (11.0.3)\n",
1757
+ "Requirement already satisfied: jsonschema>=3.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from altair<6.0,>=4.2.0->gradio) (4.21.1)\n",
1758
+ "Requirement already satisfied: toolz in c:\\users\\stylianos\\myenv\\lib\\site-packages (from altair<6.0,>=4.2.0->gradio) (0.12.1)\n",
1759
+ "Requirement already satisfied: anyio in c:\\users\\stylianos\\myenv\\lib\\site-packages (from httpx>=0.24.1->gradio) (4.3.0)\n",
1760
+ "Requirement already satisfied: certifi in c:\\users\\stylianos\\myenv\\lib\\site-packages (from httpx>=0.24.1->gradio) (2024.2.2)\n",
1761
+ "Requirement already satisfied: httpcore==1.* in c:\\users\\stylianos\\myenv\\lib\\site-packages (from httpx>=0.24.1->gradio) (1.0.5)\n",
1762
+ "Requirement already satisfied: idna in c:\\users\\stylianos\\myenv\\lib\\site-packages (from httpx>=0.24.1->gradio) (3.7)\n",
1763
+ "Requirement already satisfied: sniffio in c:\\users\\stylianos\\myenv\\lib\\site-packages (from httpx>=0.24.1->gradio) (1.3.1)\n",
1764
+ "Requirement already satisfied: h11<0.15,>=0.13 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from httpcore==1.*->httpx>=0.24.1->gradio) (0.14.0)\n",
1765
+ "Requirement already satisfied: filelock in c:\\users\\stylianos\\myenv\\lib\\site-packages (from huggingface-hub>=0.19.3->gradio) (3.13.4)\n",
1766
+ "Requirement already satisfied: requests in c:\\users\\stylianos\\myenv\\lib\\site-packages (from huggingface-hub>=0.19.3->gradio) (2.31.0)\n",
1767
+ "Requirement already satisfied: tqdm>=4.42.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from huggingface-hub>=0.19.3->gradio) (4.66.2)\n",
1768
+ "Requirement already satisfied: contourpy>=1.0.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from matplotlib~=3.0->gradio) (1.2.1)\n",
1769
+ "Requirement already satisfied: cycler>=0.10 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from matplotlib~=3.0->gradio) (0.12.1)\n",
1770
+ "Requirement already satisfied: fonttools>=4.22.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from matplotlib~=3.0->gradio) (4.51.0)\n",
1771
+ "Requirement already satisfied: kiwisolver>=1.3.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from matplotlib~=3.0->gradio) (1.4.5)\n",
1772
+ "Requirement already satisfied: pyparsing>=2.3.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from matplotlib~=3.0->gradio) (3.1.2)\n",
1773
+ "Requirement already satisfied: python-dateutil>=2.7 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from matplotlib~=3.0->gradio) (2.9.0.post0)\n",
1774
+ "Requirement already satisfied: pytz>=2020.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from pandas<3.0,>=1.0->gradio) (2024.1)\n",
1775
+ "Requirement already satisfied: tzdata>=2022.7 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from pandas<3.0,>=1.0->gradio) (2024.1)\n",
1776
+ "Requirement already satisfied: annotated-types>=0.4.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from pydantic>=2.0->gradio) (0.6.0)\n",
1777
+ "Requirement already satisfied: pydantic-core==2.18.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from pydantic>=2.0->gradio) (2.18.1)\n",
1778
+ "Requirement already satisfied: click>=8.0.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from typer<1.0,>=0.9->typer[all]<1.0,>=0.9; sys_platform != \"emscripten\"->gradio) (8.1.7)\n",
1779
+ "Requirement already satisfied: shellingham>=1.3.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from typer<1.0,>=0.9->typer[all]<1.0,>=0.9; sys_platform != \"emscripten\"->gradio) (1.5.4)\n",
1780
+ "Requirement already satisfied: rich>=10.11.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from typer<1.0,>=0.9->typer[all]<1.0,>=0.9; sys_platform != \"emscripten\"->gradio) (13.7.1)\n",
1781
+ "Requirement already satisfied: starlette<0.38.0,>=0.37.2 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from fastapi->gradio) (0.37.2)\n",
1782
+ "Requirement already satisfied: colorama in c:\\users\\stylianos\\myenv\\lib\\site-packages (from click>=8.0.0->typer<1.0,>=0.9->typer[all]<1.0,>=0.9; sys_platform != \"emscripten\"->gradio) (0.4.6)\n",
1783
+ "Requirement already satisfied: attrs>=22.2.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (23.2.0)\n",
1784
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (2023.12.1)\n",
1785
+ "Requirement already satisfied: referencing>=0.28.4 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.34.0)\n",
1786
+ "Requirement already satisfied: rpds-py>=0.7.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio) (0.18.0)\n",
1787
+ "Requirement already satisfied: six>=1.5 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio) (1.16.0)\n",
1788
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from rich>=10.11.0->typer<1.0,>=0.9->typer[all]<1.0,>=0.9; sys_platform != \"emscripten\"->gradio) (3.0.0)\n",
1789
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from rich>=10.11.0->typer<1.0,>=0.9->typer[all]<1.0,>=0.9; sys_platform != \"emscripten\"->gradio) (2.17.2)\n",
1790
+ "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->huggingface-hub>=0.19.3->gradio) (3.3.2)\n",
1791
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->huggingface-hub>=0.19.3->gradio) (2.2.1)\n",
1792
+ "Requirement already satisfied: mdurl~=0.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.9->typer[all]<1.0,>=0.9; sys_platform != \"emscripten\"->gradio) (0.1.2)\n"
1793
+ ]
1794
+ },
1795
+ {
1796
+ "name": "stderr",
1797
+ "output_type": "stream",
1798
+ "text": [
1799
+ "WARNING: typer 0.12.3 does not provide the extra 'all'\n"
1800
+ ]
1801
+ }
1802
+ ],
1803
+ "source": [
1804
+ "!pip install gradio"
1805
+ ]
1806
+ },
1807
+ {
1808
+ "cell_type": "code",
1809
+ "execution_count": 107,
1810
+ "id": "02af9a8a-9c26-4209-82ba-8f479f16897c",
1811
+ "metadata": {},
1812
+ "outputs": [],
1813
+ "source": [
1814
+ "import torch\n",
1815
+ "from transformers import AutoModelForCausalLM, AutoTokenizer"
1816
+ ]
1817
+ },
1818
+ {
1819
+ "cell_type": "code",
1820
+ "execution_count": 108,
1821
+ "id": "b977b6a7-229a-4516-8843-298b29734a37",
1822
+ "metadata": {
1823
+ "collapsed": true,
1824
+ "jupyter": {
1825
+ "outputs_hidden": true
1826
+ }
1827
+ },
1828
+ "outputs": [
1829
+ {
1830
+ "data": {
1831
+ "text/plain": [
1832
+ "BertForSequenceClassification(\n",
1833
+ " (bert): BertModel(\n",
1834
+ " (embeddings): BertEmbeddings(\n",
1835
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
1836
+ " (position_embeddings): Embedding(512, 768)\n",
1837
+ " (token_type_embeddings): Embedding(2, 768)\n",
1838
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1839
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1840
+ " )\n",
1841
+ " (encoder): BertEncoder(\n",
1842
+ " (layer): ModuleList(\n",
1843
+ " (0-11): 12 x BertLayer(\n",
1844
+ " (attention): BertAttention(\n",
1845
+ " (self): BertSelfAttention(\n",
1846
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
1847
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
1848
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
1849
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1850
+ " )\n",
1851
+ " (output): BertSelfOutput(\n",
1852
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
1853
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1854
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1855
+ " )\n",
1856
+ " )\n",
1857
+ " (intermediate): BertIntermediate(\n",
1858
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
1859
+ " (intermediate_act_fn): GELUActivation()\n",
1860
+ " )\n",
1861
+ " (output): BertOutput(\n",
1862
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
1863
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1864
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1865
+ " )\n",
1866
+ " )\n",
1867
+ " )\n",
1868
+ " )\n",
1869
+ " (pooler): BertPooler(\n",
1870
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
1871
+ " (activation): Tanh()\n",
1872
+ " )\n",
1873
+ " )\n",
1874
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1875
+ " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
1876
+ ")"
1877
+ ]
1878
+ },
1879
+ "execution_count": 108,
1880
+ "metadata": {},
1881
+ "output_type": "execute_result"
1882
+ }
1883
+ ],
1884
+ "source": [
1885
+ "model.eval()"
1886
+ ]
1887
+ },
1888
+ {
1889
+ "cell_type": "code",
1890
+ "execution_count": 109,
1891
+ "id": "243ab6b3-4db1-4445-8f13-927500cebe3b",
1892
+ "metadata": {
1893
+ "collapsed": true,
1894
+ "jupyter": {
1895
+ "outputs_hidden": true
1896
+ }
1897
+ },
1898
+ "outputs": [
1899
+ {
1900
+ "name": "stdout",
1901
+ "output_type": "stream",
1902
+ "text": [
1903
+ "Requirement already satisfied: transformers in c:\\users\\stylianos\\myenv\\lib\\site-packages (4.39.3)\n",
1904
+ "Requirement already satisfied: filelock in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (3.13.4)\n",
1905
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (0.22.2)\n",
1906
+ "Requirement already satisfied: numpy>=1.17 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (1.26.4)\n",
1907
+ "Requirement already satisfied: packaging>=20.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (24.0)\n",
1908
+ "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (6.0.1)\n",
1909
+ "Requirement already satisfied: regex!=2019.12.17 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (2023.12.25)\n",
1910
+ "Requirement already satisfied: requests in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (2.31.0)\n",
1911
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (0.15.2)\n",
1912
+ "Requirement already satisfied: safetensors>=0.4.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (0.4.2)\n",
1913
+ "Requirement already satisfied: tqdm>=4.27 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from transformers) (4.66.2)\n",
1914
+ "Requirement already satisfied: fsspec>=2023.5.0 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (2024.2.0)\n",
1915
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (4.11.0)\n",
1916
+ "Requirement already satisfied: colorama in c:\\users\\stylianos\\myenv\\lib\\site-packages (from tqdm>=4.27->transformers) (0.4.6)\n",
1917
+ "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->transformers) (3.3.2)\n",
1918
+ "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->transformers) (3.7)\n",
1919
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->transformers) (2.2.1)\n",
1920
+ "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\stylianos\\myenv\\lib\\site-packages (from requests->transformers) (2024.2.2)\n"
1921
+ ]
1922
+ }
1923
+ ],
1924
+ "source": [
1925
+ "!pip install transformers"
1926
+ ]
1927
+ },
1928
+ {
1929
+ "cell_type": "code",
1930
+ "execution_count": 110,
1931
+ "id": "6db70dfb-7a0f-4a3d-8e7e-a11a2cd823f2",
1932
+ "metadata": {},
1933
+ "outputs": [
1934
+ {
1935
+ "name": "stdout",
1936
+ "output_type": "stream",
1937
+ "text": [
1938
+ "Running on local URL: http://127.0.0.1:7862\n",
1939
+ "\n",
1940
+ "To create a public link, set `share=True` in `launch()`.\n"
1941
+ ]
1942
+ },
1943
+ {
1944
+ "data": {
1945
+ "text/html": [
1946
+ "<div><iframe src=\"http://127.0.0.1:7862/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
1947
+ ],
1948
+ "text/plain": [
1949
+ "<IPython.core.display.HTML object>"
1950
+ ]
1951
+ },
1952
+ "metadata": {},
1953
+ "output_type": "display_data"
1954
+ },
1955
+ {
1956
+ "data": {
1957
+ "text/plain": []
1958
+ },
1959
+ "execution_count": 110,
1960
+ "metadata": {},
1961
+ "output_type": "execute_result"
1962
+ }
1963
+ ],
1964
+ "source": [
1965
+ "import gradio as gr\n",
1966
+ "from transformers import pipeline\n",
1967
+ " \n",
1968
+ "# Load a pre-trained question-answering pipeline\n",
1969
+ "qa_pipeline = pipeline('question-answering', model='distilbert-base-uncased-distilled-squad')\n",
1970
+ " \n",
1971
+ "def answer_question(question, context):\n",
1972
+ " # Use the pre-trained pipeline to answer questions\n",
1973
+ " result = qa_pipeline({'question': question, 'context': context})\n",
1974
+ " return result['answer']\n",
1975
+ " \n",
1976
+ "def log_feedback(question, context, answer, correct, feedback):\n",
1977
+ " print(f\"Question: {question}\")\n",
1978
+ " print(f\"Context: {context}\")\n",
1979
+ " print(f\"Answer: {answer}\")\n",
1980
+ " print(f\"Correct: {correct}\")\n",
1981
+ " print(f\"Feedback: {feedback}\")\n",
1982
+ " # Here you can add code to save feedback to a file or a database\n",
1983
+ " \n",
1984
+ "# Define the context about the Enron scandal\n",
1985
+ "enron_context = \"The Enron scandal was an accounting scandal involving Enron Corporation, an American energy company based in Houston, Texas. When news of widespread fraud within the company became public in October 2001, the company declared bankruptcy, and its accounting firm, Arthur Andersen, was effectively dissolved.\"\n",
1986
+ " \n",
1987
+ "# Create the Gradio interface\n",
1988
+ "iface = gr.Interface(\n",
1989
+ " fn=answer_question,\n",
1990
+ " inputs=[\n",
1991
+ " gr.Textbox(lines=2, placeholder=\"Enter a question about the Enron case\"),\n",
1992
+ " gr.Textbox(value=enron_context, lines=10, placeholder=\"Context for the question\", label=\"Context\")\n",
1993
+ " ],\n",
1994
+ " outputs=[gr.Text(label=\"Answer\")],\n",
1995
+ " title=\"Enron Case Question Answering System\",\n",
1996
+ " description=\"This interface uses a pre-trained model to answer your questions about the Enron scandal. Provide the question and context and get your answer.\",\n",
1997
+ " examples=[[\"What was the Enron scandal?\", enron_context], [\"What happened to Arthur Andersen?\", enron_context]]\n",
1998
+ ")\n",
1999
+ " \n",
2000
+ "# Add a feedback form\n",
2001
+ "feedback_form = gr.Interface(\n",
2002
+ " fn=log_feedback,\n",
2003
+ " inputs=[\n",
2004
+ " gr.Textbox(label=\"Question Asked\"),\n",
2005
+ " gr.Textbox(label=\"Context Given\"),\n",
2006
+ " gr.Textbox(label=\"Answer Provided\"),\n",
2007
+ " gr.Radio(choices=[\"Yes\", \"No\"], label=\"Was the answer correct?\"),\n",
2008
+ " gr.Textbox(label=\"Additional Feedback\")\n",
2009
+ " ],\n",
2010
+ " outputs=[]\n",
2011
+ ")\n",
2012
+ " \n",
2013
+ "iface.launch()"
2014
+ ]
2015
+ }
2016
+ ],
2017
+ "metadata": {
2018
+ "kernelspec": {
2019
+ "display_name": "Python 3 (ipykernel)",
2020
+ "language": "python",
2021
+ "name": "python3"
2022
+ },
2023
+ "language_info": {
2024
+ "codemirror_mode": {
2025
+ "name": "ipython",
2026
+ "version": 3
2027
+ },
2028
+ "file_extension": ".py",
2029
+ "mimetype": "text/x-python",
2030
+ "name": "python",
2031
+ "nbconvert_exporter": "python",
2032
+ "pygments_lexer": "ipython3",
2033
+ "version": "3.12.3"
2034
+ }
2035
+ },
2036
+ "nbformat": 4,
2037
+ "nbformat_minor": 5
2038
+ }
emails.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a605f0a2ca4a4feaa557eb0f2be825914906dca816066d7e58983e7e2e8bd274
3
+ size 1426122219
enron-email-dataset.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccd26396c2dd927416aae38eabb3de493dd377e3bb78811210ba2aebec942526
3
+ size 375294957
kaggle.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"username":"sstylianou","key":"ff0443948f5735a140eafededda239b8"}
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kaggle
2
+ os
3
+ from shutil import copyfile
4
+ zipfile
5
+ pandas as pd
6
+ email
7
+ from tqdm.notebook import tqdm
8
+ transformers
9
+ datasets
10
+ from transformers import BertModel
11
+ from transformers import BertTokenizer
12
+ torch
13
+ from tqdm import tqdm
14
+ import numpy as np
15
+ from transformers import BertForMaskedM
16
+ from tqdm.auto import tqdm
17
+ pickle
18
+ from transformers import Trainer
19
+ from transformers import TrainingArguments
20
+ from torch.utils.data import Dataset
21
+ scikit-learn
22
+ from sklearn.model_selection import train_test_split
23
+ !pip install tensorboard
24
+ transformers[torch]
25
+ accelerate>=0.21.0
26
+ --upgrade transformers
27
+ torch==2.2.2
28
+ from sklearn.metrics import precision_recall_fscore_support
29
+ from sklearn.metrics import , accuracy_score
30
+ from transformers import Trainer
31
+ from transformers import TrainingArguments
32
+ import gradio as gr
33
+ from transformers import pipeline
tokenized_data.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5701333fd993ada85ad3863c1518a25c9a06a8791422dceb989ec0ef940e4245
3
+ size 1821667917