MotionGPT / mGPT /archs /tools /transformer_layers.py
bill-jiang's picture
Init
4409449
raw
history blame
No virus
9.94 kB
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
from torch import Tensor
# Took from https://github.com/joeynmt/joeynmt/blob/fb66afcbe1beef9acd59283bcc084c4d4c1e6343/joeynmt/transformer_layers.py
# pylint: disable=arguments-differ
class MultiHeadedAttention(nn.Module):
"""
Multi-Head Attention module from "Attention is All You Need"
Implementation modified from OpenNMT-py.
https://github.com/OpenNMT/OpenNMT-py
"""
def __init__(self, num_heads: int, size: int, dropout: float = 0.1):
"""
Create a multi-headed attention layer.
:param num_heads: the number of heads
:param size: model size (must be divisible by num_heads)
:param dropout: probability of dropping a unit
"""
super().__init__()
assert size % num_heads == 0
self.head_size = head_size = size // num_heads
self.model_size = size
self.num_heads = num_heads
self.k_layer = nn.Linear(size, num_heads * head_size)
self.v_layer = nn.Linear(size, num_heads * head_size)
self.q_layer = nn.Linear(size, num_heads * head_size)
self.output_layer = nn.Linear(size, size)
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None):
"""
Computes multi-headed attention.
:param k: keys [B, M, D] with M being the sentence length.
:param v: values [B, M, D]
:param q: query [B, M, D]
:param mask: optional mask [B, 1, M] or [B, M, M]
:return:
"""
batch_size = k.size(0)
num_heads = self.num_heads
# project the queries (q), keys (k), and values (v)
k = self.k_layer(k)
v = self.v_layer(v)
q = self.q_layer(q)
# reshape q, k, v for our computation to [batch_size, num_heads, ..]
k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)
v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)
q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2)
# compute scores
q = q / math.sqrt(self.head_size)
# batch x num_heads x query_len x key_len
scores = torch.matmul(q, k.transpose(2, 3))
# torch.Size([48, 8, 183, 183])
# apply the mask (if we have one)
# we add a dimension for the heads to it below: [B, 1, 1, M]
if mask is not None:
scores = scores.masked_fill(~mask.unsqueeze(1), float('-inf'))
# apply attention dropout and compute context vectors.
attention = self.softmax(scores)
attention = self.dropout(attention)
# torch.Size([48, 8, 183, 183]) [bs, nheads, time, time] (for decoding)
# v: torch.Size([48, 8, 183, 32]) (32 is 256/8)
# get context vector (select values with attention) and reshape
# back to [B, M, D]
context = torch.matmul(attention, v) # torch.Size([48, 8, 183, 32])
context = context.transpose(1, 2).contiguous().view(
batch_size, -1, num_heads * self.head_size)
# torch.Size([48, 183, 256]) put back to 256 (combine the heads)
output = self.output_layer(context)
# torch.Size([48, 183, 256]): 1 output per time step
return output
# pylint: disable=arguments-differ
class PositionwiseFeedForward(nn.Module):
"""
Position-wise Feed-forward layer
Projects to ff_size and then back down to input_size.
"""
def __init__(self, input_size, ff_size, dropout=0.1):
"""
Initializes position-wise feed-forward layer.
:param input_size: dimensionality of the input.
:param ff_size: dimensionality of intermediate representation
:param dropout:
"""
super().__init__()
self.layer_norm = nn.LayerNorm(input_size, eps=1e-6)
self.pwff_layer = nn.Sequential(
nn.Linear(input_size, ff_size),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ff_size, input_size),
nn.Dropout(dropout),
)
def forward(self, x):
x_norm = self.layer_norm(x)
return self.pwff_layer(x_norm) + x
# pylint: disable=arguments-differ
class PositionalEncoding(nn.Module):
"""
Pre-compute position encodings (PE).
In forward pass, this adds the position-encodings to the
input for as many time steps as necessary.
Implementation based on OpenNMT-py.
https://github.com/OpenNMT/OpenNMT-py
"""
def __init__(self, size: int = 0, max_len: int = 5000):
"""
Positional Encoding with maximum length max_len
:param size:
:param max_len:
:param dropout:
"""
if size % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(size))
pe = torch.zeros(max_len, size)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp((torch.arange(0, size, 2, dtype=torch.float) *
-(math.log(10000.0) / size)))
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0) # shape: [1, size, max_len]
super().__init__()
self.register_buffer('pe', pe)
self.dim = size
def forward(self, emb):
"""Embed inputs.
Args:
emb (FloatTensor): Sequence of word vectors
``(seq_len, batch_size, self.dim)``
"""
# Add position encodings
return emb + self.pe[:, :emb.size(1)]
class TransformerEncoderLayer(nn.Module):
"""
One Transformer encoder layer has a Multi-head attention layer plus
a position-wise feed-forward layer.
"""
def __init__(self,
size: int = 0,
ff_size: int = 0,
num_heads: int = 0,
dropout: float = 0.1):
"""
A single Transformer layer.
:param size:
:param ff_size:
:param num_heads:
:param dropout:
"""
super().__init__()
self.layer_norm = nn.LayerNorm(size, eps=1e-6)
self.src_src_att = MultiHeadedAttention(num_heads,
size,
dropout=dropout)
self.feed_forward = PositionwiseFeedForward(size,
ff_size=ff_size,
dropout=dropout)
self.dropout = nn.Dropout(dropout)
self.size = size
# pylint: disable=arguments-differ
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
"""
Forward pass for a single transformer encoder layer.
First applies layer norm, then self attention,
then dropout with residual connection (adding the input to the result),
and then a position-wise feed-forward layer.
:param x: layer input
:param mask: input mask
:return: output tensor
"""
x_norm = self.layer_norm(x)
h = self.src_src_att(x_norm, x_norm, x_norm, mask)
h = self.dropout(h) + x
o = self.feed_forward(h)
return o
class TransformerDecoderLayer(nn.Module):
"""
Transformer decoder layer.
Consists of self-attention, source-attention, and feed-forward.
"""
def __init__(self,
size: int = 0,
ff_size: int = 0,
num_heads: int = 0,
dropout: float = 0.1):
"""
Represents a single Transformer decoder layer.
It attends to the source representation and the previous decoder states.
:param size: model dimensionality
:param ff_size: size of the feed-forward intermediate layer
:param num_heads: number of heads
:param dropout: dropout to apply to input
"""
super().__init__()
self.size = size
self.trg_trg_att = MultiHeadedAttention(num_heads,
size,
dropout=dropout)
self.src_trg_att = MultiHeadedAttention(num_heads,
size,
dropout=dropout)
self.feed_forward = PositionwiseFeedForward(size,
ff_size=ff_size,
dropout=dropout)
self.x_layer_norm = nn.LayerNorm(size, eps=1e-6)
self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6)
self.dropout = nn.Dropout(dropout)
# pylint: disable=arguments-differ
def forward(self,
x: Tensor = None,
memory: Tensor = None,
src_mask: Tensor = None,
trg_mask: Tensor = None) -> Tensor:
"""
Forward pass of a single Transformer decoder layer.
:param x: inputs
:param memory: source representations
:param src_mask: source mask
:param trg_mask: target mask (so as to not condition on future steps)
:return: output tensor
"""
# decoder/target self-attention
x_norm = self.x_layer_norm(x) # torch.Size([48, 183, 256])
h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask)
h1 = self.dropout(h1) + x
# source-target attention
h1_norm = self.dec_layer_norm(
h1) # torch.Size([48, 183, 256]) (same for memory)
h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask)
# final position-wise feed-forward layer
o = self.feed_forward(self.dropout(h2) + h1)
return o