Model descriptions
PAIR (paper) is a flexible fine-tuning framework to improve the quality of protein representations for function predictions. PAIR uses a text decoder to guide the fine-tuning process of a protein encoder so that the learned representations can extract information contained within the diverse set of annotations in Swiss-Prot. This model fine-tunes ESM2-650M (repo) with PAIR.
Intended use
The model can be used for feature extractions in protein function prediction tasks.
How to load the model for feature extractions?
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = AutoModel.from_pretrained("h4duan/PAIR-esm2").to("cuda")
protein = ["AETCZAO"]
def extract_feature(protein):
ids = tokenizer(protein, return_tensors="pt", padding=True, max_length=1024, truncation=True, return_attention_mask=True)
input_ids = torch.tensor(ids['input_ids']).to("cuda")
attention_mask = torch.tensor(ids['attention_mask']).to("cuda")
with torch.no_grad():
embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
return torch.mean(embedding_repr, dim=1)
feature = extract_feature(protein)
How to extract the features in batch?
proteins = ["AETCZAO","SKTZP"]
def extract_features_batch(proteins):
ids = tokenizer(proteins, return_tensors="pt", padding=True, max_length=1024, truncation=True, return_attention_mask=True)
input_ids = torch.tensor(ids['input_ids']).to("cuda")
attention_mask = torch.tensor(ids['attention_mask']).to("cuda")
with torch.no_grad():
embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
attention_mask = attention_mask.unsqueeze(-1)
attention_mask = attention_mask.expand(-1, -1, embedding_repr.size(-1))
masked_embedding_repr = embedding_repr * attention_mask
sum_embedding_repr = masked_embedding_repr.sum(dim=1)
non_zero_count = attention_mask.sum(dim=1)
mean_embedding_repr = sum_embedding_repr / non_zero_count
return mean_embedding_repr
feature = extract_features_batch(proteins)