Spaces:
Sleeping
Sleeping
Clean up
Browse files- src/config.py +0 -6
- src/demo.py +0 -11
- src/embedding.py +0 -1
- src/gnn.py +5 -22
- src/heuristic.py +1 -25
- src/utils.py +2 -30
- src/visualization.py +1 -1
src/config.py
CHANGED
|
@@ -56,7 +56,6 @@ class Config(BaseSettings):
|
|
| 56 |
}
|
| 57 |
|
| 58 |
COLOR_MAPPING: dict[str, str] = {
|
| 59 |
-
# STEM & Natural Sciences -> Emerald (#06d6a0)
|
| 60 |
'Biology': '#06d6a0',
|
| 61 |
'Chemistry': '#06d6a0',
|
| 62 |
'Earth_and_environment': '#06d6a0',
|
|
@@ -65,7 +64,6 @@ class Config(BaseSettings):
|
|
| 65 |
'STEM': '#06d6a0',
|
| 66 |
'Space': '#06d6a0',
|
| 67 |
|
| 68 |
-
# Geography & Places -> Ocean Blue (#118ab2)
|
| 69 |
'Africa': '#118ab2',
|
| 70 |
'Americas': '#118ab2',
|
| 71 |
'Asia': '#118ab2',
|
|
@@ -73,7 +71,6 @@ class Config(BaseSettings):
|
|
| 73 |
'Oceania': '#118ab2',
|
| 74 |
'Geographical': '#118ab2',
|
| 75 |
|
| 76 |
-
# Arts, Entertainment & Culture -> Bubblegum Pink (#ef476f)
|
| 77 |
'Entertainment': '#ef476f',
|
| 78 |
'Fashion': '#ef476f',
|
| 79 |
'Films': '#ef476f',
|
|
@@ -83,7 +80,6 @@ class Config(BaseSettings):
|
|
| 83 |
'Visual_arts': '#ef476f',
|
| 84 |
'Literature': '#ef476f',
|
| 85 |
|
| 86 |
-
# Tech, Engineering & Infrastructure -> Dark Teal (#073b4c)
|
| 87 |
'Architecture': '#073b4c',
|
| 88 |
'Computing': '#073b4c',
|
| 89 |
'Engineering': '#073b4c',
|
|
@@ -92,7 +88,6 @@ class Config(BaseSettings):
|
|
| 92 |
'Transportation': '#073b4c',
|
| 93 |
'Video_games': '#073b4c',
|
| 94 |
|
| 95 |
-
# Society, Humanities & Lifestyle -> Coral Glow (#f78c6b)
|
| 96 |
'Biography': '#f78c6b',
|
| 97 |
'Food_and_drink': '#f78c6b',
|
| 98 |
'Linguistics': '#f78c6b',
|
|
@@ -101,7 +96,6 @@ class Config(BaseSettings):
|
|
| 101 |
'Society': '#f78c6b',
|
| 102 |
'Sports': '#f78c6b',
|
| 103 |
|
| 104 |
-
# Institutions, History & Governance -> Royal Gold (#ffd166)
|
| 105 |
'Business_and_economics': '#ffd166',
|
| 106 |
'Education': '#ffd166',
|
| 107 |
'History': '#ffd166',
|
|
|
|
| 56 |
}
|
| 57 |
|
| 58 |
COLOR_MAPPING: dict[str, str] = {
|
|
|
|
| 59 |
'Biology': '#06d6a0',
|
| 60 |
'Chemistry': '#06d6a0',
|
| 61 |
'Earth_and_environment': '#06d6a0',
|
|
|
|
| 64 |
'STEM': '#06d6a0',
|
| 65 |
'Space': '#06d6a0',
|
| 66 |
|
|
|
|
| 67 |
'Africa': '#118ab2',
|
| 68 |
'Americas': '#118ab2',
|
| 69 |
'Asia': '#118ab2',
|
|
|
|
| 71 |
'Oceania': '#118ab2',
|
| 72 |
'Geographical': '#118ab2',
|
| 73 |
|
|
|
|
| 74 |
'Entertainment': '#ef476f',
|
| 75 |
'Fashion': '#ef476f',
|
| 76 |
'Films': '#ef476f',
|
|
|
|
| 80 |
'Visual_arts': '#ef476f',
|
| 81 |
'Literature': '#ef476f',
|
| 82 |
|
|
|
|
| 83 |
'Architecture': '#073b4c',
|
| 84 |
'Computing': '#073b4c',
|
| 85 |
'Engineering': '#073b4c',
|
|
|
|
| 88 |
'Transportation': '#073b4c',
|
| 89 |
'Video_games': '#073b4c',
|
| 90 |
|
|
|
|
| 91 |
'Biography': '#f78c6b',
|
| 92 |
'Food_and_drink': '#f78c6b',
|
| 93 |
'Linguistics': '#f78c6b',
|
|
|
|
| 96 |
'Society': '#f78c6b',
|
| 97 |
'Sports': '#f78c6b',
|
| 98 |
|
|
|
|
| 99 |
'Business_and_economics': '#ffd166',
|
| 100 |
'Education': '#ffd166',
|
| 101 |
'History': '#ffd166',
|
src/demo.py
CHANGED
|
@@ -85,17 +85,6 @@ if "setup_complete" not in st.session_state:
|
|
| 85 |
st.session_state.setup_complete = True
|
| 86 |
|
| 87 |
|
| 88 |
-
# node_styles = [
|
| 89 |
-
# NodeStyle("PERSON", "#FF7F3E", "name", "person"),
|
| 90 |
-
# NodeStyle("POST", "#2A629A", "content", "description"),
|
| 91 |
-
# ]
|
| 92 |
-
|
| 93 |
-
# edge_styles = [
|
| 94 |
-
# EdgeStyle("FOLLOWS", caption="label", directed=True),
|
| 95 |
-
# EdgeStyle("POSTED", caption="label", directed=True),
|
| 96 |
-
# EdgeStyle("QUOTES", caption="label", directed=True),
|
| 97 |
-
# ]
|
| 98 |
-
|
| 99 |
node_styles = get_node_styles()
|
| 100 |
edge_styles = get_edge_styles()
|
| 101 |
|
|
|
|
| 85 |
st.session_state.setup_complete = True
|
| 86 |
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
node_styles = get_node_styles()
|
| 89 |
edge_styles = get_edge_styles()
|
| 90 |
|
src/embedding.py
CHANGED
|
@@ -9,7 +9,6 @@ from transformers import AutoModel, AutoTokenizer
|
|
| 9 |
|
| 10 |
class Embedder:
|
| 11 |
def __init__(self, path):
|
| 12 |
-
# time.sleep(1)
|
| 13 |
self.model_name_or_path = path
|
| 14 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 15 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
|
|
|
|
| 9 |
|
| 10 |
class Embedder:
|
| 11 |
def __init__(self, path):
|
|
|
|
| 12 |
self.model_name_or_path = path
|
| 13 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
|
src/gnn.py
CHANGED
|
@@ -14,8 +14,6 @@ class GNNClassifier(torch.nn.Module):
|
|
| 14 |
self.layers = layers
|
| 15 |
self.output_dim = output_dim
|
| 16 |
|
| 17 |
-
# IMPROVEMENT 1: Reduce to 2 layers to prevent over-smoothing
|
| 18 |
-
# If you really need 3 layers, you must add Residual Connections (x = x + conv(x))
|
| 19 |
if layers == 2:
|
| 20 |
self.conv1 = GCNConv(input_dim, hidden_dim)
|
| 21 |
self.conv2 = GCNConv(hidden_dim, output_dim)
|
|
@@ -27,15 +25,10 @@ class GNNClassifier(torch.nn.Module):
|
|
| 27 |
def forward(self, data):
|
| 28 |
x, edge_index = data.x, data.edge_index
|
| 29 |
|
| 30 |
-
# Layer 1
|
| 31 |
x = self.conv1(x, edge_index)
|
| 32 |
x = F.relu(x)
|
| 33 |
-
|
| 34 |
-
# IMPROVEMENT 2: Higher Dropout (0.5 is standard for citation networks)
|
| 35 |
-
# This prevents the model from relying too much on specific neighbor connections
|
| 36 |
x = F.dropout(x, p=self.dropout_rate, training=self.training)
|
| 37 |
|
| 38 |
-
# Layer 2
|
| 39 |
x = self.conv2(x, edge_index)
|
| 40 |
|
| 41 |
if self.layers == 3:
|
|
@@ -63,8 +56,8 @@ def load_data(version: str = "undirected"):
|
|
| 63 |
def infer_new_node(
|
| 64 |
base_data: Data,
|
| 65 |
model: torch.nn.Module,
|
| 66 |
-
new_embedding,
|
| 67 |
-
referenced_titles: list[str],
|
| 68 |
title_to_id: dict[str, int],
|
| 69 |
label_mapping: dict[str, int],
|
| 70 |
device: torch.device,
|
|
@@ -73,11 +66,9 @@ def infer_new_node(
|
|
| 73 |
):
|
| 74 |
model.eval()
|
| 75 |
|
| 76 |
-
# Move model to device
|
| 77 |
model = model.to(device)
|
| 78 |
base_data = base_data.to(device)
|
| 79 |
|
| 80 |
-
# --- 1) Prepare new node feature ---
|
| 81 |
x_old = base_data.x
|
| 82 |
new_x = torch.tensor(new_embedding, dtype=x_old.dtype).view(1, -1)
|
| 83 |
new_x = new_x.to(device)
|
|
@@ -85,7 +76,6 @@ def infer_new_node(
|
|
| 85 |
|
| 86 |
new_id = x.size(0) - 1
|
| 87 |
|
| 88 |
-
# --- 2) Build new edges that attach the node ---
|
| 89 |
src_list = []
|
| 90 |
tgt_list = []
|
| 91 |
|
|
@@ -94,19 +84,13 @@ def infer_new_node(
|
|
| 94 |
continue
|
| 95 |
old_id = title_to_id[t]
|
| 96 |
|
| 97 |
-
# If you want new node to be influenced by referenced nodes in 1 hop,
|
| 98 |
-
# you need edges old -> new (incoming to new).
|
| 99 |
src_list.append(old_id)
|
| 100 |
tgt_list.append(new_id)
|
| 101 |
|
| 102 |
-
# Optional: also add new -> old to make it undirected / symmetric
|
| 103 |
if make_undirected_for_new_node:
|
| 104 |
src_list.append(new_id)
|
| 105 |
tgt_list.append(old_id)
|
| 106 |
|
| 107 |
-
# If the user picked nothing, the node is isolated; GCNConv can still work
|
| 108 |
-
# because it adds self-loops by default, but performance may be weak.
|
| 109 |
-
|
| 110 |
if len(src_list) > 0 and use_edges:
|
| 111 |
new_edges = torch.tensor([src_list, tgt_list], dtype=torch.long)
|
| 112 |
new_edges = new_edges.to(device)
|
|
@@ -114,19 +98,18 @@ def infer_new_node(
|
|
| 114 |
else:
|
| 115 |
edge_index = base_data.edge_index
|
| 116 |
|
| 117 |
-
# --- 3) Run inference on the augmented graph ---
|
| 118 |
data_aug = Data(x=x, edge_index=edge_index).to(device)
|
| 119 |
|
| 120 |
with torch.no_grad():
|
| 121 |
-
out = model(data_aug)
|
| 122 |
log_probs = F.log_softmax(out, dim=1)
|
| 123 |
-
log_probs = log_probs[new_id]
|
| 124 |
pred_id = int(torch.argmax(log_probs).item())
|
| 125 |
|
| 126 |
inv_label_mapping = {v: k for k, v in label_mapping.items()}
|
| 127 |
pred_label = inv_label_mapping[pred_id]
|
| 128 |
|
| 129 |
-
probs = log_probs.exp().detach().cpu()
|
| 130 |
|
| 131 |
columns = ["Class", "Score"]
|
| 132 |
result_df = pd.DataFrame(
|
|
|
|
| 14 |
self.layers = layers
|
| 15 |
self.output_dim = output_dim
|
| 16 |
|
|
|
|
|
|
|
| 17 |
if layers == 2:
|
| 18 |
self.conv1 = GCNConv(input_dim, hidden_dim)
|
| 19 |
self.conv2 = GCNConv(hidden_dim, output_dim)
|
|
|
|
| 25 |
def forward(self, data):
|
| 26 |
x, edge_index = data.x, data.edge_index
|
| 27 |
|
|
|
|
| 28 |
x = self.conv1(x, edge_index)
|
| 29 |
x = F.relu(x)
|
|
|
|
|
|
|
|
|
|
| 30 |
x = F.dropout(x, p=self.dropout_rate, training=self.training)
|
| 31 |
|
|
|
|
| 32 |
x = self.conv2(x, edge_index)
|
| 33 |
|
| 34 |
if self.layers == 3:
|
|
|
|
| 56 |
def infer_new_node(
|
| 57 |
base_data: Data,
|
| 58 |
model: torch.nn.Module,
|
| 59 |
+
new_embedding,
|
| 60 |
+
referenced_titles: list[str],
|
| 61 |
title_to_id: dict[str, int],
|
| 62 |
label_mapping: dict[str, int],
|
| 63 |
device: torch.device,
|
|
|
|
| 66 |
):
|
| 67 |
model.eval()
|
| 68 |
|
|
|
|
| 69 |
model = model.to(device)
|
| 70 |
base_data = base_data.to(device)
|
| 71 |
|
|
|
|
| 72 |
x_old = base_data.x
|
| 73 |
new_x = torch.tensor(new_embedding, dtype=x_old.dtype).view(1, -1)
|
| 74 |
new_x = new_x.to(device)
|
|
|
|
| 76 |
|
| 77 |
new_id = x.size(0) - 1
|
| 78 |
|
|
|
|
| 79 |
src_list = []
|
| 80 |
tgt_list = []
|
| 81 |
|
|
|
|
| 84 |
continue
|
| 85 |
old_id = title_to_id[t]
|
| 86 |
|
|
|
|
|
|
|
| 87 |
src_list.append(old_id)
|
| 88 |
tgt_list.append(new_id)
|
| 89 |
|
|
|
|
| 90 |
if make_undirected_for_new_node:
|
| 91 |
src_list.append(new_id)
|
| 92 |
tgt_list.append(old_id)
|
| 93 |
|
|
|
|
|
|
|
|
|
|
| 94 |
if len(src_list) > 0 and use_edges:
|
| 95 |
new_edges = torch.tensor([src_list, tgt_list], dtype=torch.long)
|
| 96 |
new_edges = new_edges.to(device)
|
|
|
|
| 98 |
else:
|
| 99 |
edge_index = base_data.edge_index
|
| 100 |
|
|
|
|
| 101 |
data_aug = Data(x=x, edge_index=edge_index).to(device)
|
| 102 |
|
| 103 |
with torch.no_grad():
|
| 104 |
+
out = model(data_aug)
|
| 105 |
log_probs = F.log_softmax(out, dim=1)
|
| 106 |
+
log_probs = log_probs[new_id]
|
| 107 |
pred_id = int(torch.argmax(log_probs).item())
|
| 108 |
|
| 109 |
inv_label_mapping = {v: k for k, v in label_mapping.items()}
|
| 110 |
pred_label = inv_label_mapping[pred_id]
|
| 111 |
|
| 112 |
+
probs = log_probs.exp().detach().cpu()
|
| 113 |
|
| 114 |
columns = ["Class", "Score"]
|
| 115 |
result_df = pd.DataFrame(
|
src/heuristic.py
CHANGED
|
@@ -13,70 +13,46 @@ def predict_topic_nth_degree(
|
|
| 13 |
is_weighted: bool = False,
|
| 14 |
decay_factor: float = 1.0,
|
| 15 |
) -> Optional[str]:
|
| 16 |
-
"""
|
| 17 |
-
Predicts topic based on neighbors up to n-degrees away.
|
| 18 |
-
|
| 19 |
-
Args:
|
| 20 |
-
max_depth: How many hops to traverse (1 = direct neighbors, 2 = neighbors of neighbors).
|
| 21 |
-
decay_factor: Multiplier for distance. 1.0 = no decay.
|
| 22 |
-
0.5 means a neighbor at depth 2 has half the voting power of depth 1.
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
# 1. Setup BFS
|
| 26 |
-
# Queue stores: (current_node_name, current_depth)
|
| 27 |
queue = deque()
|
| 28 |
|
| 29 |
-
# We maintain a visited set to avoid cycles and processing the same node twice
|
| 30 |
visited = set()
|
| 31 |
visited.add(new_article_title)
|
| 32 |
|
| 33 |
-
# 2. Initialize BFS with the "Virtual" First Hop
|
| 34 |
-
# We iterate the input list 'edges' manually because the new article isn't in G.
|
| 35 |
for ref in edges:
|
| 36 |
if ref in G and ref not in visited:
|
| 37 |
visited.add(ref)
|
| 38 |
-
queue.append((ref, 1))
|
| 39 |
|
| 40 |
if not queue:
|
| 41 |
return None
|
| 42 |
|
| 43 |
topic_scores = defaultdict(float)
|
| 44 |
|
| 45 |
-
# 3. Process BFS
|
| 46 |
while queue:
|
| 47 |
current_node, current_depth = queue.popleft()
|
| 48 |
|
| 49 |
-
# --- Score Calculation ---
|
| 50 |
node_data = G.nodes[current_node]
|
| 51 |
topic = node_data.get("label")
|
| 52 |
|
| 53 |
if topic:
|
| 54 |
-
# Determine base weight
|
| 55 |
if is_weighted:
|
| 56 |
neighbor_embedding = node_data["embedding"]
|
| 57 |
-
# Calculate similarity
|
| 58 |
base_score = cosine_similarity(
|
| 59 |
[new_article_embedding], [neighbor_embedding]
|
| 60 |
)[0][0]
|
| 61 |
else:
|
| 62 |
base_score = 1.0
|
| 63 |
|
| 64 |
-
# Apply Distance Decay
|
| 65 |
-
# Formula: Score * (decay ^ (depth - 1))
|
| 66 |
-
# Depth 1: Score * 1
|
| 67 |
-
# Depth 2: Score * decay
|
| 68 |
weighted_score = base_score * (decay_factor ** (current_depth - 1))
|
| 69 |
|
| 70 |
topic_scores[topic] += weighted_score
|
| 71 |
|
| 72 |
-
# --- Expand to next level if within limit ---
|
| 73 |
if current_depth < max_depth:
|
| 74 |
for neighbor in G.neighbors(current_node):
|
| 75 |
if neighbor not in visited:
|
| 76 |
visited.add(neighbor)
|
| 77 |
queue.append((neighbor, current_depth + 1))
|
| 78 |
|
| 79 |
-
# 4. Determine Winner
|
| 80 |
if not topic_scores:
|
| 81 |
return None
|
| 82 |
|
|
|
|
| 13 |
is_weighted: bool = False,
|
| 14 |
decay_factor: float = 1.0,
|
| 15 |
) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
queue = deque()
|
| 17 |
|
|
|
|
| 18 |
visited = set()
|
| 19 |
visited.add(new_article_title)
|
| 20 |
|
|
|
|
|
|
|
| 21 |
for ref in edges:
|
| 22 |
if ref in G and ref not in visited:
|
| 23 |
visited.add(ref)
|
| 24 |
+
queue.append((ref, 1))
|
| 25 |
|
| 26 |
if not queue:
|
| 27 |
return None
|
| 28 |
|
| 29 |
topic_scores = defaultdict(float)
|
| 30 |
|
|
|
|
| 31 |
while queue:
|
| 32 |
current_node, current_depth = queue.popleft()
|
| 33 |
|
|
|
|
| 34 |
node_data = G.nodes[current_node]
|
| 35 |
topic = node_data.get("label")
|
| 36 |
|
| 37 |
if topic:
|
|
|
|
| 38 |
if is_weighted:
|
| 39 |
neighbor_embedding = node_data["embedding"]
|
|
|
|
| 40 |
base_score = cosine_similarity(
|
| 41 |
[new_article_embedding], [neighbor_embedding]
|
| 42 |
)[0][0]
|
| 43 |
else:
|
| 44 |
base_score = 1.0
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
weighted_score = base_score * (decay_factor ** (current_depth - 1))
|
| 47 |
|
| 48 |
topic_scores[topic] += weighted_score
|
| 49 |
|
|
|
|
| 50 |
if current_depth < max_depth:
|
| 51 |
for neighbor in G.neighbors(current_node):
|
| 52 |
if neighbor not in visited:
|
| 53 |
visited.add(neighbor)
|
| 54 |
queue.append((neighbor, current_depth + 1))
|
| 55 |
|
|
|
|
| 56 |
if not topic_scores:
|
| 57 |
return None
|
| 58 |
|
src/utils.py
CHANGED
|
@@ -48,42 +48,20 @@ def gather_neighbors(
|
|
| 48 |
|
| 49 |
|
| 50 |
def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
|
| 51 |
-
"""
|
| 52 |
-
Returns the neighbors of a node within a given depth in a format
|
| 53 |
-
compatible with Cytoscape-style visualizers.
|
| 54 |
-
|
| 55 |
-
Args:
|
| 56 |
-
graph (nx.Graph): The source NetworkX graph.
|
| 57 |
-
start_node: The title/ID of the node to start from.
|
| 58 |
-
depth (int): How many hops (degrees of separation) to traverse.
|
| 59 |
-
|
| 60 |
-
Returns:
|
| 61 |
-
dict: A dictionary containing 'nodes' and 'edges' formatted for the visualizer.
|
| 62 |
-
"""
|
| 63 |
-
|
| 64 |
-
# 1. Create a subgraph of neighbors within the specified depth
|
| 65 |
-
# If the node doesn't exist, return empty structure or raise error
|
| 66 |
if start_node not in graph:
|
| 67 |
return {"nodes": [], "edges": []}
|
| 68 |
|
| 69 |
subgraph = nx.ego_graph(graph, start_node, radius=depth)
|
| 70 |
|
| 71 |
-
# 2. Prepare data structures
|
| 72 |
nodes_data = []
|
| 73 |
edges_data = []
|
| 74 |
|
| 75 |
-
# Helper to map actual node names (titles) to integer IDs required by the format
|
| 76 |
-
# The example uses 1-based integers for IDs.
|
| 77 |
node_to_id_map = {}
|
| 78 |
current_id = 1
|
| 79 |
|
| 80 |
-
# 3. Process Nodes
|
| 81 |
for node in subgraph.nodes():
|
| 82 |
-
# Assign an integer ID
|
| 83 |
node_to_id_map[node] = current_id
|
| 84 |
|
| 85 |
-
# Get attributes (safely default if label is missing)
|
| 86 |
-
# We ignore 'embedding' as requested
|
| 87 |
node_attrs = subgraph.nodes[node]
|
| 88 |
label = node_attrs.get("label", "Unknown")
|
| 89 |
|
|
@@ -91,24 +69,20 @@ def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
|
|
| 91 |
"data": {
|
| 92 |
"id": current_id,
|
| 93 |
"label": label,
|
| 94 |
-
"name": str(node),
|
| 95 |
}
|
| 96 |
}
|
| 97 |
nodes_data.append(node_obj)
|
| 98 |
current_id += 1
|
| 99 |
|
| 100 |
-
# 4. Process Edges
|
| 101 |
-
# Edge IDs usually need to be unique strings or integers.
|
| 102 |
-
# We continue the counter from where nodes left off to ensure uniqueness.
|
| 103 |
edge_id_counter = current_id
|
| 104 |
|
| 105 |
for u, v in subgraph.edges():
|
| 106 |
source_id = node_to_id_map[u]
|
| 107 |
target_id = node_to_id_map[v]
|
| 108 |
|
| 109 |
-
# Get edge attributes if they exist (e.g., relationship type)
|
| 110 |
edge_attrs = subgraph.edges[u, v]
|
| 111 |
-
edge_label = edge_attrs.get("label", "CITES")
|
| 112 |
|
| 113 |
edge_obj = {
|
| 114 |
"data": {
|
|
@@ -121,7 +95,6 @@ def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
|
|
| 121 |
edges_data.append(edge_obj)
|
| 122 |
edge_id_counter += 1
|
| 123 |
|
| 124 |
-
# 5. Return the final structure
|
| 125 |
return {"nodes": nodes_data, "edges": edges_data}
|
| 126 |
|
| 127 |
|
|
@@ -136,5 +109,4 @@ if __name__ == "__main__":
|
|
| 136 |
|
| 137 |
neighbors = gather_neighbors(graph, test_title, test_references, depth=2)
|
| 138 |
|
| 139 |
-
# print(f"References for '{test_title}': {test_references}")
|
| 140 |
print(f"Neighbors of '{test_title}': {neighbors}")
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
if start_node not in graph:
|
| 52 |
return {"nodes": [], "edges": []}
|
| 53 |
|
| 54 |
subgraph = nx.ego_graph(graph, start_node, radius=depth)
|
| 55 |
|
|
|
|
| 56 |
nodes_data = []
|
| 57 |
edges_data = []
|
| 58 |
|
|
|
|
|
|
|
| 59 |
node_to_id_map = {}
|
| 60 |
current_id = 1
|
| 61 |
|
|
|
|
| 62 |
for node in subgraph.nodes():
|
|
|
|
| 63 |
node_to_id_map[node] = current_id
|
| 64 |
|
|
|
|
|
|
|
| 65 |
node_attrs = subgraph.nodes[node]
|
| 66 |
label = node_attrs.get("label", "Unknown")
|
| 67 |
|
|
|
|
| 69 |
"data": {
|
| 70 |
"id": current_id,
|
| 71 |
"label": label,
|
| 72 |
+
"name": str(node),
|
| 73 |
}
|
| 74 |
}
|
| 75 |
nodes_data.append(node_obj)
|
| 76 |
current_id += 1
|
| 77 |
|
|
|
|
|
|
|
|
|
|
| 78 |
edge_id_counter = current_id
|
| 79 |
|
| 80 |
for u, v in subgraph.edges():
|
| 81 |
source_id = node_to_id_map[u]
|
| 82 |
target_id = node_to_id_map[v]
|
| 83 |
|
|
|
|
| 84 |
edge_attrs = subgraph.edges[u, v]
|
| 85 |
+
edge_label = edge_attrs.get("label", "CITES")
|
| 86 |
|
| 87 |
edge_obj = {
|
| 88 |
"data": {
|
|
|
|
| 95 |
edges_data.append(edge_obj)
|
| 96 |
edge_id_counter += 1
|
| 97 |
|
|
|
|
| 98 |
return {"nodes": nodes_data, "edges": edges_data}
|
| 99 |
|
| 100 |
|
|
|
|
| 109 |
|
| 110 |
neighbors = gather_neighbors(graph, test_title, test_references, depth=2)
|
| 111 |
|
|
|
|
| 112 |
print(f"Neighbors of '{test_title}': {neighbors}")
|
src/visualization.py
CHANGED
|
@@ -4,7 +4,7 @@ from src.config import config
|
|
| 4 |
def get_node_styles() -> list[NodeStyle]:
|
| 5 |
node_styles = []
|
| 6 |
for class_name in config.ICON_MAPPING.keys():
|
| 7 |
-
color = config.COLOR_MAPPING.get(class_name, "#888888")
|
| 8 |
icon = config.ICON_MAPPING.get(class_name, None)
|
| 9 |
node_styles.append(NodeStyle(
|
| 10 |
label=class_name,
|
|
|
|
| 4 |
def get_node_styles() -> list[NodeStyle]:
|
| 5 |
node_styles = []
|
| 6 |
for class_name in config.ICON_MAPPING.keys():
|
| 7 |
+
color = config.COLOR_MAPPING.get(class_name, "#888888")
|
| 8 |
icon = config.ICON_MAPPING.get(class_name, None)
|
| 9 |
node_styles.append(NodeStyle(
|
| 10 |
label=class_name,
|