| | """ |
| | input: rg |
| | output (fixed 100 candidates, for path-based reranking): |
| | { |
| | "query": query, |
| | "pred_dict": {node_id: score}, |
| | "ans_ids": [], |
| | 'paths': {node_id: [node_ids_path]} |
| | } |
| | |
| | """ |
| | import sys |
| | import os |
| | sys.path.append(os.path.dirname(os.path.abspath(os.getcwd()))) |
| |
|
| | from utils import combine_dicts, parse_metapath, get_scorer, get_text_retriever, fix_length |
| | from models.model import ModelForSTaRKQA |
| | import time |
| |
|
| |
|
| |
|
| | class Stru4Path(ModelForSTaRKQA): |
| | def __init__(self, dataset_name, text_retriever_name, scorer_name, skb, topk=100): |
| | super(Stru4Path, self).__init__(skb) |
| | self.dataset_name = dataset_name |
| | self.text_retriever = get_text_retriever(dataset_name, text_retriever_name, skb) |
| | self.scorer = get_scorer(dataset_name, scorer_name=scorer_name, skb=skb) |
| | |
| | self.topk = topk |
| | self.node_type_list = skb.node_type_lst() |
| | self.edge_type_list = skb.rel_type_lst() |
| | if self.dataset_name == "prime": |
| | self.tp_list = skb.get_tuples() |
| | self.target_type_list = skb.candidate_types |
| | else: |
| | self.tp_dict = {(tp[0], tp[-1]): tp[1] for tp in skb.get_tuples()} |
| | self.target_type_list = ['paper' if dataset_name == 'mag' else 'product'] |
| | |
| | self.skb = skb |
| | self.ini_k = 5 |
| | self.stru_count = 0 |
| | |
| | |
| | |
| | |
| | def rg2routes(self, rg): |
| | """ |
| | input: rg: {"Metapath": "", "Restriction": {}} |
| | output: routes: [['paper', 'author', 'paper'], ['paper', 'paper']] |
| | """ |
| | |
| | metapath = rg["Metapath"] |
| | if isinstance(rg["Metapath"], list): |
| | routes = rg["Metapath"] |
| | elif isinstance(rg["Metapath"], str): |
| | routes = parse_metapath(metapath) |
| | else: |
| | return None |
| | |
| | return routes |
| | |
| | def check_valid(self, routes, rg): |
| | |
| | if not routes: |
| | |
| | return None |
| | |
| | if len(routes) == 1 and len(routes[0]) == 1: |
| | return 1 |
| | |
| | |
| | target_type_valid_routes = [ |
| | route for route in routes if route[-1] in self.target_type_list |
| | ] |
| | if not target_type_valid_routes: |
| | return None |
| |
|
| | |
| | type_valid_routes = [ |
| | route |
| | for route in target_type_valid_routes |
| | if all( |
| | node in self.node_type_list or node in self.edge_type_list |
| | for node in route |
| | ) |
| | ] |
| | if not type_valid_routes: |
| | return None |
| |
|
| | |
| | relation_valid_routes = [] |
| | for route in type_valid_routes: |
| | if self.dataset_name == "prime": |
| | triplets = [ |
| | (route[i], route[i + 1], route[i + 2]) |
| | for i in range(0, len(route) - 2, 2) |
| | ] |
| | |
| | if all(tp in self.tp_list for tp in triplets): |
| | relation_valid_routes.append(route) |
| | else: |
| | pairs = [(route[i], route[i + 1]) for i in range(len(route) - 1)] |
| | if all(tp in self.tp_dict.keys() for tp in pairs): |
| | relations = [self.tp_dict[tp] for tp in pairs] |
| | |
| | |
| | new_route = [] |
| | for i in range(len(relations)): |
| | new_route.append(pairs[i][0]) |
| | new_route.append(relations[i]) |
| | new_route.append(pairs[-1][-1]) |
| | |
| | |
| | relation_valid_routes.append(new_route) |
| |
|
| | if not relation_valid_routes: |
| | return None |
| |
|
| | return relation_valid_routes |
| | |
| | def get_candidates4route(self, query, q_id, route, restriction): |
| | |
| | |
| | ini_node_type = route[0] |
| | |
| | try: |
| | extra_restr = "".join(restriction[ini_node_type]) |
| | except: |
| | extra_restr = "" |
| | ini_dict = self.text_retriever.retrieve(query + " " + extra_restr, q_id=q_id, topk=self.ini_k, node_type=ini_node_type) |
| | current_node_ids = list(ini_dict.keys()) |
| | |
| | |
| | paths = {} |
| | for c_id in current_node_ids: |
| | paths[c_id] = [c_id] |
| | |
| | |
| | hops = len(route) |
| | |
| | for hop in range(0, hops-2, 2): |
| | new_paths = {} |
| | |
| | cur_node_type = route[hop] |
| | next_node_type = route[hop+2] |
| | edge_type = route[hop+1] |
| | next_node_ids = [] |
| | |
| | |
| | for node_id in current_node_ids: |
| | neighbor_ids = self.skb.get_neighbor_nodes(idx=node_id, edge_type=edge_type) |
| | next_node_ids.extend(neighbor_ids) |
| | |
| | |
| | for neighbor_id in neighbor_ids: |
| | new_paths[neighbor_id] = paths[node_id] + [neighbor_id] |
| | |
| | |
| | paths = new_paths |
| | |
| | current_node_ids = list(set(next_node_ids)) |
| |
|
| | candidates = current_node_ids |
| | self.paths.append(paths) |
| | |
| | |
| | return candidates |
| | |
| | def merge_candidate_pools(self, non_empty_candidates_lists): |
| | |
| | |
| | |
| | if len(non_empty_candidates_lists) == 1: |
| | return set(non_empty_candidates_lists[0]) |
| | |
| | result = set(non_empty_candidates_lists[0]) |
| | for lst in non_empty_candidates_lists[1:]: |
| | result.intersection_update(lst) |
| | |
| | |
| | if len(result) == 0: |
| | result = set() |
| | for lst in non_empty_candidates_lists: |
| | result.update(lst) |
| | |
| | |
| | |
| | return list(result) |
| | |
| | def get_mor_candidates(self, query, q_id, valid_routes, restriction): |
| | |
| | |
| | candidates_pool = [] |
| | for route in valid_routes: |
| | if route[0] in restriction.keys() and len(restriction[route[0]]) > 0: |
| | candidates_pool.append(self.get_candidates4route(query, q_id, route, restriction)) |
| | |
| | non_empty_candidates_lists = [lst for lst in candidates_pool if lst] |
| | if not non_empty_candidates_lists: |
| | print(f"123, {non_empty_candidates_lists}") |
| | |
| | |
| | return {} |
| | |
| | |
| | |
| | candidates = self.merge_candidate_pools(candidates_pool) |
| | if not candidates: |
| | return {} |
| | |
| | |
| | |
| | pred_dict = dict(zip(candidates, [-1]*len(candidates))) |
| | |
| | |
| | return pred_dict |
| | |
| | |
| | |
| | def forward(self, query, q_id, ans_ids, rg): |
| | |
| | self.paths = [] |
| | |
| | |
| | |
| | s_time = time.time() |
| | routes = self.rg2routes(rg) |
| | |
| | |
| | |
| | s_time = time.time() |
| | valid_routes = self.check_valid(routes, rg) |
| | |
| | |
| | if valid_routes is None: |
| | |
| | return { |
| | "query": query, |
| | "pred_dict": {}, |
| | "ans_ids": ans_ids, |
| | 'paths': {}, |
| | 'query_pattern': rg['Metapath'] |
| | } |
| | elif valid_routes == 1: |
| | print(f"1234: {valid_routes}") |
| | |
| | pred_dict = self.text_retriever.retrieve(query, q_id=q_id, topk=self.topk, node_type=f'{self.target_type_list[0]}') |
| | |
| | else: |
| | |
| | |
| | if self.dataset_name == "prime": |
| | pass |
| | else: |
| | valid_routes = [route[-5:] for route in valid_routes] |
| | |
| | restriction = rg["Restriction"] |
| | pred_dict = self.get_mor_candidates(query, q_id, valid_routes, restriction) |
| | self.stru_count += 1 |
| | |
| | |
| | if self.paths: |
| | self.paths = combine_dicts(self.paths, pred_dict=pred_dict) |
| | |
| | else: |
| | self.paths = {} |
| | for node_id in pred_dict.keys(): |
| | self.paths[node_id] = [node_id] |
| | |
| | |
| | if not pred_dict: |
| | return { |
| | "query": query, |
| | "pred_dict": {}, |
| | "ans_ids": ans_ids, |
| | 'paths': {}, |
| | 'query_pattern': rg['Metapath'] |
| | } |
| | |
| | |
| | pred_dict = self.scorer.score(query, q_id, list(pred_dict.keys())) |
| | |
| | |
| | |
| | |
| | if len(self.paths) != len(pred_dict): |
| | print(f"paths: {self.paths}") |
| | print(f"pred_dict: {pred_dict}") |
| | raise ValueError(f"Length mismatch between paths and pred_dict: {len(self.paths)}, {len(pred_dict)}") |
| |
|
| | output = { |
| | "query": query, |
| | "pred_dict": pred_dict, |
| | "ans_ids": ans_ids, |
| | 'paths': self.paths, |
| | 'query_pattern': rg['Metapath'], |
| | 'rg': rg |
| | } |
| | |
| | |
| | return output |
| | |
| |
|
| | |
| |
|
| |
|