|
|
|
|
|
|
|
|
|
|
|
import biotite.structure |
|
import numpy as np |
|
import torch |
|
from typing import Sequence, Tuple, List |
|
|
|
from esm.inverse_folding.util import ( |
|
load_structure, |
|
extract_coords_from_structure, |
|
load_coords, |
|
get_sequence_loss, |
|
get_encoder_output, |
|
) |
|
|
|
|
|
def extract_coords_from_complex(structure: biotite.structure.AtomArray): |
|
""" |
|
Args: |
|
structure: biotite AtomArray |
|
Returns: |
|
Tuple (coords_list, seq_list) |
|
- coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C |
|
coordinates representing the backbone of each chain |
|
- seqs: Dictionary mapping chain ids to native sequences of each chain |
|
""" |
|
coords = {} |
|
seqs = {} |
|
all_chains = biotite.structure.get_chains(structure) |
|
for chain_id in all_chains: |
|
chain = structure[structure.chain_id == chain_id] |
|
coords[chain_id], seqs[chain_id] = extract_coords_from_structure(chain) |
|
return coords, seqs |
|
|
|
|
|
def load_complex_coords(fpath, chains): |
|
""" |
|
Args: |
|
fpath: filepath to either pdb or cif file |
|
chains: the chain ids (the order matters for autoregressive model) |
|
Returns: |
|
Tuple (coords_list, seq_list) |
|
- coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C |
|
coordinates representing the backbone of each chain |
|
- seqs: Dictionary mapping chain ids to native sequences of each chain |
|
""" |
|
structure = load_structure(fpath, chains) |
|
return extract_coords_from_complex(structure) |
|
|
|
|
|
def _concatenate_coords(coords, target_chain_id, padding_length=10): |
|
""" |
|
Args: |
|
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C |
|
coordinates representing the backbone of each chain |
|
target_chain_id: The chain id to sample sequences for |
|
padding_length: Length of padding between concatenated chains |
|
Returns: |
|
Tuple (coords, seq) |
|
- coords is an L x 3 x 3 array for N, CA, C coordinates, a |
|
concatenation of the chains with padding in between |
|
- seq is the extracted sequence, with padding tokens inserted |
|
between the concatenated chains |
|
""" |
|
pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32) |
|
|
|
coords_list = [coords[target_chain_id]] |
|
for chain_id in coords: |
|
if chain_id == target_chain_id: |
|
continue |
|
coords_list.append(pad_coords) |
|
coords_list.append(coords[chain_id]) |
|
coords_concatenated = np.concatenate(coords_list, axis=0) |
|
return coords_concatenated |
|
|
|
|
|
def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1., |
|
padding_length=10): |
|
""" |
|
Samples sequence for one chain in a complex. |
|
Args: |
|
model: An instance of the GVPTransformer model |
|
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C |
|
coordinates representing the backbone of each chain |
|
target_chain_id: The chain id to sample sequences for |
|
padding_length: padding length in between chains |
|
Returns: |
|
Sampled sequence for the target chain |
|
""" |
|
target_chain_len = coords[target_chain_id].shape[0] |
|
all_coords = _concatenate_coords(coords, target_chain_id) |
|
device = next(model.parameters()).device |
|
|
|
|
|
padding_pattern = ['<pad>'] * all_coords.shape[0] |
|
for i in range(target_chain_len): |
|
padding_pattern[i] = '<mask>' |
|
sampled = model.sample(all_coords, partial_seq=padding_pattern, |
|
temperature=temperature, device=device) |
|
sampled = sampled[:target_chain_len] |
|
return sampled |
|
|
|
|
|
def score_sequence_in_complex(model, alphabet, coords, target_chain_id, |
|
target_seq, padding_length=10): |
|
""" |
|
Scores sequence for one chain in a complex. |
|
Args: |
|
model: An instance of the GVPTransformer model |
|
alphabet: Alphabet for the model |
|
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C |
|
coordinates representing the backbone of each chain |
|
target_chain_id: The chain id to sample sequences for |
|
target_seq: Target sequence for the target chain for scoring. |
|
padding_length: padding length in between chains |
|
Returns: |
|
Tuple (ll_fullseq, ll_withcoord) |
|
- ll_fullseq: Average log-likelihood over the full target chain |
|
- ll_withcoord: Average log-likelihood in target chain excluding those |
|
residues without coordinates |
|
""" |
|
all_coords = _concatenate_coords(coords, target_chain_id) |
|
|
|
loss, target_padding_mask = get_sequence_loss(model, alphabet, all_coords, |
|
target_seq) |
|
ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum( |
|
~target_padding_mask) |
|
|
|
|
|
coord_mask = np.all(np.isfinite(coords[target_chain_id]), axis=(-1, -2)) |
|
ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask) |
|
return ll_fullseq, ll_withcoord |
|
|
|
|
|
def get_encoder_output_for_complex(model, alphabet, coords, target_chain_id): |
|
""" |
|
Args: |
|
model: An instance of the GVPTransformer model |
|
alphabet: Alphabet for the model |
|
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C |
|
coordinates representing the backbone of each chain |
|
target_chain_id: The chain id to sample sequences for |
|
Returns: |
|
Dictionary mapping chain id to encoder output for each chain |
|
""" |
|
all_coords = _concatenate_coords(coords, target_chain_id) |
|
all_rep = get_encoder_output(model, alphabet, all_coords) |
|
target_chain_len = coords[target_chain_id].shape[0] |
|
return all_rep[:target_chain_len] |
|
|