| import logging |
| import os |
| import random |
| import threading |
| from pathlib import Path |
|
|
| import yaml |
| from django.conf import settings |
|
|
| from api.services.constants import COINS_CONFIG_SUFFIX, COINS_DATASET_META |
| from api.utils import clean_entity_name, clean_relation_name |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _safe_load_lightning_checkpoint(cls, ckpt_path): |
| """Load a PyTorch Lightning checkpoint without triggering DDP or deepcopy crashes. |
| |
| Bypasses ``load_from_checkpoint`` entirely: we torch.load the checkpoint |
| to CPU, extract the hyper-parameters, construct the model, load the |
| state_dict, and then move to the target device. This avoids: |
| - DDP __setstate__/__getstate__ needing a process group |
| - save_hyperparameters deepcopy crashing on pickled DDP/datamodule objects |
| - CUDA OOM from loading the full checkpoint (with heavy hparams) onto GPU |
| """ |
| import torch |
| import torch.nn.parallel.distributed as _ddp_mod |
|
|
| |
| _orig_set = _ddp_mod.DistributedDataParallel.__setstate__ |
| _orig_get = _ddp_mod.DistributedDataParallel.__getstate__ |
| _ddp_mod.DistributedDataParallel.__setstate__ = lambda self, state: self.__dict__.update(state) |
| _ddp_mod.DistributedDataParallel.__getstate__ = lambda self: self.__dict__ |
| try: |
| |
| |
| ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False) |
| finally: |
| _ddp_mod.DistributedDataParallel.__setstate__ = _orig_set |
| _ddp_mod.DistributedDataParallel.__getstate__ = _orig_get |
|
|
| hparams = ckpt.get("hyper_parameters", {}) |
| hparams["train_metrics"] = None |
| hparams["sampling_metrics"] = None |
| hparams["visualization_tools"] = None |
|
|
| |
| _orig_save = cls.save_hyperparameters |
| cls.save_hyperparameters = lambda self, *a, **kw: None |
| try: |
| model = cls(**hparams) |
| finally: |
| cls.save_hyperparameters = _orig_save |
|
|
| |
| model.load_state_dict(ckpt["state_dict"], strict=False) |
| del ckpt |
| model.to(settings.TORCH_DEVICE) |
| model.eval() |
| return model |
|
|
|
|
| |
| |
| |
| |
| HF_CHECKPOINTS_REPO = os.environ.get("HF_CHECKPOINTS_REPO", "Bani57/checkpoints") |
|
|
| |
| |
| |
| _CHECKPOINT_SUBDIRS = ( |
| Path("COINs-KGGeneration") / "graph_completion" / "checkpoints", |
| Path("COINs-KGGeneration") / "graph_generation" / "checkpoints", |
| Path("MultiProxAn") / "checkpoints", |
| ) |
|
|
| |
| _SAMPLER_HPARS = { |
| "query_structure": ["1p"], |
| "num_negative_samples": 128, |
| "num_neighbours": 10, |
| "random_walk_length": 10, |
| "context_radius": 2, |
| "pagerank_importances": True, |
| "walks_relation_specific": True, |
| } |
|
|
| |
| _DATASET_BASE = { |
| "freebase": { |
| "leiden_resolution": 5.0e-3, |
| "loader_hpars": { |
| "dataset_name": "freebase", "simulated": False, |
| "sample_source": "smore", "sampler_hpars": _SAMPLER_HPARS, |
| }, |
| }, |
| "wordnet": { |
| "leiden_resolution": 0.0, |
| "loader_hpars": { |
| "dataset_name": "wordnet", "simulated": False, |
| "sample_source": "smore", "sampler_hpars": _SAMPLER_HPARS, |
| }, |
| }, |
| "nell": { |
| "leiden_resolution": 2.0e-5, |
| "loader_hpars": { |
| "dataset_name": "nell", "simulated": False, |
| "sample_source": "smore", "sampler_hpars": _SAMPLER_HPARS, |
| }, |
| }, |
| } |
|
|
| |
| |
| |
| |
| |
| |
| _CHECKPOINT_SEEDS = { |
| ("freebase", "transe"): 4089853924, ("freebase", "distmult"): 4089853924, |
| ("freebase", "complex"): 4089853924, ("freebase", "rotate"): 4089853924, |
| ("freebase", "q2b"): 1503136574, ("freebase", "kbgat"): 123456789, |
| ("wordnet", "transe"): 1919180054, ("wordnet", "distmult"): 1919180054, |
| ("wordnet", "complex"): 1919180054, ("wordnet", "rotate"): 1919180054, |
| ("wordnet", "q2b"): 3312854056, ("wordnet", "kbgat"): 123456789, |
| ("nell", "transe"): 3192206669, ("nell", "distmult"): 3192206669, |
| ("nell", "rotate"): 3192206669, |
| ("nell", "complex"): 2409194445, ("nell", "q2b"): 3793326028, ("nell", "kbgat"): 123456789, |
| } |
|
|
| |
| VANILLA_SEEDS = { |
| "freebase": 4089853924, |
| "wordnet": 1919180054, |
| "nell": 3192206669, |
| } |
|
|
|
|
| def get_checkpoint_config(dataset_id, algorithm): |
| """Return the full config for a specific (dataset, algorithm) checkpoint.""" |
| base = _DATASET_BASE[dataset_id] |
| seed = _CHECKPOINT_SEEDS.get((dataset_id, algorithm)) |
| if seed is None: |
| seed = VANILLA_SEEDS[dataset_id] |
| return {"seed": seed, **base} |
|
|
|
|
| def _adapt_shape_mismatches(ckpt_state_dict, model_state_dict): |
| """Fix ±1 shape mismatches between checkpoint and current model state_dicts. |
| |
| These arise from minor dataset version differences (e.g. NELL gained/lost one node type |
| or one relation between the training run and the current data load) or from update_state_dict |
| adding an extra padding column that wasn't there at training time. Safe to trim because the |
| removed slice corresponded to an absent type/relation that was never looked up at inference. |
| """ |
| import torch as pt |
| adapted = dict(ckpt_state_dict) |
| for key, model_w in model_state_dict.items(): |
| if key not in adapted: |
| continue |
| ckpt_w = adapted[key] |
| if ckpt_w.shape == model_w.shape: |
| continue |
| v = ckpt_w |
| for dim in range(min(len(model_w.shape), len(v.shape))): |
| if v.shape[dim] == model_w.shape[dim] + 1: |
| slc = [slice(None)] * len(v.shape) |
| slc[dim] = slice(None, model_w.shape[dim]) |
| v = v[tuple(slc)] |
| logger.debug("Trimmed %s dim %d: %s -> %s", key, dim, ckpt_w.shape, v.shape) |
| elif model_w.shape[dim] == v.shape[dim] + 1: |
| pad_shape = list(v.shape) |
| pad_shape[dim] = 1 |
| v = pt.cat([v, pt.zeros(pad_shape, dtype=v.dtype)], dim=dim) |
| logger.debug("Padded %s dim %d: %s -> %s", key, dim, ckpt_w.shape, v.shape) |
| adapted[key] = v |
| return adapted |
|
|
|
|
| def _adapt_mlp_bn_keys(state_dict): |
| """Rename MLP BatchNorm keys from torch_geometric 2.0.x to 2.3.x format. |
| |
| In torch_geometric 2.0.x, MLP stored BN directly as ``norms.N.weight``. |
| In 2.3.x the BN is wrapped in a ModuleList proxy, producing ``norms.N.module.weight``. |
| This affects all MLP BN parameters: weight, bias, running_mean, running_var, |
| num_batches_tracked. Renaming restores the trained BN statistics (running_mean / |
| running_var can differ substantially from the (0, 1) defaults, which is why loading |
| without this fix produced near-zero Q2B scores). |
| """ |
| import re |
| adapted = {} |
| _bn_pattern = re.compile(r"(\.norms\.\d+)\.(weight|bias|running_mean|running_var|num_batches_tracked)$") |
| for key, value in state_dict.items(): |
| new_key = _bn_pattern.sub(r"\1.module.\2", key) |
| adapted[new_key] = value |
| return adapted |
|
|
|
|
| def _adapt_kbgat_state_dict(ckpt_state_dict, model_state_dict): |
| """Adapt a KBGAT embedder checkpoint from torch_geometric 2.0.x to 2.3.x format. |
| |
| In torch_geometric 2.0.x, the final GATConv layer's out_channels was interpreted as |
| the *total* output width (divided by num_heads internally). In 2.3.x it is per-head, |
| so the weight matrices in the last conv are num_heads-times wider. |
| |
| **Accuracy impact of the repeat strategy**: Both new attention heads receive the full |
| original weight matrix, making them identical. Multi-head diversity is lost — the |
| model degenerates to single-head attention with doubled hidden width. Predictions |
| remain directionally correct (the trained linear transformation is preserved) but |
| may not match the published benchmark numbers exactly. |
| """ |
| import torch as pt |
| adapted = {} |
| for key, value in ckpt_state_dict.items(): |
| if key not in model_state_dict: |
| adapted[key] = value |
| continue |
| expected_shape = model_state_dict[key].shape |
| if value.shape == expected_shape: |
| adapted[key] = value |
| continue |
| |
| |
| v = value |
| for dim in range(len(expected_shape)): |
| if dim < len(v.shape) and expected_shape[dim] == 2 * v.shape[dim]: |
| v = pt.cat([v, v], dim=dim) |
| adapted[key] = v |
| return adapted |
|
|
|
|
| def _free_heavy_arrays(loader): |
| """Free memory-intensive arrays from a Loader that aren't needed for discovery endpoints.""" |
| loader.node_neighbours = None |
| loader.com_neighbours = None |
| loader.node_adjacency = None |
| loader.com_adjacency = None |
| loader.label_community_edge_freqs = None |
| loader.label_community_edge_freqs_index = None |
| loader.machines = None |
| loader.graph = None |
| loader.node_importances = None |
| loader.neighbour_importances = None |
| loader.out_degrees = None |
| loader.in_degrees = None |
| loader.degrees = None |
| loader.node_degree_type_freqs = None |
| loader.relation_freqs = None |
|
|
|
|
| class SubgraphInfo: |
| """Holds pre-computed sample subgraphs for a KG anomaly dataset.""" |
|
|
| __slots__ = ("subgraphs",) |
|
|
| def __init__(self, subgraphs): |
| self.subgraphs = subgraphs |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| _STRUCTURE_INFO = { |
| "1p": ("1p", {0: "a"}, {}, {0: "r1"}), |
| "2p": ("2p", {0: "a"}, {1: "v1"}, {0: "r1", 1: "r2"}), |
| "3p": ("3p", {0: "a"}, {1: "v1", 2: "v2"}, {0: "r1", 1: "r2", 2: "r3"}), |
| "2i": ("2i", {0: "a1", 2: "a2"}, {}, {0: "r1", 2: "r2"}), |
| "3i": ("3i", {0: "a1", 2: "a2", 4: "a3"}, {}, {0: "r1", 2: "r2", 4: "r3"}), |
| "ip": ("2i1p", {0: "a1", 2: "a2"}, {4: "v1"}, {0: "r1", 2: "r2", 4: "r3"}), |
| "pi": ("1p2i", {0: "a1", 3: "a2"}, {1: "v1"}, {0: "r1", 1: "r2", 3: "r3"}), |
| } |
|
|
|
|
| class ModelRegistry: |
| _instance = None |
|
|
| def __init__(self): |
| self.coins_checkpoints_available = {} |
| self.graphgen_checkpoints_available = {} |
| self.kg_anomaly_checkpoints_available = {} |
| self.loaders = {} |
| self.kg_anomaly_subgraphs = {} |
| self._inference_lock = threading.Lock() |
| self._inference_lock_owner = None |
| self._coins_experiments = {} |
| self._coins_loaders = {} |
| self._graphgen_models = {} |
| self._kg_anomaly_models = {} |
|
|
| def force_release_inference_lock(self): |
| """Emergency release for a stuck inference lock (e.g. client disconnect).""" |
| if self._inference_lock.locked(): |
| self._inference_lock.release() |
| owner = self._inference_lock_owner |
| self._inference_lock_owner = None |
| logger.warning("Inference lock force-released (was held by: %s)", owner) |
| return True |
| return False |
|
|
| @classmethod |
| def get(cls): |
| if cls._instance is None: |
| raise RuntimeError("ModelRegistry not initialized. Call initialize() first.") |
| return cls._instance |
|
|
| @classmethod |
| def initialize(cls): |
| if cls._instance is not None: |
| return |
| instance = cls() |
| instance._download_checkpoints() |
| instance._scan_checkpoints() |
| instance._load_all_loaders() |
| instance._generate_sample_subgraphs() |
| cls._instance = instance |
| logger.info( |
| "ModelRegistry initialized: coins=%s, multiproxan=%s, kg_anomaly=%s, loaders=%s", |
| instance.is_coins_loaded(), |
| instance.is_graphgen_loaded(), |
| instance.is_kg_anomaly_loaded(), |
| list(instance.loaders.keys()), |
| ) |
|
|
| |
|
|
| def _download_checkpoints(self): |
| """Download checkpoints from Hugging Face Hub if not already present. |
| |
| The HF repo mirrors the on-disk layout under ``CHECKPOINTS_ROOT``, so a |
| single ``snapshot_download`` drops every file into its final location. |
| Idempotent: when all expected subdirs are populated we skip the |
| network round-trip. In production the entrypoint script also pre-warms |
| this download before gunicorn starts, so workers never block on it. |
| """ |
| if self._all_checkpoint_dirs_populated(): |
| logger.info("All checkpoint directories already populated, skipping HF Hub download") |
| return |
|
|
| try: |
| from huggingface_hub import snapshot_download |
| except ImportError: |
| logger.warning("huggingface_hub not installed, skipping checkpoint download") |
| return |
|
|
| target = Path(settings.CHECKPOINTS_ROOT) |
| target.mkdir(parents=True, exist_ok=True) |
| logger.info("Downloading checkpoints from HF Hub repo %s -> %s", HF_CHECKPOINTS_REPO, target) |
|
|
| try: |
| snapshot_download( |
| repo_id=HF_CHECKPOINTS_REPO, |
| repo_type="model", |
| local_dir=str(target), |
| local_dir_use_symlinks=False, |
| max_workers=4, |
| token=os.environ.get("HF_TOKEN"), |
| ) |
| except Exception: |
| logger.exception("Failed to download checkpoints from HF Hub, continuing with local files") |
|
|
| def _all_checkpoint_dirs_populated(self): |
| """True if every expected checkpoint subdir contains at least one weight file.""" |
| root = Path(settings.CHECKPOINTS_ROOT) |
| for sub in _CHECKPOINT_SUBDIRS: |
| dest_dir = root / sub |
| if not dest_dir.exists(): |
| return False |
| ckpt_files = list(dest_dir.glob("*.tar")) + list(dest_dir.glob("*.ckpt")) |
| if not ckpt_files: |
| return False |
| return True |
|
|
| |
|
|
| def _scan_checkpoints(self): |
| self._scan_coins_checkpoints() |
| self._scan_graphgen_checkpoints() |
| self._scan_kg_anomaly_checkpoints() |
|
|
| def _scan_coins_checkpoints(self): |
| ckpt_dir = Path(settings.COINS_COMPLETION_DIR) / "checkpoints" |
| if not ckpt_dir.exists(): |
| logger.warning("COINs checkpoint dir not found: %s", ckpt_dir) |
| return |
| for path in ckpt_dir.glob("*.tar"): |
| parts = path.stem.rsplit("_", 1) |
| if len(parts) == 2: |
| dataset_id, algorithm = parts |
| self.coins_checkpoints_available.setdefault(dataset_id, []).append(algorithm) |
| logger.info("COINs checkpoints: %s", self.coins_checkpoints_available) |
|
|
| def _scan_graphgen_checkpoints(self): |
| ckpt_dir = Path(settings.MULTIPROXAN_DIR) / "checkpoints" |
| if not ckpt_dir.exists(): |
| logger.warning("MultiProxAn checkpoint dir not found: %s", ckpt_dir) |
| return |
| for path in ckpt_dir.glob("*.ckpt"): |
| name = path.stem |
| if name.endswith("_c"): |
| dataset_id = name[:-2] |
| self.graphgen_checkpoints_available.setdefault(dataset_id, []).append("continuous") |
| else: |
| self.graphgen_checkpoints_available.setdefault(name, []).append("discrete") |
| logger.info("MultiProxAn checkpoints: %s", self.graphgen_checkpoints_available) |
|
|
| def _scan_kg_anomaly_checkpoints(self): |
| ckpt_dir = Path(settings.DIGRESS_KG_DIR) / "checkpoints" |
| if not ckpt_dir.exists(): |
| logger.warning("DiGress KG checkpoint dir not found: %s", ckpt_dir) |
| return |
| for path in ckpt_dir.glob("*.ckpt"): |
| name = path.stem |
| if name.endswith("_correct"): |
| dataset_id = name[:-8] |
| self.kg_anomaly_checkpoints_available.setdefault(dataset_id, []).append("correct") |
| else: |
| self.kg_anomaly_checkpoints_available.setdefault(name, []).append("generate") |
| logger.info("DiGress KG checkpoints: %s", self.kg_anomaly_checkpoints_available) |
|
|
| |
|
|
| def _load_all_loaders(self): |
| """Initialize one lightweight Loader per dataset for discovery endpoints. |
| |
| Loads dataset, name maps, train/val/test split, and graph indexes. |
| Heavy arrays (node_neighbours, com_neighbours, adjacency dicts) are freed |
| after startup to save memory. Full Loaders for inference are loaded on demand. |
| """ |
| coins_root = str(Path(settings.COINS_DATA_DIR).parent) |
| original_cwd = os.getcwd() |
| try: |
| os.chdir(coins_root) |
| from graph_completion.graphs.load_graph import Loader, LoaderHpars |
|
|
| for dataset_id in _DATASET_BASE: |
| seed = VANILLA_SEEDS[dataset_id] |
| config = get_checkpoint_config(dataset_id, "transe") |
| try: |
| logger.info("Initializing Loader for %s (seed=%d)...", dataset_id, seed) |
| loader = LoaderHpars.from_dict(config["loader_hpars"]).make() |
|
|
| leiden_resolution = config["leiden_resolution"] |
| if leiden_resolution is None: |
| dataset_obj = Loader.datasets[dataset_id] |
| dataset_obj.load_from_disk() |
| leiden_resolution = 1.0 / len(dataset_obj.node_data) |
| dataset_obj.unload_from_memory() |
|
|
| loader.load_graph( |
| seed=seed, device="cpu", val_size=0.01, test_size=0.02, |
| community_method="leiden", leiden_resolution=leiden_resolution, |
| ) |
| |
| |
| |
| import numpy as np |
| expected_len = loader.num_communities + 2 |
| if len(loader.machines) != expected_len: |
| logger.warning( |
| "Stale machines.npz for %s: len=%d but num_communities+2=%d; rebuilding", |
| dataset_id, len(loader.machines), expected_len, |
| ) |
| loader.machines = np.zeros(expected_len, dtype=int) |
| self.loaders[dataset_id] = loader |
| |
| |
| |
| |
| self._coins_loaders[(dataset_id, seed, leiden_resolution)] = loader |
| logger.info( |
| "Loader ready for %s: %d entities, %d relations, %d train triples", |
| dataset_id, loader.num_nodes, loader.num_relations, len(loader.train_edge_data), |
| ) |
| except Exception: |
| logger.exception("Failed to initialize Loader for %s", dataset_id) |
| finally: |
| os.chdir(original_cwd) |
|
|
| |
|
|
| def get_loader(self, dataset_id): |
| """Return the metadata Loader for a dataset, or None.""" |
| return self.loaders.get(dataset_id) |
|
|
| def get_entity_count(self, dataset_id): |
| loader = self.loaders.get(dataset_id) |
| return loader.num_nodes if loader else 0 |
|
|
| def get_relation_count(self, dataset_id): |
| loader = self.loaders.get(dataset_id) |
| return loader.num_relations if loader else 0 |
|
|
| def get_inverted_name_maps(self, dataset_id): |
| """Return (inv_node_names, inv_node_types, inv_relation_names) Series for a dataset.""" |
| loader = self.loaders.get(dataset_id) |
| if loader is None: |
| return None, None, None |
| return loader.dataset.get_inverted_name_maps() |
|
|
| def search_entities(self, dataset_id, query=None, page=1, page_size=50): |
| """Search entities by substring, return paginated (id, name) list and total.""" |
| loader = self.loaders.get(dataset_id) |
| if loader is None: |
| return [], 0 |
| inv_nodes, _, _ = loader.dataset.get_inverted_name_maps() |
| items = [(int(idx), str(name)) for idx, name in inv_nodes.items()] |
| if query: |
| q = query.lower() |
| items = [ |
| (eid, name) for eid, name in items |
| if q in name.lower() or q in clean_entity_name(name, dataset_id).lower() |
| ] |
| total = len(items) |
| start = (max(1, page) - 1) * page_size |
| return items[start:start + page_size], total |
|
|
| def search_relations(self, dataset_id, query=None, page=1, page_size=50): |
| """Search relations by substring, return paginated (id, name) list and total.""" |
| loader = self.loaders.get(dataset_id) |
| if loader is None: |
| return [], 0 |
| _, _, inv_relations = loader.dataset.get_inverted_name_maps() |
| items = [(int(idx), str(name)) for idx, name in inv_relations.items()] |
| if query: |
| q = query.lower() |
| items = [ |
| (rid, name) for rid, name in items |
| if q in name.lower() or q in clean_relation_name(name, dataset_id).lower() |
| ] |
| total = len(items) |
| start = (max(1, page) - 1) * page_size |
| return items[start:start + page_size], total |
|
|
| def sample_triples(self, dataset_id, count=10, seed=None): |
| """Return random triples with resolved entity/relation names. |
| |
| When ``seed`` is provided, sampling is deterministic — the same |
| ``(dataset_id, count, seed)`` always yields the same triples. When |
| ``seed`` is None, uses the global RNG. |
| """ |
| loader = self.loaders.get(dataset_id) |
| if loader is None: |
| return [] |
| inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps() |
| edge_data = loader.train_edge_data |
| count = min(count, len(edge_data)) |
|
|
| rng = random.Random(seed) if seed is not None else random |
| indices = rng.sample(range(len(edge_data)), count) |
|
|
| result = [] |
| for i in indices: |
| row = edge_data.iloc[i] |
| h, r, t = int(row.s), int(row.r), int(row.t) |
| h_name = str(inv_nodes.get(h, h)) |
| r_name = str(inv_relations.get(r, r)) |
| t_name = str(inv_nodes.get(t, t)) |
| result.append({ |
| "head": {"id": h, "name": h_name, "label": clean_entity_name(h_name, dataset_id)}, |
| "relation": {"id": r, "name": r_name, "label": clean_relation_name(r_name, dataset_id)}, |
| "tail": {"id": t, "name": t_name, "label": clean_entity_name(t_name, dataset_id)}, |
| }) |
| return result |
|
|
| |
|
|
| def sample_query(self, dataset_id, query_structure, count=1, seed=None): |
| """Sample structurally valid queries using Query.instantiate() from the research code. |
| |
| Picks random answer entities, walks the training graph backward via |
| adj_t_to_s to produce fully-instantiated query trees, then extracts |
| anchor entities and relation IDs mapped to frontend slot names. |
| """ |
| import numpy as np |
|
|
| loader = self.loaders.get(dataset_id) |
| if loader is None: |
| return [] |
|
|
| info = _STRUCTURE_INFO.get(query_structure) |
| if info is None: |
| return [] |
| structure_str, anchor_map, variable_map, relation_map = info |
|
|
| adj_t_to_s = loader.graph_indexes[2] |
| inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps() |
|
|
| coins_root = str(Path(settings.COINS_DATA_DIR).parent) |
| original_cwd = os.getcwd() |
| try: |
| os.chdir(coins_root) |
| from graph_completion.graphs.queries import Query, query_edge_r_to_int |
| finally: |
| os.chdir(original_cwd) |
|
|
| query = Query(structure_str) |
| query.build_query_tree() |
|
|
| answer_candidates = list(adj_t_to_s.keys()) |
| rng = random.Random(seed) if seed is not None else random |
|
|
| np_state = np.random.get_state() |
| if seed is not None: |
| np.random.seed(hash(seed) % (2**32)) |
|
|
| def ent(eid): |
| name = str(inv_nodes.get(eid, eid)) |
| return {"id": eid, "name": name, "label": clean_entity_name(name, dataset_id)} |
|
|
| def rel(rid): |
| name = str(inv_relations.get(rid, rid)) |
| return {"id": rid, "name": name, "label": clean_relation_name(name, dataset_id)} |
|
|
| results = [] |
| max_attempts = count * 200 |
| try: |
| for _ in range(max_attempts): |
| if len(results) >= count: |
| break |
| answer = rng.choice(answer_candidates) |
| qi = next( |
| query.instantiate(adj_t_to_s, loader.num_nodes, loader.num_relations, answer, sample=True), |
| None, |
| ) |
| if qi is None: |
| continue |
|
|
| qi_mapped = qi.map_to_tree(query.query_tree) |
|
|
| anchors = {} |
| for tree_idx, frontend_id in anchor_map.items(): |
| anchors[frontend_id] = ent(int(qi_mapped.vs[tree_idx]["e"])) |
|
|
| variables = {} |
| for tree_idx, frontend_id in variable_map.items(): |
| variables[frontend_id] = ent(int(qi_mapped.vs[tree_idx]["e"])) |
|
|
| relations = {} |
| for edge_idx, frontend_id in relation_map.items(): |
| rel_id = query_edge_r_to_int(qi_mapped.es[edge_idx]["r"]) |
| relations[frontend_id] = rel(rel_id) |
|
|
| target_id = int(qi_mapped.vs[query.query_answer]["e"]) |
| q = {"anchors": anchors, "relations": relations, "target": ent(target_id)} |
| if variables: |
| q["variables"] = variables |
| results.append(q) |
| finally: |
| np.random.set_state(np_state) |
|
|
| return results |
|
|
| |
|
|
| def _generate_sample_subgraphs(self): |
| """Generate sample subgraphs for KG anomaly using the Loader's context subgraph DFS.""" |
| for dataset_id in COINS_DATASET_META: |
| loader = self.loaders.get(dataset_id) |
| if loader is None: |
| continue |
| try: |
| subgraphs = self._build_sample_subgraphs(dataset_id, loader) |
| self.kg_anomaly_subgraphs[dataset_id] = SubgraphInfo(subgraphs) |
| logger.info("Generated %d sample subgraphs for %s", len(subgraphs), dataset_id) |
| except Exception: |
| logger.exception("Failed to generate sample subgraphs for %s", dataset_id) |
|
|
| def _build_sample_subgraphs(self, dataset_id, loader, num_subgraphs=40, |
| max_graph_size=10, seed=None): |
| """Build sample subgraphs using the Sampler's DFS-based context subgraph partitioning. |
| |
| When ``seed`` is provided, the DFS iterates node indices in a shuffled order, so |
| different seeds produce different partitions. Without a seed the order is |
| deterministic (original research-code behaviour). |
| """ |
| inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps() |
| node_types = loader.dataset.node_data.type.values |
|
|
| def entity_label(idx): |
| raw = inv_nodes.get(idx) |
| if raw is None or raw != raw or str(raw).strip() == "": |
| return f"#{idx}" |
| cleaned = clean_entity_name(str(raw), dataset_id) |
| return cleaned if cleaned else f"#{idx}" |
|
|
| def relation_label(idx): |
| raw = inv_relations.get(idx) |
| if raw is None or raw != raw or str(raw).strip() == "": |
| return f"rel#{idx}" |
| cleaned = clean_relation_name(str(raw), dataset_id) |
| return cleaned if cleaned else f"rel#{idx}" |
|
|
| |
| samples = loader.sampler.get_context_subgraph_samples_dfs( |
| max_graph_size, loader.graph_indexes, loader.num_nodes, |
| max_samples=num_subgraphs * 5, seed=seed, disable_tqdm=True, |
| ) |
|
|
| |
| |
| |
| |
| |
| import random as _random |
| rng = _random.Random(seed) |
| samples = list(samples) |
| rng.shuffle(samples) |
|
|
| used_partitions = set() |
| subgraphs = [] |
| for subgraph_row, subgraph_col, nodes_row, nodes_col, edges in samples: |
| if len(subgraphs) >= num_subgraphs: |
| break |
| if subgraph_row in used_partitions or subgraph_col in used_partitions: |
| continue |
| if len(edges) < 3: |
| continue |
|
|
| is_bip = (subgraph_row != subgraph_col) |
| if is_bip: |
| |
| |
| |
| if len(nodes_row) != len(nodes_col) or len(nodes_row) < 2: |
| continue |
| if (2 * len(nodes_row)) % 4 != 0: |
| continue |
| sg_nodes = nodes_row + nodes_col |
| row_size = len(nodes_row) |
| else: |
| if len(nodes_row) < 4 or len(nodes_row) % 2 != 0: |
| continue |
| sg_nodes = nodes_row |
| row_size = len(nodes_row) |
|
|
| node_idx = {n: i for i, n in enumerate(sg_nodes)} |
|
|
| nodes = [] |
| for n in sg_nodes: |
| type_id = int(node_types[n]) if n < len(node_types) else 0 |
| nodes.append({ |
| "entity_id": n, |
| "entity_name": entity_label(n), |
| "type_id": type_id, |
| }) |
|
|
| edge_list = [] |
| for h, r, t in edges: |
| if h in node_idx and t in node_idx: |
| edge_list.append({ |
| "source_idx": node_idx[h], |
| "target_idx": node_idx[t], |
| "relation_id": r, |
| "relation_name": relation_label(r), |
| "entity_name_source": entity_label(h), |
| "entity_name_target": entity_label(t), |
| }) |
|
|
| subgraphs.append({ |
| "id": f"sample_{len(subgraphs) + 1}", |
| "num_nodes": len(nodes), |
| "num_edges": len(edge_list), |
| "is_bip": is_bip, |
| "row_size": row_size, |
| "nodes": nodes, |
| "edges": edge_list, |
| }) |
| used_partitions.add(subgraph_row) |
| if is_bip: |
| used_partitions.add(subgraph_col) |
|
|
| |
| loader.sampler.context_subgraphs_nodes = None |
| loader.sampler.context_subgraphs_edges = None |
|
|
| return subgraphs |
|
|
| |
|
|
| def _load_coins_experiment(self, dataset_id, algorithm): |
| """Lazily load and cache a fully-prepared Experiment for (dataset_id, algorithm).""" |
| key = (dataset_id, algorithm) |
| if key in self._coins_experiments: |
| return self._coins_experiments[key] |
|
|
| config = get_checkpoint_config(dataset_id, algorithm) |
| seed = config["seed"] |
| leiden_resolution = config["leiden_resolution"] |
|
|
| coins_root = str(Path(settings.COINS_DATA_DIR).parent) |
| configs_dir = Path(settings.COINS_COMPLETION_DIR) / "configs" |
| suffix = COINS_CONFIG_SUFFIX[algorithm] |
| config_path = configs_dir / f"{dataset_id}{suffix}.yml" |
|
|
| with open(config_path, "r", encoding="utf-8") as f: |
| yaml_config = yaml.safe_load(f) |
|
|
| |
| if leiden_resolution is None: |
| original_cwd = os.getcwd() |
| try: |
| os.chdir(coins_root) |
| from graph_completion.graphs.load_graph import Loader |
| dataset_obj = Loader.datasets[dataset_id] |
| dataset_obj.load_from_disk() |
| leiden_resolution = 1.0 / len(dataset_obj.node_data) |
| dataset_obj.unload_from_memory() |
| finally: |
| os.chdir(original_cwd) |
|
|
| hpars = yaml_config | { |
| "seed": seed, |
| "leiden_resolution": leiden_resolution, |
| "device": settings.TORCH_DEVICE, |
| "train": False, |
| "test": False, |
| "results_dir": str(Path(settings.COINS_COMPLETION_DIR) / "results"), |
| } |
| |
| |
| |
| target_embedding_dim = int(yaml_config.get("embedder_hpars", {}).get("embedding_dim", 100)) |
|
|
| original_cwd = os.getcwd() |
| try: |
| os.chdir(coins_root) |
| from graph_completion.experiments import ExperimentHpars, update_state_dict |
| experiment = ExperimentHpars.from_dict(hpars).make() |
|
|
| |
| |
| |
| _orig_load = experiment.loader.load_graph |
| def _patched_load(*args, **kwargs): |
| _orig_load(*args, **kwargs) |
| import numpy as np |
| experiment.loader.machines = np.zeros( |
| experiment.loader.num_communities + 2, dtype=int |
| ) |
| experiment.loader.load_graph = _patched_load |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import torch as _pt |
| import torch.nn as _nn |
| _orig_share_memory = _nn.Module.share_memory |
| _nn.Module.share_memory = lambda self: self |
| _orig_torch_load = _pt.load |
|
|
| def _expand_transe_load(*args, **kwargs): |
| state_dict = _orig_torch_load(*args, **kwargs) |
| if not isinstance(state_dict, dict): |
| return state_dict |
| if not any(k.endswith("entity_embeddings.weight") for k in state_dict): |
| return state_dict |
| |
| sample = next(v for k, v in state_dict.items() |
| if k.endswith("entity_embeddings.weight") and hasattr(v, "shape")) |
| src_dim = int(sample.shape[-1]) |
| if src_dim == target_embedding_dim or src_dim == 0: |
| return state_dict |
| if target_embedding_dim % src_dim != 0: |
| logger.warning( |
| "TransE init dim %d not a divisor of YAML embedding_dim %d; " |
| "leaving init unchanged (load_state_dict may fail).", |
| src_dim, target_embedding_dim, |
| ) |
| return state_dict |
| n_repeats = target_embedding_dim // src_dim |
| expanded = {} |
| for key, value in state_dict.items(): |
| if not hasattr(value, "shape") or value.ndim < 1: |
| expanded[key] = value |
| continue |
| |
| |
| |
| if key.endswith(("entity_embeddings.weight", |
| "entity_embeddings_initial.weight", |
| "r_embeddings_initial.weight")) and value.shape[-1] == src_dim: |
| expanded[key] = value.repeat(*([1] * (value.ndim - 1)), n_repeats) |
| elif key.endswith("r_embeddings.weight") and value.shape[0] == src_dim: |
| expanded[key] = value.repeat(n_repeats, *([1] * (value.ndim - 1))) |
| else: |
| expanded[key] = value |
| logger.info("Expanded transe init from %dd to %dd (x%d repeat) for %s/%s", |
| src_dim, target_embedding_dim, n_repeats, dataset_id, algorithm) |
| return expanded |
|
|
| _pt.load = _expand_transe_load |
| try: |
| loader_key = (dataset_id, seed, leiden_resolution) |
| if loader_key in self._coins_loaders: |
| cached_loader = self._coins_loaders[loader_key] |
| |
| |
| |
| import numpy as np |
| expected_len = cached_loader.num_communities + 2 |
| if len(cached_loader.machines) != expected_len: |
| cached_loader.machines = np.zeros(expected_len, dtype=int) |
| experiment.loader = cached_loader |
| |
| |
| _orig_load_graph = cached_loader.load_graph |
| cached_loader.load_graph = lambda *args, **kwargs: None |
| try: |
| experiment.prepare() |
| finally: |
| cached_loader.load_graph = _orig_load_graph |
| logger.info("Reused shared Loader for %s seed=%d", dataset_id, seed) |
| else: |
| experiment.prepare() |
| self._coins_loaders[loader_key] = experiment.loader |
| logger.info("Cached new Loader for %s seed=%d", dataset_id, seed) |
| finally: |
| _nn.Module.share_memory = _orig_share_memory |
| _pt.load = _orig_torch_load |
|
|
| ckpt_path = (Path(settings.COINS_COMPLETION_DIR) / "checkpoints" |
| / f"{dataset_id}_{algorithm}.tar") |
| import torch as pt |
| ckpt = pt.load(str(ckpt_path), map_location=settings.TORCH_DEVICE) |
| |
| |
| |
| |
| |
| embedder_sd = _adapt_mlp_bn_keys(ckpt["embedder_state_dict"]) |
| embedder_sd = update_state_dict(embedder_sd, experiment.loader.num_relations) |
| |
| |
| embedder_sd = _adapt_shape_mismatches(embedder_sd, experiment.embedder.state_dict()) |
| if algorithm == "kbgat": |
| |
| |
| embedder_sd = _adapt_kbgat_state_dict(embedder_sd, experiment.embedder.state_dict()) |
| experiment.embedder.load_state_dict(embedder_sd, strict=False) |
| experiment.link_ranker.load_state_dict(ckpt["link_ranker_state_dict"], strict=False) |
| experiment.embedder.to(settings.TORCH_DEVICE).eval() |
| experiment.link_ranker.to(settings.TORCH_DEVICE).eval() |
|
|
| |
| |
| |
| |
| loss_margin = yaml_config.get("embedding_loss_hpars", {}).get("margin") |
| if loss_margin is not None and algorithm not in ("transe", "rotate"): |
| for ranker in (experiment.link_ranker.link_ranker, |
| experiment.link_ranker.community_link_ranker): |
| if hasattr(ranker, "margin"): |
| ranker.margin = float(loss_margin) |
| logger.debug("Patched %s link_ranker margin to %s", algorithm, loss_margin) |
| finally: |
| os.chdir(original_cwd) |
|
|
| |
| |
| |
| full_edge_data = experiment.loader.dataset.edge_data |
| full_adj_s_to_t, full_adj_t_to_s = {}, {} |
| for s, r, t in full_edge_data[["s", "r", "t"]].values: |
| s, r, t = int(s), int(r), int(t) |
| full_adj_s_to_t.setdefault(s, {}).setdefault(r, []).append(t) |
| full_adj_t_to_s.setdefault(t, {}).setdefault(r, []).append(s) |
| experiment.full_adj_s_to_t = full_adj_s_to_t |
| experiment.full_adj_t_to_s = full_adj_t_to_s |
| logger.info("Built full-KG adjacency for %s/%s (%d edges)", |
| dataset_id, algorithm, len(full_edge_data)) |
|
|
| self._coins_experiments[key] = experiment |
| logger.info("COINs experiment ready: dataset=%s algorithm=%s", dataset_id, algorithm) |
| return experiment |
|
|
| |
|
|
| def _load_graphgen_model(self, dataset_id, model_type): |
| key = (dataset_id, model_type) |
| if key in self._graphgen_models: |
| return self._graphgen_models[key] |
|
|
| |
| if model_type == "discrete": |
| from diffusion_model_discrete import DiscreteDenoisingDiffusion as cls |
| else: |
| from diffusion_model import LiftedDenoisingDiffusion as cls |
|
|
| suffix = "_c" if model_type == "continuous" else "" |
| ckpt_path = Path(settings.MULTIPROXAN_DIR) / "checkpoints" / f"{dataset_id}{suffix}.ckpt" |
| if not ckpt_path.exists(): |
| from api.exceptions import ModelUnavailable |
| raise ModelUnavailable(f"Checkpoint not found: {ckpt_path.name}") |
|
|
| logger.info("Loading MultiProxAn model: dataset=%s model_type=%s", dataset_id, model_type) |
| model = _safe_load_lightning_checkpoint(cls, ckpt_path) |
| self._graphgen_models[key] = model |
| logger.info("MultiProxAn model ready: dataset=%s model_type=%s", dataset_id, model_type) |
| return model |
|
|
| def graphgen_generate_stream(self, dataset_id, model_type, sampling_mode, num_nodes, |
| diffusion_steps, chain_frames, multiprox_params): |
| """Return a generator of NDJSON dicts (progress + result). |
| |
| Lock acquisition and model loading happen eagerly so errors surface |
| as normal DRF exceptions. The returned generator releases the lock |
| in its ``finally`` block. |
| """ |
| from api.exceptions import InferenceBusy |
| from api.services.graphgen_inference import ( |
| encode_state_blob, run_multiprox_init, run_standard_generation, |
| ) |
| if not self._inference_lock.acquire(blocking=False): |
| raise InferenceBusy() |
| self._inference_lock_owner = f"graphgen_generate {dataset_id}/{model_type}/{sampling_mode}" |
| try: |
| model = self._load_graphgen_model(dataset_id, model_type) |
| except Exception: |
| self._inference_lock_owner = None |
| self._inference_lock.release() |
| raise |
|
|
| def _gen(): |
| try: |
| if sampling_mode == "standard": |
| for event in run_standard_generation( |
| model, num_nodes, diffusion_steps, chain_frames, dataset_id): |
| if event["type"] == "result": |
| event.update({ |
| "dataset_id": dataset_id, |
| "model_type": model_type, |
| "sampling_mode": sampling_mode, |
| }) |
| yield event |
| else: |
| n = multiprox_params["n"] |
| m = multiprox_params["m"] |
| t = multiprox_params["t"] |
| t_prime = multiprox_params["t_prime"] |
| gibbs_chain_freq = multiprox_params["gibbs_chain_freq"] |
| for event in run_multiprox_init( |
| model, num_nodes, n, m, t, t_prime, gibbs_chain_freq, dataset_id): |
| if event["type"] == "result": |
| state = event.pop("state") |
| state["model_type"] = model_type |
| event.update({ |
| "step": 0, |
| "round_complete": False, |
| "done": False, |
| "state": encode_state_blob(state), |
| }) |
| yield event |
| finally: |
| self._inference_lock_owner = None |
| self._inference_lock.release() |
|
|
| return _gen() |
|
|
| def graphgen_continue_stream(self, state_b64): |
| """Return a generator of NDJSON dicts (progress + result). |
| |
| Blob decoding, lock acquisition, and model loading happen eagerly so |
| errors surface as normal DRF exceptions. The returned generator |
| releases the lock in its ``finally`` block. |
| """ |
| from api.exceptions import InferenceBusy, InvalidRequestError |
| from api.services.graphgen_inference import ( |
| decode_state_blob, encode_state_blob, run_multiprox_step, |
| ) |
| try: |
| state = decode_state_blob(state_b64) |
| except ValueError as exc: |
| raise InvalidRequestError(str(exc)) |
|
|
| if not self._inference_lock.acquire(blocking=False): |
| raise InferenceBusy() |
| self._inference_lock_owner = f"graphgen_continue {state['dataset_id']}/{state['model_type']}" |
| try: |
| model = self._load_graphgen_model(state["dataset_id"], state["model_type"]) |
| except Exception: |
| self._inference_lock_owner = None |
| self._inference_lock.release() |
| raise |
|
|
| def _gen(): |
| try: |
| for event in run_multiprox_step(model, state, state["dataset_id"]): |
| if event["type"] == "result": |
| updated_state = event.pop("state") |
| event.update({ |
| "step": updated_state["step"], |
| "state": encode_state_blob(updated_state), |
| }) |
| yield event |
| finally: |
| self._inference_lock_owner = None |
| self._inference_lock.release() |
|
|
| return _gen() |
|
|
| |
|
|
| def _load_kg_anomaly_model(self, dataset_id, task): |
| """Load the DiGress KG checkpoint for (dataset_id, task), cached. |
| |
| The KG checkpoint pickles only ``cfg`` via ``save_hyperparameters('cfg')``, |
| so we must reconstruct ``dataset_infos``, ``extra_features`` and |
| ``domain_features`` before constructing the model. Dims are inferred from |
| state_dict shapes; kg_experiment comes from the matching COINs experiment. |
| """ |
| key = (dataset_id, task) |
| if key in self._kg_anomaly_models: |
| return self._kg_anomaly_models[key] |
|
|
| import torch |
| import torch.nn.parallel.distributed as _ddp_mod |
|
|
| suffix = "_correct" if task == "correct" else "" |
| ckpt_path = Path(settings.DIGRESS_KG_DIR) / "checkpoints" / f"{dataset_id}{suffix}.ckpt" |
| if not ckpt_path.exists(): |
| from api.exceptions import ModelUnavailable |
| raise ModelUnavailable(f"KG anomaly checkpoint not found: {ckpt_path.name}") |
|
|
| logger.info("Loading KG anomaly model: dataset=%s task=%s", dataset_id, task) |
|
|
| |
| _orig_set = _ddp_mod.DistributedDataParallel.__setstate__ |
| _orig_get = _ddp_mod.DistributedDataParallel.__getstate__ |
| _ddp_mod.DistributedDataParallel.__setstate__ = lambda self, state: self.__dict__.update(state) |
| _ddp_mod.DistributedDataParallel.__getstate__ = lambda self: self.__dict__ |
| try: |
| ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False) |
| finally: |
| _ddp_mod.DistributedDataParallel.__setstate__ = _orig_set |
| _ddp_mod.DistributedDataParallel.__getstate__ = _orig_get |
|
|
| hparams = ckpt.get("hyper_parameters", {}) |
| cfg = hparams.get("cfg") if isinstance(hparams, dict) else getattr(hparams, "cfg", None) |
| if cfg is None: |
| raise RuntimeError(f"KG anomaly checkpoint {ckpt_path.name} is missing 'cfg' in hyper_parameters") |
| state_dict = ckpt["state_dict"] |
|
|
| |
| try: |
| cfg.model.task = task |
| except Exception: |
| pass |
|
|
| |
| edim_output = state_dict["model.mlp_out_E.2.weight"].shape[0] |
| input_dim_x = state_dict["model.mlp_in_X.0.weight"].shape[1] |
| input_dim_e = state_dict["model.mlp_in_E.0.weight"].shape[1] |
| input_dim_y = state_dict["model.mlp_in_y.0.weight"].shape[1] |
|
|
| |
| experiment = self._load_coins_experiment(dataset_id, "transe") |
| xdim_output = experiment.loader.num_node_types |
| |
| if input_dim_e != edim_output: |
| logger.warning( |
| "Unexpected mlp_in_E dim %d != edim_output %d for %s/%s", |
| input_dim_e, edim_output, dataset_id, task, |
| ) |
|
|
| |
| from graph_generation.src.diffusion.distributions import DistributionNodes |
| from graph_generation.src.diffusion.extra_features import ( |
| DummyExtraFeatures, ExtraFeatures, |
| ) |
|
|
| |
| try: |
| base_max = int(cfg.dataset.name.split("_")[-1]) |
| except (AttributeError, ValueError): |
| base_max = 20 |
| max_num_nodes = base_max * 2 |
|
|
| |
| n_hist = torch.ones(max_num_nodes + 1) |
| n_hist[:2] = 0 |
| nodes_dist = DistributionNodes(n_hist) |
|
|
| class _MockDataModule: |
| def __init__(self, kg_experiment, max_num_nodes): |
| self.kg_experiment = kg_experiment |
| self.max_num_nodes = max_num_nodes |
|
|
| class _MockDatasetInfos: |
| pass |
|
|
| dataset_infos = _MockDatasetInfos() |
| dataset_infos.datamodule = _MockDataModule(experiment, max_num_nodes) |
| dataset_infos.input_dims = {"X": input_dim_x, "E": input_dim_e, "y": input_dim_y} |
| dataset_infos.output_dims = {"X": xdim_output, "E": edim_output, "y": 0} |
| dataset_infos.nodes_dist = nodes_dist |
| dataset_infos.max_n_nodes = max_num_nodes |
| dataset_infos.node_types = torch.ones(xdim_output, dtype=torch.float32) |
| dataset_infos.edge_types = torch.ones(edim_output, dtype=torch.float32) |
|
|
| |
| extra_features_type = getattr(cfg.model, "extra_features", None) |
| if cfg.model.type == "discrete" and extra_features_type is not None: |
| extra_features = ExtraFeatures(extra_features_type, dataset_info=dataset_infos) |
| else: |
| extra_features = DummyExtraFeatures() |
| domain_features = DummyExtraFeatures() |
|
|
| from diffusion_model_discrete_kg import DiscreteDenoisingDiffusionKG as cls |
|
|
| _orig_save = cls.save_hyperparameters |
| cls.save_hyperparameters = lambda self, *a, **kw: None |
| try: |
| model = cls(cfg, dataset_infos, None, None, None, extra_features, domain_features) |
| finally: |
| cls.save_hyperparameters = _orig_save |
|
|
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing: |
| logger.debug("KG anomaly state_dict missing keys: %d (e.g. %s)", |
| len(missing), missing[:3]) |
| if unexpected: |
| logger.debug("KG anomaly state_dict unexpected keys: %d (e.g. %s)", |
| len(unexpected), unexpected[:3]) |
|
|
| del ckpt |
| model.to(settings.TORCH_DEVICE) |
| model.eval() |
| self._kg_anomaly_models[key] = model |
| logger.info("KG anomaly model ready: dataset=%s task=%s", dataset_id, task) |
| return model |
|
|
| def kg_anomaly_correct_stream(self, dataset_id, task, sampling_mode, subgraph, |
| diffusion_steps, chain_frames, multiprox_params): |
| """Return a generator of SSE event dicts for /kg-anomaly/correct.""" |
| from api.exceptions import InferenceBusy |
| from api.services.kg_anomaly_inference import ( |
| build_kg_tensors, encode_state_blob, |
| run_multiprox_correction_init, run_standard_correction, |
| ) |
| if not self._inference_lock.acquire(blocking=False): |
| raise InferenceBusy() |
| self._inference_lock_owner = f"kg_anomaly_correct {dataset_id}/{task}/{sampling_mode}" |
| try: |
| model = self._load_kg_anomaly_model(dataset_id, task) |
| loader = self.loaders.get(dataset_id) |
| tensors = build_kg_tensors(subgraph, loader, model) |
| except Exception: |
| self._inference_lock_owner = None |
| self._inference_lock.release() |
| raise |
|
|
| def _gen(): |
| try: |
| if sampling_mode == "standard": |
| for event in run_standard_correction( |
| model, tensors, dataset_id, task, loader, |
| diffusion_steps, chain_frames): |
| if event["type"] == "result": |
| event.update({ |
| "dataset_id": dataset_id, |
| "task": task, |
| "sampling_mode": sampling_mode, |
| }) |
| yield event |
| else: |
| n = multiprox_params["n"] |
| m = multiprox_params["m"] |
| t = multiprox_params["t"] |
| t_prime = multiprox_params["t_prime"] |
| gibbs_chain_freq = multiprox_params["gibbs_chain_freq"] |
| for event in run_multiprox_correction_init( |
| model, tensors, dataset_id, task, loader, |
| n, m, t, t_prime, gibbs_chain_freq): |
| if event["type"] == "result": |
| state = event.pop("state") |
| event.update({ |
| "dataset_id": dataset_id, |
| "task": task, |
| "sampling_mode": sampling_mode, |
| "step": 0, |
| "round_complete": False, |
| "done": False, |
| "state": encode_state_blob(state), |
| }) |
| yield event |
| finally: |
| self._inference_lock_owner = None |
| self._inference_lock.release() |
|
|
| return _gen() |
|
|
| def kg_anomaly_continue_stream(self, state_b64): |
| """Return a generator of SSE event dicts for /kg-anomaly/continue.""" |
| from api.exceptions import InferenceBusy, InvalidRequestError |
| from api.services.kg_anomaly_inference import ( |
| decode_state_blob, encode_state_blob, run_multiprox_correction_step, |
| ) |
| try: |
| state = decode_state_blob(state_b64) |
| except ValueError as exc: |
| raise InvalidRequestError(str(exc)) |
|
|
| if not self._inference_lock.acquire(blocking=False): |
| raise InferenceBusy() |
| self._inference_lock_owner = ( |
| f"kg_anomaly_continue {state['dataset_id']}/{state['task']}" |
| ) |
| try: |
| model = self._load_kg_anomaly_model(state["dataset_id"], state["task"]) |
| loader = self.loaders.get(state["dataset_id"]) |
| except Exception: |
| self._inference_lock_owner = None |
| self._inference_lock.release() |
| raise |
|
|
| def _gen(): |
| try: |
| for event in run_multiprox_correction_step(model, state, loader): |
| if event["type"] == "result": |
| updated_state = event.pop("state") |
| event.update({ |
| "dataset_id": updated_state["dataset_id"], |
| "task": updated_state["task"], |
| "step": updated_state["step"], |
| "state": encode_state_blob(updated_state), |
| }) |
| yield event |
| finally: |
| self._inference_lock_owner = None |
| self._inference_lock.release() |
|
|
| return _gen() |
|
|
| |
|
|
| def coins_predict(self, dataset_id, algorithm, query_structure_id, |
| anchors, variables, relations_map, top_k): |
| """Run COINs link prediction / query answering for a single query.""" |
| from api.exceptions import InferenceBusy |
| from api.services.coins_inference import coins_predict_inner |
| if not self._inference_lock.acquire(blocking=False): |
| raise InferenceBusy() |
| self._inference_lock_owner = f"coins_predict {dataset_id}/{algorithm}" |
| try: |
| experiment = self._load_coins_experiment(dataset_id, algorithm) |
| return coins_predict_inner( |
| experiment, dataset_id, algorithm, query_structure_id, |
| anchors, variables, relations_map, top_k, |
| ) |
| finally: |
| self._inference_lock_owner = None |
| self._inference_lock.release() |
|
|
| |
|
|
| def is_coins_loaded(self): |
| return bool(self.coins_checkpoints_available) |
|
|
| def is_graphgen_loaded(self): |
| return bool(self.graphgen_checkpoints_available) |
|
|
| def is_kg_anomaly_loaded(self): |
| return bool(self.kg_anomaly_checkpoints_available) |
|
|