anhth commited on
Commit
8314c30
·
1 Parent(s): cc2b90a

Initial Commit

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ *.pyc
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import v2
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from colorizer import ColorComicNet, MODEL_CFG
6
+ from utils import smart_padding, remove_padding
7
+
8
+ # Define the transformation pipeline for the input image
9
+ TRANSFORM = v2.Compose([
10
+ v2.ToImage(),
11
+ v2.ToDtype(torch.float32, scale=True),
12
+ v2.Normalize(mean=[0.5], std=[0.5])
13
+ ])
14
+
15
+ # Image preprocessing and postprocessing functions
16
+ def preprocess_image(image: Image.Image, divisor=16):
17
+ """ Preprocess the input PIL image for the model. """
18
+ image = image.convert('RGB')
19
+ image_tensor = TRANSFORM(image).unsqueeze(0) # Shape: (1, 3, H, W)
20
+ image_tensor, padding = smart_padding(image_tensor, divisor=divisor)
21
+ return image_tensor, padding
22
+
23
+ def postprocess_output(output_tensor, padding):
24
+ """ Postprocess the model output tensor to a PIL image. """
25
+ output_tensor = remove_padding(output_tensor, padding)
26
+ output_tensor = (output_tensor + 1) / 2 # Scale back to [0, 1]
27
+ output_image = output_tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).numpy() # Shape: (H, W, C)
28
+ return output_image
29
+
30
+ # Define the colorization function
31
+ def colorize_image(gray_image: Image.Image):
32
+ """ Colorize a single grayscale image using the model. """
33
+ with torch.no_grad():
34
+ # Preprocess
35
+ input_tensor, padding = preprocess_image(gray_image, divisor=64)
36
+ # Inference
37
+ output = model(input_tensor)
38
+ # Postprocess
39
+ output_image = postprocess_output(output, padding)
40
+ return output_image
41
+
42
+ # Initialize the model
43
+ model = ColorComicNet(MODEL_CFG)
44
+ model.load_state_dict(torch.load("./weights/colorizer.pth", map_location=torch.device('cpu')))
45
+ model.fuse()
46
+ model.eval()
47
+
48
+ # Create the Gradio interface
49
+
50
+ custom_css = """
51
+ body {
52
+ background: linear-gradient(135deg, #1e1e2f, #2a2a40);
53
+ color: white;
54
+ }
55
+ .gradio-container {
56
+ max-width: 1000px !important;
57
+ margin: auto;
58
+ }
59
+ .header {
60
+ text-align: center;
61
+ padding: 20px;
62
+ }
63
+ .header h1 {
64
+ font-size: 2.2rem;
65
+ margin-bottom: 5px;
66
+ }
67
+ .header p {
68
+ color: #cfcfe0;
69
+ }
70
+ .button-primary {
71
+ background: linear-gradient(90deg, #ff7a18, #ffb347);
72
+ border: none;
73
+ color: white;
74
+ font-weight: bold;
75
+ }
76
+ """
77
+
78
+ with gr.Blocks(css=custom_css) as demo:
79
+ # Header
80
+ with gr.Column(elem_classes="header"):
81
+ gr.Markdown("# 🎨 Comic Colorization")
82
+ gr.Markdown("Bring your grayscale comics to life with **ColorComicNet**")
83
+ with gr.Row(equal_height=True):
84
+ with gr.Column(scale=1):
85
+ input_image = gr.Image(
86
+ label="📥 Upload Grayscale Image",
87
+ type="pil",
88
+ )
89
+ colorize_button = gr.Button(
90
+ "✨ Colorize Image",
91
+ elem_classes="button-primary"
92
+ )
93
+
94
+ with gr.Column(scale=1):
95
+ output_image = gr.Image(
96
+ label="📤 Colorized Result",
97
+ type="numpy",
98
+ )
99
+
100
+ # Example section
101
+ # gr.Markdown("### 🖼️ Try an example")
102
+ # examples = gr.Examples(
103
+ # examples=[
104
+ # ["example1.png"],
105
+ # ["example2.png"]
106
+ # ],
107
+ # inputs=input_image
108
+ # )
109
+
110
+ # Footer
111
+ gr.Markdown("---")
112
+
113
+ # Interaction
114
+ colorize_button.click(
115
+ fn=colorize_image,
116
+ inputs=input_image,
117
+ outputs=output_image
118
+ )
119
+
120
+ demo.launch()
colorizer.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numbers
5
+ from dataclasses import dataclass, asdict
6
+ from einops import rearrange
7
+
8
+ class UpSample(nn.Module):
9
+ """ UpSampling block using PixelShuffle """
10
+ def __init__(self, filters=64):
11
+ super().__init__()
12
+ self.conv = nn.Conv2d(filters, filters * 2, kernel_size=1, stride=1, padding=0, bias=True)
13
+ self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
14
+
15
+ def forward(self, x):
16
+ x = self.conv(x)
17
+ x = self.pixel_shuffle(x)
18
+ return x
19
+
20
+ ## DownSampling block
21
+ class DownSample(nn.Module):
22
+ """ DownSampling block using PixelUnshuffle """
23
+ def __init__(self, filters=64):
24
+ super().__init__()
25
+ self.conv = nn.Conv2d(filters, filters // 2, kernel_size=1, stride=1, padding=0, bias=True)
26
+ self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=2)
27
+
28
+ def forward(self, x):
29
+ """ SHAPE (B, C, H, W) -> SHAPE (B, C/4, H/2, W/2) """
30
+ x = self.conv(x)
31
+ x = self.pixel_unshuffle(x)
32
+ return x
33
+
34
+ # Custom LayerNormalization
35
+ class BiasFree_LayerNorm(nn.Module):
36
+ """ Bias-Free Layer Normalization """
37
+ def __init__(self, normalized_shape):
38
+ super().__init__()
39
+ if isinstance(normalized_shape, numbers.Integral):
40
+ normalized_shape = (normalized_shape,)
41
+ normalized_shape = torch.Size(normalized_shape)
42
+
43
+ assert len(normalized_shape) == 1
44
+
45
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
46
+ self.normalized_shape = normalized_shape
47
+
48
+ def forward(self, x):
49
+ x = x.contiguous()
50
+ sigma = x.var(-1, keepdim=True, unbiased=False)
51
+ return x / torch.sqrt(sigma+1e-5) * self.weight
52
+
53
+ class WithBias_LayerNorm(nn.Module):
54
+ """ With-Bias Layer Normalization """
55
+ def __init__(self, normalized_shape):
56
+ super().__init__()
57
+ if isinstance(normalized_shape, numbers.Integral):
58
+ normalized_shape = (normalized_shape,)
59
+ normalized_shape = torch.Size(normalized_shape)
60
+
61
+ assert len(normalized_shape) == 1
62
+
63
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
64
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
65
+ self.normalized_shape = normalized_shape
66
+
67
+ def forward(self, x):
68
+ x = x.contiguous()
69
+ mu = x.mean(-1, keepdim=True)
70
+ sigma = x.var(-1, keepdim=True, unbiased=False)
71
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
72
+
73
+ class LayerNorm(nn.Module):
74
+ """ Layer Normalization supporting two types: BiasFree and WithBias """
75
+ def __init__(self, dim, LayerNorm_type, out_4d=True):
76
+ super().__init__()
77
+ if LayerNorm_type =='BiasFree':
78
+ self.body = BiasFree_LayerNorm(dim)
79
+ else:
80
+ self.body = WithBias_LayerNorm(dim)
81
+ self.out_4d = out_4d
82
+
83
+ def to_3d(self, x):
84
+ # Convert (B, C, H, W) to (B, H*W, C)
85
+ if len(x.shape) == 3:
86
+ return x
87
+ elif len(x.shape) == 4:
88
+ return rearrange(x, 'b c h w -> b (h w) c')
89
+ else:
90
+ raise ValueError("Input must be a 3D or 4D tensor")
91
+
92
+ def to_4d(self, x, h, w):
93
+ # Convert (B, H*W, C) to (B, C, H, W)
94
+ if len(x.shape) == 4:
95
+ return x
96
+ elif len(x.shape) == 3:
97
+ return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
98
+ else:
99
+ raise ValueError("Input must be a 3D or 4D tensor")
100
+
101
+ def forward(self, x):
102
+ if self.out_4d:
103
+ h, w = x.shape[-2:]
104
+ return self.to_4d(self.body(self.to_3d(x)), h, w)
105
+ else:
106
+ return self.body(x)
107
+
108
+ class RepConv3(nn.Module):
109
+ def __init__(self, in_channels, out_channels, groups, deploy=False):
110
+ super().__init__()
111
+ self.in_channels = in_channels
112
+ self.out_channels = out_channels
113
+ self.groups = groups
114
+ self.deploy = deploy
115
+ self.reparam = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups)
116
+ if not deploy:
117
+ self.conv_3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups)
118
+ self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups)
119
+ self.conv_1x3 = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 3), padding=(0, 1), groups=groups)
120
+ self.conv_3x1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 1), padding=(1, 0), groups=groups)
121
+ self.conv_1x1_branch = nn.Conv2d(in_channels, in_channels, kernel_size=1, groups=groups, bias=False)
122
+ self.conv_3x3_branch = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups, bias=False)
123
+ else:
124
+ self._delete_branches()
125
+
126
+ def _delete_branches(self):
127
+ for name in ['conv_3x3','conv_1x1','conv_1x3','conv_3x1', 'conv_1x1_branch', 'conv_3x3_branch']:
128
+ if hasattr(self, name):
129
+ delattr(self, name)
130
+
131
+ def fuse(self, delete_branches=True):
132
+ if self.deploy:
133
+ return
134
+ # Extract weights and biases
135
+ conv_3x3_w, conv_3x3_b = self.conv_3x3.weight, self.conv_3x3.bias
136
+ conv_1x1_w, conv_1x1_b = self.conv_1x1.weight, self.conv_1x1.bias
137
+ conv_1x3_w, conv_1x3_b = self.conv_1x3.weight, self.conv_1x3.bias
138
+ conv_3x1_w, conv_3x1_b = self.conv_3x1.weight, self.conv_3x1.bias
139
+ conv_1x1_branch_w, conv_3x3_branch_w = self.conv_1x1_branch.weight, self.conv_3x3_branch.weight
140
+ # Pad the smaller kernels to 3x3
141
+ conv_1x1_w_pad = F.pad(conv_1x1_w, [1, 1, 1, 1])
142
+ conv_1x3_w_pad = F.pad(conv_1x3_w, [0, 0, 1, 1])
143
+ conv_3x1_w_pad = F.pad(conv_3x1_w, [1, 1, 0, 0])
144
+ if self.groups == 1:
145
+ conv_1x1_3x3_w_pad = F.conv2d(conv_3x3_branch_w, conv_1x1_branch_w.permute(1, 0, 2, 3))
146
+ else:
147
+ w_slices = []
148
+ conv_1x1_branch_w_T = conv_1x1_branch_w.permute(1, 0, 2, 3)
149
+ in_channels_per_group = self.in_channels // self.groups
150
+ out_channels_per_group = self.out_channels // self.groups
151
+ for g in range(self.groups):
152
+ # Slice the transposed 1x1 weights for this group's channels
153
+ conv_1x1_branch_w_T_slice = conv_1x1_branch_w_T[:, g*in_channels_per_group:(g+1)*in_channels_per_group, :, :]
154
+ # Slice the 3x3 weights for this group's output channels
155
+ conv_3x3_branch_w_slice = conv_3x3_branch_w[g*out_channels_per_group:(g+1)*out_channels_per_group, :, :, :]
156
+ w_slices.append(F.conv2d(conv_3x3_branch_w_slice, conv_1x1_branch_w_T_slice))
157
+ conv_1x1_3x3_w_pad = torch.cat(w_slices, dim=0)
158
+ # Fuse weights and biases
159
+ conv_w = conv_3x3_w + conv_1x1_w_pad + conv_1x3_w_pad + conv_3x1_w_pad + conv_1x1_3x3_w_pad
160
+ if conv_3x3_b is None:
161
+ conv_3x3_b = torch.zeros(self.out_channels, device=conv_w.device)
162
+ conv_b = conv_3x3_b + conv_1x1_b + conv_1x3_b + conv_3x1_b
163
+ self.reparam.weight.data.copy_(conv_w)
164
+ self.reparam.bias.data.copy_(conv_b)
165
+ # Delete the original branches
166
+ if delete_branches:
167
+ self._delete_branches()
168
+ # Set deploy flag
169
+ self.deploy = True
170
+
171
+ def forward(self, x):
172
+ if self.deploy:
173
+ return self.reparam(x)
174
+ else:
175
+ return self.conv_3x3(x) + self.conv_1x1(x) + self.conv_1x3(x) + self.conv_3x1(x) + self.conv_3x3_branch(self.conv_1x1_branch(x))
176
+
177
+ from monarch_attn import MonarchAttention
178
+
179
+ @dataclass
180
+ class RepAttnConfig:
181
+ dim: int
182
+ num_heads: int = 8
183
+ block_size: int = 16
184
+ num_steps: int = 2
185
+ pad_type: str = "pre"
186
+ impl: str = "torch"
187
+ deploy: bool = False
188
+
189
+ class RepAttn(nn.Module):
190
+ """ Re-parameterizable Attention Block using MonarchAttention as the core attention mechanism."""
191
+ def __init__(self, dim, num_heads=8, block_size=14, num_steps=1, pad_type="pre", impl="torch", deploy=False):
192
+ super().__init__()
193
+ self.num_heads = num_heads
194
+ self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1)
195
+ self.monarch_attn = MonarchAttention(
196
+ block_size=block_size,
197
+ num_steps=num_steps,
198
+ pad_type=pad_type,
199
+ impl=impl
200
+ )
201
+ if deploy:
202
+ self.attn_fn = self.monarch_attn
203
+ else:
204
+ self.attn_fn = self.common_attn
205
+ self.proj = nn.Conv2d(dim, dim, kernel_size=1)
206
+ self.deploy = deploy
207
+
208
+ def common_attn(self, q, k, v):
209
+ """ Scaled Dot-Product Attention """
210
+ scale = (q.shape[-1]) ** -0.5
211
+ attn = (q @ k.transpose(-2, -1)) * scale
212
+ attn = attn.softmax(dim=-1)
213
+ out = attn @ v
214
+ return out
215
+
216
+ @torch.no_grad()
217
+ def fuse(self):
218
+ if not self.deploy:
219
+ self.attn_fn = self.monarch_attn
220
+ self.deploy = True
221
+
222
+ def forward(self, x):
223
+ B, C, H, W = x.shape
224
+ qkv = self.qkv(x)
225
+ q, k, v = torch.chunk(qkv, 3, dim=1)
226
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
227
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
228
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
229
+ attn_out = self.attn_fn(q, k, v)
230
+ attn_out = rearrange(attn_out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=H, w=W)
231
+ out = self.proj(attn_out)
232
+ return out
233
+
234
+ @dataclass
235
+ class FFNConfig:
236
+ dim: int
237
+ expansion_factor: int = 1
238
+ deploy: bool = False
239
+
240
+ class RepFFN(nn.Module):
241
+ def __init__(self, dim, expansion_factor=1, deploy=False):
242
+ super().__init__()
243
+ hidden_features = int(dim * expansion_factor)
244
+ self.project_in = RepConv3(dim, hidden_features, groups=1, deploy=deploy)
245
+ self.dwconv = RepConv3(hidden_features, hidden_features*2, groups=hidden_features, deploy=deploy)
246
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1)
247
+
248
+ @torch.no_grad()
249
+ def fuse(self):
250
+ self.project_in.fuse()
251
+ self.dwconv.fuse()
252
+
253
+
254
+ def forward(self, x):
255
+ x = self.project_in(x)
256
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
257
+ x = F.gelu(x1) * x2
258
+ x = self.project_out(x)
259
+ return x
260
+
261
+ class SkipConnection(nn.Module):
262
+ def __init__(self, dim):
263
+ super().__init__()
264
+ self.conv = nn.Conv2d(dim*2, dim, kernel_size=1)
265
+
266
+ def forward(self, x1, x2):
267
+ x = torch.cat([x1, x2], dim=1)
268
+ x = self.conv(x)
269
+ return x
270
+
271
+ class RepTransformerBlock(nn.Module):
272
+ def __init__(self, rep_attn_cfg: RepAttnConfig, ffn_cfg: FFNConfig, norm_type='WithBias'):
273
+ super().__init__()
274
+ self.rep_attn = RepAttn(**asdict(rep_attn_cfg))
275
+ self.rep_ffn = RepFFN(**asdict(ffn_cfg))
276
+ self.norm1 = LayerNorm(rep_attn_cfg.dim, norm_type)
277
+ self.norm2 = LayerNorm(rep_attn_cfg.dim, norm_type)
278
+
279
+ @torch.no_grad()
280
+ def fuse(self):
281
+ self.rep_attn.fuse()
282
+ self.rep_ffn.fuse()
283
+
284
+ def forward(self, x):
285
+ x = x + self.rep_attn(self.norm1(x))
286
+ x = x + self.rep_ffn(self.norm2(x))
287
+ return x
288
+
289
+ class Block(nn.Module):
290
+ def __init__(self, num_block, rep_attn_cfg: RepAttnConfig, ffn_cfg: FFNConfig, norm_type='WithBias'):
291
+ super().__init__()
292
+ self.num_block = num_block
293
+ self.blocks = nn.ModuleList([
294
+ RepTransformerBlock(rep_attn_cfg, ffn_cfg, norm_type) for _ in range(num_block)
295
+ ])
296
+
297
+ @torch.no_grad()
298
+ def fuse(self):
299
+ for block in self.blocks:
300
+ block.fuse()
301
+
302
+ def forward(self, x):
303
+ for block in self.blocks:
304
+ x = block(x)
305
+ return x
306
+
307
+ class ColorComicNet(nn.Module):
308
+ """ Main model implementation """
309
+ def __init__(self, input_shape=(3, 1024, 1024), output_channels=3, deploy=False, dims=[48, 96, 192, 384], num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 2, 4], bias=True, last_act=None):
310
+ super().__init__()
311
+ assert len(dims) == len(num_blocks) == len(num_heads), "Length of dims, num_blocks and num_heads must be the same"
312
+ self.input_shape = input_shape
313
+ self.output_channels = output_channels
314
+ self.deploy = deploy
315
+ self.dims = dims
316
+ self.num_blocks = num_blocks
317
+ self.bias = bias
318
+ self.num_heads = num_heads
319
+
320
+ # Extractor
321
+ self.stem = nn.Conv2d(input_shape[0], dims[0], kernel_size=7, stride=4, padding=3, bias=bias)
322
+
323
+ # Encoder
324
+ layers = []
325
+ down_convs = []
326
+ for idx in range(len(dims)):
327
+ attn_cfg, ffn_cfg = self.build_cfg(dims[idx], num_heads[idx])
328
+ block = Block(num_blocks[idx], attn_cfg, ffn_cfg, norm_type='WithBias')
329
+ if idx < len(dims) - 1:
330
+ down_convs.append(DownSample(dims[idx]))
331
+ layers.append(block)
332
+ self.bottleneck = layers[-1] # Last encoder layer as bottleneck
333
+ self.encoder = nn.ModuleList(layers[:-1])
334
+ self.downsample = nn.ModuleList(down_convs)
335
+
336
+ # Decoder
337
+ layers = []
338
+ up_convs = []
339
+ skip_connections = []
340
+ for idx in range(len(dims)-2, -1, -1):
341
+ attn_cfg, ffn_cfg = self.build_cfg(dims[idx], num_heads[idx])
342
+ # print(f"Decoder layer {idx}: shape {l_shape}")
343
+ up_conv = UpSample(dims[idx+1])
344
+ block = Block(num_blocks[idx], attn_cfg, ffn_cfg, norm_type='WithBias')
345
+ layers.append(block)
346
+ up_convs.append(up_conv)
347
+ skip_connections.append(SkipConnection(dims[idx]))
348
+ self.decoder = nn.ModuleList(layers)
349
+ self.up_sample = nn.ModuleList(up_convs)
350
+ self.skip = nn.ModuleList(skip_connections)
351
+
352
+ # Head
353
+ self.head = nn.Sequential(
354
+ RepConv3(dims[0], dims[0]//2, 1, deploy=deploy),
355
+ nn.GELU(),
356
+ nn.Conv2d(dims[0]//2, output_channels, kernel_size=1, bias=bias),
357
+ )
358
+ self.last_act = last_act if last_act is not None else nn.Identity()
359
+
360
+ @torch.no_grad()
361
+ def fuse(self):
362
+ for block in self.encoder:
363
+ block.fuse()
364
+ self.bottleneck.fuse()
365
+ for block in self.decoder:
366
+ block.fuse()
367
+ for conv in self.head:
368
+ if isinstance(conv, RepConv3):
369
+ conv.fuse()
370
+
371
+ def build_cfg(self, dim, head):
372
+ # RepAttn config
373
+ attn_cfg = RepAttnConfig(
374
+ dim=dim,
375
+ num_heads=head,
376
+ block_size=12,
377
+ num_steps=2,
378
+ pad_type="pre",
379
+ impl="torch",
380
+ deploy=self.deploy
381
+ )
382
+ ## FFN config
383
+ ffn_cfg = FFNConfig(
384
+ dim=dim,
385
+ expansion_factor=1,
386
+ )
387
+ return attn_cfg, ffn_cfg
388
+
389
+ def forward(self, x):
390
+ """
391
+ x: (B, C, H, W)
392
+ """
393
+ res = x
394
+ x = self.stem(x)
395
+ feats = []
396
+ for blk, down in zip(self.encoder, self.downsample):
397
+ x = blk(x)
398
+ feats.append(x)
399
+ x = down(x)
400
+ x = self.bottleneck(x)
401
+ for blk, up, skip in zip(self.decoder, self.up_sample, self.skip):
402
+ x = up(x)
403
+ cur_feat = feats.pop()
404
+ x = skip(x, cur_feat)
405
+ x = blk(x)
406
+ x = F.interpolate(x, scale_factor=4, mode='bilinear')
407
+ x = self.head(x) + res
408
+ x = self.last_act(x)
409
+ return x
410
+
411
+ # Example model configuration
412
+ MODEL_CFG = {
413
+ 'input_shape': (3, 512, 512),
414
+ 'dims': [24, 48, 96, 192],
415
+ 'num_blocks': [1, 2, 2, 4],
416
+ 'num_heads': [1, 2, 4, 8],
417
+ 'bias': True,
418
+ 'last_act': nn.Tanh(),
419
+ 'deploy': False
420
+ }
monarch_attn/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .monarch_attention import MonarchAttention
monarch_attn/ma_history.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+
7
+ Tensor = torch.Tensor
8
+
9
+
10
+ def monarch_matrix(L: Tensor, R: Tensor) -> Tensor:
11
+ out = torch.einsum("jkl,kji->ljki", L, R)
12
+ return rearrange(out, "l j k i -> (l j) (k i)")
13
+
14
+
15
+ def monarch_attention_history(q: Tensor, k: Tensor, T: int, B: int) -> list[Tensor]:
16
+ N, D = q.shape
17
+ M = N // B
18
+
19
+ q = q / sqrt(D)
20
+
21
+ qb = rearrange(q, "(l j) v -> j l v", j=B)
22
+ kb = rearrange(k, "(k i) v -> k i v", i=B)
23
+
24
+ L = torch.stack(B * [torch.eye(M, device=q.device)])
25
+
26
+ history = []
27
+
28
+ # Alternating maximization for L, R
29
+ for t in range(T):
30
+ # R update
31
+ aR = torch.einsum("jkl,jlv->kjv", L, qb)
32
+ bR = torch.einsum("kjv,kiv->kji", aR, kb)
33
+ cR = torch.einsum("jkl->kj", L)
34
+ R = F.softmax(bR / cR[:, :, None], dim=2)
35
+
36
+ history.append(monarch_matrix(L, R))
37
+
38
+ # L update
39
+ aL = torch.einsum("kji,kiv->jkv", R, kb)
40
+ bL = torch.einsum("jkv,jlv->jkl", aL, qb)
41
+ cL = torch.einsum("kji->jk", R * torch.log(R))
42
+ L = F.softmax(bL - cL[:, :, None], dim=1)
43
+
44
+ return history
monarch_attn/ma_torch.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ Tensor = torch.Tensor
7
+ xlogy = torch.special.xlogy
8
+
9
+
10
+ def al_cl_ref(ar, k, cr, sm_scale, mask, eps=1e-12):
11
+ r_hat = sm_scale * (ar @ k.transpose(-1, -2)).to(torch.float)
12
+ r_hat = r_hat / (cr[..., :, None] + eps)
13
+ r_hat = r_hat + torch.where(mask[..., None, :], 0.0, -float("inf"))
14
+ r_hat = torch.exp(
15
+ r_hat - torch.clamp(torch.max(r_hat, dim=-1, keepdim=True).values, min=eps)
16
+ )
17
+ r = r_hat / (torch.sum(r_hat, dim=-1, keepdim=True) + eps)
18
+ r = torch.clamp(r, min=torch.finfo(r.dtype).tiny)
19
+
20
+ cl = torch.sum(xlogy(r, r), dim=-1).transpose(-1, -2)
21
+ al = sm_scale * (r.to(k.dtype) @ k).transpose(-2, -3)
22
+
23
+ return al, cl
24
+
25
+
26
+ def ar_cr_ref(al, q, cl, mask_t):
27
+ l_hat = (al @ q.transpose(-1, -2)).to(torch.float)
28
+ l_hat = l_hat - cl[..., :, None]
29
+ l = F.softmax(l_hat, dim=-2)
30
+ l = mask_t[..., None, :] * l
31
+
32
+ cr = torch.sum(l, dim=-1).transpose(-1, -2)
33
+ ar = (l.to(q.dtype) @ q).transpose(-2, -3)
34
+
35
+ return ar, cr
36
+
37
+
38
+ def al_y_cl_ref(ar, k, v, cr, sm_scale, mask, eps=1e-12):
39
+ r_hat = sm_scale * (ar @ k.transpose(-1, -2)).to(torch.float)
40
+ r_hat = r_hat / (cr[..., :, None] + eps)
41
+ r_hat = r_hat + torch.where(mask[..., None, :], 0.0, -float("inf"))
42
+ r_hat = torch.exp(
43
+ r_hat - torch.clamp(torch.max(r_hat, dim=-1, keepdim=True).values, min=eps)
44
+ )
45
+ r = r_hat / (torch.sum(r_hat, dim=-1, keepdim=True) + eps)
46
+ r = torch.clamp(r, min=torch.finfo(r.dtype).tiny)
47
+
48
+ cl = torch.sum(xlogy(r, r), dim=-1).transpose(-1, -2)
49
+ al = sm_scale * (r.to(k.dtype) @ k).transpose(-2, -3)
50
+ y = (r.to(v.dtype) @ v).transpose(-2, -3)
51
+
52
+ return al, y, cl
53
+
54
+
55
+ def z_ref(al, q, cl, y):
56
+ l_hat = (q @ al.transpose(-1, -2)).to(torch.float)
57
+ l_hat = l_hat - cl[..., None, :]
58
+ l = F.softmax(l_hat, dim=-1)
59
+
60
+ z = (l.to(y.dtype) @ y).transpose(-2, -3).contiguous()
61
+
62
+ return z
63
+
64
+
65
+ def monarch_attention_torch(
66
+ q: Tensor,
67
+ k: Tensor,
68
+ v: Tensor,
69
+ attn_mask: Tensor | None,
70
+ T: int,
71
+ B: int,
72
+ pre_pad: bool,
73
+ ) -> Tensor:
74
+ E, H, N, D = q.shape
75
+ _, _, _, Dv = v.shape
76
+ M = (N + B - 1) // B
77
+ N_padded = M * B
78
+
79
+ sm_scale = 1 / sqrt(D)
80
+
81
+ pad_t = (N_padded - N, 0) if pre_pad else (0, N_padded - N)
82
+ pad_t_2d = (0, 0) + pad_t
83
+
84
+ q = F.pad(q, pad_t_2d).view(E, H, M, B, D)
85
+ k = F.pad(k, pad_t_2d).view(E, H, M, B, D)
86
+ v = F.pad(v, pad_t_2d).view(E, H, M, B, Dv)
87
+
88
+ ar = q
89
+ cr = torch.ones(E, H, M, B, device=q.device, dtype=torch.float)
90
+ q = q.transpose(-2, -3)
91
+
92
+ pad_offset = N_padded - N if pre_pad else 0
93
+ range_n = torch.arange(M * B).view(M, B).to(q.device)
94
+ mask = range_n >= pad_offset if pre_pad else range_n < N
95
+
96
+ if attn_mask is not None:
97
+ attn_mask = F.pad(attn_mask, pad_t).view(E, 1, M, B)
98
+ mask = torch.logical_and(mask, attn_mask)
99
+
100
+ for _ in range(T - 1):
101
+ al, cl = al_cl_ref(ar, k, cr, sm_scale, mask)
102
+ ar, cr = ar_cr_ref(al, q, cl, mask.mT)
103
+
104
+ al, y, cl = al_y_cl_ref(ar, k, v, cr, sm_scale, mask)
105
+ z = z_ref(al, q, cl, y)
106
+ z = z.view(E, H, N_padded, Dv)
107
+
108
+ return z[..., N_padded - N :, :] if pre_pad else z[..., :N, :]
monarch_attn/ma_triton.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ Tensor = torch.Tensor
9
+
10
+
11
+ @triton.jit
12
+ def xlogx(x):
13
+ return tl.where(x == 0, 0.0, x * tl.log(x))
14
+
15
+
16
+ @triton.jit
17
+ def _al_cl_kernel(
18
+ ar_ptr,
19
+ stride_ar_e,
20
+ stride_ar_h,
21
+ stride_ar_m,
22
+ stride_ar_b,
23
+ stride_ar_d,
24
+ k_ptr,
25
+ stride_k_e,
26
+ stride_k_h,
27
+ stride_k_m,
28
+ stride_k_b,
29
+ stride_k_d,
30
+ cr_ptr,
31
+ stride_cr_e,
32
+ stride_cr_h,
33
+ stride_cr_m,
34
+ stride_cr_b,
35
+ al_ptr,
36
+ stride_al_e,
37
+ stride_al_h,
38
+ stride_al_m,
39
+ stride_al_b,
40
+ stride_al_d,
41
+ cl_ptr,
42
+ stride_cl_e,
43
+ stride_cl_h,
44
+ stride_cl_m,
45
+ stride_cl_b,
46
+ mask_ptr,
47
+ stride_mask_e,
48
+ stride_mask_m,
49
+ stride_mask_b,
50
+ H: int,
51
+ M: int,
52
+ B: int,
53
+ D: int,
54
+ N: int,
55
+ is_first_call: int,
56
+ sm_scale: float,
57
+ HAS_ATTN_MASK: tl.constexpr,
58
+ BLOCK_B: tl.constexpr,
59
+ BLOCK_D: tl.constexpr,
60
+ PRE_PAD: tl.constexpr,
61
+ EPS: tl.constexpr,
62
+ ):
63
+ idx_ehm = tl.program_id(0)
64
+ idx_eh = idx_ehm // M
65
+ idx_e = idx_eh // H
66
+ idx_h = idx_eh % H
67
+ idx_m = idx_ehm % M
68
+
69
+ pad_offset = M * B - N if PRE_PAD else 0
70
+
71
+ range_b = tl.arange(0, BLOCK_B)
72
+ range_d = tl.arange(0, BLOCK_D)
73
+ range_n = B * idx_m + range_b
74
+
75
+ mask_b = range_b < B
76
+ pad_mask_b = mask_b & ((range_n >= pad_offset) if PRE_PAD else (range_n < N))
77
+ k_mask_b = pad_mask_b
78
+ mask_d = range_d < D
79
+
80
+ if HAS_ATTN_MASK:
81
+ mask_block_ptr = (
82
+ mask_ptr
83
+ + stride_mask_e * idx_e
84
+ + stride_mask_m * idx_m
85
+ + stride_mask_b * (range_b - pad_offset)
86
+ )
87
+ valid_token_mask = tl.load(
88
+ mask_block_ptr,
89
+ mask=pad_mask_b,
90
+ other=0,
91
+ )
92
+ k_mask_b = pad_mask_b & valid_token_mask
93
+
94
+ # Load ar
95
+ ar_block_ptr = (
96
+ ar_ptr
97
+ + stride_ar_e * idx_e
98
+ + stride_ar_h * idx_h
99
+ + stride_ar_m * idx_m
100
+ + (
101
+ stride_ar_b * (range_b - (pad_offset if is_first_call else 0))[:, None]
102
+ + stride_ar_d * range_d[None, :]
103
+ )
104
+ )
105
+ ar = tl.load(
106
+ ar_block_ptr,
107
+ mask=(pad_mask_b if is_first_call else mask_b)[:, None] & mask_d[None, :],
108
+ other=0.0,
109
+ )
110
+
111
+ # Load k
112
+ k_block_ptr = (
113
+ k_ptr
114
+ + stride_k_e * idx_e
115
+ + stride_k_h * idx_h
116
+ + stride_k_m * idx_m
117
+ + (stride_k_b * (range_b - pad_offset)[:, None] + stride_k_d * range_d[None, :])
118
+ )
119
+ k = tl.load(
120
+ k_block_ptr,
121
+ mask=k_mask_b[:, None] & mask_d[None, :],
122
+ other=0.0,
123
+ )
124
+
125
+ # Load cr
126
+ cr_block_ptr = (
127
+ cr_ptr
128
+ + stride_cr_e * idx_e
129
+ + stride_cr_h * idx_h
130
+ + stride_cr_m * idx_m
131
+ + (stride_cr_b * range_b)
132
+ )
133
+ cr = tl.load(cr_block_ptr, mask=mask_b, other=1.0)
134
+
135
+ # Attention matrix
136
+ r = sm_scale * tl.dot(ar, tl.trans(k))
137
+ r = r / (cr[:, None] + EPS)
138
+ r = r + tl.where(k_mask_b[None, :], 0.0, float("-inf"))
139
+ r = tl.exp(r - tl.clamp(tl.max(r, axis=1, keep_dims=True), EPS, float("inf")))
140
+ r = r / (tl.sum(r, axis=1, keep_dims=True) + EPS)
141
+
142
+ # Store cl
143
+ cl = tl.sum(xlogx(r), axis=1)
144
+ cl_block_ptr = (
145
+ cl_ptr
146
+ + stride_cl_e * idx_e
147
+ + stride_cl_h * idx_h
148
+ + stride_cl_m * idx_m
149
+ + (stride_cl_b * range_b)
150
+ )
151
+ tl.store(cl_block_ptr, cl, mask=mask_b)
152
+
153
+ # Store al
154
+ al = (sm_scale * tl.dot(r.to(k.dtype), k)).to(ar.dtype)
155
+ al_block_ptr = (
156
+ al_ptr
157
+ + stride_al_e * idx_e
158
+ + stride_al_h * idx_h
159
+ + stride_al_m * idx_m
160
+ + (stride_al_b * range_b[:, None] + stride_al_d * range_d[None, :])
161
+ )
162
+ tl.store(
163
+ al_block_ptr,
164
+ al,
165
+ mask=mask_b[:, None] & mask_d[None, :],
166
+ )
167
+
168
+
169
+ @triton.jit
170
+ def _ar_cr_kernel(
171
+ al_ptr,
172
+ stride_al_e,
173
+ stride_al_h,
174
+ stride_al_m,
175
+ stride_al_b,
176
+ stride_al_d,
177
+ q_ptr,
178
+ stride_q_e,
179
+ stride_q_h,
180
+ stride_q_m,
181
+ stride_q_b,
182
+ stride_q_d,
183
+ cl_ptr,
184
+ stride_cl_e,
185
+ stride_cl_h,
186
+ stride_cl_m,
187
+ stride_cl_b,
188
+ ar_ptr,
189
+ stride_ar_e,
190
+ stride_ar_h,
191
+ stride_ar_m,
192
+ stride_ar_b,
193
+ stride_ar_d,
194
+ cr_ptr,
195
+ stride_cr_e,
196
+ stride_cr_h,
197
+ stride_cr_m,
198
+ stride_cr_b,
199
+ mask_ptr,
200
+ stride_mask_e,
201
+ stride_mask_m,
202
+ stride_mask_b,
203
+ H: int,
204
+ M: int,
205
+ B: int,
206
+ D: int,
207
+ N: int,
208
+ HAS_ATTN_MASK: tl.constexpr,
209
+ BLOCK_M: tl.constexpr,
210
+ BLOCK_D: tl.constexpr,
211
+ PRE_PAD: tl.constexpr,
212
+ ):
213
+ idx_ehb = tl.program_id(0)
214
+ idx_eh = idx_ehb // B
215
+ idx_e = idx_eh // H
216
+ idx_h = idx_eh % H
217
+ idx_b = idx_ehb % B
218
+
219
+ pad_offset = M * B - N if PRE_PAD else 0
220
+
221
+ range_m = tl.arange(0, BLOCK_M)
222
+ range_d = tl.arange(0, BLOCK_D)
223
+ range_n = idx_b + B * range_m
224
+
225
+ mask_m = range_m < M
226
+ q_mask_m = mask_m & (range_n >= pad_offset if PRE_PAD else range_n < N)
227
+ mask_d = range_d < D
228
+
229
+ if HAS_ATTN_MASK:
230
+ mask_block_ptr = (
231
+ mask_ptr
232
+ + stride_mask_e * idx_e
233
+ + stride_mask_b * (idx_b - pad_offset)
234
+ + stride_mask_m * range_m
235
+ )
236
+ valid_token_mask = tl.load(
237
+ mask_block_ptr,
238
+ mask=q_mask_m,
239
+ other=0,
240
+ )
241
+ q_mask_m = q_mask_m & valid_token_mask
242
+
243
+ # Load al
244
+ al_block_ptr = (
245
+ al_ptr
246
+ + stride_al_e * idx_e
247
+ + stride_al_h * idx_h
248
+ + stride_al_b * idx_b
249
+ + (stride_al_m * range_m[:, None] + stride_al_d * range_d[None, :])
250
+ )
251
+ al = tl.load(
252
+ al_block_ptr,
253
+ mask=mask_m[:, None] & mask_d[None, :],
254
+ other=0.0,
255
+ )
256
+
257
+ # Load q
258
+ q_block_ptr = (
259
+ q_ptr
260
+ + stride_q_e * idx_e
261
+ + stride_q_h * idx_h
262
+ + stride_q_b * (idx_b - pad_offset)
263
+ + (stride_q_m * range_m[:, None] + stride_q_d * range_d[None, :])
264
+ )
265
+ q = tl.load(
266
+ q_block_ptr,
267
+ mask=q_mask_m[:, None] & mask_d[None, :],
268
+ other=0.0,
269
+ )
270
+
271
+ # Load cl
272
+ cl_block_ptr = (
273
+ cl_ptr
274
+ + stride_cl_e * idx_e
275
+ + stride_cl_h * idx_h
276
+ + stride_cl_b * idx_b
277
+ + (stride_cl_m * range_m)
278
+ )
279
+ cl = tl.load(cl_block_ptr, mask=mask_m, other=0.0)
280
+
281
+ # Attention matrix
282
+ l = tl.dot(al, tl.trans(q))
283
+ l = l - cl[:, None]
284
+ l = l + tl.where(mask_m[:, None], 0.0, float("-inf"))
285
+ l = tl.exp(l - tl.max(l, axis=0, keep_dims=True))
286
+ l = l / tl.sum(l, axis=0, keep_dims=True)
287
+ l = q_mask_m[None, :] * l
288
+
289
+ # Store cr
290
+ cr = tl.sum(l, axis=1)
291
+ cr_block_ptr = (
292
+ cr_ptr
293
+ + stride_cr_e * idx_e
294
+ + stride_cr_h * idx_h
295
+ + stride_cr_b * idx_b
296
+ + (stride_cr_m * range_m)
297
+ )
298
+ tl.store(cr_block_ptr, cr, mask=mask_m)
299
+
300
+ # Store ar
301
+ ar = tl.dot(l.to(q.dtype), q).to(al.dtype)
302
+ ar_block_ptr = (
303
+ ar_ptr
304
+ + stride_ar_e * idx_e
305
+ + stride_ar_h * idx_h
306
+ + stride_ar_b * idx_b
307
+ + (stride_ar_m * range_m[:, None] + stride_ar_d * range_d[None, :])
308
+ )
309
+ tl.store(
310
+ ar_block_ptr,
311
+ ar,
312
+ mask=mask_m[:, None] & mask_d[None, :],
313
+ )
314
+
315
+
316
+ @triton.jit
317
+ def _al_y_cl_kernel(
318
+ ar_ptr,
319
+ stride_ar_e,
320
+ stride_ar_h,
321
+ stride_ar_m,
322
+ stride_ar_b,
323
+ stride_ar_d,
324
+ k_ptr,
325
+ stride_k_e,
326
+ stride_k_h,
327
+ stride_k_m,
328
+ stride_k_b,
329
+ stride_k_d,
330
+ v_ptr,
331
+ stride_v_e,
332
+ stride_v_h,
333
+ stride_v_m,
334
+ stride_v_b,
335
+ stride_v_d,
336
+ cr_ptr,
337
+ stride_cr_e,
338
+ stride_cr_h,
339
+ stride_cr_m,
340
+ stride_cr_b,
341
+ al_ptr,
342
+ stride_al_e,
343
+ stride_al_h,
344
+ stride_al_m,
345
+ stride_al_b,
346
+ stride_al_d,
347
+ y_ptr,
348
+ stride_y_e,
349
+ stride_y_h,
350
+ stride_y_m,
351
+ stride_y_b,
352
+ stride_y_d,
353
+ cl_ptr,
354
+ stride_cl_e,
355
+ stride_cl_h,
356
+ stride_cl_m,
357
+ stride_cl_b,
358
+ mask_ptr,
359
+ stride_mask_e,
360
+ stride_mask_m,
361
+ stride_mask_b,
362
+ H: int,
363
+ M: int,
364
+ B: int,
365
+ D: int,
366
+ N: int,
367
+ is_first_call: int,
368
+ sm_scale: float,
369
+ HAS_ATTN_MASK: tl.constexpr,
370
+ BLOCK_B: tl.constexpr,
371
+ BLOCK_D: tl.constexpr,
372
+ PRE_PAD: tl.constexpr,
373
+ EPS: tl.constexpr,
374
+ ):
375
+ idx_ehm = tl.program_id(0)
376
+ idx_eh = idx_ehm // M
377
+ idx_e = idx_eh // H
378
+ idx_h = idx_eh % H
379
+ idx_m = idx_ehm % M
380
+
381
+ pad_offset = M * B - N if PRE_PAD else 0
382
+
383
+ range_b = tl.arange(0, BLOCK_B)
384
+ range_d = tl.arange(0, BLOCK_D)
385
+ range_n = B * idx_m + range_b
386
+
387
+ mask_b = range_b < B
388
+ pad_mask_b = mask_b & ((range_n >= pad_offset) if PRE_PAD else range_n < N)
389
+ k_mask_b = pad_mask_b
390
+ mask_d = range_d < D
391
+
392
+ if HAS_ATTN_MASK:
393
+ mask_block_ptr = (
394
+ mask_ptr
395
+ + stride_mask_e * idx_e
396
+ + stride_mask_m * idx_m
397
+ + stride_mask_b * (range_b - pad_offset)
398
+ )
399
+ valid_token_mask = tl.load(
400
+ mask_block_ptr,
401
+ mask=pad_mask_b,
402
+ other=0,
403
+ )
404
+ k_mask_b = pad_mask_b & valid_token_mask
405
+
406
+ # Load ar
407
+ ar_block_ptr = (
408
+ ar_ptr
409
+ + stride_ar_e * idx_e
410
+ + stride_ar_h * idx_h
411
+ + stride_ar_m * idx_m
412
+ + (
413
+ stride_ar_b * (range_b - (pad_offset if is_first_call else 0))[:, None]
414
+ + stride_ar_d * range_d[None, :]
415
+ )
416
+ )
417
+ ar = tl.load(
418
+ ar_block_ptr,
419
+ mask=(pad_mask_b if is_first_call else mask_b)[:, None] & mask_d[None, :],
420
+ other=0.0,
421
+ )
422
+
423
+ # Load k
424
+ k_block_ptr = (
425
+ k_ptr
426
+ + stride_k_e * idx_e
427
+ + stride_k_h * idx_h
428
+ + stride_k_m * idx_m
429
+ + (stride_k_b * (range_b - pad_offset)[:, None] + stride_k_d * range_d[None, :])
430
+ )
431
+ k = tl.load(
432
+ k_block_ptr,
433
+ mask=k_mask_b[:, None] & mask_d[None, :],
434
+ other=0.0,
435
+ )
436
+
437
+ # Load cr
438
+ cr_block_ptr = (
439
+ cr_ptr
440
+ + stride_cr_e * idx_e
441
+ + stride_cr_h * idx_h
442
+ + stride_cr_m * idx_m
443
+ + (stride_cr_b * range_b)
444
+ )
445
+ cr = tl.load(cr_block_ptr, mask=mask_b, other=1.0)
446
+
447
+ # Attention matrix
448
+ r = sm_scale * tl.dot(ar, tl.trans(k))
449
+ r = r / (cr[:, None] + EPS)
450
+ r = r + tl.where(k_mask_b[None, :], 0.0, float("-inf"))
451
+ r = tl.exp(r - tl.clamp(tl.max(r, axis=1, keep_dims=True), EPS, float("inf")))
452
+ r = r / (tl.sum(r, axis=1, keep_dims=True) + EPS)
453
+
454
+ # Store cl
455
+ cl = tl.sum(xlogx(r), axis=1)
456
+ cl_block_ptr = (
457
+ cl_ptr
458
+ + stride_cl_e * idx_e
459
+ + stride_cl_h * idx_h
460
+ + stride_cl_m * idx_m
461
+ + (stride_cl_b * range_b)
462
+ )
463
+ tl.store(cl_block_ptr, cl, mask=mask_b)
464
+
465
+ # Store al
466
+ al = (sm_scale * tl.dot(r.to(k.dtype), k)).to(ar.dtype)
467
+ al_block_ptr = (
468
+ al_ptr
469
+ + stride_al_e * idx_e
470
+ + stride_al_h * idx_h
471
+ + stride_al_m * idx_m
472
+ + (stride_al_b * range_b[:, None] + stride_al_d * range_d[None, :])
473
+ )
474
+ tl.store(
475
+ al_block_ptr,
476
+ al,
477
+ mask=mask_b[:, None] & mask_d[None, :],
478
+ )
479
+
480
+ # Load v
481
+ v_block_ptr = (
482
+ v_ptr
483
+ + stride_v_e * idx_e
484
+ + stride_v_h * idx_h
485
+ + stride_v_m * idx_m
486
+ + (stride_v_b * (range_b - pad_offset)[:, None] + stride_v_d * range_d[None, :])
487
+ )
488
+ v = tl.load(
489
+ v_block_ptr,
490
+ mask=k_mask_b[:, None] & mask_d[None, :],
491
+ other=0.0,
492
+ )
493
+
494
+ # Store y
495
+ y = tl.dot(r.to(v.dtype), v).to(ar.dtype)
496
+ y_block_ptr = (
497
+ y_ptr
498
+ + stride_y_e * idx_e
499
+ + stride_y_h * idx_h
500
+ + stride_y_m * idx_m
501
+ + (stride_y_b * range_b[:, None] + stride_y_d * range_d[None, :])
502
+ )
503
+ tl.store(
504
+ y_block_ptr,
505
+ y,
506
+ mask=mask_b[:, None] & mask_d[None, :],
507
+ )
508
+
509
+
510
+ @triton.jit
511
+ def _z_kernel(
512
+ al_ptr,
513
+ stride_al_e,
514
+ stride_al_h,
515
+ stride_al_m,
516
+ stride_al_b,
517
+ stride_al_d,
518
+ q_ptr,
519
+ stride_q_e,
520
+ stride_q_h,
521
+ stride_q_m,
522
+ stride_q_b,
523
+ stride_q_d,
524
+ y_ptr,
525
+ stride_y_e,
526
+ stride_y_h,
527
+ stride_y_m,
528
+ stride_y_b,
529
+ stride_y_d,
530
+ cl_ptr,
531
+ stride_cl_e,
532
+ stride_cl_h,
533
+ stride_cl_m,
534
+ stride_cl_b,
535
+ z_ptr,
536
+ stride_z_e,
537
+ stride_z_h,
538
+ stride_z_m,
539
+ stride_z_b,
540
+ stride_z_d,
541
+ H: int,
542
+ M: int,
543
+ B: int,
544
+ D: int,
545
+ N: int,
546
+ BLOCK_M: tl.constexpr,
547
+ BLOCK_D: tl.constexpr,
548
+ PRE_PAD: tl.constexpr,
549
+ ):
550
+ idx_ehb = tl.program_id(0)
551
+ idx_eh = idx_ehb // B
552
+ idx_e = idx_eh // H
553
+ idx_h = idx_eh % H
554
+ idx_b = idx_ehb % B
555
+
556
+ pad_offset = M * B - N if PRE_PAD else 0
557
+
558
+ range_m = tl.arange(0, BLOCK_M)
559
+ range_d = tl.arange(0, BLOCK_D)
560
+ range_n = idx_b + B * range_m
561
+
562
+ mask_m = range_m < M
563
+ q_mask_m = mask_m & (range_n >= pad_offset if PRE_PAD else range_n < N)
564
+ mask_d = range_d < D
565
+
566
+ # Load al
567
+ al_block_ptr = (
568
+ al_ptr
569
+ + stride_al_e * idx_e
570
+ + stride_al_h * idx_h
571
+ + stride_al_b * idx_b
572
+ + (stride_al_m * range_m[:, None] + stride_al_d * range_d[None, :])
573
+ )
574
+ al = tl.load(
575
+ al_block_ptr,
576
+ mask=mask_m[:, None] & mask_d[None, :],
577
+ other=0.0,
578
+ )
579
+
580
+ # Load q
581
+ q_block_ptr = (
582
+ q_ptr
583
+ + stride_q_e * idx_e
584
+ + stride_q_h * idx_h
585
+ + stride_q_b * (idx_b - pad_offset)
586
+ + (stride_q_m * range_m[:, None] + stride_q_d * range_d[None, :])
587
+ )
588
+ q = tl.load(
589
+ q_block_ptr,
590
+ mask=q_mask_m[:, None] & mask_d[None, :],
591
+ other=0.0,
592
+ )
593
+
594
+ # Load cl
595
+ cl_block_ptr = (
596
+ cl_ptr
597
+ + stride_cl_e * idx_e
598
+ + stride_cl_h * idx_h
599
+ + stride_cl_b * idx_b
600
+ + (stride_cl_m * range_m)
601
+ )
602
+ cl = tl.load(cl_block_ptr, mask=mask_m, other=0.0)
603
+
604
+ # Attention matrix
605
+ l = tl.dot(q, tl.trans(al))
606
+ l = l - cl[None, :]
607
+ l = l + tl.where(mask_m[None, :], 0.0, float("-inf"))
608
+ l = tl.exp(l - tl.max(l, axis=1, keep_dims=True))
609
+ l = l / tl.sum(l, axis=1, keep_dims=True)
610
+
611
+ # Load y
612
+ y_block_ptr = (
613
+ y_ptr
614
+ + stride_y_e * idx_e
615
+ + stride_y_h * idx_h
616
+ + stride_y_b * idx_b
617
+ + (stride_y_m * range_m[:, None] + stride_y_d * range_d[None, :])
618
+ )
619
+ y = tl.load(
620
+ y_block_ptr,
621
+ mask=mask_m[:, None] & mask_d[None, :],
622
+ other=0.0,
623
+ )
624
+
625
+ # Store z
626
+ z = tl.dot(l.to(y.dtype), y).to(al.dtype)
627
+ z_block_ptr = (
628
+ z_ptr
629
+ + stride_z_e * idx_e
630
+ + stride_z_h * idx_h
631
+ + stride_z_b * (idx_b - pad_offset)
632
+ + (stride_z_m * range_m[:, None] + stride_z_d * range_d[None, :])
633
+ )
634
+ tl.store(
635
+ z_block_ptr,
636
+ z,
637
+ mask=q_mask_m[:, None] & mask_d[None, :],
638
+ )
639
+
640
+
641
+ def monarch_attention_triton(
642
+ q: Tensor,
643
+ k: Tensor,
644
+ v: Tensor,
645
+ attn_mask: Tensor | None,
646
+ T: int,
647
+ B: int,
648
+ pre_pad: bool,
649
+ eps: float = 0.0,
650
+ ) -> Tensor:
651
+ E, H, N, D = q.shape
652
+ M = triton.cdiv(N, B)
653
+
654
+ HMBDN = (H, M, B, D, N)
655
+
656
+ grid_ehm = (E * H * M,)
657
+ grid_ehb = (E * H * B,)
658
+
659
+ BLOCK_B = max(triton.next_power_of_2(B), 16)
660
+ BLOCK_M = max(triton.next_power_of_2(M), 16)
661
+ BLOCK_D = max(triton.next_power_of_2(D), 16)
662
+
663
+ sm_scale = 1 / sqrt(D)
664
+
665
+ q_strides = (q.stride(0), q.stride(1), B * q.stride(2), q.stride(2), q.stride(3))
666
+ k_strides = (k.stride(0), k.stride(1), B * k.stride(2), k.stride(2), k.stride(3))
667
+ v_strides = (v.stride(0), v.stride(1), B * v.stride(2), v.stride(2), v.stride(3))
668
+
669
+ ar = torch.empty(E, H, M, B, D, device=q.device, dtype=q.dtype)
670
+ al = torch.empty_like(ar)
671
+
672
+ ar_strides = (ar.stride(0), ar.stride(1), ar.stride(2), ar.stride(3), ar.stride(4))
673
+ al_strides = (al.stride(0), al.stride(1), al.stride(2), al.stride(3), al.stride(4))
674
+
675
+ cr = torch.ones(E, H, M, B, device=q.device, dtype=torch.float)
676
+ cl = torch.empty_like(cr)
677
+
678
+ cr_strides = (cr.stride(0), cr.stride(1), cr.stride(2), cr.stride(3))
679
+ cl_strides = (cl.stride(0), cl.stride(1), cl.stride(2), cl.stride(3))
680
+
681
+ attn_mask_strides = (
682
+ (attn_mask.stride(0), B * attn_mask.stride(1), attn_mask.stride(1))
683
+ if attn_mask is not None
684
+ else (0, 0, 0)
685
+ )
686
+
687
+ for t in range(T - 1):
688
+ is_first_call = t == 0
689
+ _ar = q if is_first_call else ar
690
+ _ar_strides = q_strides if is_first_call else ar_strides
691
+ _al_cl_kernel[grid_ehm](
692
+ _ar,
693
+ *_ar_strides,
694
+ k,
695
+ *k_strides,
696
+ cr,
697
+ *cr_strides,
698
+ al,
699
+ *al_strides,
700
+ cl,
701
+ *cl_strides,
702
+ attn_mask,
703
+ *attn_mask_strides,
704
+ *HMBDN,
705
+ is_first_call=is_first_call,
706
+ sm_scale=sm_scale,
707
+ HAS_ATTN_MASK=attn_mask is not None, # type: ignore
708
+ BLOCK_B=BLOCK_B, # type: ignore
709
+ BLOCK_D=BLOCK_D, # type: ignore
710
+ PRE_PAD=pre_pad, # type: ignore
711
+ EPS=eps, # type: ignore
712
+ )
713
+
714
+ _ar_cr_kernel[grid_ehb](
715
+ al,
716
+ *al_strides,
717
+ q,
718
+ *q_strides,
719
+ cl,
720
+ *cl_strides,
721
+ ar,
722
+ *ar_strides,
723
+ cr,
724
+ *cr_strides,
725
+ attn_mask,
726
+ *attn_mask_strides,
727
+ *HMBDN,
728
+ HAS_ATTN_MASK=attn_mask is not None, # type: ignore
729
+ BLOCK_M=BLOCK_M, # type: ignore
730
+ BLOCK_D=BLOCK_D, # type: ignore
731
+ PRE_PAD=pre_pad, # type: ignore
732
+ )
733
+
734
+ y = torch.empty_like(al)
735
+ y_strides = (y.stride(0), y.stride(1), y.stride(2), y.stride(3), y.stride(4))
736
+
737
+ is_first_call_y = T == 1
738
+ _ar_y = q if is_first_call_y else ar
739
+ _ar_y_strides = q_strides if is_first_call_y else ar_strides
740
+
741
+ _al_y_cl_kernel[grid_ehm](
742
+ _ar_y,
743
+ *_ar_y_strides,
744
+ k,
745
+ *k_strides,
746
+ v,
747
+ *v_strides,
748
+ cr,
749
+ *cr_strides,
750
+ al,
751
+ *al_strides,
752
+ y,
753
+ *y_strides,
754
+ cl,
755
+ *cl_strides,
756
+ attn_mask,
757
+ *attn_mask_strides,
758
+ *HMBDN,
759
+ is_first_call=is_first_call_y,
760
+ sm_scale=sm_scale,
761
+ HAS_ATTN_MASK=attn_mask is not None, # type: ignore
762
+ BLOCK_B=BLOCK_B, # type: ignore
763
+ BLOCK_D=BLOCK_D, # type: ignore
764
+ PRE_PAD=pre_pad, # type: ignore
765
+ EPS=eps, # type: ignore
766
+ )
767
+
768
+ z = torch.empty_like(v)
769
+ z_strides = (z.stride(0), z.stride(1), B * z.stride(2), z.stride(2), z.stride(3))
770
+
771
+ _z_kernel[grid_ehb](
772
+ al,
773
+ *al_strides,
774
+ q,
775
+ *q_strides,
776
+ y,
777
+ *y_strides,
778
+ cl,
779
+ *cl_strides,
780
+ z,
781
+ *z_strides,
782
+ *HMBDN,
783
+ BLOCK_M=BLOCK_M, # type: ignore
784
+ BLOCK_D=BLOCK_D, # type: ignore
785
+ PRE_PAD=pre_pad, # type: ignore
786
+ )
787
+
788
+ return z
monarch_attn/monarch_attention.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ from enum import StrEnum
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .ma_torch import monarch_attention_torch
8
+
9
+ Tensor = torch.Tensor
10
+
11
+ MonarchAttentionFn = Callable[
12
+ [Tensor, Tensor, Tensor, Tensor | None, int, int, bool], Tensor
13
+ ]
14
+
15
+ _IMPLEMENTATIONS: dict[str, MonarchAttentionFn] = {}
16
+
17
+
18
+ def register_impl(name: str, fn: MonarchAttentionFn) -> None:
19
+ _IMPLEMENTATIONS[name] = fn
20
+
21
+
22
+ register_impl("torch", monarch_attention_torch)
23
+
24
+ try:
25
+ from .ma_triton import monarch_attention_triton
26
+
27
+ register_impl("triton", monarch_attention_triton)
28
+ except ImportError:
29
+ pass
30
+
31
+
32
+ class PadType(StrEnum):
33
+ pre = "pre"
34
+ post = "post"
35
+
36
+
37
+ class MonarchAttention(nn.Module):
38
+
39
+ def __init__(self, block_size, num_steps, pad_type, impl="torch"):
40
+ super().__init__()
41
+ self.block_size = block_size
42
+ self.num_steps = num_steps
43
+ self.pad_type = pad_type
44
+
45
+ if impl not in _IMPLEMENTATIONS:
46
+ available = ", ".join(sorted(_IMPLEMENTATIONS))
47
+ raise ValueError(f"Unknown impl {impl!r}. Available: {available}")
48
+ self._impl_fn = _IMPLEMENTATIONS[impl]
49
+
50
+ def forward(self, query, key, value, attention_mask=None):
51
+ return self._impl_fn(
52
+ query,
53
+ key,
54
+ value,
55
+ attention_mask,
56
+ self.num_steps,
57
+ self.block_size,
58
+ self.pad_type == PadType.pre,
59
+ )
60
+
61
+ def get_matrix(self, query, key, attention_mask=None):
62
+ batch_size, num_heads, seq_len, _ = query.shape
63
+ value = torch.eye(seq_len, device=query.device).expand(
64
+ batch_size, num_heads, seq_len, seq_len
65
+ )
66
+ return self.forward(query, key, value, attention_mask)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ einops
5
+ numpy
utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from PIL import Image
4
+
5
+ def smart_padding(image, divisor=16):
6
+ """ Pad the image so that its dimensions are divisible by the divisor. """
7
+ h, w = image.shape[-2:]
8
+ pad_h = (divisor - h % divisor) % divisor
9
+ pad_w = (divisor - w % divisor) % divisor
10
+ left = pad_w // 2
11
+ right = pad_w - left
12
+ top = pad_h // 2
13
+ bottom = pad_h - top
14
+ padding = (left, right, top, bottom)
15
+ padded_image = F.pad(image, padding, mode='constant', value=1.0)
16
+ return padded_image, padding
17
+
18
+ def remove_padding(image, padding):
19
+ """ Remove the padding from the image. """
20
+ left, right, top, bottom = padding
21
+ if right == 0:
22
+ w_end = image.shape[-1]
23
+ else:
24
+ w_end = -right
25
+ if bottom == 0:
26
+ h_end = image.shape[-2]
27
+ else:
28
+ h_end = -bottom
29
+ return image[..., top:h_end, left:w_end]
weights/colorizer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a8ed3573050cdd3fc9c9542ec9dd76e91143fd2052b1c379f87422c59ae60fc
3
+ size 32434763