HF_wiki_kmeans / cluster_inference.py
njcad's picture
files necessary for running k means inference on a user prompt
cf85eee verified
"""
we want to be able to assign a small user text entry to one of our clusters.
"""
import joblib
import pickle
from transformers import GPT2Tokenizer, GPT2Model
import torch
### inference demo
# load the GPT model
GPT_tokenizer = GPT2Tokenizer.from_pretrained('gpt2', padding=True)
GPT_tokenizer.pad_token = '[PAD]'
GPT_model = GPT2Model.from_pretrained('gpt2')
# set some user example
user_example = "we are looking to make some music! please point us to a lovely cluster where we can hear lovely sounds. I like the cranberries."
# tokenize the input
encoded_input = GPT_tokenizer(user_example, return_tensors="pt", padding=True, truncation=True)
# generate the embeddings
with torch.no_grad():
# get outputs from GPT model
outputs = GPT_model(**encoded_input)
# get the [CLS] (classification) token for sequence representation
cls_embedding = outputs.last_hidden_state[:, 0, :].numpy()
# load the kmeans model
kmeans_model = joblib.load('GPT_128k_means_model.joblib')
# do inference
example_cluster = kmeans_model.predict(cls_embedding)
print(example_cluster)
from collections import Counter
with open('mbid_GPT_128_clusters.pickle', 'rb') as f:
mbid_clusters = pickle.load(f)
print(type(mbid_clusters))
# print(mbid_clusters)
sample_mbid = 'bd57a71ece2912664f5e267166a2a1fb'
cluster_assignment = mbid_clusters.get(sample_mbid)
print(cluster_assignment)
# cluster_distribution = Counter(mbid_clusters.values())
# print(cluster_distribution)
# # Load the KMeans model
# kmeans_model = joblib.load('GPT_512k_means_model.joblib')
# # Load the cluster assignments from the pickle file
# with open('mbid_GPT_512_clusters.pickle', 'rb') as f:
# mbid_clusters = pickle.load(f)
# # Now you can access the KMeans model and cluster assignments
# # For example, to get the cluster assignments for a specific mbid:
# sample_mbid = '2a0a712b4b00f3df2d4fa50fe21f43cb'
# cluster_assignment = mbid_clusters.get(sample_mbid)
# # To get the distribution of clusters
# from collections import Counter
# cluster_distribution = Counter(mbid_clusters.values())
# # print(cluster_distribution)
# # To check if each article is assigned a cluster
# total_articles = len(mbid_clusters)
# articles_with_cluster = sum(1 for cluster in mbid_clusters.values() if cluster is not None)
# print(f"Total articles: {total_articles}")
# print(f"Articles with assigned clusters: {articles_with_cluster}")
# # To check different clusters
# # Replace 'cluster_number' with the cluster number you want to inspect
# cluster_number = 0
# articles_in_cluster = [mbid for mbid, cluster in mbid_clusters.items() if cluster == cluster_number]
# #print(f"Articles in cluster {cluster_number}: {articles_in_cluster}")
# # for cluster in mbid_clusters:
# import joblib
# import numpy as np
# # vectorizer
# from sklearn.feature_extraction.text import HashingVectorizer
# # load cluster data pickle file, kmeans model, and vectorizer model
# clusters = joblib.load("clusters_data.pickle")
# vectorizer = joblib.load("vectorizer.joblib")
# kmeans = joblib.load("best_kmeans_model.joblib")
# # an example to try
# user_example = ["make me and my friends a cool song!"]
# # vectorize user example
# vectorized_example = vectorizer.transform(user_example)
# print(vectorized_example)
# # assign a cluster: result is cluster 497
# example_cluster = kmeans.predict(vectorized_example)
# print(example_cluster)
# # print(type(clusters[497]))
# # print(len(clusters[497]))
# # print(clusters[497][1])
# # Get the number of data points assigned to each cluster
# num_assigned = [len(cluster_data) for cluster_data in clusters.values()]
# # Compute mean and standard deviation of the number of data points per cluster
# mean_assigned = np.mean(num_assigned)
# std_assigned = np.std(num_assigned)
# print(f"Mean number of data points per cluster: {mean_assigned}")
# print(f"Standard deviation of number of data points per cluster: {std_assigned}")
# # Mean number of data points per cluster: 9.694656488549619
# # Standard deviation of number of data points per cluster: 21.820754225240147
# # get a view of some of the clusters
# num_samples = 3
# # # Print a short version of some clusters
# # for cluster_label, cluster_data in clusters.items():
# # print(f"Cluster {cluster_label}:")
# # for i, (mbid, text) in enumerate(cluster_data[:num_samples], 1):
# # print(f"Sample {i}: {text[:100]}...") # Print only the first 100 characters of each text
# # print() # Add a blank line between clusters