ProteinGPT-Llama3 / utils.py
EdwardoSunny's picture
finished
85ab89d
# import json
# from tqdm import tqdm
# import matplotlib.pyplot as plt
# import numpy as np
# f = open("/home/ubuntu/proteinedit-mm-clean/data/esm_subset/abstract.json", "r")
# ann = json.load(f)
# total = 0
# l_256 = 0
# l_384 = 0
# x = []
# for i in tqdm(range(0, len(ann))):
# total += len(ann[i]["caption"].split())
# if (len(ann[i]["caption"].split()) <= 256):
# l_256 += 1
# if (len(ann[i]["caption"].split()) <= 384):
# l_384 += 1
# x.append(len(ann[i]["caption"].split()))
# x = np.array(x)
# print("avg: ", str(total / len(ann)))
# print("below 256: ", str(l_256 / len(ann)))
# print("below 384: ", str(l_384 / len(ann)))
# plt.hist(x)
# plt.savefig("test.png")
from minigpt4.datasets.qa_dataset import QADataset
datasets_raw = QADataset(pdb_root="/home/ubuntu/pt/",
seq_root="/home/ubuntu/seq/",
ann_paths="/home/ubuntu/proteinchat/data/esm_subset/qa_all.json",
dataset_description="/home/ubuntu/dataset.json",
chain="A")
print(datasets_raw[0]["q_input"])
print(datasets_raw[0]["a_input"])
print(len(datasets_raw))
import esm
import torch
from esm.inverse_folding.util import load_coords
device = 'cuda'
# pdb_file = '/home/ubuntu/7md4.pdb'
# pdb_file = "/home/ubuntu/8t3r.pdb"
def encode(file):
pdb_file = f'/home/ubuntu/test_pdb/{file}'
coords, native_seq = load_coords(pdb_file, "A")
print(native_seq)
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval().to(device)
sampled_seq, encoder_out = model.sample(coords, temperature=1,
device=torch.device(device))
sample_protein = encoder_out["encoder_out"][0].to(device)
print(sample_protein.shape)
# python -m pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
# python -m pip install torch-sparse -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
# python -m pip install torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
# python -m pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
# python -m pip install torch-geometric
# torch.Size([1, 32, 2560])
# /home/ubuntu/test_pdb
# 1jj9.pdb 2cma.pdb 3lhj.pdb 5p11.pdb 6jzt.pdb
encode('1jj9.pdb')
encode('2cma.pdb')
encode('3lhj.pdb')
encode('5p11.pdb')
encode('6jzt.pdb')