p1atdev commited on
Commit
f4168f8
·
1 Parent(s): 528a720

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_retnet.py +122 -0
  2. modeling_retnet.py +1491 -0
configuration_retnet.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/syncdoth/RetNet/blob/main/retnet/configuration_retnet.py
2
+
3
+ from dataclasses import dataclass
4
+ import json
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+
8
+
9
+ def load_config_from_json(config_file):
10
+ with open(config_file, "r") as f:
11
+ config = json.load(f)
12
+ config = RetNetConfig.from_dict(config)
13
+ return config
14
+
15
+
16
+ @dataclass
17
+ class RetNetConfig(PretrainedConfig):
18
+ model_type = "retnet"
19
+ initializer_range: float = 0.02
20
+ activation_fn: str = "gelu"
21
+ dropout: float = 0.0 # dropout probability
22
+ activation_dropout: float = 0.0 # dropout probability after activation in FFN.
23
+ drop_path_rate: float = 0.0
24
+ decoder_embed_dim: int = 768 # decoder embedding dimension
25
+ decoder_value_embed_dim: int = 1280 # decoder value embedding dimension
26
+ decoder_ffn_embed_dim: int = 1280 # decoder embedding dimension for FFN
27
+ decoder_layers: int = 12 # num decoder layers
28
+ decoder_retention_heads: int = 3 # num decoder retention heads
29
+ decoder_normalize_before: bool = True # apply layernorm before each decoder block
30
+ layernorm_embedding: bool = False # add layernorm to embedding
31
+ no_scale_embedding: bool = True # if True, dont scale embeddings
32
+ recurrent_chunk_size: int = 512
33
+ use_lm_decay: bool = False
34
+ use_glu: bool = True # use GLU instead of FFN
35
+ z_loss_coeff: float = 0.0 # coefficient for z loss: TODO: 1e-4
36
+ deepnorm: bool = False
37
+ subln: bool = True
38
+ use_ffn_rms_norm: bool = False
39
+ layernorm_eps: float = 1e-6
40
+ tie_word_embeddings: bool = False
41
+
42
+ def __init__(
43
+ self,
44
+ vocab_size: int = 50257,
45
+ initializer_range: float = 0.02,
46
+ is_decoder: bool = True,
47
+ pad_token_id: int = 0,
48
+ eos_token_id: int = 0,
49
+ output_retentions: bool = False,
50
+ use_cache: bool = True,
51
+ forward_impl: str = "parallel",
52
+ activation_fn: str = "gelu",
53
+ dropout: float = 0.0, # dropout probability
54
+ activation_dropout: float = 0.0, # dropout probability after activation in FFN.
55
+ drop_path_rate: float = 0.0,
56
+ decoder_embed_dim: int = 768, # decoder embedding dimension
57
+ decoder_value_embed_dim: int = 1280, # decoder value embedding dimension
58
+ decoder_ffn_embed_dim: int = 1280, # decoder embedding dimension for FFN
59
+ decoder_layers: int = 12, # num decoder layers
60
+ decoder_retention_heads: int = 3, # num decoder retention heads
61
+ decoder_normalize_before: bool = True, # apply layernorm before each decoder block
62
+ layernorm_embedding: bool = False, # add layernorm to embedding
63
+ no_scale_embedding: bool = True, # if True, dont scale embeddings
64
+ recurrent_chunk_size: int = 512,
65
+ use_glu: bool = True, # use GLU instead of FFN
66
+ z_loss_coeff: float = 0.0, # coefficient for z loss: TODO: 1e-4
67
+ use_lm_decay: bool = False,
68
+ deepnorm: bool = False,
69
+ subln: bool = True,
70
+ use_ffn_rms_norm: bool = False, # use RMSNorm instead of LayerNorm in FFN
71
+ layernorm_eps: float = 1e-6,
72
+ tie_word_embeddings: bool = False,
73
+ **kwargs
74
+ ):
75
+ self.vocab_size = vocab_size
76
+ self.initializer_range = initializer_range
77
+ self.output_retentions = output_retentions
78
+ self.use_lm_decay = use_lm_decay
79
+ self.use_glu = use_glu
80
+ self.z_loss_coeff = z_loss_coeff
81
+ # size related
82
+ self.decoder_embed_dim = decoder_embed_dim
83
+ self.decoder_value_embed_dim = decoder_value_embed_dim
84
+ self.decoder_retention_heads = decoder_retention_heads
85
+ self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
86
+ self.decoder_layers = decoder_layers
87
+ # normalization related
88
+ self.decoder_normalize_before = decoder_normalize_before
89
+ self.activation_fn = activation_fn
90
+ self.dropout = dropout
91
+ self.drop_path_rate = drop_path_rate
92
+ self.activation_dropout = activation_dropout
93
+ self.no_scale_embedding = no_scale_embedding
94
+ self.layernorm_embedding = layernorm_embedding
95
+ self.deepnorm = deepnorm
96
+ self.subln = subln
97
+ self.use_ffn_rms_norm = use_ffn_rms_norm
98
+ self.layernorm_eps = layernorm_eps
99
+ # Blockwise
100
+ self.recurrent_chunk_size = recurrent_chunk_size
101
+ self.forward_impl = forward_impl
102
+
103
+ if self.deepnorm:
104
+ self.decoder_normalize_before = False
105
+ self.subln = False
106
+ if self.subln:
107
+ self.decoder_normalize_before = True
108
+ self.deepnorm = False
109
+
110
+ super().__init__(
111
+ is_decoder=is_decoder,
112
+ pad_token_id=pad_token_id,
113
+ eos_token_id=eos_token_id,
114
+ use_cache=use_cache,
115
+ tie_word_embeddings=tie_word_embeddings,
116
+ **kwargs
117
+ )
118
+
119
+ def override(self, args):
120
+ for hp in self.__dict__.keys():
121
+ if getattr(args, hp, None) is not None:
122
+ self.__dict__[hp] = getattr(args, hp, None)
modeling_retnet.py ADDED
@@ -0,0 +1,1491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/syncdoth/RetNet/blob/main/retnet/modeling_retnet.py
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+
12
+ from torch import nn
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
+ from transformers import top_k_top_p_filtering
15
+ from transformers.activations import ACT2FN
16
+ from transformers.modeling_outputs import ModelOutput, SequenceClassifierOutputWithPast
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+
20
+ from .configuration_retnet import RetNetConfig
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ # helper functions
26
+ def split_heads(tensors, bsz, seqlen, num_heads):
27
+ assert isinstance(tensors, (tuple, list))
28
+ return [x.view(bsz, seqlen, num_heads, -1).transpose(1, 2) for x in tensors]
29
+
30
+
31
+ def rotate_every_two(x):
32
+ x1 = x[:, :, :, ::2]
33
+ x2 = x[:, :, :, 1::2]
34
+ x = torch.stack((-x2, x1), dim=-1)
35
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\
36
+
37
+
38
+ def theta_shift(x, sin, cos):
39
+ return (x * cos) + (rotate_every_two(x) * sin)
40
+
41
+
42
+ def get_activation_fn(activation):
43
+ return ACT2FN[activation]
44
+
45
+
46
+ # Copied from https://github.com/huggingface/pytorch-image-models/blob/bbe798317fb26f063c18279827c038058e376479/timm/layers/drop.py#L137C1-L154C29
47
+ def drop_path(
48
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
49
+ ):
50
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
51
+
52
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
53
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
54
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
55
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
56
+ 'survival rate' as the argument.
57
+
58
+ """
59
+ if drop_prob == 0.0 or not training:
60
+ return x
61
+ keep_prob = 1 - drop_prob
62
+ shape = (x.shape[0],) + (1,) * (
63
+ x.ndim - 1
64
+ ) # work with diff dim tensors, not just 2D ConvNets
65
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
66
+ if keep_prob > 0.0 and scale_by_keep:
67
+ random_tensor.div_(keep_prob)
68
+ return x * random_tensor
69
+
70
+
71
+ class RMSNorm(nn.Module):
72
+ def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True):
73
+ super().__init__()
74
+ self.normalized_shape = dim
75
+ self.eps = eps
76
+ self.elementwise_affine = elementwise_affine
77
+ if self.elementwise_affine:
78
+ self.weight = nn.Parameter(torch.ones(dim))
79
+ else:
80
+ self.register_parameter("weight", None)
81
+
82
+ def _norm(self, x):
83
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
84
+
85
+ def forward(self, x):
86
+ output = self._norm(x.float()).type_as(x)
87
+ if self.weight is not None:
88
+ output = output * self.weight
89
+ return output
90
+
91
+
92
+ try:
93
+ from apex.normalization import FusedRMSNorm
94
+
95
+ RMSNorm = FusedRMSNorm # noqa
96
+
97
+ logger.info(
98
+ "Discovered apex.normalization.FusedRMSNorm - will use it instead of RMSNorm"
99
+ )
100
+ except ImportError:
101
+ # using the normal RMSNorm
102
+ pass
103
+ except Exception:
104
+ logger.warning("discovered apex but it failed to load, falling back to RMSNorm")
105
+ pass
106
+
107
+
108
+ class RetNetRelPos(nn.Module):
109
+ def __init__(self, config: RetNetConfig):
110
+ super().__init__()
111
+ self.config = config
112
+ num_heads = config.decoder_retention_heads
113
+
114
+ angle = 1.0 / (
115
+ 10000 ** torch.linspace(0, 1, config.decoder_embed_dim // num_heads // 2)
116
+ )
117
+ angle = angle.unsqueeze(-1).repeat(1, 2).flatten()
118
+ # decay (gamma)
119
+ if config.use_lm_decay:
120
+ # NOTE: alternative way described in the paper
121
+ s = torch.log(torch.tensor(1 / 32))
122
+ e = torch.log(torch.tensor(1 / 512))
123
+ decay = torch.log(1 - torch.exp(torch.linspace(s, e, num_heads))) # [h,]
124
+ else:
125
+ decay = torch.log(
126
+ 1 - 2 ** (-5 - torch.arange(num_heads, dtype=torch.float))
127
+ )
128
+ self.register_buffer("angle", angle)
129
+ self.register_buffer("decay", decay)
130
+ self.recurrent_chunk_size = config.recurrent_chunk_size
131
+
132
+ def forward(
133
+ self,
134
+ slen,
135
+ forward_impl="parallel",
136
+ recurrent_chunk_size=None,
137
+ retention_mask=None,
138
+ get_decay_scale=True,
139
+ ):
140
+ if forward_impl == "recurrent":
141
+ sin = torch.sin(self.angle * (slen - 1))
142
+ cos = torch.cos(self.angle * (slen - 1))
143
+ retention_rel_pos = ((sin, cos), self.decay.view(1, -1, 1, 1).exp())
144
+ elif forward_impl == "chunkwise":
145
+ if recurrent_chunk_size is None:
146
+ recurrent_chunk_size = self.recurrent_chunk_size
147
+ index = torch.arange(slen).to(self.decay)
148
+ sin = torch.sin(index[:, None] * self.angle[None, :])
149
+ cos = torch.cos(index[:, None] * self.angle[None, :])
150
+
151
+ block_index = torch.arange(recurrent_chunk_size).to(self.decay)
152
+ mask = torch.tril(
153
+ torch.ones(recurrent_chunk_size, recurrent_chunk_size)
154
+ ).to(self.decay)
155
+ mask = torch.masked_fill(
156
+ block_index[:, None] - block_index[None, :], ~mask.bool(), float("inf")
157
+ )
158
+ mask = torch.exp(mask * self.decay[:, None, None])
159
+ mask = torch.nan_to_num(mask)
160
+ mask = mask.unsqueeze(0) # [1, h, t, t]
161
+ # TODO: need to handle retention_mask
162
+ # scaling
163
+ value_inner_decay = mask[:, :, -1] / mask[:, :, -1].sum(
164
+ dim=-1, keepdim=True
165
+ )
166
+ value_inner_decay = value_inner_decay.unsqueeze(-1)
167
+ scale = mask.sum(dim=-1, keepdim=True).sqrt()
168
+ inner_mask = mask / scale
169
+
170
+ cross_decay = torch.exp(self.decay * recurrent_chunk_size)
171
+ query_inner_decay = torch.exp(self.decay[:, None] * (block_index + 1))
172
+ cross_decay = cross_decay[None, :, None, None]
173
+ query_inner_decay = query_inner_decay[None, :, :, None] / (
174
+ scale / mask[:, :, -1].sum(dim=-1)[:, :, None, None]
175
+ )
176
+ # decay_scale (used for kv cache)
177
+ if get_decay_scale:
178
+ decay_scale = self.compute_decay_scale(slen, retention_mask)
179
+ else:
180
+ decay_scale = None
181
+ retention_rel_pos = (
182
+ (sin, cos),
183
+ (
184
+ inner_mask,
185
+ cross_decay,
186
+ query_inner_decay,
187
+ value_inner_decay,
188
+ decay_scale,
189
+ ),
190
+ )
191
+ else: # parallel
192
+ index = torch.arange(slen).to(self.decay)
193
+ sin = torch.sin(index[:, None] * self.angle[None, :])
194
+ cos = torch.cos(index[:, None] * self.angle[None, :])
195
+ mask = torch.tril(torch.ones(slen, slen)).to(self.decay)
196
+ mask = torch.masked_fill(
197
+ index[:, None] - index[None, :], ~mask.bool(), float("inf")
198
+ )
199
+ mask = torch.exp(mask * self.decay[:, None, None])
200
+ mask = torch.nan_to_num(mask)
201
+ mask = mask.unsqueeze(0) # [1, h, t, t]
202
+ if retention_mask is not None:
203
+ # this is required for left padding
204
+ mask = mask * retention_mask.float().view(-1, 1, 1, slen).to(mask)
205
+
206
+ # scaling
207
+ mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
208
+ mask = torch.nan_to_num(mask, nan=0.0)
209
+ # decay_scale (used for kv cache)
210
+ if get_decay_scale:
211
+ decay_scale = self.compute_decay_scale(slen, retention_mask)
212
+ else:
213
+ decay_scale = None
214
+ # mask processing for intra decay
215
+ if retention_mask is not None:
216
+ max_non_zero = (
217
+ torch.cumsum(retention_mask, dim=-1).max(dim=-1).indices
218
+ ) # [b,]
219
+ intra_decay = mask[range(mask.shape[0]), :, max_non_zero]
220
+ else:
221
+ intra_decay = mask[:, :, -1]
222
+
223
+ retention_rel_pos = ((sin, cos), (mask, intra_decay, decay_scale))
224
+
225
+ return retention_rel_pos
226
+
227
+ def compute_decay_scale(self, slen, retention_mask=None):
228
+ exponent = torch.arange(slen, device=self.decay.device).float()
229
+ decay_scale = self.decay.exp().view(-1, 1) ** exponent.view(1, -1) # [h, t]
230
+ if retention_mask is not None:
231
+ seqlen = retention_mask.sum(dim=-1) # [b,]
232
+ bsz = seqlen.size(0)
233
+ decay_scale = decay_scale.unsqueeze(0).repeat(bsz, 1, 1) # [b, h, t]
234
+ for i, pos in enumerate(seqlen):
235
+ # the formula for decay_scale is `sum(gamma^i) for i in [0, slen).`
236
+ # Since the retention_mask is 0 for padding, we can set the decay_scale
237
+ # to 0 for the padding positions.
238
+ decay_scale[i, :, pos.item() :] = 0
239
+ else:
240
+ bsz = 1
241
+ decay_scale = decay_scale.sum(-1).view(bsz, -1, 1, 1) # [b, h, 1, 1]
242
+ return decay_scale
243
+
244
+
245
+ class MultiScaleRetention(nn.Module):
246
+ def __init__(
247
+ self,
248
+ config: RetNetConfig,
249
+ gate_fn="swish",
250
+ use_bias=False,
251
+ tensor_parallel=False,
252
+ ):
253
+ super().__init__()
254
+ self.config = config
255
+ self.embed_dim = config.decoder_embed_dim
256
+ self.value_dim = config.decoder_value_embed_dim
257
+ self.num_heads = config.decoder_retention_heads
258
+ self.head_dim = self.value_dim // self.num_heads
259
+ self.key_dim = self.embed_dim // self.num_heads
260
+ self.scaling = self.key_dim**-0.5
261
+
262
+ self.gate_fn = get_activation_fn(activation=str(gate_fn))
263
+
264
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias)
265
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=use_bias)
266
+ self.v_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias)
267
+ self.g_proj = nn.Linear(self.embed_dim, self.value_dim, bias=use_bias)
268
+
269
+ self.out_proj = nn.Linear(self.value_dim, self.embed_dim, bias=use_bias)
270
+
271
+ self.group_norm = RMSNorm(
272
+ self.head_dim, eps=config.layernorm_eps, elementwise_affine=False
273
+ )
274
+ self.reset_parameters()
275
+
276
+ if tensor_parallel:
277
+ self.decay_proj = nn.Linear(self.num_heads, self.num_heads, bias=False)
278
+ else:
279
+ self.decay_proj = None
280
+
281
+ def reset_parameters(self):
282
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=2**-2.5)
283
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=2**-2.5)
284
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=2**-2.5)
285
+ nn.init.xavier_uniform_(self.g_proj.weight, gain=2**-2.5)
286
+ nn.init.xavier_uniform_(self.out_proj.weight)
287
+
288
+ def parallel_retention(self, q, k, v, decay_mask):
289
+ """
290
+ q, # bsz * num_head * len * qk_dim
291
+ k, # bsz * num_head * len * qk_dim
292
+ v, # bsz * num_head * len * v_dim
293
+ decay_mask, # (1 or bsz) * num_head * len * len
294
+ """
295
+ decay_mask, intra_decay, scale = decay_mask
296
+ # just return retention_rel_pos projected
297
+ # TODO: for shardformer
298
+ if self.decay_proj is not None:
299
+ decay_mask = self.decay_proj(decay_mask.transpose(-1, -3)).transpose(-3, -1)
300
+
301
+ # [b, h, t, t]
302
+ retention = q @ k.transpose(-1, -2) # (scaled dot-product)
303
+ retention = retention * decay_mask
304
+
305
+ # invariant after normalization
306
+ retention = retention / retention.detach().sum(
307
+ dim=-1, keepdim=True
308
+ ).abs().clamp(min=1)
309
+
310
+ output = retention.type_as(v) @ v # [b, h, t, v_dim / h]
311
+ output = output.transpose(1, 2) # [b, t, h, v_dim / h]
312
+
313
+ if self.training: # skip cache
314
+ return output, None, retention
315
+
316
+ if self.decay_proj is not None:
317
+ intra_decay = self.decay_proj(intra_decay.transpose(-1, -2)).transpose(
318
+ -2, -1
319
+ )
320
+
321
+ # kv cache: [b, h, t, v_dim, qk_dim]
322
+ current_kv = k.unsqueeze(-2) * v.unsqueeze(-1)
323
+ intra_decay = intra_decay[:, :, :, None, None] # [b, h, t, 1, 1]
324
+ current_kv = (current_kv * intra_decay).sum(2) # [b, h, v_dim, qk_dim]
325
+
326
+ cache = {"prev_key_value": current_kv, "scale": scale}
327
+ return output, cache, retention
328
+
329
+ def recurrent_retention(
330
+ self, q, k, v, decay, past_key_value=None, retention_mask=None
331
+ ):
332
+ """
333
+ q, k, v, # bsz * num_head * 1 * qkv_dim
334
+ past_key_value:
335
+ - "prev_key_value" # bsz * num_head * v_dim * qk_dim
336
+ - "scale" # (1 or bsz) * num_head * 1 * 1
337
+ decay # (1 or bsz) * num_head * 1 * 1
338
+ retention_mask # bsz * 1
339
+ """
340
+ if retention_mask is not None:
341
+ retention_mask = retention_mask.float().view(-1, 1, 1, 1).to(decay)
342
+ else:
343
+ retention_mask = torch.ones(k.size(0), 1, 1, 1).to(decay)
344
+ # (b, h, v_dim, qk_dim)
345
+ current_kv = k * v.transpose(-1, -2) * retention_mask
346
+
347
+ if past_key_value is not None and "prev_key_value" in past_key_value:
348
+ prev_kv = past_key_value["prev_key_value"]
349
+ prev_scale = past_key_value["scale"]
350
+ scale = torch.where(retention_mask == 0, prev_scale, prev_scale * decay + 1)
351
+ # connect prev_kv and current_kv
352
+ # how much to decay prev_kv
353
+ decay_amount = prev_scale.sqrt() * decay / scale.sqrt()
354
+ decay_amount = torch.where(retention_mask == 0, 1, decay_amount)
355
+ prev_kv = prev_kv * decay_amount # decay prev_kv
356
+ current_kv = current_kv / scale.sqrt() # scale current_kv
357
+ current_kv = torch.nan_to_num(
358
+ current_kv, nan=0.0
359
+ ) # remove nan, scale might be 0
360
+
361
+ current_kv = prev_kv + current_kv
362
+ else:
363
+ scale = torch.ones_like(decay)
364
+ # when retention_mask is 0 at the beginning, setting scale to 1 will
365
+ # make the first retention to use the padding incorrectly. Hence,
366
+ # setting it to 0 here. This is a little ugly, so we might want to
367
+ # change this later. TODO: improve
368
+ scale = torch.where(retention_mask == 0, torch.zeros_like(decay), scale)
369
+
370
+ output = torch.sum(q * current_kv, dim=3).unsqueeze(1) # (b, 1, h, d_v)
371
+
372
+ cache = {"prev_key_value": current_kv, "scale": scale}
373
+ return output, cache
374
+
375
+ def chunkwise_retention(self, q, k, v, decay_mask):
376
+ """
377
+ q, k, v, # bsz * num_head * seqlen * qkv_dim
378
+ past_key_value:
379
+ - "prev_key_value" # bsz * num_head * v_dim * qk_dim
380
+ - "scale" # (1 or bsz) * num_head * 1 * 1
381
+ decay_mask, # 1 * num_head * chunk_size * chunk_size
382
+ cross_decay, # 1 * num_head * 1 * 1
383
+ inner_decay, # 1 * num_head * chunk_size * 1
384
+ """
385
+ # TODO: not working properly
386
+ (
387
+ decay_mask,
388
+ cross_decay,
389
+ query_inner_decay,
390
+ value_inner_decay,
391
+ decay_scale,
392
+ ) = decay_mask
393
+ bsz, _, tgt_len, _ = v.size()
394
+ chunk_len = decay_mask.size(-1)
395
+ assert tgt_len % chunk_len == 0
396
+ num_chunks = tgt_len // chunk_len
397
+
398
+ # [b, n_c, h, t_c, qkv_dim]
399
+ q = q.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(
400
+ 1, 2
401
+ )
402
+ k = k.view(bsz, self.num_heads, num_chunks, chunk_len, self.key_dim).transpose(
403
+ 1, 2
404
+ )
405
+ v = v.view(bsz, self.num_heads, num_chunks, chunk_len, self.head_dim).transpose(
406
+ 1, 2
407
+ )
408
+
409
+ k_t = k.transpose(-1, -2)
410
+
411
+ qk_mat = q @ k_t # [b, n_c, h, t_c, t_c]
412
+ qk_mat = qk_mat * decay_mask.unsqueeze(1)
413
+ inner_scale = qk_mat.detach().abs().sum(dim=-1, keepdim=True).clamp(min=1)
414
+ qk_mat = qk_mat / inner_scale
415
+ # [b, n_c, h, t_c, v_dim]
416
+ inner_output = torch.matmul(qk_mat, v)
417
+
418
+ # reduce kv in one chunk
419
+ # [b, n_c, h, qk_dim, v_dim]
420
+ kv = k_t @ (v * value_inner_decay)
421
+ # kv = kv.view(bsz, num_chunks, self.num_heads, self.key_dim, self.head_dim)
422
+
423
+ kv_recurrent = []
424
+ cross_scale = []
425
+ kv_state = torch.zeros(bsz, self.num_heads, self.key_dim, self.head_dim).to(v)
426
+ kv_scale = torch.ones(bsz, self.num_heads, 1, 1).to(v)
427
+
428
+ # accumulate kv by loop
429
+ for i in range(num_chunks):
430
+ kv_recurrent.append(kv_state / kv_scale)
431
+ cross_scale.append(kv_scale)
432
+ kv_state = kv_state * cross_decay + kv[:, i]
433
+ kv_scale = (
434
+ kv_state.detach()
435
+ .abs()
436
+ .sum(dim=-2, keepdim=True)
437
+ .max(dim=-1, keepdim=True)
438
+ .values.clamp(min=1)
439
+ )
440
+
441
+ kv_recurrent = torch.stack(kv_recurrent, dim=1)
442
+ cross_scale = torch.stack(cross_scale, dim=1)
443
+
444
+ all_scale = torch.maximum(inner_scale, cross_scale)
445
+ align_inner_scale = all_scale / inner_scale
446
+ align_cross_scale = all_scale / cross_scale
447
+
448
+ cross_output = (q * query_inner_decay.unsqueeze(1)) @ kv_recurrent
449
+ output = inner_output / align_inner_scale + cross_output / align_cross_scale
450
+ output = output.transpose(2, 3) # [b, n_c, t_c, h, v_dim]
451
+
452
+ cache = {"prev_key_value": kv_state.transpose(-2, -1), "scale": decay_scale}
453
+ return output, cache
454
+
455
+ def forward(
456
+ self,
457
+ hidden_states: torch.Tensor,
458
+ rel_pos: Tuple[Tuple[torch.Tensor]],
459
+ retention_mask: Optional[torch.Tensor] = None,
460
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
461
+ forward_impl: str = "parallel",
462
+ output_retentions: Optional[bool] = False,
463
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
464
+ B, T, H = hidden_states.size()
465
+ (sin, cos), decay_mask = rel_pos
466
+ # projections
467
+ q = self.q_proj(hidden_states)
468
+ k = self.k_proj(hidden_states)
469
+ v = self.v_proj(hidden_states)
470
+ g = self.g_proj(hidden_states)
471
+ # multi-head
472
+ q, k, v = split_heads((q, k, v), B, T, self.num_heads)
473
+ k *= self.scaling # for scaled dot product
474
+ # rotate
475
+ # NOTE: theta_shift has bug with mps device.
476
+ qr = theta_shift(q, sin, cos)
477
+ kr = theta_shift(k, sin, cos)
478
+
479
+ # retention
480
+ if forward_impl == "parallel":
481
+ retention_out, curr_kv, retention_weights = self.parallel_retention(
482
+ qr, kr, v, decay_mask
483
+ )
484
+ elif forward_impl == "recurrent":
485
+ retention_out, curr_kv = self.recurrent_retention(
486
+ qr,
487
+ kr,
488
+ v,
489
+ decay_mask,
490
+ past_key_value=past_key_value,
491
+ retention_mask=retention_mask,
492
+ )
493
+ elif forward_impl == "chunkwise":
494
+ retention_out, curr_kv = self.chunkwise_retention(qr, kr, v, decay_mask)
495
+ else:
496
+ raise ValueError(f"forward_impl {forward_impl} not supported.")
497
+
498
+ # concaat heads
499
+ normed = self.group_norm(retention_out).reshape(B, T, self.value_dim)
500
+ # out gate & proj
501
+ out = self.gate_fn(g) * normed
502
+ out = self.out_proj(out.type_as(hidden_states))
503
+
504
+ outputs = (out, curr_kv)
505
+ if output_retentions:
506
+ outputs += (retention_weights,) if forward_impl == "parallel" else (None,)
507
+ return outputs
508
+
509
+
510
+ class FeedForwardNetwork(nn.Module):
511
+ def __init__(
512
+ self,
513
+ embed_dim,
514
+ ffn_dim,
515
+ activation_fn,
516
+ dropout,
517
+ activation_dropout,
518
+ layernorm_eps,
519
+ subln=False,
520
+ use_rms_norm=False,
521
+ ):
522
+ super().__init__()
523
+ self.embed_dim = embed_dim
524
+ self.activation_fn = get_activation_fn(activation=str(activation_fn))
525
+ self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
526
+ self.dropout_module = torch.nn.Dropout(dropout)
527
+ self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
528
+ self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
529
+ if subln:
530
+ if use_rms_norm:
531
+ self.ffn_layernorm = RMSNorm(ffn_dim, eps=layernorm_eps)
532
+ else:
533
+ self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps)
534
+ else:
535
+ self.ffn_layernorm = None
536
+
537
+ def reset_parameters(self):
538
+ self.fc1.reset_parameters()
539
+ self.fc2.reset_parameters()
540
+ if self.ffn_layernorm is not None:
541
+ self.ffn_layernorm.reset_parameters()
542
+
543
+ def forward(self, x):
544
+ x_shape = x.shape
545
+ x = x.reshape(-1, x.size(-1))
546
+ x = self.fc1(x)
547
+ x = self.activation_fn(x.float()).type_as(x)
548
+ x = self.activation_dropout_module(x)
549
+ if self.ffn_layernorm is not None:
550
+ x = self.ffn_layernorm(x)
551
+ x = self.fc2(x)
552
+ x = x.view(x_shape)
553
+ x = self.dropout_module(x)
554
+ return x
555
+
556
+
557
+ class GLU(nn.Module):
558
+ def __init__(
559
+ self,
560
+ embed_dim,
561
+ ffn_dim,
562
+ activation_fn,
563
+ dropout,
564
+ activation_dropout,
565
+ ):
566
+ super().__init__()
567
+ self.embed_dim = embed_dim
568
+ self.activation_fn = get_activation_fn(activation=str(activation_fn))
569
+ self.activation_dropout_module = torch.nn.Dropout(activation_dropout)
570
+ self.dropout_module = torch.nn.Dropout(dropout)
571
+ self.fc1 = nn.Linear(self.embed_dim, ffn_dim, bias=False)
572
+ self.fc2 = nn.Linear(ffn_dim, self.embed_dim, bias=False)
573
+ self.gate = nn.Linear(self.embed_dim, ffn_dim, bias=False)
574
+
575
+ def reset_parameters(self):
576
+ self.fc1.reset_parameters()
577
+ self.fc2.reset_parameters()
578
+ self.gate.reset_parameters()
579
+
580
+ def forward(self, x):
581
+ x_shape = x.shape
582
+ x = x.reshape(-1, x.size(-1))
583
+ g = self.gate(x)
584
+ x = self.fc1(x)
585
+ x = self.activation_fn(x.float()).type_as(x) * g
586
+ x = self.activation_dropout_module(x)
587
+ x = self.fc2(x)
588
+ x = x.view(x_shape)
589
+ x = self.dropout_module(x)
590
+ return x
591
+
592
+
593
+ class DropPath(nn.Module):
594
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
595
+
596
+ def __init__(self, drop_prob=None):
597
+ super(DropPath, self).__init__()
598
+ self.drop_prob = drop_prob
599
+
600
+ def forward(self, x):
601
+ return drop_path(x, self.drop_prob, self.training)
602
+
603
+ def extra_repr(self):
604
+ return "p={}".format(self.drop_prob)
605
+
606
+
607
+ class RetNetDecoderLayer(nn.Module):
608
+ def __init__(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False):
609
+ super().__init__()
610
+ self.config = config
611
+ self.embed_dim = config.decoder_embed_dim
612
+ self.dropout_module = torch.nn.Dropout(config.dropout)
613
+
614
+ if config.drop_path_rate > 0:
615
+ drop_path_prob = np.linspace(
616
+ 0, config.drop_path_rate, config.decoder_layers
617
+ )[depth]
618
+ self.drop_path = DropPath(drop_path_prob)
619
+ else:
620
+ self.drop_path = None
621
+
622
+ self.retention = MultiScaleRetention(
623
+ config, use_bias=False, tensor_parallel=tensor_parallel
624
+ )
625
+
626
+ self.normalize_before = config.decoder_normalize_before
627
+
628
+ self.retention_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
629
+
630
+ self.ffn_dim = config.decoder_ffn_embed_dim
631
+
632
+ self.ffn = self.build_ffn()
633
+
634
+ self.final_layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
635
+
636
+ if config.deepnorm:
637
+ self.alpha = math.pow(2.0 * config.decoder_layers, 0.25)
638
+ else:
639
+ self.alpha = 1.0
640
+
641
+ def build_ffn(self):
642
+ if self.config.use_glu:
643
+ return GLU(
644
+ self.embed_dim,
645
+ self.ffn_dim,
646
+ self.config.activation_fn,
647
+ self.config.dropout,
648
+ self.config.activation_dropout,
649
+ )
650
+ else:
651
+ return FeedForwardNetwork(
652
+ self.embed_dim,
653
+ self.ffn_dim,
654
+ self.config.activation_fn,
655
+ self.config.dropout,
656
+ self.config.activation_dropout,
657
+ self.config.layernorm_eps,
658
+ self.config.subln,
659
+ self.config.use_ffn_rms_norm,
660
+ )
661
+
662
+ def residual_connection(self, x, residual):
663
+ return residual * self.alpha + x
664
+
665
+ def forward(
666
+ self,
667
+ hidden_states: torch.Tensor,
668
+ retention_rel_pos: Tuple[Tuple[torch.Tensor]],
669
+ retention_mask: Optional[torch.Tensor] = None,
670
+ forward_impl: str = "parallel",
671
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
672
+ output_retentions: Optional[bool] = False,
673
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor]]:
674
+ residual = hidden_states
675
+ if self.normalize_before:
676
+ hidden_states = self.retention_layer_norm(hidden_states)
677
+
678
+ msr_outs = self.retention(
679
+ hidden_states,
680
+ retention_rel_pos,
681
+ retention_mask=retention_mask,
682
+ past_key_value=past_key_value,
683
+ forward_impl=forward_impl,
684
+ output_retentions=output_retentions,
685
+ )
686
+ hidden_states = msr_outs[0]
687
+ curr_kv = msr_outs[1]
688
+
689
+ hidden_states = self.dropout_module(hidden_states)
690
+
691
+ if self.drop_path is not None:
692
+ hidden_states = self.drop_path(hidden_states)
693
+
694
+ hidden_states = self.residual_connection(hidden_states, residual)
695
+ if not self.normalize_before:
696
+ hidden_states = self.retention_layer_norm(hidden_states)
697
+
698
+ residual = hidden_states
699
+ if self.normalize_before:
700
+ hidden_states = self.final_layer_norm(hidden_states)
701
+
702
+ hidden_states = self.ffn(hidden_states)
703
+
704
+ if self.drop_path is not None:
705
+ hidden_states = self.drop_path(hidden_states)
706
+
707
+ hidden_states = self.residual_connection(hidden_states, residual)
708
+ if not self.normalize_before:
709
+ hidden_states = self.final_layer_norm(hidden_states)
710
+
711
+ outputs = (hidden_states, curr_kv)
712
+
713
+ if output_retentions:
714
+ outputs += (msr_outs[2],)
715
+ return outputs
716
+
717
+
718
+ class RetNetPreTrainedModel(PreTrainedModel):
719
+ # copied from LlamaPretrainedModel
720
+ config_class = RetNetConfig
721
+ base_model_prefix = "model"
722
+ supports_gradient_checkpointing = True
723
+ _no_split_modules = ["RetNetDecoderLayer"]
724
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
725
+
726
+ def _init_weights(self, module):
727
+ """
728
+ Following original retnet, weights are already initialized in their own
729
+ ways within their own init.
730
+ """
731
+ pass
732
+ # below is copied from LlamaPretrainedModel
733
+ # std = self.config.initializer_range
734
+ # if isinstance(module, nn.Linear):
735
+ # module.weight.data.normal_(mean=0.0, std=std)
736
+ # if module.bias is not None:
737
+ # module.bias.data.zero_()
738
+ # elif isinstance(module, nn.Embedding):
739
+ # module.weight.data.normal_(mean=0.0, std=std)
740
+ # if module.padding_idx is not None:
741
+ # module.weight.data[module.padding_idx].zero_()
742
+
743
+
744
+ @dataclass
745
+ class RetNetOutputWithPast(ModelOutput):
746
+ """
747
+ class for RetNet model's outputs that may also contain a past key/values (to speed up sequential decoding).
748
+
749
+ config:
750
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, decoder_embed_dim)`):
751
+ Sequence of hidden-states at the output of the last layer of the model.
752
+
753
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
754
+ decoder_embed_dim)` is output.
755
+ past_key_values (`List(Dict(str, torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
756
+ - "prev_key_value": shape=(bsz * num_head * v_dim * qk_dim)
757
+ - "scale": shape=((1 or bsz) * num_head * 1 * 1)
758
+
759
+ Contains pre-computed hidden-states (key and values in the multi-scale retention blocks)
760
+ that can be used (see `past_key_values` input) to speed up sequential decoding.
761
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
762
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
763
+ one for the output of each layer) of shape `(batch_size, sequence_length, decoder_embed_dim)`.
764
+
765
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
766
+ retentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_retentions=True` is passed or when `config.output_retentions=True`):
767
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
768
+ sequence_length)`.
769
+
770
+ Retentions weights, used for visualization.
771
+
772
+ attentions (`tuple(torch.FloatTensor)`, *optional*, for backward compatibility. Same as retentions.
773
+ """
774
+
775
+ last_hidden_state: torch.FloatTensor = None
776
+ past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None
777
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
778
+ retentions: Optional[Tuple[torch.FloatTensor]] = None
779
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
780
+
781
+
782
+ class RetNetModel(RetNetPreTrainedModel):
783
+ def __init__(
784
+ self,
785
+ config: RetNetConfig,
786
+ embed_tokens: nn.Embedding = None,
787
+ tensor_parallel: bool = False,
788
+ ):
789
+ super().__init__(config)
790
+ self.config = config
791
+
792
+ self.dropout_module = torch.nn.Dropout(config.dropout)
793
+
794
+ self.embed_dim = config.decoder_embed_dim
795
+ self.embed_scale = (
796
+ 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
797
+ )
798
+
799
+ if embed_tokens is None:
800
+ embed_tokens = nn.Embedding(
801
+ config.vocab_size, config.decoder_embed_dim, config.pad_token_id
802
+ )
803
+ self.embed_tokens = embed_tokens
804
+
805
+ if config.layernorm_embedding:
806
+ self.layernorm_embedding = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
807
+ else:
808
+ self.layernorm_embedding = None
809
+
810
+ self.layers = nn.ModuleList([])
811
+
812
+ for i in range(config.decoder_layers):
813
+ self.layers.append(
814
+ RetNetDecoderLayer(config, depth=i, tensor_parallel=tensor_parallel)
815
+ )
816
+
817
+ self.decoder_layers = len(self.layers)
818
+
819
+ if config.decoder_normalize_before:
820
+ self.layer_norm = RMSNorm(self.embed_dim, eps=config.layernorm_eps)
821
+ else:
822
+ self.layer_norm = None
823
+
824
+ self.retnet_rel_pos = RetNetRelPos(config)
825
+ self.recurrent_chunk_size = config.recurrent_chunk_size
826
+
827
+ if config.deepnorm:
828
+ init_scale = math.pow(8.0 * config.decoder_layers, 0.25)
829
+ for name, p in self.named_parameters():
830
+ if (
831
+ "fc1" in name
832
+ or "fc2" in name
833
+ or "out_proj" in name
834
+ or "v_proj" in name
835
+ ):
836
+ p.data.div_(init_scale)
837
+
838
+ if config.subln and not config.use_glu:
839
+ init_scale = math.sqrt(math.log(config.decoder_layers * 2))
840
+ for name, p in self.named_parameters():
841
+ if (
842
+ "fc1" in name
843
+ or "fc2" in name
844
+ or "out_proj" in name
845
+ or "v_proj" in name
846
+ ):
847
+ p.data.mul_(init_scale)
848
+
849
+ self.gradient_checkpointing = False
850
+ self.post_init()
851
+
852
+ def get_input_embeddings(self):
853
+ return self.embed_tokens
854
+
855
+ def set_input_embeddings(self, value):
856
+ self.embed_tokens = value
857
+
858
+ def forward_embedding(
859
+ self,
860
+ input_ids,
861
+ forward_impl,
862
+ inputs_embeds=None,
863
+ past_key_values=None,
864
+ ):
865
+ # Check if input_ids are within the range
866
+ if input_ids.max() >= self.config.vocab_size:
867
+ raise ValueError("All input_ids must be less than vocab_size")
868
+
869
+ # if past_key_values is not None:
870
+ if forward_impl == "recurrent":
871
+ input_ids = input_ids[:, -1:]
872
+
873
+ if inputs_embeds is None:
874
+ inputs_embeds = self.embed_tokens(input_ids)
875
+
876
+ embed = self.embed_scale * inputs_embeds
877
+
878
+ if self.layernorm_embedding is not None:
879
+ embed = self.layernorm_embedding(embed)
880
+
881
+ embed = self.dropout_module(embed)
882
+
883
+ return embed
884
+
885
+ def forward(
886
+ self,
887
+ input_ids: torch.LongTensor = None,
888
+ retention_mask: Optional[torch.Tensor] = None,
889
+ attention_mask: Optional[torch.Tensor] = None,
890
+ past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None,
891
+ inputs_embeds: Optional[torch.FloatTensor] = None,
892
+ output_retentions: Optional[bool] = None,
893
+ output_attentions: Optional[bool] = None,
894
+ output_hidden_states: Optional[bool] = None,
895
+ use_cache: Optional[bool] = None,
896
+ return_dict: Optional[bool] = None,
897
+ forward_impl: Optional[str] = "parallel",
898
+ recurrent_chunk_size: Optional[int] = None,
899
+ retention_rel_pos: Optional[Tuple[torch.Tensor]] = None,
900
+ ) -> Union[Tuple, RetNetOutputWithPast]:
901
+ if output_retentions is None and output_attentions is not None:
902
+ output_retentions = output_attentions
903
+ output_retentions = (
904
+ output_retentions
905
+ if output_retentions is not None
906
+ else self.config.output_retentions
907
+ )
908
+ output_hidden_states = (
909
+ output_hidden_states
910
+ if output_hidden_states is not None
911
+ else self.config.output_hidden_states
912
+ )
913
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
914
+
915
+ return_dict = (
916
+ return_dict if return_dict is not None else self.config.use_return_dict
917
+ )
918
+
919
+ # retrieve input_ids and inputs_embeds
920
+ if input_ids is not None and inputs_embeds is not None:
921
+ raise ValueError(
922
+ "You cannot specify both input_ids and inputs_embeds at the same time"
923
+ )
924
+ elif input_ids is not None:
925
+ batch_size, seq_length = input_ids.shape
926
+ elif inputs_embeds is not None:
927
+ batch_size, seq_length, _ = inputs_embeds.shape
928
+ else:
929
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
930
+
931
+ # embed tokens
932
+ if inputs_embeds is None:
933
+ inputs_embeds = self.forward_embedding(
934
+ input_ids, forward_impl, inputs_embeds, past_key_values
935
+ )
936
+
937
+ if retention_mask is None and attention_mask is not None:
938
+ retention_mask = attention_mask
939
+ if retention_mask is not None and forward_impl == "recurrent":
940
+ retention_mask = retention_mask[:, -1:]
941
+
942
+ hidden_states = inputs_embeds
943
+
944
+ # handling chunking here
945
+ if recurrent_chunk_size is None:
946
+ recurrent_chunk_size = self.recurrent_chunk_size
947
+ need_pad_for_chunkwise = (
948
+ forward_impl == "chunkwise" and seq_length % recurrent_chunk_size != 0
949
+ )
950
+ if need_pad_for_chunkwise:
951
+ padding_len = recurrent_chunk_size - seq_length % recurrent_chunk_size
952
+ slen = seq_length + padding_len
953
+ hidden_states = F.pad(hidden_states, (0, 0, 0, padding_len))
954
+ else:
955
+ slen = seq_length
956
+ # relative position
957
+ if retention_rel_pos is None:
958
+ retention_rel_pos = self.retnet_rel_pos(
959
+ slen,
960
+ forward_impl=forward_impl,
961
+ recurrent_chunk_size=recurrent_chunk_size,
962
+ retention_mask=retention_mask,
963
+ get_decay_scale=not self.training,
964
+ )
965
+
966
+ # start running through the decoder layers
967
+ all_hidden_states = () if output_hidden_states else None
968
+ all_retentions = () if output_retentions else None
969
+ # layers * [bsz, num_head, qk_dim, decoder_embed_dim]
970
+ next_decoder_cache = () if use_cache else None
971
+
972
+ for idx, layer in enumerate(self.layers):
973
+ if output_hidden_states:
974
+ all_hidden_states += (hidden_states,)
975
+ past_key_value = (
976
+ past_key_values[idx] if past_key_values is not None else None
977
+ )
978
+
979
+ if self.gradient_checkpointing and self.training:
980
+
981
+ def create_custom_forward(module):
982
+ def custom_forward(*inputs):
983
+ return module(*inputs, output_retentions)
984
+
985
+ return custom_forward
986
+
987
+ layer_outputs = torch.utils.checkpoint.checkpoint(
988
+ create_custom_forward(layer),
989
+ hidden_states,
990
+ retention_rel_pos,
991
+ retention_mask,
992
+ forward_impl,
993
+ past_key_value,
994
+ )
995
+ else:
996
+ layer_outputs = layer(
997
+ hidden_states,
998
+ retention_rel_pos,
999
+ retention_mask=retention_mask,
1000
+ forward_impl=forward_impl,
1001
+ past_key_value=past_key_value,
1002
+ output_retentions=output_retentions,
1003
+ )
1004
+
1005
+ hidden_states = layer_outputs[0]
1006
+
1007
+ if use_cache:
1008
+ next_decoder_cache += (layer_outputs[1],)
1009
+
1010
+ if output_retentions:
1011
+ all_retentions += (layer_outputs[2],)
1012
+
1013
+ next_cache = next_decoder_cache if use_cache else None
1014
+
1015
+ if need_pad_for_chunkwise:
1016
+ hidden_states = hidden_states[:, :seq_length, :]
1017
+
1018
+ if self.layer_norm is not None:
1019
+ hidden_states = self.layer_norm(hidden_states)
1020
+
1021
+ # add hidden states from the last decoder layer
1022
+ if output_hidden_states:
1023
+ all_hidden_states += (hidden_states,)
1024
+
1025
+ if not return_dict:
1026
+ return tuple(
1027
+ v
1028
+ for v in [hidden_states, next_cache, all_hidden_states, all_retentions]
1029
+ if v is not None
1030
+ )
1031
+ return RetNetOutputWithPast(
1032
+ last_hidden_state=hidden_states,
1033
+ past_key_values=next_cache,
1034
+ hidden_states=all_hidden_states,
1035
+ retentions=all_retentions,
1036
+ attentions=all_retentions,
1037
+ )
1038
+
1039
+
1040
+ @dataclass
1041
+ class RetNetCausalLMOutputWithPast(ModelOutput):
1042
+ """
1043
+ class for RetNet causal language model (or autoregressive) outputs.
1044
+
1045
+ config:
1046
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1047
+ Language modeling loss (for next-token prediction).
1048
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1049
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1050
+ past_key_values (`List(Dict(str, torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1051
+ - "prev_key_value": shape=(bsz * num_head * v_dim * qk_dim)
1052
+ - "scale": shape=((1 or bsz) * num_head * 1 * 1)
1053
+
1054
+ Contains pre-computed hidden-states (key and values in the multi-scale retention blocks)
1055
+ that can be used (see `past_key_values` input) to speed up sequential decoding.
1056
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1057
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1058
+ one for the output of each layer) of shape `(batch_size, sequence_length, decoder_embed_dim)`.
1059
+
1060
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1061
+ retentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_retentions=True` is passed or when `config.output_retentions=True`):
1062
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1063
+ sequence_length)`.
1064
+
1065
+ Retentions weights, used for visualization.
1066
+
1067
+ attentions (`tuple(torch.FloatTensor)`, *optional*, for backward compatibility. Same as retentions.
1068
+ """
1069
+
1070
+ loss: Optional[torch.FloatTensor] = None
1071
+ logits: torch.FloatTensor = None
1072
+ past_key_values: Optional[List[Dict[str, torch.FloatTensor]]] = None
1073
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
1074
+ retentions: Optional[Tuple[torch.FloatTensor]] = None
1075
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
1076
+
1077
+
1078
+ class RetNetForCausalLM(RetNetPreTrainedModel):
1079
+ def __init__(
1080
+ self,
1081
+ config: RetNetConfig,
1082
+ embed_tokens: nn.Embedding = None,
1083
+ tensor_parallel: bool = False,
1084
+ ) -> None:
1085
+ super().__init__(config)
1086
+ self.model = RetNetModel(
1087
+ config, embed_tokens=embed_tokens, tensor_parallel=tensor_parallel
1088
+ )
1089
+ self.lm_head = nn.Linear(
1090
+ config.decoder_embed_dim, config.vocab_size, bias=False
1091
+ )
1092
+ # init here
1093
+ torch.nn.init.normal_(
1094
+ self.lm_head.weight, mean=0, std=config.decoder_embed_dim**-0.5
1095
+ )
1096
+
1097
+ self.post_init()
1098
+
1099
+ def get_input_embeddings(self):
1100
+ return self.model.embed_tokens
1101
+
1102
+ def set_input_embeddings(self, value):
1103
+ self.model.embed_tokens = value
1104
+
1105
+ def get_output_embeddings(self):
1106
+ return self.lm_head
1107
+
1108
+ def set_output_embeddings(self, new_embeddings):
1109
+ self.lm_head = new_embeddings
1110
+
1111
+ def set_decoder(self, decoder):
1112
+ self.model = decoder
1113
+
1114
+ def get_decoder(self):
1115
+ return self.model
1116
+
1117
+ def forward(
1118
+ self,
1119
+ input_ids: torch.LongTensor = None,
1120
+ retention_mask: Optional[torch.Tensor] = None,
1121
+ attention_mask: Optional[torch.Tensor] = None,
1122
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1123
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1124
+ labels: Optional[torch.LongTensor] = None,
1125
+ use_cache: Optional[bool] = None,
1126
+ output_retentions: Optional[bool] = None,
1127
+ output_attentions: Optional[bool] = None,
1128
+ output_hidden_states: Optional[bool] = None,
1129
+ return_dict: Optional[bool] = None,
1130
+ forward_impl: Optional[str] = None,
1131
+ recurrent_chunk_size: Optional[int] = None,
1132
+ retention_rel_pos: Optional[Tuple[torch.Tensor]] = None,
1133
+ ) -> Union[Tuple, RetNetCausalLMOutputWithPast]:
1134
+ if output_retentions is None and output_attentions is not None:
1135
+ output_retentions = output_attentions
1136
+ output_retentions = (
1137
+ output_retentions
1138
+ if output_retentions is not None
1139
+ else self.config.output_retentions
1140
+ )
1141
+ output_hidden_states = (
1142
+ output_hidden_states
1143
+ if output_hidden_states is not None
1144
+ else self.config.output_hidden_states
1145
+ )
1146
+ return_dict = (
1147
+ return_dict if return_dict is not None else self.config.use_return_dict
1148
+ )
1149
+ forward_impl = (
1150
+ forward_impl if forward_impl is not None else self.config.forward_impl
1151
+ )
1152
+ recurrent_chunk_size = (
1153
+ recurrent_chunk_size
1154
+ if recurrent_chunk_size is not None
1155
+ else self.config.recurrent_chunk_size
1156
+ )
1157
+
1158
+ if retention_mask is None and attention_mask is not None:
1159
+ retention_mask = attention_mask
1160
+
1161
+ outputs = self.model(
1162
+ input_ids,
1163
+ retention_mask=retention_mask,
1164
+ past_key_values=past_key_values,
1165
+ inputs_embeds=inputs_embeds,
1166
+ output_retentions=output_retentions,
1167
+ output_hidden_states=output_hidden_states,
1168
+ return_dict=return_dict,
1169
+ forward_impl=forward_impl,
1170
+ use_cache=use_cache,
1171
+ recurrent_chunk_size=recurrent_chunk_size,
1172
+ retention_rel_pos=retention_rel_pos,
1173
+ )
1174
+
1175
+ hidden_states = outputs[0]
1176
+ logits = self.lm_head(hidden_states)
1177
+
1178
+ loss = None
1179
+ if labels is not None:
1180
+ # Shift so that tokens < n predict n
1181
+ shift_logits = logits[..., :-1, :].contiguous()
1182
+ shift_labels = labels[..., 1:].contiguous()
1183
+ # Flatten the tokens
1184
+ loss_fct = nn.CrossEntropyLoss()
1185
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1186
+ shift_labels = shift_labels.view(-1)
1187
+ # Enable model parallelism
1188
+ shift_labels = shift_labels.to(shift_logits.device)
1189
+ loss = loss_fct(shift_logits, shift_labels)
1190
+
1191
+ if self.config.z_loss_coeff > 0:
1192
+ # z_loss from PaLM paper
1193
+ # z_loss = 1e-4 * log(log(z)), where z = sum(exp(logits))
1194
+ z_loss = torch.logsumexp(shift_logits, dim=-1).log().mean()
1195
+ loss += self.config.z_loss_coeff * z_loss
1196
+
1197
+ if not return_dict:
1198
+ output = (logits,) + outputs[1:]
1199
+ return (loss,) + output if loss is not None else output
1200
+
1201
+ return RetNetCausalLMOutputWithPast(
1202
+ loss=loss,
1203
+ logits=logits,
1204
+ past_key_values=outputs.past_key_values,
1205
+ hidden_states=outputs.hidden_states,
1206
+ retentions=outputs.retentions,
1207
+ attentions=outputs.retentions,
1208
+ )
1209
+
1210
+ def _crop_past_key_values(model, past_key_values, maximum_length):
1211
+ """Since retnet's kv do not have length, no need to crop. Just return"""
1212
+ return past_key_values
1213
+
1214
+ def prepare_inputs_for_generation(
1215
+ self,
1216
+ input_ids,
1217
+ past_key_values=None,
1218
+ attention_mask=None,
1219
+ inputs_embeds=None,
1220
+ **kwargs,
1221
+ ):
1222
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1223
+ if inputs_embeds is not None and past_key_values is None:
1224
+ model_inputs = {"inputs_embeds": inputs_embeds}
1225
+ else:
1226
+ model_inputs = {"input_ids": input_ids}
1227
+
1228
+ forward_impl = kwargs.get("forward_impl", "parallel")
1229
+ if past_key_values is not None:
1230
+ forward_impl = "recurrent"
1231
+
1232
+ model_inputs.update(
1233
+ {
1234
+ "past_key_values": past_key_values,
1235
+ "use_cache": kwargs.get("use_cache"),
1236
+ "attention_mask": attention_mask,
1237
+ "forward_impl": forward_impl,
1238
+ }
1239
+ )
1240
+ return model_inputs
1241
+
1242
+ @staticmethod
1243
+ def _reorder_cache(past_key_values, beam_idx):
1244
+ reordered_past = ()
1245
+ for layer_past in past_key_values: # dict
1246
+ layer_past_kv = layer_past["prev_key_value"] # [b, h, v_dim / h, qk_dim]
1247
+ layer_past_scale = layer_past["scale"] # [b, h, 1, 1]
1248
+ if layer_past_scale.size(0) > 1:
1249
+ # this means that retention_mask is not None, so the scale for
1250
+ # each batch is different. We need to select the correct scale then.
1251
+ # NOTE: during huggingface generate, it will generate attention_mask
1252
+ # if it is None, so this linke will always be true. Still, having
1253
+ # this line here for safety.
1254
+ layer_past_scale = layer_past_scale.index_select(0, beam_idx)
1255
+ reordered_past += (
1256
+ {
1257
+ "prev_key_value": layer_past_kv.index_select(0, beam_idx),
1258
+ "scale": layer_past_scale,
1259
+ },
1260
+ )
1261
+ return reordered_past
1262
+
1263
+ def sample_token(self, logit, do_sample=False, top_k=1, top_p=1.0, temperature=1.0):
1264
+ if not do_sample:
1265
+ return torch.argmax(logit, dim=-1, keepdim=True)
1266
+ filtered = top_k_top_p_filtering(logit / temperature, top_k=top_k, top_p=top_p)
1267
+ return torch.multinomial(torch.softmax(filtered, dim=-1), num_samples=1)
1268
+
1269
+ @torch.inference_mode()
1270
+ def custom_generate(
1271
+ self,
1272
+ input_ids: torch.LongTensor = None,
1273
+ retention_mask: Optional[torch.Tensor] = None,
1274
+ attention_mask: Optional[torch.Tensor] = None,
1275
+ parallel_compute_prompt=True,
1276
+ max_new_tokens=20,
1277
+ bos_token_id=0,
1278
+ eos_token_id=0,
1279
+ do_sample=False,
1280
+ top_k=0,
1281
+ top_p=1.0,
1282
+ temperature=1.0,
1283
+ early_stopping=True,
1284
+ ):
1285
+ if retention_mask is None and attention_mask is not None:
1286
+ retention_mask = attention_mask
1287
+
1288
+ if input_ids is not None:
1289
+ if input_ids.shape[1] == 1:
1290
+ past_key_values = None
1291
+ elif parallel_compute_prompt:
1292
+ ret_mask = (
1293
+ retention_mask[:, :-1] if retention_mask is not None else None
1294
+ )
1295
+ outputs = self(
1296
+ input_ids[:, :-1],
1297
+ retention_mask=ret_mask,
1298
+ forward_impl="parallel",
1299
+ return_dict=True,
1300
+ use_cache=True,
1301
+ )
1302
+ past_key_values = outputs.past_key_values
1303
+ else:
1304
+ past_key_values = None
1305
+ for p_i in range(input_ids.shape[1] - 1):
1306
+ ret_mask = (
1307
+ retention_mask[:, : p_i + 1]
1308
+ if retention_mask is not None
1309
+ else None
1310
+ )
1311
+ outputs = self(
1312
+ input_ids[:, : p_i + 1],
1313
+ retention_mask=ret_mask,
1314
+ forward_impl="recurrent",
1315
+ past_key_values=past_key_values,
1316
+ return_dict=True,
1317
+ use_cache=True,
1318
+ )
1319
+ past_key_values = outputs.past_key_values
1320
+
1321
+ generated = input_ids
1322
+ else:
1323
+ generated = torch.tensor([[bos_token_id]]).to(self.lm_head.weight.device)
1324
+ past_key_values = None
1325
+
1326
+ for i in range(max_new_tokens):
1327
+ outputs = self(
1328
+ generated,
1329
+ retention_mask=retention_mask,
1330
+ forward_impl="recurrent",
1331
+ past_key_values=past_key_values,
1332
+ use_cache=True,
1333
+ return_dict=True,
1334
+ )
1335
+ logit = outputs.logits[:, -1, :] # [batch_size, vocab_size]
1336
+ past_key_values = outputs.past_key_values
1337
+ token = self.sample_token(
1338
+ logit,
1339
+ do_sample=do_sample,
1340
+ top_k=top_k,
1341
+ top_p=top_p,
1342
+ temperature=temperature,
1343
+ )
1344
+ generated = torch.cat([generated, token], dim=-1)
1345
+ if retention_mask is not None:
1346
+ retention_mask = torch.cat(
1347
+ [retention_mask, torch.ones_like(token)], dim=-1
1348
+ )
1349
+ if early_stopping and (token == eos_token_id).all():
1350
+ break
1351
+ return generated
1352
+
1353
+
1354
+ class RetNetForSequenceClassification(RetNetPreTrainedModel):
1355
+ def __init__(self, config, tensor_parallel=False):
1356
+ super().__init__(config)
1357
+ self.num_labels = config.num_labels
1358
+ self.model = RetNetModel(config, tensor_parallel=tensor_parallel)
1359
+ self.score = nn.Linear(config.decoder_embed_dim, self.num_labels, bias=False)
1360
+
1361
+ # Initialize weights and apply final processing
1362
+ self.post_init()
1363
+
1364
+ def get_input_embeddings(self):
1365
+ return self.model.embed_tokens
1366
+
1367
+ def set_input_embeddings(self, value):
1368
+ self.model.embed_tokens = value
1369
+
1370
+ def forward(
1371
+ self,
1372
+ input_ids: torch.LongTensor = None,
1373
+ retention_mask: Optional[torch.Tensor] = None,
1374
+ attention_mask: Optional[torch.Tensor] = None,
1375
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1376
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1377
+ labels: Optional[torch.LongTensor] = None,
1378
+ use_cache: Optional[bool] = None,
1379
+ output_retentions: Optional[bool] = None,
1380
+ output_attentions: Optional[bool] = None,
1381
+ output_hidden_states: Optional[bool] = None,
1382
+ return_dict: Optional[bool] = None,
1383
+ forward_impl: Optional[str] = None,
1384
+ recurrent_chunk_size: Optional[int] = None,
1385
+ retention_rel_pos: Optional[Tuple[torch.Tensor]] = None,
1386
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1387
+ if output_retentions is None and output_attentions is not None:
1388
+ output_retentions = output_attentions
1389
+ output_retentions = (
1390
+ output_retentions
1391
+ if output_retentions is not None
1392
+ else self.config.output_retentions
1393
+ )
1394
+ output_hidden_states = (
1395
+ output_hidden_states
1396
+ if output_hidden_states is not None
1397
+ else self.config.output_hidden_states
1398
+ )
1399
+ return_dict = (
1400
+ return_dict if return_dict is not None else self.config.use_return_dict
1401
+ )
1402
+ forward_impl = (
1403
+ forward_impl if forward_impl is not None else self.config.forward_impl
1404
+ )
1405
+ recurrent_chunk_size = (
1406
+ recurrent_chunk_size
1407
+ if recurrent_chunk_size is not None
1408
+ else self.config.recurrent_chunk_size
1409
+ )
1410
+
1411
+ if retention_mask is None and attention_mask is not None:
1412
+ retention_mask = attention_mask
1413
+
1414
+ outputs = self.model(
1415
+ input_ids,
1416
+ retention_mask=retention_mask,
1417
+ past_key_values=past_key_values,
1418
+ inputs_embeds=inputs_embeds,
1419
+ output_retentions=output_retentions,
1420
+ output_hidden_states=output_hidden_states,
1421
+ return_dict=return_dict,
1422
+ forward_impl=forward_impl,
1423
+ use_cache=use_cache,
1424
+ recurrent_chunk_size=recurrent_chunk_size,
1425
+ retention_rel_pos=retention_rel_pos,
1426
+ )
1427
+
1428
+ hidden_states = outputs[0]
1429
+ logits = self.score(hidden_states)
1430
+
1431
+ if input_ids is not None:
1432
+ batch_size = input_ids.shape[0]
1433
+ else:
1434
+ batch_size = inputs_embeds.shape[0]
1435
+
1436
+ if self.config.pad_token_id is None and batch_size != 1:
1437
+ raise ValueError(
1438
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1439
+ )
1440
+ if self.config.pad_token_id is None:
1441
+ sequence_lengths = -1
1442
+ else:
1443
+ if input_ids is not None:
1444
+ sequence_lengths = (
1445
+ torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1
1446
+ ).to(logits.device)
1447
+ else:
1448
+ sequence_lengths = -1
1449
+
1450
+ pooled_logits = logits[
1451
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1452
+ ]
1453
+
1454
+ loss = None
1455
+ if labels is not None:
1456
+ labels = labels.to(logits.device)
1457
+ if self.config.problem_type is None:
1458
+ if self.num_labels == 1:
1459
+ self.config.problem_type = "regression"
1460
+ elif self.num_labels > 1 and (
1461
+ labels.dtype == torch.long or labels.dtype == torch.int
1462
+ ):
1463
+ self.config.problem_type = "single_label_classification"
1464
+ else:
1465
+ self.config.problem_type = "multi_label_classification"
1466
+
1467
+ if self.config.problem_type == "regression":
1468
+ loss_fct = MSELoss()
1469
+ if self.num_labels == 1:
1470
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1471
+ else:
1472
+ loss = loss_fct(pooled_logits, labels)
1473
+ elif self.config.problem_type == "single_label_classification":
1474
+ loss_fct = CrossEntropyLoss()
1475
+ loss = loss_fct(
1476
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1477
+ )
1478
+ elif self.config.problem_type == "multi_label_classification":
1479
+ loss_fct = BCEWithLogitsLoss()
1480
+ loss = loss_fct(pooled_logits, labels)
1481
+ if not return_dict:
1482
+ output = (pooled_logits,) + outputs[1:]
1483
+ return ((loss,) + output) if loss is not None else output
1484
+
1485
+ return SequenceClassifierOutputWithPast(
1486
+ loss=loss,
1487
+ logits=pooled_logits,
1488
+ past_key_values=outputs.past_key_values,
1489
+ hidden_states=outputs.hidden_states,
1490
+ attentions=outputs.attentions,
1491
+ )