fix(coins): score full hit community instead of only valid answers
Browse filesStep 2 used to filter candidate entities to the set returned by
get_all_answers, so the link ranker only ever scored entities already
known to satisfy the query. Top-K therefore showed at most one row when
a community had a single valid answer, and the reported speedup hid the
fact that the model was not actually discriminating against negatives.
Now Step 2 mirrors rank_samples (experiments.py:782-786): the hit
community is the first one in step-1 order containing any valid answer,
and we score every entity in it minus the anchors. Predictions carry an
is_valid_answer flag so the UI can mark genuine hits with a check icon.
Scoring is mini-batched at 512 to bound CPU memory on large communities.
The OpenAPI CoinsPrediction schema gains the required is_valid_answer
boolean to keep clients in sync.
|
@@ -1175,7 +1175,7 @@ components:
|
|
| 1175 |
|
| 1176 |
CoinsPrediction:
|
| 1177 |
type: object
|
| 1178 |
-
required: [rank, intra_community_rank, entity_id, entity_name, score]
|
| 1179 |
properties:
|
| 1180 |
rank:
|
| 1181 |
type: integer
|
|
@@ -1200,6 +1200,14 @@ components:
|
|
| 1200 |
type: number
|
| 1201 |
format: float
|
| 1202 |
example: 0.923
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1203 |
|
| 1204 |
CoinsTiming:
|
| 1205 |
type: object
|
|
|
|
| 1175 |
|
| 1176 |
CoinsPrediction:
|
| 1177 |
type: object
|
| 1178 |
+
required: [rank, intra_community_rank, entity_id, entity_name, score, is_valid_answer]
|
| 1179 |
properties:
|
| 1180 |
rank:
|
| 1181 |
type: integer
|
|
|
|
| 1200 |
type: number
|
| 1201 |
format: float
|
| 1202 |
example: 0.923
|
| 1203 |
+
is_valid_answer:
|
| 1204 |
+
type: boolean
|
| 1205 |
+
description: |
|
| 1206 |
+
True if this entity actually satisfies the query in the KG (i.e. is a member of
|
| 1207 |
+
`get_all_answers`). Step 2 scores every entity in the hit community, so any
|
| 1208 |
+
prediction may be a non-answer the model surfaced — the flag lets the frontend
|
| 1209 |
+
mark the genuine hits.
|
| 1210 |
+
example: true
|
| 1211 |
|
| 1212 |
CoinsTiming:
|
| 1213 |
type: object
|
|
@@ -17,6 +17,11 @@ from api.utils import clean_entity_name, clean_relation_name
|
|
| 17 |
from graph_completion.graphs.queries import Query, get_all_answers, get_node_cut_cache, query_edge_r_to_int
|
| 18 |
from graph_completion.graphs.preprocess import QueryData
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
|
| 22 |
anchors, variables, relations_map, top_k):
|
|
@@ -104,8 +109,11 @@ def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
|
|
| 104 |
community_order = community_scores.argsort(descending=True) # [K]
|
| 105 |
step1_ms = (time.perf_counter() - t1_start) * 1000.0
|
| 106 |
|
| 107 |
-
# ---- Step 2:
|
| 108 |
-
# Mirrors rank_samples:
|
|
|
|
|
|
|
|
|
|
| 109 |
t2_start = time.perf_counter()
|
| 110 |
|
| 111 |
valid_answers = set(get_all_answers(qi_skeleton, query, adj_s_to_t))
|
|
@@ -114,68 +122,70 @@ def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
|
|
| 114 |
"No entities in the knowledge graph satisfy this query"
|
| 115 |
)
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
| 118 |
rank_c = 0 # 1-based rank of the hit community (0 = no hit)
|
| 119 |
c_err = 0 # sum of sizes of communities with better step-1 score than the hit
|
| 120 |
community_size = 0 # size of the hit community
|
|
|
|
| 121 |
|
| 122 |
for rank_0indexed in range(num_communities):
|
| 123 |
cid = int(community_order[rank_0indexed].item())
|
| 124 |
c_entities_tensor = (community_membership == cid).nonzero(as_tuple=True)[0]
|
| 125 |
c_size = int(c_entities_tensor.shape[0])
|
| 126 |
-
|
| 127 |
c_entities = [int(e.item()) for e in c_entities_tensor]
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
for candidate in scored_candidates:
|
| 131 |
-
# Build qi_m directly: anchor + resolved-variable positions stay fixed;
|
| 132 |
-
# answer and phantom intersection nodes receive the candidate entity.
|
| 133 |
-
qi_m = qi_skeleton.copy()
|
| 134 |
-
qi_m.vs["e"] = [
|
| 135 |
-
entities_skeleton[i] if entities_skeleton[i] != -1 else candidate
|
| 136 |
-
for i in range(num_tree_nodes)
|
| 137 |
-
]
|
| 138 |
-
valid_qi_mappeds.append(qi_m)
|
| 139 |
-
valid_candidates.append(candidate)
|
| 140 |
-
|
| 141 |
-
if valid_qi_mappeds:
|
| 142 |
rank_c = rank_0indexed + 1
|
| 143 |
community_size = c_size
|
|
|
|
| 144 |
break
|
| 145 |
|
| 146 |
c_err += c_size
|
| 147 |
|
| 148 |
-
if not
|
| 149 |
raise InvalidRequestError(
|
| 150 |
"No entities in the knowledge graph satisfy this query"
|
| 151 |
)
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
e_batch, x_batch, c_batch, edge_attr_batch = [], [], [], []
|
| 156 |
for i in range(num_tree_nodes):
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
| 158 |
e_batch.append(entities_i)
|
| 159 |
x_batch.append(one_hot(node_types_tensor[entities_i], num_node_types).float())
|
| 160 |
c_batch.append(community_membership[entities_i])
|
| 161 |
for j in range(num_tree_edges):
|
| 162 |
-
r_label =
|
| 163 |
if "p" in r_label:
|
| 164 |
r_id = int(r_label[1:])
|
| 165 |
edge_attr_batch.append(
|
| 166 |
-
one_hot(pt.full([
|
| 167 |
)
|
| 168 |
else:
|
| 169 |
edge_attr_batch.append(
|
| 170 |
-
one_hot(pt.full([
|
| 171 |
)
|
| 172 |
|
| 173 |
with pt.no_grad():
|
| 174 |
batched_query = QueryData(query, e=e_batch, x=x_batch, c=c_batch, edge_attr=edge_attr_batch)
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
k = min(top_k,
|
| 179 |
top_scores, top_indices = scores.topk(k)
|
| 180 |
step2_ms = (time.perf_counter() - t2_start) * 1000.0
|
| 181 |
|
|
@@ -183,7 +193,7 @@ def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
|
|
| 183 |
inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
|
| 184 |
predictions = []
|
| 185 |
for intra_community_rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), 1):
|
| 186 |
-
entity_id =
|
| 187 |
raw_name = str(inv_nodes.get(entity_id, entity_id))
|
| 188 |
predictions.append({
|
| 189 |
"intra_community_rank": intra_community_rank,
|
|
@@ -191,6 +201,7 @@ def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
|
|
| 191 |
"entity_id": entity_id,
|
| 192 |
"entity_name": clean_entity_name(raw_name, dataset_id),
|
| 193 |
"score": round(float(score), 4),
|
|
|
|
| 194 |
})
|
| 195 |
|
| 196 |
total_ms = step1_ms + step2_ms
|
|
|
|
| 17 |
from graph_completion.graphs.queries import Query, get_all_answers, get_node_cut_cache, query_edge_r_to_int
|
| 18 |
from graph_completion.graphs.preprocess import QueryData
|
| 19 |
|
| 20 |
+
# Step-2 mini-batch cap: the link-ranker runs a full forward pass per candidate,
|
| 21 |
+
# and the hit community can hold thousands of nodes on Freebase. Splitting keeps
|
| 22 |
+
# CPU memory bounded without changing results.
|
| 23 |
+
SCORING_MINI_BATCH_SIZE = 512
|
| 24 |
+
|
| 25 |
|
| 26 |
def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
|
| 27 |
anchors, variables, relations_map, top_k):
|
|
|
|
| 109 |
community_order = community_scores.argsort(descending=True) # [K]
|
| 110 |
step1_ms = (time.perf_counter() - t1_start) * 1000.0
|
| 111 |
|
| 112 |
+
# ---- Step 2: Score every entity in the hit community ----
|
| 113 |
+
# Mirrors rank_samples (experiments.py:782-786): the hit community is the
|
| 114 |
+
# first one in step-1 order that contains *any* KG-valid answer; within it
|
| 115 |
+
# we score all entities (minus anchors) so the link ranker actually has to
|
| 116 |
+
# discriminate, instead of being handed only the known answers.
|
| 117 |
t2_start = time.perf_counter()
|
| 118 |
|
| 119 |
valid_answers = set(get_all_answers(qi_skeleton, query, adj_s_to_t))
|
|
|
|
| 122 |
"No entities in the knowledge graph satisfy this query"
|
| 123 |
)
|
| 124 |
|
| 125 |
+
anchor_entity_ids = {
|
| 126 |
+
entities_skeleton[i] for i in query.query_anchors if entities_skeleton[i] != -1
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
rank_c = 0 # 1-based rank of the hit community (0 = no hit)
|
| 130 |
c_err = 0 # sum of sizes of communities with better step-1 score than the hit
|
| 131 |
community_size = 0 # size of the hit community
|
| 132 |
+
candidates = []
|
| 133 |
|
| 134 |
for rank_0indexed in range(num_communities):
|
| 135 |
cid = int(community_order[rank_0indexed].item())
|
| 136 |
c_entities_tensor = (community_membership == cid).nonzero(as_tuple=True)[0]
|
| 137 |
c_size = int(c_entities_tensor.shape[0])
|
|
|
|
| 138 |
c_entities = [int(e.item()) for e in c_entities_tensor]
|
| 139 |
+
|
| 140 |
+
if any(e in valid_answers for e in c_entities):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
rank_c = rank_0indexed + 1
|
| 142 |
community_size = c_size
|
| 143 |
+
candidates = [e for e in c_entities if e not in anchor_entity_ids]
|
| 144 |
break
|
| 145 |
|
| 146 |
c_err += c_size
|
| 147 |
|
| 148 |
+
if not candidates:
|
| 149 |
raise InvalidRequestError(
|
| 150 |
"No entities in the knowledge graph satisfy this query"
|
| 151 |
)
|
| 152 |
|
| 153 |
+
n_candidates = len(candidates)
|
| 154 |
+
candidates_tensor = pt.tensor(candidates, dtype=pt.long, device=device)
|
| 155 |
+
|
| 156 |
+
# Build per-tree-node entity columns. Anchor / resolved-variable positions
|
| 157 |
+
# repeat the same id across the batch; the answer and phantom-i positions
|
| 158 |
+
# take each candidate.
|
| 159 |
e_batch, x_batch, c_batch, edge_attr_batch = [], [], [], []
|
| 160 |
for i in range(num_tree_nodes):
|
| 161 |
+
if entities_skeleton[i] == -1:
|
| 162 |
+
entities_i = candidates_tensor
|
| 163 |
+
else:
|
| 164 |
+
entities_i = pt.full([n_candidates], entities_skeleton[i], dtype=pt.long, device=device)
|
| 165 |
e_batch.append(entities_i)
|
| 166 |
x_batch.append(one_hot(node_types_tensor[entities_i], num_node_types).float())
|
| 167 |
c_batch.append(community_membership[entities_i])
|
| 168 |
for j in range(num_tree_edges):
|
| 169 |
+
r_label = qi_skeleton.es[j]["r"]
|
| 170 |
if "p" in r_label:
|
| 171 |
r_id = int(r_label[1:])
|
| 172 |
edge_attr_batch.append(
|
| 173 |
+
one_hot(pt.full([n_candidates], r_id, dtype=pt.long, device=device), num_relations + 1).float()
|
| 174 |
)
|
| 175 |
else:
|
| 176 |
edge_attr_batch.append(
|
| 177 |
+
one_hot(pt.full([n_candidates], num_relations, dtype=pt.long, device=device), num_relations + 1).float()
|
| 178 |
)
|
| 179 |
|
| 180 |
with pt.no_grad():
|
| 181 |
batched_query = QueryData(query, e=e_batch, x=x_batch, c=c_batch, edge_attr=edge_attr_batch)
|
| 182 |
+
score_chunks = []
|
| 183 |
+
for chunk in batched_query.batch_split(SCORING_MINI_BATCH_SIZE):
|
| 184 |
+
q_emb, a_emb = embedder(chunk)
|
| 185 |
+
score_chunks.append(link_ranker(q_emb, a_emb).view(-1))
|
| 186 |
+
scores = pt.cat(score_chunks)
|
| 187 |
|
| 188 |
+
k = min(top_k, n_candidates)
|
| 189 |
top_scores, top_indices = scores.topk(k)
|
| 190 |
step2_ms = (time.perf_counter() - t2_start) * 1000.0
|
| 191 |
|
|
|
|
| 193 |
inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
|
| 194 |
predictions = []
|
| 195 |
for intra_community_rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), 1):
|
| 196 |
+
entity_id = candidates[idx]
|
| 197 |
raw_name = str(inv_nodes.get(entity_id, entity_id))
|
| 198 |
predictions.append({
|
| 199 |
"intra_community_rank": intra_community_rank,
|
|
|
|
| 201 |
"entity_id": entity_id,
|
| 202 |
"entity_name": clean_entity_name(raw_name, dataset_id),
|
| 203 |
"score": round(float(score), 4),
|
| 204 |
+
"is_valid_answer": entity_id in valid_answers,
|
| 205 |
})
|
| 206 |
|
| 207 |
total_ms = step1_ms + step2_ms
|
|
@@ -19,7 +19,14 @@ function barWidth(score) {
|
|
| 19 |
#{{ p.rank }}
|
| 20 |
</div>
|
| 21 |
<div class="pred-body">
|
| 22 |
-
<div class="pred-name" :title="`entity id ${p.entity_id}`">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
<div class="pred-bar">
|
| 24 |
<div class="pred-bar-fill" :style="{ width: barWidth(p.score) }"></div>
|
| 25 |
</div>
|
|
@@ -55,6 +62,7 @@ function barWidth(score) {
|
|
| 55 |
}
|
| 56 |
.pred-body { min-width: 0; }
|
| 57 |
.pred-name { font-weight: 600; word-break: break-word; }
|
|
|
|
| 58 |
.pred-bar {
|
| 59 |
width: 100%;
|
| 60 |
height: 6px;
|
|
|
|
| 19 |
#{{ p.rank }}
|
| 20 |
</div>
|
| 21 |
<div class="pred-body">
|
| 22 |
+
<div class="pred-name" :title="`entity id ${p.entity_id}`">
|
| 23 |
+
{{ p.entity_name }}
|
| 24 |
+
<i
|
| 25 |
+
v-if="p.is_valid_answer"
|
| 26 |
+
class="check circle icon valid-flag"
|
| 27 |
+
title="Known KG answer — this entity actually satisfies the query"
|
| 28 |
+
></i>
|
| 29 |
+
</div>
|
| 30 |
<div class="pred-bar">
|
| 31 |
<div class="pred-bar-fill" :style="{ width: barWidth(p.score) }"></div>
|
| 32 |
</div>
|
|
|
|
| 62 |
}
|
| 63 |
.pred-body { min-width: 0; }
|
| 64 |
.pred-name { font-weight: 600; word-break: break-word; }
|
| 65 |
+
.pred-name .valid-flag { color: var(--primary-strong); margin-left: 0.35rem; font-size: 0.95em; }
|
| 66 |
.pred-bar {
|
| 67 |
width: 100%;
|
| 68 |
height: 6px;
|