|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Functions for building the input features for the AlphaFold model.""" |
|
|
|
|
|
import os |
|
|
from typing import Mapping, Optional, Sequence |
|
|
from absl import logging |
|
|
from alphafold.common import residue_constants |
|
|
from alphafold.data import parsers |
|
|
from alphafold.data import templates |
|
|
from alphafold.data.tools import hhblits |
|
|
from alphafold.data.tools import hhsearch |
|
|
from alphafold.data.tools import jackhmmer |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
FeatureDict = Mapping[str, np.ndarray] |
|
|
|
|
|
|
|
|
def make_sequence_features( |
|
|
sequence: str, description: str, num_res: int) -> FeatureDict: |
|
|
"""Constructs a feature dict of sequence features.""" |
|
|
features = {} |
|
|
features['aatype'] = residue_constants.sequence_to_onehot( |
|
|
sequence=sequence, |
|
|
mapping=residue_constants.restype_order_with_x, |
|
|
map_unknown_to_x=True) |
|
|
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) |
|
|
features['domain_name'] = np.array([description.encode('utf-8')], |
|
|
dtype=np.object_) |
|
|
features['residue_index'] = np.array(range(num_res), dtype=np.int32) |
|
|
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) |
|
|
features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_) |
|
|
return features |
|
|
|
|
|
|
|
|
def make_msa_features( |
|
|
msas: Sequence[Sequence[str]], |
|
|
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: |
|
|
"""Constructs a feature dict of MSA features.""" |
|
|
if not msas: |
|
|
raise ValueError('At least one MSA must be provided.') |
|
|
|
|
|
int_msa = [] |
|
|
deletion_matrix = [] |
|
|
seen_sequences = set() |
|
|
for msa_index, msa in enumerate(msas): |
|
|
if not msa: |
|
|
raise ValueError(f'MSA {msa_index} must contain at least one sequence.') |
|
|
for sequence_index, sequence in enumerate(msa): |
|
|
if sequence in seen_sequences: |
|
|
continue |
|
|
seen_sequences.add(sequence) |
|
|
int_msa.append( |
|
|
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) |
|
|
deletion_matrix.append(deletion_matrices[msa_index][sequence_index]) |
|
|
|
|
|
num_res = len(msas[0][0]) |
|
|
num_alignments = len(int_msa) |
|
|
features = {} |
|
|
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) |
|
|
features['msa'] = np.array(int_msa, dtype=np.int32) |
|
|
features['num_alignments'] = np.array( |
|
|
[num_alignments] * num_res, dtype=np.int32) |
|
|
return features |
|
|
|
|
|
|
|
|
class DataPipeline: |
|
|
"""Runs the alignment tools and assembles the input features.""" |
|
|
|
|
|
def __init__(self, |
|
|
jackhmmer_binary_path: str, |
|
|
hhblits_binary_path: str, |
|
|
hhsearch_binary_path: str, |
|
|
uniref90_database_path: str, |
|
|
mgnify_database_path: str, |
|
|
bfd_database_path: Optional[str], |
|
|
uniclust30_database_path: Optional[str], |
|
|
small_bfd_database_path: Optional[str], |
|
|
pdb70_database_path: str, |
|
|
template_featurizer: templates.TemplateHitFeaturizer, |
|
|
use_small_bfd: bool, |
|
|
mgnify_max_hits: int = 501, |
|
|
uniref_max_hits: int = 10000): |
|
|
"""Constructs a feature dict for a given FASTA file.""" |
|
|
self._use_small_bfd = use_small_bfd |
|
|
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( |
|
|
binary_path=jackhmmer_binary_path, |
|
|
database_path=uniref90_database_path) |
|
|
if use_small_bfd: |
|
|
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( |
|
|
binary_path=jackhmmer_binary_path, |
|
|
database_path=small_bfd_database_path) |
|
|
else: |
|
|
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( |
|
|
binary_path=hhblits_binary_path, |
|
|
databases=[bfd_database_path, uniclust30_database_path]) |
|
|
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( |
|
|
binary_path=jackhmmer_binary_path, |
|
|
database_path=mgnify_database_path) |
|
|
self.hhsearch_pdb70_runner = hhsearch.HHSearch( |
|
|
binary_path=hhsearch_binary_path, |
|
|
databases=[pdb70_database_path]) |
|
|
self.template_featurizer = template_featurizer |
|
|
self.mgnify_max_hits = mgnify_max_hits |
|
|
self.uniref_max_hits = uniref_max_hits |
|
|
|
|
|
def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: |
|
|
"""Runs alignment tools on the input sequence and creates features.""" |
|
|
with open(input_fasta_path) as f: |
|
|
input_fasta_str = f.read() |
|
|
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) |
|
|
if len(input_seqs) != 1: |
|
|
raise ValueError( |
|
|
f'More than one input sequence found in {input_fasta_path}.') |
|
|
input_sequence = input_seqs[0] |
|
|
input_description = input_descs[0] |
|
|
num_res = len(input_sequence) |
|
|
|
|
|
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query( |
|
|
input_fasta_path)[0] |
|
|
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query( |
|
|
input_fasta_path)[0] |
|
|
|
|
|
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( |
|
|
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits) |
|
|
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m) |
|
|
|
|
|
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') |
|
|
with open(uniref90_out_path, 'w') as f: |
|
|
f.write(jackhmmer_uniref90_result['sto']) |
|
|
|
|
|
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') |
|
|
with open(mgnify_out_path, 'w') as f: |
|
|
f.write(jackhmmer_mgnify_result['sto']) |
|
|
|
|
|
pdb70_out_path = os.path.join(msa_output_dir, 'pdb70_hits.hhr') |
|
|
with open(pdb70_out_path, 'w') as f: |
|
|
f.write(hhsearch_result) |
|
|
|
|
|
uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm( |
|
|
jackhmmer_uniref90_result['sto']) |
|
|
mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm( |
|
|
jackhmmer_mgnify_result['sto']) |
|
|
hhsearch_hits = parsers.parse_hhr(hhsearch_result) |
|
|
mgnify_msa = mgnify_msa[:self.mgnify_max_hits] |
|
|
mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits] |
|
|
|
|
|
if self._use_small_bfd: |
|
|
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query( |
|
|
input_fasta_path)[0] |
|
|
|
|
|
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m') |
|
|
with open(bfd_out_path, 'w') as f: |
|
|
f.write(jackhmmer_small_bfd_result['sto']) |
|
|
|
|
|
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm( |
|
|
jackhmmer_small_bfd_result['sto']) |
|
|
else: |
|
|
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query( |
|
|
input_fasta_path) |
|
|
|
|
|
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') |
|
|
with open(bfd_out_path, 'w') as f: |
|
|
f.write(hhblits_bfd_uniclust_result['a3m']) |
|
|
|
|
|
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m( |
|
|
hhblits_bfd_uniclust_result['a3m']) |
|
|
|
|
|
templates_result = self.template_featurizer.get_templates( |
|
|
query_sequence=input_sequence, |
|
|
query_pdb_code=None, |
|
|
query_release_date=None, |
|
|
hits=hhsearch_hits) |
|
|
|
|
|
sequence_features = make_sequence_features( |
|
|
sequence=input_sequence, |
|
|
description=input_description, |
|
|
num_res=num_res) |
|
|
|
|
|
msa_features = make_msa_features( |
|
|
msas=(uniref90_msa, bfd_msa, mgnify_msa), |
|
|
deletion_matrices=(uniref90_deletion_matrix, |
|
|
bfd_deletion_matrix, |
|
|
mgnify_deletion_matrix)) |
|
|
|
|
|
logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) |
|
|
logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) |
|
|
logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa)) |
|
|
logging.info('Final (deduplicated) MSA size: %d sequences.', |
|
|
msa_features['num_alignments'][0]) |
|
|
logging.info('Total number of templates (NB: this can include bad ' |
|
|
'templates and is later filtered to top 4): %d.', |
|
|
templates_result.features['template_domain_names'].shape[0]) |
|
|
|
|
|
return {**sequence_features, **msa_features, **templates_result.features} |
|
|
|