wenkai commited on
Commit
4fa6436
1 Parent(s): 5b023d0

Delete FAPM_inference.py

Browse files
Files changed (1) hide show
  1. FAPM_inference.py +0 -76
FAPM_inference.py DELETED
@@ -1,76 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import pandas as pd
4
- import torch.nn.functional as F
5
- from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
6
- # from lavis.models.base_model import FAPMConfig
7
- # from lavis.models.blip2_models.blip2_opt import Blip2ProteinOPT
8
- import random
9
- from lavis.models.base_model import FAPMConfig
10
-
11
- prop = True
12
-
13
- # model = Blip2ProteinOPT(config=FAPMConfig(), esm_size='3b')
14
- # model.load_checkpoint('/cluster/home/wenkai/LAVIS/lavis/output/BLIP2/Pretrain_stage2/20240327081/checkpoint_2.pth')
15
- model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
16
- model.load_checkpoint('model/checkpoint_mf2.pth')
17
- # model.from_pretrained('/cluster/home/wenkai/FAPM_model/mf')
18
- model.to('cuda')
19
-
20
- # esm_emb = torch.load('/cluster/home/wenkai/LAVIS/data/pretrain/ipr_domain_emb_esm2_3b/Gp49.pt')['representations'][36]
21
- esm_emb = torch.load('data/emb_esm2_3b/P18281.pt')['representations'][36]
22
- esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
23
- samples = {'name': ['P18281'],
24
- 'image': torch.unsqueeze(esm_emb, dim=0),
25
- 'text_input': ['actin monomer binding'],
26
- 'prompt': ['Acanthamoeba']}
27
- prediction = model.generate(samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1., repetition_penalty=1.0)
28
- print(f"Text Prediction: {prediction}")
29
-
30
-
31
- if prop == True:
32
- from data.evaluate_data.utils import Ontology
33
- import difflib
34
- import re
35
-
36
- # godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
37
- godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
38
-
39
- go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
40
- go_des.columns = ['id', 'text']
41
- go_des = go_des.dropna()
42
- go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
43
- go_obo_set = set(go_des['id'].tolist())
44
- go_des['text'] = go_des['text'].apply(lambda x: x.lower())
45
- GO_dict = dict(zip(go_des['text'], go_des['id']))
46
- Func_dict = dict(zip(go_des['id'], go_des['text']))
47
-
48
- # terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
49
- terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
50
- choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
51
- choices = {x.lower(): x for x in choices_mf}
52
-
53
- pred_terms_list = []
54
- pred_go_list = []
55
- prop_annotations = []
56
- for x in prediction:
57
- x = [eval(i) for i in x.split('; ')]
58
- pred_terms = []
59
- pred_go = []
60
- annot_set = set()
61
- for i in x:
62
- txt = i[0]
63
- prob = i[1]
64
- sim_list = difflib.get_close_matches(txt.lower(), choices, n=1, cutoff=0.9)
65
- if len(sim_list) > 0:
66
- pred_terms.append((sim_list[0], prob))
67
- pred_go.append((GO_dict[sim_list[0]], prob))
68
- annot_set |= godb.get_anchestors(GO_dict[sim_list[0]])
69
- pred_terms_list.append(pred_terms)
70
- pred_go_list.append(pred_go)
71
- annots = list(annot_set)
72
- prop_annotations.append(annots)
73
-
74
- print(f"Predictions of GO terms: \n{pred_terms_list} \nPredictions of GO id: \n{pred_go_list} \nPredictions of GO id propgated: \n{prop_annotations}")
75
-
76
-