import os import sys import torch.backends.openmp from torch import autograd from torch.utils import cpp_extension module = sys.modules[__name__] class RSPMMAddMulFunction(autograd.Function): @staticmethod def forward(ctx, edge_index, edge_type, edge_weight, relation, input): node_in, node_out = edge_index key = node_in * (node_out.max() + 1) + node_out assert (key.diff() >= 0).all(), "Expect sorted `edge_index`" if input.device.type == "cuda": forward = rspmm.rspmm_add_mul_forward_cuda else: forward = rspmm.rspmm_add_mul_forward_cpu output = forward(edge_index, edge_type, edge_weight, relation, input) ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output) return output @staticmethod def backward(ctx, output_grad): if output_grad.device.type == "cuda": backward = rspmm.rspmm_add_mul_backward_cuda else: backward = rspmm.rspmm_add_mul_backward_cpu weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) return None, None, weight_grad, relation_grad, input_grad class RSPMMMinMulFunction(autograd.Function): @staticmethod def forward(ctx, edge_index, edge_type, edge_weight, relation, input): node_in, node_out = edge_index key = node_in * (node_out.max() + 1) + node_out assert (key.diff() >= 0).all(), "Expect sorted `edge_index`" if input.device.type == "cuda": forward = rspmm.rspmm_min_mul_forward_cuda else: forward = rspmm.rspmm_min_mul_forward_cpu output = forward(edge_index, edge_type, edge_weight, relation, input) ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output) return output @staticmethod def backward(ctx, output_grad): if output_grad.device.type == "cuda": backward = rspmm.rspmm_min_mul_backward_cuda else: backward = rspmm.rspmm_min_mul_backward_cpu weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) return None, None, weight_grad, relation_grad, input_grad class RSPMMMaxMulFunction(autograd.Function): @staticmethod def forward(ctx, edge_index, edge_type, edge_weight, relation, input): node_in, node_out = edge_index key = node_in * (node_out.max() + 1) + node_out assert (key.diff() >= 0).all(), "Expect sorted `edge_index`" if input.device.type == "cuda": forward = rspmm.rspmm_max_mul_forward_cuda else: forward = rspmm.rspmm_max_mul_forward_cpu output = forward(edge_index, edge_type, edge_weight, relation, input) ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output) return output @staticmethod def backward(ctx, output_grad): if output_grad.device.type == "cuda": backward = rspmm.rspmm_max_mul_backward_cuda else: backward = rspmm.rspmm_max_mul_backward_cpu weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) return None, None, weight_grad, relation_grad, input_grad class RSPMMAddAddFunction(autograd.Function): @staticmethod def forward(ctx, edge_index, edge_type, edge_weight, relation, input): node_in, node_out = edge_index key = node_in * (node_out.max() + 1) + node_out assert (key.diff() >= 0).all(), "Expect sorted `edge_index`" if input.device.type == "cuda": forward = rspmm.rspmm_add_add_forward_cuda else: forward = rspmm.rspmm_add_add_forward_cpu output = forward(edge_index, edge_type, edge_weight, relation, input) ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output) return output @staticmethod def backward(ctx, output_grad): if output_grad.device.type == "cuda": backward = rspmm.rspmm_add_add_backward_cuda else: backward = rspmm.rspmm_add_add_backward_cpu weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) return None, None, weight_grad, relation_grad, input_grad class RSPMMMinAddFunction(autograd.Function): @staticmethod def forward(ctx, edge_index, edge_type, edge_weight, relation, input): node_in, node_out = edge_index key = node_in * (node_out.max() + 1) + node_out assert (key.diff() >= 0).all(), "Expect sorted `edge_index`" if input.device.type == "cuda": forward = rspmm.rspmm_min_add_forward_cuda else: forward = rspmm.rspmm_min_add_forward_cpu output = forward(edge_index, edge_type, edge_weight, relation, input) ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output) return output @staticmethod def backward(ctx, output_grad): if output_grad.device.type == "cuda": backward = rspmm.rspmm_min_add_backward_cuda else: backward = rspmm.rspmm_min_add_backward_cpu weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) return None, None, weight_grad, relation_grad, input_grad class RSPMMMaxAddFunction(autograd.Function): @staticmethod def forward(ctx, edge_index, edge_type, edge_weight, relation, input): node_in, node_out = edge_index key = node_in * (node_out.max() + 1) + node_out assert (key.diff() >= 0).all(), "Expect sorted `edge_index`" if input.device.type == "cuda": forward = rspmm.rspmm_max_add_forward_cuda else: forward = rspmm.rspmm_max_add_forward_cpu output = forward(edge_index, edge_type, edge_weight, relation, input) ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output) return output @staticmethod def backward(ctx, output_grad): if output_grad.device.type == "cuda": backward = rspmm.rspmm_max_add_backward_cuda else: backward = rspmm.rspmm_max_add_backward_cpu weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad) return None, None, weight_grad, relation_grad, input_grad def generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul="mul"): name = "RSPMM%s%sFunction" % (sum.capitalize(), mul.capitalize()) if not hasattr(module, name): raise ValueError("No generalized rspmm implementation found for summation `%s` and multiplication `%s`" % (sum, mul)) Function = getattr(module, name) node_in, node_out = edge_index key = node_in * (node_out.max() + 1) + node_out order = key.argsort() return Function.apply(edge_index[:, order], edge_type[order], edge_weight[order], relation, input) def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs): if extra_cflags is None: extra_cflags = ["-Ofast"] if torch.backends.openmp.is_available(): extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"] else: extra_cflags.append("-DAT_PARALLEL_NATIVE") if extra_cuda_cflags is None: if torch.cuda.is_available(): extra_cuda_cflags = ["-O3"] extra_cflags.append("-DCUDA_OP") else: new_sources = [] for source in sources: if not cpp_extension._is_cuda_file(source): new_sources.append(source) sources = new_sources return cpp_extension.load(name, sources, extra_cflags, extra_cuda_cflags, **kwargs) print("Load rspmm extension. This may take a while...") path = os.path.join(os.path.dirname(__file__), "source") rspmm = load_extension("rspmm", [os.path.join(path, "rspmm.cpp"), os.path.join(path, "rspmm.cu")])