|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gzip |
|
import json |
|
import math |
|
import os |
|
from os.path import exists |
|
from os.path import join as pjoin |
|
|
|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
import torch |
|
import transformers |
|
from datasets import load_dataset |
|
from huggingface_hub import HfApi |
|
from tqdm import tqdm |
|
|
|
|
|
|
|
pd.options.display.max_colwidth = 256 |
|
|
|
_CACHE_DIR = "cache_dir" |
|
|
|
_DEFAULT_MODEL = "sentence-transformers/all-mpnet-base-v2" |
|
|
|
_MAX_MERGE = 20000000 |
|
|
|
def sentence_mean_pooling(model_output, attention_mask): |
|
token_embeddings = model_output[ |
|
0 |
|
] |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
) |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
|
input_mask_expanded.sum(1), min=1e-9 |
|
) |
|
|
|
|
|
|
|
def get_examplars(example_ids, centroid, embeddings, dset, n_examplars): |
|
example_embeds = embeddings[example_ids] |
|
example_scores = torch.mv(example_embeds, centroid) |
|
s_scores, s_ids = example_scores.sort(dim=-1, descending=True) |
|
examplars = [ |
|
(example_ids[i.item()], s.item()) |
|
for i, s in zip(s_ids[:n_examplars], s_scores[:n_examplars]) |
|
] |
|
res = [] |
|
for eid, score in examplars: |
|
dct = dict(dset[eid]) |
|
dct["score"] = score |
|
res += [dct] |
|
return res |
|
|
|
|
|
|
|
|
|
def pretty_order(nodes, node_ids): |
|
sorted_ids = sorted(node_ids, key=lambda nid: nodes[nid]["weight"]) |
|
sorted_a = [nid for i, nid in enumerate(sorted_ids) if i % 2 == 0] |
|
sorted_b = [nid for i, nid in enumerate(sorted_ids) if i % 2 == 1] |
|
sorted_b.reverse() |
|
return sorted_a + sorted_b |
|
|
|
|
|
def make_tree_plot(node_list, root_id, max_depth=-1): |
|
|
|
plot_nodes = [{} for _ in node_list] |
|
|
|
root = { |
|
"parent_id": -1, |
|
"node_id": root_id, |
|
"label": node_list[root_id]["hover_text"], |
|
"weight": node_list[root_id]["weight"], |
|
"num_leaves": 0, |
|
"children_ids": node_list[root_id]["children_ids"], |
|
"Xmin": 0, |
|
"Y": 0, |
|
} |
|
plot_nodes[root_id] = root |
|
|
|
root_depth = node_list[root_id]["depth"] |
|
|
|
def rec_make_coordinates(node): |
|
total_weight = 0 |
|
recurse = (max_depth == -1) or ( |
|
node_list[node["node_id"]]["depth"] - root_depth < max_depth - 1 |
|
) |
|
for cid in node["children_ids"]: |
|
plot_nodes[cid] = { |
|
"parent_id": node["node_id"], |
|
"node_id": cid, |
|
"label": node_list[cid]["hover_text"], |
|
"weight": node_list[cid]["weight"], |
|
"children_ids": node_list[cid]["children_ids"] if recurse else [], |
|
"Xmin": node["Xmin"] + total_weight, |
|
"Y": node["Y"] - 1, |
|
} |
|
plot_nodes[cid]["num_leaves"] = 1 if len(plot_nodes[cid]["children_ids"]) == 0 else 0 |
|
rec_make_coordinates(plot_nodes[cid]) |
|
total_weight += plot_nodes[cid]["num_leaves"] |
|
node["num_leaves"] += plot_nodes[cid]["num_leaves"] |
|
node["Xmax"] = node["Xmin"] + node["num_leaves"] |
|
node["X"] = node["Xmin"] + (node["num_leaves"] / 2) |
|
|
|
rec_make_coordinates(root) |
|
|
|
subtree_nodes = [node for node in plot_nodes if len(node) > 0] |
|
nid_map = dict([(node["node_id"], nid) for nid, node in enumerate(subtree_nodes)]) |
|
labels = [node["label"] for node in subtree_nodes] |
|
|
|
E = [] |
|
Xn = [] |
|
Yn = [] |
|
Xe = [] |
|
Ye = [] |
|
for nid, node in enumerate(subtree_nodes): |
|
Xn += [node["X"]] |
|
Yn += [node["Y"]] |
|
for cid in node["children_ids"]: |
|
child = plot_nodes[cid] |
|
E += [(nid, nid_map[child["node_id"]])] |
|
Xe += [node["X"], child["X"], None] |
|
Ye += [node["Y"], child["Y"], None] |
|
|
|
|
|
fig = go.Figure() |
|
fig.add_trace( |
|
go.Scatter( |
|
x=Xe, |
|
y=Ye, |
|
mode="lines", |
|
name="", |
|
line=dict(color="rgb(210,210,210)", width=1), |
|
hoverinfo="none", |
|
) |
|
) |
|
fig.add_trace( |
|
go.Scatter( |
|
x=Xn, |
|
y=Yn, |
|
mode="markers", |
|
name="nodes", |
|
marker=dict( |
|
symbol="circle-dot", |
|
size=18, |
|
color="#6175c1", |
|
line=dict(color="rgb(50,50,50)", width=1) |
|
|
|
), |
|
text=labels, |
|
hoverinfo="text", |
|
opacity=0.8, |
|
) |
|
) |
|
fig.layout.showlegend = False |
|
return fig |
|
|
|
|
|
class ClusteringBuilder: |
|
def __init__( |
|
self, |
|
dataset_name, |
|
config_name, |
|
split_name, |
|
input_field_path, |
|
label_name, |
|
num_rows, |
|
model_name=_DEFAULT_MODEL, |
|
): |
|
"""Item embeddings and clustering""" |
|
self.dataset_name = dataset_name |
|
self.config_name = config_name |
|
self.split_name = split_name |
|
self.input_field_path = input_field_path |
|
self.label_name = label_name |
|
self.num_rows = num_rows |
|
self.cache_path_list = [ |
|
_CACHE_DIR, |
|
dataset_name.replace("/", "---"), |
|
f"{'default' if config_name is None else config_name}", |
|
f"{'train' if split_name is None else split_name}", |
|
f"field-{'->'.join(input_field_path)}-label-{label_name}", |
|
f"{num_rows}_rows", |
|
model_name.replace("/", "---"), |
|
] |
|
self.cache_path = pjoin(*self.cache_path_list) |
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
self.model_name = model_name |
|
|
|
|
|
def set_model(self): |
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name) |
|
self.model = transformers.AutoModel.from_pretrained(self.model_name).to( |
|
self.device |
|
) |
|
|
|
def set_features_dataset(self, use_streaming, use_auth_token, use_dataset): |
|
dset, dset_path = prepare_clustering_dataset( |
|
dataset_name=self.dataset_name, |
|
input_field_path=self.input_field_path, |
|
label_name=self.label_name, |
|
config_name=self.config_name, |
|
split_name=self.split_name, |
|
num_rows=self.num_rows, |
|
use_streaming=use_streaming, |
|
use_auth_token=use_auth_token, |
|
use_dataset=use_dataset, |
|
) |
|
self.features_dset = dset |
|
|
|
def compute_feature_embeddings(self, sentences): |
|
batch = self.tokenizer( |
|
sentences, padding=True, truncation=True, return_tensors="pt" |
|
) |
|
batch = {k: v.to(self.device) for k, v in batch.items()} |
|
with torch.no_grad(): |
|
model_output = self.model(**batch) |
|
sentence_embeds = sentence_mean_pooling( |
|
model_output, batch["attention_mask"] |
|
) |
|
sentence_embeds /= sentence_embeds.norm(dim=-1, keepdim=True) |
|
return sentence_embeds |
|
|
|
def set_embeddings_dataset(self): |
|
def batch_embed(examples): |
|
return { |
|
"embedding": [ |
|
embed.tolist() |
|
for embed in self.compute_feature_embeddings(examples["field"]) |
|
] |
|
} |
|
|
|
if not exists(self.cache_path): |
|
os.mkdir(self.cache_path) |
|
|
|
self.embeddings_dset = self.features_dset.map( |
|
batch_embed, |
|
batched=True, |
|
batch_size=32, |
|
cache_file_name=pjoin(self.cache_path, "embeddings_dset"), |
|
) |
|
|
|
def prepare_embeddings( |
|
self, |
|
use_streaming=True, |
|
use_auth_token=None, |
|
use_dataset=None, |
|
): |
|
self.set_model() |
|
self.set_features_dataset(use_streaming, use_auth_token, use_dataset) |
|
self.set_embeddings_dataset() |
|
|
|
|
|
def prepare_merges(self, batch_size, low_thres): |
|
self.embeddings = torch.Tensor(self.embeddings_dset["embedding"]) |
|
all_indices = torch.LongTensor(torch.Size([0, 2])) |
|
all_scores = torch.Tensor(torch.Size([0])) |
|
n_batches = math.ceil(self.embeddings_dset.num_rows / batch_size) |
|
for a in range(n_batches): |
|
for b in tqdm(range(a, n_batches)): |
|
cos_scores = torch.mm( |
|
self.embeddings[a * batch_size : (a + 1) * batch_size], |
|
self.embeddings[b * batch_size : (b + 1) * batch_size].t(), |
|
) |
|
if a == b: |
|
cos_scores = cos_scores.triu(diagonal=1) |
|
merge_indices = torch.nonzero(cos_scores > low_thres) |
|
merge_indices[:, 0] += a * batch_size |
|
merge_indices[:, 1] += b * batch_size |
|
merge_scores = cos_scores[cos_scores > low_thres] |
|
all_indices = torch.cat([all_indices, merge_indices], dim=0) |
|
all_scores = torch.cat([all_scores, merge_scores], dim=0) |
|
self.sorted_scores, sorted_score_ids = all_scores.sort(dim=0, descending=True) |
|
self.sorted_scores = self.sorted_scores[:_MAX_MERGE] |
|
sorted_score_ids = sorted_score_ids[:_MAX_MERGE] |
|
self.sorted_indices = all_indices[sorted_score_ids] |
|
|
|
def make_starting_nodes(self, identical_threshold): |
|
identical_indices = self.sorted_indices[ |
|
self.sorted_scores >= identical_threshold |
|
] |
|
identical_inter = identical_indices[ |
|
identical_indices[:, 1].sort(stable=True).indices |
|
] |
|
identical_sorted = identical_inter[ |
|
identical_inter[:, 0].sort(stable=True).indices |
|
] |
|
self.parents = {} |
|
for a_pre, b_pre in identical_sorted: |
|
a = a_pre.item() |
|
b = b_pre.item() |
|
while self.parents.get(a, -1) != -1: |
|
a = self.parents[a] |
|
self.parents[b] = a |
|
self.duplicates = {} |
|
for a, b in self.parents.items(): |
|
self.duplicates[b] = self.duplicates.get(b, []) + [a] |
|
self.nodes = {} |
|
for node_id in range(self.features_dset.num_rows): |
|
if node_id in self.parents: |
|
continue |
|
else: |
|
self.nodes[node_id] = { |
|
"node_id": node_id, |
|
"parent_id": -1, |
|
"children": [], |
|
"children_ids": [], |
|
"example_ids": [node_id], |
|
"weight": 1, |
|
"merge_threshold": 0.98, |
|
"depth": 0, |
|
} |
|
|
|
def make_merge_nodes(self, identical_threshold, thres_step): |
|
new_node_id = self.features_dset.num_rows |
|
current_thres = identical_threshold |
|
depth = 1 |
|
merge_ids = self.sorted_indices[self.sorted_scores < identical_threshold] |
|
merge_scores = self.sorted_scores[self.sorted_scores < identical_threshold] |
|
for (node_id_a, node_id_b), merge_score in tqdm( |
|
zip(merge_ids, merge_scores), total=len(merge_ids) |
|
): |
|
if merge_score.item() < current_thres: |
|
current_thres -= thres_step |
|
merge_a = node_id_a.item() |
|
while self.parents.get(merge_a, -1) != -1: |
|
merge_a = self.parents[merge_a] |
|
self.parents[node_id_a] = merge_a |
|
merge_b = node_id_b.item() |
|
while self.parents.get(merge_b, -1) != -1: |
|
merge_b = self.parents[merge_b] |
|
self.parents[node_id_b] = merge_b |
|
if merge_a == merge_b: |
|
continue |
|
else: |
|
merge_b, merge_a = sorted([merge_a, merge_b]) |
|
node_a = self.nodes[merge_a] |
|
node_b = self.nodes[merge_b] |
|
if (node_a["depth"]) > 0 and min( |
|
node_a["merge_threshold"], node_b["merge_threshold"] |
|
) == current_thres: |
|
node_a["depth"] = max(node_a["depth"], node_b["depth"]) |
|
node_a["weight"] += node_b["weight"] |
|
node_a["children_ids"] += ( |
|
node_b["children_ids"] |
|
if node_b["depth"] > 0 |
|
else [node_b["node_id"]] |
|
) |
|
for cid in node_b["children_ids"]: |
|
self.nodes[cid]["parent_id"] = node_a["node_id"] |
|
self.parents[cid] = node_a["node_id"] |
|
node_b["parent_id"] = node_a["node_id"] |
|
self.parents[node_b["node_id"]] = node_a["node_id"] |
|
else: |
|
new_nid = new_node_id |
|
new_node_id += 1 |
|
new_node = { |
|
"node_id": new_nid, |
|
"parent_id": -1, |
|
"children_ids": [node_a["node_id"], node_b["node_id"]], |
|
"example_ids": [], |
|
"weight": node_a["weight"] + node_b["weight"], |
|
"merge_threshold": current_thres, |
|
"depth": max(node_a["depth"], node_b["depth"]) + 1, |
|
} |
|
depth = max(depth, new_node["depth"]) |
|
node_a["parent_id"] = new_nid |
|
node_b["parent_id"] = new_nid |
|
self.parents[node_a["node_id"]] = new_nid |
|
self.parents[node_b["node_id"]] = new_nid |
|
self.parents[node_id_a] = new_nid |
|
self.parents[node_id_b] = new_nid |
|
self.nodes[new_nid] = new_node |
|
return new_node_id |
|
|
|
def collapse_nodes(self, node, min_weight): |
|
children = [ |
|
self.collapse_nodes(self.nodes[cid], min_weight) |
|
for cid in node["children_ids"] |
|
if self.nodes[cid]["weight"] >= min_weight |
|
] |
|
extras = [ |
|
lid |
|
for cid in node["children_ids"] |
|
if self.nodes[cid]["weight"] < min_weight |
|
for lid in self.collapse_nodes(self.nodes[cid], min_weight)["example_ids"] |
|
] + node["example_ids"] |
|
extras_embed = ( |
|
torch.cat( |
|
[self.embeddings[eid][None, :] for eid in extras], |
|
dim=0, |
|
).sum(dim=0) |
|
if len(extras) > 0 |
|
else torch.zeros(self.embeddings.shape[-1]) |
|
) |
|
if len(children) == 0: |
|
node["extras"] = extras |
|
node["children_ids"] = [] |
|
node["example_ids"] = extras |
|
node["embedding_sum"] = extras_embed |
|
elif len(children) == 1: |
|
node["extras"] = extras + children[0]["extras"] |
|
node["children_ids"] = children[0]["children_ids"] |
|
node["example_ids"] = extras + children[0]["example_ids"] |
|
node["embedding_sum"] = extras_embed + children[0]["embedding_sum"] |
|
else: |
|
node["extras"] = extras |
|
node["children_ids"] = [child["node_id"] for child in children] |
|
node["example_ids"] = extras + [ |
|
eid for child in children for eid in child["example_ids"] |
|
] |
|
node["embedding_sum"] = ( |
|
extras_embed |
|
+ torch.cat( |
|
[child["embedding_sum"][None, :] for child in children], |
|
dim=0, |
|
).sum(dim=0) |
|
) |
|
assert ( |
|
len(node["example_ids"]) == node["weight"] |
|
), f"stuck at {node['node_id']} - {len(node['example_ids'])} - {node['weight']}" |
|
return node |
|
|
|
def finalize_node(self, node, parent_id, n_examplars, with_labels): |
|
new_node_id = len(self.tree_node_list) |
|
new_node = { |
|
"node_id": new_node_id, |
|
"parent_id": parent_id, |
|
"depth": 0 |
|
if parent_id == -1 |
|
else self.tree_node_list[parent_id]["depth"] + 1, |
|
"merged_at": node["merge_threshold"], |
|
"weight": node["weight"], |
|
"is_extra": False, |
|
} |
|
self.tree_node_list += [new_node] |
|
centroid = node["embedding_sum"] / node["embedding_sum"].norm() |
|
new_node["centroid"] = centroid.tolist() |
|
new_node["examplars"] = get_examplars( |
|
node["example_ids"], |
|
centroid, |
|
self.embeddings, |
|
self.features_dset, |
|
n_examplars, |
|
) |
|
label_counts = {} |
|
if with_labels: |
|
for eid in node["example_ids"]: |
|
label = self.features_dset[eid]["label"] |
|
label_counts[label] = label_counts.get(label, 0) + 1 |
|
new_node["label_counts"] = sorted( |
|
label_counts.items(), key=lambda x: x[1], reverse=True |
|
) |
|
if len(node["children_ids"]) == 0: |
|
new_node["children_ids"] = [] |
|
else: |
|
children = [ |
|
self.nodes[cid] |
|
for cid in pretty_order(self.nodes, node["children_ids"]) |
|
] |
|
children_ids = [ |
|
self.finalize_node(child, new_node_id, n_examplars, with_labels) |
|
for child in children |
|
] |
|
new_node["children_ids"] = children_ids |
|
if len(node["extras"]) > 0: |
|
extra_node = { |
|
"node_id": len(self.tree_node_list), |
|
"parent_id": new_node_id, |
|
"depth": new_node["depth"] + 1, |
|
"merged_at": node["merge_threshold"], |
|
"weight": len(node["extras"]), |
|
"is_extra": True, |
|
"centroid": new_node["centroid"], |
|
"examplars": get_examplars( |
|
node["extras"], |
|
centroid, |
|
self.embeddings, |
|
self.features_dset, |
|
n_examplars, |
|
), |
|
} |
|
self.tree_node_list += [extra_node] |
|
label_counts = {} |
|
if with_labels: |
|
for eid in node["extras"]: |
|
label = self.features_dset[eid]["label"] |
|
label_counts[label] = label_counts.get(label, 0) + 1 |
|
extra_node["label_counts"] = sorted( |
|
label_counts.items(), key=lambda x: x[1], reverse=True |
|
) |
|
extra_node["children_ids"] = [] |
|
new_node["children_ids"] += [extra_node["node_id"]] |
|
return new_node_id |
|
|
|
def make_hover_text(self, num_examples=5, text_width=64, with_labels=False): |
|
for nid, node in enumerate(self.tree_node_list): |
|
line_list = [ |
|
f"Node {nid:3d} - {node['weight']:6d} items - Linking threshold: {node['merged_at']:.2f}" |
|
] |
|
for examplar in node["examplars"][:num_examples]: |
|
line_list += [ |
|
f"{examplar['ids']:6d}:{examplar['score']:.2f} - {examplar['field'][:text_width]}" |
|
+ (f" - {examplar['label']}" if with_labels else "") |
|
] |
|
if with_labels: |
|
line_list += ["Label distribution"] |
|
for label, count in node["label_counts"]: |
|
line_list += [f" - label: {label} - {count} items"] |
|
node["hover_text"] = "<br>".join(line_list) |
|
|
|
def build_tree( |
|
self, |
|
batch_size=10000, |
|
low_thres=0.5, |
|
identical_threshold=0.95, |
|
thres_step=0.05, |
|
min_weight=10, |
|
n_examplars=25, |
|
hover_examples=5, |
|
hover_text_width=64, |
|
): |
|
self.prepare_merges(batch_size, low_thres) |
|
self.make_starting_nodes(identical_threshold) |
|
|
|
root_node_id = self.make_merge_nodes(identical_threshold, thres_step) |
|
top_nodes = [node for node in self.nodes.values() if node["parent_id"] == -1] |
|
root_node = { |
|
"node_id": root_node_id, |
|
"parent_id": -1, |
|
"children_ids": [node["node_id"] for node in top_nodes], |
|
"example_ids": [], |
|
"weight": sum([node["weight"] for node in top_nodes]), |
|
"merge_threshold": -1.0, |
|
"depth": 1 + max([node["depth"] for node in top_nodes]), |
|
} |
|
for node in top_nodes: |
|
node["parent_id"] = root_node_id |
|
self.nodes[root_node_id] = root_node |
|
_ = self.collapse_nodes(root_node, min_weight) |
|
self.tree_node_list = [] |
|
self.finalize_node( |
|
root_node, |
|
-1, |
|
n_examplars, |
|
with_labels=(self.label_name is not None), |
|
) |
|
self.make_hover_text( |
|
num_examples=hover_examples, |
|
text_width=hover_text_width, |
|
with_labels=(self.label_name is not None), |
|
) |
|
|
|
def push_to_hub(self, use_auth_token=None, file_name=None): |
|
path_list = self.cache_path_list |
|
name = "tree" if file_name is None else file_name |
|
tree_file = pjoin(pjoin(*path_list), f"{name}.jsonl.gz") |
|
fout = gzip.open(tree_file, "w") |
|
for node in tqdm(self.tree_node_list): |
|
_ = fout.write((json.dumps(node) + "\n").encode("utf-8")) |
|
fout.close() |
|
api = HfApi() |
|
file_loc = api.upload_file( |
|
path_or_fileobj=tree_file, |
|
path_in_repo=pjoin(pjoin(*path_list[1:]), f"{name}.jsonl.gz"), |
|
repo_id="yjernite/datasets_clusters", |
|
token=use_auth_token, |
|
repo_type="dataset", |
|
) |
|
return file_loc |
|
|
|
|
|
class Clustering: |
|
def __init__( |
|
self, |
|
dataset_name, |
|
config_name, |
|
split_name, |
|
input_field_path, |
|
label_name, |
|
num_rows, |
|
n_examplars=10, |
|
model_name=_DEFAULT_MODEL, |
|
file_name=None, |
|
max_depth_subtree=3, |
|
): |
|
self.dataset_name = dataset_name |
|
self.config_name = config_name |
|
self.split_name = split_name |
|
self.input_field_path = input_field_path |
|
self.label_name = label_name |
|
self.num_rows = num_rows |
|
self.model_name = model_name |
|
self.n_examplars = n_examplars |
|
self.file_name = "tree" if file_name is None else file_name |
|
self.repo_path_list = [ |
|
dataset_name.replace("/", "---"), |
|
f"{'default' if config_name is None else config_name}", |
|
f"{'train' if split_name is None else split_name}", |
|
f"field-{'->'.join(input_field_path)}-label-{label_name}", |
|
f"{num_rows}_rows", |
|
model_name.replace("/", "---"), |
|
f"{self.file_name}.jsonl.gz", |
|
] |
|
self.repo_path = pjoin(*self.repo_path_list) |
|
self.node_list = load_dataset( |
|
"yjernite/datasets_clusters", data_files=[self.repo_path] |
|
)["train"] |
|
self.node_reps = [{} for node in self.node_list] |
|
self.max_depth_subtree = max_depth_subtree |
|
|
|
def set_full_tree(self): |
|
self.node_reps[0]["tree"] = self.node_reps[0].get( |
|
"tree", |
|
make_tree_plot( |
|
self.node_list, |
|
0, |
|
), |
|
) |
|
|
|
def get_full_tree(self): |
|
self.set_full_tree() |
|
return self.node_reps[0]["tree"] |
|
|
|
def set_node_subtree(self, node_id): |
|
self.node_reps[node_id]["subtree"] = self.node_reps[node_id].get( |
|
"subtree", |
|
make_tree_plot( |
|
self.node_list, |
|
node_id, |
|
self.max_depth_subtree, |
|
), |
|
) |
|
|
|
def get_node_subtree(self, node_id): |
|
self.set_node_subtree(node_id) |
|
return self.node_reps[node_id]["subtree"] |
|
|
|
def set_node_examplars(self, node_id): |
|
self.node_reps[node_id]["examplars"] = self.node_reps[node_id].get( |
|
"examplars", |
|
pd.DataFrame( |
|
[ |
|
{ |
|
"id": exple["ids"], |
|
"score": exple["score"], |
|
"field": exple["field"], |
|
"label": exple.get("label", "N/A"), |
|
} |
|
for exple in self.node_list[node_id]["examplars"] |
|
][: self.n_examplars] |
|
), |
|
) |
|
|
|
def get_node_examplars(self, node_id): |
|
self.set_node_examplars(node_id) |
|
return self.node_reps[node_id]["examplars"] |
|
|
|
def set_node_label_chart(self, node_id): |
|
self.node_reps[node_id]["label_chart"] = self.node_reps[node_id].get( |
|
"label_chart", |
|
px.pie( |
|
values=[ct for lab, ct in self.node_list[node_id]["label_counts"]], |
|
names=[ |
|
f"Label {lab}" |
|
for lab, ct in self.node_list[node_id]["label_counts"] |
|
], |
|
color_discrete_sequence=px.colors.sequential.Rainbow, |
|
width=400, |
|
height=400, |
|
), |
|
) |
|
|
|
def get_node_label_chart(self, node_id): |
|
self.set_node_label_chart(node_id) |
|
return self.node_reps[node_id]["label_chart"] |
|
|