""" | |
we want to be able to assign a small user text entry to one of our clusters. | |
""" | |
import joblib | |
import pickle | |
from transformers import GPT2Tokenizer, GPT2Model | |
import torch | |
### inference demo | |
# load the GPT model | |
GPT_tokenizer = GPT2Tokenizer.from_pretrained('gpt2', padding=True) | |
GPT_tokenizer.pad_token = '[PAD]' | |
GPT_model = GPT2Model.from_pretrained('gpt2') | |
# set some user example | |
user_example = "we are looking to make some music! please point us to a lovely cluster where we can hear lovely sounds. I like the cranberries." | |
# tokenize the input | |
encoded_input = GPT_tokenizer(user_example, return_tensors="pt", padding=True, truncation=True) | |
# generate the embeddings | |
with torch.no_grad(): | |
# get outputs from GPT model | |
outputs = GPT_model(**encoded_input) | |
# get the [CLS] (classification) token for sequence representation | |
cls_embedding = outputs.last_hidden_state[:, 0, :].numpy() | |
# load the kmeans model | |
kmeans_model = joblib.load('GPT_128k_means_model.joblib') | |
# do inference | |
example_cluster = kmeans_model.predict(cls_embedding) | |
print(example_cluster) | |
from collections import Counter | |
with open('mbid_GPT_128_clusters.pickle', 'rb') as f: | |
mbid_clusters = pickle.load(f) | |
print(type(mbid_clusters)) | |
# print(mbid_clusters) | |
sample_mbid = 'bd57a71ece2912664f5e267166a2a1fb' | |
cluster_assignment = mbid_clusters.get(sample_mbid) | |
print(cluster_assignment) | |
# cluster_distribution = Counter(mbid_clusters.values()) | |
# print(cluster_distribution) | |
# # Load the KMeans model | |
# kmeans_model = joblib.load('GPT_512k_means_model.joblib') | |
# # Load the cluster assignments from the pickle file | |
# with open('mbid_GPT_512_clusters.pickle', 'rb') as f: | |
# mbid_clusters = pickle.load(f) | |
# # Now you can access the KMeans model and cluster assignments | |
# # For example, to get the cluster assignments for a specific mbid: | |
# sample_mbid = '2a0a712b4b00f3df2d4fa50fe21f43cb' | |
# cluster_assignment = mbid_clusters.get(sample_mbid) | |
# # To get the distribution of clusters | |
# from collections import Counter | |
# cluster_distribution = Counter(mbid_clusters.values()) | |
# # print(cluster_distribution) | |
# # To check if each article is assigned a cluster | |
# total_articles = len(mbid_clusters) | |
# articles_with_cluster = sum(1 for cluster in mbid_clusters.values() if cluster is not None) | |
# print(f"Total articles: {total_articles}") | |
# print(f"Articles with assigned clusters: {articles_with_cluster}") | |
# # To check different clusters | |
# # Replace 'cluster_number' with the cluster number you want to inspect | |
# cluster_number = 0 | |
# articles_in_cluster = [mbid for mbid, cluster in mbid_clusters.items() if cluster == cluster_number] | |
# #print(f"Articles in cluster {cluster_number}: {articles_in_cluster}") | |
# # for cluster in mbid_clusters: | |
# import joblib | |
# import numpy as np | |
# # vectorizer | |
# from sklearn.feature_extraction.text import HashingVectorizer | |
# # load cluster data pickle file, kmeans model, and vectorizer model | |
# clusters = joblib.load("clusters_data.pickle") | |
# vectorizer = joblib.load("vectorizer.joblib") | |
# kmeans = joblib.load("best_kmeans_model.joblib") | |
# # an example to try | |
# user_example = ["make me and my friends a cool song!"] | |
# # vectorize user example | |
# vectorized_example = vectorizer.transform(user_example) | |
# print(vectorized_example) | |
# # assign a cluster: result is cluster 497 | |
# example_cluster = kmeans.predict(vectorized_example) | |
# print(example_cluster) | |
# # print(type(clusters[497])) | |
# # print(len(clusters[497])) | |
# # print(clusters[497][1]) | |
# # Get the number of data points assigned to each cluster | |
# num_assigned = [len(cluster_data) for cluster_data in clusters.values()] | |
# # Compute mean and standard deviation of the number of data points per cluster | |
# mean_assigned = np.mean(num_assigned) | |
# std_assigned = np.std(num_assigned) | |
# print(f"Mean number of data points per cluster: {mean_assigned}") | |
# print(f"Standard deviation of number of data points per cluster: {std_assigned}") | |
# # Mean number of data points per cluster: 9.694656488549619 | |
# # Standard deviation of number of data points per cluster: 21.820754225240147 | |
# # get a view of some of the clusters | |
# num_samples = 3 | |
# # # Print a short version of some clusters | |
# # for cluster_label, cluster_data in clusters.items(): | |
# # print(f"Cluster {cluster_label}:") | |
# # for i, (mbid, text) in enumerate(cluster_data[:num_samples], 1): | |
# # print(f"Sample {i}: {text[:100]}...") # Print only the first 100 characters of each text | |
# # print() # Add a blank line between clusters |