Huhujingjing commited on
Commit
940d605
1 Parent(s): 1b735ec

Upload model

Browse files
Files changed (4) hide show
  1. config.json +24 -0
  2. configuration_mxm.py +46 -0
  3. modeling_mxm.py +1041 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MXMModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_mxm.MXMConfig",
7
+ "AutoModel": "modeling_mxm.MXMModel"
8
+ },
9
+ "cutoff": 5.0,
10
+ "dim": 128,
11
+ "envelope_exponent": 5,
12
+ "model_type": "mxm",
13
+ "n_layer": 6,
14
+ "num_radial": 6,
15
+ "num_spherical": 7,
16
+ "processor_class": "SmilesProcessor",
17
+ "smiles": [
18
+ "C",
19
+ "CC",
20
+ "CCC"
21
+ ],
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.29.2"
24
+ }
configuration_mxm.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class MXMConfig(PretrainedConfig):
6
+ model_type = "mxm"
7
+
8
+ def __init__(
9
+ self,
10
+ dim: int=128,
11
+ n_layer: int=6,
12
+ cutoff: float=5.0,
13
+ num_spherical: int=7,
14
+ num_radial: int=6,
15
+ envelope_exponent: int=5,
16
+
17
+ smiles: List[str] = None,
18
+ processor_class: str = "SmilesProcessor",
19
+ **kwargs,
20
+ ):
21
+
22
+ self.dim = dim # the dimension of input feature
23
+ self.n_layer = n_layer # the number of GCN layers
24
+ self.cutoff = cutoff # the cutoff distance for neighbor searching
25
+ self.num_spherical = num_spherical # the number of spherical harmonics
26
+ self.num_radial = num_radial # the number of radial basis
27
+ self.envelope_exponent = envelope_exponent # the envelope exponent
28
+
29
+ self.smiles = smiles # process smiles
30
+ self.processor_class = processor_class
31
+
32
+ super().__init__(**kwargs)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ mxm_config = MXMConfig(
37
+ dim=128,
38
+ n_layer=6,
39
+ cutoff=5.0,
40
+ num_spherical=7,
41
+ num_radial=6,
42
+ envelope_exponent=5,
43
+ smiles=["C", "CC", "CCC"],
44
+ processor_class="SmilesProcessor"
45
+ )
46
+ mxm_config.save_pretrained("custom-mxm")
modeling_mxm.py ADDED
@@ -0,0 +1,1041 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import Parameter, Sequential, ModuleList, Linear
6
+
7
+ from rdkit import Chem
8
+ from rdkit.Chem import AllChem
9
+
10
+ from transformers import PretrainedConfig
11
+ from transformers import PreTrainedModel
12
+ from transformers import AutoModel
13
+
14
+ from torch_geometric.data import Data
15
+ from torch_geometric.loader import DataLoader
16
+ from torch_geometric.utils import remove_self_loops, add_self_loops, sort_edge_index
17
+ from torch_scatter import scatter
18
+ from torch_geometric.nn import global_add_pool, radius
19
+ from torch_sparse import SparseTensor
20
+
21
+ from mxm_model.configuration_mxm import MXMConfig
22
+
23
+ from tqdm import tqdm
24
+ import numpy as np
25
+ import pandas as pd
26
+ from typing import List
27
+ import math
28
+ import inspect
29
+ from operator import itemgetter
30
+ from collections import OrderedDict
31
+ from math import sqrt, pi as PI
32
+ from scipy.optimize import brentq
33
+ from scipy import special as sp
34
+
35
+ try:
36
+ import sympy as sym
37
+ except ImportError:
38
+ sym = None
39
+
40
+
41
+
42
+ class SmilesDataset(torch.utils.data.Dataset):
43
+ def __init__(self, smiles):
44
+ self.smiles_list = smiles
45
+ self.data_list = []
46
+
47
+
48
+ def __len__(self):
49
+ return len(self.data_list)
50
+
51
+ def __getitem__(self, idx):
52
+ return self.data_list[idx]
53
+
54
+ def get_data(self, smiles):
55
+ self.smiles_list = smiles
56
+ # self.data_list = []
57
+ # bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}
58
+ types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'S': 4}
59
+
60
+ for i in range(len(self.smiles_list)):
61
+ # 将 SMILES 表示转换为 RDKit 的分子对象
62
+ # print(self.smiles_list[i])
63
+ mol = Chem.MolFromSmiles(self.smiles_list[i]) # 从smiles编码中获取结构信息
64
+ if mol is None:
65
+ print("无法创建Mol对象", self.smiles_list[i])
66
+ else:
67
+
68
+ mol3d = Chem.AddHs(
69
+ mol) # 在rdkit中,分子在默认情况下是不显示氢的,但氢原子对于真实的几何构象计算有很大的影响,所以在计算3D构象前,需要使用Chem.AddHs()方法加上氢原子
70
+ if mol3d is None:
71
+ print("无法创建mol3d对象", self.smiles_list[i])
72
+ else:
73
+ AllChem.EmbedMolecule(mol3d, randomSeed=1) # 生成3D构象
74
+
75
+ N = mol3d.GetNumAtoms()
76
+ # 获取原子坐标信息
77
+ if mol3d.GetNumConformers() > 0:
78
+ conformer = mol3d.GetConformer()
79
+ pos = conformer.GetPositions()
80
+ pos = torch.tensor(pos, dtype=torch.float)
81
+
82
+ type_idx = []
83
+ # atomic_number = []
84
+ # aromatic = []
85
+ # sp = []
86
+ # sp2 = []
87
+ # sp3 = []
88
+ for atom in mol3d.GetAtoms():
89
+ type_idx.append(types[atom.GetSymbol()])
90
+ # atomic_number.append(atom.GetAtomicNum())
91
+ # aromatic.append(1 if atom.GetIsAromatic() else 0)
92
+ # hybridization = atom.GetHybridization()
93
+ # sp.append(1 if hybridization == HybridizationType.SP else 0)
94
+ # sp2.append(1 if hybridization == HybridizationType.SP2 else 0)
95
+ # sp3.append(1 if hybridization == HybridizationType.SP3 else 0)
96
+
97
+ # z = torch.tensor(atomic_number, dtype=torch.long)
98
+
99
+ row, col, edge_type = [], [], []
100
+ for bond in mol3d.GetBonds():
101
+ start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
102
+ row += [start, end]
103
+ col += [end, start]
104
+ # edge_type += 2 * [bonds[bond.GetBondType()]]
105
+
106
+ edge_index = torch.tensor([row, col], dtype=torch.long)
107
+ # edge_type = torch.tensor(edge_type, dtype=torch.long)
108
+ # edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float)
109
+
110
+ perm = (edge_index[0] * N + edge_index[1]).argsort()
111
+ edge_index = edge_index[:, perm]
112
+ # edge_type = edge_type[perm]
113
+ # edge_attr = edge_attr[perm]
114
+ #
115
+ # row, col = edge_index
116
+ # hs = (z == 1).to(torch.float)
117
+
118
+ x = torch.tensor(type_idx).to(torch.float)
119
+
120
+ # y = self.y_list[i]
121
+
122
+ data = Data(x=x, pos=pos, edge_index=edge_index, smiles=self.smiles_list[i])
123
+
124
+ self.data_list.append(data)
125
+ else:
126
+ print("无法创建comfor", self.smiles_list[i])
127
+ return self.data_list
128
+
129
+
130
+
131
+
132
+ class EMA:
133
+ def __init__(self, model, decay):
134
+ self.decay = decay
135
+ self.shadow = {}
136
+ self.original = {}
137
+
138
+ # Register model parameters
139
+ for name, param in model.named_parameters():
140
+ if param.requires_grad:
141
+ self.shadow[name] = param.data.clone()
142
+
143
+ def __call__(self, model, num_updates=99999):
144
+ decay = min(self.decay, (1.0 + num_updates) / (10.0 + num_updates))
145
+ for name, param in model.named_parameters():
146
+ if param.requires_grad:
147
+ assert name in self.shadow
148
+ new_average = \
149
+ (1.0 - decay) * param.data + decay * self.shadow[name]
150
+ self.shadow[name] = new_average.clone()
151
+
152
+ def assign(self, model):
153
+ for name, param in model.named_parameters():
154
+ if param.requires_grad:
155
+ assert name in self.shadow
156
+ self.original[name] = param.data.clone()
157
+ param.data = self.shadow[name]
158
+
159
+ def resume(self, model):
160
+ for name, param in model.named_parameters():
161
+ if param.requires_grad:
162
+ assert name in self.shadow
163
+ param.data = self.original[name]
164
+
165
+
166
+ def MLP(channels):
167
+ return Sequential(*[
168
+ Sequential(Linear(channels[i - 1], channels[i]), SiLU())
169
+ for i in range(1, len(channels))])
170
+
171
+
172
+ class Res(nn.Module):
173
+ def __init__(self, dim):
174
+ super(Res, self).__init__()
175
+
176
+ self.mlp = MLP([dim, dim, dim])
177
+
178
+ def forward(self, m):
179
+ m1 = self.mlp(m)
180
+ m_out = m1 + m
181
+ return m_out
182
+
183
+
184
+ def compute_idx(pos, edge_index):
185
+
186
+ pos_i = pos[edge_index[0]]
187
+ pos_j = pos[edge_index[1]]
188
+
189
+ d_ij = torch.norm(abs(pos_j - pos_i), dim=-1, keepdim=False).unsqueeze(-1) + 1e-5
190
+ v_ji = (pos_i - pos_j) / d_ij
191
+
192
+ unique, counts = torch.unique(edge_index[0], sorted=True, return_counts=True) #Get central values
193
+ full_index = torch.arange(0, edge_index[0].size()[0]).cuda().int() #init full index
194
+ #print('full_index', full_index)
195
+
196
+ #Compute 1
197
+ repeat = torch.repeat_interleave(counts, counts)
198
+ counts_repeat1 = torch.repeat_interleave(full_index, repeat) #0,...,0,1,...,1,...
199
+
200
+ #Compute 2
201
+ split = torch.split(full_index, counts.tolist()) #split full index
202
+ index2 = list(edge_index[0].data.cpu().numpy()) #get repeat index
203
+ counts_repeat2 = torch.cat(itemgetter(*index2)(split), dim=0) #0,1,2,...,0,1,2,..
204
+
205
+ #Compute angle embeddings
206
+ v1 = v_ji[counts_repeat1.long()]
207
+ v2 = v_ji[counts_repeat2.long()]
208
+
209
+ angle = (v1*v2).sum(-1).unsqueeze(-1)
210
+ angle = torch.clamp(angle, min=-1.0, max=1.0) + 1e-6 + 1.0
211
+
212
+ return counts_repeat1.long(), counts_repeat2.long(), angle
213
+
214
+
215
+ def Jn(r, n):
216
+ return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r)
217
+
218
+
219
+ def Jn_zeros(n, k):
220
+ zerosj = np.zeros((n, k), dtype='float32')
221
+ zerosj[0] = np.arange(1, k + 1) * np.pi
222
+ points = np.arange(1, k + n) * np.pi
223
+ racines = np.zeros(k + n - 1, dtype='float32')
224
+ for i in range(1, n):
225
+ for j in range(k + n - 1 - i):
226
+ foo = brentq(Jn, points[j], points[j + 1], (i, ))
227
+ racines[j] = foo
228
+ points = racines
229
+ zerosj[i][:k] = racines[:k]
230
+
231
+ return zerosj
232
+
233
+
234
+ def spherical_bessel_formulas(n):
235
+ x = sym.symbols('x')
236
+
237
+ f = [sym.sin(x) / x]
238
+ a = sym.sin(x) / x
239
+ for i in range(1, n):
240
+ b = sym.diff(a, x) / x
241
+ f += [sym.simplify(b * (-x)**i)]
242
+ a = sym.simplify(b)
243
+ return f
244
+
245
+
246
+ def bessel_basis(n, k):
247
+ zeros = Jn_zeros(n, k)
248
+ normalizer = []
249
+ for order in range(n):
250
+ normalizer_tmp = []
251
+ for i in range(k):
252
+ normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2]
253
+ normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5
254
+ normalizer += [normalizer_tmp]
255
+
256
+ f = spherical_bessel_formulas(n)
257
+ x = sym.symbols('x')
258
+ bess_basis = []
259
+ for order in range(n):
260
+ bess_basis_tmp = []
261
+ for i in range(k):
262
+ bess_basis_tmp += [
263
+ sym.simplify(normalizer[order][i] *
264
+ f[order].subs(x, zeros[order, i] * x))
265
+ ]
266
+ bess_basis += [bess_basis_tmp]
267
+ return bess_basis
268
+
269
+
270
+ def sph_harm_prefactor(k, m):
271
+ return ((2 * k + 1) * np.math.factorial(k - abs(m)) /
272
+ (4 * np.pi * np.math.factorial(k + abs(m))))**0.5
273
+
274
+
275
+ def associated_legendre_polynomials(k, zero_m_only=True):
276
+ z = sym.symbols('z')
277
+ P_l_m = [[0] * (j + 1) for j in range(k)]
278
+
279
+ P_l_m[0][0] = 1
280
+ if k > 0:
281
+ P_l_m[1][0] = z
282
+
283
+ for j in range(2, k):
284
+ P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] -
285
+ (j - 1) * P_l_m[j - 2][0]) / j)
286
+ if not zero_m_only:
287
+ for i in range(1, k):
288
+ P_l_m[i][i] = sym.simplify((1 - 2 * i) * P_l_m[i - 1][i - 1])
289
+ if i + 1 < k:
290
+ P_l_m[i + 1][i] = sym.simplify(
291
+ (2 * i + 1) * z * P_l_m[i][i])
292
+ for j in range(i + 2, k):
293
+ P_l_m[j][i] = sym.simplify(
294
+ ((2 * j - 1) * z * P_l_m[j - 1][i] -
295
+ (i + j - 1) * P_l_m[j - 2][i]) / (j - i))
296
+
297
+ return P_l_m
298
+
299
+
300
+ def real_sph_harm(k, zero_m_only=True, spherical_coordinates=True):
301
+ if not zero_m_only:
302
+ S_m = [0]
303
+ C_m = [1]
304
+ for i in range(1, k):
305
+ x = sym.symbols('x')
306
+ y = sym.symbols('y')
307
+ S_m += [x * S_m[i - 1] + y * C_m[i - 1]]
308
+ C_m += [x * C_m[i - 1] - y * S_m[i - 1]]
309
+
310
+ P_l_m = associated_legendre_polynomials(k, zero_m_only)
311
+ if spherical_coordinates:
312
+ theta = sym.symbols('theta')
313
+ z = sym.symbols('z')
314
+ for i in range(len(P_l_m)):
315
+ for j in range(len(P_l_m[i])):
316
+ if type(P_l_m[i][j]) != int:
317
+ P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta))
318
+ if not zero_m_only:
319
+ phi = sym.symbols('phi')
320
+ for i in range(len(S_m)):
321
+ S_m[i] = S_m[i].subs(x,
322
+ sym.sin(theta) * sym.cos(phi)).subs(
323
+ y,
324
+ sym.sin(theta) * sym.sin(phi))
325
+ for i in range(len(C_m)):
326
+ C_m[i] = C_m[i].subs(x,
327
+ sym.sin(theta) * sym.cos(phi)).subs(
328
+ y,
329
+ sym.sin(theta) * sym.sin(phi))
330
+
331
+ Y_func_l_m = [['0'] * (2 * j + 1) for j in range(k)]
332
+ for i in range(k):
333
+ Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0])
334
+
335
+ if not zero_m_only:
336
+ for i in range(1, k):
337
+ for j in range(1, i + 1):
338
+ Y_func_l_m[i][j] = sym.simplify(
339
+ 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j])
340
+ for i in range(1, k):
341
+ for j in range(1, i + 1):
342
+ Y_func_l_m[i][-j] = sym.simplify(
343
+ 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j])
344
+
345
+ return Y_func_l_m
346
+
347
+
348
+ class BesselBasisLayer(torch.nn.Module):
349
+ def __init__(self, num_radial, cutoff, envelope_exponent=6):
350
+ super(BesselBasisLayer, self).__init__()
351
+ self.cutoff = cutoff
352
+ self.envelope = Envelope(envelope_exponent)
353
+
354
+ self.freq = torch.nn.Parameter(torch.Tensor(num_radial))
355
+
356
+ self.reset_parameters()
357
+
358
+ def reset_parameters(self):
359
+ # 代替in-place操作
360
+ # torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)
361
+ # self.freq = torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)
362
+
363
+ # 计算临时张量并存储到 tmp_tensor 变量中
364
+ tmp_tensor = torch.arange(1, self.freq.numel() + 1, dtype=self.freq.dtype, device=self.freq.device)
365
+
366
+ # 使用乘法函数实现数乘并将结果保存到 self.freq 张量上
367
+ self.freq.data = torch.mul(tmp_tensor, PI)
368
+
369
+ def forward(self, dist):
370
+ dist = dist.unsqueeze(-1) / self.cutoff
371
+ return self.envelope(dist) * (self.freq * dist).sin()
372
+
373
+
374
+ class SiLU(nn.Module):
375
+ def __init__(self):
376
+ super().__init__()
377
+
378
+ def forward(self, input):
379
+ return silu(input)
380
+
381
+
382
+ def silu(input):
383
+ return input * torch.sigmoid(input)
384
+
385
+
386
+ class Envelope(torch.nn.Module):
387
+ def __init__(self, exponent):
388
+ super(Envelope, self).__init__()
389
+ self.p = exponent
390
+ self.a = -(self.p + 1) * (self.p + 2) / 2
391
+ self.b = self.p * (self.p + 2)
392
+ self.c = -self.p * (self.p + 1) / 2
393
+
394
+ def forward(self, x):
395
+ p, a, b, c = self.p, self.a, self.b, self.c
396
+ x_pow_p0 = x.pow(p)
397
+ x_pow_p1 = x_pow_p0 * x
398
+ env_val = 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p1 * x
399
+
400
+ zero = torch.zeros_like(x)
401
+ return torch.where(x < 1, env_val, zero)
402
+
403
+
404
+ class SphericalBasisLayer(torch.nn.Module):
405
+ def __init__(self, num_spherical, num_radial, cutoff=5.0,
406
+ envelope_exponent=5):
407
+ super(SphericalBasisLayer, self).__init__()
408
+ assert num_radial <= 64
409
+ self.num_spherical = num_spherical
410
+ self.num_radial = num_radial
411
+ self.cutoff = cutoff
412
+ self.envelope = Envelope(envelope_exponent)
413
+
414
+ bessel_forms = bessel_basis(num_spherical, num_radial)
415
+ sph_harm_forms = real_sph_harm(num_spherical)
416
+ self.sph_funcs = []
417
+ self.bessel_funcs = []
418
+
419
+ x, theta = sym.symbols('x theta')
420
+ modules = {'sin': torch.sin, 'cos': torch.cos}
421
+ for i in range(num_spherical):
422
+ if i == 0:
423
+ sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0)
424
+ self.sph_funcs.append(lambda x: torch.zeros_like(x) + sph1)
425
+ else:
426
+ sph = sym.lambdify([theta], sph_harm_forms[i][0], modules)
427
+ self.sph_funcs.append(sph)
428
+ for j in range(num_radial):
429
+ bessel = sym.lambdify([x], bessel_forms[i][j], modules)
430
+ self.bessel_funcs.append(bessel)
431
+
432
+ def forward(self, dist, angle, idx_kj):
433
+ dist = dist / self.cutoff
434
+ rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1)
435
+ rbf = self.envelope(dist).unsqueeze(-1) * rbf
436
+
437
+ cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1)
438
+
439
+ n, k = self.num_spherical, self.num_radial
440
+ out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k)
441
+ return out
442
+
443
+
444
+
445
+ msg_special_args = set([
446
+ 'edge_index',
447
+ 'edge_index_i',
448
+ 'edge_index_j',
449
+ 'size',
450
+ 'size_i',
451
+ 'size_j',
452
+ ])
453
+
454
+ aggr_special_args = set([
455
+ 'index',
456
+ 'dim_size',
457
+ ])
458
+
459
+ update_special_args = set([])
460
+
461
+
462
+ class MessagePassing(torch.nn.Module):
463
+ r"""Base class for creating message passing layers
464
+
465
+ .. math::
466
+ \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
467
+ \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
468
+ \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),
469
+
470
+ where :math:`\square` denotes a differentiable, permutation invariant
471
+ function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
472
+ and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
473
+ MLPs.
474
+ See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
475
+ create_gnn.html>`__ for the accompanying tutorial.
476
+
477
+ Args:
478
+ aggr (string, optional): The aggregation scheme to use
479
+ (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`).
480
+ (default: :obj:`"add"`)
481
+ flow (string, optional): The flow direction of message passing
482
+ (:obj:`"source_to_target"` or :obj:`"target_to_source"`).
483
+ (default: :obj:`"source_to_target"`)
484
+ node_dim (int, optional): The axis along which to propagate.
485
+ (default: :obj:`0`)
486
+ """
487
+ def __init__(self, aggr='add', flow='target_to_source', node_dim=0):
488
+ super(MessagePassing, self).__init__()
489
+
490
+ self.aggr = aggr
491
+ assert self.aggr in ['add', 'mean', 'max']
492
+
493
+ self.flow = flow
494
+ assert self.flow in ['source_to_target', 'target_to_source']
495
+
496
+ self.node_dim = node_dim
497
+ assert self.node_dim >= 0
498
+
499
+ self.__msg_params__ = inspect.signature(self.message).parameters
500
+ self.__msg_params__ = OrderedDict(self.__msg_params__)
501
+
502
+ self.__aggr_params__ = inspect.signature(self.aggregate).parameters
503
+ self.__aggr_params__ = OrderedDict(self.__aggr_params__)
504
+ self.__aggr_params__.popitem(last=False)
505
+
506
+ self.__update_params__ = inspect.signature(self.update).parameters
507
+ self.__update_params__ = OrderedDict(self.__update_params__)
508
+ self.__update_params__.popitem(last=False)
509
+
510
+ msg_args = set(self.__msg_params__.keys()) - msg_special_args
511
+ aggr_args = set(self.__aggr_params__.keys()) - aggr_special_args
512
+ update_args = set(self.__update_params__.keys()) - update_special_args
513
+
514
+ self.__args__ = set().union(msg_args, aggr_args, update_args)
515
+
516
+ def __set_size__(self, size, index, tensor):
517
+ if not torch.is_tensor(tensor):
518
+ pass
519
+ elif size[index] is None:
520
+ size[index] = tensor.size(self.node_dim)
521
+ elif size[index] != tensor.size(self.node_dim):
522
+ raise ValueError(
523
+ (f'Encountered node tensor with size '
524
+ f'{tensor.size(self.node_dim)} in dimension {self.node_dim}, '
525
+ f'but expected size {size[index]}.'))
526
+
527
+ def __collect__(self, edge_index, size, kwargs):
528
+ i, j = (0, 1) if self.flow == "target_to_source" else (1, 0)
529
+ ij = {"_i": i, "_j": j}
530
+
531
+ out = {}
532
+ for arg in self.__args__:
533
+ if arg[-2:] not in ij.keys():
534
+ out[arg] = kwargs.get(arg, inspect.Parameter.empty)
535
+ else:
536
+ idx = ij[arg[-2:]]
537
+ data = kwargs.get(arg[:-2], inspect.Parameter.empty)
538
+
539
+ if data is inspect.Parameter.empty:
540
+ out[arg] = data
541
+ continue
542
+
543
+ if isinstance(data, tuple) or isinstance(data, list):
544
+ assert len(data) == 2
545
+ self.__set_size__(size, 1 - idx, data[1 - idx])
546
+ data = data[idx]
547
+
548
+ if not torch.is_tensor(data):
549
+ out[arg] = data
550
+ continue
551
+
552
+ self.__set_size__(size, idx, data)
553
+ out[arg] = data.index_select(self.node_dim, edge_index[idx])
554
+
555
+ size[0] = size[1] if size[0] is None else size[0]
556
+ size[1] = size[0] if size[1] is None else size[1]
557
+
558
+ # Add special message arguments.
559
+ out['edge_index'] = edge_index
560
+ out['edge_index_i'] = edge_index[i]
561
+ out['edge_index_j'] = edge_index[j]
562
+ out['size'] = size
563
+ out['size_i'] = size[i]
564
+ out['size_j'] = size[j]
565
+
566
+ # Add special aggregate arguments.
567
+ out['index'] = out['edge_index_i']
568
+ out['dim_size'] = out['size_i']
569
+
570
+ return out
571
+
572
+ def __distribute__(self, params, kwargs):
573
+ out = {}
574
+ for key, param in params.items():
575
+ data = kwargs[key]
576
+ if data is inspect.Parameter.empty:
577
+ if param.default is inspect.Parameter.empty:
578
+ raise TypeError(f'Required parameter {key} is empty.')
579
+ data = param.default
580
+ out[key] = data
581
+ return out
582
+
583
+ def propagate(self, edge_index, size=None, **kwargs):
584
+ r"""The initial call to start propagating messages.
585
+
586
+ Args:
587
+ edge_index (Tensor): The indices of a general (sparse) assignment
588
+ matrix with shape :obj:`[N, M]` (can be directed or
589
+ undirected).
590
+ size (list or tuple, optional): The size :obj:`[N, M]` of the
591
+ assignment matrix. If set to :obj:`None`, the size will be
592
+ automatically inferred and assumed to be quadratic.
593
+ (default: :obj:`None`)
594
+ **kwargs: Any additional data which is needed to construct and
595
+ aggregate messages, and to update node embeddings.
596
+ """
597
+
598
+ size = [None, None] if size is None else size
599
+ size = [size, size] if isinstance(size, int) else size
600
+ size = size.tolist() if torch.is_tensor(size) else size
601
+ size = list(size) if isinstance(size, tuple) else size
602
+ assert isinstance(size, list)
603
+ assert len(size) == 2
604
+
605
+ kwargs = self.__collect__(edge_index, size, kwargs)
606
+
607
+ msg_kwargs = self.__distribute__(self.__msg_params__, kwargs)
608
+
609
+ m = self.message(**msg_kwargs)
610
+ aggr_kwargs = self.__distribute__(self.__aggr_params__, kwargs)
611
+ m = self.aggregate(m, **aggr_kwargs)
612
+
613
+ update_kwargs = self.__distribute__(self.__update_params__, kwargs)
614
+ m = self.update(m, **update_kwargs)
615
+
616
+ return m
617
+
618
+ def message(self, x_j): # pragma: no cover
619
+ r"""Constructs messages to node :math:`i` in analogy to
620
+ :math:`\phi_{\mathbf{\Theta}}` for each edge in
621
+ :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and
622
+ :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`.
623
+ Can take any argument which was initially passed to :meth:`propagate`.
624
+ In addition, tensors passed to :meth:`propagate` can be mapped to the
625
+ respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or
626
+ :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.
627
+ """
628
+
629
+ return x_j
630
+
631
+ def aggregate(self, inputs, index, dim_size): # pragma: no cover
632
+ r"""Aggregates messages from neighbors as
633
+ :math:`\square_{j \in \mathcal{N}(i)}`.
634
+
635
+ By default, delegates call to scatter functions that support
636
+ "add", "mean" and "max" operations specified in :meth:`__init__` by
637
+ the :obj:`aggr` argument.
638
+ """
639
+
640
+ return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
641
+
642
+ def update(self, inputs): # pragma: no cover
643
+ r"""Updates node embeddings in analogy to
644
+ :math:`\gamma_{\mathbf{\Theta}}` for each node
645
+ :math:`i \in \mathcal{V}`.
646
+ Takes in the output of aggregation as first argument and any argument
647
+ which was initially passed to :meth:`propagate`.
648
+ """
649
+
650
+ return inputs
651
+
652
+ class MXMNet(nn.Module):
653
+ def __init__(self, dim=128, n_layer=6, cutoff=5.0, num_spherical=7, num_radial=6, envelope_exponent=5):
654
+ super(MXMNet, self).__init__()
655
+
656
+ self.dim = dim
657
+ self.n_layer = n_layer
658
+ self.cutoff = cutoff
659
+
660
+ self.embeddings = nn.Parameter(torch.ones((5, self.dim)))
661
+
662
+ self.rbf_l = BesselBasisLayer(16, 5, envelope_exponent)
663
+ self.rbf_g = BesselBasisLayer(16, self.cutoff, envelope_exponent)
664
+ self.sbf = SphericalBasisLayer(num_spherical, num_radial, 5, envelope_exponent)
665
+
666
+ self.rbf_g_mlp = MLP([16, self.dim])
667
+ self.rbf_l_mlp = MLP([16, self.dim])
668
+
669
+ self.sbf_1_mlp = MLP([num_spherical * num_radial, self.dim])
670
+ self.sbf_2_mlp = MLP([num_spherical * num_radial, self.dim])
671
+
672
+ self.global_layers = torch.nn.ModuleList()
673
+ for layer in range(self.n_layer):
674
+ self.global_layers.append(Global_MP(self.dim))
675
+
676
+ self.local_layers = torch.nn.ModuleList()
677
+ for layer in range(self.n_layer):
678
+ self.local_layers.append(Local_MP(self.dim))
679
+
680
+ self.init()
681
+
682
+ def init(self):
683
+ stdv = math.sqrt(3)
684
+ self.embeddings.data.uniform_(-stdv, stdv)
685
+
686
+ def indices(self, edge_index, num_nodes):
687
+ row, col = edge_index
688
+
689
+ value = torch.arange(row.size(0), device=row.device)
690
+ adj_t = SparseTensor(row=col, col=row, value=value,
691
+ sparse_sizes=(num_nodes, num_nodes))
692
+
693
+ #Compute the node indices for two-hop angles
694
+ adj_t_row = adj_t[row]
695
+ num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)
696
+
697
+ idx_i = col.repeat_interleave(num_triplets)
698
+ idx_j = row.repeat_interleave(num_triplets)
699
+ idx_k = adj_t_row.storage.col()
700
+ mask = idx_i != idx_k
701
+ idx_i_1, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]
702
+
703
+ idx_kj = adj_t_row.storage.value()[mask]
704
+ idx_ji_1 = adj_t_row.storage.row()[mask]
705
+
706
+ #Compute the node indices for one-hop angles
707
+ adj_t_col = adj_t[col]
708
+
709
+ num_pairs = adj_t_col.set_value(None).sum(dim=1).to(torch.long)
710
+ idx_i_2 = row.repeat_interleave(num_pairs)
711
+ idx_j1 = col.repeat_interleave(num_pairs)
712
+ idx_j2 = adj_t_col.storage.col()
713
+
714
+ idx_ji_2 = adj_t_col.storage.row()
715
+ idx_jj = adj_t_col.storage.value()
716
+
717
+ return idx_i_1, idx_j, idx_k, idx_kj, idx_ji_1, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2
718
+
719
+
720
+ def forward_features(self, data):
721
+ x = data.x
722
+ edge_index = data.edge_index
723
+ pos = data.pos
724
+ batch = data.batch
725
+ # Initialize node embeddings
726
+ h = torch.index_select(self.embeddings, 0, x.long())
727
+
728
+ '''局部层--------------------------------------------------------------------------
729
+ '''
730
+ # Get the edges and pairwise distances in the local layer
731
+ edge_index_l, _ = remove_self_loops(edge_index) # 移除自环后的边索引
732
+ j_l, i_l = edge_index_l
733
+ dist_l = (pos[i_l] - pos[j_l]).pow(2).sum(dim=-1).sqrt() # 两个节点之间的距离
734
+
735
+ '''全局层--------------------------------------------------------------------------
736
+ '''
737
+ # Get the edges pairwise distances in the global layer
738
+ # radius函数返回两个节点之间的距离小于cutoff的边索引
739
+ row, col = radius(pos, pos, self.cutoff, batch, batch, max_num_neighbors=500)
740
+ edge_index_g = torch.stack([row, col], dim=0)
741
+ edge_index_g, _ = remove_self_loops(edge_index_g)
742
+ j_g, i_g = edge_index_g
743
+ dist_g = (pos[i_g] - pos[j_g]).pow(2).sum(dim=-1).sqrt()
744
+
745
+ # Compute the node indices for defining the angles
746
+ idx_i_1, idx_j, idx_k, idx_kj, idx_ji, idx_i_2, idx_j1, idx_j2, idx_jj, idx_ji_2 = self.indices(edge_index_l, num_nodes=h.size(0))
747
+
748
+ # Compute the two-hop angles
749
+ pos_ji_1, pos_kj = pos[idx_j] - pos[idx_i_1], pos[idx_k] - pos[idx_j]
750
+ a = (pos_ji_1 * pos_kj).sum(dim=-1)
751
+ b = torch.cross(pos_ji_1, pos_kj).norm(dim=-1)
752
+ angle_1 = torch.atan2(b, a)
753
+
754
+ # Compute the one-hop angles
755
+ pos_ji_2, pos_jj = pos[idx_j1] - pos[idx_i_2], pos[idx_j2] - pos[idx_j1]
756
+ a = (pos_ji_2 * pos_jj).sum(dim=-1)
757
+ b = torch.cross(pos_ji_2, pos_jj).norm(dim=-1)
758
+ angle_2 = torch.atan2(b, a)
759
+
760
+ # Get the RBF and SBF embeddings
761
+ rbf_g = self.rbf_g(dist_g)
762
+ rbf_l = self.rbf_l(dist_l)
763
+ sbf_1 = self.sbf(dist_l, angle_1, idx_kj)
764
+ sbf_2 = self.sbf(dist_l, angle_2, idx_jj)
765
+
766
+ rbf_g = self.rbf_g_mlp(rbf_g)
767
+ rbf_l = self.rbf_l_mlp(rbf_l)
768
+ sbf_1 = self.sbf_1_mlp(sbf_1)
769
+ sbf_2 = self.sbf_2_mlp(sbf_2)
770
+
771
+ # Perform the message passing schemes
772
+ node_sum = 0
773
+
774
+ for layer in range(self.n_layer):
775
+ h = self.global_layers[layer](h, rbf_g, edge_index_g)
776
+ h, t = self.local_layers[layer](h, rbf_l, sbf_1, sbf_2, idx_kj, idx_ji, idx_jj, idx_ji_2, edge_index_l)
777
+ node_sum += t
778
+
779
+ # Readout
780
+ output = global_add_pool(node_sum, batch)
781
+ return output.view(-1)
782
+
783
+ def loss(self, pred, label):
784
+ pred, label = pred.reshape(-1), label.reshape(-1)
785
+ return F.mse_loss(pred, label)
786
+
787
+
788
+ class Global_MP(MessagePassing):
789
+
790
+ def __init__(self, dim):
791
+ super(Global_MP, self).__init__()
792
+ self.dim = dim
793
+
794
+ self.h_mlp = MLP([self.dim, self.dim])
795
+
796
+ self.res1 = Res(self.dim)
797
+ self.res2 = Res(self.dim)
798
+ self.res3 = Res(self.dim)
799
+ self.mlp = MLP([self.dim, self.dim])
800
+
801
+ self.x_edge_mlp = MLP([self.dim * 3, self.dim])
802
+ self.linear = nn.Linear(self.dim, self.dim, bias=False)
803
+
804
+ def forward(self, h, edge_attr, edge_index):
805
+ edge_index, _ = add_self_loops(edge_index, num_nodes=h.size(0))
806
+
807
+ res_h = h
808
+
809
+ # Integrate the Cross Layer Mapping inside the Global Message Passing
810
+ h = self.h_mlp(h)
811
+
812
+ # Message Passing operation
813
+ h = self.propagate(edge_index, x=h, num_nodes=h.size(0), edge_attr=edge_attr)
814
+
815
+ # Update function f_u
816
+ h = self.res1(h)
817
+ h = self.mlp(h) + res_h
818
+ h = self.res2(h)
819
+ h = self.res3(h)
820
+
821
+ # Message Passing operation
822
+ h = self.propagate(edge_index, x=h, num_nodes=h.size(0), edge_attr=edge_attr)
823
+
824
+ return h
825
+
826
+ def message(self, x_i, x_j, edge_attr, edge_index, num_nodes):
827
+ num_edge = edge_attr.size()[0]
828
+
829
+ x_edge = torch.cat((x_i[:num_edge], x_j[:num_edge], edge_attr), -1)
830
+ x_edge = self.x_edge_mlp(x_edge)
831
+
832
+ x_j = torch.cat((self.linear(edge_attr) * x_edge, x_j[num_edge:]), dim=0)
833
+
834
+ return x_j
835
+
836
+ def update(self, aggr_out):
837
+ return aggr_out
838
+
839
+
840
+ class Local_MP(torch.nn.Module):
841
+ def __init__(self, dim):
842
+ super(Local_MP, self).__init__()
843
+ self.dim = dim
844
+
845
+ self.h_mlp = MLP([self.dim, self.dim])
846
+
847
+ self.mlp_kj = MLP([3 * self.dim, self.dim])
848
+ self.mlp_ji_1 = MLP([3 * self.dim, self.dim])
849
+ self.mlp_ji_2 = MLP([self.dim, self.dim])
850
+ self.mlp_jj = MLP([self.dim, self.dim])
851
+
852
+ self.mlp_sbf1 = MLP([self.dim, self.dim, self.dim])
853
+ self.mlp_sbf2 = MLP([self.dim, self.dim, self.dim])
854
+ self.lin_rbf1 = nn.Linear(self.dim, self.dim, bias=False)
855
+ self.lin_rbf2 = nn.Linear(self.dim, self.dim, bias=False)
856
+
857
+ self.res1 = Res(self.dim)
858
+ self.res2 = Res(self.dim)
859
+ self.res3 = Res(self.dim)
860
+
861
+ self.lin_rbf_out = nn.Linear(self.dim, self.dim, bias=False)
862
+
863
+ self.h_mlp = MLP([self.dim, self.dim])
864
+
865
+ self.y_mlp = MLP([self.dim, self.dim, self.dim, self.dim])
866
+ self.y_W = nn.Linear(self.dim, 1)
867
+
868
+ def forward(self, h, rbf, sbf1, sbf2, idx_kj, idx_ji_1, idx_jj, idx_ji_2, edge_index, num_nodes=None):
869
+ res_h = h
870
+
871
+ # Integrate the Cross Layer Mapping inside the Local Message Passing
872
+ h = self.h_mlp(h)
873
+
874
+ # Message Passing 1
875
+ j, i = edge_index
876
+ m = torch.cat([h[i], h[j], rbf], dim=-1)
877
+
878
+ m_kj = self.mlp_kj(m)
879
+ m_kj = m_kj * self.lin_rbf1(rbf)
880
+ m_kj = m_kj[idx_kj] * self.mlp_sbf1(sbf1)
881
+ m_kj = scatter(m_kj, idx_ji_1, dim=0, dim_size=m.size(0), reduce='add')
882
+
883
+ m_ji_1 = self.mlp_ji_1(m)
884
+
885
+ m = m_ji_1 + m_kj
886
+
887
+ # Message Passing 2 (index jj denotes j'i in the main paper)
888
+ m_jj = self.mlp_jj(m)
889
+ m_jj = m_jj * self.lin_rbf2(rbf)
890
+ m_jj = m_jj[idx_jj] * self.mlp_sbf2(sbf2)
891
+ m_jj = scatter(m_jj, idx_ji_2, dim=0, dim_size=m.size(0), reduce='add')
892
+
893
+ m_ji_2 = self.mlp_ji_2(m)
894
+
895
+ m = m_ji_2 + m_jj
896
+
897
+ # Aggregation
898
+ m = self.lin_rbf_out(rbf) * m
899
+ h = scatter(m, i, dim=0, dim_size=h.size(0), reduce='add')
900
+
901
+ # Update function f_u
902
+ h = self.res1(h)
903
+ h = self.h_mlp(h) + res_h
904
+ h = self.res2(h)
905
+ h = self.res3(h)
906
+
907
+ # Output Module
908
+ y = self.y_mlp(h)
909
+ y = self.y_W(y)
910
+
911
+ return h, y
912
+
913
+
914
+ # class MXMConfig(PretrainedConfig):
915
+ # model_type = "gcn"
916
+ #
917
+ # def __init__(
918
+ # self,
919
+ # dim: int=128,
920
+ # n_layer: int=6,
921
+ # cutoff: float=5.0,
922
+ # num_spherical: int=7,
923
+ # num_radial: int=6,
924
+ # envelope_exponent: int=5,
925
+ #
926
+ # smiles: List[str] = None,
927
+ # processor_class: str = "SmilesProcessor",
928
+ # **kwargs,
929
+ # ):
930
+ #
931
+ # self.dim = dim # the dimension of input feature
932
+ # self.n_layer = n_layer # the number of GCN layers
933
+ # self.cutoff = cutoff # the cutoff distance for neighbor searching
934
+ # self.num_spherical = num_spherical # the number of spherical harmonics
935
+ # self.num_radial = num_radial # the number of radial basis
936
+ # self.envelope_exponent = envelope_exponent # the envelope exponent
937
+ #
938
+ # self.smiles = smiles # process smiles
939
+ # self.processor_class = processor_class
940
+ #
941
+ #
942
+ # super().__init__(**kwargs)
943
+
944
+
945
+
946
+ class MXMModel(PreTrainedModel):
947
+ config_class = MXMConfig
948
+
949
+ def __init__(self, config):
950
+ super().__init__(config)
951
+
952
+ self.model = MXMNet(
953
+ dim=config.dim,
954
+ n_layer=config.n_layer,
955
+ cutoff=config.cutoff,
956
+ num_spherical=config.num_spherical,
957
+ num_radial=config.num_radial,
958
+ envelope_exponent=config.envelope_exponent,
959
+ )
960
+ self.process = SmilesDataset(
961
+ smiles=config.smiles,
962
+ )
963
+
964
+ self.mxm_model = None
965
+ self.dataset = None
966
+ self.output = None
967
+ self.data_loader = None
968
+ self.pred_data = None
969
+
970
+ def forward(self, tensor):
971
+ return self.model.forward_features(tensor)
972
+
973
+ def SmilesProcessor(self, smiles):
974
+ return self.process.get_data(smiles)
975
+
976
+
977
+ def predict_smiles(self, smiles, device: str='cpu', result_dir: str='./', **kwargs):
978
+
979
+
980
+ batch_size = kwargs.pop('batch_size', 1)
981
+ shuffle = kwargs.pop('shuffle', False)
982
+ drop_last = kwargs.pop('drop_last', False)
983
+ num_workers = kwargs.pop('num_workers', 0)
984
+
985
+ self.mxm_model = AutoModel.from_pretrained("Huhujingjing/custom-mxm", trust_remote_code=True).to(device)
986
+ self.mxm_model.eval()
987
+
988
+ self.dataset = self.process.get_data(smiles)
989
+ self.output = ""
990
+ self.output += ("predicted samples num: {}\n".format(len(self.dataset)))
991
+ self.output +=("predicted samples:{}\n".format(self.dataset[0]))
992
+ self.data_loader = DataLoader(self.dataset,
993
+ batch_size=batch_size,
994
+ shuffle=shuffle,
995
+ drop_last=drop_last,
996
+ num_workers=num_workers
997
+ )
998
+ self.pred_data = {
999
+ 'smiles': [],
1000
+ 'pred': []
1001
+ }
1002
+
1003
+ for batch in tqdm(self.data_loader):
1004
+ batch = batch.to(device)
1005
+ with torch.no_grad():
1006
+ self.pred_data['smiles'] += batch['smiles']
1007
+ self.pred_data['pred'] += self.gcn_model(batch).cpu().tolist()
1008
+
1009
+ pred = torch.tensor(self.pred_data['pred']).reshape(-1)
1010
+ if device == 'cuda':
1011
+ pred = pred.cpu().tolist()
1012
+ self.pred_data['pred'] = pred
1013
+ pred_df = pd.DataFrame(self.pred_data)
1014
+ pred_df['pred'] = pred_df['pred'].apply(lambda x: round(x, 2))
1015
+ self.output +=('-' * 40 + '\n'+'predicted result: \n'+'{}\n'.format(pred_df))
1016
+ self.output +=('-' * 40)
1017
+
1018
+ pred_df.to_csv(os.path.join(result_dir, 'gcn.csv'), index=False)
1019
+ self.output +=('\nsave predicted result to {}\n'.format(os.path.join(result_dir, 'gcn.csv')))
1020
+
1021
+ return self.output
1022
+
1023
+
1024
+ if __name__ == "__main__":
1025
+ # pass
1026
+ mxm_config = MXMConfig(
1027
+ dim=128,
1028
+ n_layer=6,
1029
+ cutoff=5.0,
1030
+ num_spherical=7,
1031
+ num_radial=6,
1032
+ envelope_exponent=5,
1033
+ smiles=["C", "CC", "CCC"],
1034
+ processor_class="SmilesProcessor"
1035
+ )
1036
+ # mxm_config.save_pretrained("custom-mxm")
1037
+
1038
+ mxmd = MXMModel(mxm_config)
1039
+ mxmd.model.load_state_dict(torch.load(r'G:\Trans_MXM\mxm_model\mxm.pt'))
1040
+ mxmd.save_pretrained("custom-mxm")
1041
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be586213d89c6bd2b8f29c6b508e3b3c6dbe5f3dbaf73422596c396c3fdae8b1
3
+ size 14813979