Songyou commited on
Commit
9e93243
·
1 Parent(s): d6fdf05

add LLM files

Browse files
Files changed (7) hide show
  1. generate.py +286 -0
  2. utils/__init__.py +0 -0
  3. utils/chem.py +65 -0
  4. utils/file.py +29 -0
  5. utils/log.py +32 -0
  6. utils/plot.py +84 -0
  7. utils/torch_util.py +32 -0
generate.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle as pkl
3
+ import os
4
+ import argparse
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ import utils.chem as uc
10
+ import utils.torch_util as ut
11
+ import utils.log as ul
12
+ import utils.plot as up
13
+ import configuration.config_default as cfgd
14
+ import models.dataset as md
15
+ import preprocess.vocabulary as mv
16
+ import configuration.opts as opts
17
+ from models.transformer.module.decode import decode
18
+ from models.transformer.encode_decode.model import EncoderDecoder
19
+ # from models.seq2seq.model import Model
20
+
21
+ def prepare_input(opt):
22
+ ''' check if the input file contain pk_diff
23
+ if not the pk_diff will be added
24
+ '''
25
+ df_input = pd.read_csv(os.path.join(opt.data_path, opt.test_file_name + '.csv'), sep=",")
26
+ delta_pkis=['(-0.5, 0.5]','(0.5, 1.5]','(1.5, 2.5]','(2.5, 4.5]','(4.5, 6.5]','(6.5, 8.5]','(8.5, 10.5]','(10.5, inf]']
27
+ # TODO:这是分子到分子的generate数据处理?
28
+ # if "Delta_pki" not in df_input.columns:
29
+ # newData=[]
30
+ # for idx,irow in df_input.iterrows():
31
+ # for idelta_pki in delta_pkis:
32
+ # newData.append([irow['fromVarSMILES'],irow['constantSMILES'],idelta_pki])
33
+ # df_new=pd.DataFrame(newData, columns=['fromVarSMILES','constantSMILES','Delta_pki'])
34
+ # df_new.to_csv(os.path.join(opt.data_path, opt.test_file_name + '_prepared'+'.csv'), index=None)
35
+ # return opt.test_file_name + '_prepared'
36
+ return opt.test_file_name
37
+
38
+ class GenerateRunner():
39
+
40
+ def __init__(self, opt):
41
+
42
+ # self.save_path = os.path.join('experiments', opt.save_directory, opt.test_file_name,
43
+ # f'evaluation_{opt.epoch}')
44
+ path = Path(os.path.join(opt.save_directory))
45
+ path.mkdir(parents=True, exist_ok=True)
46
+ self.save_path = os.path.join(path)
47
+ self.exist_flag = Path(f'{self.save_path}/generated_molecules.csv').exists()
48
+ self.overwrite = opt.overwrite
49
+ self.dev_no = opt.dev_no
50
+ global LOG
51
+ LOG = ul.get_logger(name="generate",
52
+ log_path=os.path.join(self.save_path, 'generate.log'))
53
+ LOG.info(opt)
54
+ LOG.info("Save directory: {}".format(self.save_path))
55
+
56
+ # Load vocabulary
57
+ with open(os.path.join(opt.vocab_path, 'vocab.pkl'), "rb") as input_file:
58
+ vocab = pkl.load(input_file)
59
+ self.vocab = vocab
60
+ self.tokenizer = mv.SMILESTokenizer()
61
+
62
+ def initialize_dataloader(self, opt, vocab, test_file):
63
+ """
64
+ Initialize dataloader
65
+ :param opt:
66
+ :param vocab: vocabulary
67
+ :param test_file: test_file_name
68
+ :return:
69
+ """
70
+
71
+ # Read test
72
+ data = pd.read_csv(os.path.join(opt.data_path, test_file + '.csv'), sep=",")
73
+ dataset = md.Dataset(data=data, vocabulary=vocab, tokenizer=self.tokenizer, prediction_mode=True)
74
+ dataloader = torch.utils.data.DataLoader(dataset, opt.batch_size,
75
+ shuffle=False, collate_fn=md.Dataset.collate_fn)
76
+ return dataloader
77
+
78
+ def generate(self, opt):
79
+ if not self.overwrite and self.exist_flag:
80
+ print('GENERATED MOL EXIST, SKIP GENERATING!')
81
+ return
82
+ # set device
83
+ #device = ut.allocate_gpu()
84
+ # torch.cuda.set_device(1)
85
+ # current_device = torch.cuda.current_device()
86
+ # print("当前使用的 CUDA 设备编号是:", current_device)
87
+ device = torch.device(f'cuda:{self.dev_no}')
88
+ # 构造loader
89
+ dataloader_test = self.initialize_dataloader(opt, self.vocab, opt.test_file_name)
90
+
91
+ # Load model
92
+ file_name = os.path.join(opt.model_path, f'model_{opt.epoch}.pt')
93
+ if opt.model_choice == 'transformer':
94
+ model = EncoderDecoder.load_from_file(file_name)
95
+ model.to(device)
96
+ model.eval()
97
+ elif opt.model_choice == 'seq2seq':
98
+ model = Model.load_from_file(file_name, evaluation_mode=True)
99
+ # move to GPU
100
+ model.network.encoder.to(device)
101
+ model.network.decoder.to(device)
102
+ # TODO: 有没有可能超长?模型崩溃调整长度,规则是2的倍数
103
+ max_len = cfgd.DATA_DEFAULT['max_sequence_length']
104
+ df_list = []
105
+ sampled_smiles_list = []
106
+ for j, batch in enumerate(ul.progress_bar(dataloader_test, total=len(dataloader_test))):
107
+
108
+ # df是dataframe 是一行的原始数据
109
+ src, source_length, _, src_mask, _, _, df = batch
110
+
111
+ # Move to GPU
112
+ src = src.to(device)
113
+ src_mask = src_mask.to(device)
114
+ smiles= self.sample(opt.model_choice, model, src, src_mask,
115
+ source_length,
116
+ opt.decode_type,
117
+ num_samples=opt.num_samples,
118
+ max_len=max_len,
119
+ device=device)
120
+
121
+ df_list.append(df)
122
+ sampled_smiles_list.extend(smiles)
123
+
124
+ # prepare dataframe
125
+ data_sorted = pd.concat(df_list)
126
+ sampled_smiles_list = np.array(sampled_smiles_list)
127
+
128
+ for i in range(opt.num_samples):
129
+ data_sorted['Predicted_smi_{}'.format(i + 1)] = sampled_smiles_list[:, i]
130
+
131
+ result_path = os.path.join(self.save_path, "generated_molecules.csv")
132
+ LOG.info("Save to {}".format(result_path))
133
+ data_sorted.to_csv(result_path, index=False)
134
+
135
+ def sample(self, model_choice, model, src, src_mask, source_length, decode_type, num_samples=10,
136
+ max_len=cfgd.DATA_DEFAULT['max_sequence_length'],
137
+ device=None):
138
+ batch_size = src.shape[0]
139
+ num_valid_batch = np.zeros(batch_size) # current number of unique and valid samples out of total sampled
140
+ num_valid_batch_total = np.zeros(batch_size) # current number of sampling times no matter unique or valid
141
+ num_valid_batch_desired = np.asarray([num_samples] * batch_size)
142
+ unique_set_num_samples = [set() for i in range(batch_size)] # for each starting molecule
143
+ batch_index = torch.LongTensor(range(batch_size))
144
+ batch_index_current = torch.LongTensor(range(batch_size)).to(device)
145
+ # TODO:这个好像没有用到?
146
+ start_mols = []
147
+ # zeros correspondes to ****** which is valid according to RDKit
148
+ sequences_all = torch.ones((num_samples, batch_size, max_len))
149
+ sequences_all = sequences_all.type(torch.LongTensor)
150
+ max_trials = 100000 # Maximum trials for sampling
151
+ current_trials = 0
152
+
153
+ # greedy意思是只尝试一次生成,成了就有分子式,没成的话就没有
154
+ if decode_type == 'greedy':
155
+ max_trials = 1
156
+
157
+ # Set of unique starting molecules
158
+ if src is not None:
159
+ # 这里需要修改,delta_value并不是放在第一位置
160
+ start_ind = len(cfgd.PROPERTIES)
161
+ for ibatch in range(batch_size):
162
+ source_smi = self.tokenizer.untokenize(self.vocab.decode(src[ibatch].tolist()[start_ind:]))
163
+ source_smi = uc.get_canonical_smile(source_smi)
164
+ if source_smi:
165
+ # 先添加source,用于后面去重,TODO: 但这里也不太对,因为这里已经是被mmpdb分开的,而不是一个完整的SMILES
166
+ unique_set_num_samples[ibatch].add(source_smi)
167
+ start_mols.append(source_smi)
168
+
169
+ with torch.no_grad():
170
+ if model_choice == 'seq2seq':
171
+ encoder_outputs, decoder_hidden = model.network.encoder(src, source_length)
172
+ while not all(num_valid_batch >= num_valid_batch_desired) and current_trials < max_trials:
173
+ current_trials += 1
174
+
175
+ # batch input for current trial
176
+ if src is not None:
177
+ # 这个不就是全选嘛?
178
+ src_current = src.index_select(0, batch_index_current)
179
+ if src_mask is not None:
180
+ mask_current = src_mask.index_select(0, batch_index_current)
181
+ batch_size = src_current.shape[0]
182
+
183
+ # sample molecule
184
+ if model_choice == 'transformer':
185
+ sequences = decode(model, src_current, mask_current, max_len, decode_type)
186
+ padding = (0, max_len-sequences.shape[1],
187
+ 0, 0)
188
+ sequences = torch.nn.functional.pad(sequences, padding)
189
+ elif model_choice == 'seq2seq':
190
+ sequences = self.sample_seq2seq(model, mask_current, batch_index_current, decoder_hidden,
191
+ encoder_outputs, max_len, device)
192
+ else:
193
+ LOG.info('Specify transformer or seq2seq for model_choice')
194
+
195
+ # Check valid and unique
196
+ smiles = []
197
+ is_valid_index = []
198
+ batch_index_map = dict(zip(list(range(batch_size)), batch_index_current))
199
+ # Valid, ibatch index is different from original, need map back
200
+ for ibatch in range(batch_size):
201
+ seq = sequences[ibatch]
202
+ smi = self.tokenizer.untokenize(self.vocab.decode(seq.cpu().numpy()))
203
+ smi = uc.get_canonical_smile(smi)
204
+ smiles.append(smi)
205
+ # valid and not same as starting molecules
206
+ if uc.is_valid(smi):
207
+ is_valid_index.append(ibatch)
208
+ # total sampled times
209
+ num_valid_batch_total[batch_index_map[ibatch]] += 1
210
+
211
+ # Check if duplicated and update num_valid_batch and unique
212
+ for good_index in is_valid_index:
213
+ index_in_original_batch = batch_index_map[good_index]
214
+ if smiles[good_index] not in unique_set_num_samples[index_in_original_batch]:
215
+ unique_set_num_samples[index_in_original_batch].add(smiles[good_index])
216
+ num_valid_batch[index_in_original_batch] += 1
217
+
218
+ sequences_all[int(num_valid_batch[index_in_original_batch] - 1), index_in_original_batch, :] = \
219
+ sequences[good_index]
220
+
221
+ not_completed_index = np.where(num_valid_batch < num_valid_batch_desired)[0]
222
+ # 选择未生成满的source样本继续生成
223
+ if len(not_completed_index) > 0:
224
+ batch_index_current = batch_index.index_select(0, torch.LongTensor(not_completed_index)).to(device)
225
+
226
+ # Convert to SMILES
227
+ smiles_list = [] # [batch, topk]
228
+ seqs = np.asarray(sequences_all.numpy())
229
+ # [num_sample, batch_size, max_len]
230
+ batch_size = len(seqs[0])
231
+ for ibatch in range(batch_size):
232
+ topk_list = []
233
+ for k in range(num_samples):
234
+ seq = seqs[k, ibatch, :]
235
+ topk_list.extend([self.tokenizer.untokenize(self.vocab.decode(seq))])
236
+ smiles_list.append(topk_list)
237
+
238
+ return smiles_list
239
+
240
+ def sample_seq2seq(self, model, mask, batch_index_current, decoder_hidden, encoder_outputs, max_len, device):
241
+ # batch size will change when some of the generated molecules are valid
242
+ encoder_outputs_current = encoder_outputs.index_select(0, batch_index_current)
243
+ batch_size = encoder_outputs_current.shape[0]
244
+
245
+ # start token
246
+ start_token = torch.zeros(batch_size, dtype=torch.long)
247
+ start_token[:] = self.vocab["^"]
248
+ decoder_input = start_token.to(device)
249
+ sequences = []
250
+ mask = torch.squeeze(mask, 1).to(device)
251
+
252
+ # initial decoder hidden states
253
+ if isinstance(decoder_hidden, tuple):
254
+ decoder_hidden_current = (decoder_hidden[0].index_select(1, batch_index_current),
255
+ decoder_hidden[1].index_select(1, batch_index_current))
256
+ else:
257
+ decoder_hidden_current = decoder_hidden.index_select(1, batch_index_current)
258
+ for i in range(max_len):
259
+ logits, decoder_hidden_current = model.network.decoder(decoder_input.unsqueeze(1),
260
+ decoder_hidden_current,
261
+ encoder_outputs_current, mask)
262
+ logits = logits.squeeze(1)
263
+ probabilities = logits.softmax(dim=1) # torch.Size([batch_size, vocab_size])
264
+ topi = torch.multinomial(probabilities, 1) # torch.Size([batch_size, 1])
265
+ decoder_input = topi.view(-1).detach()
266
+ sequences.append(decoder_input.view(-1, 1))
267
+
268
+ sequences = torch.cat(sequences, 1)
269
+ return sequences
270
+
271
+ def run_main():
272
+ """Main function."""
273
+ parser = argparse.ArgumentParser(
274
+ description='generate.py',
275
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
276
+
277
+ opts.generate_opts(parser)
278
+ opt = parser.parse_args()
279
+ opt.test_file_name = prepare_input(opt)
280
+
281
+ runner = GenerateRunner(opt)
282
+ runner.generate(opt)
283
+
284
+
285
+ if __name__ == "__main__":
286
+ run_main()
utils/__init__.py ADDED
File without changes
utils/chem.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RDKit util functions.
3
+ """
4
+ import rdkit.Chem as rkc
5
+ from rdkit.Chem import AllChem
6
+ from rdkit import DataStructs
7
+
8
+ def disable_rdkit_logging():
9
+ """
10
+ Disables RDKit whiny logging.
11
+ """
12
+ import rdkit.RDLogger as rkl
13
+ logger = rkl.logger()
14
+ logger.setLevel(rkl.ERROR)
15
+
16
+ import rdkit.rdBase as rkrb
17
+ rkrb.DisableLog('rdApp.error')
18
+
19
+
20
+ disable_rdkit_logging()
21
+
22
+ def to_fp_ECFP(smi):
23
+ if smi:
24
+ mol = rkc.MolFromSmiles(smi)
25
+ if mol is None:
26
+ return None
27
+ return AllChem.GetMorganFingerprint(mol, 2)
28
+
29
+ def tanimoto_similarity_pool(args):
30
+ return tanimoto_similarity(*args)
31
+
32
+ def tanimoto_similarity(smi1, smi2):
33
+ fp1, fp2 = None, None
34
+ if smi1 and type(smi1)==str and len(smi1)>0:
35
+ fp1 = to_fp_ECFP(smi1)
36
+ if smi2 and type(smi2)==str and len(smi2)>0:
37
+ fp2 = to_fp_ECFP(smi2)
38
+
39
+ if fp1 is not None and fp2 is not None:
40
+ return DataStructs.TanimotoSimilarity(fp1, fp2)
41
+ else:
42
+ return None
43
+
44
+ def is_valid(smi):
45
+ return 1 if to_mol(smi) else 0
46
+
47
+ def to_mol(smi):
48
+ """
49
+ Creates a Mol object from a SMILES string.
50
+ :param smi: SMILES string.
51
+ :return: A Mol object or None if it's not valid.
52
+ """
53
+ if isinstance(smi, str) and smi and len(smi)>0 and smi != 'nan':
54
+ return rkc.MolFromSmiles(smi)
55
+
56
+ def get_canonical_smile(smile):
57
+ if smile != 'None':
58
+ mol = rkc.MolFromSmiles(smile)
59
+ if mol is not None:
60
+ smi = rkc.MolToSmiles(mol, canonical=True, doRandom=False, isomericSmiles=False)
61
+ return smi
62
+ else:
63
+ return None
64
+ else:
65
+ return None
utils/file.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def make_directory(file, is_dir=True):
4
+ dirs = file.split('/')[:-1] if not is_dir else file.split('/')
5
+ path = '/' if file.startswith('/') else ''
6
+ for dir in dirs:
7
+ path = os.path.join(path, dir)
8
+ if not os.path.exists(path):
9
+ os.makedirs(path)
10
+
11
+ def get_parent_dir(file):
12
+ dirs = file.split('/')[:-1]
13
+ path = ''
14
+ for dir in dirs:
15
+ path = os.path.join(path, dir)
16
+ if file.startswith('/'):
17
+ path = '/' + path
18
+ return path
19
+
20
+ def chunkIt(seq, num):
21
+ avg = len(seq) / float(num)
22
+ out = []
23
+ last = 0.0
24
+
25
+ while last < len(seq):
26
+ out.append(seq[int(last):int(last + avg)])
27
+ last += avg
28
+
29
+ return out
utils/log.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import tqdm
3
+
4
+ import utils.file as uf
5
+
6
+
7
+ def get_logger(name, log_path, isMain=False, level=logging.INFO):
8
+ formatter = logging.Formatter(
9
+ fmt="%(asctime)s: %(module)s.%(funcName)s +%(lineno)s: %(levelname)-8s %(message)s",
10
+ datefmt="%H:%M:%S"
11
+ )
12
+
13
+ logger = logging.getLogger(name)
14
+ logger.setLevel(level)
15
+
16
+ # # Logging to console
17
+ stream_handler = logging.StreamHandler()
18
+ stream_handler.setFormatter(formatter)
19
+ logger.addHandler(stream_handler)
20
+
21
+ # Logging to a file
22
+ if isMain:
23
+ uf.make_directory(log_path, is_dir=False)
24
+ file_handler = logging.FileHandler(log_path)
25
+ file_handler.setFormatter(formatter)
26
+ logger.addHandler(file_handler)
27
+
28
+ return logger
29
+
30
+
31
+ def progress_bar(iterable, total, **kwargs):
32
+ return tqdm.tqdm(iterable=iterable, total=total, ascii=True, **kwargs)
utils/plot.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+
4
+ import matplotlib as mpl
5
+ import matplotlib.pyplot as plt
6
+ from scipy.stats import gaussian_kde
7
+ mpl.use('Agg')
8
+
9
+
10
+ def hist_box(data_frame, field, name="hist_box", path="./", title=None):
11
+
12
+ title = title if title else field
13
+ fig, axs = plt.subplots(1,2,figsize=(10,4))
14
+ data_frame[field].plot.hist(bins=100, title=title, ax=axs[0])
15
+ data_frame.boxplot(field, ax=axs[1])
16
+ plt.title(title)
17
+ plt.suptitle("")
18
+
19
+ plt.savefig(os.path.join(path, '{}.png'.format(name)), bbox_inches='tight')
20
+ plt.close()
21
+
22
+ def hist(data_frame, field, name="hist", path="./", title=None):
23
+
24
+
25
+ title = title if title else field
26
+
27
+ plt.hist(data_frame[field])
28
+ plt.title(title)
29
+ plt.savefig(os.path.join(path, '{}.png'.format(name)), bbox_inches='tight')
30
+ plt.close()
31
+
32
+ def hist_box_list(data_list, name="hist_box", path="./", title=None):
33
+
34
+ fig, axs = plt.subplots(1,2,figsize=(10,4))
35
+ axs[0].hist(data_list, bins=100)
36
+ axs[0].set_title(title)
37
+ axs[1].boxplot(data_list)
38
+ axs[1].set_title(title)
39
+
40
+ plt.savefig(os.path.join(path, '{}.png'.format(name)), bbox_inches='tight')
41
+ plt.close()
42
+
43
+ def scatter_hist(x, y, name, path, field=None, lims=None):
44
+ fig, axs = plt.subplots(1, 2, figsize=(10, 4))
45
+ n = len(x)
46
+ xy = np.vstack([x+ 0.00001 * np.random.rand(n), y+ 0.00001 * np.random.rand(n)])
47
+ z = gaussian_kde(xy)(xy)
48
+ axs[0].scatter(x, y, c=z, s=3, marker='o', alpha=0.2)
49
+ lims = [np.min([axs[0].get_xlim(), axs[0].get_ylim()]), np.max([axs[0].get_xlim(), axs[0].get_ylim()])] if lims is None else lims
50
+ axs[0].plot(lims, lims, 'k-', alpha=0.75)
51
+ axs[0].set_aspect('equal')
52
+ axs[0].set_xlim(lims)
53
+ axs[0].set_ylim(lims)
54
+ xlabel = ""
55
+ ylabel = ""
56
+ if "delta" in field:
57
+ if "data" in field:
58
+ axs[0].set_xlabel(r'$\Delta LogD$ (experimental)')
59
+ axs[0].set_ylabel(r'$\Delta LogD$ (calculated)')
60
+ xlabel = 'Delta LogD (experimental)'
61
+ ylabel = 'Delta LogD (calculated)'
62
+ elif "predict" in field:
63
+ axs[0].set_xlabel(r'$\Delta LogD$ (desirable)')
64
+ axs[0].set_ylabel(r'$\Delta LogD$ (generated)')
65
+ xlabel = 'Delta LogD (desirable)'
66
+ ylabel = 'Delta LogD (generated)'
67
+ if "single" in field:
68
+ if "data" in field:
69
+ xlabel, ylabel = 'LogD (experimental)', 'LogD (calculated)'
70
+ axs[0].set_xlabel(xlabel)
71
+ axs[0].set_ylabel(ylabel)
72
+ elif "predict" in field:
73
+ xlabel, ylabel = 'LogD (desirable)', 'LogD (generated)'
74
+ axs[0].set_xlabel(xlabel)
75
+ axs[0].set_ylabel(ylabel)
76
+ bins = np.histogram(np.hstack((x, y)), bins=100)[1] # get the bin edges
77
+ kwargs = dict(histtype='stepfilled', alpha=0.3, density=False, bins=bins, stacked=False)
78
+ axs[1].hist(x, **kwargs, color='b', label=xlabel)
79
+ axs[1].hist(y, **kwargs, color='r', label=ylabel)
80
+ plt.ylabel('Frequency')
81
+ plt.legend(loc='upper left')
82
+ plt.savefig(os.path.join(path, '{}.png'.format(name)), bbox_inches='tight')
83
+ plt.close()
84
+
utils/torch_util.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch related util functions
3
+ """
4
+ import torch
5
+ import os
6
+ def allocate_gpu(id=None):
7
+ '''
8
+ choose the free gpu in the node
9
+ '''
10
+ v = torch.empty(1)
11
+ if id is not None:
12
+ return torch.device("cuda:{}".format(str(id)))
13
+ else:
14
+ for i in range(8):
15
+ try:
16
+ dev_name = "cuda:{}".format(str(i))
17
+ v = v.to(dev_name)
18
+ print("Allocating cuda:{}.".format(i))
19
+
20
+ return v.device
21
+ except Exception as e:
22
+ pass
23
+ print("CUDA error: all CUDA-capable devices are busy or unavailable")
24
+ return v.device
25
+
26
+ def allocate_gpu_multi(id=None):
27
+
28
+ os.environ['CUDA_VISIBLE_DEVICES']='1'
29
+ device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
30
+ os.environ['CUDA_VISIBLE_DEVICES']='0'
31
+ device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
32
+ return device