{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/nathanluskey/opt/anaconda3/envs/ml_env/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from transformers import DistilBertTokenizer, DistilBertModel, \\\n", " BertTokenizer, BertModel, \\\n", " RobertaTokenizer, RobertaModel, \\\n", " AutoTokenizer, AutoModelForMaskedLM\n", "import gradio as gr\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "from typing import List, Tuple\n", "from sklearn.cluster import KMeans" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# global variables\n", "encoder_options = [\n", " 'distilbert-base-uncased',\n", " 'bert-base-uncased',\n", " 'bert-base-cased'\n", " 'roberta-base',\n", " 'xlm-roberta-base',\n", " ]\n", "\n", "current_encoder = encoder_options[0]\n", "tokenizer = None\n", "model = None\n", "\n", "genres = pd.read_csv(\"./all_genres.csv\")\n", "genres = genres[\"genre\"].to_list()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.weight']\n", "- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "if current_encoder == 'distilbert-base-uncased':\n", " tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n", " model = DistilBertModel.from_pretrained('distilbert-base-uncased')\n", "elif current_encoder == 'bert-base-uncased':\n", " tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", " model = BertModel.from_pretrained('bert-base-uncased')\n", "elif current_encoder == 'bert-base-cased':\n", " tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n", " model = BertModel.from_pretrained('bert-base-cased')\n", "elif current_encoder == 'roberta-base':\n", " tokenizer = RobertaTokenizer.from_pretrained('roberta-base')\n", " model = RobertaModel.from_pretrained('roberta-base')\n", "elif current_encoder == 'xlm-roberta-base':\n", " tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')\n", " model = AutoModelForMaskedLM.from_pretrained('xlm-roberta-base')" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def embed_string() -> np.ndarray:\n", " output = []\n", " for text in genres:\n", " encoded_input = tokenizer(text, return_tensors='pt')\n", " # forward pass\n", " new_output = model(**encoded_input)\n", " to_append = new_output.last_hidden_state\n", " to_append = to_append[:, -1, :] #Take the last element\n", " to_append = to_append.flatten().detach().cpu().numpy()\n", " output.append(to_append)\n", " np_output = np.zeros((len(output), output[0].shape[0]))\n", " for i, vector in enumerate(output):\n", " np_output[i, :] = vector\n", " return np_output" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def gen_clusters(input_strs:np.ndarray, num_clusters:int) -> Tuple[KMeans, np.ndarray, float]:\n", " clustering_algo = KMeans(n_clusters=num_clusters)\n", " predicted_labels = clustering_algo.fit_predict(input_strs)\n", "\n", " cluster_error = 0.0\n", " for i, predicted_label in enumerate(predicted_labels):\n", " predicted_center = clustering_algo.cluster_centers_[predicted_label, :]\n", " new_error = np.sqrt(np.sum(np.square(predicted_center, input_strs[i])))\n", " cluster_error += new_error\n", "\n", " return clustering_algo, predicted_labels, cluster_error\n", "\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def view_clusters(predicted_clusters:np.ndarray) -> pd.DataFrame:\n", " mappings = dict()\n", " for predicted_cluster, movie in zip(predicted_clusters, genres):\n", " curr_mapping = mappings.get(predicted_cluster, [])\n", " curr_mapping.append(movie)\n", " mappings[predicted_cluster] = curr_mapping\n", "\n", " output_df = pd.DataFrame()\n", " max_len = max([len(x) for x in mappings.values()])\n", " max_cluster = max(predicted_clusters)\n", "\n", " for i in range(max_cluster + 1):\n", " new_column_name = f\"cluster_{i}\"\n", " new_column_data = mappings[i]\n", " new_column_data.extend([''] * (max_len - len(new_column_data)))\n", " output_df[new_column_name] = new_column_data\n", "\n", " return output_df" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def add_new_genre(clustering_algo:KMeans, new_genre:str, recompute:bool = False) -> pd.DataFrame:\n", " global genres\n", " genres.append(new_genre)\n", " embedded_genres = embed_string()\n", " if recompute:\n", " cluster_algo, cluster_centers, error = gen_clusters(embedded_genres, 5)\n", " else:\n", " cluster_centers = cluster_algo.predict(embedded_genres)\n", " \n", " ouput_df = view_clusters(cluster_centers)\n", " return ouput_df\n", " " ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "embedded_genres = embed_string()\n", "clustering_algo, predicted_labels, cluster_error = gen_clusters(embedded_genres, 5)\n", "output_df = view_clusters(predicted_labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.6 ('ml_env')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "2434bee09bcd67f653a1f2d2df1f4f18cabf9d6c39b42950acaa6ef605d590bc" } } }, "nbformat": 4, "nbformat_minor": 2 }