Metin commited on
Commit
fdfe8da
·
1 Parent(s): 20b7f29
Files changed (7) hide show
  1. src/config.py +0 -6
  2. src/demo.py +0 -11
  3. src/embedding.py +0 -1
  4. src/gnn.py +5 -22
  5. src/heuristic.py +1 -25
  6. src/utils.py +2 -30
  7. src/visualization.py +1 -1
src/config.py CHANGED
@@ -56,7 +56,6 @@ class Config(BaseSettings):
56
  }
57
 
58
  COLOR_MAPPING: dict[str, str] = {
59
- # STEM & Natural Sciences -> Emerald (#06d6a0)
60
  'Biology': '#06d6a0',
61
  'Chemistry': '#06d6a0',
62
  'Earth_and_environment': '#06d6a0',
@@ -65,7 +64,6 @@ class Config(BaseSettings):
65
  'STEM': '#06d6a0',
66
  'Space': '#06d6a0',
67
 
68
- # Geography & Places -> Ocean Blue (#118ab2)
69
  'Africa': '#118ab2',
70
  'Americas': '#118ab2',
71
  'Asia': '#118ab2',
@@ -73,7 +71,6 @@ class Config(BaseSettings):
73
  'Oceania': '#118ab2',
74
  'Geographical': '#118ab2',
75
 
76
- # Arts, Entertainment & Culture -> Bubblegum Pink (#ef476f)
77
  'Entertainment': '#ef476f',
78
  'Fashion': '#ef476f',
79
  'Films': '#ef476f',
@@ -83,7 +80,6 @@ class Config(BaseSettings):
83
  'Visual_arts': '#ef476f',
84
  'Literature': '#ef476f',
85
 
86
- # Tech, Engineering & Infrastructure -> Dark Teal (#073b4c)
87
  'Architecture': '#073b4c',
88
  'Computing': '#073b4c',
89
  'Engineering': '#073b4c',
@@ -92,7 +88,6 @@ class Config(BaseSettings):
92
  'Transportation': '#073b4c',
93
  'Video_games': '#073b4c',
94
 
95
- # Society, Humanities & Lifestyle -> Coral Glow (#f78c6b)
96
  'Biography': '#f78c6b',
97
  'Food_and_drink': '#f78c6b',
98
  'Linguistics': '#f78c6b',
@@ -101,7 +96,6 @@ class Config(BaseSettings):
101
  'Society': '#f78c6b',
102
  'Sports': '#f78c6b',
103
 
104
- # Institutions, History & Governance -> Royal Gold (#ffd166)
105
  'Business_and_economics': '#ffd166',
106
  'Education': '#ffd166',
107
  'History': '#ffd166',
 
56
  }
57
 
58
  COLOR_MAPPING: dict[str, str] = {
 
59
  'Biology': '#06d6a0',
60
  'Chemistry': '#06d6a0',
61
  'Earth_and_environment': '#06d6a0',
 
64
  'STEM': '#06d6a0',
65
  'Space': '#06d6a0',
66
 
 
67
  'Africa': '#118ab2',
68
  'Americas': '#118ab2',
69
  'Asia': '#118ab2',
 
71
  'Oceania': '#118ab2',
72
  'Geographical': '#118ab2',
73
 
 
74
  'Entertainment': '#ef476f',
75
  'Fashion': '#ef476f',
76
  'Films': '#ef476f',
 
80
  'Visual_arts': '#ef476f',
81
  'Literature': '#ef476f',
82
 
 
83
  'Architecture': '#073b4c',
84
  'Computing': '#073b4c',
85
  'Engineering': '#073b4c',
 
88
  'Transportation': '#073b4c',
89
  'Video_games': '#073b4c',
90
 
 
91
  'Biography': '#f78c6b',
92
  'Food_and_drink': '#f78c6b',
93
  'Linguistics': '#f78c6b',
 
96
  'Society': '#f78c6b',
97
  'Sports': '#f78c6b',
98
 
 
99
  'Business_and_economics': '#ffd166',
100
  'Education': '#ffd166',
101
  'History': '#ffd166',
src/demo.py CHANGED
@@ -85,17 +85,6 @@ if "setup_complete" not in st.session_state:
85
  st.session_state.setup_complete = True
86
 
87
 
88
- # node_styles = [
89
- # NodeStyle("PERSON", "#FF7F3E", "name", "person"),
90
- # NodeStyle("POST", "#2A629A", "content", "description"),
91
- # ]
92
-
93
- # edge_styles = [
94
- # EdgeStyle("FOLLOWS", caption="label", directed=True),
95
- # EdgeStyle("POSTED", caption="label", directed=True),
96
- # EdgeStyle("QUOTES", caption="label", directed=True),
97
- # ]
98
-
99
  node_styles = get_node_styles()
100
  edge_styles = get_edge_styles()
101
 
 
85
  st.session_state.setup_complete = True
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
88
  node_styles = get_node_styles()
89
  edge_styles = get_edge_styles()
90
 
src/embedding.py CHANGED
@@ -9,7 +9,6 @@ from transformers import AutoModel, AutoTokenizer
9
 
10
  class Embedder:
11
  def __init__(self, path):
12
- # time.sleep(1)
13
  self.model_name_or_path = path
14
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
 
9
 
10
  class Embedder:
11
  def __init__(self, path):
 
12
  self.model_name_or_path = path
13
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
src/gnn.py CHANGED
@@ -14,8 +14,6 @@ class GNNClassifier(torch.nn.Module):
14
  self.layers = layers
15
  self.output_dim = output_dim
16
 
17
- # IMPROVEMENT 1: Reduce to 2 layers to prevent over-smoothing
18
- # If you really need 3 layers, you must add Residual Connections (x = x + conv(x))
19
  if layers == 2:
20
  self.conv1 = GCNConv(input_dim, hidden_dim)
21
  self.conv2 = GCNConv(hidden_dim, output_dim)
@@ -27,15 +25,10 @@ class GNNClassifier(torch.nn.Module):
27
  def forward(self, data):
28
  x, edge_index = data.x, data.edge_index
29
 
30
- # Layer 1
31
  x = self.conv1(x, edge_index)
32
  x = F.relu(x)
33
-
34
- # IMPROVEMENT 2: Higher Dropout (0.5 is standard for citation networks)
35
- # This prevents the model from relying too much on specific neighbor connections
36
  x = F.dropout(x, p=self.dropout_rate, training=self.training)
37
 
38
- # Layer 2
39
  x = self.conv2(x, edge_index)
40
 
41
  if self.layers == 3:
@@ -63,8 +56,8 @@ def load_data(version: str = "undirected"):
63
  def infer_new_node(
64
  base_data: Data,
65
  model: torch.nn.Module,
66
- new_embedding, # shape (768,) list/np array/torch
67
- referenced_titles: list[str], # titles the user selected
68
  title_to_id: dict[str, int],
69
  label_mapping: dict[str, int],
70
  device: torch.device,
@@ -73,11 +66,9 @@ def infer_new_node(
73
  ):
74
  model.eval()
75
 
76
- # Move model to device
77
  model = model.to(device)
78
  base_data = base_data.to(device)
79
 
80
- # --- 1) Prepare new node feature ---
81
  x_old = base_data.x
82
  new_x = torch.tensor(new_embedding, dtype=x_old.dtype).view(1, -1)
83
  new_x = new_x.to(device)
@@ -85,7 +76,6 @@ def infer_new_node(
85
 
86
  new_id = x.size(0) - 1
87
 
88
- # --- 2) Build new edges that attach the node ---
89
  src_list = []
90
  tgt_list = []
91
 
@@ -94,19 +84,13 @@ def infer_new_node(
94
  continue
95
  old_id = title_to_id[t]
96
 
97
- # If you want new node to be influenced by referenced nodes in 1 hop,
98
- # you need edges old -> new (incoming to new).
99
  src_list.append(old_id)
100
  tgt_list.append(new_id)
101
 
102
- # Optional: also add new -> old to make it undirected / symmetric
103
  if make_undirected_for_new_node:
104
  src_list.append(new_id)
105
  tgt_list.append(old_id)
106
 
107
- # If the user picked nothing, the node is isolated; GCNConv can still work
108
- # because it adds self-loops by default, but performance may be weak.
109
-
110
  if len(src_list) > 0 and use_edges:
111
  new_edges = torch.tensor([src_list, tgt_list], dtype=torch.long)
112
  new_edges = new_edges.to(device)
@@ -114,19 +98,18 @@ def infer_new_node(
114
  else:
115
  edge_index = base_data.edge_index
116
 
117
- # --- 3) Run inference on the augmented graph ---
118
  data_aug = Data(x=x, edge_index=edge_index).to(device)
119
 
120
  with torch.no_grad():
121
- out = model(data_aug) # your model returns raw logits
122
  log_probs = F.log_softmax(out, dim=1)
123
- log_probs = log_probs[new_id] # get log-probs for the new node only
124
  pred_id = int(torch.argmax(log_probs).item())
125
 
126
  inv_label_mapping = {v: k for k, v in label_mapping.items()}
127
  pred_label = inv_label_mapping[pred_id]
128
 
129
- probs = log_probs.exp().detach().cpu() # convert log-probs -> probs
130
 
131
  columns = ["Class", "Score"]
132
  result_df = pd.DataFrame(
 
14
  self.layers = layers
15
  self.output_dim = output_dim
16
 
 
 
17
  if layers == 2:
18
  self.conv1 = GCNConv(input_dim, hidden_dim)
19
  self.conv2 = GCNConv(hidden_dim, output_dim)
 
25
  def forward(self, data):
26
  x, edge_index = data.x, data.edge_index
27
 
 
28
  x = self.conv1(x, edge_index)
29
  x = F.relu(x)
 
 
 
30
  x = F.dropout(x, p=self.dropout_rate, training=self.training)
31
 
 
32
  x = self.conv2(x, edge_index)
33
 
34
  if self.layers == 3:
 
56
  def infer_new_node(
57
  base_data: Data,
58
  model: torch.nn.Module,
59
+ new_embedding,
60
+ referenced_titles: list[str],
61
  title_to_id: dict[str, int],
62
  label_mapping: dict[str, int],
63
  device: torch.device,
 
66
  ):
67
  model.eval()
68
 
 
69
  model = model.to(device)
70
  base_data = base_data.to(device)
71
 
 
72
  x_old = base_data.x
73
  new_x = torch.tensor(new_embedding, dtype=x_old.dtype).view(1, -1)
74
  new_x = new_x.to(device)
 
76
 
77
  new_id = x.size(0) - 1
78
 
 
79
  src_list = []
80
  tgt_list = []
81
 
 
84
  continue
85
  old_id = title_to_id[t]
86
 
 
 
87
  src_list.append(old_id)
88
  tgt_list.append(new_id)
89
 
 
90
  if make_undirected_for_new_node:
91
  src_list.append(new_id)
92
  tgt_list.append(old_id)
93
 
 
 
 
94
  if len(src_list) > 0 and use_edges:
95
  new_edges = torch.tensor([src_list, tgt_list], dtype=torch.long)
96
  new_edges = new_edges.to(device)
 
98
  else:
99
  edge_index = base_data.edge_index
100
 
 
101
  data_aug = Data(x=x, edge_index=edge_index).to(device)
102
 
103
  with torch.no_grad():
104
+ out = model(data_aug)
105
  log_probs = F.log_softmax(out, dim=1)
106
+ log_probs = log_probs[new_id]
107
  pred_id = int(torch.argmax(log_probs).item())
108
 
109
  inv_label_mapping = {v: k for k, v in label_mapping.items()}
110
  pred_label = inv_label_mapping[pred_id]
111
 
112
+ probs = log_probs.exp().detach().cpu()
113
 
114
  columns = ["Class", "Score"]
115
  result_df = pd.DataFrame(
src/heuristic.py CHANGED
@@ -13,70 +13,46 @@ def predict_topic_nth_degree(
13
  is_weighted: bool = False,
14
  decay_factor: float = 1.0,
15
  ) -> Optional[str]:
16
- """
17
- Predicts topic based on neighbors up to n-degrees away.
18
-
19
- Args:
20
- max_depth: How many hops to traverse (1 = direct neighbors, 2 = neighbors of neighbors).
21
- decay_factor: Multiplier for distance. 1.0 = no decay.
22
- 0.5 means a neighbor at depth 2 has half the voting power of depth 1.
23
- """
24
-
25
- # 1. Setup BFS
26
- # Queue stores: (current_node_name, current_depth)
27
  queue = deque()
28
 
29
- # We maintain a visited set to avoid cycles and processing the same node twice
30
  visited = set()
31
  visited.add(new_article_title)
32
 
33
- # 2. Initialize BFS with the "Virtual" First Hop
34
- # We iterate the input list 'edges' manually because the new article isn't in G.
35
  for ref in edges:
36
  if ref in G and ref not in visited:
37
  visited.add(ref)
38
- queue.append((ref, 1)) # Depth 1
39
 
40
  if not queue:
41
  return None
42
 
43
  topic_scores = defaultdict(float)
44
 
45
- # 3. Process BFS
46
  while queue:
47
  current_node, current_depth = queue.popleft()
48
 
49
- # --- Score Calculation ---
50
  node_data = G.nodes[current_node]
51
  topic = node_data.get("label")
52
 
53
  if topic:
54
- # Determine base weight
55
  if is_weighted:
56
  neighbor_embedding = node_data["embedding"]
57
- # Calculate similarity
58
  base_score = cosine_similarity(
59
  [new_article_embedding], [neighbor_embedding]
60
  )[0][0]
61
  else:
62
  base_score = 1.0
63
 
64
- # Apply Distance Decay
65
- # Formula: Score * (decay ^ (depth - 1))
66
- # Depth 1: Score * 1
67
- # Depth 2: Score * decay
68
  weighted_score = base_score * (decay_factor ** (current_depth - 1))
69
 
70
  topic_scores[topic] += weighted_score
71
 
72
- # --- Expand to next level if within limit ---
73
  if current_depth < max_depth:
74
  for neighbor in G.neighbors(current_node):
75
  if neighbor not in visited:
76
  visited.add(neighbor)
77
  queue.append((neighbor, current_depth + 1))
78
 
79
- # 4. Determine Winner
80
  if not topic_scores:
81
  return None
82
 
 
13
  is_weighted: bool = False,
14
  decay_factor: float = 1.0,
15
  ) -> Optional[str]:
 
 
 
 
 
 
 
 
 
 
 
16
  queue = deque()
17
 
 
18
  visited = set()
19
  visited.add(new_article_title)
20
 
 
 
21
  for ref in edges:
22
  if ref in G and ref not in visited:
23
  visited.add(ref)
24
+ queue.append((ref, 1))
25
 
26
  if not queue:
27
  return None
28
 
29
  topic_scores = defaultdict(float)
30
 
 
31
  while queue:
32
  current_node, current_depth = queue.popleft()
33
 
 
34
  node_data = G.nodes[current_node]
35
  topic = node_data.get("label")
36
 
37
  if topic:
 
38
  if is_weighted:
39
  neighbor_embedding = node_data["embedding"]
 
40
  base_score = cosine_similarity(
41
  [new_article_embedding], [neighbor_embedding]
42
  )[0][0]
43
  else:
44
  base_score = 1.0
45
 
 
 
 
 
46
  weighted_score = base_score * (decay_factor ** (current_depth - 1))
47
 
48
  topic_scores[topic] += weighted_score
49
 
 
50
  if current_depth < max_depth:
51
  for neighbor in G.neighbors(current_node):
52
  if neighbor not in visited:
53
  visited.add(neighbor)
54
  queue.append((neighbor, current_depth + 1))
55
 
 
56
  if not topic_scores:
57
  return None
58
 
src/utils.py CHANGED
@@ -48,42 +48,20 @@ def gather_neighbors(
48
 
49
 
50
  def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
51
- """
52
- Returns the neighbors of a node within a given depth in a format
53
- compatible with Cytoscape-style visualizers.
54
-
55
- Args:
56
- graph (nx.Graph): The source NetworkX graph.
57
- start_node: The title/ID of the node to start from.
58
- depth (int): How many hops (degrees of separation) to traverse.
59
-
60
- Returns:
61
- dict: A dictionary containing 'nodes' and 'edges' formatted for the visualizer.
62
- """
63
-
64
- # 1. Create a subgraph of neighbors within the specified depth
65
- # If the node doesn't exist, return empty structure or raise error
66
  if start_node not in graph:
67
  return {"nodes": [], "edges": []}
68
 
69
  subgraph = nx.ego_graph(graph, start_node, radius=depth)
70
 
71
- # 2. Prepare data structures
72
  nodes_data = []
73
  edges_data = []
74
 
75
- # Helper to map actual node names (titles) to integer IDs required by the format
76
- # The example uses 1-based integers for IDs.
77
  node_to_id_map = {}
78
  current_id = 1
79
 
80
- # 3. Process Nodes
81
  for node in subgraph.nodes():
82
- # Assign an integer ID
83
  node_to_id_map[node] = current_id
84
 
85
- # Get attributes (safely default if label is missing)
86
- # We ignore 'embedding' as requested
87
  node_attrs = subgraph.nodes[node]
88
  label = node_attrs.get("label", "Unknown")
89
 
@@ -91,24 +69,20 @@ def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
91
  "data": {
92
  "id": current_id,
93
  "label": label,
94
- "name": str(node), # Using the node title/ID as 'name'
95
  }
96
  }
97
  nodes_data.append(node_obj)
98
  current_id += 1
99
 
100
- # 4. Process Edges
101
- # Edge IDs usually need to be unique strings or integers.
102
- # We continue the counter from where nodes left off to ensure uniqueness.
103
  edge_id_counter = current_id
104
 
105
  for u, v in subgraph.edges():
106
  source_id = node_to_id_map[u]
107
  target_id = node_to_id_map[v]
108
 
109
- # Get edge attributes if they exist (e.g., relationship type)
110
  edge_attrs = subgraph.edges[u, v]
111
- edge_label = edge_attrs.get("label", "CITES") # Default label if none exists
112
 
113
  edge_obj = {
114
  "data": {
@@ -121,7 +95,6 @@ def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
121
  edges_data.append(edge_obj)
122
  edge_id_counter += 1
123
 
124
- # 5. Return the final structure
125
  return {"nodes": nodes_data, "edges": edges_data}
126
 
127
 
@@ -136,5 +109,4 @@ if __name__ == "__main__":
136
 
137
  neighbors = gather_neighbors(graph, test_title, test_references, depth=2)
138
 
139
- # print(f"References for '{test_title}': {test_references}")
140
  print(f"Neighbors of '{test_title}': {neighbors}")
 
48
 
49
 
50
  def get_neighbors_for_visualizer(graph: nx.Graph, start_node, depth=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if start_node not in graph:
52
  return {"nodes": [], "edges": []}
53
 
54
  subgraph = nx.ego_graph(graph, start_node, radius=depth)
55
 
 
56
  nodes_data = []
57
  edges_data = []
58
 
 
 
59
  node_to_id_map = {}
60
  current_id = 1
61
 
 
62
  for node in subgraph.nodes():
 
63
  node_to_id_map[node] = current_id
64
 
 
 
65
  node_attrs = subgraph.nodes[node]
66
  label = node_attrs.get("label", "Unknown")
67
 
 
69
  "data": {
70
  "id": current_id,
71
  "label": label,
72
+ "name": str(node),
73
  }
74
  }
75
  nodes_data.append(node_obj)
76
  current_id += 1
77
 
 
 
 
78
  edge_id_counter = current_id
79
 
80
  for u, v in subgraph.edges():
81
  source_id = node_to_id_map[u]
82
  target_id = node_to_id_map[v]
83
 
 
84
  edge_attrs = subgraph.edges[u, v]
85
+ edge_label = edge_attrs.get("label", "CITES")
86
 
87
  edge_obj = {
88
  "data": {
 
95
  edges_data.append(edge_obj)
96
  edge_id_counter += 1
97
 
 
98
  return {"nodes": nodes_data, "edges": edges_data}
99
 
100
 
 
109
 
110
  neighbors = gather_neighbors(graph, test_title, test_references, depth=2)
111
 
 
112
  print(f"Neighbors of '{test_title}': {neighbors}")
src/visualization.py CHANGED
@@ -4,7 +4,7 @@ from src.config import config
4
  def get_node_styles() -> list[NodeStyle]:
5
  node_styles = []
6
  for class_name in config.ICON_MAPPING.keys():
7
- color = config.COLOR_MAPPING.get(class_name, "#888888") # Default gray if not found
8
  icon = config.ICON_MAPPING.get(class_name, None)
9
  node_styles.append(NodeStyle(
10
  label=class_name,
 
4
  def get_node_styles() -> list[NodeStyle]:
5
  node_styles = []
6
  for class_name in config.ICON_MAPPING.keys():
7
+ color = config.COLOR_MAPPING.get(class_name, "#888888")
8
  icon = config.ICON_MAPPING.get(class_name, None)
9
  node_styles.append(NodeStyle(
10
  label=class_name,