LTEnjoy commited on
Commit
619ec19
1 Parent(s): d578510

Add application file

Browse files
Dockerfile CHANGED
@@ -5,8 +5,6 @@ FROM continuumio/anaconda3:main
5
 
6
  WORKDIR /code
7
 
8
- WORKDIR /code
9
-
10
  COPY ./requirements.txt /code/requirements.txt
11
 
12
  RUN apt-get update
 
5
 
6
  WORKDIR /code
7
 
 
 
8
  COPY ./requirements.txt /code/requirements.txt
9
 
10
  RUN apt-get update
bin/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # Place the Foldseek binary file here
demo/__init__.py ADDED
File without changes
demo/modules/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path += []
4
+
5
+ import argparse
6
+
7
+
8
+ def main():
9
+ pass
10
+
11
+
12
+ def get_args():
13
+ parser = argparse.ArgumentParser()
14
+ return parser.parse_args()
15
+
16
+
17
+ if __name__ == '__main__':
18
+ args = get_args()
19
+ main()
demo/modules/compute_score.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ from .init_model import model
5
+ from utils.foldseek_util import get_struc_seq
6
+
7
+
8
+ def compute_seq_text_score(input_1: str, input_2: str):
9
+ with torch.no_grad():
10
+ protein_embedding = model.get_protein_repr([input_1])
11
+ text_embedding = model.get_text_repr([input_2])
12
+ score = text_embedding @ protein_embedding.T / model.temperature
13
+
14
+ return f"{score.item():.4f}"
15
+
16
+
17
+ def compute_struc_text_score(input_1: str, input_2: str):
18
+ with torch.no_grad():
19
+ protein_embedding = model.get_structure_repr([input_1])
20
+ text_embedding = model.get_text_repr([input_2])
21
+ score = text_embedding @ protein_embedding.T / model.temperature
22
+
23
+ return f"{score.item():.4f}"
24
+
25
+
26
+ def compute_seq_struc_score(input_1: str, input_2: str):
27
+ with torch.no_grad():
28
+ protein_embedding_1 = model.get_protein_repr([input_1])
29
+ protein_embedding_2 = model.get_structure_repr([input_2])
30
+ score = protein_embedding_1 @ protein_embedding_2.T / model.temperature
31
+
32
+ return f"{score.item():.4f}"
33
+
34
+
35
+ # Parse the uploaded structure file and return the sequence
36
+ def pdb2seq(file):
37
+ parsed_seqs = get_struc_seq("/sujin/bin/foldseek", file)
38
+
39
+ for seqs in parsed_seqs.values():
40
+ return seqs[0]
41
+
42
+
43
+ # Parse the uploaded structure file and return the foldseek sequence
44
+ def pdb2foldseek(file):
45
+ parsed_seqs = get_struc_seq("/sujin/bin/foldseek", file)
46
+
47
+ for seqs in parsed_seqs.values():
48
+ return seqs[1].lower()
49
+
50
+
51
+ # Build the block for computing protein-text similarity
52
+ def build_score_computation():
53
+ gr.Markdown(f"# Compute similarity score between two modalities")
54
+ with gr.Row(equal_height=True):
55
+ with gr.Column():
56
+ # Compute similarity score between sequence and text
57
+ with gr.Tab("sequence - text"):
58
+ with gr.Row():
59
+ seq_text_input_1 = gr.Textbox(label="sequence")
60
+
61
+ # Provide an upload button to upload a pdb file
62
+ upload_btn = gr.UploadButton(label="Upload .pdb/.cif file", scale=0)
63
+ upload_btn.upload(pdb2seq, inputs=[upload_btn], outputs=[seq_text_input_1])
64
+
65
+ seq_text_input_2 = gr.Textbox(label="text")
66
+ seq_text_examples = gr.Examples(examples=[["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK", "Proteins with zinc bindings."],
67
+ ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT", "Proteins locating at cell membrane."],
68
+ ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE", "Human represents the name assigned to the organism responsible for the protein sequence."]],
69
+ inputs=[seq_text_input_1, seq_text_input_2])
70
+ seq_text_btn = gr.Button(value="Compute")
71
+
72
+ # Compute similarity score between structure and text
73
+ with gr.Tab("structure - text"):
74
+ with gr.Row():
75
+ struc_text_input_1 = gr.Textbox(label="structure")
76
+
77
+ # Provide an upload button to upload a pdb file
78
+ upload_btn = gr.UploadButton(label="Upload .pdb/.cif file", scale=0)
79
+ upload_btn.upload(pdb2foldseek, inputs=[upload_btn], outputs=[struc_text_input_1])
80
+
81
+ struc_text_input_2 = gr.Textbox(label="text")
82
+ struc_text_examples = gr.Examples(examples=[["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd", "Proteins with zinc bindings."],
83
+ ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd", "Proteins locating at cell membrane."],
84
+ ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd", "Human represents the name assigned to the organism responsible for the protein sequence."]],
85
+ inputs=[struc_text_input_1, struc_text_input_2])
86
+ struc_text_btn = gr.Button(value="Compute")
87
+
88
+ # Compute similarity score between sequence and structure
89
+ with gr.Tab("sequence - structure"):
90
+ with gr.Row():
91
+ seq_struc_input_1 = gr.Textbox(label="sequence")
92
+
93
+ # Provide an upload button to upload a pdb file
94
+ upload_btn = gr.UploadButton(label="Upload .pdb/.cif file", scale=0)
95
+ upload_btn.upload(pdb2seq, inputs=[upload_btn], outputs=[seq_struc_input_1])
96
+
97
+ with gr.Row():
98
+ seq_struc_input_2 = gr.Textbox(label="structure")
99
+
100
+ # Provide an upload button to upload a pdb file
101
+ upload_btn = gr.UploadButton(label="Upload .pdb/.cif file", scale=0)
102
+ upload_btn.upload(pdb2foldseek, inputs=[upload_btn], outputs=[seq_struc_input_2])
103
+
104
+ seq_struc_examples = gr.Examples(examples=[["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK", "dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
105
+ ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT", "dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
106
+ ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE", "dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]],
107
+ inputs=[seq_struc_input_1, seq_struc_input_2])
108
+ seq_struc_btn = gr.Button(value="Compute")
109
+
110
+ similarity_score = gr.Label(label="similarity score")
111
+ seq_text_btn.click(fn=compute_seq_text_score, inputs=[seq_text_input_1, seq_text_input_2], outputs=[similarity_score])
112
+ struc_text_btn.click(fn=compute_struc_text_score, inputs=[struc_text_input_1, struc_text_input_2], outputs=[similarity_score])
113
+ seq_struc_btn.click(fn=compute_seq_struc_score, inputs=[seq_struc_input_1, seq_struc_input_2], outputs=[similarity_score])
demo/modules/init_model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import pandas as pd
3
+ import os
4
+
5
+ from utils.constants import sequence_level
6
+ from model.ProtTrek.protrek_trimodal_model import ProTrekTrimodalModel
7
+
8
+
9
+ def load_model():
10
+ config = {
11
+ "protein_config": "weights/ProTrek_35M_UniRef50/esm2_t12_35M_UR50D",
12
+ "text_config": "weights/ProTrek_35M_UniRef50/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
13
+ "structure_config": "weights/ProTrek_35M_UniRef50/foldseek_t12_35M",
14
+ "load_protein_pretrained": False,
15
+ "load_text_pretrained": False,
16
+ "from_checkpoint": "weights/ProTrek_35M_UniRef50/ProTrek_35M_UniRef50.pt"
17
+ }
18
+
19
+ model = ProTrekTrimodalModel(**config)
20
+ model.eval()
21
+ return model
22
+
23
+
24
+ def load_index():
25
+ index_dir = "weights/faiss_index/faiss_index_ProTrek_35M_UniRef50"
26
+ all_index = {}
27
+
28
+ # Load protein sequence index
29
+ index_path = f"{index_dir}/sequence.index"
30
+ sequence_index = faiss.read_index(index_path)
31
+
32
+ id_path = f"{index_dir}/sequence_ids.tsv"
33
+ uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
34
+
35
+ all_index["sequence"] = {"index": sequence_index, "ids": uniprot_ids}
36
+
37
+ # Load protein structure index
38
+ index_path = f"{index_dir}/structure.index"
39
+ structure_index = faiss.read_index(index_path)
40
+
41
+ id_path = f"{index_dir}/structure_ids.tsv"
42
+ uniprot_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
43
+
44
+ all_index["structure"] = {"index": structure_index, "ids": uniprot_ids}
45
+
46
+ # Load text index
47
+ all_index["text"] = {}
48
+ text_dir = f"{index_dir}/text"
49
+
50
+ # Remove "Taxonomic lineage" from sequence_level. This is a special case which we don't need to index.
51
+ valid_subsections = set()
52
+ sequence_level.add("Global")
53
+ for subsection in sequence_level:
54
+ index_path = f"{text_dir}/{subsection.replace(' ', '_')}.index"
55
+ if not os.path.exists(index_path):
56
+ continue
57
+
58
+ text_index = faiss.read_index(index_path)
59
+
60
+ id_path = f"{text_dir}/{subsection.replace(' ', '_')}_ids.tsv"
61
+ text_ids = pd.read_csv(id_path, sep="\t", header=None).values.flatten()
62
+
63
+ all_index["text"][subsection] = {"index": text_index, "ids": text_ids}
64
+ valid_subsections.add(subsection)
65
+
66
+ return all_index, valid_subsections
67
+
68
+
69
+ device = "cuda"
70
+
71
+ print("Loading model...")
72
+ model = load_model()
73
+ model.to(device)
74
+
75
+ print("Loading index...")
76
+ all_index, valid_subsections = load_index()
77
+ print("Done...")
demo/modules/search.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import pandas as pd
4
+
5
+ from utils.foldseek_util import get_struc_seq
6
+ from .init_model import model, all_index
7
+
8
+
9
+ # Samples for input
10
+ samples = [
11
+ ["Proteins with zinc bindings."],
12
+ ["Proteins locating at cell membrane."],
13
+ ["Protein that serves as an enzyme."]
14
+ ]
15
+
16
+ # Choices for subsection type
17
+ # valid_subsections = {"Function", "Subcellular location", "Protein names", "Sequence similarities", "GO annotation", "Global"}
18
+ valid_subsections = all_index["text"].keys()
19
+ # Sort the subsections
20
+ valid_subsections = sorted(valid_subsections)
21
+
22
+
23
+ def clear_results():
24
+ return ""
25
+
26
+
27
+ # Search from database
28
+ def search(input: str, topk: int, input_type: str, query_type: str, subsection_type: str):
29
+ input_modality = input_type.split(" ")[-1].replace("sequence", "protein")
30
+ with torch.no_grad():
31
+ input_embedding = getattr(model, f"get_{input_modality}_repr")([input]).cpu().numpy()
32
+
33
+ output_modality = query_type.split(" ")[-1]
34
+ if output_modality == "text":
35
+ index = all_index["text"][subsection_type]["index"]
36
+ ids = all_index["text"][subsection_type]["ids"]
37
+
38
+ else:
39
+ index = all_index[output_modality]["index"]
40
+ ids = all_index[output_modality]["ids"]
41
+
42
+ scores, ranks = index.search(input_embedding, topk)
43
+ scores = scores / model.temperature.item()
44
+
45
+ # Get topk ids
46
+ topk_ids = []
47
+ for rank in ranks[0]:
48
+ now_id = ids[rank]
49
+ if query_type == "text":
50
+ topk_ids.append(now_id)
51
+ else:
52
+ # Provide link to uniprot website
53
+ topk_ids.append(f"[{now_id}](https://www.uniprot.org/uniprotkb/{now_id})")
54
+
55
+ df = pd.DataFrame({"Id": topk_ids, "Matching score": scores[0]})
56
+ output = df.to_markdown()
57
+
58
+ return output
59
+
60
+
61
+ def change_input_type(choice: str):
62
+ # Change examples if input type is changed
63
+ global samples
64
+ if choice == "text":
65
+ samples = [
66
+ ["Proteins with zinc bindings."],
67
+ ["Proteins locating at cell membrane."],
68
+ ["Protein that serves as an enzyme."]
69
+ ]
70
+
71
+ elif choice == "protein sequence":
72
+ samples = [
73
+ ["MSATAEQNARNPKGKGGFARTVSQRKRKRLFLIGGALAVLAVAVGLMLTAFNQDIRFFRTPADLTEQDMTSGARFRLGGLVEEGSVSRTGSELRFTVTDTIKTVKVVFEGIPPDLFREGQGVVAEGRFGSDGLFRADNVLAKHDENYVPKDLADSLKKKGVWEGK"],
74
+ ["MITLDWEKANGLITTVVQDATTKQVLMVAYMNQESLAKTMATGETWFWSRSRKTLWHKGATSGNIQTVKTIAVDCDADTLLVTVDPAGPACHTGHISCFYRHYPEGKDLT"],
75
+ ["MDLKQYVSEVQDWPKPGVSFKDITTIMDNGEAYGYATDKIVEYAKDRDVDIVVGPEARGFIIGCPVAYSMGIGFAPVRKEGKLPREVIRYEYDLEYGTNVLTMHKDAIKPGQRVLITDDLLATGGTIEAAIKLVEKLGGIVVGIAFIIELKYLNGIEKIKDYDVMSLISYDE"]
76
+ ]
77
+
78
+ elif choice == "protein structure":
79
+ samples = [
80
+ ["dddddddddddddddpdpppvcppvnvvvvvvvvvvvvvvvvvvvvvvvvvvqdpqdedeqvrddpcqqpvqhkhkykafwappqwdddpqkiwtwghnppgiaieieghdappqddhrfikifiaghdpvrhtygdhidtdddpddddvvnvvvcvvvvndpdd"],
81
+ ["dddadcpvpvqkakefeaeppprdtadiaiagpvqvvvcvvpqwhwgqdpvvrdidgqcpvpvqiwrwddwdaddnrryiytythtpahsdpvrhvhpppadvvgpddpd"],
82
+ ["dplvvqwdwdaqpphhpdtdthcvscvvppvslvvqlvvvlvvcvvqvaqeeeeepdqrcsnrvsscvvvvhyywykyfpppddaawdwdwdddppgitiiithlpseaaageyeyegaeqalqprvlrvvvrcvvnnyddaeyeyqeyevcrvncvsvvvhhydyvyydpd"]
83
+ ]
84
+
85
+ # Set visibility of upload button
86
+ if choice == "text":
87
+ visible = False
88
+ else:
89
+ visible = True
90
+
91
+ return samples, "", gr.update(visible=visible)
92
+
93
+
94
+ # Load example from dataset
95
+ def load_example(example_id):
96
+ return samples[example_id][0]
97
+
98
+
99
+ # Change the visibility of subsection type
100
+ def subsection_visibility(query_type: str):
101
+ if query_type == "text":
102
+ return gr.update(visible=True)
103
+ else:
104
+ return gr.update(visible=False)
105
+
106
+
107
+ # Parse the uploaded structure file
108
+ def parse_pdb_file(input_type, file):
109
+ parsed_seqs = get_struc_seq("bin/foldseek", file)
110
+
111
+ for seqs in parsed_seqs.values():
112
+ if input_type == "protein sequence":
113
+ return seqs[0]
114
+ else:
115
+ return seqs[1].lower()
116
+
117
+
118
+ # Build the block for text to protein
119
+ def build_search_module():
120
+ gr.Markdown(f"# Search from Swiss-Prot database (the whole UniProt database will be supported soon)")
121
+ with gr.Row(equal_height=True):
122
+ with gr.Column():
123
+ # Set input type
124
+ input_type = gr.Radio(["protein sequence", "protein structure", "text"], label="Input type (e.g. 'text' means searching based on text descriptions)", value="text")
125
+
126
+ with gr.Row():
127
+ # Set query type
128
+ query_type = gr.Radio(["protein sequence", "protein structure", "text"], label="Query type (e.g. 'protein sequence' means returning qualified protein sequences)", value="protein sequence")
129
+
130
+ # If the query type is "text", provide an option to choose the subsection of text
131
+ subsection_type = gr.Dropdown(list(valid_subsections), label="Subsection of text", value="Function",
132
+ scale=0, interactive=True, visible=False)
133
+
134
+ # Add event listener to query type
135
+ query_type.change(fn=subsection_visibility, inputs=[query_type], outputs=[subsection_type])
136
+
137
+ with gr.Row():
138
+ # Input box
139
+ input = gr.Text(label="Input")
140
+
141
+ # Provide an upload button to upload a pdb file
142
+ upload_btn = gr.UploadButton(label="Upload .pdb/.cif file", scale=0, visible=False)
143
+ upload_btn.upload(parse_pdb_file, inputs=[input_type, upload_btn], outputs=[input])
144
+
145
+ # Choose topk results
146
+ topk = gr.Slider(1, 100, 5, step=1, label="Retrieve top k results")
147
+
148
+ # Provide examples
149
+ examples = gr.Dataset(samples=samples, components=[input], type="index", label="Input examples")
150
+
151
+ # Add click event to examples
152
+ examples.click(fn=load_example, inputs=[examples], outputs=input)
153
+
154
+ # Change examples based on input type
155
+ input_type.change(fn=change_input_type, inputs=[input_type], outputs=[examples, input, upload_btn])
156
+
157
+ with gr.Row():
158
+ t2p_btn = gr.Button(value="Search")
159
+ clear_btn = gr.Button(value="Clear")
160
+
161
+ results = gr.Markdown(label="results")
162
+ t2p_btn.click(fn=search, inputs=[input, topk, input_type, query_type, subsection_type], outputs=results)
163
+ clear_btn.click(fn=clear_results, outputs=results)
demo/run.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ root_dir = __file__.rsplit("/", 2)[0]
3
+ if root_dir not in sys.path:
4
+ sys.path.append(root_dir)
5
+
6
+ import gradio as gr
7
+
8
+ from modules.search import build_search_module
9
+ from modules.compute_score import build_score_computation
10
+
11
+
12
+ # Build demo
13
+ with gr.Blocks() as demo:
14
+ build_search_module()
15
+ build_score_computation()
16
+
17
+
18
+ if __name__ == '__main__':
19
+ # args = get_args()
20
+
21
+ # Run demo
22
+ demo.launch()
model/ProtTrek/protein_encoder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from tqdm import tqdm
4
+ from torch.nn.functional import normalize
5
+ from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer
6
+
7
+
8
+ class ProteinEncoder(torch.nn.Module):
9
+ def __init__(self,
10
+ config_path: str,
11
+ out_dim: int,
12
+ load_pretrained: bool = True,
13
+ gradient_checkpointing: bool = False):
14
+ """
15
+ Args:
16
+ config_path: Path to the config file
17
+
18
+ out_dim : Output dimension of the protein representation
19
+
20
+ load_pretrained: Whether to load pretrained weights
21
+
22
+ gradient_checkpointing: Whether to use gradient checkpointing
23
+ """
24
+ super().__init__()
25
+ config = EsmConfig.from_pretrained(config_path)
26
+ if load_pretrained:
27
+ self.model = EsmForMaskedLM.from_pretrained(config_path)
28
+ else:
29
+ self.model = EsmForMaskedLM(config)
30
+ self.out = torch.nn.Linear(config.hidden_size, out_dim)
31
+
32
+ # Set gradient checkpointing
33
+ self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
34
+
35
+ # Remove contact head
36
+ self.model.esm.contact_head = None
37
+
38
+ # Remove position embedding if the embedding type is ``rotary``
39
+ if config.position_embedding_type == "rotary":
40
+ self.model.esm.embeddings.position_embeddings = None
41
+
42
+ self.tokenizer = EsmTokenizer.from_pretrained(config_path)
43
+
44
+ def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
45
+ """
46
+ Compute protein representation for the given proteins
47
+ Args:
48
+ protein: A list of protein sequences
49
+ batch_size: Batch size for inference
50
+ verbose: Whether to print progress
51
+ """
52
+ device = next(self.parameters()).device
53
+
54
+ protein_repr = []
55
+ if verbose:
56
+ iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
57
+ else:
58
+ iterator = range(0, len(proteins), batch_size)
59
+
60
+ for i in iterator:
61
+ protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
62
+ return_tensors="pt",
63
+ padding=True)
64
+ protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
65
+ output, _ = self.forward(protein_inputs)
66
+
67
+ protein_repr.append(output)
68
+
69
+ protein_repr = torch.cat(protein_repr, dim=0)
70
+ return normalize(protein_repr, dim=-1)
71
+
72
+ def forward(self, inputs: dict, get_mask_logits: bool = False):
73
+ """
74
+ Encode protein sequence into protein representation
75
+ Args:
76
+ inputs: A dictionary containing the following keys:
77
+ - input_ids: [batch, seq_len]
78
+ - attention_mask: [batch, seq_len]
79
+ get_mask_logits: Whether to return the logits for masked tokens
80
+
81
+ Returns:
82
+ protein_repr: [batch, protein_repr_dim]
83
+ mask_logits : [batch, seq_len, vocab_size]
84
+ """
85
+ last_hidden_state = self.model.esm(**inputs).last_hidden_state
86
+ reprs = last_hidden_state[:, 0, :]
87
+ reprs = self.out(reprs)
88
+
89
+ # Get logits for masked tokens
90
+ if get_mask_logits:
91
+ mask_logits = self.model.lm_head(last_hidden_state)
92
+ else:
93
+ mask_logits = None
94
+
95
+ return reprs, mask_logits
model/ProtTrek/protrek_trimodal_model.py ADDED
@@ -0,0 +1,853 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import torchmetrics
4
+ import json
5
+ import math
6
+ import numpy as np
7
+ import os
8
+ import copy
9
+ import faiss
10
+ import time
11
+ import pandas as pd
12
+ import random
13
+
14
+ from tqdm import tqdm
15
+ from .protein_encoder import ProteinEncoder
16
+ from .structure_encoder import StructureEncoder
17
+ from .text_encoder import TextEncoder
18
+ from ..abstract_model import AbstractModel
19
+ from ..model_interface import register_model
20
+ from utils.mpr import MultipleProcessRunnerSimplifier
21
+ from torch.nn.functional import normalize, cross_entropy
22
+ from utils.constants import residue_level, sequence_level
23
+ from sklearn.metrics import roc_auc_score
24
+
25
+
26
+ def multilabel_cross_entropy(logits, labels):
27
+ """
28
+ Compute cross entropy loss for multilabel classification。 See "https://arxiv.org/pdf/2208.02955.pdf"
29
+ Args:
30
+ logits: [num_samples, num_classes]
31
+ labels: [num_samples, num_classes]
32
+ """
33
+
34
+ loss = 0
35
+ for pred, label in zip(logits, labels):
36
+ pos_logits = pred[label == 1]
37
+ neg_logits = pred[label == 0]
38
+
39
+ diff = neg_logits.unsqueeze(-1) - pos_logits
40
+ loss += torch.log(1 + torch.exp(diff).sum())
41
+
42
+ return loss / len(logits)
43
+
44
+ # pred = (1 - 2 * labels) * logits
45
+ # pred_neg = pred - labels * 1e12
46
+ # pred_pos = pred - (1 - labels) * 1e12
47
+ #
48
+ # zeros = torch.zeros_like(logits[..., :1], dtype=logits.dtype)
49
+ # pred_neg = torch.cat([pred_neg, zeros], dim=-1)
50
+ # pred_pos = torch.cat([pred_pos, zeros], dim=-1)
51
+ #
52
+ # neg_loss = torch.logsumexp(pred_neg, dim=-1)
53
+ # pos_loss = torch.logsumexp(pred_pos, dim=-1)
54
+ #
55
+ # return (neg_loss + pos_loss).mean()
56
+
57
+
58
+ @register_model
59
+ class ProTrekTrimodalModel(AbstractModel):
60
+ def __init__(self,
61
+ protein_config: str,
62
+ text_config: str,
63
+ structure_config: str = None,
64
+ repr_dim: int = 1024,
65
+ temperature: float = 0.07,
66
+ load_protein_pretrained: bool = True,
67
+ load_text_pretrained: bool = True,
68
+ use_mlm_loss: bool = False,
69
+ use_zlpr_loss: bool = False,
70
+ use_saprot: bool = False,
71
+ gradient_checkpointing: bool = False,
72
+ **kwargs):
73
+ """
74
+ Args:
75
+ protein_config: Path to the config file for protein sequence encoder
76
+
77
+ text_config: Path to the config file for text encoder
78
+
79
+ structure_config: Path to the config file for structure encoder
80
+
81
+ repr_dim: Output dimension of the protein and text representation
82
+
83
+ temperature: Temperature for softmax
84
+
85
+ load_protein_pretrained: Whether to load pretrained weights for protein encoder
86
+
87
+ load_text_pretrained: Whether to load pretrained weights for text encoder
88
+
89
+ use_mlm_loss: Whether to use masked language modeling loss
90
+
91
+ use_zlpr_loss: Whether to use zlpr loss. See "https://arxiv.org/pdf/2208.02955.pdf"
92
+
93
+ use_saprot: Whether to use SaProt as protein encoder
94
+
95
+ gradient_checkpointing: Whether to use gradient checkpointing for protein encoder
96
+ """
97
+ self.protein_config = protein_config
98
+ self.structure_config = structure_config
99
+ self.text_config = text_config
100
+ self.repr_dim = repr_dim
101
+ self.temperature = temperature
102
+ self.load_protein_pretrained = load_protein_pretrained
103
+ self.load_text_pretrained = load_text_pretrained
104
+ self.use_mlm_loss = use_mlm_loss
105
+ self.use_zlpr_loss = use_zlpr_loss
106
+ self.use_saprot = use_saprot
107
+ self.gradient_checkpointing = gradient_checkpointing
108
+ super().__init__(**kwargs)
109
+
110
+ def initialize_metrics(self, stage: str) -> dict:
111
+ return_dict = {
112
+ f"{stage}_protein_text_acc": torchmetrics.Accuracy(),
113
+ f"{stage}_text_protein_acc": torchmetrics.Accuracy(),
114
+ }
115
+
116
+ if self.use_mlm_loss:
117
+ return_dict[f"{stage}_protein_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
118
+ if self.structure_config is not None:
119
+ return_dict[f"{stage}_structure_mask_acc"] = torchmetrics.Accuracy(ignore_index=-1)
120
+
121
+ if self.structure_config is not None:
122
+ return_dict[f"{stage}_structure_protein_acc"] = torchmetrics.Accuracy()
123
+ return_dict[f"{stage}_structure_text_acc"] = torchmetrics.Accuracy()
124
+ return_dict[f"{stage}_text_structure_acc"] = torchmetrics.Accuracy()
125
+ return_dict[f"{stage}_protein_structure_acc"] = torchmetrics.Accuracy()
126
+
127
+ return return_dict
128
+
129
+ def initialize_model(self):
130
+ # Initialize encoders
131
+ self.protein_encoder = ProteinEncoder(self.protein_config,
132
+ self.repr_dim,
133
+ self.load_protein_pretrained,
134
+ self.gradient_checkpointing)
135
+
136
+ self.text_encoder = TextEncoder(self.text_config,
137
+ self.repr_dim,
138
+ self.load_text_pretrained,
139
+ self.gradient_checkpointing)
140
+
141
+ # Learnable temperature
142
+ self.temperature = torch.nn.Parameter(torch.tensor(self.temperature))
143
+
144
+ # self.model is used for saving and loading
145
+ self.model = torch.nn.ParameterList([self.temperature,
146
+ self.protein_encoder,
147
+ self.text_encoder])
148
+
149
+ # If the structure encoder is specified
150
+ if self.structure_config is not None:
151
+ self.structure_encoder = StructureEncoder(self.structure_config, self.repr_dim)
152
+ self.model.append(self.structure_encoder)
153
+
154
+ def get_text_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
155
+ return self.text_encoder.get_repr(texts, batch_size, verbose)
156
+
157
+ def get_structure_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
158
+ return self.structure_encoder.get_repr(proteins, batch_size, verbose)
159
+
160
+ def get_protein_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
161
+ return self.protein_encoder.get_repr(proteins, batch_size, verbose)
162
+
163
+ def forward(self, protein_inputs: dict, text_inputs: dict, structure_inputs: dict = None):
164
+ """
165
+ Args:
166
+ protein_inputs: A dictionary for protein encoder
167
+ structure_inputs: A dictionary for structure encoder
168
+ text_inputs : A dictionary for text encoder
169
+ """
170
+ protein_repr, protein_mask_logits = self.protein_encoder(protein_inputs, self.use_mlm_loss)
171
+ text_repr = self.text_encoder(text_inputs)
172
+
173
+ outputs = [text_repr, protein_repr, protein_mask_logits]
174
+
175
+ if self.structure_config is not None:
176
+ structure_repr, structure_mask_logits = self.structure_encoder(structure_inputs, self.use_mlm_loss)
177
+ outputs += [structure_repr, structure_mask_logits]
178
+
179
+ return outputs
180
+
181
+ def loss_func(self, stage: str, outputs, labels):
182
+ if self.structure_config is not None:
183
+ text_repr, protein_repr, protein_mask_logits, structure_repr, structure_mask_logits = outputs
184
+ else:
185
+ text_repr, protein_repr, protein_mask_logits = outputs
186
+
187
+ device = text_repr.device
188
+
189
+ text_repr = normalize(text_repr, dim=-1)
190
+ protein_repr = normalize(protein_repr, dim=-1)
191
+
192
+ # Gather representations from all GPUs
193
+ all_protein_repr = self.all_gather(protein_repr).view(-1, protein_repr.shape[-1]).detach()
194
+ all_text_repr = self.all_gather(text_repr).view(-1, text_repr.shape[-1]).detach()
195
+
196
+ if self.structure_config is not None:
197
+ structure_repr = normalize(structure_repr, dim=-1)
198
+ all_structure_repr = self.all_gather(structure_repr).view(-1, structure_repr.shape[-1]).detach()
199
+
200
+ # text_idx = labels["text_idx"]
201
+ # text_candidates = labels["text_candidates"]
202
+ #
203
+ # # Gather all text ids
204
+ # text_inds = self.all_gather(text_idx).flatten()
205
+ # # Create text classification labels
206
+ # text_labels = torch.zeros(len(text_candidates), len(text_inds), dtype=int).to(device)
207
+ # for i, candidate in enumerate(text_candidates):
208
+ # for j, idx in enumerate(text_inds):
209
+ # if idx.item() in candidate:
210
+ # text_labels[i, j] = 1
211
+ #
212
+ # # Gather text labels from all GPUs
213
+ # text_labels = self.all_gather(text_labels).view(-1, text_labels.shape[-1])
214
+ #
215
+ # # Protein classification labels are the transpose of text labels
216
+ # protein_labels = text_labels.T
217
+
218
+ # Batch size
219
+ rank = dist.get_rank()
220
+ bs = text_repr.shape[0]
221
+
222
+ # Get current labels
223
+ # protein_labels = protein_labels[rank * bs: rank * bs + bs]
224
+ # text_labels = text_labels[rank * bs: rank * bs + bs]
225
+
226
+ # Create classification labels between structure and sequence
227
+ bs_labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(device)
228
+
229
+ if self.structure_config is not None:
230
+ pairs = {
231
+ "protein": ["structure", "text"],
232
+ "structure": ["protein", "text"],
233
+ "text": ["protein", "structure"]
234
+ }
235
+ else:
236
+ pairs = {
237
+ "protein": ["text"],
238
+ "text": ["protein"]
239
+ }
240
+
241
+ loss_list = []
242
+ for k, values in pairs.items():
243
+ for v in values:
244
+ # Only calculate the similarity for the current batch
245
+ sim = torch.matmul(eval(f"{k}_repr"), eval(f"all_{v}_repr").T).div(self.temperature)
246
+
247
+ # if k == "text":
248
+ # if self.use_zlpr_loss:
249
+ # loss = multilabel_cross_entropy(sim, protein_labels)
250
+ # else:
251
+ # loss = cross_entropy(sim, bs_labels)
252
+ #
253
+ # pred = []
254
+ # for s, l in zip(sim, protein_labels):
255
+ # n_label = l.sum()
256
+ # topk = torch.topk(s, k=n_label).indices
257
+ # if l[topk].sum() == n_label:
258
+ # pred.append(1)
259
+ # else:
260
+ # pred.append(0)
261
+ #
262
+ # pred = torch.tensor(pred).to(device)
263
+ # label = torch.ones_like(pred)
264
+ # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
265
+ # # if v == "protein":
266
+ # # acc = self.metrics[stage][f"{stage}_{k}_{v}_acc"].compute()
267
+ # # print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
268
+ #
269
+ # elif v == "text":
270
+ # if self.use_zlpr_loss:
271
+ # loss = multilabel_cross_entropy(sim, text_labels)
272
+ # else:
273
+ # loss = cross_entropy(sim, bs_labels)
274
+ #
275
+ # pred = []
276
+ # for s, l in zip(sim, text_labels):
277
+ # n_label = l.sum()
278
+ # topk = torch.topk(s, k=n_label).indices
279
+ # if l[topk].sum() == n_label:
280
+ # pred.append(1)
281
+ # else:
282
+ # pred.append(0)
283
+ #
284
+ # pred = torch.tensor(pred).to(device)
285
+ # label = torch.ones_like(pred)
286
+ # # if k == "protein":
287
+ # # acc = pred.sum() / len(pred)
288
+ # # print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
289
+ # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
290
+ #
291
+ # else:
292
+ # loss = cross_entropy(sim, bs_labels)
293
+ # self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
294
+
295
+ loss = cross_entropy(sim, bs_labels)
296
+ self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)
297
+ loss_list.append(loss)
298
+
299
+ # Masked language modeling loss
300
+ if self.use_mlm_loss:
301
+ k_label = [("protein", labels["seq_labels"])]
302
+ if self.structure_config is not None:
303
+ k_label.append(("structure", labels["struc_labels"]))
304
+
305
+ for k, label in k_label:
306
+ logits = eval(f"{k}_mask_logits")
307
+ # merge the first and second dimension of logits
308
+ logits = logits.view(-1, logits.shape[-1])
309
+ label = label.flatten().to(device)
310
+ mlm_loss = cross_entropy(logits, label, ignore_index=-1)
311
+ loss_list.append(mlm_loss)
312
+ self.metrics[stage][f"{stage}_{k}_mask_acc"].update(logits.detach(), label)
313
+
314
+ loss = sum(loss_list) / len(loss_list)
315
+
316
+ if stage == "train":
317
+ log_dict = self.get_log_dict("train")
318
+ log_dict["train_loss"] = loss
319
+ self.log_info(log_dict)
320
+
321
+ # Reset train metrics
322
+ self.reset_metrics("train")
323
+
324
+ return loss
325
+
326
+ def _get_protein_indices(self):
327
+ world_size = dist.get_world_size()
328
+ rank = dist.get_rank()
329
+
330
+ if self.use_saprot:
331
+ proteins = []
332
+ for sub_dict in self.uniprot2label.values():
333
+ aa_seq = sub_dict["seq"]
334
+ foldseek_seq = sub_dict["foldseek"]
335
+ assert len(aa_seq) == len(foldseek_seq)
336
+ seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
337
+ proteins.append(seq)
338
+
339
+ else:
340
+ proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
341
+
342
+ span = math.ceil(len(proteins) / world_size)
343
+ sub_proteins = proteins[rank * span: (rank + 1) * span]
344
+
345
+ # Display the progress bar on the rank 0 process
346
+ verbose = self.trainer.local_rank == 0
347
+ # Get protein representations
348
+ sub_protein_repr = self.protein_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
349
+ protein_repr = self.padded_gather(sub_protein_repr)
350
+
351
+ # Construct faiss index
352
+ d = protein_repr.shape[-1]
353
+ protein_indices = faiss.IndexFlatIP(d)
354
+ protein_indices.add(protein_repr.cpu().numpy())
355
+ return protein_indices
356
+
357
+ def _get_structure_indices(self):
358
+ world_size = dist.get_world_size()
359
+ rank = dist.get_rank()
360
+
361
+ proteins = [sub_dict["foldseek"] for sub_dict in self.uniprot2label.values()]
362
+ span = math.ceil(len(proteins) / world_size)
363
+ sub_proteins = proteins[rank * span: (rank + 1) * span]
364
+
365
+ # Display the progress bar on the rank 0 process
366
+ verbose = self.trainer.local_rank == 0
367
+ # Get protein representations
368
+ sub_protein_repr = self.structure_encoder.get_repr(sub_proteins, batch_size=1, verbose=verbose)
369
+ protein_repr = self.padded_gather(sub_protein_repr)
370
+
371
+ # Construct faiss index
372
+ d = protein_repr.shape[-1]
373
+ structure_indices = faiss.IndexFlatIP(d)
374
+ structure_indices.add(protein_repr.cpu().numpy())
375
+ return structure_indices
376
+
377
+ def _get_text_indices(self):
378
+ world_size = dist.get_world_size()
379
+ rank = dist.get_rank()
380
+
381
+ # Display the progress bar on the rank 0 process
382
+ verbose = self.trainer.local_rank == 0
383
+ if verbose:
384
+ iterator = tqdm(self.label2text.keys(), desc="Get text representations")
385
+ else:
386
+ iterator = self.label2text.keys()
387
+
388
+ text_embeddings = {}
389
+ for subsection in iterator:
390
+ if subsection == "Total":
391
+ continue
392
+
393
+ texts = []
394
+ for text_list in self.label2text[subsection].values():
395
+ # Only use the first text for efficiency
396
+ texts.append(text_list[0:1])
397
+
398
+ span = math.ceil(len(texts) / world_size)
399
+ texts = texts[rank * span: (rank + 1) * span]
400
+ embeddings = []
401
+ for text_list in texts:
402
+ text_repr = self.text_encoder.get_repr(text_list)
403
+ mean_repr = text_repr.mean(dim=0, keepdim=True)
404
+ norm_repr = torch.nn.functional.normalize(mean_repr, dim=-1)
405
+ embeddings.append(norm_repr)
406
+
407
+ if len(embeddings) > 0:
408
+ embeddings = torch.cat(embeddings, dim=0)
409
+ else:
410
+ embeddings = torch.zeros(0, self.repr_dim, dtype=self.dtype, device=self.device)
411
+
412
+ text_repr = self.padded_gather(embeddings)
413
+ text_embeddings[subsection] = text_repr
414
+
415
+ # Aggregate text embeddings for global retrieval
416
+ total_embeddings = []
417
+ for idx in self.label2text["Total"].values():
418
+ subsection, i = idx.split("|")
419
+ total_embeddings.append(text_embeddings[subsection][int(i)])
420
+
421
+ text_embeddings["Total"] = torch.stack(total_embeddings)
422
+
423
+ # Construct faiss index
424
+ text_indices = {}
425
+ for subsection, text_repr in text_embeddings.items():
426
+ d = text_repr.shape[-1]
427
+ text_indices[subsection] = faiss.IndexFlatIP(d)
428
+ text_indices[subsection].add(text_repr.cpu().numpy())
429
+
430
+ return text_indices
431
+
432
+ def _protein2text(self, modality: str, protein_indices, text_indices: dict):
433
+ def do(process_id, idx, row, writer):
434
+ subsection, uniprot_id, prob_idx, label = row
435
+
436
+ # Retrieve ranking results
437
+ p_embedding = protein_indices.reconstruct(prob_idx).reshape(1, -1)
438
+ text_inds = text_indices[subsection]
439
+ sim_scores, rank_inds = text_inds.search(p_embedding, text_inds.ntotal)
440
+ sim_scores, rank_inds = sim_scores[0], rank_inds[0]
441
+
442
+ # Calculate Average Precision(AP)
443
+ ranks = []
444
+ label = set(label)
445
+ for i, rk in enumerate(rank_inds):
446
+ # Find the rank of this label in all labels
447
+ if rk in label:
448
+ ranks.append(i + 1)
449
+
450
+ ranks = np.array(ranks)
451
+ ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
452
+
453
+ # Calculate Mean Reciprocal Rank(MRR)
454
+ best_rank = ranks[0]
455
+ mrr = 1 / best_rank
456
+
457
+ # Calculate the AUC
458
+ true_labels = np.zeros_like(sim_scores)
459
+ true_labels[ranks - 1] = 1
460
+ if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
461
+ auc = 0
462
+ else:
463
+ auc = roc_auc_score(true_labels, sim_scores)
464
+
465
+ output = json.dumps([ap, mrr, auc])
466
+ writer.write(output + "\n")
467
+
468
+ inputs = []
469
+ swissprot_subsections = set()
470
+ for subsection in text_indices.keys():
471
+ for i, (uniprot_id, labels) in enumerate(self.uniprot2label.items()):
472
+ if uniprot_id in self.swissprot_ids:
473
+ if subsection in labels:
474
+ swissprot_subsections.add(subsection)
475
+ label = labels[subsection]
476
+ inputs.append((subsection, uniprot_id, i, label))
477
+
478
+ # Randomly shuffle the inputs
479
+ random.seed(20000812)
480
+ random.shuffle(inputs)
481
+
482
+ # Split inputs into chunks for parallel processing
483
+ world_size = dist.get_world_size()
484
+ rank = dist.get_rank()
485
+
486
+ span = math.ceil(len(inputs) / world_size)
487
+ sub_inputs = inputs[rank * span: (rank + 1) * span]
488
+
489
+ # Display the progress bar on the rank 0 process
490
+ verbose = self.trainer.local_rank == 0
491
+ if verbose:
492
+ print("Evaluating on each subsection...")
493
+ tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
494
+ mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
495
+ return_results=True)
496
+ outputs = mpr.run()
497
+ os.remove(tmp_path)
498
+
499
+ # Aggregate results
500
+ tensor_outputs = []
501
+ for output in outputs:
502
+ ap, mrr, auc = json.loads(output)
503
+ tensor_outputs.append([float(ap), float(mrr), float(auc)])
504
+
505
+ tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
506
+ tensor_outputs = self.padded_gather(tensor_outputs)
507
+
508
+ # Record results
509
+ avg_results = {}
510
+ for subsection in swissprot_subsections:
511
+ avg_results[subsection] = {"map": [],
512
+ "mrr": [],
513
+ "auc": []}
514
+
515
+ for input, output in zip(inputs, tensor_outputs):
516
+ ap, mrr, auc = output
517
+ subsection, _, _, _ = input
518
+
519
+ avg_results[subsection]["map"].append(ap.cpu().item())
520
+ avg_results[subsection]["mrr"].append(mrr.cpu().item())
521
+ avg_results[subsection]["auc"].append(auc.cpu().item())
522
+
523
+ results = {
524
+ f"{modality}2Text_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
525
+ f"{modality}2Text_Total_map": np.mean(avg_results["Total"]["map"]),
526
+ f"{modality}2Text_Total_auc": np.mean(avg_results["Total"]["auc"]),
527
+ }
528
+
529
+ # Average the precision and recall for each level
530
+ for level, labels in [("residue-level", residue_level),
531
+ ("sequence-level", sequence_level),
532
+ ("all", residue_level | sequence_level)]:
533
+
534
+ mrrs = []
535
+ maps = []
536
+ aucs = []
537
+ for subsection in labels:
538
+ if subsection in avg_results:
539
+ mrrs.append(np.mean(avg_results[subsection]["mrr"]))
540
+ maps.append(np.mean(avg_results[subsection]["map"]))
541
+ aucs.append(np.mean(avg_results[subsection]["auc"]))
542
+
543
+ results[f"{modality}2Text_{level}_mrr"] = np.mean(mrrs)
544
+ results[f"{modality}2Text_{level}_map"] = np.mean(maps)
545
+ results[f"{modality}2Text_{level}_auc"] = np.mean(aucs)
546
+
547
+ return results
548
+
549
+ def _text2protein(self, modality: str, protein_indices, text_indices: dict):
550
+ def do(process_id, idx, row, writer):
551
+ subsection, text_id, label = row
552
+
553
+ # Retrieve ranking results
554
+ t_embedding = text_indices[subsection].reconstruct(text_id).reshape(1, -1)
555
+ sim_scores, rank_inds = protein_indices.search(t_embedding, protein_indices.ntotal)
556
+ sim_scores, rank_inds = sim_scores[0], rank_inds[0]
557
+
558
+ # Calculate Average Precision(AP)
559
+ ranks = []
560
+ label = set(label)
561
+ for i, rk in enumerate(rank_inds):
562
+ # Find the rank of this label in all labels
563
+ if rk in label:
564
+ ranks.append(i + 1)
565
+
566
+ ranks = np.array(ranks)
567
+ ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])
568
+
569
+ # Calculate Mean Reciprocal Rank(MRR)
570
+ best_rank = ranks[0]
571
+ mrr = 1 / best_rank
572
+
573
+ # Calculate the AUC
574
+ true_labels = np.zeros_like(sim_scores)
575
+ true_labels[ranks - 1] = 1
576
+ if true_labels.sum() == 0 or true_labels.sum() == true_labels.shape[0]:
577
+ auc = 0
578
+ else:
579
+ auc = roc_auc_score(true_labels, sim_scores)
580
+
581
+ output = json.dumps([ap, mrr, auc])
582
+ writer.write(output + "\n")
583
+
584
+ text2label = {}
585
+ swissprot_subsections = set()
586
+ for i, (uniprot_id, subsections) in enumerate(self.uniprot2label.items()):
587
+ # Only evaluate the texts in Swiss-Prot
588
+ if uniprot_id not in self.swissprot_ids:
589
+ continue
590
+
591
+ for subsection, text_ids in subsections.items():
592
+ if subsection == "seq" or subsection == "foldseek":
593
+ continue
594
+
595
+ swissprot_subsections.add(subsection)
596
+ if subsection not in text2label:
597
+ text2label[subsection] = {}
598
+
599
+ for text_id in text_ids:
600
+ text2label[subsection][text_id] = text2label[subsection].get(text_id, []) + [i]
601
+
602
+ inputs = []
603
+ for subsection in swissprot_subsections:
604
+ for i, (text_id, label) in enumerate(text2label[subsection].items()):
605
+ inputs.append((subsection, text_id, label))
606
+
607
+ # Randomly shuffle the inputs
608
+ random.seed(20000812)
609
+ random.shuffle(inputs)
610
+
611
+ # Split inputs into chunks for parallel processing
612
+ world_size = dist.get_world_size()
613
+ rank = dist.get_rank()
614
+
615
+ span = math.ceil(len(inputs) / world_size)
616
+ sub_inputs = inputs[rank * span: (rank + 1) * span]
617
+
618
+ # Display the progress bar on the rank 0 process
619
+ verbose = self.trainer.local_rank == 0
620
+ if verbose:
621
+ print("Evaluating on each text...")
622
+
623
+ # Add time stamp to the temporary file name to avoid conflicts
624
+ tmp_path = f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
625
+ mpr = MultipleProcessRunnerSimplifier(sub_inputs, do, save_path=tmp_path, n_process=8, verbose=verbose,
626
+ return_results=True)
627
+ outputs = mpr.run()
628
+ os.remove(tmp_path)
629
+
630
+ # Aggregate results
631
+ tensor_outputs = []
632
+ for output in outputs:
633
+ ap, mrr, auc = json.loads(output)
634
+ tensor_outputs.append([float(ap), float(mrr), float(auc)])
635
+
636
+ tensor_outputs = torch.tensor(tensor_outputs, dtype=torch.float32, device=self.device)
637
+ tensor_outputs = self.padded_gather(tensor_outputs)
638
+
639
+ # Record results
640
+ avg_results = {}
641
+ for subsection in swissprot_subsections:
642
+ avg_results[subsection] = {"map": [],
643
+ "mrr": [],
644
+ "auc": []}
645
+
646
+ for input, output in zip(inputs, tensor_outputs):
647
+ ap, mrr, auc = output
648
+ subsection, _, _ = input
649
+
650
+ avg_results[subsection]["map"].append(ap.cpu().item())
651
+ avg_results[subsection]["mrr"].append(mrr.cpu().item())
652
+ avg_results[subsection]["auc"].append(auc.cpu().item())
653
+
654
+ results = {
655
+ f"Text2{modality}_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
656
+ f"Text2{modality}_Total_map": np.mean(avg_results["Total"]["map"]),
657
+ f"Text2{modality}_Total_auc": np.mean(avg_results["Total"]["auc"]),
658
+ }
659
+
660
+ # Average the precision and recall for each level
661
+ for level, labels in [("residue-level", residue_level),
662
+ ("sequence-level", sequence_level),
663
+ ("all", residue_level | sequence_level)]:
664
+
665
+ mrrs = []
666
+ maps = []
667
+ aucs = []
668
+ for subsection in labels:
669
+ if subsection in avg_results:
670
+ mrrs.append(np.mean(avg_results[subsection]["mrr"]))
671
+ maps.append(np.mean(avg_results[subsection]["map"]))
672
+ aucs.append(np.mean(avg_results[subsection]["auc"]))
673
+
674
+ results[f"Text2{modality}_{level}_mrr"] = np.mean(mrrs)
675
+ results[f"Text2{modality}_{level}_map"] = np.mean(maps)
676
+ results[f"Text2{modality}_{level}_auc"] = np.mean(aucs)
677
+
678
+ return results
679
+
680
+ def retrieval_eval(self) -> dict:
681
+ # Get protein representations
682
+ protein_indices = self._get_protein_indices()
683
+
684
+ # Get structure representations
685
+ # if self.structure_config is not None:
686
+ # structure_embeddings = self._get_structure_embeddings()
687
+
688
+ # Get text representations
689
+ text_indices = self._get_text_indices()
690
+
691
+ # Retrieve texts for each protein
692
+ results = {}
693
+ results.update(self._protein2text("Sequence", protein_indices, text_indices))
694
+ # if self.structure_config is not None:
695
+ # results.update(self._protein2text("Structure", structure_embeddings, text_embeddings))
696
+ # results.update(self._text2protein("Structure", structure_embeddings, text_embeddings))
697
+
698
+ # Retrieve proteins for each text
699
+ results.update(self._text2protein("Sequence", protein_indices, text_indices))
700
+
701
+ return results
702
+
703
+ def _apply_bert_mask(self, tokens, tokenizer, mask_ratio):
704
+ while True:
705
+ masked_tokens = copy.copy(tokens)
706
+ labels = torch.full((len(tokens) + 2,), -1, dtype=torch.long)
707
+ vocab = [k for k in tokenizer.get_vocab().keys()]
708
+
709
+ for i in range(len(tokens)):
710
+ token = tokens[i]
711
+
712
+ prob = random.random()
713
+ if prob < mask_ratio:
714
+ prob /= mask_ratio
715
+ labels[i + 1] = tokenizer.convert_tokens_to_ids(token)
716
+
717
+ if prob < 0.8:
718
+ # 80% random change to mask token
719
+ if self.use_saprot:
720
+ token = "#" + token[-1]
721
+ else:
722
+ token = tokenizer.mask_token
723
+ elif prob < 0.9:
724
+ # 10% chance to change to random token
725
+ token = random.choice(vocab)
726
+ else:
727
+ # 10% chance to keep current token
728
+ pass
729
+
730
+ masked_tokens[i] = token
731
+
732
+ # Check if there is at least one masked token
733
+ if (labels != -1).any():
734
+ return masked_tokens, labels
735
+
736
+ def mlm_eval(self) -> float:
737
+ world_size = dist.get_world_size()
738
+ rank = dist.get_rank()
739
+
740
+ if self.use_saprot:
741
+ proteins = []
742
+ for sub_dict in self.uniprot2label.values():
743
+ aa_seq = sub_dict["seq"]
744
+ foldseek_seq = sub_dict["foldseek"]
745
+ assert len(aa_seq) == len(foldseek_seq)
746
+ seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
747
+ proteins.append(seq)
748
+
749
+ else:
750
+ proteins = [sub_dict["seq"] for sub_dict in self.uniprot2label.values()]
751
+
752
+ span = math.ceil(len(proteins) / world_size)
753
+ sub_proteins = proteins[rank * span: (rank + 1) * span]
754
+
755
+ # Display the progress bar on the rank 0 process
756
+ if self.trainer.local_rank == 0:
757
+ iterator = tqdm(sub_proteins, desc="Computing mlm...")
758
+ else:
759
+ iterator = sub_proteins
760
+
761
+ total = torch.tensor([0], dtype=torch.long, device=self.device)
762
+ correct = torch.tensor([0], dtype=torch.long, device=self.device)
763
+ for seq in iterator:
764
+ tokens = self.protein_encoder.tokenizer.tokenize(seq)
765
+ masked_tokens, labels = self._apply_bert_mask(tokens, self.protein_encoder.tokenizer, 0.15)
766
+ seq = " ".join(masked_tokens)
767
+
768
+ inputs = self.protein_encoder.tokenizer(seq, return_tensors="pt")
769
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
770
+ _, logits = self.protein_encoder(inputs, get_mask_logits=True)
771
+
772
+ logits = logits.squeeze(0)
773
+ labels = labels.to(self.device)
774
+
775
+ selecor = labels != -1
776
+ preds = logits.argmax(dim=-1)[selecor]
777
+ labels = labels[selecor]
778
+
779
+ total += len(preds)
780
+ correct += (preds == labels).sum()
781
+
782
+ # Gather all results
783
+ total = self.padded_gather(total).sum()
784
+ correct = self.padded_gather(correct).sum()
785
+
786
+ acc = correct / total
787
+ return acc.cpu().item()
788
+
789
+ def _load_eval_data(self, stage):
790
+ # Load the data
791
+ lmdb_dir = eval(f"self.trainer.datamodule.{stage}_lmdb")
792
+ uniprot2label_path = os.path.join(lmdb_dir, "uniprot2label.json")
793
+ label2text_path = os.path.join(lmdb_dir, "label2text.json")
794
+ swissprot_id_path = os.path.join(lmdb_dir, "swissprot_ids.tsv")
795
+
796
+ self.uniprot2label = json.load(open(uniprot2label_path, "r"))
797
+ self.label2text = json.load(open(label2text_path, "r"))
798
+ self.swissprot_ids = set(pd.read_csv(swissprot_id_path, sep="\t", header=None).values.flatten().tolist())
799
+ self.k = 3
800
+
801
+ def on_test_start(self):
802
+ self._load_eval_data("test")
803
+
804
+ log_dict = self.retrieval_eval()
805
+ log_dict = {"test_" + k: v for k, v in log_dict.items()}
806
+ if self.use_mlm_loss:
807
+ log_dict["test_mask_acc"] = self.mlm_eval()
808
+ self.log_info(log_dict)
809
+ print(log_dict)
810
+
811
+ def on_validation_start(self):
812
+ # Clear the cache
813
+ torch.cuda.empty_cache()
814
+
815
+ self._load_eval_data("valid")
816
+
817
+ log_dict = self.retrieval_eval()
818
+ log_dict = {"valid_" + k: v for k, v in log_dict.items()}
819
+ if self.use_mlm_loss:
820
+ log_dict["valid_mask_acc"] = self.mlm_eval()
821
+ self.log_info(log_dict)
822
+
823
+ self.check_save_condition(self.step, mode="max")
824
+
825
+ def test_step(self, batch, batch_idx):
826
+ return
827
+
828
+ def validation_step(self, batch, batch_idx):
829
+ return
830
+
831
+ def on_train_epoch_end(self):
832
+ super().on_train_epoch_end()
833
+ # Re-sample the subset of the training data
834
+ if self.trainer.datamodule.train_dataset.fixed_dataset_num is not None:
835
+ self.trainer.datamodule.train_dataset.sample_subset()
836
+
837
+ # def test_epoch_end(self, outputs):
838
+ # log_dict = self.get_log_dict("test")
839
+ # log_dict["test_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
840
+ #
841
+ # print(log_dict)
842
+ # self.log_info(log_dict)
843
+ #
844
+ # self.reset_metrics("test")
845
+ #
846
+ # def validation_epoch_end(self, outputs):
847
+ # log_dict = self.get_log_dict("valid")
848
+ # log_dict["valid_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
849
+ #
850
+ # self.log_info(log_dict)
851
+ # self.reset_metrics("valid")
852
+ # self.check_save_condition(log_dict["valid_loss"], mode="min")
853
+
model/ProtTrek/structure_encoder.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from tqdm import tqdm
4
+ from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer
5
+ from torch.nn.functional import normalize
6
+
7
+
8
+ class StructureEncoder(torch.nn.Module):
9
+ def __init__(self, config_path: str, out_dim: int, gradient_checkpointing: bool = False):
10
+ """
11
+ Args:
12
+ config_path: Path to the config file
13
+
14
+ out_dim: Output dimension of the structure representation
15
+
16
+ gradient_checkpointing: Whether to use gradient checkpointing
17
+ """
18
+ super().__init__()
19
+ config = EsmConfig.from_pretrained(config_path)
20
+ self.model = EsmForMaskedLM(config)
21
+ self.out = torch.nn.Linear(config.hidden_size, out_dim)
22
+
23
+ # Set gradient checkpointing
24
+ self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing
25
+
26
+ # Remove contact head
27
+ self.model.esm.contact_head = None
28
+
29
+ # Remove position embedding if the embedding type is ``rotary``
30
+ if config.position_embedding_type == "rotary":
31
+ self.model.esm.embeddings.position_embeddings = None
32
+
33
+ self.tokenizer = EsmTokenizer.from_pretrained(config_path)
34
+
35
+ def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
36
+ """
37
+ Compute protein structure representation for the given proteins
38
+ Args:
39
+ protein: A list of protein structural sequences
40
+ batch_size: Batch size for inference
41
+ verbose: Whether to print progress
42
+ """
43
+ device = next(self.parameters()).device
44
+
45
+ protein_repr = []
46
+ if verbose:
47
+ iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings")
48
+ else:
49
+ iterator = range(0, len(proteins), batch_size)
50
+
51
+ for i in iterator:
52
+ protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size],
53
+ return_tensors="pt",
54
+ padding=True)
55
+ protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()}
56
+ output, _ = self.forward(protein_inputs)
57
+
58
+ protein_repr.append(output)
59
+
60
+ protein_repr = torch.cat(protein_repr, dim=0)
61
+ return normalize(protein_repr, dim=-1)
62
+
63
+ def forward(self, inputs: dict, get_mask_logits: bool = False):
64
+ """
65
+ Encode protein structure into protein representation
66
+ Args:
67
+ inputs: A dictionary containing the following keys:
68
+ - input_ids: [batch, seq_len]
69
+ - attention_mask: [batch, seq_len]
70
+ get_mask_logits: Whether to return the logits for masked tokens
71
+
72
+ Returns:
73
+ protein_repr: [batch, protein_repr_dim]
74
+ mask_logits : [batch, seq_len, vocab_size]
75
+ """
76
+ last_hidden_state = self.model.esm(**inputs).last_hidden_state
77
+ reprs = last_hidden_state[:, 0, :]
78
+ reprs = self.out(reprs)
79
+
80
+ # Get logits for masked tokens
81
+ if get_mask_logits:
82
+ mask_logits = self.model.lm_head(last_hidden_state)
83
+ else:
84
+ mask_logits = None
85
+
86
+ return reprs, mask_logits
model/ProtTrek/text_encoder.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from tqdm import tqdm
4
+ from torch.nn.functional import normalize
5
+ from transformers import BertConfig, BertModel, BertTokenizer
6
+
7
+
8
+ class TextEncoder(torch.nn.Module):
9
+ def __init__(self,
10
+ config_path: str,
11
+ out_dim: int,
12
+ load_pretrained: bool = True,
13
+ gradient_checkpointing: bool = False):
14
+ """
15
+ Args:
16
+ config_path: Path to the config file
17
+
18
+ out_dim: Output dimension of the text representation
19
+
20
+ load_pretrained: Whether to load pretrained weights
21
+
22
+ gradient_checkpointing: Whether to enable gradient checkpointing
23
+ """
24
+ super().__init__()
25
+ config = BertConfig.from_pretrained(config_path)
26
+ if load_pretrained:
27
+ self.model = BertModel.from_pretrained(config_path, add_pooling_layer=False)
28
+ else:
29
+ self.model = BertModel(config, add_pooling_layer=False)
30
+ self.out = torch.nn.Linear(config.hidden_size, out_dim)
31
+
32
+ # Set gradient checkpointing
33
+ self.model.encoder.gradient_checkpointing = gradient_checkpointing
34
+
35
+ self.tokenizer = BertTokenizer.from_pretrained(config_path)
36
+
37
+ def get_repr(self, texts: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor:
38
+ """
39
+ Compute text representation for the given texts
40
+ Args:
41
+ texts: A list of strings
42
+ batch_size: Batch size for inference
43
+ verbose: Whether to print progress
44
+ """
45
+ device = next(self.parameters()).device
46
+
47
+ text_repr = []
48
+ if verbose:
49
+ iterator = tqdm(range(0, len(texts), batch_size), desc="Computing text embeddings")
50
+ else:
51
+ iterator = range(0, len(texts), batch_size)
52
+
53
+ for i in iterator:
54
+ text_inputs = self.tokenizer.batch_encode_plus(texts[i: i+batch_size],
55
+ return_tensors="pt",
56
+ truncation=True,
57
+ max_length=512,
58
+ padding=True)
59
+ text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
60
+ output = self(text_inputs)
61
+
62
+ text_repr.append(output)
63
+
64
+ text_repr = torch.cat(text_repr, dim=0)
65
+ return normalize(text_repr, dim=-1)
66
+
67
+ def forward(self, inputs: dict):
68
+ """
69
+ Encode text into text representation
70
+ Args:
71
+ inputs: A dictionary containing the following keys:
72
+ - input_ids: [batch, seq_len]
73
+ - attention_mask: [batch, seq_len]
74
+ - token_type_ids: [batch, seq_len]
75
+
76
+ Returns:
77
+ text_repr: [batch, text_repr_dim]
78
+ """
79
+ reprs = self.model(**inputs).last_hidden_state[:, 0, :]
80
+ reprs = self.out(reprs)
81
+ return reprs
model/abstract_model.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import abc
3
+ import os
4
+ import copy
5
+
6
+ import pytorch_lightning as pl
7
+ from utils.lr_scheduler import *
8
+ from torch import distributed as dist
9
+
10
+
11
+ class AbstractModel(pl.LightningModule):
12
+ def __init__(self,
13
+ lr_scheduler_kwargs: dict = None,
14
+ optimizer_kwargs: dict = None,
15
+ save_path: str = None,
16
+ from_checkpoint: str = None,
17
+ load_prev_scheduler: bool = False,
18
+ save_weights_only: bool = True,):
19
+ """
20
+
21
+ Args:
22
+ lr_scheduler: Kwargs for lr_scheduler
23
+ optimizer_kwargs: Kwargs for optimizer_kwargs
24
+ save_path: Save trained model
25
+ from_checkpoint: Load model from checkpoint
26
+ load_prev_scheduler: Whether load previous scheduler from checkpoint
27
+ load_strict: Whether load model strictly
28
+ save_weights_only: Whether save only weights or also optimizer and lr_scheduler
29
+
30
+ """
31
+ super().__init__()
32
+ self.initialize_model()
33
+
34
+ self.metrics = {}
35
+ for stage in ["train", "valid", "test"]:
36
+ stage_metrics = self.initialize_metrics(stage)
37
+ # Rigister metrics as attributes
38
+ for metric_name, metric in stage_metrics.items():
39
+ setattr(self, metric_name, metric)
40
+
41
+ self.metrics[stage] = stage_metrics
42
+
43
+ if lr_scheduler_kwargs is None:
44
+ # Default lr_scheduler
45
+ self.lr_scheduler_kwargs = {
46
+ "class": "ConstantLRScheduler",
47
+ "init_lr": 0,
48
+ }
49
+ print("No lr_scheduler_kwargs provided. The default learning rate is 0.")
50
+
51
+ else:
52
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs
53
+
54
+ if optimizer_kwargs is None:
55
+ # Default optimizer
56
+ self.optimizer_kwargs = {
57
+ "class": "AdamW",
58
+ "betas": (0.9, 0.98),
59
+ "weight_decay": 0.01,
60
+ }
61
+ print("No optimizer_kwargs provided. The default optimizer is AdamW.")
62
+ else:
63
+ self.optimizer_kwargs = optimizer_kwargs
64
+ self.init_optimizers()
65
+
66
+ self.save_path = save_path
67
+ self.save_weights_only = save_weights_only
68
+
69
+ # temp_step is used for accumulating gradients
70
+ self.temp_step = 0
71
+ self.step = 0
72
+ self.epoch = 0
73
+
74
+ self.load_prev_scheduler = load_prev_scheduler
75
+ self.from_checkpoint = from_checkpoint
76
+ if from_checkpoint:
77
+ self.load_checkpoint(from_checkpoint)
78
+
79
+ @abc.abstractmethod
80
+ def initialize_model(self) -> None:
81
+ """
82
+ All model initialization should be done here
83
+ Note that the whole model must be named as "self.model" for model saving and loading
84
+ """
85
+ raise NotImplementedError
86
+
87
+ @abc.abstractmethod
88
+ def forward(self, *args, **kwargs):
89
+ """
90
+ Forward propagation
91
+ """
92
+ raise NotImplementedError
93
+
94
+ @abc.abstractmethod
95
+ def initialize_metrics(self, stage: str) -> dict:
96
+ """
97
+ Initialize metrics for each stage
98
+ Args:
99
+ stage: "train", "valid" or "test"
100
+
101
+ Returns:
102
+ A dictionary of metrics for the stage. Keys are metric names and values are metric objects
103
+ """
104
+ raise NotImplementedError
105
+
106
+ @abc.abstractmethod
107
+ def loss_func(self, stage: str, outputs, labels) -> torch.Tensor:
108
+ """
109
+
110
+ Args:
111
+ stage: "train", "valid" or "test"
112
+ outputs: model outputs for calculating loss
113
+ labels: labels for calculating loss
114
+
115
+ Returns:
116
+ loss
117
+
118
+ """
119
+ raise NotImplementedError
120
+
121
+ @staticmethod
122
+ def load_weights(model, weights):
123
+ model_dict = model.state_dict()
124
+
125
+ unused_params = []
126
+ missed_params = list(model_dict.keys())
127
+
128
+ for k, v in weights.items():
129
+ if k in model_dict.keys():
130
+ model_dict[k] = v
131
+ missed_params.remove(k)
132
+
133
+ else:
134
+ unused_params.append(k)
135
+
136
+ if len(missed_params) > 0:
137
+ print(f"\033[31mSome weights of {type(model).__name__} were not "
138
+ f"initialized from the model checkpoint: {missed_params}\033[0m")
139
+
140
+ if len(unused_params) > 0:
141
+ print(f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m")
142
+
143
+ model.load_state_dict(model_dict)
144
+
145
+ def optimizer_step(
146
+ self,
147
+ epoch: int,
148
+ batch_idx: int,
149
+ optimizer,
150
+ optimizer_closure=None,
151
+ ) -> None:
152
+ super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)
153
+
154
+ self.temp_step += 1
155
+ if self.temp_step == self.trainer.accumulate_grad_batches:
156
+ self.step += 1
157
+ self.temp_step = 0
158
+
159
+ # For pytorch-lightning 1.9.5
160
+ # def optimizer_step(
161
+ # self,
162
+ # epoch: int,
163
+ # batch_idx: int,
164
+ # optimizer,
165
+ # optimizer_idx: int = 0,
166
+ # optimizer_closure=None,
167
+ # on_tpu: bool = False,
168
+ # using_native_amp: bool = False,
169
+ # using_lbfgs: bool = False,
170
+ # ) -> None:
171
+ # super().optimizer_step(
172
+ # epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs
173
+ # )
174
+ # self.temp_step += 1
175
+ # if self.temp_step == self.trainer.accumulate_grad_batches:
176
+ # self.step += 1
177
+ # self.temp_step = 0
178
+
179
+ def on_train_epoch_end(self):
180
+ self.epoch += 1
181
+
182
+ def training_step(self, batch, batch_idx):
183
+ inputs, labels = batch
184
+
185
+ # optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.98))
186
+ # for _ in range(1000):
187
+ # outputs = self(**inputs)
188
+ # loss = self.loss_func('train', outputs, labels)
189
+ # loss.backward()
190
+ # optimizer.step()
191
+ # optimizer.zero_grad()
192
+ #
193
+ # raise
194
+
195
+ outputs = self(**inputs)
196
+ loss = self.loss_func('train', outputs, labels)
197
+
198
+ self.log("loss", loss, prog_bar=True)
199
+ return loss
200
+
201
+ def validation_step(self, batch, batch_idx):
202
+ inputs, labels = batch
203
+ outputs = self(**inputs)
204
+ loss = self.loss_func('valid', outputs, labels)
205
+ self.valid_outputs.append(loss)
206
+ return loss
207
+
208
+ def test_step(self, batch, batch_idx):
209
+ inputs, labels = batch
210
+ outputs = self(**inputs)
211
+
212
+ loss = self.loss_func('test', outputs, labels)
213
+ self.test_outputs.append(loss)
214
+ return loss
215
+
216
+ def on_train_start(self) -> None:
217
+ # Load previous scheduler
218
+ if getattr(self, "prev_schechuler", None) is not None:
219
+ try:
220
+ self.step = self.prev_schechuler["global_step"]
221
+ self.epoch = self.prev_schechuler["epoch"]
222
+ self.best_value = self.prev_schechuler["best_value"]
223
+ self.lr_scheduler.load_state_dict(self.prev_schechuler["lr_scheduler"])
224
+ print(f"Previous training global step: {self.step}")
225
+ print(f"Previous training epoch: {self.epoch}")
226
+ print(f"Previous best value: {self.best_value}")
227
+ print(f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}")
228
+
229
+ # Load optimizer state
230
+ if hasattr(self.trainer.strategy, "deepspeed_engine"):
231
+ # For DeepSpeed strategy
232
+ try:
233
+ self.trainer.strategy.deepspeed_engine.load_checkpoint(self.from_checkpoint)
234
+ except Exception as e:
235
+ print(e)
236
+
237
+ else:
238
+ # For DDP strategy
239
+ self.optimizer.load_state_dict(self.prev_schechuler["optimizer"])
240
+
241
+ except Exception as e:
242
+ print(e)
243
+ raise Exception("Error in loading previous scheduler. Please set load_prev_scheduler=False")
244
+
245
+ def on_validation_epoch_start(self) -> None:
246
+ setattr(self, "valid_outputs", [])
247
+
248
+ def on_test_epoch_start(self) -> None:
249
+ setattr(self, "test_outputs", [])
250
+
251
+ def load_checkpoint(self, from_checkpoint: str) -> None:
252
+ """
253
+ Args:
254
+ from_checkpoint: Path to checkpoint.
255
+ """
256
+
257
+ # If ``from_checkpoint`` is a directory, load the checkpoint in it
258
+ if os.path.isdir(from_checkpoint):
259
+ basename = os.path.basename(from_checkpoint)
260
+ from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt")
261
+
262
+ state_dict = torch.load(from_checkpoint, map_location=self.device)
263
+ self.load_weights(self.model, state_dict["model"])
264
+
265
+ if self.load_prev_scheduler:
266
+ state_dict.pop("model")
267
+ self.prev_schechuler = state_dict
268
+
269
+ def save_checkpoint(self, save_path: str, save_info: dict = None, save_weights_only: bool = True) -> None:
270
+ """
271
+ Save model to save_path
272
+ Args:
273
+ save_path: Path to save model
274
+ save_info: Other info to save
275
+ save_weights_only: Whether only save model weights
276
+ """
277
+ dir = os.path.dirname(save_path)
278
+ os.makedirs(dir, exist_ok=True)
279
+
280
+ state_dict = {} if save_info is None else save_info
281
+ state_dict["model"] = self.model.state_dict()
282
+
283
+ # Convert model weights to fp32
284
+ for k, v in state_dict["model"].items():
285
+ state_dict["model"][k] = v.float()
286
+
287
+ if not save_weights_only:
288
+ state_dict["global_step"] = self.step
289
+ state_dict["epoch"] = self.epoch
290
+ state_dict["best_value"] = getattr(self, f"best_value", None)
291
+ state_dict["lr_scheduler"] = self.lr_schedulers().state_dict()
292
+
293
+ # If not using DeepSpeed, save optimizer state
294
+ if not hasattr(self.trainer.strategy, "deepspeed_engine"):
295
+ state_dict["optimizer"] = self.optimizers().optimizer.state_dict()
296
+
297
+ torch.save(state_dict, save_path)
298
+
299
+ def check_save_condition(self, now_value: float, mode: str, save_info: dict = None) -> None:
300
+ """
301
+ Check whether to save model. If save_path is not None and now_value is the best, save model.
302
+ Args:
303
+ now_value: Current metric value
304
+ mode: "min" or "max", meaning whether the lower the better or the higher the better
305
+ save_info: Other info to save
306
+ """
307
+
308
+ assert mode in ["min", "max"], "mode should be 'min' or 'max'"
309
+
310
+ if self.save_path is not None:
311
+ # In case there are variables to be included in the save path
312
+ save_path = eval(f"f'{self.save_path}'")
313
+
314
+ dir = os.path.dirname(save_path)
315
+ os.makedirs(dir, exist_ok=True)
316
+
317
+ # Check whether to save model
318
+ best_value = getattr(self, f"best_value", None)
319
+ if best_value is not None:
320
+ if mode == "min" and now_value >= best_value or mode == "max" and now_value <= best_value:
321
+ return
322
+
323
+ setattr(self, "best_value", now_value)
324
+
325
+ # For DeepSpeed strategy
326
+ if hasattr(self.trainer.strategy, "deepspeed_engine"):
327
+ if not self.save_weights_only:
328
+ self.trainer.strategy.deepspeed_engine.save_checkpoint(save_path, tag="deepspeed_ckpt")
329
+
330
+ # Save a complete checkpoint
331
+ if dist.get_rank() == 0:
332
+ basename = os.path.basename(save_path)
333
+ ckpt_path = os.path.join(save_path, f"{basename}.pt")
334
+ self.save_checkpoint(ckpt_path, save_info, self.save_weights_only)
335
+
336
+ # For normal situation
337
+ else:
338
+ if dist.get_rank() == 0:
339
+ self.save_checkpoint(save_path, save_info, self.save_weights_only)
340
+
341
+ def reset_metrics(self, stage) -> None:
342
+ """
343
+ Reset metrics for given stage
344
+ Args:
345
+ stage: "train", "valid" or "test"
346
+ """
347
+ for metric in self.metrics[stage].values():
348
+ metric.reset()
349
+
350
+ def get_log_dict(self, stage: str) -> dict:
351
+ """
352
+ Get log dict for the stage
353
+ Args:
354
+ stage: "train", "valid" or "test"
355
+
356
+ Returns:
357
+ A dictionary of metrics for the stage. Keys are metric names and values are metric values
358
+
359
+ """
360
+ return {name: metric.compute() for name, metric in self.metrics[stage].items()}
361
+
362
+ def log_info(self, info: dict) -> None:
363
+ """
364
+ Record metrics during training and testing
365
+ Args:
366
+ info: dict of metrics
367
+ """
368
+ if getattr(self, "logger", None) is not None and dist.get_rank() == 0:
369
+ info["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
370
+ info["epoch"] = self.epoch
371
+ self.logger.log_metrics(info, step=self.step)
372
+
373
+ def init_optimizers(self):
374
+ copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs)
375
+
376
+ # No decay for layer norm and bias
377
+ no_decay = ['LayerNorm.weight', 'bias']
378
+ weight_decay = copy_optimizer_kwargs.pop("weight_decay")
379
+
380
+ optimizer_grouped_parameters = [
381
+ {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
382
+ 'weight_decay': weight_decay},
383
+ {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
384
+ 'weight_decay': 0.0}
385
+ ]
386
+
387
+ optimizer_cls = eval(f"torch.optim.{copy_optimizer_kwargs.pop('class')}")
388
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters,
389
+ lr=self.lr_scheduler_kwargs['init_lr'],
390
+ **copy_optimizer_kwargs)
391
+
392
+ tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs)
393
+ lr_scheduler = tmp_kwargs.pop("class")
394
+ self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs)
395
+
396
+ def configure_optimizers(self):
397
+ return {"optimizer": self.optimizer,
398
+ "lr_scheduler": {"scheduler": self.lr_scheduler,
399
+ "interval": "step",
400
+ "frequency": 1}
401
+ }
model/model_interface.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import glob
4
+
5
+
6
+ # register all available models through *_model.py files
7
+ # def construct_model():
8
+ # model_dir = os.path.dirname(__file__)
9
+ #
10
+ # # lists all model files
11
+ # model_list = []
12
+ # for root, _, names in os.walk(model_dir):
13
+ # for name in names:
14
+ # if name.endswith('_model.py'):
15
+ # sub_dirs = root.replace(model_dir, '').split(os.sep)
16
+ # model_list.append((sub_dirs, name[:-3]))
17
+ #
18
+ # # load model_config.yaml, controlling which models to be loaded
19
+ # model_config = yaml.safe_load(open(f"{model_dir}/model_config.yaml", "r"))
20
+ #
21
+ # if model_config["verbose"]:
22
+ # print("*" * 30 + f" Loading model " + "*" * 30)
23
+ #
24
+ # # register models
25
+ # for sub_dirs, name in model_list:
26
+ # if name in model_config["models"]:
27
+ # if len(sub_dirs) > 1:
28
+ # cmd = f"from {'.'.join(sub_dirs)} import {name}"
29
+ # else:
30
+ # cmd = f"from . import {name}"
31
+ #
32
+ # exec(cmd)
33
+ #
34
+ # if model_config["verbose"]:
35
+ # info = f"Loaded model: {name}"
36
+ # print(f"\033[32m{info}\033[0m")
37
+ # else:
38
+ # if model_config["verbose"]:
39
+ # info = f"Skipped model: {name}"
40
+ # print(f"\033[31m{info}\033[0m")
41
+ #
42
+ # if model_config["verbose"]:
43
+ # print("*" * 75)
44
+ #
45
+ #
46
+ # # register function as a wrapper for all models
47
+ # def register_model(cls):
48
+ # model_dict[cls.__name__] = cls
49
+ # return cls
50
+ #
51
+ #
52
+ # model_dict = {}
53
+ # construct_model()
54
+ #
55
+ #
56
+ # class ModelInterface:
57
+ # @classmethod
58
+ # def get_available_models(cls):
59
+ # return model_dict.keys()
60
+ #
61
+ # @classmethod
62
+ # def init_model(cls, model: str, **kwargs):
63
+ # """
64
+ #
65
+ # Args:
66
+ # model : Class name of model you want to use. Must be in model_dict.keys()
67
+ # **kwargs: Kwargs for model initialization
68
+ #
69
+ # Returns: Corresponding model
70
+ #
71
+ # """
72
+ # assert model in model_dict.keys(), f"class {model} doesn't exist!"
73
+ # return model_dict[model](**kwargs)
74
+
75
+
76
+ ########################################################################
77
+ # Version 2 #
78
+ ########################################################################
79
+ # register function as a wrapper for all models
80
+ def register_model(cls):
81
+ global now_cls
82
+ now_cls = cls
83
+ return cls
84
+
85
+
86
+ now_cls = None
87
+
88
+
89
+ class ModelInterface:
90
+ @classmethod
91
+ def init_model(cls, model_py_path: str, **kwargs):
92
+ """
93
+
94
+ Args:
95
+ model_py_path: Py file Path of model you want to use.
96
+ **kwargs: Kwargs for model initialization
97
+
98
+ Returns: Corresponding model
99
+ """
100
+ sub_dirs = model_py_path.split(os.sep)
101
+ cmd = f"from {'.' + '.'.join(sub_dirs[:-1])} import {sub_dirs[-1]}"
102
+ exec(cmd)
103
+
104
+ return now_cls(**kwargs)
utils/constants.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+
3
+
4
+ aa_set = {"A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"}
5
+ aa_list = ["A", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y"]
6
+
7
+ foldseek_seq_vocab = "ACDEFGHIKLMNPQRSTVWY#"
8
+ foldseek_struc_vocab = "pynwrqhgdlvtmfsaeikc#"
9
+
10
+ struc_unit = "abcdefghijklmnopqrstuvwxyz"
11
+
12
+
13
+ def create_vocab(size: int) -> dict:
14
+ """
15
+
16
+ Args:
17
+ size: Size of the vocabulary
18
+
19
+ Returns:
20
+ vocab: Vocabulary
21
+ """
22
+
23
+ token_len = 1
24
+ while size > len(struc_unit) ** token_len:
25
+ token_len += 1
26
+
27
+ vocab = {}
28
+ for i, token in enumerate(itertools.product(struc_unit, repeat=token_len)):
29
+ vocab[i] = "".join(token)
30
+ if len(vocab) == size:
31
+ vocab[i+1] = "#"
32
+ return vocab
33
+
34
+ # ProTrek
35
+ residue_level = {"Active site", "Binding site", "Site", "DNA binding", "Natural variant", "Mutagenesis",
36
+ "Transmembrane", "Topological domain", "Intramembrane", "Signal peptide", "Propeptide",
37
+ "Transit peptide",
38
+ "Chain", "Peptide", "Modified residue", "Lipidation", "Glycosylation", "Disulfide bond",
39
+ "Cross-link",
40
+ "Domain", "Repeat", "Compositional bias", "Region", "Coiled coil", "Motif"}
41
+
42
+ sequence_level = {"Function", "Miscellaneous", "Caution", "Catalytic activity", "Cofactor", "Activity regulation",
43
+ "Biophysicochemical properties", "Pathway", "Involvement in disease", "Allergenic properties",
44
+ "Toxic dose", "Pharmaceutical use", "Disruption phenotype", "Subcellular location",
45
+ "Post-translational modification", "Subunit", "Domain (non-positional annotation)",
46
+ "Sequence similarities", "RNA Editing", "Tissue specificity", "Developmental stage", "Induction",
47
+ "Biotechnology", "Polymorphism", "GO annotation", "Proteomes", "Protein names", "Gene names",
48
+ "Organism", "Taxonomic lineage", "Virus host"}
49
+
50
+ raw_text_level = {"Function", "Subunit", "Tissue specificity", "Disruption phenotype", "Post-translational modification",
51
+ "Induction", "Miscellaneous", "Sequence similarities", "Developmental stage",
52
+ "Domain (non-positional annotation)", "Activity regulation", "Caution", "Polymorphism", "Toxic dose",
53
+ "Allergenic properties", "Pharmaceutical use", "Cofactor", "Biophysicochemical properties",
54
+ "Subcellular location", "RNA Editing"}
utils/foldseek_util.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import numpy as np
5
+ import re
6
+ import sys
7
+ sys.path.append(".")
8
+
9
+
10
+ # Get structural seqs from pdb file
11
+ def get_struc_seq(foldseek,
12
+ path,
13
+ chains: list = None,
14
+ process_id: int = 0,
15
+ plddt_mask: bool = False,
16
+ plddt_threshold: float = 70.,
17
+ foldseek_verbose: bool = False) -> dict:
18
+ """
19
+
20
+ Args:
21
+ foldseek: Binary executable file of foldseek
22
+
23
+ path: Path to pdb file
24
+
25
+ chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
26
+
27
+ process_id: Process ID for temporary files. This is used for parallel processing.
28
+
29
+ plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
30
+
31
+ plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
32
+
33
+ foldseek_verbose: If True, foldseek will print verbose messages.
34
+
35
+ Returns:
36
+ seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
37
+ (seq, struc_seq, combined_seq).
38
+ """
39
+ assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
40
+ assert os.path.exists(path), f"PDB file not found: {path}"
41
+
42
+ tmp_save_path = f"get_struc_seq_{process_id}_{time.time()}.tsv"
43
+ if foldseek_verbose:
44
+ cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
45
+ else:
46
+ cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
47
+ os.system(cmd)
48
+
49
+ seq_dict = {}
50
+ name = os.path.basename(path)
51
+ with open(tmp_save_path, "r") as r:
52
+ for i, line in enumerate(r):
53
+ desc, seq, struc_seq = line.split("\t")[:3]
54
+
55
+ # Mask low plddt
56
+ if plddt_mask:
57
+ plddts = extract_plddt(path)
58
+ assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
59
+
60
+ # Mask regions with plddt < threshold
61
+ indices = np.where(plddts < plddt_threshold)[0]
62
+ np_seq = np.array(list(struc_seq))
63
+ np_seq[indices] = "#"
64
+ struc_seq = "".join(np_seq)
65
+
66
+ name_chain = desc.split(" ")[0]
67
+ chain = name_chain.replace(name, "").split("_")[-1]
68
+
69
+ if chains is None or chain in chains:
70
+ if chain not in seq_dict:
71
+ combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
72
+ seq_dict[chain] = (seq, struc_seq, combined_seq)
73
+
74
+ os.remove(tmp_save_path)
75
+ os.remove(tmp_save_path + ".dbtype")
76
+ return seq_dict
77
+
78
+
79
+ def extract_plddt(pdb_path: str) -> np.ndarray:
80
+ """
81
+ Extract plddt scores from pdb file.
82
+ Args:
83
+ pdb_path: Path to pdb file.
84
+
85
+ Returns:
86
+ plddts: plddt scores.
87
+ """
88
+ with open(pdb_path, "r") as r:
89
+ plddt_dict = {}
90
+ for line in r:
91
+ line = re.sub(' +', ' ', line).strip()
92
+ splits = line.split(" ")
93
+
94
+ if splits[0] == "ATOM":
95
+ # If position < 1000
96
+ if len(splits[4]) == 1:
97
+ pos = int(splits[5])
98
+
99
+ # If position >= 1000, the blank will be removed, e.g. "A 999" -> "A1000"
100
+ # So the length of splits[4] is not 1
101
+ else:
102
+ pos = int(splits[4][1:])
103
+
104
+ plddt = float(splits[-2])
105
+
106
+ if pos not in plddt_dict:
107
+ plddt_dict[pos] = [plddt]
108
+ else:
109
+ plddt_dict[pos].append(plddt)
110
+
111
+ plddts = np.array([np.mean(v) for v in plddt_dict.values()])
112
+ return plddts
113
+
114
+
115
+ if __name__ == '__main__':
116
+ foldseek = "/sujin/bin/foldseek"
117
+ # test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
118
+ test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
119
+ plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
120
+ res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
121
+ print(res["A"][1].lower())
utils/lr_scheduler.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
4
+
5
+
6
+ class ConstantLRScheduler(_LRScheduler):
7
+ def __init__(self,
8
+ optimizer,
9
+ last_epoch: int = -1,
10
+ verbose: bool = False,
11
+ init_lr: float = 0.,
12
+ ):
13
+ """
14
+ This is an implementation of constant learning rate scheduler.
15
+ Args:
16
+ optimizer: Optimizer
17
+
18
+ last_epoch: The index of last epoch. Default: -1
19
+
20
+ verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
21
+
22
+ init_lr: Initial learning rate
23
+ """
24
+
25
+ self.init_lr = init_lr
26
+ super().__init__(optimizer, last_epoch, verbose)
27
+
28
+ def state_dict(self):
29
+ state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
30
+ return state_dict
31
+
32
+ def load_state_dict(self, state_dict):
33
+ self.__dict__.update(state_dict)
34
+
35
+ def get_lr(self):
36
+ if not self._get_lr_called_within_step:
37
+ raise RuntimeError(
38
+ "To get the last learning rate computed by the scheduler, use "
39
+ "get_last_lr()"
40
+ )
41
+
42
+ return [self.init_lr for group in self.optimizer.param_groups]
43
+
44
+
45
+ class CosineAnnealingLRScheduler(_LRScheduler):
46
+ def __init__(self,
47
+ optimizer,
48
+ last_epoch: int = -1,
49
+ verbose: bool = False,
50
+ init_lr: float = 0.,
51
+ max_lr: float = 4e-4,
52
+ final_lr: float = 4e-5,
53
+ warmup_steps: int = 2000,
54
+ cosine_steps: int = 10000,
55
+ ):
56
+ """
57
+ This is an implementation of cosine annealing learning rate scheduler.
58
+ Args:
59
+ optimizer: Optimizer
60
+
61
+ last_epoch: The index of last epoch. Default: -1
62
+
63
+ verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
64
+
65
+ init_lr: Initial learning rate
66
+
67
+ max_lr: Maximum learning rate after warmup
68
+
69
+ final_lr: Final learning rate after decay
70
+
71
+ warmup_steps: Number of steps for warmup
72
+
73
+ cosine_steps: Number of steps for cosine annealing
74
+ """
75
+
76
+ self.init_lr = init_lr
77
+ self.max_lr = max_lr
78
+ self.final_lr = final_lr
79
+ self.warmup_steps = warmup_steps
80
+ self.cosine_steps = cosine_steps
81
+ super(CosineAnnealingLRScheduler, self).__init__(optimizer, last_epoch, verbose)
82
+
83
+ def state_dict(self):
84
+ state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
85
+ return state_dict
86
+
87
+ def load_state_dict(self, state_dict):
88
+ self.__dict__.update(state_dict)
89
+
90
+ def get_lr(self):
91
+ if not self._get_lr_called_within_step:
92
+ raise RuntimeError(
93
+ "To get the last learning rate computed by the scheduler, use "
94
+ "get_last_lr()"
95
+ )
96
+
97
+ step_no = self.last_epoch
98
+
99
+ if step_no <= self.warmup_steps:
100
+ lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr)
101
+
102
+ else:
103
+ lr = self.final_lr + 0.5 * (self.max_lr - self.final_lr) \
104
+ * (1 + math.cos(math.pi * (step_no - self.warmup_steps) / self.cosine_steps))
105
+
106
+ return [lr for group in self.optimizer.param_groups]
107
+
108
+
109
+ class Esm2LRScheduler(_LRScheduler):
110
+ def __init__(self,
111
+ optimizer,
112
+ last_epoch: int = -1,
113
+ verbose: bool = False,
114
+ init_lr: float = 0.,
115
+ max_lr: float = 4e-4,
116
+ final_lr: float = 4e-5,
117
+ warmup_steps: int = 2000,
118
+ start_decay_after_n_steps: int = 500000,
119
+ end_decay_after_n_steps: int = 5000000,
120
+ on_use: bool = True,
121
+ ):
122
+ """
123
+ This is an implementation of ESM2's learning rate scheduler.
124
+ Args:
125
+ optimizer: Optimizer
126
+
127
+ last_epoch: The index of last epoch. Default: -1
128
+
129
+ verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
130
+
131
+ init_lr: Initial learning rate
132
+
133
+ max_lr: Maximum learning rate after warmup
134
+
135
+ final_lr: Final learning rate after decay
136
+
137
+ warmup_steps: Number of steps for warmup
138
+
139
+ start_decay_after_n_steps: Start decay after this number of steps
140
+
141
+ end_decay_after_n_steps: End decay after this number of steps
142
+
143
+ on_use: Whether to use this scheduler. If ``False``, the scheduler will not change the learning rate
144
+ and will only use the ``init_lr``. Default: ``True``
145
+ """
146
+
147
+ self.init_lr = init_lr
148
+ self.max_lr = max_lr
149
+ self.final_lr = final_lr
150
+ self.warmup_steps = warmup_steps
151
+ self.start_decay_after_n_steps = start_decay_after_n_steps
152
+ self.end_decay_after_n_steps = end_decay_after_n_steps
153
+ self.on_use = on_use
154
+ super(Esm2LRScheduler, self).__init__(optimizer, last_epoch, verbose)
155
+
156
+ def state_dict(self):
157
+ state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
158
+ return state_dict
159
+
160
+ def load_state_dict(self, state_dict):
161
+ self.__dict__.update(state_dict)
162
+
163
+ def get_lr(self):
164
+ if not self._get_lr_called_within_step:
165
+ raise RuntimeError(
166
+ "To get the last learning rate computed by the scheduler, use "
167
+ "get_last_lr()"
168
+ )
169
+
170
+ step_no = self.last_epoch
171
+ if not self.on_use:
172
+ return [base_lr for base_lr in self.base_lrs]
173
+
174
+ if step_no <= self.warmup_steps:
175
+ lr = self.init_lr + step_no / self.warmup_steps * (self.max_lr - self.init_lr)
176
+
177
+ elif step_no <= self.start_decay_after_n_steps:
178
+ lr = self.max_lr
179
+
180
+ elif step_no <= self.end_decay_after_n_steps:
181
+ portion = (step_no - self.start_decay_after_n_steps) / (self.end_decay_after_n_steps - self.start_decay_after_n_steps)
182
+ lr = self.max_lr - portion * (self.max_lr - self.final_lr)
183
+
184
+ else:
185
+ lr = self.final_lr
186
+
187
+ return [lr for group in self.optimizer.param_groups]
utils/mpr.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import os
3
+ import time
4
+ import sys
5
+
6
+
7
+ from tqdm import tqdm
8
+ from math import ceil
9
+
10
+
11
+ class MultipleProcessRunner:
12
+ """
13
+ Abstarct class for running tasks with multiple process
14
+ There are three abstract methods that should be implemented:
15
+ 1. __len__() : return the length of data
16
+ 2. _target() : target function for each process
17
+ 3. _aggregate() : aggregate results from each process
18
+ """
19
+
20
+ def __init__(self,
21
+ data,
22
+ save_path=None,
23
+ n_process=1,
24
+ verbose=True,
25
+ total_only=True,
26
+ log_step=1,
27
+ start_method='fork'):
28
+ """
29
+ Args:
30
+ data : data to be processed that can be sliced
31
+
32
+ path : final output path
33
+
34
+ n_process: number of process
35
+
36
+ verbose : if True, display progress bar
37
+
38
+ total_only: If True, only total progress bar is displayed
39
+
40
+ log_step : For total progress bar, Next log will be printed when
41
+ ``current iteration`` - ``last log iteration`` >= log_step
42
+
43
+ start_method: start method for multiprocessing
44
+ """
45
+ self.data = data
46
+ self.save_path = save_path
47
+ self.n_process = n_process
48
+ self.verbose = verbose
49
+ self.total_only = total_only
50
+ self.log_step = log_step
51
+ self.start_method = start_method
52
+
53
+ # get terminal width to format output
54
+ try:
55
+ self.terminal_y = os.get_terminal_size()[0]
56
+
57
+ except Exception as e:
58
+ print(e)
59
+ print("Can't get terminal size, set terminal_y = None")
60
+ self.terminal_y = None
61
+
62
+ def _s2hms(self, seconds: float):
63
+ """
64
+ convert second format of time into hour:minute:second format
65
+
66
+ """
67
+ m, s = divmod(seconds, 60)
68
+ h, m = divmod(m, 60)
69
+
70
+ return "%02d:%02d:%02d" % (h, m, s)
71
+
72
+ def _display_time(self, st_time, now, total):
73
+ ed_time = time.time()
74
+ running_time = ed_time - st_time
75
+ rest_time = running_time * (total - now) / now
76
+ iter_sec = f"{now / running_time:.2f}it/s" if now > running_time else f"{running_time / now:.2f}s/it"
77
+
78
+ return f' [{self._s2hms(running_time)} < {self._s2hms(rest_time)}, {iter_sec}]'
79
+
80
+ def _display_bar(self, now, total, length):
81
+ now = now if now <= total else total
82
+ num = now * length // total
83
+ progress_bar = '[' + '#' * num + '_' * (length - num) + ']'
84
+ return progress_bar
85
+
86
+ def _display_all(self, now, total, desc, st_time):
87
+ # make a progress bar
88
+ length = 50
89
+ progress_bar = self._display_bar(now, total, length)
90
+ time_display = self._display_time(st_time, now, total)
91
+
92
+ display = f'{desc}{progress_bar} {int(now / total * 100):02d}% {now}/{total}{time_display}'
93
+
94
+ # Clean a line
95
+ width = self.terminal_y if self.terminal_y is not None else 100
96
+ num_space = width - len(display)
97
+ if num_space > 0:
98
+ display += ' ' * num_space
99
+ else:
100
+ length += num_space
101
+ progress_bar = self._display_bar(now, total, length)
102
+ display = f'{desc}{progress_bar} {int(now / total * 100):02d}% {now}/{total}{time_display}'
103
+
104
+ # Set color
105
+ display = f"\033[31m{display}\033[0m"
106
+
107
+ return display
108
+
109
+ # Print progress bar at specific position in terminal
110
+ def terminal_progress_bar(self,
111
+ process_id: int,
112
+ now: int,
113
+ total: int,
114
+ desc: str = ''):
115
+ """
116
+
117
+ Args:
118
+ process_id: process id
119
+ now: now iteration number
120
+ total: total iteration number
121
+ desc: description
122
+
123
+ """
124
+ st_time = self.process_st_time[process_id]
125
+
126
+ # Aggregate total information
127
+ self.counts[process_id] = now
128
+ self._total_display(self.process_st_time["total"])
129
+
130
+ if not self.total_only:
131
+ process_display = self._display_all(now, total, desc, st_time)
132
+ if self.terminal_y is not None:
133
+ sys.stdout.write(f"\x1b7\x1b[{process_id + 1};{0}f{process_display}\x1b8")
134
+ sys.stdout.flush()
135
+ else:
136
+ print(f"\x1b7\x1b[{process_id + 1};{0}f{process_display}\x1b8", flush=True)
137
+
138
+ # Print global information
139
+ def _total_display(self, st_time):
140
+ if self.total_display_callable.value == 1:
141
+ self.total_display_callable.value = 0
142
+
143
+ cnt = sum([self.counts[i] for i in range(self.n_process)])
144
+ if cnt - self.last_cnt.value >= self.log_step:
145
+ total_display = self._display_all(cnt, self.__len__(), f"Total: ", st_time)
146
+ self.last_cnt.value = cnt
147
+
148
+ x = self.n_process + 1 if not self.total_only else 0
149
+ # if self.terminal_y is not None:
150
+ # sys.stdout.write(f"\x1b7\x1b[{x};{0}f{total_display}\x1b8")
151
+ # sys.stdout.flush()
152
+ # else:
153
+ # print(f"\x1b7\x1b[{x};{0}f{total_display}\x1b8", flush=True)
154
+ print(f"\r\x1b7\x1b[{x};{0}f{total_display}\x1b8", flush=True, end="")
155
+
156
+ self.total_display_callable.value = 1
157
+
158
+ def run(self):
159
+ """
160
+ The function is used to run a multi-process task
161
+ Returns: return the result of function '_aggregate()'
162
+ """
163
+
164
+ import multiprocess as mp
165
+ mp.set_start_method(self.start_method, force=True)
166
+
167
+ # total number of data that is already processed
168
+ self.counts = mp.Manager().dict({i: 0 for i in range(self.n_process)})
169
+
170
+ # record start time for each process
171
+ self.process_st_time = {"total": time.time()}
172
+
173
+ # set a lock to call total number display
174
+ self.total_display_callable = mp.Value('d', 1)
175
+
176
+ # Save last log iteration number
177
+ self.last_cnt = mp.Value('d', 0)
178
+
179
+ num_per_process = ceil(self.__len__() / self.n_process)
180
+
181
+ if self.save_path is not None:
182
+ file_name, suffix = os.path.splitext(self.save_path)
183
+
184
+ process_list = []
185
+ sub_paths = []
186
+ for i in range(self.n_process):
187
+ st = i * num_per_process
188
+ ed = st + num_per_process
189
+
190
+ # construct slice and sub path for sub process
191
+ data_slice = self.data[st: ed]
192
+
193
+ sub_path = None
194
+ # Create a directory to save sub-results
195
+ if self.save_path is not None:
196
+ save_dir = f"{file_name}{suffix}_temp"
197
+ os.makedirs(save_dir, exist_ok=True)
198
+ sub_path = f"{save_dir}/temp_{i}{suffix}"
199
+
200
+ # construct sub process
201
+ input_args = (i, data_slice, sub_path)
202
+ self.process_st_time[i] = time.time()
203
+ p = mp.Process(target=self._target, args=input_args)
204
+ p.start()
205
+
206
+ process_list.append(p)
207
+ sub_paths.append(sub_path)
208
+
209
+ for p in process_list:
210
+ p.join()
211
+
212
+ # aggregate results and remove temporary directory
213
+ results = self._aggregate(self.save_path, sub_paths)
214
+ if self.save_path is not None:
215
+ save_dir = f"{file_name}{suffix}_temp"
216
+ os.rmdir(save_dir)
217
+
218
+ return results
219
+
220
+ def parallel_run(self):
221
+ import multiprocess as mp
222
+ from joblib import Parallel, delayed
223
+
224
+ # total number of data that is already processed
225
+ self.counts = mp.Manager().dict({i: 0 for i in range(self.n_process)})
226
+
227
+ # record start time for each process
228
+ self.process_st_time = {"total": time.time()}
229
+
230
+ # set a lock to call total number display
231
+ self.total_display_callable = mp.Value('d', 1)
232
+
233
+ # Save last log iteration number
234
+ self.last_cnt = mp.Value('d', 0)
235
+
236
+ num_per_process = ceil(self.__len__() / self.n_process)
237
+
238
+ if self.save_path is not None:
239
+ file_name, suffix = os.path.splitext(self.save_path)
240
+
241
+ sub_paths = []
242
+ input_arg_list = []
243
+ for i in range(self.n_process):
244
+ st = i * num_per_process
245
+ ed = st + num_per_process
246
+
247
+ # construct slice and sub path for sub process
248
+ data_slice = self.data[st: ed]
249
+
250
+ sub_path = None
251
+ # Create a directory to save sub-results
252
+ if self.save_path is not None:
253
+ save_dir = f"{file_name}{suffix}_temp"
254
+ os.makedirs(save_dir, exist_ok=True)
255
+ sub_path = f"{save_dir}/temp_{i}{suffix}"
256
+
257
+ # construct sub process
258
+ input_args = (i, data_slice, sub_path)
259
+ self.process_st_time[i] = time.time()
260
+
261
+ sub_paths.append(sub_path)
262
+ input_arg_list.append(input_args)
263
+
264
+ # Start parallel processing
265
+ Parallel(n_jobs=self.n_process)(delayed(self._target)(input_args) for input_args in input_arg_list)
266
+
267
+ # aggregate results and remove temporary directory
268
+ results = self._aggregate(self.save_path, sub_paths)
269
+ if self.save_path is not None:
270
+ save_dir = f"{file_name}{suffix}_temp"
271
+ os.rmdir(save_dir)
272
+
273
+ return results
274
+
275
+
276
+ @abc.abstractmethod
277
+ def _aggregate(self, final_path: str, sub_paths):
278
+ """
279
+ This function is used to aggregate results from sub processes into a file
280
+
281
+ Args:
282
+ final_path: path to save final results
283
+ sub_paths : list of sub paths
284
+
285
+ Returns: None or desirable results specified by user
286
+
287
+ """
288
+ raise NotImplementedError
289
+
290
+ @abc.abstractmethod
291
+ def _target(self, process_id, data, sub_path):
292
+ """
293
+ The main body to operate data in one process
294
+
295
+ Args:
296
+ i : process id
297
+ data : data slice
298
+ sub_path: sub path to save results
299
+ """
300
+ raise NotImplementedError
301
+
302
+ @abc.abstractmethod
303
+ def __len__(self):
304
+ raise NotImplementedError
305
+
306
+
307
+ class MultipleProcessRunnerSimplifier(MultipleProcessRunner):
308
+ """
309
+ A simplified version of MultipleProcessRunner.
310
+ User only need to implement the function 'do', then it will be automatically executed
311
+ in every iteration after call the function 'run'.
312
+ If 'save_path' is specified, it will open a file in the 'sub_path' into which
313
+ user can write results, and results will be aggregated into 'save_path'.
314
+
315
+ The procedure would be like:
316
+ ...
317
+ with open(sub_path, 'w') as w:
318
+ for i, d in enumerate(data):
319
+ self.do(process_id, i, d, w) # You can write results into the file.
320
+ ...
321
+
322
+ The 'do' function should be like:
323
+ def do(process_id, idx, data, writer):
324
+ ...
325
+
326
+ If 'save_path' is None, the argument 'writer' will be set to None.
327
+
328
+ """
329
+
330
+ def __init__(self,
331
+ data,
332
+ do,
333
+ save_path=None,
334
+ n_process=1,
335
+ verbose=True,
336
+ total_only=True,
337
+ log_step=1,
338
+ return_results=False,
339
+ start_method='fork'):
340
+
341
+ super().__init__(data=data,
342
+ save_path=save_path,
343
+ n_process=n_process,
344
+ verbose=verbose,
345
+ total_only=total_only,
346
+ log_step=log_step,
347
+ start_method=start_method)
348
+ self.do = do
349
+ self.return_results = return_results
350
+
351
+ def run(self):
352
+ self.start_time = time.time()
353
+ return super().run()
354
+
355
+ def _aggregate(self, final_path: str, sub_paths):
356
+ results = []
357
+
358
+ w = open(final_path, 'w') if final_path is not None else None
359
+
360
+ if self.verbose:
361
+ iterator = tqdm(enumerate(sub_paths), "Aggregating results...")
362
+ else:
363
+ iterator = enumerate(sub_paths)
364
+
365
+ for i, sub_path in iterator:
366
+ if sub_path is None and self.return_results:
367
+ sub_path = f"MultipleProcessRunnerSimplifier_{self.start_time}_{i}.tmp"
368
+
369
+ if sub_path is not None:
370
+ with open(sub_path, 'r') as r:
371
+ for line in r:
372
+ if w is not None:
373
+ w.write(line)
374
+
375
+ if self.return_results:
376
+ results.append(line[:-1])
377
+
378
+ os.remove(sub_path)
379
+
380
+ return results
381
+
382
+ def _target(self, process_id, data, sub_path):
383
+ if sub_path is None and self.return_results:
384
+ sub_path = f"MultipleProcessRunnerSimplifier_{self.start_time}_{process_id}.tmp"
385
+
386
+ w = open(sub_path, 'w') if sub_path is not None else None
387
+ for i, d in enumerate(data):
388
+ self.do(process_id, i, d, w)
389
+ if self.verbose:
390
+ self.terminal_progress_bar(process_id, i + 1, len(data), f"Process{process_id} running...")
391
+
392
+ if w is not None:
393
+ w.close()
394
+
395
+ def __len__(self):
396
+ return len(self.data)
397
+