File size: 3,795 Bytes
1a030c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from mamba_ssm.modules.mamba2_simple import Mamba2Simple
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from pathlib import Path
from einops import repeat


from image_utils import ImageDB, ImageBatch, RGBToModel
from image_utils import ModelToRGB

epochs = 10_000
bs = 16
# orig;
# bs = 16
# d_model = 768
# headdim = 64
# n_layer = 4

d_model = 1024
headdim = 64
n_layer = 4

OPTS = {
    'device': "cuda",
    'dtype': torch.bfloat16
}
# Since we have KISS flip/flop think that number of mamba layers are actually 2 times higher
# This is somewhat relatable to LLM model where 1 block had two mamba layers: one replaced ATTN, one replaced MLP

weights_path = Path(
    f"data/image-flip-weights-{d_model}x{n_layer}-{str(OPTS['dtype'])}.bin")
print(f"Weight path is {str(weights_path)}")


class MambaWrap(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.mamba = Mamba2Simple(d_model, **OPTS, headdim=headdim)
        self.norm = nn.LayerNorm(d_model, **OPTS)

    def forward(self, x):
        residual = x
        x = self.norm(x)
        x = self.mamba(x)
        x = residual + x
        return x


class MambaFlipFlop(nn.Module):
    def __init__(self, n_values) -> None:
        super().__init__()
        self.mb_forward = MambaWrap()
        self.mb_backward = MambaWrap()
        self.n_values = n_values

    def forward(self, x):
        x = self.mb_forward(x)
        x = self.swap_order(x)
        x = self.mb_backward(x)
        x = self.swap_order(x)
        return x

    def swap_order(self, x):
        T = x.shape[1]
        head = torch.arange(0, T - self.n_values)
        tail = torch.arange(T - 1, T - self.n_values - 1, -1)
        seq = torch.cat((head, tail))
        x = x[:, seq]
        return x


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.from_rgb = RGBToModel(d_model, **OPTS)
        self.to_rgb = ModelToRGB(d_model, **OPTS)
        self.s0 = nn.Parameter(torch.randn(1, 1, d_model, **OPTS))
        self.suffix = nn.Parameter(torch.randn(64*64, d_model, **OPTS))
        self.layers = nn.ModuleList([MambaFlipFlop(64*64)
                                    for _ in range(n_layer)])
        self.norm0 = nn.LayerNorm(d_model, **OPTS)

    def forward(self, batch: ImageBatch):
        B = batch.n_batch
        batch = batch.as_1d()
        batch.im8 = self.from_rgb(batch.im8)

        s0 = self.s0.repeat(B, 1, 1)
        s1 = self.zoom(batch.im8)

        x = torch.cat((s0, batch.im8, s1), 1)
        x = self.norm0(x)
        x = self.mamba(x)
        x = x[:, -64*64:]
        y_hat = self.to_rgb(x)
        y_true = batch.im64
        batch.loss = F.mse_loss(y_hat, y_true)
        batch.im64 = y_hat
        return batch.as_2d()

    def zoom(self, im8):
        im8 = im8.view(im8.shape[0], 8, 8, im8.shape[-1])
        im8 = repeat(im8, "B H W C -> B (H 8) (W 8) C")
        im8 = im8.view(im8.shape[0], 64*64, im8.shape[-1])
        im8 = im8 + self.suffix
        return im8

    def mamba(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


if __name__ == "__main__":
    image_db = ImageDB(dtype=OPTS["dtype"])
    model = Model()
    if weights_path.exists():
        print(f"*** Load {str(weights_path)}")
        model.load_state_dict(torch.load(weights_path))
    opt = torch.optim.AdamW(model.parameters(), fused=True)

    for e in (bar := tqdm(range(epochs))):
        b = model(image_db.random_batch(bs))
        b.loss.backward()
        opt.step()
        opt.zero_grad()
        bar.set_description(f'L:{b.loss.item():.4f}')
        if e and e % 100 == 0:
            torch.save(model.state_dict(), weights_path)
    torch.save(model.state_dict(), weights_path)