File size: 3,778 Bytes
c9fc3d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
from torch import Tensor
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import InferenceParams
from tqdm.auto import tqdm
import sys
dim_model = 4096
n_vocab = 2
n_layers = 4


@torch.no_grad()
def string_to_bits(text: str, _cache = []) -> Tensor:
   all_values = torch.arange(0, 256)
   if not _cache:
      bits = [((all_values & (1 << i)) != 0).int() for i in range(7, -1, -1)]
      bits_tensor = torch.stack(bits).mT
      _cache.append(bits_tensor)
   else:
      bits_tensor = _cache[0]
   binary = text.encode()
   raw =  torch.frombuffer(binary, dtype=torch.uint8).int()   
   return bits_tensor[raw].long().ravel()

@torch.no_grad()
def bits_to_string(bits: Tensor):
    if bits.dim() == 2:
        return [bits_to_string(t) for t in bits]
    assert bits.dim() == 1
    assert len(bits) % 8 == 0
    factors = torch.tensor([2**i for i in range(7,-1,-1)]).to(device=bits.device)
    as_bytes = bits.view(-1, 8)
    as_bytes = (as_bytes*factors).sum(-1)
    return ''.join([chr(x) for x in as_bytes])

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(n_vocab, dim_model)
        self.emb.weight.data *= 0.001

    def forward(self, x):
        return self.emb(x)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm = nn.LayerNorm(dim_model)
        self.decoder = nn.Linear(dim_model, n_vocab, False)
        self.decoder.weight.data *= 0.001

    def forward(self, x):
        x = self.norm(x)
        x = self.decoder(x)
        return x

class MambaBit(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = Encoder()
        self.layers = nn.ModuleList([Mamba(dim_model) for _ in range(n_layers)])
        self.dec = Decoder()

    def forward(self, x):        
        x = self.enc(x)
        for layer in self.layers:
            x = layer(x)
        x = self.dec(x)
        return x

class MambaBitWithInference(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = Encoder()
        self.layers = nn.ModuleList([Mamba(dim_model, layer_idx=i) for i in range(n_layers)])
        self.dec = Decoder()

    def forward(self, x, inference_parms=None):
        x = self.enc(x)        
        for i,layer in enumerate(self.layers):
            x = layer(x, inference_params=inference_parms)
        x = self.dec(x)
        return x

# test using O(N^2) cacheless stateless algorithm.
@torch.no_grad()
def test_n2(m: MambaBit, prompt: str, chars=10):
    x = string_to_bits(prompt).cuda()[None]
    process = chars * 8
    for i in tqdm(range(process)):
        y = m(x)
        new = y[:, -1:].argmax(-1)
        x = torch.cat((x, new), 1)
    return bits_to_string(x)

# test using O(N) by reusing state
@torch.no_grad()
def test_n(m: MambaBit, prompt: str, chars=10):
    x = string_to_bits(prompt).cuda()[None]
    process = chars * 8

    inference_parms = InferenceParams(
        max_seqlen=x.numel() + process, 
        max_batch_size=1)

    y = m(x, inference_parms=inference_parms)
    new = y[:, -1:].argmax(-1)
    for i in tqdm(range(process)):        
        x = torch.cat((x, new), 1)
        inference_parms.seqlen_offset = x.numel() + i
        y = m(new, inference_parms=inference_parms)
        new = y[:, -1:].argmax(-1)
    return bits_to_string(x)

def run():
    mamba_bit = MambaBitWithInference().cuda()
    mamba_bit.load_state_dict(torch.load("mamba_bit.bin"))


    prompt="FIRST CITIZE" if len(sys.argv) != 2 else sys.argv[1]
    # test_n2 is O(N^2), test_n is O(N) but inference_params are not well documented
    s = test_n(mamba_bit, prompt, chars=256)[0]
    print(s)

if __name__ == "__main__":
    run()