Upload folder using huggingface_hub
Browse files- README.md +4 -3
- example.py +5 -4
README.md
CHANGED
|
@@ -93,7 +93,8 @@ B = 2048
|
|
| 93 |
K = 256
|
| 94 |
F = 1024 * 128
|
| 95 |
D = 1024
|
| 96 |
-
|
|
|
|
| 97 |
dtype = torch.float32
|
| 98 |
|
| 99 |
vals = None
|
|
@@ -126,7 +127,7 @@ def zero_grad():
|
|
| 126 |
torch.cuda.empty_cache()
|
| 127 |
|
| 128 |
|
| 129 |
-
for i in range(
|
| 130 |
init_parameters()
|
| 131 |
start_kernel = torch.cuda.Event(enable_timing=True)
|
| 132 |
end_kernel = torch.cuda.Event(enable_timing=True)
|
|
@@ -143,7 +144,7 @@ for i in range(100 + warmup):
|
|
| 143 |
loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
|
| 144 |
loss_vanilla.backward()
|
| 145 |
end_vanilla.record()
|
| 146 |
-
if i >=
|
| 147 |
torch.cuda.synchronize()
|
| 148 |
timing_kernel.append(start_kernel.elapsed_time(end_kernel))
|
| 149 |
timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
|
|
|
|
| 93 |
K = 256
|
| 94 |
F = 1024 * 128
|
| 95 |
D = 1024
|
| 96 |
+
WARMUP = 5
|
| 97 |
+
NUM_ITER = 100
|
| 98 |
dtype = torch.float32
|
| 99 |
|
| 100 |
vals = None
|
|
|
|
| 127 |
torch.cuda.empty_cache()
|
| 128 |
|
| 129 |
|
| 130 |
+
for i in range(NUM_ITER + WARMUP):
|
| 131 |
init_parameters()
|
| 132 |
start_kernel = torch.cuda.Event(enable_timing=True)
|
| 133 |
end_kernel = torch.cuda.Event(enable_timing=True)
|
|
|
|
| 144 |
loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
|
| 145 |
loss_vanilla.backward()
|
| 146 |
end_vanilla.record()
|
| 147 |
+
if i >= WARMUP:
|
| 148 |
torch.cuda.synchronize()
|
| 149 |
timing_kernel.append(start_kernel.elapsed_time(end_kernel))
|
| 150 |
timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
|
example.py
CHANGED
|
@@ -31,7 +31,8 @@ B = 2048
|
|
| 31 |
K = 256
|
| 32 |
F = 1024 * 128
|
| 33 |
D = 1024
|
| 34 |
-
|
|
|
|
| 35 |
dtype = torch.float32
|
| 36 |
|
| 37 |
vals = None
|
|
@@ -64,7 +65,7 @@ def zero_grad():
|
|
| 64 |
torch.cuda.empty_cache()
|
| 65 |
|
| 66 |
|
| 67 |
-
for i in range(
|
| 68 |
init_parameters()
|
| 69 |
start_kernel = torch.cuda.Event(enable_timing=True)
|
| 70 |
end_kernel = torch.cuda.Event(enable_timing=True)
|
|
@@ -81,7 +82,7 @@ for i in range(100 + warmup):
|
|
| 81 |
loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
|
| 82 |
loss_vanilla.backward()
|
| 83 |
end_vanilla.record()
|
| 84 |
-
if i >=
|
| 85 |
torch.cuda.synchronize()
|
| 86 |
timing_kernel.append(start_kernel.elapsed_time(end_kernel))
|
| 87 |
timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
|
|
@@ -97,4 +98,4 @@ else:
|
|
| 97 |
|
| 98 |
print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
|
| 99 |
print(f"🔥 Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} ± {np.std(timing_vanilla):.4f} ms")
|
| 100 |
-
print(f"🚀 Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")
|
|
|
|
| 31 |
K = 256
|
| 32 |
F = 1024 * 128
|
| 33 |
D = 1024
|
| 34 |
+
WARMUP = 5
|
| 35 |
+
NUM_ITER = 100
|
| 36 |
dtype = torch.float32
|
| 37 |
|
| 38 |
vals = None
|
|
|
|
| 65 |
torch.cuda.empty_cache()
|
| 66 |
|
| 67 |
|
| 68 |
+
for i in range(NUM_ITER + WARMUP):
|
| 69 |
init_parameters()
|
| 70 |
start_kernel = torch.cuda.Event(enable_timing=True)
|
| 71 |
end_kernel = torch.cuda.Event(enable_timing=True)
|
|
|
|
| 82 |
loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
|
| 83 |
loss_vanilla.backward()
|
| 84 |
end_vanilla.record()
|
| 85 |
+
if i >= WARMUP:
|
| 86 |
torch.cuda.synchronize()
|
| 87 |
timing_kernel.append(start_kernel.elapsed_time(end_kernel))
|
| 88 |
timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
|
|
|
|
| 98 |
|
| 99 |
print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
|
| 100 |
print(f"🔥 Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} ± {np.std(timing_vanilla):.4f} ms")
|
| 101 |
+
print(f"🚀 Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")
|