|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torchaudio |
|
from matplotlib.animation import FuncAnimation |
|
|
|
def l2_normalize(matrix): |
|
""" |
|
L2 Normalize the matrix along its rows. |
|
|
|
Parameters: |
|
matrix (numpy.ndarray): The input matrix. |
|
|
|
Returns: |
|
numpy.ndarray: The L2 normalized matrix. |
|
""" |
|
l2_norms = np.linalg.norm(matrix, axis=1, keepdims=True) |
|
normalized_matrix = matrix / l2_norms |
|
return normalized_matrix |
|
|
|
|
|
def z_normalize(matrix): |
|
""" |
|
Z-normalize the matrix along its rows (mean=0 and std=1). |
|
Z-normalization is also known as "standardization", and derives from z-score. |
|
Z = (X - mean) / std |
|
Z-nomarlized, each row has mean=0 and std=1. |
|
|
|
Parameters: |
|
matrix (numpy.ndarray): The input matrix. |
|
|
|
Returns: |
|
numpy.ndarray: The Z normalized matrix. |
|
""" |
|
mean = np.mean(matrix, axis=1, keepdims=True) |
|
std = np.std(matrix, axis=1, keepdims=True) |
|
normalized_matrix = (matrix - mean) / std |
|
return normalized_matrix |
|
|
|
|
|
def l2_normalize_tensors(tensor_tuple): |
|
""" |
|
Applies L2 normalization on the last two dimensions for each tensor in a tuple. |
|
|
|
Parameters: |
|
tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, each of shape (1, k, 30, 30). |
|
|
|
Returns: |
|
tuple of torch.Tensor: A tuple containing N L2-normalized tensors. |
|
""" |
|
normalized_tensors = [] |
|
for tensor in tensor_tuple: |
|
|
|
tensor = tensor.float() |
|
|
|
|
|
l2_norm = torch.linalg.norm(tensor, dim=(-2, -1), keepdim=True) |
|
|
|
|
|
normalized_tensor = tensor / ( |
|
l2_norm + 1e-7) |
|
|
|
normalized_tensors.append(normalized_tensor) |
|
|
|
return tuple(normalized_tensors) |
|
|
|
|
|
def z_normalize_tensors(tensor_tuple): |
|
""" |
|
Applies Z-normalization on the last two dimensions for each tensor in a tuple. |
|
|
|
Parameters: |
|
tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, each of shape (1, k, 30, 30). |
|
|
|
Returns: |
|
tuple of torch.Tensor: A tuple containing N Z-normalized tensors. |
|
""" |
|
normalized_tensors = [] |
|
for tensor in tensor_tuple: |
|
|
|
tensor = tensor.float() |
|
|
|
|
|
mean = tensor.mean(dim=(-2, -1), keepdim=True) |
|
std = tensor.std(dim=(-2, -1), keepdim=True) |
|
|
|
|
|
normalized_tensor = (tensor - mean) / ( |
|
std + 1e-7) |
|
|
|
normalized_tensors.append(normalized_tensor) |
|
|
|
return tuple(normalized_tensors) |
|
|
|
|
|
def apply_temperature_to_attention_tensors(tensor_tuple, temperature=1.0): |
|
""" |
|
Applies temperature scaling to the attention weights in each tensor in a tuple. |
|
|
|
Parameters: |
|
tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, |
|
each of shape (1, k, 30, 30). |
|
temperature (float): Temperature parameter to control the sharpness |
|
of the attention weights. Default is 1.0. |
|
|
|
Returns: |
|
tuple of torch.Tensor: A tuple containing N tensors with scaled attention weights. |
|
""" |
|
scaled_attention_tensors = [] |
|
|
|
for tensor in tensor_tuple: |
|
|
|
tensor = tensor.float() |
|
|
|
|
|
flattened_tensor = tensor.reshape(1, tensor.shape[1], |
|
-1) |
|
|
|
|
|
scaled_attention = flattened_tensor / temperature |
|
scaled_attention = F.softmax(scaled_attention, dim=-1) |
|
|
|
|
|
scaled_attention = scaled_attention.view_as(tensor) |
|
|
|
scaled_attention_tensors.append(scaled_attention) |
|
|
|
return tuple(scaled_attention_tensors) |
|
|
|
|
|
def shorten_att(tensor_tuple, length=30): |
|
shortend_tensors = [] |
|
for tensor in tensor_tuple: |
|
shortend_tensors.append(tensor[:, :, :length, :length]) |
|
return tuple(shortend_tensors) |
|
|
|
|
|
def keep_top_k(matrix, k=6): |
|
""" |
|
Keep only the top k values in each row, set the rest to 0. |
|
|
|
Parameters: |
|
matrix (numpy.ndarray): The input matrix. |
|
k (int): The number of top values to keep in each row. |
|
|
|
Returns: |
|
numpy.ndarray: The transformed matrix. |
|
""" |
|
topk_indices_per_row = np.argpartition(matrix, -k, axis=1)[:, -k:] |
|
result_matrix = np.zeros_like(matrix) |
|
|
|
for i in range(matrix.shape[0]): |
|
result_matrix[i, topk_indices_per_row[i]] = matrix[ |
|
i, topk_indices_per_row[i]] |
|
return result_matrix |
|
|
|
|
|
def test_case_forward_enc_perceiver_tf_dec_multi_t5(): |
|
import torch |
|
from model.ymt3 import YourMT3 |
|
from config.config import audio_cfg, model_cfg, shared_cfg |
|
model_cfg["encoder_type"] = "perceiver-tf" |
|
|
|
model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = True |
|
model_cfg["encoder"]["perceiver-tf"]["num_latents"] = 26 |
|
|
|
model_cfg["decoder_type"] = "multi-t5" |
|
|
|
audio_cfg["codec"] = "spec" |
|
audio_cfg["hop_length"] = 300 |
|
model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint = torch.load( |
|
"../logs/ymt3/ptf_mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k/checkpoints/model.ckpt", |
|
map_location="cpu") |
|
state_dict = checkpoint['state_dict'] |
|
new_state_dict = { |
|
k: v |
|
for k, v in state_dict.items() if 'pitchshift' not in k |
|
} |
|
model.load_state_dict(new_state_dict, strict=False) |
|
|
|
latents = model.encoder.latent_array.latents.detach().numpy() |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
cos = cosine_similarity(latents) |
|
|
|
from utils.data_modules import AMTDataModule |
|
from einops import rearrange |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x, _ = torchaudio.load('piano.wav') |
|
x = x.unsqueeze(0) |
|
|
|
|
|
x_spec = model.spectrogram(x) |
|
x_conv = model.pre_encoder(x_spec) |
|
|
|
plt.figure( |
|
figsize=(15, |
|
10)) |
|
plt.subplot(2, 4, 1) |
|
plt.imshow(x_spec[0].detach().numpy().T, aspect='auto', origin='lower') |
|
plt.title("spectrogram") |
|
plt.xlabel('time step') |
|
plt.ylabel('frequency bin') |
|
plt.subplot(2, 4, 2) |
|
plt.imshow(x_conv[0][:, :, 0].detach().numpy().T, |
|
aspect='auto', |
|
origin='lower') |
|
plt.title("conv(spec), ch=0") |
|
plt.xlabel('time step') |
|
plt.ylabel('F') |
|
plt.subplot(2, 4, 3) |
|
plt.imshow(x_conv[0][:, :, 42].detach().numpy().T, |
|
aspect='auto', |
|
origin='lower') |
|
plt.title("ch=42") |
|
plt.xlabel('time step') |
|
plt.ylabel('F') |
|
plt.subplot(2, 4, 4) |
|
plt.imshow(x_conv[0][:, :, 80].detach().numpy().T, |
|
aspect='auto', |
|
origin='lower') |
|
plt.title("ch=80") |
|
plt.xlabel('time step') |
|
plt.ylabel('F') |
|
plt.subplot(2, 4, 5) |
|
plt.imshow(x_conv[0][:, :, 11].detach().numpy().T, |
|
aspect='auto', |
|
origin='lower') |
|
plt.title("ch=11") |
|
plt.xlabel('time step') |
|
plt.ylabel('F') |
|
plt.subplot(2, 4, 6) |
|
plt.imshow(x_conv[0][:, :, 20].detach().numpy().T, |
|
aspect='auto', |
|
origin='lower') |
|
plt.title("ch=20") |
|
plt.xlabel('time step') |
|
plt.ylabel('F') |
|
plt.subplot(2, 4, 7) |
|
plt.imshow(x_conv[0][:, :, 77].detach().numpy().T, |
|
aspect='auto', |
|
origin='lower') |
|
plt.title("ch=77") |
|
plt.xlabel('time step') |
|
plt.ylabel('F') |
|
plt.subplot(2, 4, 8) |
|
plt.imshow(x_conv[0][:, :, 90].detach().numpy().T, |
|
aspect='auto', |
|
origin='lower') |
|
plt.title("ch=90") |
|
plt.xlabel('time step') |
|
plt.ylabel('F') |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
output = model.encoder(inputs_embeds=x_conv, |
|
output_hidden_states=True, |
|
output_attentions=True) |
|
enc_hs_all, att, catt = output["hidden_states"], output[ |
|
"attentions"], output["cross_attentions"] |
|
enc_hs_last = enc_hs_all[2] |
|
|
|
|
|
plt.subplot(2, 3, 1) |
|
plt.imshow(enc_hs_all[0][0][:, :, 21].detach().numpy().T) |
|
plt.title('ENC_HS B0, d21') |
|
plt.colorbar(orientation='horizontal') |
|
plt.ylabel('latent k') |
|
plt.xlabel('t') |
|
plt.subplot(2, 3, 4) |
|
plt.imshow(enc_hs_all[0][0][:, :, 127].detach().numpy().T) |
|
plt.colorbar(orientation='horizontal') |
|
plt.title('B0, d127') |
|
plt.ylabel('latent k') |
|
plt.xlabel('t') |
|
plt.subplot(2, 3, 2) |
|
plt.imshow(enc_hs_all[1][0][:, :, 21].detach().numpy().T) |
|
plt.colorbar(orientation='horizontal') |
|
plt.title('B1, d21') |
|
plt.ylabel('latent k') |
|
plt.xlabel('t') |
|
plt.subplot(2, 3, 5) |
|
plt.imshow(enc_hs_all[1][0][:, :, 127].detach().numpy().T) |
|
plt.colorbar(orientation='horizontal') |
|
plt.title('B1, d127') |
|
plt.ylabel('latent k') |
|
plt.xlabel('t') |
|
plt.subplot(2, 3, 3) |
|
plt.imshow(enc_hs_all[2][0][:, :, 21].detach().numpy().T) |
|
plt.colorbar(orientation='horizontal') |
|
plt.title('B2, d21') |
|
plt.ylabel('latent k') |
|
plt.xlabel('t') |
|
plt.subplot(2, 3, 6) |
|
plt.imshow(enc_hs_all[2][0][:, :, 127].detach().numpy().T) |
|
plt.colorbar(orientation='horizontal') |
|
plt.title('B2, d127') |
|
plt.ylabel('latent k') |
|
plt.xlabel('t') |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
|
|
data = enc_hs_all[2][0].detach().numpy() |
|
fig, axs = plt.subplots( |
|
5, 5, figsize=(10, 9)) |
|
axs = axs.flatten( |
|
) |
|
|
|
for k in range(25): |
|
axs[k].imshow(data[:, k, :].T, |
|
cmap='viridis') |
|
axs[k].set_title(f'k={k}') |
|
axs[k].set_xlabel('Time step') |
|
axs[k].set_ylabel('Dim') |
|
|
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
enc_hs_proj = model.pre_decoder(enc_hs_last) |
|
fig, axs = plt.subplots(1, 13, figsize=(26, 8)) |
|
data = enc_hs_proj[0].detach().numpy() |
|
for ch in range(13): |
|
axs[ch].imshow(np.rot90(data[ch]), cmap='viridis') |
|
axs[ch].set_title(f'ch: {ch}') |
|
axs[ch].set_xlabel('Time step') |
|
axs[ch].set_ylabel('Dim') |
|
plt.suptitle( |
|
'linear projection of encoder outputs by channel, which is conditioning for enc-dec cross attention', |
|
y=0.1, |
|
fontsize=12) |
|
plt.tight_layout(rect=[0, 0.1, 1, 1]) |
|
plt.show() |
|
|
|
plt.subplot(221) |
|
plt.imshow(enc_hs_all[2][0][0, :, :].detach().numpy(), aspect='auto') |
|
plt.title('enc_hs, t=0') |
|
plt.ylabel('latent k') |
|
plt.xlabel('d') |
|
plt.subplot(222) |
|
plt.imshow(enc_hs_all[2][0][10, :, :].detach().numpy(), aspect='auto') |
|
plt.title('enc_hs, t=10') |
|
plt.ylabel('latent k') |
|
plt.xlabel('d') |
|
plt.subplot(223) |
|
plt.imshow(enc_hs_all[2][0][20, :, :].detach().numpy(), aspect='auto') |
|
plt.title('enc_hs, t=20') |
|
plt.ylabel('latent k') |
|
plt.xlabel('d') |
|
plt.subplot(224) |
|
plt.imshow(enc_hs_all[2][0][30, :, :].detach().numpy(), aspect='auto') |
|
plt.title('enc_hs, t=30') |
|
plt.ylabel('latent k') |
|
plt.xlabel('d') |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
plt.subplot(1, 3, 1) |
|
a = rearrange(enc_hs_last, '1 t k d -> t (k d)').detach().numpy() |
|
plt.imshow(cosine_similarity(a)) |
|
plt.title("enc hs, t x t cos_sim") |
|
plt.subplot(1, 3, 2) |
|
b = rearrange(enc_hs_last, '1 t k d -> k (t d)').detach().numpy() |
|
plt.imshow(cosine_similarity(b)) |
|
plt.title("enc hs, k x k cos_sim") |
|
plt.subplot(1, 3, 3) |
|
c = rearrange(enc_hs_last, '1 t k d -> d (k t)').detach().numpy() |
|
plt.imshow(cosine_similarity(c)) |
|
plt.title("cross att, d x d cos_sim") |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
plt.imshow(model.encoder.latent_array.latents.detach().numpy()) |
|
plt.title('latent array') |
|
plt.xlabel('d') |
|
plt.ylabel('latent k') |
|
plt.show() |
|
|
|
|
|
plt.subplot(311) |
|
plt.imshow( |
|
torch.sum(torch.sum(catt[0][0], axis=0), axis=0).detach().numpy()) |
|
plt.title('block=0') |
|
plt.ylabel('latent k') |
|
plt.xlabel('conv channel') |
|
plt.subplot(312) |
|
plt.imshow( |
|
torch.sum(torch.sum(catt[1][0], axis=0), axis=0).detach().numpy()) |
|
plt.title('block=1') |
|
plt.ylabel('latent k') |
|
plt.xlabel('conv channel') |
|
plt.subplot(313) |
|
plt.imshow( |
|
torch.sum(torch.sum(catt[2][0], axis=0), axis=0).detach().numpy()) |
|
plt.title('block=2') |
|
plt.ylabel('latent k') |
|
plt.xlabel('conv channel') |
|
|
|
|
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6)) |
|
|
|
|
|
def update(t): |
|
|
|
ax1.clear() |
|
ax2.clear() |
|
|
|
|
|
ax1.imshow(catt[2][0][t, 3, :, :].detach().numpy()) |
|
ax1.set_title(f'block=2, t={t}, head=3') |
|
ax1.set_ylabel('latent k'); ax1.set_xlabel('conv channel') |
|
|
|
|
|
ax2.imshow(catt[2][0][t, 5, :, :].detach().numpy()) |
|
ax2.set_title(f'block=2, t={t}, head=5') |
|
ax2.set_ylabel('latent k'); ax2.set_xlabel('conv channel') |
|
|
|
|
|
fig.tight_layout() |
|
|
|
|
|
anim = FuncAnimation(fig, update, frames=range(0, 110), interval=200) |
|
anim.save('animation.gif', writer='pillow', fps=5) |
|
|
|
|
|
|
|
fig, axs = plt.subplots(3, 1, figsize=(12, 18), gridspec_kw={'height_ratios': [1, 1, 0.5]}) |
|
|
|
|
|
ax_catt3, ax_catt5, ax_att_row = axs |
|
|
|
|
|
for i in range(8): |
|
ax_att_row = fig.add_subplot(3, 8, 17 + i) |
|
|
|
|
|
def combined_update_smaller_att(t): |
|
|
|
ax_catt3.clear() |
|
ax_catt3.imshow(catt[2][0][t, 3, :, :].detach().numpy()) |
|
ax_catt3.set_title(f'block=2, t={t}, head=3') |
|
ax_catt3.set_ylabel('latent k'); ax_catt3.set_xlabel('conv channel') |
|
|
|
|
|
ax_catt5.clear() |
|
ax_catt5.imshow(catt[2][0][t, 5, :, :].detach().numpy()) |
|
ax_catt5.set_title(f'block=2, t={t}, head=5') |
|
ax_catt5.set_ylabel('latent k'); ax_catt5.set_xlabel('conv channel') |
|
|
|
|
|
for i in range(8): |
|
ax = fig.add_subplot(3, 8, 17 + i) |
|
ax.clear() |
|
ax.imshow(att[0][1][t, i, :, :].detach().numpy(), cmap='viridis') |
|
ax.set_title(f't={t}, head={i}') |
|
ax.set_xlabel('k') |
|
ax.set_ylabel('k') |
|
ax.axis('square') |
|
|
|
|
|
fig.tight_layout() |
|
combined_anim_smaller_att = FuncAnimation(fig, combined_update_smaller_att, frames=range(0, 110), interval=200) |
|
combined_anim_smaller_att.save('combined_animation_smaller_att.gif', writer='pillow', fps=5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.subplot(231) |
|
plt.imshow(torch.sum(torch.sum(att[0][0], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B0L0') |
|
plt.xlabel('latent k') |
|
plt.ylabel('latent k') |
|
plt.subplot(234) |
|
plt.imshow(torch.sum(torch.sum(att[0][1], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B0L1') |
|
plt.xlabel('latent k') |
|
plt.ylabel('latent k') |
|
plt.subplot(232) |
|
plt.imshow(torch.sum(torch.sum(att[1][0], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B1L0') |
|
plt.xlabel('latent k') |
|
plt.ylabel('latent k') |
|
plt.subplot(235) |
|
plt.imshow(torch.sum(torch.sum(att[1][1], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B1L1') |
|
plt.xlabel('latent k') |
|
plt.ylabel('latent k') |
|
plt.subplot(233) |
|
plt.imshow(torch.sum(torch.sum(att[2][0], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B2L0') |
|
plt.xlabel('latent k') |
|
plt.ylabel('latent k') |
|
plt.subplot(236) |
|
plt.imshow(torch.sum(torch.sum(att[2][1], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B2L1') |
|
plt.xlabel('latent k') |
|
plt.ylabel('latent k') |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
bl = 0 |
|
data = att[bl][1].detach().numpy() |
|
time_steps = [30, 50, 100] |
|
fig, axs = plt.subplots( |
|
len(time_steps), 8, |
|
figsize=(16, 6)) |
|
for i, t in enumerate(time_steps): |
|
for head in range(8): |
|
axs[i, head].imshow(data[t, head, :, :], cmap='viridis') |
|
axs[i, head].set_title(f't={t}, head={head}') |
|
axs[i, head].set_xlabel('k') |
|
axs[i, head].set_ylabel('k') |
|
plt.suptitle( |
|
f'latent transformer block={bl}, last layer self-attention over time', |
|
y=0, |
|
fontsize=12) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
bl = 1 |
|
data = att[bl][1].detach().numpy() |
|
time_steps = [30, 50, 100] |
|
fig, axs = plt.subplots( |
|
len(time_steps), 8, |
|
figsize=(16, 6)) |
|
for i, t in enumerate(time_steps): |
|
for head in range(8): |
|
axs[i, head].imshow(data[t, head, :, :], cmap='viridis') |
|
axs[i, head].set_title(f't={t}, head={head}') |
|
axs[i, head].set_xlabel('k') |
|
axs[i, head].set_ylabel('k') |
|
plt.suptitle( |
|
f'latent transformer block={bl}, last layer self-attention over time', |
|
y=0, |
|
fontsize=12) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
bl = 2 |
|
data = att[bl][1].detach().numpy() |
|
time_steps = [30, 50, 100] |
|
fig, axs = plt.subplots( |
|
len(time_steps), 8, |
|
figsize=(16, 6)) |
|
for i, t in enumerate(time_steps): |
|
for head in range(8): |
|
axs[i, head].imshow(data[t, head, :, :], cmap='viridis') |
|
axs[i, head].set_title(f't={t}, head={head}') |
|
axs[i, head].set_xlabel('k') |
|
axs[i, head].set_ylabel('k') |
|
plt.suptitle( |
|
f'latent transformer block={bl}, last layer self-attention over time', |
|
y=0, |
|
fontsize=12) |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
plt.subplot(231) |
|
plt.imshow(torch.sum(torch.sum(att[0][2], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B0L2') |
|
plt.xlabel('t') |
|
plt.ylabel('t') |
|
plt.subplot(234) |
|
plt.imshow(torch.sum(torch.sum(att[0][3], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B0L3') |
|
plt.xlabel('t') |
|
plt.ylabel('t') |
|
plt.subplot(232) |
|
plt.imshow(torch.sum(torch.sum(att[1][2], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B1L2') |
|
plt.xlabel('t') |
|
plt.ylabel('t') |
|
plt.subplot(235) |
|
plt.imshow(torch.sum(torch.sum(att[1][3], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B1L3') |
|
plt.xlabel('t') |
|
plt.ylabel('t') |
|
plt.subplot(233) |
|
plt.imshow(torch.sum(torch.sum(att[2][2], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B2L2') |
|
plt.xlabel('t') |
|
plt.ylabel('t') |
|
plt.subplot(236) |
|
plt.imshow(torch.sum(torch.sum(att[2][3], axis=1), |
|
axis=0).detach().numpy(), |
|
origin='upper') |
|
plt.title('B2L3') |
|
plt.xlabel('t') |
|
plt.ylabel('t') |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
dec_input_ids = model.shift_right_fn(label) |
|
dec_inputs_embeds = model.embed_tokens(dec_input_ids) |
|
dec_output = model.decoder(inputs_embeds=dec_inputs_embeds, |
|
encoder_hidden_states=enc_hs_proj, |
|
output_attentions=True, |
|
output_hidden_states=True, |
|
return_dict=True) |
|
dec_att, dec_catt = dec_output.attentions, dec_output.cross_attentions |
|
dec_hs_all = dec_output.hidden_states |
|
dec_last_hs = dec_output.last_hidden_state |
|
|
|
|
|
logits = model.lm_head(dec_last_hs) |
|
|
|
|
|
pred_ids = torch.argmax(logits, dim=3) |
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
plt.imshow(torch.sum(dec_att[5][0], axis=0).detach().numpy()) |
|
plt.title('decoder attention, layer0') |
|
plt.xlabel('decoder time step') |
|
plt.ylabel('decoder time step') |
|
plt.subplot(1, 2, 2) |
|
plt.imshow(torch.sum(dec_att[7][0], axis=0).detach().numpy()) |
|
plt.title('decoder attention, final layer') |
|
plt.xlabel('decoder step') |
|
plt.show() |
|
|
|
|
|
|
|
def remove_values_after_eos(catt_np, pred_ids, max_k): |
|
|
|
|
|
max_length = pred_ids.shape[-1] |
|
seq_lengths = np.zeros((max_k), dtype=np.int32) |
|
for k in range(max_k): |
|
for t in range(max_length): |
|
if pred_ids[0, k, t] == 1: |
|
break |
|
catt_np[k, :, t+1:, :] = 0 |
|
|
|
seq_lengths[k] = t+1 |
|
return catt_np, seq_lengths |
|
|
|
|
|
l = 4 |
|
data = dec_catt[l].detach().numpy() |
|
data, seq_lengths = remove_values_after_eos(data, pred_ids, max_k=13) |
|
seq_lengths[:]= 256 |
|
|
|
fig, axs = plt.subplots(13, 6, figsize=(21, 39)) |
|
for k in range(13): |
|
s = seq_lengths[k] |
|
for head in range(6): |
|
axs[k, head].imshow(data[k, head, :s, :].T, aspect='auto', cmap='viridis') |
|
axs[k, head].set_title(f'Layer {l}, k={k}, head={head}') |
|
axs[k, head].set_xlabel('Decoder step') |
|
axs[k, head].set_ylabel('Encoder frame') |
|
plt.tight_layout() |
|
plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
k=2 |
|
plt.imshow(dec_last_hs[0][k].detach().numpy(), origin='upper') |
|
plt.colorbar(orientation='horizontal') |
|
plt.title('decoder last hidden state, k=0') |
|
plt.xlabel('hidden dim') |
|
plt.ylabel('time step') |
|
plt.subplot(1, 2, 2) |
|
k=12 |
|
plt.imshow(dec_last_hs[0][k].detach().numpy(), origin='upper') |
|
plt.colorbar(orientation='horizontal') |
|
plt.title('decoder last hidden state, k=12') |
|
plt.xlabel('hidden dim') |
|
plt.show() |
|
|
|
|
|
logits = model.lm_head(dec_last_hs) |
|
k=6 |
|
plt.imshow(logits[0][k][0:200, :].detach().numpy().T, origin='upper') |
|
plt.title('lm head output') |
|
plt.xlabel('vocab dim') |
|
plt.ylabel('time step') |
|
plt.show() |
|
softmax = torch.nn.Softmax(dim=3) |
|
logits_sm = softmax(logits) |
|
k=6 |
|
plt.imshow(logits_sm[0][k][:255, :].detach().numpy().T, origin='upper') |
|
plt.title('lm head softmax') |
|
plt.xlabel('vocab dim') |
|
plt.ylabel('time step') |
|
|
|
plt.show() |
|
|
|
k = 10 |
|
print(torch.argmax(logits, dim=3)[0,k,:]) |
|
|
|
|
|
|
|
|
|
|