PAIR-esm2 / README.md
h4duan's picture
Update README.md
bf1606e verified
|
raw
history blame
2.23 kB

Model descriptions

PAIR (paper) is a flexible fine-tuning framework to improve the quality of protein representations for function predictions. PAIR uses a text decoder to guide the fine-tuning process of a protein encoder so that the learned representations can extract information contained within the diverse set of annotations in Swiss-Prot. This model fine-tunes ESM2-650M (repo) with PAIR.

Intended use

The model can be used for feature extractions in protein function prediction tasks.

How to load the model for feature extractions?

from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = AutoModel.from_pretrained("h4duan/PAIR-esm2")
protein = ["AETCZAO"]

def extract_feature(protein):
  ids = tokenizer(protein, return_tensors="pt", padding=True, max_length=1024, truncation=True, return_attention_mask=True)
  input_ids = torch.tensor(ids['input_ids']).to(self.device)
  attention_mask = torch.tensor(ids['attention_mask']).to(self.device)
  with torch.no_grad():
    embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
  return torch.mean(embedding_repr)

How to extract the features in batch?

proteins = ["AETCZAO","SKTZP"]
def extract_features_batch(proteins):
  ids = tokenizer(proteins, return_tensors="pt", padding=True, max_length=1024, truncation=True, return_attention_mask=True)
  input_ids = torch.tensor(ids['input_ids']).to(self.device)
  attention_mask = torch.tensor(ids['attention_mask']).to(self.device)
  with torch.no_grad():
    embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
  attention_mask = attention_mask.unsqueeze(-1)
  attention_mask = attention_mask.expand(-1, -1, embedding_repr.size(-1))
  masked_embedding_repr = embedding_repr * attention_mask
  sum_embedding_repr = masked_embedding_repr.sum(dim=1)
  non_zero_count = attention_mask.sum(dim=1) 
  mean_embedding_repr = sum_embedding_repr / non_zero_count
  return mean_embedding_repr