|
from mamba_ssm.modules.mamba2_simple import Mamba2Simple |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from tqdm.auto import tqdm |
|
from pathlib import Path |
|
from einops import repeat |
|
|
|
|
|
from image_utils import ImageDB, ImageBatch, RGBToModel |
|
from image_utils import ModelToRGB |
|
|
|
epochs = 10_000 |
|
bs = 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
d_model = 1024 |
|
headdim = 64 |
|
n_layer = 4 |
|
|
|
OPTS = { |
|
'device': "cuda", |
|
'dtype': torch.bfloat16 |
|
} |
|
|
|
|
|
|
|
weights_path = Path( |
|
f"data/image-flip-weights-{d_model}x{n_layer}-{str(OPTS['dtype'])}.bin") |
|
print(f"Weight path is {str(weights_path)}") |
|
|
|
|
|
class MambaWrap(nn.Module): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.mamba = Mamba2Simple(d_model, **OPTS, headdim=headdim) |
|
self.norm = nn.LayerNorm(d_model, **OPTS) |
|
|
|
def forward(self, x): |
|
residual = x |
|
x = self.norm(x) |
|
x = self.mamba(x) |
|
x = residual + x |
|
return x |
|
|
|
|
|
class MambaFlipFlop(nn.Module): |
|
def __init__(self, n_values) -> None: |
|
super().__init__() |
|
self.mb_forward = MambaWrap() |
|
self.mb_backward = MambaWrap() |
|
self.n_values = n_values |
|
|
|
def forward(self, x): |
|
x = self.mb_forward(x) |
|
x = self.swap_order(x) |
|
x = self.mb_backward(x) |
|
x = self.swap_order(x) |
|
return x |
|
|
|
def swap_order(self, x): |
|
T = x.shape[1] |
|
head = torch.arange(0, T - self.n_values) |
|
tail = torch.arange(T - 1, T - self.n_values - 1, -1) |
|
seq = torch.cat((head, tail)) |
|
x = x[:, seq] |
|
return x |
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.from_rgb = RGBToModel(d_model, **OPTS) |
|
self.to_rgb = ModelToRGB(d_model, **OPTS) |
|
self.s0 = nn.Parameter(torch.randn(1, 1, d_model, **OPTS)) |
|
self.suffix = nn.Parameter(torch.randn(64*64, d_model, **OPTS)) |
|
self.layers = nn.ModuleList([MambaFlipFlop(64*64) |
|
for _ in range(n_layer)]) |
|
self.norm0 = nn.LayerNorm(d_model, **OPTS) |
|
|
|
def forward(self, batch: ImageBatch): |
|
B = batch.n_batch |
|
batch = batch.as_1d() |
|
batch.im8 = self.from_rgb(batch.im8) |
|
|
|
s0 = self.s0.repeat(B, 1, 1) |
|
s1 = self.zoom(batch.im8) |
|
|
|
x = torch.cat((s0, batch.im8, s1), 1) |
|
x = self.norm0(x) |
|
x = self.mamba(x) |
|
x = x[:, -64*64:] |
|
y_hat = self.to_rgb(x) |
|
y_true = batch.im64 |
|
batch.loss = F.mse_loss(y_hat, y_true) |
|
batch.im64 = y_hat |
|
return batch.as_2d() |
|
|
|
def zoom(self, im8): |
|
im8 = im8.view(im8.shape[0], 8, 8, im8.shape[-1]) |
|
im8 = repeat(im8, "B H W C -> B (H 8) (W 8) C") |
|
im8 = im8.view(im8.shape[0], 64*64, im8.shape[-1]) |
|
im8 = im8 + self.suffix |
|
return im8 |
|
|
|
def mamba(self, x): |
|
for layer in self.layers: |
|
x = layer(x) |
|
return x |
|
|
|
|
|
if __name__ == "__main__": |
|
image_db = ImageDB(dtype=OPTS["dtype"]) |
|
model = Model() |
|
if weights_path.exists(): |
|
print(f"*** Load {str(weights_path)}") |
|
model.load_state_dict(torch.load(weights_path)) |
|
opt = torch.optim.AdamW(model.parameters(), fused=True) |
|
|
|
for e in (bar := tqdm(range(epochs))): |
|
b = model(image_db.random_batch(bs)) |
|
b.loss.backward() |
|
opt.step() |
|
opt.zero_grad() |
|
bar.set_description(f'L:{b.loss.item():.4f}') |
|
if e and e % 100 == 0: |
|
torch.save(model.state_dict(), weights_path) |
|
torch.save(model.state_dict(), weights_path) |
|
|