Andrej Janchevski commited on
Commit
bc4fc5c
·
1 Parent(s): 469d523

feat(kganomaly): add streaming denoising backend with KG-likelihood metric

Browse files

- New endpoints emit Server-Sent Events via api/renderers.py so the
diffusion reverse process streams progress + frame previews instead of
blocking on a single response.
- kg_likelihood.py exposes a per-step mean log-sigmoid score from the
frozen KG embedder + link ranker; kg_anomaly_inference.py logs it on
frame boundaries and attaches kg_log_likelihood / kg_log_likelihood_step
to progress events so the UI can render a "denoising-is-working"
trace alongside step-duration sparklines.
- registry._build_sample_subgraphs now accepts a seed (so each request
gets a different DFS partition), shuffles candidate (row, col) pairs,
and rejects ill-shaped bipartite samples whose halves can't form valid
inpaint quadrants. Sampler.get_context_subgraph_samples_dfs gains the
matching seed parameter in the research code.
- api.yaml + backend README document the new optional progress fields.

docs/api.yaml CHANGED
@@ -1806,3 +1806,14 @@ components:
1806
  total:
1807
  type: integer
1808
  description: Total steps in the stage
 
 
 
 
 
 
 
 
 
 
 
 
1806
  total:
1807
  type: integer
1808
  description: Total steps in the stage
1809
+ kg_log_likelihood:
1810
+ type: number
1811
+ nullable: true
1812
+ description: >
1813
+ Mean log-sigmoid score from the frozen KG embedder + link ranker
1814
+ applied to the edges currently present in the argmax reconstruction.
1815
+ Higher = cleaner. Present only on frame-boundary events.
1816
+ kg_log_likelihood_step:
1817
+ type: integer
1818
+ nullable: true
1819
+ description: Step index that `kg_log_likelihood` corresponds to.
src/backend/README.md CHANGED
@@ -111,6 +111,11 @@ event: progress
111
  data: {"type":"progress","phase":"denoise","step":42,"total_steps":500,"elapsed_ms":2100}
112
  ```
113
 
 
 
 
 
 
114
  **`event: preview`** — base64 PNG of the graph's current state, emitted at key frames:
115
  ```
116
  event: preview
 
111
  data: {"type":"progress","phase":"denoise","step":42,"total_steps":500,"elapsed_ms":2100}
112
  ```
113
 
114
+ KG-anomaly progress events additionally carry an optional `kg_log_likelihood`
115
+ (float) + `kg_log_likelihood_step` (int) on frame boundaries — the mean
116
+ log-sigmoid score from the frozen KG embedder + link ranker on the edges
117
+ currently present in the argmax reconstruction. Higher = cleaner.
118
+
119
  **`event: preview`** — base64 PNG of the graph's current state, emitted at key frames:
120
  ```
121
  event: preview
src/backend/api/renderers.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rest_framework.renderers import BaseRenderer
2
+
3
+
4
+ class EventStreamRenderer(BaseRenderer):
5
+ """Renderer declaring text/event-stream so DRF content negotiation accepts SSE clients.
6
+
7
+ The streaming views return a StreamingHttpResponse directly, so this
8
+ renderer is never invoked to produce bytes — it exists only to satisfy
9
+ DRF's Accept header negotiation.
10
+ """
11
+
12
+ media_type = "text/event-stream"
13
+ format = "sse"
14
+ charset = None
15
+ render_style = "binary"
16
+
17
+ def render(self, data, accepted_media_type=None, renderer_context=None):
18
+ return b""
src/backend/api/services/kg_anomaly_inference.py CHANGED
@@ -1,13 +1,27 @@
1
  import base64
 
2
  import io
 
 
 
3
  import time
 
 
 
 
 
4
 
5
  import torch
6
  import torch.nn.functional as F
 
 
 
7
 
8
  from api.services.graphgen_inference import (
9
  _frames_to_gif_b64, _pil_to_b64,
10
  )
 
 
11
 
12
  STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 # 10 MB
13
  REQUIRED_STATE_KEYS = {
@@ -76,7 +90,7 @@ def build_kg_tensors(subgraph, loader, model):
76
  X_c[0, i] = int(communities[eid])
77
 
78
  n_nodes = torch.tensor([n], dtype=torch.long)
79
- is_bip = torch.tensor([n > 20], dtype=torch.bool)
80
  node_mask = torch.ones(1, n, dtype=torch.bool)
81
 
82
  return {
@@ -90,14 +104,16 @@ def _to_device(t, device):
90
  return t.to(device) if isinstance(t, torch.Tensor) else t
91
 
92
 
93
- def apply_edge_noise(model, tensors, task, noise_level, seed=None):
 
94
  """Forward-diffuse the given subgraph's edges at t = noise_level * T.
95
 
96
  For task="correct", only edges inside the inpaint mask (the second half of
97
  nodes) are noised, matching what the correction endpoint will regenerate.
98
  For task="generate", every edge slot is noised.
99
 
100
- Returns a new list of {source_idx, target_idx, relation_id} dicts.
 
101
  """
102
  from graph_generation.src.utils import get_inpaint_mask
103
  from graph_generation.src.diffusion import diffusion_utils
@@ -137,6 +153,24 @@ def apply_edge_noise(model, tensors, task, noise_level, seed=None):
137
  E_mixed = E_noised * inpaint_mask + E * (~inpaint_mask)
138
  E_int = E_mixed[0].argmax(dim=-1).cpu()
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  edges = []
141
  for i in range(n):
142
  for j in range(n):
@@ -145,9 +179,15 @@ def apply_edge_noise(model, tensors, task, noise_level, seed=None):
145
  cls = int(E_int[i, j])
146
  if cls == 0:
147
  continue
148
- edges.append({
149
- "source_idx": i, "target_idx": j, "relation_id": cls - 1,
150
- })
 
 
 
 
 
 
151
  return edges
152
 
153
 
@@ -155,7 +195,7 @@ def apply_edge_noise(model, tensors, task, noise_level, seed=None):
155
  # Change detection
156
  # ---------------------------------------------------------------------------
157
 
158
- def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
159
  """Compute before/after edge diff for a directed KG subgraph.
160
 
161
  original_E_int / corrected_E_int: 2-D int tensors (n, n) where 0 = no edge
@@ -163,6 +203,13 @@ def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
163
  """
164
  _, _, inv_relations = loader.dataset.get_inverted_name_maps()
165
 
 
 
 
 
 
 
 
166
  edges = []
167
  summary = {"added": 0, "removed": 0, "modified": 0, "unchanged": 0}
168
 
@@ -182,7 +229,7 @@ def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
182
  edges.append({
183
  "source_idx": i, "target_idx": j, "change": "unchanged",
184
  "relation_id": c - 1,
185
- "relation_name": str(inv_relations.get(c - 1, c - 1)),
186
  })
187
  continue
188
  if o == 0 and c > 0:
@@ -190,23 +237,23 @@ def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
190
  edges.append({
191
  "source_idx": i, "target_idx": j, "change": "added",
192
  "relation_id": c - 1,
193
- "relation_name": str(inv_relations.get(c - 1, c - 1)),
194
  })
195
  elif o > 0 and c == 0:
196
  summary["removed"] += 1
197
  edges.append({
198
  "source_idx": i, "target_idx": j, "change": "removed",
199
  "original_relation_id": o - 1,
200
- "original_relation_name": str(inv_relations.get(o - 1, o - 1)),
201
  })
202
  else:
203
  summary["modified"] += 1
204
  edges.append({
205
  "source_idx": i, "target_idx": j, "change": "modified",
206
  "original_relation_id": o - 1,
207
- "original_relation_name": str(inv_relations.get(o - 1, o - 1)),
208
  "relation_id": c - 1,
209
- "relation_name": str(inv_relations.get(c - 1, c - 1)),
210
  })
211
 
212
  return {"edges": edges, "summary": summary}
@@ -216,50 +263,117 @@ def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
216
  # Rendering
217
  # ---------------------------------------------------------------------------
218
 
219
- def _format_entity_label(dataset_id, name):
220
- s = str(name)
221
- if dataset_id == "freebase":
222
- s = s.replace("/m/", "")
223
- elif dataset_id == "wordnet":
224
- s = s.split(".")[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  else:
226
- if "concept" in s:
227
- parts = s.split(":")
228
- s = parts[-2] if "new" in s and len(parts) >= 2 else parts[-1]
229
- if len(s) > 14:
230
- s = s[:13] + "…"
231
- return s
232
-
233
-
234
- def _format_relation_label(dataset_id, name):
235
- s = str(name)
236
- if dataset_id == "freebase":
237
- parts = s.split(".")
238
- s = ".".join(["_".join(p.split("/")[-2:]) for p in parts])
239
- elif dataset_id == "wordnet":
240
- s = s[1:] if s.startswith("_") else s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  else:
242
- if "concept" in s:
243
- parts = s.split(":")
244
- s = parts[-2] if "new" in s and len(parts) >= 2 else parts[-1]
245
- if len(s) > 16:
246
- s = s[:15] + "…"
247
- return s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
 
250
- def render_kg_subgraph(E_int, num_nodes, X_index, dataset_id, loader, changes=None):
251
  """Render a directed KG subgraph as a PIL image using networkx + PIL.
252
 
253
- Does not use matplotlib (same reason as graphgen_inference: Windows thread safety).
 
 
 
 
 
 
254
  """
255
- import networkx as nx
256
- from PIL import Image, ImageDraw, ImageFont
257
-
258
- inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
259
 
260
  e = E_int.cpu().tolist()
261
  xi = X_index.cpu().tolist()
262
 
 
 
 
 
 
263
  G = nx.DiGraph()
264
  for i in range(num_nodes):
265
  G.add_node(i)
@@ -270,105 +384,162 @@ def render_kg_subgraph(E_int, num_nodes, X_index, dataset_id, loader, changes=No
270
  if int(e[i][j]) > 0:
271
  G.add_edge(i, j, rel=int(e[i][j]) - 1)
272
 
273
- pos = nx.spring_layout(G, seed=42)
274
-
275
- # Build change lookup: (i, j) -> change_type
276
- change_lookup = {}
277
  if changes is not None:
278
- for entry in changes.get("edges", []):
279
- change_lookup[(entry["source_idx"], entry["target_idx"])] = entry["change"]
 
280
 
281
- size = 500
282
- margin = 50
283
- scale = (size - 2 * margin) / 2
284
- cx, cy = size / 2, size / 2
285
- pixel_pos = {k: (cx + v[0] * scale, cy + v[1] * scale) for k, v in pos.items()}
 
 
 
286
 
 
287
  img = Image.new("RGB", (size, size), "white")
288
  draw = ImageDraw.Draw(img)
 
 
 
 
289
  try:
290
- font = ImageFont.truetype("arial.ttf", 11)
291
- small_font = ImageFont.truetype("arial.ttf", 9)
292
  except (OSError, IOError):
293
  font = ImageFont.load_default()
294
- small_font = font
295
 
296
- node_r = 10
 
 
 
 
 
297
 
298
- # Draw edges first (so nodes overlay them)
299
- # Include "removed" edges from change_lookup even if not in G
300
- all_edges = set((i, j) for i, j in G.edges())
301
- if changes is not None:
302
- for (i, j), ct in change_lookup.items():
303
- if ct == "removed":
304
- all_edges.add((i, j))
305
-
306
- for (i, j) in all_edges:
307
- change_type = change_lookup.get((i, j))
308
- color = CHANGE_COLORS.get(change_type, "#444444") if changes is not None else "#444444"
309
- dashed = (change_type == "removed")
310
- x0, y0 = pixel_pos[i]
311
- x1, y1 = pixel_pos[j]
312
- # Shorten line to not overlap node circles
313
- dx, dy = x1 - x0, y1 - y0
314
- dist = max(1.0, (dx * dx + dy * dy) ** 0.5)
315
- ux, uy = dx / dist, dy / dist
316
- sx, sy = x0 + ux * node_r, y0 + uy * node_r
317
- ex, ey = x1 - ux * node_r, y1 - uy * node_r
318
- if dashed:
319
- _draw_dashed(draw, (sx, sy), (ex, ey), color, width=2, dash=6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  else:
321
- draw.line([(sx, sy), (ex, ey)], fill=color, width=2)
322
- # Arrowhead
323
- _draw_arrowhead(draw, (ex, ey), (ux, uy), color)
324
- # Relation label
325
- if (i, j) in G.edges():
326
- rel_id = G.edges[(i, j)]["rel"]
327
- rel_name = _format_relation_label(dataset_id, inv_relations.get(rel_id, rel_id))
328
- mx, my = (sx + ex) / 2, (sy + ey) / 2
329
- draw.text((mx + 3, my - 5), rel_name, fill=color, font=small_font)
330
-
331
- # Draw nodes
332
- for i in range(num_nodes):
333
- x, y = pixel_pos[i]
 
 
 
334
  draw.ellipse([x - node_r, y - node_r, x + node_r, y + node_r],
335
- fill="#2ecc71", outline="#1a7a42")
336
- eid = int(xi[i]) if i < len(xi) else i
337
- label = _format_entity_label(dataset_id, inv_nodes.get(eid, eid))
338
- draw.text((x + node_r + 2, y - 6), label, fill="#111111", font=font)
 
 
 
 
 
 
339
 
340
  return img
341
 
342
 
343
- def _draw_arrowhead(draw, tip, direction, color):
344
- import math
345
- ux, uy = direction
346
- angle = math.atan2(uy, ux)
347
- ah_len = 7
348
- ah_angle = math.radians(25)
349
- x, y = tip
350
- x1 = x - ah_len * math.cos(angle - ah_angle)
351
- y1 = y - ah_len * math.sin(angle - ah_angle)
352
- x2 = x - ah_len * math.cos(angle + ah_angle)
353
- y2 = y - ah_len * math.sin(angle + ah_angle)
354
- draw.polygon([(x, y), (x1, y1), (x2, y2)], fill=color)
355
 
356
 
357
- def _draw_dashed(draw, start, end, color, width=2, dash=6):
358
- x0, y0 = start
359
- x1, y1 = end
360
- dx, dy = x1 - x0, y1 - y0
361
- dist = max(1.0, (dx * dx + dy * dy) ** 0.5)
362
- steps = int(dist // dash)
363
- ux, uy = dx / dist, dy / dist
364
- for k in range(steps):
365
- if k % 2 == 1:
366
- continue
367
- sx = x0 + ux * dash * k
368
- sy = y0 + uy * dash * k
369
- ex = x0 + ux * dash * min(k + 1, steps)
370
- ey = y0 + uy * dash * min(k + 1, steps)
371
- draw.line([(sx, sy), (ex, ey)], fill=color, width=width)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
 
374
  # ---------------------------------------------------------------------------
@@ -406,17 +577,21 @@ def run_standard_correction(model, tensors, dataset_id, task, loader,
406
  E_given = tensors["E_given"].to(device)
407
  y_given = tensors["y_given"].to(device)
408
  X_index = tensors["X_index"].to(device)
 
409
  is_bip = tensors["is_bip"].to(device)
410
  n_nodes = tensors["n_nodes"].to(device)
411
  node_mask = tensors["node_mask"].to(device)
412
  n_max = n_nodes.item()
 
413
 
414
  inpaint_mask = _build_inpaint_mask(
415
  task, node_mask, is_bip, model.Edim_output, device)
 
416
 
417
  original_E_int = E_given[0].argmax(dim=-1).long() # (n, n)
418
  original_img = render_kg_subgraph(
419
- original_E_int, n_max, X_index[0], dataset_id, loader, changes=None)
 
420
 
421
  model_T = model.T
422
  step_stride = max(1, model_T // diffusion_steps)
@@ -451,17 +626,28 @@ def run_standard_correction(model, tensors, dataset_id, task, loader,
451
  }
452
  if is_frame:
453
  frame = render_kg_subgraph(
454
- E_int_prev, n_max, X_index[0], dataset_id, loader)
 
455
  gif_frames.append(frame)
456
  event["preview"] = _pil_to_b64(frame)
 
 
 
 
 
 
 
 
 
457
  yield event
458
 
459
  X_final, E_final = _collapse_final_kg(model, X, E, y, node_mask)
460
 
461
  corrected_E_int = E_final[0]
462
- changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
463
  corrected_img = render_kg_subgraph(
464
- corrected_E_int, n_max, X_index[0], dataset_id, loader, changes=changes)
 
465
 
466
  elapsed_ms = int((time.time() - t0) * 1000)
467
  yield {
@@ -493,9 +679,11 @@ def run_multiprox_correction_init(model, tensors, dataset_id, task, loader,
493
 
494
  inpaint_mask = _build_inpaint_mask(
495
  task, node_mask, is_bip, model.Edim_output, device)
 
496
  original_E_int = E_given[0].argmax(dim=-1).long()
497
  original_img = render_kg_subgraph(
498
- original_E_int, n_max, X_index[0], dataset_id, loader, changes=None)
 
499
 
500
  t0 = time.time()
501
  # Sample initial noise for each of M Gibbs chains
@@ -524,9 +712,10 @@ def run_multiprox_correction_init(model, tensors, dataset_id, task, loader,
524
  agg_y = torch.median(y_ens.float(), dim=1).values
525
  X_int, E_int = _collapse_final_kg(model, X_given, agg_E, agg_y, node_mask)
526
  corrected_E_int = E_int[0]
527
- changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
528
  preview_img = render_kg_subgraph(
529
- corrected_E_int, n_max, X_index[0], dataset_id, loader, changes=changes)
 
530
  elapsed_ms = int((time.time() - t0) * 1000)
531
 
532
  state = {
@@ -535,11 +724,13 @@ def run_multiprox_correction_init(model, tensors, dataset_id, task, loader,
535
  "y": y_ens.cpu(),
536
  "n_nodes": n_nodes.cpu(),
537
  "dataset_id": dataset_id,
 
538
  "task": task,
539
  "X_index": X_index.cpu(),
540
  "X_c": X_c.cpu(),
541
  "is_bip": is_bip.cpu(),
542
  "original_E_int": original_E_int.cpu(),
 
543
  "T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime,
544
  "gibbs_chain_freq": gibbs_chain_freq,
545
  "inner_step": 0, "step": 0,
@@ -562,9 +753,12 @@ def run_multiprox_correction_step(model, state, loader):
562
  E = state["E"].to(device)
563
  y = state["y"].to(device)
564
  X_index = state["X_index"].to(device)
 
565
  is_bip = state["is_bip"].to(device)
566
  n_nodes = state["n_nodes"].to(device)
567
  original_E_int = state["original_E_int"].to(device)
 
 
568
 
569
  T = state["T"]
570
  n = state["n"]
@@ -578,6 +772,7 @@ def run_multiprox_correction_step(model, state, loader):
578
  n_max = int(n_nodes.item())
579
  node_mask = torch.ones(1, n_max, dtype=torch.bool, device=device)
580
  inpaint_mask = _build_inpaint_mask(task, node_mask, is_bip, model.Edim_output, device)
 
581
 
582
  fixed_t_norm = t * torch.ones((1, 1), dtype=torch.float, device=device)
583
  fixed_s_norm = fixed_t_norm - (1.0 / T)
@@ -608,8 +803,9 @@ def run_multiprox_correction_step(model, state, loader):
608
  prev_y = torch.median(y.float(), dim=1).values
609
  _, prev_Ei = _collapse_final_kg(model, X_given, prev_E, prev_y, node_mask)
610
  preview_img = render_kg_subgraph(
611
- prev_Ei[0], n_max, X_index[0], dataset_id, loader)
612
- yield {
 
613
  "type": "progress",
614
  "phase": "gibbs",
615
  "step": i + 1,
@@ -617,6 +813,16 @@ def run_multiprox_correction_step(model, state, loader):
617
  "elapsed_ms": int((time.time() - t0) * 1000),
618
  "preview": _pil_to_b64(preview_img),
619
  }
 
 
 
 
 
 
 
 
 
 
620
 
621
  new_inner_step = inner_step + steps_this_call
622
  round_complete = new_inner_step >= m
@@ -649,21 +855,34 @@ def run_multiprox_correction_step(model, state, loader):
649
  "elapsed_ms": int((time.time() - t0) * 1000),
650
  }
651
  if is_frame:
 
652
  event["preview"] = _pil_to_b64(render_kg_subgraph(
653
- discrete_s.E[0].long(), n_max, X_index[0], dataset_id, loader))
 
 
 
 
 
 
 
 
 
 
654
  yield event
655
 
656
  X_int, E_int = _collapse_final_kg(model, cur_X, cur_E, cur_y, node_mask)
657
 
658
  corrected_E_int = E_int[0]
659
- changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
660
  corrected_img = render_kg_subgraph(
661
- corrected_E_int, n_max, X_index[0], dataset_id, loader, changes=changes)
 
662
  elapsed_ms = int((time.time() - t0) * 1000)
663
 
664
  updated_state = {
665
  **state,
666
  "E": E.cpu(), "y": y.cpu(),
 
667
  "step": new_step, "inner_step": new_inner_step,
668
  }
669
  yield {
 
1
  import base64
2
+ import faulthandler
3
  import io
4
+ import logging
5
+ import math
6
+ import sys
7
  import time
8
+ import traceback
9
+
10
+ faulthandler.enable(file=sys.stderr, all_threads=True)
11
+
12
+ logger = logging.getLogger(__name__)
13
 
14
  import torch
15
  import torch.nn.functional as F
16
+ import numpy as np
17
+ import networkx as nx
18
+ from PIL import Image, ImageDraw, ImageFont
19
 
20
  from api.services.graphgen_inference import (
21
  _frames_to_gif_b64, _pil_to_b64,
22
  )
23
+ from api.services.kg_likelihood import kg_edge_log_likelihood
24
+ from api.utils import clean_entity_name, clean_relation_name
25
 
26
  STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 # 10 MB
27
  REQUIRED_STATE_KEYS = {
 
90
  X_c[0, i] = int(communities[eid])
91
 
92
  n_nodes = torch.tensor([n], dtype=torch.long)
93
+ is_bip = torch.tensor([bool(subgraph.get("is_bip", False))], dtype=torch.bool)
94
  node_mask = torch.ones(1, n, dtype=torch.bool)
95
 
96
  return {
 
104
  return t.to(device) if isinstance(t, torch.Tensor) else t
105
 
106
 
107
+ def apply_edge_noise(model, tensors, task, noise_level, seed=None,
108
+ loader=None, dataset_id=None, nodes=None):
109
  """Forward-diffuse the given subgraph's edges at t = noise_level * T.
110
 
111
  For task="correct", only edges inside the inpaint mask (the second half of
112
  nodes) are noised, matching what the correction endpoint will regenerate.
113
  For task="generate", every edge slot is noised.
114
 
115
+ Returns a new list of edge dicts enriched with cleaned relation/entity
116
+ names when ``loader``/``dataset_id``/``nodes`` are supplied.
117
  """
118
  from graph_generation.src.utils import get_inpaint_mask
119
  from graph_generation.src.diffusion import diffusion_utils
 
153
  E_mixed = E_noised * inpaint_mask + E * (~inpaint_mask)
154
  E_int = E_mixed[0].argmax(dim=-1).cpu()
155
 
156
+ inv_relations = None
157
+ if loader is not None and dataset_id is not None:
158
+ _, _, inv_relations = loader.dataset.get_inverted_name_maps()
159
+
160
+ def node_name(idx):
161
+ if nodes is not None and 0 <= idx < len(nodes):
162
+ return nodes[idx].get("entity_name") or f"#{nodes[idx].get('entity_id', idx)}"
163
+ return f"#{idx}"
164
+
165
+ def relation_label(rid):
166
+ if inv_relations is None:
167
+ return None
168
+ raw = inv_relations.get(rid)
169
+ if raw is None or raw != raw or str(raw).strip() == "":
170
+ return f"rel#{rid}"
171
+ cleaned = clean_relation_name(str(raw), dataset_id)
172
+ return cleaned if cleaned else f"rel#{rid}"
173
+
174
  edges = []
175
  for i in range(n):
176
  for j in range(n):
 
179
  cls = int(E_int[i, j])
180
  if cls == 0:
181
  continue
182
+ rel_id = cls - 1
183
+ edge = {
184
+ "source_idx": i, "target_idx": j, "relation_id": rel_id,
185
+ }
186
+ if inv_relations is not None:
187
+ edge["relation_name"] = relation_label(rel_id)
188
+ edge["entity_name_source"] = node_name(i)
189
+ edge["entity_name_target"] = node_name(j)
190
+ edges.append(edge)
191
  return edges
192
 
193
 
 
195
  # Change detection
196
  # ---------------------------------------------------------------------------
197
 
198
+ def compute_changes(original_E_int, corrected_E_int, num_nodes, loader, dataset_id):
199
  """Compute before/after edge diff for a directed KG subgraph.
200
 
201
  original_E_int / corrected_E_int: 2-D int tensors (n, n) where 0 = no edge
 
203
  """
204
  _, _, inv_relations = loader.dataset.get_inverted_name_maps()
205
 
206
+ def rel_name(idx):
207
+ raw = inv_relations.get(idx)
208
+ if raw is None or raw != raw or str(raw).strip() == "":
209
+ return f"rel#{idx}"
210
+ cleaned = clean_relation_name(str(raw), dataset_id)
211
+ return cleaned if cleaned else f"rel#{idx}"
212
+
213
  edges = []
214
  summary = {"added": 0, "removed": 0, "modified": 0, "unchanged": 0}
215
 
 
229
  edges.append({
230
  "source_idx": i, "target_idx": j, "change": "unchanged",
231
  "relation_id": c - 1,
232
+ "relation_name": rel_name(c - 1),
233
  })
234
  continue
235
  if o == 0 and c > 0:
 
237
  edges.append({
238
  "source_idx": i, "target_idx": j, "change": "added",
239
  "relation_id": c - 1,
240
+ "relation_name": rel_name(c - 1),
241
  })
242
  elif o > 0 and c == 0:
243
  summary["removed"] += 1
244
  edges.append({
245
  "source_idx": i, "target_idx": j, "change": "removed",
246
  "original_relation_id": o - 1,
247
+ "original_relation_name": rel_name(o - 1),
248
  })
249
  else:
250
  summary["modified"] += 1
251
  edges.append({
252
  "source_idx": i, "target_idx": j, "change": "modified",
253
  "original_relation_id": o - 1,
254
+ "original_relation_name": rel_name(o - 1),
255
  "relation_id": c - 1,
256
+ "relation_name": rel_name(c - 1),
257
  })
258
 
259
  return {"edges": edges, "summary": summary}
 
263
  # Rendering
264
  # ---------------------------------------------------------------------------
265
 
266
+ def _truncate_label(s, limit):
267
+ s = str(s)
268
+ return s if len(s) <= limit else s[: limit - 1] + ""
269
+
270
+
271
+ # Green-palette anchors matching the site's primary colour: pale mint -> vivid
272
+ # green -> deep forest. Used to colour nodes by the normalized-Laplacian
273
+ # eigenvector.
274
+ _GREEN_LOW = (212, 237, 218)
275
+ _GREEN_MID = (82, 180, 120)
276
+ _GREEN_HIGH = (22, 80, 50)
277
+
278
+
279
+ def _green_rgb(t):
280
+ """Map t in [-1, 1] to an (r, g, b) tuple on a three-anchor green gradient."""
281
+ t = max(-1.0, min(1.0, float(t)))
282
+ if t < 0:
283
+ w = t + 1.0 # 0 at -1, 1 at 0
284
+ a, b = _GREEN_LOW, _GREEN_MID
285
  else:
286
+ w = t # 0 at 0, 1 at 1
287
+ a, b = _GREEN_MID, _GREEN_HIGH
288
+ return (
289
+ int(a[0] + (b[0] - a[0]) * w),
290
+ int(a[1] + (b[1] - a[1]) * w),
291
+ int(a[2] + (b[2] - a[2]) * w),
292
+ )
293
+
294
+
295
+ def _quad_bezier(p0, p1, p2, steps=20):
296
+ out = []
297
+ for k in range(steps + 1):
298
+ t = k / steps
299
+ u = 1 - t
300
+ x = u * u * p0[0] + 2 * u * t * p1[0] + t * t * p2[0]
301
+ y = u * u * p0[1] + 2 * u * t * p1[1] + t * t * p2[1]
302
+ out.append((x, y))
303
+ return out
304
+
305
+
306
+ def _draw_arrowhead(draw, tip, direction, color):
307
+ ux, uy = direction
308
+ angle = math.atan2(uy, ux)
309
+ ah_len = 9
310
+ ah_angle = math.radians(25)
311
+ x, y = tip
312
+ x1 = x - ah_len * math.cos(angle - ah_angle)
313
+ y1 = y - ah_len * math.sin(angle - ah_angle)
314
+ x2 = x - ah_len * math.cos(angle + ah_angle)
315
+ y2 = y - ah_len * math.sin(angle + ah_angle)
316
+ draw.polygon([(x, y), (x1, y1), (x2, y2)], fill=color)
317
+
318
+
319
+ def _draw_curve(draw, p0, p2, color, width=2, dashed=False, rad=0.2):
320
+ """Draw a quadratic Bézier from p0 to p2 curved by `rad` (fraction of chord length)."""
321
+ mx, my = (p0[0] + p2[0]) / 2, (p0[1] + p2[1]) / 2
322
+ dx, dy = p2[0] - p0[0], p2[1] - p0[1]
323
+ # Perpendicular offset for the control point
324
+ nx_, ny_ = -dy, dx
325
+ nlen = max(1.0, math.hypot(nx_, ny_))
326
+ ctrl = (mx + nx_ / nlen * rad * math.hypot(dx, dy),
327
+ my + ny_ / nlen * rad * math.hypot(dx, dy))
328
+ pts = _quad_bezier(p0, ctrl, p2, steps=24)
329
+ if dashed:
330
+ for k in range(0, len(pts) - 1, 2):
331
+ draw.line([pts[k], pts[k + 1]], fill=color, width=width)
332
  else:
333
+ draw.line(pts, fill=color, width=width, joint="curve")
334
+ # Tangent at p2 (for arrowhead direction)
335
+ px, py = pts[-2]
336
+ tx, ty = p2[0] - px, p2[1] - py
337
+ tlen = max(1.0, math.hypot(tx, ty))
338
+ return (tx / tlen, ty / tlen)
339
+
340
+
341
+ def _bipartite_layout(row_nodes, col_nodes):
342
+ """Two-column layout: row nodes on the left, col nodes on the right, evenly spaced."""
343
+ pos = {}
344
+
345
+ def place(nodes, x):
346
+ k = len(nodes)
347
+ for i, n in enumerate(nodes):
348
+ y = 1.0 - 2.0 * (i + 1) / (k + 1) # evenly spaced in [-1, 1], top-down
349
+ pos[n] = (x, y)
350
+
351
+ place(row_nodes, -1.0)
352
+ place(col_nodes, 1.0)
353
+ return pos
354
 
355
 
356
+ def render_kg_subgraph(E_int, num_nodes, X_index, dataset_id, loader, changes=None, is_bip=False):
357
  """Render a directed KG subgraph as a PIL image using networkx + PIL.
358
 
359
+ Ports the improvements from KnowledgeGraphVisualization.visualize_non_molecule
360
+ without importing matplotlib (which conflicts with torch on Windows): isolated
361
+ singleton filter, spring_layout(k=1), coolwarm node colouring from the
362
+ normalized-Laplacian eigenvector, and curved edges. When `changes` is
363
+ provided, per-edge CHANGE_COLORS override the default grey and "removed"
364
+ edges are drawn dashed. When `is_bip`, uses a two-column layout that
365
+ separates the row/col partitions and visually marks the inpaint quadrants.
366
  """
367
+ inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
 
 
 
368
 
369
  e = E_int.cpu().tolist()
370
  xi = X_index.cpu().tolist()
371
 
372
+ change_lookup = {}
373
+ if changes is not None:
374
+ for entry in changes.get("edges", []):
375
+ change_lookup[(entry["source_idx"], entry["target_idx"])] = entry["change"]
376
+
377
  G = nx.DiGraph()
378
  for i in range(num_nodes):
379
  G.add_node(i)
 
384
  if int(e[i][j]) > 0:
385
  G.add_edge(i, j, rel=int(e[i][j]) - 1)
386
 
 
 
 
 
387
  if changes is not None:
388
+ for (i, j), ct in change_lookup.items():
389
+ if ct == "removed" and not G.has_edge(i, j):
390
+ G.add_edge(i, j, rel=None)
391
 
392
+ # Bipartite: keep every node so the row/col structure stays visible.
393
+ # Community: drop isolated singletons so the spring layout focuses on structure.
394
+ if is_bip:
395
+ graph = G.copy()
396
+ else:
397
+ components = [G.subgraph(c).copy() for c in nx.connected_components(G.to_undirected())]
398
+ components = [c for c in components if c.number_of_nodes() > 1]
399
+ graph = nx.compose_all(components) if components else G
400
 
401
+ size = 520
402
  img = Image.new("RGB", (size, size), "white")
403
  draw = ImageDraw.Draw(img)
404
+
405
+ if graph.number_of_nodes() == 0:
406
+ return img
407
+
408
  try:
409
+ font = ImageFont.truetype("arial.ttf", 12)
 
410
  except (OSError, IOError):
411
  font = ImageFont.load_default()
 
412
 
413
+ if is_bip:
414
+ row_nodes = [n for n in graph.nodes() if n < num_nodes // 2]
415
+ col_nodes = [n for n in graph.nodes() if n >= num_nodes // 2]
416
+ pos = _bipartite_layout(row_nodes, col_nodes)
417
+ else:
418
+ pos = nx.spring_layout(graph, k=1, iterations=100, seed=42)
419
 
420
+ # Normalized Laplacian eigenvector for node colouring. Use torch.linalg.eigh
421
+ # rather than numpy.linalg.eigh on Windows, numpy's MKL DLLs conflict with
422
+ # torch's (Windows code 0xc06d007f), and torch is already healthy in-process.
423
+ try:
424
+ L = nx.normalized_laplacian_matrix(graph.to_undirected()).toarray()
425
+ L_t = torch.from_numpy(L).to(torch.float64)
426
+ _, U_t = torch.linalg.eigh(L_t)
427
+ U = U_t.numpy()
428
+ eigen_dim = 1 if U.shape[1] > 1 else 0
429
+ vec = U[:, eigen_dim]
430
+ m_abs = max(abs(vec.min()), abs(vec.max()), 1e-9)
431
+ vec_norm = vec / m_abs # now in [-1, 1]
432
+ except Exception:
433
+ logger.warning(
434
+ "eigenvector colouring failed; using flat colour:\n%s",
435
+ traceback.format_exc(),
436
+ )
437
+ vec_norm = np.zeros(graph.number_of_nodes())
438
+
439
+ node_list = list(graph.nodes())
440
+ node_color = {n: _green_rgb(vec_norm[k]) for k, n in enumerate(node_list)}
441
+
442
+ xs, ys = zip(*pos.values())
443
+ x_min, x_max = min(xs), max(xs)
444
+ y_min, y_max = min(ys), max(ys)
445
+ x_span = (x_max - x_min) or 1.0
446
+ y_span = (y_max - y_min) or 1.0
447
+ margin = 55
448
+ scale = (size - 2 * margin) / max(x_span, y_span)
449
+ cx = size / 2 - (x_min + x_span / 2) * scale
450
+ cy = size / 2 + (y_min + y_span / 2) * scale # flip y so "up" is up
451
+
452
+ def to_px(p):
453
+ return (cx + p[0] * scale, cy - p[1] * scale)
454
+
455
+ pixel_pos = {n: to_px(pos[n]) for n in graph.nodes()}
456
+
457
+ node_r = 12
458
+
459
+ # Edges first so nodes overlay them
460
+ for (i, j) in graph.edges():
461
+ ct = change_lookup.get((i, j))
462
+ dashed = (ct == "removed")
463
+ if ct is not None:
464
+ color = CHANGE_COLORS.get(ct, "#6b6b6b")
465
  else:
466
+ color = "#6b6b6b"
467
+ p0 = pixel_pos[i]
468
+ p2 = pixel_pos[j]
469
+ # Shorten endpoints so curve doesn't overlap node circles
470
+ dx, dy = p2[0] - p0[0], p2[1] - p0[1]
471
+ dist = max(1.0, math.hypot(dx, dy))
472
+ ux, uy = dx / dist, dy / dist
473
+ sx, sy = p0[0] + ux * node_r, p0[1] + uy * node_r
474
+ ex, ey = p2[0] - ux * node_r, p2[1] - uy * node_r
475
+ tangent = _draw_curve(draw, (sx, sy), (ex, ey), color,
476
+ width=2, dashed=dashed, rad=0.2)
477
+ _draw_arrowhead(draw, (ex, ey), tangent, color)
478
+
479
+ for n in graph.nodes():
480
+ x, y = pixel_pos[n]
481
+ r, g, b = node_color[n]
482
  draw.ellipse([x - node_r, y - node_r, x + node_r, y + node_r],
483
+ fill=(r, g, b), outline="#333333", width=1)
484
+ eid = int(xi[n]) if n < len(xi) else n
485
+ raw = inv_nodes.get(eid)
486
+ if raw is None or raw != raw or str(raw).strip() == "":
487
+ label = f"#{eid}"
488
+ else:
489
+ cleaned = clean_entity_name(str(raw), dataset_id)
490
+ label = cleaned if cleaned else f"#{eid}"
491
+ label = _truncate_label(label, 20)
492
+ _draw_text_with_bg(draw, (x + node_r + 3, y - 7), label, font, fill="#111111")
493
 
494
  return img
495
 
496
 
497
+ def _draw_text_with_bg(draw, xy, text, font, fill):
498
+ """Draw text with a semi-opaque white pad so labels stay readable over lines."""
499
+ try:
500
+ bbox = draw.textbbox(xy, text, font=font)
501
+ pad = 1
502
+ draw.rectangle(
503
+ [bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad],
504
+ fill="white",
505
+ )
506
+ except Exception:
507
+ pass
508
+ draw.text(xy, text, font=font, fill=fill)
509
 
510
 
511
+ def render_sample_subgraph_b64(subgraph, loader, dataset_id):
512
+ """Render a sample subgraph dict (API payload shape) to a PNG data URI.
513
+
514
+ Uses the same renderer as inference outputs so thumbnails are visually
515
+ consistent with the before/after images produced later.
516
+ """
517
+ nodes = subgraph["nodes"]
518
+ edges = subgraph["edges"]
519
+ n = len(nodes)
520
+ if n == 0:
521
+ return None
522
+
523
+ try:
524
+ E_int = torch.zeros(n, n, dtype=torch.long)
525
+ for e in edges:
526
+ src = int(e["source_idx"])
527
+ tgt = int(e["target_idx"])
528
+ rel = int(e["relation_id"])
529
+ if 0 <= src < n and 0 <= tgt < n:
530
+ E_int[src, tgt] = rel + 1
531
+
532
+ X_index = torch.tensor([int(node["entity_id"]) for node in nodes], dtype=torch.long)
533
+ is_bip = bool(subgraph.get("is_bip", False))
534
+ img = render_kg_subgraph(
535
+ E_int, n, X_index, dataset_id, loader, changes=None, is_bip=is_bip)
536
+ return _pil_to_b64(img)
537
+ except Exception:
538
+ logger.error(
539
+ "render_sample_subgraph_b64 failed: dataset=%s n=%d\n%s",
540
+ dataset_id, n, traceback.format_exc(),
541
+ )
542
+ raise
543
 
544
 
545
  # ---------------------------------------------------------------------------
 
577
  E_given = tensors["E_given"].to(device)
578
  y_given = tensors["y_given"].to(device)
579
  X_index = tensors["X_index"].to(device)
580
+ X_c = tensors["X_c"].to(device)
581
  is_bip = tensors["is_bip"].to(device)
582
  n_nodes = tensors["n_nodes"].to(device)
583
  node_mask = tensors["node_mask"].to(device)
584
  n_max = n_nodes.item()
585
+ kg_experiment = getattr(model, "kg_experiment", None)
586
 
587
  inpaint_mask = _build_inpaint_mask(
588
  task, node_mask, is_bip, model.Edim_output, device)
589
+ is_bip_bool = bool(is_bip.item())
590
 
591
  original_E_int = E_given[0].argmax(dim=-1).long() # (n, n)
592
  original_img = render_kg_subgraph(
593
+ original_E_int, n_max, X_index[0], dataset_id, loader,
594
+ changes=None, is_bip=is_bip_bool)
595
 
596
  model_T = model.T
597
  step_stride = max(1, model_T // diffusion_steps)
 
626
  }
627
  if is_frame:
628
  frame = render_kg_subgraph(
629
+ E_int_prev, n_max, X_index[0], dataset_id, loader,
630
+ is_bip=is_bip_bool)
631
  gif_frames.append(frame)
632
  event["preview"] = _pil_to_b64(frame)
633
+ if kg_experiment is not None:
634
+ ll = kg_edge_log_likelihood(
635
+ E_int_prev, X_given[0], X_index[0], X_c[0], kg_experiment)
636
+ if ll is not None:
637
+ event["kg_log_likelihood"] = ll
638
+ event["kg_log_likelihood_step"] = emitted
639
+ logger.info(
640
+ "[kg-anomaly] denoise step=%d/%d kg_log_lik=%.4f",
641
+ emitted, total_loop_steps, ll)
642
  yield event
643
 
644
  X_final, E_final = _collapse_final_kg(model, X, E, y, node_mask)
645
 
646
  corrected_E_int = E_final[0]
647
+ changes = compute_changes(original_E_int, corrected_E_int, n_max, loader, dataset_id)
648
  corrected_img = render_kg_subgraph(
649
+ corrected_E_int, n_max, X_index[0], dataset_id, loader,
650
+ changes=changes, is_bip=is_bip_bool)
651
 
652
  elapsed_ms = int((time.time() - t0) * 1000)
653
  yield {
 
679
 
680
  inpaint_mask = _build_inpaint_mask(
681
  task, node_mask, is_bip, model.Edim_output, device)
682
+ is_bip_bool = bool(is_bip.item())
683
  original_E_int = E_given[0].argmax(dim=-1).long()
684
  original_img = render_kg_subgraph(
685
+ original_E_int, n_max, X_index[0], dataset_id, loader,
686
+ changes=None, is_bip=is_bip_bool)
687
 
688
  t0 = time.time()
689
  # Sample initial noise for each of M Gibbs chains
 
712
  agg_y = torch.median(y_ens.float(), dim=1).values
713
  X_int, E_int = _collapse_final_kg(model, X_given, agg_E, agg_y, node_mask)
714
  corrected_E_int = E_int[0]
715
+ changes = compute_changes(original_E_int, corrected_E_int, n_max, loader, dataset_id)
716
  preview_img = render_kg_subgraph(
717
+ corrected_E_int, n_max, X_index[0], dataset_id, loader,
718
+ changes=changes, is_bip=is_bip_bool)
719
  elapsed_ms = int((time.time() - t0) * 1000)
720
 
721
  state = {
 
724
  "y": y_ens.cpu(),
725
  "n_nodes": n_nodes.cpu(),
726
  "dataset_id": dataset_id,
727
+ "is_bip": bool(is_bip_bool),
728
  "task": task,
729
  "X_index": X_index.cpu(),
730
  "X_c": X_c.cpu(),
731
  "is_bip": is_bip.cpu(),
732
  "original_E_int": original_E_int.cpu(),
733
+ "prev_E_int": corrected_E_int.cpu(),
734
  "T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime,
735
  "gibbs_chain_freq": gibbs_chain_freq,
736
  "inner_step": 0, "step": 0,
 
753
  E = state["E"].to(device)
754
  y = state["y"].to(device)
755
  X_index = state["X_index"].to(device)
756
+ X_c = state["X_c"].to(device)
757
  is_bip = state["is_bip"].to(device)
758
  n_nodes = state["n_nodes"].to(device)
759
  original_E_int = state["original_E_int"].to(device)
760
+ prev_E_int = state.get("prev_E_int", state["original_E_int"]).to(device)
761
+ kg_experiment = getattr(model, "kg_experiment", None)
762
 
763
  T = state["T"]
764
  n = state["n"]
 
772
  n_max = int(n_nodes.item())
773
  node_mask = torch.ones(1, n_max, dtype=torch.bool, device=device)
774
  inpaint_mask = _build_inpaint_mask(task, node_mask, is_bip, model.Edim_output, device)
775
+ is_bip_bool = bool(is_bip.item())
776
 
777
  fixed_t_norm = t * torch.ones((1, 1), dtype=torch.float, device=device)
778
  fixed_s_norm = fixed_t_norm - (1.0 / T)
 
803
  prev_y = torch.median(y.float(), dim=1).values
804
  _, prev_Ei = _collapse_final_kg(model, X_given, prev_E, prev_y, node_mask)
805
  preview_img = render_kg_subgraph(
806
+ prev_Ei[0], n_max, X_index[0], dataset_id, loader,
807
+ is_bip=is_bip_bool)
808
+ event = {
809
  "type": "progress",
810
  "phase": "gibbs",
811
  "step": i + 1,
 
813
  "elapsed_ms": int((time.time() - t0) * 1000),
814
  "preview": _pil_to_b64(preview_img),
815
  }
816
+ if kg_experiment is not None:
817
+ ll = kg_edge_log_likelihood(
818
+ prev_Ei[0], X_given[0], X_index[0], X_c[0], kg_experiment)
819
+ if ll is not None:
820
+ event["kg_log_likelihood"] = ll
821
+ event["kg_log_likelihood_step"] = i + 1
822
+ logger.info(
823
+ "[kg-anomaly] gibbs step=%d/%d kg_log_lik=%.4f",
824
+ i + 1, steps_this_call, ll)
825
+ yield event
826
 
827
  new_inner_step = inner_step + steps_this_call
828
  round_complete = new_inner_step >= m
 
855
  "elapsed_ms": int((time.time() - t0) * 1000),
856
  }
857
  if is_frame:
858
+ refine_E_int = discrete_s.E[0].long()
859
  event["preview"] = _pil_to_b64(render_kg_subgraph(
860
+ refine_E_int, n_max, X_index[0], dataset_id, loader,
861
+ is_bip=is_bip_bool))
862
+ if kg_experiment is not None:
863
+ ll = kg_edge_log_likelihood(
864
+ refine_E_int, X_given[0], X_index[0], X_c[0], kg_experiment)
865
+ if ll is not None:
866
+ event["kg_log_likelihood"] = ll
867
+ event["kg_log_likelihood_step"] = j + 1
868
+ logger.info(
869
+ "[kg-anomaly] refine step=%d/%d kg_log_lik=%.4f",
870
+ j + 1, P, ll)
871
  yield event
872
 
873
  X_int, E_int = _collapse_final_kg(model, cur_X, cur_E, cur_y, node_mask)
874
 
875
  corrected_E_int = E_int[0]
876
+ changes = compute_changes(prev_E_int, corrected_E_int, n_max, loader, dataset_id)
877
  corrected_img = render_kg_subgraph(
878
+ corrected_E_int, n_max, X_index[0], dataset_id, loader,
879
+ changes=changes, is_bip=is_bip_bool)
880
  elapsed_ms = int((time.time() - t0) * 1000)
881
 
882
  updated_state = {
883
  **state,
884
  "E": E.cpu(), "y": y.cpu(),
885
+ "prev_E_int": corrected_E_int.cpu(),
886
  "step": new_step, "inner_step": new_inner_step,
887
  }
888
  yield {
src/backend/api/services/kg_likelihood.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Per-step KG link-prediction log-likelihood for the denoising loop.
2
+
3
+ Wraps the math of `KGLikelihoodMetric.update` (see research/COINs-KGGeneration
4
+ .../metrics/abstract_metrics.py) in a one-shot, stateless helper. We query
5
+ the frozen KG embedder + link ranker on the edges currently present in the
6
+ argmax reconstruction and return their mean log-sigmoid score — a positive
7
+ higher-is-better value that rises as the graph becomes cleaner.
8
+ """
9
+
10
+ import logging
11
+
12
+ import torch
13
+ from torch.nn.functional import logsigmoid, one_hot
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def kg_edge_log_likelihood(E_int, X, X_index, X_c, kg_experiment):
19
+ """Mean log-sigmoid link-ranker score over edges currently present.
20
+
21
+ E_int: (n, n) long tensor. 0 = no edge; otherwise class = relation_id + 1.
22
+ X: (n, num_node_types) one-hot node types (unbatched, float).
23
+ X_index: (n,) long dataset-global entity ids (unbatched).
24
+ X_c: (n,) long community ids (unbatched).
25
+ kg_experiment: COINs experiment exposing .embedder, .link_ranker,
26
+ .loader.num_relations, .mini_batch_size, .device.
27
+
28
+ Returns a Python float (log-likelihood per edge) or None if no edges are
29
+ present or the scoring pass fails for any reason.
30
+ """
31
+ from graph_completion.graphs.preprocess import QueryData
32
+ from graph_completion.graphs.queries import Query
33
+
34
+ try:
35
+ embedder = kg_experiment.embedder
36
+ link_ranker = kg_experiment.link_ranker.link_ranker
37
+ num_relations = kg_experiment.loader.num_relations
38
+ kg_device = kg_experiment.device
39
+ mini_batch_size = kg_experiment.mini_batch_size
40
+
41
+ nz = E_int.nonzero(as_tuple=False)
42
+ if nz.numel() == 0:
43
+ return None
44
+ nz = nz[nz[:, 0] != nz[:, 1]]
45
+ if nz.numel() == 0:
46
+ return None
47
+ s, t = nz[:, 0], nz[:, 1]
48
+ r = E_int[s, t] - 1
49
+
50
+ e_s, e_t = X_index[s].long(), X_index[t].long()
51
+ x_s, x_t = X[s].float(), X[t].float()
52
+ c_s, c_t = X_c[s].long(), X_c[t].long()
53
+
54
+ # Stable sort by (c_s, c_t) — the embedder batches by community pair.
55
+ s_sort = torch.argsort(c_s)
56
+ t_sort = torch.sort(c_t[s_sort], stable=True).indices
57
+ pick = lambda v: v[s_sort][t_sort]
58
+ e = [pick(e_s), pick(e_t)]
59
+ x = [pick(x_s), pick(x_t)]
60
+ c = [pick(c_s), pick(c_t)]
61
+ r = pick(r)
62
+ edge_attr = [one_hot(r, num_relations + 1).float()]
63
+
64
+ q = Query("1p")
65
+ q.build_query_tree()
66
+ query_data = QueryData(q, e=e, x=x, c=c, edge_attr=edge_attr).to(kg_device)
67
+
68
+ scores = []
69
+ with torch.no_grad():
70
+ for qd_batch in query_data.batch_split(mini_batch_size):
71
+ q_emb, a_emb = embedder(qd_batch)
72
+ scores.append(link_ranker(q_emb, a_emb))
73
+ scores = torch.cat(scores, dim=0).view(-1)
74
+ if scores.numel() == 0:
75
+ return None
76
+ return float(logsigmoid(scores).mean().item())
77
+ except Exception as exc:
78
+ logger.warning("[kg-likelihood] skipped: %s", exc)
79
+ return None
src/backend/api/services/registry.py CHANGED
@@ -706,28 +706,73 @@ class ModelRegistry:
706
  except Exception:
707
  logger.exception("Failed to generate sample subgraphs for %s", dataset_id)
708
 
709
- def _build_sample_subgraphs(self, dataset_id, loader, num_subgraphs=20, max_graph_size=10):
710
- """Build sample subgraphs using the Sampler's DFS-based context subgraph partitioning."""
 
 
 
 
 
 
711
  inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
712
  node_types = loader.dataset.node_data.type.values
713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
  # Use the Sampler's DFS partitioning to get context subgraphs
715
  samples = loader.sampler.get_context_subgraph_samples_dfs(
716
  max_graph_size, loader.graph_indexes, loader.num_nodes,
717
- max_samples=num_subgraphs * 5, disable_tqdm=True,
718
  )
719
 
 
 
 
 
 
 
 
 
 
 
 
720
  subgraphs = []
721
  for subgraph_row, subgraph_col, nodes_row, nodes_col, edges in samples:
722
  if len(subgraphs) >= num_subgraphs:
723
  break
 
 
724
  if len(edges) < 3:
725
  continue
726
 
727
- if subgraph_row == subgraph_col:
728
- sg_nodes = nodes_row
729
- else:
 
 
 
 
 
 
730
  sg_nodes = nodes_row + nodes_col
 
 
 
 
 
 
731
 
732
  node_idx = {n: i for i, n in enumerate(sg_nodes)}
733
 
@@ -736,7 +781,7 @@ class ModelRegistry:
736
  type_id = int(node_types[n]) if n < len(node_types) else 0
737
  nodes.append({
738
  "entity_id": n,
739
- "entity_name": str(inv_nodes.get(n, n)),
740
  "type_id": type_id,
741
  })
742
 
@@ -747,18 +792,23 @@ class ModelRegistry:
747
  "source_idx": node_idx[h],
748
  "target_idx": node_idx[t],
749
  "relation_id": r,
750
- "relation_name": str(inv_relations.get(r, r)),
751
- "entity_name_source": str(inv_nodes.get(h, h)),
752
- "entity_name_target": str(inv_nodes.get(t, t)),
753
  })
754
 
755
  subgraphs.append({
756
  "id": f"sample_{len(subgraphs) + 1}",
757
  "num_nodes": len(nodes),
758
  "num_edges": len(edge_list),
 
 
759
  "nodes": nodes,
760
  "edges": edge_list,
761
  })
 
 
 
762
 
763
  # Free the partitioning data stored on the sampler
764
  loader.sampler.context_subgraphs_nodes = None
 
706
  except Exception:
707
  logger.exception("Failed to generate sample subgraphs for %s", dataset_id)
708
 
709
+ def _build_sample_subgraphs(self, dataset_id, loader, num_subgraphs=40,
710
+ max_graph_size=10, seed=None):
711
+ """Build sample subgraphs using the Sampler's DFS-based context subgraph partitioning.
712
+
713
+ When ``seed`` is provided, the DFS iterates node indices in a shuffled order, so
714
+ different seeds produce different partitions. Without a seed the order is
715
+ deterministic (original research-code behaviour).
716
+ """
717
  inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
718
  node_types = loader.dataset.node_data.type.values
719
 
720
+ def entity_label(idx):
721
+ raw = inv_nodes.get(idx)
722
+ if raw is None or raw != raw or str(raw).strip() == "": # None / NaN / empty
723
+ return f"#{idx}"
724
+ cleaned = clean_entity_name(str(raw), dataset_id)
725
+ return cleaned if cleaned else f"#{idx}"
726
+
727
+ def relation_label(idx):
728
+ raw = inv_relations.get(idx)
729
+ if raw is None or raw != raw or str(raw).strip() == "":
730
+ return f"rel#{idx}"
731
+ cleaned = clean_relation_name(str(raw), dataset_id)
732
+ return cleaned if cleaned else f"rel#{idx}"
733
+
734
  # Use the Sampler's DFS partitioning to get context subgraphs
735
  samples = loader.sampler.get_context_subgraph_samples_dfs(
736
  max_graph_size, loader.graph_indexes, loader.num_nodes,
737
+ max_samples=num_subgraphs * 5, seed=seed, disable_tqdm=True,
738
  )
739
 
740
+ # Randomly pick (row, col) partition pairs so each sample is structurally
741
+ # distinct from the others. Without shuffling, the DFS returns samples in
742
+ # nested (k, l) order, which means the first N samples all reuse
743
+ # partition 0's nodes. Shuffling + a disjoint-partitions guard gives 5
744
+ # different subgraphs each call.
745
+ import random as _random
746
+ rng = _random.Random(seed)
747
+ samples = list(samples)
748
+ rng.shuffle(samples)
749
+
750
+ used_partitions = set()
751
  subgraphs = []
752
  for subgraph_row, subgraph_col, nodes_row, nodes_col, edges in samples:
753
  if len(subgraphs) >= num_subgraphs:
754
  break
755
+ if subgraph_row in used_partitions or subgraph_col in used_partitions:
756
+ continue
757
  if len(edges) < 3:
758
  continue
759
 
760
+ is_bip = (subgraph_row != subgraph_col)
761
+ if is_bip:
762
+ # Inpaint mask math assumes balanced halves (n/4, n/2, 3n/4 split).
763
+ # Only accept bipartite samples where row/col are the same size and
764
+ # divisible by 4, so the four quadrants are well-defined.
765
+ if len(nodes_row) != len(nodes_col) or len(nodes_row) < 2:
766
+ continue
767
+ if (2 * len(nodes_row)) % 4 != 0:
768
+ continue
769
  sg_nodes = nodes_row + nodes_col
770
+ row_size = len(nodes_row)
771
+ else:
772
+ if len(nodes_row) < 4 or len(nodes_row) % 2 != 0:
773
+ continue
774
+ sg_nodes = nodes_row
775
+ row_size = len(nodes_row)
776
 
777
  node_idx = {n: i for i, n in enumerate(sg_nodes)}
778
 
 
781
  type_id = int(node_types[n]) if n < len(node_types) else 0
782
  nodes.append({
783
  "entity_id": n,
784
+ "entity_name": entity_label(n),
785
  "type_id": type_id,
786
  })
787
 
 
792
  "source_idx": node_idx[h],
793
  "target_idx": node_idx[t],
794
  "relation_id": r,
795
+ "relation_name": relation_label(r),
796
+ "entity_name_source": entity_label(h),
797
+ "entity_name_target": entity_label(t),
798
  })
799
 
800
  subgraphs.append({
801
  "id": f"sample_{len(subgraphs) + 1}",
802
  "num_nodes": len(nodes),
803
  "num_edges": len(edge_list),
804
+ "is_bip": is_bip,
805
+ "row_size": row_size,
806
  "nodes": nodes,
807
  "edges": edge_list,
808
  })
809
+ used_partitions.add(subgraph_row)
810
+ if is_bip:
811
+ used_partitions.add(subgraph_col)
812
 
813
  # Free the partitioning data stored on the sampler
814
  loader.sampler.context_subgraphs_nodes = None
src/backend/api/views/kg_anomaly.py CHANGED
@@ -1,12 +1,22 @@
 
 
 
 
 
1
  from rest_framework.response import Response
2
  from rest_framework.views import APIView
3
 
4
  from api.exceptions import InvalidRequestError, ModelUnavailable, NotFoundError
 
5
  from api.services.constants import KG_ANOMALY_DATASET_META
6
- from api.services.kg_anomaly_inference import apply_edge_noise, build_kg_tensors
 
 
7
  from api.services.registry import ModelRegistry
8
  from api.views.graph_generation import _streaming_sse_response
9
 
 
 
10
 
11
  class KgAnomalyDatasetsView(APIView):
12
  def get(self, request):
@@ -29,14 +39,29 @@ class KgAnomalySampleSubgraphsView(APIView):
29
  raise NotFoundError(f"Dataset '{dataset_id}' not found")
30
 
31
  registry = ModelRegistry.get()
32
- sg_info = registry.kg_anomaly_subgraphs.get(dataset_id)
33
- if sg_info is None:
34
  raise NotFoundError(f"No sample subgraphs available for dataset '{dataset_id}'")
35
 
36
  count = int(request.query_params.get("count", 5))
37
  count = max(1, min(10, count))
38
 
39
- subgraphs = [dict(sg) for sg in sg_info.subgraphs[:count]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  noise_level_raw = request.query_params.get("noise_level")
42
  if noise_level_raw is not None:
@@ -56,16 +81,39 @@ class KgAnomalySampleSubgraphsView(APIView):
56
  raise ModelUnavailable(
57
  f"No '{task}' checkpoint available for dataset '{dataset_id}'")
58
 
59
- seed_raw = request.query_params.get("seed")
60
- seed = int(seed_raw) if seed_raw is not None else None
61
-
62
- loader = registry.loaders[dataset_id]
63
- model = registry._load_kg_anomaly_model(dataset_id, task)
 
 
64
 
65
  for i, sg in enumerate(subgraphs):
66
- offset_seed = None if seed is None else seed + i
67
- tensors = build_kg_tensors(sg, loader, model)
68
- sg["edges"] = apply_edge_noise(model, tensors, task, noise_level, offset_seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  return Response({
71
  "dataset_id": dataset_id,
@@ -99,6 +147,8 @@ def _validate_subgraph(subgraph):
99
 
100
 
101
  class KgAnomalyCorrectView(APIView):
 
 
102
  def post(self, request):
103
  data = request.data
104
  registry = ModelRegistry.get()
@@ -166,6 +216,8 @@ class KgAnomalyCorrectView(APIView):
166
 
167
 
168
  class KgAnomalyContinueView(APIView):
 
 
169
  def post(self, request):
170
  state_b64 = request.data.get("state")
171
  if not state_b64 or not isinstance(state_b64, str):
 
1
+ import logging
2
+ import random
3
+ import traceback
4
+
5
+ from rest_framework.renderers import JSONRenderer
6
  from rest_framework.response import Response
7
  from rest_framework.views import APIView
8
 
9
  from api.exceptions import InvalidRequestError, ModelUnavailable, NotFoundError
10
+ from api.renderers import EventStreamRenderer
11
  from api.services.constants import KG_ANOMALY_DATASET_META
12
+ from api.services.kg_anomaly_inference import (
13
+ apply_edge_noise, build_kg_tensors, render_sample_subgraph_b64,
14
+ )
15
  from api.services.registry import ModelRegistry
16
  from api.views.graph_generation import _streaming_sse_response
17
 
18
+ logger = logging.getLogger(__name__)
19
+
20
 
21
  class KgAnomalyDatasetsView(APIView):
22
  def get(self, request):
 
39
  raise NotFoundError(f"Dataset '{dataset_id}' not found")
40
 
41
  registry = ModelRegistry.get()
42
+ loader = registry.loaders.get(dataset_id)
43
+ if loader is None:
44
  raise NotFoundError(f"No sample subgraphs available for dataset '{dataset_id}'")
45
 
46
  count = int(request.query_params.get("count", 5))
47
  count = max(1, min(10, count))
48
 
49
+ seed_raw = request.query_params.get("seed")
50
+ # Random per-request seed when the caller doesn't pin one, so each call
51
+ # produces a different DFS node-order shuffle and therefore different partitions.
52
+ seed = int(seed_raw) if seed_raw is not None else random.randrange(2**31)
53
+
54
+ # Fresh DFS partitioning per request — see Sampler.get_context_subgraph_samples_dfs
55
+ # in research/COINs-KGGeneration. The registry shuffles sample pairs and
56
+ # enforces disjoint (row, col) partitions per sample, so the returned
57
+ # subgraphs are all structurally distinct.
58
+ logger.info("[sample-subgraphs] building fresh pool for %s (seed=%d)", dataset_id, seed)
59
+ pool = registry._build_sample_subgraphs(
60
+ dataset_id, loader, num_subgraphs=count, seed=seed,
61
+ )
62
+ subgraphs = [dict(sg) for sg in pool[:count]]
63
+ for i, sg in enumerate(subgraphs):
64
+ sg["id"] = f"sample_{i + 1}"
65
 
66
  noise_level_raw = request.query_params.get("noise_level")
67
  if noise_level_raw is not None:
 
81
  raise ModelUnavailable(
82
  f"No '{task}' checkpoint available for dataset '{dataset_id}'")
83
 
84
+ logger.info("[sample-subgraphs] loading kg-anomaly model: %s/%s", dataset_id, task)
85
+ try:
86
+ model = registry._load_kg_anomaly_model(dataset_id, task)
87
+ except Exception:
88
+ logger.error("[sample-subgraphs] model load failed:\n%s", traceback.format_exc())
89
+ raise
90
+ logger.info("[sample-subgraphs] model ready, noising %d subgraphs", len(subgraphs))
91
 
92
  for i, sg in enumerate(subgraphs):
93
+ try:
94
+ tensors = build_kg_tensors(sg, loader, model)
95
+ sg["edges"] = apply_edge_noise(
96
+ model, tensors, task, noise_level, seed + i,
97
+ loader=loader, dataset_id=dataset_id, nodes=sg["nodes"])
98
+ sg["num_edges"] = len(sg["edges"])
99
+ logger.info("[sample-subgraphs] noised subgraph %d/%d", i + 1, len(subgraphs))
100
+ except Exception:
101
+ logger.error(
102
+ "[sample-subgraphs] noise failed on subgraph %d:\n%s",
103
+ i, traceback.format_exc(),
104
+ )
105
+ raise
106
+
107
+ for i, sg in enumerate(subgraphs):
108
+ try:
109
+ sg["image"] = render_sample_subgraph_b64(sg, loader, dataset_id)
110
+ logger.info("[sample-subgraphs] rendered subgraph %d/%d", i + 1, len(subgraphs))
111
+ except Exception:
112
+ logger.error(
113
+ "[sample-subgraphs] render failed on subgraph %d:\n%s",
114
+ i, traceback.format_exc(),
115
+ )
116
+ raise
117
 
118
  return Response({
119
  "dataset_id": dataset_id,
 
147
 
148
 
149
  class KgAnomalyCorrectView(APIView):
150
+ renderer_classes = [EventStreamRenderer, JSONRenderer]
151
+
152
  def post(self, request):
153
  data = request.data
154
  registry = ModelRegistry.get()
 
216
 
217
 
218
  class KgAnomalyContinueView(APIView):
219
+ renderer_classes = [EventStreamRenderer, JSONRenderer]
220
+
221
  def post(self, request):
222
  state_b64 = request.data.get("state")
223
  if not state_b64 or not isinstance(state_b64, str):
src/research/COINs-KGGeneration/graph_completion/graphs/preprocess.py CHANGED
@@ -742,6 +742,7 @@ class Sampler:
742
  def get_context_subgraph_samples_dfs(self, max_graph_size: int, graph_indexes: Iterable[AdjacencyIndex],
743
  num_nodes: int, allow_disc: bool = False,
744
  max_samples: int = 0,
 
745
  disable_tqdm: bool = False) -> Iterable[ContextSubgraph]:
746
  _, adj_s_to_t, adj_t_to_s, _, _ = graph_indexes
747
  assignment = -np.ones(num_nodes, dtype=int)
@@ -753,7 +754,12 @@ class Sampler:
753
  progress_bar = tqdm(desc="Assigning nodes to context subgraphs", total=num_nodes, leave=False,
754
  disable=disable_tqdm)
755
 
756
- for i in range(num_nodes):
 
 
 
 
 
757
  if max_subgraphs > 0 and subgraph >= max_subgraphs:
758
  break
759
  if assignment[i] >= 0:
 
742
  def get_context_subgraph_samples_dfs(self, max_graph_size: int, graph_indexes: Iterable[AdjacencyIndex],
743
  num_nodes: int, allow_disc: bool = False,
744
  max_samples: int = 0,
745
+ seed: Optional[int] = None,
746
  disable_tqdm: bool = False) -> Iterable[ContextSubgraph]:
747
  _, adj_s_to_t, adj_t_to_s, _, _ = graph_indexes
748
  assignment = -np.ones(num_nodes, dtype=int)
 
754
  progress_bar = tqdm(desc="Assigning nodes to context subgraphs", total=num_nodes, leave=False,
755
  disable=disable_tqdm)
756
 
757
+ # When seed is given, iterate in a shuffled node order so different seeds produce
758
+ # different partitions. Default order reproduces the deterministic behaviour.
759
+ node_order = (np.random.default_rng(seed).permutation(num_nodes)
760
+ if seed is not None else range(num_nodes))
761
+
762
+ for i in node_order:
763
  if max_subgraphs > 0 and subgraph >= max_subgraphs:
764
  break
765
  if assignment[i] >= 0: