import numpy as np SEP = '[SEP]' CLS = '[CLS]' MASK = '[MASK]' def drop_bad_inds(arr, left_drop, right_drop): """Given the 4d array returned by attentions of shape (n_layer, n_head, n_left_text, n_right_text), return that array modified to drop ind1 from n_left_text and ind2 from n_right_text """ # print("Length of left drop: ", len(left_drop)) # print("Length of right drop: ", len(left_drop)) print("Shape of arr: ", arr.shape) arr = arr[:, :, ~left_drop, :] # Keys and queries don't match in the final dimension if arr.shape[-1] == len(right_drop): arr = arr[:, :, :, ~right_drop] return arr def strip_attention(attention): """Given an attention output of the BERT model, return the same object without CLS and SEP token weightings NOTE: Not currently fixing key and query """ attention_out = {} # Iterate through sentence combinations # Need queries, keys, att, left_text, right_text for i, (k, v) in enumerate(attention.items()): stripped_resp = {} left_tokens = np.array(v['left_text']) right_tokens = np.array(v['right_text']) att = np.array(v['att']) # key = np.array(v['keys']) # quer = np.array(v['queries']) left_drop = (left_tokens == CLS) | (left_tokens == SEP) right_drop = (right_tokens == CLS) | (right_tokens == SEP) att_out = drop_bad_inds(att, left_drop, right_drop) # key_out = drop_bad_inds(key, left_drop, right_drop) # quer_out = drop_bad_inds(quer, left_drop, right_drop) left_out = left_tokens[~left_drop] right_out = right_tokens[~right_drop] # assert att_out.shape[:3] == key_out.shape[:3] == quer_out.shape[:3] assert att_out.shape[2] == len(left_out) assert att_out.shape[3] == len(right_out) stripped_resp['att'] = att_out.tolist() stripped_resp['keys'] = v['keys'] stripped_resp['queries'] = v['queries'] stripped_resp['left_text'] = left_out.tolist() stripped_resp['right_text'] = right_out.tolist() attention_out[k] = stripped_resp return attention_out def mask_attention(deets, maskA, maskB): """Deets have form: tokens_a, tokens_b, query_tensor.data.numpy(), key_tensor.data.numpy(), attn_tensor.data.numpy() Take the first two in tuple and mask according to maskA and maskB which are lists of indices to mask """ tokens_a = np.array(deets[0]) tokens_a[maskA] = MASK tokens_a.tolist() tokens_b = np.array(deets[1]) tokens_b[maskb] = MASK tokens_b.tolist() deets[0] = tokens_a.tolist() deets[1] = tokens_b.tolist() return deets