p1atdev commited on
Commit
4371fee
1 Parent(s): f814b50

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_retnet.py +122 -0
  2. modeling_retnet.py +1313 -4
configuration_retnet.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # by syncdoth: https://github.com/syncdoth/RetNet/blob/main/retnet/configuration_retnet.py
2
+
3
+ from dataclasses import dataclass
4
+ import json
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ def load_config_from_json(config_file):
10
+ with open(config_file, "r") as f:
11
+ config = json.load(f)
12
+ config = RetNetConfig.from_dict(config)
13
+ return config
14
+
15
+
16
+ @dataclass
17
+ class RetNetConfig(PretrainedConfig):
18
+ model_type = "retnet"
19
+ initializer_range: float = 0.02
20
+ activation_fn: str = "gelu"
21
+ dropout: float = 0.0 # dropout probability
22
+ activation_dropout: float = 0.0 # dropout probability after activation in FFN.
23
+ drop_path_rate: float = 0.0
24
+ decoder_embed_dim: int = 768 # decoder embedding dimension
25
+ decoder_value_embed_dim: int = 1280 # decoder value embedding dimension
26
+ decoder_ffn_embed_dim: int = 1280 # decoder embedding dimension for FFN
27
+ decoder_layers: int = 12 # num decoder layers
28
+ decoder_retention_heads: int = 3 # num decoder retention heads
29
+ decoder_normalize_before: bool = True # apply layernorm before each decoder block
30
+ layernorm_embedding: bool = False # add layernorm to embedding
31
+ no_scale_embedding: bool = True # if True, dont scale embeddings
32
+ recurrent_chunk_size: int = 512
33
+ use_lm_decay: bool = False
34
+ use_glu: bool = True # use GLU instead of FFN
35
+ z_loss_coeff: float = 0.0 # coefficient for z loss: TODO: 1e-4
36
+ deepnorm: bool = False
37
+ subln: bool = True
38
+ use_ffn_rms_norm: bool = False
39
+ layernorm_eps: float = 1e-6
40
+ tie_word_embeddings: bool = False
41
+
42
+ def __init__(
43
+ self,
44
+ vocab_size: int = 50257,
45
+ initializer_range: float = 0.02,
46
+ is_decoder: bool = True,
47
+ pad_token_id: int = 0,
48
+ eos_token_id: int = 0,
49
+ output_retentions: bool = False,
50
+ use_cache: bool = True,
51
+ forward_impl: str = "parallel",
52
+ activation_fn: str = "gelu",
53
+ dropout: float = 0.0, # dropout probability
54
+ activation_dropout: float = 0.0, # dropout probability after activation in FFN.
55
+ drop_path_rate: float = 0.0,
56
+ decoder_embed_dim: int = 768, # decoder embedding dimension
57
+ decoder_value_embed_dim: int = 1280, # decoder value embedding dimension
58
+ decoder_ffn_embed_dim: int = 1280, # decoder embedding dimension for FFN
59
+ decoder_layers: int = 12, # num decoder layers
60
+ decoder_retention_heads: int = 3, # num decoder retention heads
61
+ decoder_normalize_before: bool = True, # apply layernorm before each decoder block
62
+ layernorm_embedding: bool = False, # add layernorm to embedding
63
+ no_scale_embedding: bool = True, # if True, dont scale embeddings
64
+ recurrent_chunk_size: int = 512,
65
+ use_glu: bool = True, # use GLU instead of FFN
66
+ z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4
67
+ use_lm_decay: bool = False,
68
+ deepnorm: bool = False,
69
+ subln: bool = True,
70
+ use_ffn_rms_norm: bool = False, # use RMSNorm instead of LayerNorm in FFN
71
+ layernorm_eps: float = 1e-6,
72
+ tie_word_embeddings: bool = False,
73
+ **kwargs
74
+ ):
75
+ self.vocab_size = vocab_size
76
+ self.initializer_range = initializer_range
77
+ self.output_retentions = output_retentions
78
+ self.use_lm_decay = use_lm_decay
79
+ self.use_glu = use_glu
80
+ self.z_loss_coeff = z_loss_coeff
81
+ # size related
82
+ self.decoder_embed_dim = decoder_embed_dim
83
+ self.decoder_value_embed_dim = decoder_value_embed_dim
84
+ self.decoder_retention_heads = decoder_retention_heads
85
+ self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
86
+ self.decoder_layers = decoder_layers
87
+ # normalization related
88
+ self.decoder_normalize_before = decoder_normalize_before
89
+ self.activation_fn = activation_fn
90
+ self.dropout = dropout
91
+ self.drop_path_rate = drop_path_rate
92
+ self.activation_dropout = activation_dropout
93
+ self.no_scale_embedding = no_scale_embedding
94
+ self.layernorm_embedding = layernorm_embedding
95
+ self.deepnorm = deepnorm
96
+ self.subln = subln
97
+ self.use_ffn_rms_norm = use_ffn_rms_norm
98
+ self.layernorm_eps = layernorm_eps
99
+ # Blockwise
100
+ self.recurrent_chunk_size = recurrent_chunk_size
101
+ self.forward_impl = forward_impl
102
+
103
+ if self.deepnorm:
104
+ self.decoder_normalize_before = False
105
+ self.subln = False
106
+ if self.subln:
107
+ self.decoder_normalize_before = True
108
+ self.deepnorm = False
109
+
110
+ super().__init__(
111
+ is_decoder=is_decoder,
112
+ pad_token_id=pad_token_id,
113
+ eos_token_id=eos_token_id,
114
+ use_cache=use_cache,
115
+ tie_word_embeddings=tie_word_embeddings,
116
+ **kwargs
117
+ )
118
+
119
+ def override(self, args):
120
+ for hp in self.__dict__.keys():
121
+ if getattr(args, hp, None) is not None:
122
+ self.__dict__[hp] = getattr(args, hp, None)
modeling_retnet.py CHANGED
@@ -1,8 +1,1317 @@
1
- from RetNet.retnet.modeling_retnet import RetNetForCausalLM, RetNetConfig
2
 
 
 
 
3
 
4
- LightNovelIntroRetNetModelConfig = RetNetConfig
 
 
 
 
 
 
 
 
 
 
5
 
 
 
 
 
6
 
7
- class LightNovelIntroRetNetModel(RetNetForCausalLM):
8
- config_class = LightNovelIntroRetNetModelConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # by syncdoth: https://github.com/syncdoth/RetNet/blob/main/retnet/modeling_retnet.py
2
 
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional, Tuple, Union
6
 
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from timm.models.layers import drop_path
12
+ from torch import nn
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
+ from transformers import top_k_top_p_filtering
15
+ from transformers.modeling_outputs import ModelOutput, SequenceClassifierOutputWithPast
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
 
19
+ try:
20
+ from apex.normalization import FusedLayerNorm as LayerNorm
21
+ except ModuleNotFoundError:
22
+ from torch.nn import LayerNorm
23
 
24
+ from .configuration_retnet import RetNetConfig
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ # helper functions
30
+ def split_heads(tensors, bsz, seqlen, num_heads):
31
+ assert isinstance(tensors, (tuple, list))
32
+ return [x.view(bsz, seqlen, num_heads, -1).transpose(1, 2) for x in tensors]
33
+
34
+
35
+ def rotate_every_two(x):
36
+ x1 = x[:, :, :, ::2]
37
+ x2 = x[:, :, :, 1::2]
38
+ x = torch.stack((-x2, x1), dim=-1)
39
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
40
+
41
+
42
+ def theta_shift(x, sin, cos):
43
+ return (x * cos) + (rotate_every_two(x) * sin)
44
+
45
+
46
+ def get_activation_fn(activation):
47
+ if activation == "relu":
48
+ return F.relu
49
+ elif activation == "gelu":
50
+ return F.gelu
51
+ elif activation == "swish":
52
+ return F.silu
53
+ else:
54
+ raise NotImplementedError
55
+
56
+
57
+ class RMSNorm(nn.Module):
58
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
59
+ super().__init__()
60
+ self.normalized_shape = dim
61
+ self.eps = eps
62
+ self.elementwise_affine = elementwise_affine
63
+ if self.elementwise_affine:
64
+ self.weight = nn.Parameter(torch.ones(dim))
65
+ else:
66
+ self.register_parameter("weight", None)
67
+
68
+ def _norm(self, x):
69
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
70
+
71
+ def forward(self, x):
72
+ output = self._norm(x.float()).type_as(x)
73
+ if self.weight is not None:
74
+ output = output * self.weight
75
+ return output
76
+
77
+
78
+ class RetNetRelPos(nn.Module):
79
+ def __init__(self, config: RetNetConfig):
80
+ super().__init__()
81
+ self.config = config
82
+ num_heads = config.decoder_retention_heads
83
+
84
+ angle = 1.0 / (
85
+ 10000 ** torch.linspace(0, 1, config.decoder_embed_dim // num_heads // 2)
86
+ )
87
+ angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
88
+ # decay (gamma)
89
+ if config.use_lm_decay:
90
+ # NOTE: alternative way described in the paper
91
+ s = torch.log(torch.tensor(1 / 32))
92
+ e = torch.log(torch.tensor(1 / 512))
93
+ decay = torch.log(1 - torch.exp(torch.linspace(s, e, num_heads))) # [h,]
94
+ else:
95
+ decay = torch.log(
96
+ 1 - 2 ** (-5 - torch.arange(num_heads, dtype=torch.float))
97
+ )
98
+ self.register_buffer("angle", angle)
99
+ self.register_buffer("decay", decay)
100
+ self.recurrent_chunk_size = config.recurrent_chunk_size
101
+
102
+ def forward(
103
+ self,
104
+ slen,
105
+ forward_impl="parallel",
106
+ recurrent_chunk_size=None,
107
+ retention_mask=None,
108
+ get_decay_scale=True,
109
+ ):
110
+ if forward_impl == "recurrent":
111
+ sin = torch.sin(self.angle * (slen - 1))
112
+ cos = torch.cos(self.angle * (slen - 1))
113
+ retention_rel_pos = ((sin, cos), self.decay.view(1, -1, 1, 1).exp())
114
+ elif forward_impl == "chunkwise":
115
+ if recurrent_chunk_size is None:
116
+ recurrent_chunk_size = self.recurrent_chunk_size
117
+ index = torch.arange(slen).to(self.decay)
118
+ sin = torch.sin(index[:, None] * self.angle[None, :])
119
+ cos = torch.cos(index[:, None] * self.angle[None, :])
120
+
121
+ block_index = torch.arange(recurrent_chunk_size).to(self.decay)
122
+ mask = torch.tril(
123
+ torch.ones(recurrent_chunk_size, recurrent_chunk_size)
124
+ ).to(self.decay)
125
+ mask = torch.masked_fill(
126
+ block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf")
127
+ )
128
+ mask = torch.exp(mask * self.decay[:, None, None])
129
+ mask = torch.nan_to_num(mask)
130
+ mask = mask.unsqueeze(0) # [1, h, t, t]
131
+ # TODO: need to handle retention_mask
132
+ # scaling
133
+ value_inner_decay = mask[:, :, -1] / mask[:, :, -1].sum(
134
+ dim=-1, keepdim=True
135
+ )
136
+ value_inner_decay = value_inner_decay.unsqueeze(-1)
137
+ scale = mask.sum(dim=-1, keepdim=True).sqrt()
138
+ inner_mask = mask / scale
139
+
140
+ cross_decay = torch.exp(self.decay * recurrent_chunk_size)
141
+ query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
142
+ cross_decay = cross_decay[None, :, None, None]
143
+ query_inner_decay = query_inner_decay[None, :, :, None] / (
144
+ scale / mask[:, :, -1].sum(dim=-1)[:, :, None, None]
145
+ )
146
+ # decay_scale (used for kv cache)
147
+ if get_decay_scale:
148
+ decay_scale = self.compute_decay_scale(slen, retention_mask)
149
+ else:
150
+ decay_scale = None
151
+ retention_rel_pos = (
152
+ (sin, cos),
153
+ (
154
+ inner_mask,
155
+ cross_decay,
156
+ query_inner_decay,
157
+ value_inner_decay,
158
+ decay_scale,
159
+ ),
160
+ )
161
+ else: # parallel
162
+ index = torch.arange(slen).to(self.decay)
163
+ sin = torch.sin(index[:, None] * self.angle[None, :])
164
+ cos = torch.cos(index[:, None] * self.angle[None, :])
165
+ mask = torch.tril(torch.ones(slen, slen)).to(self.decay)
166
+ mask = torch.masked_fill(
167
+ index[:, None] - index[None, :], ~mask.bool(), float("inf")
168
+ )
169
+ mask = torch.exp(mask * self.decay[:, None, None])
170
+ mask = torch.nan_to_num(mask)
171
+ mask = mask.unsqueeze(0) # [1, h, t, t]
172
+ if retention_mask is not None:
173
+ # this is required for left padding
174
+ mask = mask * retention_mask.float().view(-1, 1, 1, slen).to(mask)
175
+
176
+ # scaling
177
+ mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
178
+ mask = torch.nan_to_num(mask, nan=0.0)
179
+ # decay_scale (used for kv cache)
180
+ if get_decay_scale:
181
+ decay_scale = self.compute_decay_scale(slen, retention_mask)
182
+ else:
183
+ decay_scale = None
184
+ # mask processing for intra decay
185
+ if retention_mask is not None:
186
+ max_non_zero = (
187
+ torch.cumsum(retention_mask, dim=-1).max(dim=-1).indices
188
+ ) # [b,]
189
+ intra_decay = mask[range(mask.shape[0]), :, max_non_zero]
190
+ else:
191
+ intra_decay = mask[:, :, -1]
192
+
193
+ retention_rel_pos = ((sin, cos), (mask, intra_decay, decay_scale))
194
+
195
+ return retention_rel_pos
196
+
197
+ def compute_decay_scale(self, slen, retention_mask=None):
198
+ exponent = torch.arange(slen, device=self.decay.device).float()
199
+ decay_scale = self.decay.exp().view(-1, 1) ** exponent.view(1, -1) # [h, t]
200
+ if retention_mask is not None:
201
+ seqlen = retention_mask.sum(dim=-1) # [b,]
202
+ bsz = seqlen.size(0)
203
+ decay_scale = decay_scale.unsqueeze(0).repeat(bsz, 1, 1) # [b, h, t]
204
+ for i, pos in enumerate(seqlen):
205
+ # the formula for decay_scale is `sum(gamma^i) for i in [0, slen).`
206
+ # Since the retention_mask is 0 for padding, we can set the decay_scale
207
+ # to 0 for the padding positions.
208
+ decay_scale[i, :, pos.item() :] = 0
209
+ else:
210
+ bsz = 1
211
+ decay_scale = decay_scale.sum(-1).view(bsz, -1, 1, 1) # [b, h, 1, 1]
212
+ return decay_scale
213
+
214
+
215
+ class MultiScaleRetention(nn.Module):
216
+ def __init__(
217
+ self,
218
+ config: RetNetConfig,
219
+ gate_fn="swish",
220
+ use_bias=False,
221
+ tensor_parallel=False,
222
+ ):
223
+ super().__init__()
224
+ self.config = config
225
+ self.embed_dim = config.decoder_embed_dim
226
+ self.value_dim = config.decoder_value_embed_dim
227
+ self.num_heads = config.decoder_retention_heads
228
+ self.head_dim = self.value_dim // self.num_heads
229
+ self.key_dim = self.embed_dim // self.num_heads
230
+ self.scaling = self.key_dim**-0.5
231
+
232
+ self.gate_fn = get_activation_fn(activation=str(gate_fn))
233
+
234
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias)
235
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias)
236
+ self.v_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias)
237
+ self.g_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias)
238
+
239
+ self.out_proj = nn.Linear(self.value_dim, self.embed_dim, bias=use_bias)
240
+
241
+ self.group_norm = RMSNorm(
242
+ self.head_dim, eps=config.layernorm_eps, elementwise_affine=False
243
+ )
244
+ self.reset_parameters()
245
+
246
+ if tensor_parallel:
247
+ self.decay_proj = nn.Linear(self.num_heads, self.num_heads, bias=False)
248
+ else:
249
+ self.decay_proj = None
250
+
251
+ def reset_parameters(self):
252
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5)
253
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5)
254
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5)
255
+ nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5)
256
+ nn.init.xavier_uniform_(self.out_proj.weight)
257
+
258
+ def parallel_retention(self, q, k, v, decay_mask):
259
+ """
260
+ q, # bsz * num_head * len * qk_dim
261
+ k, # bsz * num_head * len * qk_dim
262
+ v, # bsz * num_head * len * v_dim
263
+ decay_mask, # (1 or bsz) * num_head * len * len
264
+ """
265
+ decay_mask, intra_decay, scale = decay_mask
266
+ # just return retention_rel_pos projected
267
+ # TODO: for shardformer
268
+ if self.decay_proj is not None:
269
+ decay_mask = self.decay_proj(decay_mask.transpose(-1, -3)).transpose(-3, -1)
270
+
271
+ # [b, h, t, t]
272
+ retention = q @ k.transpose(-1, -2) # (scaled dot-product)
273
+ retention = retention * decay_mask
274
+
275
+ # invariant after normalization
276
+ retention = retention / retention.detach().sum(
277
+ dim=-1, keepdim=True
278
+ ).abs().clamp(min=1)
279
+
280
+ output = retention @ v # [b, h, t, v_dim / h]
281
+ output = output.transpose(1, 2) # [b, t, h, v_dim / h]
282
+
283
+ if self.training: # skip cache
284
+ return output, None, retention
285
+
286
+ if self.decay_proj is not None:
287
+ intra_decay = self.decay_proj(intra_decay.transpose(-1, -2)).transpose(
288
+ -2, -1
289
+ )
290
+
291
+ # kv cache: [b, h, t, v_dim, qk_dim]
292
+ current_kv = k.unsqueeze(-2) * v.unsqueeze(-1)
293
+ intra_decay = intra_decay[:, :, :, None, None] # [b, h, t, 1, 1]
294
+ current_kv = (current_kv * intra_decay).sum(2) # [b, h, v_dim, qk_dim]
295
+
296
+ cache = {"prev_key_value": current_kv, "scale": scale}
297
+ return output, cache, retention
298
+
299
+ def recurrent_retention(
300
+ self, q, k, v, decay, past_key_value=None, retention_mask=None
301
+ ):
302
+ """
303
+ q, k, v, # bsz * num_head * 1 * qkv_dim
304
+ past_key_value:
305
+ - "prev_key_value" # bsz * num_head * v_dim * qk_dim
306
+ - "scale" # (1 or bsz) * num_head * 1 * 1
307
+ decay # (1 or bsz) * num_head * 1 * 1
308
+ retention_mask # bsz * 1
309
+ """
310
+ if retention_mask is not None:
311
+ retention_mask = retention_mask.float().view(-1, 1, 1, 1).to(decay)
312
+ else:
313
+ retention_mask = torch.ones(k.size(0), 1, 1, 1).to(decay)
314
+ # (b, h, v_dim, qk_dim)
315
+ current_kv = k * v.transpose(-1, -2) * retention_mask
316
+
317
+ if past_key_value is not None and "prev_key_value" in past_key_value:
318
+ prev_kv = past_key_value["prev_key_value"]
319
+ prev_scale = past_key_value["scale"]
320
+ scale = torch.where(retention_mask == 0, prev_scale, prev_scale * decay + 1)
321
+ # connect prev_kv and current_kv
322
+ # how much to decay prev_kv
323
+ decay_amount = prev_scale.sqrt() * decay / scale.sqrt()
324
+ decay_amount = torch.where(retention_mask == 0, 1, decay_amount)
325
+ prev_kv = prev_kv * decay_amount # decay prev_kv
326
+ current_kv = current_kv / scale.sqrt() # scale current_kv
327
+ current_kv = torch.nan_to_num(
328
+ current_kv, nan=0.0
329
+ ) # remove nan, scale might be 0
330
+
331
+ current_kv = prev_kv + current_kv
332
+ else:
333
+ scale = torch.ones_like(decay)
334
+ # when retention_mask is 0 at the beginning, setting scale to 1 will
335
+ # make the first retention to use the padding incorrectly. Hence,
336
+ # setting it to 0 here. This is a little ugly, so we might want to
337
+ # change this later. TODO: improve
338
+ scale = torch.where(retention_mask == 0, torch.zeros_like(decay), scale)
339
+
340
+ output = torch.sum(q * current_kv, dim=3).unsqueeze(1) # (b, 1, h, d_v)
341
+
342
+ cache = {"prev_key_value": current_kv, "scale": scale}
343
+ return output, cache
344
+
345
+ def chunkwise_retention(self, q, k, v, decay_mask):
346
+ """
347
+ q, k, v, # bsz * num_head * seqlen * qkv_dim
348
+ past_key_value:
349
+ - "prev_key_value" # bsz * num_head * v_dim * qk_dim
350
+ - "scale" # (1 or bsz) * num_head * 1 * 1
351
+ decay_mask, # 1 * num_head * chunk_size * chunk_size
352
+ cross_decay, # 1 * num_head * 1 * 1
353
+ inner_decay, # 1 * num_head * chunk_size * 1
354
+ """
355
+ # TODO: not working properly
356
+ (
357
+ decay_mask,
358
+ cross_decay,
359
+ query_inner_decay,
360
+ value_inner_decay,
361
+ decay_scale,
362
+ ) = decay_mask
363
+ bsz, _, tgt_len, _ = v.size()
364
+ chunk_len = decay_mask.size(-1)
365
+ assert tgt_len % chunk_len == 0
366
+ num_chunks = tgt_len // chunk_len
367
+
368
+ # [b, n_c, h, t_c, qkv_dim]
369
+ q = q.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(
370
+ 1, 2
371
+ )
372
+ k = k.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(
373
+ 1, 2
374
+ )
375
+ v = v.view(bsz, self.num_heads, num_chunks, chunk_len, self.head_dim).transpose(
376
+ 1, 2
377
+ )
378
+
379
+ k_t = k.transpose(-1, -2)
380
+
381
+ qk_mat = q @ k_t # [b, n_c, h, t_c, t_c]
382
+ qk_mat = qk_mat * decay_mask.unsqueeze(1)
383
+ inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
384
+ qk_mat = qk_mat / inner_scale
385
+ # [b, n_c, h, t_c, v_dim]
386
+ inner_output = torch.matmul(qk_mat, v)
387
+
388
+ # reduce kv in one chunk
389
+ # [b, n_c, h, qk_dim, v_dim]
390
+ kv = k_t @ (v * value_inner_decay)
391
+ # kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)
392
+
393
+ kv_recurrent = []
394
+ cross_scale = []
395
+ kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
396
+ kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
397
+
398
+ # accumulate kv by loop
399
+ for i in range(num_chunks):
400
+ kv_recurrent.append(kv_state / kv_scale)
401
+ cross_scale.append(kv_scale)
402
+ kv_state = kv_state * cross_decay + kv[:, i]
403
+ kv_scale = (
404
+ kv_state.detach()
405
+ .abs()
406
+ .sum(dim=-2, keepdim=True)
407
+ .max(dim=-1, keepdim=True)
408
+ .values.clamp(min=1)
409
+ )
410
+
411
+ kv_recurrent = torch.stack(kv_recurrent, dim=1)
412
+ cross_scale = torch.stack(cross_scale, dim=1)
413
+
414
+ all_scale = torch.maximum(inner_scale, cross_scale)
415
+ align_inner_scale = all_scale / inner_scale
416
+ align_cross_scale = all_scale / cross_scale
417
+
418
+ cross_output = (q * query_inner_decay.unsqueeze(1)) @ kv_recurrent
419
+ output = inner_output / align_inner_scale + cross_output / align_cross_scale
420
+ output = output.transpose(2, 3) # [b, n_c, t_c, h, v_dim]
421
+
422
+ cache = {"prev_key_value": kv_state.transpose(-2, -1), "scale": decay_scale}
423
+ return output, cache
424
+
425
+ def forward(
426
+ self,
427
+ hidden_states: torch.Tensor,
428
+ rel_pos: Tuple[Tuple[torch.Tensor]],
429
+ retention_mask: Optional[torch.Tensor] = None,
430
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
431
+ forward_impl: str = "parallel",
432
+ output_retentions: Optional[bool] = False,
433
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
434
+ B, T, H = hidden_states.size()
435
+ (sin, cos), decay_mask = rel_pos
436
+ # projections
437
+ q = self.q_proj(hidden_states)
438
+ k = self.k_proj(hidden_states)
439
+ v = self.v_proj(hidden_states)
440
+ g = self.g_proj(hidden_states)
441
+ # multi-head
442
+ q, k, v = split_heads((q, k, v), B, T, self.num_heads)
443
+ k *= self.scaling # for scaled dot product
444
+ # rotate
445
+ # NOTE: theta_shift has bug with mps device.
446
+ qr = theta_shift(q, sin, cos)
447
+ kr = theta_shift(k, sin, cos)
448
+
449
+ # retention
450
+ if forward_impl == "parallel":
451
+ retention_out, curr_kv, retention_weights = self.parallel_retention(
452
+ qr, kr, v, decay_mask
453
+ )
454
+ elif forward_impl == "recurrent":
455
+ retention_out, curr_kv = self.recurrent_retention(
456
+ qr,
457
+ kr,
458
+ v,
459
+ decay_mask,
460
+ past_key_value=past_key_value,
461
+ retention_mask=retention_mask,
462
+ )
463
+ elif forward_impl == "chunkwise":
464
+ retention_out, curr_kv = self.chunkwise_retention(qr, kr, v, decay_mask)
465
+ else:
466
+ raise ValueError(f"forward_impl {forward_impl} not supported.")
467
+
468
+ # concaat heads
469
+ normed = self.group_norm(retention_out).reshape(B, T, self.value_dim)
470
+ # out gate & proj
471
+ out = self.gate_fn(g) * normed
472
+ out = self.out_proj(out)
473
+
474
+ outputs = (out, curr_kv)
475
+ if output_retentions:
476
+ outputs += (retention_weights,) if forward_impl == "parallel" else (None,)
477
+ return outputs
478
+
479
+
480
+ class FeedForwardNetwork(nn.Module):
481
+ def __init__(
482
+ self,
483
+ embed_dim,
484
+ ffn_dim,
485
+ activation_fn,
486
+ dropout,
487
+ activation_dropout,
488
+ layernorm_eps,
489
+ subln=False,
490
+ use_rms_norm=False,
491
+ ):
492
+ super().__init__()
493
+ self.embed_dim = embed_dim
494
+ self.activation_fn = get_activation_fn(activation=str(activation_fn))
495
+ self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
496
+ self.dropout_module = torch.nn.Dropout(dropout)
497
+ self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
498
+ self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
499
+ if subln:
500
+ if use_rms_norm:
501
+ self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps)
502
+ else:
503
+ self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps)
504
+ else:
505
+ self.ffn_layernorm = None
506
+
507
+ def reset_parameters(self):
508
+ self.fc1.reset_parameters()
509
+ self.fc2.reset_parameters()
510
+ if self.ffn_layernorm is not None:
511
+ self.ffn_layernorm.reset_parameters()
512
+
513
+ def forward(self, x):
514
+ x_shape = x.shape
515
+ x = x.reshape(-1, x.size(-1))
516
+ x = self.fc1(x)
517
+ x = self.activation_fn(x.float()).type_as(x)
518
+ x = self.activation_dropout_module(x)
519
+ if self.ffn_layernorm is not None:
520
+ x = self.ffn_layernorm(x)
521
+ x = self.fc2(x)
522
+ x = x.view(x_shape)
523
+ x = self.dropout_module(x)
524
+ return x
525
+
526
+
527
+ class GLU(nn.Module):
528
+ def __init__(
529
+ self,
530
+ embed_dim,
531
+ ffn_dim,
532
+ activation_fn,
533
+ dropout,
534
+ activation_dropout,
535
+ ):
536
+ super().__init__()
537
+ self.embed_dim = embed_dim
538
+ self.activation_fn = get_activation_fn(activation=str(activation_fn))
539
+ self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
540
+ self.dropout_module = torch.nn.Dropout(dropout)
541
+ self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
542
+ self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
543
+ self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
544
+
545
+ def reset_parameters(self):
546
+ self.fc1.reset_parameters()
547
+ self.fc2.reset_parameters()
548
+ self.gate.reset_parameters()
549
+
550
+ def forward(self, x):
551
+ x_shape = x.shape
552
+ x = x.reshape(-1, x.size(-1))
553
+ g = self.gate(x)
554
+ x = self.fc1(x)
555
+ x = self.activation_fn(x.float()).type_as(x) * g
556
+ x = self.activation_dropout_module(x)
557
+ x = self.fc2(x)
558
+ x = x.view(x_shape)
559
+ x = self.dropout_module(x)
560
+ return x
561
+
562
+
563
+ class DropPath(nn.Module):
564
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
565
+
566
+ def __init__(self, drop_prob=None):
567
+ super(DropPath, self).__init__()
568
+ self.drop_prob = drop_prob
569
+
570
+ def forward(self, x):
571
+ return drop_path(x, self.drop_prob, self.training)
572
+
573
+ def extra_repr(self):
574
+ return "p={}".format(self.drop_prob)
575
+
576
+
577
+ class RetNetDecoderLayer(nn.Module):
578
+ def __init__(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False):
579
+ super().__init__()
580
+ self.config = config
581
+ self.embed_dim = config.decoder_embed_dim
582
+ self.dropout_module = torch.nn.Dropout(config.dropout)
583
+
584
+ if config.drop_path_rate > 0:
585
+ drop_path_prob = np.linspace(
586
+ 0, config.drop_path_rate, config.decoder_layers
587
+ )[depth]
588
+ self.drop_path = DropPath(drop_path_prob)
589
+ else:
590
+ self.drop_path = None
591
+
592
+ self.retention = MultiScaleRetention(
593
+ config, use_bias=False, tensor_parallel=tensor_parallel
594
+ )
595
+
596
+ self.normalize_before = config.decoder_normalize_before
597
+
598
+ self.retention_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
599
+
600
+ self.ffn_dim = config.decoder_ffn_embed_dim
601
+
602
+ self.ffn = self.build_ffn()
603
+
604
+ self.final_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
605
+
606
+ if config.deepnorm:
607
+ self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
608
+ else:
609
+ self.alpha = 1.0
610
+
611
+ def build_ffn(self):
612
+ if self.config.use_glu:
613
+ return GLU(
614
+ self.embed_dim,
615
+ self.ffn_dim,
616
+ self.config.activation_fn,
617
+ self.config.dropout,
618
+ self.config.activation_dropout,
619
+ )
620
+ else:
621
+ return FeedForwardNetwork(
622
+ self.embed_dim,
623
+ self.ffn_dim,
624
+ self.config.activation_fn,
625
+ self.config.dropout,
626
+ self.config.activation_dropout,
627
+ self.config.layernorm_eps,
628
+ self.config.subln,
629
+ self.config.use_ffn_rms_norm,
630
+ )
631
+
632
+ def residual_connection(self, x, residual):
633
+ return residual * self.alpha + x
634
+
635
+ def forward(
636
+ self,
637
+ hidden_states: torch.Tensor,
638
+ retention_rel_pos: Tuple[Tuple[torch.Tensor]],
639
+ retention_mask: Optional[torch.Tensor] = None,
640
+ forward_impl: str = "parallel",
641
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
642
+ output_retentions: Optional[bool] = False,
643
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
644
+ residual = hidden_states
645
+ if self.normalize_before:
646
+ hidden_states = self.retention_layer_norm(hidden_states)
647
+
648
+ msr_outs = self.retention(
649
+ hidden_states,
650
+ retention_rel_pos,
651
+ retention_mask=retention_mask,
652
+ past_key_value=past_key_value,
653
+ forward_impl=forward_impl,
654
+ output_retentions=output_retentions,
655
+ )
656
+ hidden_states = msr_outs[0]
657
+ curr_kv = msr_outs[1]
658
+
659
+ hidden_states = self.dropout_module(hidden_states)
660
+
661
+ if self.drop_path is not None:
662
+ hidden_states = self.drop_path(hidden_states)
663
+
664
+ hidden_states = self.residual_connection(hidden_states, residual)
665
+ if not self.normalize_before:
666
+ hidden_states = self.retention_layer_norm(hidden_states)
667
+
668
+ residual = hidden_states
669
+ if self.normalize_before:
670
+ hidden_states = self.final_layer_norm(hidden_states)
671
+
672
+ hidden_states = self.ffn(hidden_states)
673
+
674
+ if self.drop_path is not None:
675
+ hidden_states = self.drop_path(hidden_states)
676
+
677
+ hidden_states = self.residual_connection(hidden_states, residual)
678
+ if not self.normalize_before:
679
+ hidden_states = self.final_layer_norm(hidden_states)
680
+
681
+ outputs = (hidden_states, curr_kv)
682
+
683
+ if output_retentions:
684
+ outputs += (msr_outs[2],)
685
+ return outputs
686
+
687
+
688
+ class RetNetPreTrainedModel(PreTrainedModel):
689
+ # copied from LlamaPretrainedModel
690
+ config_class = RetNetConfig
691
+ base_model_prefix = "model"
692
+ supports_gradient_checkpointing = True
693
+ _no_split_modules = ["RetNetDecoderLayer"]
694
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
695
+
696
+ def _init_weights(self, module):
697
+ """
698
+ Following original retnet, weights are already initialized in their own
699
+ ways within their own init.
700
+ """
701
+ pass
702
+ # below is copied from LlamaPretrainedModel
703
+ # std = self.config.initializer_range
704
+ # if isinstance(module, nn.Linear):
705
+ # module.weight.data.normal_(mean=0.0, std=std)
706
+ # if module.bias is not None:
707
+ # module.bias.data.zero_()
708
+ # elif isinstance(module, nn.Embedding):
709
+ # module.weight.data.normal_(mean=0.0, std=std)
710
+ # if module.padding_idx is not None:
711
+ # module.weight.data[module.padding_idx].zero_()
712
+
713
+
714
+ @dataclass
715
+ class RetNetOutputWithPast(ModelOutput):
716
+ """
717
+ class for RetNet model's outputs that may also contain a past key/values (to speed up sequential decoding).
718
+
719
+ config:
720
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, decoder_embed_dim)`):
721
+ Sequence of hidden-states at the output of the last layer of the model.
722
+
723
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
724
+ decoder_embed_dim)` is output.
725
+ past_key_values (`List(Dict(str, torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
726
+ - "prev_key_value": shape=(bsz * num_head * v_dim * qk_dim)
727
+ - "scale": shape=((1 or bsz) * num_head * 1 * 1)
728
+
729
+ Contains pre-computed hidden-states (key and values in the multi-scale retention blocks)
730
+ that can be used (see `past_key_values` input) to speed up sequential decoding.
731
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
732
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
733
+ one for the output of each layer) of shape `(batch_size, sequence_length, decoder_embed_dim)`.
734
+
735
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
736
+ retentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_retentions=True` is passed or when `config.output_retentions=True`):
737
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
738
+ sequence_length)`.
739
+
740
+ Retentions weights, used for visualization.
741
+
742
+ attentions (`tuple(torch.FloatTensor)`, *optional*, for backward compatibility. Same as retentions.
743
+ """
744
+
745
+ last_hidden_state: torch.FloatTensor = None
746
+ past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None
747
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
748
+ retentions: Optional[Tuple[torch.FloatTensor]] = None
749
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
750
+
751
+
752
+ class RetNetModel(RetNetPreTrainedModel):
753
+ def __init__(
754
+ self,
755
+ config: RetNetConfig,
756
+ embed_tokens: nn.Embedding = None,
757
+ tensor_parallel: bool = False,
758
+ ):
759
+ super().__init__(config)
760
+ self.config = config
761
+
762
+ self.dropout_module = torch.nn.Dropout(config.dropout)
763
+
764
+ self.embed_dim = config.decoder_embed_dim
765
+ self.embed_scale = (
766
+ 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
767
+ )
768
+
769
+ if embed_tokens is None:
770
+ embed_tokens = nn.Embedding(
771
+ config.vocab_size, config.decoder_embed_dim, config.pad_token_id
772
+ )
773
+ self.embed_tokens = embed_tokens
774
+
775
+ if config.layernorm_embedding:
776
+ self.layernorm_embedding = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
777
+ else:
778
+ self.layernorm_embedding = None
779
+
780
+ self.layers = nn.ModuleList([])
781
+
782
+ for i in range(config.decoder_layers):
783
+ self.layers.append(
784
+ RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel)
785
+ )
786
+
787
+ self.decoder_layers = len(self.layers)
788
+
789
+ if config.decoder_normalize_before:
790
+ self.layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
791
+ else:
792
+ self.layer_norm = None
793
+
794
+ self.retnet_rel_pos = RetNetRelPos(config)
795
+ self.recurrent_chunk_size = config.recurrent_chunk_size
796
+
797
+ if config.deepnorm:
798
+ init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
799
+ for name, p in self.named_parameters():
800
+ if (
801
+ "fc1" in name
802
+ or "fc2" in name
803
+ or "out_proj" in name
804
+ or "v_proj" in name
805
+ ):
806
+ p.data.div_(init_scale)
807
+
808
+ if config.subln and not config.use_glu:
809
+ init_scale = math.sqrt(math.log(config.decoder_layers * 2))
810
+ for name, p in self.named_parameters():
811
+ if (
812
+ "fc1" in name
813
+ or "fc2" in name
814
+ or "out_proj" in name
815
+ or "v_proj" in name
816
+ ):
817
+ p.data.mul_(init_scale)
818
+
819
+ self.gradient_checkpointing = False
820
+ self.post_init()
821
+
822
+ def get_input_embeddings(self):
823
+ return self.embed_tokens
824
+
825
+ def set_input_embeddings(self, value):
826
+ self.embed_tokens = value
827
+
828
+ def forward_embedding(
829
+ self,
830
+ input_ids,
831
+ forward_impl,
832
+ inputs_embeds=None,
833
+ past_key_values=None,
834
+ ):
835
+ # if past_key_values is not None:
836
+ if forward_impl == "recurrent":
837
+ input_ids = input_ids[:, -1:]
838
+
839
+ if inputs_embeds is None:
840
+ inputs_embeds = self.embed_tokens(input_ids)
841
+
842
+ embed = self.embed_scale * inputs_embeds
843
+
844
+ if self.layernorm_embedding is not None:
845
+ embed = self.layernorm_embedding(embed)
846
+
847
+ embed = self.dropout_module(embed)
848
+
849
+ return embed
850
+
851
+ def forward(
852
+ self,
853
+ input_ids: torch.LongTensor = None,
854
+ retention_mask: Optional[torch.Tensor] = None,
855
+ attention_mask: Optional[torch.Tensor] = None,
856
+ past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None,
857
+ inputs_embeds: Optional[torch.FloatTensor] = None,
858
+ output_retentions: Optional[bool] = None,
859
+ output_attentions: Optional[bool] = None,
860
+ output_hidden_states: Optional[bool] = None,
861
+ use_cache: Optional[bool] = None,
862
+ return_dict: Optional[bool] = None,
863
+ forward_impl: Optional[str] = "parallel",
864
+ recurrent_chunk_size: Optional[int] = None,
865
+ retention_rel_pos: Optional[Tuple[torch.Tensor]] = None,
866
+ ) -> Union[Tuple, RetNetOutputWithPast]:
867
+ if output_retentions is None and output_attentions is not None:
868
+ output_retentions = output_attentions
869
+ output_retentions = (
870
+ output_retentions
871
+ if output_retentions is not None
872
+ else self.config.output_retentions
873
+ )
874
+ output_hidden_states = (
875
+ output_hidden_states
876
+ if output_hidden_states is not None
877
+ else self.config.output_hidden_states
878
+ )
879
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
880
+
881
+ return_dict = (
882
+ return_dict if return_dict is not None else self.config.use_return_dict
883
+ )
884
+
885
+ # retrieve input_ids and inputs_embeds
886
+ if input_ids is not None and inputs_embeds is not None:
887
+ raise ValueError(
888
+ "You cannot specify both input_ids and inputs_embeds at the same time"
889
+ )
890
+ elif input_ids is not None:
891
+ batch_size, seq_length = input_ids.shape
892
+ elif inputs_embeds is not None:
893
+ batch_size, seq_length, _ = inputs_embeds.shape
894
+ else:
895
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
896
+
897
+ # embed tokens
898
+ if inputs_embeds is None:
899
+ inputs_embeds = self.forward_embedding(
900
+ input_ids, forward_impl, inputs_embeds, past_key_values
901
+ )
902
+
903
+ if retention_mask is None and attention_mask is not None:
904
+ retention_mask = attention_mask
905
+ if retention_mask is not None and forward_impl == "recurrent":
906
+ retention_mask = retention_mask[:, -1:]
907
+
908
+ hidden_states = inputs_embeds
909
+
910
+ # handling chunking here
911
+ if recurrent_chunk_size is None:
912
+ recurrent_chunk_size = self.recurrent_chunk_size
913
+ need_pad_for_chunkwise = (
914
+ forward_impl == "chunkwise" and seq_length % recurrent_chunk_size != 0
915
+ )
916
+ if need_pad_for_chunkwise:
917
+ padding_len = recurrent_chunk_size - seq_length % recurrent_chunk_size
918
+ slen = seq_length + padding_len
919
+ hidden_states = F.pad(hidden_states, (0, 0, 0, padding_len))
920
+ else:
921
+ slen = seq_length
922
+ # relative position
923
+ if retention_rel_pos is None:
924
+ retention_rel_pos = self.retnet_rel_pos(
925
+ slen,
926
+ forward_impl=forward_impl,
927
+ recurrent_chunk_size=recurrent_chunk_size,
928
+ retention_mask=retention_mask,
929
+ get_decay_scale=not self.training,
930
+ )
931
+
932
+ # start running through the decoder layers
933
+ all_hidden_states = () if output_hidden_states else None
934
+ all_retentions = () if output_retentions else None
935
+ # layers * [bsz, num_head, qk_dim, decoder_embed_dim]
936
+ next_decoder_cache = () if use_cache else None
937
+
938
+ for idx, layer in enumerate(self.layers):
939
+ if output_hidden_states:
940
+ all_hidden_states += (hidden_states,)
941
+ past_key_value = (
942
+ past_key_values[idx] if past_key_values is not None else None
943
+ )
944
+
945
+ if self.gradient_checkpointing and self.training:
946
+
947
+ def create_custom_forward(module):
948
+ def custom_forward(*inputs):
949
+ return module(*inputs, output_retentions)
950
+
951
+ return custom_forward
952
+
953
+ layer_outputs = torch.utils.checkpoint.checkpoint(
954
+ create_custom_forward(layer),
955
+ hidden_states,
956
+ retention_rel_pos,
957
+ retention_mask,
958
+ forward_impl,
959
+ past_key_value,
960
+ )
961
+ else:
962
+ layer_outputs = layer(
963
+ hidden_states,
964
+ retention_rel_pos,
965
+ retention_mask=retention_mask,
966
+ forward_impl=forward_impl,
967
+ past_key_value=past_key_value,
968
+ output_retentions=output_retentions,
969
+ )
970
+
971
+ hidden_states = layer_outputs[0]
972
+
973
+ if use_cache:
974
+ next_decoder_cache += (layer_outputs[1],)
975
+
976
+ if output_retentions:
977
+ all_retentions += (layer_outputs[2],)
978
+
979
+ next_cache = next_decoder_cache if use_cache else None
980
+
981
+ if need_pad_for_chunkwise:
982
+ hidden_states = hidden_states[:, :seq_length, :]
983
+
984
+ if self.layer_norm is not None:
985
+ hidden_states = self.layer_norm(hidden_states)
986
+
987
+ # add hidden states from the last decoder layer
988
+ if output_hidden_states:
989
+ all_hidden_states += (hidden_states,)
990
+
991
+ if not return_dict:
992
+ return tuple(
993
+ v
994
+ for v in [hidden_states, next_cache, all_hidden_states, all_retentions]
995
+ if v is not None
996
+ )
997
+ return RetNetOutputWithPast(
998
+ last_hidden_state=hidden_states,
999
+ past_key_values=next_cache,
1000
+ hidden_states=all_hidden_states,
1001
+ retentions=all_retentions,
1002
+ attentions=all_retentions,
1003
+ )
1004
+
1005
+
1006
+ @dataclass
1007
+ class RetNetCausalLMOutputWithPast(ModelOutput):
1008
+ """
1009
+ class for RetNet causal language model (or autoregressive) outputs.
1010
+
1011
+ config:
1012
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1013
+ Language modeling loss (for next-token prediction).
1014
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1015
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1016
+ past_key_values (`List(Dict(str, torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1017
+ - "prev_key_value": shape=(bsz * num_head * v_dim * qk_dim)
1018
+ - "scale": shape=((1 or bsz) * num_head * 1 * 1)
1019
+
1020
+ Contains pre-computed hidden-states (key and values in the multi-scale retention blocks)
1021
+ that can be used (see `past_key_values` input) to speed up sequential decoding.
1022
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1023
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1024
+ one for the output of each layer) of shape `(batch_size, sequence_length, decoder_embed_dim)`.
1025
+
1026
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1027
+ retentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_retentions=True` is passed or when `config.output_retentions=True`):
1028
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1029
+ sequence_length)`.
1030
+
1031
+ Retentions weights, used for visualization.
1032
+
1033
+ attentions (`tuple(torch.FloatTensor)`, *optional*, for backward compatibility. Same as retentions.
1034
+ """
1035
+
1036
+ loss: Optional[torch.FloatTensor] = None
1037
+ logits: torch.FloatTensor = None
1038
+ past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None
1039
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1040
+ retentions: Optional[Tuple[torch.FloatTensor]] = None
1041
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1042
+
1043
+
1044
+ class RetNetForCausalLM(RetNetPreTrainedModel):
1045
+ def __init__(
1046
+ self,
1047
+ config: RetNetConfig,
1048
+ embed_tokens: nn.Embedding = None,
1049
+ tensor_parallel: bool = False,
1050
+ ) -> None:
1051
+ super().__init__(config)
1052
+ self.model = RetNetModel(
1053
+ config, embed_tokens=embed_tokens, tensor_parallel=tensor_parallel
1054
+ )
1055
+ self.lm_head = nn.Linear(
1056
+ config.decoder_embed_dim, config.vocab_size, bias=False
1057
+ )
1058
+ # init here
1059
+ torch.nn.init.normal_(
1060
+ self.lm_head.weight, mean=0, std=config.decoder_embed_dim**-0.5
1061
+ )
1062
+
1063
+ self.post_init()
1064
+
1065
+ def get_input_embeddings(self):
1066
+ return self.model.embed_tokens
1067
+
1068
+ def set_input_embeddings(self, value):
1069
+ self.model.embed_tokens = value
1070
+
1071
+ def get_output_embeddings(self):
1072
+ return self.lm_head
1073
+
1074
+ def set_output_embeddings(self, new_embeddings):
1075
+ self.lm_head = new_embeddings
1076
+
1077
+ def set_decoder(self, decoder):
1078
+ self.model = decoder
1079
+
1080
+ def get_decoder(self):
1081
+ return self.model
1082
+
1083
+ def forward(
1084
+ self,
1085
+ input_ids: torch.LongTensor = None,
1086
+ retention_mask: Optional[torch.Tensor] = None,
1087
+ attention_mask: Optional[torch.Tensor] = None,
1088
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1089
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1090
+ labels: Optional[torch.LongTensor] = None,
1091
+ use_cache: Optional[bool] = None,
1092
+ output_retentions: Optional[bool] = None,
1093
+ output_attentions: Optional[bool] = None,
1094
+ output_hidden_states: Optional[bool] = None,
1095
+ return_dict: Optional[bool] = None,
1096
+ forward_impl: Optional[str] = None,
1097
+ recurrent_chunk_size: Optional[int] = None,
1098
+ retention_rel_pos: Optional[Tuple[torch.Tensor]] = None,
1099
+ ) -> Union[Tuple, RetNetCausalLMOutputWithPast]:
1100
+ if output_retentions is None and output_attentions is not None:
1101
+ output_retentions = output_attentions
1102
+ output_retentions = (
1103
+ output_retentions
1104
+ if output_retentions is not None
1105
+ else self.config.output_retentions
1106
+ )
1107
+ output_hidden_states = (
1108
+ output_hidden_states
1109
+ if output_hidden_states is not None
1110
+ else self.config.output_hidden_states
1111
+ )
1112
+ return_dict = (
1113
+ return_dict if return_dict is not None else self.config.use_return_dict
1114
+ )
1115
+ forward_impl = (
1116
+ forward_impl if forward_impl is not None else self.config.forward_impl
1117
+ )
1118
+ recurrent_chunk_size = (
1119
+ recurrent_chunk_size
1120
+ if recurrent_chunk_size is not None
1121
+ else self.config.recurrent_chunk_size
1122
+ )
1123
+
1124
+ if retention_mask is None and attention_mask is not None:
1125
+ retention_mask = attention_mask
1126
+
1127
+ outputs = self.model(
1128
+ input_ids,
1129
+ retention_mask=retention_mask,
1130
+ past_key_values=past_key_values,
1131
+ inputs_embeds=inputs_embeds,
1132
+ output_retentions=output_retentions,
1133
+ output_hidden_states=output_hidden_states,
1134
+ return_dict=return_dict,
1135
+ forward_impl=forward_impl,
1136
+ use_cache=use_cache,
1137
+ recurrent_chunk_size=recurrent_chunk_size,
1138
+ retention_rel_pos=retention_rel_pos,
1139
+ )
1140
+
1141
+ hidden_states = outputs[0]
1142
+ logits = self.lm_head(hidden_states)
1143
+
1144
+ loss = None
1145
+ if labels is not None:
1146
+ # Shift so that tokens < n predict n
1147
+ shift_logits = logits[..., :-1, :].contiguous()
1148
+ shift_labels = labels[..., 1:].contiguous()
1149
+ # Flatten the tokens
1150
+ loss_fct = nn.CrossEntropyLoss()
1151
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1152
+ shift_labels = shift_labels.view(-1)
1153
+ # Enable model parallelism
1154
+ shift_labels = shift_labels.to(shift_logits.device)
1155
+ loss = loss_fct(shift_logits, shift_labels)
1156
+
1157
+ if self.config.z_loss_coeff > 0:
1158
+ # z_loss from PaLM paper
1159
+ # z_loss = 1e-4 * log(log(z)), where z = sum(exp(logits))
1160
+ z_loss = torch.logsumexp(shift_logits, dim=-1).log().mean()
1161
+ loss += self.config.z_loss_coeff * z_loss
1162
+
1163
+ if not return_dict:
1164
+ output = (logits,) + outputs[1:]
1165
+ return (loss,) + output if loss is not None else output
1166
+
1167
+ return RetNetCausalLMOutputWithPast(
1168
+ loss=loss,
1169
+ logits=logits,
1170
+ past_key_values=outputs.past_key_values,
1171
+ hidden_states=outputs.hidden_states,
1172
+ retentions=outputs.retentions,
1173
+ attentions=outputs.retentions,
1174
+ )
1175
+
1176
+ def _crop_past_key_values(model, past_key_values, maximum_length):
1177
+ """Since retnet's kv do not have length, no need to crop. Just return"""
1178
+ return past_key_values
1179
+
1180
+ def prepare_inputs_for_generation(
1181
+ self,
1182
+ input_ids,
1183
+ past_key_values=None,
1184
+ attention_mask=None,
1185
+ inputs_embeds=None,
1186
+ **kwargs,
1187
+ ):
1188
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1189
+ if inputs_embeds is not None and past_key_values is None:
1190
+ model_inputs = {"inputs_embeds": inputs_embeds}
1191
+ else:
1192
+ model_inputs = {"input_ids": input_ids}
1193
+
1194
+ forward_impl = kwargs.get("forward_impl", "parallel")
1195
+ if past_key_values is not None:
1196
+ forward_impl = "recurrent"
1197
+
1198
+ model_inputs.update(
1199
+ {
1200
+ "past_key_values": past_key_values,
1201
+ "use_cache": kwargs.get("use_cache"),
1202
+ "attention_mask": attention_mask,
1203
+ "forward_impl": forward_impl,
1204
+ }
1205
+ )
1206
+ return model_inputs
1207
+
1208
+ @staticmethod
1209
+ def _reorder_cache(past_key_values, beam_idx):
1210
+ reordered_past = ()
1211
+ for layer_past in past_key_values: # dict
1212
+ layer_past_kv = layer_past["prev_key_value"] # [b, h, v_dim / h, qk_dim]
1213
+ layer_past_scale = layer_past["scale"] # [b, h, 1, 1]
1214
+ if layer_past_scale.size(0) > 1:
1215
+ # this means that retention_mask is not None, so the scale for
1216
+ # each batch is different. We need to select the correct scale then.
1217
+ # NOTE: during huggingface generate, it will generate attention_mask
1218
+ # if it is None, so this linke will always be true. Still, having
1219
+ # this line here for safety.
1220
+ layer_past_scale = layer_past_scale.index_select(0, beam_idx)
1221
+ reordered_past += (
1222
+ {
1223
+ "prev_key_value": layer_past_kv.index_select(0, beam_idx),
1224
+ "scale": layer_past_scale,
1225
+ },
1226
+ )
1227
+ return reordered_past
1228
+
1229
+ def sample_token(self, logit, do_sample=False, top_k=1, top_p=1.0, temperature=1.0):
1230
+ if not do_sample:
1231
+ return torch.argmax(logit, dim=-1, keepdim=True)
1232
+ filtered = top_k_top_p_filtering(logit / temperature, top_k=top_k, top_p=top_p)
1233
+ return torch.multinomial(torch.softmax(filtered, dim=-1), num_samples=1)
1234
+
1235
+ @torch.inference_mode()
1236
+ def custom_generate(
1237
+ self,
1238
+ input_ids: torch.LongTensor = None,
1239
+ retention_mask: Optional[torch.Tensor] = None,
1240
+ attention_mask: Optional[torch.Tensor] = None,
1241
+ parallel_compute_prompt=True,
1242
+ max_new_tokens=20,
1243
+ bos_token_id=0,
1244
+ eos_token_id=0,
1245
+ do_sample=False,
1246
+ top_k=0,
1247
+ top_p=1.0,
1248
+ temperature=1.0,
1249
+ early_stopping=True,
1250
+ ):
1251
+ if retention_mask is None and attention_mask is not None:
1252
+ retention_mask = attention_mask
1253
+
1254
+ if input_ids is not None:
1255
+ if input_ids.shape[1] == 1:
1256
+ past_key_values = None
1257
+ elif parallel_compute_prompt:
1258
+ ret_mask = (
1259
+ retention_mask[:, :-1] if retention_mask is not None else None
1260
+ )
1261
+ outputs = self(
1262
+ input_ids[:, :-1],
1263
+ retention_mask=ret_mask,
1264
+ forward_impl="parallel",
1265
+ return_dict=True,
1266
+ use_cache=True,
1267
+ )
1268
+ past_key_values = outputs.past_key_values
1269
+ else:
1270
+ past_key_values = None
1271
+ for p_i in range(input_ids.shape[1] - 1):
1272
+ ret_mask = (
1273
+ retention_mask[:, : p_i + 1]
1274
+ if retention_mask is not None
1275
+ else None
1276
+ )
1277
+ outputs = self(
1278
+ input_ids[:, : p_i + 1],
1279
+ retention_mask=ret_mask,
1280
+ forward_impl="recurrent",
1281
+ past_key_values=past_key_values,
1282
+ return_dict=True,
1283
+ use_cache=True,
1284
+ )
1285
+ past_key_values = outputs.past_key_values
1286
+
1287
+ generated = input_ids
1288
+ else:
1289
+ generated = torch.tensor([[bos_token_id]]).to(self.lm_head.weight.device)
1290
+ past_key_values = None
1291
+
1292
+ for i in range(max_new_tokens):
1293
+ outputs = self(
1294
+ generated,
1295
+ retention_mask=retention_mask,
1296
+ forward_impl="recurrent",
1297
+ past_key_values=past_key_values,
1298
+ use_cache=True,
1299
+ return_dict=True,
1300
+ )
1301
+ logit = outputs.logits[:, -1, :] # [batch_size, vocab_size]
1302
+ past_key_values = outputs.past_key_values
1303
+ token = self.sample_token(
1304
+ logit,
1305
+ do_sample=do_sample,
1306
+ top_k=top_k,
1307
+ top_p=top_p,
1308
+ temperature=temperature,
1309
+ )
1310
+ generated = torch.cat([generated, token], dim=-1)
1311
+ if retention_mask is not None:
1312
+ retention_mask = torch.cat(
1313
+ [retention_mask, torch.ones_like(token)], dim=-1
1314
+ )
1315
+ if early_stopping and (token == eos_token_id).all():
1316
+ break
1317
+ return generated