Spaces:
Running
Running
# import json | |
# from tqdm import tqdm | |
# import matplotlib.pyplot as plt | |
# import numpy as np | |
# f = open("/home/ubuntu/proteinedit-mm-clean/data/esm_subset/abstract.json", "r") | |
# ann = json.load(f) | |
# total = 0 | |
# l_256 = 0 | |
# l_384 = 0 | |
# x = [] | |
# for i in tqdm(range(0, len(ann))): | |
# total += len(ann[i]["caption"].split()) | |
# if (len(ann[i]["caption"].split()) <= 256): | |
# l_256 += 1 | |
# if (len(ann[i]["caption"].split()) <= 384): | |
# l_384 += 1 | |
# x.append(len(ann[i]["caption"].split())) | |
# x = np.array(x) | |
# print("avg: ", str(total / len(ann))) | |
# print("below 256: ", str(l_256 / len(ann))) | |
# print("below 384: ", str(l_384 / len(ann))) | |
# plt.hist(x) | |
# plt.savefig("test.png") | |
from minigpt4.datasets.qa_dataset import QADataset | |
datasets_raw = QADataset(pdb_root="/home/ubuntu/pt/", | |
seq_root="/home/ubuntu/seq/", | |
ann_paths="/home/ubuntu/proteinchat/data/esm_subset/qa_all.json", | |
dataset_description="/home/ubuntu/dataset.json", | |
chain="A") | |
print(datasets_raw[0]["q_input"]) | |
print(datasets_raw[0]["a_input"]) | |
print(len(datasets_raw)) | |
import esm | |
import torch | |
from esm.inverse_folding.util import load_coords | |
device = 'cuda' | |
# pdb_file = '/home/ubuntu/7md4.pdb' | |
# pdb_file = "/home/ubuntu/8t3r.pdb" | |
def encode(file): | |
pdb_file = f'/home/ubuntu/test_pdb/{file}' | |
coords, native_seq = load_coords(pdb_file, "A") | |
print(native_seq) | |
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() | |
model = model.eval().to(device) | |
sampled_seq, encoder_out = model.sample(coords, temperature=1, | |
device=torch.device(device)) | |
sample_protein = encoder_out["encoder_out"][0].to(device) | |
print(sample_protein.shape) | |
# python -m pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cu121.html | |
# python -m pip install torch-sparse -f https://data.pyg.org/whl/torch-2.3.0+cu121.html | |
# python -m pip install torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html | |
# python -m pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html | |
# python -m pip install torch-geometric | |
# torch.Size([1, 32, 2560]) | |
# /home/ubuntu/test_pdb | |
# 1jj9.pdb 2cma.pdb 3lhj.pdb 5p11.pdb 6jzt.pdb | |
encode('1jj9.pdb') | |
encode('2cma.pdb') | |
encode('3lhj.pdb') | |
encode('5p11.pdb') | |
encode('6jzt.pdb') |