Mismatch Between Logits Output Size and Vocab Size in ESMplusplus Model for MLM Task

#10
by acow - opened

I am using the ESMplusplus model and trying to fine-tune the logits output by AutoModelForMaskedLM.from_pretrained. The size of the last dimension of the output logits is 64, which implies that the vocab_size should also be 64. However, when I print the vocab size of the ESM++ model, it is only 33. This results in a mismatch between the dictionary size and the logits output size. Logically, the logits size in an MLM task should match the dictionary size. Could you help me understand why this discrepancy occurs

Synthyra org

Hi @acow ,

The vocab size really is 33, the original authors decided to add some extra tokens. We don't have a concrete reason for this but our assumption is that it is to make CUDA operations more efficient, which typically work better on multiples of 8. For all intents and purposes you can essentially treat it as though there is an output size of 33.
Best,
Logan

If I want to index 20 standard amino acids from 64, I should index according to the indices in vocab_size=33, right?

Synthyra org

Hi @acrow ,

You can use the tokenizer to index for specific amino acids as needed. But yes, the canonical vocabulary is within the first 33 positions of the 64 dimensional vectors. Here's some examples:

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True).to(device).eval()
tokenizer = model.tokenizer

# You can get the same tokenizer without loading the model
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')


sequence = 'MPEEVHLGEKEVETFAFQAEIAQLMSLIINTFYSNKEIFLWELISNASDALDKIRYESLTDPSKLDSGKELKIDIIPNTQEHTLTLVDTGIGMTKADLINNLGTIAKFQDQTEYLEEMQVKEVVEKHSQFLGYPITLYLEKEREKEISDGKAEEEKGEKEEENKDDEEKPKIEDVGSDEEDDSGKDKKKKTKKIKEKYIDQEELNKTKPIWTRNTEDITQEEYGEFYKSLTNDWKDHLAVRYFSVEEYVSRMKEIQKSIYYITGESKEQVANSAFVEQVWKRDSRVVYMTEPIDGYQLKEFDGKSLVSVTKEGLELPEDGEEKKRMEERKAKFENLCKFMKETLDKKVEMVTVSNRLVSSSCCIVTSTYSWTANMEQIMKA'
sequence_masked = 'MPE<mask>VHLGEKEVETFAFQAEIAQLMSLIINTFYSNKEIFLWELISNASDALDKIRYESLTDPSKLDSGKELKIDIIPNTQEHTLTLVDTGIGMTKADLINNLGTIAKFQDQTEYLEEMQVKEVVEKHSQFLGYPITLYLEKEREKEISDGKAEEEKGEKEEENKDDEEKPKIEDVGSDEEDDSGKDKKKKTKKIKEKYIDQEELNKTKPIWTRNTEDITQEEYGEFYKSLTNDWKDHLAVRYFSVEEYVSRMKEIQKSIYYITGESKEQVANSAFVEQVWKRDSRVVYMTEPIDGYQLKEFDGKSLVSVTKEGLELPEDGEEKKRMEERKAKFENLCKFMKETLDKKVEMVTVSNRLVSSSCCIVTSTYSWTANMEQIMKA'

input_ids = tokenizer.encode(sequence, return_tensors='pt', add_special_tokens=True)
print(f'Input ids: {input_ids[0][:10]}')

masked_input_ids = tokenizer.encode(sequence_masked, return_tensors='pt', add_special_tokens=True)
print(f'Masked input ids: {masked_input_ids[0][:10]}')

logits = model(input_ids=input_ids.to(device)).logits.detach().cpu()
masked_logits = model(input_ids=masked_input_ids.to(device)).logits.detach().cpu()
print(f'Logits: {logits.shape}')
Input ids: tensor([ 0, 20, 14,  9,  9,  7, 21,  4,  6,  9])
Masked input ids: tensor([ 0, 20, 14,  9, 32,  7, 21,  4,  6,  9])
Logits: torch.Size([1, 383, 64])
def E_at_position_i(logits, i):
    probabilities = logits.softmax(dim=-1)
    # You can easily remove the special tokens on the end, or leave them in and account for the indexing
    probabilities = probabilities[:, 1:-1, :]

    # You can use encode or convert_tokens_to_ids to get the token id for a given token
    token_for_E = tokenizer.encode('E', add_special_tokens=False)[0]
    print(f'Token for E: {token_for_E}')

    token_for_E = tokenizer.convert_tokens_to_ids('E')
    print(f'Token for E: {token_for_E}')

    # You can go backwards with decode or convert_ids_to_tokens
    string_for_E = tokenizer.decode(token_for_E, add_special_tokens=False)
    print(f'String for E: {string_for_E}')

    string_for_E = tokenizer.convert_ids_to_tokens(token_for_E)
    print(f'String for E: {string_for_E}')

    probs_for_E = probabilities[0, :, token_for_E]
    print(f'Probabilities for E: {probs_for_E.shape}')

    probs_for_E_at_position_i = probs_for_E[i]
    print(f'Probability for E at position {i}: {probs_for_E_at_position_i:.2f}')

E_at_position_i(logits, 3)
E_at_position_i(masked_logits, 3)
Token for E: 9
Token for E: 9
String for E: E
String for E: E
Probabilities for E: torch.Size([381])
Probability for E at position 3: 0.95
Token for E: 9
Token for E: 9
String for E: E
String for E: E
Probabilities for E: torch.Size([381])
Probability for E at position 3: 0.92

Luckily, in both the masked and unmasked cases we get a high probability for E at that position. This is expected, since there was an actual E at that position: MPEEVHL...
Be careful of 0 indexing in Python.

Hope this helps!
-- Logan

Sign up or log in to comment