Huhujingjing commited on
Commit
2444fad
1 Parent(s): 9a0b5b4

Upload model

Browse files
Files changed (2) hide show
  1. configuration_gcn.py +33 -0
  2. modeling_gcn.py +90 -46
configuration_gcn.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+ class GCNConfig(PretrainedConfig):
4
+ model_type = "gcn"
5
+
6
+ def __init__(
7
+ self,
8
+ input_feature: int=64,
9
+ emb_input: int=20,
10
+ hidden_size: int=64,
11
+ n_layers: int=6,
12
+ num_classes: int=1,
13
+
14
+ smiles: List[str] = None,
15
+ processor_class: str = "SmilesProcessor",
16
+ **kwargs,
17
+ ):
18
+
19
+ self.input_feature = input_feature # the dimension of input feature
20
+ self.emb_input = emb_input # the embedding dimension of input feature
21
+ self.hidden_size = hidden_size # the hidden size of GCN
22
+ self.n_layers = n_layers # the number of GCN layers
23
+ self.num_classes = num_classes # the number of output classes
24
+
25
+ self.smiles = smiles # process smiles
26
+ self.processor_class = processor_class
27
+
28
+ super().__init__(**kwargs)
29
+
30
+
31
+ if __name__ == "__main__":
32
+ gcn_config = GCNConfig(input_feature=64, emb_input=20, hidden_size=64, n_layers=6, num_classes=1, smiles=["C", "CC", "CCC"], processor_class="SmilesProcessor")
33
+ gcn_config.save_pretrained("custom-gcn")
modeling_gcn.py CHANGED
@@ -3,14 +3,45 @@ import torch.nn as nn
3
  import torch.nn.functional as F
4
  from torch_scatter import scatter
5
  from transformers import PreTrainedModel
6
- # from configuration_gcn import GCNConfig
7
  import torch
8
  from rdkit import Chem
9
  from rdkit.Chem import AllChem
10
  import torch
11
  from torch_geometric.data import Data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
 
 
13
 
 
14
  class SmilesDataset(torch.utils.data.Dataset):
15
  def __init__(self, smiles):
16
  self.smiles_list = smiles
@@ -145,36 +176,6 @@ class GCNNet(torch.nn.Module):
145
 
146
  return x.squeeze(-1)
147
 
148
-
149
- from transformers import PretrainedConfig
150
- from typing import List
151
- class GCNConfig(PretrainedConfig):
152
- model_type = "gcn"
153
-
154
- def __init__(
155
- self,
156
- input_feature: int=64,
157
- emb_input: int=20,
158
- hidden_size: int=64,
159
- n_layers: int=6,
160
- num_classes: int=1,
161
- smiles: List[str] = None,
162
- processor_class: str = "SmilesProcessor",
163
- **kwargs,
164
- ):
165
-
166
- self.input_feature = input_feature # the dimension of input feature
167
- self.emb_input = emb_input # the embedding dimension of input feature
168
- self.hidden_size = hidden_size # the hidden size of GCN
169
- self.n_layers = n_layers # the number of GCN layers
170
- self.num_classes = num_classes # the number of output classes
171
-
172
- self.smiles = smiles # process smiles
173
- self.processor_class = processor_class
174
-
175
- super().__init__(**kwargs)
176
-
177
-
178
  class GCNModel(PreTrainedModel):
179
  config_class = GCNConfig
180
 
@@ -192,27 +193,70 @@ class GCNModel(PreTrainedModel):
192
  smiles=config.smiles,
193
  )
194
 
 
 
 
 
 
 
195
  def forward(self, tensor):
196
  return self.model.forward_features(tensor)
197
 
198
- def process_smiles(self, smiles):
199
- return self.process.get_data(smiles)
200
 
 
201
 
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  if __name__ == "__main__":
205
- pass
206
- # gcn_config = GCNConfig(input_feature=64, emb_input=20, hidden_size=64, n_layers=6, num_classes=1, smiles=["C", "CC", "CCC"], processor_class="SmilesProcessor")
207
- # gcn_config.save_pretrained("custom-gcn")
208
- # gcn_config = GCNConfig.from_pretrained("custom-gcn")
209
-
210
- # gcnd = GCNModel(gcn_config)
211
- # gcnd.model.load_state_dict(torch.load(r'G:\Trans_MXM\gcn_model\gcn.pt'))
212
- # gcnd.save_pretrained("custom-gcn")
213
-
214
- # gcnd1 = GCNModelForMolecularPrediction(gcn_config)
215
- #
216
- # gcnd1.model.load_state_dict(torch.load(r'G:\Trans_MXM\gcn_model\gcn.pt'))
217
- # gcnd1.save_pretrained("custom-gcn")
218
 
 
3
  import torch.nn.functional as F
4
  from torch_scatter import scatter
5
  from transformers import PreTrainedModel
6
+ from gcn_model.configuration_gcn import GCNConfig
7
  import torch
8
  from rdkit import Chem
9
  from rdkit.Chem import AllChem
10
  import torch
11
  from torch_geometric.data import Data
12
+ import os
13
+ from transformers import PretrainedConfig
14
+ from typing import List
15
+ from torch_geometric.loader import DataLoader
16
+ from tqdm import tqdm
17
+ import pandas as pd
18
+ from transformers import AutoModel
19
+ class GCNConfig(PretrainedConfig):
20
+ model_type = "gcn"
21
+
22
+ def __init__(
23
+ self,
24
+ input_feature: int=64,
25
+ emb_input: int=20,
26
+ hidden_size: int=64,
27
+ n_layers: int=6,
28
+ num_classes: int=1,
29
+
30
+ smiles: List[str] = None,
31
+ processor_class: str = "SmilesProcessor",
32
+ **kwargs,
33
+ ):
34
+
35
+ self.input_feature = input_feature # the dimension of input feature
36
+ self.emb_input = emb_input # the embedding dimension of input feature
37
+ self.hidden_size = hidden_size # the hidden size of GCN
38
+ self.n_layers = n_layers # the number of GCN layers
39
+ self.num_classes = num_classes # the number of output classes
40
 
41
+ self.smiles = smiles # process smiles
42
+ self.processor_class = processor_class
43
 
44
+ super().__init__(**kwargs)
45
  class SmilesDataset(torch.utils.data.Dataset):
46
  def __init__(self, smiles):
47
  self.smiles_list = smiles
 
176
 
177
  return x.squeeze(-1)
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  class GCNModel(PreTrainedModel):
180
  config_class = GCNConfig
181
 
 
193
  smiles=config.smiles,
194
  )
195
 
196
+ self.gcn_model = None
197
+ self.dataset = None
198
+ self.output = None
199
+ self.data_loader = None
200
+ self.pred_data = None
201
+
202
  def forward(self, tensor):
203
  return self.model.forward_features(tensor)
204
 
205
+ # def process_smiles(self, smiles):
206
+ # return self.process.get_data(smiles)
207
 
208
+ def predict_smiles(self, smiles, device: str='cpu', result_dir: str='./', **kwargs):
209
 
210
 
211
+ batch_size = kwargs.pop('batch_size', 1)
212
+ shuffle = kwargs.pop('shuffle', False)
213
+ drop_last = kwargs.pop('drop_last', False)
214
+ num_workers = kwargs.pop('num_workers', 0)
215
+
216
+ self.gcn_model = AutoModel.from_pretrained("Huhujingjing/custom-gcn", trust_remote_code=True).to(device)
217
+ self.gcn_model.eval()
218
+
219
+ self.dataset = self.process.get_data(smiles)
220
+ self.output = ""
221
+ self.output += ("predicted samples num: {}\n".format(len(self.dataset)))
222
+ self.output +=("predicted samples:{}\n".format(self.dataset[0]))
223
+ self.data_loader = DataLoader(self.dataset,
224
+ batch_size=batch_size,
225
+ shuffle=shuffle,
226
+ drop_last=drop_last,
227
+ num_workers=num_workers
228
+ )
229
+ self.pred_data = {
230
+ 'smiles': [],
231
+ 'pred': []
232
+ }
233
+
234
+ for batch in tqdm(self.data_loader):
235
+ batch = batch.to(device)
236
+ with torch.no_grad():
237
+ self.pred_data['smiles'] += batch['smiles']
238
+ self.pred_data['pred'] += self.gcn_model(batch).cpu().tolist()
239
+
240
+ pred = torch.tensor(self.pred_data['pred']).reshape(-1)
241
+ if device == 'cuda':
242
+ pred = pred.cpu().tolist()
243
+ self.pred_data['pred'] = pred
244
+ pred_df = pd.DataFrame(self.pred_data)
245
+ pred_df['pred'] = pred_df['pred'].apply(lambda x: round(x, 2))
246
+ self.output +=('-' * 40 + '\n'+'predicted result: \n'+'{}\n'.format(pred_df))
247
+ self.output +=('-' * 40)
248
+
249
+ pred_df.to_csv(os.path.join(result_dir, 'gcn.csv'), index=False)
250
+ self.output +=('\nsave predicted result to {}\n'.format(os.path.join(result_dir, 'gcn.csv')))
251
+
252
+ return self.output
253
+
254
 
255
  if __name__ == "__main__":
256
+ gcn_config = GCNConfig(input_feature=64, emb_input=20, hidden_size=64, n_layers=6, num_classes=1,
257
+ smiles=["C", "CC", "CCC"], processor_class="SmilesProcessor")
258
+
259
+ gcnd = GCNModel(gcn_config)
260
+ gcnd.model.load_state_dict(torch.load(r'G:\Trans_MXM\gcn_model\gcn.pt'))
261
+ gcnd.save_pretrained("custom-gcn")
 
 
 
 
 
 
 
262