# # Copyright (c) Meta Platforms, Inc. and affiliates. # # # # This source code is licensed under the MIT license found in the # # LICENSE file in the root directory of this source tree. # # import biotite.structure # import numpy as np # import torch # from typing import Sequence, Tuple, List # # from esm.inverse_folding.util import ( # load_structure, # extract_coords_from_structure, # load_coords, # get_sequence_loss, # get_encoder_output, # ) # # # def extract_coords_from_complex(structure: biotite.structure.AtomArray): # """ # Args: # structure: biotite AtomArray # Returns: # Tuple (coords_list, seq_list) # - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C # coordinates representing the backbone of each chain # - seqs: Dictionary mapping chain ids to native sequences of each chain # """ # coords = {} # seqs = {} # all_chains = biotite.structure.get_chains(structure) # for chain_id in all_chains: # chain = structure[structure.chain_id == chain_id] # coords[chain_id], seqs[chain_id] = extract_coords_from_structure(chain) # return coords, seqs # # # def load_complex_coords(fpath, chains): # """ # Args: # fpath: filepath to either pdb or cif file # chains: the chain ids (the order matters for autoregressive model) # Returns: # Tuple (coords_list, seq_list) # - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C # coordinates representing the backbone of each chain # - seqs: Dictionary mapping chain ids to native sequences of each chain # """ # structure = load_structure(fpath, chains) # return extract_coords_from_complex(structure) # # # def _concatenate_coords(coords, target_chain_id, padding_length=10): # """ # Args: # coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C # coordinates representing the backbone of each chain # target_chain_id: The chain id to sample sequences for # padding_length: Length of padding between concatenated chains # Returns: # Tuple (coords, seq) # - coords is an L x 3 x 3 array for N, CA, C coordinates, a # concatenation of the chains with padding in between # - seq is the extracted sequence, with padding tokens inserted # between the concatenated chains # """ # pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32) # # For best performance, put the target chain first in concatenation. # coords_list = [coords[target_chain_id]] # for chain_id in coords: # if chain_id == target_chain_id: # continue # coords_list.append(pad_coords) # coords_list.append(coords[chain_id]) # coords_concatenated = np.concatenate(coords_list, axis=0) # return coords_concatenated # # # def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1., # padding_length=10): # """ # Samples sequence for one chain in a complex. # Args: # model: An instance of the GVPTransformer model # coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C # coordinates representing the backbone of each chain # target_chain_id: The chain id to sample sequences for # padding_length: padding length in between chains # Returns: # Sampled sequence for the target chain # """ # target_chain_len = coords[target_chain_id].shape[0] # all_coords = _concatenate_coords(coords, target_chain_id) # device = next(model.parameters()).device # # # Supply padding tokens for other chains to avoid unused sampling for speed # padding_pattern = [''] * all_coords.shape[0] # for i in range(target_chain_len): # padding_pattern[i] = '' # sampled = model.sample(all_coords, partial_seq=padding_pattern, # temperature=temperature, device=device) # sampled = sampled[:target_chain_len] # return sampled # # # def score_sequence_in_complex(model, alphabet, coords, target_chain_id, # target_seq, padding_length=10): # """ # Scores sequence for one chain in a complex. # Args: # model: An instance of the GVPTransformer model # alphabet: Alphabet for the model # coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C # coordinates representing the backbone of each chain # target_chain_id: The chain id to sample sequences for # target_seq: Target sequence for the target chain for scoring. # padding_length: padding length in between chains # Returns: # Tuple (ll_fullseq, ll_withcoord) # - ll_fullseq: Average log-likelihood over the full target chain # - ll_withcoord: Average log-likelihood in target chain excluding those # residues without coordinates # """ # all_coords = _concatenate_coords(coords, target_chain_id) # # loss, target_padding_mask = get_sequence_loss(model, alphabet, all_coords, # target_seq) # ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum( # ~target_padding_mask) # # # Also calculate average when excluding masked portions # coord_mask = np.all(np.isfinite(coords[target_chain_id]), axis=(-1, -2)) # ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask) # return ll_fullseq, ll_withcoord # # # def get_encoder_output_for_complex(model, alphabet, coords, target_chain_id): # """ # Args: # model: An instance of the GVPTransformer model # alphabet: Alphabet for the model # coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C # coordinates representing the backbone of each chain # target_chain_id: The chain id to sample sequences for # Returns: # Dictionary mapping chain id to encoder output for each chain # """ # all_coords = _concatenate_coords(coords, target_chain_id) # all_rep = get_encoder_output(model, alphabet, all_coords) # target_chain_len = coords[target_chain_id].shape[0] # return all_rep[:target_chain_len]