FAPM_demo / esm /inverse_folding /multichain_util.py
wenkai's picture
Upload 31 files
3f0529e verified
raw
history blame
No virus
6 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import biotite.structure
import numpy as np
import torch
from typing import Sequence, Tuple, List
from esm.inverse_folding.util import (
load_structure,
extract_coords_from_structure,
load_coords,
get_sequence_loss,
get_encoder_output,
)
def extract_coords_from_complex(structure: biotite.structure.AtomArray):
"""
Args:
structure: biotite AtomArray
Returns:
Tuple (coords_list, seq_list)
- coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
coordinates representing the backbone of each chain
- seqs: Dictionary mapping chain ids to native sequences of each chain
"""
coords = {}
seqs = {}
all_chains = biotite.structure.get_chains(structure)
for chain_id in all_chains:
chain = structure[structure.chain_id == chain_id]
coords[chain_id], seqs[chain_id] = extract_coords_from_structure(chain)
return coords, seqs
def load_complex_coords(fpath, chains):
"""
Args:
fpath: filepath to either pdb or cif file
chains: the chain ids (the order matters for autoregressive model)
Returns:
Tuple (coords_list, seq_list)
- coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
coordinates representing the backbone of each chain
- seqs: Dictionary mapping chain ids to native sequences of each chain
"""
structure = load_structure(fpath, chains)
return extract_coords_from_complex(structure)
def _concatenate_coords(coords, target_chain_id, padding_length=10):
"""
Args:
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
coordinates representing the backbone of each chain
target_chain_id: The chain id to sample sequences for
padding_length: Length of padding between concatenated chains
Returns:
Tuple (coords, seq)
- coords is an L x 3 x 3 array for N, CA, C coordinates, a
concatenation of the chains with padding in between
- seq is the extracted sequence, with padding tokens inserted
between the concatenated chains
"""
pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32)
# For best performance, put the target chain first in concatenation.
coords_list = [coords[target_chain_id]]
for chain_id in coords:
if chain_id == target_chain_id:
continue
coords_list.append(pad_coords)
coords_list.append(coords[chain_id])
coords_concatenated = np.concatenate(coords_list, axis=0)
return coords_concatenated
def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1.,
padding_length=10):
"""
Samples sequence for one chain in a complex.
Args:
model: An instance of the GVPTransformer model
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
coordinates representing the backbone of each chain
target_chain_id: The chain id to sample sequences for
padding_length: padding length in between chains
Returns:
Sampled sequence for the target chain
"""
target_chain_len = coords[target_chain_id].shape[0]
all_coords = _concatenate_coords(coords, target_chain_id)
device = next(model.parameters()).device
# Supply padding tokens for other chains to avoid unused sampling for speed
padding_pattern = ['<pad>'] * all_coords.shape[0]
for i in range(target_chain_len):
padding_pattern[i] = '<mask>'
sampled = model.sample(all_coords, partial_seq=padding_pattern,
temperature=temperature, device=device)
sampled = sampled[:target_chain_len]
return sampled
def score_sequence_in_complex(model, alphabet, coords, target_chain_id,
target_seq, padding_length=10):
"""
Scores sequence for one chain in a complex.
Args:
model: An instance of the GVPTransformer model
alphabet: Alphabet for the model
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
coordinates representing the backbone of each chain
target_chain_id: The chain id to sample sequences for
target_seq: Target sequence for the target chain for scoring.
padding_length: padding length in between chains
Returns:
Tuple (ll_fullseq, ll_withcoord)
- ll_fullseq: Average log-likelihood over the full target chain
- ll_withcoord: Average log-likelihood in target chain excluding those
residues without coordinates
"""
all_coords = _concatenate_coords(coords, target_chain_id)
loss, target_padding_mask = get_sequence_loss(model, alphabet, all_coords,
target_seq)
ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(
~target_padding_mask)
# Also calculate average when excluding masked portions
coord_mask = np.all(np.isfinite(coords[target_chain_id]), axis=(-1, -2))
ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask)
return ll_fullseq, ll_withcoord
def get_encoder_output_for_complex(model, alphabet, coords, target_chain_id):
"""
Args:
model: An instance of the GVPTransformer model
alphabet: Alphabet for the model
coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
coordinates representing the backbone of each chain
target_chain_id: The chain id to sample sequences for
Returns:
Dictionary mapping chain id to encoder output for each chain
"""
all_coords = _concatenate_coords(coords, target_chain_id)
all_rep = get_encoder_output(model, alphabet, all_coords)
target_chain_len = coords[target_chain_id].shape[0]
return all_rep[:target_chain_len]