njcad commited on
Commit
cf85eee
1 Parent(s): 0677043

files necessary for running k means inference on a user prompt

Browse files
GPT_128k_means_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:377a377d10735695ba3f923def032b65ad62e0c5eaf960cccb55b310aa123fe1
3
+ size 465849
cluster_inference.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ we want to be able to assign a small user text entry to one of our clusters.
3
+ """
4
+
5
+ import joblib
6
+ import pickle
7
+ from transformers import GPT2Tokenizer, GPT2Model
8
+ import torch
9
+
10
+ ### inference demo
11
+
12
+ # load the GPT model
13
+ GPT_tokenizer = GPT2Tokenizer.from_pretrained('gpt2', padding=True)
14
+ GPT_tokenizer.pad_token = '[PAD]'
15
+ GPT_model = GPT2Model.from_pretrained('gpt2')
16
+
17
+ # set some user example
18
+ user_example = "we are looking to make some music! please point us to a lovely cluster where we can hear lovely sounds. I like the cranberries."
19
+
20
+ # tokenize the input
21
+ encoded_input = GPT_tokenizer(user_example, return_tensors="pt", padding=True, truncation=True)
22
+
23
+ # generate the embeddings
24
+ with torch.no_grad():
25
+
26
+ # get outputs from GPT model
27
+ outputs = GPT_model(**encoded_input)
28
+
29
+ # get the [CLS] (classification) token for sequence representation
30
+ cls_embedding = outputs.last_hidden_state[:, 0, :].numpy()
31
+
32
+ # load the kmeans model
33
+ kmeans_model = joblib.load('GPT_128k_means_model.joblib')
34
+
35
+ # do inference
36
+ example_cluster = kmeans_model.predict(cls_embedding)
37
+ print(example_cluster)
38
+
39
+
40
+
41
+ from collections import Counter
42
+ with open('mbid_GPT_128_clusters.pickle', 'rb') as f:
43
+ mbid_clusters = pickle.load(f)
44
+
45
+
46
+ print(type(mbid_clusters))
47
+ # print(mbid_clusters)
48
+ sample_mbid = 'bd57a71ece2912664f5e267166a2a1fb'
49
+ cluster_assignment = mbid_clusters.get(sample_mbid)
50
+ print(cluster_assignment)
51
+
52
+ # cluster_distribution = Counter(mbid_clusters.values())
53
+ # print(cluster_distribution)
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+ # # Load the KMeans model
68
+ # kmeans_model = joblib.load('GPT_512k_means_model.joblib')
69
+
70
+ # # Load the cluster assignments from the pickle file
71
+ # with open('mbid_GPT_512_clusters.pickle', 'rb') as f:
72
+ # mbid_clusters = pickle.load(f)
73
+
74
+ # # Now you can access the KMeans model and cluster assignments
75
+ # # For example, to get the cluster assignments for a specific mbid:
76
+ # sample_mbid = '2a0a712b4b00f3df2d4fa50fe21f43cb'
77
+ # cluster_assignment = mbid_clusters.get(sample_mbid)
78
+
79
+ # # To get the distribution of clusters
80
+ # from collections import Counter
81
+ # cluster_distribution = Counter(mbid_clusters.values())
82
+ # # print(cluster_distribution)
83
+
84
+ # # To check if each article is assigned a cluster
85
+ # total_articles = len(mbid_clusters)
86
+ # articles_with_cluster = sum(1 for cluster in mbid_clusters.values() if cluster is not None)
87
+
88
+ # print(f"Total articles: {total_articles}")
89
+ # print(f"Articles with assigned clusters: {articles_with_cluster}")
90
+
91
+ # # To check different clusters
92
+ # # Replace 'cluster_number' with the cluster number you want to inspect
93
+ # cluster_number = 0
94
+ # articles_in_cluster = [mbid for mbid, cluster in mbid_clusters.items() if cluster == cluster_number]
95
+ # #print(f"Articles in cluster {cluster_number}: {articles_in_cluster}")
96
+ # # for cluster in mbid_clusters:
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+ # import joblib
128
+ # import numpy as np
129
+
130
+ # # vectorizer
131
+ # from sklearn.feature_extraction.text import HashingVectorizer
132
+
133
+ # # load cluster data pickle file, kmeans model, and vectorizer model
134
+ # clusters = joblib.load("clusters_data.pickle")
135
+ # vectorizer = joblib.load("vectorizer.joblib")
136
+ # kmeans = joblib.load("best_kmeans_model.joblib")
137
+
138
+ # # an example to try
139
+ # user_example = ["make me and my friends a cool song!"]
140
+
141
+ # # vectorize user example
142
+ # vectorized_example = vectorizer.transform(user_example)
143
+ # print(vectorized_example)
144
+
145
+ # # assign a cluster: result is cluster 497
146
+ # example_cluster = kmeans.predict(vectorized_example)
147
+ # print(example_cluster)
148
+
149
+ # # print(type(clusters[497]))
150
+ # # print(len(clusters[497]))
151
+ # # print(clusters[497][1])
152
+
153
+
154
+ # # Get the number of data points assigned to each cluster
155
+ # num_assigned = [len(cluster_data) for cluster_data in clusters.values()]
156
+
157
+ # # Compute mean and standard deviation of the number of data points per cluster
158
+ # mean_assigned = np.mean(num_assigned)
159
+ # std_assigned = np.std(num_assigned)
160
+
161
+ # print(f"Mean number of data points per cluster: {mean_assigned}")
162
+ # print(f"Standard deviation of number of data points per cluster: {std_assigned}")
163
+
164
+ # # Mean number of data points per cluster: 9.694656488549619
165
+ # # Standard deviation of number of data points per cluster: 21.820754225240147
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+
176
+
177
+ # # get a view of some of the clusters
178
+ # num_samples = 3
179
+
180
+ # # # Print a short version of some clusters
181
+ # # for cluster_label, cluster_data in clusters.items():
182
+ # # print(f"Cluster {cluster_label}:")
183
+ # # for i, (mbid, text) in enumerate(cluster_data[:num_samples], 1):
184
+ # # print(f"Sample {i}: {text[:100]}...") # Print only the first 100 characters of each text
185
+ # # print() # Add a blank line between clusters
mbid_GPT_128_clusters.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b86a2b4e34a33765fa38f6b840c565909d1c71884c358974de60afd7cc2225f9
3
+ size 3287814