Mismatched outputs from encoders of `transformers` and `whisper`

#9
by JinchaoLove - opened

Hi, there! I found the outputs of encoders from transformers and whisper are not matched, may I ask why? Here're my codes:

import torch
import whisper
import transformers as ppb
x = torch.randn(1, 80, 3000)  # random input feature
enc1 = ppb.models.whisper.modeling_whisper.WhisperEncoder.from_pretrained('openai/whisper-small')
enc2 = whisper.load_model('small').encoder
y1 = enc1(x)
y2 = enc2(x)
print(torch.sum(abs(y1.last_hidden_state - y2)))  # expected 0, but got > 1e6

The proposed method of loading the WhisperEncoder from_pretrained is resulting in none of the pre-trained weights being loaded:

import transformers as ppb

enc1 = ppb.models.whisper.modeling_whisper.WhisperEncoder.from_pretrained('openai/whisper-small')
Warning message:
Some weights of WhisperEncoder were not initialized from the model checkpoint at openai/whisper-small and are newly initialized: ['model.layers.3.self_attn.v_proj.weight', 'model.layers.6.self_attn_layer_norm.weight', 'model.layers.0.self_attn_layer_norm.bias', 'model.layers.3.final_layer_norm.bias', 'model.layers.2.fc2.weight', 'model.layers.9.fc2.bias', 'model.layers.6.self_attn_layer_norm.bias', 'model.layers.6.self_attn.v_proj.bias', 'model.layers.10.self_attn.q_proj.bias', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.9.fc1.weight', 'model.layers.1.final_layer_norm.weight', 'model.layers.1.self_attn.q_proj.bias', 'model.layers.9.fc1.bias', 'model.layers.1.self_attn.q_proj.weight', 'model.conv2.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.bias', 'model.layers.3.final_layer_norm.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.4.self_attn.out_proj.weight', 'model.layers.11.final_layer_norm.bias', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.final_layer_norm.bias', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.1.fc1.weight', 'model.layers.5.fc2.bias', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.8.self_attn.out_proj.bias', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.6.final_layer_norm.bias', 'model.layers.10.fc1.weight', 'model.layers.11.self_attn_layer_norm.bias', 'model.layers.6.fc1.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.10.final_layer_norm.weight', 'model.layers.7.self_attn.v_proj.bias', 'model.layers.1.self_attn_layer_norm.weight', 'model.layers.3.fc2.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.conv2.bias', 'model.layers.11.self_attn.out_proj.bias', 'model.layers.11.fc2.weight', 'model.layers.0.fc1.bias', 'model.layer_norm.bias', 'model.layers.10.self_attn_layer_norm.weight', 'model.layers.5.fc1.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.5.self_attn.out_proj.weight', 'model.layers.3.self_attn_layer_norm.bias', 'model.layers.3.fc1.weight', 'model.layers.1.self_attn.out_proj.weight', 'model.layers.4.final_layer_norm.bias', 'model.conv1.bias', 'model.layers.5.self_attn.out_proj.bias', 'model.layers.4.self_attn.out_proj.bias', 'model.layers.5.fc2.weight', 'model.layers.6.self_attn.out_proj.bias', 'model.layers.4.final_layer_norm.weight', 'model.layers.10.fc2.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.fc2.weight', 'model.layers.2.self_attn.q_proj.bias', 'model.layers.4.fc1.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.final_layer_norm.weight', 'model.layers.9.self_attn.q_proj.bias', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.0.fc1.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.4.fc1.bias', 'model.layers.7.self_attn.out_proj.weight', 'model.layers.11.fc2.bias', 'model.layers.2.self_attn_layer_norm.bias', 'model.layers.5.fc1.bias', 'model.layers.9.self_attn_layer_norm.bias', 'model.layers.6.fc1.bias', 'model.layers.9.self_attn.v_proj.bias', 'model.layers.6.fc2.weight', 'model.layers.11.final_layer_norm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.fc2.weight', 'model.layers.7.final_layer_norm.weight', 'model.layers.10.self_attn.out_proj.weight', 'model.layers.5.self_attn.q_proj.bias', 'model.layers.10.self_attn.out_proj.bias', 'model.layers.11.fc1.bias', 'model.layers.2.fc1.weight', 'model.layers.2.final_layer_norm.weight', 'model.layers.7.final_layer_norm.bias', 'model.layers.3.self_attn.v_proj.bias', 'model.layers.4.self_attn.q_proj.bias', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.8.fc2.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.1.final_layer_norm.bias', 'model.layers.2.self_attn_layer_norm.weight', 'model.layers.5.final_layer_norm.weight', 'model.layers.8.self_attn_layer_norm.bias', 'model.layers.7.self_attn.q_proj.bias', 'model.layers.10.self_attn_layer_norm.bias', 'model.layers.5.self_attn.v_proj.bias', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.3.self_attn.out_proj.bias', 'model.layers.9.final_layer_norm.bias', 'model.conv1.weight', 'model.layers.10.fc1.bias', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.1.fc2.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.3.self_attn.out_proj.weight', 'model.layers.8.self_attn.out_proj.weight', 'model.layers.3.fc2.bias', 'model.layers.6.self_attn.q_proj.bias', 'model.layers.7.self_attn.out_proj.bias', 'model.layers.3.fc1.bias', 'model.layers.10.final_layer_norm.bias', 'model.layers.9.final_layer_norm.weight', 'model.layers.1.fc2.bias', 'model.layers.1.fc1.bias', 'model.layers.9.fc2.weight', 'model.layers.7.fc2.bias', 'model.layers.6.self_attn.v_proj.weight', 'model.layer_norm.weight', 'model.layers.8.fc1.bias', 'model.layers.8.self_attn_layer_norm.weight', 'model.layers.7.fc1.weight', 'model.layers.2.self_attn.out_proj.bias', 'model.layers.8.self_attn.v_proj.bias', 'model.layers.6.fc2.bias', 'model.layers.0.fc2.bias', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.8.final_layer_norm.weight', 'model.layers.11.self_attn.q_proj.bias', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.0.self_attn.q_proj.bias', 'model.layers.0.final_layer_norm.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.8.fc2.bias', 'model.layers.0.self_attn_layer_norm.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.7.fc2.weight', 'model.layers.4.self_attn_layer_norm.weight', 'model.layers.6.self_attn.out_proj.weight', 'model.layers.11.self_attn_layer_norm.weight', 'model.layers.5.self_attn_layer_norm.weight', 'model.layers.4.self_attn.v_proj.bias', 'model.layers.5.final_layer_norm.bias', 'model.layers.4.fc2.bias', 'model.layers.9.self_attn.out_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.4.self_attn_layer_norm.bias', 'model.layers.10.fc2.bias', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.0.self_attn.out_proj.bias', 'model.layers.2.self_attn.out_proj.weight', 'model.layers.1.self_attn.out_proj.bias', 'model.layers.7.fc1.bias', 'model.layers.2.fc1.bias', 'model.layers.8.self_attn.q_proj.bias', 'model.layers.10.self_attn.v_proj.bias', 'model.layers.2.fc2.bias', 'model.layers.7.self_attn_layer_norm.bias', 'model.layers.11.self_attn.out_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.2.final_layer_norm.bias', 'model.layers.11.fc1.weight', 'model.layers.3.self_attn_layer_norm.weight', 'model.layers.0.self_attn.out_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.9.self_attn.out_proj.bias', 'model.layers.2.self_attn.v_proj.bias', 'model.layers.9.self_attn_layer_norm.weight', 'model.layers.3.self_attn.q_proj.bias', 'model.layers.1.self_attn_layer_norm.bias', 'model.layers.0.final_layer_norm.bias', 'model.layers.1.self_attn.v_proj.bias', 'model.layers.0.self_attn.v_proj.bias', 'model.layers.7.self_attn_layer_norm.weight', 'model.embed_positions.weight', 'model.layers.5.self_attn_layer_norm.bias', 'model.layers.8.fc1.weight']

Instead, we should load all of the encoder-decoder weights using WhisperForConditionalGeneration and then extract the encoder module. This is the same logic we are using for the OpenAI implementation. When we do so, the maximum element-wise difference between the HF implementation and the OpenAI implementation is 8.5e-5 (to within numerical precision):

import torch
from transformers import WhisperForConditionalGeneration
import whisper

x = torch.randn(1, 80, 3000)  # random input feature

enc1 = WhisperForConditionalGeneration.from_pretrained('openai/whisper-small').model.encoder
enc2 = whisper.load_model('small').encoder

with torch.no_grad():
    y1 = enc1(x)
    y2 = enc2(x)

print(torch.max(abs(y1.last_hidden_state - y2)))

Print Output:

tensor(8.5831e-05)

(we could probably fix this by adding base_model_prefix = "model" to the WhisperEncoder class WDYT @sanchit-gandhi )

Sign up or log in to comment