Dofla commited on
Commit
db27dd8
1 Parent(s): a933464

Upload Optimisation_Model (3).ipynb

Browse files

model gpt 2 fine tuné avec un jeu de données sur des tweets parlant du covid

Files changed (1) hide show
  1. Optimisation_Model (3).ipynb +619 -0
Optimisation_Model (3).ipynb ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "metadata": {
7
+ "colab": {
8
+ "base_uri": "https://localhost:8080/"
9
+ },
10
+ "id": "-C4hOJzui0GC",
11
+ "outputId": "53b258f1-0b19-4c1e-95a2-1c45e053be4d"
12
+ },
13
+ "outputs": [
14
+ {
15
+ "output_type": "stream",
16
+ "name": "stdout",
17
+ "text": [
18
+ "Mounted at /content/drive\n"
19
+ ]
20
+ }
21
+ ],
22
+ "source": [
23
+ "from google.colab import drive\n",
24
+ "drive.mount('/content/drive')"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 1,
30
+ "metadata": {
31
+ "colab": {
32
+ "base_uri": "https://localhost:8080/"
33
+ },
34
+ "id": "C9_nGmZplkwn",
35
+ "outputId": "27ea388c-1391-466e-e7f9-1155c9b1880b"
36
+ },
37
+ "outputs": [
38
+ {
39
+ "output_type": "stream",
40
+ "name": "stdout",
41
+ "text": [
42
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m465.3/465.3 kB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
43
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m950.8/950.8 kB\u001b[0m \u001b[31m15.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
44
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
45
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
46
+ "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
47
+ "tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.5 which is incompatible.\u001b[0m\u001b[31m\n",
48
+ "\u001b[0mCollecting clean-text\n",
49
+ " Downloading clean_text-0.6.0-py3-none-any.whl (11 kB)\n",
50
+ "Collecting emoji<2.0.0,>=1.0.0 (from clean-text)\n",
51
+ " Downloading emoji-1.7.0.tar.gz (175 kB)\n",
52
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m175.4/175.4 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
53
+ "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
54
+ "Collecting ftfy<7.0,>=6.0 (from clean-text)\n",
55
+ " Downloading ftfy-6.1.3-py3-none-any.whl (53 kB)\n",
56
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.4/53.4 kB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
57
+ "\u001b[?25hRequirement already satisfied: wcwidth<0.3.0,>=0.2.12 in /usr/local/lib/python3.10/dist-packages (from ftfy<7.0,>=6.0->clean-text) (0.2.13)\n",
58
+ "Building wheels for collected packages: emoji\n",
59
+ " Building wheel for emoji (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
60
+ " Created wheel for emoji: filename=emoji-1.7.0-py3-none-any.whl size=171033 sha256=0223548bdb25e799ce902bbd67d1aca1a2f166851e765cdbd400528ff496aa31\n",
61
+ " Stored in directory: /root/.cache/pip/wheels/31/8a/8c/315c9e5d7773f74b33d5ed33f075b49c6eaeb7cedbb86e2cf8\n",
62
+ "Successfully built emoji\n",
63
+ "Installing collected packages: emoji, ftfy, clean-text\n",
64
+ "Successfully installed clean-text-0.6.0 emoji-1.7.0 ftfy-6.1.3\n",
65
+ "Collecting unicode\n",
66
+ " Downloading unicode-2.9-py2.py3-none-any.whl (14 kB)\n",
67
+ "Installing collected packages: unicode\n",
68
+ "Successfully installed unicode-2.9\n"
69
+ ]
70
+ }
71
+ ],
72
+ "source": [
73
+ "!pip install -q --upgrade keras-nlp\n",
74
+ "!pip install -q --upgrade keras # Upgrade to Keras 3.\n",
75
+ "!pip install clean-text\n",
76
+ "!pip install unicode\n",
77
+ "# !pip install keras==2.15.0\n"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 5,
83
+ "metadata": {
84
+ "id": "Ik6OXLR8lkzt"
85
+ },
86
+ "outputs": [],
87
+ "source": [
88
+ "import os\n",
89
+ "\n",
90
+ "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\" # or \"tensorflow\" or \"torch\"\n",
91
+ "\n",
92
+ "import keras_nlp\n",
93
+ "import keras\n",
94
+ "\n",
95
+ "# améliore la rapidité d'entrainement\n",
96
+ "keras.mixed_precision.set_global_policy(\"mixed_float16\")"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 6,
102
+ "metadata": {
103
+ "colab": {
104
+ "base_uri": "https://localhost:8080/"
105
+ },
106
+ "id": "U3-rKDHxtFxS",
107
+ "outputId": "64ce0f1f-1563-41f8-ddda-dba344b6a16c"
108
+ },
109
+ "outputs": [
110
+ {
111
+ "output_type": "stream",
112
+ "name": "stdout",
113
+ "text": [
114
+ "<class 'pandas.core.frame.DataFrame'>\n",
115
+ "RangeIndex: 44955 entries, 0 to 44954\n",
116
+ "Data columns (total 1 columns):\n",
117
+ " # Column Non-Null Count Dtype \n",
118
+ "--- ------ -------------- ----- \n",
119
+ " 0 OriginalTweet 44955 non-null object\n",
120
+ "dtypes: object(1)\n",
121
+ "memory usage: 351.3+ KB\n"
122
+ ]
123
+ }
124
+ ],
125
+ "source": [
126
+ "import pandas as pd\n",
127
+ "\n",
128
+ "df = pd.concat([\n",
129
+ " pd.read_csv('/content/drive/MyDrive/Corona_NLP_test.csv')[['OriginalTweet']],\n",
130
+ " pd.read_csv('/content/drive/MyDrive/Corona_NLP_train.csv', encoding='latin-1', on_bad_lines='skip')[['OriginalTweet']]\n",
131
+ "\n",
132
+ "])\n",
133
+ "df.reset_index(inplace=True, drop=True)\n",
134
+ "df.info()"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "markdown",
139
+ "metadata": {
140
+ "id": "vxunnbfhFKqj"
141
+ },
142
+ "source": [
143
+ "# **Pré-traitement**"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": 7,
149
+ "metadata": {
150
+ "colab": {
151
+ "base_uri": "https://localhost:8080/"
152
+ },
153
+ "id": "P3g_M1s_tGqj",
154
+ "outputId": "8f22765c-ce40-42a4-d860-80155e732acb"
155
+ },
156
+ "outputs": [
157
+ {
158
+ "output_type": "stream",
159
+ "name": "stdout",
160
+ "text": [
161
+ "0\n",
162
+ "Tous les éléments sont uniques.\n"
163
+ ]
164
+ }
165
+ ],
166
+ "source": [
167
+ "print(df['OriginalTweet'].isnull().sum())\n",
168
+ "df.dropna(subset=['OriginalTweet'], how='any', inplace=True)\n",
169
+ "df.reset_index(inplace=True, drop=True)\n",
170
+ "\n",
171
+ "unique_elements = not df['OriginalTweet'].duplicated().any()\n",
172
+ "\n",
173
+ "if unique_elements:\n",
174
+ " print(\"Tous les éléments sont uniques.\")\n",
175
+ "else:\n",
176
+ " print(\"Certains éléments ne sont pas uniques.\")"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 8,
182
+ "metadata": {
183
+ "id": "RbFJC39ytGsZ",
184
+ "colab": {
185
+ "base_uri": "https://localhost:8080/"
186
+ },
187
+ "outputId": "b820a9b7-2622-498f-ed29-56309e781749"
188
+ },
189
+ "outputs": [
190
+ {
191
+ "output_type": "stream",
192
+ "name": "stderr",
193
+ "text": [
194
+ "WARNING:root:Since the GPL-licensed package `unidecode` is not installed, using Python's `unicodedata` package which yields worse results.\n"
195
+ ]
196
+ }
197
+ ],
198
+ "source": [
199
+ "import cleantext\n",
200
+ "import re\n",
201
+ "# corrige les caractères unicode mal formaté\n",
202
+ "# converti les caracteres non ASCII en leur équivalent exemple : é devient e\n",
203
+ "\n",
204
+ "df['OriginalTweet'] = df['OriginalTweet'].apply(lambda x: cleantext.clean(x,\n",
205
+ " fix_unicode=True,\n",
206
+ " to_ascii=True,\n",
207
+ " lower=True,\n",
208
+ " no_emails=True,\n",
209
+ " no_urls=True,\n",
210
+ " ))\n",
211
+ "df['OriginalTweet'] = df['OriginalTweet'].str.replace('<url>', '')\n",
212
+ "# traitement des character speciaux\n",
213
+ "df['OriginalTweet'] = df['OriginalTweet'].apply(lambda x: re.sub(r'\\W', ' ', x))\n",
214
+ "# on enleve les espaces en trop\n",
215
+ "df['OriginalTweet'] = df['OriginalTweet'].str.replace(r'\\s{2,}', ' ', regex=True)\n"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "metadata": {
222
+ "id": "bH-ODHW4tKvu"
223
+ },
224
+ "outputs": [],
225
+ "source": [
226
+ "df.info()"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": 40,
232
+ "metadata": {
233
+ "colab": {
234
+ "base_uri": "https://localhost:8080/",
235
+ "height": 385
236
+ },
237
+ "id": "Vb3n637UlqjX",
238
+ "outputId": "73c15e07-6bc7-49be-da1d-72a4ea49bb77"
239
+ },
240
+ "outputs": [
241
+ {
242
+ "output_type": "display_data",
243
+ "data": {
244
+ "text/plain": [
245
+ "\u001b[1mPreprocessor: \"gpt2_causal_lm_preprocessor_8\"\u001b[0m\n"
246
+ ],
247
+ "text/html": [
248
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Preprocessor: \"gpt2_causal_lm_preprocessor_8\"</span>\n",
249
+ "</pre>\n"
250
+ ]
251
+ },
252
+ "metadata": {}
253
+ },
254
+ {
255
+ "output_type": "display_data",
256
+ "data": {
257
+ "text/plain": [
258
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
259
+ "┃\u001b[1m \u001b[0m\u001b[1mTokenizer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Vocab #\u001b[0m\u001b[1m \u001b[0m┃\n",
260
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
261
+ "│ gpt2_tokenizer (\u001b[38;5;33mGPT2Tokenizer\u001b[0m) │ \u001b[38;5;34m50,257\u001b[0m │\n",
262
+ "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n"
263
+ ],
264
+ "text/html": [
265
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
266
+ "┃<span style=\"font-weight: bold\"> Tokenizer (type) </span>┃<span style=\"font-weight: bold\"> Vocab # </span>┃\n",
267
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
268
+ "│ gpt2_tokenizer (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GPT2Tokenizer</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">50,257</span> │\n",
269
+ "└────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘\n",
270
+ "</pre>\n"
271
+ ]
272
+ },
273
+ "metadata": {}
274
+ },
275
+ {
276
+ "output_type": "display_data",
277
+ "data": {
278
+ "text/plain": [
279
+ "\u001b[1mModel: \"gpt2_causal_lm_8\"\u001b[0m\n"
280
+ ],
281
+ "text/html": [
282
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"gpt2_causal_lm_8\"</span>\n",
283
+ "</pre>\n"
284
+ ]
285
+ },
286
+ "metadata": {}
287
+ },
288
+ {
289
+ "output_type": "display_data",
290
+ "data": {
291
+ "text/plain": [
292
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
293
+ "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
294
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
295
+ "│ padding_mask (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
296
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
297
+ "│ token_ids (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
298
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
299
+ "│ gpt2_backbone (\u001b[38;5;33mGPT2Backbone\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1280\u001b[0m) │ \u001b[38;5;34m774,030,080\u001b[0m │ padding_mask[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
300
+ "│ │ │ │ token_ids[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
301
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
302
+ "│ token_embedding │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m50257\u001b[0m) │ \u001b[38;5;34m64,328,960\u001b[0m │ gpt2_backbone[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
303
+ "│ (\u001b[38;5;33mReversibleEmbedding\u001b[0m) │ │ │ │\n",
304
+ "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n"
305
+ ],
306
+ "text/html": [
307
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
308
+ "┃<span style=\"font-weight: bold\"> Layer (type) </span>┃<span style=\"font-weight: bold\"> Output Shape </span>┃<span style=\"font-weight: bold\"> Param # </span>┃<span style=\"font-weight: bold\"> Connected to </span>┃\n",
309
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
310
+ "│ padding_mask (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
311
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
312
+ "│ token_ids (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">InputLayer</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> │ - │\n",
313
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
314
+ "│ gpt2_backbone (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">GPT2Backbone</span>) │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">1280</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">774,030,080</span> │ padding_mask[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>], │\n",
315
+ "│ │ │ │ token_ids[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
316
+ "├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤\n",
317
+ "│ token_embedding │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">50257</span>) │ <span style=\"color: #00af00; text-decoration-color: #00af00\">64,328,960</span> │ gpt2_backbone[<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>][<span style=\"color: #00af00; text-decoration-color: #00af00\">0</span>] │\n",
318
+ "│ (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">ReversibleEmbedding</span>) │ │ │ │\n",
319
+ "└───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘\n",
320
+ "</pre>\n"
321
+ ]
322
+ },
323
+ "metadata": {}
324
+ },
325
+ {
326
+ "output_type": "display_data",
327
+ "data": {
328
+ "text/plain": [
329
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m774,030,080\u001b[0m (2.88 GB)\n"
330
+ ],
331
+ "text/html": [
332
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">774,030,080</span> (2.88 GB)\n",
333
+ "</pre>\n"
334
+ ]
335
+ },
336
+ "metadata": {}
337
+ },
338
+ {
339
+ "output_type": "display_data",
340
+ "data": {
341
+ "text/plain": [
342
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m774,030,080\u001b[0m (2.88 GB)\n"
343
+ ],
344
+ "text/html": [
345
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">774,030,080</span> (2.88 GB)\n",
346
+ "</pre>\n"
347
+ ]
348
+ },
349
+ "metadata": {}
350
+ },
351
+ {
352
+ "output_type": "display_data",
353
+ "data": {
354
+ "text/plain": [
355
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
356
+ ],
357
+ "text/html": [
358
+ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
359
+ "</pre>\n"
360
+ ]
361
+ },
362
+ "metadata": {}
363
+ }
364
+ ],
365
+ "source": [
366
+ "# initialisation pour l'utilisation du modèle GPT-2\n",
367
+ "import pandas as pd\n",
368
+ "# from tensorflow.keras.preprocessing.text import Tokenizer\n",
369
+ "from keras_nlp.models import GPT2CausalLMPreprocessor, GPT2CausalLM\n",
370
+ "\n",
371
+ "preprocessor = GPT2CausalLMPreprocessor.from_preset(\n",
372
+ " \"gpt2_large_en\",\n",
373
+ " sequence_length=128,\n",
374
+ ")\n",
375
+ "\n",
376
+ "gpt2_lm = GPT2CausalLM.from_preset(\n",
377
+ " \"gpt2_large_en\",\n",
378
+ " preprocessor=preprocessor\n",
379
+ ")\n",
380
+ "\n",
381
+ "gpt2_lm.summary()\n"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "execution_count": null,
387
+ "metadata": {
388
+ "id": "PtHNonkjlqlV"
389
+ },
390
+ "outputs": [],
391
+ "source": [
392
+ "print(df)"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "execution_count": 12,
398
+ "metadata": {
399
+ "colab": {
400
+ "base_uri": "https://localhost:8080/"
401
+ },
402
+ "id": "d08uzWyPlqpJ",
403
+ "outputId": "18f3b943-5267-4a33-9391-e5f106fcd857"
404
+ },
405
+ "outputs": [
406
+ {
407
+ "output_type": "stream",
408
+ "name": "stdout",
409
+ "text": [
410
+ "['trending new yorkers encounter empty supermarket shelves pictured wegmans in brooklyn sold out online grocers foodkick maxdelivery as coronavirus fearing shoppers stock up ', 'when i couldn t find hand sanitizer at fred meyer i turned to amazon but 114 97 for a 2 pack of purell check out how coronavirus concerns are driving up prices ', 'find out how you can protect yourself and loved ones from coronavirus ', ' panic buying hits newyork city as anxious shoppers stock up on food medical supplies after healthcare worker in her 30s becomes bigapple 1st confirmed coronavirus patient or a bloomberg staged event qanon qanon2018 qanon2020 election2020 cdc ', ' toiletpaper dunnypaper coronavirus coronavirusaustralia coronavirusupdate covid_19 9news corvid19 7newsmelb dunnypapergate costco one week everyone buying baby milk powder the next everyone buying up toilet paper ', 'do you remember the last time you paid 2 99 a gallon for regular gas in los angeles prices at the pump are going down a look at how the coronavirus is impacting prices 4pm abc7 ', 'voting in the age of coronavirus hand sanitizer supertuesday ', ' drtedros we can t stop covid19 without protecting healthworkers prices of surgical masks have increased six fold n95 respirators have more than trebled gowns cost twice as much drtedros coronavirus', 'hi twitter i am a pharmacist i sell hand sanitizer for a living or i do when any exists like masks it is sold the fuck out everywhere should you be worried no use soap should you visit twenty pharmacies looking for the last bottle no pharmacies are full of sick people ', 'anyone been in a supermarket over the last few days went to do my normal shop last night is the sight that greeted me barmy btw what s so special about tinned tomatoes covid_19 dublin ']\n"
411
+ ]
412
+ }
413
+ ],
414
+ "source": [
415
+ "text_df = list(df['OriginalTweet'].astype(str).values)\n",
416
+ "print(text_df[:10])"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": 41,
422
+ "metadata": {
423
+ "colab": {
424
+ "base_uri": "https://localhost:8080/"
425
+ },
426
+ "id": "dawoMENVFphM",
427
+ "outputId": "48232747-b2d8-4ea3-f148-fa7f7f1ad7b8"
428
+ },
429
+ "outputs": [
430
+ {
431
+ "output_type": "stream",
432
+ "name": "stdout",
433
+ "text": [
434
+ "Epoch 1/5\n",
435
+ "\u001b[1m1405/1405\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1236s\u001b[0m 568ms/step - accuracy: 0.3385 - loss: 1.2210\n",
436
+ "Epoch 2/5\n",
437
+ "\u001b[1m1405/1405\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m425s\u001b[0m 302ms/step - accuracy: 0.4055 - loss: 1.0146\n",
438
+ "Epoch 3/5\n",
439
+ "\u001b[1m1405/1405\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m430s\u001b[0m 305ms/step - accuracy: 0.4726 - loss: 0.8486\n",
440
+ "Epoch 4/5\n",
441
+ "\u001b[1m1405/1405\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m430s\u001b[0m 305ms/step - accuracy: 0.5486 - loss: 0.6875\n",
442
+ "Epoch 5/5\n",
443
+ "\u001b[1m1405/1405\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m431s\u001b[0m 306ms/step - accuracy: 0.6267 - loss: 0.5447\n"
444
+ ]
445
+ }
446
+ ],
447
+ "source": [
448
+ "import tensorflow as tf\n",
449
+ "\n",
450
+ "\n",
451
+ "nb_epochs = 5\n",
452
+ "with tf.device('/device:GPU:0'):\n",
453
+ "\n",
454
+ " learning_rate = keras.optimizers.schedules.ExponentialDecay(\n",
455
+ " initial_learning_rate=5e-5,\n",
456
+ " decay_steps=len(text_df) * nb_epochs,\n",
457
+ " decay_rate=0.96,\n",
458
+ " staircase=True\n",
459
+ " )\n",
460
+ "\n",
461
+ " # l'otpmisation PolynomialDecay peut être utilisé à la place de Exponential\n",
462
+ " # Decay (la précision n'est pas mauvaise)\n",
463
+ "\n",
464
+ " # initial_learning_rate peut être mis à 1e-3\n",
465
+ " # learning_rate = keras.optimizers.schedules.PolynomialDecay(\n",
466
+ " # initial_learning_rate=5e-5,\n",
467
+ " # decay_steps=len(text_df) * nb_epochs,\n",
468
+ " # end_learning_rate=0.0,\n",
469
+ " # )\n",
470
+ "\n",
471
+ " # on va crée un callbacks dans le cas ou l'entrainement devient régréssant\n",
472
+ " early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=nb_epochs, restore_best_weights=True)\n",
473
+ "\n",
474
+ " loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
475
+ " gpt2_lm.compile(\n",
476
+ " optimizer=keras.optimizers.Adam(learning_rate=learning_rate),\n",
477
+ " loss=loss,\n",
478
+ " weighted_metrics=[\"accuracy\"],\n",
479
+ " )\n",
480
+ "\n",
481
+ " gpt2_lm.fit(\n",
482
+ " x=text_df,\n",
483
+ " epochs=nb_epochs,\n",
484
+ " # batch_size = 32,\n",
485
+ " callbacks=[early_stopping]\n",
486
+ " )"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "markdown",
491
+ "source": [
492
+ "## Sauvegarde du modèle\n",
493
+ "\n",
494
+ "On sauvegarde le model pour éviter de relancer toutes les cellules."
495
+ ],
496
+ "metadata": {
497
+ "id": "fT70Qhsf3fCh"
498
+ }
499
+ },
500
+ {
501
+ "cell_type": "code",
502
+ "source": [
503
+ "from tensorflow.keras.models import load_model, save_model\n",
504
+ "\n",
505
+ "save_model(gpt2_lm, '/content/drive/MyDrive/my_model.keras')"
506
+ ],
507
+ "metadata": {
508
+ "id": "Hki3ab9jIowZ"
509
+ },
510
+ "execution_count": 44,
511
+ "outputs": []
512
+ },
513
+ {
514
+ "cell_type": "markdown",
515
+ "source": [
516
+ "# Réutilisation d'un modèle sauvegardé"
517
+ ],
518
+ "metadata": {
519
+ "id": "7wRf7dmvI9OV"
520
+ }
521
+ },
522
+ {
523
+ "cell_type": "code",
524
+ "source": [
525
+ "gpt2_lm = load_model('/content/drive/MyDrive/my_model.keras')"
526
+ ],
527
+ "metadata": {
528
+ "id": "vBMqnFGGJUZg"
529
+ },
530
+ "execution_count": null,
531
+ "outputs": []
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": 43,
536
+ "metadata": {
537
+ "colab": {
538
+ "base_uri": "https://localhost:8080/"
539
+ },
540
+ "id": "33qOkv9FlqrE",
541
+ "outputId": "607200e4-94ab-4557-91e2-25f5d7494b2f"
542
+ },
543
+ "outputs": [
544
+ {
545
+ "output_type": "stream",
546
+ "name": "stdout",
547
+ "text": [
548
+ "\n",
549
+ "GPT-2 output:\n",
550
+ "The symptoms of the covid 19 virus can be so severe that a healthcare worker may be too scared to come into work to make money the government needs to step in and save these workers now \n",
551
+ "___________________________________________________________\n",
552
+ "\n",
553
+ "GPT-2 output:\n",
554
+ "The symptoms of the covid19 coronavirus vary from person to person i have observed that when i have left my house to go to the supermarket the streets have been full of people i have been told that it has been like this every day this week\n",
555
+ "___________________________________________________________\n",
556
+ "\n",
557
+ "GPT-2 output:\n",
558
+ "The symptoms of the covid2019 covid_19 pandemic vary from person to person and may include fever coughing chest congestion shortness of breath muscle aches and pains in various parts of your body learn more about what to look for in your next health exam \n",
559
+ "___________________________________________________________\n",
560
+ "\n",
561
+ "GPT-2 output:\n",
562
+ "The symptoms of the covid2019 coronavirus vary from person to person but the virus is generally passed in contact with someone who has it sneezing or talking back coughing into a tissue or using a cart or a shopping cart \n",
563
+ "___________________________________________________________\n",
564
+ "\n",
565
+ "GPT-2 output:\n",
566
+ "The symptoms of the covid19 coronavirus are typically mild and self contained adults and children should not be encouraged to engage in activities or go to the beach children under 18 are advised to stay at home \n",
567
+ "___________________________________________________________\n"
568
+ ]
569
+ }
570
+ ],
571
+ "source": [
572
+ "# model gpt-2 de base\n",
573
+ "for i in range(5):\n",
574
+ "# Ne pas laisser d'espace avant le \" ça un impacte énorme !\n",
575
+ " seed_text = \"The symptoms of the covid\"\n",
576
+ " # on peu rajouter le paramètre max_length si on veut\n",
577
+ " generated_text = gpt2_lm.generate(seed_text, max_length=100)\n",
578
+ " print(\"\\nGPT-2 output:\")\n",
579
+ " print(generated_text)\n",
580
+ " print('___________________________________________________________')"
581
+ ]
582
+ },
583
+ {
584
+ "cell_type": "code",
585
+ "execution_count": null,
586
+ "metadata": {
587
+ "id": "mMGwwyCa6tnt"
588
+ },
589
+ "outputs": [],
590
+ "source": []
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": null,
595
+ "metadata": {
596
+ "id": "K50lvyaJ6tc6"
597
+ },
598
+ "outputs": [],
599
+ "source": []
600
+ }
601
+ ],
602
+ "metadata": {
603
+ "accelerator": "GPU",
604
+ "colab": {
605
+ "gpuType": "A100",
606
+ "provenance": [],
607
+ "machine_shape": "hm"
608
+ },
609
+ "kernelspec": {
610
+ "display_name": "Python 3",
611
+ "name": "python3"
612
+ },
613
+ "language_info": {
614
+ "name": "python"
615
+ }
616
+ },
617
+ "nbformat": 4,
618
+ "nbformat_minor": 0
619
+ }