File size: 6,156 Bytes
89650c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#pragma once

#include <tuple>

#include <torch/extension.h>
//#include <ATen/SparseTensorUtils.h>
#include <ATen/native/SparseTensorUtils.h>

namespace at {

using namespace at::sparse;

void rspmm_forward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
                         const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg);

void rspmm_backward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
                          const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg,
                          const TensorArg &output_arg, const TensorArg &output_grad_arg);

Tensor ind2ptr(const Tensor &index, int size);

Tensor rspmm_add_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_add_mul_backward_cpu(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

Tensor rspmm_min_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_min_mul_backward_cpu(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

Tensor rspmm_max_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_max_mul_backward_cpu(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

Tensor rspmm_add_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_add_add_backward_cpu(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

Tensor rspmm_min_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_min_add_backward_cpu(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

Tensor rspmm_max_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_max_add_backward_cpu(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

#ifdef CUDA_OP
Tensor rspmm_add_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_add_mul_backward_cuda(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

Tensor rspmm_min_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_min_mul_backward_cuda(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

Tensor rspmm_max_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_max_mul_backward_cuda(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

Tensor rspmm_add_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_add_add_backward_cuda(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

Tensor rspmm_min_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_min_add_backward_cuda(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);

Tensor rspmm_max_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
                                 const Tensor &relation, const Tensor &input);

std::tuple<Tensor, Tensor, Tensor> rspmm_max_add_backward_cuda(
        const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
        const Tensor &input, const Tensor &output, const Tensor &output_grad);
#endif

} // namespace at