Kernels
sae
elephantmipt commited on
Commit
a38f7ad
·
verified ·
1 Parent(s): a262a48

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +4 -3
  2. example.py +5 -4
README.md CHANGED
@@ -93,7 +93,8 @@ B = 2048
93
  K = 256
94
  F = 1024 * 128
95
  D = 1024
96
- warmup = 5
 
97
  dtype = torch.float32
98
 
99
  vals = None
@@ -126,7 +127,7 @@ def zero_grad():
126
  torch.cuda.empty_cache()
127
 
128
 
129
- for i in range(100 + warmup):
130
  init_parameters()
131
  start_kernel = torch.cuda.Event(enable_timing=True)
132
  end_kernel = torch.cuda.Event(enable_timing=True)
@@ -143,7 +144,7 @@ for i in range(100 + warmup):
143
  loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
144
  loss_vanilla.backward()
145
  end_vanilla.record()
146
- if i >= warmup:
147
  torch.cuda.synchronize()
148
  timing_kernel.append(start_kernel.elapsed_time(end_kernel))
149
  timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
 
93
  K = 256
94
  F = 1024 * 128
95
  D = 1024
96
+ WARMUP = 5
97
+ NUM_ITER = 100
98
  dtype = torch.float32
99
 
100
  vals = None
 
127
  torch.cuda.empty_cache()
128
 
129
 
130
+ for i in range(NUM_ITER + WARMUP):
131
  init_parameters()
132
  start_kernel = torch.cuda.Event(enable_timing=True)
133
  end_kernel = torch.cuda.Event(enable_timing=True)
 
144
  loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
145
  loss_vanilla.backward()
146
  end_vanilla.record()
147
+ if i >= WARMUP:
148
  torch.cuda.synchronize()
149
  timing_kernel.append(start_kernel.elapsed_time(end_kernel))
150
  timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
example.py CHANGED
@@ -31,7 +31,8 @@ B = 2048
31
  K = 256
32
  F = 1024 * 128
33
  D = 1024
34
- warmup = 5
 
35
  dtype = torch.float32
36
 
37
  vals = None
@@ -64,7 +65,7 @@ def zero_grad():
64
  torch.cuda.empty_cache()
65
 
66
 
67
- for i in range(100 + warmup):
68
  init_parameters()
69
  start_kernel = torch.cuda.Event(enable_timing=True)
70
  end_kernel = torch.cuda.Event(enable_timing=True)
@@ -81,7 +82,7 @@ for i in range(100 + warmup):
81
  loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
82
  loss_vanilla.backward()
83
  end_vanilla.record()
84
- if i >= warmup:
85
  torch.cuda.synchronize()
86
  timing_kernel.append(start_kernel.elapsed_time(end_kernel))
87
  timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
@@ -97,4 +98,4 @@ else:
97
 
98
  print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
99
  print(f"🔥 Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} ± {np.std(timing_vanilla):.4f} ms")
100
- print(f"🚀 Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")
 
31
  K = 256
32
  F = 1024 * 128
33
  D = 1024
34
+ WARMUP = 5
35
+ NUM_ITER = 100
36
  dtype = torch.float32
37
 
38
  vals = None
 
65
  torch.cuda.empty_cache()
66
 
67
 
68
+ for i in range(NUM_ITER + WARMUP):
69
  init_parameters()
70
  start_kernel = torch.cuda.Event(enable_timing=True)
71
  end_kernel = torch.cuda.Event(enable_timing=True)
 
82
  loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target)
83
  loss_vanilla.backward()
84
  end_vanilla.record()
85
+ if i >= WARMUP:
86
  torch.cuda.synchronize()
87
  timing_kernel.append(start_kernel.elapsed_time(end_kernel))
88
  timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla))
 
98
 
99
  print(f"🦎 Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} ± {np.std(timing_kernel):.4f} ms")
100
  print(f"🔥 Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} ± {np.std(timing_vanilla):.4f} ms")
101
+ print(f"🚀 Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x")