File size: 2,439 Bytes
85ab89d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# import json
# from tqdm import tqdm
# import matplotlib.pyplot as plt
# import numpy as np

# f = open("/home/ubuntu/proteinedit-mm-clean/data/esm_subset/abstract.json", "r")
# ann = json.load(f)

# total = 0
# l_256 = 0
# l_384 = 0
# x = []
# for i in tqdm(range(0, len(ann))):
#     total += len(ann[i]["caption"].split())
#     if (len(ann[i]["caption"].split()) <= 256):
#         l_256 += 1
#     if (len(ann[i]["caption"].split()) <= 384):
#         l_384 += 1

#     x.append(len(ann[i]["caption"].split()))

# x = np.array(x)
# print("avg: ", str(total / len(ann)))
# print("below 256: ", str(l_256 / len(ann)))
# print("below 384: ", str(l_384 / len(ann)))
# plt.hist(x)
# plt.savefig("test.png")

from minigpt4.datasets.qa_dataset import QADataset 

datasets_raw = QADataset(pdb_root="/home/ubuntu/pt/",
                                seq_root="/home/ubuntu/seq/",
                                ann_paths="/home/ubuntu/proteinchat/data/esm_subset/qa_all.json",
                                dataset_description="/home/ubuntu/dataset.json",
                                chain="A")
print(datasets_raw[0]["q_input"])
print(datasets_raw[0]["a_input"])
print(len(datasets_raw))





import esm
import torch
from esm.inverse_folding.util import load_coords

device = 'cuda'

# pdb_file = '/home/ubuntu/7md4.pdb'
# pdb_file = "/home/ubuntu/8t3r.pdb"
def encode(file):
    pdb_file = f'/home/ubuntu/test_pdb/{file}'
    coords, native_seq = load_coords(pdb_file, "A")
    print(native_seq)
    model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
    model = model.eval().to(device)
    sampled_seq, encoder_out = model.sample(coords, temperature=1,
                                            device=torch.device(device))
    sample_protein = encoder_out["encoder_out"][0].to(device)

    print(sample_protein.shape)

# python -m pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
# python -m pip install torch-sparse -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
# python -m pip install torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
# python -m pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
# python -m pip install torch-geometric

# torch.Size([1, 32, 2560])


# /home/ubuntu/test_pdb
# 1jj9.pdb  2cma.pdb  3lhj.pdb  5p11.pdb  6jzt.pdb
encode('1jj9.pdb')
encode('2cma.pdb')
encode('3lhj.pdb')
encode('5p11.pdb')
encode('6jzt.pdb')