diff --git "a/NoteBooks/Model.ipynb" "b/NoteBooks/Model.ipynb" new file mode 100644--- /dev/null +++ "b/NoteBooks/Model.ipynb" @@ -0,0 +1,3185 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Overview" + ], + "metadata": { + "id": "rfaK4bgJwoHD" + } + }, + { + "cell_type": "markdown", + "source": [ + "This notebook is used to train a matrix factorization model for recommendation.
\n", + "We'll consider the implicit features in the MovieLens100k dataset.
\n", + "We'll use tensorflow recommenders to achieve this." + ], + "metadata": { + "id": "MK2nPnXRwqLV" + } + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qA00wBE2Ntdm" + }, + "source": [ + "## Import TFRS\n", + "\n", + "First, install and import TFRS and needed packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "6yzAaM85Z12D" + }, + "outputs": [], + "source": [ + "!pip install -q tensorflow_recommenders" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "n3oYt3R6Nr9l" + }, + "outputs": [], + "source": [ + "from typing import Dict, Text\n", + "import tensorflow as tf\n", + "import tensorflow_recommenders as tfrs\n", + "# import urllib.request\n", + "# import zipfile\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "source": [ + "# python version: 3.10.11\n", + "tf.__version__, tfrs.__version__" + ], + "metadata": { + "id": "IXyCV5VaijJg", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f8be1658-7cf4-48b4-e5f4-46d619cbb418" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "('2.15.0', 'v0.7.3')" + ] + }, + "metadata": {}, + "execution_count": 3 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# python version: 3.10.11\n", + "tf.__version__, tfrs.__version__" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RXbl9WKn3h6z", + "outputId": "07cef93f-4031-4ff8-e62a-02dbac0877e0" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "('2.15.0', 'v0.7.3')" + ] + }, + "metadata": {}, + "execution_count": 4 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Load, prepare and split data" + ], + "metadata": { + "id": "XhM4W1t6keqc" + } + }, + { + "cell_type": "code", + "source": [ + "from sklearn.model_selection import train_test_split" + ], + "metadata": { + "id": "L2yFTSyV3MAt" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "ratings = pd.read_csv('ratings.csv')\n", + "movies = pd.read_csv('movies.csv')" + ], + "metadata": { + "id": "RmZbNGNqOtv-" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "ratings.head()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "QPZtsEu4PVfE", + "outputId": "020692d9-8e32-4502-83c2-596cf49e3cd6" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " userId movieId rating timestamp\n", + "0 1 1 4.0 964982703\n", + "1 1 3 4.0 964981247\n", + "2 1 6 4.0 964982224\n", + "3 1 47 5.0 964983815\n", + "4 1 50 5.0 964982931" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
userIdmovieIdratingtimestamp
0114.0964982703
1134.0964981247
2164.0964982224
31475.0964983815
41505.0964982931
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "ratings" + } + }, + "metadata": {}, + "execution_count": 7 + } + ] + }, + { + "cell_type": "code", + "source": [ + "ratings.info()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dDmrB9PFk7Qj", + "outputId": "11334f5d-42e7-4087-ba5a-5ddfd04f180a" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "RangeIndex: 100836 entries, 0 to 100835\n", + "Data columns (total 4 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 userId 100836 non-null int64 \n", + " 1 movieId 100836 non-null int64 \n", + " 2 rating 100836 non-null float64\n", + " 3 timestamp 100836 non-null int64 \n", + "dtypes: float64(1), int64(3)\n", + "memory usage: 3.1 MB\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "movies.head()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "H-kSygb2k-DM", + "outputId": "49da1f1f-ad97-42a9-b71c-82d2f22204c0" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " movieId title \\\n", + "0 1 Toy Story (1995) \n", + "1 2 Jumanji (1995) \n", + "2 3 Grumpier Old Men (1995) \n", + "3 4 Waiting to Exhale (1995) \n", + "4 5 Father of the Bride Part II (1995) \n", + "\n", + " genres \n", + "0 Adventure|Animation|Children|Comedy|Fantasy \n", + "1 Adventure|Children|Fantasy \n", + "2 Comedy|Romance \n", + "3 Comedy|Drama|Romance \n", + "4 Comedy " + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
movieIdtitlegenres
01Toy Story (1995)Adventure|Animation|Children|Comedy|Fantasy
12Jumanji (1995)Adventure|Children|Fantasy
23Grumpier Old Men (1995)Comedy|Romance
34Waiting to Exhale (1995)Comedy|Drama|Romance
45Father of the Bride Part II (1995)Comedy
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "movies", + "summary": "{\n \"name\": \"movies\",\n \"rows\": 9742,\n \"fields\": [\n {\n \"column\": \"movieId\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 52160,\n \"min\": 1,\n \"max\": 193609,\n \"num_unique_values\": 9742,\n \"samples\": [\n 45635,\n 1373,\n 7325\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"title\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 9737,\n \"samples\": [\n \"Teenage Mutant Ninja Turtles (2014)\",\n \"America's Sweethearts (2001)\",\n \"Cast Away (2000)\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"genres\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 951,\n \"samples\": [\n \"Crime|Mystery|Romance|Thriller\",\n \"Action|Adventure|Comedy|Western\",\n \"Crime|Drama|Musical\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {}, + "execution_count": 9 + } + ] + }, + { + "cell_type": "code", + "source": [ + "movies.info()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hYMFHFE0k_0G", + "outputId": "8d690028-7aa7-4c40-9d76-945efa164cd6" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "RangeIndex: 9742 entries, 0 to 9741\n", + "Data columns (total 3 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 movieId 9742 non-null int64 \n", + " 1 title 9742 non-null object\n", + " 2 genres 9742 non-null object\n", + "dtypes: int64(1), object(2)\n", + "memory usage: 228.5+ KB\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "ratings = ratings.merge(movies, on='movieId', how='inner')[['userId', 'title', 'rating', 'timestamp']].rename(columns={'title': 'movieTitle'})\n", + "ratings.head()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "ERk9t87Dy0O0", + "outputId": "4d9b0451-60db-4864-ab55-b0c355ccc692" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " userId movieTitle rating timestamp\n", + "0 1 Toy Story (1995) 4.0 964982703\n", + "1 5 Toy Story (1995) 4.0 847434962\n", + "2 7 Toy Story (1995) 4.5 1106635946\n", + "3 15 Toy Story (1995) 2.5 1510577970\n", + "4 17 Toy Story (1995) 4.5 1305696483" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
userIdmovieTitleratingtimestamp
01Toy Story (1995)4.0964982703
15Toy Story (1995)4.0847434962
27Toy Story (1995)4.51106635946
315Toy Story (1995)2.51510577970
417Toy Story (1995)4.51305696483
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "ratings" + } + }, + "metadata": {}, + "execution_count": 11 + } + ] + }, + { + "cell_type": "code", + "source": [ + "movies = movies.rename(columns={'title':'movieTitle'})\n", + "movies" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 423 + }, + "id": "7KiLC-gIRU7K", + "outputId": "899a0f40-eb70-4368-a07b-9ba5c54602cc" + }, + "execution_count": 12, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " movieId movieTitle \\\n", + "0 1 Toy Story (1995) \n", + "1 2 Jumanji (1995) \n", + "2 3 Grumpier Old Men (1995) \n", + "3 4 Waiting to Exhale (1995) \n", + "4 5 Father of the Bride Part II (1995) \n", + "... ... ... \n", + "9737 193581 Black Butler: Book of the Atlantic (2017) \n", + "9738 193583 No Game No Life: Zero (2017) \n", + "9739 193585 Flint (2017) \n", + "9740 193587 Bungo Stray Dogs: Dead Apple (2018) \n", + "9741 193609 Andrew Dice Clay: Dice Rules (1991) \n", + "\n", + " genres \n", + "0 Adventure|Animation|Children|Comedy|Fantasy \n", + "1 Adventure|Children|Fantasy \n", + "2 Comedy|Romance \n", + "3 Comedy|Drama|Romance \n", + "4 Comedy \n", + "... ... \n", + "9737 Action|Animation|Comedy|Fantasy \n", + "9738 Animation|Comedy|Fantasy \n", + "9739 Drama \n", + "9740 Action|Animation \n", + "9741 Comedy \n", + "\n", + "[9742 rows x 3 columns]" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
movieIdmovieTitlegenres
01Toy Story (1995)Adventure|Animation|Children|Comedy|Fantasy
12Jumanji (1995)Adventure|Children|Fantasy
23Grumpier Old Men (1995)Comedy|Romance
34Waiting to Exhale (1995)Comedy|Drama|Romance
45Father of the Bride Part II (1995)Comedy
............
9737193581Black Butler: Book of the Atlantic (2017)Action|Animation|Comedy|Fantasy
9738193583No Game No Life: Zero (2017)Animation|Comedy|Fantasy
9739193585Flint (2017)Drama
9740193587Bungo Stray Dogs: Dead Apple (2018)Action|Animation
9741193609Andrew Dice Clay: Dice Rules (1991)Comedy
\n", + "

9742 rows × 3 columns

\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "movies", + "summary": "{\n \"name\": \"movies\",\n \"rows\": 9742,\n \"fields\": [\n {\n \"column\": \"movieId\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 52160,\n \"min\": 1,\n \"max\": 193609,\n \"num_unique_values\": 9742,\n \"samples\": [\n 45635,\n 1373,\n 7325\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"movieTitle\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 9737,\n \"samples\": [\n \"Teenage Mutant Ninja Turtles (2014)\",\n \"America's Sweethearts (2001)\",\n \"Cast Away (2000)\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"genres\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 951,\n \"samples\": [\n \"Crime|Mystery|Romance|Thriller\",\n \"Action|Adventure|Comedy|Western\",\n \"Crime|Drama|Musical\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {}, + "execution_count": 12 + } + ] + }, + { + "cell_type": "code", + "source": [ + "ratings['userId'] = ratings['userId'].map(lambda id_int: str(id_int))\n", + "movies['movieId'] = movies['movieId'].map(lambda id_int: str(id_int))" + ], + "metadata": { + "id": "TFn8P4_Ms72Z" + }, + "execution_count": 13, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_valid , test = train_test_split(ratings, test_size=0.2, stratify=ratings['userId'], random_state=42)" + ], + "metadata": { + "id": "qPviRgX_ppSf" + }, + "execution_count": 14, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train, valid = train_test_split(train_valid, test_size=0.1, stratify=train_valid['userId'], random_state=42)" + ], + "metadata": { + "id": "sYqrSSMHjv0F" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train.head()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "kLMvev6Kzn2j", + "outputId": "6acfc427-b924-4103-fd9c-d8240056e09a" + }, + "execution_count": 16, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " userId movieTitle rating timestamp\n", + "65913 298 I Am Sam (2001) 0.5 1447598721\n", + "82997 68 Volcano (1997) 3.0 1269123535\n", + "61517 477 Waterboy, The (1998) 3.0 1200943122\n", + "81164 448 Green Lantern (2011) 1.5 1308418333\n", + "72010 57 Howards End (1992) 3.0 972174279" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
userIdmovieTitleratingtimestamp
65913298I Am Sam (2001)0.51447598721
8299768Volcano (1997)3.01269123535
61517477Waterboy, The (1998)3.01200943122
81164448Green Lantern (2011)1.51308418333
7201057Howards End (1992)3.0972174279
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "train", + "summary": "{\n \"name\": \"train\",\n \"rows\": 72601,\n \"fields\": [\n {\n \"column\": \"userId\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 610,\n \"samples\": [\n \"147\",\n \"353\",\n \"371\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"movieTitle\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 8631,\n \"samples\": [\n \"Doug's 1st Movie (1999)\",\n \"Mouse Hunt (1997)\",\n \"Glitter (2001)\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"rating\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1.042519934056567,\n \"min\": 0.5,\n \"max\": 5.0,\n \"num_unique_values\": 10,\n \"samples\": [\n 1.0,\n 3.0,\n 4.5\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"timestamp\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 216228203,\n \"min\": 828124615,\n \"max\": 1537799250,\n \"num_unique_values\": 63043,\n \"samples\": [\n 1339546356,\n 1485656568,\n 1179179328\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {}, + "execution_count": 16 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Cold Start Problem" + ], + "metadata": { + "id": "j31d4CXz_2q3" + } + }, + { + "cell_type": "markdown", + "source": [ + "For the cold start problem (new users with no history or guests with no accounts), we'll use aggregates about the movies to show the highest rated movies and most viewed movies (since we don't have the count of views, we'll use the count of ratings instead)\n", + "\n", + "\n", + "We'll create a custom class to handle this.
\n", + "We'll use thresholds to weed out movies with few ratings and movies with low ratings" + ], + "metadata": { + "id": "BzKY7RcT_6RF" + } + }, + { + "cell_type": "code", + "source": [ + "class MovieData:\n", + " def __init__(self, data, rating_threshold, count_threshold):\n", + " self.rating_threshold = rating_threshold\n", + " self.count_threshold = count_threshold\n", + " self.data = data\n", + "\n", + " def get_highest_rated(self, n=20):\n", + " # Return top n rated movies rated at least self.count_threhold times\n", + " ratings_count = self.data.groupby(['movieTitle'])['rating'].count()\n", + " popular_movies = ratings_count[ratings_count>self.count_threshold].index\n", + " highest_rated_movies = self.data[self.data['movieTitle'].isin(popular_movies)].groupby('movieTitle').mean('rating')['rating'].sort_values(ascending=False)[:n]\n", + " return highest_rated_movies\n", + "\n", + " def get_most_rated(self, n=20):\n", + " # Return top n most rated movies with average rating more than self.rating_threhold times\n", + " average_rating = self.data.groupby(['movieTitle'])['rating'].mean('rating')\n", + " popular_movies = average_rating[average_rating>self.rating_threshold].index\n", + " most_rated_movies = self.data[self.data['movieTitle'].isin(popular_movies)].groupby('movieTitle').count()['userId'].sort_values(ascending=False)[:n]\n", + " return most_rated_movies" + ], + "metadata": { + "id": "8CizyCdS8__j" + }, + "execution_count": 17, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Data Preparation" + ], + "metadata": { + "id": "EaG_BCFbkjUD" + } + }, + { + "cell_type": "markdown", + "source": [ + "We'll create a tf dataset object for our train and test sets" + ], + "metadata": { + "id": "SoLYZEBuAfYx" + } + }, + { + "cell_type": "code", + "source": [ + "train_interaction_dataset = tf.data.Dataset.from_tensor_slices({'userId':train['userId'].values, 'movieTitle': train['movieTitle'].values})\n", + "valid_interaction_dataset = tf.data.Dataset.from_tensor_slices({'userId':valid['userId'].values, 'movieTitle': valid['movieTitle'].values})\n", + "test_interaction_dataset = tf.data.Dataset.from_tensor_slices({'userId':test['userId'].values, 'movieTitle': test['movieTitle'].values})\n", + "train_interaction_dataset" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "30a7be58-6d14-44ad-ae87-85c123c8e503", + "id": "h6bR0rsJ11m_" + }, + "execution_count": 18, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "<_TensorSliceDataset element_spec={'userId': TensorSpec(shape=(), dtype=tf.string, name=None), 'movieTitle': TensorSpec(shape=(), dtype=tf.string, name=None)}>" + ] + }, + "metadata": {}, + "execution_count": 18 + } + ] + }, + { + "cell_type": "code", + "source": [ + "movie_dataset = tf.data.Dataset.from_tensor_slices(movies['movieTitle'].values)\n", + "movie_dataset" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "9ddba984-a6ae-4392-bc66-fe0f89115c9b", + "id": "sRXBt0o011nA" + }, + "execution_count": 19, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "<_TensorSliceDataset element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>" + ] + }, + "metadata": {}, + "execution_count": 19 + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "FDUuUafmz29I" + }, + "execution_count": 19, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "user_ids_vocabulary = tf.keras.layers.StringLookup(mask_token=None, name='users_lookup')\n", + "movie_titles_vocabulary = tf.keras.layers.StringLookup(mask_token=None, name='movies_lookup')" + ], + "metadata": { + "id": "wRnkJTnouAnc" + }, + "execution_count": 20, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "user_ids_vocabulary.adapt(train_interaction_dataset.map(lambda x: x['userId']))" + ], + "metadata": { + "id": "NHskzAhwuGsJ" + }, + "execution_count": 21, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "movie_titles_vocabulary.adapt(movie_dataset.map(lambda x: x))" + ], + "metadata": { + "id": "UEQj2OyYMpjI" + }, + "execution_count": 22, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "n_users = user_ids_vocabulary.vocabulary_size()\n", + "n_movies = movie_titles_vocabulary.vocabulary_size()\n", + "n_users, n_movies" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "IWzCq8vwqKAO", + "outputId": "92933395-c60b-4139-bb6a-586be308f6fb" + }, + "execution_count": 23, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(611, 9738)" + ] + }, + "metadata": {}, + "execution_count": 23 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Lrch6rVBOB9Q" + }, + "source": [ + "## Define a model\n", + "We will use matrix factorization model without context features.\n", + "We can define a TFRS model by inheriting from `tfrs.Model` and implementing the `compute_loss` method:" + ] + }, + { + "cell_type": "markdown", + "source": [ + "The task is a convenient object that wraps both the loss and the metrics" + ], + "metadata": { + "id": "9f2XJPNaAohY" + } + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "id": "e5dNbDZwOIHR" + }, + "outputs": [], + "source": [ + "class MovieLensModel(tfrs.Model):\n", + " def __init__(self, user_model: tf.keras.Model, movie_model: tf.keras.Model, task: tfrs.tasks.Retrieval):\n", + " super().__init__()\n", + "\n", + " # Set up user and movie representations.\n", + " self.user_model = user_model\n", + " self.movie_model = movie_model\n", + "\n", + " # Set up a retrieval task.\n", + " self.task = task\n", + "\n", + " def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:\n", + " # Define how the loss is computed.\n", + " user_embeddings = self.user_model(features[\"userId\"])\n", + " movie_embeddings = self.movie_model(features[\"movieTitle\"])\n", + " return self.task(user_embeddings, movie_embeddings)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wdwtgUCEOI8y" + }, + "source": [ + "Define the two models and the retrieval task." + ] + }, + { + "cell_type": "code", + "source": [ + "movie_model = tf.keras.Sequential([\n", + " movie_titles_vocabulary,\n", + " tf.keras.layers.Embedding(n_movies, 64, name='movie_embedding')\n", + "], name='movie_model')" + ], + "metadata": { + "id": "AL0p56OGNeS1" + }, + "execution_count": 25, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "user_model = tf.keras.Sequential([\n", + " user_ids_vocabulary,\n", + " tf.keras.layers.Embedding(n_users, 64, name='user_embedding')\n", + "], name='user_model')" + ], + "metadata": { + "id": "90Ksu3nTNeS2" + }, + "execution_count": 26, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "ks is the k for top_k metrics. We use multiple ks" + ], + "metadata": { + "id": "WXM-rOyTAwNu" + } + }, + { + "cell_type": "code", + "source": [ + "task = tfrs.tasks.Retrieval(metrics=tfrs.metrics.FactorizedTopK(\n", + " candidates=movie_dataset.batch(128).map(movie_model),\n", + " ks = (1, 5, 10)\n", + " )\n", + ")" + ], + "metadata": { + "id": "ZpmNCMvgrq6y" + }, + "execution_count": 27, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BMV0HpzmJGWk" + }, + "source": [ + "## Fit and evaluate it.\n", + "\n", + "Create the model, train it, and generate predictions:\n", + "\n" + ] + }, + { + "cell_type": "code", + "source": [ + "model = MovieLensModel(user_model, movie_model, task)\n", + "model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.01))" + ], + "metadata": { + "id": "5tNGoZcVsRjb" + }, + "execution_count": 28, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model.fit(train_interaction_dataset.batch(4096), epochs=15, validation_data=valid_interaction_dataset.batch(1024))" + ], + "metadata": { + "id": "09g9rkwcvieN", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "1f80e7d9-cef0-40c2-c8e9-1b89c99e2c9f" + }, + "execution_count": 29, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/15\n", + "18/18 [==============================] - 33s 2s/step - factorized_top_k/top_1_categorical_accuracy: 4.1322e-05 - factorized_top_k/top_5_categorical_accuracy: 4.8209e-04 - factorized_top_k/top_10_categorical_accuracy: 8.4021e-04 - loss: 32982.6449 - regularization_loss: 0.0000e+00 - total_loss: 32982.6449 - val_factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - val_factorized_top_k/top_10_categorical_accuracy: 3.7189e-04 - val_loss: 6114.4116 - val_regularization_loss: 0.0000e+00 - val_total_loss: 6114.4116\n", + "Epoch 2/15\n", + "18/18 [==============================] - 30s 2s/step - factorized_top_k/top_1_categorical_accuracy: 3.0303e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0027 - factorized_top_k/top_10_categorical_accuracy: 0.0052 - loss: 32968.2183 - regularization_loss: 0.0000e+00 - total_loss: 32968.2183 - val_factorized_top_k/top_1_categorical_accuracy: 3.7189e-04 - val_factorized_top_k/top_5_categorical_accuracy: 0.0026 - val_factorized_top_k/top_10_categorical_accuracy: 0.0045 - val_loss: 6112.6470 - val_regularization_loss: 0.0000e+00 - val_total_loss: 6112.6470\n", + "Epoch 3/15\n", + "18/18 [==============================] - 31s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0013 - factorized_top_k/top_5_categorical_accuracy: 0.0109 - factorized_top_k/top_10_categorical_accuracy: 0.0207 - loss: 32938.6277 - regularization_loss: 0.0000e+00 - total_loss: 32938.6277 - val_factorized_top_k/top_1_categorical_accuracy: 4.9585e-04 - val_factorized_top_k/top_5_categorical_accuracy: 0.0053 - val_factorized_top_k/top_10_categorical_accuracy: 0.0102 - val_loss: 6106.6216 - val_regularization_loss: 0.0000e+00 - val_total_loss: 6106.6216\n", + "Epoch 4/15\n", + "18/18 [==============================] - 29s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0018 - factorized_top_k/top_5_categorical_accuracy: 0.0161 - factorized_top_k/top_10_categorical_accuracy: 0.0321 - loss: 32879.1801 - regularization_loss: 0.0000e+00 - total_loss: 32879.1801 - val_factorized_top_k/top_1_categorical_accuracy: 1.2396e-04 - val_factorized_top_k/top_5_categorical_accuracy: 0.0071 - val_factorized_top_k/top_10_categorical_accuracy: 0.0140 - val_loss: 6094.2881 - val_regularization_loss: 0.0000e+00 - val_total_loss: 6094.2881\n", + "Epoch 5/15\n", + "18/18 [==============================] - 29s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0021 - factorized_top_k/top_5_categorical_accuracy: 0.0179 - factorized_top_k/top_10_categorical_accuracy: 0.0349 - loss: 32783.4482 - regularization_loss: 0.0000e+00 - total_loss: 32783.4482 - val_factorized_top_k/top_1_categorical_accuracy: 2.4792e-04 - val_factorized_top_k/top_5_categorical_accuracy: 0.0079 - val_factorized_top_k/top_10_categorical_accuracy: 0.0164 - val_loss: 6075.7705 - val_regularization_loss: 0.0000e+00 - val_total_loss: 6075.7705\n", + "Epoch 6/15\n", + "18/18 [==============================] - 30s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0023 - factorized_top_k/top_5_categorical_accuracy: 0.0186 - factorized_top_k/top_10_categorical_accuracy: 0.0356 - loss: 32655.7899 - regularization_loss: 0.0000e+00 - total_loss: 32655.7899 - val_factorized_top_k/top_1_categorical_accuracy: 3.7189e-04 - val_factorized_top_k/top_5_categorical_accuracy: 0.0081 - val_factorized_top_k/top_10_categorical_accuracy: 0.0171 - val_loss: 6052.8936 - val_regularization_loss: 0.0000e+00 - val_total_loss: 6052.8936\n", + "Epoch 7/15\n", + "18/18 [==============================] - 30s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0018 - factorized_top_k/top_5_categorical_accuracy: 0.0185 - factorized_top_k/top_10_categorical_accuracy: 0.0356 - loss: 32506.7054 - regularization_loss: 0.0000e+00 - total_loss: 32506.7054 - val_factorized_top_k/top_1_categorical_accuracy: 6.1981e-04 - val_factorized_top_k/top_5_categorical_accuracy: 0.0079 - val_factorized_top_k/top_10_categorical_accuracy: 0.0188 - val_loss: 6028.0103 - val_regularization_loss: 0.0000e+00 - val_total_loss: 6028.0103\n", + "Epoch 8/15\n", + "18/18 [==============================] - 34s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0020 - factorized_top_k/top_5_categorical_accuracy: 0.0183 - factorized_top_k/top_10_categorical_accuracy: 0.0351 - loss: 32347.6164 - regularization_loss: 0.0000e+00 - total_loss: 32347.6164 - val_factorized_top_k/top_1_categorical_accuracy: 8.6773e-04 - val_factorized_top_k/top_5_categorical_accuracy: 0.0087 - val_factorized_top_k/top_10_categorical_accuracy: 0.0196 - val_loss: 6003.1934 - val_regularization_loss: 0.0000e+00 - val_total_loss: 6003.1934\n", + "Epoch 9/15\n", + "18/18 [==============================] - 31s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0018 - factorized_top_k/top_5_categorical_accuracy: 0.0179 - factorized_top_k/top_10_categorical_accuracy: 0.0353 - loss: 32187.7872 - regularization_loss: 0.0000e+00 - total_loss: 32187.7872 - val_factorized_top_k/top_1_categorical_accuracy: 0.0011 - val_factorized_top_k/top_5_categorical_accuracy: 0.0083 - val_factorized_top_k/top_10_categorical_accuracy: 0.0197 - val_loss: 5979.8662 - val_regularization_loss: 0.0000e+00 - val_total_loss: 5979.8662\n", + "Epoch 10/15\n", + "18/18 [==============================] - 29s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0017 - factorized_top_k/top_5_categorical_accuracy: 0.0178 - factorized_top_k/top_10_categorical_accuracy: 0.0352 - loss: 32033.2512 - regularization_loss: 0.0000e+00 - total_loss: 32033.2512 - val_factorized_top_k/top_1_categorical_accuracy: 8.6773e-04 - val_factorized_top_k/top_5_categorical_accuracy: 0.0089 - val_factorized_top_k/top_10_categorical_accuracy: 0.0203 - val_loss: 5958.7803 - val_regularization_loss: 0.0000e+00 - val_total_loss: 5958.7803\n", + "Epoch 11/15\n", + "18/18 [==============================] - 29s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0017 - factorized_top_k/top_5_categorical_accuracy: 0.0179 - factorized_top_k/top_10_categorical_accuracy: 0.0352 - loss: 31887.0750 - regularization_loss: 0.0000e+00 - total_loss: 31887.0750 - val_factorized_top_k/top_1_categorical_accuracy: 8.6773e-04 - val_factorized_top_k/top_5_categorical_accuracy: 0.0088 - val_factorized_top_k/top_10_categorical_accuracy: 0.0197 - val_loss: 5940.1665 - val_regularization_loss: 0.0000e+00 - val_total_loss: 5940.1665\n", + "Epoch 12/15\n", + "18/18 [==============================] - 29s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0016 - factorized_top_k/top_5_categorical_accuracy: 0.0178 - factorized_top_k/top_10_categorical_accuracy: 0.0354 - loss: 31750.1961 - regularization_loss: 0.0000e+00 - total_loss: 31750.1961 - val_factorized_top_k/top_1_categorical_accuracy: 0.0014 - val_factorized_top_k/top_5_categorical_accuracy: 0.0088 - val_factorized_top_k/top_10_categorical_accuracy: 0.0202 - val_loss: 5923.9399 - val_regularization_loss: 0.0000e+00 - val_total_loss: 5923.9399\n", + "Epoch 13/15\n", + "18/18 [==============================] - 31s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0018 - factorized_top_k/top_5_categorical_accuracy: 0.0181 - factorized_top_k/top_10_categorical_accuracy: 0.0360 - loss: 31622.3099 - regularization_loss: 0.0000e+00 - total_loss: 31622.3099 - val_factorized_top_k/top_1_categorical_accuracy: 0.0017 - val_factorized_top_k/top_5_categorical_accuracy: 0.0093 - val_factorized_top_k/top_10_categorical_accuracy: 0.0196 - val_loss: 5909.8633 - val_regularization_loss: 0.0000e+00 - val_total_loss: 5909.8633\n", + "Epoch 14/15\n", + "18/18 [==============================] - 29s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0017 - factorized_top_k/top_5_categorical_accuracy: 0.0178 - factorized_top_k/top_10_categorical_accuracy: 0.0359 - loss: 31502.5373 - regularization_loss: 0.0000e+00 - total_loss: 31502.5373 - val_factorized_top_k/top_1_categorical_accuracy: 0.0014 - val_factorized_top_k/top_5_categorical_accuracy: 0.0095 - val_factorized_top_k/top_10_categorical_accuracy: 0.0193 - val_loss: 5897.6528 - val_regularization_loss: 0.0000e+00 - val_total_loss: 5897.6528\n", + "Epoch 15/15\n", + "18/18 [==============================] - 29s 2s/step - factorized_top_k/top_1_categorical_accuracy: 0.0017 - factorized_top_k/top_5_categorical_accuracy: 0.0180 - factorized_top_k/top_10_categorical_accuracy: 0.0361 - loss: 31389.8361 - regularization_loss: 0.0000e+00 - total_loss: 31389.8361 - val_factorized_top_k/top_1_categorical_accuracy: 0.0014 - val_factorized_top_k/top_5_categorical_accuracy: 0.0104 - val_factorized_top_k/top_10_categorical_accuracy: 0.0182 - val_loss: 5887.0361 - val_regularization_loss: 0.0000e+00 - val_total_loss: 5887.0361\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 29 + } + ] + }, + { + "cell_type": "code", + "source": [ + "model.evaluate(test_interaction_dataset.batch(1024))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kdJwU-w-3uDN", + "outputId": "848dbf08-5639-4cf2-e855-6b8da91d2f57" + }, + "execution_count": 30, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "20/20 [==============================] - 8s 407ms/step - factorized_top_k/top_1_categorical_accuracy: 9.4209e-04 - factorized_top_k/top_5_categorical_accuracy: 0.0080 - factorized_top_k/top_10_categorical_accuracy: 0.0175 - loss: 6631.0894 - regularization_loss: 0.0000e+00 - total_loss: 6631.0894\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[0.0009420864516869187,\n", + " 0.007982943207025528,\n", + " 0.017502974718809128,\n", + " 4515.97607421875,\n", + " 0,\n", + " 4515.97607421875]" + ] + }, + "metadata": {}, + "execution_count": 30 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# model.evaluate(test_interaction_dataset.batch(1024))" + ], + "metadata": { + "id": "6atHi6wqq9HD" + }, + "execution_count": 31, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Indexers" + ], + "metadata": { + "id": "Q3dGH_I7A_ZO" + } + }, + { + "cell_type": "markdown", + "source": [ + "Indexers use store the embedding of the possible candidates as keys. When it receives a query, it embeds the query and retrieves the closest keys.\n", + "\n", + "For our recommendation task, it stores the embeddings of movies and the embedding of users. When we want to recommend for a user, it gets the movies whose embedding are the most similar (using dot product) to the user." + ], + "metadata": { + "id": "FTzjq0qLBCEX" + } + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "id": "neJAJVwbReNd" + }, + "outputs": [], + "source": [ + "# Use brute-force search to set up retrieval using the trained representations.\n", + "user_recommender = tfrs.layers.factorized_top_k.BruteForce(model.user_model, k=100)" + ] + }, + { + "cell_type": "code", + "source": [ + "user_recommender.index_from_dataset(\n", + " movie_dataset.batch(100).map(lambda title: (title, model.movie_model(title))))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "_GXqrR69UeWl", + "outputId": "6a7c9f13-83eb-4057-d42f-25525e259c4b" + }, + "execution_count": 33, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 33 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Get some recommendations.\n", + "_, titles = user_recommender(tf.constant([\"42\"]))\n", + "print(f\"Top 3 recommendations for user 42: {titles.shape}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gHnxsVx1UfmF", + "outputId": "bfbe4bb2-972f-49e1-8f9d-5910a76aae58" + }, + "execution_count": 34, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Top 3 recommendations for user 42: (1, 100)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "#### Item-Item recommendation" + ], + "metadata": { + "id": "VPI7iMrtBfvk" + } + }, + { + "cell_type": "markdown", + "source": [ + "For items similarity, we can use the embedding of movies as both query and keys" + ], + "metadata": { + "id": "ugcrEkD8BjJu" + } + }, + { + "cell_type": "code", + "source": [ + "movie_recommender = tfrs.layers.factorized_top_k.BruteForce(model.movie_model, k=100)" + ], + "metadata": { + "id": "mb_nnuziM73X" + }, + "execution_count": 35, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "movie_recommender.index_from_dataset(\n", + " movie_dataset.batch(100).map(lambda title: (title, model.movie_model(title))))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "kbX-Uob8NH9y", + "outputId": "7d9f8a2a-09c1-4cdb-ee56-37c6f1c14d14" + }, + "execution_count": 36, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 36 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Get some recommendations.\n", + "_, titles2 = movie_recommender(tf.constant([\"Freaky Friday (2003)\"]))\n", + "print(f\"Top 3 recommendations for movie 42: {titles2.shape}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Fy8ijm54PYuX", + "outputId": "c1683598-675a-43bf-e80b-e567ddf60fa5" + }, + "execution_count": 37, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Top 3 recommendations for movie 42: (1, 100)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Get some recommendations.\n", + "_, titles2 = movie_recommender(tf.constant([\"Freaky Friday (2003)\"]), k=25)\n", + "print(f\"Top 3 recommendations for movie 42: {titles2.shape}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ym-Ppz3OthFI", + "outputId": "ba083203-d097-4d1a-c08d-aa884188b777" + }, + "execution_count": 38, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Top 3 recommendations for movie 42: (1, 25)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Saving the models" + ], + "metadata": { + "id": "7d3vv_Kr7rQn" + } + }, + { + "cell_type": "code", + "source": [ + "user_recommender.save('user_model')\n", + "movie_recommender.save('movie_model')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1LVFawoR7uTY", + "outputId": "44e7c8d9-1b77-4463-d814-44f3e9b84c94" + }, + "execution_count": 39, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n", + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n", + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n", + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n", + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n", + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n", + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n", + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "tmp1 = tf.keras.models.load_model('user_model')\n", + "tmp2 = tf.keras.models.load_model('movie_model')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "x10fuwFg711U", + "outputId": "7e12ed41-c6cb-4f88-aab2-80a2af4d5307" + }, + "execution_count": 40, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n", + "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Get some recommendations.\n", + "_, titles = tmp1(tf.constant([\"42\"]))\n", + "print(f\"Top 3 recommendations for user 42: {titles.shape}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YcEvv-XP8GyX", + "outputId": "cc7621fc-b213-4aaf-a1bb-ea67f10b9232" + }, + "execution_count": 41, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Top 3 recommendations for user 42: (1, 100)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Get some recommendations.\n", + "_, titles2 = tmp2(tf.constant([\"Freaky Friday (2003)\"]))\n", + "print(f\"Top 3 recommendations for movie 42: {titles2[0, :10]}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4Yo5ZIsY8Jx3", + "outputId": "2f91dc46-b584-4e25-c886-1636cc40b68d" + }, + "execution_count": 42, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Top 3 recommendations for movie 42: [b'Freaky Friday (2003)' b'What Women Want (2000)'\n", + " b'Wedding Crashers (2005)' b'Shrek the Third (2007)'\n", + " b'Atlantis: The Lost Empire (2001)' b'Along Came Polly (2004)'\n", + " b'Princess Diaries, The (2001)'\n", + " b\"Hitchhiker's Guide to the Galaxy, The (2005)\" b'Holes (2003)'\n", + " b'Mean Girls (2004)']\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Explicit rating" + ], + "metadata": { + "id": "jkYqtqcNFKBV" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Data Prepatation" + ], + "metadata": { + "id": "X58OruuYFt2s" + } + }, + { + "cell_type": "code", + "source": [ + "train_rating_dataset = tf.data.Dataset.from_tensor_slices({'userId':train['userId'].values, 'movieTitle': train['movieTitle'].values, 'rating': train['rating'].values})\n", + "valid_rating_dataset = tf.data.Dataset.from_tensor_slices({'userId':valid['userId'].values, 'movieTitle': valid['movieTitle'].values, 'rating': valid['rating'].values})\n", + "test_rating_dataset = tf.data.Dataset.from_tensor_slices({'userId':test['userId'].values, 'movieTitle': test['movieTitle'].values, 'rating': test['rating'].values})\n", + "train_rating_dataset" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PuiVqj_vFOU1", + "outputId": "7015eeb6-a40c-470b-bd9a-cafe41de798a" + }, + "execution_count": 43, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "<_TensorSliceDataset element_spec={'userId': TensorSpec(shape=(), dtype=tf.string, name=None), 'movieTitle': TensorSpec(shape=(), dtype=tf.string, name=None), 'rating': TensorSpec(shape=(), dtype=tf.float64, name=None)}>" + ] + }, + "metadata": {}, + "execution_count": 43 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# user_ids_vocabulary.adapt(train_rating_dataset.map(lambda x: x['userId']))" + ], + "metadata": { + "id": "MabLW_CEMUWn" + }, + "execution_count": 44, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# movie_titles_vocabulary.adapt(movie_dataset.map(lambda x: x))" + ], + "metadata": { + "id": "1Ab9OgACMUWn" + }, + "execution_count": 45, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# n_users = user_ids_vocabulary.vocabulary_size()\n", + "# n_movies = movie_titles_vocabulary.vocabulary_size()\n", + "# n_users, n_movies" + ], + "metadata": { + "id": "OKp9tGC3MUWn" + }, + "execution_count": 46, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Define model" + ], + "metadata": { + "id": "lesKfVayFvli" + } + }, + { + "cell_type": "code", + "source": [ + "ranking_task = tfrs.tasks.Ranking(\n", + " loss = tf.keras.losses.MeanSquaredError(),\n", + " metrics=[tf.keras.metrics.RootMeanSquaredError()]\n", + " )" + ], + "metadata": { + "id": "OCw39lK8GZ-_" + }, + "execution_count": 47, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "rating_model = tf.keras.Sequential([\n", + " tf.keras.layers.Dense(256, activation='relu'),\n", + " tf.keras.layers.Dense(64, activation='relu'),\n", + " tf.keras.layers.Dense(1)\n", + "], name='raing_model')" + ], + "metadata": { + "id": "3Rk14hLWGaYv" + }, + "execution_count": 48, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "explicit_movie_model = tf.keras.Sequential([\n", + " movie_titles_vocabulary,\n", + " tf.keras.layers.Embedding(n_movies, 64, name='movie_embedding')\n", + "], name='movie_model')" + ], + "metadata": { + "id": "kyTH12nfMMRR" + }, + "execution_count": 49, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "explicit_user_model = tf.keras.Sequential([\n", + " user_ids_vocabulary,\n", + " tf.keras.layers.Embedding(n_users, 64, name='user_embedding')\n", + "], name='user_model')" + ], + "metadata": { + "id": "aL27lNSsMMRR" + }, + "execution_count": 50, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class ExplicitMovieLensModel(tfrs.Model):\n", + " def __init__(self, user_model: tf.keras.Model, movie_model: tf.keras.Model, rating_model:tf.keras.Model, task: tfrs.tasks.Retrieval):\n", + " super().__init__()\n", + "\n", + " # Set up user and movie representations.\n", + " self.user_model = user_model\n", + " self.movie_model = movie_model\n", + "\n", + " # Compute predictions.\n", + " self.rating_model = rating_model\n", + "\n", + " # Set up a ranking task.\n", + " self.task = task\n", + "\n", + " def call(self, features: Dict[str, tf.Tensor]) -> tf.Tensor:\n", + " user_embeddings = self.user_model(features[\"userId\"])\n", + " movie_embeddings = self.movie_model(features[\"movieTitle\"])\n", + " return self.rating_model(tf.concat([user_embeddings, movie_embeddings], axis=1))\n", + "\n", + " def compute_loss(self, features: Dict[Text, tf.Tensor], training=False) -> tf.Tensor:\n", + " labels = features.pop(\"rating\")\n", + "\n", + " rating_predictions = self(features)\n", + "\n", + " # The task computes the loss and the metrics.\n", + " return self.task(labels=labels, predictions=rating_predictions)" + ], + "metadata": { + "id": "ShsW910MF4Qy" + }, + "execution_count": 51, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "explicit_model = ExplicitMovieLensModel(user_model=explicit_user_model, movie_model=explicit_movie_model, rating_model=rating_model, task=ranking_task)\n", + "explicit_model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.05))" + ], + "metadata": { + "id": "Fbw7_TETLy2w" + }, + "execution_count": 52, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "explicit_model.fit(train_rating_dataset.batch(4096), epochs=40, validation_data=valid_rating_dataset.batch(1024))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "U3wvyygHTte6", + "outputId": "7a6ae84e-b54f-42dc-83bd-ff86a79f0410" + }, + "execution_count": 53, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/40\n", + "18/18 [==============================] - 2s 67ms/step - root_mean_squared_error: 1.6292 - loss: 2.5479 - regularization_loss: 0.0000e+00 - total_loss: 2.5479 - val_root_mean_squared_error: 1.0355 - val_loss: 1.1405 - val_regularization_loss: 0.0000e+00 - val_total_loss: 1.1405\n", + "Epoch 2/40\n", + "18/18 [==============================] - 1s 69ms/step - root_mean_squared_error: 1.0238 - loss: 1.0466 - regularization_loss: 0.0000e+00 - total_loss: 1.0466 - val_root_mean_squared_error: 1.0144 - val_loss: 1.0892 - val_regularization_loss: 0.0000e+00 - val_total_loss: 1.0892\n", + "Epoch 3/40\n", + "18/18 [==============================] - 1s 72ms/step - root_mean_squared_error: 1.0014 - loss: 1.0015 - regularization_loss: 0.0000e+00 - total_loss: 1.0015 - val_root_mean_squared_error: 0.9911 - val_loss: 1.0299 - val_regularization_loss: 0.0000e+00 - val_total_loss: 1.0299\n", + "Epoch 4/40\n", + "18/18 [==============================] - 1s 76ms/step - root_mean_squared_error: 0.9802 - loss: 0.9601 - regularization_loss: 0.0000e+00 - total_loss: 0.9601 - val_root_mean_squared_error: 0.9713 - val_loss: 0.9761 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.9761\n", + "Epoch 5/40\n", + "18/18 [==============================] - 1s 40ms/step - root_mean_squared_error: 0.9640 - loss: 0.9292 - regularization_loss: 0.0000e+00 - total_loss: 0.9292 - val_root_mean_squared_error: 0.9569 - val_loss: 0.9362 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.9362\n", + "Epoch 6/40\n", + "18/18 [==============================] - 1s 41ms/step - root_mean_squared_error: 0.9525 - loss: 0.9077 - regularization_loss: 0.0000e+00 - total_loss: 0.9077 - val_root_mean_squared_error: 0.9470 - val_loss: 0.9091 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.9091\n", + "Epoch 7/40\n", + "18/18 [==============================] - 1s 42ms/step - root_mean_squared_error: 0.9434 - loss: 0.8908 - regularization_loss: 0.0000e+00 - total_loss: 0.8908 - val_root_mean_squared_error: 0.9385 - val_loss: 0.8879 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.8879\n", + "Epoch 8/40\n", + "18/18 [==============================] - 1s 42ms/step - root_mean_squared_error: 0.9343 - loss: 0.8737 - regularization_loss: 0.0000e+00 - total_loss: 0.8737 - val_root_mean_squared_error: 0.9305 - val_loss: 0.8693 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.8693\n", + "Epoch 9/40\n", + "18/18 [==============================] - 1s 42ms/step - root_mean_squared_error: 0.9255 - loss: 0.8575 - regularization_loss: 0.0000e+00 - total_loss: 0.8575 - val_root_mean_squared_error: 0.9236 - val_loss: 0.8542 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.8542\n", + "Epoch 10/40\n", + "18/18 [==============================] - 1s 42ms/step - root_mean_squared_error: 0.9178 - loss: 0.8433 - regularization_loss: 0.0000e+00 - total_loss: 0.8433 - val_root_mean_squared_error: 0.9179 - val_loss: 0.8419 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.8419\n", + "Epoch 11/40\n", + "18/18 [==============================] - 1s 42ms/step - root_mean_squared_error: 0.9108 - loss: 0.8304 - regularization_loss: 0.0000e+00 - total_loss: 0.8304 - val_root_mean_squared_error: 0.9129 - val_loss: 0.8314 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.8314\n", + "Epoch 12/40\n", + "18/18 [==============================] - 1s 43ms/step - root_mean_squared_error: 0.9042 - loss: 0.8184 - regularization_loss: 0.0000e+00 - total_loss: 0.8184 - val_root_mean_squared_error: 0.9085 - val_loss: 0.8222 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.8222\n", + "Epoch 13/40\n", + "18/18 [==============================] - 1s 55ms/step - root_mean_squared_error: 0.8980 - loss: 0.8071 - regularization_loss: 0.0000e+00 - total_loss: 0.8071 - val_root_mean_squared_error: 0.9045 - val_loss: 0.8141 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.8141\n", + "Epoch 14/40\n", + "18/18 [==============================] - 1s 67ms/step - root_mean_squared_error: 0.8921 - loss: 0.7965 - regularization_loss: 0.0000e+00 - total_loss: 0.7965 - val_root_mean_squared_error: 0.9010 - val_loss: 0.8068 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.8068\n", + "Epoch 15/40\n", + "18/18 [==============================] - 1s 71ms/step - root_mean_squared_error: 0.8866 - loss: 0.7866 - regularization_loss: 0.0000e+00 - total_loss: 0.7866 - val_root_mean_squared_error: 0.8978 - val_loss: 0.8002 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.8002\n", + "Epoch 16/40\n", + "18/18 [==============================] - 1s 41ms/step - root_mean_squared_error: 0.8814 - loss: 0.7774 - regularization_loss: 0.0000e+00 - total_loss: 0.7774 - val_root_mean_squared_error: 0.8949 - val_loss: 0.7943 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7943\n", + "Epoch 17/40\n", + "18/18 [==============================] - 1s 40ms/step - root_mean_squared_error: 0.8765 - loss: 0.7687 - regularization_loss: 0.0000e+00 - total_loss: 0.7687 - val_root_mean_squared_error: 0.8923 - val_loss: 0.7890 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7890\n", + "Epoch 18/40\n", + "18/18 [==============================] - 1s 43ms/step - root_mean_squared_error: 0.8719 - loss: 0.7606 - regularization_loss: 0.0000e+00 - total_loss: 0.7606 - val_root_mean_squared_error: 0.8899 - val_loss: 0.7841 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7841\n", + "Epoch 19/40\n", + "18/18 [==============================] - 1s 40ms/step - root_mean_squared_error: 0.8676 - loss: 0.7529 - regularization_loss: 0.0000e+00 - total_loss: 0.7529 - val_root_mean_squared_error: 0.8877 - val_loss: 0.7797 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7797\n", + "Epoch 20/40\n", + "18/18 [==============================] - 1s 42ms/step - root_mean_squared_error: 0.8634 - loss: 0.7457 - regularization_loss: 0.0000e+00 - total_loss: 0.7457 - val_root_mean_squared_error: 0.8857 - val_loss: 0.7757 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7757\n", + "Epoch 21/40\n", + "18/18 [==============================] - 1s 41ms/step - root_mean_squared_error: 0.8596 - loss: 0.7389 - regularization_loss: 0.0000e+00 - total_loss: 0.7389 - val_root_mean_squared_error: 0.8839 - val_loss: 0.7721 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7721\n", + "Epoch 22/40\n", + "18/18 [==============================] - 1s 40ms/step - root_mean_squared_error: 0.8559 - loss: 0.7326 - regularization_loss: 0.0000e+00 - total_loss: 0.7326 - val_root_mean_squared_error: 0.8824 - val_loss: 0.7689 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7689\n", + "Epoch 23/40\n", + "18/18 [==============================] - 1s 40ms/step - root_mean_squared_error: 0.8525 - loss: 0.7267 - regularization_loss: 0.0000e+00 - total_loss: 0.7267 - val_root_mean_squared_error: 0.8809 - val_loss: 0.7660 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7660\n", + "Epoch 24/40\n", + "18/18 [==============================] - 1s 39ms/step - root_mean_squared_error: 0.8493 - loss: 0.7211 - regularization_loss: 0.0000e+00 - total_loss: 0.7211 - val_root_mean_squared_error: 0.8796 - val_loss: 0.7634 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7634\n", + "Epoch 25/40\n", + "18/18 [==============================] - 1s 48ms/step - root_mean_squared_error: 0.8462 - loss: 0.7158 - regularization_loss: 0.0000e+00 - total_loss: 0.7158 - val_root_mean_squared_error: 0.8785 - val_loss: 0.7610 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7610\n", + "Epoch 26/40\n", + "18/18 [==============================] - 1s 73ms/step - root_mean_squared_error: 0.8433 - loss: 0.7108 - regularization_loss: 0.0000e+00 - total_loss: 0.7108 - val_root_mean_squared_error: 0.8775 - val_loss: 0.7588 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7588\n", + "Epoch 27/40\n", + "18/18 [==============================] - 1s 40ms/step - root_mean_squared_error: 0.8405 - loss: 0.7061 - regularization_loss: 0.0000e+00 - total_loss: 0.7061 - val_root_mean_squared_error: 0.8766 - val_loss: 0.7568 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7568\n", + "Epoch 28/40\n", + "18/18 [==============================] - 1s 43ms/step - root_mean_squared_error: 0.8379 - loss: 0.7016 - regularization_loss: 0.0000e+00 - total_loss: 0.7016 - val_root_mean_squared_error: 0.8759 - val_loss: 0.7549 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7549\n", + "Epoch 29/40\n", + "18/18 [==============================] - 1s 40ms/step - root_mean_squared_error: 0.8354 - loss: 0.6975 - regularization_loss: 0.0000e+00 - total_loss: 0.6975 - val_root_mean_squared_error: 0.8752 - val_loss: 0.7533 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7533\n", + "Epoch 30/40\n", + "18/18 [==============================] - 1s 41ms/step - root_mean_squared_error: 0.8331 - loss: 0.6935 - regularization_loss: 0.0000e+00 - total_loss: 0.6935 - val_root_mean_squared_error: 0.8747 - val_loss: 0.7517 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7517\n", + "Epoch 31/40\n", + "18/18 [==============================] - 1s 41ms/step - root_mean_squared_error: 0.8309 - loss: 0.6898 - regularization_loss: 0.0000e+00 - total_loss: 0.6898 - val_root_mean_squared_error: 0.8742 - val_loss: 0.7502 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7502\n", + "Epoch 32/40\n", + "18/18 [==============================] - 1s 40ms/step - root_mean_squared_error: 0.8288 - loss: 0.6863 - regularization_loss: 0.0000e+00 - total_loss: 0.6863 - val_root_mean_squared_error: 0.8738 - val_loss: 0.7489 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7489\n", + "Epoch 33/40\n", + "18/18 [==============================] - 1s 43ms/step - root_mean_squared_error: 0.8268 - loss: 0.6830 - regularization_loss: 0.0000e+00 - total_loss: 0.6830 - val_root_mean_squared_error: 0.8735 - val_loss: 0.7476 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7476\n", + "Epoch 34/40\n", + "18/18 [==============================] - 1s 40ms/step - root_mean_squared_error: 0.8249 - loss: 0.6798 - regularization_loss: 0.0000e+00 - total_loss: 0.6798 - val_root_mean_squared_error: 0.8732 - val_loss: 0.7464 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7464\n", + "Epoch 35/40\n", + "18/18 [==============================] - 1s 40ms/step - root_mean_squared_error: 0.8231 - loss: 0.6768 - regularization_loss: 0.0000e+00 - total_loss: 0.6768 - val_root_mean_squared_error: 0.8729 - val_loss: 0.7452 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7452\n", + "Epoch 36/40\n", + "18/18 [==============================] - 1s 72ms/step - root_mean_squared_error: 0.8214 - loss: 0.6738 - regularization_loss: 0.0000e+00 - total_loss: 0.6738 - val_root_mean_squared_error: 0.8728 - val_loss: 0.7441 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7441\n", + "Epoch 37/40\n", + "18/18 [==============================] - 1s 39ms/step - root_mean_squared_error: 0.8197 - loss: 0.6710 - regularization_loss: 0.0000e+00 - total_loss: 0.6710 - val_root_mean_squared_error: 0.8726 - val_loss: 0.7430 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7430\n", + "Epoch 38/40\n", + "18/18 [==============================] - 1s 42ms/step - root_mean_squared_error: 0.8181 - loss: 0.6684 - regularization_loss: 0.0000e+00 - total_loss: 0.6684 - val_root_mean_squared_error: 0.8725 - val_loss: 0.7420 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7420\n", + "Epoch 39/40\n", + "18/18 [==============================] - 1s 41ms/step - root_mean_squared_error: 0.8165 - loss: 0.6658 - regularization_loss: 0.0000e+00 - total_loss: 0.6658 - val_root_mean_squared_error: 0.8724 - val_loss: 0.7411 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7411\n", + "Epoch 40/40\n", + "18/18 [==============================] - 1s 39ms/step - root_mean_squared_error: 0.8151 - loss: 0.6634 - regularization_loss: 0.0000e+00 - total_loss: 0.6634 - val_root_mean_squared_error: 0.8724 - val_loss: 0.7402 - val_regularization_loss: 0.0000e+00 - val_total_loss: 0.7402\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 53 + } + ] + }, + { + "cell_type": "code", + "source": [ + "explicit_model.evaluate(test_rating_dataset.batch(1024))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4JyRiwnqm8ab", + "outputId": "c4fad5b0-63fe-4aa1-d49a-85b292143e80" + }, + "execution_count": 54, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "20/20 [==============================] - 0s 7ms/step - root_mean_squared_error: 0.8827 - loss: 0.7832 - regularization_loss: 0.0000e+00 - total_loss: 0.7832\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[0.8827441334724426, 0.8436796069145203, 0, 0.8436796069145203]" + ] + }, + "metadata": {}, + "execution_count": 54 + } + ] + }, + { + "cell_type": "code", + "source": [ + "all_movies = movies['movieTitle'].unique().reshape(-1,1)" + ], + "metadata": { + "id": "V0EIvf7da1AI" + }, + "execution_count": 55, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Predict rating for all movies\n", + "preds = explicit_model({\"userId\": tf.tile([['42']], [9737, 1]), \"movieTitle\": all_movies})\n", + "preds" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KKSVNey9cZTw", + "outputId": "be179eeb-98be-4686-e764-fe1d34382bc9" + }, + "execution_count": 56, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 56 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# Sort movie titles from highest rated to lowest\n", + "tf.gather(all_movies, tf.squeeze(tf.argsort(preds, axis=0, direction='DESCENDING')))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "rkQPNaaTcKpA", + "outputId": "18415d2d-06f5-4038-d17a-8aae3f8029ab" + }, + "execution_count": 57, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 57 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# save model\n", + "explicit_model.save('explicit_model')" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hUkuvsLegqYH", + "outputId": "84b04a97-00cf-4d4d-e008-6e6a9a69ce8c" + }, + "execution_count": 58, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n", + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n", + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n", + "WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config.\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "explicit_loaded = tf.saved_model.load(\"explicit_model\")" + ], + "metadata": { + "id": "bK6rJ2hXhEeB" + }, + "execution_count": 59, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Note: Saved model takes input as 1-d array\n", + "preds = explicit_loaded({\"userId\": tf.tile(['42'], [9737]), \"movieTitle\": movies['movieTitle'].unique()})\n", + "tf.gather(all_movies, tf.squeeze(tf.argsort(preds, axis=0, direction='DESCENDING'))).numpy()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0MjPEu6pmuL_", + "outputId": "50226467-471b-4017-b796-f21b0b06634f" + }, + "execution_count": 60, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array([[b'Wallace & Gromit: The Best of Aardman Animation (1996)'],\n", + " [b'Andalusian Dog, An (Chien andalou, Un) (1929)'],\n", + " [b'Come and See (Idi i smotri) (1985)'],\n", + " ...,\n", + " [b'Speed 2: Cruise Control (1997)'],\n", + " [b'Battlefield Earth (2000)'],\n", + " [b'Jason X (2002)']], dtype=object)" + ] + }, + "metadata": {}, + "execution_count": 60 + } + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "collapsed_sections": [ + "qA00wBE2Ntdm", + "zCxQ1CZcO2wh", + "XhM4W1t6keqc", + "j31d4CXz_2q3", + "EaG_BCFbkjUD", + "Lrch6rVBOB9Q", + "BMV0HpzmJGWk", + "Q3dGH_I7A_ZO", + "7d3vv_Kr7rQn", + "X58OruuYFt2s", + "r44DDjPWUsW6" + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file