feat(kganomaly): add streaming denoising backend with KG-likelihood metric
Browse files- New endpoints emit Server-Sent Events via api/renderers.py so the
diffusion reverse process streams progress + frame previews instead of
blocking on a single response.
- kg_likelihood.py exposes a per-step mean log-sigmoid score from the
frozen KG embedder + link ranker; kg_anomaly_inference.py logs it on
frame boundaries and attaches kg_log_likelihood / kg_log_likelihood_step
to progress events so the UI can render a "denoising-is-working"
trace alongside step-duration sparklines.
- registry._build_sample_subgraphs now accepts a seed (so each request
gets a different DFS partition), shuffles candidate (row, col) pairs,
and rejects ill-shaped bipartite samples whose halves can't form valid
inpaint quadrants. Sampler.get_context_subgraph_samples_dfs gains the
matching seed parameter in the research code.
- api.yaml + backend README document the new optional progress fields.
- docs/api.yaml +11 -0
- src/backend/README.md +5 -0
- src/backend/api/renderers.py +18 -0
- src/backend/api/services/kg_anomaly_inference.py +357 -138
- src/backend/api/services/kg_likelihood.py +79 -0
- src/backend/api/services/registry.py +60 -10
- src/backend/api/views/kg_anomaly.py +64 -12
- src/research/COINs-KGGeneration/graph_completion/graphs/preprocess.py +7 -1
|
@@ -1806,3 +1806,14 @@ components:
|
|
| 1806 |
total:
|
| 1807 |
type: integer
|
| 1808 |
description: Total steps in the stage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1806 |
total:
|
| 1807 |
type: integer
|
| 1808 |
description: Total steps in the stage
|
| 1809 |
+
kg_log_likelihood:
|
| 1810 |
+
type: number
|
| 1811 |
+
nullable: true
|
| 1812 |
+
description: >
|
| 1813 |
+
Mean log-sigmoid score from the frozen KG embedder + link ranker
|
| 1814 |
+
applied to the edges currently present in the argmax reconstruction.
|
| 1815 |
+
Higher = cleaner. Present only on frame-boundary events.
|
| 1816 |
+
kg_log_likelihood_step:
|
| 1817 |
+
type: integer
|
| 1818 |
+
nullable: true
|
| 1819 |
+
description: Step index that `kg_log_likelihood` corresponds to.
|
|
@@ -111,6 +111,11 @@ event: progress
|
|
| 111 |
data: {"type":"progress","phase":"denoise","step":42,"total_steps":500,"elapsed_ms":2100}
|
| 112 |
```
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
**`event: preview`** — base64 PNG of the graph's current state, emitted at key frames:
|
| 115 |
```
|
| 116 |
event: preview
|
|
|
|
| 111 |
data: {"type":"progress","phase":"denoise","step":42,"total_steps":500,"elapsed_ms":2100}
|
| 112 |
```
|
| 113 |
|
| 114 |
+
KG-anomaly progress events additionally carry an optional `kg_log_likelihood`
|
| 115 |
+
(float) + `kg_log_likelihood_step` (int) on frame boundaries — the mean
|
| 116 |
+
log-sigmoid score from the frozen KG embedder + link ranker on the edges
|
| 117 |
+
currently present in the argmax reconstruction. Higher = cleaner.
|
| 118 |
+
|
| 119 |
**`event: preview`** — base64 PNG of the graph's current state, emitted at key frames:
|
| 120 |
```
|
| 121 |
event: preview
|
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rest_framework.renderers import BaseRenderer
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class EventStreamRenderer(BaseRenderer):
|
| 5 |
+
"""Renderer declaring text/event-stream so DRF content negotiation accepts SSE clients.
|
| 6 |
+
|
| 7 |
+
The streaming views return a StreamingHttpResponse directly, so this
|
| 8 |
+
renderer is never invoked to produce bytes — it exists only to satisfy
|
| 9 |
+
DRF's Accept header negotiation.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
media_type = "text/event-stream"
|
| 13 |
+
format = "sse"
|
| 14 |
+
charset = None
|
| 15 |
+
render_style = "binary"
|
| 16 |
+
|
| 17 |
+
def render(self, data, accepted_media_type=None, renderer_context=None):
|
| 18 |
+
return b""
|
|
@@ -1,13 +1,27 @@
|
|
| 1 |
import base64
|
|
|
|
| 2 |
import io
|
|
|
|
|
|
|
|
|
|
| 3 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from api.services.graphgen_inference import (
|
| 9 |
_frames_to_gif_b64, _pil_to_b64,
|
| 10 |
)
|
|
|
|
|
|
|
| 11 |
|
| 12 |
STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 # 10 MB
|
| 13 |
REQUIRED_STATE_KEYS = {
|
|
@@ -76,7 +90,7 @@ def build_kg_tensors(subgraph, loader, model):
|
|
| 76 |
X_c[0, i] = int(communities[eid])
|
| 77 |
|
| 78 |
n_nodes = torch.tensor([n], dtype=torch.long)
|
| 79 |
-
is_bip = torch.tensor([
|
| 80 |
node_mask = torch.ones(1, n, dtype=torch.bool)
|
| 81 |
|
| 82 |
return {
|
|
@@ -90,14 +104,16 @@ def _to_device(t, device):
|
|
| 90 |
return t.to(device) if isinstance(t, torch.Tensor) else t
|
| 91 |
|
| 92 |
|
| 93 |
-
def apply_edge_noise(model, tensors, task, noise_level, seed=None
|
|
|
|
| 94 |
"""Forward-diffuse the given subgraph's edges at t = noise_level * T.
|
| 95 |
|
| 96 |
For task="correct", only edges inside the inpaint mask (the second half of
|
| 97 |
nodes) are noised, matching what the correction endpoint will regenerate.
|
| 98 |
For task="generate", every edge slot is noised.
|
| 99 |
|
| 100 |
-
Returns a new list of
|
|
|
|
| 101 |
"""
|
| 102 |
from graph_generation.src.utils import get_inpaint_mask
|
| 103 |
from graph_generation.src.diffusion import diffusion_utils
|
|
@@ -137,6 +153,24 @@ def apply_edge_noise(model, tensors, task, noise_level, seed=None):
|
|
| 137 |
E_mixed = E_noised * inpaint_mask + E * (~inpaint_mask)
|
| 138 |
E_int = E_mixed[0].argmax(dim=-1).cpu()
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
edges = []
|
| 141 |
for i in range(n):
|
| 142 |
for j in range(n):
|
|
@@ -145,9 +179,15 @@ def apply_edge_noise(model, tensors, task, noise_level, seed=None):
|
|
| 145 |
cls = int(E_int[i, j])
|
| 146 |
if cls == 0:
|
| 147 |
continue
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
return edges
|
| 152 |
|
| 153 |
|
|
@@ -155,7 +195,7 @@ def apply_edge_noise(model, tensors, task, noise_level, seed=None):
|
|
| 155 |
# Change detection
|
| 156 |
# ---------------------------------------------------------------------------
|
| 157 |
|
| 158 |
-
def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
|
| 159 |
"""Compute before/after edge diff for a directed KG subgraph.
|
| 160 |
|
| 161 |
original_E_int / corrected_E_int: 2-D int tensors (n, n) where 0 = no edge
|
|
@@ -163,6 +203,13 @@ def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
|
|
| 163 |
"""
|
| 164 |
_, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
edges = []
|
| 167 |
summary = {"added": 0, "removed": 0, "modified": 0, "unchanged": 0}
|
| 168 |
|
|
@@ -182,7 +229,7 @@ def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
|
|
| 182 |
edges.append({
|
| 183 |
"source_idx": i, "target_idx": j, "change": "unchanged",
|
| 184 |
"relation_id": c - 1,
|
| 185 |
-
"relation_name":
|
| 186 |
})
|
| 187 |
continue
|
| 188 |
if o == 0 and c > 0:
|
|
@@ -190,23 +237,23 @@ def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
|
|
| 190 |
edges.append({
|
| 191 |
"source_idx": i, "target_idx": j, "change": "added",
|
| 192 |
"relation_id": c - 1,
|
| 193 |
-
"relation_name":
|
| 194 |
})
|
| 195 |
elif o > 0 and c == 0:
|
| 196 |
summary["removed"] += 1
|
| 197 |
edges.append({
|
| 198 |
"source_idx": i, "target_idx": j, "change": "removed",
|
| 199 |
"original_relation_id": o - 1,
|
| 200 |
-
"original_relation_name":
|
| 201 |
})
|
| 202 |
else:
|
| 203 |
summary["modified"] += 1
|
| 204 |
edges.append({
|
| 205 |
"source_idx": i, "target_idx": j, "change": "modified",
|
| 206 |
"original_relation_id": o - 1,
|
| 207 |
-
"original_relation_name":
|
| 208 |
"relation_id": c - 1,
|
| 209 |
-
"relation_name":
|
| 210 |
})
|
| 211 |
|
| 212 |
return {"edges": edges, "summary": summary}
|
|
@@ -216,50 +263,117 @@ def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
|
|
| 216 |
# Rendering
|
| 217 |
# ---------------------------------------------------------------------------
|
| 218 |
|
| 219 |
-
def
|
| 220 |
-
s = str(
|
| 221 |
-
if
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
else:
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
else:
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
|
| 250 |
-
def render_kg_subgraph(E_int, num_nodes, X_index, dataset_id, loader, changes=None):
|
| 251 |
"""Render a directed KG subgraph as a PIL image using networkx + PIL.
|
| 252 |
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
"""
|
| 255 |
-
|
| 256 |
-
from PIL import Image, ImageDraw, ImageFont
|
| 257 |
-
|
| 258 |
-
inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 259 |
|
| 260 |
e = E_int.cpu().tolist()
|
| 261 |
xi = X_index.cpu().tolist()
|
| 262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
G = nx.DiGraph()
|
| 264 |
for i in range(num_nodes):
|
| 265 |
G.add_node(i)
|
|
@@ -270,105 +384,162 @@ def render_kg_subgraph(E_int, num_nodes, X_index, dataset_id, loader, changes=No
|
|
| 270 |
if int(e[i][j]) > 0:
|
| 271 |
G.add_edge(i, j, rel=int(e[i][j]) - 1)
|
| 272 |
|
| 273 |
-
pos = nx.spring_layout(G, seed=42)
|
| 274 |
-
|
| 275 |
-
# Build change lookup: (i, j) -> change_type
|
| 276 |
-
change_lookup = {}
|
| 277 |
if changes is not None:
|
| 278 |
-
for
|
| 279 |
-
|
|
|
|
| 280 |
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
| 286 |
|
|
|
|
| 287 |
img = Image.new("RGB", (size, size), "white")
|
| 288 |
draw = ImageDraw.Draw(img)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
try:
|
| 290 |
-
font = ImageFont.truetype("arial.ttf",
|
| 291 |
-
small_font = ImageFont.truetype("arial.ttf", 9)
|
| 292 |
except (OSError, IOError):
|
| 293 |
font = ImageFont.load_default()
|
| 294 |
-
small_font = font
|
| 295 |
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
-
#
|
| 299 |
-
#
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
else:
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
#
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
| 334 |
draw.ellipse([x - node_r, y - node_r, x + node_r, y + node_r],
|
| 335 |
-
fill=
|
| 336 |
-
eid = int(xi[
|
| 337 |
-
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
return img
|
| 341 |
|
| 342 |
|
| 343 |
-
def
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
draw.
|
| 355 |
|
| 356 |
|
| 357 |
-
def
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
|
| 374 |
# ---------------------------------------------------------------------------
|
|
@@ -406,17 +577,21 @@ def run_standard_correction(model, tensors, dataset_id, task, loader,
|
|
| 406 |
E_given = tensors["E_given"].to(device)
|
| 407 |
y_given = tensors["y_given"].to(device)
|
| 408 |
X_index = tensors["X_index"].to(device)
|
|
|
|
| 409 |
is_bip = tensors["is_bip"].to(device)
|
| 410 |
n_nodes = tensors["n_nodes"].to(device)
|
| 411 |
node_mask = tensors["node_mask"].to(device)
|
| 412 |
n_max = n_nodes.item()
|
|
|
|
| 413 |
|
| 414 |
inpaint_mask = _build_inpaint_mask(
|
| 415 |
task, node_mask, is_bip, model.Edim_output, device)
|
|
|
|
| 416 |
|
| 417 |
original_E_int = E_given[0].argmax(dim=-1).long() # (n, n)
|
| 418 |
original_img = render_kg_subgraph(
|
| 419 |
-
original_E_int, n_max, X_index[0], dataset_id, loader,
|
|
|
|
| 420 |
|
| 421 |
model_T = model.T
|
| 422 |
step_stride = max(1, model_T // diffusion_steps)
|
|
@@ -451,17 +626,28 @@ def run_standard_correction(model, tensors, dataset_id, task, loader,
|
|
| 451 |
}
|
| 452 |
if is_frame:
|
| 453 |
frame = render_kg_subgraph(
|
| 454 |
-
E_int_prev, n_max, X_index[0], dataset_id, loader
|
|
|
|
| 455 |
gif_frames.append(frame)
|
| 456 |
event["preview"] = _pil_to_b64(frame)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
yield event
|
| 458 |
|
| 459 |
X_final, E_final = _collapse_final_kg(model, X, E, y, node_mask)
|
| 460 |
|
| 461 |
corrected_E_int = E_final[0]
|
| 462 |
-
changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
|
| 463 |
corrected_img = render_kg_subgraph(
|
| 464 |
-
corrected_E_int, n_max, X_index[0], dataset_id, loader,
|
|
|
|
| 465 |
|
| 466 |
elapsed_ms = int((time.time() - t0) * 1000)
|
| 467 |
yield {
|
|
@@ -493,9 +679,11 @@ def run_multiprox_correction_init(model, tensors, dataset_id, task, loader,
|
|
| 493 |
|
| 494 |
inpaint_mask = _build_inpaint_mask(
|
| 495 |
task, node_mask, is_bip, model.Edim_output, device)
|
|
|
|
| 496 |
original_E_int = E_given[0].argmax(dim=-1).long()
|
| 497 |
original_img = render_kg_subgraph(
|
| 498 |
-
original_E_int, n_max, X_index[0], dataset_id, loader,
|
|
|
|
| 499 |
|
| 500 |
t0 = time.time()
|
| 501 |
# Sample initial noise for each of M Gibbs chains
|
|
@@ -524,9 +712,10 @@ def run_multiprox_correction_init(model, tensors, dataset_id, task, loader,
|
|
| 524 |
agg_y = torch.median(y_ens.float(), dim=1).values
|
| 525 |
X_int, E_int = _collapse_final_kg(model, X_given, agg_E, agg_y, node_mask)
|
| 526 |
corrected_E_int = E_int[0]
|
| 527 |
-
changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
|
| 528 |
preview_img = render_kg_subgraph(
|
| 529 |
-
corrected_E_int, n_max, X_index[0], dataset_id, loader,
|
|
|
|
| 530 |
elapsed_ms = int((time.time() - t0) * 1000)
|
| 531 |
|
| 532 |
state = {
|
|
@@ -535,11 +724,13 @@ def run_multiprox_correction_init(model, tensors, dataset_id, task, loader,
|
|
| 535 |
"y": y_ens.cpu(),
|
| 536 |
"n_nodes": n_nodes.cpu(),
|
| 537 |
"dataset_id": dataset_id,
|
|
|
|
| 538 |
"task": task,
|
| 539 |
"X_index": X_index.cpu(),
|
| 540 |
"X_c": X_c.cpu(),
|
| 541 |
"is_bip": is_bip.cpu(),
|
| 542 |
"original_E_int": original_E_int.cpu(),
|
|
|
|
| 543 |
"T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime,
|
| 544 |
"gibbs_chain_freq": gibbs_chain_freq,
|
| 545 |
"inner_step": 0, "step": 0,
|
|
@@ -562,9 +753,12 @@ def run_multiprox_correction_step(model, state, loader):
|
|
| 562 |
E = state["E"].to(device)
|
| 563 |
y = state["y"].to(device)
|
| 564 |
X_index = state["X_index"].to(device)
|
|
|
|
| 565 |
is_bip = state["is_bip"].to(device)
|
| 566 |
n_nodes = state["n_nodes"].to(device)
|
| 567 |
original_E_int = state["original_E_int"].to(device)
|
|
|
|
|
|
|
| 568 |
|
| 569 |
T = state["T"]
|
| 570 |
n = state["n"]
|
|
@@ -578,6 +772,7 @@ def run_multiprox_correction_step(model, state, loader):
|
|
| 578 |
n_max = int(n_nodes.item())
|
| 579 |
node_mask = torch.ones(1, n_max, dtype=torch.bool, device=device)
|
| 580 |
inpaint_mask = _build_inpaint_mask(task, node_mask, is_bip, model.Edim_output, device)
|
|
|
|
| 581 |
|
| 582 |
fixed_t_norm = t * torch.ones((1, 1), dtype=torch.float, device=device)
|
| 583 |
fixed_s_norm = fixed_t_norm - (1.0 / T)
|
|
@@ -608,8 +803,9 @@ def run_multiprox_correction_step(model, state, loader):
|
|
| 608 |
prev_y = torch.median(y.float(), dim=1).values
|
| 609 |
_, prev_Ei = _collapse_final_kg(model, X_given, prev_E, prev_y, node_mask)
|
| 610 |
preview_img = render_kg_subgraph(
|
| 611 |
-
prev_Ei[0], n_max, X_index[0], dataset_id, loader
|
| 612 |
-
|
|
|
|
| 613 |
"type": "progress",
|
| 614 |
"phase": "gibbs",
|
| 615 |
"step": i + 1,
|
|
@@ -617,6 +813,16 @@ def run_multiprox_correction_step(model, state, loader):
|
|
| 617 |
"elapsed_ms": int((time.time() - t0) * 1000),
|
| 618 |
"preview": _pil_to_b64(preview_img),
|
| 619 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
|
| 621 |
new_inner_step = inner_step + steps_this_call
|
| 622 |
round_complete = new_inner_step >= m
|
|
@@ -649,21 +855,34 @@ def run_multiprox_correction_step(model, state, loader):
|
|
| 649 |
"elapsed_ms": int((time.time() - t0) * 1000),
|
| 650 |
}
|
| 651 |
if is_frame:
|
|
|
|
| 652 |
event["preview"] = _pil_to_b64(render_kg_subgraph(
|
| 653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
yield event
|
| 655 |
|
| 656 |
X_int, E_int = _collapse_final_kg(model, cur_X, cur_E, cur_y, node_mask)
|
| 657 |
|
| 658 |
corrected_E_int = E_int[0]
|
| 659 |
-
changes = compute_changes(
|
| 660 |
corrected_img = render_kg_subgraph(
|
| 661 |
-
corrected_E_int, n_max, X_index[0], dataset_id, loader,
|
|
|
|
| 662 |
elapsed_ms = int((time.time() - t0) * 1000)
|
| 663 |
|
| 664 |
updated_state = {
|
| 665 |
**state,
|
| 666 |
"E": E.cpu(), "y": y.cpu(),
|
|
|
|
| 667 |
"step": new_step, "inner_step": new_inner_step,
|
| 668 |
}
|
| 669 |
yield {
|
|
|
|
| 1 |
import base64
|
| 2 |
+
import faulthandler
|
| 3 |
import io
|
| 4 |
+
import logging
|
| 5 |
+
import math
|
| 6 |
+
import sys
|
| 7 |
import time
|
| 8 |
+
import traceback
|
| 9 |
+
|
| 10 |
+
faulthandler.enable(file=sys.stderr, all_threads=True)
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
import torch
|
| 15 |
import torch.nn.functional as F
|
| 16 |
+
import numpy as np
|
| 17 |
+
import networkx as nx
|
| 18 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 19 |
|
| 20 |
from api.services.graphgen_inference import (
|
| 21 |
_frames_to_gif_b64, _pil_to_b64,
|
| 22 |
)
|
| 23 |
+
from api.services.kg_likelihood import kg_edge_log_likelihood
|
| 24 |
+
from api.utils import clean_entity_name, clean_relation_name
|
| 25 |
|
| 26 |
STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 # 10 MB
|
| 27 |
REQUIRED_STATE_KEYS = {
|
|
|
|
| 90 |
X_c[0, i] = int(communities[eid])
|
| 91 |
|
| 92 |
n_nodes = torch.tensor([n], dtype=torch.long)
|
| 93 |
+
is_bip = torch.tensor([bool(subgraph.get("is_bip", False))], dtype=torch.bool)
|
| 94 |
node_mask = torch.ones(1, n, dtype=torch.bool)
|
| 95 |
|
| 96 |
return {
|
|
|
|
| 104 |
return t.to(device) if isinstance(t, torch.Tensor) else t
|
| 105 |
|
| 106 |
|
| 107 |
+
def apply_edge_noise(model, tensors, task, noise_level, seed=None,
|
| 108 |
+
loader=None, dataset_id=None, nodes=None):
|
| 109 |
"""Forward-diffuse the given subgraph's edges at t = noise_level * T.
|
| 110 |
|
| 111 |
For task="correct", only edges inside the inpaint mask (the second half of
|
| 112 |
nodes) are noised, matching what the correction endpoint will regenerate.
|
| 113 |
For task="generate", every edge slot is noised.
|
| 114 |
|
| 115 |
+
Returns a new list of edge dicts enriched with cleaned relation/entity
|
| 116 |
+
names when ``loader``/``dataset_id``/``nodes`` are supplied.
|
| 117 |
"""
|
| 118 |
from graph_generation.src.utils import get_inpaint_mask
|
| 119 |
from graph_generation.src.diffusion import diffusion_utils
|
|
|
|
| 153 |
E_mixed = E_noised * inpaint_mask + E * (~inpaint_mask)
|
| 154 |
E_int = E_mixed[0].argmax(dim=-1).cpu()
|
| 155 |
|
| 156 |
+
inv_relations = None
|
| 157 |
+
if loader is not None and dataset_id is not None:
|
| 158 |
+
_, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 159 |
+
|
| 160 |
+
def node_name(idx):
|
| 161 |
+
if nodes is not None and 0 <= idx < len(nodes):
|
| 162 |
+
return nodes[idx].get("entity_name") or f"#{nodes[idx].get('entity_id', idx)}"
|
| 163 |
+
return f"#{idx}"
|
| 164 |
+
|
| 165 |
+
def relation_label(rid):
|
| 166 |
+
if inv_relations is None:
|
| 167 |
+
return None
|
| 168 |
+
raw = inv_relations.get(rid)
|
| 169 |
+
if raw is None or raw != raw or str(raw).strip() == "":
|
| 170 |
+
return f"rel#{rid}"
|
| 171 |
+
cleaned = clean_relation_name(str(raw), dataset_id)
|
| 172 |
+
return cleaned if cleaned else f"rel#{rid}"
|
| 173 |
+
|
| 174 |
edges = []
|
| 175 |
for i in range(n):
|
| 176 |
for j in range(n):
|
|
|
|
| 179 |
cls = int(E_int[i, j])
|
| 180 |
if cls == 0:
|
| 181 |
continue
|
| 182 |
+
rel_id = cls - 1
|
| 183 |
+
edge = {
|
| 184 |
+
"source_idx": i, "target_idx": j, "relation_id": rel_id,
|
| 185 |
+
}
|
| 186 |
+
if inv_relations is not None:
|
| 187 |
+
edge["relation_name"] = relation_label(rel_id)
|
| 188 |
+
edge["entity_name_source"] = node_name(i)
|
| 189 |
+
edge["entity_name_target"] = node_name(j)
|
| 190 |
+
edges.append(edge)
|
| 191 |
return edges
|
| 192 |
|
| 193 |
|
|
|
|
| 195 |
# Change detection
|
| 196 |
# ---------------------------------------------------------------------------
|
| 197 |
|
| 198 |
+
def compute_changes(original_E_int, corrected_E_int, num_nodes, loader, dataset_id):
|
| 199 |
"""Compute before/after edge diff for a directed KG subgraph.
|
| 200 |
|
| 201 |
original_E_int / corrected_E_int: 2-D int tensors (n, n) where 0 = no edge
|
|
|
|
| 203 |
"""
|
| 204 |
_, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 205 |
|
| 206 |
+
def rel_name(idx):
|
| 207 |
+
raw = inv_relations.get(idx)
|
| 208 |
+
if raw is None or raw != raw or str(raw).strip() == "":
|
| 209 |
+
return f"rel#{idx}"
|
| 210 |
+
cleaned = clean_relation_name(str(raw), dataset_id)
|
| 211 |
+
return cleaned if cleaned else f"rel#{idx}"
|
| 212 |
+
|
| 213 |
edges = []
|
| 214 |
summary = {"added": 0, "removed": 0, "modified": 0, "unchanged": 0}
|
| 215 |
|
|
|
|
| 229 |
edges.append({
|
| 230 |
"source_idx": i, "target_idx": j, "change": "unchanged",
|
| 231 |
"relation_id": c - 1,
|
| 232 |
+
"relation_name": rel_name(c - 1),
|
| 233 |
})
|
| 234 |
continue
|
| 235 |
if o == 0 and c > 0:
|
|
|
|
| 237 |
edges.append({
|
| 238 |
"source_idx": i, "target_idx": j, "change": "added",
|
| 239 |
"relation_id": c - 1,
|
| 240 |
+
"relation_name": rel_name(c - 1),
|
| 241 |
})
|
| 242 |
elif o > 0 and c == 0:
|
| 243 |
summary["removed"] += 1
|
| 244 |
edges.append({
|
| 245 |
"source_idx": i, "target_idx": j, "change": "removed",
|
| 246 |
"original_relation_id": o - 1,
|
| 247 |
+
"original_relation_name": rel_name(o - 1),
|
| 248 |
})
|
| 249 |
else:
|
| 250 |
summary["modified"] += 1
|
| 251 |
edges.append({
|
| 252 |
"source_idx": i, "target_idx": j, "change": "modified",
|
| 253 |
"original_relation_id": o - 1,
|
| 254 |
+
"original_relation_name": rel_name(o - 1),
|
| 255 |
"relation_id": c - 1,
|
| 256 |
+
"relation_name": rel_name(c - 1),
|
| 257 |
})
|
| 258 |
|
| 259 |
return {"edges": edges, "summary": summary}
|
|
|
|
| 263 |
# Rendering
|
| 264 |
# ---------------------------------------------------------------------------
|
| 265 |
|
| 266 |
+
def _truncate_label(s, limit):
|
| 267 |
+
s = str(s)
|
| 268 |
+
return s if len(s) <= limit else s[: limit - 1] + "…"
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# Green-palette anchors matching the site's primary colour: pale mint -> vivid
|
| 272 |
+
# green -> deep forest. Used to colour nodes by the normalized-Laplacian
|
| 273 |
+
# eigenvector.
|
| 274 |
+
_GREEN_LOW = (212, 237, 218)
|
| 275 |
+
_GREEN_MID = (82, 180, 120)
|
| 276 |
+
_GREEN_HIGH = (22, 80, 50)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def _green_rgb(t):
|
| 280 |
+
"""Map t in [-1, 1] to an (r, g, b) tuple on a three-anchor green gradient."""
|
| 281 |
+
t = max(-1.0, min(1.0, float(t)))
|
| 282 |
+
if t < 0:
|
| 283 |
+
w = t + 1.0 # 0 at -1, 1 at 0
|
| 284 |
+
a, b = _GREEN_LOW, _GREEN_MID
|
| 285 |
else:
|
| 286 |
+
w = t # 0 at 0, 1 at 1
|
| 287 |
+
a, b = _GREEN_MID, _GREEN_HIGH
|
| 288 |
+
return (
|
| 289 |
+
int(a[0] + (b[0] - a[0]) * w),
|
| 290 |
+
int(a[1] + (b[1] - a[1]) * w),
|
| 291 |
+
int(a[2] + (b[2] - a[2]) * w),
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _quad_bezier(p0, p1, p2, steps=20):
|
| 296 |
+
out = []
|
| 297 |
+
for k in range(steps + 1):
|
| 298 |
+
t = k / steps
|
| 299 |
+
u = 1 - t
|
| 300 |
+
x = u * u * p0[0] + 2 * u * t * p1[0] + t * t * p2[0]
|
| 301 |
+
y = u * u * p0[1] + 2 * u * t * p1[1] + t * t * p2[1]
|
| 302 |
+
out.append((x, y))
|
| 303 |
+
return out
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _draw_arrowhead(draw, tip, direction, color):
|
| 307 |
+
ux, uy = direction
|
| 308 |
+
angle = math.atan2(uy, ux)
|
| 309 |
+
ah_len = 9
|
| 310 |
+
ah_angle = math.radians(25)
|
| 311 |
+
x, y = tip
|
| 312 |
+
x1 = x - ah_len * math.cos(angle - ah_angle)
|
| 313 |
+
y1 = y - ah_len * math.sin(angle - ah_angle)
|
| 314 |
+
x2 = x - ah_len * math.cos(angle + ah_angle)
|
| 315 |
+
y2 = y - ah_len * math.sin(angle + ah_angle)
|
| 316 |
+
draw.polygon([(x, y), (x1, y1), (x2, y2)], fill=color)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _draw_curve(draw, p0, p2, color, width=2, dashed=False, rad=0.2):
|
| 320 |
+
"""Draw a quadratic Bézier from p0 to p2 curved by `rad` (fraction of chord length)."""
|
| 321 |
+
mx, my = (p0[0] + p2[0]) / 2, (p0[1] + p2[1]) / 2
|
| 322 |
+
dx, dy = p2[0] - p0[0], p2[1] - p0[1]
|
| 323 |
+
# Perpendicular offset for the control point
|
| 324 |
+
nx_, ny_ = -dy, dx
|
| 325 |
+
nlen = max(1.0, math.hypot(nx_, ny_))
|
| 326 |
+
ctrl = (mx + nx_ / nlen * rad * math.hypot(dx, dy),
|
| 327 |
+
my + ny_ / nlen * rad * math.hypot(dx, dy))
|
| 328 |
+
pts = _quad_bezier(p0, ctrl, p2, steps=24)
|
| 329 |
+
if dashed:
|
| 330 |
+
for k in range(0, len(pts) - 1, 2):
|
| 331 |
+
draw.line([pts[k], pts[k + 1]], fill=color, width=width)
|
| 332 |
else:
|
| 333 |
+
draw.line(pts, fill=color, width=width, joint="curve")
|
| 334 |
+
# Tangent at p2 (for arrowhead direction)
|
| 335 |
+
px, py = pts[-2]
|
| 336 |
+
tx, ty = p2[0] - px, p2[1] - py
|
| 337 |
+
tlen = max(1.0, math.hypot(tx, ty))
|
| 338 |
+
return (tx / tlen, ty / tlen)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def _bipartite_layout(row_nodes, col_nodes):
|
| 342 |
+
"""Two-column layout: row nodes on the left, col nodes on the right, evenly spaced."""
|
| 343 |
+
pos = {}
|
| 344 |
+
|
| 345 |
+
def place(nodes, x):
|
| 346 |
+
k = len(nodes)
|
| 347 |
+
for i, n in enumerate(nodes):
|
| 348 |
+
y = 1.0 - 2.0 * (i + 1) / (k + 1) # evenly spaced in [-1, 1], top-down
|
| 349 |
+
pos[n] = (x, y)
|
| 350 |
+
|
| 351 |
+
place(row_nodes, -1.0)
|
| 352 |
+
place(col_nodes, 1.0)
|
| 353 |
+
return pos
|
| 354 |
|
| 355 |
|
| 356 |
+
def render_kg_subgraph(E_int, num_nodes, X_index, dataset_id, loader, changes=None, is_bip=False):
|
| 357 |
"""Render a directed KG subgraph as a PIL image using networkx + PIL.
|
| 358 |
|
| 359 |
+
Ports the improvements from KnowledgeGraphVisualization.visualize_non_molecule
|
| 360 |
+
without importing matplotlib (which conflicts with torch on Windows): isolated
|
| 361 |
+
singleton filter, spring_layout(k=1), coolwarm node colouring from the
|
| 362 |
+
normalized-Laplacian eigenvector, and curved edges. When `changes` is
|
| 363 |
+
provided, per-edge CHANGE_COLORS override the default grey and "removed"
|
| 364 |
+
edges are drawn dashed. When `is_bip`, uses a two-column layout that
|
| 365 |
+
separates the row/col partitions and visually marks the inpaint quadrants.
|
| 366 |
"""
|
| 367 |
+
inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
e = E_int.cpu().tolist()
|
| 370 |
xi = X_index.cpu().tolist()
|
| 371 |
|
| 372 |
+
change_lookup = {}
|
| 373 |
+
if changes is not None:
|
| 374 |
+
for entry in changes.get("edges", []):
|
| 375 |
+
change_lookup[(entry["source_idx"], entry["target_idx"])] = entry["change"]
|
| 376 |
+
|
| 377 |
G = nx.DiGraph()
|
| 378 |
for i in range(num_nodes):
|
| 379 |
G.add_node(i)
|
|
|
|
| 384 |
if int(e[i][j]) > 0:
|
| 385 |
G.add_edge(i, j, rel=int(e[i][j]) - 1)
|
| 386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
if changes is not None:
|
| 388 |
+
for (i, j), ct in change_lookup.items():
|
| 389 |
+
if ct == "removed" and not G.has_edge(i, j):
|
| 390 |
+
G.add_edge(i, j, rel=None)
|
| 391 |
|
| 392 |
+
# Bipartite: keep every node so the row/col structure stays visible.
|
| 393 |
+
# Community: drop isolated singletons so the spring layout focuses on structure.
|
| 394 |
+
if is_bip:
|
| 395 |
+
graph = G.copy()
|
| 396 |
+
else:
|
| 397 |
+
components = [G.subgraph(c).copy() for c in nx.connected_components(G.to_undirected())]
|
| 398 |
+
components = [c for c in components if c.number_of_nodes() > 1]
|
| 399 |
+
graph = nx.compose_all(components) if components else G
|
| 400 |
|
| 401 |
+
size = 520
|
| 402 |
img = Image.new("RGB", (size, size), "white")
|
| 403 |
draw = ImageDraw.Draw(img)
|
| 404 |
+
|
| 405 |
+
if graph.number_of_nodes() == 0:
|
| 406 |
+
return img
|
| 407 |
+
|
| 408 |
try:
|
| 409 |
+
font = ImageFont.truetype("arial.ttf", 12)
|
|
|
|
| 410 |
except (OSError, IOError):
|
| 411 |
font = ImageFont.load_default()
|
|
|
|
| 412 |
|
| 413 |
+
if is_bip:
|
| 414 |
+
row_nodes = [n for n in graph.nodes() if n < num_nodes // 2]
|
| 415 |
+
col_nodes = [n for n in graph.nodes() if n >= num_nodes // 2]
|
| 416 |
+
pos = _bipartite_layout(row_nodes, col_nodes)
|
| 417 |
+
else:
|
| 418 |
+
pos = nx.spring_layout(graph, k=1, iterations=100, seed=42)
|
| 419 |
|
| 420 |
+
# Normalized Laplacian eigenvector for node colouring. Use torch.linalg.eigh
|
| 421 |
+
# rather than numpy.linalg.eigh — on Windows, numpy's MKL DLLs conflict with
|
| 422 |
+
# torch's (Windows code 0xc06d007f), and torch is already healthy in-process.
|
| 423 |
+
try:
|
| 424 |
+
L = nx.normalized_laplacian_matrix(graph.to_undirected()).toarray()
|
| 425 |
+
L_t = torch.from_numpy(L).to(torch.float64)
|
| 426 |
+
_, U_t = torch.linalg.eigh(L_t)
|
| 427 |
+
U = U_t.numpy()
|
| 428 |
+
eigen_dim = 1 if U.shape[1] > 1 else 0
|
| 429 |
+
vec = U[:, eigen_dim]
|
| 430 |
+
m_abs = max(abs(vec.min()), abs(vec.max()), 1e-9)
|
| 431 |
+
vec_norm = vec / m_abs # now in [-1, 1]
|
| 432 |
+
except Exception:
|
| 433 |
+
logger.warning(
|
| 434 |
+
"eigenvector colouring failed; using flat colour:\n%s",
|
| 435 |
+
traceback.format_exc(),
|
| 436 |
+
)
|
| 437 |
+
vec_norm = np.zeros(graph.number_of_nodes())
|
| 438 |
+
|
| 439 |
+
node_list = list(graph.nodes())
|
| 440 |
+
node_color = {n: _green_rgb(vec_norm[k]) for k, n in enumerate(node_list)}
|
| 441 |
+
|
| 442 |
+
xs, ys = zip(*pos.values())
|
| 443 |
+
x_min, x_max = min(xs), max(xs)
|
| 444 |
+
y_min, y_max = min(ys), max(ys)
|
| 445 |
+
x_span = (x_max - x_min) or 1.0
|
| 446 |
+
y_span = (y_max - y_min) or 1.0
|
| 447 |
+
margin = 55
|
| 448 |
+
scale = (size - 2 * margin) / max(x_span, y_span)
|
| 449 |
+
cx = size / 2 - (x_min + x_span / 2) * scale
|
| 450 |
+
cy = size / 2 + (y_min + y_span / 2) * scale # flip y so "up" is up
|
| 451 |
+
|
| 452 |
+
def to_px(p):
|
| 453 |
+
return (cx + p[0] * scale, cy - p[1] * scale)
|
| 454 |
+
|
| 455 |
+
pixel_pos = {n: to_px(pos[n]) for n in graph.nodes()}
|
| 456 |
+
|
| 457 |
+
node_r = 12
|
| 458 |
+
|
| 459 |
+
# Edges first so nodes overlay them
|
| 460 |
+
for (i, j) in graph.edges():
|
| 461 |
+
ct = change_lookup.get((i, j))
|
| 462 |
+
dashed = (ct == "removed")
|
| 463 |
+
if ct is not None:
|
| 464 |
+
color = CHANGE_COLORS.get(ct, "#6b6b6b")
|
| 465 |
else:
|
| 466 |
+
color = "#6b6b6b"
|
| 467 |
+
p0 = pixel_pos[i]
|
| 468 |
+
p2 = pixel_pos[j]
|
| 469 |
+
# Shorten endpoints so curve doesn't overlap node circles
|
| 470 |
+
dx, dy = p2[0] - p0[0], p2[1] - p0[1]
|
| 471 |
+
dist = max(1.0, math.hypot(dx, dy))
|
| 472 |
+
ux, uy = dx / dist, dy / dist
|
| 473 |
+
sx, sy = p0[0] + ux * node_r, p0[1] + uy * node_r
|
| 474 |
+
ex, ey = p2[0] - ux * node_r, p2[1] - uy * node_r
|
| 475 |
+
tangent = _draw_curve(draw, (sx, sy), (ex, ey), color,
|
| 476 |
+
width=2, dashed=dashed, rad=0.2)
|
| 477 |
+
_draw_arrowhead(draw, (ex, ey), tangent, color)
|
| 478 |
+
|
| 479 |
+
for n in graph.nodes():
|
| 480 |
+
x, y = pixel_pos[n]
|
| 481 |
+
r, g, b = node_color[n]
|
| 482 |
draw.ellipse([x - node_r, y - node_r, x + node_r, y + node_r],
|
| 483 |
+
fill=(r, g, b), outline="#333333", width=1)
|
| 484 |
+
eid = int(xi[n]) if n < len(xi) else n
|
| 485 |
+
raw = inv_nodes.get(eid)
|
| 486 |
+
if raw is None or raw != raw or str(raw).strip() == "":
|
| 487 |
+
label = f"#{eid}"
|
| 488 |
+
else:
|
| 489 |
+
cleaned = clean_entity_name(str(raw), dataset_id)
|
| 490 |
+
label = cleaned if cleaned else f"#{eid}"
|
| 491 |
+
label = _truncate_label(label, 20)
|
| 492 |
+
_draw_text_with_bg(draw, (x + node_r + 3, y - 7), label, font, fill="#111111")
|
| 493 |
|
| 494 |
return img
|
| 495 |
|
| 496 |
|
| 497 |
+
def _draw_text_with_bg(draw, xy, text, font, fill):
|
| 498 |
+
"""Draw text with a semi-opaque white pad so labels stay readable over lines."""
|
| 499 |
+
try:
|
| 500 |
+
bbox = draw.textbbox(xy, text, font=font)
|
| 501 |
+
pad = 1
|
| 502 |
+
draw.rectangle(
|
| 503 |
+
[bbox[0] - pad, bbox[1] - pad, bbox[2] + pad, bbox[3] + pad],
|
| 504 |
+
fill="white",
|
| 505 |
+
)
|
| 506 |
+
except Exception:
|
| 507 |
+
pass
|
| 508 |
+
draw.text(xy, text, font=font, fill=fill)
|
| 509 |
|
| 510 |
|
| 511 |
+
def render_sample_subgraph_b64(subgraph, loader, dataset_id):
|
| 512 |
+
"""Render a sample subgraph dict (API payload shape) to a PNG data URI.
|
| 513 |
+
|
| 514 |
+
Uses the same renderer as inference outputs so thumbnails are visually
|
| 515 |
+
consistent with the before/after images produced later.
|
| 516 |
+
"""
|
| 517 |
+
nodes = subgraph["nodes"]
|
| 518 |
+
edges = subgraph["edges"]
|
| 519 |
+
n = len(nodes)
|
| 520 |
+
if n == 0:
|
| 521 |
+
return None
|
| 522 |
+
|
| 523 |
+
try:
|
| 524 |
+
E_int = torch.zeros(n, n, dtype=torch.long)
|
| 525 |
+
for e in edges:
|
| 526 |
+
src = int(e["source_idx"])
|
| 527 |
+
tgt = int(e["target_idx"])
|
| 528 |
+
rel = int(e["relation_id"])
|
| 529 |
+
if 0 <= src < n and 0 <= tgt < n:
|
| 530 |
+
E_int[src, tgt] = rel + 1
|
| 531 |
+
|
| 532 |
+
X_index = torch.tensor([int(node["entity_id"]) for node in nodes], dtype=torch.long)
|
| 533 |
+
is_bip = bool(subgraph.get("is_bip", False))
|
| 534 |
+
img = render_kg_subgraph(
|
| 535 |
+
E_int, n, X_index, dataset_id, loader, changes=None, is_bip=is_bip)
|
| 536 |
+
return _pil_to_b64(img)
|
| 537 |
+
except Exception:
|
| 538 |
+
logger.error(
|
| 539 |
+
"render_sample_subgraph_b64 failed: dataset=%s n=%d\n%s",
|
| 540 |
+
dataset_id, n, traceback.format_exc(),
|
| 541 |
+
)
|
| 542 |
+
raise
|
| 543 |
|
| 544 |
|
| 545 |
# ---------------------------------------------------------------------------
|
|
|
|
| 577 |
E_given = tensors["E_given"].to(device)
|
| 578 |
y_given = tensors["y_given"].to(device)
|
| 579 |
X_index = tensors["X_index"].to(device)
|
| 580 |
+
X_c = tensors["X_c"].to(device)
|
| 581 |
is_bip = tensors["is_bip"].to(device)
|
| 582 |
n_nodes = tensors["n_nodes"].to(device)
|
| 583 |
node_mask = tensors["node_mask"].to(device)
|
| 584 |
n_max = n_nodes.item()
|
| 585 |
+
kg_experiment = getattr(model, "kg_experiment", None)
|
| 586 |
|
| 587 |
inpaint_mask = _build_inpaint_mask(
|
| 588 |
task, node_mask, is_bip, model.Edim_output, device)
|
| 589 |
+
is_bip_bool = bool(is_bip.item())
|
| 590 |
|
| 591 |
original_E_int = E_given[0].argmax(dim=-1).long() # (n, n)
|
| 592 |
original_img = render_kg_subgraph(
|
| 593 |
+
original_E_int, n_max, X_index[0], dataset_id, loader,
|
| 594 |
+
changes=None, is_bip=is_bip_bool)
|
| 595 |
|
| 596 |
model_T = model.T
|
| 597 |
step_stride = max(1, model_T // diffusion_steps)
|
|
|
|
| 626 |
}
|
| 627 |
if is_frame:
|
| 628 |
frame = render_kg_subgraph(
|
| 629 |
+
E_int_prev, n_max, X_index[0], dataset_id, loader,
|
| 630 |
+
is_bip=is_bip_bool)
|
| 631 |
gif_frames.append(frame)
|
| 632 |
event["preview"] = _pil_to_b64(frame)
|
| 633 |
+
if kg_experiment is not None:
|
| 634 |
+
ll = kg_edge_log_likelihood(
|
| 635 |
+
E_int_prev, X_given[0], X_index[0], X_c[0], kg_experiment)
|
| 636 |
+
if ll is not None:
|
| 637 |
+
event["kg_log_likelihood"] = ll
|
| 638 |
+
event["kg_log_likelihood_step"] = emitted
|
| 639 |
+
logger.info(
|
| 640 |
+
"[kg-anomaly] denoise step=%d/%d kg_log_lik=%.4f",
|
| 641 |
+
emitted, total_loop_steps, ll)
|
| 642 |
yield event
|
| 643 |
|
| 644 |
X_final, E_final = _collapse_final_kg(model, X, E, y, node_mask)
|
| 645 |
|
| 646 |
corrected_E_int = E_final[0]
|
| 647 |
+
changes = compute_changes(original_E_int, corrected_E_int, n_max, loader, dataset_id)
|
| 648 |
corrected_img = render_kg_subgraph(
|
| 649 |
+
corrected_E_int, n_max, X_index[0], dataset_id, loader,
|
| 650 |
+
changes=changes, is_bip=is_bip_bool)
|
| 651 |
|
| 652 |
elapsed_ms = int((time.time() - t0) * 1000)
|
| 653 |
yield {
|
|
|
|
| 679 |
|
| 680 |
inpaint_mask = _build_inpaint_mask(
|
| 681 |
task, node_mask, is_bip, model.Edim_output, device)
|
| 682 |
+
is_bip_bool = bool(is_bip.item())
|
| 683 |
original_E_int = E_given[0].argmax(dim=-1).long()
|
| 684 |
original_img = render_kg_subgraph(
|
| 685 |
+
original_E_int, n_max, X_index[0], dataset_id, loader,
|
| 686 |
+
changes=None, is_bip=is_bip_bool)
|
| 687 |
|
| 688 |
t0 = time.time()
|
| 689 |
# Sample initial noise for each of M Gibbs chains
|
|
|
|
| 712 |
agg_y = torch.median(y_ens.float(), dim=1).values
|
| 713 |
X_int, E_int = _collapse_final_kg(model, X_given, agg_E, agg_y, node_mask)
|
| 714 |
corrected_E_int = E_int[0]
|
| 715 |
+
changes = compute_changes(original_E_int, corrected_E_int, n_max, loader, dataset_id)
|
| 716 |
preview_img = render_kg_subgraph(
|
| 717 |
+
corrected_E_int, n_max, X_index[0], dataset_id, loader,
|
| 718 |
+
changes=changes, is_bip=is_bip_bool)
|
| 719 |
elapsed_ms = int((time.time() - t0) * 1000)
|
| 720 |
|
| 721 |
state = {
|
|
|
|
| 724 |
"y": y_ens.cpu(),
|
| 725 |
"n_nodes": n_nodes.cpu(),
|
| 726 |
"dataset_id": dataset_id,
|
| 727 |
+
"is_bip": bool(is_bip_bool),
|
| 728 |
"task": task,
|
| 729 |
"X_index": X_index.cpu(),
|
| 730 |
"X_c": X_c.cpu(),
|
| 731 |
"is_bip": is_bip.cpu(),
|
| 732 |
"original_E_int": original_E_int.cpu(),
|
| 733 |
+
"prev_E_int": corrected_E_int.cpu(),
|
| 734 |
"T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime,
|
| 735 |
"gibbs_chain_freq": gibbs_chain_freq,
|
| 736 |
"inner_step": 0, "step": 0,
|
|
|
|
| 753 |
E = state["E"].to(device)
|
| 754 |
y = state["y"].to(device)
|
| 755 |
X_index = state["X_index"].to(device)
|
| 756 |
+
X_c = state["X_c"].to(device)
|
| 757 |
is_bip = state["is_bip"].to(device)
|
| 758 |
n_nodes = state["n_nodes"].to(device)
|
| 759 |
original_E_int = state["original_E_int"].to(device)
|
| 760 |
+
prev_E_int = state.get("prev_E_int", state["original_E_int"]).to(device)
|
| 761 |
+
kg_experiment = getattr(model, "kg_experiment", None)
|
| 762 |
|
| 763 |
T = state["T"]
|
| 764 |
n = state["n"]
|
|
|
|
| 772 |
n_max = int(n_nodes.item())
|
| 773 |
node_mask = torch.ones(1, n_max, dtype=torch.bool, device=device)
|
| 774 |
inpaint_mask = _build_inpaint_mask(task, node_mask, is_bip, model.Edim_output, device)
|
| 775 |
+
is_bip_bool = bool(is_bip.item())
|
| 776 |
|
| 777 |
fixed_t_norm = t * torch.ones((1, 1), dtype=torch.float, device=device)
|
| 778 |
fixed_s_norm = fixed_t_norm - (1.0 / T)
|
|
|
|
| 803 |
prev_y = torch.median(y.float(), dim=1).values
|
| 804 |
_, prev_Ei = _collapse_final_kg(model, X_given, prev_E, prev_y, node_mask)
|
| 805 |
preview_img = render_kg_subgraph(
|
| 806 |
+
prev_Ei[0], n_max, X_index[0], dataset_id, loader,
|
| 807 |
+
is_bip=is_bip_bool)
|
| 808 |
+
event = {
|
| 809 |
"type": "progress",
|
| 810 |
"phase": "gibbs",
|
| 811 |
"step": i + 1,
|
|
|
|
| 813 |
"elapsed_ms": int((time.time() - t0) * 1000),
|
| 814 |
"preview": _pil_to_b64(preview_img),
|
| 815 |
}
|
| 816 |
+
if kg_experiment is not None:
|
| 817 |
+
ll = kg_edge_log_likelihood(
|
| 818 |
+
prev_Ei[0], X_given[0], X_index[0], X_c[0], kg_experiment)
|
| 819 |
+
if ll is not None:
|
| 820 |
+
event["kg_log_likelihood"] = ll
|
| 821 |
+
event["kg_log_likelihood_step"] = i + 1
|
| 822 |
+
logger.info(
|
| 823 |
+
"[kg-anomaly] gibbs step=%d/%d kg_log_lik=%.4f",
|
| 824 |
+
i + 1, steps_this_call, ll)
|
| 825 |
+
yield event
|
| 826 |
|
| 827 |
new_inner_step = inner_step + steps_this_call
|
| 828 |
round_complete = new_inner_step >= m
|
|
|
|
| 855 |
"elapsed_ms": int((time.time() - t0) * 1000),
|
| 856 |
}
|
| 857 |
if is_frame:
|
| 858 |
+
refine_E_int = discrete_s.E[0].long()
|
| 859 |
event["preview"] = _pil_to_b64(render_kg_subgraph(
|
| 860 |
+
refine_E_int, n_max, X_index[0], dataset_id, loader,
|
| 861 |
+
is_bip=is_bip_bool))
|
| 862 |
+
if kg_experiment is not None:
|
| 863 |
+
ll = kg_edge_log_likelihood(
|
| 864 |
+
refine_E_int, X_given[0], X_index[0], X_c[0], kg_experiment)
|
| 865 |
+
if ll is not None:
|
| 866 |
+
event["kg_log_likelihood"] = ll
|
| 867 |
+
event["kg_log_likelihood_step"] = j + 1
|
| 868 |
+
logger.info(
|
| 869 |
+
"[kg-anomaly] refine step=%d/%d kg_log_lik=%.4f",
|
| 870 |
+
j + 1, P, ll)
|
| 871 |
yield event
|
| 872 |
|
| 873 |
X_int, E_int = _collapse_final_kg(model, cur_X, cur_E, cur_y, node_mask)
|
| 874 |
|
| 875 |
corrected_E_int = E_int[0]
|
| 876 |
+
changes = compute_changes(prev_E_int, corrected_E_int, n_max, loader, dataset_id)
|
| 877 |
corrected_img = render_kg_subgraph(
|
| 878 |
+
corrected_E_int, n_max, X_index[0], dataset_id, loader,
|
| 879 |
+
changes=changes, is_bip=is_bip_bool)
|
| 880 |
elapsed_ms = int((time.time() - t0) * 1000)
|
| 881 |
|
| 882 |
updated_state = {
|
| 883 |
**state,
|
| 884 |
"E": E.cpu(), "y": y.cpu(),
|
| 885 |
+
"prev_E_int": corrected_E_int.cpu(),
|
| 886 |
"step": new_step, "inner_step": new_inner_step,
|
| 887 |
}
|
| 888 |
yield {
|
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Per-step KG link-prediction log-likelihood for the denoising loop.
|
| 2 |
+
|
| 3 |
+
Wraps the math of `KGLikelihoodMetric.update` (see research/COINs-KGGeneration
|
| 4 |
+
.../metrics/abstract_metrics.py) in a one-shot, stateless helper. We query
|
| 5 |
+
the frozen KG embedder + link ranker on the edges currently present in the
|
| 6 |
+
argmax reconstruction and return their mean log-sigmoid score — a positive
|
| 7 |
+
higher-is-better value that rises as the graph becomes cleaner.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.nn.functional import logsigmoid, one_hot
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def kg_edge_log_likelihood(E_int, X, X_index, X_c, kg_experiment):
|
| 19 |
+
"""Mean log-sigmoid link-ranker score over edges currently present.
|
| 20 |
+
|
| 21 |
+
E_int: (n, n) long tensor. 0 = no edge; otherwise class = relation_id + 1.
|
| 22 |
+
X: (n, num_node_types) one-hot node types (unbatched, float).
|
| 23 |
+
X_index: (n,) long dataset-global entity ids (unbatched).
|
| 24 |
+
X_c: (n,) long community ids (unbatched).
|
| 25 |
+
kg_experiment: COINs experiment exposing .embedder, .link_ranker,
|
| 26 |
+
.loader.num_relations, .mini_batch_size, .device.
|
| 27 |
+
|
| 28 |
+
Returns a Python float (log-likelihood per edge) or None if no edges are
|
| 29 |
+
present or the scoring pass fails for any reason.
|
| 30 |
+
"""
|
| 31 |
+
from graph_completion.graphs.preprocess import QueryData
|
| 32 |
+
from graph_completion.graphs.queries import Query
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
embedder = kg_experiment.embedder
|
| 36 |
+
link_ranker = kg_experiment.link_ranker.link_ranker
|
| 37 |
+
num_relations = kg_experiment.loader.num_relations
|
| 38 |
+
kg_device = kg_experiment.device
|
| 39 |
+
mini_batch_size = kg_experiment.mini_batch_size
|
| 40 |
+
|
| 41 |
+
nz = E_int.nonzero(as_tuple=False)
|
| 42 |
+
if nz.numel() == 0:
|
| 43 |
+
return None
|
| 44 |
+
nz = nz[nz[:, 0] != nz[:, 1]]
|
| 45 |
+
if nz.numel() == 0:
|
| 46 |
+
return None
|
| 47 |
+
s, t = nz[:, 0], nz[:, 1]
|
| 48 |
+
r = E_int[s, t] - 1
|
| 49 |
+
|
| 50 |
+
e_s, e_t = X_index[s].long(), X_index[t].long()
|
| 51 |
+
x_s, x_t = X[s].float(), X[t].float()
|
| 52 |
+
c_s, c_t = X_c[s].long(), X_c[t].long()
|
| 53 |
+
|
| 54 |
+
# Stable sort by (c_s, c_t) — the embedder batches by community pair.
|
| 55 |
+
s_sort = torch.argsort(c_s)
|
| 56 |
+
t_sort = torch.sort(c_t[s_sort], stable=True).indices
|
| 57 |
+
pick = lambda v: v[s_sort][t_sort]
|
| 58 |
+
e = [pick(e_s), pick(e_t)]
|
| 59 |
+
x = [pick(x_s), pick(x_t)]
|
| 60 |
+
c = [pick(c_s), pick(c_t)]
|
| 61 |
+
r = pick(r)
|
| 62 |
+
edge_attr = [one_hot(r, num_relations + 1).float()]
|
| 63 |
+
|
| 64 |
+
q = Query("1p")
|
| 65 |
+
q.build_query_tree()
|
| 66 |
+
query_data = QueryData(q, e=e, x=x, c=c, edge_attr=edge_attr).to(kg_device)
|
| 67 |
+
|
| 68 |
+
scores = []
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
for qd_batch in query_data.batch_split(mini_batch_size):
|
| 71 |
+
q_emb, a_emb = embedder(qd_batch)
|
| 72 |
+
scores.append(link_ranker(q_emb, a_emb))
|
| 73 |
+
scores = torch.cat(scores, dim=0).view(-1)
|
| 74 |
+
if scores.numel() == 0:
|
| 75 |
+
return None
|
| 76 |
+
return float(logsigmoid(scores).mean().item())
|
| 77 |
+
except Exception as exc:
|
| 78 |
+
logger.warning("[kg-likelihood] skipped: %s", exc)
|
| 79 |
+
return None
|
|
@@ -706,28 +706,73 @@ class ModelRegistry:
|
|
| 706 |
except Exception:
|
| 707 |
logger.exception("Failed to generate sample subgraphs for %s", dataset_id)
|
| 708 |
|
| 709 |
-
def _build_sample_subgraphs(self, dataset_id, loader, num_subgraphs=
|
| 710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 711 |
inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 712 |
node_types = loader.dataset.node_data.type.values
|
| 713 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
# Use the Sampler's DFS partitioning to get context subgraphs
|
| 715 |
samples = loader.sampler.get_context_subgraph_samples_dfs(
|
| 716 |
max_graph_size, loader.graph_indexes, loader.num_nodes,
|
| 717 |
-
max_samples=num_subgraphs * 5, disable_tqdm=True,
|
| 718 |
)
|
| 719 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
subgraphs = []
|
| 721 |
for subgraph_row, subgraph_col, nodes_row, nodes_col, edges in samples:
|
| 722 |
if len(subgraphs) >= num_subgraphs:
|
| 723 |
break
|
|
|
|
|
|
|
| 724 |
if len(edges) < 3:
|
| 725 |
continue
|
| 726 |
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
sg_nodes = nodes_row + nodes_col
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
|
| 732 |
node_idx = {n: i for i, n in enumerate(sg_nodes)}
|
| 733 |
|
|
@@ -736,7 +781,7 @@ class ModelRegistry:
|
|
| 736 |
type_id = int(node_types[n]) if n < len(node_types) else 0
|
| 737 |
nodes.append({
|
| 738 |
"entity_id": n,
|
| 739 |
-
"entity_name":
|
| 740 |
"type_id": type_id,
|
| 741 |
})
|
| 742 |
|
|
@@ -747,18 +792,23 @@ class ModelRegistry:
|
|
| 747 |
"source_idx": node_idx[h],
|
| 748 |
"target_idx": node_idx[t],
|
| 749 |
"relation_id": r,
|
| 750 |
-
"relation_name":
|
| 751 |
-
"entity_name_source":
|
| 752 |
-
"entity_name_target":
|
| 753 |
})
|
| 754 |
|
| 755 |
subgraphs.append({
|
| 756 |
"id": f"sample_{len(subgraphs) + 1}",
|
| 757 |
"num_nodes": len(nodes),
|
| 758 |
"num_edges": len(edge_list),
|
|
|
|
|
|
|
| 759 |
"nodes": nodes,
|
| 760 |
"edges": edge_list,
|
| 761 |
})
|
|
|
|
|
|
|
|
|
|
| 762 |
|
| 763 |
# Free the partitioning data stored on the sampler
|
| 764 |
loader.sampler.context_subgraphs_nodes = None
|
|
|
|
| 706 |
except Exception:
|
| 707 |
logger.exception("Failed to generate sample subgraphs for %s", dataset_id)
|
| 708 |
|
| 709 |
+
def _build_sample_subgraphs(self, dataset_id, loader, num_subgraphs=40,
|
| 710 |
+
max_graph_size=10, seed=None):
|
| 711 |
+
"""Build sample subgraphs using the Sampler's DFS-based context subgraph partitioning.
|
| 712 |
+
|
| 713 |
+
When ``seed`` is provided, the DFS iterates node indices in a shuffled order, so
|
| 714 |
+
different seeds produce different partitions. Without a seed the order is
|
| 715 |
+
deterministic (original research-code behaviour).
|
| 716 |
+
"""
|
| 717 |
inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 718 |
node_types = loader.dataset.node_data.type.values
|
| 719 |
|
| 720 |
+
def entity_label(idx):
|
| 721 |
+
raw = inv_nodes.get(idx)
|
| 722 |
+
if raw is None or raw != raw or str(raw).strip() == "": # None / NaN / empty
|
| 723 |
+
return f"#{idx}"
|
| 724 |
+
cleaned = clean_entity_name(str(raw), dataset_id)
|
| 725 |
+
return cleaned if cleaned else f"#{idx}"
|
| 726 |
+
|
| 727 |
+
def relation_label(idx):
|
| 728 |
+
raw = inv_relations.get(idx)
|
| 729 |
+
if raw is None or raw != raw or str(raw).strip() == "":
|
| 730 |
+
return f"rel#{idx}"
|
| 731 |
+
cleaned = clean_relation_name(str(raw), dataset_id)
|
| 732 |
+
return cleaned if cleaned else f"rel#{idx}"
|
| 733 |
+
|
| 734 |
# Use the Sampler's DFS partitioning to get context subgraphs
|
| 735 |
samples = loader.sampler.get_context_subgraph_samples_dfs(
|
| 736 |
max_graph_size, loader.graph_indexes, loader.num_nodes,
|
| 737 |
+
max_samples=num_subgraphs * 5, seed=seed, disable_tqdm=True,
|
| 738 |
)
|
| 739 |
|
| 740 |
+
# Randomly pick (row, col) partition pairs so each sample is structurally
|
| 741 |
+
# distinct from the others. Without shuffling, the DFS returns samples in
|
| 742 |
+
# nested (k, l) order, which means the first N samples all reuse
|
| 743 |
+
# partition 0's nodes. Shuffling + a disjoint-partitions guard gives 5
|
| 744 |
+
# different subgraphs each call.
|
| 745 |
+
import random as _random
|
| 746 |
+
rng = _random.Random(seed)
|
| 747 |
+
samples = list(samples)
|
| 748 |
+
rng.shuffle(samples)
|
| 749 |
+
|
| 750 |
+
used_partitions = set()
|
| 751 |
subgraphs = []
|
| 752 |
for subgraph_row, subgraph_col, nodes_row, nodes_col, edges in samples:
|
| 753 |
if len(subgraphs) >= num_subgraphs:
|
| 754 |
break
|
| 755 |
+
if subgraph_row in used_partitions or subgraph_col in used_partitions:
|
| 756 |
+
continue
|
| 757 |
if len(edges) < 3:
|
| 758 |
continue
|
| 759 |
|
| 760 |
+
is_bip = (subgraph_row != subgraph_col)
|
| 761 |
+
if is_bip:
|
| 762 |
+
# Inpaint mask math assumes balanced halves (n/4, n/2, 3n/4 split).
|
| 763 |
+
# Only accept bipartite samples where row/col are the same size and
|
| 764 |
+
# divisible by 4, so the four quadrants are well-defined.
|
| 765 |
+
if len(nodes_row) != len(nodes_col) or len(nodes_row) < 2:
|
| 766 |
+
continue
|
| 767 |
+
if (2 * len(nodes_row)) % 4 != 0:
|
| 768 |
+
continue
|
| 769 |
sg_nodes = nodes_row + nodes_col
|
| 770 |
+
row_size = len(nodes_row)
|
| 771 |
+
else:
|
| 772 |
+
if len(nodes_row) < 4 or len(nodes_row) % 2 != 0:
|
| 773 |
+
continue
|
| 774 |
+
sg_nodes = nodes_row
|
| 775 |
+
row_size = len(nodes_row)
|
| 776 |
|
| 777 |
node_idx = {n: i for i, n in enumerate(sg_nodes)}
|
| 778 |
|
|
|
|
| 781 |
type_id = int(node_types[n]) if n < len(node_types) else 0
|
| 782 |
nodes.append({
|
| 783 |
"entity_id": n,
|
| 784 |
+
"entity_name": entity_label(n),
|
| 785 |
"type_id": type_id,
|
| 786 |
})
|
| 787 |
|
|
|
|
| 792 |
"source_idx": node_idx[h],
|
| 793 |
"target_idx": node_idx[t],
|
| 794 |
"relation_id": r,
|
| 795 |
+
"relation_name": relation_label(r),
|
| 796 |
+
"entity_name_source": entity_label(h),
|
| 797 |
+
"entity_name_target": entity_label(t),
|
| 798 |
})
|
| 799 |
|
| 800 |
subgraphs.append({
|
| 801 |
"id": f"sample_{len(subgraphs) + 1}",
|
| 802 |
"num_nodes": len(nodes),
|
| 803 |
"num_edges": len(edge_list),
|
| 804 |
+
"is_bip": is_bip,
|
| 805 |
+
"row_size": row_size,
|
| 806 |
"nodes": nodes,
|
| 807 |
"edges": edge_list,
|
| 808 |
})
|
| 809 |
+
used_partitions.add(subgraph_row)
|
| 810 |
+
if is_bip:
|
| 811 |
+
used_partitions.add(subgraph_col)
|
| 812 |
|
| 813 |
# Free the partitioning data stored on the sampler
|
| 814 |
loader.sampler.context_subgraphs_nodes = None
|
|
@@ -1,12 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from rest_framework.response import Response
|
| 2 |
from rest_framework.views import APIView
|
| 3 |
|
| 4 |
from api.exceptions import InvalidRequestError, ModelUnavailable, NotFoundError
|
|
|
|
| 5 |
from api.services.constants import KG_ANOMALY_DATASET_META
|
| 6 |
-
from api.services.kg_anomaly_inference import
|
|
|
|
|
|
|
| 7 |
from api.services.registry import ModelRegistry
|
| 8 |
from api.views.graph_generation import _streaming_sse_response
|
| 9 |
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class KgAnomalyDatasetsView(APIView):
|
| 12 |
def get(self, request):
|
|
@@ -29,14 +39,29 @@ class KgAnomalySampleSubgraphsView(APIView):
|
|
| 29 |
raise NotFoundError(f"Dataset '{dataset_id}' not found")
|
| 30 |
|
| 31 |
registry = ModelRegistry.get()
|
| 32 |
-
|
| 33 |
-
if
|
| 34 |
raise NotFoundError(f"No sample subgraphs available for dataset '{dataset_id}'")
|
| 35 |
|
| 36 |
count = int(request.query_params.get("count", 5))
|
| 37 |
count = max(1, min(10, count))
|
| 38 |
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
noise_level_raw = request.query_params.get("noise_level")
|
| 42 |
if noise_level_raw is not None:
|
|
@@ -56,16 +81,39 @@ class KgAnomalySampleSubgraphsView(APIView):
|
|
| 56 |
raise ModelUnavailable(
|
| 57 |
f"No '{task}' checkpoint available for dataset '{dataset_id}'")
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
|
| 65 |
for i, sg in enumerate(subgraphs):
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
return Response({
|
| 71 |
"dataset_id": dataset_id,
|
|
@@ -99,6 +147,8 @@ def _validate_subgraph(subgraph):
|
|
| 99 |
|
| 100 |
|
| 101 |
class KgAnomalyCorrectView(APIView):
|
|
|
|
|
|
|
| 102 |
def post(self, request):
|
| 103 |
data = request.data
|
| 104 |
registry = ModelRegistry.get()
|
|
@@ -166,6 +216,8 @@ class KgAnomalyCorrectView(APIView):
|
|
| 166 |
|
| 167 |
|
| 168 |
class KgAnomalyContinueView(APIView):
|
|
|
|
|
|
|
| 169 |
def post(self, request):
|
| 170 |
state_b64 = request.data.get("state")
|
| 171 |
if not state_b64 or not isinstance(state_b64, str):
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import random
|
| 3 |
+
import traceback
|
| 4 |
+
|
| 5 |
+
from rest_framework.renderers import JSONRenderer
|
| 6 |
from rest_framework.response import Response
|
| 7 |
from rest_framework.views import APIView
|
| 8 |
|
| 9 |
from api.exceptions import InvalidRequestError, ModelUnavailable, NotFoundError
|
| 10 |
+
from api.renderers import EventStreamRenderer
|
| 11 |
from api.services.constants import KG_ANOMALY_DATASET_META
|
| 12 |
+
from api.services.kg_anomaly_inference import (
|
| 13 |
+
apply_edge_noise, build_kg_tensors, render_sample_subgraph_b64,
|
| 14 |
+
)
|
| 15 |
from api.services.registry import ModelRegistry
|
| 16 |
from api.views.graph_generation import _streaming_sse_response
|
| 17 |
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
|
| 21 |
class KgAnomalyDatasetsView(APIView):
|
| 22 |
def get(self, request):
|
|
|
|
| 39 |
raise NotFoundError(f"Dataset '{dataset_id}' not found")
|
| 40 |
|
| 41 |
registry = ModelRegistry.get()
|
| 42 |
+
loader = registry.loaders.get(dataset_id)
|
| 43 |
+
if loader is None:
|
| 44 |
raise NotFoundError(f"No sample subgraphs available for dataset '{dataset_id}'")
|
| 45 |
|
| 46 |
count = int(request.query_params.get("count", 5))
|
| 47 |
count = max(1, min(10, count))
|
| 48 |
|
| 49 |
+
seed_raw = request.query_params.get("seed")
|
| 50 |
+
# Random per-request seed when the caller doesn't pin one, so each call
|
| 51 |
+
# produces a different DFS node-order shuffle and therefore different partitions.
|
| 52 |
+
seed = int(seed_raw) if seed_raw is not None else random.randrange(2**31)
|
| 53 |
+
|
| 54 |
+
# Fresh DFS partitioning per request — see Sampler.get_context_subgraph_samples_dfs
|
| 55 |
+
# in research/COINs-KGGeneration. The registry shuffles sample pairs and
|
| 56 |
+
# enforces disjoint (row, col) partitions per sample, so the returned
|
| 57 |
+
# subgraphs are all structurally distinct.
|
| 58 |
+
logger.info("[sample-subgraphs] building fresh pool for %s (seed=%d)", dataset_id, seed)
|
| 59 |
+
pool = registry._build_sample_subgraphs(
|
| 60 |
+
dataset_id, loader, num_subgraphs=count, seed=seed,
|
| 61 |
+
)
|
| 62 |
+
subgraphs = [dict(sg) for sg in pool[:count]]
|
| 63 |
+
for i, sg in enumerate(subgraphs):
|
| 64 |
+
sg["id"] = f"sample_{i + 1}"
|
| 65 |
|
| 66 |
noise_level_raw = request.query_params.get("noise_level")
|
| 67 |
if noise_level_raw is not None:
|
|
|
|
| 81 |
raise ModelUnavailable(
|
| 82 |
f"No '{task}' checkpoint available for dataset '{dataset_id}'")
|
| 83 |
|
| 84 |
+
logger.info("[sample-subgraphs] loading kg-anomaly model: %s/%s", dataset_id, task)
|
| 85 |
+
try:
|
| 86 |
+
model = registry._load_kg_anomaly_model(dataset_id, task)
|
| 87 |
+
except Exception:
|
| 88 |
+
logger.error("[sample-subgraphs] model load failed:\n%s", traceback.format_exc())
|
| 89 |
+
raise
|
| 90 |
+
logger.info("[sample-subgraphs] model ready, noising %d subgraphs", len(subgraphs))
|
| 91 |
|
| 92 |
for i, sg in enumerate(subgraphs):
|
| 93 |
+
try:
|
| 94 |
+
tensors = build_kg_tensors(sg, loader, model)
|
| 95 |
+
sg["edges"] = apply_edge_noise(
|
| 96 |
+
model, tensors, task, noise_level, seed + i,
|
| 97 |
+
loader=loader, dataset_id=dataset_id, nodes=sg["nodes"])
|
| 98 |
+
sg["num_edges"] = len(sg["edges"])
|
| 99 |
+
logger.info("[sample-subgraphs] noised subgraph %d/%d", i + 1, len(subgraphs))
|
| 100 |
+
except Exception:
|
| 101 |
+
logger.error(
|
| 102 |
+
"[sample-subgraphs] noise failed on subgraph %d:\n%s",
|
| 103 |
+
i, traceback.format_exc(),
|
| 104 |
+
)
|
| 105 |
+
raise
|
| 106 |
+
|
| 107 |
+
for i, sg in enumerate(subgraphs):
|
| 108 |
+
try:
|
| 109 |
+
sg["image"] = render_sample_subgraph_b64(sg, loader, dataset_id)
|
| 110 |
+
logger.info("[sample-subgraphs] rendered subgraph %d/%d", i + 1, len(subgraphs))
|
| 111 |
+
except Exception:
|
| 112 |
+
logger.error(
|
| 113 |
+
"[sample-subgraphs] render failed on subgraph %d:\n%s",
|
| 114 |
+
i, traceback.format_exc(),
|
| 115 |
+
)
|
| 116 |
+
raise
|
| 117 |
|
| 118 |
return Response({
|
| 119 |
"dataset_id": dataset_id,
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
class KgAnomalyCorrectView(APIView):
|
| 150 |
+
renderer_classes = [EventStreamRenderer, JSONRenderer]
|
| 151 |
+
|
| 152 |
def post(self, request):
|
| 153 |
data = request.data
|
| 154 |
registry = ModelRegistry.get()
|
|
|
|
| 216 |
|
| 217 |
|
| 218 |
class KgAnomalyContinueView(APIView):
|
| 219 |
+
renderer_classes = [EventStreamRenderer, JSONRenderer]
|
| 220 |
+
|
| 221 |
def post(self, request):
|
| 222 |
state_b64 = request.data.get("state")
|
| 223 |
if not state_b64 or not isinstance(state_b64, str):
|
|
@@ -742,6 +742,7 @@ class Sampler:
|
|
| 742 |
def get_context_subgraph_samples_dfs(self, max_graph_size: int, graph_indexes: Iterable[AdjacencyIndex],
|
| 743 |
num_nodes: int, allow_disc: bool = False,
|
| 744 |
max_samples: int = 0,
|
|
|
|
| 745 |
disable_tqdm: bool = False) -> Iterable[ContextSubgraph]:
|
| 746 |
_, adj_s_to_t, adj_t_to_s, _, _ = graph_indexes
|
| 747 |
assignment = -np.ones(num_nodes, dtype=int)
|
|
@@ -753,7 +754,12 @@ class Sampler:
|
|
| 753 |
progress_bar = tqdm(desc="Assigning nodes to context subgraphs", total=num_nodes, leave=False,
|
| 754 |
disable=disable_tqdm)
|
| 755 |
|
| 756 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
if max_subgraphs > 0 and subgraph >= max_subgraphs:
|
| 758 |
break
|
| 759 |
if assignment[i] >= 0:
|
|
|
|
| 742 |
def get_context_subgraph_samples_dfs(self, max_graph_size: int, graph_indexes: Iterable[AdjacencyIndex],
|
| 743 |
num_nodes: int, allow_disc: bool = False,
|
| 744 |
max_samples: int = 0,
|
| 745 |
+
seed: Optional[int] = None,
|
| 746 |
disable_tqdm: bool = False) -> Iterable[ContextSubgraph]:
|
| 747 |
_, adj_s_to_t, adj_t_to_s, _, _ = graph_indexes
|
| 748 |
assignment = -np.ones(num_nodes, dtype=int)
|
|
|
|
| 754 |
progress_bar = tqdm(desc="Assigning nodes to context subgraphs", total=num_nodes, leave=False,
|
| 755 |
disable=disable_tqdm)
|
| 756 |
|
| 757 |
+
# When seed is given, iterate in a shuffled node order so different seeds produce
|
| 758 |
+
# different partitions. Default order reproduces the deterministic behaviour.
|
| 759 |
+
node_order = (np.random.default_rng(seed).permutation(num_nodes)
|
| 760 |
+
if seed is not None else range(num_nodes))
|
| 761 |
+
|
| 762 |
+
for i in node_order:
|
| 763 |
if max_subgraphs > 0 and subgraph >= max_subgraphs:
|
| 764 |
break
|
| 765 |
if assignment[i] >= 0:
|