Kernels
sae
elephantmipt commited on
Commit
8f19e61
·
verified ·
1 Parent(s): a38f7ad

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -33,7 +33,7 @@ hierarchical_top_k_kernel = flex.triton_hierarchical_sae_loss
33
  "B -- batch size, K -- top-k, F -- dictionary size, D -- model hidden dim"
34
 
35
  loss: torch.Tensor = top_k_kernel(
36
- indices: torch.Tensor, # [B, K]
37
  weight: torch.Tensor, # [F, D]
38
  vals: torch.Tensor, # [B, K]
39
  bias: torch.Tensor, # [D]
@@ -41,7 +41,7 @@ loss: torch.Tensor = top_k_kernel(
41
  )
42
 
43
  loss: torch.Tensor = hierarchical_top_k_kernel(
44
- indices: torch.Tensor, # [B, K]
45
  weight: torch.Tensor, # [F, D]
46
  vals: torch.Tensor, # [B, K]
47
  bias: torch.Tensor, # [D]
 
33
  "B -- batch size, K -- top-k, F -- dictionary size, D -- model hidden dim"
34
 
35
  loss: torch.Tensor = top_k_kernel(
36
+ indices: torch.Tensor, # [B, K]
37
  weight: torch.Tensor, # [F, D]
38
  vals: torch.Tensor, # [B, K]
39
  bias: torch.Tensor, # [D]
 
41
  )
42
 
43
  loss: torch.Tensor = hierarchical_top_k_kernel(
44
+ indices: torch.Tensor, # [B, K]
45
  weight: torch.Tensor, # [F, D]
46
  vals: torch.Tensor, # [B, K]
47
  bias: torch.Tensor, # [D]