|
import logging |
|
import re |
|
from abc import ABC, abstractmethod |
|
from functools import partial |
|
from types import SimpleNamespace |
|
from typing import Dict, List, Literal, Optional |
|
|
|
import numpy as np |
|
import torch |
|
import tqdm as tqdm |
|
from datasets import Dataset |
|
from torch import Tensor |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader |
|
from transformers import ( |
|
AutoConfig, |
|
AutoModel, |
|
AutoModelForCausalLM, |
|
AutoModelForMaskedLM, |
|
AutoTokenizer, |
|
BatchEncoding, |
|
DefaultDataCollator, |
|
T5EncoderModel, |
|
T5Tokenizer, |
|
) |
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
from .modality import Modality |
|
from .eval_utils import ForwardHook, pool |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class BioSeqTransformer(ABC): |
|
""" |
|
Abstract class to wrap models which map biological sequences (DNA/Prot) to embeddings. |
|
Modelled after SentenceTransformer (https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py) |
|
|
|
Args: |
|
model_name: Name or path to the pretrained model. |
|
layers: List of model layers to probe. Can be integers or "mid" or "last". |
|
devices: List of device ids for inference. If cuda is not available, will use cpu. |
|
num_processes: Number of processes to use for data loading. |
|
max_seq_length: Maximum sequence length of the input sequences. |
|
l2_norm: If true, embeddings are L2-normalized before they are returned. |
|
batch_size: Batch size for encoding. |
|
pool_type: Pooling strategy to use. One of "mean", "max", "cls", "last". |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model_name: str, |
|
layers: Optional[List[int] | Literal["mid"] | Literal["last"]] = None, |
|
devices: List[int] = [0], |
|
num_processes: int = 16, |
|
max_seq_length: int = 1024, |
|
l2_norm: bool = False, |
|
batch_size: int = 128, |
|
pool_type: str = "mean", |
|
): |
|
super().__init__() |
|
|
|
self.id = self.__class__.__name__ |
|
self.hf_name = model_name |
|
self.encoder = self._load_model(model_name) |
|
if not hasattr(self.encoder, "config"): |
|
raise ValueError( |
|
'The model from `self._load_model()` must have a "config" attribute.' |
|
) |
|
self.config = self.encoder.config |
|
self.tokenizer = self._get_tokenizer(model_name) |
|
self.num_param = sum(p.numel() for p in self.encoder.parameters()) |
|
self.data_collator = DefaultDataCollator() |
|
self.gpu_count = len(devices) |
|
self.l2_norm = l2_norm |
|
|
|
self.device = torch.device( |
|
f"cuda:{devices[0]}" if torch.cuda.is_available() else "cpu" |
|
) |
|
self.num_processes = num_processes |
|
self.max_seq_length = max_seq_length |
|
self.batch_size = batch_size |
|
self.pool_type = pool_type |
|
|
|
if self.gpu_count > 1: |
|
self.encoder = torch.nn.DataParallel(self.encoder, device_ids=devices) |
|
self.encoder.to(self.device) |
|
self.encoder.eval() |
|
|
|
mid_layer = self.num_layers // 2 |
|
last_layer = self.num_layers - 1 |
|
mid_layer_label = f"mid ({mid_layer})" |
|
last_layer_label = f"last ({self.num_layers - 1})" |
|
|
|
if layers is None: |
|
logger.debug(f"Using default layers: {mid_layer_label}, {last_layer_label}") |
|
self.layers = [mid_layer, last_layer] |
|
self.layer_labels = [mid_layer_label, last_layer_label] |
|
elif layers == "mid": |
|
self.layers = [mid_layer] |
|
self.layer_labels = [mid_layer_label] |
|
elif layers == "last": |
|
self.layers = [last_layer] |
|
self.layer_labels = [last_layer_label] |
|
else: |
|
self.layers = layers |
|
self.layer_labels = [str(layer) for layer in layers] |
|
|
|
def _encode_single_batch(self, batch_dict: Dict[str, Tensor]): |
|
"""Returns the output embedding for the given batch with shape [batch, num_layers, D].""" |
|
outputs = self.encoder(**batch_dict, output_hidden_states=True) |
|
embeds = [outputs.hidden_states[layer] for layer in self.layers] |
|
embeds = [ |
|
pool(layer_embeds, batch_dict["attention_mask"], self.pool_type) |
|
for layer_embeds in embeds |
|
] |
|
|
|
embeds = torch.stack(embeds, dim=1) |
|
return embeds |
|
|
|
def _load_model(self, model_name): |
|
return AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
def _get_tokenizer(self, model_name): |
|
return AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
def _tokenize_func( |
|
self, tokenizer, examples: Dict[str, List], max_seq_length: int |
|
) -> BatchEncoding: |
|
batch_dict = tokenizer( |
|
examples["input_seqs"], |
|
max_length=max_seq_length, |
|
padding=True, |
|
truncation=True, |
|
) |
|
return batch_dict |
|
|
|
@property |
|
def metadata(self) -> Dict: |
|
return { |
|
"hf_name": self.hf_name, |
|
"num_layers": self.num_layers, |
|
"num_params": self.num_param, |
|
"embed_dim": self.embed_dim, |
|
} |
|
|
|
@property |
|
@abstractmethod |
|
def num_layers(self) -> int: |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def embed_dim(self) -> int: |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def modality(self) -> Modality: |
|
pass |
|
|
|
@torch.no_grad() |
|
def encode(self, sequences, **kwargs) -> np.ndarray: |
|
"""Returns a list of embeddings for the given sequences. |
|
Args: |
|
sequences (`List[str]`): List of sequences to encode |
|
Returns: |
|
`np.ndarray`: Embeddings for the given sequences of shape [num_sequences, num_layers, embedding_dim]. |
|
""" |
|
dataset = Dataset.from_dict({"input_seqs": sequences}) |
|
dataset.set_transform( |
|
partial( |
|
self._tokenize_func, self.tokenizer, max_seq_length=self.max_seq_length |
|
) |
|
) |
|
data_loader = DataLoader( |
|
dataset, |
|
batch_size=self.batch_size * self.gpu_count, |
|
shuffle=False, |
|
drop_last=False, |
|
num_workers=self.num_processes, |
|
collate_fn=self.data_collator, |
|
pin_memory=True, |
|
) |
|
|
|
if max(self.layers) >= self.num_layers: |
|
raise ValueError( |
|
f"Layer {max(self.layers)} is not available in the model. Choose a layer between 0 and {self.num_layers - 1}" |
|
) |
|
|
|
encoded_embeds = [] |
|
for batch_dict in tqdm.tqdm( |
|
data_loader, desc="encoding", mininterval=10, disable=len(sequences) < 128 |
|
): |
|
batch_dict = {k: v.to(self.device) for k, v in batch_dict.items()} |
|
|
|
embeds = self._encode_single_batch(batch_dict) |
|
|
|
if self.l2_norm: |
|
embeds = F.normalize(embeds, p=2, dim=-1) |
|
encoded_embeds.append(embeds.cpu().numpy()) |
|
|
|
return np.concatenate(encoded_embeds, axis=0) |
|
|
|
|
|
class ESM(BioSeqTransformer): |
|
"""ESM model from https://huggingface.co/docs/transformers/en/model_doc/esm""" |
|
|
|
MODEL_NAMES = [ |
|
"facebook/esm2_t6_8M_UR50D", |
|
"facebook/esm2_t12_35M_UR50D", |
|
"facebook/esm2_t30_150M_UR50D", |
|
"facebook/esm2_t33_650M_UR50D", |
|
"facebook/esm2_t36_3B_UR50D", |
|
"facebook/esm2_t48_15B_UR50D", |
|
] |
|
|
|
@property |
|
def modality(self) -> Modality: |
|
return Modality.PROTEIN |
|
|
|
@property |
|
def num_layers(self) -> int: |
|
return self.config.num_hidden_layers |
|
|
|
@property |
|
def embed_dim(self) -> int: |
|
return self.config.hidden_size |
|
|
|
|
|
class ESM3(BioSeqTransformer): |
|
"""ESM3 model from https://github.com/evolutionaryscale/esm""" |
|
|
|
MODEL_NAMES = ["esm3_sm_open_v1"] |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.hooks = [ |
|
ForwardHook(self.encoder.transformer.blocks[layer]) for layer in self.layers |
|
] |
|
|
|
@property |
|
def modality(self) -> Modality: |
|
return Modality.PROTEIN |
|
|
|
@property |
|
def num_layers(self) -> int: |
|
return self.config.num_hidden_layers |
|
|
|
@property |
|
def embed_dim(self) -> int: |
|
return self.config.hidden_size |
|
|
|
def _load_model(self, model_name): |
|
try: |
|
from esm.models.esm3 import ESM3 as ModelESM3 |
|
except ImportError: |
|
raise ImportError( |
|
"ESM3 is not installed. Please install it with `pip install esm`." |
|
) |
|
model = ModelESM3.from_pretrained("esm3_sm_open_v1") |
|
model.config = SimpleNamespace( |
|
num_hidden_layers=len(model.transformer.blocks), |
|
hidden_size=model.transformer.blocks[0].ffn[-1].out_features, |
|
) |
|
return model |
|
|
|
def _get_tokenizer(self, model_name): |
|
try: |
|
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer |
|
except ImportError: |
|
raise ImportError( |
|
"ESM3 is not installed. Please install it with `pip install esm`." |
|
) |
|
return EsmSequenceTokenizer() |
|
|
|
def _encode_single_batch(self, batch_dict: Dict[str, Tensor]): |
|
_ = self.encoder.forward(sequence_tokens=batch_dict["input_ids"]) |
|
embeds = [hook.output for hook in self.hooks] |
|
embeds = [ |
|
pool(layer_embeds, batch_dict["attention_mask"], self.pool_type) |
|
for layer_embeds in embeds |
|
] |
|
|
|
embeds = torch.stack(embeds, dim=1) |
|
embeds = embeds.to(torch.float32) |
|
return embeds |
|
|
|
|
|
class ProtT5(BioSeqTransformer): |
|
"""ProtT5 model from https://github.com/agemagician/ProtTrans""" |
|
|
|
MODEL_NAMES = [ |
|
"Rostlab/prot_t5_xl_uniref50", |
|
"Rostlab/prot_t5_xl_bfd", |
|
"Rostlab/prot_t5_xxl_uniref50", |
|
"Rostlab/prot_t5_xxl_bfd", |
|
] |
|
|
|
@property |
|
def modality(self) -> Modality: |
|
return Modality.PROTEIN |
|
|
|
@property |
|
def num_layers(self) -> int: |
|
return self.config.num_layers |
|
|
|
@property |
|
def embed_dim(self) -> int: |
|
return self.config.d_model |
|
|
|
def _load_model(self, model_name): |
|
return T5EncoderModel.from_pretrained(model_name) |
|
|
|
def _get_tokenizer(self, model_name): |
|
return T5Tokenizer.from_pretrained(model_name, do_lower_case=False) |
|
|
|
def _tokenize_func( |
|
self, tokenizer, examples: Dict[str, List], max_seq_length: int |
|
) -> BatchEncoding: |
|
example_sequences = examples["input_seqs"] |
|
|
|
example_sequences = [" ".join(sequence) for sequence in example_sequences] |
|
example_sequences = [ |
|
re.sub(r"[UZOB]", "X", sequence) for sequence in example_sequences |
|
] |
|
batch_dict = tokenizer( |
|
example_sequences, |
|
max_length=max_seq_length, |
|
padding=True, |
|
truncation=True, |
|
add_special_tokens=True, |
|
) |
|
|
|
return batch_dict |
|
|
|
|
|
class ProGen(BioSeqTransformer): |
|
"""ProGen models from https://github.com/salesforce/progen.""" |
|
|
|
MODEL_NAMES = [ |
|
"hugohrban/progen2-small", |
|
"hugohrban/progen2-medium", |
|
"hugohrban/progen2-base", |
|
"hugohrban/progen2-large", |
|
"hugohrban/progen2-xlarge", |
|
] |
|
|
|
@property |
|
def modality(self) -> Modality: |
|
return Modality.PROTEIN |
|
|
|
@property |
|
def num_layers(self) -> int: |
|
return self.config.n_layer |
|
|
|
@property |
|
def embed_dim(self) -> int: |
|
return self.config.embed_dim |
|
|
|
def _load_model(self, model_name): |
|
return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) |
|
|
|
def _get_tokenizer(self, model_name_or_path): |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name_or_path, trust_remote_code=True |
|
) |
|
tokenizer.pad_token = "<|pad|>" |
|
return tokenizer |
|
|
|
def _encode_single_batch(self, batch_dict: Dict[str, Tensor]): |
|
"""Returns the output embedding for the given batch with shape [batch, num_layers, D].""" |
|
outputs: BaseModelOutput = self.encoder( |
|
input_ids=batch_dict["input_ids"], |
|
output_hidden_states=True, |
|
use_cache=False, |
|
) |
|
embeds = [outputs.hidden_states[layer] for layer in self.layers] |
|
embeds = [ |
|
pool(layer_embeds, batch_dict["attention_mask"], self.pool_type) |
|
for layer_embeds in embeds |
|
] |
|
|
|
embeds = torch.stack(embeds, dim=1) |
|
return embeds |
|
|
|
|
|
class EvoModel(BioSeqTransformer): |
|
"""https://github.com/evo-design/evo.""" |
|
|
|
MODEL_NAMES = [ |
|
"togethercomputer/evo-1-8k-base", |
|
"togethercomputer/evo-1-131k-base", |
|
] |
|
|
|
@property |
|
def modality(self) -> Modality: |
|
return Modality.DNA |
|
|
|
@property |
|
def num_layers(self) -> int: |
|
return self.config.num_layers |
|
|
|
@property |
|
def embed_dim(self) -> int: |
|
return self.config.hidden_size |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.hooks = [] |
|
for layer in self.layers: |
|
|
|
|
|
if layer == self.num_layers - 1 or layer == -1: |
|
self.hooks.append(ForwardHook(self.encoder.backbone.norm)) |
|
else: |
|
self.hooks.append(ForwardHook(self.encoder.backbone.blocks[layer])) |
|
|
|
def _load_model(self, model_name): |
|
config = AutoConfig.from_pretrained( |
|
model_name, trust_remote_code=True, revision="1.1_fix" |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, config=config, trust_remote_code=True, revision="1.1_fix" |
|
) |
|
return model |
|
|
|
def _get_tokenizer(self, model_name): |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, revision="1.1_fix", trust_remote_code=True |
|
) |
|
|
|
tokenizer.add_special_tokens({"pad_token": "N"}) |
|
return tokenizer |
|
|
|
def _encode_single_batch(self, batch_dict: Dict[str, Tensor]): |
|
_ = self.encoder(batch_dict["input_ids"], use_cache=False) |
|
embeds = [hook.output for hook in self.hooks] |
|
|
|
embeds = [x[0] if isinstance(x, tuple) else x for x in embeds] |
|
embeds = [ |
|
pool(layer_embeds, batch_dict["attention_mask"], self.pool_type) |
|
for layer_embeds in embeds |
|
] |
|
|
|
embeds = torch.stack(embeds, dim=1) |
|
embeds = embeds.to(torch.float32) |
|
return embeds |
|
|
|
|
|
class NTModel(BioSeqTransformer): |
|
"""Nucleotide Transformer https://github.com/instadeepai/nucleotide-transformer""" |
|
|
|
MODEL_NAMES = [ |
|
"InstaDeepAI/nucleotide-transformer-v2-50m-multi-species", |
|
"InstaDeepAI/nucleotide-transformer-v2-100m-multi-species", |
|
"InstaDeepAI/nucleotide-transformer-v2-250m-multi-species", |
|
"InstaDeepAI/nucleotide-transformer-v2-500m-multi-species", |
|
"InstaDeepAI/nucleotide-transformer-2.5b-multi-species", |
|
] |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.max_seq_length = self.tokenizer.model_max_length |
|
|
|
@property |
|
def modality(self) -> Modality: |
|
return Modality.DNA |
|
|
|
@property |
|
def num_layers(self) -> int: |
|
return self.config.num_hidden_layers |
|
|
|
@property |
|
def embed_dim(self) -> int: |
|
return self.config.hidden_size |
|
|
|
def _load_model(self, model_name): |
|
return AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) |
|
|