ak0601 commited on
Commit
ad3e56b
1 Parent(s): 8e7903b

Upload 3 files

Browse files
Files changed (3) hide show
  1. model.ipynb +444 -0
  2. story_gen.h5 +3 -0
  3. tokenizer.json +0 -0
model.ipynb ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import opendatasets as od"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 4,
15
+ "metadata": {},
16
+ "outputs": [
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "Downloading mpst-movie-plot-synopses-with-tags.zip to .\\mpst-movie-plot-synopses-with-tags\n"
22
+ ]
23
+ },
24
+ {
25
+ "name": "stderr",
26
+ "output_type": "stream",
27
+ "text": [
28
+ "100%|██████████| 28.8M/28.8M [00:07<00:00, 3.81MB/s]\n"
29
+ ]
30
+ },
31
+ {
32
+ "name": "stdout",
33
+ "output_type": "stream",
34
+ "text": [
35
+ "\n"
36
+ ]
37
+ }
38
+ ],
39
+ "source": [
40
+ "od.download('https://www.kaggle.com/datasets/cryptexcode/mpst-movie-plot-synopses-with-tags')"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 6,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "import pandas as pd\n",
50
+ "df = pd.read_csv('mpst-movie-plot-synopses-with-tags\\mpst_full_data.csv')"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 7,
56
+ "metadata": {},
57
+ "outputs": [
58
+ {
59
+ "data": {
60
+ "text/html": [
61
+ "<div>\n",
62
+ "<style scoped>\n",
63
+ " .dataframe tbody tr th:only-of-type {\n",
64
+ " vertical-align: middle;\n",
65
+ " }\n",
66
+ "\n",
67
+ " .dataframe tbody tr th {\n",
68
+ " vertical-align: top;\n",
69
+ " }\n",
70
+ "\n",
71
+ " .dataframe thead th {\n",
72
+ " text-align: right;\n",
73
+ " }\n",
74
+ "</style>\n",
75
+ "<table border=\"1\" class=\"dataframe\">\n",
76
+ " <thead>\n",
77
+ " <tr style=\"text-align: right;\">\n",
78
+ " <th></th>\n",
79
+ " <th>imdb_id</th>\n",
80
+ " <th>title</th>\n",
81
+ " <th>plot_synopsis</th>\n",
82
+ " <th>tags</th>\n",
83
+ " <th>split</th>\n",
84
+ " <th>synopsis_source</th>\n",
85
+ " </tr>\n",
86
+ " </thead>\n",
87
+ " <tbody>\n",
88
+ " <tr>\n",
89
+ " <th>0</th>\n",
90
+ " <td>tt0057603</td>\n",
91
+ " <td>I tre volti della paura</td>\n",
92
+ " <td>Note: this synopsis is for the orginal Italian...</td>\n",
93
+ " <td>cult, horror, gothic, murder, atmospheric</td>\n",
94
+ " <td>train</td>\n",
95
+ " <td>imdb</td>\n",
96
+ " </tr>\n",
97
+ " <tr>\n",
98
+ " <th>1</th>\n",
99
+ " <td>tt1733125</td>\n",
100
+ " <td>Dungeons &amp; Dragons: The Book of Vile Darkness</td>\n",
101
+ " <td>Two thousand years ago, Nhagruul the Foul, a s...</td>\n",
102
+ " <td>violence</td>\n",
103
+ " <td>train</td>\n",
104
+ " <td>imdb</td>\n",
105
+ " </tr>\n",
106
+ " <tr>\n",
107
+ " <th>2</th>\n",
108
+ " <td>tt0033045</td>\n",
109
+ " <td>The Shop Around the Corner</td>\n",
110
+ " <td>Matuschek's, a gift store in Budapest, is the ...</td>\n",
111
+ " <td>romantic</td>\n",
112
+ " <td>test</td>\n",
113
+ " <td>imdb</td>\n",
114
+ " </tr>\n",
115
+ " <tr>\n",
116
+ " <th>3</th>\n",
117
+ " <td>tt0113862</td>\n",
118
+ " <td>Mr. Holland's Opus</td>\n",
119
+ " <td>Glenn Holland, not a morning person by anyone'...</td>\n",
120
+ " <td>inspiring, romantic, stupid, feel-good</td>\n",
121
+ " <td>train</td>\n",
122
+ " <td>imdb</td>\n",
123
+ " </tr>\n",
124
+ " <tr>\n",
125
+ " <th>4</th>\n",
126
+ " <td>tt0086250</td>\n",
127
+ " <td>Scarface</td>\n",
128
+ " <td>In May 1980, a Cuban man named Tony Montana (A...</td>\n",
129
+ " <td>cruelty, murder, dramatic, cult, violence, atm...</td>\n",
130
+ " <td>val</td>\n",
131
+ " <td>imdb</td>\n",
132
+ " </tr>\n",
133
+ " </tbody>\n",
134
+ "</table>\n",
135
+ "</div>"
136
+ ],
137
+ "text/plain": [
138
+ " imdb_id title \\\n",
139
+ "0 tt0057603 I tre volti della paura \n",
140
+ "1 tt1733125 Dungeons & Dragons: The Book of Vile Darkness \n",
141
+ "2 tt0033045 The Shop Around the Corner \n",
142
+ "3 tt0113862 Mr. Holland's Opus \n",
143
+ "4 tt0086250 Scarface \n",
144
+ "\n",
145
+ " plot_synopsis \\\n",
146
+ "0 Note: this synopsis is for the orginal Italian... \n",
147
+ "1 Two thousand years ago, Nhagruul the Foul, a s... \n",
148
+ "2 Matuschek's, a gift store in Budapest, is the ... \n",
149
+ "3 Glenn Holland, not a morning person by anyone'... \n",
150
+ "4 In May 1980, a Cuban man named Tony Montana (A... \n",
151
+ "\n",
152
+ " tags split synopsis_source \n",
153
+ "0 cult, horror, gothic, murder, atmospheric train imdb \n",
154
+ "1 violence train imdb \n",
155
+ "2 romantic test imdb \n",
156
+ "3 inspiring, romantic, stupid, feel-good train imdb \n",
157
+ "4 cruelty, murder, dramatic, cult, violence, atm... val imdb "
158
+ ]
159
+ },
160
+ "execution_count": 7,
161
+ "metadata": {},
162
+ "output_type": "execute_result"
163
+ }
164
+ ],
165
+ "source": [
166
+ "df.head()"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "!pip install gpt-2-simple"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 21,
181
+ "metadata": {},
182
+ "outputs": [
183
+ {
184
+ "data": {
185
+ "text/plain": [
186
+ "Index(['imdb_id', 'title', 'plot_synopsis', 'tags', 'split',\n",
187
+ " 'synopsis_source'],\n",
188
+ " dtype='object')"
189
+ ]
190
+ },
191
+ "execution_count": 21,
192
+ "metadata": {},
193
+ "output_type": "execute_result"
194
+ }
195
+ ],
196
+ "source": [
197
+ "df.columns"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 40,
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "import numpy as np\n",
207
+ "from sklearn.model_selection import train_test_split\n",
208
+ "from tensorflow.keras.preprocessing.text import Tokenizer\n",
209
+ "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
210
+ "from tensorflow.keras.models import Sequential\n",
211
+ "from tensorflow.keras.layers import Embedding, LSTM, Dense, Flatten\n",
212
+ "from sklearn.preprocessing import MultiLabelBinarizer"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": 23,
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "df = df[['title', 'plot_synopsis', 'tags']]"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 63,
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "tokenizer = Tokenizer()\n",
231
+ "tokenizer.fit_on_texts(df['title'])\n",
232
+ "title_sequences = tokenizer.texts_to_sequences(df['title'])\n",
233
+ "max_title_length = max(len(seq) for seq in title_sequences)\n",
234
+ "title_sequences = pad_sequences(title_sequences, maxlen=max_title_length)"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": 64,
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "tags = [tag.split(', ') for tag in df['tags']]\n",
244
+ "mlb = MultiLabelBinarizer()\n",
245
+ "tags = mlb.fit_transform(tags)"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": 65,
251
+ "metadata": {},
252
+ "outputs": [],
253
+ "source": [
254
+ "tokenizer_json = tokenizer.to_json()\n",
255
+ "with open('tokenizer.json', 'w') as json_file:\n",
256
+ " json_file.write(tokenizer_json)"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": 42,
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "X_train, X_test, y_train, y_test = train_test_split(title_sequences, tags, test_size=0.2, random_state=42)"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": 43,
271
+ "metadata": {},
272
+ "outputs": [],
273
+ "source": [
274
+ "vocab_size = len(tokenizer.word_index) + 1\n",
275
+ "embedding_dim = 100"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": 46,
281
+ "metadata": {},
282
+ "outputs": [
283
+ {
284
+ "name": "stdout",
285
+ "output_type": "stream",
286
+ "text": [
287
+ "Train on 11862 samples, validate on 2966 samples\n",
288
+ "Epoch 1/15\n",
289
+ "11862/11862 [==============================] - 10s 826us/sample - loss: 0.1911 - accuracy: 0.9457 - val_loss: 0.1417 - val_accuracy: 0.9569\n",
290
+ "Epoch 2/15\n",
291
+ "11862/11862 [==============================] - 11s 887us/sample - loss: 0.1390 - accuracy: 0.9583 - val_loss: 0.1416 - val_accuracy: 0.9569\n",
292
+ "Epoch 3/15\n",
293
+ "11862/11862 [==============================] - 11s 941us/sample - loss: 0.1388 - accuracy: 0.9583 - val_loss: 0.1415 - val_accuracy: 0.9569\n",
294
+ "Epoch 4/15\n",
295
+ "11862/11862 [==============================] - 11s 916us/sample - loss: 0.1367 - accuracy: 0.9583 - val_loss: 0.1420 - val_accuracy: 0.9568\n",
296
+ "Epoch 5/15\n",
297
+ "11862/11862 [==============================] - 11s 906us/sample - loss: 0.1310 - accuracy: 0.9595 - val_loss: 0.1433 - val_accuracy: 0.9567\n",
298
+ "Epoch 6/15\n",
299
+ "11862/11862 [==============================] - 11s 909us/sample - loss: 0.1248 - accuracy: 0.9608 - val_loss: 0.1444 - val_accuracy: 0.9569\n",
300
+ "Epoch 7/15\n",
301
+ "11862/11862 [==============================] - 11s 911us/sample - loss: 0.1184 - accuracy: 0.9624 - val_loss: 0.1461 - val_accuracy: 0.9564\n",
302
+ "Epoch 8/15\n",
303
+ "11862/11862 [==============================] - 11s 948us/sample - loss: 0.1123 - accuracy: 0.9649 - val_loss: 0.1484 - val_accuracy: 0.9562\n",
304
+ "Epoch 9/15\n",
305
+ "11862/11862 [==============================] - 11s 916us/sample - loss: 0.1069 - accuracy: 0.9668 - val_loss: 0.1509 - val_accuracy: 0.9552\n",
306
+ "Epoch 10/15\n",
307
+ "11862/11862 [==============================] - 11s 921us/sample - loss: 0.1021 - accuracy: 0.9682 - val_loss: 0.1537 - val_accuracy: 0.9550\n",
308
+ "Epoch 11/15\n",
309
+ "11862/11862 [==============================] - 11s 932us/sample - loss: 0.0978 - accuracy: 0.9692 - val_loss: 0.1566 - val_accuracy: 0.9541\n",
310
+ "Epoch 12/15\n",
311
+ "11862/11862 [==============================] - 11s 927us/sample - loss: 0.0937 - accuracy: 0.9700 - val_loss: 0.1591 - val_accuracy: 0.9540\n",
312
+ "Epoch 13/15\n",
313
+ "11862/11862 [==============================] - 11s 927us/sample - loss: 0.0896 - accuracy: 0.9710 - val_loss: 0.1621 - val_accuracy: 0.9536\n",
314
+ "Epoch 14/15\n",
315
+ "11862/11862 [==============================] - 11s 954us/sample - loss: 0.0857 - accuracy: 0.9719 - val_loss: 0.1660 - val_accuracy: 0.9536\n",
316
+ "Epoch 15/15\n",
317
+ "11862/11862 [==============================] - 12s 1ms/sample - loss: 0.0820 - accuracy: 0.9729 - val_loss: 0.1690 - val_accuracy: 0.9538\n"
318
+ ]
319
+ },
320
+ {
321
+ "data": {
322
+ "text/plain": [
323
+ "<keras.callbacks.History at 0x1cc31c0b250>"
324
+ ]
325
+ },
326
+ "execution_count": 46,
327
+ "metadata": {},
328
+ "output_type": "execute_result"
329
+ }
330
+ ],
331
+ "source": [
332
+ "\n",
333
+ "model = Sequential()\n",
334
+ "model.add(Embedding(vocab_size, embedding_dim, input_length=max_title_length))\n",
335
+ "model.add(LSTM(100))\n",
336
+ "model.add(Dense(tags.shape[1], activation='sigmoid'))\n",
337
+ "\n",
338
+ "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
339
+ "\n",
340
+ "\n",
341
+ "model.fit(X_train, y_train, batch_size=64, epochs=15, validation_data=(X_test, y_test))\n"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": 47,
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "model.save('story_gen.h5')"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": 59,
356
+ "metadata": {},
357
+ "outputs": [],
358
+ "source": [
359
+ "title = \"A oversized t-shirt\"\n"
360
+ ]
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "execution_count": 60,
365
+ "metadata": {},
366
+ "outputs": [],
367
+ "source": [
368
+ "title_sequences = tokenizer.texts_to_sequences(title)"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "execution_count": null,
374
+ "metadata": {},
375
+ "outputs": [],
376
+ "source": [
377
+ "predictions = model.predict(title_sequences)"
378
+ ]
379
+ },
380
+ {
381
+ "cell_type": "code",
382
+ "execution_count": 75,
383
+ "metadata": {},
384
+ "outputs": [
385
+ {
386
+ "name": "stdout",
387
+ "output_type": "stream",
388
+ "text": [
389
+ "Input Title: Spider Man\n",
390
+ "Predicted Tags: [('murder',)]\n"
391
+ ]
392
+ }
393
+ ],
394
+ "source": [
395
+ "from tensorflow.keras.models import load_model\n",
396
+ "with open('tokenizer.json', 'r') as f:\n",
397
+ " tokenizer = tokenizer_from_json(f.read())\n",
398
+ "\n",
399
+ "model = load_model('story_gen.h5') \n",
400
+ "\n",
401
+ "example_title = \"Spider Man\"\n",
402
+ "\n",
403
+ "example_sequence = tokenizer.texts_to_sequences([example_title])\n",
404
+ "example_sequence = pad_sequences(example_sequence, maxlen=max_title_length)\n",
405
+ "\n",
406
+ "predictions = model.predict(example_sequence)\n",
407
+ "\n",
408
+ "predicted_tags = mlb.inverse_transform((predictions > 0.5).astype(int))\n",
409
+ "\n",
410
+ "print(\"Input Title:\", example_title)\n",
411
+ "print(\"Predicted Tags:\", predicted_tags)"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": null,
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": []
420
+ }
421
+ ],
422
+ "metadata": {
423
+ "kernelspec": {
424
+ "display_name": "base",
425
+ "language": "python",
426
+ "name": "python3"
427
+ },
428
+ "language_info": {
429
+ "codemirror_mode": {
430
+ "name": "ipython",
431
+ "version": 3
432
+ },
433
+ "file_extension": ".py",
434
+ "mimetype": "text/x-python",
435
+ "name": "python",
436
+ "nbconvert_exporter": "python",
437
+ "pygments_lexer": "ipython3",
438
+ "version": "3.10.9"
439
+ },
440
+ "orig_nbformat": 4
441
+ },
442
+ "nbformat": 4,
443
+ "nbformat_minor": 2
444
+ }
story_gen.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb406194f44af2207635abf33c6c54cbbdb4e06ed14bdb3ef434b98fa806ecfb
3
+ size 15675428
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff