| import warnings |
| from typing import cast |
|
|
| import attr |
| import torch |
|
|
| from src.data.esm.models.function_decoder import FunctionTokenDecoder |
| from src.data.esm.models.vqvae import StructureTokenDecoder |
| from src.data.esm.sdk.api import ESMProtein, ESMProteinTensor |
| from src.data.esm.tokenization import TokenizerCollectionProtocol |
| from src.data.esm.tokenization.function_tokenizer import ( |
| InterProQuantizedTokenizer, |
| ) |
| from src.data.esm.tokenization.residue_tokenizer import ( |
| ResidueAnnotationsTokenizer, |
| ) |
| from src.data.esm.tokenization.sasa_tokenizer import ( |
| SASADiscretizingTokenizer, |
| ) |
| from src.data.esm.tokenization.sequence_tokenizer import ( |
| EsmSequenceTokenizer, |
| ) |
| from src.data.esm.tokenization.ss_tokenizer import ( |
| SecondaryStructureTokenizer, |
| ) |
| from src.data.esm.tokenization.structure_tokenizer import ( |
| StructureTokenizer, |
| ) |
| from src.data.esm.tokenization.tokenizer_base import EsmTokenizerBase |
| from src.data.esm.utils.constants import esm3 as C |
| from src.data.esm.utils.function.encode_decode import ( |
| decode_function_tokens, |
| decode_residue_annotation_tokens, |
| ) |
| from src.data.esm.utils.misc import maybe_list |
| from src.data.esm.utils.structure.protein_chain import ProteinChain |
| from src.data.esm.utils.types import FunctionAnnotation |
|
|
|
|
| def decode_protein_tensor( |
| input: ESMProteinTensor, |
| tokenizers: TokenizerCollectionProtocol, |
| structure_token_decoder: StructureTokenDecoder, |
| function_token_decoder: FunctionTokenDecoder | None = None, |
| ) -> ESMProtein: |
| input = attr.evolve(input) |
|
|
| sequence = None |
| secondary_structure = None |
| sasa = None |
| function_annotations = [] |
|
|
| coordinates = None |
|
|
| |
| for track in attr.fields(ESMProteinTensor): |
| tokens: torch.Tensor | None = getattr(input, track.name) |
| if track.name == "coordinates" or track.name == "potential_sequence_of_concern": |
| continue |
| if tokens is not None: |
| tokens = tokens[1:-1] |
| tokens = tokens.flatten() |
| track_tokenizer = getattr(tokenizers, track.name) |
| if torch.all(tokens == track_tokenizer.pad_token_id): |
| setattr(input, track.name, None) |
| |
| if track.name == "structure" and torch.any( |
| tokens == track_tokenizer.mask_token_id |
| ): |
| setattr(input, track.name, None) |
|
|
| if input.sequence is not None: |
| sequence = decode_sequence(input.sequence, tokenizers.sequence) |
|
|
| plddt, ptm = None, None |
| if input.structure is not None: |
| |
| coordinates, plddt, ptm = decode_structure( |
| structure_tokens=input.structure, |
| structure_decoder=structure_token_decoder, |
| structure_tokenizer=tokenizers.structure, |
| sequence=sequence, |
| ) |
| elif input.coordinates is not None: |
| coordinates = input.coordinates[1:-1, ...] |
|
|
| if input.secondary_structure is not None: |
| secondary_structure = decode_secondary_structure( |
| input.secondary_structure, tokenizers.secondary_structure |
| ) |
| if input.sasa is not None: |
| sasa = decode_sasa(input.sasa, tokenizers.sasa) |
| if input.function is not None: |
| if function_token_decoder is None: |
| raise ValueError( |
| "Cannot decode function annotations without a function token decoder" |
| ) |
| function_track_annotations = decode_function_annotations( |
| input.function, |
| function_token_decoder=function_token_decoder, |
| function_tokenizer=tokenizers.function, |
| ) |
| function_annotations.extend(function_track_annotations) |
| if input.residue_annotations is not None: |
| residue_annotations = decode_residue_annotations( |
| input.residue_annotations, tokenizers.residue_annotations |
| ) |
| function_annotations.extend(residue_annotations) |
|
|
| return ESMProtein( |
| sequence=sequence, |
| secondary_structure=secondary_structure, |
| sasa=sasa, |
| function_annotations=function_annotations if function_annotations else None, |
| coordinates=coordinates, |
| plddt=plddt, |
| ptm=ptm, |
| potential_sequence_of_concern=input.potential_sequence_of_concern, |
| ) |
|
|
|
|
| def _bos_eos_warn(msg: str, tensor: torch.Tensor, tok: EsmTokenizerBase): |
| if tensor[0] != tok.bos_token_id: |
| warnings.warn( |
| f"{msg} does not start with BOS token, token is ignored. BOS={tok.bos_token_id} vs {tensor}" |
| ) |
| if tensor[-1] != tok.eos_token_id: |
| warnings.warn( |
| f"{msg} does not end with EOS token, token is ignored. EOS='{tok.eos_token_id}': {tensor}" |
| ) |
|
|
|
|
| def decode_sequence( |
| sequence_tokens: torch.Tensor, sequence_tokenizer: EsmSequenceTokenizer, **kwargs |
| ) -> str: |
| _bos_eos_warn("Sequence", sequence_tokens, sequence_tokenizer) |
| sequence = sequence_tokenizer.decode(sequence_tokens, **kwargs) |
| sequence = sequence.replace(" ", "") |
| sequence = sequence.replace(sequence_tokenizer.mask_token, C.MASK_STR_SHORT) |
| sequence = sequence.replace(sequence_tokenizer.cls_token, "") |
| sequence = sequence.replace(sequence_tokenizer.eos_token, "") |
|
|
| return sequence |
|
|
|
|
| def decode_structure( |
| structure_tokens: torch.Tensor, |
| structure_decoder: StructureTokenDecoder, |
| structure_tokenizer: StructureTokenizer, |
| sequence: str | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: |
| is_singleton = len(structure_tokens.size()) == 1 |
| if is_singleton: |
| structure_tokens = structure_tokens.unsqueeze(0) |
| else: |
| raise ValueError( |
| f"Only one structure can be decoded at a time, got structure tokens of shape {structure_tokens.size()}" |
| ) |
| _bos_eos_warn("Structure", structure_tokens[0], structure_tokenizer) |
|
|
| decoder_output = structure_decoder.decode(structure_tokens) |
| bb_coords: torch.Tensor = decoder_output["bb_pred"][ |
| 0, 1:-1, ... |
| ] |
| bb_coords = bb_coords.detach().cpu() |
|
|
| if "plddt" in decoder_output: |
| plddt = decoder_output["plddt"][0, 1:-1] |
| plddt = plddt.detach().cpu() |
| else: |
| plddt = None |
|
|
| if "ptm" in decoder_output: |
| ptm = decoder_output["ptm"] |
| else: |
| ptm = None |
|
|
| chain = ProteinChain.from_backbone_atom_coordinates(bb_coords, sequence=sequence) |
| chain = chain.infer_oxygen() |
| return torch.tensor(chain.atom37_positions), plddt, ptm |
|
|
|
|
| def decode_secondary_structure( |
| secondary_structure_tokens: torch.Tensor, ss_tokenizer: SecondaryStructureTokenizer |
| ) -> str: |
| _bos_eos_warn("Secondary structure", secondary_structure_tokens, ss_tokenizer) |
| secondary_structure_tokens = secondary_structure_tokens[1:-1] |
| secondary_structure = ss_tokenizer.decode(secondary_structure_tokens) |
| return secondary_structure |
|
|
|
|
| def decode_sasa( |
| sasa_tokens: torch.Tensor, sasa_tokenizer: SASADiscretizingTokenizer |
| ) -> list[float]: |
| if sasa_tokens[0] != 0: |
| raise ValueError("SASA does not start with 0 corresponding to BOS token") |
| if sasa_tokens[-1] != 0: |
| raise ValueError("SASA does not end with 0 corresponding to EOS token") |
| sasa_tokens = sasa_tokens[1:-1] |
| if sasa_tokens.dtype in [ |
| torch.int8, |
| torch.int16, |
| torch.int32, |
| torch.int64, |
| torch.long, |
| ]: |
| |
| |
| sasa = sasa_tokenizer.decode_float(sasa_tokens) |
| else: |
| |
| sasa = cast(list[float], maybe_list(sasa_tokens, convert_nan_to_none=True)) |
|
|
| return sasa |
|
|
|
|
| def decode_function_annotations( |
| function_annotation_tokens: torch.Tensor, |
| function_token_decoder: FunctionTokenDecoder, |
| function_tokenizer: InterProQuantizedTokenizer, |
| **kwargs, |
| ) -> list[FunctionAnnotation]: |
| |
|
|
| function_annotations = decode_function_tokens( |
| function_annotation_tokens, |
| function_token_decoder=function_token_decoder, |
| function_tokens_tokenizer=function_tokenizer, |
| **kwargs, |
| ) |
| return function_annotations |
|
|
|
|
| def decode_residue_annotations( |
| residue_annotation_tokens: torch.Tensor, |
| residue_annotation_decoder: ResidueAnnotationsTokenizer, |
| ) -> list[FunctionAnnotation]: |
| |
|
|
| residue_annotations = decode_residue_annotation_tokens( |
| residue_annotations_token_ids=residue_annotation_tokens, |
| residue_annotations_tokenizer=residue_annotation_decoder, |
| ) |
| return residue_annotations |
|
|