Andrej Janchevski commited on
Commit
2d03843
·
1 Parent(s): f74f48e

fix(coins): score full hit community instead of only valid answers

Browse files

Step 2 used to filter candidate entities to the set returned by
get_all_answers, so the link ranker only ever scored entities already
known to satisfy the query. Top-K therefore showed at most one row when
a community had a single valid answer, and the reported speedup hid the
fact that the model was not actually discriminating against negatives.

Now Step 2 mirrors rank_samples (experiments.py:782-786): the hit
community is the first one in step-1 order containing any valid answer,
and we score every entity in it minus the anchors. Predictions carry an
is_valid_answer flag so the UI can mark genuine hits with a check icon.
Scoring is mini-batched at 512 to bound CPU memory on large communities.

The OpenAPI CoinsPrediction schema gains the required is_valid_answer
boolean to keep clients in sync.

docs/api.yaml CHANGED
@@ -1175,7 +1175,7 @@ components:
1175
 
1176
  CoinsPrediction:
1177
  type: object
1178
- required: [rank, intra_community_rank, entity_id, entity_name, score]
1179
  properties:
1180
  rank:
1181
  type: integer
@@ -1200,6 +1200,14 @@ components:
1200
  type: number
1201
  format: float
1202
  example: 0.923
 
 
 
 
 
 
 
 
1203
 
1204
  CoinsTiming:
1205
  type: object
 
1175
 
1176
  CoinsPrediction:
1177
  type: object
1178
+ required: [rank, intra_community_rank, entity_id, entity_name, score, is_valid_answer]
1179
  properties:
1180
  rank:
1181
  type: integer
 
1200
  type: number
1201
  format: float
1202
  example: 0.923
1203
+ is_valid_answer:
1204
+ type: boolean
1205
+ description: |
1206
+ True if this entity actually satisfies the query in the KG (i.e. is a member of
1207
+ `get_all_answers`). Step 2 scores every entity in the hit community, so any
1208
+ prediction may be a non-answer the model surfaced — the flag lets the frontend
1209
+ mark the genuine hits.
1210
+ example: true
1211
 
1212
  CoinsTiming:
1213
  type: object
src/backend/api/services/coins_inference.py CHANGED
@@ -17,6 +17,11 @@ from api.utils import clean_entity_name, clean_relation_name
17
  from graph_completion.graphs.queries import Query, get_all_answers, get_node_cut_cache, query_edge_r_to_int
18
  from graph_completion.graphs.preprocess import QueryData
19
 
 
 
 
 
 
20
 
21
  def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
22
  anchors, variables, relations_map, top_k):
@@ -104,8 +109,11 @@ def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
104
  community_order = community_scores.argsort(descending=True) # [K]
105
  step1_ms = (time.perf_counter() - t1_start) * 1000.0
106
 
107
- # ---- Step 2: Search communities in descending step-1 score order ----
108
- # Mirrors rank_samples: get_all_answers pre-filters to KG-valid answers.
 
 
 
109
  t2_start = time.perf_counter()
110
 
111
  valid_answers = set(get_all_answers(qi_skeleton, query, adj_s_to_t))
@@ -114,68 +122,70 @@ def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
114
  "No entities in the knowledge graph satisfy this query"
115
  )
116
 
117
- valid_qi_mappeds, valid_candidates = [], []
 
 
 
118
  rank_c = 0 # 1-based rank of the hit community (0 = no hit)
119
  c_err = 0 # sum of sizes of communities with better step-1 score than the hit
120
  community_size = 0 # size of the hit community
 
121
 
122
  for rank_0indexed in range(num_communities):
123
  cid = int(community_order[rank_0indexed].item())
124
  c_entities_tensor = (community_membership == cid).nonzero(as_tuple=True)[0]
125
  c_size = int(c_entities_tensor.shape[0])
126
-
127
  c_entities = [int(e.item()) for e in c_entities_tensor]
128
- scored_candidates = [e for e in c_entities if e in valid_answers]
129
-
130
- for candidate in scored_candidates:
131
- # Build qi_m directly: anchor + resolved-variable positions stay fixed;
132
- # answer and phantom intersection nodes receive the candidate entity.
133
- qi_m = qi_skeleton.copy()
134
- qi_m.vs["e"] = [
135
- entities_skeleton[i] if entities_skeleton[i] != -1 else candidate
136
- for i in range(num_tree_nodes)
137
- ]
138
- valid_qi_mappeds.append(qi_m)
139
- valid_candidates.append(candidate)
140
-
141
- if valid_qi_mappeds:
142
  rank_c = rank_0indexed + 1
143
  community_size = c_size
 
144
  break
145
 
146
  c_err += c_size
147
 
148
- if not valid_qi_mappeds:
149
  raise InvalidRequestError(
150
  "No entities in the knowledge graph satisfy this query"
151
  )
152
 
153
- # Build batched QueryData and score in one forward pass
154
- n_valid = len(valid_qi_mappeds)
 
 
 
 
155
  e_batch, x_batch, c_batch, edge_attr_batch = [], [], [], []
156
  for i in range(num_tree_nodes):
157
- entities_i = pt.tensor([qm.vs[i]["e"] for qm in valid_qi_mappeds], dtype=pt.long, device=device)
 
 
 
158
  e_batch.append(entities_i)
159
  x_batch.append(one_hot(node_types_tensor[entities_i], num_node_types).float())
160
  c_batch.append(community_membership[entities_i])
161
  for j in range(num_tree_edges):
162
- r_label = valid_qi_mappeds[0].es[j]["r"]
163
  if "p" in r_label:
164
  r_id = int(r_label[1:])
165
  edge_attr_batch.append(
166
- one_hot(pt.full([n_valid], r_id, dtype=pt.long, device=device), num_relations + 1).float()
167
  )
168
  else:
169
  edge_attr_batch.append(
170
- one_hot(pt.full([n_valid], num_relations, dtype=pt.long, device=device), num_relations + 1).float()
171
  )
172
 
173
  with pt.no_grad():
174
  batched_query = QueryData(query, e=e_batch, x=x_batch, c=c_batch, edge_attr=edge_attr_batch)
175
- q_emb, a_emb = embedder(batched_query)
176
- scores = link_ranker(q_emb, a_emb).view(-1) # ensure 1D even for batch_size=1
 
 
 
177
 
178
- k = min(top_k, n_valid)
179
  top_scores, top_indices = scores.topk(k)
180
  step2_ms = (time.perf_counter() - t2_start) * 1000.0
181
 
@@ -183,7 +193,7 @@ def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
183
  inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
184
  predictions = []
185
  for intra_community_rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), 1):
186
- entity_id = valid_candidates[idx]
187
  raw_name = str(inv_nodes.get(entity_id, entity_id))
188
  predictions.append({
189
  "intra_community_rank": intra_community_rank,
@@ -191,6 +201,7 @@ def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
191
  "entity_id": entity_id,
192
  "entity_name": clean_entity_name(raw_name, dataset_id),
193
  "score": round(float(score), 4),
 
194
  })
195
 
196
  total_ms = step1_ms + step2_ms
 
17
  from graph_completion.graphs.queries import Query, get_all_answers, get_node_cut_cache, query_edge_r_to_int
18
  from graph_completion.graphs.preprocess import QueryData
19
 
20
+ # Step-2 mini-batch cap: the link-ranker runs a full forward pass per candidate,
21
+ # and the hit community can hold thousands of nodes on Freebase. Splitting keeps
22
+ # CPU memory bounded without changing results.
23
+ SCORING_MINI_BATCH_SIZE = 512
24
+
25
 
26
  def coins_predict_inner(experiment, dataset_id, algorithm, query_structure_id,
27
  anchors, variables, relations_map, top_k):
 
109
  community_order = community_scores.argsort(descending=True) # [K]
110
  step1_ms = (time.perf_counter() - t1_start) * 1000.0
111
 
112
+ # ---- Step 2: Score every entity in the hit community ----
113
+ # Mirrors rank_samples (experiments.py:782-786): the hit community is the
114
+ # first one in step-1 order that contains *any* KG-valid answer; within it
115
+ # we score all entities (minus anchors) so the link ranker actually has to
116
+ # discriminate, instead of being handed only the known answers.
117
  t2_start = time.perf_counter()
118
 
119
  valid_answers = set(get_all_answers(qi_skeleton, query, adj_s_to_t))
 
122
  "No entities in the knowledge graph satisfy this query"
123
  )
124
 
125
+ anchor_entity_ids = {
126
+ entities_skeleton[i] for i in query.query_anchors if entities_skeleton[i] != -1
127
+ }
128
+
129
  rank_c = 0 # 1-based rank of the hit community (0 = no hit)
130
  c_err = 0 # sum of sizes of communities with better step-1 score than the hit
131
  community_size = 0 # size of the hit community
132
+ candidates = []
133
 
134
  for rank_0indexed in range(num_communities):
135
  cid = int(community_order[rank_0indexed].item())
136
  c_entities_tensor = (community_membership == cid).nonzero(as_tuple=True)[0]
137
  c_size = int(c_entities_tensor.shape[0])
 
138
  c_entities = [int(e.item()) for e in c_entities_tensor]
139
+
140
+ if any(e in valid_answers for e in c_entities):
 
 
 
 
 
 
 
 
 
 
 
 
141
  rank_c = rank_0indexed + 1
142
  community_size = c_size
143
+ candidates = [e for e in c_entities if e not in anchor_entity_ids]
144
  break
145
 
146
  c_err += c_size
147
 
148
+ if not candidates:
149
  raise InvalidRequestError(
150
  "No entities in the knowledge graph satisfy this query"
151
  )
152
 
153
+ n_candidates = len(candidates)
154
+ candidates_tensor = pt.tensor(candidates, dtype=pt.long, device=device)
155
+
156
+ # Build per-tree-node entity columns. Anchor / resolved-variable positions
157
+ # repeat the same id across the batch; the answer and phantom-i positions
158
+ # take each candidate.
159
  e_batch, x_batch, c_batch, edge_attr_batch = [], [], [], []
160
  for i in range(num_tree_nodes):
161
+ if entities_skeleton[i] == -1:
162
+ entities_i = candidates_tensor
163
+ else:
164
+ entities_i = pt.full([n_candidates], entities_skeleton[i], dtype=pt.long, device=device)
165
  e_batch.append(entities_i)
166
  x_batch.append(one_hot(node_types_tensor[entities_i], num_node_types).float())
167
  c_batch.append(community_membership[entities_i])
168
  for j in range(num_tree_edges):
169
+ r_label = qi_skeleton.es[j]["r"]
170
  if "p" in r_label:
171
  r_id = int(r_label[1:])
172
  edge_attr_batch.append(
173
+ one_hot(pt.full([n_candidates], r_id, dtype=pt.long, device=device), num_relations + 1).float()
174
  )
175
  else:
176
  edge_attr_batch.append(
177
+ one_hot(pt.full([n_candidates], num_relations, dtype=pt.long, device=device), num_relations + 1).float()
178
  )
179
 
180
  with pt.no_grad():
181
  batched_query = QueryData(query, e=e_batch, x=x_batch, c=c_batch, edge_attr=edge_attr_batch)
182
+ score_chunks = []
183
+ for chunk in batched_query.batch_split(SCORING_MINI_BATCH_SIZE):
184
+ q_emb, a_emb = embedder(chunk)
185
+ score_chunks.append(link_ranker(q_emb, a_emb).view(-1))
186
+ scores = pt.cat(score_chunks)
187
 
188
+ k = min(top_k, n_candidates)
189
  top_scores, top_indices = scores.topk(k)
190
  step2_ms = (time.perf_counter() - t2_start) * 1000.0
191
 
 
193
  inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
194
  predictions = []
195
  for intra_community_rank, (idx, score) in enumerate(zip(top_indices.tolist(), top_scores.tolist()), 1):
196
+ entity_id = candidates[idx]
197
  raw_name = str(inv_nodes.get(entity_id, entity_id))
198
  predictions.append({
199
  "intra_community_rank": intra_community_rank,
 
201
  "entity_id": entity_id,
202
  "entity_name": clean_entity_name(raw_name, dataset_id),
203
  "score": round(float(score), 4),
204
+ "is_valid_answer": entity_id in valid_answers,
205
  })
206
 
207
  total_ms = step1_ms + step2_ms
src/frontend/src/components/coins/PredictionList.vue CHANGED
@@ -19,7 +19,14 @@ function barWidth(score) {
19
  #{{ p.rank }}
20
  </div>
21
  <div class="pred-body">
22
- <div class="pred-name" :title="`entity id ${p.entity_id}`">{{ p.entity_name }}</div>
 
 
 
 
 
 
 
23
  <div class="pred-bar">
24
  <div class="pred-bar-fill" :style="{ width: barWidth(p.score) }"></div>
25
  </div>
@@ -55,6 +62,7 @@ function barWidth(score) {
55
  }
56
  .pred-body { min-width: 0; }
57
  .pred-name { font-weight: 600; word-break: break-word; }
 
58
  .pred-bar {
59
  width: 100%;
60
  height: 6px;
 
19
  #{{ p.rank }}
20
  </div>
21
  <div class="pred-body">
22
+ <div class="pred-name" :title="`entity id ${p.entity_id}`">
23
+ {{ p.entity_name }}
24
+ <i
25
+ v-if="p.is_valid_answer"
26
+ class="check circle icon valid-flag"
27
+ title="Known KG answer — this entity actually satisfies the query"
28
+ ></i>
29
+ </div>
30
  <div class="pred-bar">
31
  <div class="pred-bar-fill" :style="{ width: barWidth(p.score) }"></div>
32
  </div>
 
62
  }
63
  .pred-body { min-width: 0; }
64
  .pred-name { font-weight: 600; word-break: break-word; }
65
+ .pred-name .valid-flag { color: var(--primary-strong); margin-left: 0.35rem; font-size: 0.95em; }
66
  .pred-bar {
67
  width: 100%;
68
  height: 6px;