Spaces:
Running
Running
# # Copyright (c) Meta Platforms, Inc. and affiliates. | |
# # | |
# # This source code is licensed under the MIT license found in the | |
# # LICENSE file in the root directory of this source tree. | |
# | |
# import biotite.structure | |
# import numpy as np | |
# import torch | |
# from typing import Sequence, Tuple, List | |
# | |
# from esm.inverse_folding.util import ( | |
# load_structure, | |
# extract_coords_from_structure, | |
# load_coords, | |
# get_sequence_loss, | |
# get_encoder_output, | |
# ) | |
# | |
# | |
# def extract_coords_from_complex(structure: biotite.structure.AtomArray): | |
# """ | |
# Args: | |
# structure: biotite AtomArray | |
# Returns: | |
# Tuple (coords_list, seq_list) | |
# - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C | |
# coordinates representing the backbone of each chain | |
# - seqs: Dictionary mapping chain ids to native sequences of each chain | |
# """ | |
# coords = {} | |
# seqs = {} | |
# all_chains = biotite.structure.get_chains(structure) | |
# for chain_id in all_chains: | |
# chain = structure[structure.chain_id == chain_id] | |
# coords[chain_id], seqs[chain_id] = extract_coords_from_structure(chain) | |
# return coords, seqs | |
# | |
# | |
# def load_complex_coords(fpath, chains): | |
# """ | |
# Args: | |
# fpath: filepath to either pdb or cif file | |
# chains: the chain ids (the order matters for autoregressive model) | |
# Returns: | |
# Tuple (coords_list, seq_list) | |
# - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C | |
# coordinates representing the backbone of each chain | |
# - seqs: Dictionary mapping chain ids to native sequences of each chain | |
# """ | |
# structure = load_structure(fpath, chains) | |
# return extract_coords_from_complex(structure) | |
# | |
# | |
# def _concatenate_coords(coords, target_chain_id, padding_length=10): | |
# """ | |
# Args: | |
# coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C | |
# coordinates representing the backbone of each chain | |
# target_chain_id: The chain id to sample sequences for | |
# padding_length: Length of padding between concatenated chains | |
# Returns: | |
# Tuple (coords, seq) | |
# - coords is an L x 3 x 3 array for N, CA, C coordinates, a | |
# concatenation of the chains with padding in between | |
# - seq is the extracted sequence, with padding tokens inserted | |
# between the concatenated chains | |
# """ | |
# pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32) | |
# # For best performance, put the target chain first in concatenation. | |
# coords_list = [coords[target_chain_id]] | |
# for chain_id in coords: | |
# if chain_id == target_chain_id: | |
# continue | |
# coords_list.append(pad_coords) | |
# coords_list.append(coords[chain_id]) | |
# coords_concatenated = np.concatenate(coords_list, axis=0) | |
# return coords_concatenated | |
# | |
# | |
# def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1., | |
# padding_length=10): | |
# """ | |
# Samples sequence for one chain in a complex. | |
# Args: | |
# model: An instance of the GVPTransformer model | |
# coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C | |
# coordinates representing the backbone of each chain | |
# target_chain_id: The chain id to sample sequences for | |
# padding_length: padding length in between chains | |
# Returns: | |
# Sampled sequence for the target chain | |
# """ | |
# target_chain_len = coords[target_chain_id].shape[0] | |
# all_coords = _concatenate_coords(coords, target_chain_id) | |
# device = next(model.parameters()).device | |
# | |
# # Supply padding tokens for other chains to avoid unused sampling for speed | |
# padding_pattern = ['<pad>'] * all_coords.shape[0] | |
# for i in range(target_chain_len): | |
# padding_pattern[i] = '<mask>' | |
# sampled = model.sample(all_coords, partial_seq=padding_pattern, | |
# temperature=temperature, device=device) | |
# sampled = sampled[:target_chain_len] | |
# return sampled | |
# | |
# | |
# def score_sequence_in_complex(model, alphabet, coords, target_chain_id, | |
# target_seq, padding_length=10): | |
# """ | |
# Scores sequence for one chain in a complex. | |
# Args: | |
# model: An instance of the GVPTransformer model | |
# alphabet: Alphabet for the model | |
# coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C | |
# coordinates representing the backbone of each chain | |
# target_chain_id: The chain id to sample sequences for | |
# target_seq: Target sequence for the target chain for scoring. | |
# padding_length: padding length in between chains | |
# Returns: | |
# Tuple (ll_fullseq, ll_withcoord) | |
# - ll_fullseq: Average log-likelihood over the full target chain | |
# - ll_withcoord: Average log-likelihood in target chain excluding those | |
# residues without coordinates | |
# """ | |
# all_coords = _concatenate_coords(coords, target_chain_id) | |
# | |
# loss, target_padding_mask = get_sequence_loss(model, alphabet, all_coords, | |
# target_seq) | |
# ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum( | |
# ~target_padding_mask) | |
# | |
# # Also calculate average when excluding masked portions | |
# coord_mask = np.all(np.isfinite(coords[target_chain_id]), axis=(-1, -2)) | |
# ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask) | |
# return ll_fullseq, ll_withcoord | |
# | |
# | |
# def get_encoder_output_for_complex(model, alphabet, coords, target_chain_id): | |
# """ | |
# Args: | |
# model: An instance of the GVPTransformer model | |
# alphabet: Alphabet for the model | |
# coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C | |
# coordinates representing the backbone of each chain | |
# target_chain_id: The chain id to sample sequences for | |
# Returns: | |
# Dictionary mapping chain id to encoder output for each chain | |
# """ | |
# all_coords = _concatenate_coords(coords, target_chain_id) | |
# all_rep = get_encoder_output(model, alphabet, all_coords) | |
# target_chain_len = coords[target_chain_id].shape[0] | |
# return all_rep[:target_chain_len] | |