Spaces:
Running
Running
anhth commited on
Commit ·
8314c30
1
Parent(s): cc2b90a
Initial Commit
Browse files- .gitignore +2 -0
- app.py +120 -0
- colorizer.py +420 -0
- monarch_attn/__init__.py +1 -0
- monarch_attn/ma_history.py +44 -0
- monarch_attn/ma_torch.py +108 -0
- monarch_attn/ma_triton.py +788 -0
- monarch_attn/monarch_attention.py +66 -0
- requirements.txt +5 -0
- utils.py +29 -0
- weights/colorizer.pth +3 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
app.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.transforms import v2
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from colorizer import ColorComicNet, MODEL_CFG
|
| 6 |
+
from utils import smart_padding, remove_padding
|
| 7 |
+
|
| 8 |
+
# Define the transformation pipeline for the input image
|
| 9 |
+
TRANSFORM = v2.Compose([
|
| 10 |
+
v2.ToImage(),
|
| 11 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 12 |
+
v2.Normalize(mean=[0.5], std=[0.5])
|
| 13 |
+
])
|
| 14 |
+
|
| 15 |
+
# Image preprocessing and postprocessing functions
|
| 16 |
+
def preprocess_image(image: Image.Image, divisor=16):
|
| 17 |
+
""" Preprocess the input PIL image for the model. """
|
| 18 |
+
image = image.convert('RGB')
|
| 19 |
+
image_tensor = TRANSFORM(image).unsqueeze(0) # Shape: (1, 3, H, W)
|
| 20 |
+
image_tensor, padding = smart_padding(image_tensor, divisor=divisor)
|
| 21 |
+
return image_tensor, padding
|
| 22 |
+
|
| 23 |
+
def postprocess_output(output_tensor, padding):
|
| 24 |
+
""" Postprocess the model output tensor to a PIL image. """
|
| 25 |
+
output_tensor = remove_padding(output_tensor, padding)
|
| 26 |
+
output_tensor = (output_tensor + 1) / 2 # Scale back to [0, 1]
|
| 27 |
+
output_image = output_tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).numpy() # Shape: (H, W, C)
|
| 28 |
+
return output_image
|
| 29 |
+
|
| 30 |
+
# Define the colorization function
|
| 31 |
+
def colorize_image(gray_image: Image.Image):
|
| 32 |
+
""" Colorize a single grayscale image using the model. """
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
# Preprocess
|
| 35 |
+
input_tensor, padding = preprocess_image(gray_image, divisor=64)
|
| 36 |
+
# Inference
|
| 37 |
+
output = model(input_tensor)
|
| 38 |
+
# Postprocess
|
| 39 |
+
output_image = postprocess_output(output, padding)
|
| 40 |
+
return output_image
|
| 41 |
+
|
| 42 |
+
# Initialize the model
|
| 43 |
+
model = ColorComicNet(MODEL_CFG)
|
| 44 |
+
model.load_state_dict(torch.load("./weights/colorizer.pth", map_location=torch.device('cpu')))
|
| 45 |
+
model.fuse()
|
| 46 |
+
model.eval()
|
| 47 |
+
|
| 48 |
+
# Create the Gradio interface
|
| 49 |
+
|
| 50 |
+
custom_css = """
|
| 51 |
+
body {
|
| 52 |
+
background: linear-gradient(135deg, #1e1e2f, #2a2a40);
|
| 53 |
+
color: white;
|
| 54 |
+
}
|
| 55 |
+
.gradio-container {
|
| 56 |
+
max-width: 1000px !important;
|
| 57 |
+
margin: auto;
|
| 58 |
+
}
|
| 59 |
+
.header {
|
| 60 |
+
text-align: center;
|
| 61 |
+
padding: 20px;
|
| 62 |
+
}
|
| 63 |
+
.header h1 {
|
| 64 |
+
font-size: 2.2rem;
|
| 65 |
+
margin-bottom: 5px;
|
| 66 |
+
}
|
| 67 |
+
.header p {
|
| 68 |
+
color: #cfcfe0;
|
| 69 |
+
}
|
| 70 |
+
.button-primary {
|
| 71 |
+
background: linear-gradient(90deg, #ff7a18, #ffb347);
|
| 72 |
+
border: none;
|
| 73 |
+
color: white;
|
| 74 |
+
font-weight: bold;
|
| 75 |
+
}
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
with gr.Blocks(css=custom_css) as demo:
|
| 79 |
+
# Header
|
| 80 |
+
with gr.Column(elem_classes="header"):
|
| 81 |
+
gr.Markdown("# 🎨 Comic Colorization")
|
| 82 |
+
gr.Markdown("Bring your grayscale comics to life with **ColorComicNet**")
|
| 83 |
+
with gr.Row(equal_height=True):
|
| 84 |
+
with gr.Column(scale=1):
|
| 85 |
+
input_image = gr.Image(
|
| 86 |
+
label="📥 Upload Grayscale Image",
|
| 87 |
+
type="pil",
|
| 88 |
+
)
|
| 89 |
+
colorize_button = gr.Button(
|
| 90 |
+
"✨ Colorize Image",
|
| 91 |
+
elem_classes="button-primary"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
with gr.Column(scale=1):
|
| 95 |
+
output_image = gr.Image(
|
| 96 |
+
label="📤 Colorized Result",
|
| 97 |
+
type="numpy",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Example section
|
| 101 |
+
# gr.Markdown("### 🖼️ Try an example")
|
| 102 |
+
# examples = gr.Examples(
|
| 103 |
+
# examples=[
|
| 104 |
+
# ["example1.png"],
|
| 105 |
+
# ["example2.png"]
|
| 106 |
+
# ],
|
| 107 |
+
# inputs=input_image
|
| 108 |
+
# )
|
| 109 |
+
|
| 110 |
+
# Footer
|
| 111 |
+
gr.Markdown("---")
|
| 112 |
+
|
| 113 |
+
# Interaction
|
| 114 |
+
colorize_button.click(
|
| 115 |
+
fn=colorize_image,
|
| 116 |
+
inputs=input_image,
|
| 117 |
+
outputs=output_image
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
demo.launch()
|
colorizer.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numbers
|
| 5 |
+
from dataclasses import dataclass, asdict
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
|
| 8 |
+
class UpSample(nn.Module):
|
| 9 |
+
""" UpSampling block using PixelShuffle """
|
| 10 |
+
def __init__(self, filters=64):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.conv = nn.Conv2d(filters, filters * 2, kernel_size=1, stride=1, padding=0, bias=True)
|
| 13 |
+
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
x = self.conv(x)
|
| 17 |
+
x = self.pixel_shuffle(x)
|
| 18 |
+
return x
|
| 19 |
+
|
| 20 |
+
## DownSampling block
|
| 21 |
+
class DownSample(nn.Module):
|
| 22 |
+
""" DownSampling block using PixelUnshuffle """
|
| 23 |
+
def __init__(self, filters=64):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.conv = nn.Conv2d(filters, filters // 2, kernel_size=1, stride=1, padding=0, bias=True)
|
| 26 |
+
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=2)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
""" SHAPE (B, C, H, W) -> SHAPE (B, C/4, H/2, W/2) """
|
| 30 |
+
x = self.conv(x)
|
| 31 |
+
x = self.pixel_unshuffle(x)
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
# Custom LayerNormalization
|
| 35 |
+
class BiasFree_LayerNorm(nn.Module):
|
| 36 |
+
""" Bias-Free Layer Normalization """
|
| 37 |
+
def __init__(self, normalized_shape):
|
| 38 |
+
super().__init__()
|
| 39 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 40 |
+
normalized_shape = (normalized_shape,)
|
| 41 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 42 |
+
|
| 43 |
+
assert len(normalized_shape) == 1
|
| 44 |
+
|
| 45 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 46 |
+
self.normalized_shape = normalized_shape
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
x = x.contiguous()
|
| 50 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 51 |
+
return x / torch.sqrt(sigma+1e-5) * self.weight
|
| 52 |
+
|
| 53 |
+
class WithBias_LayerNorm(nn.Module):
|
| 54 |
+
""" With-Bias Layer Normalization """
|
| 55 |
+
def __init__(self, normalized_shape):
|
| 56 |
+
super().__init__()
|
| 57 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 58 |
+
normalized_shape = (normalized_shape,)
|
| 59 |
+
normalized_shape = torch.Size(normalized_shape)
|
| 60 |
+
|
| 61 |
+
assert len(normalized_shape) == 1
|
| 62 |
+
|
| 63 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 64 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 65 |
+
self.normalized_shape = normalized_shape
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
x = x.contiguous()
|
| 69 |
+
mu = x.mean(-1, keepdim=True)
|
| 70 |
+
sigma = x.var(-1, keepdim=True, unbiased=False)
|
| 71 |
+
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
|
| 72 |
+
|
| 73 |
+
class LayerNorm(nn.Module):
|
| 74 |
+
""" Layer Normalization supporting two types: BiasFree and WithBias """
|
| 75 |
+
def __init__(self, dim, LayerNorm_type, out_4d=True):
|
| 76 |
+
super().__init__()
|
| 77 |
+
if LayerNorm_type =='BiasFree':
|
| 78 |
+
self.body = BiasFree_LayerNorm(dim)
|
| 79 |
+
else:
|
| 80 |
+
self.body = WithBias_LayerNorm(dim)
|
| 81 |
+
self.out_4d = out_4d
|
| 82 |
+
|
| 83 |
+
def to_3d(self, x):
|
| 84 |
+
# Convert (B, C, H, W) to (B, H*W, C)
|
| 85 |
+
if len(x.shape) == 3:
|
| 86 |
+
return x
|
| 87 |
+
elif len(x.shape) == 4:
|
| 88 |
+
return rearrange(x, 'b c h w -> b (h w) c')
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError("Input must be a 3D or 4D tensor")
|
| 91 |
+
|
| 92 |
+
def to_4d(self, x, h, w):
|
| 93 |
+
# Convert (B, H*W, C) to (B, C, H, W)
|
| 94 |
+
if len(x.shape) == 4:
|
| 95 |
+
return x
|
| 96 |
+
elif len(x.shape) == 3:
|
| 97 |
+
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError("Input must be a 3D or 4D tensor")
|
| 100 |
+
|
| 101 |
+
def forward(self, x):
|
| 102 |
+
if self.out_4d:
|
| 103 |
+
h, w = x.shape[-2:]
|
| 104 |
+
return self.to_4d(self.body(self.to_3d(x)), h, w)
|
| 105 |
+
else:
|
| 106 |
+
return self.body(x)
|
| 107 |
+
|
| 108 |
+
class RepConv3(nn.Module):
|
| 109 |
+
def __init__(self, in_channels, out_channels, groups, deploy=False):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.in_channels = in_channels
|
| 112 |
+
self.out_channels = out_channels
|
| 113 |
+
self.groups = groups
|
| 114 |
+
self.deploy = deploy
|
| 115 |
+
self.reparam = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups)
|
| 116 |
+
if not deploy:
|
| 117 |
+
self.conv_3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups)
|
| 118 |
+
self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups)
|
| 119 |
+
self.conv_1x3 = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 3), padding=(0, 1), groups=groups)
|
| 120 |
+
self.conv_3x1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 1), padding=(1, 0), groups=groups)
|
| 121 |
+
self.conv_1x1_branch = nn.Conv2d(in_channels, in_channels, kernel_size=1, groups=groups, bias=False)
|
| 122 |
+
self.conv_3x3_branch = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups, bias=False)
|
| 123 |
+
else:
|
| 124 |
+
self._delete_branches()
|
| 125 |
+
|
| 126 |
+
def _delete_branches(self):
|
| 127 |
+
for name in ['conv_3x3','conv_1x1','conv_1x3','conv_3x1', 'conv_1x1_branch', 'conv_3x3_branch']:
|
| 128 |
+
if hasattr(self, name):
|
| 129 |
+
delattr(self, name)
|
| 130 |
+
|
| 131 |
+
def fuse(self, delete_branches=True):
|
| 132 |
+
if self.deploy:
|
| 133 |
+
return
|
| 134 |
+
# Extract weights and biases
|
| 135 |
+
conv_3x3_w, conv_3x3_b = self.conv_3x3.weight, self.conv_3x3.bias
|
| 136 |
+
conv_1x1_w, conv_1x1_b = self.conv_1x1.weight, self.conv_1x1.bias
|
| 137 |
+
conv_1x3_w, conv_1x3_b = self.conv_1x3.weight, self.conv_1x3.bias
|
| 138 |
+
conv_3x1_w, conv_3x1_b = self.conv_3x1.weight, self.conv_3x1.bias
|
| 139 |
+
conv_1x1_branch_w, conv_3x3_branch_w = self.conv_1x1_branch.weight, self.conv_3x3_branch.weight
|
| 140 |
+
# Pad the smaller kernels to 3x3
|
| 141 |
+
conv_1x1_w_pad = F.pad(conv_1x1_w, [1, 1, 1, 1])
|
| 142 |
+
conv_1x3_w_pad = F.pad(conv_1x3_w, [0, 0, 1, 1])
|
| 143 |
+
conv_3x1_w_pad = F.pad(conv_3x1_w, [1, 1, 0, 0])
|
| 144 |
+
if self.groups == 1:
|
| 145 |
+
conv_1x1_3x3_w_pad = F.conv2d(conv_3x3_branch_w, conv_1x1_branch_w.permute(1, 0, 2, 3))
|
| 146 |
+
else:
|
| 147 |
+
w_slices = []
|
| 148 |
+
conv_1x1_branch_w_T = conv_1x1_branch_w.permute(1, 0, 2, 3)
|
| 149 |
+
in_channels_per_group = self.in_channels // self.groups
|
| 150 |
+
out_channels_per_group = self.out_channels // self.groups
|
| 151 |
+
for g in range(self.groups):
|
| 152 |
+
# Slice the transposed 1x1 weights for this group's channels
|
| 153 |
+
conv_1x1_branch_w_T_slice = conv_1x1_branch_w_T[:, g*in_channels_per_group:(g+1)*in_channels_per_group, :, :]
|
| 154 |
+
# Slice the 3x3 weights for this group's output channels
|
| 155 |
+
conv_3x3_branch_w_slice = conv_3x3_branch_w[g*out_channels_per_group:(g+1)*out_channels_per_group, :, :, :]
|
| 156 |
+
w_slices.append(F.conv2d(conv_3x3_branch_w_slice, conv_1x1_branch_w_T_slice))
|
| 157 |
+
conv_1x1_3x3_w_pad = torch.cat(w_slices, dim=0)
|
| 158 |
+
# Fuse weights and biases
|
| 159 |
+
conv_w = conv_3x3_w + conv_1x1_w_pad + conv_1x3_w_pad + conv_3x1_w_pad + conv_1x1_3x3_w_pad
|
| 160 |
+
if conv_3x3_b is None:
|
| 161 |
+
conv_3x3_b = torch.zeros(self.out_channels, device=conv_w.device)
|
| 162 |
+
conv_b = conv_3x3_b + conv_1x1_b + conv_1x3_b + conv_3x1_b
|
| 163 |
+
self.reparam.weight.data.copy_(conv_w)
|
| 164 |
+
self.reparam.bias.data.copy_(conv_b)
|
| 165 |
+
# Delete the original branches
|
| 166 |
+
if delete_branches:
|
| 167 |
+
self._delete_branches()
|
| 168 |
+
# Set deploy flag
|
| 169 |
+
self.deploy = True
|
| 170 |
+
|
| 171 |
+
def forward(self, x):
|
| 172 |
+
if self.deploy:
|
| 173 |
+
return self.reparam(x)
|
| 174 |
+
else:
|
| 175 |
+
return self.conv_3x3(x) + self.conv_1x1(x) + self.conv_1x3(x) + self.conv_3x1(x) + self.conv_3x3_branch(self.conv_1x1_branch(x))
|
| 176 |
+
|
| 177 |
+
from monarch_attn import MonarchAttention
|
| 178 |
+
|
| 179 |
+
@dataclass
|
| 180 |
+
class RepAttnConfig:
|
| 181 |
+
dim: int
|
| 182 |
+
num_heads: int = 8
|
| 183 |
+
block_size: int = 16
|
| 184 |
+
num_steps: int = 2
|
| 185 |
+
pad_type: str = "pre"
|
| 186 |
+
impl: str = "torch"
|
| 187 |
+
deploy: bool = False
|
| 188 |
+
|
| 189 |
+
class RepAttn(nn.Module):
|
| 190 |
+
""" Re-parameterizable Attention Block using MonarchAttention as the core attention mechanism."""
|
| 191 |
+
def __init__(self, dim, num_heads=8, block_size=14, num_steps=1, pad_type="pre", impl="torch", deploy=False):
|
| 192 |
+
super().__init__()
|
| 193 |
+
self.num_heads = num_heads
|
| 194 |
+
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1)
|
| 195 |
+
self.monarch_attn = MonarchAttention(
|
| 196 |
+
block_size=block_size,
|
| 197 |
+
num_steps=num_steps,
|
| 198 |
+
pad_type=pad_type,
|
| 199 |
+
impl=impl
|
| 200 |
+
)
|
| 201 |
+
if deploy:
|
| 202 |
+
self.attn_fn = self.monarch_attn
|
| 203 |
+
else:
|
| 204 |
+
self.attn_fn = self.common_attn
|
| 205 |
+
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
| 206 |
+
self.deploy = deploy
|
| 207 |
+
|
| 208 |
+
def common_attn(self, q, k, v):
|
| 209 |
+
""" Scaled Dot-Product Attention """
|
| 210 |
+
scale = (q.shape[-1]) ** -0.5
|
| 211 |
+
attn = (q @ k.transpose(-2, -1)) * scale
|
| 212 |
+
attn = attn.softmax(dim=-1)
|
| 213 |
+
out = attn @ v
|
| 214 |
+
return out
|
| 215 |
+
|
| 216 |
+
@torch.no_grad()
|
| 217 |
+
def fuse(self):
|
| 218 |
+
if not self.deploy:
|
| 219 |
+
self.attn_fn = self.monarch_attn
|
| 220 |
+
self.deploy = True
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
B, C, H, W = x.shape
|
| 224 |
+
qkv = self.qkv(x)
|
| 225 |
+
q, k, v = torch.chunk(qkv, 3, dim=1)
|
| 226 |
+
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 227 |
+
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 228 |
+
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
|
| 229 |
+
attn_out = self.attn_fn(q, k, v)
|
| 230 |
+
attn_out = rearrange(attn_out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=H, w=W)
|
| 231 |
+
out = self.proj(attn_out)
|
| 232 |
+
return out
|
| 233 |
+
|
| 234 |
+
@dataclass
|
| 235 |
+
class FFNConfig:
|
| 236 |
+
dim: int
|
| 237 |
+
expansion_factor: int = 1
|
| 238 |
+
deploy: bool = False
|
| 239 |
+
|
| 240 |
+
class RepFFN(nn.Module):
|
| 241 |
+
def __init__(self, dim, expansion_factor=1, deploy=False):
|
| 242 |
+
super().__init__()
|
| 243 |
+
hidden_features = int(dim * expansion_factor)
|
| 244 |
+
self.project_in = RepConv3(dim, hidden_features, groups=1, deploy=deploy)
|
| 245 |
+
self.dwconv = RepConv3(hidden_features, hidden_features*2, groups=hidden_features, deploy=deploy)
|
| 246 |
+
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1)
|
| 247 |
+
|
| 248 |
+
@torch.no_grad()
|
| 249 |
+
def fuse(self):
|
| 250 |
+
self.project_in.fuse()
|
| 251 |
+
self.dwconv.fuse()
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def forward(self, x):
|
| 255 |
+
x = self.project_in(x)
|
| 256 |
+
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
| 257 |
+
x = F.gelu(x1) * x2
|
| 258 |
+
x = self.project_out(x)
|
| 259 |
+
return x
|
| 260 |
+
|
| 261 |
+
class SkipConnection(nn.Module):
|
| 262 |
+
def __init__(self, dim):
|
| 263 |
+
super().__init__()
|
| 264 |
+
self.conv = nn.Conv2d(dim*2, dim, kernel_size=1)
|
| 265 |
+
|
| 266 |
+
def forward(self, x1, x2):
|
| 267 |
+
x = torch.cat([x1, x2], dim=1)
|
| 268 |
+
x = self.conv(x)
|
| 269 |
+
return x
|
| 270 |
+
|
| 271 |
+
class RepTransformerBlock(nn.Module):
|
| 272 |
+
def __init__(self, rep_attn_cfg: RepAttnConfig, ffn_cfg: FFNConfig, norm_type='WithBias'):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.rep_attn = RepAttn(**asdict(rep_attn_cfg))
|
| 275 |
+
self.rep_ffn = RepFFN(**asdict(ffn_cfg))
|
| 276 |
+
self.norm1 = LayerNorm(rep_attn_cfg.dim, norm_type)
|
| 277 |
+
self.norm2 = LayerNorm(rep_attn_cfg.dim, norm_type)
|
| 278 |
+
|
| 279 |
+
@torch.no_grad()
|
| 280 |
+
def fuse(self):
|
| 281 |
+
self.rep_attn.fuse()
|
| 282 |
+
self.rep_ffn.fuse()
|
| 283 |
+
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
x = x + self.rep_attn(self.norm1(x))
|
| 286 |
+
x = x + self.rep_ffn(self.norm2(x))
|
| 287 |
+
return x
|
| 288 |
+
|
| 289 |
+
class Block(nn.Module):
|
| 290 |
+
def __init__(self, num_block, rep_attn_cfg: RepAttnConfig, ffn_cfg: FFNConfig, norm_type='WithBias'):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.num_block = num_block
|
| 293 |
+
self.blocks = nn.ModuleList([
|
| 294 |
+
RepTransformerBlock(rep_attn_cfg, ffn_cfg, norm_type) for _ in range(num_block)
|
| 295 |
+
])
|
| 296 |
+
|
| 297 |
+
@torch.no_grad()
|
| 298 |
+
def fuse(self):
|
| 299 |
+
for block in self.blocks:
|
| 300 |
+
block.fuse()
|
| 301 |
+
|
| 302 |
+
def forward(self, x):
|
| 303 |
+
for block in self.blocks:
|
| 304 |
+
x = block(x)
|
| 305 |
+
return x
|
| 306 |
+
|
| 307 |
+
class ColorComicNet(nn.Module):
|
| 308 |
+
""" Main model implementation """
|
| 309 |
+
def __init__(self, input_shape=(3, 1024, 1024), output_channels=3, deploy=False, dims=[48, 96, 192, 384], num_blocks=[4, 6, 6, 8], num_heads=[1, 2, 2, 4], bias=True, last_act=None):
|
| 310 |
+
super().__init__()
|
| 311 |
+
assert len(dims) == len(num_blocks) == len(num_heads), "Length of dims, num_blocks and num_heads must be the same"
|
| 312 |
+
self.input_shape = input_shape
|
| 313 |
+
self.output_channels = output_channels
|
| 314 |
+
self.deploy = deploy
|
| 315 |
+
self.dims = dims
|
| 316 |
+
self.num_blocks = num_blocks
|
| 317 |
+
self.bias = bias
|
| 318 |
+
self.num_heads = num_heads
|
| 319 |
+
|
| 320 |
+
# Extractor
|
| 321 |
+
self.stem = nn.Conv2d(input_shape[0], dims[0], kernel_size=7, stride=4, padding=3, bias=bias)
|
| 322 |
+
|
| 323 |
+
# Encoder
|
| 324 |
+
layers = []
|
| 325 |
+
down_convs = []
|
| 326 |
+
for idx in range(len(dims)):
|
| 327 |
+
attn_cfg, ffn_cfg = self.build_cfg(dims[idx], num_heads[idx])
|
| 328 |
+
block = Block(num_blocks[idx], attn_cfg, ffn_cfg, norm_type='WithBias')
|
| 329 |
+
if idx < len(dims) - 1:
|
| 330 |
+
down_convs.append(DownSample(dims[idx]))
|
| 331 |
+
layers.append(block)
|
| 332 |
+
self.bottleneck = layers[-1] # Last encoder layer as bottleneck
|
| 333 |
+
self.encoder = nn.ModuleList(layers[:-1])
|
| 334 |
+
self.downsample = nn.ModuleList(down_convs)
|
| 335 |
+
|
| 336 |
+
# Decoder
|
| 337 |
+
layers = []
|
| 338 |
+
up_convs = []
|
| 339 |
+
skip_connections = []
|
| 340 |
+
for idx in range(len(dims)-2, -1, -1):
|
| 341 |
+
attn_cfg, ffn_cfg = self.build_cfg(dims[idx], num_heads[idx])
|
| 342 |
+
# print(f"Decoder layer {idx}: shape {l_shape}")
|
| 343 |
+
up_conv = UpSample(dims[idx+1])
|
| 344 |
+
block = Block(num_blocks[idx], attn_cfg, ffn_cfg, norm_type='WithBias')
|
| 345 |
+
layers.append(block)
|
| 346 |
+
up_convs.append(up_conv)
|
| 347 |
+
skip_connections.append(SkipConnection(dims[idx]))
|
| 348 |
+
self.decoder = nn.ModuleList(layers)
|
| 349 |
+
self.up_sample = nn.ModuleList(up_convs)
|
| 350 |
+
self.skip = nn.ModuleList(skip_connections)
|
| 351 |
+
|
| 352 |
+
# Head
|
| 353 |
+
self.head = nn.Sequential(
|
| 354 |
+
RepConv3(dims[0], dims[0]//2, 1, deploy=deploy),
|
| 355 |
+
nn.GELU(),
|
| 356 |
+
nn.Conv2d(dims[0]//2, output_channels, kernel_size=1, bias=bias),
|
| 357 |
+
)
|
| 358 |
+
self.last_act = last_act if last_act is not None else nn.Identity()
|
| 359 |
+
|
| 360 |
+
@torch.no_grad()
|
| 361 |
+
def fuse(self):
|
| 362 |
+
for block in self.encoder:
|
| 363 |
+
block.fuse()
|
| 364 |
+
self.bottleneck.fuse()
|
| 365 |
+
for block in self.decoder:
|
| 366 |
+
block.fuse()
|
| 367 |
+
for conv in self.head:
|
| 368 |
+
if isinstance(conv, RepConv3):
|
| 369 |
+
conv.fuse()
|
| 370 |
+
|
| 371 |
+
def build_cfg(self, dim, head):
|
| 372 |
+
# RepAttn config
|
| 373 |
+
attn_cfg = RepAttnConfig(
|
| 374 |
+
dim=dim,
|
| 375 |
+
num_heads=head,
|
| 376 |
+
block_size=12,
|
| 377 |
+
num_steps=2,
|
| 378 |
+
pad_type="pre",
|
| 379 |
+
impl="torch",
|
| 380 |
+
deploy=self.deploy
|
| 381 |
+
)
|
| 382 |
+
## FFN config
|
| 383 |
+
ffn_cfg = FFNConfig(
|
| 384 |
+
dim=dim,
|
| 385 |
+
expansion_factor=1,
|
| 386 |
+
)
|
| 387 |
+
return attn_cfg, ffn_cfg
|
| 388 |
+
|
| 389 |
+
def forward(self, x):
|
| 390 |
+
"""
|
| 391 |
+
x: (B, C, H, W)
|
| 392 |
+
"""
|
| 393 |
+
res = x
|
| 394 |
+
x = self.stem(x)
|
| 395 |
+
feats = []
|
| 396 |
+
for blk, down in zip(self.encoder, self.downsample):
|
| 397 |
+
x = blk(x)
|
| 398 |
+
feats.append(x)
|
| 399 |
+
x = down(x)
|
| 400 |
+
x = self.bottleneck(x)
|
| 401 |
+
for blk, up, skip in zip(self.decoder, self.up_sample, self.skip):
|
| 402 |
+
x = up(x)
|
| 403 |
+
cur_feat = feats.pop()
|
| 404 |
+
x = skip(x, cur_feat)
|
| 405 |
+
x = blk(x)
|
| 406 |
+
x = F.interpolate(x, scale_factor=4, mode='bilinear')
|
| 407 |
+
x = self.head(x) + res
|
| 408 |
+
x = self.last_act(x)
|
| 409 |
+
return x
|
| 410 |
+
|
| 411 |
+
# Example model configuration
|
| 412 |
+
MODEL_CFG = {
|
| 413 |
+
'input_shape': (3, 512, 512),
|
| 414 |
+
'dims': [24, 48, 96, 192],
|
| 415 |
+
'num_blocks': [1, 2, 2, 4],
|
| 416 |
+
'num_heads': [1, 2, 4, 8],
|
| 417 |
+
'bias': True,
|
| 418 |
+
'last_act': nn.Tanh(),
|
| 419 |
+
'deploy': False
|
| 420 |
+
}
|
monarch_attn/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .monarch_attention import MonarchAttention
|
monarch_attn/ma_history.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import sqrt
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
Tensor = torch.Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def monarch_matrix(L: Tensor, R: Tensor) -> Tensor:
|
| 11 |
+
out = torch.einsum("jkl,kji->ljki", L, R)
|
| 12 |
+
return rearrange(out, "l j k i -> (l j) (k i)")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def monarch_attention_history(q: Tensor, k: Tensor, T: int, B: int) -> list[Tensor]:
|
| 16 |
+
N, D = q.shape
|
| 17 |
+
M = N // B
|
| 18 |
+
|
| 19 |
+
q = q / sqrt(D)
|
| 20 |
+
|
| 21 |
+
qb = rearrange(q, "(l j) v -> j l v", j=B)
|
| 22 |
+
kb = rearrange(k, "(k i) v -> k i v", i=B)
|
| 23 |
+
|
| 24 |
+
L = torch.stack(B * [torch.eye(M, device=q.device)])
|
| 25 |
+
|
| 26 |
+
history = []
|
| 27 |
+
|
| 28 |
+
# Alternating maximization for L, R
|
| 29 |
+
for t in range(T):
|
| 30 |
+
# R update
|
| 31 |
+
aR = torch.einsum("jkl,jlv->kjv", L, qb)
|
| 32 |
+
bR = torch.einsum("kjv,kiv->kji", aR, kb)
|
| 33 |
+
cR = torch.einsum("jkl->kj", L)
|
| 34 |
+
R = F.softmax(bR / cR[:, :, None], dim=2)
|
| 35 |
+
|
| 36 |
+
history.append(monarch_matrix(L, R))
|
| 37 |
+
|
| 38 |
+
# L update
|
| 39 |
+
aL = torch.einsum("kji,kiv->jkv", R, kb)
|
| 40 |
+
bL = torch.einsum("jkv,jlv->jkl", aL, qb)
|
| 41 |
+
cL = torch.einsum("kji->jk", R * torch.log(R))
|
| 42 |
+
L = F.softmax(bL - cL[:, :, None], dim=1)
|
| 43 |
+
|
| 44 |
+
return history
|
monarch_attn/ma_torch.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import sqrt
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
Tensor = torch.Tensor
|
| 7 |
+
xlogy = torch.special.xlogy
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def al_cl_ref(ar, k, cr, sm_scale, mask, eps=1e-12):
|
| 11 |
+
r_hat = sm_scale * (ar @ k.transpose(-1, -2)).to(torch.float)
|
| 12 |
+
r_hat = r_hat / (cr[..., :, None] + eps)
|
| 13 |
+
r_hat = r_hat + torch.where(mask[..., None, :], 0.0, -float("inf"))
|
| 14 |
+
r_hat = torch.exp(
|
| 15 |
+
r_hat - torch.clamp(torch.max(r_hat, dim=-1, keepdim=True).values, min=eps)
|
| 16 |
+
)
|
| 17 |
+
r = r_hat / (torch.sum(r_hat, dim=-1, keepdim=True) + eps)
|
| 18 |
+
r = torch.clamp(r, min=torch.finfo(r.dtype).tiny)
|
| 19 |
+
|
| 20 |
+
cl = torch.sum(xlogy(r, r), dim=-1).transpose(-1, -2)
|
| 21 |
+
al = sm_scale * (r.to(k.dtype) @ k).transpose(-2, -3)
|
| 22 |
+
|
| 23 |
+
return al, cl
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def ar_cr_ref(al, q, cl, mask_t):
|
| 27 |
+
l_hat = (al @ q.transpose(-1, -2)).to(torch.float)
|
| 28 |
+
l_hat = l_hat - cl[..., :, None]
|
| 29 |
+
l = F.softmax(l_hat, dim=-2)
|
| 30 |
+
l = mask_t[..., None, :] * l
|
| 31 |
+
|
| 32 |
+
cr = torch.sum(l, dim=-1).transpose(-1, -2)
|
| 33 |
+
ar = (l.to(q.dtype) @ q).transpose(-2, -3)
|
| 34 |
+
|
| 35 |
+
return ar, cr
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def al_y_cl_ref(ar, k, v, cr, sm_scale, mask, eps=1e-12):
|
| 39 |
+
r_hat = sm_scale * (ar @ k.transpose(-1, -2)).to(torch.float)
|
| 40 |
+
r_hat = r_hat / (cr[..., :, None] + eps)
|
| 41 |
+
r_hat = r_hat + torch.where(mask[..., None, :], 0.0, -float("inf"))
|
| 42 |
+
r_hat = torch.exp(
|
| 43 |
+
r_hat - torch.clamp(torch.max(r_hat, dim=-1, keepdim=True).values, min=eps)
|
| 44 |
+
)
|
| 45 |
+
r = r_hat / (torch.sum(r_hat, dim=-1, keepdim=True) + eps)
|
| 46 |
+
r = torch.clamp(r, min=torch.finfo(r.dtype).tiny)
|
| 47 |
+
|
| 48 |
+
cl = torch.sum(xlogy(r, r), dim=-1).transpose(-1, -2)
|
| 49 |
+
al = sm_scale * (r.to(k.dtype) @ k).transpose(-2, -3)
|
| 50 |
+
y = (r.to(v.dtype) @ v).transpose(-2, -3)
|
| 51 |
+
|
| 52 |
+
return al, y, cl
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def z_ref(al, q, cl, y):
|
| 56 |
+
l_hat = (q @ al.transpose(-1, -2)).to(torch.float)
|
| 57 |
+
l_hat = l_hat - cl[..., None, :]
|
| 58 |
+
l = F.softmax(l_hat, dim=-1)
|
| 59 |
+
|
| 60 |
+
z = (l.to(y.dtype) @ y).transpose(-2, -3).contiguous()
|
| 61 |
+
|
| 62 |
+
return z
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def monarch_attention_torch(
|
| 66 |
+
q: Tensor,
|
| 67 |
+
k: Tensor,
|
| 68 |
+
v: Tensor,
|
| 69 |
+
attn_mask: Tensor | None,
|
| 70 |
+
T: int,
|
| 71 |
+
B: int,
|
| 72 |
+
pre_pad: bool,
|
| 73 |
+
) -> Tensor:
|
| 74 |
+
E, H, N, D = q.shape
|
| 75 |
+
_, _, _, Dv = v.shape
|
| 76 |
+
M = (N + B - 1) // B
|
| 77 |
+
N_padded = M * B
|
| 78 |
+
|
| 79 |
+
sm_scale = 1 / sqrt(D)
|
| 80 |
+
|
| 81 |
+
pad_t = (N_padded - N, 0) if pre_pad else (0, N_padded - N)
|
| 82 |
+
pad_t_2d = (0, 0) + pad_t
|
| 83 |
+
|
| 84 |
+
q = F.pad(q, pad_t_2d).view(E, H, M, B, D)
|
| 85 |
+
k = F.pad(k, pad_t_2d).view(E, H, M, B, D)
|
| 86 |
+
v = F.pad(v, pad_t_2d).view(E, H, M, B, Dv)
|
| 87 |
+
|
| 88 |
+
ar = q
|
| 89 |
+
cr = torch.ones(E, H, M, B, device=q.device, dtype=torch.float)
|
| 90 |
+
q = q.transpose(-2, -3)
|
| 91 |
+
|
| 92 |
+
pad_offset = N_padded - N if pre_pad else 0
|
| 93 |
+
range_n = torch.arange(M * B).view(M, B).to(q.device)
|
| 94 |
+
mask = range_n >= pad_offset if pre_pad else range_n < N
|
| 95 |
+
|
| 96 |
+
if attn_mask is not None:
|
| 97 |
+
attn_mask = F.pad(attn_mask, pad_t).view(E, 1, M, B)
|
| 98 |
+
mask = torch.logical_and(mask, attn_mask)
|
| 99 |
+
|
| 100 |
+
for _ in range(T - 1):
|
| 101 |
+
al, cl = al_cl_ref(ar, k, cr, sm_scale, mask)
|
| 102 |
+
ar, cr = ar_cr_ref(al, q, cl, mask.mT)
|
| 103 |
+
|
| 104 |
+
al, y, cl = al_y_cl_ref(ar, k, v, cr, sm_scale, mask)
|
| 105 |
+
z = z_ref(al, q, cl, y)
|
| 106 |
+
z = z.view(E, H, N_padded, Dv)
|
| 107 |
+
|
| 108 |
+
return z[..., N_padded - N :, :] if pre_pad else z[..., :N, :]
|
monarch_attn/ma_triton.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import sqrt
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
Tensor = torch.Tensor
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@triton.jit
|
| 12 |
+
def xlogx(x):
|
| 13 |
+
return tl.where(x == 0, 0.0, x * tl.log(x))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@triton.jit
|
| 17 |
+
def _al_cl_kernel(
|
| 18 |
+
ar_ptr,
|
| 19 |
+
stride_ar_e,
|
| 20 |
+
stride_ar_h,
|
| 21 |
+
stride_ar_m,
|
| 22 |
+
stride_ar_b,
|
| 23 |
+
stride_ar_d,
|
| 24 |
+
k_ptr,
|
| 25 |
+
stride_k_e,
|
| 26 |
+
stride_k_h,
|
| 27 |
+
stride_k_m,
|
| 28 |
+
stride_k_b,
|
| 29 |
+
stride_k_d,
|
| 30 |
+
cr_ptr,
|
| 31 |
+
stride_cr_e,
|
| 32 |
+
stride_cr_h,
|
| 33 |
+
stride_cr_m,
|
| 34 |
+
stride_cr_b,
|
| 35 |
+
al_ptr,
|
| 36 |
+
stride_al_e,
|
| 37 |
+
stride_al_h,
|
| 38 |
+
stride_al_m,
|
| 39 |
+
stride_al_b,
|
| 40 |
+
stride_al_d,
|
| 41 |
+
cl_ptr,
|
| 42 |
+
stride_cl_e,
|
| 43 |
+
stride_cl_h,
|
| 44 |
+
stride_cl_m,
|
| 45 |
+
stride_cl_b,
|
| 46 |
+
mask_ptr,
|
| 47 |
+
stride_mask_e,
|
| 48 |
+
stride_mask_m,
|
| 49 |
+
stride_mask_b,
|
| 50 |
+
H: int,
|
| 51 |
+
M: int,
|
| 52 |
+
B: int,
|
| 53 |
+
D: int,
|
| 54 |
+
N: int,
|
| 55 |
+
is_first_call: int,
|
| 56 |
+
sm_scale: float,
|
| 57 |
+
HAS_ATTN_MASK: tl.constexpr,
|
| 58 |
+
BLOCK_B: tl.constexpr,
|
| 59 |
+
BLOCK_D: tl.constexpr,
|
| 60 |
+
PRE_PAD: tl.constexpr,
|
| 61 |
+
EPS: tl.constexpr,
|
| 62 |
+
):
|
| 63 |
+
idx_ehm = tl.program_id(0)
|
| 64 |
+
idx_eh = idx_ehm // M
|
| 65 |
+
idx_e = idx_eh // H
|
| 66 |
+
idx_h = idx_eh % H
|
| 67 |
+
idx_m = idx_ehm % M
|
| 68 |
+
|
| 69 |
+
pad_offset = M * B - N if PRE_PAD else 0
|
| 70 |
+
|
| 71 |
+
range_b = tl.arange(0, BLOCK_B)
|
| 72 |
+
range_d = tl.arange(0, BLOCK_D)
|
| 73 |
+
range_n = B * idx_m + range_b
|
| 74 |
+
|
| 75 |
+
mask_b = range_b < B
|
| 76 |
+
pad_mask_b = mask_b & ((range_n >= pad_offset) if PRE_PAD else (range_n < N))
|
| 77 |
+
k_mask_b = pad_mask_b
|
| 78 |
+
mask_d = range_d < D
|
| 79 |
+
|
| 80 |
+
if HAS_ATTN_MASK:
|
| 81 |
+
mask_block_ptr = (
|
| 82 |
+
mask_ptr
|
| 83 |
+
+ stride_mask_e * idx_e
|
| 84 |
+
+ stride_mask_m * idx_m
|
| 85 |
+
+ stride_mask_b * (range_b - pad_offset)
|
| 86 |
+
)
|
| 87 |
+
valid_token_mask = tl.load(
|
| 88 |
+
mask_block_ptr,
|
| 89 |
+
mask=pad_mask_b,
|
| 90 |
+
other=0,
|
| 91 |
+
)
|
| 92 |
+
k_mask_b = pad_mask_b & valid_token_mask
|
| 93 |
+
|
| 94 |
+
# Load ar
|
| 95 |
+
ar_block_ptr = (
|
| 96 |
+
ar_ptr
|
| 97 |
+
+ stride_ar_e * idx_e
|
| 98 |
+
+ stride_ar_h * idx_h
|
| 99 |
+
+ stride_ar_m * idx_m
|
| 100 |
+
+ (
|
| 101 |
+
stride_ar_b * (range_b - (pad_offset if is_first_call else 0))[:, None]
|
| 102 |
+
+ stride_ar_d * range_d[None, :]
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
ar = tl.load(
|
| 106 |
+
ar_block_ptr,
|
| 107 |
+
mask=(pad_mask_b if is_first_call else mask_b)[:, None] & mask_d[None, :],
|
| 108 |
+
other=0.0,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Load k
|
| 112 |
+
k_block_ptr = (
|
| 113 |
+
k_ptr
|
| 114 |
+
+ stride_k_e * idx_e
|
| 115 |
+
+ stride_k_h * idx_h
|
| 116 |
+
+ stride_k_m * idx_m
|
| 117 |
+
+ (stride_k_b * (range_b - pad_offset)[:, None] + stride_k_d * range_d[None, :])
|
| 118 |
+
)
|
| 119 |
+
k = tl.load(
|
| 120 |
+
k_block_ptr,
|
| 121 |
+
mask=k_mask_b[:, None] & mask_d[None, :],
|
| 122 |
+
other=0.0,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Load cr
|
| 126 |
+
cr_block_ptr = (
|
| 127 |
+
cr_ptr
|
| 128 |
+
+ stride_cr_e * idx_e
|
| 129 |
+
+ stride_cr_h * idx_h
|
| 130 |
+
+ stride_cr_m * idx_m
|
| 131 |
+
+ (stride_cr_b * range_b)
|
| 132 |
+
)
|
| 133 |
+
cr = tl.load(cr_block_ptr, mask=mask_b, other=1.0)
|
| 134 |
+
|
| 135 |
+
# Attention matrix
|
| 136 |
+
r = sm_scale * tl.dot(ar, tl.trans(k))
|
| 137 |
+
r = r / (cr[:, None] + EPS)
|
| 138 |
+
r = r + tl.where(k_mask_b[None, :], 0.0, float("-inf"))
|
| 139 |
+
r = tl.exp(r - tl.clamp(tl.max(r, axis=1, keep_dims=True), EPS, float("inf")))
|
| 140 |
+
r = r / (tl.sum(r, axis=1, keep_dims=True) + EPS)
|
| 141 |
+
|
| 142 |
+
# Store cl
|
| 143 |
+
cl = tl.sum(xlogx(r), axis=1)
|
| 144 |
+
cl_block_ptr = (
|
| 145 |
+
cl_ptr
|
| 146 |
+
+ stride_cl_e * idx_e
|
| 147 |
+
+ stride_cl_h * idx_h
|
| 148 |
+
+ stride_cl_m * idx_m
|
| 149 |
+
+ (stride_cl_b * range_b)
|
| 150 |
+
)
|
| 151 |
+
tl.store(cl_block_ptr, cl, mask=mask_b)
|
| 152 |
+
|
| 153 |
+
# Store al
|
| 154 |
+
al = (sm_scale * tl.dot(r.to(k.dtype), k)).to(ar.dtype)
|
| 155 |
+
al_block_ptr = (
|
| 156 |
+
al_ptr
|
| 157 |
+
+ stride_al_e * idx_e
|
| 158 |
+
+ stride_al_h * idx_h
|
| 159 |
+
+ stride_al_m * idx_m
|
| 160 |
+
+ (stride_al_b * range_b[:, None] + stride_al_d * range_d[None, :])
|
| 161 |
+
)
|
| 162 |
+
tl.store(
|
| 163 |
+
al_block_ptr,
|
| 164 |
+
al,
|
| 165 |
+
mask=mask_b[:, None] & mask_d[None, :],
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@triton.jit
|
| 170 |
+
def _ar_cr_kernel(
|
| 171 |
+
al_ptr,
|
| 172 |
+
stride_al_e,
|
| 173 |
+
stride_al_h,
|
| 174 |
+
stride_al_m,
|
| 175 |
+
stride_al_b,
|
| 176 |
+
stride_al_d,
|
| 177 |
+
q_ptr,
|
| 178 |
+
stride_q_e,
|
| 179 |
+
stride_q_h,
|
| 180 |
+
stride_q_m,
|
| 181 |
+
stride_q_b,
|
| 182 |
+
stride_q_d,
|
| 183 |
+
cl_ptr,
|
| 184 |
+
stride_cl_e,
|
| 185 |
+
stride_cl_h,
|
| 186 |
+
stride_cl_m,
|
| 187 |
+
stride_cl_b,
|
| 188 |
+
ar_ptr,
|
| 189 |
+
stride_ar_e,
|
| 190 |
+
stride_ar_h,
|
| 191 |
+
stride_ar_m,
|
| 192 |
+
stride_ar_b,
|
| 193 |
+
stride_ar_d,
|
| 194 |
+
cr_ptr,
|
| 195 |
+
stride_cr_e,
|
| 196 |
+
stride_cr_h,
|
| 197 |
+
stride_cr_m,
|
| 198 |
+
stride_cr_b,
|
| 199 |
+
mask_ptr,
|
| 200 |
+
stride_mask_e,
|
| 201 |
+
stride_mask_m,
|
| 202 |
+
stride_mask_b,
|
| 203 |
+
H: int,
|
| 204 |
+
M: int,
|
| 205 |
+
B: int,
|
| 206 |
+
D: int,
|
| 207 |
+
N: int,
|
| 208 |
+
HAS_ATTN_MASK: tl.constexpr,
|
| 209 |
+
BLOCK_M: tl.constexpr,
|
| 210 |
+
BLOCK_D: tl.constexpr,
|
| 211 |
+
PRE_PAD: tl.constexpr,
|
| 212 |
+
):
|
| 213 |
+
idx_ehb = tl.program_id(0)
|
| 214 |
+
idx_eh = idx_ehb // B
|
| 215 |
+
idx_e = idx_eh // H
|
| 216 |
+
idx_h = idx_eh % H
|
| 217 |
+
idx_b = idx_ehb % B
|
| 218 |
+
|
| 219 |
+
pad_offset = M * B - N if PRE_PAD else 0
|
| 220 |
+
|
| 221 |
+
range_m = tl.arange(0, BLOCK_M)
|
| 222 |
+
range_d = tl.arange(0, BLOCK_D)
|
| 223 |
+
range_n = idx_b + B * range_m
|
| 224 |
+
|
| 225 |
+
mask_m = range_m < M
|
| 226 |
+
q_mask_m = mask_m & (range_n >= pad_offset if PRE_PAD else range_n < N)
|
| 227 |
+
mask_d = range_d < D
|
| 228 |
+
|
| 229 |
+
if HAS_ATTN_MASK:
|
| 230 |
+
mask_block_ptr = (
|
| 231 |
+
mask_ptr
|
| 232 |
+
+ stride_mask_e * idx_e
|
| 233 |
+
+ stride_mask_b * (idx_b - pad_offset)
|
| 234 |
+
+ stride_mask_m * range_m
|
| 235 |
+
)
|
| 236 |
+
valid_token_mask = tl.load(
|
| 237 |
+
mask_block_ptr,
|
| 238 |
+
mask=q_mask_m,
|
| 239 |
+
other=0,
|
| 240 |
+
)
|
| 241 |
+
q_mask_m = q_mask_m & valid_token_mask
|
| 242 |
+
|
| 243 |
+
# Load al
|
| 244 |
+
al_block_ptr = (
|
| 245 |
+
al_ptr
|
| 246 |
+
+ stride_al_e * idx_e
|
| 247 |
+
+ stride_al_h * idx_h
|
| 248 |
+
+ stride_al_b * idx_b
|
| 249 |
+
+ (stride_al_m * range_m[:, None] + stride_al_d * range_d[None, :])
|
| 250 |
+
)
|
| 251 |
+
al = tl.load(
|
| 252 |
+
al_block_ptr,
|
| 253 |
+
mask=mask_m[:, None] & mask_d[None, :],
|
| 254 |
+
other=0.0,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Load q
|
| 258 |
+
q_block_ptr = (
|
| 259 |
+
q_ptr
|
| 260 |
+
+ stride_q_e * idx_e
|
| 261 |
+
+ stride_q_h * idx_h
|
| 262 |
+
+ stride_q_b * (idx_b - pad_offset)
|
| 263 |
+
+ (stride_q_m * range_m[:, None] + stride_q_d * range_d[None, :])
|
| 264 |
+
)
|
| 265 |
+
q = tl.load(
|
| 266 |
+
q_block_ptr,
|
| 267 |
+
mask=q_mask_m[:, None] & mask_d[None, :],
|
| 268 |
+
other=0.0,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# Load cl
|
| 272 |
+
cl_block_ptr = (
|
| 273 |
+
cl_ptr
|
| 274 |
+
+ stride_cl_e * idx_e
|
| 275 |
+
+ stride_cl_h * idx_h
|
| 276 |
+
+ stride_cl_b * idx_b
|
| 277 |
+
+ (stride_cl_m * range_m)
|
| 278 |
+
)
|
| 279 |
+
cl = tl.load(cl_block_ptr, mask=mask_m, other=0.0)
|
| 280 |
+
|
| 281 |
+
# Attention matrix
|
| 282 |
+
l = tl.dot(al, tl.trans(q))
|
| 283 |
+
l = l - cl[:, None]
|
| 284 |
+
l = l + tl.where(mask_m[:, None], 0.0, float("-inf"))
|
| 285 |
+
l = tl.exp(l - tl.max(l, axis=0, keep_dims=True))
|
| 286 |
+
l = l / tl.sum(l, axis=0, keep_dims=True)
|
| 287 |
+
l = q_mask_m[None, :] * l
|
| 288 |
+
|
| 289 |
+
# Store cr
|
| 290 |
+
cr = tl.sum(l, axis=1)
|
| 291 |
+
cr_block_ptr = (
|
| 292 |
+
cr_ptr
|
| 293 |
+
+ stride_cr_e * idx_e
|
| 294 |
+
+ stride_cr_h * idx_h
|
| 295 |
+
+ stride_cr_b * idx_b
|
| 296 |
+
+ (stride_cr_m * range_m)
|
| 297 |
+
)
|
| 298 |
+
tl.store(cr_block_ptr, cr, mask=mask_m)
|
| 299 |
+
|
| 300 |
+
# Store ar
|
| 301 |
+
ar = tl.dot(l.to(q.dtype), q).to(al.dtype)
|
| 302 |
+
ar_block_ptr = (
|
| 303 |
+
ar_ptr
|
| 304 |
+
+ stride_ar_e * idx_e
|
| 305 |
+
+ stride_ar_h * idx_h
|
| 306 |
+
+ stride_ar_b * idx_b
|
| 307 |
+
+ (stride_ar_m * range_m[:, None] + stride_ar_d * range_d[None, :])
|
| 308 |
+
)
|
| 309 |
+
tl.store(
|
| 310 |
+
ar_block_ptr,
|
| 311 |
+
ar,
|
| 312 |
+
mask=mask_m[:, None] & mask_d[None, :],
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
@triton.jit
|
| 317 |
+
def _al_y_cl_kernel(
|
| 318 |
+
ar_ptr,
|
| 319 |
+
stride_ar_e,
|
| 320 |
+
stride_ar_h,
|
| 321 |
+
stride_ar_m,
|
| 322 |
+
stride_ar_b,
|
| 323 |
+
stride_ar_d,
|
| 324 |
+
k_ptr,
|
| 325 |
+
stride_k_e,
|
| 326 |
+
stride_k_h,
|
| 327 |
+
stride_k_m,
|
| 328 |
+
stride_k_b,
|
| 329 |
+
stride_k_d,
|
| 330 |
+
v_ptr,
|
| 331 |
+
stride_v_e,
|
| 332 |
+
stride_v_h,
|
| 333 |
+
stride_v_m,
|
| 334 |
+
stride_v_b,
|
| 335 |
+
stride_v_d,
|
| 336 |
+
cr_ptr,
|
| 337 |
+
stride_cr_e,
|
| 338 |
+
stride_cr_h,
|
| 339 |
+
stride_cr_m,
|
| 340 |
+
stride_cr_b,
|
| 341 |
+
al_ptr,
|
| 342 |
+
stride_al_e,
|
| 343 |
+
stride_al_h,
|
| 344 |
+
stride_al_m,
|
| 345 |
+
stride_al_b,
|
| 346 |
+
stride_al_d,
|
| 347 |
+
y_ptr,
|
| 348 |
+
stride_y_e,
|
| 349 |
+
stride_y_h,
|
| 350 |
+
stride_y_m,
|
| 351 |
+
stride_y_b,
|
| 352 |
+
stride_y_d,
|
| 353 |
+
cl_ptr,
|
| 354 |
+
stride_cl_e,
|
| 355 |
+
stride_cl_h,
|
| 356 |
+
stride_cl_m,
|
| 357 |
+
stride_cl_b,
|
| 358 |
+
mask_ptr,
|
| 359 |
+
stride_mask_e,
|
| 360 |
+
stride_mask_m,
|
| 361 |
+
stride_mask_b,
|
| 362 |
+
H: int,
|
| 363 |
+
M: int,
|
| 364 |
+
B: int,
|
| 365 |
+
D: int,
|
| 366 |
+
N: int,
|
| 367 |
+
is_first_call: int,
|
| 368 |
+
sm_scale: float,
|
| 369 |
+
HAS_ATTN_MASK: tl.constexpr,
|
| 370 |
+
BLOCK_B: tl.constexpr,
|
| 371 |
+
BLOCK_D: tl.constexpr,
|
| 372 |
+
PRE_PAD: tl.constexpr,
|
| 373 |
+
EPS: tl.constexpr,
|
| 374 |
+
):
|
| 375 |
+
idx_ehm = tl.program_id(0)
|
| 376 |
+
idx_eh = idx_ehm // M
|
| 377 |
+
idx_e = idx_eh // H
|
| 378 |
+
idx_h = idx_eh % H
|
| 379 |
+
idx_m = idx_ehm % M
|
| 380 |
+
|
| 381 |
+
pad_offset = M * B - N if PRE_PAD else 0
|
| 382 |
+
|
| 383 |
+
range_b = tl.arange(0, BLOCK_B)
|
| 384 |
+
range_d = tl.arange(0, BLOCK_D)
|
| 385 |
+
range_n = B * idx_m + range_b
|
| 386 |
+
|
| 387 |
+
mask_b = range_b < B
|
| 388 |
+
pad_mask_b = mask_b & ((range_n >= pad_offset) if PRE_PAD else range_n < N)
|
| 389 |
+
k_mask_b = pad_mask_b
|
| 390 |
+
mask_d = range_d < D
|
| 391 |
+
|
| 392 |
+
if HAS_ATTN_MASK:
|
| 393 |
+
mask_block_ptr = (
|
| 394 |
+
mask_ptr
|
| 395 |
+
+ stride_mask_e * idx_e
|
| 396 |
+
+ stride_mask_m * idx_m
|
| 397 |
+
+ stride_mask_b * (range_b - pad_offset)
|
| 398 |
+
)
|
| 399 |
+
valid_token_mask = tl.load(
|
| 400 |
+
mask_block_ptr,
|
| 401 |
+
mask=pad_mask_b,
|
| 402 |
+
other=0,
|
| 403 |
+
)
|
| 404 |
+
k_mask_b = pad_mask_b & valid_token_mask
|
| 405 |
+
|
| 406 |
+
# Load ar
|
| 407 |
+
ar_block_ptr = (
|
| 408 |
+
ar_ptr
|
| 409 |
+
+ stride_ar_e * idx_e
|
| 410 |
+
+ stride_ar_h * idx_h
|
| 411 |
+
+ stride_ar_m * idx_m
|
| 412 |
+
+ (
|
| 413 |
+
stride_ar_b * (range_b - (pad_offset if is_first_call else 0))[:, None]
|
| 414 |
+
+ stride_ar_d * range_d[None, :]
|
| 415 |
+
)
|
| 416 |
+
)
|
| 417 |
+
ar = tl.load(
|
| 418 |
+
ar_block_ptr,
|
| 419 |
+
mask=(pad_mask_b if is_first_call else mask_b)[:, None] & mask_d[None, :],
|
| 420 |
+
other=0.0,
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# Load k
|
| 424 |
+
k_block_ptr = (
|
| 425 |
+
k_ptr
|
| 426 |
+
+ stride_k_e * idx_e
|
| 427 |
+
+ stride_k_h * idx_h
|
| 428 |
+
+ stride_k_m * idx_m
|
| 429 |
+
+ (stride_k_b * (range_b - pad_offset)[:, None] + stride_k_d * range_d[None, :])
|
| 430 |
+
)
|
| 431 |
+
k = tl.load(
|
| 432 |
+
k_block_ptr,
|
| 433 |
+
mask=k_mask_b[:, None] & mask_d[None, :],
|
| 434 |
+
other=0.0,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Load cr
|
| 438 |
+
cr_block_ptr = (
|
| 439 |
+
cr_ptr
|
| 440 |
+
+ stride_cr_e * idx_e
|
| 441 |
+
+ stride_cr_h * idx_h
|
| 442 |
+
+ stride_cr_m * idx_m
|
| 443 |
+
+ (stride_cr_b * range_b)
|
| 444 |
+
)
|
| 445 |
+
cr = tl.load(cr_block_ptr, mask=mask_b, other=1.0)
|
| 446 |
+
|
| 447 |
+
# Attention matrix
|
| 448 |
+
r = sm_scale * tl.dot(ar, tl.trans(k))
|
| 449 |
+
r = r / (cr[:, None] + EPS)
|
| 450 |
+
r = r + tl.where(k_mask_b[None, :], 0.0, float("-inf"))
|
| 451 |
+
r = tl.exp(r - tl.clamp(tl.max(r, axis=1, keep_dims=True), EPS, float("inf")))
|
| 452 |
+
r = r / (tl.sum(r, axis=1, keep_dims=True) + EPS)
|
| 453 |
+
|
| 454 |
+
# Store cl
|
| 455 |
+
cl = tl.sum(xlogx(r), axis=1)
|
| 456 |
+
cl_block_ptr = (
|
| 457 |
+
cl_ptr
|
| 458 |
+
+ stride_cl_e * idx_e
|
| 459 |
+
+ stride_cl_h * idx_h
|
| 460 |
+
+ stride_cl_m * idx_m
|
| 461 |
+
+ (stride_cl_b * range_b)
|
| 462 |
+
)
|
| 463 |
+
tl.store(cl_block_ptr, cl, mask=mask_b)
|
| 464 |
+
|
| 465 |
+
# Store al
|
| 466 |
+
al = (sm_scale * tl.dot(r.to(k.dtype), k)).to(ar.dtype)
|
| 467 |
+
al_block_ptr = (
|
| 468 |
+
al_ptr
|
| 469 |
+
+ stride_al_e * idx_e
|
| 470 |
+
+ stride_al_h * idx_h
|
| 471 |
+
+ stride_al_m * idx_m
|
| 472 |
+
+ (stride_al_b * range_b[:, None] + stride_al_d * range_d[None, :])
|
| 473 |
+
)
|
| 474 |
+
tl.store(
|
| 475 |
+
al_block_ptr,
|
| 476 |
+
al,
|
| 477 |
+
mask=mask_b[:, None] & mask_d[None, :],
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# Load v
|
| 481 |
+
v_block_ptr = (
|
| 482 |
+
v_ptr
|
| 483 |
+
+ stride_v_e * idx_e
|
| 484 |
+
+ stride_v_h * idx_h
|
| 485 |
+
+ stride_v_m * idx_m
|
| 486 |
+
+ (stride_v_b * (range_b - pad_offset)[:, None] + stride_v_d * range_d[None, :])
|
| 487 |
+
)
|
| 488 |
+
v = tl.load(
|
| 489 |
+
v_block_ptr,
|
| 490 |
+
mask=k_mask_b[:, None] & mask_d[None, :],
|
| 491 |
+
other=0.0,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# Store y
|
| 495 |
+
y = tl.dot(r.to(v.dtype), v).to(ar.dtype)
|
| 496 |
+
y_block_ptr = (
|
| 497 |
+
y_ptr
|
| 498 |
+
+ stride_y_e * idx_e
|
| 499 |
+
+ stride_y_h * idx_h
|
| 500 |
+
+ stride_y_m * idx_m
|
| 501 |
+
+ (stride_y_b * range_b[:, None] + stride_y_d * range_d[None, :])
|
| 502 |
+
)
|
| 503 |
+
tl.store(
|
| 504 |
+
y_block_ptr,
|
| 505 |
+
y,
|
| 506 |
+
mask=mask_b[:, None] & mask_d[None, :],
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
@triton.jit
|
| 511 |
+
def _z_kernel(
|
| 512 |
+
al_ptr,
|
| 513 |
+
stride_al_e,
|
| 514 |
+
stride_al_h,
|
| 515 |
+
stride_al_m,
|
| 516 |
+
stride_al_b,
|
| 517 |
+
stride_al_d,
|
| 518 |
+
q_ptr,
|
| 519 |
+
stride_q_e,
|
| 520 |
+
stride_q_h,
|
| 521 |
+
stride_q_m,
|
| 522 |
+
stride_q_b,
|
| 523 |
+
stride_q_d,
|
| 524 |
+
y_ptr,
|
| 525 |
+
stride_y_e,
|
| 526 |
+
stride_y_h,
|
| 527 |
+
stride_y_m,
|
| 528 |
+
stride_y_b,
|
| 529 |
+
stride_y_d,
|
| 530 |
+
cl_ptr,
|
| 531 |
+
stride_cl_e,
|
| 532 |
+
stride_cl_h,
|
| 533 |
+
stride_cl_m,
|
| 534 |
+
stride_cl_b,
|
| 535 |
+
z_ptr,
|
| 536 |
+
stride_z_e,
|
| 537 |
+
stride_z_h,
|
| 538 |
+
stride_z_m,
|
| 539 |
+
stride_z_b,
|
| 540 |
+
stride_z_d,
|
| 541 |
+
H: int,
|
| 542 |
+
M: int,
|
| 543 |
+
B: int,
|
| 544 |
+
D: int,
|
| 545 |
+
N: int,
|
| 546 |
+
BLOCK_M: tl.constexpr,
|
| 547 |
+
BLOCK_D: tl.constexpr,
|
| 548 |
+
PRE_PAD: tl.constexpr,
|
| 549 |
+
):
|
| 550 |
+
idx_ehb = tl.program_id(0)
|
| 551 |
+
idx_eh = idx_ehb // B
|
| 552 |
+
idx_e = idx_eh // H
|
| 553 |
+
idx_h = idx_eh % H
|
| 554 |
+
idx_b = idx_ehb % B
|
| 555 |
+
|
| 556 |
+
pad_offset = M * B - N if PRE_PAD else 0
|
| 557 |
+
|
| 558 |
+
range_m = tl.arange(0, BLOCK_M)
|
| 559 |
+
range_d = tl.arange(0, BLOCK_D)
|
| 560 |
+
range_n = idx_b + B * range_m
|
| 561 |
+
|
| 562 |
+
mask_m = range_m < M
|
| 563 |
+
q_mask_m = mask_m & (range_n >= pad_offset if PRE_PAD else range_n < N)
|
| 564 |
+
mask_d = range_d < D
|
| 565 |
+
|
| 566 |
+
# Load al
|
| 567 |
+
al_block_ptr = (
|
| 568 |
+
al_ptr
|
| 569 |
+
+ stride_al_e * idx_e
|
| 570 |
+
+ stride_al_h * idx_h
|
| 571 |
+
+ stride_al_b * idx_b
|
| 572 |
+
+ (stride_al_m * range_m[:, None] + stride_al_d * range_d[None, :])
|
| 573 |
+
)
|
| 574 |
+
al = tl.load(
|
| 575 |
+
al_block_ptr,
|
| 576 |
+
mask=mask_m[:, None] & mask_d[None, :],
|
| 577 |
+
other=0.0,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
# Load q
|
| 581 |
+
q_block_ptr = (
|
| 582 |
+
q_ptr
|
| 583 |
+
+ stride_q_e * idx_e
|
| 584 |
+
+ stride_q_h * idx_h
|
| 585 |
+
+ stride_q_b * (idx_b - pad_offset)
|
| 586 |
+
+ (stride_q_m * range_m[:, None] + stride_q_d * range_d[None, :])
|
| 587 |
+
)
|
| 588 |
+
q = tl.load(
|
| 589 |
+
q_block_ptr,
|
| 590 |
+
mask=q_mask_m[:, None] & mask_d[None, :],
|
| 591 |
+
other=0.0,
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
# Load cl
|
| 595 |
+
cl_block_ptr = (
|
| 596 |
+
cl_ptr
|
| 597 |
+
+ stride_cl_e * idx_e
|
| 598 |
+
+ stride_cl_h * idx_h
|
| 599 |
+
+ stride_cl_b * idx_b
|
| 600 |
+
+ (stride_cl_m * range_m)
|
| 601 |
+
)
|
| 602 |
+
cl = tl.load(cl_block_ptr, mask=mask_m, other=0.0)
|
| 603 |
+
|
| 604 |
+
# Attention matrix
|
| 605 |
+
l = tl.dot(q, tl.trans(al))
|
| 606 |
+
l = l - cl[None, :]
|
| 607 |
+
l = l + tl.where(mask_m[None, :], 0.0, float("-inf"))
|
| 608 |
+
l = tl.exp(l - tl.max(l, axis=1, keep_dims=True))
|
| 609 |
+
l = l / tl.sum(l, axis=1, keep_dims=True)
|
| 610 |
+
|
| 611 |
+
# Load y
|
| 612 |
+
y_block_ptr = (
|
| 613 |
+
y_ptr
|
| 614 |
+
+ stride_y_e * idx_e
|
| 615 |
+
+ stride_y_h * idx_h
|
| 616 |
+
+ stride_y_b * idx_b
|
| 617 |
+
+ (stride_y_m * range_m[:, None] + stride_y_d * range_d[None, :])
|
| 618 |
+
)
|
| 619 |
+
y = tl.load(
|
| 620 |
+
y_block_ptr,
|
| 621 |
+
mask=mask_m[:, None] & mask_d[None, :],
|
| 622 |
+
other=0.0,
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
# Store z
|
| 626 |
+
z = tl.dot(l.to(y.dtype), y).to(al.dtype)
|
| 627 |
+
z_block_ptr = (
|
| 628 |
+
z_ptr
|
| 629 |
+
+ stride_z_e * idx_e
|
| 630 |
+
+ stride_z_h * idx_h
|
| 631 |
+
+ stride_z_b * (idx_b - pad_offset)
|
| 632 |
+
+ (stride_z_m * range_m[:, None] + stride_z_d * range_d[None, :])
|
| 633 |
+
)
|
| 634 |
+
tl.store(
|
| 635 |
+
z_block_ptr,
|
| 636 |
+
z,
|
| 637 |
+
mask=q_mask_m[:, None] & mask_d[None, :],
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def monarch_attention_triton(
|
| 642 |
+
q: Tensor,
|
| 643 |
+
k: Tensor,
|
| 644 |
+
v: Tensor,
|
| 645 |
+
attn_mask: Tensor | None,
|
| 646 |
+
T: int,
|
| 647 |
+
B: int,
|
| 648 |
+
pre_pad: bool,
|
| 649 |
+
eps: float = 0.0,
|
| 650 |
+
) -> Tensor:
|
| 651 |
+
E, H, N, D = q.shape
|
| 652 |
+
M = triton.cdiv(N, B)
|
| 653 |
+
|
| 654 |
+
HMBDN = (H, M, B, D, N)
|
| 655 |
+
|
| 656 |
+
grid_ehm = (E * H * M,)
|
| 657 |
+
grid_ehb = (E * H * B,)
|
| 658 |
+
|
| 659 |
+
BLOCK_B = max(triton.next_power_of_2(B), 16)
|
| 660 |
+
BLOCK_M = max(triton.next_power_of_2(M), 16)
|
| 661 |
+
BLOCK_D = max(triton.next_power_of_2(D), 16)
|
| 662 |
+
|
| 663 |
+
sm_scale = 1 / sqrt(D)
|
| 664 |
+
|
| 665 |
+
q_strides = (q.stride(0), q.stride(1), B * q.stride(2), q.stride(2), q.stride(3))
|
| 666 |
+
k_strides = (k.stride(0), k.stride(1), B * k.stride(2), k.stride(2), k.stride(3))
|
| 667 |
+
v_strides = (v.stride(0), v.stride(1), B * v.stride(2), v.stride(2), v.stride(3))
|
| 668 |
+
|
| 669 |
+
ar = torch.empty(E, H, M, B, D, device=q.device, dtype=q.dtype)
|
| 670 |
+
al = torch.empty_like(ar)
|
| 671 |
+
|
| 672 |
+
ar_strides = (ar.stride(0), ar.stride(1), ar.stride(2), ar.stride(3), ar.stride(4))
|
| 673 |
+
al_strides = (al.stride(0), al.stride(1), al.stride(2), al.stride(3), al.stride(4))
|
| 674 |
+
|
| 675 |
+
cr = torch.ones(E, H, M, B, device=q.device, dtype=torch.float)
|
| 676 |
+
cl = torch.empty_like(cr)
|
| 677 |
+
|
| 678 |
+
cr_strides = (cr.stride(0), cr.stride(1), cr.stride(2), cr.stride(3))
|
| 679 |
+
cl_strides = (cl.stride(0), cl.stride(1), cl.stride(2), cl.stride(3))
|
| 680 |
+
|
| 681 |
+
attn_mask_strides = (
|
| 682 |
+
(attn_mask.stride(0), B * attn_mask.stride(1), attn_mask.stride(1))
|
| 683 |
+
if attn_mask is not None
|
| 684 |
+
else (0, 0, 0)
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
for t in range(T - 1):
|
| 688 |
+
is_first_call = t == 0
|
| 689 |
+
_ar = q if is_first_call else ar
|
| 690 |
+
_ar_strides = q_strides if is_first_call else ar_strides
|
| 691 |
+
_al_cl_kernel[grid_ehm](
|
| 692 |
+
_ar,
|
| 693 |
+
*_ar_strides,
|
| 694 |
+
k,
|
| 695 |
+
*k_strides,
|
| 696 |
+
cr,
|
| 697 |
+
*cr_strides,
|
| 698 |
+
al,
|
| 699 |
+
*al_strides,
|
| 700 |
+
cl,
|
| 701 |
+
*cl_strides,
|
| 702 |
+
attn_mask,
|
| 703 |
+
*attn_mask_strides,
|
| 704 |
+
*HMBDN,
|
| 705 |
+
is_first_call=is_first_call,
|
| 706 |
+
sm_scale=sm_scale,
|
| 707 |
+
HAS_ATTN_MASK=attn_mask is not None, # type: ignore
|
| 708 |
+
BLOCK_B=BLOCK_B, # type: ignore
|
| 709 |
+
BLOCK_D=BLOCK_D, # type: ignore
|
| 710 |
+
PRE_PAD=pre_pad, # type: ignore
|
| 711 |
+
EPS=eps, # type: ignore
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
_ar_cr_kernel[grid_ehb](
|
| 715 |
+
al,
|
| 716 |
+
*al_strides,
|
| 717 |
+
q,
|
| 718 |
+
*q_strides,
|
| 719 |
+
cl,
|
| 720 |
+
*cl_strides,
|
| 721 |
+
ar,
|
| 722 |
+
*ar_strides,
|
| 723 |
+
cr,
|
| 724 |
+
*cr_strides,
|
| 725 |
+
attn_mask,
|
| 726 |
+
*attn_mask_strides,
|
| 727 |
+
*HMBDN,
|
| 728 |
+
HAS_ATTN_MASK=attn_mask is not None, # type: ignore
|
| 729 |
+
BLOCK_M=BLOCK_M, # type: ignore
|
| 730 |
+
BLOCK_D=BLOCK_D, # type: ignore
|
| 731 |
+
PRE_PAD=pre_pad, # type: ignore
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
y = torch.empty_like(al)
|
| 735 |
+
y_strides = (y.stride(0), y.stride(1), y.stride(2), y.stride(3), y.stride(4))
|
| 736 |
+
|
| 737 |
+
is_first_call_y = T == 1
|
| 738 |
+
_ar_y = q if is_first_call_y else ar
|
| 739 |
+
_ar_y_strides = q_strides if is_first_call_y else ar_strides
|
| 740 |
+
|
| 741 |
+
_al_y_cl_kernel[grid_ehm](
|
| 742 |
+
_ar_y,
|
| 743 |
+
*_ar_y_strides,
|
| 744 |
+
k,
|
| 745 |
+
*k_strides,
|
| 746 |
+
v,
|
| 747 |
+
*v_strides,
|
| 748 |
+
cr,
|
| 749 |
+
*cr_strides,
|
| 750 |
+
al,
|
| 751 |
+
*al_strides,
|
| 752 |
+
y,
|
| 753 |
+
*y_strides,
|
| 754 |
+
cl,
|
| 755 |
+
*cl_strides,
|
| 756 |
+
attn_mask,
|
| 757 |
+
*attn_mask_strides,
|
| 758 |
+
*HMBDN,
|
| 759 |
+
is_first_call=is_first_call_y,
|
| 760 |
+
sm_scale=sm_scale,
|
| 761 |
+
HAS_ATTN_MASK=attn_mask is not None, # type: ignore
|
| 762 |
+
BLOCK_B=BLOCK_B, # type: ignore
|
| 763 |
+
BLOCK_D=BLOCK_D, # type: ignore
|
| 764 |
+
PRE_PAD=pre_pad, # type: ignore
|
| 765 |
+
EPS=eps, # type: ignore
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
z = torch.empty_like(v)
|
| 769 |
+
z_strides = (z.stride(0), z.stride(1), B * z.stride(2), z.stride(2), z.stride(3))
|
| 770 |
+
|
| 771 |
+
_z_kernel[grid_ehb](
|
| 772 |
+
al,
|
| 773 |
+
*al_strides,
|
| 774 |
+
q,
|
| 775 |
+
*q_strides,
|
| 776 |
+
y,
|
| 777 |
+
*y_strides,
|
| 778 |
+
cl,
|
| 779 |
+
*cl_strides,
|
| 780 |
+
z,
|
| 781 |
+
*z_strides,
|
| 782 |
+
*HMBDN,
|
| 783 |
+
BLOCK_M=BLOCK_M, # type: ignore
|
| 784 |
+
BLOCK_D=BLOCK_D, # type: ignore
|
| 785 |
+
PRE_PAD=pre_pad, # type: ignore
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
return z
|
monarch_attn/monarch_attention.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Callable
|
| 2 |
+
from enum import StrEnum
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from .ma_torch import monarch_attention_torch
|
| 8 |
+
|
| 9 |
+
Tensor = torch.Tensor
|
| 10 |
+
|
| 11 |
+
MonarchAttentionFn = Callable[
|
| 12 |
+
[Tensor, Tensor, Tensor, Tensor | None, int, int, bool], Tensor
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
_IMPLEMENTATIONS: dict[str, MonarchAttentionFn] = {}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def register_impl(name: str, fn: MonarchAttentionFn) -> None:
|
| 19 |
+
_IMPLEMENTATIONS[name] = fn
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
register_impl("torch", monarch_attention_torch)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from .ma_triton import monarch_attention_triton
|
| 26 |
+
|
| 27 |
+
register_impl("triton", monarch_attention_triton)
|
| 28 |
+
except ImportError:
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class PadType(StrEnum):
|
| 33 |
+
pre = "pre"
|
| 34 |
+
post = "post"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class MonarchAttention(nn.Module):
|
| 38 |
+
|
| 39 |
+
def __init__(self, block_size, num_steps, pad_type, impl="torch"):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.block_size = block_size
|
| 42 |
+
self.num_steps = num_steps
|
| 43 |
+
self.pad_type = pad_type
|
| 44 |
+
|
| 45 |
+
if impl not in _IMPLEMENTATIONS:
|
| 46 |
+
available = ", ".join(sorted(_IMPLEMENTATIONS))
|
| 47 |
+
raise ValueError(f"Unknown impl {impl!r}. Available: {available}")
|
| 48 |
+
self._impl_fn = _IMPLEMENTATIONS[impl]
|
| 49 |
+
|
| 50 |
+
def forward(self, query, key, value, attention_mask=None):
|
| 51 |
+
return self._impl_fn(
|
| 52 |
+
query,
|
| 53 |
+
key,
|
| 54 |
+
value,
|
| 55 |
+
attention_mask,
|
| 56 |
+
self.num_steps,
|
| 57 |
+
self.block_size,
|
| 58 |
+
self.pad_type == PadType.pre,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def get_matrix(self, query, key, attention_mask=None):
|
| 62 |
+
batch_size, num_heads, seq_len, _ = query.shape
|
| 63 |
+
value = torch.eye(seq_len, device=query.device).expand(
|
| 64 |
+
batch_size, num_heads, seq_len, seq_len
|
| 65 |
+
)
|
| 66 |
+
return self.forward(query, key, value, attention_mask)
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
Pillow
|
| 4 |
+
einops
|
| 5 |
+
numpy
|
utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
def smart_padding(image, divisor=16):
|
| 6 |
+
""" Pad the image so that its dimensions are divisible by the divisor. """
|
| 7 |
+
h, w = image.shape[-2:]
|
| 8 |
+
pad_h = (divisor - h % divisor) % divisor
|
| 9 |
+
pad_w = (divisor - w % divisor) % divisor
|
| 10 |
+
left = pad_w // 2
|
| 11 |
+
right = pad_w - left
|
| 12 |
+
top = pad_h // 2
|
| 13 |
+
bottom = pad_h - top
|
| 14 |
+
padding = (left, right, top, bottom)
|
| 15 |
+
padded_image = F.pad(image, padding, mode='constant', value=1.0)
|
| 16 |
+
return padded_image, padding
|
| 17 |
+
|
| 18 |
+
def remove_padding(image, padding):
|
| 19 |
+
""" Remove the padding from the image. """
|
| 20 |
+
left, right, top, bottom = padding
|
| 21 |
+
if right == 0:
|
| 22 |
+
w_end = image.shape[-1]
|
| 23 |
+
else:
|
| 24 |
+
w_end = -right
|
| 25 |
+
if bottom == 0:
|
| 26 |
+
h_end = image.shape[-2]
|
| 27 |
+
else:
|
| 28 |
+
h_end = -bottom
|
| 29 |
+
return image[..., top:h_end, left:w_end]
|
weights/colorizer.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a8ed3573050cdd3fc9c9542ec9dd76e91143fd2052b1c379f87422c59ae60fc
|
| 3 |
+
size 32434763
|