from transformers import DistilBertTokenizer, DistilBertModel, \ BertTokenizer, BertModel, \ RobertaTokenizer, RobertaModel, \ AutoTokenizer, AutoModelForMaskedLM import gradio as gr import pandas as pd import numpy as np from typing import Tuple from sklearn.cluster import KMeans # global variables # global variables encoder_options = [ 'distilbert-base-uncased', 'bert-base-uncased', 'bert-base-cased', 'roberta-base', 'xlm-roberta-base', ] tokenizer = None model = None genres = pd.read_csv("./all_genres.csv") genres = set(genres["genre"].to_list()) def update_models(current_encoder: str) -> None: global model, tokenizer if current_encoder == 'distilbert-base-uncased': tokenizer = DistilBertTokenizer.from_pretrained( 'distilbert-base-uncased' ) model = DistilBertModel.from_pretrained('distilbert-base-uncased') elif current_encoder == 'bert-base-uncased': tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased') elif current_encoder == 'bert-base-cased': tokenizer = BertTokenizer.from_pretrained('bert-base-cased') model = BertModel.from_pretrained('bert-base-cased') elif current_encoder == 'roberta-base': tokenizer = RobertaTokenizer.from_pretrained('roberta-base') model = RobertaModel.from_pretrained('roberta-base') elif current_encoder == 'xlm-roberta-base': tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base') model = AutoModelForMaskedLM.from_pretrained('xlm-roberta-base') def embed_string() -> np.ndarray: output = [] for text in genres: encoded_input = tokenizer(text, return_tensors='pt') # forward pass new_output = model(**encoded_input) to_append = new_output.last_hidden_state to_append = to_append[:, -1, :] to_append = to_append.flatten().detach().cpu().numpy() output.append(to_append) np_output = np.zeros((len(output), output[0].shape[0])) for i, vector in enumerate(output): np_output[i, :] = vector return np_output def gen_clusters( input_strs: np.ndarray, num_clusters: int ) -> Tuple[KMeans, np.ndarray, float]: clustering_algo = KMeans(n_clusters=num_clusters) predicted_labels = clustering_algo.fit_predict(input_strs) cluster_error = 0.0 for i, predicted_label in enumerate(predicted_labels): predicted_center = clustering_algo.cluster_centers_[predicted_label, :] new_error = np.sqrt(np.sum(np.square(predicted_center, input_strs[i]))) cluster_error += new_error return clustering_algo, predicted_labels, cluster_error def view_clusters(predicted_clusters: np.ndarray) -> pd.DataFrame: mappings = dict() for predicted_cluster, movie in zip(predicted_clusters, genres): curr_mapping = mappings.get(predicted_cluster, []) curr_mapping.append(movie) mappings[predicted_cluster] = curr_mapping output_df = pd.DataFrame() max_len = max([len(x) for x in mappings.values()]) max_cluster = max(predicted_clusters) for i in range(max_cluster + 1): new_column_name = f"cluster_{i}" new_column_data = mappings[i] new_column_data.extend([''] * (max_len - len(new_column_data))) output_df[new_column_name] = new_column_data return output_df def add_new_genre( new_genre: str = "", num_clusters: int = 5, ) -> pd.DataFrame: global genres if new_genre != "": genres.add(new_genre) embedded_genres = embed_string() _, cluster_centers, error = gen_clusters(embedded_genres, num_clusters) ouput_df = view_clusters(cluster_centers) return ouput_df, error if __name__ == "__main__": with gr.Blocks() as demo: current_encoder = gr.Radio(encoder_options, label="Encoder") current_encoder.change(fn=update_models, inputs=current_encoder) new_genre_input = gr.Textbox(value="", label="New Genre") num_clusters_input = gr.Number( value=5, precision=0, label="Clusters" ) output_clustering = gr.DataFrame() output_error = gr.Number(label="Clustering Error", interactive=False) encode_button = gr.Button(value="Run") encode_button.click( fn=add_new_genre, inputs=[new_genre_input, num_clusters_input], outputs=[output_clustering, output_error], ) demo.launch()