File size: 2,925 Bytes
1a030c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.modules.mamba2_simple import Mamba2Simple
from mamba_ssm.modules.mamba2 import Mamba2
import torch
from torch import Tensor
import random
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from pathlib import Path
from einops import rearrange, repeat
from typing import Optional


from image_utils import ImageDB, ImageBatch, RGBToModel
from image_utils import ModelToRGB
from torch_utils import model_numel

epochs = 10_000
bs = 16
d_model = 768
weights_path = Path(f"data/weights-{d_model}.bin")

OPTS = {
    'device': "cuda",
    'dtype': torch.float32
}


class MambaWrap(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.mamba = Mamba2Simple(d_model, **OPTS, headdim=64)
        self.norm = nn.LayerNorm(d_model, **OPTS)

    def forward(self, x):
        residual = x
        x = self.norm(x)
        x = self.mamba(x)
        x = residual + x
        return x


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.from_rgb = RGBToModel(d_model, **OPTS)
        self.to_rgb = ModelToRGB(d_model, **OPTS)
        self.s0 = nn.Parameter(torch.randn(1, 1, d_model, **OPTS))
        self.suffix = nn.Parameter(torch.randn(64*64, d_model, **OPTS))
        self.layers = nn.ModuleList([MambaWrap() for _ in range(4)])
        self.norm0 = nn.LayerNorm(d_model, **OPTS)

    def forward(self, batch: ImageBatch):
        B = batch.n_batch
        batch = batch.as_1d()
        batch.im8 = self.from_rgb(batch.im8)

        s0 = self.s0.repeat(B, 1, 1)
        s1 = self.zoom(batch.im8)

        x = torch.cat((s0, batch.im8, s1), 1)
        x = self.norm0(x)
        x = self.mamba(x)
        x = x[:, -64*64:]
        y_hat = self.to_rgb(x)
        y_true = batch.im64
        batch.loss = F.mse_loss(y_hat, y_true)
        batch.im64 = y_hat
        return batch.as_2d()

    def zoom(self, im8):
        im8 = im8.view(im8.shape[0], 8, 8, im8.shape[-1])
        im8 = repeat(
            im8, "B H W C -> B (H 8) (W 8) C").view(im8.shape[0], 64*64, im8.shape[-1])
        im8 = im8 + self.suffix
        return im8

    def mamba(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

if __name__ == "__main_":
    image_db = ImageDB(dtype=OPTS["dtype"])
    model = Model()
    if weights_path.exists():
        print(f"*** Load {weights_path:s}")
        model.load_state_dict(torch.load(weights_path))
    opt = torch.optim.AdamW(model.parameters(), fused=True)
    
    for e in (bar := tqdm(range(epochs))):
        b = model(image_db.random_batch(bs))
        b.loss.backward()
        opt.step()
        opt.zero_grad()
        bar.set_description(f'L:{b.loss.item():.4f}')
        if e and e % 100 == 0:
            torch.save(model.state_dict(), weights_path)
    torch.save(model.state_dict(), weights_path)