from abc import ABC import numpy as np from tqdm.auto import tqdm try: from src.attention_flow_abstract import AttentionFlow except Exception: from ..src.attention_flow_abstract import AttentionFlow class AttentionRollout(AttentionFlow, ABC): def compute_flows(self, attentions_list, desc="", output_hidden_states=False, num_cpus=0): """ :param attentions_list: list of attention maps (#examples, #layers, #sent_len, #sent_len) :param desc: :param output_hidden_states: :param num_cpus: :return: """ attentions_rollouts = [] for i in tqdm(range(len(attentions_list)), desc=desc): if output_hidden_states: attentions_rollouts.append(self.compute_joint_attention(attentions_list[i])) else: attentions_rollouts.append(self.compute_joint_attention(attentions_list[i])[[-1]]) return attentions_rollouts def compute_joint_attention(self, att_mat): res_att_mat = self.pre_process(att_mat) # res_att_mat = res_att_mat[4:10, :, :] joint_attentions = np.zeros(res_att_mat.shape) layers = joint_attentions.shape[0] joint_attentions[0] = res_att_mat[0] for i in np.arange(1, layers): joint_attentions[i] = res_att_mat[i].dot(joint_attentions[i - 1]) return joint_attentions