Upload folder using huggingface_hub
Browse files
README.md
CHANGED
|
@@ -33,7 +33,7 @@ hierarchical_top_k_kernel = flex.triton_hierarchical_sae_loss
|
|
| 33 |
"B -- batch size, K -- top-k, F -- dictionary size, D -- model hidden dim"
|
| 34 |
|
| 35 |
loss: torch.Tensor = top_k_kernel(
|
| 36 |
-
|
| 37 |
weight: torch.Tensor, # [F, D]
|
| 38 |
vals: torch.Tensor, # [B, K]
|
| 39 |
bias: torch.Tensor, # [D]
|
|
@@ -41,7 +41,7 @@ loss: torch.Tensor = top_k_kernel(
|
|
| 41 |
)
|
| 42 |
|
| 43 |
loss: torch.Tensor = hierarchical_top_k_kernel(
|
| 44 |
-
|
| 45 |
weight: torch.Tensor, # [F, D]
|
| 46 |
vals: torch.Tensor, # [B, K]
|
| 47 |
bias: torch.Tensor, # [D]
|
|
|
|
| 33 |
"B -- batch size, K -- top-k, F -- dictionary size, D -- model hidden dim"
|
| 34 |
|
| 35 |
loss: torch.Tensor = top_k_kernel(
|
| 36 |
+
indices: torch.Tensor, # [B, K]
|
| 37 |
weight: torch.Tensor, # [F, D]
|
| 38 |
vals: torch.Tensor, # [B, K]
|
| 39 |
bias: torch.Tensor, # [D]
|
|
|
|
| 41 |
)
|
| 42 |
|
| 43 |
loss: torch.Tensor = hierarchical_top_k_kernel(
|
| 44 |
+
indices: torch.Tensor, # [B, K]
|
| 45 |
weight: torch.Tensor, # [F, D]
|
| 46 |
vals: torch.Tensor, # [B, K]
|
| 47 |
bias: torch.Tensor, # [D]
|