ZhaohanM commited on
Commit
7c46397
1 Parent(s): 5b5fc06
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ step_300_model.bin filter=lfs diff=lfs merge=lfs -text
37
+ disgenet_latest.csv filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/gda_api-checkpoint.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import os
5
+ import subprocess
6
+
7
+ def predict_top_100_genes(disease_id):
8
+ # Initialize paths
9
+ input_csv_path = '/data/downstream/{}_disease.csv'.format(disease_id)
10
+ output_csv_path = '/data/downstream/{}_top100.csv'.format(disease_id)
11
+
12
+ # Check if the output CSV already exists
13
+ if not os.path.exists(output_csv_path):
14
+ # Proceed with your existing code if the output file doesn't exist
15
+ df = pd.read_csv('/data/pretrain/disgenet_latest.csv')
16
+ df = df[df['proteinSeq'].notna()]
17
+ desired_diseaseDes = df[df['diseaseId'] == disease_id]['diseaseDes'].iloc[0]
18
+ related_proteins = df[df['diseaseDes'] == desired_diseaseDes]['proteinSeq'].unique()
19
+ df['score'] = df['proteinSeq'].isin(related_proteins).astype(int)
20
+ new_df = pd.DataFrame({
21
+ 'diseaseId': disease_id,
22
+ 'diseaseDes': desired_diseaseDes,
23
+ 'geneSymbol': df['geneSymbol'],
24
+ 'proteinSeq': df['proteinSeq'],
25
+ 'score': df['score']
26
+ }).drop_duplicates().reset_index(drop=True)
27
+
28
+ new_df.to_csv(input_csv_path, index=False)
29
+
30
+ # Call the model script only if the output CSV does not exist
31
+ script_path = 'model.sh'
32
+ subprocess.run(['bash', script_path, input_csv_path, output_csv_path], check=True)
33
+
34
+ # Read the model output file or the existing file to get the top 100 genes
35
+ output_df = pd.read_csv(output_csv_path)
36
+ # Update here to select only the required columns and rename them
37
+ result_df = output_df[['geneSymbol', 'Prediction_score']].rename(columns={'geneSymbol': 'Gene', 'Prediction_score': 'Score'}).head(100)
38
+
39
+ return result_df
40
+
41
+
42
+ iface = gr.Interface(
43
+ fn=predict_top_100_genes,
44
+ inputs=gr.Textbox(lines=1, placeholder="Enter Disease ID Here...", label="Disease ID"),
45
+ outputs=gr.Dataframe(label="Predicted Top 100 Related Genes"),
46
+ title="Gene Disease Association Prediction",
47
+ description = (
48
+ "This AI model predicts the top 100 genes associated with a given disease based on 16,733 genes."
49
+ " To get started, you need a Disease ID (UMLS CUI), which can be obtained from the DisGeNET database. "
50
+ "\n\n**Steps to Obtain a Disease ID from DisGeNET:**\n"
51
+ "1. Visit the DisGeNET website: [https://www.disgenet.org/search](https://www.disgenet.org/search).\n"
52
+ "2. Use the search bar to enter your disease of interest. For instance, if you're interested in 'Alzheimer's Disease', type 'Alzheimer's Disease' into the search bar.\n"
53
+ "3. From the search results, identify the disease you're researching. The Disease ID (UMLS CUI) is listed alongside each disease name, e.g. C0002395.\n"
54
+ "4. Enter the Disease ID into the input box below and submit.\n\n"
55
+ "The DisGeNET database contains all known gene-disease associations and associated evidence. In addition, it is able to find the corresponding diseases based on a gene.\n"
56
+ "\n**The model will take about 18 minutes to inference a new disease.**\n"
57
+ )
58
+ )
59
+
60
+ iface.launch(share=True)
.ipynb_checkpoints/model-checkpoint.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ input_csv_path="$1"
4
+ output_csv_path="$2"
5
+ max_depth=6
6
+ device='cuda:0'
7
+ model_path_list=(
8
+ "../../save_model_ckp/gda_infoNCE_2024_GPU3090" \
9
+ )
10
+
11
+ cd ../src/finetune/
12
+ for save_model_path in ${model_path_list[@]}; do
13
+ num_leaves=$((2**($max_depth-1)))
14
+ python finetune.py \
15
+ --input_csv_path $input_csv_path \
16
+ --output_csv_path $output_csv_path \
17
+ --save_model_path $save_model_path \
18
+ --device $device \
19
+ --batch_size 128 \
20
+ --step "300" \
21
+ --use_pooled \
22
+ --num_leaves $num_leaves \
23
+ --max_depth $max_depth
24
+ done
.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ lightgbm
2
+ pytorch-metric-learning
3
+ torch
4
+ transformers
5
+ PyTDC
data/pretrain/disgenet_latest.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1eddc359c7671bcb71a2975e96846c2dde66a4e60b886e47ff62a3f6b28868d0
3
+ size 1121691139
gda_api.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import os
5
+ import subprocess
6
+
7
+ def predict_top_100_genes(disease_id):
8
+ # Initialize paths
9
+ input_csv_path = '/data/downstream/{}_disease.csv'.format(disease_id)
10
+ output_csv_path = '/data/downstream/{}_top100.csv'.format(disease_id)
11
+
12
+ # Check if the output CSV already exists
13
+ if not os.path.exists(output_csv_path):
14
+ # Proceed with your existing code if the output file doesn't exist
15
+ df = pd.read_csv('/data/pretrain/disgenet_latest.csv')
16
+ df = df[df['proteinSeq'].notna()]
17
+ desired_diseaseDes = df[df['diseaseId'] == disease_id]['diseaseDes'].iloc[0]
18
+ related_proteins = df[df['diseaseDes'] == desired_diseaseDes]['proteinSeq'].unique()
19
+ df['score'] = df['proteinSeq'].isin(related_proteins).astype(int)
20
+ new_df = pd.DataFrame({
21
+ 'diseaseId': disease_id,
22
+ 'diseaseDes': desired_diseaseDes,
23
+ 'geneSymbol': df['geneSymbol'],
24
+ 'proteinSeq': df['proteinSeq'],
25
+ 'score': df['score']
26
+ }).drop_duplicates().reset_index(drop=True)
27
+
28
+ new_df.to_csv(input_csv_path, index=False)
29
+
30
+ # Call the model script only if the output CSV does not exist
31
+ script_path = 'model.sh'
32
+ subprocess.run(['bash', script_path, input_csv_path, output_csv_path], check=True)
33
+
34
+ # Read the model output file or the existing file to get the top 100 genes
35
+ output_df = pd.read_csv(output_csv_path)
36
+ # Update here to select only the required columns and rename them
37
+ result_df = output_df[['geneSymbol', 'Prediction_score']].rename(columns={'geneSymbol': 'Gene', 'Prediction_score': 'Score'}).head(100)
38
+
39
+ return result_df
40
+
41
+
42
+ iface = gr.Interface(
43
+ fn=predict_top_100_genes,
44
+ inputs=gr.Textbox(lines=1, placeholder="Enter Disease ID Here...", label="Disease ID"),
45
+ outputs=gr.Dataframe(label="Predicted Top 100 Related Genes"),
46
+ title="Gene Disease Association Prediction",
47
+ description = (
48
+ "This AI model predicts the top 100 genes associated with a given disease based on 16,733 genes."
49
+ " To get started, you need a Disease ID (UMLS CUI), which can be obtained from the DisGeNET database. "
50
+ "\n\n**Steps to Obtain a Disease ID from DisGeNET:**\n"
51
+ "1. Visit the DisGeNET website: [https://www.disgenet.org/search](https://www.disgenet.org/search).\n"
52
+ "2. Use the search bar to enter your disease of interest. For instance, if you're interested in 'Alzheimer's Disease', type 'Alzheimer's Disease' into the search bar.\n"
53
+ "3. From the search results, identify the disease you're researching. The Disease ID (UMLS CUI) is listed alongside each disease name, e.g. C0002395.\n"
54
+ "4. Enter the Disease ID into the input box below and submit.\n\n"
55
+ "The DisGeNET database contains all known gene-disease associations and associated evidence. In addition, it is able to find the corresponding diseases based on a gene.\n"
56
+ "\n**The model will take about 18 minutes to inference a new disease.**\n"
57
+ )
58
+ )
59
+
60
+ iface.launch(share=True)
model.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ input_csv_path="$1"
4
+ output_csv_path="$2"
5
+ max_depth=6
6
+ device='cuda:0'
7
+ model_path_list=(
8
+ "../../save_model_ckp/gda_infoNCE_2024_GPU3090" \
9
+ )
10
+
11
+ cd ../src/finetune/
12
+ for save_model_path in ${model_path_list[@]}; do
13
+ num_leaves=$((2**($max_depth-1)))
14
+ python finetune.py \
15
+ --input_csv_path $input_csv_path \
16
+ --output_csv_path $output_csv_path \
17
+ --save_model_path $save_model_path \
18
+ --device $device \
19
+ --batch_size 128 \
20
+ --step "300" \
21
+ --use_pooled \
22
+ --num_leaves $num_leaves \
23
+ --max_depth $max_depth
24
+ done
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ lightgbm
2
+ pytorch-metric-learning
3
+ torch
4
+ transformers
5
+ PyTDC
save_model_ckp/gda_infoNCE_2024_GPU3090/step_300_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:504129ccb1c717366e843df99e73d629b5c0bac0603deb8dbc6fb9b5479387b7
3
+ size 3131981635
src/finetune/.ipynb_checkpoints/finetune-checkpoint.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import string
5
+ import sys
6
+ import pandas as pd
7
+ from datetime import datetime
8
+
9
+ sys.path.append("../")
10
+ import numpy as np
11
+ import torch
12
+ import lightgbm as lgb
13
+ import sklearn.metrics as metrics
14
+ from sklearn.utils import class_weight
15
+ from sklearn.model_selection import train_test_split
16
+ from sklearn.metrics import accuracy_score, precision_recall_curve, f1_score, precision_recall_fscore_support,roc_auc_score
17
+ from torch.utils.data import DataLoader
18
+ from tqdm.auto import tqdm
19
+ from transformers import EsmTokenizer, EsmForMaskedLM, BertModel, BertTokenizer, AutoTokenizer, EsmModel
20
+ from utils.downstream_disgenet import DisGeNETProcessor
21
+ from utils.metric_learning_models import GDA_Metric_Learning
22
+
23
+ def parse_config():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument('-f')
26
+ parser.add_argument("--step", type=int, default=0)
27
+ parser.add_argument(
28
+ "--save_model_path",
29
+ type=str,
30
+ default=None,
31
+ help="path of the pretrained disease model located",
32
+ )
33
+ parser.add_argument(
34
+ "--prot_encoder_path",
35
+ type=str,
36
+ default="facebook/esm2_t33_650M_UR50D",
37
+ #"facebook/galactica-6.7b", "Rostlab/prot_bert" “facebook/esm2_t33_650M_UR50D”
38
+ help="path/name of protein encoder model located",
39
+ )
40
+ parser.add_argument(
41
+ "--disease_encoder_path",
42
+ type=str,
43
+ default="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
44
+ help="path/name of textual pre-trained language model",
45
+ )
46
+ parser.add_argument("--reduction_factor", type=int, default=8)
47
+ parser.add_argument(
48
+ "--loss",
49
+ help="{ms_loss|infoNCE|cosine_loss|circle_loss|triplet_loss}}",
50
+ default="infoNCE",
51
+ )
52
+ parser.add_argument(
53
+ "--input_feature_save_path",
54
+ type=str,
55
+ default="../../data/processed_disease",
56
+ help="path of tokenized training data",
57
+ )
58
+ parser.add_argument(
59
+ "--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}"
60
+ )
61
+ parser.add_argument("--batch_size", type=int, default=256)
62
+ parser.add_argument("--patience", type=int, default=5)
63
+ parser.add_argument("--num_leaves", type=int, default=5)
64
+ parser.add_argument("--max_depth", type=int, default=5)
65
+ parser.add_argument("--lr", type=float, default=0.35)
66
+ parser.add_argument("--dropout", type=float, default=0.1)
67
+ parser.add_argument("--test", type=int, default=0)
68
+ parser.add_argument("--use_miner", action="store_true")
69
+ parser.add_argument("--miner_margin", default=0.2, type=float)
70
+ parser.add_argument("--freeze_prot_encoder", action="store_true")
71
+ parser.add_argument("--freeze_disease_encoder", action="store_true")
72
+ parser.add_argument("--use_adapter", action="store_true")
73
+ parser.add_argument("--use_pooled", action="store_true")
74
+ parser.add_argument("--device", type=str, default="cpu")
75
+ parser.add_argument(
76
+ "--use_both_feature",
77
+ help="use the both features of gnn_feature_v1_samples and pretrained models",
78
+ action="store_true",
79
+ )
80
+ parser.add_argument(
81
+ "--use_v1_feature_only",
82
+ help="use the features of gnn_feature_v1_samples only",
83
+ action="store_true",
84
+ )
85
+ parser.add_argument(
86
+ "--save_path_prefix",
87
+ type=str,
88
+ default="../../save_model_ckp/finetune/",
89
+ help="save the result in which directory",
90
+ )
91
+ parser.add_argument(
92
+ "--save_name", default="fine_tune", type=str, help="the name of the saved file"
93
+ )
94
+ # Add argument for input CSV file path
95
+ parser.add_argument("--input_csv_path", type=str, required=True, help="Path to the input CSV file.")
96
+
97
+ # Add argument for output CSV file path
98
+ parser.add_argument("--output_csv_path", type=str, required=True, help="Path to the output CSV file.")
99
+ return parser.parse_args()
100
+
101
+ def get_feature(model, dataloader, args):
102
+ x = list()
103
+ y = list()
104
+ with torch.no_grad():
105
+ for step, batch in tqdm(enumerate(dataloader)):
106
+ prot_input_ids, prot_attention_mask, dis_input_ids, dis_attention_mask, y1 = batch
107
+ prot_input = {
108
+ 'input_ids': prot_input_ids.to(args.device),
109
+ 'attention_mask': prot_attention_mask.to(args.device)
110
+ }
111
+ dis_input = {
112
+ 'input_ids': dis_input_ids.to(args.device),
113
+ 'attention_mask': dis_attention_mask.to(args.device)
114
+ }
115
+ feature_output = model.predict(prot_input, dis_input)
116
+ x1 = feature_output.cpu().numpy()
117
+ x.append(x1)
118
+ y.append(y1.cpu().numpy())
119
+ x = np.concatenate(x, axis=0)
120
+ y = np.concatenate(y, axis=0)
121
+ return x, y
122
+
123
+
124
+ def encode_pretrained_feature(args, disGeNET):
125
+ input_feat_file = os.path.join(
126
+ args.input_feature_save_path,
127
+ f"{args.model_short}_{args.step}_use_{'pooled' if args.use_pooled else 'cls'}_feat.npz",
128
+ )
129
+
130
+ if os.path.exists(input_feat_file):
131
+ print(f"load prior feature data from {input_feat_file}.")
132
+ loaded = np.load(input_feat_file)
133
+ x_train, y_train = loaded["x_train"], loaded["y_train"]
134
+ x_valid, y_valid = loaded["x_valid"], loaded["y_valid"]
135
+ # x_test, y_test = loaded["x_test"], loaded["y_test"]
136
+
137
+ prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
138
+ # prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
139
+ print("prot_tokenizer", len(prot_tokenizer))
140
+ disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path)
141
+ print("disease_tokenizer", len(disease_tokenizer))
142
+
143
+ prot_model = EsmModel.from_pretrained(args.prot_encoder_path)
144
+ # prot_model = BertModel.from_pretrained(args.prot_encoder_path)
145
+ disease_model = BertModel.from_pretrained(args.disease_encoder_path)
146
+
147
+ if args.save_model_path:
148
+ model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args)
149
+
150
+ if args.use_adapter:
151
+ prot_model_path = os.path.join(
152
+ args.save_model_path, f"prot_adapter_step_{args.step}"
153
+ )
154
+ disease_model_path = os.path.join(
155
+ args.save_model_path, f"disease_adapter_step_{args.step}"
156
+ )
157
+ model.load_adapters(prot_model_path, disease_model_path)
158
+ else:
159
+ prot_model_path = os.path.join(
160
+ args.save_model_path, f"step_{args.step}_model.bin"
161
+ )# , f"step_{args.step}_model.bin"
162
+ disease_model_path = os.path.join(
163
+ args.save_model_path, f"step_{args.step}_model.bin"
164
+ )
165
+ model.non_adapters(prot_model_path, disease_model_path)
166
+
167
+ model = model.to(args.device)
168
+ prot_model = model.prot_encoder
169
+ disease_model = model.disease_encoder
170
+ print(f"loaded prior model {args.save_model_path}.")
171
+
172
+ def collate_fn_batch_encoding(batch):
173
+ query1, query2, scores = zip(*batch)
174
+
175
+ query_encodings1 = prot_tokenizer.batch_encode_plus(
176
+ list(query1),
177
+ max_length=512,
178
+ padding="max_length",
179
+ truncation=True,
180
+ add_special_tokens=True,
181
+ return_tensors="pt",
182
+ )
183
+ query_encodings2 = disease_tokenizer.batch_encode_plus(
184
+ list(query2),
185
+ max_length=512,
186
+ padding="max_length",
187
+ truncation=True,
188
+ add_special_tokens=True,
189
+ return_tensors="pt",
190
+ )
191
+ scores = torch.tensor(list(scores))
192
+ attention_mask1 = query_encodings1["attention_mask"].bool()
193
+ attention_mask2 = query_encodings2["attention_mask"].bool()
194
+
195
+ return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
196
+
197
+ test_examples = disGeNET.get_test_examples(args.test)
198
+ print(f"get test examples: {len(test_examples)}")
199
+
200
+ test_dataloader = DataLoader(
201
+ test_examples,
202
+ batch_size=args.batch_size,
203
+ shuffle=False,
204
+ collate_fn=collate_fn_batch_encoding,
205
+ )
206
+ print( f"dataset loaded: test-{len(test_examples)}")
207
+
208
+ x_test, y_test = get_feature(model, test_dataloader, args)
209
+
210
+ else:
211
+ prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
212
+ # prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
213
+ print("prot_tokenizer", len(prot_tokenizer))
214
+ disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path)
215
+ print("disease_tokenizer", len(disease_tokenizer))
216
+
217
+ prot_model = EsmModel.from_pretrained(args.prot_encoder_path)
218
+ # prot_model = BertModel.from_pretrained(args.prot_encoder_path)
219
+ disease_model = BertModel.from_pretrained(args.disease_encoder_path)
220
+
221
+ if args.save_model_path:
222
+ model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args)
223
+
224
+ if args.use_adapter:
225
+ prot_model_path = os.path.join(
226
+ args.save_model_path, f"prot_adapter_step_{args.step}"
227
+ )
228
+ disease_model_path = os.path.join(
229
+ args.save_model_path, f"disease_adapter_step_{args.step}"
230
+ )
231
+ model.load_adapters(prot_model_path, disease_model_path)
232
+ else:
233
+ prot_model_path = os.path.join(
234
+ args.save_model_path, f"step_{args.step}_model.bin"
235
+ )# , f"step_{args.step}_model.bin"
236
+ disease_model_path = os.path.join(
237
+ args.save_model_path, f"step_{args.step}_model.bin"
238
+ )
239
+ model.non_adapters(prot_model_path, disease_model_path)
240
+
241
+ model = model.to(args.device)
242
+ prot_model = model.prot_encoder
243
+ disease_model = model.disease_encoder
244
+ print(f"loaded prior model {args.save_model_path}.")
245
+
246
+ def collate_fn_batch_encoding(batch):
247
+ query1, query2, scores = zip(*batch)
248
+
249
+ query_encodings1 = prot_tokenizer.batch_encode_plus(
250
+ list(query1),
251
+ max_length=512,
252
+ padding="max_length",
253
+ truncation=True,
254
+ add_special_tokens=True,
255
+ return_tensors="pt",
256
+ )
257
+ query_encodings2 = disease_tokenizer.batch_encode_plus(
258
+ list(query2),
259
+ max_length=512,
260
+ padding="max_length",
261
+ truncation=True,
262
+ add_special_tokens=True,
263
+ return_tensors="pt",
264
+ )
265
+ scores = torch.tensor(list(scores))
266
+ attention_mask1 = query_encodings1["attention_mask"].bool()
267
+ attention_mask2 = query_encodings2["attention_mask"].bool()
268
+
269
+ return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
270
+
271
+ train_examples = disGeNET.get_train_examples(args.test)
272
+ print(f"get training examples: {len(train_examples)}")
273
+ valid_examples = disGeNET.get_val_examples(args.test)
274
+ print(f"get validation examples: {len(valid_examples)}")
275
+ test_examples = disGeNET.get_test_examples(args.test)
276
+ print(f"get test examples: {len(test_examples)}")
277
+
278
+ train_dataloader = DataLoader(
279
+ train_examples,
280
+ batch_size=args.batch_size,
281
+ shuffle=False,
282
+ collate_fn=collate_fn_batch_encoding,
283
+ )
284
+ valid_dataloader = DataLoader(
285
+ valid_examples,
286
+ batch_size=args.batch_size,
287
+ shuffle=False,
288
+ collate_fn=collate_fn_batch_encoding,
289
+ )
290
+ test_dataloader = DataLoader(
291
+ test_examples,
292
+ batch_size=args.batch_size,
293
+ shuffle=False,
294
+ collate_fn=collate_fn_batch_encoding,
295
+ )
296
+ print( f"dataset loaded: train-{len(train_examples)}; valid-{len(valid_examples)}; test-{len(test_examples)}")
297
+
298
+ x_train, y_train = get_feature(model, train_dataloader, args)
299
+ x_valid, y_valid = get_feature(model, valid_dataloader, args)
300
+ x_test, y_test = get_feature(model, test_dataloader, args)
301
+
302
+ # Save input feature to reduce encoding time
303
+ np.savez_compressed(
304
+ input_feat_file,
305
+ x_train=x_train,
306
+ y_train=y_train,
307
+ x_valid=x_valid,
308
+ y_valid=y_valid,
309
+ )
310
+ print(f"save input feature into {input_feat_file}")
311
+ # Save input feature to reduce encoding time
312
+ return x_train, y_train, x_valid, y_valid, x_test, y_test
313
+
314
+
315
+ def train(args):
316
+ # defining parameters
317
+ if args.save_model_path:
318
+ args.model_short = (
319
+ args.save_model_path.split("/")[-1]
320
+ )
321
+ print(f"model name {args.model_short}")
322
+
323
+ else:
324
+ args.model_short = (
325
+ args.disease_encoder_path.split("/")[-1]
326
+ )
327
+ print(f"model name {args.model_short}")
328
+
329
+ # disGeNET = DisGeNETProcessor()
330
+ disGeNET = DisGeNETProcessor(input_csv_path=args.input_csv_path)
331
+
332
+
333
+ x_train, y_train, x_valid, y_valid, x_test, y_test = encode_pretrained_feature(args, disGeNET)
334
+
335
+ print("train: ", x_train.shape, y_train.shape)
336
+ print("valid: ", x_valid.shape, y_valid.shape)
337
+ print("test: ", x_test.shape, y_test.shape)
338
+
339
+ params = {
340
+ "task": "train", # "predict" train
341
+ "boosting": "gbdt", # "The options are "gbdt" (traditional Gradient Boosting Decision Tree), "rf" (Random Forest), "dart" (Dropouts meet Multiple Additive Regression Trees), or "goss" (Gradient-based One-Side Sampling). The default is "gbdt"."
342
+ "objective": "binary",
343
+ "num_leaves": args.num_leaves,
344
+ "early_stopping_round": 30,
345
+ "max_depth": args.max_depth,
346
+ "learning_rate": args.lr,
347
+ "metric": "binary_logloss", #"metric": "l2","binary_logloss" "auc"
348
+ "verbose": 1,
349
+ }
350
+
351
+ lgb_train = lgb.Dataset(x_train, y_train)
352
+ lgb_valid = lgb.Dataset(x_valid, y_valid)
353
+ lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train)
354
+
355
+ # fitting the model
356
+ model = lgb.train(
357
+ params, train_set=lgb_train, valid_sets=lgb_valid)
358
+
359
+ # prediction
360
+ valid_y_pred = model.predict(x_valid)
361
+ test_y_pred = model.predict(x_test)
362
+
363
+ # predict liver fibrosis
364
+ predictions_df = pd.DataFrame(test_y_pred, columns=["Prediction_score"])
365
+ # data_test = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/test_tdc.csv')
366
+ data_test = pd.read_csv(args.input_csv_path)
367
+ predictions = pd.concat([data_test, predictions_df], axis=1)
368
+ # filtered_dataset = test_dataset_with_predictions[test_dataset_with_predictions['diseaseId'] == 'C0009714']
369
+ predictions.sort_values(by='Prediction_score', ascending=False, inplace=True)
370
+ top_100_predictions = predictions.head(100)
371
+ top_100_predictions.to_csv(args.output_csv_path, index=False)
372
+
373
+ # Accuracy
374
+ y_pred = model.predict(x_test, num_iteration=model.best_iteration)
375
+ y_pred[y_pred >= 0.5] = 1
376
+ y_pred[y_pred < 0.5] = 0
377
+ accuracy = accuracy_score(y_test, y_pred)
378
+
379
+ # AUC
380
+ valid_roc_auc_score = metrics.roc_auc_score(y_valid, valid_y_pred)
381
+ valid_average_precision_score = metrics.average_precision_score(
382
+ y_valid, valid_y_pred
383
+ )
384
+ test_roc_auc_score = metrics.roc_auc_score(y_test, test_y_pred)
385
+ test_average_precision_score = metrics.average_precision_score(y_test, test_y_pred)
386
+
387
+ # AUPR
388
+ valid_aupr = metrics.average_precision_score(y_valid, valid_y_pred)
389
+ test_aupr = metrics.average_precision_score(y_test, test_y_pred)
390
+
391
+ # Fmax
392
+ valid_precision, valid_recall, valid_thresholds = precision_recall_curve(y_valid, valid_y_pred)
393
+ valid_fmax = (2 * valid_precision * valid_recall / (valid_precision + valid_recall)).max()
394
+ test_precision, test_recall, test_thresholds = precision_recall_curve(y_test, test_y_pred)
395
+ test_fmax = (2 * test_precision * test_recall / (test_precision + test_recall)).max()
396
+
397
+ # F1
398
+ valid_f1 = f1_score(y_valid, valid_y_pred >= 0.5)
399
+ test_f1 = f1_score(y_test, test_y_pred >= 0.5)
400
+
401
+
402
+ if __name__ == "__main__":
403
+ args = parse_config()
404
+ if torch.cuda.is_available():
405
+ print("cuda is available.")
406
+ print(f"current device {args}.")
407
+ else:
408
+ args.device = "cpu"
409
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
410
+ random_str = "".join([random.choice(string.ascii_lowercase) for n in range(6)])
411
+ best_model_dir = (
412
+ f"{args.save_path_prefix}{args.save_name}_{timestamp_str}_{random_str}/"
413
+ )
414
+ os.makedirs(best_model_dir)
415
+ args.save_name = best_model_dir
416
+ train(args)
src/finetune/finetune.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import string
5
+ import sys
6
+ import pandas as pd
7
+ from datetime import datetime
8
+
9
+ sys.path.append("../")
10
+ import numpy as np
11
+ import torch
12
+ import lightgbm as lgb
13
+ import sklearn.metrics as metrics
14
+ from sklearn.utils import class_weight
15
+ from sklearn.model_selection import train_test_split
16
+ from sklearn.metrics import accuracy_score, precision_recall_curve, f1_score, precision_recall_fscore_support,roc_auc_score
17
+ from torch.utils.data import DataLoader
18
+ from tqdm.auto import tqdm
19
+ from transformers import EsmTokenizer, EsmForMaskedLM, BertModel, BertTokenizer, AutoTokenizer, EsmModel
20
+ from utils.downstream_disgenet import DisGeNETProcessor
21
+ from utils.metric_learning_models import GDA_Metric_Learning
22
+
23
+ def parse_config():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument('-f')
26
+ parser.add_argument("--step", type=int, default=0)
27
+ parser.add_argument(
28
+ "--save_model_path",
29
+ type=str,
30
+ default=None,
31
+ help="path of the pretrained disease model located",
32
+ )
33
+ parser.add_argument(
34
+ "--prot_encoder_path",
35
+ type=str,
36
+ default="facebook/esm2_t33_650M_UR50D",
37
+ #"facebook/galactica-6.7b", "Rostlab/prot_bert" “facebook/esm2_t33_650M_UR50D”
38
+ help="path/name of protein encoder model located",
39
+ )
40
+ parser.add_argument(
41
+ "--disease_encoder_path",
42
+ type=str,
43
+ default="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
44
+ help="path/name of textual pre-trained language model",
45
+ )
46
+ parser.add_argument("--reduction_factor", type=int, default=8)
47
+ parser.add_argument(
48
+ "--loss",
49
+ help="{ms_loss|infoNCE|cosine_loss|circle_loss|triplet_loss}}",
50
+ default="infoNCE",
51
+ )
52
+ parser.add_argument(
53
+ "--input_feature_save_path",
54
+ type=str,
55
+ default="../../data/processed_disease",
56
+ help="path of tokenized training data",
57
+ )
58
+ parser.add_argument(
59
+ "--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}"
60
+ )
61
+ parser.add_argument("--batch_size", type=int, default=256)
62
+ parser.add_argument("--patience", type=int, default=5)
63
+ parser.add_argument("--num_leaves", type=int, default=5)
64
+ parser.add_argument("--max_depth", type=int, default=5)
65
+ parser.add_argument("--lr", type=float, default=0.35)
66
+ parser.add_argument("--dropout", type=float, default=0.1)
67
+ parser.add_argument("--test", type=int, default=0)
68
+ parser.add_argument("--use_miner", action="store_true")
69
+ parser.add_argument("--miner_margin", default=0.2, type=float)
70
+ parser.add_argument("--freeze_prot_encoder", action="store_true")
71
+ parser.add_argument("--freeze_disease_encoder", action="store_true")
72
+ parser.add_argument("--use_adapter", action="store_true")
73
+ parser.add_argument("--use_pooled", action="store_true")
74
+ parser.add_argument("--device", type=str, default="cpu")
75
+ parser.add_argument(
76
+ "--use_both_feature",
77
+ help="use the both features of gnn_feature_v1_samples and pretrained models",
78
+ action="store_true",
79
+ )
80
+ parser.add_argument(
81
+ "--use_v1_feature_only",
82
+ help="use the features of gnn_feature_v1_samples only",
83
+ action="store_true",
84
+ )
85
+ parser.add_argument(
86
+ "--save_path_prefix",
87
+ type=str,
88
+ default="../../save_model_ckp/finetune/",
89
+ help="save the result in which directory",
90
+ )
91
+ parser.add_argument(
92
+ "--save_name", default="fine_tune", type=str, help="the name of the saved file"
93
+ )
94
+ # Add argument for input CSV file path
95
+ parser.add_argument("--input_csv_path", type=str, required=True, help="Path to the input CSV file.")
96
+
97
+ # Add argument for output CSV file path
98
+ parser.add_argument("--output_csv_path", type=str, required=True, help="Path to the output CSV file.")
99
+ return parser.parse_args()
100
+
101
+ def get_feature(model, dataloader, args):
102
+ x = list()
103
+ y = list()
104
+ with torch.no_grad():
105
+ for step, batch in tqdm(enumerate(dataloader)):
106
+ prot_input_ids, prot_attention_mask, dis_input_ids, dis_attention_mask, y1 = batch
107
+ prot_input = {
108
+ 'input_ids': prot_input_ids.to(args.device),
109
+ 'attention_mask': prot_attention_mask.to(args.device)
110
+ }
111
+ dis_input = {
112
+ 'input_ids': dis_input_ids.to(args.device),
113
+ 'attention_mask': dis_attention_mask.to(args.device)
114
+ }
115
+ feature_output = model.predict(prot_input, dis_input)
116
+ x1 = feature_output.cpu().numpy()
117
+ x.append(x1)
118
+ y.append(y1.cpu().numpy())
119
+ x = np.concatenate(x, axis=0)
120
+ y = np.concatenate(y, axis=0)
121
+ return x, y
122
+
123
+
124
+ def encode_pretrained_feature(args, disGeNET):
125
+ input_feat_file = os.path.join(
126
+ args.input_feature_save_path,
127
+ f"{args.model_short}_{args.step}_use_{'pooled' if args.use_pooled else 'cls'}_feat.npz",
128
+ )
129
+
130
+ if os.path.exists(input_feat_file):
131
+ print(f"load prior feature data from {input_feat_file}.")
132
+ loaded = np.load(input_feat_file)
133
+ x_train, y_train = loaded["x_train"], loaded["y_train"]
134
+ x_valid, y_valid = loaded["x_valid"], loaded["y_valid"]
135
+ # x_test, y_test = loaded["x_test"], loaded["y_test"]
136
+
137
+ prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
138
+ # prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
139
+ print("prot_tokenizer", len(prot_tokenizer))
140
+ disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path)
141
+ print("disease_tokenizer", len(disease_tokenizer))
142
+
143
+ prot_model = EsmModel.from_pretrained(args.prot_encoder_path)
144
+ # prot_model = BertModel.from_pretrained(args.prot_encoder_path)
145
+ disease_model = BertModel.from_pretrained(args.disease_encoder_path)
146
+
147
+ if args.save_model_path:
148
+ model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args)
149
+
150
+ if args.use_adapter:
151
+ prot_model_path = os.path.join(
152
+ args.save_model_path, f"prot_adapter_step_{args.step}"
153
+ )
154
+ disease_model_path = os.path.join(
155
+ args.save_model_path, f"disease_adapter_step_{args.step}"
156
+ )
157
+ model.load_adapters(prot_model_path, disease_model_path)
158
+ else:
159
+ prot_model_path = os.path.join(
160
+ args.save_model_path, f"step_{args.step}_model.bin"
161
+ )# , f"step_{args.step}_model.bin"
162
+ disease_model_path = os.path.join(
163
+ args.save_model_path, f"step_{args.step}_model.bin"
164
+ )
165
+ model.non_adapters(prot_model_path, disease_model_path)
166
+
167
+ model = model.to(args.device)
168
+ prot_model = model.prot_encoder
169
+ disease_model = model.disease_encoder
170
+ print(f"loaded prior model {args.save_model_path}.")
171
+
172
+ def collate_fn_batch_encoding(batch):
173
+ query1, query2, scores = zip(*batch)
174
+
175
+ query_encodings1 = prot_tokenizer.batch_encode_plus(
176
+ list(query1),
177
+ max_length=512,
178
+ padding="max_length",
179
+ truncation=True,
180
+ add_special_tokens=True,
181
+ return_tensors="pt",
182
+ )
183
+ query_encodings2 = disease_tokenizer.batch_encode_plus(
184
+ list(query2),
185
+ max_length=512,
186
+ padding="max_length",
187
+ truncation=True,
188
+ add_special_tokens=True,
189
+ return_tensors="pt",
190
+ )
191
+ scores = torch.tensor(list(scores))
192
+ attention_mask1 = query_encodings1["attention_mask"].bool()
193
+ attention_mask2 = query_encodings2["attention_mask"].bool()
194
+
195
+ return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
196
+
197
+ test_examples = disGeNET.get_test_examples(args.test)
198
+ print(f"get test examples: {len(test_examples)}")
199
+
200
+ test_dataloader = DataLoader(
201
+ test_examples,
202
+ batch_size=args.batch_size,
203
+ shuffle=False,
204
+ collate_fn=collate_fn_batch_encoding,
205
+ )
206
+ print( f"dataset loaded: test-{len(test_examples)}")
207
+
208
+ x_test, y_test = get_feature(model, test_dataloader, args)
209
+
210
+ else:
211
+ prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
212
+ # prot_tokenizer = BertTokenizer.from_pretrained(args.prot_encoder_path, do_lower_case=False)
213
+ print("prot_tokenizer", len(prot_tokenizer))
214
+ disease_tokenizer = BertTokenizer.from_pretrained(args.disease_encoder_path)
215
+ print("disease_tokenizer", len(disease_tokenizer))
216
+
217
+ prot_model = EsmModel.from_pretrained(args.prot_encoder_path)
218
+ # prot_model = BertModel.from_pretrained(args.prot_encoder_path)
219
+ disease_model = BertModel.from_pretrained(args.disease_encoder_path)
220
+
221
+ if args.save_model_path:
222
+ model = GDA_Metric_Learning(prot_model, disease_model, 1280, 768, args)
223
+
224
+ if args.use_adapter:
225
+ prot_model_path = os.path.join(
226
+ args.save_model_path, f"prot_adapter_step_{args.step}"
227
+ )
228
+ disease_model_path = os.path.join(
229
+ args.save_model_path, f"disease_adapter_step_{args.step}"
230
+ )
231
+ model.load_adapters(prot_model_path, disease_model_path)
232
+ else:
233
+ prot_model_path = os.path.join(
234
+ args.save_model_path, f"step_{args.step}_model.bin"
235
+ )# , f"step_{args.step}_model.bin"
236
+ disease_model_path = os.path.join(
237
+ args.save_model_path, f"step_{args.step}_model.bin"
238
+ )
239
+ model.non_adapters(prot_model_path, disease_model_path)
240
+
241
+ model = model.to(args.device)
242
+ prot_model = model.prot_encoder
243
+ disease_model = model.disease_encoder
244
+ print(f"loaded prior model {args.save_model_path}.")
245
+
246
+ def collate_fn_batch_encoding(batch):
247
+ query1, query2, scores = zip(*batch)
248
+
249
+ query_encodings1 = prot_tokenizer.batch_encode_plus(
250
+ list(query1),
251
+ max_length=512,
252
+ padding="max_length",
253
+ truncation=True,
254
+ add_special_tokens=True,
255
+ return_tensors="pt",
256
+ )
257
+ query_encodings2 = disease_tokenizer.batch_encode_plus(
258
+ list(query2),
259
+ max_length=512,
260
+ padding="max_length",
261
+ truncation=True,
262
+ add_special_tokens=True,
263
+ return_tensors="pt",
264
+ )
265
+ scores = torch.tensor(list(scores))
266
+ attention_mask1 = query_encodings1["attention_mask"].bool()
267
+ attention_mask2 = query_encodings2["attention_mask"].bool()
268
+
269
+ return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
270
+
271
+ train_examples = disGeNET.get_train_examples(args.test)
272
+ print(f"get training examples: {len(train_examples)}")
273
+ valid_examples = disGeNET.get_val_examples(args.test)
274
+ print(f"get validation examples: {len(valid_examples)}")
275
+ test_examples = disGeNET.get_test_examples(args.test)
276
+ print(f"get test examples: {len(test_examples)}")
277
+
278
+ train_dataloader = DataLoader(
279
+ train_examples,
280
+ batch_size=args.batch_size,
281
+ shuffle=False,
282
+ collate_fn=collate_fn_batch_encoding,
283
+ )
284
+ valid_dataloader = DataLoader(
285
+ valid_examples,
286
+ batch_size=args.batch_size,
287
+ shuffle=False,
288
+ collate_fn=collate_fn_batch_encoding,
289
+ )
290
+ test_dataloader = DataLoader(
291
+ test_examples,
292
+ batch_size=args.batch_size,
293
+ shuffle=False,
294
+ collate_fn=collate_fn_batch_encoding,
295
+ )
296
+ print( f"dataset loaded: train-{len(train_examples)}; valid-{len(valid_examples)}; test-{len(test_examples)}")
297
+
298
+ x_train, y_train = get_feature(model, train_dataloader, args)
299
+ x_valid, y_valid = get_feature(model, valid_dataloader, args)
300
+ x_test, y_test = get_feature(model, test_dataloader, args)
301
+
302
+ # Save input feature to reduce encoding time
303
+ np.savez_compressed(
304
+ input_feat_file,
305
+ x_train=x_train,
306
+ y_train=y_train,
307
+ x_valid=x_valid,
308
+ y_valid=y_valid,
309
+ )
310
+ print(f"save input feature into {input_feat_file}")
311
+ # Save input feature to reduce encoding time
312
+ return x_train, y_train, x_valid, y_valid, x_test, y_test
313
+
314
+
315
+ def train(args):
316
+ # defining parameters
317
+ if args.save_model_path:
318
+ args.model_short = (
319
+ args.save_model_path.split("/")[-1]
320
+ )
321
+ print(f"model name {args.model_short}")
322
+
323
+ else:
324
+ args.model_short = (
325
+ args.disease_encoder_path.split("/")[-1]
326
+ )
327
+ print(f"model name {args.model_short}")
328
+
329
+ # disGeNET = DisGeNETProcessor()
330
+ disGeNET = DisGeNETProcessor(input_csv_path=args.input_csv_path)
331
+
332
+
333
+ x_train, y_train, x_valid, y_valid, x_test, y_test = encode_pretrained_feature(args, disGeNET)
334
+
335
+ print("train: ", x_train.shape, y_train.shape)
336
+ print("valid: ", x_valid.shape, y_valid.shape)
337
+ print("test: ", x_test.shape, y_test.shape)
338
+
339
+ params = {
340
+ "task": "train", # "predict" train
341
+ "boosting": "gbdt", # "The options are "gbdt" (traditional Gradient Boosting Decision Tree), "rf" (Random Forest), "dart" (Dropouts meet Multiple Additive Regression Trees), or "goss" (Gradient-based One-Side Sampling). The default is "gbdt"."
342
+ "objective": "binary",
343
+ "num_leaves": args.num_leaves,
344
+ "early_stopping_round": 30,
345
+ "max_depth": args.max_depth,
346
+ "learning_rate": args.lr,
347
+ "metric": "binary_logloss", #"metric": "l2","binary_logloss" "auc"
348
+ "verbose": 1,
349
+ }
350
+
351
+ lgb_train = lgb.Dataset(x_train, y_train)
352
+ lgb_valid = lgb.Dataset(x_valid, y_valid)
353
+ lgb_eval = lgb.Dataset(x_test, y_test, reference=lgb_train)
354
+
355
+ # fitting the model
356
+ model = lgb.train(
357
+ params, train_set=lgb_train, valid_sets=lgb_valid)
358
+
359
+ # prediction
360
+ valid_y_pred = model.predict(x_valid)
361
+ test_y_pred = model.predict(x_test)
362
+
363
+ # predict liver fibrosis
364
+ predictions_df = pd.DataFrame(test_y_pred, columns=["Prediction_score"])
365
+ # data_test = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/test_tdc.csv')
366
+ data_test = pd.read_csv(args.input_csv_path)
367
+ predictions = pd.concat([data_test, predictions_df], axis=1)
368
+ # filtered_dataset = test_dataset_with_predictions[test_dataset_with_predictions['diseaseId'] == 'C0009714']
369
+ predictions.sort_values(by='Prediction_score', ascending=False, inplace=True)
370
+ top_100_predictions = predictions.head(100)
371
+ top_100_predictions.to_csv(args.output_csv_path, index=False)
372
+
373
+ # Accuracy
374
+ y_pred = model.predict(x_test, num_iteration=model.best_iteration)
375
+ y_pred[y_pred >= 0.5] = 1
376
+ y_pred[y_pred < 0.5] = 0
377
+ accuracy = accuracy_score(y_test, y_pred)
378
+
379
+ # AUC
380
+ valid_roc_auc_score = metrics.roc_auc_score(y_valid, valid_y_pred)
381
+ valid_average_precision_score = metrics.average_precision_score(
382
+ y_valid, valid_y_pred
383
+ )
384
+ test_roc_auc_score = metrics.roc_auc_score(y_test, test_y_pred)
385
+ test_average_precision_score = metrics.average_precision_score(y_test, test_y_pred)
386
+
387
+ # AUPR
388
+ valid_aupr = metrics.average_precision_score(y_valid, valid_y_pred)
389
+ test_aupr = metrics.average_precision_score(y_test, test_y_pred)
390
+
391
+ # Fmax
392
+ valid_precision, valid_recall, valid_thresholds = precision_recall_curve(y_valid, valid_y_pred)
393
+ valid_fmax = (2 * valid_precision * valid_recall / (valid_precision + valid_recall)).max()
394
+ test_precision, test_recall, test_thresholds = precision_recall_curve(y_test, test_y_pred)
395
+ test_fmax = (2 * test_precision * test_recall / (test_precision + test_recall)).max()
396
+
397
+ # F1
398
+ valid_f1 = f1_score(y_valid, valid_y_pred >= 0.5)
399
+ test_f1 = f1_score(y_test, test_y_pred >= 0.5)
400
+
401
+
402
+ if __name__ == "__main__":
403
+ args = parse_config()
404
+ if torch.cuda.is_available():
405
+ print("cuda is available.")
406
+ print(f"current device {args}.")
407
+ else:
408
+ args.device = "cpu"
409
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
410
+ random_str = "".join([random.choice(string.ascii_lowercase) for n in range(6)])
411
+ best_model_dir = (
412
+ f"{args.save_path_prefix}{args.save_name}_{timestamp_str}_{random_str}/"
413
+ )
414
+ os.makedirs(best_model_dir)
415
+ args.save_name = best_model_dir
416
+ train(args)
src/utils/downstream_disgenet.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ import os
4
+ import torch
5
+ from utils.data_loader import GDA_Dataset
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.model_selection import KFold
8
+ import numpy as np
9
+ import pandas as pd
10
+
11
+ sys.path.append("../")
12
+
13
+ class DisGeNETProcessor:
14
+ def __init__(self,input_csv_path, data_dir="/nfs/dpa_pretrain/data/downstream/"):
15
+ train_data = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/train.csv')
16
+ valid_data = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/valid.csv')
17
+ test_data = pd.read_csv(input_csv_path)
18
+
19
+ # test_data = pd.read_csv('/nfs/dpa_pretrain/data/downstream/GDA_Data/test.csv')
20
+ # valid_data, test_data = train_test_split(valid_data, test_size=1/3, random_state=42)
21
+ # train_data = pd.read_csv('/nfs/dpa_pretrain/data/downstream/test/train.csv')
22
+ # valid_data = pd.read_csv('/nfs/dpa_pretrain/data/downstream/test/valid.csv')
23
+
24
+
25
+ # train_data = pd.read_csv('/nfs/dpa_pretrain/data/downstream/disgenet_finetune.csv')
26
+ # train_data, valid_data = train_test_split(train_data, test_size=0.2, random_state=42)
27
+ # valid_data, test_data = train_test_split(valid_data, test_size=1/3, random_state=42)
28
+
29
+ # alzheimer and stomach dataset use [["proteinSeq", "diseaseDes", "Y"]].dropna()
30
+
31
+ self.name = "DisGeNET"
32
+ self.train_dataset_df = train_data[["proteinSeq", "diseaseDes", "score"]].dropna()
33
+ self.val_dataset_df = valid_data[["proteinSeq", "diseaseDes", "score"]].dropna()
34
+ self.test_dataset_df = test_data[["proteinSeq", "diseaseDes", "score"]].dropna()
35
+ # self.test_dataset_df = test_data[["proteinSeq", "diseaseDes", "Y"]].dropna()
36
+
37
+
38
+ def get_train_examples(self, test=False):
39
+ """get training examples
40
+
41
+ Args:
42
+ test (bool, optional): test can be int or bool. If test>1, will take test as the number of test examples. Defaults to False.
43
+
44
+ Returns:
45
+ _type_: _description_
46
+ """
47
+ if test == 1: # Small testing set, to reduce the running time
48
+ return (
49
+ self.train_dataset_df["proteinSeq"].values[:4096],
50
+ self.train_dataset_df["diseaseDes"].values[:4096],
51
+ self.train_dataset_df["score"].values[:4096],
52
+ )
53
+ elif test > 1:
54
+ return (
55
+ self.train_dataset_df["proteinSeq"].values[:test],
56
+ self.train_dataset_df["diseaseDes"].values[:test],
57
+ self.train_dataset_df["score"].values[:test],
58
+ )
59
+ else:
60
+ return GDA_Dataset( (
61
+ self.train_dataset_df["proteinSeq"].values,
62
+ self.train_dataset_df["diseaseDes"].values,
63
+ self.train_dataset_df["score"].values,
64
+ ))
65
+
66
+ def get_val_examples(self, test=False):
67
+ """get validation examples
68
+
69
+ Args:
70
+ test (bool, optional): test can be int or bool. If test>1, will take test as the number of test examples. Defaults to False.
71
+
72
+ Returns:
73
+ _type_: _description_
74
+
75
+ """
76
+ if test == 1: # Small testing set, to reduce the running time
77
+ return (
78
+ self.val_dataset_df["proteinSeq"].values[:1024],
79
+ self.val_dataset_df["diseaseDes"].values[:1024],
80
+ self.val_dataset_df["score"].values[:1024],
81
+ )
82
+ elif test > 1:
83
+ return (
84
+ self.val_dataset_df["proteinSeq"].values[:test],
85
+ self.val_dataset_df["diseaseDes"].values[:test],
86
+ self.val_dataset_df["score"].values[:test],
87
+ )
88
+ else:
89
+ return GDA_Dataset((
90
+ self.val_dataset_df["proteinSeq"].values,
91
+ self.val_dataset_df["diseaseDes"].values,
92
+ self.val_dataset_df["score"].values,
93
+ ))
94
+
95
+ # def get_test_examples(self, test=False):
96
+ # """get test examples
97
+
98
+ # Args:
99
+ # test (bool, optional): test can be int or bool. If test>1, will take test as the number of test examples. Defaults to False.
100
+
101
+ # Returns:
102
+ # _type_: _description_
103
+ # """
104
+ # if test == 1: # Small testing set, to reduce the running time
105
+ # return (
106
+ # self.test_dataset_df["proteinSeq"].values[:1024],
107
+ # self.test_dataset_df["diseaseDes"].values[:1024],
108
+ # self.test_dataset_df["Y"].values[:1024],
109
+ # )
110
+ # elif test > 1:
111
+ # return (
112
+ # self.test_dataset_df["proteinSeq"].values[:test],
113
+ # self.test_dataset_df["diseaseDes"].values[:test],
114
+ # self.test_dataset_df["Y"].values[:test],
115
+ # )
116
+ # else:
117
+ # return GDA_Dataset( (
118
+ # self.test_dataset_df["proteinSeq"].values,
119
+ # self.test_dataset_df["diseaseDes"].values,
120
+ # self.test_dataset_df["Y"].values,
121
+ # ))
122
+ def get_test_examples(self, test=False):
123
+ """get test examples
124
+
125
+ Args:
126
+ test (bool, optional): test can be int or bool. If test>1, will take test as the number of test examples. Defaults to False.
127
+
128
+ Returns:
129
+ _type_: _description_
130
+ """
131
+ if test == 1: # Small testing set, to reduce the running time
132
+ return (
133
+ self.test_dataset_df["proteinSeq"].values[:1024],
134
+ self.test_dataset_df["diseaseDes"].values[:1024],
135
+ self.test_dataset_df["score"].values[:1024],
136
+ )
137
+ elif test > 1:
138
+ return (
139
+ self.test_dataset_df["proteinSeq"].values[:test],
140
+ self.test_dataset_df["diseaseDes"].values[:test],
141
+ self.test_dataset_df["score"].values[:test],
142
+ )
143
+ else:
144
+ return GDA_Dataset( (
145
+ self.test_dataset_df["proteinSeq"].values,
146
+ self.test_dataset_df["diseaseDes"].values,
147
+ self.test_dataset_df["score"].values,
148
+ ))
src/utils/metric_learning_models.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+
5
+ sys.path.append("../")
6
+ from pytorch_metric_learning.distances import CosineSimilarity
7
+ from pytorch_metric_learning.reducers import ThresholdReducer
8
+ from pytorch_metric_learning.regularizers import LpRegularizer
9
+ from pytorch_metric_learning import losses
10
+
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import functional as F
15
+ from pytorch_metric_learning import losses, miners
16
+ from torch.cuda.amp import autocast
17
+ from torch.nn import Module
18
+ from tqdm import tqdm
19
+ from utils.gd_model import GDANet
20
+ from torch.nn import MultiheadAttention
21
+
22
+ from transformers import BertModel
23
+ from transformers import EsmModel, EsmConfig
24
+
25
+ LOGGER = logging.getLogger(__name__)
26
+
27
+ class FusionModule(nn.Module):
28
+ def __init__(self, out_dim, num_head, dropout= 0.1):
29
+ super(FusionModule, self).__init__()
30
+ """FusionModule.
31
+
32
+ Args:
33
+ dropout= 0.1 is defaut
34
+ out_dim: model output dimension
35
+ num_head = 8: Multi-head Attention
36
+ """
37
+
38
+ self.out_dim = out_dim
39
+ self.num_head = num_head
40
+
41
+ self.WqS = nn.Linear(out_dim, out_dim)
42
+ self.WkS = nn.Linear(out_dim, out_dim)
43
+ self.WvS = nn.Linear(out_dim, out_dim)
44
+
45
+ self.WqT = nn.Linear(out_dim, out_dim)
46
+ self.WkT = nn.Linear(out_dim, out_dim)
47
+ self.WvT = nn.Linear(out_dim, out_dim)
48
+ self.multi_head_attention = nn.MultiheadAttention(out_dim, num_head, dropout=dropout)
49
+
50
+ def forward(self, zs, zt):
51
+ # nn.MultiheadAttention The input representation is (token_length, batch_size, out_dim)
52
+ # zs = protein_representation.permute(1, 0, 2)
53
+ # zt = disease_representation.permute(1, 0, 2)
54
+
55
+ # Compute query, key and value representations
56
+ qs = self.WqS(zs)
57
+ ks = self.WkS(zs)
58
+ vs = self.WvS(zs)
59
+
60
+ qt = self.WqT(zt)
61
+ kt = self.WkT(zt)
62
+ vt = self.WvT(zt)
63
+
64
+ #self.multi_head_attention() The function returns two values: the representation and the attention weight matrix, computed after multiple attentions. In this case, we only care about the computed representation and not the attention weight matrix, so "_" is used to indicate that we do not intend to use or store the second return value.
65
+ zs_attention1, _ = self.multi_head_attention(qs, ks, vs)
66
+ zs_attention2, _ = self.multi_head_attention(qs, kt, vt)
67
+ zt_attention1, _ = self.multi_head_attention(qt, kt, vt)
68
+ zt_attention2, _ = self.multi_head_attention(qt, ks, vs)
69
+
70
+ protein_fused = 0.5 * (zs_attention1 + zs_attention2)
71
+ dis_fused = 0.5 * (zt_attention1 + zt_attention2)
72
+
73
+ return protein_fused, dis_fused
74
+
75
+ class CrossAttentionBlock(nn.Module):
76
+
77
+ def __init__(self, hidden_dim, num_heads):
78
+ super(CrossAttentionBlock, self).__init__()
79
+ if hidden_dim % num_heads != 0:
80
+ raise ValueError(
81
+ "The hidden size (%d) is not a multiple of the number of attention "
82
+ "heads (%d)" % (hidden_dim, num_heads))
83
+ self.hidden_dim = hidden_dim
84
+ self.num_heads = num_heads
85
+ self.head_size = hidden_dim // num_heads
86
+
87
+ self.query1 = nn.Linear(hidden_dim, hidden_dim, bias=False)
88
+ self.key1 = nn.Linear(hidden_dim, hidden_dim, bias=False)
89
+ self.value1 = nn.Linear(hidden_dim, hidden_dim, bias=False)
90
+
91
+ self.query2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
92
+ self.key2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
93
+ self.value2 = nn.Linear(hidden_dim, hidden_dim, bias=False)
94
+
95
+ def _alpha_from_logits(self, logits, mask_row, mask_col, inf=1e6):
96
+ N, L1, L2, H = logits.shape
97
+ mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
98
+ mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
99
+ mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
100
+
101
+ logits = torch.where(mask_pair, logits, logits - inf)
102
+ alpha = torch.softmax(logits, dim=2)
103
+ mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
104
+ alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
105
+ return alpha
106
+
107
+ def _heads(self, x, n_heads, n_ch):
108
+ s = list(x.size())[:-1] + [n_heads, n_ch]
109
+ return x.view(*s)
110
+
111
+ def forward(self, input1, input2, mask1, mask2):
112
+ query1 = self._heads(self.query1(input1), self.num_heads, self.head_size)
113
+ key1 = self._heads(self.key1(input1), self.num_heads, self.head_size)
114
+ query2 = self._heads(self.query2(input2), self.num_heads, self.head_size)
115
+ key2 = self._heads(self.key2(input2), self.num_heads, self.head_size)
116
+ logits11 = torch.einsum('blhd, bkhd->blkh', query1, key1)
117
+ logits12 = torch.einsum('blhd, bkhd->blkh', query1, key2)
118
+ logits21 = torch.einsum('blhd, bkhd->blkh', query2, key1)
119
+ logits22 = torch.einsum('blhd, bkhd->blkh', query2, key2)
120
+
121
+ alpha11 = self._alpha_from_logits(logits11, mask1, mask1)
122
+ alpha12 = self._alpha_from_logits(logits12, mask1, mask2)
123
+ alpha21 = self._alpha_from_logits(logits21, mask2, mask1)
124
+ alpha22 = self._alpha_from_logits(logits22, mask2, mask2)
125
+
126
+ value1 = self._heads(self.value1(input1), self.num_heads, self.head_size)
127
+ value2 = self._heads(self.value2(input2), self.num_heads, self.head_size)
128
+ output1 = (torch.einsum('blkh, bkhd->blhd', alpha11, value1).flatten(-2) +
129
+ torch.einsum('blkh, bkhd->blhd', alpha12, value2).flatten(-2)) / 2
130
+ output2 = (torch.einsum('blkh, bkhd->blhd', alpha21, value1).flatten(-2) +
131
+ torch.einsum('blkh, bkhd->blhd', alpha22, value2).flatten(-2)) / 2
132
+
133
+ return output1, output2
134
+
135
+ class GDA_Metric_Learning(GDANet):
136
+ def __init__(
137
+ self, prot_encoder, disease_encoder, prot_out_dim, disease_out_dim, args
138
+ ):
139
+ """Constructor for the model.
140
+
141
+ Args:
142
+ prot_encoder (_type_): Protein encoder.
143
+ disease_encoder (_type_): Disease Textual encoder.
144
+ prot_out_dim (_type_): Dimension of the Protein encoder.
145
+ disease_out_dim (_type_): Dimension of the Disease encoder.
146
+ args (_type_): _description_
147
+ """
148
+ super(GDA_Metric_Learning, self).__init__(
149
+ prot_encoder,
150
+ disease_encoder,
151
+ )
152
+ self.prot_encoder = prot_encoder
153
+ self.disease_encoder = disease_encoder
154
+ self.loss = args.loss
155
+ self.use_miner = args.use_miner
156
+ self.miner_margin = args.miner_margin
157
+ self.agg_mode = args.agg_mode
158
+ self.prot_reg = nn.Linear(prot_out_dim, 1024)
159
+ # self.prot_reg = nn.Linear(prot_out_dim, disease_out_dim)
160
+ self.dis_reg = nn.Linear(disease_out_dim, 1024)
161
+ # self.prot_adapter_name = None
162
+ # self.disease_adapter_name = None
163
+
164
+ self.fusion_layer = FusionModule(1024, num_head=8)
165
+ self.cross_attention_layer = CrossAttentionBlock(1024, 8)
166
+
167
+ # # MMP Prediction Heads
168
+ # self.prot_pred_head = nn.Sequential(
169
+ # nn.Linear(disease_out_dim, disease_out_dim),
170
+ # nn.ReLU(),
171
+ # nn.Linear(disease_out_dim, 1280) #vocabulary size : prot model tokenize length 30 446
172
+ # )
173
+ # self.dise_pred_head = nn.Sequential(
174
+ # nn.Linear(disease_out_dim, disease_out_dim),
175
+ # nn.ReLU(),
176
+ # nn.Linear(disease_out_dim, 768) #vocabulary size : disease model tokenize length 30522
177
+ # )
178
+
179
+ if self.use_miner:
180
+ self.miner = miners.TripletMarginMiner(
181
+ margin=args.miner_margin, type_of_triplets="all"
182
+ )
183
+ else:
184
+ self.miner = None
185
+
186
+ if self.loss == "ms_loss":
187
+ self.loss = losses.MultiSimilarityLoss(
188
+ alpha=2, beta=50, base=0.5
189
+ ) # 1,2,3; 40,50,60
190
+ #1_40=1.5141 50=1.4988 60=1.4905 2_60=1.1786 50=1.1874 40=1.2008 3_40=1.1146 50=1.1012
191
+ elif self.loss == "circle_loss":
192
+ self.loss = losses.CircleLoss(
193
+ m=0.4, gamma=80
194
+ )
195
+ elif self.loss == "triplet_loss":
196
+ self.loss = losses.TripletMarginLoss(
197
+ margin=0.05, swap=False, smooth_loss=False,
198
+ triplets_per_anchor="all")
199
+ # distance = CosineSimilarity(),
200
+ # reducer = ThresholdReducer(high=0.3),
201
+ # embedding_regularizer = LpRegularizer() )
202
+
203
+ elif self.loss == "infoNCE":
204
+ self.loss = losses.NTXentLoss(
205
+ temperature=0.07
206
+ ) # The MoCo paper uses 0.07, while SimCLR uses 0.5.
207
+ elif self.loss == "lifted_structure_loss":
208
+ self.loss = losses.LiftedStructureLoss(
209
+ neg_margin=1, pos_margin=0
210
+ )
211
+ elif self.loss == "nca_loss":
212
+ self.loss = losses.NCALoss(
213
+ softmax_scale=1
214
+ )
215
+ self.fusion = False
216
+ # self.stack = False
217
+ self.dropout = torch.nn.Dropout(args.dropout)
218
+ print("miner:", self.miner)
219
+ print("loss:", self.loss)
220
+
221
+ # def add_fusion(self):
222
+ # adapter_setup = Fuse("prot_adapter", "disease_adapter")
223
+ # self.prot_encoder.add_fusion(adapter_setup)
224
+ # self.prot_encoder.set_active_adapters(adapter_setup)
225
+ # self.prot_encoder.train_fusion(adapter_setup)
226
+ # self.disease_encoder.add_fusion(adapter_setup)
227
+ # self.disease_encoder.set_active_adapters(adapter_setup)
228
+ # self.disease_encoder.train_fusion(adapter_setup)
229
+ # self.fusion = True
230
+
231
+ # def add_stack_gda(self, reduction_factor):
232
+ # self.add_gda_adapters(reduction_factor=reduction_factor)
233
+ # # adapter_setup = Fuse("prot_adapter", "disease_adapter")
234
+ # self.prot_encoder.active_adapters = Stack(
235
+ # self.prot_adapter_name, self.gda_adapter_name
236
+ # )
237
+ # self.disease_encoder.active_adapters = Stack(
238
+ # self.disease_adapter_name, self.gda_adapter_name
239
+ # )
240
+ # print("stacked adapters loaded.")
241
+ # self.stack = True
242
+
243
+ # def load_adapters(
244
+ # self,
245
+ # prot_model_path,
246
+ # disease_model_path,
247
+ # prot_adapter_name="prot_adapter",
248
+ # disease_adapter_name="disease_adapter",
249
+ # ):
250
+ # if os.path.exists(prot_model_path):
251
+ # print(f"loading prot adapter from: {prot_model_path}")
252
+ # self.prot_adapter_name = prot_adapter_name
253
+ # self.prot_encoder.load_adapter(prot_model_path, load_as=prot_adapter_name)
254
+ # self.prot_encoder.set_active_adapters(prot_adapter_name)
255
+ # print(f"load protein adapters from: {prot_model_path} {prot_adapter_name}")
256
+ # else:
257
+ # print(f"{prot_model_path} not exits")
258
+
259
+ # if os.path.exists(disease_model_path):
260
+ # print(f"loading prot adapter from: {disease_model_path}")
261
+ # self.disease_adapter_name = disease_adapter_name
262
+ # self.disease_encoder.load_adapter(
263
+ # disease_model_path, load_as=disease_adapter_name
264
+ # )
265
+ # self.disease_encoder.set_active_adapters(disease_adapter_name)
266
+ # print(
267
+ # f"load disease adapters from: {disease_model_path} {disease_adapter_name}"
268
+ # )
269
+ # else:
270
+ # print(f"{disease_model_path} not exits")
271
+
272
+ def non_adapters(
273
+ self,
274
+ prot_model_path,
275
+ disease_model_path,
276
+
277
+ ):
278
+ if os.path.exists(prot_model_path):
279
+ # Load the entire model for prot_model
280
+ prot_model = torch.load(prot_model_path)
281
+ # Set the prot_encoder to the loaded model
282
+ self.prot_encoder = prot_model.prot_encoder
283
+ print(f"load protein from: {prot_model_path}")
284
+ else:
285
+ print(f"{prot_model_path} not exits")
286
+
287
+ if os.path.exists(disease_model_path):
288
+ # Load the entire model for disease_model
289
+ disease_model = torch.load(disease_model_path)
290
+ # Set the disease_encoder to the loaded model
291
+ self.disease_encoder = disease_model.disease_encoder
292
+ print(f"load disease from: {disease_model_path}")
293
+
294
+ else:
295
+ print(f"{disease_model_path} not exits")
296
+
297
+
298
+ # def add_gda_adapters(
299
+ # self,
300
+ # gda_adapter_name="gda_adapter",
301
+ # reduction_factor=16,
302
+ # ):
303
+ # """Initialise adapters
304
+
305
+ # Args:
306
+ # prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter".
307
+ # disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter".
308
+ # reduction_factor (int, optional): _description_. Defaults to 16.
309
+ # """
310
+ # adapter_config = AdapterConfig.load(
311
+ # "pfeiffer", reduction_factor=reduction_factor
312
+ # )
313
+ # self.gda_adapter_name = gda_adapter_name
314
+ # self.prot_encoder.add_adapter(gda_adapter_name, config=adapter_config)
315
+ # self.prot_encoder.train_adapter([gda_adapter_name])
316
+ # self.disease_encoder.add_adapter(gda_adapter_name, config=adapter_config)
317
+ # self.disease_encoder.train_adapter([gda_adapter_name])
318
+
319
+ # def init_adapters(
320
+ # self,
321
+ # prot_adapter_name="gda_prot_adapter",
322
+ # disease_adapter_name="gda_disease_adapter",
323
+ # reduction_factor=16,
324
+ # ):
325
+ # """Initialise adapters
326
+
327
+ # Args:
328
+ # prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter".
329
+ # disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter".
330
+ # reduction_factor (int, optional): _description_. Defaults to 16.
331
+ # """
332
+ # adapter_config = AdapterConfig.load(
333
+ # "pfeiffer", reduction_factor=reduction_factor
334
+ # )
335
+
336
+ # self.prot_adapter_name = prot_adapter_name
337
+ # self.disease_adapter_name = disease_adapter_name
338
+ # self.prot_encoder.add_adapter(prot_adapter_name, config=adapter_config)
339
+ # self.prot_encoder.train_adapter([prot_adapter_name])
340
+ # self.disease_encoder.add_adapter(disease_adapter_name, config=adapter_config)
341
+ # self.disease_encoder.train_adapter([disease_adapter_name])
342
+ # print(f"adapter modules initialized")
343
+
344
+ # def save_adapters(self, save_path_prefix, total_step):
345
+ # """Save adapters into file.
346
+
347
+ # Args:
348
+ # save_path_prefix (string): saving path prefix.
349
+ # total_step (int): total step number.
350
+ # """
351
+ # prot_save_dir = os.path.join(
352
+ # save_path_prefix, f"prot_adapter_step_{total_step}"
353
+ # )# adapter
354
+ # disease_save_dir = os.path.join(
355
+ # save_path_prefix, f"disease_adapter_step_{total_step}"
356
+ # )
357
+ # os.makedirs(prot_save_dir, exist_ok=True)
358
+ # os.makedirs(disease_save_dir, exist_ok=True)
359
+ # self.prot_encoder.save_adapter(prot_save_dir, self.prot_adapter_name)
360
+ # prot_head_save_path = os.path.join(prot_save_dir, "prot_head.bin")
361
+ # torch.save(self.prot_reg, prot_head_save_path)
362
+ # self.disease_encoder.save_adapter(disease_save_dir, self.disease_adapter_name)
363
+ # disease_head_save_path = os.path.join(prot_save_dir, "disease_head.bin")
364
+ # torch.save(self.prot_reg, disease_head_save_path)
365
+ # if self.fusion:
366
+ # self.prot_encoder.save_all_adapters(prot_save_dir)
367
+ # self.disease_encoder.save_all_adapters(disease_save_dir)
368
+
369
+ def predict(self, query_toks1, query_toks2):
370
+ """
371
+ query : (N, h), candidates : (N, topk, h)
372
+ output : (N, topk)
373
+ """
374
+ # Extract input_ids and attention_mask for protein
375
+ prot_input_ids = query_toks1["input_ids"]
376
+ prot_attention_mask = query_toks1["attention_mask"]
377
+
378
+ # Extract input_ids and attention_mask for dis
379
+ dis_input_ids = query_toks2["input_ids"]
380
+ dis_attention_mask = query_toks2["attention_mask"]
381
+
382
+ # Process inputs through encoders
383
+ last_hidden_state1 = self.prot_encoder(
384
+ input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True
385
+ ).last_hidden_state
386
+ last_hidden_state1 = self.prot_reg(last_hidden_state1)
387
+
388
+ last_hidden_state2 = self.disease_encoder(
389
+ input_ids=dis_input_ids, attention_mask=dis_attention_mask, return_dict=True
390
+ ).last_hidden_state
391
+ last_hidden_state2 = self.dis_reg(last_hidden_state2)
392
+ # Apply the cross-attention layer
393
+ prot_fused, dis_fused = self.cross_attention_layer(
394
+ last_hidden_state1, last_hidden_state2, prot_attention_mask, dis_attention_mask
395
+ )
396
+
397
+ # last_hidden_state1 = self.prot_encoder(
398
+ # query_toks1, return_dict=True
399
+ # ).last_hidden_state
400
+ # last_hidden_state1 = self.prot_reg(
401
+ # last_hidden_state1
402
+ # ) # transform the prot embedding into the same dimension as the disease embedding
403
+ # last_hidden_state2 = self.disease_encoder(
404
+ # query_toks2, return_dict=True
405
+ # ).last_hidden_state
406
+ # last_hidden_state2 = self.dis_reg(
407
+ # last_hidden_state2
408
+ # ) # transform the disease embedding into 1024
409
+
410
+ # Apply the fusion layer and Recovery of representational shape
411
+ # prot_fused, dis_fused = self.fusion_layer(last_hidden_state1, last_hidden_state2)
412
+
413
+ if self.agg_mode == "cls":
414
+ query_embed1 = prot_fused[:, 0] # query : [batch_size, hidden]
415
+ query_embed2 = dis_fused[:, 0] # query : [batch_size, hidden]
416
+ elif self.agg_mode == "mean_all_tok":
417
+ query_embed1 = prot_fused.mean(1) # query : [batch_size, hidden]
418
+ query_embed2 = dis_fused.mean(1) # query : [batch_size, hidden]
419
+ elif self.agg_mode == "mean":
420
+ query_embed1 = (
421
+ prot_fused * query_toks1["attention_mask"].unsqueeze(-1)
422
+ ).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1)
423
+ query_embed2 = (
424
+ dis_fused * query_toks2["attention_mask"].unsqueeze(-1)
425
+ ).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1)
426
+ else:
427
+ raise NotImplementedError()
428
+
429
+ query_embed = torch.cat([query_embed1, query_embed2], dim=1)
430
+ return query_embed
431
+
432
+ def forward(self, query_toks1, query_toks2, labels):
433
+ """
434
+ query : (N, h), candidates : (N, topk, h)
435
+ output : (N, topk)
436
+ """
437
+ # Extract input_ids and attention_mask for protein
438
+ prot_input_ids = query_toks1["input_ids"]
439
+ prot_attention_mask = query_toks1["attention_mask"]
440
+
441
+ # Extract input_ids and attention_mask for dis
442
+ dis_input_ids = query_toks2["input_ids"]
443
+ dis_attention_mask = query_toks2["attention_mask"]
444
+
445
+ # Process inputs through encoders
446
+ last_hidden_state1 = self.prot_encoder(
447
+ input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True
448
+ ).last_hidden_state
449
+ last_hidden_state1 = self.prot_reg(last_hidden_state1)
450
+
451
+ last_hidden_state2 = self.disease_encoder(
452
+ input_ids=dis_input_ids, attention_mask=dis_attention_mask, return_dict=True
453
+ ).last_hidden_state
454
+ last_hidden_state2 = self.dis_reg(last_hidden_state2)
455
+ # Apply the cross-attention layer
456
+ prot_fused, dis_fused = self.cross_attention_layer(
457
+ last_hidden_state1, last_hidden_state2, prot_attention_mask, dis_attention_mask
458
+ )
459
+ # last_hidden_state1 = self.prot_encoder(
460
+ # query_toks1, return_dict=True
461
+ # ).last_hidden_state
462
+
463
+ # last_hidden_state1 = self.prot_reg(
464
+ # last_hidden_state1
465
+ # ) # transform the prot embedding into the same dimension as the disease embedding
466
+ # last_hidden_state2 = self.disease_encoder(
467
+ # query_toks2, return_dict=True
468
+ # ).last_hidden_state
469
+ # last_hidden_state2 = self.dis_reg(
470
+ # last_hidden_state2
471
+ # ) # transform the disease embedding into 1024
472
+
473
+ # # Apply the fusion layer and Recovery of representational shape
474
+ # prot_fused, dis_fused = self.fusion_layer(last_hidden_state1, last_hidden_state2)
475
+ if self.agg_mode == "cls":
476
+ query_embed1 = prot_pred[:, 0] # query : [batch_size, hidden]
477
+ query_embed2 = dise_pred[:, 0] # query : [batch_size, hidden]
478
+ elif self.agg_mode == "mean_all_tok":
479
+ query_embed1 = prot_fused.mean(1) # query : [batch_size, hidden]
480
+ query_embed2 = dis_fused.mean(1) # query : [batch_size, hidden]
481
+ elif self.agg_mode == "mean":
482
+ query_embed1 = (
483
+ prot_pred * query_toks1["attention_mask"].unsqueeze(-1)
484
+ ).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1)
485
+ query_embed2 = (
486
+ dis_fused * query_toks2["attention_mask"].unsqueeze(-1)
487
+ ).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1)
488
+ else:
489
+ raise NotImplementedError()
490
+
491
+ # print("query_embed1 =", query_embed1.shape, "query_embed2 =", query_embed2.shape)
492
+ query_embed = torch.cat([query_embed1, query_embed2], dim=0)
493
+ # print("query_embed =", len(query_embed))
494
+
495
+ labels = torch.cat([torch.arange(len(labels)), torch.arange(len(labels))], dim=0)
496
+
497
+ if self.use_miner:
498
+ hard_pairs = self.miner(query_embed, labels)
499
+ return self.loss(query_embed, labels, hard_pairs)# + loss_mmp
500
+ else:
501
+ loss = self.loss(query_embed, labels)# + loss_mmp
502
+ # print('loss :', loss)
503
+ return loss
504
+
505
+ def get_embeddings(self, mentions, batch_size=1024):
506
+ """
507
+ Compute all embeddings from mention tokens.
508
+ """
509
+ embedding_table = []
510
+ with torch.no_grad():
511
+ for start in tqdm(range(0, len(mentions), batch_size)):
512
+ end = min(start + batch_size, len(mentions))
513
+ batch = mentions[start:end]
514
+ batch_embedding = self.vectorizer(batch)
515
+ batch_embedding = batch_embedding.cpu()
516
+ embedding_table.append(batch_embedding)
517
+ embedding_table = torch.cat(embedding_table, dim=0)
518
+ return embedding_table
519
+
520
+
521
+
522
+ class DDA_Metric_Learning(Module):
523
+ def __init__(self, disease_encoder, args):
524
+ """Constructor for the model.
525
+
526
+ Args:
527
+ disease_encoder (_type_): disease encoder.
528
+ args (_type_): _description_
529
+ """
530
+ super(DDA_Metric_Learning, self).__init__()
531
+ self.disease_encoder = disease_encoder
532
+ self.loss = args.loss
533
+ self.use_miner = args.use_miner
534
+ self.miner_margin = args.miner_margin
535
+ self.agg_mode = args.agg_mode
536
+ self.disease_adapter_name = None
537
+ if self.use_miner:
538
+ self.miner = miners.TripletMarginMiner(
539
+ margin=args.miner_margin, type_of_triplets="all"
540
+ )
541
+ else:
542
+ self.miner = None
543
+
544
+ if self.loss == "ms_loss":
545
+ self.loss = losses.MultiSimilarityLoss(
546
+ alpha=1, beta=60, base=0.5
547
+ ) # 1,2,3; 40,50,60
548
+ elif self.loss == "circle_loss":
549
+ self.loss = losses.CircleLoss()
550
+ elif self.loss == "triplet_loss":
551
+ self.loss = losses.TripletMarginLoss()
552
+ elif self.loss == "infoNCE":
553
+ self.loss = losses.NTXentLoss(
554
+ temperature=0.07
555
+ ) # The MoCo paper uses 0.07, while SimCLR uses 0.5.
556
+ elif self.loss == "lifted_structure_loss":
557
+ self.loss = losses.LiftedStructureLoss()
558
+ elif self.loss == "nca_loss":
559
+ self.loss = losses.NCALoss()
560
+ self.reg = None
561
+ self.cls = None
562
+ self.dropout = torch.nn.Dropout(args.dropout)
563
+ print("miner:", self.miner)
564
+ print("loss:", self.loss)
565
+
566
+ def add_classification_head(self, disease_out_dim=768, out_dim=2):
567
+ """Add regression head.
568
+
569
+ Args:
570
+ disease_out_dim (_type_): disease encoder output dimension.
571
+ out_dim (int, optional): output dimension. Defaults to 2.
572
+ drop_out (int, optional): dropout rate. Defaults to 0.
573
+ """
574
+ self.cls = nn.Linear(disease_out_dim * 2, out_dim)
575
+
576
+ def load_disease_adapter(
577
+ self,
578
+ disease_model_path,
579
+ disease_adapter_name="disease_adapter",
580
+ ):
581
+ if os.path.exists(disease_model_path):
582
+ self.disease_adapter_name = disease_adapter_name
583
+ self.disease_encoder.load_adapter(
584
+ disease_model_path, load_as=disease_adapter_name
585
+ )
586
+ self.disease_encoder.set_active_adapters(disease_adapter_name)
587
+ print(
588
+ f"load disease adapters from: {disease_model_path} {disease_adapter_name}"
589
+ )
590
+ else:
591
+ print(f"{disease_adapter_name} not exits")
592
+
593
+ def init_adapters(
594
+ self,
595
+ disease_adapter_name="disease_adapter",
596
+ reduction_factor=16,
597
+ ):
598
+ """Initialise adapters
599
+
600
+ Args:
601
+ disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter".
602
+ reduction_factor (int, optional): _description_. Defaults to 16.
603
+ """
604
+ adapter_config = AdapterConfig.load(
605
+ "pfeiffer", reduction_factor=reduction_factor
606
+ )
607
+ self.disease_adapter_name = disease_adapter_name
608
+ self.disease_encoder.add_adapter(disease_adapter_name, config=adapter_config)
609
+ self.disease_encoder.train_adapter([disease_adapter_name])
610
+
611
+ def save_adapters(self, save_path_prefix, total_step):
612
+ """Save adapters into file.
613
+
614
+ Args:
615
+ save_path_prefix (string): saving path prefix.
616
+ total_step (int): total step number.
617
+ """
618
+ disease_save_dir = os.path.join(
619
+ save_path_prefix, f"disease_adapter_step_{total_step}"
620
+ )
621
+ os.makedirs(disease_save_dir, exist_ok=True)
622
+ self.disease_encoder.save_adapter(disease_save_dir, self.disease_adapter_name)
623
+
624
+ def predict(self, x1, x2):
625
+ """
626
+ query : (N, h), candidates : (N, topk, h)
627
+ output : (N, topk)
628
+
629
+ """
630
+ if self.agg_mode == "cls":
631
+ x1 = self.disease_encoder(x1).last_hidden_state[:, 0]
632
+ x2 = self.disease_encoder(x2).last_hidden_state[:, 0]
633
+ x = torch.cat((x1, x2), 1)
634
+ return x
635
+ else:
636
+ x1 = self.disease_encoder(x1).last_hidden_state.mean(1) # query : [batch_size, hidden]
637
+ x2 = self.disease_encoder(x2).last_hidden_state.mean(1) # query : [batch_size, hidden]
638
+ x = torch.cat((x1, x2), 1)
639
+ return x
640
+
641
+ def module_predict(self, x1, x2):
642
+ """
643
+ query : (N, h), candidates : (N, topk, h)
644
+ output : (N, topk)
645
+
646
+ """
647
+ if self.agg_mode == "cls":
648
+ x1 = self.disease_encoder.module(x1).last_hidden_state[:, 0]
649
+ x2 = self.disease_encoder.module(x2).last_hidden_state[:, 0]
650
+ x = torch.cat((x1, x2), 1)
651
+ return x
652
+ else:
653
+ x1 = self.disease_encoder.module(x1).last_hidden_state.mean(1) # query : [batch_size, hidden]
654
+ x2 = self.disease_encoder.module(x2).last_hidden_state.mean(1) # query : [batch_size, hidden]
655
+ x = torch.cat((x1, x2), 1)
656
+ return x
657
+
658
+ @autocast()
659
+ def forward(self, query_toks1, query_toks2, labels):
660
+ """
661
+ query : (N, h), candidates : (N, topk, h)
662
+ output : (N, topk)
663
+ """
664
+ last_hidden_state1 = self.disease_encoder(
665
+ **query_toks1, return_dict=True
666
+ ).last_hidden_state
667
+ last_hidden_state2 = self.disease_encoder(
668
+ **query_toks2, return_dict=True
669
+ ).last_hidden_state
670
+ if self.agg_mode == "cls":
671
+ query_embed1 = last_hidden_state1[:, 0] # query : [batch_size, hidden]
672
+ query_embed2 = last_hidden_state2[:, 0] # query : [batch_size, hidden]
673
+ elif self.agg_mode == "mean_all_tok":
674
+ query_embed1 = last_hidden_state1.mean(1) # query : [batch_size, hidden]
675
+ query_embed2 = last_hidden_state2.mean(1) # query : [batch_size, hidden]
676
+ elif self.agg_mode == "mean":
677
+ query_embed1 = (
678
+ last_hidden_state1 * query_toks1["attention_mask"].unsqueeze(-1)
679
+ ).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1)
680
+ query_embed2 = (
681
+ last_hidden_state2 * query_toks2["attention_mask"].unsqueeze(-1)
682
+ ).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1)
683
+ else:
684
+ raise NotImplementedError()
685
+ query_embed = torch.cat([query_embed1, query_embed2], dim=0)
686
+
687
+ labels = torch.cat([labels, labels], dim=0)
688
+ if self.use_miner:
689
+ hard_pairs = self.miner(query_embed, labels)
690
+ print('miner used')
691
+ return self.loss(query_embed, labels, hard_pairs)
692
+ else:
693
+ print('no miner')
694
+ return self.loss(query_embed, labels)
695
+
696
+
697
+ class PPI_Metric_Learning(Module):
698
+ def __init__(self, prot_encoder, args):
699
+ """Constructor for the model.
700
+
701
+ Args:
702
+ prot_encoder (_type_): Protein encoder.
703
+ prot_encoder (_type_): prot Textual encoder.
704
+ prot_out_dim (_type_): Dimension of the Protein encoder.
705
+ prot_out_dim (_type_): Dimension of the prot encoder.
706
+ args (_type_): _description_
707
+ """
708
+ super(PPI_Metric_Learning, self).__init__()
709
+ self.prot_encoder = prot_encoder
710
+ self.loss = args.loss
711
+ self.use_miner = args.use_miner
712
+ self.miner_margin = args.miner_margin
713
+ self.agg_mode = args.agg_mode
714
+ self.prot_adapter_name = None
715
+ if self.use_miner:
716
+ self.miner = miners.TripletMarginMiner(
717
+ margin=args.miner_margin, type_of_triplets="all"
718
+ )
719
+ else:
720
+ self.miner = None
721
+
722
+ if self.loss == "ms_loss":
723
+ self.loss = losses.MultiSimilarityLoss(
724
+ alpha=1, beta=60, base=0.5
725
+ ) # 1,2,3; 40,50,60
726
+ elif self.loss == "circle_loss":
727
+ self.loss = losses.CircleLoss()
728
+ elif self.loss == "triplet_loss":
729
+ self.loss = losses.TripletMarginLoss()
730
+ elif self.loss == "infoNCE":
731
+ self.loss = losses.NTXentLoss(
732
+ temperature=0.07
733
+ ) # The MoCo paper uses 0.07, while SimCLR uses 0.5.
734
+ elif self.loss == "lifted_structure_loss":
735
+ self.loss = losses.LiftedStructureLoss()
736
+ elif self.loss == "nca_loss":
737
+ self.loss = losses.NCALoss()
738
+ self.reg = None
739
+ self.cls = None
740
+ self.dropout = torch.nn.Dropout(args.dropout)
741
+ print("miner:", self.miner)
742
+ print("loss:", self.loss)
743
+
744
+ def add_classification_head(self, prot_out_dim=1024, out_dim=2):
745
+ """Add regression head.
746
+
747
+ Args:
748
+ prot_out_dim (_type_): protein encoder output dimension.
749
+ disease_out_dim (_type_): disease encoder output dimension.
750
+ out_dim (int, optional): output dimension. Defaults to 2.
751
+ drop_out (int, optional): dropout rate. Defaults to 0.
752
+ """
753
+ self.cls = nn.Linear(prot_out_dim + prot_out_dim, out_dim)
754
+
755
+ def load_prot_adapter(
756
+ self,
757
+ prot_model_path,
758
+ prot_adapter_name="prot_adapter",
759
+ ):
760
+ if os.path.exists(prot_model_path):
761
+ self.prot_adapter_name = prot_adapter_name
762
+ self.prot_encoder.load_adapter(prot_model_path, load_as=prot_adapter_name)
763
+ self.prot_encoder.set_active_adapters(prot_adapter_name)
764
+ print(f"load protein adapters from: {prot_model_path} {prot_adapter_name}")
765
+ else:
766
+ print(f"{prot_model_path} not exits")
767
+
768
+ def init_adapters(
769
+ self,
770
+ prot_adapter_name="prot_adapter",
771
+ reduction_factor=16,
772
+ ):
773
+ """Initialise adapters
774
+
775
+ Args:
776
+ prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter".
777
+ reduction_factor (int, optional): _description_. Defaults to 16.
778
+ """
779
+ adapter_config = AdapterConfig.load(
780
+ "pfeiffer", reduction_factor=reduction_factor
781
+ )
782
+ self.prot_adapter_name = prot_adapter_name
783
+ self.prot_encoder.add_adapter(prot_adapter_name, config=adapter_config)
784
+ self.prot_encoder.train_adapter([prot_adapter_name])
785
+
786
+ def save_adapters(self, save_path_prefix, total_step):
787
+ """Save adapters into file.
788
+
789
+ Args:
790
+ save_path_prefix (string): saving path prefix.
791
+ total_step (int): total step number.
792
+ """
793
+ prot_save_dir = os.path.join(
794
+ save_path_prefix, f"prot_adapter_step_{total_step}"
795
+ )
796
+ os.makedirs(prot_save_dir, exist_ok=True)
797
+ self.prot_encoder.save_adapter(prot_save_dir, self.prot_adapter_name)
798
+
799
+ def predict(self, x1, x2):
800
+ """
801
+ query : (N, h), candidates : (N, topk, h)
802
+ output : (N, topk)
803
+
804
+ """
805
+
806
+ if self.agg_mode == "cls":
807
+ x1 = self.prot_encoder(x1).last_hidden_state[:, 0]
808
+ x2 = self.prot_encoder(x2).last_hidden_state[:, 0]
809
+ x = torch.cat((x1, x2), 1)
810
+ return x
811
+ else:
812
+ x1 = self.prot_encoder(x1).last_hidden_state.mean(1) # query : [batch_size, hidden]
813
+ x2 = self.prot_encoder(x2).last_hidden_state.mean(1) # query : [batch_size, hidden]
814
+ x = torch.cat((x1, x2), 1)
815
+ return x
816
+
817
+ def module_predict(self, x1, x2):
818
+ """
819
+ query : (N, h), candidates : (N, topk, h)
820
+ output : (N, topk)
821
+
822
+ """
823
+ if self.agg_mode == "cls":
824
+ x1 = self.prot_encoder.module(x1).last_hidden_state[:, 0]
825
+ x2 = self.prot_encoder.module(x2).last_hidden_state[:, 0]
826
+ x = torch.cat((x1, x2), 1)
827
+ return x
828
+ else:
829
+ x1 = self.prot_encoder.module(x1).last_hidden_state.mean(1) # query : [batch_size, hidden]
830
+ x2 = self.prot_encoder.module(x2).last_hidden_state.mean(1) # query : [batch_size, hidden]
831
+ x = torch.cat((x1, x2), 1)
832
+ return x
833
+
834
+ @autocast()
835
+ def forward(self, query_toks1, query_toks2, labels):
836
+ """
837
+ query : (N, h), candidates : (N, topk, h)
838
+ output : (N, topk)
839
+ """
840
+ last_hidden_state1 = self.prot_encoder(
841
+ **query_toks1, return_dict=True
842
+ ).last_hidden_state
843
+ last_hidden_state2 = self.prot_encoder(
844
+ **query_toks2, return_dict=True
845
+ ).last_hidden_state
846
+ if self.agg_mode == "cls":
847
+ query_embed1 = last_hidden_state1[:, 0] # query : [batch_size, hidden]
848
+ query_embed2 = last_hidden_state2[:, 0] # query : [batch_size, hidden]
849
+ elif self.agg_mode == "mean_all_tok":
850
+ query_embed1 = last_hidden_state1.mean(1) # query : [batch_size, hidden]
851
+ query_embed2 = last_hidden_state2.mean(1) # query : [batch_size, hidden]
852
+ elif self.agg_mode == "mean":
853
+ query_embed1 = (
854
+ last_hidden_state1 * query_toks1["attention_mask"].unsqueeze(-1)
855
+ ).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1)
856
+ query_embed2 = (
857
+ last_hidden_state2 * query_toks2["attention_mask"].unsqueeze(-1)
858
+ ).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1)
859
+ else:
860
+ raise NotImplementedError()
861
+ query_embed = torch.cat([query_embed1, query_embed2], dim=0)
862
+
863
+ labels = torch.cat([labels, labels], dim=0)
864
+ if self.use_miner:
865
+ hard_pairs = self.miner(query_embed, labels)
866
+ return self.loss(query_embed, labels, hard_pairs)
867
+ else:
868
+ return self.loss(query_embed, labels)
869
+