jinysun commited on
Commit
8c7cadd
·
1 Parent(s): febd0c4

Delete screen.py

Browse files
Files changed (1) hide show
  1. screen.py +0 -120
screen.py DELETED
@@ -1,120 +0,0 @@
1
- import os
2
- import pandas as pd
3
-
4
- import torch
5
- from torch.nn import functional as F
6
- from transformers import AutoTokenizer
7
-
8
- from util.utils import *
9
-
10
- from tqdm import tqdm
11
- from train import markerModel
12
-
13
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
14
- os.environ["CUDA_VISIBLE_DEVICES"] = '0 '
15
-
16
- device_count = torch.cuda.device_count()
17
- device_biomarker = torch.device('cuda' if torch.cuda.is_available() else "cpu")
18
-
19
- device = torch.device('cpu')
20
- a_model_name = 'DeepChem/ChemBERTa-10M-MLM'
21
- d_model_name = 'DeepChem/ChemBERTa-10M-MTR'
22
-
23
- tokenizer = AutoTokenizer.from_pretrained(a_model_name)
24
- d_tokenizer = AutoTokenizer.from_pretrained(d_model_name)
25
-
26
- #--biomarker Model
27
- ##-- hyper param config file Load --##
28
- config = load_hparams('config/predict.json')
29
- config = DictX(config)
30
- model = markerModel(config.d_model_name, config.p_model_name,
31
- config.lr, config.dropout, config.layer_features, config.loss_fn, config.layer_limit, config.pretrained['chem'], config.pretrained['prot'])
32
-
33
- model = markerModel.load_from_checkpoint(config.load_checkpoint,strict=False)
34
- model.eval()
35
- model.freeze()
36
-
37
- if device_biomarker.type == 'cuda':
38
- model = torch.nn.DataParallel(model)
39
-
40
- def get_marker(drug_inputs, prot_inputs):
41
- output_preds = model(drug_inputs, prot_inputs)
42
-
43
- predict = torch.squeeze( (output_preds)).tolist()
44
-
45
- # output_preds = torch.relu(output_preds)
46
- # predict = torch.tanh(output_preds)
47
- # predict = predict.squeeze(dim=1).tolist()
48
-
49
- return predict
50
-
51
-
52
- def marker_prediction(smiles, aas):
53
- try:
54
- aas_input = []
55
- for ass_data in aas:
56
- aas_input.append(' '.join(list(ass_data)))
57
-
58
- a_inputs = tokenizer(smiles, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
59
- # d_inputs = tokenizer(smiles, truncation=True, return_tensors="pt")
60
- a_input_ids = a_inputs['input_ids'].to(device)
61
- a_attention_mask = a_inputs['attention_mask'].to(device)
62
- a_inputs = {'input_ids': a_input_ids, 'attention_mask': a_attention_mask}
63
-
64
- d_inputs = d_tokenizer(aas_input, padding='max_length', max_length=510, truncation=True, return_tensors="pt")
65
- # p_inputs = prot_tokenizer(aas_input, truncation=True, return_tensors="pt")
66
- d_input_ids = d_inputs['input_ids'].to(device)
67
- d_attention_mask = d_inputs['attention_mask'].to(device)
68
- d_inputs = {'input_ids': d_input_ids, 'attention_mask': d_attention_mask}
69
-
70
- output_predict = get_marker(a_inputs, d_inputs)
71
-
72
- output_list = [{'acceptor': smiles[i], 'donor': aas[i], 'predict': output_predict[i]} for i in range(0,len(aas))]
73
-
74
- return output_list
75
-
76
- except Exception as e:
77
- print(e)
78
- return {'Error_message': e}
79
-
80
-
81
- def smiles_aas_test(file):
82
-
83
- batch_size = 80
84
- try:
85
- datas = []
86
- marker_list = []
87
- marker_datas = []
88
-
89
- smiles_aas = pd.read_csv(file)
90
-
91
- ## -- 1 to 1 pair predict check -- ##
92
- for data in smiles_aas.values:
93
- mola = Chem.MolFromSmiles(data[2])
94
- data[2] = Chem.MolToSmiles(mola, canonical=True)
95
- mola = Chem.MolFromSmiles(data[1])
96
- data[1] = Chem.MolToSmiles(mola, canonical=True)
97
- marker_datas.append([data[2], data[1]])
98
- if len(marker_datas) == batch_size:
99
- marker_list.append(list(marker_datas))
100
- marker_datas.clear()
101
-
102
- if len(marker_datas) != 0:
103
- marker_list.append(list(marker_datas))
104
- marker_datas.clear()
105
-
106
- for marker_datas in tqdm(marker_list, total=len(marker_list)):
107
- smiles_d , smiles_a = zip(*marker_datas)
108
- output_pred = marker_prediction(list(smiles_d), list(smiles_a) )
109
- if len(datas) == 0:
110
- datas = output_pred
111
- else:
112
- datas = datas + output_pred
113
- datas = pd.DataFrame(datas)
114
-
115
- return datas
116
-
117
- except Exception as e:
118
- print(e)
119
- return {'Error_message': e}
120
-