wenkai commited on
Commit
b7b6da7
1 Parent(s): 4fa6436

Upload 2 files

Browse files
Files changed (2) hide show
  1. FAPM_inference.py +86 -0
  2. README.md +75 -0
FAPM_inference.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pandas as pd
4
+ import torch.nn.functional as F
5
+ from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
6
+ # from lavis.models.base_model import FAPMConfig
7
+ # from lavis.models.blip2_models.blip2_opt import Blip2ProteinOPT
8
+ import random
9
+ from lavis.models.base_model import FAPMConfig
10
+ import argparse
11
+
12
+ prop = True
13
+
14
+ if __name__ == '__main__':
15
+ parser = argparse.ArgumentParser(description='FAPM')
16
+ parser.add_argument('--model_path', type=str, help='Dataset path')
17
+ parser.add_argument('--example_path', type=str, help='Example protein path')
18
+ parser.add_argument('--device', type=str, default='cuda', help='Which gpu to use if any (default: cuda)')
19
+ parser.add_argument('--prompt', type=str, default='none', help='Input prompt for protein function prediction')
20
+ parser.add_argument('--ground_truth', type=str, default='none', help='ground truth function')
21
+ args = parser.parse_args()
22
+ test_sdf_paths = args.model_path
23
+
24
+ # model = Blip2ProteinOPT(config=FAPMConfig(), esm_size='3b')
25
+ # model.load_checkpoint('/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20240327081/checkpoint_2.pth')
26
+ model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
27
+ model.load_checkpoint(args.model_path)
28
+ model.to(args.device)
29
+
30
+ # esm_emb = torch.load('/cluster/home/wenkai/LAVIS/data/pretrain/ipr_domain_emb_esm2_3b/Gp49.pt')['representations'][36]
31
+ esm_emb = torch.load(args.example_path)['representations'][36]
32
+ esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
33
+ samples = {'name': ['P18281'],
34
+ 'image': torch.unsqueeze(esm_emb, dim=0),
35
+ 'text_input': [args.ground_truth],
36
+ 'prompt': [args.prompt]}
37
+ prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
38
+ print(f"Text Prediction: {prediction}")
39
+
40
+
41
+ if prop == True:
42
+ from data.evaluate_data.utils import Ontology
43
+ import difflib
44
+ import re
45
+
46
+ # godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
47
+ godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
48
+
49
+ go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
50
+ go_des.columns = ['id', 'text']
51
+ go_des = go_des.dropna()
52
+ go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
53
+ go_obo_set = set(go_des['id'].tolist())
54
+ go_des['text'] = go_des['text'].apply(lambda x: x.lower())
55
+ GO_dict = dict(zip(go_des['text'], go_des['id']))
56
+ Func_dict = dict(zip(go_des['id'], go_des['text']))
57
+
58
+ # terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
59
+ terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
60
+ choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
61
+ choices = {x.lower(): x for x in choices_mf}
62
+
63
+ pred_terms_list = []
64
+ pred_go_list = []
65
+ prop_annotations = []
66
+ for x in prediction:
67
+ x = [eval(i) for i in x.split('; ')]
68
+ pred_terms = []
69
+ pred_go = []
70
+ annot_set = set()
71
+ for i in x:
72
+ txt = i[0]
73
+ prob = i[1]
74
+ sim_list = difflib.get_close_matches(txt.lower(), choices, n=1, cutoff=0.9)
75
+ if len(sim_list) > 0:
76
+ pred_terms.append((sim_list[0], prob))
77
+ pred_go.append((GO_dict[sim_list[0]], prob))
78
+ annot_set |= godb.get_anchestors(GO_dict[sim_list[0]])
79
+ pred_terms_list.append(pred_terms)
80
+ pred_go_list.append(pred_go)
81
+ annots = list(annot_set)
82
+ prop_annotations.append(annots)
83
+
84
+ print(f"Predictions of GO terms: \n{pred_terms_list} \nPredictions of GO id: \n{pred_go_list} \nPredictions of GO id propgated: \n{prop_annotations}")
85
+
86
+
README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Introduction
2
+ <p align="center">
3
+ <br>
4
+ <img src="assets/FAPM.png"/>
5
+ <br>
6
+ <p>
7
+
8
+ ## Installation
9
+
10
+ 1. (Optional) Creating conda environment
11
+
12
+ ```bash
13
+ conda create -n lavis python=3.8
14
+ conda activate lavis
15
+ ```
16
+
17
+ 2. for development, you may build from source
18
+
19
+ ```bash
20
+ git clone https://github.com/xiangwenkai/FAPM.git
21
+ cd FAPM
22
+ pip install -e .
23
+
24
+ pip install Biopython
25
+ pip install fair-esm
26
+ ```
27
+
28
+ ### Datasets
29
+ #### 1.raw dataset
30
+ Raw data are avaliable at *https://ftp.uniprot.org/pub/databases/uniprot/previous_releases/release-2023_04/knowledgebase/*, this file is very large and need to be processed to get its name, sequence, GO label, function description and prompt.
31
+ The domain level protein dataset we used are avaliable at *https://ftp.ebi.ac.uk/pub/databases/interpro/releases/95.0/protein2ipr.dat.gz*
32
+ In this respository, We provide the experimental train/val/test sets of Swiss-Prot, which are avaliable at data/swissprot_exp
33
+ #### 2.ESM2 embeddings
34
+ Source code for ESM2 embeddings generation: *https://github.com/facebookresearch/esm*
35
+ The generation command:
36
+ ```bash
37
+ python esm_scripts/extract.py esm2_t33_3B_UR50D you_path/protein.fasta you_path_to_save_embedding_files --repr_layers 36 --truncation_seq_length 1024 --include per_tok
38
+ ```
39
+ The default path to save embedding files in this respository is **data/emb_esm2_3b**
40
+
41
+ ## Pretraining language models
42
+ Source: *https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B*
43
+
44
+ ## Training
45
+ data config: lavis/configs/datasets/protein/GO_defaults_cap.yaml
46
+ stage1 config: lavis/projects/blip2/train/protein_pretrain_stage1.yaml
47
+ stage1 training command: run_scripts/blip2/train/protein_pretrain_domain_stage1.sh
48
+ stage2 config: lavis/projects/blip2/train/protein_pretrain_stage2.yaml
49
+ stage2 training/finetuning command: run_scripts/blip2/train/protein_pretrain_domain_stage2.sh
50
+
51
+ ## Trained models
52
+ You can download our trained models from drive: *https://drive.google.com/drive/folders/1aA0eSYxNw3DvrU5GU1Cu-4q2kIxxAGSE?usp=drive_link*
53
+
54
+ ## Testing
55
+ config: lavis/projects/blip2/eval/caption_protein_eval.yaml
56
+ command: run_scripts/blip2/eval/eval_cap_protein.sh
57
+
58
+ ## Inference example
59
+ ```
60
+ python FAPM_inference.py \
61
+ --model_path model/checkpoint_mf2.pth \
62
+ --example_path data/emb_esm2_3b/P18281.pt \
63
+ --device cuda \
64
+ --prompt Acanthamoeba
65
+ ```
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+