IgT5_unpaired / README.md
exs-fdreyer's picture
Update README.md
4222993 verified
|
raw
history blame
2.43 kB
metadata
tags:
  - antibody language model
  - antibody
  - protein language model
base_model: Rostlab/prot_t5_xl_uniref50
license: mit

IgT5 unpaired model

Model pretrained on protein and antibody sequences using a masked language modeling (MLM) objective. It was introduced in the paper Large scale paired antibody language models.

The model is finetuned from ProtT5 using unpaired antibody sequences from the Observed Antibody Space.

Use

The encoder part of the model and tokeniser can be loaded using the transformers library

from transformers import T5EncoderModel, T5Tokenizer

tokeniser = T5Tokenizer.from_pretrained("Exscientia/IgT5_unpaired", do_lower_case=False)
model = T5EncoderModel.from_pretrained("Exscientia/IgT5_unpaired")

The tokeniser is used to prepare batch inputs

# single chain sequences
sequences = [
    "EVVMTQSPASLSVSPGERATLSCRARASLGISTDLAWYQQRPGQAPRLLIYGASTRATGIPARFSGSGSGTEFTLTISSLQSEDSAVYYCQQYSNWPLTFGGGTKVEIK",
    "ALTQPASVSGSPGQSITISCTGTSSDVGGYNYVSWYQQHPGKAPKLMIYDVSKRPSGVSNRFSGSKSGNTASLTISGLQSEDEADYYCNSLTSISTWVFGGGTKLTVL"
]

# The tokeniser expects input of the form ["E V V M...", "A L T Q..."]
sequences = [' '.join(sequence) for sequence in sequences] 

tokens = tokeniser.batch_encode_plus(
    sequences, 
    add_special_tokens=True, 
    pad_to_max_length=True, 
    return_tensors="pt",
    return_special_tokens_mask=True
) 

Note that the tokeniser adds a </s> token at the end of each sequence and pads using the <pad> token. For example a batch containing sequences E V V M, A L will be tokenised to E V V M </s> and A L </s> <pad> <pad>.

Sequence embeddings are generated by feeding tokens through the model

output = model(
    input_ids=tokens['input_ids'], 
    attention_mask=tokens['attention_mask']
)

residue_embeddings = output.last_hidden_state

To obtain a sequence representation, the residue tokens can be averaged over like so

import torch

# mask special tokens before summing over embeddings
residue_embeddings[tokens["special_tokens_mask"] == 1] = 0
sequence_embeddings_sum = residue_embeddings.sum(1)

# average embedding by dividing sum by sequence lengths
sequence_lengths = torch.sum(tokens["special_tokens_mask"] == 0, dim=1)
sequence_embeddings = sequence_embeddings_sum / sequence_lengths.unsqueeze(1)