from functools import reduce from torch_scatter import scatter_add from torch_geometric.data import Data import torch def edge_match(edge_index, query_index): # O((n + q)logn) time # O(n) memory # edge_index: big underlying graph # query_index: edges to match # preparing unique hashing of edges, base: (max_node, max_relation) + 1 base = edge_index.max(dim=1)[0] + 1 # we will map edges to long ints, so we need to make sure the maximum product is less than MAX_LONG_INT # idea: max number of edges = num_nodes * num_relations # e.g. for a graph of 10 nodes / 5 relations, edge IDs 0...9 mean all possible outgoing edge types from node 0 # given a tuple (h, r), we will search for all other existing edges starting from head h assert reduce(int.__mul__, base.tolist()) < torch.iinfo(torch.long).max scale = base.cumprod(0) scale = scale[-1] // scale # hash both the original edge index and the query index to unique integers edge_hash = (edge_index * scale.unsqueeze(-1)).sum(dim=0) edge_hash, order = edge_hash.sort() query_hash = (query_index * scale.unsqueeze(-1)).sum(dim=0) # matched ranges: [start[i], end[i]) start = torch.bucketize(query_hash, edge_hash) end = torch.bucketize(query_hash, edge_hash, right=True) # num_match shows how many edges satisfy the (h, r) pattern for each query in the batch num_match = end - start # generate the corresponding ranges offset = num_match.cumsum(0) - num_match range = torch.arange(num_match.sum(), device=edge_index.device) range = range + (start - offset).repeat_interleave(num_match) return order[range], num_match def negative_sampling(data, batch, num_negative, strict=True): batch_size = len(batch) pos_h_index, pos_t_index, pos_r_index = batch.t() # strict negative sampling vs random negative sampling if strict: t_mask, h_mask = strict_negative_mask(data, batch) t_mask = t_mask[:batch_size // 2] neg_t_candidate = t_mask.nonzero()[:, 1] num_t_candidate = t_mask.sum(dim=-1) # draw samples for negative tails rand = torch.rand(len(t_mask), num_negative, device=batch.device) index = (rand * num_t_candidate.unsqueeze(-1)).long() index = index + (num_t_candidate.cumsum(0) - num_t_candidate).unsqueeze(-1) neg_t_index = neg_t_candidate[index] h_mask = h_mask[batch_size // 2:] neg_h_candidate = h_mask.nonzero()[:, 1] num_h_candidate = h_mask.sum(dim=-1) # draw samples for negative heads rand = torch.rand(len(h_mask), num_negative, device=batch.device) index = (rand * num_h_candidate.unsqueeze(-1)).long() index = index + (num_h_candidate.cumsum(0) - num_h_candidate).unsqueeze(-1) neg_h_index = neg_h_candidate[index] else: neg_index = torch.randint(data.num_nodes, (batch_size, num_negative), device=batch.device) neg_t_index, neg_h_index = neg_index[:batch_size // 2], neg_index[batch_size // 2:] h_index = pos_h_index.unsqueeze(-1).repeat(1, num_negative + 1) t_index = pos_t_index.unsqueeze(-1).repeat(1, num_negative + 1) r_index = pos_r_index.unsqueeze(-1).repeat(1, num_negative + 1) t_index[:batch_size // 2, 1:] = neg_t_index h_index[batch_size // 2:, 1:] = neg_h_index return torch.stack([h_index, t_index, r_index], dim=-1) def all_negative(data, batch): pos_h_index, pos_t_index, pos_r_index = batch.t() r_index = pos_r_index.unsqueeze(-1).expand(-1, data.num_nodes) # generate all negative tails for this batch all_index = torch.arange(data.num_nodes, device=batch.device) h_index, t_index = torch.meshgrid(pos_h_index, all_index, indexing="ij") # indexing "xy" would return transposed t_batch = torch.stack([h_index, t_index, r_index], dim=-1) # generate all negative heads for this batch all_index = torch.arange(data.num_nodes, device=batch.device) t_index, h_index = torch.meshgrid(pos_t_index, all_index, indexing="ij") h_batch = torch.stack([h_index, t_index, r_index], dim=-1) return t_batch, h_batch def strict_negative_mask(data, batch): # this function makes sure that for a given (h, r) batch we will NOT sample true tails as random negatives # similarly, for a given (t, r) we will NOT sample existing true heads as random negatives pos_h_index, pos_t_index, pos_r_index = batch.t() # part I: sample hard negative tails # edge index of all (head, relation) edges from the underlying graph edge_index = torch.stack([data.edge_index[0], data.edge_type]) # edge index of current batch (head, relation) for which we will sample negatives query_index = torch.stack([pos_h_index, pos_r_index]) # search for all true tails for the given (h, r) batch edge_id, num_t_truth = edge_match(edge_index, query_index) # build an index from the found edges t_truth_index = data.edge_index[1, edge_id] sample_id = torch.arange(len(num_t_truth), device=batch.device).repeat_interleave(num_t_truth) t_mask = torch.ones(len(num_t_truth), data.num_nodes, dtype=torch.bool, device=batch.device) # assign 0s to the mask with the found true tails t_mask[sample_id, t_truth_index] = 0 t_mask.scatter_(1, pos_t_index.unsqueeze(-1), 0) # part II: sample hard negative heads # edge_index[1] denotes tails, so the edge index becomes (t, r) edge_index = torch.stack([data.edge_index[1], data.edge_type]) # edge index of current batch (tail, relation) for which we will sample heads query_index = torch.stack([pos_t_index, pos_r_index]) # search for all true heads for the given (t, r) batch edge_id, num_h_truth = edge_match(edge_index, query_index) # build an index from the found edges h_truth_index = data.edge_index[0, edge_id] sample_id = torch.arange(len(num_h_truth), device=batch.device).repeat_interleave(num_h_truth) h_mask = torch.ones(len(num_h_truth), data.num_nodes, dtype=torch.bool, device=batch.device) # assign 0s to the mask with the found true heads h_mask[sample_id, h_truth_index] = 0 h_mask.scatter_(1, pos_h_index.unsqueeze(-1), 0) return t_mask, h_mask def compute_ranking(pred, target, mask=None): pos_pred = pred.gather(-1, target.unsqueeze(-1)) if mask is not None: # filtered ranking ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1 else: # unfiltered ranking ranking = torch.sum(pos_pred <= pred, dim=-1) + 1 return ranking def build_relation_graph(graph): # expect the graph is already with inverse edges edge_index, edge_type = graph.edge_index, graph.edge_type num_nodes, num_rels = graph.num_nodes, graph.num_relations device = edge_index.device Eh = torch.vstack([edge_index[0], edge_type]).T.unique(dim=0) # (num_edges, 2) Dh = scatter_add(torch.ones_like(Eh[:, 1]), Eh[:, 0]) EhT = torch.sparse_coo_tensor( torch.flip(Eh, dims=[1]).T, torch.ones(Eh.shape[0], device=device) / Dh[Eh[:, 0]], (num_rels, num_nodes) ) Eh = torch.sparse_coo_tensor( Eh.T, torch.ones(Eh.shape[0], device=device), (num_nodes, num_rels) ) Et = torch.vstack([edge_index[1], edge_type]).T.unique(dim=0) # (num_edges, 2) Dt = scatter_add(torch.ones_like(Et[:, 1]), Et[:, 0]) assert not (Dt[Et[:, 0]] == 0).any() EtT = torch.sparse_coo_tensor( torch.flip(Et, dims=[1]).T, torch.ones(Et.shape[0], device=device) / Dt[Et[:, 0]], (num_rels, num_nodes) ) Et = torch.sparse_coo_tensor( Et.T, torch.ones(Et.shape[0], device=device), (num_nodes, num_rels) ) Ahh = torch.sparse.mm(EhT, Eh).coalesce() Att = torch.sparse.mm(EtT, Et).coalesce() Aht = torch.sparse.mm(EhT, Et).coalesce() Ath = torch.sparse.mm(EtT, Eh).coalesce() hh_edges = torch.cat([Ahh.indices().T, torch.zeros(Ahh.indices().T.shape[0], 1, dtype=torch.long).fill_(0)], dim=1) # head to head tt_edges = torch.cat([Att.indices().T, torch.zeros(Att.indices().T.shape[0], 1, dtype=torch.long).fill_(1)], dim=1) # tail to tail ht_edges = torch.cat([Aht.indices().T, torch.zeros(Aht.indices().T.shape[0], 1, dtype=torch.long).fill_(2)], dim=1) # head to tail th_edges = torch.cat([Ath.indices().T, torch.zeros(Ath.indices().T.shape[0], 1, dtype=torch.long).fill_(3)], dim=1) # tail to head rel_graph = Data( edge_index=torch.cat([hh_edges[:, [0, 1]].T, tt_edges[:, [0, 1]].T, ht_edges[:, [0, 1]].T, th_edges[:, [0, 1]].T], dim=1), edge_type=torch.cat([hh_edges[:, 2], tt_edges[:, 2], ht_edges[:, 2], th_edges[:, 2]], dim=0), num_nodes=num_rels, num_relations=4 ) graph.relation_graph = rel_graph return graph