from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel import torch import os class FusOnTokenizer: """ FusOnTokenizer class: a wrapper around AutoTokenizer """ def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D'): self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path) def __getattr__(self, name): """ Delegate attribute access to the underlying tokenizer. This allows calls like .tokenize(), .train(), and .eval() to be forwarded to the tokenizer. """ return getattr(self.tokenizer, name) def __call__(self, *args, **kwargs): """ Make the FusOnTokenizer object callable, delegating to the tokenizer's __call__ method. """ return self.tokenizer(*args, **kwargs) def save_tokenizer(self, save_directory): self.tokenizer.save_pretrained(save_directory) def load_tokenizer(self, load_directory): self.tokenizer = AutoTokenizer.from_pretrained(load_directory) class FusOnpLM: """ FusOn-pLM class: a wrapper around AutoModelForMaskedLM """ def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D', ckpt_path = None, mlm_head=False): if not(ckpt_path is None): self.load_model(ckpt_path, mlm_head) else: # Load the pre-trained model and tokenizer self.model = AutoModelForMaskedLM.from_pretrained(pretrained_path) self.tokenizer = FusOnTokenizer(pretrained_path) self.n_layers = self.count_encoder_layers() def __getattr__(self, name): """ Delegate attribute access to the underlying model. This allows calls like .to(), .train(), and .eval() to be forwarded to the model. """ return getattr(self.model, name) def __call__(self, *args, **kwargs): """ Make the FusOnpLM object callable, delegating to the model's __call__ method. """ return self.model(*args, **kwargs) def freeze_model(self): """ Freezes all parameters in the model """ for param in self.model.parameters(): param.requires_grad = False def unfreeze_last_n_layers(self, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True): """ Unfreezes specific parts of the final n layers in the model's encoder. Args: n_unfrozen_layers (int): Number of final layers to unfreeze. unfreeze_query (bool): Whether to unfreeze the query projections. Default is True. unfreeze_key (bool): Whether to unfreeze the key projections. Default is True. unfreeze_value (bool): Whether to unfreeze the value projections. Default is True. """ for i, layer in enumerate(self.model.esm.encoder.layer): if (self.n_layers - i) <= n_unfrozen_layers: # Only the last n layers if unfreeze_query: self._unfreeze_parameters(layer.attention.self.query) if unfreeze_key: self._unfreeze_parameters(layer.attention.self.key) if unfreeze_value: self._unfreeze_parameters(layer.attention.self.value) def _unfreeze_parameters(self, module): """ Helper method to unfreeze parameters in a given module. Args: module (nn.Module): The module whose parameters are to be unfrozen. """ for param in module.parameters(): param.requires_grad = True def count_encoder_layers(self): """ Count the number of encoder layers in the model. """ return len(self.model.esm.encoder.layer) def save_model(self, save_directory, optimizer=None): # Save the model and tokenizer self.model.save_pretrained(save_directory) self.tokenizer.save_pretrained(save_directory) # If an optimizer is provided, save its state dict if optimizer is not None: optimizer_path = os.path.join(save_directory, "optimizer.pt") torch.save(optimizer.state_dict(), optimizer_path) def load_model(self, load_directory, mlm_head): # Load a checkpoint of the model either with or without an MLM head if mlm_head: self.model = AutoModelForMaskedLM.from_pretrained(load_directory) else: # Load the model and tokenizer from a directory self.model = AutoModel.from_pretrained(load_directory) self.tokenizer = AutoTokenizer.from_pretrained(load_directory)