|
import gzip |
|
import json |
|
import os |
|
import pickle |
|
from abc import abstractmethod |
|
from os.path import exists |
|
from typing import List |
|
import string |
|
import random |
|
import biotite.structure |
|
import numpy as np |
|
import pandas as pd |
|
import socket |
|
import torch |
|
from Bio.PDB import PDBParser |
|
from Bio.PDB.DSSP import DSSP |
|
from biopandas.pdb import PandasPdb |
|
from biotite.sequence import ProteinSequence |
|
from biotite.structure import get_chains |
|
from biotite.structure.io import pdbx, pdb |
|
from biotite.structure.residues import get_residues |
|
from torch_cluster import radius_graph, knn_graph |
|
|
|
|
|
AF2_DATA_PATH = './data.files/af2.files/' |
|
|
|
AF2_REP_DATA_PATH = "NA" |
|
|
|
ESM_MODEL_SIZE = '650M' |
|
ESM_DATA_PATH = f'./data.files/esm.files/' |
|
|
|
MSA_DATA_PATH_ARCHIVE = './data.files/gMVP.MSA/' |
|
MSA_DATA_PATH = './data.files/MSA/' |
|
|
|
PAE_DATA_PATH = 'NA' |
|
|
|
|
|
MSA_ATTN_DATA_PATH = './data.files/esm.MSA/' |
|
NUM_THREADS = 42 |
|
|
|
with open(f'./utils/LANGUAGE_MODEL.{ESM_MODEL_SIZE}.pkl', 'rb') as f: |
|
LANGUAGE_MODEL = pickle.load(f) |
|
with open(f'./utils/ALPHABET_CONVERTER.{ESM_MODEL_SIZE}.pkl', 'rb') as f: |
|
ALPHABET_CONVERTER = pickle.load(f) |
|
with open(f'./utils/ESM_AA_EMBEDDING_DICT.{ESM_MODEL_SIZE}.pkl', 'rb') as f: |
|
ESM_AA_EMBEDDING_DICT = pickle.load(f) |
|
with open(f'./utils/ESM_AA_EMBEDDING_DICT.esm1b.pkl', 'rb') as f: |
|
ESM1b_AA_EMBEDDING_DICT = pickle.load(f) |
|
|
|
with open(f'./utils/AA_5_DIM_EMBED.pkl', 'rb') as f: |
|
AA_5DIM_EMBED = pickle.load(f) |
|
|
|
ESM_TOKENS = ['<cls>', '<pad>', '<eos>', '<unk>', |
|
'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', |
|
'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', |
|
'X', 'B', 'U', 'Z', 'O', '.', '-', |
|
'<null_1>', '<mask>'] |
|
|
|
AA_DICT = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', |
|
'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', |
|
'X', 'B', 'U', 'Z', 'O', '<mask>'] |
|
AA_DICT_HUMAN = ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', |
|
'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C'] |
|
DSSP_DICT = ['H', 'B', 'E', 'G', 'I', 'T', 'S', '-', 'P'] |
|
PTM_DICT = {'ac': 0, 'ga': 1, 'gl': 2, 'm1': 3, 'm2': 4, 'm3': 5, 'me': 6, 'p': 7, 'sm': 8, 'ub': 9} |
|
|
|
class Mutation: |
|
""" |
|
A mutation object that stores the information of a mutation. |
|
Can specify max_len of sequence to crop the sequence. |
|
Can specify af2_file to ignore the input sequence and use the AF2 sequence instead. |
|
""" |
|
def __init__(self, uniprot_id, transcript_id, seq_orig, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len=2251, af2_file=None): |
|
|
|
self.seq = None |
|
self.seq_start = None |
|
self.seq_end = None |
|
self.seq_start_orig = None |
|
self.seq_end_orig = None |
|
self.pos = None |
|
self.uniprot_id = None |
|
self.af2_file = None |
|
self.af2_rep_file_prefix = None |
|
self.af2_seq_index = None |
|
self.msa_seq_index = None |
|
self.esm_seq_index = None |
|
self.af2_rep_index = None |
|
self.ref_aa = None |
|
self.alt_aa = None |
|
self.ESM_prefix = None |
|
self.crop = False |
|
self.seq_len = None |
|
self.seq_len_orig = None |
|
self.max_len = max_len |
|
self.half_max_len = max_len // 2 |
|
self.set_af2_fragment_idx(seq_orig, seq_orig_len, uniprot_id, pos_orig, af2_file) |
|
self.transcript_id = transcript_id |
|
self.set_ref_alt_aa(ref_aa, alt_aa) |
|
self.init_af2_file_idx() |
|
self.crop_fn() |
|
|
|
def set_af2_fragment_idx(self, seq_orig, seq_orig_len, uniprot_id, pos_orig, af2_file): |
|
self.seq_len_orig = seq_orig_len |
|
if isinstance(pos_orig, str): |
|
pos_orig = np.array([int(i) for i in pos_orig.split(';')]) |
|
else: |
|
pos_orig = np.array([int(pos_orig)]) |
|
if af2_file is None or pd.isna(af2_file): |
|
if uniprot_id.find('-F') != -1: |
|
idx = int(uniprot_id.split('-F')[-1]) |
|
uniprot_id = uniprot_id.split('-F')[0] |
|
seq_start = 1 |
|
seq_end = seq_orig_len |
|
self.seq_start_orig = seq_start |
|
self.seq_end_orig = seq_end |
|
seq = seq_orig |
|
pos = pos_orig |
|
self.ESM_prefix = f'{uniprot_id}-F{idx}' |
|
seq_len = 1400 |
|
self.af2_rep_file_prefix = f'{AF2_REP_DATA_PATH}/{uniprot_id}-F{idx}/{uniprot_id}-F{idx}' |
|
else: |
|
self.ESM_prefix = f'{uniprot_id}' |
|
if seq_orig_len > 2700: |
|
idx = min(max(1, pos_orig[0] // 200 - 2), seq_orig_len // 200 - 5) |
|
seq_start = (idx - 1) * 200 + 1 |
|
seq_end = min((idx + 6) * 200, seq_orig_len) |
|
self.seq_start_orig = seq_start |
|
self.seq_end_orig = seq_end |
|
seq = seq_orig[seq_start - 1:seq_end] |
|
pos = pos_orig - seq_start + 1 |
|
seq_len = seq_end - seq_start + 1 |
|
seq_start = 1 |
|
seq_end = seq_len |
|
else: |
|
idx = 1 |
|
seq_start = 1 |
|
seq_end = seq_orig_len |
|
self.seq_start_orig = seq_start |
|
self.seq_end_orig = seq_end |
|
seq_len = seq_orig_len |
|
seq = seq_orig |
|
pos = pos_orig |
|
if uniprot_id == "Q8WZ42": |
|
self.ESM_prefix = f'{uniprot_id}-F{idx}' |
|
if seq_orig_len >= 7000: |
|
self.af2_rep_file_prefix = f'{AF2_REP_DATA_PATH}/{uniprot_id}-F{idx}/{uniprot_id}-F{idx}' |
|
else: |
|
self.af2_rep_file_prefix = f'{AF2_REP_DATA_PATH}/{uniprot_id}/{uniprot_id}' |
|
self.seq = seq |
|
self.seq_start = seq_start |
|
self.seq_end = seq_end |
|
self.seq_len = seq_len |
|
self.pos = pos |
|
self.uniprot_id = uniprot_id |
|
self.af2_file = f'{AF2_DATA_PATH}/AF-{uniprot_id}-F{idx}-model_v4.pdb.gz' |
|
else: |
|
self.af2_file = af2_file |
|
self.ESM_prefix = uniprot_id |
|
self.seq = seq_orig |
|
self.seq_start = 1 |
|
self.seq_end = seq_orig_len |
|
self.seq_start_orig = self.seq_start |
|
self.seq_end_orig = self.seq_end |
|
self.seq_len = seq_orig_len |
|
self.pos = pos_orig |
|
self.uniprot_id = uniprot_id |
|
|
|
def set_ref_alt_aa(self, ref_aa, alt_aa): |
|
|
|
if ";" in ref_aa or ";" in alt_aa: |
|
|
|
self.ref_aa = np.array(ref_aa.split(';')) |
|
self.alt_aa = np.array(alt_aa.split(';')) |
|
else: |
|
|
|
self.ref_aa = np.array([ref_aa]) |
|
self.alt_aa = np.array([alt_aa]) |
|
|
|
def init_af2_file_idx(self): |
|
if not exists(self.af2_file): |
|
print(f'Warning: {self.uniprot_id} AF2 file not found: {self.af2_file}') |
|
self.af2_file = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.af2_seq_index = None |
|
|
|
def crop_fn(self): |
|
seq_len = self.seq_len |
|
pos = self.pos |
|
seq_start = self.seq_start |
|
seq_end = self.seq_end |
|
seq = self.seq |
|
|
|
if seq_len >= self.max_len: |
|
if pos[0] <= self.half_max_len: |
|
seq_start = 1 |
|
seq_end = self.max_len |
|
seq = seq[:self.max_len] |
|
pos = pos |
|
seq_len = self.max_len |
|
elif seq_len - pos[0] <= self.max_len - self.half_max_len: |
|
seq_start = seq_len - self.max_len + 1 |
|
seq_end = seq_len |
|
seq = seq[seq_start - 1:] |
|
pos = pos - seq_start + 1 |
|
seq_len = self.max_len |
|
else: |
|
seq_start = pos[0] - self.half_max_len |
|
seq_end = pos[0] + self.max_len - self.half_max_len - 1 |
|
seq = seq[seq_start - 1:seq_end] |
|
pos = pos - seq_start + 1 |
|
seq_len = self.max_len |
|
self.crop = True |
|
self.seq = seq |
|
self.seq_start = seq_start |
|
self.seq_end = seq_end |
|
self.seq_len = seq_len |
|
self.pos = pos |
|
|
|
def set_af2_seq_index(self, idx): |
|
self.af2_seq_index = idx |
|
|
|
def set_msa_seq_index(self, idx): |
|
self.msa_seq_index = idx |
|
|
|
def set_esm_seq_index(self, idx): |
|
self.esm_seq_index = idx |
|
|
|
def set_af2_rep_index(self, idx): |
|
self.af2_rep_index = idx |
|
|
|
|
|
class RandomPointMutation(Mutation): |
|
def __init__(self, uniprot_id, transcript_id, seq_orig, seq_orig_len, max_len=2251): |
|
pos_orig = np.random.randint(1, seq_orig_len + 1) |
|
ref_aa = seq_orig[pos_orig - 1] |
|
alt_aa = np.random.choice(list("ACDEFGHIKLMNPQRSTVWY")) |
|
super().__init__(uniprot_id, transcript_id, seq_orig, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len) |
|
|
|
|
|
class MaskPredictPointMutation(Mutation): |
|
|
|
def __init__(self, uniprot_id, transcript_id, seq_orig, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len=2251, af2_file=None): |
|
if pos_orig is None or pos_orig == 0: |
|
pos_orig = np.random.randint(1, seq_orig_len + 1) |
|
self.ESM_prefix = None |
|
self.max_len = max_len |
|
self.half_max_len = max_len // 2 |
|
super().__init__(uniprot_id, transcript_id, seq_orig, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len=max_len, af2_file=af2_file) |
|
|
|
|
|
def init_af2_file_idx(self): |
|
|
|
if not exists(self.af2_file): |
|
print(f'Warning: {self.uniprot_id} AF2 file not found: {self.af2_file}') |
|
self.af2_file = None |
|
self.af2_seq_index = None |
|
|
|
|
|
def convert_to_onesite(dataset: pd.DataFrame): |
|
|
|
if 'ref_aa' not in dataset.columns: |
|
dataset['ref_aa'] = dataset['ref'] |
|
if 'alt_aa' not in dataset.columns: |
|
dataset['alt_aa'] = dataset['alt'] |
|
dataset_onesite = dataset.copy(deep=True) |
|
dataset_onesite = dataset_onesite.drop_duplicates(subset=['uniprotID', 'pos.orig']) |
|
|
|
|
|
|
|
score_cols = [col for col in dataset.columns if col.startswith('score')] |
|
confidence_cols = [col for col in dataset.columns if col.startswith('confidence.score')] |
|
|
|
if len(confidence_cols) == 0: |
|
confidence_cols = [f'confidence.score.{i}' for i in range(len(score_cols))] |
|
for col in confidence_cols: |
|
dataset[col] = 1 |
|
dataset_onesite[col] = 1 |
|
for i in dataset_onesite.index: |
|
subdataset = dataset[(dataset['uniprotID'] == dataset_onesite.loc[i, 'uniprotID']) & (dataset['pos.orig'] == dataset_onesite.loc[i, 'pos.orig'])] |
|
dataset_onesite.loc[i, 'ref_aa'] = ';'.join(subdataset['ref_aa'].values) |
|
dataset_onesite.loc[i, 'alt_aa'] = ';'.join(subdataset['alt_aa'].values) |
|
|
|
if len(score_cols) > 0: |
|
for col in score_cols: |
|
dataset_onesite.loc[i, col] = ';'.join(subdataset[col].values.astype('str')) |
|
if len(confidence_cols) > 0: |
|
for col in confidence_cols: |
|
dataset_onesite.loc[i, col] = ';'.join(subdataset[col].values.astype('str')) |
|
return dataset_onesite |
|
|
|
|
|
def load_structure(fpath, chain=None): |
|
""" |
|
Args: |
|
fpath: filepath to either pdb or cif file |
|
chain: the chain id or list of chain ids to load |
|
Returns: |
|
biotite.structure.AtomArray |
|
""" |
|
if fpath.endswith('cif'): |
|
with open(fpath) as fin: |
|
pdbxf = pdbx.PDBxFile.read(fin) |
|
structure = pdbx.get_structure(pdbxf, model=1) |
|
elif fpath.endswith('cif.gz'): |
|
with gzip.open(fpath, 'rt') as fin: |
|
pdbxf = pdbx.PDBxFile.read(fin) |
|
structure = pdbx.get_structure(pdbxf, model=1) |
|
elif fpath.endswith('pdb'): |
|
with open(fpath) as fin: |
|
pdbf = pdb.PDBFile.read(fin) |
|
structure = pdb.get_structure(pdbf, model=1) |
|
elif fpath.endswith('pdb.gz'): |
|
with gzip.open(fpath, 'rt') as fin: |
|
pdbf = pdb.PDBFile.read(fin) |
|
structure = pdb.get_structure(pdbf, model=1) |
|
else: |
|
raise ValueError("Invalid file extension") |
|
|
|
|
|
all_chains = get_chains(structure) |
|
if len(all_chains) == 0: |
|
raise ValueError('No chains found in the input file.') |
|
if chain is None: |
|
chain_ids = all_chains |
|
elif isinstance(chain, list): |
|
chain_ids = chain |
|
else: |
|
chain_ids = [chain] |
|
for chain in chain_ids: |
|
if chain not in all_chains: |
|
raise ValueError(f'Chain {chain} not found in input file') |
|
chain_filter = [a.chain_id in chain_ids for a in structure] |
|
structure = structure[chain_filter] |
|
return structure |
|
|
|
|
|
def extract_coords_from_structure(structure: biotite.structure.AtomArray): |
|
""" |
|
Args: |
|
structure: An instance of biotite AtomArray |
|
Returns: |
|
Tuple coords |
|
- coords is an L x 5 x 3 array for N, C, O, CA, CB coordinates |
|
""" |
|
coords = get_atom_coords_residue_wise(["N", "C", "O", "CA", "CB"], structure) |
|
return coords |
|
|
|
|
|
def extract_sidechain_from_structure(structure: biotite.structure.AtomArray): |
|
""" |
|
Args: |
|
structure: An instance of biotite AtomArray |
|
Returns: |
|
Tuple coords |
|
- coords is an L x 31 x 3 array for side chain coordinates |
|
""" |
|
coords = get_atom_coords_residue_wise(['CD', 'CD1', 'CD2', 'CE', 'CE1', |
|
'CE2', 'CE3', 'CG', 'CG1', 'CG2', |
|
'CH2', 'CZ', 'CZ2', 'CZ3', 'ND1', |
|
'ND2', 'NE', 'NE1', 'NE2', 'NH1', |
|
'NH2', 'NZ', 'OD1', 'OD2', 'OE1', |
|
'OE2', 'OG', 'OG1', 'OH', 'SD', |
|
'SG'], |
|
structure) |
|
return coords |
|
|
|
|
|
def extract_residues_from_structure(structure: biotite.structure.AtomArray): |
|
""" |
|
Args: |
|
structure: An instance of biotite AtomArray |
|
Returns: |
|
Tuple (coords, seq) |
|
- coords is an L x 3 x 3 array for N, CA, C coordinates |
|
- seq is the extracted sequence |
|
""" |
|
residue_identities = get_residues(structure)[1] |
|
seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) |
|
return seq |
|
|
|
|
|
def get_atom_coords_residue_wise(atoms: List[str], struct: biotite.structure.AtomArray): |
|
""" |
|
Example for atoms argument: ["N", "O", "CA", "C", "CB"] |
|
""" |
|
|
|
def filterfn(s, axis=None): |
|
filters = np.stack([s.atom_name == name for name in atoms], axis=1) |
|
filter_sum = filters.sum(0) |
|
if not np.all(filter_sum <= np.ones(filters.shape[1])): |
|
raise RuntimeError("structure has multiple atoms with same name") |
|
index = filters.argmax(0) |
|
coords = s[index].coord |
|
coords[filter_sum == 0] = float("nan") |
|
return coords |
|
|
|
return biotite.structure.apply_residue_wise(struct, struct, filterfn) |
|
|
|
|
|
def get_mutations(uniprot_id, transcript_id, seq, seq_orig_len, |
|
pos_orig, ref_aa, alt_aa, max_len=1400, af2_file=None): |
|
mutation = Mutation(uniprot_id, transcript_id, seq, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len, af2_file) |
|
if mutation.af2_file is None: |
|
print( |
|
f"No AF2 file found for this mutation "+ |
|
f"{mutation.uniprot_id}:{mutation.ref_aa}:{mutation.pos}:{mutation.alt_aa}. Skipping..." |
|
) |
|
return False |
|
else: |
|
return mutation |
|
|
|
|
|
def get_random_point_mutations(uniprot_id, transcript_id, seq, seq_orig_len, |
|
pos_orig, ref_aa, alt_aa, score): |
|
if score == -1: |
|
point_mutation = RandomPointMutation(uniprot_id, transcript_id, seq, seq_orig_len) |
|
else: |
|
point_mutation = Mutation(uniprot_id, transcript_id, seq, seq_orig_len, pos_orig, ref_aa, alt_aa) |
|
if point_mutation.af2_file is None: |
|
return False |
|
else: |
|
return point_mutation |
|
|
|
|
|
def get_mask_predict_point_mutations(uniprot_id, transcript_id, seq, seq_orig_len, |
|
pos_orig, ref_aa, alt_aa, max_len=2251, af2_file=None): |
|
point_mutation = MaskPredictPointMutation(uniprot_id, transcript_id, seq, seq_orig_len, pos_orig, ref_aa, alt_aa, max_len, af2_file) |
|
|
|
if point_mutation.af2_file is None: |
|
print( |
|
f"No AF2 file found for this mutation "+ |
|
f"{point_mutation.uniprot_id}:{point_mutation.ref_aa}:{point_mutation.pos}:{point_mutation.alt_aa}. Skipping..." |
|
) |
|
return False |
|
else: |
|
return point_mutation |
|
|
|
|
|
def get_coords_from_af2(af2_file, add_sidechain=False): |
|
pdb_path = af2_file |
|
structure = load_structure(pdb_path) |
|
af2_coords = extract_coords_from_structure(structure) |
|
if add_sidechain: |
|
af2_coords_sidechain = extract_sidechain_from_structure(structure) |
|
af2_coords = np.concatenate([af2_coords, af2_coords_sidechain], axis=1) |
|
return af2_coords |
|
|
|
|
|
def get_plddt_from_af2(af2_file): |
|
pdb_file = PandasPdb().read_pdb(af2_file) |
|
pdb_file = pdb_file.df['ATOM'].drop_duplicates(subset=['residue_number']) |
|
plddt = pdb_file['b_factor'].values |
|
return plddt |
|
|
|
|
|
def get_dssp_from_af2(af2_file): |
|
p = PDBParser() |
|
with gzip.open(af2_file, 'rt') as f: |
|
structure = p.get_structure('', f) |
|
model = structure[0] |
|
|
|
|
|
|
|
random.seed(hash(af2_file)) |
|
tmpfile = '/share/descartes/Users/gz2294/tmp/'+ ''.join(random.choices(string.ascii_letters, k=5)) + '.pdb' |
|
with open(tmpfile, 'w') as f: |
|
f.write(gzip.open(af2_file, 'rt').read()) |
|
dssp = DSSP(model, tmpfile, file_type="PDB", dssp="/share/descartes/Users/gz2294/miniconda3/bin/mkdssp") |
|
os.remove(tmpfile) |
|
|
|
dssp = pd.DataFrame(dssp) |
|
sec_struc = np.eye(len(DSSP_DICT), dtype=np.float32)[[DSSP_DICT.index(i) for i in dssp.iloc[:, 2].values]] |
|
return np.concatenate([sec_struc, |
|
dssp.iloc[:, 3].values[:, None], |
|
dssp.iloc[:, 4].values[:, None] / 180 * np.pi, |
|
dssp.iloc[:, 5].values[:, None] / 180 * np.pi], axis=1) |
|
|
|
|
|
def get_ptm_from_mutation(mutation: Mutation, ptm_ref): |
|
|
|
|
|
uniprotID = mutation.uniprot_id |
|
ptm_ref = ptm_ref[ptm_ref['uniprotID'] == uniprotID] |
|
seq = mutation.seq |
|
|
|
ptm_ref['pos'] = ptm_ref['pos'] - mutation.seq_start_orig - mutation.seq_start + 1 |
|
ptm_ref = ptm_ref[ptm_ref['pos'] >= 0] |
|
ptm_ref = ptm_ref[ptm_ref['pos'] < mutation.seq_len] |
|
ptm_mat = np.zeros([mutation.seq_len, len(PTM_DICT)]) |
|
for i in ptm_ref.index: |
|
if ptm_ref['ref'].loc[i] == seq[ptm_ref['pos'].loc[i]]: |
|
ptm_mat[ptm_ref['pos'].loc[i], PTM_DICT[ptm_ref['type'].loc[i]]] = 1 |
|
return ptm_mat |
|
|
|
|
|
def get_knn_graphs_from_af2(af2_coords, radius=None, max_neighbors=None, loop=False, gpu_id=None): |
|
CA_coord = af2_coords[:, 3] |
|
if radius is None: |
|
edge_index = np.indices((af2_coords.shape[0], af2_coords.shape[0])).reshape(2, -1) |
|
|
|
if not loop: |
|
edge_index = edge_index[:, edge_index[0] != edge_index[1]] |
|
else: |
|
if max_neighbors is None: |
|
max_neighbors = af2_coords.shape[0] + 1 |
|
with torch.no_grad(): |
|
CA_coord = torch.from_numpy(CA_coord) |
|
edge_index = knn_graph( |
|
x=CA_coord.to(f'cuda:{gpu_id}') if gpu_id is not None and torch.cuda.is_available() else CA_coord, |
|
|
|
loop=loop, |
|
|
|
k=max_neighbors, |
|
num_workers=NUM_THREADS, |
|
).detach().cpu().numpy() |
|
del CA_coord |
|
return edge_index |
|
|
|
|
|
def get_radius_graphs_from_af2(af2_coords, radius, loop=False, gpu_id=None): |
|
CA_coord = af2_coords[:, 3] |
|
max_neighbors = af2_coords.shape[0] + 1 |
|
with torch.no_grad(): |
|
CA_coord = torch.from_numpy(CA_coord) |
|
edge_index = radius_graph( |
|
x=CA_coord.to(f'cuda:{gpu_id}') if gpu_id is not None and torch.cuda.is_available() else CA_coord, |
|
r=radius, |
|
loop=loop, |
|
max_num_neighbors=max_neighbors, |
|
num_workers=NUM_THREADS, |
|
).detach().cpu().numpy() |
|
del CA_coord |
|
return edge_index |
|
|
|
|
|
def get_radius_knn_graphs_from_af2(af2_coords, center_nodes, radius, max_neighbors, loop=False, gpu_id=None): |
|
|
|
CA_coord = af2_coords[:, 3] |
|
with torch.no_grad(): |
|
CA_coord = torch.from_numpy(CA_coord) |
|
edge_index = radius_graph( |
|
x=CA_coord.to(f'cuda:{gpu_id}') if gpu_id is not None and torch.cuda.is_available() else CA_coord, |
|
r=radius, |
|
loop=loop, |
|
max_num_neighbors=af2_coords.shape[0] + 1, |
|
num_workers=NUM_THREADS, |
|
).detach().cpu().numpy() |
|
|
|
edge_index_radius = edge_index[:, np.isin(edge_index[0], center_nodes)] |
|
|
|
edge_index = knn_graph( |
|
x=CA_coord.to(f'cuda:{gpu_id}') if gpu_id is not None and torch.cuda.is_available() else CA_coord, |
|
loop=loop, |
|
k=max_neighbors, |
|
num_workers=NUM_THREADS, |
|
).detach().cpu().numpy() |
|
del CA_coord |
|
|
|
edge_index = edge_index[:, np.isin(edge_index[0], edge_index_radius.flatten()) & np.isin(edge_index[1], edge_index_radius.flatten())] |
|
return edge_index |
|
|
|
|
|
def get_graphs_from_neighbor(af2_coords, max_neighbors=None, loop=False): |
|
nodes = af2_coords.shape[0] |
|
if max_neighbors is None: |
|
|
|
max_neighbors = nodes + 1 |
|
edge_graph = np.ones((nodes, nodes)) |
|
|
|
edge_graph *= np.tri(nodes, k=int(np.floor(max_neighbors / 2))) \ |
|
* np.tri(nodes, k=int(np.floor(max_neighbors / 2))).T |
|
edge_index = np.array(np.where(edge_graph == 1)) |
|
if not loop: |
|
edge_index = edge_index[:, edge_index[0] != edge_index[1]] |
|
return edge_index |
|
|
|
|
|
def get_embedding_from_esm2(protein, check_mode=True, seq_start=None, seq_end=None): |
|
if isinstance(protein, str): |
|
file_path = f"{ESM_DATA_PATH}/{protein}.representations.layer.48.npy" |
|
if os.path.exists(file_path): |
|
if check_mode: |
|
return True |
|
wt_orig = np.load(file_path) |
|
|
|
batch_tokens = wt_orig[max(0, seq_start): |
|
min(wt_orig.shape[0] - 1, seq_end + 1)] |
|
else: |
|
if check_mode: |
|
return False |
|
batch_tokens = np.zeros([seq_end - seq_start + 1, 5120 if ESM_MODEL_SIZE == "15B" else 1280]) |
|
elif isinstance(protein, np.ndarray): |
|
batch_tokens = protein[max(0, seq_start): |
|
min(protein.shape[0] - 1, seq_end + 1)] |
|
else: |
|
raise ValueError("protein must be either a string of uniprotID or a numpy array") |
|
return batch_tokens |
|
|
|
|
|
def get_esm_dict_from_uniprot(uniprotID): |
|
file_path = f"{ESM_DATA_PATH}/{uniprotID}.representations.layer.48.npy" |
|
wt_orig = np.load(file_path) |
|
return wt_orig |
|
|
|
|
|
def get_af2_single_rep_dict_from_prefix(uniprotID_prefix, filter=False): |
|
|
|
file_path = f"{uniprotID_prefix}_single_repr_rank_001_alphafold2_ptm_model_1_seed_000.npy" |
|
wt_orig = np.load(file_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return wt_orig |
|
|
|
|
|
def get_af2_pairwise_rep_dict_from_prefix(uniprotID_prefix): |
|
file_path = f"{uniprotID_prefix}_pair_repr_rank_001_alphafold2_ptm_model_1_seed_000.npy" |
|
wt_orig = np.load(file_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return wt_orig |
|
|
|
|
|
def get_embedding_from_esm1b(protein, check_mode=True, seq_start=None, seq_end=None): |
|
if isinstance(protein, str): |
|
file_path = f"/share/vault/Users/gz2294/Data/DMS/ClinVar.HGMD.PrimateAI.syn/esm1b.embedding.uniprotIDs/{protein}.representations.layer.48.npy" |
|
if os.path.exists(file_path): |
|
if check_mode: |
|
return True |
|
wt_orig = np.load(file_path) |
|
|
|
batch_tokens = wt_orig[max(0, seq_start): |
|
min(wt_orig.shape[0] - 1, seq_end + 1)] |
|
else: |
|
if check_mode: |
|
return False |
|
batch_tokens = np.zeros([seq_end - seq_start + 1, 5120 if ESM_MODEL_SIZE == "15B" else 1280]) |
|
elif isinstance(protein, np.ndarray): |
|
batch_tokens = protein[max(0, seq_start): |
|
min(protein.shape[0] - 1, seq_end + 1)] |
|
else: |
|
raise ValueError("protein must be either a string of uniprotID or a numpy array") |
|
return batch_tokens |
|
|
|
|
|
def get_embedding_from_onehot(seq, seq_start=None, seq_end=None, return_idx=False, aa_dict=None, return_onehot_mat=False): |
|
if aa_dict is None: |
|
idx = [AA_DICT.index(aa) for aa in seq] |
|
protein = np.eye(len(AA_DICT))[idx] |
|
one_hot_mat = np.eye(len(AA_DICT)) |
|
else: |
|
idx = [aa_dict.index(aa) for aa in seq] |
|
protein = np.eye(len(aa_dict))[idx] |
|
one_hot_mat = np.eye(len(aa_dict)) |
|
if seq_start is not None and seq_end is not None: |
|
batch_tokens = protein[max(0, seq_start - 1): min(protein.shape[0], seq_end)] |
|
else: |
|
batch_tokens = protein |
|
if return_idx: |
|
if return_onehot_mat: |
|
return batch_tokens, np.array(idx), one_hot_mat |
|
else: |
|
return batch_tokens, np.array(idx) |
|
else: |
|
if return_onehot_mat: |
|
return batch_tokens, one_hot_mat |
|
else: |
|
return batch_tokens |
|
|
|
|
|
def get_embedding_from_esm_onehot(seq, seq_start=None, seq_end=None, return_idx=False, aa_dict=None, return_onehot_mat=False): |
|
if aa_dict is None: |
|
idx = [ESM_TOKENS.index('<cls>')] + [ESM_TOKENS.index(aa) for aa in seq] + [ESM_TOKENS.index('<eos>')] |
|
|
|
protein = np.array(idx) |
|
else: |
|
idx = [aa_dict.index(aa) for aa in seq] |
|
protein = np.array(idx) |
|
if seq_start is not None and seq_end is not None: |
|
batch_tokens = protein[max(0, seq_start - 1): min(protein.shape[0], seq_end)] |
|
else: |
|
batch_tokens = protein |
|
if return_idx: |
|
if return_onehot_mat: |
|
return batch_tokens, np.array(idx), None |
|
else: |
|
return batch_tokens, np.array(idx) |
|
else: |
|
if return_onehot_mat: |
|
return batch_tokens, None |
|
else: |
|
return batch_tokens |
|
|
|
|
|
def get_embedding_from_5dim(seq, seq_start=None, seq_end=None): |
|
protein = np.array([AA_5DIM_EMBED[aa] for aa in seq]) |
|
if seq_start is not None and seq_end is not None: |
|
batch_tokens = protein[max(0, seq_start - 1): min(protein.shape[0], seq_end)] |
|
else: |
|
batch_tokens = protein |
|
return batch_tokens |
|
|
|
|
|
def get_embedding_from_onehot_nonzero(seq, seq_start=None, seq_end=None, return_idx=False, |
|
aa_dict=None, min_prob=0.001, return_onehot_mat=False): |
|
if aa_dict is None: |
|
aa_dict = AA_DICT |
|
one_hot_mat = np.eye(len(aa_dict)) |
|
n_special_tok = 0 |
|
for special_tok in ['<mask>', '<pad>']: |
|
if special_tok in aa_dict: |
|
one_hot_mat[aa_dict.index(special_tok), :] = -1 |
|
one_hot_mat[:, aa_dict.index(special_tok)] = -1 |
|
one_hot_mat[aa_dict.index(special_tok), aa_dict.index(special_tok)] = 2 |
|
n_special_tok += 1 |
|
one_hot_mat[one_hot_mat == 0] = min_prob |
|
one_hot_mat[one_hot_mat == 1] = 1 - min_prob * (len(aa_dict) - n_special_tok) |
|
one_hot_mat[one_hot_mat == -1] = 0 |
|
one_hot_mat[one_hot_mat == 2] = 1 |
|
idx = [aa_dict.index(aa) for aa in seq] |
|
protein = one_hot_mat[idx] |
|
if seq_start is not None and seq_end is not None: |
|
batch_tokens = protein[max(0, seq_start - 1): min(protein.shape[0], seq_end)] |
|
else: |
|
batch_tokens = protein |
|
if return_idx: |
|
if return_onehot_mat: |
|
return batch_tokens, np.array(idx), one_hot_mat |
|
else: |
|
return batch_tokens, np.array(idx) |
|
else: |
|
if return_onehot_mat: |
|
return batch_tokens, one_hot_mat |
|
else: |
|
return batch_tokens |
|
|
|
|
|
def get_conservation_from_msa(mutation: Mutation, check_mode=False): |
|
transcript = mutation.transcript_id |
|
seq = mutation.seq |
|
seq_start = mutation.seq_start_orig |
|
seq_end = mutation.seq_end_orig |
|
if seq_start is None: |
|
seq_start = 1 |
|
if seq_end is None: |
|
seq_end = len(seq) |
|
msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
|
if not os.path.exists(f'{MSA_DATA_PATH}/{transcript}.pickle'): |
|
matched_line = False |
|
else: |
|
with open(os.path.join(MSA_DATA_PATH, transcript + '.pickle'), 'rb') as file: |
|
msa_mat = pickle.load(file) |
|
msa_seq = ''.join(msa_alphabet[msa_mat[seq_start - 1:seq_end, 0].astype(int)]) |
|
if mutation.crop: |
|
msa_seq = msa_seq[mutation.seq_start -1:mutation.seq_end] |
|
matched_line = msa_seq == seq |
|
if matched_line: |
|
if check_mode: |
|
return True |
|
|
|
conservation = msa_mat[seq_start - 1:seq_end, 1:41] |
|
else: |
|
if check_mode: |
|
return False |
|
conservation = np.zeros([seq_end - seq_start + 1, 40]) |
|
if mutation.crop: |
|
conservation = conservation[mutation.seq_start -1:mutation.seq_end] |
|
return conservation |
|
|
|
|
|
def get_msa_dict_from_transcript_archive(transcript): |
|
msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
|
if pd.isna(transcript) or not os.path.exists(f'{MSA_DATA_PATH}/{transcript}.pickle'): |
|
msa_seq = '' |
|
conservation = np.zeros([0, 20]) |
|
msa = np.zeros([0, 200]) |
|
else: |
|
with open(os.path.join(MSA_DATA_PATH, transcript + '.pickle'), 'rb') as file: |
|
msa_mat = pickle.load(file) |
|
msa_seq = ''.join(msa_alphabet[msa_mat[:, 0].astype(int)]) |
|
conservation = msa_mat[:, 1:21] |
|
msa = msa_mat[:, 21:221] |
|
return msa_seq, conservation, msa |
|
|
|
|
|
def get_msa_dict_from_transcript(uniprotID): |
|
msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
|
if pd.isna(uniprotID) or not os.path.exists(f'{MSA_DATA_PATH}/{uniprotID}_MSA.npy'): |
|
msa_seq = '' |
|
conservation = np.zeros([0, 20]) |
|
msa = np.zeros([0, 199]) |
|
else: |
|
msa_mat = np.load(f'{MSA_DATA_PATH}/{uniprotID}_MSA.npy') |
|
msa_seq = ''.join(msa_alphabet[msa_mat[:, 0].astype(int)]) |
|
conservation = np.eye(21)[msa_mat.astype(int)].mean(axis=1)[:, :20] |
|
msa = msa_mat |
|
return msa_seq, conservation, msa |
|
|
|
|
|
def get_confidence_from_af2file(af2file, pLDDT): |
|
uniprotID = af2file.split('/')[-1].split('.')[0].split('-model')[0] |
|
if pd.isna(uniprotID) or not os.path.exists(f'{PAE_DATA_PATH}/{uniprotID[3:6]}/{uniprotID}-predicted_aligned_error_v4.json.gz'): |
|
|
|
|
|
pae = (200 - pLDDT[None, :] - pLDDT[:, None]) / 4 if not pLDDT is None else None |
|
else: |
|
with gzip.open(f'{PAE_DATA_PATH}/{uniprotID[3:6]}/{uniprotID}-predicted_aligned_error_v4.json.gz', 'rt') as f: |
|
pae = json.load(f) |
|
|
|
pae = np.array(pae[0]['predicted_aligned_error']) |
|
return pae |
|
|
|
|
|
def get_msa(mutation: Mutation, check_mode=False): |
|
transcript = mutation.transcript_id |
|
seq = mutation.seq |
|
seq_start = mutation.seq_start_orig |
|
seq_end = mutation.seq_end_orig |
|
if seq_start is None: |
|
seq_start = 1 |
|
if seq_end is None: |
|
seq_end = len(seq) |
|
msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
|
if not os.path.exists(f'{MSA_DATA_PATH}/{transcript}.pickle'): |
|
matched_line = False |
|
else: |
|
with open(os.path.join(MSA_DATA_PATH, transcript + '.pickle'), 'rb') as file: |
|
msa_mat = pickle.load(file) |
|
msa_seq = ''.join(msa_alphabet[msa_mat[seq_start - 1:seq_end, 0].astype(int)]) |
|
if mutation.crop: |
|
msa_seq = msa_seq[mutation.seq_start -1:mutation.seq_end] |
|
matched_line = msa_seq == seq |
|
if matched_line: |
|
if check_mode: |
|
return True |
|
|
|
msa = msa_mat[seq_start - 1:seq_end, 21:221] |
|
else: |
|
if check_mode: |
|
return False |
|
msa = np.zeros([seq_end - seq_start + 1, 200]) |
|
if mutation.crop: |
|
msa = msa[mutation.seq_start -1:mutation.seq_end] |
|
return msa |
|
|
|
|
|
def get_logits_from_esm2(protein, check_mode=True, seq_start=None, seq_end=None): |
|
if isinstance(protein, str): |
|
file_path = f"{ESM_DATA_PATH}/{protein}.logits.npy" |
|
if os.path.exists(file_path): |
|
if check_mode: |
|
return True |
|
wt_orig = np.load(file_path) |
|
|
|
batch_tokens = wt_orig[max(0, seq_start): |
|
min(wt_orig.shape[0] - 1, seq_end + 1)] |
|
else: |
|
if check_mode: |
|
return False |
|
batch_tokens = np.zeros([seq_end - seq_start + 1, 32]) |
|
elif isinstance(protein, np.ndarray): |
|
batch_tokens = protein[max(0, seq_start): |
|
min(protein.shape[0] - 1, seq_end + 1)] |
|
else: |
|
raise ValueError("protein must be either a string of uniprotID or a numpy array") |
|
return batch_tokens |
|
|
|
|
|
def get_attn_from_msa(transcript, seq, check_mode=False, seq_start=None, seq_end=None): |
|
NUM_LAYERS = 6 |
|
msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
|
if isinstance(transcript, str): |
|
if pd.isna(transcript) \ |
|
or not os.path.exists(f'{MSA_DATA_PATH}/{transcript}.pickle') \ |
|
or not os.path.exists(f'{MSA_ATTN_DATA_PATH}/{transcript}.row_attentions.pt'): |
|
matched_line = False |
|
else: |
|
with open(os.path.join(MSA_DATA_PATH, transcript + '.pickle'), 'rb') as file: |
|
msa_mat = pickle.load(file) |
|
if seq_start is None: |
|
seq_start = 1 |
|
if seq_end is None: |
|
seq_end = len(seq) |
|
msa_seq = ''.join(msa_alphabet[msa_mat[seq_start - 1:seq_end, 0].astype(int)]) |
|
matched_line = msa_seq == seq |
|
if matched_line: |
|
if check_mode: |
|
return True |
|
msa_row_attns = torch.load( |
|
os.path.join(MSA_ATTN_DATA_PATH, transcript + '.row_attentions.pt')).detach().numpy() |
|
msa_contacts = torch.load(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.contacts.pt')).detach().numpy() |
|
|
|
|
|
msa_row_attns = msa_row_attns[:, (12 - NUM_LAYERS):, :, seq_start - 1:seq_end, seq_start - 1:seq_end] |
|
msa_contacts = msa_contacts[:, seq_start - 1:seq_end, seq_start - 1:seq_end] |
|
msa_pairwise = np.concatenate([msa_row_attns.reshape(-1, msa_row_attns.shape[-2], msa_row_attns.shape[-1]), |
|
msa_contacts], axis=0).transpose((1, 2, 0)) |
|
else: |
|
if check_mode: |
|
return False |
|
msa_pairwise = np.zeros([seq_end - seq_start + 1, seq_end - seq_start + 1, NUM_LAYERS * 12 + 1]) |
|
elif isinstance(transcript, tuple): |
|
msa_row_attns = transcript[0] |
|
msa_contacts = transcript[1] |
|
if msa_row_attns is not None and msa_contacts is not None: |
|
msa_row_attns = msa_row_attns[:, (12 - NUM_LAYERS):, :, seq_start - 1:seq_end, seq_start - 1:seq_end] |
|
msa_contacts = msa_contacts[:, seq_start - 1:seq_end, seq_start - 1:seq_end] |
|
msa_pairwise = np.concatenate([msa_row_attns.reshape(-1, msa_row_attns.shape[-2], msa_row_attns.shape[-1]), |
|
msa_contacts], axis=0).transpose((1, 2, 0)) |
|
else: |
|
msa_pairwise = np.zeros([seq_end - seq_start + 1, seq_end - seq_start + 1, NUM_LAYERS * 12 + 1]) |
|
else: |
|
raise ValueError("transcript must be either a string of transcriptID" |
|
" or a tuple of msa_row_attns and msa_contacts") |
|
return msa_pairwise |
|
|
|
|
|
def get_contacts_from_msa(mutation: Mutation, check_mode=False): |
|
transcript = mutation.transcript_id |
|
seq = mutation.seq |
|
seq_start = mutation.seq_start |
|
seq_end = mutation.seq_end |
|
msa_alphabet = np.array(list('ACDEFGHIKLMNPQRSTVWYU')) |
|
if pd.isna(transcript) \ |
|
or not os.path.exists(f'{MSA_DATA_PATH_ARCHIVE}/{transcript}.pickle') \ |
|
or not os.path.exists(f'{MSA_ATTN_DATA_PATH}/{transcript}.contacts.pt'): |
|
matched_line = False |
|
else: |
|
with open(os.path.join(MSA_DATA_PATH_ARCHIVE, transcript + '.pickle'), 'rb') as file: |
|
msa_mat = pickle.load(file) |
|
if seq_start is None: |
|
seq_start = 1 |
|
if seq_end is None: |
|
seq_end = len(seq) |
|
msa_seq = ''.join(msa_alphabet[msa_mat[seq_start - 1:seq_end, 0].astype(int)]) |
|
matched_line = msa_seq == seq |
|
if matched_line: |
|
if check_mode: |
|
return True |
|
msa_contacts = torch.load(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.contacts.pt')).detach().numpy() |
|
|
|
msa_contacts = msa_contacts[:, seq_start - 1:seq_end, seq_start - 1:seq_end] |
|
msa_pairwise = msa_contacts.transpose((1, 2, 0)) |
|
else: |
|
|
|
if not os.path.exists(f'{ESM_DATA_PATH}/{mutation.ESM_prefix}.contacts.npy'): |
|
if check_mode: |
|
return False |
|
msa_pairwise = np.zeros([seq_end - seq_start + 1, seq_end - seq_start + 1, 1]) |
|
else: |
|
if check_mode: |
|
return True |
|
msa_pairwise = np.load(f'{ESM_DATA_PATH}/{mutation.ESM_prefix}.contacts.npy') |
|
msa_pairwise = np.expand_dims(msa_pairwise[seq_start - 1:seq_end, seq_start - 1:seq_end], axis=2) |
|
return msa_pairwise |
|
|
|
|
|
def get_contacts_from_msa_by_identifier(identifier): |
|
str_split = identifier.split(":") |
|
transcript = str_split[0] |
|
seq = str_split[1] |
|
seq_start = int(str_split[2]) |
|
seq_end = int(str_split[3]) |
|
check_mode = False |
|
return get_contacts_from_msa(transcript, seq, check_mode, seq_start, seq_end) |
|
|
|
|
|
def load_embedding_from_esm2(protein): |
|
file_path = f"{ESM_DATA_PATH}/{protein}.representations.layer.48.npy" |
|
assert os.path.exists(file_path) |
|
return np.load(file_path) |
|
|
|
|
|
def load_logits_from_esm2(protein): |
|
file_path = f"{ESM_DATA_PATH}/{protein}.logits.npy" |
|
assert os.path.exists(file_path) |
|
return np.load(file_path) |
|
|
|
|
|
def load_attn_from_msa(transcript): |
|
if os.path.exists(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.row_attentions.pt')) and \ |
|
os.path.exists(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.contacts.pt')): |
|
msa_row_attns = torch.load(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.row_attentions.pt')).detach().numpy() |
|
msa_contacts = torch.load(os.path.join(MSA_ATTN_DATA_PATH, transcript + '.contacts.pt')).detach().numpy() |
|
return msa_row_attns, msa_contacts |
|
else: |
|
return None, None |
|
|
|
|
|
def _test_load(): |
|
test_file = pd.read_csv('/share/terra/Users/gz2294/ld1/Data/DMS/ClinVar.HGMD.PrimateAI.syn/training.csv', |
|
index_col=0) |
|
|
|
idx = np.where(test_file['sequence.len.orig'] == 4753)[0][0] |
|
point_mutation = get_mutations(test_file['uniprotID'].iloc[idx], |
|
test_file['ENST'].iloc[idx], |
|
test_file['wt.orig'].iloc[idx], |
|
test_file['sequence.len.orig'].iloc[idx], |
|
test_file['pos.orig'].iloc[idx], |
|
test_file['ref'].iloc[idx], |
|
test_file['alt'].iloc[idx]) |
|
coords = get_coords_from_af2(point_mutation.af2_file) |
|
|
|
CA_coord = coords[:, 3] |
|
embed_data = get_embedding_from_esm2(point_mutation.uniprot_id, False, |
|
point_mutation.seq_start, point_mutation.seq_end) |
|
|
|
coev_strength = get_attn_from_msa(point_mutation.transcript_id, point_mutation.seq, False, |
|
point_mutation.seq_start, point_mutation.seq_end) |
|
edge_index = np.indices((coords.shape[0], coords.shape[0])).reshape(2, -1) |
|
|
|
edge_index = edge_index[:, edge_index[0] != edge_index[1]] |
|
edge_attr = coev_strength[edge_index[0], edge_index[1], :] |
|
|
|
CA_CB = coords[:, [4]] - coords[:, [3]] |
|
CA_C = coords[:, [1]] - coords[:, [3]] |
|
CA_O = coords[:, [2]] - coords[:, [3]] |
|
CA_N = coords[:, [0]] - coords[:, [3]] |
|
nodes_vector = np.concatenate([CA_CB, CA_C, CA_O, CA_N], axis=1) |
|
|
|
features = dict( |
|
pos=torch.from_numpy(CA_coord), x=torch.from_numpy(embed_data), |
|
edge_index=torch.from_numpy(edge_index), edge_attr=torch.from_numpy(edge_attr).to(torch.float), |
|
node_vec_attr=torch.from_numpy(nodes_vector).transpose(1, 2) |
|
) |
|
from torch_geometric.data import Data |
|
|
|
map_data = Data(**features) |
|
return map_data |
|
|
|
|
|
if __name__ == '__main__': |
|
print(_test_load()) |
|
|