import torch from tqdm import tqdm from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer from torch.nn.functional import normalize class StructureEncoder(torch.nn.Module): def __init__(self, config_path: str, out_dim: int, gradient_checkpointing: bool = False): """ Args: config_path: Path to the config file out_dim: Output dimension of the structure representation gradient_checkpointing: Whether to use gradient checkpointing """ super().__init__() config = EsmConfig.from_pretrained(config_path) self.model = EsmForMaskedLM(config) self.out = torch.nn.Linear(config.hidden_size, out_dim) # Set gradient checkpointing self.model.esm.encoder.gradient_checkpointing = gradient_checkpointing # Remove contact head self.model.esm.contact_head = None # Remove position embedding if the embedding type is ``rotary`` if config.position_embedding_type == "rotary": self.model.esm.embeddings.position_embeddings = None self.tokenizer = EsmTokenizer.from_pretrained(config_path) def get_repr(self, proteins: list, batch_size: int = 64, verbose: bool = False) -> torch.Tensor: """ Compute protein structure representation for the given proteins Args: protein: A list of protein structural sequences batch_size: Batch size for inference verbose: Whether to print progress """ device = next(self.parameters()).device protein_repr = [] if verbose: iterator = tqdm(range(0, len(proteins), batch_size), desc="Computing protein embeddings") else: iterator = range(0, len(proteins), batch_size) for i in iterator: protein_inputs = self.tokenizer.batch_encode_plus(proteins[i:i + batch_size], return_tensors="pt", padding=True) protein_inputs = {k: v.to(device) for k, v in protein_inputs.items()} output, _ = self.forward(protein_inputs) protein_repr.append(output) protein_repr = torch.cat(protein_repr, dim=0) return normalize(protein_repr, dim=-1) def forward(self, inputs: dict, get_mask_logits: bool = False): """ Encode protein structure into protein representation Args: inputs: A dictionary containing the following keys: - input_ids: [batch, seq_len] - attention_mask: [batch, seq_len] get_mask_logits: Whether to return the logits for masked tokens Returns: protein_repr: [batch, protein_repr_dim] mask_logits : [batch, seq_len, vocab_size] """ last_hidden_state = self.model.esm(**inputs).last_hidden_state reprs = last_hidden_state[:, 0, :] reprs = self.out(reprs) # Get logits for masked tokens if get_mask_logits: mask_logits = self.model.lm_head(last_hidden_state) else: mask_logits = None return reprs, mask_logits