| | import warnings |
| | from typing import cast |
| |
|
| | import attr |
| | import torch |
| |
|
| | from src.data.esm.models.function_decoder import FunctionTokenDecoder |
| | from src.data.esm.models.vqvae import StructureTokenDecoder |
| | from src.data.esm.sdk.api import ESMProtein, ESMProteinTensor |
| | from src.data.esm.tokenization import TokenizerCollectionProtocol |
| | from src.data.esm.tokenization.function_tokenizer import ( |
| | InterProQuantizedTokenizer, |
| | ) |
| | from src.data.esm.tokenization.residue_tokenizer import ( |
| | ResidueAnnotationsTokenizer, |
| | ) |
| | from src.data.esm.tokenization.sasa_tokenizer import ( |
| | SASADiscretizingTokenizer, |
| | ) |
| | from src.data.esm.tokenization.sequence_tokenizer import ( |
| | EsmSequenceTokenizer, |
| | ) |
| | from src.data.esm.tokenization.ss_tokenizer import ( |
| | SecondaryStructureTokenizer, |
| | ) |
| | from src.data.esm.tokenization.structure_tokenizer import ( |
| | StructureTokenizer, |
| | ) |
| | from src.data.esm.tokenization.tokenizer_base import EsmTokenizerBase |
| | from src.data.esm.utils.constants import esm3 as C |
| | from src.data.esm.utils.function.encode_decode import ( |
| | decode_function_tokens, |
| | decode_residue_annotation_tokens, |
| | ) |
| | from src.data.esm.utils.misc import maybe_list |
| | from src.data.esm.utils.structure.protein_chain import ProteinChain |
| | from src.data.esm.utils.types import FunctionAnnotation |
| |
|
| |
|
| | def decode_protein_tensor( |
| | input: ESMProteinTensor, |
| | tokenizers: TokenizerCollectionProtocol, |
| | structure_token_decoder: StructureTokenDecoder, |
| | function_token_decoder: FunctionTokenDecoder | None = None, |
| | ) -> ESMProtein: |
| | input = attr.evolve(input) |
| |
|
| | sequence = None |
| | secondary_structure = None |
| | sasa = None |
| | function_annotations = [] |
| |
|
| | coordinates = None |
| |
|
| | |
| | for track in attr.fields(ESMProteinTensor): |
| | tokens: torch.Tensor | None = getattr(input, track.name) |
| | if track.name == "coordinates" or track.name == "potential_sequence_of_concern": |
| | continue |
| | if tokens is not None: |
| | tokens = tokens[1:-1] |
| | tokens = tokens.flatten() |
| | track_tokenizer = getattr(tokenizers, track.name) |
| | if torch.all(tokens == track_tokenizer.pad_token_id): |
| | setattr(input, track.name, None) |
| | |
| | if track.name == "structure" and torch.any( |
| | tokens == track_tokenizer.mask_token_id |
| | ): |
| | setattr(input, track.name, None) |
| |
|
| | if input.sequence is not None: |
| | sequence = decode_sequence(input.sequence, tokenizers.sequence) |
| |
|
| | plddt, ptm = None, None |
| | if input.structure is not None: |
| | |
| | coordinates, plddt, ptm = decode_structure( |
| | structure_tokens=input.structure, |
| | structure_decoder=structure_token_decoder, |
| | structure_tokenizer=tokenizers.structure, |
| | sequence=sequence, |
| | ) |
| | elif input.coordinates is not None: |
| | coordinates = input.coordinates[1:-1, ...] |
| |
|
| | if input.secondary_structure is not None: |
| | secondary_structure = decode_secondary_structure( |
| | input.secondary_structure, tokenizers.secondary_structure |
| | ) |
| | if input.sasa is not None: |
| | sasa = decode_sasa(input.sasa, tokenizers.sasa) |
| | if input.function is not None: |
| | if function_token_decoder is None: |
| | raise ValueError( |
| | "Cannot decode function annotations without a function token decoder" |
| | ) |
| | function_track_annotations = decode_function_annotations( |
| | input.function, |
| | function_token_decoder=function_token_decoder, |
| | function_tokenizer=tokenizers.function, |
| | ) |
| | function_annotations.extend(function_track_annotations) |
| | if input.residue_annotations is not None: |
| | residue_annotations = decode_residue_annotations( |
| | input.residue_annotations, tokenizers.residue_annotations |
| | ) |
| | function_annotations.extend(residue_annotations) |
| |
|
| | return ESMProtein( |
| | sequence=sequence, |
| | secondary_structure=secondary_structure, |
| | sasa=sasa, |
| | function_annotations=function_annotations if function_annotations else None, |
| | coordinates=coordinates, |
| | plddt=plddt, |
| | ptm=ptm, |
| | potential_sequence_of_concern=input.potential_sequence_of_concern, |
| | ) |
| |
|
| |
|
| | def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase): |
| | if tensor[0] != tok.bos_token_id: |
| | warnings.warn( |
| | f"{msg} does not start with BOS token, token is ignored. BOS={tok.bos_token_id} vs {tensor}" |
| | ) |
| | if tensor[-1] != tok.eos_token_id: |
| | warnings.warn( |
| | f"{msg} does not end with EOS token, token is ignored. EOS='{tok.eos_token_id}': {tensor}" |
| | ) |
| |
|
| |
|
| | def decode_sequence( |
| | sequence_tokens: torch.Tensor, sequence_tokenizer: EsmSequenceTokenizer, **kwargs |
| | ) -> str: |
| | _bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer) |
| | sequence = sequence_tokenizer.decode(sequence_tokens, **kwargs) |
| | sequence = sequence.replace(" ", "") |
| | sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT) |
| | sequence = sequence.replace(sequence_tokenizer.cls_token, "") |
| | sequence = sequence.replace(sequence_tokenizer.eos_token, "") |
| |
|
| | return sequence |
| |
|
| |
|
| | def decode_structure( |
| | structure_tokens: torch.Tensor, |
| | structure_decoder: StructureTokenDecoder, |
| | structure_tokenizer: StructureTokenizer, |
| | sequence: str | None = None, |
| | ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: |
| | is_singleton = len(structure_tokens.size()) == 1 |
| | if is_singleton: |
| | structure_tokens = structure_tokens.unsqueeze(0) |
| | else: |
| | raise ValueError( |
| | f"Only one structure can be decoded at a time, got structure tokens of shape {structure_tokens.size()}" |
| | ) |
| | _bos_eos_warn("Structure", structure_tokens[0], structure_tokenizer) |
| |
|
| | decoder_output = structure_decoder.decode(structure_tokens) |
| | bb_coords: torch.Tensor = decoder_output["bb_pred"][ |
| | 0, 1:-1, ... |
| | ] |
| | bb_coords = bb_coords.detach().cpu() |
| |
|
| | if "plddt" in decoder_output: |
| | plddt = decoder_output["plddt"][0, 1:-1] |
| | plddt = plddt.detach().cpu() |
| | else: |
| | plddt = None |
| |
|
| | if "ptm" in decoder_output: |
| | ptm = decoder_output["ptm"] |
| | else: |
| | ptm = None |
| |
|
| | chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence) |
| | chain = chain.infer_oxygen() |
| | return torch.tensor(chain.atom37_positions), plddt, ptm |
| |
|
| |
|
| | def decode_secondary_structure( |
| | secondary_structure_tokens: torch.Tensor, ss_tokenizer: SecondaryStructureTokenizer |
| | ) -> str: |
| | _bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer) |
| | secondary_structure_tokens = secondary_structure_tokens[1:-1] |
| | secondary_structure = ss_tokenizer.decode(secondary_structure_tokens) |
| | return secondary_structure |
| |
|
| |
|
| | def decode_sasa( |
| | sasa_tokens: torch.Tensor, sasa_tokenizer: SASADiscretizingTokenizer |
| | ) -> list[float]: |
| | if sasa_tokens[0] != 0: |
| | raise ValueError("SASA does not start with 0 corresponding to BOS token") |
| | if sasa_tokens[-1] != 0: |
| | raise ValueError("SASA does not end with 0 corresponding to EOS token") |
| | sasa_tokens = sasa_tokens[1:-1] |
| | if sasa_tokens.dtype in [ |
| | torch.int8, |
| | torch.int16, |
| | torch.int32, |
| | torch.int64, |
| | torch.long, |
| | ]: |
| | |
| | |
| | sasa = sasa_tokenizer.decode_float(sasa_tokens) |
| | else: |
| | |
| | sasa = cast(list[float], maybe_list(sasa_tokens, convert_nan_to_none=True)) |
| |
|
| | return sasa |
| |
|
| |
|
| | def decode_function_annotations( |
| | function_annotation_tokens: torch.Tensor, |
| | function_token_decoder: FunctionTokenDecoder, |
| | function_tokenizer: InterProQuantizedTokenizer, |
| | **kwargs, |
| | ) -> list[FunctionAnnotation]: |
| | |
| |
|
| | function_annotations = decode_function_tokens( |
| | function_annotation_tokens, |
| | function_token_decoder=function_token_decoder, |
| | function_tokens_tokenizer=function_tokenizer, |
| | **kwargs, |
| | ) |
| | return function_annotations |
| |
|
| |
|
| | def decode_residue_annotations( |
| | residue_annotation_tokens: torch.Tensor, |
| | residue_annotation_decoder: ResidueAnnotationsTokenizer, |
| | ) -> list[FunctionAnnotation]: |
| | |
| |
|
| | residue_annotations = decode_residue_annotation_tokens( |
| | residue_annotations_token_ids=residue_annotation_tokens, |
| | residue_annotations_tokenizer=residue_annotation_decoder, |
| | ) |
| | return residue_annotations |
| |
|