|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from os.path import exists |
|
from os.path import join as pjoin |
|
|
|
import plotly.graph_objects as go |
|
import torch |
|
import transformers |
|
from datasets import load_from_disk |
|
from plotly.io import read_json |
|
from tqdm import tqdm |
|
|
|
from .dataset_utils import EMBEDDING_FIELD |
|
|
|
|
|
def sentence_mean_pooling(model_output, attention_mask): |
|
"""Mean pooling of token embeddings for a sentence.""" |
|
token_embeddings = model_output[ |
|
0 |
|
] |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
) |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
|
input_mask_expanded.sum(1), min=1e-9 |
|
) |
|
|
|
|
|
class Embeddings: |
|
def __init__( |
|
self, |
|
dstats=None, |
|
text_dset=None, |
|
text_field_name="text", |
|
cache_path="", |
|
use_cache=False, |
|
): |
|
"""Item embeddings and clustering""" |
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
self.model_name = "sentence-transformers/all-mpnet-base-v2" |
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) |
|
self.model = transformers.AutoModel.from_pretrained(self.model_name).to( |
|
self.device |
|
) |
|
self.text_dset = text_dset if dstats is None else dstats.text_dset |
|
self.text_field_name = ( |
|
text_field_name if dstats is None else dstats.our_text_field |
|
) |
|
self.cache_path = cache_path if dstats is None else dstats.cache_path |
|
self.embeddings_dset_fid = pjoin(self.cache_path, "embeddings_dset") |
|
self.embeddings_dset = None |
|
self.node_list_fid = pjoin(self.cache_path, "node_list.th") |
|
self.node_list = None |
|
self.nid_map = None |
|
self.fig_tree_fid = pjoin(self.cache_path, "node_figure.json") |
|
self.fig_tree = None |
|
self.cached_clusters = {} |
|
self.use_cache = use_cache |
|
|
|
def compute_sentence_embeddings(self, sentences): |
|
""" |
|
Takes a list of sentences and computes their embeddings |
|
using self.tokenizer and self.model (with output dimension D) |
|
followed by mean pooling of the token representations and normalization |
|
Args: |
|
sentences ([string]): list of N input sentences |
|
Returns: |
|
torch.Tensor: sentence embeddings, dimension NxD |
|
""" |
|
batch = self.tokenizer( |
|
sentences, padding=True, truncation=True, return_tensors="pt" |
|
) |
|
batch = {k: v.to(self.device) for k, v in batch.items()} |
|
with torch.no_grad(): |
|
model_output = self.model(**batch) |
|
sentence_embeds = sentence_mean_pooling( |
|
model_output, batch["attention_mask"] |
|
) |
|
sentence_embeds /= sentence_embeds.norm(dim=-1, keepdim=True) |
|
return sentence_embeds |
|
|
|
def make_embeddings(self): |
|
""" |
|
Batch computes the embeddings of the Dataset self.text_dset, |
|
using the field self.text_field_name as input. |
|
Returns: |
|
Dataset: HF dataset object with a single EMBEDDING_FIELD field |
|
corresponding to the embeddings (list of floats) |
|
""" |
|
|
|
def batch_embed_sentences(sentences): |
|
return { |
|
EMBEDDING_FIELD: [ |
|
embed.tolist() |
|
for embed in self.compute_sentence_embeddings( |
|
sentences[self.text_field_name] |
|
) |
|
] |
|
} |
|
|
|
self.embeddings_dset = self.text_dset.map( |
|
batch_embed_sentences, |
|
batched=True, |
|
batch_size=32, |
|
remove_columns=[self.text_field_name], |
|
) |
|
|
|
return self.embeddings_dset |
|
|
|
def make_text_embeddings(self): |
|
"""Load embeddings dataset from cache or compute it.""" |
|
if self.use_cache and exists(self.embeddings_dset_fid): |
|
self.embeddings_dset = load_from_disk(self.embeddings_dset_fid) |
|
else: |
|
self.embeddings_dset = self.make_embeddings() |
|
self.embeddings_dset.save_to_disk(self.embeddings_dset_fid) |
|
|
|
def make_hierarchical_clustering( |
|
self, |
|
batch_size=1000, |
|
approx_neighbors=1000, |
|
min_cluster_size=10, |
|
): |
|
if self.use_cache and exists(self.node_list_fid): |
|
self.node_list, self.nid_map = torch.load(self.node_list_fid) |
|
else: |
|
self.make_text_embeddings() |
|
embeddings = torch.Tensor(self.embeddings_dset[EMBEDDING_FIELD]) |
|
self.node_list = fast_cluster( |
|
embeddings, batch_size, approx_neighbors, min_cluster_size |
|
) |
|
self.nid_map = dict( |
|
[(node["nid"], nid) for nid, node in enumerate(self.node_list)] |
|
) |
|
torch.save((self.node_list, self.nid_map), self.node_list_fid) |
|
print(exists(self.fig_tree_fid), self.fig_tree_fid) |
|
if self.use_cache and exists(self.fig_tree_fid): |
|
self.fig_tree = read_json(self.fig_tree_fid) |
|
else: |
|
self.fig_tree = make_tree_plot( |
|
self.node_list, self.nid_map, self.text_dset, self.text_field_name |
|
) |
|
self.fig_tree.write_json(self.fig_tree_fid) |
|
|
|
def find_cluster_beam(self, sentence, beam_size=20): |
|
""" |
|
This function finds the `beam_size` leaf clusters that are closest to the |
|
proposed sentence and returns the full path from the root to the cluster |
|
along with the dot product between the sentence embedding and the |
|
cluster centroid |
|
Args: |
|
sentence (string): input sentence for which to find clusters |
|
beam_size (int): this is a beam size algorithm to explore the tree |
|
Returns: |
|
[([int], float)]: list of (path_from_root, score) sorted by score |
|
""" |
|
embed = self.compute_sentence_embeddings([sentence])[0].to("cpu") |
|
active_paths = [([0], torch.dot(embed, self.node_list[0]["centroid"]).item())] |
|
finished_paths = [] |
|
children_ids_list = [ |
|
[ |
|
self.nid_map[nid] |
|
for nid in self.node_list[path[-1]]["children_ids"] |
|
if nid in self.nid_map |
|
] |
|
for path, score in active_paths |
|
] |
|
while len(active_paths) > 0: |
|
next_ids = sorted( |
|
[ |
|
( |
|
beam_id, |
|
nid, |
|
torch.dot(embed, self.node_list[nid]["centroid"]).item(), |
|
) |
|
for beam_id, children_ids in enumerate(children_ids_list) |
|
for nid in children_ids |
|
], |
|
key=lambda x: x[2], |
|
reverse=True, |
|
)[:beam_size] |
|
paths = [ |
|
(active_paths[beam_id][0] + [next_id], score) |
|
for beam_id, next_id, score in next_ids |
|
] |
|
active_paths = [] |
|
for path, score in paths: |
|
if ( |
|
len( |
|
[ |
|
nid |
|
for nid in self.node_list[path[-1]]["children_ids"] |
|
if nid in self.nid_map |
|
] |
|
) |
|
> 0 |
|
): |
|
active_paths += [(path, score)] |
|
else: |
|
finished_paths += [(path, score)] |
|
children_ids_list = [ |
|
[ |
|
self.nid_map[nid] |
|
for nid in self.node_list[path[-1]]["children_ids"] |
|
if nid in self.nid_map |
|
] |
|
for path, score in active_paths |
|
] |
|
return sorted( |
|
finished_paths, |
|
key=lambda x: x[-1], |
|
reverse=True, |
|
)[:beam_size] |
|
|
|
|
|
def prepare_merges(embeddings, batch_size=1000, approx_neighbors=1000, low_thres=0.5): |
|
""" |
|
Prepares an initial list of merges for hierarchical |
|
clustering. First compute the `approx_neighbors` nearest neighbors, |
|
then propose a merge for any two points that are closer than `low_thres` |
|
|
|
Note that if a point has more than `approx_neighbors` neighbors |
|
closer than `low_thres`, this approach will miss some of those merges |
|
|
|
Args: |
|
embeddings (toch.Tensor): Tensor of sentence embeddings - dimension NxD |
|
batch_size (int): compute nearest neighbors of `batch_size` points at a time |
|
approx_neighbors (int): only keep `approx_neighbors` nearest neighbors of a point |
|
low_thres (float): only return merges where the dot product is greater than `low_thres` |
|
Returns: |
|
torch.LongTensor: proposed merges ([i, j] with i>j) - dimension: Mx2 |
|
torch.Tensor: merge scores - dimension M |
|
""" |
|
top_idx_pre = torch.cat( |
|
[torch.LongTensor(range(embeddings.shape[0]))[:, None]] * batch_size, dim=1 |
|
) |
|
top_val_all = torch.Tensor(0, approx_neighbors) |
|
top_idx_all = torch.LongTensor(0, approx_neighbors) |
|
n_batches = math.ceil(len(embeddings) / batch_size) |
|
for b in tqdm(range(n_batches)): |
|
|
|
cos_scores = torch.mm( |
|
embeddings[b * batch_size : (b + 1) * batch_size], embeddings.t() |
|
) |
|
for i in range(cos_scores.shape[0]): |
|
cos_scores[i, (b * batch_size) + i :] = -1 |
|
top_val_large, top_idx_large = cos_scores.topk( |
|
k=approx_neighbors, dim=-1, largest=True |
|
) |
|
top_val_all = torch.cat([top_val_all, top_val_large], dim=0) |
|
top_idx_all = torch.cat([top_idx_all, top_idx_large], dim=0) |
|
max_neighbor_dist = top_val_large[:, -1].max().item() |
|
if max_neighbor_dist > low_thres: |
|
print( |
|
f"WARNING: with the current set of neireast neighbor, the farthest is {max_neighbor_dist}" |
|
) |
|
|
|
all_merges = torch.cat( |
|
[ |
|
top_idx_pre[top_val_all > low_thres][:, None], |
|
top_idx_all[top_val_all > low_thres][:, None], |
|
], |
|
dim=1, |
|
) |
|
all_merge_scores = top_val_all[top_val_all > low_thres] |
|
|
|
return (all_merges, all_merge_scores) |
|
|
|
|
|
def merge_nodes(nodes, current_thres, previous_thres, all_merges, all_merge_scores): |
|
""" |
|
Merge all nodes if the max dot product between any of their descendants |
|
is greater than current_thres. |
|
|
|
Args: |
|
nodes ([dict]): list of dicts representing the current set of nodes |
|
current_thres (float): merge all nodes closer than current_thres |
|
previous_thres (float): nodes closer than previous_thres are already merged |
|
all_merges (torch.LongTensor): proposed merges ([i, j] with i>j) - dimension: Mx2 |
|
all_merge_scores (torch.Tensor): merge scores - dimension M |
|
Returns: |
|
[dict]: extended list with the newly created internal nodes |
|
""" |
|
merge_ids = (all_merge_scores <= previous_thres) * ( |
|
all_merge_scores > current_thres |
|
) |
|
if merge_ids.sum().item() > 0: |
|
merges = all_merges[merge_ids] |
|
for a, b in merges.tolist(): |
|
node_a = nodes[a] |
|
while node_a["parent_id"] != -1: |
|
node_a = nodes[node_a["parent_id"]] |
|
node_b = nodes[b] |
|
while node_b["parent_id"] != -1: |
|
node_b = nodes[node_b["parent_id"]] |
|
if node_a["nid"] == node_b["nid"]: |
|
continue |
|
else: |
|
|
|
if (node_a["depth"] + node_b["depth"]) > 0 and min( |
|
node_a["merge_threshold"], node_b["merge_threshold"] |
|
) == current_thres: |
|
merge_to = None |
|
merge_from = None |
|
if node_a["nid"] < node_b["nid"]: |
|
merge_from = node_a |
|
merge_to = node_b |
|
if node_a["nid"] > node_b["nid"]: |
|
merge_from = node_b |
|
merge_to = node_a |
|
merge_to["depth"] = max(merge_to["depth"], merge_from["depth"]) |
|
merge_to["weight"] += merge_from["weight"] |
|
merge_to["children_ids"] += ( |
|
merge_from["children_ids"] |
|
if merge_from["depth"] > 0 |
|
else [merge_from["nid"]] |
|
) |
|
for cid in merge_from["children_ids"]: |
|
nodes[cid]["parent_id"] = merge_to["nid"] |
|
merge_from["parent_id"] = merge_to["nid"] |
|
|
|
else: |
|
new_nid = len(nodes) |
|
new_node = { |
|
"nid": new_nid, |
|
"parent_id": -1, |
|
"depth": max(node_a["depth"], node_b["depth"]) + 1, |
|
"weight": node_a["weight"] + node_b["weight"], |
|
"children": [], |
|
"children_ids": [node_a["nid"], node_b["nid"]], |
|
"example_ids": [], |
|
"merge_threshold": current_thres, |
|
} |
|
node_a["parent_id"] = new_nid |
|
node_b["parent_id"] = new_nid |
|
nodes += [new_node] |
|
return nodes |
|
|
|
|
|
def finalize_node(node, nodes, min_cluster_size): |
|
"""Post-process nodes to sort children by descending weight, |
|
get full list of leaves in the sub-tree, and direct links |
|
to the cildren nodes, then recurses to all children. |
|
|
|
Nodes with fewer than `min_cluster_size` descendants are collapsed |
|
into a single leaf. |
|
""" |
|
node["children"] = sorted( |
|
[ |
|
finalize_node(nodes[cid], nodes, min_cluster_size) |
|
for cid in node["children_ids"] |
|
], |
|
key=lambda x: x["weight"], |
|
reverse=True, |
|
) |
|
if node["depth"] > 0: |
|
node["example_ids"] = [ |
|
eid for child in node["children"] for eid in child["example_ids"] |
|
] |
|
node["children"] = [ |
|
child for child in node["children"] if child["weight"] >= min_cluster_size |
|
] |
|
assert node["weight"] == len(node["example_ids"]), print(node) |
|
return node |
|
|
|
|
|
def fast_cluster( |
|
embeddings, |
|
batch_size=1000, |
|
approx_neighbors=1000, |
|
min_cluster_size=10, |
|
low_thres=0.5, |
|
): |
|
""" |
|
Computes an approximate hierarchical clustering based on example |
|
embeddings. The join criterion is min clustering, i.e. two clusters |
|
are joined if any pair of their descendants are closer than a threshold |
|
|
|
The approximate comes from the fact that only the `approx_neighbors` nearest |
|
neighbors of an example are considered for merges |
|
""" |
|
batch_size = min(embeddings.shape[0], batch_size) |
|
all_merges, all_merge_scores = prepare_merges( |
|
embeddings, batch_size, approx_neighbors, low_thres |
|
) |
|
|
|
nodes = [ |
|
{ |
|
"nid": nid, |
|
"parent_id": -1, |
|
"depth": 0, |
|
"weight": 1, |
|
"children": [], |
|
"children_ids": [], |
|
"example_ids": [nid], |
|
"merge_threshold": 1.0, |
|
} |
|
for nid in range(embeddings.shape[0]) |
|
] |
|
|
|
for i in range(10): |
|
p_thres = 1 - i * 0.05 |
|
c_thres = 0.95 - i * 0.05 |
|
nodes = merge_nodes(nodes, c_thres, p_thres, all_merges, all_merge_scores) |
|
|
|
root_children = [ |
|
node |
|
for node in nodes |
|
if node["parent_id"] == -1 and node["weight"] >= min_cluster_size |
|
] |
|
root = { |
|
"nid": len(nodes), |
|
"parent_id": -1, |
|
"depth": max([node["depth"] for node in root_children]) + 1, |
|
"weight": sum([node["weight"] for node in root_children]), |
|
"children": [], |
|
"children_ids": [node["nid"] for node in root_children], |
|
"example_ids": [], |
|
"merge_threshold": -1.0, |
|
} |
|
nodes += [root] |
|
for node in root_children: |
|
node["parent_id"] = root["nid"] |
|
|
|
tree = finalize_node(root, nodes, min_cluster_size) |
|
node_list = [] |
|
|
|
def rec_map_nodes(node, node_list): |
|
node_list += [node] |
|
for child in node["children"]: |
|
rec_map_nodes(child, node_list) |
|
|
|
rec_map_nodes(tree, node_list) |
|
|
|
for node in node_list: |
|
node_embeds = embeddings[node["example_ids"]] |
|
node["centroid"] = node_embeds.sum(dim=0) |
|
node["centroid"] /= node["centroid"].norm() |
|
node["centroid_dot_prods"] = torch.mv(node_embeds, node["centroid"]) |
|
node["sorted_examples_centroid"] = sorted( |
|
[ |
|
(eid, edp.item()) |
|
for eid, edp in zip(node["example_ids"], node["centroid_dot_prods"]) |
|
], |
|
key=lambda x: x[1], |
|
reverse=True, |
|
) |
|
return node_list |
|
|
|
|
|
def make_tree_plot(node_list, nid_map, text_dset, text_field_name): |
|
""" |
|
Makes a graphical representation of the tree encoded |
|
in node-list. The hover label for each node shows the number |
|
of descendants and the 5 examples that are closest to the centroid |
|
""" |
|
for nid, node in enumerate(node_list): |
|
|
|
node_examples = {} |
|
for sid, score in node["sorted_examples_centroid"]: |
|
node_examples[text_dset[sid][text_field_name]] = score |
|
if len(node_examples) >= 5: |
|
break |
|
node["label"] = node.get( |
|
"label", |
|
f"{nid:2d} - {node['weight']:5d} items <br>" |
|
+ "<br>".join( |
|
[ |
|
f" {score:.2f} > {txt[:64]}" + ("..." if len(txt) >= 63 else "") |
|
for txt, score in node_examples.items() |
|
] |
|
), |
|
) |
|
|
|
|
|
labels = [node["label"] for node in node_list] |
|
|
|
root = node_list[0] |
|
root["X"] = 0 |
|
root["Y"] = 0 |
|
|
|
def rec_make_coordinates(node): |
|
total_weight = 0 |
|
add_weight = len(node["example_ids"]) - sum( |
|
[child["weight"] for child in node["children"]] |
|
) |
|
for child in node["children"]: |
|
child["X"] = node["X"] + total_weight |
|
child["Y"] = node["Y"] - 1 |
|
total_weight += child["weight"] + add_weight / len(node["children"]) |
|
rec_make_coordinates(child) |
|
|
|
rec_make_coordinates(root) |
|
|
|
E = [] |
|
Xn = [] |
|
Yn = [] |
|
Xe = [] |
|
Ye = [] |
|
for nid, node in enumerate(node_list): |
|
Xn += [node["X"]] |
|
Yn += [node["Y"]] |
|
for child in node["children"]: |
|
E += [(nid, nid_map[child["nid"]])] |
|
Xe += [node["X"], child["X"], None] |
|
Ye += [node["Y"], child["Y"], None] |
|
|
|
|
|
fig = go.Figure() |
|
fig.add_trace( |
|
go.Scatter( |
|
x=Xe, |
|
y=Ye, |
|
mode="lines", |
|
line=dict(color="rgb(210,210,210)", width=1), |
|
hoverinfo="none", |
|
) |
|
) |
|
fig.add_trace( |
|
go.Scatter( |
|
x=Xn, |
|
y=Yn, |
|
mode="markers", |
|
name="nodes", |
|
marker=dict( |
|
symbol="circle-dot", |
|
size=18, |
|
color="#6175c1", |
|
line=dict(color="rgb(50,50,50)", width=1) |
|
|
|
), |
|
text=labels, |
|
hoverinfo="text", |
|
opacity=0.8, |
|
) |
|
) |
|
return fig |
|
|