Spaces:
Runtime error
Runtime error
File size: 4,847 Bytes
f4dac30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Positonal Encoding Module."""
import math
import torch
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Perform pre-hook in load_state_dict for backward compatibility.
Note:
We saved self.pe until v.0.5.2 but we have omitted it later.
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
"""
k = prefix + "pe"
if k in state_dict:
state_dict.pop(k)
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
:param reverse: whether to reverse the input position
"""
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
self.reverse = reverse
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1)]
return self.dropout(x)
class ScaledPositionalEncoding(PositionalEncoding):
"""Scaled positional encoding module.
See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Initialize class.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
"""
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
def reset_parameters(self):
"""Reset parameters."""
self.alpha.data = torch.tensor(1.0)
def forward(self, x):
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
"""
self.extend_pe(x)
x = x + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(x)
class RelPositionalEncoding(PositionalEncoding):
"""Relitive positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Initialize class.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True)
def forward(self, x):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
torch.Tensor: x. Its shape is (batch, time, ...)
torch.Tensor: pos_emb. Its shape is (1, time, ...)
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, : x.size(1)]
return self.dropout(x), self.dropout(pos_emb)
|