File size: 1,786 Bytes
a0806ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# -*- coding: utf-8 -*-

from fla.layers import (ABCAttention, Attention, BasedLinearAttention,
                        BitAttention, DeltaNet, GatedLinearAttention,
                        GatedSlotAttention, HGRN2Attention, HGRNAttention,
                        LinearAttention, MultiScaleRetention,
                        ReBasedLinearAttention)
from fla.models import (ABCForCausalLM, ABCModel, BitNetForCausalLM,
                        BitNetModel, DeltaNetForCausalLM, DeltaNetModel,
                        GLAForCausalLM, GLAModel, GSAForCausalLM, GSAModel,
                        HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM,
                        LinearAttentionForCausalLM, LinearAttentionModel,
                        RetNetForCausalLM, RetNetModel, RWKV6ForCausalLM,
                        RWKV6Model, TransformerForCausalLM, TransformerModel)

__all__ = [
    'ABCAttention',
    'Attention',
    'BasedLinearAttention',
    'BitAttention',
    'DeltaNet',
    'HGRNAttention',
    'HGRN2Attention',
    'GatedLinearAttention',
    'GatedSlotAttention',
    'LinearAttention',
    'MultiScaleRetention',
    'ReBasedLinearAttention',
    'ABCForCausalLM',
    'ABCModel',
    'BitNetForCausalLM',
    'BitNetModel',
    'DeltaNetForCausalLM',
    'DeltaNetModel',
    'HGRNForCausalLM',
    'HGRNModel',
    'HGRN2ForCausalLM',
    'HGRN2Model',
    'GLAForCausalLM',
    'GLAModel',
    'GSAForCausalLM',
    'GSAModel',
    'LinearAttentionForCausalLM',
    'LinearAttentionModel',
    'RetNetForCausalLM',
    'RetNetModel',
    'RWKV6ForCausalLM',
    'RWKV6Model',
    'TransformerForCausalLM',
    'TransformerModel',
    'chunk_gla',
    'chunk_retention',
    'fused_chunk_based',
    'fused_chunk_gla',
    'fused_chunk_retention'
]

__version__ = '0.1'