File size: 15,111 Bytes
b3fe4f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import math
import torch
import torch.nn as nn
from einops import rearrange, repeat

from flash_attn.layers.rotary import apply_rotary_emb_qkv_, apply_rotary_emb_func, apply_rotary_emb_kv_

class RelativePositionalEncoding(nn.Module):

    def __init__(self, relative_attention_num_buckets, relative_attention_max_distance, n_heads, max_sequence_length, bidirectional=True, randomized_position=False):

        super().__init__()

        self.relative_attention_num_buckets = relative_attention_num_buckets
        self.relative_attention_max_distance = relative_attention_max_distance
        self.n_heads = n_heads
        self.max_sequence_length = max_sequence_length
        self.bidirectional = bidirectional
        self.randomized_position = randomized_position

        self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)

    @staticmethod
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer

        Returns:
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
        """
        relative_buckets = 0
        if bidirectional:
            num_buckets //= 2
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
        else:
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
        relative_position_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).to(torch.long)
        relative_position_if_large = torch.min(
            relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
        )

        relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
        return relative_buckets

    def compute_bias(self, query_length, key_length, device=None):
        """Compute binned relative position bias"""
        if device is None:
            device = self.relative_attention_bias.weight.device

        if self.randomized_position:
            context_position = torch.arange(self.max_sequence_length, dtype=torch.long, device=device)
            context_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:query_length])
            context_indices_rand[0] = 0 # root the first element of the sequence
            context_position = context_position[context_indices_rand][:, None]

            memory_position = torch.arange(self.max_sequence_length, dtype=torch.long, device=device)
            memory_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:key_length])
            memory_indices_rand[0] = 0 # root the first element of the sequence
            memory_position = memory_position[memory_indices_rand][None, :]
        else:
            context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
            memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]

        relative_position = memory_position - context_position  # shape (query_length, key_length)

        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=self.bidirectional,
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance,
        )
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
        return values

    def forward(self, q, k=None, v=None):

        query_length = q.shape[1]
        key_length = k.shape[1] if k is not None else query_length
        bias = self.compute_bias(query_length, key_length, device=q.device).contiguous().to(q.dtype)

        return q, k, v, bias


class ALiBiPositionalEncoding(nn.Module):

    def __init__(self, max_sequence_length, num_heads, mode='symetric', randomized_position=False):

        super().__init__()

        self.max_sequence_length = max_sequence_length
        self.num_heads = num_heads
        self.mode = mode
        self.randomized_position = randomized_position

        self.alibi_bias = self.build_alibi_bias_matrix(num_heads, max_sequence_length, mode)

    @staticmethod
    def fill_with_neg_inf(t):
        """FP16-compatible function that fills a tensor with -inf."""
        return t.float().fill_(float("-inf")).type_as(t)

    def get_slopes(self, n):

        def get_slopes_power_of_2(n):
            start = (2**(-2**-(math.log2(n)-3)))
            ratio = start
            return [start*ratio**i for i in range(n)]

        if math.log2(n).is_integer():
            return get_slopes_power_of_2(n)                   #In the paper, we only train models that have 2^a heads for some a. This function has
        else:                                                 #some good properties that only occur when the input is a power of 2. To maintain that even
            closest_power_of_2 = 2**math.floor(math.log2(n))  #when the number of heads is not a power of 2, we use this workaround.
            return get_slopes_power_of_2(closest_power_of_2) + self.get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]

    def build_symetric_alibi_bias_matrix(self, num_heads, maxpos):

        context_position = torch.arange(maxpos)[:, None]
        memory_position = torch.arange(maxpos)[None, :]

        relative_position = memory_position - context_position
        relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_heads, -1,-1)

        slopes = torch.Tensor(self.get_slopes(num_heads)) * -1
        alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
        return alibi.view(1, num_heads, maxpos, maxpos)

    def build_asymetric_alibi_bias_matrix(self, num_heads, maxpos):
        _future_mask_right = torch.triu(self.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1).unsqueeze(0).repeat(num_heads // 2, 1, 1)
        _future_mask_left = torch.tril(self.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), -1).unsqueeze(0).repeat(num_heads // 2, 1, 1)

        nonsym_mask = torch.cat((_future_mask_right, _future_mask_left), dim = 0).unsqueeze(0)
        slopes = torch.Tensor(self.get_slopes(num_heads // 2)) * -1

        context_position = torch.arange(maxpos)[:, None]
        memory_position = torch.arange(maxpos)[None, :]

        relative_position = memory_position - context_position
        relative_position = torch.abs(relative_position).unsqueeze(0).expand(num_heads // 2, -1,-1)

        alibi = slopes.unsqueeze(1).unsqueeze(1) * relative_position
        alibi = alibi.view(1, num_heads // 2, maxpos, maxpos)
        alibi = alibi.repeat(1, 2, 1, 1)

        return alibi.view(1, num_heads, maxpos, maxpos) + nonsym_mask.view(1, num_heads, maxpos, maxpos)


    def build_alibi_bias_matrix(self, num_heads, maxpos, mode='symetric'):
        if mode == 'symetric':
            return self.build_symetric_alibi_bias_matrix(num_heads, maxpos)
        elif mode == 'asymetric':
            return self.build_asymetric_alibi_bias_matrix(num_heads, maxpos)
        else:
            raise ValueError("ALiBi mode " + mode + " is not implemented.")

    def forward(self, q, k=None, v=None):

        query_length = q.shape[1]
        key_length = k.shape[1] if k is not None else query_length
        assert (self.alibi_bias.shape[1] < query_length) & (self.alibi_bias.shape[1] < key_length), "Sequence length larger than allowed alibi bound"

        if self.randomized_position:
            query_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:query_length])
            key_indices_rand, _ = torch.sort(torch.randperm(self.max_sequence_length)[:key_length])

            # ground sequences
            query_indices_rand[0] = 0
            key_indices_rand[0] = 0

            bias = self.alibi_bias[:, :, query_indices_rand, key_indices_rand].to(q.device)

        else:
            bias = self.alibi_bias[:, :, :query_length, :key_length].to(q.device)

        return q, k, v, bias.to(q.dtype).contiguous()

class RotaryPositionalEncoding(nn.Module):

    def __init__(self, dim,
        max_sequence_length,
        base=10000.0,
        interleaved=False,
        scale_base=None,
        randomized_position=False):

        super().__init__()

        self.max_sequence_length = max_sequence_length
        self.randomized_position = randomized_position

        self.dim = dim
        self.base = base
        self.interleaved = interleaved
        self.scale_base = scale_base

        inv_freq = self._compute_inv_freq()
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        scale = (
            (torch.arange(0, dim, 2, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
            if scale_base is not None
            else None
        )
        self.register_buffer("scale", scale, persistent=False)

        self._cos_cached = None
        self._sin_cached = None
        self._cos_k_cached = None
        self._sin_k_cached = None

    def _compute_inv_freq(self, device=None):
        return 1.0 / (
            self.base
            ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
        )

    def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
        # Reset the tables if the sequence length has changed,
        # if we're on a new device (possibly due to tracing for instance),
        # or if we're switching from inference mode to training
        if (
            self._cos_cached is None
            or self._cos_cached.device != device
            or self._cos_cached.dtype != dtype
            or (self.training and self._cos_cached.is_inference())
        ):
            # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
            # And the output of arange can be quite large, so bf16 would lose a lot of precision.
            # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
            inv_freq = self._compute_inv_freq(device=device)

            # Don't do einsum, it converts fp32 to fp16 under AMP
            # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            t = torch.arange(seqlen, device=device, dtype=dtype)
            freqs = torch.outer(t, inv_freq)
            if self.scale is None:
                self._cos_cached = torch.cos(freqs).to(dtype)
                self._sin_cached = torch.sin(freqs).to(dtype)
                self._cos_k_cached = None
                self._sin_k_cached = None
            else:
                power = (
                    torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
                    - seqlen // 2
                ) / self.scale_base
                scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
                # We want the multiplication by scale to happen in fp32
                self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
                self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
                self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
                self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)

    def forward(self, q, k=None, v=None):

        if self._cos_cached is None:
            self._update_cos_sin_cache(self.max_sequence_length, device=q.device, dtype=q.dtype)

        if k is None and v is None:
            q = apply_rotary_emb_qkv_(
                    q,
                    self._cos_cached,
                    self._sin_cached,
                    self._cos_k_cached,
                    self._sin_k_cached,
                    interleaved=self.interleaved,
                    seqlen_offsets=0
                )
        elif v is None and k is not None:
            q = apply_rotary_emb_func(
                q,
                self._cos_cached,
                self._sin_cached,
                interleaved=self.interleaved,
                inplace=True,
                seqlen_offsets=0
            )

            k = apply_rotary_emb_kv_(
                k,
                self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
                self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
                interleaved=self.interleaved,
                seqlen_offsets=0,
            )
        else:
            q = apply_rotary_emb_func(
                q,
                self._cos_cached,
                self._sin_cached,
                interleaved=self.interleaved,
                inplace=True,
                seqlen_offsets=0
            )

            k = apply_rotary_emb_func(
                k,
                self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
                self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
                interleaved=self.interleaved,
                seqlen_offsets=0,
            )

            v = apply_rotary_emb_func(
                v,
                self._cos_cached if self._cos_k_cached is None else self._cos_k_cached,
                self._sin_cached if self._sin_k_cached is None else self._sin_k_cached,
                interleaved=self.interleaved,
                seqlen_offsets=0,
            )

        return q, k, v, None