# import json # from tqdm import tqdm # import matplotlib.pyplot as plt # import numpy as np # f = open("/home/ubuntu/proteinedit-mm-clean/data/esm_subset/abstract.json", "r") # ann = json.load(f) # total = 0 # l_256 = 0 # l_384 = 0 # x = [] # for i in tqdm(range(0, len(ann))): # total += len(ann[i]["caption"].split()) # if (len(ann[i]["caption"].split()) <= 256): # l_256 += 1 # if (len(ann[i]["caption"].split()) <= 384): # l_384 += 1 # x.append(len(ann[i]["caption"].split())) # x = np.array(x) # print("avg: ", str(total / len(ann))) # print("below 256: ", str(l_256 / len(ann))) # print("below 384: ", str(l_384 / len(ann))) # plt.hist(x) # plt.savefig("test.png") from minigpt4.datasets.qa_dataset import QADataset datasets_raw = QADataset(pdb_root="/home/ubuntu/pt/", seq_root="/home/ubuntu/seq/", ann_paths="/home/ubuntu/proteinchat/data/esm_subset/qa_all.json", dataset_description="/home/ubuntu/dataset.json", chain="A") print(datasets_raw[0]["q_input"]) print(datasets_raw[0]["a_input"]) print(len(datasets_raw)) import esm import torch from esm.inverse_folding.util import load_coords device = 'cuda' # pdb_file = '/home/ubuntu/7md4.pdb' # pdb_file = "/home/ubuntu/8t3r.pdb" def encode(file): pdb_file = f'/home/ubuntu/test_pdb/{file}' coords, native_seq = load_coords(pdb_file, "A") print(native_seq) model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50() model = model.eval().to(device) sampled_seq, encoder_out = model.sample(coords, temperature=1, device=torch.device(device)) sample_protein = encoder_out["encoder_out"][0].to(device) print(sample_protein.shape) # python -m pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+cu121.html # python -m pip install torch-sparse -f https://data.pyg.org/whl/torch-2.3.0+cu121.html # python -m pip install torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+cu121.html # python -m pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html # python -m pip install torch-geometric # torch.Size([1, 32, 2560]) # /home/ubuntu/test_pdb # 1jj9.pdb 2cma.pdb 3lhj.pdb 5p11.pdb 6jzt.pdb encode('1jj9.pdb') encode('2cma.pdb') encode('3lhj.pdb') encode('5p11.pdb') encode('6jzt.pdb')