Spaces:
Runtime error
Runtime error
File size: 5,228 Bytes
e2c1e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder self-attention layer definition."""
import copy
import torch
from torch import nn
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
class EncoderLayer(nn.Module):
"""Encoder layer module.
:param int size: input dim
:param espnet.nets.pytorch_backend.transformer.attention.
MultiHeadedAttention self_attn: self attention module
RelPositionMultiHeadedAttention self_attn: self attention module
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
PositionwiseFeedForward feed_forward:
feed forward module
:param espnet.nets.pytorch_backend.transformer.convolution.
ConvolutionModule feed_foreard:
feed forward module
:param float dropout_rate: dropout rate
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
:param bool macaron_style: whether to use macaron style for PositionwiseFeedForward
"""
def __init__(
self,
size,
self_attn,
feed_forward,
conv_module,
dropout_rate,
normalize_before=True,
concat_after=False,
macaron_style=False,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.ff_scale = 1.0
self.conv_module = conv_module
self.macaron_style = macaron_style
self.norm_ff = LayerNorm(size) # for the FNN module
self.norm_mha = LayerNorm(size) # for the MHA module
if self.macaron_style:
self.feed_forward_macaron = copy.deepcopy(feed_forward)
self.ff_scale = 0.5
# for another FNN module in macaron style
self.norm_ff_macaron = LayerNorm(size)
if self.conv_module is not None:
self.norm_conv = LayerNorm(size) # for the CNN module
self.norm_final = LayerNorm(size) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
def forward(self, x_input, mask, cache=None):
"""Compute encoded features.
:param torch.Tensor x_input: encoded source features (batch, max_time_in, size)
:param torch.Tensor mask: mask for x (batch, max_time_in)
:param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
if isinstance(x_input, tuple):
x, pos_emb = x_input[0], x_input[1]
else:
x, pos_emb = x_input, None
# whether to use macaron style
if self.macaron_style:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if pos_emb is not None:
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x = residual + self.dropout(self.conv_module(x))
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
else:
return x, mask
|