Spaces:
Runtime error
Runtime error
Commit
·
8fd248e
1
Parent(s):
66cf393
revert changes
Browse files
app.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
3 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
4 |
-
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster
|
5 |
from sklearn.cluster import KMeans
|
6 |
import numpy as np
|
7 |
|
@@ -37,38 +35,58 @@ def embed_words(words, model_name):
|
|
37 |
embeddings = embedder(words)
|
38 |
return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
|
39 |
|
40 |
-
def
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
clusters
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def display_clusters(clusters):
|
57 |
-
for i, words in clusters
|
58 |
st.markdown(f"### Group {i+1}")
|
59 |
st.write(", ".join(words))
|
60 |
|
61 |
def main():
|
62 |
st.title("NYT Connections Solver")
|
63 |
st.write("This app demonstrates solving the NYT Connections game using word embeddings and clustering.")
|
64 |
-
st.write("Select an embedding model
|
65 |
|
|
|
66 |
model_name = st.selectbox("Select Embedding Model", list(models.keys()))
|
67 |
-
clustering_method = st.selectbox("Select Clustering Method", ['K-means', 'Cosine Similarity'])
|
68 |
|
69 |
if st.button("Generate Clusters"):
|
70 |
with st.spinner("Generating clusters..."):
|
71 |
-
clusters =
|
72 |
display_clusters(clusters)
|
73 |
|
74 |
if __name__ == "__main__":
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
|
|
|
|
3 |
from sklearn.cluster import KMeans
|
4 |
import numpy as np
|
5 |
|
|
|
35 |
embeddings = embedder(words)
|
36 |
return np.array([np.mean(embedding[0], axis=0) for embedding in embeddings])
|
37 |
|
38 |
+
def iterative_clustering(words, model_name):
|
39 |
+
remaining_words = words[:]
|
40 |
+
grouped_words = []
|
41 |
+
|
42 |
+
while len(remaining_words) >= 4:
|
43 |
+
embeddings = embed_words(remaining_words, model_name)
|
44 |
+
kmeans = KMeans(n_clusters=min(4, len(remaining_words) // 4), random_state=0).fit(embeddings)
|
45 |
+
clusters = {i: [] for i in range(kmeans.n_clusters)}
|
46 |
+
for word, label in zip(remaining_words, kmeans.labels_):
|
47 |
+
if len(clusters[label]) < 4:
|
48 |
+
clusters[label].append(word)
|
49 |
+
|
50 |
+
# Select the most cohesive cluster
|
51 |
+
best_cluster, best_idx = select_most_cohesive_cluster(clusters, kmeans, embeddings)
|
52 |
+
|
53 |
+
# Store the best cluster and remove those words
|
54 |
+
grouped_words.append(best_cluster)
|
55 |
+
remaining_words = [word for word in remaining_words if word not in best_cluster]
|
56 |
+
|
57 |
+
return grouped_words
|
58 |
+
|
59 |
+
def select_most_cohesive_cluster(clusters, kmeans_model, embeddings):
|
60 |
+
min_distance = float('inf')
|
61 |
+
best_cluster = None
|
62 |
+
best_idx = -1
|
63 |
+
for idx, cluster in clusters.items():
|
64 |
+
if len(cluster) == 4:
|
65 |
+
cluster_embeddings = embeddings[[i for i, label in enumerate(kmeans_model.labels_) if label == idx]]
|
66 |
+
centroid = kmeans_model.cluster_centers_[idx]
|
67 |
+
distance = np.mean(np.linalg.norm(cluster_embeddings - centroid, axis=1))
|
68 |
+
if distance < min_distance:
|
69 |
+
min_distance = distance
|
70 |
+
best_cluster = cluster
|
71 |
+
best_idx = idx
|
72 |
+
return best_cluster, best_idx
|
73 |
|
74 |
def display_clusters(clusters):
|
75 |
+
for i, words in enumerate(clusters):
|
76 |
st.markdown(f"### Group {i+1}")
|
77 |
st.write(", ".join(words))
|
78 |
|
79 |
def main():
|
80 |
st.title("NYT Connections Solver")
|
81 |
st.write("This app demonstrates solving the NYT Connections game using word embeddings and clustering.")
|
82 |
+
st.write("Select an embedding model from the dropdown menu and click 'Generate Clusters' to see the grouped words.")
|
83 |
|
84 |
+
# Dropdown menu for selecting the embedding model
|
85 |
model_name = st.selectbox("Select Embedding Model", list(models.keys()))
|
|
|
86 |
|
87 |
if st.button("Generate Clusters"):
|
88 |
with st.spinner("Generating clusters..."):
|
89 |
+
clusters = iterative_clustering(mock_words, model_name)
|
90 |
display_clusters(clusters)
|
91 |
|
92 |
if __name__ == "__main__":
|