|
from collections import defaultdict |
|
from typing import Dict |
|
from typing import List |
|
|
|
import torch |
|
|
|
from espnet.nets.pytorch_backend.rnn.attentions import AttAdd |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttCov |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttCovLoc |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttDot |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttForward |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttLoc |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttLoc2D |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttLocRec |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadAdd |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadDot |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadLoc |
|
from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadMultiResLoc |
|
from espnet.nets.pytorch_backend.rnn.attentions import NoAtt |
|
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention |
|
|
|
|
|
from espnet2.train.abs_espnet_model import AbsESPnetModel |
|
|
|
|
|
@torch.no_grad() |
|
def calculate_all_attentions( |
|
model: AbsESPnetModel, batch: Dict[str, torch.Tensor] |
|
) -> Dict[str, List[torch.Tensor]]: |
|
"""Derive the outputs from the all attention layers |
|
|
|
Args: |
|
model: |
|
batch: same as forward |
|
Returns: |
|
return_dict: A dict of a list of tensor. |
|
key_names x batch x (D1, D2, ...) |
|
|
|
""" |
|
bs = len(next(iter(batch.values()))) |
|
assert all(len(v) == bs for v in batch.values()), { |
|
k: v.shape for k, v in batch.items() |
|
} |
|
|
|
|
|
outputs = {} |
|
handles = {} |
|
for name, modu in model.named_modules(): |
|
|
|
def hook(module, input, output, name=name): |
|
if isinstance(module, MultiHeadedAttention): |
|
|
|
|
|
outputs[name] = module.attn.detach().cpu() |
|
elif isinstance(module, AttLoc2D): |
|
c, w = output |
|
|
|
|
|
att_w = w[:, -1].detach().cpu() |
|
outputs.setdefault(name, []).append(att_w) |
|
elif isinstance(module, (AttCov, AttCovLoc)): |
|
c, w = output |
|
assert isinstance(w, list), type(w) |
|
|
|
|
|
att_w = w[-1].detach().cpu() |
|
outputs.setdefault(name, []).append(att_w) |
|
elif isinstance(module, AttLocRec): |
|
|
|
c, (w, (att_h, att_c)) = output |
|
att_w = w.detach().cpu() |
|
outputs.setdefault(name, []).append(att_w) |
|
elif isinstance( |
|
module, |
|
( |
|
AttMultiHeadDot, |
|
AttMultiHeadAdd, |
|
AttMultiHeadLoc, |
|
AttMultiHeadMultiResLoc, |
|
), |
|
): |
|
c, w = output |
|
|
|
assert isinstance(w, list), type(w) |
|
att_w = [_w.detach().cpu() for _w in w] |
|
outputs.setdefault(name, []).append(att_w) |
|
elif isinstance( |
|
module, |
|
( |
|
AttAdd, |
|
AttDot, |
|
AttForward, |
|
AttForwardTA, |
|
AttLoc, |
|
NoAtt, |
|
), |
|
): |
|
c, w = output |
|
att_w = w.detach().cpu() |
|
outputs.setdefault(name, []).append(att_w) |
|
|
|
handle = modu.register_forward_hook(hook) |
|
handles[name] = handle |
|
|
|
|
|
|
|
keys = [] |
|
for k in batch: |
|
if not k.endswith("_lengths"): |
|
keys.append(k) |
|
|
|
return_dict = defaultdict(list) |
|
for ibatch in range(bs): |
|
|
|
_sample = { |
|
k: batch[k][ibatch, None, : batch[k + "_lengths"][ibatch]] |
|
if k + "_lengths" in batch |
|
else batch[k][ibatch, None] |
|
for k in keys |
|
} |
|
|
|
|
|
_sample.update( |
|
{ |
|
k + "_lengths": batch[k + "_lengths"][ibatch, None] |
|
for k in keys |
|
if k + "_lengths" in batch |
|
} |
|
) |
|
model(**_sample) |
|
|
|
|
|
for name, output in outputs.items(): |
|
if isinstance(output, list): |
|
if isinstance(output[0], list): |
|
|
|
output = torch.stack( |
|
[ |
|
|
|
torch.cat([o[idx] for o in output], dim=0) |
|
for idx in range(len(output[0])) |
|
], |
|
dim=0, |
|
) |
|
else: |
|
|
|
output = torch.cat(output, dim=0) |
|
else: |
|
|
|
output = output.squeeze(0) |
|
|
|
return_dict[name].append(output) |
|
outputs.clear() |
|
|
|
|
|
for _, handle in handles.items(): |
|
handle.remove() |
|
|
|
return dict(return_dict) |
|
|