Edit model card
YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Model descriptions

PAIR (paper) is a flexible fine-tuning framework to improve the quality of protein representations for function predictions. PAIR uses a text decoder to guide the fine-tuning process of a protein encoder so that the learned representations can extract information contained within the diverse set of annotations in Swiss-Prot. This model fine-tunes ESM-1b (repo) with PAIR.

Intended use

The model can be used for feature extractions in protein function prediction tasks.

How to load the model for feature extractions?

from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = AutoModel.from_pretrained("h4duan/PAIR-esm1b").to("cuda")
protein = ["AETCZAO"]

def extract_feature(protein):
  ids = tokenizer(protein, return_tensors="pt", padding=True, max_length=1024, truncation=True, return_attention_mask=True)
  input_ids = torch.tensor(ids['input_ids']).to("cuda")
  attention_mask = torch.tensor(ids['attention_mask']).to("cuda")
  with torch.no_grad():
    embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
  return torch.mean(embedding_repr, dim=1)

feature = extract_feature(protein)

How to extract the features in batch?

proteins = ["AETCZAO","SKTZP"]

def extract_features_batch(proteins):
  ids = tokenizer(proteins, return_tensors="pt", padding=True, max_length=1024, truncation=True, return_attention_mask=True)
  input_ids = torch.tensor(ids['input_ids']).to("cuda")
  attention_mask = torch.tensor(ids['attention_mask']).to("cuda")
  with torch.no_grad():
    embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask).last_hidden_state
  attention_mask = attention_mask.unsqueeze(-1)
  attention_mask = attention_mask.expand(-1, -1, embedding_repr.size(-1))
  masked_embedding_repr = embedding_repr * attention_mask
  sum_embedding_repr = masked_embedding_repr.sum(dim=1)
  non_zero_count = attention_mask.sum(dim=1) 
  mean_embedding_repr = sum_embedding_repr / non_zero_count
  return mean_embedding_repr

feature = extract_feature_batch(proteins)
Downloads last month
4
Safetensors
Model size
652M params
Tensor type
F32
·
Inference API
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.