{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import opendatasets as od" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading mpst-movie-plot-synopses-with-tags.zip to .\\mpst-movie-plot-synopses-with-tags\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 28.8M/28.8M [00:07<00:00, 3.81MB/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "od.download('https://www.kaggle.com/datasets/cryptexcode/mpst-movie-plot-synopses-with-tags')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "df = pd.read_csv('mpst-movie-plot-synopses-with-tags\\mpst_full_data.csv')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
imdb_idtitleplot_synopsistagssplitsynopsis_source
0tt0057603I tre volti della pauraNote: this synopsis is for the orginal Italian...cult, horror, gothic, murder, atmospherictrainimdb
1tt1733125Dungeons & Dragons: The Book of Vile DarknessTwo thousand years ago, Nhagruul the Foul, a s...violencetrainimdb
2tt0033045The Shop Around the CornerMatuschek's, a gift store in Budapest, is the ...romantictestimdb
3tt0113862Mr. Holland's OpusGlenn Holland, not a morning person by anyone'...inspiring, romantic, stupid, feel-goodtrainimdb
4tt0086250ScarfaceIn May 1980, a Cuban man named Tony Montana (A...cruelty, murder, dramatic, cult, violence, atm...valimdb
\n", "
" ], "text/plain": [ " imdb_id title \\\n", "0 tt0057603 I tre volti della paura \n", "1 tt1733125 Dungeons & Dragons: The Book of Vile Darkness \n", "2 tt0033045 The Shop Around the Corner \n", "3 tt0113862 Mr. Holland's Opus \n", "4 tt0086250 Scarface \n", "\n", " plot_synopsis \\\n", "0 Note: this synopsis is for the orginal Italian... \n", "1 Two thousand years ago, Nhagruul the Foul, a s... \n", "2 Matuschek's, a gift store in Budapest, is the ... \n", "3 Glenn Holland, not a morning person by anyone'... \n", "4 In May 1980, a Cuban man named Tony Montana (A... \n", "\n", " tags split synopsis_source \n", "0 cult, horror, gothic, murder, atmospheric train imdb \n", "1 violence train imdb \n", "2 romantic test imdb \n", "3 inspiring, romantic, stupid, feel-good train imdb \n", "4 cruelty, murder, dramatic, cult, violence, atm... val imdb " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install gpt-2-simple" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['imdb_id', 'title', 'plot_synopsis', 'tags', 'split',\n", " 'synopsis_source'],\n", " dtype='object')" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.columns" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.model_selection import train_test_split\n", "from tensorflow.keras.preprocessing.text import Tokenizer\n", "from tensorflow.keras.preprocessing.sequence import pad_sequences\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Embedding, LSTM, Dense, Flatten\n", "from sklearn.preprocessing import MultiLabelBinarizer" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "df = df[['title', 'plot_synopsis', 'tags']]" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "tokenizer = Tokenizer()\n", "tokenizer.fit_on_texts(df['title'])\n", "title_sequences = tokenizer.texts_to_sequences(df['title'])\n", "max_title_length = max(len(seq) for seq in title_sequences)\n", "title_sequences = pad_sequences(title_sequences, maxlen=max_title_length)" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "tags = [tag.split(', ') for tag in df['tags']]\n", "mlb = MultiLabelBinarizer()\n", "tags = mlb.fit_transform(tags)" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "tokenizer_json = tokenizer.to_json()\n", "with open('tokenizer.json', 'w') as json_file:\n", " json_file.write(tokenizer_json)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(title_sequences, tags, test_size=0.2, random_state=42)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "vocab_size = len(tokenizer.word_index) + 1\n", "embedding_dim = 100" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train on 11862 samples, validate on 2966 samples\n", "Epoch 1/15\n", "11862/11862 [==============================] - 10s 826us/sample - loss: 0.1911 - accuracy: 0.9457 - val_loss: 0.1417 - val_accuracy: 0.9569\n", "Epoch 2/15\n", "11862/11862 [==============================] - 11s 887us/sample - loss: 0.1390 - accuracy: 0.9583 - val_loss: 0.1416 - val_accuracy: 0.9569\n", "Epoch 3/15\n", "11862/11862 [==============================] - 11s 941us/sample - loss: 0.1388 - accuracy: 0.9583 - val_loss: 0.1415 - val_accuracy: 0.9569\n", "Epoch 4/15\n", "11862/11862 [==============================] - 11s 916us/sample - loss: 0.1367 - accuracy: 0.9583 - val_loss: 0.1420 - val_accuracy: 0.9568\n", "Epoch 5/15\n", "11862/11862 [==============================] - 11s 906us/sample - loss: 0.1310 - accuracy: 0.9595 - val_loss: 0.1433 - val_accuracy: 0.9567\n", "Epoch 6/15\n", "11862/11862 [==============================] - 11s 909us/sample - loss: 0.1248 - accuracy: 0.9608 - val_loss: 0.1444 - val_accuracy: 0.9569\n", "Epoch 7/15\n", "11862/11862 [==============================] - 11s 911us/sample - loss: 0.1184 - accuracy: 0.9624 - val_loss: 0.1461 - val_accuracy: 0.9564\n", "Epoch 8/15\n", "11862/11862 [==============================] - 11s 948us/sample - loss: 0.1123 - accuracy: 0.9649 - val_loss: 0.1484 - val_accuracy: 0.9562\n", "Epoch 9/15\n", "11862/11862 [==============================] - 11s 916us/sample - loss: 0.1069 - accuracy: 0.9668 - val_loss: 0.1509 - val_accuracy: 0.9552\n", "Epoch 10/15\n", "11862/11862 [==============================] - 11s 921us/sample - loss: 0.1021 - accuracy: 0.9682 - val_loss: 0.1537 - val_accuracy: 0.9550\n", "Epoch 11/15\n", "11862/11862 [==============================] - 11s 932us/sample - loss: 0.0978 - accuracy: 0.9692 - val_loss: 0.1566 - val_accuracy: 0.9541\n", "Epoch 12/15\n", "11862/11862 [==============================] - 11s 927us/sample - loss: 0.0937 - accuracy: 0.9700 - val_loss: 0.1591 - val_accuracy: 0.9540\n", "Epoch 13/15\n", "11862/11862 [==============================] - 11s 927us/sample - loss: 0.0896 - accuracy: 0.9710 - val_loss: 0.1621 - val_accuracy: 0.9536\n", "Epoch 14/15\n", "11862/11862 [==============================] - 11s 954us/sample - loss: 0.0857 - accuracy: 0.9719 - val_loss: 0.1660 - val_accuracy: 0.9536\n", "Epoch 15/15\n", "11862/11862 [==============================] - 12s 1ms/sample - loss: 0.0820 - accuracy: 0.9729 - val_loss: 0.1690 - val_accuracy: 0.9538\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "model = Sequential()\n", "model.add(Embedding(vocab_size, embedding_dim, input_length=max_title_length))\n", "model.add(LSTM(100))\n", "model.add(Dense(tags.shape[1], activation='sigmoid'))\n", "\n", "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n", "\n", "\n", "model.fit(X_train, y_train, batch_size=64, epochs=15, validation_data=(X_test, y_test))\n" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "model.save('story_gen.h5')" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "title = \"A oversized t-shirt\"\n" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "title_sequences = tokenizer.texts_to_sequences(title)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictions = model.predict(title_sequences)" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input Title: Spider Man\n", "Predicted Tags: [('murder',)]\n" ] } ], "source": [ "from tensorflow.keras.models import load_model\n", "with open('tokenizer.json', 'r') as f:\n", " tokenizer = tokenizer_from_json(f.read())\n", "\n", "model = load_model('story_gen.h5') \n", "\n", "example_title = \"Spider Man\"\n", "\n", "example_sequence = tokenizer.texts_to_sequences([example_title])\n", "example_sequence = pad_sequences(example_sequence, maxlen=max_title_length)\n", "\n", "predictions = model.predict(example_sequence)\n", "\n", "predicted_tags = mlb.inverse_transform((predictions > 0.5).astype(int))\n", "\n", "print(\"Input Title:\", example_title)\n", "print(\"Predicted Tags:\", predicted_tags)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }