meg-huggingface commited on
Commit
0b7eeeb
·
1 Parent(s): f9936fb

Updating from rollback

Browse files
data_measurements/embeddings.py CHANGED
@@ -20,12 +20,14 @@ import plotly.graph_objects as go
20
  import torch
21
  import transformers
22
  from datasets import load_from_disk
 
23
  from tqdm import tqdm
24
 
25
- from .dataset_utils import EMBEDDING_FIELD, OUR_TEXT_FIELD
26
 
27
 
28
  def sentence_mean_pooling(model_output, attention_mask):
 
29
  token_embeddings = model_output[
30
  0
31
  ] # First element of model_output contains all token embeddings
@@ -38,46 +40,46 @@ def sentence_mean_pooling(model_output, attention_mask):
38
 
39
 
40
  class Embeddings:
41
- def __init__(self, dstats, use_cache=False):
 
 
 
 
 
 
 
42
  """Item embeddings and clustering"""
43
  self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  self.node_list = None
45
  self.nid_map = None
46
- self.embeddings_dset = None
47
  self.fig_tree = None
48
  self.cached_clusters = {}
49
- self.dstats = dstats
50
- self.cache_path = dstats.cache_path
51
- self.node_list_fid = pjoin(self.cache_path, "node_list.th")
52
  self.use_cache = use_cache
53
- self.tokenizer = transformers.AutoTokenizer.from_pretrained(
54
- "sentence-transformers/all-mpnet-base-v2"
55
- )
56
- self.model = transformers.AutoModel.from_pretrained(
57
- "sentence-transformers/all-mpnet-base-v2"
58
- ).to(self.device)
59
-
60
- def make_text_embeddings(self):
61
- embeddings_dset_fid = pjoin(self.cache_path, "embeddings_dset")
62
- if self.use_cache and exists(embeddings_dset_fid):
63
- self.embeddings_dset = load_from_disk(embeddings_dset_fid)
64
- else:
65
- self.embeddings_dset = self.make_embeddings()
66
- self.embeddings_dset.save_to_disk(embeddings_dset_fid)
67
-
68
- def make_hierarchical_clustering(self):
69
- if self.use_cache and exists(self.node_list_fid):
70
- self.node_list = torch.load(self.node_list_fid)
71
- else:
72
- self.make_text_embeddings()
73
- self.node_list = self.fast_cluster(self.embeddings_dset, EMBEDDING_FIELD)
74
- torch.save(self.node_list, self.node_list_fid)
75
- self.nid_map = dict(
76
- [(node["nid"], nid) for nid, node in enumerate(self.node_list)]
77
- )
78
- self.fig_tree = make_tree_plot(self.node_list, self.dstats.text_dset)
79
 
80
  def compute_sentence_embeddings(self, sentences):
 
 
 
 
 
 
 
 
 
81
  batch = self.tokenizer(
82
  sentences, padding=True, truncation=True, return_tensors="pt"
83
  )
@@ -91,212 +93,70 @@ class Embeddings:
91
  return sentence_embeds
92
 
93
  def make_embeddings(self):
 
 
 
 
 
 
 
 
94
  def batch_embed_sentences(sentences):
95
  return {
96
  EMBEDDING_FIELD: [
97
  embed.tolist()
98
  for embed in self.compute_sentence_embeddings(
99
- sentences[OUR_TEXT_FIELD]
100
  )
101
  ]
102
  }
103
 
104
- text_dset_embeds = self.dstats.text_dset.map(
105
  batch_embed_sentences,
106
  batched=True,
107
  batch_size=32,
108
- remove_columns=[self.dstats.our_text_field],
109
- )
110
-
111
- return text_dset_embeds
112
-
113
- @staticmethod
114
- def prepare_merges(embeddings, batch_size, low_thres=0.5):
115
- top_idx_pre = torch.cat(
116
- [torch.LongTensor(range(embeddings.shape[0]))[:, None]] * batch_size, dim=1
117
- )
118
- top_val_all = torch.Tensor(0, batch_size)
119
- top_idx_all = torch.LongTensor(0, batch_size)
120
- n_batches = math.ceil(len(embeddings) / batch_size)
121
- for b in tqdm(range(n_batches)):
122
- cos_scores = torch.mm(
123
- embeddings[b * batch_size : (b + 1) * batch_size], embeddings.t()
124
- )
125
- for i in range(cos_scores.shape[0]):
126
- cos_scores[i, (b * batch_size) + i :] = -1
127
- top_val_large, top_idx_large = cos_scores.topk(
128
- k=batch_size, dim=-1, largest=True
129
- )
130
- top_val_all = torch.cat([top_val_all, top_val_large], dim=0)
131
- top_idx_all = torch.cat([top_idx_all, top_idx_large], dim=0)
132
-
133
- all_merges = torch.cat(
134
- [
135
- top_idx_pre[top_val_all > low_thres][:, None],
136
- top_idx_all[top_val_all > low_thres][:, None],
137
- ],
138
- dim=1,
139
  )
140
- all_merge_scores = top_val_all[top_val_all > low_thres]
141
- return (all_merges, all_merge_scores)
142
 
143
- @staticmethod
144
- def merge_nodes(nodes, current_thres, previous_thres, all_merges, all_merge_scores):
145
- merge_ids = (all_merge_scores <= previous_thres) * (
146
- all_merge_scores > current_thres
147
- )
148
- merges = all_merges[merge_ids]
149
- for a, b in merges.tolist():
150
- node_a = nodes[a]
151
- while node_a["parent_id"] != -1:
152
- node_a = nodes[node_a["parent_id"]]
153
- node_b = nodes[b]
154
- while node_b["parent_id"] != -1:
155
- node_b = nodes[node_b["parent_id"]]
156
- if node_a["nid"] == node_b["nid"]:
157
- continue
158
- else:
159
- # merge if threshold allows
160
- if (node_a["depth"] + node_b["depth"]) > 0 and min(
161
- node_a["merge_threshold"], node_b["merge_threshold"]
162
- ) == current_thres:
163
- merge_to = None
164
- merge_from = None
165
- if node_a["nid"] < node_b["nid"]:
166
- merge_from = node_a
167
- merge_to = node_b
168
- if node_a["nid"] > node_b["nid"]:
169
- merge_from = node_b
170
- merge_to = node_a
171
- merge_to["depth"] = max(merge_to["depth"], merge_from["depth"])
172
- merge_to["weight"] += merge_from["weight"]
173
- merge_to["children_ids"] += (
174
- merge_from["children_ids"]
175
- if merge_from["depth"] > 0
176
- else [merge_from["nid"]]
177
- )
178
- for cid in merge_from["children_ids"]:
179
- nodes[cid]["parent_id"] = merge_to["nid"]
180
- merge_from["parent_id"] = merge_to["nid"]
181
- # else new node
182
- else:
183
- new_nid = len(nodes)
184
- new_node = {
185
- "nid": new_nid,
186
- "parent_id": -1,
187
- "depth": max(node_a["depth"], node_b["depth"]) + 1,
188
- "weight": node_a["weight"] + node_b["weight"],
189
- "children": [],
190
- "children_ids": [node_a["nid"], node_b["nid"]],
191
- "example_ids": [],
192
- "merge_threshold": current_thres,
193
- }
194
- node_a["parent_id"] = new_nid
195
- node_b["parent_id"] = new_nid
196
- nodes += [new_node]
197
- return nodes
198
 
199
- def finalize_node(self, node, nodes, min_cluster_size):
200
- node["children"] = sorted(
201
- [
202
- self.finalize_node(nodes[cid], nodes, min_cluster_size)
203
- for cid in node["children_ids"]
204
- ],
205
- key=lambda x: x["weight"],
206
- reverse=True,
207
- )
208
- if node["depth"] > 0:
209
- node["example_ids"] = [
210
- eid for child in node["children"] for eid in child["example_ids"]
211
- ]
212
- node["children"] = [
213
- child for child in node["children"] if child["weight"] >= min_cluster_size
214
- ]
215
- assert node["weight"] == len(node["example_ids"]), print(node)
216
- return node
217
 
218
- def fast_cluster(
219
  self,
220
- text_dset_embeds,
221
- embedding_field,
222
  batch_size=1000,
 
223
  min_cluster_size=10,
224
- low_thres=0.5,
225
  ):
226
- embeddings = torch.Tensor(text_dset_embeds[embedding_field])
227
- batch_size = min(embeddings.shape[0], batch_size)
228
- all_merges, all_merge_scores = self.prepare_merges(
229
- embeddings, batch_size, low_thres
230
- )
231
- # prepare leaves
232
- nodes = [
233
- {
234
- "nid": nid,
235
- "parent_id": -1,
236
- "depth": 0,
237
- "weight": 1,
238
- "children": [],
239
- "children_ids": [],
240
- "example_ids": [nid],
241
- "merge_threshold": 1.0,
242
- }
243
- for nid in range(embeddings.shape[0])
244
- ]
245
- # one level per threshold range
246
- for i in range(10):
247
- p_thres = 1 - i * 0.05
248
- c_thres = 0.95 - i * 0.05
249
- nodes = self.merge_nodes(
250
- nodes, c_thres, p_thres, all_merges, all_merge_scores
251
  )
252
- # make root
253
- root_children = [
254
- node
255
- for node in nodes
256
- if node["parent_id"] == -1 and node["weight"] >= min_cluster_size
257
- ]
258
- root = {
259
- "nid": len(nodes),
260
- "parent_id": -1,
261
- "depth": max([node["depth"] for node in root_children]) + 1,
262
- "weight": sum([node["weight"] for node in root_children]),
263
- "children": [],
264
- "children_ids": [node["nid"] for node in root_children],
265
- "example_ids": [],
266
- "merge_threshold": -1.0,
267
- }
268
- nodes += [root]
269
- for node in root_children:
270
- node["parent_id"] = root["nid"]
271
- # finalize tree
272
- tree = self.finalize_node(root, nodes, min_cluster_size)
273
- node_list = []
274
-
275
- def rec_map_nodes(node, node_list):
276
- node_list += [node]
277
- for child in node["children"]:
278
- rec_map_nodes(child, node_list)
279
-
280
- rec_map_nodes(tree, node_list)
281
- # get centroids and distances
282
- for node in node_list:
283
- node_embeds = embeddings[node["example_ids"]]
284
- node["centroid"] = node_embeds.sum(dim=0)
285
- node["centroid"] /= node["centroid"].norm()
286
- node["centroid_dot_prods"] = torch.mv(node_embeds, node["centroid"])
287
- node["sorted_examples_centroid"] = sorted(
288
- [
289
- (eid, edp.item())
290
- for eid, edp in zip(node["example_ids"], node["centroid_dot_prods"])
291
- ],
292
- key=lambda x: x[1],
293
- reverse=True,
294
  )
295
- return node_list
296
 
297
  def find_cluster_beam(self, sentence, beam_size=20):
298
  """
299
- This function finds the `beam_size` lef clusters that are closest to the
300
  proposed sentence and returns the full path from the root to the cluster
301
  along with the dot product between the sentence embedding and the
302
  cluster centroid
@@ -365,25 +225,268 @@ class Embeddings:
365
  )[:beam_size]
366
 
367
 
368
- def make_tree_plot(node_list, text_dset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
370
 
371
  for nid, node in enumerate(node_list):
 
 
 
 
 
 
372
  node["label"] = node.get(
373
  "label",
374
  f"{nid:2d} - {node['weight']:5d} items <br>"
375
  + "<br>".join(
376
  [
377
- "> " + txt[:64] + ("..." if len(txt) >= 63 else "")
378
- for txt in list(
379
- set(text_dset.select(node["example_ids"])[OUR_TEXT_FIELD])
380
- )[:5]
381
  ]
382
  ),
383
  )
384
 
385
  # make plot nodes
386
- # TODO: something more efficient than set to remove duplicates
387
  labels = [node["label"] for node in node_list]
388
 
389
  root = node_list[0]
 
20
  import torch
21
  import transformers
22
  from datasets import load_from_disk
23
+ from plotly.io import read_json
24
  from tqdm import tqdm
25
 
26
+ from .dataset_utils import EMBEDDING_FIELD
27
 
28
 
29
  def sentence_mean_pooling(model_output, attention_mask):
30
+ """Mean pooling of token embeddings for a sentence."""
31
  token_embeddings = model_output[
32
  0
33
  ] # First element of model_output contains all token embeddings
 
40
 
41
 
42
  class Embeddings:
43
+ def __init__(
44
+ self,
45
+ dstats=None,
46
+ text_dset=None,
47
+ text_field_name="text",
48
+ cache_path="",
49
+ use_cache=False,
50
+ ):
51
  """Item embeddings and clustering"""
52
  self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
53
+ self.model_name = "sentence-transformers/all-mpnet-base-v2"
54
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name)
55
+ self.model = transformers.AutoModel.from_pretrained(self.model_name).to(
56
+ self.device
57
+ )
58
+ self.text_dset = text_dset if dstats is None else dstats.text_dset
59
+ self.text_field_name = (
60
+ text_field_name if dstats is None else dstats.our_text_field
61
+ )
62
+ self.cache_path = cache_path if dstats is None else dstats.cache_path
63
+ self.embeddings_dset_fid = pjoin(self.cache_path, "embeddings_dset")
64
+ self.embeddings_dset = None
65
+ self.node_list_fid = pjoin(self.cache_path, "node_list.th")
66
  self.node_list = None
67
  self.nid_map = None
68
+ self.fig_tree_fid = pjoin(self.cache_path, "node_figure.json")
69
  self.fig_tree = None
70
  self.cached_clusters = {}
 
 
 
71
  self.use_cache = use_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def compute_sentence_embeddings(self, sentences):
74
+ """
75
+ Takes a list of sentences and computes their embeddings
76
+ using self.tokenizer and self.model (with output dimension D)
77
+ followed by mean pooling of the token representations and normalization
78
+ Args:
79
+ sentences ([string]): list of N input sentences
80
+ Returns:
81
+ torch.Tensor: sentence embeddings, dimension NxD
82
+ """
83
  batch = self.tokenizer(
84
  sentences, padding=True, truncation=True, return_tensors="pt"
85
  )
 
93
  return sentence_embeds
94
 
95
  def make_embeddings(self):
96
+ """
97
+ Batch computes the embeddings of the Dataset self.text_dset,
98
+ using the field self.text_field_name as input.
99
+ Returns:
100
+ Dataset: HF dataset object with a single EMBEDDING_FIELD field
101
+ corresponding to the embeddings (list of floats)
102
+ """
103
+
104
  def batch_embed_sentences(sentences):
105
  return {
106
  EMBEDDING_FIELD: [
107
  embed.tolist()
108
  for embed in self.compute_sentence_embeddings(
109
+ sentences[self.text_field_name]
110
  )
111
  ]
112
  }
113
 
114
+ self.embeddings_dset = self.text_dset.map(
115
  batch_embed_sentences,
116
  batched=True,
117
  batch_size=32,
118
+ remove_columns=[self.text_field_name],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  )
 
 
120
 
121
+ return self.embeddings_dset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ def make_text_embeddings(self):
124
+ """Load embeddings dataset from cache or compute it."""
125
+ if self.use_cache and exists(self.embeddings_dset_fid):
126
+ self.embeddings_dset = load_from_disk(self.embeddings_dset_fid)
127
+ else:
128
+ self.embeddings_dset = self.make_embeddings()
129
+ self.embeddings_dset.save_to_disk(self.embeddings_dset_fid)
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ def make_hierarchical_clustering(
132
  self,
 
 
133
  batch_size=1000,
134
+ approx_neighbors=1000,
135
  min_cluster_size=10,
 
136
  ):
137
+ if self.use_cache and exists(self.node_list_fid):
138
+ self.node_list, self.nid_map = torch.load(self.node_list_fid)
139
+ else:
140
+ self.make_text_embeddings()
141
+ embeddings = torch.Tensor(self.embeddings_dset[EMBEDDING_FIELD])
142
+ self.node_list = fast_cluster(
143
+ embeddings, batch_size, approx_neighbors, min_cluster_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
+ self.nid_map = dict(
146
+ [(node["nid"], nid) for nid, node in enumerate(self.node_list)]
147
+ )
148
+ torch.save((self.node_list, self.nid_map), self.node_list_fid)
149
+ if self.use_cache and exists(self.fig_tree_fid):
150
+ self.fig_tree = read_json(self.fig_tree_fid)
151
+ else:
152
+ self.fig_tree = make_tree_plot(
153
+ self.node_list, self.text_dset, self.text_field_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  )
155
+ self.fig_tree.write_json(self.fig_tree_fid)
156
 
157
  def find_cluster_beam(self, sentence, beam_size=20):
158
  """
159
+ This function finds the `beam_size` leaf clusters that are closest to the
160
  proposed sentence and returns the full path from the root to the cluster
161
  along with the dot product between the sentence embedding and the
162
  cluster centroid
 
225
  )[:beam_size]
226
 
227
 
228
+ def prepare_merges(embeddings, batch_size=1000, approx_neighbors=1000, low_thres=0.5):
229
+ """
230
+ Prepares an initial list of merges for hierarchical
231
+ clustering. First compute the `approx_neighbors` nearest neighbors,
232
+ then propose a merge for any two points that are closer than `low_thres`
233
+
234
+ Note that if a point has more than `approx_neighbors` neighbors
235
+ closer than `low_thres`, this approach will miss some of those merges
236
+
237
+ Args:
238
+ embeddings (toch.Tensor): Tensor of sentence embeddings - dimension NxD
239
+ batch_size (int): compute nearest neighbors of `batch_size` points at a time
240
+ approx_neighbors (int): only keep `approx_neighbors` nearest neighbors of a point
241
+ low_thres (float): only return merges where the dot product is greater than `low_thres`
242
+ Returns:
243
+ torch.LongTensor: proposed merges ([i, j] with i>j) - dimension: Mx2
244
+ torch.Tensor: merge scores - dimension M
245
+ """
246
+ top_idx_pre = torch.cat(
247
+ [torch.LongTensor(range(embeddings.shape[0]))[:, None]] * batch_size, dim=1
248
+ )
249
+ top_val_all = torch.Tensor(0, approx_neighbors)
250
+ top_idx_all = torch.LongTensor(0, approx_neighbors)
251
+ n_batches = math.ceil(len(embeddings) / batch_size)
252
+ for b in tqdm(range(n_batches)):
253
+ # TODO: batch across second dimension
254
+ cos_scores = torch.mm(
255
+ embeddings[b * batch_size : (b + 1) * batch_size], embeddings.t()
256
+ )
257
+ for i in range(cos_scores.shape[0]):
258
+ cos_scores[i, (b * batch_size) + i :] = -1
259
+ top_val_large, top_idx_large = cos_scores.topk(
260
+ k=approx_neighbors, dim=-1, largest=True
261
+ )
262
+ top_val_all = torch.cat([top_val_all, top_val_large], dim=0)
263
+ top_idx_all = torch.cat([top_idx_all, top_idx_large], dim=0)
264
+ max_neighbor_dist = top_val_large[:, -1].max().item()
265
+ if max_neighbor_dist > low_thres:
266
+ print(
267
+ f"WARNING: with the current set of neireast neighbor, the farthest is {max_neighbor_dist}"
268
+ )
269
+
270
+ all_merges = torch.cat(
271
+ [
272
+ top_idx_pre[top_val_all > low_thres][:, None],
273
+ top_idx_all[top_val_all > low_thres][:, None],
274
+ ],
275
+ dim=1,
276
+ )
277
+ all_merge_scores = top_val_all[top_val_all > low_thres]
278
+
279
+ return (all_merges, all_merge_scores)
280
+
281
+
282
+ def merge_nodes(nodes, current_thres, previous_thres, all_merges, all_merge_scores):
283
+ """
284
+ Merge all nodes if the max dot product between any of their descendants
285
+ is greater than current_thres.
286
+
287
+ Args:
288
+ nodes ([dict]): list of dicts representing the current set of nodes
289
+ current_thres (float): merge all nodes closer than current_thres
290
+ previous_thres (float): nodes closer than previous_thres are already merged
291
+ all_merges (torch.LongTensor): proposed merges ([i, j] with i>j) - dimension: Mx2
292
+ all_merge_scores (torch.Tensor): merge scores - dimension M
293
+ Returns:
294
+ [dict]: extended list with the newly created internal nodes
295
+ """
296
+ merge_ids = (all_merge_scores <= previous_thres) * (
297
+ all_merge_scores > current_thres
298
+ )
299
+ if merge_ids.sum().item() > 0:
300
+ merges = all_merges[merge_ids]
301
+ for a, b in merges.tolist():
302
+ node_a = nodes[a]
303
+ while node_a["parent_id"] != -1:
304
+ node_a = nodes[node_a["parent_id"]]
305
+ node_b = nodes[b]
306
+ while node_b["parent_id"] != -1:
307
+ node_b = nodes[node_b["parent_id"]]
308
+ if node_a["nid"] == node_b["nid"]:
309
+ continue
310
+ else:
311
+ # merge if threshold allows
312
+ if (node_a["depth"] + node_b["depth"]) > 0 and min(
313
+ node_a["merge_threshold"], node_b["merge_threshold"]
314
+ ) == current_thres:
315
+ merge_to = None
316
+ merge_from = None
317
+ if node_a["nid"] < node_b["nid"]:
318
+ merge_from = node_a
319
+ merge_to = node_b
320
+ if node_a["nid"] > node_b["nid"]:
321
+ merge_from = node_b
322
+ merge_to = node_a
323
+ merge_to["depth"] = max(merge_to["depth"], merge_from["depth"])
324
+ merge_to["weight"] += merge_from["weight"]
325
+ merge_to["children_ids"] += (
326
+ merge_from["children_ids"]
327
+ if merge_from["depth"] > 0
328
+ else [merge_from["nid"]]
329
+ )
330
+ for cid in merge_from["children_ids"]:
331
+ nodes[cid]["parent_id"] = merge_to["nid"]
332
+ merge_from["parent_id"] = merge_to["nid"]
333
+ # else new node
334
+ else:
335
+ new_nid = len(nodes)
336
+ new_node = {
337
+ "nid": new_nid,
338
+ "parent_id": -1,
339
+ "depth": max(node_a["depth"], node_b["depth"]) + 1,
340
+ "weight": node_a["weight"] + node_b["weight"],
341
+ "children": [],
342
+ "children_ids": [node_a["nid"], node_b["nid"]],
343
+ "example_ids": [],
344
+ "merge_threshold": current_thres,
345
+ }
346
+ node_a["parent_id"] = new_nid
347
+ node_b["parent_id"] = new_nid
348
+ nodes += [new_node]
349
+ return nodes
350
+
351
+
352
+ def finalize_node(node, nodes, min_cluster_size):
353
+ """Post-process nodes to sort children by descending weight,
354
+ get full list of leaves in the sub-tree, and direct links
355
+ to the cildren nodes, then recurses to all children.
356
+
357
+ Nodes with fewer than `min_cluster_size` descendants are collapsed
358
+ into a single leaf.
359
+ """
360
+ node["children"] = sorted(
361
+ [
362
+ finalize_node(nodes[cid], nodes, min_cluster_size)
363
+ for cid in node["children_ids"]
364
+ ],
365
+ key=lambda x: x["weight"],
366
+ reverse=True,
367
+ )
368
+ if node["depth"] > 0:
369
+ node["example_ids"] = [
370
+ eid for child in node["children"] for eid in child["example_ids"]
371
+ ]
372
+ node["children"] = [
373
+ child for child in node["children"] if child["weight"] >= min_cluster_size
374
+ ]
375
+ assert node["weight"] == len(node["example_ids"]), print(node)
376
+ return node
377
+
378
+
379
+ def fast_cluster(
380
+ embeddings,
381
+ batch_size=1000,
382
+ approx_neighbors=1000,
383
+ min_cluster_size=10,
384
+ low_thres=0.5,
385
+ ):
386
+ """
387
+ Computes an approximate hierarchical clustering based on example
388
+ embeddings. The join criterion is min clustering, i.e. two clusters
389
+ are joined if any pair of their descendants are closer than a threshold
390
+
391
+ The approximate comes from the fact that only the `approx_neighbors` nearest
392
+ neighbors of an example are considered for merges
393
+ """
394
+ batch_size = min(embeddings.shape[0], batch_size)
395
+ all_merges, all_merge_scores = prepare_merges(
396
+ embeddings, batch_size, approx_neighbors, low_thres
397
+ )
398
+ # prepare leaves
399
+ nodes = [
400
+ {
401
+ "nid": nid,
402
+ "parent_id": -1,
403
+ "depth": 0,
404
+ "weight": 1,
405
+ "children": [],
406
+ "children_ids": [],
407
+ "example_ids": [nid],
408
+ "merge_threshold": 1.0,
409
+ }
410
+ for nid in range(embeddings.shape[0])
411
+ ]
412
+ # one level per threshold range
413
+ for i in range(10):
414
+ p_thres = 1 - i * 0.05
415
+ c_thres = 0.95 - i * 0.05
416
+ nodes = merge_nodes(nodes, c_thres, p_thres, all_merges, all_merge_scores)
417
+ # make root
418
+ root_children = [
419
+ node
420
+ for node in nodes
421
+ if node["parent_id"] == -1 and node["weight"] >= min_cluster_size
422
+ ]
423
+ root = {
424
+ "nid": len(nodes),
425
+ "parent_id": -1,
426
+ "depth": max([node["depth"] for node in root_children]) + 1,
427
+ "weight": sum([node["weight"] for node in root_children]),
428
+ "children": [],
429
+ "children_ids": [node["nid"] for node in root_children],
430
+ "example_ids": [],
431
+ "merge_threshold": -1.0,
432
+ }
433
+ nodes += [root]
434
+ for node in root_children:
435
+ node["parent_id"] = root["nid"]
436
+ # finalize tree
437
+ tree = finalize_node(root, nodes, min_cluster_size)
438
+ node_list = []
439
+
440
+ def rec_map_nodes(node, node_list):
441
+ node_list += [node]
442
+ for child in node["children"]:
443
+ rec_map_nodes(child, node_list)
444
+
445
+ rec_map_nodes(tree, node_list)
446
+ # get centroids and distances
447
+ for node in node_list:
448
+ node_embeds = embeddings[node["example_ids"]]
449
+ node["centroid"] = node_embeds.sum(dim=0)
450
+ node["centroid"] /= node["centroid"].norm()
451
+ node["centroid_dot_prods"] = torch.mv(node_embeds, node["centroid"])
452
+ node["sorted_examples_centroid"] = sorted(
453
+ [
454
+ (eid, edp.item())
455
+ for eid, edp in zip(node["example_ids"], node["centroid_dot_prods"])
456
+ ],
457
+ key=lambda x: x[1],
458
+ reverse=True,
459
+ )
460
+ return node_list
461
+
462
+
463
+ def make_tree_plot(node_list, text_dset, text_field_name):
464
+ """
465
+ Makes a graphical representation of the tree encoded
466
+ in node-list. The hover label for each node shows the number
467
+ of descendants and the 5 examples that are closest to the centroid
468
+ """
469
  nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
470
 
471
  for nid, node in enumerate(node_list):
472
+ # get list of
473
+ node_examples = {}
474
+ for sid, score in node["sorted_examples_centroid"]:
475
+ node_examples[text_dset[sid][text_field_name]] = score
476
+ if len(node_examples) >= 5:
477
+ break
478
  node["label"] = node.get(
479
  "label",
480
  f"{nid:2d} - {node['weight']:5d} items <br>"
481
  + "<br>".join(
482
  [
483
+ f" {score:.2f} > {txt[:64]}" + ("..." if len(txt) >= 63 else "")
484
+ for txt, score in node_examples.items()
 
 
485
  ]
486
  ),
487
  )
488
 
489
  # make plot nodes
 
490
  labels = [node["label"] for node in node_list]
491
 
492
  root = node_list[0]
data_measurements/streamlit_utils.py CHANGED
@@ -21,7 +21,6 @@ from st_aggrid import AgGrid, GridOptionsBuilder
21
 
22
  from .dataset_utils import HF_DESC_FIELD, HF_FEATURE_FIELD, HF_LABEL_FIELD
23
 
24
-
25
  def sidebar_header():
26
  st.sidebar.markdown(
27
  """
@@ -167,7 +166,11 @@ def expander_text_lengths(dstats, column_id):
167
  st.markdown(
168
  "### Here is the relative frequency of different text lengths in your dataset:"
169
  )
170
- st.plotly_chart(dstats.fig_tok_length, use_container_width=True)
 
 
 
 
171
  st.markdown(
172
  "The average length of text instances is **"
173
  + str(dstats.avg_length)
@@ -175,19 +178,11 @@ def expander_text_lengths(dstats, column_id):
175
  + str(dstats.std_length)
176
  + "**."
177
  )
178
-
179
- start_id_show_lengths = st.slider(
180
- f"Show the shortest sentences{column_id} starting at:",
181
- 0,
182
- dstats.num_uniq_lengths,
183
- value=0,
184
- step=1,
185
- )
186
-
187
  # This is quite a large file and is breaking our ability to navigate the app development.
188
  # Just passing if it's not already there for launch v0
189
  if dstats.length_df is not None:
190
- st.dataframe(dstats.length_df[dstats.length_df["length"] == start_id_show_lengths].set_index("length"))
 
191
 
192
 
193
  ### Third, use a sentence embedding model
@@ -285,17 +280,7 @@ def expander_text_duplicates(dstats, column_id):
285
  if dstats.dup_counts_df is None or dstats.dup_counts_df.empty:
286
  st.write("There are no duplicates in this dataset! 🥳")
287
  else:
288
- gb = GridOptionsBuilder.from_dataframe(dstats.dup_counts_df)
289
- gb.configure_column(
290
- f"text{column_id}",
291
- wrapText=True,
292
- resizable=True,
293
- autoHeight=True,
294
- min_column_width=85,
295
- use_container_width=True,
296
- )
297
- go = gb.build()
298
- AgGrid(dstats.dup_counts_df, gridOptions=go)
299
 
300
 
301
  def expander_npmi_description(min_vocab):
 
21
 
22
  from .dataset_utils import HF_DESC_FIELD, HF_FEATURE_FIELD, HF_LABEL_FIELD
23
 
 
24
  def sidebar_header():
25
  st.sidebar.markdown(
26
  """
 
166
  st.markdown(
167
  "### Here is the relative frequency of different text lengths in your dataset:"
168
  )
169
+ #TODO: figure out more elegant way to do this:
170
+ try:
171
+ st.image(dstats.fig_tok_length_png)
172
+ except:
173
+ st.pyplot(dstats.fig_tok_length, use_container_width=True)
174
  st.markdown(
175
  "The average length of text instances is **"
176
  + str(dstats.avg_length)
 
178
  + str(dstats.std_length)
179
  + "**."
180
  )
 
 
 
 
 
 
 
 
 
181
  # This is quite a large file and is breaking our ability to navigate the app development.
182
  # Just passing if it's not already there for launch v0
183
  if dstats.length_df is not None:
184
+ start_id_show_lengths= st.selectbox("Show examples of length:", sorted(dstats.length_df["length"].unique().tolist()))
185
+ st.table(dstats.length_df[dstats.length_df["length"] == start_id_show_lengths].set_index("length"))
186
 
187
 
188
  ### Third, use a sentence embedding model
 
280
  if dstats.dup_counts_df is None or dstats.dup_counts_df.empty:
281
  st.write("There are no duplicates in this dataset! 🥳")
282
  else:
283
+ st.dataframe(dstats.dup_counts_df.reset_index(drop=True))
 
 
 
 
 
 
 
 
 
 
284
 
285
 
286
  def expander_npmi_description(min_vocab):