Kernels
danieldk HF Staff commited on
Commit
cd5a62f
·
verified ·
1 Parent(s): ff7ab44

Benchmarks uploaded using `kernels`.

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +124 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, einsum
3
+
4
+ from kernels.benchmark import Benchmark
5
+
6
+
7
+ class TriMulReference(nn.Module):
8
+ """Reference implementation of Triangle Multiplicative Module."""
9
+
10
+ def __init__(self, dim: int, hidden_dim: int):
11
+ super().__init__()
12
+ self.norm = nn.LayerNorm(dim)
13
+ self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
14
+ self.right_proj = nn.Linear(dim, hidden_dim, bias=False)
15
+ self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
16
+ self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
17
+ self.out_gate = nn.Linear(dim, hidden_dim, bias=False)
18
+ self.to_out_norm = nn.LayerNorm(hidden_dim)
19
+ self.to_out = nn.Linear(hidden_dim, dim, bias=False)
20
+
21
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
22
+ x = self.norm(x)
23
+
24
+ left = self.left_proj(x)
25
+ right = self.right_proj(x)
26
+
27
+ mask = mask.unsqueeze(-1)
28
+ left = left * mask
29
+ right = right * mask
30
+
31
+ left_gate = self.left_gate(x).sigmoid()
32
+ right_gate = self.right_gate(x).sigmoid()
33
+ out_gate = self.out_gate(x).sigmoid()
34
+
35
+ left = left * left_gate
36
+ right = right * right_gate
37
+
38
+ out = einsum("... i k d, ... j k d -> ... i j d", left, right)
39
+
40
+ out = self.to_out_norm(out)
41
+ out = out * out_gate
42
+ return self.to_out(out)
43
+
44
+
45
+ class TrimulGpumodeBenchmark(Benchmark):
46
+ seed: int = 42
47
+
48
+ def setup(self):
49
+ # Note: hidden_dim must be 128 (kernel constraint)
50
+ batch_size = 1
51
+ seq_len = 128
52
+ dim = 128
53
+ hidden_dim = 128
54
+
55
+ self.config = {"dim": dim, "hidden_dim": hidden_dim}
56
+
57
+ self.input_tensor = torch.randn(
58
+ batch_size, seq_len, seq_len, dim, device="cuda", dtype=torch.float32
59
+ )
60
+ self.mask = torch.ones(
61
+ batch_size, seq_len, seq_len, device="cuda", dtype=torch.float32
62
+ )
63
+
64
+ self.weights = {
65
+ "norm.weight": torch.ones(dim, device="cuda", dtype=torch.float32),
66
+ "norm.bias": torch.zeros(dim, device="cuda", dtype=torch.float32),
67
+ "left_proj.weight": torch.randn(
68
+ hidden_dim, dim, device="cuda", dtype=torch.float32
69
+ )
70
+ * 0.02,
71
+ "right_proj.weight": torch.randn(
72
+ hidden_dim, dim, device="cuda", dtype=torch.float32
73
+ )
74
+ * 0.02,
75
+ "left_gate.weight": torch.randn(
76
+ hidden_dim, dim, device="cuda", dtype=torch.float32
77
+ )
78
+ * 0.02,
79
+ "right_gate.weight": torch.randn(
80
+ hidden_dim, dim, device="cuda", dtype=torch.float32
81
+ )
82
+ * 0.02,
83
+ "out_gate.weight": torch.randn(
84
+ hidden_dim, dim, device="cuda", dtype=torch.float32
85
+ )
86
+ * 0.02,
87
+ "to_out_norm.weight": torch.ones(
88
+ hidden_dim, device="cuda", dtype=torch.float32
89
+ ),
90
+ "to_out_norm.bias": torch.zeros(
91
+ hidden_dim, device="cuda", dtype=torch.float32
92
+ ),
93
+ "to_out.weight": torch.randn(
94
+ dim, hidden_dim, device="cuda", dtype=torch.float32
95
+ )
96
+ * 0.02,
97
+ }
98
+
99
+ self.out = torch.empty(
100
+ batch_size, seq_len, seq_len, dim, device="cuda", dtype=torch.float32
101
+ )
102
+
103
+ def benchmark_base(self):
104
+ data = (self.input_tensor, self.mask, self.weights, self.config)
105
+ self.out = self.kernel.kernel_global(data)
106
+
107
+ def verify_base(self) -> torch.Tensor:
108
+ ref = TriMulReference(
109
+ dim=self.config["dim"], hidden_dim=self.config["hidden_dim"]
110
+ ).cuda()
111
+
112
+ ref.norm.weight = nn.Parameter(self.weights["norm.weight"])
113
+ ref.norm.bias = nn.Parameter(self.weights["norm.bias"])
114
+ ref.left_proj.weight = nn.Parameter(self.weights["left_proj.weight"])
115
+ ref.right_proj.weight = nn.Parameter(self.weights["right_proj.weight"])
116
+ ref.left_gate.weight = nn.Parameter(self.weights["left_gate.weight"])
117
+ ref.right_gate.weight = nn.Parameter(self.weights["right_gate.weight"])
118
+ ref.out_gate.weight = nn.Parameter(self.weights["out_gate.weight"])
119
+ ref.to_out_norm.weight = nn.Parameter(self.weights["to_out_norm.weight"])
120
+ ref.to_out_norm.bias = nn.Parameter(self.weights["to_out_norm.bias"])
121
+ ref.to_out.weight = nn.Parameter(self.weights["to_out.weight"])
122
+
123
+ with torch.no_grad():
124
+ return ref(self.input_tensor, self.mask)