import chainer import chainer.functions as F import chainer.links as L import numpy as np # dot product based attention class AttDot(chainer.Chain): """Compute attention based on dot product. Args: eprojs (int | None): Dimension of input vectors from encoder. dunits (int | None): Dimension of input vectors for decoder. att_dim (int): Dimension of input vectors for attention. """ def __init__(self, eprojs, dunits, att_dim): super(AttDot, self).__init__() with self.init_scope(): self.mlp_enc = L.Linear(eprojs, att_dim) self.mlp_dec = L.Linear(dunits, att_dim) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None def reset(self): """Reset states.""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None def __call__(self, enc_hs, dec_z, att_prev, scaling=2.0): """Compute AttDot forward layer. Args: enc_hs (chainer.Variable | N-dimensional array): Input variable from encoder. dec_z (chainer.Variable | N-dimensional array): Input variable of decoder. scaling (float): Scaling weight to make attention sharp. Returns: chainer.Variable: Weighted sum over flames. chainer.Variable: Attention weight. """ batch = len(enc_hs) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = F.pad_sequence(enc_hs) # utt x frame x hdim self.h_length = self.enc_h.shape[1] # utt x frame x att_dim self.pre_compute_enc_h = F.tanh(self.mlp_enc(self.enc_h, n_batch_axes=2)) if dec_z is None: dec_z = chainer.Variable( self.xp.zeros((batch, self.dunits), dtype=np.float32) ) else: dec_z = dec_z.reshape(batch, self.dunits) # for all t u = F.broadcast_to( F.expand_dims(F.tanh(self.mlp_dec(dec_z)), 1), self.pre_compute_enc_h.shape ) e = F.sum(self.pre_compute_enc_h * u, axis=2) # utt x frame # Applying a minus-large-number filter # to make a probability value zero for a padded area # simply degrades the performance, and I gave up this implementation # Apply a scaling to make an attention sharp w = F.softmax(scaling * e) # weighted sum over flames # utt x hdim c = F.sum( self.enc_h * F.broadcast_to(F.expand_dims(w, 2), self.enc_h.shape), axis=1 ) return c, w # location based attention class AttLoc(chainer.Chain): """Compute location-based attention. Args: eprojs (int | None): Dimension of input vectors from encoder. dunits (int | None): Dimension of input vectors for decoder. att_dim (int): Dimension of input vectors for attention. aconv_chans (int): Number of channels of output arrays from convolutional layer. aconv_filts (int): Size of filters of convolutional layer. """ def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts): super(AttLoc, self).__init__() with self.init_scope(): self.mlp_enc = L.Linear(eprojs, att_dim) self.mlp_dec = L.Linear(dunits, att_dim, nobias=True) self.mlp_att = L.Linear(aconv_chans, att_dim, nobias=True) self.loc_conv = L.Convolution2D( 1, aconv_chans, ksize=(1, 2 * aconv_filts + 1), pad=(0, aconv_filts) ) self.gvec = L.Linear(att_dim, 1) self.dunits = dunits self.eprojs = eprojs self.att_dim = att_dim self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.aconv_chans = aconv_chans def reset(self): """Reset states.""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None def __call__(self, enc_hs, dec_z, att_prev, scaling=2.0): """Compute AttLoc forward layer. Args: enc_hs (chainer.Variable | N-dimensional array): Input variable from encoders. dec_z (chainer.Variable | N-dimensional array): Input variable of decoder. att_prev (chainer.Variable | None): Attention weight. scaling (float): Scaling weight to make attention sharp. Returns: chainer.Variable: Weighted sum over flames. chainer.Variable: Attention weight. """ batch = len(enc_hs) # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = F.pad_sequence(enc_hs) # utt x frame x hdim self.h_length = self.enc_h.shape[1] # utt x frame x att_dim self.pre_compute_enc_h = self.mlp_enc(self.enc_h, n_batch_axes=2) if dec_z is None: dec_z = chainer.Variable( self.xp.zeros((batch, self.dunits), dtype=np.float32) ) else: dec_z = dec_z.reshape(batch, self.dunits) # initialize attention weight with uniform dist. if att_prev is None: att_prev = [ self.xp.full(hh.shape[0], 1.0 / hh.shape[0], dtype=np.float32) for hh in enc_hs ] att_prev = [chainer.Variable(att) for att in att_prev] att_prev = F.pad_sequence(att_prev) # att_prev: utt x frame -> utt x 1 x 1 x frame # -> utt x att_conv_chans x 1 x frame att_conv = self.loc_conv(att_prev.reshape(batch, 1, 1, self.h_length)) # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans att_conv = F.swapaxes(F.squeeze(att_conv, axis=2), 1, 2) # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim att_conv = self.mlp_att(att_conv, n_batch_axes=2) # dec_z_tiled: utt x frame x att_dim dec_z_tiled = F.broadcast_to( F.expand_dims(self.mlp_dec(dec_z), 1), self.pre_compute_enc_h.shape ) # dot with gvec # utt x frame x att_dim -> utt x frame # TODO(watanabe) use batch_matmul e = F.squeeze( self.gvec( F.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled), n_batch_axes=2 ), axis=2, ) # Applying a minus-large-number filter # to make a probability value zero for a padded area # simply degrades the performance, and I gave up this implementation # Apply a scaling to make an attention sharp w = F.softmax(scaling * e) # weighted sum over flames # utt x hdim c = F.sum( self.enc_h * F.broadcast_to(F.expand_dims(w, 2), self.enc_h.shape), axis=1 ) return c, w class NoAtt(chainer.Chain): """Compute non-attention layer. This layer is a dummy attention layer to be compatible with other attention-based models. """ def __init__(self): super(NoAtt, self).__init__() self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.c = None def reset(self): """Reset states.""" self.h_length = None self.enc_h = None self.pre_compute_enc_h = None self.c = None def __call__(self, enc_hs, dec_z, att_prev): """Compute NoAtt forward layer. Args: enc_hs (chainer.Variable | N-dimensional array): Input variable from encoders. dec_z: Dummy. att_prev (chainer.Variable | None): Attention weight. Returns: chainer.Variable: Sum over flames. chainer.Variable: Attention weight. """ # pre-compute all h outside the decoder loop if self.pre_compute_enc_h is None: self.enc_h = F.pad_sequence(enc_hs) # utt x frame x hdim self.h_length = self.enc_h.shape[1] # initialize attention weight with uniform dist. if att_prev is None: att_prev = [ self.xp.full(hh.shape[0], 1.0 / hh.shape[0], dtype=np.float32) for hh in enc_hs ] att_prev = [chainer.Variable(att) for att in att_prev] att_prev = F.pad_sequence(att_prev) self.c = F.sum( self.enc_h * F.broadcast_to(F.expand_dims(att_prev, 2), self.enc_h.shape), axis=1, ) return self.c, att_prev def att_for(args): """Returns an attention layer given the program arguments. Args: args (Namespace): The arguments. Returns: chainer.Chain: The corresponding attention module. """ if args.atype == "dot": att = AttDot(args.eprojs, args.dunits, args.adim) elif args.atype == "location": att = AttLoc( args.eprojs, args.dunits, args.adim, args.aconv_chans, args.aconv_filts ) elif args.atype == "noatt": att = NoAtt() else: raise NotImplementedError( "chainer supports only noatt, dot, and location attention." ) return att