jharrison27 commited on
Commit
8fd248e
1 Parent(s): 66cf393

revert changes

Browse files
Files changed (1) hide show
  1. app.py +39 -21
app.py CHANGED
@@ -1,7 +1,5 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
- from sklearn.metrics.pairwise import cosine_similarity
4
- from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
5
  from sklearn.cluster import KMeans
6
  import numpy as np
7
 
@@ -37,38 +35,58 @@ def embed_words(words, model_name):
37
  embeddings = embedder(words)
38
  return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
39
 
40
- def cluster_words(words, model_name, method):
41
- embeddings = embed_words(words, model_name)
42
- if method == 'Cosine Similarity':
43
- # Use cosine similarity and hierarchical clustering
44
- sim_matrix = cosine_similarity(embeddings)
45
- Z = linkage(sim_matrix, 'average', metric='cosine')
46
- labels = fcluster(Z, t=4, criterion='maxclust')
47
- elif method == 'K-means':
48
- # Use K-means clustering
49
- kmeans = KMeans(n_clusters=4, random_state=0).fit(embeddings)
50
- labels = kmeans.labels_ + 1
51
- clusters = {i: [] for i in range(1, 5)}
52
- for word, label in zip(words, labels):
53
- clusters[label].append(word)
54
- return clusters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def display_clusters(clusters):
57
- for i, words in clusters.items():
58
  st.markdown(f"### Group {i+1}")
59
  st.write(", ".join(words))
60
 
61
  def main():
62
  st.title("NYT Connections Solver")
63
  st.write("This app demonstrates solving the NYT Connections game using word embeddings and clustering.")
64
- st.write("Select an embedding model and a clustering method from the dropdown menus, then click 'Generate Clusters' to see the grouped words.")
65
 
 
66
  model_name = st.selectbox("Select Embedding Model", list(models.keys()))
67
- clustering_method = st.selectbox("Select Clustering Method", ['K-means', 'Cosine Similarity'])
68
 
69
  if st.button("Generate Clusters"):
70
  with st.spinner("Generating clusters..."):
71
- clusters = cluster_words(mock_words, model_name, clustering_method)
72
  display_clusters(clusters)
73
 
74
  if __name__ == "__main__":
 
1
  import streamlit as st
2
  from transformers import pipeline
 
 
3
  from sklearn.cluster import KMeans
4
  import numpy as np
5
 
 
35
  embeddings = embedder(words)
36
  return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
37
 
38
+ def iterative_clustering(words, model_name):
39
+ remaining_words = words[:]
40
+ grouped_words = []
41
+
42
+ while len(remaining_words) >= 4:
43
+ embeddings = embed_words(remaining_words, model_name)
44
+ kmeans = KMeans(n_clusters=min(4, len(remaining_words) // 4), random_state=0).fit(embeddings)
45
+ clusters = {i: [] for i in range(kmeans.n_clusters)}
46
+ for word, label in zip(remaining_words, kmeans.labels_):
47
+ if len(clusters[label]) < 4:
48
+ clusters[label].append(word)
49
+
50
+ # Select the most cohesive cluster
51
+ best_cluster, best_idx = select_most_cohesive_cluster(clusters, kmeans, embeddings)
52
+
53
+ # Store the best cluster and remove those words
54
+ grouped_words.append(best_cluster)
55
+ remaining_words = [word for word in remaining_words if word not in best_cluster]
56
+
57
+ return grouped_words
58
+
59
+ def select_most_cohesive_cluster(clusters, kmeans_model, embeddings):
60
+ min_distance = float('inf')
61
+ best_cluster = None
62
+ best_idx = -1
63
+ for idx, cluster in clusters.items():
64
+ if len(cluster) == 4:
65
+ cluster_embeddings = embeddings[[i for i, label in enumerate(kmeans_model.labels_) if label == idx]]
66
+ centroid = kmeans_model.cluster_centers_[idx]
67
+ distance = np.mean(np.linalg.norm(cluster_embeddings - centroid, axis=1))
68
+ if distance < min_distance:
69
+ min_distance = distance
70
+ best_cluster = cluster
71
+ best_idx = idx
72
+ return best_cluster, best_idx
73
 
74
  def display_clusters(clusters):
75
+ for i, words in enumerate(clusters):
76
  st.markdown(f"### Group {i+1}")
77
  st.write(", ".join(words))
78
 
79
  def main():
80
  st.title("NYT Connections Solver")
81
  st.write("This app demonstrates solving the NYT Connections game using word embeddings and clustering.")
82
+ st.write("Select an embedding model from the dropdown menu and click 'Generate Clusters' to see the grouped words.")
83
 
84
+ # Dropdown menu for selecting the embedding model
85
  model_name = st.selectbox("Select Embedding Model", list(models.keys()))
 
86
 
87
  if st.button("Generate Clusters"):
88
  with st.spinner("Generating clusters..."):
89
+ clusters = iterative_clustering(mock_words, model_name)
90
  display_clusters(clusters)
91
 
92
  if __name__ == "__main__":